Skip to content

Commit 3dbbfed

Browse files
authored
Add experiment timing and remove _usd (#158)
1 parent ca994a0 commit 3dbbfed

File tree

5 files changed

+37
-24
lines changed

5 files changed

+37
-24
lines changed

eval_protocol/models.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -519,11 +519,11 @@ class EvalMetadata(BaseModel):
519519
class CostMetrics(BaseModel):
520520
"""Cost metrics for LLM API calls."""
521521

522-
input_cost_usd: Optional[float] = Field(None, description="Cost in USD for input tokens.")
522+
input_cost: Optional[float] = Field(None, description="Cost in USD for input tokens.")
523523

524-
output_cost_usd: Optional[float] = Field(None, description="Cost in USD for output tokens.")
524+
output_cost: Optional[float] = Field(None, description="Cost in USD for output tokens.")
525525

526-
total_cost_usd: Optional[float] = Field(None, description="Total cost in USD for the API call.")
526+
total_cost: Optional[float] = Field(None, description="Total cost in USD for the API call.")
527527

528528

529529
class ExecutionMetadata(BaseModel):
@@ -560,6 +560,11 @@ class ExecutionMetadata(BaseModel):
560560
description="Processing duration in seconds for this evaluation row. Note that if it gets retried, this will be the duration of the last attempt.",
561561
)
562562

563+
experiment_duration_seconds: Optional[float] = Field(
564+
default=None,
565+
description="Processing duration in seconds for an entire experiment. Note that includes time it took for retries.",
566+
)
567+
563568

