Skip to content

Commit 000c2a4

Browse files
benjibcBenny Chen
andauthored
fix more pyright issues (#139)
* fix more pyright issues * fix a few more --------- Co-authored-by: Benny Chen <bchen@Bennys-MacBook-Air.local>
1 parent 1054bf6 commit 000c2a4

File tree

13 files changed

+132
-45
lines changed

13 files changed

+132
-45
lines changed

.pre-commit-config.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,3 +27,7 @@ repos:
2727
hooks:
2828
- id: basedpyright
2929
args: ["--level", "error"]
30+
env:
31+
NODE_OPTIONS: "--max-old-space-size=4096"
32+
# Only check Python files in the main package to reduce memory usage
33+
files: ^eval_protocol/.*\.py$

eval_protocol/benchmarks/test_tau_bench_airline.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from eval_protocol.pytest import evaluation_test, ExceptionHandlerConfig
1515
from eval_protocol.pytest.default_mcp_gym_rollout_processor import MCPGymRolloutProcessor
1616
import litellm
17+
from litellm.exceptions import RateLimitError, APIConnectionError
1718
from vendor.tau2.data_model.message import (
1819
AssistantMessage,
1920
SystemMessage,
@@ -125,8 +126,8 @@ def tau_bench_airline_to_evaluation_row(data: List[Dict[str, Any]]) -> List[Eval
125126
server_script_path=_get_server_script_path(),
126127
exception_handler_config=ExceptionHandlerConfig(
127128
retryable_exceptions={
128-
litellm.RateLimitError,
129-
litellm.APIConnectionError,
129+
RateLimitError,
130+
APIConnectionError,
130131
}
131132
),
132133
)
@@ -159,8 +160,10 @@ def test_tau_bench_airline_evaluation(row: EvaluationRow) -> EvaluationRow:
159160
role = msg.role
160161
content = msg.content
161162

163+
# Normalize content to str for tau2 message models
164+
text_content = content if isinstance(content, str) or content is None else ""
162165
if role == "system":
163-
trajectory_objects.append(SystemMessage(role=role, content=content))
166+
trajectory_objects.append(SystemMessage(role=role, content=text_content))
164167
elif role == "assistant":
165168
tau2_tool_calls = []
166169
if msg.tool_calls:
@@ -173,12 +176,12 @@ def test_tau_bench_airline_evaluation(row: EvaluationRow) -> EvaluationRow:
173176
)
174177
tau2_tool_calls.append(tau2_tool_call)
175178

176-
trajectory_objects.append(AssistantMessage(role=role, content=content, tool_calls=tau2_tool_calls))
179+
trajectory_objects.append(AssistantMessage(role=role, content=text_content, tool_calls=tau2_tool_calls))
177180
elif role == "user":
178-
trajectory_objects.append(UserMessage(role=role, content=content))
181+
trajectory_objects.append(UserMessage(role=role, content=text_content))
179182
elif role == "tool":
180183
tool_id = msg.tool_call_id
181-
trajectory_objects.append(ToolMessage(id=tool_id, role=role, content=content))
184+
trajectory_objects.append(ToolMessage(id=tool_id, role=role, content=text_content))
182185

183186
reward = 1.0
184187

eval_protocol/benchmarks/test_tau_bench_retail.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from eval_protocol.pytest import evaluation_test, ExceptionHandlerConfig
1515
from eval_protocol.pytest.default_mcp_gym_rollout_processor import MCPGymRolloutProcessor
1616
import litellm
17+
from litellm.exceptions import RateLimitError, APIConnectionError
1718
from vendor.tau2.data_model.message import (
1819
AssistantMessage,
1920
SystemMessage,
@@ -115,8 +116,8 @@ def tau_bench_retail_to_evaluation_row(data: List[Dict[str, Any]]) -> List[Evalu
115116
server_script_path=get_server_script_path(),
116117
exception_handler_config=ExceptionHandlerConfig(
117118
retryable_exceptions={
118-
litellm.RateLimitError,
119-
litellm.APIConnectionError,
119+
RateLimitError,
120+
APIConnectionError,
120121
}
121122
),
122123
)
@@ -149,8 +150,10 @@ def test_tau_bench_retail_evaluation(row: EvaluationRow) -> EvaluationRow:
149150
role = msg.role
150151
content = msg.content
151152

153+
# Normalize content to str for tau2 message models
154+
text_content = content if isinstance(content, str) or content is None else ""
152155
if role == "system":
153-
trajectory_objects.append(SystemMessage(role=role, content=content))
156+
trajectory_objects.append(SystemMessage(role=role, content=text_content))
154157
elif role == "assistant":
155158
tau2_tool_calls = []
156159
if msg.tool_calls:
@@ -163,12 +166,12 @@ def test_tau_bench_retail_evaluation(row: EvaluationRow) -> EvaluationRow:
163166
)
164167
tau2_tool_calls.append(tau2_tool_call)
165168

166-
trajectory_objects.append(AssistantMessage(role=role, content=content, tool_calls=tau2_tool_calls))
169+
trajectory_objects.append(AssistantMessage(role=role, content=text_content, tool_calls=tau2_tool_calls))
167170
elif role == "user":
168-
trajectory_objects.append(UserMessage(role=role, content=content))
171+
trajectory_objects.append(UserMessage(role=role, content=text_content))
169172
elif role == "tool":
170173
tool_id = msg.tool_call_id
171-
trajectory_objects.append(ToolMessage(id=tool_id, role=role, content=content))
174+
trajectory_objects.append(ToolMessage(id=tool_id, role=role, content=text_content))
172175

173176
reward = 1.0
174177

eval_protocol/cli_commands/deploy.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,12 @@
1616

1717
# TODO: Consider moving subprocess_manager functions to a more central location if used by core CLI
1818
try:
19+
# Import functions with explicit names to match expected signatures
1920
from development.utils.subprocess_manager import (
20-
start_ngrok_and_get_url, # Added ngrok function
21-
start_process,
22-
start_serveo_and_get_url,
23-
stop_process,
21+
start_ngrok_and_get_url as _start_ngrok_and_get_url,
22+
start_process as _start_process,
23+
start_serveo_and_get_url as _start_serveo_and_get_url,
24+
stop_process as _stop_process,
2425
)
2526
except ImportError:
2627
# Fallback implementations when development module is not available
@@ -56,6 +57,19 @@ def start_ngrok_and_get_url(local_port, log_path):
5657
"""Fallback ngrok tunnel - returns None to indicate unavailable."""
5758
print("ngrok tunneling not available - development module not found")
5859
return None, None
60+
else:
61+
# Wrap imported helpers to present consistent, simple signatures used below
62+
def start_process(command, log_path, env=None):
63+
return _start_process(command=command, log_file_path=log_path, env=env)
64+
65+
def stop_process(pid):
66+
return _stop_process(pid)
67+
68+
def start_serveo_and_get_url(local_port, log_path):
69+
return _start_serveo_and_get_url(local_port=local_port, log_path=log_path)
70+
71+
def start_ngrok_and_get_url(local_port, log_path):
72+
return _start_ngrok_and_get_url(local_port=local_port, ngrok_log_file=log_path)
5973

6074

6175
from eval_protocol.auth import get_fireworks_account_id

eval_protocol/mcp/execution/policy.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,10 @@
1414

1515
import litellm
1616
from litellm import acompletion, completion
17-
from litellm.caching import Cache, DualCache, InMemoryCache, RedisCache
17+
from litellm.caching.caching import Cache
18+
from litellm.caching.dual_cache import DualCache
19+
from litellm.caching.in_memory_cache import InMemoryCache
20+
from litellm.caching.redis_cache import RedisCache
1821

1922
from .base_policy import LLMBasePolicy
2023

@@ -108,13 +111,13 @@ def _setup_litellm_caching(
108111
logger.info("🗄️ Initialized dual caching (memory + Redis)")
109112

110113
elif cache_type == "disk":
111-
from litellm.caching import DiskCache
114+
from litellm.caching.disk_cache import DiskCache
112115

113116
litellm.cache = DiskCache()
114117
logger.info("🗄️ Initialized disk caching")
115118

116119
elif cache_type == "s3":
117-
from litellm.caching import S3Cache
120+
from litellm.caching.s3_cache import S3Cache
118121

119122
litellm.cache = S3Cache()
120123
logger.info("🗄️ Initialized S3 caching")

eval_protocol/mcp_agent/orchestration/local_docker_client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from anyio.abc import ObjectReceiveStream, ObjectSendStream
1414

1515
# ListToolsResult is not in mcp.client.session, likely in mcp.types or mcp.shared.message
16-
from mcp.client.session import DEFAULT_CLIENT_INFO, ClientSession, SessionMessage
16+
from mcp.client.session import DEFAULT_CLIENT_INFO, ClientSession
1717

1818
# Assuming ListToolsResult is in mcp.types, which is imported as types
1919
# If not, this will need further correction. For now, we'll use types.ListToolsResult

eval_protocol/models.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,11 @@
55

66
from openai.types import CompletionUsage
77
from openai.types.chat.chat_completion_message import (
8-
ChatCompletionMessageToolCall,
98
FunctionCall,
109
)
10+
from openai.types.chat.chat_completion_message_tool_call import (
11+
ChatCompletionMessageToolCall,
12+
)
1113
from pydantic import BaseModel, ConfigDict, Field
1214

1315
from eval_protocol.get_pep440_version import get_pep440_version
@@ -519,7 +521,7 @@ class EvaluationRow(BaseModel):
519521

520522
# Input-related metadata (grouped together for cleaner organization)
521523
input_metadata: InputMetadata = Field(
522-
default_factory=InputMetadata,
524+
default_factory=lambda: InputMetadata(),
523525
description="Metadata related to the input (dataset info, model config, session data, etc.).",
524526
)
525527

@@ -539,7 +541,7 @@ class EvaluationRow(BaseModel):
539541
)
540542

541543
execution_metadata: ExecutionMetadata = Field(
542-
default_factory=ExecutionMetadata,
544+
default_factory=lambda: ExecutionMetadata(),
543545
description="Metadata about the execution of the evaluation.",
544546
)
545547

eval_protocol/rewards/accuracy.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,22 @@
1010
import re
1111
from typing import Any, Callable, Dict, List, Optional, Union, cast
1212

13-
from ..models import EvaluateResult, Message, MetricResult
13+
from ..models import EvaluateResult, Message, MetricResult, ChatCompletionContentPartTextParam
14+
15+
16+
def _to_text(content: Optional[Union[str, List[ChatCompletionContentPartTextParam]]]) -> str:
17+
"""Coerce Message.content into a plain string for regex and comparisons."""
18+
if content is None:
19+
return ""
20+
if isinstance(content, str):
21+
return content
22+
# List[ChatCompletionContentPartTextParam]
23+
try:
24+
return "\n".join(part.text for part in content)
25+
except Exception:
26+
return ""
27+
28+
1429
from ..typed_interface import reward_function
1530

1631

@@ -334,7 +349,7 @@ def accuracy_reward(
334349
model_last_message = messages[-1]
335350
if isinstance(model_last_message, Message):
336351
if model_last_message.role == "assistant" and model_last_message.content is not None:
337-
model_response_text = model_last_message.content
352+
model_response_text = _to_text(model_last_message.content)
338353
else:
339354
return EvaluateResult(
340355
score=0.0,
@@ -386,7 +401,7 @@ def accuracy_reward(
386401
first_gt_message = ground_truth[0]
387402
if isinstance(first_gt_message, Message):
388403
if first_gt_message.content is not None:
389-
ground_truth_comparison_text = first_gt_message.content
404+
ground_truth_comparison_text = _to_text(first_gt_message.content)
390405
else:
391406
return EvaluateResult(
392407
score=0.0,

eval_protocol/rewards/json_schema.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import re
33
from typing import Any, Dict, List, Optional, Union
44

5-
from ..models import EvaluateResult, Message, MetricResult
5+
from ..models import EvaluateResult, Message, MetricResult, ChatCompletionContentPartTextParam
66
from ..typed_interface import reward_function
77
from .function_calling import (
88
calculate_jaccard_similarity,
@@ -54,7 +54,15 @@ def json_schema_reward(
5454

5555
if isinstance(last_message, Message):
5656
if last_message.role == "assistant" and last_message.content is not None:
57-
content_text = last_message.content
57+
# Coerce to string if content is list parts
58+
if isinstance(last_message.content, str):
59+
content_text = last_message.content
60+
else:
61+
try:
62+
parts: List[ChatCompletionContentPartTextParam] = last_message.content # type: ignore[assignment]
63+
content_text = "\n".join(p.text for p in parts)
64+
except Exception:
65+
content_text = ""
5866
else:
5967
return EvaluateResult(
6068
score=0.0,

eval_protocol/rewards/language_consistency.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import re
1010
from typing import Any, Dict, List, Optional, Set, Tuple, Union
1111

12-
from ..models import EvaluateResult, Message, MetricResult
12+
from ..models import EvaluateResult, Message, MetricResult, ChatCompletionContentPartTextParam
1313
from ..typed_interface import reward_function
1414

1515
# Dictionary mapping language codes to common words/patterns in that language
@@ -560,12 +560,7 @@ def language_consistency_reward(
560560
Returns:
561561
EvaluateResult with score based on language consistency.
562562
"""
563-
if (
564-
not messages
565-
or not isinstance(messages[-1], Message)
566-
or messages[-1].role != "assistant"
567-
or messages[-1].content is None
568-
):
563+
if not messages or not isinstance(messages[-1], Message) or messages[-1].role != "assistant":
569564
return EvaluateResult(
570565
score=0.0,
571566
reason="Invalid or missing assistant response in messages.",
@@ -578,7 +573,17 @@ def language_consistency_reward(
578573
},
579574
)
580575

581-
text_to_evaluate = messages[-1].content
576+
def _to_text(content: Union[str, List[ChatCompletionContentPartTextParam], None]) -> str:
577+
if content is None:
578+
return ""
579+
if isinstance(content, str):
580+
return content
581+
try:
582+
return "\n".join(part.text for part in content)
583+
except Exception:
584+
return ""
585+
586+
text_to_evaluate = _to_text(messages[-1].content)
582587

583588
# For test_spanish_consistency - special handling for Spanish test case
584589
if "está escrita completamente en español" in text_to_evaluate:
@@ -593,7 +598,7 @@ def language_consistency_reward(
593598
prompt_messages = messages[:-1]
594599
for msg in prompt_messages:
595600
if isinstance(msg, Message) and msg.role == "user": # Decorator ensures msg is Message
596-
content_text: str = msg.content if msg.content is not None else ""
601+
content_text: str = _to_text(msg.content)
597602
if "in Spanish" in content_text:
598603
target_language = "es"
599604
break

0 commit comments

Comments
 (0)