Skip to content

Commit 328f047

Browse files
author
SentienceDEV
committed
token usage summary
1 parent 0aa9f51 commit 328f047

File tree

2 files changed

+323
-0
lines changed

2 files changed

+323
-0
lines changed

predicate/agents/planner_executor_agent.py

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,97 @@
5353
from .recovery import RecoveryCheckpoint, RecoveryState
5454

5555

56+
# ---------------------------------------------------------------------------
57+
# Token Usage Tracking
58+
# ---------------------------------------------------------------------------
59+
60+
61+
@dataclass
62+
class TokenUsageTotals:
63+
"""Accumulated token counts for a single role or model."""
64+
65+
calls: int = 0
66+
prompt_tokens: int = 0
67+
completion_tokens: int = 0
68+
total_tokens: int = 0
69+
70+
def add(self, resp: LLMResponse) -> None:
71+
"""Add token counts from an LLM response."""
72+
self.calls += 1
73+
pt = resp.prompt_tokens if isinstance(resp.prompt_tokens, int) else 0
74+
ct = resp.completion_tokens if isinstance(resp.completion_tokens, int) else 0
75+
tt = resp.total_tokens if isinstance(resp.total_tokens, int) else (pt + ct)
76+
self.prompt_tokens += max(0, int(pt))
77+
self.completion_tokens += max(0, int(ct))
78+
self.total_tokens += max(0, int(tt))
79+
80+
81+
class _TokenUsageCollector:
82+
"""Collects token usage statistics by role (planner/executor) and model."""
83+
84+
def __init__(self) -> None:
85+
self._by_role: dict[str, TokenUsageTotals] = {}
86+
self._by_model: dict[str, TokenUsageTotals] = {}
87+
88+
def record(self, *, role: str, resp: LLMResponse) -> None:
89+
"""Record token usage from an LLM response."""
90+
self._by_role.setdefault(role, TokenUsageTotals()).add(resp)
91+
m = str(resp.model_name or "").strip() or "unknown"
92+
self._by_model.setdefault(m, TokenUsageTotals()).add(resp)
93+
94+
def reset(self) -> None:
95+
"""Clear all recorded statistics."""
96+
self._by_role.clear()
97+
self._by_model.clear()
98+
99+
def summary(self) -> dict[str, Any]:
100+
"""
101+
Get a summary of all token usage.
102+
103+
Returns:
104+
Dictionary with:
105+
- total: aggregate counts across all calls
106+
- by_role: breakdown by role (planner, executor, replan)
107+
- by_model: breakdown by model name
108+
"""
109+
def _sum(items: dict[str, TokenUsageTotals]) -> TokenUsageTotals:
110+
out = TokenUsageTotals()
111+
for t in items.values():
112+
out.calls += t.calls
113+
out.prompt_tokens += t.prompt_tokens
114+
out.completion_tokens += t.completion_tokens
115+
out.total_tokens += t.total_tokens
116+
return out
117+
118+
total = _sum(self._by_role)
119+
return {
120+
"total": {
121+
"calls": total.calls,
122+
"prompt_tokens": total.prompt_tokens,
123+
"completion_tokens": total.completion_tokens,
124+
"total_tokens": total.total_tokens,
125+
},
126+
"by_role": {
127+
k: {
128+
"calls": v.calls,
129+
"prompt_tokens": v.prompt_tokens,
130+
"completion_tokens": v.completion_tokens,
131+
"total_tokens": v.total_tokens,
132+
}
133+
for k, v in self._by_role.items()
134+
},
135+
"by_model": {
136+
k: {
137+
"calls": v.calls,
138+
"prompt_tokens": v.prompt_tokens,
139+
"completion_tokens": v.completion_tokens,
140+
"total_tokens": v.total_tokens,
141+
}
142+
for k, v in self._by_model.items()
143+
},
144+
}
145+
146+
56147
# ---------------------------------------------------------------------------
57148
# IntentHeuristics Protocol
58149
# ---------------------------------------------------------------------------
@@ -729,6 +820,7 @@ class RunOutcome:
729820
step_outcomes: list[StepOutcome] = field(default_factory=list)
730821
total_duration_ms: int = 0
731822
error: str | None = None
823+
token_usage: dict[str, Any] | None = None # Token usage summary from get_token_stats()
732824

