Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Only hash model_fields, not whole __dict__ #7786

Merged
merged 6 commits into from Nov 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
50 changes: 37 additions & 13 deletions pydantic/_internal/_model_construction.py
@@ -1,6 +1,7 @@
"""Private logic for creating models."""
from __future__ import annotations as _annotations

import operator
import typing
import warnings
import weakref
Expand Down Expand Up @@ -123,9 +124,6 @@ def wrapped_model_post_init(self: BaseModel, __context: Any) -> None:
namespace['__class_vars__'] = class_vars
namespace['__private_attributes__'] = {**base_private_attributes, **private_attributes}

if config_wrapper.frozen:
set_default_hash_func(namespace, bases)

cls: type[BaseModel] = super().__new__(mcs, cls_name, bases, namespace, **kwargs) # type: ignore

from ..main import BaseModel
Expand Down Expand Up @@ -181,6 +179,10 @@ def wrapped_model_post_init(self: BaseModel, __context: Any) -> None:

types_namespace = get_cls_types_namespace(cls, parent_namespace)
set_model_fields(cls, bases, config_wrapper, types_namespace)

if config_wrapper.frozen and '__hash__' not in namespace:
set_default_hash_func(cls, bases)

complete_model_class(
cls,
cls_name,
Expand Down Expand Up @@ -388,19 +390,41 @@ def inspect_namespace( # noqa C901
return private_attributes


def set_default_hash_func(namespace: dict[str, Any], bases: tuple[type[Any], ...]) -> None:
if '__hash__' in namespace:
return

def set_default_hash_func(cls: type[BaseModel], bases: tuple[type[Any], ...]) -> None:
base_hash_func = get_attribute_from_bases(bases, '__hash__')
if base_hash_func in {None, object.__hash__}:
# If `__hash__` is None _or_ `object.__hash__`, we generate a hash function.
# It will be `None` if not overridden from BaseModel, but may be `object.__hash__` if there is another
new_hash_func = make_hash_func(cls)
if base_hash_func in {None, object.__hash__} or getattr(base_hash_func, '__code__', None) == new_hash_func.__code__:
# If `__hash__` is some default, we generate a hash function.
# It will be `None` if not overridden from BaseModel.
# It may be `object.__hash__` if there is another
# parent class earlier in the bases which doesn't override `__hash__` (e.g. `typing.Generic`).
def hash_func(self: Any) -> int:
return hash(self.__class__) + hash(tuple(self.__dict__.values()))
# It may be a value set by `set_default_hash_func` if `cls` is a subclass of another frozen model.
# In the last case we still need a new hash function to account for new `model_fields`.
cls.__hash__ = new_hash_func


def make_hash_func(cls: type[BaseModel]) -> Any:
getter = operator.itemgetter(*cls.model_fields.keys()) if cls.model_fields else lambda _: 0

def hash_func(self: Any) -> int:
try:
return hash(getter(self.__dict__))
except KeyError:
# In rare cases (such as when using the deprecated copy method), the __dict__ may not contain
# all model fields, which is how we can get here.
# getter(self.__dict__) is much faster than any 'safe' method that accounts for missing keys,
# and wrapping it in a `try` doesn't slow things down much in the common case.
return hash(getter(FallbackDict(self.__dict__))) # type: ignore

return hash_func


class FallbackDict:
def __init__(self, inner):
self.inner = inner

namespace['__hash__'] = hash_func
def __getitem__(self, key):
return self.inner.get(key)


def set_model_fields(
Expand Down
82 changes: 63 additions & 19 deletions tests/test_main.py
Expand Up @@ -590,6 +590,16 @@ class TestModel(BaseModel):
assert "unhashable type: 'list'" in exc_info.value.args[0]


def test_hash_function_empty_model():
class TestModel(BaseModel):
model_config = ConfigDict(frozen=True)

m = TestModel()
m2 = TestModel()
assert m == m2
assert hash(m) == hash(m2)


def test_hash_function_give_different_result_for_different_object():
class TestModel(BaseModel):
model_config = ConfigDict(frozen=True)
Expand All @@ -602,13 +612,45 @@ class TestModel(BaseModel):
assert hash(m) == hash(m2)
assert hash(m) != hash(m3)

# Redefined `TestModel`

def test_hash_function_works_when_instance_dict_modified():
class TestModel(BaseModel):
model_config = ConfigDict(frozen=True)
a: int = 10

m4 = TestModel()
assert hash(m) != hash(m4)
a: int
b: int

m = TestModel(a=1, b=2)
h = hash(m)

# Test edge cases where __dict__ is modified
# @functools.cached_property can add keys to __dict__, these should be ignored.
m.__dict__['c'] = 1
assert hash(m) == h

# Order of keys can be changed, e.g. with the deprecated copy method, which shouldn't matter.
m.__dict__ = {'b': 2, 'a': 1}
assert hash(m) == h

# Keys can be missing, e.g. when using the deprecated copy method.
# This should change the hash, and more importantly hashing shouldn't raise a KeyError.
del m.__dict__['a']
assert h != hash(m)


def test_default_hash_function_overrides_default_hash_function():
class A(BaseModel):
model_config = ConfigDict(frozen=True)

x: int

class B(A):
model_config = ConfigDict(frozen=True)

y: int

assert A.__hash__ != B.__hash__
assert hash(A(x=1)) != hash(B(x=1, y=2)) != hash(B(x=1, y=3))


def test_hash_method_is_inherited_for_frozen_models():
Expand Down Expand Up @@ -2220,7 +2262,7 @@ class Model(BaseModel):
def test_model_equality_generics():
T = TypeVar('T')

class GenericModel(BaseModel, Generic[T]):
class GenericModel(BaseModel, Generic[T], frozen=True):
x: T

class ConcreteModel(BaseModel):
Expand All @@ -2233,20 +2275,22 @@ class ConcreteModel(BaseModel):
assert GenericModel(x=1) != GenericModel(x=2)

S = TypeVar('S')
assert GenericModel(x=1) == GenericModel(x=1)
assert GenericModel(x=1) == GenericModel[S](x=1)
assert GenericModel(x=1) == GenericModel[Any](x=1)
assert GenericModel(x=1) == GenericModel[float](x=1)

assert GenericModel[int](x=1) == GenericModel[int](x=1)
assert GenericModel[int](x=1) == GenericModel[S](x=1)
assert GenericModel[int](x=1) == GenericModel[Any](x=1)
assert GenericModel[int](x=1) == GenericModel[float](x=1)

# Test that it works with nesting as well
nested_any = GenericModel[GenericModel[Any]](x=GenericModel[Any](x=1))
nested_int = GenericModel[GenericModel[int]](x=GenericModel[int](x=1))
assert nested_any == nested_int
models = [
GenericModel(x=1),
GenericModel[S](x=1),
GenericModel[Any](x=1),
GenericModel[int](x=1),
GenericModel[float](x=1),
]
for m1 in models:
for m2 in models:
# Test that it works with nesting as well
m3 = GenericModel[type(m1)](x=m1)
m4 = GenericModel[type(m2)](x=m2)
assert m1 == m2
assert m3 == m4
assert hash(m1) == hash(m2)
assert hash(m3) == hash(m4)


def test_model_validate_strict() -> None:
Expand Down