Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a function to identify ast Constant nodes more granularly #638

Merged
merged 2 commits into from
Jan 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
33 changes: 33 additions & 0 deletions rope/base/ast.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,24 @@
import ast
import sys
from ast import * # noqa: F401,F403

from rope.base import fscommands

try:
from ast import _const_node_type_names
except ImportError:
# backported from stdlib `ast`
_const_node_type_names = {
bool: "NameConstant", # should be before int
type(None): "NameConstant",
int: "Num",
float: "Num",
complex: "Num",
str: "Str",
bytes: "Bytes",
type(...): "Ellipsis",
}


def parse(source, filename="<string>", *args, **kwargs): # type: ignore
if isinstance(source, str):
Expand Down Expand Up @@ -35,3 +51,20 @@ def visit(self, node):
method = "_" + node.__class__.__name__
visitor = getattr(self, method, self.generic_visit)
return visitor(node)


def get_const_subtype_name(node):
"""Get pre-3.8 ast node name"""
# fmt: off
assert sys.version_info >= (3, 8), "This should only be called in Python 3.8 and above"
# fmt: on
assert isinstance(node, ast.Constant)
return _const_node_type_names[type(node.value)]


def get_node_type_name(node):
return (
get_const_subtype_name(node)
if isinstance(node, ast.Constant)
else node.__class__.__name__
)
52 changes: 8 additions & 44 deletions ropetest/refactor/patchedasttest.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from ropetest import testutils

NameConstant = "Name" if sys.version_info <= (3, 8) else "NameConstant"
Bytes = "Bytes" if (3, 0) <= sys.version_info <= (3, 8) else "Str"


class PatchedASTTest(unittest.TestCase):
Expand Down Expand Up @@ -57,8 +56,8 @@ def test_bytes_string(self):
checker = _ResultChecker(self, ast_frag)
str_fragment = 'b"("'
start = source.index(str_fragment)
checker.check_region(Bytes, start, start + len(str_fragment))
checker.check_children(Bytes, [str_fragment])
checker.check_region("Bytes", start, start + len(str_fragment))
checker.check_children("Bytes", [str_fragment])

def test_integer_literals_and_region(self):
source = "a = 10\n"
Expand Down Expand Up @@ -839,28 +838,14 @@ def test_list_comp_node_with_multiple_comprehensions(self):
)

def test_set_node(self):
# make sure we are in a python version with set literals
source = "{1, 2}\n"

try:
eval(source)
except SyntaxError:
return

ast_frag = patchedast.get_patched_ast(source, True)
checker = _ResultChecker(self, ast_frag)
checker.check_region("Set", 0, len(source) - 1)
checker.check_children("Set", ["{", "", "Num", "", ",", " ", "Num", "", "}"])

def test_set_comp_node(self):
# make sure we are in a python version with set comprehensions
source = "{i for i in range(1) if True}\n"

try:
eval(source)
except SyntaxError:
return

ast_frag = patchedast.get_patched_ast(source, True)
checker = _ResultChecker(self, ast_frag)
checker.check_region("SetComp", 0, len(source) - 1)
Expand All @@ -873,14 +858,7 @@ def test_set_comp_node(self):
)

def test_dict_comp_node(self):
# make sure we are in a python version with dict comprehensions
source = "{i:i for i in range(1) if True}\n"

try:
eval(source)
except SyntaxError:
return

ast_frag = patchedast.get_patched_ast(source, True)
checker = _ResultChecker(self, ast_frag)
checker.check_region("DictComp", 0, len(source) - 1)
Expand Down Expand Up @@ -1285,7 +1263,7 @@ def test_match_node_with_constant_match_value(self):
checker = _ResultChecker(self, ast_frag)
self.assert_single_case_match_block(checker, "MatchValue")
checker.check_children("MatchValue", [
"Constant"
"Num"
])

@testutils.only_for_versions_higher("3.10")
Expand Down Expand Up @@ -1340,7 +1318,7 @@ def test_match_node_with_match_class(self):
")",
])
checker.check_children("MatchValue", [
"Constant"
"Num"
])

@testutils.only_for_versions_higher("3.10")
Expand Down Expand Up @@ -1488,7 +1466,7 @@ def test_match_node_with_match_mapping_match_as(self):
checker.check_children("MatchMapping", [
"{",
"",
"Constant",
"Str",
"",
":",
" ",
Expand Down Expand Up @@ -1519,17 +1497,10 @@ class Search:

def __call__(self, node):
for text in goal:
if sys.version_info >= (3, 8) and text in [
"Num",
"Str",
"NameConstant",
"Ellipsis",
]:
text = "Constant"
if str(node).startswith(text):
self.result = node
break
if node.__class__.__name__.startswith(text):
if ast.get_node_type_name(node).startswith(text):
self.result = node
break
return self.result is not None
Expand All @@ -1554,15 +1525,8 @@ def check_children(self, text, children):
break
else:
self.test_case.assertNotEqual("", text, "probably ignoring some node")
if sys.version_info >= (3, 8) and expected in [
"Num",
"Str",
"NameConstant",
"Ellipsis",
]:
expected = "Constant"
self.test_case.assertTrue(
child.__class__.__name__.startswith(expected),
ast.get_node_type_name(child).startswith(expected),
msg="Expected <%s> but was <%s>"
% (expected, child.__class__.__name__),
% (expected, ast.get_node_type_name(child)),
)