|
6 | 6 | from dataclasses import replace |
7 | 7 | from typing import Any, Literal |
8 | 8 |
|
| 9 | +from litellm.cost_calculator import cost_per_token |
9 | 10 | from tqdm import tqdm |
10 | 11 |
|
11 | 12 | from eval_protocol.dataset_logger.dataset_logger import DatasetLogger |
12 | 13 | from eval_protocol.models import ( |
| 14 | + CostMetrics, |
| 15 | + CompletionParams, |
13 | 16 | EvalMetadata, |
14 | 17 | EvaluationRow, |
15 | 18 | EvaluationThreshold, |
16 | 19 | EvaluationThresholdDict, |
17 | 20 | Status, |
18 | | - CompletionParams, |
19 | 21 | ) |
20 | 22 | from eval_protocol.pytest.rollout_processor import RolloutProcessor |
21 | 23 | from eval_protocol.pytest.types import ( |
@@ -298,3 +300,55 @@ def extract_effort_tag(params: dict) -> str | None: # pyright: ignore[reportMis |
298 | 300 | except Exception: |
299 | 301 | return None |
300 | 302 | return None |
| 303 | + |
| 304 | + |
| 305 | +def add_cost_metrics(row: EvaluationRow) -> None: |
| 306 | + """Calculate and add cost metrics for an EvaluationRow based on its usage data.""" |
| 307 | + # Can't calculate cost without usage stats or model info |
| 308 | + if not row.execution_metadata.usage or not row.input_metadata.completion_params: |
| 309 | + row.execution_metadata.cost_metrics = CostMetrics( |
| 310 | + input_cost_usd=0.0, |
| 311 | + output_cost_usd=0.0, |
| 312 | + total_cost_usd=0.0, |
| 313 | + ) |
| 314 | + return |
| 315 | + |
| 316 | + model = row.input_metadata.completion_params.get("model", "unknown") |
| 317 | + provider = row.input_metadata.completion_params.get("provider") |
| 318 | + |
| 319 | + # Pydantic AI mapping to LiteLLM format |
| 320 | + # TODO: make more generic for other frameworks too. |
| 321 | + provider_mapping = { |
| 322 | + "fireworks": "fireworks_ai", |
| 323 | + "together": "together_ai", |
| 324 | + "openai": "", # No prefix needed |
| 325 | + "azure": "azure", |
| 326 | + "deepseek": "deepseek", |
| 327 | + "openrouter": "openrouter", |
| 328 | + "grok": "grok", |
| 329 | + "github": "github", |
| 330 | + "heroku": "heroku", |
| 331 | + } |
| 332 | + |
| 333 | + if provider and provider in provider_mapping: |
| 334 | + litellm_prefix = provider_mapping[provider] |
| 335 | + model_id = f"{litellm_prefix}/{model}" if litellm_prefix else model |
| 336 | + else: |
| 337 | + model_id = model |
| 338 | + |
| 339 | + usage = row.execution_metadata.usage |
| 340 | + |
| 341 | + input_tokens = usage.prompt_tokens or 0 |
| 342 | + output_tokens = usage.completion_tokens or 0 |
| 343 | + |
| 344 | + input_cost, output_cost = cost_per_token( |
| 345 | + model=model_id, prompt_tokens=input_tokens, completion_tokens=output_tokens |
| 346 | + ) |
| 347 | + total_cost = input_cost + output_cost |
| 348 | + |
| 349 | + # Set all cost metrics on the row |
| 350 | + row.execution_metadata.cost_metrics = CostMetrics( |
| 351 | + input_cost_usd=input_cost, |
| 352 | + output_cost_usd=output_cost, |
| 353 | + total_cost_usd=total_cost, |
| 354 | + ) |
0 commit comments