Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 13 additions & 4 deletions eval_protocol/adapters/langfuse.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ def _convert_trace_to_evaluation_row(self, trace: Any, include_tool_calls: bool
Returns:
EvaluationRow or None if conversion fails
"""
# TODO: move this logic into an adapter in llm_judge.py. langfuse.py should just return traces
try:
# Get observations (generations, spans) from the trace
observations_response = self.client.api.observations.get_many(trace_id=trace.id, limit=100)
Expand Down Expand Up @@ -183,7 +184,7 @@ def _convert_trace_to_evaluation_row(self, trace: Any, include_tool_calls: bool
ground_truth = self._extract_ground_truth(trace)

# Extract tools if available
tools = self._extract_tools(observations) if include_tool_calls else None
tools = self._extract_tools(observations, trace) if include_tool_calls else None

return EvaluationRow(
messages=messages,
Expand Down Expand Up @@ -471,17 +472,25 @@ def _extract_ground_truth(self, trace: Any) -> Optional[str]:

return None

def _extract_tools(self, observations: List[Any]) -> Optional[List[Dict[str, Any]]]:
"""Extract tool definitions from observations.
def _extract_tools(self, observations: List[Any], trace: Any = None) -> Optional[List[Dict[str, Any]]]:
"""Extract tool definitions from trace metadata or observations.

Args:
observations: List of observation objects
trace: Trace object that may contain metadata with tools

Returns:
List of tool definitions or None
"""
# First, try to extract tools from trace metadata (preferred)
if trace and hasattr(trace, "metadata") and trace.metadata:
if isinstance(trace.metadata, dict) and "tools" in trace.metadata:
tools_from_metadata = trace.metadata["tools"]
if tools_from_metadata:
return tools_from_metadata

# Fallback: extract from observations
tools = []

for obs in observations:
if hasattr(obs, "input") and obs.input and isinstance(obs.input, dict):
if "tools" in obs.input:
Expand Down
15 changes: 15 additions & 0 deletions eval_protocol/pytest/default_pydantic_ai_rollout_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,21 @@ async def process_row(row: EvaluationRow) -> EvaluationRow:
"""Process a single row with agent rollout."""
start_time = time.perf_counter()

tools = []
for _, tool in agent._function_tools.items():
tool_dict = {
"type": "function",
"function": {
"name": tool.name,
"parameters": tool.function_schema.json_schema,
},
}
if tool.description:
tool_dict["function"]["description"] = tool.description

tools.append(tool_dict)
row.tools = tools

model_messages = [self.convert_ep_message_to_pyd_message(m, row) for m in row.messages]
response = await agent.run(message_history=model_messages, usage_limits=config.kwargs.get("usage_limits"))
row.messages = await self.convert_pyd_message_to_ep_message(response.all_messages())
Expand Down
88 changes: 34 additions & 54 deletions tests/chinook/langfuse/generate_traces.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,15 @@
from eval_protocol.pytest import evaluation_test

from eval_protocol.pytest.default_pydantic_ai_rollout_processor import PydanticAgentRolloutProcessor
from eval_protocol.pytest.types import RolloutProcessorConfig
from tests.chinook.pydantic.agent import setup_agent

from tests.chinook.dataset import collect_dataset

try:
from langfuse import get_client, observe # pyright: ignore[reportPrivateImportUsage]
from pydantic_ai.agent import Agent
from pydantic_ai.models.openai import OpenAIModel

LANGFUSE_AVAILABLE = True
langfuse_client = get_client()
Expand All @@ -37,6 +39,13 @@ def decorator(func):
)


def agent_factory(config: RolloutProcessorConfig) -> Agent:
model_name = config.completion_params["model"]
provider = config.completion_params["provider"]
model = OpenAIModel(model_name, provider=provider)
return setup_agent(model)


@pytest.mark.skipif(
os.environ.get("CI") == "true",
reason="Only run this test locally (skipped in CI)",
Expand All @@ -47,24 +56,20 @@ def decorator(func):
input_rows=[collect_dataset()[0:1]],
completion_params=[
{
"model": {
"orchestrator_agent_model": {
"model": "accounts/fireworks/models/kimi-k2-instruct",
"provider": "fireworks",
}
}
"model": "accounts/fireworks/models/kimi-k2-instruct",
"provider": "fireworks",
},
],
rollout_processor=PydanticAgentRolloutProcessor(),
rollout_processor_kwargs={"agent": setup_agent},
rollout_processor=PydanticAgentRolloutProcessor(agent_factory),
mode="pointwise",
)
async def test_complex_query_0(row: EvaluationRow) -> EvaluationRow:
"""
Complex queries - PydanticAI automatically creates rich Langfuse traces.
"""
# Have to postprocess tools because row.tools isn't set until during rollout
if langfuse_client:
langfuse_client.update_current_trace(tags=["chinook_sql"])
langfuse_client.update_current_trace(tags=["chinook_sql"], metadata={"tools": row.tools})

return row

