Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use a stack for the types namespace #8378

Merged
merged 5 commits into from Jan 10, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
62 changes: 44 additions & 18 deletions pydantic/_internal/_generate_schema.py
Expand Up @@ -267,12 +267,36 @@ def _add_custom_serialization_from_json_encoders(
return schema


TypesNamespace = dict[str, Any] | None


class TypesNamespaceStack:
sydney-runkle marked this conversation as resolved.
Show resolved Hide resolved
"""A stack of types namespaces."""

def __init__(self, types_namespace: TypesNamespace):
self._types_namespace_stack: list[TypesNamespace] = [types_namespace]

@property
def tail(self) -> TypesNamespace:
return self._types_namespace_stack[-1]

@contextmanager
def push(self, for_type: type[Any], existing_types_namespace: TypesNamespace):
# Not sure whether we should really preserve/extend the existing types namespace or not
types_namespace = {**_typing_extra.get_cls_types_namespace(for_type), **(existing_types_namespace or {})}
self._types_namespace_stack.append(types_namespace)
try:
yield
finally:
self._types_namespace_stack.pop()


class GenerateSchema:
"""Generate core schema for a Pydantic model, dataclass and types like `str`, `datetime`, ... ."""

__slots__ = (
'_config_wrapper_stack',
'_types_namespace',
'_types_namespace_stack',
'_typevars_map',
'_needs_apply_discriminated_union',
'_has_invalid_schema',
Expand All @@ -288,7 +312,7 @@ def __init__(
) -> None:
# we need a stack for recursing into child models
self._config_wrapper_stack = ConfigWrapperStack(config_wrapper)
self._types_namespace = types_namespace
self._types_namespace_stack = TypesNamespaceStack(types_namespace)
self._typevars_map = typevars_map
self._needs_apply_discriminated_union = False
self._has_invalid_schema = False
Expand All @@ -299,13 +323,13 @@ def __init__(
def __from_parent(
cls,
config_wrapper_stack: ConfigWrapperStack,
types_namespace: dict[str, Any] | None,
types_namespace_stack: TypesNamespaceStack,
typevars_map: dict[Any, Any] | None,
defs: _Definitions,
) -> GenerateSchema:
obj = cls.__new__(cls)
obj._config_wrapper_stack = config_wrapper_stack
obj._types_namespace = types_namespace
obj._types_namespace_stack = types_namespace_stack
obj._typevars_map = typevars_map
obj._needs_apply_discriminated_union = False
obj._has_invalid_schema = False
Expand All @@ -317,12 +341,16 @@ def __from_parent(
def _config_wrapper(self) -> ConfigWrapper:
return self._config_wrapper_stack.tail

@property
def _types_namespace(self) -> dict[str, Any] | None:
return self._types_namespace_stack.tail

@property
def _current_generate_schema(self) -> GenerateSchema:
cls = self._config_wrapper.schema_generator or GenerateSchema
return cls.__from_parent(
self._config_wrapper_stack,
self._types_namespace,
self._types_namespace_stack,
self._typevars_map,
self.defs,
)
Expand Down Expand Up @@ -709,7 +737,8 @@ def _generate_schema(self, obj: Any) -> core_schema.CoreSchema:
self._has_invalid_schema = False
needs_apply_discriminated_union = self._needs_apply_discriminated_union
self._needs_apply_discriminated_union = False
schema = self._post_process_generated_schema(self._generate_schema_inner(obj))
schema = self._generate_schema_inner(obj)
schema = self._post_process_generated_schema(schema)
self._has_invalid_schema = self._has_invalid_schema or has_invalid_schema
self._needs_apply_discriminated_union = self._needs_apply_discriminated_union or needs_apply_discriminated_union
return schema
Expand Down Expand Up @@ -1111,19 +1140,14 @@ def _type_alias_type_schema(

origin = get_origin(obj) or obj

namespace = (self._types_namespace or {}).copy()
new_namespace = {**_typing_extra.get_cls_types_namespace(origin), **namespace}
annotation = origin.__value__

self._types_namespace = new_namespace
typevars_map = get_standard_typevars_map(obj)

annotation = _typing_extra.eval_type_lenient(annotation, self._types_namespace, None)
annotation = replace_types(annotation, typevars_map)
schema = self.generate_schema(annotation)
assert schema['type'] != 'definitions'
schema['ref'] = ref # type: ignore
self._types_namespace = namespace or None
with self._types_namespace_stack.push(origin, self._types_namespace):
annotation = _typing_extra.eval_type_lenient(annotation, self._types_namespace, None)
annotation = replace_types(annotation, typevars_map)
schema = self.generate_schema(annotation)
assert schema['type'] != 'definitions'
schema['ref'] = ref # type: ignore
self.defs.definitions[ref] = schema
return core_schema.definition_reference_schema(ref)

Expand Down Expand Up @@ -1424,7 +1448,9 @@ def _dataclass_schema(
dataclass = origin

config = getattr(dataclass, '__pydantic_config__', None)
with self._config_wrapper_stack.push(config):
with self._config_wrapper_stack.push(config), self._types_namespace_stack.push(
dataclass, self._types_namespace
):
core_config = self._config_wrapper.core_config(dataclass)

self = self._current_generate_schema
Expand Down