Skip to content

Commit 3ec9a06

Browse files
benjibcBenny Chen
andauthored
pyright fix round 5 (#144)
Co-authored-by: Benny Chen <bchen@Bennys-MacBook-Air.local>
1 parent 1a5e11e commit 3ec9a06

File tree

13 files changed

+199
-75
lines changed

13 files changed

+199
-75
lines changed

eval_protocol/_version.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,9 @@ def run_command(
121121
if verbose:
122122
print("unable to find command, tried %s" % (commands,))
123123
return None, None
124-
stdout = process.communicate()[0].strip().decode()
124+
stdout_bytes = process.communicate()[0]
125+
stdout_raw = stdout_bytes.decode() if isinstance(stdout_bytes, (bytes, bytearray)) else stdout_bytes
126+
stdout = str(stdout_raw).strip()
125127
if process.returncode != 0:
126128
if verbose:
127129
print("unable to run %s (error)" % dispcmd)

eval_protocol/adapters/bigquery.py

Lines changed: 24 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -7,34 +7,36 @@
77
from __future__ import annotations
88

99
import logging
10-
from typing import TYPE_CHECKING, Any, Callable, Dict, Iterator, List, Optional, Union
10+
from typing import TYPE_CHECKING, Any, Callable, Dict, Iterator, List, Optional, Union, cast, TypeAlias
1111

1212
from eval_protocol.models import CompletionParams, EvaluationRow, InputMetadata, Message
1313

1414
logger = logging.getLogger(__name__)
1515

1616
try:
17+
# Import at runtime if available
1718
from google.auth.exceptions import DefaultCredentialsError
18-
from google.cloud import bigquery
19+
from google.cloud import bigquery as _bigquery_runtime # type: ignore
1920
from google.cloud.exceptions import Forbidden, NotFound
2021
from google.oauth2 import service_account
2122

2223
BIGQUERY_AVAILABLE = True
2324
except ImportError:
25+
# Provide fallbacks for type checking/runtime when package is missing
26+
DefaultCredentialsError = Exception # type: ignore[assignment]
27+
Forbidden = Exception # type: ignore[assignment]
28+
NotFound = Exception # type: ignore[assignment]
29+
service_account: Any
30+
service_account = None
31+
_bigquery_runtime = None # type: ignore[assignment]
2432
BIGQUERY_AVAILABLE = False
2533
# Optional dependency: avoid noisy warnings during import
2634
logger.debug("Google Cloud BigQuery not installed. Optional feature disabled.")
2735

28-
# Avoid importing BigQuery types at runtime for annotations when not installed
29-
if TYPE_CHECKING:
30-
from google.cloud import bigquery as _bigquery_type
31-
32-
QueryParameterType = Union[
33-
_bigquery_type.ScalarQueryParameter,
34-
_bigquery_type.ArrayQueryParameter,
35-
]
36-
else:
37-
QueryParameterType = Any
36+
# Simple type aliases to avoid importing optional google types under pyright
37+
QueryParameterType: TypeAlias = Any
38+
BigQueryClient: TypeAlias = Any
39+
QueryJobConfig: TypeAlias = Any
3840

3941
# Type alias for transformation function
4042
TransformFunction = Callable[[Dict[str, Any]], Dict[str, Any]]
@@ -98,7 +100,13 @@ def __init__(
98100
client_args["location"] = location
99101

100102
client_args.update(client_kwargs)
101-
self.client = bigquery.Client(**client_args)
103+
# Use runtime alias to avoid basedpyright import symbol error when lib is missing
104+
if _bigquery_runtime is None:
105+
raise ImportError(
106+
"google-cloud-bigquery is not installed. Install with: pip install 'eval-protocol[bigquery]'"
107+
)
108+
# Avoid strict typing on optional dependency
109+
self.client = _bigquery_runtime.Client(**client_args) # type: ignore[no-untyped-call, assignment]
102110

103111
except DefaultCredentialsError as e:
104112
logger.error("Failed to authenticate with BigQuery: %s", e)
@@ -139,7 +147,9 @@ def get_evaluation_rows(
139147
"""
140148
try:
141149
# Configure query job
142-
job_config = bigquery.QueryJobConfig()
150+
if _bigquery_runtime is None:
151+
raise RuntimeError("BigQuery runtime not available")
152+
job_config = _bigquery_runtime.QueryJobConfig() # type: ignore[no-untyped-call]
143153
if query_params:
144154
job_config.query_parameters = query_params
145155
if self.location:

eval_protocol/adapters/langchain.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from __future__ import annotations
22

33
import os
4-
from typing import List
4+
from typing import Any, Dict, List, Optional
55

66
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, ToolMessage
77

@@ -49,10 +49,10 @@ def serialize_lc_message_to_ep(msg: BaseMessage) -> Message:
4949
parts.append(item)
5050
content = "\n".join(parts)
5151

52-
tool_calls_payload = None
52+
tool_calls_payload: Optional[List[Dict[str, Any]]] = None
5353

54-
def _normalize_tool_calls(tc_list: list) -> list[dict]:
55-
mapped: List[dict] = []
54+
def _normalize_tool_calls(tc_list: List[Any]) -> List[Dict[str, Any]]:
55+
mapped: List[Dict[str, Any]] = []
5656
for call in tc_list:
5757
if not isinstance(call, dict):
5858
continue
@@ -104,8 +104,13 @@ def _normalize_tool_calls(tc_list: list) -> list[dict]:
104104
if collected:
105105
reasoning_content = "\n\n".join([s for s in collected if s]) or None
106106

107+
# Message.tool_calls expects List[ChatCompletionMessageToolCall] | None.
108+
# We pass through Dicts at runtime but avoid type error by casting.
107109
ep_msg = Message(
108-
role="assistant", content=content, tool_calls=tool_calls_payload, reasoning_content=reasoning_content
110+
role="assistant",
111+
content=content,
112+
tool_calls=tool_calls_payload, # type: ignore[arg-type]
113+
reasoning_content=reasoning_content,
109114
)
110115
_dbg_print(
111116
"[EP-Ser] -> EP Message:",

eval_protocol/adapters/langfuse.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
import logging
88
from datetime import datetime
9-
from typing import Any, Dict, Iterator, List, Optional
9+
from typing import Any, Dict, Iterator, List, Optional, cast
1010

1111
from eval_protocol.models import EvaluationRow, InputMetadata, Message
1212

@@ -63,7 +63,7 @@ def __init__(
6363
if not LANGFUSE_AVAILABLE:
6464
raise ImportError("Langfuse not installed. Install with: pip install 'eval-protocol[langfuse]'")
6565

66-
self.client = Langfuse(public_key=public_key, secret_key=secret_key, host=host)
66+
self.client = cast(Any, Langfuse)(public_key=public_key, secret_key=secret_key, host=host)
6767
self.project_id = project_id
6868

6969
def get_evaluation_rows(

eval_protocol/benchmarks/test_gpqa.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,8 @@ def _extract_abcd_letter(text: str) -> str | None:
5858

5959

6060
def _strip_gt_messages(msgs: list[Message]) -> list[Message]:
61+
# assert that all the messages just have a plain .content string field
62+
assert all(isinstance(m.content, str) for m in msgs), "Messages must have a plain .content string field"
6163
return [m for m in msgs if not (m.role == "system" and (m.content or "").startswith("__GT__:"))]
6264

6365

eval_protocol/benchmarks/test_tau_bench_retail.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,7 @@ def test_tau_bench_retail_evaluation(row: EvaluationRow) -> EvaluationRow:
188188
task = Task(
189189
id="Filler", evaluation_criteria=evaluation_criteria, user_scenario=UserScenario(instructions="Filler")
190190
) # id and user_scenario are required for the Task type but not used in calculating reward
191+
assert task.evaluation_criteria is not None, "Task evaluation criteria is None"
191192

192193
if RewardType.DB in task.evaluation_criteria.reward_basis:
193194
env_reward_info = EnvironmentEvaluator.calculate_reward(

eval_protocol/evaluation.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -257,22 +257,21 @@ def load_metric_folder(self, metric_name, folder_path):
257257
for keyword in decorator_node.keywords:
258258
if keyword.arg == "requirements":
259259
if isinstance(keyword.value, ast.List):
260-
reqs = []
260+
reqs: List[str] = []
261261
for elt in keyword.value.elts:
262-
if isinstance(elt, ast.Constant) and isinstance(
263-
elt.value, str
264-
): # Python 3.8+
265-
reqs.append(elt.value)
262+
if isinstance(elt, ast.Constant): # Python 3.8+
263+
if isinstance(elt.value, str):
264+
reqs.append(cast(str, elt.value))
266265
elif isinstance(elt, ast.Str): # Python < 3.8
267-
reqs.append(elt.s)
266+
reqs.append(cast(str, elt.s))
268267
if reqs:
269268
metric_requirements_list = cast(List[str], reqs)
270269
elif isinstance(keyword.value, ast.Constant) and isinstance(
271270
keyword.value.value, str
272271
): # Python 3.8+ (single req string)
273-
metric_requirements_list = [keyword.value.value]
272+
metric_requirements_list = [cast(str, keyword.value.value)]
274273
elif isinstance(keyword.value, ast.Str): # Python < 3.8 (single req string)
275-
metric_requirements_list = [keyword.value.s]
274+
metric_requirements_list = [cast(str, keyword.value.s)]
276275
break
277276
if metric_requirements_list:
278277
break

eval_protocol/mcp/client/connection.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -441,24 +441,34 @@ async def call_tool(self, session: MCPSession, tool_name: str, arguments: Dict)
441441
# Extract data plane results (observation only)
442442
if tool_result.content and len(tool_result.content) > 0:
443443
content = tool_result.content[0]
444-
if hasattr(content, "text"):
444+
# Safely attempt to read a "text" attribute if present across content types
445+
text_attr = getattr(content, "text", None)
446+
if isinstance(text_attr, str):
447+
content_text = text_attr
448+
elif isinstance(text_attr, list):
449+
# text can also be an array of parts with optional .text fields
450+
content_text = "".join([getattr(p, "text", "") for p in text_attr])
451+
else:
452+
content_text = None
453+
454+
if isinstance(content_text, str):
445455
# Fix: Handle empty or invalid JSON responses gracefully
446-
if not content.text or content.text.strip() == "":
456+
if content_text.strip() == "":
447457
logger.warning(f"Session {session.session_id}: Empty tool response from {tool_name}")
448458
observation = {
449459
"observation": "empty_response",
450460
"session_id": session.session_id,
451461
}
452462
else:
453463
try:
454-
observation = json.loads(content.text)
464+
observation = json.loads(content_text)
455465
except json.JSONDecodeError as e:
456466
logger.warning(
457-
f"Session {session.session_id}: Invalid JSON from {tool_name}: {content.text}. Error: {e}"
467+
f"Session {session.session_id}: Invalid JSON from {tool_name}: {content_text}. Error: {e}"
458468
)
459469
# Create a structured response from the raw text
460470
observation = {
461-
"observation": content.text,
471+
"observation": content_text,
462472
"session_id": session.session_id,
463473
"error": "invalid_json_response",
464474
}

eval_protocol/mcp/execution/policy.py

Lines changed: 32 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -117,10 +117,19 @@ def _setup_litellm_caching(
117117
logger.info("🗄️ Initialized disk caching")
118118

119119
elif cache_type == "s3":
120-
from litellm.caching.s3_cache import S3Cache
121-
122-
litellm.cache = S3Cache()
123-
logger.info("🗄️ Initialized S3 caching")
120+
try:
121+
from litellm.caching.s3_cache import S3Cache
122+
123+
# Some versions require positional or named 's3_bucket_name'
124+
s3_bucket_name = os.getenv("LITELLM_S3_BUCKET")
125+
if not s3_bucket_name:
126+
raise ValueError("Missing LITELLM_S3_BUCKET for S3 cache")
127+
# Use explicit arg name expected by basedpyright
128+
litellm.cache = S3Cache(s3_bucket_name=s3_bucket_name)
129+
logger.info("🗄️ Initialized S3 caching for bucket %s", s3_bucket_name)
130+
except Exception as e:
131+
logger.warning(f"Failed to initialize S3 cache ({e}); falling back to in-memory cache")
132+
litellm.cache = Cache()
124133

125134
except Exception as e:
126135
logger.warning(f"Failed to setup {cache_type} caching: {e}. Falling back to in-memory cache.")
@@ -147,7 +156,7 @@ def _clean_messages_for_api(self, messages: List[Dict]) -> List[Dict]:
147156
clean_messages.append(clean_msg)
148157
return clean_messages
149158

150-
async def _make_llm_call(self, messages: List[Dict], tools: List[Dict]) -> Dict:
159+
async def _make_llm_call(self, messages: List[Dict[str, Any]], tools: List[Dict[str, Any]]) -> Dict[str, Any]:
151160
"""
152161
Make an LLM API call with retry logic and caching.
153162
@@ -162,7 +171,7 @@ async def _make_llm_call(self, messages: List[Dict], tools: List[Dict]) -> Dict:
162171
clean_messages = self._clean_messages_for_api(messages)
163172

164173
# Prepare request parameters
165-
request_params = {
174+
request_params: Dict[str, Any] = {
166175
"messages": clean_messages,
167176
"temperature": self.temperature,
168177
"max_tokens": self.max_tokens,
@@ -188,7 +197,8 @@ async def _make_llm_call(self, messages: List[Dict], tools: List[Dict]) -> Dict:
188197
response = await acompletion(model=self.model_id, **request_params)
189198

190199
# Log cache hit/miss for monitoring
191-
cache_hit = getattr(response, "_hidden_params", {}).get("cache_hit", False)
200+
hidden = getattr(response, "_hidden_params", {})
201+
cache_hit = hidden.get("cache_hit", False) if isinstance(hidden, dict) else False
192202
if cache_hit:
193203
logger.debug(f"🎯 Cache hit for model: {self.model_id}")
194204
else:
@@ -199,31 +209,34 @@ async def _make_llm_call(self, messages: List[Dict], tools: List[Dict]) -> Dict:
199209
"choices": [
200210
{
201211
"message": {
202-
"role": response.choices[0].message.role,
203-
"content": response.choices[0].message.content,
212+
"role": getattr(getattr(response.choices[0], "message", object()), "role", "assistant"),
213+
"content": getattr(getattr(response.choices[0], "message", object()), "content", None),
204214
"tool_calls": (
205215
[
206216
{
207-
"id": tc.id,
208-
"type": tc.type,
217+
"id": getattr(tc, "id", None),
218+
"type": getattr(tc, "type", "function"),
209219
"function": {
210-
"name": tc.function.name,
211-
"arguments": tc.function.arguments,
220+
"name": getattr(getattr(tc, "function", None), "name", "tool"),
221+
"arguments": getattr(getattr(tc, "function", None), "arguments", "{}"),
212222
},
213223
}
214-
for tc in (response.choices[0].message.tool_calls or [])
224+
for tc in (
225+
getattr(getattr(response.choices[0], "message", object()), "tool_calls", [])
226+
or []
227+
)
215228
]
216-
if response.choices[0].message.tool_calls
229+
if getattr(getattr(response.choices[0], "message", object()), "tool_calls", None)
217230
else []
218231
),
219232
},
220-
"finish_reason": response.choices[0].finish_reason,
233+
"finish_reason": getattr(response.choices[0], "finish_reason", None),
221234
}
222235
],
223236
"usage": {
224-
"prompt_tokens": response.usage.prompt_tokens,
225-
"completion_tokens": response.usage.completion_tokens,
226-
"total_tokens": response.usage.total_tokens,
237+
"prompt_tokens": getattr(getattr(response, "usage", {}), "prompt_tokens", 0),
238+
"completion_tokens": getattr(getattr(response, "usage", {}), "completion_tokens", 0),
239+
"total_tokens": getattr(getattr(response, "usage", {}), "total_tokens", 0),
227240
},
228241
}
229242

eval_protocol/mcp_servers/tau2/__init__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,9 @@
1212
def get_server_script_path() -> str:
1313
"""Get the path to the tau2 MCP server script."""
1414
try:
15-
# Try to get from installed package
16-
with importlib.resources.as_file(importlib.resources.files(__package__) / "server.py") as server_path:
15+
# Try to get from installed package. __package__ can be None during some tooling.
16+
package = __package__ if __package__ is not None else __name__
17+
with importlib.resources.as_file(importlib.resources.files(package) / "server.py") as server_path:
1718
return str(server_path)
1819
except (ImportError, FileNotFoundError):
1920
# Fallback for development environment

0 commit comments

Comments
 (0)