diff --git a/eval_protocol/mcp/execution/manager.py b/eval_protocol/mcp/execution/manager.py index 8e461b49..d6cb2b83 100644 --- a/eval_protocol/mcp/execution/manager.py +++ b/eval_protocol/mcp/execution/manager.py @@ -101,13 +101,12 @@ def execute_rollouts( async def _execute_with_semaphore(idx): async with semaphore: + evaluation_row: EvaluationRow = evaluation_rows[idx] + trajectory = await self._execute_rollout( - envs, policy, idx, steps, openai_logger, recording_mode, playback_mode, start_time + envs, policy, idx, steps, openai_logger, recording_mode, playback_mode, start_time, evaluation_row ) - # Convert trajectory to EvaluationRow immediately - evaluation_row: EvaluationRow = evaluation_rows[idx] - # Handle multimodal content by extracting text from complex content structures messages = [] for msg in trajectory.conversation_history: @@ -161,6 +160,7 @@ async def _execute_rollout( recording_mode: bool, playback_mode: bool, start_time: float, + evaluation_row: Optional[EvaluationRow] = None, ) -> Trajectory: """ Execute a single rollout for one environment (async version for thread execution). @@ -170,6 +170,25 @@ async def _execute_rollout( session = envs.sessions[rollout_idx] dataset_row = envs.dataset_rows[rollout_idx] + # Helper function to sync conversation history to evaluation_row.messages + def update_evaluation_row_messages(): + if evaluation_row: + + def extract_text_content(msg_dict): + msg_copy = dict(msg_dict) + if isinstance(msg_copy.get("content"), list): + for content_block in msg_copy["content"]: + if isinstance(content_block, dict) and content_block.get("type") == "text": + msg_copy["content"] = content_block.get("text", "") + break + else: + msg_copy["content"] = "" + return msg_copy + + evaluation_row.messages = [ + Message.model_validate(extract_text_content(msg)) for msg in trajectory.conversation_history + ] + # Initialize trajectory trajectory = Trajectory( session=session, @@ -223,6 +242,7 @@ async def _execute_rollout( {"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt}, ] + update_evaluation_row_messages() logger.info(f"🎯 Starting rollout {rollout_idx} in thread {threading.current_thread().name}") @@ -251,6 +271,7 @@ async def _execute_rollout( user_prompt = envs.format_user_prompt(rollout_idx, user_content) trajectory.conversation_history.append({"role": "user", "content": user_prompt}) + update_evaluation_row_messages() # Check if user simulator signaled termination if UserSimulator.is_stop(user_message): @@ -262,6 +283,7 @@ async def _execute_rollout( tool_calls, usage_stats, finish_reason = await policy( tool_schema, rollout_idx, trajectory.conversation_history ) + update_evaluation_row_messages() # calc llm usage stats happened in this turn if there is aany if usage_stats: @@ -297,6 +319,7 @@ async def _execute_rollout( env_end, info, ) + update_evaluation_row_messages() # Update trajectory with both data and control plane information trajectory.observations.append(observation) @@ -379,6 +402,7 @@ async def _execute_rollout( _, usage_stats, finish_reason = await policy( tool_schema, rollout_idx, trajectory.conversation_history ) + update_evaluation_row_messages() if usage_stats: trajectory.usage["prompt_tokens"] += usage_stats.prompt_tokens trajectory.usage["completion_tokens"] += usage_stats.completion_tokens