Skip to content

Commit 2df860d

Browse files
cursoragentbenjibc
andcommitted
Refactor type hints and remove OpenAI-specific imports across multiple files
Co-authored-by: bchen <bchen@fireworks.ai>
1 parent facd060 commit 2df860d

File tree

4 files changed

+33
-48
lines changed

4 files changed

+33
-48
lines changed

eval_protocol/playback_policy.py

Lines changed: 3 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,6 @@
1212
from abc import ABC, abstractmethod
1313
from typing import Any, Dict, List, Optional, Tuple
1414

15-
from openai.types import CompletionUsage
16-
1715
from .types import MCPToolCall
1816

1917
logger = logging.getLogger(__name__)
@@ -207,7 +205,7 @@ async def _generate_live_tool_calls(
207205
tool_schemas: List[Dict],
208206
env_index: int,
209207
conversation_history: List[Dict[str, Any]],
210-
) -> Tuple[List["MCPToolCall"], CompletionUsage, str]:
208+
) -> Tuple[List["MCPToolCall"], Optional[Dict[str, int]], Optional[str]]:
211209
"""
212210
Generate tool calls in live mode. Concrete classes must implement this.
213211
@@ -217,7 +215,7 @@ async def _generate_live_tool_calls(
217215
conversation_history: Current conversation history for this environment
218216
219217
Returns:
220-
List of ToolCall objects and LLM interation usage stats
218+
Tuple of (tool calls, optional usage dict, optional correlation id)
221219
"""
222220
pass
223221

@@ -341,33 +339,7 @@ def get_playback_progress(self) -> Dict[str, Any]:
341339

342340
return progress
343341

344-
def log_conversation_state_for_playback(
345-
self, env_index: int, step: int, conversation_history: List[Dict[str, Any]]
346-
):
347-
"""
348-
Log the current conversation state in the format required for playback.
349-
350-
Base implementation that subclasses can override with specific behavior.
351-
Expected format: {"env_index": 0, "step": 0, "messages": [{..}, {..}]}
352-
353-
Args:
354-
env_index: Environment index
355-
step: Current step number
356-
conversation_history: List of conversation messages
357-
"""
358-
# Use EP_PLAYBACK_FILE environment variable for recording
359-
playback_file = os.environ.get("EP_PLAYBACK_FILE")
360-
if not playback_file:
361-
return # No recording file specified
362-
363-
playback_entry = {
364-
"env_index": env_index,
365-
"step": step,
366-
"messages": conversation_history.copy(),
367-
}
368-
369-
with open(playback_file, "a") as f:
370-
f.write(json.dumps(playback_entry) + "\n")
342+
# Duplicate definition removed
371343

