Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor signature generation for simplicity #8572

Merged
merged 13 commits into from Jan 19, 2024
47 changes: 2 additions & 45 deletions pydantic/_internal/_dataclasses.py
Expand Up @@ -2,16 +2,13 @@
from __future__ import annotations as _annotations

import dataclasses
import inspect
import typing
import warnings
from functools import partial, wraps
from inspect import Parameter
from typing import Any, Callable, ClassVar

from pydantic_core import (
ArgsKwargs,
PydanticUndefined,
SchemaSerializer,
SchemaValidator,
core_schema,
Expand All @@ -23,13 +20,12 @@
from ..plugin._schema_validator import create_schema_validator
from ..warnings import PydanticDeprecatedSince20
from . import _config, _decorators, _typing_extra
from ._constructor_signature_generators import generate_pydantic_signature
from ._fields import collect_dataclass_fields
from ._generate_schema import GenerateSchema
from ._generics import get_standard_typevars_map
from ._mock_val_ser import set_dataclass_mocks
from ._schema_generation_shared import CallbackGetCoreSchemaHandler
from ._utils import is_valid_identifier
from ._signature import generate_pydantic_signature

if typing.TYPE_CHECKING:
from ..config import ConfigDict
Expand Down Expand Up @@ -129,7 +125,7 @@ def complete_dataclass(
init=cls.__init__,
fields=cls.__pydantic_fields__, # type: ignore
config_wrapper=config_wrapper,
parameter_post_processor=process_param_defaults,
is_dataclass=True,
)

# dataclass.__init__ must be defined here so its `__qualname__` can be changed since functions can't be copied.
Expand Down Expand Up @@ -192,45 +188,6 @@ def validated_setattr(instance: Any, __field: str, __value: str) -> None:
return True


def process_param_defaults(param: Parameter) -> Parameter:
"""Modify the signature for a parameter in a dataclass where the default value is a FieldInfo instance.

Args:
param (Parameter): The parameter

Returns:
Parameter: The custom processed parameter
"""
param_default = param.default
if isinstance(param_default, FieldInfo):
annotation = param.annotation
# Replace the annotation if appropriate
# inspect does "clever" things to show annotations as strings because we have
# `from __future__ import annotations` in main, we don't want that
if annotation == 'Any':
annotation = Any

# Replace the field name with the alias if present
name = param.name
alias = param_default.alias
validation_alias = param_default.validation_alias
if validation_alias is None and isinstance(alias, str) and is_valid_identifier(alias):
name = alias
elif isinstance(validation_alias, str) and is_valid_identifier(validation_alias):
name = validation_alias

# Replace the field default
default = param_default.default
if default is PydanticUndefined:
if param_default.default_factory is PydanticUndefined:
default = inspect.Signature.empty
else:
# this is used by dataclasses to indicate a factory exists:
default = dataclasses._HAS_DEFAULT_FACTORY # type: ignore
return param.replace(annotation=annotation, name=name, default=default)
return param


def is_builtin_dataclass(_cls: type[Any]) -> TypeGuard[type[StandardDataclass]]:
"""Returns True if a class is a stdlib dataclass and *not* a pydantic dataclass.

Expand Down
2 changes: 1 addition & 1 deletion pydantic/_internal/_model_construction.py
Expand Up @@ -18,13 +18,13 @@
from ..plugin._schema_validator import create_schema_validator
from ..warnings import GenericBeforeBaseModelWarning, PydanticDeprecatedSince20
from ._config import ConfigWrapper
from ._constructor_signature_generators import generate_pydantic_signature
from ._decorators import DecoratorInfos, PydanticDescriptorProxy, get_attribute_from_bases
from ._fields import collect_model_fields, is_valid_field_name, is_valid_privateattr_name
from ._generate_schema import GenerateSchema
from ._generics import PydanticGenericMetadata, get_model_typevars_map
from ._mock_val_ser import MockValSer, set_model_mocks
from ._schema_generation_shared import CallbackGetCoreSchemaHandler
from ._signature import generate_pydantic_signature
from ._typing_extra import get_cls_types_namespace, is_annotated, is_classvar, parent_frame_namespace
from ._utils import ClassAttribute, SafeGetItemProxy
from ._validate_call import ValidateCallWrapper
Expand Down
@@ -1,47 +1,79 @@
from __future__ import annotations

import dataclasses
from inspect import Parameter, Signature, signature
from typing import TYPE_CHECKING, Any, Callable

from pydantic_core import PydanticUndefined

from ._config import ConfigWrapper
from ._utils import is_valid_identifier

if TYPE_CHECKING:
from ..fields import FieldInfo


def _field_name_or_alias(field_name: str, field_info: FieldInfo) -> str:
def _field_name_for_signature(field_name: str, field_info: FieldInfo) -> str:
"""Extract the correct name to use for the field when generating a signature.
If it has a valid alias then returns its alais, else returns its name

Assuming the field has a valid alias, this will return the alias. Otherwise, it will return the field name.
First priority is given to the validation_alias, then the alias, then the field name.

Args:
field_name: The name of the field
field_info: The field
field_info: The corresponding FieldInfo object.

Returns:
The correct name to use when generating a signature.
"""
return (
field_info.alias if isinstance(field_info.alias, str) and is_valid_identifier(field_info.alias) else field_name
)

def _alias_if_valid(x: Any) -> str | None:
"""Return the alias if it is a valid alias and identifier, else None."""
return x if isinstance(x, str) and is_valid_identifier(x) else None

return _alias_if_valid(field_info.alias) or _alias_if_valid(field_info.validation_alias) or field_name

def generate_pydantic_signature(
init: Callable[..., None],
fields: dict[str, FieldInfo],
config_wrapper: ConfigWrapper,
parameter_post_processor: Callable[[Parameter], Parameter] = lambda x: x,
) -> Signature:
"""Generate signature for a pydantic BaseModel or dataclass.

def _process_param_defaults(param: Parameter) -> Parameter:
"""Modify the signature for a parameter in a dataclass where the default value is a FieldInfo instance.

Args:
init: The class init.
fields: The model fields.
config_wrapper: The config wrapper instance.
parameter_post_processor: Optional additional processing for parameter
param (Parameter): The parameter

Returns:
The dataclass/BaseModel subclass signature.
Parameter: The custom processed parameter
"""
from ..fields import FieldInfo

param_default = param.default
if isinstance(param_default, FieldInfo):
annotation = param.annotation
# Replace the annotation if appropriate
# inspect does "clever" things to show annotations as strings because we have
# `from __future__ import annotations` in main, we don't want that
if annotation == 'Any':
annotation = Any

# Replace the field default
default = param_default.default
if default is PydanticUndefined:
if param_default.default_factory is PydanticUndefined:
default = Signature.empty
else:
# this is used by dataclasses to indicate a factory exists:
default = dataclasses._HAS_DEFAULT_FACTORY # type: ignore
return param.replace(
annotation=annotation, name=_field_name_for_signature(param.name, param_default), default=default
)
return param


def _generate_signature_parameters(
init: Callable[..., None],
fields: dict[str, FieldInfo],
config_wrapper: ConfigWrapper,
) -> dict[str, Parameter]:
"""Generate a mapping of parameter names to Parameter objects for a pydantic BaseModel or dataclass."""
from itertools import islice

present_params = signature(init).parameters.values()
Expand All @@ -53,20 +85,19 @@ def generate_pydantic_signature(
# inspect does "clever" things to show annotations as strings because we have
# `from __future__ import annotations` in main, we don't want that
if fields.get(param.name):
param_name = _field_name_or_alias(param.name, fields[param.name])
param = param.replace(name=param_name)
param = param.replace(name=_field_name_for_signature(param.name, fields[param.name]))
if param.annotation == 'Any':
param = param.replace(annotation=Any)
if param.kind is param.VAR_KEYWORD:
var_kw = param
continue
merged_params[param.name] = parameter_post_processor(param)
merged_params[param.name] = param

if var_kw: # if custom init has no var_kw, fields which are not declared in it cannot be passed through
allow_names = config_wrapper.populate_by_name
for field_name, field in fields.items():
# when alias is a str it should be used for signature generation
param_name = _field_name_or_alias(field_name, field)
param_name = _field_name_for_signature(field_name, field)

if field_name in merged_params or param_name in merged_params:
continue
Expand All @@ -79,8 +110,8 @@ def generate_pydantic_signature(
continue

kwargs = {} if field.is_required() else {'default': field.get_default(call_default_factory=False)}
merged_params[param_name] = parameter_post_processor(
Parameter(param_name, Parameter.KEYWORD_ONLY, annotation=field.rebuild_annotation(), **kwargs)
merged_params[param_name] = Parameter(
param_name, Parameter.KEYWORD_ONLY, annotation=field.rebuild_annotation(), **kwargs
)

if config_wrapper.extra == 'allow':
Expand All @@ -103,6 +134,28 @@ def generate_pydantic_signature(
# generate a name that's definitely unique
while var_kw_name in fields:
var_kw_name += '_'
merged_params[var_kw_name] = parameter_post_processor(var_kw.replace(name=var_kw_name))
merged_params[var_kw_name] = var_kw.replace(name=var_kw_name)

return merged_params


def generate_pydantic_signature(
init: Callable[..., None], fields: dict[str, FieldInfo], config_wrapper: ConfigWrapper, is_dataclass: bool = False
) -> Signature:
"""Generate signature for a pydantic BaseModel or dataclass.

Args:
init: The class init.
fields: The model fields.
config_wrapper: The config wrapper instance.
is_dataclass: Whether the model is a dataclass.

Returns:
The dataclass/BaseModel subclass signature.
"""
merged_params = _generate_signature_parameters(init, fields, config_wrapper)

if is_dataclass:
merged_params = {k: _process_param_defaults(v) for k, v in merged_params.items()}

return Signature(parameters=list(merged_params.values()), return_annotation=None)
6 changes: 3 additions & 3 deletions pydantic/main.py
Expand Up @@ -107,6 +107,9 @@ class BaseModel(metaclass=_model_construction.ModelMetaclass):
This replaces `Model.__fields__` from Pydantic V1.
"""

model_computed_fields: ClassVar[dict[str, ComputedFieldInfo]]
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Preferential move of this field

"""A dictionary of computed field names and their corresponding `ComputedFieldInfo` objects."""

__class_vars__: ClassVar[set[str]]
__private_attributes__: ClassVar[dict[str, ModelPrivateAttr]]
__signature__: ClassVar[Signature]
Expand All @@ -122,9 +125,6 @@ class BaseModel(metaclass=_model_construction.ModelMetaclass):
__pydantic_serializer__: ClassVar[SchemaSerializer]
__pydantic_validator__: ClassVar[SchemaValidator]

model_computed_fields: ClassVar[dict[str, ComputedFieldInfo]]
"""A dictionary of computed field names and their corresponding `ComputedFieldInfo` objects."""

# Instance attributes
# Note: we use the non-existent kwarg `init=False` in pydantic.fields.Field below so that @dataclass_transform
# doesn't think these are valid as keyword arguments to the class initializer.
Expand Down