Skip to content

Commit

Permalink
tweak depth limit logic
Browse files Browse the repository at this point in the history
  • Loading branch information
samuelcolvin committed Jan 14, 2024
1 parent f98b085 commit 5d14e70
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 9 deletions.
32 changes: 24 additions & 8 deletions src/recursion_guard.rs
Expand Up @@ -17,16 +17,16 @@ pub struct RecursionGuard {
ids: SmallContainer<RecursionKey>,
// depth could be a hashmap {validator_id => depth} but for simplicity and performance it's easier to just
// use one number for all validators
depth: u16,
depth: u8,
}

// A hard limit to avoid stack overflows when rampant recursion occurs
pub const RECURSION_GUARD_LIMIT: u16 = if cfg!(any(target_family = "wasm", all(windows, PyPy))) {
pub const RECURSION_GUARD_LIMIT: u8 = if cfg!(any(target_family = "wasm", all(windows, PyPy))) {
// wasm and windows PyPy have very limited stack sizes
50
49
} else if cfg!(any(PyPy, windows)) {
// PyPy and Windows in general have more restricted stack space
100
99
} else {
255
};
Expand All @@ -41,11 +41,26 @@ impl RecursionGuard {

// see #143 this is used as a backup in case the identity check recursion guard fails
#[must_use]
#[cfg(any(target_family = "wasm", windows, PyPy))]
pub fn incr_depth(&mut self) -> bool {
// use saturating_add as it's faster (since there's no error path)
// and the RECURSION_GUARD_LIMIT check will be hit before it overflows
debug_assert!(RECURSION_GUARD_LIMIT < 255);
self.depth = self.depth.saturating_add(1);
self.depth >= RECURSION_GUARD_LIMIT
self.depth > RECURSION_GUARD_LIMIT
}

#[must_use]
#[cfg(not(any(target_family = "wasm", windows, PyPy)))]
pub fn incr_depth(&mut self) -> bool {
debug_assert_eq!(RECURSION_GUARD_LIMIT, 255);
// use checked_add to check if we've hit the limit
if let Some(depth) = self.depth.checked_add(1) {
self.depth = depth;
false
} else {
true
}
}

pub fn decr_depth(&mut self) {
Expand Down Expand Up @@ -90,9 +105,9 @@ impl<T: Eq + Hash + Clone> SmallContainer<T> {
first_slot = first_slot.or(Some(index));
}
}
if let Some(first_slot) = first_slot {
array[first_slot] = Some(v);
Some(first_slot)
if let Some(index) = first_slot {
array[index] = Some(v);
first_slot
} else {
let mut set: AHashSet<T> = AHashSet::with_capacity(ARRAY_SIZE + 1);
for existing in array.iter_mut() {
Expand All @@ -108,6 +123,7 @@ impl<T: Eq + Hash + Clone> SmallContainer<T> {
// "If the set did not have this value present, `true` is returned."
Self::Set(set) => {
if set.insert(v) {
// again id doesn't matter here as we'll be removing from a set
Some(0)
} else {
None
Expand Down
2 changes: 1 addition & 1 deletion tests/serializers/test_any.py
Expand Up @@ -371,7 +371,7 @@ def fallback_func(obj):
f = FoobarCount(0)
v = 0
# when recursion is detected and we're in mode python, we just return the value
expected_visits = pydantic_core._pydantic_core._recursion_limit - 1
expected_visits = pydantic_core._pydantic_core._recursion_limit
assert any_serializer.to_python(f, fallback=fallback_func) == HasRepr(f'<FoobarCount {expected_visits} repr>')

with pytest.raises(ValueError, match=r'Circular reference detected \(depth exceeded\)'):
Expand Down

0 comments on commit 5d14e70

Please sign in to comment.