Skip to content

Commit caf93cf

Browse files
benjibccursoragent
andauthored
Add type safety and error handling improvements to evaluation pipeline (#146)
Co-authored-by: Cursor Agent <cursoragent@cursor.com>
1 parent fbb8d34 commit caf93cf

File tree

2 files changed

+27
-13
lines changed

2 files changed

+27
-13
lines changed

eval_protocol/agent/task_manager.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from copy import deepcopy
1717
from datetime import datetime
1818
from pathlib import Path
19-
from typing import Any, Dict, List, Optional, Set, Tuple
19+
from typing import Any, Dict, List, Optional, Set, Tuple, cast
2020

2121
import requests
2222

@@ -684,6 +684,7 @@ async def execute_single_rollout(sample_index: int, rollout_index: int, sample_d
684684

685685
# Add sample metadata to the result
686686
if isinstance(result, dict):
687+
result = cast(Dict[str, Any], result)
687688
result["sample_data"] = sample_data
688689
result["sample_index"] = sample_index
689690
result["rollout_index"] = rollout_index
@@ -920,9 +921,10 @@ def _save_detailed_results(
920921
if chosen_dir is None:
921922
chosen_dir = Path(".")
922923

923-
output_file = chosen_dir / f"trajectory_{task_id}_{timestamp}.jsonl"
924+
output_path = chosen_dir / f"trajectory_{task_id}_{timestamp}.jsonl"
924925

925-
output_path = Path(output_file)
926+
else:
927+
output_path = Path(output_file)
926928

927929
try:
928930
self.logger.info("=== TRAJECTORY SAVE DEBUG START ===")

eval_protocol/execution/pipeline.py

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import json
99
import logging
1010
import os
11-
from typing import Any, Dict, List, Optional, Union
11+
from typing import Any, Dict, List, Optional, Union, cast
1212

1313
import aiohttp
1414
import hydra
@@ -312,7 +312,7 @@ async def _execute_mcp_agent_rollout(
312312
tools=openai_formatted_tools,
313313
)
314314

315-
assistant_msg_for_history = {"role": "assistant"}
315+
assistant_msg_for_history: Dict[str, Any] = {"role": "assistant"}
316316

317317
if generation_output_turn.tool_calls:
318318
assistant_msg_for_history["tool_calls"] = [
@@ -479,7 +479,7 @@ async def _process_single_sample_internal(
479479
sample: Dict[str, Any],
480480
http_session: Optional[aiohttp.ClientSession], # For model_client, not mcp_client
481481
original_index: Optional[int] = None,
482-
) -> Optional[Dict[str, Any]]:
482+
) -> Dict[str, Any]:
483483
sample_id_fallback = (
484484
f"idx_{original_index}" if original_index is not None else "unknown_id_" + os.urandom(4).hex()
485485
)
@@ -497,7 +497,10 @@ async def _process_single_sample_internal(
497497
logger.warning(
498498
f"Skipping sample {sample_id}: needs either ('user_query' + 'ground_truth_for_eval') for generation or 'messages' for evaluation."
499499
)
500-
return None
500+
return {
501+
"id": sample_id,
502+
"error": "Missing required fields for generation/evaluation",
503+
}
501504

502505
original_system_prompt = sample.get("system_prompt") or self.cfg.get("system_prompt")
503506
discovered_tools_for_llm_prompt: List[Dict[str, Any]] = []
@@ -582,13 +585,18 @@ async def _process_single_sample_internal(
582585
}
583586
else:
584587
logger.warning(f"Sample {sample_id}: Evaluation mode requires generation.enabled=false")
585-
return None
588+
return {
589+
"id": sample_id,
590+
"error": "Evaluation mode requires generation.enabled=false",
591+
}
586592

587593
# Generation mode: Initial messages for the main rollout (or single generation if not agent)
594+
# At this point, generation format is guaranteed by the control flow above; cast for type checking
595+
user_query_str: str = cast(str, user_query)
588596
current_messages_for_rollout: List[Dict[str, Any]] = []
589597
if system_prompt_content:
590598
current_messages_for_rollout.append({"role": "system", "content": system_prompt_content})
591-
current_messages_for_rollout.append({"role": "user", "content": user_query})
599+
current_messages_for_rollout.append({"role": "user", "content": user_query_str})
592600

593601
# --- LLM Generation / Agent Rollout ---
594602
if not self.cfg.generation.enabled:
@@ -605,10 +613,14 @@ async def _process_single_sample_internal(
605613
final_assistant_output_for_log = self.cache.get(
606614
sample_id=sample_id,
607615
system_prompt=original_system_prompt,
608-
user_query=user_query,
616+
user_query=cast(str, user_query),
609617
model_name=gen_cfg.get("model_name", "unknown_model"),
610618
temperature=gen_cfg.get("temperature", 0.0),
611-
# ... other cache params
619+
top_p=gen_cfg.get("top_p", 1.0),
620+
top_k=gen_cfg.get("top_k", 0),
621+
min_p=gen_cfg.get("min_p", 0.0),
622+
max_tokens=gen_cfg.get("max_tokens", 2048),
623+
reasoning_effort=gen_cfg.get("reasoning_effort", None),
612624
)
613625
if not final_assistant_output_for_log:
614626
return {
@@ -627,7 +639,7 @@ async def _process_single_sample_internal(
627639
elif self.mcp_intermediary_client and self.cfg.agent.type == "mcp_agent":
628640
mcp_result = await self._execute_mcp_agent_rollout(
629641
sample_id=sample_id,
630-
user_query=user_query,
642+
user_query=user_query_str,
631643
system_prompt_content=system_prompt_content,
632644
openai_formatted_tools=openai_formatted_tools,
633645
http_session=http_session,
@@ -676,7 +688,7 @@ async def _process_single_sample_internal(
676688
else:
677689
generation_result = await self._execute_standard_generation(
678690
sample_id=sample_id,
679-
user_query=user_query,
691+
user_query=user_query_str,
680692
system_prompt_content=system_prompt_content,
681693
http_session=http_session,
682694
)

0 commit comments

Comments
 (0)