From 393fea1e2d5b2d0612a820e2706f850a1ebc2800 Mon Sep 17 00:00:00 2001 From: azinneck0485 <123660683+azinneck0485@users.noreply.github.com> Date: Sat, 25 Nov 2023 19:10:32 -0500 Subject: [PATCH] #418: Modify B017 to no longer have a false negative when raises() is imported directly from pytest (#424) * handle ast.ImportFrom nodes in b005 check to add this to the list of imports * work out logic to determine if ast.Name node "raises" correspondes to the pytest.raises package * integrate new logic for B017 to the check itself * add check for matches keyword to prevent B017; whitespace corrections to pass tests * update tests to capture changes to B017 logic * files modified by pre-commit --- bugbear.py | 28 +++++++++++++++++++++++----- tests/b017.py | 12 ++++++++++++ tests/test_bugbear.py | 4 +++- 3 files changed, 38 insertions(+), 6 deletions(-) diff --git a/bugbear.py b/bugbear.py index 5fbfa19..ee2afd4 100644 --- a/bugbear.py +++ b/bugbear.py @@ -547,6 +547,9 @@ def visit_Import(self, node): self.check_for_b005(node) self.generic_visit(node) + def visit_ImportFrom(self, node): + self.visit_Import(node) + def visit_Set(self, node): self.check_for_b033(node) self.generic_visit(node) @@ -555,6 +558,9 @@ def check_for_b005(self, node): if isinstance(node, ast.Import): for name in node.names: self._b005_imports.add(name.asname or name.name) + elif isinstance(node, ast.ImportFrom): + for name in node.names: + self._b005_imports.add(f"{node.module}.{name.name or name.asname}") elif isinstance(node, ast.Call): if node.func.attr not in B005.methods: return # method name doesn't match @@ -652,13 +658,25 @@ def check_for_b017(self, node): if ( hasattr(item_context, "func") - and isinstance(item_context.func, ast.Attribute) and ( - item_context.func.attr == "assertRaises" + ( + isinstance(item_context.func, ast.Attribute) + and ( + item_context.func.attr == "assertRaises" + or ( + item_context.func.attr == "raises" + and isinstance(item_context.func.value, ast.Name) + and item_context.func.value.id == "pytest" + and "match" + not in [kwd.arg for kwd in item_context.keywords] + ) + ) + ) or ( - item_context.func.attr == "raises" - and isinstance(item_context.func.value, ast.Name) - and item_context.func.value.id == "pytest" + isinstance(item_context.func, ast.Name) + and item_context.func.id == "raises" + and isinstance(item_context.func.ctx, ast.Load) + and "pytest.raises" in self._b005_imports and "match" not in [kwd.arg for kwd in item_context.keywords] ) ) diff --git a/tests/b017.py b/tests/b017.py index 2365cd1..af52f6d 100644 --- a/tests/b017.py +++ b/tests/b017.py @@ -7,6 +7,7 @@ import unittest import pytest +from pytest import raises CONSTANT = True @@ -28,31 +29,42 @@ def evil_raises(self) -> None: raise Exception("Evil I say!") with pytest.raises(Exception): raise Exception("Evil I say!") + with raises(Exception): + raise Exception("Evil I say!") # These are evil as well but we are only testing inside a with statement self.assertRaises(Exception, lambda x, y: x / y, 1, y=0) pytest.raises(Exception, lambda x, y: x / y, 1, y=0) + raises(Exception, lambda x, y: x / y, 1, y=0) def context_manager_raises(self) -> None: with self.assertRaises(Exception) as ex: raise Exception("Context manager is good") with pytest.raises(Exception) as pyt_ex: raise Exception("Context manager is good") + with raises(Exception) as r_ex: + raise Exception("Context manager is good") self.assertEqual("Context manager is good", str(ex.exception)) self.assertEqual("Context manager is good", str(pyt_ex.value)) + self.assertEqual("Context manager is good", str(r_ex.value)) def regex_raises(self) -> None: with self.assertRaisesRegex(Exception, "Regex is good"): raise Exception("Regex is good") with pytest.raises(Exception, match="Regex is good"): raise Exception("Regex is good") + with raises(Exception, match="Regex is good"): + raise Exception("Regex is good") def non_context_manager_raises(self) -> None: self.assertRaises(ZeroDivisionError, lambda x, y: x / y, 1, y=0) pytest.raises(ZeroDivisionError, lambda x, y: x / y, 1, y=0) + raises(ZeroDivisionError, lambda x, y: x / y, 1, y=0) def raises_with_absolute_reference(self): with self.assertRaises(asyncio.CancelledError): Foo() with pytest.raises(asyncio.CancelledError): Foo() + with raises(asyncio.CancelledError): + Foo() diff --git a/tests/test_bugbear.py b/tests/test_bugbear.py index 49377b4..22ccd6a 100644 --- a/tests/test_bugbear.py +++ b/tests/test_bugbear.py @@ -258,7 +258,7 @@ def test_b017(self): filename = Path(__file__).absolute().parent / "b017.py" bbc = BugBearChecker(filename=str(filename)) errors = list(bbc.run()) - expected = self.errors(B017(25, 8), B017(27, 8), B017(29, 8)) + expected = self.errors(B017(26, 8), B017(28, 8), B017(30, 8), B017(32, 8)) self.assertEqual(errors, expected) def test_b018_functions(self): @@ -530,9 +530,11 @@ def test_b908(self): B908(15, 8), B908(21, 8), B908(27, 8), + B017(37, 0), B908(37, 0), B908(41, 0), B908(45, 0), + B017(56, 0), ) self.assertEqual(errors, expected)