|
53 | 53 | from .recovery import RecoveryCheckpoint, RecoveryState |
54 | 54 |
|
55 | 55 |
|
| 56 | +# --------------------------------------------------------------------------- |
| 57 | +# Token Usage Tracking |
| 58 | +# --------------------------------------------------------------------------- |
| 59 | + |
| 60 | + |
| 61 | +@dataclass |
| 62 | +class TokenUsageTotals: |
| 63 | + """Accumulated token counts for a single role or model.""" |
| 64 | + |
| 65 | + calls: int = 0 |
| 66 | + prompt_tokens: int = 0 |
| 67 | + completion_tokens: int = 0 |
| 68 | + total_tokens: int = 0 |
| 69 | + |
| 70 | + def add(self, resp: LLMResponse) -> None: |
| 71 | + """Add token counts from an LLM response.""" |
| 72 | + self.calls += 1 |
| 73 | + pt = resp.prompt_tokens if isinstance(resp.prompt_tokens, int) else 0 |
| 74 | + ct = resp.completion_tokens if isinstance(resp.completion_tokens, int) else 0 |
| 75 | + tt = resp.total_tokens if isinstance(resp.total_tokens, int) else (pt + ct) |
| 76 | + self.prompt_tokens += max(0, int(pt)) |
| 77 | + self.completion_tokens += max(0, int(ct)) |
| 78 | + self.total_tokens += max(0, int(tt)) |
| 79 | + |
| 80 | + |
| 81 | +class _TokenUsageCollector: |
| 82 | + """Collects token usage statistics by role (planner/executor) and model.""" |
| 83 | + |
| 84 | + def __init__(self) -> None: |
| 85 | + self._by_role: dict[str, TokenUsageTotals] = {} |
| 86 | + self._by_model: dict[str, TokenUsageTotals] = {} |
| 87 | + |
| 88 | + def record(self, *, role: str, resp: LLMResponse) -> None: |
| 89 | + """Record token usage from an LLM response.""" |
| 90 | + self._by_role.setdefault(role, TokenUsageTotals()).add(resp) |
| 91 | + m = str(resp.model_name or "").strip() or "unknown" |
| 92 | + self._by_model.setdefault(m, TokenUsageTotals()).add(resp) |
| 93 | + |
| 94 | + def reset(self) -> None: |
| 95 | + """Clear all recorded statistics.""" |
| 96 | + self._by_role.clear() |
| 97 | + self._by_model.clear() |
| 98 | + |
| 99 | + def summary(self) -> dict[str, Any]: |
| 100 | + """ |
| 101 | + Get a summary of all token usage. |
| 102 | +
|
| 103 | + Returns: |
| 104 | + Dictionary with: |
| 105 | + - total: aggregate counts across all calls |
| 106 | + - by_role: breakdown by role (planner, executor, replan) |
| 107 | + - by_model: breakdown by model name |
| 108 | + """ |
| 109 | + def _sum(items: dict[str, TokenUsageTotals]) -> TokenUsageTotals: |
| 110 | + out = TokenUsageTotals() |
| 111 | + for t in items.values(): |
| 112 | + out.calls += t.calls |
| 113 | + out.prompt_tokens += t.prompt_tokens |
| 114 | + out.completion_tokens += t.completion_tokens |
| 115 | + out.total_tokens += t.total_tokens |
| 116 | + return out |
| 117 | + |
| 118 | + total = _sum(self._by_role) |
| 119 | + return { |
| 120 | + "total": { |
| 121 | + "calls": total.calls, |
| 122 | + "prompt_tokens": total.prompt_tokens, |
| 123 | + "completion_tokens": total.completion_tokens, |
| 124 | + "total_tokens": total.total_tokens, |
| 125 | + }, |
| 126 | + "by_role": { |
| 127 | + k: { |
| 128 | + "calls": v.calls, |
| 129 | + "prompt_tokens": v.prompt_tokens, |
| 130 | + "completion_tokens": v.completion_tokens, |
| 131 | + "total_tokens": v.total_tokens, |
| 132 | + } |
| 133 | + for k, v in self._by_role.items() |
| 134 | + }, |
| 135 | + "by_model": { |
| 136 | + k: { |
| 137 | + "calls": v.calls, |
| 138 | + "prompt_tokens": v.prompt_tokens, |
| 139 | + "completion_tokens": v.completion_tokens, |
| 140 | + "total_tokens": v.total_tokens, |
| 141 | + } |
| 142 | + for k, v in self._by_model.items() |
| 143 | + }, |
| 144 | + } |
| 145 | + |
| 146 | + |
56 | 147 | # --------------------------------------------------------------------------- |
57 | 148 | # IntentHeuristics Protocol |
58 | 149 | # --------------------------------------------------------------------------- |
@@ -729,6 +820,7 @@ class RunOutcome: |
729 | 820 | step_outcomes: list[StepOutcome] = field(default_factory=list) |
730 | 821 | total_duration_ms: int = 0 |
731 | 822 | error: str | None = None |
| 823 | + token_usage: dict[str, Any] | None = None # Token usage summary from get_token_stats() |
732 | 824 |
|
733 | 825 |
|
734 | 826 | # --------------------------------------------------------------------------- |
@@ -1323,6 +1415,37 @@ def __init__( |
1323 | 1415 | # Current automation task (for run-level context) |
1324 | 1416 | self._current_task: AutomationTask | None = None |
1325 | 1417 |
|
| 1418 | + # Token usage tracking |
| 1419 | + self._token_collector = _TokenUsageCollector() |
| 1420 | + |
| 1421 | + def get_token_stats(self) -> dict[str, Any]: |
| 1422 | + """ |
| 1423 | + Get token usage statistics for the agent session. |
| 1424 | +
|
| 1425 | + Returns: |
| 1426 | + Dictionary with: |
| 1427 | + - total: aggregate counts (calls, prompt_tokens, completion_tokens, total_tokens) |
| 1428 | + - by_role: breakdown by role (planner, executor, replan, vision) |
| 1429 | + - by_model: breakdown by model name |
| 1430 | +
|
| 1431 | + Example: |
| 1432 | + >>> stats = agent.get_token_stats() |
| 1433 | + >>> print(f"Total tokens: {stats['total']['total_tokens']}") |
| 1434 | + >>> print(f"Planner tokens: {stats['by_role'].get('planner', {}).get('total_tokens', 0)}") |
| 1435 | + """ |
| 1436 | + return self._token_collector.summary() |
| 1437 | + |
| 1438 | + def reset_token_stats(self) -> None: |
| 1439 | + """Reset token usage statistics to zero.""" |
| 1440 | + self._token_collector.reset() |
| 1441 | + |
| 1442 | + def _record_token_usage(self, role: str, resp: LLMResponse) -> None: |
| 1443 | + """Record token usage from an LLM response.""" |
| 1444 | + try: |
| 1445 | + self._token_collector.record(role=role, resp=resp) |
| 1446 | + except Exception: |
| 1447 | + pass # Don't fail on token tracking errors |
| 1448 | + |
1326 | 1449 | def _format_context(self, snap: Snapshot, goal: str) -> str: |
1327 | 1450 | """ |
1328 | 1451 | Format snapshot for LLM context. |
@@ -2069,6 +2192,7 @@ async def plan( |
2069 | 2192 | temperature=self.config.planner_temperature, |
2070 | 2193 | max_new_tokens=max_tokens, |
2071 | 2194 | ) |
| 2195 | + self._record_token_usage("planner", resp) |
2072 | 2196 | last_output = resp.content |
2073 | 2197 |
|
2074 | 2198 | if self.config.verbose: |
@@ -2169,6 +2293,7 @@ async def replan( |
2169 | 2293 | temperature=self.config.planner_temperature, |
2170 | 2294 | max_new_tokens=1024, |
2171 | 2295 | ) |
| 2296 | + self._record_token_usage("replan", resp) |
2172 | 2297 | last_output = resp.content |
2173 | 2298 |
|
2174 | 2299 | try: |
@@ -2327,6 +2452,7 @@ async def _scroll_to_find_element( |
2327 | 2452 | temperature=self.config.executor_temperature, |
2328 | 2453 | max_new_tokens=self.config.executor_max_tokens, |
2329 | 2454 | ) |
| 2455 | + self._record_token_usage("executor", resp) |
2330 | 2456 | parsed_action, parsed_args = self._parse_action(resp.content) |
2331 | 2457 |
|
2332 | 2458 | if parsed_action == "CLICK" and parsed_args: |
@@ -2398,6 +2524,7 @@ async def _execute_optional_substeps( |
2398 | 2524 | temperature=self.config.executor_temperature, |
2399 | 2525 | max_new_tokens=self.config.executor_max_tokens, |
2400 | 2526 | ) |
| 2527 | + self._record_token_usage("executor", resp) |
2401 | 2528 | parsed_action, parsed_args = self._parse_action(resp.content) |
2402 | 2529 | if parsed_action == "CLICK" and parsed_args: |
2403 | 2530 | element_id = parsed_args[0] |
@@ -2917,6 +3044,7 @@ async def _execute_step( |
2917 | 3044 | temperature=self.config.executor_temperature, |
2918 | 3045 | max_new_tokens=self.config.executor_max_tokens, |
2919 | 3046 | ) |
| 3047 | + self._record_token_usage("executor", resp) |
2920 | 3048 | llm_response = resp.content |
2921 | 3049 |
|
2922 | 3050 | if self.config.verbose: |
@@ -3467,6 +3595,7 @@ async def run( |
3467 | 3595 | step_outcomes=step_outcomes, |
3468 | 3596 | total_duration_ms=int((time.time() - start_time) * 1000), |
3469 | 3597 | error=error, |
| 3598 | + token_usage=self.get_token_stats(), |
3470 | 3599 | ) |
3471 | 3600 |
|
3472 | 3601 | # Emit run end |
|
0 commit comments