Skip to content

Commit

Permalink
Use ser_json_<timedelta|bytes> on default in GenerateJsonSchema (#…
Browse files Browse the repository at this point in the history
…7269)

Co-authored-by: David Montague <35119617+dmontagu@users.noreply.github.com>
  • Loading branch information
Kludex and dmontagu committed Aug 29, 2023
1 parent 84282ef commit 2acf1af
Show file tree
Hide file tree
Showing 5 changed files with 232 additions and 48 deletions.
43 changes: 41 additions & 2 deletions pydantic/_internal/_config.py
@@ -1,10 +1,21 @@
from __future__ import annotations as _annotations

import warnings
from typing import TYPE_CHECKING, Any, Callable, cast
from contextlib import contextmanager, nullcontext
from typing import (
TYPE_CHECKING,
Any,
Callable,
ContextManager,
Iterator,
cast,
)

from pydantic_core import core_schema
from typing_extensions import Literal, Self
from typing_extensions import (
Literal,
Self,
)

from ..config import ConfigDict, ExtraValues, JsonEncoder, JsonSchemaExtraCallable
from ..errors import PydanticUserError
Expand Down Expand Up @@ -169,6 +180,34 @@ def __repr__(self):
return f'ConfigWrapper({c})'


class ConfigWrapperStack:
"""A stack of `ConfigWrapper` instances."""

def __init__(self, config_wrapper: ConfigWrapper):
self._config_wrapper_stack: list[ConfigWrapper] = [config_wrapper]

@property
def tail(self) -> ConfigWrapper:
return self._config_wrapper_stack[-1]

def push(self, config_wrapper: ConfigWrapper | ConfigDict | None) -> ContextManager[None]:
if config_wrapper is None:
return nullcontext()

if not isinstance(config_wrapper, ConfigWrapper):
config_wrapper = ConfigWrapper(config_wrapper, check=False)

@contextmanager
def _context_manager() -> Iterator[None]:
self._config_wrapper_stack.append(config_wrapper)
try:
yield
finally:
self._config_wrapper_stack.pop()

return _context_manager()


