Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

replace ultra_strict with new union implementation #867

Merged
merged 15 commits into from Nov 8, 2023
Merged
12 changes: 0 additions & 12 deletions src/definitions.rs
Expand Up @@ -31,21 +31,9 @@ use crate::{build_tools::py_schema_err, py_gc::PyGcTraverse};
#[derive(Clone)]
pub struct Definitions<T>(AHashMap<Arc<String>, Definition<T>>);

impl<T> Definitions<T> {
pub fn values(&self) -> impl Iterator<Item = &Definition<T>> {
self.0.values()
}
}

/// Internal type which contains a definition to be filled
pub struct Definition<T>(Arc<DefinitionInner<T>>);

impl<T> Definition<T> {
pub fn get(&self) -> Option<&T> {
self.0.value.get()
}
}

struct DefinitionInner<T> {
value: OnceLock<T>,
name: LazyName,
Expand Down
162 changes: 26 additions & 136 deletions src/input/input_abstract.rs
Expand Up @@ -6,13 +6,13 @@ use pyo3::{intern, prelude::*};

use jiter::JsonValue;

use crate::errors::{AsLocItem, InputValue, ValResult};
use crate::errors::{AsLocItem, ErrorTypeDefaults, InputValue, ValError, ValResult};
use crate::tools::py_err;
use crate::{PyMultiHostUrl, PyUrl};

use super::datetime::{EitherDate, EitherDateTime, EitherTime, EitherTimedelta};
use super::return_enums::{EitherBytes, EitherInt, EitherString};
use super::{EitherFloat, GenericArguments, GenericIterable, GenericIterator, GenericMapping};
use super::{EitherFloat, GenericArguments, GenericIterable, GenericIterator, GenericMapping, ValidationMatch};

#[derive(Debug, Clone, Copy)]
pub enum InputType {
Expand Down Expand Up @@ -48,7 +48,7 @@ impl TryFrom<&str> for InputType {
/// the convention is to either implement:
/// * `strict_*` & `lax_*` if they have different behavior
/// * or, `validate_*` and `strict_*` to just call `validate_*` if the behavior for strict and lax is the same
pub trait Input<'a>: fmt::Debug + ToPyObject + AsLocItem {
pub trait Input<'a>: fmt::Debug + ToPyObject + AsLocItem + Sized {
fn as_error_value(&'a self) -> InputValue<'a>;

fn identity(&self) -> Option<usize> {
Expand Down Expand Up @@ -91,85 +91,37 @@ pub trait Input<'a>: fmt::Debug + ToPyObject + AsLocItem {

fn parse_json(&'a self) -> ValResult<'a, JsonValue>;

fn validate_str(&'a self, strict: bool, coerce_numbers_to_str: bool) -> ValResult<EitherString<'a>> {
if strict {
self.strict_str()
} else {
self.lax_str(coerce_numbers_to_str)
}
}
fn strict_str(&'a self) -> ValResult<EitherString<'a>>;
#[cfg_attr(has_coverage_attribute, coverage(off))]
fn lax_str(&'a self, _coerce_numbers_to_str: bool) -> ValResult<EitherString<'a>> {
self.strict_str()
}
fn validate_str(
&'a self,
strict: bool,
coerce_numbers_to_str: bool,
) -> ValResult<ValidationMatch<EitherString<'a>>>;

fn validate_bytes(&'a self, strict: bool) -> ValResult<EitherBytes<'a>> {
if strict {
self.strict_bytes()
} else {
self.lax_bytes()
}
}
fn strict_bytes(&'a self) -> ValResult<EitherBytes<'a>>;
#[cfg_attr(has_coverage_attribute, coverage(off))]
fn lax_bytes(&'a self) -> ValResult<EitherBytes<'a>> {
self.strict_bytes()
}
fn validate_bytes(&'a self, strict: bool) -> ValResult<ValidationMatch<EitherBytes<'a>>>;

fn validate_bool(&self, strict: bool) -> ValResult<bool> {
if strict {
self.strict_bool()
} else {
self.lax_bool()
}
}
fn strict_bool(&self) -> ValResult<bool>;
#[cfg_attr(has_coverage_attribute, coverage(off))]
fn lax_bool(&self) -> ValResult<bool> {
self.strict_bool()
}
fn validate_bool(&self, strict: bool) -> ValResult<'_, ValidationMatch<bool>>;

fn validate_int(&'a self, strict: bool) -> ValResult<EitherInt<'a>> {
if strict {
self.strict_int()
} else {
self.lax_int()
}
}
fn strict_int(&'a self) -> ValResult<EitherInt<'a>>;
#[cfg_attr(has_coverage_attribute, coverage(off))]
fn lax_int(&'a self) -> ValResult<EitherInt<'a>> {
self.strict_int()
}
fn validate_int(&'a self, strict: bool) -> ValResult<'a, ValidationMatch<EitherInt<'a>>>;

/// Extract an EitherInt from the input, only allowing exact
/// matches for an Int (no subclasses)
fn exact_int(&'a self) -> ValResult<EitherInt<'a>> {
self.strict_int()
self.validate_int(true).and_then(|val_match| {
val_match
.require_exact()
.ok_or_else(|| ValError::new(ErrorTypeDefaults::IntType, self))
})
}

/// Extract a String from the input, only allowing exact
/// matches for a String (no subclasses)
fn exact_str(&'a self) -> ValResult<EitherString<'a>> {
samuelcolvin marked this conversation as resolved.
Show resolved Hide resolved
self.strict_str()
self.validate_str(true, false).and_then(|val_match| {
val_match
.require_exact()
.ok_or_else(|| ValError::new(ErrorTypeDefaults::StringType, self))
})
}

fn validate_float(&'a self, strict: bool, ultra_strict: bool) -> ValResult<EitherFloat<'a>> {
if ultra_strict {
self.ultra_strict_float()
} else if strict {
self.strict_float()
} else {
self.lax_float()
}
}
fn ultra_strict_float(&'a self) -> ValResult<EitherFloat<'a>>;
fn strict_float(&'a self) -> ValResult<EitherFloat<'a>>;
#[cfg_attr(has_coverage_attribute, coverage(off))]
fn lax_float(&'a self) -> ValResult<EitherFloat<'a>> {
self.strict_float()
}
fn validate_float(&'a self, strict: bool) -> ValResult<'a, ValidationMatch<EitherFloat<'a>>>;

fn validate_decimal(&'a self, strict: bool, py: Python<'a>) -> ValResult<&'a PyAny> {
if strict {
Expand Down Expand Up @@ -257,87 +209,25 @@ pub trait Input<'a>: fmt::Debug + ToPyObject + AsLocItem {

fn validate_iter(&self) -> ValResult<GenericIterator>;

fn validate_date(&self, strict: bool) -> ValResult<EitherDate> {
if strict {
self.strict_date()
} else {
self.lax_date()
}
}
fn strict_date(&self) -> ValResult<EitherDate>;
#[cfg_attr(has_coverage_attribute, coverage(off))]
fn lax_date(&self) -> ValResult<EitherDate> {
self.strict_date()
}
fn validate_date(&self, strict: bool) -> ValResult<ValidationMatch<EitherDate>>;

fn validate_time(
&self,
strict: bool,
microseconds_overflow_behavior: speedate::MicrosecondsPrecisionOverflowBehavior,
) -> ValResult<EitherTime> {
if strict {
self.strict_time(microseconds_overflow_behavior)
} else {
self.lax_time(microseconds_overflow_behavior)
}
}
fn strict_time(
&self,
microseconds_overflow_behavior: speedate::MicrosecondsPrecisionOverflowBehavior,
) -> ValResult<EitherTime>;
#[cfg_attr(has_coverage_attribute, coverage(off))]
fn lax_time(
&self,
microseconds_overflow_behavior: speedate::MicrosecondsPrecisionOverflowBehavior,
) -> ValResult<EitherTime> {
self.strict_time(microseconds_overflow_behavior)
}
) -> ValResult<ValidationMatch<EitherTime>>;

fn validate_datetime(
&self,
strict: bool,
microseconds_overflow_behavior: speedate::MicrosecondsPrecisionOverflowBehavior,
) -> ValResult<EitherDateTime> {
if strict {
self.strict_datetime(microseconds_overflow_behavior)
} else {
self.lax_datetime(microseconds_overflow_behavior)
}
}
fn strict_datetime(
&self,
microseconds_overflow_behavior: speedate::MicrosecondsPrecisionOverflowBehavior,
) -> ValResult<EitherDateTime>;
#[cfg_attr(has_coverage_attribute, coverage(off))]
fn lax_datetime(
&self,
microseconds_overflow_behavior: speedate::MicrosecondsPrecisionOverflowBehavior,
) -> ValResult<EitherDateTime> {
self.strict_datetime(microseconds_overflow_behavior)
}
) -> ValResult<ValidationMatch<EitherDateTime>>;

fn validate_timedelta(
&self,
strict: bool,
microseconds_overflow_behavior: speedate::MicrosecondsPrecisionOverflowBehavior,
) -> ValResult<EitherTimedelta> {
if strict {
self.strict_timedelta(microseconds_overflow_behavior)
} else {
self.lax_timedelta(microseconds_overflow_behavior)
}
}
fn strict_timedelta(
&self,
microseconds_overflow_behavior: speedate::MicrosecondsPrecisionOverflowBehavior,
) -> ValResult<EitherTimedelta>;
#[cfg_attr(has_coverage_attribute, coverage(off))]
fn lax_timedelta(
&self,
microseconds_overflow_behavior: speedate::MicrosecondsPrecisionOverflowBehavior,
) -> ValResult<EitherTimedelta> {
self.strict_timedelta(microseconds_overflow_behavior)
}
) -> ValResult<ValidationMatch<EitherTimedelta>>;
}

/// The problem to solve here is that iterating a `StringMapping` returns an owned
Expand Down