Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use ser_json_<timedelta|bytes> on default in GenerateJsonSchema #7269

Merged
merged 6 commits into from Aug 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
43 changes: 41 additions & 2 deletions pydantic/_internal/_config.py
@@ -1,10 +1,21 @@
from __future__ import annotations as _annotations

import warnings
from typing import TYPE_CHECKING, Any, Callable, cast
from contextlib import contextmanager, nullcontext
from typing import (
TYPE_CHECKING,
Any,
Callable,
ContextManager,
Iterator,
cast,
)

from pydantic_core import core_schema
from typing_extensions import Literal, Self
from typing_extensions import (
Literal,
Self,
)

from ..config import ConfigDict, ExtraValues, JsonEncoder, JsonSchemaExtraCallable
from ..errors import PydanticUserError
Expand Down Expand Up @@ -169,6 +180,34 @@ def __repr__(self):
return f'ConfigWrapper({c})'


class ConfigWrapperStack:
"""A stack of `ConfigWrapper` instances."""

def __init__(self, config_wrapper: ConfigWrapper):
self._config_wrapper_stack: list[ConfigWrapper] = [config_wrapper]

@property
def tail(self) -> ConfigWrapper:
return self._config_wrapper_stack[-1]

def push(self, config_wrapper: ConfigWrapper | ConfigDict | None) -> ContextManager[None]:
if config_wrapper is None:
return nullcontext()

if not isinstance(config_wrapper, ConfigWrapper):
config_wrapper = ConfigWrapper(config_wrapper, check=False)

@contextmanager
def _context_manager() -> Iterator[None]:
self._config_wrapper_stack.append(config_wrapper)
try:
yield
finally:
self._config_wrapper_stack.pop()

return _context_manager()


