Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify handling of typing.Annotated in GenerateSchema #6887

Merged
merged 2 commits into from Jul 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
61 changes: 25 additions & 36 deletions pydantic/_internal/_generate_schema.py
Expand Up @@ -87,6 +87,7 @@
from ._schema_generation_shared import GetJsonSchemaFunction

_SUPPORTS_TYPEDDICT = sys.version_info >= (3, 12)
_AnnotatedType = type(Annotated[int, 123])
Copy link
Contributor Author

@dmontagu dmontagu Jul 26, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Repeated calls to type and Annotated.__class_getitem__ while building schemas seemed unnecessary/inefficient, but I didn't profile this or anything. Probably negligible overhead but just ... rubbed me the wrong way reading isinstance(obj, type(Annotated[int, 123])) in a frequently-executed loop (at least during model creation) lol. Can revert this change if preferred.


FieldDecoratorInfo = Union[ValidatorDecoratorInfo, FieldValidatorDecoratorInfo, FieldSerializerDecoratorInfo]
FieldDecoratorInfoType = TypeVar('FieldDecoratorInfoType', bound=FieldDecoratorInfo)
Expand Down Expand Up @@ -391,6 +392,17 @@ def collect_definitions(self, schema: CoreSchema) -> CoreSchema:
list(self.defs.definitions.values()),
)

def _add_js_function(self, metadata_schema: CoreSchema, js_function: Callable[..., Any]) -> None:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this function just moved, no changes were made

metadata = CoreMetadataHandler(metadata_schema).metadata
pydantic_js_functions = metadata.setdefault('pydantic_js_functions', [])
# because of how we generate core schemas for nested generic models
# we can end up adding `BaseModel.__get_pydantic_json_schema__` multiple times
# this check may fail to catch duplicates if the function is a `functools.partial`
# or something like that
# but if it does it'll fail by inserting the duplicate
if js_function not in pydantic_js_functions:
pydantic_js_functions.append(js_function)

