Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix discriminated union schema gen bug #8904

Merged
merged 4 commits into from Feb 27, 2024
Merged

Conversation

sydney-runkle
Copy link
Member

Fix #8709

With this code snippet:

from __future__ import annotations
from typing import Literal, Annotated
from pydantic import Field, TypeAdapter, BaseModel

class NestedState(BaseModel):
    state_type: Literal["nested"]
    substate: AnyState

# If this type is left out, the model behaves normally again
class LoopState(BaseModel):
    state_type: Literal["loop"]
    substate: AnyState

class LeafState(BaseModel):
    state_type: Literal["leaf"]

AnyState = Annotated[NestedState | LoopState | LeafState, Field(..., discriminator="state_type")]

def build_nested_state(n):
    if n <= 0:
        return {"state_type": "leaf"}
    else:
        return {"state_type": "loop", "substate": {"state_type": "nested", "substate": build_nested_state(n-1)}}
        
adapter = TypeAdapter(AnyState)

Previously:

# the next statement takes around 0.8s
adapter.validate_python(build_nested_state(9))

# the next statement takes around 3.5s
adapter.validate_python(build_nested_state(10))

# the next statement takes around 12.5s
adapter.validate_python(build_nested_state(11))

Now, note the speedups:

start = time.time()
adapter.validate_python(build_nested_state(9))
print(time.time() - start)
#> 0.0063190460205078125

start = time.time()
adapter.validate_python(build_nested_state(10))
print(time.time() - start)
#> 0.02297687530517578

start = time.time()
adapter.validate_python(build_nested_state(11))
print(time.time() - start)
#> 0.057311058044433594

@sydney-runkle sydney-runkle added the relnotes-fix Used for bugfixes. label Feb 27, 2024
Copy link

codspeed-hq bot commented Feb 27, 2024

CodSpeed Performance Report

Merging #8904 will not alter performance

Comparing fix-schema-building-bug (d35cd68) with main (752daaf)

Summary

✅ 10 untouched benchmarks

Copy link

cloudflare-pages bot commented Feb 27, 2024

Deploying with  Cloudflare Pages  Cloudflare Pages

Latest commit: d35cd68
Status: ✅  Deploy successful!
Preview URL: https://36f1154d.pydantic-docs2.pages.dev
Branch Preview URL: https://fix-schema-building-bug.pydantic-docs2.pages.dev

View logs

Copy link
Member

@adriangb adriangb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I'd like a little bit more documentation on how this fixes things (maybe resultant in a comment in the code) so we don't make the same mistake again, but other than that LGTM.

@sydney-runkle this is really impressive work in record time, great job.

@@ -54,6 +54,7 @@ def inner(s: core_schema.CoreSchema, recurse: _core_utils.Recurse) -> core_schem
if discriminator is not None:
if definitions is None:
definitions = collect_definitions(schema)
definitions = {k: recurse(v, inner) for k, v in definitions.items()}
s = apply_discriminator(s, discriminator, definitions)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🤯 care to explain how this fixes things?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure thing! I'll add some more comments in the code to explain things 😄!

I'll also add a more in-depth note to this PR, so we can reference that for more detail, if needed 👍.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So previously, with the above snippet of code, the resultant core_schema was the following:

