Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 3 additions & 31 deletions eval_protocol/playback_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional, Tuple

from openai.types import CompletionUsage

from .types import MCPToolCall

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -207,7 +205,7 @@ async def _generate_live_tool_calls(
tool_schemas: List[Dict],
env_index: int,
conversation_history: List[Dict[str, Any]],
) -> Tuple[List["MCPToolCall"], CompletionUsage, str]:
) -> Tuple[List["MCPToolCall"], Optional[Dict[str, int]], Optional[str]]:
"""
Generate tool calls in live mode. Concrete classes must implement this.

Expand All @@ -217,7 +215,7 @@ async def _generate_live_tool_calls(
conversation_history: Current conversation history for this environment

Returns:
List of ToolCall objects and LLM interation usage stats
Tuple of (tool calls, optional usage dict, optional correlation id)
"""
pass

Expand Down Expand Up @@ -341,33 +339,7 @@ def get_playback_progress(self) -> Dict[str, Any]:

return progress

def log_conversation_state_for_playback(
self, env_index: int, step: int, conversation_history: List[Dict[str, Any]]
):
"""
Log the current conversation state in the format required for playback.

Base implementation that subclasses can override with specific behavior.
Expected format: {"env_index": 0, "step": 0, "messages": [{..}, {..}]}

Args:
env_index: Environment index
step: Current step number
conversation_history: List of conversation messages
"""
# Use EP_PLAYBACK_FILE environment variable for recording
playback_file = os.environ.get("EP_PLAYBACK_FILE")
if not playback_file:
return # No recording file specified

playback_entry = {
"env_index": env_index,
"step": step,
"messages": conversation_history.copy(),
}

with open(playback_file, "a") as f:
f.write(json.dumps(playback_entry) + "\n")
# Duplicate definition removed

