Skip to content

Commit

Permalink
Fix #411. Fix extracting generator without parens
Browse files Browse the repository at this point in the history
Extracting generator comprehension without selecting the surrounding
parentheses may produce invalid code. This commit adds surrounding
parentheses when it detects that the selected code is a bare generator
comprehension.
  • Loading branch information
lieryan committed Sep 27, 2022
1 parent 5d0fcb7 commit 64d39ae
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 1 deletion.
16 changes: 15 additions & 1 deletion rope/refactor/extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,20 @@ def returning_named_expr(self):
)
return self._returning_named_expr

_returning_generator = None

@property
def returning_generator_exp(self):
"""Does the extracted piece contains a generator expression"""
if self._returning_generator is None:
self._returning_generator = (
isinstance(self._parsed_extracted, ast.Module)
and isinstance(self._parsed_extracted.body[0], ast.Expr)
and isinstance(self._parsed_extracted.body[0].value, ast.GeneratorExp)
)

return self._returning_generator


class _ExtractCollector:
"""Collects information needed for performing the extract"""
Expand Down Expand Up @@ -1103,6 +1117,6 @@ def _get_single_expression_body(extracted, info):
if not large_multiline:
extracted = _join_lines(extracted)
multiline_expression = "\n" in extracted
if info.returning_named_expr or (multiline_expression and not large_multiline):
if info.returning_named_expr or info.returning_generator_exp or (multiline_expression and not large_multiline):
extracted = "(" + extracted + ")"
return extracted
21 changes: 21 additions & 0 deletions ropetest/refactor/extracttest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2708,6 +2708,27 @@ def _a(y):
""")
self.assertEqual(expected, refactored)

def test_extract_with_generator_2(self):
code = dedent("""\
def f():
y = [1,2,3,4]
a = sum(x for x in y)
""")
extract_target = "x for x in y"
start, end = code.index(extract_target), code.index(extract_target) + len(
extract_target
)
refactored = self.do_extract_method(code, start, end, "_a")
expected = dedent("""\
def f():
y = [1,2,3,4]
a = sum(_a(y))
def _a(y):
return (x for x in y)
""")
self.assertEqual(expected, refactored)

def test_extract_with_set_comprehension(self):
code = dedent("""\
def f():
Expand Down

0 comments on commit 64d39ae

Please sign in to comment.