Skip to content

Commit

Permalink
Merge pull request #9163 from bluetech/rewrite-end-lineno
Browse files Browse the repository at this point in the history
rewrite: fixup end_lineno, end_col_offset of rewritten asserts
  • Loading branch information
bluetech committed Oct 5, 2021
2 parents c82bda2 + 6a5211f commit 54811b2
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 15 deletions.
1 change: 1 addition & 0 deletions changelog/9163.bugfix.rst
@@ -0,0 +1 @@
The end line number and end column offset are now properly set for rewritten assert statements.
24 changes: 9 additions & 15 deletions src/_pytest/assertion/rewrite.py
Expand Up @@ -19,6 +19,7 @@
from typing import Dict
from typing import IO
from typing import Iterable
from typing import Iterator
from typing import List
from typing import Optional
from typing import Sequence
Expand Down Expand Up @@ -539,19 +540,11 @@ def _check_if_assertion_pass_impl() -> bool:
}


def set_location(node, lineno, col_offset):
"""Set node location information recursively."""

def _fix(node, lineno, col_offset):
if "lineno" in node._attributes:
node.lineno = lineno
if "col_offset" in node._attributes:
node.col_offset = col_offset
for child in ast.iter_child_nodes(node):
_fix(child, lineno, col_offset)

_fix(node, lineno, col_offset)
return node
def traverse_node(node: ast.AST) -> Iterator[ast.AST]:
"""Recursively yield node and all its children in depth-first order."""
yield node
for child in ast.iter_child_nodes(node):
yield from traverse_node(child)


@functools.lru_cache(maxsize=1)
Expand Down Expand Up @@ -954,9 +947,10 @@ def visit_Assert(self, assert_: ast.Assert) -> List[ast.stmt]:
variables = [ast.Name(name, ast.Store()) for name in self.variables]
clear = ast.Assign(variables, ast.NameConstant(None))
self.statements.append(clear)
# Fix line numbers.
# Fix locations (line numbers/column offsets).
for stmt in self.statements:
set_location(stmt, assert_.lineno, assert_.col_offset)
for node in traverse_node(stmt):
ast.copy_location(node, assert_)
return self.statements

def visit_Name(self, name: ast.Name) -> Tuple[ast.Name, str]:
Expand Down
22 changes: 22 additions & 0 deletions testing/test_assertrewrite.py
Expand Up @@ -111,6 +111,28 @@ def test_place_initial_imports(self) -> None:
assert imp.col_offset == 0
assert isinstance(m.body[3], ast.Expr)

def test_location_is_set(self) -> None:
s = textwrap.dedent(
"""
assert False, (
"Ouch"
)
"""
)
m = rewrite(s)
for node in m.body:
if isinstance(node, ast.Import):
continue
for n in [node, *ast.iter_child_nodes(node)]:
assert n.lineno == 3
assert n.col_offset == 0
if sys.version_info >= (3, 8):
assert n.end_lineno == 6
assert n.end_col_offset == 3

def test_dont_rewrite(self) -> None:
s = """'PYTEST_DONT_REWRITE'\nassert 14"""
m = rewrite(s)
Expand Down

0 comments on commit 54811b2

Please sign in to comment.