Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 47 additions & 46 deletions eval_protocol/agent/orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,39 +11,33 @@
import logging
import os
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Type, cast
import importlib.util as _importlib_util

# Attempt to import OpenAI client
try:
from openai import AsyncOpenAI, OpenAI
from openai.types.chat import ChatCompletionMessage, ChatCompletionToolParam
from openai.types.chat.chat_completion_message_tool_call import (
ChatCompletionMessageToolCall,
)

OPENAI_AVAILABLE = True
except ImportError:
OPENAI_AVAILABLE = False
# Define dummy types if openai is not installed, to avoid runtime errors on load
from typing import Any, Dict, List, Optional, Union

# Use simple class definitions for runtime and type checking
class OpenAI:
def __init__(self, **kwargs: Any) -> None:
pass
# Determine OpenAI availability without importing symbols for typing
OPENAI_AVAILABLE = _importlib_util.find_spec("openai") is not None

class AsyncOpenAI:
def __init__(self, **kwargs: Any) -> None:
pass
# Expose AsyncOpenAI/OpenAI at module level for tests/patching, even if we import lazily elsewhere
if OPENAI_AVAILABLE:
try:
from openai import AsyncOpenAI as AsyncOpenAI, OpenAI as OpenAI # type: ignore[import-not-found]
except Exception:

class ChatCompletionMessage:
content: str = ""
role: str = "assistant"
class AsyncOpenAI: # type: ignore[no-redef]
def __init__(self, **_: Any) -> None:
pass

class ChatCompletionToolParam:
pass
class OpenAI: # type: ignore[no-redef]
def __init__(self, **_: Any) -> None:
pass
else:

class ChatCompletionMessageToolCall:
pass
class AsyncOpenAI: # type: ignore[no-redef]
def __init__(self, **_: Any) -> None:
pass

class OpenAI: # type: ignore[no-redef]
def __init__(self, **_: Any) -> None:
pass


# Max steps for the inner loop within a single user turn
Expand Down Expand Up @@ -71,17 +65,19 @@ def __init__(self, task_definition: TaskDefinitionModel):
self.logger = logging.getLogger(f"Orchestrator.{self.task_definition.name}")
self.logger.setLevel(logging.DEBUG) # Ensure debug logs are processed
self.logger.info(f"Orchestrator initialized for task: {self.task_definition.name}")
self._openai_client: Optional[AsyncOpenAI] = None
# Use Any here to avoid pyright stubs mismatches across openai versions
self._openai_client: Optional[Any] = None

def _initialize_openai_client(self):
"""Initializes the AsyncOpenAI client if available and not already initialized."""
if not OPENAI_AVAILABLE:
self.logger.warning("OpenAI library not available. Cannot use OpenAI models.")
return
if self._openai_client is None:
# Consider adding error handling for missing API key
try:
self._openai_client = AsyncOpenAI(api_key=os.environ.get("OPENAI_API_KEY"))
from openai import AsyncOpenAI # type: ignore[import-not-found]

