Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

replace ultra_strict with new union implementation #867

Merged
merged 15 commits into from Nov 8, 2023
Merged
12 changes: 0 additions & 12 deletions src/definitions.rs
Expand Up @@ -31,21 +31,9 @@ use crate::{build_tools::py_schema_err, py_gc::PyGcTraverse};
#[derive(Clone)]
pub struct Definitions<T>(AHashMap<Arc<String>, Definition<T>>);

impl<T> Definitions<T> {
pub fn values(&self) -> impl Iterator<Item = &Definition<T>> {
self.0.values()
}
}

/// Internal type which contains a definition to be filled
pub struct Definition<T>(Arc<DefinitionInner<T>>);

impl<T> Definition<T> {
pub fn get(&self) -> Option<&T> {
self.0.value.get()
}
}

struct DefinitionInner<T> {
value: OnceLock<T>,
name: LazyName,
Expand Down
79 changes: 21 additions & 58 deletions src/input/input_abstract.rs
Expand Up @@ -6,13 +6,13 @@ use pyo3::{intern, prelude::*};

use jiter::JsonValue;

use crate::errors::{AsLocItem, InputValue, ValResult};
use crate::errors::{AsLocItem, ErrorTypeDefaults, InputValue, ValError, ValResult};
use crate::tools::py_err;
use crate::{PyMultiHostUrl, PyUrl};

use super::datetime::{EitherDate, EitherDateTime, EitherTime, EitherTimedelta};
use super::return_enums::{EitherBytes, EitherInt, EitherString};
use super::{EitherFloat, GenericArguments, GenericIterable, GenericIterator, GenericMapping};
use super::{EitherFloat, GenericArguments, GenericIterable, GenericIterator, GenericMapping, ValidationMatch};

#[derive(Debug, Clone, Copy)]
pub enum InputType {
Expand Down Expand Up @@ -48,7 +48,7 @@ impl TryFrom<&str> for InputType {
/// the convention is to either implement:
/// * `strict_*` & `lax_*` if they have different behavior
/// * or, `validate_*` and `strict_*` to just call `validate_*` if the behavior for strict and lax is the same
pub trait Input<'a>: fmt::Debug + ToPyObject + AsLocItem {
pub trait Input<'a>: fmt::Debug + ToPyObject + AsLocItem + Sized {
fn as_error_value(&'a self) -> InputValue<'a>;

fn identity(&self) -> Option<usize> {
Expand Down Expand Up @@ -91,18 +91,11 @@ pub trait Input<'a>: fmt::Debug + ToPyObject + AsLocItem {

fn parse_json(&'a self) -> ValResult<'a, JsonValue>;

fn validate_str(&'a self, strict: bool, coerce_numbers_to_str: bool) -> ValResult<EitherString<'a>> {
if strict {
self.strict_str()
} else {
self.lax_str(coerce_numbers_to_str)
}
}
fn strict_str(&'a self) -> ValResult<EitherString<'a>>;
#[cfg_attr(has_coverage_attribute, coverage(off))]
fn lax_str(&'a self, _coerce_numbers_to_str: bool) -> ValResult<EitherString<'a>> {
self.strict_str()
}
fn validate_str(
&'a self,
strict: bool,
coerce_numbers_to_str: bool,
) -> ValResult<ValidationMatch<EitherString<'a>>>;

fn validate_bytes(&'a self, strict: bool) -> ValResult<EitherBytes<'a>> {
if strict {
Expand All @@ -117,59 +110,29 @@ pub trait Input<'a>: fmt::Debug + ToPyObject + AsLocItem {
self.strict_bytes()
}

fn validate_bool(&self, strict: bool) -> ValResult<bool> {
if strict {
self.strict_bool()
} else {
self.lax_bool()
}
}
fn strict_bool(&self) -> ValResult<bool>;
#[cfg_attr(has_coverage_attribute, coverage(off))]
fn lax_bool(&self) -> ValResult<bool> {
self.strict_bool()
}
fn validate_bool(&self, strict: bool) -> ValResult<'_, ValidationMatch<bool>>;

fn validate_int(&'a self, strict: bool) -> ValResult<EitherInt<'a>> {
if strict {
self.strict_int()
} else {
self.lax_int()
}
}
fn strict_int(&'a self) -> ValResult<EitherInt<'a>>;
#[cfg_attr(has_coverage_attribute, coverage(off))]
fn lax_int(&'a self) -> ValResult<EitherInt<'a>> {
self.strict_int()
}
fn validate_int(&'a self, strict: bool) -> ValResult<'a, ValidationMatch<EitherInt<'a>>>;

/// Extract an EitherInt from the input, only allowing exact
/// matches for an Int (no subclasses)
fn exact_int(&'a self) -> ValResult<EitherInt<'a>> {
self.strict_int()
self.validate_int(true).and_then(|val_match| {
val_match
.require_exact()
.ok_or_else(|| ValError::new(ErrorTypeDefaults::IntType, self))
})
}

/// Extract a String from the input, only allowing exact
/// matches for a String (no subclasses)
fn exact_str(&'a self) -> ValResult<EitherString<'a>> {
samuelcolvin marked this conversation as resolved.
Show resolved Hide resolved
self.strict_str()
self.validate_str(true, false).and_then(|val_match| {
val_match
.require_exact()
.ok_or_else(|| ValError::new(ErrorTypeDefaults::StringType, self))
})
}

fn validate_float(&'a self, strict: bool, ultra_strict: bool) -> ValResult<EitherFloat<'a>> {
if ultra_strict {
self.ultra_strict_float()
} else if strict {
self.strict_float()
} else {
self.lax_float()
}
}
fn ultra_strict_float(&'a self) -> ValResult<EitherFloat<'a>>;
fn strict_float(&'a self) -> ValResult<EitherFloat<'a>>;
#[cfg_attr(has_coverage_attribute, coverage(off))]
fn lax_float(&'a self) -> ValResult<EitherFloat<'a>> {
self.strict_float()
}
fn validate_float(&'a self, strict: bool) -> ValResult<'a, ValidationMatch<EitherFloat<'a>>>;

fn validate_decimal(&'a self, strict: bool, py: Python<'a>) -> ValResult<&'a PyAny> {
if strict {
Expand Down
118 changes: 52 additions & 66 deletions src/input/input_json.rs
Expand Up @@ -13,6 +13,7 @@ use super::datetime::{
bytes_as_date, bytes_as_datetime, bytes_as_time, bytes_as_timedelta, float_as_datetime, float_as_duration,
float_as_time, int_as_datetime, int_as_duration, int_as_time, EitherDate, EitherDateTime, EitherTime,
};
use super::return_enums::ValidationMatch;
use super::shared::{float_as_int, int_as_bool, map_json_err, str_as_bool, str_as_float, str_as_int};
use super::{
BorrowInput, EitherBytes, EitherFloat, EitherInt, EitherString, EitherTimedelta, GenericArguments, GenericIterable,
Expand Down Expand Up @@ -84,18 +85,28 @@ impl<'a> Input<'a> for JsonValue {
}
}

fn strict_str(&'a self) -> ValResult<EitherString<'a>> {
fn exact_str(&'a self) -> ValResult<EitherString<'a>> {
match self {
JsonValue::Str(s) => Ok(s.as_str().into()),
_ => Err(ValError::new(ErrorTypeDefaults::StringType, self)),
}
}
fn lax_str(&'a self, coerce_numbers_to_str: bool) -> ValResult<EitherString<'a>> {

fn validate_str(
&'a self,
strict: bool,
coerce_numbers_to_str: bool,
) -> ValResult<ValidationMatch<EitherString<'a>>> {
// Justification for `strict` instead of `exact` is that in JSON strings can also
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will this want to change in future, if so can we leave a consistent TODO: V3 comment to make it easy to find in future.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This depends what you think of e.g. pydantic/pydantic#7097. The current semantic that this PR preserves is that UUID is "as good" as str for a UUID-format JSON string, so it's dependent on ordering. Changing this to be "exact" means that in JSON mode a str in a union wins over UUID, datetime, etc, irrespective of position.

That is a positive change for simplicity and matches Python semantics better, so there is a case to consider breaking this, but it's not clear cut IMO.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's hard to argue that in datetime | str or UUID | str or datetime | int that the type which actually matches JSON type shouldn't take priority. If you agree I would be happy to change it now.

Otherwise add the TODO: V3 comment so we remember to change it in future.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree but feeling a bit cautious so will mark as TODO and we can decide later if it's a 2.6 or V3 thing.

// represent other datatypes such as UUID and date more exactly, so string is a
// converting input
// TODO: in V3 we may want to make JSON str always win if in union, for consistency,
// see https://github.com/pydantic/pydantic-core/pull/867#discussion_r1386582501
match self {
JsonValue::Str(s) => Ok(s.as_str().into()),
JsonValue::Int(i) if coerce_numbers_to_str => Ok(i.to_string().into()),
JsonValue::BigInt(b) if coerce_numbers_to_str => Ok(b.to_string().into()),
JsonValue::Float(f) if coerce_numbers_to_str => Ok(f.to_string().into()),
JsonValue::Str(s) => Ok(ValidationMatch::strict(s.as_str().into())),
JsonValue::Int(i) if !strict && coerce_numbers_to_str => Ok(ValidationMatch::lax(i.to_string().into())),
JsonValue::BigInt(b) if !strict && coerce_numbers_to_str => Ok(ValidationMatch::lax(b.to_string().into())),
JsonValue::Float(f) if !strict && coerce_numbers_to_str => Ok(ValidationMatch::lax(f.to_string().into())),
_ => Err(ValError::new(ErrorTypeDefaults::StringType, self)),
}
}
Expand All @@ -111,70 +122,39 @@ impl<'a> Input<'a> for JsonValue {
self.validate_bytes(false)
}

fn strict_bool(&self) -> ValResult<bool> {
match self {
JsonValue::Bool(b) => Ok(*b),
_ => Err(ValError::new(ErrorTypeDefaults::BoolType, self)),
}
}
fn lax_bool(&self) -> ValResult<bool> {
fn validate_bool(&self, strict: bool) -> ValResult<'_, ValidationMatch<bool>> {
davidhewitt marked this conversation as resolved.
Show resolved Hide resolved
match self {
JsonValue::Bool(b) => Ok(*b),
JsonValue::Str(s) => str_as_bool(self, s),
JsonValue::Int(int) => int_as_bool(self, *int),
JsonValue::Float(float) => match float_as_int(self, *float) {
JsonValue::Bool(b) => Ok(ValidationMatch::exact(*b)),
JsonValue::Str(s) if !strict => str_as_bool(self, s).map(ValidationMatch::lax),
JsonValue::Int(int) if !strict => int_as_bool(self, *int).map(ValidationMatch::lax),
JsonValue::Float(float) if !strict => match float_as_int(self, *float) {
Ok(int) => int
.as_bool()
.ok_or_else(|| ValError::new(ErrorTypeDefaults::BoolParsing, self)),
.ok_or_else(|| ValError::new(ErrorTypeDefaults::BoolParsing, self))
.map(ValidationMatch::lax),
_ => Err(ValError::new(ErrorTypeDefaults::BoolType, self)),
},
_ => Err(ValError::new(ErrorTypeDefaults::BoolType, self)),
}
}

fn strict_int(&'a self) -> ValResult<EitherInt<'a>> {
fn validate_int(&'a self, strict: bool) -> ValResult<'a, ValidationMatch<EitherInt<'a>>> {
match self {
JsonValue::Int(i) => Ok(EitherInt::I64(*i)),
JsonValue::BigInt(b) => Ok(EitherInt::BigInt(b.clone())),
_ => Err(ValError::new(ErrorTypeDefaults::IntType, self)),
}
}
fn lax_int(&'a self) -> ValResult<EitherInt<'a>> {
match self {
JsonValue::Bool(b) => match *b {
true => Ok(EitherInt::I64(1)),
false => Ok(EitherInt::I64(0)),
},
JsonValue::Int(i) => Ok(EitherInt::I64(*i)),
JsonValue::BigInt(b) => Ok(EitherInt::BigInt(b.clone())),
JsonValue::Float(f) => float_as_int(self, *f),
JsonValue::Str(str) => str_as_int(self, str),
JsonValue::Int(i) => Ok(ValidationMatch::exact(EitherInt::I64(*i))),
JsonValue::BigInt(b) => Ok(ValidationMatch::exact(EitherInt::BigInt(b.clone()))),
JsonValue::Bool(b) if !strict => Ok(ValidationMatch::lax(EitherInt::I64((*b).into()))),
JsonValue::Float(f) if !strict => float_as_int(self, *f).map(ValidationMatch::lax),
JsonValue::Str(str) if !strict => str_as_int(self, str).map(ValidationMatch::lax),
_ => Err(ValError::new(ErrorTypeDefaults::IntType, self)),
}
}

fn ultra_strict_float(&'a self) -> ValResult<EitherFloat<'a>> {
fn validate_float(&'a self, strict: bool) -> ValResult<'a, ValidationMatch<EitherFloat<'a>>> {
match self {
JsonValue::Float(f) => Ok(EitherFloat::F64(*f)),
_ => Err(ValError::new(ErrorTypeDefaults::FloatType, self)),
}
}
fn strict_float(&'a self) -> ValResult<EitherFloat<'a>> {
match self {
JsonValue::Float(f) => Ok(EitherFloat::F64(*f)),
JsonValue::Int(i) => Ok(EitherFloat::F64(*i as f64)),
_ => Err(ValError::new(ErrorTypeDefaults::FloatType, self)),
}
}
fn lax_float(&'a self) -> ValResult<EitherFloat<'a>> {
match self {
JsonValue::Bool(b) => match *b {
true => Ok(EitherFloat::F64(1.0)),
false => Ok(EitherFloat::F64(0.0)),
},
JsonValue::Float(f) => Ok(EitherFloat::F64(*f)),
JsonValue::Int(i) => Ok(EitherFloat::F64(*i as f64)),
JsonValue::Str(str) => str_as_float(self, str),
JsonValue::Float(f) => Ok(ValidationMatch::exact(EitherFloat::F64(*f))),
JsonValue::Int(i) => Ok(ValidationMatch::strict(EitherFloat::F64(*i as f64))),
JsonValue::Bool(b) if !strict => Ok(ValidationMatch::lax(EitherFloat::F64(if *b { 1.0 } else { 0.0 }))),
JsonValue::Str(str) if !strict => str_as_float(self, str).map(ValidationMatch::lax),
_ => Err(ValError::new(ErrorTypeDefaults::FloatType, self)),
}
}
Expand Down Expand Up @@ -399,30 +379,36 @@ impl<'a> Input<'a> for String {
JsonValue::parse(self.as_bytes(), true).map_err(|e| map_json_err(self, e))
}

fn strict_str(&'a self) -> ValResult<EitherString<'a>> {
Ok(self.as_str().into())
fn validate_str(
&'a self,
_strict: bool,
_coerce_numbers_to_str: bool,
) -> ValResult<ValidationMatch<EitherString<'a>>> {
// Justification for `strict` instead of `exact` is that in JSON strings can also
// represent other datatypes such as UUID and date more exactly, so string is a
// converting input
// TODO: in V3 we may want to make JSON str always win if in union, for consistency,
// see https://github.com/pydantic/pydantic-core/pull/867#discussion_r1386582501
Ok(ValidationMatch::strict(self.as_str().into()))
}

fn strict_bytes(&'a self) -> ValResult<EitherBytes<'a>> {
Ok(self.as_bytes().into())
}

fn strict_bool(&self) -> ValResult<bool> {
str_as_bool(self, self)
fn validate_bool(&self, _strict: bool) -> ValResult<'_, ValidationMatch<bool>> {
str_as_bool(self, self).map(ValidationMatch::lax)
}

fn strict_int(&'a self) -> ValResult<EitherInt<'a>> {
fn validate_int(&'a self, _strict: bool) -> ValResult<'a, ValidationMatch<EitherInt<'a>>> {
match self.parse() {
Ok(i) => Ok(EitherInt::I64(i)),
Ok(i) => Ok(ValidationMatch::lax(EitherInt::I64(i))),
Err(_) => Err(ValError::new(ErrorTypeDefaults::IntParsing, self)),
}
}

fn ultra_strict_float(&'a self) -> ValResult<EitherFloat<'a>> {
self.strict_float()
}
fn strict_float(&'a self) -> ValResult<EitherFloat<'a>> {
str_as_float(self, self)
fn validate_float(&'a self, _strict: bool) -> ValResult<'a, ValidationMatch<EitherFloat<'a>>> {
str_as_float(self, self).map(ValidationMatch::lax)
}

fn strict_decimal(&'a self, py: Python<'a>) -> ValResult<&'a PyAny> {
Expand Down