diff --git a/eval_protocol/benchmarks/test_tau_bench_airline.py b/eval_protocol/benchmarks/test_tau_bench_airline.py index d9701e5e..77cfec0c 100644 --- a/eval_protocol/benchmarks/test_tau_bench_airline.py +++ b/eval_protocol/benchmarks/test_tau_bench_airline.py @@ -182,7 +182,9 @@ def test_tau_bench_airline_evaluation(row: EvaluationRow) -> EvaluationRow: trajectory_objects.append(UserMessage(role=role, content=text_content)) elif role == "tool": tool_id = msg.tool_call_id - trajectory_objects.append(ToolMessage(id=tool_id, role=role, content=text_content, requestor="assistant")) + trajectory_objects.append( + ToolMessage(id=tool_id or "unknown_tool_call", role=role, content=text_content, requestor="assistant") + ) reward = 1.0 diff --git a/eval_protocol/benchmarks/test_tau_bench_retail.py b/eval_protocol/benchmarks/test_tau_bench_retail.py index 6ca8c040..68ec8430 100644 --- a/eval_protocol/benchmarks/test_tau_bench_retail.py +++ b/eval_protocol/benchmarks/test_tau_bench_retail.py @@ -172,7 +172,9 @@ def test_tau_bench_retail_evaluation(row: EvaluationRow) -> EvaluationRow: trajectory_objects.append(UserMessage(role=role, content=text_content)) elif role == "tool": tool_id = msg.tool_call_id - trajectory_objects.append(ToolMessage(id=tool_id, role=role, content=text_content, requestor="assistant")) + trajectory_objects.append( + ToolMessage(id=tool_id or "unknown_tool_call", role=role, content=text_content, requestor="assistant") + ) reward = 1.0 diff --git a/eval_protocol/execution/pipeline.py b/eval_protocol/execution/pipeline.py index dcda4ca0..b61b43b8 100644 --- a/eval_protocol/execution/pipeline.py +++ b/eval_protocol/execution/pipeline.py @@ -847,9 +847,11 @@ async def process_with_semaphore_wrapper(sample_idx: int, sample_data: Dict[str, for i_outer in range(0, len(tasks), batch_size_for_logging): batch_tasks = tasks[i_outer : i_outer + batch_size_for_logging] - batch_results_values: List[ - Union[Exception, Dict[str, Any], List[Dict[str, Any]]] - ] = await asyncio.gather(*batch_tasks, return_exceptions=True) + # asyncio.gather with return_exceptions=True returns List[Any]; cast to expected union + batch_results_values = cast( + List[Union[Exception, Dict[str, Any], List[Dict[str, Any]]]], + await asyncio.gather(*batch_tasks, return_exceptions=True), + ) for res_idx, res_or_exc in enumerate(batch_results_values): if isinstance(res_or_exc, Exception): logger.error( diff --git a/eval_protocol/mcp/simulation_server.py b/eval_protocol/mcp/simulation_server.py index 205e7a27..425ab32c 100644 --- a/eval_protocol/mcp/simulation_server.py +++ b/eval_protocol/mcp/simulation_server.py @@ -30,7 +30,7 @@ def reset_environment(self, env, seed): ... from abc import ABC, abstractmethod from collections.abc import AsyncIterator from contextlib import asynccontextmanager -from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple, Iterable +from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple, Iterable, cast from pydantic import AnyUrl import uvicorn @@ -327,12 +327,12 @@ async def list_resources(): # Extract docstring as description description = resource_func.__doc__ or f"Resource {resource_name}" - # Some callables may not have the attribute; guard for type checkers - # MyPy/Pyright: Resource expects AnyUrl; convert string to str, letting pydantic coerce it - uri_value = getattr(resource_func, "_resource_uri", f"/{resource_name}") + # Some callables may not have the attribute; guard for type checkers. + # Resource expects AnyUrl; pass as str and allow coercion by pydantic. + uri_value: str = str(getattr(resource_func, "_resource_uri", f"/{resource_name}")) resources.append( Resource( - uri=uri_value, + uri=cast(AnyUrl, uri_value), name=resource_name, description=description, mimeType="application/json", @@ -347,10 +347,15 @@ def _register_session_handlers(self): """Register session initialization and cleanup handlers.""" @self.app.set_logging_level() - async def set_logging_level(level: str): + async def set_logging_level(level: str) -> None: """Handle logging level requests.""" - logger.setLevel(getattr(logging, level.upper())) - return {} + # Validate and set logging level; ignore invalid values gracefully + try: + numeric_level = getattr(logging, level.upper()) + if isinstance(numeric_level, int): + logger.setLevel(numeric_level) + except Exception: + pass # NOTE: The low-level Server doesn't have built-in session lifecycle hooks # We'll need to capture client_info during the first request in each session diff --git a/eval_protocol/mcp_env.py b/eval_protocol/mcp_env.py index 779f7503..35ad517b 100644 --- a/eval_protocol/mcp_env.py +++ b/eval_protocol/mcp_env.py @@ -315,7 +315,8 @@ async def rollout( ) # Await all tasks and return concrete EvaluationRows - results: List[EvaluationRow] = await asyncio.gather(*tasks) + # Gather returns list of EvaluationRow; use type ignore to appease Pyright when inferring coroutine types + results: List[EvaluationRow] = await asyncio.gather(*tasks) # type: ignore[reportUnknownArgumentType] return results @@ -343,7 +344,7 @@ async def test_mcp(base_url: str, seeds: List[int]) -> Dict[str, Any]: policy = FireworksPolicy("test-model") # Run short rollout - evaluation_rows = rollout(envs, policy=policy, steps=10) + evaluation_rows = await rollout(envs, policy=policy, steps=10) if evaluation_rows and len(evaluation_rows[0].messages) > 1: results["successful"] += 1 diff --git a/eval_protocol/mcp_servers/tau2/airplane_environment/airline_environment.py b/eval_protocol/mcp_servers/tau2/airplane_environment/airline_environment.py index 9be30adf..1e9b00dd 100644 --- a/eval_protocol/mcp_servers/tau2/airplane_environment/airline_environment.py +++ b/eval_protocol/mcp_servers/tau2/airplane_environment/airline_environment.py @@ -39,7 +39,11 @@ def reset(self, seed: Optional[int] = None) -> Tuple[Dict[str, Any], Dict[str, A """Reset the environment to initial state""" logger.info("🔄 Resetting airline environment - reloading database from disk") # FlightDB.load expects a str path - self.db = FlightDB.load(str(AIRLINE_DB_PATH)) + # Ensure type matches expected FlightDB + # FlightDB.load returns vendor.tau2.domains.airline.data_model.FlightDB which is compatible + db_loaded = FlightDB.load(str(AIRLINE_DB_PATH)) + assert isinstance(db_loaded, FlightDB) + self.db = db_loaded self.airline_tools = AirlineTools(self.db) return {}, {} diff --git a/eval_protocol/mcp_servers/tau2/mock_environment/mock_environment.py b/eval_protocol/mcp_servers/tau2/mock_environment/mock_environment.py index e4a73fdf..87911964 100644 --- a/eval_protocol/mcp_servers/tau2/mock_environment/mock_environment.py +++ b/eval_protocol/mcp_servers/tau2/mock_environment/mock_environment.py @@ -32,7 +32,9 @@ class MockEnvironment: def __init__(self, config: Optional[Dict[str, Any]] = None): self.config = config or {} # MockDB.load expects a str path - self.db = MockDB.load(str(MOCK_DB_PATH)) + db_loaded = MockDB.load(str(MOCK_DB_PATH)) + assert isinstance(db_loaded, MockDB) + self.db = db_loaded self.mock_tools = MockTools(self.db) def reset(self, seed: Optional[int] = None) -> Tuple[Dict[str, Any], Dict[str, Any]]: diff --git a/eval_protocol/mcp_servers/tau2/retail_environment/retail_environment.py b/eval_protocol/mcp_servers/tau2/retail_environment/retail_environment.py index 91c364ad..8f099fd9 100644 --- a/eval_protocol/mcp_servers/tau2/retail_environment/retail_environment.py +++ b/eval_protocol/mcp_servers/tau2/retail_environment/retail_environment.py @@ -37,7 +37,9 @@ def __init__(self, config: Optional[Dict[str, Any]] = None): def reset(self, seed: Optional[int] = None) -> Tuple[Dict[str, Any], Dict[str, Any]]: """Reset the environment to initial state""" # RetailDB.load expects a str path - self.db = RetailDB.load(str(RETAIL_DB_PATH)) + db_loaded = RetailDB.load(str(RETAIL_DB_PATH)) + assert isinstance(db_loaded, RetailDB) + self.db = db_loaded self.retail_tools = RetailTools(self.db) return {}, {} diff --git a/eval_protocol/pytest/default_mcp_gym_rollout_processor.py b/eval_protocol/pytest/default_mcp_gym_rollout_processor.py index 21b1e99c..f875ee55 100644 --- a/eval_protocol/pytest/default_mcp_gym_rollout_processor.py +++ b/eval_protocol/pytest/default_mcp_gym_rollout_processor.py @@ -243,6 +243,7 @@ def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig) -> ) # Create MCP environments directly from evaluation_rows + assert self.policy is not None, "Policy must be initialized before rollout" envs = ep.make( "http://localhost:9700/mcp/", evaluation_rows=rows, @@ -252,6 +253,7 @@ def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig) -> # Get rollout tasks from ep.rollout async def _run_rollout_and_wrap(row_index: int) -> EvaluationRow: # ep.rollout now returns concrete results + assert self.policy is not None, "Policy must be initialized before rollout" results = await ep.rollout( envs, policy=self.policy, diff --git a/eval_protocol/rewards/apps_testing_util.py b/eval_protocol/rewards/apps_testing_util.py index 27f52cc5..e4b931f9 100644 --- a/eval_protocol/rewards/apps_testing_util.py +++ b/eval_protocol/rewards/apps_testing_util.py @@ -255,7 +255,9 @@ def run_test(in_outs, test=None, debug=False, timeout=15): print(f"get method = {datetime.now().time()}") try: - method = getattr(tmp, method_name) + # Ensure attribute name is a string for getattr + method_name_str = str(method_name) + method = getattr(tmp, method_name_str) except AttributeError: # More specific exception signal.alarm(0) error_traceback = traceback.format_exc() diff --git a/eval_protocol/rewards/function_calling.py b/eval_protocol/rewards/function_calling.py index aa0517b6..14e8bfe8 100644 --- a/eval_protocol/rewards/function_calling.py +++ b/eval_protocol/rewards/function_calling.py @@ -451,7 +451,7 @@ def schema_jaccard_reward( DeprecationWarning, stacklevel=2, ) - return exact_tool_match_reward(messages, ground_truth, **kwargs) + return exact_tool_match_reward(messages=messages, ground_truth=ground_truth, **kwargs) @reward_function @@ -493,7 +493,7 @@ def llm_judge_reward( DeprecationWarning, stacklevel=2, ) - return exact_tool_match_reward(messages, ground_truth, **kwargs) + return exact_tool_match_reward(messages=messages, ground_truth=ground_truth, **kwargs) @reward_function @@ -537,7 +537,7 @@ def composite_function_call_reward( DeprecationWarning, stacklevel=2, ) - return exact_tool_match_reward(messages, ground_truth, **kwargs) + return exact_tool_match_reward(messages=messages, ground_truth=ground_truth, **kwargs) # JSON schema reward functions have been moved to json_schema.py module diff --git a/eval_protocol/rewards/json_schema.py b/eval_protocol/rewards/json_schema.py index e48c1160..9c446e67 100644 --- a/eval_protocol/rewards/json_schema.py +++ b/eval_protocol/rewards/json_schema.py @@ -60,7 +60,7 @@ def json_schema_reward( else: try: parts: List[ChatCompletionContentPartTextParam] = last_message.content # type: ignore[assignment] - content_text = "\n".join(p.text for p in parts) + content_text = "\n".join(getattr(p, "text", "") for p in parts) except Exception: content_text = "" else: @@ -290,8 +290,8 @@ def json_schema_reward_with_llm_judge( normalized_weights = {k: v / total_weight for k, v in weights.items()} schema_result = json_schema_reward( - messages, - ground_truth, + messages=messages, + ground_truth=ground_truth, json_content=json_content, expected_schema=expected_schema, **kwargs,