From 2cf06ce758534ed1b99985069f686ee2f4cdb4cb Mon Sep 17 00:00:00 2001 From: Daniele Martinoli Date: Wed, 8 Apr 2026 14:43:34 +0200 Subject: [PATCH] feat: add per-invocation tool budget configuration Introduces the `MAX_TOOL_CALLS_PER_INVOCATION` environment variable to limit the number of MCP tool executions per agent run. This feature is enforced in-process and is not shared across replicas. Updates to documentation and configuration files reflect this new capability, along with implementation in the `UsageTrackingPlugin` to manage tool call limits during execution. Signed-off-by: Daniele Martinoli --- .env.example | 4 ++ CLAUDE.md | 3 + deploy/cloudrun/service.yaml | 4 ++ deploy/podman/lightspeed-agent-configmap.yaml | 2 + docs/configuration.md | 6 +- docs/metering.md | 35 +++++++++- docs/rate-limiting.md | 14 ++-- src/lightspeed_agent/api/a2a/usage_plugin.py | 49 ++++++++++++++ src/lightspeed_agent/config/settings.py | 8 +++ tests/test_usage_plugin.py | 66 ++++++++++++++++++- 10 files changed, 180 insertions(+), 11 deletions(-) diff --git a/.env.example b/.env.example index a286317b..e3ebb262 100644 --- a/.env.example +++ b/.env.example @@ -143,6 +143,10 @@ RATE_LIMIT_REQUESTS_PER_HOUR=1000 # How often to report usage to Google Cloud (in seconds) USAGE_REPORT_INTERVAL_SECONDS=3600 +# Max MCP tool executions per ADK invocation (single agent run). 0 = disabled. +# In-memory per process; not shared across Cloud Run replicas (see docs/metering.md). +MAX_TOOL_CALLS_PER_INVOCATION=0 + # ----------------------------------------------------------------------------- # Logging Configuration # ----------------------------------------------------------------------------- diff --git a/CLAUDE.md b/CLAUDE.md index 1c1a0afe..6d0f2d03 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -173,6 +173,9 @@ All configuration is via environment variables, managed through Pydantic setting **Agent:** - `AGENT_HOST`, `AGENT_PORT` +**Usage / metering:** +- `MAX_TOOL_CALLS_PER_INVOCATION` (per-run MCP tool cap; `0` disables; in-process only) + **Service Control:** - `SERVICE_CONTROL_SERVICE_NAME`, `SERVICE_CONTROL_ENABLED` diff --git a/deploy/cloudrun/service.yaml b/deploy/cloudrun/service.yaml index d5edf15d..b5c7515e 100644 --- a/deploy/cloudrun/service.yaml +++ b/deploy/cloudrun/service.yaml @@ -154,6 +154,10 @@ spec: value: "60" - name: RATE_LIMIT_REQUESTS_PER_HOUR value: "1000" + # Optional: max MCP tool executions per agent run (0 = unlimited). + # Counter is in-memory per Cloud Run instance; see docs/metering.md. + - name: MAX_TOOL_CALLS_PER_INVOCATION + value: "0" # Health checks startupProbe: httpGet: diff --git a/deploy/podman/lightspeed-agent-configmap.yaml b/deploy/podman/lightspeed-agent-configmap.yaml index 79e998dd..d74d1f19 100644 --- a/deploy/podman/lightspeed-agent-configmap.yaml +++ b/deploy/podman/lightspeed-agent-configmap.yaml @@ -68,6 +68,8 @@ data: # Usage Reporting Configuration USAGE_REPORT_INTERVAL_SECONDS: "3600" + # Per-invocation MCP tool cap (0 = disabled). In-memory per agent process. + MAX_TOOL_CALLS_PER_INVOCATION: "0" # Google Cloud Service Control (for usage reporting to Marketplace) # Set SERVICE_CONTROL_SERVICE_NAME in secrets for production diff --git a/docs/configuration.md b/docs/configuration.md index 8c224379..b3ccd270 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -225,10 +225,14 @@ GOOGLE_APPLICATION_CREDENTIALS=/path/to/service-account.json ### Usage Tracking -Usage tracking is built into the agent via the ADK plugin system. No configuration required for basic tracking. +Usage tracking is built into the agent via the ADK plugin system. No database configuration is required for basic tracking. | Variable | Default | Description | |----------|---------|-------------| +| `MAX_TOOL_CALLS_PER_INVOCATION` | `0` | Cap MCP tool executions per agent run (`0` disables). Per-process memory only; see [metering.md](metering.md#per-invocation-tool-budget). | +| `METERING_STALE_CLAIM_MINUTES` | `15` | Release usage rows claimed longer than this (worker crash recovery) | +| `METERING_BACKFILL_MAX_AGE_HOURS` | `168` | Backfill only periods within this window (default 7 days) | +| `METERING_BACKFILL_LIMIT_PER_RUN` | `20` | Max unreported periods to process per backfill run | | `LOG_LEVEL` | `INFO` | Set to `DEBUG` to see detailed usage logs | See [Usage Tracking and Metering](metering.md) for details on the plugin system and how to extend it. diff --git a/docs/metering.md b/docs/metering.md index 17a0d5a4..bb07db1c 100644 --- a/docs/metering.md +++ b/docs/metering.md @@ -35,7 +35,11 @@ All usage tracking is handled by the `UsageTrackingPlugin` in `src/lightspeed_ag │ │ │ │ │ │ after_model_callback ───► Extract token counts from response │ │ │ │ │ │ -│ │ after_tool_callback ────► Increment tool call counter │ │ +│ │ before_tool_callback ───► Optional per-run tool budget │ │ +│ │ │ │ +│ │ after_tool_callback ────► Increment tool call counter (DB) │ │ +│ │ │ │ +│ │ after_run_callback ─────► Clear in-memory tool budget state │ │ │ │ │ │ │ └─────────────────────────────────────────────────────────────────┘ │ │ │ │ @@ -89,7 +93,30 @@ app = App( ## UsageTrackingPlugin Implementation -The `UsageTrackingPlugin` (`src/lightspeed_agent/api/a2a/usage_plugin.py`) implements three callbacks: +The `UsageTrackingPlugin` (`src/lightspeed_agent/api/a2a/usage_plugin.py`) implements the metering callbacks below, plus optional **per-invocation tool budgeting** (`before_tool_callback` / `after_run_callback`). + +## Per-invocation tool budget + +When `MAX_TOOL_CALLS_PER_INVOCATION` is greater than zero, the plugin counts how many tools are **started** in a single ADK invocation (same `invocation_id` on `ToolContext`). Before each tool runs, if the count is already at the limit, `before_tool_callback` returns a short-circuit **dict** (ADK skips executing the tool and does not run `after_tool_callback` for that call). Allowed tools increment the in-memory counter; when the run ends, `after_run_callback` removes the counter for that invocation so memory does not grow unbounded across requests. + +- **Metering interaction:** Blocked tools never reach `after_tool_callback`, so they are **not** persisted as `tool_calls` in `UsageRepository`. +- **Relation to HTTP rate limits:** HTTP rate limiting (see [Rate limiting](rate-limiting.md)) bounds **incoming A2A requests**. The tool budget bounds **depth of a single agent run** (tool–model loops). They address different abuse patterns. + +### Multi-instance caveat (initial limitation) + +The budget uses an **in-memory** `dict` guarded by `asyncio.Lock` inside each agent process. It is **not** shared across Cloud Run instances, Podman replicas, or multi-worker Uvicorn processes. A tenant that spreads traffic across N instances effectively gets up to **N × limit** tool starts per logical “burst” if requests land on different instances, and **invocation_id** is only consistent within the process that handled the run. + +**Operational mitigations today:** keep `maxScale` predictable, use **session affinity** if your platform supports it for long-lived runs, or set a conservative limit accepting per-instance enforcement. + +### Proposed enhancements (persistent / shared counters) + +For **cross-replica, accurate** per-run caps, a future iteration could: + +1. **Redis (recommended):** Store `INCR` (or `INCRBY`) under a key such as `lightspeed:tool_budget:{invocation_id}` with **TTL** slightly above the max agent timeout (e.g. Cloud Run `timeoutSeconds`). `before_tool` would `INCR` and compare to `MAX_TOOL_CALLS_PER_INVOCATION`; optionally use a small Lua script for atomic check-and-set. Reuse the same Redis deployment as rate limiting or a dedicated prefix. +2. **Database:** A row keyed by `invocation_id` or `(order_id, run_id)` with a monotonic `tool_starts` column and optimistic locking; heavier than Redis but consistent with existing PostgreSQL. +3. **Runner hints:** If the platform exposes a stable **run identifier** in headers or metadata, prefer that over ad-hoc IDs so budgets survive internal retries more predictably. + +Until one of the above is implemented, treat `MAX_TOOL_CALLS_PER_INVOCATION` as a **per-process safety rail**, not a strict global quota. ### Request Counting @@ -149,7 +176,7 @@ async def after_tool_callback( return None # Don't modify the result ``` -This callback fires after every MCP tool invocation, persisting a tool-call increment for the current order. +This callback fires after every MCP tool invocation, persisting a tool-call increment for the current order (skipped when `before_tool_callback` blocks the tool). ## Storage: UsageRepository @@ -302,11 +329,13 @@ app = App( - **Retry on failure**: Failed reports are queued and retried with configurable max attempts; rows are released on failure for re-claim on retry - **Stale claim recovery**: Rows claimed by a crashed worker (never marked or released) are released at the start of each hourly run; threshold configurable via `METERING_STALE_CLAIM_MINUTES` - **Automatic backfill**: Unreported periods (from scheduler downtime or stale releases) are reported on each hourly run; configurable via `METERING_BACKFILL_MAX_AGE_HOURS` (default 7 days) and `METERING_BACKFILL_LIMIT_PER_RUN` (default 20) +- **Optional per-run tool budget**: Configurable via `MAX_TOOL_CALLS_PER_INVOCATION`; in-process enforcement with documented multi-instance limits and a Redis/DB upgrade path (see [Per-invocation tool budget](#per-invocation-tool-budget)) ## Configuration | Variable | Default | Description | |----------|---------|-------------| +| `MAX_TOOL_CALLS_PER_INVOCATION` | 0 | Max tool starts per ADK invocation; `0` disables (in-process only; see [Per-invocation tool budget](#per-invocation-tool-budget)) | | `METERING_STALE_CLAIM_MINUTES` | 15 | Release rows claimed longer than this (worker crash recovery) | | `METERING_BACKFILL_MAX_AGE_HOURS` | 168 | Backfill only periods within this many hours (7 days) | | `METERING_BACKFILL_LIMIT_PER_RUN` | 20 | Max unreported periods to process per backfill run | diff --git a/docs/rate-limiting.md b/docs/rate-limiting.md index 5794ea49..5e70b307 100644 --- a/docs/rate-limiting.md +++ b/docs/rate-limiting.md @@ -157,16 +157,18 @@ With authentication, rate limits apply per `order_id` and `user_id` (from the to ## Rate Limiting vs Usage Tracking -The agent has two separate systems for managing API usage: +The agent combines **HTTP-level throttling**, **usage metering**, and an optional **per-run tool budget**. They are separate layers: -| System | Purpose | Mechanism | -|--------|---------|-----------| -| **Rate Limiting** | Prevent abuse | FastAPI middleware, rejects excess requests | -| **Usage Tracking** | Monitor consumption | ADK plugin, counts tokens and tool calls | +| Layer | What it limits | When it runs | Shared across replicas? | +|-------|----------------|--------------|-------------------------| +| **HTTP rate limiting** | Incoming A2A POSTs per principal (minute/hour windows) | FastAPI middleware before the ADK runner | Yes, when all instances use the same Redis | +| **Usage tracking (DB)** | Requests, tokens, completed tool calls for billing/analytics | ADK plugin (`UsageTrackingPlugin`) | Yes, all instances write to the same database | +| **Per-invocation tool budget** | How many tools may **start** in one agent run | ADK `before_tool_callback` | **No (today):** in-memory per process; see [Per-invocation tool budget](metering.md#per-invocation-tool-budget) and [proposed shared counters](metering.md#proposed-enhancements-persistent--shared-counters) | -Rate limiting happens **before** the request is processed (at the middleware layer), while usage tracking happens **during** request processing (via ADK plugin callbacks). +**Comparison in plain terms:** Redis rate limits stop a client from opening too many **HTTP conversations**. The tool budget stops a **single** conversation from hammering MCP with an unbounded tool–model loop. Metering records what actually ran for reporting, including tools that completed (blocked tools are not counted in `after_tool_callback`). ## Notes - Rate limits are enforced across replicas as long as they share the same Redis instance. - The service verifies Redis connectivity at startup and fails fast when Redis is unavailable. +- Tool budgets are **not** distributed across replicas until a shared store (for example Redis with TTL keyed by `invocation_id`) is implemented; see [metering.md](metering.md#proposed-enhancements-persistent--shared-counters). diff --git a/src/lightspeed_agent/api/a2a/usage_plugin.py b/src/lightspeed_agent/api/a2a/usage_plugin.py index 58930cd5..53ed1060 100644 --- a/src/lightspeed_agent/api/a2a/usage_plugin.py +++ b/src/lightspeed_agent/api/a2a/usage_plugin.py @@ -1,5 +1,6 @@ """Usage tracking plugin with per-order metrics.""" +import asyncio import logging from typing import Any @@ -11,10 +12,13 @@ from google.adk.tools.tool_context import ToolContext from lightspeed_agent.auth.middleware import get_request_order_id +from lightspeed_agent.config import get_settings from lightspeed_agent.metering import get_usage_repository logger = logging.getLogger(__name__) +_TOOL_LIMIT_CODE = "usage_tool_call_limit_exceeded" + def _resolve_order_id() -> str | None: """Resolve the current request order_id from request context.""" @@ -27,6 +31,8 @@ class UsageTrackingPlugin(BasePlugin): def __init__(self) -> None: super().__init__(name="usage_tracking") self._usage_repo = get_usage_repository() + self._tool_budget_lock = asyncio.Lock() + self._tool_calls_by_invocation: dict[str, int] = {} async def before_run_callback(self, *, invocation_context: InvocationContext) -> None: """Track request count at start of each run.""" @@ -69,6 +75,49 @@ async def after_model_callback( return None # Don't modify the response + async def before_tool_callback( + self, + *, + tool: BaseTool, + tool_args: dict[str, Any], + tool_context: ToolContext, + ) -> dict[str, Any] | None: + """Enforce optional per-invocation tool budget before MCP execution.""" + limit = get_settings().max_tool_calls_per_invocation + if limit <= 0: + return None + + inv_id = tool_context.invocation_id + async with self._tool_budget_lock: + current = self._tool_calls_by_invocation.get(inv_id, 0) + if current >= limit: + tool_name = getattr(tool, "name", type(tool).__name__) + logger.warning( + "Tool call blocked: invocation %s already used %s tool call(s) " + "(max_tool_calls_per_invocation=%s); attempted tool=%s", + inv_id, + current, + limit, + tool_name, + ) + return { + "error": ( + f"Exceeded maximum of {limit} tool call(s) for this agent run. " + "Start a new message or ask your administrator to raise " + "MAX_TOOL_CALLS_PER_INVOCATION if appropriate." + ), + "code": _TOOL_LIMIT_CODE, + } + self._tool_calls_by_invocation[inv_id] = current + 1 + return None + + async def after_run_callback(self, *, invocation_context: InvocationContext) -> None: + """Drop per-invocation tool budget state when the run completes.""" + inv_id = invocation_context.invocation_id + async with self._tool_budget_lock: + self._tool_calls_by_invocation.pop(inv_id, None) + return None + async def after_tool_callback( self, *, diff --git a/src/lightspeed_agent/config/settings.py b/src/lightspeed_agent/config/settings.py index c10debe6..2121de79 100644 --- a/src/lightspeed_agent/config/settings.py +++ b/src/lightspeed_agent/config/settings.py @@ -143,6 +143,14 @@ class Settings(BaseSettings): default=20, description="Max unreported periods to process per backfill run", ) + max_tool_calls_per_invocation: int = Field( + default=0, + description=( + "Maximum MCP tool executions per ADK invocation (single agent run). " + "0 disables the limit. Enforced in-process via UsageTrackingPlugin " + "(not shared across replicas; see docs/metering.md)." + ), + ) # Rate Limiting (Redis-backed) rate_limit_requests_per_minute: int = Field( diff --git a/tests/test_usage_plugin.py b/tests/test_usage_plugin.py index 1611f72a..230a0344 100644 --- a/tests/test_usage_plugin.py +++ b/tests/test_usage_plugin.py @@ -1,6 +1,6 @@ """Tests for usage tracking plugin persistence behavior.""" -from unittest.mock import AsyncMock, MagicMock +from unittest.mock import AsyncMock, MagicMock, patch import pytest @@ -111,3 +111,67 @@ async def test_after_tool_persists_tool_call_increment(self): tool_calls=1, ) + +def _tool_context(invocation_id: str) -> MagicMock: + ctx = MagicMock() + ctx.invocation_id = invocation_id + return ctx + + +def _invocation_context(invocation_id: str) -> MagicMock: + ctx = MagicMock() + ctx.invocation_id = invocation_id + return ctx + + +class TestUsageToolCallBudget: + """Per-invocation tool budget (before_tool_callback).""" + + @pytest.mark.asyncio + async def test_before_tool_no_enforcement_when_limit_zero(self): + """With limit 0, every before_tool call is allowed.""" + settings = MagicMock(max_tool_calls_per_invocation=0) + tool = MagicMock() + tool.name = "t" + tc = _tool_context("inv-a") + with patch("lightspeed_agent.api.a2a.usage_plugin.get_settings", return_value=settings): + plugin = usage_plugin.UsageTrackingPlugin() + for _ in range(5): + assert await plugin.before_tool_callback( + tool=tool, tool_args={}, tool_context=tc + ) is None + + @pytest.mark.asyncio + async def test_before_tool_blocks_after_limit(self): + """Allow exactly N tool starts, then return a short-circuit error dict.""" + settings = MagicMock(max_tool_calls_per_invocation=2) + tool = MagicMock() + tool.name = "t" + tc = _tool_context("inv-limit") + with patch("lightspeed_agent.api.a2a.usage_plugin.get_settings", return_value=settings): + plugin = usage_plugin.UsageTrackingPlugin() + kwargs = {"tool": tool, "tool_args": {}, "tool_context": tc} + assert await plugin.before_tool_callback(**kwargs) is None + assert await plugin.before_tool_callback(**kwargs) is None + blocked = await plugin.before_tool_callback(**kwargs) + assert blocked is not None + assert blocked["code"] == usage_plugin._TOOL_LIMIT_CODE + assert "Exceeded maximum of 2" in blocked["error"] + + @pytest.mark.asyncio + async def test_after_run_clears_budget_for_next_run_same_plugin(self): + """after_run_callback drops counters so a new run can use the budget again.""" + settings = MagicMock(max_tool_calls_per_invocation=1) + tool = MagicMock() + tool.name = "t" + tc = _tool_context("inv-reset") + inv = _invocation_context("inv-reset") + with patch("lightspeed_agent.api.a2a.usage_plugin.get_settings", return_value=settings): + plugin = usage_plugin.UsageTrackingPlugin() + kwargs = {"tool": tool, "tool_args": {}, "tool_context": tc} + assert await plugin.before_tool_callback(**kwargs) is None + blocked = await plugin.before_tool_callback(**kwargs) + assert blocked is not None + await plugin.after_run_callback(invocation_context=inv) + assert await plugin.before_tool_callback(**kwargs) is None +