Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix BaseModel type annotations to be resolvable by typing.get_type_hints #7680

Merged
merged 6 commits into from Feb 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
20 changes: 12 additions & 8 deletions pydantic/main.py
Expand Up @@ -32,22 +32,26 @@
from .json_schema import DEFAULT_REF_TEMPLATE, GenerateJsonSchema, JsonSchemaMode, JsonSchemaValue, model_json_schema
from .warnings import PydanticDeprecatedSince20

# Always define certain types that are needed to resolve method type hints/annotations
# (even when not type checking) via typing.get_type_hints.
Model = typing.TypeVar('Model', bound='BaseModel')
TupleGenerator = typing.Generator[typing.Tuple[str, Any], None, None]
IncEx: typing_extensions.TypeAlias = typing.Union[
typing.Set[int], typing.Set[str], typing.Dict[int, Any], typing.Dict[str, Any], None
]
sydney-runkle marked this conversation as resolved.
Show resolved Hide resolved


if typing.TYPE_CHECKING:
from inspect import Signature
from pathlib import Path

from pydantic_core import CoreSchema, SchemaSerializer, SchemaValidator
from typing_extensions import Literal, Unpack
from typing_extensions import Unpack

from ._internal._utils import AbstractSetIntStr, MappingIntStrAny
from .deprecated.parse import Protocol as DeprecatedParseProtocol
from .fields import ComputedFieldInfo, FieldInfo, ModelPrivateAttr
from .fields import Field as _Field

TupleGenerator = typing.Generator[typing.Tuple[str, Any], None, None]
Model = typing.TypeVar('Model', bound='BaseModel')
# should be `set[int] | set[str] | dict[int, IncEx] | dict[str, IncEx] | None`, but mypy can't cope
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment should be moved. Or better yet addressed - can mypy not handle pydantic.config.JsonValue?

