Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions eval_protocol/benchmarks/test_livebench_data_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,8 +424,9 @@ def _extract_gt(row: EvaluationRow) -> Dict[str, Any]:

@evaluation_test(
completion_params=[{"model": "fireworks_ai/accounts/fireworks/models/gpt-oss-120b"}],
# Wrap dataset messages in an extra list to match Sequence[list[InputMessagesParam]]
input_messages=[[[m for m in r.messages] for r in _CTA_ROWS]],
rollout_processor_kwargs=[{"extra_body": {"reasoning_effort": "low"}}],
rollout_processor_kwargs={"extra_body": {"reasoning_effort": "low"}},
rollout_processor=SingleTurnRolloutProcessor(),
aggregation_method="mean",
passed_threshold=None,
Expand Down Expand Up @@ -468,7 +469,7 @@ def test_livebench_cta_pointwise(row: EvaluationRow) -> EvaluationRow:
@evaluation_test(
completion_params=[{"model": "fireworks_ai/accounts/fireworks/models/gpt-oss-120b"}],
input_messages=[[[m for m in r.messages] for r in _TABLEJOIN_ROWS]],
rollout_processor_kwargs=[{"extra_body": {"reasoning_effort": "low"}}],
rollout_processor_kwargs={"extra_body": {"reasoning_effort": "low"}},
rollout_processor=LiveBenchGroundTruthRolloutProcessor(_TABLEJOIN_ROWS),
aggregation_method="mean",
passed_threshold=None,
Expand Down Expand Up @@ -510,8 +511,8 @@ def test_livebench_tablejoin_pointwise(row: EvaluationRow) -> EvaluationRow:

@evaluation_test(
completion_params=[{"model": "fireworks_ai/accounts/fireworks/models/gpt-oss-120b"}],
input_messages=[[m for m in r.messages] for r in _TABLEREFORMAT_ROWS],
rollout_processor_kwargs=[{"extra_body": {"reasoning_effort": "low"}}],
input_messages=[[[m for m in r.messages] for r in _TABLEREFORMAT_ROWS]],
rollout_processor_kwargs={"extra_body": {"reasoning_effort": "low"}},
rollout_processor=LiveBenchGroundTruthRolloutProcessor(_TABLEREFORMAT_ROWS),
aggregation_method="mean",
passed_threshold=None,
Expand Down
2 changes: 1 addition & 1 deletion eval_protocol/benchmarks/test_tau_bench_airline.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@
messages = row.messages

# Get evaluation criteria and user_simulation from input_metadata.dataset_info
dataset_info = row.input_metadata.dataset_info if row.input_metadata else {}
dataset_info = (row.input_metadata.dataset_info or {}) if row.input_metadata else {}
evaluation_criteria = dataset_info.get("evaluation_criteria", {})

nl_assertions = evaluation_criteria.get("nl_assertions", [])
Expand All @@ -169,7 +169,7 @@
if msg.tool_calls:
for tool_call in msg.tool_calls:
arguments = json.loads(tool_call.function.arguments)
tau2_tool_call = ToolCall(

Check failure on line 172 in eval_protocol/benchmarks/test_tau_bench_airline.py

View workflow job for this annotation

GitHub Actions / Lint & Type Check

Argument missing for parameter "requestor" (reportCallIssue)
id=tool_call.id,
name=tool_call.function.name,
arguments=arguments,
Expand All @@ -181,11 +181,11 @@
trajectory_objects.append(UserMessage(role=role, content=text_content))
elif role == "tool":
tool_id = msg.tool_call_id
trajectory_objects.append(ToolMessage(id=tool_id, role=role, content=text_content))

Check failure on line 184 in eval_protocol/benchmarks/test_tau_bench_airline.py

View workflow job for this annotation

GitHub Actions / Lint & Type Check

Argument missing for parameter "requestor" (reportCallIssue)

reward = 1.0

evaluation_criteria = EvaluationCriteria(

Check failure on line 188 in eval_protocol/benchmarks/test_tau_bench_airline.py

View workflow job for this annotation

GitHub Actions / Lint & Type Check

Argument missing for parameter "env_assertions" (reportCallIssue)
nl_assertions=nl_assertions,
communicate_info=communicate_info,
actions=actions,
Expand All @@ -195,8 +195,8 @@
],
)

task = Task(

Check failure on line 198 in eval_protocol/benchmarks/test_tau_bench_airline.py

View workflow job for this annotation

GitHub Actions / Lint & Type Check

Arguments missing for parameters "description", "ticket", "initial_state" (reportCallIssue)
id="Filler", evaluation_criteria=evaluation_criteria, user_scenario=UserScenario(instructions="Filler")

Check failure on line 199 in eval_protocol/benchmarks/test_tau_bench_airline.py

View workflow job for this annotation

GitHub Actions / Lint & Type Check

Argument missing for parameter "persona" (reportCallIssue)
) # id and user_scenario are required for the Task type but not used in calculating reward
assert task.evaluation_criteria is not None, "Task evaluation criteria is None"

Expand Down
2 changes: 1 addition & 1 deletion eval_protocol/benchmarks/test_tau_bench_retail.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@
messages = row.messages

# Get evaluation criteria and user_simulation from input_metadata.dataset_info
dataset_info = row.input_metadata.dataset_info if row.input_metadata else {}
dataset_info = (row.input_metadata.dataset_info or {}) if row.input_metadata else {}
evaluation_criteria = dataset_info.get("evaluation_criteria", {})

nl_assertions = evaluation_criteria.get("nl_assertions", [])
Expand All @@ -159,7 +159,7 @@
if msg.tool_calls:
for tool_call in msg.tool_calls:
arguments = json.loads(tool_call.function.arguments)
tau2_tool_call = ToolCall(

Check failure on line 162 in eval_protocol/benchmarks/test_tau_bench_retail.py

View workflow job for this annotation

GitHub Actions / Lint & Type Check

Argument missing for parameter "requestor" (reportCallIssue)
id=tool_call.id,
name=tool_call.function.name,
arguments=arguments,
Expand All @@ -171,11 +171,11 @@
trajectory_objects.append(UserMessage(role=role, content=text_content))
elif role == "tool":
tool_id = msg.tool_call_id
trajectory_objects.append(ToolMessage(id=tool_id, role=role, content=text_content))

Check failure on line 174 in eval_protocol/benchmarks/test_tau_bench_retail.py

View workflow job for this annotation

GitHub Actions / Lint & Type Check

Argument missing for parameter "requestor" (reportCallIssue)

reward = 1.0

evaluation_criteria = EvaluationCriteria(

Check failure on line 178 in eval_protocol/benchmarks/test_tau_bench_retail.py

View workflow job for this annotation

GitHub Actions / Lint & Type Check

Argument missing for parameter "env_assertions" (reportCallIssue)
nl_assertions=nl_assertions,
communicate_info=communicate_info,
actions=actions,
Expand Down
8 changes: 3 additions & 5 deletions eval_protocol/datasets/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def load_and_process_dataset(
# preprocessing_steps: Optional[List[str]] = None, # To be implemented
hf_extra_load_params: Optional[Dict[str, Any]] = None,
**kwargs: Any, # Catch-all for other params
) -> Union[Dataset, DatasetDict, List[Dict[str, Any]]]:
) -> Union[Dataset, DatasetDict]:
"""
Loads a dataset from the specified source.
Expand All @@ -116,7 +116,8 @@ def load_and_process_dataset(
Returns:
Loaded dataset, typically as Hugging Face Dataset or DatasetDict.
"""
loaded_dataset: Union[Dataset, DatasetDict, List[Dict[str, Any]]]
# Hugging Face load_dataset always returns Dataset or DatasetDict in our supported modes
loaded_dataset: Union[Dataset, DatasetDict]

# Prepare kwargs for datasets.load_dataset, separating out custom ones
load_kwargs_for_hf = hf_extra_load_params.copy() if hf_extra_load_params else {}
Expand Down Expand Up @@ -238,9 +239,6 @@ def load_and_process_dataset(
for s_name in loaded_dataset.keys():
if len(loaded_dataset[s_name]) > max_samples:
loaded_dataset[s_name] = loaded_dataset[s_name].select(range(max_samples))
elif isinstance(loaded_dataset, list): # Should not happen if always converting to HF Dataset
if len(loaded_dataset) > max_samples:
loaded_dataset = loaded_dataset[:max_samples]

# Apply column mapping if provided
if column_mapping_from_kwargs and isinstance(loaded_dataset, (Dataset, DatasetDict)):
Expand Down
7 changes: 7 additions & 0 deletions eval_protocol/execution/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ async def _discover_tools_for_sample(self, sample_id: str, mcp_backend_ref: str)

try:
backend_requests = [{"backend_name_ref": mcp_backend_ref, "num_instances": 1}]
assert self.mcp_intermediary_client is not None
init_response = await self.mcp_intermediary_client.initialize_session(backend_requests)

if init_response.get("error"):
Expand All @@ -109,6 +110,7 @@ async def _discover_tools_for_sample(self, sample_id: str, mcp_backend_ref: str)
current_instance_id = inst_info_dict.get("instance_id")
if not current_instance_id:
continue
assert self.mcp_intermediary_client is not None
list_tools_result = await self.mcp_intermediary_client.list_backend_tools(
rk_session_id=rk_session_id,
instance_id=current_instance_id,
Expand All @@ -130,6 +132,7 @@ async def _discover_tools_for_sample(self, sample_id: str, mcp_backend_ref: str)
if rk_session_id and self.mcp_intermediary_client:
logger.info(f"Sample {sample_id}: Cleaning up tool discovery session '{rk_session_id}'.")
try:
assert self.mcp_intermediary_client is not None
await self.mcp_intermediary_client.cleanup_session(rk_session_id)
except Exception as e_cl:
logger.error(
Expand Down Expand Up @@ -276,6 +279,7 @@ async def _execute_mcp_agent_rollout(

try:
backend_requests = [{"backend_name_ref": mcp_backend_ref, "num_instances": 1}]
assert self.mcp_intermediary_client is not None
init_response = await self.mcp_intermediary_client.initialize_session(backend_requests)
if init_response.get("error"):
raise RuntimeError(
Expand Down Expand Up @@ -331,6 +335,7 @@ async def _execute_mcp_agent_rollout(
if not isinstance(tool_args_dict, dict):
raise ValueError("Args not dict")

assert self.mcp_intermediary_client is not None
exec_result = await self.mcp_intermediary_client.call_backend_tool(
rk_session_id=rk_session_id,
instance_id=primary_instance_id_for_agent_actions,
Expand Down Expand Up @@ -405,6 +410,7 @@ async def _execute_mcp_agent_rollout(
state_capture_tool = self.cfg.agent.get("state_capture_tool")
if state_capture_tool:
state_capture_args = dict(self.cfg.agent.get("state_capture_args", OmegaConf.create({})))
assert self.mcp_intermediary_client is not None
final_filesystem_state_from_mcp = await self.mcp_intermediary_client.call_backend_tool(
rk_session_id=rk_session_id,
instance_id=primary_instance_id_for_agent_actions,
Expand Down Expand Up @@ -432,6 +438,7 @@ async def _execute_mcp_agent_rollout(
}
finally:
if rk_session_id and self.mcp_intermediary_client:
assert self.mcp_intermediary_client is not None
await self.mcp_intermediary_client.cleanup_session(rk_session_id)

async def _process_single_sample(
Expand Down
15 changes: 8 additions & 7 deletions eval_protocol/integrations/braintrust.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Adapters for integrating Eval Protocol with Braintrust scoring functions."""

from typing import Any, Callable, List, Optional
from typing import Any, Callable, List, Optional, cast

from eval_protocol.models import EvaluateResult, Message
from eval_protocol.typed_interface import reward_function
Expand All @@ -17,8 +17,7 @@ def scorer_to_reward_fn(
) -> Callable[[List[Message], Optional[List[Message]]], EvaluateResult]:
"""Wrap a Braintrust scorer as an Eval Protocol reward function."""

@reward_function
def reward_fn(
def reward_fn_core(
messages: List[Message], ground_truth: Optional[List[Message]] = None, **kwargs: Any
) -> EvaluateResult:
input_val = messages_to_input(messages) if messages_to_input else messages[0].content
Expand All @@ -29,9 +28,11 @@ def reward_fn(
ground_truth_to_expected(ground_truth) if ground_truth_to_expected else ground_truth[-1].content
)
score = scorer(input_val, output_val, expected_val)
return EvaluateResult(score=score)
return EvaluateResult(score=float(score))

return reward_fn
# Wrap with reward_function decorator while preserving precise callable type for type checker
wrapped = reward_function(reward_fn_core)
return cast(Callable[[List[Message], Optional[List[Message]]], EvaluateResult], wrapped)


def reward_fn_to_scorer(
Expand All @@ -47,7 +48,7 @@ def scorer(input_val: Any, output: Any, expected: Any) -> float:
ground_truth = None
if expected is not None:
ground_truth = [Message(role="assistant", content=str(expected))]
result = reward_fn(messages=messages, ground_truth=ground_truth)
return result.score
result = reward_fn(messages, ground_truth)
return float(result.score)

return scorer
11 changes: 6 additions & 5 deletions eval_protocol/mcp/client/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,24 +441,25 @@ async def call_tool(self, session: MCPSession, tool_name: str, arguments: Dict)
# Extract data plane results (observation only)
if tool_result.content and len(tool_result.content) > 0:
content = tool_result.content[0]
if hasattr(content, "text"):
text_value = getattr(content, "text", None)
if isinstance(text_value, str):
# Fix: Handle empty or invalid JSON responses gracefully
if not content.text or content.text.strip() == "":
if text_value.strip() == "":
logger.warning(f"Session {session.session_id}: Empty tool response from {tool_name}")
observation = {
"observation": "empty_response",
"session_id": session.session_id,
}
else:
try:
observation = json.loads(content.text)
observation = json.loads(text_value)
except json.JSONDecodeError as e:
logger.warning(
f"Session {session.session_id}: Invalid JSON from {tool_name}: {content.text}. Error: {e}"
f"Session {session.session_id}: Invalid JSON from {tool_name}: {text_value}. Error: {e}"
)
# Create a structured response from the raw text
observation = {
"observation": content.text,
"observation": text_value,
"session_id": session.session_id,
"error": "invalid_json_response",
}
Expand Down
Loading
Loading