Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions eval_protocol/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,11 +519,11 @@ class EvalMetadata(BaseModel):
class CostMetrics(BaseModel):
"""Cost metrics for LLM API calls."""

input_cost_usd: Optional[float] = Field(None, description="Cost in USD for input tokens.")
input_cost: Optional[float] = Field(None, description="Cost in USD for input tokens.")

output_cost_usd: Optional[float] = Field(None, description="Cost in USD for output tokens.")
output_cost: Optional[float] = Field(None, description="Cost in USD for output tokens.")

total_cost_usd: Optional[float] = Field(None, description="Total cost in USD for the API call.")
total_cost: Optional[float] = Field(None, description="Total cost in USD for the API call.")


class ExecutionMetadata(BaseModel):
Expand Down Expand Up @@ -560,6 +560,11 @@ class ExecutionMetadata(BaseModel):
description="Processing duration in seconds for this evaluation row. Note that if it gets retried, this will be the duration of the last attempt.",
)

experiment_duration_seconds: Optional[float] = Field(
default=None,
description="Processing duration in seconds for an entire experiment. Note that includes time it took for retries.",
)


class EvaluationRow(BaseModel):
"""
Expand Down
6 changes: 6 additions & 0 deletions eval_protocol/pytest/evaluation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import inspect
import os
import sys
import time
from collections import defaultdict
from typing import Any, Callable
from typing_extensions import Unpack
Expand Down Expand Up @@ -212,6 +213,7 @@ async def wrapper_body(**kwargs: Unpack[ParameterizedTestKwargs]) -> None:
all_results: list[list[EvaluationRow]] = [[] for _ in range(num_runs)]

experiment_id = generate_id()
experiment_start_time = time.perf_counter()

def _log_eval_error(status: Status, rows: list[EvaluationRow] | None, passed: bool) -> None:
log_eval_status_and_rows(eval_metadata, rows, status, passed, active_logger)
Expand Down Expand Up @@ -506,6 +508,8 @@ async def execute_run_with_progress(run_idx: int, config):
tasks.append(asyncio.create_task(execute_run_with_progress(run_idx, config)))
await asyncio.gather(*tasks) # pyright: ignore[reportUnknownArgumentType]

experiment_duration_seconds = time.perf_counter() - experiment_start_time

# for groupwise mode, the result contains eval otuput from multiple completion_params, we need to differentiate them
# rollout_id is used to differentiate the result from different completion_params
if mode == "groupwise":
Expand All @@ -526,6 +530,7 @@ async def execute_run_with_progress(run_idx: int, config):
original_completion_params[rollout_id], # pyright: ignore[reportArgumentType]
test_func.__name__,
num_runs,
experiment_duration_seconds,
)
else:
postprocess(
Expand All @@ -537,6 +542,7 @@ async def execute_run_with_progress(run_idx: int, config):
completion_params, # pyright: ignore[reportArgumentType]
test_func.__name__,
num_runs,
experiment_duration_seconds,
)

except AssertionError:
Expand Down
2 changes: 2 additions & 0 deletions eval_protocol/pytest/evaluation_test_postprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def postprocess(
completion_params: CompletionParams,
test_func_name: str,
num_runs: int,
experiment_duration_seconds: float,
):
scores = [
sum([r.evaluation_result.score for r in result if r.evaluation_result]) / len(result) for result in all_results
Expand Down Expand Up @@ -68,6 +69,7 @@ def postprocess(
if r.evaluation_result is not None:
r.evaluation_result.agg_score = agg_score
r.evaluation_result.standard_error = standard_error
r.execution_metadata.experiment_duration_seconds = experiment_duration_seconds
active_logger.log(r)

# Optional: print and/or persist a summary artifact for CI
Expand Down
12 changes: 6 additions & 6 deletions eval_protocol/pytest/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,9 +307,9 @@ def add_cost_metrics(row: EvaluationRow) -> None:
# Can't calculate cost without usage stats or model info
if not row.execution_metadata.usage or not row.input_metadata.completion_params:
row.execution_metadata.cost_metrics = CostMetrics(
input_cost_usd=0.0,
output_cost_usd=0.0,
total_cost_usd=0.0,
input_cost=0.0,
output_cost=0.0,
total_cost=0.0,
)
return

Expand Down Expand Up @@ -348,7 +348,7 @@ def add_cost_metrics(row: EvaluationRow) -> None:

# Set all cost metrics on the row
row.execution_metadata.cost_metrics = CostMetrics(
input_cost_usd=input_cost,
output_cost_usd=output_cost,
total_cost_usd=total_cost,
input_cost=input_cost,
output_cost=output_cost,
total_cost=total_cost,
)
30 changes: 15 additions & 15 deletions tests/pytest/test_execution_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@ def test_single_model_with_provider(self):
add_cost_metrics(row)

assert row.execution_metadata.cost_metrics is not None
assert row.execution_metadata.cost_metrics.input_cost_usd is not None
assert row.execution_metadata.cost_metrics.output_cost_usd is not None
assert row.execution_metadata.cost_metrics.total_cost_usd is not None
assert row.execution_metadata.cost_metrics.input_cost is not None
assert row.execution_metadata.cost_metrics.output_cost is not None
assert row.execution_metadata.cost_metrics.total_cost is not None

@pytest.mark.skip(reason="Revisit when we figure out how to get cost metrics for multi-agent Pydantic.")
def test_pydantic_ai_multi_agent_model_dict(self):
Expand Down Expand Up @@ -54,9 +54,9 @@ def test_pydantic_ai_multi_agent_model_dict(self):
add_cost_metrics(row)

assert row.execution_metadata.cost_metrics is not None
assert row.execution_metadata.cost_metrics.input_cost_usd is not None
assert row.execution_metadata.cost_metrics.output_cost_usd is not None
assert row.execution_metadata.cost_metrics.total_cost_usd is not None
assert row.execution_metadata.cost_metrics.input_cost is not None
assert row.execution_metadata.cost_metrics.output_cost is not None
assert row.execution_metadata.cost_metrics.total_cost is not None

def test_no_usage_stats(self):
"""Test case with no usage statistics."""
Expand All @@ -69,9 +69,9 @@ def test_no_usage_stats(self):
add_cost_metrics(row)

assert row.execution_metadata.cost_metrics is not None
assert row.execution_metadata.cost_metrics.input_cost_usd == 0.0
assert row.execution_metadata.cost_metrics.output_cost_usd == 0.0
assert row.execution_metadata.cost_metrics.total_cost_usd == 0.0
assert row.execution_metadata.cost_metrics.input_cost == 0.0
assert row.execution_metadata.cost_metrics.output_cost == 0.0
assert row.execution_metadata.cost_metrics.total_cost == 0.0

def test_no_completion_params(self):
"""Test case with empty completion parameters."""
Expand All @@ -86,9 +86,9 @@ def test_no_completion_params(self):
add_cost_metrics(row)

assert row.execution_metadata.cost_metrics is not None
assert row.execution_metadata.cost_metrics.input_cost_usd == 0.0
assert row.execution_metadata.cost_metrics.output_cost_usd == 0.0
assert row.execution_metadata.cost_metrics.total_cost_usd == 0.0
assert row.execution_metadata.cost_metrics.input_cost == 0.0
assert row.execution_metadata.cost_metrics.output_cost == 0.0
assert row.execution_metadata.cost_metrics.total_cost == 0.0

def test_zero_tokens(self):
"""Test case with zero token usage."""
Expand All @@ -103,9 +103,9 @@ def test_zero_tokens(self):
add_cost_metrics(row)

assert row.execution_metadata.cost_metrics is not None
assert row.execution_metadata.cost_metrics.input_cost_usd == 0.0
assert row.execution_metadata.cost_metrics.output_cost_usd == 0.0
assert row.execution_metadata.cost_metrics.total_cost_usd == 0.0
assert row.execution_metadata.cost_metrics.input_cost == 0.0
assert row.execution_metadata.cost_metrics.output_cost == 0.0
assert row.execution_metadata.cost_metrics.total_cost == 0.0

def test_provider_mapping_variations(self):
"""Test different provider mappings."""
Expand Down
Loading