Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Only compare pydantic fields in BaseModel.__eq__ instead of whole __dict__ #7825

Merged
merged 14 commits into from Nov 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
12 changes: 2 additions & 10 deletions pydantic/_internal/_model_construction.py
Expand Up @@ -25,7 +25,7 @@
from ._mock_val_ser import MockValSer, set_model_mocks
from ._schema_generation_shared import CallbackGetCoreSchemaHandler
from ._typing_extra import get_cls_types_namespace, is_annotated, is_classvar, parent_frame_namespace
from ._utils import ClassAttribute
from ._utils import ClassAttribute, SafeGetItemProxy
from ._validate_call import ValidateCallWrapper

if typing.TYPE_CHECKING:
Expand Down Expand Up @@ -422,19 +422,11 @@ def hash_func(self: Any) -> int:
# all model fields, which is how we can get here.
# getter(self.__dict__) is much faster than any 'safe' method that accounts for missing keys,
# and wrapping it in a `try` doesn't slow things down much in the common case.
return hash(getter(FallbackDict(self.__dict__))) # type: ignore
return hash(getter(SafeGetItemProxy(self.__dict__)))

return hash_func


class FallbackDict:
def __init__(self, inner):
self.inner = inner

def __getitem__(self, key):
return self.inner.get(key)


def set_model_fields(
cls: type[BaseModel], bases: tuple[type[Any], ...], config_wrapper: ConfigWrapper, types_namespace: dict[str, Any]
) -> None:
Expand Down
33 changes: 30 additions & 3 deletions pydantic/_internal/_utils.py
Expand Up @@ -4,14 +4,15 @@
"""
from __future__ import annotations as _annotations

import dataclasses
import keyword
import typing
import weakref
from collections import OrderedDict, defaultdict, deque
from copy import deepcopy
from itertools import zip_longest
from types import BuiltinFunctionType, CodeType, FunctionType, GeneratorType, LambdaType, ModuleType
from typing import Any, TypeVar
from typing import Any, Mapping, TypeVar

from typing_extensions import TypeAlias, TypeGuard

Expand Down Expand Up @@ -317,7 +318,7 @@ def smart_deepcopy(obj: Obj) -> Obj:
return deepcopy(obj) # slowest way when we actually might need a deepcopy


_EMPTY = object()
_SENTINEL = object()


def all_identical(left: typing.Iterable[Any], right: typing.Iterable[Any]) -> bool:
Expand All @@ -329,7 +330,33 @@ def all_identical(left: typing.Iterable[Any], right: typing.Iterable[Any]) -> bo
>>> all_identical([a, b, [a]], [a, b, [a]]) # new list object, while "equal" is not "identical"
False
"""
for left_item, right_item in zip_longest(left, right, fillvalue=_EMPTY):
for left_item, right_item in zip_longest(left, right, fillvalue=_SENTINEL):
if left_item is not right_item:
return False
return True


@dataclasses.dataclass(frozen=True)
class SafeGetItemProxy:
"""Wrapper redirecting `__getitem__` to `get` with a sentinel value as default

This makes is safe to use in `operator.itemgetter` when some keys may be missing
"""

# Define __slots__manually for performances
# @dataclasses.dataclass() only support slots=True in python>=3.10
__slots__ = ('wrapped',)

wrapped: Mapping[str, Any]

def __getitem__(self, __key: str) -> Any:
return self.wrapped.get(__key, _SENTINEL)

# required to pass the object to operator.itemgetter() instances due to a quirk of typeshed
# https://github.com/python/mypy/issues/13713
# https://github.com/python/typeshed/pull/8785
# Since this is typing-only, hide it in a typing.TYPE_CHECKING block
if typing.TYPE_CHECKING:

def __contains__(self, __key: str) -> bool:
return self.wrapped.__contains__(__key)
45 changes: 42 additions & 3 deletions pydantic/main.py
@@ -1,6 +1,7 @@
"""Logic for creating models."""
from __future__ import annotations as _annotations

import operator
import sys
import types
import typing
Expand Down Expand Up @@ -868,12 +869,50 @@ def __eq__(self, other: Any) -> bool:
self_type = self.__pydantic_generic_metadata__['origin'] or self.__class__
other_type = other.__pydantic_generic_metadata__['origin'] or other.__class__

return (
# Perform common checks first
if not (
self_type == other_type
and self.__dict__ == other.__dict__
and self.__pydantic_private__ == other.__pydantic_private__
and self.__pydantic_extra__ == other.__pydantic_extra__
)
):
return False

# We only want to compare pydantic fields but ignoring fields is costly.
# We'll perform a fast check first, and fallback only when needed
# See GH-7444 and GH-7825 for rationale and a performance benchmark

# First, do the fast (and sometimes faulty) __dict__ comparison
if self.__dict__ == other.__dict__:
# If the check above passes, then pydantic fields are equal, we can return early
return True

# We don't want to trigger unnecessary costly filtering of __dict__ on all unequal objects, so we return
# early if there are no keys to ignore (we would just return False later on anyway)
model_fields = type(self).model_fields.keys()
if self.__dict__.keys() <= model_fields and other.__dict__.keys() <= model_fields:
QuentinSoubeyranAqemia marked this conversation as resolved.
Show resolved Hide resolved
return False

# If we reach here, there are non-pydantic-fields keys, mapped to unequal values, that we need to ignore
# Resort to costly filtering of the __dict__ objects
# We use operator.itemgetter because it is much faster than dict comprehensions
# NOTE: Contrary to standard python class and instances, when the Model class has a default value for an
# attribute and the model instance doesn't have a corresponding attribute, accessing the missing attribute
# raises an error in BaseModel.__getattr__ instead of returning the class attribute
# So we can use operator.itemgetter() instead of operator.attrgetter()
getter = operator.itemgetter(*model_fields) if model_fields else lambda _: _utils._SENTINEL
try:
return getter(self.__dict__) == getter(other.__dict__)
except KeyError:
# In rare cases (such as when using the deprecated BaseModel.copy() method),
# the __dict__ may not contain all model fields, which is how we can get here.
# getter(self.__dict__) is much faster than any 'safe' method that accounts
# for missing keys, and wrapping it in a `try` doesn't slow things down much
# in the common case.
self_fields_proxy = _utils.SafeGetItemProxy(self.__dict__)
other_fields_proxy = _utils.SafeGetItemProxy(other.__dict__)
return getter(self_fields_proxy) == getter(other_fields_proxy)

# other instance is not a BaseModel
else:
return NotImplemented # delegate to the other item in the comparison

Expand Down