Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Better-support mypy strict equality flag #8799

Merged
merged 3 commits into from Feb 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
110 changes: 56 additions & 54 deletions pydantic/main.py
Expand Up @@ -867,60 +867,62 @@ def __setstate__(self, state: dict[Any, Any]) -> None:
_object_setattr(self, '__pydantic_private__', state['__pydantic_private__'])
_object_setattr(self, '__dict__', state['__dict__'])

def __eq__(self, other: Any) -> bool:
if isinstance(other, BaseModel):
# When comparing instances of generic types for equality, as long as all field values are equal,
# only require their generic origin types to be equal, rather than exact type equality.
# This prevents headaches like MyGeneric(x=1) != MyGeneric[Any](x=1).
self_type = self.__pydantic_generic_metadata__['origin'] or self.__class__
other_type = other.__pydantic_generic_metadata__['origin'] or other.__class__

# Perform common checks first
if not (
self_type == other_type
and self.__pydantic_private__ == other.__pydantic_private__
and self.__pydantic_extra__ == other.__pydantic_extra__
):
return False

# We only want to compare pydantic fields but ignoring fields is costly.
# We'll perform a fast check first, and fallback only when needed
# See GH-7444 and GH-7825 for rationale and a performance benchmark

# First, do the fast (and sometimes faulty) __dict__ comparison
if self.__dict__ == other.__dict__:
# If the check above passes, then pydantic fields are equal, we can return early
return True

# We don't want to trigger unnecessary costly filtering of __dict__ on all unequal objects, so we return
# early if there are no keys to ignore (we would just return False later on anyway)
model_fields = type(self).model_fields.keys()
if self.__dict__.keys() <= model_fields and other.__dict__.keys() <= model_fields:
return False

# If we reach here, there are non-pydantic-fields keys, mapped to unequal values, that we need to ignore
# Resort to costly filtering of the __dict__ objects
# We use operator.itemgetter because it is much faster than dict comprehensions
# NOTE: Contrary to standard python class and instances, when the Model class has a default value for an
# attribute and the model instance doesn't have a corresponding attribute, accessing the missing attribute
# raises an error in BaseModel.__getattr__ instead of returning the class attribute
# So we can use operator.itemgetter() instead of operator.attrgetter()
getter = operator.itemgetter(*model_fields) if model_fields else lambda _: _utils._SENTINEL
try:
return getter(self.__dict__) == getter(other.__dict__)
except KeyError:
# In rare cases (such as when using the deprecated BaseModel.copy() method),
# the __dict__ may not contain all model fields, which is how we can get here.
# getter(self.__dict__) is much faster than any 'safe' method that accounts
# for missing keys, and wrapping it in a `try` doesn't slow things down much
# in the common case.
self_fields_proxy = _utils.SafeGetItemProxy(self.__dict__)
other_fields_proxy = _utils.SafeGetItemProxy(other.__dict__)
return getter(self_fields_proxy) == getter(other_fields_proxy)

# other instance is not a BaseModel
else:
return NotImplemented # delegate to the other item in the comparison
if not typing.TYPE_CHECKING:

def __eq__(self, other: Any) -> bool:
if isinstance(other, BaseModel):
# When comparing instances of generic types for equality, as long as all field values are equal,
# only require their generic origin types to be equal, rather than exact type equality.
# This prevents headaches like MyGeneric(x=1) != MyGeneric[Any](x=1).
self_type = self.__pydantic_generic_metadata__['origin'] or self.__class__
other_type = other.__pydantic_generic_metadata__['origin'] or other.__class__

# Perform common checks first
if not (
self_type == other_type
and self.__pydantic_private__ == other.__pydantic_private__
and self.__pydantic_extra__ == other.__pydantic_extra__
):
return False

# We only want to compare pydantic fields but ignoring fields is costly.
# We'll perform a fast check first, and fallback only when needed
# See GH-7444 and GH-7825 for rationale and a performance benchmark

# First, do the fast (and sometimes faulty) __dict__ comparison
if self.__dict__ == other.__dict__:
# If the check above passes, then pydantic fields are equal, we can return early
return True

# We don't want to trigger unnecessary costly filtering of __dict__ on all unequal objects, so we return
# early if there are no keys to ignore (we would just return False later on anyway)
model_fields = type(self).model_fields.keys()
if self.__dict__.keys() <= model_fields and other.__dict__.keys() <= model_fields:
return False

# If we reach here, there are non-pydantic-fields keys, mapped to unequal values, that we need to ignore
# Resort to costly filtering of the __dict__ objects
# We use operator.itemgetter because it is much faster than dict comprehensions
# NOTE: Contrary to standard python class and instances, when the Model class has a default value for an
# attribute and the model instance doesn't have a corresponding attribute, accessing the missing attribute
# raises an error in BaseModel.__getattr__ instead of returning the class attribute
# So we can use operator.itemgetter() instead of operator.attrgetter()
getter = operator.itemgetter(*model_fields) if model_fields else lambda _: _utils._SENTINEL
try:
return getter(self.__dict__) == getter(other.__dict__)
except KeyError:
# In rare cases (such as when using the deprecated BaseModel.copy() method),
# the __dict__ may not contain all model fields, which is how we can get here.
# getter(self.__dict__) is much faster than any 'safe' method that accounts
# for missing keys, and wrapping it in a `try` doesn't slow things down much
# in the common case.
self_fields_proxy = _utils.SafeGetItemProxy(self.__dict__)
other_fields_proxy = _utils.SafeGetItemProxy(other.__dict__)
return getter(self_fields_proxy) == getter(other_fields_proxy)

