Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix extracting generator without parens #505

Merged
merged 3 commits into from
Oct 3, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
# **Upcoming release**

## Bug fixes

- #411 Fix extracting generator without parens

# Release 1.3.0

## Bug fixes
Expand Down
20 changes: 19 additions & 1 deletion rope/refactor/extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,20 @@ def returning_named_expr(self):
)
return self._returning_named_expr

_returning_generator = None

@property
def returning_generator_exp(self):
"""Does the extracted piece contains a generator expression"""
if self._returning_generator is None:
self._returning_generator = (
isinstance(self._parsed_extracted, ast.Module)
and isinstance(self._parsed_extracted.body[0], ast.Expr)
and isinstance(self._parsed_extracted.body[0].value, ast.GeneratorExp)
)

return self._returning_generator


class _ExtractCollector:
"""Collects information needed for performing the extract"""
Expand Down Expand Up @@ -1103,6 +1117,10 @@ def _get_single_expression_body(extracted, info):
if not large_multiline:
extracted = _join_lines(extracted)
multiline_expression = "\n" in extracted
if info.returning_named_expr or (multiline_expression and not large_multiline):
if (
info.returning_named_expr
or info.returning_generator_exp
or (multiline_expression and not large_multiline)
):
extracted = "(" + extracted + ")"
return extracted
21 changes: 21 additions & 0 deletions ropetest/refactor/extracttest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2708,6 +2708,27 @@ def _a(y):
""")
self.assertEqual(expected, refactored)

def test_extract_with_generator_2(self):
code = dedent("""\
def f():
y = [1,2,3,4]
a = sum(x for x in y)
""")
extract_target = "x for x in y"
start, end = code.index(extract_target), code.index(extract_target) + len(
extract_target
)
refactored = self.do_extract_method(code, start, end, "_a")
expected = dedent("""\
def f():
y = [1,2,3,4]
a = sum(_a(y))

def _a(y):
return (x for x in y)
""")
self.assertEqual(expected, refactored)

def test_extract_with_set_comprehension(self):
code = dedent("""\
def f():
Expand Down