Skip to content

Commit

Permalink
Serialize unsubstituted type vars as Any (#7606)
Browse files Browse the repository at this point in the history
  • Loading branch information
adriangb committed Sep 25, 2023
1 parent 1f59075 commit 095ff12
Show file tree
Hide file tree
Showing 3 changed files with 258 additions and 24 deletions.
108 changes: 89 additions & 19 deletions docs/concepts/models.md
Expand Up @@ -783,43 +783,113 @@ print(concrete_model(a=1, b=1))

If you need to perform isinstance checks against parametrized generics, you can do this by subclassing the parametrized generic class. This looks like `class MyIntModel(MyGenericModel[int]): ...` and `isinstance(my_model, MyIntModel)`.

If a Pydantic model is used in a `TypeVar` constraint, [`SerializeAsAny`](serialization.md#serializing-with-duck-typing) can be used to
serialize it using the concrete model instead of the model `TypeVar` is bound to.
If a Pydantic model is used in a `TypeVar` bound and the generic type is never parametrized then Pydantic will use the bound for validation but treat the value as `Any` in terms of serialization:

```py
from typing import Generic, TypeVar
from typing import Generic, Optional, TypeVar

from pydantic import BaseModel, SerializeAsAny
from pydantic import BaseModel


class Model(BaseModel):
a: int = 42
class ErrorDetails(BaseModel):
foo: str


class DataModel(Model):
b: int = 2
c: int = 3
ErrorDataT = TypeVar('ErrorDataT', bound=ErrorDetails)


class Error(BaseModel, Generic[ErrorDataT]):
message: str
details: Optional[ErrorDataT]


BoundT = TypeVar('BoundT', bound=Model)
class MyErrorDetails(ErrorDetails):
bar: str


class GenericModel(BaseModel, Generic[BoundT]):
data: BoundT
# serialized as Any
error = Error(
message='We just had an error',
details=MyErrorDetails(foo='var', bar='var2'),
)
assert error.model_dump() == {
'message': 'We just had an error',
'details': {
'foo': 'var',
'bar': 'var2',
},
}

# serialized using the concrete parametrization
# note that `'bar': 'var2'` is missing
error = Error[ErrorDetails](
message='We just had an error',
details=ErrorDetails(foo='var'),
)
assert error.model_dump() == {
'message': 'We just had an error',
'details': {
'foo': 'var',
},
}
```

class SerializeAsAnyModel(BaseModel, Generic[BoundT]):
data: SerializeAsAny[BoundT]
If you use a `default=...` (available in Python >= 3.13 or via `typing-extensions`) or constraints (`TypeVar('T', str, int)`; note that you rarely want to use this form of a `TypeVar`) then the default value or constraints will be used for both validation and serialization if the type variable is not parametrized. You can override this behavior using `pydantic.SerializeAsAny`:

```py
from typing import Generic, Optional

from typing_extensions import TypeVar

data_model = DataModel()
from pydantic import BaseModel, SerializeAsAny

print(GenericModel(data=data_model).model_dump())
#> {'data': {'a': 42}}

class ErrorDetails(BaseModel):
foo: str

print(SerializeAsAnyModel(data=data_model).model_dump())
#> {'data': {'a': 42, 'b': 2, 'c': 3}}

ErrorDataT = TypeVar('ErrorDataT', default=ErrorDetails)


class Error(BaseModel, Generic[ErrorDataT]):
message: str
details: Optional[ErrorDataT]


class MyErrorDetails(ErrorDetails):
bar: str


# serialized using the default's serializer
error = Error(
message='We just had an error',
details=MyErrorDetails(foo='var', bar='var2'),
)
assert error.model_dump() == {
'message': 'We just had an error',
'details': {
'foo': 'var',
},
}


class SerializeAsAnyError(BaseModel, Generic[ErrorDataT]):
message: str
details: Optional[SerializeAsAny[ErrorDataT]]


# serialized as Any
error = SerializeAsAnyError(
message='We just had an error',
details=MyErrorDetails(foo='var', bar='baz'),
)
assert error.model_dump() == {
'message': 'We just had an error',
'details': {
'foo': 'var',
'bar': 'baz',
},
}
```

## Dynamic model creation
Expand Down
24 changes: 20 additions & 4 deletions pydantic/_internal/_generate_schema.py
Expand Up @@ -1449,10 +1449,26 @@ def _callable_schema(self, function: Callable[..., Any]) -> core_schema.CallSche
def _unsubstituted_typevar_schema(self, typevar: typing.TypeVar) -> core_schema.CoreSchema:
assert isinstance(typevar, typing.TypeVar)

if typevar.__bound__:
return self.generate_schema(typevar.__bound__)
elif typevar.__constraints__:
return self._union_schema(typing.Union[typevar.__constraints__]) # type: ignore
bound = typevar.__bound__
constraints = typevar.__constraints__
not_set = object()
default = getattr(typevar, '__default__', not_set)

if (bound is not None) + (len(constraints) != 0) + (default is not not_set) > 1:
raise NotImplementedError(
'Pydantic does not support mixing more than one of TypeVar bounds, constraints and defaults'
)

if default is not not_set:
return self.generate_schema(default)
elif constraints:
return self._union_schema(typing.Union[constraints]) # type: ignore
elif bound:
schema = self.generate_schema(bound)
schema['serialization'] = core_schema.wrap_serializer_function_ser_schema(
lambda x, h: h(x), schema=core_schema.any_schema()
)
return schema
else:
return core_schema.any_schema()

Expand Down
150 changes: 149 additions & 1 deletion tests/test_generics.py
Expand Up @@ -32,7 +32,18 @@
import pytest
from dirty_equals import HasRepr, IsStr
from pydantic_core import CoreSchema, core_schema
from typing_extensions import Annotated, Literal, OrderedDict, ParamSpec, TypeVarTuple, Unpack, get_args
from typing_extensions import (
Annotated,
Literal,
OrderedDict,
ParamSpec,
TypeVarTuple,
Unpack,
get_args,
)
from typing_extensions import (
TypeVar as TypingExtensionsTypeVar,
)

from pydantic import (
BaseModel,
Expand Down Expand Up @@ -2622,3 +2633,140 @@ class Model(Generic[T], BaseModel):
m1 = Model[int](x=1)
m2 = Model[int](x=1)
assert len({m1, m2}) == 1


def test_serialize_unsubstituted_typevars_bound() -> None:
class ErrorDetails(BaseModel):
foo: str

ErrorDataT = TypeVar('ErrorDataT', bound=ErrorDetails)

class Error(BaseModel, Generic[ErrorDataT]):
message: str
details: ErrorDataT

class MyErrorDetails(ErrorDetails):
bar: str

sample_error = Error(
message='We just had an error',
details=MyErrorDetails(foo='var', bar='baz'),
)
assert sample_error.details.model_dump() == {
'foo': 'var',
'bar': 'baz',
}
assert sample_error.model_dump() == {
'message': 'We just had an error',
'details': {
'foo': 'var',
'bar': 'baz',
},
}

sample_error = Error[ErrorDetails](
message='We just had an error',
details=MyErrorDetails(foo='var', bar='baz'),
)
assert sample_error.details.model_dump() == {
'foo': 'var',
'bar': 'baz',
}
assert sample_error.model_dump() == {
'message': 'We just had an error',
'details': {
'foo': 'var',
},
}

sample_error = Error[MyErrorDetails](
message='We just had an error',
details=MyErrorDetails(foo='var', bar='baz'),
)
assert sample_error.details.model_dump() == {
'foo': 'var',
'bar': 'baz',
}
assert sample_error.model_dump() == {
'message': 'We just had an error',
'details': {
'foo': 'var',
'bar': 'baz',
},
}


@pytest.mark.parametrize(
'type_var',
[
TypingExtensionsTypeVar('ErrorDataT', default=BaseModel),
TypeVar('ErrorDataT', BaseModel, str),
],
ids=['default', 'constraint'],
)
def test_serialize_unsubstituted_typevars_bound(
type_var: Type[BaseModel],
) -> None:
class ErrorDetails(BaseModel):
foo: str

class Error(BaseModel, Generic[type_var]): # type: ignore
message: str
details: type_var

class MyErrorDetails(ErrorDetails):
bar: str

sample_error = Error(
message='We just had an error',
details=MyErrorDetails(foo='var', bar='baz'),
)
assert sample_error.details.model_dump() == {
'foo': 'var',
'bar': 'baz',
}
assert sample_error.model_dump() == {
'message': 'We just had an error',
'details': {},
}

sample_error = Error[ErrorDetails](
message='We just had an error',
details=MyErrorDetails(foo='var', bar='baz'),
)
assert sample_error.details.model_dump() == {
'foo': 'var',
'bar': 'baz',
}
assert sample_error.model_dump() == {
'message': 'We just had an error',
'details': {
'foo': 'var',
},
}

sample_error = Error[MyErrorDetails](
message='We just had an error',
details=MyErrorDetails(foo='var', bar='baz'),
)
assert sample_error.details.model_dump() == {
'foo': 'var',
'bar': 'baz',
}
assert sample_error.model_dump() == {
'message': 'We just had an error',
'details': {
'foo': 'var',
'bar': 'baz',
},
}


def test_mix_default_and_constraints() -> None:
T = TypingExtensionsTypeVar('T', str, int, default=str)

msg = 'Pydantic does not support mixing more than one of TypeVar bounds, constraints and defaults'
with pytest.raises(NotImplementedError, match=msg):

class _(BaseModel, Generic[T]):
x: T

0 comments on commit 095ff12

Please sign in to comment.