Skip to content

Commit

Permalink
Merge pull request #1027 from overdrivenpotato/float-key
Browse files Browse the repository at this point in the history
Allow `f32` and `f64` map keys
  • Loading branch information
dtolnay committed Jul 11, 2023
2 parents 4514365 + a4e2719 commit b8d8d10
Show file tree
Hide file tree
Showing 6 changed files with 204 additions and 46 deletions.
48 changes: 28 additions & 20 deletions src/de.rs
Expand Up @@ -2118,20 +2118,26 @@ struct MapKey<'a, R: 'a> {
de: &'a mut Deserializer<R>,
}

macro_rules! deserialize_integer_key {
($method:ident => $visit:ident) => {
macro_rules! deserialize_numeric_key {
($method:ident) => {
fn $method<V>(self, visitor: V) -> Result<V::Value>
where
V: de::Visitor<'de>,
{
self.de.eat_char();
self.de.scratch.clear();
let string = tri!(self.de.read.parse_str(&mut self.de.scratch));
match (string.parse(), string) {
(Ok(integer), _) => visitor.$visit(integer),
(Err(_), Reference::Borrowed(s)) => visitor.visit_borrowed_str(s),
(Err(_), Reference::Copied(s)) => visitor.visit_str(s),

if let Some(b' ') | Some(b'\n') | Some(b'\r') | Some(b'\t') = tri!(self.de.peek()) {
return Err(self.de.peek_error(ErrorCode::UnexpectedWhitespaceInKey));
}

let value = tri!(self.de.$method(visitor));

match self.de.peek()? {
Some(b'"') => self.de.eat_char(),
_ => return Err(self.de.peek_error(ErrorCode::ExpectedDoubleQuote)),
}

Ok(value)
}
};
}
Expand All @@ -2155,16 +2161,18 @@ where
}
}

deserialize_integer_key!(deserialize_i8 => visit_i8);
deserialize_integer_key!(deserialize_i16 => visit_i16);
deserialize_integer_key!(deserialize_i32 => visit_i32);
deserialize_integer_key!(deserialize_i64 => visit_i64);
deserialize_integer_key!(deserialize_i128 => visit_i128);
deserialize_integer_key!(deserialize_u8 => visit_u8);
deserialize_integer_key!(deserialize_u16 => visit_u16);
deserialize_integer_key!(deserialize_u32 => visit_u32);
deserialize_integer_key!(deserialize_u64 => visit_u64);
deserialize_integer_key!(deserialize_u128 => visit_u128);
deserialize_numeric_key!(deserialize_i8);
deserialize_numeric_key!(deserialize_i16);
deserialize_numeric_key!(deserialize_i32);
deserialize_numeric_key!(deserialize_i64);
deserialize_numeric_key!(deserialize_i128);
deserialize_numeric_key!(deserialize_u8);
deserialize_numeric_key!(deserialize_u16);
deserialize_numeric_key!(deserialize_u32);
deserialize_numeric_key!(deserialize_u64);
deserialize_numeric_key!(deserialize_u128);
deserialize_numeric_key!(deserialize_f32);
deserialize_numeric_key!(deserialize_f64);

#[inline]
fn deserialize_option<V>(self, visitor: V) -> Result<V::Value>
Expand Down Expand Up @@ -2221,8 +2229,8 @@ where
}

forward_to_deserialize_any! {
bool f32 f64 char str string unit unit_struct seq tuple tuple_struct map
struct identifier ignored_any
bool char str string unit unit_struct seq tuple tuple_struct map struct
identifier ignored_any
}
}

