Skip to content

Commit

Permalink
Merge pull request #505 from python-rope/lieryan-411-parensless-gener…
Browse files Browse the repository at this point in the history
…ator-comprehension

Fix extracting generator without parens
  • Loading branch information
lieryan committed Oct 3, 2022
2 parents 5d0fcb7 + 381385a commit 2733daa
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 1 deletion.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
# **Upcoming release**

## Bug fixes

- #411 Fix extracting generator without parens

# Release 1.3.0

## Bug fixes
Expand Down
20 changes: 19 additions & 1 deletion rope/refactor/extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,20 @@ def returning_named_expr(self):
)
return self._returning_named_expr

_returning_generator = None

@property
def returning_generator_exp(self):
"""Does the extracted piece contains a generator expression"""
if self._returning_generator is None:
self._returning_generator = (
isinstance(self._parsed_extracted, ast.Module)
and isinstance(self._parsed_extracted.body[0], ast.Expr)
and isinstance(self._parsed_extracted.body[0].value, ast.GeneratorExp)
)

return self._returning_generator


class _ExtractCollector:
"""Collects information needed for performing the extract"""
Expand Down Expand Up @@ -1103,6 +1117,10 @@ def _get_single_expression_body(extracted, info):
if not large_multiline:
extracted = _join_lines(extracted)
multiline_expression = "\n" in extracted
if info.returning_named_expr or (multiline_expression and not large_multiline):
if (
info.returning_named_expr
or info.returning_generator_exp
or (multiline_expression and not large_multiline)
):
extracted = "(" + extracted + ")"
return extracted
21 changes: 21 additions & 0 deletions ropetest/refactor/extracttest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2708,6 +2708,27 @@ def _a(y):
""")
self.assertEqual(expected, refactored)

def test_extract_with_generator_2(self):
code = dedent("""\
def f():
y = [1,2,3,4]
a = sum(x for x in y)
""")
extract_target = "x for x in y"
start, end = code.index(extract_target), code.index(extract_target) + len(
extract_target
)
refactored = self.do_extract_method(code, start, end, "_a")
expected = dedent("""\
def f():
y = [1,2,3,4]
a = sum(_a(y))
def _a(y):
return (x for x in y)
""")
self.assertEqual(expected, refactored)

def test_extract_with_set_comprehension(self):
code = dedent("""\
def f():
Expand Down

0 comments on commit 2733daa

Please sign in to comment.