Skip to content

Commit

Permalink
Improve cache based on PR feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
adriangb committed Sep 21, 2023
1 parent 388e62f commit 39f1ae6
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 33 deletions.
45 changes: 29 additions & 16 deletions pydantic/_internal/_core_utils.py
Expand Up @@ -42,6 +42,14 @@
_DEFINITIONS_CACHE_METADATA_KEY = 'pydantic.definitions_cache'

NEEDS_APPLY_DISCRIMINATED_UNION_METADATA_KEY = 'pydantic.internal.needs_apply_discriminated_union'
"""Used to mark a schema that has a discriminated union that needs to be checked for validity at the end of
schema building because one of it's members refers to a definition that was not yet defined when the union
was first encountered.
"""
HAS_INVALID_SCHEMAS_METADATA_KEY = 'pydantic.internal.invalid'
"""Used to mark a schema that is invalid because it refers to a definition that was not yet defined when the
schema was first encountered.
"""


def is_core_schema(
Expand Down Expand Up @@ -136,11 +144,11 @@ def _record_valid_refs(s: core_schema.CoreSchema, recurse: Recurse) -> core_sche

def define_expected_missing_refs(
schema: core_schema.CoreSchema, allowed_missing_refs: set[str]
) -> tuple[core_schema.CoreSchema, bool]:
) -> core_schema.CoreSchema | None:
if not allowed_missing_refs:
# in this case, there are no missing refs to potentially substitute, so there's no need to walk the schema
# this is a common case (will be hit for all non-generic models), so it's worth optimizing for
return schema, False
return None

refs = collect_definitions(schema).keys()

Expand All @@ -149,29 +157,34 @@ def define_expected_missing_refs(
definitions: list[core_schema.CoreSchema] = [
# TODO: Replace this with a (new) CoreSchema that, if present at any level, makes validation fail
# Issue: https://github.com/pydantic/pydantic-core/issues/619
core_schema.none_schema(ref=ref, metadata={'pydantic_debug_missing_ref': True, 'invalid': True})
core_schema.none_schema(ref=ref, metadata={HAS_INVALID_SCHEMAS_METADATA_KEY: True})
for ref in expected_missing_refs
]
return core_schema.definitions_schema(schema, definitions), True
return schema, False
return core_schema.definitions_schema(schema, definitions)
return None


def collect_invalid_schemas(schema: core_schema.CoreSchema) -> list[core_schema.CoreSchema]:
invalid_schemas: list[core_schema.CoreSchema] = []
def collect_invalid_schemas(schema: core_schema.CoreSchema) -> bool:
invalid = False

def _is_schema_valid(s: core_schema.CoreSchema, recurse: Recurse) -> core_schema.CoreSchema:
metadata = s.get('metadata', None)
if metadata is None:
return recurse(s, _is_schema_valid)
invalid = metadata.get('invalid', None)
if invalid is False:
return s
elif invalid is True:
invalid_schemas.append(s)
nonlocal invalid
if 'metadata' in s:
metadata = s['metadata']
if HAS_INVALID_SCHEMAS_METADATA_KEY in metadata:
invalid = metadata[HAS_INVALID_SCHEMAS_METADATA_KEY]
if invalid is True:
invalid = True
return s
return recurse(s, _is_schema_valid)

walk_core_schema(schema, _is_schema_valid)
return invalid_schemas
if 'metadata' in schema:
metadata = schema['metadata']
metadata[HAS_INVALID_SCHEMAS_METADATA_KEY] = invalid
else:
schema['metadata'] = {HAS_INVALID_SCHEMAS_METADATA_KEY: invalid}
return invalid


T = TypeVar('T')
Expand Down
18 changes: 9 additions & 9 deletions pydantic/_internal/_generate_schema.py
Expand Up @@ -50,6 +50,7 @@
build_metadata_dict,
)
from ._core_utils import (
HAS_INVALID_SCHEMAS_METADATA_KEY,
NEEDS_APPLY_DISCRIMINATED_UNION_METADATA_KEY,
CoreSchemaOrField,
define_expected_missing_refs,
Expand Down Expand Up @@ -513,15 +514,14 @@ def _model_schema(self, cls: type[BaseModel]) -> core_schema.CoreSchema:
model_name=cls.__name__,
)
inner_schema = apply_validators(fields_schema, decorators.root_validators.values(), None)
inner_schema, has_invalid_schema = define_expected_missing_refs(
inner_schema, recursively_defined_type_refs()
)
inner_schema = apply_model_validators(inner_schema, model_validators, 'inner')

