Skip to content

Commit

Permalink
Simplify flatteining and inlining of Coreschema
Browse files Browse the repository at this point in the history
  • Loading branch information
adriangb committed Sep 20, 2023
1 parent 70d3c3e commit fcb97b4
Show file tree
Hide file tree
Showing 11 changed files with 154 additions and 201 deletions.
164 changes: 92 additions & 72 deletions pydantic/_internal/_core_utils.py
Expand Up @@ -5,15 +5,14 @@
Any,
Callable,
Hashable,
Iterable,
TypeVar,
Union,
_GenericAlias, # type: ignore
cast,
)

from pydantic_core import CoreSchema, core_schema
from typing_extensions import TypeAliasType, TypeGuard, get_args
from typing_extensions import TypeAliasType, TypedDict, TypeGuard, get_args

from . import _repr

Expand All @@ -40,6 +39,8 @@
_FUNCTION_WITH_INNER_SCHEMA_TYPES = {'function-before', 'function-after', 'function-wrap'}
_LIST_LIKE_SCHEMA_WITH_ITEMS_TYPES = {'list', 'tuple-variable', 'set', 'frozenset'}

_DEFINITIONS_CACHE_METADATA_KEY = 'pydantic.definitions_cache'


def is_core_schema(
schema: CoreSchemaOrField,
Expand Down Expand Up @@ -416,92 +417,122 @@ def walk_core_schema(schema: core_schema.CoreSchema, f: Walk) -> core_schema.Cor
return f(schema, _dispatch)


def _simplify_schema_references(schema: core_schema.CoreSchema, inline: bool) -> core_schema.CoreSchema: # noqa: C901
all_defs: dict[str, core_schema.CoreSchema] = {}
class _DefinitionsState(TypedDict):
definitions: dict[str, core_schema.CoreSchema]
ref_counts: dict[str, int]
involved_in_recursion: dict[str, bool]
current_recursion_ref_count: dict[str, int]

def make_result(schema: core_schema.CoreSchema, defs: Iterable[core_schema.CoreSchema]) -> core_schema.CoreSchema:
definitions = list(defs)
if definitions:
return core_schema.definitions_schema(schema=schema, definitions=definitions)
return schema

def simplify_schema_references(schema: core_schema.CoreSchema) -> core_schema.CoreSchema: # noqa: C901
"""Simplify schema references by:
1. Inlining any definitions that are only referenced in one place and are not involved in a cycle.
2. Removing any unused `ref` references from schemas.
"""
state = _DefinitionsState(
definitions={},
ref_counts=defaultdict(int),
involved_in_recursion={},
current_recursion_ref_count=defaultdict(int),
)

def collect_refs(s: core_schema.CoreSchema, recurse: Recurse) -> core_schema.CoreSchema:
if 'metadata' in s:
definitions_cache: _DefinitionsState | None = s['metadata'].get(_DEFINITIONS_CACHE_METADATA_KEY, None)
if definitions_cache is not None:
state['definitions'].update(definitions_cache['definitions'])
return s

if s['type'] == 'definitions':
for definition in s['definitions']:
ref = get_ref(definition)
assert ref is not None
all_defs[ref] = recurse(definition, collect_refs)
state['definitions'][ref] = definition
recurse(definition, collect_refs)
return recurse(s['schema'], collect_refs)
else:
ref = get_ref(s)
if ref is not None:
all_defs[ref] = s
return recurse(s, collect_refs)

schema = walk_core_schema(schema, collect_refs)

def flatten_refs(s: core_schema.CoreSchema, recurse: Recurse) -> core_schema.CoreSchema:
if s['type'] == 'definitions':
# iterate ourselves, we don't want to flatten the actual defs!
definitions: list[CoreSchema] = s.pop('definitions') # type: ignore
schema: CoreSchema = s.pop('schema') # type: ignore
# remaining keys are optional like 'serialization'
schema: CoreSchema = {**schema, **s} # type: ignore
s['schema'] = recurse(schema, flatten_refs)
for definition in definitions:
recurse(definition, flatten_refs) # don't re-assign here!
return schema
else:
s = recurse(s, flatten_refs)
ref = get_ref(s)
if ref and ref in all_defs:
all_defs[ref] = s
state['definitions'][ref] = s
recurse(s, collect_refs)
return core_schema.definition_reference_schema(schema_ref=ref)
return s

schema = walk_core_schema(schema, flatten_refs)

for def_schema in all_defs.values():
walk_core_schema(def_schema, flatten_refs)

if not inline:
return make_result(schema, all_defs.values())
else:
return recurse(s, collect_refs)

ref_counts: defaultdict[str, int] = defaultdict(int)
involved_in_recursion: dict[str, bool] = {}
current_recursion_ref_count: defaultdict[str, int] = defaultdict(int)
schema = walk_core_schema(schema, collect_refs)

def count_refs(s: core_schema.CoreSchema, recurse: Recurse) -> core_schema.CoreSchema:
if 'metadata' in s:
definitions_cache: _DefinitionsState | None = s['metadata'].get(_DEFINITIONS_CACHE_METADATA_KEY, None)
if definitions_cache is not None:
for ref in definitions_cache['ref_counts']:
state['ref_counts'][ref] += definitions_cache['ref_counts'][ref]
# it's possible that a schema was seen before we hit the cache
# and also exists in the cache, in which case it is involved in a recursion
if state['current_recursion_ref_count'][ref] != 0:
state['involved_in_recursion'][ref] = True
# if it's involved in recursion in the inner schema mark it globally as involved in a recursion
for ref_in_recursion in definitions_cache['involved_in_recursion']:
if ref_in_recursion:
state['involved_in_recursion'][ref_in_recursion] = True
return s

if s['type'] != 'definition-ref':
return recurse(s, count_refs)
ref = s['schema_ref']
ref_counts[ref] += 1
state['ref_counts'][ref] += 1

if ref_counts[ref] >= 2:
if state['ref_counts'][ref] >= 2:
# If this model is involved in a recursion this should be detected
# on its second encounter, we can safely stop the walk here.
if current_recursion_ref_count[ref] != 0:
involved_in_recursion[ref] = True
if state['current_recursion_ref_count'][ref] != 0:
state['involved_in_recursion'][ref] = True
return s

current_recursion_ref_count[ref] += 1
recurse(all_defs[ref], count_refs)
current_recursion_ref_count[ref] -= 1
state['current_recursion_ref_count'][ref] += 1
recurse(state['definitions'][ref], count_refs)
state['current_recursion_ref_count'][ref] -= 1
return s

schema = walk_core_schema(schema, count_refs)

assert all(c == 0 for c in current_recursion_ref_count.values()), 'this is a bug! please report it'
assert all(c == 0 for c in state['current_recursion_ref_count'].values()), 'this is a bug! please report it'

definitions_cache = _DefinitionsState(
definitions=state['definitions'],
ref_counts=dict(state['ref_counts']),
involved_in_recursion=state['involved_in_recursion'],
current_recursion_ref_count=dict(state['current_recursion_ref_count']),
)

def can_be_inlined(s: core_schema.DefinitionReferenceSchema, ref: str) -> bool:
if state['ref_counts'][ref] > 1:
return False
if state['involved_in_recursion'].get(ref, False):
return False
if 'serialization' in s:
return False
if 'metadata' in s:
metadata = s['metadata']
for k in (
'pydantic_js_functions',
'pydantic_js_annotation_functions',
'pydantic.internal.union_discriminator',
):
if k in metadata:
# we need to keep this as a ref
return False
return True

def inline_refs(s: core_schema.CoreSchema, recurse: Recurse) -> core_schema.CoreSchema:
if s['type'] == 'definition-ref':
ref = s['schema_ref']
# Check if the reference is only used once and not involved in recursion
if ref_counts[ref] <= 1 and not involved_in_recursion.get(ref, False):
# Check if the reference is only used once, not involved in recursion and does not have
# any extra keys (like 'serialization')
if can_be_inlined(s, ref):
# Inline the reference by replacing the reference with the actual schema
new = all_defs.pop(ref)
ref_counts[ref] -= 1 # because we just replaced it!
new.pop('ref') # type: ignore
new = state['definitions'].pop(ref)
state['ref_counts'][ref] -= 1 # because we just replaced it!
# put all other keys that were on the def-ref schema into the inlined version
# in particular this is needed for `serialization`
if 'serialization' in s:
Expand All @@ -515,23 +546,12 @@ def inline_refs(s: core_schema.CoreSchema, recurse: Recurse) -> core_schema.Core

schema = walk_core_schema(schema, inline_refs)

definitions = [d for d in all_defs.values() if ref_counts[d['ref']] > 0] # type: ignore
return make_result(schema, definitions)
definitions = [d for d in state['definitions'].values() if state['ref_counts'][d['ref']] > 0] # type: ignore


def flatten_schema_defs(schema: core_schema.CoreSchema) -> core_schema.CoreSchema:
"""Simplify schema references by:
1. Grouping all definitions into a single top-level `definitions` schema, similar to a JSON schema's `#/$defs`.
"""
return _simplify_schema_references(schema, inline=False)


def inline_schema_defs(schema: core_schema.CoreSchema) -> core_schema.CoreSchema:
"""Simplify schema references by:
1. Inlining any definitions that are only referenced in one place and are not involved in a cycle.
2. Removing any unused `ref` references from schemas.
"""
return _simplify_schema_references(schema, inline=True)
if definitions:
schema = core_schema.definitions_schema(schema=schema, definitions=definitions)
schema.setdefault('metadata', {})[_DEFINITIONS_CACHE_METADATA_KEY] = definitions_cache # type: ignore
return schema


def pretty_print_core_schema(
Expand Down
12 changes: 6 additions & 6 deletions pydantic/_internal/_dataclasses.py
Expand Up @@ -16,7 +16,7 @@
from ..fields import FieldInfo
from ..warnings import PydanticDeprecatedSince20
from . import _config, _decorators, _discriminated_union, _typing_extra
from ._core_utils import collect_invalid_schemas, flatten_schema_defs, inline_schema_defs
from ._core_utils import collect_invalid_schemas, simplify_schema_references
from ._fields import collect_dataclass_fields
from ._generate_schema import GenerateSchema
from ._generics import get_standard_typevars_map
Expand Down Expand Up @@ -152,19 +152,19 @@ def __init__(__dataclass_self__: PydanticDataclass, *args: Any, **kwargs: Any) -
core_config = config_wrapper.core_config(cls)

schema = gen_schema.collect_definitions(schema)
schema = flatten_schema_defs(schema)
if collect_invalid_schemas(schema):
set_dataclass_mock_validator(cls, cls.__name__, 'all referenced types')
return False

schema = _discriminated_union.apply_discriminators(simplify_schema_references(schema))

# We are about to set all the remaining required properties expected for this cast;
# __pydantic_decorators__ and __pydantic_fields__ should already be set
cls = typing.cast('type[PydanticDataclass]', cls)
# debug(schema)
cls.__pydantic_core_schema__ = schema = _discriminated_union.apply_discriminators(flatten_schema_defs(schema))
simplified_core_schema = inline_schema_defs(schema)
cls.__pydantic_validator__ = validator = SchemaValidator(simplified_core_schema, core_config)
cls.__pydantic_serializer__ = SchemaSerializer(simplified_core_schema, core_config)
cls.__pydantic_core_schema__ = schema
cls.__pydantic_validator__ = validator = SchemaValidator(schema, core_config)
cls.__pydantic_serializer__ = SchemaSerializer(schema, core_config)

if config_wrapper.validate_assignment:

Expand Down
3 changes: 0 additions & 3 deletions pydantic/_internal/_generate_schema.py
Expand Up @@ -52,7 +52,6 @@
from ._core_utils import (
CoreSchemaOrField,
define_expected_missing_refs,
flatten_schema_defs,
get_type_ref,
is_list_like_schema_with_items_schema,
)
Expand Down Expand Up @@ -518,8 +517,6 @@ def _unpack_refs_defs(self, schema: CoreSchema) -> CoreSchema:
def get_ref(s: CoreSchema) -> str:
return s['ref'] # type: ignore

schema = flatten_schema_defs(schema)

if schema['type'] == 'definitions':
self.defs.definitions.update({get_ref(s): s for s in schema['definitions']})
schema = schema['schema']
Expand Down
10 changes: 5 additions & 5 deletions pydantic/_internal/_model_construction.py
Expand Up @@ -16,7 +16,7 @@
from ..fields import Field, FieldInfo, ModelPrivateAttr, PrivateAttr
from ..warnings import PydanticDeprecatedSince20
from ._config import ConfigWrapper
from ._core_utils import collect_invalid_schemas, flatten_schema_defs, inline_schema_defs
from ._core_utils import collect_invalid_schemas, simplify_schema_references
from ._decorators import (
ComputedFieldInfo,
DecoratorInfos,
Expand Down Expand Up @@ -487,16 +487,16 @@ def complete_model_class(
core_config = config_wrapper.core_config(cls)

schema = gen_schema.collect_definitions(schema)
schema = apply_discriminators(flatten_schema_defs(schema))
if collect_invalid_schemas(schema):
set_model_mocks(cls, cls_name)
return False

schema = apply_discriminators(simplify_schema_references(schema))

# debug(schema)
cls.__pydantic_core_schema__ = schema
simplified_core_schema = inline_schema_defs(schema)
cls.__pydantic_validator__ = SchemaValidator(simplified_core_schema, core_config)
cls.__pydantic_serializer__ = SchemaSerializer(simplified_core_schema, core_config)
cls.__pydantic_validator__ = SchemaValidator(schema, core_config)
cls.__pydantic_serializer__ = SchemaSerializer(schema, core_config)
cls.__pydantic_complete__ = True

# set __signature__ attr only for model class, but not for its instances
Expand Down
21 changes: 10 additions & 11 deletions pydantic/_internal/_validate_call.py
Expand Up @@ -10,7 +10,7 @@
from ..config import ConfigDict
from . import _discriminated_union, _generate_schema, _typing_extra
from ._config import ConfigWrapper
from ._core_utils import flatten_schema_defs, inline_schema_defs
from ._core_utils import simplify_schema_references


@dataclass
Expand Down Expand Up @@ -61,11 +61,12 @@ def __init__(self, function: Callable[..., Any], config: ConfigDict | None, vali
namespace = _typing_extra.add_module_globals(function, None)
config_wrapper = ConfigWrapper(config)
gen_schema = _generate_schema.GenerateSchema(config_wrapper, namespace)
self.__pydantic_core_schema__ = schema = gen_schema.collect_definitions(gen_schema.generate_schema(function))
schema = gen_schema.collect_definitions(gen_schema.generate_schema(function))
schema = simplify_schema_references(schema)
self.__pydantic_core_schema__ = schema = schema
core_config = config_wrapper.core_config(self)
schema = _discriminated_union.apply_discriminators(flatten_schema_defs(schema))
simplified_schema = inline_schema_defs(schema)
self.__pydantic_validator__ = pydantic_core.SchemaValidator(simplified_schema, core_config)
schema = _discriminated_union.apply_discriminators(schema)
self.__pydantic_validator__ = pydantic_core.SchemaValidator(schema, core_config)

if self._validate_return:
return_type = (
Expand All @@ -74,13 +75,11 @@ def __init__(self, function: Callable[..., Any], config: ConfigDict | None, vali
else Any
)
gen_schema = _generate_schema.GenerateSchema(config_wrapper, namespace)
self.__return_pydantic_core_schema__ = schema = gen_schema.collect_definitions(
gen_schema.generate_schema(return_type)
)
schema = gen_schema.collect_definitions(gen_schema.generate_schema(return_type))
schema = _discriminated_union.apply_discriminators(simplify_schema_references(schema))
self.__return_pydantic_core_schema__ = schema
core_config = config_wrapper.core_config(self)
schema = _discriminated_union.apply_discriminators(flatten_schema_defs(schema))
simplified_schema = inline_schema_defs(schema)
validator = pydantic_core.SchemaValidator(simplified_schema, core_config)
validator = pydantic_core.SchemaValidator(schema, core_config)
if inspect.iscoroutinefunction(self.raw_function):

async def return_val_wrapper(aw: Awaitable[Any]) -> None:
Expand Down
7 changes: 3 additions & 4 deletions pydantic/type_adapter.py
Expand Up @@ -166,21 +166,20 @@ def __init__(self, type: Any, *, config: ConfigDict | None = None, _parent_depth
except AttributeError:
core_schema = _get_schema(type, config_wrapper, parent_depth=_parent_depth + 1)

core_schema = _discriminated_union.apply_discriminators(_core_utils.flatten_schema_defs(core_schema))
simplified_core_schema = _core_utils.inline_schema_defs(core_schema)
core_schema = _discriminated_union.apply_discriminators(_core_utils.simplify_schema_references(core_schema))

core_config = config_wrapper.core_config(None)
validator: SchemaValidator
try:
validator = _getattr_no_parents(type, '__pydantic_validator__')
except AttributeError:
validator = SchemaValidator(simplified_core_schema, core_config)
validator = SchemaValidator(core_schema, core_config)

serializer: SchemaSerializer
try:
serializer = _getattr_no_parents(type, '__pydantic_serializer__')
except AttributeError:
serializer = SchemaSerializer(simplified_core_schema, core_config)
serializer = SchemaSerializer(core_schema, core_config)

self.core_schema = core_schema
self.validator = validator
Expand Down
4 changes: 2 additions & 2 deletions tests/test_dataclasses.py
Expand Up @@ -2070,8 +2070,8 @@ class GenericDataclass(Generic[T]):

# verify that generic parameters are showing up in the type ref for generic dataclasses
# this can probably be removed if the schema changes in some way that makes this part of the test fail
assert '[int:' in validator1.core_schema['schema']['schema_ref']
assert '[str:' in validator2.core_schema['schema']['schema_ref']
assert '[int:' in validator1.core_schema['ref']
assert '[str:' in validator2.core_schema['ref']

assert validator1.validate_python({'x': 1}).x == 1
assert validator2.validate_python({'x': 'hello world'}).x == 'hello world'
Expand Down

0 comments on commit fcb97b4

Please sign in to comment.