diff --git a/src/de.rs b/src/de.rs index b17ce21d..fa186cfe 100644 --- a/src/de.rs +++ b/src/de.rs @@ -1,7 +1,3 @@ -//! YAML Deserialization -//! -//! This module provides YAML deserialization with the type `Deserializer`. - use crate::path::Path; use crate::{error, Error, Result}; use serde::de::{ @@ -18,6 +14,336 @@ use std::str; use yaml_rust::parser::{Event as YamlEvent, MarkedEventReceiver, Parser}; use yaml_rust::scanner::{Marker, TScalarStyle, TokenType}; +/// A structure that deserializes YAML into Rust values. +pub struct Deserializer<'a> { + input: Input<'a>, +} + +enum Input<'a> { + Str(&'a str), + Slice(&'a [u8]), + Read(Box), +} + +impl<'a> Deserializer<'a> { + /// Creates a YAML deserializer from a `&str`. + pub fn from_str(s: &'a str) -> Self { + let input = Input::Str(s); + Deserializer { input } + } + + /// Creates a YAML deserializer from a `&[u8]`. + pub fn from_slice(v: &'a [u8]) -> Self { + let input = Input::Slice(v); + Deserializer { input } + } + + /// Creates a YAML deserializer from an `io::Read`. + /// + /// Reader-based deserializers do not support deserializing borrowed types + /// like `&str`, since the `std::io::Read` trait has no non-copying methods + /// -- everything it does involves copying bytes out of the data source. + pub fn from_reader(rdr: R) -> Self + where + R: io::Read + 'a, + { + let input = Input::Read(Box::new(rdr)); + Deserializer { input } + } + + fn de(self, f: impl FnOnce(&mut DeserializerFromEvents) -> Result) -> Result { + let loader = loader(self.input)?; + let mut pos = 0; + let t = f(&mut DeserializerFromEvents { + events: &loader.events, + aliases: &loader.aliases, + pos: &mut pos, + path: Path::Root, + remaining_depth: 128, + })?; + if pos == loader.events.len() { + Ok(t) + } else { + Err(error::more_than_one_document()) + } + } +} + +fn loader(input: Input) -> Result { + enum Input2<'a> { + Str(&'a str), + Slice(&'a [u8]), + } + + let mut buffer; + let input = match input { + Input::Str(s) => Input2::Str(s), + Input::Slice(bytes) => Input2::Slice(bytes), + Input::Read(mut rdr) => { + buffer = Vec::new(); + rdr.read_to_end(&mut buffer).map_err(error::io)?; + Input2::Slice(&buffer) + } + }; + + let input = match input { + Input2::Str(s) => s, + Input2::Slice(bytes) => str::from_utf8(bytes).map_err(error::str_utf8)?, + }; + + let mut parser = Parser::new(input.chars()); + let mut loader = Loader { + events: Vec::new(), + aliases: BTreeMap::new(), + }; + parser.load(&mut loader, true).map_err(error::scanner)?; + if loader.events.is_empty() { + Err(error::end_of_stream()) + } else { + Ok(loader) + } +} + +impl<'de> de::Deserializer<'de> for Deserializer<'de> { + type Error = Error; + + fn deserialize_any(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + self.de(|state| state.deserialize_any(visitor)) + } + + fn deserialize_bool(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + self.de(|state| state.deserialize_bool(visitor)) + } + + fn deserialize_i8(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + self.de(|state| state.deserialize_i8(visitor)) + } + + fn deserialize_i16(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + self.de(|state| state.deserialize_i16(visitor)) + } + + fn deserialize_i32(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + self.de(|state| state.deserialize_i32(visitor)) + } + + fn deserialize_i64(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + self.de(|state| state.deserialize_i64(visitor)) + } + + serde_if_integer128! { + fn deserialize_i128(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + self.de(|state| state.deserialize_i128(visitor)) + } + } + + fn deserialize_u8(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + self.de(|state| state.deserialize_u8(visitor)) + } + + fn deserialize_u16(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + self.de(|state| state.deserialize_u16(visitor)) + } + + fn deserialize_u32(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + self.de(|state| state.deserialize_u32(visitor)) + } + + fn deserialize_u64(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + self.de(|state| state.deserialize_u64(visitor)) + } + + serde_if_integer128! { + fn deserialize_u128(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + self.de(|state| state.deserialize_u128(visitor)) + } + } + + fn deserialize_f32(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + self.de(|state| state.deserialize_f32(visitor)) + } + + fn deserialize_f64(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + self.de(|state| state.deserialize_f64(visitor)) + } + + fn deserialize_char(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + self.de(|state| state.deserialize_char(visitor)) + } + + fn deserialize_str(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + self.de(|state| state.deserialize_str(visitor)) + } + + fn deserialize_string(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + self.de(|state| state.deserialize_string(visitor)) + } + + fn deserialize_bytes(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + self.de(|state| state.deserialize_bytes(visitor)) + } + + fn deserialize_byte_buf(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + self.de(|state| state.deserialize_byte_buf(visitor)) + } + + fn deserialize_option(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + self.de(|state| state.deserialize_option(visitor)) + } + + fn deserialize_unit(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + self.de(|state| state.deserialize_unit(visitor)) + } + + fn deserialize_unit_struct(self, name: &'static str, visitor: V) -> Result + where + V: Visitor<'de>, + { + self.de(|state| state.deserialize_unit_struct(name, visitor)) + } + + fn deserialize_newtype_struct(self, name: &'static str, visitor: V) -> Result + where + V: Visitor<'de>, + { + self.de(|state| state.deserialize_newtype_struct(name, visitor)) + } + + fn deserialize_seq(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + self.de(|state| state.deserialize_seq(visitor)) + } + + fn deserialize_tuple(self, len: usize, visitor: V) -> Result + where + V: Visitor<'de>, + { + self.de(|state| state.deserialize_tuple(len, visitor)) + } + + fn deserialize_tuple_struct( + self, + name: &'static str, + len: usize, + visitor: V, + ) -> Result + where + V: Visitor<'de>, + { + self.de(|state| state.deserialize_tuple_struct(name, len, visitor)) + } + + fn deserialize_map(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + self.de(|state| state.deserialize_map(visitor)) + } + + fn deserialize_struct( + self, + name: &'static str, + fields: &'static [&'static str], + visitor: V, + ) -> Result + where + V: Visitor<'de>, + { + self.de(|state| state.deserialize_struct(name, fields, visitor)) + } + + fn deserialize_enum( + self, + name: &'static str, + variants: &'static [&'static str], + visitor: V, + ) -> Result + where + V: Visitor<'de>, + { + self.de(|state| state.deserialize_enum(name, variants, visitor)) + } + + fn deserialize_identifier(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + self.de(|state| state.deserialize_identifier(visitor)) + } + + fn deserialize_ignored_any(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + self.de(|state| state.deserialize_ignored_any(visitor)) + } +} + pub struct Loader { events: Vec<(Event, Marker)>, /// Map from alias id to index in events. @@ -63,7 +389,7 @@ enum Event { MappingEnd, } -struct Deserializer<'a> { +struct DeserializerFromEvents<'a> { events: &'a [(Event, Marker)], /// Map from alias id to index in events. aliases: &'a BTreeMap, @@ -72,7 +398,7 @@ struct Deserializer<'a> { remaining_depth: u8, } -impl<'a> Deserializer<'a> { +impl<'a> DeserializerFromEvents<'a> { fn peek(&self) -> Result<(&'a Event, Marker)> { match self.events.get(*self.pos) { Some(event) => Ok((&event.0, event.1)), @@ -91,11 +417,11 @@ impl<'a> Deserializer<'a> { }) } - fn jump(&'a self, pos: &'a mut usize) -> Result> { + fn jump(&'a self, pos: &'a mut usize) -> Result> { match self.aliases.get(pos) { Some(&found) => { *pos = found; - Ok(Deserializer { + Ok(DeserializerFromEvents { events: self.events, aliases: self.aliases, pos, @@ -281,7 +607,7 @@ where } struct SeqAccess<'a: 'r, 'r> { - de: &'r mut Deserializer<'a>, + de: &'r mut DeserializerFromEvents<'a>, len: usize, } @@ -295,7 +621,7 @@ impl<'de, 'a, 'r> de::SeqAccess<'de> for SeqAccess<'a, 'r> { match self.de.peek()?.0 { Event::SequenceEnd => Ok(None), _ => { - let mut element_de = Deserializer { + let mut element_de = DeserializerFromEvents { events: self.de.events, aliases: self.de.aliases, pos: self.de.pos, @@ -313,7 +639,7 @@ impl<'de, 'a, 'r> de::SeqAccess<'de> for SeqAccess<'a, 'r> { } struct MapAccess<'a: 'r, 'r> { - de: &'r mut Deserializer<'a>, + de: &'r mut DeserializerFromEvents<'a>, len: usize, key: Option<&'a str>, } @@ -344,7 +670,7 @@ impl<'de, 'a, 'r> de::MapAccess<'de> for MapAccess<'a, 'r> { where V: DeserializeSeed<'de>, { - let mut value_de = Deserializer { + let mut value_de = DeserializerFromEvents { events: self.de.events, aliases: self.de.aliases, pos: self.de.pos, @@ -365,14 +691,14 @@ impl<'de, 'a, 'r> de::MapAccess<'de> for MapAccess<'a, 'r> { } struct EnumAccess<'a: 'r, 'r> { - de: &'r mut Deserializer<'a>, + de: &'r mut DeserializerFromEvents<'a>, name: &'static str, tag: Option<&'static str>, } impl<'de, 'a, 'r> de::EnumAccess<'de> for EnumAccess<'a, 'r> { type Error = Error; - type Variant = Deserializer<'r>; + type Variant = DeserializerFromEvents<'r>; fn variant_seed(self, seed: V) -> Result<(V::Value, Self::Variant)> where @@ -408,7 +734,7 @@ impl<'de, 'a, 'r> de::EnumAccess<'de> for EnumAccess<'a, 'r> { let str_de = IntoDeserializer::::into_deserializer(variant); let ret = seed.deserialize(str_de)?; - let variant_visitor = Deserializer { + let variant_visitor = DeserializerFromEvents { events: self.de.events, aliases: self.de.aliases, pos: self.de.pos, @@ -422,7 +748,7 @@ impl<'de, 'a, 'r> de::EnumAccess<'de> for EnumAccess<'a, 'r> { } } -impl<'de, 'a> de::VariantAccess<'de> for Deserializer<'a> { +impl<'de, 'a> de::VariantAccess<'de> for DeserializerFromEvents<'a> { type Error = Error; fn unit_variant(mut self) -> Result<()> { @@ -452,7 +778,7 @@ impl<'de, 'a> de::VariantAccess<'de> for Deserializer<'a> { } struct UnitVariantAccess<'a: 'r, 'r> { - de: &'r mut Deserializer<'a>, + de: &'r mut DeserializerFromEvents<'a>, } impl<'de, 'a, 'r> de::EnumAccess<'de> for UnitVariantAccess<'a, 'r> { @@ -634,7 +960,7 @@ fn invalid_type(event: &Event, exp: &dyn Expected) -> Error { } } -impl<'a> Deserializer<'a> { +impl<'a> DeserializerFromEvents<'a> { fn deserialize_scalar<'de, V>(&mut self, visitor: V) -> Result where V: Visitor<'de>, @@ -649,7 +975,7 @@ impl<'a> Deserializer<'a> { } } -impl<'de, 'a, 'r> de::Deserializer<'de> for &'r mut Deserializer<'a> { +impl<'de, 'a, 'r> de::Deserializer<'de> for &'r mut DeserializerFromEvents<'a> { type Error = Error; fn deserialize_any(self, visitor: V) -> Result @@ -1034,29 +1360,7 @@ pub fn from_str_seed(s: &str, seed: S) -> Result where S: for<'de> DeserializeSeed<'de, Value = T>, { - let mut parser = Parser::new(s.chars()); - let mut loader = Loader { - events: Vec::new(), - aliases: BTreeMap::new(), - }; - parser.load(&mut loader, true).map_err(error::scanner)?; - if loader.events.is_empty() { - Err(error::end_of_stream()) - } else { - let mut pos = 0; - let t = seed.deserialize(&mut Deserializer { - events: &loader.events, - aliases: &loader.aliases, - pos: &mut pos, - path: Path::Root, - remaining_depth: 128, - })?; - if pos == loader.events.len() { - Ok(t) - } else { - Err(error::more_than_one_document()) - } - } + seed.deserialize(Deserializer::from_str(s)) } /// Deserialize an instance of type `T` from an IO stream of YAML. @@ -1085,14 +1389,12 @@ where /// is wrong with the data, for example required struct fields are missing from /// the YAML map or some number is too big to fit in the expected primitive /// type. -pub fn from_reader_seed(mut rdr: R, seed: S) -> Result +pub fn from_reader_seed(rdr: R, seed: S) -> Result where R: io::Read, S: for<'de> DeserializeSeed<'de, Value = T>, { - let mut bytes = Vec::new(); - rdr.read_to_end(&mut bytes).map_err(error::io)?; - from_slice_seed(&bytes, seed) + seed.deserialize(Deserializer::from_reader(rdr)) } /// Deserialize an instance of type `T` from bytes of YAML text. @@ -1128,6 +1430,5 @@ pub fn from_slice_seed(v: &[u8], seed: S) -> Result where S: for<'de> DeserializeSeed<'de, Value = T>, { - let s = str::from_utf8(v).map_err(error::str_utf8)?; - from_str_seed(s, seed) + seed.deserialize(Deserializer::from_slice(v)) } diff --git a/src/lib.rs b/src/lib.rs index 433a5839..45f46b55 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -97,7 +97,7 @@ clippy::match_like_matches_macro, )] -pub use crate::de::{from_reader, from_slice, from_str}; +pub use crate::de::{from_reader, from_slice, from_str, Deserializer}; pub use crate::error::{Error, Location, Result}; pub use crate::ser::{to_string, to_vec, to_writer}; pub use crate::value::{from_value, to_value, Index, Number, Sequence, Value};