Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit e7dfab6

Browse files
committedApr 14, 2025·
refine
1 parent 21b0be2 commit e7dfab6

File tree

2 files changed

+38
-24
lines changed

2 files changed

+38
-24
lines changed
 

‎ci_scripts/check_api_parameters.py

+32-20
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
import re
2121
import sys
2222

23+
import paddle # noqa: F401
24+
2325

2426
def add_path(path):
2527
if path not in sys.path:
@@ -70,6 +72,7 @@ def _check_params_in_description(rstfilename, paramstr):
7072
params_in_title.remove("/")
7173
if "*" in params_in_title:
7274
params_in_title.remove("*")
75+
params_in_title = ", ".join(params_in_title)
7376

7477
funcdescnode = extract_params_desc_from_rst_file(rstfilename)
7578
if funcdescnode:
@@ -122,34 +125,20 @@ def _check_params_in_description_with_fullargspec(rstfilename, funcname):
122125
info = ""
123126
try:
124127
func = eval(funcname)
125-
except NameError:
126-
import paddle # noqa: F401
127-
128-
func = eval(funcname)
128+
except AttributeError:
129+
flag = False
130+
info = f"function {funcname} in rst file {rstfilename} not found in paddle module, please check it."
131+
return flag, info
129132
source = inspect.getsource(func)
130133

131-
class FunctionDefExtractor(ast.NodeTransformer):
132-
target_name = func.__name__
133-
134-
def visit_FunctionDef(self, node):
135-
if node.name == self.target_name:
136-
node.decorator_list = []
137-
node.body = [ast.Pass()]
138-
return node
139-
return None
140-
141134
tree = ast.parse(source)
142-
modified_tree = FunctionDefExtractor().visit(tree)
143-
modified_tree.body = [
144-
node for node in modified_tree.body if node is not None
145-
]
146-
147-
func_node = modified_tree.body[0]
135+
func_node = tree.body[0]
148136
params_inspec = gen_functions_args_str(func_node).split(", ")
149137
if "/" in params_inspec:
150138
params_inspec.remove("/")
151139
if "*" in params_inspec:
152140
params_inspec.remove("*")
141+
params_inspec = ", ".join(params_inspec)
153142
funcdescnode = extract_params_desc_from_rst_file(rstfilename)
154143
if funcdescnode:
155144
items = funcdescnode.children[1].children[0].children
@@ -207,16 +196,39 @@ def check_api_parameters(rstfiles, apiinfo):
207196
print(f"checking : {rstfile}")
208197
with open(rstfilename, "r") as rst_fobj:
209198
func_found = False
199+
is_first_line = True
200+
api_label = None
210201
for line in rst_fobj:
202+
if is_first_line:
203+
api_label = (
204+
line.strip()
205+
.removeprefix(".. _cn_api_")
206+
.replace("_", ".")
207+
.removesuffix("__upper")
208+
)
209+
is_first_line = False
211210
mo = pat.match(line)
212211
if mo:
213212
func_found = True
214213
functype = mo.group(1)
215214
if functype not in ("function", "method"):
215+
# TODO: check class method
216216
check_passed.append(rstfile)
217217
continue
218218
funcname = mo.group(2)
219219
paramstr = mo.group(3)
220+
221+
# check same as the api_label
222+
if funcname != api_label:
223+
# if funcname is a function, try to back to class
224+
obj = eval(funcname)
225+
if inspect.isfunction(obj):
226+
class_name = ".".join(funcname.split(".")[:-1])
227+
if class_name != api_label:
228+
flag = False
229+
info = f"funcname in title is not same as the label name: {funcname} != {api_label}."
230+
return flag, info
231+
220232
flag = False
221233
func_found_in_json = False
222234
for apiobj in apiinfo.values():

‎docs/api/gen_doc.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -328,7 +328,9 @@ def parse_module_file(mod):
328328
and n.name == "__init__"
329329
):
330330
api_info_dict[obj_id]["args"] = (
331-
gen_functions_args_str(n)
331+
gen_functions_args_str(
332+
n, skip_self=True
333+
)
332334
)
333335
break
334336
else:
@@ -361,7 +363,7 @@ def parse_module_file(mod):
361363
logger.debug("%s omitted", obj_full_name)
362364

363365

364-
def gen_functions_args_str(node):
366+
def gen_functions_args_str(node, skip_self=False):
365367
def _process_positional_args(args, params):
366368
positional_args = args.posonlyargs + args.args
367369
num_defaults = len(args.defaults)
@@ -370,7 +372,7 @@ def _process_positional_args(args, params):
370372
first_default_pos = total_positional - num_defaults
371373
if args.posonlyargs:
372374
for idx, arg in enumerate(args.posonlyargs):
373-
if arg.arg == "self":
375+
if skip_self and arg.arg == "self":
374376
continue
375377
param = _format_arg_with_default(
376378
arg, idx, first_default_pos, args.defaults
@@ -379,7 +381,7 @@ def _process_positional_args(args, params):
379381
params.append("/")
380382

381383
for idx, arg in enumerate(args.args):
382-
if arg.arg == "self":
384+
if skip_self and arg.arg == "self":
383385
continue
384386
global_idx = idx + len(args.posonlyargs)
385387
param = _format_arg_with_default(

0 commit comments

Comments
 (0)