Skip to content

Commit

Permalink
Merge pull request #638 from python-rope/lieryan-fine-grained-ast-nod…
Browse files Browse the repository at this point in the history
…e-identification

Add a function to identify ast Constant nodes more granularly
  • Loading branch information
lieryan committed Jan 3, 2023
2 parents fe7c0fe + ba269c8 commit 163aef3
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 44 deletions.
33 changes: 33 additions & 0 deletions rope/base/ast.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,24 @@
import ast
import sys
from ast import * # noqa: F401,F403

from rope.base import fscommands

try:
from ast import _const_node_type_names
except ImportError:
# backported from stdlib `ast`
_const_node_type_names = {
bool: "NameConstant", # should be before int
type(None): "NameConstant",
int: "Num",
float: "Num",
complex: "Num",
str: "Str",
bytes: "Bytes",
type(...): "Ellipsis",
}


def parse(source, filename="<string>", *args, **kwargs): # type: ignore
if isinstance(source, str):
Expand Down Expand Up @@ -35,3 +51,20 @@ def visit(self, node):
method = "_" + node.__class__.__name__
visitor = getattr(self, method, self.generic_visit)
return visitor(node)


def get_const_subtype_name(node):
"""Get pre-3.8 ast node name"""
# fmt: off
assert sys.version_info >= (3, 8), "This should only be called in Python 3.8 and above"
# fmt: on
assert isinstance(node, ast.Constant)
return _const_node_type_names[type(node.value)]


def get_node_type_name(node):
return (
get_const_subtype_name(node)
if isinstance(node, ast.Constant)
else node.__class__.__name__
)
52 changes: 8 additions & 44 deletions ropetest/refactor/patchedasttest.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from ropetest import testutils

NameConstant = "Name" if sys.version_info <= (3, 8) else "NameConstant"
Bytes = "Bytes" if (3, 0) <= sys.version_info <= (3, 8) else "Str"


class PatchedASTTest(unittest.TestCase):
Expand Down Expand Up @@ -57,8 +56,8 @@ def test_bytes_string(self):
checker = _ResultChecker(self, ast_frag)
str_fragment = 'b"("'
start = source.index(str_fragment)
checker.check_region(Bytes, start, start + len(str_fragment))
checker.check_children(Bytes, [str_fragment])
checker.check_region("Bytes", start, start + len(str_fragment))
checker.check_children("Bytes", [str_fragment])

def test_integer_literals_and_region(self):
source = "a = 10\n"
Expand Down Expand Up @@ -839,28 +838,14 @@ def test_list_comp_node_with_multiple_comprehensions(self):
)

def test_set_node(self):
# make sure we are in a python version with set literals
source = "{1, 2}\n"

try:
eval(source)
except SyntaxError:
return

ast_frag = patchedast.get_patched_ast(source, True)
checker = _ResultChecker(self, ast_frag)
checker.check_region("Set", 0, len(source) - 1)
checker.check_children("Set", ["{", "", "Num", "", ",", " ", "Num", "", "}"])

def test_set_comp_node(self):
# make sure we are in a python version with set comprehensions
source = "{i for i in range(1) if True}\n"

try:
eval(source)
except SyntaxError:
return

ast_frag = patchedast.get_patched_ast(source, True)
checker = _ResultChecker(self, ast_frag)
checker.check_region("SetComp", 0, len(source) - 1)
Expand All @@ -873,14 +858,7 @@ def test_set_comp_node(self):
)

def test_dict_comp_node(self):
# make sure we are in a python version with dict comprehensions
source = "{i:i for i in range(1) if True}\n"

try:
eval(source)
except SyntaxError:
return

ast_frag = patchedast.get_patched_ast(source, True)
checker = _ResultChecker(self, ast_frag)
checker.check_region("DictComp", 0, len(source) - 1)
Expand Down Expand Up @@ -1285,7 +1263,7 @@ def test_match_node_with_constant_match_value(self):
checker = _ResultChecker(self, ast_frag)
self.assert_single_case_match_block(checker, "MatchValue")
checker.check_children("MatchValue", [
"Constant"
"Num"
])

@testutils.only_for_versions_higher("3.10")
Expand Down Expand Up @@ -1340,7 +1318,7 @@ def test_match_node_with_match_class(self):
")",
])
checker.check_children("MatchValue", [
"Constant"
"Num"
])

@testutils.only_for_versions_higher("3.10")
Expand Down Expand Up @@ -1488,7 +1466,7 @@ def test_match_node_with_match_mapping_match_as(self):
checker.check_children("MatchMapping", [
"{",
"",
"Constant",
"Str",
"",
":",
" ",
Expand Down Expand Up @@ -1519,17 +1497,10 @@ class Search:

def __call__(self, node):
for text in goal:
if sys.version_info >= (3, 8) and text in [
"Num",
"Str",
"NameConstant",
"Ellipsis",
]:
text = "Constant"
if str(node).startswith(text):
self.result = node
break
if node.__class__.__name__.startswith(text):
if ast.get_node_type_name(node).startswith(text):
self.result = node
break
return self.result is not None
Expand All @@ -1554,15 +1525,8 @@ def check_children(self, text, children):
break
else:
self.test_case.assertNotEqual("", text, "probably ignoring some node")
if sys.version_info >= (3, 8) and expected in [
"Num",
"Str",
"NameConstant",
"Ellipsis",
]:
expected = "Constant"
self.test_case.assertTrue(
child.__class__.__name__.startswith(expected),
ast.get_node_type_name(child).startswith(expected),
msg="Expected <%s> but was <%s>"
% (expected, child.__class__.__name__),
% (expected, ast.get_node_type_name(child)),
)

0 comments on commit 163aef3

Please sign in to comment.