diff --git a/eval_protocol/adapters/braintrust.py b/eval_protocol/adapters/braintrust.py index bc444cfa..979d4d52 100644 --- a/eval_protocol/adapters/braintrust.py +++ b/eval_protocol/adapters/braintrust.py @@ -1,8 +1,240 @@ -"""Deprecated adapter wrappers for Braintrust. +"""Braintrust adapter for Eval Protocol. -This module forwards imports to :mod:`eval_protocol.integrations.braintrust`. +This adapter allows pulling data from Braintrust deployments and converting it +to EvaluationRow format for use in evaluation pipelines. """ +import logging +import os +import random +import time +from datetime import datetime, timedelta +from typing import Any, Dict, List, Optional, Protocol + +import requests + +from eval_protocol.models import EvaluationRow, InputMetadata, Message +from .utils import extract_messages_from_data + +# Keep backward compatibility from ..integrations.braintrust import reward_fn_to_scorer, scorer_to_reward_fn -__all__ = ["scorer_to_reward_fn", "reward_fn_to_scorer"] + +logger = logging.getLogger(__name__) + + +class TraceConverter(Protocol): + """Protocol for custom trace-to-EvaluationRow converter functions. + + A converter function should take a Braintrust trace along with processing + options and return an EvaluationRow or None to skip the trace. + """ + + def __call__( + self, + trace: Dict[str, Any], + include_tool_calls: bool, + ) -> Optional[EvaluationRow]: + """Convert a Braintrust trace to an EvaluationRow. + + Args: + trace: The Braintrust trace object to convert + include_tool_calls: Whether to include tool calling information + + Returns: + EvaluationRow or None if the trace should be skipped + """ + ... + + +def convert_trace_to_evaluation_row(trace: Dict[str, Any], include_tool_calls: bool = True) -> Optional[EvaluationRow]: + """Convert a Braintrust trace to EvaluationRow format. + + Args: + trace: Braintrust trace object + include_tool_calls: Whether to include tool calling information + + Returns: + EvaluationRow or None if conversion fails + """ + try: + # Extract messages from the trace + messages = extract_messages_from_trace(trace, include_tool_calls) + + # Extract tools if available + tools = None + if include_tool_calls: + metadata = trace.get("metadata", {}) + tools = metadata.get("tools") + if not tools: + hidden_params = metadata.get("hidden_params", {}) + optional_params = hidden_params.get("optional_params", {}) + tools = optional_params.get("tools") + + if not messages: + return None + + return EvaluationRow( + messages=messages, + tools=tools, + input_metadata=InputMetadata( + session_data={ + "braintrust_trace_id": trace.get("id"), + } + ), + ) + + except (AttributeError, ValueError, KeyError) as e: + logger.error("Error converting trace %s: %s", trace.get("id", "unknown"), e) + return None + + +def extract_messages_from_trace(trace: Dict[str, Any], include_tool_calls: bool = True) -> List[Message]: + """Extract messages from Braintrust trace input and output. + + Args: + trace: Braintrust trace object + include_tool_calls: Whether to include tool calling information + + Returns: + List of Message objects + """ + messages = [] + + try: + # Look for complete conversations (input + output arrays) + input_data = trace.get("input") + + output_data = None + output_list = trace.get("output", []) + if output_list and len(output_list) > 0: + first_output = output_list[0] + if isinstance(first_output, dict): + output_data = first_output.get("message") + + # Skip spans without meaningful conversation data + if not input_data or not output_data: + return messages + + # Extract messages from input and output + if input_data: + messages.extend(extract_messages_from_data(input_data, include_tool_calls)) + if output_data: + messages.extend(extract_messages_from_data(output_data, include_tool_calls)) + + except (AttributeError, ValueError, KeyError) as e: + logger.warning("Error processing trace %s: %s", trace.get("id", "unknown"), e) + + return messages + + +class BraintrustAdapter: + """Adapter to pull data from Braintrust and convert to EvaluationRow format. + + This adapter can pull both chat conversations and tool calling traces from + Braintrust deployments and convert them into the EvaluationRow format expected + by the evaluation protocol. + + Examples: + Basic usage: + >>> adapter = BraintrustAdapter( + ... api_key="your_api_key", + ... project_id="your_project_id" + ... ) + >>> btql_query = "select: * from: project_logs('your_project_id') traces limit: 10" + >>> rows = adapter.get_evaluation_rows(btql_query) + + Using BTQL for custom queries: + >>> btql_query = ''' + ... select: * + ... from: project_logs('your_project_id') traces + ... filter: metadata.agent_name = 'agent_instance' + ... limit: 50 + ... ''' + >>> rows = adapter.get_evaluation_rows(btql_query) + """ + + def __init__( + self, + api_key: Optional[str] = None, + api_url: Optional[str] = None, + project_id: Optional[str] = None, + ): + """Initialize the Braintrust adapter. + + Args: + api_key: Braintrust API key (defaults to BRAINTRUST_API_KEY env var) + api_url: Braintrust API URL (defaults to BRAINTRUST_API_URL env var) + project_id: Project ID to fetch logs from (defaults to BRAINTRUST_PROJECT_ID env var) + """ + self.api_key = api_key or os.getenv("BRAINTRUST_API_KEY") + self.api_url = api_url or os.getenv("BRAINTRUST_API_URL", "https://api.braintrust.dev") + self.project_id = project_id or os.getenv("BRAINTRUST_PROJECT_ID") + + if not self.api_key: + raise ValueError("BRAINTRUST_API_KEY environment variable or api_key parameter required") + if not self.project_id: + raise ValueError("BRAINTRUST_PROJECT_ID environment variable or project_id parameter required") + + def get_evaluation_rows( + self, + btql_query: str, + include_tool_calls: bool = True, + converter: Optional[TraceConverter] = None, + ) -> List[EvaluationRow]: + """Get evaluation rows using a custom BTQL query. + + Args: + btql_query: The BTQL query string to execute + include_tool_calls: Whether to include tool calling information + converter: Optional custom converter implementing TraceConverter protocol + + Returns: + List[EvaluationRow]: Converted evaluation rows + """ + eval_rows = [] + + headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"} + + response = requests.post(f"{self.api_url}/btql", headers=headers, json={"query": btql_query, "fmt": "json"}) + response.raise_for_status() + query_response = response.json() + + if not query_response or not query_response.get("data"): + logger.debug("No data returned from BTQL query") + return eval_rows + + all_traces = query_response["data"] + logger.debug("BTQL query returned %d traces", len(all_traces)) + + # Process each selected trace + for trace in all_traces: + try: + if converter: + eval_row = converter(trace, include_tool_calls) + else: + eval_row = convert_trace_to_evaluation_row(trace, include_tool_calls) + if eval_row: + eval_rows.append(eval_row) + except (AttributeError, ValueError, KeyError) as e: + logger.warning("Failed to convert trace %s: %s", trace.get("id", "unknown"), e) + continue + + logger.info("Successfully processed %d BTQL results into %d evaluation rows", len(all_traces), len(eval_rows)) + return eval_rows + + +def create_braintrust_adapter( + api_key: Optional[str] = None, + api_url: Optional[str] = None, + project_id: Optional[str] = None, +) -> BraintrustAdapter: + """Factory function to create a Braintrust adapter.""" + return BraintrustAdapter( + api_key=api_key, + api_url=api_url, + project_id=project_id, + ) + + +__all__ = ["scorer_to_reward_fn", "reward_fn_to_scorer", "BraintrustAdapter", "create_braintrust_adapter"] diff --git a/eval_protocol/adapters/langfuse.py b/eval_protocol/adapters/langfuse.py index e3f3144a..115448dd 100644 --- a/eval_protocol/adapters/langfuse.py +++ b/eval_protocol/adapters/langfuse.py @@ -12,6 +12,7 @@ from typing import Any, Dict, List, Optional, Protocol from eval_protocol.models import EvaluationRow, InputMetadata, Message +from .utils import extract_messages_from_data logger = logging.getLogger(__name__) @@ -112,7 +113,7 @@ def extract_messages_from_trace( if span_name: # Look for a generation tied to a span name try: # Find the final generation in the named span - gen: ObservationsView | None = find_final_generation_in_span(trace, span_name) + gen: ObservationsView | None = get_final_generation_in_span(trace, span_name) if not gen: return messages @@ -141,87 +142,8 @@ def extract_messages_from_trace( return messages -def extract_messages_from_data(data, include_tool_calls: bool) -> List[Message]: - """Extract messages from data (works for both input and output). - - Args: - data: Data from trace or generation (input or output) - include_tool_calls: Whether to include tool calling information - - Returns: - List of Message objects - """ - messages = [] - - if isinstance(data, dict): - if "messages" in data: - # OpenAI-style messages format - for msg in data["messages"]: - messages.append(dict_to_message(msg, include_tool_calls)) - elif "role" in data: - # Single message format - messages.append(dict_to_message(data, include_tool_calls)) - elif "prompt" in data: - # Simple prompt format - messages.append(Message(role="user", content=str(data["prompt"]))) - elif "content" in data: - # Simple content format - messages.append(Message(role="assistant", content=str(data["content"]))) - else: - # Fallback: treat as single message - messages.append(dict_to_message(data, include_tool_calls)) - elif isinstance(data, list): - # Direct list of message dicts - for msg in data: - if isinstance(msg, dict): - messages.append(dict_to_message(msg, include_tool_calls)) - elif isinstance(data, str): - # Simple string - role depends on context, default to user - messages.append(Message(role="user", content=data)) - - return messages - - -def dict_to_message(msg_dict: Dict[str, Any], include_tool_calls: bool = True) -> Message: - """Convert a dictionary to a Message object. - - Args: - msg_dict: Dictionary containing message data - include_tool_calls: Whether to include tool calling information - - Returns: - Message object - """ - # Extract basic message components - role = msg_dict.get("role", "assistant") - content = msg_dict.get("content") - name = msg_dict.get("name") - - # Handle tool calls if enabled - tool_calls = None - tool_call_id = None - function_call = None - - if include_tool_calls: - if "tool_calls" in msg_dict: - tool_calls = msg_dict["tool_calls"] - if "tool_call_id" in msg_dict: - tool_call_id = msg_dict["tool_call_id"] - if "function_call" in msg_dict: - function_call = msg_dict["function_call"] - - return Message( - role=role, - content=content, - name=name, - tool_call_id=tool_call_id, - tool_calls=tool_calls, - function_call=function_call, - ) - - -def find_final_generation_in_span(trace: TraceWithFullDetails, span_name: str) -> ObservationsView | None: - """Find the final generation within a named span that contains full message history. +def get_final_generation_in_span(trace: TraceWithFullDetails, span_name: str) -> ObservationsView | None: + """Get the final generation within a named span that contains full message history. Args: trace: Langfuse trace object @@ -511,6 +433,36 @@ def get_evaluation_rows_by_ids( continue return eval_rows + def push_scores(self, rows: List[EvaluationRow], model_name: str, mean_score: float) -> None: + """Push evaluation scores back to Langfuse traces for tracking and analysis. + + Creates a score entry in Langfuse for each unique trace_id found in the evaluation + rows' session data. This allows you to see evaluation results directly in the + Langfuse UI alongside the original traces. + + Args: + rows: List of EvaluationRow objects with session_data containing trace IDs + model_name: Name of the model (used as the score name in Langfuse) + mean_score: The calculated mean score to push to Langfuse + + Note: + Silently handles errors if rows lack session data + """ + try: + for trace_id in set( + row.input_metadata.session_data["langfuse_trace_id"] + for row in rows + if row.evaluation_result and row.input_metadata and row.input_metadata.session_data + ): + if trace_id: + self.client.create_score( + trace_id=trace_id, + name=model_name, + value=mean_score, + ) + except Exception as e: + logger.warning("Failed to push scores to Langfuse: %s", e) + def create_langfuse_adapter() -> LangfuseAdapter: """Factory function to create a Langfuse adapter.""" diff --git a/eval_protocol/adapters/utils.py b/eval_protocol/adapters/utils.py new file mode 100644 index 00000000..0ccf6caf --- /dev/null +++ b/eval_protocol/adapters/utils.py @@ -0,0 +1,98 @@ +"""Common utilities for adapter implementations. + +This module contains shared functions and utilities used across different +adapter implementations to avoid code duplication. +""" + +import logging +import time +from typing import Any, Dict, List + +from eval_protocol.models import Message + +logger = logging.getLogger(__name__) + + +def extract_messages_from_data(data, include_tool_calls: bool) -> List[Message]: + """Extract messages from data (works for both input and output). + + This is a common function used by multiple adapters to parse message data + from various formats (dict, list, string) into standardized Message objects. + + Args: + data: Data from trace/log (input or output) - can be dict, list, or string + include_tool_calls: Whether to include tool calling information + + Returns: + List of Message objects + """ + messages = [] + + if isinstance(data, dict): + if "messages" in data: + # OpenAI-style messages format + for msg in data["messages"]: + messages.append(dict_to_message(msg, include_tool_calls)) + elif "role" in data: + # Single message format + messages.append(dict_to_message(data, include_tool_calls)) + elif "prompt" in data: + # Simple prompt format + messages.append(Message(role="user", content=str(data["prompt"]))) + elif "content" in data: + # Simple content format + messages.append(Message(role="assistant", content=str(data["content"]))) + else: + # Fallback: treat as single message + messages.append(dict_to_message(data, include_tool_calls)) + elif isinstance(data, list): + # Direct list of message dicts + for msg in data: + if isinstance(msg, dict): + messages.append(dict_to_message(msg, include_tool_calls)) + elif isinstance(data, str): + # Simple string - role depends on context, default to user + messages.append(Message(role="user", content=data)) + + return messages + + +def dict_to_message(msg_dict: Dict[str, Any], include_tool_calls: bool = True) -> Message: + """Convert a dictionary to a Message object. + + This is a common function used by multiple adapters to convert dictionary + representations of messages into standardized Message objects. + + Args: + msg_dict: Dictionary containing message data + include_tool_calls: Whether to include tool calling information + + Returns: + Message object + """ + # Extract basic message components + role = msg_dict.get("role", "assistant") + content = msg_dict.get("content") + name = msg_dict.get("name") + + # Handle tool calls if enabled + tool_calls = None + tool_call_id = None + function_call = None + + if include_tool_calls: + if "tool_calls" in msg_dict: + tool_calls = msg_dict["tool_calls"] + if "tool_call_id" in msg_dict: + tool_call_id = msg_dict["tool_call_id"] + if "function_call" in msg_dict: + function_call = msg_dict["function_call"] + + return Message( + role=role, + content=content, + name=name, + tool_call_id=tool_call_id, + tool_calls=tool_calls, + function_call=function_call, + ) diff --git a/eval_protocol/quickstart/llm_judge.py b/eval_protocol/quickstart/llm_judge.py index 7c1be954..3978d46e 100644 --- a/eval_protocol/quickstart/llm_judge.py +++ b/eval_protocol/quickstart/llm_judge.py @@ -16,7 +16,6 @@ split_multi_turn_rows, JUDGE_CONFIGS, calculate_bootstrap_scores, - push_scores_to_langfuse, run_judgment_async, ) import asyncio @@ -131,6 +130,6 @@ async def run_judgment(row): row.evaluation_result.score = mean_score # Optional, push scores back to Langfuse. Note that one score per model will be pushed back onto same trace. - push_scores_to_langfuse(rows, model_name, mean_score) + adapter.push_scores(rows, model_name, mean_score) return rows diff --git a/eval_protocol/quickstart/llm_judge_braintrust.py b/eval_protocol/quickstart/llm_judge_braintrust.py new file mode 100644 index 00000000..a1902cf7 --- /dev/null +++ b/eval_protocol/quickstart/llm_judge_braintrust.py @@ -0,0 +1,133 @@ +""" +Default LLM judge for Eval Protocol using Braintrust. Inspired by Arena-Hard-Auto. +""" + +import os +from datetime import datetime +from typing import List, Dict, Any, Optional +from tqdm import tqdm + +import pytest + +from eval_protocol.models import EvaluateResult, EvaluationRow, MetricResult +from eval_protocol.pytest import evaluation_test +from eval_protocol.pytest.default_single_turn_rollout_process import SingleTurnRolloutProcessor +from eval_protocol.quickstart.utils import ( + split_multi_turn_rows, + JUDGE_CONFIGS, + calculate_bootstrap_scores, + run_judgment_async, +) +import asyncio +from openai import AsyncOpenAI +from eval_protocol.adapters.braintrust import create_braintrust_adapter + +adapter = create_braintrust_adapter() + + +@pytest.mark.asyncio +@evaluation_test( + input_rows=[ + adapter.get_evaluation_rows( + btql_query=f""" +select: * +from: project_logs('{os.getenv("BRAINTRUST_PROJECT_ID")}') traces +filter: is_root = true +limit: 5 +""" + ) + ], + completion_params=[ + {"model": "gpt-4.1"}, + { + "max_tokens": 131000, + "extra_body": {"reasoning_effort": "medium"}, + "model": "fireworks_ai/accounts/fireworks/models/gpt-oss-120b", + }, + { + "max_tokens": 131000, + "extra_body": {"reasoning_effort": "low"}, + "model": "fireworks_ai/accounts/fireworks/models/gpt-oss-20b", + }, + ], + rollout_processor=SingleTurnRolloutProcessor(), + preprocess_fn=split_multi_turn_rows, + max_concurrent_rollouts=64, + mode="all", +) +async def test_llm_judge(rows: list[EvaluationRow]) -> list[EvaluationRow]: + """ + LLM Judge evaluation using Arena-Hard-Auto style pairwise comparisons. + + Compares model responses against ground truth using an LLM judge. For each row: + 1. Extracts the question from messages[:-1] + 2. Compares messages[-1] (new model response) vs ground_truth (baseline response) + 3. Runs two judgment rounds (A vs B, B vs A) to reduce position bias + 4. Calculates bootstrap scores across all comparisons + 5. Updates evaluation_result with final scores and confidence intervals + + Args: + rows: List of EvaluationRow objects with messages, ground_truth, and tools + + Returns: + Same rows with updated evaluation_result containing scores and judgments + """ + + judge_name = "gemini-2.5-pro" # Edit to which judge you'd like to use. Configs are in utils.py. + + if not rows: + print("❌ No evaluation rows provided") + return rows + + print(f"🔄 Processing {len(rows)} evaluation rows for LLM judging...") + + model_name = rows[0].input_metadata.completion_params.get("model", "unknown_model") + + judgments = [] + max_concurrency = JUDGE_CONFIGS[judge_name]["max_concurrency"] + + judge_config = JUDGE_CONFIGS[judge_name] + + async with AsyncOpenAI( + api_key=judge_config.get("api_key"), base_url=judge_config.get("base_url") + ) as shared_client: + semaphore = asyncio.Semaphore(max_concurrency) + + async def run_judgment(row): + async with semaphore: + return await run_judgment_async(row, model_name, judge_name, shared_client) + + tasks = [run_judgment(row) for row in rows] + + for coro in tqdm(asyncio.as_completed(tasks), total=len(tasks), desc="Generating judgments"): + result = await coro + if result and result["games"][0] and result["games"][1]: + judgments.append(result) + + if not judgments: + print("❌ No valid judgments generated") + return rows + + print(f"✅ Generated {len(judgments)} valid judgments") + + # Calculate bootstrap scores + result = calculate_bootstrap_scores(judgments) + if not result: + print("❌ No valid scores extracted") + return rows + + mean_score, lower_score, upper_score = result + + # Print leaderboard + print("\n##### LLM Judge Results (90th percentile CI) #####") + + clean_model_name = model_name.split("/")[-1] # Clean model name + + print(f"{clean_model_name}: {mean_score:.1%} (CI: {lower_score:.1%} - {upper_score:.1%})") + print("original: 50.0% (CI: 50.0% - 50.0%)") + + for row in rows: + if row.evaluation_result: + row.evaluation_result.score = mean_score + + return rows diff --git a/eval_protocol/quickstart/utils.py b/eval_protocol/quickstart/utils.py index d862a472..9fda11b5 100644 --- a/eval_protocol/quickstart/utils.py +++ b/eval_protocol/quickstart/utils.py @@ -280,39 +280,3 @@ def calculate_bootstrap_scores(judgments: List[Dict[str, Any]]) -> Optional[tupl upper_score = bootstraps.quantile(0.95) return mean_score, lower_score, upper_score - - -def push_scores_to_langfuse(rows: List[EvaluationRow], model_name: str, mean_score: float) -> None: - """ - Push evaluation scores back to Langfuse traces for tracking and analysis. - - Creates a score entry in Langfuse for each unique trace_id found in the evaluation - rows' session data. This allows you to see evaluation results directly in the - Langfuse UI alongside the original traces. - - Args: - rows: List of EvaluationRow objects with session_data containing trace IDs - model_name: Name of the model (used as the score name in Langfuse) - mean_score: The calculated mean score to push to Langfuse - - Note: - Silently handles errors if Langfuse is unavailable or if rows lack session data - """ - try: - from eval_protocol.adapters.langfuse import create_langfuse_adapter - - langfuse = create_langfuse_adapter().client - - for trace_id in set( - row.input_metadata.session_data["langfuse_trace_id"] - for row in rows - if row.evaluation_result and row.input_metadata and row.input_metadata.session_data - ): - if trace_id: - langfuse.create_score( - trace_id=trace_id, - name=model_name, - value=mean_score, - ) - except Exception as e: - print(f"⚠️ Failed to push scores to Langfuse: {e}") diff --git a/pyproject.toml b/pyproject.toml index b55ace62..4724328b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -136,6 +136,9 @@ chinook = [ langchain = [ "langchain-core>=0.3.0", ] +braintrust = [ + "braintrust[otel]", +] # Optional deps for LangGraph example/tests langgraph = [ diff --git a/tests/chinook/braintrust/generate_traces.py b/tests/chinook/braintrust/generate_traces.py new file mode 100644 index 00000000..1e95b598 --- /dev/null +++ b/tests/chinook/braintrust/generate_traces.py @@ -0,0 +1,225 @@ +import os +import pytest + +from eval_protocol.models import EvaluationRow, Message +from eval_protocol.pytest import evaluation_test +from eval_protocol.pytest.default_pydantic_ai_rollout_processor import PydanticAgentRolloutProcessor + +from tests.chinook.dataset import collect_dataset +from tests.chinook.pydantic.agent import setup_agent + +dataset = collect_dataset() +current_idx = 0 + + +class SpanIDCapturingProcessor: + """Custom processor to capture span IDs when they open/close.""" + + def on_start(self, span, parent_context=None): + """Called when span starts - capture the ID.""" + global current_idx + if span.name == "agent run": + span.set_attribute("ground_truth", dataset[current_idx].ground_truth) + current_idx += 1 + + def on_end(self, span): + pass + + def shutdown(self): + pass + + def force_flush(self, timeout_millis=30000): + pass + + +try: + from braintrust.otel import BraintrustSpanProcessor + from opentelemetry import trace + from opentelemetry.sdk.trace import TracerProvider + from pydantic_ai.agent import Agent + from braintrust import init_logger + + BRAINTRUST_AVAILABLE = True + + provider = TracerProvider() + trace.set_tracer_provider(provider) + provider.add_span_processor(BraintrustSpanProcessor()) # pyright: ignore[reportArgumentType] + provider.add_span_processor(SpanIDCapturingProcessor()) # pyright: ignore[reportArgumentType] + + logger = init_logger(project="default-otel-project") + + Agent.instrument_all() + +except ImportError: + BRAINTRUST_AVAILABLE = False + + def setup_braintrust(): + pass + + +@pytest.mark.skipif( + os.environ.get("CI") == "true", + reason="Only run this test locally (skipped in CI)", +) +@pytest.mark.asyncio +@evaluation_test( + input_rows=[collect_dataset()[0:1]], + completion_params=[ + { + "model": { + "orchestrator_agent_model": { + "model": "accounts/fireworks/models/kimi-k2-instruct", + "provider": "fireworks", + } + } + }, + ], + rollout_processor=PydanticAgentRolloutProcessor(), + rollout_processor_kwargs={"agent": setup_agent}, + mode="pointwise", +) +async def test_complex_query_0(row: EvaluationRow) -> EvaluationRow: + """ + Complex queries - Ground truth set by span processor during span creation. + """ + return row + + +@pytest.mark.skipif( + os.environ.get("CI") == "true", + reason="Only run this test locally (skipped in CI)", +) +@pytest.mark.asyncio +@evaluation_test( + input_rows=[collect_dataset()[1:2]], + completion_params=[ + { + "model": { + "orchestrator_agent_model": { + "model": "accounts/fireworks/models/kimi-k2-instruct", + "provider": "fireworks", + } + } + }, + ], + rollout_processor=PydanticAgentRolloutProcessor(), + rollout_processor_kwargs={"agent": setup_agent}, + mode="pointwise", +) +async def test_complex_query_1(row: EvaluationRow) -> EvaluationRow: + """ + Complex queries - PydanticAI automatically creates rich Braintrust traces. + """ + return row + + +@pytest.mark.skipif( + os.environ.get("CI") == "true", + reason="Only run this test locally (skipped in CI)", +) +@pytest.mark.asyncio +@evaluation_test( + input_rows=[collect_dataset()[2:3]], + completion_params=[ + { + "model": { + "orchestrator_agent_model": { + "model": "accounts/fireworks/models/kimi-k2-instruct", + "provider": "fireworks", + } + } + }, + ], + rollout_processor=PydanticAgentRolloutProcessor(), + rollout_processor_kwargs={"agent": setup_agent}, + mode="pointwise", +) +async def test_complex_query_2(row: EvaluationRow) -> EvaluationRow: + """ + Complex queries - PydanticAI automatically creates rich Braintrust traces. + """ + return row + + +@pytest.mark.skipif( + os.environ.get("CI") == "true", + reason="Only run this test locally (skipped in CI)", +) +@pytest.mark.asyncio +@evaluation_test( + input_rows=[collect_dataset()[3:4]], + completion_params=[ + { + "model": { + "orchestrator_agent_model": { + "model": "accounts/fireworks/models/kimi-k2-instruct", + "provider": "fireworks", + } + } + }, + ], + rollout_processor=PydanticAgentRolloutProcessor(), + rollout_processor_kwargs={"agent": setup_agent}, + mode="pointwise", +) +async def test_complex_query_3(row: EvaluationRow) -> EvaluationRow: + """ + Complex queries - PydanticAI automatically creates rich Braintrust traces. + """ + return row + + +@pytest.mark.skipif( + os.environ.get("CI") == "true", + reason="Only run this test locally (skipped in CI)", +) +@pytest.mark.asyncio +@evaluation_test( + input_rows=[collect_dataset()[4:5]], + completion_params=[ + { + "model": { + "orchestrator_agent_model": { + "model": "accounts/fireworks/models/kimi-k2-instruct", + "provider": "fireworks", + } + } + }, + ], + rollout_processor=PydanticAgentRolloutProcessor(), + rollout_processor_kwargs={"agent": setup_agent}, + mode="pointwise", +) +async def test_complex_query_4(row: EvaluationRow) -> EvaluationRow: + """ + Complex queries - PydanticAI automatically creates rich Braintrust traces. + """ + return row + + +@pytest.mark.skipif( + os.environ.get("CI") == "true", + reason="Only run this test locally (skipped in CI)", +) +@pytest.mark.asyncio +@evaluation_test( + input_rows=[collect_dataset()[5:6]], + completion_params=[ + { + "model": { + "orchestrator_agent_model": { + "model": "accounts/fireworks/models/kimi-k2-instruct", + "provider": "fireworks", + } + } + }, + ], + rollout_processor=PydanticAgentRolloutProcessor(), + rollout_processor_kwargs={"agent": setup_agent}, + mode="pointwise", +) +async def test_complex_query_5(row: EvaluationRow) -> EvaluationRow: + """ + Complex queries - PydanticAI automatically creates rich Braintrust traces. + """ + return row diff --git a/tests/chinook/braintrust/test_braintrust_chinook.py b/tests/chinook/braintrust/test_braintrust_chinook.py new file mode 100644 index 00000000..d9c2a77d --- /dev/null +++ b/tests/chinook/braintrust/test_braintrust_chinook.py @@ -0,0 +1,133 @@ +import os +from datetime import datetime, timedelta +from typing import List, Any, Dict + +import pytest +from pydantic import BaseModel +from pydantic_ai import Agent +from pydantic_ai.models.openai import OpenAIModel + +from eval_protocol.models import EvaluateResult, EvaluationRow, Message, InputMetadata +from eval_protocol.pytest import evaluation_test, NoOpRolloutProcessor + +try: + from eval_protocol.adapters.braintrust import create_braintrust_adapter + + BRAINTRUST_AVAILABLE = True +except ImportError: + BRAINTRUST_AVAILABLE = False + create_braintrust_adapter = None + + +class Response(BaseModel): + score: float + reason: str + + +LLM_JUDGE_PROMPT = ( + "Your job is to compare the response to the expected answer.\n" + "The response will be a narrative report of the query results.\n" + "If the response contains the same or well summarized information as the expected answer, return 1.0.\n" + "If the response does not contain the same information or is missing information, return 0.0." +) + + +def fetch_braintrust_traces_as_evaluation_rows(hours_back: int = 24) -> List[EvaluationRow]: + """ + Dataset adapter: Use BraintrustAdapter to fetch traces from project logs. + """ + if not BRAINTRUST_AVAILABLE or not create_braintrust_adapter: + print("⚠️ Braintrust unavailable - no traces to evaluate") + return [] + + try: + print("🧠 Using BraintrustAdapter to fetch Chinook traces") + + adapter = create_braintrust_adapter( + project_id="df6863de-6ce2-4fcc-9995-1fa6605f8623" # Your Braintrust project + ) + + # Use the adapter to fetch logs + now = datetime.now() + from_timestamp = now - timedelta(hours=hours_back) + + evaluation_rows = list( + adapter.get_evaluation_rows( + from_timestamp=from_timestamp, + to_timestamp=now, + ) + ) + + print(f"✅ BraintrustAdapter extracted {len(evaluation_rows)} evaluation rows") + return evaluation_rows + + except Exception as e: + print(f"❌ BraintrustAdapter failed: {e}") + return [] + + +@pytest.mark.skipif( + os.environ.get("CI") == "true", + reason="Only run this test locally (skipped in CI)", +) +@pytest.mark.asyncio +@evaluation_test( + input_rows=[fetch_braintrust_traces_as_evaluation_rows(hours_back=168)], # 1 week back + rollout_processor=NoOpRolloutProcessor(), # No-op since traces already exist + mode="pointwise", +) +async def test_braintrust_trace_evaluation(row: EvaluationRow) -> EvaluationRow: + """ + This test acts as an external evaluation pipeline for Braintrust traces. + It: + 1. Gets traces from Braintrust (via dataset adapter) + 2. Uses NoOpRolloutProcessor (traces already exist) + 3. Evaluates each trace using same LLM judge as PydanticAI test + 4. Pushes scores back to Braintrust (if API supports it) + """ + # Same eval logic as PydanticAI example + last_assistant_message = row.last_assistant_message() + if last_assistant_message is None: + row.evaluation_result = EvaluateResult( + score=0.0, + reason="No assistant message found", + ) + elif not last_assistant_message.content: + row.evaluation_result = EvaluateResult( + score=0.0, + reason="No assistant message found", + ) + else: + model = OpenAIModel( + "accounts/fireworks/models/kimi-k2-instruct", + provider="fireworks", + ) + + class Response(BaseModel): + """ + A score between 0.0 and 1.0 indicating whether the response is correct. + """ + + score: float + + """ + A short explanation of why the response is correct or incorrect. + """ + reason: str + + comparison_agent = Agent( + model=model, + system_prompt=LLM_JUDGE_PROMPT, + output_type=Response, + output_retries=5, + ) + + result = await comparison_agent.run( + f"Expected answer: {row.ground_truth}\nResponse: {last_assistant_message.content}" + ) + row.evaluation_result = EvaluateResult( + score=result.output.score, + reason=result.output.reason, + ) + + return row diff --git a/uv.lock b/uv.lock index 6d333a22..f8f310a4 100644 --- a/uv.lock +++ b/uv.lock @@ -1,5 +1,5 @@ version = 1 -revision = 3 +revision = 2 requires-python = ">=3.10" resolution-markers = [ "python_full_version >= '3.13'", @@ -491,6 +491,33 @@ version = "2.3.5" source = { registry = "https://pypi.org/simple" } sdist = { url = "https://files.pythonhosted.org/packages/dd/5a/ad8d3ef9c13d5afcc1e44a77f11792ee717f6727b3320bddbc607e935e2a/box2d-py-2.3.5.tar.gz", hash = "sha256:b37dc38844bcd7def48a97111d2b082e4f81cca3cece7460feb3eacda0da2207", size = 374446, upload-time = "2018-10-02T01:03:23.527Z" } +[[package]] +name = "braintrust" +version = "0.2.7" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "chevron" }, + { name = "exceptiongroup" }, + { name = "gitpython" }, + { name = "python-dotenv" }, + { name = "python-slugify" }, + { name = "requests" }, + { name = "sseclient-py" }, + { name = "tqdm" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/9e/ab/199161c7810f9c22fd04dff374075536fc66aabcca5ea522296aedeb6378/braintrust-0.2.7.tar.gz", hash = "sha256:faa9d54c2d6dac30b11d9b4b68817aa1258aeab5945758159107fb6402ac5b80", size = 184823, upload-time = "2025-09-11T23:44:58.661Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/82/02/f704c8ea68622286dd7aaa16a3a223a9ee2f8b337c86c652d111aa05b442/braintrust-0.2.7-py3-none-any.whl", hash = "sha256:735e1b32a785e144756c4821e0515dd40dca921c86c417000f4b5617024f1349", size = 214417, upload-time = "2025-09-11T23:44:57.028Z" }, +] + +[package.optional-dependencies] +otel = [ + { name = "opentelemetry-api" }, + { name = "opentelemetry-exporter-otlp-proto-http" }, + { name = "opentelemetry-sdk" }, +] + [[package]] name = "brotli" version = "1.1.0" @@ -744,6 +771,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/20/94/c5790835a017658cbfabd07f3bfb549140c3ac458cfc196323996b10095a/charset_normalizer-3.4.2-py3-none-any.whl", hash = "sha256:7f56930ab0abd1c45cd15be65cc741c28b1c9a34876ce8c17a2fa107810c0af0", size = 52626, upload-time = "2025-05-02T08:34:40.053Z" }, ] +[[package]] +name = "chevron" +version = "0.14.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/15/1f/ca74b65b19798895d63a6e92874162f44233467c9e7c1ed8afd19016ebe9/chevron-0.14.0.tar.gz", hash = "sha256:87613aafdf6d77b6a90ff073165a61ae5086e21ad49057aa0e53681601800ebf", size = 11440, upload-time = "2021-01-02T22:47:59.233Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/52/93/342cc62a70ab727e093ed98e02a725d85b746345f05d2b5e5034649f4ec8/chevron-0.14.0-py3-none-any.whl", hash = "sha256:fbf996a709f8da2e745ef763f482ce2d311aa817d287593a5b990d6d6e4f0443", size = 11595, upload-time = "2021-01-02T22:47:57.847Z" }, +] + [[package]] name = "click" version = "8.2.1" @@ -1227,6 +1263,9 @@ box2d = [ { name = "pillow" }, { name = "swig" }, ] +braintrust = [ + { name = "braintrust", extra = ["otel"] }, +] chinook = [ { name = "psycopg2-binary" }, ] @@ -1315,6 +1354,7 @@ requires-dist = [ { name = "aiosqlite" }, { name = "anthropic", specifier = ">=0.59.0" }, { name = "backoff", specifier = ">=2.2.0" }, + { name = "braintrust", extras = ["otel"], marker = "extra == 'braintrust'" }, { name = "build", marker = "extra == 'dev'" }, { name = "dataclasses-json", specifier = ">=0.5.7" }, { name = "datasets", specifier = ">=3.0.0" }, @@ -1394,7 +1434,7 @@ requires-dist = [ { name = "websockets", specifier = ">=15.0.1" }, { name = "werkzeug", marker = "extra == 'dev'", specifier = ">=2.0.0" }, ] -provides-extras = ["dev", "trl", "openevals", "fireworks", "box2d", "langfuse", "huggingface", "adapters", "langsmith", "bigquery", "svgbench", "pydantic", "supabase", "chinook", "langchain", "langgraph", "langgraph-tools"] +provides-extras = ["dev", "trl", "openevals", "fireworks", "box2d", "langfuse", "huggingface", "adapters", "langsmith", "bigquery", "svgbench", "pydantic", "supabase", "chinook", "langchain", "braintrust", "langgraph", "langgraph-tools"] [package.metadata.requires-dev] dev = [ @@ -1750,6 +1790,30 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/86/12/41fcfba4ae0f6b4805f09d11f0e6d6417df2572cea13208c0f439170ee0c/genai_prices-0.0.25-py3-none-any.whl", hash = "sha256:47b412e6927787caa00717a5d99b2e4c0858bed507bb16473b1bcaff48d5aae9", size = 47002, upload-time = "2025-09-01T17:30:41.012Z" }, ] +[[package]] +name = "gitdb" +version = "4.0.12" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "smmap" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/72/94/63b0fc47eb32792c7ba1fe1b694daec9a63620db1e313033d18140c2320a/gitdb-4.0.12.tar.gz", hash = "sha256:5ef71f855d191a3326fcfbc0d5da835f26b13fbcba60c32c21091c349ffdb571", size = 394684, upload-time = "2025-01-02T07:20:46.413Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a0/61/5c78b91c3143ed5c14207f463aecfc8f9dbb5092fb2869baf37c273b2705/gitdb-4.0.12-py3-none-any.whl", hash = "sha256:67073e15955400952c6565cc3e707c554a4eea2e428946f7a4c162fab9bd9bcf", size = 62794, upload-time = "2025-01-02T07:20:43.624Z" }, +] + +[[package]] +name = "gitpython" +version = "3.1.45" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "gitdb" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/9a/c8/dd58967d119baab745caec2f9d853297cec1989ec1d63f677d3880632b88/gitpython-3.1.45.tar.gz", hash = "sha256:85b0ee964ceddf211c41b9f27a49086010a190fd8132a24e21f362a4b36a791c", size = 215076, upload-time = "2025-07-24T03:45:54.871Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/01/61/d4b89fec821f72385526e1b9d9a3a0385dda4a72b206d28049e2c7cd39b8/gitpython-3.1.45-py3-none-any.whl", hash = "sha256:8908cb2e02fb3b93b7eb0f2827125cb699869470432cc885f019b8fd0fccff77", size = 208168, upload-time = "2025-07-24T03:45:52.517Z" }, +] + [[package]] name = "google-api-core" version = "2.25.1" @@ -5420,6 +5484,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/45/58/38b5afbc1a800eeea951b9285d3912613f2603bdf897a4ab0f4bd7f405fc/python_multipart-0.0.20-py3-none-any.whl", hash = "sha256:8a62d3a8335e06589fe01f2a3e178cdcc632f3fbe0d492ad9ee0ec35aab1f104", size = 24546, upload-time = "2024-12-16T19:45:44.423Z" }, ] +[[package]] +name = "python-slugify" +version = "8.0.4" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "text-unidecode" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/87/c7/5e1547c44e31da50a460df93af11a535ace568ef89d7a811069ead340c4a/python-slugify-8.0.4.tar.gz", hash = "sha256:59202371d1d05b54a9e7720c5e038f928f45daaffe41dd10822f3907b937c856", size = 10921, upload-time = "2024-02-08T18:32:45.488Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a4/62/02da182e544a51a5c3ccf4b03ab79df279f9c60c5e82d5e8bec7ca26ac11/python_slugify-8.0.4-py2.py3-none-any.whl", hash = "sha256:276540b79961052b66b7d116620b36518847f52d5fd9e3a70164fc8c50faa6b8", size = 10051, upload-time = "2024-02-08T18:32:43.911Z" }, +] + [[package]] name = "pytz" version = "2025.2" @@ -6168,6 +6244,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/b7/ce/149a00dd41f10bc29e5921b496af8b574d8413afcd5e30dfa0ed46c2cc5e/six-1.17.0-py2.py3-none-any.whl", hash = "sha256:4721f391ed90541fddacab5acf947aa0d3dc7d27b2e1e8eda2be8970586c3274", size = 11050, upload-time = "2024-12-04T17:35:26.475Z" }, ] +[[package]] +name = "smmap" +version = "5.0.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/44/cd/a040c4b3119bbe532e5b0732286f805445375489fceaec1f48306068ee3b/smmap-5.0.2.tar.gz", hash = "sha256:26ea65a03958fa0c8a1c7e8c7a58fdc77221b8910f6be2131affade476898ad5", size = 22329, upload-time = "2025-01-02T07:14:40.909Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/04/be/d09147ad1ec7934636ad912901c5fd7667e1c858e19d355237db0d0cd5e4/smmap-5.0.2-py3-none-any.whl", hash = "sha256:b30115f0def7d7531d22a0fb6502488d879e75b260a9db4d0819cfb25403af5e", size = 24303, upload-time = "2025-01-02T07:14:38.724Z" }, +] + [[package]] name = "sniffio" version = "1.3.1" @@ -6252,6 +6337,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/e4/f1/6c7eaa8187ba789a6dd6d74430307478d2a91c23a5452ab339b6fbe15a08/sse_starlette-2.4.1-py3-none-any.whl", hash = "sha256:08b77ea898ab1a13a428b2b6f73cfe6d0e607a7b4e15b9bb23e4a37b087fd39a", size = 10824, upload-time = "2025-07-06T09:41:32.321Z" }, ] +[[package]] +name = "sseclient-py" +version = "1.8.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/e8/ed/3df5ab8bb0c12f86c28d0cadb11ed1de44a92ed35ce7ff4fd5518a809325/sseclient-py-1.8.0.tar.gz", hash = "sha256:c547c5c1a7633230a38dc599a21a2dc638f9b5c297286b48b46b935c71fac3e8", size = 7791, upload-time = "2023-09-01T19:39:20.45Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/49/58/97655efdfeb5b4eeab85b1fc5d3fa1023661246c2ab2a26ea8e47402d4f2/sseclient_py-1.8.0-py2.py3-none-any.whl", hash = "sha256:4ecca6dc0b9f963f8384e9d7fd529bf93dd7d708144c4fb5da0e0a1a926fee83", size = 8828, upload-time = "2023-09-01T19:39:17.627Z" }, +] + [[package]] name = "stack-data" version = "0.6.3" @@ -6423,6 +6517,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/6a/9e/2064975477fdc887e47ad42157e214526dcad8f317a948dee17e1659a62f/terminado-0.18.1-py3-none-any.whl", hash = "sha256:a4468e1b37bb318f8a86514f65814e1afc977cf29b3992a4500d9dd305dcceb0", size = 14154, upload-time = "2024-03-12T14:34:36.569Z" }, ] +[[package]] +name = "text-unidecode" +version = "1.3" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/ab/e2/e9a00f0ccb71718418230718b3d900e71a5d16e701a3dae079a21e9cd8f8/text-unidecode-1.3.tar.gz", hash = "sha256:bad6603bb14d279193107714b288be206cac565dfa49aa5b105294dd5c4aab93", size = 76885, upload-time = "2019-08-30T21:36:45.405Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a6/a5/c0b6468d3824fe3fde30dbb5e1f687b291608f9473681bbf7dabbf5a87d7/text_unidecode-1.3-py2.py3-none-any.whl", hash = "sha256:1311f10e8b895935241623731c2ba64f4c455287888b18189350b67134a822e8", size = 78154, upload-time = "2019-08-30T21:37:03.543Z" }, +] + [[package]] name = "tiktoken" version = "0.9.0"