Skip to content

Commit

Permalink
fix: update typevar handling when default is not set (#7719)
Browse files Browse the repository at this point in the history
  • Loading branch information
pmmmwh committed Oct 5, 2023
1 parent 832a045 commit 32ea570
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 5 deletions.
7 changes: 3 additions & 4 deletions pydantic/_internal/_generate_schema.py
Expand Up @@ -1449,15 +1449,14 @@ def _unsubstituted_typevar_schema(self, typevar: typing.TypeVar) -> core_schema.

bound = typevar.__bound__
constraints = typevar.__constraints__
not_set = object()
default = getattr(typevar, '__default__', not_set)
default = getattr(typevar, '__default__', None)

if (bound is not None) + (len(constraints) != 0) + (default is not not_set) > 1:
if (bound is not None) + (len(constraints) != 0) + (default is not None) > 1:
raise NotImplementedError(
'Pydantic does not support mixing more than one of TypeVar bounds, constraints and defaults'
)

if default is not not_set:
if default is not None:
return self.generate_schema(default)
elif constraints:
return self._union_schema(typing.Union[constraints]) # type: ignore
Expand Down
65 changes: 64 additions & 1 deletion tests/test_generics.py
Expand Up @@ -2639,6 +2639,7 @@ def test_serialize_unsubstituted_typevars_bound() -> None:
class ErrorDetails(BaseModel):
foo: str

# This version of `TypeVar` does not support `default` on Python <3.12
ErrorDataT = TypeVar('ErrorDataT', bound=ErrorDetails)

class Error(BaseModel, Generic[ErrorDataT]):
Expand Down Expand Up @@ -2696,6 +2697,68 @@ class MyErrorDetails(ErrorDetails):
}


def test_serialize_unsubstituted_typevars_bound_default_supported() -> None:
class ErrorDetails(BaseModel):
foo: str

# This version of `TypeVar` always support `default`
ErrorDataT = TypingExtensionsTypeVar('ErrorDataT', bound=ErrorDetails)

class Error(BaseModel, Generic[ErrorDataT]):
message: str
details: ErrorDataT

class MyErrorDetails(ErrorDetails):
bar: str

sample_error = Error(
message='We just had an error',
details=MyErrorDetails(foo='var', bar='baz'),
)
assert sample_error.details.model_dump() == {
'foo': 'var',
'bar': 'baz',
}
assert sample_error.model_dump() == {
'message': 'We just had an error',
'details': {
'foo': 'var',
'bar': 'baz',
},
}

sample_error = Error[ErrorDetails](
message='We just had an error',
details=MyErrorDetails(foo='var', bar='baz'),
)
assert sample_error.details.model_dump() == {
'foo': 'var',
'bar': 'baz',
}
assert sample_error.model_dump() == {
'message': 'We just had an error',
'details': {
'foo': 'var',
},
}

sample_error = Error[MyErrorDetails](
message='We just had an error',
details=MyErrorDetails(foo='var', bar='baz'),
)
assert sample_error.details.model_dump() == {
'foo': 'var',
'bar': 'baz',
}
assert sample_error.model_dump() == {
'message': 'We just had an error',
'details': {
'foo': 'var',
'bar': 'baz',
},
}


@pytest.mark.parametrize(
'type_var',
[
Expand All @@ -2704,7 +2767,7 @@ class MyErrorDetails(ErrorDetails):
],
ids=['default', 'constraint'],
)
def test_serialize_unsubstituted_typevars_bound(
def test_serialize_unsubstituted_typevars_variants(
type_var: Type[BaseModel],
) -> None:
class ErrorDetails(BaseModel):
Expand Down

0 comments on commit 32ea570

Please sign in to comment.