Skip to content

Commit

Permalink
Fix nested discriminated union schema gen, pt 2 (#8932)
Browse files Browse the repository at this point in the history
  • Loading branch information
sydney-runkle committed Mar 8, 2024
1 parent 459cc34 commit f0e0606
Show file tree
Hide file tree
Showing 4 changed files with 97 additions and 45 deletions.
5 changes: 0 additions & 5 deletions pydantic/_internal/_core_utils.py
Expand Up @@ -42,11 +42,6 @@

_DEFINITIONS_CACHE_METADATA_KEY = 'pydantic.definitions_cache'

NEEDS_APPLY_DISCRIMINATED_UNION_METADATA_KEY = 'pydantic.internal.needs_apply_discriminated_union'
"""Used to mark a schema that has a discriminated union that needs to be checked for validity at the end of
schema building because one of it's members refers to a definition that was not yet defined when the union
was first encountered.
"""
TAGGED_UNION_TAG_KEY = 'pydantic.internal.tagged_union_tag'
"""
Used in a `Tag` schema to specify the tag used for a discriminated union.
Expand Down
23 changes: 7 additions & 16 deletions pydantic/_internal/_discriminated_union.py
Expand Up @@ -7,7 +7,6 @@
from ..errors import PydanticUserError
from . import _core_utils
from ._core_utils import (
NEEDS_APPLY_DISCRIMINATED_UNION_METADATA_KEY,
CoreSchemaField,
collect_definitions,
simplify_schema_references,
Expand All @@ -29,7 +28,7 @@ def __init__(self, ref: str) -> None:
super().__init__(f'Missing definition for ref {self.ref!r}')


def set_discriminator(schema: CoreSchema, discriminator: Any) -> None:
def set_discriminator_in_metadata(schema: CoreSchema, discriminator: Any) -> None:
schema.setdefault('metadata', {})
metadata = schema.get('metadata')
assert metadata is not None
Expand All @@ -41,25 +40,16 @@ def apply_discriminators(schema: core_schema.CoreSchema) -> core_schema.CoreSche

def inner(s: core_schema.CoreSchema, recurse: _core_utils.Recurse) -> core_schema.CoreSchema:
nonlocal definitions
if 'metadata' in s:
if s['metadata'].get(NEEDS_APPLY_DISCRIMINATED_UNION_METADATA_KEY, True) is False:
return s

s = recurse(s, inner)
if s['type'] == 'tagged-union':
return s

metadata = s.get('metadata', {})
discriminator = metadata.get(CORE_SCHEMA_METADATA_DISCRIMINATOR_PLACEHOLDER_KEY, None)
discriminator = metadata.pop(CORE_SCHEMA_METADATA_DISCRIMINATOR_PLACEHOLDER_KEY, None)
if discriminator is not None:
if definitions is None:
definitions = collect_definitions(schema)
# After we collect the definitions schemas, we must run through the discriminator
# application logic for each one. This step is crucial to prevent an exponential
# increase in complexity that occurs if schemas are left as 'union' schemas
# rather than 'tagged-union' schemas.
# For more details, see https://github.com/pydantic/pydantic/pull/8904#discussion_r1504687302
definitions = {k: recurse(v, inner) for k, v in definitions.items()}
s = apply_discriminator(s, discriminator, definitions)
return s

Expand Down Expand Up @@ -274,6 +264,10 @@ def _handle_choice(self, choice: core_schema.CoreSchema) -> None:
* Validating that each allowed discriminator value maps to a unique choice
* Updating the _tagged_union_choices mapping that will ultimately be used to build the TaggedUnionSchema.
"""
if choice['type'] == 'definition-ref':
if choice['schema_ref'] not in self.definitions:
raise MissingDefinitionForUnionRef(choice['schema_ref'])

if choice['type'] == 'none':
self._should_be_nullable = True
elif choice['type'] == 'definitions':
Expand All @@ -285,17 +279,14 @@ def _handle_choice(self, choice: core_schema.CoreSchema) -> None:
# Reverse the choices list before extending the stack so that they get handled in the order they occur
choices_schemas = [v[0] if isinstance(v, tuple) else v for v in choice['choices'][::-1]]
self._choices_to_handle.extend(choices_schemas)
elif choice['type'] == 'definition-ref':
if choice['schema_ref'] not in self.definitions:
raise MissingDefinitionForUnionRef(choice['schema_ref'])
self._handle_choice(self.definitions[choice['schema_ref']])
elif choice['type'] not in {
'model',
'typed-dict',
'tagged-union',
'lax-or-strict',
'dataclass',
'dataclass-args',
'definition-ref',
} and not _core_utils.is_function_with_inner_schema(choice):
# We should eventually handle 'definition-ref' as well
raise TypeError(
Expand Down
25 changes: 4 additions & 21 deletions pydantic/_internal/_generate_schema.py
Expand Up @@ -48,7 +48,6 @@
from ._config import ConfigWrapper, ConfigWrapperStack
from ._core_metadata import CoreMetadataHandler, build_metadata_dict
from ._core_utils import (
NEEDS_APPLY_DISCRIMINATED_UNION_METADATA_KEY,
CoreSchemaOrField,
collect_invalid_schemas,
define_expected_missing_refs,
Expand Down Expand Up @@ -302,7 +301,6 @@ class GenerateSchema:
'_config_wrapper_stack',
'_types_namespace_stack',
'_typevars_map',
'_needs_apply_discriminated_union',
'_has_invalid_schema',
'field_name_stack',
'defs',
Expand All @@ -318,7 +316,6 @@ def __init__(
self._config_wrapper_stack = ConfigWrapperStack(config_wrapper)
self._types_namespace_stack = TypesNamespaceStack(types_namespace)
self._typevars_map = typevars_map
self._needs_apply_discriminated_union = False
self._has_invalid_schema = False
self.field_name_stack = _FieldNameStack()
self.defs = _Definitions()
Expand All @@ -335,7 +332,6 @@ def __from_parent(
obj._config_wrapper_stack = config_wrapper_stack
obj._types_namespace_stack = types_namespace_stack
obj._typevars_map = typevars_map
obj._needs_apply_discriminated_union = False
obj._has_invalid_schema = False
obj.field_name_stack = _FieldNameStack()
obj.defs = defs
Expand Down Expand Up @@ -416,15 +412,10 @@ def _apply_discriminator_to_union(
)
except _discriminated_union.MissingDefinitionForUnionRef:
# defer until defs are resolved
_discriminated_union.set_discriminator(
_discriminated_union.set_discriminator_in_metadata(
schema,
discriminator,
)
if 'metadata' in schema:
schema['metadata'][NEEDS_APPLY_DISCRIMINATED_UNION_METADATA_KEY] = True
else:
schema['metadata'] = {NEEDS_APPLY_DISCRIMINATED_UNION_METADATA_KEY: True}
self._needs_apply_discriminated_union = True
return schema

class CollectedInvalid(Exception):
Expand Down Expand Up @@ -719,24 +710,16 @@ def _get_first_two_args_or_any(self, obj: Any) -> tuple[Any, Any]:
return args[0], args[1]

def _post_process_generated_schema(self, schema: core_schema.CoreSchema) -> core_schema.CoreSchema:
if 'metadata' in schema:
metadata = schema['metadata']
metadata[NEEDS_APPLY_DISCRIMINATED_UNION_METADATA_KEY] = self._needs_apply_discriminated_union
else:
schema['metadata'] = {
NEEDS_APPLY_DISCRIMINATED_UNION_METADATA_KEY: self._needs_apply_discriminated_union,
}
if 'metadata' not in schema:
schema['metadata'] = {}
return schema

def _generate_schema(self, obj: Any) -> core_schema.CoreSchema:
"""Recursively generate a pydantic-core schema for any supported python type."""
has_invalid_schema = self._has_invalid_schema
self._has_invalid_schema = False
needs_apply_discriminated_union = self._needs_apply_discriminated_union
self._needs_apply_discriminated_union = False
schema = self._post_process_generated_schema(self._generate_schema_inner(obj))
schema = self._generate_schema_inner(obj)
self._has_invalid_schema = self._has_invalid_schema or has_invalid_schema
self._needs_apply_discriminated_union = self._needs_apply_discriminated_union or needs_apply_discriminated_union
return schema

def _generate_schema_inner(self, obj: Any) -> core_schema.CoreSchema:
Expand Down
89 changes: 86 additions & 3 deletions tests/test_discriminated_union.py
Expand Up @@ -7,10 +7,11 @@
import pytest
from dirty_equals import HasRepr, IsStr
from pydantic_core import SchemaValidator, core_schema
from typing_extensions import Annotated, Literal
from typing_extensions import Annotated, Literal, TypedDict

from pydantic import BaseModel, ConfigDict, Discriminator, Field, TypeAdapter, ValidationError, field_validator
from pydantic._internal._discriminated_union import apply_discriminator
from pydantic.dataclasses import dataclass as pydantic_dataclass
from pydantic.errors import PydanticUserError
from pydantic.fields import FieldInfo
from pydantic.json_schema import GenerateJsonSchema
Expand Down Expand Up @@ -1868,11 +1869,93 @@ class LeafState(BaseModel):
state_type: Literal['leaf']

AnyState = Annotated[Union[NestedState, LoopState, LeafState], Field(..., discriminator='state_type')]
NestedState.model_rebuild()
LoopState.model_rebuild()
adapter = TypeAdapter(AnyState)

assert adapter.core_schema['schema']['type'] == 'tagged-union'
for definition in adapter.core_schema['definitions']:
if definition['schema']['model_name'] in ['NestedState', 'LoopState']:
assert definition['schema']['fields']['substate']['schema']['schema']['type'] == 'tagged-union'


def test_recursive_discriminiated_union_with_typed_dict() -> None:
class Foo(TypedDict):
type: Literal['foo']
x: 'Foobar'

class Bar(TypedDict):
type: Literal['bar']

Foobar = Annotated[Union[Foo, Bar], Field(discriminator='type')]
ta = TypeAdapter(Foobar)

# len of errors should be 1 for each case, bc we're using a tagged union
with pytest.raises(ValidationError) as e:
ta.validate_python({'type': 'wrong'})
assert len(e.value.errors()) == 1

with pytest.raises(ValidationError) as e:
ta.validate_python({'type': 'foo', 'x': {'type': 'wrong'}})
assert len(e.value.errors()) == 1

core_schema = ta.core_schema
assert core_schema['schema']['type'] == 'tagged-union'
for definition in core_schema['definitions']:
if 'Foo' in definition['ref']:
assert definition['fields']['x']['schema']['type'] == 'tagged-union'


def test_recursive_discriminiated_union_with_base_model() -> None:
class Foo(BaseModel):
type: Literal['foo']
x: 'Foobar'

class Bar(BaseModel):
type: Literal['bar']

Foobar = Annotated[Union[Foo, Bar], Field(discriminator='type')]
ta = TypeAdapter(Foobar)

# len of errors should be 1 for each case, bc we're using a tagged union
with pytest.raises(ValidationError) as e:
ta.validate_python({'type': 'wrong'})
assert len(e.value.errors()) == 1

with pytest.raises(ValidationError) as e:
ta.validate_python({'type': 'foo', 'x': {'type': 'wrong'}})
assert len(e.value.errors()) == 1

core_schema = ta.core_schema
assert core_schema['schema']['type'] == 'tagged-union'
for definition in core_schema['definitions']:
if 'Foo' in definition['ref']:
assert definition['schema']['fields']['x']['schema']['type'] == 'tagged-union'


def test_recursive_discriminated_union_with_pydantic_dataclass() -> None:
@pydantic_dataclass
class Foo:
type: Literal['foo']
x: 'Foobar'

@pydantic_dataclass
class Bar:
type: Literal['bar']

Foobar = Annotated[Union[Foo, Bar], Field(discriminator='type')]
ta = TypeAdapter(Foobar)

# len of errors should be 1 for each case, bc we're using a tagged union
with pytest.raises(ValidationError) as e:
ta.validate_python({'type': 'wrong'})
assert len(e.value.errors()) == 1

with pytest.raises(ValidationError) as e:
ta.validate_python({'type': 'foo', 'x': {'type': 'wrong'}})
assert len(e.value.errors()) == 1

core_schema = ta.core_schema
assert core_schema['schema']['type'] == 'tagged-union'
for definition in core_schema['definitions']:
if 'Foo' in definition['ref']:
for field in definition['schema']['fields']:
assert field['schema']['type'] == 'tagged-union' if field['name'] == 'x' else True

0 comments on commit f0e0606

Please sign in to comment.