Skip to content

Commit

Permalink
Fix resolution of forward refs in dataclass base classes that are not…
Browse files Browse the repository at this point in the history
… present in the subclass module namespace (#8751)
  • Loading branch information
matsjoyce-refeyn committed Feb 13, 2024
1 parent b145516 commit 4672662
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 4 deletions.
22 changes: 18 additions & 4 deletions pydantic/_internal/_generate_schema.py
Expand Up @@ -8,7 +8,7 @@
import sys
import typing
import warnings
from contextlib import contextmanager
from contextlib import ExitStack, contextmanager
from copy import copy, deepcopy
from enum import Enum
from functools import partial
Expand Down Expand Up @@ -1470,8 +1470,18 @@ def _dataclass_schema(
if origin is not None:
dataclass = origin

config = getattr(dataclass, '__pydantic_config__', None)
with self._config_wrapper_stack.push(config), self._types_namespace_stack.push(dataclass):
with ExitStack() as dataclass_bases_stack:
# Pushing a namespace prioritises items already in the stack, so iterate though the MRO forwards
for dataclass_base in dataclass.__mro__:
if dataclasses.is_dataclass(dataclass_base):
dataclass_bases_stack.enter_context(self._types_namespace_stack.push(dataclass_base))

# Pushing a config overwrites the previous config, so iterate though the MRO backwards
for dataclass_base in reversed(dataclass.__mro__):
if dataclasses.is_dataclass(dataclass_base):
config = getattr(dataclass_base, '__pydantic_config__', None)
dataclass_bases_stack.enter_context(self._config_wrapper_stack.push(config))

core_config = self._config_wrapper.core_config(dataclass)

self = self._current_generate_schema
Expand All @@ -1491,7 +1501,7 @@ def _dataclass_schema(
)

# disallow combination of init=False on a dataclass field and extra='allow' on a dataclass
if config and config.get('extra') == 'allow':
if self._config_wrapper_stack.tail.extra == 'allow':
# disallow combination of init=False on a dataclass field and extra='allow' on a dataclass
for field_name, field in fields.items():
if field.init is False:
Expand Down Expand Up @@ -1540,6 +1550,10 @@ def _dataclass_schema(
self.defs.definitions[dataclass_ref] = self._post_process_generated_schema(schema)
return core_schema.definition_reference_schema(dataclass_ref)

# Type checkers seem to assume ExitStack may suppress exceptions and therefore
# control flow can exit the `with` block without returning.
assert False, 'Unreachable'

def _callable_schema(self, function: Callable[..., Any]) -> core_schema.CallSchema:
"""Generate schema for a Callable.
Expand Down
73 changes: 73 additions & 0 deletions tests/test_dataclasses.py
Expand Up @@ -1622,6 +1622,79 @@ class D2:
]


@pytest.mark.parametrize(
'dataclass_decorator',
[
pydantic.dataclasses.dataclass,
dataclasses.dataclass,
],
ids=['pydantic', 'stdlib'],
)
def test_base_dataclasses_annotations_resolving(create_module, dataclass_decorator: Callable):
@create_module
def module():
import dataclasses
from typing import NewType

OddInt = NewType('OddInt', int)

@dataclasses.dataclass
class D1:
d1: 'OddInt'
s: str

__pydantic_config__ = {'str_to_lower': True}

@dataclass_decorator
class D2(module.D1):
d2: int

assert TypeAdapter(D2).validate_python({'d1': 1, 'd2': 2, 's': 'ABC'}) == D2(d1=1, d2=2, s='abc')


@pytest.mark.parametrize(
'dataclass_decorator',
[
pydantic.dataclasses.dataclass,
dataclasses.dataclass,
],
ids=['pydantic', 'stdlib'],
)
def test_base_dataclasses_annotations_resolving_with_override(create_module, dataclass_decorator: Callable):
@create_module
def module1():
import dataclasses
from typing import NewType

IDType = NewType('IDType', int)

@dataclasses.dataclass
class D1:
db_id: 'IDType'

__pydantic_config__ = {'str_to_lower': True}

@create_module
def module2():
import dataclasses
from typing import NewType

IDType = NewType('IDType', str)

@dataclasses.dataclass
class D2:
db_id: 'IDType'
s: str

__pydantic_config__ = {'str_to_lower': False}

@dataclass_decorator
class D3(module1.D1, module2.D2):
...

assert TypeAdapter(D3).validate_python({'db_id': 42, 's': 'ABC'}) == D3(db_id=42, s='abc')


@pytest.mark.skipif(sys.version_info < (3, 10), reason='kw_only is not available in python < 3.10')
def test_kw_only():
@pydantic.dataclasses.dataclass(kw_only=True)
Expand Down

0 comments on commit 4672662

Please sign in to comment.