Skip to content

Commit

Permalink
Merge branch 'dh/input-assocs' into dh/json-cow
Browse files Browse the repository at this point in the history
  • Loading branch information
davidhewitt committed Mar 18, 2024
2 parents 2a3f32c + 230309d commit 3a27298
Show file tree
Hide file tree
Showing 7 changed files with 259 additions and 73 deletions.
78 changes: 56 additions & 22 deletions src/input/input_abstract.rs
@@ -1,7 +1,7 @@
use std::fmt;

use pyo3::exceptions::PyValueError;
use pyo3::types::{PyDict, PyType};
use pyo3::types::{PyDict, PyList, PyType};
use pyo3::{intern, prelude::*};

use crate::errors::{ErrorTypeDefaults, InputValue, ValError, ValResult};
Expand Down Expand Up @@ -42,6 +42,8 @@ impl TryFrom<&str> for InputType {
}
}

pub type ValMatch<T> = ValResult<ValidationMatch<T>>;

/// all types have three methods: `validate_*`, `strict_*`, `lax_*`
/// the convention is to either implement:
/// * `strict_*` & `lax_*` if they have different behavior
Expand Down Expand Up @@ -87,13 +89,13 @@ pub trait Input<'py>: fmt::Debug + ToPyObject {

fn validate_dataclass_args<'a>(&'a self, dataclass_name: &str) -> ValResult<GenericArguments<'a, 'py>>;

fn validate_str(&self, strict: bool, coerce_numbers_to_str: bool) -> ValResult<ValidationMatch<EitherString<'_>>>;
fn validate_str(&self, strict: bool, coerce_numbers_to_str: bool) -> ValMatch<EitherString<'_>>;

fn validate_bytes<'a>(&'a self, strict: bool) -> ValResult<ValidationMatch<EitherBytes<'a, 'py>>>;
fn validate_bytes<'a>(&'a self, strict: bool) -> ValMatch<EitherBytes<'a, 'py>>;

fn validate_bool(&self, strict: bool) -> ValResult<ValidationMatch<bool>>;
fn validate_bool(&self, strict: bool) -> ValMatch<bool>;

fn validate_int(&self, strict: bool) -> ValResult<ValidationMatch<EitherInt<'_>>>;
fn validate_int(&self, strict: bool) -> ValMatch<EitherInt<'_>>;

fn exact_int(&self) -> ValResult<EitherInt<'_>> {
self.validate_int(true).and_then(|val_match| {
Expand All @@ -113,7 +115,7 @@ pub trait Input<'py>: fmt::Debug + ToPyObject {
})
}

fn validate_float(&self, strict: bool) -> ValResult<ValidationMatch<EitherFloat<'_>>>;
fn validate_float(&self, strict: bool) -> ValMatch<EitherFloat<'_>>;

fn validate_decimal(&self, strict: bool, py: Python<'py>) -> ValResult<Bound<'py, PyAny>> {
if strict {
Expand Down Expand Up @@ -145,18 +147,11 @@ pub trait Input<'py>: fmt::Debug + ToPyObject {
self.validate_dict(strict)
}

fn validate_list<'a>(&'a self, strict: bool) -> ValResult<GenericIterable<'a, 'py>> {
if strict {
self.strict_list()
} else {
self.lax_list()
}
}
fn strict_list<'a>(&'a self) -> ValResult<GenericIterable<'a, 'py>>;
#[cfg_attr(has_coverage_attribute, coverage(off))]
fn lax_list<'a>(&'a self) -> ValResult<GenericIterable<'a, 'py>> {
self.strict_list()
}
type List<'a>: Iterable<'py> + AsPyList<'py>
where
Self: 'a;

fn validate_list(&self, strict: bool) -> ValMatch<Self::List<'_>>;

