Skip to content

Commit

Permalink
implement fix
Browse files Browse the repository at this point in the history
  • Loading branch information
QuentinSoubeyranAqemia committed Oct 16, 2023
1 parent 6fb10b1 commit 65837da
Showing 1 changed file with 150 additions and 29 deletions.
179 changes: 150 additions & 29 deletions pydantic/main.py
@@ -1,11 +1,13 @@
"""Logic for creating models."""
from __future__ import annotations as _annotations

import enum
import operator
import types
import typing
import warnings
from copy import copy, deepcopy
from typing import Any, ClassVar
from typing import Any, ClassVar, Generic, Mapping, TypeVar

import pydantic_core
import typing_extensions
Expand All @@ -28,7 +30,13 @@
from .config import ConfigDict
from .errors import PydanticUndefinedAnnotation, PydanticUserError
from .fields import ComputedFieldInfo, FieldInfo, ModelPrivateAttr
from .json_schema import DEFAULT_REF_TEMPLATE, GenerateJsonSchema, JsonSchemaMode, JsonSchemaValue, model_json_schema
from .json_schema import (
DEFAULT_REF_TEMPLATE,
GenerateJsonSchema,
JsonSchemaMode,
JsonSchemaValue,
model_json_schema,
)
from .warnings import PydanticDeprecatedSince20

if typing.TYPE_CHECKING:
Expand Down Expand Up @@ -56,6 +64,42 @@

_object_setattr = _model_construction.object_setattr

_K = TypeVar('_K')
_V = TypeVar('_V')


# We need a sentinel value for missing fields when comparing models
# Models are equals if-and-only-if they miss the same fields, and since None is a legitimate value
# we can't default to None
# We use the single-value enum trick to allow correct typing when using a sentinel
class _SentinelType(enum.Enum):
SENTINEL = enum.auto()


_SENTINEL = _SentinelType.SENTINEL


class _SafeGetItemProxy(Generic[_K, _V]):
"""Wrapper redirecting `__getitem__` to `get` with a sentinel value as default
This makes is safe to use in `operator.itemgetter` when some keys may be missing
"""

wrapped: Mapping[_K, _V]

def __init__(self, __mapping: Mapping[_K, _V], /) -> None:
self.wrapped = __mapping

def __getitem__(self, __key: _K) -> _SentinelType | _V:
return self.wrapped.get(__key, _SENTINEL)

# required to pass the proxy to operator.itemgetter instances
def __contains__(self, __key: _K) -> bool:
return self.wrapped.__contains__(__key)

def __repr__(self) -> str:
return f'{type(self).__name__}({self.wrapped!r})'


class BaseModel(metaclass=_model_construction.ModelMetaclass):
"""Usage docs: https://docs.pydantic.dev/2.4/concepts/models/
Expand Down Expand Up @@ -383,7 +427,11 @@ def model_json_schema(
The JSON schema for the given model class.
"""
return model_json_schema(
cls, by_alias=by_alias, ref_template=ref_template, schema_generator=schema_generator, mode=mode
cls,
by_alias=by_alias,
ref_template=ref_template,
schema_generator=schema_generator,
mode=mode,
)

@classmethod
Expand Down Expand Up @@ -718,7 +766,10 @@ def __deepcopy__(self: Model, memo: dict[int, Any] | None = None) -> Model:
_object_setattr(
m,
'__pydantic_private__',
deepcopy({k: v for k, v in self.__pydantic_private__.items() if v is not PydanticUndefined}, memo=memo),
deepcopy(
{k: v for k, v in self.__pydantic_private__.items() if v is not PydanticUndefined},
memo=memo,
),
)

return m
Expand Down Expand Up @@ -866,12 +917,41 @@ def __eq__(self, other: Any) -> bool:
self_type = self.__pydantic_generic_metadata__['origin'] or self.__class__
other_type = other.__pydantic_generic_metadata__['origin'] or other.__class__

return (
# Perform common checks first
if not (
self_type == other_type
and self.__dict__ == other.__dict__
and self.__pydantic_private__ == other.__pydantic_private__
and self.__pydantic_extra__ == other.__pydantic_extra__
)
):
return False