self._openai_client = AsyncOpenAI(api_key=os.environ.get("OPENAI_API_KEY")) # type: ignore[call-arg]
self.logger.info("AsyncOpenAI client initialized.")
except Exception as e:
self.logger.error(f"Failed to initialize AsyncOpenAI client: {e}")
Expand All @@ -94,7 +90,9 @@ def _initialize_fireworks_client(self):
return
if self._openai_client is None:
try:
self._openai_client = AsyncOpenAI(
from openai import AsyncOpenAI # type: ignore[import-not-found]

self._openai_client = AsyncOpenAI( # type: ignore[call-arg]
api_key=os.environ.get("FIREWORKS_API_KEY"),
base_url="https://api.fireworks.ai/inference/v1",
)
Expand Down Expand Up @@ -469,18 +467,20 @@ async def execute_task_poc(self, sample_data: Optional[Dict[str, Any]] = None) -
# Initialize the episode resource with sample data if provided
if sample_data:
self.logger.info(f"Initializing episode resource with sample data: {sample_data}")
if hasattr(episode_resource, "initialize"):
await episode_resource.initialize(**sample_data)
initializer = getattr(episode_resource, "initialize", None)
if callable(initializer):
await initializer(**sample_data) # type: ignore[misc]
else:
self.logger.warning(
f"Episode resource {type(episode_resource).__name__} does not have initialize method"
)

# Get initial state for injection into first prompt (for HTTP rollout)
initial_state_description = None
if hasattr(episode_resource, "get_initial_state_description"):
get_init_state = getattr(episode_resource, "get_initial_state_description", None)
if callable(get_init_state):
try:
initial_state_description = await episode_resource.get_initial_state_description()
initial_state_description = await get_init_state() # type: ignore[misc]
self.logger.info("Retrieved initial state description for first prompt")
except Exception as e:
self.logger.warning(f"Failed to get initial state description: {e}")
Expand Down Expand Up @@ -577,21 +577,21 @@ async def execute_task_poc(self, sample_data: Optional[Dict[str, Any]] = None) -
) # Get adapters for execution

