Skip to content

Commit 4e89ec4

Browse files
committed
fix post-review robustness for test generation flow
Tighten answer_ext normalization and retry semantics, clear stale generated outputs before final writes, and make cleanup/semantic checks more resilient across platforms and generator branch styles. Made-with: Cursor
1 parent 3baac63 commit 4e89ec4

3 files changed

Lines changed: 57 additions & 14 deletions

File tree

src/autocode_mcp/tools/generator.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -142,18 +142,19 @@ async def execute(
142142
)
143143

144144
def _check_type34_semantics(self, code: str) -> dict:
145-
has_type3 = bool(re.search(r"type\s*==\s*3", code))
146-
has_type4 = bool(re.search(r"type\s*==\s*4", code))
145+
type3_blocks = self._extract_type_branch_snippets(code, 3)
146+
type4_blocks = self._extract_type_branch_snippets(code, 4)
147+
has_type3 = bool(type3_blocks)
148+
has_type4 = bool(type4_blocks)
147149
if not has_type3 or not has_type4:
148150
return {
149151
"enabled": True,
150-
"passed": False,
151-
"reason": "generator lacks explicit type==3/type==4 branches",
152-
"hint": "需要给 type=3/type=4 设计不同逻辑,避免仅靠参数放大",
152+
"passed": True,
153+
"advisory": True,
154+
"reason": "semantic check could not reliably detect both type=3/type=4 branches",
155+
"hint": "请人工确认 type=3/type=4 分支存在且有实质差异",
153156
}
154157

155-
type3_blocks = re.findall(r"type\s*==\s*3[\s\S]{0,240}", code)
156-
type4_blocks = re.findall(r"type\s*==\s*4[\s\S]{0,240}", code)
157158
norm3 = " ".join(type3_blocks).replace(" ", "")
158159
norm4 = " ".join(type4_blocks).replace(" ", "")
159160
output_lines = [line.strip() for line in code.splitlines() if "cout" in line or "printf" in line]
@@ -166,6 +167,18 @@ def _check_type34_semantics(self, code: str) -> dict:
166167
"hint": "为 type=4 增加针对性卡法,而不仅是 n_max/t_max 取最大值",
167168
}
168169

170+
def _extract_type_branch_snippets(self, code: str, type_value: int) -> list[str]:
171+
patterns = [
172+
rf"type\s*==\s*{type_value}\b",
173+
rf"\b{type_value}\s*==\s*type\b",
174+
rf"case\s+{type_value}\s*:",
175+
]
176+
snippets: list[str] = []
177+
for pattern in patterns:
178+
for match in re.finditer(pattern, code):
179+
snippets.append(code[match.start(): match.start() + 240])
180+
return snippets
181+
169182

170183
class GeneratorRunTool(Tool):
171184
"""运行多策略数据生成器。"""

src/autocode_mcp/tools/problem.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -563,6 +563,11 @@ async def execute(
563563
else:
564564
final_tests = candidates
565565

566+
# 最终写盘前清理历史生成产物,防止 resume 场景残留旧编号样例。
567+
clear_before_write_error = self._clear_generated_tests(tests_dir, normalized_answer_ext)
568+
if clear_before_write_error:
569+
return clear_before_write_error
570+
566571
# 写入文件
567572
generated_tests = []
568573
test_manifest: list[dict[str, str | int]] = []
@@ -721,6 +726,8 @@ def _normalize_answer_ext(self, answer_ext: str) -> tuple[str | None, ToolResult
721726
return None, ToolResult.fail("answer_ext cannot be empty")
722727
if not ext.startswith("."):
723728
ext = f".{ext}"
729+
if not any(ch != "." for ch in ext[1:]):
730+
return None, ToolResult.fail("answer_ext must contain non-dot characters")
724731
if any(ch in ext for ch in ('/', '\\', ':', '*', '?', '"', "<", ">", "|")):
725732
return None, ToolResult.fail("answer_ext contains illegal characters")
726733
if ext == ".in":
@@ -758,7 +765,7 @@ def _on_start(pid: int) -> None:
758765
# 取消路径保留 PID 到状态文件,供 cleanup 精准回收。
759766
if started_pid is not None and not cancelled:
760767
active_pids.discard(started_pid)
761-
if not getattr(last_result, "error", None):
768+
if last_result.success:
762769
return last_result
763770
await asyncio.sleep(0.1 * (2**attempt))
764771
if last_result is not None:
@@ -1076,7 +1083,13 @@ async def execute(self, problem_dir: str, kill_all_generators: bool = False) ->
10761083
if os.path.exists(state_path) and not pids:
10771084
os.remove(state_path)
10781085
removed_files.append(state_path)
1079-
return ToolResult.ok(removed_files=removed_files, message="Cleanup finished")
1086+
return ToolResult.ok(
1087+
removed_files=removed_files,
1088+
killed_pids=[],
1089+
failed_pids=[],
1090+
warning="PID termination is only supported on Windows" if kill_all_generators and os.name != "nt" else "",
1091+
message="Cleanup finished",
1092+
)
10801093

10811094
def _load_cleanup_state(self, state_path: str) -> dict | None:
10821095
if not os.path.exists(state_path):

src/autocode_mcp/tools/test_verify.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -493,16 +493,33 @@ def _check_limit_semantics(self, tests_dir: str) -> dict:
493493
}
494494

495495
def _resolve_answer_ext(self, tests_dir: str, answer_ext: str | None) -> str:
496-
if answer_ext:
497-
return answer_ext if answer_ext.startswith(".") else f".{answer_ext}"
496+
normalized = self._normalize_answer_ext(answer_ext)
497+
if normalized:
498+
return normalized
498499
manifest_path = os.path.join(tests_dir, _TEST_MANIFEST_FILENAME)
499500
if os.path.exists(manifest_path):
500501
try:
501502
with open(manifest_path, encoding="utf-8") as f:
502503
manifest = json.load(f)
503-
ext = manifest.get("answer_ext")
504-
if isinstance(ext, str) and ext:
505-
return ext if ext.startswith(".") else f".{ext}"
504+
ext = self._normalize_answer_ext(manifest.get("answer_ext"))
505+
if ext:
506+
return ext
506507
except (json.JSONDecodeError, OSError):
507508
pass
508509
return ".ans"
510+
511+
def _normalize_answer_ext(self, answer_ext: str | None) -> str | None:
512+
if not isinstance(answer_ext, str):
513+
return None
514+
ext = answer_ext.strip()
515+
if not ext:
516+
return None
517+
if not ext.startswith("."):
518+
ext = f".{ext}"
519+
if not any(ch != "." for ch in ext[1:]):
520+
return None
521+
if any(ch in ext for ch in ('/', '\\', ':', '*', '?', '"', "<", ">", "|")):
522+
return None
523+
if ext == ".in":
524+
return None
525+
return ext

0 commit comments

Comments
 (0)