Skip to content

Commit

Permalink
Improve schema gen for nested dataclasses (#9114)
Browse files Browse the repository at this point in the history
  • Loading branch information
sydney-runkle committed Mar 27, 2024
1 parent 7ac7881 commit b30f44d
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 11 deletions.
28 changes: 17 additions & 11 deletions pydantic/_internal/_generate_schema.py
Expand Up @@ -636,24 +636,30 @@ def _generate_schema_from_property(self, obj: Any, source: Any) -> core_schema.C
ref_mode = 'to-def'

schema: CoreSchema
get_schema = getattr(obj, '__get_pydantic_core_schema__', None)
if get_schema is None:
validators = getattr(obj, '__get_validators__', None)
if validators is None:
return None
warn(
'`__get_validators__` is deprecated and will be removed, use `__get_pydantic_core_schema__` instead.',
PydanticDeprecatedSince20,
)
schema = core_schema.chain_schema([core_schema.with_info_plain_validator_function(v) for v in validators()])
else:

if (get_schema := getattr(obj, '__get_pydantic_core_schema__', None)) is not None:
if len(inspect.signature(get_schema).parameters) == 1:
# (source) -> CoreSchema
schema = get_schema(source)
else:
schema = get_schema(
source, CallbackGetCoreSchemaHandler(self._generate_schema_inner, self, ref_mode=ref_mode)
)
# fmt: off
elif (existing_schema := getattr(obj, '__pydantic_core_schema__', None)) is not None and existing_schema.get(
'cls', None
) == obj:
schema = existing_schema
# fmt: on
elif (validators := getattr(obj, '__get_validators__', None)) is not None:
warn(
'`__get_validators__` is deprecated and will be removed, use `__get_pydantic_core_schema__` instead.',
PydanticDeprecatedSince20,
)
schema = core_schema.chain_schema([core_schema.with_info_plain_validator_function(v) for v in validators()])
else:
# we have no existing schema information on the property, exit early so that we can go generate a schema
return None

schema = self._unpack_refs_defs(schema)

Expand Down
32 changes: 32 additions & 0 deletions tests/test_dataclasses.py
Expand Up @@ -2906,3 +2906,35 @@ class C(B):
for field_name in ['a', 'b', 'c', 'd']:
assert class_.__pydantic_fields__[field_name].annotation is str
assert getattr(instance, field_name) == field_name


def test_schema_valid_for_inner_generic() -> None:
T = TypeVar('T')

@pydantic.dataclasses.dataclass()
class Inner(Generic[T]):
x: T

@pydantic.dataclasses.dataclass()
class Outer:
inner: Inner[int]

assert Outer(inner={'x': 1}).inner.x == 1
# note, this isn't Inner[Int] like it is for the BaseModel case, but the type of x is substituted, which is the important part
assert Outer.__pydantic_core_schema__['schema']['fields'][0]['schema']['cls'] == Inner
assert (
Outer.__pydantic_core_schema__['schema']['fields'][0]['schema']['schema']['fields'][0]['schema']['type']
== 'int'
)


def test_validation_works_for_cyclical_forward_refs() -> None:
@pydantic.dataclasses.dataclass()
class X:
y: Union['Y', None]

@pydantic.dataclasses.dataclass()
class Y:
x: Union[X, None]

assert Y(x={'y': None}).x.y is None
28 changes: 28 additions & 0 deletions tests/test_main.py
Expand Up @@ -3235,3 +3235,31 @@ class Child(module.Base):

assert Child.CONST1 == 'a'
assert Child.CONST2 == 'b'


def test_schema_valid_for_inner_generic() -> None:
T = TypeVar('T')

class Inner(BaseModel, Generic[T]):
x: T

class Outer(BaseModel):
inner: Inner[int]

assert Outer(inner={'x': 1}).inner.x == 1
# confirming that the typevars are substituted in the outer model schema
assert Outer.__pydantic_core_schema__['schema']['fields']['inner']['schema']['cls'] == Inner[int]
assert (
Outer.__pydantic_core_schema__['schema']['fields']['inner']['schema']['schema']['fields']['x']['schema']['type']
== 'int'
)


def test_validation_works_for_cyclical_forward_refs() -> None:
class X(BaseModel):
y: Union['Y', None]

class Y(BaseModel):
x: Union[X, None]

assert Y(x={'y': None}).x.y is None

0 comments on commit b30f44d

Please sign in to comment.