fn validate_tuple<'a>(&'a self, strict: bool) -> ValResult<GenericIterable<'a, 'py>> {
if strict {
Expand Down Expand Up @@ -201,25 +196,25 @@ pub trait Input<'py>: fmt::Debug + ToPyObject {

fn validate_iter(&self) -> ValResult<GenericIterator>;

fn validate_date(&self, strict: bool) -> ValResult<ValidationMatch<EitherDate<'py>>>;
fn validate_date(&self, strict: bool) -> ValMatch<EitherDate<'py>>;

fn validate_time(
&self,
strict: bool,
microseconds_overflow_behavior: speedate::MicrosecondsPrecisionOverflowBehavior,
) -> ValResult<ValidationMatch<EitherTime<'py>>>;
) -> ValMatch<EitherTime<'py>>;

fn validate_datetime(
&self,
strict: bool,
microseconds_overflow_behavior: speedate::MicrosecondsPrecisionOverflowBehavior,
) -> ValResult<ValidationMatch<EitherDateTime<'py>>>;
) -> ValMatch<EitherDateTime<'py>>;

fn validate_timedelta(
&self,
strict: bool,
microseconds_overflow_behavior: speedate::MicrosecondsPrecisionOverflowBehavior,
) -> ValResult<ValidationMatch<EitherTimedelta<'py>>>;
) -> ValMatch<EitherTimedelta<'py>>;
}

/// The problem to solve here is that iterating collections often returns owned
Expand All @@ -238,3 +233,42 @@ impl<'py, T: Input<'py> + ?Sized> BorrowInput<'py> for &'_ T {
self
}
}

pub enum Never {}

// Pairs with Iterable below
pub trait ConsumeIterator<T> {
type Output;
fn consume_iterator(self, iterator: impl Iterator<Item = T>) -> Self::Output;
}

// This slightly awkward trait is used to define types which can be iterable. This formulation
// arises because the Python enums have several different underlying iterator types, and we want to
// be able to dispatch over each of them without overhead.
pub trait Iterable<'py> {
type Input: BorrowInput<'py>;
fn len(&self) -> Option<usize>;
fn iterate<R>(self, consumer: impl ConsumeIterator<PyResult<Self::Input>, Output = R>) -> ValResult<R>;
}

// Necessary for inputs which don't support certain types, e.g. String -> list
impl<'py> Iterable<'py> for Never {
type Input = Bound<'py, PyAny>; // Doesn't really matter what this is
fn len(&self) -> Option<usize> {
unreachable!()
}
fn iterate<R>(self, _consumer: impl ConsumeIterator<PyResult<Self::Input>, Output = R>) -> ValResult<R> {
unreachable!()
}
}

// Optimization pathway for inputs which are already python lists
pub trait AsPyList<'py>: Iterable<'py> {
fn as_py_list(&self) -> Option<&Bound<'py, PyList>>;
}

impl<'py> AsPyList<'py> for Never {
fn as_py_list(&self) -> Option<&Bound<'py, PyList>> {
unreachable!()
}
}
38 changes: 28 additions & 10 deletions src/input/input_json.rs
Expand Up @@ -2,7 +2,8 @@ use std::borrow::Cow;

use jiter::{JsonArray, JsonValue};
use pyo3::prelude::*;
use pyo3::types::{PyDict, PyString};
use pyo3::types::{PyDict, PyList, PyString};
use smallvec::SmallVec;
use speedate::MicrosecondsPrecisionOverflowBehavior;
use strum::EnumMessage;

Expand All @@ -13,6 +14,7 @@ use super::datetime::{
bytes_as_date, bytes_as_datetime, bytes_as_time, bytes_as_timedelta, float_as_datetime, float_as_duration,
float_as_time, int_as_datetime, int_as_duration, int_as_time, EitherDate, EitherDateTime, EitherTime,
};
use super::input_abstract::{AsPyList, ConsumeIterator, Iterable, Never, ValMatch};
use super::return_enums::ValidationMatch;
use super::shared::{float_as_int, int_as_bool, str_as_bool, str_as_float, str_as_int};
use super::{
Expand All @@ -37,7 +39,7 @@ impl From<JsonValue<'_>> for LocItem {
}
}

