Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: update typevar handling when default is not set #7719

Merged
merged 4 commits into from Oct 5, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 2 additions & 0 deletions pydantic/_internal/_generate_schema.py
Expand Up @@ -1451,6 +1451,8 @@ def _unsubstituted_typevar_schema(self, typevar: typing.TypeVar) -> core_schema.
constraints = typevar.__constraints__
not_set = object()
default = getattr(typevar, '__default__', not_set)
if default is None:
default = not_set
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can the default not be None? How does this differentiate default=None from no default?

Copy link
Contributor Author

@pmmmwh pmmmwh Oct 3, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

According to the PEP:

  • when default is unset, __default__ would be None
  • if default=None, __default__ would be NoneType

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we get rid of our internal unset thing then?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, yes - let me try and see if that works

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done :)


if (bound is not None) + (len(constraints) != 0) + (default is not not_set) > 1:
raise NotImplementedError(
Expand Down
65 changes: 64 additions & 1 deletion tests/test_generics.py
Expand Up @@ -2639,6 +2639,7 @@ def test_serialize_unsubstituted_typevars_bound() -> None:
class ErrorDetails(BaseModel):
foo: str

# This version of `TypeVar` does not support `default` on Python <3.12
ErrorDataT = TypeVar('ErrorDataT', bound=ErrorDetails)

class Error(BaseModel, Generic[ErrorDataT]):
Expand Down Expand Up @@ -2696,6 +2697,68 @@ class MyErrorDetails(ErrorDetails):
}


def test_serialize_unsubstituted_typevars_bound_default_supported() -> None:
class ErrorDetails(BaseModel):
foo: str

# This version of `TypeVar` always support `default`
ErrorDataT = TypingExtensionsTypeVar('ErrorDataT', bound=ErrorDetails)

class Error(BaseModel, Generic[ErrorDataT]):
message: str
details: ErrorDataT

class MyErrorDetails(ErrorDetails):
bar: str

sample_error = Error(
message='We just had an error',
details=MyErrorDetails(foo='var', bar='baz'),
)
assert sample_error.details.model_dump() == {
'foo': 'var',
'bar': 'baz',
}
assert sample_error.model_dump() == {
'message': 'We just had an error',
'details': {
'foo': 'var',
'bar': 'baz',
},
}

sample_error = Error[ErrorDetails](
message='We just had an error',
details=MyErrorDetails(foo='var', bar='baz'),
)
assert sample_error.details.model_dump() == {
'foo': 'var',
'bar': 'baz',
}
assert sample_error.model_dump() == {
'message': 'We just had an error',
'details': {
'foo': 'var',
},
}

sample_error = Error[MyErrorDetails](
message='We just had an error',
details=MyErrorDetails(foo='var', bar='baz'),
)
assert sample_error.details.model_dump() == {
'foo': 'var',
'bar': 'baz',
}
assert sample_error.model_dump() == {
'message': 'We just had an error',
'details': {
'foo': 'var',
'bar': 'baz',
},
}


@pytest.mark.parametrize(
'type_var',
[
Expand All @@ -2704,7 +2767,7 @@ class MyErrorDetails(ErrorDetails):
],
ids=['default', 'constraint'],
)
def test_serialize_unsubstituted_typevars_bound(
def test_serialize_unsubstituted_typevars_variants(
type_var: Type[BaseModel],
) -> None:
class ErrorDetails(BaseModel):
Expand Down