From c0bbe1ad4288dc4a558cd73389d8a32346a8800c Mon Sep 17 00:00:00 2001 From: Eli Snow Date: Sat, 27 Oct 2018 23:22:12 -0600 Subject: [PATCH] Allow YAML tags to be used to specify an enum variant. Closes #115 --- src/de.rs | 33 ++++++++++++++++++++++++++------- tests/test_de.rs | 25 +++++++++++++++++++++++++ 2 files changed, 51 insertions(+), 7 deletions(-) diff --git a/src/de.rs b/src/de.rs index e7ee6b00..e58f3888 100644 --- a/src/de.rs +++ b/src/de.rs @@ -380,6 +380,7 @@ impl<'de, 'a, 'r> de::MapAccess<'de> for MapAccess<'a, 'r> { struct EnumAccess<'a: 'r, 'r> { de: &'r mut Deserializer<'a>, name: &'static str, + tag: Option<&'static str>, } impl<'de, 'a, 'r> de::EnumAccess<'de> for EnumAccess<'a, 'r> { @@ -405,12 +406,16 @@ impl<'de, 'a, 'r> de::EnumAccess<'de> for EnumAccess<'a, 'r> { } } - let variant = match *self.de.next()?.0 { - Event::Scalar(ref s, _, _) => &**s, - _ => { - *self.de.pos -= 1; - let bad = BadKey { name: self.name }; - return Err(de::Deserializer::deserialize_any(&mut *self.de, bad).unwrap_err()); + let variant = if let Some(tag) = self.tag { + tag + } else { + match *self.de.next()?.0 { + Event::Scalar(ref s, _, _) => &**s, + _ => { + *self.de.pos -= 1; + let bad = BadKey { name: self.name }; + return Err(de::Deserializer::deserialize_any(&mut *self.de, bad).unwrap_err()); + } } }; @@ -938,12 +943,26 @@ impl<'de, 'a, 'r> de::Deserializer<'de> for &'r mut Deserializer<'a> { self.jump(&mut pos)? .deserialize_enum(name, variants, visitor) } - Event::Scalar(_, _, _) => visitor.visit_enum(UnitVariantAccess { de: self }), + Event::Scalar(_, _, ref t) => { + if let Some(TokenType::Tag(ref handle, ref suffix)) = t { + if handle == "!" { + if let Some(tag) = variants.iter().find(|v| *v == suffix) { + return visitor.visit_enum(EnumAccess { + de: self, + name: name, + tag: Some(tag), + }); + } + } + } + visitor.visit_enum(UnitVariantAccess { de: self }) + }, Event::MappingStart => { *self.pos += 1; let value = visitor.visit_enum(EnumAccess { de: self, name: name, + tag: None, })?; self.end_mapping(1)?; Ok(value) diff --git a/tests/test_de.rs b/tests/test_de.rs index 17482463..aa2f66e5 100644 --- a/tests/test_de.rs +++ b/tests/test_de.rs @@ -162,6 +162,31 @@ fn test_enum_alias() { test_de(&yaml, &expected); } +#[test] +fn test_enum_tag() { + #[derive(Deserialize, PartialEq, Debug)] + enum E { + A(String), + B(String), + } + #[derive(Deserialize, PartialEq, Debug)] + struct Data { + a: E, + b: E, + } + let yaml = unindent( + " + --- + a: !A foo + b: !B bar" + ); + let expected = Data { + a: E::A("foo".into()), + b: E::B("bar".into()), + }; + test_de(&yaml, &expected); +} + #[test] fn test_number_as_string() { #[derive(Deserialize, PartialEq, Debug)]