-
-
Notifications
You must be signed in to change notification settings - Fork 1.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix nested discriminated union schema gen, pt 2 #8932
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,7 +7,6 @@ | |
from ..errors import PydanticUserError | ||
from . import _core_utils | ||
from ._core_utils import ( | ||
NEEDS_APPLY_DISCRIMINATED_UNION_METADATA_KEY, | ||
CoreSchemaField, | ||
collect_definitions, | ||
simplify_schema_references, | ||
|
@@ -29,7 +28,7 @@ def __init__(self, ref: str) -> None: | |
super().__init__(f'Missing definition for ref {self.ref!r}') | ||
|
||
|
||
def set_discriminator(schema: CoreSchema, discriminator: Any) -> None: | ||
def set_discriminator_in_metadata(schema: CoreSchema, discriminator: Any) -> None: | ||
schema.setdefault('metadata', {}) | ||
metadata = schema.get('metadata') | ||
assert metadata is not None | ||
|
@@ -41,25 +40,16 @@ def apply_discriminators(schema: core_schema.CoreSchema) -> core_schema.CoreSche | |
|
||
def inner(s: core_schema.CoreSchema, recurse: _core_utils.Recurse) -> core_schema.CoreSchema: | ||
nonlocal definitions | ||
if 'metadata' in s: | ||
if s['metadata'].get(NEEDS_APPLY_DISCRIMINATED_UNION_METADATA_KEY, True) is False: | ||
return s | ||
|
||
s = recurse(s, inner) | ||
if s['type'] == 'tagged-union': | ||
return s | ||
|
||
metadata = s.get('metadata', {}) | ||
discriminator = metadata.get(CORE_SCHEMA_METADATA_DISCRIMINATOR_PLACEHOLDER_KEY, None) | ||
discriminator = metadata.pop(CORE_SCHEMA_METADATA_DISCRIMINATOR_PLACEHOLDER_KEY, None) | ||
if discriminator is not None: | ||
if definitions is None: | ||
definitions = collect_definitions(schema) | ||
# After we collect the definitions schemas, we must run through the discriminator | ||
# application logic for each one. This step is crucial to prevent an exponential | ||
# increase in complexity that occurs if schemas are left as 'union' schemas | ||
# rather than 'tagged-union' schemas. | ||
# For more details, see https://github.com/pydantic/pydantic/pull/8904#discussion_r1504687302 | ||
definitions = {k: recurse(v, inner) for k, v in definitions.items()} | ||
s = apply_discriminator(s, discriminator, definitions) | ||
return s | ||
|
||
|
@@ -274,6 +264,10 @@ def _handle_choice(self, choice: core_schema.CoreSchema) -> None: | |
* Validating that each allowed discriminator value maps to a unique choice | ||
* Updating the _tagged_union_choices mapping that will ultimately be used to build the TaggedUnionSchema. | ||
""" | ||
if choice['type'] == 'definition-ref': | ||
if choice['schema_ref'] not in self.definitions: | ||
raise MissingDefinitionForUnionRef(choice['schema_ref']) | ||
|
||
if choice['type'] == 'none': | ||
self._should_be_nullable = True | ||
elif choice['type'] == 'definitions': | ||
|
@@ -285,17 +279,14 @@ def _handle_choice(self, choice: core_schema.CoreSchema) -> None: | |
# Reverse the choices list before extending the stack so that they get handled in the order they occur | ||
choices_schemas = [v[0] if isinstance(v, tuple) else v for v in choice['choices'][::-1]] | ||
self._choices_to_handle.extend(choices_schemas) | ||
elif choice['type'] == 'definition-ref': | ||
if choice['schema_ref'] not in self.definitions: | ||
raise MissingDefinitionForUnionRef(choice['schema_ref']) | ||
self._handle_choice(self.definitions[choice['schema_ref']]) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Important!!! If a choice is of type Our schema walking logic walks through both the schema and the definitions, so we can rest easy knowing that unions will be converted to tagged unions in the definitions list as well. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does this work with JSON schema generation if e.g. the ref'ed schema is itself a discriminated union? Maybe that can't happen, and either way this seems like an improvement if no tests fail, but still There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good question. I believe that the walk core schema logic handles definitions schemas such that the function being applied during the walk is also applied to all of the definitions in a definitions schema, so that's why I felt comfortable removing this step. This can be seen via the example test I added - discriminated union transformation logic is applied to the 2 schemas in the definitions list that require said changes! |
||
elif choice['type'] not in { | ||
'model', | ||
'typed-dict', | ||
'tagged-union', | ||
'lax-or-strict', | ||
'dataclass', | ||
'dataclass-args', | ||
'definition-ref', | ||
} and not _core_utils.is_function_with_inner_schema(choice): | ||
# We should eventually handle 'definition-ref' as well | ||
raise TypeError( | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1868,8 +1868,6 @@ class LeafState(BaseModel): | |
state_type: Literal['leaf'] | ||
|
||
AnyState = Annotated[Union[NestedState, LoopState, LeafState], Field(..., discriminator='state_type')] | ||
NestedState.model_rebuild() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If you expand this section, you can see that this test showcases the example I've shown in the PR description 👍 |
||
LoopState.model_rebuild() | ||
adapter = TypeAdapter(AnyState) | ||
|
||
assert adapter.core_schema['schema']['type'] == 'tagged-union' | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually get rid of this once we use it