Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added fix for signature of inherited dataclass #7925

Merged
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
157 changes: 115 additions & 42 deletions pydantic/_internal/_dataclasses.py
Expand Up @@ -23,6 +23,7 @@
from ..plugin._schema_validator import create_schema_validator
from ..warnings import PydanticDeprecatedSince20
from . import _config, _decorators, _typing_extra
from ._config import ConfigWrapper
from ._fields import collect_dataclass_fields
from ._generate_schema import GenerateSchema
from ._generics import get_standard_typevars_map
Expand Down Expand Up @@ -123,19 +124,23 @@ def complete_dataclass(
typevars_map,
)

# dataclass.__init__ must be defined here so its `__qualname__` can be changed since functions can't be copied.
# Get a temporary signature before we change the __init__
fields = cls.__pydantic_fields__ # type: ignore
# This needs to be called before we change the __init__
sig = generate_dataclass_signature(cls, fields, config_wrapper)
howsunjow marked this conversation as resolved.
Show resolved Hide resolved

# dataclass.__init__ must be defined here so its `__qualname__` can be changed since functions can't be copied.
def __init__(__dataclass_self__: PydanticDataclass, *args: Any, **kwargs: Any) -> None:
__tracebackhide__ = True
s = __dataclass_self__
s.__pydantic_validator__.validate_python(ArgsKwargs(args, kwargs), self_instance=s)

__init__.__qualname__ = f'{cls.__qualname__}.__init__'
sig = generate_dataclass_signature(cls)

cls.__init__ = __init__ # type: ignore
cls.__signature__ = sig # type: ignore
cls.__pydantic_config__ = config_wrapper.config_dict # type: ignore

# Set to the temporary signature
howsunjow marked this conversation as resolved.
Show resolved Hide resolved
cls.__signature__ = sig # type: ignore
get_core_schema = getattr(cls, '__get_pydantic_core_schema__', None)
try:
if get_core_schema:
Expand Down Expand Up @@ -185,54 +190,122 @@ def validated_setattr(instance: Any, __field: str, __value: str) -> None:
return True


def generate_dataclass_signature(cls: type[StandardDataclass]) -> Signature:
"""Generate signature for a pydantic dataclass.
def process_param_defaults(param: Parameter) -> Parameter:
"""Custom processing where the parameter default is of type FieldInfo

This implementation assumes we do not support custom `__init__`, which is currently true for pydantic dataclasses.
If we change this eventually, we should make this function's logic more closely mirror that from
`pydantic._internal._model_construction.generate_model_signature`.
Args:
param (Parameter): The parameter

Returns:
Parameter: The custom processed parameter
"""
param_default = param.default
if isinstance(param_default, FieldInfo):
annotation = param.annotation
# Replace the annotation if appropriate
# inspect does "clever" things to show annotations as strings because we have
# `from __future__ import annotations` in main, we don't want that
if annotation == 'Any':
annotation = Any

# Replace the field name with the alias if present
name = param.name
alias = param_default.alias
validation_alias = param_default.validation_alias
if validation_alias is None and isinstance(alias, str) and is_valid_identifier(alias):
name = alias
elif isinstance(validation_alias, str) and is_valid_identifier(validation_alias):
name = validation_alias

# Replace the field default
default = param_default.default
if default is PydanticUndefined:
if param_default.default_factory is PydanticUndefined:
default = inspect.Signature.empty
else:
# this is used by dataclasses to indicate a factory exists:
default = dataclasses._HAS_DEFAULT_FACTORY # type: ignore
return param.replace(annotation=annotation, name=name, default=default)
return param


def generate_dataclass_signature(
cls: type[StandardDataclass], fields: dict[str, FieldInfo], config_wrapper: ConfigWrapper
) -> Signature:
"""Generate signature for a pydantic dataclass.

Args:
cls: The dataclass.
fields: The model fields.
config_wrapper: The config wrapper instance.

