diff --git a/eval_protocol/playback_policy.py b/eval_protocol/playback_policy.py index 876e419f..a84fa834 100644 --- a/eval_protocol/playback_policy.py +++ b/eval_protocol/playback_policy.py @@ -12,8 +12,6 @@ from abc import ABC, abstractmethod from typing import Any, Dict, List, Optional, Tuple -from openai.types import CompletionUsage - from .types import MCPToolCall logger = logging.getLogger(__name__) @@ -207,7 +205,7 @@ async def _generate_live_tool_calls( tool_schemas: List[Dict], env_index: int, conversation_history: List[Dict[str, Any]], - ) -> Tuple[List["MCPToolCall"], CompletionUsage, str]: + ) -> Tuple[List["MCPToolCall"], Optional[Dict[str, int]], Optional[str]]: """ Generate tool calls in live mode. Concrete classes must implement this. @@ -217,7 +215,7 @@ async def _generate_live_tool_calls( conversation_history: Current conversation history for this environment Returns: - List of ToolCall objects and LLM interation usage stats + Tuple of (tool calls, optional usage dict, optional correlation id) """ pass @@ -341,33 +339,7 @@ def get_playback_progress(self) -> Dict[str, Any]: return progress - def log_conversation_state_for_playback( - self, env_index: int, step: int, conversation_history: List[Dict[str, Any]] - ): - """ - Log the current conversation state in the format required for playback. - - Base implementation that subclasses can override with specific behavior. - Expected format: {"env_index": 0, "step": 0, "messages": [{..}, {..}]} - - Args: - env_index: Environment index - step: Current step number - conversation_history: List of conversation messages - """ - # Use EP_PLAYBACK_FILE environment variable for recording - playback_file = os.environ.get("EP_PLAYBACK_FILE") - if not playback_file: - return # No recording file specified - - playback_entry = { - "env_index": env_index, - "step": step, - "messages": conversation_history.copy(), - } - - with open(playback_file, "a") as f: - f.write(json.dumps(playback_entry) + "\n") + # Duplicate definition removed def log_conversation_state_for_playback( self, env_index: int, step: int, conversation_history: List[Dict[str, Any]] diff --git a/eval_protocol/rewards/accuracy_length.py b/eval_protocol/rewards/accuracy_length.py index b8e64eb2..45a99a3e 100644 --- a/eval_protocol/rewards/accuracy_length.py +++ b/eval_protocol/rewards/accuracy_length.py @@ -7,7 +7,7 @@ """ import math -from typing import Any, Callable, Dict, List, Optional, Union +from typing import Any, Callable, Dict, List, Optional, Union, cast from ..models import EvaluateResult, Message, MetricResult from ..typed_interface import reward_function @@ -77,12 +77,25 @@ def cosine_scaled_accuracy_length_reward( ) }, ) - text: str = response.content + # Coerce response content to string + text: str + if isinstance(response.content, str): + text = response.content + elif isinstance(response.content, list) and response.content: + # Join text parts if provided as structured content + try: + text = " ".join(part.text for part in response.content) # type: ignore[union-attr] + except Exception: + text = "" + else: + text = "" # Step 1: Evaluate accuracy - accuracy_eval_result = accuracy_reward( - messages=messages, # Pass the full messages list - ground_truth=ground_truth, # Pass the ground_truth list + # Ensure ground_truth is a list if provided; default to [] for compatibility + gt_for_accuracy = ground_truth if ground_truth is not None else [] + accuracy_eval_result = cast(Any, accuracy_reward)( + messages=messages, + ground_truth=gt_for_accuracy, extract_fn=extract_fn, compare_fn=compare_fn, ) diff --git a/eval_protocol/server.py b/eval_protocol/server.py index 7a013f5a..76944761 100644 --- a/eval_protocol/server.py +++ b/eval_protocol/server.py @@ -2,11 +2,11 @@ import json import logging import os -from typing import Any, Callable, Dict, List, Optional, Union +from typing import Any, Callable, Dict, List, Optional, Union, Tuple, cast -import uvicorn -from fastapi import FastAPI, HTTPException, Request -from pydantic import BaseModel, Field +import uvicorn # type: ignore[reportMissingImports] +from fastapi import FastAPI, HTTPException, Request # type: ignore[reportMissingImports] +from pydantic import BaseModel, Field # type: ignore[reportMissingImports] from .models import EvaluateResult @@ -254,7 +254,7 @@ async def reward(request_data: RewardRequest): return result.model_dump() elif isinstance(result, tuple) and len(result) == 2: # Legacy tuple logger.warning("Reward function passed to create_app returned legacy tuple format.") - score, components = result + score, components = cast(Tuple[float, Dict[str, Any]], result) return {"score": score, "metrics": components} else: raise TypeError(f"Invalid return type from reward function after decoration: {type(result)}") diff --git a/eval_protocol/utils/static_policy.py b/eval_protocol/utils/static_policy.py index c8b31792..7c0a2c52 100644 --- a/eval_protocol/utils/static_policy.py +++ b/eval_protocol/utils/static_policy.py @@ -19,7 +19,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union # Import the base policy and types for proper recording functionality -from openai.types import CompletionUsage +from typing import Optional as _Optional from ..playback_policy import PlaybackPolicyBase from ..types import MCPToolCall @@ -73,7 +73,7 @@ async def _generate_live_tool_calls( tool_schemas: List[Dict], env_index: int, conversation_history: List[Dict[str, Any]], - ) -> Tuple[List[MCPToolCall], CompletionUsage, str]: + ) -> Tuple[List[MCPToolCall], Optional[Dict[str, int]], Optional[str]]: """ Generate tool calls in live mode using the static action sequence. @@ -105,7 +105,7 @@ async def _generate_live_tool_calls( logger.debug(f"🎮 Env {env_index} step {step_count}: {action}") - usage_stats = CompletionUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0) + usage_stats: Optional[Dict[str, int]] = {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0} return [tool_call], usage_stats, None def add_tool_response( @@ -116,7 +116,7 @@ def add_tool_response( conversation_history: List[Dict[str, Any]], reward: float = 0.0, terminated: bool = False, - info: Dict[str, Any] = None, + info: Optional[Dict[str, Any]] = None, ): """Add tool call and response to conversation history for recording.""" @@ -220,7 +220,7 @@ async def _generate_live_tool_calls( tool_schemas: List[Dict], env_index: int, conversation_history: List[Dict[str, Any]], - ) -> Tuple[List[MCPToolCall], CompletionUsage, str]: + ) -> Tuple[List[MCPToolCall], Optional[Dict[str, int]], Optional[str]]: """ Generate random tool calls in live mode. @@ -240,7 +240,7 @@ async def _generate_live_tool_calls( logger.debug(f"🎲 Env {env_index}: {action}") - usage_stats = CompletionUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0) + usage_stats: Optional[Dict[str, int]] = {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0} return [tool_call], usage_stats, None def add_tool_response( @@ -251,7 +251,7 @@ def add_tool_response( conversation_history: List[Dict[str, Any]], reward: float = 0.0, terminated: bool = False, - info: Dict[str, Any] = None, + info: Optional[Dict[str, Any]] = None, ): """Add tool call and response to conversation history for recording."""