Skip to content

Commit

Permalink
Merge pull request #487 from python-rope/lieryan-improved-value-infer…
Browse files Browse the repository at this point in the history
…ence

Improved value inference of __all__ declaration
  • Loading branch information
lieryan committed Jun 7, 2022
2 parents 2893941 + 0666808 commit 064b45a
Show file tree
Hide file tree
Showing 3 changed files with 165 additions and 18 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

- #479 Add ABC and type hints for TaskHandle and JobSet (@bageljrkhanofemus)
- #486 Drop Python 2 support (@bageljrkhanofemus, @lieryan)
- #487 Improved value inference of __all__ declaration (@lieryan)

# Release 1.1.1

Expand Down
71 changes: 53 additions & 18 deletions rope/refactor/importutils/module_imports.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
from rope.base import ast
from rope.base import exceptions
from rope.base import pynames
from rope.base import utils
from rope.refactor.importutils import actions
from rope.refactor.importutils import importinfo
from typing import Union, List

from rope.base import ast, exceptions, pynames, pynamesdef, utils
from rope.refactor.importutils import actions, importinfo


class ModuleImports:
Expand Down Expand Up @@ -32,37 +30,74 @@ def _get_unbound_names(self, defined_pyobject):
return visitor.unbound

def _get_all_star_list(self, pymodule):
def _resolve_name(
name: Union[pynamesdef.AssignedName, pynames.ImportedName]
) -> List:
while isinstance(name, pynames.ImportedName):
try:
name = name.imported_module.get_object().get_attribute(
name.imported_name,
)
except exceptions.AttributeNotFoundError:
return []
assert isinstance(name, pynamesdef.AssignedName)
return name.assignments

result = set()
try:
all_star_list = pymodule.get_attribute("__all__")
except exceptions.AttributeNotFoundError:
return result

assignments = [
assignment.ast_node for assignment in _resolve_name(all_star_list)
]

# FIXME: Need a better way to recursively infer possible values.
# Currently pyobjects can recursively infer type, but not values.
# Do a very basic 1-level value inference
for assignment in all_star_list.assignments:
if isinstance(assignment.ast_node, ast.List):
stack = list(assignment.ast_node.elts)
while assignments:
assignment = assignments.pop()
if isinstance(assignment, ast.List):
stack = list(assignment.elts)
while stack:
el = stack.pop()
if isinstance(el, ast.Str):
result.add(el.s)
elif isinstance(el, ast.Name):
name = pymodule.get_attribute(el.id)
if isinstance(name, pynames.AssignedName):
for av in name.assignments:
if isinstance(av.ast_node, ast.Str):
result.add(av.ast_node.s)
elif isinstance(el, ast.IfExp):
if isinstance(el, ast.IfExp):
stack.append(el.body)
stack.append(el.orelse)
elif isinstance(el, ast.Starred):
assignments.append(el.value)
else:
if isinstance(el, ast.Str):
result.add(el.s)
elif isinstance(el, ast.Name):
try:
name = pymodule.get_attribute(el.id)
except exceptions.AttributeNotFoundError:
continue
else:
for av in _resolve_name(name):
if isinstance(av.ast_node, ast.Str):
result.add(av.ast_node.s)
elif isinstance(assignment, ast.Name):
try:
name = pymodule.get_attribute(assignment.id)
except exceptions.AttributeNotFoundError:
continue
else:
assignments.extend(
assignment.ast_node for assignment in _resolve_name(name)
)
elif isinstance(assignment, ast.BinOp):
assignments.append(assignment.left)
assignments.append(assignment.right)
return result

