Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add SchemaSerializer.__reduce__ method to enable pickle serialization #1006

Merged
merged 19 commits into from Oct 9, 2023
Merged
25 changes: 23 additions & 2 deletions src/serializers/mod.rs
Expand Up @@ -3,6 +3,7 @@ use std::sync::atomic::{AtomicUsize, Ordering};

use pyo3::prelude::*;
use pyo3::types::{PyBytes, PyDict};
use pyo3::PyTypeInfo;
use pyo3::{PyTraverseError, PyVisit};

use crate::definitions::{Definitions, DefinitionsBuilder};
Expand Down Expand Up @@ -33,6 +34,10 @@ pub struct SchemaSerializer {
definitions: Definitions<CombinedSerializer>,
expected_json_size: AtomicUsize,
config: SerializationConfig,
// References to the Python schema and config objects are saved to enable
// reconstructing the object for cloudpickle support (see `__reduce__`).
py_schema: Py<PyDict>,
py_config: Option<Py<PyDict>>,
}

impl SchemaSerializer {
Expand Down Expand Up @@ -71,15 +76,19 @@ impl SchemaSerializer {
#[pymethods]
impl SchemaSerializer {
#[new]
pub fn py_new(schema: &PyDict, config: Option<&PyDict>) -> PyResult<Self> {
pub fn py_new(py: Python, schema: &PyDict, config: Option<&PyDict>) -> PyResult<Self> {
let mut definitions_builder = DefinitionsBuilder::new();

let serializer = CombinedSerializer::build(schema.downcast()?, config, &mut definitions_builder)?;
Ok(Self {
serializer,
definitions: definitions_builder.finish()?,
expected_json_size: AtomicUsize::new(1024),
config: SerializationConfig::from_config(config)?,
py_schema: schema.into_py(py),
py_config: match config {
Some(d) if !d.is_empty() => Some(d.into_py(py)),
_ => None,
},
})
}

Expand Down Expand Up @@ -174,6 +183,14 @@ impl SchemaSerializer {
Ok(py_bytes.into())
}

pub fn __reduce__(&self, py: Python) -> PyResult<(PyObject, (PyObject, PyObject))> {
// Enables support for `cloudpickle` serialization.
Ok((
SchemaSerializer::type_object(py).to_object(py),
(self.py_schema.to_object(py), self.py_config.to_object(py)),
))
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here's my take:

Suggested change
pub fn __reduce__(&self, py: Python) -> PyResult<(PyObject, (PyObject, PyObject))> {
// Enables support for `cloudpickle` serialization.
Ok((
SchemaSerializer::type_object(py).to_object(py),
(self.py_schema.to_object(py), self.py_config.to_object(py)),
))
}
pub fn __reduce__(slf: &PyCell<Self>) -> PyResult<(&PyType, (PyObject, PyObject))> {
let py = slf.py();
let args = (slf.get().schema.to_object(py), slf.get().py_config.to_object(py));
let cls = slf.get_type();
Ok((cls, args))
}

You'll need to update the #[pyclass] definition above for SchemaSerializer to include #[pyclass(frozen)].

Justifications:

  1. slf: &PyCell<Self> is equivalent to self Python object
  2. slf.get() gives read access to the Rust data as long as it's frozen
  3. I think we want to look up type(self) rather than hard-code SchemaSerializer type in case we allow subclassing later. slf.get_type() is a cleaner way to do it than trying to use getattr.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense, thanks for the pointers, I'll update!


pub fn __repr__(&self) -> String {
format!(
"SchemaSerializer(serializer={:#?}, definitions={:#?})",
Expand All @@ -182,6 +199,10 @@ impl SchemaSerializer {
}

fn __traverse__(&self, visit: PyVisit<'_>) -> Result<(), PyTraverseError> {
visit.call(&self.py_schema)?;
if let Some(ref py_config) = self.py_config {
visit.call(py_config)?;
}
self.serializer.py_gc_traverse(&visit)?;
self.definitions.py_gc_traverse(&visit)?;
Ok(())
Expand Down
2 changes: 1 addition & 1 deletion src/validators/url.rs
Expand Up @@ -498,7 +498,7 @@ fn check_sub_defaults(
if let Some(default_port) = default_port {
lib_url
.set_port(Some(default_port))
.map_err(|_| map_parse_err(ParseError::EmptyHost))?;
.map_err(|()| map_parse_err(ParseError::EmptyHost))?;
edoakes marked this conversation as resolved.
Show resolved Hide resolved
}
}
if let Some(ref default_path) = default_path {
Expand Down
4 changes: 2 additions & 2 deletions tests/test.rs
Expand Up @@ -46,7 +46,7 @@ mod tests {
]
}"#;
let schema: &PyDict = py.eval(code, None, None).unwrap().extract().unwrap();
SchemaSerializer::py_new(schema, None).unwrap();
SchemaSerializer::py_new(py, schema, None).unwrap();
edoakes marked this conversation as resolved.
Show resolved Hide resolved
});
}

Expand Down Expand Up @@ -77,7 +77,7 @@ a = A()
py.run(code, None, Some(locals)).unwrap();
let a: &PyAny = locals.get_item("a").unwrap().extract().unwrap();
let schema: &PyDict = locals.get_item("schema").unwrap().extract().unwrap();
let serialized: Vec<u8> = SchemaSerializer::py_new(schema, None)
let serialized: Vec<u8> = SchemaSerializer::py_new(py, schema, None)
.unwrap()
.to_json(py, a, None, None, None, true, false, false, false, false, true, None)
.unwrap()
Expand Down