Skip to content
This repository has been archived by the owner on Mar 25, 2024. It is now read-only.

Prevent too deep recursion #105

Merged
merged 1 commit into from Sep 15, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
32 changes: 24 additions & 8 deletions src/de.rs
Expand Up @@ -79,6 +79,7 @@ struct Deserializer<'a> {
aliases: &'a BTreeMap<usize, usize>,
pos: &'a mut usize,
path: Path<'a>,
remaining_depth: u8,
}

impl<'a> Deserializer<'a> {
Expand Down Expand Up @@ -109,6 +110,7 @@ impl<'a> Deserializer<'a> {
aliases: self.aliases,
pos: pos,
path: Path::Alias { parent: &self.path },
remaining_depth: self.remaining_depth,
})
}
None => panic!("unresolved alias: {}", *pos),
Expand Down Expand Up @@ -161,11 +163,11 @@ impl<'a> Deserializer<'a> {
where
V: Visitor<'de>,
{
let (value, len) = {
let mut seq = SeqAccess { de: self, len: 0 };
let (value, len) = self.recursion_check(|de| {
let mut seq = SeqAccess { de: de, len: 0 };
let value = visitor.visit_seq(&mut seq)?;
(value, seq.len)
};
Ok((value, seq.len))
})?;
self.end_sequence(len)?;
Ok(value)
}
Expand All @@ -174,15 +176,15 @@ impl<'a> Deserializer<'a> {
where
V: Visitor<'de>,
{
let (value, len) = {
let (value, len) = self.recursion_check(|de| {
let mut map = MapAccess {
de: &mut *self,
de: de,
len: 0,
key: None,
};
let value = visitor.visit_map(&mut map)?;
(value, map.len)
};
Ok((value, map.len))
})?;
self.end_mapping(len)?;
Ok(value)
}
Expand Down Expand Up @@ -238,6 +240,16 @@ impl<'a> Deserializer<'a> {
Err(de::Error::invalid_length(total, &ExpectedMap(len)))
}
}

fn recursion_check<F: FnOnce(&mut Self) -> Result<T>, T>(&mut self, f: F) -> Result<T> {
let previous_depth = self.remaining_depth;
self.remaining_depth = previous_depth
.checked_sub(1)
.ok_or_else(Error::recursion_limit_exceeded)?;
let result = f(self);
self.remaining_depth = previous_depth;
result
}
}

fn visit_scalar<'de, V>(
Expand Down Expand Up @@ -303,6 +315,7 @@ impl<'de, 'a, 'r> de::SeqAccess<'de> for SeqAccess<'a, 'r> {
parent: &self.de.path,
index: self.len,
},
remaining_depth: self.de.remaining_depth,
};
self.len += 1;
seed.deserialize(&mut element_de).map(Some)
Expand Down Expand Up @@ -357,6 +370,7 @@ impl<'de, 'a, 'r> de::MapAccess<'de> for MapAccess<'a, 'r> {
parent: &self.de.path,
}
},
remaining_depth: self.de.remaining_depth,
};
seed.deserialize(&mut value_de)
}
Expand Down Expand Up @@ -409,6 +423,7 @@ impl<'de, 'a, 'r> de::EnumAccess<'de> for EnumAccess<'a, 'r> {
parent: &self.de.path,
key: variant,
},
remaining_depth: self.de.remaining_depth,
};
Ok((ret, variant_visitor))
}
Expand Down Expand Up @@ -949,6 +964,7 @@ where
aliases: &loader.aliases,
pos: &mut pos,
path: Path::Root,
remaining_depth: 128,
})?;
if pos == loader.events.len() {
Ok(t)
Expand Down
12 changes: 12 additions & 0 deletions src/error.rs
Expand Up @@ -41,6 +41,7 @@ pub enum ErrorImpl {

EndOfStream,
MoreThanOneDocument,
RecursionLimitExceeded,
}

#[derive(Debug)]
Expand Down Expand Up @@ -157,6 +158,12 @@ impl Error {
Error(Box::new(ErrorImpl::FromUtf8(err)))
}

// Not public API. Should be pub(crate).
#[doc(hidden)]
pub fn recursion_limit_exceeded() -> Error {
Error(Box::new(ErrorImpl::RecursionLimitExceeded))
}

// Not public API. Should be pub(crate).
#[doc(hidden)]
pub fn fix_marker(mut self, marker: Marker, path: Path) -> Self {
Expand All @@ -183,6 +190,7 @@ impl error::Error for Error {
ErrorImpl::MoreThanOneDocument => {
"deserializing from YAML containing more than one document is not supported"
}
ErrorImpl::RecursionLimitExceeded => "recursion limit exceeded",
}
}

Expand Down Expand Up @@ -218,6 +226,7 @@ impl Display for Error {
ErrorImpl::MoreThanOneDocument => f.write_str(
"deserializing from YAML containing more than one document is not supported",
),
ErrorImpl::RecursionLimitExceeded => f.write_str("recursion limit exceeded"),
}
}
}
Expand All @@ -241,6 +250,9 @@ impl Debug for Error {
}
ErrorImpl::EndOfStream => formatter.debug_tuple("EndOfStream").finish(),
ErrorImpl::MoreThanOneDocument => formatter.debug_tuple("MoreThanOneDocument").finish(),
ErrorImpl::RecursionLimitExceeded => {
formatter.debug_tuple("RecursionLimitExceeded").finish()
}
}
}
}
Expand Down
48 changes: 48 additions & 0 deletions tests/test_error.rs
Expand Up @@ -257,3 +257,51 @@ fn test_invalid_scalar_type() {
let expected = "x: invalid type: unit value, expected an array of length 1 at line 2 column 1";
test_error::<S>(yaml, expected);
}

#[test]
fn test_infinite_recursion_objects() {
#[derive(Deserialize, Debug)]
struct S {
x: Option<Box<S>>,
}

let yaml = "&a {x: *a}";
let expected = "recursion limit exceeded";
test_error::<S>(yaml, expected);
}

#[test]
fn test_infinite_recursion_arrays() {
#[derive(Deserialize, Debug)]
struct S {
x: Option<Box<S>>,
}

let yaml = "&a [*a]";
let expected = "recursion limit exceeded";
test_error::<S>(yaml, expected);
}

#[test]
fn test_finite_recursion_objects() {
#[derive(Deserialize, Debug)]
struct S {
x: Option<Box<S>>,
}

let yaml = "{x:".repeat(1_000) + &"}".repeat(1_000);
let expected = "recursion limit exceeded";
test_error::<i32>(&yaml, expected);
}

#[test]
fn test_finite_recursion_arrays() {
#[derive(Deserialize, Debug)]
struct S {
x: Option<Box<S>>,
}

let yaml = "[".repeat(1_000) + &"]".repeat(1_000);
let expected = "recursion limit exceeded";
test_error::<S>(&yaml, expected);
}