Skip to content

Commit 3410973

Browse files
authored
Add tools to pydantic trace and make sure langfuse working (#168)
1 parent fb5c9e3 commit 3410973

File tree

3 files changed

+62
-58
lines changed

3 files changed

+62
-58
lines changed

eval_protocol/adapters/langfuse.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,7 @@ def _convert_trace_to_evaluation_row(self, trace: Any, include_tool_calls: bool
147147
Returns:
148148
EvaluationRow or None if conversion fails
149149
"""
150+
# TODO: move this logic into an adapter in llm_judge.py. langfuse.py should just return traces
150151
try:
151152
# Get observations (generations, spans) from the trace
152153
observations_response = self.client.api.observations.get_many(trace_id=trace.id, limit=100)
@@ -183,7 +184,7 @@ def _convert_trace_to_evaluation_row(self, trace: Any, include_tool_calls: bool
183184
ground_truth = self._extract_ground_truth(trace)
184185

185186
# Extract tools if available
186-
tools = self._extract_tools(observations) if include_tool_calls else None
187+
tools = self._extract_tools(observations, trace) if include_tool_calls else None
187188

188189
return EvaluationRow(
189190
messages=messages,
@@ -471,17 +472,25 @@ def _extract_ground_truth(self, trace: Any) -> Optional[str]:
471472

472473
return None
473474

474-
def _extract_tools(self, observations: List[Any]) -> Optional[List[Dict[str, Any]]]:
475-
"""Extract tool definitions from observations.
475+
def _extract_tools(self, observations: List[Any], trace: Any = None) -> Optional[List[Dict[str, Any]]]:
476+
"""Extract tool definitions from trace metadata or observations.
476477
477478
Args:
478479
observations: List of observation objects
480+
trace: Trace object that may contain metadata with tools
479481
480482
Returns:
481483
List of tool definitions or None
482484
"""
485+
# First, try to extract tools from trace metadata (preferred)
486+
if trace and hasattr(trace, "metadata") and trace.metadata:
487+
if isinstance(trace.metadata, dict) and "tools" in trace.metadata:
488+
tools_from_metadata = trace.metadata["tools"]
489+
if tools_from_metadata:
490+
return tools_from_metadata
491+
492+
# Fallback: extract from observations
483493
tools = []
484-
485494
for obs in observations:
486495
if hasattr(obs, "input") and obs.input and isinstance(obs.input, dict):
487496
if "tools" in obs.input:

eval_protocol/pytest/default_pydantic_ai_rollout_processor.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,21 @@ async def process_row(row: EvaluationRow) -> EvaluationRow:
5252
"""Process a single row with agent rollout."""
5353
start_time = time.perf_counter()
5454

55+
tools = []
56+
for _, tool in agent._function_tools.items():
57+
tool_dict = {
58+
"type": "function",
59+
"function": {
60+
"name": tool.name,
61+
"parameters": tool.function_schema.json_schema,
62+
},
63+
}
64+
if tool.description:
65+
tool_dict["function"]["description"] = tool.description
66+
67+
tools.append(tool_dict)
68+
row.tools = tools
69+
5570
model_messages = [self.convert_ep_message_to_pyd_message(m, row) for m in row.messages]
5671
response = await agent.run(message_history=model_messages, usage_limits=config.kwargs.get("usage_limits"))
5772
row.messages = await self.convert_pyd_message_to_ep_message(response.all_messages())

tests/chinook/langfuse/generate_traces.py

Lines changed: 34 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,15 @@
55
from eval_protocol.pytest import evaluation_test
66

77
from eval_protocol.pytest.default_pydantic_ai_rollout_processor import PydanticAgentRolloutProcessor
8+
from eval_protocol.pytest.types import RolloutProcessorConfig
89
from tests.chinook.pydantic.agent import setup_agent
910

1011
from tests.chinook.dataset import collect_dataset
1112

1213
try:
1314
from langfuse import get_client, observe # pyright: ignore[reportPrivateImportUsage]
1415
from pydantic_ai.agent import Agent
16+
from pydantic_ai.models.openai import OpenAIModel
1517

1618
LANGFUSE_AVAILABLE = True
1719
langfuse_client = get_client()
@@ -37,6 +39,13 @@ def decorator(func):
3739
)
3840

3941

42+
def agent_factory(config: RolloutProcessorConfig) -> Agent:
43+
model_name = config.completion_params["model"]
44+
provider = config.completion_params["provider"]
45+
model = OpenAIModel(model_name, provider=provider)
46+
return setup_agent(model)
47+
48+
4049
@pytest.mark.skipif(
4150
os.environ.get("CI") == "true",
4251
reason="Only run this test locally (skipped in CI)",
@@ -47,24 +56,20 @@ def decorator(func):
4756
input_rows=[collect_dataset()[0:1]],
4857
completion_params=[
4958
{
50-
"model": {
51-
"orchestrator_agent_model": {
52-
"model": "accounts/fireworks/models/kimi-k2-instruct",
53-
"provider": "fireworks",
54-
}
55-
}
59+
"model": "accounts/fireworks/models/kimi-k2-instruct",
60+
"provider": "fireworks",
5661
},
5762
],
58-
rollout_processor=PydanticAgentRolloutProcessor(),
59-
rollout_processor_kwargs={"agent": setup_agent},
63+
rollout_processor=PydanticAgentRolloutProcessor(agent_factory),
6064
mode="pointwise",
6165
)
6266
async def test_complex_query_0(row: EvaluationRow) -> EvaluationRow:
6367
"""
6468
Complex queries - PydanticAI automatically creates rich Langfuse traces.
6569
"""
70+
# Have to postprocess tools because row.tools isn't set until during rollout
6671
if langfuse_client:
67-
langfuse_client.update_current_trace(tags=["chinook_sql"])
72+
langfuse_client.update_current_trace(tags=["chinook_sql"], metadata={"tools": row.tools})
6873

6974
return row
7075

@@ -79,24 +84,19 @@ async def test_complex_query_0(row: EvaluationRow) -> EvaluationRow:
7984
input_rows=[collect_dataset()[1:2]],
8085
completion_params=[
8186
{
82-
"model": {
83-
"orchestrator_agent_model": {
84-
"model": "accounts/fireworks/models/kimi-k2-instruct",
85-
"provider": "fireworks",
86-
}
87-
}
87+
"model": "accounts/fireworks/models/kimi-k2-instruct",
88+
"provider": "fireworks",
8889
},
8990
],
90-
rollout_processor=PydanticAgentRolloutProcessor(),
91-
rollout_processor_kwargs={"agent": setup_agent},
91+
rollout_processor=PydanticAgentRolloutProcessor(agent_factory),
9292
mode="pointwise",
9393
)
9494
async def test_complex_query_1(row: EvaluationRow) -> EvaluationRow:
9595
"""
9696
Complex queries - PydanticAI automatically creates rich Langfuse traces.
9797
"""
9898
if langfuse_client:
99-
langfuse_client.update_current_trace(tags=["chinook_sql"])
99+
langfuse_client.update_current_trace(tags=["chinook_sql"], metadata={"tools": row.tools})
100100

