Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow f32 and f64 map keys #1027

Merged
merged 11 commits into from Jul 11, 2023
48 changes: 28 additions & 20 deletions src/de.rs
Expand Up @@ -2118,20 +2118,26 @@ struct MapKey<'a, R: 'a> {
de: &'a mut Deserializer<R>,
}

macro_rules! deserialize_integer_key {
($method:ident => $visit:ident) => {
macro_rules! deserialize_numeric_key {
($method:ident) => {
fn $method<V>(self, visitor: V) -> Result<V::Value>
where
V: de::Visitor<'de>,
{
self.de.eat_char();
self.de.scratch.clear();
let string = tri!(self.de.read.parse_str(&mut self.de.scratch));
match (string.parse(), string) {
(Ok(integer), _) => visitor.$visit(integer),
(Err(_), Reference::Borrowed(s)) => visitor.visit_borrowed_str(s),
(Err(_), Reference::Copied(s)) => visitor.visit_str(s),

if let Some(b' ') | Some(b'\n') | Some(b'\r') | Some(b'\t') = tri!(self.de.peek()) {
return Err(self.de.peek_error(ErrorCode::UnexpectedWhitespaceInKey));
}

let value = tri!(self.de.$method(visitor));

match self.de.peek()? {
Some(b'"') => self.de.eat_char(),
_ => return Err(self.de.peek_error(ErrorCode::ExpectedDoubleQuote)),
}

Ok(value)
}
};
}
Expand All @@ -2155,16 +2161,18 @@ where
}
}

deserialize_integer_key!(deserialize_i8 => visit_i8);
deserialize_integer_key!(deserialize_i16 => visit_i16);
deserialize_integer_key!(deserialize_i32 => visit_i32);
deserialize_integer_key!(deserialize_i64 => visit_i64);
deserialize_integer_key!(deserialize_i128 => visit_i128);
deserialize_integer_key!(deserialize_u8 => visit_u8);
deserialize_integer_key!(deserialize_u16 => visit_u16);
deserialize_integer_key!(deserialize_u32 => visit_u32);
deserialize_integer_key!(deserialize_u64 => visit_u64);
deserialize_integer_key!(deserialize_u128 => visit_u128);
deserialize_numeric_key!(deserialize_i8);
deserialize_numeric_key!(deserialize_i16);
deserialize_numeric_key!(deserialize_i32);
deserialize_numeric_key!(deserialize_i64);
deserialize_numeric_key!(deserialize_i128);
deserialize_numeric_key!(deserialize_u8);
deserialize_numeric_key!(deserialize_u16);
deserialize_numeric_key!(deserialize_u32);
deserialize_numeric_key!(deserialize_u64);
deserialize_numeric_key!(deserialize_u128);
deserialize_numeric_key!(deserialize_f32);
deserialize_numeric_key!(deserialize_f64);

#[inline]
fn deserialize_option<V>(self, visitor: V) -> Result<V::Value>
Expand Down Expand Up @@ -2221,8 +2229,8 @@ where
}

forward_to_deserialize_any! {
bool f32 f64 char str string unit unit_struct seq tuple tuple_struct map
struct identifier ignored_any
bool char str string unit unit_struct seq tuple tuple_struct map struct
identifier ignored_any
}
}

