|
11 | 11 | from eval_protocol.pytest.types import RolloutProcessorConfig |
12 | 12 | from openai.types.chat import ChatCompletion, ChatCompletionMessageParam |
13 | 13 | from openai.types.chat.chat_completion import Choice as ChatCompletionChoice |
14 | | - |
| 14 | +from pydantic_ai.models.anthropic import AnthropicModel |
15 | 15 | from pydantic_ai.models.openai import OpenAIModel |
| 16 | +from pydantic_ai.models.google import GoogleModel |
16 | 17 | from pydantic import TypeAdapter |
17 | 18 | from pydantic_ai.messages import ModelMessage |
18 | 19 | from pydantic_ai._utils import generate_tool_call_id |
@@ -61,10 +62,19 @@ def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig) -> |
61 | 62 | ) |
62 | 63 | kwargs = {} |
63 | 64 | for k, v in config.completion_params["model"].items(): |
64 | | - kwargs[k] = OpenAIModel( |
65 | | - v["model"], |
66 | | - provider=v["provider"], |
67 | | - ) |
| 65 | + if v["model"] and v["model"].startswith("anthropic:"): |
| 66 | + kwargs[k] = AnthropicModel( |
| 67 | + v["model"].removeprefix("anthropic:"), |
| 68 | + ) |
| 69 | + elif v["model"] and v["model"].startswith("google:"): |
| 70 | + kwargs[k] = GoogleModel( |
| 71 | + v["model"].removeprefix("google:"), |
| 72 | + ) |
| 73 | + else: |
| 74 | + kwargs[k] = OpenAIModel( |
| 75 | + v["model"], |
| 76 | + provider=v["provider"], |
| 77 | + ) |
68 | 78 | agent = setup_agent(**kwargs) |
69 | 79 | model = None |
70 | 80 | else: |
|
0 commit comments