Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 21 additions & 8 deletions eval_protocol/mcp/client/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,16 +53,26 @@ async def initialize_session(self, session: MCPSession) -> None:

exit_stack = AsyncExitStack()

client_info = Implementation(name="reward-kit", version="1.0.0", _extra={})
client_info._extra["session_id"] = session.session_id
# Attach client metadata for the server to consume (session_id, seed, config, etc.).
# The server inspects a private `_extra` dict on client_info, so we populate it here.
client_info = Implementation(name="reward-kit", version="1.0.0")
extra_data: Dict[str, Any] = {"session_id": session.session_id}
if session.seed is not None:
client_info._extra["seed"] = session.seed
extra_data["seed"] = session.seed
if session.dataset_row and session.dataset_row.environment_context:
client_info._extra["config"] = session.dataset_row.environment_context
extra_data["config"] = session.dataset_row.environment_context
if session.dataset_row and session.dataset_row.id:
client_info._extra["dataset_row_id"] = session.dataset_row.id
extra_data["dataset_row_id"] = session.dataset_row.id
if session.model_id:
client_info._extra["model_id"] = session.model_id
extra_data["model_id"] = session.model_id

# Merge with any existing _extra dict instead of overwriting
existing_extra = getattr(client_info, "_extra", None)
merged_extra: Dict[str, Any] = {}
if isinstance(existing_extra, dict):
merged_extra.update(existing_extra)
merged_extra.update(extra_data)
setattr(client_info, "_extra", merged_extra)

