Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

dataclass serialization speedups #1162

Merged
merged 5 commits into from Jan 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
257 changes: 162 additions & 95 deletions src/serializers/fields.rs
Expand Up @@ -100,6 +100,15 @@ pub struct GeneralFieldsSerializer {
required_fields: usize,
}

macro_rules! option_length {
($op_has_len:expr) => {
match $op_has_len {
Some(ref has_len) => has_len.len(),
None => 0,
}
};
}

impl GeneralFieldsSerializer {
pub(super) fn new(
fields: AHashMap<String, SerField>,
Expand Down Expand Up @@ -136,50 +145,21 @@ impl GeneralFieldsSerializer {
}
}
}
}

macro_rules! option_length {
($op_has_len:expr) => {
match $op_has_len {
Some(ref has_len) => has_len.len(),
None => 0,
}
};
}

impl_py_gc_traverse!(GeneralFieldsSerializer {
fields,
computed_fields
});

impl TypeSerializer for GeneralFieldsSerializer {
fn to_python(
pub fn main_to_python<'py>(
&self,
value: &PyAny,
include: Option<&PyAny>,
exclude: Option<&PyAny>,
extra: &Extra,
) -> PyResult<PyObject> {
let py = value.py();
// If there is already a model registered (from a dataclass, BaseModel)
// then do not touch it
// If there is no model, we (a TypedDict) are the model
let td_extra = Extra {
model: extra.model.map_or_else(|| Some(value), Some),
..*extra
};
let (main_dict, extra_dict) = if let Some(main_extra_dict) = self.extract_dicts(value) {
main_extra_dict
} else {
td_extra.warnings.on_fallback_py(self.get_name(), value, &td_extra)?;
return infer_to_python(value, include, exclude, &td_extra);
};

py: Python<'py>,
main_iter: impl Iterator<Item = PyResult<(&'py PyAny, &'py PyAny)>>,
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

generic function for building a dict from an iterator, so we can use both in the GeneralFieldsSerializer serializer and directly from the dataclass serializer.

include: Option<&'py PyAny>,
exclude: Option<&'py PyAny>,
extra: Extra,
) -> PyResult<&'py PyDict> {
let output_dict = PyDict::new(py);
let mut used_req_fields: usize = 0;

// NOTE! we maintain the order of the input dict assuming that's right
for (key, value) in main_dict {
for result in main_iter {
let (key, value) = result?;
let key_str = key_str(key)?;
let op_field = self.fields.get(key_str);
if extra.exclude_none && value.is_none() {
Expand All @@ -190,16 +170,16 @@ impl TypeSerializer for GeneralFieldsSerializer {
}
continue;
}
let extra = Extra {
let field_extra = Extra {
field_name: Some(key_str),
..td_extra
..extra
};
if let Some((next_include, next_exclude)) = self.filter.key_filter(key, include, exclude)? {
if let Some(field) = op_field {
if let Some(ref serializer) = field.serializer {
if !exclude_default(value, &extra, serializer)? {
let value = serializer.to_python(value, next_include, next_exclude, &extra)?;
let output_key = field.get_key_py(output_dict.py(), &extra);
if !exclude_default(value, &field_extra, serializer)? {
let value = serializer.to_python(value, next_include, next_exclude, &field_extra)?;
let output_key = field.get_key_py(output_dict.py(), &field_extra);
output_dict.set_item(output_key, value)?;
}
}
Expand All @@ -209,23 +189,140 @@ impl TypeSerializer for GeneralFieldsSerializer {
}
} else if self.mode == FieldsMode::TypedDictAllow {
let value = match &self.extra_serializer {
Some(serializer) => serializer.to_python(value, next_include, next_exclude, &extra)?,
None => infer_to_python(value, next_include, next_exclude, &extra)?,
Some(serializer) => serializer.to_python(value, next_include, next_exclude, &field_extra)?,
None => infer_to_python(value, next_include, next_exclude, &field_extra)?,
};
output_dict.set_item(key, value)?;
} else if extra.check == SerCheck::Strict {
} else if field_extra.check == SerCheck::Strict {
return Err(PydanticSerializationUnexpectedValue::new_err(None));
}
}
}
if td_extra.check.enabled()

if extra.check.enabled()
// If any of these are true we can't count fields
&& !(extra.exclude_defaults || extra.exclude_unset || extra.exclude_none)
// Check for missing fields, we can't have extra fields here
&& self.required_fields > used_req_fields
{
return Err(PydanticSerializationUnexpectedValue::new_err(None));
Err(PydanticSerializationUnexpectedValue::new_err(None))
} else {
Ok(output_dict)
}
}

pub fn main_serde_serialize<'py, S: serde::ser::Serializer>(
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same for serde/json.

&self,
main_iter: impl Iterator<Item = PyResult<(&'py PyAny, &'py PyAny)>>,
expected_len: usize,
serializer: S,
include: Option<&'py PyAny>,
exclude: Option<&'py PyAny>,
extra: Extra,
) -> Result<S::SerializeMap, S::Error> {
// NOTE! As above, we maintain the order of the input dict assuming that's right
// we don't both with `used_fields` here because on unions, `to_python(..., mode='json')` is used
let mut map = serializer.serialize_map(Some(expected_len))?;

for result in main_iter {
let (key, value) = result.map_err(py_err_se_err)?;
if extra.exclude_none && value.is_none() {
continue;
}
let key_str = key_str(key).map_err(py_err_se_err)?;
let field_extra = Extra {
field_name: Some(key_str),
..extra
};

let filter = self.filter.key_filter(key, include, exclude).map_err(py_err_se_err)?;
if let Some((next_include, next_exclude)) = filter {
if let Some(field) = self.fields.get(key_str) {
if let Some(ref serializer) = field.serializer {
if !exclude_default(value, &field_extra, serializer).map_err(py_err_se_err)? {
let s =
PydanticSerializer::new(value, serializer, next_include, next_exclude, &field_extra);
let output_key = field.get_key_json(key_str, &field_extra);
map.serialize_entry(&output_key, &s)?;
}
}
} else if self.mode == FieldsMode::TypedDictAllow {
let output_key = infer_json_key(key, &field_extra).map_err(py_err_se_err)?;
let s = SerializeInfer::new(value, next_include, next_exclude, &field_extra);
map.serialize_entry(&output_key, &s)?;
}
// no error case here since unions (which need the error case) use `to_python(..., mode='json')`
}
}
Ok(map)
}

pub fn add_computed_fields_python(
&self,
model: Option<&PyAny>,
output_dict: &PyDict,
include: Option<&PyAny>,
exclude: Option<&PyAny>,
extra: &Extra,
) -> PyResult<()> {
if let Some(ref computed_fields) = self.computed_fields {
if let Some(model_value) = model {
let cf_extra = Extra { model, ..*extra };
computed_fields.to_python(model_value, output_dict, &self.filter, include, exclude, &cf_extra)?;
}
}
Ok(())
}

pub fn add_computed_fields_json<S: serde::ser::Serializer>(
&self,
model: Option<&PyAny>,
map: &mut S::SerializeMap,
include: Option<&PyAny>,
exclude: Option<&PyAny>,
extra: &Extra,
) -> Result<(), S::Error> {
if let Some(ref computed_fields) = self.computed_fields {
if let Some(model) = model {
computed_fields.serde_serialize::<S>(model, map, &self.filter, include, exclude, extra)?;
}
}
Ok(())
}

pub fn computed_field_count(&self) -> usize {
option_length!(self.computed_fields)
}
}

impl_py_gc_traverse!(GeneralFieldsSerializer {
fields,
computed_fields
});

impl TypeSerializer for GeneralFieldsSerializer {
fn to_python(
&self,
value: &PyAny,
include: Option<&PyAny>,
exclude: Option<&PyAny>,
extra: &Extra,
) -> PyResult<PyObject> {
let py = value.py();
// If there is already a model registered (from a dataclass, BaseModel)
// then do not touch it
// If there is no model, we (a TypedDict) are the model
let model = extra.model.map_or_else(|| Some(value), Some);
let td_extra = Extra { model, ..*extra };
let (main_dict, extra_dict) = if let Some(main_extra_dict) = self.extract_dicts(value) {
main_extra_dict
} else {
td_extra.warnings.on_fallback_py(self.get_name(), value, &td_extra)?;
return infer_to_python(value, include, exclude, &td_extra);
};

let output_dict = self.main_to_python(py, main_dict.iter().map(Ok), include, exclude, td_extra)?;

// this is used to include `__pydantic_extra__` in serialization on models
if let Some(extra_dict) = extra_dict {
for (key, value) in extra_dict {
Expand All @@ -241,11 +338,7 @@ impl TypeSerializer for GeneralFieldsSerializer {
}
}
}
if let Some(ref computed_fields) = self.computed_fields {
if let Some(model) = td_extra.model {
computed_fields.to_python(model, output_dict, &self.filter, include, exclude, &td_extra)?;
}
}
self.add_computed_fields_python(model, output_dict, include, exclude, extra)?;
Ok(output_dict.into_py(py))
}

Expand All @@ -271,46 +364,23 @@ impl TypeSerializer for GeneralFieldsSerializer {
// If there is already a model registered (from a dataclass, BaseModel)
// then do not touch it
// If there is no model, we (a TypedDict) are the model
let td_extra = Extra {
model: extra.model.map_or_else(|| Some(value), Some),
..*extra
};
let model = extra.model.map_or_else(|| Some(value), Some);
let td_extra = Extra { model, ..*extra };
let expected_len = match self.mode {
FieldsMode::TypedDictAllow => main_dict.len() + option_length!(self.computed_fields),
_ => self.fields.len() + option_length!(extra_dict) + option_length!(self.computed_fields),
FieldsMode::TypedDictAllow => main_dict.len() + self.computed_field_count(),
_ => self.fields.len() + option_length!(extra_dict) + self.computed_field_count(),
};
// NOTE! As above, we maintain the order of the input dict assuming that's right
// we don't both with `used_fields` here because on unions, `to_python(..., mode='json')` is used
let mut map = serializer.serialize_map(Some(expected_len))?;

for (key, value) in main_dict {
if extra.exclude_none && value.is_none() {
continue;
}
let key_str = key_str(key).map_err(py_err_se_err)?;
let extra = Extra {
field_name: Some(key_str),
..td_extra
};
let mut map = self.main_serde_serialize(
main_dict.iter().map(Ok),
expected_len,
serializer,
include,
exclude,
td_extra,
)?;

let filter = self.filter.key_filter(key, include, exclude).map_err(py_err_se_err)?;
if let Some((next_include, next_exclude)) = filter {
if let Some(field) = self.fields.get(key_str) {
if let Some(ref serializer) = field.serializer {
if !exclude_default(value, &extra, serializer).map_err(py_err_se_err)? {
let s = PydanticSerializer::new(value, serializer, next_include, next_exclude, &extra);
let output_key = field.get_key_json(key_str, &extra);
map.serialize_entry(&output_key, &s)?;
}
}
} else if self.mode == FieldsMode::TypedDictAllow {
let output_key = infer_json_key(key, &extra).map_err(py_err_se_err)?;
let s = SerializeInfer::new(value, next_include, next_exclude, &extra);
map.serialize_entry(&output_key, &s)?;
}
// no error case here since unions (which need the error case) use `to_python(..., mode='json')`
}
}
// this is used to include `__pydantic_extra__` in serialization on models
if let Some(extra_dict) = extra_dict {
for (key, value) in extra_dict {
Expand All @@ -319,17 +389,14 @@ impl TypeSerializer for GeneralFieldsSerializer {
}
let filter = self.filter.key_filter(key, include, exclude).map_err(py_err_se_err)?;
if let Some((next_include, next_exclude)) = filter {
let output_key = infer_json_key(key, &td_extra).map_err(py_err_se_err)?;
let s = SerializeInfer::new(value, next_include, next_exclude, &td_extra);
let output_key = infer_json_key(key, extra).map_err(py_err_se_err)?;
let s = SerializeInfer::new(value, next_include, next_exclude, extra);
map.serialize_entry(&output_key, &s)?;
}
}
}
if let Some(ref computed_fields) = self.computed_fields {
if let Some(model) = td_extra.model {
computed_fields.serde_serialize::<S>(model, &mut map, &self.filter, include, exclude, &td_extra)?;
}
}

self.add_computed_fields_json::<S>(model, &mut map, include, exclude, extra)?;
map.end()
}

Expand Down