Skip to content

Commit

Permalink
Better-support mypy strict equality flag (#8799)
Browse files Browse the repository at this point in the history
  • Loading branch information
dmontagu committed Feb 13, 2024
1 parent 14ce997 commit 3c6bddb
Show file tree
Hide file tree
Showing 5 changed files with 135 additions and 55 deletions.
110 changes: 56 additions & 54 deletions pydantic/main.py
Expand Up @@ -867,60 +867,62 @@ def __setstate__(self, state: dict[Any, Any]) -> None:
_object_setattr(self, '__pydantic_private__', state['__pydantic_private__'])
_object_setattr(self, '__dict__', state['__dict__'])

def __eq__(self, other: Any) -> bool:
if isinstance(other, BaseModel):
# When comparing instances of generic types for equality, as long as all field values are equal,
# only require their generic origin types to be equal, rather than exact type equality.
# This prevents headaches like MyGeneric(x=1) != MyGeneric[Any](x=1).
self_type = self.__pydantic_generic_metadata__['origin'] or self.__class__
other_type = other.__pydantic_generic_metadata__['origin'] or other.__class__

# Perform common checks first
if not (
self_type == other_type
and self.__pydantic_private__ == other.__pydantic_private__
and self.__pydantic_extra__ == other.__pydantic_extra__
):
return False

# We only want to compare pydantic fields but ignoring fields is costly.
# We'll perform a fast check first, and fallback only when needed
# See GH-7444 and GH-7825 for rationale and a performance benchmark

# First, do the fast (and sometimes faulty) __dict__ comparison
if self.__dict__ == other.__dict__:
# If the check above passes, then pydantic fields are equal, we can return early
return True

# We don't want to trigger unnecessary costly filtering of __dict__ on all unequal objects, so we return
# early if there are no keys to ignore (we would just return False later on anyway)
model_fields = type(self).model_fields.keys()
if self.__dict__.keys() <= model_fields and other.__dict__.keys() <= model_fields:
return False

# If we reach here, there are non-pydantic-fields keys, mapped to unequal values, that we need to ignore
# Resort to costly filtering of the __dict__ objects
# We use operator.itemgetter because it is much faster than dict comprehensions
# NOTE: Contrary to standard python class and instances, when the Model class has a default value for an
# attribute and the model instance doesn't have a corresponding attribute, accessing the missing attribute
# raises an error in BaseModel.__getattr__ instead of returning the class attribute
# So we can use operator.itemgetter() instead of operator.attrgetter()
getter = operator.itemgetter(*model_fields) if model_fields else lambda _: _utils._SENTINEL
try:
return getter(self.__dict__) == getter(other.__dict__)
except KeyError:
# In rare cases (such as when using the deprecated BaseModel.copy() method),
# the __dict__ may not contain all model fields, which is how we can get here.
# getter(self.__dict__) is much faster than any 'safe' method that accounts
# for missing keys, and wrapping it in a `try` doesn't slow things down much
# in the common case.
self_fields_proxy = _utils.SafeGetItemProxy(self.__dict__)
other_fields_proxy = _utils.SafeGetItemProxy(other.__dict__)
return getter(self_fields_proxy) == getter(other_fields_proxy)

# other instance is not a BaseModel
else:
return NotImplemented # delegate to the other item in the comparison
if not typing.TYPE_CHECKING:

def __eq__(self, other: Any) -> bool:
if isinstance(other, BaseModel):
# When comparing instances of generic types for equality, as long as all field values are equal,
# only require their generic origin types to be equal, rather than exact type equality.
# This prevents headaches like MyGeneric(x=1) != MyGeneric[Any](x=1).
self_type = self.__pydantic_generic_metadata__['origin'] or self.__class__
other_type = other.__pydantic_generic_metadata__['origin'] or other.__class__

# Perform common checks first
if not (
self_type == other_type
and self.__pydantic_private__ == other.__pydantic_private__
and self.__pydantic_extra__ == other.__pydantic_extra__
):
return False

# We only want to compare pydantic fields but ignoring fields is costly.
# We'll perform a fast check first, and fallback only when needed
# See GH-7444 and GH-7825 for rationale and a performance benchmark

# First, do the fast (and sometimes faulty) __dict__ comparison
if self.__dict__ == other.__dict__:
# If the check above passes, then pydantic fields are equal, we can return early
return True

# We don't want to trigger unnecessary costly filtering of __dict__ on all unequal objects, so we return
# early if there are no keys to ignore (we would just return False later on anyway)
model_fields = type(self).model_fields.keys()
if self.__dict__.keys() <= model_fields and other.__dict__.keys() <= model_fields:
return False

# If we reach here, there are non-pydantic-fields keys, mapped to unequal values, that we need to ignore
# Resort to costly filtering of the __dict__ objects
# We use operator.itemgetter because it is much faster than dict comprehensions
# NOTE: Contrary to standard python class and instances, when the Model class has a default value for an
# attribute and the model instance doesn't have a corresponding attribute, accessing the missing attribute
# raises an error in BaseModel.__getattr__ instead of returning the class attribute
# So we can use operator.itemgetter() instead of operator.attrgetter()
getter = operator.itemgetter(*model_fields) if model_fields else lambda _: _utils._SENTINEL
try:
return getter(self.__dict__) == getter(other.__dict__)
except KeyError:
# In rare cases (such as when using the deprecated BaseModel.copy() method),
# the __dict__ may not contain all model fields, which is how we can get here.
# getter(self.__dict__) is much faster than any 'safe' method that accounts
# for missing keys, and wrapping it in a `try` doesn't slow things down much
# in the common case.
self_fields_proxy = _utils.SafeGetItemProxy(self.__dict__)
other_fields_proxy = _utils.SafeGetItemProxy(other.__dict__)
return getter(self_fields_proxy) == getter(other_fields_proxy)

# other instance is not a BaseModel
else:
return NotImplemented # delegate to the other item in the comparison

if typing.TYPE_CHECKING:
# We put `__init_subclass__` in a TYPE_CHECKING block because, even though we want the type-checking benefits
Expand Down
48 changes: 48 additions & 0 deletions tests/mypy/configs/pyproject-plugin-strict-equality.toml
@@ -0,0 +1,48 @@
[build-system]
requires = ["poetry>=0.12"]
build_backend = "poetry.masonry.api"

[tool.poetry]
name = "test"
version = "0.0.1"
readme = "README.md"
authors = [
"author@example.com"
]

[tool.poetry.dependencies]
python = "*"

[tool.pytest.ini_options]
addopts = "-v -p no:warnings"

[tool.mypy]
plugins = "pydantic.mypy"
ignore_missing_imports = true
warn_return_any = true
warn_unreachable = true
warn_unused_configs = true
follow_imports = "normal"
show_column_numbers = true
strict_optional = true
warn_redundant_casts = true
pretty = false
strict = true
warn_unused_ignores = true
check_untyped_defs = true
disallow_untyped_calls = true
disallow_untyped_defs = true
disallow_untyped_decorators = false
strict_equality = true

[tool.pydantic-mypy]
init_forbid_extra = true
init_typed = true
warn_required_dynamic_aliases = true


[[tool.mypy.overrides]]
module = [
'pydantic_core.*',
]
follow_imports = "skip"
11 changes: 11 additions & 0 deletions tests/mypy/modules/strict_equality.py
@@ -0,0 +1,11 @@
from pydantic import BaseModel


class User(BaseModel):
username: str


user = User(username='test')
print(user == 'test')
print(user.username == int('1'))
print(user.username == 'test')
@@ -0,0 +1,13 @@
from pydantic import BaseModel


class User(BaseModel):
username: str


user = User(username='test')
print(user == 'test')
# MYPY: error: Non-overlapping equality check (left operand type: "User", right operand type: "Literal['test']") [comparison-overlap]
print(user.username == int('1'))
# MYPY: error: Non-overlapping equality check (left operand type: "str", right operand type: "int") [comparison-overlap]
print(user.username == 'test')
8 changes: 7 additions & 1 deletion tests/mypy/test_mypy.py
Expand Up @@ -112,6 +112,7 @@ def build(self) -> List[Union[Tuple[str, str], Any]]:
('pyproject-default.toml', 'computed_fields.py'),
('pyproject-default.toml', 'with_config_decorator.py'),
('pyproject-plugin-no-strict-optional.toml', 'no_strict_optional.py'),
('pyproject-plugin-strict-equality.toml', 'strict_equality.py'),
]
)

