Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 28 additions & 4 deletions eval_protocol/mcp/execution/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,13 +101,12 @@ def execute_rollouts(

async def _execute_with_semaphore(idx):
async with semaphore:
evaluation_row: EvaluationRow = evaluation_rows[idx]

trajectory = await self._execute_rollout(
envs, policy, idx, steps, openai_logger, recording_mode, playback_mode, start_time
envs, policy, idx, steps, openai_logger, recording_mode, playback_mode, start_time, evaluation_row
)

# Convert trajectory to EvaluationRow immediately
evaluation_row: EvaluationRow = evaluation_rows[idx]

# Handle multimodal content by extracting text from complex content structures
messages = []
for msg in trajectory.conversation_history:
Expand Down Expand Up @@ -161,6 +160,7 @@ async def _execute_rollout(
recording_mode: bool,
playback_mode: bool,
start_time: float,
evaluation_row: Optional[EvaluationRow] = None,
) -> Trajectory:
"""
Execute a single rollout for one environment (async version for thread execution).
Expand All @@ -170,6 +170,25 @@ async def _execute_rollout(
session = envs.sessions[rollout_idx]
dataset_row = envs.dataset_rows[rollout_idx]

# Helper function to sync conversation history to evaluation_row.messages
def update_evaluation_row_messages():
if evaluation_row:

def extract_text_content(msg_dict):
msg_copy = dict(msg_dict)
if isinstance(msg_copy.get("content"), list):
for content_block in msg_copy["content"]:
if isinstance(content_block, dict) and content_block.get("type") == "text":
msg_copy["content"] = content_block.get("text", "")
break
else:
msg_copy["content"] = ""
return msg_copy

evaluation_row.messages = [
Message.model_validate(extract_text_content(msg)) for msg in trajectory.conversation_history
]

# Initialize trajectory
trajectory = Trajectory(
session=session,
Expand Down Expand Up @@ -223,6 +242,7 @@ async def _execute_rollout(
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
]
update_evaluation_row_messages()

logger.info(f"🎯 Starting rollout {rollout_idx} in thread {threading.current_thread().name}")

Expand Down Expand Up @@ -251,6 +271,7 @@ async def _execute_rollout(

user_prompt = envs.format_user_prompt(rollout_idx, user_content)
trajectory.conversation_history.append({"role": "user", "content": user_prompt})
update_evaluation_row_messages()

# Check if user simulator signaled termination
if UserSimulator.is_stop(user_message):
Expand All @@ -262,6 +283,7 @@ async def _execute_rollout(
tool_calls, usage_stats, finish_reason = await policy(
tool_schema, rollout_idx, trajectory.conversation_history
)
update_evaluation_row_messages()

# calc llm usage stats happened in this turn if there is aany
if usage_stats:
Expand Down Expand Up @@ -297,6 +319,7 @@ async def _execute_rollout(
env_end,
info,
)
update_evaluation_row_messages()

# Update trajectory with both data and control plane information
trajectory.observations.append(observation)
Expand Down Expand Up @@ -379,6 +402,7 @@ async def _execute_rollout(
_, usage_stats, finish_reason = await policy(
tool_schema, rollout_idx, trajectory.conversation_history
)
update_evaluation_row_messages()
if usage_stats:
trajectory.usage["prompt_tokens"] += usage_stats.prompt_tokens
trajectory.usage["completion_tokens"] += usage_stats.completion_tokens
Expand Down
Loading