Skip to content

Commit

Permalink
Support serialization of deque when passed to `Sequence[blah blah b…
Browse files Browse the repository at this point in the history
…lah]` (#9128)
  • Loading branch information
sydney-runkle committed Mar 27, 2024
1 parent b30f44d commit 477e0a2
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 19 deletions.
10 changes: 9 additions & 1 deletion pydantic/_internal/_generate_schema.py
Expand Up @@ -1434,6 +1434,8 @@ def _subclass_schema(self, type_: Any) -> core_schema.CoreSchema:

def _sequence_schema(self, sequence_type: Any) -> core_schema.CoreSchema:
"""Generate schema for a Sequence, e.g. `Sequence[int]`."""
from ._std_types_schema import serialize_sequence_via_list

item_type = self._get_first_arg_or_any(sequence_type)
item_type_schema = self.generate_schema(item_type)
list_schema = core_schema.list_schema(item_type_schema)
Expand All @@ -1445,7 +1447,13 @@ def _sequence_schema(self, sequence_type: Any) -> core_schema.CoreSchema:
python_schema = core_schema.chain_schema(
[python_schema, core_schema.no_info_wrap_validator_function(sequence_validator, list_schema)],
)
return core_schema.json_or_python_schema(json_schema=list_schema, python_schema=python_schema)

serialization = core_schema.wrap_serializer_function_ser_schema(
serialize_sequence_via_list, schema=item_type_schema, info_arg=True
)
return core_schema.json_or_python_schema(
json_schema=list_schema, python_schema=python_schema, serialization=serialization
)

def _iterable_schema(self, type_: Any) -> core_schema.GeneratorSchema:
"""Generate a schema for an `Iterable`."""
Expand Down
43 changes: 25 additions & 18 deletions pydantic/_internal/_std_types_schema.py
Expand Up @@ -275,6 +275,30 @@ def dequeue_validator(
return collections.deque(handler(input_value), maxlen=maxlen)


def serialize_sequence_via_list(
v: Any, handler: core_schema.SerializerFunctionWrapHandler, info: core_schema.SerializationInfo
) -> Any:
items: list[Any] = []

mapped_origin = SEQUENCE_ORIGIN_MAP.get(type(v), None)
if mapped_origin is None:
# we shouldn't hit this branch, should probably add a serialization error or something
return v

for index, item in enumerate(v):
try:
v = handler(item, index)
except PydanticOmit:
pass
else:
items.append(v)

if info.mode_is_json():
return items
else:
return mapped_origin(items)


@dataclasses.dataclass(**slots_true)
class SequenceValidator:
mapped_origin: type[Any]
Expand All @@ -283,23 +307,6 @@ class SequenceValidator:
max_length: int | None = None
strict: bool | None = None

def serialize_sequence_via_list(
self, v: Any, handler: core_schema.SerializerFunctionWrapHandler, info: core_schema.SerializationInfo
) -> Any:
items: list[Any] = []
for index, item in enumerate(v):
try:
v = handler(item, index)
except PydanticOmit:
pass
else:
items.append(v)

if info.mode_is_json():
return items
else:
return self.mapped_origin(items)

def __get_pydantic_core_schema__(self, source_type: Any, handler: GetCoreSchemaHandler) -> CoreSchema:
if self.item_source_type is Any:
items_schema = None
Expand Down Expand Up @@ -344,7 +351,7 @@ def __get_pydantic_core_schema__(self, source_type: Any, handler: GetCoreSchemaH
)

serialization = core_schema.wrap_serializer_function_ser_schema(
self.serialize_sequence_via_list, schema=items_schema or core_schema.any_schema(), info_arg=True
serialize_sequence_via_list, schema=items_schema or core_schema.any_schema(), info_arg=True
)

strict = core_schema.chain_schema([check_instance, coerce_instance_wrap(constrained_schema)])
Expand Down
9 changes: 9 additions & 0 deletions tests/test_types.py
Expand Up @@ -6636,3 +6636,12 @@ class Model(BaseModel):
obj = Model.model_validate({'my_model': {'my_enum': 'a'}, 'other_model': {'my_enum': 'a'}})
assert not isinstance(obj.my_model.my_enum, MyEnum)
assert isinstance(obj.other_model.my_enum, MyEnum)


def test_can_serialize_deque_passed_to_sequence() -> None:
ta = TypeAdapter(Sequence[int])
my_dec = ta.validate_python(deque([1, 2, 3]))
assert my_dec == deque([1, 2, 3])

assert ta.dump_python(my_dec) == my_dec
assert ta.dump_json(my_dec) == b'[1,2,3]'

0 comments on commit 477e0a2

Please sign in to comment.