Skip to content

Commit

Permalink
Use stricter serializer for unions of simple types (#1132)
Browse files Browse the repository at this point in the history
  • Loading branch information
alexdrydew committed Jan 10, 2024
1 parent 4df7624 commit 8dde89e
Show file tree
Hide file tree
Showing 2 changed files with 134 additions and 10 deletions.
28 changes: 18 additions & 10 deletions src/serializers/type_serializers/simple.rs
Expand Up @@ -5,11 +5,12 @@ use std::borrow::Cow;

use serde::Serialize;

use crate::PydanticSerializationUnexpectedValue;
use crate::{definitions::DefinitionsBuilder, input::Int};

use super::{
infer_json_key, infer_serialize, infer_to_python, BuildSerializer, CombinedSerializer, Extra, IsType, ObType,
SerMode, TypeSerializer,
SerCheck, SerMode, TypeSerializer,
};

#[derive(Debug, Clone)]
Expand Down Expand Up @@ -85,7 +86,7 @@ impl TypeSerializer for NoneSerializer {
}

macro_rules! build_simple_serializer {
($struct_name:ident, $expected_type:literal, $rust_type:ty, $ob_type:expr, $key_method:ident) => {
($struct_name:ident, $expected_type:literal, $rust_type:ty, $ob_type:expr, $key_method:ident, $subtypes_allowed:expr) => {
#[derive(Debug, Clone)]
pub struct $struct_name;

Expand Down Expand Up @@ -114,12 +115,15 @@ macro_rules! build_simple_serializer {
let py = value.py();
match extra.ob_type_lookup.is_type(value, $ob_type) {
IsType::Exact => Ok(value.into_py(py)),
IsType::Subclass => match extra.mode {
SerMode::Json => {
let rust_value = value.extract::<$rust_type>()?;
Ok(rust_value.to_object(py))
}
_ => infer_to_python(value, include, exclude, extra),
IsType::Subclass => match extra.check {
SerCheck::Strict => Err(PydanticSerializationUnexpectedValue::new_err(None)),
SerCheck::Lax | SerCheck::None => match extra.mode {
SerMode::Json => {
let rust_value = value.extract::<$rust_type>()?;
Ok(rust_value.to_object(py))
}
_ => infer_to_python(value, include, exclude, extra),
},
},
IsType::False => {
extra.warnings.on_fallback_py(self.get_name(), value, extra)?;
Expand Down Expand Up @@ -160,6 +164,10 @@ macro_rules! build_simple_serializer {
fn get_name(&self) -> &str {
Self::EXPECTED_TYPE
}

fn retry_with_lax_check(&self) -> bool {
$subtypes_allowed
}
}
};
}
Expand All @@ -168,7 +176,7 @@ pub(crate) fn to_str_json_key(key: &PyAny) -> PyResult<Cow<str>> {
Ok(key.str()?.to_string_lossy())
}

build_simple_serializer!(IntSerializer, "int", Int, ObType::Int, to_str_json_key);
build_simple_serializer!(IntSerializer, "int", Int, ObType::Int, to_str_json_key, true);

pub(crate) fn bool_json_key(key: &PyAny) -> PyResult<Cow<str>> {
let v = if key.is_true().unwrap_or(false) {
Expand All @@ -179,4 +187,4 @@ pub(crate) fn bool_json_key(key: &PyAny) -> PyResult<Cow<str>> {
Ok(Cow::Borrowed(v))
}

build_simple_serializer!(BoolSerializer, "bool", bool, ObType::Bool, bool_json_key);
build_simple_serializer!(BoolSerializer, "bool", bool, ObType::Bool, bool_json_key, false);
116 changes: 116 additions & 0 deletions tests/serializers/test_union.py
@@ -1,6 +1,8 @@
import dataclasses
import json
import re
import uuid
from decimal import Decimal
from typing import Any, ClassVar, Union

import pytest
Expand Down Expand Up @@ -510,3 +512,117 @@ class Item(BaseModel):
)

assert s.to_python([DBUser(name='John', password='secret')]) == [{'name': 'John'}]


EXAMPLE_UUID = uuid.uuid4()


class IntSubclass(int):
pass


@pytest.mark.parametrize('reverse', [False, True])
@pytest.mark.parametrize(
'core_schema_left,core_schema_right,input_value,expected_value',
[
(core_schema.int_schema(), core_schema.bool_schema(), True, True),
(core_schema.int_schema(), core_schema.bool_schema(), 1, 1),
(core_schema.str_schema(), core_schema.int_schema(), 1, 1),
(core_schema.str_schema(), core_schema.int_schema(), '1', '1'),
(core_schema.int_schema(), core_schema.bool_schema(), IntSubclass(1), 1),
(
core_schema.decimal_schema(),
core_schema.int_schema(),
Decimal('1'),
Decimal('1'),
),
(core_schema.decimal_schema(), core_schema.int_schema(), 1, 1),
(
core_schema.decimal_schema(),
core_schema.float_schema(),
Decimal('1.'),
Decimal('1.'),
),
(
core_schema.decimal_schema(),
core_schema.str_schema(),
Decimal('_1'),
Decimal('_1'),
),
(
core_schema.decimal_schema(),
core_schema.str_schema(),
'_1',
'_1',
),
(
core_schema.uuid_schema(),
core_schema.str_schema(),
EXAMPLE_UUID,
EXAMPLE_UUID,
),
(
core_schema.uuid_schema(),
core_schema.str_schema(),
str(EXAMPLE_UUID),
str(EXAMPLE_UUID),
),
],
)
def test_union_serializer_picks_exact_type_over_subclass(
core_schema_left, core_schema_right, input_value, expected_value, reverse
):
s = SchemaSerializer(
core_schema.union_schema(
[core_schema_right, core_schema_left] if reverse else [core_schema_left, core_schema_right]
)
)
assert s.to_python(input_value) == expected_value


@pytest.mark.parametrize('reverse', [False, True])
@pytest.mark.parametrize(
'core_schema_left,core_schema_right,input_value,expected_value',
[
(core_schema.int_schema(), core_schema.bool_schema(), True, True),
(core_schema.int_schema(), core_schema.bool_schema(), 1, 1),
(core_schema.str_schema(), core_schema.int_schema(), 1, 1),
(core_schema.str_schema(), core_schema.int_schema(), '1', '1'),
(core_schema.int_schema(), core_schema.bool_schema(), IntSubclass(1), 1),
(
core_schema.decimal_schema(),
core_schema.int_schema(),
Decimal('1'),
'1',
),
(core_schema.decimal_schema(), core_schema.int_schema(), 1, 1),
(
core_schema.decimal_schema(),
core_schema.float_schema(),
Decimal('1.'),
'1',
),
(
core_schema.decimal_schema(),
core_schema.str_schema(),
Decimal('_1'),
'1',
),
(
core_schema.decimal_schema(),
core_schema.str_schema(),
'_1',
'_1',
),
],
)
def test_union_serializer_picks_exact_type_over_subclass_json(
core_schema_left, core_schema_right, input_value, expected_value, reverse
):
s = SchemaSerializer(
core_schema.union_schema(
[core_schema_right, core_schema_left] if reverse else [core_schema_left, core_schema_right]
)
)
assert s.to_python(input_value, mode='json') == expected_value
assert s.to_json(input_value) == json.dumps(expected_value).encode()

0 comments on commit 8dde89e

Please sign in to comment.