Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion eval_protocol/benchmarks/test_tau_bench_airline.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,9 @@ def test_tau_bench_airline_evaluation(row: EvaluationRow) -> EvaluationRow:
trajectory_objects.append(UserMessage(role=role, content=text_content))
elif role == "tool":
tool_id = msg.tool_call_id
trajectory_objects.append(ToolMessage(id=tool_id, role=role, content=text_content, requestor="assistant"))
trajectory_objects.append(
ToolMessage(id=tool_id or "unknown_tool_call", role=role, content=text_content, requestor="assistant")
)

reward = 1.0

Expand Down
4 changes: 3 additions & 1 deletion eval_protocol/benchmarks/test_tau_bench_retail.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,9 @@ def test_tau_bench_retail_evaluation(row: EvaluationRow) -> EvaluationRow:
trajectory_objects.append(UserMessage(role=role, content=text_content))
elif role == "tool":
tool_id = msg.tool_call_id
trajectory_objects.append(ToolMessage(id=tool_id, role=role, content=text_content, requestor="assistant"))
trajectory_objects.append(
ToolMessage(id=tool_id or "unknown_tool_call", role=role, content=text_content, requestor="assistant")
)

reward = 1.0

Expand Down
8 changes: 5 additions & 3 deletions eval_protocol/execution/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -847,9 +847,11 @@ async def process_with_semaphore_wrapper(sample_idx: int, sample_data: Dict[str,

for i_outer in range(0, len(tasks), batch_size_for_logging):
batch_tasks = tasks[i_outer : i_outer + batch_size_for_logging]
batch_results_values: List[
Union[Exception, Dict[str, Any], List[Dict[str, Any]]]
] = await asyncio.gather(*batch_tasks, return_exceptions=True)
# asyncio.gather with return_exceptions=True returns List[Any]; cast to expected union
batch_results_values = cast(
List[Union[Exception, Dict[str, Any], List[Dict[str, Any]]]],
await asyncio.gather(*batch_tasks, return_exceptions=True),
)
for res_idx, res_or_exc in enumerate(batch_results_values):
if isinstance(res_or_exc, Exception):
logger.error(
Expand Down
21 changes: 13 additions & 8 deletions eval_protocol/mcp/simulation_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def reset_environment(self, env, seed): ...
from abc import ABC, abstractmethod
from collections.abc import AsyncIterator
from contextlib import asynccontextmanager
from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple, Iterable
from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple, Iterable, cast
from pydantic import AnyUrl

import uvicorn
Expand Down Expand Up @@ -327,12 +327,12 @@ async def list_resources():
# Extract docstring as description
description = resource_func.__doc__ or f"Resource {resource_name}"

# Some callables may not have the attribute; guard for type checkers
# MyPy/Pyright: Resource expects AnyUrl; convert string to str, letting pydantic coerce it
uri_value = getattr(resource_func, "_resource_uri", f"/{resource_name}")
# Some callables may not have the attribute; guard for type checkers.
# Resource expects AnyUrl; pass as str and allow coercion by pydantic.
uri_value: str = str(getattr(resource_func, "_resource_uri", f"/{resource_name}"))
resources.append(
Resource(
uri=uri_value,
uri=cast(AnyUrl, uri_value),
name=resource_name,
description=description,
mimeType="application/json",
Expand All @@ -347,10 +347,15 @@ def _register_session_handlers(self):
"""Register session initialization and cleanup handlers."""

@self.app.set_logging_level()
async def set_logging_level(level: str):
async def set_logging_level(level: str) -> None:
"""Handle logging level requests."""
logger.setLevel(getattr(logging, level.upper()))
return {}
# Validate and set logging level; ignore invalid values gracefully
try:
numeric_level = getattr(logging, level.upper())
if isinstance(numeric_level, int):
logger.setLevel(numeric_level)
except Exception:
pass

# NOTE: The low-level Server doesn't have built-in session lifecycle hooks
# We'll need to capture client_info during the first request in each session
Expand Down
5 changes: 3 additions & 2 deletions eval_protocol/mcp_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,8 @@ async def rollout(
)

# Await all tasks and return concrete EvaluationRows
results: List[EvaluationRow] = await asyncio.gather(*tasks)
# Gather returns list of EvaluationRow; use type ignore to appease Pyright when inferring coroutine types
results: List[EvaluationRow] = await asyncio.gather(*tasks) # type: ignore[reportUnknownArgumentType]
return results


Expand Down Expand Up @@ -343,7 +344,7 @@ async def test_mcp(base_url: str, seeds: List[int]) -> Dict[str, Any]:
policy = FireworksPolicy("test-model")

# Run short rollout
evaluation_rows = rollout(envs, policy=policy, steps=10)
evaluation_rows = await rollout(envs, policy=policy, steps=10)

if evaluation_rows and len(evaluation_rows[0].messages) > 1:
results["successful"] += 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,11 @@ def reset(self, seed: Optional[int] = None) -> Tuple[Dict[str, Any], Dict[str, A
"""Reset the environment to initial state"""
logger.info("🔄 Resetting airline environment - reloading database from disk")
# FlightDB.load expects a str path
self.db = FlightDB.load(str(AIRLINE_DB_PATH))
# Ensure type matches expected FlightDB
# FlightDB.load returns vendor.tau2.domains.airline.data_model.FlightDB which is compatible
db_loaded = FlightDB.load(str(AIRLINE_DB_PATH))
assert isinstance(db_loaded, FlightDB)
self.db = db_loaded
self.airline_tools = AirlineTools(self.db)

return {}, {}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,9 @@ class MockEnvironment:
def __init__(self, config: Optional[Dict[str, Any]] = None):
self.config = config or {}
# MockDB.load expects a str path
self.db = MockDB.load(str(MOCK_DB_PATH))
db_loaded = MockDB.load(str(MOCK_DB_PATH))
assert isinstance(db_loaded, MockDB)
self.db = db_loaded
self.mock_tools = MockTools(self.db)

def reset(self, seed: Optional[int] = None) -> Tuple[Dict[str, Any], Dict[str, Any]]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,9 @@ def __init__(self, config: Optional[Dict[str, Any]] = None):
def reset(self, seed: Optional[int] = None) -> Tuple[Dict[str, Any], Dict[str, Any]]:
"""Reset the environment to initial state"""
# RetailDB.load expects a str path
self.db = RetailDB.load(str(RETAIL_DB_PATH))
db_loaded = RetailDB.load(str(RETAIL_DB_PATH))
assert isinstance(db_loaded, RetailDB)
self.db = db_loaded
self.retail_tools = RetailTools(self.db)

return {}, {}
Expand Down
2 changes: 2 additions & 0 deletions eval_protocol/pytest/default_mcp_gym_rollout_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,7 @@ def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig) ->
)

# Create MCP environments directly from evaluation_rows
assert self.policy is not None, "Policy must be initialized before rollout"
envs = ep.make(
"http://localhost:9700/mcp/",
evaluation_rows=rows,
Expand All @@ -252,6 +253,7 @@ def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig) ->
# Get rollout tasks from ep.rollout
async def _run_rollout_and_wrap(row_index: int) -> EvaluationRow:
# ep.rollout now returns concrete results
assert self.policy is not None, "Policy must be initialized before rollout"
results = await ep.rollout(
envs,
policy=self.policy,
Expand Down
4 changes: 3 additions & 1 deletion eval_protocol/rewards/apps_testing_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,9 @@ def run_test(in_outs, test=None, debug=False, timeout=15):
print(f"get method = {datetime.now().time()}")

try:
method = getattr(tmp, method_name)
# Ensure attribute name is a string for getattr
method_name_str = str(method_name)
method = getattr(tmp, method_name_str)
except AttributeError: # More specific exception
signal.alarm(0)
error_traceback = traceback.format_exc()
Expand Down
6 changes: 3 additions & 3 deletions eval_protocol/rewards/function_calling.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,7 +451,7 @@ def schema_jaccard_reward(
DeprecationWarning,
stacklevel=2,
)
return exact_tool_match_reward(messages, ground_truth, **kwargs)
return exact_tool_match_reward(messages=messages, ground_truth=ground_truth, **kwargs)


@reward_function
Expand Down Expand Up @@ -493,7 +493,7 @@ def llm_judge_reward(
DeprecationWarning,
stacklevel=2,
)
return exact_tool_match_reward(messages, ground_truth, **kwargs)
return exact_tool_match_reward(messages=messages, ground_truth=ground_truth, **kwargs)


@reward_function
Expand Down Expand Up @@ -537,7 +537,7 @@ def composite_function_call_reward(
DeprecationWarning,
stacklevel=2,
)
return exact_tool_match_reward(messages, ground_truth, **kwargs)
return exact_tool_match_reward(messages=messages, ground_truth=ground_truth, **kwargs)


# JSON schema reward functions have been moved to json_schema.py module
6 changes: 3 additions & 3 deletions eval_protocol/rewards/json_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def json_schema_reward(
else:
try:
parts: List[ChatCompletionContentPartTextParam] = last_message.content # type: ignore[assignment]
content_text = "\n".join(p.text for p in parts)
content_text = "\n".join(getattr(p, "text", "") for p in parts)
except Exception:
content_text = ""
else:
Expand Down Expand Up @@ -290,8 +290,8 @@ def json_schema_reward_with_llm_judge(
normalized_weights = {k: v / total_weight for k, v in weights.items()}

schema_result = json_schema_reward(
messages,
ground_truth,
messages=messages,
ground_truth=ground_truth,
json_content=json_content,
expected_schema=expected_schema,
**kwargs,
Expand Down
Loading