Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PrivateAttr is passed from Annotated default position #8004

Merged
merged 9 commits into from Nov 3, 2023
85 changes: 63 additions & 22 deletions pdm.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

60 changes: 47 additions & 13 deletions pydantic/_internal/_model_construction.py
Expand Up @@ -9,20 +9,29 @@
from types import FunctionType
from typing import Any, Callable, Generic, Mapping

import typing_extensions
from pydantic_core import PydanticUndefined, SchemaSerializer
from typing_extensions import dataclass_transform, deprecated

from ..errors import PydanticUndefinedAnnotation, PydanticUserError
from ..plugin._schema_validator import create_schema_validator
from ..warnings import GenericBeforeBaseModelWarning, PydanticDeprecatedSince20
from ._config import ConfigWrapper
from ._decorators import DecoratorInfos, PydanticDescriptorProxy, get_attribute_from_bases
from ._fields import collect_model_fields, is_valid_field_name, is_valid_privateattr_name
from ._decorators import (
DecoratorInfos,
PydanticDescriptorProxy,
get_attribute_from_bases,
)
from ._fields import (
collect_model_fields,
is_valid_field_name,
is_valid_privateattr_name,
)
from ._generate_schema import GenerateSchema, generate_pydantic_signature
from ._generics import PydanticGenericMetadata, get_model_typevars_map
from ._mock_val_ser import MockValSer, set_model_mocks
from ._schema_generation_shared import CallbackGetCoreSchemaHandler
from ._typing_extra import get_cls_types_namespace, is_classvar, parent_frame_namespace
from ._typing_extra import get_cls_types_namespace, is_annotated, is_classvar, parent_frame_namespace
from ._utils import ClassAttribute
from ._validate_call import ValidateCallWrapper

Expand Down Expand Up @@ -84,7 +93,11 @@ def __new__(
# that `BaseModel` itself won't have any bases, but any subclass of it will, to determine whether the `__new__`
# call we're in the middle of is for the `BaseModel` class.
if bases:
base_field_names, class_vars, base_private_attributes = mcs._collect_bases_data(bases)
(
base_field_names,
class_vars,
base_private_attributes,
) = mcs._collect_bases_data(bases)

config_wrapper = ConfigWrapper.for_model(bases, namespace, kwargs)
namespace['model_config'] = config_wrapper.config_dict
Expand All @@ -108,7 +121,10 @@ def wrapped_model_post_init(self: BaseModel, __context: Any) -> None:
namespace['model_post_init'] = init_private_attributes

namespace['__class_vars__'] = class_vars
namespace['__private_attributes__'] = {**base_private_attributes, **private_attributes}
namespace['__private_attributes__'] = {
**base_private_attributes,
**private_attributes,
}
sydney-runkle marked this conversation as resolved.
Show resolved Hide resolved

if config_wrapper.frozen:
set_default_hash_func(namespace, bases)
Expand Down Expand Up @@ -163,7 +179,9 @@ def wrapped_model_post_init(self: BaseModel, __context: Any) -> None:
'parameters': parameters,
}

cls.__pydantic_complete__ = False # Ensure this specific class gets completed
cls.__pydantic_complete__ = (
False # Ensure this specific class gets completed
)

# preserve `__set_name__` protocol defined in https://peps.python.org/pep-0487
# for attributes not in `new_namespace` (e.g. private attributes)
Expand Down Expand Up @@ -241,10 +259,14 @@ def _collect_bases_data(bases: tuple[type[Any], ...]) -> tuple[set[str], set[str

@property
@deprecated(
'The `__fields__` attribute is deprecated, use `model_fields` instead.', category=PydanticDeprecatedSince20
'The `__fields__` attribute is deprecated, use `model_fields` instead.',
category=PydanticDeprecatedSince20,
)
def __fields__(self) -> dict[str, FieldInfo]:
warnings.warn('The `__fields__` attribute is deprecated, use `model_fields` instead.', DeprecationWarning)
warnings.warn(
'The `__fields__` attribute is deprecated, use `model_fields` instead.',
DeprecationWarning,
)
return self.model_fields # type: ignore


Expand Down Expand Up @@ -365,7 +387,8 @@ def inspect_namespace( # noqa C901
)
elif isinstance(value, FieldInfo):
raise PydanticUserError(
f'Field {var_name!r} requires a type annotation', code='model-field-missing-annotation'
f'Field {var_name!r} requires a type annotation',
code='model-field-missing-annotation',
)
else:
raise PydanticUserError(
Expand All @@ -374,7 +397,6 @@ def inspect_namespace( # noqa C901
f"error by annotating it as a `ClassVar` or updating `model_config['ignored_types']`.",
code='model-field-missing-annotation',
)

for ann_name, ann_type in raw_annotations.items():
if (
is_valid_privateattr_name(ann_name)
Expand All @@ -384,6 +406,12 @@ def inspect_namespace( # noqa C901
and ann_type not in all_ignored_types
and getattr(ann_type, '__module__', None) != 'functools'
):
if is_annotated(ann_type):
_, *metadata = typing_extensions.get_args(ann_type)
private_attr = next((v for v in metadata if isinstance(v, ModelPrivateAttr)), None)
if private_attr is not None:
private_attributes[ann_name] = private_attr
continue
private_attributes[ann_name] = PrivateAttr()

return private_attributes
Expand All @@ -405,7 +433,10 @@ def hash_func(self: Any) -> int:


def set_model_fields(
cls: type[BaseModel], bases: tuple[type[Any], ...], config_wrapper: ConfigWrapper, types_namespace: dict[str, Any]
cls: type[BaseModel],
bases: tuple[type[Any], ...],
config_wrapper: ConfigWrapper,
types_namespace: dict[str, Any],
) -> None:
"""Collect and set `cls.model_fields` and `cls.__class_vars__`.

Expand Down Expand Up @@ -513,13 +544,16 @@ def complete_model_class(

# set __signature__ attr only for model class, but not for its instances
cls.__signature__ = ClassAttribute(
'__signature__', generate_model_signature(cls.__init__, cls.model_fields, config_wrapper)
'__signature__',
generate_model_signature(cls.__init__, cls.model_fields, config_wrapper),
)
return True


def generate_model_signature(
init: Callable[..., None], fields: dict[str, FieldInfo], config_wrapper: ConfigWrapper
init: Callable[..., None],
fields: dict[str, FieldInfo],
config_wrapper: ConfigWrapper,
) -> Signature:
"""Generate signature for model based on its fields.

Expand Down