Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify flatteining and inlining of Coreschema #7523

Merged
merged 3 commits into from Sep 20, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
168 changes: 94 additions & 74 deletions pydantic/_internal/_core_utils.py
Expand Up @@ -5,15 +5,14 @@
Any,
Callable,
Hashable,
Iterable,
TypeVar,
Union,
_GenericAlias, # type: ignore
cast,
)

from pydantic_core import CoreSchema, core_schema
from typing_extensions import TypeAliasType, TypeGuard, get_args
from typing_extensions import TypeAliasType, TypedDict, TypeGuard, get_args

from . import _repr

Expand All @@ -40,6 +39,8 @@
_FUNCTION_WITH_INNER_SCHEMA_TYPES = {'function-before', 'function-after', 'function-wrap'}
_LIST_LIKE_SCHEMA_WITH_ITEMS_TYPES = {'list', 'tuple-variable', 'set', 'frozenset'}

_DEFINITIONS_CACHE_METADATA_KEY = 'pydantic.definitions_cache'


def is_core_schema(
schema: CoreSchemaOrField,
Expand Down Expand Up @@ -132,10 +133,10 @@ def define_expected_missing_refs(
# in this case, there are no missing refs to potentially substitute, so there's no need to walk the schema
# this is a common case (will be hit for all non-generic models), so it's worth optimizing for
return schema
refs = set()
refs: set[str] = set()

def _record_refs(s: core_schema.CoreSchema, recurse: Recurse) -> core_schema.CoreSchema:
ref: str | None = s.get('ref')
ref = get_ref(s)
if ref:
refs.add(ref)
return recurse(s, _record_refs)
Expand Down Expand Up @@ -416,92 +417,122 @@ def walk_core_schema(schema: core_schema.CoreSchema, f: Walk) -> core_schema.Cor
return f(schema, _dispatch)


def _simplify_schema_references(schema: core_schema.CoreSchema, inline: bool) -> core_schema.CoreSchema: # noqa: C901
all_defs: dict[str, core_schema.CoreSchema] = {}
class _DefinitionsState(TypedDict):
definitions: dict[str, core_schema.CoreSchema]
ref_counts: dict[str, int]
involved_in_recursion: dict[str, bool]
current_recursion_ref_count: dict[str, int]

def make_result(schema: core_schema.CoreSchema, defs: Iterable[core_schema.CoreSchema]) -> core_schema.CoreSchema:
definitions = list(defs)
if definitions:
return core_schema.definitions_schema(schema=schema, definitions=definitions)
return schema

def simplify_schema_references(schema: core_schema.CoreSchema) -> core_schema.CoreSchema: # noqa: C901
"""Simplify schema references by:
1. Inlining any definitions that are only referenced in one place and are not involved in a cycle.
2. Removing any unused `ref` references from schemas.
"""
state = _DefinitionsState(
definitions={},
ref_counts=defaultdict(int),
involved_in_recursion={},
current_recursion_ref_count=defaultdict(int),
)

def collect_refs(s: core_schema.CoreSchema, recurse: Recurse) -> core_schema.CoreSchema:
if 'metadata' in s:
definitions_cache: _DefinitionsState | None = s['metadata'].get(_DEFINITIONS_CACHE_METADATA_KEY, None)
if definitions_cache is not None:
state['definitions'].update(definitions_cache['definitions'])
return s

if s['type'] == 'definitions':
for definition in s['definitions']:
ref = get_ref(definition)
assert ref is not None
all_defs[ref] = recurse(definition, collect_refs)
state['definitions'][ref] = definition
recurse(definition, collect_refs)
return recurse(s['schema'], collect_refs)
else:
ref = get_ref(s)
if ref is not None:
all_defs[ref] = s
return recurse(s, collect_refs)

schema = walk_core_schema(schema, collect_refs)

def flatten_refs(s: core_schema.CoreSchema, recurse: Recurse) -> core_schema.CoreSchema:
if s['type'] == 'definitions':
# iterate ourselves, we don't want to flatten the actual defs!
definitions: list[CoreSchema] = s.pop('definitions') # type: ignore
schema: CoreSchema = s.pop('schema') # type: ignore
# remaining keys are optional like 'serialization'
schema: CoreSchema = {**schema, **s} # type: ignore
s['schema'] = recurse(schema, flatten_refs)
for definition in definitions:
recurse(definition, flatten_refs) # don't re-assign here!
return schema
else:
s = recurse(s, flatten_refs)
ref = get_ref(s)
if ref and ref in all_defs:
all_defs[ref] = s
state['definitions'][ref] = s
recurse(s, collect_refs)
return core_schema.definition_reference_schema(schema_ref=ref)
return s

schema = walk_core_schema(schema, flatten_refs)

for def_schema in all_defs.values():
walk_core_schema(def_schema, flatten_refs)

if not inline:
return make_result(schema, all_defs.values())
else:
return recurse(s, collect_refs)

ref_counts: defaultdict[str, int] = defaultdict(int)
involved_in_recursion: dict[str, bool] = {}
current_recursion_ref_count: defaultdict[str, int] = defaultdict(int)
schema = walk_core_schema(schema, collect_refs)

def count_refs(s: core_schema.CoreSchema, recurse: Recurse) -> core_schema.CoreSchema:
if 'metadata' in s:
definitions_cache: _DefinitionsState | None = s['metadata'].get(_DEFINITIONS_CACHE_METADATA_KEY, None)
if definitions_cache is not None:
for ref in definitions_cache['ref_counts']:
state['ref_counts'][ref] += definitions_cache['ref_counts'][ref]
# it's possible that a schema was seen before we hit the cache
# and also exists in the cache, in which case it is involved in a recursion
if state['current_recursion_ref_count'][ref] != 0:
state['involved_in_recursion'][ref] = True
# if it's involved in recursion in the inner schema mark it globally as involved in a recursion
for ref_in_recursion in definitions_cache['involved_in_recursion']:
if ref_in_recursion:
state['involved_in_recursion'][ref_in_recursion] = True
return s

if s['type'] != 'definition-ref':
return recurse(s, count_refs)
ref = s['schema_ref']
ref_counts[ref] += 1
state['ref_counts'][ref] += 1

if ref_counts[ref] >= 2:
if state['ref_counts'][ref] >= 2:
# If this model is involved in a recursion this should be detected
# on its second encounter, we can safely stop the walk here.
if current_recursion_ref_count[ref] != 0:
involved_in_recursion[ref] = True
if state['current_recursion_ref_count'][ref] != 0:
state['involved_in_recursion'][ref] = True
return s

current_recursion_ref_count[ref] += 1
recurse(all_defs[ref], count_refs)
current_recursion_ref_count[ref] -= 1
state['current_recursion_ref_count'][ref] += 1
recurse(state['definitions'][ref], count_refs)
state['current_recursion_ref_count'][ref] -= 1
return s

schema = walk_core_schema(schema, count_refs)

assert all(c == 0 for c in current_recursion_ref_count.values()), 'this is a bug! please report it'
assert all(c == 0 for c in state['current_recursion_ref_count'].values()), 'this is a bug! please report it'

definitions_cache = _DefinitionsState(
definitions=state['definitions'],
ref_counts=dict(state['ref_counts']),
involved_in_recursion=state['involved_in_recursion'],
current_recursion_ref_count=dict(state['current_recursion_ref_count']),
)

def can_be_inlined(s: core_schema.DefinitionReferenceSchema, ref: str) -> bool:
if state['ref_counts'][ref] > 1:
return False
if state['involved_in_recursion'].get(ref, False):
return False
if 'serialization' in s:
return False
if 'metadata' in s:
metadata = s['metadata']
for k in (
'pydantic_js_functions',
'pydantic_js_annotation_functions',
'pydantic.internal.union_discriminator',
):
if k in metadata:
# we need to keep this as a ref
return False
return True

def inline_refs(s: core_schema.CoreSchema, recurse: Recurse) -> core_schema.CoreSchema:
if s['type'] == 'definition-ref':
ref = s['schema_ref']
# Check if the reference is only used once and not involved in recursion
if ref_counts[ref] <= 1 and not involved_in_recursion.get(ref, False):
# Check if the reference is only used once, not involved in recursion and does not have
# any extra keys (like 'serialization')
if can_be_inlined(s, ref):
# Inline the reference by replacing the reference with the actual schema
new = all_defs.pop(ref)
ref_counts[ref] -= 1 # because we just replaced it!
new.pop('ref') # type: ignore
new = state['definitions'].pop(ref)
state['ref_counts'][ref] -= 1 # because we just replaced it!
# put all other keys that were on the def-ref schema into the inlined version
# in particular this is needed for `serialization`
if 'serialization' in s:
Expand All @@ -515,23 +546,12 @@ def inline_refs(s: core_schema.CoreSchema, recurse: Recurse) -> core_schema.Core

schema = walk_core_schema(schema, inline_refs)

definitions = [d for d in all_defs.values() if ref_counts[d['ref']] > 0] # type: ignore
return make_result(schema, definitions)
definitions = [d for d in state['definitions'].values() if state['ref_counts'][d['ref']] > 0] # type: ignore


def flatten_schema_defs(schema: core_schema.CoreSchema) -> core_schema.CoreSchema:
"""Simplify schema references by:
1. Grouping all definitions into a single top-level `definitions` schema, similar to a JSON schema's `#/$defs`.
"""
return _simplify_schema_references(schema, inline=False)


def inline_schema_defs(schema: core_schema.CoreSchema) -> core_schema.CoreSchema:
"""Simplify schema references by:
1. Inlining any definitions that are only referenced in one place and are not involved in a cycle.
2. Removing any unused `ref` references from schemas.
"""
return _simplify_schema_references(schema, inline=True)
if definitions:
schema = core_schema.definitions_schema(schema=schema, definitions=definitions)
schema.setdefault('metadata', {})[_DEFINITIONS_CACHE_METADATA_KEY] = definitions_cache # type: ignore
return schema


def pretty_print_core_schema(
Expand Down
12 changes: 6 additions & 6 deletions pydantic/_internal/_dataclasses.py
Expand Up @@ -16,7 +16,7 @@
from ..fields import FieldInfo
from ..warnings import PydanticDeprecatedSince20
from . import _config, _decorators, _discriminated_union, _typing_extra
from ._core_utils import collect_invalid_schemas, flatten_schema_defs, inline_schema_defs
from ._core_utils import collect_invalid_schemas, simplify_schema_references
from ._fields import collect_dataclass_fields
from ._generate_schema import GenerateSchema
from ._generics import get_standard_typevars_map
Expand Down Expand Up @@ -152,19 +152,19 @@ def __init__(__dataclass_self__: PydanticDataclass, *args: Any, **kwargs: Any) -
core_config = config_wrapper.core_config(cls)

schema = gen_schema.collect_definitions(schema)
schema = flatten_schema_defs(schema)
if collect_invalid_schemas(schema):
set_dataclass_mock_validator(cls, cls.__name__, 'all referenced types')
return False

schema = _discriminated_union.apply_discriminators(simplify_schema_references(schema))

# We are about to set all the remaining required properties expected for this cast;
# __pydantic_decorators__ and __pydantic_fields__ should already be set
cls = typing.cast('type[PydanticDataclass]', cls)
# debug(schema)
cls.__pydantic_core_schema__ = schema = _discriminated_union.apply_discriminators(flatten_schema_defs(schema))
simplified_core_schema = inline_schema_defs(schema)
cls.__pydantic_validator__ = validator = SchemaValidator(simplified_core_schema, core_config)
cls.__pydantic_serializer__ = SchemaSerializer(simplified_core_schema, core_config)
cls.__pydantic_core_schema__ = schema
cls.__pydantic_validator__ = validator = SchemaValidator(schema, core_config)
cls.__pydantic_serializer__ = SchemaSerializer(schema, core_config)

if config_wrapper.validate_assignment:

Expand Down
3 changes: 0 additions & 3 deletions pydantic/_internal/_generate_schema.py
Expand Up @@ -52,7 +52,6 @@
from ._core_utils import (
CoreSchemaOrField,
define_expected_missing_refs,
flatten_schema_defs,
get_type_ref,
is_list_like_schema_with_items_schema,
)
Expand Down Expand Up @@ -518,8 +517,6 @@ def _unpack_refs_defs(self, schema: CoreSchema) -> CoreSchema:
def get_ref(s: CoreSchema) -> str:
return s['ref'] # type: ignore

schema = flatten_schema_defs(schema)

if schema['type'] == 'definitions':
self.defs.definitions.update({get_ref(s): s for s in schema['definitions']})
schema = schema['schema']
Expand Down
10 changes: 5 additions & 5 deletions pydantic/_internal/_model_construction.py
Expand Up @@ -16,7 +16,7 @@
from ..fields import Field, FieldInfo, ModelPrivateAttr, PrivateAttr
from ..warnings import PydanticDeprecatedSince20
from ._config import ConfigWrapper
from ._core_utils import collect_invalid_schemas, flatten_schema_defs, inline_schema_defs
from ._core_utils import collect_invalid_schemas, simplify_schema_references
from ._decorators import (
ComputedFieldInfo,
DecoratorInfos,
Expand Down Expand Up @@ -487,16 +487,16 @@ def complete_model_class(
core_config = config_wrapper.core_config(cls)

schema = gen_schema.collect_definitions(schema)
schema = apply_discriminators(flatten_schema_defs(schema))
if collect_invalid_schemas(schema):
set_model_mocks(cls, cls_name)
return False

schema = apply_discriminators(simplify_schema_references(schema))

# debug(schema)
cls.__pydantic_core_schema__ = schema
simplified_core_schema = inline_schema_defs(schema)
cls.__pydantic_validator__ = SchemaValidator(simplified_core_schema, core_config)
cls.__pydantic_serializer__ = SchemaSerializer(simplified_core_schema, core_config)
cls.__pydantic_validator__ = SchemaValidator(schema, core_config)
cls.__pydantic_serializer__ = SchemaSerializer(schema, core_config)
cls.__pydantic_complete__ = True

# set __signature__ attr only for model class, but not for its instances
Expand Down
21 changes: 10 additions & 11 deletions pydantic/_internal/_validate_call.py
Expand Up @@ -10,7 +10,7 @@
from ..config import ConfigDict
from . import _discriminated_union, _generate_schema, _typing_extra
from ._config import ConfigWrapper
from ._core_utils import flatten_schema_defs, inline_schema_defs
from ._core_utils import simplify_schema_references


@dataclass
Expand Down Expand Up @@ -61,11 +61,12 @@ def __init__(self, function: Callable[..., Any], config: ConfigDict | None, vali
namespace = _typing_extra.add_module_globals(function, None)
config_wrapper = ConfigWrapper(config)
gen_schema = _generate_schema.GenerateSchema(config_wrapper, namespace)
self.__pydantic_core_schema__ = schema = gen_schema.collect_definitions(gen_schema.generate_schema(function))
schema = gen_schema.collect_definitions(gen_schema.generate_schema(function))
schema = simplify_schema_references(schema)
self.__pydantic_core_schema__ = schema = schema
core_config = config_wrapper.core_config(self)
schema = _discriminated_union.apply_discriminators(flatten_schema_defs(schema))
simplified_schema = inline_schema_defs(schema)
self.__pydantic_validator__ = pydantic_core.SchemaValidator(simplified_schema, core_config)
schema = _discriminated_union.apply_discriminators(schema)
self.__pydantic_validator__ = pydantic_core.SchemaValidator(schema, core_config)

if self._validate_return:
return_type = (
Expand All @@ -74,13 +75,11 @@ def __init__(self, function: Callable[..., Any], config: ConfigDict | None, vali
else Any
)
gen_schema = _generate_schema.GenerateSchema(config_wrapper, namespace)
self.__return_pydantic_core_schema__ = schema = gen_schema.collect_definitions(
gen_schema.generate_schema(return_type)
)
schema = gen_schema.collect_definitions(gen_schema.generate_schema(return_type))
schema = _discriminated_union.apply_discriminators(simplify_schema_references(schema))
self.__return_pydantic_core_schema__ = schema
core_config = config_wrapper.core_config(self)
schema = _discriminated_union.apply_discriminators(flatten_schema_defs(schema))
simplified_schema = inline_schema_defs(schema)
validator = pydantic_core.SchemaValidator(simplified_schema, core_config)
validator = pydantic_core.SchemaValidator(schema, core_config)
if inspect.iscoroutinefunction(self.raw_function):

async def return_val_wrapper(aw: Awaitable[Any]) -> None:
Expand Down
7 changes: 3 additions & 4 deletions pydantic/type_adapter.py
Expand Up @@ -166,21 +166,20 @@ def __init__(self, type: Any, *, config: ConfigDict | None = None, _parent_depth
except AttributeError:
core_schema = _get_schema(type, config_wrapper, parent_depth=_parent_depth + 1)

core_schema = _discriminated_union.apply_discriminators(_core_utils.flatten_schema_defs(core_schema))
simplified_core_schema = _core_utils.inline_schema_defs(core_schema)
core_schema = _discriminated_union.apply_discriminators(_core_utils.simplify_schema_references(core_schema))

core_config = config_wrapper.core_config(None)
validator: SchemaValidator
try:
validator = _getattr_no_parents(type, '__pydantic_validator__')
except AttributeError:
validator = SchemaValidator(simplified_core_schema, core_config)
validator = SchemaValidator(core_schema, core_config)

serializer: SchemaSerializer
try:
serializer = _getattr_no_parents(type, '__pydantic_serializer__')
except AttributeError:
serializer = SchemaSerializer(simplified_core_schema, core_config)
serializer = SchemaSerializer(core_schema, core_config)

self.core_schema = core_schema
self.validator = validator
Expand Down