Skip to content

Commit f0cda72

Browse files
benjibcBenny Chen
andauthored
type fix round 7 (#149)
* type fix round 7 * fix a few more --------- Co-authored-by: Benny Chen <bchen@Bennys-MacBook-Air.local>
1 parent 56c7bdd commit f0cda72

File tree

14 files changed

+138
-49
lines changed

14 files changed

+138
-49
lines changed

eval_protocol/benchmarks/test_livebench_data_analysis.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -424,8 +424,9 @@ def _extract_gt(row: EvaluationRow) -> Dict[str, Any]:
424424

425425
@evaluation_test(
426426
completion_params=[{"model": "fireworks_ai/accounts/fireworks/models/gpt-oss-120b"}],
427+
# Wrap dataset messages in an extra list to match Sequence[list[InputMessagesParam]]
427428
input_messages=[[[m for m in r.messages] for r in _CTA_ROWS]],
428-
rollout_processor_kwargs=[{"extra_body": {"reasoning_effort": "low"}}],
429+
rollout_processor_kwargs={"extra_body": {"reasoning_effort": "low"}},
429430
rollout_processor=SingleTurnRolloutProcessor(),
430431
aggregation_method="mean",
431432
passed_threshold=None,
@@ -468,7 +469,7 @@ def test_livebench_cta_pointwise(row: EvaluationRow) -> EvaluationRow:
468469
@evaluation_test(
469470
completion_params=[{"model": "fireworks_ai/accounts/fireworks/models/gpt-oss-120b"}],
470471
input_messages=[[[m for m in r.messages] for r in _TABLEJOIN_ROWS]],
471-
rollout_processor_kwargs=[{"extra_body": {"reasoning_effort": "low"}}],
472+
rollout_processor_kwargs={"extra_body": {"reasoning_effort": "low"}},
472473
rollout_processor=LiveBenchGroundTruthRolloutProcessor(_TABLEJOIN_ROWS),
473474
aggregation_method="mean",
474475
passed_threshold=None,
@@ -510,8 +511,8 @@ def test_livebench_tablejoin_pointwise(row: EvaluationRow) -> EvaluationRow:
510511

511512
@evaluation_test(
512513
completion_params=[{"model": "fireworks_ai/accounts/fireworks/models/gpt-oss-120b"}],
513-
input_messages=[[m for m in r.messages] for r in _TABLEREFORMAT_ROWS],
514-
rollout_processor_kwargs=[{"extra_body": {"reasoning_effort": "low"}}],
514+
input_messages=[[[m for m in r.messages] for r in _TABLEREFORMAT_ROWS]],
515+
rollout_processor_kwargs={"extra_body": {"reasoning_effort": "low"}},
515516
rollout_processor=LiveBenchGroundTruthRolloutProcessor(_TABLEREFORMAT_ROWS),
516517
aggregation_method="mean",
517518
passed_threshold=None,

eval_protocol/benchmarks/test_tau_bench_airline.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ def test_tau_bench_airline_evaluation(row: EvaluationRow) -> EvaluationRow:
147147
messages = row.messages
148148

149149
# Get evaluation criteria and user_simulation from input_metadata.dataset_info
150-
dataset_info = row.input_metadata.dataset_info if row.input_metadata else {}
150+
dataset_info = (row.input_metadata.dataset_info or {}) if row.input_metadata else {}
151151
evaluation_criteria = dataset_info.get("evaluation_criteria", {})
152152

153153
nl_assertions = evaluation_criteria.get("nl_assertions", [])

eval_protocol/benchmarks/test_tau_bench_retail.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ def test_tau_bench_retail_evaluation(row: EvaluationRow) -> EvaluationRow:
137137
messages = row.messages
138138

139139
# Get evaluation criteria and user_simulation from input_metadata.dataset_info
140-
dataset_info = row.input_metadata.dataset_info if row.input_metadata else {}
140+
dataset_info = (row.input_metadata.dataset_info or {}) if row.input_metadata else {}
141141
evaluation_criteria = dataset_info.get("evaluation_criteria", {})
142142

143143
nl_assertions = evaluation_criteria.get("nl_assertions", [])

eval_protocol/datasets/loader.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ def load_and_process_dataset(
9797
# preprocessing_steps: Optional[List[str]] = None, # To be implemented
9898
hf_extra_load_params: Optional[Dict[str, Any]] = None,
9999
**kwargs: Any, # Catch-all for other params
100-
) -> Union[Dataset, DatasetDict, List[Dict[str, Any]]]:
100+
) -> Union[Dataset, DatasetDict]:
101101
"""
102102
Loads a dataset from the specified source.
103103
@@ -116,7 +116,8 @@ def load_and_process_dataset(
116116
Returns:
117117
Loaded dataset, typically as Hugging Face Dataset or DatasetDict.
118118
"""
119-
loaded_dataset: Union[Dataset, DatasetDict, List[Dict[str, Any]]]
119+
# Hugging Face load_dataset always returns Dataset or DatasetDict in our supported modes
120+
loaded_dataset: Union[Dataset, DatasetDict]
120121

121122
# Prepare kwargs for datasets.load_dataset, separating out custom ones
122123
load_kwargs_for_hf = hf_extra_load_params.copy() if hf_extra_load_params else {}
@@ -238,9 +239,6 @@ def load_and_process_dataset(
238239
for s_name in loaded_dataset.keys():
239240
if len(loaded_dataset[s_name]) > max_samples:
240241
loaded_dataset[s_name] = loaded_dataset[s_name].select(range(max_samples))
241-
elif isinstance(loaded_dataset, list): # Should not happen if always converting to HF Dataset
242-
if len(loaded_dataset) > max_samples:
243-
loaded_dataset = loaded_dataset[:max_samples]
244242

245243
# Apply column mapping if provided
246244
if column_mapping_from_kwargs and isinstance(loaded_dataset, (Dataset, DatasetDict)):

eval_protocol/execution/pipeline.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ async def _discover_tools_for_sample(self, sample_id: str, mcp_backend_ref: str)
8787

8888
try:
8989
backend_requests = [{"backend_name_ref": mcp_backend_ref, "num_instances": 1}]
90+
assert self.mcp_intermediary_client is not None
9091
init_response = await self.mcp_intermediary_client.initialize_session(backend_requests)
9192

9293
if init_response.get("error"):
@@ -109,6 +110,7 @@ async def _discover_tools_for_sample(self, sample_id: str, mcp_backend_ref: str)
109110
current_instance_id = inst_info_dict.get("instance_id")
110111
if not current_instance_id:
111112
continue
113+
assert self.mcp_intermediary_client is not None
112114
list_tools_result = await self.mcp_intermediary_client.list_backend_tools(
113115
rk_session_id=rk_session_id,
114116
instance_id=current_instance_id,
@@ -130,6 +132,7 @@ async def _discover_tools_for_sample(self, sample_id: str, mcp_backend_ref: str)
130132
if rk_session_id and self.mcp_intermediary_client:
131133
logger.info(f"Sample {sample_id}: Cleaning up tool discovery session '{rk_session_id}'.")
132134
try:
135+
assert self.mcp_intermediary_client is not None
133136
await self.mcp_intermediary_client.cleanup_session(rk_session_id)
134137
except Exception as e_cl:
135138
logger.error(
@@ -276,6 +279,7 @@ async def _execute_mcp_agent_rollout(
276279

277280
try:
278281
backend_requests = [{"backend_name_ref": mcp_backend_ref, "num_instances": 1}]
282+
assert self.mcp_intermediary_client is not None
279283
init_response = await self.mcp_intermediary_client.initialize_session(backend_requests)
280284
if init_response.get("error"):
281285
raise RuntimeError(
@@ -331,6 +335,7 @@ async def _execute_mcp_agent_rollout(
331335
if not isinstance(tool_args_dict, dict):
332336
raise ValueError("Args not dict")
333337

338+
assert self.mcp_intermediary_client is not None
334339
exec_result = await self.mcp_intermediary_client.call_backend_tool(
335340
rk_session_id=rk_session_id,
336341
instance_id=primary_instance_id_for_agent_actions,
@@ -405,6 +410,7 @@ async def _execute_mcp_agent_rollout(
405410
state_capture_tool = self.cfg.agent.get("state_capture_tool")
406411
if state_capture_tool:
407412
state_capture_args = dict(self.cfg.agent.get("state_capture_args", OmegaConf.create({})))
413+
assert self.mcp_intermediary_client is not None
408414
final_filesystem_state_from_mcp = await self.mcp_intermediary_client.call_backend_tool(
409415
rk_session_id=rk_session_id,
410416
instance_id=primary_instance_id_for_agent_actions,
@@ -432,6 +438,7 @@ async def _execute_mcp_agent_rollout(
432438
}
433439
finally:
434440
if rk_session_id and self.mcp_intermediary_client:
441+
assert self.mcp_intermediary_client is not None
435442
await self.mcp_intermediary_client.cleanup_session(rk_session_id)
436443

437444
async def _process_single_sample(

eval_protocol/integrations/braintrust.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""Adapters for integrating Eval Protocol with Braintrust scoring functions."""
22

3-
from typing import Any, Callable, List, Optional
3+
from typing import Any, Callable, List, Optional, cast
44

55
from eval_protocol.models import EvaluateResult, Message
66
from eval_protocol.typed_interface import reward_function
@@ -17,8 +17,7 @@ def scorer_to_reward_fn(
1717
) -> Callable[[List[Message], Optional[List[Message]]], EvaluateResult]:
1818
"""Wrap a Braintrust scorer as an Eval Protocol reward function."""
1919

20-
@reward_function
21-
def reward_fn(
20+
def reward_fn_core(
2221
messages: List[Message], ground_truth: Optional[List[Message]] = None, **kwargs: Any
2322
) -> EvaluateResult:
2423
input_val = messages_to_input(messages) if messages_to_input else messages[0].content
@@ -29,9 +28,11 @@ def reward_fn(
2928
ground_truth_to_expected(ground_truth) if ground_truth_to_expected else ground_truth[-1].content
3029
)
3130
score = scorer(input_val, output_val, expected_val)
32-
return EvaluateResult(score=score)
31+
return EvaluateResult(score=float(score))
3332

34-
return reward_fn
33+
# Wrap with reward_function decorator while preserving precise callable type for type checker
34+
wrapped = reward_function(reward_fn_core)
35+
return cast(Callable[[List[Message], Optional[List[Message]]], EvaluateResult], wrapped)
3536

3637

3738
def reward_fn_to_scorer(
@@ -47,7 +48,7 @@ def scorer(input_val: Any, output: Any, expected: Any) -> float:
4748
ground_truth = None
4849
if expected is not None:
4950
ground_truth = [Message(role="assistant", content=str(expected))]
50-
result = reward_fn(messages=messages, ground_truth=ground_truth)
51-
return result.score
51+
result = reward_fn(messages, ground_truth)
52+
return float(result.score)
5253

5354
return scorer

eval_protocol/mcp/client/connection.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -441,24 +441,25 @@ async def call_tool(self, session: MCPSession, tool_name: str, arguments: Dict)
441441
# Extract data plane results (observation only)
442442
if tool_result.content and len(tool_result.content) > 0:
443443
content = tool_result.content[0]
444-
if hasattr(content, "text"):
444+
text_value = getattr(content, "text", None)
445+
if isinstance(text_value, str):
445446
# Fix: Handle empty or invalid JSON responses gracefully
446-
if not content.text or content.text.strip() == "":
447+
if text_value.strip() == "":
447448
logger.warning(f"Session {session.session_id}: Empty tool response from {tool_name}")
448449
observation = {
449450
"observation": "empty_response",
450451
"session_id": session.session_id,
451452
}
452453
else:
453454
try:
454-
observation = json.loads(content.text)
455+
observation = json.loads(text_value)
455456
except json.JSONDecodeError as e:
456457
logger.warning(
457-
f"Session {session.session_id}: Invalid JSON from {tool_name}: {content.text}. Error: {e}"
458+
f"Session {session.session_id}: Invalid JSON from {tool_name}: {text_value}. Error: {e}"
458459
)
459460
# Create a structured response from the raw text
460461
observation = {
461-
"observation": content.text,
462+
"observation": text_value,
462463
"session_id": session.session_id,
463464
"error": "invalid_json_response",
464465
}

0 commit comments

Comments
 (0)