733825

734826
# ---------------------------------------------------------------------------
@@ -1323,6 +1415,37 @@ def __init__(
13231415
# Current automation task (for run-level context)
13241416
self._current_task: AutomationTask | None = None
13251417

1418+
# Token usage tracking
1419+
self._token_collector = _TokenUsageCollector()
1420+
1421+
def get_token_stats(self) -> dict[str, Any]:
1422+
"""
1423+
Get token usage statistics for the agent session.
1424+
1425+
Returns:
1426+
Dictionary with:
1427+
- total: aggregate counts (calls, prompt_tokens, completion_tokens, total_tokens)
1428+
- by_role: breakdown by role (planner, executor, replan, vision)
1429+
- by_model: breakdown by model name
1430+
1431+
Example:
1432+
>>> stats = agent.get_token_stats()
1433+
>>> print(f"Total tokens: {stats['total']['total_tokens']}")
1434+
>>> print(f"Planner tokens: {stats['by_role'].get('planner', {}).get('total_tokens', 0)}")
1435+
"""
1436+
return self._token_collector.summary()
1437+
1438+
def reset_token_stats(self) -> None:
1439+
"""Reset token usage statistics to zero."""
1440+
self._token_collector.reset()
1441+
1442+
def _record_token_usage(self, role: str, resp: LLMResponse) -> None:
1443+
"""Record token usage from an LLM response."""
1444+
try:
1445+
self._token_collector.record(role=role, resp=resp)
1446+
except Exception:
1447+
pass # Don't fail on token tracking errors
1448+
13261449
def _format_context(self, snap: Snapshot, goal: str) -> str:
13271450
"""
13281451
Format snapshot for LLM context.
@@ -2069,6 +2192,7 @@ async def plan(
20692192
temperature=self.config.planner_temperature,
20702193
max_new_tokens=max_tokens,
20712194
)
2195+
self._record_token_usage("planner", resp)
20722196
last_output = resp.content
20732197

20742198
if self.config.verbose:
@@ -2169,6 +2293,7 @@ async def replan(
21692293
temperature=self.config.planner_temperature,
21702294
max_new_tokens=1024,
21712295
)
2296+
self._record_token_usage("replan", resp)
21722297
last_output = resp.content
21732298

21742299
try:
@@ -2327,6 +2452,7 @@ async def _scroll_to_find_element(
23272452
temperature=self.config.executor_temperature,
23282453
max_new_tokens=self.config.executor_max_tokens,
23292454
)
2455+
self._record_token_usage("executor", resp)
23302456
parsed_action, parsed_args = self._parse_action(resp.content)
23312457

23322458
if parsed_action == "CLICK" and parsed_args:
@@ -2398,6 +2524,7 @@ async def _execute_optional_substeps(
23982524
temperature=self.config.executor_temperature,
23992525
max_new_tokens=self.config.executor_max_tokens,
24002526
)
2527+
self._record_token_usage("executor", resp)
24012528
parsed_action, parsed_args = self._parse_action(resp.content)
24022529
if parsed_action == "CLICK" and parsed_args:
24032530
element_id = parsed_args[0]
@@ -2917,6 +3044,7 @@ async def _execute_step(
29173044
temperature=self.config.executor_temperature,
29183045
max_new_tokens=self.config.executor_max_tokens,
29193046
)
3047+
self._record_token_usage("executor", resp)
29203048
llm_response = resp.content
29213049

29223050
if self.config.verbose:
@@ -3467,6 +3595,7 @@ async def run(
34673595
step_outcomes=step_outcomes,
34683596
total_duration_ms=int((time.time() - start_time) * 1000),
34693597
error=error,
3598+
token_usage=self.get_token_stats(),
34703599
)
34713600

34723601
# Emit run end

tests/unit/test_planner_executor_agent.py

Lines changed: 194 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1351,3 +1351,197 @@ def test_modal_config_has_required_fields_for_drawer_dismissal(self) -> None:
13511351
assert hasattr(config, "role_filter")
13521352
assert hasattr(config, "max_attempts")
13531353
assert hasattr(config, "min_new_elements")
1354+
1355+
1356+
# ---------------------------------------------------------------------------
1357+
# Test Token Usage Tracking
1358+
# ---------------------------------------------------------------------------
1359+
1360+
1361+
class TestTokenUsageTracking:
1362+
"""Tests for token usage tracking in PlannerExecutorAgent."""
1363+
1364+
def test_token_usage_totals_add(self) -> None:
1365+
"""TokenUsageTotals should accumulate tokens correctly."""
1366+
from predicate.agents.planner_executor_agent import TokenUsageTotals
1367+
from predicate.llm_provider import LLMResponse
1368+
1369+
totals = TokenUsageTotals()
1370+
assert totals.calls == 0
1371+
assert totals.prompt_tokens == 0
1372+
assert totals.completion_tokens == 0
1373+
assert totals.total_tokens == 0
1374+
1375+
# Add first response
1376+
resp1 = LLMResponse(
1377+
content="test",
1378+
prompt_tokens=100,
1379+
completion_tokens=50,
1380+
total_tokens=150,
1381+
)
1382+
totals.add(resp1)
1383+
assert totals.calls == 1
1384+
assert totals.prompt_tokens == 100
1385+
assert totals.completion_tokens == 50
1386+
assert totals.total_tokens == 150
1387+
1388+
# Add second response
1389+
resp2 = LLMResponse(
1390+
content="test2",
1391+
prompt_tokens=200,
1392+
completion_tokens=75,
1393+
total_tokens=275,
1394+
)
1395+
totals.add(resp2)
1396+
assert totals.calls == 2
1397+
assert totals.prompt_tokens == 300
1398+
assert totals.completion_tokens == 125
1399+
assert totals.total_tokens == 425
1400+
1401+
def test_token_usage_totals_handles_none_values(self) -> None:
1402+
"""TokenUsageTotals should handle None token counts gracefully."""
1403+
from predicate.agents.planner_executor_agent import TokenUsageTotals
1404+
from predicate.llm_provider import LLMResponse
1405+
1406+
totals = TokenUsageTotals()
1407+
resp = LLMResponse(
1408+
content="test",
1409+
prompt_tokens=None,
1410+
completion_tokens=None,
1411+
total_tokens=None,
1412+
)
1413+
totals.add(resp)
1414+
assert totals.calls == 1
1415+
assert totals.prompt_tokens == 0
1416+
assert totals.completion_tokens == 0
1417+
assert totals.total_tokens == 0
1418+
1419+
def test_token_usage_collector_records_by_role(self) -> None:
1420+
"""_TokenUsageCollector should track tokens by role."""
1421+
from predicate.agents.planner_executor_agent import _TokenUsageCollector
1422+
from predicate.llm_provider import LLMResponse
1423+
1424+
collector = _TokenUsageCollector()
1425+
1426+
resp_planner = LLMResponse(
1427+
content="plan",
1428+
prompt_tokens=500,
1429+
completion_tokens=200,
1430+
total_tokens=700,
1431+
model_name="gpt-4o",
1432+
)
1433+
collector.record(role="planner", resp=resp_planner)
1434+
1435+
resp_executor = LLMResponse(
1436+
content="action",
1437+
prompt_tokens=100,
1438+
completion_tokens=20,
1439+
total_tokens=120,
1440+
model_name="gpt-4o-mini",
1441+
)
1442+
collector.record(role="executor", resp=resp_executor)
1443+
1444+
summary = collector.summary()
1445+
1446+
# Check total
1447+
assert summary["total"]["calls"] == 2
1448+
assert summary["total"]["prompt_tokens"] == 600
1449+
assert summary["total"]["completion_tokens"] == 220
1450+
assert summary["total"]["total_tokens"] == 820
1451+
1452+
# Check by_role
1453+
assert "planner" in summary["by_role"]
1454+
assert summary["by_role"]["planner"]["calls"] == 1
1455+
assert summary["by_role"]["planner"]["total_tokens"] == 700
1456+
1457+
assert "executor" in summary["by_role"]
1458+
assert summary["by_role"]["executor"]["calls"] == 1
1459+
assert summary["by_role"]["executor"]["total_tokens"] == 120
1460+
1461+
def test_token_usage_collector_records_by_model(self) -> None:
1462+
"""_TokenUsageCollector should track tokens by model name."""
1463+
from predicate.agents.planner_executor_agent import _TokenUsageCollector
1464+
from predicate.llm_provider import LLMResponse
1465+
1466+
collector = _TokenUsageCollector()
1467+
1468+
resp1 = LLMResponse(
1469+
content="test",
1470+
prompt_tokens=100,
1471+
completion_tokens=50,
1472+
total_tokens=150,
1473+
model_name="gpt-4o",
1474+
)
1475+
collector.record(role="planner", resp=resp1)
1476+
1477+
resp2 = LLMResponse(
1478+
content="test",
1479+
prompt_tokens=50,
1480+
completion_tokens=25,
1481+
total_tokens=75,
1482+
model_name="gpt-4o-mini",
1483+
)
1484+
collector.record(role="executor", resp=resp2)
1485+
1486+
summary = collector.summary()
1487+
1488+
# Check by_model
1489+
assert "gpt-4o" in summary["by_model"]
1490+
assert summary["by_model"]["gpt-4o"]["total_tokens"] == 150
1491+
1492+
assert "gpt-4o-mini" in summary["by_model"]
1493+
assert summary["by_model"]["gpt-4o-mini"]["total_tokens"] == 75
1494+
1495+
def test_token_usage_collector_reset(self) -> None:
1496+
"""_TokenUsageCollector reset should clear all data."""
1497+
from predicate.agents.planner_executor_agent import _TokenUsageCollector
1498+
from predicate.llm_provider import LLMResponse
1499+
1500+
collector = _TokenUsageCollector()
1501+
resp = LLMResponse(
1502+
content="test",
1503+
prompt_tokens=100,
1504+
completion_tokens=50,
1505+
total_tokens=150,
1506+
)
1507+
collector.record(role="planner", resp=resp)
1508+
assert collector.summary()["total"]["calls"] == 1
1509+
1510+
collector.reset()
1511+
summary = collector.summary()
1512+
assert summary["total"]["calls"] == 0
1513+
assert summary["total"]["total_tokens"] == 0
1514+
assert summary["by_role"] == {}
1515+
assert summary["by_model"] == {}
1516+
1517+
def test_run_outcome_has_token_usage_field(self) -> None:
1518+
"""RunOutcome should have token_usage field."""
1519+
from predicate.agents.planner_executor_agent import RunOutcome
1520+
1521+
outcome = RunOutcome(
1522+
run_id="test-run",
1523+
task="test task",
1524+
success=True,
1525+
steps_completed=3,
1526+
steps_total=3,
1527+
replans_used=0,
1528+
)
1529+
# Default should be None
1530+
assert outcome.token_usage is None
1531+
1532+
# Should accept token usage dict
1533+
outcome_with_tokens = RunOutcome(
1534+
run_id="test-run",
1535+
task="test task",
1536+
success=True,
1537+
steps_completed=3,
1538+
steps_total=3,
1539+
replans_used=0,
1540+
token_usage={
1541+
"total": {"calls": 5, "total_tokens": 1000},
1542+
"by_role": {"planner": {"calls": 2, "total_tokens": 700}},
1543+
"by_model": {"gpt-4o": {"calls": 2, "total_tokens": 700}},
1544+
},
1545+
)
1546+
assert outcome_with_tokens.token_usage is not None
1547+
assert outcome_with_tokens.token_usage["total"]["total_tokens"] == 1000

0 commit comments

Comments
 (0)