372344
def log_conversation_state_for_playback(
373345
self, env_index: int, step: int, conversation_history: List[Dict[str, Any]]

eval_protocol/rewards/accuracy_length.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
"""
88

99
import math
10-
from typing import Any, Callable, Dict, List, Optional, Union
10+
from typing import Any, Callable, Dict, List, Optional, Union, cast
1111

1212
from ..models import EvaluateResult, Message, MetricResult
1313
from ..typed_interface import reward_function
@@ -77,12 +77,25 @@ def cosine_scaled_accuracy_length_reward(
7777
)
7878
},
7979
)
80-
text: str = response.content
80+
# Coerce response content to string
81+
text: str
82+
if isinstance(response.content, str):
83+
text = response.content
84+
elif isinstance(response.content, list) and response.content:
85+
# Join text parts if provided as structured content
86+
try:
87+
text = " ".join(part.text for part in response.content) # type: ignore[union-attr]
88+
except Exception:
89+
text = ""
90+
else:
91+
text = ""
8192

8293
# Step 1: Evaluate accuracy
83-
accuracy_eval_result = accuracy_reward(
84-
messages=messages, # Pass the full messages list
85-
ground_truth=ground_truth, # Pass the ground_truth list
94+
# Ensure ground_truth is a list if provided; default to [] for compatibility
95+
gt_for_accuracy = ground_truth if ground_truth is not None else []
96+
accuracy_eval_result = cast(Any, accuracy_reward)(
97+
messages=messages,
98+
ground_truth=gt_for_accuracy,
8699
extract_fn=extract_fn,
87100
compare_fn=compare_fn,
88101
)

eval_protocol/server.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,11 @@
22
import json
33
import logging
44
import os
5-
from typing import Any, Callable, Dict, List, Optional, Union
5+
from typing import Any, Callable, Dict, List, Optional, Union, Tuple, cast
66

7-
import uvicorn
8-
from fastapi import FastAPI, HTTPException, Request
9-
from pydantic import BaseModel, Field
7+
import uvicorn # type: ignore[reportMissingImports]
8+
from fastapi import FastAPI, HTTPException, Request # type: ignore[reportMissingImports]
9+
from pydantic import BaseModel, Field # type: ignore[reportMissingImports]
1010

1111
from .models import EvaluateResult
1212

@@ -254,7 +254,7 @@ async def reward(request_data: RewardRequest):
254254
return result.model_dump()
255255
elif isinstance(result, tuple) and len(result) == 2: # Legacy tuple
256256
logger.warning("Reward function passed to create_app returned legacy tuple format.")
257-
score, components = result
257+
score, components = cast(Tuple[float, Dict[str, Any]], result)
258258
return {"score": score, "metrics": components}
259259
else:
260260
raise TypeError(f"Invalid return type from reward function after decoration: {type(result)}")

eval_protocol/utils/static_policy.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from typing import Any, Dict, List, Optional, Tuple, Union
2020

2121
# Import the base policy and types for proper recording functionality
22-
from openai.types import CompletionUsage
22+
from typing import Optional as _Optional
2323

2424
from ..playback_policy import PlaybackPolicyBase
2525
from ..types import MCPToolCall
@@ -73,7 +73,7 @@ async def _generate_live_tool_calls(
7373
tool_schemas: List[Dict],
7474
env_index: int,
7575
conversation_history: List[Dict[str, Any]],
76-
) -> Tuple[List[MCPToolCall], CompletionUsage, str]:
76+
) -> Tuple[List[MCPToolCall], Optional[Dict[str, int]], Optional[str]]:
7777
"""
7878
Generate tool calls in live mode using the static action sequence.
7979
@@ -105,7 +105,7 @@ async def _generate_live_tool_calls(
105105

106106
logger.debug(f"🎮 Env {env_index} step {step_count}: {action}")
107107

108-
usage_stats = CompletionUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0)
108+
usage_stats: Optional[Dict[str, int]] = {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}
109109
return [tool_call], usage_stats, None
110110

111111
def add_tool_response(
@@ -116,7 +116,7 @@ def add_tool_response(
116116
conversation_history: List[Dict[str, Any]],
117117
reward: float = 0.0,
118118
terminated: bool = False,
119-
info: Dict[str, Any] = None,
119+
info: Optional[Dict[str, Any]] = None,
120120
):
121121
"""Add tool call and response to conversation history for recording."""
122122

@@ -220,7 +220,7 @@ async def _generate_live_tool_calls(
220220
tool_schemas: List[Dict],
221221
env_index: int,
222222
conversation_history: List[Dict[str, Any]],
223-
) -> Tuple[List[MCPToolCall], CompletionUsage, str]:
223+
) -> Tuple[List[MCPToolCall], Optional[Dict[str, int]], Optional[str]]:
224224
"""
225225
Generate random tool calls in live mode.
226226
@@ -240,7 +240,7 @@ async def _generate_live_tool_calls(
240240

241241
logger.debug(f"🎲 Env {env_index}: {action}")
242242

243-
usage_stats = CompletionUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0)
243+
usage_stats: Optional[Dict[str, int]] = {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}
244244
return [tool_call], usage_stats, None
245245

246246
def add_tool_response(
@@ -251,7 +251,7 @@ def add_tool_response(
251251
conversation_history: List[Dict[str, Any]],
252252
reward: float = 0.0,
253253
terminated: bool = False,
254-
info: Dict[str, Any] = None,
254+
info: Optional[Dict[str, Any]] = None,
255255
):
256256
"""Add tool call and response to conversation history for recording."""
257257

0 commit comments

Comments
 (0)