Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use ser_json_<timedelta|bytes> on default in GenerateJsonSchema #7269

Merged
merged 6 commits into from Aug 29, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
43 changes: 41 additions & 2 deletions pydantic/_internal/_config.py
@@ -1,10 +1,21 @@
from __future__ import annotations as _annotations

import warnings
from typing import TYPE_CHECKING, Any, Callable, cast
from contextlib import contextmanager, nullcontext
from typing import (
TYPE_CHECKING,
Any,
Callable,
ContextManager,
Iterator,
cast,
)

from pydantic_core import core_schema
from typing_extensions import Literal, Self
from typing_extensions import (
Literal,
Self,
)

from ..config import ConfigDict, ExtraValues, JsonEncoder, JsonSchemaExtraCallable
from ..errors import PydanticUserError
Expand Down Expand Up @@ -169,6 +180,34 @@ def __repr__(self):
return f'ConfigWrapper({c})'


class ConfigWrapperStack:
"""A stack of `ConfigWrapper` instances."""

def __init__(self, config_wrapper: ConfigWrapper):
self._config_wrapper_stack: list[ConfigWrapper] = [config_wrapper]

@property
def tail(self) -> ConfigWrapper:
return self._config_wrapper_stack[-1]

def push(self, config_wrapper: ConfigWrapper | ConfigDict | None) -> ContextManager[None]:
if config_wrapper is None:
return nullcontext()

if not isinstance(config_wrapper, ConfigWrapper):
config_wrapper = ConfigWrapper(config_wrapper, check=False)

@contextmanager
def _context_manager() -> Iterator[None]:
self._config_wrapper_stack.append(config_wrapper)
try:
yield
finally:
self._config_wrapper_stack.pop()

return _context_manager()


