diff --git a/eval_protocol/adapters/langfuse.py b/eval_protocol/adapters/langfuse.py index 4f11cd3f..04b34d45 100644 --- a/eval_protocol/adapters/langfuse.py +++ b/eval_protocol/adapters/langfuse.py @@ -4,8 +4,9 @@ to EvaluationRow format for use in evaluation pipelines. """ +from langfuse.api.resources.commons.types.observations_view import ObservationsView import logging -from datetime import datetime +from datetime import datetime, timedelta from typing import Any, Dict, Iterator, List, Optional, cast from eval_protocol.models import EvaluationRow, InputMetadata, Message @@ -13,7 +14,10 @@ logger = logging.getLogger(__name__) try: - from langfuse import Langfuse # pyright: ignore[reportPrivateImportUsage] + from langfuse import get_client # pyright: ignore[reportPrivateImportUsage] + from langfuse.api.resources.trace.types.traces import Traces + from langfuse.api.resources.commons.types.trace import Trace + from langfuse.api.resources.commons.types.trace_with_full_details import TraceWithFullDetails LANGFUSE_AVAILABLE = True except ImportError: @@ -45,26 +49,12 @@ class LangfuseAdapter: ... )) """ - def __init__( - self, - public_key: str, - secret_key: str, - host: str = "https://cloud.langfuse.com", - project_id: Optional[str] = None, - ): - """Initialize the Langfuse adapter. - - Args: - public_key: Langfuse public key - secret_key: Langfuse secret key - host: Langfuse host URL (default: https://cloud.langfuse.com) - project_id: Optional project ID to filter traces - """ + def __init__(self): + """Initialize the Langfuse adapter.""" if not LANGFUSE_AVAILABLE: raise ImportError("Langfuse not installed. Install with: pip install 'eval-protocol[langfuse]'") - self.client = cast(Any, Langfuse)(public_key=public_key, secret_key=secret_key, host=host) - self.project_id = project_id + self.client = get_client() def get_evaluation_rows( self, @@ -72,8 +62,7 @@ def get_evaluation_rows( tags: Optional[List[str]] = None, user_id: Optional[str] = None, session_id: Optional[str] = None, - from_timestamp: Optional[datetime] = None, - to_timestamp: Optional[datetime] = None, + hours_back: Optional[int] = None, include_tool_calls: bool = True, ) -> List[EvaluationRow]: """Pull traces from Langfuse and convert to EvaluationRow format. @@ -83,16 +72,24 @@ def get_evaluation_rows( tags: Filter by specific tags user_id: Filter by user ID session_id: Filter by session ID - from_timestamp: Filter traces after this timestamp - to_timestamp: Filter traces before this timestamp + hours_back: Filter traces from this many hours ago include_tool_calls: Whether to include tool calling traces Yields: EvaluationRow: Converted evaluation rows """ # Get traces from Langfuse using new API + + if hours_back: + to_timestamp = datetime.now() + from_timestamp = to_timestamp - timedelta(hours=hours_back) + else: + to_timestamp = None + from_timestamp = None + eval_rows = [] - traces = self.client.api.trace.list( + + traces: Traces = self.client.api.trace.list( limit=limit, tags=tags, user_id=user_id, @@ -128,7 +125,7 @@ def get_evaluation_rows_by_ids( eval_rows = [] for trace_id in trace_ids: try: - trace = self.client.api.trace.get(trace_id) + trace: TraceWithFullDetails = self.client.api.trace.get(trace_id) eval_row = self._convert_trace_to_evaluation_row(trace, include_tool_calls) if eval_row: eval_rows.append(eval_row) @@ -137,7 +134,9 @@ def get_evaluation_rows_by_ids( continue return eval_rows - def _convert_trace_to_evaluation_row(self, trace: Any, include_tool_calls: bool = True) -> Optional[EvaluationRow]: + def _convert_trace_to_evaluation_row( + self, trace: Trace, include_tool_calls: bool = True + ) -> Optional[EvaluationRow]: """Convert a Langfuse trace to EvaluationRow format. Args: @@ -147,63 +146,34 @@ def _convert_trace_to_evaluation_row(self, trace: Any, include_tool_calls: bool Returns: EvaluationRow or None if conversion fails """ - # TODO: move this logic into an adapter in llm_judge.py. langfuse.py should just return traces try: - # Get observations (generations, spans) from the trace - observations_response = self.client.api.observations.get_many(trace_id=trace.id, limit=100) - observations = ( - observations_response.data if hasattr(observations_response, "data") else list(observations_response) - ) + trace = self.client.api.trace.get("2d9f3474-83ab-4431-9788-049ca4219023") + + # Extract messages from trace input and output + messages = self._extract_messages_from_trace(trace, include_tool_calls) - # Look for conversation history in trace output or observations - messages = [] - conversation_found = False - - # Look for complete conversation in observations - if not conversation_found: - for obs in observations: - # Check each observation's output for complete conversation array - if hasattr(obs, "output") and obs.output: - conversation = self._extract_conversation_from_output(obs.output) - if conversation: - messages = conversation - conversation_found = True - break - - # Fallback: try extracting from observations using old method - if not conversation_found: - messages = self._extract_messages_from_observations(observations, include_tool_calls) + # Extract tools if available + tools = None + if include_tool_calls and isinstance(trace.input, dict) and "tools" in trace.input: + tools = trace.input["tools"] if not messages: return None - # Extract metadata - input_metadata = self._create_input_metadata(trace, observations) - - # Extract ground truth if available (from trace metadata or tags) - ground_truth = self._extract_ground_truth(trace) - - # Extract tools if available - tools = self._extract_tools(observations, trace) if include_tool_calls else None - return EvaluationRow( messages=messages, tools=tools, - input_metadata=input_metadata, - ground_truth=ground_truth, ) except (AttributeError, ValueError, KeyError) as e: logger.error("Error converting trace %s: %s", trace.id, e) return None - def _extract_messages_from_observations( - self, observations: List[Any], include_tool_calls: bool = True - ) -> List[Message]: - """Extract messages from Langfuse observations. + def _extract_messages_from_trace(self, trace: Any, include_tool_calls: bool = True) -> List[Message]: + """Extract messages from Langfuse trace input and output. Args: - observations: List of Langfuse observation objects + trace: Langfuse trace object include_tool_calls: Whether to include tool calling information Returns: @@ -211,45 +181,44 @@ def _extract_messages_from_observations( """ messages = [] - # Sort observations by timestamp - sorted_observations = sorted(observations, key=lambda x: x.start_time or datetime.min) - - for obs in sorted_observations: - try: - if hasattr(obs, "input") and obs.input: - # Handle different input formats - if isinstance(obs.input, dict): - if "messages" in obs.input: - # OpenAI-style messages format - for msg in obs.input["messages"]: - messages.append(self._dict_to_message(msg, include_tool_calls)) - elif "role" in obs.input: - # Single message format - messages.append(self._dict_to_message(obs.input, include_tool_calls)) - elif "prompt" in obs.input: - # Simple prompt format - messages.append(Message(role="user", content=str(obs.input["prompt"]))) - elif isinstance(obs.input, str): - # Simple string input - messages.append(Message(role="user", content=obs.input)) - - if hasattr(obs, "output") and obs.output: - # Handle output - if isinstance(obs.output, dict): - if "content" in obs.output: - messages.append(Message(role="assistant", content=str(obs.output["content"]))) - elif "message" in obs.output: - msg_dict = obs.output["message"] - messages.append(self._dict_to_message(msg_dict, include_tool_calls)) - else: - # Fallback: convert entire output to string - messages.append(Message(role="assistant", content=str(obs.output))) - elif isinstance(obs.output, str): - messages.append(Message(role="assistant", content=obs.output)) + try: + # Handle trace input + if hasattr(trace, "input") and trace.input: + if isinstance(trace.input, dict): + if "messages" in trace.input: + # OpenAI-style messages format + for msg in trace.input["messages"]: + messages.append(self._dict_to_message(msg, include_tool_calls)) + elif "role" in trace.input: + # Single message format + messages.append(self._dict_to_message(trace.input, include_tool_calls)) + elif "prompt" in trace.input: + # Simple prompt format + messages.append(Message(role="user", content=str(trace.input["prompt"]))) + elif isinstance(trace.input, list): + # Direct list of message dicts + for msg in trace.input: + messages.append(self._dict_to_message(msg, include_tool_calls)) + elif isinstance(trace.input, str): + # Simple string input + messages.append(Message(role="user", content=trace.input)) + + # Handle trace output + if hasattr(trace, "output") and trace.output: + if isinstance(trace.output, dict): + if "content" in trace.output: + messages.append(Message(role="assistant", content=str(trace.output["content"]))) + elif "message" in trace.output: + msg_dict = trace.output["message"] + messages.append(self._dict_to_message(msg_dict, include_tool_calls)) + else: + # Fallback: convert entire output to string + messages.append(Message(role="assistant", content=str(trace.output))) + elif isinstance(trace.output, str): + messages.append(Message(role="assistant", content=trace.output)) - except (AttributeError, ValueError, KeyError) as e: - logger.warning("Error processing observation %s: %s", obs.id, e) - continue + except (AttributeError, ValueError, KeyError) as e: + logger.warning("Error processing trace %s: %s", trace.id, e) return messages @@ -290,239 +259,8 @@ def _dict_to_message(self, msg_dict: Dict[str, Any], include_tool_calls: bool = function_call=function_call, ) - def _extract_conversation_from_output(self, output: Any) -> Optional[List[Message]]: - """Extract conversation history from PydanticAI agent run output. - - This looks for the conversation format like: - [ - {"role": "user", "content": "..."}, - {"role": "assistant", "content": "...", "tool_calls": [...]}, - {"role": "tool", "content": "...", "name": "execute_sql"}, - ... - ] - - Args: - output: The output object to search for conversation history - - Returns: - List of Message objects or None if no conversation found - """ - messages = [] - - try: - # Handle different output formats - conversation_data = None - - if isinstance(output, list): - # Direct list of messages - conversation_data = output - elif isinstance(output, dict): - # Look for conversation in various nested formats - if "messages" in output: - conversation_data = output["messages"] - elif "conversation" in output: - conversation_data = output["conversation"] - elif "history" in output: - conversation_data = output["history"] - elif "agent_run" in output: # Handle nested conversation data PydanticAI style - agent_run = output["agent_run"] - if isinstance(agent_run, dict) and "messages" in agent_run: - conversation_data = agent_run["messages"] - elif len(output.keys()) == 1: - # Single key, check if its value is a list - single_key = list(output.keys())[0] - if isinstance(output[single_key], list): - conversation_data = output[single_key] - elif isinstance(output, str): - # Try to parse JSON string - import json - - try: - parsed = json.loads(output) - return self._extract_conversation_from_output(parsed) - except (json.JSONDecodeError, ValueError): - pass - - # Parse conversation data into messages - if conversation_data and isinstance(conversation_data, list): - for msg_data in conversation_data: - if isinstance(msg_data, dict) and "role" in msg_data: - role = msg_data.get("role") - if role is None: - continue - content = msg_data.get("content", "") - - # Handle tool calls in assistant messages - tool_calls = None - if role == "assistant" and "tool_calls" in msg_data: - tool_calls = msg_data["tool_calls"] - - # Handle tool responses - name = None - if role == "tool": - name = msg_data.get("name") - - messages.append(Message(role=role, content=content, name=name, tool_calls=tool_calls)) - - return messages if messages else None - - except Exception as e: - logger.debug("Error extracting conversation from output: %s", e) - return None - - def _create_input_metadata(self, trace: Any, observations: List[Any]) -> InputMetadata: - """Create InputMetadata from trace and observations. - - Args: - trace: Langfuse trace object - observations: List of observation objects - - Returns: - InputMetadata object - """ - # Extract completion parameters from trace input first, then observations - completion_params = {} - - # First check trace input for evaluation test completion_params - if hasattr(trace, "input") and trace.input: - if isinstance(trace.input, dict): - kwargs = trace.input.get("kwargs", {}) - if "completion_params" in kwargs: - trace_completion_params = kwargs["completion_params"] - if trace_completion_params and isinstance(trace_completion_params, dict): - completion_params.update(trace_completion_params) - - # Fallback: Look for model parameters in observations if not found in trace input - if not completion_params: - for obs in observations: - if hasattr(obs, "model") and obs.model: - completion_params["model"] = obs.model - if hasattr(obs, "model_parameters") and obs.model_parameters: - params = obs.model_parameters - if "temperature" in params: - completion_params["temperature"] = params["temperature"] - if "max_tokens" in params: - completion_params["max_tokens"] = params["max_tokens"] - if "top_p" in params: - completion_params["top_p"] = params["top_p"] - break - - # Create dataset info from trace metadata - dataset_info = { - "trace_id": trace.id, - "trace_name": getattr(trace, "name", None), - "trace_tags": getattr(trace, "tags", []), - "langfuse_project_id": self.project_id, - } - - # Add trace metadata if available - if hasattr(trace, "metadata") and trace.metadata: - dataset_info["trace_metadata"] = trace.metadata - - # Create session data - session_data = { - "session_id": getattr(trace, "session_id", None), - "user_id": getattr(trace, "user_id", None), - "timestamp": getattr(trace, "timestamp", None), - "langfuse_trace_url": ( - f"{self.client.host}/project/{self.project_id}/traces/{trace.id}" if self.project_id else None - ), - } - - return InputMetadata( - row_id=trace.id, - completion_params=completion_params, - dataset_info=dataset_info, - session_data=session_data, - ) - - def _extract_ground_truth(self, trace: Any) -> Optional[str]: - """Extract ground truth from trace if available. - - Args: - trace: Langfuse trace object - - Returns: - Ground truth string or None - """ - # First check trace input for evaluation test data structure - if hasattr(trace, "input") and trace.input: - if isinstance(trace.input, dict): - # Handle EP test format: kwargs.input_rows[0].ground_truth - kwargs = trace.input.get("kwargs", {}) - if "input_rows" in kwargs: - input_rows = kwargs["input_rows"] - if input_rows and len(input_rows) > 0: - first_row = input_rows[0] - if isinstance(first_row, dict) and "ground_truth" in first_row: - ground_truth = first_row["ground_truth"] - if ground_truth: # Only return if not None/empty - return str(ground_truth) - - # Check trace metadata for ground truth - if hasattr(trace, "metadata") and trace.metadata: - if isinstance(trace.metadata, dict): - return trace.metadata.get("ground_truth") or trace.metadata.get("expected_answer") - - # Check tags for ground truth indicators - if hasattr(trace, "tags") and trace.tags: - for tag in trace.tags: - if tag.startswith("ground_truth:"): - return tag.replace("ground_truth:", "", 1) - - return None - - def _extract_tools(self, observations: List[Any], trace: Any = None) -> Optional[List[Dict[str, Any]]]: - """Extract tool definitions from trace metadata or observations. - Args: - observations: List of observation objects - trace: Trace object that may contain metadata with tools +def create_langfuse_adapter() -> LangfuseAdapter: + """Factory function to create a Langfuse adapter.""" - Returns: - List of tool definitions or None - """ - # First, try to extract tools from trace metadata (preferred) - if trace and hasattr(trace, "metadata") and trace.metadata: - if isinstance(trace.metadata, dict) and "tools" in trace.metadata: - tools_from_metadata = trace.metadata["tools"] - if tools_from_metadata: - return tools_from_metadata - - # Fallback: extract from observations - tools = [] - for obs in observations: - if hasattr(obs, "input") and obs.input and isinstance(obs.input, dict): - if "tools" in obs.input: - tools.extend(obs.input["tools"]) - elif "functions" in obs.input: - # Convert functions to tools format - for func in obs.input["functions"]: - tools.append({"type": "function", "function": func}) - - return tools if tools else None - - -def create_langfuse_adapter( - public_key: str, - secret_key: str, - host: str = "https://cloud.langfuse.com", - project_id: Optional[str] = None, -) -> LangfuseAdapter: - """Factory function to create a Langfuse adapter. - - Args: - public_key: Langfuse public key - secret_key: Langfuse secret key - host: Langfuse host URL - project_id: Optional project ID - - Returns: - LangfuseAdapter instance - """ - return LangfuseAdapter( - public_key=public_key, - secret_key=secret_key, - host=host, - project_id=project_id, - ) + return LangfuseAdapter() diff --git a/eval_protocol/pytest/default_single_turn_rollout_process.py b/eval_protocol/pytest/default_single_turn_rollout_process.py index 96df88b2..2b4bf893 100644 --- a/eval_protocol/pytest/default_single_turn_rollout_process.py +++ b/eval_protocol/pytest/default_single_turn_rollout_process.py @@ -30,7 +30,7 @@ async def process_row(row: EvaluationRow) -> EvaluationRow: if len(row.messages) == 0: raise ValueError("Messages is empty. Please provide a non-empty dataset") - messages_payload = [{"role": m.role, "content": m.content} for m in row.messages] + messages_payload = [message.model_dump() for message in row.messages] request_params = {"messages": messages_payload, **config.completion_params} # Ensure caching is disabled only for this request (review feedback) diff --git a/eval_protocol/pytest/evaluation_test.py b/eval_protocol/pytest/evaluation_test.py index 612a47a6..a7ec65f3 100644 --- a/eval_protocol/pytest/evaluation_test.py +++ b/eval_protocol/pytest/evaluation_test.py @@ -84,6 +84,7 @@ def evaluation_test( steps: int = 30, mode: EvaluationTestMode = "pointwise", combine_datasets: bool = True, + preprocess_fn: Callable[[list[EvaluationRow]], list[EvaluationRow]] | None = None, logger: DatasetLogger | None = None, exception_handler_config: ExceptionHandlerConfig | None = None, ) -> Callable[[TestFunction], TestFunction]: @@ -150,6 +151,9 @@ def evaluation_test( mode: Evaluation mode. "pointwise" (default) applies test function to each row (rollout result). "groupwise" applies test function to a group of rollout results from the same original row (for use cases such as dpo/grpo). "all" applies test function to the whole dataset. + preprocess_fn: Optional preprocessing function that takes a list of EvaluationRow objects + and returns a modified list. Useful for transformations like splitting multi-turn conversations, + filtering data, or other preprocessing steps before rollout execution. logger: DatasetLogger to use for logging. If not provided, a default logger will be used. exception_handler_config: Configuration for exception handling and backoff retry logic. If not provided, a default configuration will be used with common retryable exceptions. @@ -244,6 +248,9 @@ def _log_eval_error(status: Status, rows: list[EvaluationRow] | None, passed: bo else: raise ValueError("No input dataset, input messages, or input rows provided") + if preprocess_fn: + data = preprocess_fn(data) + for row in data: # generate a stable row_id for each row if row.input_metadata.row_id is None: @@ -266,11 +273,9 @@ def _log_eval_error(status: Status, rows: list[EvaluationRow] | None, passed: bo passed=None, ) for row in data: - # Only set completion_params if they don't already exist - if not row.input_metadata.completion_params: - row.input_metadata.completion_params = ( - completion_params if completion_params is not None else {} - ) + row.input_metadata.completion_params = ( + completion_params if completion_params is not None else {} + ) # Add mode to session_data if row.input_metadata.session_data is None: row.input_metadata.session_data = {} diff --git a/eval_protocol/pytest/evaluation_test_postprocess.py b/eval_protocol/pytest/evaluation_test_postprocess.py index 8b069cec..6e44c620 100644 --- a/eval_protocol/pytest/evaluation_test_postprocess.py +++ b/eval_protocol/pytest/evaluation_test_postprocess.py @@ -62,15 +62,17 @@ def postprocess( passed = success_passed and standard_error_passed # Update eval metadata passed field for all results - for result in all_results: - for r in result: - if r.eval_metadata is not None: - r.eval_metadata.passed = passed - if r.evaluation_result is not None: - r.evaluation_result.agg_score = agg_score - r.evaluation_result.standard_error = standard_error - r.execution_metadata.experiment_duration_seconds = experiment_duration_seconds - active_logger.log(r) + for results in all_results: + for result in results: + if result.eval_metadata is not None: + result.eval_metadata.passed = passed + if result.evaluation_result is not None: + if result.evaluation_result.agg_score is None: + result.evaluation_result.agg_score = agg_score + if result.evaluation_result.standard_error is None: + result.evaluation_result.standard_error = standard_error + result.execution_metadata.experiment_duration_seconds = experiment_duration_seconds + active_logger.log(result) # Optional: print and/or persist a summary artifact for CI try: diff --git a/eval_protocol/quickstart/llm_judge.py b/eval_protocol/quickstart/llm_judge.py new file mode 100644 index 00000000..ab4ca37a --- /dev/null +++ b/eval_protocol/quickstart/llm_judge.py @@ -0,0 +1,221 @@ +""" +Default LLM judge for Eval Protocol. Inspired by Arena-Hard-Auto. +""" + +import os +from datetime import datetime, timedelta +from typing import List, Dict, Any, Optional +import pandas as pd +from tqdm import tqdm + +import pytest + +from eval_protocol.models import EvaluateResult, EvaluationRow, MetricResult +from eval_protocol.pytest import evaluation_test +from eval_protocol.pytest.default_single_turn_rollout_process import SingleTurnRolloutProcessor +from eval_protocol.quickstart.utils import pairwise_judgment, split_multi_turn_rows, serialize_message +from eval_protocol.adapters.langfuse import create_langfuse_adapter + +import concurrent.futures +from concurrent.futures import ThreadPoolExecutor + +JUDGE_CONFIGS = { + "gpt-4.1": { + "model": "gpt-4.1", + "temperature": 0.0, + "max_tokens": 16000, + "max_concurrency": 64, + }, + "gemini-2.5-pro": { + "model": "gemini-2.5-pro", + "temperature": 1.0, + "max_tokens": 32000, + "api_key": os.getenv("GEMINI_API_KEY"), + "base_url": "https://generativelanguage.googleapis.com/v1beta/openai/", + "max_concurrency": 32, + }, +} + + +def fetch_langfuse_traces_as_evaluation_rows( + limit: int = 100, + tags: Optional[List[str]] = None, + user_id: Optional[str] = None, + session_id: Optional[str] = None, + hours_back: Optional[int] = None, + include_tool_calls: bool = True, +) -> List[EvaluationRow]: + try: + adapter = create_langfuse_adapter() + + return adapter.get_evaluation_rows( + limit=limit, + tags=tags, + user_id=user_id, + session_id=session_id, + hours_back=hours_back, + include_tool_calls=include_tool_calls, + ) + + except Exception as e: + print(f"❌ LangfuseAdapter failed: {e}") + return [] + + +@pytest.mark.skipif(os.environ.get("CI") == "true", reason="Skip in CI") +@pytest.mark.asyncio +@evaluation_test( + input_rows=[fetch_langfuse_traces_as_evaluation_rows()], + completion_params=[ + {"model": "gpt-5"}, + { + # "max_tokens": 131000, + # "extra_body": {"reasoning_effort": "low"}, + "model": "fireworks_ai/accounts/fireworks/models/qwen3-235b-a22b-instruct-2507", + }, + ], + rollout_processor=SingleTurnRolloutProcessor(), + preprocess_fn=split_multi_turn_rows, + mode="all", +) +async def test_llm_judge(rows: list[EvaluationRow]) -> list[EvaluationRow]: + """ + Simplified LLM Judge for Arena-Hard-Auto style pairwise comparisons. + + Each row contains: + - messages[:-1]: Question/prompt (conversation context) + - messages[-1]: Model B's answer (comparison model response) + - ground_truth: Model A's answer (original assistant response) + """ + + judge_name = "gemini-2.5-pro" # Edit to which judge you'd like to use. Configs at top of file. + + if not rows: + print("❌ No evaluation rows provided") + return rows + + print(f"🔄 Processing {len(rows)} evaluation rows for LLM judging...") + + model_name = rows[0].input_metadata.completion_params.get("model", "unknown_model") + + def run_judgment(row: EvaluationRow) -> Optional[Dict[str, Any]]: + """Run pairwise judgment for a single evaluation row.""" + if not row.messages: + return None + + question_text = "\n".join([serialize_message(msg) for msg in row.messages[:-1]]) + model_a_answer = row.ground_truth + model_b_answer = serialize_message(row.messages[-1]) + + games = [] + + # Round 1: A vs B (original vs comparison) + result1 = pairwise_judgment( + question_text=question_text, + answer_a=model_a_answer, + answer_b=model_b_answer, + tools=row.tools, + judge_config=JUDGE_CONFIGS[judge_name], + ) + games.append(result1) + + # Round 2: B vs A (comparison vs original) + result2 = pairwise_judgment( + question_text=question_text, + answer_a=model_b_answer, + answer_b=model_a_answer, + tools=row.tools, + judge_config=JUDGE_CONFIGS[judge_name], + ) + games.append(result2) + + row.evaluation_result = EvaluateResult( + score=0.0, + reason=f"LLM Judge comparison: Round 1: {result1['score']}, Round 2: {result2['score']}" + if result1 and result2 + else "Failed to get judgement scores", + metrics={ + "round1_judgment": MetricResult( + score=0.0, reason=result1["judgment"] if result1 else "Failed to get judgment reason" + ), + "round2_judgment": MetricResult( + score=0.0, reason=result2["judgment"] if result2 else "Failed to get judgment reason" + ), + }, + ) + + return {"model": model_name, "games": games} + + judgments = [] + max_concurrency = JUDGE_CONFIGS[judge_name]["max_concurrency"] + + with ThreadPoolExecutor(max_workers=max_concurrency) as executor: + futures = [executor.submit(run_judgment, row) for row in rows] + + for future in tqdm(concurrent.futures.as_completed(futures), total=len(futures), desc="Generating judgments"): + result = future.result() + if result and result["games"][0] and result["games"][1]: + judgments.append(result) + + if not judgments: + print("❌ No valid judgments generated") + return rows + + print(f"✅ Generated {len(judgments)} valid judgments") + + # Convert to scores for leaderboard + label_to_score = { + "A>B": [1], + "A>>B": [1] * 3, + "A=B": [0.5], + "A<A": [0], + "B>>A": [0] * 3, + "B=A": [0.5], + "B<>B]] +2. Assistant A is slightly better: [[A>B]] +3. Tie, relatively the same: [[A=B]] +4. Assistant B is slightly better: [[B>A]] +5. Assistant B is significantly better: [[B>>A]] + +Example output: "My final verdict is tie: [[A=B]]".""" + + +def get_score(judgment, patterns): + """Extract judgment score from text. From arena-hard-auto/gen_judgment.py""" + for pattern in patterns: + pattern = re.compile(pattern) + + matches = pattern.findall(judgment.upper()) + matches = [m for m in matches if m != ""] + + if len(set(matches)) > 0: + return matches[-1].strip("\n") + return None + + +def serialize_message(msg: Message) -> str: + parts = [f"{msg.role}: {msg.content}"] + + if msg.tool_calls: + for tool_call in msg.tool_calls: + tool_name = tool_call.function.name + tool_args = tool_call.function.arguments + parts.append(f"[Tool Call: {tool_name}({tool_args})]") + + return "\n".join(parts) + + +def split_multi_turn_rows(data: list[EvaluationRow]) -> list[EvaluationRow]: + """ + Split multi-turn conversation rows into individual evaluation rows for each assistant message. + + Args: + data: List of EvaluationRow objects + + Returns: + List of expanded EvaluationRow objects, one for each assistant message + """ + expanded_rows = [] + + for row in data: + messages = row.messages + tools = row.tools + input_metadata = row.input_metadata + + assistant_positions = [] + for i, message in enumerate(messages): + if message.role == "assistant": + assistant_positions.append(i) + + # Create separate evaluation rows on each assistant message (where the comparison model will respond) + for pos in assistant_positions: + messages_before_assistant = messages[:pos] + assistant_message = messages[pos] + + ground_truth_message = serialize_message(assistant_message) + + expanded_rows.append( + EvaluationRow( + messages=messages_before_assistant, + tools=tools, + input_metadata=input_metadata, + ground_truth=ground_truth_message, + ) + ) + + return expanded_rows + + +def pairwise_judgment(question_text, answer_a, answer_b, tools, judge_config): + """Pairwise judgment function. Adapted from arena-hard-auto/gen_judgment.py""" + user_prompt = f"""<|User Prompt|> +{question_text} + +<|The Start of Assistant A's Answer|> +{answer_a} +<|The End of Assistant A's Answer|> + +<|The Start of Assistant B's Answer|> +{answer_b} +<|The End of Assistant B's Answer|>""" + + messages = [ + { + "role": "system", + "content": OG_ARENA_HARD_PROMPT, + }, + { + "role": "user", + "content": user_prompt, + }, + ] + + try: + from openai import OpenAI + + client = OpenAI(api_key=judge_config["api_key"], base_url=judge_config["base_url"]) + + api_params = { + "model": judge_config["model"], + "messages": messages, # type: ignore + "temperature": judge_config["temperature"], + "max_tokens": judge_config["max_tokens"], + } + + if tools: + api_params["tools"] = tools + api_params["tool_choice"] = ( + "none" # Judge can see tools to help in response, but won't actually try to call them + ) + + response = client.chat.completions.create(**api_params) + + judgment_text = response.choices[0].message.content + if not judgment_text: + return None + + except Exception as e: + print(f"Error getting judgment from OpenAI: {e}") + return None + + score = get_score(judgment_text, [r"\[\[([AB<>=]+)\]\]", r"\[([AB<>=]+)\]"]) + + result = { + "score": score, + "judgment": judgment_text, + "prompt": messages, + } + return result diff --git a/tests/chinook/pydantic/agent.py b/tests/chinook/pydantic/agent.py index bdf20b08..2b260fd4 100644 --- a/tests/chinook/pydantic/agent.py +++ b/tests/chinook/pydantic/agent.py @@ -27,7 +27,7 @@ def setup_agent(orchestrator_agent_model: Model): """ agent = Agent( - system_prompt=SYSTEM_PROMPT, + instructions=SYSTEM_PROMPT, model=orchestrator_agent_model, instrument=True, ) diff --git a/tests/chinook/pydantic/test_pydantic_chinook.py b/tests/chinook/pydantic/test_pydantic_chinook.py index 4d88dfb5..9cdac0ee 100644 --- a/tests/chinook/pydantic/test_pydantic_chinook.py +++ b/tests/chinook/pydantic/test_pydantic_chinook.py @@ -61,6 +61,9 @@ async def test_simple_query(row: EvaluationRow) -> EvaluationRow: assert hasattr(row, "tools"), "Row missing 'tools' attribute" assert row.tools == expected_tools, f"Tools validation failed. Expected: {expected_tools}, Got: {row.tools}" + # assert that there is a system message + assert row.messages[0].role == "system" + last_assistant_message = row.last_assistant_message() if last_assistant_message is None: row.evaluation_result = EvaluateResult( diff --git a/tests/chinook/pydantic/test_pydantic_complex_queries_responses.py b/tests/chinook/pydantic/test_pydantic_complex_queries_responses.py new file mode 100644 index 00000000..d94e98fa --- /dev/null +++ b/tests/chinook/pydantic/test_pydantic_complex_queries_responses.py @@ -0,0 +1,47 @@ +from collections.abc import Awaitable, Callable +import os +from typing_extensions import cast +from pydantic_ai import Agent +from pydantic_ai.models.openai import OpenAIResponsesModel, OpenAIResponsesModelSettings +import pytest + +from eval_protocol.models import EvaluationRow +from eval_protocol.pytest import evaluation_test +from eval_protocol.pytest.types import RolloutProcessorConfig +from tests.chinook.dataset import collect_dataset +from tests.chinook.pydantic.agent import setup_agent +from tests.pytest.test_pydantic_agent import PydanticAgentRolloutProcessor + +# IMPORTANT: import must be renamed to something without the "test_" prefix to +# avoid pytest discovering the import as a test +from tests.chinook.pydantic.test_pydantic_complex_queries import test_pydantic_complex_queries as eval + + +def agent_factory(config: RolloutProcessorConfig) -> Agent: + model_name = config.completion_params["model"] + model_settings = OpenAIResponsesModelSettings() + model = OpenAIResponsesModel(model_name) + return setup_agent(model) + + +@pytest.mark.skipif( + os.environ.get("CI") == "true", + reason="This was only run locally to generate traces in Responses API", +) +@pytest.mark.asyncio +@evaluation_test( + input_rows=[collect_dataset()], + completion_params=[ + { + "model": "gpt-4o", + }, + ], + rollout_processor=PydanticAgentRolloutProcessor(agent_factory), +) +async def test_pydantic_complex_queries_responses(row: EvaluationRow) -> EvaluationRow: + """ + Evaluation of complex queries for the Chinook database using PydanticAI + """ + casted_evaluation_test = cast(Callable[[EvaluationRow], Awaitable[EvaluationRow]], eval) + evaluated_row = await casted_evaluation_test(row) + return evaluated_row