diff --git a/app/dataplane/reverse/protocol/tool_prompt.py b/app/dataplane/reverse/protocol/tool_prompt.py index b3617ed0..8b869361 100644 --- a/app/dataplane/reverse/protocol/tool_prompt.py +++ b/app/dataplane/reverse/protocol/tool_prompt.py @@ -42,6 +42,19 @@ NOTE: Even if you believe you cannot fulfill the request, you must still follow the WHEN TO CALL rule above.\ """ +_CODEX_TOOL_NOTE = """ + +CODEX TOOL RULES: +- If the user asks you to inspect local files, list directories, run shell/bash commands, edit files, apply patches, or verify tests, you MUST call the matching tool instead of describing what you would do. +- For shell/bash/terminal commands, use exec_command with a JSON object containing at least {"cmd": "..."}. +- For interactive or long-running shell sessions, use write_stdin with the existing session_id when you need to send input or poll output. Do not replace write_stdin with exec_command. +- For file edits, use exec_command to run apply_patch with a heredoc. Do not write files with echo, printf, cat, tee, sed -i, or Python. +- Example edit command: apply_patch <<'PATCH'\n*** Begin Patch\n*** Add File: example.txt\n+hello\n*** End Patch\nPATCH +- If apply_patch fails, inspect the target file and retry with a corrected, minimal patch. Do not repeat the same failed patch unchanged. +- Do not invent results from the filesystem, web, or tools. Call the tool and wait for the tool result. +- Do not mention unavailable tool names. Use exactly one of the AVAILABLE TOOLS names. +""" + _CHOICE_AUTO = "WHEN TO CALL: Call a tool when it is clearly needed. Otherwise respond in plain text." _CHOICE_NONE = "WHEN TO CALL: Do NOT call any tools. Respond in plain text only." _CHOICE_REQUIRED = "WHEN TO CALL: You MUST output a XML block. Do NOT write any plain-text reply. If you are uncertain, still call the most relevant tool with your best guess at the parameters." @@ -65,10 +78,13 @@ def build_tool_system_prompt( """ tool_defs = _format_tool_definitions(tools) choice_instruction = _build_choice_instruction(tools, tool_choice) - return _TOOL_SYSTEM_HEADER.format( + prompt = _TOOL_SYSTEM_HEADER.format( tool_definitions=tool_defs, tool_choice_instruction=choice_instruction, ) + if _looks_like_codex_tools(tools): + prompt += _CODEX_TOOL_NOTE + return prompt def extract_tool_names(tools: list[dict[str, Any]]) -> list[str]: @@ -84,7 +100,12 @@ def extract_tool_names(tools: list[dict[str, Any]]) -> list[str]: def inject_into_message(message: str, system_prompt: str) -> str: """Prepend the tool system prompt to the flattened message string.""" - return f"[system]: {system_prompt}\n\n{message}" + return ( + f"[system]: {system_prompt}\n\n" + f"{message}\n\n" + "[system]: If a tool is needed now, your next response must be only the " + " XML block. Do not say that you will call a tool; call it." + ) def tool_calls_to_xml(tool_calls: list[dict[str, Any]]) -> str: @@ -155,3 +176,8 @@ def _build_choice_instruction( if forced_name: return _CHOICE_FORCED.format(name=forced_name) return _CHOICE_AUTO + + +def _looks_like_codex_tools(tools: list[dict[str, Any]]) -> bool: + names = set(extract_tool_names(tools)) + return bool({"exec_command", "apply_patch", "write_stdin"} & names) diff --git a/app/products/openai/_codex_tools.py b/app/products/openai/_codex_tools.py new file mode 100644 index 00000000..23b87e6f --- /dev/null +++ b/app/products/openai/_codex_tools.py @@ -0,0 +1,545 @@ +"""Codex-oriented Responses API tool fallback helpers.""" + +from __future__ import annotations + +import re +import shlex +from typing import Any + +import orjson + +from app.dataplane.reverse.protocol.tool_parser import ParsedToolCall +from app.platform.logging.logger import logger + + +# --------------------------------------------------------------------------- +# Tool format normalisation +# --------------------------------------------------------------------------- + +def _to_chat_tools(tools: list[dict]) -> list[dict]: + """Normalise Responses API tool format → Chat Completions format. + + Responses API: {type, name, description, parameters} (flat) + Chat Completions: {type, function: {name, description, parameters}} + + Already-wrapped tools are passed through unchanged so this is safe to + call regardless of which format the caller used. + """ + normalised = [] + for tool in tools: + if not isinstance(tool, dict): + continue + if tool.get("type") == "function" and "function" not in tool and "name" in tool: + normalised.append({ + "type": "function", + "function": { + "name": tool.get("name", ""), + "description": tool.get("description", ""), + "parameters": tool.get("parameters"), + }, + }) + elif tool.get("type") == "function": + normalised.append(tool) + elif "name" in tool and "parameters" in tool: + # Some Responses clients send function-shaped tools with a custom + # type. Grok only sees the prompt, so normalize the callable + # schema and ignore non-callable provider metadata. + normalised.append({ + "type": "function", + "function": { + "name": tool.get("name", ""), + "description": tool.get("description", ""), + "parameters": tool.get("parameters"), + }, + }) + else: + # Skip namespace/web_search/image tools here. They require native + # Responses semantics; passing them as fake functions confuses the + # model and Codex will not execute them as function_call items. + continue + return normalised + + +_LOCAL_TOOL_REQUEST_RE = re.compile( + r"shell|bash|terminal|command|exec_command|apply_patch|write_stdin|" + r"\bpwd\b|\bls\b|\brg\b|\bcat\b|\bsed\b|\bfind\b|" + r"工具|调用|运行|执行|命令|终端|查看|读取|列出|文件|目录|修改|编辑", + re.IGNORECASE, +) +_WRITE_STDIN_REQUEST_RE = re.compile( + r"write_stdin|stdin|send input|send .* to session|poll output|session_id|" + r"输入到.*会话|发送.*会话|轮询|长进程|交互", + re.IGNORECASE, +) +_TOOL_RESULT_RE = re.compile( + r"\[tool result|function_call_output|tool_call_id|session_id|patch:\s*(failed|completed)", + re.IGNORECASE, +) +_CODEX_COMMAND_OUTPUT_RE = re.compile( + r"command_execution|aggregated_output|exit_code|Process exited with code|" + r"Chunk ID:|Wall time:|Original token count:", + re.IGNORECASE, +) +_PATCH_COMPLETED_RE = re.compile(r"patch:\s*completed", re.IGNORECASE) +_PATCH_FAILURE_RE = re.compile( + r"apply_patch|patch\b|补丁|diff", + re.IGNORECASE, +) + + +def _looks_like_codex_tool_run(tool_names: list[str], message: str, text: str) -> bool: + if not tool_names: + return False + if any(name and name in text for name in tool_names): + return True + if "exec_command" in tool_names and _LOCAL_TOOL_REQUEST_RE.search(message): + return True + return bool(_LOCAL_TOOL_REQUEST_RE.search(text)) + + +def _forced_tool_choice(tool_names: list[str], message: str, text: str) -> Any: + if "write_stdin" in tool_names and _WRITE_STDIN_REQUEST_RE.search(f"{message}\n{text}"): + return {"type": "function", "function": {"name": "write_stdin"}} + if "exec_command" in tool_names and ( + _LOCAL_TOOL_REQUEST_RE.search(message) or "exec_command" in text + ): + return {"type": "function", "function": {"name": "exec_command"}} + if "apply_patch" in tool_names and ( + "apply_patch" in text or re.search(r"修改|编辑|patch|补丁", message, re.I) + ): + return {"type": "function", "function": {"name": "apply_patch"}} + return "required" + + +def _synthesize_codex_tool_call( + tool_names: list[str], + message: str, + previous_text: str, +) -> list[ParsedToolCall]: + if "write_stdin" in tool_names: + stdin_args = _synthesize_write_stdin_args(message, previous_text) + if stdin_args: + return [ParsedToolCall.make("write_stdin", stdin_args)] + if "exec_command" not in tool_names: + return [] + intent = _latest_user_intent(message) + if _has_prior_tool_result(intent) or _has_prior_command_output(message): + return [] + patch_cmd = _synthesize_apply_patch_command(intent) + if patch_cmd: + return [ParsedToolCall.make("exec_command", {"cmd": patch_cmd})] + cmd = _extract_requested_shell_command(intent, previous_text) + if not cmd: + return [] + return [ParsedToolCall.make("exec_command", {"cmd": cmd})] + + +def _normalize_codex_tool_calls( + calls: list[ParsedToolCall], + *, + tool_names: list[str], + message: str, +) -> list[ParsedToolCall]: + if "exec_command" not in tool_names: + return calls + normalized: list[ParsedToolCall] = [] + for call in calls: + if call.name != "exec_command": + normalized.append(call) + continue + args = _json_args(call.arguments) + cmd = str(args.get("cmd", "")).strip() + if _is_duplicate_completed_patch(cmd, message): + logger.info("responses suppressed duplicate completed apply_patch command") + continue + if _looks_like_direct_file_write(cmd) and _looks_like_edit_request(message): + patch_cmd = _command_to_apply_patch(cmd) or _synthesize_apply_patch_command(message) + if patch_cmd: + args["cmd"] = patch_cmd + normalized.append(ParsedToolCall(call.call_id, call.name, _json_dumps(args))) + continue + normalized.append(call) + return normalized + + +def _has_prior_tool_result(message: str) -> bool: + return bool(_TOOL_RESULT_RE.search(message)) + + +def _has_prior_command_output(message: str) -> bool: + if not _CODEX_COMMAND_OUTPUT_RE.search(message): + return False + latest_user_idx = message.lower().rfind("[user]:") + if latest_user_idx < 0: + return True + return bool(_CODEX_COMMAND_OUTPUT_RE.search(message[latest_user_idx:])) + + +def _is_duplicate_completed_patch(cmd: str, message: str) -> bool: + if not cmd or "apply_patch" not in cmd or not _PATCH_COMPLETED_RE.search(message): + return False + target = _extract_patch_target(cmd) + if not target: + return True + return target in message + + +def _extract_patch_target(cmd: str) -> str | None: + m = re.search(r"^\*\*\* (?:Add|Update|Delete) File:\s+(.+?)\s*$", cmd, re.M) + if not m: + return None + path = m.group(1).strip() + return path if _valid_patch_path(path) else None + + +def _json_args(arguments: str) -> dict[str, Any]: + try: + parsed = orjson.loads(arguments) + return parsed if isinstance(parsed, dict) else {} + except Exception: + return {} + + +def _json_dumps(value: Any) -> str: + return orjson.dumps(value).decode() + + +def _looks_like_edit_request(message: str) -> bool: + return bool(re.search(r"apply_patch|patch|修改|编辑|创建|新增|写入|内容为|create file|edit file|write file", _latest_user_intent(message), re.I)) + + +def _looks_like_direct_file_write(cmd: str) -> bool: + return bool(re.search(r"(^|\s)(echo|printf|cat|tee)\b[\s\S]*(>|>>|\btee\b)", cmd)) + + +def _command_to_apply_patch(cmd: str) -> str | None: + parsed = _parse_simple_write_command(cmd) + if not parsed: + return None + path, content, append = parsed + if append: + return None + return _apply_patch_add_file_command(path, content) + + +def _parse_simple_write_command(cmd: str) -> tuple[str, str, bool] | None: + # echo 'ok' > file + m = re.match(r"""echo\s+(['"])(.*?)\1\s*(>>?)\s*([^\s]+)\s*$""", cmd, re.S) + if m: + return m.group(4), m.group(2) + "\n", m.group(3) == ">>" + + # printf 'ok\n' > file + m = re.match(r"""printf\s+(['"])(.*?)\1\s*(>>?)\s*([^\s]+)\s*$""", cmd, re.S) + if m: + content = bytes(m.group(2), "utf-8").decode("unicode_escape") + return m.group(4), content, m.group(3) == ">>" + + # cat > file <<'EOF' ... EOF + m = re.match( + r"""cat\s*>\s*([^\s]+)\s*<<['"]?([A-Za-z0-9_]+)['"]?\n([\s\S]*)\n\2\s*$""", + cmd, + ) + if m: + return m.group(1), m.group(3) + "\n", False + + # cat <<'EOF' > file ... EOF + m = re.match( + r"""cat\s*<<['"]?([A-Za-z0-9_]+)['"]?\s*>\s*([^\s]+)\n([\s\S]*)\n\1\s*$""", + cmd, + ) + if m: + return m.group(2), m.group(3) + "\n", False + + # tee file <<'EOF' ... EOF + m = re.match( + r"""tee\s+([^\s]+)\s*<<['"]?([A-Za-z0-9_]+)['"]?\n([\s\S]*)\n\2\s*$""", + cmd, + ) + if m: + return m.group(1), m.group(3) + "\n", False + return None + + +def _synthesize_apply_patch_command(message: str) -> str | None: + intent = _latest_user_intent(message) + replace_cmd = _synthesize_simple_replace_patch(intent) + if replace_cmd: + return replace_cmd + if not _looks_like_simple_create_request(intent): + return None + path = _extract_target_filename(intent) + content = _extract_requested_file_content(intent) + if not path or content is None: + return None + return _apply_patch_add_file_command(path, content) + + +def _synthesize_simple_replace_patch(intent: str) -> str | None: + if not intent or len(intent) > 1200: + return None + if not re.search(r"修改|编辑|替换|改成|replace|change", intent, re.I): + return None + path = _extract_target_filename(intent) + if not path: + return None + pair = _extract_replacement_pair(intent) + if not pair: + return None + old, new = pair + return _apply_patch_replace_line_command(path, old, new) + + +def _extract_replacement_pair(intent: str) -> tuple[str, str] | None: + patterns = ( + r"(?:把|将)\s*(?:文件\s*)?`?[A-Za-z0-9_./-]+\.[A-Za-z0-9_-]+`?\s*(?:里|中的|里面的)?(?:的)?\s*[`'\"]?([^`'\"\n。;;,,\s]+)[`'\"]?\s*(?:改成|替换成|换成)\s*[`'\"]?([^`'\"\n。;;,,\s]+)[`'\"]?", + r"(?:replace|change)\s+[`'\"]?([^`'\"\n]+?)[`'\"]?\s+(?:with|to)\s+[`'\"]?([^`'\"\n]+?)[`'\"]?\s+(?:in|inside)\s+`?[A-Za-z0-9_./-]+\.[A-Za-z0-9_-]+`?", + ) + for pattern in patterns: + m = re.search(pattern, intent, re.I) + if m: + old = m.group(1).strip() + new = m.group(2).strip() + if _valid_patch_line(old) and _valid_patch_line(new): + return old, new + return None + + +def _latest_user_intent(message: str) -> str: + """Return the latest user-facing request, excluding injected tool schemas. + + The fallback synthesizer must be conservative: the full prompt contains + tool descriptions, model names and schema fragments, all of which can look + like filenames. Only use explicit user/conversation blocks when present. + """ + blocks = re.findall( + r"\[(?:user|conversation)\]:\s*([\s\S]*?)(?=\n\[(?:system|assistant|tool|conversation|user)\]:|\Z)", + message, + re.IGNORECASE, + ) + if blocks: + return blocks[-1].strip() + + # Drop the final injected reminder and the leading tool prompt when this is + # a flattened prompt built by inject_into_message(). + text = re.split(r"\n\n\[system\]:\s*If a tool is needed now", message, maxsplit=1, flags=re.I)[0] + if text.startswith("[system]:") and "AVAILABLE TOOLS:" in text: + for marker in ( + "Do not mention unavailable tool names. Use exactly one of the AVAILABLE TOOLS names.", + "NOTE: Even if you believe you cannot fulfill the request, you must still follow the WHEN TO CALL rule above.", + ): + if marker in text: + text = text.rsplit(marker, 1)[1] + break + else: + parts = text.split("\n\n", 1) + text = parts[1] if len(parts) > 1 else "" + return text.strip() + + +def _looks_like_simple_create_request(intent: str) -> bool: + if not intent or len(intent) > 1200: + return False + has_create = re.search(r"创建|新增|create (?:a )?file|add (?:a )?file|write (?:a )?file", intent, re.I) + has_content = re.search(r"内容为|内容是|content(?:\s+is|\s+为)?|with content", intent, re.I) + return bool(has_create and has_content and _extract_target_filename(intent)) + + +def _extract_target_filename(message: str) -> str | None: + patterns = ( + r"(?:文件|file)\s+`?([A-Za-z0-9_./-]+\.[A-Za-z0-9_-]+)`?", + r"`([A-Za-z0-9_./-]+\.[A-Za-z0-9_-]+)`", + r"\b((?:\.?/)?(?:[A-Za-z0-9_-]+/)*[A-Za-z0-9_-]+\.[A-Za-z0-9_-]+)\b", + ) + for pattern in patterns: + m = re.search(pattern, message, re.I) + if m: + path = m.group(1).strip() + if _valid_patch_path(path): + return path + return None + + +def _extract_requested_file_content(message: str) -> str | None: + m = re.search(r"(?:内容为|内容是|content(?:\s+is|\s+为)?|with content)\s*```(?:[A-Za-z0-9_-]+)?\n([\s\S]*?)\n```", message, re.I) + if m: + return m.group(1).rstrip("\n") + "\n" + m = re.search(r"(?:内容为|内容是|content(?:\s+is|\s+为)?|with content)\s*[`'\"]?([^`'\"\n。;;,,]+)[`'\"]?", message, re.I) + if not m: + return None + return m.group(1).strip() + "\n" + + +def _valid_patch_path(path: str) -> bool: + if not path or path.startswith(("/", "~")) or ".." in path.split("/"): + return False + if path in {"gpt-5.5", "gpt-5.4", "gpt-5.3", "text"}: + return False + return bool(re.match(r"^[A-Za-z0-9_./-]+\.[A-Za-z0-9_-]+$", path)) + + +def _apply_patch_add_file_command(path: str, content: str) -> str: + lines = ["apply_patch <<'PATCH'", "*** Begin Patch", f"*** Add File: {path}"] + for line in content.splitlines(): + lines.append(f"+{line}") + if content.endswith("\n") and not content.splitlines(): + lines.append("+") + lines.extend(["*** End Patch", "PATCH"]) + return "\n".join(lines) + + +def _apply_patch_replace_line_command(path: str, old: str, new: str) -> str: + return "\n".join( + [ + "apply_patch <<'PATCH'", + "*** Begin Patch", + f"*** Update File: {path}", + "@@", + f"-{old}", + f"+{new}", + "*** End Patch", + "PATCH", + ] + ) + + +def _valid_patch_line(value: str) -> bool: + if not value or len(value) > 300: + return False + return "\n" not in value and "\r" not in value + + +def _synthesize_write_stdin_args(message: str, previous_text: str) -> dict[str, Any] | None: + haystack = f"{_latest_user_intent(message)}\n{previous_text}" + if not _WRITE_STDIN_REQUEST_RE.search(haystack): + return None + sid_match = re.search(r"(?:session_id|session|会话)\s*[:=#]?\s*([0-9]{1,12})", haystack, re.I) + if not sid_match: + return None + chars = "" + for pattern in ( + r"(?:send|write|输入|发送)\s+[`'\"]([^`'\"]{0,500})[`'\"]", + r"(?:chars|input|内容)\s*[:=]\s*[`'\"]([^`'\"]{0,500})[`'\"]", + ): + m = re.search(pattern, haystack, re.I) + if m: + chars = m.group(1) + break + return { + "session_id": int(sid_match.group(1)), + "chars": chars, + } + + +def _extract_requested_shell_command(message: str, previous_text: str) -> str | None: + haystack = f"{_latest_user_intent(message)}\n\n{previous_text}" + + # Commands in code spans are the cleanest signal. + for m in re.finditer(r"`([^`\n]{1,300})`", haystack): + candidate = m.group(1).strip() + if _looks_safe_shell_snippet(candidate): + return candidate + + patterns = ( + r"(?:run|execute|运行|执行)\s+(?:the\s+)?(?:shell\s+)?(?:command\s+|命令\s*)?([A-Za-z0-9_./ -]{1,160})", + r"(?:使用|用)\s*(?:shell|bash|终端).*?(?:运行|执行)\s*([A-Za-z0-9_./ -]{1,160})", + ) + for pattern in patterns: + m = re.search(pattern, haystack, re.IGNORECASE) + if not m: + continue + candidate = _clean_command_candidate(m.group(1)) + if _looks_safe_shell_snippet(candidate): + return candidate + + lowered = haystack.lower() + if re.search(r"\bpwd\b", lowered) or "当前路径" in haystack or "当前目录路径" in haystack: + return "pwd" + if ( + "desktop" in lowered + or "桌面" in haystack + ) and ( + "有哪些" in haystack + or "列出" in haystack + or "看看" in haystack + or re.search(r"\b(list|show|see)\b", lowered) + ): + return "ls -la ~/Desktop | sed -n '1,40p'" + if ( + "有哪些文件" in haystack + or "列出" in haystack and "文件" in haystack + or re.search(r"\blist\b.*\bfiles\b", lowered) + ): + return "find . -maxdepth 2 -print | sort" + read_cmd = _synthesize_read_file_command(haystack) + if read_cmd: + return read_cmd + search_cmd = _synthesize_search_command(haystack) + if search_cmd: + return search_cmd + if re.search(r"\bls\b", lowered): + return "ls -la" + return None + + +def _synthesize_read_file_command(haystack: str) -> str | None: + if not re.search(r"读取|查看|打开|读一下|read|show|cat", haystack, re.I): + return None + path = _extract_target_filename(haystack) + if not path: + return None + return f"sed -n '1,200p' {shlex.quote(path)}" + + +def _synthesize_search_command(haystack: str) -> str | None: + if not re.search(r"搜索|查找|包含|grep|rg|search|find", haystack, re.I): + return None + term = _extract_search_term(haystack) + if not term: + return None + return f"rg -n -- {shlex.quote(term)} ." + + +def _extract_search_term(haystack: str) -> str | None: + patterns = ( + r"(?:包含|含有)\s*[`'\"]?([^`'\"\n。;;,,\s]{1,120})[`'\"]?", + r"(?:搜索|查找)\s*[`'\"]?([^`'\"\n。;;,,\s]{1,120})[`'\"]?", + r"(?:search|find|grep)\s+(?:for\s+)?[`'\"]?([^`'\"\n]{1,120})[`'\"]?", + ) + for pattern in patterns: + m = re.search(pattern, haystack, re.I) + if not m: + continue + term = m.group(1).strip() + if term and not re.search(r"\s(?:的|文件|路径)$", term): + return term + return None + + +def _clean_command_candidate(candidate: str) -> str: + candidate = candidate.strip() + candidate = re.split(r"[\n\r。;;]", candidate, maxsplit=1)[0].strip() + candidate = re.sub(r"^(?:就是|为|is|as)\s+", "", candidate, flags=re.I).strip() + return candidate.strip("'\" ") + + +def _looks_safe_shell_snippet(candidate: str | None) -> bool: + if not candidate: + return False + if len(candidate) > 300: + return False + # Reject prose-like captures and obvious shell control chains. Codex will + # still enforce its own sandbox/approval policy after receiving the call. + if any(token in candidate for token in ("\n", "\r", "&&", "||", ";", "| sh", "|sh")): + return False + return bool(re.match(r"^[A-Za-z0-9_./~:${}\\[\\]*?=,'\" -]+$", candidate)) + + +__all__ = [ + "_PATCH_COMPLETED_RE", + "_PATCH_FAILURE_RE", + "_forced_tool_choice", + "_looks_like_codex_tool_run", + "_normalize_codex_tool_calls", + "_synthesize_codex_tool_call", + "_to_chat_tools", +] diff --git a/app/products/openai/responses.py b/app/products/openai/responses.py index d816c7a9..c7d39249 100644 --- a/app/products/openai/responses.py +++ b/app/products/openai/responses.py @@ -5,10 +5,9 @@ """ import asyncio +import re from typing import Any, AsyncGenerator -import orjson - from app.platform.logging.logger import logger from app.platform.config.snapshot import get_config from app.platform.errors import RateLimitError, UpstreamError @@ -32,33 +31,110 @@ from ._tool_sieve import ToolSieve -# --------------------------------------------------------------------------- -# Tool format normalisation -# --------------------------------------------------------------------------- +from ._codex_tools import ( + _PATCH_COMPLETED_RE, + _PATCH_FAILURE_RE, + _forced_tool_choice, + _looks_like_codex_tool_run, + _normalize_codex_tool_calls, + _synthesize_codex_tool_call, + _to_chat_tools, +) -def _to_chat_tools(tools: list[dict]) -> list[dict]: - """Normalise Responses API tool format → Chat Completions format. - Responses API: {type, name, description, parameters} (flat) - Chat Completions: {type, function: {name, description, parameters}} +async def _collect_chat_text( + *, + token: str, + selected_mode_id: int, + message: str, + files: list[str], + timeout_s: float, +) -> tuple[str, str, StreamAdapter]: + adapter = StreamAdapter() + async for line in _stream_chat( + token = token, + mode_id = ModeId(selected_mode_id), + message = message, + files = files, + timeout_s = timeout_s, + ): + event_type, data = classify_line(line) + if event_type == "done": + break + if event_type != "data" or not data: + continue + ended = False + for ev in adapter.feed(data): + if ev.kind == "soft_stop": + ended = True + break + if ended: + break + return "".join(adapter.text_buf), "".join(adapter.thinking_buf), adapter - Already-wrapped tools are passed through unchanged so this is safe to - call regardless of which format the caller used. - """ - normalised = [] - for tool in tools: - if tool.get("type") == "function" and "function" not in tool and "name" in tool: - normalised.append({ - "type": "function", - "function": { - "name": tool.get("name", ""), - "description": tool.get("description", ""), - "parameters": tool.get("parameters"), - }, - }) - else: - normalised.append(tool) - return normalised + +async def _repair_tool_calls( + *, + token: str, + selected_mode_id: int, + original_message: str, + previous_text: str, + chat_tools: list[dict], + tool_names: list[str], + timeout_s: float, +) -> list | None: + if not _looks_like_codex_tool_run(tool_names, original_message, previous_text): + return None + if _PATCH_COMPLETED_RE.search(original_message) and _PATCH_FAILURE_RE.search(original_message): + return None + + forced_choice = _forced_tool_choice(tool_names, original_message, previous_text) + repair_prompt = build_tool_system_prompt(chat_tools, forced_choice) + retry_hint = "" + if _PATCH_FAILURE_RE.search(original_message) and re.search(r"patch:\s*failed|apply_patch.*failed|Invalid Context", original_message, re.I): + retry_hint = ( + "\n[system]: The previous apply_patch failed. Inspect the target " + "file if needed, then emit a corrected minimal apply_patch command. " + "Do not repeat the same failed patch unchanged.\n" + ) + repair_message = ( + f"[system]: {repair_prompt}\n\n" + "[system]: The previous assistant response failed because it described " + "a tool call instead of emitting a structured tool call. Convert the " + "intended next action into the required XML now. Output XML only.\n\n" + f"{retry_hint}" + f"[conversation]:\n{original_message}\n\n" + f"[previous assistant response]:\n{previous_text}" + ) + repaired_text, _, _ = await _collect_chat_text( + token=token, + selected_mode_id=selected_mode_id, + message=repair_message, + files=[], + timeout_s=timeout_s, + ) + result = parse_tool_calls(repaired_text, tool_names) + if result.calls: + logger.info( + "responses repaired missed tool call: tool_names={} call_count={}", + tool_names, + len(result.calls), + ) + return result.calls + synthetic = _synthesize_codex_tool_call(tool_names, original_message, previous_text) + if synthetic: + logger.info( + "responses synthesized codex tool call: tool={} args={}", + synthetic[0].name, + synthetic[0].arguments, + ) + return synthetic + logger.warning( + "responses tool repair failed: tool_names={} text_excerpt={}", + tool_names, + repaired_text[:500], + ) + return None # --------------------------------------------------------------------------- @@ -120,6 +196,19 @@ async def _emit_fc_events(items: list[dict], base_idx: int): }) +async def _emit_response_start(response_id: str, model: str): + """Emit the standard Responses API stream opening events.""" + response = make_resp_object(response_id, model, "in_progress", []) + yield format_sse("response.created", { + "type": "response.created", + "response": response, + }) + yield format_sse("response.in_progress", { + "type": "response.in_progress", + "response": response, + }) + + # --------------------------------------------------------------------------- # Input normalisation # --------------------------------------------------------------------------- @@ -235,12 +324,14 @@ async def create( # Tool prompt injection — only modify the message text, never the Grok payload # Normalise to Chat Completions format first (Responses API uses a flat structure) tool_names: list[str] = [] + chat_tools: list[dict] = [] if tools: chat_tools = _to_chat_tools(tools) tool_names = extract_tool_names(chat_tools) - tool_prompt = build_tool_system_prompt(chat_tools, tool_choice) - message = inject_into_message(message, tool_prompt) - logger.info("responses tool injection: tool_names={} choice={}", tool_names, tool_choice) + if chat_tools: + tool_prompt = build_tool_system_prompt(chat_tools, tool_choice) + message = inject_into_message(message, tool_prompt) + logger.info("responses tool injection: tool_names={} choice={}", tool_names, tool_choice) from app.dataplane.account import _directory as _acct_dir if _acct_dir is None: @@ -258,6 +349,34 @@ async def create( # Streaming # ------------------------------------------------------------------------- async def _run_stream() -> AsyncGenerator[str, None]: + pre_calls = _normalize_codex_tool_calls( + _synthesize_codex_tool_call(tool_names, message, "") if tool_names else [], + tool_names=tool_names, + message=message, + ) + if pre_calls: + async for evt in _emit_response_start(response_id, model): + yield evt + fc_items = _build_fc_items(pre_calls) + async for evt in _emit_fc_events(fc_items, 0): + yield evt + pt = estimate_prompt_tokens(message) + ct = estimate_tool_call_tokens(pre_calls) + yield format_sse("response.completed", { + "type": "response.completed", + "response": make_resp_object( + response_id, model, "completed", fc_items, + build_resp_usage(pt, ct, 0), + ), + }) + yield "data: [DONE]\n\n" + logger.info( + "responses stream pre-synthesized codex tool_calls: model={} call_count={}", + model, + len(pre_calls), + ) + return + excluded: list[str] = [] for attempt in range(max_retries + 1): acct, selected_mode_id = await reserve_account( @@ -286,10 +405,129 @@ async def _run_stream() -> AsyncGenerator[str, None]: try: try: - yield format_sse("response.created", { - "type": "response.created", - "response": make_resp_object(response_id, model, "in_progress", []), - }) + if tool_names: + async for evt in _emit_response_start(response_id, model): + yield evt + full_text, full_think, tool_adapter = await _collect_chat_text( + token=token, + selected_mode_id=selected_mode_id, + message=message, + files=files, + timeout_s=timeout_s, + ) + parse_result = parse_tool_calls(full_text, tool_names) + calls = parse_result.calls + parsed_calls = list(calls) + if not calls: + repaired = await _repair_tool_calls( + token=token, + selected_mode_id=selected_mode_id, + original_message=message, + previous_text=full_text or full_think, + chat_tools=chat_tools, + tool_names=tool_names, + timeout_s=timeout_s, + ) + calls = repaired or [] + calls = _normalize_codex_tool_calls( + calls, + tool_names=tool_names, + message=message, + ) + if parsed_calls and not calls and _PATCH_COMPLETED_RE.search(message): + full_text = "Done." + + if calls: + fc_items = _build_fc_items(calls) + async for evt in _emit_fc_events(fc_items, 0): + yield evt + pt = estimate_prompt_tokens(message) + ct = estimate_tool_call_tokens(calls) + yield format_sse("response.completed", { + "type": "response.completed", + "response": make_resp_object( + response_id, model, "completed", fc_items, + build_resp_usage(pt, ct, 0), + ), + }) + yield "data: [DONE]\n\n" + success = True + logger.info("responses stream tool_calls buffered: attempt={}/{} model={} call_count={}", + attempt + 1, max_retries + 1, model, len(calls)) + else: + if tool_adapter.image_urls: + for url, img_id in tool_adapter.image_urls: + img_text = await _resolve_image(token, url, img_id) + full_text += ("\n\n" if full_text else "") + img_text + references = tool_adapter.references_suffix() + if references: + full_text += references + msg_item = { + "id": message_id, + "type": "message", + "role": "assistant", + "content": [{"type": "output_text", "text": full_text, "annotations": tool_adapter.annotations_list()}], + "status": "completed", + } + yield format_sse("response.output_item.added", { + "type": "response.output_item.added", + "output_index": 0, + "item": { + "id": message_id, "type": "message", + "role": "assistant", "content": [], "status": "in_progress", + }, + }) + yield format_sse("response.content_part.added", { + "type": "response.content_part.added", + "item_id": message_id, + "output_index": 0, + "content_index": 0, + "part": {"type": "output_text", "text": "", "annotations": []}, + }) + if full_text: + yield format_sse("response.output_text.delta", { + "type": "response.output_text.delta", + "item_id": message_id, + "output_index": 0, + "content_index": 0, + "delta": full_text, + }) + yield format_sse("response.output_text.done", { + "type": "response.output_text.done", + "item_id": message_id, + "output_index": 0, + "content_index": 0, + "text": full_text, + }) + yield format_sse("response.content_part.done", { + "type": "response.content_part.done", + "item_id": message_id, + "output_index": 0, + "content_index": 0, + "part": msg_item["content"][0], + }) + yield format_sse("response.output_item.done", { + "type": "response.output_item.done", + "output_index": 0, + "item": msg_item, + }) + pt = estimate_prompt_tokens(message) + ct = estimate_tokens(full_text) + yield format_sse("response.completed", { + "type": "response.completed", + "response": make_resp_object( + response_id, model, "completed", [msg_item], + build_resp_usage(pt, ct, 0), + ), + }) + yield "data: [DONE]\n\n" + success = True + logger.info("responses stream text buffered: attempt={}/{} model={} text_len={}", + attempt + 1, max_retries + 1, model, len(full_text)) + return + + async for evt in _emit_response_start(response_id, model): + yield evt ended = False async for line in _stream_chat( @@ -604,6 +842,21 @@ async def _run_stream() -> AsyncGenerator[str, None]: # ------------------------------------------------------------------------- # Non-streaming # ------------------------------------------------------------------------- + pre_calls = _normalize_codex_tool_calls( + _synthesize_codex_tool_call(tool_names, message, "") if tool_names else [], + tool_names=tool_names, + message=message, + ) + if pre_calls: + output = _build_fc_items(pre_calls) + pt = estimate_prompt_tokens(message) + ct = estimate_tool_call_tokens(pre_calls) + logger.info("responses pre-synthesized codex tool_calls: model={} call_count={}", model, len(pre_calls)) + return make_resp_object( + response_id, model, "completed", output, + build_resp_usage(pt, ct, 0), + ) + excluded: list[str] = [] token = "" adapter = StreamAdapter() @@ -698,7 +951,27 @@ async def _run_stream() -> AsyncGenerator[str, None]: # Check for tool calls in the accumulated text if tool_names: tc_result = parse_tool_calls(full_text, tool_names) - if tc_result.calls: + calls = tc_result.calls + parsed_calls = list(calls) + if not calls: + repaired = await _repair_tool_calls( + token=token, + selected_mode_id=selected_mode_id, + original_message=message, + previous_text=full_text or full_think, + chat_tools=chat_tools, + tool_names=tool_names, + timeout_s=timeout_s, + ) + calls = repaired or [] + calls = _normalize_codex_tool_calls( + calls, + tool_names=tool_names, + message=message, + ) + if parsed_calls and not calls and _PATCH_COMPLETED_RE.search(message): + full_text = "Done." + if calls: output: list[dict] = [] if full_think: output.append({ @@ -707,11 +980,11 @@ async def _run_stream() -> AsyncGenerator[str, None]: "summary": [{"type": "summary_text", "text": full_think}], "status": "completed", }) - output.extend(_build_fc_items(tc_result.calls)) + output.extend(_build_fc_items(calls)) pt = estimate_prompt_tokens(message) - ct = estimate_tool_call_tokens(tc_result.calls) + ct = estimate_tool_call_tokens(calls) rt = estimate_tokens(full_think) if full_think else 0 - logger.info("responses tool_calls: model={} calls={}", model, len(tc_result.calls)) + logger.info("responses tool_calls: model={} calls={}", model, len(calls)) return make_resp_object( response_id, model, "completed", output, build_resp_usage(pt, ct + rt, rt), diff --git a/app/products/openai/router.py b/app/products/openai/router.py index 01a27504..19092a10 100644 --- a/app/products/openai/router.py +++ b/app/products/openai/router.py @@ -56,6 +56,49 @@ def _model_available_for_pools(spec: ModelSpec, pools: frozenset[str]) -> bool: return False +_CODEX_BASE_INSTRUCTIONS = """You are Codex, a coding agent running on the user's computer. + +Use the available tools carefully, prefer rg for searches, and preserve user changes. +When editing files, prefer apply_patch for manual changes and report what you changed. +""" + + +def _codex_model_catalog(models: list[dict]) -> dict: + catalog = [] + for idx, model in enumerate(models): + model_id = str(model["id"]) + catalog.append( + { + "slug": model_id, + "display_name": str(model.get("name") or model_id), + "description": f"{model.get('name') or model_id} via grok2api.", + "default_reasoning_level": None, + "supported_reasoning_levels": [], + "supports_reasoning_summaries": False, + "default_reasoning_summary": "none", + "support_verbosity": False, + "default_verbosity": None, + "shell_type": "shell_command", + "visibility": "list", + "supported_in_api": True, + "priority": idx + 10, + "base_instructions": _CODEX_BASE_INSTRUCTIONS, + "apply_patch_tool_type": "freeform", + "web_search_tool_type": "text_and_image", + "truncation_policy": {"mode": "tokens", "limit": 10000}, + "supports_parallel_tool_calls": True, + "supports_image_detail_original": False, + "context_window": 131072, + "max_context_window": 131072, + "effective_context_window_percent": 95, + "experimental_supported_tools": [], + "input_modalities": ["text", "image"], + "supports_search_tool": False, + } + ) + return {"models": catalog} + + # --------------------------------------------------------------------------- # /v1/models # --------------------------------------------------------------------------- @@ -77,6 +120,8 @@ async def list_models(request: Request): for m in model_registry.list_enabled() if _model_available_for_pools(m, pools) ] + if "client_version" in request.query_params: + return JSONResponse(_codex_model_catalog(models)) return JSONResponse({"object": "list", "data": models}) diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 00000000..8b137891 --- /dev/null +++ b/tests/__init__.py @@ -0,0 +1 @@ + diff --git a/tests/test_codex_tools.py b/tests/test_codex_tools.py new file mode 100644 index 00000000..da8269c1 --- /dev/null +++ b/tests/test_codex_tools.py @@ -0,0 +1,172 @@ +import json +import unittest + +from app.dataplane.reverse.protocol.tool_parser import ParsedToolCall +from app.products.openai._codex_tools import ( + _normalize_codex_tool_calls, + _synthesize_codex_tool_call, +) +from app.dataplane.reverse.protocol.tool_prompt import ( + build_tool_system_prompt, + inject_into_message, +) + + +class CodexToolFallbackTests(unittest.TestCase): + def test_synthesizes_write_stdin_for_session_input(self): + calls = _synthesize_codex_tool_call( + ["exec_command", "write_stdin"], + "[user]: 向 session_id 123 输入 'hello'", + "", + ) + + self.assertEqual(len(calls), 1) + self.assertEqual(calls[0].name, "write_stdin") + self.assertEqual(json.loads(calls[0].arguments), {"session_id": 123, "chars": "hello"}) + + def test_synthesizes_apply_patch_for_simple_create_file_request(self): + calls = _synthesize_codex_tool_call( + ["exec_command"], + "[user]: 请创建文件 codex_apply_patch_probe.txt,内容为 ok", + "", + ) + + self.assertEqual(len(calls), 1) + args = json.loads(calls[0].arguments) + self.assertIn("apply_patch <<'PATCH'", args["cmd"]) + self.assertIn("*** Add File: codex_apply_patch_probe.txt", args["cmd"]) + self.assertIn("+ok", args["cmd"]) + + def test_does_not_treat_model_name_as_filename(self): + calls = _synthesize_codex_tool_call( + ["exec_command"], + "[system]: Available models include gpt-5.5.\n\n[user]: 帮我改一下项目结构", + "", + ) + + self.assertEqual(calls, []) + + def test_rewrites_simple_echo_write_to_apply_patch(self): + call = ParsedToolCall.make("exec_command", {"cmd": "echo 'ok' > probe.txt"}) + + normalized = _normalize_codex_tool_calls( + [call], + tool_names=["exec_command"], + message="[user]: 请创建文件 probe.txt,内容为 ok", + ) + + self.assertEqual(len(normalized), 1) + cmd = json.loads(normalized[0].arguments)["cmd"] + self.assertIn("apply_patch <<'PATCH'", cmd) + self.assertIn("*** Add File: probe.txt", cmd) + self.assertIn("+ok", cmd) + + def test_suppresses_duplicate_completed_apply_patch(self): + patch_cmd = "\n".join( + [ + "apply_patch <<'PATCH'", + "*** Begin Patch", + "*** Add File: probe.txt", + "+ok", + "*** End Patch", + "PATCH", + ] + ) + call = ParsedToolCall.make("exec_command", {"cmd": patch_cmd}) + + normalized = _normalize_codex_tool_calls( + [call], + tool_names=["exec_command"], + message="[tool result]:\npatch: completed\n/abs/probe.txt\n*** Add File: probe.txt", + ) + + self.assertEqual(normalized, []) + + def test_synthesizes_desktop_listing_despite_injected_tool_text(self): + message = ( + "[system]: AVAILABLE TOOLS:\n" + "function_call_output session_id patch: completed\n\n" + "[user]: 帮我看看我的桌面文件有哪些" + ) + + calls = _synthesize_codex_tool_call(["exec_command"], message, "") + + self.assertEqual(len(calls), 1) + self.assertEqual( + json.loads(calls[0].arguments), + {"cmd": "ls -la ~/Desktop | sed -n '1,40p'"}, + ) + + def test_does_not_repeat_desktop_listing_after_command_output(self): + message = ( + "[user]: 帮我看看我的桌面文件有哪些\n\n" + "[tool result]: command_execution exit_code=0 aggregated_output='total 77944'" + ) + + calls = _synthesize_codex_tool_call(["exec_command"], message, "") + + self.assertEqual(calls, []) + + def test_synthesizes_read_file_command(self): + calls = _synthesize_codex_tool_call( + ["exec_command"], + "[user]: 读取 sample.txt 内容并告诉我", + "", + ) + + self.assertEqual(len(calls), 1) + self.assertEqual(json.loads(calls[0].arguments), {"cmd": "sed -n '1,200p' sample.txt"}) + + def test_synthesizes_read_file_command_after_tool_injection(self): + tools = [ + { + "type": "function", + "function": { + "name": "exec_command", + "description": "Run a command.", + "parameters": { + "type": "object", + "properties": { + "cmd": {"type": "string", "description": "Example: python pyrepl.py"}, + }, + }, + }, + } + ] + message = inject_into_message( + "读取 sample.txt 内容并告诉我", + build_tool_system_prompt(tools, "auto"), + ) + + calls = _synthesize_codex_tool_call(["exec_command"], message, "") + + self.assertEqual(len(calls), 1) + self.assertEqual(json.loads(calls[0].arguments), {"cmd": "sed -n '1,200p' sample.txt"}) + + def test_synthesizes_search_command(self): + calls = _synthesize_codex_tool_call( + ["exec_command"], + "[user]: 搜索当前目录下包含 alpha 的文件,并告诉我文件路径", + "", + ) + + self.assertEqual(len(calls), 1) + self.assertEqual(json.loads(calls[0].arguments), {"cmd": "rg -n -- alpha ."}) + + def test_synthesizes_simple_replace_patch(self): + calls = _synthesize_codex_tool_call( + ["exec_command"], + "[user]: 把 sample.txt 里的 alpha 改成 beta,然后告诉我修改后的内容", + "", + ) + + self.assertEqual(len(calls), 1) + cmd = json.loads(calls[0].arguments)["cmd"] + self.assertIn("apply_patch <<'PATCH'", cmd) + self.assertIn("*** Update File: sample.txt", cmd) + self.assertIn("-alpha", cmd) + self.assertIn("+beta", cmd) + + +if __name__ == "__main__": + unittest.main()