Skip to content

Commit 56ea7ec

Browse files
benjibcBenny Chen
andauthored
pyright round 9 (#151)
Co-authored-by: Benny Chen <bchen@Bennys-MacBook-Air.local>
1 parent dcf7b0e commit 56ea7ec

File tree

12 files changed

+49
-25
lines changed

12 files changed

+49
-25
lines changed

eval_protocol/benchmarks/test_tau_bench_airline.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,9 @@ def test_tau_bench_airline_evaluation(row: EvaluationRow) -> EvaluationRow:
182182
trajectory_objects.append(UserMessage(role=role, content=text_content))
183183
elif role == "tool":
184184
tool_id = msg.tool_call_id
185-
trajectory_objects.append(ToolMessage(id=tool_id, role=role, content=text_content, requestor="assistant"))
185+
trajectory_objects.append(
186+
ToolMessage(id=tool_id or "unknown_tool_call", role=role, content=text_content, requestor="assistant")
187+
)
186188

187189
reward = 1.0
188190

eval_protocol/benchmarks/test_tau_bench_retail.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,9 @@ def test_tau_bench_retail_evaluation(row: EvaluationRow) -> EvaluationRow:
172172
trajectory_objects.append(UserMessage(role=role, content=text_content))
173173
elif role == "tool":
174174
tool_id = msg.tool_call_id
175-
trajectory_objects.append(ToolMessage(id=tool_id, role=role, content=text_content, requestor="assistant"))
175+
trajectory_objects.append(
176+
ToolMessage(id=tool_id or "unknown_tool_call", role=role, content=text_content, requestor="assistant")
177+
)
176178

177179
reward = 1.0
178180

eval_protocol/execution/pipeline.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -847,9 +847,11 @@ async def process_with_semaphore_wrapper(sample_idx: int, sample_data: Dict[str,
847847

848848
for i_outer in range(0, len(tasks), batch_size_for_logging):
849849
batch_tasks = tasks[i_outer : i_outer + batch_size_for_logging]
850-
batch_results_values: List[
851-
Union[Exception, Dict[str, Any], List[Dict[str, Any]]]
852-
] = await asyncio.gather(*batch_tasks, return_exceptions=True)
850+
# asyncio.gather with return_exceptions=True returns List[Any]; cast to expected union
851+
batch_results_values = cast(
852+
List[Union[Exception, Dict[str, Any], List[Dict[str, Any]]]],
853+
await asyncio.gather(*batch_tasks, return_exceptions=True),
854+
)
853855
for res_idx, res_or_exc in enumerate(batch_results_values):
854856
if isinstance(res_or_exc, Exception):
855857
logger.error(

eval_protocol/mcp/simulation_server.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def reset_environment(self, env, seed): ...
3030
from abc import ABC, abstractmethod
3131
from collections.abc import AsyncIterator
3232
from contextlib import asynccontextmanager
33-
from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple, Iterable
33+
from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple, Iterable, cast
3434
from pydantic import AnyUrl
3535

3636
import uvicorn
@@ -327,12 +327,12 @@ async def list_resources():
327327
# Extract docstring as description
328328
description = resource_func.__doc__ or f"Resource {resource_name}"
329329

330-
# Some callables may not have the attribute; guard for type checkers
331-
# MyPy/Pyright: Resource expects AnyUrl; convert string to str, letting pydantic coerce it
332-
uri_value = getattr(resource_func, "_resource_uri", f"/{resource_name}")
330+
# Some callables may not have the attribute; guard for type checkers.
331+
# Resource expects AnyUrl; pass as str and allow coercion by pydantic.
332+
uri_value: str = str(getattr(resource_func, "_resource_uri", f"/{resource_name}"))
333333
resources.append(
334334
Resource(
335-
uri=uri_value,
335+
uri=cast(AnyUrl, uri_value),
336336
name=resource_name,
337337
description=description,
338338
mimeType="application/json",
@@ -347,10 +347,15 @@ def _register_session_handlers(self):
347347
"""Register session initialization and cleanup handlers."""
348348

349349
@self.app.set_logging_level()
350-
async def set_logging_level(level: str):
350+
async def set_logging_level(level: str) -> None:
351351
"""Handle logging level requests."""
352-
logger.setLevel(getattr(logging, level.upper()))
353-
return {}
352+
# Validate and set logging level; ignore invalid values gracefully
353+
try:
354+
numeric_level = getattr(logging, level.upper())
355+
if isinstance(numeric_level, int):
356+
logger.setLevel(numeric_level)
357+
except Exception:
358+
pass
354359

355360
# NOTE: The low-level Server doesn't have built-in session lifecycle hooks
356361
# We'll need to capture client_info during the first request in each session

eval_protocol/mcp_env.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -315,7 +315,8 @@ async def rollout(
315315
)
316316

317317
# Await all tasks and return concrete EvaluationRows
318-
results: List[EvaluationRow] = await asyncio.gather(*tasks)
318+
# Gather returns list of EvaluationRow; use type ignore to appease Pyright when inferring coroutine types
319+
results: List[EvaluationRow] = await asyncio.gather(*tasks) # type: ignore[reportUnknownArgumentType]
319320
return results
320321

321322

@@ -343,7 +344,7 @@ async def test_mcp(base_url: str, seeds: List[int]) -> Dict[str, Any]:
343344
policy = FireworksPolicy("test-model")
344345

345346
# Run short rollout
346-
evaluation_rows = rollout(envs, policy=policy, steps=10)
347+
evaluation_rows = await rollout(envs, policy=policy, steps=10)
347348

348349
if evaluation_rows and len(evaluation_rows[0].messages) > 1:
349350
results["successful"] += 1

eval_protocol/mcp_servers/tau2/airplane_environment/airline_environment.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,11 @@ def reset(self, seed: Optional[int] = None) -> Tuple[Dict[str, Any], Dict[str, A
3939
"""Reset the environment to initial state"""
4040
logger.info("🔄 Resetting airline environment - reloading database from disk")
4141
# FlightDB.load expects a str path
42-
self.db = FlightDB.load(str(AIRLINE_DB_PATH))
42+
# Ensure type matches expected FlightDB
43+
# FlightDB.load returns vendor.tau2.domains.airline.data_model.FlightDB which is compatible
44+
db_loaded = FlightDB.load(str(AIRLINE_DB_PATH))
45+
assert isinstance(db_loaded, FlightDB)
46+
self.db = db_loaded
4347
self.airline_tools = AirlineTools(self.db)
4448

4549
return {}, {}

eval_protocol/mcp_servers/tau2/mock_environment/mock_environment.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,9 @@ class MockEnvironment:
3232
def __init__(self, config: Optional[Dict[str, Any]] = None):
3333
self.config = config or {}
3434
# MockDB.load expects a str path
35-
self.db = MockDB.load(str(MOCK_DB_PATH))
35+
db_loaded = MockDB.load(str(MOCK_DB_PATH))
36+
assert isinstance(db_loaded, MockDB)
37+
self.db = db_loaded
3638
self.mock_tools = MockTools(self.db)
3739

3840
def reset(self, seed: Optional[int] = None) -> Tuple[Dict[str, Any], Dict[str, Any]]:

eval_protocol/mcp_servers/tau2/retail_environment/retail_environment.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,9 @@ def __init__(self, config: Optional[Dict[str, Any]] = None):
3737
def reset(self, seed: Optional[int] = None) -> Tuple[Dict[str, Any], Dict[str, Any]]:
3838
"""Reset the environment to initial state"""
3939
# RetailDB.load expects a str path
40-
self.db = RetailDB.load(str(RETAIL_DB_PATH))
40+
db_loaded = RetailDB.load(str(RETAIL_DB_PATH))
41+
assert isinstance(db_loaded, RetailDB)
42+
self.db = db_loaded
4143
self.retail_tools = RetailTools(self.db)
4244

4345
return {}, {}

eval_protocol/pytest/default_mcp_gym_rollout_processor.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -243,6 +243,7 @@ def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig) ->
243243
)
244244

245245
# Create MCP environments directly from evaluation_rows
246+
assert self.policy is not None, "Policy must be initialized before rollout"
246247
envs = ep.make(
247248
"http://localhost:9700/mcp/",
248249
evaluation_rows=rows,
@@ -252,6 +253,7 @@ def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig) ->
252253
# Get rollout tasks from ep.rollout
253254
async def _run_rollout_and_wrap(row_index: int) -> EvaluationRow:
254255
# ep.rollout now returns concrete results
256+
assert self.policy is not None, "Policy must be initialized before rollout"
255257
results = await ep.rollout(
256258
envs,
257259
policy=self.policy,

eval_protocol/rewards/apps_testing_util.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -255,7 +255,9 @@ def run_test(in_outs, test=None, debug=False, timeout=15):
255255
print(f"get method = {datetime.now().time()}")
256256

257257
try:
258-
method = getattr(tmp, method_name)
258+
# Ensure attribute name is a string for getattr
259+
method_name_str = str(method_name)
260+
method = getattr(tmp, method_name_str)
259261
except AttributeError: # More specific exception
260262
signal.alarm(0)
261263
error_traceback = traceback.format_exc()

0 commit comments

Comments
 (0)