Expand Down
19 changes: 19 additions & 0 deletions src/error.rs
Expand Up @@ -64,12 +64,15 @@ impl Error {
| ErrorCode::ExpectedObjectCommaOrEnd
| ErrorCode::ExpectedSomeIdent
| ErrorCode::ExpectedSomeValue
| ErrorCode::ExpectedDoubleQuote
| ErrorCode::InvalidEscape
| ErrorCode::InvalidNumber
| ErrorCode::NumberOutOfRange
| ErrorCode::InvalidUnicodeCodePoint
| ErrorCode::ControlCharacterWhileParsingString
| ErrorCode::KeyMustBeAString
| ErrorCode::FloatKeyMustBeFinite
| ErrorCode::UnexpectedWhitespaceInKey
| ErrorCode::LoneLeadingSurrogateInHexEscape
| ErrorCode::TrailingComma
| ErrorCode::TrailingCharacters
Expand Down Expand Up @@ -264,6 +267,9 @@ pub(crate) enum ErrorCode {
/// Expected this character to start a JSON value.
ExpectedSomeValue,

/// Expected this character to be a `"`.
ExpectedDoubleQuote,

/// Invalid hex escape code.
InvalidEscape,

Expand All @@ -282,6 +288,12 @@ pub(crate) enum ErrorCode {
/// Object key is not a string.
KeyMustBeAString,

/// Object key is a non-finite float value.
FloatKeyMustBeFinite,

/// Unexpected whitespace in a numeric key.
UnexpectedWhitespaceInKey,

/// Lone leading surrogate in hex escape.
LoneLeadingSurrogateInHexEscape,

Expand Down Expand Up @@ -348,6 +360,7 @@ impl Display for ErrorCode {
ErrorCode::ExpectedObjectCommaOrEnd => f.write_str("expected `,` or `}`"),
ErrorCode::ExpectedSomeIdent => f.write_str("expected ident"),
ErrorCode::ExpectedSomeValue => f.write_str("expected value"),
ErrorCode::ExpectedDoubleQuote => f.write_str("expected `\"`"),
ErrorCode::InvalidEscape => f.write_str("invalid escape"),
ErrorCode::InvalidNumber => f.write_str("invalid number"),
ErrorCode::NumberOutOfRange => f.write_str("number out of range"),
Expand All @@ -356,6 +369,12 @@ impl Display for ErrorCode {
f.write_str("control character (\\u0000-\\u001F) found while parsing a string")
}
ErrorCode::KeyMustBeAString => f.write_str("key must be a string"),
ErrorCode::FloatKeyMustBeFinite => {
f.write_str("float key must be finite (got NaN or +/-inf)")
}
ErrorCode::UnexpectedWhitespaceInKey => {
f.write_str("unexpected whitespace in object key")
}
ErrorCode::LoneLeadingSurrogateInHexEscape => {
f.write_str("lone leading surrogate in hex escape")
}
Expand Down
46 changes: 42 additions & 4 deletions src/ser.rs
Expand Up @@ -789,6 +789,10 @@ fn key_must_be_a_string() -> Error {
Error::syntax(ErrorCode::KeyMustBeAString, 0, 0)
}

fn float_key_must_be_finite() -> Error {
Error::syntax(ErrorCode::FloatKeyMustBeFinite, 0, 0)
}

impl<'a, W, F> ser::Serializer for MapKeySerializer<'a, W, F>
where
W: io::Write,
Expand Down Expand Up @@ -1002,12 +1006,46 @@ where
.map_err(Error::io)
}

fn serialize_f32(self, _value: f32) -> Result<()> {
Err(key_must_be_a_string())
fn serialize_f32(self, value: f32) -> Result<()> {
if !value.is_finite() {
return Err(float_key_must_be_finite());
}

tri!(self
.ser
.formatter
.begin_string(&mut self.ser.writer)
.map_err(Error::io));
tri!(self
.ser
.formatter
.write_f32(&mut self.ser.writer, value)
.map_err(Error::io));
dtolnay marked this conversation as resolved.
Show resolved Hide resolved
self.ser
.formatter
.end_string(&mut self.ser.writer)
.map_err(Error::io)
}

fn serialize_f64(self, _value: f64) -> Result<()> {
Err(key_must_be_a_string())
fn serialize_f64(self, value: f64) -> Result<()> {
if !value.is_finite() {
return Err(float_key_must_be_finite());
}

tri!(self
.ser
.formatter
.begin_string(&mut self.ser.writer)
.map_err(Error::io));
tri!(self
.ser
.formatter
.write_f64(&mut self.ser.writer, value)
.map_err(Error::io));
self.ser
.formatter
.end_string(&mut self.ser.writer)
.map_err(Error::io)
}

fn serialize_char(self, value: char) -> Result<()> {
Expand Down
29 changes: 16 additions & 13 deletions src/value/de.rs
Expand Up @@ -1120,13 +1120,14 @@ struct MapKeyDeserializer<'de> {
key: Cow<'de, str>,
}

macro_rules! deserialize_integer_key {
macro_rules! deserialize_numeric_key {
($method:ident => $visit:ident) => {
fn $method<V>(self, visitor: V) -> Result<V::Value, Error>
where
V: Visitor<'de>,
{
match (self.key.parse(), self.key) {
let parsed = crate::from_str(&self.key);
match (parsed, self.key) {
(Ok(integer), _) => visitor.$visit(integer),
(Err(_), Cow::Borrowed(s)) => visitor.visit_borrowed_str(s),
#[cfg(any(feature = "std", feature = "alloc"))]
Expand All @@ -1146,16 +1147,18 @@ impl<'de> serde::Deserializer<'de> for MapKeyDeserializer<'de> {
BorrowedCowStrDeserializer::new(self.key).deserialize_any(visitor)
}

deserialize_integer_key!(deserialize_i8 => visit_i8);
deserialize_integer_key!(deserialize_i16 => visit_i16);
deserialize_integer_key!(deserialize_i32 => visit_i32);
deserialize_integer_key!(deserialize_i64 => visit_i64);
deserialize_integer_key!(deserialize_i128 => visit_i128);
deserialize_integer_key!(deserialize_u8 => visit_u8);
deserialize_integer_key!(deserialize_u16 => visit_u16);
deserialize_integer_key!(deserialize_u32 => visit_u32);
deserialize_integer_key!(deserialize_u64 => visit_u64);
deserialize_integer_key!(deserialize_u128 => visit_u128);
deserialize_numeric_key!(deserialize_i8 => visit_i8);
deserialize_numeric_key!(deserialize_i16 => visit_i16);
deserialize_numeric_key!(deserialize_i32 => visit_i32);
deserialize_numeric_key!(deserialize_i64 => visit_i64);
deserialize_numeric_key!(deserialize_i128 => visit_i128);
deserialize_numeric_key!(deserialize_u8 => visit_u8);
deserialize_numeric_key!(deserialize_u16 => visit_u16);
deserialize_numeric_key!(deserialize_u32 => visit_u32);
deserialize_numeric_key!(deserialize_u64 => visit_u64);
deserialize_numeric_key!(deserialize_u128 => visit_u128);
deserialize_numeric_key!(deserialize_f32 => visit_f32);
deserialize_numeric_key!(deserialize_f64 => visit_f64);
dtolnay marked this conversation as resolved.
Show resolved Hide resolved

#[inline]
fn deserialize_option<V>(self, visitor: V) -> Result<V::Value, Error>
Expand Down Expand Up @@ -1193,7 +1196,7 @@ impl<'de> serde::Deserializer<'de> for MapKeyDeserializer<'de> {
}

forward_to_deserialize_any! {
bool f32 f64 char str string bytes byte_buf unit unit_struct seq tuple
bool char str string bytes byte_buf unit unit_struct seq tuple
tuple_struct map struct identifier ignored_any
}
}
Expand Down
20 changes: 16 additions & 4 deletions src/value/ser.rs
Expand Up @@ -451,6 +451,10 @@ fn key_must_be_a_string() -> Error {
Error::syntax(ErrorCode::KeyMustBeAString, 0, 0)
}

fn float_key_must_be_finite() -> Error {
Error::syntax(ErrorCode::FloatKeyMustBeFinite, 0, 0)
}

impl serde::Serializer for MapKeySerializer {
type Ok = String;
type Error = Error;
Expand Down Expand Up @@ -517,12 +521,20 @@ impl serde::Serializer for MapKeySerializer {
Ok(value.to_string())
}

fn serialize_f32(self, _value: f32) -> Result<String> {
Err(key_must_be_a_string())
fn serialize_f32(self, value: f32) -> Result<String> {
if value.is_finite() {
Ok(ryu::Buffer::new().format_finite(value).to_owned())
} else {
Err(float_key_must_be_finite())
}
}

fn serialize_f64(self, _value: f64) -> Result<String> {
Err(key_must_be_a_string())
fn serialize_f64(self, value: f64) -> Result<String> {
if value.is_finite() {
Ok(ryu::Buffer::new().format_finite(value).to_owned())
} else {
Err(float_key_must_be_finite())
}
}

#[inline]
Expand Down