Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix schema generation for generics with union type bounds #7899

Merged
merged 2 commits into from Oct 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
10 changes: 8 additions & 2 deletions pydantic/_internal/_generate_schema.py
Expand Up @@ -1269,13 +1269,20 @@ def _type_schema(self) -> core_schema.CoreSchema:
custom_error_message='Input should be a type',
)

def _union_is_subclass_schema(self, union_type: Any) -> core_schema.CoreSchema:
"""Generate schema for `Type[Union[X, ...]]`."""
args = self._get_args_resolving_forward_refs(union_type, required=True)
return core_schema.union_schema([self.generate_schema(typing.Type[args]) for args in args])

def _subclass_schema(self, type_: Any) -> core_schema.CoreSchema:
"""Generate schema for a Type, e.g. `Type[int]`."""
type_param = self._get_first_arg_or_any(type_)
if type_param == Any:
return self._type_schema()
elif isinstance(type_param, typing.TypeVar):
if type_param.__bound__:
if _typing_extra.origin_is_union(get_origin(type_param.__bound__)):
return self._union_is_subclass_schema(type_param.__bound__)
return core_schema.is_subclass_schema(type_param.__bound__)
elif type_param.__constraints__:
return core_schema.union_schema(
Expand All @@ -1284,8 +1291,7 @@ def _subclass_schema(self, type_: Any) -> core_schema.CoreSchema:
else:
return self._type_schema()
elif _typing_extra.origin_is_union(get_origin(type_param)):
args = self._get_args_resolving_forward_refs(type_param, required=True)
return core_schema.union_schema([self.generate_schema(typing.Type[args]) for args in args])
return self._union_is_subclass_schema(type_param)
else:
return core_schema.is_subclass_schema(type_param)

Expand Down
16 changes: 16 additions & 0 deletions tests/test_generics.py
Expand Up @@ -1508,6 +1508,22 @@ class MyInt(int):
ReferenceModel[MyInt]


def test_generic_with_referenced_generic_union_type_bound():
T = TypeVar('T', bound=Union[str, int])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From typing docs there is an example of a union bound, so I think it is completely correct to test this case 👍


class ModelWithType(BaseModel, Generic[T]):
some_type: Type[T]

class MyInt(int):
...

class MyStr(str):
...

ModelWithType[MyInt]
ModelWithType[MyStr]


def test_generic_with_referenced_generic_type_constraints():
T = TypeVar('T', int, str)

Expand Down