491 lines
16 KiB
Python
491 lines
16 KiB
Python
"""
|
||
纯 Python Controller AST 解析器(基于 javalang,无需 Java/Maven)。
|
||
仅解析 Spring @RestController / @Controller 的接口映射与参数。
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import re
|
||
from pathlib import Path
|
||
from typing import Dict, List, Optional
|
||
|
||
import javalang
|
||
from javalang.tree import (
|
||
Annotation,
|
||
ClassDeclaration,
|
||
ElementValuePair,
|
||
FieldDeclaration,
|
||
FormalParameter,
|
||
Literal,
|
||
MemberReference,
|
||
MethodDeclaration,
|
||
)
|
||
|
||
from models import ApiEndpoint, ApiParameter
|
||
|
||
MAPPING_ANNS = {"GetMapping", "PostMapping", "PutMapping", "DeleteMapping", "PatchMapping", "RequestMapping"}
|
||
CONTROLLER_ANNS = {"RestController", "Controller"}
|
||
|
||
# Spring MVC 框架自动注入参数,不属于 API 调用方入参,解析时忽略
|
||
FRAMEWORK_PARAM_TYPES = {
|
||
"HttpServletRequest",
|
||
"HttpServletResponse",
|
||
"HttpSession",
|
||
"ServletRequest",
|
||
"ServletResponse",
|
||
"WebRequest",
|
||
"NativeWebRequest",
|
||
"Model",
|
||
"ModelMap",
|
||
"RedirectAttributes",
|
||
"BindingResult",
|
||
"Errors",
|
||
"Authentication",
|
||
"Principal",
|
||
"Locale",
|
||
"TimeZone",
|
||
"InputStream",
|
||
"OutputStream",
|
||
"Reader",
|
||
"Writer",
|
||
"HttpHeaders",
|
||
"UriComponentsBuilder",
|
||
}
|
||
|
||
|
||
def _is_framework_param(type_name: str, param_name: str) -> bool:
|
||
"""判断是否为框架注入参数(非 API 调用方需要传递)。"""
|
||
simple = type_name.split(".")[-1].replace(">", "").replace("<", "").strip()
|
||
if simple in FRAMEWORK_PARAM_TYPES:
|
||
return True
|
||
if param_name in ("request", "response") and (
|
||
simple.endswith("Request") or simple.endswith("Response")
|
||
):
|
||
return True
|
||
return False
|
||
|
||
|
||
def _ann_simple_name(ann: Annotation) -> str:
|
||
"""获取注解简单类名。"""
|
||
return ann.name.split(".")[-1]
|
||
|
||
|
||
def _literal_str(node) -> str:
|
||
"""从 AST 节点提取字符串或布尔值。"""
|
||
if node is None:
|
||
return ""
|
||
if isinstance(node, Literal):
|
||
v = node.value
|
||
if isinstance(v, bool):
|
||
return str(v).lower()
|
||
return str(v or "").strip('"').strip("'")
|
||
if isinstance(node, MemberReference):
|
||
return node.member
|
||
if isinstance(node, bool):
|
||
return str(node).lower()
|
||
return str(node).strip('"').strip("'")
|
||
|
||
|
||
def _collect_ann_members(ann: Annotation) -> Dict[str, str]:
|
||
"""
|
||
解析注解成员为字典。
|
||
支持 @GetMapping("/path") 与 @RequestParam(value="x", required=false)。
|
||
"""
|
||
members: Dict[str, str] = {}
|
||
el = ann.element
|
||
if el is None:
|
||
return members
|
||
if isinstance(el, ElementValuePair):
|
||
members[el.name] = _literal_str(el.value)
|
||
elif isinstance(el, list):
|
||
for item in el:
|
||
if isinstance(item, ElementValuePair):
|
||
members[item.name] = _literal_str(item.value)
|
||
else:
|
||
members["value"] = _literal_str(el)
|
||
return members
|
||
|
||
|
||
def _ann_string(ann: Annotation, *keys: str) -> str:
|
||
"""从注解提取字符串属性。"""
|
||
members = _collect_ann_members(ann)
|
||
for k in keys:
|
||
if k in members and members[k]:
|
||
return members[k]
|
||
if "value" in members:
|
||
return members["value"]
|
||
return ""
|
||
|
||
|
||
def _type_to_str(type_node) -> str:
|
||
"""Java 类型节点转字符串。"""
|
||
if type_node is None:
|
||
return "Object"
|
||
if isinstance(type_node, javalang.tree.BasicType):
|
||
return type_node.name
|
||
if isinstance(type_node, javalang.tree.ReferenceType):
|
||
name = type_node.name or "Object"
|
||
if type_node.arguments:
|
||
args = ",".join(_type_to_str(a.type) for a in type_node.arguments)
|
||
return f"{name}<{args}>"
|
||
if type_node.sub_type:
|
||
return _type_to_str(type_node.sub_type)
|
||
return name
|
||
return str(type_node)
|
||
|
||
|
||
def _normalize_path(path: str) -> str:
|
||
"""规范化 URI 路径。"""
|
||
if not path or not path.strip():
|
||
return ""
|
||
path = path.strip()
|
||
if not path.startswith("/"):
|
||
path = "/" + path
|
||
return re.sub(r"/+", "/", path)
|
||
|
||
|
||
def _join_paths(base: str, sub: str) -> str:
|
||
"""拼接类级与方法级路径。"""
|
||
b, m = _normalize_path(base), _normalize_path(sub)
|
||
if not b:
|
||
return m or "/"
|
||
if not m:
|
||
return b
|
||
if b.endswith("/") and m.startswith("/"):
|
||
return b + m[1:]
|
||
if not b.endswith("/") and not m.startswith("/"):
|
||
return b + "/" + m
|
||
return b + m
|
||
|
||
|
||
def _http_method(ann_name: str, ann: Annotation) -> str:
|
||
"""从映射注解推断 HTTP 方法。
|
||
|
||
支持大小写不敏感匹配,避免 PUTMapping、POSTMapping 等不规范写法导致解析失败。
|
||
"""
|
||
mapping = {
|
||
"GetMapping": "GET",
|
||
"PostMapping": "POST",
|
||
"PutMapping": "PUT",
|
||
"DeleteMapping": "DELETE",
|
||
"PatchMapping": "PATCH",
|
||
}
|
||
# 大小写不敏感匹配
|
||
for key, value in mapping.items():
|
||
if key.lower() == ann_name.lower():
|
||
return value
|
||
if ann_name.lower() == "requestmapping":
|
||
m = _ann_string(ann, "method")
|
||
if m:
|
||
return m.replace("RequestMethod.", "").upper()
|
||
return "GET"
|
||
return "GET"
|
||
|
||
|
||
def _has_ann(node, name: str) -> bool:
|
||
"""节点是否含指定注解。"""
|
||
anns = getattr(node, "annotations", None) or []
|
||
return any(_ann_simple_name(a) == name for a in anns)
|
||
|
||
|
||
def _find_ann(node, name: str) -> Optional[Annotation]:
|
||
"""查找指定注解。"""
|
||
for a in getattr(node, "annotations", None) or []:
|
||
if _ann_simple_name(a) == name:
|
||
return a
|
||
return None
|
||
|
||
|
||
def _param_source(param: FormalParameter) -> str:
|
||
"""参数来源:path / query / header / form / body。"""
|
||
if _has_ann(param, "PathVariable"):
|
||
return "path"
|
||
if _has_ann(param, "RequestHeader"):
|
||
return "header"
|
||
if _has_ann(param, "RequestPart") or _has_ann(param, "ModelAttribute"):
|
||
return "form"
|
||
if _has_ann(param, "RequestBody"):
|
||
return "body"
|
||
return "query"
|
||
|
||
|
||
def _param_name(param: FormalParameter) -> str:
|
||
"""解析参数名。"""
|
||
for ann_name in ("RequestParam", "PathVariable", "RequestHeader", "RequestPart"):
|
||
ann = _find_ann(param, ann_name)
|
||
if ann:
|
||
val = _ann_string(ann, "value", "name")
|
||
if val:
|
||
return val
|
||
return param.name
|
||
|
||
|
||
def _param_required(param: FormalParameter) -> bool:
|
||
"""是否必填。"""
|
||
ann = _find_ann(param, "RequestParam")
|
||
if ann:
|
||
members = _collect_ann_members(ann)
|
||
if members.get("required", "").lower() == "false":
|
||
return False
|
||
type_name = _type_to_str(param.type)
|
||
if type_name.startswith("Optional"):
|
||
return False
|
||
return not _has_ann(param, "Nullable")
|
||
|
||
|
||
def _param_description(param: FormalParameter) -> Optional[str]:
|
||
"""提取参数描述,优先 Swagger 注解,其次 @RequestParam 的 value。"""
|
||
# 1. Swagger @ApiParam
|
||
ann = _find_ann(param, "ApiParam")
|
||
if ann:
|
||
desc = _ann_string(ann, "value", "name")
|
||
if desc:
|
||
return desc
|
||
|
||
# 2. Swagger @ApiImplicitParam(方法级注解,较少见,暂不实现)
|
||
|
||
# 3. Fallback: @RequestParam 的 value 作为描述提示
|
||
ann = _find_ann(param, "RequestParam")
|
||
if ann:
|
||
val = _ann_string(ann, "value", "name")
|
||
if val and val != param.name:
|
||
return val
|
||
|
||
return None
|
||
|
||
|
||
def _parse_javadoc_params(source_code: str) -> Dict[str, Dict[str, str]]:
|
||
"""
|
||
使用正则从 Java 源代码中提取方法上的 Javadoc @param 描述。
|
||
|
||
返回结构:
|
||
{ method_name: { param_name: description } }
|
||
"""
|
||
javadoc_map: Dict[str, Dict[str, str]] = {}
|
||
|
||
# 匹配 Javadoc 块 + 其后的方法声明
|
||
# 简单启发式:Javadoc 后面第一个标识符 + ( 视为方法名
|
||
pattern = re.compile(
|
||
r'/\*\*(?P<javadoc>.*?)\*/\s*(?:public|private|protected|static|\s)*'
|
||
r'[\w<>\[\],\.\s]+\s+(?P<method>\w+)\s*\(',
|
||
re.DOTALL | re.MULTILINE
|
||
)
|
||
|
||
for match in pattern.finditer(source_code):
|
||
javadoc_block = match.group('javadoc')
|
||
method_name = match.group('method')
|
||
|
||
param_descs: Dict[str, str] = {}
|
||
# 提取 @param 行
|
||
param_pattern = re.compile(r'@param\s+(\w+)\s+([^\n@]+)', re.IGNORECASE)
|
||
for p_match in param_pattern.finditer(javadoc_block):
|
||
p_name = p_match.group(1)
|
||
p_desc = p_match.group(2).strip()
|
||
if p_desc:
|
||
param_descs[p_name] = p_desc
|
||
|
||
if param_descs:
|
||
javadoc_map[method_name] = param_descs
|
||
|
||
return javadoc_map
|
||
|
||
|
||
class ControllerAstParser:
|
||
"""
|
||
基于 javalang 的 Controller 解析器。
|
||
只解析传入的文件,不扫描整个目录(CI 更快)。
|
||
"""
|
||
|
||
def __init__(self, repo_root: Path, source_dir: Path):
|
||
"""
|
||
:param repo_root: 仓库根目录
|
||
:param source_dir: Java 源码根目录(repo_root 下的相对路径对应的绝对路径)
|
||
"""
|
||
self.repo_root = repo_root
|
||
self.source_dir = source_dir
|
||
self._dto_cache: Dict[str, List[ApiParameter]] = {}
|
||
self._javadoc_map: Dict[str, Dict[str, str]] = {} # method_name -> param_name -> description
|
||
|
||
def parse_file_content(self, source: str, repo_relative_path: str) -> List[ApiEndpoint]:
|
||
"""
|
||
解析单个 Java 源文件内容。
|
||
|
||
:param source: 源码文本
|
||
:param repo_relative_path: 相对仓库根目录的路径(与 git diff 一致)
|
||
:return: 端点列表
|
||
"""
|
||
endpoints: List[ApiEndpoint] = []
|
||
try:
|
||
tree = javalang.parse.parse(source)
|
||
except (javalang.parser.JavaSyntaxError, RecursionError) as exc:
|
||
print(f"[警告] 解析失败 {repo_relative_path}: {exc}")
|
||
return endpoints
|
||
|
||
# 解析当前文件的 Javadoc @param(用于参数描述回退)
|
||
self._javadoc_map = _parse_javadoc_params(source)
|
||
|
||
for type_decl in tree.types or []:
|
||
if not isinstance(type_decl, ClassDeclaration):
|
||
continue
|
||
if not self._is_controller(type_decl):
|
||
continue
|
||
|
||
class_path = ""
|
||
for ann in type_decl.annotations or []:
|
||
if _ann_simple_name(ann) == "RequestMapping":
|
||
class_path = _ann_string(ann, "value", "path")
|
||
break
|
||
|
||
for method in type_decl.methods or []:
|
||
ep = self._parse_method(method, type_decl.name, class_path, repo_relative_path)
|
||
if ep:
|
||
endpoints.append(ep)
|
||
|
||
return endpoints
|
||
|
||
def _is_controller(self, cls: ClassDeclaration) -> bool:
|
||
"""是否为 Controller 类。"""
|
||
return any(_ann_simple_name(a) in CONTROLLER_ANNS for a in (cls.annotations or []))
|
||
|
||
def _parse_method(
|
||
self,
|
||
method: MethodDeclaration,
|
||
class_name: str,
|
||
class_path: str,
|
||
source_file: str,
|
||
) -> Optional[ApiEndpoint]:
|
||
"""解析带映射注解的方法。"""
|
||
for ann in method.annotations or []:
|
||
ann_name = _ann_simple_name(ann)
|
||
if ann_name not in MAPPING_ANNS:
|
||
continue
|
||
|
||
method_path = _ann_string(ann, "value", "path")
|
||
params = []
|
||
for p in method.parameters or []:
|
||
params.extend(self._extract_param(p, method_name=method.name))
|
||
|
||
return ApiEndpoint(
|
||
http_method=_http_method(ann_name, ann),
|
||
uri=_join_paths(class_path, method_path),
|
||
controller_class=class_name,
|
||
method_name=method.name,
|
||
source_file=source_file.replace("\\", "/"),
|
||
parameters=params,
|
||
)
|
||
return None
|
||
|
||
def _extract_param(self, param: FormalParameter, method_name: Optional[str] = None) -> List[ApiParameter]:
|
||
"""提取方法参数,@RequestBody 展开 DTO 字段;忽略框架注入参数。"""
|
||
type_name = _type_to_str(param.type)
|
||
name = _param_name(param)
|
||
|
||
if _is_framework_param(type_name, name):
|
||
return []
|
||
|
||
if _has_ann(param, "RequestBody"):
|
||
return self._expand_dto(type_name, "body")
|
||
|
||
desc = _param_description(param)
|
||
|
||
# Javadoc @param 回退(如果注解中没有描述)
|
||
if not desc and method_name and method_name in self._javadoc_map:
|
||
desc = self._javadoc_map[method_name].get(name)
|
||
|
||
return [
|
||
ApiParameter(
|
||
name=name,
|
||
type=type_name,
|
||
required=_param_required(param),
|
||
source=_param_source(param),
|
||
description=desc,
|
||
)
|
||
]
|
||
|
||
def _expand_dto(self, type_name: str, source: str) -> List[ApiParameter]:
|
||
"""展开 @RequestBody DTO 字段。"""
|
||
simple = type_name.split(".")[-1].replace(">", "").replace("<", "").strip()
|
||
if simple in self._dto_cache:
|
||
return self._dto_cache[simple]
|
||
|
||
dto_file = self._find_dto_file(simple)
|
||
if not dto_file:
|
||
result = [ApiParameter(name=simple, type=type_name, required=True, source=source)]
|
||
self._dto_cache[simple] = result
|
||
return result
|
||
|
||
try:
|
||
tree = javalang.parse.parse(dto_file.read_text(encoding="utf-8", errors="ignore"))
|
||
except (javalang.parser.JavaSyntaxError, OSError):
|
||
result = [ApiParameter(name=simple, type=type_name, required=True, source=source)]
|
||
self._dto_cache[simple] = result
|
||
return result
|
||
|
||
fields: List[ApiParameter] = []
|
||
for type_decl in tree.types or []:
|
||
if not isinstance(type_decl, ClassDeclaration):
|
||
continue
|
||
for field in type_decl.fields or []:
|
||
if "static" in (field.modifiers or []):
|
||
continue
|
||
for decl in field.declarators:
|
||
# 尝试提取 @ApiModelProperty 的描述
|
||
desc = None
|
||
api_model_ann = _find_ann(field, "ApiModelProperty")
|
||
if api_model_ann:
|
||
desc = _ann_string(api_model_ann, "value")
|
||
|
||
fields.append(
|
||
ApiParameter(
|
||
name=decl.name,
|
||
type=_type_to_str(field.type),
|
||
required=not _has_ann(field, "Nullable"),
|
||
source=source,
|
||
description=desc,
|
||
)
|
||
)
|
||
|
||
if not fields:
|
||
fields = [ApiParameter(name=simple, type=type_name, required=True, source=source)]
|
||
|
||
self._dto_cache[simple] = fields
|
||
return fields
|
||
|
||
def _find_dto_file(self, simple_name: str) -> Optional[Path]:
|
||
"""在源码目录中查找 DTO 文件。"""
|
||
if not self.source_dir.exists():
|
||
return None
|
||
target = f"{simple_name}.java"
|
||
for path in self.source_dir.rglob(target):
|
||
return path
|
||
return None
|
||
|
||
|
||
def parse_controller_files(
|
||
repo_root: Path,
|
||
source_subdir: str,
|
||
file_paths: List[str],
|
||
file_contents: Dict[str, str],
|
||
) -> List[ApiEndpoint]:
|
||
"""
|
||
批量解析指定 Controller 文件(仅解析传入的文件,不全量扫描)。
|
||
|
||
:param repo_root: 仓库根目录
|
||
:param source_subdir: 源码子目录(相对仓库根)
|
||
:param file_paths: 要解析的文件路径列表(相对仓库根)
|
||
:param file_contents: {文件路径: 源码内容}
|
||
:return: 所有端点
|
||
"""
|
||
source_dir = (repo_root / source_subdir).resolve()
|
||
parser = ControllerAstParser(repo_root, source_dir)
|
||
endpoints: List[ApiEndpoint] = []
|
||
|
||
for path in file_paths:
|
||
norm = path.replace("\\", "/")
|
||
content = file_contents.get(norm)
|
||
if not content:
|
||
continue
|
||
endpoints.extend(parser.parse_file_content(content, norm))
|
||
|
||
return endpoints
|