Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve schema gen for nested dataclasses #9114

Merged
merged 11 commits into from Mar 27, 2024
28 changes: 17 additions & 11 deletions pydantic/_internal/_generate_schema.py
Expand Up @@ -636,24 +636,30 @@ def _generate_schema_from_property(self, obj: Any, source: Any) -> core_schema.C
ref_mode = 'to-def'

schema: CoreSchema
get_schema = getattr(obj, '__get_pydantic_core_schema__', None)
if get_schema is None:
validators = getattr(obj, '__get_validators__', None)
if validators is None:
return None
warn(
'`__get_validators__` is deprecated and will be removed, use `__get_pydantic_core_schema__` instead.',
PydanticDeprecatedSince20,
)
schema = core_schema.chain_schema([core_schema.with_info_plain_validator_function(v) for v in validators()])
else:

if (get_schema := getattr(obj, '__get_pydantic_core_schema__', None)) is not None:
if len(inspect.signature(get_schema).parameters) == 1:
# (source) -> CoreSchema
schema = get_schema(source)
else:
schema = get_schema(
source, CallbackGetCoreSchemaHandler(self._generate_schema_inner, self, ref_mode=ref_mode)
)
# fmt: off
elif (existing_schema := getattr(obj, '__pydantic_core_schema__', None)) is not None and existing_schema.get(
'cls', None
) == obj:
schema = existing_schema
# fmt: on
elif (validators := getattr(obj, '__get_validators__', None)) is not None:
warn(
'`__get_validators__` is deprecated and will be removed, use `__get_pydantic_core_schema__` instead.',
PydanticDeprecatedSince20,
)
schema = core_schema.chain_schema([core_schema.with_info_plain_validator_function(v) for v in validators()])
else:
# we have no existing schema information on the property, exit early so that we can go generate a schema
return None

schema = self._unpack_refs_defs(schema)

Expand Down
32 changes: 32 additions & 0 deletions tests/test_dataclasses.py
Expand Up @@ -2906,3 +2906,35 @@ class C(B):
for field_name in ['a', 'b', 'c', 'd']:
assert class_.__pydantic_fields__[field_name].annotation is str
assert getattr(instance, field_name) == field_name


def test_schema_valid_for_inner_generic() -> None:
T = TypeVar('T')

@pydantic.dataclasses.dataclass()
class Inner(Generic[T]):
x: T

@pydantic.dataclasses.dataclass()
class Outer:
inner: Inner[int]

assert Outer(inner={'x': 1}).inner.x == 1
# note, this isn't Inner[Int] like it is for the BaseModel case, but the type of x is substituted, which is the important part
assert Outer.__pydantic_core_schema__['schema']['fields'][0]['schema']['cls'] == Inner
assert (
Outer.__pydantic_core_schema__['schema']['fields'][0]['schema']['schema']['fields'][0]['schema']['type']
== 'int'
)


def test_validation_works_for_cyclical_forward_refs() -> None:
@pydantic.dataclasses.dataclass()
class X:
y: Union['Y', None]

@pydantic.dataclasses.dataclass()
class Y:
x: Union[X, None]

assert Y(x={'y': None}).x.y is None
28 changes: 28 additions & 0 deletions tests/test_main.py
Expand Up @@ -3235,3 +3235,31 @@ class Child(module.Base):

assert Child.CONST1 == 'a'
assert Child.CONST2 == 'b'


def test_schema_valid_for_inner_generic() -> None:
T = TypeVar('T')

class Inner(BaseModel, Generic[T]):
x: T

class Outer(BaseModel):
inner: Inner[int]

assert Outer(inner={'x': 1}).inner.x == 1
# confirming that the typevars are substituted in the outer model schema
assert Outer.__pydantic_core_schema__['schema']['fields']['inner']['schema']['cls'] == Inner[int]
assert (
Outer.__pydantic_core_schema__['schema']['fields']['inner']['schema']['schema']['fields']['x']['schema']['type']
== 'int'
)


def test_validation_works_for_cyclical_forward_refs() -> None:
class X(BaseModel):
y: Union['Y', None]

class Y(BaseModel):
x: Union[X, None]

assert Y(x={'y': None}).x.y is None