def log_conversation_state_for_playback(
self, env_index: int, step: int, conversation_history: List[Dict[str, Any]]
Expand Down
23 changes: 18 additions & 5 deletions eval_protocol/rewards/accuracy_length.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
"""

import math
from typing import Any, Callable, Dict, List, Optional, Union
from typing import Any, Callable, Dict, List, Optional, Union, cast

from ..models import EvaluateResult, Message, MetricResult
from ..typed_interface import reward_function
Expand Down Expand Up @@ -77,12 +77,25 @@ def cosine_scaled_accuracy_length_reward(
)
},
)
text: str = response.content
# Coerce response content to string
text: str
if isinstance(response.content, str):
text = response.content
elif isinstance(response.content, list) and response.content:
# Join text parts if provided as structured content
try:
text = " ".join(part.text for part in response.content) # type: ignore[union-attr]
Copy link

Copilot AI Aug 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The type ignore comment suggests uncertainty about the structure. Consider adding a more specific type check or explicit attribute validation before accessing .text to make the code more robust.

Suggested change
text = " ".join(part.text for part in response.content) # type: ignore[union-attr]
text = " ".join(
getattr(part, "text", "") for part in response.content if hasattr(part, "text")
)

Copilot uses AI. Check for mistakes.
except Exception:
text = ""
else:
text = ""

# Step 1: Evaluate accuracy
accuracy_eval_result = accuracy_reward(
messages=messages, # Pass the full messages list
ground_truth=ground_truth, # Pass the ground_truth list
# Ensure ground_truth is a list if provided; default to [] for compatibility
gt_for_accuracy = ground_truth if ground_truth is not None else []
accuracy_eval_result = cast(Any, accuracy_reward)(
Copy link

Copilot AI Aug 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Casting to Any defeats the purpose of type checking. Consider using a more specific type or ensuring the decorator preserves the original function signature instead of using cast(Any, ...).

Suggested change
accuracy_eval_result = cast(Any, accuracy_reward)(
accuracy_eval_result = accuracy_reward(

Copilot uses AI. Check for mistakes.
messages=messages,
ground_truth=gt_for_accuracy,
extract_fn=extract_fn,
compare_fn=compare_fn,
)
Expand Down
10 changes: 5 additions & 5 deletions eval_protocol/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@
import json
import logging
import os
from typing import Any, Callable, Dict, List, Optional, Union
from typing import Any, Callable, Dict, List, Optional, Union, Tuple, cast

import uvicorn
from fastapi import FastAPI, HTTPException, Request
from pydantic import BaseModel, Field
import uvicorn # type: ignore[reportMissingImports]
from fastapi import FastAPI, HTTPException, Request # type: ignore[reportMissingImports]
from pydantic import BaseModel, Field # type: ignore[reportMissingImports]

from .models import EvaluateResult

Expand Down Expand Up @@ -254,7 +254,7 @@ async def reward(request_data: RewardRequest):
return result.model_dump()
elif isinstance(result, tuple) and len(result) == 2: # Legacy tuple
logger.warning("Reward function passed to create_app returned legacy tuple format.")
score, components = result
score, components = cast(Tuple[float, Dict[str, Any]], result)
return {"score": score, "metrics": components}
else:
raise TypeError(f"Invalid return type from reward function after decoration: {type(result)}")
Expand Down
14 changes: 7 additions & 7 deletions eval_protocol/utils/static_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from typing import Any, Dict, List, Optional, Tuple, Union

# Import the base policy and types for proper recording functionality
from openai.types import CompletionUsage
from typing import Optional as _Optional
Copy link

Copilot AI Aug 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This import is unused and creates an unnecessary alias. The regular Optional import on line 19 is sufficient for the usage in this file.

Suggested change
from typing import Optional as _Optional

Copilot uses AI. Check for mistakes.

from ..playback_policy import PlaybackPolicyBase
from ..types import MCPToolCall
Expand Down Expand Up @@ -73,7 +73,7 @@ async def _generate_live_tool_calls(
tool_schemas: List[Dict],
env_index: int,
conversation_history: List[Dict[str, Any]],
) -> Tuple[List[MCPToolCall], CompletionUsage, str]:
) -> Tuple[List[MCPToolCall], Optional[Dict[str, int]], Optional[str]]:
"""
Generate tool calls in live mode using the static action sequence.

Expand Down Expand Up @@ -105,7 +105,7 @@ async def _generate_live_tool_calls(

logger.debug(f"🎮 Env {env_index} step {step_count}: {action}")

usage_stats = CompletionUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0)
usage_stats: Optional[Dict[str, int]] = {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}
return [tool_call], usage_stats, None

def add_tool_response(
Expand All @@ -116,7 +116,7 @@ def add_tool_response(
conversation_history: List[Dict[str, Any]],
reward: float = 0.0,
terminated: bool = False,
info: Dict[str, Any] = None,
info: Optional[Dict[str, Any]] = None,
):
"""Add tool call and response to conversation history for recording."""

Expand Down Expand Up @@ -220,7 +220,7 @@ async def _generate_live_tool_calls(
tool_schemas: List[Dict],
env_index: int,
conversation_history: List[Dict[str, Any]],
) -> Tuple[List[MCPToolCall], CompletionUsage, str]:
) -> Tuple[List[MCPToolCall], Optional[Dict[str, int]], Optional[str]]:
"""
Generate random tool calls in live mode.

Expand All @@ -240,7 +240,7 @@ async def _generate_live_tool_calls(

logger.debug(f"🎲 Env {env_index}: {action}")

usage_stats = CompletionUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0)
usage_stats: Optional[Dict[str, int]] = {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}
return [tool_call], usage_stats, None

def add_tool_response(
Expand All @@ -251,7 +251,7 @@ def add_tool_response(
conversation_history: List[Dict[str, Any]],
reward: float = 0.0,
terminated: bool = False,
info: Dict[str, Any] = None,
info: Optional[Dict[str, Any]] = None,
):
"""Add tool call and response to conversation history for recording."""

Expand Down
Loading