Files
AI-Check-Test/.gitea/checker/controller_ast_parser.py
2026-06-05 15:42:29 +08:00

498 lines
16 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
纯 Python Controller AST 解析器(基于 javalang无需 Java/Maven
仅解析 Spring @RestController / @Controller 的接口映射与参数。
"""
from __future__ import annotations
import re
from pathlib import Path
from typing import Dict, List, Optional
import javalang
from javalang.tree import (
Annotation,
ClassDeclaration,
ElementValuePair,
FieldDeclaration,
FormalParameter,
Literal,
MemberReference,
MethodDeclaration,
)
from models import ApiEndpoint, ApiParameter
MAPPING_ANNS = {"GetMapping", "PostMapping", "PutMapping", "DeleteMapping", "PatchMapping", "RequestMapping"}
CONTROLLER_ANNS = {"RestController", "Controller"}
# Spring MVC 框架自动注入参数,不属于 API 调用方入参,解析时忽略
FRAMEWORK_PARAM_TYPES = {
"HttpServletRequest",
"HttpServletResponse",
"HttpSession",
"ServletRequest",
"ServletResponse",
"WebRequest",
"NativeWebRequest",
"Model",
"ModelMap",
"RedirectAttributes",
"BindingResult",
"Errors",
"Authentication",
"Principal",
"Locale",
"TimeZone",
"InputStream",
"OutputStream",
"Reader",
"Writer",
"HttpHeaders",
"UriComponentsBuilder",
}
def _is_framework_param(type_name: str, param_name: str) -> bool:
"""判断是否为框架注入参数(非 API 调用方需要传递)。"""
simple = type_name.split(".")[-1].replace(">", "").replace("<", "").strip()
if simple in FRAMEWORK_PARAM_TYPES:
return True
if param_name in ("request", "response") and (
simple.endswith("Request") or simple.endswith("Response")
):
return True
return False
def _ann_simple_name(ann: Annotation) -> str:
"""获取注解简单类名。"""
return ann.name.split(".")[-1]
def _literal_str(node) -> str:
"""从 AST 节点提取字符串或布尔值。"""
if node is None:
return ""
if isinstance(node, Literal):
v = node.value
if isinstance(v, bool):
return str(v).lower()
return str(v or "").strip('"').strip("'")
if isinstance(node, MemberReference):
return node.member
if isinstance(node, bool):
return str(node).lower()
return str(node).strip('"').strip("'")
def _collect_ann_members(ann: Annotation) -> Dict[str, str]:
"""
解析注解成员为字典。
支持 @GetMapping("/path") 与 @RequestParam(value="x", required=false)。
"""
members: Dict[str, str] = {}
el = ann.element
if el is None:
return members
if isinstance(el, ElementValuePair):
members[el.name] = _literal_str(el.value)
elif isinstance(el, list):
for item in el:
if isinstance(item, ElementValuePair):
members[item.name] = _literal_str(item.value)
else:
members["value"] = _literal_str(el)
return members
def _ann_string(ann: Annotation, *keys: str) -> str:
"""从注解提取字符串属性。"""
members = _collect_ann_members(ann)
for k in keys:
if k in members and members[k]:
return members[k]
if "value" in members:
return members["value"]
return ""
def _type_to_str(type_node) -> str:
"""Java 类型节点转字符串。"""
if type_node is None:
return "Object"
if isinstance(type_node, javalang.tree.BasicType):
return type_node.name
if isinstance(type_node, javalang.tree.ReferenceType):
name = type_node.name or "Object"
if type_node.arguments:
args = ",".join(_type_to_str(a.type) for a in type_node.arguments)
return f"{name}<{args}>"
if type_node.sub_type:
return _type_to_str(type_node.sub_type)
return name
return str(type_node)
def _normalize_path(path: str) -> str:
"""规范化 URI 路径。"""
if not path or not path.strip():
return ""
path = path.strip()
if not path.startswith("/"):
path = "/" + path
return re.sub(r"/+", "/", path)
def _join_paths(base: str, sub: str) -> str:
"""拼接类级与方法级路径。"""
b, m = _normalize_path(base), _normalize_path(sub)
if not b:
return m or "/"
if not m:
return b
if b.endswith("/") and m.startswith("/"):
return b + m[1:]
if not b.endswith("/") and not m.startswith("/"):
return b + "/" + m
return b + m
def _http_method(ann_name: str, ann: Annotation) -> str:
"""从映射注解推断 HTTP 方法。
支持大小写不敏感匹配,避免 PUTMapping、POSTMapping 等不规范写法导致解析失败。
"""
mapping = {
"GetMapping": "GET",
"PostMapping": "POST",
"PutMapping": "PUT",
"DeleteMapping": "DELETE",
"PatchMapping": "PATCH",
}
# 大小写不敏感匹配
for key, value in mapping.items():
if key.lower() == ann_name.lower():
return value
if ann_name.lower() == "requestmapping":
m = _ann_string(ann, "method")
if m:
return m.replace("RequestMethod.", "").upper()
return "GET"
return "GET"
def _has_ann(node, name: str) -> bool:
"""节点是否含指定注解。"""
anns = getattr(node, "annotations", None) or []
return any(_ann_simple_name(a) == name for a in anns)
def _find_ann(node, name: str) -> Optional[Annotation]:
"""查找指定注解。"""
for a in getattr(node, "annotations", None) or []:
if _ann_simple_name(a) == name:
return a
return None
def _param_source(param: FormalParameter) -> str:
"""参数来源path / query / header / form / body。"""
if _has_ann(param, "PathVariable"):
return "path"
if _has_ann(param, "RequestHeader"):
return "header"
if _has_ann(param, "RequestPart") or _has_ann(param, "ModelAttribute"):
return "form"
if _has_ann(param, "RequestBody"):
return "body"
return "query"
def _param_name(param: FormalParameter) -> str:
"""解析参数名。"""
for ann_name in ("RequestParam", "PathVariable", "RequestHeader", "RequestPart"):
ann = _find_ann(param, ann_name)
if ann:
val = _ann_string(ann, "value", "name")
if val:
return val
return param.name
def _param_required(param: FormalParameter) -> bool:
"""是否必填。"""
ann = _find_ann(param, "RequestParam")
if ann:
members = _collect_ann_members(ann)
if members.get("required", "").lower() == "false":
return False
type_name = _type_to_str(param.type)
if type_name.startswith("Optional"):
return False
return not _has_ann(param, "Nullable")
JAVADOC_PARAM_RE = re.compile(
r"@param\s+(\w+)\s+(.*?)(?=\n\s*\*\s*@|\n\s*\*/|\Z)",
re.DOTALL,
)
def _clean_javadoc_text(text: str) -> str:
"""清理 JavaDoc 行前缀与多余空白。"""
cleaned = re.sub(r"\s*\*\s?", " ", text)
return re.sub(r"\s+", " ", cleaned).strip()
def _parse_javadoc_params(javadoc: str) -> Dict[str, str]:
"""从 JavaDoc 块解析 @param 名称 -> 说明。"""
if not javadoc:
return {}
result: Dict[str, str] = {}
for match in JAVADOC_PARAM_RE.finditer(javadoc):
name = match.group(1)
desc = _clean_javadoc_text(match.group(2))
if desc:
result[name] = desc
return result
def _extract_javadoc_before_line(source: str, target_line: int) -> str:
"""
提取目标行之前紧邻的 JavaDoc 块。
target_line 为 1-indexed与方法声明行号一致
"""
if not source or target_line <= 1:
return ""
lines = source.splitlines()
idx = target_line - 2
while idx >= 0 and not lines[idx].strip():
idx -= 1
while idx >= 0 and lines[idx].strip().startswith("@"):
idx -= 1
if idx < 0 or not lines[idx].strip().endswith("*/"):
return ""
end_idx = idx
while idx >= 0 and not lines[idx].strip().startswith("/**"):
idx -= 1
if idx < 0:
return ""
return "\n".join(lines[idx : end_idx + 1])
def _lookup_param_description(
javadoc_params: Dict[str, str], param: FormalParameter, resolved_name: str
) -> Optional[str]:
"""按注解名或形参名匹配 JavaDoc @param 说明。"""
for key in (resolved_name, param.name):
if key and key in javadoc_params:
return javadoc_params[key]
return None
class ControllerAstParser:
"""
基于 javalang 的 Controller 解析器。
只解析传入的文件不扫描整个目录CI 更快)。
"""
def __init__(self, repo_root: Path, source_dir: Path):
"""
:param repo_root: 仓库根目录
:param source_dir: Java 源码根目录repo_root 下的相对路径对应的绝对路径)
"""
self.repo_root = repo_root
self.source_dir = source_dir
self._dto_cache: Dict[str, List[ApiParameter]] = {}
self._current_source = ""
def parse_file_content(self, source: str, repo_relative_path: str) -> List[ApiEndpoint]:
"""
解析单个 Java 源文件内容。
:param source: 源码文本
:param repo_relative_path: 相对仓库根目录的路径(与 git diff 一致)
:return: 端点列表
"""
endpoints: List[ApiEndpoint] = []
self._current_source = source
try:
tree = javalang.parse.parse(source)
except (javalang.parser.JavaSyntaxError, RecursionError) as exc:
print(f"[警告] 解析失败 {repo_relative_path}: {exc}")
return endpoints
for type_decl in tree.types or []:
if not isinstance(type_decl, ClassDeclaration):
continue
if not self._is_controller(type_decl):
continue
class_path = ""
for ann in type_decl.annotations or []:
if _ann_simple_name(ann) == "RequestMapping":
class_path = _ann_string(ann, "value", "path")
break
for method in type_decl.methods or []:
ep = self._parse_method(method, type_decl.name, class_path, repo_relative_path)
if ep:
endpoints.append(ep)
return endpoints
def _is_controller(self, cls: ClassDeclaration) -> bool:
"""是否为 Controller 类。"""
return any(_ann_simple_name(a) in CONTROLLER_ANNS for a in (cls.annotations or []))
def _parse_method(
self,
method: MethodDeclaration,
class_name: str,
class_path: str,
source_file: str,
) -> Optional[ApiEndpoint]:
"""解析带映射注解的方法。"""
for ann in method.annotations or []:
ann_name = _ann_simple_name(ann)
if ann_name not in MAPPING_ANNS:
continue
method_path = _ann_string(ann, "value", "path")
javadoc_params: Dict[str, str] = {}
if getattr(method, "position", None) and method.position:
javadoc = _extract_javadoc_before_line(
self._current_source, method.position.line
)
javadoc_params = _parse_javadoc_params(javadoc)
params = []
for p in method.parameters or []:
params.extend(self._extract_param(p, javadoc_params))
return ApiEndpoint(
http_method=_http_method(ann_name, ann),
uri=_join_paths(class_path, method_path),
controller_class=class_name,
method_name=method.name,
source_file=source_file.replace("\\", "/"),
parameters=params,
)
return None
def _extract_param(
self, param: FormalParameter, javadoc_params: Optional[Dict[str, str]] = None
) -> List[ApiParameter]:
"""提取方法参数,@RequestBody 展开 DTO 字段;忽略框架注入参数。"""
type_name = _type_to_str(param.type)
name = _param_name(param)
javadoc_params = javadoc_params or {}
if _is_framework_param(type_name, name):
return []
if _has_ann(param, "RequestBody"):
return self._expand_dto(type_name, "body")
description = _lookup_param_description(javadoc_params, param, name)
return [
ApiParameter(
name=name,
type=type_name,
required=_param_required(param),
source=_param_source(param),
description=description,
)
]
def _expand_dto(self, type_name: str, source: str) -> List[ApiParameter]:
"""展开 @RequestBody DTO 字段。"""
simple = type_name.split(".")[-1].replace(">", "").replace("<", "").strip()
if simple in self._dto_cache:
return self._dto_cache[simple]
dto_file = self._find_dto_file(simple)
if not dto_file:
result = [ApiParameter(name=simple, type=type_name, required=True, source=source)]
self._dto_cache[simple] = result
return result
try:
dto_source = dto_file.read_text(encoding="utf-8", errors="ignore")
tree = javalang.parse.parse(dto_source)
except (javalang.parser.JavaSyntaxError, OSError):
result = [ApiParameter(name=simple, type=type_name, required=True, source=source)]
self._dto_cache[simple] = result
return result
fields: List[ApiParameter] = []
for type_decl in tree.types or []:
if not isinstance(type_decl, ClassDeclaration):
continue
for field in type_decl.fields or []:
if "static" in (field.modifiers or []):
continue
field_javadoc = ""
if getattr(field, "position", None) and field.position:
field_javadoc = _extract_javadoc_before_line(
dto_source, field.position.line
)
field_desc = _clean_javadoc_text(
field_javadoc.replace("/**", "").replace("*/", "").strip()
) or None
for decl in field.declarators:
fields.append(
ApiParameter(
name=decl.name,
type=_type_to_str(field.type),
required=not _has_ann(field, "Nullable"),
source=source,
description=field_desc,
)
)
if not fields:
fields = [ApiParameter(name=simple, type=type_name, required=True, source=source)]
self._dto_cache[simple] = fields
return fields
def _find_dto_file(self, simple_name: str) -> Optional[Path]:
"""在源码目录中查找 DTO 文件。"""
if not self.source_dir.exists():
return None
target = f"{simple_name}.java"
for path in self.source_dir.rglob(target):
return path
return None
def parse_controller_files(
repo_root: Path,
source_subdir: str,
file_paths: List[str],
file_contents: Dict[str, str],
) -> List[ApiEndpoint]:
"""
批量解析指定 Controller 文件(仅解析传入的文件,不全量扫描)。
:param repo_root: 仓库根目录
:param source_subdir: 源码子目录(相对仓库根)
:param file_paths: 要解析的文件路径列表(相对仓库根)
:param file_contents: {文件路径: 源码内容}
:return: 所有端点
"""
source_dir = (repo_root / source_subdir).resolve()
parser = ControllerAstParser(repo_root, source_dir)
endpoints: List[ApiEndpoint] = []
for path in file_paths:
norm = path.replace("\\", "/")
content = file_contents.get(norm)
if not content:
continue
endpoints.extend(parser.parse_file_content(content, norm))
return endpoints