55from eval_protocol .pytest import evaluation_test
66
77from eval_protocol .pytest .default_pydantic_ai_rollout_processor import PydanticAgentRolloutProcessor
8+ from eval_protocol .pytest .types import RolloutProcessorConfig
89from tests .chinook .pydantic .agent import setup_agent
910
1011from tests .chinook .dataset import collect_dataset
1112
1213try :
1314 from langfuse import get_client , observe # pyright: ignore[reportPrivateImportUsage]
1415 from pydantic_ai .agent import Agent
16+ from pydantic_ai .models .openai import OpenAIModel
1517
1618 LANGFUSE_AVAILABLE = True
1719 langfuse_client = get_client ()
@@ -37,6 +39,13 @@ def decorator(func):
3739)
3840
3941
42+ def agent_factory (config : RolloutProcessorConfig ) -> Agent :
43+ model_name = config .completion_params ["model" ]
44+ provider = config .completion_params ["provider" ]
45+ model = OpenAIModel (model_name , provider = provider )
46+ return setup_agent (model )
47+
48+
4049@pytest .mark .skipif (
4150 os .environ .get ("CI" ) == "true" ,
4251 reason = "Only run this test locally (skipped in CI)" ,
@@ -47,24 +56,20 @@ def decorator(func):
4756 input_rows = [collect_dataset ()[0 :1 ]],
4857 completion_params = [
4958 {
50- "model" : {
51- "orchestrator_agent_model" : {
52- "model" : "accounts/fireworks/models/kimi-k2-instruct" ,
53- "provider" : "fireworks" ,
54- }
55- }
59+ "model" : "accounts/fireworks/models/kimi-k2-instruct" ,
60+ "provider" : "fireworks" ,
5661 },
5762 ],
58- rollout_processor = PydanticAgentRolloutProcessor (),
59- rollout_processor_kwargs = {"agent" : setup_agent },
63+ rollout_processor = PydanticAgentRolloutProcessor (agent_factory ),
6064 mode = "pointwise" ,
6165)
6266async def test_complex_query_0 (row : EvaluationRow ) -> EvaluationRow :
6367 """
6468 Complex queries - PydanticAI automatically creates rich Langfuse traces.
6569 """
70+ # Have to postprocess tools because row.tools isn't set until during rollout
6671 if langfuse_client :
67- langfuse_client .update_current_trace (tags = ["chinook_sql" ])
72+ langfuse_client .update_current_trace (tags = ["chinook_sql" ], metadata = { "tools" : row . tools } )
6873
6974 return row
7075
@@ -79,24 +84,19 @@ async def test_complex_query_0(row: EvaluationRow) -> EvaluationRow:
7984 input_rows = [collect_dataset ()[1 :2 ]],
8085 completion_params = [
8186 {
82- "model" : {
83- "orchestrator_agent_model" : {
84- "model" : "accounts/fireworks/models/kimi-k2-instruct" ,
85- "provider" : "fireworks" ,
86- }
87- }
87+ "model" : "accounts/fireworks/models/kimi-k2-instruct" ,
88+ "provider" : "fireworks" ,
8889 },
8990 ],
90- rollout_processor = PydanticAgentRolloutProcessor (),
91- rollout_processor_kwargs = {"agent" : setup_agent },
91+ rollout_processor = PydanticAgentRolloutProcessor (agent_factory ),
9292 mode = "pointwise" ,
9393)
9494async def test_complex_query_1 (row : EvaluationRow ) -> EvaluationRow :
9595 """
9696 Complex queries - PydanticAI automatically creates rich Langfuse traces.
9797 """
9898 if langfuse_client :
99- langfuse_client .update_current_trace (tags = ["chinook_sql" ])
99+ langfuse_client .update_current_trace (tags = ["chinook_sql" ], metadata = { "tools" : row . tools } )
100100
101101 return row
102102
@@ -111,24 +111,19 @@ async def test_complex_query_1(row: EvaluationRow) -> EvaluationRow:
111111 input_rows = [collect_dataset ()[2 :3 ]],
112112 completion_params = [
113113 {
114- "model" : {
115- "orchestrator_agent_model" : {
116- "model" : "accounts/fireworks/models/kimi-k2-instruct" ,
117- "provider" : "fireworks" ,
118- }
119- }
114+ "model" : "accounts/fireworks/models/kimi-k2-instruct" ,
115+ "provider" : "fireworks" ,
120116 },
121117 ],
122- rollout_processor = PydanticAgentRolloutProcessor (),
123- rollout_processor_kwargs = {"agent" : setup_agent },
118+ rollout_processor = PydanticAgentRolloutProcessor (agent_factory ),
124119 mode = "pointwise" ,
125120)
126121async def test_complex_query_2 (row : EvaluationRow ) -> EvaluationRow :
127122 """
128123 Complex queries - PydanticAI automatically creates rich Langfuse traces.
129124 """
130125 if langfuse_client :
131- langfuse_client .update_current_trace (tags = ["chinook_sql" ])
126+ langfuse_client .update_current_trace (tags = ["chinook_sql" ], metadata = { "tools" : row . tools } )
132127
133128 return row
134129
@@ -143,24 +138,19 @@ async def test_complex_query_2(row: EvaluationRow) -> EvaluationRow:
143138 input_rows = [collect_dataset ()[3 :4 ]],
144139 completion_params = [
145140 {
146- "model" : {
147- "orchestrator_agent_model" : {
148- "model" : "accounts/fireworks/models/kimi-k2-instruct" ,
149- "provider" : "fireworks" ,
150- }
151- }
141+ "model" : "accounts/fireworks/models/kimi-k2-instruct" ,
142+ "provider" : "fireworks" ,
152143 },
153144 ],
154- rollout_processor = PydanticAgentRolloutProcessor (),
155- rollout_processor_kwargs = {"agent" : setup_agent },
145+ rollout_processor = PydanticAgentRolloutProcessor (agent_factory ),
156146 mode = "pointwise" ,
157147)
158148async def test_complex_query_3 (row : EvaluationRow ) -> EvaluationRow :
159149 """
160150 Complex queries - PydanticAI automatically creates rich Langfuse traces.
161151 """
162152 if langfuse_client :
163- langfuse_client .update_current_trace (tags = ["chinook_sql" ])
153+ langfuse_client .update_current_trace (tags = ["chinook_sql" ], metadata = { "tools" : row . tools } )
164154
165155 return row
166156
@@ -175,24 +165,19 @@ async def test_complex_query_3(row: EvaluationRow) -> EvaluationRow:
175165 input_rows = [collect_dataset ()[4 :5 ]],
176166 completion_params = [
177167 {
178- "model" : {
179- "orchestrator_agent_model" : {
180- "model" : "accounts/fireworks/models/kimi-k2-instruct" ,
181- "provider" : "fireworks" ,
182- }
183- }
168+ "model" : "accounts/fireworks/models/kimi-k2-instruct" ,
169+ "provider" : "fireworks" ,
184170 },
185171 ],
186- rollout_processor = PydanticAgentRolloutProcessor (),
187- rollout_processor_kwargs = {"agent" : setup_agent },
172+ rollout_processor = PydanticAgentRolloutProcessor (agent_factory ),
188173 mode = "pointwise" ,
189174)
190175async def test_complex_query_4 (row : EvaluationRow ) -> EvaluationRow :
191176 """
192177 Complex queries - PydanticAI automatically creates rich Langfuse traces.
193178 """
194179 if langfuse_client :
195- langfuse_client .update_current_trace (tags = ["chinook_sql" ])
180+ langfuse_client .update_current_trace (tags = ["chinook_sql" ], metadata = { "tools" : row . tools } )
196181
197182 return row
198183
@@ -207,23 +192,18 @@ async def test_complex_query_4(row: EvaluationRow) -> EvaluationRow:
207192 input_rows = [collect_dataset ()[5 :6 ]],
208193 completion_params = [
209194 {
210- "model" : {
211- "orchestrator_agent_model" : {
212- "model" : "accounts/fireworks/models/kimi-k2-instruct" ,
213- "provider" : "fireworks" ,
214- }
215- }
195+ "model" : "accounts/fireworks/models/kimi-k2-instruct" ,
196+ "provider" : "fireworks" ,
216197 },
217198 ],
218- rollout_processor = PydanticAgentRolloutProcessor (),
219- rollout_processor_kwargs = {"agent" : setup_agent },
199+ rollout_processor = PydanticAgentRolloutProcessor (agent_factory ),
220200 mode = "pointwise" ,
221201)
222202async def test_complex_query_5 (row : EvaluationRow ) -> EvaluationRow :
223203 """
224204 Complex queries - PydanticAI automatically creates rich Langfuse traces.
225205 """
226206 if langfuse_client :
227- langfuse_client .update_current_trace (tags = ["chinook_sql" ])
207+ langfuse_client .update_current_trace (tags = ["chinook_sql" ], metadata = { "tools" : row . tools } )
228208
229209 return row
0 commit comments