From 3e3037757e128427269646c90da82a36f7170600 Mon Sep 17 00:00:00 2001 From: Cursor Agent Date: Mon, 1 Sep 2025 07:36:50 +0000 Subject: [PATCH] Add type safety and error handling improvements to evaluation pipeline Co-authored-by: bchen --- eval_protocol/agent/task_manager.py | 8 +++++--- eval_protocol/execution/pipeline.py | 32 ++++++++++++++++++++--------- 2 files changed, 27 insertions(+), 13 deletions(-) diff --git a/eval_protocol/agent/task_manager.py b/eval_protocol/agent/task_manager.py index 582e9fc3..1384003d 100644 --- a/eval_protocol/agent/task_manager.py +++ b/eval_protocol/agent/task_manager.py @@ -16,7 +16,7 @@ from copy import deepcopy from datetime import datetime from pathlib import Path -from typing import Any, Dict, List, Optional, Set, Tuple +from typing import Any, Dict, List, Optional, Set, Tuple, cast import requests @@ -684,6 +684,7 @@ async def execute_single_rollout(sample_index: int, rollout_index: int, sample_d # Add sample metadata to the result if isinstance(result, dict): + result = cast(Dict[str, Any], result) result["sample_data"] = sample_data result["sample_index"] = sample_index result["rollout_index"] = rollout_index @@ -920,9 +921,10 @@ def _save_detailed_results( if chosen_dir is None: chosen_dir = Path(".") - output_file = chosen_dir / f"trajectory_{task_id}_{timestamp}.jsonl" + output_path = chosen_dir / f"trajectory_{task_id}_{timestamp}.jsonl" - output_path = Path(output_file) + else: + output_path = Path(output_file) try: self.logger.info("=== TRAJECTORY SAVE DEBUG START ===") diff --git a/eval_protocol/execution/pipeline.py b/eval_protocol/execution/pipeline.py index 5cd7b4d9..8830f3ff 100644 --- a/eval_protocol/execution/pipeline.py +++ b/eval_protocol/execution/pipeline.py @@ -8,7 +8,7 @@ import json import logging import os -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Union, cast import aiohttp import hydra @@ -312,7 +312,7 @@ async def _execute_mcp_agent_rollout( tools=openai_formatted_tools, ) - assistant_msg_for_history = {"role": "assistant"} + assistant_msg_for_history: Dict[str, Any] = {"role": "assistant"} if generation_output_turn.tool_calls: assistant_msg_for_history["tool_calls"] = [ @@ -479,7 +479,7 @@ async def _process_single_sample_internal( sample: Dict[str, Any], http_session: Optional[aiohttp.ClientSession], # For model_client, not mcp_client original_index: Optional[int] = None, - ) -> Optional[Dict[str, Any]]: + ) -> Dict[str, Any]: sample_id_fallback = ( f"idx_{original_index}" if original_index is not None else "unknown_id_" + os.urandom(4).hex() ) @@ -497,7 +497,10 @@ async def _process_single_sample_internal( logger.warning( f"Skipping sample {sample_id}: needs either ('user_query' + 'ground_truth_for_eval') for generation or 'messages' for evaluation." ) - return None + return { + "id": sample_id, + "error": "Missing required fields for generation/evaluation", + } original_system_prompt = sample.get("system_prompt") or self.cfg.get("system_prompt") discovered_tools_for_llm_prompt: List[Dict[str, Any]] = [] @@ -582,13 +585,18 @@ async def _process_single_sample_internal( } else: logger.warning(f"Sample {sample_id}: Evaluation mode requires generation.enabled=false") - return None + return { + "id": sample_id, + "error": "Evaluation mode requires generation.enabled=false", + } # Generation mode: Initial messages for the main rollout (or single generation if not agent) + # At this point, generation format is guaranteed by the control flow above; cast for type checking + user_query_str: str = cast(str, user_query) current_messages_for_rollout: List[Dict[str, Any]] = [] if system_prompt_content: current_messages_for_rollout.append({"role": "system", "content": system_prompt_content}) - current_messages_for_rollout.append({"role": "user", "content": user_query}) + current_messages_for_rollout.append({"role": "user", "content": user_query_str}) # --- LLM Generation / Agent Rollout --- if not self.cfg.generation.enabled: @@ -605,10 +613,14 @@ async def _process_single_sample_internal( final_assistant_output_for_log = self.cache.get( sample_id=sample_id, system_prompt=original_system_prompt, - user_query=user_query, + user_query=cast(str, user_query), model_name=gen_cfg.get("model_name", "unknown_model"), temperature=gen_cfg.get("temperature", 0.0), - # ... other cache params + top_p=gen_cfg.get("top_p", 1.0), + top_k=gen_cfg.get("top_k", 0), + min_p=gen_cfg.get("min_p", 0.0), + max_tokens=gen_cfg.get("max_tokens", 2048), + reasoning_effort=gen_cfg.get("reasoning_effort", None), ) if not final_assistant_output_for_log: return { @@ -627,7 +639,7 @@ async def _process_single_sample_internal( elif self.mcp_intermediary_client and self.cfg.agent.type == "mcp_agent": mcp_result = await self._execute_mcp_agent_rollout( sample_id=sample_id, - user_query=user_query, + user_query=user_query_str, system_prompt_content=system_prompt_content, openai_formatted_tools=openai_formatted_tools, http_session=http_session, @@ -676,7 +688,7 @@ async def _process_single_sample_internal( else: generation_result = await self._execute_standard_generation( sample_id=sample_id, - user_query=user_query, + user_query=user_query_str, system_prompt_content=system_prompt_content, http_session=http_session, )