-
Notifications
You must be signed in to change notification settings - Fork 16
Expand file tree
/
Copy pathdefault_pydantic_ai_rollout_processor.py
More file actions
122 lines (105 loc) · 5.49 KB
/
default_pydantic_ai_rollout_processor.py
File metadata and controls
122 lines (105 loc) · 5.49 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
# pyright: reportPrivateUsage=false
import asyncio
from collections.abc import Callable
import logging
import time
from pydantic_ai.usage import UsageLimits
from typing_extensions import override
from eval_protocol.models import EvaluationRow, Message
from eval_protocol.pytest.rollout_processor import RolloutProcessor
from eval_protocol.pytest.types import RolloutProcessorConfig
from openai.types.chat import ChatCompletion, ChatCompletionMessage, ChatCompletionMessageParam
from openai.types.chat.chat_completion import Choice as ChatCompletionChoice
from pydantic import TypeAdapter
from pydantic_ai import Agent
from pydantic_ai._utils import generate_tool_call_id
from pydantic_ai.messages import ModelMessage
from pydantic_ai.messages import (
ModelRequest,
SystemPromptPart,
ToolReturnPart,
UserPromptPart,
)
from pydantic_ai.models.openai import OpenAIModel
from pydantic_ai.providers.openai import OpenAIProvider
logger = logging.getLogger(__name__)
class PydanticAgentRolloutProcessor(RolloutProcessor):
"""Rollout processor for Pydantic AI agents. Mainly converts
EvaluationRow.messages to and from Pydantic AI ModelMessage format."""
def __init__(
self,
agent_factory: Callable[[RolloutProcessorConfig], Agent],
usage_limits: UsageLimits | None = None,
):
# dummy model used for its helper functions for processing messages
self._util: OpenAIModel = OpenAIModel("dummy-model", provider=OpenAIProvider(api_key="dummy"))
self._setup_agent = agent_factory
@override
def __call__(self, rows: list[EvaluationRow], config: RolloutProcessorConfig) -> list[asyncio.Task[EvaluationRow]]:
"""Create agent rollout tasks and return them for external handling."""
semaphore = config.semaphore
agent = self._setup_agent(config)
async def process_row(row: EvaluationRow) -> EvaluationRow:
"""Process a single row with agent rollout."""
start_time = time.perf_counter()
model_messages = [self.convert_ep_message_to_pyd_message(m, row) for m in row.messages]
response = await agent.run(message_history=model_messages, usage_limits=config.kwargs.get("usage_limits"))
row.messages = await self.convert_pyd_message_to_ep_message(response.all_messages())
# TODO: pydantic ai accumulates usage info across all models in multi-agent setup, so this simple tracking doesn't work for cost. to discuss with @dphuang2 when he's back.
# usage_info = response.usage()
# row.execution_metadata.usage = CompletionUsage(
# prompt_tokens=usage_info.request_tokens or 0,
# completion_tokens=usage_info.response_tokens or 0,
# total_tokens=usage_info.total_tokens or 0,
# )
row.execution_metadata.duration_seconds = time.perf_counter() - start_time
return row
async def _sem_wrapper(r: EvaluationRow) -> EvaluationRow:
async with semaphore:
result = await process_row(r)
return result
# Create and return tasks for external handling
tasks = [asyncio.create_task(_sem_wrapper(row)) for row in rows]
return tasks
async def convert_pyd_message_to_ep_message(self, messages: list[ModelMessage]) -> list[Message]:
oai_messages: list[ChatCompletionMessageParam] = await self._util._map_messages(messages)
return [Message(**m) for m in oai_messages] # pyright: ignore[reportArgumentType]
def convert_ep_message_to_pyd_message(self, message: Message, row: EvaluationRow) -> ModelMessage:
if message.role == "assistant":
type_adapter = TypeAdapter(ChatCompletionMessage)
oai_message = type_adapter.validate_python(message)
# Fix: Provide required finish_reason and index, and ensure created is int (timestamp)
return self._util._process_response(
ChatCompletion(
choices=[ChatCompletionChoice(message=oai_message, finish_reason="stop", index=0)],
object="chat.completion",
model="",
id="",
created=int(row.created_at.timestamp()),
)
)
elif message.role == "user":
if isinstance(message.content, str):
return ModelRequest(parts=[UserPromptPart(content=message.content)])
elif isinstance(message.content, list):
return ModelRequest(parts=[UserPromptPart(content=message.content[0].text)])
else:
raise ValueError(f"Unsupported content type for user message: {type(message.content)}")
elif message.role == "system":
if isinstance(message.content, str):
return ModelRequest(parts=[SystemPromptPart(content=message.content)])
elif isinstance(message.content, list):
return ModelRequest(parts=[SystemPromptPart(content=message.content[0].text)])
else:
raise ValueError(f"Unsupported content type for system message: {type(message.content)}")
elif message.role == "tool":
return ModelRequest(
parts=[
ToolReturnPart(
content=message.content,
tool_name="",
tool_call_id=message.tool_call_id or generate_tool_call_id(),
)
]
)
raise ValueError(f"Unknown role: {message.role}")