88import json
99import logging
1010import os
11- from typing import Any , Dict , List , Optional , Union
11+ from typing import Any , Dict , List , Optional , Union , cast
1212
1313import aiohttp
1414import hydra
@@ -312,7 +312,7 @@ async def _execute_mcp_agent_rollout(
312312 tools = openai_formatted_tools ,
313313 )
314314
315- assistant_msg_for_history = {"role" : "assistant" }
315+ assistant_msg_for_history : Dict [ str , Any ] = {"role" : "assistant" }
316316
317317 if generation_output_turn .tool_calls :
318318 assistant_msg_for_history ["tool_calls" ] = [
@@ -479,7 +479,7 @@ async def _process_single_sample_internal(
479479 sample : Dict [str , Any ],
480480 http_session : Optional [aiohttp .ClientSession ], # For model_client, not mcp_client
481481 original_index : Optional [int ] = None ,
482- ) -> Optional [ Dict [str , Any ] ]:
482+ ) -> Dict [str , Any ]:
483483 sample_id_fallback = (
484484 f"idx_{ original_index } " if original_index is not None else "unknown_id_" + os .urandom (4 ).hex ()
485485 )
@@ -497,7 +497,10 @@ async def _process_single_sample_internal(
497497 logger .warning (
498498 f"Skipping sample { sample_id } : needs either ('user_query' + 'ground_truth_for_eval') for generation or 'messages' for evaluation."
499499 )
500- return None
500+ return {
501+ "id" : sample_id ,
502+ "error" : "Missing required fields for generation/evaluation" ,
503+ }
501504
502505 original_system_prompt = sample .get ("system_prompt" ) or self .cfg .get ("system_prompt" )
503506 discovered_tools_for_llm_prompt : List [Dict [str , Any ]] = []
@@ -582,13 +585,18 @@ async def _process_single_sample_internal(
582585 }
583586 else :
584587 logger .warning (f"Sample { sample_id } : Evaluation mode requires generation.enabled=false" )
585- return None
588+ return {
589+ "id" : sample_id ,
590+ "error" : "Evaluation mode requires generation.enabled=false" ,
591+ }
586592
587593 # Generation mode: Initial messages for the main rollout (or single generation if not agent)
594+ # At this point, generation format is guaranteed by the control flow above; cast for type checking
595+ user_query_str : str = cast (str , user_query )
588596 current_messages_for_rollout : List [Dict [str , Any ]] = []
589597 if system_prompt_content :
590598 current_messages_for_rollout .append ({"role" : "system" , "content" : system_prompt_content })
591- current_messages_for_rollout .append ({"role" : "user" , "content" : user_query })
599+ current_messages_for_rollout .append ({"role" : "user" , "content" : user_query_str })
592600
593601 # --- LLM Generation / Agent Rollout ---
594602 if not self .cfg .generation .enabled :
@@ -605,10 +613,14 @@ async def _process_single_sample_internal(
605613 final_assistant_output_for_log = self .cache .get (
606614 sample_id = sample_id ,
607615 system_prompt = original_system_prompt ,
608- user_query = user_query ,
616+ user_query = cast ( str , user_query ) ,
609617 model_name = gen_cfg .get ("model_name" , "unknown_model" ),
610618 temperature = gen_cfg .get ("temperature" , 0.0 ),
611- # ... other cache params
619+ top_p = gen_cfg .get ("top_p" , 1.0 ),
620+ top_k = gen_cfg .get ("top_k" , 0 ),
621+ min_p = gen_cfg .get ("min_p" , 0.0 ),
622+ max_tokens = gen_cfg .get ("max_tokens" , 2048 ),
623+ reasoning_effort = gen_cfg .get ("reasoning_effort" , None ),
612624 )
613625 if not final_assistant_output_for_log :
614626 return {
@@ -627,7 +639,7 @@ async def _process_single_sample_internal(
627639 elif self .mcp_intermediary_client and self .cfg .agent .type == "mcp_agent" :
628640 mcp_result = await self ._execute_mcp_agent_rollout (
629641 sample_id = sample_id ,
630- user_query = user_query ,
642+ user_query = user_query_str ,
631643 system_prompt_content = system_prompt_content ,
632644 openai_formatted_tools = openai_formatted_tools ,
633645 http_session = http_session ,
@@ -676,7 +688,7 @@ async def _process_single_sample_internal(
676688 else :
677689 generation_result = await self ._execute_standard_generation (
678690 sample_id = sample_id ,
679- user_query = user_query ,
691+ user_query = user_query_str ,
680692 system_prompt_content = system_prompt_content ,
681693 http_session = http_session ,
682694 )
0 commit comments