Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improved value inference of __all__ declaration #487

Merged
merged 14 commits into from
Jun 7, 2022
Merged
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

- #479 Add ABC and type hints for TaskHandle and JobSet (@bageljrkhanofemus)
- #486 Drop Python 2 support (@bageljrkhanofemus, @lieryan)
- #487 Improved value inference of __all__ declaration (@lieryan)

# Release 1.1.1

Expand Down
71 changes: 53 additions & 18 deletions rope/refactor/importutils/module_imports.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
from rope.base import ast
from rope.base import exceptions
from rope.base import pynames
from rope.base import utils
from rope.refactor.importutils import actions
from rope.refactor.importutils import importinfo
from typing import Union, List

from rope.base import ast, exceptions, pynames, pynamesdef, utils
from rope.refactor.importutils import actions, importinfo


class ModuleImports:
Expand Down Expand Up @@ -32,37 +30,74 @@ def _get_unbound_names(self, defined_pyobject):
return visitor.unbound

def _get_all_star_list(self, pymodule):
def _resolve_name(
name: Union[pynamesdef.AssignedName, pynames.ImportedName]
) -> List:
while isinstance(name, pynames.ImportedName):
try:
name = name.imported_module.get_object().get_attribute(
name.imported_name,
)
except exceptions.AttributeNotFoundError:
return []
assert isinstance(name, pynamesdef.AssignedName)
return name.assignments

result = set()
try:
all_star_list = pymodule.get_attribute("__all__")
except exceptions.AttributeNotFoundError:
return result

assignments = [
assignment.ast_node for assignment in _resolve_name(all_star_list)
]

# FIXME: Need a better way to recursively infer possible values.
# Currently pyobjects can recursively infer type, but not values.
# Do a very basic 1-level value inference
for assignment in all_star_list.assignments:
if isinstance(assignment.ast_node, ast.List):
stack = list(assignment.ast_node.elts)
while assignments:
assignment = assignments.pop()
if isinstance(assignment, ast.List):
stack = list(assignment.elts)
while stack:
el = stack.pop()
if isinstance(el, ast.Str):
result.add(el.s)
elif isinstance(el, ast.Name):
name = pymodule.get_attribute(el.id)
if isinstance(name, pynames.AssignedName):
for av in name.assignments:
if isinstance(av.ast_node, ast.Str):
result.add(av.ast_node.s)
elif isinstance(el, ast.IfExp):
if isinstance(el, ast.IfExp):
stack.append(el.body)
stack.append(el.orelse)
elif isinstance(el, ast.Starred):
assignments.append(el.value)
else:
if isinstance(el, ast.Str):
result.add(el.s)
elif isinstance(el, ast.Name):
try:
name = pymodule.get_attribute(el.id)
except exceptions.AttributeNotFoundError:
continue
else:
for av in _resolve_name(name):
if isinstance(av.ast_node, ast.Str):
result.add(av.ast_node.s)
elif isinstance(assignment, ast.Name):
try:
name = pymodule.get_attribute(assignment.id)
except exceptions.AttributeNotFoundError:
continue
else:
assignments.extend(
assignment.ast_node for assignment in _resolve_name(name)
)
elif isinstance(assignment, ast.BinOp):
assignments.append(assignment.left)
assignments.append(assignment.right)
return result

