Skip to content

Commit

Permalink
chore: share implementation with pydantic#7786
Browse files Browse the repository at this point in the history
  • Loading branch information
QuentinSoubeyranAqemia committed Nov 15, 2023
1 parent 4639866 commit c756c1f
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 52 deletions.
12 changes: 2 additions & 10 deletions pydantic/_internal/_model_construction.py
Expand Up @@ -25,7 +25,7 @@
from ._mock_val_ser import MockValSer, set_model_mocks
from ._schema_generation_shared import CallbackGetCoreSchemaHandler
from ._typing_extra import get_cls_types_namespace, is_annotated, is_classvar, parent_frame_namespace
from ._utils import ClassAttribute
from ._utils import ClassAttribute, SafeGetItemProxy
from ._validate_call import ValidateCallWrapper

if typing.TYPE_CHECKING:
Expand Down Expand Up @@ -422,19 +422,11 @@ def hash_func(self: Any) -> int:
# all model fields, which is how we can get here.
# getter(self.__dict__) is much faster than any 'safe' method that accounts for missing keys,
# and wrapping it in a `try` doesn't slow things down much in the common case.
return hash(getter(FallbackDict(self.__dict__))) # type: ignore
return hash(getter(SafeGetItemProxy(self.__dict__)))

return hash_func


class FallbackDict:
def __init__(self, inner):
self.inner = inner

def __getitem__(self, key):
return self.inner.get(key)


def set_model_fields(
cls: type[BaseModel], bases: tuple[type[Any], ...], config_wrapper: ConfigWrapper, types_namespace: dict[str, Any]
) -> None:
Expand Down
42 changes: 41 additions & 1 deletion pydantic/_internal/_utils.py
Expand Up @@ -4,14 +4,16 @@
"""
from __future__ import annotations as _annotations

import dataclasses
import enum
import keyword
import typing
import weakref
from collections import OrderedDict, defaultdict, deque
from copy import deepcopy
from itertools import zip_longest
from types import BuiltinFunctionType, CodeType, FunctionType, GeneratorType, LambdaType, ModuleType
from typing import Any, TypeVar
from typing import Any, Generic, Mapping, TypeVar

from typing_extensions import TypeAlias, TypeGuard

Expand Down Expand Up @@ -333,3 +335,41 @@ def all_identical(left: typing.Iterable[Any], right: typing.Iterable[Any]) -> bo
if left_item is not right_item:
return False
return True


_KeyType = TypeVar('_KeyType')
_ValueType = TypeVar('_ValueType')


# We need a sentinel value for missing fields when comparing models
# Models are equals if-and-only-if they miss the same fields, and since None is a legitimate value
# we can't default to None
# We use the single-value enum trick to allow correct typing when using a sentinel
# https://github.com/python/typing/issues/689#issuecomment-561568944
class SentinelType(enum.Enum):
SENTINEL = enum.auto()


SENTINEL = SentinelType.SENTINEL


@dataclasses.dataclass(frozen=True)
class SafeGetItemProxy(Generic[_KeyType, _ValueType]):
"""Wrapper redirecting `__getitem__` to `get` with a sentinel value as default
This makes is safe to use in `operator.itemgetter` when some keys may be missing
"""

wrapped: Mapping[_KeyType, _ValueType]

def __getitem__(self, __key: _KeyType) -> SentinelType | _ValueType:
return self.wrapped.get(__key, SENTINEL)

# required to pass the object to operator.itemgetter() instances due to a quirk of typeshed
# https://github.com/python/mypy/issues/13713
# https://github.com/python/typeshed/pull/8785
# Since this is typing-only, hide it in a typing.TYPE_CHECKING block
if typing.TYPE_CHECKING:

def __contains__(self, __key: _KeyType) -> bool:
return self.wrapped.__contains__(__key)
45 changes: 4 additions & 41 deletions pydantic/main.py
@@ -1,14 +1,13 @@
"""Logic for creating models."""
from __future__ import annotations as _annotations

import enum
import operator
import sys
import types
import typing
import warnings
from copy import copy, deepcopy
from typing import Any, ClassVar, Generic, Mapping, TypeVar
from typing import Any, ClassVar

import pydantic_core
import typing_extensions
Expand Down Expand Up @@ -58,42 +57,6 @@

_object_setattr = _model_construction.object_setattr

_K = TypeVar('_K')
_V = TypeVar('_V')


# We need a sentinel value for missing fields when comparing models
# Models are equals if-and-only-if they miss the same fields, and since None is a legitimate value
# we can't default to None
# We use the single-value enum trick to allow correct typing when using a sentinel
class _SentinelType(enum.Enum):
SENTINEL = enum.auto()


_SENTINEL = _SentinelType.SENTINEL


class _SafeGetItemProxy(Generic[_K, _V]):
"""Wrapper redirecting `__getitem__` to `get` with a sentinel value as default
This makes is safe to use in `operator.itemgetter` when some keys may be missing
"""

wrapped: Mapping[_K, _V]

def __init__(self, __mapping: Mapping[_K, _V]) -> None:
self.wrapped = __mapping

def __getitem__(self, __key: _K) -> _SentinelType | _V:
return self.wrapped.get(__key, _SENTINEL)

# required to pass the proxy to operator.itemgetter instances
def __contains__(self, __key: _K) -> bool:
return self.wrapped.__contains__(__key)

def __repr__(self) -> str:
return f'{type(self).__name__}({self.wrapped!r})'


class BaseModel(metaclass=_model_construction.ModelMetaclass):
"""Usage docs: https://docs.pydantic.dev/2.5/concepts/models/
Expand Down Expand Up @@ -936,7 +899,7 @@ def __eq__(self, other: Any) -> bool:
# attribute and the model instance doesn't have a corresponding attribute, accessing the missing attribute
# raises an error in BaseModel.__getattr__ instead of returning the class attribute
# So we can use operator.itemgetter() instead of operator.attrgetter()
getter = operator.itemgetter(*model_fields) if model_fields else lambda _: _SENTINEL
getter = operator.itemgetter(*model_fields) if model_fields else lambda _: _utils.SENTINEL
try:
return getter(self.__dict__) == getter(other.__dict__)
except KeyError:
Expand All @@ -945,8 +908,8 @@ def __eq__(self, other: Any) -> bool:
# getter(self.__dict__) is much faster than any 'safe' method that accounts
# for missing keys, and wrapping it in a `try` doesn't slow things down much
# in the common case.
self_fields_proxy = _SafeGetItemProxy(self.__dict__)
other_fields_proxy = _SafeGetItemProxy(other.__dict__)
self_fields_proxy = _utils.SafeGetItemProxy(self.__dict__)
other_fields_proxy = _utils.SafeGetItemProxy(other.__dict__)
return getter(self_fields_proxy) == getter(other_fields_proxy)

# other instance is not a BaseModel
Expand Down

0 comments on commit c756c1f

Please sign in to comment.