101101
return row
102102

@@ -111,24 +111,19 @@ async def test_complex_query_1(row: EvaluationRow) -> EvaluationRow:
111111
input_rows=[collect_dataset()[2:3]],
112112
completion_params=[
113113
{
114-
"model": {
115-
"orchestrator_agent_model": {
116-
"model": "accounts/fireworks/models/kimi-k2-instruct",
117-
"provider": "fireworks",
118-
}
119-
}
114+
"model": "accounts/fireworks/models/kimi-k2-instruct",
115+
"provider": "fireworks",
120116
},
121117
],
122-
rollout_processor=PydanticAgentRolloutProcessor(),
123-
rollout_processor_kwargs={"agent": setup_agent},
118+
rollout_processor=PydanticAgentRolloutProcessor(agent_factory),
124119
mode="pointwise",
125120
)
126121
async def test_complex_query_2(row: EvaluationRow) -> EvaluationRow:
127122
"""
128123
Complex queries - PydanticAI automatically creates rich Langfuse traces.
129124
"""
130125
if langfuse_client:
131-
langfuse_client.update_current_trace(tags=["chinook_sql"])
126+
langfuse_client.update_current_trace(tags=["chinook_sql"], metadata={"tools": row.tools})
132127

133128
return row
134129

@@ -143,24 +138,19 @@ async def test_complex_query_2(row: EvaluationRow) -> EvaluationRow:
143138
input_rows=[collect_dataset()[3:4]],
144139
completion_params=[
145140
{
146-
"model": {
147-
"orchestrator_agent_model": {
148-
"model": "accounts/fireworks/models/kimi-k2-instruct",
149-
"provider": "fireworks",
150-
}
151-
}
141+
"model": "accounts/fireworks/models/kimi-k2-instruct",
142+
"provider": "fireworks",
152143
},
153144
],
154-
rollout_processor=PydanticAgentRolloutProcessor(),
155-
rollout_processor_kwargs={"agent": setup_agent},
145+
rollout_processor=PydanticAgentRolloutProcessor(agent_factory),
156146
mode="pointwise",
157147
)
158148
async def test_complex_query_3(row: EvaluationRow) -> EvaluationRow:
159149
"""
160150
Complex queries - PydanticAI automatically creates rich Langfuse traces.
161151
"""
162152
if langfuse_client:
163-
langfuse_client.update_current_trace(tags=["chinook_sql"])
153+
langfuse_client.update_current_trace(tags=["chinook_sql"], metadata={"tools": row.tools})
164154

