脚本修改
All checks were successful
API接口参数变更检测 / api-param-check (push) Successful in 21s

This commit is contained in:
2026-06-03 15:33:52 +08:00
parent 2c20a26af8
commit 5a57c32558
2 changed files with 395 additions and 0 deletions

View File

@@ -0,0 +1,362 @@
"""
纯 Python Controller AST 解析器(基于 javalang无需 Java/Maven
仅解析 Spring @RestController / @Controller 的接口映射与参数。
"""
from __future__ import annotations
import re
from pathlib import Path
from typing import Dict, List, Optional, Union
import javalang
from javalang.tree import (
Annotation,
ClassDeclaration,
FieldDeclaration,
FormalParameter,
Literal,
MemberReference,
MethodDeclaration,
)
from models import ApiEndpoint, ApiParameter
MAPPING_ANNS = {"GetMapping", "PostMapping", "PutMapping", "DeleteMapping", "PatchMapping", "RequestMapping"}
CONTROLLER_ANNS = {"RestController", "Controller"}
def _ann_simple_name(ann: Annotation) -> str:
"""获取注解简单类名。"""
return ann.name.split(".")[-1]
def _literal_str(node) -> str:
"""从 Literal 节点提取字符串值。"""
if node is None:
return ""
if isinstance(node, Literal):
v = node.value or ""
return str(v).strip('"').strip("'")
if isinstance(node, MemberReference):
return node.member
return str(node).strip('"').strip("'")
def _collect_ann_members(ann: Annotation) -> Dict[str, str]:
"""
解析注解成员为字典。
支持 @GetMapping("/path") 与 @RequestParam(value="x", required=false)。
"""
members: Dict[str, str] = {}
el = ann.element
if el is None:
return members
if isinstance(el, list):
for item in el:
if hasattr(item, "name") and hasattr(item, "value"):
members[item.name] = _literal_str(item.value)
elif hasattr(el, "name") and hasattr(el, "value"):
members[el.name] = _literal_str(el.value)
else:
members["value"] = _literal_str(el)
return members
def _ann_string(ann: Annotation, *keys: str) -> str:
"""从注解提取字符串属性。"""
members = _collect_ann_members(ann)
for k in keys:
if k in members and members[k]:
return members[k]
if "value" in members:
return members["value"]
return ""
def _type_to_str(type_node) -> str:
"""Java 类型节点转字符串。"""
if type_node is None:
return "Object"
if isinstance(type_node, javalang.tree.BasicType):
return type_node.name
if isinstance(type_node, javalang.tree.ReferenceType):
name = type_node.name or "Object"
if type_node.arguments:
args = ",".join(_type_to_str(a.type) for a in type_node.arguments)
return f"{name}<{args}>"
if type_node.sub_type:
return _type_to_str(type_node.sub_type)
return name
return str(type_node)
def _normalize_path(path: str) -> str:
"""规范化 URI 路径。"""
if not path or not path.strip():
return ""
path = path.strip()
if not path.startswith("/"):
path = "/" + path
return re.sub(r"/+", "/", path)
def _join_paths(base: str, sub: str) -> str:
"""拼接类级与方法级路径。"""
b, m = _normalize_path(base), _normalize_path(sub)
if not b:
return m or "/"
if not m:
return b
if b.endswith("/") and m.startswith("/"):
return b + m[1:]
if not b.endswith("/") and not m.startswith("/"):
return b + "/" + m
return b + m
def _http_method(ann_name: str, ann: Annotation) -> str:
"""从映射注解推断 HTTP 方法。"""
mapping = {
"GetMapping": "GET",
"PostMapping": "POST",
"PutMapping": "PUT",
"DeleteMapping": "DELETE",
"PatchMapping": "PATCH",
}
if ann_name in mapping:
return mapping[ann_name]
if ann_name == "RequestMapping":
m = _ann_string(ann, "method")
if m:
return m.replace("RequestMethod.", "").upper()
return "GET"
return "GET"
def _has_ann(node, name: str) -> bool:
"""节点是否含指定注解。"""
anns = getattr(node, "annotations", None) or []
return any(_ann_simple_name(a) == name for a in anns)
def _find_ann(node, name: str) -> Optional[Annotation]:
"""查找指定注解。"""
for a in getattr(node, "annotations", None) or []:
if _ann_simple_name(a) == name:
return a
return None
def _param_source(param: FormalParameter) -> str:
"""参数来源path / query / header / form / body。"""
if _has_ann(param, "PathVariable"):
return "path"
if _has_ann(param, "RequestHeader"):
return "header"
if _has_ann(param, "RequestPart") or _has_ann(param, "ModelAttribute"):
return "form"
if _has_ann(param, "RequestBody"):
return "body"
return "query"
def _param_name(param: FormalParameter) -> str:
"""解析参数名。"""
for ann_name in ("RequestParam", "PathVariable", "RequestHeader", "RequestPart"):
ann = _find_ann(param, ann_name)
if ann:
val = _ann_string(ann, "value", "name")
if val:
return val
return param.name
def _param_required(param: FormalParameter) -> bool:
"""是否必填。"""
ann = _find_ann(param, "RequestParam")
if ann:
members = _collect_ann_members(ann)
if members.get("required", "").lower() == "false":
return False
type_name = _type_to_str(param.type)
if type_name.startswith("Optional"):
return False
return not _has_ann(param, "Nullable")
class ControllerAstParser:
"""
基于 javalang 的 Controller 解析器。
只解析传入的文件不扫描整个目录CI 更快)。
"""
def __init__(self, repo_root: Path, source_dir: Path):
"""
:param repo_root: 仓库根目录
:param source_dir: Java 源码根目录repo_root 下的相对路径对应的绝对路径)
"""
self.repo_root = repo_root
self.source_dir = source_dir
self._dto_cache: Dict[str, List[ApiParameter]] = {}
def parse_file_content(self, source: str, repo_relative_path: str) -> List[ApiEndpoint]:
"""
解析单个 Java 源文件内容。
:param source: 源码文本
:param repo_relative_path: 相对仓库根目录的路径(与 git diff 一致)
:return: 端点列表
"""
endpoints: List[ApiEndpoint] = []
try:
tree = javalang.parse.parse(source)
except (javalang.parser.JavaSyntaxError, RecursionError) as exc:
print(f"[警告] 解析失败 {repo_relative_path}: {exc}")
return endpoints
for type_decl in tree.types or []:
if not isinstance(type_decl, ClassDeclaration):
continue
if not self._is_controller(type_decl):
continue
class_path = ""
for ann in type_decl.annotations or []:
if _ann_simple_name(ann) == "RequestMapping":
class_path = _ann_string(ann, "value", "path")
break
for method in type_decl.methods or []:
ep = self._parse_method(method, type_decl.name, class_path, repo_relative_path)
if ep:
endpoints.append(ep)
return endpoints
def _is_controller(self, cls: ClassDeclaration) -> bool:
"""是否为 Controller 类。"""
return any(_ann_simple_name(a) in CONTROLLER_ANNS for a in (cls.annotations or []))
def _parse_method(
self,
method: MethodDeclaration,
class_name: str,
class_path: str,
source_file: str,
) -> Optional[ApiEndpoint]:
"""解析带映射注解的方法。"""
for ann in method.annotations or []:
ann_name = _ann_simple_name(ann)
if ann_name not in MAPPING_ANNS:
continue
method_path = _ann_string(ann, "value", "path")
params = []
for p in method.parameters or []:
params.extend(self._extract_param(p))
return ApiEndpoint(
http_method=_http_method(ann_name, ann),
uri=_join_paths(class_path, method_path),
controller_class=class_name,
method_name=method.name,
source_file=source_file.replace("\\", "/"),
parameters=params,
)
return None
def _extract_param(self, param: FormalParameter) -> List[ApiParameter]:
"""提取方法参数,@RequestBody 展开 DTO 字段。"""
type_name = _type_to_str(param.type)
if _has_ann(param, "RequestBody"):
return self._expand_dto(type_name, "body")
return [
ApiParameter(
name=_param_name(param),
type=type_name,
required=_param_required(param),
source=_param_source(param),
)
]
def _expand_dto(self, type_name: str, source: str) -> List[ApiParameter]:
"""展开 @RequestBody DTO 字段。"""
simple = type_name.split(".")[-1].replace(">", "").replace("<", "").strip()
if simple in self._dto_cache:
return self._dto_cache[simple]
dto_file = self._find_dto_file(simple)
if not dto_file:
result = [ApiParameter(name=simple, type=type_name, required=True, source=source)]
self._dto_cache[simple] = result
return result
try:
tree = javalang.parse.parse(dto_file.read_text(encoding="utf-8", errors="ignore"))
except (javalang.parser.JavaSyntaxError, OSError):
result = [ApiParameter(name=simple, type=type_name, required=True, source=source)]
self._dto_cache[simple] = result
return result
fields: List[ApiParameter] = []
for type_decl in tree.types or []:
if not isinstance(type_decl, ClassDeclaration):
continue
for field in type_decl.fields or []:
if "static" in (field.modifiers or []):
continue
for decl in field.declarators:
fields.append(
ApiParameter(
name=decl.name,
type=_type_to_str(field.type),
required=not _has_ann(field, "Nullable"),
source=source,
)
)
if not fields:
fields = [ApiParameter(name=simple, type=type_name, required=True, source=source)]
self._dto_cache[simple] = fields
return fields
def _find_dto_file(self, simple_name: str) -> Optional[Path]:
"""在源码目录中查找 DTO 文件。"""
if not self.source_dir.exists():
return None
target = f"{simple_name}.java"
for path in self.source_dir.rglob(target):
return path
return None
def parse_controller_files(
repo_root: Path,
source_subdir: str,
file_paths: List[str],
file_contents: Dict[str, str],
) -> List[ApiEndpoint]:
"""
批量解析指定 Controller 文件(仅解析传入的文件,不全量扫描)。
:param repo_root: 仓库根目录
:param source_subdir: 源码子目录(相对仓库根)
:param file_paths: 要解析的文件路径列表(相对仓库根)
:param file_contents: {文件路径: 源码内容}
:return: 所有端点
"""
source_dir = repo_root / source_subdir.replace("/", Path.sep)
parser = ControllerAstParser(repo_root, source_dir)
endpoints: List[ApiEndpoint] = []
for path in file_paths:
content = file_contents.get(path)
if not content:
continue
norm = path.replace("\\", "/")
endpoints.extend(parser.parse_file_content(content, norm))
return endpoints

33
.gitea/checker/models.py Normal file
View File

@@ -0,0 +1,33 @@
"""
Controller 端点数据模型。
"""
from dataclasses import dataclass, field
from typing import List, Optional
@dataclass
class ApiParameter:
"""单个接口参数。"""
name: str
type: str
required: bool = True
source: str = "query"
description: Optional[str] = None
@dataclass
class ApiEndpoint:
"""单个 Controller 接口端点。"""
http_method: str
uri: str
controller_class: str
method_name: str
source_file: str
parameters: List[ApiParameter] = field(default_factory=list)
@property
def endpoint_key(self) -> str:
return f"{self.http_method} {self.uri}"