impl<'py> Input<'py> for JsonValue<'_> {
impl<'py, 'data> Input<'py> for JsonValue<'data> {
fn as_error_value(&self) -> InputValue {
// cloning JsonValue is cheap due to use of Arc
InputValue::Json(self.clone().into_static())
Expand Down Expand Up @@ -172,16 +174,14 @@ impl<'py> Input<'py> for JsonValue<'_> {
self.validate_dict(false)
}

fn validate_list<'a>(&'a self, _strict: bool) -> ValResult<GenericIterable<'a, 'py>> {
type List<'a> = &'a JsonArray<'data> where Self: 'a;

fn validate_list(&self, _strict: bool) -> ValMatch<&JsonArray<'data>> {
match self {
JsonValue::Array(a) => Ok(GenericIterable::JsonArray(a)),
JsonValue::Array(a) => Ok(ValidationMatch::strict(a)),
_ => Err(ValError::new(ErrorTypeDefaults::ListType, self)),
}
}
#[cfg_attr(has_coverage_attribute, coverage(off))]
fn strict_list<'a>(&'a self) -> ValResult<GenericIterable<'a, 'py>> {
self.validate_list(false)
}

fn validate_tuple<'a>(&'a self, _strict: bool) -> ValResult<GenericIterable<'a, 'py>> {
// just as in set's case, List has to be allowed
Expand Down Expand Up @@ -375,8 +375,9 @@ impl<'py> Input<'py> for str {
Err(ValError::new(ErrorTypeDefaults::DictType, self))
}

#[cfg_attr(has_coverage_attribute, coverage(off))]
fn strict_list<'a>(&'a self) -> ValResult<GenericIterable<'a, 'py>> {
type List<'a> = Never;

fn validate_list(&self, _strict: bool) -> ValMatch<Never> {
Err(ValError::new(ErrorTypeDefaults::ListType, self))
}

Expand Down Expand Up @@ -449,3 +450,20 @@ impl BorrowInput<'_> for String {
fn string_to_vec(s: &str) -> JsonArray {
JsonArray::new(s.chars().map(|c| JsonValue::Str(c.to_string().into())).collect())
}

impl<'a, 'data> Iterable<'_> for &'a JsonArray<'data> {
type Input = &'a JsonValue<'data>;

fn len(&self) -> Option<usize> {
Some(SmallVec::len(self))
}
fn iterate<R>(self, consumer: impl ConsumeIterator<PyResult<Self::Input>, Output = R>) -> ValResult<R> {
Ok(consumer.consume_iterator(self.iter().map(Ok)))
}
}

impl<'py> AsPyList<'py> for &'_ JsonArray<'_> {
fn as_py_list(&self) -> Option<&Bound<'py, PyList>> {
None
}
}
34 changes: 18 additions & 16 deletions src/input/input_python.rs
Expand Up @@ -23,6 +23,7 @@ use super::datetime::{
float_as_duration, float_as_time, int_as_datetime, int_as_duration, int_as_time, EitherDate, EitherDateTime,
EitherTime,
};
use super::input_abstract::ValMatch;
use super::return_enums::ValidationMatch;
use super::shared::{
decimal_as_int, float_as_int, get_enum_meta_object, int_as_bool, str_as_bool, str_as_float, str_as_int,
Expand Down Expand Up @@ -461,24 +462,25 @@ impl<'py> Input<'py> for Bound<'py, PyAny> {
}
}

fn strict_list<'a>(&'a self) -> ValResult<GenericIterable<'a, 'py>> {
match self.lax_list()? {
GenericIterable::List(iter) => Ok(GenericIterable::List(iter)),
_ => Err(ValError::new(ErrorTypeDefaults::ListType, self)),
}
}
type List<'a> = GenericIterable<'a, 'py> where Self: 'a;

