diff --git a/eval_protocol/mcp/client/connection.py b/eval_protocol/mcp/client/connection.py index ee2c92e9..6916cb08 100644 --- a/eval_protocol/mcp/client/connection.py +++ b/eval_protocol/mcp/client/connection.py @@ -53,16 +53,26 @@ async def initialize_session(self, session: MCPSession) -> None: exit_stack = AsyncExitStack() - client_info = Implementation(name="reward-kit", version="1.0.0", _extra={}) - client_info._extra["session_id"] = session.session_id + # Attach client metadata for the server to consume (session_id, seed, config, etc.). + # The server inspects a private `_extra` dict on client_info, so we populate it here. + client_info = Implementation(name="reward-kit", version="1.0.0") + extra_data: Dict[str, Any] = {"session_id": session.session_id} if session.seed is not None: - client_info._extra["seed"] = session.seed + extra_data["seed"] = session.seed if session.dataset_row and session.dataset_row.environment_context: - client_info._extra["config"] = session.dataset_row.environment_context + extra_data["config"] = session.dataset_row.environment_context if session.dataset_row and session.dataset_row.id: - client_info._extra["dataset_row_id"] = session.dataset_row.id + extra_data["dataset_row_id"] = session.dataset_row.id if session.model_id: - client_info._extra["model_id"] = session.model_id + extra_data["model_id"] = session.model_id + + # Merge with any existing _extra dict instead of overwriting + existing_extra = getattr(client_info, "_extra", None) + merged_extra: Dict[str, Any] = {} + if isinstance(existing_extra, dict): + merged_extra.update(existing_extra) + merged_extra.update(extra_data) + setattr(client_info, "_extra", merged_extra) read_stream, write_stream, _ = await exit_stack.enter_async_context( streamablehttp_client(session.base_url, terminate_on_close=True) @@ -92,7 +102,10 @@ async def _prewarm_tools_cache(self, session: MCPSession) -> None: # Only fetch tools if not already cached for this base_url if cache_key not in self._tools_cache: logger.debug(f"Pre-warming tools cache for {cache_key}") - tools_response = await session._mcp_session.list_tools() + mcp_session_local = session._mcp_session + if mcp_session_local is None: + raise RuntimeError("Session not initialized during prewarm") + tools_response = await mcp_session_local.list_tools() tools = tools_response.tools if hasattr(tools_response, "tools") else [] tool_schemas = [] @@ -213,7 +226,7 @@ async def get_initial_state(self, session: MCPSession) -> Any: try: # Use shorter timeout for playback mode, longer timeout for high-concurrency initialization # (50+ concurrent sessions need more time for initial state setup) - timeout = 3.0 if hasattr(session, "_is_playback_mode") and session._is_playback_mode else 15.0 + timeout = 3.0 if bool(getattr(session, "_is_playback_mode", False)) else 15.0 async with httpx.AsyncClient(timeout=timeout) as client: initial_state_response = await client.get( f"{base_url}/control/initial_state", diff --git a/eval_protocol/mcp/clients.py b/eval_protocol/mcp/clients.py index 12bb3e61..b1287fe1 100644 --- a/eval_protocol/mcp/clients.py +++ b/eval_protocol/mcp/clients.py @@ -29,7 +29,7 @@ def __init__(self, intermediary_server_url: str): async def connect(self): """Establishes connection and MCP session.""" - if self._mcp_session and not self._mcp_session.is_closed: + if self._mcp_session is not None and not self._mcp_session.is_closed: logger.debug("Already connected.") return diff --git a/eval_protocol/mcp/mcp_multi_client.py b/eval_protocol/mcp/mcp_multi_client.py index 5b0c676c..1c14db1e 100644 --- a/eval_protocol/mcp/mcp_multi_client.py +++ b/eval_protocol/mcp/mcp_multi_client.py @@ -90,8 +90,8 @@ async def _connect_to_server( if env_config: self._validate_environment_variables(server_name, env_config) - # Use the current system environment (os.environ) - don't override with config - server_params = StdioServerParameters(command=command, args=args, env=os.environ) + # Use the current system environment (os.environ) - convert to plain dict for typing compatibility + server_params = StdioServerParameters(command=command, args=args, env=dict(os.environ)) stdio_transport = await self.exit_stack.enter_async_context(stdio_client(server_params)) stdio, write = stdio_transport diff --git a/eval_protocol/mcp_agent/orchestration/local_docker_client.py b/eval_protocol/mcp_agent/orchestration/local_docker_client.py index d6660487..6383c210 100644 --- a/eval_protocol/mcp_agent/orchestration/local_docker_client.py +++ b/eval_protocol/mcp_agent/orchestration/local_docker_client.py @@ -244,6 +244,10 @@ async def provision_instances( logger.info( f"Creating template container for commit: {temp_cont_name} from {backend_config.docker_image}" ) + if not backend_config.docker_image: + raise ValueError( + f"docker_image is required for template commit for backend {backend_config.backend_name_ref}" + ) temp_c = self.docker_client.containers.run( # type: ignore image=backend_config.docker_image, name=temp_cont_name, @@ -322,6 +326,11 @@ async def provision_instances( logger.info( f"Provisioning instance {container_name} (transport: {backend_config.mcp_transport}) from image {image_to_run_from}" ) + # Ensure the image reference is present before using it in Docker APIs + if not image_to_run_from: + raise ValueError( + f"docker_image is required to provision instance {container_name} for backend {backend_config.backend_name_ref}" + ) if backend_config.mcp_transport == "http": # ... (HTTP provisioning logic, ensure it uses current_container_volumes) ... if not self.docker_client: diff --git a/eval_protocol/mcp_agent/orchestration/remote_http_client.py b/eval_protocol/mcp_agent/orchestration/remote_http_client.py index dd432783..d6749b61 100644 --- a/eval_protocol/mcp_agent/orchestration/remote_http_client.py +++ b/eval_protocol/mcp_agent/orchestration/remote_http_client.py @@ -263,6 +263,9 @@ async def call_tool_on_instance( logger.debug(f"Proxying tool {tool_name} to {target_url} for instance {instance.instance_id}") else: # Call tool directly on the instance's MCP endpoint + # mypy/pyright: instance.mcp_endpoint_url is Optional[str]; validate before assignment + if not instance.mcp_endpoint_url: + raise ValueError(f"Instance {instance.instance_id} missing mcp_endpoint_url for direct tool call") target_url = instance.mcp_endpoint_url logger.debug(f"Calling tool {tool_name} directly on {target_url} for instance {instance.instance_id}") diff --git a/eval_protocol/models.py b/eval_protocol/models.py index dde91e89..325db4ce 100644 --- a/eval_protocol/models.py +++ b/eval_protocol/models.py @@ -243,8 +243,19 @@ class Message(BaseModel): @classmethod def model_validate(cls, obj, *args, **kwargs): - if isinstance(obj, dict) and "role" not in obj: - raise ValueError("Role is required") + if isinstance(obj, dict): + if "role" not in obj: + raise ValueError("Role is required") + # Be lenient: if tool_calls entries are missing required 'id', synthesize one + tool_calls_obj = obj.get("tool_calls") + if isinstance(tool_calls_obj, list): + fixed_tool_calls = [] + for tc in tool_calls_obj: + if isinstance(tc, dict): + if not tc.get("id"): + tc = {**tc, "id": generate_id()} + fixed_tool_calls.append(tc) + obj = {**obj, "tool_calls": fixed_tool_calls} return super().model_validate(obj, *args, **kwargs) @@ -611,27 +622,35 @@ def get_steps(self) -> int: def get_total_reward(self) -> float: """Get total reward from control_plane_step data.""" messages_with_control_plane = [msg for msg in self.messages if msg.control_plane_step] - return ( - sum(msg.control_plane_step["reward"] for msg in messages_with_control_plane) - if messages_with_control_plane - else 0.0 - ) + if not messages_with_control_plane: + return 0.0 + total = 0.0 + for msg in messages_with_control_plane: + step = msg.control_plane_step or {} + try: + total += float(step.get("reward", 0.0)) + except (TypeError, ValueError): + continue + return total def get_terminated(self) -> bool: """Get termination status from control_plane_step data.""" messages_with_control_plane = [msg for msg in self.messages if msg.control_plane_step] - return ( - any(msg.control_plane_step["terminated"] for msg in messages_with_control_plane) - if messages_with_control_plane - else False - ) + if not messages_with_control_plane: + return False + for msg in messages_with_control_plane: + step = msg.control_plane_step or {} + if bool(step.get("terminated", False)): + return True + return False def get_termination_reason(self) -> str: """Get termination reason from the final control_plane_step data.""" # Find the last message with control_plane_step that has termination_reason for msg in reversed(self.messages): if msg.control_plane_step and msg.control_plane_step.get("termination_reason"): - return msg.control_plane_step["termination_reason"] + reason = msg.control_plane_step.get("termination_reason") + return str(reason) return "unknown" def __hash__(self) -> int: diff --git a/eval_protocol/pytest/__init__.py b/eval_protocol/pytest/__init__.py index 3c4bf795..b6d02ae2 100644 --- a/eval_protocol/pytest/__init__.py +++ b/eval_protocol/pytest/__init__.py @@ -8,7 +8,7 @@ from .rollout_processor import RolloutProcessor from .types import RolloutProcessorConfig -# Conditional import for optional dependency +# Conditional import for optional dependencies try: from .default_pydantic_ai_rollout_processor import PydanticAgentRolloutProcessor diff --git a/eval_protocol/pytest/default_langchain_rollout_processor.py b/eval_protocol/pytest/default_langchain_rollout_processor.py index f60c966a..e3c86e1a 100644 --- a/eval_protocol/pytest/default_langchain_rollout_processor.py +++ b/eval_protocol/pytest/default_langchain_rollout_processor.py @@ -1,7 +1,13 @@ import asyncio from typing import List -from langchain_core.messages import BaseMessage +try: + from langchain_core.messages import BaseMessage +except Exception: # pragma: no cover - optional dependency path + # Minimal fallback base type to satisfy typing when langchain is not present + class BaseMessage: # type: ignore + pass + from eval_protocol.models import EvaluationRow, Message from eval_protocol.pytest.rollout_processor import RolloutProcessor @@ -25,7 +31,13 @@ def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig): async def _process_row(row: EvaluationRow) -> EvaluationRow: # Build LC messages from EP row - from langchain_core.messages import HumanMessage + try: + from langchain_core.messages import HumanMessage + except Exception: + # Fallback minimal message if langchain_core is unavailable + class HumanMessage: # type: ignore + def __init__(self, content: str): + self.content = content lm_messages: List[BaseMessage] = [] if row.messages: diff --git a/eval_protocol/rewards/code_execution.py b/eval_protocol/rewards/code_execution.py index 52db7db2..2723d5b7 100644 --- a/eval_protocol/rewards/code_execution.py +++ b/eval_protocol/rewards/code_execution.py @@ -171,7 +171,11 @@ def local_code_execution_reward( }, ) - response_content = messages[-1].content + # Normalize content to string; Message.content may be str or list of content parts + last_content = messages[-1].content + response_content = ( + last_content if isinstance(last_content, str) else "".join([p.text for p in (last_content or [])]) + ) expected_output_str = ground_truth code_blocks = extract_code_blocks(response_content, language) @@ -935,7 +939,10 @@ def e2b_code_execution_reward( }, ) - response_content = messages[-1].content + last_content = messages[-1].content + response_content = ( + last_content if isinstance(last_content, str) else "".join([p.text for p in (last_content or [])]) + ) expected_output_str = ground_truth code_blocks = extract_code_blocks(response_content, language) diff --git a/eval_protocol/rewards/lean_prover.py b/eval_protocol/rewards/lean_prover.py index fae390d0..f134fcfe 100644 --- a/eval_protocol/rewards/lean_prover.py +++ b/eval_protocol/rewards/lean_prover.py @@ -57,7 +57,8 @@ def lean_prover_reward( }, ) - response = messages[-1].content + last_content = messages[-1].content + response = last_content if isinstance(last_content, str) else "".join([p.text for p in (last_content or [])]) if not response: return EvaluateResult( score=0.0, @@ -230,7 +231,10 @@ def deepseek_prover_v2_reward( and messages[-1].role == "assistant" and messages[-1].content is not None ): - response_content = messages[-1].content + last_content = messages[-1].content + response_content = ( + last_content if isinstance(last_content, str) else "".join([p.text for p in (last_content or [])]) + ) final_score = base_score subgoal_count = 0 diff --git a/eval_protocol/typed_interface.py b/eval_protocol/typed_interface.py index 054babee..696bd538 100644 --- a/eval_protocol/typed_interface.py +++ b/eval_protocol/typed_interface.py @@ -99,6 +99,23 @@ def decorator(func: F) -> F: # Detect if the user supplied function is a coroutine (async def) _is_async_function = inspect.iscoroutinefunction(func) + def _is_list_of_message_annotation(annotation: Any) -> bool: + origin = get_origin(annotation) + args = get_args(annotation) + # Direct List[Message] + if origin in (list, List) and args and args[0] == Message: + return True + # Optional[List[Message]] or Union[List[Message], None] + if origin is Union and args: + # Filter out NoneType + non_none = [a for a in args if a is not type(None)] # noqa: E721 + if len(non_none) == 1: + inner = non_none[0] + inner_origin = get_origin(inner) + inner_args = get_args(inner) + return inner_origin in (list, List) and inner_args and inner_args[0] == Message + return False + def _prepare_final_args(*args: Any, **kwargs: Any): """Prepare final positional and keyword arguments for the user function call. This includes Pydantic coercion and resource injection. Returns a tuple of @@ -119,7 +136,7 @@ def _coerce_to_list_message(data_list: Any, arg_name_for_error: str) -> List[Mes if isinstance(item_data, Message): typed_list.append(item_data) elif isinstance(item_data, dict): - typed_list.append(Message(**item_data)) + typed_list.append(Message.model_validate(item_data)) else: raise TypeError(f"Unexpected type for item {i} in '{arg_name_for_error}': {type(item_data)}") return typed_list @@ -127,11 +144,7 @@ def _coerce_to_list_message(data_list: Any, arg_name_for_error: str) -> List[Mes # 1. Conditional Pydantic conversion for 'messages' (pointwise) or 'rollouts_messages' (batch) if mode == "pointwise" and "messages" in params and "messages" in final_func_args: messages_param_annotation = params["messages"].annotation - if ( - get_origin(messages_param_annotation) in (list, List) - and get_args(messages_param_annotation) - and get_args(messages_param_annotation)[0] == Message - ): + if _is_list_of_message_annotation(messages_param_annotation): try: final_func_args["messages"] = _coerce_to_list_message(final_func_args["messages"], "messages") except Exception as err: @@ -155,12 +168,18 @@ def _coerce_to_list_message(data_list: Any, arg_name_for_error: str) -> List[Mes # Ground truth coercion (if needed) if "ground_truth" in params and "ground_truth" in final_func_args: gt_ann = params["ground_truth"].annotation - if get_origin(gt_ann) in (list, List) and get_args(gt_ann) and get_args(gt_ann)[0] == Message: + if _is_list_of_message_annotation(gt_ann): if final_func_args["ground_truth"] is not None: + gt_val = final_func_args["ground_truth"] try: - final_func_args["ground_truth"] = _coerce_to_list_message( - final_func_args["ground_truth"], "ground_truth" - ) + if isinstance(gt_val, list): + final_func_args["ground_truth"] = _coerce_to_list_message(gt_val, "ground_truth") + elif isinstance(gt_val, dict): + final_func_args["ground_truth"] = _coerce_to_list_message([gt_val], "ground_truth") + elif isinstance(gt_val, str): + final_func_args["ground_truth"] = _coerce_to_list_message( + [{"role": "system", "content": gt_val}], "ground_truth" + ) except Exception as err: raise ValueError( f"Input 'ground_truth' failed Pydantic validation for List[Message]: {err}"