Skip to content

Commit

Permalink
Do not override model_post_init in subclass with private attrs (#7302)
Browse files Browse the repository at this point in the history
Co-authored-by: David Montague <35119617+dmontagu@users.noreply.github.com>
  • Loading branch information
Viicos and dmontagu committed Sep 11, 2023
1 parent bb8289e commit ed009ba
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 9 deletions.
18 changes: 11 additions & 7 deletions pydantic/_internal/_decorators.py
Expand Up @@ -341,6 +341,9 @@ def merge_seqs(seqs: list[deque[type[Any]]]) -> Iterable[type[Any]]:
return tuple(merge_seqs(seqs))


_sentinel = object()


def get_attribute_from_bases(tp: type[Any] | tuple[type[Any], ...], name: str) -> Any:
"""Get the attribute from the next class in the MRO that has it,
aiming to simulate calling the method on the actual class.
Expand All @@ -362,17 +365,18 @@ def get_attribute_from_bases(tp: type[Any] | tuple[type[Any], ...], name: str) -
"""
if isinstance(tp, tuple):
for base in mro_for_bases(tp):
if hasattr(base, name):
return getattr(base, name)
attribute = base.__dict__.get(name, _sentinel)
if attribute is not _sentinel:
attribute_get = getattr(attribute, '__get__', None)
if attribute_get is not None:
return attribute_get(None, tp)
return attribute
raise AttributeError(f'{name} not found in {tp}')
else:
try:
return getattr(tp, name)
except AttributeError as e:
for base in mro(tp):
if hasattr(base, name):
return getattr(base, name)
raise e
except AttributeError:
return get_attribute_from_bases(mro(tp), name)


def get_attribute_from_base_dicts(tp: type[Any], name: str) -> Any:
Expand Down
16 changes: 14 additions & 2 deletions pydantic/_internal/_model_construction.py
Expand Up @@ -104,9 +104,9 @@ def __new__(
namespace, config_wrapper.ignored_types, class_vars, base_field_names
)
if private_attributes:
if 'model_post_init' in namespace:
original_model_post_init = get_model_post_init(namespace, bases)
if original_model_post_init is not None:
# if there are private_attributes and a model_post_init function, we handle both
original_model_post_init = namespace['model_post_init']

def wrapped_model_post_init(self: BaseModel, __context: Any) -> None:
"""We need to both initialize private attributes and call the user-defined model_post_init
Expand Down Expand Up @@ -266,6 +266,18 @@ def init_private_attributes(self: BaseModel, __context: Any) -> None:
object_setattr(self, '__pydantic_private__', pydantic_private)


def get_model_post_init(namespace: dict[str, Any], bases: tuple[type[Any], ...]) -> Callable[..., Any] | None:
"""Get the `model_post_init` method from the namespace or the class bases, or `None` if not defined."""
if 'model_post_init' in namespace:
return namespace['model_post_init']

from ..main import BaseModel

model_post_init = get_attribute_from_bases(bases, 'model_post_init')
if model_post_init is not BaseModel.model_post_init:
return model_post_init


def inspect_namespace( # noqa C901
namespace: dict[str, Any],
ignored_types: tuple[type[Any], ...],
Expand Down
42 changes: 42 additions & 0 deletions tests/test_main.py
Expand Up @@ -1979,6 +1979,48 @@ def model_post_init(self, __context: Any) -> None:
BaseModel.model_post_init = original_base_model_post_init


def test_model_post_init_subclass_private_attrs():
"""https://github.com/pydantic/pydantic/issues/7293"""
calls = []

class A(BaseModel):
a: int = 1

def model_post_init(self, __context: Any) -> None:
calls.append(f'{self.__class__.__name__}.model_post_init')

class B(A):
pass

class C(B):
_private: bool = True

C()

assert calls == ['C.model_post_init']


def test_model_post_init_correct_mro():
"""https://github.com/pydantic/pydantic/issues/7293"""
calls = []

class A(BaseModel):
a: int = 1

class B(BaseModel):
b: int = 1

def model_post_init(self, __context: Any) -> None:
calls.append(f'{self.__class__.__name__}.model_post_init')

class C(A, B):
_private: bool = True

C()

assert calls == ['C.model_post_init']


def test_deeper_recursive_model():
class A(BaseModel):
b: 'B'
Expand Down

0 comments on commit ed009ba

Please sign in to comment.