Skip to content

Commit cdf92b5

Browse files
Benny ChenBenny Chen
authored andcommitted
type fix round 7
1 parent 56c7bdd commit cdf92b5

File tree

7 files changed

+105
-38
lines changed

7 files changed

+105
-38
lines changed

eval_protocol/benchmarks/test_livebench_data_analysis.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -424,8 +424,9 @@ def _extract_gt(row: EvaluationRow) -> Dict[str, Any]:
424424

425425
@evaluation_test(
426426
completion_params=[{"model": "fireworks_ai/accounts/fireworks/models/gpt-oss-120b"}],
427-
input_messages=[[[m for m in r.messages] for r in _CTA_ROWS]],
428-
rollout_processor_kwargs=[{"extra_body": {"reasoning_effort": "low"}}],
427+
# Provide a flat list per run (Sequence[InputMessagesParam]) to match signature
428+
input_messages=[[m for m in r.messages] for r in _CTA_ROWS],
429+
rollout_processor_kwargs={"extra_body": {"reasoning_effort": "low"}},
429430
rollout_processor=SingleTurnRolloutProcessor(),
430431
aggregation_method="mean",
431432
passed_threshold=None,
@@ -467,8 +468,8 @@ def test_livebench_cta_pointwise(row: EvaluationRow) -> EvaluationRow:
467468

468469
@evaluation_test(
469470
completion_params=[{"model": "fireworks_ai/accounts/fireworks/models/gpt-oss-120b"}],
470-
input_messages=[[[m for m in r.messages] for r in _TABLEJOIN_ROWS]],
471-
rollout_processor_kwargs=[{"extra_body": {"reasoning_effort": "low"}}],
471+
input_messages=[[m for m in r.messages] for r in _TABLEJOIN_ROWS],
472+
rollout_processor_kwargs={"extra_body": {"reasoning_effort": "low"}},
472473
rollout_processor=LiveBenchGroundTruthRolloutProcessor(_TABLEJOIN_ROWS),
473474
aggregation_method="mean",
474475
passed_threshold=None,
@@ -511,7 +512,7 @@ def test_livebench_tablejoin_pointwise(row: EvaluationRow) -> EvaluationRow:
511512
@evaluation_test(
512513
completion_params=[{"model": "fireworks_ai/accounts/fireworks/models/gpt-oss-120b"}],
513514
input_messages=[[m for m in r.messages] for r in _TABLEREFORMAT_ROWS],
514-
rollout_processor_kwargs=[{"extra_body": {"reasoning_effort": "low"}}],
515+
rollout_processor_kwargs={"extra_body": {"reasoning_effort": "low"}},
515516
rollout_processor=LiveBenchGroundTruthRolloutProcessor(_TABLEREFORMAT_ROWS),
516517
aggregation_method="mean",
517518
passed_threshold=None,

eval_protocol/datasets/loader.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ def load_and_process_dataset(
9797
# preprocessing_steps: Optional[List[str]] = None, # To be implemented
9898
hf_extra_load_params: Optional[Dict[str, Any]] = None,
9999
**kwargs: Any, # Catch-all for other params
100-
) -> Union[Dataset, DatasetDict, List[Dict[str, Any]]]:
100+
) -> Union[Dataset, DatasetDict]:
101101
"""
102102
Loads a dataset from the specified source.
103103
@@ -116,7 +116,8 @@ def load_and_process_dataset(
116116
Returns:
117117
Loaded dataset, typically as Hugging Face Dataset or DatasetDict.
118118
"""
119-
loaded_dataset: Union[Dataset, DatasetDict, List[Dict[str, Any]]]
119+
# Hugging Face load_dataset always returns Dataset or DatasetDict in our supported modes
120+
loaded_dataset: Union[Dataset, DatasetDict]
120121

121122
# Prepare kwargs for datasets.load_dataset, separating out custom ones
122123
load_kwargs_for_hf = hf_extra_load_params.copy() if hf_extra_load_params else {}
@@ -238,9 +239,6 @@ def load_and_process_dataset(
238239
for s_name in loaded_dataset.keys():
239240
if len(loaded_dataset[s_name]) > max_samples:
240241
loaded_dataset[s_name] = loaded_dataset[s_name].select(range(max_samples))
241-
elif isinstance(loaded_dataset, list): # Should not happen if always converting to HF Dataset
242-
if len(loaded_dataset) > max_samples:
243-
loaded_dataset = loaded_dataset[:max_samples]
244242

245243
# Apply column mapping if provided
246244
if column_mapping_from_kwargs and isinstance(loaded_dataset, (Dataset, DatasetDict)):

eval_protocol/integrations/braintrust.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""Adapters for integrating Eval Protocol with Braintrust scoring functions."""
22

3-
from typing import Any, Callable, List, Optional
3+
from typing import Any, Callable, List, Optional, cast
44

55
from eval_protocol.models import EvaluateResult, Message
66
from eval_protocol.typed_interface import reward_function
@@ -17,8 +17,7 @@ def scorer_to_reward_fn(
1717
) -> Callable[[List[Message], Optional[List[Message]]], EvaluateResult]:
1818
"""Wrap a Braintrust scorer as an Eval Protocol reward function."""
1919

20-
@reward_function
21-
def reward_fn(
20+
def reward_fn_core(
2221
messages: List[Message], ground_truth: Optional[List[Message]] = None, **kwargs: Any
2322
) -> EvaluateResult:
2423
input_val = messages_to_input(messages) if messages_to_input else messages[0].content
@@ -29,9 +28,11 @@ def reward_fn(
2928
ground_truth_to_expected(ground_truth) if ground_truth_to_expected else ground_truth[-1].content
3029
)
3130
score = scorer(input_val, output_val, expected_val)
32-
return EvaluateResult(score=score)
31+
return EvaluateResult(score=float(score))
3332

34-
return reward_fn
33+
# Wrap with reward_function decorator while preserving precise callable type for type checker
34+
wrapped = reward_function(reward_fn_core)
35+
return cast(Callable[[List[Message], Optional[List[Message]]], EvaluateResult], wrapped)
3536

3637

3738
def reward_fn_to_scorer(
@@ -48,6 +49,6 @@ def scorer(input_val: Any, output: Any, expected: Any) -> float:
4849
if expected is not None:
4950
ground_truth = [Message(role="assistant", content=str(expected))]
5051
result = reward_fn(messages=messages, ground_truth=ground_truth)
51-
return result.score
52+
return float(result.score)
5253

5354
return scorer

eval_protocol/mcp/client/connection.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -441,24 +441,25 @@ async def call_tool(self, session: MCPSession, tool_name: str, arguments: Dict)
441441
# Extract data plane results (observation only)
442442
if tool_result.content and len(tool_result.content) > 0:
443443
content = tool_result.content[0]
444-
if hasattr(content, "text"):
444+
text_value = getattr(content, "text", None)
445+
if isinstance(text_value, str):
445446
# Fix: Handle empty or invalid JSON responses gracefully
446-
if not content.text or content.text.strip() == "":
447+
if text_value.strip() == "":
447448
logger.warning(f"Session {session.session_id}: Empty tool response from {tool_name}")
448449
observation = {
449450
"observation": "empty_response",
450451
"session_id": session.session_id,
451452
}
452453
else:
453454
try:
454-
observation = json.loads(content.text)
455+
observation = json.loads(text_value)
455456
except json.JSONDecodeError as e:
456457
logger.warning(
457-
f"Session {session.session_id}: Invalid JSON from {tool_name}: {content.text}. Error: {e}"
458+
f"Session {session.session_id}: Invalid JSON from {tool_name}: {text_value}. Error: {e}"
458459
)
459460
# Create a structured response from the raw text
460461
observation = {
461-
"observation": content.text,
462+
"observation": text_value,
462463
"session_id": session.session_id,
463464
"error": "invalid_json_response",
464465
}

eval_protocol/mcp/execution/manager.py

Lines changed: 76 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import threading
1313
import time
1414
from dataclasses import asdict
15-
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union
15+
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union, cast
1616

1717
import anyio
1818
from openai.types import CompletionUsage
@@ -126,7 +126,15 @@ async def _execute_with_semaphore(idx):
126126

127127
evaluation_row.messages = messages
128128
evaluation_row.tools = shared_tool_schema
129-
evaluation_row.usage = CompletionUsage(**trajectory.usage)
129+
# Some OpenAI SDK versions type CompletionUsage as a TypedDict; construct via cast to avoid ctor mismatches
130+
evaluation_row.usage = cast(
131+
CompletionUsage,
132+
{
133+
"prompt_tokens": trajectory.usage.get("prompt_tokens", 0),
134+
"completion_tokens": trajectory.usage.get("completion_tokens", 0),
135+
"total_tokens": trajectory.usage.get("total_tokens", 0),
136+
},
137+
)
130138
evaluation_row.input_metadata.completion_params = {
131139
"model": policy.model_id,
132140
"temperature": getattr(policy, "temperature", None),
@@ -138,8 +146,14 @@ async def _execute_with_semaphore(idx):
138146
extra_info = None
139147
if trajectory.control_plane_summary.get("error_message"):
140148
extra_info = {"error_message": trajectory.control_plane_summary.get("error_message")}
149+
# Convert string termination reason to TerminationReason enum if needed
150+
term_reason = (
151+
trajectory.termination_reason
152+
if isinstance(trajectory.termination_reason, TerminationReason)
153+
else TerminationReason.from_str(str(trajectory.termination_reason))
154+
)
141155
evaluation_row.rollout_status = Status.rollout_finished(
142-
termination_reason=trajectory.termination_reason, extra_info=extra_info
156+
termination_reason=term_reason, extra_info=extra_info
143157
)
144158
else:
145159
evaluation_row.rollout_status = Status.rollout_running()
@@ -231,8 +245,9 @@ def extract_text_content(msg_dict):
231245

232246
# Get initial messages in tau2-bench format for user simulator
233247
user_simulator_state = user_simulator.get_init_state()
248+
# Generate initial user response by prompting the simulator with a user role message
234249
user_message, user_simulator_state = await user_simulator.generate_next_message(
235-
AssistantMessage(role="assistant", content="Hi! How can I help you today?"),
250+
UserMessage(role="user", content=""),
236251
user_simulator_state,
237252
)
238253
current_observation = user_message.content if user_message.content else ""
@@ -264,8 +279,11 @@ def extract_text_content(msg_dict):
264279
# Last message was agent, simulated user response
265280
if user_simulator_messages and isinstance(user_simulator_messages[-1], AssistantMessage):
266281
# Generate user response using the simulator
282+
# Pass the assistant message content to drive the simulated user's next response
283+
last_assistant = user_simulator_messages[-1]
267284
user_message, user_simulator_state = await user_simulator.generate_next_message(
268-
user_simulator_messages[-1], user_simulator_state
285+
last_assistant,
286+
user_simulator_state,
269287
)
270288
user_content = user_message.content if user_message.content else ""
271289

@@ -285,11 +303,33 @@ def extract_text_content(msg_dict):
285303
)
286304
update_evaluation_row_messages()
287305

288-
# calc llm usage stats happened in this turn if there is aany
306+
# Update LLM usage stats if available; support both dict-like and attribute access
289307
if usage_stats:
290-
trajectory.usage["prompt_tokens"] += usage_stats.prompt_tokens
291-
trajectory.usage["completion_tokens"] += usage_stats.completion_tokens
292-
trajectory.usage["total_tokens"] += usage_stats.total_tokens
308+
try:
309+
prompt_tokens = (
310+
usage_stats.get("prompt_tokens")
311+
if isinstance(usage_stats, dict)
312+
else usage_stats.prompt_tokens
313+
)
314+
completion_tokens = (
315+
usage_stats.get("completion_tokens")
316+
if isinstance(usage_stats, dict)
317+
else usage_stats.completion_tokens
318+
)
319+
total_tokens = (
320+
usage_stats.get("total_tokens")
321+
if isinstance(usage_stats, dict)
322+
else usage_stats.total_tokens
323+
)
324+
if isinstance(prompt_tokens, int):
325+
trajectory.usage["prompt_tokens"] += prompt_tokens
326+
if isinstance(completion_tokens, int):
327+
trajectory.usage["completion_tokens"] += completion_tokens
328+
if isinstance(total_tokens, int):
329+
trajectory.usage["total_tokens"] += total_tokens
330+
except Exception:
331+
# Best-effort; ignore malformed usage stats
332+
pass
293333

294334
# If no tool call is generated, turn is finished
295335
if len(tool_calls) == 1:
@@ -300,7 +340,8 @@ def extract_text_content(msg_dict):
300340
# If there's no user simulator, then it marks the end of the episode as LLM think there is no tool call needed.
301341
elif tool_calls[0].tool_name in ["_playback_terminate", "_no_tool_call"]:
302342
trajectory.terminated = True
303-
trajectory.termination_reason = TerminationReason.from_str(finish_reason)
343+
# Ensure finish_reason is a string before converting
344+
trajectory.termination_reason = TerminationReason.from_str(str(finish_reason))
304345
break
305346

306347
# Execute each tool call sequentially
@@ -404,11 +445,32 @@ def extract_text_content(msg_dict):
404445
)
405446
update_evaluation_row_messages()
406447
if usage_stats:
407-
trajectory.usage["prompt_tokens"] += usage_stats.prompt_tokens
408-
trajectory.usage["completion_tokens"] += usage_stats.completion_tokens
409-
trajectory.usage["total_tokens"] += usage_stats.total_tokens
448+
try:
449+
prompt_tokens = (
450+
usage_stats.get("prompt_tokens")
451+
if isinstance(usage_stats, dict)
452+
else usage_stats.prompt_tokens
453+
)
454+
completion_tokens = (
455+
usage_stats.get("completion_tokens")
456+
if isinstance(usage_stats, dict)
457+
else usage_stats.completion_tokens
458+
)
459+
total_tokens = (
460+
usage_stats.get("total_tokens")
461+
if isinstance(usage_stats, dict)
462+
else usage_stats.total_tokens
463+
)
464+
if isinstance(prompt_tokens, int):
465+
trajectory.usage["prompt_tokens"] += prompt_tokens
466+
if isinstance(completion_tokens, int):
467+
trajectory.usage["completion_tokens"] += completion_tokens
468+
if isinstance(total_tokens, int):
469+
trajectory.usage["total_tokens"] += total_tokens
470+
except Exception:
471+
pass
410472
trajectory.terminated = True
411-
trajectory.termination_reason = TerminationReason.from_str(finish_reason)
473+
trajectory.termination_reason = TerminationReason.from_str(str(finish_reason))
412474
trajectory.control_plane_summary.update(
413475
{
414476
"total_reward": trajectory.total_reward,

eval_protocol/pytest/default_mcp_gym_rollout_processor.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ def _wait_for_server_ready(self, timeout: int = 15) -> bool:
107107

108108
while time.time() - start_time < timeout:
109109
# Check if process is still running
110-
if self.process.poll() is not None:
110+
if self.process and self.process.poll() is not None:
111111
print("Server process exited early")
112112
return False
113113

@@ -220,7 +220,9 @@ def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig) ->
220220
self.server.start()
221221

222222
self.policy = ep.LiteLLMPolicy(
223-
model_id=config.completion_params.get("model", None),
223+
model_id=str(
224+
(config.completion_params.get("model") if config.completion_params else None) or "gpt-4o-mini"
225+
),
224226
temperature=config.completion_params.get("temperature", 0.0),
225227
max_tokens=config.completion_params.get("max_tokens", 4096),
226228
**(config.completion_params.get("extra_body", {}) or {}),

eval_protocol/rewards/function_calling.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -394,9 +394,11 @@ def exact_tool_match_reward(
394394
try:
395395
ground_truth = json.loads(ground_truth)
396396
except json.JSONDecodeError:
397+
# Cast to string before slicing to satisfy type checker if ground_truth is of unknown type
398+
gt_preview = str(ground_truth)
397399
return EvaluateResult(
398400
score=0.0,
399-
reason=f"Ground truth was a string but failed to parse as JSON: {ground_truth[:100]}...",
401+
reason=f"Ground truth was a string but failed to parse as JSON: {gt_preview[:100]}...",
400402
metrics={},
401403
)
402404

0 commit comments

Comments
 (0)