Skip to content

Commit

Permalink
Eagerly resolve discriminated unions and cache cases where we can't (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
adriangb committed Sep 20, 2023
1 parent 2d36952 commit 1cb0b78
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 16 deletions.
2 changes: 2 additions & 0 deletions pydantic/_internal/_core_utils.py
Expand Up @@ -41,6 +41,8 @@

_DEFINITIONS_CACHE_METADATA_KEY = 'pydantic.definitions_cache'

NEEDS_APPLY_DISCRIMINATED_UNION_METADATA_KEY = 'pydantic.internal.needs_apply_discriminated_union'


def is_core_schema(
schema: CoreSchemaOrField,
Expand Down
28 changes: 22 additions & 6 deletions pydantic/_internal/_discriminated_union.py
Expand Up @@ -6,11 +6,21 @@

from ..errors import PydanticUserError
from . import _core_utils
from ._core_utils import CoreSchemaField, collect_definitions
from ._core_utils import NEEDS_APPLY_DISCRIMINATED_UNION_METADATA_KEY, CoreSchemaField, collect_definitions

CORE_SCHEMA_METADATA_DISCRIMINATOR_PLACEHOLDER_KEY = 'pydantic.internal.union_discriminator'


class MissingDefinitionForUnionRef(Exception):
"""Raised when applying a discriminated union discriminator to a schema
requires a definition that is not yet defined
"""

def __init__(self, ref: str) -> None:
self.ref = ref
super().__init__(f'Missing definition for ref {self.ref!r}')


def set_discriminator(schema: CoreSchema, discriminator: Any) -> None:
schema.setdefault('metadata', {})
metadata = schema.get('metadata')
Expand All @@ -19,16 +29,23 @@ def set_discriminator(schema: CoreSchema, discriminator: Any) -> None:


def apply_discriminators(schema: core_schema.CoreSchema) -> core_schema.CoreSchema:
definitions = collect_definitions(schema)
definitions: dict[str, CoreSchema] | None = None

def inner(s: core_schema.CoreSchema, recurse: _core_utils.Recurse) -> core_schema.CoreSchema:
nonlocal definitions
if 'metadata' in s:
if s['metadata'].get(NEEDS_APPLY_DISCRIMINATED_UNION_METADATA_KEY, True) is False:
return s

s = recurse(s, inner)
if s['type'] == 'tagged-union':
return s

metadata = s.get('metadata', {})
discriminator = metadata.get(CORE_SCHEMA_METADATA_DISCRIMINATOR_PLACEHOLDER_KEY, None)
if discriminator is not None:
if definitions is None:
definitions = collect_definitions(schema)
s = apply_discriminator(s, discriminator, definitions)
return s

Expand All @@ -53,7 +70,7 @@ def apply_discriminator(
- If `discriminator` is used with invalid union variant.
- If `discriminator` is used with `Union` type with one variant.
- If `discriminator` value mapped to multiple choices.
ValueError:
MissingDefinitionForUnionRef:
If the definition for ref is missing.
PydanticUserError:
- If a model in union doesn't have a discriminator field.
Expand Down Expand Up @@ -242,7 +259,7 @@ def _handle_choice(self, choice: core_schema.CoreSchema) -> None:
self._choices_to_handle.extend(choices_schemas)
elif choice['type'] == 'definition-ref':
if choice['schema_ref'] not in self.definitions:
raise ValueError(f"Missing definition for ref {choice['schema_ref']!r}")
raise MissingDefinitionForUnionRef(choice['schema_ref'])
self._handle_choice(self.definitions[choice['schema_ref']])
elif choice['type'] not in {
'model',
Expand Down Expand Up @@ -344,9 +361,8 @@ def _infer_discriminator_values_for_choice( # noqa C901
elif choice['type'] == 'definition-ref':
schema_ref = choice['schema_ref']
if schema_ref not in self.definitions:
raise ValueError(f'Missing definition for inner ref {schema_ref!r}')
raise MissingDefinitionForUnionRef(schema_ref)
return self._infer_discriminator_values_for_choice(self.definitions[schema_ref], source_name=source_name)

else:
raise TypeError(
f'{choice["type"]!r} is not a valid discriminated union variant;'
Expand Down
50 changes: 40 additions & 10 deletions pydantic/_internal/_generate_schema.py
Expand Up @@ -50,6 +50,7 @@
build_metadata_dict,
)
from ._core_utils import (
NEEDS_APPLY_DISCRIMINATED_UNION_METADATA_KEY,
CoreSchemaOrField,
define_expected_missing_refs,
get_type_ref,
Expand Down Expand Up @@ -267,11 +268,12 @@ def __init__(
config_wrapper: ConfigWrapper,
types_namespace: dict[str, Any] | None,
typevars_map: dict[Any, Any] | None = None,
):
) -> None:
# we need a stack for recursing into child models
self._config_wrapper_stack = ConfigWrapperStack(config_wrapper)
self._types_namespace = types_namespace
self._typevars_map = typevars_map
self._needs_apply_discriminated_union = False
self.defs = _Definitions()

@classmethod
Expand All @@ -286,6 +288,7 @@ def __from_parent(
obj._config_wrapper_stack = config_wrapper_stack
obj._types_namespace = types_namespace
obj._typevars_map = typevars_map
obj._needs_apply_discriminated_union = False
obj.defs = defs
return obj

Expand Down Expand Up @@ -355,6 +358,22 @@ def _unknown_type_schema(self, obj: Any) -> CoreSchema:
' `__get_pydantic_core_schema__` on `<some type>` otherwise to avoid infinite recursion.'
)

def _apply_discriminator_to_union(self, schema: CoreSchema, discriminator: Any) -> CoreSchema:
try:
return _discriminated_union.apply_discriminator(
schema,
discriminator,
)
except _discriminated_union.MissingDefinitionForUnionRef:
# defer until defs are resolved
_discriminated_union.set_discriminator(
schema,
discriminator,
)
schema.setdefault('metadata', {})[NEEDS_APPLY_DISCRIMINATED_UNION_METADATA_KEY] = True
self._needs_apply_discriminated_union = True
return schema

def collect_definitions(self, schema: CoreSchema) -> CoreSchema:
ref = cast('str | None', schema.get('ref', None))
if ref:
Expand Down Expand Up @@ -420,6 +439,8 @@ def generate_schema(

schema = _add_custom_serialization_from_json_encoders(self._config_wrapper.json_encoders, obj, schema)

schema = self._post_process_generated_schema(schema)

return schema

def _model_schema(self, cls: type[BaseModel]) -> core_schema.CoreSchema:
Expand Down Expand Up @@ -506,7 +527,7 @@ def _model_schema(self, cls: type[BaseModel]) -> core_schema.CoreSchema:

schema = self._apply_model_serializers(model_schema, decorators.model_serializers.values())
schema = apply_model_validators(schema, model_validators, 'outer')
self.defs.definitions[model_ref] = schema
self.defs.definitions[model_ref] = self._post_process_generated_schema(schema)
return core_schema.definition_reference_schema(model_ref)

def _unpack_refs_defs(self, schema: CoreSchema) -> CoreSchema:
Expand Down Expand Up @@ -562,8 +583,11 @@ def _generate_schema_from_property(self, obj: Any, source: Any) -> core_schema.C

ref: str | None = schema.get('ref', None)
if ref:
self.defs.definitions[ref] = schema
self.defs.definitions[ref] = self._post_process_generated_schema(schema)
return core_schema.definition_reference_schema(ref)

schema = self._post_process_generated_schema(schema)

return schema

def _resolve_forward_ref(self, obj: Any) -> Any:
Expand Down Expand Up @@ -619,8 +643,17 @@ def _get_first_two_args_or_any(self, obj: Any) -> tuple[Any, Any]:
raise TypeError(f'Expected two type arguments for {origin}, got 1')
return args[0], args[1]

def _post_process_generated_schema(self, schema: core_schema.CoreSchema) -> core_schema.CoreSchema:
schema.setdefault('metadata', {})[
NEEDS_APPLY_DISCRIMINATED_UNION_METADATA_KEY
] = self._needs_apply_discriminated_union
return schema

def _generate_schema(self, obj: Any) -> core_schema.CoreSchema:
"""Recursively generate a pydantic-core schema for any supported python type."""
return self._post_process_generated_schema(self._generate_schema_inner(obj))

def _generate_schema_inner(self, obj: Any) -> core_schema.CoreSchema:
if isinstance(obj, _AnnotatedType):
return self._annotated_schema(obj)

Expand Down Expand Up @@ -847,7 +880,7 @@ def _common_field_schema(self, name: str, field_info: FieldInfo, decorators: Dec
source_type, annotations = field_info.annotation, field_info.metadata

def set_discriminator(schema: CoreSchema) -> CoreSchema:
_discriminated_union.set_discriminator(schema, field_info.discriminator)
schema = self._apply_discriminator_to_union(schema, field_info.discriminator)
return schema

if field_info.discriminator is not None:
Expand Down Expand Up @@ -1065,7 +1098,7 @@ def _typed_dict_schema(self, typed_dict_cls: Any, origin: Any) -> core_schema.Co

schema = self._apply_model_serializers(td_schema, decorators.model_serializers.values())
schema = apply_model_validators(schema, decorators.model_validators.values(), 'all')
self.defs.definitions[typed_dict_ref] = schema
self.defs.definitions[typed_dict_ref] = self._post_process_generated_schema(schema)
return core_schema.definition_reference_schema(typed_dict_ref)

def _namedtuple_schema(self, namedtuple_cls: Any, origin: Any) -> core_schema.CoreSchema:
Expand Down Expand Up @@ -1318,7 +1351,7 @@ def _dataclass_schema(
)
schema = self._apply_model_serializers(dc_schema, decorators.model_serializers.values())
schema = apply_model_validators(schema, model_validators, 'outer')
self.defs.definitions[dataclass_ref] = schema
self.defs.definitions[dataclass_ref] = self._post_process_generated_schema(schema)
return core_schema.definition_reference_schema(dataclass_ref)

def _callable_schema(self, function: Callable[..., Any]) -> core_schema.CallSchema:
Expand Down Expand Up @@ -1525,10 +1558,7 @@ def _apply_single_annotation(self, schema: core_schema.CoreSchema, metadata: Any
schema = self._apply_single_annotation(schema, field_metadata)

if metadata.discriminator is not None:
_discriminated_union.set_discriminator(
schema,
metadata.discriminator,
)
schema = self._apply_discriminator_to_union(schema, metadata.discriminator)
return schema

if schema['type'] == 'nullable':
Expand Down

0 comments on commit 1cb0b78

Please sign in to comment.