diff --git a/CHANGELOG.md b/CHANGELOG.md index 457b3b7..b7876c4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,24 @@ All notable changes to the `fipsagents` package will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/). +## [0.14.0] - 2026-04-27 + +### Added + +- **`SessionStore.update()`** — partial-update method on the ABC for recording per-session accumulator state without rewriting message history. Signature: `update(session_id, *, cost_data: dict | None = None) -> bool`. Implementations on `Null` (no-op `False`), `Sqlite` (Python-side shallow merge), `Postgres` (native `||` JSONB merge), and `Http` (maps to `PATCH /v1/sessions/{id}`). First slice of [#104](https://github.com/fips-agents/agent-template/issues/104) (Cost Tracking). +- **`SessionStore.get_cost_data()`** — symmetric reader so the server-side accumulator can read existing totals before writing cumulative ones back. Implemented on `Null` / `Sqlite` / `Postgres`; `Http` raises `NotImplementedError` until the platform exposes a GET endpoint (tracked at [fipsagents-platform#4](https://github.com/fips-agents/fipsagents-platform/issues/4)). +- **Per-turn token-usage persistence** — `OpenAIChatServer` extracts `prompt_tokens` / `completion_tokens` from each turn's terminal `StreamComplete` event (sync and streaming paths) and accumulates `input_tokens`, `output_tokens`, `cached_tokens`, `model`, and `turn_count` onto the session's `cost_data` via `SessionStore.update()`. Persistence failures are caught and logged so cost-tracking issues never break the chat response. +- **`cost_data` column** on the `sessions` table — `TEXT NOT NULL DEFAULT '{}'` on SQLite, `JSONB NOT NULL DEFAULT '{}'::jsonb` on Postgres. Existing databases pick up the column on first connect via idempotent `ALTER TABLE ADD COLUMN` migrations; no operator action required. + +### Changed + +- `SqliteSessionStore.save()` switches from `INSERT OR REPLACE` to `ON CONFLICT(session_id) DO UPDATE SET messages, updated_at` so `cost_data` survives saves of new messages. Postgres's `save()` already had the right shape. + +### Notes + +- HTTP-backed deployments currently fall back to per-turn-delta writes (last-write-wins) for `cost_data` because the platform doesn't yet expose a read endpoint — see [fipsagents-platform#4](https://github.com/fips-agents/fipsagents-platform/issues/4). SQLite/Postgres backends get cumulative semantics for free. +- Cost data shape (`input_tokens`, `output_tokens`, `cached_tokens`, `model`, `turn_count`) is owned by the server layer; pricing, budget enforcement, and aggregation endpoints are deferred follow-ups on [#104](https://github.com/fips-agents/agent-template/issues/104). + ## [0.13.0] - 2026-04-27 ### Added diff --git a/packages/fipsagents/pyproject.toml b/packages/fipsagents/pyproject.toml index 7da48c4..7cf2c27 100644 --- a/packages/fipsagents/pyproject.toml +++ b/packages/fipsagents/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "fipsagents" -version = "0.13.0" +version = "0.14.0" description = "Production-ready AI agent framework for FIPS/OpenShift environments" readme = "README.md" license = {file = "LICENSE"} diff --git a/packages/fipsagents/src/fipsagents/server/app.py b/packages/fipsagents/src/fipsagents/server/app.py index aef960b..1e7ed6f 100644 --- a/packages/fipsagents/src/fipsagents/server/app.py +++ b/packages/fipsagents/src/fipsagents/server/app.py @@ -558,6 +558,9 @@ async def _chat_completions(self, request: Request, req: ChatCompletionRequest): # Session: save after sync response. if req.session_id and self._session_store: await self._session_store.save(req.session_id, agent.messages) + await self._persist_cost_data( + req.session_id, metrics, model_name, + ) if collector: await collector.end_request() if self._metrics_collector and metrics_start is not None: @@ -677,6 +680,7 @@ async def _stream( self._agent.messages = list(incoming) stream_status = "ok" + captured_metrics: StreamMetrics | None = None try: events = self._agent.astep_stream(max_iterations=10, **(overrides or {})) if self._metrics_collector is not None: @@ -685,6 +689,18 @@ async def _stream( ) if collector: events = collector.observe(events) + + # Pass-through observer that snapshots the StreamMetrics + # from the StreamComplete event so the post-stream + # cost-data accumulator can read them. + async def _capture_metrics(stream): + nonlocal captured_metrics + async for ev in stream: + if isinstance(ev, StreamComplete): + captured_metrics = ev.metrics + yield ev + + events = _capture_metrics(events) async for chunk in stream_events_as_sse( events, model_name, trace_id=trace_id, ): @@ -706,6 +722,70 @@ async def _stream( # Session: save after streaming completes. if session_id and self._session_store: await self._session_store.save(session_id, self._agent.messages) + await self._persist_cost_data( + session_id, captured_metrics, model_name, + ) + + # -- Cost-data accumulator ----------------------------------------------- + + async def _persist_cost_data( + self, + session_id: str, + metrics: StreamMetrics | None, + model_name: str, + ) -> None: + """Accumulate this turn's token usage into the session's cost_data. + + Cumulative-for-the-session: read the existing accumulator, add + this turn's deltas, write it back. Failures are logged and + swallowed -- cost tracking must never break the chat response. + + Backends that don't support reading cost_data (eg + :class:`HttpSessionStore`) raise :class:`NotImplementedError` + from ``get_cost_data``; in that case we treat the existing + total as empty and the next ``update`` records this turn's + delta only. A follow-up issue tracks exposing the platform + read endpoint so HTTP-backed deployments get cumulative totals. + """ + if metrics is None or self._session_store is None: + return + + prompt = metrics.prompt_tokens + completion = metrics.completion_tokens + # Nothing useful to record when the provider didn't report usage. + if prompt is None and completion is None: + return + + try: + existing = await self._session_store.get_cost_data(session_id) + except NotImplementedError: + existing = {} + except Exception: # noqa: BLE001 — keep chat response alive + logger.warning( + "Failed to read cost_data for %s; using empty baseline", + session_id, + exc_info=True, + ) + existing = {} + + new_data = { + "input_tokens": int(existing.get("input_tokens", 0) or 0) + + int(prompt or 0), + "output_tokens": int(existing.get("output_tokens", 0) or 0) + + int(completion or 0), + "cached_tokens": int(existing.get("cached_tokens", 0) or 0), + "model": model_name or existing.get("model"), + "turn_count": int(existing.get("turn_count", 0) or 0) + 1, + } + + try: + await self._session_store.update(session_id, cost_data=new_data) + except Exception: # noqa: BLE001 — keep chat response alive + logger.warning( + "Failed to persist cost_data for %s", + session_id, + exc_info=True, + ) # -- Run ----------------------------------------------------------------- diff --git a/packages/fipsagents/src/fipsagents/server/http.py b/packages/fipsagents/src/fipsagents/server/http.py index 3d982a5..6a2debb 100644 --- a/packages/fipsagents/src/fipsagents/server/http.py +++ b/packages/fipsagents/src/fipsagents/server/http.py @@ -194,6 +194,7 @@ class HttpSessionStore(SessionStore): - ``create`` → ``POST /v1/sessions`` - ``load`` → ``GET /v1/sessions/{id}`` - ``save`` → ``PUT /v1/sessions/{id}`` (upsert) + - ``update`` → ``PATCH /v1/sessions/{id}`` - ``exists`` → ``HEAD /v1/sessions/{id}`` - ``delete`` → ``DELETE /v1/sessions/{id}`` - ``delete_before`` → no platform endpoint; logged no-op (the @@ -233,6 +234,36 @@ async def save(self, session_id: str, messages: list[dict]) -> None: json={"messages": messages}, ) + async def update( + self, + session_id: str, + *, + cost_data: dict | None = None, + ) -> bool: + if cost_data is None: + return await self.exists(session_id) + body: dict[str, Any] = {"cost_data": cost_data} + status, _ = await self._client.request( + "PATCH", + f"/v1/sessions/{session_id}", + json=body, + not_found_returns_none=True, + ) + return status != 404 + + async def get_cost_data(self, session_id: str) -> dict: + # The platform service has no GET /v1/sessions/{id}/cost_data + # endpoint yet. Until it does, callers must treat the HTTP + # backend as write-only for cost accumulator state. The server's + # per-turn accumulator catches NotImplementedError and treats + # the existing total as empty (so the next write is the turn's + # delta rather than a true cumulative). A follow-up issue tracks + # exposing the read endpoint on the platform. + raise NotImplementedError( + "HttpSessionStore.get_cost_data: the platform service does " + "not expose a GET endpoint for cost_data yet." + ) + async def delete(self, session_id: str) -> bool: status, _ = await self._client.request( "DELETE", diff --git a/packages/fipsagents/src/fipsagents/server/sessions.py b/packages/fipsagents/src/fipsagents/server/sessions.py index 83b40fc..78d1111 100644 --- a/packages/fipsagents/src/fipsagents/server/sessions.py +++ b/packages/fipsagents/src/fipsagents/server/sessions.py @@ -41,6 +41,33 @@ async def load(self, session_id: str) -> list[dict] | None: async def save(self, session_id: str, messages: list[dict]) -> None: """Persist the full message history for a session.""" + @abstractmethod + async def update( + self, + session_id: str, + *, + cost_data: dict | None = None, + ) -> bool: + """Partial update of a session. + + Currently supports merging ``cost_data`` (shallow merge per top-level key, + write-wins). Returns True if the session existed, False otherwise. + Designed to be additive -- future fields can be added as keyword-only args. + """ + + @abstractmethod + async def get_cost_data(self, session_id: str) -> dict: + """Return the current accumulated ``cost_data`` for a session. + + Symmetric companion to :meth:`update` so callers (notably the + server's per-turn cost accumulator) can read the existing totals + before computing the next write. + + Returns an empty dict if the session is missing or has no + cost_data yet. Backends without a read endpoint (notably the + HTTP-backed store) raise :class:`NotImplementedError`. + """ + @abstractmethod async def delete(self, session_id: str) -> bool: """Remove a session. Return True if it existed.""" @@ -71,6 +98,17 @@ async def load(self, session_id: str) -> list[dict] | None: async def save(self, session_id: str, messages: list[dict]) -> None: pass + async def update( + self, + session_id: str, + *, + cost_data: dict | None = None, + ) -> bool: + return False + + async def get_cost_data(self, session_id: str) -> dict: + return {} + async def delete(self, session_id: str) -> bool: return False @@ -89,7 +127,8 @@ class SqliteSessionStore(SessionStore): session_id TEXT PRIMARY KEY, messages TEXT NOT NULL, created_at TEXT NOT NULL, - updated_at TEXT NOT NULL + updated_at TEXT NOT NULL, + cost_data TEXT NOT NULL DEFAULT '{}' )""" def __init__(self, db_path: str = "./agent.db", *, connection: Any = None) -> None: @@ -110,6 +149,14 @@ async def _get_db(self) -> Any: async def _ensure_table(self) -> None: db = self._db await db.execute(self._CREATE_TABLE) + # Migrate older databases that predate cost_data. + cursor = await db.execute("PRAGMA table_info(sessions)") + cols = {row[1] for row in await cursor.fetchall()} + if "cost_data" not in cols: + await db.execute( + "ALTER TABLE sessions ADD COLUMN cost_data TEXT NOT NULL DEFAULT '{}'" + ) + logger.debug("SqliteSessionStore: migrated schema (added cost_data)") await db.commit() self._initialized = True @@ -141,17 +188,62 @@ async def load(self, session_id: str) -> list[dict] | None: async def save(self, session_id: str, messages: list[dict]) -> None: now = _utc_now_iso() db = await self._get_db() + # Upsert that preserves cost_data on conflict (it's accumulator state + # owned by ``update()`` and must not be reset by every save). await db.execute( - "INSERT OR REPLACE INTO sessions " + "INSERT INTO sessions " "(session_id, messages, created_at, updated_at) " - "VALUES (?, ?, COALESCE(" - " (SELECT created_at FROM sessions WHERE session_id = ?), ?" - "), ?)", - (session_id, json.dumps(messages), session_id, now, now), + "VALUES (?, ?, ?, ?) " + "ON CONFLICT(session_id) DO UPDATE SET " + " messages = excluded.messages, " + " updated_at = excluded.updated_at", + (session_id, json.dumps(messages), now, now), ) await db.commit() logger.debug("SqliteSessionStore: saved %s (%d messages)", session_id, len(messages)) + async def update( + self, + session_id: str, + *, + cost_data: dict | None = None, + ) -> bool: + db = await self._get_db() + if cost_data is None: + return await self.exists(session_id) + cursor = await db.execute( + "SELECT cost_data FROM sessions WHERE session_id = ?", + (session_id,), + ) + row = await cursor.fetchone() + if row is None: + return False + existing = json.loads(row[0]) if row[0] else {} + existing.update(cost_data) + now = _utc_now_iso() + await db.execute( + "UPDATE sessions SET cost_data = ?, updated_at = ? " + "WHERE session_id = ?", + (json.dumps(existing), now, session_id), + ) + await db.commit() + logger.debug("SqliteSessionStore: updated %s cost_data", session_id) + return True + + async def get_cost_data(self, session_id: str) -> dict: + db = await self._get_db() + cursor = await db.execute( + "SELECT cost_data FROM sessions WHERE session_id = ?", + (session_id,), + ) + row = await cursor.fetchone() + if row is None or not row[0]: + return {} + try: + return json.loads(row[0]) + except (TypeError, ValueError): + return {} + async def delete(self, session_id: str) -> bool: db = await self._get_db() cursor = await db.execute( @@ -196,12 +288,17 @@ class PostgresSessionStore(SessionStore): session_id TEXT PRIMARY KEY, messages JSONB NOT NULL, created_at TIMESTAMPTZ NOT NULL DEFAULT now(), - updated_at TIMESTAMPTZ NOT NULL DEFAULT now() + updated_at TIMESTAMPTZ NOT NULL DEFAULT now(), + cost_data JSONB NOT NULL DEFAULT '{}'::jsonb )""" _CREATE_INDEX = ( "CREATE INDEX IF NOT EXISTS idx_sessions_updated " "ON sessions (updated_at)" ) + _ADD_COST_DATA = ( + "ALTER TABLE sessions " + "ADD COLUMN IF NOT EXISTS cost_data JSONB NOT NULL DEFAULT '{}'::jsonb" + ) def __init__(self, database_url: str) -> None: self._database_url = database_url @@ -221,6 +318,7 @@ async def _ensure_table(self) -> None: pool = self._pool async with pool.acquire() as conn: await conn.execute(self._CREATE_TABLE) + await conn.execute(self._ADD_COST_DATA) await conn.execute(self._CREATE_INDEX) self._initialized = True @@ -268,6 +366,50 @@ async def save(self, session_id: str, messages: list[dict]) -> None: "PostgresSessionStore: saved %s (%d messages)", session_id, len(messages), ) + async def update( + self, + session_id: str, + *, + cost_data: dict | None = None, + ) -> bool: + pool = await self._get_pool() + if cost_data is None: + return await self.exists(session_id) + now = datetime.now(timezone.utc) + async with pool.acquire() as conn: + result = await conn.execute( + "UPDATE sessions " + "SET cost_data = cost_data || $2::jsonb, " + " updated_at = $3 " + "WHERE session_id = $1", + session_id, json.dumps(cost_data), now, + ) + # asyncpg returns "UPDATE N" + updated = not result.endswith("0") + if updated: + logger.debug("PostgresSessionStore: updated %s cost_data", session_id) + return updated + + async def get_cost_data(self, session_id: str) -> dict: + pool = await self._get_pool() + async with pool.acquire() as conn: + row = await conn.fetchrow( + "SELECT cost_data FROM sessions WHERE session_id = $1", + session_id, + ) + if row is None: + return {} + cost = row["cost_data"] + # asyncpg auto-decodes JSONB to Python objects, but be defensive. + if cost is None: + return {} + if isinstance(cost, str): + try: + return json.loads(cost) + except (TypeError, ValueError): + return {} + return cost + async def delete(self, session_id: str) -> bool: pool = await self._get_pool() async with pool.acquire() as conn: diff --git a/packages/fipsagents/tests/test_http_stores.py b/packages/fipsagents/tests/test_http_stores.py index b27a60d..de2010e 100644 --- a/packages/fipsagents/tests/test_http_stores.py +++ b/packages/fipsagents/tests/test_http_stores.py @@ -159,6 +159,95 @@ async def test_session_delete_before_is_noop() -> None: await store.close() +@pytest.mark.asyncio +async def test_session_update_sends_patch() -> None: + rec = _Recorder([_ok(200, { + "session_id": "sess_abc", + "messages": [], + "created_at": "2026-04-27T12:00:00+00:00", + "updated_at": "2026-04-27T12:00:01+00:00", + "cost_data": {"input_tokens": 100, "output_tokens": 50}, + })]) + store = HttpSessionStore( + "http://platform.test", transport=httpx.MockTransport(rec), + ) + tokens = set_request_context( + authorization="Bearer per-request-jwt", traceparent=None, + ) + try: + result = await store.update( + "sess_abc", + cost_data={"input_tokens": 100, "output_tokens": 50}, + ) + finally: + reset_request_context(tokens) + assert result is True + assert rec.requests[0].method == "PATCH" + assert rec.requests[0].url.path == "/v1/sessions/sess_abc" + assert json.loads(rec.requests[0].content) == { + "cost_data": {"input_tokens": 100, "output_tokens": 50}, + } + assert rec.requests[0].headers["authorization"] == "Bearer per-request-jwt" + await store.close() + + +@pytest.mark.asyncio +async def test_session_update_404_returns_false() -> None: + rec = _Recorder([_ok(404, {"detail": "not found"})]) + store = HttpSessionStore( + "http://platform.test", transport=httpx.MockTransport(rec), + ) + result = await store.update( + "missing", cost_data={"input_tokens": 1}, + ) + assert result is False + assert rec.requests[0].method == "PATCH" + await store.close() + + +@pytest.mark.asyncio +async def test_session_update_none_cost_data_delegates_to_exists() -> None: + """When cost_data is None, update() delegates to exists() (HEAD probe).""" + rec = _Recorder([_ok(200), _ok(404)]) + store = HttpSessionStore( + "http://platform.test", transport=httpx.MockTransport(rec), + ) + assert await store.update("sess_1", cost_data=None) is True + assert await store.update("missing", cost_data=None) is False + # Both calls should have been HEAD, not PATCH. + assert rec.requests[0].method == "HEAD" + assert rec.requests[1].method == "HEAD" + await store.close() + + +@pytest.mark.asyncio +async def test_session_get_cost_data_raises_not_implemented() -> None: + """The HTTP backend has no GET cost_data endpoint yet -- raise so the + server-side accumulator can fall back to a delta-only write.""" + rec = _Recorder([]) + store = HttpSessionStore( + "http://platform.test", transport=httpx.MockTransport(rec), + ) + with pytest.raises(NotImplementedError): + await store.get_cost_data("sess_anything") + # No HTTP request should have been issued. + assert rec.requests == [] + await store.close() + + +@pytest.mark.asyncio +async def test_session_update_5xx_raises() -> None: + rec = _Recorder([httpx.Response(500, json={"detail": "boom"})]) + store = HttpSessionStore( + "http://platform.test", transport=httpx.MockTransport(rec), + ) + with pytest.raises(PlatformError) as exc_info: + await store.update("sess_1", cost_data={"input_tokens": 1}) + assert exc_info.value.status_code == 500 + assert "500" in str(exc_info.value) + await store.close() + + # --------------------------------------------------------------------------- # HttpTraceStore # --------------------------------------------------------------------------- diff --git a/packages/fipsagents/tests/test_http_stores_e2e.py b/packages/fipsagents/tests/test_http_stores_e2e.py index ab3e597..2524e00 100644 --- a/packages/fipsagents/tests/test_http_stores_e2e.py +++ b/packages/fipsagents/tests/test_http_stores_e2e.py @@ -116,6 +116,64 @@ async def test_session_save_creates_when_missing(platform_transport) -> None: await store.close() +@pytest.mark.asyncio +async def test_session_update_round_trip(platform_transport) -> None: + """PATCH /v1/sessions/{id} merges cost_data; later writes win per top-level key.""" + store = HttpSessionStore( + "http://platform.test", transport=platform_transport, + ) + sid = await store.create("sess_e2e_update") + assert await store.update(sid, cost_data={"a": 1}) is True + assert await store.update(sid, cost_data={"b": 2, "a": 5}) is True + + # Round-trip through the platform's underlying SqliteSessionStore — we + # use it directly here since GET /v1/sessions/{id} only returns messages. + from fipsagents_platform.config import get_settings + + settings = get_settings() + from fipsagents.server.sessions import SqliteSessionStore + + direct = SqliteSessionStore(settings.sqlite_path) + try: + # Re-use the same DB the platform writes to and read cost_data + # straight out of the table. + db = await direct._get_db() + cursor = await db.execute( + "SELECT cost_data FROM sessions WHERE session_id = ?", (sid,), + ) + row = await cursor.fetchone() + import json as _json + assert row is not None + assert _json.loads(row[0]) == {"a": 5, "b": 2} + finally: + await direct.close() + await store.close() + + +@pytest.mark.asyncio +async def test_session_update_missing_returns_false(platform_transport) -> None: + store = HttpSessionStore( + "http://platform.test", transport=platform_transport, + ) + assert await store.update( + "sess_does_not_exist", cost_data={"input_tokens": 1}, + ) is False + await store.close() + + +@pytest.mark.asyncio +async def test_session_update_none_cost_data_returns_existence( + platform_transport, +) -> None: + store = HttpSessionStore( + "http://platform.test", transport=platform_transport, + ) + sid = await store.create("sess_e2e_update_none") + assert await store.update(sid, cost_data=None) is True + assert await store.update("sess_does_not_exist", cost_data=None) is False + await store.close() + + # --------------------------------------------------------------------------- # Traces # --------------------------------------------------------------------------- diff --git a/packages/fipsagents/tests/test_server_openai.py b/packages/fipsagents/tests/test_server_openai.py index 408c7e5..7a3e9db 100644 --- a/packages/fipsagents/tests/test_server_openai.py +++ b/packages/fipsagents/tests/test_server_openai.py @@ -854,3 +854,308 @@ def test_create_feedback_rejects_invalid_rating(): with pytest.raises(ValidationError): CreateFeedbackRequest(rating=0) + + +# --------------------------------------------------------------------------- +# Cost-data accumulator (per-turn token usage → SessionStore.update) +# --------------------------------------------------------------------------- + + +def _build_server_with_sqlite_sessions(tmp_path, events, *, model_name="stub"): + """Build a server with sessions enabled and backed by SQLite.""" + AgentClass = _make_agent_class(events, model_name=model_name) + db_path = str(tmp_path / "sessions.db") + + class _A(AgentClass): # type: ignore[misc, valid-type] + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.config.server.storage = types.SimpleNamespace( + backend="sqlite", + sqlite_path=db_path, + database_url="", + platform_url="", + platform_token="", + ) + self.config.server.sessions = types.SimpleNamespace( + enabled=True, + max_age_hours=0, + backend=None, + ) + + return OpenAIChatServer(_A) + + +def test_cost_data_persisted_across_turns(tmp_path): + """Two completions on the same session_id accumulate cumulative totals.""" + metrics = StreamMetrics( + prompt_tokens=10, completion_tokens=4, total_tokens=14, + ) + events = [ + ContentDelta(content="ok"), + StreamComplete(finish_reason="stop", metrics=metrics), + ] + server = _build_server_with_sqlite_sessions( + tmp_path, events, model_name="stub-model", + ) + with TestClient(server.app) as client: + # Pre-create the session so the first save's upsert finds it. + resp = client.post("/v1/sessions", json={"session_id": "sess_cost"}) + assert resp.status_code == 201 + + for _ in range(2): + resp = client.post( + "/v1/chat/completions", + json={ + "messages": [{"role": "user", "content": "hi"}], + "stream": False, + "session_id": "sess_cost", + }, + ) + assert resp.status_code == 200 + + # Read cost_data directly from the sqlite store to verify. + store = server._session_store + cost = asyncio.get_event_loop().run_until_complete( + store.get_cost_data("sess_cost") + ) + + assert cost == { + "input_tokens": 20, + "output_tokens": 8, + "cached_tokens": 0, + "model": "stub-model", + "turn_count": 2, + } + + +def test_cost_data_no_session_no_persist(tmp_path): + """Without a session_id, update() must not be invoked.""" + metrics = StreamMetrics(prompt_tokens=10, completion_tokens=4) + events = [ + ContentDelta(content="ok"), + StreamComplete(finish_reason="stop", metrics=metrics), + ] + server = _build_server_with_sqlite_sessions(tmp_path, events) + + with TestClient(server.app) as client: + store = server._session_store + update_calls: list = [] + original_update = store.update + + async def _spy(session_id, *, cost_data=None): + update_calls.append((session_id, cost_data)) + return await original_update(session_id, cost_data=cost_data) + + store.update = _spy # type: ignore[method-assign] + + resp = client.post( + "/v1/chat/completions", + json={ + "messages": [{"role": "user", "content": "hi"}], + "stream": False, + }, + ) + assert resp.status_code == 200 + + assert update_calls == [] + + +def test_cost_data_persist_failure_does_not_break_response(tmp_path): + """If update() raises, the chat response still completes successfully.""" + metrics = StreamMetrics(prompt_tokens=10, completion_tokens=4) + events = [ + ContentDelta(content="ok"), + StreamComplete(finish_reason="stop", metrics=metrics), + ] + server = _build_server_with_sqlite_sessions(tmp_path, events) + + with TestClient(server.app) as client: + resp = client.post( + "/v1/sessions", json={"session_id": "sess_boom"}, + ) + assert resp.status_code == 201 + + async def _boom(session_id, *, cost_data=None): # noqa: ARG001 + raise RuntimeError("simulated platform 500") + + server._session_store.update = _boom # type: ignore[method-assign] + + resp = client.post( + "/v1/chat/completions", + json={ + "messages": [{"role": "user", "content": "hi"}], + "stream": False, + "session_id": "sess_boom", + }, + ) + + assert resp.status_code == 200 + assert resp.json()["choices"][0]["message"]["content"] == "ok" + + +def test_cost_data_null_session_store_is_noop(): + """With NullSessionStore (default), the server doesn't crash.""" + metrics = StreamMetrics(prompt_tokens=10, completion_tokens=4) + events = [ + ContentDelta(content="ok"), + StreamComplete(finish_reason="stop", metrics=metrics), + ] + # _build_server uses the default stub config (sessions disabled, + # NullSessionStore). The chat request without a session_id must + # succeed. + server = _build_server(events) + with TestClient(server.app) as client: + resp = client.post( + "/v1/chat/completions", + json={ + "messages": [{"role": "user", "content": "hi"}], + "stream": False, + }, + ) + assert resp.status_code == 200 + # Even when session_id is provided but store is the Null backend + # (sessions disabled in config), no crash. + with TestClient(server.app) as client: + resp = client.post( + "/v1/chat/completions", + json={ + "messages": [{"role": "user", "content": "hi"}], + "stream": False, + "session_id": "sess_anything", + }, + ) + assert resp.status_code == 200 + + +def test_cost_data_persisted_across_streaming_turns(tmp_path): + """The streaming path also accumulates per-turn token usage.""" + metrics = StreamMetrics( + prompt_tokens=7, completion_tokens=3, total_tokens=10, + ) + events = [ + ContentDelta(content="hi"), + StreamComplete(finish_reason="stop", metrics=metrics), + ] + server = _build_server_with_sqlite_sessions( + tmp_path, events, model_name="stream-stub", + ) + with TestClient(server.app) as client: + resp = client.post( + "/v1/sessions", json={"session_id": "sess_stream"}, + ) + assert resp.status_code == 201 + + for _ in range(2): + resp = client.post( + "/v1/chat/completions", + json={ + "messages": [{"role": "user", "content": "hi"}], + "stream": True, + "session_id": "sess_stream", + }, + ) + assert resp.status_code == 200 + # Drain so the streaming task completes before next turn. + _ = resp.text + + cost = asyncio.get_event_loop().run_until_complete( + server._session_store.get_cost_data("sess_stream") + ) + + assert cost["input_tokens"] == 14 + assert cost["output_tokens"] == 6 + assert cost["turn_count"] == 2 + assert cost["model"] == "stream-stub" + + +def test_cost_data_no_usage_no_persist(tmp_path): + """When the model didn't report usage, no cost_data is written.""" + # No prompt_tokens / completion_tokens → metrics has all-None counts. + metrics = StreamMetrics() + events = [ + ContentDelta(content="ok"), + StreamComplete(finish_reason="stop", metrics=metrics), + ] + server = _build_server_with_sqlite_sessions(tmp_path, events) + + with TestClient(server.app) as client: + resp = client.post( + "/v1/sessions", json={"session_id": "sess_nousage"}, + ) + assert resp.status_code == 201 + + update_calls: list = [] + original_update = server._session_store.update + + async def _spy(session_id, *, cost_data=None): + update_calls.append((session_id, cost_data)) + return await original_update(session_id, cost_data=cost_data) + + server._session_store.update = _spy # type: ignore[method-assign] + + resp = client.post( + "/v1/chat/completions", + json={ + "messages": [{"role": "user", "content": "hi"}], + "stream": False, + "session_id": "sess_nousage", + }, + ) + assert resp.status_code == 200 + + assert update_calls == [] + + +def test_cost_data_http_get_not_implemented_falls_back_to_delta(tmp_path): + """When get_cost_data raises NotImplementedError, the next write is the delta only. + + This emulates the HttpSessionStore case until the platform exposes a + GET cost_data endpoint. The accumulator must NOT crash; it simply + treats the existing total as empty so the write becomes a per-turn + delta rather than a true cumulative. + """ + metrics = StreamMetrics(prompt_tokens=10, completion_tokens=4) + events = [ + ContentDelta(content="ok"), + StreamComplete(finish_reason="stop", metrics=metrics), + ] + server = _build_server_with_sqlite_sessions(tmp_path, events) + + with TestClient(server.app) as client: + resp = client.post( + "/v1/sessions", json={"session_id": "sess_http_like"}, + ) + assert resp.status_code == 201 + + async def _no_read(session_id): # noqa: ARG001 + raise NotImplementedError("simulated http backend") + + server._session_store.get_cost_data = _no_read # type: ignore[method-assign] + + # Two turns: each writes the per-turn delta because the read + # raises NotImplementedError. With a real Sqlite backend the + # update() path still merges into the row, so we expect two + # separate writes that DO accumulate via update()'s shallow + # merge -- but turn_count won't be cumulative since we can't + # read the prior value. That is the documented behaviour. + for _ in range(2): + resp = client.post( + "/v1/chat/completions", + json={ + "messages": [{"role": "user", "content": "hi"}], + "stream": False, + "session_id": "sess_http_like", + }, + ) + assert resp.status_code == 200 + + # Restore the real reader to inspect what actually got written. + del server._session_store.get_cost_data # type: ignore[attr-defined] + cost = asyncio.get_event_loop().run_until_complete( + server._session_store.get_cost_data("sess_http_like") + ) + + # Each write replaced (write-wins) the per-turn delta, not the cumulative. + assert cost["input_tokens"] == 10 + assert cost["output_tokens"] == 4 + assert cost["turn_count"] == 1 diff --git a/packages/fipsagents/tests/test_sessions.py b/packages/fipsagents/tests/test_sessions.py index 081d782..95d945b 100644 --- a/packages/fipsagents/tests/test_sessions.py +++ b/packages/fipsagents/tests/test_sessions.py @@ -63,6 +63,17 @@ async def test_exists_returns_false(self): store = NullSessionStore() assert await store.exists("anything") is False + @pytest.mark.asyncio + async def test_update_returns_false(self): + store = NullSessionStore() + assert await store.update("anything", cost_data={"x": 1}) is False + assert await store.update("anything") is False + + @pytest.mark.asyncio + async def test_get_cost_data_returns_empty_dict(self): + store = NullSessionStore() + assert await store.get_cost_data("anything") == {} + # --------------------------------------------------------------------------- # SqliteSessionStore @@ -189,6 +200,120 @@ async def test_close_and_reopen(self, tmp_path): assert loaded == msgs + # -- update() / cost_data ------------------------------------------------ + + @staticmethod + async def _read_cost_data(store, session_id): + """Read raw cost_data JSON via direct DB query.""" + import json as _json + + db = await store._get_db() + cursor = await db.execute( + "SELECT cost_data FROM sessions WHERE session_id = ?", + (session_id,), + ) + row = await cursor.fetchone() + return _json.loads(row[0]) if row else None + + @pytest.mark.asyncio + async def test_update_merges_cost_data(self, sqlite_store): + """Successive update() calls shallow-merge with write-wins.""" + sid = await sqlite_store.create() + + assert await sqlite_store.update(sid, cost_data={"a": 1}) is True + assert await sqlite_store.update(sid, cost_data={"b": 2, "a": 5}) is True + + merged = await self._read_cost_data(sqlite_store, sid) + assert merged == {"a": 5, "b": 2} + + @pytest.mark.asyncio + async def test_update_missing_session(self, sqlite_store): + """update() on a nonexistent session returns False.""" + assert await sqlite_store.update("doesnotexist", cost_data={"x": 1}) is False + + @pytest.mark.asyncio + async def test_update_none_returns_existence(self, sqlite_store): + """cost_data=None means: just confirm whether the session exists.""" + sid = await sqlite_store.create() + assert await sqlite_store.update(sid) is True + assert await sqlite_store.update("missing-session") is False + + @pytest.mark.asyncio + async def test_save_preserves_cost_data(self, sqlite_store): + """save() must not clobber cost_data accumulated via update().""" + sid = await sqlite_store.create() + await sqlite_store.update(sid, cost_data={"tokens": 100, "usd": 0.01}) + + await sqlite_store.save(sid, [{"role": "user", "content": "first"}]) + await sqlite_store.save(sid, [ + {"role": "user", "content": "first"}, + {"role": "assistant", "content": "second"}, + ]) + + cost = await self._read_cost_data(sqlite_store, sid) + assert cost == {"tokens": 100, "usd": 0.01} + + @pytest.mark.asyncio + async def test_get_cost_data_empty_by_default(self, sqlite_store): + """A freshly created session has no cost_data.""" + sid = await sqlite_store.create() + assert await sqlite_store.get_cost_data(sid) == {} + + @pytest.mark.asyncio + async def test_get_cost_data_returns_merged_state(self, sqlite_store): + """get_cost_data round-trips the accumulator state.""" + sid = await sqlite_store.create() + await sqlite_store.update( + sid, cost_data={"input_tokens": 11, "model": "stub"}, + ) + await sqlite_store.update( + sid, cost_data={"output_tokens": 5, "model": "stub-2"}, + ) + cost = await sqlite_store.get_cost_data(sid) + assert cost == { + "input_tokens": 11, + "output_tokens": 5, + "model": "stub-2", + } + + @pytest.mark.asyncio + async def test_get_cost_data_missing_returns_empty(self, sqlite_store): + """An unknown session returns an empty dict, not None.""" + assert await sqlite_store.get_cost_data("nope") == {} + + @pytest.mark.asyncio + async def test_migration_adds_cost_data_column(self, tmp_path): + """A pre-existing DB without cost_data is migrated transparently.""" + import aiosqlite + + db_path = str(tmp_path / "legacy.db") + + # Create the old schema by hand (no cost_data column). + async with aiosqlite.connect(db_path) as legacy: + await legacy.execute( + "CREATE TABLE sessions (" + " session_id TEXT PRIMARY KEY, " + " messages TEXT NOT NULL, " + " created_at TEXT NOT NULL, " + " updated_at TEXT NOT NULL" + ")" + ) + await legacy.execute( + "INSERT INTO sessions (session_id, messages, created_at, updated_at) " + "VALUES (?, ?, ?, ?)", + ("legacy-1", "[]", "2025-01-01T00:00:00+00:00", "2025-01-01T00:00:00+00:00"), + ) + await legacy.commit() + + # Open via SqliteSessionStore -- _ensure_table() should add cost_data. + store = SqliteSessionStore(db_path) + try: + assert await store.update("legacy-1", cost_data={"a": 1}) is True + cost = await self._read_cost_data(store, "legacy-1") + assert cost == {"a": 1} + finally: + await store.close() + # --------------------------------------------------------------------------- # Factory