Skip to content

Commit

Permalink
Fix inheriting annotations in dataclasses (#8679)
Browse files Browse the repository at this point in the history
Co-authored-by: Alex Hall <alex.mojaki@gmail.com>
  • Loading branch information
sydney-runkle and alexmojaki committed Feb 5, 2024
1 parent f3532ed commit 43327d8
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 21 deletions.
49 changes: 28 additions & 21 deletions pydantic/dataclasses.py
Expand Up @@ -93,7 +93,7 @@ def dataclass(


@dataclass_transform(field_specifiers=(dataclasses.field, Field))
def dataclass(
def dataclass( # noqa: C901
_cls: type[_T] | None = None,
*,
init: Literal[False] = False,
Expand Down Expand Up @@ -153,26 +153,33 @@ def make_pydantic_fields_compatible(cls: type[Any]) -> None:
into
`x: int = dataclasses.field(default=pydantic.Field(..., kw_only=True), kw_only=True)`
"""
# In Python < 3.9, `__annotations__` might not be present if there are no fields.
# we therefore need to use `getattr` to avoid an `AttributeError`.
for field_name in getattr(cls, '__annotations__', []):
field_value = getattr(cls, field_name, None)
# Process only if this is an instance of `FieldInfo`.
if not isinstance(field_value, FieldInfo):
continue

# Initialize arguments for the standard `dataclasses.field`.
field_args: dict = {'default': field_value}

# Handle `kw_only` for Python 3.10+
if sys.version_info >= (3, 10) and field_value.kw_only:
field_args['kw_only'] = True

# Set `repr` attribute if it's explicitly specified to be not `True`.
if field_value.repr is not True:
field_args['repr'] = field_value.repr

setattr(cls, field_name, dataclasses.field(**field_args))
for annotation_cls in cls.__mro__:
# In Python < 3.9, `__annotations__` might not be present if there are no fields.
# we therefore need to use `getattr` to avoid an `AttributeError`.
annotations = getattr(annotation_cls, '__annotations__', [])
for field_name in annotations:
field_value = getattr(cls, field_name, None)
# Process only if this is an instance of `FieldInfo`.
if not isinstance(field_value, FieldInfo):
continue

# Initialize arguments for the standard `dataclasses.field`.
field_args: dict = {'default': field_value}

# Handle `kw_only` for Python 3.10+
if sys.version_info >= (3, 10) and field_value.kw_only:
field_args['kw_only'] = True

# Set `repr` attribute if it's explicitly specified to be not `True`.
if field_value.repr is not True:
field_args['repr'] = field_value.repr

setattr(cls, field_name, dataclasses.field(**field_args))
# In Python 3.8, dataclasses checks cls.__dict__['__annotations__'] for annotations,
# so we must make sure it's initialized before we add to it.
if cls.__dict__.get('__annotations__') is None:
cls.__annotations__ = {}
cls.__annotations__[field_name] = annotations[field_name]

def create_dataclass(cls: type[Any]) -> type[PydanticDataclass]:
"""Create a Pydantic dataclass from a regular dataclass.
Expand Down
63 changes: 63 additions & 0 deletions tests/test_dataclasses.py
Expand Up @@ -2740,3 +2740,66 @@ def test_disallow_init_false_and_init_var_true() -> None:
@pydantic.dataclasses.dataclass
class Foo:
bar: str = Field(..., init=False, init_var=True)


def test_annotations_valid_for_field_inheritance() -> None:
# testing https://github.com/pydantic/pydantic/issues/8670

@pydantic.dataclasses.dataclass()
class A:
a: int = pydantic.dataclasses.Field()

@pydantic.dataclasses.dataclass()
class B(A):
...

assert B.__pydantic_fields__['a'].annotation is int

assert B(a=1).a == 1


def test_annotations_valid_for_field_inheritance_with_existing_field() -> None:
# variation on testing https://github.com/pydantic/pydantic/issues/8670

@pydantic.dataclasses.dataclass()
class A:
a: int = pydantic.dataclasses.Field()

@pydantic.dataclasses.dataclass()
class B(A):
b: str = pydantic.dataclasses.Field()

assert B.__pydantic_fields__['a'].annotation is int
assert B.__pydantic_fields__['b'].annotation is str

b = B(a=1, b='b')
assert b.a == 1
assert b.b == 'b'


def test_annotation_with_double_override() -> None:
@pydantic.dataclasses.dataclass()
class A:
a: int
b: int
c: int = pydantic.dataclasses.Field()
d: int = pydantic.dataclasses.Field()

# note, the order of fields is different here, as to test that the annotation
# is correctly set on the field no matter the base's default / current class's default
@pydantic.dataclasses.dataclass()
class B(A):
a: str
c: str
b: str = pydantic.dataclasses.Field()
d: str = pydantic.dataclasses.Field()

@pydantic.dataclasses.dataclass()
class C(B):
...

for class_ in [B, C]:
instance = class_(a='a', b='b', c='c', d='d')
for field_name in ['a', 'b', 'c', 'd']:
assert class_.__pydantic_fields__[field_name].annotation is str
assert getattr(instance, field_name) == field_name

0 comments on commit 43327d8

Please sign in to comment.