diff --git a/changelog/9163.bugfix.rst b/changelog/9163.bugfix.rst new file mode 100644 index 00000000000..fb559d10fe8 --- /dev/null +++ b/changelog/9163.bugfix.rst @@ -0,0 +1 @@ +The end line number and end column offset are now properly set for rewritten assert statements. diff --git a/src/_pytest/assertion/rewrite.py b/src/_pytest/assertion/rewrite.py index 99416aed20b..1f7de90d5a9 100644 --- a/src/_pytest/assertion/rewrite.py +++ b/src/_pytest/assertion/rewrite.py @@ -19,6 +19,7 @@ from typing import Dict from typing import IO from typing import Iterable +from typing import Iterator from typing import List from typing import Optional from typing import Sequence @@ -539,19 +540,11 @@ def _check_if_assertion_pass_impl() -> bool: } -def set_location(node, lineno, col_offset): - """Set node location information recursively.""" - - def _fix(node, lineno, col_offset): - if "lineno" in node._attributes: - node.lineno = lineno - if "col_offset" in node._attributes: - node.col_offset = col_offset - for child in ast.iter_child_nodes(node): - _fix(child, lineno, col_offset) - - _fix(node, lineno, col_offset) - return node +def traverse_node(node: ast.AST) -> Iterator[ast.AST]: + """Recursively yield node and all its children in depth-first order.""" + yield node + for child in ast.iter_child_nodes(node): + yield from traverse_node(child) @functools.lru_cache(maxsize=1) @@ -954,9 +947,10 @@ def visit_Assert(self, assert_: ast.Assert) -> List[ast.stmt]: variables = [ast.Name(name, ast.Store()) for name in self.variables] clear = ast.Assign(variables, ast.NameConstant(None)) self.statements.append(clear) - # Fix line numbers. + # Fix locations (line numbers/column offsets). for stmt in self.statements: - set_location(stmt, assert_.lineno, assert_.col_offset) + for node in traverse_node(stmt): + ast.copy_location(node, assert_) return self.statements def visit_Name(self, name: ast.Name) -> Tuple[ast.Name, str]: diff --git a/testing/test_assertrewrite.py b/testing/test_assertrewrite.py index d732cbdd709..5e63d61fa30 100644 --- a/testing/test_assertrewrite.py +++ b/testing/test_assertrewrite.py @@ -111,6 +111,28 @@ def test_place_initial_imports(self) -> None: assert imp.col_offset == 0 assert isinstance(m.body[3], ast.Expr) + def test_location_is_set(self) -> None: + s = textwrap.dedent( + """ + + assert False, ( + + "Ouch" + ) + + """ + ) + m = rewrite(s) + for node in m.body: + if isinstance(node, ast.Import): + continue + for n in [node, *ast.iter_child_nodes(node)]: + assert n.lineno == 3 + assert n.col_offset == 0 + if sys.version_info >= (3, 8): + assert n.end_lineno == 6 + assert n.end_col_offset == 3 + def test_dont_rewrite(self) -> None: s = """'PYTEST_DONT_REWRITE'\nassert 14""" m = rewrite(s)