Expand All @@ -79,24 +84,19 @@ async def test_complex_query_0(row: EvaluationRow) -> EvaluationRow:
input_rows=[collect_dataset()[1:2]],
completion_params=[
{
"model": {
"orchestrator_agent_model": {
"model": "accounts/fireworks/models/kimi-k2-instruct",
"provider": "fireworks",
}
}
"model": "accounts/fireworks/models/kimi-k2-instruct",
"provider": "fireworks",
},
],
rollout_processor=PydanticAgentRolloutProcessor(),
rollout_processor_kwargs={"agent": setup_agent},
rollout_processor=PydanticAgentRolloutProcessor(agent_factory),
mode="pointwise",
)
async def test_complex_query_1(row: EvaluationRow) -> EvaluationRow:
"""
Complex queries - PydanticAI automatically creates rich Langfuse traces.
"""
if langfuse_client:
langfuse_client.update_current_trace(tags=["chinook_sql"])
langfuse_client.update_current_trace(tags=["chinook_sql"], metadata={"tools": row.tools})

return row

Expand All @@ -111,24 +111,19 @@ async def test_complex_query_1(row: EvaluationRow) -> EvaluationRow:
input_rows=[collect_dataset()[2:3]],
completion_params=[
{
"model": {
"orchestrator_agent_model": {
"model": "accounts/fireworks/models/kimi-k2-instruct",
"provider": "fireworks",
}
}
"model": "accounts/fireworks/models/kimi-k2-instruct",
"provider": "fireworks",
},
],
rollout_processor=PydanticAgentRolloutProcessor(),
rollout_processor_kwargs={"agent": setup_agent},
rollout_processor=PydanticAgentRolloutProcessor(agent_factory),
mode="pointwise",
)
async def test_complex_query_2(row: EvaluationRow) -> EvaluationRow:
"""
Complex queries - PydanticAI automatically creates rich Langfuse traces.
"""
if langfuse_client:
langfuse_client.update_current_trace(tags=["chinook_sql"])
langfuse_client.update_current_trace(tags=["chinook_sql"], metadata={"tools": row.tools})

return row

Expand All @@ -143,24 +138,19 @@ async def test_complex_query_2(row: EvaluationRow) -> EvaluationRow:
input_rows=[collect_dataset()[3:4]],
completion_params=[
{
"model": {
"orchestrator_agent_model": {
"model": "accounts/fireworks/models/kimi-k2-instruct",
"provider": "fireworks",
}
}
"model": "accounts/fireworks/models/kimi-k2-instruct",
"provider": "fireworks",
},
],
rollout_processor=PydanticAgentRolloutProcessor(),
rollout_processor_kwargs={"agent": setup_agent},
rollout_processor=PydanticAgentRolloutProcessor(agent_factory),
mode="pointwise",
)
async def test_complex_query_3(row: EvaluationRow) -> EvaluationRow:
"""
Complex queries - PydanticAI automatically creates rich Langfuse traces.
"""
if langfuse_client:
langfuse_client.update_current_trace(tags=["chinook_sql"])
langfuse_client.update_current_trace(tags=["chinook_sql"], metadata={"tools": row.tools})

return row

Expand All @@ -175,24 +165,19 @@ async def test_complex_query_3(row: EvaluationRow) -> EvaluationRow:
input_rows=[collect_dataset()[4:5]],
completion_params=[
{
"model": {
"orchestrator_agent_model": {
"model": "accounts/fireworks/models/kimi-k2-instruct",
"provider": "fireworks",
}
}
"model": "accounts/fireworks/models/kimi-k2-instruct",
"provider": "fireworks",
},
],
rollout_processor=PydanticAgentRolloutProcessor(),
rollout_processor_kwargs={"agent": setup_agent},
rollout_processor=PydanticAgentRolloutProcessor(agent_factory),
mode="pointwise",
)
async def test_complex_query_4(row: EvaluationRow) -> EvaluationRow:
"""
Complex queries - PydanticAI automatically creates rich Langfuse traces.
"""
if langfuse_client:
langfuse_client.update_current_trace(tags=["chinook_sql"])
langfuse_client.update_current_trace(tags=["chinook_sql"], metadata={"tools": row.tools})

return row

Expand All @@ -207,23 +192,18 @@ async def test_complex_query_4(row: EvaluationRow) -> EvaluationRow:
input_rows=[collect_dataset()[5:6]],
completion_params=[
{
"model": {
"orchestrator_agent_model": {
"model": "accounts/fireworks/models/kimi-k2-instruct",
"provider": "fireworks",
}
}
"model": "accounts/fireworks/models/kimi-k2-instruct",
"provider": "fireworks",
},
],
rollout_processor=PydanticAgentRolloutProcessor(),
rollout_processor_kwargs={"agent": setup_agent},
rollout_processor=PydanticAgentRolloutProcessor(agent_factory),
mode="pointwise",
)
async def test_complex_query_5(row: EvaluationRow) -> EvaluationRow:
"""
Complex queries - PydanticAI automatically creates rich Langfuse traces.
"""
if langfuse_client:
langfuse_client.update_current_trace(tags=["chinook_sql"])
langfuse_client.update_current_trace(tags=["chinook_sql"], metadata={"tools": row.tools})

return row
Loading