Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions eval_protocol/agent/task_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from copy import deepcopy
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, List, Optional, Set, Tuple
from typing import Any, Dict, List, Optional, Set, Tuple, cast

import requests

Expand Down Expand Up @@ -684,6 +684,7 @@ async def execute_single_rollout(sample_index: int, rollout_index: int, sample_d

# Add sample metadata to the result
if isinstance(result, dict):
result = cast(Dict[str, Any], result)
result["sample_data"] = sample_data
result["sample_index"] = sample_index
result["rollout_index"] = rollout_index
Expand Down Expand Up @@ -920,9 +921,10 @@ def _save_detailed_results(
if chosen_dir is None:
chosen_dir = Path(".")

output_file = chosen_dir / f"trajectory_{task_id}_{timestamp}.jsonl"
output_path = chosen_dir / f"trajectory_{task_id}_{timestamp}.jsonl"

output_path = Path(output_file)
else:
output_path = Path(output_file)

try:
self.logger.info("=== TRAJECTORY SAVE DEBUG START ===")
Expand Down
32 changes: 22 additions & 10 deletions eval_protocol/execution/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import json
import logging
import os
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List, Optional, Union, cast

import aiohttp
import hydra
Expand Down Expand Up @@ -312,7 +312,7 @@ async def _execute_mcp_agent_rollout(
tools=openai_formatted_tools,
)

assistant_msg_for_history = {"role": "assistant"}
assistant_msg_for_history: Dict[str, Any] = {"role": "assistant"}

if generation_output_turn.tool_calls:
assistant_msg_for_history["tool_calls"] = [
Expand Down Expand Up @@ -479,7 +479,7 @@ async def _process_single_sample_internal(
sample: Dict[str, Any],
http_session: Optional[aiohttp.ClientSession], # For model_client, not mcp_client
original_index: Optional[int] = None,
) -> Optional[Dict[str, Any]]:
) -> Dict[str, Any]:
sample_id_fallback = (
f"idx_{original_index}" if original_index is not None else "unknown_id_" + os.urandom(4).hex()
)
Expand All @@ -497,7 +497,10 @@ async def _process_single_sample_internal(
logger.warning(
f"Skipping sample {sample_id}: needs either ('user_query' + 'ground_truth_for_eval') for generation or 'messages' for evaluation."
)
return None
return {
"id": sample_id,
"error": "Missing required fields for generation/evaluation",
}

original_system_prompt = sample.get("system_prompt") or self.cfg.get("system_prompt")
discovered_tools_for_llm_prompt: List[Dict[str, Any]] = []
Expand Down Expand Up @@ -582,13 +585,18 @@ async def _process_single_sample_internal(
}
else:
logger.warning(f"Sample {sample_id}: Evaluation mode requires generation.enabled=false")
return None
return {
"id": sample_id,
"error": "Evaluation mode requires generation.enabled=false",
}

# Generation mode: Initial messages for the main rollout (or single generation if not agent)
# At this point, generation format is guaranteed by the control flow above; cast for type checking
user_query_str: str = cast(str, user_query)
current_messages_for_rollout: List[Dict[str, Any]] = []
if system_prompt_content:
current_messages_for_rollout.append({"role": "system", "content": system_prompt_content})
current_messages_for_rollout.append({"role": "user", "content": user_query})
current_messages_for_rollout.append({"role": "user", "content": user_query_str})

# --- LLM Generation / Agent Rollout ---
if not self.cfg.generation.enabled:
Expand All @@ -605,10 +613,14 @@ async def _process_single_sample_internal(
final_assistant_output_for_log = self.cache.get(
sample_id=sample_id,
system_prompt=original_system_prompt,
user_query=user_query,
user_query=cast(str, user_query),
model_name=gen_cfg.get("model_name", "unknown_model"),
temperature=gen_cfg.get("temperature", 0.0),
# ... other cache params
top_p=gen_cfg.get("top_p", 1.0),
top_k=gen_cfg.get("top_k", 0),
min_p=gen_cfg.get("min_p", 0.0),
max_tokens=gen_cfg.get("max_tokens", 2048),
reasoning_effort=gen_cfg.get("reasoning_effort", None),
)
if not final_assistant_output_for_log:
return {
Expand All @@ -627,7 +639,7 @@ async def _process_single_sample_internal(
elif self.mcp_intermediary_client and self.cfg.agent.type == "mcp_agent":
mcp_result = await self._execute_mcp_agent_rollout(
sample_id=sample_id,
user_query=user_query,
user_query=user_query_str,
system_prompt_content=system_prompt_content,
openai_formatted_tools=openai_formatted_tools,
http_session=http_session,
Expand Down Expand Up @@ -676,7 +688,7 @@ async def _process_single_sample_internal(
else:
generation_result = await self._execute_standard_generation(
sample_id=sample_id,
user_query=user_query,
user_query=user_query_str,
system_prompt_content=system_prompt_content,
http_session=http_session,
)
Expand Down
Loading