Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix hash function generation for frozen models with unusual MRO #7274

Merged
merged 2 commits into from Aug 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
30 changes: 19 additions & 11 deletions pydantic/_internal/_decorators.py
Expand Up @@ -309,6 +309,11 @@ def mro(tp: type[Any]) -> tuple[type[Any], ...]:
# GenericAlias and some other cases
pass

bases = get_bases(tp)
return (tp,) + mro_for_bases(bases)


def mro_for_bases(bases: tuple[type[Any], ...]) -> tuple[type[Any], ...]:
def merge_seqs(seqs: list[deque[type[Any]]]) -> Iterable[type[Any]]:
while True:
non_empty = [seq for seq in seqs if seq]
Expand All @@ -332,14 +337,11 @@ def merge_seqs(seqs: list[deque[type[Any]]]) -> Iterable[type[Any]]:
if seq[0] == candidate:
seq.popleft()

bases = get_bases(tp)
seqs = [deque(mro(base)) for base in bases] + [deque(bases)]
res = tuple(merge_seqs(seqs))

return (tp,) + res
return tuple(merge_seqs(seqs))


def get_attribute_from_bases(tp: type[Any], name: str) -> Any:
def get_attribute_from_bases(tp: type[Any] | tuple[type[Any], ...], name: str) -> Any:
"""Get the attribute from the next class in the MRO that has it,
aiming to simulate calling the method on the actual class.

Expand All @@ -349,7 +351,7 @@ def get_attribute_from_bases(tp: type[Any], name: str) -> Any:
from its bases (as done here).

Args:
tp: The type or class to search for the attribute.
tp: The type or class to search for the attribute. If a tuple, this is treated as a set of base classes.
name: The name of the attribute to retrieve.

Returns:
Expand All @@ -358,13 +360,19 @@ def get_attribute_from_bases(tp: type[Any], name: str) -> Any:
Raises:
AttributeError: If the attribute is not found in any class in the MRO.
"""
try:
return getattr(tp, name)
except AttributeError as e:
for base in reversed(mro(tp)):
if isinstance(tp, tuple):
for base in mro_for_bases(tp):
if hasattr(base, name):
return getattr(base, name)
raise e
raise AttributeError(f'{name} not found in {tp}')
else:
try:
return getattr(tp, name)
except AttributeError as e:
for base in mro(tp):
if hasattr(base, name):
return getattr(base, name)
raise e


def get_attribute_from_base_dicts(tp: type[Any], name: str) -> Any:
Expand Down
22 changes: 11 additions & 11 deletions pydantic/_internal/_model_construction.py
Expand Up @@ -17,7 +17,12 @@
from ..warnings import PydanticDeprecatedSince20
from ._config import ConfigWrapper
from ._core_utils import collect_invalid_schemas, flatten_schema_defs, inline_schema_defs
from ._decorators import ComputedFieldInfo, DecoratorInfos, PydanticDescriptorProxy
from ._decorators import (
ComputedFieldInfo,
DecoratorInfos,
PydanticDescriptorProxy,
get_attribute_from_bases,
)
from ._discriminated_union import apply_discriminators
from ._fields import collect_model_fields, is_valid_field_name, is_valid_privateattr_name
from ._generate_schema import GenerateSchema
Expand Down Expand Up @@ -374,16 +379,11 @@ def set_default_hash_func(namespace: dict[str, Any], bases: tuple[type[Any], ...
if '__hash__' in namespace:
return

base_hash_func = None
for base in bases:
base_hash_func = getattr(base, '__hash__', PydanticUndefined)
if base_hash_func is not PydanticUndefined:
break

if base_hash_func is None:
# This will be the case for `BaseModel` since it defines `__eq__` but not `__hash__`.
# In this case, we generate a standard hash function, generally for use with frozen models.

base_hash_func = get_attribute_from_bases(bases, '__hash__')
if base_hash_func in {None, object.__hash__}:
# If `__hash__` is None _or_ `object.__hash__`, we generate a hash function.
# It will be `None` if not overridden from BaseModel, but may be `object.__hash__` if there is another
# parent class earlier in the bases which doesn't override `__hash__` (e.g. `typing.Generic`).
def hash_func(self: Any) -> int:
return hash(self.__class__) + hash(tuple(self.__dict__.values()))

Expand Down
12 changes: 12 additions & 0 deletions tests/test_generics.py
Expand Up @@ -2613,3 +2613,15 @@ class C(B[int]):
'type': 'int_parsing',
},
]


def test_reverse_order_generic_hashability():
T = TypeVar('T')

class Model(Generic[T], BaseModel):
x: T
model_config = dict(frozen=True)

m1 = Model[int](x=1)
m2 = Model[int](x=1)
assert len({m1, m2}) == 1
14 changes: 14 additions & 0 deletions tests/test_types_typeddict.py
Expand Up @@ -21,6 +21,7 @@
PydanticUserError,
ValidationError,
)
from pydantic._internal._decorators import get_attribute_from_bases
from pydantic.functional_serializers import field_serializer, model_serializer
from pydantic.functional_validators import field_validator, model_validator
from pydantic.type_adapter import TypeAdapter
Expand Down Expand Up @@ -905,3 +906,16 @@ class MySubTypedDict(MyMiddleTypedDict):

validated_data = TypeAdapter(MySubTypedDict).validate_python({'x': 'ABC', 'y': 'DEF', 'z': 'GHI'})
assert validated_data == {'x': 'abc', 'y': 'def', 'z': 'ghi'}


def test_typeddict_mro():
class A(TypedDict):
x = 1

class B(A):
x = 2

class C(B):
pass

assert get_attribute_from_bases(C, 'x') == 2
Copy link
Contributor Author

@dmontagu dmontagu Aug 29, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently on the main branch, this returns 1 @adriangb

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh wow I’m very glad you caught this 🙈