Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove unnecessary logic for definitions schema gen with discriminated unions #8951

Merged
merged 6 commits into from Mar 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
26 changes: 5 additions & 21 deletions pydantic/_internal/_discriminated_union.py
Expand Up @@ -9,7 +9,6 @@
from ._core_utils import (
CoreSchemaField,
collect_definitions,
simplify_schema_references,
)

if TYPE_CHECKING:
Expand All @@ -36,9 +35,10 @@ def set_discriminator_in_metadata(schema: CoreSchema, discriminator: Any) -> Non


def apply_discriminators(schema: core_schema.CoreSchema) -> core_schema.CoreSchema:
# Throughout recursion, we allow references to be resolved from the definitions
# that are present in the outermost schema. Before apply_discriminators is called,
# we call simplify_schema_references (in the clean_schema function),
# We recursively walk through the `schema` passed to `apply_discriminators`, applying discriminators
# where necessary at each level. During this recursion, we allow references to be resolved from the definitions
# that are originally present on the original, outermost `schema`. Before `apply_discriminators` is called,
# `simplify_schema_references` is called on the schema (in the `clean_schema` function),
Comment on lines +38 to +41
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just making this a bit more clear

# which often puts the definitions in the outermost schema.
global_definitions: dict[str, CoreSchema] = collect_definitions(schema)

Expand All @@ -55,7 +55,7 @@ def inner(s: core_schema.CoreSchema, recurse: _core_utils.Recurse) -> core_schem
s = apply_discriminator(s, discriminator, global_definitions)
return s

return simplify_schema_references(_core_utils.walk_core_schema(schema, inner))
return _core_utils.walk_core_schema(schema, inner)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Woohoo! Performance benefit!



def apply_discriminator(
Expand Down Expand Up @@ -187,27 +187,11 @@ def apply(self, schema: core_schema.CoreSchema) -> core_schema.CoreSchema:
- If discriminator fields have different aliases.
- If discriminator field not of type `Literal`.
"""
# Fetch the definitions attached to the (often inner) schema in question,
# and add them to the definitions that we will use to resolve references
original_local_defs = collect_definitions(schema)
self.definitions.update(original_local_defs)

assert not self._used
schema = self._apply_to_root(schema)
if self._should_be_nullable and not self._is_nullable:
schema = core_schema.nullable_schema(schema)
self._used = True

# If there are any definitions that were present on the original schema but not on the new schema,
# we need to add them to the new schema. This is necessary because the definitions may contain
# schemas that are referenced by the choices in the union, and we need to ensure that the new schema
# contains all the definitions that are necessary to resolve these references.
# Note -- by "original schema", we refer to the schema that was passed to the apply method,
# not the outermost schema that we're recursing on (where self.definitions came from).
new_local_defs = collect_definitions(schema)
missing_defs = original_local_defs.keys() - new_local_defs.keys()
if missing_defs:
schema = core_schema.definitions_schema(schema, [original_local_defs[ref] for ref in missing_defs])
return schema

def _apply_to_root(self, schema: core_schema.CoreSchema) -> core_schema.CoreSchema:
Expand Down
68 changes: 31 additions & 37 deletions tests/test_discriminated_union.py
Expand Up @@ -974,60 +974,54 @@ def test_lax_or_strict_definitions() -> None:
discriminated_schema = apply_discriminator(core_schema.union_schema([cat, dog]), 'kind')
# insert_assert(discriminated_schema)
assert discriminated_schema == {
'type': 'definitions',
'schema': {
'type': 'tagged-union',
'choices': {
'cat': {
'type': 'tagged-union',
'choices': {
'cat': {
'type': 'typed-dict',
'fields': {'kind': {'type': 'typed-dict-field', 'schema': {'type': 'literal', 'expected': ['cat']}}},
},
'DOG': {
'type': 'lax-or-strict',
'lax_schema': {
'type': 'typed-dict',
'fields': {
'kind': {'type': 'typed-dict-field', 'schema': {'type': 'literal', 'expected': ['cat']}}
'kind': {'type': 'typed-dict-field', 'schema': {'type': 'literal', 'expected': ['DOG']}}
},
},
'DOG': {
'type': 'lax-or-strict',
'lax_schema': {
'strict_schema': {
'type': 'definitions',
'schema': {
'type': 'typed-dict',
'fields': {
'kind': {'type': 'typed-dict-field', 'schema': {'type': 'literal', 'expected': ['DOG']}}
'kind': {'type': 'typed-dict-field', 'schema': {'type': 'literal', 'expected': ['dog']}}
},
},
'strict_schema': {
'type': 'definitions',
'schema': {
'type': 'typed-dict',
'fields': {
'kind': {'type': 'typed-dict-field', 'schema': {'type': 'literal', 'expected': ['dog']}}
},
},
'definitions': [{'type': 'int', 'ref': 'my-int-definition'}],
'definitions': [{'type': 'int', 'ref': 'my-int-definition'}],
},
},
'dog': {
'type': 'lax-or-strict',
'lax_schema': {
'type': 'typed-dict',
'fields': {
'kind': {'type': 'typed-dict-field', 'schema': {'type': 'literal', 'expected': ['DOG']}}
},
},
'dog': {
'type': 'lax-or-strict',
'lax_schema': {
'strict_schema': {
'type': 'definitions',
'schema': {
'type': 'typed-dict',
'fields': {
'kind': {'type': 'typed-dict-field', 'schema': {'type': 'literal', 'expected': ['DOG']}}
},
},
'strict_schema': {
'type': 'definitions',
'schema': {
'type': 'typed-dict',
'fields': {
'kind': {'type': 'typed-dict-field', 'schema': {'type': 'literal', 'expected': ['dog']}}
},
'kind': {'type': 'typed-dict-field', 'schema': {'type': 'literal', 'expected': ['dog']}}
},
'definitions': [{'type': 'int', 'ref': 'my-int-definition'}],
},
'definitions': [{'type': 'int', 'ref': 'my-int-definition'}],
},
},
'discriminator': 'kind',
'strict': False,
'from_attributes': True,
},
'definitions': [{'type': 'str', 'ref': 'my-str-definition'}],
'discriminator': 'kind',
'strict': False,
'from_attributes': True,
}


Expand Down