Expand Down Expand Up @@ -192,6 +193,8 @@ def test_mypy_results(config_filename: str, python_filename: str, request: pytes
if test_config.existing is not None:
existing_output_code = test_config.existing.output_path.read_text()
print(f'Comparing output with {test_config.existing.output_path}')
else:
print('Expecting no mypy errors')

merged_output = merge_python_and_mypy_output(input_code, mypy_out)

Expand All @@ -202,6 +205,9 @@ def test_mypy_results(config_filename: str, python_filename: str, request: pytes
test_config.current.output_path.parent.mkdir(parents=True, exist_ok=True)
test_config.current.output_path.write_text(merged_output)
else:
print('**** Merged Output ****')
print(merged_output)
print('***********************')
assert existing_output_code is not None, 'No output file found, run `make test-mypy-update` to create it'
assert merged_output == existing_output_code
expected_returncode = get_expected_return_code(existing_output_code)
Expand Down Expand Up @@ -279,7 +285,7 @@ def merge_python_and_mypy_output(source_code: str, mypy_output: str) -> str:
if not line:
continue
try:
line_number, message = line.split(':', maxsplit=1)
line_number, message = re.split(r':(?:\d+:)?', line, maxsplit=1)
merged_lines.insert(int(line_number), (f'# MYPY: {message.strip()}', True))
except ValueError:
# This could happen due to lack of a ':' in `line`, or the pre-':' contents not being a number
Expand Down

0 comments on commit 3c6bddb

Please sign in to comment.