From d94d2568071bdd4f79ffef63f9a4ac92de623969 Mon Sep 17 00:00:00 2001 From: Derek Xu Date: Wed, 27 Aug 2025 18:07:43 -0700 Subject: [PATCH] Added cost --- eval_protocol/mcp/execution/manager.py | 2 +- eval_protocol/models.py | 21 +++++++++++---- eval_protocol/pytest/evaluation_test.py | 10 ++++++-- eval_protocol/pytest/utils.py | 34 +++++++++++++++++++++++-- 4 files changed, 57 insertions(+), 10 deletions(-) diff --git a/eval_protocol/mcp/execution/manager.py b/eval_protocol/mcp/execution/manager.py index d6cb2b83..e22f39e5 100644 --- a/eval_protocol/mcp/execution/manager.py +++ b/eval_protocol/mcp/execution/manager.py @@ -126,7 +126,7 @@ async def _execute_with_semaphore(idx): evaluation_row.messages = messages evaluation_row.tools = shared_tool_schema - evaluation_row.usage = CompletionUsage(**trajectory.usage) + evaluation_row.execution_metadata.usage = CompletionUsage(**trajectory.usage) evaluation_row.input_metadata.completion_params = { "model": policy.model_id, "temperature": getattr(policy, "temperature", None), diff --git a/eval_protocol/models.py b/eval_protocol/models.py index f930c717..47099903 100644 --- a/eval_protocol/models.py +++ b/eval_protocol/models.py @@ -462,6 +462,16 @@ class EvalMetadata(BaseModel): passed: Optional[bool] = Field(None, description="Whether the evaluation passed based on the threshold") +class CostMetrics(BaseModel): + """Cost metrics for LLM API calls.""" + + input_cost_usd: Optional[float] = Field(None, description="Cost in USD for input tokens.") + + output_cost_usd: Optional[float] = Field(None, description="Cost in USD for output tokens.") + + total_cost_usd: Optional[float] = Field(None, description="Total cost in USD for the API call.") + + class ExecutionMetadata(BaseModel): """Metadata about the execution of the evaluation.""" @@ -485,6 +495,12 @@ class ExecutionMetadata(BaseModel): description=("The ID of the run that this row belongs to."), ) + usage: Optional[CompletionUsage] = Field( + default=None, description="Token usage statistics from LLM calls during execution." + ) + + cost_metrics: Optional[CostMetrics] = Field(default=None, description="Cost breakdown for LLM API calls.") + class EvaluationRow(BaseModel): """ @@ -530,11 +546,6 @@ class EvaluationRow(BaseModel): description="Metadata about the execution of the evaluation.", ) - # LLM usage statistics - usage: Optional[CompletionUsage] = Field( - default=None, description="Token usage statistics from LLM calls during execution." - ) - created_at: datetime = Field(default_factory=datetime.now, description="The timestamp when the row was created.") eval_metadata: Optional[EvalMetadata] = Field( diff --git a/eval_protocol/pytest/evaluation_test.py b/eval_protocol/pytest/evaluation_test.py index 47d98eb6..04551af1 100644 --- a/eval_protocol/pytest/evaluation_test.py +++ b/eval_protocol/pytest/evaluation_test.py @@ -47,8 +47,8 @@ from eval_protocol.pytest.utils import ( AggregationMethod, aggregate, + calculate_cost_metrics_for_row, create_dynamically_parameterized_wrapper, - deep_update_dict, extract_effort_tag, generate_parameter_combinations, log_eval_status_and_rows, @@ -633,7 +633,11 @@ async def _execute_eval_with_semaphore(**inner_kwargs): processed_dataset=inner_kwargs["rows"], evaluation_test_kwargs=kwargs.get("evaluation_test_kwargs") or {}, ) - if results is None or not isinstance(results, list): + if ( + results is None + or not isinstance(results, list) + or not all(isinstance(r, EvaluationRow) for r in results) + ): raise ValueError( f"Test function {test_func.__name__} did not return a list of EvaluationRow instances. You must return a list of EvaluationRow instances from your test function decorated with @evaluation_test." ) @@ -724,6 +728,8 @@ async def _collect_result(config, lst): all_results[i] = results for r in results: + calculate_cost_metrics_for_row(r) + print(r.execution_metadata.cost_metrics) if r.eval_metadata is not None: if r.rollout_status.is_error(): r.eval_metadata.status = Status.error( diff --git a/eval_protocol/pytest/utils.py b/eval_protocol/pytest/utils.py index f4bbaebe..fc451bcb 100644 --- a/eval_protocol/pytest/utils.py +++ b/eval_protocol/pytest/utils.py @@ -3,10 +3,12 @@ import os import re from dataclasses import replace -from typing import Any, Callable, Dict, List, Literal, Optional, Union +from typing import Any, Callable, List, Literal, Optional, Union + +from litellm import cost_per_token from eval_protocol.dataset_logger.dataset_logger import DatasetLogger -from eval_protocol.models import EvalMetadata, EvaluationRow, Status +from eval_protocol.models import CostMetrics, EvalMetadata, EvaluationRow, Status from eval_protocol.pytest.rollout_processor import RolloutProcessor from eval_protocol.pytest.types import ( CompletionParams, @@ -435,3 +437,31 @@ def extract_effort_tag(params: dict) -> Optional[str]: except Exception: return None return None + + +def calculate_cost_metrics_for_row(row: EvaluationRow) -> None: + """Calculate and set cost metrics for an EvaluationRow based on its usage data.""" + if not row.execution_metadata.usage: + return + + model_id = ( + row.input_metadata.completion_params.get("model", "unknown") + if row.input_metadata.completion_params + else "unknown" + ) + usage = row.execution_metadata.usage + + input_tokens = usage.prompt_tokens or 0 + output_tokens = usage.completion_tokens or 0 + + input_cost, output_cost = cost_per_token( + model=model_id, prompt_tokens=input_tokens, completion_tokens=output_tokens + ) + total_cost = input_cost + output_cost + + # Set all cost metrics on the row + row.execution_metadata.cost_metrics = CostMetrics( + input_cost_usd=input_cost, + output_cost_usd=output_cost, + total_cost_usd=total_cost, + )