Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

improve performance of recursion guard #1156

Merged
merged 7 commits into from Jan 15, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.toml
Expand Up @@ -57,6 +57,7 @@ extension-module = ["pyo3/extension-module"]
lto = "fat"
codegen-units = 1
strip = true
#debug = true
samuelcolvin marked this conversation as resolved.
Show resolved Hide resolved

[profile.bench]
debug = true
Expand Down
129 changes: 102 additions & 27 deletions src/recursion_guard.rs
@@ -1,4 +1,5 @@
use ahash::AHashSet;
use std::hash::Hash;

type RecursionKey = (
// Identifier for the input object, e.g. the id() of a Python dict
Expand All @@ -13,56 +14,130 @@ type RecursionKey = (
/// It's used in `validators/definition` to detect when a reference is reused within itself.
#[derive(Debug, Clone, Default)]
pub struct RecursionGuard {
ids: Option<AHashSet<RecursionKey>>,
ids: SmallContainer<RecursionKey>,
// depth could be a hashmap {validator_id => depth} but for simplicity and performance it's easier to just
// use one number for all validators
depth: u16,
depth: u8,
}

// A hard limit to avoid stack overflows when rampant recursion occurs
pub const RECURSION_GUARD_LIMIT: u16 = if cfg!(any(target_family = "wasm", all(windows, PyPy))) {
pub const RECURSION_GUARD_LIMIT: u8 = if cfg!(any(target_family = "wasm", all(windows, PyPy))) {
// wasm and windows PyPy have very limited stack sizes
50
49
} else if cfg!(any(PyPy, windows)) {
// PyPy and Windows in general have more restricted stack space
100
99
} else {
255
};

impl RecursionGuard {
// insert a new id into the set, return whether the set already had the id in it
pub fn contains_or_insert(&mut self, obj_id: usize, node_id: usize) -> bool {
match self.ids {
// https://doc.rust-lang.org/std/collections/struct.HashSet.html#method.insert
// "If the set did not have this value present, `true` is returned."
Some(ref mut set) => !set.insert((obj_id, node_id)),
None => {
let mut set: AHashSet<RecursionKey> = AHashSet::with_capacity(10);
set.insert((obj_id, node_id));
self.ids = Some(set);
false
}
}
// insert a new value
// * return `None` if the array/set already had it in it
// * return `Some(index)` if the array didn't have it in it and it was inserted
pub fn contains_or_insert(&mut self, obj_id: usize, node_id: usize) -> Option<usize> {
self.ids.contains_or_insert((obj_id, node_id))
}

// see #143 this is used as a backup in case the identity check recursion guard fails
#[must_use]
#[cfg(any(target_family = "wasm", windows, PyPy))]
pub fn incr_depth(&mut self) -> bool {
// use saturating_add as it's faster (since there's no error path)
// and the RECURSION_GUARD_LIMIT check will be hit before it overflows
debug_assert!(RECURSION_GUARD_LIMIT < 255);
self.depth = self.depth.saturating_add(1);
self.depth > RECURSION_GUARD_LIMIT
}

#[must_use]
#[cfg(not(any(target_family = "wasm", windows, PyPy)))]
pub fn incr_depth(&mut self) -> bool {
self.depth += 1;
self.depth >= RECURSION_GUARD_LIMIT
debug_assert_eq!(RECURSION_GUARD_LIMIT, 255);
// use checked_add to check if we've hit the limit
if let Some(depth) = self.depth.checked_add(1) {
self.depth = depth;
false
} else {
true
}
}

pub fn decr_depth(&mut self) {
self.depth -= 1;
// for the same reason as incr_depth, use saturating_sub
self.depth = self.depth.saturating_sub(1);
}

pub fn remove(&mut self, obj_id: usize, node_id: usize, index: usize) {
self.ids.remove(&(obj_id, node_id), index);
}
}

// trial and error suggests this is a good value, going higher causes array lookups to get significantly slower
const ARRAY_SIZE: usize = 16;

#[derive(Debug, Clone)]
enum SmallContainer<T> {
Array([Option<T>; ARRAY_SIZE]),
Set(AHashSet<T>),
}

impl<T: Copy> Default for SmallContainer<T> {
fn default() -> Self {
Self::Array([None; ARRAY_SIZE])
}
}

impl<T: Eq + Hash + Clone> SmallContainer<T> {
// insert a new value
// * return `None` if the array/set already had it in it
// * return `Some(index)` if the array didn't have it in it and it was inserted
pub fn contains_or_insert(&mut self, v: T) -> Option<usize> {
match self {
Self::Array(array) => {
for (index, op_value) in array.iter_mut().enumerate() {
if let Some(existing) = op_value {
if existing == &v {
return None;
}
} else {
*op_value = Some(v);
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i don't think this is safe because there's a case where an earlier item has been removed from the array, but later values are still set.

E.g.

array = [123, 456, None, 789]

If you're looking up 789 it wouldn't be found in this case.

I had this before, but changed it to be safer.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess if we can be sure that new items are always added to the array (and removed) inside parent calls, we can be sure that there will never be None caps in the array.

@davidhewitt if you're confident of that, we can go ahead with this change. Worst case we just fallback on the depth guard.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pushed another commit which should work more like a stack, and we'll panic if the code is wrong.

return Some(index);
}
}

// No array slots exist; convert to set
let mut set: AHashSet<T> = AHashSet::with_capacity(ARRAY_SIZE + 1);
for existing in array.iter_mut() {
set.insert(existing.take().unwrap());
}
set.insert(v);
*self = Self::Set(set);
// id doesn't matter here as we'll be removing from a set
Some(0)
}
// https://doc.rust-lang.org/std/collections/struct.HashSet.html#method.insert
// "If the set did not have this value present, `true` is returned."
Self::Set(set) => {
if set.insert(v) {
// again id doesn't matter here as we'll be removing from a set
Some(0)
} else {
None
}
}
}
}

pub fn remove(&mut self, obj_id: usize, node_id: usize) {
match self.ids {
Some(ref mut set) => {
set.remove(&(obj_id, node_id));
pub fn remove(&mut self, v: &T, index: usize) {
match self {
Self::Array(array) => {
debug_assert!(array[index].as_ref() == Some(v), "remove did not match insert");
array[index] = None;
}
Self::Set(set) => {
set.remove(v);
}
None => unreachable!(),
};
}
}
}
20 changes: 10 additions & 10 deletions src/serializers/extra.rs
Expand Up @@ -345,24 +345,24 @@ pub struct SerRecursionGuard {
}

impl SerRecursionGuard {
pub fn add(&self, value: &PyAny, def_ref_id: usize) -> PyResult<usize> {
// https://doc.rust-lang.org/std/collections/struct.HashSet.html#method.insert
// "If the set did not have this value present, `true` is returned."
pub fn add(&self, value: &PyAny, def_ref_id: usize) -> PyResult<(usize, usize)> {
let id = value.as_ptr() as usize;
let mut guard = self.guard.borrow_mut();

if guard.contains_or_insert(id, def_ref_id) {
Err(PyValueError::new_err("Circular reference detected (id repeated)"))
} else if guard.incr_depth() {
Err(PyValueError::new_err("Circular reference detected (depth exceeded)"))
if let Some(insert_index) = guard.contains_or_insert(id, def_ref_id) {
if guard.incr_depth() {
Err(PyValueError::new_err("Circular reference detected (depth exceeded)"))
} else {
Ok((id, insert_index))
}
} else {
Ok(id)
Err(PyValueError::new_err("Circular reference detected (id repeated)"))
}
}

pub fn pop(&self, id: usize, def_ref_id: usize) {
pub fn pop(&self, id: usize, def_ref_id: usize, insert_index: usize) {
let mut guard = self.guard.borrow_mut();
guard.decr_depth();
guard.remove(id, def_ref_id);
guard.remove(id, def_ref_id, insert_index);
}
}
14 changes: 7 additions & 7 deletions src/serializers/infer.rs
Expand Up @@ -45,7 +45,7 @@ pub(crate) fn infer_to_python_known(
extra: &Extra,
) -> PyResult<PyObject> {
let py = value.py();
let value_id = match extra.rec_guard.add(value, INFER_DEF_REF_ID) {
let (value_id, insert_index) = match extra.rec_guard.add(value, INFER_DEF_REF_ID) {
Ok(id) => id,
Err(e) => {
return match extra.mode {
Expand Down Expand Up @@ -226,7 +226,7 @@ pub(crate) fn infer_to_python_known(
if let Some(fallback) = extra.fallback {
let next_value = fallback.call1((value,))?;
let next_result = infer_to_python(next_value, include, exclude, extra);
extra.rec_guard.pop(value_id, INFER_DEF_REF_ID);
extra.rec_guard.pop(value_id, INFER_DEF_REF_ID, insert_index);
return next_result;
} else if extra.serialize_unknown {
serialize_unknown(value).into_py(py)
Expand Down Expand Up @@ -284,15 +284,15 @@ pub(crate) fn infer_to_python_known(
if let Some(fallback) = extra.fallback {
let next_value = fallback.call1((value,))?;
let next_result = infer_to_python(next_value, include, exclude, extra);
extra.rec_guard.pop(value_id, INFER_DEF_REF_ID);
extra.rec_guard.pop(value_id, INFER_DEF_REF_ID, insert_index);
return next_result;
}
value.into_py(py)
}
_ => value.into_py(py),
},
};
extra.rec_guard.pop(value_id, INFER_DEF_REF_ID);
extra.rec_guard.pop(value_id, INFER_DEF_REF_ID, insert_index);
Ok(value)
}

Expand Down Expand Up @@ -351,7 +351,7 @@ pub(crate) fn infer_serialize_known<S: Serializer>(
exclude: Option<&PyAny>,
extra: &Extra,
) -> Result<S::Ok, S::Error> {
let value_id = match extra.rec_guard.add(value, INFER_DEF_REF_ID).map_err(py_err_se_err) {
let (value_id, insert_index) = match extra.rec_guard.add(value, INFER_DEF_REF_ID).map_err(py_err_se_err) {
Ok(v) => v,
Err(e) => {
return if extra.serialize_unknown {
Expand Down Expand Up @@ -534,7 +534,7 @@ pub(crate) fn infer_serialize_known<S: Serializer>(
if let Some(fallback) = extra.fallback {
let next_value = fallback.call1((value,)).map_err(py_err_se_err)?;
let next_result = infer_serialize(next_value, serializer, include, exclude, extra);
extra.rec_guard.pop(value_id, INFER_DEF_REF_ID);
extra.rec_guard.pop(value_id, INFER_DEF_REF_ID, insert_index);
return next_result;
} else if extra.serialize_unknown {
serializer.serialize_str(&serialize_unknown(value))
Expand All @@ -548,7 +548,7 @@ pub(crate) fn infer_serialize_known<S: Serializer>(
}
}
};
extra.rec_guard.pop(value_id, INFER_DEF_REF_ID);
extra.rec_guard.pop(value_id, INFER_DEF_REF_ID, insert_index);
ser_result
}

Expand Down
8 changes: 4 additions & 4 deletions src/serializers/type_serializers/definitions.rs
Expand Up @@ -70,9 +70,9 @@ impl TypeSerializer for DefinitionRefSerializer {
) -> PyResult<PyObject> {
self.definition.read(|comb_serializer| {
let comb_serializer = comb_serializer.unwrap();
let value_id = extra.rec_guard.add(value, self.definition.id())?;
let (value_id, insert_index) = extra.rec_guard.add(value, self.definition.id())?;
let r = comb_serializer.to_python(value, include, exclude, extra);
extra.rec_guard.pop(value_id, self.definition.id());
extra.rec_guard.pop(value_id, self.definition.id(), insert_index);
r
})
}
Expand All @@ -91,12 +91,12 @@ impl TypeSerializer for DefinitionRefSerializer {
) -> Result<S::Ok, S::Error> {
self.definition.read(|comb_serializer| {
let comb_serializer = comb_serializer.unwrap();
let value_id = extra
let (value_id, insert_index) = extra
.rec_guard
.add(value, self.definition.id())
.map_err(py_err_se_err)?;
let r = comb_serializer.serde_serialize(value, serializer, include, exclude, extra);
extra.rec_guard.pop(value_id, self.definition.id());
extra.rec_guard.pop(value_id, self.definition.id(), insert_index);
r
})
}
Expand Down
20 changes: 10 additions & 10 deletions src/validators/definitions.rs
Expand Up @@ -76,17 +76,17 @@ impl Validator for DefinitionRefValidator {
self.definition.read(|validator| {
let validator = validator.unwrap();
if let Some(id) = input.identity() {
if state.recursion_guard.contains_or_insert(id, self.definition.id()) {
// we don't remove id here, we leave that to the validator which originally added id to `recursion_guard`
Err(ValError::new(ErrorTypeDefaults::RecursionLoop, input))
} else {
if let Some(insert_index) = state.recursion_guard.contains_or_insert(id, self.definition.id()) {
if state.recursion_guard.incr_depth() {
return Err(ValError::new(ErrorTypeDefaults::RecursionLoop, input));
}
let output = validator.validate(py, input, state);
state.recursion_guard.remove(id, self.definition.id());
state.recursion_guard.remove(id, self.definition.id(), insert_index);
state.recursion_guard.decr_depth();
output
} else {
// we don't remove id here, we leave that to the validator which originally added id to `recursion_guard`
Err(ValError::new(ErrorTypeDefaults::RecursionLoop, input))
}
} else {
validator.validate(py, input, state)
Expand All @@ -105,17 +105,17 @@ impl Validator for DefinitionRefValidator {
self.definition.read(|validator| {
let validator = validator.unwrap();
if let Some(id) = obj.identity() {
if state.recursion_guard.contains_or_insert(id, self.definition.id()) {
// we don't remove id here, we leave that to the validator which originally added id to `recursion_guard`
Err(ValError::new(ErrorTypeDefaults::RecursionLoop, obj))
} else {
if let Some(insert_index) = state.recursion_guard.contains_or_insert(id, self.definition.id()) {
if state.recursion_guard.incr_depth() {
return Err(ValError::new(ErrorTypeDefaults::RecursionLoop, obj));
}
let output = validator.validate_assignment(py, obj, field_name, field_value, state);
state.recursion_guard.remove(id, self.definition.id());
state.recursion_guard.remove(id, self.definition.id(), insert_index);
state.recursion_guard.decr_depth();
output
} else {
// we don't remove id here, we leave that to the validator which originally added id to `recursion_guard`
Err(ValError::new(ErrorTypeDefaults::RecursionLoop, obj))
}
} else {
validator.validate_assignment(py, obj, field_name, field_value, state)
Expand Down
2 changes: 1 addition & 1 deletion tests/serializers/test_any.py
Expand Up @@ -371,7 +371,7 @@ def fallback_func(obj):
f = FoobarCount(0)
v = 0
# when recursion is detected and we're in mode python, we just return the value
expected_visits = pydantic_core._pydantic_core._recursion_limit - 1
expected_visits = pydantic_core._pydantic_core._recursion_limit
assert any_serializer.to_python(f, fallback=fallback_func) == HasRepr(f'<FoobarCount {expected_visits} repr>')

with pytest.raises(ValueError, match=r'Circular reference detected \(depth exceeded\)'):
Expand Down