Skip to content

Commit

Permalink
Add support for dataclass fields init (#8552)
Browse files Browse the repository at this point in the history
Co-authored-by: sydney-runkle <sydneymarierunkle@gmail.com>
  • Loading branch information
dmontagu and sydney-runkle committed Jan 19, 2024
1 parent 48d0df4 commit a2a4281
Show file tree
Hide file tree
Showing 8 changed files with 178 additions and 4 deletions.
1 change: 1 addition & 0 deletions docs/concepts/fields.md
Expand Up @@ -459,6 +459,7 @@ print(foo)

There are fields that can be used to constrain dataclasses:

* `init`: Whether the field should be included in the `__init__` of the dataclass.
* `init_var`: Whether the field should be seen as an [init-only field] in the dataclass.
* `kw_only`: Whether the field should be a keyword-only argument in the constructor of the dataclass.

Expand Down
43 changes: 43 additions & 0 deletions docs/errors/usage_errors.md
Expand Up @@ -1046,4 +1046,47 @@ class Model(BaseModel):
date: datetime.date = Field(description='A date')
```

## Incompatible `dataclass` `init` and `extra` settings {#dataclass-init-false-extra-allow}

Pydantic does not allow the specification of the `extra='allow'` setting on a dataclass
while any of the fields have `init=False` set.

Thus, you may not do something like the following:

```py test="skip"
from pydantic import ConfigDict, Field
from pydantic.dataclasses import dataclass


@dataclass(config=ConfigDict(extra='allow'))
class A:
a: int = Field(init=False, default=1)
```

The above snippet results in the following error during schema building for the `A` dataclass:

```
pydantic.errors.PydanticUserError: Field a has `init=False` and dataclass has config setting `extra="allow"`.
This combination is not allowed.
```

## Incompatible `init` and `init_var` settings on `dataclass` field {#clashing-init-and-init-var}

The `init=False` and `init_var=True` settings are mutually exclusive. Doing so results in the `PydanticUserError` shown in the example below.

```py test="skip"
from pydantic import Field
from pydantic.dataclasses import dataclass


@dataclass
class Foo:
bar: str = Field(..., init=False, init_var=True)


"""
pydantic.errors.PydanticUserError: Dataclass field bar has init=False and init_var=True, but these are mutually exclusive.
"""
```

{% endraw %}
11 changes: 10 additions & 1 deletion pydantic/_internal/_fields.py
Expand Up @@ -10,6 +10,8 @@

from pydantic_core import PydanticUndefined

from pydantic.errors import PydanticUserError

from . import _typing_extra
from ._config import ConfigWrapper
from ._repr import Representation
Expand Down Expand Up @@ -284,11 +286,18 @@ def collect_dataclass_fields(

if isinstance(dataclass_field.default, FieldInfo):
if dataclass_field.default.init_var:
# TODO: same note as above
if dataclass_field.default.init is False:
raise PydanticUserError(
f'Dataclass field {ann_name} has init=False and init_var=True, but these are mutually exclusive.',
code='clashing-init-and-init-var',
)

# TODO: same note as above re validate_assignment
continue
field_info = FieldInfo.from_annotated_attribute(ann_type, dataclass_field.default)
else:
field_info = FieldInfo.from_annotated_attribute(ann_type, dataclass_field)

fields[ann_name] = field_info

if field_info.default is not PydanticUndefined and isinstance(getattr(cls, ann_name, field_info), FieldInfo):
Expand Down
13 changes: 13 additions & 0 deletions pydantic/_internal/_generate_schema.py
Expand Up @@ -941,6 +941,7 @@ def _generate_dc_field_schema(
return core_schema.dataclass_field(
name,
common_field['schema'],
init=field_info.init,
init_only=field_info.init_var or None,
kw_only=None if field_info.kw_only else False,
serialization_exclude=common_field['serialization_exclude'],
Expand Down Expand Up @@ -1468,6 +1469,18 @@ def _dataclass_schema(
self._types_namespace,
typevars_map=typevars_map,
)

# disallow combination of init=False on a dataclass field and extra='allow' on a dataclass
if config and config.get('extra') == 'allow':
# disallow combination of init=False on a dataclass field and extra='allow' on a dataclass
for field_name, field in fields.items():
if field.init is False:
raise PydanticUserError(
f'Field {field_name} has `init=False` and dataclass has config setting `extra="allow"`. '
f'This combination is not allowed.',
code='dataclass-init-false-extra-allow',
)

decorators = dataclass.__dict__.get('__pydantic_decorators__') or DecoratorInfos.build(dataclass)
# Move kw_only=False args to the start of the list, as this is how vanilla dataclasses work.
# Note that when kw_only is missing or None, it is treated as equivalent to kw_only=True
Expand Down
5 changes: 4 additions & 1 deletion pydantic/_internal/_signature.py
Expand Up @@ -68,7 +68,7 @@ def _process_param_defaults(param: Parameter) -> Parameter:
return param


def _generate_signature_parameters(
def _generate_signature_parameters( # noqa: C901 (ignore complexity, could use a refactor)
init: Callable[..., None],
fields: dict[str, FieldInfo],
config_wrapper: ConfigWrapper,
Expand All @@ -85,6 +85,9 @@ def _generate_signature_parameters(
# inspect does "clever" things to show annotations as strings because we have
# `from __future__ import annotations` in main, we don't want that
if fields.get(param.name):
# exclude params with init=False
if getattr(fields[param.name], 'init', True) is False:
continue
param = param.replace(name=_field_name_for_signature(param.name, fields[param.name]))
if param.annotation == 'Any':
param = param.replace(annotation=Any)
Expand Down
2 changes: 2 additions & 0 deletions pydantic/errors.py
Expand Up @@ -59,6 +59,8 @@
'type-adapter-config-unused',
'root-model-extra',
'unevaluable-type-annotation',
'dataclass-init-false-extra-allow',
'clashing-init-and-init-var',
]


Expand Down
15 changes: 13 additions & 2 deletions pydantic/fields.py
Expand Up @@ -64,6 +64,7 @@ class _FromFieldInfoInputs(typing_extensions.TypedDict, total=False):
frozen: bool | None
validate_default: bool | None
repr: bool
init: bool | None
init_var: bool | None
kw_only: bool | None

Expand Down Expand Up @@ -101,7 +102,8 @@ class FieldInfo(_repr.Representation):
frozen: Whether the field is frozen.
validate_default: Whether to validate the default value of the field.
repr: Whether to include the field in representation of the model.
init_var: Whether the field should be included in the constructor of the dataclass.
init: Whether the field should be included in the constructor of the dataclass.
init_var: Whether the field should _only_ be included in the constructor of the dataclass, and not stored.
kw_only: Whether the field should be a keyword-only argument in the constructor of the dataclass.
metadata: List of metadata constraints.
"""
Expand All @@ -122,6 +124,7 @@ class FieldInfo(_repr.Representation):
frozen: bool | None
validate_default: bool | None
repr: bool
init: bool | None
init_var: bool | None
kw_only: bool | None
metadata: list[Any]
Expand All @@ -143,6 +146,7 @@ class FieldInfo(_repr.Representation):
'frozen',
'validate_default',
'repr',
'init',
'init_var',
'kw_only',
'metadata',
Expand Down Expand Up @@ -203,6 +207,7 @@ def __init__(self, **kwargs: Unpack[_FieldInfoInputs]) -> None:
self.validate_default = kwargs.pop('validate_default', None)
self.frozen = kwargs.pop('frozen', None)
# currently only used on dataclasses
self.init = kwargs.pop('init', None)
self.init_var = kwargs.pop('init_var', None)
self.kw_only = kwargs.pop('kw_only', None)

Expand Down Expand Up @@ -360,6 +365,7 @@ class MyModel(pydantic.BaseModel):
)
pydantic_field.frozen = final or pydantic_field.frozen
pydantic_field.init_var = init_var
pydantic_field.init = getattr(default, 'init', None)
pydantic_field.kw_only = getattr(default, 'kw_only', None)
return pydantic_field
else:
Expand Down Expand Up @@ -596,6 +602,7 @@ class _EmptyKwargs(typing_extensions.TypedDict):
frozen=None,
validate_default=None,
repr=True,
init=None,
init_var=None,
kw_only=None,
pattern=None,
Expand Down Expand Up @@ -630,6 +637,7 @@ def Field( # noqa: C901
frozen: bool | None = _Unset,
validate_default: bool | None = _Unset,
repr: bool = _Unset,
init: bool | None = _Unset,
init_var: bool | None = _Unset,
kw_only: bool | None = _Unset,
pattern: str | None = _Unset,
Expand Down Expand Up @@ -675,7 +683,9 @@ def Field( # noqa: C901
validate_default: If `True`, apply validation to the default value every time you create an instance.
Otherwise, for performance reasons, the default value of the field is trusted and not validated.
repr: A boolean indicating whether to include the field in the `__repr__` output.
init_var: Whether the field should be included in the constructor of the dataclass.
init: Whether the field should be included in the constructor of the dataclass.
(Only applies to dataclasses.)
init_var: Whether the field should _only_ be included in the constructor of the dataclass.
(Only applies to dataclasses.)
kw_only: Whether the field should be a keyword-only argument in the constructor of the dataclass.
(Only applies to dataclasses.)
Expand Down Expand Up @@ -784,6 +794,7 @@ def Field( # noqa: C901
pattern=pattern,
validate_default=validate_default,
repr=repr,
init=init,
init_var=init_var,
kw_only=kw_only,
strict=strict,
Expand Down
92 changes: 92 additions & 0 deletions tests/test_dataclasses.py
Expand Up @@ -2648,3 +2648,95 @@ class Model(BaseModel):
n: Nested

assert Model.model_validate_strings({'n': {'d': '2017-01-01'}}).n.d == date(2017, 1, 1)


@pytest.mark.parametrize('field_constructor', [dataclasses.field, pydantic.dataclasses.Field])
@pytest.mark.parametrize('extra', ['ignore', 'forbid'])
def test_init_false_not_in_signature(extra, field_constructor):
@pydantic.dataclasses.dataclass(config=ConfigDict(extra=extra))
class MyDataclass:
a: int = field_constructor(init=False, default=-1)
b: int = pydantic.dataclasses.Field(default=2)

signature = inspect.signature(MyDataclass)
# `a` should not be in the __init__
assert 'a' not in signature.parameters.keys()
assert 'b' in signature.parameters.keys()


init_test_cases = [
({'a': 2, 'b': -1}, 'ignore', {'a': 2, 'b': 1}),
({'a': 2}, 'ignore', {'a': 2, 'b': 1}),
(
{'a': 2, 'b': -1},
'forbid',
[
{
'type': 'unexpected_keyword_argument',
'loc': ('b',),
'msg': 'Unexpected keyword argument',
'input': -1,
}
],
),
({'a': 2}, 'forbid', {'a': 2, 'b': 1}),
]


@pytest.mark.parametrize('field_constructor', [dataclasses.field, pydantic.dataclasses.Field])
@pytest.mark.parametrize(
'input_data,extra,expected',
init_test_cases,
)
def test_init_false_with_post_init(input_data, extra, expected, field_constructor):
@pydantic.dataclasses.dataclass(config=ConfigDict(extra=extra))
class MyDataclass:
a: int
b: int = field_constructor(init=False)

def __post_init__(self):
self.b = 1

if isinstance(expected, list):
with pytest.raises(ValidationError) as exc_info:
MyDataclass(**input_data)

assert exc_info.value.errors(include_url=False) == expected
else:
assert dataclasses.asdict(MyDataclass(**input_data)) == expected


@pytest.mark.parametrize('field_constructor', [dataclasses.field, pydantic.dataclasses.Field])
@pytest.mark.parametrize(
'input_data,extra,expected',
init_test_cases,
)
def test_init_false_with_default(input_data, extra, expected, field_constructor):
@pydantic.dataclasses.dataclass(config=ConfigDict(extra=extra))
class MyDataclass:
a: int
b: int = field_constructor(init=False, default=1)

if isinstance(expected, list):
with pytest.raises(ValidationError) as exc_info:
MyDataclass(**input_data)

assert exc_info.value.errors(include_url=False) == expected
else:
assert dataclasses.asdict(MyDataclass(**input_data)) == expected


def test_disallow_extra_allow_and_init_false() -> None:
with pytest.raises(PydanticUserError, match='This combination is not allowed.'):

@pydantic.dataclasses.dataclass(config=ConfigDict(extra='allow'))
class A:
a: int = Field(init=False, default=1)


def test_disallow_init_false_and_init_var_true() -> None:
with pytest.raises(PydanticUserError, match='mutually exclusive.'):

@pydantic.dataclasses.dataclass
class Foo:
bar: str = Field(..., init=False, init_var=True)

0 comments on commit a2a4281

Please sign in to comment.