Skip to content

Commit fbb8d34

Browse files
benjibccursoragent
andauthored
Run precommit and fix type errors (#145)
* Refactor type handling and improve error resilience across modules Co-authored-by: bchen <bchen@fireworks.ai> * Refactor tool handling and deprecate mcp_agent main for better compatibility Co-authored-by: bchen <bchen@fireworks.ai> * Replace SimpleNamespace with Pydantic FunctionLike model for tool definitions Co-authored-by: bchen <bchen@fireworks.ai> --------- Co-authored-by: Cursor Agent <cursoragent@cursor.com>
1 parent 3ec9a06 commit fbb8d34

File tree

10 files changed

+134
-366
lines changed

10 files changed

+134
-366
lines changed

eval_protocol/agent/orchestrator.py

Lines changed: 47 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -11,39 +11,33 @@
1111
import logging
1212
import os
1313
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Type, cast
14+
import importlib.util as _importlib_util
1415

15-
# Attempt to import OpenAI client
16-
try:
17-
from openai import AsyncOpenAI, OpenAI
18-
from openai.types.chat import ChatCompletionMessage, ChatCompletionToolParam
19-
from openai.types.chat.chat_completion_message_tool_call import (
20-
ChatCompletionMessageToolCall,
21-
)
22-
23-
OPENAI_AVAILABLE = True
24-
except ImportError:
25-
OPENAI_AVAILABLE = False
26-
# Define dummy types if openai is not installed, to avoid runtime errors on load
27-
from typing import Any, Dict, List, Optional, Union
28-
29-
# Use simple class definitions for runtime and type checking
30-
class OpenAI:
31-
def __init__(self, **kwargs: Any) -> None:
32-
pass
16+
# Determine OpenAI availability without importing symbols for typing
17+
OPENAI_AVAILABLE = _importlib_util.find_spec("openai") is not None
3318

34-
class AsyncOpenAI:
35-
def __init__(self, **kwargs: Any) -> None:
36-
pass
19+
# Expose AsyncOpenAI/OpenAI at module level for tests/patching, even if we import lazily elsewhere
20+
if OPENAI_AVAILABLE:
21+
try:
22+
from openai import AsyncOpenAI as AsyncOpenAI, OpenAI as OpenAI # type: ignore[import-not-found]
23+
except Exception:
3724

38-
class ChatCompletionMessage:
39-
content: str = ""
40-
role: str = "assistant"
25+
class AsyncOpenAI: # type: ignore[no-redef]
26+
def __init__(self, **_: Any) -> None:
27+
pass
4128

42-
class ChatCompletionToolParam:
43-
pass
29+
class OpenAI: # type: ignore[no-redef]
30+
def __init__(self, **_: Any) -> None:
31+
pass
32+
else:
4433

45-
class ChatCompletionMessageToolCall:
46-
pass
34+
class AsyncOpenAI: # type: ignore[no-redef]
35+
def __init__(self, **_: Any) -> None:
36+
pass
37+
38+
class OpenAI: # type: ignore[no-redef]
39+
def __init__(self, **_: Any) -> None:
40+
pass
4741

4842

4943
# Max steps for the inner loop within a single user turn
@@ -71,17 +65,19 @@ def __init__(self, task_definition: TaskDefinitionModel):
7165
self.logger = logging.getLogger(f"Orchestrator.{self.task_definition.name}")
7266
self.logger.setLevel(logging.DEBUG) # Ensure debug logs are processed
7367
self.logger.info(f"Orchestrator initialized for task: {self.task_definition.name}")
74-
self._openai_client: Optional[AsyncOpenAI] = None
68+
# Use Any here to avoid pyright stubs mismatches across openai versions
69+
self._openai_client: Optional[Any] = None
7570

7671
def _initialize_openai_client(self):
7772
"""Initializes the AsyncOpenAI client if available and not already initialized."""
7873
if not OPENAI_AVAILABLE:
7974
self.logger.warning("OpenAI library not available. Cannot use OpenAI models.")
8075
return
8176
if self._openai_client is None:
82-
# Consider adding error handling for missing API key
8377
try:
84-
self._openai_client = AsyncOpenAI(api_key=os.environ.get("OPENAI_API_KEY"))
78+
from openai import AsyncOpenAI # type: ignore[import-not-found]
79+
80+
self._openai_client = AsyncOpenAI(api_key=os.environ.get("OPENAI_API_KEY")) # type: ignore[call-arg]
8581
self.logger.info("AsyncOpenAI client initialized.")
8682
except Exception as e:
8783
self.logger.error(f"Failed to initialize AsyncOpenAI client: {e}")
@@ -94,7 +90,9 @@ def _initialize_fireworks_client(self):
9490
return
9591
if self._openai_client is None:
9692
try:
97-
self._openai_client = AsyncOpenAI(
93+
from openai import AsyncOpenAI # type: ignore[import-not-found]
94+
95+
self._openai_client = AsyncOpenAI( # type: ignore[call-arg]
9896
api_key=os.environ.get("FIREWORKS_API_KEY"),
9997
base_url="https://api.fireworks.ai/inference/v1",
10098
)
@@ -469,18 +467,20 @@ async def execute_task_poc(self, sample_data: Optional[Dict[str, Any]] = None) -
469467
# Initialize the episode resource with sample data if provided
470468
if sample_data:
471469
self.logger.info(f"Initializing episode resource with sample data: {sample_data}")
472-
if hasattr(episode_resource, "initialize"):
473-
await episode_resource.initialize(**sample_data)
470+
initializer = getattr(episode_resource, "initialize", None)
471+
if callable(initializer):
472+
await initializer(**sample_data) # type: ignore[misc]
474473
else:
475474
self.logger.warning(
476475
f"Episode resource {type(episode_resource).__name__} does not have initialize method"
477476
)
478477

479478
# Get initial state for injection into first prompt (for HTTP rollout)
480479
initial_state_description = None
481-
if hasattr(episode_resource, "get_initial_state_description"):
480+
get_init_state = getattr(episode_resource, "get_initial_state_description", None)
481+
if callable(get_init_state):
482482
try:
483-
initial_state_description = await episode_resource.get_initial_state_description()
483+
initial_state_description = await get_init_state() # type: ignore[misc]
484484
self.logger.info("Retrieved initial state description for first prompt")
485485
except Exception as e:
486486
self.logger.warning(f"Failed to get initial state description: {e}")
@@ -577,21 +577,21 @@ async def execute_task_poc(self, sample_data: Optional[Dict[str, Any]] = None) -
577577
) # Get adapters for execution
578578

579579
# Format tools for OpenAI API (should be done once per user turn, or if tools change)
580-
openai_tools: List[ChatCompletionToolParam] = []
580+
openai_tools: List[Dict[str, Any]] = []
581581
if OPENAI_AVAILABLE:
582582
# First add tools from the resource
583583
for spec in resource_tool_specs:
584584
# Ensure spec has the structure with name and parameters
585585
if "name" in spec and "parameters" in spec:
586586
openai_tools.append(
587-
ChatCompletionToolParam(
588-
type="function",
589-
function={
587+
{
588+
"type": "function",
589+
"function": {
590590
"name": spec["name"],
591591
"description": spec.get("description", ""),
592-
"parameters": spec["parameters"], # Assuming this matches OpenAI schema
592+
"parameters": spec["parameters"], # Assuming OpenAI-compatible schema
593593
},
594-
)
594+
}
595595
)
596596
else:
597597
self.logger.warning(f"Skipping tool spec due to missing name/parameters: {spec}")
@@ -605,14 +605,14 @@ async def execute_task_poc(self, sample_data: Optional[Dict[str, Any]] = None) -
605605
registry_tools = self.tools_module.R.get_openai_tools()
606606
for tool_spec in registry_tools:
607607
openai_tools.append(
608-
ChatCompletionToolParam(
609-
type="function",
610-
function={
608+
{
609+
"type": "function",
610+
"function": {
611611
"name": tool_spec["name"],
612612
"description": tool_spec.get("description", ""),
613613
"parameters": tool_spec["parameters"],
614614
},
615-
)
615+
}
616616
)
617617
else:
618618
self.logger.warning("OpenAI not available, cannot format tools for API.")
@@ -642,6 +642,7 @@ async def execute_task_poc(self, sample_data: Optional[Dict[str, Any]] = None) -
642642
if not self._openai_client:
643643
raise Exception("OpenAI client not initialized")
644644

645+
# type: ignore[reportUnknownMemberType]
645646
response = await self._openai_client.chat.completions.create(
646647
model=agent_model_name,
647648
messages=conversation_messages, # type: ignore

eval_protocol/agent/resources/docker_resource.py

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -13,20 +13,17 @@
1313
try:
1414
import docker
1515

16-
if TYPE_CHECKING:
17-
from docker.errors import APIError, DockerException, NotFound
18-
from docker.models.containers import Container
19-
else:
20-
from docker.errors import APIError, DockerException, NotFound
21-
from docker.models.containers import Container
16+
# Import for runtime; annotate as Any to avoid mismatched type aliasing across modules
17+
from docker.errors import APIError as _APIError, DockerException as _DockerException, NotFound as _NotFound
18+
from docker.models.containers import Container as _Container
2219

2320
DOCKER_SDK_AVAILABLE = True
2421
# Ensure these are available for type checking even if the runtime import fails
2522
# The `else` block for DOCKER_SDK_AVAILABLE = False will define runtime dummies.
26-
DockerException = DockerException
27-
NotFound = NotFound
28-
APIError = APIError
29-
Container = Container
23+
DockerException = _DockerException # type: ignore[assignment]
24+
NotFound = _NotFound # type: ignore[assignment]
25+
APIError = _APIError # type: ignore[assignment]
26+
Container = _Container # type: ignore[assignment]
3027
try:
3128
_daemon_check_client = docker.from_env()
3229
_daemon_check_client.ping()
@@ -99,7 +96,7 @@ def __init__(self) -> None:
9996
raise RuntimeError("Docker daemon not running or not accessible")
10097
self._client = docker.from_env()
10198
self._config: Dict[str, Any] = {}
102-
self._container: Optional[Container] = None
99+
self._container: Optional[Any] = None
103100
self._image_id_for_fork_or_checkpoint: Optional[str] = (
104101
None # Stores the ID of the image used for the current container
105102
)
@@ -108,14 +105,14 @@ def __init__(self) -> None:
108105
def _generate_name(self, prefix: str) -> str:
109106
return f"rk_{prefix}_{uuid.uuid4().hex}"
110107

111-
def _cleanup_container(self, container: Optional[Container]) -> None:
108+
def _cleanup_container(self, container: Optional[Any]) -> None:
112109
if container:
113110
try:
114111
container.remove(force=True, v=True) # v=True to remove volumes
115112
except NotFound:
116113
pass # Already removed
117114
except APIError as e:
118-
print(f"DockerResource: Error removing container {(container.id or '')[:12]}: {e}")
115+
print(f"DockerResource: Error removing container {(getattr(container, 'id', '') or '')[:12]}: {e}")
119116

120117
def _cleanup_image(self, image_id: Optional[str]) -> None:
121118
if image_id:

eval_protocol/agent/task_manager.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -492,10 +492,10 @@ async def execute_single_rollout(rollout_index: int):
492492
# Convert EvaluateResult to dict if needed
493493
if hasattr(result, "model_dump"):
494494
# Pydantic model - convert to dict
495-
result = result.model_dump()
495+
result = result.model_dump() # type: ignore[call-arg]
496496
elif hasattr(result, "dict"):
497497
# Older pydantic models
498-
result = result.dict()
498+
result = result.dict() # type: ignore[call-arg]
499499
# If it's already a dict, leave it as is
500500

501501
# Add reward function inputs to the result for JSONL trajectory storage
@@ -529,7 +529,14 @@ async def execute_single_rollout(rollout_index: int):
529529

530530
# Execute all rollouts concurrently
531531
rollout_tasks = [execute_single_rollout(i) for i in range(num_rollouts)]
532-
rollout_results = await asyncio.gather(*rollout_tasks)
532+
rollout_results_raw = await asyncio.gather(*rollout_tasks)
533+
# Normalize to list of dicts for typing purposes where possible
534+
rollout_results: List[Dict[str, Any]] = []
535+
for item in rollout_results_raw:
536+
if isinstance(item, dict):
537+
rollout_results.append(item)
538+
else:
539+
rollout_results.append({"result": item})
533540

534541
# Log failed rollouts but return all results for comprehensive analysis
535542
successful_results = [r for r in rollout_results if not (isinstance(r, dict) and "error" in r)]
@@ -665,10 +672,10 @@ async def execute_single_rollout(sample_index: int, rollout_index: int, sample_d
665672
# Convert EvaluateResult to dict if needed
666673
if hasattr(result, "model_dump"):
667674
# Pydantic model - convert to dict
668-
result = result.model_dump()
675+
result = result.model_dump() # type: ignore[call-arg]
669676
elif hasattr(result, "dict"):
670677
# Older pydantic models
671-
result = result.dict()
678+
result = result.dict() # type: ignore[call-arg]
672679
# If it's already a dict, leave it as is
673680

674681
# Add reward function inputs to the result for JSONL trajectory storage

eval_protocol/mcp/mcp_multi_client.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,22 @@
33
from contextlib import AsyncExitStack
44
from dataclasses import dataclass
55
from typing import Any, Dict, List, Optional, Union
6+
from pydantic import BaseModel
7+
from typing import Optional
8+
9+
10+
class FunctionLike(BaseModel):
11+
name: Optional[str] = None
12+
description: Optional[str] = None
13+
parameters: Any = None
14+
615

716
from dotenv import load_dotenv
817
from mcp import ClientSession, StdioServerParameters
918
from mcp.client.stdio import stdio_client
1019
from mcp.client.streamable_http import streamablehttp_client
1120
from mcp.types import CallToolResult
1221
from openai.types import FunctionDefinition
13-
from openai.types.chat import ChatCompletionToolParam
1422

1523
from eval_protocol.models import (
1624
MCPConfigurationServerStdio,
@@ -125,29 +133,29 @@ async def _connect_to_server(
125133
[tool.name for tool in tools],
126134
)
127135

128-
async def get_available_tools(self) -> List[ChatCompletionToolParam]:
136+
async def get_available_tools(self) -> List[Dict[str, Any]]:
129137
"""Get all available tools from all connected servers"""
130138
all_tools = []
131139
for server_name, session in self.sessions.items():
132140
try:
133141
response = await session.list_tools()
134142
for tool in response.tools:
135143
all_tools.append(
136-
ChatCompletionToolParam(
137-
function=FunctionDefinition(
138-
name=tool.name, # Prefix with server name
144+
{
145+
"type": "function",
146+
"function": FunctionLike(
147+
name=tool.name,
139148
description=tool.description,
140149
parameters=tool.inputSchema,
141150
),
142-
type="function",
143-
)
151+
}
144152
)
145153
except Exception as e:
146154
print(f"Error listing tools from server '{server_name}': {e}")
147155

148156
return all_tools
149157

150-
async def call_tool(self, tool_name: str, tool_args: Dict[str, Any]) -> CallToolResult:
158+
async def call_tool(self, tool_name: str, tool_args: Dict[str, Any]) -> Union[CallToolResult, str]:
151159
"""Call a specific tool by name with arguments"""
152160

153161
session = self.tools_to_sessions[tool_name]

eval_protocol/mcp/simple_process_manager.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
import time
1414
import uuid
1515
from contextlib import AsyncExitStack
16-
from typing import Dict, Tuple
16+
from typing import Dict, Tuple, Optional
1717

1818
from mcp.client.session import ClientSession
1919
from mcp.client.streamable_http import streamablehttp_client
@@ -26,7 +26,7 @@ class SimpleServerProcessManager:
2626
def __init__(
2727
self,
2828
script_path: str,
29-
python_executable: str = None,
29+
python_executable: Optional[str] = None,
3030
port_range: Tuple[int, int] = (10000, 11000),
3131
):
3232
"""

0 commit comments

Comments
 (0)