Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Serialize unsubstituted type vars as Any #7606

Merged
merged 5 commits into from Sep 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
108 changes: 89 additions & 19 deletions docs/concepts/models.md
Expand Up @@ -783,43 +783,113 @@ print(concrete_model(a=1, b=1))

If you need to perform isinstance checks against parametrized generics, you can do this by subclassing the parametrized generic class. This looks like `class MyIntModel(MyGenericModel[int]): ...` and `isinstance(my_model, MyIntModel)`.

If a Pydantic model is used in a `TypeVar` constraint, [`SerializeAsAny`](serialization.md#serializing-with-duck-typing) can be used to
serialize it using the concrete model instead of the model `TypeVar` is bound to.
If a Pydantic model is used in a `TypeVar` bound and the generic type is never parametrized then Pydantic will use the bound for validation but treat the value as `Any` in terms of serialization:

```py
from typing import Generic, TypeVar
from typing import Generic, Optional, TypeVar

from pydantic import BaseModel, SerializeAsAny
from pydantic import BaseModel


class Model(BaseModel):
a: int = 42
class ErrorDetails(BaseModel):
foo: str


class DataModel(Model):
b: int = 2
c: int = 3
ErrorDataT = TypeVar('ErrorDataT', bound=ErrorDetails)


class Error(BaseModel, Generic[ErrorDataT]):
message: str
details: Optional[ErrorDataT]


BoundT = TypeVar('BoundT', bound=Model)
class MyErrorDetails(ErrorDetails):
bar: str


class GenericModel(BaseModel, Generic[BoundT]):
data: BoundT
# serialized as Any
error = Error(
message='We just had an error',
details=MyErrorDetails(foo='var', bar='var2'),
)
assert error.model_dump() == {
'message': 'We just had an error',
'details': {
'foo': 'var',
'bar': 'var2',
},
}

# serialized using the concrete parametrization
# note that `'bar': 'var2'` is missing
error = Error[ErrorDetails](
message='We just had an error',
details=ErrorDetails(foo='var'),
)
assert error.model_dump() == {
'message': 'We just had an error',
'details': {
'foo': 'var',
},
}
```

class SerializeAsAnyModel(BaseModel, Generic[BoundT]):
data: SerializeAsAny[BoundT]
If you use a `default=...` (available in Python >= 3.13 or via `typing-extensions`) or constraints (`TypeVar('T', str, int)`; note that you rarely want to use this form of a `TypeVar`) then the default value or constraints will be used for both validation and serialization if the type variable is not parametrized. You can override this behavior using `pydantic.SerializeAsAny`:

```py
from typing import Generic, Optional

from typing_extensions import TypeVar

data_model = DataModel()
from pydantic import BaseModel, SerializeAsAny

print(GenericModel(data=data_model).model_dump())
#> {'data': {'a': 42}}

class ErrorDetails(BaseModel):
foo: str

print(SerializeAsAnyModel(data=data_model).model_dump())
#> {'data': {'a': 42, 'b': 2, 'c': 3}}

ErrorDataT = TypeVar('ErrorDataT', default=ErrorDetails)


class Error(BaseModel, Generic[ErrorDataT]):
message: str
details: Optional[ErrorDataT]


class MyErrorDetails(ErrorDetails):
bar: str


# serialized using the default's serializer
error = Error(
message='We just had an error',
details=MyErrorDetails(foo='var', bar='var2'),
)
assert error.model_dump() == {
'message': 'We just had an error',
'details': {
'foo': 'var',
},
}


class SerializeAsAnyError(BaseModel, Generic[ErrorDataT]):
message: str
details: Optional[SerializeAsAny[ErrorDataT]]


# serialized as Any
error = SerializeAsAnyError(
message='We just had an error',
details=MyErrorDetails(foo='var', bar='baz'),
)
assert error.model_dump() == {
'message': 'We just had an error',
'details': {
'foo': 'var',
'bar': 'baz',
},
}
```

## Dynamic model creation
Expand Down
24 changes: 20 additions & 4 deletions pydantic/_internal/_generate_schema.py
Expand Up @@ -1449,10 +1449,26 @@ def _callable_schema(self, function: Callable[..., Any]) -> core_schema.CallSche
def _unsubstituted_typevar_schema(self, typevar: typing.TypeVar) -> core_schema.CoreSchema:
assert isinstance(typevar, typing.TypeVar)

if typevar.__bound__:
return self.generate_schema(typevar.__bound__)
elif typevar.__constraints__:
return self._union_schema(typing.Union[typevar.__constraints__]) # type: ignore
bound = typevar.__bound__
constraints = typevar.__constraints__
not_set = object()
default = getattr(typevar, '__default__', not_set)

if (bound is not None) + (len(constraints) != 0) + (default is not not_set) > 1:
raise NotImplementedError(
'Pydantic does not support mixing more than one of TypeVar bounds, constraints and defaults'
)

if default is not not_set:
return self.generate_schema(default)
elif constraints:
return self._union_schema(typing.Union[constraints]) # type: ignore
elif bound:
schema = self.generate_schema(bound)
schema['serialization'] = core_schema.wrap_serializer_function_ser_schema(
lambda x, h: h(x), schema=core_schema.any_schema()
)
return schema
else:
return core_schema.any_schema()

Expand Down
150 changes: 149 additions & 1 deletion tests/test_generics.py
Expand Up @@ -32,7 +32,18 @@
import pytest
from dirty_equals import HasRepr, IsStr
from pydantic_core import CoreSchema, core_schema
from typing_extensions import Annotated, Literal, OrderedDict, ParamSpec, TypeVarTuple, Unpack, get_args
from typing_extensions import (
Annotated,
Literal,
OrderedDict,
ParamSpec,
TypeVarTuple,
Unpack,
get_args,
)
from typing_extensions import (
TypeVar as TypingExtensionsTypeVar,
)

from pydantic import (
BaseModel,
Expand Down Expand Up @@ -2622,3 +2633,140 @@ class Model(Generic[T], BaseModel):
m1 = Model[int](x=1)
m2 = Model[int](x=1)
assert len({m1, m2}) == 1


def test_serialize_unsubstituted_typevars_bound() -> None:
class ErrorDetails(BaseModel):
foo: str

ErrorDataT = TypeVar('ErrorDataT', bound=ErrorDetails)

class Error(BaseModel, Generic[ErrorDataT]):
message: str
details: ErrorDataT

class MyErrorDetails(ErrorDetails):
bar: str

sample_error = Error(
message='We just had an error',
details=MyErrorDetails(foo='var', bar='baz'),
)
assert sample_error.details.model_dump() == {
'foo': 'var',
'bar': 'baz',
}
assert sample_error.model_dump() == {
'message': 'We just had an error',
'details': {
'foo': 'var',
'bar': 'baz',
},
}

sample_error = Error[ErrorDetails](
message='We just had an error',
details=MyErrorDetails(foo='var', bar='baz'),
)
assert sample_error.details.model_dump() == {
'foo': 'var',
'bar': 'baz',
}
assert sample_error.model_dump() == {
'message': 'We just had an error',
'details': {
'foo': 'var',
},
}

sample_error = Error[MyErrorDetails](
message='We just had an error',
details=MyErrorDetails(foo='var', bar='baz'),
)
assert sample_error.details.model_dump() == {
'foo': 'var',
'bar': 'baz',
}
assert sample_error.model_dump() == {
'message': 'We just had an error',
'details': {
'foo': 'var',
'bar': 'baz',
},
}


@pytest.mark.parametrize(
'type_var',
[
TypingExtensionsTypeVar('ErrorDataT', default=BaseModel),
TypeVar('ErrorDataT', BaseModel, str),
],
ids=['default', 'constraint'],
)
def test_serialize_unsubstituted_typevars_bound(
type_var: Type[BaseModel],
) -> None:
class ErrorDetails(BaseModel):
foo: str

class Error(BaseModel, Generic[type_var]): # type: ignore
message: str
details: type_var

class MyErrorDetails(ErrorDetails):
bar: str

sample_error = Error(
message='We just had an error',
details=MyErrorDetails(foo='var', bar='baz'),
)
assert sample_error.details.model_dump() == {
'foo': 'var',
'bar': 'baz',
}
assert sample_error.model_dump() == {
'message': 'We just had an error',
'details': {},
}

sample_error = Error[ErrorDetails](
message='We just had an error',
details=MyErrorDetails(foo='var', bar='baz'),
)
assert sample_error.details.model_dump() == {
'foo': 'var',
'bar': 'baz',
}
assert sample_error.model_dump() == {
'message': 'We just had an error',
'details': {
'foo': 'var',
},
}

sample_error = Error[MyErrorDetails](
message='We just had an error',
details=MyErrorDetails(foo='var', bar='baz'),
)
assert sample_error.details.model_dump() == {
'foo': 'var',
'bar': 'baz',
}
assert sample_error.model_dump() == {
'message': 'We just had an error',
'details': {
'foo': 'var',
'bar': 'baz',
},
}


def test_mix_default_and_constraints() -> None:
T = TypingExtensionsTypeVar('T', str, int, default=str)

msg = 'Pydantic does not support mixing more than one of TypeVar bounds, constraints and defaults'
with pytest.raises(NotImplementedError, match=msg):

class _(BaseModel, Generic[T]):
x: T