Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix schema build for nested dataclasses / TypedDicts with discriminators #8950

Merged
merged 4 commits into from Mar 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
31 changes: 22 additions & 9 deletions pydantic/_internal/_discriminated_union.py
Expand Up @@ -36,10 +36,14 @@ def set_discriminator_in_metadata(schema: CoreSchema, discriminator: Any) -> Non


def apply_discriminators(schema: core_schema.CoreSchema) -> core_schema.CoreSchema:
definitions: dict[str, CoreSchema] | None = None
# Throughout recursion, we allow references to be resolved from the definitions
# that are present in the outermost schema. Before apply_discriminators is called,
# we call simplify_schema_references (in the clean_schema function),
# which often puts the definitions in the outermost schema.
global_definitions: dict[str, CoreSchema] = collect_definitions(schema)

def inner(s: core_schema.CoreSchema, recurse: _core_utils.Recurse) -> core_schema.CoreSchema:
nonlocal definitions
nonlocal global_definitions

s = recurse(s, inner)
if s['type'] == 'tagged-union':
Expand All @@ -48,9 +52,7 @@ def inner(s: core_schema.CoreSchema, recurse: _core_utils.Recurse) -> core_schem
metadata = s.get('metadata', {})
discriminator = metadata.pop(CORE_SCHEMA_METADATA_DISCRIMINATOR_PLACEHOLDER_KEY, None)
if discriminator is not None:
if definitions is None:
definitions = collect_definitions(schema)
s = apply_discriminator(s, discriminator, definitions)
s = apply_discriminator(s, discriminator, global_definitions)
return s

return simplify_schema_references(_core_utils.walk_core_schema(schema, inner))
Expand Down Expand Up @@ -185,16 +187,27 @@ def apply(self, schema: core_schema.CoreSchema) -> core_schema.CoreSchema:
- If discriminator fields have different aliases.
- If discriminator field not of type `Literal`.
"""
self.definitions.update(collect_definitions(schema))
# Fetch the definitions attached to the (often inner) schema in question,
# and add them to the definitions that we will use to resolve references
original_local_defs = collect_definitions(schema)
self.definitions.update(original_local_defs)

assert not self._used
schema = self._apply_to_root(schema)
if self._should_be_nullable and not self._is_nullable:
schema = core_schema.nullable_schema(schema)
self._used = True
new_defs = collect_definitions(schema)
missing_defs = self.definitions.keys() - new_defs.keys()

# If there are any definitions that were present on the original schema but not on the new schema,
# we need to add them to the new schema. This is necessary because the definitions may contain
# schemas that are referenced by the choices in the union, and we need to ensure that the new schema
# contains all the definitions that are necessary to resolve these references.
# Note -- by "original schema", we refer to the schema that was passed to the apply method,
# not the outermost schema that we're recursing on (where self.definitions came from).
new_local_defs = collect_definitions(schema)
missing_defs = original_local_defs.keys() - new_local_defs.keys()
if missing_defs:
schema = core_schema.definitions_schema(schema, [self.definitions[ref] for ref in missing_defs])
schema = core_schema.definitions_schema(schema, [original_local_defs[ref] for ref in missing_defs])
return schema

def _apply_to_root(self, schema: core_schema.CoreSchema) -> core_schema.CoreSchema:
Expand Down
38 changes: 38 additions & 0 deletions tests/test_discriminated_union.py
Expand Up @@ -1959,3 +1959,41 @@ class Bar:
if 'Foo' in definition['ref']:
for field in definition['schema']['fields']:
assert field['schema']['type'] == 'tagged-union' if field['name'] == 'x' else True


def test_discriminated_union_with_nested_dataclass() -> None:
@pydantic_dataclass
class Cat:
type: Literal['cat'] = 'cat'

@pydantic_dataclass
class Dog:
type: Literal['dog'] = 'dog'

@pydantic_dataclass
class NestedDataClass:
animal: Annotated[Union[Cat, Dog], Discriminator('type')]

@pydantic_dataclass
class Root:
data_class: NestedDataClass

ta = TypeAdapter(Root)
assert ta.core_schema['schema']['fields'][0]['schema']['schema']['fields'][0]['schema']['type'] == 'tagged-union'


def test_discriminated_union_with_nested_typed_dicts() -> None:
class Cat(TypedDict):
type: Literal['cat']

class Dog(TypedDict):
type: Literal['dog']

class NestedTypedDict(TypedDict):
animal: Annotated[Union[Cat, Dog], Discriminator('type')]

class Root(TypedDict):
data_class: NestedTypedDict

ta = TypeAdapter(Root)
assert ta.core_schema['fields']['data_class']['schema']['fields']['animal']['schema']['type'] == 'tagged-union'