config_defaults = ConfigDict(
title=None,
str_to_lower=False,
Expand Down
33 changes: 2 additions & 31 deletions pydantic/_internal/_generate_schema.py
Expand Up @@ -8,7 +8,7 @@
import sys
import typing
import warnings
from contextlib import contextmanager, nullcontext
from contextlib import contextmanager
from copy import copy
from enum import Enum
from functools import partial
Expand All @@ -20,7 +20,6 @@
TYPE_CHECKING,
Any,
Callable,
ContextManager,
Dict,
ForwardRef,
Iterable,
Expand All @@ -45,7 +44,7 @@
from ..warnings import PydanticDeprecatedSince20
from . import _decorators, _discriminated_union, _known_annotated_metadata, _typing_extra
from ._annotated_handlers import GetCoreSchemaHandler, GetJsonSchemaHandler
from ._config import ConfigWrapper
from ._config import ConfigWrapper, ConfigWrapperStack
from ._core_metadata import (
CoreMetadataHandler,
build_metadata_dict,
Expand Down Expand Up @@ -260,34 +259,6 @@ def _add_custom_serialization_from_json_encoders(
return schema


class ConfigWrapperStack:
"""A stack of `ConfigWrapper` instances."""

def __init__(self, config_wrapper: ConfigWrapper):
self._config_wrapper_stack: list[ConfigWrapper] = [config_wrapper]

@property
def tail(self) -> ConfigWrapper:
return self._config_wrapper_stack[-1]

def push(self, config_wrapper: ConfigWrapper | ConfigDict | None) -> ContextManager[None]:
if config_wrapper is None:
return nullcontext()

if not isinstance(config_wrapper, ConfigWrapper):
config_wrapper = ConfigWrapper(config_wrapper, check=False)

@contextmanager
def _context_manager() -> Iterator[None]:
self._config_wrapper_stack.append(config_wrapper)
try:
yield
finally:
self._config_wrapper_stack.pop()

return _context_manager()


class GenerateSchema:
"""Generate core schema for a Pydantic model, dataclass and types like `str`, `datetime`, ... ."""

Expand Down
93 changes: 70 additions & 23 deletions pydantic/json_schema.py
Expand Up @@ -35,21 +35,37 @@
)
Kludex marked this conversation as resolved.
Show resolved Hide resolved

import pydantic_core
from pydantic_core import CoreConfig, CoreSchema, PydanticOmit, core_schema, to_jsonable_python
from pydantic_core import (
CoreConfig,
CoreSchema,
PydanticOmit,
core_schema,
to_jsonable_python,
)
from pydantic_core.core_schema import ComputedField
from typing_extensions import Annotated, Literal, assert_never

from pydantic._internal import _annotated_handlers, _internal_dataclass

from ._internal import _core_metadata, _core_utils, _mock_val_ser, _schema_generation_shared, _typing_extra
from ._internal import (
_annotated_handlers,
_config,
_core_metadata,
_core_utils,
_internal_dataclass,
_mock_val_ser,
_schema_generation_shared,
_typing_extra,
)
from .config import JsonSchemaExtraCallable
from .errors import PydanticInvalidForJsonSchema, PydanticUserError

if TYPE_CHECKING:
from . import ConfigDict
from ._internal._core_utils import CoreSchemaField, CoreSchemaOrField
from ._internal._dataclasses import PydanticDataclass
from ._internal._schema_generation_shared import GetJsonSchemaFunction, GetJsonSchemaHandler
from ._internal._schema_generation_shared import (
GetJsonSchemaFunction,
GetJsonSchemaHandler,
)
from .main import BaseModel


Expand All @@ -75,7 +91,10 @@
for validation inputs, or that will be matched by serialization outputs.
"""

_MODE_TITLE_MAPPING: dict[JsonSchemaMode, str] = {'validation': 'Input', 'serialization': 'Output'}
_MODE_TITLE_MAPPING: dict[JsonSchemaMode, str] = {
'validation': 'Input',
'serialization': 'Output',
}


def update_json_schema(schema: JsonSchemaValue, updates: dict[str, Any]) -> JsonSchemaValue:
Expand Down Expand Up @@ -266,6 +285,7 @@ def __init__(self, by_alias: bool = True, ref_template: str = DEFAULT_REF_TEMPLA
self.json_to_defs_refs: dict[JsonRef, DefsRef] = {}

self.definitions: dict[DefsRef, JsonSchemaValue] = {}
self._config_wrapper_stack = _config.ConfigWrapperStack(_config.ConfigWrapper({}))

self.mode: JsonSchemaMode = 'validation'

Expand Down Expand Up @@ -318,8 +338,9 @@ def build_schema_type_to_method(
return mapping

def generate_definitions(
self, inputs: Sequence[tuple[JsonSchemaKeyT, JsonSchemaMode, core_schema.CoreSchema]]
) -> tuple[dict[tuple[JsonSchemaKeyT, JsonSchemaMode], JsonSchemaValue], dict[DefsRef, JsonSchemaValue]]:
self,
inputs: Sequence[tuple[JsonSchemaKeyT, JsonSchemaMode, core_schema.CoreSchema]],
) -> tuple[dict[tuple[JsonSchemaKeyT, JsonSchemaMode], JsonSchemaValue], dict[DefsRef, JsonSchemaValue],]:
"""Generates JSON schema definitions from a list of core schemas, pairing the generated definitions with a
mapping that links the input keys to the definition references.

Expand Down Expand Up @@ -649,7 +670,8 @@ def bytes_schema(self, schema: core_schema.BytesSchema) -> JsonSchemaValue:
Returns:
The generated JSON schema.
"""
json_schema = {'type': 'string', 'format': 'binary'}
config = self._config_wrapper_stack.tail
json_schema = {'type': 'string', 'format': 'base64url' if config.ser_json_bytes == 'base64' else 'binary'}
self.update_with_validations(json_schema, schema, self.ValidationsMapping.bytes)
return json_schema

Expand Down Expand Up @@ -697,6 +719,9 @@ def timedelta_schema(self, schema: core_schema.TimedeltaSchema) -> JsonSchemaVal
Returns:
The generated JSON schema.
"""
config = self._config_wrapper_stack.tail
if config.ser_json_timedelta == 'float':
return {'type': 'number'}
return {'type': 'string', 'format': 'duration'}

def literal_schema(self, schema: core_schema.LiteralSchema) -> JsonSchemaValue:
Expand Down Expand Up @@ -1286,12 +1311,13 @@ def model_schema(self, schema: core_schema.ModelSchema) -> JsonSchemaValue:
"""
# We do not use schema['model'].model_json_schema() here
# because it could lead to inconsistent refs handling, etc.
json_schema = self.generate_inner(schema['schema'])

cls = cast('type[BaseModel]', schema['cls'])
config = cls.model_config
title = config.get('title')

with self._config_wrapper_stack.push(config):
json_schema = self.generate_inner(schema['schema'])

json_schema_extra = config.get('json_schema_extra')
if cls.__pydantic_root_model__:
root_json_schema_extra = cls.model_fields['root'].json_schema_extra
Expand Down Expand Up @@ -1461,13 +1487,13 @@ def dataclass_schema(self, schema: core_schema.DataclassSchema) -> JsonSchemaVal
Returns:
The generated JSON schema.
"""
json_schema = self.generate_inner(schema['schema']).copy()

cls = schema['cls']
config: ConfigDict = getattr(cls, '__pydantic_config__', cast('ConfigDict', {}))

title = config.get('title') or cls.__name__

with self._config_wrapper_stack.push(config):
json_schema = self.generate_inner(schema['schema']).copy()

json_schema_extra = config.get('json_schema_extra')
json_schema = self._update_class_schema(json_schema, title, config.get('extra', None), cls, json_schema_extra)

Expand Down Expand Up @@ -1522,7 +1548,9 @@ def arguments_schema(self, schema: core_schema.ArgumentsSchema) -> JsonSchemaVal
)

def kw_arguments_schema(
self, arguments: list[core_schema.ArgumentsParameter], var_kwargs_schema: CoreSchema | None
self,
arguments: list[core_schema.ArgumentsParameter],
var_kwargs_schema: CoreSchema | None,
) -> JsonSchemaValue:
"""Generates a JSON schema that matches a schema that defines a function's keyword arguments.

Expand Down Expand Up @@ -1559,7 +1587,9 @@ def kw_arguments_schema(
return json_schema

def p_arguments_schema(
self, arguments: list[core_schema.ArgumentsParameter], var_args_schema: CoreSchema | None
self,
arguments: list[core_schema.ArgumentsParameter],
var_args_schema: CoreSchema | None,
) -> JsonSchemaValue:
"""Generates a JSON schema that matches a schema that defines a function's positional arguments.

Expand Down Expand Up @@ -1650,7 +1680,11 @@ def json_schema(self, schema: core_schema.JsonSchema) -> JsonSchemaValue:
content_core_schema = schema.get('schema') or core_schema.any_schema()
content_json_schema = self.generate_inner(content_core_schema)
if self.mode == 'validation':
return {'type': 'string', 'contentMediaType': 'application/json', 'contentSchema': content_json_schema}
return {
'type': 'string',
'contentMediaType': 'application/json',
'contentSchema': content_json_schema,
}
else:
# self.mode == 'serialization'
return content_json_schema
Expand Down Expand Up @@ -1725,7 +1759,8 @@ def definition_ref_schema(self, schema: core_schema.DefinitionReferenceSchema) -
return ref_json_schema

def ser_schema(
self, schema: core_schema.SerSchema | core_schema.IncExSeqSerSchema | core_schema.IncExDictSerSchema
self,
schema: core_schema.SerSchema | core_schema.IncExSeqSerSchema | core_schema.IncExDictSerSchema,
) -> JsonSchemaValue | None:
"""Generates a JSON schema that matches a schema that defines a serialized object.

Expand Down Expand Up @@ -1942,10 +1977,18 @@ def encode_default(self, dft: Any) -> Any:
Returns:
The encoded default value.
"""
return pydantic_core.to_jsonable_python(dft)
config = self._config_wrapper_stack.tail
return pydantic_core.to_jsonable_python(
dft,
timedelta_mode=config.ser_json_timedelta,
bytes_mode=config.ser_json_bytes,
)

def update_with_validations(
self, json_schema: JsonSchemaValue, core_schema: CoreSchema, mapping: dict[str, str]
self,
json_schema: JsonSchemaValue,
core_schema: CoreSchema,
mapping: dict[str, str],
) -> None:
"""Update the json_schema with the corresponding validations specified in the core_schema,
using the provided mapping to translate keys in core_schema to the appropriate keys for a JSON schema.
Expand Down Expand Up @@ -2133,7 +2176,7 @@ def models_json_schema(
description: str | None = None,
ref_template: str = DEFAULT_REF_TEMPLATE,
schema_generator: type[GenerateJsonSchema] = GenerateJsonSchema,
) -> tuple[dict[tuple[type[BaseModel] | type[PydanticDataclass], JsonSchemaMode], JsonSchemaValue], JsonSchemaValue]:
) -> tuple[dict[tuple[type[BaseModel] | type[PydanticDataclass], JsonSchemaMode], JsonSchemaValue,], JsonSchemaValue,]:
"""Utility function to generate a JSON Schema for multiple models.

Args:
Expand Down Expand Up @@ -2223,7 +2266,9 @@ class WithJsonSchema:
mode: Literal['validation', 'serialization'] | None = None

def __get_pydantic_json_schema__(
self, core_schema: core_schema.CoreSchema, handler: _annotated_handlers.GetJsonSchemaHandler
self,
core_schema: core_schema.CoreSchema,
handler: _annotated_handlers.GetJsonSchemaHandler,
) -> JsonSchemaValue:
mode = self.mode or handler.mode
if mode != handler.mode:
Expand Down Expand Up @@ -2253,7 +2298,9 @@ class Examples:
mode: Literal['validation', 'serialization'] | None = None

def __get_pydantic_json_schema__(
self, core_schema: core_schema.CoreSchema, handler: _annotated_handlers.GetJsonSchemaHandler
self,
core_schema: core_schema.CoreSchema,
handler: _annotated_handlers.GetJsonSchemaHandler,
) -> JsonSchemaValue:
mode = self.mode or handler.mode
json_schema = handler(core_schema)
Expand Down
44 changes: 44 additions & 0 deletions tests/test_json_schema.py
Expand Up @@ -1679,6 +1679,50 @@ class Outer(BaseModel):
}


