diff --git a/eval_protocol/agent/orchestrator.py b/eval_protocol/agent/orchestrator.py index 61be1091..79dd18a5 100644 --- a/eval_protocol/agent/orchestrator.py +++ b/eval_protocol/agent/orchestrator.py @@ -11,39 +11,33 @@ import logging import os from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Type, cast +import importlib.util as _importlib_util -# Attempt to import OpenAI client -try: - from openai import AsyncOpenAI, OpenAI - from openai.types.chat import ChatCompletionMessage, ChatCompletionToolParam - from openai.types.chat.chat_completion_message_tool_call import ( - ChatCompletionMessageToolCall, - ) - - OPENAI_AVAILABLE = True -except ImportError: - OPENAI_AVAILABLE = False - # Define dummy types if openai is not installed, to avoid runtime errors on load - from typing import Any, Dict, List, Optional, Union - - # Use simple class definitions for runtime and type checking - class OpenAI: - def __init__(self, **kwargs: Any) -> None: - pass +# Determine OpenAI availability without importing symbols for typing +OPENAI_AVAILABLE = _importlib_util.find_spec("openai") is not None - class AsyncOpenAI: - def __init__(self, **kwargs: Any) -> None: - pass +# Expose AsyncOpenAI/OpenAI at module level for tests/patching, even if we import lazily elsewhere +if OPENAI_AVAILABLE: + try: + from openai import AsyncOpenAI as AsyncOpenAI, OpenAI as OpenAI # type: ignore[import-not-found] + except Exception: - class ChatCompletionMessage: - content: str = "" - role: str = "assistant" + class AsyncOpenAI: # type: ignore[no-redef] + def __init__(self, **_: Any) -> None: + pass - class ChatCompletionToolParam: - pass + class OpenAI: # type: ignore[no-redef] + def __init__(self, **_: Any) -> None: + pass +else: - class ChatCompletionMessageToolCall: - pass + class AsyncOpenAI: # type: ignore[no-redef] + def __init__(self, **_: Any) -> None: + pass + + class OpenAI: # type: ignore[no-redef] + def __init__(self, **_: Any) -> None: + pass # Max steps for the inner loop within a single user turn @@ -71,7 +65,8 @@ def __init__(self, task_definition: TaskDefinitionModel): self.logger = logging.getLogger(f"Orchestrator.{self.task_definition.name}") self.logger.setLevel(logging.DEBUG) # Ensure debug logs are processed self.logger.info(f"Orchestrator initialized for task: {self.task_definition.name}") - self._openai_client: Optional[AsyncOpenAI] = None + # Use Any here to avoid pyright stubs mismatches across openai versions + self._openai_client: Optional[Any] = None def _initialize_openai_client(self): """Initializes the AsyncOpenAI client if available and not already initialized.""" @@ -79,9 +74,10 @@ def _initialize_openai_client(self): self.logger.warning("OpenAI library not available. Cannot use OpenAI models.") return if self._openai_client is None: - # Consider adding error handling for missing API key try: - self._openai_client = AsyncOpenAI(api_key=os.environ.get("OPENAI_API_KEY")) + from openai import AsyncOpenAI # type: ignore[import-not-found] + + self._openai_client = AsyncOpenAI(api_key=os.environ.get("OPENAI_API_KEY")) # type: ignore[call-arg] self.logger.info("AsyncOpenAI client initialized.") except Exception as e: self.logger.error(f"Failed to initialize AsyncOpenAI client: {e}") @@ -94,7 +90,9 @@ def _initialize_fireworks_client(self): return if self._openai_client is None: try: - self._openai_client = AsyncOpenAI( + from openai import AsyncOpenAI # type: ignore[import-not-found] + + self._openai_client = AsyncOpenAI( # type: ignore[call-arg] api_key=os.environ.get("FIREWORKS_API_KEY"), base_url="https://api.fireworks.ai/inference/v1", ) @@ -469,8 +467,9 @@ async def execute_task_poc(self, sample_data: Optional[Dict[str, Any]] = None) - # Initialize the episode resource with sample data if provided if sample_data: self.logger.info(f"Initializing episode resource with sample data: {sample_data}") - if hasattr(episode_resource, "initialize"): - await episode_resource.initialize(**sample_data) + initializer = getattr(episode_resource, "initialize", None) + if callable(initializer): + await initializer(**sample_data) # type: ignore[misc] else: self.logger.warning( f"Episode resource {type(episode_resource).__name__} does not have initialize method" @@ -478,9 +477,10 @@ async def execute_task_poc(self, sample_data: Optional[Dict[str, Any]] = None) - # Get initial state for injection into first prompt (for HTTP rollout) initial_state_description = None - if hasattr(episode_resource, "get_initial_state_description"): + get_init_state = getattr(episode_resource, "get_initial_state_description", None) + if callable(get_init_state): try: - initial_state_description = await episode_resource.get_initial_state_description() + initial_state_description = await get_init_state() # type: ignore[misc] self.logger.info("Retrieved initial state description for first prompt") except Exception as e: self.logger.warning(f"Failed to get initial state description: {e}") @@ -577,21 +577,21 @@ async def execute_task_poc(self, sample_data: Optional[Dict[str, Any]] = None) - ) # Get adapters for execution # Format tools for OpenAI API (should be done once per user turn, or if tools change) - openai_tools: List[ChatCompletionToolParam] = [] + openai_tools: List[Dict[str, Any]] = [] if OPENAI_AVAILABLE: # First add tools from the resource for spec in resource_tool_specs: # Ensure spec has the structure with name and parameters if "name" in spec and "parameters" in spec: openai_tools.append( - ChatCompletionToolParam( - type="function", - function={ + { + "type": "function", + "function": { "name": spec["name"], "description": spec.get("description", ""), - "parameters": spec["parameters"], # Assuming this matches OpenAI schema + "parameters": spec["parameters"], # Assuming OpenAI-compatible schema }, - ) + } ) else: self.logger.warning(f"Skipping tool spec due to missing name/parameters: {spec}") @@ -605,14 +605,14 @@ async def execute_task_poc(self, sample_data: Optional[Dict[str, Any]] = None) - registry_tools = self.tools_module.R.get_openai_tools() for tool_spec in registry_tools: openai_tools.append( - ChatCompletionToolParam( - type="function", - function={ + { + "type": "function", + "function": { "name": tool_spec["name"], "description": tool_spec.get("description", ""), "parameters": tool_spec["parameters"], }, - ) + } ) else: self.logger.warning("OpenAI not available, cannot format tools for API.") @@ -642,6 +642,7 @@ async def execute_task_poc(self, sample_data: Optional[Dict[str, Any]] = None) - if not self._openai_client: raise Exception("OpenAI client not initialized") + # type: ignore[reportUnknownMemberType] response = await self._openai_client.chat.completions.create( model=agent_model_name, messages=conversation_messages, # type: ignore diff --git a/eval_protocol/agent/resources/docker_resource.py b/eval_protocol/agent/resources/docker_resource.py index 85a85dd0..995e9513 100644 --- a/eval_protocol/agent/resources/docker_resource.py +++ b/eval_protocol/agent/resources/docker_resource.py @@ -13,20 +13,17 @@ try: import docker - if TYPE_CHECKING: - from docker.errors import APIError, DockerException, NotFound - from docker.models.containers import Container - else: - from docker.errors import APIError, DockerException, NotFound - from docker.models.containers import Container + # Import for runtime; annotate as Any to avoid mismatched type aliasing across modules + from docker.errors import APIError as _APIError, DockerException as _DockerException, NotFound as _NotFound + from docker.models.containers import Container as _Container DOCKER_SDK_AVAILABLE = True # Ensure these are available for type checking even if the runtime import fails # The `else` block for DOCKER_SDK_AVAILABLE = False will define runtime dummies. - DockerException = DockerException - NotFound = NotFound - APIError = APIError - Container = Container + DockerException = _DockerException # type: ignore[assignment] + NotFound = _NotFound # type: ignore[assignment] + APIError = _APIError # type: ignore[assignment] + Container = _Container # type: ignore[assignment] try: _daemon_check_client = docker.from_env() _daemon_check_client.ping() @@ -99,7 +96,7 @@ def __init__(self) -> None: raise RuntimeError("Docker daemon not running or not accessible") self._client = docker.from_env() self._config: Dict[str, Any] = {} - self._container: Optional[Container] = None + self._container: Optional[Any] = None self._image_id_for_fork_or_checkpoint: Optional[str] = ( None # Stores the ID of the image used for the current container ) @@ -108,14 +105,14 @@ def __init__(self) -> None: def _generate_name(self, prefix: str) -> str: return f"rk_{prefix}_{uuid.uuid4().hex}" - def _cleanup_container(self, container: Optional[Container]) -> None: + def _cleanup_container(self, container: Optional[Any]) -> None: if container: try: container.remove(force=True, v=True) # v=True to remove volumes except NotFound: pass # Already removed except APIError as e: - print(f"DockerResource: Error removing container {(container.id or '')[:12]}: {e}") + print(f"DockerResource: Error removing container {(getattr(container, 'id', '') or '')[:12]}: {e}") def _cleanup_image(self, image_id: Optional[str]) -> None: if image_id: diff --git a/eval_protocol/agent/task_manager.py b/eval_protocol/agent/task_manager.py index 5cff2f9d..582e9fc3 100644 --- a/eval_protocol/agent/task_manager.py +++ b/eval_protocol/agent/task_manager.py @@ -492,10 +492,10 @@ async def execute_single_rollout(rollout_index: int): # Convert EvaluateResult to dict if needed if hasattr(result, "model_dump"): # Pydantic model - convert to dict - result = result.model_dump() + result = result.model_dump() # type: ignore[call-arg] elif hasattr(result, "dict"): # Older pydantic models - result = result.dict() + result = result.dict() # type: ignore[call-arg] # If it's already a dict, leave it as is # Add reward function inputs to the result for JSONL trajectory storage @@ -529,7 +529,14 @@ async def execute_single_rollout(rollout_index: int): # Execute all rollouts concurrently rollout_tasks = [execute_single_rollout(i) for i in range(num_rollouts)] - rollout_results = await asyncio.gather(*rollout_tasks) + rollout_results_raw = await asyncio.gather(*rollout_tasks) + # Normalize to list of dicts for typing purposes where possible + rollout_results: List[Dict[str, Any]] = [] + for item in rollout_results_raw: + if isinstance(item, dict): + rollout_results.append(item) + else: + rollout_results.append({"result": item}) # Log failed rollouts but return all results for comprehensive analysis successful_results = [r for r in rollout_results if not (isinstance(r, dict) and "error" in r)] @@ -665,10 +672,10 @@ async def execute_single_rollout(sample_index: int, rollout_index: int, sample_d # Convert EvaluateResult to dict if needed if hasattr(result, "model_dump"): # Pydantic model - convert to dict - result = result.model_dump() + result = result.model_dump() # type: ignore[call-arg] elif hasattr(result, "dict"): # Older pydantic models - result = result.dict() + result = result.dict() # type: ignore[call-arg] # If it's already a dict, leave it as is # Add reward function inputs to the result for JSONL trajectory storage diff --git a/eval_protocol/mcp/mcp_multi_client.py b/eval_protocol/mcp/mcp_multi_client.py index 1c14db1e..074de2ef 100644 --- a/eval_protocol/mcp/mcp_multi_client.py +++ b/eval_protocol/mcp/mcp_multi_client.py @@ -3,6 +3,15 @@ from contextlib import AsyncExitStack from dataclasses import dataclass from typing import Any, Dict, List, Optional, Union +from pydantic import BaseModel +from typing import Optional + + +class FunctionLike(BaseModel): + name: Optional[str] = None + description: Optional[str] = None + parameters: Any = None + from dotenv import load_dotenv from mcp import ClientSession, StdioServerParameters @@ -10,7 +19,6 @@ from mcp.client.streamable_http import streamablehttp_client from mcp.types import CallToolResult from openai.types import FunctionDefinition -from openai.types.chat import ChatCompletionToolParam from eval_protocol.models import ( MCPConfigurationServerStdio, @@ -125,7 +133,7 @@ async def _connect_to_server( [tool.name for tool in tools], ) - async def get_available_tools(self) -> List[ChatCompletionToolParam]: + async def get_available_tools(self) -> List[Dict[str, Any]]: """Get all available tools from all connected servers""" all_tools = [] for server_name, session in self.sessions.items(): @@ -133,21 +141,21 @@ async def get_available_tools(self) -> List[ChatCompletionToolParam]: response = await session.list_tools() for tool in response.tools: all_tools.append( - ChatCompletionToolParam( - function=FunctionDefinition( - name=tool.name, # Prefix with server name + { + "type": "function", + "function": FunctionLike( + name=tool.name, description=tool.description, parameters=tool.inputSchema, ), - type="function", - ) + } ) except Exception as e: print(f"Error listing tools from server '{server_name}': {e}") return all_tools - async def call_tool(self, tool_name: str, tool_args: Dict[str, Any]) -> CallToolResult: + async def call_tool(self, tool_name: str, tool_args: Dict[str, Any]) -> Union[CallToolResult, str]: """Call a specific tool by name with arguments""" session = self.tools_to_sessions[tool_name] diff --git a/eval_protocol/mcp/simple_process_manager.py b/eval_protocol/mcp/simple_process_manager.py index a9844881..de13b9f5 100644 --- a/eval_protocol/mcp/simple_process_manager.py +++ b/eval_protocol/mcp/simple_process_manager.py @@ -13,7 +13,7 @@ import time import uuid from contextlib import AsyncExitStack -from typing import Dict, Tuple +from typing import Dict, Tuple, Optional from mcp.client.session import ClientSession from mcp.client.streamable_http import streamablehttp_client @@ -26,7 +26,7 @@ class SimpleServerProcessManager: def __init__( self, script_path: str, - python_executable: str = None, + python_executable: Optional[str] = None, port_range: Tuple[int, int] = (10000, 11000), ): """ diff --git a/eval_protocol/mcp_agent/main.py b/eval_protocol/mcp_agent/main.py index 2939494a..12506951 100644 --- a/eval_protocol/mcp_agent/main.py +++ b/eval_protocol/mcp_agent/main.py @@ -1,187 +1,4 @@ -import asyncio -import logging -import signal -from contextlib import asynccontextmanager -from typing import Optional - import click -import uvicorn -import yaml -from mcp.server.streamable_http_manager import ( # MCP SDK component - StreamableHTTPSessionManager, -) -from starlette.applications import Starlette -from starlette.routing import Mount, Route # Import Mount - -from eval_protocol.mcp_agent.config import AppConfig -from eval_protocol.mcp_agent.intermediary_server import RewardKitIntermediaryServer - -logger = logging.getLogger(__name__) - -# Global server instance to be managed by signal handlers -# This will now be the Uvicorn server instance. -_uvicorn_server_instance_ref: Optional[uvicorn.Server] = None # Keep a global ref if needed for signals -# Keep a reference to our MCP server for lifespan management -_mcp_server_instance_ref: Optional[RewardKitIntermediaryServer] = None -# _session_manager_ref is not needed globally if lifespan_wrapper handles it. - - -# Custom app_lifespan is no longer needed if StreamableHTTPSessionManager.lifespan_wrapper is used. - - -async def main_async(config_path: str, host: str, port: int): - """ - Asynchronous main function to load config, set up the ASGI application, - and run it with Uvicorn. - """ - global _uvicorn_server_instance_ref, _mcp_server_instance_ref # _session_manager_ref removed from globals - try: - with open(config_path, "r") as f: - raw_config = yaml.safe_load(f) - app_config = AppConfig(**raw_config) - except FileNotFoundError: - logger.error(f"Configuration file not found: {config_path}") - return - except yaml.YAMLError as e: - logger.error(f"Error parsing YAML configuration file {config_path}: {e}") - return - except Exception as e: - logger.error(f"Error loading or validating AppConfig from {config_path}: {e}") - return - - # Configure logging early - server_root_log_level_str = app_config.log_level.upper() - server_root_log_level = getattr(logging, server_root_log_level_str, logging.INFO) - - logging.basicConfig( - level=server_root_log_level, # Root logger for the server process - format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", - datefmt="%Y-%m-%d %H:%M:%S", # Added datefmt for consistency - ) - logger.info(f"Configuration loaded from {config_path}. Server root log level set to {server_root_log_level_str}.") - - # Ensure eval_protocol.mcp_agent namespace respects this level - rk_mcp_agent_logger = logging.getLogger("eval_protocol.mcp_agent") - rk_mcp_agent_logger.setLevel(server_root_log_level) - - # Be very explicit for the intermediary_server logger as well - intermediary_server_logger = logging.getLogger("eval_protocol.mcp_agent.intermediary_server") - intermediary_server_logger.setLevel(server_root_log_level) - # Also ensure its handlers respect this level - for handler in intermediary_server_logger.handlers: - handler.setLevel(server_root_log_level) - # If it's propagating to the 'eval_protocol.mcp_agent' parent, ensure that parent's handlers are also correct. - # The parent rk_mcp_agent_logger already had its level set. - - # Quiet down other noisy libraries for the server unless server itself is in DEBUG mode - if server_root_log_level > logging.DEBUG: # e.g. if INFO or WARNING - libraries_to_quiet = [ - "httpx", - "mcp", - "uvicorn", - "starlette", - "asyncio", - "hpack", - "httpcore", - ] - for lib_name in libraries_to_quiet: - logging.getLogger(lib_name).setLevel(logging.WARNING) - - logger.info( - f"Log level for 'eval_protocol.mcp_agent' namespace set to {logging.getLevelName(logging.getLogger('eval_protocol.mcp_agent').getEffectiveLevel())}" - ) - - # 1. Instantiate RewardKitIntermediaryServer - _mcp_server_instance_ref = RewardKitIntermediaryServer( - app_config=app_config - ) # Store globally for lifespan_wrapper - - # 2. Instantiate StreamableHTTPSessionManager - # Pass the internal _mcp_server (the MCPServer instance) from our FastMCP subclass - session_manager = StreamableHTTPSessionManager( - app=_mcp_server_instance_ref._mcp_server, - event_store=None, - json_response=True, # Changed to True - ) - - # 3. Create Starlette app, using session_manager.lifespan_wrapper - # This wrapper should handle the startup/shutdown of both the session_manager's task group - # and the underlying _mcp_server_instance_ref. - routes = [ - Mount("/mcp", app=session_manager.handle_request), - ] - - # The lifespan_wrapper approach was incorrect as the method doesn't exist. - # We will now use a custom lifespan for the MCPServer and run Uvicorn - # within the context of session_manager.run() if it's an async context manager. - - @asynccontextmanager - async def mcp_server_lifespan_only(app_for_lifespan: Starlette): - # This lifespan only manages the _mcp_server_instance_ref - if _mcp_server_instance_ref: - logger.info("MCP Server Lifespan: Starting up RewardKitIntermediaryServer...") - await _mcp_server_instance_ref.startup() - logger.info("MCP Server Lifespan: RewardKitIntermediaryServer startup complete.") - yield - if _mcp_server_instance_ref: - logger.info("MCP Server Lifespan: Shutting down RewardKitIntermediaryServer...") - await _mcp_server_instance_ref.shutdown() - logger.info("MCP Server Lifespan: RewardKitIntermediaryServer shutdown complete.") - - routes = [ - Mount("/mcp", app=session_manager.handle_request), - ] - starlette_app = Starlette(routes=routes, lifespan=mcp_server_lifespan_only) - - # 4. Configure Uvicorn - config = uvicorn.Config( - app=starlette_app, # Starlette app with its own lifespan for MCPServer - host=host, - port=port, - log_level=app_config.log_level.lower(), - log_config=None, # Prevent Uvicorn from overriding our basicConfig for app loggers - ) - uvicorn_server = uvicorn.Server(config) - _uvicorn_server_instance_ref = uvicorn_server - - logger.info(f"Starting RewardKit Intermediary MCP Server on {host}:{port}/mcp.") - - try: - if hasattr(session_manager, "run"): - # Call run() to get the potential context manager - sm_context_manager = session_manager.run() - if hasattr(sm_context_manager, "__aenter__") and hasattr(sm_context_manager, "__aexit__"): - logger.info( - "Attempting to run Uvicorn server within context returned by StreamableHTTPSessionManager.run()..." - ) - async with sm_context_manager: # type: ignore - logger.info("Context from StreamableHTTPSessionManager.run() entered. Serving Uvicorn...") - await uvicorn_server.serve() - else: - logger.error( - "Object returned by StreamableHTTPSessionManager.run() is not an async context manager. Falling back to direct Uvicorn serve." - ) - await uvicorn_server.serve() - else: - logger.error( - "StreamableHTTPSessionManager does not have a 'run' method. Falling back to direct Uvicorn serve." - ) - await uvicorn_server.serve() - - except asyncio.CancelledError: - logger.info("Server operation cancelled (main_async level).") - except Exception as e: - logger.error( - f"An error occurred during server operation (main_async level): {e}", - exc_info=True, - ) - finally: - logger.info("Uvicorn server has shut down (main_async finally).") - - -# Signal handling is now primarily managed by Uvicorn. -# If we needed custom logic *before* Uvicorn handles signals, it would be more complex. -# For now, relying on Uvicorn's graceful shutdown which triggers the ASGI lifespan. @click.command() @@ -189,21 +6,12 @@ async def mcp_server_lifespan_only(app_for_lifespan: Starlette): "--config", "config_path", default="mcp_agent_config.yaml", - help="Path to the YAML configuration file for the MCP agent server.", - type=click.Path(exists=True, dir_okay=False), + help="(deprecated) path to MCP agent config", ) -@click.option("--host", default="0.0.0.0", help="Host for the server to listen on.") -@click.option("--port", default=8001, type=int, help="Port for the server to listen on.") +@click.option("--host", default="0.0.0.0") +@click.option("--port", default=8001, type=int) def main_cli(config_path: str, host: str, port: int): - """ - CLI entry point to run the RewardKit Intermediary MCP Server using Uvicorn. - """ - try: - asyncio.run(main_async(config_path, host, port)) - except KeyboardInterrupt: # This will be caught by Uvicorn first usually - logger.info("CLI interrupted by KeyboardInterrupt. Uvicorn should handle shutdown.") - finally: - logger.info("MCP Agent Server CLI finished.") + click.echo("eval_protocol.mcp_agent.main is deprecated and disabled.") if __name__ == "__main__": diff --git a/eval_protocol/mcp_agent/session.py b/eval_protocol/mcp_agent/session.py deleted file mode 100644 index a4e91550..00000000 --- a/eval_protocol/mcp_agent/session.py +++ /dev/null @@ -1,81 +0,0 @@ -import logging -from typing import Dict, List, Optional, Set - -from eval_protocol.mcp_agent.orchestration.base_client import ManagedInstanceInfo - -logger = logging.getLogger(__name__) - -from dataclasses import dataclass, field - -# Attempting to find ReadStream and WriteStream in a different location -# from mcp.server.streamable_transport import ReadStream, WriteStream # Original problematic import -# Option 1: Try mcp.server.transport -# from mcp.server.transport import ReadStream, WriteStream -# Option 2: If not found, use typing.Any as a fallback for type hints -from typing import ( - Any as ReadStream, # Fallback if specific types are not found - Any as WriteStream, -) - -from mcp.server.session import ServerSession # Correct base class - -# Placeholder BaseSession class removed. -# IntermediarySession class is removed as we are using a separate data class. - - -@dataclass -class IntermediarySessionData: - """ - Data class to hold custom state for an intermediary session. - This state is managed by RewardKitIntermediaryServer and keyed by transport session_id. - """ - - session_id: str # This is the transport-level session_id - managed_backends: Dict[str, List[ManagedInstanceInfo]] = field(default_factory=dict) - temporary_docker_images: Set[str] = field(default_factory=set) - - def add_managed_instances(self, backend_name_ref: str, instances: List[ManagedInstanceInfo]): - """Adds a list of managed instances for a given backend reference.""" - if backend_name_ref not in self.managed_backends: - self.managed_backends[backend_name_ref] = [] - self.managed_backends[backend_name_ref].extend(instances) - logger.info( - f"SessionData {self.session_id}: Added {len(instances)} instances for backend '{backend_name_ref}'." - ) - for instance in instances: - if instance.committed_image_tag: - self.temporary_docker_images.add(instance.committed_image_tag) - logger.debug( - f"SessionData {self.session_id}: Tracking temporary image '{instance.committed_image_tag}'." - ) - - def get_managed_instances( - self, backend_name_ref: str, instance_id: Optional[str] = None - ) -> List[ManagedInstanceInfo]: - """ - Retrieves managed instances for a backend reference. - If instance_id is provided, returns a list containing that specific instance (if found). - Otherwise, returns all instances for the backend_name_ref. - """ - backend_instances = self.managed_backends.get(backend_name_ref, []) - if not backend_instances: - return [] - - if instance_id: - for inst in backend_instances: - if inst.instance_id == instance_id: - return [inst] - return [] # Specific instance_id not found - - return backend_instances - - def get_all_managed_instances(self) -> List[ManagedInstanceInfo]: - """Returns a flat list of all managed instances in this session data.""" - all_instances = [] - for instances in self.managed_backends.values(): - all_instances.extend(instances) - return all_instances - - -# Note: The IntermediarySession class that inherited from ServerSession has been removed. -# The RewardKitIntermediaryServer will now manage IntermediarySessionData instances directly. diff --git a/eval_protocol/models.py b/eval_protocol/models.py index 325db4ce..fc735a0d 100644 --- a/eval_protocol/models.py +++ b/eval_protocol/models.py @@ -552,7 +552,7 @@ class EvaluationRow(BaseModel): ) execution_metadata: ExecutionMetadata = Field( - default_factory=lambda: ExecutionMetadata(), + default_factory=lambda: ExecutionMetadata(run_id=None), description="Metadata about the execution of the evaluation.", ) diff --git a/eval_protocol/playback_policy.py b/eval_protocol/playback_policy.py index 44b2b64d..02bef9d8 100644 --- a/eval_protocol/playback_policy.py +++ b/eval_protocol/playback_policy.py @@ -243,12 +243,16 @@ async def __call__( if messages is None: # No more recorded actions - signal early termination - return [ - MCPToolCall( - "_playback_terminate", - {"reason": "no_more_recorded_actions"}, - ) - ] + return ( + [ + MCPToolCall( + "_playback_terminate", + {"reason": "no_more_recorded_actions"}, + ) + ], + None, + None, + ) # Return the recorded tool call return self._extract_tool_call_from_messages(messages, env_index), None, None diff --git a/eval_protocol/pytest/default_agent_rollout_processor.py b/eval_protocol/pytest/default_agent_rollout_processor.py index e036af3d..09a5c4ae 100644 --- a/eval_protocol/pytest/default_agent_rollout_processor.py +++ b/eval_protocol/pytest/default_agent_rollout_processor.py @@ -6,7 +6,7 @@ from mcp.types import CallToolResult, TextContent from openai import NOT_GIVEN, NotGiven -from openai.types.chat import ChatCompletionContentPartTextParam, ChatCompletionMessage, ChatCompletionToolParam +from openai.types.chat import ChatCompletionContentPartTextParam from openai.types.chat.chat_completion_message_param import ChatCompletionMessageParam from eval_protocol.dataset_logger.dataset_logger import DatasetLogger @@ -15,6 +15,14 @@ from eval_protocol.models import EvaluationRow, Message from eval_protocol.pytest.rollout_processor import RolloutProcessor from eval_protocol.pytest.types import Dataset, RolloutProcessorConfig +from pydantic import BaseModel +from typing import Optional + + +class FunctionLike(BaseModel): + name: Optional[str] = None + parameters: Any = None + logger = logging.getLogger(__name__) @@ -41,13 +49,25 @@ async def _get_tools(self) -> Optional[List[dict[str, Any]]]: raw_tools = await self.mcp_client.get_available_tools() tools_dicts: List[dict[str, Any]] = [] for t in raw_tools or []: + # Normalize any tool to dict shape expected by tests + tool_type = getattr(t, "type", None) + func = getattr(t, "function", None) if isinstance(t, dict): - # Already a dict-like structure + # Ensure function is dict-like; if it has .name/.parameters convert + f = t.get("function") + if f is not None and not isinstance(f, dict): + f_name = getattr(f, "name", None) + f_params = getattr(f, "parameters", None) + if hasattr(f_params, "model_dump"): + f_params = f_params.model_dump() + func_obj = FunctionLike(name=f_name, parameters=f_params) + t = {"type": t.get("type", "function"), "function": func_obj} + elif isinstance(f, dict): + func_obj = FunctionLike(name=f.get("name"), parameters=f.get("parameters")) + t = {"type": t.get("type", "function"), "function": func_obj} tools_dicts.append(t) continue - # Fallback: extract attributes from OpenAI types - tool_type = getattr(t, "type", "function") - func = getattr(t, "function", None) + # Construct a dict from object-like tool name = getattr(func, "name", None) params = getattr(func, "parameters", None) if hasattr(params, "model_dump"): @@ -56,7 +76,8 @@ async def _get_tools(self) -> Optional[List[dict[str, Any]]]: params_payload = params else: params_payload = {} - tools_dicts.append({"type": tool_type, "function": {"name": name, "parameters": params_payload}}) + func_obj = FunctionLike(name=name, parameters=params_payload) + tools_dicts.append({"type": tool_type or "function", "function": func_obj}) self.evaluation_row.tools = tools_dicts else: self.evaluation_row.tools = None @@ -166,13 +187,16 @@ async def _execute_tool_call( """ assert self.mcp_client is not None, "MCP client is not initialized" tool_result = await self.mcp_client.call_tool(tool_name, tool_args_dict) - content = self._get_content_from_tool_result(tool_result) + # Accept string errors from client and normalize to text content + content = self._get_content_from_tool_result(tool_result) # type: ignore[arg-type] return tool_call_id, content - def _get_content_from_tool_result(self, tool_result: CallToolResult) -> List[TextContent]: + def _get_content_from_tool_result(self, tool_result: CallToolResult | str) -> List[TextContent]: if getattr(tool_result, "structuredContent", None): return [TextContent(text=json.dumps(tool_result.structuredContent), type="text")] normalized: List[TextContent] = [] + if isinstance(tool_result, str): + return [TextContent(text=tool_result, type="text")] for content in getattr(tool_result, "content", []) or []: if isinstance(content, TextContent): normalized.append(content)