diff --git a/eval_protocol/adapters/langfuse.py b/eval_protocol/adapters/langfuse.py index 9a15da88..4f11cd3f 100644 --- a/eval_protocol/adapters/langfuse.py +++ b/eval_protocol/adapters/langfuse.py @@ -147,6 +147,7 @@ def _convert_trace_to_evaluation_row(self, trace: Any, include_tool_calls: bool Returns: EvaluationRow or None if conversion fails """ + # TODO: move this logic into an adapter in llm_judge.py. langfuse.py should just return traces try: # Get observations (generations, spans) from the trace observations_response = self.client.api.observations.get_many(trace_id=trace.id, limit=100) @@ -183,7 +184,7 @@ def _convert_trace_to_evaluation_row(self, trace: Any, include_tool_calls: bool ground_truth = self._extract_ground_truth(trace) # Extract tools if available - tools = self._extract_tools(observations) if include_tool_calls else None + tools = self._extract_tools(observations, trace) if include_tool_calls else None return EvaluationRow( messages=messages, @@ -471,17 +472,25 @@ def _extract_ground_truth(self, trace: Any) -> Optional[str]: return None - def _extract_tools(self, observations: List[Any]) -> Optional[List[Dict[str, Any]]]: - """Extract tool definitions from observations. + def _extract_tools(self, observations: List[Any], trace: Any = None) -> Optional[List[Dict[str, Any]]]: + """Extract tool definitions from trace metadata or observations. Args: observations: List of observation objects + trace: Trace object that may contain metadata with tools Returns: List of tool definitions or None """ + # First, try to extract tools from trace metadata (preferred) + if trace and hasattr(trace, "metadata") and trace.metadata: + if isinstance(trace.metadata, dict) and "tools" in trace.metadata: + tools_from_metadata = trace.metadata["tools"] + if tools_from_metadata: + return tools_from_metadata + + # Fallback: extract from observations tools = [] - for obs in observations: if hasattr(obs, "input") and obs.input and isinstance(obs.input, dict): if "tools" in obs.input: diff --git a/eval_protocol/pytest/default_pydantic_ai_rollout_processor.py b/eval_protocol/pytest/default_pydantic_ai_rollout_processor.py index 4c0edfc3..169c8b68 100644 --- a/eval_protocol/pytest/default_pydantic_ai_rollout_processor.py +++ b/eval_protocol/pytest/default_pydantic_ai_rollout_processor.py @@ -52,6 +52,21 @@ async def process_row(row: EvaluationRow) -> EvaluationRow: """Process a single row with agent rollout.""" start_time = time.perf_counter() + tools = [] + for _, tool in agent._function_tools.items(): + tool_dict = { + "type": "function", + "function": { + "name": tool.name, + "parameters": tool.function_schema.json_schema, + }, + } + if tool.description: + tool_dict["function"]["description"] = tool.description + + tools.append(tool_dict) + row.tools = tools + model_messages = [self.convert_ep_message_to_pyd_message(m, row) for m in row.messages] response = await agent.run(message_history=model_messages, usage_limits=config.kwargs.get("usage_limits")) row.messages = await self.convert_pyd_message_to_ep_message(response.all_messages()) diff --git a/tests/chinook/langfuse/generate_traces.py b/tests/chinook/langfuse/generate_traces.py index e2d7d011..ce184c9f 100644 --- a/tests/chinook/langfuse/generate_traces.py +++ b/tests/chinook/langfuse/generate_traces.py @@ -5,6 +5,7 @@ from eval_protocol.pytest import evaluation_test from eval_protocol.pytest.default_pydantic_ai_rollout_processor import PydanticAgentRolloutProcessor +from eval_protocol.pytest.types import RolloutProcessorConfig from tests.chinook.pydantic.agent import setup_agent from tests.chinook.dataset import collect_dataset @@ -12,6 +13,7 @@ try: from langfuse import get_client, observe # pyright: ignore[reportPrivateImportUsage] from pydantic_ai.agent import Agent + from pydantic_ai.models.openai import OpenAIModel LANGFUSE_AVAILABLE = True langfuse_client = get_client() @@ -37,6 +39,13 @@ def decorator(func): ) +def agent_factory(config: RolloutProcessorConfig) -> Agent: + model_name = config.completion_params["model"] + provider = config.completion_params["provider"] + model = OpenAIModel(model_name, provider=provider) + return setup_agent(model) + + @pytest.mark.skipif( os.environ.get("CI") == "true", reason="Only run this test locally (skipped in CI)", @@ -47,24 +56,20 @@ def decorator(func): input_rows=[collect_dataset()[0:1]], completion_params=[ { - "model": { - "orchestrator_agent_model": { - "model": "accounts/fireworks/models/kimi-k2-instruct", - "provider": "fireworks", - } - } + "model": "accounts/fireworks/models/kimi-k2-instruct", + "provider": "fireworks", }, ], - rollout_processor=PydanticAgentRolloutProcessor(), - rollout_processor_kwargs={"agent": setup_agent}, + rollout_processor=PydanticAgentRolloutProcessor(agent_factory), mode="pointwise", ) async def test_complex_query_0(row: EvaluationRow) -> EvaluationRow: """ Complex queries - PydanticAI automatically creates rich Langfuse traces. """ + # Have to postprocess tools because row.tools isn't set until during rollout if langfuse_client: - langfuse_client.update_current_trace(tags=["chinook_sql"]) + langfuse_client.update_current_trace(tags=["chinook_sql"], metadata={"tools": row.tools}) return row @@ -79,16 +84,11 @@ async def test_complex_query_0(row: EvaluationRow) -> EvaluationRow: input_rows=[collect_dataset()[1:2]], completion_params=[ { - "model": { - "orchestrator_agent_model": { - "model": "accounts/fireworks/models/kimi-k2-instruct", - "provider": "fireworks", - } - } + "model": "accounts/fireworks/models/kimi-k2-instruct", + "provider": "fireworks", }, ], - rollout_processor=PydanticAgentRolloutProcessor(), - rollout_processor_kwargs={"agent": setup_agent}, + rollout_processor=PydanticAgentRolloutProcessor(agent_factory), mode="pointwise", ) async def test_complex_query_1(row: EvaluationRow) -> EvaluationRow: @@ -96,7 +96,7 @@ async def test_complex_query_1(row: EvaluationRow) -> EvaluationRow: Complex queries - PydanticAI automatically creates rich Langfuse traces. """ if langfuse_client: - langfuse_client.update_current_trace(tags=["chinook_sql"]) + langfuse_client.update_current_trace(tags=["chinook_sql"], metadata={"tools": row.tools}) return row @@ -111,16 +111,11 @@ async def test_complex_query_1(row: EvaluationRow) -> EvaluationRow: input_rows=[collect_dataset()[2:3]], completion_params=[ { - "model": { - "orchestrator_agent_model": { - "model": "accounts/fireworks/models/kimi-k2-instruct", - "provider": "fireworks", - } - } + "model": "accounts/fireworks/models/kimi-k2-instruct", + "provider": "fireworks", }, ], - rollout_processor=PydanticAgentRolloutProcessor(), - rollout_processor_kwargs={"agent": setup_agent}, + rollout_processor=PydanticAgentRolloutProcessor(agent_factory), mode="pointwise", ) async def test_complex_query_2(row: EvaluationRow) -> EvaluationRow: @@ -128,7 +123,7 @@ async def test_complex_query_2(row: EvaluationRow) -> EvaluationRow: Complex queries - PydanticAI automatically creates rich Langfuse traces. """ if langfuse_client: - langfuse_client.update_current_trace(tags=["chinook_sql"]) + langfuse_client.update_current_trace(tags=["chinook_sql"], metadata={"tools": row.tools}) return row @@ -143,16 +138,11 @@ async def test_complex_query_2(row: EvaluationRow) -> EvaluationRow: input_rows=[collect_dataset()[3:4]], completion_params=[ { - "model": { - "orchestrator_agent_model": { - "model": "accounts/fireworks/models/kimi-k2-instruct", - "provider": "fireworks", - } - } + "model": "accounts/fireworks/models/kimi-k2-instruct", + "provider": "fireworks", }, ], - rollout_processor=PydanticAgentRolloutProcessor(), - rollout_processor_kwargs={"agent": setup_agent}, + rollout_processor=PydanticAgentRolloutProcessor(agent_factory), mode="pointwise", ) async def test_complex_query_3(row: EvaluationRow) -> EvaluationRow: @@ -160,7 +150,7 @@ async def test_complex_query_3(row: EvaluationRow) -> EvaluationRow: Complex queries - PydanticAI automatically creates rich Langfuse traces. """ if langfuse_client: - langfuse_client.update_current_trace(tags=["chinook_sql"]) + langfuse_client.update_current_trace(tags=["chinook_sql"], metadata={"tools": row.tools}) return row @@ -175,16 +165,11 @@ async def test_complex_query_3(row: EvaluationRow) -> EvaluationRow: input_rows=[collect_dataset()[4:5]], completion_params=[ { - "model": { - "orchestrator_agent_model": { - "model": "accounts/fireworks/models/kimi-k2-instruct", - "provider": "fireworks", - } - } + "model": "accounts/fireworks/models/kimi-k2-instruct", + "provider": "fireworks", }, ], - rollout_processor=PydanticAgentRolloutProcessor(), - rollout_processor_kwargs={"agent": setup_agent}, + rollout_processor=PydanticAgentRolloutProcessor(agent_factory), mode="pointwise", ) async def test_complex_query_4(row: EvaluationRow) -> EvaluationRow: @@ -192,7 +177,7 @@ async def test_complex_query_4(row: EvaluationRow) -> EvaluationRow: Complex queries - PydanticAI automatically creates rich Langfuse traces. """ if langfuse_client: - langfuse_client.update_current_trace(tags=["chinook_sql"]) + langfuse_client.update_current_trace(tags=["chinook_sql"], metadata={"tools": row.tools}) return row @@ -207,16 +192,11 @@ async def test_complex_query_4(row: EvaluationRow) -> EvaluationRow: input_rows=[collect_dataset()[5:6]], completion_params=[ { - "model": { - "orchestrator_agent_model": { - "model": "accounts/fireworks/models/kimi-k2-instruct", - "provider": "fireworks", - } - } + "model": "accounts/fireworks/models/kimi-k2-instruct", + "provider": "fireworks", }, ], - rollout_processor=PydanticAgentRolloutProcessor(), - rollout_processor_kwargs={"agent": setup_agent}, + rollout_processor=PydanticAgentRolloutProcessor(agent_factory), mode="pointwise", ) async def test_complex_query_5(row: EvaluationRow) -> EvaluationRow: @@ -224,6 +204,6 @@ async def test_complex_query_5(row: EvaluationRow) -> EvaluationRow: Complex queries - PydanticAI automatically creates rich Langfuse traces. """ if langfuse_client: - langfuse_client.update_current_trace(tags=["chinook_sql"]) + langfuse_client.update_current_trace(tags=["chinook_sql"], metadata={"tools": row.tools}) return row