Skip to content

Commit

Permalink
Fix schema build for nested dataclasses / TypedDicts with discriminat…
Browse files Browse the repository at this point in the history
…ors (#8950)

Co-authored-by: David Montague <35119617+dmontagu@users.noreply.github.com>
  • Loading branch information
sydney-runkle and dmontagu committed Mar 5, 2024
1 parent b09e06c commit a25882c
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 9 deletions.
31 changes: 22 additions & 9 deletions pydantic/_internal/_discriminated_union.py
Expand Up @@ -36,10 +36,14 @@ def set_discriminator_in_metadata(schema: CoreSchema, discriminator: Any) -> Non


def apply_discriminators(schema: core_schema.CoreSchema) -> core_schema.CoreSchema:
definitions: dict[str, CoreSchema] | None = None
# Throughout recursion, we allow references to be resolved from the definitions
# that are present in the outermost schema. Before apply_discriminators is called,
# we call simplify_schema_references (in the clean_schema function),
# which often puts the definitions in the outermost schema.
global_definitions: dict[str, CoreSchema] = collect_definitions(schema)

def inner(s: core_schema.CoreSchema, recurse: _core_utils.Recurse) -> core_schema.CoreSchema:
nonlocal definitions
nonlocal global_definitions

s = recurse(s, inner)
if s['type'] == 'tagged-union':
Expand All @@ -48,9 +52,7 @@ def inner(s: core_schema.CoreSchema, recurse: _core_utils.Recurse) -> core_schem
metadata = s.get('metadata', {})
discriminator = metadata.pop(CORE_SCHEMA_METADATA_DISCRIMINATOR_PLACEHOLDER_KEY, None)
if discriminator is not None:
if definitions is None:
definitions = collect_definitions(schema)
s = apply_discriminator(s, discriminator, definitions)
s = apply_discriminator(s, discriminator, global_definitions)
return s

return simplify_schema_references(_core_utils.walk_core_schema(schema, inner))
Expand Down Expand Up @@ -185,16 +187,27 @@ def apply(self, schema: core_schema.CoreSchema) -> core_schema.CoreSchema:
- If discriminator fields have different aliases.
- If discriminator field not of type `Literal`.
"""
self.definitions.update(collect_definitions(schema))
# Fetch the definitions attached to the (often inner) schema in question,
# and add them to the definitions that we will use to resolve references
original_local_defs = collect_definitions(schema)
self.definitions.update(original_local_defs)

assert not self._used
schema = self._apply_to_root(schema)
if self._should_be_nullable and not self._is_nullable:
schema = core_schema.nullable_schema(schema)
self._used = True
new_defs = collect_definitions(schema)
missing_defs = self.definitions.keys() - new_defs.keys()

# If there are any definitions that were present on the original schema but not on the new schema,
# we need to add them to the new schema. This is necessary because the definitions may contain
# schemas that are referenced by the choices in the union, and we need to ensure that the new schema
# contains all the definitions that are necessary to resolve these references.
# Note -- by "original schema", we refer to the schema that was passed to the apply method,
# not the outermost schema that we're recursing on (where self.definitions came from).
new_local_defs = collect_definitions(schema)
missing_defs = original_local_defs.keys() - new_local_defs.keys()
if missing_defs:
schema = core_schema.definitions_schema(schema, [self.definitions[ref] for ref in missing_defs])
schema = core_schema.definitions_schema(schema, [original_local_defs[ref] for ref in missing_defs])
return schema

def _apply_to_root(self, schema: core_schema.CoreSchema) -> core_schema.CoreSchema:
Expand Down
38 changes: 38 additions & 0 deletions tests/test_discriminated_union.py
Expand Up @@ -1959,3 +1959,41 @@ class Bar:
if 'Foo' in definition['ref']:
for field in definition['schema']['fields']:
assert field['schema']['type'] == 'tagged-union' if field['name'] == 'x' else True


def test_discriminated_union_with_nested_dataclass() -> None:
@pydantic_dataclass
class Cat:
type: Literal['cat'] = 'cat'

@pydantic_dataclass
class Dog:
type: Literal['dog'] = 'dog'

@pydantic_dataclass
class NestedDataClass:
animal: Annotated[Union[Cat, Dog], Discriminator('type')]

@pydantic_dataclass
class Root:
data_class: NestedDataClass

ta = TypeAdapter(Root)
assert ta.core_schema['schema']['fields'][0]['schema']['schema']['fields'][0]['schema']['type'] == 'tagged-union'


def test_discriminated_union_with_nested_typed_dicts() -> None:
class Cat(TypedDict):
type: Literal['cat']

class Dog(TypedDict):
type: Literal['dog']

class NestedTypedDict(TypedDict):
animal: Annotated[Union[Cat, Dog], Discriminator('type')]

class Root(TypedDict):
data_class: NestedTypedDict

ta = TypeAdapter(Root)
assert ta.core_schema['fields']['data_class']['schema']['fields']['animal']['schema']['type'] == 'tagged-union'

0 comments on commit a25882c

Please sign in to comment.