# Fix GH-7444 by comparing only pydantic fields
# We provide a fast-path for performance: __dict__ comparison is *much* faster
# See tests/benchmarks/test_basemodel_eq_performances.py and GH-7825 for benchmarks
if self.__dict__ == other.__dict__:
# If the check above passes, then pydantic fields are equal, we can return early
return True
else:
# Else, we need to perform a more detailed, costlier comparison
# We use operator.itemgetter because it is much faster than dict comprehensions
# NOTE: Contratry to standard python class and instances, when the Model class has
# attribute default values and the model instance doesn't has a corresponding
# attribute, accessing the missing attribute raises an error in
# __getattr__ instance of returning the class attribute
# Thus, using operator.itemgetter() instead of operator.attrgetter() is valid
model_fields = type(self).model_fields.keys()
getter = operator.itemgetter(*model_fields) if model_fields else lambda _: _SENTINEL
try:
return getter(self.__dict__) == getter(other.__dict__)
except KeyError:
# In rare cases (such as when using the deprecated BaseModel.copy() method),
# the __dict__ may not contain all model fields, which is how we can get here.
# getter(self.__dict__) is much faster than any 'safe' method that accounts
# for missing keys, and wrapping it in a `try` doesn't slow things down much
# in the common case.
self_fields_proxy = _SafeGetItemProxy(self.__dict__)
other_fields_proxy = _SafeGetItemProxy(other.__dict__)
return getter(self_fields_proxy) == getter(other_fields_proxy)
else:
return NotImplemented # delegate to the other item in the comparison

Expand Down Expand Up @@ -944,10 +1024,14 @@ def __str__(self) -> str:
# ##### Deprecated methods from v1 #####
@property
@typing_extensions.deprecated(
'The `__fields__` attribute is deprecated, use `model_fields` instead.', category=PydanticDeprecatedSince20
'The `__fields__` attribute is deprecated, use `model_fields` instead.',
category=PydanticDeprecatedSince20,
)
def __fields__(self) -> dict[str, FieldInfo]:
warnings.warn('The `__fields__` attribute is deprecated, use `model_fields` instead.', DeprecationWarning)
warnings.warn(
'The `__fields__` attribute is deprecated, use `model_fields` instead.',
DeprecationWarning,
)
return self.model_fields

@property
Expand All @@ -957,12 +1041,14 @@ def __fields__(self) -> dict[str, FieldInfo]:
)
def __fields_set__(self) -> set[str]:
warnings.warn(
'The `__fields_set__` attribute is deprecated, use `model_fields_set` instead.', DeprecationWarning
'The `__fields_set__` attribute is deprecated, use `model_fields_set` instead.',
DeprecationWarning,
)
return self.__pydantic_fields_set__