Expand Down
19 changes: 19 additions & 0 deletions src/error.rs
Expand Up @@ -64,12 +64,15 @@ impl Error {
| ErrorCode::ExpectedObjectCommaOrEnd
| ErrorCode::ExpectedSomeIdent
| ErrorCode::ExpectedSomeValue
| ErrorCode::ExpectedDoubleQuote
| ErrorCode::InvalidEscape
| ErrorCode::InvalidNumber
| ErrorCode::NumberOutOfRange
| ErrorCode::InvalidUnicodeCodePoint
| ErrorCode::ControlCharacterWhileParsingString
| ErrorCode::KeyMustBeAString
| ErrorCode::FloatKeyMustBeFinite
| ErrorCode::UnexpectedWhitespaceInKey
| ErrorCode::LoneLeadingSurrogateInHexEscape
| ErrorCode::TrailingComma
| ErrorCode::TrailingCharacters
Expand Down Expand Up @@ -264,6 +267,9 @@ pub(crate) enum ErrorCode {
/// Expected this character to start a JSON value.
ExpectedSomeValue,

/// Expected this character to be a `"`.
ExpectedDoubleQuote,

/// Invalid hex escape code.
InvalidEscape,

Expand All @@ -282,6 +288,12 @@ pub(crate) enum ErrorCode {
/// Object key is not a string.
KeyMustBeAString,

/// Object key is a non-finite float value.
FloatKeyMustBeFinite,

/// Unexpected whitespace in a numeric key.
UnexpectedWhitespaceInKey,

/// Lone leading surrogate in hex escape.
LoneLeadingSurrogateInHexEscape,

Expand Down Expand Up @@ -348,6 +360,7 @@ impl Display for ErrorCode {
ErrorCode::ExpectedObjectCommaOrEnd => f.write_str("expected `,` or `}`"),
ErrorCode::ExpectedSomeIdent => f.write_str("expected ident"),
ErrorCode::ExpectedSomeValue => f.write_str("expected value"),
ErrorCode::ExpectedDoubleQuote => f.write_str("expected `\"`"),
ErrorCode::InvalidEscape => f.write_str("invalid escape"),
ErrorCode::InvalidNumber => f.write_str("invalid number"),
ErrorCode::NumberOutOfRange => f.write_str("number out of range"),
Expand All @@ -356,6 +369,12 @@ impl Display for ErrorCode {
f.write_str("control character (\\u0000-\\u001F) found while parsing a string")
}
ErrorCode::KeyMustBeAString => f.write_str("key must be a string"),
ErrorCode::FloatKeyMustBeFinite => {
f.write_str("float key must be finite (got NaN or +/-inf)")
}
ErrorCode::UnexpectedWhitespaceInKey => {
f.write_str("unexpected whitespace in object key")
}
ErrorCode::LoneLeadingSurrogateInHexEscape => {
f.write_str("lone leading surrogate in hex escape")
}
Expand Down
46 changes: 42 additions & 4 deletions src/ser.rs
Expand Up @@ -789,6 +789,10 @@ fn key_must_be_a_string() -> Error {
Error::syntax(ErrorCode::KeyMustBeAString, 0, 0)
}

fn float_key_must_be_finite() -> Error {
Error::syntax(ErrorCode::FloatKeyMustBeFinite, 0, 0)
}

impl<'a, W, F> ser::Serializer for MapKeySerializer<'a, W, F>
where
W: io::Write,
Expand Down Expand Up @@ -1002,12 +1006,46 @@ where
.map_err(Error::io)
}

fn serialize_f32(self, _value: f32) -> Result<()> {
Err(key_must_be_a_string())
fn serialize_f32(self, value: f32) -> Result<()> {
if !value.is_finite() {
return Err(float_key_must_be_finite());
}

tri!(self
.ser
.formatter
.begin_string(&mut self.ser.writer)
.map_err(Error::io));
tri!(self
.ser
.formatter
.write_f32(&mut self.ser.writer, value)
.map_err(Error::io));
self.ser
.formatter
.end_string(&mut self.ser.writer)
.map_err(Error::io)
}

fn serialize_f64(self, _value: f64) -> Result<()> {
Err(key_must_be_a_string())
fn serialize_f64(self, value: f64) -> Result<()> {
if !value.is_finite() {
return Err(float_key_must_be_finite());
}

tri!(self
.ser
.formatter
.begin_string(&mut self.ser.writer)
.map_err(Error::io));
tri!(self
.ser
.formatter
.write_f64(&mut self.ser.writer, value)
.map_err(Error::io));
self.ser
.formatter
.end_string(&mut self.ser.writer)
.map_err(Error::io)
}

fn serialize_char(self, value: char) -> Result<()> {
Expand Down
29 changes: 16 additions & 13 deletions src/value/de.rs
Expand Up @@ -1120,13 +1120,14 @@ struct MapKeyDeserializer<'de> {
key: Cow<'de, str>,
}

macro_rules! deserialize_integer_key {
macro_rules! deserialize_numeric_key {
($method:ident => $visit:ident) => {
fn $method<V>(self, visitor: V) -> Result<V::Value, Error>
where
V: Visitor<'de>,
{
match (self.key.parse(), self.key) {
let parsed = crate::from_str(&self.key);
match (parsed, self.key) {
(Ok(integer), _) => visitor.$visit(integer),
(Err(_), Cow::Borrowed(s)) => visitor.visit_borrowed_str(s),
#[cfg(any(feature = "std", feature = "alloc"))]
Expand All @@ -1146,16 +1147,18 @@ impl<'de> serde::Deserializer<'de> for MapKeyDeserializer<'de> {
BorrowedCowStrDeserializer::new(self.key).deserialize_any(visitor)
}

deserialize_integer_key!(deserialize_i8 => visit_i8);
deserialize_integer_key!(deserialize_i16 => visit_i16);
deserialize_integer_key!(deserialize_i32 => visit_i32);
deserialize_integer_key!(deserialize_i64 => visit_i64);
deserialize_integer_key!(deserialize_i128 => visit_i128);
deserialize_integer_key!(deserialize_u8 => visit_u8);
deserialize_integer_key!(deserialize_u16 => visit_u16);
deserialize_integer_key!(deserialize_u32 => visit_u32);
deserialize_integer_key!(deserialize_u64 => visit_u64);
deserialize_integer_key!(deserialize_u128 => visit_u128);
deserialize_numeric_key!(deserialize_i8 => visit_i8);
deserialize_numeric_key!(deserialize_i16 => visit_i16);
deserialize_numeric_key!(deserialize_i32 => visit_i32);
deserialize_numeric_key!(deserialize_i64 => visit_i64);
deserialize_numeric_key!(deserialize_i128 => visit_i128);
deserialize_numeric_key!(deserialize_u8 => visit_u8);
deserialize_numeric_key!(deserialize_u16 => visit_u16);
deserialize_numeric_key!(deserialize_u32 => visit_u32);
deserialize_numeric_key!(deserialize_u64 => visit_u64);
deserialize_numeric_key!(deserialize_u128 => visit_u128);
deserialize_numeric_key!(deserialize_f32 => visit_f32);
deserialize_numeric_key!(deserialize_f64 => visit_f64);

#[inline]
fn deserialize_option<V>(self, visitor: V) -> Result<V::Value, Error>
Expand Down Expand Up @@ -1193,7 +1196,7 @@ impl<'de> serde::Deserializer<'de> for MapKeyDeserializer<'de> {
}

forward_to_deserialize_any! {
bool f32 f64 char str string bytes byte_buf unit unit_struct seq tuple
bool char str string bytes byte_buf unit unit_struct seq tuple
tuple_struct map struct identifier ignored_any
}
}
Expand Down
20 changes: 16 additions & 4 deletions src/value/ser.rs
Expand Up @@ -449,6 +449,10 @@ fn key_must_be_a_string() -> Error {
Error::syntax(ErrorCode::KeyMustBeAString, 0, 0)
}

fn float_key_must_be_finite() -> Error {
Error::syntax(ErrorCode::FloatKeyMustBeFinite, 0, 0)
}

impl serde::Serializer for MapKeySerializer {
type Ok = String;
type Error = Error;
Expand Down Expand Up @@ -515,12 +519,20 @@ impl serde::Serializer for MapKeySerializer {
Ok(value.to_string())
}

fn serialize_f32(self, _value: f32) -> Result<String> {
Err(key_must_be_a_string())
fn serialize_f32(self, value: f32) -> Result<String> {
if value.is_finite() {
Ok(ryu::Buffer::new().format_finite(value).to_owned())
} else {
Err(float_key_must_be_finite())
}
}

fn serialize_f64(self, _value: f64) -> Result<String> {
Err(key_must_be_a_string())
fn serialize_f64(self, value: f64) -> Result<String> {
if value.is_finite() {
Ok(ryu::Buffer::new().format_finite(value).to_owned())
} else {
Err(float_key_must_be_finite())
}
}

#[inline]
Expand Down

0 comments on commit b8d8d10

Please sign in to comment.