Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for typing.Self (fix #5992) #9023

Merged
merged 10 commits into from Mar 25, 2024
36 changes: 30 additions & 6 deletions pydantic/_internal/_generate_schema.py
Expand Up @@ -77,10 +77,8 @@
from ._fields import collect_dataclass_fields, get_type_hints_infer_globalns
from ._forward_ref import PydanticRecursiveRef
from ._generics import get_standard_typevars_map, has_instance_in_type, recursively_defined_type_refs, replace_types
from ._schema_generation_shared import (
CallbackGetCoreSchemaHandler,
)
from ._typing_extra import is_finalvar
from ._schema_generation_shared import CallbackGetCoreSchemaHandler
from ._typing_extra import is_finalvar, is_self_type
from ._utils import lenient_issubclass

if TYPE_CHECKING:
Expand Down Expand Up @@ -315,6 +313,7 @@ class GenerateSchema:
'_needs_apply_discriminated_union',
'_has_invalid_schema',
'field_name_stack',
'model_type_stack',
'defs',
)

Expand All @@ -331,19 +330,22 @@ def __init__(
self._needs_apply_discriminated_union = False
self._has_invalid_schema = False
self.field_name_stack = _FieldNameStack()
self.model_type_stack = _ModelTypeStack()
self.defs = _Definitions()

@classmethod
def __from_parent(
cls,
config_wrapper_stack: ConfigWrapperStack,
types_namespace_stack: TypesNamespaceStack,
model_type_stack: _ModelTypeStack,
typevars_map: dict[Any, Any] | None,
defs: _Definitions,
) -> GenerateSchema:
obj = cls.__new__(cls)
obj._config_wrapper_stack = config_wrapper_stack
obj._types_namespace_stack = types_namespace_stack
obj.model_type_stack = model_type_stack
obj._typevars_map = typevars_map
obj._needs_apply_discriminated_union = False
obj._has_invalid_schema = False
Expand All @@ -365,6 +367,7 @@ def _current_generate_schema(self) -> GenerateSchema:
return cls.__from_parent(
self._config_wrapper_stack,
self._types_namespace_stack,
self.model_type_stack,
self._typevars_map,
self.defs,
)
Expand Down Expand Up @@ -637,6 +640,8 @@ def _generate_schema_from_property(self, obj: Any, source: Any) -> core_schema.C
decide whether to use a `__pydantic_core_schema__` attribute, or generate a fresh schema.
"""
# avoid calling `__get_pydantic_core_schema__` if we've already visited this object
if is_self_type(obj):
obj = self.model_type_stack.get()
with self.defs.get_schema_or_ref(obj) as (_, maybe_schema):
if maybe_schema is not None:
return maybe_schema
Expand Down Expand Up @@ -773,7 +778,8 @@ def _generate_schema_inner(self, obj: Any) -> core_schema.CoreSchema:
from ..main import BaseModel

if lenient_issubclass(obj, BaseModel):
return self._model_schema(obj)
with self.model_type_stack.push(obj):
return self._model_schema(obj)

if isinstance(obj, PydanticRecursiveRef):
return core_schema.definition_reference_schema(schema_ref=obj.type_ref)
Expand Down Expand Up @@ -853,7 +859,6 @@ def match_type(self, obj: Any) -> core_schema.CoreSchema: # noqa: C901

if _typing_extra.is_dataclass(obj):
return self._dataclass_schema(obj, None)

res = self._get_prepare_pydantic_annotations_for_known_type(obj, ())
if res is not None:
source_type, annotations = res
Expand Down Expand Up @@ -2289,3 +2294,22 @@ def get(self) -> str | None:
return self._stack[-1]
else:
return None


class _ModelTypeStack:
__slots__ = ('_stack',)

def __init__(self) -> None:
self._stack: list[str] = []

@contextmanager
def push(self, type_obj: str) -> Iterator[None]:
self._stack.append(type_obj)
yield
self._stack.pop()

def get(self) -> str | None:
Youssefares marked this conversation as resolved.
Show resolved Hide resolved
if self._stack:
return self._stack[-1]
else:
return None
5 changes: 5 additions & 0 deletions pydantic/_internal/_typing_extra.py
Expand Up @@ -467,3 +467,8 @@ def is_generic_alias(type_: type[Any]) -> bool:

def is_generic_alias(type_: type[Any]) -> bool:
return isinstance(type_, typing._GenericAlias) # type: ignore


def is_self_type(tp: Any) -> bool:
"""Check if a given class is a Self type (from `typing` or `typing_extensions`)"""
return isinstance(tp, typing_base) and getattr(tp, '_name', None) == 'Self'
82 changes: 82 additions & 0 deletions tests/test_types_self.py
@@ -0,0 +1,82 @@
import typing

import pytest
import typing_extensions

from pydantic import BaseModel, ValidationError


@pytest.fixture(
name='Self',
params=[
pytest.param(typing, id='typing.Self'),
pytest.param(typing_extensions, id='t_e.Self'),
],
)
def fixture_self_all(request):
try:
return request.param.Self
except AttributeError:
pytest.skip(f'Self is not available from {request.param}')
Youssefares marked this conversation as resolved.
Show resolved Hide resolved


def test_recursive_model(Self):
class SelfRef(BaseModel):
data: int
ref: typing.Optional[Self] = None

assert SelfRef(data=1, ref={'data': 2}).model_dump() == {'data': 1, 'ref': {'data': 2, 'ref': None}}


def test_recursive_model_invalid(Self):
class SelfRef(BaseModel):
data: int
ref: typing.Optional[Self] = None

with pytest.raises(
ValidationError,
match=r'ref\.ref\s+Input should be a valid dictionary or instance of SelfRef \[type=model_type,',
):
SelfRef(data=1, ref={'data': 2, 'ref': 3}).model_dump()


def test_recursive_model_with_subclass(Self):
"""
Youssefares marked this conversation as resolved.
Show resolved Hide resolved
Self refs should be valid and should reference the correct class in covariant direction
"""

class SelfRef(BaseModel):
x: int
ref: Self | None = None

class SubSelfRef(SelfRef):
y: int

assert SubSelfRef(x=1, ref=SubSelfRef(x=3, y=4), y=2).model_dump() == {
'x': 1,
'ref': {'x': 3, 'ref': None, 'y': 4}, # SubSelfRef.ref: SubSelfRef
'y': 2,
}
assert SelfRef(x=1, ref=SubSelfRef(x=2, y=3)).model_dump() == {
'x': 1,
'ref': {'x': 2, 'ref': None},
} # SelfRef.ref: SelfRef


def test_recursive_model_with_subclass_invalid(Self):
"""
Self refs are invalid in contravariant direction
"""

class SelfRef(BaseModel):
x: int
ref: Self | None = None

class SubSelfRef(SelfRef):
y: int

with pytest.raises(
ValidationError,
match=r'ref\s+Input should be a valid dictionary or instance of SubSelfRef \[type=model_type,',
):
SubSelfRef(x=1, ref=SelfRef(x=2), y=3).model_dump()