Skip to content

Commit

Permalink
Fix schema generation for generics with union type bounds (#7899)
Browse files Browse the repository at this point in the history
  • Loading branch information
sydney-runkle committed Oct 23, 2023
1 parent d044f39 commit 47aa97e
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 2 deletions.
10 changes: 8 additions & 2 deletions pydantic/_internal/_generate_schema.py
Expand Up @@ -1269,13 +1269,20 @@ def _type_schema(self) -> core_schema.CoreSchema:
custom_error_message='Input should be a type',
)

def _union_is_subclass_schema(self, union_type: Any) -> core_schema.CoreSchema:
"""Generate schema for `Type[Union[X, ...]]`."""
args = self._get_args_resolving_forward_refs(union_type, required=True)
return core_schema.union_schema([self.generate_schema(typing.Type[args]) for args in args])

def _subclass_schema(self, type_: Any) -> core_schema.CoreSchema:
"""Generate schema for a Type, e.g. `Type[int]`."""
type_param = self._get_first_arg_or_any(type_)
if type_param == Any:
return self._type_schema()
elif isinstance(type_param, typing.TypeVar):
if type_param.__bound__:
if _typing_extra.origin_is_union(get_origin(type_param.__bound__)):
return self._union_is_subclass_schema(type_param.__bound__)
return core_schema.is_subclass_schema(type_param.__bound__)
elif type_param.__constraints__:
return core_schema.union_schema(
Expand All @@ -1284,8 +1291,7 @@ def _subclass_schema(self, type_: Any) -> core_schema.CoreSchema:
else:
return self._type_schema()
elif _typing_extra.origin_is_union(get_origin(type_param)):
args = self._get_args_resolving_forward_refs(type_param, required=True)
return core_schema.union_schema([self.generate_schema(typing.Type[args]) for args in args])
return self._union_is_subclass_schema(type_param)
else:
return core_schema.is_subclass_schema(type_param)

Expand Down
16 changes: 16 additions & 0 deletions tests/test_generics.py
Expand Up @@ -1508,6 +1508,22 @@ class MyInt(int):
ReferenceModel[MyInt]


def test_generic_with_referenced_generic_union_type_bound():
T = TypeVar('T', bound=Union[str, int])

class ModelWithType(BaseModel, Generic[T]):
some_type: Type[T]

class MyInt(int):
...

class MyStr(str):
...

ModelWithType[MyInt]
ModelWithType[MyStr]


def test_generic_with_referenced_generic_type_constraints():
T = TypeVar('T', int, str)

Expand Down

0 comments on commit 47aa97e

Please sign in to comment.