Skip to content

Commit b297a50

Browse files
committed
reuse pydantic example for local model picking
1 parent 3b84fe4 commit b297a50

File tree

3 files changed

+32
-5
lines changed

3 files changed

+32
-5
lines changed

eval_protocol/pytest/default_pydantic_ai_rollout_processor.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,12 +46,20 @@ def __call__(self, rows: list[EvaluationRow], config: RolloutProcessorConfig) ->
4646
"""Create agent rollout tasks and return them for external handling."""
4747

4848
semaphore = config.semaphore
49-
agent = self._setup_agent(config)
49+
# NOTE: Do not create the agent outside of the semaphore or multiple rows
50+
# will initialize clients and start network calls concurrently. This can
51+
# overwhelm local providers like Ollama where only one request should be
52+
# active at a time. Instead, construct the agent within the semaphore-guarded
53+
# section per row.
5054

5155
async def process_row(row: EvaluationRow) -> EvaluationRow:
5256
"""Process a single row with agent rollout."""
5357
start_time = time.perf_counter()
5458

59+
# Build the agent lazily inside the semaphore guard to ensure we fully
60+
# respect max_concurrent_rollouts across both setup and run phases.
61+
agent = self._setup_agent(config)
62+
5563
tools = []
5664
for toolset in agent.toolsets:
5765
if isinstance(toolset, FunctionToolset):

tests/chinook/pydantic/agent.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ def execute_sql(ctx: RunContext, query: str) -> str:
6161

6262
return "\n".join(table_lines)
6363
except Exception as e:
64+
print("Show exception: ", e)
6465
connection.rollback()
6566
raise ModelRetry("Please try again with a different query. Here is the error: " + str(e))
6667

tests/chinook/pydantic/test_pydantic_complex_queries.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from pydantic import BaseModel
33
from pydantic_ai import Agent
44
from pydantic_ai.models.openai import OpenAIChatModel, OpenAIChatModelSettings
5+
from pydantic_ai.providers.openai import OpenAIProvider
56
import pytest
67

78
from eval_protocol.models import EvaluateResult, EvaluationRow
@@ -24,10 +25,23 @@
2425

2526
def agent_factory(config: RolloutProcessorConfig) -> Agent:
2627
model_name = config.completion_params["model"]
27-
provider = config.completion_params.get("provider")
28+
provider_param = config.completion_params.get("provider")
2829
reasoning = config.completion_params.get("reasoning")
29-
settings = OpenAIChatModelSettings(openai_reasoning_effort=reasoning)
30-
model = OpenAIChatModel(model_name, provider=provider or "openai", settings=settings)
30+
# gpt-4o-mini does not support reasoning
31+
if model_name == "gpt-4o-mini":
32+
settings = OpenAIChatModelSettings()
33+
else:
34+
settings = OpenAIChatModelSettings(openai_reasoning_effort=reasoning)
35+
base_url = config.completion_params.get("base_url")
36+
api_key = config.completion_params.get("api_key") or os.getenv("OPENAI_API_KEY") or "dummy"
37+
if base_url or provider_param == "ollama":
38+
provider = OpenAIProvider(
39+
api_key=api_key,
40+
base_url=base_url or os.getenv("OLLAMA_OPENAI_BASE_URL", "http://localhost:11434/v1"),
41+
)
42+
else:
43+
provider = provider_param or "openai"
44+
model = OpenAIChatModel(model_name, provider=provider, settings=settings)
3145
return setup_agent(model)
3246

3347

@@ -51,7 +65,11 @@ def agent_factory(config: RolloutProcessorConfig) -> Agent:
5165
# "model": "accounts/fireworks/models/kimi-k2-instruct-0905",
5266
# "provider": "fireworks",
5367
# },
54-
{"model": "gpt-5"},
68+
# {"model": "gpt-4o-mini"},
69+
{"model": "gpt-5-nano-2025-08-07"},
70+
# {"model": "qwen3:4b", "provider": "ollama", "base_url": os.getenv("OLLAMA_OPENAI_BASE_URL", "http://localhost:11434/v1")},
71+
# {"model": "qwen3:8b", "provider": "ollama", "base_url": os.getenv("OLLAMA_OPENAI_BASE_URL", "http://localhost:11434/v1")},
72+
# {"model": "granite4:micro", "provider": "ollama", "base_url": os.getenv("OLLAMA_OPENAI_BASE_URL", "http://localhost:11434/v1")},
5573
# {"model": "gpt-5", "reasoning": "high"},
5674
],
5775
rollout_processor=PydanticAgentRolloutProcessor(agent_factory),

0 commit comments

Comments
 (0)