@pytest.mark.parametrize(
'ser_json_timedelta,properties',
[
('float', {'duration': {'default': 300.0, 'title': 'Duration', 'type': 'number'}}),
('iso8601', {'duration': {'default': 'PT300S', 'format': 'duration', 'title': 'Duration', 'type': 'string'}}),
],
)
def test_model_default_timedelta(ser_json_timedelta: Literal['float', 'iso8601'], properties: dict[str, Any]):
class Model(BaseModel):
model_config = ConfigDict(ser_json_timedelta=ser_json_timedelta)

duration: timedelta = timedelta(minutes=5)

assert Model.model_json_schema(mode='serialization') == {
'properties': properties,
'required': ['duration'],
'title': 'Model',
'type': 'object',
}


# test ser_json_bytes
@pytest.mark.parametrize(
'ser_json_bytes,properties',
[
('base64', {'data': {'default': 'Zm9vYmFy', 'format': 'base64url', 'title': 'Data', 'type': 'string'}}),
('utf8', {'data': {'default': 'foobar', 'format': 'binary', 'title': 'Data', 'type': 'string'}}),
],
)
def test_model_default_bytes(ser_json_bytes: Literal['base64', 'utf8'], properties: dict[str, Any]):
class Model(BaseModel):
model_config = ConfigDict(ser_json_bytes=ser_json_bytes)

data: bytes = b'foobar'

# insert_assert(Model.model_json_schema(mode='serialization'))
assert Model.model_json_schema(mode='serialization') == {
'properties': properties,
'required': ['data'],
'title': 'Model',
'type': 'object',
}


def test_model_subclass_metadata():
class A(BaseModel):
"""A Model docstring"""
Expand Down