Skip to content

Commit

Permalink
Cache invalid schemas during CoreSchema post processing
Browse files Browse the repository at this point in the history
  • Loading branch information
adriangb committed Sep 21, 2023
1 parent 198c8c6 commit 876a058
Show file tree
Hide file tree
Showing 5 changed files with 77 additions and 36 deletions.
46 changes: 29 additions & 17 deletions pydantic/_internal/_core_utils.py
Expand Up @@ -42,6 +42,14 @@
_DEFINITIONS_CACHE_METADATA_KEY = 'pydantic.definitions_cache'

NEEDS_APPLY_DISCRIMINATED_UNION_METADATA_KEY = 'pydantic.internal.needs_apply_discriminated_union'
"""Used to mark a schema that has a discriminated union that needs to be checked for validity at the end of
schema building because one of it's members refers to a definition that was not yet defined when the union
was first encountered.
"""
HAS_INVALID_SCHEMAS_METADATA_KEY = 'pydantic.internal.invalid'
"""Used to mark a schema that is invalid because it refers to a definition that was not yet defined when the
schema was first encountered.
"""


def is_core_schema(
Expand Down Expand Up @@ -136,43 +144,47 @@ def _record_valid_refs(s: core_schema.CoreSchema, recurse: Recurse) -> core_sche

def define_expected_missing_refs(
schema: core_schema.CoreSchema, allowed_missing_refs: set[str]
) -> core_schema.CoreSchema:
) -> core_schema.CoreSchema | None:
if not allowed_missing_refs:
# in this case, there are no missing refs to potentially substitute, so there's no need to walk the schema
# this is a common case (will be hit for all non-generic models), so it's worth optimizing for
return schema
refs: set[str] = set()
return None

def _record_refs(s: core_schema.CoreSchema, recurse: Recurse) -> core_schema.CoreSchema:
ref = get_ref(s)
if ref:
refs.add(ref)
return recurse(s, _record_refs)

walk_core_schema(schema, _record_refs)
refs = collect_definitions(schema).keys()

expected_missing_refs = allowed_missing_refs.difference(refs)
if expected_missing_refs:
definitions: list[core_schema.CoreSchema] = [
# TODO: Replace this with a (new) CoreSchema that, if present at any level, makes validation fail
# Issue: https://github.com/pydantic/pydantic-core/issues/619
core_schema.none_schema(ref=ref, metadata={'pydantic_debug_missing_ref': True, 'invalid': True})
core_schema.none_schema(ref=ref, metadata={HAS_INVALID_SCHEMAS_METADATA_KEY: True})
for ref in expected_missing_refs
]
return core_schema.definitions_schema(schema, definitions)
return schema
return None


def collect_invalid_schemas(schema: core_schema.CoreSchema) -> list[core_schema.CoreSchema]:
invalid_schemas: list[core_schema.CoreSchema] = []
def collect_invalid_schemas(schema: core_schema.CoreSchema) -> bool:
invalid = False

def _is_schema_valid(s: core_schema.CoreSchema, recurse: Recurse) -> core_schema.CoreSchema:
if s.get('metadata', {}).get('invalid'):
invalid_schemas.append(s)
nonlocal invalid
if 'metadata' in s:
metadata = s['metadata']
if HAS_INVALID_SCHEMAS_METADATA_KEY in metadata:
invalid = metadata[HAS_INVALID_SCHEMAS_METADATA_KEY]
if invalid is True:
invalid = True
return s
return recurse(s, _is_schema_valid)

walk_core_schema(schema, _is_schema_valid)
return invalid_schemas
if 'metadata' in schema:
metadata = schema['metadata']
metadata[HAS_INVALID_SCHEMAS_METADATA_KEY] = invalid
else:
schema['metadata'] = {HAS_INVALID_SCHEMAS_METADATA_KEY: invalid}
return invalid


T = TypeVar('T')
Expand Down
13 changes: 6 additions & 7 deletions pydantic/_internal/_discriminated_union.py
Expand Up @@ -41,13 +41,12 @@ def inner(s: core_schema.CoreSchema, recurse: _core_utils.Recurse) -> core_schem
if s['type'] == 'tagged-union':
return s

metadata = s.get('metadata', None)
if metadata is not None:
discriminator = metadata.get(CORE_SCHEMA_METADATA_DISCRIMINATOR_PLACEHOLDER_KEY, None)
if discriminator is not None:
if definitions is None:
definitions = collect_definitions(schema)
s = apply_discriminator(s, discriminator, definitions)
metadata = s.get('metadata', {})
discriminator = metadata.get(CORE_SCHEMA_METADATA_DISCRIMINATOR_PLACEHOLDER_KEY, None)
if discriminator is not None:
if definitions is None:
definitions = collect_definitions(schema)
s = apply_discriminator(s, discriminator, definitions)
return s

