Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix inheriting annotations in dataclasses #8679

Merged
merged 12 commits into from Feb 1, 2024
49 changes: 28 additions & 21 deletions pydantic/dataclasses.py
Expand Up @@ -93,7 +93,7 @@ def dataclass(


@dataclass_transform(field_specifiers=(dataclasses.field, Field))
def dataclass(
def dataclass( # noqa: C901
_cls: type[_T] | None = None,
*,
init: Literal[False] = False,
Expand Down Expand Up @@ -153,26 +153,33 @@ def make_pydantic_fields_compatible(cls: type[Any]) -> None:
into
`x: int = dataclasses.field(default=pydantic.Field(..., kw_only=True), kw_only=True)`
"""
# In Python < 3.9, `__annotations__` might not be present if there are no fields.
# we therefore need to use `getattr` to avoid an `AttributeError`.
for field_name in getattr(cls, '__annotations__', []):
field_value = getattr(cls, field_name, None)
# Process only if this is an instance of `FieldInfo`.
if not isinstance(field_value, FieldInfo):
continue

# Initialize arguments for the standard `dataclasses.field`.
field_args: dict = {'default': field_value}

# Handle `kw_only` for Python 3.10+
if sys.version_info >= (3, 10) and field_value.kw_only:
field_args['kw_only'] = True

# Set `repr` attribute if it's explicitly specified to be not `True`.
if field_value.repr is not True:
field_args['repr'] = field_value.repr

setattr(cls, field_name, dataclasses.field(**field_args))
for annotation_cls in cls.__mro__:
# In Python < 3.9, `__annotations__` might not be present if there are no fields.
# we therefore need to use `getattr` to avoid an `AttributeError`.
annotations = getattr(annotation_cls, '__annotations__', [])
for field_name in annotations:
field_value = getattr(cls, field_name, None)
# Process only if this is an instance of `FieldInfo`.
if not isinstance(field_value, FieldInfo):
continue

# Initialize arguments for the standard `dataclasses.field`.
field_args: dict = {'default': field_value}

# Handle `kw_only` for Python 3.10+
if sys.version_info >= (3, 10) and field_value.kw_only:
field_args['kw_only'] = True

# Set `repr` attribute if it's explicitly specified to be not `True`.
if field_value.repr is not True:
field_args['repr'] = field_value.repr

setattr(cls, field_name, dataclasses.field(**field_args))
# In Python 3.8, dataclasses checks cls.__dict__['__annotations__'] for annotations,
# so we must make sure it's initialized before we add to it.
if cls.__dict__.get('__annotations__') is None:
cls.__annotations__ = {}
cls.__annotations__[field_name] = annotations[field_name]

def create_dataclass(cls: type[Any]) -> type[PydanticDataclass]:
"""Create a Pydantic dataclass from a regular dataclass.
Expand Down
85 changes: 85 additions & 0 deletions tests/test_dataclasses.py
Expand Up @@ -2759,3 +2759,88 @@ def test_disallow_init_false_and_init_var_true() -> None:
@pydantic.dataclasses.dataclass
class Foo:
bar: str = Field(..., init=False, init_var=True)


def test_annotations_valid_for_field_inheritance() -> None:
# testing https://github.com/pydantic/pydantic/issues/8670

@pydantic.dataclasses.dataclass()
class A:
a: int = pydantic.dataclasses.Field()

@pydantic.dataclasses.dataclass()
class B(A):
...

assert B.__pydantic_fields__['a'].annotation is int

assert B(a=1).a == 1


def test_annotations_valid_for_field_inheritance_with_existing_field() -> None:
# variation on testing https://github.com/pydantic/pydantic/issues/8670

@pydantic.dataclasses.dataclass()
class A:
a: int = pydantic.dataclasses.Field()

@pydantic.dataclasses.dataclass()
class B(A):
b: str = pydantic.dataclasses.Field()

assert B.__pydantic_fields__['a'].annotation is int
assert B.__pydantic_fields__['b'].annotation is str

b = B(a=1, b='b')
assert b.a == 1
assert b.b == 'b'


def test_annotations_with_override() -> None:
@pydantic.dataclasses.dataclass()
class A:
a: int
b: int
c: int = pydantic.dataclasses.Field()
d: int = pydantic.dataclasses.Field()

# note, the order of fields is different here, as to test that the annotation
# is correctly set on the field no matter the base's default / current class's default
@pydantic.dataclasses.dataclass()
class B(A):
a: str
c: str
b: str = pydantic.dataclasses.Field()
d: str = pydantic.dataclasses.Field()

b = B(a='a', b='b', c='c', d='d')
for field_name in ['a', 'b', 'c', 'd']:
assert B.__pydantic_fields__[field_name].annotation is str
assert getattr(b, field_name) == field_name


def test_annotation_with_double_override() -> None:
@pydantic.dataclasses.dataclass()
class A:
a: int
b: int
c: int = pydantic.dataclasses.Field()
d: int = pydantic.dataclasses.Field()

# note, the order of fields is different here, as to test that the annotation
# is correctly set on the field no matter the base's default / current class's default
@pydantic.dataclasses.dataclass()
class B(A):
a: str
c: str
b: str = pydantic.dataclasses.Field()
d: str = pydantic.dataclasses.Field()

@pydantic.dataclasses.dataclass()
class C(B):
...

c = C(a='a', b='b', c='c', d='d')
sydney-runkle marked this conversation as resolved.
Show resolved Hide resolved
for field_name in ['a', 'b', 'c', 'd']:
assert C.__pydantic_fields__[field_name].annotation is str
assert getattr(c, field_name) == field_name