Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions eval_protocol/benchmarks/test_aime25.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
EvaluationRow,
Message,
MetricResult,
ChatCompletionContentPartParam,
ChatCompletionContentPartTextParam,
)
from eval_protocol.pytest.default_single_turn_rollout_process import (
Expand All @@ -18,10 +19,12 @@


def _coerce_content_to_str(
content: str | list[ChatCompletionContentPartTextParam] | None,
content: str | list[ChatCompletionContentPartParam] | None,
) -> str:
if isinstance(content, list):
return "".join([getattr(p, "text", str(p)) for p in content])
return "".join(
getattr(p, "text", str(p)) if isinstance(p, ChatCompletionContentPartTextParam) else "" for p in content
)
return str(content or "")


Expand Down
7 changes: 5 additions & 2 deletions eval_protocol/benchmarks/test_gpqa.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
EvaluationRow,
Message,
MetricResult,
ChatCompletionContentPartParam,
ChatCompletionContentPartTextParam,
)
from eval_protocol.pytest.default_single_turn_rollout_process import (
Expand Down Expand Up @@ -54,10 +55,12 @@ def _load_gpqa_messages_from_csv() -> list[list[list[Message]]]:


def _coerce_content_to_str(
content: str | list[ChatCompletionContentPartTextParam] | None,
content: str | list[ChatCompletionContentPartParam] | None,
) -> str:
if isinstance(content, list):
return "".join([getattr(p, "text", str(p)) for p in content])
return "".join(
getattr(p, "text", str(p)) if isinstance(p, ChatCompletionContentPartTextParam) else "" for p in content
)
return str(content or "")


Expand Down
7 changes: 5 additions & 2 deletions eval_protocol/benchmarks/test_livebench_data_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
EvaluationRow,
Message,
MetricResult,
ChatCompletionContentPartParam,
ChatCompletionContentPartTextParam,
)
from eval_protocol.pytest.default_single_turn_rollout_process import (
Expand Down Expand Up @@ -37,9 +38,11 @@ def _extract_last_boxed_segment(text: str) -> Optional[str]:
return matches[-1]


def _coerce_content_to_str(content: str | list[ChatCompletionContentPartTextParam] | None) -> str:
def _coerce_content_to_str(content: str | list[ChatCompletionContentPartParam] | None) -> str:
if isinstance(content, list):
return "".join([getattr(p, "text", str(p)) for p in content])
return "".join(
getattr(p, "text", str(p)) if isinstance(p, ChatCompletionContentPartTextParam) else "" for p in content
)
return str(content or "")


Expand Down
37 changes: 36 additions & 1 deletion eval_protocol/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,11 +458,46 @@ def __iter__(self):
return iter(["text", "type"])


class ChatCompletionContentPartImageParam(BaseModel):
type: Literal["image_url"] = Field("image_url", description="The type of the content part.")
image_url: Dict[str, Any] = Field(
..., description="Image descriptor (e.g., {'url': 'data:image/png;base64,...', 'detail': 'high'})."
)

def __getitem__(self, key: str) -> Any:
if key == "image_url":
return self.image_url
if key == "type":
return self.type
raise KeyError(key)

def get(self, key: str, default: Any = None) -> Any:
try:
return self[key]
except KeyError:
return default

def keys(self):
return (k for k in ("image_url", "type"))

def values(self):
return (self.image_url, self.type)

def items(self):
return [("image_url", self.image_url), ("type", self.type)]

def __iter__(self):
return iter(["image_url", "type"])


ChatCompletionContentPartParam = Union[ChatCompletionContentPartTextParam, ChatCompletionContentPartImageParam]


class Message(BaseModel):
"""Chat message model with trajectory evaluation support."""

role: str # assistant, user, system, tool
content: Optional[Union[str, List[ChatCompletionContentPartTextParam]]] = Field(
content: Optional[Union[str, List[ChatCompletionContentPartParam]]] = Field(
default="", description="The content of the message."
)
reasoning_content: Optional[str] = Field(
Expand Down
11 changes: 8 additions & 3 deletions eval_protocol/pytest/default_agent_rollout_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,12 @@
from eval_protocol.dataset_logger.dataset_logger import DatasetLogger
from eval_protocol.mcp.execution.policy import LiteLLMPolicy
from eval_protocol.mcp.mcp_multi_client import MCPMultiClient
from eval_protocol.models import EvaluationRow, Message, ChatCompletionContentPartTextParam
from eval_protocol.models import (
EvaluationRow,
Message,
ChatCompletionContentPartParam,
ChatCompletionContentPartTextParam,
)
from openai.types import CompletionUsage
from eval_protocol.pytest.rollout_processor import RolloutProcessor
from eval_protocol.pytest.types import Dataset, RolloutProcessorConfig
Expand Down Expand Up @@ -98,7 +103,7 @@ def append_message_and_log(self, message: Message):
self.messages.append(message)
self.logger.log(self.evaluation_row)

async def call_agent(self) -> Optional[Union[str, List[ChatCompletionContentPartTextParam]]]:
async def call_agent(self) -> Optional[Union[str, List[ChatCompletionContentPartParam]]]:
"""
Call the assistant with the user query.
"""
Expand Down Expand Up @@ -222,7 +227,7 @@ def _get_content_from_tool_result(self, tool_result: CallToolResult | str) -> Li

def _format_tool_message_content(
self, content: List[TextContent]
) -> Union[str, List[ChatCompletionContentPartTextParam]]:
) -> Union[str, List[ChatCompletionContentPartParam]]:
"""Format tool result content for inclusion in a tool message.

- If a single text item, return plain string per OpenAI semantics.
Expand Down
16 changes: 13 additions & 3 deletions eval_protocol/rewards/accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,18 +10,28 @@
import re
from typing import Any, Callable, Dict, List, Optional, Union, cast

from ..models import EvaluateResult, Message, MetricResult, ChatCompletionContentPartTextParam
from ..models import (
EvaluateResult,
Message,
MetricResult,
ChatCompletionContentPartParam,
ChatCompletionContentPartTextParam,
)


def _to_text(content: Optional[Union[str, List[ChatCompletionContentPartTextParam]]]) -> str:
def _to_text(content: Optional[Union[str, List[ChatCompletionContentPartParam]]]) -> str:
"""Coerce Message.content into a plain string for regex and comparisons."""
if content is None:
return ""
if isinstance(content, str):
return content
# List[ChatCompletionContentPartTextParam]
try:
return "\n".join(part.text for part in content)
texts: List[str] = []
for part in content:
if isinstance(part, ChatCompletionContentPartTextParam):
texts.append(part.text)
return "\n".join(texts)
except Exception:
return ""

Expand Down
14 changes: 11 additions & 3 deletions eval_protocol/rewards/json_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,13 @@
import re
from typing import Any, Dict, List, Optional, Union

from ..models import EvaluateResult, Message, MetricResult, ChatCompletionContentPartTextParam
from ..models import (
EvaluateResult,
Message,
MetricResult,
ChatCompletionContentPartParam,
ChatCompletionContentPartTextParam,
)
from ..typed_interface import reward_function
from .function_calling import (
calculate_jaccard_similarity,
Expand Down Expand Up @@ -59,8 +65,10 @@ def json_schema_reward(
content_text = last_message.content
else:
try:
parts: List[ChatCompletionContentPartTextParam] = last_message.content # type: ignore[assignment]
content_text = "\n".join(getattr(p, "text", "") for p in parts)
parts: List[ChatCompletionContentPartParam] = last_message.content # type: ignore[assignment]
content_text = "\n".join(
getattr(p, "text", "") for p in parts if isinstance(p, ChatCompletionContentPartTextParam)
)
except Exception:
content_text = ""
else:
Expand Down
16 changes: 13 additions & 3 deletions eval_protocol/rewards/language_consistency.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,13 @@
import re
from typing import Any, Dict, List, Optional, Set, Tuple, Union

from ..models import EvaluateResult, Message, MetricResult, ChatCompletionContentPartTextParam
from ..models import (
EvaluateResult,
Message,
MetricResult,
ChatCompletionContentPartParam,
ChatCompletionContentPartTextParam,
)
from ..typed_interface import reward_function

# Dictionary mapping language codes to common words/patterns in that language
Expand Down Expand Up @@ -573,13 +579,17 @@ def language_consistency_reward(
},
)

def _to_text(content: Union[str, List[ChatCompletionContentPartTextParam], None]) -> str:
def _to_text(content: Union[str, List[ChatCompletionContentPartParam], None]) -> str:
if content is None:
return ""
if isinstance(content, str):
return content
try:
return "\n".join(part.text for part in content)
texts: List[str] = []
for part in content:
if isinstance(part, ChatCompletionContentPartTextParam):
texts.append(part.text)
return "\n".join(texts)
except Exception:
return ""

Expand Down
16 changes: 13 additions & 3 deletions eval_protocol/rewards/repetition.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,26 @@
import re
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union

from ..models import EvaluateResult, Message, MetricResult, ChatCompletionContentPartTextParam
from ..models import (
EvaluateResult,
Message,
MetricResult,
ChatCompletionContentPartParam,
ChatCompletionContentPartTextParam,
)


def _to_text(content: Optional[Union[str, List[ChatCompletionContentPartTextParam]]]) -> str:
def _to_text(content: Optional[Union[str, List[ChatCompletionContentPartParam]]]) -> str:
if content is None:
return ""
if isinstance(content, str):
return content
try:
return "\n".join(part.text for part in content)
texts: List[str] = []
for part in content:
if isinstance(part, ChatCompletionContentPartTextParam):
texts.append(part.text)
return "\n".join(texts)
except Exception:
return ""

Expand Down
16 changes: 13 additions & 3 deletions eval_protocol/rewards/tag_count.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,26 @@
import re
from typing import Any, Dict, List, Set, Union

from ..models import EvaluateResult, Message, MetricResult, ChatCompletionContentPartTextParam
from ..models import (
EvaluateResult,
Message,
MetricResult,
ChatCompletionContentPartParam,
ChatCompletionContentPartTextParam,
)


def _to_text(content: Union[str, List[ChatCompletionContentPartTextParam], None]) -> str:
def _to_text(content: Union[str, List[ChatCompletionContentPartParam], None]) -> str:
if content is None:
return ""
if isinstance(content, str):
return content
try:
return "\n".join(part.text for part in content)
texts: List[str] = []
for part in content:
if isinstance(part, ChatCompletionContentPartTextParam):
texts.append(part.text)
return "\n".join(texts)
except Exception:
return ""

Expand Down
47 changes: 47 additions & 0 deletions tests/pytest/test_single_turn_rollout_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,3 +116,50 @@ async def fake_acompletion(**kwargs):
assert [m["role"] for m in sent_msgs] == ["user", "assistant"]
assert [m.role for m in out.messages] == ["user", "assistant", "assistant"]
assert out.messages[-1].content == "Hello again"


@pytest.mark.asyncio
async def test_single_turn_handles_missing_usage_block(monkeypatch):
row = EvaluationRow(messages=[Message(role="user", content="Describe the picture")])

import eval_protocol.pytest.default_single_turn_rollout_process as mod

class StubChoices:
pass

class StubModelResponse:
def __init__(self, text: str):
self.choices = [StubChoices()]
self.choices[0].message = SimpleNamespace(content=text, tool_calls=None)
self.usage = None

async def fake_acompletion(**kwargs):
return StubModelResponse(text="It looks like creme brulee")

class StubLogger:
def __init__(self):
self.logged = []

def log(self, row):
self.logged.append(row)

def read(self, rollout_id=None):
return list(self.logged)

stub_logger = StubLogger()

monkeypatch.setattr(mod, "ModelResponse", StubModelResponse, raising=True)
monkeypatch.setattr(mod, "Choices", StubChoices, raising=True)
monkeypatch.setattr(mod, "acompletion", fake_acompletion, raising=True)
monkeypatch.setattr(mod, "default_logger", stub_logger, raising=False)

processor = SingleTurnRolloutProcessor()
config = _DummyConfig()

tasks = processor([row], config)
out = await tasks[0]

assert [m.role for m in out.messages] == ["user", "assistant"]
assert out.messages[-1].content == "It looks like creme brulee"
# Usage should remain unset when the provider omits it
assert out.execution_metadata.usage is None
Comment thread
benjibc marked this conversation as resolved.
Binary file added vite-app/bun.lockb
Binary file not shown.
1 change: 0 additions & 1 deletion vite-app/dist/assets/index-BIhepl19.css

This file was deleted.

46 changes: 46 additions & 0 deletions vite-app/dist/assets/index-CuQbfdPD.js

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions vite-app/dist/assets/index-CuQbfdPD.js.map

Large diffs are not rendered by default.

Loading
Loading