Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use a stack for the types namespace #8378

Merged
merged 5 commits into from Jan 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
60 changes: 41 additions & 19 deletions pydantic/_internal/_generate_schema.py
Expand Up @@ -267,12 +267,35 @@ def _add_custom_serialization_from_json_encoders(
return schema


TypesNamespace = Union[Dict[str, Any], None]


class TypesNamespaceStack:
sydney-runkle marked this conversation as resolved.
Show resolved Hide resolved
"""A stack of types namespaces."""

def __init__(self, types_namespace: TypesNamespace):
self._types_namespace_stack: list[TypesNamespace] = [types_namespace]

@property
def tail(self) -> TypesNamespace:
return self._types_namespace_stack[-1]

@contextmanager
def push(self, for_type: type[Any]):
types_namespace = {**_typing_extra.get_cls_types_namespace(for_type), **(self.tail or {})}
self._types_namespace_stack.append(types_namespace)
try:
yield
finally:
self._types_namespace_stack.pop()


class GenerateSchema:
"""Generate core schema for a Pydantic model, dataclass and types like `str`, `datetime`, ... ."""

__slots__ = (
'_config_wrapper_stack',
'_types_namespace',
'_types_namespace_stack',
'_typevars_map',
'_needs_apply_discriminated_union',
'_has_invalid_schema',
Expand All @@ -288,7 +311,7 @@ def __init__(
) -> None:
# we need a stack for recursing into child models
self._config_wrapper_stack = ConfigWrapperStack(config_wrapper)
self._types_namespace = types_namespace
self._types_namespace_stack = TypesNamespaceStack(types_namespace)
self._typevars_map = typevars_map
self._needs_apply_discriminated_union = False
self._has_invalid_schema = False
Expand All @@ -299,13 +322,13 @@ def __init__(
def __from_parent(
cls,
config_wrapper_stack: ConfigWrapperStack,
types_namespace: dict[str, Any] | None,
types_namespace_stack: TypesNamespaceStack,
typevars_map: dict[Any, Any] | None,
defs: _Definitions,
) -> GenerateSchema:
obj = cls.__new__(cls)
obj._config_wrapper_stack = config_wrapper_stack
obj._types_namespace = types_namespace
obj._types_namespace_stack = types_namespace_stack
obj._typevars_map = typevars_map
obj._needs_apply_discriminated_union = False
obj._has_invalid_schema = False
Expand All @@ -317,12 +340,16 @@ def __from_parent(
def _config_wrapper(self) -> ConfigWrapper:
return self._config_wrapper_stack.tail

@property
def _types_namespace(self) -> dict[str, Any] | None:
return self._types_namespace_stack.tail

@property
def _current_generate_schema(self) -> GenerateSchema:
cls = self._config_wrapper.schema_generator or GenerateSchema
return cls.__from_parent(
self._config_wrapper_stack,
self._types_namespace,
self._types_namespace_stack,
self._typevars_map,
self.defs,
)
Expand Down Expand Up @@ -526,7 +553,7 @@ def _model_schema(self, cls: type[BaseModel]) -> core_schema.CoreSchema:
extras_schema = self.generate_schema(extra_items_type)
break

with self._config_wrapper_stack.push(config_wrapper):
with self._config_wrapper_stack.push(config_wrapper), self._types_namespace_stack.push(cls):
self = self._current_generate_schema
if cls.__pydantic_root_model__:
root_field = self._common_field_schema('root', fields['root'], decorators)
Expand Down Expand Up @@ -1111,19 +1138,14 @@ def _type_alias_type_schema(

origin = get_origin(obj) or obj

namespace = (self._types_namespace or {}).copy()
new_namespace = {**_typing_extra.get_cls_types_namespace(origin), **namespace}
annotation = origin.__value__

self._types_namespace = new_namespace
typevars_map = get_standard_typevars_map(obj)

annotation = _typing_extra.eval_type_lenient(annotation, self._types_namespace, None)
annotation = replace_types(annotation, typevars_map)
schema = self.generate_schema(annotation)
assert schema['type'] != 'definitions'
schema['ref'] = ref # type: ignore
self._types_namespace = namespace or None
with self._types_namespace_stack.push(origin):
annotation = _typing_extra.eval_type_lenient(annotation, self._types_namespace, None)
annotation = replace_types(annotation, typevars_map)
schema = self.generate_schema(annotation)
assert schema['type'] != 'definitions'
schema['ref'] = ref # type: ignore
self.defs.definitions[ref] = schema
return core_schema.definition_reference_schema(ref)

Expand Down Expand Up @@ -1170,7 +1192,7 @@ def _typed_dict_schema(self, typed_dict_cls: Any, origin: Any) -> core_schema.Co
except AttributeError:
config = None

with self._config_wrapper_stack.push(config):
with self._config_wrapper_stack.push(config), self._types_namespace_stack.push(typed_dict_cls):
core_config = self._config_wrapper.core_config(typed_dict_cls)

self = self._current_generate_schema
Expand Down Expand Up @@ -1424,7 +1446,7 @@ def _dataclass_schema(
dataclass = origin

config = getattr(dataclass, '__pydantic_config__', None)
with self._config_wrapper_stack.push(config):
with self._config_wrapper_stack.push(config), self._types_namespace_stack.push(dataclass):
core_config = self._config_wrapper.core_config(dataclass)

self = self._current_generate_schema
Expand Down
20 changes: 20 additions & 0 deletions tests/test_create_model.py
Expand Up @@ -548,3 +548,23 @@ def test_json_schema_with_inner_models_with_duplicate_names():
'title': 'a',
'type': 'object',
}


def test_resolving_forward_refs_across_modules(create_module):
module = create_module(
# language=Python
"""\
from __future__ import annotations
from dataclasses import dataclass
from pydantic import BaseModel

class X(BaseModel):
pass

@dataclass
class Y:
x: X
"""
)
Z = create_model('Z', y=(module.Y, ...))
assert Z(y={'x': {}}).y is not None