""" 纯 Python Controller AST 解析器(基于 javalang,无需 Java/Maven)。 仅解析 Spring @RestController / @Controller 的接口映射与参数。 """ from __future__ import annotations import re from pathlib import Path from typing import Dict, List, Optional import javalang from javalang.tree import ( Annotation, ClassDeclaration, ElementValuePair, FieldDeclaration, FormalParameter, Literal, MemberReference, MethodDeclaration, ) from models import ApiEndpoint, ApiParameter MAPPING_ANNS = {"GetMapping", "PostMapping", "PutMapping", "DeleteMapping", "PatchMapping", "RequestMapping"} CONTROLLER_ANNS = {"RestController", "Controller"} # Spring MVC 框架自动注入参数,不属于 API 调用方入参,解析时忽略 FRAMEWORK_PARAM_TYPES = { "HttpServletRequest", "HttpServletResponse", "HttpSession", "ServletRequest", "ServletResponse", "WebRequest", "NativeWebRequest", "Model", "ModelMap", "RedirectAttributes", "BindingResult", "Errors", "Authentication", "Principal", "Locale", "TimeZone", "InputStream", "OutputStream", "Reader", "Writer", "HttpHeaders", "UriComponentsBuilder", } def _is_framework_param(type_name: str, param_name: str) -> bool: """判断是否为框架注入参数(非 API 调用方需要传递)。""" simple = type_name.split(".")[-1].replace(">", "").replace("<", "").strip() if simple in FRAMEWORK_PARAM_TYPES: return True if param_name in ("request", "response") and ( simple.endswith("Request") or simple.endswith("Response") ): return True return False def _ann_simple_name(ann: Annotation) -> str: """获取注解简单类名。""" return ann.name.split(".")[-1] def _literal_str(node) -> str: """从 AST 节点提取字符串或布尔值。""" if node is None: return "" if isinstance(node, Literal): v = node.value if isinstance(v, bool): return str(v).lower() return str(v or "").strip('"').strip("'") if isinstance(node, MemberReference): return node.member if isinstance(node, bool): return str(node).lower() return str(node).strip('"').strip("'") def _collect_ann_members(ann: Annotation) -> Dict[str, str]: """ 解析注解成员为字典。 支持 @GetMapping("/path") 与 @RequestParam(value="x", required=false)。 """ members: Dict[str, str] = {} el = ann.element if el is None: return members if isinstance(el, ElementValuePair): members[el.name] = _literal_str(el.value) elif isinstance(el, list): for item in el: if isinstance(item, ElementValuePair): members[item.name] = _literal_str(item.value) else: members["value"] = _literal_str(el) return members def _ann_string(ann: Annotation, *keys: str) -> str: """从注解提取字符串属性。""" members = _collect_ann_members(ann) for k in keys: if k in members and members[k]: return members[k] if "value" in members: return members["value"] return "" def _type_to_str(type_node) -> str: """Java 类型节点转字符串。""" if type_node is None: return "Object" if isinstance(type_node, javalang.tree.BasicType): return type_node.name if isinstance(type_node, javalang.tree.ReferenceType): name = type_node.name or "Object" if type_node.arguments: args = ",".join(_type_to_str(a.type) for a in type_node.arguments) return f"{name}<{args}>" if type_node.sub_type: return _type_to_str(type_node.sub_type) return name return str(type_node) def _normalize_path(path: str) -> str: """规范化 URI 路径。""" if not path or not path.strip(): return "" path = path.strip() if not path.startswith("/"): path = "/" + path return re.sub(r"/+", "/", path) def _join_paths(base: str, sub: str) -> str: """拼接类级与方法级路径。""" b, m = _normalize_path(base), _normalize_path(sub) if not b: return m or "/" if not m: return b if b.endswith("/") and m.startswith("/"): return b + m[1:] if not b.endswith("/") and not m.startswith("/"): return b + "/" + m return b + m def _http_method(ann_name: str, ann: Annotation) -> str: """从映射注解推断 HTTP 方法。 支持大小写不敏感匹配,避免 PUTMapping、POSTMapping 等不规范写法导致解析失败。 """ mapping = { "GetMapping": "GET", "PostMapping": "POST", "PutMapping": "PUT", "DeleteMapping": "DELETE", "PatchMapping": "PATCH", } # 大小写不敏感匹配 for key, value in mapping.items(): if key.lower() == ann_name.lower(): return value if ann_name.lower() == "requestmapping": m = _ann_string(ann, "method") if m: return m.replace("RequestMethod.", "").upper() return "GET" return "GET" def _has_ann(node, name: str) -> bool: """节点是否含指定注解。""" anns = getattr(node, "annotations", None) or [] return any(_ann_simple_name(a) == name for a in anns) def _find_ann(node, name: str) -> Optional[Annotation]: """查找指定注解。""" for a in getattr(node, "annotations", None) or []: if _ann_simple_name(a) == name: return a return None def _param_source(param: FormalParameter) -> str: """参数来源:path / query / header / form / body。""" if _has_ann(param, "PathVariable"): return "path" if _has_ann(param, "RequestHeader"): return "header" if _has_ann(param, "RequestPart") or _has_ann(param, "ModelAttribute"): return "form" if _has_ann(param, "RequestBody"): return "body" return "query" def _param_name(param: FormalParameter) -> str: """解析参数名。""" for ann_name in ("RequestParam", "PathVariable", "RequestHeader", "RequestPart"): ann = _find_ann(param, ann_name) if ann: val = _ann_string(ann, "value", "name") if val: return val return param.name def _param_required(param: FormalParameter) -> bool: """是否必填。""" ann = _find_ann(param, "RequestParam") if ann: members = _collect_ann_members(ann) if members.get("required", "").lower() == "false": return False type_name = _type_to_str(param.type) if type_name.startswith("Optional"): return False return not _has_ann(param, "Nullable") class ControllerAstParser: """ 基于 javalang 的 Controller 解析器。 只解析传入的文件,不扫描整个目录(CI 更快)。 """ def __init__(self, repo_root: Path, source_dir: Path): """ :param repo_root: 仓库根目录 :param source_dir: Java 源码根目录(repo_root 下的相对路径对应的绝对路径) """ self.repo_root = repo_root self.source_dir = source_dir self._dto_cache: Dict[str, List[ApiParameter]] = {} def parse_file_content(self, source: str, repo_relative_path: str) -> List[ApiEndpoint]: """ 解析单个 Java 源文件内容。 :param source: 源码文本 :param repo_relative_path: 相对仓库根目录的路径(与 git diff 一致) :return: 端点列表 """ endpoints: List[ApiEndpoint] = [] try: tree = javalang.parse.parse(source) except (javalang.parser.JavaSyntaxError, RecursionError) as exc: print(f"[警告] 解析失败 {repo_relative_path}: {exc}") return endpoints for type_decl in tree.types or []: if not isinstance(type_decl, ClassDeclaration): continue if not self._is_controller(type_decl): continue class_path = "" for ann in type_decl.annotations or []: if _ann_simple_name(ann) == "RequestMapping": class_path = _ann_string(ann, "value", "path") break for method in type_decl.methods or []: ep = self._parse_method(method, type_decl.name, class_path, repo_relative_path) if ep: endpoints.append(ep) return endpoints def _is_controller(self, cls: ClassDeclaration) -> bool: """是否为 Controller 类。""" return any(_ann_simple_name(a) in CONTROLLER_ANNS for a in (cls.annotations or [])) def _parse_method( self, method: MethodDeclaration, class_name: str, class_path: str, source_file: str, ) -> Optional[ApiEndpoint]: """解析带映射注解的方法。""" for ann in method.annotations or []: ann_name = _ann_simple_name(ann) if ann_name not in MAPPING_ANNS: continue method_path = _ann_string(ann, "value", "path") params = [] for p in method.parameters or []: params.extend(self._extract_param(p)) return ApiEndpoint( http_method=_http_method(ann_name, ann), uri=_join_paths(class_path, method_path), controller_class=class_name, method_name=method.name, source_file=source_file.replace("\\", "/"), parameters=params, ) return None def _extract_param(self, param: FormalParameter) -> List[ApiParameter]: """提取方法参数,@RequestBody 展开 DTO 字段;忽略框架注入参数。""" type_name = _type_to_str(param.type) name = _param_name(param) if _is_framework_param(type_name, name): return [] if _has_ann(param, "RequestBody"): return self._expand_dto(type_name, "body") return [ ApiParameter( name=name, type=type_name, required=_param_required(param), source=_param_source(param), ) ] def _expand_dto(self, type_name: str, source: str) -> List[ApiParameter]: """展开 @RequestBody DTO 字段。""" simple = type_name.split(".")[-1].replace(">", "").replace("<", "").strip() if simple in self._dto_cache: return self._dto_cache[simple] dto_file = self._find_dto_file(simple) if not dto_file: result = [ApiParameter(name=simple, type=type_name, required=True, source=source)] self._dto_cache[simple] = result return result try: tree = javalang.parse.parse(dto_file.read_text(encoding="utf-8", errors="ignore")) except (javalang.parser.JavaSyntaxError, OSError): result = [ApiParameter(name=simple, type=type_name, required=True, source=source)] self._dto_cache[simple] = result return result fields: List[ApiParameter] = [] for type_decl in tree.types or []: if not isinstance(type_decl, ClassDeclaration): continue for field in type_decl.fields or []: if "static" in (field.modifiers or []): continue for decl in field.declarators: fields.append( ApiParameter( name=decl.name, type=_type_to_str(field.type), required=not _has_ann(field, "Nullable"), source=source, ) ) if not fields: fields = [ApiParameter(name=simple, type=type_name, required=True, source=source)] self._dto_cache[simple] = fields return fields def _find_dto_file(self, simple_name: str) -> Optional[Path]: """在源码目录中查找 DTO 文件。""" if not self.source_dir.exists(): return None target = f"{simple_name}.java" for path in self.source_dir.rglob(target): return path return None def parse_controller_files( repo_root: Path, source_subdir: str, file_paths: List[str], file_contents: Dict[str, str], ) -> List[ApiEndpoint]: """ 批量解析指定 Controller 文件(仅解析传入的文件,不全量扫描)。 :param repo_root: 仓库根目录 :param source_subdir: 源码子目录(相对仓库根) :param file_paths: 要解析的文件路径列表(相对仓库根) :param file_contents: {文件路径: 源码内容} :return: 所有端点 """ source_dir = (repo_root / source_subdir).resolve() parser = ControllerAstParser(repo_root, source_dir) endpoints: List[ApiEndpoint] = [] for path in file_paths: norm = path.replace("\\", "/") content = file_contents.get(norm) if not content: continue endpoints.extend(parser.parse_file_content(content, norm)) return endpoints