IncEx: typing_extensions.TypeAlias = 'set[int] | set[str] | dict[int, Any] | dict[str, Any] | None'
else:
# See PyCharm issues https://youtrack.jetbrains.com/issue/PY-21915
# and https://youtrack.jetbrains.com/issue/PY-51428
Expand Down Expand Up @@ -120,7 +124,7 @@ class BaseModel(metaclass=_model_construction.ModelMetaclass):
__pydantic_decorators__: ClassVar[_decorators.DecoratorInfos]
__pydantic_generic_metadata__: ClassVar[_generics.PydanticGenericMetadata]
__pydantic_parent_namespace__: ClassVar[dict[str, Any] | None]
__pydantic_post_init__: ClassVar[None | Literal['model_post_init']]
__pydantic_post_init__: ClassVar[None | typing_extensions.Literal['model_post_init']]
__pydantic_root_model__: ClassVar[bool]
__pydantic_serializer__: ClassVar[SchemaSerializer]
__pydantic_validator__: ClassVar[SchemaValidator]
Expand Down Expand Up @@ -281,7 +285,7 @@ def model_copy(self: Model, *, update: dict[str, Any] | None = None, deep: bool
def model_dump(
self,
*,
mode: Literal['json', 'python'] | str = 'python',
mode: typing_extensions.Literal['json', 'python'] | str = 'python',
include: IncEx = None,
exclude: IncEx = None,
by_alias: bool = False,
Expand Down
4 changes: 1 addition & 3 deletions pydantic/root_model.py
Expand Up @@ -24,14 +24,12 @@
@dataclass_transform(kw_only_default=False, field_specifiers=(PydanticModelField,))
class _RootModelMetaclass(_model_construction.ModelMetaclass):
...

Model = typing.TypeVar('Model', bound='BaseModel')
else:
_RootModelMetaclass = _model_construction.ModelMetaclass

__all__ = ('RootModel',)


Model = typing.TypeVar('Model', bound='BaseModel')
RootModelRootType = typing.TypeVar('RootModelRootType')


Expand Down
142 changes: 142 additions & 0 deletions tests/test_type_hints.py
@@ -0,0 +1,142 @@
"""
Test pydantic model type hints (annotations) and that they can be
queried by :py:meth:`typing.get_type_hints`.
"""
import inspect
import sys
from typing import (
Any,
Dict,
Generic,
Optional,
Set,
TypeVar,
)

import pytest
import typing_extensions

from pydantic import (
BaseModel,
RootModel,
)
from pydantic.dataclasses import dataclass

DEPRECATED_MODEL_MEMBERS = {
'construct',
'copy',
'dict',
'from_orm',
'json',
'json_schema',
'parse_file',
'parse_obj',
}

# Disable deprecation warnings, as we enumerate members that may be
# i.e. pydantic.warnings.PydanticDeprecatedSince20: The `__fields__` attribute is deprecated,
# use `model_fields` instead.
# Additionally, only run these tests for 3.10+
pytestmark = [
pytest.mark.filterwarnings('ignore::DeprecationWarning'),
pytest.mark.skipif(sys.version_info < (3, 10), reason='requires python3.10 or higher to work properly'),
]


@pytest.fixture(name='ParentModel', scope='session')
def parent_sub_model_fixture():
class UltraSimpleModel(BaseModel):
a: float
b: int = 10

class ParentModel(BaseModel):
grape: bool
banana: UltraSimpleModel

return ParentModel


def inspect_type_hints(
obj_type, members: Optional[Set[str]] = None, exclude_members: Optional[Set[str]] = None, recursion_limit: int = 3
):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it make sense to return the hints from this function, so that they can be asserted to be of the correct form? Or is there just way too much stuff there?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good question. They're pretty gnarly, so I'd opt to skip returning for now. Down the line, if someone's feeling quite invested, they could make the addition, but I don't see it as a blocker on this PR.

"""
Test an object and its members to make sure type hints can be resolved.
:param obj_type: Type to check
:param members: Explicit set of members to check, None to check all
:param exclude_members: Set of member names to exclude
:param recursion_limit: Recursion limit (0 to disallow)
"""

try:
hints = typing_extensions.get_type_hints(obj_type)
assert isinstance(hints, dict), f'Type annotation(s) on {obj_type} are invalid'
except NameError as ex:
raise AssertionError(f'Type annotation(s) on {obj_type} are invalid: {str(ex)}') from ex

if recursion_limit <= 0:
return

if isinstance(obj_type, type):
# Check class members
for member_name, member_obj in inspect.getmembers(obj_type):
if member_name.startswith('_'):
# Ignore private members
continue
if (members and member_name not in members) or (exclude_members and member_name in exclude_members):
continue

if inspect.isclass(member_obj) or inspect.isfunction(member_obj):
# Inspect all child members (can"t exclude specific ones)
print(f'Inspecting {obj_type}.{member_name}') # Add this line
sydney-runkle marked this conversation as resolved.
Show resolved Hide resolved
inspect_type_hints(member_obj, recursion_limit=recursion_limit - 1)


@pytest.mark.parametrize(
('obj_type', 'members', 'exclude_members'),
[
(BaseModel, None, DEPRECATED_MODEL_MEMBERS),
(RootModel, None, DEPRECATED_MODEL_MEMBERS),
],
)
def test_obj_type_hints(obj_type, members: Optional[Set[str]], exclude_members: Optional[Set[str]]):
"""
Test an object and its members to make sure type hints can be resolved.
:param obj_type: Type to check
:param members: Explicit set of members to check, None to check all
:param exclude_members: Set of member names to exclude
"""
inspect_type_hints(obj_type, members, exclude_members)


def test_parent_sub_model(ParentModel):
inspect_type_hints(ParentModel, None, DEPRECATED_MODEL_MEMBERS)


def test_root_model_as_field():
class MyRootModel(RootModel[int]):
pass

class MyModel(BaseModel):
root_model: MyRootModel

inspect_type_hints(MyRootModel, None, DEPRECATED_MODEL_MEMBERS)
inspect_type_hints(MyModel, None, DEPRECATED_MODEL_MEMBERS)


def test_generics():
data_type = TypeVar('data_type')

class Result(BaseModel, Generic[data_type]):
data: data_type

inspect_type_hints(Result, None, DEPRECATED_MODEL_MEMBERS)
inspect_type_hints(Result[Dict[str, Any]], None, DEPRECATED_MODEL_MEMBERS)


def test_dataclasses():
@dataclass
class MyDataclass:
a: int
b: float

inspect_type_hints(MyDataclass)