-
Notifications
You must be signed in to change notification settings - Fork 16
Expand file tree
/
Copy pathdefault_pydantic_ai_rollout_processor.py
More file actions
139 lines (121 loc) · 6.24 KB
/
default_pydantic_ai_rollout_processor.py
File metadata and controls
139 lines (121 loc) · 6.24 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
# pyright: reportPrivateUsage=false
import asyncio
from collections.abc import Callable
import logging
import time
from pydantic_ai.toolsets import FunctionToolset
from pydantic_ai.usage import UsageLimits
from typing_extensions import override
from eval_protocol.models import EvaluationRow, Message
from eval_protocol.pytest.rollout_processor import RolloutProcessor
from eval_protocol.pytest.types import RolloutProcessorConfig
from openai.types.chat import ChatCompletion, ChatCompletionMessage, ChatCompletionMessageParam
from openai.types.chat.chat_completion import Choice as ChatCompletionChoice
from pydantic import TypeAdapter
from pydantic_ai import Agent
from pydantic_ai._utils import generate_tool_call_id
from pydantic_ai.messages import ModelMessage
from pydantic_ai.messages import (
ModelRequest,
SystemPromptPart,
ToolReturnPart,
UserPromptPart,
)
from pydantic_ai.models.openai import OpenAIChatModel
from pydantic_ai.providers.openai import OpenAIProvider
logger = logging.getLogger(__name__)
class PydanticAgentRolloutProcessor(RolloutProcessor):
"""Rollout processor for Pydantic AI agents. Mainly converts
EvaluationRow.messages to and from Pydantic AI ModelMessage format."""
def __init__(
self,
agent_factory: Callable[[RolloutProcessorConfig], Agent],
usage_limits: UsageLimits | None = None,
):
# dummy model used for its helper functions for processing messages
self._util: OpenAIChatModel = OpenAIChatModel("dummy-model", provider=OpenAIProvider(api_key="dummy"))
self._setup_agent = agent_factory
@override
def __call__(self, rows: list[EvaluationRow], config: RolloutProcessorConfig) -> list[asyncio.Task[EvaluationRow]]:
"""Create agent rollout tasks and return them for external handling."""
semaphore = config.semaphore
agent = self._setup_agent(config)
async def process_row(row: EvaluationRow) -> EvaluationRow:
"""Process a single row with agent rollout."""
start_time = time.perf_counter()
tools = []
for toolset in agent.toolsets:
if isinstance(toolset, FunctionToolset):
for _, tool in toolset.tools.items():
tool_dict = {
"type": "function",
"function": {
"name": tool.name,
"parameters": tool.function_schema.json_schema,
},
}
if tool.description:
tool_dict["function"]["description"] = tool.description
tools.append(tool_dict)
row.tools = tools
model_messages = [self.convert_ep_message_to_pyd_message(m, row) for m in row.messages]
response = await agent.run(message_history=model_messages, usage_limits=config.kwargs.get("usage_limits"))
row.messages = await self.convert_pyd_message_to_ep_message(response.all_messages())
# TODO: pydantic ai accumulates usage info across all models in multi-agent setup, so this simple tracking doesn't work for cost. to discuss with @dphuang2 when he's back.
# usage_info = response.usage()
# row.execution_metadata.usage = CompletionUsage(
# prompt_tokens=usage_info.request_tokens or 0,
# completion_tokens=usage_info.response_tokens or 0,
# total_tokens=usage_info.total_tokens or 0,
# )
row.execution_metadata.duration_seconds = time.perf_counter() - start_time
return row
async def _sem_wrapper(r: EvaluationRow) -> EvaluationRow:
async with semaphore:
result = await process_row(r)
return result
# Create and return tasks for external handling
tasks = [asyncio.create_task(_sem_wrapper(row)) for row in rows]
return tasks
async def convert_pyd_message_to_ep_message(self, messages: list[ModelMessage]) -> list[Message]:
oai_messages: list[ChatCompletionMessageParam] = await self._util._map_messages(messages)
return [Message(**m) for m in oai_messages] # pyright: ignore[reportArgumentType]
def convert_ep_message_to_pyd_message(self, message: Message, row: EvaluationRow) -> ModelMessage:
if message.role == "assistant":
type_adapter = TypeAdapter(ChatCompletionMessage)
oai_message = type_adapter.validate_python(message)
# Fix: Provide required finish_reason and index, and ensure created is int (timestamp)
return self._util._process_response(
ChatCompletion(
choices=[ChatCompletionChoice(message=oai_message, finish_reason="stop", index=0)],
object="chat.completion",
model="",
id="",
created=int(row.created_at.timestamp()),
)
)
elif message.role == "user":
if isinstance(message.content, str):
return ModelRequest(parts=[UserPromptPart(content=message.content)])
elif isinstance(message.content, list):
return ModelRequest(parts=[UserPromptPart(content=message.content[0].text)])
else:
raise ValueError(f"Unsupported content type for user message: {type(message.content)}")
elif message.role == "system":
if isinstance(message.content, str):
return ModelRequest(parts=[SystemPromptPart(content=message.content)])
elif isinstance(message.content, list):
return ModelRequest(parts=[SystemPromptPart(content=message.content[0].text)])
else:
raise ValueError(f"Unsupported content type for system message: {type(message.content)}")
elif message.role == "tool":
return ModelRequest(
parts=[
ToolReturnPart(
content=message.content,
tool_name="",
tool_call_id=message.tool_call_id or generate_tool_call_id(),
)
]
)
raise ValueError(f"Unknown role: {message.role}")