def remove_unused_imports(self):
can_select = _OneTimeSelector(
self._get_unbound_names(self.pymodule)
| self._get_all_star_list(self.pymodule)
| {"__all__"}
)
visitor = actions.RemovingVisitor(
self.project, self._current_folder(), can_select
Expand Down
111 changes: 111 additions & 0 deletions ropetest/refactor/importutilstest.py
Original file line number Diff line number Diff line change
Expand Up @@ -1873,6 +1873,33 @@ def test_organizing_imports_all_star_tolerates_non_list_of_str_1(self):
pymod = self.project.get_pymodule(self.mod)
self.assertEqual(expected, self.import_tools.organize_imports(pymod))

def test_organizing_imports_all_star_assigned_name_alias(self):
code = expected = dedent("""\
from package import name_one, name_two


foo = ['name_one', 'name_two']
__all__ = foo
""")
self.mod.write(code)
pymod = self.project.get_pymodule(self.mod)
self.assertEqual(expected, self.import_tools.organize_imports(pymod))

def test_organizing_imports_all_star_imported_name_alias(self):
self.mod1.write("foo = ['name_one', 'name_two']")
self.mod2.write("from pkg1.mod1 import foo")
code = expected = dedent("""\
from package import name_one, name_two

from pkg2.mod2 import foo


__all__ = foo
""")
self.mod3.write(code)
pymod = self.project.get_pymodule(self.mod3)
self.assertEqual(expected, self.import_tools.organize_imports(pymod))

def test_organizing_imports_all_star_tolerates_non_list_of_str_2(self):
code = expected = dedent("""\
from package import name_one, name_two
Expand All @@ -1886,6 +1913,30 @@ def test_organizing_imports_all_star_tolerates_non_list_of_str_2(self):
pymod = self.project.get_pymodule(self.mod)
self.assertEqual(expected, self.import_tools.organize_imports(pymod))

def test_organizing_imports_all_star_plusjoin(self):
code = expected = dedent("""\
from package import name_one, name_two


foo = ['name_two']
__all__ = ['name_one'] + foo
""")
self.mod.write(code)
pymod = self.project.get_pymodule(self.mod)
self.assertEqual(expected, self.import_tools.organize_imports(pymod))

def test_organizing_imports_all_star_starjoin(self):
code = expected = dedent("""\
from package import name_one, name_two


foo = ['name_two']
__all__ = ['name_one', *foo]
""")
self.mod.write(code)
pymod = self.project.get_pymodule(self.mod)
self.assertEqual(expected, self.import_tools.organize_imports(pymod))

@testutils.time_limit(60)
def test_organizing_imports_all_star_no_infinite_loop(self):
code = expected = dedent("""\
Expand All @@ -1900,6 +1951,66 @@ def test_organizing_imports_all_star_no_infinite_loop(self):
pymod = self.project.get_pymodule(self.mod)
self.assertEqual(expected, self.import_tools.organize_imports(pymod))

def test_organizing_imports_all_star_resolve_imported_name(self):
self.mod1.write("foo = 'name_one'")

code = expected = dedent("""\
from package import name_one, name_two

from pkg1.mod1 import foo


__all__ = [foo, 'name_two']
""")
self.mod.write(code)
pymod = self.project.get_pymodule(self.mod)
self.assertEqual(expected, self.import_tools.organize_imports(pymod))

def test_organizing_imports_undefined_variable(self):
code = expected = dedent("""\
from foo import some_name


__all__ = ['some_name', undefined_variable]
""")
self.mod1.write(code)

pymod = self.project.get_pymodule(self.mod1)
self.assertEqual(expected, self.import_tools.organize_imports(pymod))

def test_organizing_imports_undefined_variable_with_imported_name(self):
self.mod1.write("")
self.mod2.write("from pkg1.mod1 import undefined_variable")

code = expected = dedent("""\
from pkg2.mod2 import undefined_variable


__all__ = undefined_variable
""")
self.mod3.write(code)

pymod = self.project.get_pymodule(self.mod3)
self.assertEqual(expected, self.import_tools.organize_imports(pymod))

def test_organizing_indirect_all_star_import(self):
self.mod1.write("some_name = 1")
self.mod2.write(
dedent("""\
__all__ = ['some_name', *imported_all]
""")
)

code = expected = dedent("""\
from mod1 import some_name

from mod2 import __all__
""")
self.mod3.write(code)

pymod = self.project.get_pymodule(self.mod3)
self.assertEqual(expected, self.import_tools.organize_imports(pymod))

def test_customized_import_organization(self):
self.mod.write(
dedent("""\
Expand Down