# Format tools for OpenAI API (should be done once per user turn, or if tools change)
openai_tools: List[ChatCompletionToolParam] = []
openai_tools: List[Dict[str, Any]] = []
if OPENAI_AVAILABLE:
# First add tools from the resource
for spec in resource_tool_specs:
# Ensure spec has the structure with name and parameters
if "name" in spec and "parameters" in spec:
openai_tools.append(
ChatCompletionToolParam(
type="function",
function={
{
"type": "function",
"function": {
"name": spec["name"],
"description": spec.get("description", ""),
"parameters": spec["parameters"], # Assuming this matches OpenAI schema
"parameters": spec["parameters"], # Assuming OpenAI-compatible schema
},
)
}
)
else:
self.logger.warning(f"Skipping tool spec due to missing name/parameters: {spec}")
Expand All @@ -605,14 +605,14 @@ async def execute_task_poc(self, sample_data: Optional[Dict[str, Any]] = None) -
registry_tools = self.tools_module.R.get_openai_tools()
for tool_spec in registry_tools:
openai_tools.append(
ChatCompletionToolParam(
type="function",
function={
{
"type": "function",
"function": {
"name": tool_spec["name"],
"description": tool_spec.get("description", ""),
"parameters": tool_spec["parameters"],
},
)
}
)
else:
self.logger.warning("OpenAI not available, cannot format tools for API.")
Expand Down Expand Up @@ -642,6 +642,7 @@ async def execute_task_poc(self, sample_data: Optional[Dict[str, Any]] = None) -
if not self._openai_client:
raise Exception("OpenAI client not initialized")

# type: ignore[reportUnknownMemberType]
response = await self._openai_client.chat.completions.create(
model=agent_model_name,
messages=conversation_messages, # type: ignore
Expand Down
23 changes: 10 additions & 13 deletions eval_protocol/agent/resources/docker_resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,20 +13,17 @@
try:
import docker

if TYPE_CHECKING:
from docker.errors import APIError, DockerException, NotFound
from docker.models.containers import Container
else:
from docker.errors import APIError, DockerException, NotFound
from docker.models.containers import Container
# Import for runtime; annotate as Any to avoid mismatched type aliasing across modules
from docker.errors import APIError as _APIError, DockerException as _DockerException, NotFound as _NotFound
from docker.models.containers import Container as _Container

DOCKER_SDK_AVAILABLE = True
# Ensure these are available for type checking even if the runtime import fails
# The `else` block for DOCKER_SDK_AVAILABLE = False will define runtime dummies.
DockerException = DockerException
NotFound = NotFound
APIError = APIError
Container = Container
DockerException = _DockerException # type: ignore[assignment]
NotFound = _NotFound # type: ignore[assignment]
APIError = _APIError # type: ignore[assignment]
Container = _Container # type: ignore[assignment]
try:
_daemon_check_client = docker.from_env()
_daemon_check_client.ping()
Expand Down Expand Up @@ -99,7 +96,7 @@ def __init__(self) -> None:
raise RuntimeError("Docker daemon not running or not accessible")
self._client = docker.from_env()
self._config: Dict[str, Any] = {}
self._container: Optional[Container] = None
self._container: Optional[Any] = None
self._image_id_for_fork_or_checkpoint: Optional[str] = (
None # Stores the ID of the image used for the current container
)
Expand All @@ -108,14 +105,14 @@ def __init__(self) -> None:
def _generate_name(self, prefix: str) -> str:
return f"rk_{prefix}_{uuid.uuid4().hex}"

def _cleanup_container(self, container: Optional[Container]) -> None:
def _cleanup_container(self, container: Optional[Any]) -> None:
if container:
try:
container.remove(force=True, v=True) # v=True to remove volumes
except NotFound:
pass # Already removed
except APIError as e:
print(f"DockerResource: Error removing container {(container.id or '')[:12]}: {e}")
print(f"DockerResource: Error removing container {(getattr(container, 'id', '') or '')[:12]}: {e}")

def _cleanup_image(self, image_id: Optional[str]) -> None:
if image_id:
Expand Down
17 changes: 12 additions & 5 deletions eval_protocol/agent/task_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,10 +492,10 @@
# Convert EvaluateResult to dict if needed
if hasattr(result, "model_dump"):
# Pydantic model - convert to dict
result = result.model_dump()
result = result.model_dump() # type: ignore[call-arg]
elif hasattr(result, "dict"):
# Older pydantic models
result = result.dict()
result = result.dict() # type: ignore[call-arg]
# If it's already a dict, leave it as is

# Add reward function inputs to the result for JSONL trajectory storage
Expand Down Expand Up @@ -529,7 +529,14 @@

# Execute all rollouts concurrently
rollout_tasks = [execute_single_rollout(i) for i in range(num_rollouts)]
rollout_results = await asyncio.gather(*rollout_tasks)
rollout_results_raw = await asyncio.gather(*rollout_tasks)
# Normalize to list of dicts for typing purposes where possible
rollout_results: List[Dict[str, Any]] = []
for item in rollout_results_raw:
if isinstance(item, dict):
rollout_results.append(item)
else:
rollout_results.append({"result": item})

# Log failed rollouts but return all results for comprehensive analysis
successful_results = [r for r in rollout_results if not (isinstance(r, dict) and "error" in r)]
Expand Down Expand Up @@ -665,10 +672,10 @@
# Convert EvaluateResult to dict if needed
if hasattr(result, "model_dump"):
# Pydantic model - convert to dict
result = result.model_dump()
result = result.model_dump() # type: ignore[call-arg]
elif hasattr(result, "dict"):
# Older pydantic models
result = result.dict()
result = result.dict() # type: ignore[call-arg]
# If it's already a dict, leave it as is

# Add reward function inputs to the result for JSONL trajectory storage
Expand All @@ -677,9 +684,9 @@

# Add sample metadata to the result
if isinstance(result, dict):
result["sample_data"] = sample_data

Check failure on line 687 in eval_protocol/agent/task_manager.py

View workflow job for this annotation

GitHub Actions / Lint & Type Check

Argument of type "Dict[str, Any]" cannot be assigned to parameter "value" of type "str" in function "__setitem__"   "Dict[str, Any]" is not assignable to "str" (reportArgumentType)
result["sample_index"] = sample_index

Check failure on line 688 in eval_protocol/agent/task_manager.py

View workflow job for this annotation

GitHub Actions / Lint & Type Check

Argument of type "int" cannot be assigned to parameter "value" of type "str" in function "__setitem__"   "int" is not assignable to "str" (reportArgumentType)
result["rollout_index"] = rollout_index

Check failure on line 689 in eval_protocol/agent/task_manager.py

View workflow job for this annotation

GitHub Actions / Lint & Type Check

Argument of type "int" cannot be assigned to parameter "value" of type "str" in function "__setitem__"   "int" is not assignable to "str" (reportArgumentType)

score = result.get("score", "N/A") if isinstance(result, dict) else "N/A"
self.logger.info(
Expand Down Expand Up @@ -913,9 +920,9 @@
if chosen_dir is None:
chosen_dir = Path(".")

output_file = chosen_dir / f"trajectory_{task_id}_{timestamp}.jsonl"

Check failure on line 923 in eval_protocol/agent/task_manager.py

View workflow job for this annotation

GitHub Actions / Lint & Type Check

Type "Path" is not assignable to declared type "str | None"   Type "Path" is not assignable to type "str | None"     "Path" is not assignable to "str"     "Path" is not assignable to "None" (reportAssignmentType)

output_path = Path(output_file)

Check failure on line 925 in eval_protocol/agent/task_manager.py

View workflow job for this annotation

GitHub Actions / Lint & Type Check

Argument of type "str | None" cannot be assigned to parameter "args" of type "StrPath" in function "__new__"   Type "str | None" is not assignable to type "StrPath"     Type "None" is not assignable to type "StrPath"       "None" is not assignable to "str"       "None" is incompatible with protocol "PathLike[str]"         "__fspath__" is not present (reportArgumentType)

try:
self.logger.info("=== TRAJECTORY SAVE DEBUG START ===")
Expand Down
24 changes: 16 additions & 8 deletions eval_protocol/mcp/mcp_multi_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,22 @@
from contextlib import AsyncExitStack
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Union
from pydantic import BaseModel
from typing import Optional


class FunctionLike(BaseModel):
name: Optional[str] = None
description: Optional[str] = None
parameters: Any = None


from dotenv import load_dotenv
from mcp import ClientSession, StdioServerParameters
from mcp.client.stdio import stdio_client
from mcp.client.streamable_http import streamablehttp_client
from mcp.types import CallToolResult
from openai.types import FunctionDefinition
from openai.types.chat import ChatCompletionToolParam

from eval_protocol.models import (
MCPConfigurationServerStdio,
Expand Down Expand Up @@ -125,29 +133,29 @@ async def _connect_to_server(
[tool.name for tool in tools],
)

async def get_available_tools(self) -> List[ChatCompletionToolParam]:
async def get_available_tools(self) -> List[Dict[str, Any]]:
"""Get all available tools from all connected servers"""
all_tools = []
for server_name, session in self.sessions.items():
try:
response = await session.list_tools()
for tool in response.tools:
all_tools.append(
ChatCompletionToolParam(
function=FunctionDefinition(
name=tool.name, # Prefix with server name
{
"type": "function",
"function": FunctionLike(
name=tool.name,
description=tool.description,
parameters=tool.inputSchema,
),
type="function",
)
}
)
except Exception as e:
print(f"Error listing tools from server '{server_name}': {e}")

return all_tools

async def call_tool(self, tool_name: str, tool_args: Dict[str, Any]) -> CallToolResult:
async def call_tool(self, tool_name: str, tool_args: Dict[str, Any]) -> Union[CallToolResult, str]:
"""Call a specific tool by name with arguments"""

session = self.tools_to_sessions[tool_name]
Expand Down
4 changes: 2 additions & 2 deletions eval_protocol/mcp/simple_process_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import time
import uuid
from contextlib import AsyncExitStack
from typing import Dict, Tuple
from typing import Dict, Tuple, Optional

from mcp.client.session import ClientSession
from mcp.client.streamable_http import streamablehttp_client
Expand All @@ -26,7 +26,7 @@ class SimpleServerProcessManager:
def __init__(
self,
script_path: str,
python_executable: str = None,
python_executable: Optional[str] = None,
port_range: Tuple[int, int] = (10000, 11000),
):
"""
Expand Down
Loading
Loading