165155
return row
166156

@@ -175,24 +165,19 @@ async def test_complex_query_3(row: EvaluationRow) -> EvaluationRow:
175165
input_rows=[collect_dataset()[4:5]],
176166
completion_params=[
177167
{
178-
"model": {
179-
"orchestrator_agent_model": {
180-
"model": "accounts/fireworks/models/kimi-k2-instruct",
181-
"provider": "fireworks",
182-
}
183-
}
168+
"model": "accounts/fireworks/models/kimi-k2-instruct",
169+
"provider": "fireworks",
184170
},
185171
],
186-
rollout_processor=PydanticAgentRolloutProcessor(),
187-
rollout_processor_kwargs={"agent": setup_agent},
172+
rollout_processor=PydanticAgentRolloutProcessor(agent_factory),
188173
mode="pointwise",
189174
)
190175
async def test_complex_query_4(row: EvaluationRow) -> EvaluationRow:
191176
"""
192177
Complex queries - PydanticAI automatically creates rich Langfuse traces.
193178
"""
194179
if langfuse_client:
195-
langfuse_client.update_current_trace(tags=["chinook_sql"])
180+
langfuse_client.update_current_trace(tags=["chinook_sql"], metadata={"tools": row.tools})
196181

197182
return row
198183

@@ -207,23 +192,18 @@ async def test_complex_query_4(row: EvaluationRow) -> EvaluationRow:
207192
input_rows=[collect_dataset()[5:6]],
208193
completion_params=[
209194
{
210-
"model": {
211-
"orchestrator_agent_model": {
212-
"model": "accounts/fireworks/models/kimi-k2-instruct",
213-
"provider": "fireworks",
214-
}
215-
}
195+
"model": "accounts/fireworks/models/kimi-k2-instruct",
196+
"provider": "fireworks",
216197
},
217198
],
218-
rollout_processor=PydanticAgentRolloutProcessor(),
219-
rollout_processor_kwargs={"agent": setup_agent},
199+
rollout_processor=PydanticAgentRolloutProcessor(agent_factory),
220200
mode="pointwise",
221201
)
222202
async def test_complex_query_5(row: EvaluationRow) -> EvaluationRow:
223203
"""
224204
Complex queries - PydanticAI automatically creates rich Langfuse traces.
225205
"""
226206
if langfuse_client:
227-
langfuse_client.update_current_trace(tags=["chinook_sql"])
207+
langfuse_client.update_current_trace(tags=["chinook_sql"], metadata={"tools": row.tools})
228208

229209
return row

0 commit comments

Comments
 (0)