Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 28 additions & 29 deletions eval_protocol/pytest/default_pydantic_ai_rollout_processor.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
# pyright: reportPrivateUsage=false

import asyncio
import logging
import types
from typing import List

from attr import dataclass
from openai.types.chat.chat_completion_assistant_message_param import ChatCompletionAssistantMessageParam

from pydantic_ai.models import Model
from typing_extensions import override
from eval_protocol.models import EvaluationRow, Message
from eval_protocol.pytest.rollout_processor import RolloutProcessor
from eval_protocol.pytest.types import RolloutProcessorConfig
from openai.types.chat import ChatCompletion, ChatCompletionMessageParam
from openai.types.chat import ChatCompletion, ChatCompletionMessage, ChatCompletionMessageParam
from openai.types.chat.chat_completion import Choice as ChatCompletionChoice
from pydantic_ai.models.anthropic import AnthropicModel
from pydantic_ai.models.openai import OpenAIModel
Expand All @@ -25,7 +24,6 @@
UserPromptPart,
)
from pydantic_ai.providers.openai import OpenAIProvider
from typing_extensions import TypedDict

logger = logging.getLogger(__name__)

Expand All @@ -36,9 +34,10 @@ class PydanticAgentRolloutProcessor(RolloutProcessor):

def __init__(self):
# dummy model used for its helper functions for processing messages
self.util = OpenAIModel("dummy-model", provider=OpenAIProvider(api_key="dummy"))
self.util: OpenAIModel = OpenAIModel("dummy-model", provider=OpenAIProvider(api_key="dummy"))

def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig) -> List[asyncio.Task[EvaluationRow]]:
@override
def __call__(self, rows: list[EvaluationRow], config: RolloutProcessorConfig) -> list[asyncio.Task[EvaluationRow]]:
"""Create agent rollout tasks and return them for external handling."""

max_concurrent = getattr(config, "max_concurrent_rollouts", 8) or 8
Expand All @@ -60,34 +59,34 @@ def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig) ->
raise ValueError(
"completion_params['model'] must be a dict mapping agent argument names to model config dicts (with 'model' and 'provider' keys)"
)
kwargs = {}
for k, v in config.completion_params["model"].items():
if v["model"] and v["model"].startswith("anthropic:"):
kwargs: dict[str, Model] = {}
for k, v in config.completion_params["model"].items(): # pyright: ignore[reportUnknownVariableType]
if v["model"] and v["model"].startswith("anthropic:"): # pyright: ignore[reportUnknownMemberType]
kwargs[k] = AnthropicModel(
v["model"].removeprefix("anthropic:"),
v["model"].removeprefix("anthropic:"), # pyright: ignore[reportUnknownMemberType, reportUnknownArgumentType]
)
elif v["model"] and v["model"].startswith("google:"):
elif v["model"] and v["model"].startswith("google:"): # pyright: ignore[reportUnknownMemberType]
kwargs[k] = GoogleModel(
v["model"].removeprefix("google:"),
v["model"].removeprefix("google:"), # pyright: ignore[reportUnknownMemberType, reportUnknownArgumentType]
)
else:
kwargs[k] = OpenAIModel(
v["model"],
provider=v["provider"],
v["model"], # pyright: ignore[reportUnknownArgumentType]
provider=v["provider"], # pyright: ignore[reportUnknownArgumentType]
)
agent = setup_agent(**kwargs)
agent_instance: Agent = setup_agent(**kwargs) # pyright: ignore[reportAny]
model = None
else:
agent = config.kwargs["agent"]
agent_instance = config.kwargs["agent"] # pyright: ignore[reportAssignmentType]
model = OpenAIModel(
config.completion_params["model"],
provider=config.completion_params["provider"],
config.completion_params["model"], # pyright: ignore[reportAny]
provider=config.completion_params["provider"], # pyright: ignore[reportAny]
)