def remove_unused_imports(self):
can_select = _OneTimeSelector(
self._get_unbound_names(self.pymodule)
| self._get_all_star_list(self.pymodule)
| {"__all__"}
)
visitor = actions.RemovingVisitor(
self.project, self._current_folder(), can_select
Expand Down
111 changes: 111 additions & 0 deletions ropetest/refactor/importutilstest.py
Original file line number Diff line number Diff line change
Expand Up @@ -1873,6 +1873,33 @@ def test_organizing_imports_all_star_tolerates_non_list_of_str_1(self):
pymod = self.project.get_pymodule(self.mod)
self.assertEqual(expected, self.import_tools.organize_imports(pymod))

def test_organizing_imports_all_star_assigned_name_alias(self):
code = expected = dedent("""\
from package import name_one, name_two
foo = ['name_one', 'name_two']
__all__ = foo
""")
self.mod.write(code)
pymod = self.project.get_pymodule(self.mod)
self.assertEqual(expected, self.import_tools.organize_imports(pymod))

def test_organizing_imports_all_star_imported_name_alias(self):
self.mod1.write("foo = ['name_one', 'name_two']")
self.mod2.write("from pkg1.mod1 import foo")
code = expected = dedent("""\
from package import name_one, name_two
from pkg2.mod2 import foo
__all__ = foo
""")
self.mod3.write(code)
pymod = self.project.get_pymodule(self.mod3)
self.assertEqual(expected, self.import_tools.organize_imports(pymod))

def test_organizing_imports_all_star_tolerates_non_list_of_str_2(self):
code = expected = dedent("""\
from package import name_one, name_two
Expand All @@ -1886,6 +1913,30 @@ def test_organizing_imports_all_star_tolerates_non_list_of_str_2(self):
pymod = self.project.get_pymodule(self.mod)
self.assertEqual(expected, self.import_tools.organize_imports(pymod))

def test_organizing_imports_all_star_plusjoin(self):
code = expected = dedent("""\
from package import name_one, name_two
foo = ['name_two']
__all__ = ['name_one'] + foo
""")
self.mod.write(code)
pymod = self.project.get_pymodule(self.mod)
self.assertEqual(expected, self.import_tools.organize_imports(pymod))

def test_organizing_imports_all_star_starjoin(self):
code = expected = dedent("""\
from package import name_one, name_two
foo = ['name_two']
__all__ = ['name_one', *foo]
""")
self.mod.write(code)
pymod = self.project.get_pymodule(self.mod)
self.assertEqual(expected, self.import_tools.organize_imports(pymod))

@testutils.time_limit(60)
def test_organizing_imports_all_star_no_infinite_loop(self):
code = expected = dedent("""\
Expand All @@ -1900,6 +1951,66 @@ def test_organizing_imports_all_star_no_infinite_loop(self):
pymod = self.project.get_pymodule(self.mod)
self.assertEqual(expected, self.import_tools.organize_imports(pymod))

def test_organizing_imports_all_star_resolve_imported_name(self):
self.mod1.write("foo = 'name_one'")

code = expected = dedent("""\
from package import name_one, name_two
from pkg1.mod1 import foo
__all__ = [foo, 'name_two']
""")
self.mod.write(code)
pymod = self.project.get_pymodule(self.mod)
self.assertEqual(expected, self.import_tools.organize_imports(pymod))

def test_organizing_imports_undefined_variable(self):
code = expected = dedent("""\
from foo import some_name
__all__ = ['some_name', undefined_variable]
""")
self.mod1.write(code)

pymod = self.project.get_pymodule(self.mod1)
self.assertEqual(expected, self.import_tools.organize_imports(pymod))

def test_organizing_imports_undefined_variable_with_imported_name(self):
self.mod1.write("")
self.mod2.write("from pkg1.mod1 import undefined_variable")

code = expected = dedent("""\
from pkg2.mod2 import undefined_variable
__all__ = undefined_variable
""")
self.mod3.write(code)

pymod = self.project.get_pymodule(self.mod3)
self.assertEqual(expected, self.import_tools.organize_imports(pymod))

def test_organizing_indirect_all_star_import(self):
self.mod1.write("some_name = 1")
self.mod2.write(
dedent("""\
__all__ = ['some_name', *imported_all]
""")
)

code = expected = dedent("""\
from mod1 import some_name
from mod2 import __all__
""")
self.mod3.write(code)

pymod = self.project.get_pymodule(self.mod3)
self.assertEqual(expected, self.import_tools.organize_imports(pymod))

def test_customized_import_organization(self):
self.mod.write(
dedent("""\
Expand Down

0 comments on commit 064b45a

Please sign in to comment.