Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for typing.Self (fix #5992) #9023

Merged
merged 10 commits into from Mar 25, 2024
36 changes: 30 additions & 6 deletions pydantic/_internal/_generate_schema.py
Expand Up @@ -77,10 +77,8 @@
from ._fields import collect_dataclass_fields, get_type_hints_infer_globalns
from ._forward_ref import PydanticRecursiveRef
from ._generics import get_standard_typevars_map, has_instance_in_type, recursively_defined_type_refs, replace_types
from ._schema_generation_shared import (
CallbackGetCoreSchemaHandler,
)
from ._typing_extra import is_finalvar
from ._schema_generation_shared import CallbackGetCoreSchemaHandler
from ._typing_extra import is_finalvar, is_self_type
from ._utils import lenient_issubclass

if TYPE_CHECKING:
Expand Down Expand Up @@ -315,6 +313,7 @@ class GenerateSchema:
'_needs_apply_discriminated_union',
'_has_invalid_schema',
'field_name_stack',
'model_type_stack',
'defs',
)

Expand All @@ -331,19 +330,22 @@ def __init__(
self._needs_apply_discriminated_union = False
self._has_invalid_schema = False
self.field_name_stack = _FieldNameStack()
self.model_type_stack = _ModelTypeStack()
self.defs = _Definitions()

@classmethod
def __from_parent(
cls,
config_wrapper_stack: ConfigWrapperStack,
types_namespace_stack: TypesNamespaceStack,
model_type_stack: _ModelTypeStack,
typevars_map: dict[Any, Any] | None,
defs: _Definitions,
) -> GenerateSchema:
obj = cls.__new__(cls)
obj._config_wrapper_stack = config_wrapper_stack
obj._types_namespace_stack = types_namespace_stack
obj.model_type_stack = model_type_stack
obj._typevars_map = typevars_map
obj._needs_apply_discriminated_union = False
obj._has_invalid_schema = False
Expand All @@ -365,6 +367,7 @@ def _current_generate_schema(self) -> GenerateSchema:
return cls.__from_parent(
self._config_wrapper_stack,
self._types_namespace_stack,
self.model_type_stack,
self._typevars_map,
self.defs,
)
Expand Down Expand Up @@ -637,6 +640,8 @@ def _generate_schema_from_property(self, obj: Any, source: Any) -> core_schema.C
decide whether to use a `__pydantic_core_schema__` attribute, or generate a fresh schema.
"""
# avoid calling `__get_pydantic_core_schema__` if we've already visited this object
if is_self_type(obj):
obj = self.model_type_stack.get()
with self.defs.get_schema_or_ref(obj) as (_, maybe_schema):
if maybe_schema is not None:
return maybe_schema
Expand Down Expand Up @@ -773,7 +778,8 @@ def _generate_schema_inner(self, obj: Any) -> core_schema.CoreSchema:
from ..main import BaseModel

if lenient_issubclass(obj, BaseModel):
return self._model_schema(obj)
with self.model_type_stack.push(obj):
return self._model_schema(obj)

if isinstance(obj, PydanticRecursiveRef):
return core_schema.definition_reference_schema(schema_ref=obj.type_ref)
Expand Down Expand Up @@ -853,7 +859,6 @@ def match_type(self, obj: Any) -> core_schema.CoreSchema: # noqa: C901

if _typing_extra.is_dataclass(obj):
return self._dataclass_schema(obj, None)

res = self._get_prepare_pydantic_annotations_for_known_type(obj, ())
if res is not None:
source_type, annotations = res
Expand Down Expand Up @@ -2289,3 +2294,22 @@ def get(self) -> str | None:
return self._stack[-1]
else:
return None


class _ModelTypeStack:
__slots__ = ('_stack',)

def __init__(self) -> None:
self._stack: list[str] = []

@contextmanager
def push(self, type_obj: str) -> Iterator[None]:
self._stack.append(type_obj)
yield
self._stack.pop()

def get(self) -> str | None:
Youssefares marked this conversation as resolved.
Show resolved Hide resolved
if self._stack:
return self._stack[-1]
else:
return None
5 changes: 5 additions & 0 deletions pydantic/_internal/_typing_extra.py
Expand Up @@ -467,3 +467,8 @@ def is_generic_alias(type_: type[Any]) -> bool:

def is_generic_alias(type_: type[Any]) -> bool:
return isinstance(type_, typing._GenericAlias) # type: ignore


def is_self_type(tp: Any) -> bool:
"""Check if a given class is a Self type (from `typing` or `typing_extensions`)"""
return isinstance(tp, typing_base) and getattr(tp, '_name', None) == 'Self'
40 changes: 9 additions & 31 deletions pydantic/v1/typing.py
Expand Up @@ -2,38 +2,15 @@
import typing
from collections.abc import Callable
from os import PathLike
from typing import ( # type: ignore
TYPE_CHECKING,
AbstractSet,
Any,
Callable as TypingCallable,
ClassVar,
Dict,
ForwardRef,
Generator,
Iterable,
List,
Mapping,
NewType,
Optional,
Sequence,
Set,
Tuple,
Type,
TypeVar,
Union,
_eval_type,
cast,
get_type_hints,
)
from typing import TYPE_CHECKING, AbstractSet, Any
from typing import Callable as TypingCallable # type: ignore
from typing import (ClassVar, Dict, ForwardRef, Generator, Iterable, List,
Mapping, NewType, Optional, Sequence, Set, Tuple, Type,
TypeVar, Union, _eval_type, cast, get_type_hints)

from typing_extensions import (
Annotated,
Final,
Literal,
NotRequired as TypedDictNotRequired,
Required as TypedDictRequired,
)
from typing_extensions import Annotated, Final, Literal
from typing_extensions import NotRequired as TypedDictNotRequired
from typing_extensions import Required as TypedDictRequired

try:
from typing import _TypingBase as typing_base # type: ignore
Expand Down Expand Up @@ -317,6 +294,7 @@ def is_union(tp: Optional[Type[Any]]) -> bool:
'is_union',
'StrPath',
'MappingIntStrAny',
'is_self_type'
)


Expand Down
75 changes: 75 additions & 0 deletions tests/test_types_self.py
@@ -0,0 +1,75 @@
import typing

import pytest
import typing_extensions

from pydantic import BaseModel, ValidationError


@pytest.fixture(
name='Self',
params=[
pytest.param(typing, id='typing.Self'),
pytest.param(typing_extensions, id='t_e.Self'),
],
)
def fixture_self_all(request):
try:
return request.param.Self
except AttributeError:
pytest.skip(f'Self is not available from {request.param}')
Youssefares marked this conversation as resolved.
Show resolved Hide resolved


def test_recursive_model(Self):
class SelfRef(BaseModel):
data: int
ref: typing.Optional[Self] = None

assert SelfRef(data=1, ref={'data': 2}).model_dump() == {'data': 1, 'ref': {'data': 2, 'ref': None}}


def test_recursive_model_invalid(Self):
class SelfRef(BaseModel):
data: int
ref: typing.Optional[Self] = None

with pytest.raises(
ValidationError,
match=r'ref\.ref\s+Input should be a valid dictionary or instance of SelfRef \[type=model_type,',
):
SelfRef(data=1, ref={'data': 2, 'ref': 3}).model_dump()


def test_recursive_model_with_subclass(Self):
"""
Youssefares marked this conversation as resolved.
Show resolved Hide resolved
Self refs should be valid in covariant direction
"""

class SelfRef(BaseModel):
data: int
ref: Self | None = None

class SubSelfRef(SelfRef):
pass
sydney-runkle marked this conversation as resolved.
Show resolved Hide resolved

assert SubSelfRef(data=1, ref=SubSelfRef(data=2)).model_dump() == {'data': 1, 'ref': {'data': 2, 'ref': None}}
assert SelfRef(data=1, ref=SubSelfRef(data=2)).model_dump() == {'data': 1, 'ref': {'data': 2, 'ref': None}}


def test_recursive_model_with_subclass_invalid(Self):
"""
Self refs are invalid in contravariant direction
"""

class SelfRef(BaseModel):
data: int
ref: Self | None = None

class SubSelfRef(SelfRef):
pass

with pytest.raises(
ValidationError,
match=r'ref\s+Input should be a valid dictionary or instance of SubSelfRef \[type=model_type,',
):
SubSelfRef(data=1, ref=SelfRef(data=2)).model_dump()