read_stream, write_stream, _ = await exit_stack.enter_async_context(
streamablehttp_client(session.base_url, terminate_on_close=True)
Expand Down Expand Up @@ -92,7 +102,10 @@ async def _prewarm_tools_cache(self, session: MCPSession) -> None:
# Only fetch tools if not already cached for this base_url
if cache_key not in self._tools_cache:
logger.debug(f"Pre-warming tools cache for {cache_key}")
tools_response = await session._mcp_session.list_tools()
mcp_session_local = session._mcp_session
if mcp_session_local is None:
raise RuntimeError("Session not initialized during prewarm")
tools_response = await mcp_session_local.list_tools()
tools = tools_response.tools if hasattr(tools_response, "tools") else []

tool_schemas = []
Expand Down Expand Up @@ -213,7 +226,7 @@ async def get_initial_state(self, session: MCPSession) -> Any:
try:
# Use shorter timeout for playback mode, longer timeout for high-concurrency initialization
# (50+ concurrent sessions need more time for initial state setup)
timeout = 3.0 if hasattr(session, "_is_playback_mode") and session._is_playback_mode else 15.0
timeout = 3.0 if bool(getattr(session, "_is_playback_mode", False)) else 15.0
async with httpx.AsyncClient(timeout=timeout) as client:
initial_state_response = await client.get(
f"{base_url}/control/initial_state",
Expand Down
2 changes: 1 addition & 1 deletion eval_protocol/mcp/clients.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def __init__(self, intermediary_server_url: str):

async def connect(self):
"""Establishes connection and MCP session."""
if self._mcp_session and not self._mcp_session.is_closed:
if self._mcp_session is not None and not self._mcp_session.is_closed:
logger.debug("Already connected.")
return

Expand Down
4 changes: 2 additions & 2 deletions eval_protocol/mcp/mcp_multi_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,8 @@ async def _connect_to_server(
if env_config:
self._validate_environment_variables(server_name, env_config)

# Use the current system environment (os.environ) - don't override with config
server_params = StdioServerParameters(command=command, args=args, env=os.environ)
# Use the current system environment (os.environ) - convert to plain dict for typing compatibility
server_params = StdioServerParameters(command=command, args=args, env=dict(os.environ))

stdio_transport = await self.exit_stack.enter_async_context(stdio_client(server_params))
stdio, write = stdio_transport
Expand Down
9 changes: 9 additions & 0 deletions eval_protocol/mcp_agent/orchestration/local_docker_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,10 @@ async def provision_instances(
logger.info(
f"Creating template container for commit: {temp_cont_name} from {backend_config.docker_image}"
)
if not backend_config.docker_image:
raise ValueError(
f"docker_image is required for template commit for backend {backend_config.backend_name_ref}"
)
temp_c = self.docker_client.containers.run( # type: ignore
image=backend_config.docker_image,
name=temp_cont_name,
Expand Down Expand Up @@ -322,6 +326,11 @@ async def provision_instances(
logger.info(
f"Provisioning instance {container_name} (transport: {backend_config.mcp_transport}) from image {image_to_run_from}"
)
# Ensure the image reference is present before using it in Docker APIs
if not image_to_run_from:
raise ValueError(
f"docker_image is required to provision instance {container_name} for backend {backend_config.backend_name_ref}"
)
if backend_config.mcp_transport == "http":
# ... (HTTP provisioning logic, ensure it uses current_container_volumes) ...
if not self.docker_client:
Expand Down
3 changes: 3 additions & 0 deletions eval_protocol/mcp_agent/orchestration/remote_http_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,9 @@ async def call_tool_on_instance(
logger.debug(f"Proxying tool {tool_name} to {target_url} for instance {instance.instance_id}")
else:
# Call tool directly on the instance's MCP endpoint
# mypy/pyright: instance.mcp_endpoint_url is Optional[str]; validate before assignment
if not instance.mcp_endpoint_url:
raise ValueError(f"Instance {instance.instance_id} missing mcp_endpoint_url for direct tool call")
target_url = instance.mcp_endpoint_url
logger.debug(f"Calling tool {tool_name} directly on {target_url} for instance {instance.instance_id}")

Expand Down
45 changes: 32 additions & 13 deletions eval_protocol/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,8 +243,19 @@ class Message(BaseModel):

@classmethod
def model_validate(cls, obj, *args, **kwargs):
if isinstance(obj, dict) and "role" not in obj:
raise ValueError("Role is required")
if isinstance(obj, dict):
if "role" not in obj:
raise ValueError("Role is required")
# Be lenient: if tool_calls entries are missing required 'id', synthesize one
tool_calls_obj = obj.get("tool_calls")
if isinstance(tool_calls_obj, list):
fixed_tool_calls = []
for tc in tool_calls_obj:
if isinstance(tc, dict):
if not tc.get("id"):
tc = {**tc, "id": generate_id()}
fixed_tool_calls.append(tc)
obj = {**obj, "tool_calls": fixed_tool_calls}
return super().model_validate(obj, *args, **kwargs)


Expand Down Expand Up @@ -611,27 +622,35 @@ def get_steps(self) -> int:
def get_total_reward(self) -> float:
"""Get total reward from control_plane_step data."""
messages_with_control_plane = [msg for msg in self.messages if msg.control_plane_step]
return (
sum(msg.control_plane_step["reward"] for msg in messages_with_control_plane)
if messages_with_control_plane
else 0.0
)
if not messages_with_control_plane:
return 0.0
total = 0.0
for msg in messages_with_control_plane:
step = msg.control_plane_step or {}
try:
total += float(step.get("reward", 0.0))
except (TypeError, ValueError):
continue
return total

def get_terminated(self) -> bool:
"""Get termination status from control_plane_step data."""
messages_with_control_plane = [msg for msg in self.messages if msg.control_plane_step]
return (
any(msg.control_plane_step["terminated"] for msg in messages_with_control_plane)
if messages_with_control_plane
else False
)
if not messages_with_control_plane:
return False
for msg in messages_with_control_plane:
step = msg.control_plane_step or {}
if bool(step.get("terminated", False)):
return True
return False

def get_termination_reason(self) -> str:
"""Get termination reason from the final control_plane_step data."""
# Find the last message with control_plane_step that has termination_reason
for msg in reversed(self.messages):
if msg.control_plane_step and msg.control_plane_step.get("termination_reason"):
return msg.control_plane_step["termination_reason"]
reason = msg.control_plane_step.get("termination_reason")
return str(reason)
return "unknown"

def __hash__(self) -> int:
Expand Down
2 changes: 1 addition & 1 deletion eval_protocol/pytest/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from .rollout_processor import RolloutProcessor
from .types import RolloutProcessorConfig

# Conditional import for optional dependency
# Conditional import for optional dependencies
try:
from .default_pydantic_ai_rollout_processor import PydanticAgentRolloutProcessor

Expand Down
16 changes: 14 additions & 2 deletions eval_protocol/pytest/default_langchain_rollout_processor.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
import asyncio
from typing import List

from langchain_core.messages import BaseMessage
try:
from langchain_core.messages import BaseMessage
except Exception: # pragma: no cover - optional dependency path
# Minimal fallback base type to satisfy typing when langchain is not present
class BaseMessage: # type: ignore
pass


from eval_protocol.models import EvaluationRow, Message
from eval_protocol.pytest.rollout_processor import RolloutProcessor
Expand All @@ -25,7 +31,13 @@ def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig):

async def _process_row(row: EvaluationRow) -> EvaluationRow:
# Build LC messages from EP row
from langchain_core.messages import HumanMessage
try:
from langchain_core.messages import HumanMessage
except Exception:
# Fallback minimal message if langchain_core is unavailable
class HumanMessage: # type: ignore
def __init__(self, content: str):
self.content = content

lm_messages: List[BaseMessage] = []
if row.messages:
Expand Down
11 changes: 9 additions & 2 deletions eval_protocol/rewards/code_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,11 @@ def local_code_execution_reward(
},
)

response_content = messages[-1].content
# Normalize content to string; Message.content may be str or list of content parts
last_content = messages[-1].content
response_content = (
last_content if isinstance(last_content, str) else "".join([p.text for p in (last_content or [])])
)
expected_output_str = ground_truth

code_blocks = extract_code_blocks(response_content, language)
Expand Down Expand Up @@ -935,7 +939,10 @@ def e2b_code_execution_reward(
},
)

response_content = messages[-1].content
last_content = messages[-1].content
response_content = (
last_content if isinstance(last_content, str) else "".join([p.text for p in (last_content or [])])
)
expected_output_str = ground_truth

code_blocks = extract_code_blocks(response_content, language)
Expand Down
8 changes: 6 additions & 2 deletions eval_protocol/rewards/lean_prover.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,8 @@ def lean_prover_reward(
},
)

response = messages[-1].content
last_content = messages[-1].content
response = last_content if isinstance(last_content, str) else "".join([p.text for p in (last_content or [])])
if not response:
return EvaluateResult(
score=0.0,
Expand Down Expand Up @@ -230,7 +231,10 @@ def deepseek_prover_v2_reward(
and messages[-1].role == "assistant"
and messages[-1].content is not None
):
response_content = messages[-1].content
last_content = messages[-1].content
response_content = (
last_content if isinstance(last_content, str) else "".join([p.text for p in (last_content or [])])
)

final_score = base_score
subgoal_count = 0
Expand Down
39 changes: 29 additions & 10 deletions eval_protocol/typed_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,23 @@ def decorator(func: F) -> F:
# Detect if the user supplied function is a coroutine (async def)
_is_async_function = inspect.iscoroutinefunction(func)

def _is_list_of_message_annotation(annotation: Any) -> bool:
origin = get_origin(annotation)
args = get_args(annotation)
# Direct List[Message]
if origin in (list, List) and args and args[0] == Message:
return True
# Optional[List[Message]] or Union[List[Message], None]
if origin is Union and args:
# Filter out NoneType
non_none = [a for a in args if a is not type(None)] # noqa: E721
if len(non_none) == 1:
inner = non_none[0]
inner_origin = get_origin(inner)
inner_args = get_args(inner)
return inner_origin in (list, List) and inner_args and inner_args[0] == Message
return False

def _prepare_final_args(*args: Any, **kwargs: Any):
"""Prepare final positional and keyword arguments for the user function call.
This includes Pydantic coercion and resource injection. Returns a tuple of
Expand All @@ -119,19 +136,15 @@ def _coerce_to_list_message(data_list: Any, arg_name_for_error: str) -> List[Mes
if isinstance(item_data, Message):
typed_list.append(item_data)
elif isinstance(item_data, dict):
typed_list.append(Message(**item_data))
typed_list.append(Message.model_validate(item_data))
else:
raise TypeError(f"Unexpected type for item {i} in '{arg_name_for_error}': {type(item_data)}")
return typed_list

# 1. Conditional Pydantic conversion for 'messages' (pointwise) or 'rollouts_messages' (batch)
if mode == "pointwise" and "messages" in params and "messages" in final_func_args:
messages_param_annotation = params["messages"].annotation
if (
get_origin(messages_param_annotation) in (list, List)
and get_args(messages_param_annotation)
and get_args(messages_param_annotation)[0] == Message
):
if _is_list_of_message_annotation(messages_param_annotation):
try:
final_func_args["messages"] = _coerce_to_list_message(final_func_args["messages"], "messages")
except Exception as err:
Expand All @@ -155,12 +168,18 @@ def _coerce_to_list_message(data_list: Any, arg_name_for_error: str) -> List[Mes
# Ground truth coercion (if needed)
if "ground_truth" in params and "ground_truth" in final_func_args:
gt_ann = params["ground_truth"].annotation
if get_origin(gt_ann) in (list, List) and get_args(gt_ann) and get_args(gt_ann)[0] == Message:
if _is_list_of_message_annotation(gt_ann):
if final_func_args["ground_truth"] is not None:
gt_val = final_func_args["ground_truth"]
try:
final_func_args["ground_truth"] = _coerce_to_list_message(
final_func_args["ground_truth"], "ground_truth"
)
if isinstance(gt_val, list):
final_func_args["ground_truth"] = _coerce_to_list_message(gt_val, "ground_truth")
elif isinstance(gt_val, dict):
final_func_args["ground_truth"] = _coerce_to_list_message([gt_val], "ground_truth")
elif isinstance(gt_val, str):
final_func_args["ground_truth"] = _coerce_to_list_message(
[{"role": "system", "content": gt_val}], "ground_truth"
)
except Exception as err:
raise ValueError(
f"Input 'ground_truth' failed Pydantic validation for List[Message]: {err}"
Expand Down
Loading