Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

rewrite: fixup end_lineno, end_col_offset of rewritten asserts #9163

Merged
merged 1 commit into from Oct 5, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog/9163.bugfix.rst
@@ -0,0 +1 @@
The end line number and end column offset are now properly set for rewritten assert statements.
24 changes: 9 additions & 15 deletions src/_pytest/assertion/rewrite.py
Expand Up @@ -19,6 +19,7 @@
from typing import Dict
from typing import IO
from typing import Iterable
from typing import Iterator
from typing import List
from typing import Optional
from typing import Sequence
Expand Down Expand Up @@ -539,19 +540,11 @@ def _check_if_assertion_pass_impl() -> bool:
}


def set_location(node, lineno, col_offset):
"""Set node location information recursively."""

def _fix(node, lineno, col_offset):
if "lineno" in node._attributes:
node.lineno = lineno
if "col_offset" in node._attributes:
node.col_offset = col_offset
for child in ast.iter_child_nodes(node):
_fix(child, lineno, col_offset)

_fix(node, lineno, col_offset)
return node
def traverse_node(node: ast.AST) -> Iterator[ast.AST]:
"""Recursively yield node and all its children in depth-first order."""
yield node
for child in ast.iter_child_nodes(node):
yield from traverse_node(child)


@functools.lru_cache(maxsize=1)
Expand Down Expand Up @@ -954,9 +947,10 @@ def visit_Assert(self, assert_: ast.Assert) -> List[ast.stmt]:
variables = [ast.Name(name, ast.Store()) for name in self.variables]
clear = ast.Assign(variables, ast.NameConstant(None))
self.statements.append(clear)
# Fix line numbers.
# Fix locations (line numbers/column offsets).
for stmt in self.statements:
set_location(stmt, assert_.lineno, assert_.col_offset)
for node in traverse_node(stmt):
ast.copy_location(node, assert_)
return self.statements

def visit_Name(self, name: ast.Name) -> Tuple[ast.Name, str]:
Expand Down
22 changes: 22 additions & 0 deletions testing/test_assertrewrite.py
Expand Up @@ -111,6 +111,28 @@ def test_place_initial_imports(self) -> None:
assert imp.col_offset == 0
assert isinstance(m.body[3], ast.Expr)

def test_location_is_set(self) -> None:
s = textwrap.dedent(
"""

assert False, (

"Ouch"
)

"""
)
m = rewrite(s)
for node in m.body:
if isinstance(node, ast.Import):
continue
for n in [node, *ast.iter_child_nodes(node)]:
assert n.lineno == 3
assert n.col_offset == 0
if sys.version_info >= (3, 8):
assert n.end_lineno == 6
assert n.end_col_offset == 3

def test_dont_rewrite(self) -> None:
s = """'PYTEST_DONT_REWRITE'\nassert 14"""
m = rewrite(s)
Expand Down