Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added fix for signature of inherited dataclass #7925

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
103 changes: 56 additions & 47 deletions pydantic/_internal/_dataclasses.py
Expand Up @@ -6,7 +6,7 @@
import typing
import warnings
from functools import partial, wraps
from inspect import Parameter, Signature, signature
from inspect import Parameter, Signature
from typing import Any, Callable, ClassVar

from pydantic_core import (
Expand All @@ -23,8 +23,9 @@
from ..plugin._schema_validator import create_schema_validator
from ..warnings import PydanticDeprecatedSince20
from . import _config, _decorators, _typing_extra
from ._config import ConfigWrapper
from ._fields import collect_dataclass_fields
from ._generate_schema import GenerateSchema
from ._generate_schema import GenerateSchema, generate_pydantic_signature
from ._generics import get_standard_typevars_map
from ._mock_val_ser import set_dataclass_mocks
from ._schema_generation_shared import CallbackGetCoreSchemaHandler
Expand Down Expand Up @@ -123,19 +124,20 @@ def complete_dataclass(
typevars_map,
)

# dataclass.__init__ must be defined here so its `__qualname__` can be changed since functions can't be copied.
# This needs to be called before we change the __init__
sig = generate_dataclass_signature(cls, cls.__pydantic_fields__, config_wrapper) # type: ignore

# dataclass.__init__ must be defined here so its `__qualname__` can be changed since functions can't be copied.
def __init__(__dataclass_self__: PydanticDataclass, *args: Any, **kwargs: Any) -> None:
__tracebackhide__ = True
s = __dataclass_self__
s.__pydantic_validator__.validate_python(ArgsKwargs(args, kwargs), self_instance=s)

__init__.__qualname__ = f'{cls.__qualname__}.__init__'
sig = generate_dataclass_signature(cls)

cls.__init__ = __init__ # type: ignore
cls.__signature__ = sig # type: ignore
cls.__pydantic_config__ = config_wrapper.config_dict # type: ignore

cls.__signature__ = sig # type: ignore
get_core_schema = getattr(cls, '__get_pydantic_core_schema__', None)
try:
if get_core_schema:
Expand Down Expand Up @@ -185,54 +187,61 @@ def validated_setattr(instance: Any, __field: str, __value: str) -> None:
return True


def generate_dataclass_signature(cls: type[StandardDataclass]) -> Signature:
"""Generate signature for a pydantic dataclass.
def process_param_defaults(param: Parameter) -> Parameter:
"""Custom processing where the parameter default is of type FieldInfo

This implementation assumes we do not support custom `__init__`, which is currently true for pydantic dataclasses.
If we change this eventually, we should make this function's logic more closely mirror that from
`pydantic._internal._model_construction.generate_model_signature`.
Args:
param (Parameter): The parameter

Returns:
Parameter: The custom processed parameter
"""
param_default = param.default
if isinstance(param_default, FieldInfo):
annotation = param.annotation
# Replace the annotation if appropriate
# inspect does "clever" things to show annotations as strings because we have
# `from __future__ import annotations` in main, we don't want that
if annotation == 'Any':
annotation = Any

# Replace the field name with the alias if present
name = param.name
alias = param_default.alias
validation_alias = param_default.validation_alias
if validation_alias is None and isinstance(alias, str) and is_valid_identifier(alias):
name = alias
elif isinstance(validation_alias, str) and is_valid_identifier(validation_alias):
name = validation_alias

# Replace the field default
default = param_default.default
if default is PydanticUndefined:
if param_default.default_factory is PydanticUndefined:
default = inspect.Signature.empty
else:
# this is used by dataclasses to indicate a factory exists:
default = dataclasses._HAS_DEFAULT_FACTORY # type: ignore
return param.replace(annotation=annotation, name=name, default=default)
return param


def generate_dataclass_signature(
cls: type[StandardDataclass], fields: dict[str, FieldInfo], config_wrapper: ConfigWrapper
) -> Signature:
"""Generate signature for a pydantic dataclass.

Args:
cls: The dataclass.
fields: The model fields.
config_wrapper: The config wrapper instance.

Returns:
The signature.
The dataclass signature.
"""
sig = signature(cls)
final_params: dict[str, Parameter] = {}

for param in sig.parameters.values():
param_default = param.default
if isinstance(param_default, FieldInfo):
annotation = param.annotation
# Replace the annotation if appropriate
# inspect does "clever" things to show annotations as strings because we have
# `from __future__ import annotations` in main, we don't want that
if annotation == 'Any':
annotation = Any

# Replace the field name with the alias if present
name = param.name
alias = param_default.alias
validation_alias = param_default.validation_alias
if validation_alias is None and isinstance(alias, str) and is_valid_identifier(alias):
name = alias
elif isinstance(validation_alias, str) and is_valid_identifier(validation_alias):
name = validation_alias

# Replace the field default
default = param_default.default
if default is PydanticUndefined:
if param_default.default_factory is PydanticUndefined:
default = inspect.Signature.empty
else:
# this is used by dataclasses to indicate a factory exists:
default = dataclasses._HAS_DEFAULT_FACTORY # type: ignore

param = param.replace(annotation=annotation, name=name, default=default)
final_params[param.name] = param

return Signature(parameters=list(final_params.values()), return_annotation=None)
return generate_pydantic_signature(
init=cls.__init__, fields=fields, config_wrapper=config_wrapper, post_process_parameter=process_param_defaults
)


def is_builtin_dataclass(_cls: type[Any]) -> TypeGuard[type[StandardDataclass]]:
Expand Down
86 changes: 85 additions & 1 deletion pydantic/_internal/_generate_schema.py
Expand Up @@ -82,7 +82,7 @@
CallbackGetCoreSchemaHandler,
)
from ._typing_extra import is_finalvar
from ._utils import lenient_issubclass
from ._utils import is_valid_identifier, lenient_issubclass

if TYPE_CHECKING:
from ..main import BaseModel
Expand Down Expand Up @@ -2083,3 +2083,87 @@ def get(self) -> str | None:
return self._stack[-1]
else:
return None


def generate_pydantic_signature(
init: Callable[..., None],
fields: dict[str, FieldInfo],
config_wrapper: ConfigWrapper,
post_process_parameter: Callable[[Parameter], Parameter] = lambda x: x,
) -> inspect.Signature:
"""Generate signature for a pydantic class generated by inheriting from BaseModel or
using the dataclass annotation

Args:
init: The class init.
fields: The model fields.
config_wrapper: The config wrapper instance.
post_process_parameter: Optional additional processing for parameter

Returns:
The dataclass/BaseModel subclass signature.
"""
from itertools import islice

present_params = signature(init).parameters.values()
merged_params: dict[str, Parameter] = {}
var_kw = None
use_var_kw = False

for param in islice(present_params, 1, None): # skip self arg
# inspect does "clever" things to show annotations as strings because we have
# `from __future__ import annotations` in main, we don't want that
if param.annotation == 'Any':
param = param.replace(annotation=Any)
if param.kind is param.VAR_KEYWORD:
var_kw = param
continue
merged_params[param.name] = post_process_parameter(param)

if var_kw: # if custom init has no var_kw, fields which are not declared in it cannot be passed through
allow_names = config_wrapper.populate_by_name
for field_name, field in fields.items():
# when alias is a str it should be used for signature generation
if isinstance(field.alias, str):
param_name = field.alias
else:
param_name = field_name

if field_name in merged_params or param_name in merged_params:
continue

if not is_valid_identifier(param_name):
if allow_names and is_valid_identifier(field_name):
param_name = field_name
else:
use_var_kw = True
continue

kwargs = {} if field.is_required() else {'default': field.get_default(call_default_factory=False)}
merged_params[param_name] = post_process_parameter(
Parameter(param_name, Parameter.KEYWORD_ONLY, annotation=field.rebuild_annotation(), **kwargs)
)

if config_wrapper.extra == 'allow':
use_var_kw = True

if var_kw and use_var_kw:
# Make sure the parameter for extra kwargs
# does not have the same name as a field
default_model_signature = [
('__pydantic_self__', Parameter.POSITIONAL_OR_KEYWORD),
('data', Parameter.VAR_KEYWORD),
]
if [(p.name, p.kind) for p in present_params] == default_model_signature:
# if this is the standard model signature, use extra_data as the extra args name
var_kw_name = 'extra_data'
else:
# else start from var_kw
var_kw_name = var_kw.name

# generate a name that's definitely unique
while var_kw_name in fields:
var_kw_name += '_'
merged_params[var_kw_name] = post_process_parameter(var_kw.replace(name=var_kw_name))

return inspect.Signature(parameters=list(merged_params.values()), return_annotation=None)
70 changes: 3 additions & 67 deletions pydantic/_internal/_model_construction.py
Expand Up @@ -24,12 +24,12 @@
get_attribute_from_bases,
)
from ._fields import collect_model_fields, is_valid_field_name, is_valid_privateattr_name
from ._generate_schema import GenerateSchema
from ._generate_schema import GenerateSchema, generate_pydantic_signature
from ._generics import PydanticGenericMetadata, get_model_typevars_map
from ._mock_val_ser import MockValSer, set_model_mocks
from ._schema_generation_shared import CallbackGetCoreSchemaHandler
from ._typing_extra import get_cls_types_namespace, is_classvar, parent_frame_namespace
from ._utils import ClassAttribute, is_valid_identifier
from ._utils import ClassAttribute
from ._validate_call import ValidateCallWrapper

if typing.TYPE_CHECKING:
Expand Down Expand Up @@ -518,71 +518,7 @@ def generate_model_signature(
Returns:
The model signature.
"""
from inspect import Parameter, Signature, signature
from itertools import islice

present_params = signature(init).parameters.values()
merged_params: dict[str, Parameter] = {}
var_kw = None
use_var_kw = False

for param in islice(present_params, 1, None): # skip self arg
# inspect does "clever" things to show annotations as strings because we have
# `from __future__ import annotations` in main, we don't want that
if param.annotation == 'Any':
param = param.replace(annotation=Any)
if param.kind is param.VAR_KEYWORD:
var_kw = param
continue
merged_params[param.name] = param

if var_kw: # if custom init has no var_kw, fields which are not declared in it cannot be passed through
allow_names = config_wrapper.populate_by_name
for field_name, field in fields.items():
# when alias is a str it should be used for signature generation
if isinstance(field.alias, str):
param_name = field.alias
else:
param_name = field_name

if field_name in merged_params or param_name in merged_params:
continue

if not is_valid_identifier(param_name):
if allow_names and is_valid_identifier(field_name):
param_name = field_name
else:
use_var_kw = True
continue

kwargs = {} if field.is_required() else {'default': field.get_default(call_default_factory=False)}
merged_params[param_name] = Parameter(
param_name, Parameter.KEYWORD_ONLY, annotation=field.rebuild_annotation(), **kwargs
)

if config_wrapper.extra == 'allow':
use_var_kw = True

if var_kw and use_var_kw:
# Make sure the parameter for extra kwargs
# does not have the same name as a field
default_model_signature = [
('__pydantic_self__', Parameter.POSITIONAL_OR_KEYWORD),
('data', Parameter.VAR_KEYWORD),
]
if [(p.name, p.kind) for p in present_params] == default_model_signature:
# if this is the standard model signature, use extra_data as the extra args name
var_kw_name = 'extra_data'
else:
# else start from var_kw
var_kw_name = var_kw.name

# generate a name that's definitely unique
while var_kw_name in fields:
var_kw_name += '_'
merged_params[var_kw_name] = var_kw.replace(name=var_kw_name)

return Signature(parameters=list(merged_params.values()), return_annotation=None)
return generate_pydantic_signature(init, fields, config_wrapper)


class _PydanticWeakRef:
Expand Down
13 changes: 13 additions & 0 deletions tests/test_dataclasses.py
Expand Up @@ -2507,6 +2507,19 @@ class Model:
)


def test_inherited_dataclass_signature():
@pydantic.dataclasses.dataclass
class A:
a: int

@pydantic.dataclasses.dataclass
class B(A):
b: int

assert str(inspect.signature(A)) == '(a: int) -> None'
assert str(inspect.signature(B)) == '(a: int, b: int) -> None'


def test_dataclasses_with_slots_and_default():
@pydantic.dataclasses.dataclass(slots=True)
class A:
Expand Down