Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

#418: Modify B017 to no longer have a false negative when raises() is imported directly from pytest #424

Merged
merged 6 commits into from Nov 26, 2023
28 changes: 23 additions & 5 deletions bugbear.py
Expand Up @@ -547,6 +547,9 @@ def visit_Import(self, node):
self.check_for_b005(node)
self.generic_visit(node)

def visit_ImportFrom(self, node):
self.visit_Import(node)

def visit_Set(self, node):
self.check_for_b033(node)
self.generic_visit(node)
Expand All @@ -555,6 +558,9 @@ def check_for_b005(self, node):
if isinstance(node, ast.Import):
for name in node.names:
self._b005_imports.add(name.asname or name.name)
elif isinstance(node, ast.ImportFrom):
for name in node.names:
self._b005_imports.add(f"{node.module}.{name.name or name.asname}")
elif isinstance(node, ast.Call):
if node.func.attr not in B005.methods:
return # method name doesn't match
Expand Down Expand Up @@ -652,13 +658,25 @@ def check_for_b017(self, node):

if (
hasattr(item_context, "func")
and isinstance(item_context.func, ast.Attribute)
and (
item_context.func.attr == "assertRaises"
(
isinstance(item_context.func, ast.Attribute)
and (
item_context.func.attr == "assertRaises"
or (
item_context.func.attr == "raises"
and isinstance(item_context.func.value, ast.Name)
and item_context.func.value.id == "pytest"
and "match"
not in [kwd.arg for kwd in item_context.keywords]
)
)
)
or (
item_context.func.attr == "raises"
and isinstance(item_context.func.value, ast.Name)
and item_context.func.value.id == "pytest"
isinstance(item_context.func, ast.Name)
and item_context.func.id == "raises"
and isinstance(item_context.func.ctx, ast.Load)
and "pytest.raises" in self._b005_imports
and "match" not in [kwd.arg for kwd in item_context.keywords]
)
)
Expand Down
12 changes: 12 additions & 0 deletions tests/b017.py
Expand Up @@ -7,6 +7,7 @@
import unittest

import pytest
from pytest import raises

CONSTANT = True

Expand All @@ -28,31 +29,42 @@ def evil_raises(self) -> None:
raise Exception("Evil I say!")
with pytest.raises(Exception):
raise Exception("Evil I say!")
with raises(Exception):
raise Exception("Evil I say!")
# These are evil as well but we are only testing inside a with statement
self.assertRaises(Exception, lambda x, y: x / y, 1, y=0)
pytest.raises(Exception, lambda x, y: x / y, 1, y=0)
raises(Exception, lambda x, y: x / y, 1, y=0)

def context_manager_raises(self) -> None:
with self.assertRaises(Exception) as ex:
raise Exception("Context manager is good")
with pytest.raises(Exception) as pyt_ex:
raise Exception("Context manager is good")
with raises(Exception) as r_ex:
raise Exception("Context manager is good")

self.assertEqual("Context manager is good", str(ex.exception))
self.assertEqual("Context manager is good", str(pyt_ex.value))
self.assertEqual("Context manager is good", str(r_ex.value))

def regex_raises(self) -> None:
with self.assertRaisesRegex(Exception, "Regex is good"):
raise Exception("Regex is good")
with pytest.raises(Exception, match="Regex is good"):
raise Exception("Regex is good")
with raises(Exception, match="Regex is good"):
raise Exception("Regex is good")

def non_context_manager_raises(self) -> None:
self.assertRaises(ZeroDivisionError, lambda x, y: x / y, 1, y=0)
pytest.raises(ZeroDivisionError, lambda x, y: x / y, 1, y=0)
raises(ZeroDivisionError, lambda x, y: x / y, 1, y=0)

def raises_with_absolute_reference(self):
with self.assertRaises(asyncio.CancelledError):
Foo()
with pytest.raises(asyncio.CancelledError):
Foo()
with raises(asyncio.CancelledError):
Foo()
4 changes: 3 additions & 1 deletion tests/test_bugbear.py
Expand Up @@ -258,7 +258,7 @@ def test_b017(self):
filename = Path(__file__).absolute().parent / "b017.py"
bbc = BugBearChecker(filename=str(filename))
errors = list(bbc.run())
expected = self.errors(B017(25, 8), B017(27, 8), B017(29, 8))
expected = self.errors(B017(26, 8), B017(28, 8), B017(30, 8), B017(32, 8))
self.assertEqual(errors, expected)

def test_b018_functions(self):
Expand Down Expand Up @@ -530,9 +530,11 @@ def test_b908(self):
B908(15, 8),
B908(21, 8),
B908(27, 8),
B017(37, 0),
B908(37, 0),
B908(41, 0),
B908(45, 0),
B017(56, 0),
)
self.assertEqual(errors, expected)

Expand Down