Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix discriminated union schema gen bug #8904

Merged
merged 4 commits into from Feb 27, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions pydantic/_internal/_discriminated_union.py
Expand Up @@ -54,6 +54,7 @@ def inner(s: core_schema.CoreSchema, recurse: _core_utils.Recurse) -> core_schem
if discriminator is not None:
if definitions is None:
definitions = collect_definitions(schema)
definitions = {k: recurse(v, inner) for k, v in definitions.items()}
s = apply_discriminator(s, discriminator, definitions)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🤯 care to explain how this fixes things?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure thing! I'll add some more comments in the code to explain things 😄!

I'll also add a more in-depth note to this PR, so we can reference that for more detail, if needed 👍.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So previously, with the above snippet of code, the resultant core_schema was the following:

{'definitions': [{'cls': <class '__main__.LeafState'>,
                  'config': {'title': 'LeafState'},
                  'custom_init': False,
                  'metadata': {'pydantic.internal.needs_apply_discriminated_union': False,
                               'pydantic_js_annotation_functions': [],
                               'pydantic_js_functions': [functools.partial(<function modify_model_json_schema at 0x101bcd900>, cls=<class '__main__.LeafState'>),
                                                         <bound method BaseModel.__get_pydantic_json_schema__ of <class '__main__.LeafState'>>]},
                  'ref': '__main__.LeafState:5484303040',
                  'root_model': False,
                  'schema': {'computed_fields': [],
                             'fields': {'state_type': {'metadata': {'pydantic_js_annotation_functions': [<function get_json_schema_update_func.<locals>.json_schema_update_func at 0x101c09d80>],
                                                                    'pydantic_js_functions': []},
                                                       'schema': {'expected': ['leaf'],
                                                                  'metadata': {'pydantic.internal.needs_apply_discriminated_union': False},
                                                                  'type': 'literal'},
                                                       'type': 'model-field'}},
                             'model_name': 'LeafState',
                             'type': 'model-fields'},
                  'type': 'model'},
                 {'cls': <class '__main__.LoopState'>,
                  'config': {'title': 'LoopState'},
                  'custom_init': False,
                  'metadata': {'pydantic.internal.needs_apply_discriminated_union': False,
                               'pydantic_js_annotation_functions': [],
                               'pydantic_js_functions': [functools.partial(<function modify_model_json_schema at 0x101bcd900>, cls=<class '__main__.LoopState'>),
                                                         <bound method BaseModel.__get_pydantic_json_schema__ of <class '__main__.LoopState'>>]},
                  'ref': '__main__.LoopState:5484288784',
                  'root_model': False,
                  'schema': {'computed_fields': [],
                             'fields': {'state_type': {'metadata': {'pydantic_js_annotation_functions': [<function get_json_schema_update_func.<locals>.json_schema_update_func at 0x101cf8310>],
                                                                    'pydantic_js_functions': []},
                                                       'schema': {'expected': ['loop'],
                                                                  'metadata': {'pydantic.internal.needs_apply_discriminated_union': False},
                                                                  'type': 'literal'},
                                                       'type': 'model-field'},
                                        'substate': {'metadata': {'pydantic_js_annotation_functions': [<function get_json_schema_update_func.<locals>.json_schema_update_func at 0x101cf83a0>],
                                                                  'pydantic_js_functions': []},
                                                     'schema': {'default': Ellipsis,
                                                                'schema': {'choices': [{'metadata': {'pydantic.internal.needs_apply_discriminated_union': False},
                                                                                        'schema_ref': '__main__.NestedState:5484231952',
                                                                                        'type': 'definition-ref'},
                                                                                       {'metadata': {'pydantic.internal.needs_apply_discriminated_union': False},
                                                                                        'schema_ref': '__main__.LoopState:5484288784',
                                                                                        'type': 'definition-ref'},
                                                                                       {'metadata': {'pydantic.internal.needs_apply_discriminated_union': False},
                                                                                        'schema_ref': '__main__.LeafState:5484303040',
                                                                                        'type': 'definition-ref'}],
                                                                           'metadata': {'pydantic.internal.needs_apply_discriminated_union': True,
                                                                                        'pydantic.internal.union_discriminator': 'state_type'},
                                                                           'type': 'union'},
                                                                'type': 'default'},
                                                     'type': 'model-field'}},
                             'model_name': 'LoopState',
                             'type': 'model-fields'},
                  'type': 'model'},
                 {'cls': <class '__main__.NestedState'>,
                  'config': {'title': 'NestedState'},
                  'custom_init': False,
                  'metadata': {'pydantic.internal.needs_apply_discriminated_union': False,
                               'pydantic_js_annotation_functions': [],
                               'pydantic_js_functions': [functools.partial(<function modify_model_json_schema at 0x101bcd900>, cls=<class '__main__.NestedState'>),
                                                         <bound method BaseModel.__get_pydantic_json_schema__ of <class '__main__.NestedState'>>]},
                  'ref': '__main__.NestedState:5484231952',
                  'root_model': False,
                  'schema': {'computed_fields': [],
                             'fields': {'state_type': {'metadata': {'pydantic_js_annotation_functions': [<function get_json_schema_update_func.<locals>.json_schema_update_func at 0x101cf80d0>],
                                                                    'pydantic_js_functions': []},
                                                       'schema': {'expected': ['nested'],
                                                                  'metadata': {'pydantic.internal.needs_apply_discriminated_union': False},
                                                                  'type': 'literal'},
                                                       'type': 'model-field'},
                                        'substate': {'metadata': {'pydantic_js_annotation_functions': [<function get_json_schema_update_func.<locals>.json_schema_update_func at 0x101cf8160>],
                                                                  'pydantic_js_functions': []},
                                                     'schema': {'default': Ellipsis,
                                                                'schema': {'choices': [{'metadata': {'pydantic.internal.needs_apply_discriminated_union': False},
                                                                                        'schema_ref': '__main__.NestedState:5484231952',
                                                                                        'type': 'definition-ref'},
                                                                                       {'metadata': {'pydantic.internal.needs_apply_discriminated_union': False},
                                                                                        'schema_ref': '__main__.LoopState:5484288784',
                                                                                        'type': 'definition-ref'},
                                                                                       {'metadata': {'pydantic.internal.needs_apply_discriminated_union': False},
                                                                                        'schema_ref': '__main__.LeafState:5484303040',
                                                                                        'type': 'definition-ref'}],
                                                                           'metadata': {'pydantic.internal.needs_apply_discriminated_union': True,
                                                                                        'pydantic.internal.union_discriminator': 'state_type'},
                                                                           'type': 'union'},
                                                                'type': 'default'},
                                                     'type': 'model-field'}},
                             'model_name': 'NestedState',
                             'type': 'model-fields'},
                  'type': 'model'}],
 'schema': {'choices': {'leaf': {'schema_ref': '__main__.LeafState:5484303040',
                                 'type': 'definition-ref'},
                        'loop': {'schema_ref': '__main__.LoopState:5484288784',
                                 'type': 'definition-ref'},
                        'nested': {'schema_ref': '__main__.NestedState:5484231952',
                                   'type': 'definition-ref'}},
            'discriminator': 'state_type',
            'from_attributes': True,
            'metadata': {'pydantic.internal.needs_apply_discriminated_union': True,
                         'pydantic.internal.union_discriminator': 'state_type'},
            'strict': False,
            'type': 'tagged-union'},
 'type': 'definitions'}

