Skip to content

Commit

Permalink
Add lax_str and lax_int support for enum values not inherited from st…
Browse files Browse the repository at this point in the history
…r/int (#1015)
  • Loading branch information
michaelhly committed Oct 26, 2023
1 parent 23d1065 commit 866eb2d
Show file tree
Hide file tree
Showing 5 changed files with 65 additions and 3 deletions.
21 changes: 20 additions & 1 deletion src/input/input_python.rs
Expand Up @@ -21,7 +21,10 @@ use super::datetime::{
float_as_duration, float_as_time, int_as_datetime, int_as_duration, int_as_time, EitherDate, EitherDateTime,
EitherTime,
};
use super::shared::{decimal_as_int, float_as_int, int_as_bool, map_json_err, str_as_bool, str_as_float, str_as_int};
use super::shared::{
decimal_as_int, float_as_int, get_enum_meta_object, int_as_bool, map_json_err, str_as_bool, str_as_float,
str_as_int,
};
use super::{
py_string_str, BorrowInput, EitherBytes, EitherFloat, EitherInt, EitherString, EitherTimedelta, GenericArguments,
GenericIterable, GenericIterator, GenericMapping, Input, JsonInput, PyArgs,
Expand Down Expand Up @@ -256,6 +259,8 @@ impl<'a> Input<'a> for PyAny {
|| self.is_instance(decimal_type.as_ref(py)).unwrap_or_default()
} {
Ok(self.str()?.into())
} else if let Some(enum_val) = maybe_as_enum(self) {
Ok(enum_val.str()?.into())
} else {
Err(ValError::new(ErrorTypeDefaults::StringType, self))
}
Expand Down Expand Up @@ -340,6 +345,8 @@ impl<'a> Input<'a> for PyAny {
decimal_as_int(self.py(), self, decimal)
} else if let Ok(float) = self.extract::<f64>() {
float_as_int(self, float)
} else if let Some(enum_val) = maybe_as_enum(self) {
Ok(EitherInt::Py(enum_val))
} else {
Err(ValError::new(ErrorTypeDefaults::IntType, self))
}
Expand Down Expand Up @@ -759,6 +766,18 @@ fn maybe_as_string(v: &PyAny, unicode_error: ErrorType) -> ValResult<Option<Cow<
}
}

/// Utility for extracting an enum value, if possible.
fn maybe_as_enum(v: &PyAny) -> Option<&PyAny> {
let py = v.py();
let enum_meta_object = get_enum_meta_object(py);
let meta_type = v.get_type().get_type();
if meta_type.is(&enum_meta_object) {
v.getattr(intern!(py, "value")).ok()
} else {
None
}
}

#[cfg(PyPy)]
static DICT_KEYS_TYPE: pyo3::once_cell::GILOnceCell<Py<PyType>> = pyo3::once_cell::GILOnceCell::new();

Expand Down
16 changes: 15 additions & 1 deletion src/input/shared.rs
@@ -1,11 +1,25 @@
use num_bigint::BigInt;
use pyo3::{intern, PyAny, Python};
use pyo3::sync::GILOnceCell;
use pyo3::{intern, Py, PyAny, Python, ToPyObject};

use crate::errors::{ErrorType, ErrorTypeDefaults, ValError, ValResult};

use super::parse_json::{JsonArray, JsonInput};
use super::{EitherFloat, EitherInt, Input};

static ENUM_META_OBJECT: GILOnceCell<Py<PyAny>> = GILOnceCell::new();

pub fn get_enum_meta_object(py: Python) -> Py<PyAny> {
ENUM_META_OBJECT
.get_or_init(py, || {
py.import(intern!(py, "enum"))
.and_then(|enum_module| enum_module.getattr(intern!(py, "EnumMeta")))
.unwrap()
.to_object(py)
})
.clone()
}

pub fn map_json_err<'a>(input: &'a impl Input<'a>, error: serde_json::Error) -> ValError<'a> {
ValError::new(
ErrorType::JsonInvalid {
Expand Down
3 changes: 2 additions & 1 deletion src/serializers/ob_type.rs
Expand Up @@ -259,8 +259,9 @@ impl ObTypeLookup {
fn is_enum(&self, op_value: Option<&PyAny>, py_type: &PyType) -> bool {
// only test on the type itself, not base types
if op_value.is_some() {
let enum_meta_type = self.enum_object.as_ref(py_type.py()).get_type();
let meta_type = py_type.get_type();
meta_type.is(&self.enum_object)
meta_type.is(enum_meta_type)
} else {
false
}
Expand Down
13 changes: 13 additions & 0 deletions tests/validators/test_int.py
Expand Up @@ -459,3 +459,16 @@ def test_float_subclass() -> None:
v_lax = v.validate_python(FloatSubclass(1))
assert v_lax == 1
assert type(v_lax) == int


def test_int_subclass_plain_enum() -> None:
v = SchemaValidator({'type': 'int'})

from enum import Enum

class PlainEnum(Enum):
ONE = 1

v_lax = v.validate_python(PlainEnum.ONE)
assert v_lax == 1
assert type(v_lax) == int
15 changes: 15 additions & 0 deletions tests/validators/test_string.py
Expand Up @@ -249,6 +249,21 @@ def test_lax_subclass(FruitEnum, kwargs):
assert repr(p) == "'pear'"


@pytest.mark.parametrize('kwargs', [{}, {'to_lower': True}], ids=repr)
def test_lax_subclass_plain_enum(kwargs):
v = SchemaValidator(core_schema.str_schema(**kwargs))

from enum import Enum

class PlainEnum(Enum):
ONE = 'one'

p = v.validate_python(PlainEnum.ONE)
assert p == 'one'
assert type(p) is str
assert repr(p) == "'one'"


def test_subclass_preserved() -> None:
class StrSubclass(str):
pass
Expand Down

0 comments on commit 866eb2d

Please sign in to comment.