1212import threading
1313import time
1414from dataclasses import asdict
15- from typing import TYPE_CHECKING , Any , Callable , Dict , List , Optional , Union
15+ from typing import TYPE_CHECKING , Any , Callable , Dict , List , Optional , Union , cast
1616
1717import anyio
1818from openai .types import CompletionUsage
@@ -126,7 +126,15 @@ async def _execute_with_semaphore(idx):
126126
127127 evaluation_row .messages = messages
128128 evaluation_row .tools = shared_tool_schema
129- evaluation_row .usage = CompletionUsage (** trajectory .usage )
129+ # Some OpenAI SDK versions type CompletionUsage as a TypedDict; construct via cast to avoid ctor mismatches
130+ evaluation_row .usage = cast (
131+ CompletionUsage ,
132+ {
133+ "prompt_tokens" : trajectory .usage .get ("prompt_tokens" , 0 ),
134+ "completion_tokens" : trajectory .usage .get ("completion_tokens" , 0 ),
135+ "total_tokens" : trajectory .usage .get ("total_tokens" , 0 ),
136+ },
137+ )
130138 evaluation_row .input_metadata .completion_params = {
131139 "model" : policy .model_id ,
132140 "temperature" : getattr (policy , "temperature" , None ),
@@ -138,8 +146,14 @@ async def _execute_with_semaphore(idx):
138146 extra_info = None
139147 if trajectory .control_plane_summary .get ("error_message" ):
140148 extra_info = {"error_message" : trajectory .control_plane_summary .get ("error_message" )}
149+ # Convert string termination reason to TerminationReason enum if needed
150+ term_reason = (
151+ trajectory .termination_reason
152+ if isinstance (trajectory .termination_reason , TerminationReason )
153+ else TerminationReason .from_str (str (trajectory .termination_reason ))
154+ )
141155 evaluation_row .rollout_status = Status .rollout_finished (
142- termination_reason = trajectory . termination_reason , extra_info = extra_info
156+ termination_reason = term_reason , extra_info = extra_info
143157 )
144158 else :
145159 evaluation_row .rollout_status = Status .rollout_running ()
@@ -231,8 +245,9 @@ def extract_text_content(msg_dict):
231245
232246 # Get initial messages in tau2-bench format for user simulator
233247 user_simulator_state = user_simulator .get_init_state ()
248+ # Generate initial user response by prompting the simulator with a user role message
234249 user_message , user_simulator_state = await user_simulator .generate_next_message (
235- AssistantMessage (role = "assistant " , content = "Hi! How can I help you today? " ),
250+ UserMessage (role = "user " , content = "" ),
236251 user_simulator_state ,
237252 )
238253 current_observation = user_message .content if user_message .content else ""
@@ -264,8 +279,11 @@ def extract_text_content(msg_dict):
264279 # Last message was agent, simulated user response
265280 if user_simulator_messages and isinstance (user_simulator_messages [- 1 ], AssistantMessage ):
266281 # Generate user response using the simulator
282+ # Pass the assistant message content to drive the simulated user's next response
283+ last_assistant = user_simulator_messages [- 1 ]
267284 user_message , user_simulator_state = await user_simulator .generate_next_message (
268- user_simulator_messages [- 1 ], user_simulator_state
285+ last_assistant ,
286+ user_simulator_state ,
269287 )
270288 user_content = user_message .content if user_message .content else ""
271289
@@ -285,11 +303,33 @@ def extract_text_content(msg_dict):
285303 )
286304 update_evaluation_row_messages ()
287305
288- # calc llm usage stats happened in this turn if there is aany
306+ # Update LLM usage stats if available; support both dict-like and attribute access
289307 if usage_stats :
290- trajectory .usage ["prompt_tokens" ] += usage_stats .prompt_tokens
291- trajectory .usage ["completion_tokens" ] += usage_stats .completion_tokens
292- trajectory .usage ["total_tokens" ] += usage_stats .total_tokens
308+ try :
309+ prompt_tokens = (
310+ usage_stats .get ("prompt_tokens" )
311+ if isinstance (usage_stats , dict )
312+ else usage_stats .prompt_tokens
313+ )
314+ completion_tokens = (
315+ usage_stats .get ("completion_tokens" )
316+ if isinstance (usage_stats , dict )
317+ else usage_stats .completion_tokens
318+ )
319+ total_tokens = (
320+ usage_stats .get ("total_tokens" )
321+ if isinstance (usage_stats , dict )
322+ else usage_stats .total_tokens
323+ )
324+ if isinstance (prompt_tokens , int ):
325+ trajectory .usage ["prompt_tokens" ] += prompt_tokens
326+ if isinstance (completion_tokens , int ):
327+ trajectory .usage ["completion_tokens" ] += completion_tokens
328+ if isinstance (total_tokens , int ):
329+ trajectory .usage ["total_tokens" ] += total_tokens
330+ except Exception :
331+ # Best-effort; ignore malformed usage stats
332+ pass
293333
294334 # If no tool call is generated, turn is finished
295335 if len (tool_calls ) == 1 :
@@ -300,7 +340,8 @@ def extract_text_content(msg_dict):
300340 # If there's no user simulator, then it marks the end of the episode as LLM think there is no tool call needed.
301341 elif tool_calls [0 ].tool_name in ["_playback_terminate" , "_no_tool_call" ]:
302342 trajectory .terminated = True
303- trajectory .termination_reason = TerminationReason .from_str (finish_reason )
343+ # Ensure finish_reason is a string before converting
344+ trajectory .termination_reason = TerminationReason .from_str (str (finish_reason ))
304345 break
305346
306347 # Execute each tool call sequentially
@@ -404,11 +445,32 @@ def extract_text_content(msg_dict):
404445 )
405446 update_evaluation_row_messages ()
406447 if usage_stats :
407- trajectory .usage ["prompt_tokens" ] += usage_stats .prompt_tokens
408- trajectory .usage ["completion_tokens" ] += usage_stats .completion_tokens
409- trajectory .usage ["total_tokens" ] += usage_stats .total_tokens
448+ try :
449+ prompt_tokens = (
450+ usage_stats .get ("prompt_tokens" )
451+ if isinstance (usage_stats , dict )
452+ else usage_stats .prompt_tokens
453+ )
454+ completion_tokens = (
455+ usage_stats .get ("completion_tokens" )
456+ if isinstance (usage_stats , dict )
457+ else usage_stats .completion_tokens
458+ )
459+ total_tokens = (
460+ usage_stats .get ("total_tokens" )
461+ if isinstance (usage_stats , dict )
462+ else usage_stats .total_tokens
463+ )
464+ if isinstance (prompt_tokens , int ):
465+ trajectory .usage ["prompt_tokens" ] += prompt_tokens
466+ if isinstance (completion_tokens , int ):
467+ trajectory .usage ["completion_tokens" ] += completion_tokens
468+ if isinstance (total_tokens , int ):
469+ trajectory .usage ["total_tokens" ] += total_tokens
470+ except Exception :
471+ pass
410472 trajectory .terminated = True
411- trajectory .termination_reason = TerminationReason .from_str (finish_reason )
473+ trajectory .termination_reason = TerminationReason .from_str (str ( finish_reason ) )
412474 trajectory .control_plane_summary .update (
413475 {
414476 "total_reward" : trajectory .total_reward ,
0 commit comments