Skip to content

Commit

Permalink
caching strings from JSON (#1240)
Browse files Browse the repository at this point in the history
Co-authored-by: David Hewitt <mail@davidhewitt.dev>
  • Loading branch information
samuelcolvin and davidhewitt committed Mar 27, 2024
1 parent c607fd8 commit 2e2c139
Show file tree
Hide file tree
Showing 27 changed files with 195 additions and 59 deletions.
2 changes: 2 additions & 0 deletions .mypy-stubtest-allowlist
@@ -1,2 +1,4 @@
# TODO: don't want to expose this staticmethod, requires https://github.com/PyO3/pyo3/issues/2384
pydantic_core._pydantic_core.PydanticUndefinedType.new
# As per #1240, from_json has custom logic to coverage the `cache_strings` kwarg
pydantic_core._pydantic_core.from_json
15 changes: 2 additions & 13 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Expand Up @@ -44,7 +44,7 @@ base64 = "0.21.7"
num-bigint = "0.4.4"
python3-dll-a = "0.2.7"
uuid = "1.7.0"
jiter = { version = "0.1.0", features = ["python"] }
jiter = { version = "0.1.1", features = ["python"] }

[lib]
name = "_pydantic_core"
Expand Down
15 changes: 12 additions & 3 deletions python/pydantic_core/_pydantic_core.pyi
Expand Up @@ -390,17 +390,26 @@ def to_json(
JSON bytes.
"""

def from_json(data: str | bytes | bytearray, *, allow_inf_nan: bool = True, cache_strings: bool = True) -> Any:
def from_json(
data: str | bytes | bytearray,
*,
allow_inf_nan: bool = True,
cache_strings: bool | Literal['all', 'keys', 'none'] = True,
allow_partial: bool = False,
) -> Any:
"""
Deserialize JSON data to a Python object.
This is effectively a faster version of `json.loads()`.
This is effectively a faster version of `json.loads()`, with some extra functionality.
Arguments:
data: The JSON data to deserialize.
allow_inf_nan: Whether to allow `Infinity`, `-Infinity` and `NaN` values as `json.loads()` does by default.
cache_strings: Whether to cache strings to avoid constructing new Python objects,
this should have a significant impact on performance while increasing memory usage slightly.
this should have a significant impact on performance while increasing memory usage slightly,
`all/True` means cache all strings, `keys` means cache only dict keys, `none/False` means no caching.
allow_partial: Whether to allow partial deserialization, if `True` JSON data is returned if the end of the
input is reached before the full object is deserialized, e.g. `["aa", "bb", "c` would return `['aa', 'bb']`.
Raises:
ValueError: If deserialization fails.
Expand Down
3 changes: 3 additions & 0 deletions python/pydantic_core/core_schema.py
Expand Up @@ -75,6 +75,8 @@ class CoreConfig(TypedDict, total=False):
Requires exceptiongroup backport pre Python 3.11.
coerce_numbers_to_str: Whether to enable coercion of any `Number` type to `str` (not applicable in `strict` mode).
regex_engine: The regex engine to use for regex pattern validation. Default is 'rust-regex'. See `StringSchema`.
cache_strings: Whether to cache strings. Default is `True`, `True` or `'all'` is required to cache strings
during general validation since validators don't know if they're in a key or a value.
"""

title: str
Expand Down Expand Up @@ -110,6 +112,7 @@ class CoreConfig(TypedDict, total=False):
validation_error_cause: bool # default: False
coerce_numbers_to_str: bool # default: False
regex_engine: Literal['rust-regex', 'python-re'] # default: 'rust-regex'
cache_strings: Union[bool, Literal['all', 'keys', 'none']] # default: 'True'


IncExCall: TypeAlias = 'set[int | str] | dict[int | str, IncExCall] | None'
Expand Down
18 changes: 9 additions & 9 deletions src/input/return_enums.rs
Expand Up @@ -3,7 +3,7 @@ use std::cmp::Ordering;
use std::ops::Rem;
use std::str::FromStr;

use jiter::{JsonArray, JsonValue};
use jiter::{JsonArray, JsonValue, StringCacheMode};
use num_bigint::BigInt;

use pyo3::exceptions::PyTypeError;
Expand Down Expand Up @@ -435,9 +435,15 @@ impl<'a> EitherString<'a> {
}
}

pub fn as_py_string(&'a self, py: Python<'a>) -> Bound<'a, PyString> {
pub fn as_py_string(&'a self, py: Python<'a>, cache_str: StringCacheMode) -> Bound<'a, PyString> {
match self {
Self::Cow(cow) => PyString::new_bound(py, cow),
Self::Cow(cow) => {
if matches!(cache_str, StringCacheMode::All) {
jiter::cached_py_string(py, cow.as_ref())
} else {
PyString::new_bound(py, cow.as_ref())
}
}
Self::Py(py_string) => py_string.clone(),
}
}
Expand All @@ -461,12 +467,6 @@ impl<'a> From<Bound<'a, PyString>> for EitherString<'a> {
}
}

impl<'a> IntoPy<PyObject> for EitherString<'a> {
fn into_py(self, py: Python<'_>) -> PyObject {
self.as_py_string(py).into_py(py)
}
}

pub fn py_string_str<'a>(py_str: &'a Bound<'_, PyString>) -> ValResult<&'a str> {
py_str.to_str().map_err(|_| {
ValError::new_custom_input(
Expand Down
19 changes: 16 additions & 3 deletions src/lib.rs
Expand Up @@ -4,6 +4,7 @@ extern crate core;

use std::sync::OnceLock;

use jiter::StringCacheMode;
use pyo3::exceptions::PyTypeError;
use pyo3::{prelude::*, sync::GILOnceCell};

Expand Down Expand Up @@ -38,19 +39,31 @@ pub use validators::{validate_core_schema, PySome, SchemaValidator};

use crate::input::Input;

#[pyfunction(signature = (data, *, allow_inf_nan=true, cache_strings=true))]
#[derive(FromPyObject)]
pub enum CacheStringsArg {
Bool(bool),
Literal(StringCacheMode),
}

#[pyfunction(signature = (data, *, allow_inf_nan=true, cache_strings=CacheStringsArg::Bool(true), allow_partial=false))]
pub fn from_json<'py>(
py: Python<'py>,
data: &Bound<'_, PyAny>,
allow_inf_nan: bool,
cache_strings: bool,
cache_strings: CacheStringsArg,
allow_partial: bool,
) -> PyResult<Bound<'py, PyAny>> {
let v_match = data
.validate_bytes(false)
.map_err(|_| PyTypeError::new_err("Expected bytes, bytearray or str"))?;
let json_either_bytes = v_match.into_inner();
let json_bytes = json_either_bytes.as_slice();
jiter::python_parse(py, json_bytes, allow_inf_nan, cache_strings).map_err(|e| jiter::map_json_error(json_bytes, &e))
let cache_mode = match cache_strings {
CacheStringsArg::Bool(b) => b.into(),
CacheStringsArg::Literal(mode) => mode,
};
jiter::python_parse(py, json_bytes, allow_inf_nan, cache_mode, allow_partial)
.map_err(|e| jiter::map_json_error(json_bytes, &e))
}

pub fn get_pydantic_core_version() -> &'static str {
Expand Down
9 changes: 6 additions & 3 deletions src/validators/arguments.rs
Expand Up @@ -55,7 +55,8 @@ impl BuildValidator for ArgumentsValidator {
for (arg_index, arg) in arguments_schema.iter().enumerate() {
let arg = arg.downcast::<PyDict>()?;

let name: String = arg.get_as_req(intern!(py, "name"))?;
let py_name: Bound<PyString> = arg.get_as_req(intern!(py, "name"))?;
let name = py_name.to_string();
let mode = arg.get_as::<Bound<'_, PyString>>(intern!(py, "mode"))?;
let mode = mode
.as_ref()
Expand All @@ -77,7 +78,7 @@ impl BuildValidator for ArgumentsValidator {
}
None => Some(LookupKey::from_string(py, &name)),
};
kwarg_key = Some(PyString::new_bound(py, &name).into());
kwarg_key = Some(py_name.into_py(py));
}

let schema = arg.get_as_req(intern!(py, "schema"))?;
Expand Down Expand Up @@ -274,7 +275,9 @@ impl Validator for ArgumentsValidator {
if !used_kwargs.contains(either_str.as_cow()?.as_ref()) {
match self.var_kwargs_validator {
Some(ref validator) => match validator.validate(py, value.borrow_input(), state) {
Ok(value) => output_kwargs.set_item(either_str.as_py_string(py), value)?,
Ok(value) => {
output_kwargs.set_item(either_str.as_py_string(py, state.cache_str()), value)?;
}
Err(ValError::LineErrors(line_errors)) => {
for err in line_errors {
errors.push(err.with_outer_location(raw_key.clone()));
Expand Down
10 changes: 7 additions & 3 deletions src/validators/dataclass.rs
Expand Up @@ -302,7 +302,10 @@ impl Validator for DataclassArgsValidator {
if let Some(ref validator) = self.extras_validator {
match validator.validate(py, value.borrow_input(), state) {
Ok(value) => {
output_dict.set_item(either_str.as_py_string(py), value)?;
output_dict.set_item(
either_str.as_py_string(py, state.cache_str()),
value,
)?;
}
Err(ValError::LineErrors(line_errors)) => {
for err in line_errors {
Expand All @@ -312,7 +315,8 @@ impl Validator for DataclassArgsValidator {
Err(err) => return Err(err),
}
} else {
output_dict.set_item(either_str.as_py_string(py), value)?;
output_dict
.set_item(either_str.as_py_string(py, state.cache_str()), value)?;
}
}
}
Expand Down Expand Up @@ -455,7 +459,7 @@ impl BuildValidator for DataclassValidator {
let validator = build_validator(&sub_schema, config, definitions)?;

let post_init = if schema.get_as::<bool>(intern!(py, "post_init"))?.unwrap_or(false) {
Some(PyString::new_bound(py, "__post_init__").into())
Some(intern!(py, "__post_init__").into_py(py))
} else {
None
};
Expand Down
4 changes: 4 additions & 0 deletions src/validators/generator.rs
Expand Up @@ -219,6 +219,7 @@ pub struct InternalValidator {
validation_mode: InputType,
hide_input_in_errors: bool,
validation_error_cause: bool,
cache_str: jiter::StringCacheMode,
}

impl fmt::Debug for InternalValidator {
Expand Down Expand Up @@ -250,6 +251,7 @@ impl InternalValidator {
validation_mode: extra.input_type,
hide_input_in_errors,
validation_error_cause,
cache_str: extra.cache_str,
}
}

Expand All @@ -268,6 +270,7 @@ impl InternalValidator {
from_attributes: self.from_attributes,
context: self.context.as_ref().map(|data| data.bind(py)),
self_instance: self.self_instance.as_ref().map(|data| data.bind(py)),
cache_str: self.cache_str,
};
let mut state = ValidationState::new(extra, &mut self.recursion_guard);
state.exactness = self.exactness;
Expand Down Expand Up @@ -302,6 +305,7 @@ impl InternalValidator {
from_attributes: self.from_attributes,
context: self.context.as_ref().map(|data| data.bind(py)),
self_instance: self.self_instance.as_ref().map(|data| data.bind(py)),
cache_str: self.cache_str,
};
let mut state = ValidationState::new(extra, &mut self.recursion_guard);
state.exactness = self.exactness;
Expand Down
4 changes: 2 additions & 2 deletions src/validators/json.rs
Expand Up @@ -64,8 +64,8 @@ impl Validator for JsonValidator {
validator.validate(py, &json_value, &mut json_state)
}
None => {
let obj =
jiter::python_parse(py, json_bytes, true, true).map_err(|e| map_json_err(input, e, json_bytes))?;
let obj = jiter::python_parse(py, json_bytes, true, state.cache_str(), false)
.map_err(|e| map_json_err(input, e, json_bytes))?;
Ok(obj.unbind())
}
}
Expand Down

0 comments on commit 2e2c139

Please sign in to comment.