接口类对象解析

This commit is contained in:
2026-06-05 16:35:51 +08:00
parent 3cba3bb74e
commit 03fb9766a6
7 changed files with 192 additions and 49 deletions

View File

@@ -23,7 +23,8 @@ from javalang.tree import (
from models import ApiEndpoint, ApiParameter
MAPPING_ANNS = {"GetMapping", "PostMapping", "PutMapping", "DeleteMapping", "PatchMapping", "RequestMapping"}
# javax.validation 必填注解
REQUIRED_FIELD_ANNS = {"NotNull", "NotEmpty", "NotBlank"}
CONTROLLER_ANNS = {"RestController", "Controller"}
# Spring MVC 框架自动注入参数,不属于 API 调用方入参,解析时忽略
@@ -281,6 +282,16 @@ def _extract_javadoc_before_line(source: str, target_line: int) -> str:
return "\n".join(lines[idx : end_idx + 1])
def _field_required(field: FieldDeclaration) -> bool:
"""DTO 字段是否必填(@NotNull / @NotEmpty / @NotBlank"""
if _has_ann(field, "Nullable"):
return False
for ann in field.annotations or []:
if _ann_simple_name(ann) in REQUIRED_FIELD_ANNS:
return True
return False
def _lookup_param_description(
javadoc_params: Dict[str, str], param: FormalParameter, resolved_name: str
) -> Optional[str]:
@@ -297,13 +308,13 @@ class ControllerAstParser:
只解析传入的文件不扫描整个目录CI 更快)。
"""
def __init__(self, repo_root: Path, source_dir: Path):
def __init__(self, repo_root: Path, source_dirs: List[Path]):
"""
:param repo_root: 仓库根目录
:param source_dir: Java 源码根目录repo_root 下的相对路径对应的绝对路径
:param repo_root: 仓库根目录
:param source_dirs: Java 源码根目录列表(用于查找 DTO 等
"""
self.repo_root = repo_root
self.source_dir = source_dir
self.source_dirs = source_dirs
self._dto_cache: Dict[str, List[ApiParameter]] = {}
self._current_source = ""
@@ -393,7 +404,13 @@ class ControllerAstParser:
return []
if _has_ann(param, "RequestBody"):
return self._expand_dto(type_name, "body")
body_desc = _lookup_param_description(javadoc_params, param, name)
return self._expand_dto(
type_name,
"body",
body_param_name=param.name,
body_param_desc=body_desc,
)
description = _lookup_param_description(javadoc_params, param, name)
return [
@@ -406,24 +423,51 @@ class ControllerAstParser:
)
]
def _expand_dto(self, type_name: str, source: str) -> List[ApiParameter]:
"""展开 @RequestBody DTO 字段。"""
def _expand_dto(
self,
type_name: str,
source: str,
body_param_name: str = "",
body_param_desc: Optional[str] = None,
) -> List[ApiParameter]:
"""展开 @RequestBody DTO 一级字段。"""
simple = type_name.split(".")[-1].replace(">", "").replace("<", "").strip()
if simple in self._dto_cache:
return self._dto_cache[simple]
cache_key = f"{simple}:{body_param_name}"
if cache_key in self._dto_cache:
return self._dto_cache[cache_key]
dto_file = self._find_dto_file(simple)
if not dto_file:
result = [ApiParameter(name=simple, type=type_name, required=True, source=source)]
self._dto_cache[simple] = result
result = [
ApiParameter(
name=simple,
type=type_name,
required=True,
source=source,
description=body_param_desc,
parent_dto=simple,
body_param_name=body_param_name or None,
)
]
self._dto_cache[cache_key] = result
return result
try:
dto_source = dto_file.read_text(encoding="utf-8", errors="ignore")
tree = javalang.parse.parse(dto_source)
except (javalang.parser.JavaSyntaxError, OSError):
result = [ApiParameter(name=simple, type=type_name, required=True, source=source)]
self._dto_cache[simple] = result
result = [
ApiParameter(
name=simple,
type=type_name,
required=True,
source=source,
description=body_param_desc,
parent_dto=simple,
body_param_name=body_param_name or None,
)
]
self._dto_cache[cache_key] = result
return result
fields: List[ApiParameter] = []
@@ -446,45 +490,61 @@ class ControllerAstParser:
ApiParameter(
name=decl.name,
type=_type_to_str(field.type),
required=not _has_ann(field, "Nullable"),
required=_field_required(field),
source=source,
description=field_desc,
parent_dto=simple,
body_param_name=body_param_name or None,
)
)
if not fields:
fields = [ApiParameter(name=simple, type=type_name, required=True, source=source)]
fields = [
ApiParameter(
name=simple,
type=type_name,
required=True,
source=source,
description=body_param_desc,
parent_dto=simple,
body_param_name=body_param_name or None,
)
]
self._dto_cache[simple] = fields
self._dto_cache[cache_key] = fields
return fields
def _find_dto_file(self, simple_name: str) -> Optional[Path]:
"""在源码目录中查找 DTO 文件。"""
if not self.source_dir.exists():
return None
"""配置的源码目录及仓库内 src/main/java 中查找 DTO 文件。"""
target = f"{simple_name}.java"
for path in self.source_dir.rglob(target):
return path
for source_dir in self.source_dirs:
if source_dir.exists():
for path in source_dir.rglob(target):
return path
if self.repo_root.exists():
for path in self.repo_root.rglob(target):
if "src/main/java" in path.as_posix():
return path
return None
def parse_controller_files(
repo_root: Path,
source_subdir: str,
source_subdirs: List[str],
file_paths: List[str],
file_contents: Dict[str, str],
) -> List[ApiEndpoint]:
"""
批量解析指定 Controller 文件(仅解析传入的文件,不全量扫描)。
:param repo_root: 仓库根目录
:param source_subdir: 源码子目录(相对仓库根)
:param file_paths: 要解析的文件路径列表(相对仓库根)
:param file_contents: {文件路径: 源码内容}
:param repo_root: 仓库根目录
:param source_subdirs: 源码子目录列表(相对仓库根)
:param file_paths: 要解析的文件路径列表(相对仓库根)
:param file_contents: {文件路径: 源码内容}
:return: 所有端点
"""
source_dir = (repo_root / source_subdir).resolve()
parser = ControllerAstParser(repo_root, source_dir)
source_dirs = [(repo_root / sub).resolve() for sub in source_subdirs]
parser = ControllerAstParser(repo_root, source_dirs)
endpoints: List[ApiEndpoint] = []
for path in file_paths: