Skip to content

Commit ed1dc59

Browse files
cursoragentbenjibc
andcommitted
Refactor type handling and improve error resilience across modules
Co-authored-by: bchen <bchen@fireworks.ai>
1 parent 3ec9a06 commit ed1dc59

File tree

9 files changed

+84
-92
lines changed

9 files changed

+84
-92
lines changed

eval_protocol/agent/orchestrator.py

Lines changed: 28 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -11,39 +11,10 @@
1111
import logging
1212
import os
1313
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Type, cast
14+
import importlib.util as _importlib_util
1415

15-
# Attempt to import OpenAI client
16-
try:
17-
from openai import AsyncOpenAI, OpenAI
18-
from openai.types.chat import ChatCompletionMessage, ChatCompletionToolParam
19-
from openai.types.chat.chat_completion_message_tool_call import (
20-
ChatCompletionMessageToolCall,
21-
)
22-
23-
OPENAI_AVAILABLE = True
24-
except ImportError:
25-
OPENAI_AVAILABLE = False
26-
# Define dummy types if openai is not installed, to avoid runtime errors on load
27-
from typing import Any, Dict, List, Optional, Union
28-
29-
# Use simple class definitions for runtime and type checking
30-
class OpenAI:
31-
def __init__(self, **kwargs: Any) -> None:
32-
pass
33-
34-
class AsyncOpenAI:
35-
def __init__(self, **kwargs: Any) -> None:
36-
pass
37-
38-
class ChatCompletionMessage:
39-
content: str = ""
40-
role: str = "assistant"
41-
42-
class ChatCompletionToolParam:
43-
pass
44-
45-
class ChatCompletionMessageToolCall:
46-
pass
16+
# Determine OpenAI availability without importing symbols for typing
17+
OPENAI_AVAILABLE = _importlib_util.find_spec("openai") is not None
4718

4819

4920
# Max steps for the inner loop within a single user turn
@@ -71,17 +42,19 @@ def __init__(self, task_definition: TaskDefinitionModel):
7142
self.logger = logging.getLogger(f"Orchestrator.{self.task_definition.name}")
7243
self.logger.setLevel(logging.DEBUG) # Ensure debug logs are processed
7344
self.logger.info(f"Orchestrator initialized for task: {self.task_definition.name}")
74-
self._openai_client: Optional[AsyncOpenAI] = None
45+
# Use Any here to avoid pyright stubs mismatches across openai versions
46+
self._openai_client: Optional[Any] = None
7547

7648
def _initialize_openai_client(self):
7749
"""Initializes the AsyncOpenAI client if available and not already initialized."""
7850
if not OPENAI_AVAILABLE:
7951
self.logger.warning("OpenAI library not available. Cannot use OpenAI models.")
8052
return
8153
if self._openai_client is None:
82-
# Consider adding error handling for missing API key
8354
try:
84-
self._openai_client = AsyncOpenAI(api_key=os.environ.get("OPENAI_API_KEY"))
55+
from openai import AsyncOpenAI # type: ignore[import-not-found]
56+
57+
self._openai_client = AsyncOpenAI(api_key=os.environ.get("OPENAI_API_KEY")) # type: ignore[call-arg]
8558
self.logger.info("AsyncOpenAI client initialized.")
8659
except Exception as e:
8760
self.logger.error(f"Failed to initialize AsyncOpenAI client: {e}")
@@ -94,7 +67,9 @@ def _initialize_fireworks_client(self):
9467
return
9568
if self._openai_client is None:
9669
try:
97-
self._openai_client = AsyncOpenAI(
70+
from openai import AsyncOpenAI # type: ignore[import-not-found]
71+
72+
self._openai_client = AsyncOpenAI( # type: ignore[call-arg]
9873
api_key=os.environ.get("FIREWORKS_API_KEY"),
9974
base_url="https://api.fireworks.ai/inference/v1",
10075
)
@@ -469,18 +444,20 @@ async def execute_task_poc(self, sample_data: Optional[Dict[str, Any]] = None) -
469444
# Initialize the episode resource with sample data if provided
470445
if sample_data:
471446
self.logger.info(f"Initializing episode resource with sample data: {sample_data}")
472-
if hasattr(episode_resource, "initialize"):
473-
await episode_resource.initialize(**sample_data)
447+
initializer = getattr(episode_resource, "initialize", None)
448+
if callable(initializer):
449+
await initializer(**sample_data) # type: ignore[misc]
474450
else:
475451
self.logger.warning(
476452
f"Episode resource {type(episode_resource).__name__} does not have initialize method"
477453
)
478454

479455
# Get initial state for injection into first prompt (for HTTP rollout)
480456
initial_state_description = None
481-
if hasattr(episode_resource, "get_initial_state_description"):
457+
get_init_state = getattr(episode_resource, "get_initial_state_description", None)
458+
if callable(get_init_state):
482459
try:
483-
initial_state_description = await episode_resource.get_initial_state_description()
460+
initial_state_description = await get_init_state() # type: ignore[misc]
484461
self.logger.info("Retrieved initial state description for first prompt")
485462
except Exception as e:
486463
self.logger.warning(f"Failed to get initial state description: {e}")
@@ -577,21 +554,21 @@ async def execute_task_poc(self, sample_data: Optional[Dict[str, Any]] = None) -
577554
) # Get adapters for execution
578555

579556
# Format tools for OpenAI API (should be done once per user turn, or if tools change)
580-
openai_tools: List[ChatCompletionToolParam] = []
557+
openai_tools: List[Dict[str, Any]] = []
581558
if OPENAI_AVAILABLE:
582559
# First add tools from the resource
583560
for spec in resource_tool_specs:
584561
# Ensure spec has the structure with name and parameters
585562
if "name" in spec and "parameters" in spec:
586563
openai_tools.append(
587-
ChatCompletionToolParam(
588-
type="function",
589-
function={
564+
{
565+
"type": "function",
566+
"function": {
590567
"name": spec["name"],
591568
"description": spec.get("description", ""),
592-
"parameters": spec["parameters"], # Assuming this matches OpenAI schema
569+
"parameters": spec["parameters"], # Assuming OpenAI-compatible schema
593570
},
594-
)
571+
}
595572
)
596573
else:
597574
self.logger.warning(f"Skipping tool spec due to missing name/parameters: {spec}")
@@ -605,14 +582,14 @@ async def execute_task_poc(self, sample_data: Optional[Dict[str, Any]] = None) -
605582
registry_tools = self.tools_module.R.get_openai_tools()
606583
for tool_spec in registry_tools:
607584
openai_tools.append(
608-
ChatCompletionToolParam(
609-
type="function",
610-
function={
585+
{
586+
"type": "function",
587+
"function": {
611588
"name": tool_spec["name"],
612589
"description": tool_spec.get("description", ""),
613590
"parameters": tool_spec["parameters"],
614591
},
615-
)
592+
}
616593
)
617594
else:
618595
self.logger.warning("OpenAI not available, cannot format tools for API.")
@@ -642,6 +619,7 @@ async def execute_task_poc(self, sample_data: Optional[Dict[str, Any]] = None) -
642619
if not self._openai_client:
643620
raise Exception("OpenAI client not initialized")
644621

622+
# type: ignore[reportUnknownMemberType]
645623
response = await self._openai_client.chat.completions.create(
646624
model=agent_model_name,
647625
messages=conversation_messages, # type: ignore

eval_protocol/agent/resources/docker_resource.py

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -13,20 +13,17 @@
1313
try:
1414
import docker
1515

16-
if TYPE_CHECKING:
17-
from docker.errors import APIError, DockerException, NotFound
18-
from docker.models.containers import Container
19-
else:
20-
from docker.errors import APIError, DockerException, NotFound
21-
from docker.models.containers import Container
16+
# Import for runtime; annotate as Any to avoid mismatched type aliasing across modules
17+
from docker.errors import APIError as _APIError, DockerException as _DockerException, NotFound as _NotFound
18+
from docker.models.containers import Container as _Container
2219

2320
DOCKER_SDK_AVAILABLE = True
2421
# Ensure these are available for type checking even if the runtime import fails
2522
# The `else` block for DOCKER_SDK_AVAILABLE = False will define runtime dummies.
26-
DockerException = DockerException
27-
NotFound = NotFound
28-
APIError = APIError
29-
Container = Container
23+
DockerException = _DockerException # type: ignore[assignment]
24+
NotFound = _NotFound # type: ignore[assignment]
25+
APIError = _APIError # type: ignore[assignment]
26+
Container = _Container # type: ignore[assignment]
3027
try:
3128
_daemon_check_client = docker.from_env()
3229
_daemon_check_client.ping()
@@ -99,7 +96,7 @@ def __init__(self) -> None:
9996
raise RuntimeError("Docker daemon not running or not accessible")
10097
self._client = docker.from_env()
10198
self._config: Dict[str, Any] = {}
102-
self._container: Optional[Container] = None
99+
self._container: Optional[Any] = None
103100
self._image_id_for_fork_or_checkpoint: Optional[str] = (
104101
None # Stores the ID of the image used for the current container
105102
)
@@ -108,14 +105,14 @@ def __init__(self) -> None:
108105
def _generate_name(self, prefix: str) -> str:
109106
return f"rk_{prefix}_{uuid.uuid4().hex}"
110107

111-
def _cleanup_container(self, container: Optional[Container]) -> None:
108+
def _cleanup_container(self, container: Optional[Any]) -> None:
112109
if container:
113110
try:
114111
container.remove(force=True, v=True) # v=True to remove volumes
115112
except NotFound:
116113
pass # Already removed
117114
except APIError as e:
118-
print(f"DockerResource: Error removing container {(container.id or '')[:12]}: {e}")
115+
print(f"DockerResource: Error removing container {(getattr(container, 'id', '') or '')[:12]}: {e}")
119116

120117
def _cleanup_image(self, image_id: Optional[str]) -> None:
121118
if image_id:

eval_protocol/agent/task_manager.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -492,10 +492,10 @@ async def execute_single_rollout(rollout_index: int):
492492
# Convert EvaluateResult to dict if needed
493493
if hasattr(result, "model_dump"):
494494
# Pydantic model - convert to dict
495-
result = result.model_dump()
495+
result = result.model_dump() # type: ignore[call-arg]
496496
elif hasattr(result, "dict"):
497497
# Older pydantic models
498-
result = result.dict()
498+
result = result.dict() # type: ignore[call-arg]
499499
# If it's already a dict, leave it as is
500500

501501
# Add reward function inputs to the result for JSONL trajectory storage
@@ -529,7 +529,14 @@ async def execute_single_rollout(rollout_index: int):
529529

530530
# Execute all rollouts concurrently
531531
rollout_tasks = [execute_single_rollout(i) for i in range(num_rollouts)]
532-
rollout_results = await asyncio.gather(*rollout_tasks)
532+
rollout_results_raw = await asyncio.gather(*rollout_tasks)
533+
# Normalize to list of dicts for typing purposes where possible
534+
rollout_results: List[Dict[str, Any]] = []
535+
for item in rollout_results_raw:
536+
if isinstance(item, dict):
537+
rollout_results.append(item)
538+
else:
539+
rollout_results.append({"result": item})
533540

534541
# Log failed rollouts but return all results for comprehensive analysis
535542
successful_results = [r for r in rollout_results if not (isinstance(r, dict) and "error" in r)]
@@ -665,10 +672,10 @@ async def execute_single_rollout(sample_index: int, rollout_index: int, sample_d
665672
# Convert EvaluateResult to dict if needed
666673
if hasattr(result, "model_dump"):
667674
# Pydantic model - convert to dict
668-
result = result.model_dump()
675+
result = result.model_dump() # type: ignore[call-arg]
669676
elif hasattr(result, "dict"):
670677
# Older pydantic models
671-
result = result.dict()
678+
result = result.dict() # type: ignore[call-arg]
672679
# If it's already a dict, leave it as is
673680

674681
# Add reward function inputs to the result for JSONL trajectory storage

eval_protocol/mcp/mcp_multi_client.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
from mcp.client.streamable_http import streamablehttp_client
1111
from mcp.types import CallToolResult
1212
from openai.types import FunctionDefinition
13-
from openai.types.chat import ChatCompletionToolParam
1413

1514
from eval_protocol.models import (
1615
MCPConfigurationServerStdio,
@@ -125,29 +124,29 @@ async def _connect_to_server(
125124
[tool.name for tool in tools],
126125
)
127126

128-
async def get_available_tools(self) -> List[ChatCompletionToolParam]:
127+
async def get_available_tools(self) -> List[Dict[str, Any]]:
129128
"""Get all available tools from all connected servers"""
130129
all_tools = []
131130
for server_name, session in self.sessions.items():
132131
try:
133132
response = await session.list_tools()
134133
for tool in response.tools:
135134
all_tools.append(
136-
ChatCompletionToolParam(
137-
function=FunctionDefinition(
138-
name=tool.name, # Prefix with server name
139-
description=tool.description,
140-
parameters=tool.inputSchema,
141-
),
142-
type="function",
143-
)
135+
{
136+
"type": "function",
137+
"function": {
138+
"name": tool.name,
139+
"description": tool.description,
140+
"parameters": tool.inputSchema,
141+
},
142+
}
144143
)
145144
except Exception as e:
146145
print(f"Error listing tools from server '{server_name}': {e}")
147146

148147
return all_tools
149148

150-
async def call_tool(self, tool_name: str, tool_args: Dict[str, Any]) -> CallToolResult:
149+
async def call_tool(self, tool_name: str, tool_args: Dict[str, Any]) -> Union[CallToolResult, str]:
151150
"""Call a specific tool by name with arguments"""
152151

153152
session = self.tools_to_sessions[tool_name]

eval_protocol/mcp/simple_process_manager.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
import time
1414
import uuid
1515
from contextlib import AsyncExitStack
16-
from typing import Dict, Tuple
16+
from typing import Dict, Tuple, Optional
1717

1818
from mcp.client.session import ClientSession
1919
from mcp.client.streamable_http import streamablehttp_client
@@ -26,7 +26,7 @@ class SimpleServerProcessManager:
2626
def __init__(
2727
self,
2828
script_path: str,
29-
python_executable: str = None,
29+
python_executable: Optional[str] = None,
3030
port_range: Tuple[int, int] = (10000, 11000),
3131
):
3232
"""

eval_protocol/mcp_agent/main.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,8 +98,12 @@ async def main_async(config_path: str, host: str, port: int):
9898

9999
# 2. Instantiate StreamableHTTPSessionManager
100100
# Pass the internal _mcp_server (the MCPServer instance) from our FastMCP subclass
101+
if _mcp_server_instance_ref is None:
102+
logger.error("Failed to initialize RewardKitIntermediaryServer")
103+
return
104+
101105
session_manager = StreamableHTTPSessionManager(
102-
app=_mcp_server_instance_ref._mcp_server,
106+
app=_mcp_server_instance_ref._mcp_server, # type: ignore[attr-defined]
103107
event_store=None,
104108
json_response=True, # Changed to True
105109
)

eval_protocol/models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -552,7 +552,7 @@ class EvaluationRow(BaseModel):
552552
)
553553

554554
execution_metadata: ExecutionMetadata = Field(
555-
default_factory=lambda: ExecutionMetadata(),
555+
default_factory=lambda: ExecutionMetadata(run_id=None),
556556
description="Metadata about the execution of the evaluation.",
557557
)
558558

eval_protocol/playback_policy.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -243,12 +243,16 @@ async def __call__(
243243

244244
if messages is None:
245245
# No more recorded actions - signal early termination
246-
return [
247-
MCPToolCall(
248-
"_playback_terminate",
249-
{"reason": "no_more_recorded_actions"},
250-
)
251-
]
246+
return (
247+
[
248+
MCPToolCall(
249+
"_playback_terminate",
250+
{"reason": "no_more_recorded_actions"},
251+
)
252+
],
253+
None,
254+
None,
255+
)
252256

253257
# Return the recorded tool call
254258
return self._extract_tool_call_from_messages(messages, env_index), None, None

eval_protocol/pytest/default_agent_rollout_processor.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
from mcp.types import CallToolResult, TextContent
88
from openai import NOT_GIVEN, NotGiven
9-
from openai.types.chat import ChatCompletionContentPartTextParam, ChatCompletionMessage, ChatCompletionToolParam
9+
from openai.types.chat import ChatCompletionContentPartTextParam
1010
from openai.types.chat.chat_completion_message_param import ChatCompletionMessageParam
1111

1212
from eval_protocol.dataset_logger.dataset_logger import DatasetLogger
@@ -166,13 +166,16 @@ async def _execute_tool_call(
166166
"""
167167
assert self.mcp_client is not None, "MCP client is not initialized"
168168
tool_result = await self.mcp_client.call_tool(tool_name, tool_args_dict)
169-
content = self._get_content_from_tool_result(tool_result)
169+
# Accept string errors from client and normalize to text content
170+
content = self._get_content_from_tool_result(tool_result) # type: ignore[arg-type]
170171
return tool_call_id, content
171172

172-
def _get_content_from_tool_result(self, tool_result: CallToolResult) -> List[TextContent]:
173+
def _get_content_from_tool_result(self, tool_result: CallToolResult | str) -> List[TextContent]:
173174
if getattr(tool_result, "structuredContent", None):
174175
return [TextContent(text=json.dumps(tool_result.structuredContent), type="text")]
175176
normalized: List[TextContent] = []
177+
if isinstance(tool_result, str):
178+
return [TextContent(text=tool_result, type="text")]
176179
for content in getattr(tool_result, "content", []) or []:
177180
if isinstance(content, TextContent):
178181
normalized.append(content)

0 commit comments

Comments
 (0)