Skip to content

Commit

Permalink
Only hash model_fields, not whole __dict__ (#7786)
Browse files Browse the repository at this point in the history
  • Loading branch information
alexmojaki committed Nov 13, 2023
1 parent a168d5f commit f23578e
Show file tree
Hide file tree
Showing 2 changed files with 100 additions and 32 deletions.
50 changes: 37 additions & 13 deletions pydantic/_internal/_model_construction.py
@@ -1,6 +1,7 @@
"""Private logic for creating models."""
from __future__ import annotations as _annotations

import operator
import typing
import warnings
import weakref
Expand Down Expand Up @@ -111,9 +112,6 @@ def wrapped_model_post_init(self: BaseModel, __context: Any) -> None:
namespace['__class_vars__'] = class_vars
namespace['__private_attributes__'] = {**base_private_attributes, **private_attributes}

if config_wrapper.frozen:
set_default_hash_func(namespace, bases)

cls: type[BaseModel] = super().__new__(mcs, cls_name, bases, namespace, **kwargs) # type: ignore

from ..main import BaseModel
Expand Down Expand Up @@ -179,6 +177,10 @@ def wrapped_model_post_init(self: BaseModel, __context: Any) -> None:

types_namespace = get_cls_types_namespace(cls, parent_namespace)
set_model_fields(cls, bases, config_wrapper, types_namespace)

if config_wrapper.frozen and '__hash__' not in namespace:
set_default_hash_func(cls, bases)

complete_model_class(
cls,
cls_name,
Expand Down Expand Up @@ -396,19 +398,41 @@ def inspect_namespace( # noqa C901
return private_attributes


def set_default_hash_func(namespace: dict[str, Any], bases: tuple[type[Any], ...]) -> None:
if '__hash__' in namespace:
return

def set_default_hash_func(cls: type[BaseModel], bases: tuple[type[Any], ...]) -> None:
base_hash_func = get_attribute_from_bases(bases, '__hash__')
if base_hash_func in {None, object.__hash__}:
# If `__hash__` is None _or_ `object.__hash__`, we generate a hash function.
# It will be `None` if not overridden from BaseModel, but may be `object.__hash__` if there is another
new_hash_func = make_hash_func(cls)
if base_hash_func in {None, object.__hash__} or getattr(base_hash_func, '__code__', None) == new_hash_func.__code__:
# If `__hash__` is some default, we generate a hash function.
# It will be `None` if not overridden from BaseModel.
# It may be `object.__hash__` if there is another
# parent class earlier in the bases which doesn't override `__hash__` (e.g. `typing.Generic`).
def hash_func(self: Any) -> int:
return hash(self.__class__) + hash(tuple(self.__dict__.values()))
# It may be a value set by `set_default_hash_func` if `cls` is a subclass of another frozen model.
# In the last case we still need a new hash function to account for new `model_fields`.
cls.__hash__ = new_hash_func


def make_hash_func(cls: type[BaseModel]) -> Any:
getter = operator.itemgetter(*cls.model_fields.keys()) if cls.model_fields else lambda _: 0

def hash_func(self: Any) -> int:
try:
return hash(getter(self.__dict__))
except KeyError:
# In rare cases (such as when using the deprecated copy method), the __dict__ may not contain
# all model fields, which is how we can get here.
# getter(self.__dict__) is much faster than any 'safe' method that accounts for missing keys,
# and wrapping it in a `try` doesn't slow things down much in the common case.
return hash(getter(FallbackDict(self.__dict__))) # type: ignore

return hash_func


class FallbackDict:
def __init__(self, inner):
self.inner = inner

namespace['__hash__'] = hash_func
def __getitem__(self, key):
return self.inner.get(key)


def set_model_fields(
Expand Down
82 changes: 63 additions & 19 deletions tests/test_main.py
Expand Up @@ -607,6 +607,16 @@ class TestModel(BaseModel):
assert "unhashable type: 'list'" in exc_info.value.args[0]


def test_hash_function_empty_model():
class TestModel(BaseModel):
model_config = ConfigDict(frozen=True)

m = TestModel()
m2 = TestModel()
assert m == m2
assert hash(m) == hash(m2)


def test_hash_function_give_different_result_for_different_object():
class TestModel(BaseModel):
model_config = ConfigDict(frozen=True)
Expand All @@ -619,13 +629,45 @@ class TestModel(BaseModel):
assert hash(m) == hash(m2)
assert hash(m) != hash(m3)

# Redefined `TestModel`

def test_hash_function_works_when_instance_dict_modified():
class TestModel(BaseModel):
model_config = ConfigDict(frozen=True)
a: int = 10

m4 = TestModel()
assert hash(m) != hash(m4)
a: int
b: int

m = TestModel(a=1, b=2)
h = hash(m)

# Test edge cases where __dict__ is modified
# @functools.cached_property can add keys to __dict__, these should be ignored.
m.__dict__['c'] = 1
assert hash(m) == h

# Order of keys can be changed, e.g. with the deprecated copy method, which shouldn't matter.
m.__dict__ = {'b': 2, 'a': 1}
assert hash(m) == h

# Keys can be missing, e.g. when using the deprecated copy method.
# This should change the hash, and more importantly hashing shouldn't raise a KeyError.
del m.__dict__['a']
assert h != hash(m)


def test_default_hash_function_overrides_default_hash_function():
class A(BaseModel):
model_config = ConfigDict(frozen=True)

x: int

class B(A):
model_config = ConfigDict(frozen=True)

y: int

assert A.__hash__ != B.__hash__
assert hash(A(x=1)) != hash(B(x=1, y=2)) != hash(B(x=1, y=3))


def test_hash_method_is_inherited_for_frozen_models():
Expand Down Expand Up @@ -2269,7 +2311,7 @@ class Model(BaseModel):
def test_model_equality_generics():
T = TypeVar('T')

class GenericModel(BaseModel, Generic[T]):
class GenericModel(BaseModel, Generic[T], frozen=True):
x: T

class ConcreteModel(BaseModel):
Expand All @@ -2282,20 +2324,22 @@ class ConcreteModel(BaseModel):
assert GenericModel(x=1) != GenericModel(x=2)

S = TypeVar('S')
assert GenericModel(x=1) == GenericModel(x=1)
assert GenericModel(x=1) == GenericModel[S](x=1)
assert GenericModel(x=1) == GenericModel[Any](x=1)
assert GenericModel(x=1) == GenericModel[float](x=1)

assert GenericModel[int](x=1) == GenericModel[int](x=1)
assert GenericModel[int](x=1) == GenericModel[S](x=1)
assert GenericModel[int](x=1) == GenericModel[Any](x=1)
assert GenericModel[int](x=1) == GenericModel[float](x=1)

# Test that it works with nesting as well
nested_any = GenericModel[GenericModel[Any]](x=GenericModel[Any](x=1))
nested_int = GenericModel[GenericModel[int]](x=GenericModel[int](x=1))
assert nested_any == nested_int
models = [
GenericModel(x=1),
GenericModel[S](x=1),
GenericModel[Any](x=1),
GenericModel[int](x=1),
GenericModel[float](x=1),
]
for m1 in models:
for m2 in models:
# Test that it works with nesting as well
m3 = GenericModel[type(m1)](x=m1)
m4 = GenericModel[type(m2)](x=m2)
assert m1 == m2
assert m3 == m4
assert hash(m1) == hash(m2)
assert hash(m3) == hash(m4)


def test_model_validate_strict() -> None:
Expand Down

0 comments on commit f23578e

Please sign in to comment.