# other instance is not a BaseModel
else:
return NotImplemented # delegate to the other item in the comparison

if typing.TYPE_CHECKING:
# We put `__init_subclass__` in a TYPE_CHECKING block because, even though we want the type-checking benefits
Expand Down
48 changes: 48 additions & 0 deletions tests/mypy/configs/pyproject-plugin-strict-equality.toml
@@ -0,0 +1,48 @@
[build-system]
requires = ["poetry>=0.12"]
build_backend = "poetry.masonry.api"

[tool.poetry]
name = "test"
version = "0.0.1"
readme = "README.md"
authors = [
"author@example.com"
]

[tool.poetry.dependencies]
python = "*"

[tool.pytest.ini_options]
addopts = "-v -p no:warnings"

[tool.mypy]
plugins = "pydantic.mypy"
ignore_missing_imports = true
warn_return_any = true
warn_unreachable = true
warn_unused_configs = true
follow_imports = "normal"
show_column_numbers = true
strict_optional = true
warn_redundant_casts = true
pretty = false
strict = true
warn_unused_ignores = true
check_untyped_defs = true
disallow_untyped_calls = true
disallow_untyped_defs = true
disallow_untyped_decorators = false
strict_equality = true

[tool.pydantic-mypy]
init_forbid_extra = true
init_typed = true
warn_required_dynamic_aliases = true


[[tool.mypy.overrides]]
module = [
'pydantic_core.*',
]
follow_imports = "skip"
11 changes: 11 additions & 0 deletions tests/mypy/modules/strict_equality.py
@@ -0,0 +1,11 @@
from pydantic import BaseModel


class User(BaseModel):
username: str


user = User(username='test')
print(user == 'test')
print(user.username == int('1'))
print(user.username == 'test')
@@ -0,0 +1,13 @@
from pydantic import BaseModel


class User(BaseModel):
username: str


user = User(username='test')
print(user == 'test')
# MYPY: error: Non-overlapping equality check (left operand type: "User", right operand type: "Literal['test']") [comparison-overlap]
print(user.username == int('1'))
# MYPY: error: Non-overlapping equality check (left operand type: "str", right operand type: "int") [comparison-overlap]
print(user.username == 'test')
8 changes: 7 additions & 1 deletion tests/mypy/test_mypy.py
Expand Up @@ -112,6 +112,7 @@ def build(self) -> List[Union[Tuple[str, str], Any]]:
('pyproject-default.toml', 'computed_fields.py'),
('pyproject-default.toml', 'with_config_decorator.py'),
('pyproject-plugin-no-strict-optional.toml', 'no_strict_optional.py'),
('pyproject-plugin-strict-equality.toml', 'strict_equality.py'),
]
)

Expand Down Expand Up @@ -192,6 +193,8 @@ def test_mypy_results(config_filename: str, python_filename: str, request: pytes
if test_config.existing is not None:
existing_output_code = test_config.existing.output_path.read_text()
print(f'Comparing output with {test_config.existing.output_path}')
else:
print('Expecting no mypy errors')
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wasted some time confused by the lack of this message


merged_output = merge_python_and_mypy_output(input_code, mypy_out)

Expand All @@ -202,6 +205,9 @@ def test_mypy_results(config_filename: str, python_filename: str, request: pytes
test_config.current.output_path.parent.mkdir(parents=True, exist_ok=True)
test_config.current.output_path.write_text(merged_output)
else:
print('**** Merged Output ****')
print(merged_output)
print('***********************')
Comment on lines +208 to +210
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this was helpful for debugging, and I don't see much reason not to include it

assert existing_output_code is not None, 'No output file found, run `make test-mypy-update` to create it'
assert merged_output == existing_output_code
expected_returncode = get_expected_return_code(existing_output_code)
Expand Down Expand Up @@ -279,7 +285,7 @@ def merge_python_and_mypy_output(source_code: str, mypy_output: str) -> str:
if not line:
continue
try:
line_number, message = line.split(':', maxsplit=1)
line_number, message = re.split(r':(?:\d+:)?', line, maxsplit=1)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if this is just a python 3.12 thing, but I was running into an issue where the messages had two numbers, like, 9:7: error: Non-overlapping equality check (left operand type: "User", right operand type: "Literal['test']") [comparison-overlap]. I think the second number is the character index. Considering this happened for me locally but I don't think it does in CI, it seems best to just prepare for it to be present in some cases and not others and split that off like this if present.

merged_lines.insert(int(line_number), (f'# MYPY: {message.strip()}', True))
except ValueError:
# This could happen due to lack of a ':' in `line`, or the pre-':' contents not being a number
Expand Down