Skip to content

Commit

Permalink
Added fix for signature of inherited dataclass (#7925)
Browse files Browse the repository at this point in the history
Co-authored-by: howsun.jow <howsun@etiq.ai>
  • Loading branch information
howsunjow and hjow1 committed Oct 27, 2023
1 parent 419398d commit 0ca1cf2
Show file tree
Hide file tree
Showing 4 changed files with 157 additions and 115 deletions.
103 changes: 56 additions & 47 deletions pydantic/_internal/_dataclasses.py
Expand Up @@ -6,7 +6,7 @@
import typing
import warnings
from functools import partial, wraps
from inspect import Parameter, Signature, signature
from inspect import Parameter, Signature
from typing import Any, Callable, ClassVar

from pydantic_core import (
Expand All @@ -23,8 +23,9 @@
from ..plugin._schema_validator import create_schema_validator
from ..warnings import PydanticDeprecatedSince20
from . import _config, _decorators, _typing_extra
from ._config import ConfigWrapper
from ._fields import collect_dataclass_fields
from ._generate_schema import GenerateSchema
from ._generate_schema import GenerateSchema, generate_pydantic_signature
from ._generics import get_standard_typevars_map
from ._mock_val_ser import set_dataclass_mocks
from ._schema_generation_shared import CallbackGetCoreSchemaHandler
Expand Down Expand Up @@ -123,19 +124,20 @@ def complete_dataclass(
typevars_map,
)

# dataclass.__init__ must be defined here so its `__qualname__` can be changed since functions can't be copied.
# This needs to be called before we change the __init__
sig = generate_dataclass_signature(cls, cls.__pydantic_fields__, config_wrapper) # type: ignore

# dataclass.__init__ must be defined here so its `__qualname__` can be changed since functions can't be copied.
def __init__(__dataclass_self__: PydanticDataclass, *args: Any, **kwargs: Any) -> None:
__tracebackhide__ = True
s = __dataclass_self__
s.__pydantic_validator__.validate_python(ArgsKwargs(args, kwargs), self_instance=s)

__init__.__qualname__ = f'{cls.__qualname__}.__init__'
sig = generate_dataclass_signature(cls)

cls.__init__ = __init__ # type: ignore
cls.__signature__ = sig # type: ignore
cls.__pydantic_config__ = config_wrapper.config_dict # type: ignore

cls.__signature__ = sig # type: ignore
get_core_schema = getattr(cls, '__get_pydantic_core_schema__', None)
try:
if get_core_schema:
Expand Down Expand Up @@ -185,54 +187,61 @@ def validated_setattr(instance: Any, __field: str, __value: str) -> None:
return True


def generate_dataclass_signature(cls: type[StandardDataclass]) -> Signature:
"""Generate signature for a pydantic dataclass.
def process_param_defaults(param: Parameter) -> Parameter:
"""Custom processing where the parameter default is of type FieldInfo
This implementation assumes we do not support custom `__init__`, which is currently true for pydantic dataclasses.
If we change this eventually, we should make this function's logic more closely mirror that from
`pydantic._internal._model_construction.generate_model_signature`.
Args:
param (Parameter): The parameter
Returns:
Parameter: The custom processed parameter
"""
param_default = param.default
if isinstance(param_default, FieldInfo):
annotation = param.annotation
# Replace the annotation if appropriate
# inspect does "clever" things to show annotations as strings because we have
# `from __future__ import annotations` in main, we don't want that
if annotation == 'Any':
annotation = Any

# Replace the field name with the alias if present
name = param.name
alias = param_default.alias
validation_alias = param_default.validation_alias
if validation_alias is None and isinstance(alias, str) and is_valid_identifier(alias):
name = alias
elif isinstance(validation_alias, str) and is_valid_identifier(validation_alias):
name = validation_alias

# Replace the field default
default = param_default.default
if default is PydanticUndefined:
if param_default.default_factory is PydanticUndefined:
default = inspect.Signature.empty
else:
# this is used by dataclasses to indicate a factory exists:
default = dataclasses._HAS_DEFAULT_FACTORY # type: ignore
return param.replace(annotation=annotation, name=name, default=default)
return param


def generate_dataclass_signature(
cls: type[StandardDataclass], fields: dict[str, FieldInfo], config_wrapper: ConfigWrapper
) -> Signature:
"""Generate signature for a pydantic dataclass.
Args:
cls: The dataclass.
fields: The model fields.
config_wrapper: The config wrapper instance.
Returns:
The signature.
The dataclass signature.
"""
sig = signature(cls)
final_params: dict[str, Parameter] = {}

for param in sig.parameters.values():
param_default = param.default
if isinstance(param_default, FieldInfo):
annotation = param.annotation
# Replace the annotation if appropriate
# inspect does "clever" things to show annotations as strings because we have
# `from __future__ import annotations` in main, we don't want that
if annotation == 'Any':
annotation = Any

# Replace the field name with the alias if present
name = param.name
alias = param_default.alias
validation_alias = param_default.validation_alias
if validation_alias is None and isinstance(alias, str) and is_valid_identifier(alias):
name = alias
elif isinstance(validation_alias, str) and is_valid_identifier(validation_alias):
name = validation_alias

# Replace the field default
default = param_default.default
if default is PydanticUndefined:
if param_default.default_factory is PydanticUndefined:
default = inspect.Signature.empty
else:
# this is used by dataclasses to indicate a factory exists:
default = dataclasses._HAS_DEFAULT_FACTORY # type: ignore

param = param.replace(annotation=annotation, name=name, default=default)
final_params[param.name] = param

return Signature(parameters=list(final_params.values()), return_annotation=None)
return generate_pydantic_signature(
init=cls.__init__, fields=fields, config_wrapper=config_wrapper, post_process_parameter=process_param_defaults
)


def is_builtin_dataclass(_cls: type[Any]) -> TypeGuard[type[StandardDataclass]]:
Expand Down
86 changes: 85 additions & 1 deletion pydantic/_internal/_generate_schema.py
Expand Up @@ -82,7 +82,7 @@
CallbackGetCoreSchemaHandler,
)
from ._typing_extra import is_finalvar
from ._utils import lenient_issubclass
from ._utils import is_valid_identifier, lenient_issubclass

if TYPE_CHECKING:
from ..main import BaseModel
Expand Down Expand Up @@ -2085,3 +2085,87 @@ def get(self) -> str | None:
return self._stack[-1]
else:
return None


def generate_pydantic_signature(
init: Callable[..., None],
fields: dict[str, FieldInfo],
config_wrapper: ConfigWrapper,
post_process_parameter: Callable[[Parameter], Parameter] = lambda x: x,
) -> inspect.Signature:
"""Generate signature for a pydantic class generated by inheriting from BaseModel or
using the dataclass annotation
Args:
init: The class init.
fields: The model fields.
config_wrapper: The config wrapper instance.
post_process_parameter: Optional additional processing for parameter
Returns:
The dataclass/BaseModel subclass signature.
"""
from itertools import islice

present_params = signature(init).parameters.values()
merged_params: dict[str, Parameter] = {}
var_kw = None
use_var_kw = False

for param in islice(present_params, 1, None): # skip self arg
# inspect does "clever" things to show annotations as strings because we have
# `from __future__ import annotations` in main, we don't want that
if param.annotation == 'Any':
param = param.replace(annotation=Any)
if param.kind is param.VAR_KEYWORD:
var_kw = param
continue
merged_params[param.name] = post_process_parameter(param)

if var_kw: # if custom init has no var_kw, fields which are not declared in it cannot be passed through
allow_names = config_wrapper.populate_by_name
for field_name, field in fields.items():
# when alias is a str it should be used for signature generation
if isinstance(field.alias, str):
param_name = field.alias
else:
param_name = field_name

if field_name in merged_params or param_name in merged_params:
continue

if not is_valid_identifier(param_name):
if allow_names and is_valid_identifier(field_name):
param_name = field_name
else:
use_var_kw = True
continue

kwargs = {} if field.is_required() else {'default': field.get_default(call_default_factory=False)}
merged_params[param_name] = post_process_parameter(
Parameter(param_name, Parameter.KEYWORD_ONLY, annotation=field.rebuild_annotation(), **kwargs)
)

if config_wrapper.extra == 'allow':
use_var_kw = True

if var_kw and use_var_kw:
# Make sure the parameter for extra kwargs
# does not have the same name as a field
default_model_signature = [
('__pydantic_self__', Parameter.POSITIONAL_OR_KEYWORD),
('data', Parameter.VAR_KEYWORD),
]
if [(p.name, p.kind) for p in present_params] == default_model_signature:
# if this is the standard model signature, use extra_data as the extra args name
var_kw_name = 'extra_data'
else:
# else start from var_kw
var_kw_name = var_kw.name

# generate a name that's definitely unique
while var_kw_name in fields:
var_kw_name += '_'
merged_params[var_kw_name] = post_process_parameter(var_kw.replace(name=var_kw_name))

return inspect.Signature(parameters=list(merged_params.values()), return_annotation=None)
70 changes: 3 additions & 67 deletions pydantic/_internal/_model_construction.py
Expand Up @@ -24,12 +24,12 @@
get_attribute_from_bases,
)
from ._fields import collect_model_fields, is_valid_field_name, is_valid_privateattr_name
from ._generate_schema import GenerateSchema
from ._generate_schema import GenerateSchema, generate_pydantic_signature
from ._generics import PydanticGenericMetadata, get_model_typevars_map
from ._mock_val_ser import MockValSer, set_model_mocks
from ._schema_generation_shared import CallbackGetCoreSchemaHandler
from ._typing_extra import get_cls_types_namespace, is_classvar, parent_frame_namespace
from ._utils import ClassAttribute, is_valid_identifier
from ._utils import ClassAttribute
from ._validate_call import ValidateCallWrapper

if typing.TYPE_CHECKING:
Expand Down Expand Up @@ -518,71 +518,7 @@ def generate_model_signature(
Returns:
The model signature.
"""
from inspect import Parameter, Signature, signature
from itertools import islice

present_params = signature(init).parameters.values()
merged_params: dict[str, Parameter] = {}
var_kw = None
use_var_kw = False

for param in islice(present_params, 1, None): # skip self arg
# inspect does "clever" things to show annotations as strings because we have
# `from __future__ import annotations` in main, we don't want that
if param.annotation == 'Any':
param = param.replace(annotation=Any)
if param.kind is param.VAR_KEYWORD:
var_kw = param
continue
merged_params[param.name] = param

if var_kw: # if custom init has no var_kw, fields which are not declared in it cannot be passed through
allow_names = config_wrapper.populate_by_name
for field_name, field in fields.items():
# when alias is a str it should be used for signature generation
if isinstance(field.alias, str):
param_name = field.alias
else:
param_name = field_name

if field_name in merged_params or param_name in merged_params:
continue

if not is_valid_identifier(param_name):
if allow_names and is_valid_identifier(field_name):
param_name = field_name
else:
use_var_kw = True
continue

kwargs = {} if field.is_required() else {'default': field.get_default(call_default_factory=False)}
merged_params[param_name] = Parameter(
param_name, Parameter.KEYWORD_ONLY, annotation=field.rebuild_annotation(), **kwargs
)

if config_wrapper.extra == 'allow':
use_var_kw = True

if var_kw and use_var_kw:
# Make sure the parameter for extra kwargs
# does not have the same name as a field
default_model_signature = [
('__pydantic_self__', Parameter.POSITIONAL_OR_KEYWORD),
('data', Parameter.VAR_KEYWORD),
]
if [(p.name, p.kind) for p in present_params] == default_model_signature:
# if this is the standard model signature, use extra_data as the extra args name
var_kw_name = 'extra_data'
else:
# else start from var_kw
var_kw_name = var_kw.name

# generate a name that's definitely unique
while var_kw_name in fields:
var_kw_name += '_'
merged_params[var_kw_name] = var_kw.replace(name=var_kw_name)

return Signature(parameters=list(merged_params.values()), return_annotation=None)
return generate_pydantic_signature(init, fields, config_wrapper)


class _PydanticWeakRef:
Expand Down
13 changes: 13 additions & 0 deletions tests/test_dataclasses.py
Expand Up @@ -2507,6 +2507,19 @@ class Model:
)


def test_inherited_dataclass_signature():
@pydantic.dataclasses.dataclass
class A:
a: int

@pydantic.dataclasses.dataclass
class B(A):
b: int

assert str(inspect.signature(A)) == '(a: int) -> None'
assert str(inspect.signature(B)) == '(a: int, b: int) -> None'


def test_dataclasses_with_slots_and_default():
@pydantic.dataclasses.dataclass(slots=True)
class A:
Expand Down

0 comments on commit 0ca1cf2

Please sign in to comment.