Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

improve performance of recursion guard #1156

Merged
merged 7 commits into from Jan 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
144 changes: 118 additions & 26 deletions src/recursion_guard.rs
@@ -1,4 +1,5 @@
use ahash::AHashSet;
use std::mem::MaybeUninit;

type RecursionKey = (
// Identifier for the input object, e.g. the id() of a Python dict
Expand All @@ -13,56 +14,147 @@ type RecursionKey = (
/// It's used in `validators/definition` to detect when a reference is reused within itself.
#[derive(Debug, Clone, Default)]
pub struct RecursionGuard {
ids: Option<AHashSet<RecursionKey>>,
ids: RecursionStack,
// depth could be a hashmap {validator_id => depth} but for simplicity and performance it's easier to just
// use one number for all validators
depth: u16,
depth: u8,
}

// A hard limit to avoid stack overflows when rampant recursion occurs
pub const RECURSION_GUARD_LIMIT: u16 = if cfg!(any(target_family = "wasm", all(windows, PyPy))) {
pub const RECURSION_GUARD_LIMIT: u8 = if cfg!(any(target_family = "wasm", all(windows, PyPy))) {
// wasm and windows PyPy have very limited stack sizes
50
49
} else if cfg!(any(PyPy, windows)) {
// PyPy and Windows in general have more restricted stack space
100
99
} else {
255
};

impl RecursionGuard {
// insert a new id into the set, return whether the set already had the id in it
pub fn contains_or_insert(&mut self, obj_id: usize, node_id: usize) -> bool {
match self.ids {
// https://doc.rust-lang.org/std/collections/struct.HashSet.html#method.insert
// "If the set did not have this value present, `true` is returned."
Some(ref mut set) => !set.insert((obj_id, node_id)),
None => {
let mut set: AHashSet<RecursionKey> = AHashSet::with_capacity(10);
set.insert((obj_id, node_id));
self.ids = Some(set);
false
}
}
// insert a new value
// * return `false` if the stack already had it in it
// * return `true` if the stack didn't have it in it and it was inserted
pub fn insert(&mut self, obj_id: usize, node_id: usize) -> bool {
self.ids.insert((obj_id, node_id))
}

// see #143 this is used as a backup in case the identity check recursion guard fails
#[must_use]
#[cfg(any(target_family = "wasm", windows, PyPy))]
pub fn incr_depth(&mut self) -> bool {
self.depth += 1;
self.depth >= RECURSION_GUARD_LIMIT
// use saturating_add as it's faster (since there's no error path)
// and the RECURSION_GUARD_LIMIT check will be hit before it overflows
debug_assert!(RECURSION_GUARD_LIMIT < 255);
self.depth = self.depth.saturating_add(1);
self.depth > RECURSION_GUARD_LIMIT
}

#[must_use]
#[cfg(not(any(target_family = "wasm", windows, PyPy)))]
pub fn incr_depth(&mut self) -> bool {
debug_assert_eq!(RECURSION_GUARD_LIMIT, 255);
// use checked_add to check if we've hit the limit
if let Some(depth) = self.depth.checked_add(1) {
self.depth = depth;
false
} else {
true
}
}

pub fn decr_depth(&mut self) {
self.depth -= 1;
// for the same reason as incr_depth, use saturating_sub
self.depth = self.depth.saturating_sub(1);
}

pub fn remove(&mut self, obj_id: usize, node_id: usize) {
match self.ids {
Some(ref mut set) => {
set.remove(&(obj_id, node_id));
self.ids.remove(&(obj_id, node_id));
}
}

// trial and error suggests this is a good value, going higher causes array lookups to get significantly slower
const ARRAY_SIZE: usize = 16;

#[derive(Debug, Clone)]
enum RecursionStack {
Array {
data: [MaybeUninit<RecursionKey>; ARRAY_SIZE],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we don't like the MaybeUninit here we could probably use Option with only a small perf cost.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

MaybeUninit looks graet.

len: usize,
},
Set(AHashSet<RecursionKey>),
}

impl Default for RecursionStack {
fn default() -> Self {
Self::Array {
data: std::array::from_fn(|_| MaybeUninit::uninit()),
len: 0,
}
}
}

impl RecursionStack {
// insert a new value
// * return `false` if the stack already had it in it
// * return `true` if the stack didn't have it in it and it was inserted
pub fn insert(&mut self, v: RecursionKey) -> bool {
match self {
Self::Array { data, len } => {
if *len < ARRAY_SIZE {
for value in data.iter().take(*len) {
// Safety: reading values within bounds
if unsafe { value.assume_init() } == v {
return false;
}
}

data[*len].write(v);
*len += 1;
true
} else {
let mut set = AHashSet::with_capacity(ARRAY_SIZE + 1);
for existing in data.iter() {
// Safety: the array is fully initialized
set.insert(unsafe { existing.assume_init() });
}
let inserted = set.insert(v);
*self = Self::Set(set);
inserted
}
}
// https://doc.rust-lang.org/std/collections/struct.HashSet.html#method.insert
// "If the set did not have this value present, `true` is returned."
Self::Set(set) => set.insert(v),
}
}

pub fn remove(&mut self, v: &RecursionKey) {
match self {
Self::Array { data, len } => {
*len = len.checked_sub(1).expect("remove from empty recursion guard");
// Safety: this is reading what was the back of the initialized array
let removed = unsafe { data.get_unchecked_mut(*len) };
assert!(unsafe { removed.assume_init_ref() } == v, "remove did not match insert");
// this should compile away to a noop
unsafe { std::ptr::drop_in_place(removed.as_mut_ptr()) }
}
Self::Set(set) => {
set.remove(v);
}
}
}
}

impl Drop for RecursionStack {
fn drop(&mut self) {
// This should compile away to a noop as Recursion>Key doesn't implement Drop, but it seemed
// desirable to leave this in for safety in case that should change in the future
if let Self::Array { data, len } = self {
for value in data.iter_mut().take(*len) {
// Safety: reading values within bounds
unsafe { std::ptr::drop_in_place(value.as_mut_ptr()) };
}
None => unreachable!(),
};
}
}
}
14 changes: 7 additions & 7 deletions src/serializers/extra.rs
Expand Up @@ -346,17 +346,17 @@ pub struct SerRecursionGuard {

impl SerRecursionGuard {
pub fn add(&self, value: &PyAny, def_ref_id: usize) -> PyResult<usize> {
// https://doc.rust-lang.org/std/collections/struct.HashSet.html#method.insert
// "If the set did not have this value present, `true` is returned."
let id = value.as_ptr() as usize;
let mut guard = self.guard.borrow_mut();

if guard.contains_or_insert(id, def_ref_id) {
Err(PyValueError::new_err("Circular reference detected (id repeated)"))
} else if guard.incr_depth() {
Err(PyValueError::new_err("Circular reference detected (depth exceeded)"))
if guard.insert(id, def_ref_id) {
if guard.incr_depth() {
Err(PyValueError::new_err("Circular reference detected (depth exceeded)"))
} else {
Ok(id)
}
} else {
Ok(id)
Err(PyValueError::new_err("Circular reference detected (id repeated)"))
}
}

Expand Down
16 changes: 8 additions & 8 deletions src/validators/definitions.rs
Expand Up @@ -76,17 +76,17 @@ impl Validator for DefinitionRefValidator {
self.definition.read(|validator| {
let validator = validator.unwrap();
if let Some(id) = input.identity() {
if state.recursion_guard.contains_or_insert(id, self.definition.id()) {
// we don't remove id here, we leave that to the validator which originally added id to `recursion_guard`
Err(ValError::new(ErrorTypeDefaults::RecursionLoop, input))
} else {
if state.recursion_guard.insert(id, self.definition.id()) {
if state.recursion_guard.incr_depth() {
return Err(ValError::new(ErrorTypeDefaults::RecursionLoop, input));
}
let output = validator.validate(py, input, state);
state.recursion_guard.remove(id, self.definition.id());
state.recursion_guard.decr_depth();
output
} else {
// we don't remove id here, we leave that to the validator which originally added id to `recursion_guard`
Err(ValError::new(ErrorTypeDefaults::RecursionLoop, input))
}
} else {
validator.validate(py, input, state)
Expand All @@ -105,17 +105,17 @@ impl Validator for DefinitionRefValidator {
self.definition.read(|validator| {
let validator = validator.unwrap();
if let Some(id) = obj.identity() {
if state.recursion_guard.contains_or_insert(id, self.definition.id()) {
// we don't remove id here, we leave that to the validator which originally added id to `recursion_guard`
Err(ValError::new(ErrorTypeDefaults::RecursionLoop, obj))
} else {
if state.recursion_guard.insert(id, self.definition.id()) {
if state.recursion_guard.incr_depth() {
return Err(ValError::new(ErrorTypeDefaults::RecursionLoop, obj));
}
let output = validator.validate_assignment(py, obj, field_name, field_value, state);
state.recursion_guard.remove(id, self.definition.id());
state.recursion_guard.decr_depth();
output
} else {
// we don't remove id here, we leave that to the validator which originally added id to `recursion_guard`
Err(ValError::new(ErrorTypeDefaults::RecursionLoop, obj))
}
} else {
validator.validate_assignment(py, obj, field_name, field_value, state)
Expand Down
2 changes: 1 addition & 1 deletion tests/serializers/test_any.py
Expand Up @@ -371,7 +371,7 @@ def fallback_func(obj):
f = FoobarCount(0)
v = 0
# when recursion is detected and we're in mode python, we just return the value
expected_visits = pydantic_core._pydantic_core._recursion_limit - 1
expected_visits = pydantic_core._pydantic_core._recursion_limit
assert any_serializer.to_python(f, fallback=fallback_func) == HasRepr(f'<FoobarCount {expected_visits} repr>')

with pytest.raises(ValueError, match=r'Circular reference detected \(depth exceeded\)'):
Expand Down