564569
class EvaluationRow(BaseModel):
565570
"""

eval_protocol/pytest/evaluation_test.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import inspect
33
import os
44
import sys
5+
import time
56
from collections import defaultdict
67
from typing import Any, Callable
78
from typing_extensions import Unpack
@@ -212,6 +213,7 @@ async def wrapper_body(**kwargs: Unpack[ParameterizedTestKwargs]) -> None:
212213
all_results: list[list[EvaluationRow]] = [[] for _ in range(num_runs)]
213214

214215
experiment_id = generate_id()
216+
experiment_start_time = time.perf_counter()
215217

216218
def _log_eval_error(status: Status, rows: list[EvaluationRow] | None, passed: bool) -> None:
217219
log_eval_status_and_rows(eval_metadata, rows, status, passed, active_logger)
@@ -506,6 +508,8 @@ async def execute_run_with_progress(run_idx: int, config):
506508
tasks.append(asyncio.create_task(execute_run_with_progress(run_idx, config)))
507509
await asyncio.gather(*tasks) # pyright: ignore[reportUnknownArgumentType]
508510

511+
experiment_duration_seconds = time.perf_counter() - experiment_start_time
512+
509513
# for groupwise mode, the result contains eval otuput from multiple completion_params, we need to differentiate them
510514
# rollout_id is used to differentiate the result from different completion_params
511515
if mode == "groupwise":
@@ -526,6 +530,7 @@ async def execute_run_with_progress(run_idx: int, config):
526530
original_completion_params[rollout_id], # pyright: ignore[reportArgumentType]
527531
test_func.__name__,
528532
num_runs,
533+
experiment_duration_seconds,
529534
)
530535
else:
531536
postprocess(
@@ -537,6 +542,7 @@ async def execute_run_with_progress(run_idx: int, config):
537542
completion_params, # pyright: ignore[reportArgumentType]
538543
test_func.__name__,
539544
num_runs,
545+
experiment_duration_seconds,
540546
)
541547

542548
except AssertionError:

eval_protocol/pytest/evaluation_test_postprocess.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ def postprocess(
2323
completion_params: CompletionParams,
2424
test_func_name: str,
2525
num_runs: int,
26+
experiment_duration_seconds: float,
2627
):
2728
scores = [
2829
sum([r.evaluation_result.score for r in result if r.evaluation_result]) / len(result) for result in all_results
@@ -68,6 +69,7 @@ def postprocess(
6869
if r.evaluation_result is not None:
6970
r.evaluation_result.agg_score = agg_score
7071
r.evaluation_result.standard_error = standard_error
72+
r.execution_metadata.experiment_duration_seconds = experiment_duration_seconds
7173
active_logger.log(r)
7274

7375
# Optional: print and/or persist a summary artifact for CI

eval_protocol/pytest/utils.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -307,9 +307,9 @@ def add_cost_metrics(row: EvaluationRow) -> None:
307307
# Can't calculate cost without usage stats or model info
308308
if not row.execution_metadata.usage or not row.input_metadata.completion_params:
309309
row.execution_metadata.cost_metrics = CostMetrics(
310-
input_cost_usd=0.0,
311-
output_cost_usd=0.0,
312-
total_cost_usd=0.0,
310+
input_cost=0.0,
311+
output_cost=0.0,
312+
total_cost=0.0,
313313
)
314314
return
315315

@@ -348,7 +348,7 @@ def add_cost_metrics(row: EvaluationRow) -> None:
348348

349349
# Set all cost metrics on the row
350350
row.execution_metadata.cost_metrics = CostMetrics(
351-
input_cost_usd=input_cost,
352-
output_cost_usd=output_cost,
353-
total_cost_usd=total_cost,
351+
input_cost=input_cost,
352+
output_cost=output_cost,
353+
total_cost=total_cost,
354354
)

tests/pytest/test_execution_metadata.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,9 @@ def test_single_model_with_provider(self):
2323
add_cost_metrics(row)
2424

2525
assert row.execution_metadata.cost_metrics is not None
26-
assert row.execution_metadata.cost_metrics.input_cost_usd is not None
27-
assert row.execution_metadata.cost_metrics.output_cost_usd is not None
28-
assert row.execution_metadata.cost_metrics.total_cost_usd is not None
26+
assert row.execution_metadata.cost_metrics.input_cost is not None
27+
assert row.execution_metadata.cost_metrics.output_cost is not None
28+
assert row.execution_metadata.cost_metrics.total_cost is not None
2929

3030
@pytest.mark.skip(reason="Revisit when we figure out how to get cost metrics for multi-agent Pydantic.")
3131
def test_pydantic_ai_multi_agent_model_dict(self):
@@ -54,9 +54,9 @@ def test_pydantic_ai_multi_agent_model_dict(self):
5454
add_cost_metrics(row)
5555

5656
assert row.execution_metadata.cost_metrics is not None
57-
assert row.execution_metadata.cost_metrics.input_cost_usd is not None
58-
assert row.execution_metadata.cost_metrics.output_cost_usd is not None
59-
assert row.execution_metadata.cost_metrics.total_cost_usd is not None
57+
assert row.execution_metadata.cost_metrics.input_cost is not None
58+
assert row.execution_metadata.cost_metrics.output_cost is not None
59+
assert row.execution_metadata.cost_metrics.total_cost is not None
6060

6161
def test_no_usage_stats(self):
6262
"""Test case with no usage statistics."""
@@ -69,9 +69,9 @@ def test_no_usage_stats(self):
6969
add_cost_metrics(row)
7070

7171
assert row.execution_metadata.cost_metrics is not None
72-
assert row.execution_metadata.cost_metrics.input_cost_usd == 0.0
73-
assert row.execution_metadata.cost_metrics.output_cost_usd == 0.0
74-
assert row.execution_metadata.cost_metrics.total_cost_usd == 0.0
72+
assert row.execution_metadata.cost_metrics.input_cost == 0.0
73+
assert row.execution_metadata.cost_metrics.output_cost == 0.0
74+
assert row.execution_metadata.cost_metrics.total_cost == 0.0
7575

7676
def test_no_completion_params(self):
7777
"""Test case with empty completion parameters."""
@@ -86,9 +86,9 @@ def test_no_completion_params(self):
8686
add_cost_metrics(row)
8787

8888
assert row.execution_metadata.cost_metrics is not None
89-
assert row.execution_metadata.cost_metrics.input_cost_usd == 0.0
90-
assert row.execution_metadata.cost_metrics.output_cost_usd == 0.0
91-
assert row.execution_metadata.cost_metrics.total_cost_usd == 0.0
89+
assert row.execution_metadata.cost_metrics.input_cost == 0.0
90+
assert row.execution_metadata.cost_metrics.output_cost == 0.0
91+
assert row.execution_metadata.cost_metrics.total_cost == 0.0
9292

9393
def test_zero_tokens(self):
9494
"""Test case with zero token usage."""
@@ -103,9 +103,9 @@ def test_zero_tokens(self):
103103
add_cost_metrics(row)
104104

105105
assert row.execution_metadata.cost_metrics is not None
106-
assert row.execution_metadata.cost_metrics.input_cost_usd == 0.0
107-
assert row.execution_metadata.cost_metrics.output_cost_usd == 0.0
108-
assert row.execution_metadata.cost_metrics.total_cost_usd == 0.0
106+
assert row.execution_metadata.cost_metrics.input_cost == 0.0
107+
assert row.execution_metadata.cost_metrics.output_cost == 0.0
108+
assert row.execution_metadata.cost_metrics.total_cost == 0.0
109109

110110
def test_provider_mapping_variations(self):
111111
"""Test different provider mappings."""

0 commit comments

Comments
 (0)