Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Eagerly resolve discriminated unions and cache cases where we can't #7529

Merged
merged 2 commits into from Sep 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 2 additions & 0 deletions pydantic/_internal/_core_utils.py
Expand Up @@ -41,6 +41,8 @@

_DEFINITIONS_CACHE_METADATA_KEY = 'pydantic.definitions_cache'

NEEDS_APPLY_DISCRIMINATED_UNION_METADATA_KEY = 'pydantic.internal.needs_apply_discriminated_union'


def is_core_schema(
schema: CoreSchemaOrField,
Expand Down
28 changes: 22 additions & 6 deletions pydantic/_internal/_discriminated_union.py
Expand Up @@ -6,11 +6,21 @@

from ..errors import PydanticUserError
from . import _core_utils
from ._core_utils import CoreSchemaField, collect_definitions
from ._core_utils import NEEDS_APPLY_DISCRIMINATED_UNION_METADATA_KEY, CoreSchemaField, collect_definitions

CORE_SCHEMA_METADATA_DISCRIMINATOR_PLACEHOLDER_KEY = 'pydantic.internal.union_discriminator'


class MissingDefinitionForUnionRef(Exception):
"""Raised when applying a discriminated union discriminator to a schema
requires a definition that is not yet defined
"""

def __init__(self, ref: str) -> None:
self.ref = ref
super().__init__(f'Missing definition for ref {self.ref!r}')


def set_discriminator(schema: CoreSchema, discriminator: Any) -> None:
schema.setdefault('metadata', {})
metadata = schema.get('metadata')
Expand All @@ -19,16 +29,23 @@ def set_discriminator(schema: CoreSchema, discriminator: Any) -> None:


def apply_discriminators(schema: core_schema.CoreSchema) -> core_schema.CoreSchema:
definitions = collect_definitions(schema)
definitions: dict[str, CoreSchema] | None = None

def inner(s: core_schema.CoreSchema, recurse: _core_utils.Recurse) -> core_schema.CoreSchema:
nonlocal definitions
if 'metadata' in s:
if s['metadata'].get(NEEDS_APPLY_DISCRIMINATED_UNION_METADATA_KEY, True) is False:
davidhewitt marked this conversation as resolved.
Show resolved Hide resolved
return s

s = recurse(s, inner)
if s['type'] == 'tagged-union':
return s

metadata = s.get('metadata', {})
discriminator = metadata.get(CORE_SCHEMA_METADATA_DISCRIMINATOR_PLACEHOLDER_KEY, None)
if discriminator is not None:
if definitions is None:
definitions = collect_definitions(schema)
s = apply_discriminator(s, discriminator, definitions)
return s

Expand All @@ -53,7 +70,7 @@ def apply_discriminator(
- If `discriminator` is used with invalid union variant.
- If `discriminator` is used with `Union` type with one variant.
- If `discriminator` value mapped to multiple choices.
ValueError:
MissingDefinitionForUnionRef:
If the definition for ref is missing.
PydanticUserError:
- If a model in union doesn't have a discriminator field.
Expand Down Expand Up @@ -242,7 +259,7 @@ def _handle_choice(self, choice: core_schema.CoreSchema) -> None:
self._choices_to_handle.extend(choices_schemas)
elif choice['type'] == 'definition-ref':
if choice['schema_ref'] not in self.definitions:
raise ValueError(f"Missing definition for ref {choice['schema_ref']!r}")
raise MissingDefinitionForUnionRef(choice['schema_ref'])
self._handle_choice(self.definitions[choice['schema_ref']])
elif choice['type'] not in {
'model',
Expand Down Expand Up @@ -344,9 +361,8 @@ def _infer_discriminator_values_for_choice( # noqa C901
elif choice['type'] == 'definition-ref':
schema_ref = choice['schema_ref']
if schema_ref not in self.definitions:
raise ValueError(f'Missing definition for inner ref {schema_ref!r}')
raise MissingDefinitionForUnionRef(schema_ref)
return self._infer_discriminator_values_for_choice(self.definitions[schema_ref], source_name=source_name)

else:
raise TypeError(
f'{choice["type"]!r} is not a valid discriminated union variant;'
Expand Down
50 changes: 40 additions & 10 deletions pydantic/_internal/_generate_schema.py
Expand Up @@ -50,6 +50,7 @@
build_metadata_dict,
)
from ._core_utils import (
NEEDS_APPLY_DISCRIMINATED_UNION_METADATA_KEY,
CoreSchemaOrField,
define_expected_missing_refs,
get_type_ref,
Expand Down Expand Up @@ -267,11 +268,12 @@ def __init__(
config_wrapper: ConfigWrapper,
types_namespace: dict[str, Any] | None,
typevars_map: dict[Any, Any] | None = None,
):
) -> None:
# we need a stack for recursing into child models
self._config_wrapper_stack = ConfigWrapperStack(config_wrapper)
self._types_namespace = types_namespace
self._typevars_map = typevars_map
self._needs_apply_discriminated_union = False
self.defs = _Definitions()

@classmethod
Expand All @@ -286,6 +288,7 @@ def __from_parent(
obj._config_wrapper_stack = config_wrapper_stack
obj._types_namespace = types_namespace
obj._typevars_map = typevars_map
obj._needs_apply_discriminated_union = False
obj.defs = defs
return obj

Expand Down Expand Up @@ -355,6 +358,22 @@ def _unknown_type_schema(self, obj: Any) -> CoreSchema:
' `__get_pydantic_core_schema__` on `<some type>` otherwise to avoid infinite recursion.'
)

def _apply_discriminator_to_union(self, schema: CoreSchema, discriminator: Any) -> CoreSchema:
try:
return _discriminated_union.apply_discriminator(
schema,
discriminator,
)
except _discriminated_union.MissingDefinitionForUnionRef:
# defer until defs are resolved
_discriminated_union.set_discriminator(
schema,
discriminator,
)
schema.setdefault('metadata', {})[NEEDS_APPLY_DISCRIMINATED_UNION_METADATA_KEY] = True
self._needs_apply_discriminated_union = True
davidhewitt marked this conversation as resolved.
Show resolved Hide resolved
return schema

def collect_definitions(self, schema: CoreSchema) -> CoreSchema:
ref = cast('str | None', schema.get('ref', None))
if ref:
Expand Down Expand Up @@ -420,6 +439,8 @@ def generate_schema(

schema = _add_custom_serialization_from_json_encoders(self._config_wrapper.json_encoders, obj, schema)

schema = self._post_process_generated_schema(schema)

return schema

def _model_schema(self, cls: type[BaseModel]) -> core_schema.CoreSchema:
Expand Down Expand Up @@ -506,7 +527,7 @@ def _model_schema(self, cls: type[BaseModel]) -> core_schema.CoreSchema:

schema = self._apply_model_serializers(model_schema, decorators.model_serializers.values())
schema = apply_model_validators(schema, model_validators, 'outer')
self.defs.definitions[model_ref] = schema
self.defs.definitions[model_ref] = self._post_process_generated_schema(schema)
return core_schema.definition_reference_schema(model_ref)

def _unpack_refs_defs(self, schema: CoreSchema) -> CoreSchema:
Expand Down Expand Up @@ -562,8 +583,11 @@ def _generate_schema_from_property(self, obj: Any, source: Any) -> core_schema.C

ref: str | None = schema.get('ref', None)
if ref:
self.defs.definitions[ref] = schema
self.defs.definitions[ref] = self._post_process_generated_schema(schema)
return core_schema.definition_reference_schema(ref)

schema = self._post_process_generated_schema(schema)

return schema

def _resolve_forward_ref(self, obj: Any) -> Any:
Expand Down Expand Up @@ -619,8 +643,17 @@ def _get_first_two_args_or_any(self, obj: Any) -> tuple[Any, Any]:
raise TypeError(f'Expected two type arguments for {origin}, got 1')
return args[0], args[1]

def _post_process_generated_schema(self, schema: core_schema.CoreSchema) -> core_schema.CoreSchema:
schema.setdefault('metadata', {})[
NEEDS_APPLY_DISCRIMINATED_UNION_METADATA_KEY
] = self._needs_apply_discriminated_union
return schema

def _generate_schema(self, obj: Any) -> core_schema.CoreSchema:
"""Recursively generate a pydantic-core schema for any supported python type."""
return self._post_process_generated_schema(self._generate_schema_inner(obj))

def _generate_schema_inner(self, obj: Any) -> core_schema.CoreSchema:
if isinstance(obj, _AnnotatedType):
return self._annotated_schema(obj)

Expand Down Expand Up @@ -847,7 +880,7 @@ def _common_field_schema(self, name: str, field_info: FieldInfo, decorators: Dec
source_type, annotations = field_info.annotation, field_info.metadata

def set_discriminator(schema: CoreSchema) -> CoreSchema:
_discriminated_union.set_discriminator(schema, field_info.discriminator)
schema = self._apply_discriminator_to_union(schema, field_info.discriminator)
return schema

if field_info.discriminator is not None:
Expand Down Expand Up @@ -1065,7 +1098,7 @@ def _typed_dict_schema(self, typed_dict_cls: Any, origin: Any) -> core_schema.Co

schema = self._apply_model_serializers(td_schema, decorators.model_serializers.values())
schema = apply_model_validators(schema, decorators.model_validators.values(), 'all')
self.defs.definitions[typed_dict_ref] = schema
self.defs.definitions[typed_dict_ref] = self._post_process_generated_schema(schema)
return core_schema.definition_reference_schema(typed_dict_ref)

def _namedtuple_schema(self, namedtuple_cls: Any, origin: Any) -> core_schema.CoreSchema:
Expand Down Expand Up @@ -1318,7 +1351,7 @@ def _dataclass_schema(
)
schema = self._apply_model_serializers(dc_schema, decorators.model_serializers.values())
schema = apply_model_validators(schema, model_validators, 'outer')
self.defs.definitions[dataclass_ref] = schema
self.defs.definitions[dataclass_ref] = self._post_process_generated_schema(schema)
return core_schema.definition_reference_schema(dataclass_ref)

def _callable_schema(self, function: Callable[..., Any]) -> core_schema.CallSchema:
Expand Down Expand Up @@ -1525,10 +1558,7 @@ def _apply_single_annotation(self, schema: core_schema.CoreSchema, metadata: Any
schema = self._apply_single_annotation(schema, field_metadata)

if metadata.discriminator is not None:
_discriminated_union.set_discriminator(
schema,
metadata.discriminator,
)
schema = self._apply_discriminator_to_union(schema, metadata.discriminator)
return schema

if schema['type'] == 'nullable':
Expand Down