diff --git a/eval_protocol/benchmarks/test_livebench_data_analysis.py b/eval_protocol/benchmarks/test_livebench_data_analysis.py index a54ccc2b..75dc4613 100644 --- a/eval_protocol/benchmarks/test_livebench_data_analysis.py +++ b/eval_protocol/benchmarks/test_livebench_data_analysis.py @@ -258,6 +258,7 @@ def _read_jsonl_table_from_text(text: str, header_cols: List[str]): reader = _read_df_v1 if version == "v1" else _read_df_v2 gt_df = reader(output_fmt, ground_truth) + assert gt_df is not None, "GT dataframe is None" llm_clean = _clean_llm_output(llm_answer) llm_clean = _remove_initial_phrase(llm_clean) diff --git a/eval_protocol/benchmarks/test_tau_bench_airline.py b/eval_protocol/benchmarks/test_tau_bench_airline.py index c2b260ee..24417ddc 100644 --- a/eval_protocol/benchmarks/test_tau_bench_airline.py +++ b/eval_protocol/benchmarks/test_tau_bench_airline.py @@ -198,6 +198,7 @@ def test_tau_bench_airline_evaluation(row: EvaluationRow) -> EvaluationRow: task = Task( id="Filler", evaluation_criteria=evaluation_criteria, user_scenario=UserScenario(instructions="Filler") ) # id and user_scenario are required for the Task type but not used in calculating reward + assert task.evaluation_criteria is not None, "Task evaluation criteria is None" if RewardType.DB in task.evaluation_criteria.reward_basis: env_reward_info = EnvironmentEvaluator.calculate_reward( diff --git a/eval_protocol/execution/pipeline.py b/eval_protocol/execution/pipeline.py index e644ba32..5cd7b4d9 100644 --- a/eval_protocol/execution/pipeline.py +++ b/eval_protocol/execution/pipeline.py @@ -212,6 +212,7 @@ async def _execute_standard_generation( if system_prompt_content: current_messages_for_rollout.append({"role": "system", "content": system_prompt_content}) current_messages_for_rollout.append({"role": "user", "content": user_query}) + assert self.model_client is not None, "at this point model client needs to be initialized" generation_output_std = await self.model_client.generate( messages=current_messages_for_rollout, diff --git a/eval_protocol/mcp/client/connection.py b/eval_protocol/mcp/client/connection.py index d8c13f3b..a6fcd53d 100644 --- a/eval_protocol/mcp/client/connection.py +++ b/eval_protocol/mcp/client/connection.py @@ -11,7 +11,7 @@ import logging import time from contextlib import AsyncExitStack -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple, cast import httpx from mcp.client.session import ClientSession @@ -276,7 +276,10 @@ async def _get_initial_state_from_mcp_resource(self, session: MCPSession) -> Any try: # List available resources - this is where initial state should come from logger.debug(f"Session {session.session_id}: Discovering MCP resources for initial state...") - resources_response = await mcp_session.list_resources() + mcp_session_local = session._mcp_session + if mcp_session_local is None: + raise RuntimeError("Session not initialized while listing resources") + resources_response = await mcp_session_local.list_resources() resources = resources_response.resources if hasattr(resources_response, "resources") else [] logger.debug(f"Session {session.session_id}: Found {len(resources)} MCP resources") for resource in resources: @@ -303,7 +306,10 @@ async def _get_initial_state_from_mcp_resource(self, session: MCPSession) -> Any f"Session {session.session_id}: Reading initial state from resource: {initial_state_resource.uri}" ) - resource_content = await mcp_session.read_resource(initial_state_resource.uri) + mcp_session_for_read = session._mcp_session + if mcp_session_for_read is None: + raise RuntimeError("Session not initialized while reading resource") + resource_content = await mcp_session_for_read.read_resource(initial_state_resource.uri) # Handle the new ResourceContents format text_value = getattr(resource_content, "text", None) @@ -348,7 +354,10 @@ async def _get_initial_state_from_mcp_resource(self, session: MCPSession) -> Any f"Session {session.session_id}: About to call mcp_session.read_resource with fallback URI: {first_resource.uri}" ) - resource_content = await mcp_session.read_resource(first_resource.uri) + mcp_session_for_fallback_read = session._mcp_session + if mcp_session_for_fallback_read is None: + raise RuntimeError("Session not initialized while reading fallback resource") + resource_content = await mcp_session_for_fallback_read.read_resource(first_resource.uri) logger.debug( f"Session {session.session_id}: fallback read_resource returned type: {type(resource_content)}" diff --git a/eval_protocol/mcp_agent/intermediary_server.py b/eval_protocol/mcp_agent/intermediary_server.py deleted file mode 100644 index 368f0232..00000000 --- a/eval_protocol/mcp_agent/intermediary_server.py +++ /dev/null @@ -1,541 +0,0 @@ -import asyncio -import logging -import uuid -from typing import Any, Dict, List, Optional - -import anyio # Added for debugging cancel scopes and tasks -from mcp import types as mcp_types # Added for type hinting -from pydantic import BaseModel, Field - -from eval_protocol.mcp_agent.config import AppConfig, BackendServerConfig -from eval_protocol.mcp_agent.orchestration.base_client import ( - AbstractOrchestrationClient, - ManagedInstanceInfo, -) -from eval_protocol.mcp_agent.orchestration.local_docker_client import ( - LocalDockerOrchestrationClient, -) -from eval_protocol.mcp_agent.orchestration.remote_http_client import ( - RemoteHttpOrchestrationClient, -) -from eval_protocol.mcp_agent.session import IntermediarySessionData - -logger = logging.getLogger(__name__) -# logger.setLevel(logging.DEBUG) # Removed: Let level be set by main config - -from mcp.server.fastmcp.server import Context as FastMCPContext, FastMCP - -# RequestContext is not directly used by handlers anymore, mcp_ctx is. - - -# Backend initialization models (moved here to avoid separate backends module) -class BackendInitRequest(BaseModel): - backend_name_ref: str = Field( - ..., - description="The unique reference name of the backend configuration to use (must match one in AppConfig.backends).", - ) - num_instances: int = Field( - 1, - ge=1, - description="Number of instances of this backend to provision for the session.", - ) - template_details: Optional[Any] = Field( - None, - description="Backend-specific details for initializing stateful instances from a template.", - ) - - class Config: - extra = "forbid" - - -class BackendInitResult(BaseModel): - backend_name_ref: str - instances: List[ManagedInstanceInfo] - - -# Pydantic models for tool arguments -class InitializeSessionArgs(BaseModel): - backends: List[BackendInitRequest] - - -class CallBackendToolArgs(BaseModel): - rk_session_id: str = Field(..., description="The session ID obtained from initialize_session.") - backend_name_ref: str = Field(..., description="The reference name of the backend to target.") - instance_id: str = Field(..., description="The ID of the specific backend instance to use.") - tool_name: str = Field(..., description="The name of the tool to call on the backend instance.") - tool_args: Dict[str, Any] = Field(default_factory=dict, description="Arguments for the backend tool.") - - -class ListBackendToolsArgs(BaseModel): - rk_session_id: str = Field(..., description="The session ID obtained from initialize_session.") - backend_name_ref: str = Field(..., description="The reference name of the backend to target.") - instance_id: str = Field(..., description="The ID of the specific backend instance to query for tools.") - - -class CleanupSessionArgs(BaseModel): - rk_session_id: str = Field(..., description="The session ID to clean up.") - - -# Ping might not need specific args if it uses session from mcp_ctx, or could take rk_session_id -class PingArgs(BaseModel): - rk_session_id: Optional[str] = Field(default=None, description="Optional session ID for context.") - - -class RewardKitIntermediaryServer(FastMCP): - def __init__(self, app_config: AppConfig, **kwargs_for_fastmcp): - super().__init__( - name="RewardKitIntermediaryMCP", - instructions="Intermediary Server for managing backend MCP resources for RewardKit RL rollouts.", - **kwargs_for_fastmcp, - ) - - self.app_config = app_config - self._local_docker_orchestrator: Optional[LocalDockerOrchestrationClient] = None - self._remote_http_orchestrators: Dict[str, RemoteHttpOrchestrationClient] = {} - self._shared_global_instances: Dict[str, ManagedInstanceInfo] = {} - self._shared_instance_locks: Dict[str, asyncio.Lock] = {} - self.intermediary_session_data: Dict[str, IntermediarySessionData] = {} - - logger.info("RewardKitIntermediaryServer (FastMCP based) initialized. AppConfig loaded.") - - # Register tools directly - self.add_tool(self._initialize_session_actual, name="initialize_session") - self.add_tool(self._call_backend_tool_actual, name="call_backend_tool") - self.add_tool(self._list_backend_tools_actual, name="list_backend_tools") # New tool - self.add_tool(self._cleanup_session_actual, name="cleanup_session") - self.add_tool(self._ping_actual, name="ping") - - logger.info("Registered tools directly with FastMCP.") - - # Explicitly set this module's logger level based on app_config - # This is to ensure it overrides any prior default or hardcoded DEBUG level - # if external configuration in main.py isn't fully effective. - try: - config_log_level_str = app_config.log_level.upper() - config_log_level_int = getattr(logging, config_log_level_str, logging.INFO) - if logger.getEffectiveLevel() != config_log_level_int: - logger.info( - f"Overriding intermediary_server logger level from {logging.getLevelName(logger.getEffectiveLevel())} to {config_log_level_str}" - ) - logger.setLevel(config_log_level_int) - # Also ensure handlers attached directly to this logger respect it (if any) - for handler in logger.handlers: - handler.setLevel(config_log_level_int) - logger.info( - f"IntermediaryServer logger effective level: {logging.getLevelName(logger.getEffectiveLevel())}" - ) - - except Exception as e_log: - logger.error(f"Error trying to set intermediary_server logger level: {e_log}") - - # Removed _execute_proxied_tool_impl and _internal_tool_handlers - - async def _initialize_orchestrators(self): - logger.info("Initializing orchestration clients...") - if any(b.orchestration_mode == "local_docker" for b in self.app_config.backends): - self._local_docker_orchestrator = LocalDockerOrchestrationClient(self.app_config) - await self._local_docker_orchestrator.startup() - logger.info("LocalDockerOrchestrationClient initialized and started.") - - unique_remote_api_refs = set() - for backend_cfg in self.app_config.backends: - if backend_cfg.orchestration_mode == "remote_http_api": - if backend_cfg.remote_api_config_ref: - unique_remote_api_refs.add(backend_cfg.remote_api_config_ref) - elif backend_cfg.remote_api_config_inline: - logger.warning( - f"Inline remote_api_config for {backend_cfg.backend_name_ref}. Consider using global_remote_apis." - ) - key = backend_cfg.remote_api_config_inline.base_url - if key not in self._remote_http_orchestrators: - temp_app_config_for_inline = AppConfig( - global_remote_apis={key: backend_cfg.remote_api_config_inline} - ) - client = RemoteHttpOrchestrationClient(temp_app_config_for_inline) - await client.startup() - self._remote_http_orchestrators[key] = client - logger.info(f"RemoteHttpOrchestrationClient for inline config {key} initialized.") - - for ref_name in unique_remote_api_refs: - if ref_name not in self.app_config.global_remote_apis: - logger.error(f"Remote API ref '{ref_name}' not in global_remote_apis.") - continue - if ref_name not in self._remote_http_orchestrators: - isolated_app_cfg = AppConfig( - global_remote_apis={ref_name: self.app_config.global_remote_apis[ref_name]}, - global_remote_api_defaults=self.app_config.global_remote_api_defaults, - ) - client = RemoteHttpOrchestrationClient(isolated_app_cfg) - await client.startup() - self._remote_http_orchestrators[ref_name] = client - logger.info(f"RemoteHttpOrchestrationClient for '{ref_name}' initialized.") - logger.info("Orchestration clients initialization complete.") - - def _get_orchestration_client(self, backend_cfg: BackendServerConfig) -> AbstractOrchestrationClient: - if backend_cfg.orchestration_mode == "local_docker": - if not self._local_docker_orchestrator: - raise RuntimeError("Local Docker orchestrator not initialized.") - return self._local_docker_orchestrator - elif backend_cfg.orchestration_mode == "remote_http_api": - key = backend_cfg.remote_api_config_ref - if not key: - if backend_cfg.remote_api_config_inline: - key = backend_cfg.remote_api_config_inline.base_url - else: - raise ValueError(f"Remote API config missing for {backend_cfg.backend_name_ref}") - client = self._remote_http_orchestrators.get(key) - if not client: - raise RuntimeError(f"Remote HTTP orchestrator for '{key}' not initialized.") - return client - else: - raise ValueError(f"Unsupported orchestration mode: {backend_cfg.orchestration_mode}") - - async def _get_or_provision_shared_global_instance(self, backend_name_ref: str) -> ManagedInstanceInfo: - if backend_name_ref not in self._shared_instance_locks: - self._shared_instance_locks[backend_name_ref] = asyncio.Lock() - async with self._shared_instance_locks[backend_name_ref]: - if backend_name_ref in self._shared_global_instances: - logger.info(f"Returning existing shared global instance for '{backend_name_ref}'.") - return self._shared_global_instances[backend_name_ref] - logger.info(f"Provisioning new shared global instance for '{backend_name_ref}'.") - backend_cfg = next( - (b for b in self.app_config.backends if b.backend_name_ref == backend_name_ref), - None, - ) - if not backend_cfg or backend_cfg.instance_scoping != "shared_global": - raise ValueError(f"Backend '{backend_name_ref}' not for shared_global scoping.") - orchestration_client = self._get_orchestration_client(backend_cfg) - provisioned_list = await orchestration_client.provision_instances( - backend_config=backend_cfg, - num_instances=1, - session_id="global_shared_session", - template_details=backend_cfg.template_data_path_host, - ) - if not provisioned_list: - raise RuntimeError(f"Failed to provision shared global for '{backend_name_ref}'.") - instance_info = provisioned_list[0] - self._shared_global_instances[backend_name_ref] = instance_info - logger.info(f"Provisioned shared global for '{backend_name_ref}': {instance_info.instance_id}") - return instance_info - - async def _provision_shared_global_instances(self): - logger.info("Pre-provisioning all shared_global instances...") - for backend_cfg in self.app_config.backends: - if backend_cfg.instance_scoping == "shared_global": - try: - await self._get_or_provision_shared_global_instance(backend_cfg.backend_name_ref) - except Exception as e: - logger.error( - f"Failed to pre-provision for '{backend_cfg.backend_name_ref}': {e}", - exc_info=True, - ) - logger.info("Shared_global instances pre-provisioning complete.") - - async def _initialize_session_actual(self, mcp_ctx: FastMCPContext, args: InitializeSessionArgs) -> Dict[str, Any]: - task_name = anyio.get_current_task().name if anyio.get_current_task() else "unknown_task" - logger.debug( - f"ENTERING _initialize_session_actual: task='{task_name}', mcp_ctx type: {type(mcp_ctx)}, args: {args}" - ) - - transport_session_id: Optional[str] = None - if ( - hasattr(mcp_ctx, "session") - and mcp_ctx.session - and hasattr(mcp_ctx.session, "client_params") - and mcp_ctx.session.client_params - and hasattr(mcp_ctx.session.client_params, "session_id") - and mcp_ctx.session.client_params.session_id - ): - transport_session_id = mcp_ctx.session.client_params.session_id - logger.info(f"Retrieved transport_session_id: {transport_session_id}") - - rk_session_id = transport_session_id if transport_session_id else uuid.uuid4().hex - if not transport_session_id: - logger.warning(f"Transport session ID not found. Generated new rk_session_id: {rk_session_id}") - else: - logger.info(f"Using transport_session_id as rk_session_id: {rk_session_id}") - - if rk_session_id in self.intermediary_session_data: - logger.warning(f"rk_session_id '{rk_session_id}' already exists. Overwriting.") - session_data = IntermediarySessionData(session_id=rk_session_id) - self.intermediary_session_data[rk_session_id] = session_data - - logger.info( - f"Initializing IntermediarySessionData for rk_session_id '{rk_session_id}' with {len(args.backends)} backend requests." - ) - initialized_backends_results: List[BackendInitResult] = [] - - for backend_req in args.backends: - backend_cfg = next( - (b for b in self.app_config.backends if b.backend_name_ref == backend_req.backend_name_ref), - None, - ) - if not backend_cfg: - logger.error(f"Session {rk_session_id}: Config for '{backend_req.backend_name_ref}' not found.") - initialized_backends_results.append( - BackendInitResult(backend_name_ref=backend_req.backend_name_ref, instances=[]) - ) - continue - try: - if backend_cfg.instance_scoping == "shared_global": - shared_instance_info = await self._get_or_provision_shared_global_instance( - backend_req.backend_name_ref - ) - instances_for_this_backend = [shared_instance_info] * backend_req.num_instances - else: - orchestration_client = self._get_orchestration_client(backend_cfg) - instances_for_this_backend = await orchestration_client.provision_instances( - backend_config=backend_cfg, - num_instances=backend_req.num_instances, - session_id=session_data.session_id, - template_details=backend_req.template_details, - ) - session_data.add_managed_instances(backend_req.backend_name_ref, instances_for_this_backend) - initialized_backends_results.append( - BackendInitResult( - backend_name_ref=backend_req.backend_name_ref, - instances=instances_for_this_backend, - ) - ) - except Exception as e: - logger.error( - f"Session {rk_session_id}: Error initializing '{backend_req.backend_name_ref}': {e}", - exc_info=True, - ) - initialized_backends_results.append( - BackendInitResult( - backend_name_ref=backend_req.backend_name_ref, - instances=[], - error_message=str(e), - ) - ) - - task_name_exit = anyio.get_current_task().name if anyio.get_current_task() else "unknown_task" - logger.debug(f"EXITING _initialize_session_actual: task='{task_name_exit}'") - return { - "rk_session_id": rk_session_id, - "initialized_backends": [res.model_dump(exclude_none=True) for res in initialized_backends_results], - } - - async def _call_backend_tool_actual(self, mcp_ctx: FastMCPContext, args: CallBackendToolArgs) -> Dict[str, Any]: - task_name_entry = anyio.get_current_task().name if anyio.get_current_task() else "unknown_task" - logger.debug( - f"ENTERING _call_backend_tool_actual: task='{task_name_entry}', mcp_ctx type: {type(mcp_ctx)}, args: {args}" - ) - - session_data = self.intermediary_session_data.get(args.rk_session_id) - if not session_data: - task_name_error = anyio.get_current_task().name if anyio.get_current_task() else "unknown_task" - logger.error( - f"ERROR in _call_backend_tool_actual (session not found): task='{task_name_error}', rk_session_id='{args.rk_session_id}'" - ) - raise ValueError(f"IntermediarySessionData for rk_session_id '{args.rk_session_id}' not found.") - - target_instances = session_data.get_managed_instances(args.backend_name_ref, args.instance_id) - if not target_instances: - raise ValueError( - f"Instance '{args.instance_id}' for backend '{args.backend_name_ref}' not found in session '{args.rk_session_id}'." - ) - managed_instance_info = target_instances[0] - backend_cfg = next( - (b for b in self.app_config.backends if b.backend_name_ref == args.backend_name_ref), - None, - ) - if not backend_cfg: - raise ValueError(f"Backend config '{args.backend_name_ref}' not found.") - orchestration_client = self._get_orchestration_client(backend_cfg) - - task_name_before_call = anyio.get_current_task().name if anyio.get_current_task() else "unknown_task" - logger.debug( - f"BEFORE orchestrator.call_tool_on_instance in _call_backend_tool_actual: task='{task_name_before_call}'" - ) - - try: - result = await orchestration_client.call_tool_on_instance( - instance=managed_instance_info, - tool_name=args.tool_name, - tool_args=args.tool_args, - ) - task_name_after_call = anyio.get_current_task().name if anyio.get_current_task() else "unknown_task" - logger.debug( - f"AFTER orchestrator.call_tool_on_instance in _call_backend_tool_actual: task='{task_name_after_call}'" - ) - - task_name_exit = anyio.get_current_task().name if anyio.get_current_task() else "unknown_task" - logger.debug(f"EXITING _call_backend_tool_actual (SUCCESS): task='{task_name_exit}'") - return result - except Exception as e: - task_name_exception = anyio.get_current_task().name if anyio.get_current_task() else "unknown_task" - logger.error( - f"EXCEPTION in _call_backend_tool_actual: task='{task_name_exception}'. Session {args.rk_session_id}: Error calling tool '{args.tool_name}' on instance '{args.instance_id}': {e}", - exc_info=True, - ) - raise - - async def _list_backend_tools_actual( - self, mcp_ctx: FastMCPContext, args: ListBackendToolsArgs - ) -> Dict[str, Any]: # Returning dict for FastMCP, will be ListToolsResult internally - task_name_entry = anyio.get_current_task().name if anyio.get_current_task() else "unknown_task" - logger.debug(f"ENTERING _list_backend_tools_actual: task='{task_name_entry}', args: {args}") - - session_data = self.intermediary_session_data.get(args.rk_session_id) - if not session_data: - logger.error( - f"ERROR in _list_backend_tools_actual (session not found): rk_session_id='{args.rk_session_id}'" - ) - raise ValueError(f"IntermediarySessionData for rk_session_id '{args.rk_session_id}' not found.") - - target_instances = session_data.get_managed_instances(args.backend_name_ref, args.instance_id) - if not target_instances: - raise ValueError( - f"Instance '{args.instance_id}' for backend '{args.backend_name_ref}' not found in session '{args.rk_session_id}'." - ) - managed_instance_info = target_instances[0] - - backend_cfg = next( - (b for b in self.app_config.backends if b.backend_name_ref == args.backend_name_ref), - None, - ) - if not backend_cfg: - raise ValueError(f"Backend config '{args.backend_name_ref}' not found.") - orchestration_client = self._get_orchestration_client(backend_cfg) - - logger.debug( - f"Calling orchestrator.list_tools_on_instance for backend '{args.backend_name_ref}', instance '{args.instance_id}'" - ) - try: - list_tools_result: mcp_types.ListToolsResult = await orchestration_client.list_tools_on_instance( - instance=managed_instance_info - ) - # FastMCP tools expect to return a dictionary that can be JSON serialized. - # ListToolsResult is a Pydantic model, so model_dump() is appropriate. - return list_tools_result.model_dump(exclude_none=True) - except Exception as e: - logger.error( - f"EXCEPTION in _list_backend_tools_actual for session {args.rk_session_id}, backend {args.backend_name_ref}, instance {args.instance_id}: {e}", - exc_info=True, - ) - raise # Re-raise to let FastMCP handle error reporting to client - - async def cleanup_session_internal(self, session_data_to_clean: IntermediarySessionData, rk_session_id: str): - logger.info(f"Starting internal cleanup for IntermediarySessionData (rk_session_id: '{rk_session_id}').") - all_session_instances = session_data_to_clean.get_all_managed_instances() - local_docker_instances = [inst for inst in all_session_instances if inst.orchestration_mode == "local_docker"] - if local_docker_instances and self._local_docker_orchestrator: - try: - await self._local_docker_orchestrator.deprovision_instances(local_docker_instances) - except Exception as e: - logger.error( - f"Session {rk_session_id}: Error deprovisioning local Docker: {e}", - exc_info=True, - ) - - remote_instances_by_key: Dict[str, List[ManagedInstanceInfo]] = {} - for inst in all_session_instances: - if inst.orchestration_mode == "remote_http_api": - key = self._get_orchestration_client_key_for_instance(inst) - if key: - remote_instances_by_key.setdefault(key, []).append(inst) - for key, remote_list in remote_instances_by_key.items(): - orchestrator = self._remote_http_orchestrators.get(key) - if orchestrator and remote_list: - try: - await orchestrator.deprovision_instances(remote_list) - except Exception as e: - logger.error( - f"Session {rk_session_id}: Error deprovisioning remote for '{key}': {e}", - exc_info=True, - ) - logger.info(f"Internal cleanup for session data (rk_session_id: '{rk_session_id}') complete.") - - async def _cleanup_session_actual(self, mcp_ctx: FastMCPContext, args: CleanupSessionArgs) -> Dict[str, str]: - logger.debug(f"_cleanup_session_actual called. mcp_ctx type: {type(mcp_ctx)}, args: {args}") - session_data_obj = self.intermediary_session_data.pop(args.rk_session_id, None) - if not session_data_obj: - logger.warning( - f"IntermediarySessionData for rk_session_id '{args.rk_session_id}' not found or already cleaned." - ) - return { - "status": "custom_session_data_not_found_or_already_cleaned", - "rk_session_id": args.rk_session_id, - } - await self.cleanup_session_internal(session_data_obj, args.rk_session_id) - logger.info(f"IntermediarySessionData for rk_session_id '{args.rk_session_id}' fully cleaned up.") - return {"status": "cleaned", "rk_session_id": args.rk_session_id} - - async def startup(self): - logger.info("RewardKitIntermediaryServer performing custom startup tasks...") - try: - await self._initialize_orchestrators() - await self._provision_shared_global_instances() - logger.info("RewardKitIntermediaryServer custom startup tasks complete.") - except Exception as e: - logger.error( - f"Error during RewardKitIntermediaryServer custom startup: {e}", - exc_info=True, - ) - raise - - async def _ping_actual(self, mcp_ctx: FastMCPContext, args: PingArgs) -> Dict[str, str]: - logger.debug(f"_ping_actual called. mcp_ctx type: {type(mcp_ctx)}, args: {args}") - ping_session_id: Optional[str] = None - if args.rk_session_id: # If client provides its known rk_session_id - ping_session_id = args.rk_session_id - logger.info(f"Ping using rk_session_id from args: {ping_session_id}") - elif ( - hasattr(mcp_ctx, "session") - and mcp_ctx.session - and hasattr(mcp_ctx.session, "client_params") - and mcp_ctx.session.client_params - and hasattr(mcp_ctx.session.client_params, "session_id") - and mcp_ctx.session.client_params.session_id - ): - ping_session_id = mcp_ctx.session.client_params.session_id - logger.info(f"Ping using transport_session_id from mcp_ctx: {ping_session_id}") - else: - ping_session_id = "unknown_session_for_ping" - logger.warning(f"Session ID for ping not found in args or mcp_ctx, using fallback: {ping_session_id}") - return {"reply": "pong", "session_id": ping_session_id or ""} - - async def shutdown(self): - logger.info("RewardKitIntermediaryServer (FastMCP based) performing custom shutdown tasks...") - logger.info(f"Cleaning up {len(self.intermediary_session_data)} IntermediarySessionData entries...") - for session_id_key in list(self.intermediary_session_data.keys()): - session_data_obj = self.intermediary_session_data.pop(session_id_key, None) - if session_data_obj: - await self.cleanup_session_internal(session_data_obj, session_id_key) - - shared_instances = list(self._shared_global_instances.values()) - if shared_instances: - logger.info(f"Deprovisioning {len(shared_instances)} shared global instances.") - local_shared = [i for i in shared_instances if i.orchestration_mode == "local_docker"] - if local_shared and self._local_docker_orchestrator: - await self._local_docker_orchestrator.deprovision_instances(local_shared) - remote_shared_by_key: Dict[str, List[ManagedInstanceInfo]] = {} - for inst_info in shared_instances: - if inst_info.orchestration_mode == "remote_http_api": - key = self._get_orchestration_client_key_for_instance(inst_info) - if key: - remote_shared_by_key.setdefault(key, []).append(inst_info) - for key, instances_list in remote_shared_by_key.items(): - orchestrator = self._remote_http_orchestrators.get(key) - if orchestrator: - await orchestrator.deprovision_instances(instances_list) - - if self._local_docker_orchestrator: - await self._local_docker_orchestrator.shutdown() - for orch in self._remote_http_orchestrators.values(): - await orch.shutdown() - logger.info("RewardKitIntermediaryServer custom shutdown tasks complete.") - - def _get_orchestration_client_key_for_instance(self, instance_info: ManagedInstanceInfo) -> Optional[str]: - if instance_info.orchestration_mode == "remote_http_api": - backend_cfg = next( - (b for b in self.app_config.backends if b.backend_name_ref == instance_info.backend_name_ref), - None, - ) - if backend_cfg: - return backend_cfg.remote_api_config_ref or ( - backend_cfg.remote_api_config_inline.base_url if backend_cfg.remote_api_config_inline else None - ) - return None diff --git a/eval_protocol/mcp_agent/orchestration/remote_http_client.py b/eval_protocol/mcp_agent/orchestration/remote_http_client.py deleted file mode 100644 index d6749b61..00000000 --- a/eval_protocol/mcp_agent/orchestration/remote_http_client.py +++ /dev/null @@ -1,307 +0,0 @@ -import logging -from typing import Any, Dict, List, Optional - -import httpx -from mcp import types as mcp_types -from mcp.client.session import ClientSession -from mcp.client.streamable_http import streamablehttp_client - -from eval_protocol.mcp_agent.config import ( - AppConfig, - BackendServerConfig, - RemoteApiConfig, -) -from eval_protocol.mcp_agent.orchestration.base_client import ( - AbstractOrchestrationClient, - ManagedInstanceInfo, -) - -logger = logging.getLogger(__name__) - - -class RemoteHttpOrchestrationClient(AbstractOrchestrationClient): - """ - Orchestrates backend MCP server instances by communicating with a remote HTTP API. - This client translates provisioning, deprovisioning, and tool call requests - into HTTP requests to a customer-defined remote orchestration service. - """ - - def __init__(self, app_config: AppConfig): - self.app_config = app_config - self.http_client: Optional[httpx.AsyncClient] = None - - async def startup(self) -> None: - """Initializes the httpx client.""" - # Default timeout can be overridden by specific remote_api_config later - timeout_config = httpx.Timeout( - self.app_config.global_remote_api_defaults.get("timeout", 30.0), - connect=self.app_config.global_remote_api_defaults.get("connect_timeout", 5.0), - ) - self.http_client = httpx.AsyncClient(timeout=timeout_config) - logger.info("RemoteHttpOrchestrationClient started.") - - async def shutdown(self) -> None: - """Closes the httpx client.""" - if self.http_client: - await self.http_client.aclose() - logger.info("HTTPX client for RemoteHttpOrchestrationClient closed.") - logger.info("RemoteHttpOrchestrationClient shut down.") - - def _get_auth_headers(self, remote_api_config: RemoteApiConfig) -> Dict[str, str]: - """Constructs authentication headers based on the remote API config.""" - headers = {} - if remote_api_config.auth_type == "bearer_token": - token = remote_api_config.auth_details.get("token") - if token: - headers["Authorization"] = f"Bearer {token}" - else: - logger.warning("Bearer token auth selected but no token provided.") - elif remote_api_config.auth_type == "custom_header": - header_name = remote_api_config.auth_details.get("header_name") - header_value = remote_api_config.auth_details.get("header_value") - if header_name and header_value: - headers[header_name] = header_value - else: - logger.warning("Custom header auth selected but header_name or header_value missing.") - return headers - - async def _make_request( - self, - method: str, - url: str, - remote_api_config: RemoteApiConfig, - json_payload: Optional[Dict[str, Any]] = None, - params: Optional[Dict[str, Any]] = None, - ) -> httpx.Response: - """Helper method to make HTTP requests with authentication and error handling.""" - if not self.http_client: - raise RuntimeError("HTTP client not initialized. Call startup() first.") - - headers = self._get_auth_headers(remote_api_config) - headers["Content-Type"] = "application/json" # Assume JSON requests - - try: - logger.debug(f"Making {method} request to {url} with payload: {json_payload} and params: {params}") - response = await self.http_client.request(method, url, headers=headers, json=json_payload, params=params) - response.raise_for_status() # Raise an exception for 4xx/5xx responses - return response - except httpx.RequestError as e: - logger.error(f"Request error during {method} to {url}: {e}") - raise RuntimeError(f"Remote API request failed: Network error calling {url}") from e - except httpx.HTTPStatusError as e: - logger.error(f"HTTP status error during {method} to {url}: {e.response.status_code} - {e.response.text}") - try: - error_details = e.response.json() - except Exception: - error_details = e.response.text - raise RuntimeError( - f"Remote API request failed: Server returned error {e.response.status_code}. Details: {error_details}" - ) from e - - async def provision_instances( - self, - backend_config: BackendServerConfig, - num_instances: int, - session_id: str, - template_details: Optional[Any] = None, - ) -> List[ManagedInstanceInfo]: - if backend_config.orchestration_mode != "remote_http_api": - raise ValueError("RemoteHttpOrchestrationClient can only handle 'remote_http_api' mode.") - - remote_api_config = self.app_config.get_remote_api_config(backend_config) - if not remote_api_config: - raise ValueError(f"RemoteApiConfig not found for backend {backend_config.backend_name_ref}.") - - create_url = ( - f"{remote_api_config.base_url.rstrip('/')}/{remote_api_config.create_instance_endpoint.lstrip('/')}" - ) - - provisioned_instances_info: List[ManagedInstanceInfo] = [] - - # The remote API might support batch creation or require individual calls. - # This example assumes the remote API can take num_instances and returns a list. - # Adjust if the API requires one call per instance. - payload = { - "resource_type_identifier": backend_config.remote_resource_type_identifier, - "num_instances": num_instances, - "session_id": session_id, - "instance_scoping": backend_config.instance_scoping, - "template_details": template_details, # Pass along any template info - # Add any other necessary parameters the remote API expects - } - - logger.info( - f"Requesting {num_instances} instances of type '{backend_config.remote_resource_type_identifier}' from {create_url}" - ) - - response = await self._make_request("POST", create_url, remote_api_config, json_payload=payload) - response_data = response.json() # Expecting a list of instance details - - if not isinstance(response_data, list): - raise ValueError( - f"Remote API at {create_url} did not return a list of instances. Response: {response_data}" - ) - - for i, inst_data in enumerate(response_data): - # The remote API response should provide necessary details for ManagedInstanceInfo - # Required: instance_id (client-facing), mcp_endpoint_url, internal_instance_details (like remote_instance_id) - remote_instance_id = inst_data.get("remote_instance_id") - mcp_endpoint_url = inst_data.get("mcp_endpoint_url") - client_facing_instance_id = inst_data.get( - "instance_id", f"{session_id}-{backend_config.backend_name_ref}-{i}" - ) - - if not remote_instance_id or not mcp_endpoint_url: - logger.error( - f"Remote API response for instance missing 'remote_instance_id' or 'mcp_endpoint_url'. Data: {inst_data}" - ) - # Decide on error handling: skip this instance, or fail all? - # For now, let's raise an error if critical info is missing. - raise ValueError(f"Remote API response for instance creation is incomplete: {inst_data}") - - provisioned_instances_info.append( - ManagedInstanceInfo( - instance_id=client_facing_instance_id, - backend_name_ref=backend_config.backend_name_ref, - orchestration_mode="remote_http_api", - mcp_endpoint_url=mcp_endpoint_url, - internal_instance_details={ - "remote_instance_id": remote_instance_id, - **inst_data.get("additional_details", {}), # Any other info from remote - }, - ) - ) - logger.info( - f"Instance {client_facing_instance_id} (Remote ID: {remote_instance_id}) provisioned. MCP Endpoint: {mcp_endpoint_url}" - ) - - if ( - len(provisioned_instances_info) != num_instances and num_instances > 0 and len(response_data) > 0 - ): # if API supports batch and returns partial - logger.warning( - f"Requested {num_instances} but remote API returned details for {len(provisioned_instances_info)} instances." - ) - - return provisioned_instances_info - - async def deprovision_instances(self, instances: List[ManagedInstanceInfo]) -> None: - for instance in instances: - if instance.orchestration_mode != "remote_http_api": - logger.warning( - f"Skipping deprovision for instance {instance.instance_id} as it's not remote_http_api." - ) - continue - - # Need to find the BackendServerConfig that led to this instance to get its RemoteApiConfig - backend_cfg = next( - (b for b in self.app_config.backends if b.backend_name_ref == instance.backend_name_ref), - None, - ) - if not backend_cfg: - logger.error( - f"Could not find BackendServerConfig for {instance.backend_name_ref} during deprovision of {instance.instance_id}" - ) - continue - - remote_api_config = self.app_config.get_remote_api_config(backend_cfg) - if not remote_api_config: - logger.error( - f"RemoteApiConfig not found for backend {instance.backend_name_ref} during deprovision of {instance.instance_id}." - ) - continue - - remote_instance_id = instance.internal_instance_details.get("remote_instance_id") - if not remote_instance_id: - logger.warning(f"No remote_instance_id found for instance {instance.instance_id}. Cannot deprovision.") - continue - - delete_url_template = remote_api_config.delete_instance_endpoint_template - delete_url = f"{remote_api_config.base_url.rstrip('/')}/{delete_url_template.lstrip('/').format(remote_instance_id=remote_instance_id)}" - - logger.info(f"Requesting deprovision of remote instance {remote_instance_id} via {delete_url}") - try: - await self._make_request("DELETE", delete_url, remote_api_config) - logger.info(f"Successfully requested deprovision for remote instance {remote_instance_id}.") - except Exception as e: - # Log error but continue trying to deprovision other instances - logger.error(f"Failed to deprovision remote instance {remote_instance_id}: {e}") - - async def call_tool_on_instance( - self, instance: ManagedInstanceInfo, tool_name: str, tool_args: Dict[str, Any] - ) -> Dict[str, Any]: - if instance.orchestration_mode != "remote_http_api": - raise ValueError("This client only handles remote_http_api instances.") - - backend_cfg = next( - (b for b in self.app_config.backends if b.backend_name_ref == instance.backend_name_ref), - None, - ) - if not backend_cfg: - raise RuntimeError(f"Could not find BackendServerConfig for {instance.backend_name_ref}") - - remote_api_config = self.app_config.get_remote_api_config(backend_cfg) - if not remote_api_config: - raise RuntimeError(f"RemoteApiConfig not found for backend {instance.backend_name_ref}.") - - mcp_payload = {"tool_name": tool_name, "arguments": tool_args} - - target_url: str - # Check if tool calls are proxied through the orchestrator or made directly to the instance - if remote_api_config.call_tool_endpoint_template: - remote_instance_id = instance.internal_instance_details.get("remote_instance_id") - if not remote_instance_id: - raise ValueError( - f"Missing remote_instance_id for instance {instance.instance_id} when proxying tool call." - ) - - call_template = remote_api_config.call_tool_endpoint_template - # The template might need remote_instance_id and potentially tool_name if it's part of the path - # Assuming a generic proxy endpoint for now that takes tool_name in payload - target_url = f"{remote_api_config.base_url.rstrip('/')}/{call_template.lstrip('/').format(remote_instance_id=remote_instance_id)}" - # The payload to the proxy might need to be wrapped, e.g. including the actual MCP payload - # For now, assume the proxy forwards the mcp_payload directly. - logger.debug(f"Proxying tool {tool_name} to {target_url} for instance {instance.instance_id}") - else: - # Call tool directly on the instance's MCP endpoint - # mypy/pyright: instance.mcp_endpoint_url is Optional[str]; validate before assignment - if not instance.mcp_endpoint_url: - raise ValueError(f"Instance {instance.instance_id} missing mcp_endpoint_url for direct tool call") - target_url = instance.mcp_endpoint_url - logger.debug(f"Calling tool {tool_name} directly on {target_url} for instance {instance.instance_id}") - - response = await self._make_request("POST", target_url, remote_api_config, json_payload=mcp_payload) - return response.json() - - async def list_tools_on_instance(self, instance: ManagedInstanceInfo) -> mcp_types.ListToolsResult: - if instance.orchestration_mode != "remote_http_api": - raise ValueError("RemoteHttpOrchestrationClient can only list tools for 'remote_http_api' instances.") - if instance.mcp_transport != "http" or not instance.mcp_endpoint_url: - raise ValueError( - f"Instance {instance.instance_id} ({instance.backend_name_ref}) is not configured for HTTP MCP transport or mcp_endpoint_url is missing." - ) - - # Assuming instance.mcp_endpoint_url is the base URL of the target MCP server - # e.g., "http://localhost:12345" - target_base_url = instance.mcp_endpoint_url.rstrip("/") - - logger.info( - f"Listing tools for remote HTTP instance {instance.instance_id} ({instance.backend_name_ref}) at base URL {target_base_url}" - ) - - try: - # streamablehttp_client will manage its own httpx.AsyncClient if one is not provided. - # The context manager handles session.initialize() and session.close(). - async with streamablehttp_client(base_url=target_base_url) as session: # type: ClientSession - list_tools_result = await session.list_tools() - logger.info( - f"Successfully listed {len(list_tools_result.tools)} tools from {target_base_url} for instance {instance.instance_id} ({instance.backend_name_ref})" - ) - return list_tools_result - except Exception as e: - logger.error( - f"Error listing tools from {target_base_url} for instance {instance.instance_id} ({instance.backend_name_ref}): {e}", - exc_info=True, - ) - raise RuntimeError( - f"Failed to list tools from backend instance {instance.instance_id} ({instance.backend_name_ref}) at {target_base_url}" - ) from e diff --git a/eval_protocol/mcp_env.py b/eval_protocol/mcp_env.py index f5d09ba0..ca32b327 100644 --- a/eval_protocol/mcp_env.py +++ b/eval_protocol/mcp_env.py @@ -136,7 +136,11 @@ def make( if evaluation_rows: for i, row in enumerate(evaluation_rows): - dataset_info = row.input_metadata.dataset_info if row.input_metadata else {} + dataset_info = ( + row.input_metadata.dataset_info + if (row.input_metadata and row.input_metadata.dataset_info is not None) + else {} + ) system_message = row.get_system_message() system_prompt = system_message.content or "" diff --git a/eval_protocol/mcp_servers/tau2/airplane_environment/airline_environment.py b/eval_protocol/mcp_servers/tau2/airplane_environment/airline_environment.py index 97acd49b..9be30adf 100644 --- a/eval_protocol/mcp_servers/tau2/airplane_environment/airline_environment.py +++ b/eval_protocol/mcp_servers/tau2/airplane_environment/airline_environment.py @@ -68,6 +68,7 @@ def close(self): def _execute_airline_action(self, action_name: str, parameters: Dict[str, Any]) -> Dict[str, Any]: """Execute action using airline tools.""" + assert isinstance(self.airline_tools, AirlineTools), "Airline tools not initialized" action_map = { "book_reservation": self.airline_tools.book_reservation, "cancel_reservation": self.airline_tools.cancel_reservation, diff --git a/eval_protocol/playback_policy.py b/eval_protocol/playback_policy.py index a84fa834..44b2b64d 100644 --- a/eval_protocol/playback_policy.py +++ b/eval_protocol/playback_policy.py @@ -224,7 +224,7 @@ async def __call__( tool_schemas: List[Dict], env_index: int, conversation_history: List[Dict[str, Any]], - ): + ) -> Tuple[List["MCPToolCall"], Optional[Dict[str, int]], Optional[str]]: """ Main policy call method. Delegates to playback or live mode. diff --git a/eval_protocol/rewards/apps_coding_reward.py b/eval_protocol/rewards/apps_coding_reward.py index 6cbd63b9..5ab7465b 100644 --- a/eval_protocol/rewards/apps_coding_reward.py +++ b/eval_protocol/rewards/apps_coding_reward.py @@ -84,7 +84,7 @@ def evaluate_apps_solution(messages: List[Message], ground_truth: Optional[str], reason="No messages provided.", ) - raw_solution_content = messages[-1].content + raw_solution_content = messages[-1].content if isinstance(messages[-1].content, str) else "" code_solution = _extract_python_code(raw_solution_content) if not code_solution or not code_solution.strip(): diff --git a/eval_protocol/rewards/apps_testing_util.py b/eval_protocol/rewards/apps_testing_util.py index 84f13f8a..27f52cc5 100644 --- a/eval_protocol/rewards/apps_testing_util.py +++ b/eval_protocol/rewards/apps_testing_util.py @@ -174,7 +174,10 @@ def run_test(in_outs, test=None, debug=False, timeout=15): if isinstance(last_block, ast.If): condition = last_block.test if ast.unparse(condition).strip() == "__name__ == '__main__'": - test = ast.unparse(astree.body[:-1]) + "\n" + ast.unparse(last_block.body) + # Build modules for unparse to avoid passing lists to ast.unparse + prefix_module = ast.Module(body=astree.body[:-1], type_ignores=[]) + body_module = ast.Module(body=last_block.body, type_ignores=[]) + test = ast.unparse(prefix_module) + "\n" + ast.unparse(body_module) except Exception: pass diff --git a/eval_protocol/rewards/bfcl_reward.py b/eval_protocol/rewards/bfcl_reward.py index ccf1cda4..6b2b3ecf 100644 --- a/eval_protocol/rewards/bfcl_reward.py +++ b/eval_protocol/rewards/bfcl_reward.py @@ -262,7 +262,8 @@ def bfcl_reward( assistant_message_found = True total_assistant_messages += 1 # Check for any content or any tool_call - if (msg.content and msg.content.strip()) or msg.tool_calls: + content_str = msg.content if isinstance(msg.content, str) else "" + if (content_str and content_str.strip()) or msg.tool_calls: valid_assistant_messages += 1 if not assistant_message_found: diff --git a/eval_protocol/rewards/cpp_code.py b/eval_protocol/rewards/cpp_code.py index cb324273..4f0a27a7 100644 --- a/eval_protocol/rewards/cpp_code.py +++ b/eval_protocol/rewards/cpp_code.py @@ -608,7 +608,7 @@ def _ioi_cpp_code_reward_impl( }, ) - response_content = messages[-1].content + response_content = messages[-1].content if isinstance(messages[-1].content, str) else "" expected_output_str_from_gt: Optional[str] = None test_cases_from_gt: Optional[List[Dict[str, Any]]] = None diff --git a/eval_protocol/rewards/json_schema.py b/eval_protocol/rewards/json_schema.py index 3c729936..06f2fc5e 100644 --- a/eval_protocol/rewards/json_schema.py +++ b/eval_protocol/rewards/json_schema.py @@ -77,7 +77,8 @@ def json_schema_reward( ) elif isinstance(last_message, dict): if last_message.get("role") == "assistant" and last_message.get("content") is not None: - content_text = last_message.get("content", "") + raw_content = last_message.get("content", "") + content_text = raw_content if isinstance(raw_content, str) else "" else: return EvaluateResult( score=0.0, diff --git a/eval_protocol/rewards/length.py b/eval_protocol/rewards/length.py index 75798430..55a84fb0 100644 --- a/eval_protocol/rewards/length.py +++ b/eval_protocol/rewards/length.py @@ -93,7 +93,7 @@ def length_reward( ) }, ) - text = response.content + text = response.content if isinstance(response.content, str) else "" elif isinstance(response, dict): if response.get("role") != "assistant" or not response.get("content"): return EvaluateResult( @@ -107,7 +107,8 @@ def length_reward( ) }, ) - text = response.get("content", "") + text_val = response.get("content", "") + text = text_val if isinstance(text_val, str) else "" else: return EvaluateResult( score=0.0, @@ -294,7 +295,7 @@ def cosine_length_reward( ) }, ) - text = response.content + text = response.content if isinstance(response.content, str) else "" elif isinstance(response, dict): if response.get("role") != "assistant" or not response.get("content"): return EvaluateResult( @@ -308,7 +309,8 @@ def cosine_length_reward( ) }, ) - text = response.get("content", "") + text_val = response.get("content", "") + text = text_val if isinstance(text_val, str) else "" else: return EvaluateResult( score=0.0, diff --git a/eval_protocol/rewards/math.py b/eval_protocol/rewards/math.py index e0406314..3e03f022 100644 --- a/eval_protocol/rewards/math.py +++ b/eval_protocol/rewards/math.py @@ -587,7 +587,7 @@ def math_reward( ) }, ) - model_response_content = messages[-1].content + model_response_content = messages[-1].content if isinstance(messages[-1].content, str) else "" if ground_truth is None or ground_truth == "": return EvaluateResult( score=0.0, diff --git a/eval_protocol/rewards/reasoning_steps.py b/eval_protocol/rewards/reasoning_steps.py index 98da3d85..57a17611 100644 --- a/eval_protocol/rewards/reasoning_steps.py +++ b/eval_protocol/rewards/reasoning_steps.py @@ -60,7 +60,7 @@ def reasoning_steps_reward( ) }, ) - text: str = response.content + text: str = response.content if isinstance(response.content, str) else "" # Default patterns for detecting reasoning steps default_patterns = [ @@ -199,7 +199,7 @@ def sequence_reward( ) }, ) - text: str = response.content + text: str = response.content if isinstance(response.content, str) else "" if not sequence_terms: sequence_terms = [ diff --git a/eval_protocol/rewards/repetition.py b/eval_protocol/rewards/repetition.py index ae366a23..27c7d644 100644 --- a/eval_protocol/rewards/repetition.py +++ b/eval_protocol/rewards/repetition.py @@ -248,7 +248,7 @@ def diversity_reward( ) }, ) - text = response.content or "" + text = _to_text(response.content) elif isinstance(response, dict): if response.get("role") != "assistant": return EvaluateResult( @@ -262,7 +262,8 @@ def diversity_reward( ) }, ) - text = response.get("content", "") + text_val = response.get("content", "") + text = text_val if isinstance(text_val, str) else "" else: return EvaluateResult( score=0.0,