Skip to content

Commit

Permalink
dataclass serialization speedups (#1162)
Browse files Browse the repository at this point in the history
  • Loading branch information
samuelcolvin committed Jan 16, 2024
1 parent e1cb0eb commit 5a1385b
Show file tree
Hide file tree
Showing 5 changed files with 362 additions and 187 deletions.
257 changes: 162 additions & 95 deletions src/serializers/fields.rs
Expand Up @@ -100,6 +100,15 @@ pub struct GeneralFieldsSerializer {
required_fields: usize,
}

macro_rules! option_length {
($op_has_len:expr) => {
match $op_has_len {
Some(ref has_len) => has_len.len(),
None => 0,
}
};
}

impl GeneralFieldsSerializer {
pub(super) fn new(
fields: AHashMap<String, SerField>,
Expand Down Expand Up @@ -136,50 +145,21 @@ impl GeneralFieldsSerializer {
}
}
}
}

macro_rules! option_length {
($op_has_len:expr) => {
match $op_has_len {
Some(ref has_len) => has_len.len(),
None => 0,
}
};
}

impl_py_gc_traverse!(GeneralFieldsSerializer {
fields,
computed_fields
});

impl TypeSerializer for GeneralFieldsSerializer {
fn to_python(
pub fn main_to_python<'py>(
&self,
value: &PyAny,
include: Option<&PyAny>,
exclude: Option<&PyAny>,
extra: &Extra,
) -> PyResult<PyObject> {
let py = value.py();
// If there is already a model registered (from a dataclass, BaseModel)
// then do not touch it
// If there is no model, we (a TypedDict) are the model
let td_extra = Extra {
model: extra.model.map_or_else(|| Some(value), Some),
..*extra
};
let (main_dict, extra_dict) = if let Some(main_extra_dict) = self.extract_dicts(value) {
main_extra_dict
} else {
td_extra.warnings.on_fallback_py(self.get_name(), value, &td_extra)?;
return infer_to_python(value, include, exclude, &td_extra);
};

py: Python<'py>,
main_iter: impl Iterator<Item = PyResult<(&'py PyAny, &'py PyAny)>>,
include: Option<&'py PyAny>,
exclude: Option<&'py PyAny>,
extra: Extra,
) -> PyResult<&'py PyDict> {
let output_dict = PyDict::new(py);
let mut used_req_fields: usize = 0;

// NOTE! we maintain the order of the input dict assuming that's right
for (key, value) in main_dict {
for result in main_iter {
let (key, value) = result?;
let key_str = key_str(key)?;
let op_field = self.fields.get(key_str);
if extra.exclude_none && value.is_none() {
Expand All @@ -190,16 +170,16 @@ impl TypeSerializer for GeneralFieldsSerializer {
}
continue;
}
let extra = Extra {
let field_extra = Extra {
field_name: Some(key_str),
..td_extra
..extra
};
if let Some((next_include, next_exclude)) = self.filter.key_filter(key, include, exclude)? {
if let Some(field) = op_field {
if let Some(ref serializer) = field.serializer {
if !exclude_default(value, &extra, serializer)? {
let value = serializer.to_python(value, next_include, next_exclude, &extra)?;
let output_key = field.get_key_py(output_dict.py(), &extra);
if !exclude_default(value, &field_extra, serializer)? {
let value = serializer.to_python(value, next_include, next_exclude, &field_extra)?;
let output_key = field.get_key_py(output_dict.py(), &field_extra);
output_dict.set_item(output_key, value)?;
}
}
Expand All @@ -209,23 +189,140 @@ impl TypeSerializer for GeneralFieldsSerializer {
}
} else if self.mode == FieldsMode::TypedDictAllow {
let value = match &self.extra_serializer {
Some(serializer) => serializer.to_python(value, next_include, next_exclude, &extra)?,
None => infer_to_python(value, next_include, next_exclude, &extra)?,
Some(serializer) => serializer.to_python(value, next_include, next_exclude, &field_extra)?,
None => infer_to_python(value, next_include, next_exclude, &field_extra)?,
};
output_dict.set_item(key, value)?;
} else if extra.check == SerCheck::Strict {
} else if field_extra.check == SerCheck::Strict {
return Err(PydanticSerializationUnexpectedValue::new_err(None));
}
}
}
if td_extra.check.enabled()

if extra.check.enabled()
// If any of these are true we can't count fields
&& !(extra.exclude_defaults || extra.exclude_unset || extra.exclude_none)
// Check for missing fields, we can't have extra fields here
&& self.required_fields > used_req_fields
{
return Err(PydanticSerializationUnexpectedValue::new_err(None));
Err(PydanticSerializationUnexpectedValue::new_err(None))
} else {
Ok(output_dict)
}
}

pub fn main_serde_serialize<'py, S: serde::ser::Serializer>(
&self,
main_iter: impl Iterator<Item = PyResult<(&'py PyAny, &'py PyAny)>>,
expected_len: usize,
serializer: S,
include: Option<&'py PyAny>,
exclude: Option<&'py PyAny>,
extra: Extra,
) -> Result<S::SerializeMap, S::Error> {
// NOTE! As above, we maintain the order of the input dict assuming that's right
// we don't both with `used_fields` here because on unions, `to_python(..., mode='json')` is used
let mut map = serializer.serialize_map(Some(expected_len))?;

for result in main_iter {
let (key, value) = result.map_err(py_err_se_err)?;
if extra.exclude_none && value.is_none() {
continue;
}
let key_str = key_str(key).map_err(py_err_se_err)?;
let field_extra = Extra {
field_name: Some(key_str),
..extra
};

let filter = self.filter.key_filter(key, include, exclude).map_err(py_err_se_err)?;
if let Some((next_include, next_exclude)) = filter {
if let Some(field) = self.fields.get(key_str) {
if let Some(ref serializer) = field.serializer {
if !exclude_default(value, &field_extra, serializer).map_err(py_err_se_err)? {
let s =
PydanticSerializer::new(value, serializer, next_include, next_exclude, &field_extra);
let output_key = field.get_key_json(key_str, &field_extra);
map.serialize_entry(&output_key, &s)?;
}
}
} else if self.mode == FieldsMode::TypedDictAllow {
let output_key = infer_json_key(key, &field_extra).map_err(py_err_se_err)?;
let s = SerializeInfer::new(value, next_include, next_exclude, &field_extra);
map.serialize_entry(&output_key, &s)?;
}
// no error case here since unions (which need the error case) use `to_python(..., mode='json')`
}
}
Ok(map)
}

pub fn add_computed_fields_python(
&self,
model: Option<&PyAny>,
output_dict: &PyDict,
include: Option<&PyAny>,
exclude: Option<&PyAny>,
extra: &Extra,
) -> PyResult<()> {
if let Some(ref computed_fields) = self.computed_fields {
if let Some(model_value) = model {
let cf_extra = Extra { model, ..*extra };
computed_fields.to_python(model_value, output_dict, &self.filter, include, exclude, &cf_extra)?;
}
}
Ok(())
}

pub fn add_computed_fields_json<S: serde::ser::Serializer>(
&self,
model: Option<&PyAny>,
map: &mut S::SerializeMap,
include: Option<&PyAny>,
exclude: Option<&PyAny>,
extra: &Extra,
) -> Result<(), S::Error> {
if let Some(ref computed_fields) = self.computed_fields {
if let Some(model) = model {
computed_fields.serde_serialize::<S>(model, map, &self.filter, include, exclude, extra)?;
}
}
Ok(())
}

pub fn computed_field_count(&self) -> usize {
option_length!(self.computed_fields)
}
}

impl_py_gc_traverse!(GeneralFieldsSerializer {
fields,
computed_fields
});

impl TypeSerializer for GeneralFieldsSerializer {
fn to_python(
&self,
value: &PyAny,
include: Option<&PyAny>,
exclude: Option<&PyAny>,
extra: &Extra,
) -> PyResult<PyObject> {
let py = value.py();
// If there is already a model registered (from a dataclass, BaseModel)
// then do not touch it
// If there is no model, we (a TypedDict) are the model
let model = extra.model.map_or_else(|| Some(value), Some);
let td_extra = Extra { model, ..*extra };
let (main_dict, extra_dict) = if let Some(main_extra_dict) = self.extract_dicts(value) {
main_extra_dict
} else {
td_extra.warnings.on_fallback_py(self.get_name(), value, &td_extra)?;
return infer_to_python(value, include, exclude, &td_extra);
};

let output_dict = self.main_to_python(py, main_dict.iter().map(Ok), include, exclude, td_extra)?;

// this is used to include `__pydantic_extra__` in serialization on models
if let Some(extra_dict) = extra_dict {
for (key, value) in extra_dict {
Expand All @@ -241,11 +338,7 @@ impl TypeSerializer for GeneralFieldsSerializer {
}
}
}
if let Some(ref computed_fields) = self.computed_fields {
if let Some(model) = td_extra.model {
computed_fields.to_python(model, output_dict, &self.filter, include, exclude, &td_extra)?;
}
}
self.add_computed_fields_python(model, output_dict, include, exclude, extra)?;
Ok(output_dict.into_py(py))
}

Expand All @@ -271,46 +364,23 @@ impl TypeSerializer for GeneralFieldsSerializer {
// If there is already a model registered (from a dataclass, BaseModel)
// then do not touch it
// If there is no model, we (a TypedDict) are the model
let td_extra = Extra {
model: extra.model.map_or_else(|| Some(value), Some),
..*extra
};
let model = extra.model.map_or_else(|| Some(value), Some);
let td_extra = Extra { model, ..*extra };
let expected_len = match self.mode {
FieldsMode::TypedDictAllow => main_dict.len() + option_length!(self.computed_fields),
_ => self.fields.len() + option_length!(extra_dict) + option_length!(self.computed_fields),
FieldsMode::TypedDictAllow => main_dict.len() + self.computed_field_count(),
_ => self.fields.len() + option_length!(extra_dict) + self.computed_field_count(),
};
// NOTE! As above, we maintain the order of the input dict assuming that's right
// we don't both with `used_fields` here because on unions, `to_python(..., mode='json')` is used
let mut map = serializer.serialize_map(Some(expected_len))?;

for (key, value) in main_dict {
if extra.exclude_none && value.is_none() {
continue;
}
let key_str = key_str(key).map_err(py_err_se_err)?;
let extra = Extra {
field_name: Some(key_str),
..td_extra
};
let mut map = self.main_serde_serialize(
main_dict.iter().map(Ok),
expected_len,
serializer,
include,
exclude,
td_extra,
)?;

let filter = self.filter.key_filter(key, include, exclude).map_err(py_err_se_err)?;
if let Some((next_include, next_exclude)) = filter {
if let Some(field) = self.fields.get(key_str) {
if let Some(ref serializer) = field.serializer {
if !exclude_default(value, &extra, serializer).map_err(py_err_se_err)? {
let s = PydanticSerializer::new(value, serializer, next_include, next_exclude, &extra);
let output_key = field.get_key_json(key_str, &extra);
map.serialize_entry(&output_key, &s)?;
}
}
} else if self.mode == FieldsMode::TypedDictAllow {
let output_key = infer_json_key(key, &extra).map_err(py_err_se_err)?;
let s = SerializeInfer::new(value, next_include, next_exclude, &extra);
map.serialize_entry(&output_key, &s)?;
}
// no error case here since unions (which need the error case) use `to_python(..., mode='json')`
}
}
// this is used to include `__pydantic_extra__` in serialization on models
if let Some(extra_dict) = extra_dict {
for (key, value) in extra_dict {
Expand All @@ -319,17 +389,14 @@ impl TypeSerializer for GeneralFieldsSerializer {
}
let filter = self.filter.key_filter(key, include, exclude).map_err(py_err_se_err)?;
if let Some((next_include, next_exclude)) = filter {
let output_key = infer_json_key(key, &td_extra).map_err(py_err_se_err)?;
let s = SerializeInfer::new(value, next_include, next_exclude, &td_extra);
let output_key = infer_json_key(key, extra).map_err(py_err_se_err)?;
let s = SerializeInfer::new(value, next_include, next_exclude, extra);
map.serialize_entry(&output_key, &s)?;
}
}
}
if let Some(ref computed_fields) = self.computed_fields {
if let Some(model) = td_extra.model {
computed_fields.serde_serialize::<S>(model, &mut map, &self.filter, include, exclude, &td_extra)?;
}
}

self.add_computed_fields_json::<S>(model, &mut map, include, exclude, extra)?;
map.end()
}

Expand Down

0 comments on commit 5a1385b

Please sign in to comment.