Skip to content

Commit

Permalink
Remove schema building caches (#7624)
Browse files Browse the repository at this point in the history
Co-authored-by: Serge Matveenko <lig@pydantic.dev>
  • Loading branch information
adriangb and lig committed Sep 26, 2023
1 parent 734d3f9 commit 38bc2da
Show file tree
Hide file tree
Showing 3 changed files with 85 additions and 120 deletions.
163 changes: 59 additions & 104 deletions pydantic/_internal/_core_utils.py
Expand Up @@ -14,7 +14,7 @@

from pydantic_core import CoreSchema, core_schema
from pydantic_core import validate_core_schema as _validate_core_schema
from typing_extensions import TypeAliasType, TypedDict, TypeGuard, get_args
from typing_extensions import TypeAliasType, TypeGuard, get_args

from . import _repr

Expand Down Expand Up @@ -128,12 +128,6 @@ def collect_definitions(schema: core_schema.CoreSchema) -> dict[str, core_schema
defs: dict[str, CoreSchema] = {}

def _record_valid_refs(s: core_schema.CoreSchema, recurse: Recurse) -> core_schema.CoreSchema:
if 'metadata' in s:
definitions_cache: _DefinitionsState | None = s['metadata'].get(_DEFINITIONS_CACHE_METADATA_KEY, None)
if definitions_cache is not None:
defs.update(definitions_cache['definitions'])
return s

ref = get_ref(s)
if ref:
defs[ref] = s
Expand Down Expand Up @@ -215,7 +209,7 @@ def walk(self, schema: core_schema.CoreSchema, f: Walk) -> core_schema.CoreSchem
return f(schema, self._walk)

def _walk(self, schema: core_schema.CoreSchema, f: Walk) -> core_schema.CoreSchema:
schema = self._schema_type_to_method[schema['type']](schema, f)
schema = self._schema_type_to_method[schema['type']](schema.copy(), f)
ser_schema: core_schema.SerSchema | None = schema.get('serialization') # type: ignore
if ser_schema:
schema['serialization'] = self._handle_ser_schemas(ser_schema, f)
Expand Down Expand Up @@ -436,101 +430,62 @@ def walk_core_schema(schema: core_schema.CoreSchema, f: Walk) -> core_schema.Cor
Returns:
core_schema.CoreSchema: A processed CoreSchema.
"""
return f(schema, _dispatch)


class _DefinitionsState(TypedDict):
definitions: dict[str, core_schema.CoreSchema]
ref_counts: dict[str, int]
involved_in_recursion: dict[str, bool]
current_recursion_ref_count: dict[str, int]
return f(schema.copy(), _dispatch)


def simplify_schema_references(schema: core_schema.CoreSchema) -> core_schema.CoreSchema: # noqa: C901
"""Simplify schema references by:
1. Inlining any definitions that are only referenced in one place and are not involved in a cycle.
2. Removing any unused `ref` references from schemas.
"""
state = _DefinitionsState(
definitions={},
ref_counts=defaultdict(int),
involved_in_recursion={},
current_recursion_ref_count=defaultdict(int),
)
definitions: dict[str, core_schema.CoreSchema] = {}
ref_counts: dict[str, int] = defaultdict(int)
involved_in_recursion: dict[str, bool] = {}
current_recursion_ref_count: dict[str, int] = defaultdict(int)

def collect_refs(s: core_schema.CoreSchema, recurse: Recurse) -> core_schema.CoreSchema:
if 'metadata' in s:
definitions_cache: _DefinitionsState | None = s['metadata'].get(_DEFINITIONS_CACHE_METADATA_KEY, None)
if definitions_cache is not None:
state['definitions'].update(definitions_cache['definitions'])
return s

if s['type'] == 'definitions':
for definition in s['definitions']:
ref = get_ref(definition)
assert ref is not None
state['definitions'][ref] = definition
definitions[ref] = definition
recurse(definition, collect_refs)
return recurse(s['schema'], collect_refs)
else:
ref = get_ref(s)
if ref is not None:
state['definitions'][ref] = s
recurse(s, collect_refs)
new = recurse(s, collect_refs)
new_ref = get_ref(new)
if new_ref:
definitions[new_ref] = new
return core_schema.definition_reference_schema(schema_ref=ref)
else:
return recurse(s, collect_refs)

schema = walk_core_schema(schema, collect_refs)

def count_refs(s: core_schema.CoreSchema, recurse: Recurse) -> core_schema.CoreSchema:
if 'metadata' in s:
definitions_cache: _DefinitionsState | None = s['metadata'].get(_DEFINITIONS_CACHE_METADATA_KEY, None)
if definitions_cache is not None:
for ref in definitions_cache['ref_counts']:
state['ref_counts'][ref] += definitions_cache['ref_counts'][ref]
# it's possible that a schema was seen before we hit the cache
# and also exists in the cache, in which case it is involved in a recursion
if state['current_recursion_ref_count'][ref] != 0:
state['involved_in_recursion'][ref] = True
# if it's involved in recursion in the inner schema mark it globally as involved in a recursion
for ref_in_recursion in definitions_cache['involved_in_recursion']:
if ref_in_recursion:
state['involved_in_recursion'][ref_in_recursion] = True
return s

if s['type'] != 'definition-ref':
return recurse(s, count_refs)
ref = s['schema_ref']
state['ref_counts'][ref] += 1
ref_counts[ref] += 1

if state['ref_counts'][ref] >= 2:
if ref_counts[ref] >= 2:
# If this model is involved in a recursion this should be detected
# on its second encounter, we can safely stop the walk here.
if state['current_recursion_ref_count'][ref] != 0:
state['involved_in_recursion'][ref] = True
if current_recursion_ref_count[ref] != 0:
involved_in_recursion[ref] = True
return s

state['current_recursion_ref_count'][ref] += 1
recurse(state['definitions'][ref], count_refs)
state['current_recursion_ref_count'][ref] -= 1
current_recursion_ref_count[ref] += 1
recurse(definitions[ref], count_refs)
current_recursion_ref_count[ref] -= 1
return s

schema = walk_core_schema(schema, count_refs)

assert all(c == 0 for c in state['current_recursion_ref_count'].values()), 'this is a bug! please report it'

definitions_cache = _DefinitionsState(
definitions=state['definitions'],
ref_counts=dict(state['ref_counts']),
involved_in_recursion=state['involved_in_recursion'],
current_recursion_ref_count=dict(state['current_recursion_ref_count']),
)
assert all(c == 0 for c in current_recursion_ref_count.values()), 'this is a bug! please report it'

def can_be_inlined(s: core_schema.DefinitionReferenceSchema, ref: str) -> bool:
if state['ref_counts'][ref] > 1:
if ref_counts[ref] > 1:
return False
if state['involved_in_recursion'].get(ref, False):
if involved_in_recursion.get(ref, False):
return False
if 'serialization' in s:
return False
Expand All @@ -553,8 +508,8 @@ def inline_refs(s: core_schema.CoreSchema, recurse: Recurse) -> core_schema.Core
# any extra keys (like 'serialization')
if can_be_inlined(s, ref):
# Inline the reference by replacing the reference with the actual schema
new = state['definitions'].pop(ref)
state['ref_counts'][ref] -= 1 # because we just replaced it!
new = definitions.pop(ref)
ref_counts[ref] -= 1 # because we just replaced it!
# put all other keys that were on the def-ref schema into the inlined version
# in particular this is needed for `serialization`
if 'serialization' in s:
Expand All @@ -568,17 +523,44 @@ def inline_refs(s: core_schema.CoreSchema, recurse: Recurse) -> core_schema.Core

schema = walk_core_schema(schema, inline_refs)

definitions = [d for d in state['definitions'].values() if state['ref_counts'][d['ref']] > 0] # type: ignore
def_values = [v for v in definitions.values() if ref_counts[v['ref']] > 0] # type: ignore

if definitions:
schema = core_schema.definitions_schema(schema=schema, definitions=definitions)
if 'metadata' in schema:
schema['metadata'][_DEFINITIONS_CACHE_METADATA_KEY] = definitions_cache
else:
schema['metadata'] = {_DEFINITIONS_CACHE_METADATA_KEY: definitions_cache}
if def_values:
schema = core_schema.definitions_schema(schema=schema, definitions=def_values)
return schema


def _strip_metadata(schema: CoreSchema) -> CoreSchema:
def strip_metadata(s: CoreSchema, recurse: Recurse) -> CoreSchema:
s = s.copy()
s.pop('metadata', None)
if s['type'] == 'model-fields':
s = s.copy()
s['fields'] = {k: v.copy() for k, v in s['fields'].items()}
for field_name, field_schema in s['fields'].items():
field_schema.pop('metadata', None)
s['fields'][field_name] = field_schema
computed_fields = s.get('computed_fields', None)
if computed_fields:
s['computed_fields'] = [cf.copy() for cf in computed_fields]
for cf in computed_fields:
cf.pop('metadata', None)
else:
s.pop('computed_fields', None)
elif s['type'] == 'model':
# remove some defaults
if s.get('custom_init', True) is False:
s.pop('custom_init')
if s.get('root_model', True) is False:
s.pop('root_model')
if {'title'}.issuperset(s.get('config', {}).keys()):
s.pop('config', None)

return recurse(s, strip_metadata)

return walk_core_schema(schema, strip_metadata)


def pretty_print_core_schema(
schema: CoreSchema,
include_metadata: bool = False,
Expand All @@ -593,34 +575,7 @@ def pretty_print_core_schema(
from rich import print # type: ignore # install it manually in your dev env

if not include_metadata:

def strip_metadata(s: CoreSchema, recurse: Recurse) -> CoreSchema:
s.pop('metadata', None)
if s['type'] == 'model-fields':
s = s.copy()
s['fields'] = {k: v.copy() for k, v in s['fields'].items()}
for field_name, field_schema in s['fields'].items():
field_schema.pop('metadata', None)
s['fields'][field_name] = field_schema
computed_fields = s.get('computed_fields', None)
if computed_fields:
s['computed_fields'] = [cf.copy() for cf in computed_fields]
for cf in computed_fields:
cf.pop('metadata', None)
else:
s.pop('computed_fields', None)
elif s['type'] == 'model':
# remove some defaults
if s.get('custom_init', True) is False:
s.pop('custom_init')
if s.get('root_model', True) is False:
s.pop('root_model')
if {'title'}.issuperset(s.get('config', {}).keys()):
s.pop('config')

return recurse(s, strip_metadata)

schema = walk_core_schema(schema, strip_metadata)
schema = _strip_metadata(schema)

return print(schema)

Expand Down
20 changes: 4 additions & 16 deletions pydantic/_internal/_generate_schema.py
Expand Up @@ -54,6 +54,7 @@
NEEDS_APPLY_DISCRIMINATED_UNION_METADATA_KEY,
CoreSchemaOrField,
define_expected_missing_refs,
get_ref,
get_type_ref,
is_list_like_schema_with_items_schema,
)
Expand Down Expand Up @@ -396,6 +397,8 @@ def collect_definitions(self, schema: CoreSchema) -> CoreSchema:
ref = cast('str | None', schema.get('ref', None))
if ref:
self.defs.definitions[ref] = schema
if 'ref' in schema:
schema = core_schema.definition_reference_schema(schema['ref'])
return core_schema.definitions_schema(
schema,
list(self.defs.definitions.values()),
Expand Down Expand Up @@ -554,19 +557,6 @@ def _model_schema(self, cls: type[BaseModel]) -> core_schema.CoreSchema:
self.defs.definitions[model_ref] = self._post_process_generated_schema(schema)
return core_schema.definition_reference_schema(model_ref)

def _unpack_refs_defs(self, schema: CoreSchema) -> CoreSchema:
"""Unpack all 'definitions' schemas into `GenerateSchema.defs.definitions`
and return the inner schema.
"""

def get_ref(s: CoreSchema) -> str:
return s['ref'] # type: ignore

if schema['type'] == 'definitions':
self.defs.definitions.update({get_ref(s): s for s in schema['definitions']})
schema = schema['schema']
return schema

def _generate_schema_from_property(self, obj: Any, source: Any) -> core_schema.CoreSchema | None:
"""Try to generate schema from either the `__get_pydantic_core_schema__` function or
`__pydantic_core_schema__` property.
Expand Down Expand Up @@ -603,9 +593,7 @@ def _generate_schema_from_property(self, obj: Any, source: Any) -> core_schema.C
source, CallbackGetCoreSchemaHandler(self._generate_schema, self, ref_mode=ref_mode)
)

schema = self._unpack_refs_defs(schema)

ref: str | None = schema.get('ref', None)
ref = get_ref(schema)
if ref:
self.defs.definitions[ref] = self._post_process_generated_schema(schema)
return core_schema.definition_reference_schema(ref)
Expand Down
22 changes: 22 additions & 0 deletions tests/test_edge_cases.py
Expand Up @@ -2612,3 +2612,25 @@ def __exit__(self, _exception_type, exception, exception_traceback):
MyModel(**data)

assert len(traceback_exceptions) == 1


def test_recursive_walk_fails_on_double_diamond_composition():
class A(BaseModel):
pass

class B(BaseModel):
a_1: A
a_2: A

class C(BaseModel):
b: B

class D(BaseModel):
c_1: C
c_2: C

class E(BaseModel):
c: C

# This is just to check that above model contraption doesn't fail
assert E(c=C(b=B(a_1=A(), a_2=A()))).model_dump() == {'c': {'b': {'a_1': {}, 'a_2': {}}}}

0 comments on commit 38bc2da

Please sign in to comment.