Skip to content

Commit d94d256

Browse files
committed
Added cost
1 parent c0bece6 commit d94d256

File tree

4 files changed

+57
-10
lines changed

4 files changed

+57
-10
lines changed

eval_protocol/mcp/execution/manager.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ async def _execute_with_semaphore(idx):
126126

127127
evaluation_row.messages = messages
128128
evaluation_row.tools = shared_tool_schema
129-
evaluation_row.usage = CompletionUsage(**trajectory.usage)
129+
evaluation_row.execution_metadata.usage = CompletionUsage(**trajectory.usage)
130130
evaluation_row.input_metadata.completion_params = {
131131
"model": policy.model_id,
132132
"temperature": getattr(policy, "temperature", None),

eval_protocol/models.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -462,6 +462,16 @@ class EvalMetadata(BaseModel):
462462
passed: Optional[bool] = Field(None, description="Whether the evaluation passed based on the threshold")
463463

464464

465+
class CostMetrics(BaseModel):
466+
"""Cost metrics for LLM API calls."""
467+
468+
input_cost_usd: Optional[float] = Field(None, description="Cost in USD for input tokens.")
469+
470+
output_cost_usd: Optional[float] = Field(None, description="Cost in USD for output tokens.")
471+
472+
total_cost_usd: Optional[float] = Field(None, description="Total cost in USD for the API call.")
473+
474+
465475
class ExecutionMetadata(BaseModel):
466476
"""Metadata about the execution of the evaluation."""
467477

@@ -485,6 +495,12 @@ class ExecutionMetadata(BaseModel):
485495
description=("The ID of the run that this row belongs to."),
486496
)
487497

498+
usage: Optional[CompletionUsage] = Field(
499+
default=None, description="Token usage statistics from LLM calls during execution."
500+
)
501+
502+
cost_metrics: Optional[CostMetrics] = Field(default=None, description="Cost breakdown for LLM API calls.")
503+
488504

489505
class EvaluationRow(BaseModel):
490506
"""
@@ -530,11 +546,6 @@ class EvaluationRow(BaseModel):
530546
description="Metadata about the execution of the evaluation.",
531547
)
532548

533-
# LLM usage statistics
534-
usage: Optional[CompletionUsage] = Field(
535-
default=None, description="Token usage statistics from LLM calls during execution."
536-
)
537-
538549
created_at: datetime = Field(default_factory=datetime.now, description="The timestamp when the row was created.")
539550

540551
eval_metadata: Optional[EvalMetadata] = Field(

eval_protocol/pytest/evaluation_test.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,8 @@
4747
from eval_protocol.pytest.utils import (
4848
AggregationMethod,
4949
aggregate,
50+
calculate_cost_metrics_for_row,
5051
create_dynamically_parameterized_wrapper,
51-
deep_update_dict,
5252
extract_effort_tag,
5353
generate_parameter_combinations,
5454
log_eval_status_and_rows,
@@ -633,7 +633,11 @@ async def _execute_eval_with_semaphore(**inner_kwargs):
633633
processed_dataset=inner_kwargs["rows"],
634634
evaluation_test_kwargs=kwargs.get("evaluation_test_kwargs") or {},
635635
)
636-
if results is None or not isinstance(results, list):
636+
if (
637+
results is None
638+
or not isinstance(results, list)
639+
or not all(isinstance(r, EvaluationRow) for r in results)
640+
):
637641
raise ValueError(
638642
f"Test function {test_func.__name__} did not return a list of EvaluationRow instances. You must return a list of EvaluationRow instances from your test function decorated with @evaluation_test."
639643
)
@@ -724,6 +728,8 @@ async def _collect_result(config, lst):
724728
all_results[i] = results
725729

726730
for r in results:
731+
calculate_cost_metrics_for_row(r)
732+
print(r.execution_metadata.cost_metrics)
727733
if r.eval_metadata is not None:
728734
if r.rollout_status.is_error():
729735
r.eval_metadata.status = Status.error(

eval_protocol/pytest/utils.py

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,12 @@
33
import os
44
import re
55
from dataclasses import replace
6-
from typing import Any, Callable, Dict, List, Literal, Optional, Union
6+
from typing import Any, Callable, List, Literal, Optional, Union
7+
8+
from litellm import cost_per_token
79

810
from eval_protocol.dataset_logger.dataset_logger import DatasetLogger
9-
from eval_protocol.models import EvalMetadata, EvaluationRow, Status
11+
from eval_protocol.models import CostMetrics, EvalMetadata, EvaluationRow, Status
1012
from eval_protocol.pytest.rollout_processor import RolloutProcessor
1113
from eval_protocol.pytest.types import (
1214
CompletionParams,
@@ -435,3 +437,31 @@ def extract_effort_tag(params: dict) -> Optional[str]:
435437
except Exception:
436438
return None
437439
return None
440+
441+
442+
def calculate_cost_metrics_for_row(row: EvaluationRow) -> None:
443+
"""Calculate and set cost metrics for an EvaluationRow based on its usage data."""
444+
if not row.execution_metadata.usage:
445+
return
446+
447+
model_id = (
448+
row.input_metadata.completion_params.get("model", "unknown")
449+
if row.input_metadata.completion_params
450+
else "unknown"
451+
)
452+
usage = row.execution_metadata.usage
453+
454+
input_tokens = usage.prompt_tokens or 0
455+
output_tokens = usage.completion_tokens or 0
456+
457+
input_cost, output_cost = cost_per_token(
458+
model=model_id, prompt_tokens=input_tokens, completion_tokens=output_tokens
459+
)
460+
total_cost = input_cost + output_cost
461+
462+
# Set all cost metrics on the row
463+
row.execution_metadata.cost_metrics = CostMetrics(
464+
input_cost_usd=input_cost,
465+
output_cost_usd=output_cost,
466+
total_cost_usd=total_cost,
467+
)

0 commit comments

Comments
 (0)