Skip to content

Commit 0644511

Browse files
benjibcBenny Chen
andauthored
fix pyright round 3 (#142)
* fix pyright round 3 * fix tests --------- Co-authored-by: Benny Chen <bchen@Bennys-MacBook-Air.local>
1 parent a2e232a commit 0644511

File tree

12 files changed

+87
-38
lines changed

12 files changed

+87
-38
lines changed

eval_protocol/cli_commands/deploy.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
import socket
3030
import subprocess
3131

32-
def start_process(command, log_path, env=None):
32+
def _fallback_start_process(command, log_path, env=None):
3333
"""Fallback process starter."""
3434
try:
3535
with open(log_path, "w") as log_file:
@@ -39,7 +39,7 @@ def start_process(command, log_path, env=None):
3939
print(f"Error starting process: {e}")
4040
return None
4141

42-
def stop_process(pid):
42+
def _fallback_stop_process(pid):
4343
"""Fallback process stopper."""
4444
try:
4545
import os
@@ -48,15 +48,21 @@ def stop_process(pid):
4848
except Exception:
4949
pass
5050

51-
def start_serveo_and_get_url(local_port, log_path):
51+
def _fallback_start_serveo_and_get_url(local_port, log_path):
5252
"""Fallback serveo tunnel - returns None to indicate unavailable."""
5353
print("Serveo tunneling not available - development module not found")
5454
return None, None
5555

56-
def start_ngrok_and_get_url(local_port, log_path):
56+
def _fallback_start_ngrok_and_get_url(local_port, log_path):
5757
"""Fallback ngrok tunnel - returns None to indicate unavailable."""
5858
print("ngrok tunneling not available - development module not found")
5959
return None, None
60+
61+
# Expose unified names using fallbacks
62+
start_process = _fallback_start_process
63+
stop_process = _fallback_stop_process
64+
start_serveo_and_get_url = _fallback_start_serveo_and_get_url
65+
start_ngrok_and_get_url = _fallback_start_ngrok_and_get_url
6066
else:
6167
# Wrap imported helpers to present consistent, simple signatures used below
6268
def start_process(command, log_path, env=None):
@@ -66,7 +72,7 @@ def stop_process(pid):
6672
return _stop_process(pid)
6773

6874
def start_serveo_and_get_url(local_port, log_path):
69-
return _start_serveo_and_get_url(local_port=local_port, log_path=log_path)
75+
return _start_serveo_and_get_url(local_port=local_port, log_file_path=log_path)
7076

7177
def start_ngrok_and_get_url(local_port, log_path):
7278
return _start_ngrok_and_get_url(local_port=local_port, ngrok_log_file=log_path)

eval_protocol/evaluation.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import time
88
import types
99
from pathlib import Path
10-
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
10+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union, cast
1111

1212
if TYPE_CHECKING:
1313
# For type checking only
@@ -173,6 +173,8 @@ def __init__(
173173
self.description = ""
174174
self.display_name = ""
175175
self.api_base = os.environ.get("FIREWORKS_API_BASE", "https://api.fireworks.ai")
176+
# Optional requirements string for multi-metric mode (when loaded differently)
177+
self._loaded_multi_metric_requirements_str: Optional[str] = None
176178

177179
if self.ts_mode_config:
178180
python_code = self.ts_mode_config.get("python_code")
@@ -264,7 +266,7 @@ def load_metric_folder(self, metric_name, folder_path):
264266
elif isinstance(elt, ast.Str): # Python < 3.8
265267
reqs.append(elt.s)
266268
if reqs:
267-
metric_requirements_list = reqs
269+
metric_requirements_list = cast(List[str], reqs)
268270
elif isinstance(keyword.value, ast.Constant) and isinstance(
269271
keyword.value.value, str
270272
): # Python 3.8+ (single req string)

eval_protocol/get_pep440_version.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,15 @@
11
# Cache for PEP 440 version string
22
import subprocess
33

4-
_version_cache = {"version": None, "base_version": None}
4+
from typing import Dict, Optional, TypedDict
5+
6+
7+
class _VersionCache(TypedDict):
8+
version: Optional[str]
9+
base_version: Optional[str]
10+
11+
12+
_version_cache: _VersionCache = {"version": None, "base_version": None}
513

614

715
def get_pep440_version(base_version=None):

eval_protocol/mcp/client/connection.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -306,26 +306,28 @@ async def _get_initial_state_from_mcp_resource(self, session: MCPSession) -> Any
306306
resource_content = await mcp_session.read_resource(initial_state_resource.uri)
307307

308308
# Handle the new ResourceContents format
309-
if hasattr(resource_content, "text"):
309+
text_value = getattr(resource_content, "text", None)
310+
if text_value is not None:
310311
try:
311-
initial_observation = json.loads(resource_content.text)
312+
initial_observation = json.loads(text_value)
312313
logger.info(
313314
f"Session {session.session_id}: ✅ Successfully parsed JSON initial state with grid_layout: {initial_observation.get('grid_layout', 'N/A')[:20]}..."
314315
)
315316
except json.JSONDecodeError:
316-
initial_observation = {"observation": resource_content.text}
317+
initial_observation = {"observation": text_value}
317318
elif (
318319
hasattr(resource_content, "contents")
319320
and resource_content.contents
320321
and len(resource_content.contents) > 0
321322
):
322323
# Fallback to old format for backward compatibility
323324
content = resource_content.contents[0]
324-
if hasattr(content, "text"):
325+
content_text = getattr(content, "text", None)
326+
if content_text is not None:
325327
try:
326-
initial_observation = json.loads(content.text)
328+
initial_observation = json.loads(content_text)
327329
except json.JSONDecodeError:
328-
initial_observation = {"observation": content.text}
330+
initial_observation = {"observation": content_text}
329331
else:
330332
initial_observation = {"observation": str(resource_content)}
331333
else:
@@ -359,23 +361,25 @@ async def _get_initial_state_from_mcp_resource(self, session: MCPSession) -> Any
359361
)
360362

361363
# Handle the new ResourceContents format
362-
if hasattr(resource_content, "text"):
364+
text_value_2 = getattr(resource_content, "text", None)
365+
if text_value_2 is not None:
363366
try:
364-
initial_observation = json.loads(resource_content.text)
367+
initial_observation = json.loads(text_value_2)
365368
except json.JSONDecodeError:
366-
initial_observation = {"observation": resource_content.text}
369+
initial_observation = {"observation": text_value_2}
367370
elif (
368371
hasattr(resource_content, "contents")
369372
and resource_content.contents
370373
and len(resource_content.contents) > 0
371374
):
372375
# Fallback to old format for backward compatibility
373376
content = resource_content.contents[0]
374-
if hasattr(content, "text"):
377+
content_text_2 = getattr(content, "text", None)
378+
if content_text_2 is not None:
375379
try:
376-
initial_observation = json.loads(content.text)
380+
initial_observation = json.loads(content_text_2)
377381
except json.JSONDecodeError:
378-
initial_observation = {"observation": content.text}
382+
initial_observation = {"observation": content_text_2}
379383
else:
380384
initial_observation = {"observation": str(content)}
381385
else:

eval_protocol/mcp/clients.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,8 @@ def __init__(self, intermediary_server_url: str):
2929

3030
async def connect(self):
3131
"""Establishes connection and MCP session."""
32-
if self._mcp_session is not None and not self._mcp_session.is_closed:
32+
# ClientSession does not expose a stable public `is_closed`; consider session presence sufficient
33+
if self._mcp_session is not None:
3334
logger.debug("Already connected.")
3435
return
3536

@@ -97,26 +98,27 @@ async def _call_intermediary_tool(self, tool_name: str, tool_args_payload: Dict[
9798
if mcp_response.isError or not mcp_response.content or not hasattr(mcp_response.content[0], "text"):
9899
error_message = f"Tool call '{tool_name}' to intermediary failed."
99100
if mcp_response.isError and mcp_response.content and hasattr(mcp_response.content[0], "text"):
100-
error_message += f" Details: {mcp_response.content[0].text}"
101+
error_text = getattr(mcp_response.content[0], "text", "")
102+
error_message += f" Details: {error_text}"
101103
elif mcp_response.isError:
102104
error_message += " No detailed error message in content."
103105
logger.error(error_message)
104106
try:
105107
if mcp_response.content and hasattr(mcp_response.content[0], "text"):
106-
parsed_error = json.loads(mcp_response.content[0].text)
108+
parsed_error = json.loads(getattr(mcp_response.content[0], "text", ""))
107109
if isinstance(parsed_error, dict) and "error" in parsed_error:
108110
raise RuntimeError(f"{error_message} Nested error: {parsed_error['error']}")
109111
except (json.JSONDecodeError, TypeError):
110112
pass
111113
raise RuntimeError(error_message)
112114

113115
try:
114-
parsed_result = json.loads(mcp_response.content[0].text)
116+
parsed_result = json.loads(getattr(mcp_response.content[0], "text", ""))
115117
logger.debug(f"Parsed JSON result from intermediary for '{tool_name}': {parsed_result}")
116118
return parsed_result
117119
except json.JSONDecodeError as e:
118120
logger.error(
119-
f"Failed to parse JSON from intermediary's tool '{tool_name}' response content: {mcp_response.content[0].text}. Error: {e}"
121+
f"Failed to parse JSON from intermediary's tool '{tool_name}' response content: {getattr(mcp_response.content[0], 'text', '')}. Error: {e}"
120122
)
121123
raise RuntimeError(f"Failed to parse JSON response from intermediary tool '{tool_name}'.")
122124

eval_protocol/rewards/code_execution.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def extract_code_blocks(text: str, language: Optional[str] = None) -> List[Dict[
8080
List of dictionaries with "code" and "language" keys
8181
"""
8282
pattern = r"```(\w*)\n([\s\S]*?)\n```"
83-
matches = re.findall(pattern, text)
83+
matches = re.findall(pattern, text or "")
8484

8585
code_blocks = []
8686
verbose_patterns_removed = []
@@ -1098,7 +1098,15 @@ def fractional_code_reward(
10981098
},
10991099
)
11001100

1101-
code_blocks = extract_code_blocks(response_content, language)
1101+
# Normalize content to string; Message.content may be str or list of content parts
1102+
_last_content = response_content
1103+
response_content_str = (
1104+
_last_content
1105+
if isinstance(_last_content, str)
1106+
else "".join([getattr(p, "text", "") for p in (_last_content or [])])
1107+
)
1108+
1109+
code_blocks = extract_code_blocks(response_content_str, language)
11021110

11031111
if not code_blocks:
11041112
return EvaluateResult(
@@ -1617,7 +1625,7 @@ class Capturing(list):
16171625
def __enter__(self):
16181626
self._stdout = sys.stdout
16191627
sys.stdout = self._stringio = StringIO()
1620-
self._stringio.close = lambda x: None
1628+
self._stringio.close = lambda: None
16211629
return self
16221630

16231631
def __exit__(self, *args):

eval_protocol/rewards/deepcoder_reward.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,12 @@ def deepcoder_code_reward(
7373
is_score_valid=False,
7474
)
7575

76-
assistant_content = messages[-1].content
76+
assistant_content_raw = messages[-1].content
77+
assistant_content = (
78+
assistant_content_raw
79+
if isinstance(assistant_content_raw, str)
80+
else "".join([getattr(p, "text", "") for p in (assistant_content_raw or [])])
81+
)
7782
test_cases = ground_truth
7883

7984
code_blocks = extract_code_blocks(assistant_content, language)

eval_protocol/rewards/list_comparison_math_reward.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,12 @@ def list_comparison_math_reward(
127127
},
128128
)
129129

130-
gen_content = messages[-1].content
130+
gen_content_raw = messages[-1].content
131+
gen_content = (
132+
gen_content_raw
133+
if isinstance(gen_content_raw, str)
134+
else "".join([getattr(p, "text", "") for p in (gen_content_raw or [])])
135+
)
131136
orig_content = ground_truth
132137

133138
if not gen_content:

eval_protocol/rewards/multiple_choice_math_reward.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,12 @@ def multiple_choice_math_reward(
134134
if messages and len(messages) > 0:
135135
gen_response_message = messages[-1]
136136
if gen_response_message.role == "assistant":
137-
gen_content = gen_response_message.content or ""
137+
raw_gen_content = gen_response_message.content
138+
gen_content = (
139+
raw_gen_content
140+
if isinstance(raw_gen_content, str)
141+
else "".join([getattr(p, "text", "") for p in (raw_gen_content or [])])
142+
)
138143

139144
if not gen_content:
140145
metrics["error_generated_message"] = MetricResult(
@@ -152,7 +157,12 @@ def multiple_choice_math_reward(
152157
if ground_truth and len(ground_truth) > 0:
153158
orig_response_message = ground_truth[0]
154159
if orig_response_message.role == "assistant":
155-
orig_content = orig_response_message.content or ""
160+
raw_orig_content = orig_response_message.content
161+
orig_content = (
162+
raw_orig_content
163+
if isinstance(raw_orig_content, str)
164+
else "".join([getattr(p, "text", "") for p in (raw_orig_content or [])])
165+
)
156166

157167
if not orig_content:
158168
metrics["error_original_message"] = MetricResult(

eval_protocol/typed_interface.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -82,8 +82,7 @@ def decorator(func: F) -> F:
8282

8383
if not has_var_keyword:
8484
raise ValueError(
85-
f"Function '{func.__name__}' must accept **kwargs parameter. "
86-
f"Please add '**kwargs' to the function signature."
85+
f"Function '{func.__name__}' must accept **kwargs parameter. Please add '**kwargs' to the function signature."
8786
)
8887

8988
# Setup resources once when the decorator is applied
@@ -113,7 +112,7 @@ def _is_list_of_message_annotation(annotation: Any) -> bool:
113112
inner = non_none[0]
114113
inner_origin = get_origin(inner)
115114
inner_args = get_args(inner)
116-
return inner_origin in (list, List) and inner_args and inner_args[0] == Message
115+
return (inner_origin in (list, List)) and bool(inner_args) and (inner_args[0] == Message)
117116
return False
118117

119118
def _prepare_final_args(*args: Any, **kwargs: Any):

0 commit comments

Comments
 (0)