fn lax_list<'a>(&'a self) -> ValResult<GenericIterable<'a, 'py>> {
match self
.extract_generic_iterable()
.map_err(|_| ValError::new(ErrorTypeDefaults::ListType, self))?
{
GenericIterable::PyString(_)
| GenericIterable::Bytes(_)
| GenericIterable::Dict(_)
| GenericIterable::Mapping(_) => Err(ValError::new(ErrorTypeDefaults::ListType, self)),
other => Ok(other),
fn validate_list<'a>(&'a self, strict: bool) -> ValMatch<GenericIterable<'a, 'py>> {
if let Ok(list) = self.downcast::<PyList>() {
return Ok(ValidationMatch::exact(GenericIterable::List(list)));
} else if !strict {
match self.extract_generic_iterable() {
Ok(
GenericIterable::PyString(_)
| GenericIterable::Bytes(_)
| GenericIterable::Dict(_)
| GenericIterable::Mapping(_),
)
| Err(_) => {}
Ok(other) => return Ok(ValidationMatch::lax(other)),
}
}

Err(ValError::new(ErrorTypeDefaults::ListType, self))
}

fn strict_tuple<'a>(&'a self) -> ValResult<GenericIterable<'a, 'py>> {
Expand Down
5 changes: 4 additions & 1 deletion src/input/input_string.rs
Expand Up @@ -11,6 +11,7 @@ use crate::validators::decimal::create_decimal;
use super::datetime::{
bytes_as_date, bytes_as_datetime, bytes_as_time, bytes_as_timedelta, EitherDate, EitherDateTime, EitherTime,
};
use super::input_abstract::{Never, ValMatch};
use super::shared::{str_as_bool, str_as_float, str_as_int};
use super::{
BorrowInput, EitherBytes, EitherFloat, EitherInt, EitherString, EitherTimedelta, GenericArguments, GenericIterable,
Expand Down Expand Up @@ -138,7 +139,9 @@ impl<'py> Input<'py> for StringMapping<'py> {
}
}

fn strict_list<'a>(&'a self) -> ValResult<GenericIterable<'a, 'py>> {
type List<'a> = Never where Self: 'a;

fn validate_list(&self, _strict: bool) -> ValMatch<Never> {
Err(ValError::new(ErrorTypeDefaults::ListType, self))
}

Expand Down
9 changes: 5 additions & 4 deletions src/input/mod.rs
Expand Up @@ -15,12 +15,13 @@ pub(crate) use datetime::{
duration_as_pytimedelta, pydate_as_date, pydatetime_as_datetime, pytime_as_time, EitherDate, EitherDateTime,
EitherTime, EitherTimedelta,
};
pub(crate) use input_abstract::{BorrowInput, Input, InputType};
pub(crate) use input_abstract::{AsPyList, BorrowInput, ConsumeIterator, Input, InputType, Iterable};
pub(crate) use input_string::StringMapping;
pub(crate) use return_enums::{
py_string_str, AttributesGenericIterator, DictGenericIterator, EitherBytes, EitherFloat, EitherInt, EitherString,
GenericArguments, GenericIterable, GenericIterator, GenericMapping, Int, JsonArgs, JsonObjectGenericIterator,
MappingGenericIterator, PyArgs, StringMappingGenericIterator, ValidationMatch,
no_validator_iter_to_vec, py_string_str, validate_iter_to_vec, AttributesGenericIterator, DictGenericIterator,
EitherBytes, EitherFloat, EitherInt, EitherString, GenericArguments, GenericIterable, GenericIterator,
GenericMapping, Int, JsonArgs, JsonObjectGenericIterator, MappingGenericIterator, MaxLengthCheck, PyArgs,
StringMappingGenericIterator, ValidationMatch,
};

// Defined here as it's not exported by pyo3
Expand Down

0 comments on commit 3a27298

Please sign in to comment.