diff --git a/pyproject.toml b/pyproject.toml index 3e5e92a..b85181e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,7 +26,7 @@ classifiers = [ ] keywords = ["ai", "agent", "fips", "openshift", "fastapi"] dependencies = [ - "fipsagents[feedback,server]>=0.12.0", + "fipsagents[feedback,server]>=0.14.0", "fastapi>=0.110", "uvicorn[standard]>=0.27", "pydantic>=2.0", diff --git a/src/fipsagents_platform/routes/sessions.py b/src/fipsagents_platform/routes/sessions.py index f05645b..88b76a7 100644 --- a/src/fipsagents_platform/routes/sessions.py +++ b/src/fipsagents_platform/routes/sessions.py @@ -24,6 +24,17 @@ class SaveSessionRequest(BaseModel): messages: list[dict] = Field(default_factory=list) +class UpdateSessionRequest(BaseModel): + """Request body for PATCH /v1/sessions/{session_id}. + + Partial update — only the fields present are touched. ``cost_data`` is + shallow-merged with any existing cost data on the platform side + (write-wins per top-level key), matching ``SessionStore.update()``. + """ + + cost_data: dict | None = None + + @router.post("", status_code=201) async def create_session( request: Request, @@ -66,6 +77,30 @@ async def save_session( return JSONResponse({"session_id": session_id, "saved": True}) +@router.patch("/{session_id}") +async def update_session( + session_id: str, + body: UpdateSessionRequest, + request: Request, + _user: str = Depends(require_user), +) -> JSONResponse: + """Partial update for a session (currently: ``cost_data``). + + Delegates to ``SessionStore.update()`` which shallow-merges the supplied + ``cost_data`` with any existing accumulator state. Returns the full + session shape so the agent side can observe the merged result. + """ + store = request.app.state.session_store + updated = await store.update(session_id, cost_data=body.cost_data) + if not updated: + raise HTTPException(status_code=404, detail=f"Session {session_id} not found") + messages = await store.load(session_id) + return JSONResponse({ + "session_id": session_id, + "messages": messages or [], + }) + + @router.head("/{session_id}") async def session_exists( session_id: str, diff --git a/tests/test_sessions.py b/tests/test_sessions.py index 5ca4f02..5c7dc2f 100644 --- a/tests/test_sessions.py +++ b/tests/test_sessions.py @@ -7,9 +7,29 @@ from __future__ import annotations +import json + import pytest +async def _read_cost_data(client, session_id: str) -> dict: + """Read ``cost_data`` straight out of the SQLite store. + + The public ``GET /v1/sessions/{id}`` route only returns messages, but the + PATCH route is about ``cost_data``. We peek at the underlying aiosqlite + connection via ``app.state.session_store`` to confirm the merge happened. + """ + store = client._test_app.state.session_store + db = await store._get_db() + cursor = await db.execute( + "SELECT cost_data FROM sessions WHERE session_id = ?", + (session_id,), + ) + row = await cursor.fetchone() + assert row is not None, f"session {session_id!r} not found in store" + return json.loads(row[0]) if row[0] else {} + + @pytest.mark.asyncio async def test_create_generates_id(client) -> None: resp = await client.post("/v1/sessions", json={}) @@ -138,3 +158,76 @@ async def test_head_missing_session_returns_404(client) -> None: resp = await client.head("/v1/sessions/not-here") assert resp.status_code == 404 assert resp.content == b"" + + +# --------------------------------------------------------------------------- +# PATCH /v1/sessions/{session_id} — partial update for cost_data accumulator. +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_patch_session_updates_cost_data(client) -> None: + """Two PATCH calls shallow-merge per top-level key (write-wins).""" + sid = "patch-cost" + await client.post("/v1/sessions", json={"session_id": sid}) + + first = await client.patch(f"/v1/sessions/{sid}", json={"cost_data": {"a": 1}}) + assert first.status_code == 200 + body = first.json() + assert body == {"session_id": sid, "messages": []} + + second = await client.patch( + f"/v1/sessions/{sid}", json={"cost_data": {"b": 2, "a": 5}} + ) + assert second.status_code == 200 + + merged = await _read_cost_data(client, sid) + assert merged == {"a": 5, "b": 2} + + +@pytest.mark.asyncio +async def test_patch_session_404_when_missing(client) -> None: + resp = await client.patch( + "/v1/sessions/never-existed", json={"cost_data": {"a": 1}} + ) + assert resp.status_code == 404 + + +@pytest.mark.asyncio +async def test_patch_session_none_cost_data(client) -> None: + """``cost_data: null`` is a no-op merge but still confirms existence.""" + sid = "patch-noop" + await client.post("/v1/sessions", json={"session_id": sid}) + await client.patch(f"/v1/sessions/{sid}", json={"cost_data": {"seed": 1}}) + + resp = await client.patch(f"/v1/sessions/{sid}", json={"cost_data": None}) + assert resp.status_code == 200 + + # Existing accumulator state is left alone. + assert await _read_cost_data(client, sid) == {"seed": 1} + + +@pytest.mark.asyncio +async def test_patch_session_404_when_missing_with_none_cost_data(client) -> None: + """Even with ``cost_data: null`` the route must 404 for unknown sessions.""" + resp = await client.patch( + "/v1/sessions/ghost", json={"cost_data": None} + ) + assert resp.status_code == 404 + + +@pytest.mark.asyncio +async def test_patch_session_returns_messages(client) -> None: + """PATCH echoes the persisted message history alongside the merge result.""" + sid = "patch-with-history" + messages = [ + {"role": "user", "content": "hello"}, + {"role": "assistant", "content": "hi"}, + ] + await client.put(f"/v1/sessions/{sid}", json={"messages": messages}) + + resp = await client.patch( + f"/v1/sessions/{sid}", json={"cost_data": {"prompt_tokens": 42}} + ) + assert resp.status_code == 200 + assert resp.json() == {"session_id": sid, "messages": messages}