config_defaults = ConfigDict(
title=None,
str_to_lower=False,
Expand Down
4 changes: 4 additions & 0 deletions pydantic/_internal/_core_metadata.py
Expand Up @@ -30,6 +30,8 @@ class CoreMetadata(typing_extensions.TypedDict, total=False):
# prefer positional over keyword arguments for an 'arguments' schema.
pydantic_js_prefer_positional_arguments: bool | None

pydantic_typed_dict_cls: type[Any] | None # TODO: Consider moving this into the pydantic-core TypedDictSchema


class CoreMetadataHandler:
"""Because the metadata field in pydantic_core is of type `Any`, we can't assume much about its contents.
Expand Down Expand Up @@ -67,6 +69,7 @@ def build_metadata_dict(
js_functions: list[GetJsonSchemaFunction] | None = None,
js_annotation_functions: list[GetJsonSchemaFunction] | None = None,
js_prefer_positional_arguments: bool | None = None,
typed_dict_cls: type[Any] | None = None,
initial_metadata: Any | None = None,
) -> Any:
"""Builds a dict to use as the metadata field of a CoreSchema object in a manner that is consistent
Expand All @@ -79,6 +82,7 @@ def build_metadata_dict(
pydantic_js_functions=js_functions or [],
pydantic_js_annotation_functions=js_annotation_functions or [],
pydantic_js_prefer_positional_arguments=js_prefer_positional_arguments,
pydantic_typed_dict_cls=typed_dict_cls,
)
metadata = {k: v for k, v in metadata.items() if v is not None}

Expand Down
37 changes: 5 additions & 32 deletions pydantic/_internal/_generate_schema.py
Expand Up @@ -8,7 +8,7 @@
import sys
import typing
import warnings
from contextlib import contextmanager, nullcontext
from contextlib import contextmanager
from copy import copy
from enum import Enum
from functools import partial
Expand All @@ -20,7 +20,6 @@
TYPE_CHECKING,
Any,
Callable,
ContextManager,
Dict,
ForwardRef,
Iterable,
Expand All @@ -45,7 +44,7 @@
from ..warnings import PydanticDeprecatedSince20
from . import _decorators, _discriminated_union, _known_annotated_metadata, _typing_extra
from ._annotated_handlers import GetCoreSchemaHandler, GetJsonSchemaHandler
from ._config import ConfigWrapper
from ._config import ConfigWrapper, ConfigWrapperStack
from ._core_metadata import (
CoreMetadataHandler,
build_metadata_dict,
Expand Down Expand Up @@ -261,34 +260,6 @@ def _add_custom_serialization_from_json_encoders(
return schema


class ConfigWrapperStack:
"""A stack of `ConfigWrapper` instances."""

def __init__(self, config_wrapper: ConfigWrapper):
self._config_wrapper_stack: list[ConfigWrapper] = [config_wrapper]

@property
def tail(self) -> ConfigWrapper:
return self._config_wrapper_stack[-1]

def push(self, config_wrapper: ConfigWrapper | ConfigDict | None) -> ContextManager[None]:
if config_wrapper is None:
return nullcontext()

if not isinstance(config_wrapper, ConfigWrapper):
config_wrapper = ConfigWrapper(config_wrapper, check=False)

@contextmanager
def _context_manager() -> Iterator[None]:
self._config_wrapper_stack.append(config_wrapper)
try:
yield
finally:
self._config_wrapper_stack.pop()

return _context_manager()


class GenerateSchema:
"""Generate core schema for a Pydantic model, dataclass and types like `str`, `datetime`, ... ."""

Expand Down Expand Up @@ -1098,7 +1069,9 @@ def _typed_dict_schema(self, typed_dict_cls: Any, origin: Any) -> core_schema.Co
field_name, field_info, decorators, required=required
)

metadata = build_metadata_dict(js_functions=[partial(modify_model_json_schema, cls=typed_dict_cls)])
metadata = build_metadata_dict(
js_functions=[partial(modify_model_json_schema, cls=typed_dict_cls)], typed_dict_cls=typed_dict_cls
)

td_schema = core_schema.typed_dict_schema(
fields,
Expand Down
62 changes: 48 additions & 14 deletions pydantic/json_schema.py
Expand Up @@ -35,13 +35,21 @@
)
Kludex marked this conversation as resolved.
Show resolved Hide resolved

import pydantic_core
from pydantic_core import CoreConfig, CoreSchema, PydanticOmit, core_schema, to_jsonable_python
from pydantic_core import CoreSchema, PydanticOmit, core_schema, to_jsonable_python
from pydantic_core.core_schema import ComputedField
from typing_extensions import Annotated, Literal, assert_never

from pydantic._internal import _annotated_handlers, _internal_dataclass

from ._internal import _core_metadata, _core_utils, _mock_val_ser, _schema_generation_shared, _typing_extra
from ._internal import (
_annotated_handlers,
_config,
_core_metadata,
_core_utils,
_decorators,
_internal_dataclass,
_mock_val_ser,
_schema_generation_shared,
_typing_extra,
)
from .config import JsonSchemaExtraCallable
from .errors import PydanticInvalidForJsonSchema, PydanticUserError

Expand Down Expand Up @@ -266,6 +274,7 @@ def __init__(self, by_alias: bool = True, ref_template: str = DEFAULT_REF_TEMPLA
self.json_to_defs_refs: dict[JsonRef, DefsRef] = {}

self.definitions: dict[DefsRef, JsonSchemaValue] = {}
self._config_wrapper_stack = _config.ConfigWrapperStack(_config.ConfigWrapper({}))

self.mode: JsonSchemaMode = 'validation'

Expand All @@ -291,6 +300,10 @@ def __init__(self, by_alias: bool = True, ref_template: str = DEFAULT_REF_TEMPLA
# of a single instance of a schema generator
self._used = False

@property
def _config(self) -> _config.ConfigWrapper:
return self._config_wrapper_stack.tail

def build_schema_type_to_method(
self,
) -> dict[CoreSchemaOrFieldType, Callable[[CoreSchemaOrField], JsonSchemaValue]]:
Expand Down Expand Up @@ -649,7 +662,7 @@ def bytes_schema(self, schema: core_schema.BytesSchema) -> JsonSchemaValue:
Returns:
The generated JSON schema.
"""
json_schema = {'type': 'string', 'format': 'binary'}
json_schema = {'type': 'string', 'format': 'base64url' if self._config.ser_json_bytes == 'base64' else 'binary'}
self.update_with_validations(json_schema, schema, self.ValidationsMapping.bytes)
return json_schema

Expand Down Expand Up @@ -697,6 +710,8 @@ def timedelta_schema(self, schema: core_schema.TimedeltaSchema) -> JsonSchemaVal
Returns:
The generated JSON schema.
"""
if self._config.ser_json_timedelta == 'float':
return {'type': 'number'}
return {'type': 'string', 'format': 'duration'}

def literal_schema(self, schema: core_schema.LiteralSchema) -> JsonSchemaValue:
Expand Down Expand Up @@ -1168,10 +1183,12 @@ def typed_dict_schema(self, schema: core_schema.TypedDictSchema) -> JsonSchemaVa
]
if self.mode == 'serialization':
named_required_fields.extend(self._name_required_computed_fields(schema.get('computed_fields', [])))
json_schema = self._named_required_fields_schema(named_required_fields)
config: CoreConfig | None = schema.get('config', None)

extra = (config or {}).get('extra_fields_behavior', 'ignore')
config = _get_typed_dict_config(schema)
with self._config_wrapper_stack.push(config):
json_schema = self._named_required_fields_schema(named_required_fields)

extra = config.get('extra', 'ignore')
if extra == 'forbid':
json_schema['additionalProperties'] = False
elif extra == 'allow':
Expand Down Expand Up @@ -1286,12 +1303,13 @@ def model_schema(self, schema: core_schema.ModelSchema) -> JsonSchemaValue:
"""
# We do not use schema['model'].model_json_schema() here
# because it could lead to inconsistent refs handling, etc.
json_schema = self.generate_inner(schema['schema'])

cls = cast('type[BaseModel]', schema['cls'])
config = cls.model_config
title = config.get('title')

with self._config_wrapper_stack.push(config):
json_schema = self.generate_inner(schema['schema'])

json_schema_extra = config.get('json_schema_extra')
if cls.__pydantic_root_model__:
root_json_schema_extra = cls.model_fields['root'].json_schema_extra
Expand Down Expand Up @@ -1461,13 +1479,13 @@ def dataclass_schema(self, schema: core_schema.DataclassSchema) -> JsonSchemaVal
Returns:
The generated JSON schema.
"""
json_schema = self.generate_inner(schema['schema']).copy()

cls = schema['cls']
config: ConfigDict = getattr(cls, '__pydantic_config__', cast('ConfigDict', {}))

title = config.get('title') or cls.__name__

with self._config_wrapper_stack.push(config):
json_schema = self.generate_inner(schema['schema']).copy()

json_schema_extra = config.get('json_schema_extra')
json_schema = self._update_class_schema(json_schema, title, config.get('extra', None), cls, json_schema_extra)

Expand Down Expand Up @@ -1942,7 +1960,12 @@ def encode_default(self, dft: Any) -> Any:
Returns:
The encoded default value.
"""
return pydantic_core.to_jsonable_python(dft)
config = self._config
return pydantic_core.to_jsonable_python(
dft,
timedelta_mode=config.ser_json_timedelta,
bytes_mode=config.ser_json_bytes,
)

def update_with_validations(
self, json_schema: JsonSchemaValue, core_schema: CoreSchema, mapping: dict[str, str]
Expand Down Expand Up @@ -2321,3 +2344,14 @@ def __get_pydantic_json_schema__(

def __hash__(self) -> int:
return hash(type(self))


def _get_typed_dict_config(schema: core_schema.TypedDictSchema) -> ConfigDict:
metadata = _core_metadata.CoreMetadataHandler(schema).metadata
cls = metadata.get('pydantic_typed_dict_cls')
if cls is not None:
try:
return _decorators.get_attribute_from_bases(cls, '__pydantic_config__')
except AttributeError:
pass
return {}