-
Notifications
You must be signed in to change notification settings - Fork 16
Expand file tree
/
Copy pathdefault_pydantic_ai_rollout_processor.py
More file actions
156 lines (140 loc) · 7.72 KB
/
default_pydantic_ai_rollout_processor.py
File metadata and controls
156 lines (140 loc) · 7.72 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
# pyright: reportPrivateUsage=false
import asyncio
import logging
import types
from pydantic_ai.models import Model
from typing_extensions import override
from eval_protocol.models import EvaluationRow, Message
from openai.types import CompletionUsage
from eval_protocol.pytest.rollout_processor import RolloutProcessor
from eval_protocol.pytest.types import RolloutProcessorConfig
from openai.types.chat import ChatCompletion, ChatCompletionMessage, ChatCompletionMessageParam
from openai.types.chat.chat_completion import Choice as ChatCompletionChoice
from pydantic_ai.models.anthropic import AnthropicModel
from pydantic_ai.models.openai import OpenAIModel
from pydantic_ai.models.google import GoogleModel
from pydantic import TypeAdapter
from pydantic_ai.messages import ModelMessage
from pydantic_ai._utils import generate_tool_call_id
from pydantic_ai import Agent
from pydantic_ai.messages import (
ModelRequest,
SystemPromptPart,
ToolReturnPart,
UserPromptPart,
)
from pydantic_ai.providers.openai import OpenAIProvider
logger = logging.getLogger(__name__)
class PydanticAgentRolloutProcessor(RolloutProcessor):
"""Rollout processor for Pydantic AI agents. Mainly converts
EvaluationRow.messages to and from Pydantic AI ModelMessage format."""
def __init__(self):
# dummy model used for its helper functions for processing messages
self.util: OpenAIModel = OpenAIModel("dummy-model", provider=OpenAIProvider(api_key="dummy"))
@override
def __call__(self, rows: list[EvaluationRow], config: RolloutProcessorConfig) -> list[asyncio.Task[EvaluationRow]]:
"""Create agent rollout tasks and return them for external handling."""
semaphore = config.semaphore
# validate that the "agent" field is present with a valid Pydantic AI Agent instance in the completion_params dict
if "agent" not in config.kwargs:
raise ValueError("kwargs must contain an 'agent' field with a valid Pydantic AI Agent instance")
if not isinstance(config.kwargs["agent"], Agent) and not isinstance(
config.kwargs["agent"], types.FunctionType
):
raise ValueError(
"kwargs['agent'] must be a valid Pydantic AI Agent instance or a function that returns an Agent"
)
if isinstance(config.kwargs["agent"], types.FunctionType):
setup_agent = config.kwargs["agent"]
if not isinstance(config.completion_params["model"], dict):
raise ValueError(
"completion_params['model'] must be a dict mapping agent argument names to model config dicts (with 'model' and 'provider' keys)"
)
kwargs: dict[str, Model] = {}
for k, v in config.completion_params["model"].items(): # pyright: ignore[reportUnknownVariableType]
if v["model"] and v["model"].startswith("anthropic:"): # pyright: ignore[reportUnknownMemberType]
kwargs[k] = AnthropicModel(
v["model"].removeprefix("anthropic:"), # pyright: ignore[reportUnknownMemberType, reportUnknownArgumentType]
)
elif v["model"] and v["model"].startswith("google:"): # pyright: ignore[reportUnknownMemberType]
kwargs[k] = GoogleModel(
v["model"].removeprefix("google:"), # pyright: ignore[reportUnknownMemberType, reportUnknownArgumentType]
)
else:
kwargs[k] = OpenAIModel(
v["model"], # pyright: ignore[reportUnknownArgumentType]
provider=v["provider"], # pyright: ignore[reportUnknownArgumentType]
)
agent_instance: Agent = setup_agent(**kwargs) # pyright: ignore[reportAny]
model = None
else:
agent_instance = config.kwargs["agent"] # pyright: ignore[reportAssignmentType]
model = OpenAIModel(
config.completion_params["model"], # pyright: ignore[reportAny]
provider=config.completion_params["provider"], # pyright: ignore[reportAny]
)
async def process_row(row: EvaluationRow) -> EvaluationRow:
"""Process a single row with agent rollout."""
model_messages = [self.convert_ep_message_to_pyd_message(m, row) for m in row.messages]
response = await agent_instance.run(
message_history=model_messages, model=model, usage_limits=config.kwargs.get("usage_limits")
)
row.messages = await self.convert_pyd_message_to_ep_message(response.all_messages())
# TODO: pydantic ai accumulates usage info across all models in multi-agent setup, so this simple tracking doesn't work for cost. to discuss with @dphuang2 when he's back.
# usage_info = response.usage()
# row.execution_metadata.usage = CompletionUsage(
# prompt_tokens=usage_info.request_tokens or 0,
# completion_tokens=usage_info.response_tokens or 0,
# total_tokens=usage_info.total_tokens or 0,
# )
return row
async def _sem_wrapper(r: EvaluationRow) -> EvaluationRow:
async with semaphore:
result = await process_row(r)
return result
# Create and return tasks for external handling
tasks = [asyncio.create_task(_sem_wrapper(row)) for row in rows]
return tasks
async def convert_pyd_message_to_ep_message(self, messages: list[ModelMessage]) -> list[Message]:
oai_messages: list[ChatCompletionMessageParam] = await self.util._map_messages(messages)
return [Message(**m) for m in oai_messages] # pyright: ignore[reportArgumentType]
def convert_ep_message_to_pyd_message(self, message: Message, row: EvaluationRow) -> ModelMessage:
if message.role == "assistant":
type_adapter = TypeAdapter(ChatCompletionMessage)
oai_message = type_adapter.validate_python(message)
# Fix: Provide required finish_reason and index, and ensure created is int (timestamp)
return self.util._process_response(
ChatCompletion(
choices=[ChatCompletionChoice(message=oai_message, finish_reason="stop", index=0)],
object="chat.completion",
model="",
id="",
created=int(row.created_at.timestamp()),
)
)
elif message.role == "user":
if isinstance(message.content, str):
return ModelRequest(parts=[UserPromptPart(content=message.content)])
elif isinstance(message.content, list):
return ModelRequest(parts=[UserPromptPart(content=message.content[0].text)])
else:
raise ValueError(f"Unsupported content type for user message: {type(message.content)}")
elif message.role == "system":
if isinstance(message.content, str):
return ModelRequest(parts=[SystemPromptPart(content=message.content)])
elif isinstance(message.content, list):
return ModelRequest(parts=[SystemPromptPart(content=message.content[0].text)])
else:
raise ValueError(f"Unsupported content type for system message: {type(message.content)}")
elif message.role == "tool":
return ModelRequest(
parts=[
ToolReturnPart(
content=message.content,
tool_name="",
tool_call_id=message.tool_call_id or generate_tool_call_id(),
)
]
)
else:
raise ValueError(f"Unknown role: {message.role}")