Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 18 additions & 3 deletions eval_protocol/benchmarks/test_aime25.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
from typing import Any, Dict, List, Optional

from eval_protocol.models import EvaluateResult, EvaluationRow, Message, MetricResult
from eval_protocol.models import (
EvaluateResult,
EvaluationRow,
Message,
MetricResult,
ChatCompletionContentPartTextParam,
)
from eval_protocol.pytest.default_single_turn_rollout_process import (
SingleTurnRolloutProcessor,
)
Expand All @@ -11,6 +17,14 @@
)


def _coerce_content_to_str(
content: str | list[ChatCompletionContentPartTextParam] | None,
) -> str:
if isinstance(content, list):
return "".join([getattr(p, "text", str(p)) for p in content])
return str(content or "")


def _extract_boxed_text(text: str) -> str:
import re

Expand Down Expand Up @@ -80,9 +94,10 @@ def aime2025_dataset_adapter(rows: List[Dict[str, Any]]) -> List[EvaluationRow]:
)
def test_aime25_pointwise(row: EvaluationRow) -> EvaluationRow:
assistant_msgs = [m for m in row.messages if m.role == "assistant"]
content = assistant_msgs[-1].content if assistant_msgs else ""
raw_content = assistant_msgs[-1].content if assistant_msgs else ""
content_str = _coerce_content_to_str(raw_content)

extracted_text = _extract_boxed_text(content or "")
extracted_text = _extract_boxed_text(content_str)
extracted_int = _normalize_to_int_or_none(extracted_text)
gt_int = _normalize_to_int_or_none(row.ground_truth or "")

Expand Down
50 changes: 38 additions & 12 deletions eval_protocol/benchmarks/test_gpqa.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,13 @@

import requests

from eval_protocol.models import EvaluateResult, EvaluationRow, Message, MetricResult
from eval_protocol.models import (
EvaluateResult,
EvaluationRow,
Message,
MetricResult,
ChatCompletionContentPartTextParam,
)
from eval_protocol.pytest.default_single_turn_rollout_process import (
SingleTurnRolloutProcessor,
)
Expand Down Expand Up @@ -47,6 +53,14 @@ def _load_gpqa_messages_from_csv() -> list[list[list[Message]]]:
return [messages_list]


def _coerce_content_to_str(
content: str | list[ChatCompletionContentPartTextParam] | None,
) -> str:
if isinstance(content, list):
return "".join([getattr(p, "text", str(p)) for p in content])
return str(content or "")


def _extract_abcd_letter(text: str) -> str | None:
if not text:
return None
Expand All @@ -58,9 +72,12 @@ def _extract_abcd_letter(text: str) -> str | None:


def _strip_gt_messages(msgs: list[Message]) -> list[Message]:
# assert that all the messages just have a plain .content string field
assert all(isinstance(m.content, str) for m in msgs), "Messages must have a plain .content string field"
return [m for m in msgs if not (m.role == "system" and (m.content or "").startswith("__GT__:"))]
result: list[Message] = []
for m in msgs:
content_str = _coerce_content_to_str(m.content)
if not (m.role == "system" and content_str.startswith("__GT__:")):
result.append(m)
return result


class GPQAStripGTRolloutProcessor(RolloutProcessor):
Expand All @@ -75,15 +92,23 @@ def __call__(self, rows: list[EvaluationRow], config: RolloutProcessorConfig) ->
processed: list[EvaluationRow] = []

for r in rows:
gt_tokens = [
m.content for m in r.messages if m.role == "system" and (m.content or "").startswith("__GT__:")
]
gt_tokens: list[str] = []
for m in r.messages:
if m.role == "system":
content_str = _coerce_content_to_str(m.content)
if content_str.startswith("__GT__:"):
gt_tokens.append(content_str)
if gt_tokens:
gt_val = gt_tokens[-1].split(":", 1)[1].strip()
r.ground_truth = gt_val
r.messages = [
m for m in r.messages if not (m.role == "system" and (m.content or "").startswith("__GT__:"))
]
filtered: list[Message] = []
for m in r.messages:
if m.role == "system":
content_str = _coerce_content_to_str(m.content)
if content_str.startswith("__GT__:"):
continue
filtered.append(m)
r.messages = filtered
processed.append(r)

# Delegate to SingleTurnRolloutProcessor
Expand All @@ -103,9 +128,10 @@ def __call__(self, rows: list[EvaluationRow], config: RolloutProcessorConfig) ->
)
def test_gpqa_pointwise(row: EvaluationRow) -> EvaluationRow:
assistant_msgs = [m for m in row.messages if m.role == "assistant"]
content = assistant_msgs[-1].content if assistant_msgs else ""
raw_content = assistant_msgs[-1].content if assistant_msgs else ""
content_str = _coerce_content_to_str(raw_content)

