Skip to content

Commit 649c6d1

Browse files
authored
Cost Metrics added (#156)
* Cost Metrics added * remove pydantic usage tracking for now
1 parent 7c10e74 commit 649c6d1

File tree

8 files changed

+141
-24
lines changed

8 files changed

+141
-24
lines changed

eval_protocol/mcp/execution/manager.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -124,14 +124,10 @@ async def _execute_with_semaphore(idx):
124124

125125
evaluation_row.messages = messages
126126
evaluation_row.tools = shared_tool_schema
127-
# Some OpenAI SDK versions type CompletionUsage as a TypedDict; construct via cast to avoid ctor mismatches
128-
evaluation_row.usage = cast(
129-
CompletionUsage,
130-
{
131-
"prompt_tokens": trajectory.usage.get("prompt_tokens", 0),
132-
"completion_tokens": trajectory.usage.get("completion_tokens", 0),
133-
"total_tokens": trajectory.usage.get("total_tokens", 0),
134-
},
127+
evaluation_row.execution_metadata.usage = CompletionUsage(
128+
prompt_tokens=trajectory.usage.get("prompt_tokens", 0),
129+
completion_tokens=trajectory.usage.get("completion_tokens", 0),
130+
total_tokens=trajectory.usage.get("total_tokens", 0),
135131
)
136132
evaluation_row.input_metadata.completion_params = {
137133
"model": policy.model_id,

eval_protocol/models.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -516,6 +516,16 @@ class EvalMetadata(BaseModel):
516516
passed: Optional[bool] = Field(None, description="Whether the evaluation passed based on the threshold")
517517

518518

519+
class CostMetrics(BaseModel):
520+
"""Cost metrics for LLM API calls."""
521+
522+
input_cost_usd: Optional[float] = Field(None, description="Cost in USD for input tokens.")
523+
524+
output_cost_usd: Optional[float] = Field(None, description="Cost in USD for output tokens.")
525+
526+
total_cost_usd: Optional[float] = Field(None, description="Total cost in USD for the API call.")
527+
528+
519529
class ExecutionMetadata(BaseModel):
520530
"""Metadata about the execution of the evaluation."""
521531

@@ -535,10 +545,16 @@ class ExecutionMetadata(BaseModel):
535545
)
536546

537547
run_id: Optional[str] = Field(
538-
None,
548+
default=None,
539549
description=("The ID of the run that this row belongs to."),
540550
)
541551

552+
usage: Optional[CompletionUsage] = Field(
553+
default=None, description="Token usage statistics from LLM calls during execution."
554+
)
555+
556+
cost_metrics: Optional[CostMetrics] = Field(default=None, description="Cost breakdown for LLM API calls.")
557+
542558

543559
class EvaluationRow(BaseModel):
544560
"""
@@ -586,11 +602,6 @@ class EvaluationRow(BaseModel):
586602
description="Metadata about the execution of the evaluation.",
587603
)
588604

589-
# LLM usage statistics
590-
usage: Optional[CompletionUsage] = Field(
591-
default=None, description="Token usage statistics from LLM calls during execution."
592-
)
593-
594605
created_at: datetime = Field(default_factory=datetime.now, description="The timestamp when the row was created.")
595606

596607
eval_metadata: Optional[EvalMetadata] = Field(

eval_protocol/pytest/default_agent_rollout_processor.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from eval_protocol.mcp.execution.policy import LiteLLMPolicy
1414
from eval_protocol.mcp.mcp_multi_client import MCPMultiClient
1515
from eval_protocol.models import EvaluationRow, Message, ChatCompletionContentPartTextParam
16+
from openai.types import CompletionUsage
1617
from eval_protocol.pytest.rollout_processor import RolloutProcessor
1718
from eval_protocol.pytest.types import Dataset, RolloutProcessorConfig
1819
from pydantic import BaseModel
@@ -38,6 +39,11 @@ def __init__(self, model: str, row: EvaluationRow, config_path: str, logger: Dat
3839
self._policy = LiteLLMPolicy(model_id=model)
3940
self.mcp_client = MCPMultiClient(config_path=config_path) if config_path else None
4041
self.logger: DatasetLogger = logger
42+
self.usage = {
43+
"prompt_tokens": 0,
44+
"completion_tokens": 0,
45+
"total_tokens": 0,
46+
}
4147

4248
async def setup(self):
4349
if self.mcp_client:
@@ -166,6 +172,11 @@ async def _call_model(self, messages: list[Message], tools: Optional[List[dict[s
166172
payload_tools.append({"type": tool_type, "function": {"name": name, "parameters": params_payload}})
167173

168174
response = await self._policy._make_llm_call(messages=messages_payload, tools=payload_tools)
175+
176+
self.usage["prompt_tokens"] += response["usage"]["prompt_tokens"]
177+
self.usage["completion_tokens"] += response["usage"]["completion_tokens"]
178+
self.usage["total_tokens"] += response["usage"]["total_tokens"]
179+
169180
# Coerce content to a string to align with our Message model type expectations
170181
raw_content = response["choices"][0]["message"].get("content")
171182
if isinstance(raw_content, list):
@@ -238,6 +249,13 @@ async def process_row(row: EvaluationRow) -> EvaluationRow:
238249
try:
239250
await agent.setup()
240251
await agent.call_agent()
252+
253+
agent.evaluation_row.execution_metadata.usage = CompletionUsage(
254+
prompt_tokens=agent.usage["prompt_tokens"],
255+
completion_tokens=agent.usage["completion_tokens"],
256+
total_tokens=agent.usage["total_tokens"],
257+
)
258+
241259
return agent.evaluation_row
242260
finally:
243261
if agent.mcp_client:

eval_protocol/pytest/default_langchain_rollout_processor.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ class BaseMessage: # type: ignore
1010

1111

1212
from eval_protocol.models import EvaluationRow, Message
13+
from openai.types import CompletionUsage
1314
from eval_protocol.pytest.rollout_processor import RolloutProcessor
1415
from eval_protocol.pytest.types import RolloutProcessorConfig
1516

@@ -86,6 +87,28 @@ async def _invoke_wrapper(payload):
8687
else:
8788
result_messages = getattr(result_obj, "messages", [])
8889

90+
# TODO: i didn't see a langgraph example so couldn't fully test this. should uncomment and test when we have example ready.
91+
# total_input_tokens = 0
92+
# total_output_tokens = 0
93+
# total_tokens = 0
94+
95+
# for msg in result_messages:
96+
# if isinstance(msg, BaseMessage):
97+
# usage = getattr(msg, 'response_metadata', {})
98+
# else:
99+
# usage = msg.get("response_metadata", {})
100+
101+
# if usage:
102+
# total_input_tokens += usage.get("prompt_tokens", 0)
103+
# total_output_tokens += usage.get("completion_tokens", 0)
104+
# total_tokens += usage.get("total_tokens", 0)
105+
106+
# row.execution_metadata.usage = CompletionUsage(
107+
# prompt_tokens=total_input_tokens,
108+
# completion_tokens=total_output_tokens,
109+
# total_tokens=total_tokens,
110+
# )
111+
89112
def _serialize_message(msg: BaseMessage) -> Message:
90113
# Prefer SDK-level serializer
91114
try:

eval_protocol/pytest/default_pydantic_ai_rollout_processor.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from pydantic_ai.models import Model
77
from typing_extensions import override
88
from eval_protocol.models import EvaluationRow, Message
9+
from openai.types import CompletionUsage
910
from eval_protocol.pytest.rollout_processor import RolloutProcessor
1011
from eval_protocol.pytest.types import RolloutProcessorConfig
1112
from openai.types.chat import ChatCompletion, ChatCompletionMessage, ChatCompletionMessageParam
@@ -89,6 +90,15 @@ async def process_row(row: EvaluationRow) -> EvaluationRow:
8990
message_history=model_messages, model=model, usage_limits=config.kwargs.get("usage_limits")
9091
)
9192
row.messages = await self.convert_pyd_message_to_ep_message(response.all_messages())
93+
94+
# TODO: pydantic ai accumulates usage info across all models in multi-agent setup, so this simple tracking doesn't work for cost. to discuss with @dphuang2 when he's back.
95+
# usage_info = response.usage()
96+
# row.execution_metadata.usage = CompletionUsage(
97+
# prompt_tokens=usage_info.request_tokens or 0,
98+
# completion_tokens=usage_info.response_tokens or 0,
99+
# total_tokens=usage_info.total_tokens or 0,
100+
# )
101+
92102
return row
93103

94104
async def _sem_wrapper(r: EvaluationRow) -> EvaluationRow:

eval_protocol/pytest/default_single_turn_rollout_process.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
from eval_protocol.dataset_logger import default_logger
1111
from eval_protocol.models import EvaluationRow, Message
12+
from openai.types import CompletionUsage
1213
from eval_protocol.pytest.rollout_processor import RolloutProcessor
1314
from eval_protocol.pytest.types import RolloutProcessorConfig
1415

@@ -108,6 +109,12 @@ async def process_row(row: EvaluationRow) -> EvaluationRow:
108109
)
109110
]
110111

112+
row.execution_metadata.usage = CompletionUsage(
113+
prompt_tokens=response.usage.prompt_tokens,
114+
completion_tokens=response.usage.completion_tokens,
115+
total_tokens=response.usage.total_tokens,
116+
)
117+
111118
row.messages = messages
112119
default_logger.log(row)
113120
return row

eval_protocol/pytest/evaluation_test.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949

5050
from eval_protocol.pytest.utils import (
5151
AggregationMethod,
52+
add_cost_metrics,
5253
log_eval_status_and_rows,
5354
parse_ep_completion_params,
5455
parse_ep_max_concurrent_rollouts,
@@ -430,25 +431,22 @@ async def _collect_result(config, lst): # pyright: ignore[reportUnknownParamete
430431
processed_dataset=input_dataset, # pyright: ignore[reportUnknownArgumentType]
431432
evaluation_test_kwargs=kwargs.get("evaluation_test_kwargs") or {},
432433
)
433-
if results is None: # pyright: ignore[reportUnnecessaryComparison]
434-
raise ValueError(
435-
f"Test function {test_func.__name__} did not return an EvaluationRow instance. You must return an EvaluationRow instance from your test function decorated with @evaluation_test."
436-
)
437-
if not isinstance(results, list):
434+
if (
435+
results is None
436+
or not isinstance(results, list)
437+
or not all(isinstance(r, EvaluationRow) for r in results)
438+
):
438439
raise ValueError(
439440
f"Test function {test_func.__name__} did not return a list of EvaluationRow instances. You must return a list of EvaluationRow instances from your test function decorated with @evaluation_test."
440441
)
441442
if not results:
442443
raise ValueError(
443444
f"Test function {test_func.__name__} returned an empty list. You must return a non-empty list of EvaluationRow instances from your test function decorated with @evaluation_test."
444445
)
445-
if not all(isinstance(r, EvaluationRow) for r in results): # pyright: ignore[reportUnnecessaryIsInstance]
446-
raise ValueError(
447-
f"Test function {test_func.__name__} returned a list containing non-EvaluationRow instances. You must return a list of EvaluationRow instances from your test function decorated with @evaluation_test."
448-
)
449446
all_results[run_idx] = results
450447

451448
for r in results:
449+
add_cost_metrics(r)
452450
if r.eval_metadata is not None:
453451
if r.rollout_status.is_error():
454452
r.eval_metadata.status = Status.error(

eval_protocol/pytest/utils.py

Lines changed: 55 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,16 +6,18 @@
66
from dataclasses import replace
77
from typing import Any, Literal
88

9+
from litellm.cost_calculator import cost_per_token
910
from tqdm import tqdm
1011

1112
from eval_protocol.dataset_logger.dataset_logger import DatasetLogger
1213
from eval_protocol.models import (
14+
CostMetrics,
15+
CompletionParams,
1316
EvalMetadata,
1417
EvaluationRow,
1518
EvaluationThreshold,
1619
EvaluationThresholdDict,
1720
Status,
18-
CompletionParams,
1921
)
2022
from eval_protocol.pytest.rollout_processor import RolloutProcessor
2123
from eval_protocol.pytest.types import (
@@ -298,3 +300,55 @@ def extract_effort_tag(params: dict) -> str | None: # pyright: ignore[reportMis
298300
except Exception:
299301
return None
300302
return None
303+
304+
305+
def add_cost_metrics(row: EvaluationRow) -> None:
306+
"""Calculate and add cost metrics for an EvaluationRow based on its usage data."""
307+
# Can't calculate cost without usage stats or model info
308+
if not row.execution_metadata.usage or not row.input_metadata.completion_params:
309+
row.execution_metadata.cost_metrics = CostMetrics(
310+
input_cost_usd=0.0,
311+
output_cost_usd=0.0,
312+
total_cost_usd=0.0,
313+
)
314+
return
315+
316+
model = row.input_metadata.completion_params.get("model", "unknown")
317+
provider = row.input_metadata.completion_params.get("provider")
318+
319+
# Pydantic AI mapping to LiteLLM format
320+
# TODO: make more generic for other frameworks too.
321+
provider_mapping = {
322+
"fireworks": "fireworks_ai",
323+
"together": "together_ai",
324+
"openai": "", # No prefix needed
325+
"azure": "azure",
326+
"deepseek": "deepseek",
327+
"openrouter": "openrouter",
328+
"grok": "grok",
329+
"github": "github",
330+
"heroku": "heroku",
331+
}
332+
333+
if provider and provider in provider_mapping:
334+
litellm_prefix = provider_mapping[provider]
335+
model_id = f"{litellm_prefix}/{model}" if litellm_prefix else model
336+
else:
337+
model_id = model
338+
339+
usage = row.execution_metadata.usage
340+
341+
input_tokens = usage.prompt_tokens or 0
342+
output_tokens = usage.completion_tokens or 0
343+
344+
input_cost, output_cost = cost_per_token(
345+
model=model_id, prompt_tokens=input_tokens, completion_tokens=output_tokens
346+
)
347+
total_cost = input_cost + output_cost
348+
349+
# Set all cost metrics on the row
350+
row.execution_metadata.cost_metrics = CostMetrics(
351+
input_cost_usd=input_cost,
352+
output_cost_usd=output_cost,
353+
total_cost_usd=total_cost,
354+
)

0 commit comments

Comments
 (0)