Skip to content

Commit

Permalink
Only compare pydantic fields in BaseModel.__eq__ instead of whole `…
Browse files Browse the repository at this point in the history
…__dict__` (#7825)

Co-authored-by: David Montague <35119617+dmontagu@users.noreply.github.com>
  • Loading branch information
QuentinSoubeyranAqemia and dmontagu committed Nov 17, 2023
1 parent 575dcd9 commit 62ed0aa
Show file tree
Hide file tree
Showing 6 changed files with 723 additions and 18 deletions.
12 changes: 2 additions & 10 deletions pydantic/_internal/_model_construction.py
Expand Up @@ -25,7 +25,7 @@
from ._mock_val_ser import MockValSer, set_model_mocks
from ._schema_generation_shared import CallbackGetCoreSchemaHandler
from ._typing_extra import get_cls_types_namespace, is_annotated, is_classvar, parent_frame_namespace
from ._utils import ClassAttribute
from ._utils import ClassAttribute, SafeGetItemProxy
from ._validate_call import ValidateCallWrapper

if typing.TYPE_CHECKING:
Expand Down Expand Up @@ -422,19 +422,11 @@ def hash_func(self: Any) -> int:
# all model fields, which is how we can get here.
# getter(self.__dict__) is much faster than any 'safe' method that accounts for missing keys,
# and wrapping it in a `try` doesn't slow things down much in the common case.
return hash(getter(FallbackDict(self.__dict__))) # type: ignore
return hash(getter(SafeGetItemProxy(self.__dict__)))

return hash_func


class FallbackDict:
def __init__(self, inner):
self.inner = inner

def __getitem__(self, key):
return self.inner.get(key)


def set_model_fields(
cls: type[BaseModel], bases: tuple[type[Any], ...], config_wrapper: ConfigWrapper, types_namespace: dict[str, Any]
) -> None:
Expand Down
33 changes: 30 additions & 3 deletions pydantic/_internal/_utils.py
Expand Up @@ -4,14 +4,15 @@
"""
from __future__ import annotations as _annotations

import dataclasses
import keyword
import typing
import weakref
from collections import OrderedDict, defaultdict, deque
from copy import deepcopy
from itertools import zip_longest
from types import BuiltinFunctionType, CodeType, FunctionType, GeneratorType, LambdaType, ModuleType
from typing import Any, TypeVar
from typing import Any, Mapping, TypeVar

from typing_extensions import TypeAlias, TypeGuard

Expand Down Expand Up @@ -317,7 +318,7 @@ def smart_deepcopy(obj: Obj) -> Obj:
return deepcopy(obj) # slowest way when we actually might need a deepcopy


_EMPTY = object()
_SENTINEL = object()


def all_identical(left: typing.Iterable[Any], right: typing.Iterable[Any]) -> bool:
Expand All @@ -329,7 +330,33 @@ def all_identical(left: typing.Iterable[Any], right: typing.Iterable[Any]) -> bo
>>> all_identical([a, b, [a]], [a, b, [a]]) # new list object, while "equal" is not "identical"
False
"""
for left_item, right_item in zip_longest(left, right, fillvalue=_EMPTY):
for left_item, right_item in zip_longest(left, right, fillvalue=_SENTINEL):
if left_item is not right_item:
return False
return True


@dataclasses.dataclass(frozen=True)
class SafeGetItemProxy:
"""Wrapper redirecting `__getitem__` to `get` with a sentinel value as default
This makes is safe to use in `operator.itemgetter` when some keys may be missing
"""

# Define __slots__manually for performances
# @dataclasses.dataclass() only support slots=True in python>=3.10
__slots__ = ('wrapped',)

wrapped: Mapping[str, Any]

def __getitem__(self, __key: str) -> Any:
return self.wrapped.get(__key, _SENTINEL)

# required to pass the object to operator.itemgetter() instances due to a quirk of typeshed
# https://github.com/python/mypy/issues/13713
# https://github.com/python/typeshed/pull/8785
# Since this is typing-only, hide it in a typing.TYPE_CHECKING block
if typing.TYPE_CHECKING:

def __contains__(self, __key: str) -> bool:
return self.wrapped.__contains__(__key)
45 changes: 42 additions & 3 deletions pydantic/main.py
@@ -1,6 +1,7 @@
"""Logic for creating models."""
from __future__ import annotations as _annotations

import operator
import sys
import types
import typing
Expand Down Expand Up @@ -868,12 +869,50 @@ def __eq__(self, other: Any) -> bool:
self_type = self.__pydantic_generic_metadata__['origin'] or self.__class__
other_type = other.__pydantic_generic_metadata__['origin'] or other.__class__

return (
# Perform common checks first
if not (
self_type == other_type
and self.__dict__ == other.__dict__
and self.__pydantic_private__ == other.__pydantic_private__
and self.__pydantic_extra__ == other.__pydantic_extra__
)
):
return False

# We only want to compare pydantic fields but ignoring fields is costly.
# We'll perform a fast check first, and fallback only when needed
# See GH-7444 and GH-7825 for rationale and a performance benchmark

# First, do the fast (and sometimes faulty) __dict__ comparison
if self.__dict__ == other.__dict__:
# If the check above passes, then pydantic fields are equal, we can return early
return True

# We don't want to trigger unnecessary costly filtering of __dict__ on all unequal objects, so we return
# early if there are no keys to ignore (we would just return False later on anyway)
model_fields = type(self).model_fields.keys()
if self.__dict__.keys() <= model_fields and other.__dict__.keys() <= model_fields:
return False

# If we reach here, there are non-pydantic-fields keys, mapped to unequal values, that we need to ignore
# Resort to costly filtering of the __dict__ objects
# We use operator.itemgetter because it is much faster than dict comprehensions
# NOTE: Contrary to standard python class and instances, when the Model class has a default value for an
# attribute and the model instance doesn't have a corresponding attribute, accessing the missing attribute
# raises an error in BaseModel.__getattr__ instead of returning the class attribute
# So we can use operator.itemgetter() instead of operator.attrgetter()
getter = operator.itemgetter(*model_fields) if model_fields else lambda _: _utils._SENTINEL
try:
return getter(self.__dict__) == getter(other.__dict__)
except KeyError:
# In rare cases (such as when using the deprecated BaseModel.copy() method),
# the __dict__ may not contain all model fields, which is how we can get here.
# getter(self.__dict__) is much faster than any 'safe' method that accounts
# for missing keys, and wrapping it in a `try` doesn't slow things down much
# in the common case.
self_fields_proxy = _utils.SafeGetItemProxy(self.__dict__)
other_fields_proxy = _utils.SafeGetItemProxy(other.__dict__)
return getter(self_fields_proxy) == getter(other_fields_proxy)

# other instance is not a BaseModel
else:
return NotImplemented # delegate to the other item in the comparison

Expand Down

0 comments on commit 62ed0aa

Please sign in to comment.