pred = _extract_abcd_letter(content or "")
pred = _extract_abcd_letter(content_str)
# GPQA diamond CSV constructs options so that the correct answer is always A
gt = "A"

Expand Down
27 changes: 21 additions & 6 deletions eval_protocol/benchmarks/test_livebench_data_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,13 @@
import re
from typing import Any, Dict, List, Optional

from eval_protocol.models import EvaluateResult, EvaluationRow, Message, MetricResult
from eval_protocol.models import (
EvaluateResult,
EvaluationRow,
Message,
MetricResult,
ChatCompletionContentPartTextParam,
)
from eval_protocol.pytest.default_single_turn_rollout_process import (
SingleTurnRolloutProcessor,
)
Expand Down Expand Up @@ -31,6 +37,12 @@
return matches[-1]


def _coerce_content_to_str(content: str | list[ChatCompletionContentPartTextParam] | None) -> str:
if isinstance(content, list):
return "".join([getattr(p, "text", str(p)) for p in content])
return str(content or "")


def _cta_process_results(ground_truth: str, llm_answer: str) -> int:
parsed_answer = llm_answer
if "\\boxed{" in parsed_answer or "\\framebox{" in parsed_answer:
Expand Down Expand Up @@ -275,6 +287,8 @@
return 0

# Compare
assert llm_df is not None, "LLM dataframe is None"
assert gt_df is not None, "GT dataframe is None"
try:
gt_df.columns = [str(s).strip() for s in gt_df.columns]
if "index" in gt_df.columns:
Expand Down Expand Up @@ -411,7 +425,7 @@
@evaluation_test(
completion_params=[{"model": "fireworks_ai/accounts/fireworks/models/gpt-oss-120b"}],
input_messages=[[[m for m in r.messages] for r in _CTA_ROWS]],
rollout_processor_kwargs=[{"extra_body": {"reasoning_effort": "low"}}],

Check failure on line 428 in eval_protocol/benchmarks/test_livebench_data_analysis.py

View workflow job for this annotation

GitHub Actions / Lint & Type Check

Argument of type "list[dict[str, dict[str, str]]]" cannot be assigned to parameter "rollout_processor_kwargs" of type "RolloutProcessorInputParam | None" in function "evaluation_test"   Type "list[dict[str, dict[str, str]]]" is not assignable to type "RolloutProcessorInputParam | None"     "list[dict[str, dict[str, str]]]" is not assignable to "dict[str, Any]"     "list[dict[str, dict[str, str]]]" is not assignable to "None" (reportArgumentType)
rollout_processor=SingleTurnRolloutProcessor(),
aggregation_method="mean",
passed_threshold=None,
Expand All @@ -420,7 +434,8 @@
)
def test_livebench_cta_pointwise(row: EvaluationRow) -> EvaluationRow:
assistant_msgs = [m for m in row.messages if m.role == "assistant"]
content = assistant_msgs[-1].content if assistant_msgs else ""
raw_content = assistant_msgs[-1].content if assistant_msgs else ""
content = _coerce_content_to_str(raw_content)
payload = _extract_gt(row)
gt = payload.get("ground_truth")
gt_str = str(gt) if gt is not None else ""
Expand Down Expand Up @@ -453,7 +468,7 @@
@evaluation_test(
completion_params=[{"model": "fireworks_ai/accounts/fireworks/models/gpt-oss-120b"}],
input_messages=[[[m for m in r.messages] for r in _TABLEJOIN_ROWS]],
rollout_processor_kwargs=[{"extra_body": {"reasoning_effort": "low"}}],

Check failure on line 471 in eval_protocol/benchmarks/test_livebench_data_analysis.py

View workflow job for this annotation

GitHub Actions / Lint & Type Check

Argument of type "list[dict[str, dict[str, str]]]" cannot be assigned to parameter "rollout_processor_kwargs" of type "RolloutProcessorInputParam | None" in function "evaluation_test"   Type "list[dict[str, dict[str, str]]]" is not assignable to type "RolloutProcessorInputParam | None"     "list[dict[str, dict[str, str]]]" is not assignable to "dict[str, Any]"     "list[dict[str, dict[str, str]]]" is not assignable to "None" (reportArgumentType)
rollout_processor=LiveBenchGroundTruthRolloutProcessor(_TABLEJOIN_ROWS),
aggregation_method="mean",
passed_threshold=None,
Expand All @@ -462,9 +477,9 @@
)
def test_livebench_tablejoin_pointwise(row: EvaluationRow) -> EvaluationRow:
user_msgs = [m for m in row.messages if m.role == "user"]
question = user_msgs[-1].content if user_msgs else ""
question = _coerce_content_to_str(user_msgs[-1].content if user_msgs else "")
assistant_msgs = [m for m in row.messages if m.role == "assistant"]
content = assistant_msgs[-1].content if assistant_msgs else ""
content = _coerce_content_to_str(assistant_msgs[-1].content if assistant_msgs else "")
payload = _extract_gt(row)
gt = payload.get("ground_truth")

