Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add SchemaSerializer.__reduce__ method to enable pickle serialization #1006

Merged
merged 19 commits into from Oct 9, 2023
Merged
26 changes: 23 additions & 3 deletions src/serializers/mod.rs
Expand Up @@ -26,13 +26,17 @@ mod ob_type;
mod shared;
mod type_serializers;

#[pyclass(module = "pydantic_core._pydantic_core")]
#[pyclass(module = "pydantic_core._pydantic_core", frozen)]
#[derive(Debug)]
pub struct SchemaSerializer {
serializer: CombinedSerializer,
definitions: Definitions<CombinedSerializer>,
expected_json_size: AtomicUsize,
config: SerializationConfig,
// References to the Python schema and config objects are saved to enable
// reconstructing the object for pickle support (see `__reduce__`).
py_schema: Py<PyDict>,
py_config: Option<Py<PyDict>>,
}

impl SchemaSerializer {
Expand Down Expand Up @@ -71,15 +75,19 @@ impl SchemaSerializer {
#[pymethods]
impl SchemaSerializer {
#[new]
pub fn py_new(schema: &PyDict, config: Option<&PyDict>) -> PyResult<Self> {
pub fn py_new(py: Python, schema: &PyDict, config: Option<&PyDict>) -> PyResult<Self> {
let mut definitions_builder = DefinitionsBuilder::new();

let serializer = CombinedSerializer::build(schema.downcast()?, config, &mut definitions_builder)?;
Ok(Self {
serializer,
definitions: definitions_builder.finish()?,
expected_json_size: AtomicUsize::new(1024),
config: SerializationConfig::from_config(config)?,
py_schema: schema.into_py(py),
py_config: match config {
Some(c) if !c.is_empty() => Some(c.into_py(py)),
_ => None,
},
})
}

Expand Down Expand Up @@ -174,6 +182,14 @@ impl SchemaSerializer {
Ok(py_bytes.into())
}

pub fn __reduce__(slf: &PyCell<Self>) -> PyResult<(PyObject, (PyObject, PyObject))> {
// Enables support for `pickle` serialization.
let py = slf.py();
let cls = slf.get_type().into();
let init_args = (slf.get().py_schema.to_object(py), slf.get().py_config.to_object(py));
Ok((cls, init_args))
}

pub fn __repr__(&self) -> String {
format!(
"SchemaSerializer(serializer={:#?}, definitions={:#?})",
Expand All @@ -182,6 +198,10 @@ impl SchemaSerializer {
}

fn __traverse__(&self, visit: PyVisit<'_>) -> Result<(), PyTraverseError> {
visit.call(&self.py_schema)?;
if let Some(ref py_config) = self.py_config {
visit.call(py_config)?;
}
self.serializer.py_gc_traverse(&visit)?;
self.definitions.py_gc_traverse(&visit)?;
Ok(())
Expand Down
44 changes: 29 additions & 15 deletions src/validators/mod.rs
Expand Up @@ -97,14 +97,17 @@ impl PySome {
}
}

#[pyclass(module = "pydantic_core._pydantic_core")]
#[pyclass(module = "pydantic_core._pydantic_core", frozen)]
#[derive(Debug)]
pub struct SchemaValidator {
validator: CombinedValidator,
definitions: Definitions<CombinedValidator>,
schema: PyObject,
// References to the Python schema and config objects are saved to enable
// reconstructing the object for cloudpickle support (see `__reduce__`).
py_schema: Py<PyAny>,
py_config: Option<Py<PyDict>>,
#[pyo3(get)]
title: PyObject,
py_title: Py<PyAny>,
hide_input_in_errors: bool,
validation_error_cause: bool,
}
Expand All @@ -121,11 +124,16 @@ impl SchemaValidator {
for val in definitions.values() {
val.get().unwrap().complete()?;
}
let py_schema = schema.into_py(py);
let py_config = match config {
Some(c) if !c.is_empty() => Some(c.into_py(py)),
_ => None,
};
let config_title = match config {
Some(c) => c.get_item("title"),
None => None,
};
let title = match config_title {
let py_title = match config_title {
Some(t) => t.into_py(py),
None => validator.get_name().into_py(py),
};
Expand All @@ -134,18 +142,20 @@ impl SchemaValidator {
Ok(Self {
validator,
definitions,
schema: schema.into_py(py),
title,
py_schema,
py_config,
py_title,
hide_input_in_errors,
validation_error_cause,
})
}

pub fn __reduce__(slf: &PyCell<Self>) -> PyResult<PyObject> {
pub fn __reduce__(slf: &PyCell<Self>) -> PyResult<(PyObject, (PyObject, PyObject))> {
// Enables support for `pickle` serialization.
let py = slf.py();
let args = (slf.try_borrow()?.schema.to_object(py),);
let cls = slf.getattr("__class__")?;
Ok((cls, args).into_py(py))
let cls = slf.get_type().into();
let init_args = (slf.get().py_schema.to_object(py), slf.get().py_config.to_object(py));
Ok((cls, init_args))
}

#[pyo3(signature = (input, *, strict=None, from_attributes=None, context=None, self_instance=None))]
Expand Down Expand Up @@ -299,15 +309,18 @@ impl SchemaValidator {
pub fn __repr__(&self, py: Python) -> String {
format!(
"SchemaValidator(title={:?}, validator={:#?}, definitions={:#?})",
self.title.extract::<&str>(py).unwrap(),
self.py_title.extract::<&str>(py).unwrap(),
self.validator,
self.definitions,
)
}

fn __traverse__(&self, visit: PyVisit<'_>) -> Result<(), PyTraverseError> {
self.validator.py_gc_traverse(&visit)?;
visit.call(&self.schema)?;
visit.call(&self.py_schema)?;
if let Some(ref py_config) = self.py_config {
visit.call(py_config)?;
}
Ok(())
}
}
Expand Down Expand Up @@ -338,7 +351,7 @@ impl SchemaValidator {
fn prepare_validation_err(&self, py: Python, error: ValError, input_type: InputType) -> PyErr {
ValidationError::from_val_error(
py,
self.title.clone_ref(py),
self.py_title.clone_ref(py),
input_type,
error,
None,
Expand Down Expand Up @@ -396,8 +409,9 @@ impl<'py> SelfValidator<'py> {
Ok(SchemaValidator {
validator,
definitions,
schema: py.None(),
title: "Self Schema".into_py(py),
py_schema: py.None(),
py_config: None,
py_title: "Self Schema".into_py(py),
hide_input_in_errors: false,
validation_error_cause: false,
})
Expand Down
2 changes: 1 addition & 1 deletion src/validators/url.rs
Expand Up @@ -498,7 +498,7 @@ fn check_sub_defaults(
if let Some(default_port) = default_port {
lib_url
.set_port(Some(default_port))
.map_err(|_| map_parse_err(ParseError::EmptyHost))?;
.map_err(|()| map_parse_err(ParseError::EmptyHost))?;
edoakes marked this conversation as resolved.
Show resolved Hide resolved
}
}
if let Some(ref default_path) = default_path {
Expand Down
50 changes: 50 additions & 0 deletions tests/serializers/test_pickling.py
@@ -0,0 +1,50 @@
import json
import pickle
from datetime import timedelta

import pytest

from pydantic_core import core_schema
from pydantic_core._pydantic_core import SchemaSerializer


def repr_function(value, _info):
return repr(value)


def test_basic_schema_serializer():
s = SchemaSerializer(core_schema.dict_schema())
s = pickle.loads(pickle.dumps(s))
assert s.to_python({'a': 1, b'b': 2, 33: 3}) == {'a': 1, b'b': 2, 33: 3}
assert s.to_python({'a': 1, b'b': 2, 33: 3, True: 4}, mode='json') == {'a': 1, 'b': 2, '33': 3, 'true': 4}
assert s.to_json({'a': 1, b'b': 2, 33: 3, True: 4}) == b'{"a":1,"b":2,"33":3,"true":4}'

assert s.to_python({(1, 2): 3}) == {(1, 2): 3}
assert s.to_python({(1, 2): 3}, mode='json') == {'1,2': 3}
assert s.to_json({(1, 2): 3}) == b'{"1,2":3}'


@pytest.mark.parametrize(
'value,expected_python,expected_json',
[(None, 'None', b'"None"'), (1, '1', b'"1"'), ([1, 2, 3], '[1, 2, 3]', b'"[1, 2, 3]"')],
)
def test_schema_serializer_capturing_function(value, expected_python, expected_json):
# Test a SchemaSerializer that captures a function.
s = SchemaSerializer(
core_schema.any_schema(
serialization=core_schema.plain_serializer_function_ser_schema(repr_function, info_arg=True)
)
)
s = pickle.loads(pickle.dumps(s))
assert s.to_python(value) == expected_python
assert s.to_json(value) == expected_json
assert s.to_python(value, mode='json') == json.loads(expected_json)


def test_schema_serializer_containing_config():
s = SchemaSerializer(core_schema.timedelta_schema(), config={'ser_json_timedelta': 'float'})
s = pickle.loads(pickle.dumps(s))

assert s.to_python(timedelta(seconds=4, microseconds=500_000)) == timedelta(seconds=4, microseconds=500_000)
assert s.to_python(timedelta(seconds=4, microseconds=500_000), mode='json') == 4.5
assert s.to_json(timedelta(seconds=4, microseconds=500_000)) == b'4.5'
4 changes: 2 additions & 2 deletions tests/test.rs
Expand Up @@ -46,7 +46,7 @@ mod tests {
]
}"#;
let schema: &PyDict = py.eval(code, None, None).unwrap().extract().unwrap();
SchemaSerializer::py_new(schema, None).unwrap();
SchemaSerializer::py_new(py, schema, None).unwrap();
edoakes marked this conversation as resolved.
Show resolved Hide resolved
});
}

Expand Down Expand Up @@ -77,7 +77,7 @@ a = A()
py.run(code, None, Some(locals)).unwrap();
let a: &PyAny = locals.get_item("a").unwrap().extract().unwrap();
let schema: &PyDict = locals.get_item("schema").unwrap().extract().unwrap();
let serialized: Vec<u8> = SchemaSerializer::py_new(schema, None)
let serialized: Vec<u8> = SchemaSerializer::py_new(py, schema, None)
.unwrap()
.to_json(py, a, None, None, None, true, false, false, false, false, true, None)
.unwrap()
Expand Down
4 changes: 3 additions & 1 deletion tests/test_garbage_collection.py
Expand Up @@ -27,7 +27,9 @@ class BaseModel:
__schema__: SchemaSerializer

def __init_subclass__(cls) -> None:
cls.__schema__ = SchemaSerializer(core_schema.model_schema(cls, GC_TEST_SCHEMA_INNER))
cls.__schema__ = SchemaSerializer(
core_schema.model_schema(cls, GC_TEST_SCHEMA_INNER), config={'ser_json_timedelta': 'float'}
)
Comment on lines +30 to +32
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Curious why this was necessary?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wasn't necessary, I just wanted to ensure that there was coverage of the garbage collection path when a config was passed given that I'm holding a reference to it now (existing tests didn't take a config).


cache: 'WeakValueDictionary[int, Any]' = WeakValueDictionary()

Expand Down
12 changes: 0 additions & 12 deletions tests/validators/test_datetime.py
@@ -1,6 +1,5 @@
import copy
import json
import pickle
import platform
import re
from datetime import date, datetime, time, timedelta, timezone, tzinfo
Expand Down Expand Up @@ -480,17 +479,6 @@ def test_tz_constraint_wrong():
validate_core_schema(core_schema.datetime_schema(tz_constraint='wrong'))


def test_tz_pickle() -> None:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is moved to test_pickling.py

"""
https://github.com/pydantic/pydantic-core/issues/589
"""
v = SchemaValidator(core_schema.datetime_schema())
original = datetime(2022, 6, 8, 12, 13, 14, tzinfo=timezone(timedelta(hours=-12, minutes=-15)))
validated = v.validate_python('2022-06-08T12:13:14-12:15')
assert validated == original
assert pickle.loads(pickle.dumps(validated)) == validated == original


def test_tz_hash() -> None:
v = SchemaValidator(core_schema.datetime_schema())
lookup: Dict[datetime, str] = {}
Expand Down
55 changes: 55 additions & 0 deletions tests/validators/test_pickling.py
@@ -0,0 +1,55 @@
import pickle
import re
from datetime import datetime, timedelta, timezone

import pytest

from pydantic_core import core_schema
from pydantic_core._pydantic_core import SchemaValidator, ValidationError

from ..conftest import PyAndJson


def test_basic_schema_validator(py_and_json: PyAndJson):
v = py_and_json({'type': 'dict', 'keys_schema': {'type': 'int'}, 'values_schema': {'type': 'int'}})
v = pickle.loads(pickle.dumps(v))
assert v.validate_test({'1': 2, '3': 4}) == {1: 2, 3: 4}

v = py_and_json({'type': 'dict', 'strict': True, 'keys_schema': {'type': 'int'}, 'values_schema': {'type': 'int'}})
v = pickle.loads(pickle.dumps(v))
assert v.validate_test({'1': 2, '3': 4}) == {1: 2, 3: 4}
assert v.validate_test({}) == {}
with pytest.raises(ValidationError, match=re.escape('[type=dict_type, input_value=[], input_type=list]')):
v.validate_test([])


def test_schema_validator_containing_config():
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I verified that this newly-added test case fails without my changes to SchemaValidator.__reduce__ (config object is lost).

"""
Verify that the config object is not lost during (de)serialization.
"""
v = SchemaValidator(
core_schema.model_fields_schema({'f': core_schema.model_field(core_schema.str_schema())}),
config=core_schema.CoreConfig(extra_fields_behavior='allow'),
)
v = pickle.loads(pickle.dumps(v))

m, model_extra, fields_set = v.validate_python({'f': 'x', 'extra_field': '123'})
assert m == {'f': 'x'}
# If the config was lost during (de)serialization, the below checks would fail as
# the default behavior is to ignore extra fields.
assert model_extra == {'extra_field': '123'}
assert fields_set == {'f', 'extra_field'}

v.validate_assignment(m, 'f', 'y')
assert m == {'f': 'y'}


def test_schema_validator_tz_pickle() -> None:
"""
https://github.com/pydantic/pydantic-core/issues/589
"""
v = SchemaValidator(core_schema.datetime_schema())
original = datetime(2022, 6, 8, 12, 13, 14, tzinfo=timezone(timedelta(hours=-12, minutes=-15)))
validated = v.validate_python('2022-06-08T12:13:14-12:15')
assert validated == original
assert pickle.loads(pickle.dumps(validated)) == validated == original