@typing_extensions.deprecated(
'The `dict` method is deprecated; use `model_dump` instead.', category=PydanticDeprecatedSince20
'The `dict` method is deprecated; use `model_dump` instead.',
category=PydanticDeprecatedSince20,
)
def dict( # noqa: D102
self,
Expand All @@ -985,7 +1071,8 @@ def dict( # noqa: D102
)

@typing_extensions.deprecated(
'The `json` method is deprecated; use `model_dump_json` instead.', category=PydanticDeprecatedSince20
'The `json` method is deprecated; use `model_dump_json` instead.',
category=PydanticDeprecatedSince20,
)
def json( # noqa: D102
self,
Expand Down Expand Up @@ -1018,10 +1105,14 @@ def json( # noqa: D102

@classmethod
@typing_extensions.deprecated(
'The `parse_obj` method is deprecated; use `model_validate` instead.', category=PydanticDeprecatedSince20
'The `parse_obj` method is deprecated; use `model_validate` instead.',
category=PydanticDeprecatedSince20,
)
def parse_obj(cls: type[Model], obj: Any) -> Model: # noqa: D102
warnings.warn('The `parse_obj` method is deprecated; use `model_validate` instead.', DeprecationWarning)
warnings.warn(
'The `parse_obj` method is deprecated; use `model_validate` instead.',
DeprecationWarning,
)
return cls.model_validate(obj)

@classmethod
Expand Down Expand Up @@ -1122,20 +1213,26 @@ def from_orm(cls: type[Model], obj: Any) -> Model: # noqa: D102
)
if not cls.model_config.get('from_attributes', None):
raise PydanticUserError(
'You must set the config attribute `from_attributes=True` to use from_orm', code=None
'You must set the config attribute `from_attributes=True` to use from_orm',
code=None,
)
return cls.model_validate(obj)

@classmethod
@typing_extensions.deprecated(
'The `construct` method is deprecated; use `model_construct` instead.', category=PydanticDeprecatedSince20
'The `construct` method is deprecated; use `model_construct` instead.',
category=PydanticDeprecatedSince20,
)
def construct(cls: type[Model], _fields_set: set[str] | None = None, **values: Any) -> Model: # noqa: D102
warnings.warn('The `construct` method is deprecated; use `model_construct` instead.', DeprecationWarning)
warnings.warn(
'The `construct` method is deprecated; use `model_construct` instead.',
DeprecationWarning,
)
return cls.model_construct(_fields_set=_fields_set, **values)

@typing_extensions.deprecated(
'The copy method is deprecated; use `model_copy` instead.', category=PydanticDeprecatedSince20
'The copy method is deprecated; use `model_copy` instead.',
category=PydanticDeprecatedSince20,
)
def copy(
self: Model,
Expand Down Expand Up @@ -1179,7 +1276,12 @@ def copy(

values = dict(
copy_internals._iter(
self, to_dict=False, by_alias=False, include=include, exclude=exclude, exclude_unset=False
self,
to_dict=False,
by_alias=False,
include=include,
exclude=exclude,
exclude_unset=False,
),
**(update or {}),
)
Expand Down Expand Up @@ -1213,12 +1315,16 @@ def copy(

@classmethod
@typing_extensions.deprecated(
'The `schema` method is deprecated; use `model_json_schema` instead.', category=PydanticDeprecatedSince20
'The `schema` method is deprecated; use `model_json_schema` instead.',
category=PydanticDeprecatedSince20,
)
def schema( # noqa: D102
cls, by_alias: bool = True, ref_template: str = DEFAULT_REF_TEMPLATE
) -> typing.Dict[str, Any]: # noqa UP006
warnings.warn('The `schema` method is deprecated; use `model_json_schema` instead.', DeprecationWarning)
warnings.warn(
'The `schema` method is deprecated; use `model_json_schema` instead.',
DeprecationWarning,
)
return cls.model_json_schema(by_alias=by_alias, ref_template=ref_template)

@classmethod
Expand All @@ -1227,7 +1333,11 @@ def schema( # noqa: D102
category=PydanticDeprecatedSince20,
)
def schema_json( # noqa: D102
cls, *, by_alias: bool = True, ref_template: str = DEFAULT_REF_TEMPLATE, **dumps_kwargs: Any
cls,
*,
by_alias: bool = True,
ref_template: str = DEFAULT_REF_TEMPLATE,
**dumps_kwargs: Any,
) -> str: # pragma: no cover
import json

Expand All @@ -1245,10 +1355,14 @@ def schema_json( # noqa: D102

@classmethod
@typing_extensions.deprecated(
'The `validate` method is deprecated; use `model_validate` instead.', category=PydanticDeprecatedSince20
'The `validate` method is deprecated; use `model_validate` instead.',
category=PydanticDeprecatedSince20,
)
def validate(cls: type[Model], value: Any) -> Model: # noqa: D102
warnings.warn('The `validate` method is deprecated; use `model_validate` instead.', DeprecationWarning)
warnings.warn(
'The `validate` method is deprecated; use `model_validate` instead.',
DeprecationWarning,
)
return cls.model_validate(value)

@classmethod
Expand All @@ -1258,17 +1372,22 @@ def validate(cls: type[Model], value: Any) -> Model: # noqa: D102
)
def update_forward_refs(cls, **localns: Any) -> None: # noqa: D102
warnings.warn(
'The `update_forward_refs` method is deprecated; use `model_rebuild` instead.', DeprecationWarning
'The `update_forward_refs` method is deprecated; use `model_rebuild` instead.',
DeprecationWarning,
)
if localns: # pragma: no cover
raise TypeError('`localns` arguments are not longer accepted.')
cls.model_rebuild(force=True)

@typing_extensions.deprecated(
'The private method `_iter` will be removed and should no longer be used.', category=PydanticDeprecatedSince20
'The private method `_iter` will be removed and should no longer be used.',
category=PydanticDeprecatedSince20,
)
def _iter(self, *args: Any, **kwargs: Any) -> Any:
warnings.warn('The private method `_iter` will be removed and should no longer be used.', DeprecationWarning)
warnings.warn(
'The private method `_iter` will be removed and should no longer be used.',
DeprecationWarning,
)

from .deprecated import copy_internals

Expand All @@ -1294,7 +1413,8 @@ def _copy_and_set_values(self, *args: Any, **kwargs: Any) -> Any:
)
def _get_value(cls, *args: Any, **kwargs: Any) -> Any:
warnings.warn(
'The private method `_get_value` will be removed and should no longer be used.', DeprecationWarning
'The private method `_get_value` will be removed and should no longer be used.',
DeprecationWarning,
)

from .deprecated import copy_internals
Expand All @@ -1307,7 +1427,8 @@ def _get_value(cls, *args: Any, **kwargs: Any) -> Any:
)
def _calculate_keys(self, *args: Any, **kwargs: Any) -> Any:
warnings.warn(
'The private method `_calculate_keys` will be removed and should no longer be used.', DeprecationWarning
'The private method `_calculate_keys` will be removed and should no longer be used.',
DeprecationWarning,
)

from .deprecated import copy_internals
Expand Down

0 comments on commit 65837da

Please sign in to comment.