async def process_row(row: EvaluationRow) -> EvaluationRow:
"""Process a single row with agent rollout."""
model_messages = [self.convert_ep_message_to_pyd_message(m, row) for m in row.messages]
response = await agent.run(
response = await agent_instance.run(
message_history=model_messages, model=model, usage_limits=config.kwargs.get("usage_limits")
)
row.messages = await self.convert_pyd_message_to_ep_message(response.all_messages())
Expand All @@ -104,11 +103,11 @@ async def _sem_wrapper(r: EvaluationRow) -> EvaluationRow:

async def convert_pyd_message_to_ep_message(self, messages: list[ModelMessage]) -> list[Message]:
oai_messages: list[ChatCompletionMessageParam] = await self.util._map_messages(messages)
return [Message(**m) for m in oai_messages]
return [Message(**m) for m in oai_messages] # pyright: ignore[reportArgumentType]

def convert_ep_message_to_pyd_message(self, message: Message, row: EvaluationRow) -> ModelMessage:
if message.role == "assistant":
type_adapter = TypeAdapter(ChatCompletionAssistantMessageParam)
type_adapter = TypeAdapter(ChatCompletionMessage)
oai_message = type_adapter.validate_python(message)
# Fix: Provide required finish_reason and index, and ensure created is int (timestamp)
return self.util._process_response(
Expand All @@ -117,23 +116,23 @@ def convert_ep_message_to_pyd_message(self, message: Message, row: EvaluationRow
object="chat.completion",
model="",
id="",
created=(
int(row.created_at.timestamp())
if hasattr(row.created_at, "timestamp")
else int(row.created_at)
),
created=int(row.created_at.timestamp()),
)
)
elif message.role == "user":
if isinstance(message.content, str):
return ModelRequest(parts=[UserPromptPart(content=message.content)])
elif isinstance(message.content, list):
return ModelRequest(parts=[UserPromptPart(content=message.content[0].text)])
else:
raise ValueError(f"Unsupported content type for user message: {type(message.content)}")
elif message.role == "system":
if isinstance(message.content, str):
return ModelRequest(parts=[SystemPromptPart(content=message.content)])
elif isinstance(message.content, list):
return ModelRequest(parts=[SystemPromptPart(content=message.content[0].text)])
else:
raise ValueError(f"Unsupported content type for system message: {type(message.content)}")
elif message.role == "tool":
return ModelRequest(
parts=[
Expand Down
6 changes: 5 additions & 1 deletion tests/chinook/test_pydantic_chinook.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from eval_protocol.pytest.default_pydantic_ai_rollout_processor import PydanticAgentRolloutProcessor
from tests.chinook.agent import setup_agent
import os
from pydantic_ai.models.openai import OpenAIModel

from tests.chinook.dataset import collect_dataset
Expand Down Expand Up @@ -82,7 +83,10 @@ class Response(BaseModel):
return row


@pytest.mark.skip(reason="takes too long to run")
@pytest.mark.skipif(
os.environ.get("CI") == "true",
reason="Only run this test locally (skipped in CI)",
)
@pytest.mark.asyncio
@evaluation_test(
input_rows=[collect_dataset()],
Expand Down

Large diffs are not rendered by default.

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion vite-app/dist/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
<title>EP | Log Viewer</title>
<link rel="icon" href="/assets/favicon-BkAAWQga.png" />
<script type="module" crossorigin src="/assets/index-CMzLJz8S.js"></script>
<script type="module" crossorigin src="/assets/index-D04dO2VH.js"></script>
<link rel="stylesheet" crossorigin href="/assets/index-DLyzGYL0.css">
</head>
<body>
Expand Down
10 changes: 5 additions & 5 deletions vite-app/src/components/EvaluationRow.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -402,6 +402,11 @@ export const EvaluationRow = observer(
<ExpandIcon rolloutId={rolloutId} />
</TableCell>

{/* Created */}
<TableCell className="py-3 text-xs">
<RowCreated created_at={row.created_at} />
</TableCell>

{/* Name */}
<TableCell className="py-3 text-xs">
<RowName name={row.eval_metadata?.name} />
Expand Down Expand Up @@ -461,11 +466,6 @@ export const EvaluationRow = observer(
<TableCell className="py-3 text-xs">
<RowScore score={row.evaluation_result?.score} />
</TableCell>

{/* Created */}
<TableCell className="py-3 text-xs">
<RowCreated created_at={row.created_at} />
</TableCell>
</TableRowInteractive>

{/* Expanded Content Row */}
Expand Down
16 changes: 8 additions & 8 deletions vite-app/src/components/EvaluationTable.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,14 @@ export const EvaluationTable = observer(() => {
<TableHead>
<tr>
<TableHeader className="w-8">&nbsp;</TableHeader>
<SortableTableHeader
sortField="created_at"
currentSortField={state.sortField}
currentSortDirection={state.sortDirection}
onSort={handleSort}
>
Created
</SortableTableHeader>
<SortableTableHeader
sortField="$.eval_metadata.name"
currentSortField={state.sortField}
Expand Down Expand Up @@ -245,14 +253,6 @@ export const EvaluationTable = observer(() => {
>
Score
</SortableTableHeader>
<SortableTableHeader
sortField="created_at"
currentSortField={state.sortField}
currentSortDirection={state.sortDirection}
onSort={handleSort}
>
Created
</SortableTableHeader>
</tr>
</TableHead>

Expand Down
Loading