Skip to content

Commit

Permalink
Fix BaseModel type annotations to be resolvable by typing.get_type_hi…
Browse files Browse the repository at this point in the history
…nts (#7680)

Co-authored-by: Sydney Runkle <54324534+sydney-runkle@users.noreply.github.com>
Co-authored-by: sydney-runkle <sydneymarierunkle@gmail.com>
  • Loading branch information
3 people committed Feb 13, 2024
1 parent 4672662 commit 3d1355f
Show file tree
Hide file tree
Showing 3 changed files with 155 additions and 11 deletions.
20 changes: 12 additions & 8 deletions pydantic/main.py
Expand Up @@ -32,22 +32,26 @@
from .json_schema import DEFAULT_REF_TEMPLATE, GenerateJsonSchema, JsonSchemaMode, JsonSchemaValue, model_json_schema
from .warnings import PydanticDeprecatedSince20

# Always define certain types that are needed to resolve method type hints/annotations
# (even when not type checking) via typing.get_type_hints.
Model = typing.TypeVar('Model', bound='BaseModel')
TupleGenerator = typing.Generator[typing.Tuple[str, Any], None, None]
IncEx: typing_extensions.TypeAlias = typing.Union[
typing.Set[int], typing.Set[str], typing.Dict[int, Any], typing.Dict[str, Any], None
]


if typing.TYPE_CHECKING:
from inspect import Signature
from pathlib import Path

from pydantic_core import CoreSchema, SchemaSerializer, SchemaValidator
from typing_extensions import Literal, Unpack
from typing_extensions import Unpack

from ._internal._utils import AbstractSetIntStr, MappingIntStrAny
from .deprecated.parse import Protocol as DeprecatedParseProtocol
from .fields import ComputedFieldInfo, FieldInfo, ModelPrivateAttr
from .fields import Field as _Field

TupleGenerator = typing.Generator[typing.Tuple[str, Any], None, None]
Model = typing.TypeVar('Model', bound='BaseModel')
# should be `set[int] | set[str] | dict[int, IncEx] | dict[str, IncEx] | None`, but mypy can't cope
IncEx: typing_extensions.TypeAlias = 'set[int] | set[str] | dict[int, Any] | dict[str, Any] | None'
else:
# See PyCharm issues https://youtrack.jetbrains.com/issue/PY-21915
# and https://youtrack.jetbrains.com/issue/PY-51428
Expand Down Expand Up @@ -120,7 +124,7 @@ class BaseModel(metaclass=_model_construction.ModelMetaclass):
__pydantic_decorators__: ClassVar[_decorators.DecoratorInfos]
__pydantic_generic_metadata__: ClassVar[_generics.PydanticGenericMetadata]
__pydantic_parent_namespace__: ClassVar[dict[str, Any] | None]
__pydantic_post_init__: ClassVar[None | Literal['model_post_init']]
__pydantic_post_init__: ClassVar[None | typing_extensions.Literal['model_post_init']]
__pydantic_root_model__: ClassVar[bool]
__pydantic_serializer__: ClassVar[SchemaSerializer]
__pydantic_validator__: ClassVar[SchemaValidator]
Expand Down Expand Up @@ -281,7 +285,7 @@ def model_copy(self: Model, *, update: dict[str, Any] | None = None, deep: bool
def model_dump(
self,
*,
mode: Literal['json', 'python'] | str = 'python',
mode: typing_extensions.Literal['json', 'python'] | str = 'python',
include: IncEx = None,
exclude: IncEx = None,
by_alias: bool = False,
Expand Down
4 changes: 1 addition & 3 deletions pydantic/root_model.py
Expand Up @@ -24,14 +24,12 @@
@dataclass_transform(kw_only_default=False, field_specifiers=(PydanticModelField,))
class _RootModelMetaclass(_model_construction.ModelMetaclass):
...

Model = typing.TypeVar('Model', bound='BaseModel')
else:
_RootModelMetaclass = _model_construction.ModelMetaclass

__all__ = ('RootModel',)


Model = typing.TypeVar('Model', bound='BaseModel')
RootModelRootType = typing.TypeVar('RootModelRootType')


Expand Down
142 changes: 142 additions & 0 deletions tests/test_type_hints.py
@@ -0,0 +1,142 @@
"""
Test pydantic model type hints (annotations) and that they can be
queried by :py:meth:`typing.get_type_hints`.
"""
import inspect
import sys
from typing import (
Any,
Dict,
Generic,
Optional,
Set,
TypeVar,
)

import pytest
import typing_extensions

from pydantic import (
BaseModel,
RootModel,
)
from pydantic.dataclasses import dataclass

DEPRECATED_MODEL_MEMBERS = {
'construct',
'copy',
'dict',
'from_orm',
'json',
'json_schema',
'parse_file',
'parse_obj',
}

# Disable deprecation warnings, as we enumerate members that may be
# i.e. pydantic.warnings.PydanticDeprecatedSince20: The `__fields__` attribute is deprecated,
# use `model_fields` instead.
# Additionally, only run these tests for 3.10+
pytestmark = [
pytest.mark.filterwarnings('ignore::DeprecationWarning'),
pytest.mark.skipif(sys.version_info < (3, 10), reason='requires python3.10 or higher to work properly'),
]


@pytest.fixture(name='ParentModel', scope='session')
def parent_sub_model_fixture():
class UltraSimpleModel(BaseModel):
a: float
b: int = 10

class ParentModel(BaseModel):
grape: bool
banana: UltraSimpleModel

return ParentModel


def inspect_type_hints(
obj_type, members: Optional[Set[str]] = None, exclude_members: Optional[Set[str]] = None, recursion_limit: int = 3
):
"""
Test an object and its members to make sure type hints can be resolved.
:param obj_type: Type to check
:param members: Explicit set of members to check, None to check all
:param exclude_members: Set of member names to exclude
:param recursion_limit: Recursion limit (0 to disallow)
"""

try:
hints = typing_extensions.get_type_hints(obj_type)
assert isinstance(hints, dict), f'Type annotation(s) on {obj_type} are invalid'
except NameError as ex:
raise AssertionError(f'Type annotation(s) on {obj_type} are invalid: {str(ex)}') from ex

if recursion_limit <= 0:
return

if isinstance(obj_type, type):
# Check class members
for member_name, member_obj in inspect.getmembers(obj_type):
if member_name.startswith('_'):
# Ignore private members
continue
if (members and member_name not in members) or (exclude_members and member_name in exclude_members):
continue

if inspect.isclass(member_obj) or inspect.isfunction(member_obj):
# Inspect all child members (can"t exclude specific ones)
print(f'Inspecting {obj_type}.{member_name}') # Add this line
inspect_type_hints(member_obj, recursion_limit=recursion_limit - 1)


@pytest.mark.parametrize(
('obj_type', 'members', 'exclude_members'),
[
(BaseModel, None, DEPRECATED_MODEL_MEMBERS),
(RootModel, None, DEPRECATED_MODEL_MEMBERS),
],
)
def test_obj_type_hints(obj_type, members: Optional[Set[str]], exclude_members: Optional[Set[str]]):
"""
Test an object and its members to make sure type hints can be resolved.
:param obj_type: Type to check
:param members: Explicit set of members to check, None to check all
:param exclude_members: Set of member names to exclude
"""
inspect_type_hints(obj_type, members, exclude_members)


def test_parent_sub_model(ParentModel):
inspect_type_hints(ParentModel, None, DEPRECATED_MODEL_MEMBERS)


def test_root_model_as_field():
class MyRootModel(RootModel[int]):
pass

class MyModel(BaseModel):
root_model: MyRootModel

inspect_type_hints(MyRootModel, None, DEPRECATED_MODEL_MEMBERS)
inspect_type_hints(MyModel, None, DEPRECATED_MODEL_MEMBERS)


def test_generics():
data_type = TypeVar('data_type')

class Result(BaseModel, Generic[data_type]):
data: data_type

inspect_type_hints(Result, None, DEPRECATED_MODEL_MEMBERS)
inspect_type_hints(Result[Dict[str, Any]], None, DEPRECATED_MODEL_MEMBERS)


def test_dataclasses():
@dataclass
class MyDataclass:
a: int
b: float

inspect_type_hints(MyDataclass)

0 comments on commit 3d1355f

Please sign in to comment.