{'definitions': [{'cls': <class '__main__.LeafState'>,
                  'config': {'title': 'LeafState'},
                  'custom_init': False,
                  'metadata': {'pydantic.internal.needs_apply_discriminated_union': False,
                               'pydantic_js_annotation_functions': [],
                               'pydantic_js_functions': [functools.partial(<function modify_model_json_schema at 0x101bcd900>, cls=<class '__main__.LeafState'>),
                                                         <bound method BaseModel.__get_pydantic_json_schema__ of <class '__main__.LeafState'>>]},
                  'ref': '__main__.LeafState:5484303040',
                  'root_model': False,
                  'schema': {'computed_fields': [],
                             'fields': {'state_type': {'metadata': {'pydantic_js_annotation_functions': [<function get_json_schema_update_func.<locals>.json_schema_update_func at 0x101c09d80>],
                                                                    'pydantic_js_functions': []},
                                                       'schema': {'expected': ['leaf'],
                                                                  'metadata': {'pydantic.internal.needs_apply_discriminated_union': False},
                                                                  'type': 'literal'},
                                                       'type': 'model-field'}},
                             'model_name': 'LeafState',
                             'type': 'model-fields'},
                  'type': 'model'},
                 {'cls': <class '__main__.LoopState'>,
                  'config': {'title': 'LoopState'},
                  'custom_init': False,
                  'metadata': {'pydantic.internal.needs_apply_discriminated_union': False,
                               'pydantic_js_annotation_functions': [],
                               'pydantic_js_functions': [functools.partial(<function modify_model_json_schema at 0x101bcd900>, cls=<class '__main__.LoopState'>),
                                                         <bound method BaseModel.__get_pydantic_json_schema__ of <class '__main__.LoopState'>>]},
                  'ref': '__main__.LoopState:5484288784',
                  'root_model': False,
                  'schema': {'computed_fields': [],
                             'fields': {'state_type': {'metadata': {'pydantic_js_annotation_functions': [<function get_json_schema_update_func.<locals>.json_schema_update_func at 0x101cf8310>],
                                                                    'pydantic_js_functions': []},
                                                       'schema': {'expected': ['loop'],
                                                                  'metadata': {'pydantic.internal.needs_apply_discriminated_union': False},
                                                                  'type': 'literal'},
                                                       'type': 'model-field'},
                                        'substate': {'metadata': {'pydantic_js_annotation_functions': [<function get_json_schema_update_func.<locals>.json_schema_update_func at 0x101cf83a0>],
                                                                  'pydantic_js_functions': []},
                                                     'schema': {'default': Ellipsis,
                                                                'schema': {'choices': [{'metadata': {'pydantic.internal.needs_apply_discriminated_union': False},
                                                                                        'schema_ref': '__main__.NestedState:5484231952',
                                                                                        'type': 'definition-ref'},
                                                                                       {'metadata': {'pydantic.internal.needs_apply_discriminated_union': False},
                                                                                        'schema_ref': '__main__.LoopState:5484288784',
                                                                                        'type': 'definition-ref'},
                                                                                       {'metadata': {'pydantic.internal.needs_apply_discriminated_union': False},
                                                                                        'schema_ref': '__main__.LeafState:5484303040',
                                                                                        'type': 'definition-ref'}],
                                                                           'metadata': {'pydantic.internal.needs_apply_discriminated_union': True,
                                                                                        'pydantic.internal.union_discriminator': 'state_type'},
                                                                           'type': 'union'},
                                                                'type': 'default'},
                                                     'type': 'model-field'}},
                             'model_name': 'LoopState',
                             'type': 'model-fields'},
                  'type': 'model'},
                 {'cls': <class '__main__.NestedState'>,
                  'config': {'title': 'NestedState'},
                  'custom_init': False,
                  'metadata': {'pydantic.internal.needs_apply_discriminated_union': False,
                               'pydantic_js_annotation_functions': [],
                               'pydantic_js_functions': [functools.partial(<function modify_model_json_schema at 0x101bcd900>, cls=<class '__main__.NestedState'>),
                                                         <bound method BaseModel.__get_pydantic_json_schema__ of <class '__main__.NestedState'>>]},
                  'ref': '__main__.NestedState:5484231952',
                  'root_model': False,
                  'schema': {'computed_fields': [],
                             'fields': {'state_type': {'metadata': {'pydantic_js_annotation_functions': [<function get_json_schema_update_func.<locals>.json_schema_update_func at 0x101cf80d0>],
                                                                    'pydantic_js_functions': []},
                                                       'schema': {'expected': ['nested'],
                                                                  'metadata': {'pydantic.internal.needs_apply_discriminated_union': False},
                                                                  'type': 'literal'},
                                                       'type': 'model-field'},
                                        'substate': {'metadata': {'pydantic_js_annotation_functions': [<function get_json_schema_update_func.<locals>.json_schema_update_func at 0x101cf8160>],
                                                                  'pydantic_js_functions': []},
                                                     'schema': {'default': Ellipsis,
                                                                'schema': {'choices': [{'metadata': {'pydantic.internal.needs_apply_discriminated_union': False},
                                                                                        'schema_ref': '__main__.NestedState:5484231952',
                                                                                        'type': 'definition-ref'},
                                                                                       {'metadata': {'pydantic.internal.needs_apply_discriminated_union': False},
                                                                                        'schema_ref': '__main__.LoopState:5484288784',
                                                                                        'type': 'definition-ref'},
                                                                                       {'metadata': {'pydantic.internal.needs_apply_discriminated_union': False},
                                                                                        'schema_ref': '__main__.LeafState:5484303040',
                                                                                        'type': 'definition-ref'}],
                                                                           'metadata': {'pydantic.internal.needs_apply_discriminated_union': True,
                                                                                        'pydantic.internal.union_discriminator': 'state_type'},
                                                                           'type': 'union'},
                                                                'type': 'default'},
                                                     'type': 'model-field'}},
                             'model_name': 'NestedState',
                             'type': 'model-fields'},
                  'type': 'model'}],
 'schema': {'choices': {'leaf': {'schema_ref': '__main__.LeafState:5484303040',
                                 'type': 'definition-ref'},
                        'loop': {'schema_ref': '__main__.LoopState:5484288784',
                                 'type': 'definition-ref'},
                        'nested': {'schema_ref': '__main__.NestedState:5484231952',
                                   'type': 'definition-ref'}},
            'discriminator': 'state_type',
            'from_attributes': True,
            'metadata': {'pydantic.internal.needs_apply_discriminated_union': True,
                         'pydantic.internal.union_discriminator': 'state_type'},
            'strict': False,
            'type': 'tagged-union'},
 'type': 'definitions'}

We already identified that the schemas for NestedState and LoopState within the definitions list were marked as having union type substate schemas, rather than tagged-union substate schemas. Interestingly, though, the metadata for those schemas had the necessary discriminated-union related information:

'metadata': {'pydantic.internal.needs_apply_discriminated_union': True, 'pydantic.internal.union_discriminator': 'state_type'}

So the question, then, is why aren't we "applying" the discriminators to those schemas? This is where my change comes in - after we collect the definitions schemas, we also need to pass over them + apply discriminators where appropriate. Within the inner function, we were already calling recurse, but we weren't calling it on any definitions schemas, which was the missing piece.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Totally makes sense!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Totally makes sense!

Copy link
Member

@samuelcolvin samuelcolvin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@sydney-runkle sydney-runkle merged commit 5a4c056 into main Feb 27, 2024
54 checks passed
@sydney-runkle sydney-runkle deleted the fix-schema-building-bug branch February 27, 2024 18:05
@sydney-runkle
Copy link
Member Author

Update -- I believe this fixed the performance issue for the above case, but didn't entirely fix the schema generation process. Here's a fix for that: #8932.

I'd like to work on a few other schema generation / performance related issues, but hopefully we can get these out in another patch release soon 😄.

Feel free to ping me if you have any questions!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
relnotes-fix Used for bugfixes.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Recursive model with discriminated union has exponential time and space complexity for validation
3 participants