Skip to content

Commit

Permalink
Use a stack for the types namespace (#8378)
Browse files Browse the repository at this point in the history
Co-authored-by: sydney-runkle <sydneymarierunkle@gmail.com>
  • Loading branch information
dmontagu and sydney-runkle committed Jan 10, 2024
1 parent 2f48249 commit 54a1576
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 19 deletions.
60 changes: 41 additions & 19 deletions pydantic/_internal/_generate_schema.py
Expand Up @@ -272,12 +272,35 @@ def _add_custom_serialization_from_json_encoders(
return schema


TypesNamespace = Union[Dict[str, Any], None]


class TypesNamespaceStack:
"""A stack of types namespaces."""

def __init__(self, types_namespace: TypesNamespace):
self._types_namespace_stack: list[TypesNamespace] = [types_namespace]

@property
def tail(self) -> TypesNamespace:
return self._types_namespace_stack[-1]

@contextmanager
def push(self, for_type: type[Any]):
types_namespace = {**_typing_extra.get_cls_types_namespace(for_type), **(self.tail or {})}
self._types_namespace_stack.append(types_namespace)
try:
yield
finally:
self._types_namespace_stack.pop()


class GenerateSchema:
"""Generate core schema for a Pydantic model, dataclass and types like `str`, `datetime`, ... ."""

__slots__ = (
'_config_wrapper_stack',
'_types_namespace',
'_types_namespace_stack',
'_typevars_map',
'_needs_apply_discriminated_union',
'_has_invalid_schema',
Expand All @@ -293,7 +316,7 @@ def __init__(
) -> None:
# we need a stack for recursing into child models
self._config_wrapper_stack = ConfigWrapperStack(config_wrapper)
self._types_namespace = types_namespace
self._types_namespace_stack = TypesNamespaceStack(types_namespace)
self._typevars_map = typevars_map
self._needs_apply_discriminated_union = False
self._has_invalid_schema = False
Expand All @@ -304,13 +327,13 @@ def __init__(
def __from_parent(
cls,
config_wrapper_stack: ConfigWrapperStack,
types_namespace: dict[str, Any] | None,
types_namespace_stack: TypesNamespaceStack,
typevars_map: dict[Any, Any] | None,
defs: _Definitions,
) -> GenerateSchema:
obj = cls.__new__(cls)
obj._config_wrapper_stack = config_wrapper_stack
obj._types_namespace = types_namespace
obj._types_namespace_stack = types_namespace_stack
obj._typevars_map = typevars_map
obj._needs_apply_discriminated_union = False
obj._has_invalid_schema = False
Expand All @@ -322,12 +345,16 @@ def __from_parent(
def _config_wrapper(self) -> ConfigWrapper:
return self._config_wrapper_stack.tail

@property
def _types_namespace(self) -> dict[str, Any] | None:
return self._types_namespace_stack.tail

@property
def _current_generate_schema(self) -> GenerateSchema:
cls = self._config_wrapper.schema_generator or GenerateSchema
return cls.__from_parent(
self._config_wrapper_stack,
self._types_namespace,
self._types_namespace_stack,
self._typevars_map,
self.defs,
)
Expand Down Expand Up @@ -524,7 +551,7 @@ def _model_schema(self, cls: type[BaseModel]) -> core_schema.CoreSchema:
extras_schema = self.generate_schema(extra_items_type)
break

with self._config_wrapper_stack.push(config_wrapper):
with self._config_wrapper_stack.push(config_wrapper), self._types_namespace_stack.push(cls):
self = self._current_generate_schema
if cls.__pydantic_root_model__:
root_field = self._common_field_schema('root', fields['root'], decorators)
Expand Down Expand Up @@ -1114,19 +1141,14 @@ def _type_alias_type_schema(

origin = get_origin(obj) or obj

namespace = (self._types_namespace or {}).copy()
new_namespace = {**_typing_extra.get_cls_types_namespace(origin), **namespace}
annotation = origin.__value__

self._types_namespace = new_namespace
typevars_map = get_standard_typevars_map(obj)

annotation = _typing_extra.eval_type_lenient(annotation, self._types_namespace, None)
annotation = replace_types(annotation, typevars_map)
schema = self.generate_schema(annotation)
assert schema['type'] != 'definitions'
schema['ref'] = ref # type: ignore
self._types_namespace = namespace or None
with self._types_namespace_stack.push(origin):
annotation = _typing_extra.eval_type_lenient(annotation, self._types_namespace, None)
annotation = replace_types(annotation, typevars_map)
schema = self.generate_schema(annotation)
assert schema['type'] != 'definitions'
schema['ref'] = ref # type: ignore
self.defs.definitions[ref] = schema
return core_schema.definition_reference_schema(ref)

Expand Down Expand Up @@ -1173,7 +1195,7 @@ def _typed_dict_schema(self, typed_dict_cls: Any, origin: Any) -> core_schema.Co
except AttributeError:
config = None

with self._config_wrapper_stack.push(config):
with self._config_wrapper_stack.push(config), self._types_namespace_stack.push(typed_dict_cls):
core_config = self._config_wrapper.core_config(typed_dict_cls)

self = self._current_generate_schema
Expand Down Expand Up @@ -1427,7 +1449,7 @@ def _dataclass_schema(
dataclass = origin

config = getattr(dataclass, '__pydantic_config__', None)
with self._config_wrapper_stack.push(config):
with self._config_wrapper_stack.push(config), self._types_namespace_stack.push(dataclass):
core_config = self._config_wrapper.core_config(dataclass)

self = self._current_generate_schema
Expand Down
20 changes: 20 additions & 0 deletions tests/test_create_model.py
Expand Up @@ -581,3 +581,23 @@ def test_json_schema_with_inner_models_with_duplicate_names():
'title': 'a',
'type': 'object',
}


def test_resolving_forward_refs_across_modules(create_module):
module = create_module(
# language=Python
"""\
from __future__ import annotations
from dataclasses import dataclass
from pydantic import BaseModel
class X(BaseModel):
pass
@dataclass
class Y:
x: X
"""
)
Z = create_model('Z', y=(module.Y, ...))
assert Z(y={'x': {}}).y is not None

0 comments on commit 54a1576

Please sign in to comment.