Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for typing.Self (fix #5992) #9023

Merged
merged 10 commits into from Mar 25, 2024
51 changes: 42 additions & 9 deletions pydantic/_internal/_generate_schema.py
Expand Up @@ -77,10 +77,8 @@
from ._fields import collect_dataclass_fields, get_type_hints_infer_globalns
from ._forward_ref import PydanticRecursiveRef
from ._generics import get_standard_typevars_map, has_instance_in_type, recursively_defined_type_refs, replace_types
from ._schema_generation_shared import (
CallbackGetCoreSchemaHandler,
)
from ._typing_extra import is_finalvar
from ._schema_generation_shared import CallbackGetCoreSchemaHandler
from ._typing_extra import is_finalvar, is_self_type
from ._utils import lenient_issubclass

if TYPE_CHECKING:
Expand Down Expand Up @@ -315,6 +313,7 @@ class GenerateSchema:
'_needs_apply_discriminated_union',
'_has_invalid_schema',
'field_name_stack',
'model_type_stack',
'defs',
)

Expand All @@ -331,19 +330,22 @@ def __init__(
self._needs_apply_discriminated_union = False
self._has_invalid_schema = False
self.field_name_stack = _FieldNameStack()
self.model_type_stack = _ModelTypeStack()
self.defs = _Definitions()

@classmethod
def __from_parent(
cls,
config_wrapper_stack: ConfigWrapperStack,
types_namespace_stack: TypesNamespaceStack,
model_type_stack: _ModelTypeStack,
typevars_map: dict[Any, Any] | None,
defs: _Definitions,
) -> GenerateSchema:
obj = cls.__new__(cls)
obj._config_wrapper_stack = config_wrapper_stack
obj._types_namespace_stack = types_namespace_stack
obj.model_type_stack = model_type_stack
obj._typevars_map = typevars_map
obj._needs_apply_discriminated_union = False
obj._has_invalid_schema = False
Expand All @@ -365,6 +367,7 @@ def _current_generate_schema(self) -> GenerateSchema:
return cls.__from_parent(
self._config_wrapper_stack,
self._types_namespace_stack,
self.model_type_stack,
self._typevars_map,
self.defs,
)
Expand Down Expand Up @@ -637,6 +640,8 @@ def _generate_schema_from_property(self, obj: Any, source: Any) -> core_schema.C
decide whether to use a `__pydantic_core_schema__` attribute, or generate a fresh schema.
"""
# avoid calling `__get_pydantic_core_schema__` if we've already visited this object
if is_self_type(obj):
obj = self.model_type_stack.get()
with self.defs.get_schema_or_ref(obj) as (_, maybe_schema):
if maybe_schema is not None:
return maybe_schema
Expand Down Expand Up @@ -773,7 +778,8 @@ def _generate_schema_inner(self, obj: Any) -> core_schema.CoreSchema:
from ..main import BaseModel

if lenient_issubclass(obj, BaseModel):
return self._model_schema(obj)
with self.model_type_stack.push(obj):
return self._model_schema(obj)

if isinstance(obj, PydanticRecursiveRef):
return core_schema.definition_reference_schema(schema_ref=obj.type_ref)
Expand Down Expand Up @@ -853,7 +859,6 @@ def match_type(self, obj: Any) -> core_schema.CoreSchema: # noqa: C901

if _typing_extra.is_dataclass(obj):
return self._dataclass_schema(obj, None)

res = self._get_prepare_pydantic_annotations_for_known_type(obj, ())
if res is not None:
source_type, annotations = res
Expand Down Expand Up @@ -1238,7 +1243,10 @@ def _typed_dict_schema(self, typed_dict_cls: Any, origin: Any) -> core_schema.Co
"""
from ..fields import FieldInfo

with self.defs.get_schema_or_ref(typed_dict_cls) as (typed_dict_ref, maybe_schema):
with self.model_type_stack.push(typed_dict_cls), self.defs.get_schema_or_ref(typed_dict_cls) as (
typed_dict_ref,
maybe_schema,
):
if maybe_schema is not None:
return maybe_schema

Expand Down Expand Up @@ -1325,7 +1333,10 @@ def _typed_dict_schema(self, typed_dict_cls: Any, origin: Any) -> core_schema.Co

def _namedtuple_schema(self, namedtuple_cls: Any, origin: Any) -> core_schema.CoreSchema:
"""Generate schema for a NamedTuple."""
with self.defs.get_schema_or_ref(namedtuple_cls) as (namedtuple_ref, maybe_schema):
with self.model_type_stack.push(namedtuple_cls), self.defs.get_schema_or_ref(namedtuple_cls) as (
namedtuple_ref,
maybe_schema,
):
if maybe_schema is not None:
return maybe_schema
typevars_map = get_standard_typevars_map(namedtuple_cls)
Expand Down Expand Up @@ -1514,7 +1525,10 @@ def _dataclass_schema(
self, dataclass: type[StandardDataclass], origin: type[StandardDataclass] | None
) -> core_schema.CoreSchema:
"""Generate schema for a dataclass."""
with self.defs.get_schema_or_ref(dataclass) as (dataclass_ref, maybe_schema):
with self.model_type_stack.push(dataclass), self.defs.get_schema_or_ref(dataclass) as (
dataclass_ref,
maybe_schema,
):
if maybe_schema is not None:
return maybe_schema

Expand Down Expand Up @@ -2289,3 +2303,22 @@ def get(self) -> str | None:
return self._stack[-1]
else:
return None


class _ModelTypeStack:
__slots__ = ('_stack',)

def __init__(self) -> None:
self._stack: list[type] = []

@contextmanager
def push(self, type_obj: type) -> Iterator[None]:
self._stack.append(type_obj)
yield
self._stack.pop()

def get(self) -> type | None:
if self._stack:
return self._stack[-1]
else:
return None
5 changes: 5 additions & 0 deletions pydantic/_internal/_typing_extra.py
Expand Up @@ -467,3 +467,8 @@ def is_generic_alias(type_: type[Any]) -> bool:

def is_generic_alias(type_: type[Any]) -> bool:
return isinstance(type_, typing._GenericAlias) # type: ignore


def is_self_type(tp: Any) -> bool:
"""Check if a given class is a Self type (from `typing` or `typing_extensions`)"""
return isinstance(tp, typing_base) and getattr(tp, '_name', None) == 'Self'
171 changes: 171 additions & 0 deletions tests/test_types_self.py
@@ -0,0 +1,171 @@
import dataclasses
import typing
from typing import List, Optional, Union

import pytest
import typing_extensions
from typing_extensions import NamedTuple, TypedDict

from pydantic import BaseModel, Field, TypeAdapter, ValidationError


@pytest.fixture(
name='Self',
params=[
pytest.param(typing, id='typing.Self'),
pytest.param(typing_extensions, id='t_e.Self'),
],
)
def fixture_self_all(request):
try:
return request.param.Self
except AttributeError:
pytest.skip(f'Self is not available from {request.param}')
Youssefares marked this conversation as resolved.
Show resolved Hide resolved


def test_recursive_model(Self):
class SelfRef(BaseModel):
data: int
ref: typing.Optional[Self] = None

assert SelfRef(data=1, ref={'data': 2}).model_dump() == {'data': 1, 'ref': {'data': 2, 'ref': None}}


def test_recursive_model_invalid(Self):
class SelfRef(BaseModel):
data: int
ref: typing.Optional[Self] = None

with pytest.raises(
ValidationError,
match=r'ref\.ref\s+Input should be a valid dictionary or instance of SelfRef \[type=model_type,',
):
SelfRef(data=1, ref={'data': 2, 'ref': 3}).model_dump()


def test_recursive_model_with_subclass(Self):
"""Self refs should be valid and should reference the correct class in covariant direction"""

class SelfRef(BaseModel):
x: int
ref: Self | None = None

class SubSelfRef(SelfRef):
y: int

assert SubSelfRef(x=1, ref=SubSelfRef(x=3, y=4), y=2).model_dump() == {
'x': 1,
'ref': {'x': 3, 'ref': None, 'y': 4}, # SubSelfRef.ref: SubSelfRef
'y': 2,
}
assert SelfRef(x=1, ref=SubSelfRef(x=2, y=3)).model_dump() == {
'x': 1,
'ref': {'x': 2, 'ref': None},
} # SelfRef.ref: SelfRef


def test_recursive_model_with_subclass_invalid(Self):
"""Self refs are invalid in contravariant direction"""

class SelfRef(BaseModel):
x: int
ref: Self | None = None

class SubSelfRef(SelfRef):
y: int

with pytest.raises(
ValidationError,
match=r'ref\s+Input should be a valid dictionary or instance of SubSelfRef \[type=model_type,',
):
SubSelfRef(x=1, ref=SelfRef(x=2), y=3).model_dump()


def test_recursive_model_with_subclass_override(Self):
"""Self refs should be overridable"""

class SelfRef(BaseModel):
x: int
ref: Self | None = None

class SubSelfRef(SelfRef):
y: int
ref: Optional[Union[SelfRef, Self]] = None

assert SubSelfRef(x=1, ref=SubSelfRef(x=3, y=4), y=2).model_dump() == {
'x': 1,
'ref': {'x': 3, 'ref': None, 'y': 4},
'y': 2,
}
assert SubSelfRef(x=1, ref=SelfRef(x=3, y=4), y=2).model_dump() == {
'x': 1,
'ref': {'x': 3, 'ref': None},
'y': 2,
}


def test_self_type_with_field(Self):
with pytest.raises(TypeError, match=r'The following constraints cannot be applied.*\'gt\''):

class SelfRef(BaseModel):
x: int
refs: typing.List[Self] = Field(..., gt=0)


def test_self_type_json_schema(Self):
class SelfRef(BaseModel):
x: int
refs: Optional[List[Self]] = []

assert SelfRef.model_json_schema() == {
'$defs': {
'SelfRef': {
'properties': {
'x': {'title': 'X', 'type': 'integer'},
'refs': {
'anyOf': [{'items': {'$ref': '#/$defs/SelfRef'}, 'type': 'array'}, {'type': 'null'}],
'default': [],
'title': 'Refs',
},
},
'required': ['x'],
'title': 'SelfRef',
'type': 'object',
}
},
'allOf': [{'$ref': '#/$defs/SelfRef'}],
}


def test_self_type_in_named_tuple(Self):
class SelfRefNamedTuple(NamedTuple):
x: int
ref: Self | None

ta = TypeAdapter(SelfRefNamedTuple)
assert ta.validate_python({'x': 1, 'ref': {'x': 2, 'ref': None}}) == (1, (2, None))


def test_self_type_in_typed_dict(Self):
class SelfRefTypedDict(TypedDict):
x: int
ref: Self | None

ta = TypeAdapter(SelfRefTypedDict)
assert ta.validate_python({'x': 1, 'ref': {'x': 2, 'ref': None}}) == {'x': 1, 'ref': {'x': 2, 'ref': None}}


def test_self_type_in_dataclass(Self):
@dataclasses.dataclass(frozen=True)
class SelfRef:
x: int
ref: Self | None

class Model(BaseModel):
item: SelfRef

m = Model.model_validate({'item': {'x': 1, 'ref': {'x': 2, 'ref': None}}})
assert m.item.x == 1
assert m.item.ref.x == 2
with pytest.raises(dataclasses.FrozenInstanceError):
m.item.ref.x = 3