We already identified that the schemas for NestedState and LoopState within the definitions list were marked as having union type substate schemas, rather than tagged-union substate schemas. Interestingly, though, the metadata for those schemas had the necessary discriminated-union related information:

'metadata': {'pydantic.internal.needs_apply_discriminated_union': True, 'pydantic.internal.union_discriminator': 'state_type'}

So the question, then, is why aren't we "applying" the discriminators to those schemas? This is where my change comes in - after we collect the definitions schemas, we also need to pass over them + apply discriminators where appropriate. Within the inner function, we were already calling recurse, but we weren't calling it on any definitions schemas, which was the missing piece.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Totally makes sense!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Totally makes sense!

return s

Expand Down
24 changes: 24 additions & 0 deletions tests/test_discriminated_union.py
Expand Up @@ -1852,3 +1852,27 @@ class SubModel(MyModel):
'title': 'MyModel',
'type': 'object',
}


def test_nested_schema_gen_uses_tagged_union_in_ref() -> None:
class NestedState(BaseModel):
state_type: Literal['nested']
substate: 'AnyState'

# If this type is left out, the model behaves normally again
class LoopState(BaseModel):
state_type: Literal['loop']
substate: 'AnyState'

class LeafState(BaseModel):
state_type: Literal['leaf']

AnyState = Annotated[Union[NestedState, LoopState, LeafState], Field(..., discriminator='state_type')]
NestedState.model_rebuild()
LoopState.model_rebuild()
adapter = TypeAdapter(AnyState)

assert adapter.core_schema['schema']['type'] == 'tagged-union'
for definition in adapter.core_schema['definitions']:
if definition['schema']['model_name'] in ['NestedState', 'LoopState']:
assert definition['schema']['fields']['substate']['schema']['schema']['type'] == 'tagged-union'