def generate_schema(
self,
obj: Any,
Expand Down Expand Up @@ -421,29 +433,6 @@ def generate_schema(
- If `typing.TypedDict` is used instead of `typing_extensions.TypedDict` on Python < 3.12.
- If `__modify_schema__` method is used instead of `__get_pydantic_json_schema__`.
"""
if isinstance(obj, type(Annotated[int, 123])):
return self._annotated_schema(obj)
return self._generate_schema_for_type(
obj, from_dunder_get_core_schema=from_dunder_get_core_schema, from_prepare_args=from_prepare_args
)

def _add_js_function(self, metadata_schema: CoreSchema, js_function: Callable[..., Any]) -> None:
metadata = CoreMetadataHandler(metadata_schema).metadata
pydantic_js_functions = metadata.setdefault('pydantic_js_functions', [])
# because of how we generate core schemas for nested generic models
# we can end up adding `BaseModel.__get_pydantic_json_schema__` multiple times
# this check may fail to catch duplicates if the function is a `functools.partial`
# or something like that
# but if it does it'll fail by inserting the duplicate
if js_function not in pydantic_js_functions:
pydantic_js_functions.append(js_function)

def _generate_schema_for_type(
Copy link
Contributor Author

@dmontagu dmontagu Jul 26, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I eliminated this method since the only place it was called was in generate_schema, and with the removal of the handling of _annotated_schema there, generate_schema was just immediately calling this function with the exact same signature and arguments passing through without modification.

self,
obj: Any,
from_dunder_get_core_schema: bool = True,
from_prepare_args: bool = True,
) -> CoreSchema:
schema: CoreSchema | None = None

if from_prepare_args:
Expand Down Expand Up @@ -668,6 +657,9 @@ def _get_first_two_args_or_any(self, obj: Any) -> tuple[Any, Any]:

def _generate_schema(self, obj: Any) -> core_schema.CoreSchema:
"""Recursively generate a pydantic-core schema for any supported python type."""
if isinstance(obj, _AnnotatedType):
return self._annotated_schema(obj)

if isinstance(obj, dict):
# we assume this is already a valid schema
return obj # type: ignore[return-value]
Expand Down Expand Up @@ -1509,7 +1501,7 @@ def _prepare_annotations(self, source_type: Any, annotations: Iterable[Any]) ->

return source_type, list(annotations)

def _apply_annotations( # noqa: C901
def _apply_annotations(
self,
source_type: Any,
annotations: list[Any],
Expand Down Expand Up @@ -1542,19 +1534,16 @@ def _apply_annotations( # noqa: C901
pydantic_js_annotation_functions: list[GetJsonSchemaFunction] = []

def inner_handler(obj: Any) -> CoreSchema:
if isinstance(obj, type(Annotated[int, 123])):
schema = transform_inner_schema(self._annotated_schema(obj))
Comment on lines -1545 to -1546
Copy link
Contributor Author

@dmontagu dmontagu Jul 26, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This branch no longer seemed necessary now that _generate_schema handles Annotated properly (and Annotated doesn't hit any of the other property extractions as far as I can tell). So I just eliminated the branch here, that's all that's happening in this chunk.

from_property = self._generate_schema_from_property(obj, obj)
if from_property is None:
schema = self._generate_schema(obj)
else:
from_property = self._generate_schema_from_property(obj, obj)
if from_property is None:
schema = self._generate_schema(obj)
else:
schema = from_property
metadata_js_function = _extract_get_pydantic_json_schema(obj, schema)
if metadata_js_function is not None:
metadata_schema = resolve_original_schema(schema, self.defs.definitions)
if metadata_schema is not None:
self._add_js_function(metadata_schema, metadata_js_function)
schema = from_property
metadata_js_function = _extract_get_pydantic_json_schema(obj, schema)
if metadata_js_function is not None:
metadata_schema = resolve_original_schema(schema, self.defs.definitions)
if metadata_schema is not None:
self._add_js_function(metadata_schema, metadata_js_function)
return transform_inner_schema(schema)

get_inner_schema = CallbackGetCoreSchemaHandler(inner_handler, self)
Expand Down
65 changes: 64 additions & 1 deletion tests/test_discriminated_union.py
@@ -1,7 +1,7 @@
import re
import sys
from enum import Enum, IntEnum
from typing import Generic, Optional, TypeVar, Union
from typing import Generic, Optional, Sequence, TypeVar, Union

import pytest
from dirty_equals import HasRepr, IsStr
Expand Down Expand Up @@ -1271,3 +1271,66 @@ class Model(BaseModel):
'type': 'union_tag_invalid',
}
]


def test_sequence_discriminated_union():
class Cat(BaseModel):
pet_type: Literal['cat']
meows: int

class Dog(BaseModel):
pet_type: Literal['dog']
barks: float

class Lizard(BaseModel):
pet_type: Literal['reptile', 'lizard']
scales: bool

Pet = Annotated[Union[Cat, Dog, Lizard], Field(discriminator='pet_type')]

class Model(BaseModel):
pet: Sequence[Pet]
n: int

assert Model.model_json_schema() == {
'$defs': {
'Cat': {
'properties': {
'meows': {'title': 'Meows', 'type': 'integer'},
'pet_type': {'const': 'cat', 'title': 'Pet Type'},
},
'required': ['pet_type', 'meows'],
'title': 'Cat',
'type': 'object',
},
'Dog': {
'properties': {
'barks': {'title': 'Barks', 'type': 'number'},
'pet_type': {'const': 'dog', 'title': 'Pet Type'},
},
'required': ['pet_type', 'barks'],
'title': 'Dog',
'type': 'object',
},
'Lizard': {
'properties': {
'pet_type': {'enum': ['reptile', 'lizard'], 'title': 'Pet Type', 'type': 'string'},
'scales': {'title': 'Scales', 'type': 'boolean'},
},
'required': ['pet_type', 'scales'],
'title': 'Lizard',
'type': 'object',
},
},
'properties': {
'n': {'title': 'N', 'type': 'integer'},
'pet': {
'items': {'anyOf': [{'$ref': '#/$defs/Cat'}, {'$ref': '#/$defs/Dog'}, {'$ref': '#/$defs/Lizard'}]},
'title': 'Pet',
'type': 'array',
},
},
'required': ['pet', 'n'],
'title': 'Model',
'type': 'object',
}