Skip to content

Commit

Permalink
improve performance of recursion guard (#1156)
Browse files Browse the repository at this point in the history
Co-authored-by: David Hewitt <mail@davidhewitt.dev>
Co-authored-by: David Hewitt <david.hewitt@pydantic.dev>
  • Loading branch information
3 people committed Jan 15, 2024
1 parent d7cf72d commit e1cb0eb
Show file tree
Hide file tree
Showing 4 changed files with 134 additions and 42 deletions.
144 changes: 118 additions & 26 deletions src/recursion_guard.rs
@@ -1,4 +1,5 @@
use ahash::AHashSet;
use std::mem::MaybeUninit;

type RecursionKey = (
// Identifier for the input object, e.g. the id() of a Python dict
Expand All @@ -13,56 +14,147 @@ type RecursionKey = (
/// It's used in `validators/definition` to detect when a reference is reused within itself.
#[derive(Debug, Clone, Default)]
pub struct RecursionGuard {
ids: Option<AHashSet<RecursionKey>>,
ids: RecursionStack,
// depth could be a hashmap {validator_id => depth} but for simplicity and performance it's easier to just
// use one number for all validators
depth: u16,
depth: u8,
}

// A hard limit to avoid stack overflows when rampant recursion occurs
pub const RECURSION_GUARD_LIMIT: u16 = if cfg!(any(target_family = "wasm", all(windows, PyPy))) {
pub const RECURSION_GUARD_LIMIT: u8 = if cfg!(any(target_family = "wasm", all(windows, PyPy))) {
// wasm and windows PyPy have very limited stack sizes
50
49
} else if cfg!(any(PyPy, windows)) {
// PyPy and Windows in general have more restricted stack space
100
99
} else {
255
};

impl RecursionGuard {
// insert a new id into the set, return whether the set already had the id in it
pub fn contains_or_insert(&mut self, obj_id: usize, node_id: usize) -> bool {
match self.ids {
// https://doc.rust-lang.org/std/collections/struct.HashSet.html#method.insert
// "If the set did not have this value present, `true` is returned."
Some(ref mut set) => !set.insert((obj_id, node_id)),
None => {
let mut set: AHashSet<RecursionKey> = AHashSet::with_capacity(10);
set.insert((obj_id, node_id));
self.ids = Some(set);
false
}
}
// insert a new value
// * return `false` if the stack already had it in it
// * return `true` if the stack didn't have it in it and it was inserted
pub fn insert(&mut self, obj_id: usize, node_id: usize) -> bool {
self.ids.insert((obj_id, node_id))
}

// see #143 this is used as a backup in case the identity check recursion guard fails
#[must_use]
#[cfg(any(target_family = "wasm", windows, PyPy))]
pub fn incr_depth(&mut self) -> bool {
self.depth += 1;
self.depth >= RECURSION_GUARD_LIMIT
// use saturating_add as it's faster (since there's no error path)
// and the RECURSION_GUARD_LIMIT check will be hit before it overflows
debug_assert!(RECURSION_GUARD_LIMIT < 255);
self.depth = self.depth.saturating_add(1);
self.depth > RECURSION_GUARD_LIMIT
}

#[must_use]
#[cfg(not(any(target_family = "wasm", windows, PyPy)))]
pub fn incr_depth(&mut self) -> bool {
debug_assert_eq!(RECURSION_GUARD_LIMIT, 255);
// use checked_add to check if we've hit the limit
if let Some(depth) = self.depth.checked_add(1) {
self.depth = depth;
false
} else {
true
}
}

pub fn decr_depth(&mut self) {
self.depth -= 1;
// for the same reason as incr_depth, use saturating_sub
self.depth = self.depth.saturating_sub(1);
}

pub fn remove(&mut self, obj_id: usize, node_id: usize) {
match self.ids {
Some(ref mut set) => {
set.remove(&(obj_id, node_id));
self.ids.remove(&(obj_id, node_id));
}
}

// trial and error suggests this is a good value, going higher causes array lookups to get significantly slower
const ARRAY_SIZE: usize = 16;

#[derive(Debug, Clone)]
enum RecursionStack {
Array {
data: [MaybeUninit<RecursionKey>; ARRAY_SIZE],
len: usize,
},
Set(AHashSet<RecursionKey>),
}

impl Default for RecursionStack {
fn default() -> Self {
Self::Array {
data: std::array::from_fn(|_| MaybeUninit::uninit()),
len: 0,
}
}
}

impl RecursionStack {
// insert a new value
// * return `false` if the stack already had it in it
// * return `true` if the stack didn't have it in it and it was inserted
pub fn insert(&mut self, v: RecursionKey) -> bool {
match self {
Self::Array { data, len } => {
if *len < ARRAY_SIZE {
for value in data.iter().take(*len) {
// Safety: reading values within bounds
if unsafe { value.assume_init() } == v {
return false;
}
}

data[*len].write(v);
*len += 1;
true
} else {
let mut set = AHashSet::with_capacity(ARRAY_SIZE + 1);
for existing in data.iter() {
// Safety: the array is fully initialized
set.insert(unsafe { existing.assume_init() });
}
let inserted = set.insert(v);
*self = Self::Set(set);
inserted
}
}
// https://doc.rust-lang.org/std/collections/struct.HashSet.html#method.insert
// "If the set did not have this value present, `true` is returned."
Self::Set(set) => set.insert(v),
}
}

pub fn remove(&mut self, v: &RecursionKey) {
match self {
Self::Array { data, len } => {
*len = len.checked_sub(1).expect("remove from empty recursion guard");
// Safety: this is reading what was the back of the initialized array
let removed = unsafe { data.get_unchecked_mut(*len) };
assert!(unsafe { removed.assume_init_ref() } == v, "remove did not match insert");
// this should compile away to a noop
unsafe { std::ptr::drop_in_place(removed.as_mut_ptr()) }
}
Self::Set(set) => {
set.remove(v);
}
}
}
}

impl Drop for RecursionStack {
fn drop(&mut self) {
// This should compile away to a noop as Recursion>Key doesn't implement Drop, but it seemed
// desirable to leave this in for safety in case that should change in the future
if let Self::Array { data, len } = self {
for value in data.iter_mut().take(*len) {
// Safety: reading values within bounds
unsafe { std::ptr::drop_in_place(value.as_mut_ptr()) };
}
None => unreachable!(),
};
}
}
}
14 changes: 7 additions & 7 deletions src/serializers/extra.rs
Expand Up @@ -346,17 +346,17 @@ pub struct SerRecursionGuard {

impl SerRecursionGuard {
pub fn add(&self, value: &PyAny, def_ref_id: usize) -> PyResult<usize> {
// https://doc.rust-lang.org/std/collections/struct.HashSet.html#method.insert
// "If the set did not have this value present, `true` is returned."
let id = value.as_ptr() as usize;
let mut guard = self.guard.borrow_mut();

if guard.contains_or_insert(id, def_ref_id) {
Err(PyValueError::new_err("Circular reference detected (id repeated)"))
} else if guard.incr_depth() {
Err(PyValueError::new_err("Circular reference detected (depth exceeded)"))
if guard.insert(id, def_ref_id) {
if guard.incr_depth() {
Err(PyValueError::new_err("Circular reference detected (depth exceeded)"))
} else {
Ok(id)
}
} else {
Ok(id)
Err(PyValueError::new_err("Circular reference detected (id repeated)"))
}
}

Expand Down
16 changes: 8 additions & 8 deletions src/validators/definitions.rs
Expand Up @@ -76,17 +76,17 @@ impl Validator for DefinitionRefValidator {
self.definition.read(|validator| {
let validator = validator.unwrap();
if let Some(id) = input.identity() {
if state.recursion_guard.contains_or_insert(id, self.definition.id()) {
// we don't remove id here, we leave that to the validator which originally added id to `recursion_guard`
Err(ValError::new(ErrorTypeDefaults::RecursionLoop, input))
} else {
if state.recursion_guard.insert(id, self.definition.id()) {
if state.recursion_guard.incr_depth() {
return Err(ValError::new(ErrorTypeDefaults::RecursionLoop, input));
}
let output = validator.validate(py, input, state);
state.recursion_guard.remove(id, self.definition.id());
state.recursion_guard.decr_depth();
output
} else {
// we don't remove id here, we leave that to the validator which originally added id to `recursion_guard`
Err(ValError::new(ErrorTypeDefaults::RecursionLoop, input))
}
} else {
validator.validate(py, input, state)
Expand All @@ -105,17 +105,17 @@ impl Validator for DefinitionRefValidator {
self.definition.read(|validator| {
let validator = validator.unwrap();
if let Some(id) = obj.identity() {
if state.recursion_guard.contains_or_insert(id, self.definition.id()) {
// we don't remove id here, we leave that to the validator which originally added id to `recursion_guard`
Err(ValError::new(ErrorTypeDefaults::RecursionLoop, obj))
} else {
if state.recursion_guard.insert(id, self.definition.id()) {
if state.recursion_guard.incr_depth() {
return Err(ValError::new(ErrorTypeDefaults::RecursionLoop, obj));
}
let output = validator.validate_assignment(py, obj, field_name, field_value, state);
state.recursion_guard.remove(id, self.definition.id());
state.recursion_guard.decr_depth();
output
} else {
// we don't remove id here, we leave that to the validator which originally added id to `recursion_guard`
Err(ValError::new(ErrorTypeDefaults::RecursionLoop, obj))
}
} else {
validator.validate_assignment(py, obj, field_name, field_value, state)
Expand Down
2 changes: 1 addition & 1 deletion tests/serializers/test_any.py
Expand Up @@ -371,7 +371,7 @@ def fallback_func(obj):
f = FoobarCount(0)
v = 0
# when recursion is detected and we're in mode python, we just return the value
expected_visits = pydantic_core._pydantic_core._recursion_limit - 1
expected_visits = pydantic_core._pydantic_core._recursion_limit
assert any_serializer.to_python(f, fallback=fallback_func) == HasRepr(f'<FoobarCount {expected_visits} repr>')

with pytest.raises(ValueError, match=r'Circular reference detected \(depth exceeded\)'):
Expand Down

0 comments on commit e1cb0eb

Please sign in to comment.