Skip to content

Commit

Permalink
Simplify handling of typing.Annotated in GenerateSchema (#6887)
Browse files Browse the repository at this point in the history
  • Loading branch information
dmontagu committed Jul 26, 2023
1 parent 0008ed3 commit 14f27b5
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 37 deletions.
61 changes: 25 additions & 36 deletions pydantic/_internal/_generate_schema.py
Expand Up @@ -87,6 +87,7 @@
from ._schema_generation_shared import GetJsonSchemaFunction

_SUPPORTS_TYPEDDICT = sys.version_info >= (3, 12)
_AnnotatedType = type(Annotated[int, 123])

FieldDecoratorInfo = Union[ValidatorDecoratorInfo, FieldValidatorDecoratorInfo, FieldSerializerDecoratorInfo]
FieldDecoratorInfoType = TypeVar('FieldDecoratorInfoType', bound=FieldDecoratorInfo)
Expand Down Expand Up @@ -391,6 +392,17 @@ def collect_definitions(self, schema: CoreSchema) -> CoreSchema:
list(self.defs.definitions.values()),
)

def _add_js_function(self, metadata_schema: CoreSchema, js_function: Callable[..., Any]) -> None:
metadata = CoreMetadataHandler(metadata_schema).metadata
pydantic_js_functions = metadata.setdefault('pydantic_js_functions', [])
# because of how we generate core schemas for nested generic models
# we can end up adding `BaseModel.__get_pydantic_json_schema__` multiple times
# this check may fail to catch duplicates if the function is a `functools.partial`
# or something like that
# but if it does it'll fail by inserting the duplicate
if js_function not in pydantic_js_functions:
pydantic_js_functions.append(js_function)

def generate_schema(
self,
obj: Any,
Expand Down Expand Up @@ -421,29 +433,6 @@ def generate_schema(
- If `typing.TypedDict` is used instead of `typing_extensions.TypedDict` on Python < 3.12.
- If `__modify_schema__` method is used instead of `__get_pydantic_json_schema__`.
"""
if isinstance(obj, type(Annotated[int, 123])):
return self._annotated_schema(obj)
return self._generate_schema_for_type(
obj, from_dunder_get_core_schema=from_dunder_get_core_schema, from_prepare_args=from_prepare_args
)

def _add_js_function(self, metadata_schema: CoreSchema, js_function: Callable[..., Any]) -> None:
metadata = CoreMetadataHandler(metadata_schema).metadata
pydantic_js_functions = metadata.setdefault('pydantic_js_functions', [])
# because of how we generate core schemas for nested generic models
# we can end up adding `BaseModel.__get_pydantic_json_schema__` multiple times
# this check may fail to catch duplicates if the function is a `functools.partial`
# or something like that
# but if it does it'll fail by inserting the duplicate
if js_function not in pydantic_js_functions:
pydantic_js_functions.append(js_function)

def _generate_schema_for_type(
self,
obj: Any,
from_dunder_get_core_schema: bool = True,
from_prepare_args: bool = True,
) -> CoreSchema:
schema: CoreSchema | None = None

if from_prepare_args:
Expand Down Expand Up @@ -668,6 +657,9 @@ def _get_first_two_args_or_any(self, obj: Any) -> tuple[Any, Any]:

def _generate_schema(self, obj: Any) -> core_schema.CoreSchema:
"""Recursively generate a pydantic-core schema for any supported python type."""
if isinstance(obj, _AnnotatedType):
return self._annotated_schema(obj)

if isinstance(obj, dict):
# we assume this is already a valid schema
return obj # type: ignore[return-value]
Expand Down Expand Up @@ -1509,7 +1501,7 @@ def _prepare_annotations(self, source_type: Any, annotations: Iterable[Any]) ->

return source_type, list(annotations)

def _apply_annotations( # noqa: C901
def _apply_annotations(
self,
source_type: Any,
annotations: list[Any],
Expand Down Expand Up @@ -1542,19 +1534,16 @@ def _apply_annotations( # noqa: C901
pydantic_js_annotation_functions: list[GetJsonSchemaFunction] = []

def inner_handler(obj: Any) -> CoreSchema:
if isinstance(obj, type(Annotated[int, 123])):
schema = transform_inner_schema(self._annotated_schema(obj))
from_property = self._generate_schema_from_property(obj, obj)
if from_property is None:
schema = self._generate_schema(obj)
else:
from_property = self._generate_schema_from_property(obj, obj)
if from_property is None:
schema = self._generate_schema(obj)
else:
schema = from_property
metadata_js_function = _extract_get_pydantic_json_schema(obj, schema)
if metadata_js_function is not None:
metadata_schema = resolve_original_schema(schema, self.defs.definitions)
if metadata_schema is not None:
self._add_js_function(metadata_schema, metadata_js_function)
schema = from_property
metadata_js_function = _extract_get_pydantic_json_schema(obj, schema)
if metadata_js_function is not None:
metadata_schema = resolve_original_schema(schema, self.defs.definitions)
if metadata_schema is not None:
self._add_js_function(metadata_schema, metadata_js_function)
return transform_inner_schema(schema)

get_inner_schema = CallbackGetCoreSchemaHandler(inner_handler, self)
Expand Down
65 changes: 64 additions & 1 deletion tests/test_discriminated_union.py
@@ -1,7 +1,7 @@
import re
import sys
from enum import Enum, IntEnum
from typing import Generic, Optional, TypeVar, Union
from typing import Generic, Optional, Sequence, TypeVar, Union

import pytest
from dirty_equals import HasRepr, IsStr
Expand Down Expand Up @@ -1271,3 +1271,66 @@ class Model(BaseModel):
'type': 'union_tag_invalid',
}
]


def test_sequence_discriminated_union():
class Cat(BaseModel):
pet_type: Literal['cat']
meows: int

class Dog(BaseModel):
pet_type: Literal['dog']
barks: float

class Lizard(BaseModel):
pet_type: Literal['reptile', 'lizard']
scales: bool

Pet = Annotated[Union[Cat, Dog, Lizard], Field(discriminator='pet_type')]

class Model(BaseModel):
pet: Sequence[Pet]
n: int

assert Model.model_json_schema() == {
'$defs': {
'Cat': {
'properties': {
'meows': {'title': 'Meows', 'type': 'integer'},
'pet_type': {'const': 'cat', 'title': 'Pet Type'},
},
'required': ['pet_type', 'meows'],
'title': 'Cat',
'type': 'object',
},
'Dog': {
'properties': {
'barks': {'title': 'Barks', 'type': 'number'},
'pet_type': {'const': 'dog', 'title': 'Pet Type'},
},
'required': ['pet_type', 'barks'],
'title': 'Dog',
'type': 'object',
},
'Lizard': {
'properties': {
'pet_type': {'enum': ['reptile', 'lizard'], 'title': 'Pet Type', 'type': 'string'},
'scales': {'title': 'Scales', 'type': 'boolean'},
},
'required': ['pet_type', 'scales'],
'title': 'Lizard',
'type': 'object',
},
},
'properties': {
'n': {'title': 'N', 'type': 'integer'},
'pet': {
'items': {'anyOf': [{'$ref': '#/$defs/Cat'}, {'$ref': '#/$defs/Dog'}, {'$ref': '#/$defs/Lizard'}]},
'title': 'Pet',
'type': 'array',
},
},
'required': ['pet', 'n'],
'title': 'Model',
'type': 'object',
}

0 comments on commit 14f27b5

Please sign in to comment.