Expand Down Expand Up @@ -495,8 +510,8 @@

@evaluation_test(
completion_params=[{"model": "fireworks_ai/accounts/fireworks/models/gpt-oss-120b"}],
input_messages=[[m for m in r.messages] for r in _TABLEREFORMAT_ROWS],

Check failure on line 513 in eval_protocol/benchmarks/test_livebench_data_analysis.py

View workflow job for this annotation

GitHub Actions / Lint & Type Check

Argument of type "list[list[Message]]" cannot be assigned to parameter "input_messages" of type "Sequence[list[InputMessagesParam] | None] | None" in function "evaluation_test" (reportArgumentType)
rollout_processor_kwargs=[{"extra_body": {"reasoning_effort": "low"}}],

Check failure on line 514 in eval_protocol/benchmarks/test_livebench_data_analysis.py

View workflow job for this annotation

GitHub Actions / Lint & Type Check

Argument of type "list[dict[str, dict[str, str]]]" cannot be assigned to parameter "rollout_processor_kwargs" of type "RolloutProcessorInputParam | None" in function "evaluation_test"   Type "list[dict[str, dict[str, str]]]" is not assignable to type "RolloutProcessorInputParam | None"     "list[dict[str, dict[str, str]]]" is not assignable to "dict[str, Any]"     "list[dict[str, dict[str, str]]]" is not assignable to "None" (reportArgumentType)
rollout_processor=LiveBenchGroundTruthRolloutProcessor(_TABLEREFORMAT_ROWS),
aggregation_method="mean",
passed_threshold=None,
Expand All @@ -505,9 +520,9 @@
)
def test_livebench_tablereformat_pointwise(row: EvaluationRow) -> EvaluationRow:
user_msgs = [m for m in row.messages if m.role == "user"]
question = user_msgs[-1].content if user_msgs else ""
question = _coerce_content_to_str(user_msgs[-1].content if user_msgs else "")
assistant_msgs = [m for m in row.messages if m.role == "assistant"]
content = assistant_msgs[-1].content if assistant_msgs else ""
content = _coerce_content_to_str(assistant_msgs[-1].content if assistant_msgs else "")
payload = _extract_gt(row)
gt = payload.get("ground_truth")
release = payload.get("release") or ""
Expand Down
4 changes: 3 additions & 1 deletion eval_protocol/integrations/braintrust.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@ def scorer_to_reward_fn(
"""Wrap a Braintrust scorer as an Eval Protocol reward function."""

@reward_function
def reward_fn(messages: List[Message], ground_truth: Optional[List[Message]] = None, **kwargs) -> EvaluateResult:
def reward_fn(
messages: List[Message], ground_truth: Optional[List[Message]] = None, **kwargs: Any
) -> EvaluateResult:
input_val = messages_to_input(messages) if messages_to_input else messages[0].content
output_val = messages[-1].content
expected_val = None
Expand Down
13 changes: 11 additions & 2 deletions eval_protocol/integrations/deepeval.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,10 @@ def _build_case_kwargs() -> Dict[str, Any]:
case_kwargs["actual_output"] = output
return case_kwargs

if isinstance(metric, BaseConversationalMetric):
if BaseConversationalMetric is not None and isinstance(metric, BaseConversationalMetric):
# Narrow types for optional imports to satisfy the type checker
assert LLMTestCase is not None
assert ConversationalTestCase is not None
turns = []
for i, msg in enumerate(messages):
turn_input = messages[i - 1].get("content", "") if i > 0 else ""
Expand All @@ -93,10 +96,16 @@ def _build_case_kwargs() -> Dict[str, Any]:
output = messages[-1].get("content", "")
test_case = ConversationalTestCase(turns=turns)
else:
# Narrow types for optional imports to satisfy the type checker
assert LLMTestCase is not None
case_kwargs = _build_case_kwargs()
test_case = LLMTestCase(**case_kwargs)

metric.measure(test_case, **kwargs)
# Guard against metric.measure being None or non-callable
measure_fn = getattr(metric, "measure", None)
if not callable(measure_fn):
raise TypeError("Provided metric does not have a callable 'measure' method")
measure_fn(test_case, **kwargs)
score = float(metric.score or 0.0)
reason = getattr(metric, "reason", None)
name = _metric_name(metric)
Expand Down
2 changes: 1 addition & 1 deletion eval_protocol/mcp/mcpgym.py
Original file line number Diff line number Diff line change
Expand Up @@ -563,7 +563,7 @@ def format_observation(self, obs: Any, env: Any) -> Dict[str, Any]:
else:
return {"observation": serialized_obs}

def run(self, transport: str = "streamable-http", **kwargs):
def run(self, transport: Literal["stdio", "sse", "streamable-http"] = "streamable-http", **kwargs):
"""Run the unified MCP-Gym server with high concurrency settings."""
if transport == "streamable-http":
# Run with custom high-concurrency uvicorn config
Expand Down
3 changes: 2 additions & 1 deletion eval_protocol/mcp/simulation_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ def reset_environment(self, env, seed): ...
from abc import ABC, abstractmethod
from collections.abc import AsyncIterator
from contextlib import asynccontextmanager
from typing import Any, Callable, Dict, List, Optional, Tuple
from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple, Iterable
from pydantic import AnyUrl

import uvicorn
from mcp.server.lowlevel import Server
Expand Down
3 changes: 2 additions & 1 deletion eval_protocol/mcp_agent/orchestration/local_docker_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ async def startup(self) -> None:
except docker.errors.DockerException as e:
logger.warning(f"docker.from_env() failed: {e}. Trying explicit base_url.")
try:
# docker.from_env is preferred, but as a fallback use DockerClient with url param name 'base_url'
self.docker_client = docker.DockerClient(base_url="unix://var/run/docker.sock")
if not self.docker_client.ping(): # type: ignore
raise ConnectionError("Failed to connect to Docker daemon with explicit base_url.")
Expand Down Expand Up @@ -649,7 +650,7 @@ async def list_tools_on_instance(self, instance: ManagedInstanceInfo) -> types.L
)
target_base_url = instance.mcp_endpoint_url.rstrip("/")
try:
async with streamablehttp_client(base_url=target_base_url) as (
async with streamablehttp_client(base_url=target_base_url) as ( # type: ignore
read_s,
write_s,
_, # get_session_id_func usually not needed for a single call
Expand Down
3 changes: 2 additions & 1 deletion eval_protocol/mcp_servers/tau2/tau2_mcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def __init__(self, seed: Optional[int] = None, **kwargs):

self.adapter = EnvironmentAdapter(env_class=AirlineEnvironment, default_config=default_config)

# Ensure name is a str and not None
super().__init__("airline", self.adapter, seed, **kwargs)

def _register_tools(self):
Expand Down Expand Up @@ -421,7 +422,7 @@ def _register_tools(self):
"""Register mock-specific MCP tools matching τ²-Bench schemas"""

@self.mcp.tool(name="create_task", description="Create a new task for a user.")
def create_task(user_id: str, title: str, ctx: Context, description: str = None) -> Dict[str, Any]:
def create_task(user_id: str, title: str, ctx: Context, description: Optional[str] = None) -> Dict[str, Any]:
"""Create a new task for a user"""
session_id = self._get_session_id(ctx)

Expand Down
34 changes: 32 additions & 2 deletions eval_protocol/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,33 @@ class ChatCompletionContentPartTextParam(BaseModel):
text: str = Field(..., description="The text content.")
type: Literal["text"] = Field("text", description="The type of the content part.")

# Provide dict-like access for tests and ergonomic usage
def __getitem__(self, key: str) -> Any:
if key == "text":
return self.text
if key == "type":
return self.type
raise KeyError(key)

def get(self, key: str, default: Any = None) -> Any:
try:
return self[key]
except KeyError:
return default

def keys(self):
return (k for k in ("text", "type"))

def values(self):
return (self.text, self.type)

def items(self):
return [("text", self.text), ("type", self.type)]

def __iter__(self):
# Iterate over keys only
return iter(["text", "type"])


class Message(BaseModel):
"""Chat message model with trajectory evaluation support."""
Expand Down Expand Up @@ -271,6 +298,7 @@ class MetricResult(BaseModel):
is_score_valid: bool = True
score: float = Field(..., ge=0.0, le=1.0)
reason: str
data: Dict[str, Any] = Field(default_factory=dict, description="Optional extra metric data for debugging.")

def __getitem__(self, key: str) -> Any:
if key in self.__fields__: # Changed to __fields__ for Pydantic v1 compatibility
Expand All @@ -292,10 +320,12 @@ def values(self):
return [getattr(self, key) for key in self.__fields__.keys()] # Changed to __fields__

def items(self):
return [(key, getattr(self, key)) for key in self.__fields__.keys()] # Changed to __fields__
# Exclude 'data' from items to keep items hashable and match tests
return [(key, getattr(self, key)) for key in self.__fields__.keys() if key != "data"] # Changed to __fields__

def __iter__(self):
return iter(self.__fields__.keys()) # Changed to __fields__
# Exclude 'data' to match expectations in tests
return iter([k for k in self.__fields__.keys() if k != "data"]) # Changed to __fields__


class StepOutput(BaseModel):
Expand Down
Loading
Loading