if has_invalid_schema:
new_inner_schema = define_expected_missing_refs(inner_schema, recursively_defined_type_refs())
if new_inner_schema is not None:
inner_schema = new_inner_schema
self._has_invalid_schema = True

metadata['invalid'] = has_invalid_schema
metadata[HAS_INVALID_SCHEMAS_METADATA_KEY] = True
else:
metadata[HAS_INVALID_SCHEMAS_METADATA_KEY] = False
inner_schema = apply_model_validators(inner_schema, model_validators, 'inner')

model_schema = core_schema.model_schema(
cls,
Expand Down Expand Up @@ -655,7 +655,7 @@ def _get_first_two_args_or_any(self, obj: Any) -> tuple[Any, Any]:
def _post_process_generated_schema(self, schema: core_schema.CoreSchema) -> core_schema.CoreSchema:
metadata = schema.setdefault('metadata', {})
metadata[NEEDS_APPLY_DISCRIMINATED_UNION_METADATA_KEY] = self._needs_apply_discriminated_union
metadata['invalid'] = self._has_invalid_schema
metadata[HAS_INVALID_SCHEMAS_METADATA_KEY] = self._has_invalid_schema
return schema

def _generate_schema(self, obj: Any) -> core_schema.CoreSchema:
Expand Down
9 changes: 2 additions & 7 deletions tests/test_generics.py
Expand Up @@ -1782,7 +1782,7 @@ class M2(BaseModel, Generic[V3]):
M1 = module.M1

# assert M1.__pydantic_core_schema__ == {}
assert collect_invalid_schemas(M1.__pydantic_core_schema__) == []
assert collect_invalid_schemas(M1.__pydantic_core_schema__) is False


def test_generic_recursive_models_complicated(create_module):
Expand Down Expand Up @@ -1842,7 +1842,7 @@ class M2(BaseModel, Generic[V3]):

M1 = module.M1

assert collect_invalid_schemas(M1.__pydantic_core_schema__) == []
assert collect_invalid_schemas(M1.__pydantic_core_schema__) is False


def test_generic_recursive_models_in_container(create_module):
Expand All @@ -1863,11 +1863,6 @@ class MyGenericModel(BaseModel, Generic[T]):
assert type(instance.foobar[0]) == MyGenericModel[int]


def test_schema_is_valid():
assert not collect_invalid_schemas(core_schema.none_schema())
assert collect_invalid_schemas(core_schema.nullable_schema(core_schema.int_schema(metadata={'invalid': True})))


def test_generic_enum():
T = TypeVar('T')

Expand Down
16 changes: 15 additions & 1 deletion tests/test_internal.py
Expand Up @@ -7,7 +7,13 @@
from pydantic_core import CoreSchema, SchemaValidator
from pydantic_core import core_schema as cs

from pydantic._internal._core_utils import Walk, simplify_schema_references, walk_core_schema
from pydantic._internal._core_utils import (
HAS_INVALID_SCHEMAS_METADATA_KEY,
Walk,
collect_invalid_schemas,
simplify_schema_references,
walk_core_schema,
)
from pydantic._internal._repr import Representation


Expand Down Expand Up @@ -192,3 +198,11 @@ class Obj(Representation):
' ) (Obj)',
]
assert list(obj.__rich_repr__()) == [('int_attr', 42), ('str_attr', 'Marvin')]


def test_schema_is_valid():
assert collect_invalid_schemas(cs.none_schema()) is False
assert (
collect_invalid_schemas(cs.nullable_schema(cs.int_schema(metadata={HAS_INVALID_SCHEMAS_METADATA_KEY: True})))
is True
)

0 comments on commit 39f1ae6

Please sign in to comment.