Returns:
The signature.
The dataclass signature.
"""
sig = signature(cls)
final_params: dict[str, Parameter] = {}

for param in sig.parameters.values():
param_default = param.default
if isinstance(param_default, FieldInfo):
annotation = param.annotation
# Replace the annotation if appropriate
# inspect does "clever" things to show annotations as strings because we have
# `from __future__ import annotations` in main, we don't want that
if annotation == 'Any':
annotation = Any

# Replace the field name with the alias if present
name = param.name
alias = param_default.alias
validation_alias = param_default.validation_alias
if validation_alias is None and isinstance(alias, str) and is_valid_identifier(alias):
name = alias
elif isinstance(validation_alias, str) and is_valid_identifier(validation_alias):
name = validation_alias

# Replace the field default
default = param_default.default
if default is PydanticUndefined:
if param_default.default_factory is PydanticUndefined:
default = inspect.Signature.empty
from itertools import islice

present_params = signature(cls.__init__).parameters.values()
merged_params: dict[str, Parameter] = {}
var_kw = None
use_var_kw = False

for param in islice(present_params, 1, None): # skip self arg
# inspect does "clever" things to show annotations as strings because we have
# `from __future__ import annotations` in main, we don't want that
if param.annotation == 'Any':
param = param.replace(annotation=Any)
if param.kind is param.VAR_KEYWORD:
var_kw = param
continue
merged_params[param.name] = process_param_defaults(param)

if var_kw: # if custom init has no var_kw, fields which are not declared in it cannot be passed through
allow_names = config_wrapper.populate_by_name
for field_name, field in fields.items():
# when alias is a str it should be used for signature generation
if isinstance(field.alias, str):
param_name = field.alias
else:
param_name = field_name

if field_name in merged_params or param_name in merged_params:
continue

if not is_valid_identifier(param_name):
if allow_names and is_valid_identifier(field_name):
param_name = field_name
else:
# this is used by dataclasses to indicate a factory exists:
default = dataclasses._HAS_DEFAULT_FACTORY # type: ignore
use_var_kw = True
continue

kwargs = {} if field.is_required() else {'default': field.get_default(call_default_factory=False)}
merged_params[param_name] = process_param_defaults(
Parameter(param_name, Parameter.KEYWORD_ONLY, annotation=field.rebuild_annotation(), **kwargs)
)

if config_wrapper.extra == 'allow':
use_var_kw = True

if var_kw and use_var_kw:
# Make sure the parameter for extra kwargs
# does not have the same name as a field
default_model_signature = [
('__pydantic_self__', Parameter.POSITIONAL_OR_KEYWORD),
('data', Parameter.VAR_KEYWORD),
]
if [(p.name, p.kind) for p in present_params] == default_model_signature:
# if this is the standard model signature, use extra_data as the extra args name
var_kw_name = 'extra_data'
else:
# else start from var_kw
var_kw_name = var_kw.name

param = param.replace(annotation=annotation, name=name, default=default)
final_params[param.name] = param
# generate a name that's definitely unique
while var_kw_name in fields:
var_kw_name += '_'
merged_params[var_kw_name] = process_param_defaults(var_kw.replace(name=var_kw_name))

return Signature(parameters=list(final_params.values()), return_annotation=None)
return Signature(parameters=list(merged_params.values()), return_annotation=None)


def is_builtin_dataclass(_cls: type[Any]) -> TypeGuard[type[StandardDataclass]]:
Expand Down
13 changes: 13 additions & 0 deletions tests/test_dataclasses.py
Expand Up @@ -2507,6 +2507,19 @@ class Model:
)


def test_inherited_dataclass_signature():
@pydantic.dataclasses.dataclass
class A:
a: int

@pydantic.dataclasses.dataclass
class B(A):
b: int

assert str(inspect.signature(A)) == '(a: int) -> None'
assert str(inspect.signature(B)) == '(a: int, b: int) -> None'


def test_dataclasses_with_slots_and_default():
@pydantic.dataclasses.dataclass(slots=True)
class A:
Expand Down