From 8fcecf65ab6ca61a185edd0cadd0c147ad71ec86 Mon Sep 17 00:00:00 2001 From: 1aal Date: Tue, 9 Jun 2026 21:48:25 +0800 Subject: [PATCH] fix: require initial tool choice for evidence runs --- agent/agent_init.py | 2 + agent/conversation_loop.py | 39 ++++++++++++ run_agent.py | 2 + runtime_manager/manager.py | 29 +++++++++ runtime_manager/worker_main.py | 28 +++++++++ .../test_required_tool_choice_policy.py | 62 +++++++++++++++++++ tests/runtime_manager/test_registry.py | 17 ++++- 7 files changed, 178 insertions(+), 1 deletion(-) create mode 100644 tests/run_agent/test_required_tool_choice_policy.py diff --git a/agent/agent_init.py b/agent/agent_init.py index 30bb6d837053..4ccdff0ead6e 100644 --- a/agent/agent_init.py +++ b/agent/agent_init.py @@ -198,6 +198,7 @@ def init_agent( reasoning_config: Dict[str, Any] = None, service_tier: str = None, request_overrides: Dict[str, Any] = None, + tool_choice_policy: str = None, prefill_messages: List[Dict[str, Any]] = None, platform: str = None, user_id: str = None, @@ -482,6 +483,7 @@ def init_agent( agent.reasoning_config = reasoning_config # None = use default (medium for OpenRouter) agent.service_tier = service_tier agent.request_overrides = dict(request_overrides or {}) + agent.tool_choice_policy = str(tool_choice_policy or "").strip() agent.prefill_messages = prefill_messages or [] # Prefilled conversation turns agent._force_ascii_payload = False diff --git a/agent/conversation_loop.py b/agent/conversation_loop.py index bd1881e912bc..9784fca950d9 100644 --- a/agent/conversation_loop.py +++ b/agent/conversation_loop.py @@ -70,6 +70,44 @@ # to treat it as cancellation metadata rather than assistant prose. INTERRUPT_WAITING_FOR_MODEL_PREFIX = "Operation interrupted: waiting for model response (" +_REQUIRE_UNTIL_FIRST_TOOL_POLICIES = { + "require_until_first_tool", + "required_until_first_tool", +} + + +def _current_turn_has_tool_result(messages: List[Dict[str, Any]], current_turn_user_idx: int) -> bool: + if not isinstance(messages, list): + return False + start_idx = current_turn_user_idx + 1 if isinstance(current_turn_user_idx, int) else 0 + if start_idx < 0: + start_idx = 0 + for message in messages[start_idx:]: + if isinstance(message, dict) and message.get("role") == "tool": + return True + return False + + +def _maybe_apply_required_tool_choice( + agent: Any, + api_kwargs: Dict[str, Any], + messages: List[Dict[str, Any]], + current_turn_user_idx: int, +) -> None: + """Force a first tool call only until the current turn has tool evidence.""" + policy = str(getattr(agent, "tool_choice_policy", "") or "").strip().lower() + if policy not in _REQUIRE_UNTIL_FIRST_TOOL_POLICIES: + return + if not isinstance(api_kwargs, dict): + return + if api_kwargs.get("tool_choice") is not None: + return + if not api_kwargs.get("tools"): + return + if _current_turn_has_tool_result(messages, current_turn_user_idx): + return + api_kwargs["tool_choice"] = "required" + def _ollama_context_limit_error(agent: Any, request_tokens: int) -> Optional[str]: """Return a user-facing error when Ollama is loaded with too little context.""" @@ -868,6 +906,7 @@ def run_conversation( # isn't sent with stale, primary-shaped reasoning fields. agent._reapply_reasoning_echo_for_provider(api_messages) api_kwargs = agent._build_api_kwargs(api_messages) + _maybe_apply_required_tool_choice(agent, api_kwargs, messages, current_turn_user_idx) if agent._force_ascii_payload: _sanitize_structure_non_ascii(api_kwargs) if agent.api_mode == "codex_responses": diff --git a/run_agent.py b/run_agent.py index 9c720bcbfe09..e2ac58322d67 100644 --- a/run_agent.py +++ b/run_agent.py @@ -387,6 +387,7 @@ def __init__( reasoning_config: Dict[str, Any] = None, service_tier: str = None, request_overrides: Dict[str, Any] = None, + tool_choice_policy: str = None, prefill_messages: List[Dict[str, Any]] = None, platform: str = None, user_id: str = None, @@ -460,6 +461,7 @@ def __init__( reasoning_config=reasoning_config, service_tier=service_tier, request_overrides=request_overrides, + tool_choice_policy=tool_choice_policy, prefill_messages=prefill_messages, platform=platform, user_id=user_id, diff --git a/runtime_manager/manager.py b/runtime_manager/manager.py index c621309a9f59..6e6eccc6a3d8 100644 --- a/runtime_manager/manager.py +++ b/runtime_manager/manager.py @@ -96,6 +96,7 @@ async def start_run(self, payload: dict[str, Any]) -> RunHandle: llm_config.get("base_url"), llm_config.get("baseURL"), ) + tool_choice_policy = _tool_choice_policy_for_payload(payload, llm_config) run_id = f"run_{uuid.uuid4().hex}" await self._reserve_run(run_id=run_id, user_id=user_id, conversation_id=conversation_id) handle = self.registry.create( @@ -154,6 +155,8 @@ async def start_run(self, payload: dict[str, Any]) -> RunHandle: "max_iterations": resolved.max_iterations, "metadata": payload.get("metadata") or {}, "artifact_dir": str(artifact_dir), + "requires_tool_evidence": _truthy_payload_flag(payload.get("requires_tool_evidence")), + "tool_choice_policy": tool_choice_policy, } assert proc.stdin is not None proc.stdin.write((json.dumps(worker_request, ensure_ascii=False) + "\n").encode("utf-8")) @@ -411,6 +414,32 @@ def _first_present(*values: Any) -> Any: return None +def _truthy_payload_flag(value: Any) -> bool: + if value is None: + return False + if isinstance(value, bool): + return value + if isinstance(value, str): + return value.strip().lower() in {"1", "true", "yes", "y", "on"} + if isinstance(value, (int, float)): + return value != 0 + return bool(value) + + +def _tool_choice_policy_for_payload(payload: dict[str, Any], llm_config: dict[str, Any]) -> str: + policy = _first_present( + payload.get("tool_choice_policy"), + payload.get("toolChoicePolicy"), + llm_config.get("tool_choice_policy"), + llm_config.get("toolChoicePolicy"), + ) + if isinstance(policy, str) and policy.strip(): + return policy.strip() + if _truthy_payload_flag(payload.get("requires_tool_evidence")): + return "require_until_first_tool" + return "" + + def _remove_session_files(sessions_dir: Path, session_id: str) -> bool: removed = False if not sessions_dir.exists(): diff --git a/runtime_manager/worker_main.py b/runtime_manager/worker_main.py index 33002a44963a..cc97bce75b5e 100644 --- a/runtime_manager/worker_main.py +++ b/runtime_manager/worker_main.py @@ -369,6 +369,7 @@ def on_thinking(message: str | None) -> None: _first_present(request.get("provider"), llm_config.get("provider")), base_url=base_url, ) + tool_choice_policy = _tool_choice_policy_from_request(request, llm_config) system_prompt = _compose_effective_system_prompt( request, @@ -403,6 +404,7 @@ def on_thinking(message: str | None) -> None: skip_context_files=bool(request.get("skip_context_files", True)), ephemeral_system_prompt=system_prompt, max_iterations=int(request.get("max_iterations") or 90), + tool_choice_policy=tool_choice_policy, ) _AGENT_HOLDER["agent"] = agent @@ -483,6 +485,32 @@ def _first_present(*values: Any) -> Any: return None +def _truthy_request_flag(value: Any) -> bool: + if value is None: + return False + if isinstance(value, bool): + return value + if isinstance(value, str): + return value.strip().lower() in {"1", "true", "yes", "y", "on"} + if isinstance(value, (int, float)): + return value != 0 + return bool(value) + + +def _tool_choice_policy_from_request(request: dict[str, Any], llm_config: dict[str, Any]) -> str: + policy = _first_present( + request.get("tool_choice_policy"), + request.get("toolChoicePolicy"), + llm_config.get("tool_choice_policy"), + llm_config.get("toolChoicePolicy"), + ) + if isinstance(policy, str) and policy.strip(): + return policy.strip() + if _truthy_request_flag(request.get("requires_tool_evidence")): + return "require_until_first_tool" + return "" + + def _normalize_agent_provider(provider: Any, *, base_url: Any = None) -> str | None: if provider is None: return None diff --git a/tests/run_agent/test_required_tool_choice_policy.py b/tests/run_agent/test_required_tool_choice_policy.py new file mode 100644 index 000000000000..095f41d15dde --- /dev/null +++ b/tests/run_agent/test_required_tool_choice_policy.py @@ -0,0 +1,62 @@ +from types import SimpleNamespace + +from agent.conversation_loop import _maybe_apply_required_tool_choice + + +def _agent(policy="require_until_first_tool"): + return SimpleNamespace(tool_choice_policy=policy) + + +def _kwargs(): + return { + "model": "qwen3.6-35b-a3b", + "messages": [{"role": "user", "content": "检查一下当前集群的整体健康状态"}], + "tools": [{"type": "function", "function": {"name": "terminal", "parameters": {}}}], + } + + +def test_requires_tool_choice_until_first_tool_result(): + api_kwargs = _kwargs() + _maybe_apply_required_tool_choice( + _agent(), + api_kwargs, + [{"role": "user", "content": "检查一下当前集群的整体健康状态"}], + 0, + ) + + assert api_kwargs["tool_choice"] == "required" + + +def test_required_tool_choice_policy_releases_after_tool_result(): + api_kwargs = _kwargs() + _maybe_apply_required_tool_choice( + _agent(), + api_kwargs, + [ + {"role": "user", "content": "检查一下当前集群的整体健康状态"}, + { + "role": "assistant", + "content": None, + "tool_calls": [ + {"id": "call_1", "function": {"name": "terminal", "arguments": "{}"}} + ], + }, + {"role": "tool", "tool_call_id": "call_1", "content": "kubectl output"}, + ], + 0, + ) + + assert "tool_choice" not in api_kwargs + + +def test_required_tool_choice_policy_does_not_override_explicit_choice(): + api_kwargs = _kwargs() + api_kwargs["tool_choice"] = "auto" + _maybe_apply_required_tool_choice( + _agent(), + api_kwargs, + [{"role": "user", "content": "检查一下当前集群的整体健康状态"}], + 0, + ) + + assert api_kwargs["tool_choice"] == "auto" diff --git a/tests/runtime_manager/test_registry.py b/tests/runtime_manager/test_registry.py index 7d189155ca05..71d886305384 100644 --- a/tests/runtime_manager/test_registry.py +++ b/tests/runtime_manager/test_registry.py @@ -71,6 +71,18 @@ def test_runtime_worker_normalizes_cloud_provider_aliases_to_hermes_names(): assert _normalize_agent_provider("qwen-oauth") == "qwen-oauth" +def test_runtime_worker_maps_evidence_requirement_to_tool_choice_policy(): + from runtime_manager.worker_main import _tool_choice_policy_from_request + + assert _tool_choice_policy_from_request({"requires_tool_evidence": True}, {}) == "require_until_first_tool" + assert _tool_choice_policy_from_request({"requires_tool_evidence": "true"}, {}) == "require_until_first_tool" + assert _tool_choice_policy_from_request( + {"requires_tool_evidence": True, "tool_choice_policy": "custom-policy"}, + {}, + ) == "custom-policy" + assert _tool_choice_policy_from_request({}, {}) == "" + + def test_runtime_worker_tool_event_helpers_are_json_safe(): from runtime_manager.worker_main import ( _approval_display_fields, @@ -520,7 +532,7 @@ async def test_runtime_manager_forwards_per_run_llm_config_to_worker(tmp_path): "import json, sys, time", "req = json.loads(sys.stdin.readline())", "run_id = req['run_id']", - "print(json.dumps({'event': 'run.completed', 'run_id': run_id, 'timestamp': time.time(), 'output': json.dumps({'model': req.get('model'), 'provider': req.get('provider'), 'base_url': req.get('base_url'), 'api_key': req.get('api_key')})}), flush=True)", + "print(json.dumps({'event': 'run.completed', 'run_id': run_id, 'timestamp': time.time(), 'output': json.dumps({'model': req.get('model'), 'provider': req.get('provider'), 'base_url': req.get('base_url'), 'api_key': req.get('api_key'), 'requires_tool_evidence': req.get('requires_tool_evidence'), 'tool_choice_policy': req.get('tool_choice_policy')})}), flush=True)", ] ), encoding="utf-8", @@ -540,6 +552,7 @@ async def test_runtime_manager_forwards_per_run_llm_config_to_worker(tmp_path): "user_id": "user-1", "conversation_id": "conv-1", "message": "hello", + "requires_tool_evidence": True, "llm_config": { "provider": "openai", "model": "gpt-4.1", @@ -562,6 +575,8 @@ async def test_runtime_manager_forwards_per_run_llm_config_to_worker(tmp_path): "provider": "openai", "base_url": "https://models.example/v1", "api_key": "sk-test", + "requires_tool_evidence": True, + "tool_choice_policy": "require_until_first_tool", }