config_defaults = ConfigDict(
title=None,
str_to_lower=False,
Expand Down
4 changes: 4 additions & 0 deletions pydantic/_internal/_core_metadata.py
Expand Up @@ -30,6 +30,8 @@ class CoreMetadata(typing_extensions.TypedDict, total=False):
# prefer positional over keyword arguments for an 'arguments' schema.
pydantic_js_prefer_positional_arguments: bool | None

pydantic_typed_dict_cls: type[Any] | None # TODO: Consider moving this into the pydantic-core TypedDictSchema


class CoreMetadataHandler:
"""Because the metadata field in pydantic_core is of type `Any`, we can't assume much about its contents.
Expand Down Expand Up @@ -67,6 +69,7 @@ def build_metadata_dict(
js_functions: list[GetJsonSchemaFunction] | None = None,
js_annotation_functions: list[GetJsonSchemaFunction] | None = None,
js_prefer_positional_arguments: bool | None = None,
typed_dict_cls: type[Any] | None = None,
initial_metadata: Any | None = None,
) -> Any:
"""Builds a dict to use as the metadata field of a CoreSchema object in a manner that is consistent
Expand All @@ -79,6 +82,7 @@ def build_metadata_dict(
pydantic_js_functions=js_functions or [],
pydantic_js_annotation_functions=js_annotation_functions or [],
pydantic_js_prefer_positional_arguments=js_prefer_positional_arguments,
pydantic_typed_dict_cls=typed_dict_cls,
)
metadata = {k: v for k, v in metadata.items() if v is not None}

Expand Down
37 changes: 5 additions & 32 deletions pydantic/_internal/_generate_schema.py
Expand Up @@ -8,7 +8,7 @@
import sys
import typing
import warnings
from contextlib import contextmanager, nullcontext
from contextlib import contextmanager
from copy import copy
from enum import Enum
from functools import partial
Expand All @@ -20,7 +20,6 @@
TYPE_CHECKING,
Any,
Callable,
ContextManager,
Dict,
ForwardRef,
Iterable,
Expand All @@ -45,7 +44,7 @@
from ..warnings import PydanticDeprecatedSince20
from . import _decorators, _discriminated_union, _known_annotated_metadata, _typing_extra
from ._annotated_handlers import GetCoreSchemaHandler, GetJsonSchemaHandler
from ._config import ConfigWrapper
from ._config import ConfigWrapper, ConfigWrapperStack
from ._core_metadata import (
CoreMetadataHandler,
build_metadata_dict,
Expand Down Expand Up @@ -261,34 +260,6 @@ def _add_custom_serialization_from_json_encoders(
return schema


class ConfigWrapperStack:
"""A stack of `ConfigWrapper` instances."""

def __init__(self, config_wrapper: ConfigWrapper):
self._config_wrapper_stack: list[ConfigWrapper] = [config_wrapper]

@property
def tail(self) -> ConfigWrapper:
return self._config_wrapper_stack[-1]

def push(self, config_wrapper: ConfigWrapper | ConfigDict | None) -> ContextManager[None]:
if config_wrapper is None:
return nullcontext()

if not isinstance(config_wrapper, ConfigWrapper):
config_wrapper = ConfigWrapper(config_wrapper, check=False)

@contextmanager
def _context_manager() -> Iterator[None]:
self._config_wrapper_stack.append(config_wrapper)
try:
yield
finally:
self._config_wrapper_stack.pop()

return _context_manager()


class GenerateSchema:
"""Generate core schema for a Pydantic model, dataclass and types like `str`, `datetime`, ... ."""

Expand Down Expand Up @@ -1098,7 +1069,9 @@ def _typed_dict_schema(self, typed_dict_cls: Any, origin: Any) -> core_schema.Co
field_name, field_info, decorators, required=required
)

metadata = build_metadata_dict(js_functions=[partial(modify_model_json_schema, cls=typed_dict_cls)])
metadata = build_metadata_dict(
js_functions=[partial(modify_model_json_schema, cls=typed_dict_cls)], typed_dict_cls=typed_dict_cls
)

td_schema = core_schema.typed_dict_schema(
fields,
Expand Down
62 changes: 48 additions & 14 deletions pydantic/json_schema.py
Expand Up @@ -35,13 +35,21 @@
)

import pydantic_core
from pydantic_core import CoreConfig, CoreSchema, PydanticOmit, core_schema, to_jsonable_python
from pydantic_core import CoreSchema, PydanticOmit, core_schema, to_jsonable_python
from pydantic_core.core_schema import ComputedField
from typing_extensions import Annotated, Literal, assert_never

from pydantic._internal import _annotated_handlers, _internal_dataclass

from ._internal import _core_metadata, _core_utils, _mock_val_ser, _schema_generation_shared, _typing_extra
from ._internal import (
_annotated_handlers,
_config,
_core_metadata,
_core_utils,
_decorators,
_internal_dataclass,
_mock_val_ser,
_schema_generation_shared,
_typing_extra,
)
from .config import JsonSchemaExtraCallable
from .errors import PydanticInvalidForJsonSchema, PydanticUserError

Expand Down Expand Up @@ -266,6 +274,7 @@ def __init__(self, by_alias: bool = True, ref_template: str = DEFAULT_REF_TEMPLA
self.json_to_defs_refs: dict[JsonRef, DefsRef] = {}

self.definitions: dict[DefsRef, JsonSchemaValue] = {}
self._config_wrapper_stack = _config.ConfigWrapperStack(_config.ConfigWrapper({}))

self.mode: JsonSchemaMode = 'validation'

Expand All @@ -291,6 +300,10 @@ def __init__(self, by_alias: bool = True, ref_template: str = DEFAULT_REF_TEMPLA
# of a single instance of a schema generator
self._used = False

@property
def _config(self) -> _config.ConfigWrapper:
return self._config_wrapper_stack.tail

def build_schema_type_to_method(
self,
) -> dict[CoreSchemaOrFieldType, Callable[[CoreSchemaOrField], JsonSchemaValue]]:
Expand Down Expand Up @@ -649,7 +662,7 @@ def bytes_schema(self, schema: core_schema.BytesSchema) -> JsonSchemaValue:
Returns:
The generated JSON schema.
"""
json_schema = {'type': 'string', 'format': 'binary'}
json_schema = {'type': 'string', 'format': 'base64url' if self._config.ser_json_bytes == 'base64' else 'binary'}
self.update_with_validations(json_schema, schema, self.ValidationsMapping.bytes)
return json_schema

Expand Down Expand Up @@ -697,6 +710,8 @@ def timedelta_schema(self, schema: core_schema.TimedeltaSchema) -> JsonSchemaVal
Returns:
The generated JSON schema.
"""
if self._config.ser_json_timedelta == 'float':
return {'type': 'number'}
return {'type': 'string', 'format': 'duration'}

def literal_schema(self, schema: core_schema.LiteralSchema) -> JsonSchemaValue:
Expand Down Expand Up @@ -1168,10 +1183,12 @@ def typed_dict_schema(self, schema: core_schema.TypedDictSchema) -> JsonSchemaVa
]
if self.mode == 'serialization':
named_required_fields.extend(self._name_required_computed_fields(schema.get('computed_fields', [])))
json_schema = self._named_required_fields_schema(named_required_fields)
config: CoreConfig | None = schema.get('config', None)

extra = (config or {}).get('extra_fields_behavior', 'ignore')
config = _get_typed_dict_config(schema)
with self._config_wrapper_stack.push(config):
json_schema = self._named_required_fields_schema(named_required_fields)

extra = config.get('extra', 'ignore')
if extra == 'forbid':
json_schema['additionalProperties'] = False
elif extra == 'allow':
Expand Down Expand Up @@ -1286,12 +1303,13 @@ def model_schema(self, schema: core_schema.ModelSchema) -> JsonSchemaValue:
"""
# We do not use schema['model'].model_json_schema() here
# because it could lead to inconsistent refs handling, etc.
json_schema = self.generate_inner(schema['schema'])

cls = cast('type[BaseModel]', schema['cls'])
config = cls.model_config
title = config.get('title')

with self._config_wrapper_stack.push(config):
json_schema = self.generate_inner(schema['schema'])

json_schema_extra = config.get('json_schema_extra')
if cls.__pydantic_root_model__:
root_json_schema_extra = cls.model_fields['root'].json_schema_extra
Expand Down Expand Up @@ -1461,13 +1479,13 @@ def dataclass_schema(self, schema: core_schema.DataclassSchema) -> JsonSchemaVal
Returns:
The generated JSON schema.
"""
json_schema = self.generate_inner(schema['schema']).copy()

cls = schema['cls']
config: ConfigDict = getattr(cls, '__pydantic_config__', cast('ConfigDict', {}))

title = config.get('title') or cls.__name__

with self._config_wrapper_stack.push(config):
json_schema = self.generate_inner(schema['schema']).copy()

json_schema_extra = config.get('json_schema_extra')
json_schema = self._update_class_schema(json_schema, title, config.get('extra', None), cls, json_schema_extra)

Expand Down Expand Up @@ -1942,7 +1960,12 @@ def encode_default(self, dft: Any) -> Any:
Returns:
The encoded default value.
"""
return pydantic_core.to_jsonable_python(dft)
config = self._config
return pydantic_core.to_jsonable_python(
dft,
timedelta_mode=config.ser_json_timedelta,
bytes_mode=config.ser_json_bytes,
)

def update_with_validations(
self, json_schema: JsonSchemaValue, core_schema: CoreSchema, mapping: dict[str, str]
Expand Down Expand Up @@ -2321,3 +2344,14 @@ def __get_pydantic_json_schema__(

def __hash__(self) -> int:
return hash(type(self))


def _get_typed_dict_config(schema: core_schema.TypedDictSchema) -> ConfigDict:
metadata = _core_metadata.CoreMetadataHandler(schema).metadata
cls = metadata.get('pydantic_typed_dict_cls')
if cls is not None:
try:
return _decorators.get_attribute_from_bases(cls, '__pydantic_config__')
except AttributeError:
pass
return {}

0 comments on commit 2acf1af

Please sign in to comment.