return _core_utils.walk_core_schema(schema, inner)
Expand Down
29 changes: 25 additions & 4 deletions pydantic/_internal/_generate_schema.py
Expand Up @@ -50,6 +50,7 @@
build_metadata_dict,
)
from ._core_utils import (
HAS_INVALID_SCHEMAS_METADATA_KEY,
NEEDS_APPLY_DISCRIMINATED_UNION_METADATA_KEY,
CoreSchemaOrField,
define_expected_missing_refs,
Expand Down Expand Up @@ -274,6 +275,7 @@ def __init__(
self._types_namespace = types_namespace
self._typevars_map = typevars_map
self._needs_apply_discriminated_union = False
self._has_invalid_schema = False
self.defs = _Definitions()

@classmethod
Expand All @@ -289,6 +291,7 @@ def __from_parent(
obj._types_namespace = types_namespace
obj._typevars_map = typevars_map
obj._needs_apply_discriminated_union = False
obj._has_invalid_schema = False
obj.defs = defs
return obj

Expand Down Expand Up @@ -514,7 +517,13 @@ def _model_schema(self, cls: type[BaseModel]) -> core_schema.CoreSchema:
model_name=cls.__name__,
)
inner_schema = apply_validators(fields_schema, decorators.root_validators.values(), None)
inner_schema = define_expected_missing_refs(inner_schema, recursively_defined_type_refs())
new_inner_schema = define_expected_missing_refs(inner_schema, recursively_defined_type_refs())
if new_inner_schema is not None:
inner_schema = new_inner_schema
self._has_invalid_schema = True
metadata[HAS_INVALID_SCHEMAS_METADATA_KEY] = True
else:
metadata[HAS_INVALID_SCHEMAS_METADATA_KEY] = False
inner_schema = apply_model_validators(inner_schema, model_validators, 'inner')

model_schema = core_schema.model_schema(
Expand Down Expand Up @@ -648,14 +657,26 @@ def _get_first_two_args_or_any(self, obj: Any) -> tuple[Any, Any]:

def _post_process_generated_schema(self, schema: core_schema.CoreSchema) -> core_schema.CoreSchema:
if 'metadata' in schema:
schema['metadata'][NEEDS_APPLY_DISCRIMINATED_UNION_METADATA_KEY] = self._needs_apply_discriminated_union
metadata = schema['metadata']
metadata[NEEDS_APPLY_DISCRIMINATED_UNION_METADATA_KEY] = self._needs_apply_discriminated_union
metadata[HAS_INVALID_SCHEMAS_METADATA_KEY] = self._has_invalid_schema
else:
schema['metadata'] = {NEEDS_APPLY_DISCRIMINATED_UNION_METADATA_KEY: self._needs_apply_discriminated_union}
schema['metadata'] = {
NEEDS_APPLY_DISCRIMINATED_UNION_METADATA_KEY: self._needs_apply_discriminated_union,
HAS_INVALID_SCHEMAS_METADATA_KEY: self._has_invalid_schema,
}
return schema

def _generate_schema(self, obj: Any) -> core_schema.CoreSchema:
"""Recursively generate a pydantic-core schema for any supported python type."""
return self._post_process_generated_schema(self._generate_schema_inner(obj))
has_invalid_schema = self._has_invalid_schema
self._has_invalid_schema = False
needs_apply_discriminated_union = self._needs_apply_discriminated_union
self._needs_apply_discriminated_union = False
schema = self._post_process_generated_schema(self._generate_schema_inner(obj))
self._has_invalid_schema = self._has_invalid_schema or has_invalid_schema
self._needs_apply_discriminated_union = self._needs_apply_discriminated_union or needs_apply_discriminated_union
return schema

def _generate_schema_inner(self, obj: Any) -> core_schema.CoreSchema:
if isinstance(obj, _AnnotatedType):
Expand Down
9 changes: 2 additions & 7 deletions tests/test_generics.py
Expand Up @@ -1783,7 +1783,7 @@ class M2(BaseModel, Generic[V3]):
M1 = module.M1

# assert M1.__pydantic_core_schema__ == {}
assert collect_invalid_schemas(M1.__pydantic_core_schema__) == []
assert collect_invalid_schemas(M1.__pydantic_core_schema__) is False


def test_generic_recursive_models_complicated(create_module):
Expand Down Expand Up @@ -1843,7 +1843,7 @@ class M2(BaseModel, Generic[V3]):

M1 = module.M1

assert collect_invalid_schemas(M1.__pydantic_core_schema__) == []
assert collect_invalid_schemas(M1.__pydantic_core_schema__) is False


def test_generic_recursive_models_in_container(create_module):
Expand All @@ -1864,11 +1864,6 @@ class MyGenericModel(BaseModel, Generic[T]):
assert type(instance.foobar[0]) == MyGenericModel[int]


def test_schema_is_valid():
assert not collect_invalid_schemas(core_schema.none_schema())
assert collect_invalid_schemas(core_schema.nullable_schema(core_schema.int_schema(metadata={'invalid': True})))


def test_generic_enum():
T = TypeVar('T')

Expand Down
16 changes: 15 additions & 1 deletion tests/test_internal.py
Expand Up @@ -7,7 +7,13 @@
from pydantic_core import CoreSchema, SchemaValidator
from pydantic_core import core_schema as cs

from pydantic._internal._core_utils import Walk, simplify_schema_references, walk_core_schema
from pydantic._internal._core_utils import (
HAS_INVALID_SCHEMAS_METADATA_KEY,
Walk,
collect_invalid_schemas,
simplify_schema_references,
walk_core_schema,
)
from pydantic._internal._repr import Representation


Expand Down Expand Up @@ -192,3 +198,11 @@ class Obj(Representation):
' ) (Obj)',
]
assert list(obj.__rich_repr__()) == [('int_attr', 42), ('str_attr', 'Marvin')]


def test_schema_is_valid():
assert collect_invalid_schemas(cs.none_schema()) is False
assert (
collect_invalid_schemas(cs.nullable_schema(cs.int_schema(metadata={HAS_INVALID_SCHEMAS_METADATA_KEY: True})))
is True
)

0 comments on commit 876a058

Please sign in to comment.