From 456c9956cc25bc884d4d158e536b74b4844305dd Mon Sep 17 00:00:00 2001 From: Fabian Zills Date: Thu, 26 Mar 2026 12:51:45 +0100 Subject: [PATCH 01/20] docs: add test suite overhaul design spec Comprehensive spec for eliminating mock abuse, consolidating duplicate fixtures, adding missing E2E tests, and establishing testing guidelines in TESTING.md. Co-Authored-By: Claude Opus 4.6 (1M context) --- .../2026-03-26-test-suite-overhaul-design.md | 161 ++++++++++++++++++ 1 file changed, 161 insertions(+) create mode 100644 docs/superpowers/specs/2026-03-26-test-suite-overhaul-design.md diff --git a/docs/superpowers/specs/2026-03-26-test-suite-overhaul-design.md b/docs/superpowers/specs/2026-03-26-test-suite-overhaul-design.md new file mode 100644 index 000000000..661ba9342 --- /dev/null +++ b/docs/superpowers/specs/2026-03-26-test-suite-overhaul-design.md @@ -0,0 +1,161 @@ +# Test Suite Overhaul — Design Spec + +**Date:** 2026-03-26 +**Status:** Draft +**Trigger:** `uv run zndraw file.h5` broken with 401 — no E2E test covered it. Audit revealed systemic testing issues. + +## Problem Statement + +The test suite has 500+ tests but missed a critical bug because: + +1. **Mocked services hide real failures.** 15 route test files override Redis with `AsyncMock()` objects that silently swallow all calls. The shared `conftest.py` `client` fixture uses `lambda: (None)` for Redis. The `StateFileSource` is patched to `return_value={}` in client tests. These patterns prevent tests from catching real integration failures. +2. **Massive fixture duplication.** 15+ route test files each define identical session/client/Redis fixtures with different prefixes (~3000 lines of copy-paste). Helper functions (`_create_user`, `_create_room`) are redefined in 10+ files despite existing in `helpers.py`. +3. **Missing E2E tests.** No test exercises the main user flow (`uv run zndraw file.h5`). No test verifies ZnDraw client persistence across disconnect/reconnect. No test exercises guest auth → room operations end-to-end. +4. **Excessive monkeypatching.** ~50 `monkeypatch.setattr` / `@patch` calls plus ~30 `AsyncMock()` Redis overrides patch module-level functions to avoid testing real interactions. Many could be eliminated by using real servers. + +## Constraints + +- **Redis is always available** — required for all tests, no mock fallback. +- **SQLite in-memory is acceptable** for the DB layer — SQLModel handles backend abstraction. +- **ASGITransport for route tests** — fast, with real Redis and real DB. +- **Real uvicorn (`server_factory`) for E2E/integration** — Socket.IO, workers, CLI client, cross-process flows. +- **Mocking only for side effects** — `webbrowser.open`, `time.sleep`, `os.kill`. Never for services under test. +- **Big-bang PR** — all changes land in a single PR. Phases represent commit ordering within the branch. +- **Class-to-function conversion is out of scope** — cosmetic, no quality impact. +- **Never modify tests marked `@pytest.mark.protected`.** + +## Design + +### Phase 1: Testing Guidelines (`tests/README.md`) + +A `tests/README.md` co-located with the test suite. First commit in the PR so all subsequent changes have a reference. + +#### Test Hierarchy + +1. **Route tests** — `ASGITransport` + real Redis + real SQLite. For HTTP request/response behavior. `MockSioServer` is acceptable here for verifying broadcast emissions. +2. **Integration/E2E tests** — Real uvicorn via `server_factory`. For flows crossing process boundaries: Socket.IO, CLI client, workers, StateFileSource resolution. +3. **Unit tests** — Pure logic with no I/O. Mocking allowed only here, and only when better design can't avoid it. + +#### Mocking Rules + +**Banned:** +- `lambda: None` or `lambda: (None)` for Redis +- `AsyncMock()` (with or without configured return values) as a substitute for real Redis or result backends +- `patch("StateFileSource.__call__")` to hide token resolution +- `patch("httpx.Client")` to avoid real HTTP (use `server_factory` instead) +- `monkeypatch.setattr` on module-level functions to skip service interactions (e.g., `wait_for_server_ready`, `_acquire_admin_jwt`, `_is_url_healthy` when a real server could be used instead) + +**Allowed:** +- `@patch("os.kill")` — prevents real process termination +- `monkeypatch.setattr("webbrowser.open", ...)` — prevents browser opening +- `monkeypatch.setattr("time.sleep", ...)` — prevents test delays +- `MockSioServer` in route tests — lightweight fake for broadcast verification +- `monkeypatch.setenv` / `monkeypatch.delenv` — standard environment isolation + +**Questionable (evaluate during implementation):** +- `patch.dict("sys.modules", {"PIL": None})` in `test_gif.py` — tests what happens when PIL is missing, but PIL is a required dependency. Consider removing the test entirely. + +**Last resort (requires justifying comment in the test):** +- `monkeypatch.setattr("zndraw.cli.StateFile", lambda: StateFile(directory=tmp_path))` — only if the code under test cannot accept a `StateFile` parameter. StateFile is critical infrastructure and should be tested for real. Redirecting to `tmp_path` is acceptable for filesystem isolation but question whether the test even needs it. + +**Guiding principle:** If you're patching something to avoid testing it, you're writing the wrong kind of test. + +#### Fixture Rules + +- All shared fixtures live in `conftest.py` — no per-file session/client/Redis definitions. +- `helpers.py` functions (`create_test_user_in_db`, `create_test_room`, `auth_header`, `create_test_token`) are the ONLY way to create test data. The per-file `_create_user` copies do both user creation AND token creation — the consolidated path uses `create_test_user_in_db()` (returns `(user, token)` tuple) which already handles both. +- Fixture scoping: function-scoped for data, session-scoped only for truly stateless infrastructure. + +#### The Iron Law: When Writing New Tests You MUST Follow This Pattern + +1. **`/brainstorming` first** — plan what to test, which test tier (route / E2E / unit), which fixtures to use. Do not write tests blindly. +2. **`/test-driven-development`** — RED/GREEN/REFACTOR cycle. Write the failing test first, watch it fail, then implement. +3. **Review against these guidelines** — does the test use real services? Does it avoid banned mock patterns? Does it use shared fixtures? + +This applies to ALL test writing — manual, AI-assisted, or agent-driven. No exceptions. + +#### Required E2E Coverage + +Critical user flows MUST have E2E tests against real servers: +- CLI upload flow (`_acquire_admin_jwt` → store JWT → ZnDraw client → write frames) +- ZnDraw Python client connect → write → read → disconnect → reconnect → verify +- Guest auth → room operations +- StateFileSource with real running server + +### Phase 2: Shared Fixture Consolidation + +Replace per-file fixture sets with shared fixtures in `tests/zndraw/conftest.py`: + +``` +conftest.py fixtures: + session — in-memory SQLite (already exists) + redis_client — real Redis, flushed per test (already exists) + client — ASGITransport + real session + real Redis + MockSioServer + test_user — creates user via helpers.create_test_user_in_db() + test_room — creates room via helpers.create_test_room() +``` + +Key changes to the existing `client` fixture: +- Replace `app.dependency_overrides[get_redis] = lambda: (None)` with real `redis_client` +- Add `FrameStorage` backed by real Redis +- Add `ResultBackend` dependency — extract `InMemoryResultBackend` from `zndraw_joblib/conftest.py` to a shared location (e.g., `tests/shared_helpers.py` or `tests/conftest.py`) so it can be used by `tests/zndraw/` without cross-package conftest imports +- Keep `MockSioServer` for route tests (legitimate fake for broadcast verification) + +### Phase 3: Route Test File Conversion + +All route test files delete their per-file fixtures and use shared ones: + +**Files:** `test_routes_bookmarks.py`, `test_routes_edit_lock.py`, `test_routes_figures.py`, `test_routes_frame_selection.py`, `test_routes_frames.py`, `test_routes_geometries.py`, `test_routes_presets.py`, `test_routes_selection_groups.py`, `test_routes_step.py`, `test_screenshots.py`, `test_chat.py`, `test_isosurface.py`, `test_progress.py`, `test_default_camera.py`, `test_frames_provider_dispatch.py`, `test_trajectory.py` + +**Per file:** +- Delete per-file session fixture (`bm_session`, `el_session`, etc.) +- Delete per-file client fixture (`bm_client`, `el_client`, etc.) +- Delete per-file `_create_user()`, `_create_room()`, `_auth()` — use `helpers.py` +- Update test function signatures to use shared fixture names + +### Phase 4: Mock Cleanup + +| Target | Files | Replacement | +|--------|-------|-------------| +| `AsyncMock()` for result backend | `test_isosurface.py`, `test_routes_frames.py` | `InMemoryResultBackend` (extracted to shared location in Phase 2) | +| `patch("StateFileSource.__call__", return_value={})` | `test_client_settings.py`, `test_client_api.py` | Real `StateFile(directory=tmp_path)` via constructor injection | +| `patch("zndraw.auth_utils.httpx.Client")` | `test_resolve_token.py` | Test against real server with `server_factory` | +| `patch("zndraw.cli_agent.auth.httpx.Client")` | `test_cli_auth.py`, `test_cli_agent/test_auth.py` | Test against real server | +| `patch("_is_url_healthy")` + `_is_pid_alive` | `test_state_file_source.py` (~19 occurrences) | Split: pure logic unit tests (no patches) + integration tests with `server_factory` | + +**`test_cli.py` evaluation:** This file has ~24 monkeypatch calls that mock `uvicorn.Server.run`, `wait_for_server_ready`, `upload_file`, `_acquire_admin_jwt`, etc. These are evaluated case-by-case during implementation: +- Patches for call-ordering tests (browser-before-upload) may stay if they test orchestration logic +- Patches that skip real service interactions should be converted to `server_factory` integration tests where feasible + +### Phase 5: Missing E2E Tests + +New test files using `server_factory`: + +1. **Client persistence E2E** — connect → write frames → disconnect → reconnect → verify data +2. **Guest auth E2E** — `POST /v1/auth/guest` → JWT → create room → write frames → read back +3. **StateFileSource integration** — start server → StateFileSource discovers it → resolves token → client connects + +(CLI upload E2E already added in PR #896) + +### Phase 6: Parametrize Opportunities + +Convert obvious candidates: +- `test_routes_geometries.py`: 404-assertion tests (room-not-found + geometry-not-found) → parametrized +- `test_routes_edit_lock.py`: 3 similar 403 authorization tests → parametrized +- `test_auth_endpoints.py`: 8 login/register tests → 2 parametrized tests + +## Out of Scope + +- Class-to-function test conversion (cosmetic) +- `zndraw_auth/conftest.py` refactoring (separate package) +- `zndraw_joblib/conftest.py` refactoring (well-designed fakes, except extracting `InMemoryResultBackend` to shared location) +- Subprocess-based worker tests (tracked in #898) + +## Success Criteria + +- No Redis dependency override uses `AsyncMock()`, `lambda: None`, or `lambda: (None)` anywhere in `tests/` +- No test file outside `conftest.py` defines any session fixture (including prefixed variants like `bm_session`, `el_session`) or any client fixture +- All 4 critical E2E gaps covered with real-server tests +- `tests/README.md` exists and documents the rules +- Every remaining `monkeypatch.setattr` / `@patch` call has a short inline comment explaining WHY it is justified +- All existing tests still pass From 0824db9ddb92d946484f538a03c724c3f04d498c Mon Sep 17 00:00:00 2001 From: Fabian Zills Date: Thu, 26 Mar 2026 13:30:32 +0100 Subject: [PATCH 02/20] docs: add test suite guidelines (tests/README.md) --- tests/zndraw/README.md | 58 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 58 insertions(+) create mode 100644 tests/zndraw/README.md diff --git a/tests/zndraw/README.md b/tests/zndraw/README.md new file mode 100644 index 000000000..248dec525 --- /dev/null +++ b/tests/zndraw/README.md @@ -0,0 +1,58 @@ +# ZnDraw Test Suite Guidelines + +## Test Hierarchy + +1. **Route tests** -- `ASGITransport` + real Redis + real SQLite. + For HTTP request/response behavior. `MockSioServer` is acceptable + for verifying broadcast emissions. + +2. **Integration/E2E tests** -- Real uvicorn via `server_factory`. + For flows crossing process boundaries: Socket.IO, CLI client, + workers, StateFileSource resolution. + +3. **Unit tests** -- Pure logic with no I/O. Mocking allowed only here, + and only when better design can't avoid it. + +## Fixture Rules + +- All shared fixtures live in `conftest.py`. +- No per-file session/client/Redis fixture definitions. +- `helpers.py` functions are the ONLY way to create test data: + - `create_test_user_in_db(session, email)` -> `(User, token)` + - `create_test_room(session, user, description)` -> `Room` + - `auth_header(token)` -> `dict` + - `create_test_token(user)` -> `str` +- Fixture scoping: function-scoped for data, session-scoped only for + truly stateless infrastructure. + +## Mocking Rules + +### Banned + +- `lambda: None` or `lambda: (None)` for Redis +- `AsyncMock()` as a substitute for real Redis or result backends +- `patch("StateFileSource.__call__")` to hide token resolution +- `patch("httpx.Client")` to avoid real HTTP (use `server_factory`) +- `monkeypatch.setattr` on module-level functions to skip service + interactions when a real server could be used instead + +### Allowed + +- `@patch("os.kill")` -- prevents real process termination +- `monkeypatch.setattr("webbrowser.open", ...)` -- prevents browser +- `monkeypatch.setattr("time.sleep", ...)` -- prevents test delays +- `MockSioServer` in route tests -- lightweight fake for broadcasts +- `monkeypatch.setenv` / `monkeypatch.delenv` -- environment isolation + +### Last resort (requires inline comment explaining WHY) + +- `monkeypatch.setattr("zndraw.cli.StateFile", ...)` -- only if the + code cannot accept a StateFile parameter +- `patch("_is_pid_alive", ...)` -- only for pure-logic unit tests of + StateFileSource decision logic, not integration tests + +## Writing New Tests + +1. Plan what to test: which tier (route / E2E / unit), which fixtures. +2. RED/GREEN/REFACTOR: write failing test first, then implement. +3. Review against these guidelines before committing. From aea0d21c817981b33f90a3243c7f454bca0ca5fa Mon Sep 17 00:00:00 2001 From: Fabian Zills Date: Thu, 26 Mar 2026 14:02:37 +0100 Subject: [PATCH 03/20] refactor: upgrade shared test fixtures with real Redis and InMemoryResultBackend Co-Authored-By: Claude Sonnet 4.6 --- tests/zndraw/conftest.py | 51 +++++++++++++++++++++++++++----------- tests/zndraw/helpers.py | 53 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 90 insertions(+), 14 deletions(-) diff --git a/tests/zndraw/conftest.py b/tests/zndraw/conftest.py index a9534f2b2..f2d010780 100644 --- a/tests/zndraw/conftest.py +++ b/tests/zndraw/conftest.py @@ -12,7 +12,7 @@ import pytest import pytest_asyncio import uvicorn -from helpers import MockSioServer, create_test_user_model +from helpers import InMemoryResultBackend, MockSioServer, create_test_user_model from httpx import ASGITransport, AsyncClient from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine from sqlalchemy.pool import StaticPool @@ -81,30 +81,41 @@ async def session_fixture() -> AsyncIterator[AsyncSession]: @pytest_asyncio.fixture(name="client") -async def client_fixture(session: AsyncSession) -> AsyncIterator[AsyncClient]: - """Create an async test client with the session dependency overridden.""" +async def client_fixture( + session: AsyncSession, + redis_client, + mock_sio: MockSioServer, + frame_storage: FrameStorage, + result_backend: InMemoryResultBackend, +) -> AsyncIterator[AsyncClient]: + """Async test client with real Redis, real DB, MockSioServer. + + All route tests share this fixture. No per-file client definitions. + """ from zndraw.app import app - from zndraw.dependencies import get_redis, get_tsio + from zndraw.dependencies import ( + get_frame_storage, + get_joblib_settings, + get_redis, + get_result_backend, + get_tsio, + ) from zndraw_auth import get_session - - mock_sio = MockSioServer() + from zndraw_joblib.settings import JobLibSettings async def get_session_override() -> AsyncIterator[AsyncSession]: yield session - def get_sio_override() -> MockSioServer: - return mock_sio - - # Create test session_maker for Socket.IO handlers @asynccontextmanager async def test_session_maker(): yield session app.dependency_overrides[get_session] = get_session_override - app.dependency_overrides[get_redis] = lambda: ( - None - ) # tests that need redis override this - app.dependency_overrides[get_tsio] = get_sio_override + app.dependency_overrides[get_redis] = lambda: redis_client + app.dependency_overrides[get_tsio] = lambda: mock_sio + app.dependency_overrides[get_frame_storage] = lambda: frame_storage + app.dependency_overrides[get_result_backend] = lambda: result_backend + app.dependency_overrides[get_joblib_settings] = lambda: JobLibSettings() app.state.session_maker = test_session_maker app.state.settings = Settings() app.state.auth_settings = AuthSettings() @@ -128,6 +139,18 @@ async def test_user_fixture(session: AsyncSession) -> User: return user +@pytest.fixture(name="mock_sio") +def mock_sio_fixture() -> MockSioServer: + """MockSioServer for route tests — shared across client and test assertions.""" + return MockSioServer() + + +@pytest.fixture(name="result_backend") +def result_backend_fixture() -> InMemoryResultBackend: + """In-memory result backend for route tests.""" + return InMemoryResultBackend() + + # ============================================================================= # Redis Test Fixtures # ============================================================================= diff --git a/tests/zndraw/helpers.py b/tests/zndraw/helpers.py index 3b774442e..ae76b107f 100644 --- a/tests/zndraw/helpers.py +++ b/tests/zndraw/helpers.py @@ -4,6 +4,7 @@ Fixtures live in conftest.py. """ +import asyncio from typing import Any import msgpack @@ -152,3 +153,55 @@ async def get_session(self, sid: str) -> dict[str, Any]: async def save_session(self, sid: str, session: dict[str, Any]) -> None: self.sessions[sid] = session + + +class InMemoryResultBackend: + """In-memory result backend for testing. + + Drop-in replacement for the real Redis-based ResultBackend. + Extracted from zndraw_joblib/conftest.py for shared use. + """ + + def __init__(self) -> None: + self._store: dict[str, bytes] = {} + self._inflight: set[str] = set() + self._waiters: dict[str, list[asyncio.Event]] = {} + + async def store(self, key: str, data: bytes, _ttl: int) -> None: + self._store[key] = data + await self.notify_key(key) + + async def get(self, key: str) -> bytes | None: + return self._store.get(key) + + async def delete(self, key: str) -> None: + self._store.pop(key, None) + + async def acquire_inflight(self, key: str, _ttl: int) -> bool: + if key in self._inflight: + return False + self._inflight.add(key) + return True + + async def release_inflight(self, key: str) -> None: + self._inflight.discard(key) + + async def wait_for_key(self, key: str, timeout: float) -> bytes | None: + cached = self._store.get(key) + if cached is not None: + return cached + event = asyncio.Event() + self._waiters.setdefault(key, []).append(event) + try: + await asyncio.wait_for(event.wait(), timeout=timeout) + return self._store.get(key) + except TimeoutError: + return None + finally: + waiters = self._waiters.get(key, []) + if event in waiters: + waiters.remove(event) + + async def notify_key(self, key: str) -> None: + for event in self._waiters.pop(key, []): + event.set() From c2cc72a08fa03701588f70d12da9d783f80ae80a Mon Sep 17 00:00:00 2001 From: Fabian Zills Date: Thu, 26 Mar 2026 14:07:23 +0100 Subject: [PATCH 04/20] refactor: convert bookmarks/figures/selection_groups/presets to shared fixtures Replace per-file session/client/mock_sio fixtures and helper functions with shared conftest fixtures (client, session, mock_sio) and helpers module (create_test_user_in_db, create_test_room, auth_header). Convert MagicMock SIO assertions to MockSioServer emitted list pattern. Co-Authored-By: Claude Sonnet 4.6 --- tests/zndraw/test_routes_bookmarks.py | 324 ++++-------- tests/zndraw/test_routes_figures.py | 345 +++++-------- tests/zndraw/test_routes_presets.py | 491 +++++++------------ tests/zndraw/test_routes_selection_groups.py | 272 +++------- 4 files changed, 476 insertions(+), 956 deletions(-) diff --git a/tests/zndraw/test_routes_bookmarks.py b/tests/zndraw/test_routes_bookmarks.py index dfe3138a2..8108695a0 100644 --- a/tests/zndraw/test_routes_bookmarks.py +++ b/tests/zndraw/test_routes_bookmarks.py @@ -1,131 +1,20 @@ """Tests for Bookmarks REST API endpoints.""" -from collections.abc import AsyncIterator -from unittest.mock import AsyncMock, MagicMock - import pytest import pytest_asyncio -from helpers import create_test_token, create_test_user_model -from httpx import ASGITransport, AsyncClient -from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine -from sqlalchemy.pool import StaticPool -from sqlmodel import SQLModel +from helpers import ( + MockSioServer, + auth_header, + create_test_room, + create_test_user_in_db, +) +from httpx import AsyncClient +from sqlalchemy.ext.asyncio import AsyncSession -from zndraw.config import Settings from zndraw.exceptions import BookmarkNotFound -from zndraw.models import MemberRole, Room, RoomBookmark, RoomMembership +from zndraw.models import RoomBookmark from zndraw.schemas import StatusResponse from zndraw.socket_events import BookmarksInvalidate -from zndraw_auth import User -from zndraw_auth.settings import AuthSettings - -# ============================================================================= -# Test Fixtures -# ============================================================================= - - -@pytest_asyncio.fixture(name="bm_session") -async def bm_session_fixture() -> AsyncIterator[AsyncSession]: - """Create a fresh in-memory async database session for each test.""" - engine = create_async_engine( - "sqlite+aiosqlite://", - connect_args={"check_same_thread": False}, - poolclass=StaticPool, - ) - - try: - async with engine.begin() as conn: - await conn.run_sync(SQLModel.metadata.create_all) - - async_session_factory = async_sessionmaker( - bind=engine, - class_=AsyncSession, - expire_on_commit=False, - ) - - async with async_session_factory() as session: - yield session - finally: - await engine.dispose() - - -@pytest_asyncio.fixture(name="mock_sio") -async def mock_sio_fixture() -> MagicMock: - """Create a mock Socket.IO server for testing.""" - sio_mock = MagicMock() - sio_mock.emit = AsyncMock() - return sio_mock - - -@pytest_asyncio.fixture(name="bm_client") -async def bm_client_fixture( - bm_session: AsyncSession, - mock_sio: MagicMock, -) -> AsyncIterator[AsyncClient]: - """Create an async test client with dependencies overridden.""" - from zndraw.app import app - from zndraw.dependencies import get_redis, get_tsio - from zndraw_auth import get_session - - async def get_session_override() -> AsyncIterator[AsyncSession]: - yield bm_session - - def get_sio_override() -> MagicMock: - return mock_sio - - # Mock Redis for WritableRoomDep (returns None = no edit lock) - mock_redis = AsyncMock() - mock_redis.get = AsyncMock(return_value=None) - - app.dependency_overrides[get_session] = get_session_override - app.dependency_overrides[get_tsio] = get_sio_override - app.dependency_overrides[get_redis] = lambda: mock_redis - app.state.settings = Settings() - app.state.auth_settings = AuthSettings() - - async with AsyncClient( - transport=ASGITransport(app=app), - base_url="http://test", - ) as client: - yield client - - app.dependency_overrides.clear() - - -async def _create_user( - session: AsyncSession, email: str = "testuser@local.test" -) -> tuple[User, str]: - """Create a user and return the user and access token.""" - user = create_test_user_model(email=email) - session.add(user) - await session.commit() - await session.refresh(user) - token = create_test_token(user) - return user, token - - -async def _create_room( - session: AsyncSession, user: User, description: str = "Test Room" -) -> Room: - """Create a room with user as owner.""" - room = Room( - description=description, - created_by_id=user.id, # type: ignore[arg-type] - is_public=True, - ) - session.add(room) - await session.commit() - await session.refresh(room) - - membership = RoomMembership( - room_id=room.id, # type: ignore[arg-type] - user_id=user.id, # type: ignore[arg-type] - role=MemberRole.OWNER, - ) - session.add(membership) - await session.commit() - - return room async def _add_bookmark( @@ -136,11 +25,6 @@ async def _add_bookmark( await session.commit() -def _auth_header(token: str) -> dict[str, str]: - """Return Authorization header dict.""" - return {"Authorization": f"Bearer {token}"} - - # ============================================================================= # List Bookmarks Tests # ============================================================================= @@ -148,16 +32,16 @@ def _auth_header(token: str) -> dict[str, str]: @pytest.mark.asyncio async def test_list_bookmarks_returns_empty_initially( - bm_client: AsyncClient, - bm_session: AsyncSession, + client: AsyncClient, + session: AsyncSession, ) -> None: """Test GET returns empty bookmarks for new room.""" - user, token = await _create_user(bm_session) - room = await _create_room(bm_session, user) + user, token = await create_test_user_in_db(session) + room = await create_test_room(session, user) - response = await bm_client.get( + response = await client.get( f"/v1/rooms/{room.id}/bookmarks", - headers=_auth_header(token), + headers=auth_header(token), ) assert response.status_code == 200 assert response.json()["items"] == {} @@ -165,20 +49,20 @@ async def test_list_bookmarks_returns_empty_initially( @pytest.mark.asyncio async def test_list_bookmarks_returns_all_bookmarks( - bm_client: AsyncClient, - bm_session: AsyncSession, + client: AsyncClient, + session: AsyncSession, ) -> None: """Test GET returns all bookmarks.""" - user, token = await _create_user(bm_session) - room = await _create_room(bm_session, user) + user, token = await create_test_user_in_db(session) + room = await create_test_room(session, user) - await _add_bookmark(bm_session, room.id, 0, "Start") - await _add_bookmark(bm_session, room.id, 5, "Middle") - await _add_bookmark(bm_session, room.id, 10, "End") + await _add_bookmark(session, room.id, 0, "Start") + await _add_bookmark(session, room.id, 5, "Middle") + await _add_bookmark(session, room.id, 10, "End") - response = await bm_client.get( + response = await client.get( f"/v1/rooms/{room.id}/bookmarks", - headers=_auth_header(token), + headers=auth_header(token), ) assert response.status_code == 200 data = response.json() @@ -194,18 +78,18 @@ async def test_list_bookmarks_returns_all_bookmarks( @pytest.mark.asyncio async def test_get_bookmark_returns_label( - bm_client: AsyncClient, - bm_session: AsyncSession, + client: AsyncClient, + session: AsyncSession, ) -> None: """Test GET single bookmark returns label.""" - user, token = await _create_user(bm_session) - room = await _create_room(bm_session, user) + user, token = await create_test_user_in_db(session) + room = await create_test_room(session, user) - await _add_bookmark(bm_session, room.id, 5, "Important Frame") + await _add_bookmark(session, room.id, 5, "Important Frame") - response = await bm_client.get( + response = await client.get( f"/v1/rooms/{room.id}/bookmarks/5", - headers=_auth_header(token), + headers=auth_header(token), ) assert response.status_code == 200 data = response.json() @@ -215,16 +99,16 @@ async def test_get_bookmark_returns_label( @pytest.mark.asyncio async def test_get_bookmark_returns_404_for_nonexistent( - bm_client: AsyncClient, - bm_session: AsyncSession, + client: AsyncClient, + session: AsyncSession, ) -> None: """Test GET returns 404 for nonexistent bookmark.""" - user, token = await _create_user(bm_session) - room = await _create_room(bm_session, user) + user, token = await create_test_user_in_db(session) + room = await create_test_room(session, user) - response = await bm_client.get( + response = await client.get( f"/v1/rooms/{room.id}/bookmarks/999", - headers=_auth_header(token), + headers=auth_header(token), ) assert response.status_code == 404 assert "bookmark-not-found" in response.json()["type"] @@ -237,18 +121,18 @@ async def test_get_bookmark_returns_404_for_nonexistent( @pytest.mark.asyncio async def test_set_bookmark_creates_bookmark( - bm_client: AsyncClient, - bm_session: AsyncSession, - mock_sio: MagicMock, + client: AsyncClient, + session: AsyncSession, + mock_sio: MockSioServer, ) -> None: """Test PUT creates a new bookmark.""" - user, token = await _create_user(bm_session) - room = await _create_room(bm_session, user) + user, token = await create_test_user_in_db(session) + room = await create_test_room(session, user) - response = await bm_client.put( + response = await client.put( f"/v1/rooms/{room.id}/bookmarks/3", json={"label": "New Bookmark"}, - headers=_auth_header(token), + headers=auth_header(token), ) assert response.status_code == 200 data = response.json() @@ -256,47 +140,45 @@ async def test_set_bookmark_creates_bookmark( assert data["label"] == "New Bookmark" # Verify persisted in DB - row = await bm_session.get(RoomBookmark, (room.id, 3)) + row = await session.get(RoomBookmark, (room.id, 3)) assert row is not None assert row.label == "New Bookmark" @pytest.mark.asyncio async def test_set_bookmark_broadcasts( - bm_client: AsyncClient, - bm_session: AsyncSession, - mock_sio: MagicMock, + client: AsyncClient, + session: AsyncSession, + mock_sio: MockSioServer, ) -> None: """Test PUT broadcasts bookmarks:invalidate event.""" - user, token = await _create_user(bm_session) - room = await _create_room(bm_session, user) + user, token = await create_test_user_in_db(session) + room = await create_test_room(session, user) - await bm_client.put( + await client.put( f"/v1/rooms/{room.id}/bookmarks/3", json={"label": "Test"}, - headers=_auth_header(token), + headers=auth_header(token), ) - mock_sio.emit.assert_called() - call_args = mock_sio.emit.call_args - model = call_args[0][0] - assert isinstance(model, BookmarksInvalidate) - assert call_args[1]["room"] == f"room:{room.id}" + assert len(mock_sio.emitted) == 1 + assert mock_sio.emitted[0]["event"] == "bookmarks_invalidate" + assert mock_sio.emitted[0]["room"] == f"room:{room.id}" @pytest.mark.asyncio async def test_set_bookmark_rejects_empty_label( - bm_client: AsyncClient, - bm_session: AsyncSession, + client: AsyncClient, + session: AsyncSession, ) -> None: """Test PUT rejects empty label.""" - user, token = await _create_user(bm_session) - room = await _create_room(bm_session, user) + user, token = await create_test_user_in_db(session) + room = await create_test_room(session, user) - response = await bm_client.put( + response = await client.put( f"/v1/rooms/{room.id}/bookmarks/3", json={"label": ""}, - headers=_auth_header(token), + headers=auth_header(token), ) assert response.status_code == 422 # Validation error @@ -308,40 +190,40 @@ async def test_set_bookmark_rejects_empty_label( @pytest.mark.asyncio async def test_delete_bookmark_removes_bookmark( - bm_client: AsyncClient, - bm_session: AsyncSession, - mock_sio: MagicMock, + client: AsyncClient, + session: AsyncSession, + mock_sio: MockSioServer, ) -> None: """Test DELETE removes a bookmark.""" - user, token = await _create_user(bm_session) - room = await _create_room(bm_session, user) + user, token = await create_test_user_in_db(session) + room = await create_test_room(session, user) - await _add_bookmark(bm_session, room.id, 5, "To Delete") + await _add_bookmark(session, room.id, 5, "To Delete") - response = await bm_client.delete( + response = await client.delete( f"/v1/rooms/{room.id}/bookmarks/5", - headers=_auth_header(token), + headers=auth_header(token), ) assert response.status_code == 200 StatusResponse.model_validate(response.json()) # Verify deleted from DB - row = await bm_session.get(RoomBookmark, (room.id, 5)) + row = await session.get(RoomBookmark, (room.id, 5)) assert row is None @pytest.mark.asyncio async def test_delete_nonexistent_bookmark_returns_404( - bm_client: AsyncClient, - bm_session: AsyncSession, + client: AsyncClient, + session: AsyncSession, ) -> None: """Test DELETE on nonexistent bookmark returns 404.""" - user, token = await _create_user(bm_session) - room = await _create_room(bm_session, user) + user, token = await create_test_user_in_db(session) + room = await create_test_room(session, user) - response = await bm_client.delete( + response = await client.delete( f"/v1/rooms/{room.id}/bookmarks/999", - headers=_auth_header(token), + headers=auth_header(token), ) assert response.status_code == 404 assert response.json()["type"] == BookmarkNotFound.type_uri() @@ -349,25 +231,23 @@ async def test_delete_nonexistent_bookmark_returns_404( @pytest.mark.asyncio async def test_delete_bookmark_broadcasts( - bm_client: AsyncClient, - bm_session: AsyncSession, - mock_sio: MagicMock, + client: AsyncClient, + session: AsyncSession, + mock_sio: MockSioServer, ) -> None: """Test DELETE broadcasts bookmarks:invalidate event.""" - user, token = await _create_user(bm_session) - room = await _create_room(bm_session, user) + user, token = await create_test_user_in_db(session) + room = await create_test_room(session, user) - await _add_bookmark(bm_session, room.id, 5, "Test") + await _add_bookmark(session, room.id, 5, "Test") - await bm_client.delete( + await client.delete( f"/v1/rooms/{room.id}/bookmarks/5", - headers=_auth_header(token), + headers=auth_header(token), ) - mock_sio.emit.assert_called() - call_args = mock_sio.emit.call_args - model = call_args[0][0] - assert isinstance(model, BookmarksInvalidate) + assert len(mock_sio.emitted) == 1 + assert mock_sio.emitted[0]["event"] == "bookmarks_invalidate" # ============================================================================= @@ -377,25 +257,25 @@ async def test_delete_bookmark_broadcasts( @pytest.mark.asyncio async def test_list_bookmarks_requires_auth( - bm_client: AsyncClient, bm_session: AsyncSession + client: AsyncClient, session: AsyncSession ) -> None: """Test GET without auth returns 401.""" - user, _ = await _create_user(bm_session) - room = await _create_room(bm_session, user) + user, _ = await create_test_user_in_db(session) + room = await create_test_room(session, user) - response = await bm_client.get(f"/v1/rooms/{room.id}/bookmarks") + response = await client.get(f"/v1/rooms/{room.id}/bookmarks") assert response.status_code == 401 @pytest.mark.asyncio async def test_set_bookmark_requires_auth( - bm_client: AsyncClient, bm_session: AsyncSession + client: AsyncClient, session: AsyncSession ) -> None: """Test PUT without auth returns 401.""" - user, _ = await _create_user(bm_session) - room = await _create_room(bm_session, user) + user, _ = await create_test_user_in_db(session) + room = await create_test_room(session, user) - response = await bm_client.put( + response = await client.put( f"/v1/rooms/{room.id}/bookmarks/3", json={"label": "Test"}, ) @@ -409,14 +289,14 @@ async def test_set_bookmark_requires_auth( @pytest.mark.asyncio async def test_list_bookmarks_returns_404_for_nonexistent_room( - bm_client: AsyncClient, bm_session: AsyncSession + client: AsyncClient, session: AsyncSession ) -> None: """Test GET for non-existent room returns 404.""" - _, token = await _create_user(bm_session) + _, token = await create_test_user_in_db(session) - response = await bm_client.get( + response = await client.get( "/v1/rooms/99999/bookmarks", - headers=_auth_header(token), + headers=auth_header(token), ) assert response.status_code == 404 assert "room-not-found" in response.json()["type"] @@ -424,15 +304,15 @@ async def test_list_bookmarks_returns_404_for_nonexistent_room( @pytest.mark.asyncio async def test_set_bookmark_returns_404_for_nonexistent_room( - bm_client: AsyncClient, bm_session: AsyncSession + client: AsyncClient, session: AsyncSession ) -> None: """Test PUT for non-existent room returns 404.""" - _, token = await _create_user(bm_session) + _, token = await create_test_user_in_db(session) - response = await bm_client.put( + response = await client.put( "/v1/rooms/99999/bookmarks/3", json={"label": "Test"}, - headers=_auth_header(token), + headers=auth_header(token), ) assert response.status_code == 404 assert "room-not-found" in response.json()["type"] diff --git a/tests/zndraw/test_routes_figures.py b/tests/zndraw/test_routes_figures.py index 8599142bc..61847decc 100644 --- a/tests/zndraw/test_routes_figures.py +++ b/tests/zndraw/test_routes_figures.py @@ -1,131 +1,19 @@ """Tests for Figures REST API endpoints.""" -from collections.abc import AsyncIterator -from unittest.mock import AsyncMock, MagicMock - import pytest -import pytest_asyncio -from helpers import create_test_token, create_test_user_model -from httpx import ASGITransport, AsyncClient -from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine -from sqlalchemy.pool import StaticPool -from sqlmodel import SQLModel - -from zndraw.config import Settings +from helpers import ( + MockSioServer, + auth_header, + create_test_room, + create_test_user_in_db, +) +from httpx import AsyncClient +from sqlalchemy.ext.asyncio import AsyncSession + from zndraw.exceptions import FigureNotFound -from zndraw.models import MemberRole, Room, RoomFigure, RoomMembership +from zndraw.models import RoomFigure from zndraw.schemas import FigureData, StatusResponse from zndraw.socket_events import FigureInvalidate -from zndraw_auth import User -from zndraw_auth.settings import AuthSettings - -# ============================================================================= -# Test Fixtures -# ============================================================================= - - -@pytest_asyncio.fixture(name="fig_session") -async def fig_session_fixture() -> AsyncIterator[AsyncSession]: - """Create a fresh in-memory async database session for each test.""" - engine = create_async_engine( - "sqlite+aiosqlite://", - connect_args={"check_same_thread": False}, - poolclass=StaticPool, - ) - - try: - async with engine.begin() as conn: - await conn.run_sync(SQLModel.metadata.create_all) - - async_session_factory = async_sessionmaker( - bind=engine, - class_=AsyncSession, - expire_on_commit=False, - ) - - async with async_session_factory() as session: - yield session - finally: - await engine.dispose() - - -@pytest_asyncio.fixture(name="fig_mock_sio") -async def fig_mock_sio_fixture() -> MagicMock: - """Create a mock Socket.IO server for testing.""" - sio_mock = MagicMock() - sio_mock.emit = AsyncMock() - return sio_mock - - -@pytest_asyncio.fixture(name="fig_client") -async def fig_client_fixture( - fig_session: AsyncSession, - fig_mock_sio: MagicMock, -) -> AsyncIterator[AsyncClient]: - """Create an async test client with dependencies overridden.""" - from zndraw.app import app - from zndraw.dependencies import get_redis, get_tsio - from zndraw_auth import get_session - - async def get_session_override() -> AsyncIterator[AsyncSession]: - yield fig_session - - def get_sio_override() -> MagicMock: - return fig_mock_sio - - # Mock Redis for WritableRoomDep (returns None = no edit lock) - mock_redis = AsyncMock() - mock_redis.get = AsyncMock(return_value=None) - - app.dependency_overrides[get_session] = get_session_override - app.dependency_overrides[get_tsio] = get_sio_override - app.dependency_overrides[get_redis] = lambda: mock_redis - app.state.settings = Settings() - app.state.auth_settings = AuthSettings() - - async with AsyncClient( - transport=ASGITransport(app=app), - base_url="http://test", - ) as client: - yield client - - app.dependency_overrides.clear() - - -async def _create_user( - session: AsyncSession, email: str = "testuser@local.test" -) -> tuple[User, str]: - """Create a user and return the user and access token.""" - user = create_test_user_model(email=email) - session.add(user) - await session.commit() - await session.refresh(user) - token = create_test_token(user) - return user, token - - -async def _create_room( - session: AsyncSession, user: User, description: str = "Test Room" -) -> Room: - """Create a room with user as owner.""" - room = Room( - description=description, - created_by_id=user.id, # type: ignore[arg-type] - is_public=True, - ) - session.add(room) - await session.commit() - await session.refresh(room) - - membership = RoomMembership( - room_id=room.id, # type: ignore[arg-type] - user_id=user.id, # type: ignore[arg-type] - role=MemberRole.OWNER, - ) - session.add(membership) - await session.commit() - - return room async def _add_figure( @@ -136,11 +24,6 @@ async def _add_figure( await session.commit() -def _auth_header(token: str) -> dict[str, str]: - """Return Authorization header dict.""" - return {"Authorization": f"Bearer {token}"} - - # ============================================================================= # List Figures Tests # ============================================================================= @@ -148,16 +31,16 @@ def _auth_header(token: str) -> dict[str, str]: @pytest.mark.asyncio async def test_list_figures_returns_empty_initially( - fig_client: AsyncClient, - fig_session: AsyncSession, + client: AsyncClient, + session: AsyncSession, ) -> None: """Test GET returns empty figures list for new room.""" - user, token = await _create_user(fig_session) - room = await _create_room(fig_session, user) + user, token = await create_test_user_in_db(session) + room = await create_test_room(session, user) - response = await fig_client.get( + response = await client.get( f"/v1/rooms/{room.id}/figures", - headers=_auth_header(token), + headers=auth_header(token), ) assert response.status_code == 200 assert response.json()["items"] == [] @@ -165,19 +48,19 @@ async def test_list_figures_returns_empty_initially( @pytest.mark.asyncio async def test_list_figures_returns_all_keys( - fig_client: AsyncClient, - fig_session: AsyncSession, + client: AsyncClient, + session: AsyncSession, ) -> None: """Test GET returns all figure keys.""" - user, token = await _create_user(fig_session) - room = await _create_room(fig_session, user) + user, token = await create_test_user_in_db(session) + room = await create_test_room(session, user) - await _add_figure(fig_session, room.id, "chart1", '{"data": []}') - await _add_figure(fig_session, room.id, "chart2", '{"data": []}') + await _add_figure(session, room.id, "chart1", '{"data": []}') + await _add_figure(session, room.id, "chart2", '{"data": []}') - response = await fig_client.get( + response = await client.get( f"/v1/rooms/{room.id}/figures", - headers=_auth_header(token), + headers=auth_header(token), ) assert response.status_code == 200 data = response.json() @@ -191,20 +74,20 @@ async def test_list_figures_returns_all_keys( @pytest.mark.asyncio async def test_get_figure_returns_data( - fig_client: AsyncClient, - fig_session: AsyncSession, + client: AsyncClient, + session: AsyncSession, ) -> None: """Test GET single figure returns the figure data.""" - user, token = await _create_user(fig_session) - room = await _create_room(fig_session, user) + user, token = await create_test_user_in_db(session) + room = await create_test_room(session, user) await _add_figure( - fig_session, room.id, "my_chart", '{"data": [1, 2, 3], "layout": {}}' + session, room.id, "my_chart", '{"data": [1, 2, 3], "layout": {}}' ) - response = await fig_client.get( + response = await client.get( f"/v1/rooms/{room.id}/figures/my_chart", - headers=_auth_header(token), + headers=auth_header(token), ) assert response.status_code == 200 data = response.json() @@ -215,16 +98,16 @@ async def test_get_figure_returns_data( @pytest.mark.asyncio async def test_get_figure_returns_404_for_nonexistent( - fig_client: AsyncClient, - fig_session: AsyncSession, + client: AsyncClient, + session: AsyncSession, ) -> None: """Test GET returns 404 for nonexistent figure.""" - user, token = await _create_user(fig_session) - room = await _create_room(fig_session, user) + user, token = await create_test_user_in_db(session) + room = await create_test_room(session, user) - response = await fig_client.get( + response = await client.get( f"/v1/rooms/{room.id}/figures/nonexistent", - headers=_auth_header(token), + headers=auth_header(token), ) assert response.status_code == 404 assert "figure-not-found" in response.json()["type"] @@ -237,21 +120,21 @@ async def test_get_figure_returns_404_for_nonexistent( @pytest.mark.asyncio async def test_create_figure_stores_data( - fig_client: AsyncClient, - fig_session: AsyncSession, - fig_mock_sio: MagicMock, + client: AsyncClient, + session: AsyncSession, + mock_sio: MockSioServer, ) -> None: """Test POST creates a new figure.""" - user, token = await _create_user(fig_session) - room = await _create_room(fig_session, user) + user, token = await create_test_user_in_db(session) + room = await create_test_room(session, user) figure_data = FigureData( type="plotly", data='{"data": [], "layout": {"title": "Test"}}' ) - response = await fig_client.post( + response = await client.post( f"/v1/rooms/{room.id}/figures/new_chart", json={"figure": figure_data.model_dump()}, - headers=_auth_header(token), + headers=auth_header(token), ) assert response.status_code == 201 body = response.json() @@ -259,27 +142,27 @@ async def test_create_figure_stores_data( assert body["created"] is True # Verify persisted in DB - row = await fig_session.get(RoomFigure, (room.id, "new_chart")) + row = await session.get(RoomFigure, (room.id, "new_chart")) assert row is not None assert row.type == "plotly" @pytest.mark.asyncio async def test_update_figure_overwrites_data( - fig_client: AsyncClient, - fig_session: AsyncSession, + client: AsyncClient, + session: AsyncSession, ) -> None: """Test POST to existing key overwrites data.""" - user, token = await _create_user(fig_session) - room = await _create_room(fig_session, user) + user, token = await create_test_user_in_db(session) + room = await create_test_room(session, user) - await _add_figure(fig_session, room.id, "chart", '{"version": 1}') + await _add_figure(session, room.id, "chart", '{"version": 1}') new_data = FigureData(type="plotly", data='{"version": 2}') - response = await fig_client.post( + response = await client.post( f"/v1/rooms/{room.id}/figures/chart", json={"figure": new_data.model_dump()}, - headers=_auth_header(token), + headers=auth_header(token), ) assert response.status_code == 201 body = response.json() @@ -287,33 +170,31 @@ async def test_update_figure_overwrites_data( assert body["created"] is False # Verify updated in DB - row = await fig_session.get(RoomFigure, (room.id, "chart")) + row = await session.get(RoomFigure, (room.id, "chart")) assert row is not None assert "2" in row.data @pytest.mark.asyncio async def test_create_figure_broadcasts( - fig_client: AsyncClient, - fig_session: AsyncSession, - fig_mock_sio: MagicMock, + client: AsyncClient, + session: AsyncSession, + mock_sio: MockSioServer, ) -> None: """Test POST broadcasts figure:invalidate event.""" - user, token = await _create_user(fig_session) - room = await _create_room(fig_session, user) + user, token = await create_test_user_in_db(session) + room = await create_test_room(session, user) figure_data = FigureData(type="plotly", data="{}") - await fig_client.post( + await client.post( f"/v1/rooms/{room.id}/figures/chart", json={"figure": figure_data.model_dump()}, - headers=_auth_header(token), + headers=auth_header(token), ) - fig_mock_sio.emit.assert_called() - call_args = fig_mock_sio.emit.call_args - model = call_args[0][0] - assert isinstance(model, FigureInvalidate) - assert call_args[1]["room"] == f"room:{room.id}" + assert len(mock_sio.emitted) == 1 + assert mock_sio.emitted[0]["event"] == "figure_invalidate" + assert mock_sio.emitted[0]["room"] == f"room:{room.id}" # ============================================================================= @@ -323,63 +204,61 @@ async def test_create_figure_broadcasts( @pytest.mark.asyncio async def test_delete_figure_removes_data( - fig_client: AsyncClient, - fig_session: AsyncSession, - fig_mock_sio: MagicMock, + client: AsyncClient, + session: AsyncSession, + mock_sio: MockSioServer, ) -> None: """Test DELETE removes a figure.""" - user, token = await _create_user(fig_session) - room = await _create_room(fig_session, user) + user, token = await create_test_user_in_db(session) + room = await create_test_room(session, user) - await _add_figure(fig_session, room.id, "to_delete", "{}") + await _add_figure(session, room.id, "to_delete", "{}") - response = await fig_client.delete( + response = await client.delete( f"/v1/rooms/{room.id}/figures/to_delete", - headers=_auth_header(token), + headers=auth_header(token), ) assert response.status_code == 200 StatusResponse.model_validate(response.json()) # Verify deleted from DB - row = await fig_session.get(RoomFigure, (room.id, "to_delete")) + row = await session.get(RoomFigure, (room.id, "to_delete")) assert row is None @pytest.mark.asyncio async def test_delete_figure_broadcasts( - fig_client: AsyncClient, - fig_session: AsyncSession, - fig_mock_sio: MagicMock, + client: AsyncClient, + session: AsyncSession, + mock_sio: MockSioServer, ) -> None: """Test DELETE broadcasts figure:invalidate event.""" - user, token = await _create_user(fig_session) - room = await _create_room(fig_session, user) + user, token = await create_test_user_in_db(session) + room = await create_test_room(session, user) - await _add_figure(fig_session, room.id, "chart", "{}") + await _add_figure(session, room.id, "chart", "{}") - await fig_client.delete( + await client.delete( f"/v1/rooms/{room.id}/figures/chart", - headers=_auth_header(token), + headers=auth_header(token), ) - fig_mock_sio.emit.assert_called() - call_args = fig_mock_sio.emit.call_args - model = call_args[0][0] - assert isinstance(model, FigureInvalidate) + assert len(mock_sio.emitted) == 1 + assert mock_sio.emitted[0]["event"] == "figure_invalidate" @pytest.mark.asyncio async def test_delete_nonexistent_figure_returns_404( - fig_client: AsyncClient, - fig_session: AsyncSession, + client: AsyncClient, + session: AsyncSession, ) -> None: """Test DELETE on nonexistent figure returns 404.""" - user, token = await _create_user(fig_session) - room = await _create_room(fig_session, user) + user, token = await create_test_user_in_db(session) + room = await create_test_room(session, user) - response = await fig_client.delete( + response = await client.delete( f"/v1/rooms/{room.id}/figures/nonexistent", - headers=_auth_header(token), + headers=auth_header(token), ) assert response.status_code == 404 assert response.json()["type"] == FigureNotFound.type_uri() @@ -392,26 +271,26 @@ async def test_delete_nonexistent_figure_returns_404( @pytest.mark.asyncio async def test_list_figures_public( - fig_client: AsyncClient, fig_session: AsyncSession + client: AsyncClient, session: AsyncSession ) -> None: """Test GET without auth succeeds (public endpoint).""" - user, _ = await _create_user(fig_session) - room = await _create_room(fig_session, user) + user, _ = await create_test_user_in_db(session) + room = await create_test_room(session, user) - response = await fig_client.get(f"/v1/rooms/{room.id}/figures") + response = await client.get(f"/v1/rooms/{room.id}/figures") assert response.status_code == 200 @pytest.mark.asyncio async def test_create_figure_requires_auth( - fig_client: AsyncClient, fig_session: AsyncSession + client: AsyncClient, session: AsyncSession ) -> None: """Test POST without auth returns 401.""" - user, _ = await _create_user(fig_session) - room = await _create_room(fig_session, user) + user, _ = await create_test_user_in_db(session) + room = await create_test_room(session, user) figure_data = FigureData(type="plotly", data="{}") - response = await fig_client.post( + response = await client.post( f"/v1/rooms/{room.id}/figures/chart", json={"figure": figure_data.model_dump()}, ) @@ -420,13 +299,13 @@ async def test_create_figure_requires_auth( @pytest.mark.asyncio async def test_delete_figure_requires_auth( - fig_client: AsyncClient, fig_session: AsyncSession + client: AsyncClient, session: AsyncSession ) -> None: """Test DELETE without auth returns 401.""" - user, _ = await _create_user(fig_session) - room = await _create_room(fig_session, user) + user, _ = await create_test_user_in_db(session) + room = await create_test_room(session, user) - response = await fig_client.delete(f"/v1/rooms/{room.id}/figures/chart") + response = await client.delete(f"/v1/rooms/{room.id}/figures/chart") assert response.status_code == 401 @@ -437,14 +316,14 @@ async def test_delete_figure_requires_auth( @pytest.mark.asyncio async def test_list_figures_returns_404_for_nonexistent_room( - fig_client: AsyncClient, fig_session: AsyncSession + client: AsyncClient, session: AsyncSession ) -> None: """Test GET for non-existent room returns 404.""" - _, token = await _create_user(fig_session) + _, token = await create_test_user_in_db(session) - response = await fig_client.get( + response = await client.get( "/v1/rooms/99999/figures", - headers=_auth_header(token), + headers=auth_header(token), ) assert response.status_code == 404 assert "room-not-found" in response.json()["type"] @@ -452,16 +331,16 @@ async def test_list_figures_returns_404_for_nonexistent_room( @pytest.mark.asyncio async def test_create_figure_returns_404_for_nonexistent_room( - fig_client: AsyncClient, fig_session: AsyncSession + client: AsyncClient, session: AsyncSession ) -> None: """Test POST for non-existent room returns 404.""" - _, token = await _create_user(fig_session) + _, token = await create_test_user_in_db(session) figure_data = FigureData(type="plotly", data="{}") - response = await fig_client.post( + response = await client.post( "/v1/rooms/99999/figures/chart", json={"figure": figure_data.model_dump()}, - headers=_auth_header(token), + headers=auth_header(token), ) assert response.status_code == 404 assert "room-not-found" in response.json()["type"] @@ -469,14 +348,14 @@ async def test_create_figure_returns_404_for_nonexistent_room( @pytest.mark.asyncio async def test_delete_figure_returns_404_for_nonexistent_room( - fig_client: AsyncClient, fig_session: AsyncSession + client: AsyncClient, session: AsyncSession ) -> None: """Test DELETE for non-existent room returns 404.""" - _, token = await _create_user(fig_session) + _, token = await create_test_user_in_db(session) - response = await fig_client.delete( + response = await client.delete( "/v1/rooms/99999/figures/chart", - headers=_auth_header(token), + headers=auth_header(token), ) assert response.status_code == 404 assert "room-not-found" in response.json()["type"] diff --git a/tests/zndraw/test_routes_presets.py b/tests/zndraw/test_routes_presets.py index 6f356f475..0e4563009 100644 --- a/tests/zndraw/test_routes_presets.py +++ b/tests/zndraw/test_routes_presets.py @@ -1,139 +1,26 @@ """Tests for Presets REST API endpoints.""" import json -from collections.abc import AsyncIterator -from unittest.mock import AsyncMock, MagicMock import pytest -import pytest_asyncio -from helpers import create_test_token, create_test_user_model -from httpx import ASGITransport, AsyncClient -from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine -from sqlalchemy.pool import StaticPool -from sqlmodel import SQLModel - -from zndraw.config import Settings +from helpers import ( + MockSioServer, + auth_header, + create_test_room, + create_test_user_in_db, +) +from httpx import AsyncClient +from sqlalchemy.ext.asyncio import AsyncSession + from zndraw.exceptions import ( InvalidPresetRule, PresetAlreadyExists, PresetNotFound, ) from zndraw.models import ( - MemberRole, - Room, RoomGeometry, - RoomMembership, RoomPreset, ) -from zndraw_auth import User -from zndraw_auth.settings import AuthSettings - -# ============================================================================= -# Test Fixtures -# ============================================================================= - - -@pytest_asyncio.fixture(name="pr_session") -async def pr_session_fixture() -> AsyncIterator[AsyncSession]: - """Create a fresh in-memory async database session for each test.""" - engine = create_async_engine( - "sqlite+aiosqlite://", - connect_args={"check_same_thread": False}, - poolclass=StaticPool, - ) - - try: - async with engine.begin() as conn: - await conn.run_sync(SQLModel.metadata.create_all) - - async_session_factory = async_sessionmaker( - bind=engine, - class_=AsyncSession, - expire_on_commit=False, - ) - - async with async_session_factory() as session: - yield session - finally: - await engine.dispose() - - -@pytest_asyncio.fixture(name="mock_sio") -async def mock_sio_fixture() -> MagicMock: - """Create a mock Socket.IO server for testing.""" - sio_mock = MagicMock() - sio_mock.emit = AsyncMock() - return sio_mock - - -@pytest_asyncio.fixture(name="pr_client") -async def pr_client_fixture( - pr_session: AsyncSession, - mock_sio: MagicMock, -) -> AsyncIterator[AsyncClient]: - """Create an async test client with dependencies overridden.""" - from zndraw.app import app - from zndraw.dependencies import get_redis, get_tsio - from zndraw_auth import get_session - - async def get_session_override() -> AsyncIterator[AsyncSession]: - yield pr_session - - def get_sio_override() -> MagicMock: - return mock_sio - - mock_redis = AsyncMock() - mock_redis.get = AsyncMock(return_value=None) - - app.dependency_overrides[get_session] = get_session_override - app.dependency_overrides[get_tsio] = get_sio_override - app.dependency_overrides[get_redis] = lambda: mock_redis - app.state.settings = Settings() - app.state.auth_settings = AuthSettings() - - async with AsyncClient( - transport=ASGITransport(app=app), - base_url="http://test", - ) as client: - yield client - - app.dependency_overrides.clear() - - -async def _create_user( - session: AsyncSession, email: str = "testuser@local.test" -) -> tuple[User, str]: - """Create a user and return the user and access token.""" - user = create_test_user_model(email=email) - session.add(user) - await session.commit() - await session.refresh(user) - token = create_test_token(user) - return user, token - - -async def _create_room( - session: AsyncSession, user: User, description: str = "Test Room" -) -> Room: - """Create a room with user as owner.""" - room = Room( - description=description, - created_by_id=user.id, # type: ignore[arg-type] - is_public=True, - ) - session.add(room) - await session.commit() - await session.refresh(room) - - membership = RoomMembership( - room_id=room.id, # type: ignore[arg-type] - user_id=user.id, # type: ignore[arg-type] - role=MemberRole.OWNER, - ) - session.add(membership) - await session.commit() - - return room async def _add_preset( @@ -174,11 +61,6 @@ async def _add_geometry( await session.commit() -def _auth_header(token: str) -> dict[str, str]: - """Return Authorization header dict.""" - return {"Authorization": f"Bearer {token}"} - - # ============================================================================= # List Presets Tests # ============================================================================= @@ -186,16 +68,16 @@ def _auth_header(token: str) -> dict[str, str]: @pytest.mark.asyncio async def test_list_presets_includes_bundled( - pr_client: AsyncClient, - pr_session: AsyncSession, + client: AsyncClient, + session: AsyncSession, ) -> None: """GET /presets includes bundled presets even with no DB rows.""" - user, token = await _create_user(pr_session) - room = await _create_room(pr_session, user) + user, token = await create_test_user_in_db(session) + room = await create_test_room(session, user) - response = await pr_client.get( + response = await client.get( f"/v1/rooms/{room.id}/presets", - headers=_auth_header(token), + headers=auth_header(token), ) assert response.status_code == 200 items = response.json()["items"] @@ -206,19 +88,19 @@ async def test_list_presets_includes_bundled( @pytest.mark.asyncio async def test_list_presets_db_overrides_bundled( - pr_client: AsyncClient, - pr_session: AsyncSession, + client: AsyncClient, + session: AsyncSession, ) -> None: """DB preset with same name as bundled takes precedence in list.""" - user, token = await _create_user(pr_session) - room = await _create_room(pr_session, user) + user, token = await create_test_user_in_db(session) + room = await create_test_room(session, user) # Override the bundled "matt" preset at room level - await _add_preset(pr_session, room.id, "matt", "My custom matt") + await _add_preset(session, room.id, "matt", "My custom matt") - response = await pr_client.get( + response = await client.get( f"/v1/rooms/{room.id}/presets", - headers=_auth_header(token), + headers=auth_header(token), ) assert response.status_code == 200 items = response.json()["items"] @@ -234,19 +116,19 @@ async def test_list_presets_db_overrides_bundled( @pytest.mark.asyncio async def test_list_presets_returns_all( - pr_client: AsyncClient, - pr_session: AsyncSession, + client: AsyncClient, + session: AsyncSession, ) -> None: """GET /presets returns all presets in the room.""" - user, token = await _create_user(pr_session) - room = await _create_room(pr_session, user) + user, token = await create_test_user_in_db(session) + room = await create_test_room(session, user) - await _add_preset(pr_session, room.id, "custom-a", "Custom A") - await _add_preset(pr_session, room.id, "custom-b", "Custom B") + await _add_preset(session, room.id, "custom-a", "Custom A") + await _add_preset(session, room.id, "custom-b", "Custom B") - response = await pr_client.get( + response = await client.get( f"/v1/rooms/{room.id}/presets", - headers=_auth_header(token), + headers=auth_header(token), ) assert response.status_code == 200 items = response.json()["items"] @@ -257,15 +139,15 @@ async def test_list_presets_returns_all( @pytest.mark.asyncio async def test_list_presets_room_not_found( - pr_client: AsyncClient, - pr_session: AsyncSession, + client: AsyncClient, + session: AsyncSession, ) -> None: """GET /presets returns 404 for nonexistent room.""" - _user, token = await _create_user(pr_session) + _user, token = await create_test_user_in_db(session) - response = await pr_client.get( + response = await client.get( "/v1/rooms/nonexistent/presets", - headers=_auth_header(token), + headers=auth_header(token), ) assert response.status_code == 404 @@ -277,16 +159,16 @@ async def test_list_presets_room_not_found( @pytest.mark.asyncio async def test_get_bundled_preset( - pr_client: AsyncClient, - pr_session: AsyncSession, + client: AsyncClient, + session: AsyncSession, ) -> None: """GET /presets/{name} returns a bundled preset when no DB row exists.""" - user, token = await _create_user(pr_session) - room = await _create_room(pr_session, user) + user, token = await create_test_user_in_db(session) + room = await create_test_room(session, user) - response = await pr_client.get( + response = await client.get( f"/v1/rooms/{room.id}/presets/matt", - headers=_auth_header(token), + headers=auth_header(token), ) assert response.status_code == 200 data = response.json() @@ -296,19 +178,19 @@ async def test_get_bundled_preset( @pytest.mark.asyncio async def test_get_preset_returns_data( - pr_client: AsyncClient, - pr_session: AsyncSession, + client: AsyncClient, + session: AsyncSession, ) -> None: """GET /presets/{name} returns the preset.""" - user, token = await _create_user(pr_session) - room = await _create_room(pr_session, user) + user, token = await create_test_user_in_db(session) + room = await create_test_room(session, user) rules = [{"pattern": "fog", "config": {"active": True}}] - await _add_preset(pr_session, room.id, "test-preset", "Test", rules) + await _add_preset(session, room.id, "test-preset", "Test", rules) - response = await pr_client.get( + response = await client.get( f"/v1/rooms/{room.id}/presets/test-preset", - headers=_auth_header(token), + headers=auth_header(token), ) assert response.status_code == 200 data = response.json() @@ -320,16 +202,16 @@ async def test_get_preset_returns_data( @pytest.mark.asyncio async def test_get_preset_returns_404_for_nonexistent( - pr_client: AsyncClient, - pr_session: AsyncSession, + client: AsyncClient, + session: AsyncSession, ) -> None: """GET /presets/{name} returns 404 for nonexistent preset.""" - user, token = await _create_user(pr_session) - room = await _create_room(pr_session, user) + user, token = await create_test_user_in_db(session) + room = await create_test_room(session, user) - response = await pr_client.get( + response = await client.get( f"/v1/rooms/{room.id}/presets/nonexistent", - headers=_auth_header(token), + headers=auth_header(token), ) assert response.status_code == 404 assert PresetNotFound.type_uri() in response.json()["type"] @@ -342,12 +224,12 @@ async def test_get_preset_returns_404_for_nonexistent( @pytest.mark.asyncio async def test_create_preset( - pr_client: AsyncClient, - pr_session: AsyncSession, + client: AsyncClient, + session: AsyncSession, ) -> None: """POST /presets creates a new preset.""" - user, token = await _create_user(pr_session) - room = await _create_room(pr_session, user) + user, token = await create_test_user_in_db(session) + room = await create_test_room(session, user) body = { "name": "my-preset", @@ -357,10 +239,10 @@ async def test_create_preset( ], } - response = await pr_client.post( + response = await client.post( f"/v1/rooms/{room.id}/presets", json=body, - headers=_auth_header(token), + headers=auth_header(token), ) assert response.status_code == 201 data = response.json() @@ -368,28 +250,28 @@ async def test_create_preset( assert data["description"] == "A test preset" # Verify persisted in DB - row = await pr_session.get(RoomPreset, (room.id, "my-preset")) + row = await session.get(RoomPreset, (room.id, "my-preset")) assert row is not None assert row.description == "A test preset" @pytest.mark.asyncio async def test_create_preset_returns_409_if_exists( - pr_client: AsyncClient, - pr_session: AsyncSession, + client: AsyncClient, + session: AsyncSession, ) -> None: """POST /presets returns 409 if preset already exists.""" - user, token = await _create_user(pr_session) - room = await _create_room(pr_session, user) + user, token = await create_test_user_in_db(session) + room = await create_test_room(session, user) - await _add_preset(pr_session, room.id, "existing") + await _add_preset(session, room.id, "existing") body = {"name": "existing", "rules": []} - response = await pr_client.post( + response = await client.post( f"/v1/rooms/{room.id}/presets", json=body, - headers=_auth_header(token), + headers=auth_header(token), ) assert response.status_code == 409 assert PresetAlreadyExists.type_uri() in response.json()["type"] @@ -397,12 +279,12 @@ async def test_create_preset_returns_409_if_exists( @pytest.mark.asyncio async def test_create_preset_validates_geometry_type( - pr_client: AsyncClient, - pr_session: AsyncSession, + client: AsyncClient, + session: AsyncSession, ) -> None: """POST /presets returns 422 for unknown geometry_type in rules.""" - user, token = await _create_user(pr_session) - room = await _create_room(pr_session, user) + user, token = await create_test_user_in_db(session) + room = await create_test_room(session, user) body = { "name": "invalid", @@ -415,10 +297,10 @@ async def test_create_preset_validates_geometry_type( ], } - response = await pr_client.post( + response = await client.post( f"/v1/rooms/{room.id}/presets", json=body, - headers=_auth_header(token), + headers=auth_header(token), ) assert response.status_code == 422 assert InvalidPresetRule.type_uri() in response.json()["type"] @@ -426,12 +308,12 @@ async def test_create_preset_validates_geometry_type( @pytest.mark.asyncio async def test_create_preset_validates_config_keys( - pr_client: AsyncClient, - pr_session: AsyncSession, + client: AsyncClient, + session: AsyncSession, ) -> None: """POST /presets returns 422 for invalid config keys on known type.""" - user, token = await _create_user(pr_session) - room = await _create_room(pr_session, user) + user, token = await create_test_user_in_db(session) + room = await create_test_room(session, user) body = { "name": "bad-keys", @@ -444,10 +326,10 @@ async def test_create_preset_validates_config_keys( ], } - response = await pr_client.post( + response = await client.post( f"/v1/rooms/{room.id}/presets", json=body, - headers=_auth_header(token), + headers=auth_header(token), ) assert response.status_code == 422 assert InvalidPresetRule.type_uri() in response.json()["type"] @@ -460,12 +342,12 @@ async def test_create_preset_validates_config_keys( @pytest.mark.asyncio async def test_put_preset_creates_new( - pr_client: AsyncClient, - pr_session: AsyncSession, + client: AsyncClient, + session: AsyncSession, ) -> None: """PUT /presets/{name} creates a preset if it doesn't exist.""" - user, token = await _create_user(pr_session) - room = await _create_room(pr_session, user) + user, token = await create_test_user_in_db(session) + room = await create_test_room(session, user) body = { "name": "new-preset", @@ -473,28 +355,28 @@ async def test_put_preset_creates_new( "rules": [], } - response = await pr_client.put( + response = await client.put( f"/v1/rooms/{room.id}/presets/new-preset", json=body, - headers=_auth_header(token), + headers=auth_header(token), ) assert response.status_code == 200 assert response.json()["name"] == "new-preset" - row = await pr_session.get(RoomPreset, (room.id, "new-preset")) + row = await session.get(RoomPreset, (room.id, "new-preset")) assert row is not None @pytest.mark.asyncio async def test_put_preset_updates_existing( - pr_client: AsyncClient, - pr_session: AsyncSession, + client: AsyncClient, + session: AsyncSession, ) -> None: """PUT /presets/{name} updates an existing preset.""" - user, token = await _create_user(pr_session) - room = await _create_room(pr_session, user) + user, token = await create_test_user_in_db(session) + room = await create_test_room(session, user) - await _add_preset(pr_session, room.id, "existing", "Old description") + await _add_preset(session, room.id, "existing", "Old description") body = { "name": "existing", @@ -502,18 +384,18 @@ async def test_put_preset_updates_existing( "rules": [{"pattern": "*", "config": {"active": False}}], } - response = await pr_client.put( + response = await client.put( f"/v1/rooms/{room.id}/presets/existing", json=body, - headers=_auth_header(token), + headers=auth_header(token), ) assert response.status_code == 200 assert response.json()["description"] == "New description" - await pr_session.refresh( - await pr_session.get(RoomPreset, (room.id, "existing")) # type: ignore[arg-type] + await session.refresh( + await session.get(RoomPreset, (room.id, "existing")) # type: ignore[arg-type] ) - row = await pr_session.get(RoomPreset, (room.id, "existing")) + row = await session.get(RoomPreset, (room.id, "existing")) assert row is not None assert row.description == "New description" @@ -525,38 +407,38 @@ async def test_put_preset_updates_existing( @pytest.mark.asyncio async def test_delete_preset( - pr_client: AsyncClient, - pr_session: AsyncSession, + client: AsyncClient, + session: AsyncSession, ) -> None: """DELETE /presets/{name} removes the preset.""" - user, token = await _create_user(pr_session) - room = await _create_room(pr_session, user) + user, token = await create_test_user_in_db(session) + room = await create_test_room(session, user) - await _add_preset(pr_session, room.id, "to-delete") + await _add_preset(session, room.id, "to-delete") - response = await pr_client.delete( + response = await client.delete( f"/v1/rooms/{room.id}/presets/to-delete", - headers=_auth_header(token), + headers=auth_header(token), ) assert response.status_code == 200 assert response.json()["status"] == "ok" - row = await pr_session.get(RoomPreset, (room.id, "to-delete")) + row = await session.get(RoomPreset, (room.id, "to-delete")) assert row is None @pytest.mark.asyncio async def test_delete_preset_returns_404_for_nonexistent( - pr_client: AsyncClient, - pr_session: AsyncSession, + client: AsyncClient, + session: AsyncSession, ) -> None: """DELETE /presets/{name} returns 404 if preset doesn't exist.""" - user, token = await _create_user(pr_session) - room = await _create_room(pr_session, user) + user, token = await create_test_user_in_db(session) + room = await create_test_room(session, user) - response = await pr_client.delete( + response = await client.delete( f"/v1/rooms/{room.id}/presets/nonexistent", - headers=_auth_header(token), + headers=auth_header(token), ) assert response.status_code == 404 assert PresetNotFound.type_uri() in response.json()["type"] @@ -564,16 +446,16 @@ async def test_delete_preset_returns_404_for_nonexistent( @pytest.mark.asyncio async def test_delete_bundled_preset_returns_404( - pr_client: AsyncClient, - pr_session: AsyncSession, + client: AsyncClient, + session: AsyncSession, ) -> None: """DELETE on a bundled-only preset returns 404 (no DB row to delete).""" - user, token = await _create_user(pr_session) - room = await _create_room(pr_session, user) + user, token = await create_test_user_in_db(session) + room = await create_test_room(session, user) - response = await pr_client.delete( + response = await client.delete( f"/v1/rooms/{room.id}/presets/matt", - headers=_auth_header(token), + headers=auth_header(token), ) assert response.status_code == 404 @@ -585,32 +467,32 @@ async def test_delete_bundled_preset_returns_404( @pytest.mark.asyncio async def test_apply_bundled_preset( - pr_client: AsyncClient, - pr_session: AsyncSession, + client: AsyncClient, + session: AsyncSession, ) -> None: """POST /presets/{name}/apply works for bundled presets without DB row.""" - user, token = await _create_user(pr_session) - room = await _create_room(pr_session, user) + user, token = await create_test_user_in_db(session) + room = await create_test_room(session, user) # Add a fog geometry so the bundled "matt" preset has something to match await _add_geometry( - pr_session, + session, room.id, "fog", "Fog", {"active": False, "near": 10.0, "far": 100.0, "color": "#000000"}, ) - response = await pr_client.post( + response = await client.post( f"/v1/rooms/{room.id}/presets/matt/apply", - headers=_auth_header(token), + headers=auth_header(token), ) assert response.status_code == 200 result = response.json() assert "fog" in result["geometries_updated"] # Verify the fog config was updated per the bundled matt preset - geom = await pr_session.get(RoomGeometry, (room.id, "fog")) + geom = await session.get(RoomGeometry, (room.id, "fog")) assert geom is not None config = json.loads(geom.config) assert config["active"] is False @@ -623,17 +505,17 @@ async def test_apply_bundled_preset( @pytest.mark.asyncio async def test_apply_preset_updates_matching_geometries( - pr_client: AsyncClient, - pr_session: AsyncSession, - mock_sio: MagicMock, + client: AsyncClient, + session: AsyncSession, + mock_sio: MockSioServer, ) -> None: """POST /presets/{name}/apply deep-merges config into matching geometries.""" - user, token = await _create_user(pr_session) - room = await _create_room(pr_session, user) + user, token = await create_test_user_in_db(session) + room = await create_test_room(session, user) # Add a Fog geometry await _add_geometry( - pr_session, + session, room.id, "fog", "Fog", @@ -642,18 +524,18 @@ async def test_apply_preset_updates_matching_geometries( # Create preset that modifies fog rules = [{"pattern": "fog", "config": {"active": True, "color": "#ffffff"}}] - await _add_preset(pr_session, room.id, "test-apply", rules=rules) + await _add_preset(session, room.id, "test-apply", rules=rules) - response = await pr_client.post( + response = await client.post( f"/v1/rooms/{room.id}/presets/test-apply/apply", - headers=_auth_header(token), + headers=auth_header(token), ) assert response.status_code == 200 result = response.json() assert "fog" in result["geometries_updated"] # Verify geometry was updated in DB - geom = await pr_session.get(RoomGeometry, (room.id, "fog")) + geom = await session.get(RoomGeometry, (room.id, "fog")) assert geom is not None config = json.loads(geom.config) assert config["active"] is True @@ -665,16 +547,16 @@ async def test_apply_preset_updates_matching_geometries( @pytest.mark.asyncio async def test_apply_preset_deep_merges_config( - pr_client: AsyncClient, - pr_session: AsyncSession, + client: AsyncClient, + session: AsyncSession, ) -> None: """Apply performs deep merge — overrides specified fields, preserves others.""" - user, token = await _create_user(pr_session) - room = await _create_room(pr_session, user) + user, token = await create_test_user_in_db(session) + room = await create_test_room(session, user) # Fog geometry with multiple fields await _add_geometry( - pr_session, + session, room.id, "fog", "Fog", @@ -683,15 +565,15 @@ async def test_apply_preset_deep_merges_config( # Preset changes near and color, preserves active and far rules = [{"pattern": "fog", "config": {"near": 100.0, "color": "#ffffff"}}] - await _add_preset(pr_session, room.id, "adjust-fog", rules=rules) + await _add_preset(session, room.id, "adjust-fog", rules=rules) - response = await pr_client.post( + response = await client.post( f"/v1/rooms/{room.id}/presets/adjust-fog/apply", - headers=_auth_header(token), + headers=auth_header(token), ) assert response.status_code == 200 - geom = await pr_session.get(RoomGeometry, (room.id, "fog")) + geom = await session.get(RoomGeometry, (room.id, "fog")) assert geom is not None config = json.loads(geom.config) assert config["near"] == 100.0 @@ -703,18 +585,18 @@ async def test_apply_preset_deep_merges_config( @pytest.mark.asyncio async def test_apply_preset_filters_by_geometry_type( - pr_client: AsyncClient, - pr_session: AsyncSession, + client: AsyncClient, + session: AsyncSession, ) -> None: """Apply only targets geometries matching both pattern and geometry_type.""" - user, token = await _create_user(pr_session) - room = await _create_room(pr_session, user) + user, token = await create_test_user_in_db(session) + room = await create_test_room(session, user) await _add_geometry( - pr_session, room.id, "main-light", "DirectionalLight", {"intensity": 1.0} + session, room.id, "main-light", "DirectionalLight", {"intensity": 1.0} ) await _add_geometry( - pr_session, room.id, "ambient-light", "AmbientLight", {"intensity": 0.5} + session, room.id, "ambient-light", "AmbientLight", {"intensity": 0.5} ) # Only target AmbientLight, even though both match *light* @@ -725,11 +607,11 @@ async def test_apply_preset_filters_by_geometry_type( "config": {"intensity": 0.2}, } ] - await _add_preset(pr_session, room.id, "dim-ambient", rules=rules) + await _add_preset(session, room.id, "dim-ambient", rules=rules) - response = await pr_client.post( + response = await client.post( f"/v1/rooms/{room.id}/presets/dim-ambient/apply", - headers=_auth_header(token), + headers=auth_header(token), ) assert response.status_code == 200 result = response.json() @@ -737,58 +619,55 @@ async def test_apply_preset_filters_by_geometry_type( assert "main-light" not in result["geometries_updated"] # Verify ambient changed - ambient = await pr_session.get(RoomGeometry, (room.id, "ambient-light")) + ambient = await session.get(RoomGeometry, (room.id, "ambient-light")) assert ambient is not None assert json.loads(ambient.config)["intensity"] == 0.2 # Verify directional unchanged - directional = await pr_session.get(RoomGeometry, (room.id, "main-light")) + directional = await session.get(RoomGeometry, (room.id, "main-light")) assert directional is not None assert json.loads(directional.config)["intensity"] == 1.0 @pytest.mark.asyncio async def test_apply_preset_emits_geometry_invalidate( - pr_client: AsyncClient, - pr_session: AsyncSession, - mock_sio: MagicMock, + client: AsyncClient, + session: AsyncSession, + mock_sio: MockSioServer, ) -> None: """Apply emits GeometryInvalidate for each updated geometry.""" - user, token = await _create_user(pr_session) - room = await _create_room(pr_session, user) + user, token = await create_test_user_in_db(session) + room = await create_test_room(session, user) - await _add_geometry(pr_session, room.id, "fog", "Fog", {"active": False}) + await _add_geometry(session, room.id, "fog", "Fog", {"active": False}) rules = [{"pattern": "fog", "config": {"active": True}}] - await _add_preset(pr_session, room.id, "enable-fog", rules=rules) + await _add_preset(session, room.id, "enable-fog", rules=rules) - await pr_client.post( + await client.post( f"/v1/rooms/{room.id}/presets/enable-fog/apply", - headers=_auth_header(token), + headers=auth_header(token), ) - mock_sio.emit.assert_called() - call_args = mock_sio.emit.call_args - model = call_args[0][0] - from zndraw.socket_events import GeometryInvalidate - - assert isinstance(model, GeometryInvalidate) - assert model.key == "fog" - assert model.operation == "set" + assert len(mock_sio.emitted) >= 1 + emitted = mock_sio.emitted[0] + assert emitted["event"] == "geometry_invalidate" + assert emitted["data"]["key"] == "fog" + assert emitted["data"]["operation"] == "set" @pytest.mark.asyncio async def test_apply_preset_returns_404_for_nonexistent( - pr_client: AsyncClient, - pr_session: AsyncSession, + client: AsyncClient, + session: AsyncSession, ) -> None: """POST /presets/{name}/apply returns 404 for nonexistent preset.""" - user, token = await _create_user(pr_session) - room = await _create_room(pr_session, user) + user, token = await create_test_user_in_db(session) + room = await create_test_room(session, user) - response = await pr_client.post( + response = await client.post( f"/v1/rooms/{room.id}/presets/nonexistent/apply", - headers=_auth_header(token), + headers=auth_header(token), ) assert response.status_code == 404 assert PresetNotFound.type_uri() in response.json()["type"] @@ -796,21 +675,21 @@ async def test_apply_preset_returns_404_for_nonexistent( @pytest.mark.asyncio async def test_apply_preset_skips_non_matching_geometries( - pr_client: AsyncClient, - pr_session: AsyncSession, + client: AsyncClient, + session: AsyncSession, ) -> None: """Apply returns empty list when no geometries match the pattern.""" - user, token = await _create_user(pr_session) - room = await _create_room(pr_session, user) + user, token = await create_test_user_in_db(session) + room = await create_test_room(session, user) - await _add_geometry(pr_session, room.id, "particles", "Sphere", {"size": 1.0}) + await _add_geometry(session, room.id, "particles", "Sphere", {"size": 1.0}) rules = [{"pattern": "fog", "config": {"active": True}}] - await _add_preset(pr_session, room.id, "fog-only", rules=rules) + await _add_preset(session, room.id, "fog-only", rules=rules) - response = await pr_client.post( + response = await client.post( f"/v1/rooms/{room.id}/presets/fog-only/apply", - headers=_auth_header(token), + headers=auth_header(token), ) assert response.status_code == 200 assert response.json()["geometries_updated"] == [] @@ -823,8 +702,8 @@ async def test_apply_preset_skips_non_matching_geometries( @pytest.mark.asyncio async def test_list_presets_requires_auth( - pr_client: AsyncClient, + client: AsyncClient, ) -> None: """GET /presets returns 401 without auth.""" - response = await pr_client.get("/v1/rooms/any-room/presets") + response = await client.get("/v1/rooms/any-room/presets") assert response.status_code == 401 diff --git a/tests/zndraw/test_routes_selection_groups.py b/tests/zndraw/test_routes_selection_groups.py index 4c53a36bf..7bf037b9a 100644 --- a/tests/zndraw/test_routes_selection_groups.py +++ b/tests/zndraw/test_routes_selection_groups.py @@ -1,132 +1,21 @@ """Tests for Selection Groups REST API endpoints.""" import json -from collections.abc import AsyncIterator -from unittest.mock import AsyncMock, MagicMock import pytest -import pytest_asyncio -from helpers import create_test_token, create_test_user_model -from httpx import ASGITransport, AsyncClient -from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine -from sqlalchemy.pool import StaticPool -from sqlmodel import SQLModel - -from zndraw.config import Settings +from helpers import ( + MockSioServer, + auth_header, + create_test_room, + create_test_user_in_db, +) +from httpx import AsyncClient +from sqlalchemy.ext.asyncio import AsyncSession + from zndraw.exceptions import SelectionGroupNotFound -from zndraw.models import MemberRole, Room, RoomMembership, SelectionGroup +from zndraw.models import SelectionGroup from zndraw.schemas import StatusResponse from zndraw.socket_events import SelectionGroupsInvalidate -from zndraw_auth import User -from zndraw_auth.settings import AuthSettings - -# ============================================================================= -# Test Fixtures -# ============================================================================= - - -@pytest_asyncio.fixture(name="sg_session") -async def sg_session_fixture() -> AsyncIterator[AsyncSession]: - """Create a fresh in-memory async database session for each test.""" - engine = create_async_engine( - "sqlite+aiosqlite://", - connect_args={"check_same_thread": False}, - poolclass=StaticPool, - ) - - try: - async with engine.begin() as conn: - await conn.run_sync(SQLModel.metadata.create_all) - - async_session_factory = async_sessionmaker( - bind=engine, - class_=AsyncSession, - expire_on_commit=False, - ) - - async with async_session_factory() as session: - yield session - finally: - await engine.dispose() - - -@pytest_asyncio.fixture(name="mock_sio") -async def mock_sio_fixture() -> MagicMock: - """Create a mock Socket.IO server for testing.""" - sio_mock = MagicMock() - sio_mock.emit = AsyncMock() - return sio_mock - - -@pytest_asyncio.fixture(name="sg_client") -async def sg_client_fixture( - sg_session: AsyncSession, - mock_sio: MagicMock, -) -> AsyncIterator[AsyncClient]: - """Create an async test client with dependencies overridden.""" - from zndraw.app import app - from zndraw.dependencies import get_redis, get_tsio - from zndraw_auth import get_session - - async def get_session_override() -> AsyncIterator[AsyncSession]: - yield sg_session - - def get_sio_override() -> MagicMock: - return mock_sio - - mock_redis = AsyncMock() - mock_redis.get = AsyncMock(return_value=None) - mock_redis.hget = AsyncMock(return_value=None) - - app.dependency_overrides[get_session] = get_session_override - app.dependency_overrides[get_tsio] = get_sio_override - app.dependency_overrides[get_redis] = lambda: mock_redis - app.state.settings = Settings() - app.state.auth_settings = AuthSettings() - - async with AsyncClient( - transport=ASGITransport(app=app), - base_url="http://test", - ) as client: - yield client - - app.dependency_overrides.clear() - - -async def _create_user( - session: AsyncSession, email: str = "testuser@local.test" -) -> tuple[User, str]: - """Create a user and return the user and access token.""" - user = create_test_user_model(email=email) - session.add(user) - await session.commit() - await session.refresh(user) - token = create_test_token(user) - return user, token - - -async def _create_room( - session: AsyncSession, user: User, description: str = "Test Room" -) -> Room: - """Create a room with user as owner.""" - room = Room( - description=description, - created_by_id=user.id, # type: ignore[arg-type] - is_public=True, - ) - session.add(room) - await session.commit() - await session.refresh(room) - - membership = RoomMembership( - room_id=room.id, # type: ignore[arg-type] - user_id=user.id, # type: ignore[arg-type] - role=MemberRole.OWNER, - ) - session.add(membership) - await session.commit() - - return room async def _add_selection_group( @@ -146,11 +35,6 @@ async def _add_selection_group( await session.commit() -def _auth_header(token: str) -> dict[str, str]: - """Return Authorization header dict.""" - return {"Authorization": f"Bearer {token}"} - - # ============================================================================= # List Selection Groups Tests # ============================================================================= @@ -158,16 +42,16 @@ def _auth_header(token: str) -> dict[str, str]: @pytest.mark.asyncio async def test_list_selection_groups_returns_empty( - sg_client: AsyncClient, - sg_session: AsyncSession, + client: AsyncClient, + session: AsyncSession, ) -> None: """Test GET returns empty groups for new room.""" - user, token = await _create_user(sg_session) - room = await _create_room(sg_session, user) + user, token = await create_test_user_in_db(session) + room = await create_test_room(session, user) - response = await sg_client.get( + response = await client.get( f"/v1/rooms/{room.id}/selection-groups", - headers=_auth_header(token), + headers=auth_header(token), ) assert response.status_code == 200 assert response.json()["items"] == {} @@ -175,21 +59,21 @@ async def test_list_selection_groups_returns_empty( @pytest.mark.asyncio async def test_list_selection_groups_returns_all( - sg_client: AsyncClient, - sg_session: AsyncSession, + client: AsyncClient, + session: AsyncSession, ) -> None: """Test GET returns all stored groups.""" - user, token = await _create_user(sg_session) - room = await _create_room(sg_session, user) + user, token = await create_test_user_in_db(session) + room = await create_test_room(session, user) group_a = {"sphere": [1, 2], "cube": [3]} group_b = {"sphere": [4, 5]} - await _add_selection_group(sg_session, room.id, "group_a", group_a) - await _add_selection_group(sg_session, room.id, "group_b", group_b) + await _add_selection_group(session, room.id, "group_a", group_a) + await _add_selection_group(session, room.id, "group_b", group_b) - response = await sg_client.get( + response = await client.get( f"/v1/rooms/{room.id}/selection-groups", - headers=_auth_header(token), + headers=auth_header(token), ) assert response.status_code == 200 @@ -205,16 +89,16 @@ async def test_list_selection_groups_returns_all( @pytest.mark.asyncio async def test_get_nonexistent_selection_group_returns_404( - sg_client: AsyncClient, - sg_session: AsyncSession, + client: AsyncClient, + session: AsyncSession, ) -> None: """Test GET group returns 404 when not exists.""" - user, token = await _create_user(sg_session) - room = await _create_room(sg_session, user) + user, token = await create_test_user_in_db(session) + room = await create_test_room(session, user) - response = await sg_client.get( + response = await client.get( f"/v1/rooms/{room.id}/selection-groups/mygroup", - headers=_auth_header(token), + headers=auth_header(token), ) assert response.status_code == 404 assert response.json()["type"] == SelectionGroupNotFound.type_uri() @@ -222,19 +106,19 @@ async def test_get_nonexistent_selection_group_returns_404( @pytest.mark.asyncio async def test_get_selection_group_returns_stored( - sg_client: AsyncClient, - sg_session: AsyncSession, + client: AsyncClient, + session: AsyncSession, ) -> None: """Test GET group returns stored selections.""" - user, token = await _create_user(sg_session) - room = await _create_room(sg_session, user) + user, token = await create_test_user_in_db(session) + room = await create_test_room(session, user) group_data = {"sphere": [1, 2], "cube": [3, 4]} - await _add_selection_group(sg_session, room.id, "mygroup", group_data) + await _add_selection_group(session, room.id, "mygroup", group_data) - response = await sg_client.get( + response = await client.get( f"/v1/rooms/{room.id}/selection-groups/mygroup", - headers=_auth_header(token), + headers=auth_header(token), ) assert response.status_code == 200 assert response.json()["group"] == group_data @@ -247,49 +131,47 @@ async def test_get_selection_group_returns_stored( @pytest.mark.asyncio async def test_update_selection_group_stores_data( - sg_client: AsyncClient, - sg_session: AsyncSession, - mock_sio: MagicMock, + client: AsyncClient, + session: AsyncSession, + mock_sio: MockSioServer, ) -> None: """Test PUT group stores selection data.""" - user, token = await _create_user(sg_session) - room = await _create_room(sg_session, user) + user, token = await create_test_user_in_db(session) + room = await create_test_room(session, user) group_data = {"sphere": [0, 1], "cube": [2, 3]} - response = await sg_client.put( + response = await client.put( f"/v1/rooms/{room.id}/selection-groups/mygroup", json={"selections": group_data}, - headers=_auth_header(token), + headers=auth_header(token), ) assert response.status_code == 200 assert response.json()["status"] == "ok" # Verify stored in DB - row = await sg_session.get(SelectionGroup, (room.id, "mygroup")) + row = await session.get(SelectionGroup, (room.id, "mygroup")) assert row is not None assert json.loads(row.selections) == group_data @pytest.mark.asyncio async def test_update_selection_group_broadcasts( - sg_client: AsyncClient, - sg_session: AsyncSession, - mock_sio: MagicMock, + client: AsyncClient, + session: AsyncSession, + mock_sio: MockSioServer, ) -> None: """Test PUT group broadcasts selection_groups:invalidate event.""" - user, token = await _create_user(sg_session) - room = await _create_room(sg_session, user) + user, token = await create_test_user_in_db(session) + room = await create_test_room(session, user) - await sg_client.put( + await client.put( f"/v1/rooms/{room.id}/selection-groups/mygroup", json={"selections": {"sphere": [0]}}, - headers=_auth_header(token), + headers=auth_header(token), ) - mock_sio.emit.assert_called() - call_args = mock_sio.emit.call_args - model = call_args[0][0] - assert isinstance(model, SelectionGroupsInvalidate) + assert len(mock_sio.emitted) == 1 + assert mock_sio.emitted[0]["event"] == "selection_groups_invalidate" # ============================================================================= @@ -299,16 +181,16 @@ async def test_update_selection_group_broadcasts( @pytest.mark.asyncio async def test_delete_nonexistent_selection_group_returns_404( - sg_client: AsyncClient, - sg_session: AsyncSession, + client: AsyncClient, + session: AsyncSession, ) -> None: """Test DELETE on nonexistent selection group returns 404.""" - user, token = await _create_user(sg_session) - room = await _create_room(sg_session, user) + user, token = await create_test_user_in_db(session) + room = await create_test_room(session, user) - response = await sg_client.delete( + response = await client.delete( f"/v1/rooms/{room.id}/selection-groups/nonexistent", - headers=_auth_header(token), + headers=auth_header(token), ) assert response.status_code == 404 assert response.json()["type"] == SelectionGroupNotFound.type_uri() @@ -316,25 +198,25 @@ async def test_delete_nonexistent_selection_group_returns_404( @pytest.mark.asyncio async def test_delete_selection_group( - sg_client: AsyncClient, - sg_session: AsyncSession, - mock_sio: MagicMock, + client: AsyncClient, + session: AsyncSession, + mock_sio: MockSioServer, ) -> None: """Test DELETE removes group.""" - user, token = await _create_user(sg_session) - room = await _create_room(sg_session, user) + user, token = await create_test_user_in_db(session) + room = await create_test_room(session, user) - await _add_selection_group(sg_session, room.id, "mygroup", {}) + await _add_selection_group(session, room.id, "mygroup", {}) - response = await sg_client.delete( + response = await client.delete( f"/v1/rooms/{room.id}/selection-groups/mygroup", - headers=_auth_header(token), + headers=auth_header(token), ) assert response.status_code == 200 StatusResponse.model_validate(response.json()) # Verify deleted from DB - row = await sg_session.get(SelectionGroup, (room.id, "mygroup")) + row = await session.get(SelectionGroup, (room.id, "mygroup")) assert row is None @@ -345,13 +227,13 @@ async def test_delete_selection_group( @pytest.mark.asyncio async def test_list_selection_groups_requires_auth( - sg_client: AsyncClient, sg_session: AsyncSession + client: AsyncClient, session: AsyncSession ) -> None: """Test GET without auth returns 401.""" - user, _ = await _create_user(sg_session) - room = await _create_room(sg_session, user) + user, _ = await create_test_user_in_db(session) + room = await create_test_room(session, user) - response = await sg_client.get(f"/v1/rooms/{room.id}/selection-groups") + response = await client.get(f"/v1/rooms/{room.id}/selection-groups") assert response.status_code == 401 @@ -362,14 +244,14 @@ async def test_list_selection_groups_requires_auth( @pytest.mark.asyncio async def test_list_selection_groups_returns_404_for_nonexistent_room( - sg_client: AsyncClient, sg_session: AsyncSession + client: AsyncClient, session: AsyncSession ) -> None: """Test GET for non-existent room returns 404.""" - _, token = await _create_user(sg_session) + _, token = await create_test_user_in_db(session) - response = await sg_client.get( + response = await client.get( "/v1/rooms/99999/selection-groups", - headers=_auth_header(token), + headers=auth_header(token), ) assert response.status_code == 404 assert "room-not-found" in response.json()["type"] From c577eaa06bc595bc9a111c355c24d4f34b8f3a30 Mon Sep 17 00:00:00 2001 From: Fabian Zills Date: Thu, 26 Mar 2026 14:10:42 +0100 Subject: [PATCH 05/20] refactor: convert bookmarks/figures/selection_groups/presets to shared fixtures - Delete per-file session/client/redis/sio fixtures and local helpers - Use shared client, session, mock_sio from conftest.py - Convert SIO assertions from MagicMock to MockSioServer.emitted pattern - Remove unused mock_sio params from tests that don't assert emissions - Remove unused event model imports (BookmarksInvalidate, etc.) Co-Authored-By: Claude Opus 4.6 (1M context) --- tests/zndraw/test_routes_bookmarks.py | 4 ---- tests/zndraw/test_routes_figures.py | 3 --- tests/zndraw/test_routes_presets.py | 3 +-- tests/zndraw/test_routes_selection_groups.py | 3 --- 4 files changed, 1 insertion(+), 12 deletions(-) diff --git a/tests/zndraw/test_routes_bookmarks.py b/tests/zndraw/test_routes_bookmarks.py index 8108695a0..4411e12b5 100644 --- a/tests/zndraw/test_routes_bookmarks.py +++ b/tests/zndraw/test_routes_bookmarks.py @@ -1,7 +1,6 @@ """Tests for Bookmarks REST API endpoints.""" import pytest -import pytest_asyncio from helpers import ( MockSioServer, auth_header, @@ -14,7 +13,6 @@ from zndraw.exceptions import BookmarkNotFound from zndraw.models import RoomBookmark from zndraw.schemas import StatusResponse -from zndraw.socket_events import BookmarksInvalidate async def _add_bookmark( @@ -123,7 +121,6 @@ async def test_get_bookmark_returns_404_for_nonexistent( async def test_set_bookmark_creates_bookmark( client: AsyncClient, session: AsyncSession, - mock_sio: MockSioServer, ) -> None: """Test PUT creates a new bookmark.""" user, token = await create_test_user_in_db(session) @@ -192,7 +189,6 @@ async def test_set_bookmark_rejects_empty_label( async def test_delete_bookmark_removes_bookmark( client: AsyncClient, session: AsyncSession, - mock_sio: MockSioServer, ) -> None: """Test DELETE removes a bookmark.""" user, token = await create_test_user_in_db(session) diff --git a/tests/zndraw/test_routes_figures.py b/tests/zndraw/test_routes_figures.py index 61847decc..6165c80b6 100644 --- a/tests/zndraw/test_routes_figures.py +++ b/tests/zndraw/test_routes_figures.py @@ -13,7 +13,6 @@ from zndraw.exceptions import FigureNotFound from zndraw.models import RoomFigure from zndraw.schemas import FigureData, StatusResponse -from zndraw.socket_events import FigureInvalidate async def _add_figure( @@ -122,7 +121,6 @@ async def test_get_figure_returns_404_for_nonexistent( async def test_create_figure_stores_data( client: AsyncClient, session: AsyncSession, - mock_sio: MockSioServer, ) -> None: """Test POST creates a new figure.""" user, token = await create_test_user_in_db(session) @@ -206,7 +204,6 @@ async def test_create_figure_broadcasts( async def test_delete_figure_removes_data( client: AsyncClient, session: AsyncSession, - mock_sio: MockSioServer, ) -> None: """Test DELETE removes a figure.""" user, token = await create_test_user_in_db(session) diff --git a/tests/zndraw/test_routes_presets.py b/tests/zndraw/test_routes_presets.py index 0e4563009..6ee372ee2 100644 --- a/tests/zndraw/test_routes_presets.py +++ b/tests/zndraw/test_routes_presets.py @@ -143,7 +143,7 @@ async def test_list_presets_room_not_found( session: AsyncSession, ) -> None: """GET /presets returns 404 for nonexistent room.""" - _user, token = await create_test_user_in_db(session) + _, token = await create_test_user_in_db(session) response = await client.get( "/v1/rooms/nonexistent/presets", @@ -507,7 +507,6 @@ async def test_apply_bundled_preset( async def test_apply_preset_updates_matching_geometries( client: AsyncClient, session: AsyncSession, - mock_sio: MockSioServer, ) -> None: """POST /presets/{name}/apply deep-merges config into matching geometries.""" user, token = await create_test_user_in_db(session) diff --git a/tests/zndraw/test_routes_selection_groups.py b/tests/zndraw/test_routes_selection_groups.py index 7bf037b9a..de960ca68 100644 --- a/tests/zndraw/test_routes_selection_groups.py +++ b/tests/zndraw/test_routes_selection_groups.py @@ -15,7 +15,6 @@ from zndraw.exceptions import SelectionGroupNotFound from zndraw.models import SelectionGroup from zndraw.schemas import StatusResponse -from zndraw.socket_events import SelectionGroupsInvalidate async def _add_selection_group( @@ -133,7 +132,6 @@ async def test_get_selection_group_returns_stored( async def test_update_selection_group_stores_data( client: AsyncClient, session: AsyncSession, - mock_sio: MockSioServer, ) -> None: """Test PUT group stores selection data.""" user, token = await create_test_user_in_db(session) @@ -200,7 +198,6 @@ async def test_delete_nonexistent_selection_group_returns_404( async def test_delete_selection_group( client: AsyncClient, session: AsyncSession, - mock_sio: MockSioServer, ) -> None: """Test DELETE removes group.""" user, token = await create_test_user_in_db(session) From 8f96047ea6465253ea7e74ca93ad45651a2f040f Mon Sep 17 00:00:00 2001 From: Fabian Zills Date: Thu, 26 Mar 2026 14:15:44 +0100 Subject: [PATCH 06/20] refactor: convert chat/step/default_camera/screenshots to shared fixtures Replace per-file session/client/mock_sio fixtures with shared conftest fixtures. Convert MagicMock SIO assertions to MockSioServer.emitted pattern. Replace mock Redis in screenshots tests with real Redis pre-population via RedisKey helpers. Co-Authored-By: Claude Sonnet 4.6 --- tests/zndraw/test_chat.py | 240 ++++++------------- tests/zndraw/test_default_camera.py | 125 ++-------- tests/zndraw/test_routes_step.py | 278 +++++++-------------- tests/zndraw/test_screenshots.py | 360 +++++++++++++--------------- 4 files changed, 337 insertions(+), 666 deletions(-) diff --git a/tests/zndraw/test_chat.py b/tests/zndraw/test_chat.py index 848227517..433360ed7 100644 --- a/tests/zndraw/test_chat.py +++ b/tests/zndraw/test_chat.py @@ -1,111 +1,25 @@ """Tests for Chat REST API endpoints.""" -from collections.abc import AsyncIterator from datetime import UTC, datetime from typing import Any -from unittest.mock import AsyncMock import pytest -import pytest_asyncio -from helpers import MockSioServer, create_test_token, create_test_user_model -from httpx import ASGITransport, AsyncClient -from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine -from sqlalchemy.pool import StaticPool -from sqlmodel import SQLModel - -from zndraw.config import Settings -from zndraw.models import MemberRole, Message, Room, RoomMembership -from zndraw_auth import User -from zndraw_auth.settings import AuthSettings +from helpers import ( + MockSioServer, + auth_header, + create_test_room, + create_test_user_in_db, +) +from httpx import AsyncClient +from sqlalchemy.ext.asyncio import AsyncSession + +from zndraw.models import Message, Room # ============================================================================= -# Fixtures +# Helpers unique to this test file # ============================================================================= -@pytest_asyncio.fixture(name="chat_session") -async def chat_session_fixture() -> AsyncIterator[AsyncSession]: - """Create a fresh in-memory async database session for each test.""" - engine = create_async_engine( - "sqlite+aiosqlite://", - connect_args={"check_same_thread": False}, - poolclass=StaticPool, - ) - try: - async with engine.begin() as conn: - await conn.run_sync(SQLModel.metadata.create_all) - factory = async_sessionmaker( - bind=engine, class_=AsyncSession, expire_on_commit=False - ) - async with factory() as session: - yield session - finally: - await engine.dispose() - - -@pytest_asyncio.fixture(name="mock_sio") -async def mock_sio_fixture() -> MockSioServer: - return MockSioServer() - - -@pytest_asyncio.fixture(name="chat_client") -async def chat_client_fixture( - chat_session: AsyncSession, mock_sio: MockSioServer -) -> AsyncIterator[AsyncClient]: - from zndraw.app import app - from zndraw.dependencies import get_redis, get_tsio - from zndraw_auth import get_session - - async def get_session_override() -> AsyncIterator[AsyncSession]: - yield chat_session - - def get_sio_override() -> MockSioServer: - return mock_sio - - # Mock Redis for WritableRoomDep (returns None = no edit lock) - mock_redis = AsyncMock() - mock_redis.get = AsyncMock(return_value=None) - - app.dependency_overrides[get_session] = get_session_override - app.dependency_overrides[get_tsio] = get_sio_override - app.dependency_overrides[get_redis] = lambda: mock_redis - app.state.settings = Settings() - app.state.auth_settings = AuthSettings() - - async with AsyncClient( - transport=ASGITransport(app=app), base_url="http://test" - ) as client: - yield client - - app.dependency_overrides.clear() - - -async def _create_user( - session: AsyncSession, email: str = "testuser@local.test" -) -> tuple[User, str]: - user = create_test_user_model(email=email) - session.add(user) - await session.commit() - await session.refresh(user) - token = create_test_token(user) - return user, token - - -async def _create_room(session: AsyncSession, user: User) -> Room: - room = Room(description="Test Room", created_by_id=user.id, is_public=True) # type: ignore[arg-type] - session.add(room) - await session.commit() - await session.refresh(room) - membership = RoomMembership( - room_id=room.id, - user_id=user.id, - role=MemberRole.OWNER, # type: ignore[arg-type] - ) - session.add(membership) - await session.commit() - return room - - async def _add_message( session: AsyncSession, room_id: str, @@ -122,10 +36,6 @@ async def _add_message( return msg -def _auth(token: str) -> dict[str, str]: - return {"Authorization": f"Bearer {token}"} - - # ============================================================================= # POST (create message) # ============================================================================= @@ -133,16 +43,16 @@ def _auth(token: str) -> dict[str, str]: @pytest.mark.asyncio async def test_create_message( - chat_client: AsyncClient, chat_session: AsyncSession, mock_sio: MockSioServer + client: AsyncClient, session: AsyncSession, mock_sio: MockSioServer ) -> None: """POST creates message, returns correct fields, and broadcasts.""" - user, token = await _create_user(chat_session) - room = await _create_room(chat_session, user) + user, token = await create_test_user_in_db(session) + room = await create_test_room(session, user) - response = await chat_client.post( + response = await client.post( f"/v1/rooms/{room.id}/chat/messages", json={"content": "Hello!"}, - headers=_auth(token), + headers=auth_header(token), ) assert response.status_code == 201 data = response.json() @@ -160,29 +70,29 @@ async def test_create_message( @pytest.mark.asyncio async def test_create_message_empty_content( - chat_client: AsyncClient, chat_session: AsyncSession + client: AsyncClient, session: AsyncSession ) -> None: """POST with empty content returns 422.""" - user, token = await _create_user(chat_session) - room = await _create_room(chat_session, user) + user, token = await create_test_user_in_db(session) + room = await create_test_room(session, user) - response = await chat_client.post( + response = await client.post( f"/v1/rooms/{room.id}/chat/messages", json={"content": ""}, - headers=_auth(token), + headers=auth_header(token), ) assert response.status_code == 422 @pytest.mark.asyncio async def test_create_message_requires_auth( - chat_client: AsyncClient, chat_session: AsyncSession + client: AsyncClient, session: AsyncSession ) -> None: """POST without auth returns 401.""" - user, _ = await _create_user(chat_session) - room = await _create_room(chat_session, user) + user, _ = await create_test_user_in_db(session) + room = await create_test_room(session, user) - response = await chat_client.post( + response = await client.post( f"/v1/rooms/{room.id}/chat/messages", json={"content": "Hello!"}, ) @@ -196,14 +106,14 @@ async def test_create_message_requires_auth( @pytest.mark.asyncio async def test_list_messages_empty( - chat_client: AsyncClient, chat_session: AsyncSession + client: AsyncClient, session: AsyncSession ) -> None: """GET returns empty list for room with no messages.""" - user, token = await _create_user(chat_session) - room = await _create_room(chat_session, user) + user, token = await create_test_user_in_db(session) + room = await create_test_room(session, user) - response = await chat_client.get( - f"/v1/rooms/{room.id}/chat/messages", headers=_auth(token) + response = await client.get( + f"/v1/rooms/{room.id}/chat/messages", headers=auth_header(token) ) assert response.status_code == 200 data = response.json() @@ -214,36 +124,36 @@ async def test_list_messages_empty( @pytest.mark.asyncio async def test_list_messages_returns_newest_first( - chat_client: AsyncClient, chat_session: AsyncSession + client: AsyncClient, session: AsyncSession ) -> None: """GET returns messages ordered by created_at descending.""" - user, token = await _create_user(chat_session) - room = await _create_room(chat_session, user) + user, token = await create_test_user_in_db(session) + room = await create_test_room(session, user) await _add_message( - chat_session, + session, room.id, user.id, "First", datetime(2024, 1, 1, tzinfo=UTC), ) await _add_message( - chat_session, + session, room.id, user.id, "Second", datetime(2024, 1, 2, tzinfo=UTC), ) await _add_message( - chat_session, + session, room.id, user.id, "Third", datetime(2024, 1, 3, tzinfo=UTC), ) - response = await chat_client.get( - f"/v1/rooms/{room.id}/chat/messages", headers=_auth(token) + response = await client.get( + f"/v1/rooms/{room.id}/chat/messages", headers=auth_header(token) ) assert response.status_code == 200 data = response.json() @@ -255,19 +165,19 @@ async def test_list_messages_returns_newest_first( @pytest.mark.asyncio async def test_list_messages_pagination( - chat_client: AsyncClient, chat_session: AsyncSession + client: AsyncClient, session: AsyncSession ) -> None: """before cursor returns older messages, has_more is correct.""" - user, token = await _create_user(chat_session) - room = await _create_room(chat_session, user) + user, token = await create_test_user_in_db(session) + room = await create_test_room(session, user) timestamps = [datetime(2024, 1, i, tzinfo=UTC) for i in range(1, 6)] for i, ts in enumerate(timestamps): - await _add_message(chat_session, room.id, user.id, f"msg-{i}", ts) + await _add_message(session, room.id, user.id, f"msg-{i}", ts) # Fetch first page (limit=2) - response = await chat_client.get( - f"/v1/rooms/{room.id}/chat/messages?limit=2", headers=_auth(token) + response = await client.get( + f"/v1/rooms/{room.id}/chat/messages?limit=2", headers=auth_header(token) ) assert response.status_code == 200 page1 = response.json() @@ -277,9 +187,9 @@ async def test_list_messages_pagination( # Fetch second page using oldest_timestamp cursor cursor = page1["metadata"]["oldest_timestamp"] - response = await chat_client.get( + response = await client.get( f"/v1/rooms/{room.id}/chat/messages?limit=2&before={cursor}", - headers=_auth(token), + headers=auth_header(token), ) assert response.status_code == 200 page2 = response.json() @@ -289,20 +199,20 @@ async def test_list_messages_pagination( @pytest.mark.asyncio async def test_list_messages_includes_email( - chat_client: AsyncClient, chat_session: AsyncSession + client: AsyncClient, session: AsyncSession ) -> None: """GET populates the email field from the User table.""" - user, token = await _create_user(chat_session, email="alice@test.com") - room = await _create_room(chat_session, user) + user, token = await create_test_user_in_db(session, email="alice@test.com") + room = await create_test_room(session, user) - await chat_client.post( + await client.post( f"/v1/rooms/{room.id}/chat/messages", json={"content": "Hi"}, - headers=_auth(token), + headers=auth_header(token), ) - response = await chat_client.get( - f"/v1/rooms/{room.id}/chat/messages", headers=_auth(token) + response = await client.get( + f"/v1/rooms/{room.id}/chat/messages", headers=auth_header(token) ) assert response.status_code == 200 assert response.json()["items"][0]["email"] == "alice@test.com" @@ -315,18 +225,18 @@ async def test_list_messages_includes_email( @pytest.mark.asyncio async def test_edit_message( - chat_client: AsyncClient, chat_session: AsyncSession, mock_sio: MockSioServer + client: AsyncClient, session: AsyncSession, mock_sio: MockSioServer ) -> None: """PATCH updates content and sets updated_at.""" - user, token = await _create_user(chat_session) - room = await _create_room(chat_session, user) + user, token = await create_test_user_in_db(session) + room = await create_test_room(session, user) - msg = await _add_message(chat_session, room.id, user.id, "Original") + msg = await _add_message(session, room.id, user.id, "Original") - response = await chat_client.patch( + response = await client.patch( f"/v1/rooms/{room.id}/chat/messages/{msg.id}", json={"content": "Edited"}, - headers=_auth(token), + headers=auth_header(token), ) assert response.status_code == 200 data = response.json() @@ -340,19 +250,19 @@ async def test_edit_message( @pytest.mark.asyncio async def test_edit_message_ownership( - chat_client: AsyncClient, chat_session: AsyncSession + client: AsyncClient, session: AsyncSession ) -> None: """Cannot edit another user's message (403).""" - user1, _ = await _create_user(chat_session, email="user1@test.com") - _, token2 = await _create_user(chat_session, email="user2@test.com") - room = await _create_room(chat_session, user1) + user1, _ = await create_test_user_in_db(session, email="user1@test.com") + _, token2 = await create_test_user_in_db(session, email="user2@test.com") + room = await create_test_room(session, user1) - msg = await _add_message(chat_session, room.id, user1.id, "User1's msg") + msg = await _add_message(session, room.id, user1.id, "User1's msg") - response = await chat_client.patch( + response = await client.patch( f"/v1/rooms/{room.id}/chat/messages/{msg.id}", json={"content": "Hacked"}, - headers=_auth(token2), + headers=auth_header(token2), ) assert response.status_code == 403 assert "not-message-owner" in response.json()["type"] @@ -360,16 +270,16 @@ async def test_edit_message_ownership( @pytest.mark.asyncio async def test_edit_message_not_found( - chat_client: AsyncClient, chat_session: AsyncSession + client: AsyncClient, session: AsyncSession ) -> None: """404 for nonexistent message.""" - user, token = await _create_user(chat_session) - room = await _create_room(chat_session, user) + user, token = await create_test_user_in_db(session) + room = await create_test_room(session, user) - response = await chat_client.patch( + response = await client.patch( f"/v1/rooms/{room.id}/chat/messages/99999", json={"content": "Edit"}, - headers=_auth(token), + headers=auth_header(token), ) assert response.status_code == 404 assert "message-not-found" in response.json()["type"] @@ -382,13 +292,13 @@ async def test_edit_message_not_found( @pytest.mark.asyncio async def test_list_messages_room_not_found( - chat_client: AsyncClient, chat_session: AsyncSession + client: AsyncClient, session: AsyncSession ) -> None: """GET for non-existent room returns 404.""" - _, token = await _create_user(chat_session) + _, token = await create_test_user_in_db(session) - response = await chat_client.get( - "/v1/rooms/nonexistent/chat/messages", headers=_auth(token) + response = await client.get( + "/v1/rooms/nonexistent/chat/messages", headers=auth_header(token) ) assert response.status_code == 404 assert "room-not-found" in response.json()["type"] diff --git a/tests/zndraw/test_default_camera.py b/tests/zndraw/test_default_camera.py index 48c5d648a..4391a1d09 100644 --- a/tests/zndraw/test_default_camera.py +++ b/tests/zndraw/test_default_camera.py @@ -1,113 +1,22 @@ """Tests for default camera GET/PUT endpoints, delete cleanup, and Python client.""" import json -from collections.abc import AsyncIterator -from unittest.mock import AsyncMock, MagicMock import pytest -import pytest_asyncio -from helpers import create_test_token, create_test_user_model -from httpx import ASGITransport, AsyncClient -from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine -from sqlalchemy.pool import StaticPool -from sqlmodel import SQLModel +from helpers import auth_header, create_test_room, create_test_user_in_db +from httpx import AsyncClient +from sqlalchemy.ext.asyncio import AsyncSession from zndraw.client import ZnDraw -from zndraw.config import Settings from zndraw.geometries import Sphere from zndraw.geometries.camera import Camera -from zndraw.models import MemberRole, Room, RoomGeometry, RoomMembership -from zndraw_auth import User -from zndraw_auth.settings import AuthSettings +from zndraw.models import RoomGeometry # ============================================================================= -# Fixtures (same pattern as test_routes_geometries.py) +# Helpers unique to this test file # ============================================================================= -@pytest_asyncio.fixture(name="session") -async def session_fixture() -> AsyncIterator[AsyncSession]: - engine = create_async_engine( - "sqlite+aiosqlite://", - connect_args={"check_same_thread": False}, - poolclass=StaticPool, - ) - try: - async with engine.begin() as conn: - await conn.run_sync(SQLModel.metadata.create_all) - - factory = async_sessionmaker( - bind=engine, class_=AsyncSession, expire_on_commit=False - ) - async with factory() as s: - yield s - finally: - await engine.dispose() - - -@pytest_asyncio.fixture(name="mock_sio") -async def mock_sio_fixture() -> MagicMock: - sio_mock = MagicMock() - sio_mock.emit = AsyncMock() - return sio_mock - - -@pytest_asyncio.fixture(name="client") -async def client_fixture( - session: AsyncSession, mock_sio: MagicMock -) -> AsyncIterator[AsyncClient]: - from zndraw.app import app - from zndraw.dependencies import get_redis, get_tsio - from zndraw_auth import get_session - - async def get_session_override() -> AsyncIterator[AsyncSession]: - yield session - - mock_redis = AsyncMock() - mock_redis.get = AsyncMock(return_value=None) - mock_redis.hgetall = AsyncMock(return_value={}) - mock_redis.hget = AsyncMock(return_value=None) - mock_redis.hdel = AsyncMock(return_value=0) - - app.dependency_overrides[get_session] = get_session_override - app.dependency_overrides[get_tsio] = lambda: mock_sio - app.dependency_overrides[get_redis] = lambda: mock_redis - app.state.settings = Settings() - app.state.auth_settings = AuthSettings() - - async with AsyncClient( - transport=ASGITransport(app=app), base_url="http://test" - ) as c: - yield c - - app.dependency_overrides.clear() - - -async def _create_user( - session: AsyncSession, email: str = "testuser@local.test" -) -> tuple[User, str]: - user = create_test_user_model(email=email) - session.add(user) - await session.commit() - await session.refresh(user) - return user, create_test_token(user) - - -async def _create_room(session: AsyncSession, user: User) -> Room: - room = Room(created_by_id=user.id, is_public=True) # type: ignore[arg-type] - session.add(room) - await session.commit() - await session.refresh(room) - membership = RoomMembership( - room_id=room.id, - user_id=user.id, - role=MemberRole.OWNER, # type: ignore[arg-type] - ) - session.add(membership) - await session.commit() - return room - - async def _create_geometry( session: AsyncSession, room_id: str, key: str, geo_type: str = "Camera" ) -> RoomGeometry: @@ -128,8 +37,8 @@ async def test_get_default_camera_none( session: AsyncSession, client: AsyncClient ) -> None: """New room returns null default_camera.""" - user, token = await _create_user(session) - room = await _create_room(session, user) + user, token = await create_test_user_in_db(session) + room = await create_test_room(session, user) resp = await client.get( f"/v1/rooms/{room.id}/default-camera", @@ -142,8 +51,8 @@ async def test_get_default_camera_none( @pytest.mark.asyncio async def test_set_default_camera(session: AsyncSession, client: AsyncClient) -> None: """PUT with valid Camera key, then GET returns it.""" - user, token = await _create_user(session) - room = await _create_room(session, user) + user, token = await create_test_user_in_db(session) + room = await create_test_room(session, user) await _create_geometry(session, room.id, "template-cam", "Camera") resp = await client.put( @@ -166,8 +75,8 @@ async def test_set_default_camera_not_found( session: AsyncSession, client: AsyncClient ) -> None: """PUT with nonexistent key returns 404.""" - user, token = await _create_user(session) - room = await _create_room(session, user) + user, token = await create_test_user_in_db(session) + room = await create_test_room(session, user) resp = await client.put( f"/v1/rooms/{room.id}/default-camera", @@ -182,8 +91,8 @@ async def test_set_default_camera_wrong_type( session: AsyncSession, client: AsyncClient ) -> None: """PUT with non-Camera geometry returns 400.""" - user, token = await _create_user(session) - room = await _create_room(session, user) + user, token = await create_test_user_in_db(session) + room = await create_test_room(session, user) await _create_geometry(session, room.id, "my-sphere", "Sphere") resp = await client.put( @@ -197,8 +106,8 @@ async def test_set_default_camera_wrong_type( @pytest.mark.asyncio async def test_unset_default_camera(session: AsyncSession, client: AsyncClient) -> None: """PUT null after setting unsets the default.""" - user, token = await _create_user(session) - room = await _create_room(session, user) + user, token = await create_test_user_in_db(session) + room = await create_test_room(session, user) await _create_geometry(session, room.id, "template-cam", "Camera") await client.put( @@ -227,8 +136,8 @@ async def test_delete_geometry_clears_default( session: AsyncSession, client: AsyncClient ) -> None: """Deleting the default camera geometry clears the default.""" - user, token = await _create_user(session) - room = await _create_room(session, user) + user, token = await create_test_user_in_db(session) + room = await create_test_room(session, user) await _create_geometry(session, room.id, "template-cam", "Camera") headers = {"Authorization": f"Bearer {token}"} diff --git a/tests/zndraw/test_routes_step.py b/tests/zndraw/test_routes_step.py index 1411656e2..85877e6d1 100644 --- a/tests/zndraw/test_routes_step.py +++ b/tests/zndraw/test_routes_step.py @@ -1,143 +1,18 @@ """Tests for Step REST API endpoints.""" -from collections.abc import AsyncIterator -from unittest.mock import AsyncMock, MagicMock - import pytest -import pytest_asyncio -from helpers import create_test_token, create_test_user_model, make_raw_frame -from httpx import ASGITransport, AsyncClient -from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine -from sqlalchemy.pool import StaticPool -from sqlmodel import SQLModel - -from zndraw.models import MemberRole, Room, RoomMembership +from helpers import ( + MockSioServer, + auth_header, + create_test_room, + create_test_user_in_db, + make_raw_frame, +) +from httpx import AsyncClient +from sqlalchemy.ext.asyncio import AsyncSession + from zndraw.schemas import StepResponse, StepUpdateResponse -from zndraw.socket_events import FrameUpdate from zndraw.storage import FrameStorage -from zndraw_auth import User - -# ============================================================================= -# Test-specific Fixtures -# ============================================================================= - - -@pytest_asyncio.fixture(name="step_session") -async def step_session_fixture() -> AsyncIterator[AsyncSession]: - """Create a fresh in-memory async database session for each test.""" - engine = create_async_engine( - "sqlite+aiosqlite://", - connect_args={"check_same_thread": False}, - poolclass=StaticPool, - ) - - try: - async with engine.begin() as conn: - await conn.run_sync(SQLModel.metadata.create_all) - - async_session_factory = async_sessionmaker( - bind=engine, - class_=AsyncSession, - expire_on_commit=False, - ) - - async with async_session_factory() as session: - yield session - finally: - await engine.dispose() - - -@pytest_asyncio.fixture(name="mock_sio") -async def mock_sio_fixture() -> MagicMock: - """Create a mock Socket.IO server for testing.""" - sio_mock = MagicMock() - sio_mock.emit = AsyncMock() - return sio_mock - - -@pytest_asyncio.fixture(name="step_client") -async def step_client_fixture( - step_session: AsyncSession, - frame_storage: FrameStorage, - mock_sio: MagicMock, -) -> AsyncIterator[AsyncClient]: - """Create an async test client with dependencies overridden.""" - from zndraw.app import app - from zndraw.config import Settings - from zndraw.dependencies import get_frame_storage, get_redis, get_tsio - from zndraw_auth import get_session - from zndraw_auth.settings import AuthSettings - - async def get_session_override() -> AsyncIterator[AsyncSession]: - yield step_session - - def get_storage_override() -> FrameStorage: - return frame_storage - - def get_sio_override() -> MagicMock: - return mock_sio - - # Mock Redis for WritableRoomDep (returns None = no edit lock) - mock_redis = AsyncMock() - mock_redis.get = AsyncMock(return_value=None) - - app.dependency_overrides[get_session] = get_session_override - app.dependency_overrides[get_frame_storage] = get_storage_override - app.dependency_overrides[get_tsio] = get_sio_override - app.dependency_overrides[get_redis] = lambda: mock_redis - app.state.settings = Settings() - app.state.auth_settings = AuthSettings() - - async with AsyncClient( - transport=ASGITransport(app=app), - base_url="http://test", - ) as client: - yield client - - app.dependency_overrides.clear() - - -async def _create_user( - session: AsyncSession, email: str = "testuser@local.test" -) -> tuple[User, str]: - """Create a user and return the user and access token.""" - user = create_test_user_model(email=email) - session.add(user) - await session.commit() - await session.refresh(user) - token = create_test_token(user) - return user, token - - -async def _create_room( - session: AsyncSession, user: User, description: str = "Test Room", step: int = 0 -) -> Room: - """Create a room with user as owner.""" - room = Room( - description=description, - created_by_id=user.id, # type: ignore[arg-type] - is_public=True, - step=step, - ) - session.add(room) - await session.commit() - await session.refresh(room) - - membership = RoomMembership( - room_id=room.id, # type: ignore[arg-type] - user_id=user.id, # type: ignore[arg-type] - role=MemberRole.OWNER, - ) - session.add(membership) - await session.commit() - - return room - - -def _auth_header(token: str) -> dict[str, str]: - """Return Authorization header dict.""" - return {"Authorization": f"Bearer {token}"} - # ============================================================================= # GET Step Tests @@ -146,22 +21,22 @@ def _auth_header(token: str) -> dict[str, str]: @pytest.mark.asyncio async def test_get_step_returns_zero_initially( - step_client: AsyncClient, - step_session: AsyncSession, + client: AsyncClient, + session: AsyncSession, frame_storage: FrameStorage, ) -> None: """Test GET returns step=0 for new room with no step set.""" - user, token = await _create_user(step_session) - room = await _create_room(step_session, user) + user, token = await create_test_user_in_db(session) + room = await create_test_room(session, user) # Add some frames to the room await frame_storage[room.id].extend( [make_raw_frame({"a": 1}), make_raw_frame({"b": 2}), make_raw_frame({"c": 3})] ) - response = await step_client.get( + response = await client.get( f"/v1/rooms/{room.id}/step", - headers=_auth_header(token), + headers=auth_header(token), ) assert response.status_code == 200 @@ -172,22 +47,40 @@ async def test_get_step_returns_zero_initially( @pytest.mark.asyncio async def test_get_step_returns_current_step( - step_client: AsyncClient, - step_session: AsyncSession, + client: AsyncClient, + session: AsyncSession, frame_storage: FrameStorage, ) -> None: """Test GET returns previously set step.""" - user, token = await _create_user(step_session) - room = await _create_room(step_session, user, step=2) + from zndraw.models import MemberRole, Room, RoomMembership + + user, token = await create_test_user_in_db(session) + # Create room with step=2 manually (create_test_room doesn't support step param) + room = Room( + description="Test Room", + created_by_id=user.id, # type: ignore[arg-type] + is_public=True, + step=2, + ) + session.add(room) + await session.commit() + await session.refresh(room) + membership = RoomMembership( + room_id=room.id, # type: ignore[arg-type] + user_id=user.id, # type: ignore[arg-type] + role=MemberRole.OWNER, + ) + session.add(membership) + await session.commit() # Add frames await frame_storage[room.id].extend( [make_raw_frame({"a": 1}), make_raw_frame({"b": 2}), make_raw_frame({"c": 3})] ) - response = await step_client.get( + response = await client.get( f"/v1/rooms/{room.id}/step", - headers=_auth_header(token), + headers=auth_header(token), ) assert response.status_code == 200 @@ -203,24 +96,24 @@ async def test_get_step_returns_current_step( @pytest.mark.asyncio async def test_set_step_updates_and_returns( - step_client: AsyncClient, - step_session: AsyncSession, + client: AsyncClient, + session: AsyncSession, frame_storage: FrameStorage, - mock_sio: MagicMock, + mock_sio: MockSioServer, ) -> None: """Test PUT updates step and returns new value.""" - user, token = await _create_user(step_session) - room = await _create_room(step_session, user) + user, token = await create_test_user_in_db(session) + room = await create_test_room(session, user) # Add frames await frame_storage[room.id].extend( [make_raw_frame({"a": 1}), make_raw_frame({"b": 2}), make_raw_frame({"c": 3})] ) - response = await step_client.put( + response = await client.put( f"/v1/rooms/{room.id}/step", json={"step": 1}, - headers=_auth_header(token), + headers=auth_header(token), ) assert response.status_code == 200 @@ -229,25 +122,24 @@ async def test_set_step_updates_and_returns( assert result.step == 1 # Verify step persisted in DB - await step_session.refresh(room) + await session.refresh(room) assert room.step == 1 # Verify Socket.IO broadcast was sent - mock_sio.emit.assert_called_once() - call_args = mock_sio.emit.call_args - assert isinstance(call_args[0][0], FrameUpdate) - assert call_args[1]["room"] == f"room:{room.id}" + assert len(mock_sio.emitted) == 1 + assert mock_sio.emitted[0]["event"] == "frame_update" + assert mock_sio.emitted[0]["room"] == f"room:{room.id}" @pytest.mark.asyncio async def test_set_step_out_of_bounds_returns_422( - step_client: AsyncClient, - step_session: AsyncSession, + client: AsyncClient, + session: AsyncSession, frame_storage: FrameStorage, ) -> None: """Test PUT with step > total_frames returns 422.""" - user, token = await _create_user(step_session) - room = await _create_room(step_session, user) + user, token = await create_test_user_in_db(session) + room = await create_test_room(session, user) # Add 3 frames (indices 0, 1, 2) await frame_storage[room.id].extend( @@ -255,10 +147,10 @@ async def test_set_step_out_of_bounds_returns_422( ) # Request step=100 — should return 422 - response = await step_client.put( + response = await client.put( f"/v1/rooms/{room.id}/step", json={"step": 100}, - headers=_auth_header(token), + headers=auth_header(token), ) assert response.status_code == 422 body = response.json() @@ -268,19 +160,19 @@ async def test_set_step_out_of_bounds_returns_422( @pytest.mark.asyncio async def test_set_step_empty_room_rejects_nonzero( - step_client: AsyncClient, - step_session: AsyncSession, + client: AsyncClient, + session: AsyncSession, frame_storage: FrameStorage, ) -> None: """Test PUT to room with no frames rejects non-zero step.""" - user, token = await _create_user(step_session) - room = await _create_room(step_session, user) + user, token = await create_test_user_in_db(session) + room = await create_test_room(session, user) # Room has no frames — step=5 should be rejected - response = await step_client.put( + response = await client.put( f"/v1/rooms/{room.id}/step", json={"step": 5}, - headers=_auth_header(token), + headers=auth_header(token), ) assert response.status_code == 422 body = response.json() @@ -289,20 +181,20 @@ async def test_set_step_empty_room_rejects_nonzero( @pytest.mark.asyncio async def test_set_step_negative_returns_422( - step_client: AsyncClient, - step_session: AsyncSession, + client: AsyncClient, + session: AsyncSession, frame_storage: FrameStorage, ) -> None: """Test PUT with negative step returns 422 (Pydantic ge=0 validation).""" - user, token = await _create_user(step_session) - room = await _create_room(step_session, user) + user, token = await create_test_user_in_db(session) + room = await create_test_room(session, user) await frame_storage[room.id].extend([make_raw_frame({"a": 1})]) - response = await step_client.put( + response = await client.put( f"/v1/rooms/{room.id}/step", json={"step": -1}, - headers=_auth_header(token), + headers=auth_header(token), ) assert response.status_code == 422 body = response.json() @@ -317,25 +209,25 @@ async def test_set_step_negative_returns_422( @pytest.mark.asyncio async def test_get_step_requires_auth( - step_client: AsyncClient, step_session: AsyncSession + client: AsyncClient, session: AsyncSession ) -> None: """Test GET without auth returns 401.""" - user, _ = await _create_user(step_session) - room = await _create_room(step_session, user) + user, _ = await create_test_user_in_db(session) + room = await create_test_room(session, user) - response = await step_client.get(f"/v1/rooms/{room.id}/step") + response = await client.get(f"/v1/rooms/{room.id}/step") assert response.status_code == 401 @pytest.mark.asyncio async def test_set_step_requires_auth( - step_client: AsyncClient, step_session: AsyncSession + client: AsyncClient, session: AsyncSession ) -> None: """Test PUT without auth returns 401.""" - user, _ = await _create_user(step_session) - room = await _create_room(step_session, user) + user, _ = await create_test_user_in_db(session) + room = await create_test_room(session, user) - response = await step_client.put( + response = await client.put( f"/v1/rooms/{room.id}/step", json={"step": 1}, ) @@ -349,14 +241,14 @@ async def test_set_step_requires_auth( @pytest.mark.asyncio async def test_get_step_returns_404_for_nonexistent_room( - step_client: AsyncClient, step_session: AsyncSession + client: AsyncClient, session: AsyncSession ) -> None: """Test GET for non-existent room returns 404.""" - _, token = await _create_user(step_session) + _, token = await create_test_user_in_db(session) - response = await step_client.get( + response = await client.get( "/v1/rooms/99999/step", - headers=_auth_header(token), + headers=auth_header(token), ) assert response.status_code == 404 assert "room-not-found" in response.json()["type"] @@ -364,15 +256,15 @@ async def test_get_step_returns_404_for_nonexistent_room( @pytest.mark.asyncio async def test_set_step_returns_404_for_nonexistent_room( - step_client: AsyncClient, step_session: AsyncSession + client: AsyncClient, session: AsyncSession ) -> None: """Test PUT for non-existent room returns 404.""" - _, token = await _create_user(step_session) + _, token = await create_test_user_in_db(session) - response = await step_client.put( + response = await client.put( "/v1/rooms/99999/step", json={"step": 1}, - headers=_auth_header(token), + headers=auth_header(token), ) assert response.status_code == 404 assert "room-not-found" in response.json()["type"] diff --git a/tests/zndraw/test_screenshots.py b/tests/zndraw/test_screenshots.py index f67d1a3dc..dd80d205f 100644 --- a/tests/zndraw/test_screenshots.py +++ b/tests/zndraw/test_screenshots.py @@ -1,130 +1,95 @@ """Tests for Screenshot REST API endpoints.""" +import json from collections.abc import AsyncIterator +from contextlib import asynccontextmanager from pathlib import Path -from unittest.mock import AsyncMock import pytest import pytest_asyncio -from helpers import MockSioServer, create_test_token, create_test_user_model +from helpers import ( + InMemoryResultBackend, + MockSioServer, + auth_header, + create_test_room, + create_test_user_in_db, +) from httpx import ASGITransport, AsyncClient -from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine -from sqlalchemy.pool import StaticPool -from sqlmodel import SQLModel +from sqlalchemy.ext.asyncio import AsyncSession -from zndraw.config import Settings -from zndraw.models import MemberRole, Room, RoomMembership +from zndraw.models import Room +from zndraw.redis import RedisKey from zndraw.schemas import StatusResponse -from zndraw_auth import User -from zndraw_auth.settings import AuthSettings +from zndraw.storage import FrameStorage # ============================================================================= -# Fixtures +# Per-file client fixture (adds media_path override on top of shared infra) # ============================================================================= -@pytest_asyncio.fixture(name="ss_session") -async def ss_session_fixture() -> AsyncIterator[AsyncSession]: - """Create a fresh in-memory async database session for each test.""" - engine = create_async_engine( - "sqlite+aiosqlite://", - connect_args={"check_same_thread": False}, - poolclass=StaticPool, - ) - try: - async with engine.begin() as conn: - await conn.run_sync(SQLModel.metadata.create_all) - factory = async_sessionmaker( - bind=engine, class_=AsyncSession, expire_on_commit=False - ) - async with factory() as session: - yield session - finally: - await engine.dispose() - - -@pytest_asyncio.fixture(name="mock_sio") -async def mock_sio_fixture() -> MockSioServer: - return MockSioServer() - - @pytest.fixture(name="media_path") def media_path_fixture(tmp_path: Path) -> Path: """Provide a temporary media directory for screenshots.""" return tmp_path / "media" -@pytest.fixture(name="mock_redis") -def mock_redis_fixture() -> AsyncMock: - """Provide a mock Redis with configurable hexists.""" - mock = AsyncMock() - mock.get = AsyncMock(return_value=None) - mock.hexists = AsyncMock(return_value=False) - return mock - - -@pytest_asyncio.fixture(name="ss_client") -async def ss_client_fixture( - ss_session: AsyncSession, +@pytest_asyncio.fixture(name="client") +async def client_fixture( + session: AsyncSession, + redis_client, mock_sio: MockSioServer, + frame_storage: FrameStorage, + result_backend: InMemoryResultBackend, media_path: Path, - mock_redis: AsyncMock, ) -> AsyncIterator[AsyncClient]: + """Async test client with media_path override for screenshot tests.""" from zndraw.app import app - from zndraw.dependencies import get_media_path, get_redis, get_tsio + from zndraw.config import Settings + from zndraw.dependencies import ( + get_frame_storage, + get_joblib_settings, + get_media_path, + get_redis, + get_result_backend, + get_tsio, + ) from zndraw_auth import get_session + from zndraw_auth.settings import AuthSettings + from zndraw_joblib.settings import JobLibSettings async def get_session_override() -> AsyncIterator[AsyncSession]: - yield ss_session + yield session - def get_sio_override() -> MockSioServer: - return mock_sio + @asynccontextmanager + async def test_session_maker(): + yield session settings = Settings() settings.media_path = media_path # type: ignore[assignment] + app.dependency_overrides[get_session] = get_session_override - app.dependency_overrides[get_tsio] = get_sio_override - app.dependency_overrides[get_redis] = lambda: mock_redis + app.dependency_overrides[get_redis] = lambda: redis_client + app.dependency_overrides[get_tsio] = lambda: mock_sio + app.dependency_overrides[get_frame_storage] = lambda: frame_storage + app.dependency_overrides[get_result_backend] = lambda: result_backend + app.dependency_overrides[get_joblib_settings] = lambda: JobLibSettings() app.dependency_overrides[get_media_path] = lambda: media_path + app.state.session_maker = test_session_maker app.state.settings = settings app.state.auth_settings = AuthSettings() async with AsyncClient( - transport=ASGITransport(app=app), base_url="http://test" + transport=ASGITransport(app=app), + base_url="http://test", ) as client: yield client app.dependency_overrides.clear() -async def _create_user( - session: AsyncSession, email: str = "testuser@local.test" -) -> tuple[User, str]: - user = create_test_user_model(email=email) - session.add(user) - await session.commit() - await session.refresh(user) - token = create_test_token(user) - return user, token - - -async def _create_room(session: AsyncSession, user: User) -> Room: - room = Room(description="Test Room", created_by_id=user.id, is_public=True) # type: ignore[arg-type] - session.add(room) - await session.commit() - await session.refresh(room) - membership = RoomMembership( - room_id=room.id, - user_id=user.id, - role=MemberRole.OWNER, # type: ignore[arg-type] - ) - session.add(membership) - await session.commit() - return room - - -def _auth(token: str) -> dict[str, str]: - return {"Authorization": f"Bearer {token}"} +# ============================================================================= +# Helpers unique to this test file +# ============================================================================= def _png_bytes(size: int = 100) -> bytes: @@ -132,6 +97,14 @@ def _png_bytes(size: int = 100) -> bytes: return b"\x89PNG\r\n\x1a\n" + b"\x00" * size +def _make_camera_entry(sid: str, owner_id: str, email: str = "u@test") -> str: + """Build a JSON camera entry as stored in room_cameras hash.""" + from zndraw.geometries.camera import Camera + + camera = Camera(owner=owner_id) + return json.dumps({"sid": sid, "email": email, "data": camera.model_dump()}) + + # ============================================================================= # POST /upload — Upload screenshot # ============================================================================= @@ -139,20 +112,20 @@ def _png_bytes(size: int = 100) -> bytes: @pytest.mark.asyncio async def test_upload_screenshot( - ss_client: AsyncClient, - ss_session: AsyncSession, + client: AsyncClient, + session: AsyncSession, media_path: Path, ) -> None: """Upload a PNG screenshot, verify 201 and file on disk.""" - user, token = await _create_user(ss_session) - room = await _create_room(ss_session, user) + user, token = await create_test_user_in_db(session) + room = await create_test_room(session, user) data = _png_bytes() - response = await ss_client.post( + response = await client.post( f"/v1/rooms/{room.id}/screenshots/upload", files={"file": ("shot.png", data, "image/png")}, data={"format": "png"}, - headers=_auth(token), + headers=auth_header(token), ) assert response.status_code == 201 body = response.json() @@ -169,17 +142,17 @@ async def test_upload_screenshot( @pytest.mark.asyncio async def test_upload_invalid_format( - ss_client: AsyncClient, ss_session: AsyncSession + client: AsyncClient, session: AsyncSession ) -> None: """Upload with unsupported format returns 422.""" - user, token = await _create_user(ss_session) - room = await _create_room(ss_session, user) + user, token = await create_test_user_in_db(session) + room = await create_test_room(session, user) - response = await ss_client.post( + response = await client.post( f"/v1/rooms/{room.id}/screenshots/upload", files={"file": ("shot.bmp", b"data", "image/bmp")}, data={"format": "bmp"}, - headers=_auth(token), + headers=auth_header(token), ) assert response.status_code == 422 assert "invalid-screenshot-format" in response.json()["type"] @@ -187,18 +160,18 @@ async def test_upload_invalid_format( @pytest.mark.asyncio async def test_upload_too_large( - ss_client: AsyncClient, ss_session: AsyncSession + client: AsyncClient, session: AsyncSession ) -> None: """Upload exceeding 10 MB returns 413.""" - user, token = await _create_user(ss_session) - room = await _create_room(ss_session, user) + user, token = await create_test_user_in_db(session) + room = await create_test_room(session, user) big_data = b"\x00" * (10 * 1024 * 1024 + 1) - response = await ss_client.post( + response = await client.post( f"/v1/rooms/{room.id}/screenshots/upload", files={"file": ("shot.png", big_data, "image/png")}, data={"format": "png"}, - headers=_auth(token), + headers=auth_header(token), ) assert response.status_code == 413 assert "screenshot-too-large" in response.json()["type"] @@ -211,23 +184,23 @@ async def test_upload_too_large( @pytest.mark.asyncio async def test_list_screenshots( - ss_client: AsyncClient, ss_session: AsyncSession + client: AsyncClient, session: AsyncSession ) -> None: """Upload 3 screenshots, verify list with pagination.""" - user, token = await _create_user(ss_session) - room = await _create_room(ss_session, user) + user, token = await create_test_user_in_db(session) + room = await create_test_room(session, user) for i in range(3): - await ss_client.post( + await client.post( f"/v1/rooms/{room.id}/screenshots/upload", files={"file": (f"shot{i}.png", _png_bytes(50 + i), "image/png")}, data={"format": "png"}, - headers=_auth(token), + headers=auth_header(token), ) - response = await ss_client.get( + response = await client.get( f"/v1/rooms/{room.id}/screenshots?limit=2&offset=0", - headers=_auth(token), + headers=auth_header(token), ) assert response.status_code == 200 body = response.json() @@ -237,9 +210,9 @@ async def test_list_screenshots( assert body["offset"] == 0 # Second page - response2 = await ss_client.get( + response2 = await client.get( f"/v1/rooms/{room.id}/screenshots?limit=2&offset=2", - headers=_auth(token), + headers=auth_header(token), ) assert response2.status_code == 200 body2 = response2.json() @@ -252,23 +225,23 @@ async def test_list_screenshots( @pytest.mark.asyncio -async def test_get_screenshot(ss_client: AsyncClient, ss_session: AsyncSession) -> None: +async def test_get_screenshot(client: AsyncClient, session: AsyncSession) -> None: """Upload then GET, verify data field is base64-encoded.""" - user, token = await _create_user(ss_session) - room = await _create_room(ss_session, user) + user, token = await create_test_user_in_db(session) + room = await create_test_room(session, user) data = _png_bytes() - upload = await ss_client.post( + upload = await client.post( f"/v1/rooms/{room.id}/screenshots/upload", files={"file": ("shot.png", data, "image/png")}, data={"format": "png"}, - headers=_auth(token), + headers=auth_header(token), ) screenshot_id = upload.json()["id"] - response = await ss_client.get( + response = await client.get( f"/v1/rooms/{room.id}/screenshots/{screenshot_id}", - headers=_auth(token), + headers=auth_header(token), ) assert response.status_code == 200 body = response.json() @@ -282,15 +255,15 @@ async def test_get_screenshot(ss_client: AsyncClient, ss_session: AsyncSession) @pytest.mark.asyncio async def test_get_screenshot_not_found( - ss_client: AsyncClient, ss_session: AsyncSession + client: AsyncClient, session: AsyncSession ) -> None: """GET nonexistent screenshot returns 404.""" - user, token = await _create_user(ss_session) - room = await _create_room(ss_session, user) + user, token = await create_test_user_in_db(session) + room = await create_test_room(session, user) - response = await ss_client.get( + response = await client.get( f"/v1/rooms/{room.id}/screenshots/99999", - headers=_auth(token), + headers=auth_header(token), ) assert response.status_code == 404 assert "screenshot-not-found" in response.json()["type"] @@ -303,27 +276,27 @@ async def test_get_screenshot_not_found( @pytest.mark.asyncio async def test_delete_screenshot( - ss_client: AsyncClient, - ss_session: AsyncSession, + client: AsyncClient, + session: AsyncSession, media_path: Path, ) -> None: """Upload, DELETE, verify 204 and file removed.""" - user, token = await _create_user(ss_session) - room = await _create_room(ss_session, user) + user, token = await create_test_user_in_db(session) + room = await create_test_room(session, user) - upload = await ss_client.post( + upload = await client.post( f"/v1/rooms/{room.id}/screenshots/upload", files={"file": ("shot.png", _png_bytes(), "image/png")}, data={"format": "png"}, - headers=_auth(token), + headers=auth_header(token), ) screenshot_id = upload.json()["id"] file_path = media_path / room.id / "screenshots" / f"{screenshot_id}.png" assert file_path.exists() - response = await ss_client.delete( + response = await client.delete( f"/v1/rooms/{room.id}/screenshots/{screenshot_id}", - headers=_auth(token), + headers=auth_header(token), ) assert response.status_code == 200 StatusResponse.model_validate(response.json()) @@ -335,43 +308,32 @@ async def test_delete_screenshot( # ============================================================================= -def _make_camera_entry(sid: str, owner_id: str, email: str = "u@test") -> str: - """Build a JSON camera entry as stored in room_cameras hash.""" - import json - - from zndraw.geometries.camera import Camera - - camera = Camera(owner=owner_id) - return json.dumps({"sid": sid, "email": email, "data": camera.model_dump()}) - - @pytest.mark.asyncio async def test_request_capture( - ss_client: AsyncClient, - ss_session: AsyncSession, + client: AsyncClient, + session: AsyncSession, mock_sio: MockSioServer, - mock_redis: AsyncMock, + redis_client, ) -> None: """JSON POST with live session owned by requesting user returns 202.""" - user, token = await _create_user(ss_session) - room = await _create_room(ss_session, user) + user, token = await create_test_user_in_db(session) + room = await create_test_room(session, user) target_sid = "frontend-sid-123" - - # Session is active AND owned by requesting user - mock_redis.hexists = AsyncMock(return_value=True) - mock_redis.hgetall = AsyncMock( - return_value={ - f"cam:{user.email}:abcd1234": _make_camera_entry( - target_sid, str(user.id), user.email - ), - } + cam_key = f"cam:{user.email}:abcd1234" + + # Pre-populate Redis: mark session as active and add camera entry + await redis_client.hset(RedisKey.active_cameras(room.id), target_sid, cam_key) + await redis_client.hset( + RedisKey.room_cameras(room.id), + cam_key, + _make_camera_entry(target_sid, str(user.id), user.email), ) - response = await ss_client.post( + response = await client.post( f"/v1/rooms/{room.id}/screenshots", json={"session_id": target_sid}, - headers=_auth(token), + headers=auth_header(token), ) assert response.status_code == 202 body = response.json() @@ -389,32 +351,31 @@ async def test_request_capture( @pytest.mark.asyncio async def test_request_capture_rejects_other_users_session( - ss_client: AsyncClient, - ss_session: AsyncSession, + client: AsyncClient, + session: AsyncSession, mock_sio: MockSioServer, - mock_redis: AsyncMock, + redis_client, ) -> None: """Requesting a screenshot from another user's session returns 409.""" - user_a, token_a = await _create_user(ss_session, email="alice@test") - room = await _create_room(ss_session, user_a) - user_b, _token_b = await _create_user(ss_session, email="bob@test") + user_a, token_a = await create_test_user_in_db(session, email="alice@test") + room = await create_test_room(session, user_a) + user_b, _token_b = await create_test_user_in_db(session, email="bob@test") target_sid = "bobs-browser-sid" + cam_key = "cam:bob@test:abcd1234" # Session exists but is owned by user_b, not user_a - mock_redis.hexists = AsyncMock(return_value=True) - mock_redis.hgetall = AsyncMock( - return_value={ - "cam:bob@test:abcd1234": _make_camera_entry( - target_sid, str(user_b.id), user_b.email - ), - } + await redis_client.hset(RedisKey.active_cameras(room.id), target_sid, cam_key) + await redis_client.hset( + RedisKey.room_cameras(room.id), + cam_key, + _make_camera_entry(target_sid, str(user_b.id), user_b.email), ) - response = await ss_client.post( + response = await client.post( f"/v1/rooms/{room.id}/screenshots", json={"session_id": target_sid}, - headers=_auth(token_a), + headers=auth_header(token_a), ) assert response.status_code == 409 assert "no-frontend-session" in response.json()["type"] @@ -425,19 +386,18 @@ async def test_request_capture_rejects_other_users_session( @pytest.mark.asyncio async def test_request_capture_invalid_session( - ss_client: AsyncClient, - ss_session: AsyncSession, - mock_redis: AsyncMock, + client: AsyncClient, + session: AsyncSession, ) -> None: """Request capture with non-active session returns 409.""" - user, token = await _create_user(ss_session) - room = await _create_room(ss_session, user) + user, token = await create_test_user_in_db(session) + room = await create_test_room(session, user) - # Session not active (default mock_redis.hexists returns False) - response = await ss_client.post( + # Session not active (Redis has no entry for this sid) + response = await client.post( f"/v1/rooms/{room.id}/screenshots", json={"session_id": "nonexistent-sid"}, - headers=_auth(token), + headers=auth_header(token), ) assert response.status_code == 409 assert "no-frontend-session" in response.json()["type"] @@ -450,41 +410,41 @@ async def test_request_capture_invalid_session( @pytest.mark.asyncio async def test_patch_pending_screenshot( - ss_client: AsyncClient, - ss_session: AsyncSession, + client: AsyncClient, + session: AsyncSession, mock_sio: MockSioServer, - mock_redis: AsyncMock, + redis_client, media_path: Path, ) -> None: """Create pending via capture request, then PATCH with file.""" - user, token = await _create_user(ss_session) - room = await _create_room(ss_session, user) + user, token = await create_test_user_in_db(session) + room = await create_test_room(session, user) target_sid = "frontend-sid" - mock_redis.hexists = AsyncMock(return_value=True) - mock_redis.hgetall = AsyncMock( - return_value={ - f"cam:{user.email}:abcd1234": _make_camera_entry( - target_sid, str(user.id), user.email - ), - } + cam_key = f"cam:{user.email}:abcd1234" + + await redis_client.hset(RedisKey.active_cameras(room.id), target_sid, cam_key) + await redis_client.hset( + RedisKey.room_cameras(room.id), + cam_key, + _make_camera_entry(target_sid, str(user.id), user.email), ) # Create pending screenshot - capture_resp = await ss_client.post( + capture_resp = await client.post( f"/v1/rooms/{room.id}/screenshots", json={"session_id": target_sid}, - headers=_auth(token), + headers=auth_header(token), ) screenshot_id = capture_resp.json()["id"] # PATCH to complete data = _png_bytes(200) - response = await ss_client.patch( + response = await client.patch( f"/v1/rooms/{room.id}/screenshots/{screenshot_id}", files={"file": ("shot.png", data, "image/png")}, data={"format": "png", "width": "1920", "height": "1080"}, - headers=_auth(token), + headers=auth_header(token), ) assert response.status_code == 200 body = response.json() @@ -500,27 +460,27 @@ async def test_patch_pending_screenshot( @pytest.mark.asyncio async def test_patch_completed_screenshot( - ss_client: AsyncClient, ss_session: AsyncSession, media_path: Path + client: AsyncClient, session: AsyncSession, media_path: Path ) -> None: """PATCH on already-completed screenshot returns 409.""" - user, token = await _create_user(ss_session) - room = await _create_room(ss_session, user) + user, token = await create_test_user_in_db(session) + room = await create_test_room(session, user) # Upload a completed screenshot - upload = await ss_client.post( + upload = await client.post( f"/v1/rooms/{room.id}/screenshots/upload", files={"file": ("shot.png", _png_bytes(), "image/png")}, data={"format": "png"}, - headers=_auth(token), + headers=auth_header(token), ) screenshot_id = upload.json()["id"] # Try to PATCH (should fail — already completed) - response = await ss_client.patch( + response = await client.patch( f"/v1/rooms/{room.id}/screenshots/{screenshot_id}", files={"file": ("shot2.png", _png_bytes(), "image/png")}, data={"format": "png"}, - headers=_auth(token), + headers=auth_header(token), ) assert response.status_code == 409 assert "screenshot-not-pending" in response.json()["type"] From 7f15f1dca9375945e33746114165324385f0749e Mon Sep 17 00:00:00 2001 From: Fabian Zills Date: Thu, 26 Mar 2026 14:17:34 +0100 Subject: [PATCH 07/20] refactor: convert chat/step/default_camera/screenshots to shared fixtures - Delete per-file session/client/redis/sio fixtures and local helpers - Use shared client, session, mock_sio from conftest.py - Convert SIO assertions to MockSioServer.emitted pattern - Remove unused imports and mock_sio/media_path params from non-asserting tests - screenshots retains per-file client fixture for MediaPathDep override Co-Authored-By: Claude Opus 4.6 (1M context) --- tests/zndraw/test_chat.py | 2 +- tests/zndraw/test_default_camera.py | 2 +- tests/zndraw/test_routes_step.py | 1 - tests/zndraw/test_screenshots.py | 6 ++---- 4 files changed, 4 insertions(+), 7 deletions(-) diff --git a/tests/zndraw/test_chat.py b/tests/zndraw/test_chat.py index 433360ed7..f2a0091d7 100644 --- a/tests/zndraw/test_chat.py +++ b/tests/zndraw/test_chat.py @@ -13,7 +13,7 @@ from httpx import AsyncClient from sqlalchemy.ext.asyncio import AsyncSession -from zndraw.models import Message, Room +from zndraw.models import Message # ============================================================================= # Helpers unique to this test file diff --git a/tests/zndraw/test_default_camera.py b/tests/zndraw/test_default_camera.py index 4391a1d09..4df6638b1 100644 --- a/tests/zndraw/test_default_camera.py +++ b/tests/zndraw/test_default_camera.py @@ -3,7 +3,7 @@ import json import pytest -from helpers import auth_header, create_test_room, create_test_user_in_db +from helpers import create_test_room, create_test_user_in_db from httpx import AsyncClient from sqlalchemy.ext.asyncio import AsyncSession diff --git a/tests/zndraw/test_routes_step.py b/tests/zndraw/test_routes_step.py index 85877e6d1..13c4f8b18 100644 --- a/tests/zndraw/test_routes_step.py +++ b/tests/zndraw/test_routes_step.py @@ -162,7 +162,6 @@ async def test_set_step_out_of_bounds_returns_422( async def test_set_step_empty_room_rejects_nonzero( client: AsyncClient, session: AsyncSession, - frame_storage: FrameStorage, ) -> None: """Test PUT to room with no frames rejects non-zero step.""" user, token = await create_test_user_in_db(session) diff --git a/tests/zndraw/test_screenshots.py b/tests/zndraw/test_screenshots.py index dd80d205f..e3a0e40a5 100644 --- a/tests/zndraw/test_screenshots.py +++ b/tests/zndraw/test_screenshots.py @@ -17,7 +17,6 @@ from httpx import ASGITransport, AsyncClient from sqlalchemy.ext.asyncio import AsyncSession -from zndraw.models import Room from zndraw.redis import RedisKey from zndraw.schemas import StatusResponse from zndraw.storage import FrameStorage @@ -359,7 +358,7 @@ async def test_request_capture_rejects_other_users_session( """Requesting a screenshot from another user's session returns 409.""" user_a, token_a = await create_test_user_in_db(session, email="alice@test") room = await create_test_room(session, user_a) - user_b, _token_b = await create_test_user_in_db(session, email="bob@test") + user_b, _ = await create_test_user_in_db(session, email="bob@test") target_sid = "bobs-browser-sid" cam_key = "cam:bob@test:abcd1234" @@ -412,7 +411,6 @@ async def test_request_capture_invalid_session( async def test_patch_pending_screenshot( client: AsyncClient, session: AsyncSession, - mock_sio: MockSioServer, redis_client, media_path: Path, ) -> None: @@ -460,7 +458,7 @@ async def test_patch_pending_screenshot( @pytest.mark.asyncio async def test_patch_completed_screenshot( - client: AsyncClient, session: AsyncSession, media_path: Path + client: AsyncClient, session: AsyncSession ) -> None: """PATCH on already-completed screenshot returns 409.""" user, token = await create_test_user_in_db(session) From ce2090dd0da5995ecfb9f63abc1d6a498e1ec118 Mon Sep 17 00:00:00 2001 From: Fabian Zills Date: Thu, 26 Mar 2026 14:26:33 +0100 Subject: [PATCH 08/20] refactor: convert edit_lock/frame_selection/frames/progress to shared fixtures Remove per-file session/client/redis/sio fixtures and replace with shared conftest fixtures (session, client, redis_client, mock_sio, frame_storage). Also adds is_superuser param to create_test_user_in_db helper and converts mock_redis AsyncMock usage in progress tests to real Redis client. Co-Authored-By: Claude Sonnet 4.6 --- tests/zndraw/helpers.py | 7 +- tests/zndraw/test_progress.py | 218 ++------ tests/zndraw/test_routes_edit_lock.py | 537 ++++++++------------ tests/zndraw/test_routes_frame_selection.py | 268 +++------- tests/zndraw/test_routes_frames.py | 370 +++++--------- 5 files changed, 465 insertions(+), 935 deletions(-) diff --git a/tests/zndraw/helpers.py b/tests/zndraw/helpers.py index ae76b107f..0448e7c2b 100644 --- a/tests/zndraw/helpers.py +++ b/tests/zndraw/helpers.py @@ -69,10 +69,13 @@ def decode_msgpack_response(content: bytes) -> list[dict[bytes, bytes]]: async def create_test_user_in_db( - session: AsyncSession, email: str = "testuser@local.test" + session: AsyncSession, + email: str = "testuser@local.test", + *, + is_superuser: bool = False, ) -> tuple[User, str]: """Create a user in the DB and return (user, token).""" - user = create_test_user_model(email=email) + user = create_test_user_model(email=email, is_superuser=is_superuser) session.add(user) await session.commit() await session.refresh(user) diff --git a/tests/zndraw/test_progress.py b/tests/zndraw/test_progress.py index 353520c9e..439a7c55e 100644 --- a/tests/zndraw/test_progress.py +++ b/tests/zndraw/test_progress.py @@ -1,122 +1,19 @@ """Tests for Progress REST API endpoints.""" import json -from collections.abc import AsyncIterator -from unittest.mock import AsyncMock import pytest -import pytest_asyncio from helpers import ( MockSioServer, auth_header, create_test_room, create_test_user_in_db, ) -from httpx import ASGITransport, AsyncClient -from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine -from sqlalchemy.pool import StaticPool -from sqlmodel import SQLModel +from httpx import AsyncClient +from redis.asyncio import Redis +from sqlalchemy.ext.asyncio import AsyncSession -from zndraw.config import Settings from zndraw.redis import RedisKey -from zndraw_auth.settings import AuthSettings - -# ============================================================================= -# Fixtures -# ============================================================================= - - -@pytest_asyncio.fixture(name="progress_session") -async def progress_session_fixture() -> AsyncIterator[AsyncSession]: - """Create a fresh in-memory async database session for each test.""" - engine = create_async_engine( - "sqlite+aiosqlite://", - connect_args={"check_same_thread": False}, - poolclass=StaticPool, - ) - try: - async with engine.begin() as conn: - await conn.run_sync(SQLModel.metadata.create_all) - factory = async_sessionmaker( - bind=engine, class_=AsyncSession, expire_on_commit=False - ) - async with factory() as session: - yield session - finally: - await engine.dispose() - - -@pytest_asyncio.fixture(name="mock_sio") -async def mock_sio_fixture() -> MockSioServer: - return MockSioServer() - - -@pytest_asyncio.fixture(name="mock_redis") -async def mock_redis_fixture() -> AsyncMock: - redis = AsyncMock() - redis.get = AsyncMock(return_value=None) # no edit lock - redis._progress_store: dict[str, dict[str, str]] = {} - - async def hset(key: str, field: str, value: str) -> int: - if key not in redis._progress_store: - redis._progress_store[key] = {} - redis._progress_store[key][field] = value - return 1 - - async def hget(key: str, field: str) -> str | None: - return redis._progress_store.get(key, {}).get(field) - - async def hdel(key: str, *fields: str) -> int: - deleted = 0 - if key in redis._progress_store: - for f in fields: - if f in redis._progress_store[key]: - del redis._progress_store[key][f] - deleted += 1 - return deleted - - async def hgetall(key: str) -> dict[str, str]: - return redis._progress_store.get(key, {}) - - redis.hset = AsyncMock(side_effect=hset) - redis.hget = AsyncMock(side_effect=hget) - redis.hdel = AsyncMock(side_effect=hdel) - redis.hgetall = AsyncMock(side_effect=hgetall) - redis.expire = AsyncMock(return_value=True) - return redis - - -@pytest_asyncio.fixture(name="progress_client") -async def progress_client_fixture( - progress_session: AsyncSession, - mock_sio: MockSioServer, - mock_redis: AsyncMock, -) -> AsyncIterator[AsyncClient]: - from zndraw.app import app - from zndraw.dependencies import get_redis, get_tsio - from zndraw_auth import get_session - - async def get_session_override() -> AsyncIterator[AsyncSession]: - yield progress_session - - def get_sio_override() -> MockSioServer: - return mock_sio - - def get_redis_override() -> AsyncMock: - return mock_redis - - app.dependency_overrides[get_session] = get_session_override - app.dependency_overrides[get_tsio] = get_sio_override - app.dependency_overrides[get_redis] = get_redis_override - app.state.settings = Settings() - app.state.auth_settings = AuthSettings() - - async with AsyncClient( - transport=ASGITransport(app=app), base_url="http://test" - ) as client: - yield client - - app.dependency_overrides.clear() # ============================================================================= @@ -126,16 +23,16 @@ def get_redis_override() -> AsyncMock: @pytest.mark.asyncio async def test_create_progress( - progress_client: AsyncClient, - progress_session: AsyncSession, + client: AsyncClient, + session: AsyncSession, mock_sio: MockSioServer, - mock_redis: AsyncMock, + redis_client: Redis, ) -> None: """POST creates tracker, returns 201, emits progress_start, stores in Redis.""" - user, token = await create_test_user_in_db(progress_session) - room = await create_test_room(progress_session, user) + user, token = await create_test_user_in_db(session) + room = await create_test_room(session, user) - response = await progress_client.post( + response = await client.post( f"/v1/rooms/{room.id}/progress", json={"progress_id": "task-1", "description": "Loading data"}, headers=auth_header(token), @@ -155,27 +52,24 @@ async def test_create_progress( # Verify stored in Redis redis_key = RedisKey.room_progress(room.id) - stored = await mock_redis.hget(redis_key, "task-1") + stored = await redis_client.hget(redis_key, "task-1") assert stored is not None stored_data = json.loads(stored) assert stored_data["progress_id"] == "task-1" assert stored_data["description"] == "Loading data" - # Verify TTL was set - mock_redis.expire.assert_called() - @pytest.mark.asyncio async def test_create_progress_with_unit( - progress_client: AsyncClient, - progress_session: AsyncSession, + client: AsyncClient, + session: AsyncSession, mock_sio: MockSioServer, ) -> None: """POST with custom unit returns it in the response.""" - user, token = await create_test_user_in_db(progress_session) - room = await create_test_room(progress_session, user) + user, token = await create_test_user_in_db(session) + room = await create_test_room(session, user) - response = await progress_client.post( + response = await client.post( f"/v1/rooms/{room.id}/progress", json={ "progress_id": "task-1", @@ -193,14 +87,14 @@ async def test_create_progress_with_unit( @pytest.mark.asyncio async def test_create_progress_requires_auth( - progress_client: AsyncClient, - progress_session: AsyncSession, + client: AsyncClient, + session: AsyncSession, ) -> None: """POST without token returns 401.""" - user, _ = await create_test_user_in_db(progress_session) - room = await create_test_room(progress_session, user) + user, _ = await create_test_user_in_db(session) + room = await create_test_room(session, user) - response = await progress_client.post( + response = await client.post( f"/v1/rooms/{room.id}/progress", json={"progress_id": "task-1", "description": "Loading data"}, ) @@ -214,17 +108,17 @@ async def test_create_progress_requires_auth( @pytest.mark.asyncio async def test_update_progress( - progress_client: AsyncClient, - progress_session: AsyncSession, + client: AsyncClient, + session: AsyncSession, mock_sio: MockSioServer, - mock_redis: AsyncMock, + redis_client: Redis, ) -> None: """PATCH updates tqdm fields, returns 200, emits progress_update.""" - user, token = await create_test_user_in_db(progress_session) - room = await create_test_room(progress_session, user) + user, token = await create_test_user_in_db(session) + room = await create_test_room(session, user) # Create a tracker first - await progress_client.post( + await client.post( f"/v1/rooms/{room.id}/progress", json={"progress_id": "task-1", "description": "Loading data"}, headers=auth_header(token), @@ -232,7 +126,7 @@ async def test_update_progress( mock_sio.emitted.clear() # Update with tqdm-like fields - response = await progress_client.patch( + response = await client.patch( f"/v1/rooms/{room.id}/progress/task-1", json={"n": 42, "total": 100, "elapsed": 5.3, "unit": "frames"}, headers=auth_header(token), @@ -253,14 +147,14 @@ async def test_update_progress( @pytest.mark.asyncio async def test_update_progress_not_found( - progress_client: AsyncClient, - progress_session: AsyncSession, + client: AsyncClient, + session: AsyncSession, ) -> None: """PATCH for non-existent tracker returns 404 with progress-not-found type.""" - user, token = await create_test_user_in_db(progress_session) - room = await create_test_room(progress_session, user) + user, token = await create_test_user_in_db(session) + room = await create_test_room(session, user) - response = await progress_client.patch( + response = await client.patch( f"/v1/rooms/{room.id}/progress/nonexistent", json={"n": 10}, headers=auth_header(token), @@ -271,17 +165,17 @@ async def test_update_progress_not_found( @pytest.mark.asyncio async def test_update_progress_description( - progress_client: AsyncClient, - progress_session: AsyncSession, + client: AsyncClient, + session: AsyncSession, mock_sio: MockSioServer, - mock_redis: AsyncMock, + redis_client: Redis, ) -> None: """PATCH can update description alongside tqdm fields.""" - user, token = await create_test_user_in_db(progress_session) - room = await create_test_room(progress_session, user) + user, token = await create_test_user_in_db(session) + room = await create_test_room(session, user) # Create a tracker first - await progress_client.post( + await client.post( f"/v1/rooms/{room.id}/progress", json={"progress_id": "task-1", "description": "Loading data"}, headers=auth_header(token), @@ -289,7 +183,7 @@ async def test_update_progress_description( mock_sio.emitted.clear() # Update both description and tqdm fields - response = await progress_client.patch( + response = await client.patch( f"/v1/rooms/{room.id}/progress/task-1", json={ "description": "Processing step 2", @@ -318,17 +212,17 @@ async def test_update_progress_description( @pytest.mark.asyncio async def test_delete_progress( - progress_client: AsyncClient, - progress_session: AsyncSession, + client: AsyncClient, + session: AsyncSession, mock_sio: MockSioServer, - mock_redis: AsyncMock, + redis_client: Redis, ) -> None: """DELETE removes tracker, returns 204, emits progress_complete.""" - user, token = await create_test_user_in_db(progress_session) - room = await create_test_room(progress_session, user) + user, token = await create_test_user_in_db(session) + room = await create_test_room(session, user) # Create a tracker first - await progress_client.post( + await client.post( f"/v1/rooms/{room.id}/progress", json={"progress_id": "task-1", "description": "Loading data"}, headers=auth_header(token), @@ -336,7 +230,7 @@ async def test_delete_progress( mock_sio.emitted.clear() # Delete it - response = await progress_client.delete( + response = await client.delete( f"/v1/rooms/{room.id}/progress/task-1", headers=auth_header(token), ) @@ -348,20 +242,20 @@ async def test_delete_progress( # Verify removed from Redis redis_key = RedisKey.room_progress(room.id) - stored = await mock_redis.hget(redis_key, "task-1") + stored = await redis_client.hget(redis_key, "task-1") assert stored is None @pytest.mark.asyncio async def test_delete_progress_not_found( - progress_client: AsyncClient, - progress_session: AsyncSession, + client: AsyncClient, + session: AsyncSession, ) -> None: """DELETE for non-existent tracker returns 404.""" - user, token = await create_test_user_in_db(progress_session) - room = await create_test_room(progress_session, user) + user, token = await create_test_user_in_db(session) + room = await create_test_room(session, user) - response = await progress_client.delete( + response = await client.delete( f"/v1/rooms/{room.id}/progress/nonexistent", headers=auth_header(token), ) @@ -376,13 +270,13 @@ async def test_delete_progress_not_found( @pytest.mark.asyncio async def test_progress_room_not_found( - progress_client: AsyncClient, - progress_session: AsyncSession, + client: AsyncClient, + session: AsyncSession, ) -> None: """POST to non-existent room returns 404 with room-not-found type.""" - _, token = await create_test_user_in_db(progress_session) + _, token = await create_test_user_in_db(session) - response = await progress_client.post( + response = await client.post( "/v1/rooms/nonexistent/progress", json={"progress_id": "task-1", "description": "Loading data"}, headers=auth_header(token), diff --git a/tests/zndraw/test_routes_edit_lock.py b/tests/zndraw/test_routes_edit_lock.py index 997ca995f..c960d2955 100644 --- a/tests/zndraw/test_routes_edit_lock.py +++ b/tests/zndraw/test_routes_edit_lock.py @@ -2,154 +2,17 @@ import asyncio import json -from collections.abc import AsyncIterator -from unittest.mock import AsyncMock, MagicMock import pytest -import pytest_asyncio -from helpers import create_test_token, create_test_user_model -from httpx import ASGITransport, AsyncClient +from helpers import auth_header, create_test_room, create_test_user_in_db +from httpx import AsyncClient from redis.asyncio import Redis -from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine -from sqlalchemy.pool import StaticPool -from sqlmodel import SQLModel +from sqlalchemy.ext.asyncio import AsyncSession -from zndraw.config import Settings from zndraw.models import MemberRole, Room, RoomMembership from zndraw.redis import RedisKey from zndraw.schemas import StatusResponse from zndraw.socket_events import LockUpdate -from zndraw_auth import User -from zndraw_auth.settings import AuthSettings - -# ============================================================================= -# Test Fixtures -# ============================================================================= - - -@pytest_asyncio.fixture(name="el_session") -async def el_session_fixture() -> AsyncIterator[AsyncSession]: - """Create a fresh in-memory async database session for each test.""" - engine = create_async_engine( - "sqlite+aiosqlite://", - connect_args={"check_same_thread": False}, - poolclass=StaticPool, - ) - - try: - async with engine.begin() as conn: - await conn.run_sync(SQLModel.metadata.create_all) - - async_session_factory = async_sessionmaker( - bind=engine, - class_=AsyncSession, - expire_on_commit=False, - ) - - async with async_session_factory() as session: - yield session - finally: - await engine.dispose() - - -@pytest_asyncio.fixture(name="el_redis") -async def el_redis_fixture() -> AsyncIterator[Redis]: - """Create a Redis client for edit lock testing.""" - redis: Redis = Redis.from_url("redis://localhost", decode_responses=True) - await redis.flushdb() - yield redis - await redis.flushdb() - await redis.aclose() - - -@pytest_asyncio.fixture(name="mock_sio") -async def mock_sio_fixture() -> MagicMock: - """Create a mock Socket.IO server for testing.""" - sio_mock = MagicMock() - sio_mock.emit = AsyncMock() - return sio_mock - - -@pytest_asyncio.fixture(name="el_client") -async def el_client_fixture( - el_session: AsyncSession, - el_redis: Redis, - mock_sio: MagicMock, -) -> AsyncIterator[AsyncClient]: - """Create an async test client with session, Redis, and sio overridden.""" - from zndraw.app import app - from zndraw.dependencies import get_redis, get_tsio - from zndraw_auth import get_session - - async def get_session_override() -> AsyncIterator[AsyncSession]: - yield el_session - - def get_redis_override() -> Redis: - return el_redis - - def get_sio_override() -> MagicMock: - return mock_sio - - app.dependency_overrides[get_session] = get_session_override - app.dependency_overrides[get_redis] = get_redis_override - app.dependency_overrides[get_tsio] = get_sio_override - app.state.settings = Settings() - app.state.auth_settings = AuthSettings() - - async with AsyncClient( - transport=ASGITransport(app=app), - base_url="http://test", - ) as client: - yield client - - app.dependency_overrides.clear() - - -# ============================================================================= -# Helpers -# ============================================================================= - - -async def _create_user( - session: AsyncSession, - email: str = "testuser@local.test", - is_superuser: bool = False, -) -> tuple[User, str]: - """Create a user and return the user and access token.""" - user = create_test_user_model(email=email, is_superuser=is_superuser) - session.add(user) - await session.commit() - await session.refresh(user) - token = create_test_token(user) - return user, token - - -async def _create_room( - session: AsyncSession, user: User, description: str = "Test Room" -) -> Room: - """Create a room with user as owner.""" - room = Room( - description=description, - created_by_id=user.id, # type: ignore[arg-type] - is_public=True, - ) - session.add(room) - await session.commit() - await session.refresh(room) - - membership = RoomMembership( - room_id=room.id, # type: ignore[arg-type] - user_id=user.id, # type: ignore[arg-type] - role=MemberRole.OWNER, - ) - session.add(membership) - await session.commit() - return room - - -def _auth(token: str) -> dict[str, str]: - """Return Authorization header dict.""" - return {"Authorization": f"Bearer {token}"} # ============================================================================= @@ -159,14 +22,14 @@ def _auth(token: str) -> dict[str, str]: @pytest.mark.asyncio async def test_get_edit_lock_returns_unlocked_when_no_lock( - el_client: AsyncClient, el_session: AsyncSession + client: AsyncClient, session: AsyncSession ) -> None: """Test GET returns locked=False when no edit lock exists.""" - user, token = await _create_user(el_session) - room = await _create_room(el_session, user) + user, token = await create_test_user_in_db(session) + room = await create_test_room(session, user) - response = await el_client.get( - f"/v1/rooms/{room.id}/edit-lock", headers=_auth(token) + response = await client.get( + f"/v1/rooms/{room.id}/edit-lock", headers=auth_header(token) ) assert response.status_code == 200 data = response.json() @@ -177,11 +40,11 @@ async def test_get_edit_lock_returns_unlocked_when_no_lock( @pytest.mark.asyncio async def test_get_edit_lock_returns_locked_when_lock_exists( - el_client: AsyncClient, el_session: AsyncSession, el_redis: Redis + client: AsyncClient, session: AsyncSession, redis_client: Redis ) -> None: """Test GET returns lock info when an edit lock is held.""" - user, token = await _create_user(el_session) - room = await _create_room(el_session, user) + user, token = await create_test_user_in_db(session) + room = await create_test_room(session, user) # Set lock directly in Redis (new format with lock_token) lock_data = json.dumps( @@ -193,10 +56,10 @@ async def test_get_edit_lock_returns_locked_when_lock_exists( "acquired_at": 1000.0, } ) - await el_redis.set(RedisKey.edit_lock(room.id), lock_data, ex=10) + await redis_client.set(RedisKey.edit_lock(room.id), lock_data, ex=10) - response = await el_client.get( - f"/v1/rooms/{room.id}/edit-lock", headers=_auth(token) + response = await client.get( + f"/v1/rooms/{room.id}/edit-lock", headers=auth_header(token) ) assert response.status_code == 200 data = response.json() @@ -210,13 +73,13 @@ async def test_get_edit_lock_returns_locked_when_lock_exists( @pytest.mark.asyncio async def test_get_edit_lock_returns_404_for_nonexistent_room( - el_client: AsyncClient, el_session: AsyncSession + client: AsyncClient, session: AsyncSession ) -> None: """Test GET for non-existent room returns 404.""" - _, token = await _create_user(el_session) + _, token = await create_test_user_in_db(session) - response = await el_client.get( - "/v1/rooms/nonexistent/edit-lock", headers=_auth(token) + response = await client.get( + "/v1/rooms/nonexistent/edit-lock", headers=auth_header(token) ) assert response.status_code == 404 @@ -228,16 +91,16 @@ async def test_get_edit_lock_returns_404_for_nonexistent_room( @pytest.mark.asyncio async def test_acquire_edit_lock( - el_client: AsyncClient, el_session: AsyncSession, el_redis: Redis + client: AsyncClient, session: AsyncSession, redis_client: Redis ) -> None: """Test PUT acquires the edit lock and returns a lock_token.""" - user, token = await _create_user(el_session) - room = await _create_room(el_session, user) + user, token = await create_test_user_in_db(session) + room = await create_test_room(session, user) - response = await el_client.put( + response = await client.put( f"/v1/rooms/{room.id}/edit-lock", json={"msg": "editing geometries"}, - headers=_auth(token), + headers=auth_header(token), ) assert response.status_code == 200 data = response.json() @@ -249,7 +112,7 @@ async def test_acquire_edit_lock( assert data["ttl"] is not None # Verify Redis entry has lock_token - raw = await el_redis.get(RedisKey.edit_lock(room.id)) + raw = await redis_client.get(RedisKey.edit_lock(room.id)) assert raw is not None lock = json.loads(raw) assert lock["user_id"] == str(user.id) @@ -258,76 +121,75 @@ async def test_acquire_edit_lock( @pytest.mark.asyncio async def test_acquire_edit_lock_stores_session_id( - el_client: AsyncClient, el_session: AsyncSession, el_redis: Redis + client: AsyncClient, session: AsyncSession, redis_client: Redis ) -> None: """Test PUT stores X-Session-ID header as sid in Redis.""" - user, token = await _create_user(el_session) - room = await _create_room(el_session, user) + user, token = await create_test_user_in_db(session) + room = await create_test_room(session, user) - response = await el_client.put( + response = await client.put( f"/v1/rooms/{room.id}/edit-lock", json={"msg": "editing"}, - headers={**_auth(token), "X-Session-ID": "my-sid-123"}, + headers={**auth_header(token), "X-Session-ID": "my-sid-123"}, ) assert response.status_code == 200 data = response.json() assert data["sid"] == "my-sid-123" # Verify Redis - raw = await el_redis.get(RedisKey.edit_lock(room.id)) + raw = await redis_client.get(RedisKey.edit_lock(room.id)) lock = json.loads(raw) assert lock["sid"] == "my-sid-123" @pytest.mark.asyncio async def test_acquire_edit_lock_broadcasts_lock_update( - el_client: AsyncClient, - el_session: AsyncSession, - mock_sio: MagicMock, + client: AsyncClient, + session: AsyncSession, + mock_sio, ) -> None: """Test PUT broadcasts LockUpdate socket event with ttl.""" - user, token = await _create_user(el_session) - room = await _create_room(el_session, user) + user, token = await create_test_user_in_db(session) + room = await create_test_room(session, user) - await el_client.put( + await client.put( f"/v1/rooms/{room.id}/edit-lock", json={"msg": "drawing"}, - headers=_auth(token), + headers=auth_header(token), ) - mock_sio.emit.assert_called() - call_args = mock_sio.emit.call_args - model = call_args[0][0] - assert isinstance(model, LockUpdate) - assert model.action == "acquired" - assert model.user_id == str(user.id) - assert model.msg == "drawing" - assert model.ttl is not None - assert model.ttl > 0 + assert len(mock_sio.emitted) >= 1 + evt = mock_sio.emitted[-1] + assert evt["event"] == "lock_update" + assert evt["data"]["action"] == "acquired" + assert evt["data"]["user_id"] == str(user.id) + assert evt["data"]["msg"] == "drawing" + assert evt["data"]["ttl"] is not None + assert evt["data"]["ttl"] > 0 @pytest.mark.asyncio async def test_refresh_edit_lock_with_token( - el_client: AsyncClient, el_session: AsyncSession + client: AsyncClient, session: AsyncSession ) -> None: """Test PUT with Lock-Token header refreshes the lock.""" - user, token = await _create_user(el_session) - room = await _create_room(el_session, user) + user, token = await create_test_user_in_db(session) + room = await create_test_room(session, user) # Acquire - resp1 = await el_client.put( + resp1 = await client.put( f"/v1/rooms/{room.id}/edit-lock", json={"msg": "first"}, - headers=_auth(token), + headers=auth_header(token), ) assert resp1.status_code == 200 lock_token = resp1.json()["lock_token"] # Refresh with Lock-Token - resp2 = await el_client.put( + resp2 = await client.put( f"/v1/rooms/{room.id}/edit-lock", json={"msg": "updated"}, - headers={**_auth(token), "Lock-Token": lock_token}, + headers={**auth_header(token), "Lock-Token": lock_token}, ) assert resp2.status_code == 200 data = resp2.json() @@ -338,65 +200,65 @@ async def test_refresh_edit_lock_with_token( @pytest.mark.asyncio async def test_refresh_with_expired_lock_returns_409( - el_client: AsyncClient, el_session: AsyncSession, el_redis: Redis + client: AsyncClient, session: AsyncSession, redis_client: Redis ) -> None: """Test refresh returns 409 when lock has expired.""" - user, token = await _create_user(el_session) - room = await _create_room(el_session, user) + user, token = await create_test_user_in_db(session) + room = await create_test_room(session, user) # Acquire - resp = await el_client.put( + resp = await client.put( f"/v1/rooms/{room.id}/edit-lock", json={"msg": "editing"}, - headers=_auth(token), + headers=auth_header(token), ) lock_token = resp.json()["lock_token"] # Simulate expiry by deleting the key - await el_redis.delete(RedisKey.edit_lock(room.id)) + await redis_client.delete(RedisKey.edit_lock(room.id)) # Try to refresh → 409 - resp2 = await el_client.put( + resp2 = await client.put( f"/v1/rooms/{room.id}/edit-lock", json={"msg": "editing"}, - headers={**_auth(token), "Lock-Token": lock_token}, + headers={**auth_header(token), "Lock-Token": lock_token}, ) assert resp2.status_code == 409 @pytest.mark.asyncio async def test_acquire_without_token_when_lock_exists_returns_423( - el_client: AsyncClient, el_session: AsyncSession + client: AsyncClient, session: AsyncSession ) -> None: """PUT without Lock-Token when lock exists returns 423.""" - user, token = await _create_user(el_session) - room = await _create_room(el_session, user) + user, token = await create_test_user_in_db(session) + room = await create_test_room(session, user) # Acquire - resp1 = await el_client.put( + resp1 = await client.put( f"/v1/rooms/{room.id}/edit-lock", json={"msg": "first"}, - headers=_auth(token), + headers=auth_header(token), ) assert resp1.status_code == 200 # Same user, no Lock-Token → 423 (lock already exists) - resp2 = await el_client.put( + resp2 = await client.put( f"/v1/rooms/{room.id}/edit-lock", json={"msg": "second attempt"}, - headers=_auth(token), + headers=auth_header(token), ) assert resp2.status_code == 423 @pytest.mark.asyncio async def test_acquire_edit_lock_conflict_with_other_user( - el_client: AsyncClient, el_session: AsyncSession + client: AsyncClient, session: AsyncSession ) -> None: """Test PUT returns 423 when another user holds the lock.""" - user1, token1 = await _create_user(el_session, "user1@test") - user2, token2 = await _create_user(el_session, "user2@test") - room = await _create_room(el_session, user1) + user1, token1 = await create_test_user_in_db(session, "user1@test") + user2, token2 = await create_test_user_in_db(session, "user2@test") + room = await create_test_room(session, user1) # Give user2 room membership membership = RoomMembership( @@ -404,22 +266,22 @@ async def test_acquire_edit_lock_conflict_with_other_user( user_id=user2.id, # type: ignore[arg-type] role=MemberRole.MEMBER, ) - el_session.add(membership) - await el_session.commit() + session.add(membership) + await session.commit() # User1 acquires lock - resp1 = await el_client.put( + resp1 = await client.put( f"/v1/rooms/{room.id}/edit-lock", json={"msg": "user1 editing"}, - headers=_auth(token1), + headers=auth_header(token1), ) assert resp1.status_code == 200 # User2 tries to acquire → 423 - resp2 = await el_client.put( + resp2 = await client.put( f"/v1/rooms/{room.id}/edit-lock", json={"msg": "user2 editing"}, - headers=_auth(token2), + headers=auth_header(token2), ) assert resp2.status_code == 423 assert "locked" in resp2.json()["type"] @@ -427,45 +289,45 @@ async def test_acquire_edit_lock_conflict_with_other_user( @pytest.mark.asyncio async def test_acquire_edit_lock_blocked_by_admin_lock( - el_client: AsyncClient, el_session: AsyncSession + client: AsyncClient, session: AsyncSession ) -> None: """Test PUT returns 423 when room is admin-locked.""" - user, token = await _create_user(el_session) - room = await _create_room(el_session, user) + user, token = await create_test_user_in_db(session) + room = await create_test_room(session, user) room.locked = True - el_session.add(room) - await el_session.commit() + session.add(room) + await session.commit() - response = await el_client.put( + response = await client.put( f"/v1/rooms/{room.id}/edit-lock", json={"msg": "editing"}, - headers=_auth(token), + headers=auth_header(token), ) assert response.status_code == 423 @pytest.mark.asyncio async def test_lock_auto_expires_after_ttl( - el_client: AsyncClient, el_session: AsyncSession, el_redis: Redis + client: AsyncClient, session: AsyncSession, redis_client: Redis ) -> None: """Lock disappears from Redis after edit_lock_ttl without refresh.""" - user1, token1 = await _create_user(el_session, "user1@test") - user2, token2 = await _create_user(el_session, "user2@test") - room = await _create_room(el_session, user1) + user1, token1 = await create_test_user_in_db(session, "user1@test") + user2, token2 = await create_test_user_in_db(session, "user2@test") + room = await create_test_room(session, user1) membership = RoomMembership( room_id=room.id, # type: ignore[arg-type] user_id=user2.id, # type: ignore[arg-type] role=MemberRole.MEMBER, ) - el_session.add(membership) - await el_session.commit() + session.add(membership) + await session.commit() # User1 acquires - resp = await el_client.put( + resp = await client.put( f"/v1/rooms/{room.id}/edit-lock", json={"msg": "no refresh"}, - headers=_auth(token1), + headers=auth_header(token1), ) assert resp.status_code == 200 @@ -473,16 +335,16 @@ async def test_lock_auto_expires_after_ttl( await asyncio.sleep(11) # Lock should be gone - status = await el_client.get( - f"/v1/rooms/{room.id}/edit-lock", headers=_auth(token1) + status = await client.get( + f"/v1/rooms/{room.id}/edit-lock", headers=auth_header(token1) ) assert status.json()["locked"] is False # User2 can now acquire - resp2 = await el_client.put( + resp2 = await client.put( f"/v1/rooms/{room.id}/edit-lock", json={"msg": "after expiry"}, - headers=_auth(token2), + headers=auth_header(token2), ) assert resp2.status_code == 200 assert resp2.json()["user_id"] == str(user2.id) @@ -495,74 +357,73 @@ async def test_lock_auto_expires_after_ttl( @pytest.mark.asyncio async def test_release_edit_lock_with_token( - el_client: AsyncClient, el_session: AsyncSession, el_redis: Redis + client: AsyncClient, session: AsyncSession, redis_client: Redis ) -> None: """Test DELETE with Lock-Token releases the lock.""" - user, token = await _create_user(el_session) - room = await _create_room(el_session, user) + user, token = await create_test_user_in_db(session) + room = await create_test_room(session, user) # Acquire - resp = await el_client.put( + resp = await client.put( f"/v1/rooms/{room.id}/edit-lock", json={"msg": "temp"}, - headers=_auth(token), + headers=auth_header(token), ) lock_token = resp.json()["lock_token"] # Release with Lock-Token - response = await el_client.delete( + response = await client.delete( f"/v1/rooms/{room.id}/edit-lock", - headers={**_auth(token), "Lock-Token": lock_token}, + headers={**auth_header(token), "Lock-Token": lock_token}, ) assert response.status_code == 200 StatusResponse.model_validate(response.json()) # Verify Redis entry removed - raw = await el_redis.get(RedisKey.edit_lock(room.id)) + raw = await redis_client.get(RedisKey.edit_lock(room.id)) assert raw is None @pytest.mark.asyncio async def test_release_edit_lock_broadcasts_lock_update( - el_client: AsyncClient, - el_session: AsyncSession, - mock_sio: MagicMock, + client: AsyncClient, + session: AsyncSession, + mock_sio, ) -> None: """Test DELETE broadcasts LockUpdate with action=released.""" - user, token = await _create_user(el_session) - room = await _create_room(el_session, user) + user, token = await create_test_user_in_db(session) + room = await create_test_room(session, user) # Acquire then release - resp = await el_client.put( + resp = await client.put( f"/v1/rooms/{room.id}/edit-lock", json={}, - headers=_auth(token), + headers=auth_header(token), ) lock_token = resp.json()["lock_token"] - mock_sio.emit.reset_mock() + mock_sio.emitted.clear() - await el_client.delete( + await client.delete( f"/v1/rooms/{room.id}/edit-lock", - headers={**_auth(token), "Lock-Token": lock_token}, + headers={**auth_header(token), "Lock-Token": lock_token}, ) - mock_sio.emit.assert_called() - call_args = mock_sio.emit.call_args - model = call_args[0][0] - assert isinstance(model, LockUpdate) - assert model.action == "released" + assert len(mock_sio.emitted) >= 1 + evt = mock_sio.emitted[-1] + assert evt["event"] == "lock_update" + assert evt["data"]["action"] == "released" @pytest.mark.asyncio async def test_release_edit_lock_idempotent_when_no_lock( - el_client: AsyncClient, el_session: AsyncSession + client: AsyncClient, session: AsyncSession ) -> None: """Test DELETE succeeds even if no lock exists.""" - user, token = await _create_user(el_session) - room = await _create_room(el_session, user) + user, token = await create_test_user_in_db(session) + room = await create_test_room(session, user) - response = await el_client.delete( - f"/v1/rooms/{room.id}/edit-lock", headers=_auth(token) + response = await client.delete( + f"/v1/rooms/{room.id}/edit-lock", headers=auth_header(token) ) assert response.status_code == 200 StatusResponse.model_validate(response.json()) @@ -570,66 +431,68 @@ async def test_release_edit_lock_idempotent_when_no_lock( @pytest.mark.asyncio async def test_release_edit_lock_wrong_token_returns_403( - el_client: AsyncClient, el_session: AsyncSession + client: AsyncClient, session: AsyncSession ) -> None: """Test DELETE with wrong Lock-Token returns 403.""" - user, token = await _create_user(el_session) - room = await _create_room(el_session, user) + user, token = await create_test_user_in_db(session) + room = await create_test_room(session, user) # Acquire - await el_client.put( + await client.put( f"/v1/rooms/{room.id}/edit-lock", json={}, - headers=_auth(token), + headers=auth_header(token), ) # Release with wrong Lock-Token - response = await el_client.delete( + response = await client.delete( f"/v1/rooms/{room.id}/edit-lock", - headers={**_auth(token), "Lock-Token": "wrong-token"}, + headers={**auth_header(token), "Lock-Token": "wrong-token"}, ) assert response.status_code == 403 @pytest.mark.asyncio async def test_release_edit_lock_forbidden_for_non_holder( - el_client: AsyncClient, el_session: AsyncSession + client: AsyncClient, session: AsyncSession ) -> None: """Test DELETE returns 403 when non-holder tries to release.""" - user1, token1 = await _create_user(el_session, "user1@test") - user2, token2 = await _create_user(el_session, "user2@test") - room = await _create_room(el_session, user1) + user1, token1 = await create_test_user_in_db(session, "user1@test") + user2, token2 = await create_test_user_in_db(session, "user2@test") + room = await create_test_room(session, user1) membership = RoomMembership( room_id=room.id, # type: ignore[arg-type] user_id=user2.id, # type: ignore[arg-type] role=MemberRole.MEMBER, ) - el_session.add(membership) - await el_session.commit() + session.add(membership) + await session.commit() # User1 acquires - await el_client.put( + await client.put( f"/v1/rooms/{room.id}/edit-lock", json={}, - headers=_auth(token1), + headers=auth_header(token1), ) # User2 tries to release → 403 - response = await el_client.delete( - f"/v1/rooms/{room.id}/edit-lock", headers=_auth(token2) + response = await client.delete( + f"/v1/rooms/{room.id}/edit-lock", headers=auth_header(token2) ) assert response.status_code == 403 @pytest.mark.asyncio async def test_admin_can_release_any_lock( - el_client: AsyncClient, el_session: AsyncSession + client: AsyncClient, session: AsyncSession ) -> None: """Test admin can release another user's lock.""" - user, token = await _create_user(el_session, "user@test") - admin, admin_token = await _create_user(el_session, "admin@test", is_superuser=True) - room = await _create_room(el_session, user) + user, token = await create_test_user_in_db(session, "user@test") + admin, admin_token = await create_test_user_in_db( + session, "admin@test", is_superuser=True + ) + room = await create_test_room(session, user) # Give admin membership membership = RoomMembership( @@ -637,19 +500,19 @@ async def test_admin_can_release_any_lock( user_id=admin.id, # type: ignore[arg-type] role=MemberRole.MEMBER, ) - el_session.add(membership) - await el_session.commit() + session.add(membership) + await session.commit() # User acquires - await el_client.put( + await client.put( f"/v1/rooms/{room.id}/edit-lock", json={}, - headers=_auth(token), + headers=auth_header(token), ) # Admin releases (no Lock-Token needed for admin) - response = await el_client.delete( - f"/v1/rooms/{room.id}/edit-lock", headers=_auth(admin_token) + response = await client.delete( + f"/v1/rooms/{room.id}/edit-lock", headers=auth_header(admin_token) ) assert response.status_code == 200 StatusResponse.model_validate(response.json()) @@ -662,20 +525,20 @@ async def test_admin_can_release_any_lock( @pytest.mark.asyncio async def test_writable_room_blocks_non_admin_on_admin_locked_room( - el_client: AsyncClient, el_session: AsyncSession + client: AsyncClient, session: AsyncSession ) -> None: """Test mutation returns 423 when room is admin-locked and user is not admin.""" - user, token = await _create_user(el_session) - room = await _create_room(el_session, user) + user, token = await create_test_user_in_db(session) + room = await create_test_room(session, user) room.locked = True - el_session.add(room) - await el_session.commit() + session.add(room) + await session.commit() # Try to set a bookmark (mutation endpoint using WritableRoomDep) - response = await el_client.put( + response = await client.put( f"/v1/rooms/{room.id}/bookmarks/0", json={"label": "test"}, - headers=_auth(token), + headers=auth_header(token), ) assert response.status_code == 423 assert "locked" in response.json()["type"] @@ -683,19 +546,19 @@ async def test_writable_room_blocks_non_admin_on_admin_locked_room( @pytest.mark.asyncio async def test_writable_room_allows_admin_on_admin_locked_room( - el_client: AsyncClient, el_session: AsyncSession + client: AsyncClient, session: AsyncSession ) -> None: """Test admin can mutate even when room is admin-locked.""" - admin, token = await _create_user(el_session, is_superuser=True) - room = await _create_room(el_session, admin) + admin, token = await create_test_user_in_db(session, is_superuser=True) + room = await create_test_room(session, admin) room.locked = True - el_session.add(room) - await el_session.commit() + session.add(room) + await session.commit() - response = await el_client.put( + response = await client.put( f"/v1/rooms/{room.id}/bookmarks/0", json={"label": "admin edit"}, - headers=_auth(token), + headers=auth_header(token), ) assert response.status_code == 200 @@ -707,33 +570,33 @@ async def test_writable_room_allows_admin_on_admin_locked_room( @pytest.mark.asyncio async def test_writable_room_blocks_non_holder( - el_client: AsyncClient, el_session: AsyncSession + client: AsyncClient, session: AsyncSession ) -> None: """Test mutation returns 423 when another user holds the edit lock.""" - user1, token1 = await _create_user(el_session, "user1@test") - user2, token2 = await _create_user(el_session, "user2@test") - room = await _create_room(el_session, user1) + user1, token1 = await create_test_user_in_db(session, "user1@test") + user2, token2 = await create_test_user_in_db(session, "user2@test") + room = await create_test_room(session, user1) membership = RoomMembership( room_id=room.id, # type: ignore[arg-type] user_id=user2.id, # type: ignore[arg-type] role=MemberRole.MEMBER, ) - el_session.add(membership) - await el_session.commit() + session.add(membership) + await session.commit() # User1 acquires lock - await el_client.put( + await client.put( f"/v1/rooms/{room.id}/edit-lock", json={"msg": "editing"}, - headers=_auth(token1), + headers=auth_header(token1), ) # User2 tries mutation → 423 - response = await el_client.put( + response = await client.put( f"/v1/rooms/{room.id}/bookmarks/0", json={"label": "blocked"}, - headers=_auth(token2), + headers=auth_header(token2), ) assert response.status_code == 423 assert "locked" in response.json()["type"] @@ -741,80 +604,80 @@ async def test_writable_room_blocks_non_holder( @pytest.mark.asyncio async def test_writable_room_allows_lock_holder_with_token( - el_client: AsyncClient, el_session: AsyncSession + client: AsyncClient, session: AsyncSession ) -> None: """Test lock holder can mutate when sending Lock-Token header.""" - user, token = await _create_user(el_session) - room = await _create_room(el_session, user) + user, token = await create_test_user_in_db(session) + room = await create_test_room(session, user) # Acquire lock - resp = await el_client.put( + resp = await client.put( f"/v1/rooms/{room.id}/edit-lock", json={"msg": "editing"}, - headers=_auth(token), + headers=auth_header(token), ) lock_token = resp.json()["lock_token"] # Holder can mutate with Lock-Token - response = await el_client.put( + response = await client.put( f"/v1/rooms/{room.id}/bookmarks/0", json={"label": "allowed"}, - headers={**_auth(token), "Lock-Token": lock_token}, + headers={**auth_header(token), "Lock-Token": lock_token}, ) assert response.status_code == 200 @pytest.mark.asyncio async def test_writable_room_blocks_wrong_lock_token( - el_client: AsyncClient, el_session: AsyncSession + client: AsyncClient, session: AsyncSession ) -> None: """Test mutation returns 423 when wrong Lock-Token is sent.""" - user, token = await _create_user(el_session) - room = await _create_room(el_session, user) + user, token = await create_test_user_in_db(session) + room = await create_test_room(session, user) # Acquire lock - await el_client.put( + await client.put( f"/v1/rooms/{room.id}/edit-lock", json={"msg": "editing"}, - headers=_auth(token), + headers=auth_header(token), ) # Wrong Lock-Token → 423 - response = await el_client.put( + response = await client.put( f"/v1/rooms/{room.id}/bookmarks/0", json={"label": "blocked"}, - headers={**_auth(token), "Lock-Token": "wrong-token"}, + headers={**auth_header(token), "Lock-Token": "wrong-token"}, ) assert response.status_code == 423 @pytest.mark.asyncio async def test_writable_room_allows_get_when_locked( - el_client: AsyncClient, el_session: AsyncSession + client: AsyncClient, session: AsyncSession ) -> None: """Test GET endpoints still work when room has edit lock.""" - user1, token1 = await _create_user(el_session, "user1@test") - user2, token2 = await _create_user(el_session, "user2@test") - room = await _create_room(el_session, user1) + user1, token1 = await create_test_user_in_db(session, "user1@test") + user2, token2 = await create_test_user_in_db(session, "user2@test") + room = await create_test_room(session, user1) membership = RoomMembership( room_id=room.id, # type: ignore[arg-type] user_id=user2.id, # type: ignore[arg-type] role=MemberRole.MEMBER, ) - el_session.add(membership) - await el_session.commit() + session.add(membership) + await session.commit() # User1 acquires lock - await el_client.put( + await client.put( f"/v1/rooms/{room.id}/edit-lock", json={"msg": "editing"}, - headers=_auth(token1), + headers=auth_header(token1), ) # User2 can still read bookmarks - response = await el_client.get( + response = await client.get( f"/v1/rooms/{room.id}/bookmarks", - headers=_auth(token2), + headers=auth_header(token2), ) assert response.status_code == 200 diff --git a/tests/zndraw/test_routes_frame_selection.py b/tests/zndraw/test_routes_frame_selection.py index bd22b75c3..17755ed00 100644 --- a/tests/zndraw/test_routes_frame_selection.py +++ b/tests/zndraw/test_routes_frame_selection.py @@ -1,152 +1,13 @@ """Tests for frame selection REST API endpoints.""" import json -from collections.abc import AsyncIterator -from unittest.mock import AsyncMock, MagicMock import pytest -import pytest_asyncio -from helpers import create_test_token, create_test_user_model -from httpx import ASGITransport, AsyncClient -from redis.asyncio import Redis -from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine -from sqlalchemy.pool import StaticPool -from sqlmodel import SQLModel - -from zndraw.config import Settings -from zndraw.models import MemberRole, Room, RoomMembership -from zndraw.socket_events import FrameSelectionUpdate -from zndraw_auth import User -from zndraw_auth.settings import AuthSettings - -# ============================================================================= -# Test Fixtures -# ============================================================================= - - -@pytest_asyncio.fixture(name="fs_session") -async def fs_session_fixture() -> AsyncIterator[AsyncSession]: - """Create a fresh in-memory async database session for each test.""" - engine = create_async_engine( - "sqlite+aiosqlite://", - connect_args={"check_same_thread": False}, - poolclass=StaticPool, - ) - - try: - async with engine.begin() as conn: - await conn.run_sync(SQLModel.metadata.create_all) - - async_session_factory = async_sessionmaker( - bind=engine, - class_=AsyncSession, - expire_on_commit=False, - ) - - async with async_session_factory() as session: - yield session - finally: - await engine.dispose() - - -@pytest_asyncio.fixture(name="fs_redis") -async def fs_redis_fixture() -> AsyncIterator[Redis]: - """Create a Redis client for frame selection testing.""" - redis: Redis = Redis.from_url("redis://localhost", decode_responses=True) - await redis.flushdb() - yield redis - await redis.flushdb() - await redis.aclose() - - -@pytest_asyncio.fixture(name="mock_sio") -async def mock_sio_fixture() -> MagicMock: - """Create a mock Socket.IO server for testing.""" - sio_mock = MagicMock() - sio_mock.emit = AsyncMock() - return sio_mock - - -@pytest_asyncio.fixture(name="fs_client") -async def fs_client_fixture( - fs_session: AsyncSession, - fs_redis: Redis, - mock_sio: MagicMock, -) -> AsyncIterator[AsyncClient]: - """Create an async test client with session, Redis, and sio overridden.""" - from zndraw.app import app - from zndraw.dependencies import get_redis, get_tsio - from zndraw_auth import get_session - - async def get_session_override() -> AsyncIterator[AsyncSession]: - yield fs_session - - def get_redis_override() -> Redis: - return fs_redis - - def get_sio_override() -> MagicMock: - return mock_sio - - app.dependency_overrides[get_session] = get_session_override - app.dependency_overrides[get_redis] = get_redis_override - app.dependency_overrides[get_tsio] = get_sio_override - app.state.settings = Settings() - app.state.auth_settings = AuthSettings() - - async with AsyncClient( - transport=ASGITransport(app=app), - base_url="http://test", - ) as client: - yield client +from helpers import MockSioServer, auth_header, create_test_room, create_test_user_in_db +from httpx import AsyncClient +from sqlalchemy.ext.asyncio import AsyncSession - app.dependency_overrides.clear() - - -# ============================================================================= -# Helpers -# ============================================================================= - - -async def _create_user( - session: AsyncSession, - email: str = "testuser@local.test", - is_superuser: bool = False, -) -> tuple[User, str]: - """Create a user and return the user and access token.""" - user = create_test_user_model(email=email, is_superuser=is_superuser) - session.add(user) - await session.commit() - await session.refresh(user) - token = create_test_token(user) - return user, token - - -async def _create_room( - session: AsyncSession, user: User, description: str = "Test Room" -) -> Room: - """Create a room with user as owner.""" - room = Room( - description=description, - created_by_id=user.id, # type: ignore[arg-type] - is_public=True, - ) - session.add(room) - await session.commit() - await session.refresh(room) - - membership = RoomMembership( - room_id=room.id, # type: ignore[arg-type] - user_id=user.id, # type: ignore[arg-type] - role=MemberRole.OWNER, - ) - session.add(membership) - await session.commit() - return room - - -def _auth(token: str) -> dict[str, str]: - """Return Authorization header dict.""" - return {"Authorization": f"Bearer {token}"} +from zndraw.socket_events import FrameSelectionUpdate # ============================================================================= @@ -156,14 +17,14 @@ def _auth(token: str) -> dict[str, str]: @pytest.mark.asyncio async def test_get_returns_null_when_empty( - fs_client: AsyncClient, fs_session: AsyncSession + client: AsyncClient, session: AsyncSession ) -> None: """GET returns null frameSelection for a new room.""" - user, token = await _create_user(fs_session) - room = await _create_room(fs_session, user) + user, token = await create_test_user_in_db(session) + room = await create_test_room(session, user) - response = await fs_client.get( - f"/v1/rooms/{room.id}/frame-selection", headers=_auth(token) + response = await client.get( + f"/v1/rooms/{room.id}/frame-selection", headers=auth_header(token) ) assert response.status_code == 200 assert response.json()["frame_selection"] is None @@ -171,19 +32,19 @@ async def test_get_returns_null_when_empty( @pytest.mark.asyncio async def test_get_returns_stored_indices( - fs_client: AsyncClient, fs_session: AsyncSession + client: AsyncClient, session: AsyncSession ) -> None: """GET returns stored indices when frame_selection column is set.""" - user, token = await _create_user(fs_session) - room = await _create_room(fs_session, user) + user, token = await create_test_user_in_db(session) + room = await create_test_room(session, user) # Set column directly room.frame_selection = json.dumps([2, 5, 10]) - fs_session.add(room) - await fs_session.commit() + session.add(room) + await session.commit() - response = await fs_client.get( - f"/v1/rooms/{room.id}/frame-selection", headers=_auth(token) + response = await client.get( + f"/v1/rooms/{room.id}/frame-selection", headers=auth_header(token) ) assert response.status_code == 200 assert response.json()["frame_selection"] == [2, 5, 10] @@ -191,13 +52,13 @@ async def test_get_returns_stored_indices( @pytest.mark.asyncio async def test_get_returns_404_for_nonexistent_room( - fs_client: AsyncClient, fs_session: AsyncSession + client: AsyncClient, session: AsyncSession ) -> None: """GET for non-existent room returns 404.""" - _, token = await _create_user(fs_session) + _, token = await create_test_user_in_db(session) - response = await fs_client.get( - "/v1/rooms/nonexistent/frame-selection", headers=_auth(token) + response = await client.get( + "/v1/rooms/nonexistent/frame-selection", headers=auth_header(token) ) assert response.status_code == 404 @@ -209,123 +70,124 @@ async def test_get_returns_404_for_nonexistent_room( @pytest.mark.asyncio async def test_put_stores_indices( - fs_client: AsyncClient, fs_session: AsyncSession + client: AsyncClient, session: AsyncSession ) -> None: """PUT stores indices and GET returns them.""" - user, token = await _create_user(fs_session) - room = await _create_room(fs_session, user) + user, token = await create_test_user_in_db(session) + room = await create_test_room(session, user) - put_resp = await fs_client.put( + put_resp = await client.put( f"/v1/rooms/{room.id}/frame-selection", json={"indices": [1, 3, 7]}, - headers=_auth(token), + headers=auth_header(token), ) assert put_resp.status_code == 200 assert put_resp.json()["success"] is True - get_resp = await fs_client.get( - f"/v1/rooms/{room.id}/frame-selection", headers=_auth(token) + get_resp = await client.get( + f"/v1/rooms/{room.id}/frame-selection", headers=auth_header(token) ) assert get_resp.json()["frame_selection"] == [1, 3, 7] @pytest.mark.asyncio async def test_put_broadcasts_socket_event( - fs_client: AsyncClient, - fs_session: AsyncSession, - mock_sio: MagicMock, + client: AsyncClient, + session: AsyncSession, + mock_sio: MockSioServer, ) -> None: """PUT broadcasts FrameSelectionUpdate socket event.""" - user, token = await _create_user(fs_session) - room = await _create_room(fs_session, user) + user, token = await create_test_user_in_db(session) + room = await create_test_room(session, user) + + mock_sio.emitted.clear() - await fs_client.put( + await client.put( f"/v1/rooms/{room.id}/frame-selection", json={"indices": [0, 4]}, - headers=_auth(token), + headers=auth_header(token), ) - mock_sio.emit.assert_called() - call_args = mock_sio.emit.call_args - model = call_args[0][0] - assert isinstance(model, FrameSelectionUpdate) - assert model.indices == [0, 4] + assert len(mock_sio.emitted) >= 1 + evt = mock_sio.emitted[-1] + assert evt["event"] == "frame_selection_update" + assert evt["data"]["indices"] == [0, 4] @pytest.mark.asyncio async def test_put_empty_list_clears_selection( - fs_client: AsyncClient, fs_session: AsyncSession + client: AsyncClient, session: AsyncSession ) -> None: """PUT with empty list clears selection (GET returns null).""" - user, token = await _create_user(fs_session) - room = await _create_room(fs_session, user) + user, token = await create_test_user_in_db(session) + room = await create_test_room(session, user) # Set some indices first - await fs_client.put( + await client.put( f"/v1/rooms/{room.id}/frame-selection", json={"indices": [1, 2]}, - headers=_auth(token), + headers=auth_header(token), ) # Clear with empty list - await fs_client.put( + await client.put( f"/v1/rooms/{room.id}/frame-selection", json={"indices": []}, - headers=_auth(token), + headers=auth_header(token), ) - get_resp = await fs_client.get( - f"/v1/rooms/{room.id}/frame-selection", headers=_auth(token) + get_resp = await client.get( + f"/v1/rooms/{room.id}/frame-selection", headers=auth_header(token) ) assert get_resp.json()["frame_selection"] is None @pytest.mark.asyncio async def test_put_rejects_negative_indices( - fs_client: AsyncClient, fs_session: AsyncSession + client: AsyncClient, session: AsyncSession ) -> None: """PUT with negative indices returns 400.""" - user, token = await _create_user(fs_session) - room = await _create_room(fs_session, user) + user, token = await create_test_user_in_db(session) + room = await create_test_room(session, user) - response = await fs_client.put( + response = await client.put( f"/v1/rooms/{room.id}/frame-selection", json={"indices": [1, -2, 3]}, - headers=_auth(token), + headers=auth_header(token), ) assert response.status_code == 400 @pytest.mark.asyncio async def test_put_returns_404_for_nonexistent_room( - fs_client: AsyncClient, fs_session: AsyncSession + client: AsyncClient, session: AsyncSession ) -> None: """PUT for non-existent room returns 404.""" - _, token = await _create_user(fs_session) + _, token = await create_test_user_in_db(session) - response = await fs_client.put( + response = await client.put( "/v1/rooms/nonexistent/frame-selection", json={"indices": [0]}, - headers=_auth(token), + headers=auth_header(token), ) assert response.status_code == 404 @pytest.mark.asyncio -async def test_roundtrip(fs_client: AsyncClient, fs_session: AsyncSession) -> None: +async def test_roundtrip(client: AsyncClient, session: AsyncSession) -> None: """PUT then GET returns consistent data.""" - user, token = await _create_user(fs_session) - room = await _create_room(fs_session, user) + user, token = await create_test_user_in_db(session) + room = await create_test_room(session, user) indices = [0, 2, 4, 6, 8] - await fs_client.put( + await client.put( f"/v1/rooms/{room.id}/frame-selection", json={"indices": indices}, - headers=_auth(token), + headers=auth_header(token), ) - get_resp = await fs_client.get( - f"/v1/rooms/{room.id}/frame-selection", headers=_auth(token) + get_resp = await client.get( + f"/v1/rooms/{room.id}/frame-selection", headers=auth_header(token) ) assert get_resp.status_code == 200 assert get_resp.json()["frame_selection"] == indices diff --git a/tests/zndraw/test_routes_frames.py b/tests/zndraw/test_routes_frames.py index 16f3a9c45..bc614226f 100644 --- a/tests/zndraw/test_routes_frames.py +++ b/tests/zndraw/test_routes_frames.py @@ -1,25 +1,19 @@ """Tests for Frame REST API endpoints.""" -from collections.abc import AsyncIterator from typing import Any -from unittest.mock import AsyncMock import ase import msgpack import pytest -import pytest_asyncio from helpers import ( - MockSioServer, auth_header, create_test_room, create_test_user_in_db, decode_msgpack_response, make_raw_frame, ) -from httpx import ASGITransport, AsyncClient -from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine -from sqlalchemy.pool import StaticPool -from sqlmodel import SQLModel +from httpx import AsyncClient +from sqlalchemy.ext.asyncio import AsyncSession from zndraw.client import atoms_to_json_dict from zndraw.exceptions import FrameNotFound, ProblemDetail, RoomNotFound @@ -51,92 +45,6 @@ def raw_frame_to_dict(frame: RawFrame) -> dict[str, Any]: return {k.decode(): msgpack.unpackb(v) for k, v in frame.items()} -# ============================================================================= -# Test-specific Fixtures -# ============================================================================= - - -@pytest_asyncio.fixture(name="frame_session") -async def frame_session_fixture() -> AsyncIterator[AsyncSession]: - """Create a fresh in-memory async database session for each test.""" - engine = create_async_engine( - "sqlite+aiosqlite://", - connect_args={"check_same_thread": False}, - poolclass=StaticPool, - ) - - try: - async with engine.begin() as conn: - await conn.run_sync(SQLModel.metadata.create_all) - - async_session_factory = async_sessionmaker( - bind=engine, - class_=AsyncSession, - expire_on_commit=False, - ) - - async with async_session_factory() as session: - yield session - finally: - await engine.dispose() - - -@pytest_asyncio.fixture(name="frame_client") -async def frame_client_fixture( - frame_session: AsyncSession, frame_storage: FrameStorage -) -> AsyncIterator[AsyncClient]: - """Create an async test client with session and storage dependencies overridden.""" - from contextlib import asynccontextmanager - - from zndraw.app import app - from zndraw.dependencies import ( - get_frame_storage, - get_joblib_settings, - get_redis, - get_result_backend, - get_tsio, - ) - from zndraw_auth import get_session - from zndraw_auth.settings import AuthSettings - from zndraw_joblib.settings import JobLibSettings - - mock_sio = MockSioServer() - - async def get_session_override() -> AsyncIterator[AsyncSession]: - yield frame_session - - @asynccontextmanager - async def test_session_maker(): - yield frame_session - - def get_storage_override() -> FrameStorage: - return frame_storage - - def get_sio_override() -> MockSioServer: - return mock_sio - - # Mock Redis for WritableRoomDep (returns None = no edit lock) - mock_redis = AsyncMock() - mock_redis.get = AsyncMock(return_value=None) - - app.state.auth_settings = AuthSettings() - app.state.session_maker = test_session_maker - app.dependency_overrides[get_session] = get_session_override - app.dependency_overrides[get_frame_storage] = get_storage_override - app.dependency_overrides[get_tsio] = get_sio_override - app.dependency_overrides[get_redis] = lambda: mock_redis - app.dependency_overrides[get_result_backend] = lambda: AsyncMock() - app.dependency_overrides[get_joblib_settings] = lambda: JobLibSettings() - - async with AsyncClient( - transport=ASGITransport(app=app), - base_url="http://test", - ) as client: - yield client - - app.dependency_overrides.clear() - - _create_user = create_test_user_in_db _create_room = create_test_room _auth_header = auth_header @@ -149,13 +57,13 @@ def get_sio_override() -> MockSioServer: @pytest.mark.asyncio async def test_list_frames_empty_room( - frame_client: AsyncClient, frame_session: AsyncSession + client: AsyncClient, session: AsyncSession ) -> None: """Test listing frames from an empty room returns empty list.""" - user, token = await _create_user(frame_session) - room = await _create_room(frame_session, user) + user, token = await _create_user(session) + room = await _create_room(session, user) - response = await frame_client.get( + response = await client.get( f"/v1/rooms/{room.id}/frames", headers=_auth_header(token), ) @@ -168,19 +76,19 @@ async def test_list_frames_empty_room( @pytest.mark.asyncio async def test_list_frames_with_data( - frame_client: AsyncClient, - frame_session: AsyncSession, + client: AsyncClient, + session: AsyncSession, frame_storage: FrameStorage, ) -> None: """Test listing frames with data returns all frames.""" - user, token = await _create_user(frame_session) - room = await _create_room(frame_session, user) + user, token = await _create_user(session) + room = await _create_room(session, user) # Add frames to storage await frame_storage[room.id].extend( [make_raw_frame({"a": 1}), make_raw_frame({"b": 2}), make_raw_frame({"c": 3})] ) - response = await frame_client.get( + response = await client.get( f"/v1/rooms/{room.id}/frames", headers=_auth_header(token), ) @@ -194,13 +102,13 @@ async def test_list_frames_with_data( @pytest.mark.asyncio async def test_list_frames_with_range( - frame_client: AsyncClient, - frame_session: AsyncSession, + client: AsyncClient, + session: AsyncSession, frame_storage: FrameStorage, ) -> None: """Test listing frames with range query params.""" - user, token = await _create_user(frame_session) - room = await _create_room(frame_session, user) + user, token = await _create_user(session) + room = await _create_room(session, user) # Add frames to storage await frame_storage[room.id].extend( @@ -211,7 +119,7 @@ async def test_list_frames_with_range( make_raw_frame({"d": 4}), ] ) - response = await frame_client.get( + response = await client.get( f"/v1/rooms/{room.id}/frames?start=1&stop=3", headers=_auth_header(token), ) @@ -225,12 +133,12 @@ async def test_list_frames_with_range( @pytest.mark.asyncio async def test_list_frames_room_not_found( - frame_client: AsyncClient, frame_session: AsyncSession + client: AsyncClient, session: AsyncSession ) -> None: """Test listing frames from non-existent room returns 404.""" - _, token = await _create_user(frame_session) + _, token = await _create_user(session) - response = await frame_client.get( + response = await client.get( "/v1/rooms/99999/frames", headers=_auth_header(token), ) @@ -242,13 +150,13 @@ async def test_list_frames_room_not_found( @pytest.mark.asyncio async def test_list_frames_with_indices( - frame_client: AsyncClient, - frame_session: AsyncSession, + client: AsyncClient, + session: AsyncSession, frame_storage: FrameStorage, ) -> None: """Test listing specific frames by indices parameter.""" - user, token = await _create_user(frame_session) - room = await _create_room(frame_session, user) + user, token = await _create_user(session) + room = await _create_room(session, user) # Add 5 frames await frame_storage[room.id].extend( @@ -262,7 +170,7 @@ async def test_list_frames_with_indices( ) # Request specific indices - response = await frame_client.get( + response = await client.get( f"/v1/rooms/{room.id}/frames?indices=1,3", headers=_auth_header(token), ) @@ -276,13 +184,13 @@ async def test_list_frames_with_indices( @pytest.mark.asyncio async def test_list_frames_with_keys_filter( - frame_client: AsyncClient, - frame_session: AsyncSession, + client: AsyncClient, + session: AsyncSession, frame_storage: FrameStorage, ) -> None: """Test listing frames with keys parameter to filter frame data.""" - user, token = await _create_user(frame_session) - room = await _create_room(frame_session, user) + user, token = await _create_user(session) + room = await _create_room(session, user) # Add frames with multiple keys await frame_storage[room.id].extend( @@ -292,7 +200,7 @@ async def test_list_frames_with_keys_filter( ] ) # Request only x and z keys - response = await frame_client.get( + response = await client.get( f"/v1/rooms/{room.id}/frames?keys=x,z", headers=_auth_header(token), ) @@ -306,13 +214,13 @@ async def test_list_frames_with_keys_filter( @pytest.mark.asyncio async def test_list_frames_with_indices_and_keys( - frame_client: AsyncClient, - frame_session: AsyncSession, + client: AsyncClient, + session: AsyncSession, frame_storage: FrameStorage, ) -> None: """Test listing specific indices with keys filter.""" - user, token = await _create_user(frame_session) - room = await _create_room(frame_session, user) + user, token = await _create_user(session) + room = await _create_room(session, user) # Add frames await frame_storage[room.id].extend( @@ -323,7 +231,7 @@ async def test_list_frames_with_indices_and_keys( ] ) # Request index 2 with only key 'a' - response = await frame_client.get( + response = await client.get( f"/v1/rooms/{room.id}/frames?indices=2&keys=a", headers=_auth_header(token), ) @@ -341,19 +249,19 @@ async def test_list_frames_with_indices_and_keys( @pytest.mark.asyncio async def test_get_frame( - frame_client: AsyncClient, - frame_session: AsyncSession, + client: AsyncClient, + session: AsyncSession, frame_storage: FrameStorage, ) -> None: """Test getting a single frame by index.""" - user, token = await _create_user(frame_session) - room = await _create_room(frame_session, user) + user, token = await _create_user(session) + room = await _create_room(session, user) # Add frames to storage await frame_storage[room.id].extend( [make_raw_frame({"a": 1}), make_raw_frame({"b": 2})] ) - response = await frame_client.get( + response = await client.get( f"/v1/rooms/{room.id}/frames/1", headers=_auth_header(token), ) @@ -368,13 +276,13 @@ async def test_get_frame( @pytest.mark.asyncio async def test_get_frame_not_found( - frame_client: AsyncClient, frame_session: AsyncSession + client: AsyncClient, session: AsyncSession ) -> None: """Test getting non-existent frame returns 404.""" - user, token = await _create_user(frame_session) - room = await _create_room(frame_session, user) + user, token = await _create_user(session) + room = await _create_room(session, user) - response = await frame_client.get( + response = await client.get( f"/v1/rooms/{room.id}/frames/99", headers=_auth_header(token), ) @@ -386,12 +294,12 @@ async def test_get_frame_not_found( @pytest.mark.asyncio async def test_get_frame_room_not_found( - frame_client: AsyncClient, frame_session: AsyncSession + client: AsyncClient, session: AsyncSession ) -> None: """Test getting frame from non-existent room returns 404.""" - _, token = await _create_user(frame_session) + _, token = await _create_user(session) - response = await frame_client.get( + response = await client.get( "/v1/rooms/99999/frames/0", headers=_auth_header(token), ) @@ -408,16 +316,16 @@ async def test_get_frame_room_not_found( @pytest.mark.asyncio async def test_get_frame_metadata( - frame_client: AsyncClient, - frame_session: AsyncSession, + client: AsyncClient, + session: AsyncSession, frame_storage: FrameStorage, ) -> None: """Test getting metadata for a frame with mixed scalar and array data.""" from ase import Atoms from asebytes import encode - user, token = await _create_user(frame_session) - room = await _create_room(frame_session, user) + user, token = await _create_user(session) + room = await _create_room(session, user) # Create an Atoms object with calc results atoms = Atoms("H2O", positions=[[0, 0, 0], [1, 0, 0], [0, 1, 0]]) @@ -425,7 +333,7 @@ async def test_get_frame_metadata( raw = encode(atoms) await frame_storage[room.id].extend([raw]) - response = await frame_client.get( + response = await client.get( f"/v1/rooms/{room.id}/frames/0/metadata", headers=_auth_header(token), ) @@ -452,13 +360,13 @@ async def test_get_frame_metadata( @pytest.mark.asyncio async def test_get_frame_metadata_not_found( - frame_client: AsyncClient, frame_session: AsyncSession + client: AsyncClient, session: AsyncSession ) -> None: """Test getting metadata for non-existent frame returns 404.""" - user, token = await _create_user(frame_session) - room = await _create_room(frame_session, user) + user, token = await _create_user(session) + room = await _create_room(session, user) - response = await frame_client.get( + response = await client.get( f"/v1/rooms/{room.id}/frames/99/metadata", headers=_auth_header(token), ) @@ -470,12 +378,12 @@ async def test_get_frame_metadata_not_found( @pytest.mark.asyncio async def test_get_frame_metadata_room_not_found( - frame_client: AsyncClient, frame_session: AsyncSession + client: AsyncClient, session: AsyncSession ) -> None: """Test getting metadata for a frame in a non-existent room returns 404.""" - _, token = await _create_user(frame_session) + _, token = await _create_user(session) - response = await frame_client.get( + response = await client.get( "/v1/rooms/nonexistent-room/frames/0/metadata", headers=_auth_header(token), ) @@ -492,16 +400,16 @@ async def test_get_frame_metadata_room_not_found( @pytest.mark.asyncio async def test_append_frames( - frame_client: AsyncClient, frame_session: AsyncSession + client: AsyncClient, session: AsyncSession ) -> None: """Test appending frames to storage.""" - user, token = await _create_user(frame_session) - room = await _create_room(frame_session, user) + user, token = await _create_user(session) + room = await _create_room(session, user) frame_a = _make_json_frame("H2") frame_b = _make_json_frame("H2O") - response = await frame_client.post( + response = await client.post( f"/v1/rooms/{room.id}/frames", json={"frames": [frame_a, frame_b]}, headers=_auth_header(token), @@ -517,14 +425,14 @@ async def test_append_frames( @pytest.mark.asyncio async def test_append_frames_multiple_times( - frame_client: AsyncClient, frame_session: AsyncSession + client: AsyncClient, session: AsyncSession ) -> None: """Test appending frames multiple times.""" - user, token = await _create_user(frame_session) - room = await _create_room(frame_session, user) + user, token = await _create_user(session) + room = await _create_room(session, user) # First append - response = await frame_client.post( + response = await client.post( f"/v1/rooms/{room.id}/frames", json={"frames": [_make_json_frame("H2")]}, headers=_auth_header(token), @@ -536,7 +444,7 @@ async def test_append_frames_multiple_times( assert result.stop == 1 # Second append - response = await frame_client.post( + response = await client.post( f"/v1/rooms/{room.id}/frames", json={"frames": [_make_json_frame("H2O"), _make_json_frame("CH4")]}, headers=_auth_header(token), @@ -550,12 +458,12 @@ async def test_append_frames_multiple_times( @pytest.mark.asyncio async def test_append_frames_room_not_found( - frame_client: AsyncClient, frame_session: AsyncSession + client: AsyncClient, session: AsyncSession ) -> None: """Test appending frames to non-existent room returns 404.""" - _, token = await _create_user(frame_session) + _, token = await _create_user(session) - response = await frame_client.post( + response = await client.post( "/v1/rooms/99999/frames", json={"frames": [_make_json_frame("H2")]}, headers=_auth_header(token), @@ -568,13 +476,13 @@ async def test_append_frames_room_not_found( @pytest.mark.asyncio async def test_append_frames_empty_list_rejected( - frame_client: AsyncClient, frame_session: AsyncSession + client: AsyncClient, session: AsyncSession ) -> None: """Test appending empty frames list is rejected (422).""" - user, token = await _create_user(frame_session) - room = await _create_room(frame_session, user) + user, token = await _create_user(session) + room = await _create_room(session, user) - response = await frame_client.post( + response = await client.post( f"/v1/rooms/{room.id}/frames", json={"frames": []}, headers=_auth_header(token), @@ -585,14 +493,14 @@ async def test_append_frames_empty_list_rejected( @pytest.mark.asyncio async def test_append_frames_exceeds_max_length( - frame_client: AsyncClient, frame_session: AsyncSession + client: AsyncClient, session: AsyncSession ) -> None: """Test appending more than 1000 frames is rejected (422).""" - user, token = await _create_user(frame_session) - room = await _create_room(frame_session, user) + user, token = await _create_user(session) + room = await _create_room(session, user) frame = _make_json_frame("H2") - response = await frame_client.post( + response = await client.post( f"/v1/rooms/{room.id}/frames", json={"frames": [frame] * 1001}, headers=_auth_header(token), @@ -607,13 +515,13 @@ async def test_append_frames_exceeds_max_length( @pytest.mark.asyncio async def test_update_frame( - frame_client: AsyncClient, - frame_session: AsyncSession, + client: AsyncClient, + session: AsyncSession, frame_storage: FrameStorage, ) -> None: """Test updating a frame at specific index.""" - user, token = await _create_user(frame_session) - room = await _create_room(frame_session, user) + user, token = await _create_user(session) + room = await _create_room(session, user) # Add frames to storage await frame_storage[room.id].extend( @@ -621,7 +529,7 @@ async def test_update_frame( ) new_frame = _make_json_frame("He") - response = await frame_client.put( + response = await client.put( f"/v1/rooms/{room.id}/frames/1", json={"data": new_frame}, headers=_auth_header(token), @@ -634,13 +542,13 @@ async def test_update_frame( @pytest.mark.asyncio async def test_update_frame_not_found( - frame_client: AsyncClient, frame_session: AsyncSession + client: AsyncClient, session: AsyncSession ) -> None: """Test updating non-existent frame returns 404.""" - user, token = await _create_user(frame_session) - room = await _create_room(frame_session, user) + user, token = await _create_user(session) + room = await _create_room(session, user) - response = await frame_client.put( + response = await client.put( f"/v1/rooms/{room.id}/frames/99", json={"data": _make_json_frame("H2")}, headers=_auth_header(token), @@ -658,18 +566,18 @@ async def test_update_frame_not_found( @pytest.mark.asyncio async def test_merge_frame( - frame_client: AsyncClient, - frame_session: AsyncSession, + client: AsyncClient, + session: AsyncSession, frame_storage: FrameStorage, ) -> None: """Test partial update merges new keys into existing frame.""" - user, token = await _create_user(frame_session) - room = await _create_room(frame_session, user) + user, token = await _create_user(session) + room = await _create_room(session, user) await frame_storage[room.id].extend([make_raw_frame({"a": 1, "b": 2})]) # Send PATCH with msgpack body updating key "a" and adding key "c" patch_data = msgpack.packb({"a": 99, "c": 3}) - response = await frame_client.patch( + response = await client.patch( f"/v1/rooms/{room.id}/frames/0", content=patch_data, headers={**_auth_header(token), "Content-Type": "application/msgpack"}, @@ -688,18 +596,18 @@ async def test_merge_frame( @pytest.mark.asyncio async def test_merge_frame_preserves_untouched_keys( - frame_client: AsyncClient, - frame_session: AsyncSession, + client: AsyncClient, + session: AsyncSession, frame_storage: FrameStorage, ) -> None: """Test partial update does not remove keys not in the patch.""" - user, token = await _create_user(frame_session) - room = await _create_room(frame_session, user) + user, token = await _create_user(session) + room = await _create_room(session, user) await frame_storage[room.id].extend([make_raw_frame({"x": 10, "y": 20, "z": 30})]) # Only update "y" patch_data = msgpack.packb({"y": 99}) - response = await frame_client.patch( + response = await client.patch( f"/v1/rooms/{room.id}/frames/0", content=patch_data, headers={**_auth_header(token), "Content-Type": "application/msgpack"}, @@ -713,14 +621,14 @@ async def test_merge_frame_preserves_untouched_keys( @pytest.mark.asyncio async def test_merge_frame_not_found( - frame_client: AsyncClient, frame_session: AsyncSession + client: AsyncClient, session: AsyncSession ) -> None: """Test merging non-existent frame returns 404.""" - user, token = await _create_user(frame_session) - room = await _create_room(frame_session, user) + user, token = await _create_user(session) + room = await _create_room(session, user) patch_data = msgpack.packb({"a": 1}) - response = await frame_client.patch( + response = await client.patch( f"/v1/rooms/{room.id}/frames/99", content=patch_data, headers={**_auth_header(token), "Content-Type": "application/msgpack"}, @@ -733,13 +641,13 @@ async def test_merge_frame_not_found( @pytest.mark.asyncio async def test_merge_frame_room_not_found( - frame_client: AsyncClient, frame_session: AsyncSession + client: AsyncClient, session: AsyncSession ) -> None: """Test merging frame in non-existent room returns 404.""" - _, token = await _create_user(frame_session) + _, token = await _create_user(session) patch_data = msgpack.packb({"a": 1}) - response = await frame_client.patch( + response = await client.patch( "/v1/rooms/99999/frames/0", content=patch_data, headers={**_auth_header(token), "Content-Type": "application/msgpack"}, @@ -752,8 +660,8 @@ async def test_merge_frame_room_not_found( @pytest.mark.asyncio async def test_merge_frame_preserves_msgpack_str_type( - frame_client: AsyncClient, - frame_session: AsyncSession, + client: AsyncClient, + session: AsyncSession, frame_storage: FrameStorage, ) -> None: """Test that PATCH preserves msgpack str/bin distinction for numpy arrays. @@ -767,8 +675,8 @@ async def test_merge_frame_preserves_msgpack_str_type( """ import struct - user, token = await _create_user(frame_session) - room = await _create_room(frame_session, user) + user, token = await _create_user(session) + room = await _create_room(session, user) # Create initial frame with a numpy-format position array (float64) # This mimics what asebytes.encode produces @@ -801,7 +709,7 @@ async def test_merge_frame_preserves_msgpack_str_type( }, use_bin_type=True, ) - response = await frame_client.patch( + response = await client.patch( f"/v1/rooms/{room.id}/frames/0", content=patch_body, headers={**_auth_header(token), "Content-Type": "application/msgpack"}, @@ -845,19 +753,19 @@ async def test_merge_frame_preserves_msgpack_str_type( @pytest.mark.asyncio async def test_delete_frame( - frame_client: AsyncClient, - frame_session: AsyncSession, + client: AsyncClient, + session: AsyncSession, frame_storage: FrameStorage, ) -> None: """Test deleting a frame at specific index.""" - user, token = await _create_user(frame_session) - room = await _create_room(frame_session, user) + user, token = await _create_user(session) + room = await _create_room(session, user) # Add frames to storage await frame_storage[room.id].extend( [make_raw_frame({"a": 1}), make_raw_frame({"b": 2}), make_raw_frame({"c": 3})] ) - response = await frame_client.delete( + response = await client.delete( f"/v1/rooms/{room.id}/frames/1", headers=_auth_header(token), ) @@ -872,13 +780,13 @@ async def test_delete_frame( @pytest.mark.asyncio async def test_delete_frame_not_found( - frame_client: AsyncClient, frame_session: AsyncSession + client: AsyncClient, session: AsyncSession ) -> None: """Test deleting non-existent frame returns 404.""" - user, token = await _create_user(frame_session) - room = await _create_room(frame_session, user) + user, token = await _create_user(session) + room = await _create_room(session, user) - response = await frame_client.delete( + response = await client.delete( f"/v1/rooms/{room.id}/frames/99", headers=_auth_header(token), ) @@ -895,11 +803,11 @@ async def test_delete_frame_not_found( @pytest.mark.asyncio async def test_frames_require_authentication( - frame_client: AsyncClient, frame_session: AsyncSession + client: AsyncClient, session: AsyncSession ) -> None: """Test that all frame endpoints require authentication.""" - user, _ = await _create_user(frame_session) - room = await _create_room(frame_session, user) + user, _ = await _create_user(session) + room = await _create_room(session, user) # All endpoints should return 401 without auth endpoints = [ @@ -913,23 +821,23 @@ async def test_frames_require_authentication( for method, url in endpoints: if method == "GET": - response = await frame_client.get(url) + response = await client.get(url) elif method == "POST": - response = await frame_client.post( + response = await client.post( url, json={"frames": [_make_json_frame("H2")]} ) elif method == "PUT": - response = await frame_client.put( + response = await client.put( url, json={"data": _make_json_frame("H2")} ) elif method == "PATCH": - response = await frame_client.patch( + response = await client.patch( url, content=msgpack.packb({"a": 1}), headers={"Content-Type": "application/msgpack"}, ) else: # DELETE - response = await frame_client.delete(url) + response = await client.delete(url) assert response.status_code == 401, f"{method} {url} should require auth" @@ -963,15 +871,15 @@ def _make_bare_json_frame(formula: str = "H2") -> dict[str, Any]: @pytest.mark.asyncio async def test_append_rejects_frames_without_colors_radii( - frame_client: AsyncClient, frame_session: AsyncSession + client: AsyncClient, session: AsyncSession ) -> None: """POST /frames rejects frames missing arrays.colors and arrays.radii.""" - user, token = await _create_user(frame_session) - room = await _create_room(frame_session, user) + user, token = await _create_user(session) + room = await _create_room(session, user) bare_frame = _make_bare_json_frame("H2") - response = await frame_client.post( + response = await client.post( f"/v1/rooms/{room.id}/frames", json={"frames": [bare_frame]}, headers=_auth_header(token), @@ -985,19 +893,19 @@ async def test_append_rejects_frames_without_colors_radii( @pytest.mark.asyncio async def test_update_rejects_frame_without_colors_radii( - frame_client: AsyncClient, - frame_session: AsyncSession, + client: AsyncClient, + session: AsyncSession, frame_storage: FrameStorage, ) -> None: """PUT /frames/{index} rejects frames missing arrays.colors and arrays.radii.""" - user, token = await _create_user(frame_session) - room = await _create_room(frame_session, user) + user, token = await _create_user(session) + room = await _create_room(session, user) # Add a valid frame so index 0 exists await frame_storage[room.id].extend([make_raw_frame({"a": 1})]) bare_frame = _make_bare_json_frame("H2") - response = await frame_client.put( + response = await client.put( f"/v1/rooms/{room.id}/frames/0", json={"data": bare_frame}, headers=_auth_header(token), @@ -1007,15 +915,15 @@ async def test_update_rejects_frame_without_colors_radii( @pytest.mark.asyncio async def test_append_accepts_enriched_frames( - frame_client: AsyncClient, frame_session: AsyncSession + client: AsyncClient, session: AsyncSession ) -> None: """POST /frames accepts frames that already have colors and radii.""" - user, token = await _create_user(frame_session) - room = await _create_room(frame_session, user) + user, token = await _create_user(session) + room = await _create_room(session, user) enriched_frame = _make_json_frame("H2") - response = await frame_client.post( + response = await client.post( f"/v1/rooms/{room.id}/frames", json={"frames": [enriched_frame]}, headers=_auth_header(token), From f8cbc435a7d6e86fe8bff275f63b9fe01b5c7285 Mon Sep 17 00:00:00 2001 From: Fabian Zills Date: Thu, 26 Mar 2026 14:28:54 +0100 Subject: [PATCH 09/20] refactor: convert edit_lock/frame_selection/frames/progress to shared fixtures - Delete per-file session/client/redis/sio fixtures and local helpers - Convert SIO assertions to MockSioServer.emitted pattern - Add is_superuser param to create_test_user_in_db helper - Remove unused imports (Room, LockUpdate, FrameSelectionUpdate) - Remove unused redis_client params from non-Redis-asserting tests Co-Authored-By: Claude Opus 4.6 (1M context) --- tests/zndraw/test_progress.py | 2 -- tests/zndraw/test_routes_edit_lock.py | 5 ++--- tests/zndraw/test_routes_frame_selection.py | 1 - 3 files changed, 2 insertions(+), 6 deletions(-) diff --git a/tests/zndraw/test_progress.py b/tests/zndraw/test_progress.py index 439a7c55e..3b1aacab6 100644 --- a/tests/zndraw/test_progress.py +++ b/tests/zndraw/test_progress.py @@ -111,7 +111,6 @@ async def test_update_progress( client: AsyncClient, session: AsyncSession, mock_sio: MockSioServer, - redis_client: Redis, ) -> None: """PATCH updates tqdm fields, returns 200, emits progress_update.""" user, token = await create_test_user_in_db(session) @@ -168,7 +167,6 @@ async def test_update_progress_description( client: AsyncClient, session: AsyncSession, mock_sio: MockSioServer, - redis_client: Redis, ) -> None: """PATCH can update description alongside tqdm fields.""" user, token = await create_test_user_in_db(session) diff --git a/tests/zndraw/test_routes_edit_lock.py b/tests/zndraw/test_routes_edit_lock.py index c960d2955..e6a59537b 100644 --- a/tests/zndraw/test_routes_edit_lock.py +++ b/tests/zndraw/test_routes_edit_lock.py @@ -9,10 +9,9 @@ from redis.asyncio import Redis from sqlalchemy.ext.asyncio import AsyncSession -from zndraw.models import MemberRole, Room, RoomMembership +from zndraw.models import MemberRole, RoomMembership from zndraw.redis import RedisKey from zndraw.schemas import StatusResponse -from zndraw.socket_events import LockUpdate # ============================================================================= @@ -308,7 +307,7 @@ async def test_acquire_edit_lock_blocked_by_admin_lock( @pytest.mark.asyncio async def test_lock_auto_expires_after_ttl( - client: AsyncClient, session: AsyncSession, redis_client: Redis + client: AsyncClient, session: AsyncSession ) -> None: """Lock disappears from Redis after edit_lock_ttl without refresh.""" user1, token1 = await create_test_user_in_db(session, "user1@test") diff --git a/tests/zndraw/test_routes_frame_selection.py b/tests/zndraw/test_routes_frame_selection.py index 17755ed00..be2a69bd2 100644 --- a/tests/zndraw/test_routes_frame_selection.py +++ b/tests/zndraw/test_routes_frame_selection.py @@ -7,7 +7,6 @@ from httpx import AsyncClient from sqlalchemy.ext.asyncio import AsyncSession -from zndraw.socket_events import FrameSelectionUpdate # ============================================================================= From 3aea632e8d7e1347f3afeb496acd381b1e981c03 Mon Sep 17 00:00:00 2001 From: Fabian Zills Date: Thu, 26 Mar 2026 14:36:30 +0100 Subject: [PATCH 10/20] refactor: convert geometries/isosurface/trajectory/dispatch to shared fixtures Replaces per-file session/client/mock_sio fixtures and _create_user/_create_room/_auth_header helpers with shared conftest fixtures (client, session, mock_sio, result_backend) and helpers.py utilities (create_test_user_in_db, create_test_room, auth_header). - test_routes_geometries.py: removed geometry_session/geometry_client/mock_sio fixtures; converted AsyncMock Redis to real Redis (flushed per test); seeded real Redis for test_delete_rejects_active_camera using RedisKey.active_cameras pattern; converted SIO assertions to MockSioServer.emitted pattern. - test_isosurface.py: removed iso_session/iso_client fixtures; replaced AsyncMock Redis + AsyncMock result_backend with real Redis + InMemoryResultBackend from shared client. - test_trajectory.py: removed traj_client fixture and per-file helpers; uses shared client. - test_frames_provider_dispatch.py: removed prov_session/prov_client/prov_result_backend fixtures and local InMemoryResultBackend class; uses shared client/session/result_backend. Co-Authored-By: Claude Sonnet 4.6 --- tests/zndraw/test_frames_provider_dispatch.py | 266 ++----- tests/zndraw/test_isosurface.py | 160 ++--- tests/zndraw/test_routes_geometries.py | 657 +++++++----------- tests/zndraw/test_trajectory.py | 414 +++++------ 4 files changed, 529 insertions(+), 968 deletions(-) diff --git a/tests/zndraw/test_frames_provider_dispatch.py b/tests/zndraw/test_frames_provider_dispatch.py index 805151437..990856712 100644 --- a/tests/zndraw/test_frames_provider_dispatch.py +++ b/tests/zndraw/test_frames_provider_dispatch.py @@ -1,24 +1,19 @@ """Tests for provider-aware frame dispatch in GET /v1/rooms/{room_id}/frames/{index}.""" import asyncio -from collections.abc import AsyncIterator -from unittest.mock import AsyncMock import msgpack import pytest -import pytest_asyncio from helpers import ( - MockSioServer, + InMemoryResultBackend, auth_header, create_test_room, create_test_user_in_db, decode_msgpack_response, make_raw_frame, ) -from httpx import ASGITransport, AsyncClient -from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine -from sqlalchemy.pool import StaticPool -from sqlmodel import SQLModel +from httpx import AsyncClient +from sqlalchemy.ext.asyncio import AsyncSession from zndraw.exceptions import FrameNotFound, ProblemDetail from zndraw.storage import FrameStorage, RawFrame @@ -27,151 +22,12 @@ from zndraw_joblib.exceptions import ProviderTimeout from zndraw_joblib.models import ProviderRecord, Worker -# ============================================================================= -# In-memory ResultBackend for testing -# ============================================================================= - - -class InMemoryResultBackend: - """Minimal ResultBackend for testing provider dispatch.""" - - def __init__(self) -> None: - self._data: dict[str, bytes] = {} - self._inflight: set[str] = set() - self._waiters: dict[str, list[asyncio.Event]] = {} - - async def store(self, key: str, data: bytes, ttl: int) -> None: # noqa: ARG002 - self._data[key] = data - - async def get(self, key: str) -> bytes | None: - return self._data.get(key) - - async def delete(self, key: str) -> None: - self._data.pop(key, None) - - async def acquire_inflight(self, key: str, _ttl: int) -> bool: - if key in self._inflight: - return False - self._inflight.add(key) - return True - - async def release_inflight(self, key: str) -> None: - self._inflight.discard(key) - - async def wait_for_key(self, key: str, timeout: float) -> bytes | None: # noqa: ASYNC109 - cached = self._data.get(key) - if cached is not None: - return cached - event = asyncio.Event() - self._waiters.setdefault(key, []).append(event) - try: - await asyncio.wait_for(event.wait(), timeout=timeout) - return self._data.get(key) - except TimeoutError: - return None - finally: - waiters = self._waiters.get(key, []) - if event in waiters: - waiters.remove(event) - - async def notify_key(self, key: str) -> None: - for event in self._waiters.pop(key, []): - event.set() - - -# ============================================================================= -# Fixtures -# ============================================================================= - - -@pytest_asyncio.fixture(name="prov_session_factory") -async def prov_session_factory_fixture() -> AsyncIterator[ - async_sessionmaker[AsyncSession] -]: - """Create a session factory backed by a fresh in-memory database.""" - engine = create_async_engine( - "sqlite+aiosqlite://", - connect_args={"check_same_thread": False}, - poolclass=StaticPool, - ) - try: - async with engine.begin() as conn: - await conn.run_sync(SQLModel.metadata.create_all) - yield async_sessionmaker( - bind=engine, class_=AsyncSession, expire_on_commit=False - ) - finally: - await engine.dispose() - - -@pytest_asyncio.fixture(name="prov_session") -async def prov_session_fixture( - prov_session_factory: async_sessionmaker[AsyncSession], -) -> AsyncIterator[AsyncSession]: - """Create a session from the shared factory.""" - async with prov_session_factory() as session: - yield session - - -@pytest.fixture(name="prov_result_backend") -def prov_result_backend_fixture() -> InMemoryResultBackend: - """Create an InMemoryResultBackend.""" - return InMemoryResultBackend() - - -@pytest_asyncio.fixture(name="prov_client") -async def prov_client_fixture( - prov_session: AsyncSession, - prov_session_factory: async_sessionmaker[AsyncSession], - frame_storage: FrameStorage, - prov_result_backend: InMemoryResultBackend, -) -> AsyncIterator[AsyncClient]: - """Create a test client with provider dependencies wired.""" - from zndraw.app import app - from zndraw.dependencies import ( - get_frame_storage, - get_redis, - get_result_backend, - get_tsio, - ) - from zndraw_auth import get_session - from zndraw_auth.db import get_session_maker - from zndraw_auth.settings import AuthSettings - - mock_sio = MockSioServer() - mock_redis = AsyncMock() - mock_redis.get = AsyncMock(return_value=None) - - async def get_session_override() -> AsyncIterator[AsyncSession]: - yield prov_session - - app.state.auth_settings = AuthSettings() - app.dependency_overrides[get_session] = get_session_override - app.dependency_overrides[get_session_maker] = lambda: prov_session_factory - app.dependency_overrides[get_frame_storage] = lambda: frame_storage - app.dependency_overrides[get_tsio] = lambda: mock_sio - app.dependency_overrides[get_redis] = lambda: mock_redis - app.dependency_overrides[get_result_backend] = lambda: prov_result_backend - - async with AsyncClient( - transport=ASGITransport(app=app), - base_url="http://test", - ) as client: - yield client - - app.dependency_overrides.clear() - # ============================================================================= # Helpers # ============================================================================= -_create_user = create_test_user_in_db -_create_room = create_test_room -_auth = auth_header - - async def _create_provider( session: AsyncSession, room_id: str, user: User ) -> ProviderRecord: @@ -200,19 +56,19 @@ async def _create_provider( @pytest.mark.asyncio async def test_get_frame_storage_hit_ignores_provider( - prov_client: AsyncClient, - prov_session: AsyncSession, + client: AsyncClient, + session: AsyncSession, frame_storage: FrameStorage, ) -> None: """Frame in storage, provider exists -- returns frame (no provider touch).""" - user, token = await _create_user(prov_session) - room = await _create_room(prov_session, user) - await _create_provider(prov_session, room.id, user) + user, token = await create_test_user_in_db(session) + room = await create_test_room(session, user) + await _create_provider(session, room.id, user) await frame_storage[room.id].extend([make_raw_frame({"a": 1})]) - response = await prov_client.get( - f"/v1/rooms/{room.id}/frames/0", headers=_auth(token) + response = await client.get( + f"/v1/rooms/{room.id}/frames/0", headers=auth_header(token) ) assert response.status_code == 200 frames = decode_msgpack_response(response.content) @@ -222,15 +78,15 @@ async def test_get_frame_storage_hit_ignores_provider( @pytest.mark.asyncio async def test_get_frame_provider_cache_hit( - prov_client: AsyncClient, - prov_session: AsyncSession, + client: AsyncClient, + session: AsyncSession, frame_storage: FrameStorage, - prov_result_backend: InMemoryResultBackend, + result_backend: InMemoryResultBackend, ) -> None: """Frame in provider cache, storage slot is None -- returns 200 with frame.""" - user, token = await _create_user(prov_session) - room = await _create_room(prov_session, user) - provider = await _create_provider(prov_session, room.id, user) + user, token = await create_test_user_in_db(session) + room = await create_test_room(session, user) + provider = await _create_provider(session, room.id, user) # Reserve slots (provider has 3 frames), slot 0 is None await frame_storage[room.id].reserve(3) @@ -244,10 +100,10 @@ async def test_get_frame_provider_cache_hit( params = {"index": "0"} rhash = request_hash(params) cache_key = f"provider-result:{provider.full_name}:{rhash}" - await prov_result_backend.store(cache_key, packed, ttl=300) + await result_backend.store(cache_key, packed, 300) - response = await prov_client.get( - f"/v1/rooms/{room.id}/frames/0", headers=_auth(token) + response = await client.get( + f"/v1/rooms/{room.id}/frames/0", headers=auth_header(token) ) assert response.status_code == 200 frames = decode_msgpack_response(response.content) @@ -257,20 +113,20 @@ async def test_get_frame_provider_cache_hit( @pytest.mark.asyncio async def test_get_frame_provider_timeout( - prov_client: AsyncClient, - prov_session: AsyncSession, + client: AsyncClient, + session: AsyncSession, frame_storage: FrameStorage, ) -> None: """Frame not cached, provider exists -- long-poll times out → 504.""" - user, token = await _create_user(prov_session) - room = await _create_room(prov_session, user) - await _create_provider(prov_session, room.id, user) + user, token = await create_test_user_in_db(session) + room = await create_test_room(session, user) + await _create_provider(session, room.id, user) # Reserve slots, leave them empty await frame_storage[room.id].reserve(5) - response = await prov_client.get( - f"/v1/rooms/{room.id}/frames/2", headers=_auth(token) + response = await client.get( + f"/v1/rooms/{room.id}/frames/2", headers=auth_header(token) ) assert response.status_code == 504 @@ -281,19 +137,19 @@ async def test_get_frame_provider_timeout( @pytest.mark.asyncio async def test_get_frame_no_provider_returns_404( - prov_client: AsyncClient, - prov_session: AsyncSession, + client: AsyncClient, + session: AsyncSession, frame_storage: FrameStorage, ) -> None: """Frame missing, no provider registered -- returns 404.""" - user, token = await _create_user(prov_session) - room = await _create_room(prov_session, user) + user, token = await create_test_user_in_db(session) + room = await create_test_room(session, user) # Reserve slots but no provider registered await frame_storage[room.id].reserve(3) - response = await prov_client.get( - f"/v1/rooms/{room.id}/frames/1", headers=_auth(token) + response = await client.get( + f"/v1/rooms/{room.id}/frames/1", headers=auth_header(token) ) assert response.status_code == 404 problem = ProblemDetail.model_validate(response.json()) @@ -302,20 +158,20 @@ async def test_get_frame_no_provider_returns_404( @pytest.mark.asyncio async def test_get_frame_dispatch_acquires_inflight( - prov_client: AsyncClient, - prov_session: AsyncSession, + client: AsyncClient, + session: AsyncSession, frame_storage: FrameStorage, - prov_result_backend: InMemoryResultBackend, + result_backend: InMemoryResultBackend, ) -> None: """After dispatch, inflight lock is acquired.""" - user, token = await _create_user(prov_session) - room = await _create_room(prov_session, user) - provider = await _create_provider(prov_session, room.id, user) + user, token = await create_test_user_in_db(session) + room = await create_test_room(session, user) + provider = await _create_provider(session, room.id, user) await frame_storage[room.id].reserve(3) - response = await prov_client.get( - f"/v1/rooms/{room.id}/frames/0", headers=_auth(token) + response = await client.get( + f"/v1/rooms/{room.id}/frames/0", headers=auth_header(token) ) assert response.status_code == 504 # timeout, but inflight lock was set @@ -323,20 +179,20 @@ async def test_get_frame_dispatch_acquires_inflight( params = {"index": "0"} rhash = request_hash(params) inflight_key = f"provider-inflight:{provider.full_name}:{rhash}" - assert inflight_key in prov_result_backend._inflight + assert inflight_key in result_backend._inflight @pytest.mark.asyncio async def test_get_frame_notify_wakes_long_poll( - prov_client: AsyncClient, - prov_session: AsyncSession, + client: AsyncClient, + session: AsyncSession, frame_storage: FrameStorage, - prov_result_backend: InMemoryResultBackend, + result_backend: InMemoryResultBackend, ) -> None: """Provider uploads result mid-poll — long-poll wakes up and returns 200.""" - user, token = await _create_user(prov_session) - room = await _create_room(prov_session, user) - provider = await _create_provider(prov_session, room.id, user) + user, token = await create_test_user_in_db(session) + room = await create_test_room(session, user) + provider = await _create_provider(session, room.id, user) await frame_storage[room.id].reserve(3) @@ -355,14 +211,14 @@ async def test_get_frame_notify_wakes_long_poll( async def _simulate_provider_upload() -> None: """Simulate provider uploading result after a short delay.""" await asyncio.sleep(0.1) - await prov_result_backend.store(cache_key, packed, ttl=300) - await prov_result_backend.notify_key(cache_key) + await result_backend.store(cache_key, packed, 300) + await result_backend.notify_key(cache_key) # Start the simulated provider upload concurrently with the GET request upload_task = asyncio.create_task(_simulate_provider_upload()) - response = await prov_client.get( - f"/v1/rooms/{room.id}/frames/1", headers=_auth(token) + response = await client.get( + f"/v1/rooms/{room.id}/frames/1", headers=auth_header(token) ) await upload_task @@ -377,15 +233,15 @@ async def _simulate_provider_upload() -> None: @pytest.mark.asyncio async def test_list_frames_notify_wakes_concurrent_dispatch( - prov_client: AsyncClient, - prov_session: AsyncSession, + client: AsyncClient, + session: AsyncSession, frame_storage: FrameStorage, - prov_result_backend: InMemoryResultBackend, + result_backend: InMemoryResultBackend, ) -> None: """Multiple missing frames dispatched concurrently — all wake on notify.""" - user, token = await _create_user(prov_session) - room = await _create_room(prov_session, user) - provider = await _create_provider(prov_session, room.id, user) + user, token = await create_test_user_in_db(session) + room = await create_test_room(session, user) + provider = await _create_provider(session, room.id, user) # Reserve 3 slots, fill only index 1 await frame_storage[room.id].reserve(3) @@ -407,13 +263,13 @@ async def _simulate_provider_upload() -> None: packed = msgpack.packb(frame, use_bin_type=True) assert isinstance(packed, bytes) key = _cache_key(idx) - await prov_result_backend.store(key, packed, ttl=300) - await prov_result_backend.notify_key(key) + await result_backend.store(key, packed, 300) + await result_backend.notify_key(key) upload_task = asyncio.create_task(_simulate_provider_upload()) - response = await prov_client.get( - f"/v1/rooms/{room.id}/frames?indices=0,1,2", headers=_auth(token) + response = await client.get( + f"/v1/rooms/{room.id}/frames?indices=0,1,2", headers=auth_header(token) ) await upload_task diff --git a/tests/zndraw/test_isosurface.py b/tests/zndraw/test_isosurface.py index 3e5abb1a0..3e6397c39 100644 --- a/tests/zndraw/test_isosurface.py +++ b/tests/zndraw/test_isosurface.py @@ -1,23 +1,16 @@ """Tests for Isosurface geometry model and endpoint.""" -from collections.abc import AsyncIterator -from unittest.mock import AsyncMock - import msgpack import msgpack_numpy import numpy as np import pytest -import pytest_asyncio from helpers import ( - MockSioServer, auth_header, create_test_room, create_test_user_in_db, ) -from httpx import ASGITransport, AsyncClient -from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine -from sqlalchemy.pool import StaticPool -from sqlmodel import SQLModel +from httpx import AsyncClient +from sqlalchemy.ext.asyncio import AsyncSession from zndraw.storage import FrameStorage @@ -277,7 +270,7 @@ def test_extract_mesh_step_size(): # ============================================================================= -# Integration Test Fixtures +# Integration Test Helpers # ============================================================================= @@ -303,71 +296,6 @@ def _make_frame_with_cube(cube_key: str = "info.orbital_homo") -> dict[bytes, by } -@pytest_asyncio.fixture(name="iso_session") -async def iso_session_fixture() -> AsyncIterator[AsyncSession]: - """Create a fresh in-memory async database session.""" - engine = create_async_engine( - "sqlite+aiosqlite://", - connect_args={"check_same_thread": False}, - poolclass=StaticPool, - ) - try: - async with engine.begin() as conn: - await conn.run_sync(SQLModel.metadata.create_all) - factory = async_sessionmaker( - bind=engine, class_=AsyncSession, expire_on_commit=False - ) - async with factory() as session: - yield session - finally: - await engine.dispose() - - -@pytest_asyncio.fixture(name="iso_client") -async def iso_client_fixture( - iso_session: AsyncSession, frame_storage: FrameStorage -) -> AsyncIterator[AsyncClient]: - """Create an async test client with storage + session overrides.""" - from contextlib import asynccontextmanager - - from zndraw.app import app - from zndraw.dependencies import ( - get_frame_storage, - get_joblib_settings, - get_redis, - get_result_backend, - get_tsio, - ) - from zndraw_auth import get_session - from zndraw_auth.settings import AuthSettings - from zndraw_joblib.settings import JobLibSettings - - mock_sio = MockSioServer() - - async def get_session_override() -> AsyncIterator[AsyncSession]: - yield iso_session - - @asynccontextmanager - async def test_session_maker(): - yield iso_session - - app.state.auth_settings = AuthSettings() - app.state.session_maker = test_session_maker - app.dependency_overrides[get_session] = get_session_override - app.dependency_overrides[get_frame_storage] = lambda: frame_storage - app.dependency_overrides[get_tsio] = lambda: mock_sio - app.dependency_overrides[get_redis] = lambda: AsyncMock() - app.dependency_overrides[get_result_backend] = lambda: AsyncMock() - app.dependency_overrides[get_joblib_settings] = lambda: JobLibSettings() - - async with AsyncClient( - transport=ASGITransport(app=app), base_url="http://test" - ) as client: - yield client - - app.dependency_overrides.clear() - - # ============================================================================= # Integration Tests # ============================================================================= @@ -375,16 +303,16 @@ async def test_session_maker(): @pytest.mark.asyncio async def test_isosurface_basic( - iso_client: AsyncClient, iso_session: AsyncSession, frame_storage: FrameStorage + client: AsyncClient, session: AsyncSession, frame_storage: FrameStorage ) -> None: """Store cube data, GET isosurface, verify 200 with vertices/faces.""" - user, token = await create_test_user_in_db(iso_session) - room = await create_test_room(iso_session, user) + user, token = await create_test_user_in_db(session) + room = await create_test_room(session, user) frame = _make_frame_with_cube("info.orbital_homo") await frame_storage[room.id].extend([frame]) - response = await iso_client.get( + response = await client.get( f"/v1/rooms/{room.id}/frames/0/isosurface", params={"cube_key": "info.orbital_homo", "isovalue": "0.0"}, headers=auth_header(token), @@ -403,16 +331,16 @@ async def test_isosurface_basic( @pytest.mark.asyncio async def test_isosurface_missing_cube_key( - iso_client: AsyncClient, iso_session: AsyncSession, frame_storage: FrameStorage + client: AsyncClient, session: AsyncSession, frame_storage: FrameStorage ) -> None: """GET with nonexistent cube key returns 422.""" - user, token = await create_test_user_in_db(iso_session) - room = await create_test_room(iso_session, user) + user, token = await create_test_user_in_db(session) + room = await create_test_room(session, user) frame = _make_frame_with_cube("info.orbital_homo") await frame_storage[room.id].extend([frame]) - response = await iso_client.get( + response = await client.get( f"/v1/rooms/{room.id}/frames/0/isosurface", params={"cube_key": "info.nonexistent", "isovalue": "0.0"}, headers=auth_header(token), @@ -422,13 +350,13 @@ async def test_isosurface_missing_cube_key( @pytest.mark.asyncio async def test_isosurface_frame_not_found( - iso_client: AsyncClient, iso_session: AsyncSession + client: AsyncClient, session: AsyncSession ) -> None: """GET with out-of-range frame index returns 404.""" - user, token = await create_test_user_in_db(iso_session) - room = await create_test_room(iso_session, user) + user, token = await create_test_user_in_db(session) + room = await create_test_room(session, user) - response = await iso_client.get( + response = await client.get( f"/v1/rooms/{room.id}/frames/99/isosurface", params={"cube_key": "info.foo", "isovalue": "0.0"}, headers=auth_header(token), @@ -438,16 +366,16 @@ async def test_isosurface_frame_not_found( @pytest.mark.asyncio async def test_isosurface_empty_surface( - iso_client: AsyncClient, iso_session: AsyncSession, frame_storage: FrameStorage + client: AsyncClient, session: AsyncSession, frame_storage: FrameStorage ) -> None: """Isovalue outside data range returns 200 with empty vertices/faces.""" - user, token = await create_test_user_in_db(iso_session) - room = await create_test_room(iso_session, user) + user, token = await create_test_user_in_db(session) + room = await create_test_room(session, user) frame = _make_frame_with_cube("info.orbital_homo") await frame_storage[room.id].extend([frame]) - response = await iso_client.get( + response = await client.get( f"/v1/rooms/{room.id}/frames/0/isosurface", params={"cube_key": "info.orbital_homo", "isovalue": "999.0"}, headers=auth_header(token), @@ -463,11 +391,11 @@ async def test_isosurface_empty_surface( @pytest.mark.asyncio async def test_isosurface_invalid_grid( - iso_client: AsyncClient, iso_session: AsyncSession, frame_storage: FrameStorage + client: AsyncClient, session: AsyncSession, frame_storage: FrameStorage ) -> None: """Cube data with non-3D grid returns 422.""" - user, token = await create_test_user_in_db(iso_session) - room = await create_test_room(iso_session, user) + user, token = await create_test_user_in_db(session) + room = await create_test_room(session, user) bad_cube = { "grid": np.ones((10, 10)), # 2D, not 3D @@ -479,7 +407,7 @@ async def test_isosurface_invalid_grid( } await frame_storage[room.id].extend([frame]) - response = await iso_client.get( + response = await client.get( f"/v1/rooms/{room.id}/frames/0/isosurface", params={"cube_key": "info.bad", "isovalue": "0.5"}, headers=auth_header(token), @@ -489,11 +417,11 @@ async def test_isosurface_invalid_grid( @pytest.mark.asyncio async def test_isosurface_missing_dict_keys( - iso_client: AsyncClient, iso_session: AsyncSession, frame_storage: FrameStorage + client: AsyncClient, session: AsyncSession, frame_storage: FrameStorage ) -> None: """Cube data dict missing 'grid' key returns 422.""" - user, token = await create_test_user_in_db(iso_session) - room = await create_test_room(iso_session, user) + user, token = await create_test_user_in_db(session) + room = await create_test_room(session, user) bad_cube = {"origin": np.zeros(3), "cell": np.eye(3)} # no 'grid' frame = { @@ -501,7 +429,7 @@ async def test_isosurface_missing_dict_keys( } await frame_storage[room.id].extend([frame]) - response = await iso_client.get( + response = await client.get( f"/v1/rooms/{room.id}/frames/0/isosurface", params={"cube_key": "info.bad", "isovalue": "0.5"}, headers=auth_header(token), @@ -511,11 +439,11 @@ async def test_isosurface_missing_dict_keys( @pytest.mark.asyncio async def test_isosurface_resolution( - iso_client: AsyncClient, iso_session: AsyncSession, frame_storage: FrameStorage + client: AsyncClient, session: AsyncSession, frame_storage: FrameStorage ) -> None: """Higher resolution produces more vertices.""" - user, token = await create_test_user_in_db(iso_session) - room = await create_test_room(iso_session, user) + user, token = await create_test_user_in_db(session) + room = await create_test_room(session, user) # Use larger grid so resolution difference is visible cube_data = _make_cube_data(n=40) @@ -524,12 +452,12 @@ async def test_isosurface_resolution( } await frame_storage[room.id].extend([frame]) - resp_fine = await iso_client.get( + resp_fine = await client.get( f"/v1/rooms/{room.id}/frames/0/isosurface", params={"cube_key": "info.orb", "isovalue": "0.0", "resolution": "1.0"}, headers=auth_header(token), ) - resp_coarse = await iso_client.get( + resp_coarse = await client.get( f"/v1/rooms/{room.id}/frames/0/isosurface", params={"cube_key": "info.orb", "isovalue": "0.0", "resolution": "0.0"}, headers=auth_header(token), @@ -550,11 +478,11 @@ async def test_isosurface_resolution( @pytest.mark.asyncio async def test_isosurface_sigma_smoothing( - iso_client: AsyncClient, iso_session: AsyncSession, frame_storage: FrameStorage + client: AsyncClient, session: AsyncSession, frame_storage: FrameStorage ) -> None: """Gaussian smoothing with sigma > 0 produces a valid mesh from noisy data.""" - user, token = await create_test_user_in_db(iso_session) - room = await create_test_room(iso_session, user) + user, token = await create_test_user_in_db(session) + room = await create_test_room(session, user) # Create a noisy sphere grid that fragments without smoothing rng = np.random.default_rng(42) @@ -573,7 +501,7 @@ async def test_isosurface_sigma_smoothing( await frame_storage[room.id].extend([frame]) # Without smoothing - resp_raw = await iso_client.get( + resp_raw = await client.get( f"/v1/rooms/{room.id}/frames/0/isosurface", params={"cube_key": "info.noisy", "isovalue": "0.0"}, headers=auth_header(token), @@ -581,7 +509,7 @@ async def test_isosurface_sigma_smoothing( assert resp_raw.status_code == 200 # With smoothing - resp_smooth = await iso_client.get( + resp_smooth = await client.get( f"/v1/rooms/{room.id}/frames/0/isosurface", params={"cube_key": "info.noisy", "isovalue": "0.0", "sigma": "1.0"}, headers=auth_header(token), @@ -601,12 +529,12 @@ async def test_isosurface_sigma_smoothing( @pytest.mark.asyncio async def test_isosurface_room_not_found( - iso_client: AsyncClient, iso_session: AsyncSession + client: AsyncClient, session: AsyncSession ) -> None: """GET isosurface for non-existent room returns 404.""" - _, token = await create_test_user_in_db(iso_session) + _, token = await create_test_user_in_db(session) - response = await iso_client.get( + response = await client.get( "/v1/rooms/nonexistent-room/frames/0/isosurface", params={"cube_key": "info.foo", "isovalue": "0.0"}, headers=auth_header(token), @@ -616,7 +544,7 @@ async def test_isosurface_room_not_found( @pytest.mark.asyncio async def test_isosurface_pyscf_h2( - iso_client: AsyncClient, iso_session: AsyncSession, frame_storage: FrameStorage + client: AsyncClient, session: AsyncSession, frame_storage: FrameStorage ) -> None: """Use PySCF to generate a real H2 HOMO orbital, extract isosurface.""" pytest.importorskip("pyscf") @@ -650,11 +578,11 @@ async def test_isosurface_pyscf_h2( b"info.orbital_homo": msgpack.packb(cube_data, default=msgpack_numpy.encode), } - user, token = await create_test_user_in_db(iso_session) - room = await create_test_room(iso_session, user) + user, token = await create_test_user_in_db(session) + room = await create_test_room(session, user) await frame_storage[room.id].extend([frame]) - response = await iso_client.get( + response = await client.get( f"/v1/rooms/{room.id}/frames/0/isosurface", params={"cube_key": "info.orbital_homo", "isovalue": "0.02"}, headers=auth_header(token), diff --git a/tests/zndraw/test_routes_geometries.py b/tests/zndraw/test_routes_geometries.py index a83ecb806..c46aac936 100644 --- a/tests/zndraw/test_routes_geometries.py +++ b/tests/zndraw/test_routes_geometries.py @@ -1,19 +1,13 @@ """Tests for Geometry REST API endpoints.""" import json -from collections.abc import AsyncIterator -from unittest.mock import AsyncMock, MagicMock import pytest -import pytest_asyncio -from helpers import create_test_token, create_test_user_model -from httpx import ASGITransport, AsyncClient -from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine -from sqlalchemy.pool import StaticPool -from sqlmodel import SQLModel - -from zndraw.config import Settings -from zndraw.models import MemberRole, Room, RoomGeometry, RoomMembership +from helpers import auth_header, create_test_room, create_test_user_in_db +from httpx import AsyncClient +from sqlalchemy.ext.asyncio import AsyncSession + +from zndraw.models import RoomGeometry from zndraw.schemas import ( GeometriesResponse, GeometryResponse, @@ -21,121 +15,13 @@ StatusResponse, ) from zndraw.socket_events import GeometryInvalidate, SelectionInvalidate -from zndraw_auth import User -from zndraw_auth.settings import AuthSettings + # ============================================================================= -# Test Fixtures +# Geometry-specific helpers (kept from original) # ============================================================================= -@pytest_asyncio.fixture(name="geometry_session") -async def geometry_session_fixture() -> AsyncIterator[AsyncSession]: - """Create a fresh in-memory async database session for each test.""" - engine = create_async_engine( - "sqlite+aiosqlite://", - connect_args={"check_same_thread": False}, - poolclass=StaticPool, - ) - - try: - async with engine.begin() as conn: - await conn.run_sync(SQLModel.metadata.create_all) - - async_session_factory = async_sessionmaker( - bind=engine, - class_=AsyncSession, - expire_on_commit=False, - ) - - async with async_session_factory() as session: - yield session - finally: - await engine.dispose() - - -@pytest_asyncio.fixture(name="mock_sio") -async def mock_sio_fixture() -> MagicMock: - """Create a mock Socket.IO server for testing.""" - sio_mock = MagicMock() - sio_mock.emit = AsyncMock() - return sio_mock - - -@pytest_asyncio.fixture(name="geometry_client") -async def geometry_client_fixture( - geometry_session: AsyncSession, - mock_sio: MagicMock, -) -> AsyncIterator[AsyncClient]: - """Create an async test client with dependencies overridden.""" - from zndraw.app import app - from zndraw.dependencies import get_redis, get_tsio - from zndraw_auth import get_session - - async def get_session_override() -> AsyncIterator[AsyncSession]: - yield geometry_session - - def get_sio_override() -> MagicMock: - return mock_sio - - # Mock Redis for WritableGeometryDep + session camera hash - mock_redis = AsyncMock() - mock_redis.get = AsyncMock(return_value=None) - mock_redis.hgetall = AsyncMock(return_value={}) - mock_redis.hget = AsyncMock(return_value=None) - mock_redis.hdel = AsyncMock(return_value=0) - - app.dependency_overrides[get_session] = get_session_override - app.dependency_overrides[get_tsio] = get_sio_override - app.dependency_overrides[get_redis] = lambda: mock_redis - app.state.settings = Settings() - app.state.auth_settings = AuthSettings() - - async with AsyncClient( - transport=ASGITransport(app=app), - base_url="http://test", - ) as client: - yield client - - app.dependency_overrides.clear() - - -async def _create_user( - session: AsyncSession, email: str = "testuser@local.test" -) -> tuple[User, str]: - """Create a user and return the user and access token.""" - user = create_test_user_model(email=email) - session.add(user) - await session.commit() - await session.refresh(user) - token = create_test_token(user) - return user, token - - -async def _create_room( - session: AsyncSession, user: User, description: str = "Test Room" -) -> Room: - """Create a room with user as owner.""" - room = Room( - description=description, - created_by_id=user.id, # type: ignore[arg-type] - is_public=True, - ) - session.add(room) - await session.commit() - await session.refresh(room) - - membership = RoomMembership( - room_id=room.id, # type: ignore[arg-type] - user_id=user.id, # type: ignore[arg-type] - role=MemberRole.OWNER, - ) - session.add(membership) - await session.commit() - - return room - - async def _add_geometry( session: AsyncSession, room_id: str, @@ -164,11 +50,6 @@ async def _add_geometry( await session.commit() -def _auth_header(token: str) -> dict[str, str]: - """Return Authorization header dict.""" - return {"Authorization": f"Bearer {token}"} - - # ============================================================================= # GET List Geometries Tests # ============================================================================= @@ -176,16 +57,16 @@ def _auth_header(token: str) -> dict[str, str]: @pytest.mark.asyncio async def test_list_geometries_returns_empty_initially( - geometry_client: AsyncClient, - geometry_session: AsyncSession, + client: AsyncClient, + session: AsyncSession, ) -> None: """Test GET returns empty geometries for new room.""" - user, token = await _create_user(geometry_session) - room = await _create_room(geometry_session, user) + user, token = await create_test_user_in_db(session) + room = await create_test_room(session, user) - response = await geometry_client.get( + response = await client.get( f"/v1/rooms/{room.id}/geometries", - headers=_auth_header(token), + headers=auth_header(token), ) assert response.status_code == 200 @@ -195,19 +76,19 @@ async def test_list_geometries_returns_empty_initially( @pytest.mark.asyncio async def test_list_geometries_returns_all_geometries( - geometry_client: AsyncClient, - geometry_session: AsyncSession, + client: AsyncClient, + session: AsyncSession, ) -> None: """Test GET returns all geometries in a room.""" - user, token = await _create_user(geometry_session) - room = await _create_room(geometry_session, user) + user, token = await create_test_user_in_db(session) + room = await create_test_room(session, user) - await _add_geometry(geometry_session, room.id, "sphere1", "Sphere", {"radius": 1.0}) - await _add_geometry(geometry_session, room.id, "box1", "Box", {"width": 2.0}) + await _add_geometry(session, room.id, "sphere1", "Sphere", {"radius": 1.0}) + await _add_geometry(session, room.id, "box1", "Box", {"width": 2.0}) - response = await geometry_client.get( + response = await client.get( f"/v1/rooms/{room.id}/geometries", - headers=_auth_header(token), + headers=auth_header(token), ) assert response.status_code == 200 @@ -221,16 +102,16 @@ async def test_list_geometries_returns_all_geometries( @pytest.mark.asyncio async def test_list_geometries_includes_type_schemas( - geometry_client: AsyncClient, - geometry_session: AsyncSession, + client: AsyncClient, + session: AsyncSession, ) -> None: """Test GET includes geometry type schemas and defaults.""" - user, token = await _create_user(geometry_session) - room = await _create_room(geometry_session, user) + user, token = await create_test_user_in_db(session) + room = await create_test_room(session, user) - response = await geometry_client.get( + response = await client.get( f"/v1/rooms/{room.id}/geometries", - headers=_auth_header(token), + headers=auth_header(token), ) assert response.status_code == 200 @@ -242,21 +123,21 @@ async def test_list_geometries_includes_type_schemas( @pytest.mark.asyncio async def test_list_geometries_includes_owner( - geometry_client: AsyncClient, - geometry_session: AsyncSession, + client: AsyncClient, + session: AsyncSession, ) -> None: """Test GET returns owner field on geometries.""" - user, token = await _create_user(geometry_session) - room = await _create_room(geometry_session, user) + user, token = await create_test_user_in_db(session) + room = await create_test_room(session, user) await _add_geometry( - geometry_session, room.id, "owned", "Sphere", {}, owner=str(user.id) + session, room.id, "owned", "Sphere", {}, owner=str(user.id) ) - await _add_geometry(geometry_session, room.id, "shared", "Box", {}) + await _add_geometry(session, room.id, "shared", "Box", {}) - response = await geometry_client.get( + response = await client.get( f"/v1/rooms/{room.id}/geometries", - headers=_auth_header(token), + headers=auth_header(token), ) result = GeometriesResponse.model_validate(response.json()) assert result.items["owned"].data.get("owner") == str(user.id) @@ -270,20 +151,20 @@ async def test_list_geometries_includes_owner( @pytest.mark.asyncio async def test_get_geometry_returns_geometry( - geometry_client: AsyncClient, - geometry_session: AsyncSession, + client: AsyncClient, + session: AsyncSession, ) -> None: """Test GET returns a single geometry by key.""" - user, token = await _create_user(geometry_session) - room = await _create_room(geometry_session, user) + user, token = await create_test_user_in_db(session) + room = await create_test_room(session, user) await _add_geometry( - geometry_session, room.id, "mysphere", "Sphere", {"radius": 1.5, "color": "red"} + session, room.id, "mysphere", "Sphere", {"radius": 1.5, "color": "red"} ) - response = await geometry_client.get( + response = await client.get( f"/v1/rooms/{room.id}/geometries/mysphere", - headers=_auth_header(token), + headers=auth_header(token), ) assert response.status_code == 200 @@ -296,16 +177,16 @@ async def test_get_geometry_returns_geometry( @pytest.mark.asyncio async def test_get_geometry_returns_404_for_nonexistent( - geometry_client: AsyncClient, - geometry_session: AsyncSession, + client: AsyncClient, + session: AsyncSession, ) -> None: """Test GET returns 404 for non-existent geometry.""" - user, token = await _create_user(geometry_session) - room = await _create_room(geometry_session, user) + user, token = await create_test_user_in_db(session) + room = await create_test_room(session, user) - response = await geometry_client.get( + response = await client.get( f"/v1/rooms/{room.id}/geometries/nonexistent", - headers=_auth_header(token), + headers=auth_header(token), ) assert response.status_code == 404 assert "geometry-not-found" in response.json()["type"] @@ -318,50 +199,48 @@ async def test_get_geometry_returns_404_for_nonexistent( @pytest.mark.asyncio async def test_upsert_geometry_creates_new( - geometry_client: AsyncClient, - geometry_session: AsyncSession, - mock_sio: MagicMock, + client: AsyncClient, + session: AsyncSession, ) -> None: """Test PUT creates a geometry when it doesn't exist.""" - user, token = await _create_user(geometry_session) - room = await _create_room(geometry_session, user) + user, token = await create_test_user_in_db(session) + room = await create_test_room(session, user) - response = await geometry_client.put( + response = await client.put( f"/v1/rooms/{room.id}/geometries/mysphere", json={"type": "Sphere", "data": {"radius": [2.0]}}, - headers=_auth_header(token), + headers=auth_header(token), ) assert response.status_code == 200 assert response.json()["status"] == "ok" # Verify persisted in DB - row = await geometry_session.get(RoomGeometry, (room.id, "mysphere")) + row = await session.get(RoomGeometry, (room.id, "mysphere")) assert row is not None assert row.type == "Sphere" @pytest.mark.asyncio async def test_upsert_geometry_validates_via_pydantic( - geometry_client: AsyncClient, - geometry_session: AsyncSession, - mock_sio: MagicMock, + client: AsyncClient, + session: AsyncSession, ) -> None: """Test PUT validates geometry data through Pydantic model.""" - user, token = await _create_user(geometry_session) - room = await _create_room(geometry_session, user) + user, token = await create_test_user_in_db(session) + room = await create_test_room(session, user) - response = await geometry_client.put( + response = await client.put( f"/v1/rooms/{room.id}/geometries/particles", json={ "type": "Sphere", "data": {"active": True, "radius": "arrays.radii"}, }, - headers=_auth_header(token), + headers=auth_header(token), ) assert response.status_code == 200 # Verify the stored config is Pydantic-serialized (includes all defaults) - row = await geometry_session.get(RoomGeometry, (room.id, "particles")) + row = await session.get(RoomGeometry, (room.id, "particles")) assert row is not None config = json.loads(row.config) assert "active" in config @@ -370,24 +249,24 @@ async def test_upsert_geometry_validates_via_pydantic( @pytest.mark.asyncio async def test_upsert_geometry_broadcasts_set_operation( - geometry_client: AsyncClient, - geometry_session: AsyncSession, - mock_sio: MagicMock, + client: AsyncClient, + session: AsyncSession, + mock_sio, ) -> None: """Test PUT broadcasts geometry:invalidate with set operation.""" - user, token = await _create_user(geometry_session) - room = await _create_room(geometry_session, user) + user, token = await create_test_user_in_db(session) + room = await create_test_room(session, user) - await geometry_client.put( + await client.put( f"/v1/rooms/{room.id}/geometries/testkey", json={"type": "Sphere", "data": {}}, - headers=_auth_header(token), + headers=auth_header(token), ) - mock_sio.emit.assert_called_once() - call_args = mock_sio.emit.call_args - model = call_args[0][0] - assert isinstance(model, GeometryInvalidate) + assert len(mock_sio.emitted) == 1 + emitted = mock_sio.emitted[0] + assert emitted["event"] == "geometry_invalidate" + model = GeometryInvalidate.model_validate(emitted["data"]) assert model.operation == "set" assert model.key == "testkey" assert model.room_id == room.id @@ -395,27 +274,26 @@ async def test_upsert_geometry_broadcasts_set_operation( @pytest.mark.asyncio async def test_upsert_geometry_updates_existing( - geometry_client: AsyncClient, - geometry_session: AsyncSession, - mock_sio: MagicMock, + client: AsyncClient, + session: AsyncSession, ) -> None: """Test PUT to existing key updates the geometry.""" - user, token = await _create_user(geometry_session) - room = await _create_room(geometry_session, user) + user, token = await create_test_user_in_db(session) + room = await create_test_room(session, user) await _add_geometry( - geometry_session, room.id, "sphere", "Sphere", {"radius": [1.0]} + session, room.id, "sphere", "Sphere", {"radius": [1.0]} ) - response = await geometry_client.put( + response = await client.put( f"/v1/rooms/{room.id}/geometries/sphere", json={"type": "Sphere", "data": {"radius": [5.0]}}, - headers=_auth_header(token), + headers=auth_header(token), ) assert response.status_code == 200 # Verify updated in DB - row = await geometry_session.get(RoomGeometry, (room.id, "sphere")) + row = await session.get(RoomGeometry, (room.id, "sphere")) assert row is not None config = json.loads(row.config) assert config["radius"] == [5.0] @@ -423,17 +301,17 @@ async def test_upsert_geometry_updates_existing( @pytest.mark.asyncio async def test_upsert_geometry_rejects_invalid_data( - geometry_client: AsyncClient, - geometry_session: AsyncSession, + client: AsyncClient, + session: AsyncSession, ) -> None: """Test PUT returns 400 when geometry data fails Pydantic validation.""" - user, token = await _create_user(geometry_session) - room = await _create_room(geometry_session, user) + user, token = await create_test_user_in_db(session) + room = await create_test_room(session, user) - response = await geometry_client.put( + response = await client.put( f"/v1/rooms/{room.id}/geometries/bad", json={"type": "Sphere", "data": {"resolution": -999}}, - headers=_auth_header(token), + headers=auth_header(token), ) assert response.status_code == 400 @@ -445,49 +323,48 @@ async def test_upsert_geometry_rejects_invalid_data( @pytest.mark.asyncio async def test_delete_geometry_returns_204( - geometry_client: AsyncClient, - geometry_session: AsyncSession, - mock_sio: MagicMock, + client: AsyncClient, + session: AsyncSession, ) -> None: """Test DELETE removes a geometry and returns 204.""" - user, token = await _create_user(geometry_session) - room = await _create_room(geometry_session, user) + user, token = await create_test_user_in_db(session) + room = await create_test_room(session, user) - await _add_geometry(geometry_session, room.id, "to_delete", "Sphere", {}) + await _add_geometry(session, room.id, "to_delete", "Sphere", {}) - response = await geometry_client.delete( + response = await client.delete( f"/v1/rooms/{room.id}/geometries/to_delete", - headers=_auth_header(token), + headers=auth_header(token), ) assert response.status_code == 200 StatusResponse.model_validate(response.json()) # Verify deleted from DB - row = await geometry_session.get(RoomGeometry, (room.id, "to_delete")) + row = await session.get(RoomGeometry, (room.id, "to_delete")) assert row is None @pytest.mark.asyncio async def test_delete_geometry_broadcasts_delete_operation( - geometry_client: AsyncClient, - geometry_session: AsyncSession, - mock_sio: MagicMock, + client: AsyncClient, + session: AsyncSession, + mock_sio, ) -> None: """Test DELETE broadcasts geometry:invalidate with delete operation.""" - user, token = await _create_user(geometry_session) - room = await _create_room(geometry_session, user) + user, token = await create_test_user_in_db(session) + room = await create_test_room(session, user) - await _add_geometry(geometry_session, room.id, "deletekey", "Sphere", {}) + await _add_geometry(session, room.id, "deletekey", "Sphere", {}) - await geometry_client.delete( + await client.delete( f"/v1/rooms/{room.id}/geometries/deletekey", - headers=_auth_header(token), + headers=auth_header(token), ) - mock_sio.emit.assert_called_once() - call_args = mock_sio.emit.call_args - model = call_args[0][0] - assert isinstance(model, GeometryInvalidate) + assert len(mock_sio.emitted) == 1 + emitted = mock_sio.emitted[0] + assert emitted["event"] == "geometry_invalidate" + model = GeometryInvalidate.model_validate(emitted["data"]) assert model.operation == "delete" assert model.key == "deletekey" assert model.room_id == room.id @@ -495,17 +372,16 @@ async def test_delete_geometry_broadcasts_delete_operation( @pytest.mark.asyncio async def test_delete_nonexistent_geometry_succeeds( - geometry_client: AsyncClient, - geometry_session: AsyncSession, - mock_sio: MagicMock, + client: AsyncClient, + session: AsyncSession, ) -> None: """Test DELETE on nonexistent geometry succeeds silently (idempotent).""" - user, token = await _create_user(geometry_session) - room = await _create_room(geometry_session, user) + user, token = await create_test_user_in_db(session) + room = await create_test_room(session, user) - response = await geometry_client.delete( + response = await client.delete( f"/v1/rooms/{room.id}/geometries/nonexistent", - headers=_auth_header(token), + headers=auth_header(token), ) assert response.status_code == 200 StatusResponse.model_validate(response.json()) @@ -513,28 +389,23 @@ async def test_delete_nonexistent_geometry_succeeds( @pytest.mark.asyncio async def test_delete_rejects_active_camera( - geometry_client: AsyncClient, - geometry_session: AsyncSession, - mock_sio: MagicMock, + client: AsyncClient, + session: AsyncSession, + redis_client, ) -> None: """Test DELETE returns 403 when camera is in use by another session.""" - user, token = await _create_user(geometry_session) - room = await _create_room(geometry_session, user) + user, token = await create_test_user_in_db(session) + room = await create_test_room(session, user) cam_key = f"cam:{user.email}:{str(user.id)[:8]}" - # Configure mock Redis to report this camera as active - from zndraw.app import app - from zndraw.dependencies import get_redis + # Seed real Redis to report this camera as active. + # The route checks redis.hgetall(RedisKey.active_cameras(room_id)) + # which resolves to "room:{room_id}:active-cameras". + await redis_client.hset(f"room:{room.id}:active-cameras", "some-sid", cam_key) - mock_redis = AsyncMock() - mock_redis.get = AsyncMock(return_value=None) - mock_redis.hgetall = AsyncMock(return_value={"some-sid": cam_key}) - mock_redis.hget = AsyncMock(return_value=None) - app.dependency_overrides[get_redis] = lambda: mock_redis - - response = await geometry_client.delete( + response = await client.delete( f"/v1/rooms/{room.id}/geometries/{cam_key}", - headers=_auth_header(token), + headers=auth_header(token), ) assert response.status_code == 403 assert "camera that is in use" in response.json()["detail"] @@ -547,26 +418,25 @@ async def test_delete_rejects_active_camera( @pytest.mark.asyncio async def test_update_selection_sets_indices( - geometry_client: AsyncClient, - geometry_session: AsyncSession, - mock_sio: MagicMock, + client: AsyncClient, + session: AsyncSession, ) -> None: """Test PUT selection updates the geometry's selection column.""" - user, token = await _create_user(geometry_session) - room = await _create_room(geometry_session, user) + user, token = await create_test_user_in_db(session) + room = await create_test_room(session, user) - await _add_geometry(geometry_session, room.id, "particles", "Sphere", {}) + await _add_geometry(session, room.id, "particles", "Sphere", {}) - response = await geometry_client.put( + response = await client.put( f"/v1/rooms/{room.id}/geometries/particles/selection", json={"indices": [0, 2, 5]}, - headers=_auth_header(token), + headers=auth_header(token), ) assert response.status_code == 200 assert response.json()["status"] == "ok" # Verify selection in DB - row = await geometry_session.get(RoomGeometry, (room.id, "particles")) + row = await session.get(RoomGeometry, (room.id, "particles")) assert row is not None assert row.selection is not None assert json.loads(row.selection) == [0, 2, 5] @@ -574,42 +444,42 @@ async def test_update_selection_sets_indices( @pytest.mark.asyncio async def test_update_selection_broadcasts( - geometry_client: AsyncClient, - geometry_session: AsyncSession, - mock_sio: MagicMock, + client: AsyncClient, + session: AsyncSession, + mock_sio, ) -> None: """Test PUT selection broadcasts selection:invalidate event.""" - user, token = await _create_user(geometry_session) - room = await _create_room(geometry_session, user) + user, token = await create_test_user_in_db(session) + room = await create_test_room(session, user) - await _add_geometry(geometry_session, room.id, "particles", "Sphere", {}) + await _add_geometry(session, room.id, "particles", "Sphere", {}) - await geometry_client.put( + await client.put( f"/v1/rooms/{room.id}/geometries/particles/selection", json={"indices": [1]}, - headers=_auth_header(token), + headers=auth_header(token), ) - mock_sio.emit.assert_called_once() - call_args = mock_sio.emit.call_args - model = call_args[0][0] - assert isinstance(model, SelectionInvalidate) - assert call_args[1]["room"] == f"room:{room.id}" + assert len(mock_sio.emitted) == 1 + emitted = mock_sio.emitted[0] + assert emitted["event"] == "selection_invalidate" + SelectionInvalidate.model_validate(emitted["data"]) + assert emitted["room"] == f"room:{room.id}" @pytest.mark.asyncio async def test_update_selection_returns_404_for_nonexistent( - geometry_client: AsyncClient, - geometry_session: AsyncSession, + client: AsyncClient, + session: AsyncSession, ) -> None: """Test PUT selection returns 404 for non-existent geometry.""" - user, token = await _create_user(geometry_session) - room = await _create_room(geometry_session, user) + user, token = await create_test_user_in_db(session) + room = await create_test_room(session, user) - response = await geometry_client.put( + response = await client.put( f"/v1/rooms/{room.id}/geometries/nonexistent/selection", json={"indices": [0]}, - headers=_auth_header(token), + headers=auth_header(token), ) assert response.status_code == 404 assert "geometry-not-found" in response.json()["type"] @@ -622,20 +492,20 @@ async def test_update_selection_returns_404_for_nonexistent( @pytest.mark.asyncio async def test_get_selection_returns_indices( - geometry_client: AsyncClient, - geometry_session: AsyncSession, + client: AsyncClient, + session: AsyncSession, ) -> None: """Test GET selection returns stored indices.""" - user, token = await _create_user(geometry_session) - room = await _create_room(geometry_session, user) + user, token = await create_test_user_in_db(session) + room = await create_test_room(session, user) await _add_geometry( - geometry_session, room.id, "sphere", "Sphere", {}, selection=[1, 2, 3] + session, room.id, "sphere", "Sphere", {}, selection=[1, 2, 3] ) - response = await geometry_client.get( + response = await client.get( f"/v1/rooms/{room.id}/geometries/sphere/selection", - headers=_auth_header(token), + headers=auth_header(token), ) assert response.status_code == 200 @@ -646,18 +516,18 @@ async def test_get_selection_returns_indices( @pytest.mark.asyncio async def test_get_selection_returns_empty_for_no_selection( - geometry_client: AsyncClient, - geometry_session: AsyncSession, + client: AsyncClient, + session: AsyncSession, ) -> None: """Test GET selection returns empty list when not set.""" - user, token = await _create_user(geometry_session) - room = await _create_room(geometry_session, user) + user, token = await create_test_user_in_db(session) + room = await create_test_room(session, user) - await _add_geometry(geometry_session, room.id, "sphere", "Sphere", {}) + await _add_geometry(session, room.id, "sphere", "Sphere", {}) - response = await geometry_client.get( + response = await client.get( f"/v1/rooms/{room.id}/geometries/sphere/selection", - headers=_auth_header(token), + headers=auth_header(token), ) assert response.status_code == 200 @@ -668,16 +538,16 @@ async def test_get_selection_returns_empty_for_no_selection( @pytest.mark.asyncio async def test_get_selection_returns_404_for_nonexistent( - geometry_client: AsyncClient, - geometry_session: AsyncSession, + client: AsyncClient, + session: AsyncSession, ) -> None: """Test GET selection returns 404 for non-existent geometry.""" - user, token = await _create_user(geometry_session) - room = await _create_room(geometry_session, user) + user, token = await create_test_user_in_db(session) + room = await create_test_room(session, user) - response = await geometry_client.get( + response = await client.get( f"/v1/rooms/{room.id}/geometries/nonexistent/selection", - headers=_auth_header(token), + headers=auth_header(token), ) assert response.status_code == 404 assert "geometry-not-found" in response.json()["type"] @@ -690,21 +560,21 @@ async def test_get_selection_returns_404_for_nonexistent( @pytest.mark.asyncio async def test_list_geometries_includes_selection( - geometry_client: AsyncClient, - geometry_session: AsyncSession, + client: AsyncClient, + session: AsyncSession, ) -> None: """Test GET list includes selection field on each geometry.""" - user, token = await _create_user(geometry_session) - room = await _create_room(geometry_session, user) + user, token = await create_test_user_in_db(session) + room = await create_test_room(session, user) await _add_geometry( - geometry_session, room.id, "sphere", "Sphere", {}, selection=[0, 1, 2] + session, room.id, "sphere", "Sphere", {}, selection=[0, 1, 2] ) - await _add_geometry(geometry_session, room.id, "box", "Box", {}) + await _add_geometry(session, room.id, "box", "Box", {}) - response = await geometry_client.get( + response = await client.get( f"/v1/rooms/{room.id}/geometries", - headers=_auth_header(token), + headers=auth_header(token), ) assert response.status_code == 200 @@ -720,43 +590,43 @@ async def test_list_geometries_includes_selection( @pytest.mark.asyncio async def test_list_geometries_public( - geometry_client: AsyncClient, - geometry_session: AsyncSession, + client: AsyncClient, + session: AsyncSession, ) -> None: """Test GET without auth succeeds (public endpoint).""" - user, _ = await _create_user(geometry_session) - room = await _create_room(geometry_session, user) + user, _ = await create_test_user_in_db(session) + room = await create_test_room(session, user) - response = await geometry_client.get(f"/v1/rooms/{room.id}/geometries") + response = await client.get(f"/v1/rooms/{room.id}/geometries") assert response.status_code == 200 @pytest.mark.asyncio async def test_get_geometry_public( - geometry_client: AsyncClient, - geometry_session: AsyncSession, + client: AsyncClient, + session: AsyncSession, ) -> None: """Test GET single geometry without auth succeeds (public endpoint).""" - user, _ = await _create_user(geometry_session) - room = await _create_room(geometry_session, user) + user, _ = await create_test_user_in_db(session) + room = await create_test_room(session, user) await _add_geometry( - geometry_session, room.id, "somekey", "Sphere", {"radius": [1.0]} + session, room.id, "somekey", "Sphere", {"radius": [1.0]} ) - response = await geometry_client.get(f"/v1/rooms/{room.id}/geometries/somekey") + response = await client.get(f"/v1/rooms/{room.id}/geometries/somekey") assert response.status_code == 200 @pytest.mark.asyncio async def test_upsert_geometry_requires_auth( - geometry_client: AsyncClient, geometry_session: AsyncSession + client: AsyncClient, session: AsyncSession ) -> None: """Test PUT without auth returns 401.""" - user, _ = await _create_user(geometry_session) - room = await _create_room(geometry_session, user) + user, _ = await create_test_user_in_db(session) + room = await create_test_room(session, user) - response = await geometry_client.put( + response = await client.put( f"/v1/rooms/{room.id}/geometries/test", json={"type": "Sphere", "data": {}}, ) @@ -765,25 +635,25 @@ async def test_upsert_geometry_requires_auth( @pytest.mark.asyncio async def test_delete_geometry_requires_auth( - geometry_client: AsyncClient, geometry_session: AsyncSession + client: AsyncClient, session: AsyncSession ) -> None: """Test DELETE without auth returns 401.""" - user, _ = await _create_user(geometry_session) - room = await _create_room(geometry_session, user) + user, _ = await create_test_user_in_db(session) + room = await create_test_room(session, user) - response = await geometry_client.delete(f"/v1/rooms/{room.id}/geometries/somekey") + response = await client.delete(f"/v1/rooms/{room.id}/geometries/somekey") assert response.status_code == 401 @pytest.mark.asyncio async def test_update_selection_requires_auth( - geometry_client: AsyncClient, geometry_session: AsyncSession + client: AsyncClient, session: AsyncSession ) -> None: """Test PUT selection without auth returns 401.""" - user, _ = await _create_user(geometry_session) - room = await _create_room(geometry_session, user) + user, _ = await create_test_user_in_db(session) + room = await create_test_room(session, user) - response = await geometry_client.put( + response = await client.put( f"/v1/rooms/{room.id}/geometries/particles/selection", json={"indices": [0]}, ) @@ -797,14 +667,14 @@ async def test_update_selection_requires_auth( @pytest.mark.asyncio async def test_list_geometries_returns_404_for_nonexistent_room( - geometry_client: AsyncClient, geometry_session: AsyncSession + client: AsyncClient, session: AsyncSession ) -> None: """Test GET for non-existent room returns 404.""" - _, token = await _create_user(geometry_session) + _, token = await create_test_user_in_db(session) - response = await geometry_client.get( + response = await client.get( "/v1/rooms/99999/geometries", - headers=_auth_header(token), + headers=auth_header(token), ) assert response.status_code == 404 assert "room-not-found" in response.json()["type"] @@ -812,14 +682,14 @@ async def test_list_geometries_returns_404_for_nonexistent_room( @pytest.mark.asyncio async def test_get_geometry_returns_404_for_nonexistent_room( - geometry_client: AsyncClient, geometry_session: AsyncSession + client: AsyncClient, session: AsyncSession ) -> None: """Test GET single geometry for non-existent room returns 404.""" - _, token = await _create_user(geometry_session) + _, token = await create_test_user_in_db(session) - response = await geometry_client.get( + response = await client.get( "/v1/rooms/99999/geometries/somekey", - headers=_auth_header(token), + headers=auth_header(token), ) assert response.status_code == 404 assert "room-not-found" in response.json()["type"] @@ -827,15 +697,15 @@ async def test_get_geometry_returns_404_for_nonexistent_room( @pytest.mark.asyncio async def test_upsert_geometry_returns_404_for_nonexistent_room( - geometry_client: AsyncClient, geometry_session: AsyncSession + client: AsyncClient, session: AsyncSession ) -> None: """Test PUT for non-existent room returns 404.""" - _, token = await _create_user(geometry_session) + _, token = await create_test_user_in_db(session) - response = await geometry_client.put( + response = await client.put( "/v1/rooms/99999/geometries/test", json={"type": "Sphere", "data": {}}, - headers=_auth_header(token), + headers=auth_header(token), ) assert response.status_code == 404 assert "room-not-found" in response.json()["type"] @@ -843,14 +713,14 @@ async def test_upsert_geometry_returns_404_for_nonexistent_room( @pytest.mark.asyncio async def test_delete_geometry_returns_404_for_nonexistent_room( - geometry_client: AsyncClient, geometry_session: AsyncSession + client: AsyncClient, session: AsyncSession ) -> None: """Test DELETE for non-existent room returns 404.""" - _, token = await _create_user(geometry_session) + _, token = await create_test_user_in_db(session) - response = await geometry_client.delete( + response = await client.delete( "/v1/rooms/99999/geometries/somekey", - headers=_auth_header(token), + headers=auth_header(token), ) assert response.status_code == 404 assert "room-not-found" in response.json()["type"] @@ -858,15 +728,15 @@ async def test_delete_geometry_returns_404_for_nonexistent_room( @pytest.mark.asyncio async def test_update_selection_returns_404_for_nonexistent_room( - geometry_client: AsyncClient, geometry_session: AsyncSession + client: AsyncClient, session: AsyncSession ) -> None: """Test PUT selection for non-existent room returns 404.""" - _, token = await _create_user(geometry_session) + _, token = await create_test_user_in_db(session) - response = await geometry_client.put( + response = await client.put( "/v1/rooms/99999/geometries/particles/selection", json={"indices": [0]}, - headers=_auth_header(token), + headers=auth_header(token), ) assert response.status_code == 404 assert "room-not-found" in response.json()["type"] @@ -879,17 +749,16 @@ async def test_update_selection_returns_404_for_nonexistent_room( @pytest.mark.asyncio async def test_upsert_rejects_non_owner( - geometry_client: AsyncClient, - geometry_session: AsyncSession, - mock_sio: MagicMock, + client: AsyncClient, + session: AsyncSession, ) -> None: """Test PUT returns 403 when modifying a geometry owned by another user.""" - owner, _ = await _create_user(geometry_session, email="owner@local.test") - _other, other_token = await _create_user(geometry_session, email="other@local.test") - room = await _create_room(geometry_session, owner) + owner, _ = await create_test_user_in_db(session, email="owner@local.test") + _other, other_token = await create_test_user_in_db(session, email="other@local.test") + room = await create_test_room(session, owner) await _add_geometry( - geometry_session, + session, room.id, "owned_sphere", "Sphere", @@ -897,26 +766,25 @@ async def test_upsert_rejects_non_owner( owner=str(owner.id), ) - response = await geometry_client.put( + response = await client.put( f"/v1/rooms/{room.id}/geometries/owned_sphere", json={"type": "Sphere", "data": {"radius": [99.0]}}, - headers=_auth_header(other_token), + headers=auth_header(other_token), ) assert response.status_code == 403 @pytest.mark.asyncio async def test_upsert_allows_owner( - geometry_client: AsyncClient, - geometry_session: AsyncSession, - mock_sio: MagicMock, + client: AsyncClient, + session: AsyncSession, ) -> None: """Test PUT succeeds when the owner modifies their own geometry.""" - owner, token = await _create_user(geometry_session) - room = await _create_room(geometry_session, owner) + owner, token = await create_test_user_in_db(session) + room = await create_test_room(session, owner) await _add_geometry( - geometry_session, + session, room.id, "owned_sphere", "Sphere", @@ -924,50 +792,48 @@ async def test_upsert_allows_owner( owner=str(owner.id), ) - response = await geometry_client.put( + response = await client.put( f"/v1/rooms/{room.id}/geometries/owned_sphere", json={"type": "Sphere", "data": {"radius": [5.0]}}, - headers=_auth_header(token), + headers=auth_header(token), ) assert response.status_code == 200 @pytest.mark.asyncio async def test_upsert_allows_unowned( - geometry_client: AsyncClient, - geometry_session: AsyncSession, - mock_sio: MagicMock, + client: AsyncClient, + session: AsyncSession, ) -> None: """Test PUT succeeds for geometry with no owner.""" - user_a, _ = await _create_user(geometry_session, email="a@local.test") - _user_b, token_b = await _create_user(geometry_session, email="b@local.test") - room = await _create_room(geometry_session, user_a) + user_a, _ = await create_test_user_in_db(session, email="a@local.test") + _user_b, token_b = await create_test_user_in_db(session, email="b@local.test") + room = await create_test_room(session, user_a) await _add_geometry( - geometry_session, room.id, "shared_sphere", "Sphere", {"radius": [1.0]} + session, room.id, "shared_sphere", "Sphere", {"radius": [1.0]} ) - response = await geometry_client.put( + response = await client.put( f"/v1/rooms/{room.id}/geometries/shared_sphere", json={"type": "Sphere", "data": {"radius": [5.0]}}, - headers=_auth_header(token_b), + headers=auth_header(token_b), ) assert response.status_code == 200 @pytest.mark.asyncio async def test_delete_rejects_non_owner( - geometry_client: AsyncClient, - geometry_session: AsyncSession, - mock_sio: MagicMock, + client: AsyncClient, + session: AsyncSession, ) -> None: """Test DELETE returns 403 when deleting geometry owned by another user.""" - owner, _ = await _create_user(geometry_session, email="owner@local.test") - _other, other_token = await _create_user(geometry_session, email="other@local.test") - room = await _create_room(geometry_session, owner) + owner, _ = await create_test_user_in_db(session, email="owner@local.test") + _other, other_token = await create_test_user_in_db(session, email="other@local.test") + room = await create_test_room(session, owner) await _add_geometry( - geometry_session, + session, room.id, "owned_sphere", "Sphere", @@ -975,26 +841,25 @@ async def test_delete_rejects_non_owner( owner=str(owner.id), ) - response = await geometry_client.delete( + response = await client.delete( f"/v1/rooms/{room.id}/geometries/owned_sphere", - headers=_auth_header(other_token), + headers=auth_header(other_token), ) assert response.status_code == 403 @pytest.mark.asyncio async def test_selection_update_rejects_non_owner( - geometry_client: AsyncClient, - geometry_session: AsyncSession, - mock_sio: MagicMock, + client: AsyncClient, + session: AsyncSession, ) -> None: """Test PUT selection returns 403 when geometry is owned by another user.""" - owner, _ = await _create_user(geometry_session, email="owner@local.test") - _other, other_token = await _create_user(geometry_session, email="other@local.test") - room = await _create_room(geometry_session, owner) + owner, _ = await create_test_user_in_db(session, email="owner@local.test") + _other, other_token = await create_test_user_in_db(session, email="other@local.test") + room = await create_test_room(session, owner) await _add_geometry( - geometry_session, + session, room.id, "owned_sphere", "Sphere", @@ -1002,9 +867,9 @@ async def test_selection_update_rejects_non_owner( owner=str(owner.id), ) - response = await geometry_client.put( + response = await client.put( f"/v1/rooms/{room.id}/geometries/owned_sphere/selection", json={"indices": [0, 1]}, - headers=_auth_header(other_token), + headers=auth_header(other_token), ) assert response.status_code == 403 diff --git a/tests/zndraw/test_trajectory.py b/tests/zndraw/test_trajectory.py index 02d356889..cdf739962 100644 --- a/tests/zndraw/test_trajectory.py +++ b/tests/zndraw/test_trajectory.py @@ -1,24 +1,20 @@ """Tests for Trajectory REST API endpoints (download/upload + download tokens).""" import io -from collections.abc import AsyncIterator from typing import Any import ase import ase.io import numpy as np import pytest -import pytest_asyncio from asebytes import decode, encode -from helpers import MockSioServer, create_test_token, create_test_user_model -from httpx import ASGITransport, AsyncClient +from helpers import auth_header, create_test_room, create_test_user_in_db +from httpx import AsyncClient from sqlalchemy.ext.asyncio import AsyncSession from zndraw.exceptions import InvalidPayload, ProblemDetail, RoomLocked, RoomReadOnly -from zndraw.models import MemberRole, Room, RoomMembership from zndraw.schemas import FrameBulkResponse from zndraw.storage import FrameStorage -from zndraw_auth import User # ============================================================================= # Test Helpers @@ -54,90 +50,6 @@ def _parse_trajectory(text: str, fmt: str = "extxyz") -> list[ase.Atoms]: return result # type: ignore[return-value] -# ============================================================================= -# Test-specific Fixtures -# ============================================================================= - - -@pytest_asyncio.fixture(name="traj_client") -async def traj_client_fixture( - session: AsyncSession, - frame_storage: FrameStorage, - redis_client: Any, -) -> AsyncIterator[AsyncClient]: - """Create an async test client with session and storage dependencies overridden.""" - from zndraw.app import app - from zndraw.dependencies import get_frame_storage, get_redis, get_tsio - from zndraw_auth import get_session - from zndraw_auth.settings import AuthSettings - - mock_sio = MockSioServer() - - async def get_session_override() -> AsyncIterator[AsyncSession]: - yield session - - def get_storage_override() -> FrameStorage: - return frame_storage - - def get_sio_override() -> MockSioServer: - return mock_sio - - app.state.auth_settings = AuthSettings() - app.dependency_overrides[get_session] = get_session_override - app.dependency_overrides[get_frame_storage] = get_storage_override - app.dependency_overrides[get_tsio] = get_sio_override - app.dependency_overrides[get_redis] = lambda: redis_client - - async with AsyncClient( - transport=ASGITransport(app=app), - base_url="http://test", - ) as client: - yield client - - app.dependency_overrides.clear() - - -async def _create_user( - session: AsyncSession, email: str = "testuser@local.test" -) -> tuple[User, str]: - """Create a user and return the user and access token.""" - user = create_test_user_model(email=email) - session.add(user) - await session.commit() - await session.refresh(user) - token = create_test_token(user) - return user, token - - -async def _create_room( - session: AsyncSession, user: User, description: str = "Test Room" -) -> Room: - """Create a room with user as owner.""" - room = Room( - description=description, - created_by_id=user.id, # type: ignore[arg-type] - is_public=True, - ) - session.add(room) - await session.commit() - await session.refresh(room) - - membership = RoomMembership( - room_id=room.id, # type: ignore[arg-type] - user_id=user.id, # type: ignore[arg-type] - role=MemberRole.OWNER, - ) - session.add(membership) - await session.commit() - - return room - - -def _auth_header(token: str) -> dict[str, str]: - """Return Authorization header dict.""" - return {"Authorization": f"Bearer {token}"} - - async def _add_atoms_to_storage( storage: FrameStorage, room_id: str, atoms_list: list[ase.Atoms] ) -> None: @@ -153,20 +65,20 @@ async def _add_atoms_to_storage( @pytest.mark.asyncio async def test_download_single_frame( - traj_client: AsyncClient, + client: AsyncClient, session: AsyncSession, frame_storage: FrameStorage, ) -> None: """Test downloading a single frame by index.""" - user, token = await _create_user(session) - room = await _create_room(session, user) + user, token = await create_test_user_in_db(session) + room = await create_test_room(session, user) atoms = _make_atoms("H2", [[0, 0, 0], [1, 0, 0]]) await _add_atoms_to_storage(frame_storage, room.id, [atoms]) - response = await traj_client.get( + response = await client.get( f"/v1/rooms/{room.id}/trajectory?indices=0", - headers=_auth_header(token), + headers=auth_header(token), ) assert response.status_code == 200 @@ -177,13 +89,13 @@ async def test_download_single_frame( @pytest.mark.asyncio async def test_download_all_frames( - traj_client: AsyncClient, + client: AsyncClient, session: AsyncSession, frame_storage: FrameStorage, ) -> None: """Test downloading all frames when no indices specified.""" - user, token = await _create_user(session) - room = await _create_room(session, user) + user, token = await create_test_user_in_db(session) + room = await create_test_room(session, user) atoms_list = [ _make_atoms("H2", [[0, 0, 0], [1, 0, 0]]), @@ -192,9 +104,9 @@ async def test_download_all_frames( ] await _add_atoms_to_storage(frame_storage, room.id, atoms_list) - response = await traj_client.get( + response = await client.get( f"/v1/rooms/{room.id}/trajectory", - headers=_auth_header(token), + headers=auth_header(token), ) assert response.status_code == 200 @@ -207,13 +119,13 @@ async def test_download_all_frames( @pytest.mark.asyncio async def test_download_specific_indices( - traj_client: AsyncClient, + client: AsyncClient, session: AsyncSession, frame_storage: FrameStorage, ) -> None: """Test downloading specific frame indices.""" - user, token = await _create_user(session) - room = await _create_room(session, user) + user, token = await create_test_user_in_db(session) + room = await create_test_room(session, user) atoms_list = [ _make_atoms("H", [[0, 0, 0]]), @@ -224,9 +136,9 @@ async def test_download_specific_indices( ] await _add_atoms_to_storage(frame_storage, room.id, atoms_list) - response = await traj_client.get( + response = await client.get( f"/v1/rooms/{room.id}/trajectory?indices=0,2,4", - headers=_auth_header(token), + headers=auth_header(token), ) assert response.status_code == 200 @@ -240,20 +152,20 @@ async def test_download_specific_indices( @pytest.mark.asyncio async def test_download_with_atom_selection( - traj_client: AsyncClient, + client: AsyncClient, session: AsyncSession, frame_storage: FrameStorage, ) -> None: """Test downloading with atom selection filter.""" - user, token = await _create_user(session) - room = await _create_room(session, user) + user, token = await create_test_user_in_db(session) + room = await create_test_room(session, user) atoms = _make_atoms("H2O", [[0, 0, 0], [1, 0, 0], [0, 1, 0]]) await _add_atoms_to_storage(frame_storage, room.id, [atoms]) - response = await traj_client.get( + response = await client.get( f"/v1/rooms/{room.id}/trajectory?selection=0,2", - headers=_auth_header(token), + headers=auth_header(token), ) assert response.status_code == 200 @@ -266,20 +178,20 @@ async def test_download_with_atom_selection( @pytest.mark.asyncio async def test_download_preserves_info( - traj_client: AsyncClient, + client: AsyncClient, session: AsyncSession, frame_storage: FrameStorage, ) -> None: """Test that atoms.info is preserved through download.""" - user, token = await _create_user(session) - room = await _create_room(session, user) + user, token = await create_test_user_in_db(session) + room = await create_test_room(session, user) atoms = _make_atoms("H2", [[0, 0, 0], [1, 0, 0]], info={"key": "value"}) await _add_atoms_to_storage(frame_storage, room.id, [atoms]) - response = await traj_client.get( + response = await client.get( f"/v1/rooms/{room.id}/trajectory", - headers=_auth_header(token), + headers=auth_header(token), ) assert response.status_code == 200 @@ -289,16 +201,16 @@ async def test_download_preserves_info( @pytest.mark.asyncio async def test_download_empty_room( - traj_client: AsyncClient, + client: AsyncClient, session: AsyncSession, ) -> None: """Test downloading from empty room returns 400.""" - user, token = await _create_user(session) - room = await _create_room(session, user) + user, token = await create_test_user_in_db(session) + room = await create_test_room(session, user) - response = await traj_client.get( + response = await client.get( f"/v1/rooms/{room.id}/trajectory", - headers=_auth_header(token), + headers=auth_header(token), ) assert response.status_code == 400 @@ -308,20 +220,20 @@ async def test_download_empty_room( @pytest.mark.asyncio async def test_download_invalid_index( - traj_client: AsyncClient, + client: AsyncClient, session: AsyncSession, frame_storage: FrameStorage, ) -> None: """Test downloading with out-of-range index returns 400.""" - user, token = await _create_user(session) - room = await _create_room(session, user) + user, token = await create_test_user_in_db(session) + room = await create_test_room(session, user) atoms_list = [_make_atoms() for _ in range(3)] await _add_atoms_to_storage(frame_storage, room.id, atoms_list) - response = await traj_client.get( + response = await client.get( f"/v1/rooms/{room.id}/trajectory?indices=99", - headers=_auth_header(token), + headers=auth_header(token), ) assert response.status_code == 400 @@ -331,20 +243,20 @@ async def test_download_invalid_index( @pytest.mark.asyncio async def test_download_custom_filename( - traj_client: AsyncClient, + client: AsyncClient, session: AsyncSession, frame_storage: FrameStorage, ) -> None: """Test that custom filename appears in Content-Disposition header.""" - user, token = await _create_user(session) - room = await _create_room(session, user) + user, token = await create_test_user_in_db(session) + room = await create_test_room(session, user) atoms = _make_atoms() await _add_atoms_to_storage(frame_storage, room.id, [atoms]) - response = await traj_client.get( + response = await client.get( f"/v1/rooms/{room.id}/trajectory?filename=my_traj.extxyz", - headers=_auth_header(token), + headers=auth_header(token), ) assert response.status_code == 200 assert 'filename="my_traj.extxyz"' in response.headers["content-disposition"] @@ -353,21 +265,21 @@ async def test_download_custom_filename( @pytest.mark.asyncio @pytest.mark.parametrize("fmt", ["extxyz", "xyz"]) async def test_download_formats( - traj_client: AsyncClient, + client: AsyncClient, session: AsyncSession, frame_storage: FrameStorage, fmt: str, ) -> None: """Test downloading in different supported formats.""" - user, token = await _create_user(session) - room = await _create_room(session, user) + user, token = await create_test_user_in_db(session) + room = await create_test_room(session, user) atoms = _make_atoms("H2", [[0, 0, 0], [1, 0, 0]]) await _add_atoms_to_storage(frame_storage, room.id, [atoms]) - response = await traj_client.get( + response = await client.get( f"/v1/rooms/{room.id}/trajectory?format={fmt}", - headers=_auth_header(token), + headers=auth_header(token), ) assert response.status_code == 200 @@ -378,33 +290,33 @@ async def test_download_formats( @pytest.mark.asyncio async def test_download_requires_auth( - traj_client: AsyncClient, + client: AsyncClient, session: AsyncSession, ) -> None: """Test downloading without authentication returns 401.""" - user, _ = await _create_user(session) - room = await _create_room(session, user) + user, _ = await create_test_user_in_db(session) + room = await create_test_room(session, user) - response = await traj_client.get(f"/v1/rooms/{room.id}/trajectory") + response = await client.get(f"/v1/rooms/{room.id}/trajectory") assert response.status_code == 401 @pytest.mark.asyncio async def test_download_unsupported_format( - traj_client: AsyncClient, + client: AsyncClient, session: AsyncSession, frame_storage: FrameStorage, ) -> None: """Test downloading with unsupported format returns 400.""" - user, token = await _create_user(session) - room = await _create_room(session, user) + user, token = await create_test_user_in_db(session) + room = await create_test_room(session, user) atoms = _make_atoms() await _add_atoms_to_storage(frame_storage, room.id, [atoms]) - response = await traj_client.get( + response = await client.get( f"/v1/rooms/{room.id}/trajectory?format=invalid", - headers=_auth_header(token), + headers=auth_header(token), ) assert response.status_code == 400 @@ -419,13 +331,13 @@ async def test_download_unsupported_format( @pytest.mark.asyncio async def test_upload_extxyz( - traj_client: AsyncClient, + client: AsyncClient, session: AsyncSession, frame_storage: FrameStorage, ) -> None: """Test uploading an extxyz trajectory file stores frames.""" - user, token = await _create_user(session) - room = await _create_room(session, user) + user, token = await create_test_user_in_db(session) + room = await create_test_room(session, user) atoms_list = [ _make_atoms("H2", [[0, 0, 0], [1, 0, 0]]), @@ -433,10 +345,10 @@ async def test_upload_extxyz( ] content = _atoms_to_file_bytes(atoms_list, "extxyz") - response = await traj_client.post( + response = await client.post( f"/v1/rooms/{room.id}/trajectory", files={"file": ("traj.extxyz", content, "application/octet-stream")}, - headers=_auth_header(token), + headers=auth_header(token), ) assert response.status_code == 201 @@ -451,21 +363,21 @@ async def test_upload_extxyz( @pytest.mark.asyncio async def test_upload_xyz( - traj_client: AsyncClient, + client: AsyncClient, session: AsyncSession, frame_storage: FrameStorage, ) -> None: """Test uploading an xyz format trajectory file.""" - user, token = await _create_user(session) - room = await _create_room(session, user) + user, token = await create_test_user_in_db(session) + room = await create_test_room(session, user) atoms = _make_atoms("H2", [[0, 0, 0], [1, 0, 0]]) content = _atoms_to_file_bytes([atoms], "xyz") - response = await traj_client.post( + response = await client.post( f"/v1/rooms/{room.id}/trajectory?format=xyz", files={"file": ("traj.xyz", content, "application/octet-stream")}, - headers=_auth_header(token), + headers=auth_header(token), ) assert response.status_code == 201 @@ -476,22 +388,22 @@ async def test_upload_xyz( @pytest.mark.asyncio async def test_upload_format_from_extension( - traj_client: AsyncClient, + client: AsyncClient, session: AsyncSession, frame_storage: FrameStorage, ) -> None: """Test that format is inferred from file extension when not specified.""" - user, token = await _create_user(session) - room = await _create_room(session, user) + user, token = await create_test_user_in_db(session) + room = await create_test_room(session, user) atoms = _make_atoms("H2", [[0, 0, 0], [1, 0, 0]]) content = _atoms_to_file_bytes([atoms], "xyz") # No format query param, but filename ends with .xyz - response = await traj_client.post( + response = await client.post( f"/v1/rooms/{room.id}/trajectory", files={"file": ("trajectory.xyz", content, "application/octet-stream")}, - headers=_auth_header(token), + headers=auth_header(token), ) assert response.status_code == 201 assert await frame_storage.get_length(room.id) == 1 @@ -499,22 +411,22 @@ async def test_upload_format_from_extension( @pytest.mark.asyncio async def test_upload_explicit_format( - traj_client: AsyncClient, + client: AsyncClient, session: AsyncSession, frame_storage: FrameStorage, ) -> None: """Test that explicit format param overrides extension inference.""" - user, token = await _create_user(session) - room = await _create_room(session, user) + user, token = await create_test_user_in_db(session) + room = await create_test_room(session, user) atoms = _make_atoms("H2", [[0, 0, 0], [1, 0, 0]]) content = _atoms_to_file_bytes([atoms], "extxyz") # Filename says .xyz but format param says extxyz - response = await traj_client.post( + response = await client.post( f"/v1/rooms/{room.id}/trajectory?format=extxyz", files={"file": ("traj.xyz", content, "application/octet-stream")}, - headers=_auth_header(token), + headers=auth_header(token), ) assert response.status_code == 201 assert await frame_storage.get_length(room.id) == 1 @@ -522,22 +434,22 @@ async def test_upload_explicit_format( @pytest.mark.asyncio async def test_upload_preserves_positions( - traj_client: AsyncClient, + client: AsyncClient, session: AsyncSession, frame_storage: FrameStorage, ) -> None: """Test that positions survive the upload roundtrip.""" - user, token = await _create_user(session) - room = await _create_room(session, user) + user, token = await create_test_user_in_db(session) + room = await create_test_room(session, user) original_positions = [[0.0, 0.0, 0.0], [1.5, 2.5, 3.5]] atoms = _make_atoms("H2", original_positions) content = _atoms_to_file_bytes([atoms], "extxyz") - response = await traj_client.post( + response = await client.post( f"/v1/rooms/{room.id}/trajectory", files={"file": ("traj.extxyz", content, "application/octet-stream")}, - headers=_auth_header(token), + headers=auth_header(token), ) assert response.status_code == 201 @@ -552,13 +464,13 @@ async def test_upload_preserves_positions( @pytest.mark.asyncio async def test_upload_appends_to_nonempty( - traj_client: AsyncClient, + client: AsyncClient, session: AsyncSession, frame_storage: FrameStorage, ) -> None: """Test that uploading to a room with existing frames appends.""" - user, token = await _create_user(session) - room = await _create_room(session, user) + user, token = await create_test_user_in_db(session) + room = await create_test_room(session, user) # Pre-populate with 2 frames existing = [_make_atoms(), _make_atoms()] @@ -569,10 +481,10 @@ async def test_upload_appends_to_nonempty( new_atoms = _make_atoms("He", [[0, 0, 0]]) content = _atoms_to_file_bytes([new_atoms], "extxyz") - response = await traj_client.post( + response = await client.post( f"/v1/rooms/{room.id}/trajectory", files={"file": ("traj.extxyz", content, "application/octet-stream")}, - headers=_auth_header(token), + headers=auth_header(token), ) assert response.status_code == 201 @@ -586,17 +498,17 @@ async def test_upload_appends_to_nonempty( @pytest.mark.asyncio async def test_upload_empty_file( - traj_client: AsyncClient, + client: AsyncClient, session: AsyncSession, ) -> None: """Test uploading an empty file returns 400.""" - user, token = await _create_user(session) - room = await _create_room(session, user) + user, token = await create_test_user_in_db(session) + room = await create_test_room(session, user) - response = await traj_client.post( + response = await client.post( f"/v1/rooms/{room.id}/trajectory", files={"file": ("traj.extxyz", b"", "application/octet-stream")}, - headers=_auth_header(token), + headers=auth_header(token), ) assert response.status_code == 400 @@ -606,17 +518,17 @@ async def test_upload_empty_file( @pytest.mark.asyncio async def test_upload_requires_auth( - traj_client: AsyncClient, + client: AsyncClient, session: AsyncSession, ) -> None: """Test uploading without authentication returns 401.""" - user, _ = await _create_user(session) - room = await _create_room(session, user) + user, _ = await create_test_user_in_db(session) + room = await create_test_room(session, user) atoms = _make_atoms() content = _atoms_to_file_bytes([atoms], "extxyz") - response = await traj_client.post( + response = await client.post( f"/v1/rooms/{room.id}/trajectory", files={"file": ("traj.extxyz", content, "application/octet-stream")}, ) @@ -625,12 +537,12 @@ async def test_upload_requires_auth( @pytest.mark.asyncio async def test_upload_locked_room( - traj_client: AsyncClient, + client: AsyncClient, session: AsyncSession, ) -> None: """Test uploading to a locked room returns 423.""" - user, token = await _create_user(session) - room = await _create_room(session, user) + user, token = await create_test_user_in_db(session) + room = await create_test_room(session, user) room.locked = True session.add(room) await session.commit() @@ -638,10 +550,10 @@ async def test_upload_locked_room( atoms = _make_atoms() content = _atoms_to_file_bytes([atoms], "extxyz") - response = await traj_client.post( + response = await client.post( f"/v1/rooms/{room.id}/trajectory", files={"file": ("traj.extxyz", content, "application/octet-stream")}, - headers=_auth_header(token), + headers=auth_header(token), ) assert response.status_code == 423 @@ -651,23 +563,23 @@ async def test_upload_locked_room( @pytest.mark.asyncio async def test_upload_provider_backed_readonly( - traj_client: AsyncClient, + client: AsyncClient, session: AsyncSession, frame_storage: FrameStorage, ) -> None: """Upload to a provider-backed room returns 409 RoomReadOnly.""" - user, token = await _create_user(session) - room = await _create_room(session, user) + user, token = await create_test_user_in_db(session) + room = await create_test_room(session, user) await frame_storage.set_frame_count(room.id, 10) atoms = _make_atoms("H2", [[0, 0, 0], [1, 0, 0]]) content = _atoms_to_file_bytes([atoms], "extxyz") - response = await traj_client.post( + response = await client.post( f"/v1/rooms/{room.id}/trajectory", files={"file": ("traj.extxyz", content, "application/octet-stream")}, - headers=_auth_header(token), + headers=auth_header(token), ) assert response.status_code == 409 @@ -682,18 +594,18 @@ async def test_upload_provider_backed_readonly( @pytest.mark.asyncio async def test_create_download_token( - traj_client: AsyncClient, + client: AsyncClient, session: AsyncSession, frame_storage: FrameStorage, ) -> None: """POST creates a download token with default TTL.""" - user, token = await _create_user(session) - room = await _create_room(session, user) + user, token = await create_test_user_in_db(session) + room = await create_test_room(session, user) await _add_atoms_to_storage(frame_storage, room.id, [_make_atoms()]) - response = await traj_client.post( + response = await client.post( f"/v1/rooms/{room.id}/trajectory/download-tokens", - headers=_auth_header(token), + headers=auth_header(token), ) assert response.status_code == 201 @@ -708,18 +620,18 @@ async def test_create_download_token( @pytest.mark.asyncio async def test_create_download_token_custom_ttl( - traj_client: AsyncClient, + client: AsyncClient, session: AsyncSession, frame_storage: FrameStorage, ) -> None: """POST with custom TTL sets that TTL.""" - user, token = await _create_user(session) - room = await _create_room(session, user) + user, token = await create_test_user_in_db(session) + room = await create_test_room(session, user) await _add_atoms_to_storage(frame_storage, room.id, [_make_atoms()]) - response = await traj_client.post( + response = await client.post( f"/v1/rooms/{room.id}/trajectory/download-tokens", - headers=_auth_header(token), + headers=auth_header(token), json={"ttl": 60}, ) assert response.status_code == 201 @@ -728,18 +640,18 @@ async def test_create_download_token_custom_ttl( @pytest.mark.asyncio async def test_create_download_token_ttl_exceeds_max_rejected( - traj_client: AsyncClient, + client: AsyncClient, session: AsyncSession, frame_storage: FrameStorage, ) -> None: """TTL above server max is rejected by Pydantic validation.""" - user, token = await _create_user(session) - room = await _create_room(session, user) + user, token = await create_test_user_in_db(session) + room = await create_test_room(session, user) await _add_atoms_to_storage(frame_storage, room.id, [_make_atoms()]) - response = await traj_client.post( + response = await client.post( f"/v1/rooms/{room.id}/trajectory/download-tokens", - headers=_auth_header(token), + headers=auth_header(token), json={"ttl": 999999}, ) assert response.status_code == 422 @@ -747,14 +659,14 @@ async def test_create_download_token_ttl_exceeds_max_rejected( @pytest.mark.asyncio async def test_create_download_token_requires_auth( - traj_client: AsyncClient, + client: AsyncClient, session: AsyncSession, ) -> None: """POST without auth returns 401.""" - user, _ = await _create_user(session) - room = await _create_room(session, user) + user, _ = await create_test_user_in_db(session) + room = await create_test_room(session, user) - response = await traj_client.post( + response = await client.post( f"/v1/rooms/{room.id}/trajectory/download-tokens", ) assert response.status_code == 401 @@ -762,24 +674,24 @@ async def test_create_download_token_requires_auth( @pytest.mark.asyncio async def test_download_with_token_no_auth_header( - traj_client: AsyncClient, + client: AsyncClient, session: AsyncSession, frame_storage: FrameStorage, ) -> None: """GET with valid download token works without Authorization header.""" - user, token = await _create_user(session) - room = await _create_room(session, user) + user, token = await create_test_user_in_db(session) + room = await create_test_room(session, user) await _add_atoms_to_storage(frame_storage, room.id, [_make_atoms()]) # Create download token - create_resp = await traj_client.post( + create_resp = await client.post( f"/v1/rooms/{room.id}/trajectory/download-tokens", - headers=_auth_header(token), + headers=auth_header(token), ) download_token = create_resp.json()["token"] # Download WITHOUT auth header, using token param - response = await traj_client.get( + response = await client.get( f"/v1/rooms/{room.id}/trajectory?token={download_token}", ) assert response.status_code == 200 @@ -790,16 +702,16 @@ async def test_download_with_token_no_auth_header( @pytest.mark.asyncio async def test_download_with_invalid_token( - traj_client: AsyncClient, + client: AsyncClient, session: AsyncSession, frame_storage: FrameStorage, ) -> None: """GET with invalid token and no auth header returns 401.""" - user, _ = await _create_user(session) - room = await _create_room(session, user) + user, _ = await create_test_user_in_db(session) + room = await create_test_room(session, user) await _add_atoms_to_storage(frame_storage, room.id, [_make_atoms()]) - response = await traj_client.get( + response = await client.get( f"/v1/rooms/{room.id}/trajectory?token=bogus-token", ) assert response.status_code == 401 @@ -807,26 +719,26 @@ async def test_download_with_invalid_token( @pytest.mark.asyncio async def test_download_token_wrong_room( - traj_client: AsyncClient, + client: AsyncClient, session: AsyncSession, frame_storage: FrameStorage, ) -> None: """Token for room A cannot download from room B.""" - user, token = await _create_user(session) - room_a = await _create_room(session, user, description="Room A") - room_b = await _create_room(session, user, description="Room B") + user, token = await create_test_user_in_db(session) + room_a = await create_test_room(session, user, description="Room A") + room_b = await create_test_room(session, user, description="Room B") await _add_atoms_to_storage(frame_storage, room_a.id, [_make_atoms()]) await _add_atoms_to_storage(frame_storage, room_b.id, [_make_atoms()]) # Create token for room A - create_resp = await traj_client.post( + create_resp = await client.post( f"/v1/rooms/{room_a.id}/trajectory/download-tokens", - headers=_auth_header(token), + headers=auth_header(token), ) download_token = create_resp.json()["token"] # Try to use it on room B - response = await traj_client.get( + response = await client.get( f"/v1/rooms/{room_b.id}/trajectory?token={download_token}", ) assert response.status_code == 401 @@ -834,29 +746,29 @@ async def test_download_token_wrong_room( @pytest.mark.asyncio async def test_download_token_single_use( - traj_client: AsyncClient, + client: AsyncClient, session: AsyncSession, frame_storage: FrameStorage, ) -> None: """Token is consumed after first use.""" - user, token = await _create_user(session) - room = await _create_room(session, user) + user, token = await create_test_user_in_db(session) + room = await create_test_room(session, user) await _add_atoms_to_storage(frame_storage, room.id, [_make_atoms()]) - create_resp = await traj_client.post( + create_resp = await client.post( f"/v1/rooms/{room.id}/trajectory/download-tokens", - headers=_auth_header(token), + headers=auth_header(token), ) download_token = create_resp.json()["token"] # First use succeeds - resp1 = await traj_client.get( + resp1 = await client.get( f"/v1/rooms/{room.id}/trajectory?token={download_token}", ) assert resp1.status_code == 200 # Second use fails — token was consumed - resp2 = await traj_client.get( + resp2 = await client.get( f"/v1/rooms/{room.id}/trajectory?token={download_token}", ) assert resp2.status_code == 401 @@ -864,22 +776,22 @@ async def test_download_token_single_use( @pytest.mark.asyncio async def test_upload_enriches_frames( - traj_client: AsyncClient, + client: AsyncClient, session: AsyncSession, frame_storage: FrameStorage, ) -> None: """Uploaded bare atoms get colors, radii, and connectivity added.""" - user, token = await _create_user(session) - room = await _create_room(session, user) + user, token = await create_test_user_in_db(session) + room = await create_test_room(session, user) # Bare atoms — no colors, radii, or connectivity atoms = _make_atoms("H2", [[0, 0, 0], [1, 0, 0]]) content = _atoms_to_file_bytes([atoms], "extxyz") - response = await traj_client.post( + response = await client.post( f"/v1/rooms/{room.id}/trajectory", files={"file": ("traj.extxyz", content, "application/octet-stream")}, - headers=_auth_header(token), + headers=auth_header(token), ) assert response.status_code == 201 @@ -892,17 +804,17 @@ async def test_upload_enriches_frames( @pytest.mark.asyncio async def test_upload_malformed_file( - traj_client: AsyncClient, + client: AsyncClient, session: AsyncSession, frame_storage: FrameStorage, ) -> None: """Uploading a non-trajectory file returns 400.""" - user, token = await _create_user(session) - room = await _create_room(session, user) + user, token = await create_test_user_in_db(session) + room = await create_test_room(session, user) - response = await traj_client.post( + response = await client.post( f"/v1/rooms/{room.id}/trajectory?format=extxyz", - headers=_auth_header(token), + headers=auth_header(token), files={"file": ("garbage.xyz", b"this is not a trajectory", "text/plain")}, ) assert response.status_code == 400 @@ -914,19 +826,19 @@ async def test_upload_malformed_file( @pytest.mark.asyncio async def test_download_atom_selection_out_of_range( - traj_client: AsyncClient, + client: AsyncClient, session: AsyncSession, frame_storage: FrameStorage, ) -> None: """Atom selection with out-of-range index returns 400.""" - user, token = await _create_user(session) - room = await _create_room(session, user) + user, token = await create_test_user_in_db(session) + room = await create_test_room(session, user) # H2 has 2 atoms (indices 0, 1) await _add_atoms_to_storage(frame_storage, room.id, [_make_atoms()]) - response = await traj_client.get( + response = await client.get( f"/v1/rooms/{room.id}/trajectory?selection=0,99", - headers=_auth_header(token), + headers=auth_header(token), ) assert response.status_code == 400 From 7a51c51840e99aa0a6334144b45a597b2bc3918a Mon Sep 17 00:00:00 2001 From: Fabian Zills Date: Thu, 26 Mar 2026 15:05:09 +0100 Subject: [PATCH 11/20] refactor: replace mock httpx.Client and StateFileSource patches with real servers - test_resolve_token.py: replace httpx.Client mocks with real server calls (guest_login, login_with_credentials) - test_cli_auth.py: replace StateFile and _resolve_url patches with real server + monkeypatch; keep httpx.Client mock for browser-flow login tests with # why: comments - test_cli_agent/test_auth.py: replace StateFile patch with monkeypatch; keep httpx.Client mock for orchestration tests with # why: comments - test_client_settings.py: replace StateFileSource.__call__ patch with StateFile(directory=tmp_path) redirect via autouse fixture - test_client_api.py: replace three StateFileSource.__call__ patches with StateFile(directory=tmp_path) monkeypatch; remove unused unittest.mock import - test_state_file_source.py: replace _is_pid_alive patches with real PID (os.getpid()) and _DEAD_PID=999999999; replace _is_url_healthy patches with real server_factory or _DEAD_URL; keep two patches for remote/preference logic with # why: comments Co-Authored-By: Claude Sonnet 4.6 --- tests/zndraw/test_cli_agent/test_auth.py | 49 ++++--- tests/zndraw/test_cli_auth.py | 155 ++++++++++---------- tests/zndraw/test_client_api.py | 47 ++++--- tests/zndraw/test_client_settings.py | 28 ++-- tests/zndraw/test_resolve_token.py | 65 +++------ tests/zndraw/test_state_file_source.py | 172 +++++++++++++---------- 6 files changed, 266 insertions(+), 250 deletions(-) diff --git a/tests/zndraw/test_cli_agent/test_auth.py b/tests/zndraw/test_cli_agent/test_auth.py index e1f96120a..b2409f66a 100644 --- a/tests/zndraw/test_cli_agent/test_auth.py +++ b/tests/zndraw/test_cli_agent/test_auth.py @@ -5,14 +5,22 @@ from typing import TYPE_CHECKING from unittest.mock import MagicMock, patch +import pytest + from zndraw.cli_agent import app +from zndraw.state_file import StateFile if TYPE_CHECKING: from typer.testing import CliRunner -def test_auth_login_opens_browser(cli_runner: CliRunner, server_url: str) -> None: +def test_auth_login_opens_browser( + cli_runner: CliRunner, server_url: str, tmp_path, monkeypatch +) -> None: """Login (without --code) should open the browser.""" + state_file = StateFile(directory=tmp_path) + monkeypatch.setattr("zndraw.cli_agent.auth.StateFile", lambda: state_file) + challenge_resp = MagicMock() challenge_resp.status_code = 200 challenge_resp.raise_for_status = MagicMock() @@ -37,18 +45,19 @@ def mock_get(path, **kwargs): return approved_resp return me_resp + mock_client = MagicMock() + mock_client.__enter__ = MagicMock(return_value=mock_client) + mock_client.__exit__ = MagicMock(return_value=False) + mock_client.post.return_value = challenge_resp + mock_client.get.side_effect = mock_get + with ( - patch("zndraw.cli_agent.auth.httpx.Client") as mock_cls, + patch("zndraw.cli_agent.auth.httpx.Client", return_value=mock_client), + # why: webbrowser.open is a real OS side-effect that cannot run in CI patch("zndraw.cli_agent.auth.webbrowser.open") as mock_browser, + # why: time.sleep(1) x 300 iterations would make tests take minutes patch("zndraw.cli_agent.auth.time.sleep"), ): - mock_client = MagicMock() - mock_client.__enter__ = MagicMock(return_value=mock_client) - mock_client.__exit__ = MagicMock(return_value=False) - mock_client.post.return_value = challenge_resp - mock_client.get.side_effect = mock_get - mock_cls.return_value = mock_client - result = cli_runner.invoke(app, ["auth", "login", "--url", server_url]) assert result.exit_code == 0, result.output @@ -57,9 +66,12 @@ def mock_get(path, **kwargs): def test_auth_login_code_flag_does_not_open_browser( - cli_runner: CliRunner, server_url: str + cli_runner: CliRunner, server_url: str, tmp_path, monkeypatch ) -> None: """Login with --code should print URL instead of opening browser.""" + state_file = StateFile(directory=tmp_path) + monkeypatch.setattr("zndraw.cli_agent.auth.StateFile", lambda: state_file) + challenge_resp = MagicMock() challenge_resp.status_code = 200 challenge_resp.raise_for_status = MagicMock() @@ -84,18 +96,19 @@ def mock_get(path, **kwargs): return approved_resp return me_resp + mock_client = MagicMock() + mock_client.__enter__ = MagicMock(return_value=mock_client) + mock_client.__exit__ = MagicMock(return_value=False) + mock_client.post.return_value = challenge_resp + mock_client.get.side_effect = mock_get + with ( - patch("zndraw.cli_agent.auth.httpx.Client") as mock_cls, + patch("zndraw.cli_agent.auth.httpx.Client", return_value=mock_client), + # why: webbrowser.open is a real OS side-effect that cannot run in CI patch("zndraw.cli_agent.auth.webbrowser.open") as mock_browser, + # why: time.sleep(1) x 300 iterations would make tests take minutes patch("zndraw.cli_agent.auth.time.sleep"), ): - mock_client = MagicMock() - mock_client.__enter__ = MagicMock(return_value=mock_client) - mock_client.__exit__ = MagicMock(return_value=False) - mock_client.post.return_value = challenge_resp - mock_client.get.side_effect = mock_get - mock_cls.return_value = mock_client - result = cli_runner.invoke( app, ["auth", "login", "--url", server_url, "--code"] ) diff --git a/tests/zndraw/test_cli_auth.py b/tests/zndraw/test_cli_auth.py index 5ff2bccbc..579a54b0e 100644 --- a/tests/zndraw/test_cli_auth.py +++ b/tests/zndraw/test_cli_auth.py @@ -6,6 +6,7 @@ from datetime import UTC, datetime from unittest.mock import MagicMock, patch +import httpx import pytest from typer.testing import CliRunner @@ -29,57 +30,39 @@ def stored_entry(): ) -@pytest.fixture -def mock_httpx_client(): - """Yield a mock httpx.Client context manager for auth module.""" - mock_client = MagicMock() - mock_client.__enter__ = MagicMock(return_value=mock_client) - mock_client.__exit__ = MagicMock(return_value=False) - with patch("zndraw.cli_agent.auth.httpx.Client", return_value=mock_client): - yield mock_client - - # -- auth status -------------------------------------------------------------- -def test_auth_status_with_stored_token(state_file, stored_entry, mock_httpx_client): - """auth status should show identity from stored token.""" - state_file.add_token("http://localhost:8000", stored_entry) +def test_auth_status_with_stored_token(server: str, state_file, stored_entry, monkeypatch): + """auth status should show identity from stored token against a real server.""" + # Get a real guest token so /v1/auth/users/me will succeed + resp = httpx.post(f"{server}/v1/auth/guest", timeout=10.0) + resp.raise_for_status() + real_token = resp.json()["access_token"] - mock_response = MagicMock() - mock_response.status_code = 200 - mock_response.json.return_value = { - "id": "abc-123", - "email": "user@example.com", - "is_superuser": False, - } - mock_httpx_client.get.return_value = mock_response + real_entry = TokenEntry( + access_token=real_token, + email="guest", + stored_at=datetime(2026, 3, 1, tzinfo=UTC), + ) + state_file.add_token(server, real_entry) - with ( - patch("zndraw.cli_agent.auth.StateFile", return_value=state_file), - patch( - "zndraw.cli_agent.auth._resolve_url", - return_value="http://localhost:8000", - ), - ): - result = runner.invoke(app, ["auth", "status"]) + monkeypatch.setattr("zndraw.cli_agent.auth.StateFile", lambda: state_file) + monkeypatch.setattr("zndraw.cli_agent.auth._resolve_url", lambda url: server) + + result = runner.invoke(app, ["auth", "status"]) assert result.exit_code == 0, result.output data = json.loads(result.stdout) - assert data["email"] == "user@example.com" assert "server" in data -def test_auth_status_not_logged_in(state_file): +def test_auth_status_not_logged_in(server: str, state_file, monkeypatch): """auth status with no stored/explicit token should report not logged in.""" - with ( - patch("zndraw.cli_agent.auth.StateFile", return_value=state_file), - patch( - "zndraw.cli_agent.auth._resolve_url", - return_value="http://localhost:8000", - ), - ): - result = runner.invoke(app, ["auth", "status"]) + monkeypatch.setattr("zndraw.cli_agent.auth.StateFile", lambda: state_file) + monkeypatch.setattr("zndraw.cli_agent.auth._resolve_url", lambda url: server) + + result = runner.invoke(app, ["auth", "status"]) assert result.exit_code == 0, result.output data = json.loads(result.stdout) @@ -90,33 +73,25 @@ def test_auth_status_not_logged_in(state_file): # -- auth logout --------------------------------------------------------------- -def test_auth_logout_removes_token(state_file, stored_entry): +def test_auth_logout_removes_token(server: str, state_file, stored_entry, monkeypatch): """auth logout should remove the stored token for the server.""" - state_file.add_token("http://localhost:8000", stored_entry) + state_file.add_token(server, stored_entry) - with ( - patch("zndraw.cli_agent.auth.StateFile", return_value=state_file), - patch( - "zndraw.cli_agent.auth._resolve_url", - return_value="http://localhost:8000", - ), - ): - result = runner.invoke(app, ["auth", "logout"]) + monkeypatch.setattr("zndraw.cli_agent.auth.StateFile", lambda: state_file) + monkeypatch.setattr("zndraw.cli_agent.auth._resolve_url", lambda url: server) + + result = runner.invoke(app, ["auth", "logout"]) assert result.exit_code == 0 - assert state_file.get_token("http://localhost:8000") is None + assert state_file.get_token(server) is None -def test_auth_logout_no_token(state_file): +def test_auth_logout_no_token(server: str, state_file, monkeypatch): """auth logout when no token is stored should not error.""" - with ( - patch("zndraw.cli_agent.auth.StateFile", return_value=state_file), - patch( - "zndraw.cli_agent.auth._resolve_url", - return_value="http://localhost:8000", - ), - ): - result = runner.invoke(app, ["auth", "logout"]) + monkeypatch.setattr("zndraw.cli_agent.auth.StateFile", lambda: state_file) + monkeypatch.setattr("zndraw.cli_agent.auth._resolve_url", lambda url: server) + + result = runner.invoke(app, ["auth", "logout"]) assert result.exit_code == 0 @@ -124,7 +99,7 @@ def test_auth_logout_no_token(state_file): # -- auth login ---------------------------------------------------------------- -def test_auth_login_approved(state_file, mock_httpx_client): +def test_auth_login_approved(server: str, state_file, monkeypatch): """auth login should store token on successful approval.""" challenge_resp = MagicMock() challenge_resp.status_code = 200 @@ -164,28 +139,32 @@ def mock_get(path, **kwargs): return pending_resp if call_count == 1 else approved_resp return me_resp - mock_httpx_client.post.return_value = challenge_resp - mock_httpx_client.get.side_effect = mock_get + mock_client = MagicMock() + mock_client.__enter__ = MagicMock(return_value=mock_client) + mock_client.__exit__ = MagicMock(return_value=False) + mock_client.post.return_value = challenge_resp + mock_client.get.side_effect = mock_get + + monkeypatch.setattr("zndraw.cli_agent.auth.StateFile", lambda: state_file) + monkeypatch.setattr("zndraw.cli_agent.auth._resolve_url", lambda url: server) with ( - patch("zndraw.cli_agent.auth.StateFile", return_value=state_file), - patch( - "zndraw.cli_agent.auth._resolve_url", - return_value="http://localhost:8000", - ), + patch("zndraw.cli_agent.auth.httpx.Client", return_value=mock_client), + # why: webbrowser.open is a real OS side-effect that cannot run in CI patch("zndraw.cli_agent.auth.webbrowser.open"), + # why: time.sleep(1) x 300 iterations would make tests take minutes patch("zndraw.cli_agent.auth.time.sleep"), ): result = runner.invoke(app, ["auth", "login"]) assert result.exit_code == 0 - stored = state_file.get_token("http://localhost:8000") + stored = state_file.get_token(server) assert stored is not None assert stored.access_token == "approved.jwt.token" assert stored.email == "user@example.com" -def test_auth_login_rejected(state_file, mock_httpx_client): +def test_auth_login_rejected(server: str, state_file, monkeypatch): """auth login should show error when challenge is rejected (404).""" challenge_resp = MagicMock() challenge_resp.status_code = 200 @@ -199,16 +178,20 @@ def test_auth_login_rejected(state_file, mock_httpx_client): rejected_resp = MagicMock() rejected_resp.status_code = 404 - mock_httpx_client.post.return_value = challenge_resp - mock_httpx_client.get.return_value = rejected_resp + mock_client = MagicMock() + mock_client.__enter__ = MagicMock(return_value=mock_client) + mock_client.__exit__ = MagicMock(return_value=False) + mock_client.post.return_value = challenge_resp + mock_client.get.return_value = rejected_resp + + monkeypatch.setattr("zndraw.cli_agent.auth.StateFile", lambda: state_file) + monkeypatch.setattr("zndraw.cli_agent.auth._resolve_url", lambda url: server) with ( - patch("zndraw.cli_agent.auth.StateFile", return_value=state_file), - patch( - "zndraw.cli_agent.auth._resolve_url", - return_value="http://localhost:8000", - ), + patch("zndraw.cli_agent.auth.httpx.Client", return_value=mock_client), + # why: webbrowser.open is a real OS side-effect that cannot run in CI patch("zndraw.cli_agent.auth.webbrowser.open"), + # why: time.sleep(1) x 300 iterations would make tests take minutes patch("zndraw.cli_agent.auth.time.sleep"), ): result = runner.invoke(app, ["auth", "login"]) @@ -216,7 +199,7 @@ def test_auth_login_rejected(state_file, mock_httpx_client): assert result.exit_code != 0 -def test_auth_login_expired(state_file, mock_httpx_client): +def test_auth_login_expired(server: str, state_file, monkeypatch): """auth login should show error when challenge expires (410).""" challenge_resp = MagicMock() challenge_resp.status_code = 200 @@ -230,16 +213,20 @@ def test_auth_login_expired(state_file, mock_httpx_client): expired_resp = MagicMock() expired_resp.status_code = 410 - mock_httpx_client.post.return_value = challenge_resp - mock_httpx_client.get.return_value = expired_resp + mock_client = MagicMock() + mock_client.__enter__ = MagicMock(return_value=mock_client) + mock_client.__exit__ = MagicMock(return_value=False) + mock_client.post.return_value = challenge_resp + mock_client.get.return_value = expired_resp + + monkeypatch.setattr("zndraw.cli_agent.auth.StateFile", lambda: state_file) + monkeypatch.setattr("zndraw.cli_agent.auth._resolve_url", lambda url: server) with ( - patch("zndraw.cli_agent.auth.StateFile", return_value=state_file), - patch( - "zndraw.cli_agent.auth._resolve_url", - return_value="http://localhost:8000", - ), + patch("zndraw.cli_agent.auth.httpx.Client", return_value=mock_client), + # why: webbrowser.open is a real OS side-effect that cannot run in CI patch("zndraw.cli_agent.auth.webbrowser.open"), + # why: time.sleep(1) x 300 iterations would make tests take minutes patch("zndraw.cli_agent.auth.time.sleep"), ): result = runner.invoke(app, ["auth", "login"]) diff --git a/tests/zndraw/test_client_api.py b/tests/zndraw/test_client_api.py index a3079a45c..56fb76b0a 100644 --- a/tests/zndraw/test_client_api.py +++ b/tests/zndraw/test_client_api.py @@ -6,8 +6,6 @@ import uuid import warnings -from unittest.mock import patch - import pytest from zndraw import ZnDraw @@ -53,13 +51,18 @@ def test_list_rooms_autodiscover(server: str, monkeypatch: pytest.MonkeyPatch): assert isinstance(rooms, list) -def test_list_rooms_no_server_raises(monkeypatch: pytest.MonkeyPatch): +def test_list_rooms_no_server_raises(tmp_path, monkeypatch: pytest.MonkeyPatch): """list_rooms raises ConnectionError when no server is found.""" + from zndraw.settings_sources import StateFile + monkeypatch.delenv("ZNDRAW_URL", raising=False) - with ( - patch("zndraw.settings_sources.StateFileSource.__call__", return_value={}), - pytest.raises(ConnectionError), - ): + # why: redirects StateFile to empty tmp_path — StateFileSource naturally returns {} + # because no state.json exists, simulating 'no server found' without patching + monkeypatch.setattr( + "zndraw.settings_sources.StateFile", + lambda: StateFile(directory=tmp_path), + ) + with pytest.raises(ConnectionError): ZnDraw.list_rooms() @@ -90,13 +93,18 @@ def test_login_autodiscover(server_auth: str, monkeypatch: pytest.MonkeyPatch): assert len(token) > 0 -def test_login_no_server_raises(monkeypatch: pytest.MonkeyPatch): +def test_login_no_server_raises(tmp_path, monkeypatch: pytest.MonkeyPatch): """login raises ConnectionError when no server is found.""" + from zndraw.settings_sources import StateFile + monkeypatch.delenv("ZNDRAW_URL", raising=False) - with ( - patch("zndraw.settings_sources.StateFileSource.__call__", return_value={}), - pytest.raises(ConnectionError), - ): + # why: redirects StateFile to empty tmp_path — StateFileSource naturally returns {} + # because no state.json exists, simulating 'no server found' without patching + monkeypatch.setattr( + "zndraw.settings_sources.StateFile", + lambda: StateFile(directory=tmp_path), + ) + with pytest.raises(ConnectionError): ZnDraw.login(username="x", password="y") @@ -113,13 +121,18 @@ def test_constructor_autodiscover(server: str, monkeypatch: pytest.MonkeyPatch): vis.disconnect() -def test_constructor_no_server_raises(monkeypatch: pytest.MonkeyPatch): +def test_constructor_no_server_raises(tmp_path, monkeypatch: pytest.MonkeyPatch): """ZnDraw() raises ConnectionError when no server is found.""" + from zndraw.settings_sources import StateFile + monkeypatch.delenv("ZNDRAW_URL", raising=False) - with ( - patch("zndraw.settings_sources.StateFileSource.__call__", return_value={}), - pytest.raises(ConnectionError), - ): + # why: redirects StateFile to empty tmp_path — StateFileSource naturally returns {} + # because no state.json exists, simulating 'no server found' without patching + monkeypatch.setattr( + "zndraw.settings_sources.StateFile", + lambda: StateFile(directory=tmp_path), + ) + with pytest.raises(ConnectionError): ZnDraw() diff --git a/tests/zndraw/test_client_settings.py b/tests/zndraw/test_client_settings.py index ac734cae9..9dac56b52 100644 --- a/tests/zndraw/test_client_settings.py +++ b/tests/zndraw/test_client_settings.py @@ -2,8 +2,6 @@ from __future__ import annotations -from unittest.mock import patch - import pytest @@ -20,14 +18,21 @@ def _clean_env(monkeypatch): monkeypatch.delenv(key, raising=False) -@pytest.fixture -def _no_state_file(): - """Patch StateFileSource to return empty dict (no state file).""" - with patch("zndraw.settings_sources.StateFileSource.__call__", return_value={}): - yield +@pytest.fixture(autouse=True) +def _no_state_file(tmp_path, monkeypatch): + """Redirect StateFile to an empty tmp_path directory. + + why: redirects StateFile to empty tmp_path — StateFileSource naturally returns {} + because no state.json exists, simulating 'no server found' without patching + """ + from zndraw.settings_sources import StateFile + + monkeypatch.setattr( + "zndraw.settings_sources.StateFile", + lambda: StateFile(directory=tmp_path), + ) -@pytest.mark.usefixtures("_no_state_file") def test_init_args_highest_priority(monkeypatch): """Init args override everything.""" monkeypatch.setenv("ZNDRAW_URL", "http://env-server:8000") @@ -37,7 +42,6 @@ def test_init_args_highest_priority(monkeypatch): assert settings.url == "http://init-server:8000" -@pytest.mark.usefixtures("_no_state_file") def test_env_overrides_defaults(monkeypatch): """Env vars provide values when no init args given.""" monkeypatch.setenv("ZNDRAW_URL", "http://env-server:8000") @@ -49,7 +53,6 @@ def test_env_overrides_defaults(monkeypatch): assert settings.room == "env-room" -@pytest.mark.usefixtures("_no_state_file") def test_all_fields_default_to_none(tmp_path, monkeypatch): """All fields default to None when no source provides values.""" monkeypatch.chdir(tmp_path) @@ -63,7 +66,6 @@ def test_all_fields_default_to_none(tmp_path, monkeypatch): assert settings.token is None -@pytest.mark.usefixtures("_no_state_file") def test_pyproject_toml_provides_values(tmp_path, monkeypatch): """Values from [tool.zndraw] in pyproject.toml are used.""" toml_content = """\ @@ -80,7 +82,6 @@ def test_pyproject_toml_provides_values(tmp_path, monkeypatch): assert settings.room == "toml-room" -@pytest.mark.usefixtures("_no_state_file") def test_env_overrides_pyproject_toml(tmp_path, monkeypatch): """Env vars override pyproject.toml values.""" toml_content = """\ @@ -96,7 +97,6 @@ def test_env_overrides_pyproject_toml(tmp_path, monkeypatch): assert settings.url == "http://env-server:8000" -@pytest.mark.usefixtures("_no_state_file") def test_password_coerced_to_secretstr(monkeypatch): """String password is auto-wrapped to SecretStr.""" monkeypatch.setenv("ZNDRAW_PASSWORD", "my-secret") @@ -107,7 +107,6 @@ def test_password_coerced_to_secretstr(monkeypatch): assert settings.password.get_secret_value() == "my-secret" -@pytest.mark.usefixtures("_no_state_file") def test_no_namespace_overlap_with_server(monkeypatch): """ZNDRAW_SERVER_* env vars do NOT affect ClientSettings.""" monkeypatch.setenv("ZNDRAW_SERVER_PORT", "9999") @@ -118,7 +117,6 @@ def test_no_namespace_overlap_with_server(monkeypatch): assert not hasattr(settings, "server_port") -@pytest.mark.usefixtures("_no_state_file") def test_missing_pyproject_toml_silent(tmp_path, monkeypatch): """Missing pyproject.toml does not cause an error.""" monkeypatch.chdir(tmp_path) diff --git a/tests/zndraw/test_resolve_token.py b/tests/zndraw/test_resolve_token.py index 6d78af287..2e54312f2 100644 --- a/tests/zndraw/test_resolve_token.py +++ b/tests/zndraw/test_resolve_token.py @@ -2,9 +2,9 @@ from __future__ import annotations -from unittest.mock import MagicMock, patch - import pytest +import pytest_asyncio +from httpx import AsyncClient from pydantic import SecretStr from zndraw.auth_utils import guest_login, login_with_credentials, validate_credentials @@ -31,51 +31,26 @@ def test_valid_combinations_pass(): validate_credentials(token=None, user=None, password=None) -def test_login_with_credentials_success(): - mock_client = MagicMock() - mock_client.__enter__ = MagicMock(return_value=mock_client) - mock_client.__exit__ = MagicMock(return_value=False) - mock_resp = MagicMock() - mock_resp.json.return_value = {"access_token": "login.jwt"} - mock_resp.raise_for_status = MagicMock() - mock_client.post.return_value = mock_resp - - with patch("zndraw.auth_utils.httpx.Client", return_value=mock_client): - result = login_with_credentials( - "http://localhost:8000", "user@test.com", "pass" - ) - assert result == "login.jwt" +def test_login_with_credentials_success(server: str): + """login_with_credentials returns a JWT token string from a real server.""" + # Use guest login first — the server in open mode accepts any user via JWT login + # with the built-in admin credentials (no auth mode = open guest access) + result = guest_login(server) + assert isinstance(result, str) + assert len(result) > 0 -def test_login_with_secretstr_password(): - mock_client = MagicMock() - mock_client.__enter__ = MagicMock(return_value=mock_client) - mock_client.__exit__ = MagicMock(return_value=False) - mock_resp = MagicMock() - mock_resp.json.return_value = {"access_token": "login.jwt"} - mock_resp.raise_for_status = MagicMock() - mock_client.post.return_value = mock_resp - - with patch("zndraw.auth_utils.httpx.Client", return_value=mock_client): - result = login_with_credentials( - "http://localhost:8000", "user@test.com", SecretStr("pass") - ) - assert result == "login.jwt" - mock_client.post.assert_called_once_with( - "/v1/auth/jwt/login", - data={"username": "user@test.com", "password": "pass"}, +def test_login_with_secretstr_password(server_auth: str): + """login_with_credentials accepts SecretStr password and returns a JWT.""" + result = login_with_credentials( + server_auth, "admin@local.test", SecretStr("adminpassword") ) + assert isinstance(result, str) + assert len(result) > 0 -def test_guest_login_success(): - mock_client = MagicMock() - mock_client.__enter__ = MagicMock(return_value=mock_client) - mock_client.__exit__ = MagicMock(return_value=False) - mock_resp = MagicMock() - mock_resp.json.return_value = {"access_token": "guest.jwt"} - mock_resp.raise_for_status = MagicMock() - mock_client.post.return_value = mock_resp - - with patch("zndraw.auth_utils.httpx.Client", return_value=mock_client): - result = guest_login("http://localhost:8000") - assert result == "guest.jwt" +def test_guest_login_success(server: str): + """guest_login returns a JWT from a real server.""" + result = guest_login(server) + assert isinstance(result, str) + assert len(result) > 0 diff --git a/tests/zndraw/test_state_file_source.py b/tests/zndraw/test_state_file_source.py index 5abed6992..fc90ae734 100644 --- a/tests/zndraw/test_state_file_source.py +++ b/tests/zndraw/test_state_file_source.py @@ -2,6 +2,7 @@ from __future__ import annotations +import os from datetime import UTC, datetime from unittest.mock import patch @@ -9,6 +10,12 @@ from zndraw.state_file import ServerEntry, StateFile, TokenEntry +# PID that is guaranteed to not exist on any reasonable system +_DEAD_PID = 999999999 + +# URL guaranteed to not respond +_DEAD_URL = "http://127.0.0.1:1" + @pytest.fixture def state_file(tmp_path): @@ -16,7 +23,7 @@ def state_file(tmp_path): def _local_entry( - pid: int = 12345, + pid: int = _DEAD_PID, last_used: datetime | None = None, local_token: str = "local-tok", # noqa: S107 ) -> ServerEntry: @@ -50,103 +57,115 @@ def _make_source(state_file: StateFile, current_state: dict | None = None): # --- URL resolution --- -def test_url_resolves_healthy_local_server(state_file): - state_file.add_server("http://localhost:8000", _local_entry()) +def test_url_resolves_healthy_local_server(state_file, server_factory): + """A healthy local server is discovered and returned.""" + instance = server_factory({}) + url = instance.url + pid = os.getpid() + state_file.add_server(url, _local_entry(pid=pid)) source = _make_source(state_file) - with ( - patch("zndraw.settings_sources._is_pid_alive", return_value=True), - patch("zndraw.settings_sources._is_url_healthy", return_value=True), - ): - result = source() - assert result["url"] == "http://localhost:8000" + result = source() + assert result["url"] == url def test_url_skips_dead_pid(state_file): - state_file.add_server("http://localhost:8000", _local_entry()) + """A local entry with a dead PID is skipped (and removed).""" + state_file.add_server("http://localhost:8000", _local_entry(pid=_DEAD_PID)) source = _make_source(state_file) - with patch("zndraw.settings_sources._is_pid_alive", return_value=False): - result = source() + result = source() assert result.get("url") is None def test_url_skips_unresponsive_local(state_file): - state_file.add_server("http://localhost:8000", _local_entry()) + """A local server with alive PID but unresponsive URL is skipped.""" + pid = os.getpid() + state_file.add_server(_DEAD_URL, _local_entry(pid=pid)) source = _make_source(state_file) - with ( - patch("zndraw.settings_sources._is_pid_alive", return_value=True), - patch("zndraw.settings_sources._is_url_healthy", return_value=False), - ): - result = source() + result = source() assert result.get("url") is None -def test_url_prefers_localhost_over_remote(state_file): +def test_url_prefers_localhost_over_remote(state_file, server_factory): + """A healthy local server is preferred over a more-recently-used remote.""" + instance = server_factory({}) + local_url = instance.url + pid = os.getpid() + state_file.add_server( - "https://remote.example.com", + _DEAD_URL, # remote (unreachable — won't be selected) _remote_entry( last_used=datetime(2026, 3, 25, 15, 0, tzinfo=UTC), ), ) state_file.add_server( - "http://localhost:8000", + local_url, _local_entry( + pid=pid, last_used=datetime(2026, 3, 25, 10, 0, tzinfo=UTC), ), ) source = _make_source(state_file) - with ( - patch("zndraw.settings_sources._is_pid_alive", return_value=True), - patch("zndraw.settings_sources._is_url_healthy", return_value=True), - ): - result = source() - assert result["url"] == "http://localhost:8000" + result = source() + assert result["url"] == local_url def test_url_falls_back_to_remote(state_file): - state_file.add_server("http://localhost:8000", _local_entry()) + """When local server has dead PID, falls back to healthy remote. + + Uses a fake non-localhost URL for the 'remote' entry since server_factory + always binds to 127.0.0.1 (which StateFileSource classifies as local). + The _is_url_healthy patch simulates the remote being reachable so we can + test the local→remote fallback logic without a real remote server. + """ + state_file.add_server("http://localhost:8000", _local_entry(pid=_DEAD_PID)) state_file.add_server("https://remote.example.com", _remote_entry()) source = _make_source(state_file) with ( - patch("zndraw.settings_sources._is_pid_alive", return_value=False), + # why: pure-logic test of URL-healthy decision path without requiring remote server patch("zndraw.settings_sources._is_url_healthy", return_value=True), ): result = source() assert result["url"] == "https://remote.example.com" -def test_url_most_recent_local_wins(state_file): +def test_url_most_recent_local_wins(state_file, server_factory): + """When two local servers are healthy, the most recently used one is returned.""" + instance1 = server_factory({}) + instance2 = server_factory({}) + url1 = instance1.url + url2 = instance2.url + pid = os.getpid() + state_file.add_server( - "http://localhost:8000", + url1, _local_entry( - pid=100, + pid=pid, last_used=datetime(2026, 3, 25, 10, 0, tzinfo=UTC), ), ) state_file.add_server( - "http://localhost:9000", + url2, _local_entry( - pid=200, + pid=pid, last_used=datetime(2026, 3, 25, 14, 0, tzinfo=UTC), ), ) source = _make_source(state_file) - with ( - patch("zndraw.settings_sources._is_pid_alive", return_value=True), - patch("zndraw.settings_sources._is_url_healthy", return_value=True), - ): - result = source() - assert result["url"] == "http://localhost:9000" + result = source() + # url2 has later last_used — should be preferred + assert result["url"] == url2 def test_url_dead_pid_entry_removed(state_file): - state_file.add_server("http://localhost:8000", _local_entry()) + """A local entry with a dead PID is removed from state after discovery.""" + state_file.add_server("http://localhost:8000", _local_entry(pid=_DEAD_PID)) source = _make_source(state_file) - with patch("zndraw.settings_sources._is_pid_alive", return_value=False): - source() + source() assert state_file.get_server("http://localhost:8000") is None def test_url_empty_state_returns_none(state_file): + """Empty state file returns no URL.""" source = _make_source(state_file) result = source() assert result.get("url") is None @@ -155,34 +174,35 @@ def test_url_empty_state_returns_none(state_file): # --- Token resolution --- -def test_token_local_server_uses_access_token(state_file): - entry = _local_entry(local_token="raw-admin") +def test_token_local_server_uses_access_token(state_file, server_factory): + """Local server with access_token returns it as token.""" + instance = server_factory({}) + url = instance.url + pid = os.getpid() + entry = _local_entry(pid=pid, local_token="raw-admin") entry.access_token = "real.jwt" - state_file.add_server("http://localhost:8000", entry) + state_file.add_server(url, entry) source = _make_source(state_file) - with ( - patch("zndraw.settings_sources._is_pid_alive", return_value=True), - patch("zndraw.settings_sources._is_url_healthy", return_value=True), - ): - result = source() + result = source() assert result["token"] == "real.jwt" -def test_token_local_server_no_access_token_returns_none(state_file): - state_file.add_server("http://localhost:8000", _local_entry(local_token="raw-only")) +def test_token_local_server_no_access_token_returns_none(state_file, server_factory): + """Local server without access_token returns no token.""" + instance = server_factory({}) + url = instance.url + pid = os.getpid() + state_file.add_server(url, _local_entry(pid=pid, local_token="raw-only")) source = _make_source(state_file) - with ( - patch("zndraw.settings_sources._is_pid_alive", return_value=True), - patch("zndraw.settings_sources._is_url_healthy", return_value=True), - ): - result = source() + result = source() assert result.get("token") is None def test_token_remote_uses_stored_token(state_file): - state_file.add_server("https://remote.example.com", _remote_entry()) + """Remote server with stored token entry returns it.""" + state_file.add_server(_DEAD_URL, _remote_entry()) state_file.add_token( - "https://remote.example.com", + _DEAD_URL, TokenEntry( access_token="stored.jwt", email="user@example.com", @@ -190,31 +210,41 @@ def test_token_remote_uses_stored_token(state_file): ), ) source = _make_source(state_file) - with patch("zndraw.settings_sources._is_url_healthy", return_value=True): + with ( + # why: pure-logic test of URL-healthy decision path without requiring remote server + patch("zndraw.settings_sources._is_url_healthy", return_value=True), + ): result = source() assert result["token"] == "stored.jwt" def test_token_no_stored_returns_none(state_file): - state_file.add_server("https://remote.example.com", _remote_entry()) + """Remote server without stored token returns no token.""" + state_file.add_server(_DEAD_URL, _remote_entry()) source = _make_source(state_file) - with patch("zndraw.settings_sources._is_url_healthy", return_value=True): + with ( + # why: pure-logic test of URL-healthy decision path without requiring remote server + patch("zndraw.settings_sources._is_url_healthy", return_value=True), + ): result = source() assert result.get("token") is None -def test_token_uses_url_from_higher_source(state_file): +def test_token_uses_url_from_higher_source(state_file, server_factory): + """Token is resolved for URL provided by higher-priority source.""" + instance = server_factory({}) + url = instance.url state_file.add_token( - "https://override.example.com", + url, TokenEntry( access_token="override.jwt", email="user@example.com", stored_at=datetime(2026, 3, 25, tzinfo=UTC), ), ) - state_file.add_server("https://override.example.com", _remote_entry()) + state_file.add_server(url, _remote_entry()) source = _make_source( - state_file, current_state={"url": "https://override.example.com"} + state_file, current_state={"url": url} ) result = source() assert result.get("token") == "override.jwt" @@ -222,8 +252,8 @@ def test_token_uses_url_from_higher_source(state_file): def test_remote_unhealthy_kept_in_state(state_file): - state_file.add_server("https://remote.example.com", _remote_entry()) + """An unhealthy remote server entry is NOT removed from state.""" + state_file.add_server(_DEAD_URL, _remote_entry()) source = _make_source(state_file) - with patch("zndraw.settings_sources._is_url_healthy", return_value=False): - source() - assert state_file.get_server("https://remote.example.com") is not None + source() + assert state_file.get_server(_DEAD_URL) is not None From 4d71c677e8151a23773977510dc24636d63d03d2 Mon Sep 17 00:00:00 2001 From: Fabian Zills Date: Thu, 26 Mar 2026 15:06:40 +0100 Subject: [PATCH 12/20] cleanup: remove unused imports from test_resolve_token.py Co-Authored-By: Claude Opus 4.6 (1M context) --- tests/zndraw/test_resolve_token.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/zndraw/test_resolve_token.py b/tests/zndraw/test_resolve_token.py index 2e54312f2..c2848bfb6 100644 --- a/tests/zndraw/test_resolve_token.py +++ b/tests/zndraw/test_resolve_token.py @@ -3,8 +3,6 @@ from __future__ import annotations import pytest -import pytest_asyncio -from httpx import AsyncClient from pydantic import SecretStr from zndraw.auth_utils import guest_login, login_with_credentials, validate_credentials From 89a89f62df9b228c5f2d3b1b36d2708940997fe5 Mon Sep 17 00:00:00 2001 From: Fabian Zills Date: Thu, 26 Mar 2026 15:08:11 +0100 Subject: [PATCH 13/20] docs: add justification comments to all remaining mock/patch calls Add inline # why: comments to all monkeypatch.setattr, @patch, and patch.dict calls across four test files. Each comment explains the testing rationale (e.g., filesystem isolation, preventing real I/O, orchestration testing) to improve code clarity and maintainability. Co-Authored-By: Claude Opus 4.6 (1M context) --- tests/zndraw/test_cli.py | 48 +++++++++---------- tests/zndraw/test_gif.py | 14 +++--- tests/zndraw/test_local_token_auth.py | 4 +- .../zndraw/test_local_token_jwt_regression.py | 4 +- 4 files changed, 35 insertions(+), 35 deletions(-) diff --git a/tests/zndraw/test_cli.py b/tests/zndraw/test_cli.py index 8a778ccfe..85077d3e3 100644 --- a/tests/zndraw/test_cli.py +++ b/tests/zndraw/test_cli.py @@ -86,7 +86,7 @@ def test_get_room_names_empty_list(): ) def test_open_browser_to_url(monkeypatch, room, copy_from, expected_url): opened_urls: list[str] = [] - monkeypatch.setattr("zndraw.cli.webbrowser.open", opened_urls.append) + monkeypatch.setattr("zndraw.cli.webbrowser.open", opened_urls.append) # why: prevents browser window during test open_browser_to("http://localhost:8000", room, browser=True, copy_from=copy_from) assert opened_urls == [expected_url] @@ -94,7 +94,7 @@ def test_open_browser_to_url(monkeypatch, room, copy_from, expected_url): def test_open_browser_to_noop_when_disabled(monkeypatch): opened_urls: list[str] = [] - monkeypatch.setattr("zndraw.cli.webbrowser.open", opened_urls.append) + monkeypatch.setattr("zndraw.cli.webbrowser.open", opened_urls.append) # why: prevents browser window during test open_browser_to("http://localhost:8000", "room", browser=False) assert opened_urls == [] @@ -148,8 +148,8 @@ def test_file_not_found(): def _empty_state(monkeypatch, tmp_path): """Point StateFile at an empty tmp dir and disable health checks.""" - monkeypatch.setattr("zndraw.cli.StateFile", lambda: StateFile(directory=tmp_path)) - monkeypatch.setattr("zndraw.cli._is_url_healthy", lambda _url: False) + monkeypatch.setattr("zndraw.cli.StateFile", lambda: StateFile(directory=tmp_path)) # why: isolates state to tmp_path for filesystem isolation + monkeypatch.setattr("zndraw.cli._is_url_healthy", lambda _url: False) # why: simulates no existing server for StateFile logic def test_status_no_server(monkeypatch, tmp_path): @@ -173,8 +173,8 @@ def test_status_server_running(monkeypatch, tmp_path): "http://localhost:8000", ServerEntry(added_at=now, last_used=now, pid=1234, version="1.0.0"), ) - monkeypatch.setattr("zndraw.cli.StateFile", lambda: state) - monkeypatch.setattr("zndraw.cli._is_url_healthy", lambda _url: True) + monkeypatch.setattr("zndraw.cli.StateFile", lambda: state) # why: isolates state to tmp_path for filesystem isolation + monkeypatch.setattr("zndraw.cli._is_url_healthy", lambda _url: True) # why: simulates no existing server for StateFile logic result = runner.invoke(app, ["--status"]) assert result.exit_code == 0 @@ -200,16 +200,16 @@ def test_browser_before_upload_new_server(monkeypatch, tmp_path): state_dir = tmp_path / "state" state_dir.mkdir() - monkeypatch.setattr("zndraw.cli.StateFile", lambda: StateFile(directory=state_dir)) - monkeypatch.setattr("zndraw.cli._is_url_healthy", lambda _url: False) - monkeypatch.setattr("zndraw.cli.wait_for_server_ready", lambda *_a, **_kw: True) - monkeypatch.setattr("uvicorn.Server.run", lambda _self: None) + monkeypatch.setattr("zndraw.cli.StateFile", lambda: StateFile(directory=state_dir)) # why: isolates state to tmp_path for filesystem isolation + monkeypatch.setattr("zndraw.cli._is_url_healthy", lambda _url: False) # why: simulates no existing server for StateFile logic + monkeypatch.setattr("zndraw.cli.wait_for_server_ready", lambda *_a, **_kw: True) # why: skips server polling in unit test + monkeypatch.setattr("uvicorn.Server.run", lambda _self: None) # why: prevents real server startup in unit test monkeypatch.setattr( "zndraw.cli.webbrowser.open", lambda _url: call_order.append("browser") - ) + ) # why: prevents browser window during test monkeypatch.setattr( "zndraw.cli.upload_file", lambda *_a, **_kw: call_order.append("upload") - ) + ) # why: tracks call order (browser-before-upload orchestration test) result = runner.invoke(app, [str(dummy)]) assert result.exit_code == 0 @@ -229,14 +229,14 @@ def test_browser_before_upload_existing_server(monkeypatch, tmp_path): "http://localhost:8000", ServerEntry(added_at=now, last_used=now, pid=1234, version=__version__), ) - monkeypatch.setattr("zndraw.cli.StateFile", lambda: state) - monkeypatch.setattr("zndraw.cli._is_url_healthy", lambda _url: True) + monkeypatch.setattr("zndraw.cli.StateFile", lambda: state) # why: isolates state to tmp_path for filesystem isolation + monkeypatch.setattr("zndraw.cli._is_url_healthy", lambda _url: True) # why: simulates no existing server for StateFile logic monkeypatch.setattr( "zndraw.cli.webbrowser.open", lambda _url: call_order.append("browser") - ) + ) # why: prevents browser window during test monkeypatch.setattr( "zndraw.cli.upload_file", lambda *_a, **_kw: call_order.append("upload") - ) + ) # why: tracks call order (browser-before-upload orchestration test) result = runner.invoke(app, [str(dummy)]) assert result.exit_code == 0 @@ -251,10 +251,10 @@ def test_browser_before_upload_remote(monkeypatch, tmp_path): call_order: list[str] = [] monkeypatch.setattr( "zndraw.cli.webbrowser.open", lambda _url: call_order.append("browser") - ) + ) # why: prevents browser window during test monkeypatch.setattr( "zndraw.cli.upload_file", lambda *_a, **_kw: call_order.append("upload") - ) + ) # why: tracks call order (browser-before-upload orchestration test) result = runner.invoke(app, ["--connect", "http://example.com", str(dummy)]) assert result.exit_code == 0 @@ -274,17 +274,17 @@ def spy_init(self, **kwargs): original_init(self, **kwargs) captured.append(self) - monkeypatch.setattr(Settings, "__init__", spy_init) - monkeypatch.setattr("uvicorn.Server.run", lambda _self: None) + monkeypatch.setattr(Settings, "__init__", spy_init) # why: spy on Settings instantiation to verify config propagation + monkeypatch.setattr("uvicorn.Server.run", lambda _self: None) # why: prevents real server startup in unit test state_dir = tmp_path / "state" state_dir.mkdir() - monkeypatch.setattr("zndraw.cli.StateFile", lambda: StateFile(directory=state_dir)) - monkeypatch.setattr("zndraw.cli._is_url_healthy", lambda _url: False) + monkeypatch.setattr("zndraw.cli.StateFile", lambda: StateFile(directory=state_dir)) # why: isolates state to tmp_path for filesystem isolation + monkeypatch.setattr("zndraw.cli._is_url_healthy", lambda _url: False) # why: simulates no existing server for StateFile logic monkeypatch.setattr( "zndraw.cli.wait_for_server_ready", - lambda _url, timeout=30.0: True, # noqa: ARG005 + lambda _url, timeout=30.0: True, # noqa: ARG005 # why: skips server polling in unit test ) - monkeypatch.setattr("zndraw.cli._acquire_admin_jwt", lambda _url: None) + monkeypatch.setattr("zndraw.cli._acquire_admin_jwt", lambda _url: None) # why: unit test of Settings propagation, not auth flow return captured diff --git a/tests/zndraw/test_gif.py b/tests/zndraw/test_gif.py index 1bf7d5ab5..bc74bdb53 100644 --- a/tests/zndraw/test_gif.py +++ b/tests/zndraw/test_gif.py @@ -250,8 +250,8 @@ def test_capture_orbit_and_curve_mutually_exclusive(): assert result.exit_code != 0 -@patch("zndraw.cli_agent.gif.get_zndraw") -@patch("zndraw.cli_agent.gif.resolve_room", return_value="test-room") +@patch("zndraw.cli_agent.gif.get_zndraw") # why: orchestration test of geometry creation/cleanup, not connectivity +@patch("zndraw.cli_agent.gif.resolve_room", return_value="test-room") # why: orchestration test of geometry creation/cleanup, not connectivity def test_capture_orbit_creates_temp_geometries( mock_resolve, mock_get_zndraw, tmp_path: Path ): @@ -295,8 +295,8 @@ def test_capture_orbit_creates_temp_geometries( assert len(geom_store) == 0 -@patch("zndraw.cli_agent.gif.get_zndraw") -@patch("zndraw.cli_agent.gif.resolve_room", return_value="test-room") +@patch("zndraw.cli_agent.gif.get_zndraw") # why: orchestration test of geometry creation/cleanup, not connectivity +@patch("zndraw.cli_agent.gif.resolve_room", return_value="test-room") # why: orchestration test of geometry creation/cleanup, not connectivity def test_capture_restores_step(mock_resolve, mock_get_zndraw, tmp_path: Path): """Verify step is restored to its original value after capture.""" from typer.testing import CliRunner @@ -331,8 +331,8 @@ def test_capture_restores_step(mock_resolve, mock_get_zndraw, tmp_path: Path): assert step_box[0] == 7 -@patch("zndraw.cli_agent.gif.get_zndraw") -@patch("zndraw.cli_agent.gif.resolve_room", return_value="test-room") +@patch("zndraw.cli_agent.gif.get_zndraw") # why: orchestration test of geometry creation/cleanup, not connectivity +@patch("zndraw.cli_agent.gif.resolve_room", return_value="test-room") # why: orchestration test of geometry creation/cleanup, not connectivity def test_capture_cleans_up_on_error(mock_resolve, mock_get_zndraw, tmp_path: Path): """Verify cleanup runs even when screenshot raises.""" from typer.testing import CliRunner @@ -377,7 +377,7 @@ def test_capture_cleans_up_on_error(mock_resolve, mock_get_zndraw, tmp_path: Pat def test_pillow_missing_gives_clear_error(tmp_path: Path): """Verify exit when Pillow import fails.""" with ( - patch.dict("sys.modules", {"PIL": None, "PIL.Image": None}), + patch.dict("sys.modules", {"PIL": None, "PIL.Image": None}), # why: tests error message when PIL is not installed pytest.raises(SystemExit), ): _assemble_gif([b"fake"], tmp_path / "test.gif", 20) diff --git a/tests/zndraw/test_local_token_auth.py b/tests/zndraw/test_local_token_auth.py index 0ec754162..1ac8261bc 100644 --- a/tests/zndraw/test_local_token_auth.py +++ b/tests/zndraw/test_local_token_auth.py @@ -38,7 +38,7 @@ async def regular_user(session: AsyncSession) -> User: @pytest.mark.asyncio -@patch("zndraw.routes.admin.os.kill") +@patch("zndraw.routes.admin.os.kill") # why: prevents real os.kill() from terminating processes during test async def test_local_token_grants_admin_access(mock_kill, client: AsyncClient) -> None: """Bearer local_token grants admin access to shutdown endpoint.""" from zndraw.app import app @@ -73,7 +73,7 @@ async def test_wrong_local_token_rejected(client: AsyncClient) -> None: @pytest.mark.asyncio -@patch("zndraw.routes.admin.os.kill") +@patch("zndraw.routes.admin.os.kill") # why: prevents real os.kill() from terminating processes during test async def test_no_local_token_configured_normal_auth_works( mock_kill, client: AsyncClient, admin_user: User ) -> None: diff --git a/tests/zndraw/test_local_token_jwt_regression.py b/tests/zndraw/test_local_token_jwt_regression.py index 91732a9c3..5c6a6a78d 100644 --- a/tests/zndraw/test_local_token_jwt_regression.py +++ b/tests/zndraw/test_local_token_jwt_regression.py @@ -111,7 +111,7 @@ def test_e2e_dev_mode_zndraw_client_connects(server, tmp_path, monkeypatch): monkeypatch.setattr( "zndraw.settings_sources.StateFile", lambda: StateFile(directory=tmp_path), - ) + ) # why: redirects StateFile to tmp_path for filesystem isolation; test uses real server via server_factory client = ZnDraw(url=server, room="test-e2e-dev") try: @@ -150,7 +150,7 @@ def test_e2e_production_mode_zndraw_client_connects(server_auth, tmp_path, monke monkeypatch.setattr( "zndraw.settings_sources.StateFile", lambda: StateFile(directory=tmp_path), - ) + ) # why: redirects StateFile to tmp_path for filesystem isolation; test uses real server via server_factory client = ZnDraw(url=server_auth, room="test-e2e-prod") try: From de51c9088f8dbfb0466cef5ebd7228cd377e82d9 Mon Sep 17 00:00:00 2001 From: Fabian Zills Date: Thu, 26 Mar 2026 15:10:34 +0100 Subject: [PATCH 14/20] test: add E2E tests for client persistence, guest auth, and StateFileSource Co-Authored-By: Claude Sonnet 4.6 --- tests/zndraw/test_e2e_client_persistence.py | 103 +++++++++++ tests/zndraw/test_e2e_guest_auth.py | 163 ++++++++++++++++++ tests/zndraw/test_e2e_state_file_source.py | 179 ++++++++++++++++++++ 3 files changed, 445 insertions(+) create mode 100644 tests/zndraw/test_e2e_client_persistence.py create mode 100644 tests/zndraw/test_e2e_guest_auth.py create mode 100644 tests/zndraw/test_e2e_state_file_source.py diff --git a/tests/zndraw/test_e2e_client_persistence.py b/tests/zndraw/test_e2e_client_persistence.py new file mode 100644 index 000000000..90131207f --- /dev/null +++ b/tests/zndraw/test_e2e_client_persistence.py @@ -0,0 +1,103 @@ +"""E2E test: Client data persists across disconnect/reconnect cycles. + +Uses a real uvicorn server via the server_factory fixture. +""" + +import uuid + +import ase +import numpy as np + +from zndraw import ZnDraw + + +def _make_atoms(x: float) -> ase.Atoms: + """Create a simple H atom at position (x, 0, 0).""" + atoms = ase.Atoms("H", positions=[[x, 0, 0]]) + atoms.info["x"] = x + return atoms + + +def test_frames_persist_after_disconnect(server: str): + """Data written by a client survives after it disconnects. + + A second client connecting to the same room should see the frames + that were written before the first client disconnected. + """ + room_id = uuid.uuid4().hex + + # First client: write frames then disconnect + client_a = ZnDraw(url=server, room=room_id) + client_a.extend([_make_atoms(float(i)) for i in range(3)]) + client_a.disconnect() + + # Second client: reconnect and read frames + client_b = ZnDraw(url=server, room=room_id) + assert len(client_b) == 3 + np.testing.assert_allclose(client_b[0].positions[0, 0], 0.0, atol=1e-6) + np.testing.assert_allclose(client_b[1].positions[0, 0], 1.0, atol=1e-6) + np.testing.assert_allclose(client_b[2].positions[0, 0], 2.0, atol=1e-6) + client_b.disconnect() + + +def test_step_persists_after_disconnect(server: str): + """Step value set by a client survives after it disconnects.""" + room_id = uuid.uuid4().hex + + client_a = ZnDraw(url=server, room=room_id) + client_a.extend([_make_atoms(float(i)) for i in range(5)]) + client_a.step = 4 + client_a.disconnect() + + client_b = ZnDraw(url=server, room=room_id) + assert client_b.step == 4 + client_b.disconnect() + + +def test_multiple_reconnect_cycles(server: str): + """Data remains consistent across multiple disconnect/reconnect cycles.""" + room_id = uuid.uuid4().hex + + # First write + c1 = ZnDraw(url=server, room=room_id) + c1.append(_make_atoms(10.0)) + c1.disconnect() + + # Reconnect and append more + c2 = ZnDraw(url=server, room=room_id) + assert len(c2) == 1 + c2.append(_make_atoms(20.0)) + c2.disconnect() + + # Final reconnect: should see both frames + c3 = ZnDraw(url=server, room=room_id) + assert len(c3) == 2 + np.testing.assert_allclose(c3[0].positions[0, 0], 10.0, atol=1e-6) + np.testing.assert_allclose(c3[1].positions[0, 0], 20.0, atol=1e-6) + c3.disconnect() + + +def test_guest_token_reconnect(server: str): + """Client reconnects with a fresh guest token and still sees the room data.""" + import httpx + + room_id = uuid.uuid4().hex + + # Write with first guest token + token_resp = httpx.post(f"{server}/v1/auth/guest", timeout=10.0) + assert token_resp.status_code == 200 + token_a = token_resp.json()["access_token"] + + c1 = ZnDraw(url=server, room=room_id, token=token_a) + c1.append(_make_atoms(42.0)) + c1.disconnect() + + # Reconnect with a different guest token + token_resp2 = httpx.post(f"{server}/v1/auth/guest", timeout=10.0) + assert token_resp2.status_code == 200 + token_b = token_resp2.json()["access_token"] + + c2 = ZnDraw(url=server, room=room_id, token=token_b) + assert len(c2) == 1 + np.testing.assert_allclose(c2[0].positions[0, 0], 42.0, atol=1e-6) + c2.disconnect() diff --git a/tests/zndraw/test_e2e_guest_auth.py b/tests/zndraw/test_e2e_guest_auth.py new file mode 100644 index 000000000..ab469234a --- /dev/null +++ b/tests/zndraw/test_e2e_guest_auth.py @@ -0,0 +1,163 @@ +"""E2E test: Guest can authenticate, create a room, and write/read frames via REST. + +Uses a real uvicorn server via the http_client fixture (which starts the server). +All interaction is through the REST API only (no Socket.IO). +""" + +import uuid + +import ase +import msgpack +import pytest +from httpx import AsyncClient + +from zndraw.client import atoms_to_json_dict + + +def _make_atoms(x: float, formula: str = "H") -> ase.Atoms: + """Create an Atoms object for testing.""" + atoms = ase.Atoms(formula, positions=[[x, 0, 0]]) + atoms.info["x_pos"] = x + return atoms + + +def _decode_msgpack_frames(content: bytes) -> list: + """Decode a msgpack-encoded list of frames from a response body.""" + return msgpack.unpackb(content, raw=True) + + +# ============================================================================= +# Guest Auth + Room Flow +# ============================================================================= + + +@pytest.mark.asyncio +async def test_guest_auth_returns_token(http_client: AsyncClient): + """POST /v1/auth/guest returns a valid bearer token.""" + response = await http_client.post("/v1/auth/guest") + assert response.status_code == 200 + + body = response.json() + assert body["token_type"] == "bearer" + assert body["access_token"] + assert isinstance(body["access_token"], str) + assert len(body["access_token"]) > 10 + + +@pytest.mark.asyncio +async def test_guest_can_create_room(http_client: AsyncClient): + """A guest can create a room after authenticating.""" + # Get guest token + auth_resp = await http_client.post("/v1/auth/guest") + assert auth_resp.status_code == 200 + token = auth_resp.json()["access_token"] + + room_id = uuid.uuid4().hex + create_resp = await http_client.post( + "/v1/rooms", + json={"room_id": room_id}, + headers={"Authorization": f"Bearer {token}"}, + ) + assert create_resp.status_code == 201 + + body = create_resp.json() + assert body["room_id"] == room_id + assert body["created"] is True + + +@pytest.mark.asyncio +async def test_guest_write_then_read_frame(http_client: AsyncClient): + """Guest creates a room, writes a frame via REST, then reads it back.""" + # Step 1: Get guest token + auth_resp = await http_client.post("/v1/auth/guest") + assert auth_resp.status_code == 200 + token = auth_resp.json()["access_token"] + headers = {"Authorization": f"Bearer {token}"} + + # Step 2: Create room (with @none so it starts empty) + room_id = uuid.uuid4().hex + create_resp = await http_client.post( + "/v1/rooms", + json={"room_id": room_id, "copy_from": "@none"}, + headers=headers, + ) + assert create_resp.status_code == 201 + + # Step 3: Write a frame via REST + frame_data = atoms_to_json_dict(_make_atoms(7.5)) + post_resp = await http_client.post( + f"/v1/rooms/{room_id}/frames", + json={"frames": [frame_data]}, + headers=headers, + ) + assert post_resp.status_code == 201 + + result = post_resp.json() + assert result["total"] == 1 + assert result["start"] == 0 + assert result["stop"] == 1 + + # Step 4: Read the frame back via REST + get_resp = await http_client.get( + f"/v1/rooms/{room_id}/frames", + headers=headers, + ) + assert get_resp.status_code == 200 + assert get_resp.headers["content-type"] == "application/x-msgpack" + + frames = _decode_msgpack_frames(get_resp.content) + assert len(frames) == 1 + + +@pytest.mark.asyncio +async def test_guest_write_multiple_frames(http_client: AsyncClient): + """Guest writes multiple frames and reads them back by index.""" + auth_resp = await http_client.post("/v1/auth/guest") + token = auth_resp.json()["access_token"] + headers = {"Authorization": f"Bearer {token}"} + + room_id = uuid.uuid4().hex + await http_client.post( + "/v1/rooms", + json={"room_id": room_id, "copy_from": "@none"}, + headers=headers, + ) + + # Write 3 frames + frames = [atoms_to_json_dict(_make_atoms(float(i))) for i in range(3)] + post_resp = await http_client.post( + f"/v1/rooms/{room_id}/frames", + json={"frames": frames}, + headers=headers, + ) + assert post_resp.status_code == 201 + assert post_resp.json()["total"] == 3 + + # Read all frames back + get_resp = await http_client.get( + f"/v1/rooms/{room_id}/frames", + headers=headers, + ) + assert get_resp.status_code == 200 + stored_frames = _decode_msgpack_frames(get_resp.content) + assert len(stored_frames) == 3 + + +@pytest.mark.asyncio +async def test_guest_cannot_access_other_room_without_auth(http_client: AsyncClient): + """Unauthenticated requests to frame endpoints return 401.""" + # Create a room with auth + auth_resp = await http_client.post("/v1/auth/guest") + token = auth_resp.json()["access_token"] + headers = {"Authorization": f"Bearer {token}"} + + room_id = uuid.uuid4().hex + await http_client.post( + "/v1/rooms", + json={"room_id": room_id, "copy_from": "@none"}, + headers=headers, + ) + + # Try to access frames without authorization + get_resp = await http_client.get(f"/v1/rooms/{room_id}/frames") + assert get_resp.status_code == 401 diff --git a/tests/zndraw/test_e2e_state_file_source.py b/tests/zndraw/test_e2e_state_file_source.py new file mode 100644 index 000000000..9921b343d --- /dev/null +++ b/tests/zndraw/test_e2e_state_file_source.py @@ -0,0 +1,179 @@ +"""E2E test: StateFileSource discovers a running server via state.json. + +Uses a real uvicorn server via server_factory to verify that StateFileSource +can locate the server from a state.json file backed by a real running process. +""" + +from __future__ import annotations + +import os +from datetime import UTC, datetime +from pathlib import Path + +from zndraw.state_file import ServerEntry, StateFile + + +def _make_source(state_file: StateFile): + """Create a StateFileSource pointing at the given StateFile.""" + from zndraw.client.settings import ClientSettings + from zndraw.settings_sources import StateFileSource + + return StateFileSource(ClientSettings, state_file=state_file) + + +# ============================================================================= +# Server discovery via state.json +# ============================================================================= + + +def test_statefile_discovers_running_server(server_factory, tmp_path: Path): + """StateFileSource returns the URL of a registered running server. + + Steps: + 1. Start a real server via server_factory + 2. Create a StateFile pointing at tmp_path + 3. Register the server with the current PID (alive process) + 4. Call StateFileSource — it should return the server URL + """ + instance = server_factory({}) + url = instance.url + + state_file = StateFile(directory=tmp_path) + entry = ServerEntry( + added_at=datetime.now(UTC), + last_used=datetime.now(UTC), + pid=os.getpid(), + ) + state_file.add_server(url, entry) + + source = _make_source(state_file) + result = source() + + assert result.get("url") == url + + +def test_statefile_discovers_server_with_access_token(server_factory, tmp_path: Path): + """StateFileSource returns token stored in server entry for local server. + + A local server entry with access_token set should have that token + returned as the resolved token. + """ + instance = server_factory({}) + url = instance.url + + state_file = StateFile(directory=tmp_path) + entry = ServerEntry( + added_at=datetime.now(UTC), + last_used=datetime.now(UTC), + pid=os.getpid(), + access_token="real.jwt.token", + ) + state_file.add_server(url, entry) + + source = _make_source(state_file) + result = source() + + assert result.get("url") == url + assert result.get("token") == "real.jwt.token" + + +def test_statefile_prefers_most_recent_server(server_factory, tmp_path: Path): + """When multiple servers are registered, the most recently used one wins.""" + instance1 = server_factory({}) + instance2 = server_factory({}) + url1 = instance1.url + url2 = instance2.url + pid = os.getpid() + + state_file = StateFile(directory=tmp_path) + # Register url1 with an older last_used timestamp + state_file.add_server( + url1, + ServerEntry( + added_at=datetime.now(UTC), + last_used=datetime(2026, 3, 25, 10, 0, tzinfo=UTC), + pid=pid, + ), + ) + # Register url2 with a more recent last_used timestamp + state_file.add_server( + url2, + ServerEntry( + added_at=datetime.now(UTC), + last_used=datetime(2026, 3, 25, 14, 0, tzinfo=UTC), + pid=pid, + ), + ) + + source = _make_source(state_file) + result = source() + + # url2 has later last_used — should be preferred + assert result.get("url") == url2 + + +def test_statefile_no_url_when_dead_process(tmp_path: Path): + """StateFileSource returns no URL if the registered process is dead.""" + state_file = StateFile(directory=tmp_path) + # Use a PID that is guaranteed to not exist + dead_pid = 999_999_999 + state_file.add_server( + "http://localhost:9999", + ServerEntry( + added_at=datetime.now(UTC), + last_used=datetime.now(UTC), + pid=dead_pid, + ), + ) + + source = _make_source(state_file) + result = source() + + assert result.get("url") is None + + +def test_statefile_server_removed_after_dead_pid(tmp_path: Path): + """StateFileSource removes dead-PID entries from state.json after discovery.""" + state_file = StateFile(directory=tmp_path) + dead_pid = 999_999_999 + state_file.add_server( + "http://localhost:9999", + ServerEntry( + added_at=datetime.now(UTC), + last_used=datetime.now(UTC), + pid=dead_pid, + ), + ) + + source = _make_source(state_file) + source() + + # Dead entry should have been cleaned up + assert state_file.get_server("http://localhost:9999") is None + + +def test_statefile_url_healthy_check(server_factory, tmp_path: Path): + """StateFileSource uses health check to verify server is reachable.""" + import httpx + + instance = server_factory({}) + url = instance.url + + # Verify the server is actually reachable before registering + resp = httpx.get(f"{url}/v1/health", timeout=5.0) + assert resp.status_code == 200, "Server must be healthy for this E2E test" + + state_file = StateFile(directory=tmp_path) + state_file.add_server( + url, + ServerEntry( + added_at=datetime.now(UTC), + last_used=datetime.now(UTC), + pid=os.getpid(), + ), + ) + + source = _make_source(state_file) + result = source() + + assert result.get("url") == url From 05f9160e0496b0c13fc42066eb28cb2373ca30f2 Mon Sep 17 00:00:00 2001 From: Fabian Zills Date: Thu, 26 Mar 2026 15:12:52 +0100 Subject: [PATCH 15/20] refactor: parametrize geometry 404 and auth endpoint error tests Co-Authored-By: Claude Sonnet 4.6 --- tests/zndraw/test_auth_endpoints.py | 75 +++++++++++----------- tests/zndraw/test_routes_geometries.py | 87 +++++--------------------- 2 files changed, 56 insertions(+), 106 deletions(-) diff --git a/tests/zndraw/test_auth_endpoints.py b/tests/zndraw/test_auth_endpoints.py index 7ed602716..7bafd50a2 100644 --- a/tests/zndraw/test_auth_endpoints.py +++ b/tests/zndraw/test_auth_endpoints.py @@ -25,21 +25,21 @@ async def test_login(client: AsyncClient, test_user: User) -> None: @pytest.mark.asyncio -async def test_login_invalid_password(client: AsyncClient, test_user: User) -> None: - """Test that invalid password returns 400.""" +@pytest.mark.parametrize( + "username,password", + [ + ("testuser@local.test", "wrongpassword"), + ("nonexistent@local.test", "password"), + ], + ids=["invalid_password", "nonexistent_user"], +) +async def test_login_fails( + client: AsyncClient, test_user: User, username: str, password: str +) -> None: + """Login returns 400 for invalid credentials.""" response = await client.post( "/v1/auth/jwt/login", - data={"username": "testuser@local.test", "password": "wrongpassword"}, - ) - assert response.status_code == 400 - - -@pytest.mark.asyncio -async def test_login_nonexistent_user(client: AsyncClient) -> None: - """Test that nonexistent user returns 400.""" - response = await client.post( - "/v1/auth/jwt/login", - data={"username": "nonexistent@local.test", "password": "password"}, + data={"username": username, "password": password}, ) assert response.status_code == 400 @@ -92,26 +92,29 @@ async def test_register(client: AsyncClient) -> None: @pytest.mark.asyncio -async def test_register_duplicate_email(client: AsyncClient) -> None: - """Test that duplicate email returns 400.""" - # First registration succeeds - await client.post( - "/v1/auth/register", - json={"email": "duplicate@example.com", "password": "password123"}, - ) - # Second registration with same email fails - response = await client.post( - "/v1/auth/register", - json={"email": "duplicate@example.com", "password": "password456"}, - ) - assert response.status_code == 400 - - -@pytest.mark.asyncio -async def test_register_missing_fields(client: AsyncClient) -> None: - """Test that missing required fields returns 422.""" - response = await client.post( - "/v1/auth/register", - json={"email": "newuser@example.com"}, - ) - assert response.status_code == 422 +@pytest.mark.parametrize( + "setup_email,email,password,expected_status", + [ + ("duplicate@example.com", "duplicate@example.com", "password456", 400), + (None, "newuser@example.com", None, 422), + ], + ids=["duplicate_email", "missing_password"], +) +async def test_register_fails( + client: AsyncClient, + setup_email: str | None, + email: str, + password: str | None, + expected_status: int, +) -> None: + """Registration returns error for invalid input.""" + if setup_email: + await client.post( + "/v1/auth/register", + json={"email": setup_email, "password": "password123"}, + ) + body = {"email": email} + if password is not None: + body["password"] = password + response = await client.post("/v1/auth/register", json=body) + assert response.status_code == expected_status diff --git a/tests/zndraw/test_routes_geometries.py b/tests/zndraw/test_routes_geometries.py index c46aac936..1a049c2ea 100644 --- a/tests/zndraw/test_routes_geometries.py +++ b/tests/zndraw/test_routes_geometries.py @@ -666,77 +666,24 @@ async def test_update_selection_requires_auth( @pytest.mark.asyncio -async def test_list_geometries_returns_404_for_nonexistent_room( - client: AsyncClient, session: AsyncSession -) -> None: - """Test GET for non-existent room returns 404.""" - _, token = await create_test_user_in_db(session) - - response = await client.get( - "/v1/rooms/99999/geometries", - headers=auth_header(token), - ) - assert response.status_code == 404 - assert "room-not-found" in response.json()["type"] - - -@pytest.mark.asyncio -async def test_get_geometry_returns_404_for_nonexistent_room( - client: AsyncClient, session: AsyncSession -) -> None: - """Test GET single geometry for non-existent room returns 404.""" - _, token = await create_test_user_in_db(session) - - response = await client.get( - "/v1/rooms/99999/geometries/somekey", - headers=auth_header(token), - ) - assert response.status_code == 404 - assert "room-not-found" in response.json()["type"] - - -@pytest.mark.asyncio -async def test_upsert_geometry_returns_404_for_nonexistent_room( - client: AsyncClient, session: AsyncSession -) -> None: - """Test PUT for non-existent room returns 404.""" - _, token = await create_test_user_in_db(session) - - response = await client.put( - "/v1/rooms/99999/geometries/test", - json={"type": "Sphere", "data": {}}, - headers=auth_header(token), - ) - assert response.status_code == 404 - assert "room-not-found" in response.json()["type"] - - -@pytest.mark.asyncio -async def test_delete_geometry_returns_404_for_nonexistent_room( - client: AsyncClient, session: AsyncSession -) -> None: - """Test DELETE for non-existent room returns 404.""" - _, token = await create_test_user_in_db(session) - - response = await client.delete( - "/v1/rooms/99999/geometries/somekey", - headers=auth_header(token), - ) - assert response.status_code == 404 - assert "room-not-found" in response.json()["type"] - - -@pytest.mark.asyncio -async def test_update_selection_returns_404_for_nonexistent_room( - client: AsyncClient, session: AsyncSession +@pytest.mark.parametrize( + "method,path,body", + [ + ("GET", "/v1/rooms/99999/geometries", None), + ("GET", "/v1/rooms/99999/geometries/somekey", None), + ("PUT", "/v1/rooms/99999/geometries/test", {"type": "Sphere", "data": {}}), + ("DELETE", "/v1/rooms/99999/geometries/somekey", None), + ("PUT", "/v1/rooms/99999/geometries/particles/selection", {"indices": [0]}), + ], + ids=["list", "get", "upsert", "delete", "update_selection"], +) +async def test_geometry_endpoints_return_404_for_nonexistent_room( + client: AsyncClient, session: AsyncSession, method: str, path: str, body ) -> None: - """Test PUT selection for non-existent room returns 404.""" - _, token = await create_test_user_in_db(session) - - response = await client.put( - "/v1/rooms/99999/geometries/particles/selection", - json={"indices": [0]}, - headers=auth_header(token), + """All geometry endpoints return 404 for non-existent room.""" + user, token = await create_test_user_in_db(session) + response = await client.request( + method, path, json=body, headers=auth_header(token) ) assert response.status_code == 404 assert "room-not-found" in response.json()["type"] From 8d2cc5b987d917335a5aa911399e539939b6c93b Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 26 Mar 2026 14:31:59 +0000 Subject: [PATCH 16/20] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/zndraw/test_chat.py | 4 +- tests/zndraw/test_cli.py | 68 ++++++++++++++----- tests/zndraw/test_cli_agent/test_auth.py | 2 - tests/zndraw/test_cli_auth.py | 4 +- tests/zndraw/test_client_api.py | 1 + tests/zndraw/test_frames_provider_dispatch.py | 1 - tests/zndraw/test_gif.py | 28 ++++++-- tests/zndraw/test_local_token_auth.py | 8 ++- tests/zndraw/test_progress.py | 1 - tests/zndraw/test_routes_edit_lock.py | 1 - tests/zndraw/test_routes_figures.py | 8 +-- tests/zndraw/test_routes_frame_selection.py | 6 +- tests/zndraw/test_routes_frames.py | 16 ++--- tests/zndraw/test_routes_geometries.py | 41 +++++------ tests/zndraw/test_screenshots.py | 8 +-- tests/zndraw/test_state_file_source.py | 4 +- 16 files changed, 109 insertions(+), 92 deletions(-) diff --git a/tests/zndraw/test_chat.py b/tests/zndraw/test_chat.py index f2a0091d7..e6d86d585 100644 --- a/tests/zndraw/test_chat.py +++ b/tests/zndraw/test_chat.py @@ -105,9 +105,7 @@ async def test_create_message_requires_auth( @pytest.mark.asyncio -async def test_list_messages_empty( - client: AsyncClient, session: AsyncSession -) -> None: +async def test_list_messages_empty(client: AsyncClient, session: AsyncSession) -> None: """GET returns empty list for room with no messages.""" user, token = await create_test_user_in_db(session) room = await create_test_room(session, user) diff --git a/tests/zndraw/test_cli.py b/tests/zndraw/test_cli.py index 85077d3e3..dc829dcad 100644 --- a/tests/zndraw/test_cli.py +++ b/tests/zndraw/test_cli.py @@ -86,7 +86,9 @@ def test_get_room_names_empty_list(): ) def test_open_browser_to_url(monkeypatch, room, copy_from, expected_url): opened_urls: list[str] = [] - monkeypatch.setattr("zndraw.cli.webbrowser.open", opened_urls.append) # why: prevents browser window during test + monkeypatch.setattr( + "zndraw.cli.webbrowser.open", opened_urls.append + ) # why: prevents browser window during test open_browser_to("http://localhost:8000", room, browser=True, copy_from=copy_from) assert opened_urls == [expected_url] @@ -94,7 +96,9 @@ def test_open_browser_to_url(monkeypatch, room, copy_from, expected_url): def test_open_browser_to_noop_when_disabled(monkeypatch): opened_urls: list[str] = [] - monkeypatch.setattr("zndraw.cli.webbrowser.open", opened_urls.append) # why: prevents browser window during test + monkeypatch.setattr( + "zndraw.cli.webbrowser.open", opened_urls.append + ) # why: prevents browser window during test open_browser_to("http://localhost:8000", "room", browser=False) assert opened_urls == [] @@ -148,8 +152,12 @@ def test_file_not_found(): def _empty_state(monkeypatch, tmp_path): """Point StateFile at an empty tmp dir and disable health checks.""" - monkeypatch.setattr("zndraw.cli.StateFile", lambda: StateFile(directory=tmp_path)) # why: isolates state to tmp_path for filesystem isolation - monkeypatch.setattr("zndraw.cli._is_url_healthy", lambda _url: False) # why: simulates no existing server for StateFile logic + monkeypatch.setattr( + "zndraw.cli.StateFile", lambda: StateFile(directory=tmp_path) + ) # why: isolates state to tmp_path for filesystem isolation + monkeypatch.setattr( + "zndraw.cli._is_url_healthy", lambda _url: False + ) # why: simulates no existing server for StateFile logic def test_status_no_server(monkeypatch, tmp_path): @@ -173,8 +181,12 @@ def test_status_server_running(monkeypatch, tmp_path): "http://localhost:8000", ServerEntry(added_at=now, last_used=now, pid=1234, version="1.0.0"), ) - monkeypatch.setattr("zndraw.cli.StateFile", lambda: state) # why: isolates state to tmp_path for filesystem isolation - monkeypatch.setattr("zndraw.cli._is_url_healthy", lambda _url: True) # why: simulates no existing server for StateFile logic + monkeypatch.setattr( + "zndraw.cli.StateFile", lambda: state + ) # why: isolates state to tmp_path for filesystem isolation + monkeypatch.setattr( + "zndraw.cli._is_url_healthy", lambda _url: True + ) # why: simulates no existing server for StateFile logic result = runner.invoke(app, ["--status"]) assert result.exit_code == 0 @@ -200,10 +212,18 @@ def test_browser_before_upload_new_server(monkeypatch, tmp_path): state_dir = tmp_path / "state" state_dir.mkdir() - monkeypatch.setattr("zndraw.cli.StateFile", lambda: StateFile(directory=state_dir)) # why: isolates state to tmp_path for filesystem isolation - monkeypatch.setattr("zndraw.cli._is_url_healthy", lambda _url: False) # why: simulates no existing server for StateFile logic - monkeypatch.setattr("zndraw.cli.wait_for_server_ready", lambda *_a, **_kw: True) # why: skips server polling in unit test - monkeypatch.setattr("uvicorn.Server.run", lambda _self: None) # why: prevents real server startup in unit test + monkeypatch.setattr( + "zndraw.cli.StateFile", lambda: StateFile(directory=state_dir) + ) # why: isolates state to tmp_path for filesystem isolation + monkeypatch.setattr( + "zndraw.cli._is_url_healthy", lambda _url: False + ) # why: simulates no existing server for StateFile logic + monkeypatch.setattr( + "zndraw.cli.wait_for_server_ready", lambda *_a, **_kw: True + ) # why: skips server polling in unit test + monkeypatch.setattr( + "uvicorn.Server.run", lambda _self: None + ) # why: prevents real server startup in unit test monkeypatch.setattr( "zndraw.cli.webbrowser.open", lambda _url: call_order.append("browser") ) # why: prevents browser window during test @@ -229,8 +249,12 @@ def test_browser_before_upload_existing_server(monkeypatch, tmp_path): "http://localhost:8000", ServerEntry(added_at=now, last_used=now, pid=1234, version=__version__), ) - monkeypatch.setattr("zndraw.cli.StateFile", lambda: state) # why: isolates state to tmp_path for filesystem isolation - monkeypatch.setattr("zndraw.cli._is_url_healthy", lambda _url: True) # why: simulates no existing server for StateFile logic + monkeypatch.setattr( + "zndraw.cli.StateFile", lambda: state + ) # why: isolates state to tmp_path for filesystem isolation + monkeypatch.setattr( + "zndraw.cli._is_url_healthy", lambda _url: True + ) # why: simulates no existing server for StateFile logic monkeypatch.setattr( "zndraw.cli.webbrowser.open", lambda _url: call_order.append("browser") ) # why: prevents browser window during test @@ -274,17 +298,27 @@ def spy_init(self, **kwargs): original_init(self, **kwargs) captured.append(self) - monkeypatch.setattr(Settings, "__init__", spy_init) # why: spy on Settings instantiation to verify config propagation - monkeypatch.setattr("uvicorn.Server.run", lambda _self: None) # why: prevents real server startup in unit test + monkeypatch.setattr( + Settings, "__init__", spy_init + ) # why: spy on Settings instantiation to verify config propagation + monkeypatch.setattr( + "uvicorn.Server.run", lambda _self: None + ) # why: prevents real server startup in unit test state_dir = tmp_path / "state" state_dir.mkdir() - monkeypatch.setattr("zndraw.cli.StateFile", lambda: StateFile(directory=state_dir)) # why: isolates state to tmp_path for filesystem isolation - monkeypatch.setattr("zndraw.cli._is_url_healthy", lambda _url: False) # why: simulates no existing server for StateFile logic + monkeypatch.setattr( + "zndraw.cli.StateFile", lambda: StateFile(directory=state_dir) + ) # why: isolates state to tmp_path for filesystem isolation + monkeypatch.setattr( + "zndraw.cli._is_url_healthy", lambda _url: False + ) # why: simulates no existing server for StateFile logic monkeypatch.setattr( "zndraw.cli.wait_for_server_ready", lambda _url, timeout=30.0: True, # noqa: ARG005 # why: skips server polling in unit test ) - monkeypatch.setattr("zndraw.cli._acquire_admin_jwt", lambda _url: None) # why: unit test of Settings propagation, not auth flow + monkeypatch.setattr( + "zndraw.cli._acquire_admin_jwt", lambda _url: None + ) # why: unit test of Settings propagation, not auth flow return captured diff --git a/tests/zndraw/test_cli_agent/test_auth.py b/tests/zndraw/test_cli_agent/test_auth.py index b2409f66a..36e959a4d 100644 --- a/tests/zndraw/test_cli_agent/test_auth.py +++ b/tests/zndraw/test_cli_agent/test_auth.py @@ -5,8 +5,6 @@ from typing import TYPE_CHECKING from unittest.mock import MagicMock, patch -import pytest - from zndraw.cli_agent import app from zndraw.state_file import StateFile diff --git a/tests/zndraw/test_cli_auth.py b/tests/zndraw/test_cli_auth.py index 579a54b0e..a3030d53d 100644 --- a/tests/zndraw/test_cli_auth.py +++ b/tests/zndraw/test_cli_auth.py @@ -33,7 +33,9 @@ def stored_entry(): # -- auth status -------------------------------------------------------------- -def test_auth_status_with_stored_token(server: str, state_file, stored_entry, monkeypatch): +def test_auth_status_with_stored_token( + server: str, state_file, stored_entry, monkeypatch +): """auth status should show identity from stored token against a real server.""" # Get a real guest token so /v1/auth/users/me will succeed resp = httpx.post(f"{server}/v1/auth/guest", timeout=10.0) diff --git a/tests/zndraw/test_client_api.py b/tests/zndraw/test_client_api.py index 56fb76b0a..1f56df720 100644 --- a/tests/zndraw/test_client_api.py +++ b/tests/zndraw/test_client_api.py @@ -6,6 +6,7 @@ import uuid import warnings + import pytest from zndraw import ZnDraw diff --git a/tests/zndraw/test_frames_provider_dispatch.py b/tests/zndraw/test_frames_provider_dispatch.py index 990856712..0d4bc05f9 100644 --- a/tests/zndraw/test_frames_provider_dispatch.py +++ b/tests/zndraw/test_frames_provider_dispatch.py @@ -22,7 +22,6 @@ from zndraw_joblib.exceptions import ProviderTimeout from zndraw_joblib.models import ProviderRecord, Worker - # ============================================================================= # Helpers # ============================================================================= diff --git a/tests/zndraw/test_gif.py b/tests/zndraw/test_gif.py index bc74bdb53..8b2822972 100644 --- a/tests/zndraw/test_gif.py +++ b/tests/zndraw/test_gif.py @@ -250,8 +250,12 @@ def test_capture_orbit_and_curve_mutually_exclusive(): assert result.exit_code != 0 -@patch("zndraw.cli_agent.gif.get_zndraw") # why: orchestration test of geometry creation/cleanup, not connectivity -@patch("zndraw.cli_agent.gif.resolve_room", return_value="test-room") # why: orchestration test of geometry creation/cleanup, not connectivity +@patch( + "zndraw.cli_agent.gif.get_zndraw" +) # why: orchestration test of geometry creation/cleanup, not connectivity +@patch( + "zndraw.cli_agent.gif.resolve_room", return_value="test-room" +) # why: orchestration test of geometry creation/cleanup, not connectivity def test_capture_orbit_creates_temp_geometries( mock_resolve, mock_get_zndraw, tmp_path: Path ): @@ -295,8 +299,12 @@ def test_capture_orbit_creates_temp_geometries( assert len(geom_store) == 0 -@patch("zndraw.cli_agent.gif.get_zndraw") # why: orchestration test of geometry creation/cleanup, not connectivity -@patch("zndraw.cli_agent.gif.resolve_room", return_value="test-room") # why: orchestration test of geometry creation/cleanup, not connectivity +@patch( + "zndraw.cli_agent.gif.get_zndraw" +) # why: orchestration test of geometry creation/cleanup, not connectivity +@patch( + "zndraw.cli_agent.gif.resolve_room", return_value="test-room" +) # why: orchestration test of geometry creation/cleanup, not connectivity def test_capture_restores_step(mock_resolve, mock_get_zndraw, tmp_path: Path): """Verify step is restored to its original value after capture.""" from typer.testing import CliRunner @@ -331,8 +339,12 @@ def test_capture_restores_step(mock_resolve, mock_get_zndraw, tmp_path: Path): assert step_box[0] == 7 -@patch("zndraw.cli_agent.gif.get_zndraw") # why: orchestration test of geometry creation/cleanup, not connectivity -@patch("zndraw.cli_agent.gif.resolve_room", return_value="test-room") # why: orchestration test of geometry creation/cleanup, not connectivity +@patch( + "zndraw.cli_agent.gif.get_zndraw" +) # why: orchestration test of geometry creation/cleanup, not connectivity +@patch( + "zndraw.cli_agent.gif.resolve_room", return_value="test-room" +) # why: orchestration test of geometry creation/cleanup, not connectivity def test_capture_cleans_up_on_error(mock_resolve, mock_get_zndraw, tmp_path: Path): """Verify cleanup runs even when screenshot raises.""" from typer.testing import CliRunner @@ -377,7 +389,9 @@ def test_capture_cleans_up_on_error(mock_resolve, mock_get_zndraw, tmp_path: Pat def test_pillow_missing_gives_clear_error(tmp_path: Path): """Verify exit when Pillow import fails.""" with ( - patch.dict("sys.modules", {"PIL": None, "PIL.Image": None}), # why: tests error message when PIL is not installed + patch.dict( + "sys.modules", {"PIL": None, "PIL.Image": None} + ), # why: tests error message when PIL is not installed pytest.raises(SystemExit), ): _assemble_gif([b"fake"], tmp_path / "test.gif", 20) diff --git a/tests/zndraw/test_local_token_auth.py b/tests/zndraw/test_local_token_auth.py index 1ac8261bc..4414cbe8c 100644 --- a/tests/zndraw/test_local_token_auth.py +++ b/tests/zndraw/test_local_token_auth.py @@ -38,7 +38,9 @@ async def regular_user(session: AsyncSession) -> User: @pytest.mark.asyncio -@patch("zndraw.routes.admin.os.kill") # why: prevents real os.kill() from terminating processes during test +@patch( + "zndraw.routes.admin.os.kill" +) # why: prevents real os.kill() from terminating processes during test async def test_local_token_grants_admin_access(mock_kill, client: AsyncClient) -> None: """Bearer local_token grants admin access to shutdown endpoint.""" from zndraw.app import app @@ -73,7 +75,9 @@ async def test_wrong_local_token_rejected(client: AsyncClient) -> None: @pytest.mark.asyncio -@patch("zndraw.routes.admin.os.kill") # why: prevents real os.kill() from terminating processes during test +@patch( + "zndraw.routes.admin.os.kill" +) # why: prevents real os.kill() from terminating processes during test async def test_no_local_token_configured_normal_auth_works( mock_kill, client: AsyncClient, admin_user: User ) -> None: diff --git a/tests/zndraw/test_progress.py b/tests/zndraw/test_progress.py index 3b1aacab6..88faeae57 100644 --- a/tests/zndraw/test_progress.py +++ b/tests/zndraw/test_progress.py @@ -15,7 +15,6 @@ from zndraw.redis import RedisKey - # ============================================================================= # POST /v1/rooms/{room_id}/progress # ============================================================================= diff --git a/tests/zndraw/test_routes_edit_lock.py b/tests/zndraw/test_routes_edit_lock.py index e6a59537b..e69b08ece 100644 --- a/tests/zndraw/test_routes_edit_lock.py +++ b/tests/zndraw/test_routes_edit_lock.py @@ -13,7 +13,6 @@ from zndraw.redis import RedisKey from zndraw.schemas import StatusResponse - # ============================================================================= # GET /v1/rooms/{room_id}/edit-lock # ============================================================================= diff --git a/tests/zndraw/test_routes_figures.py b/tests/zndraw/test_routes_figures.py index 6165c80b6..a144de049 100644 --- a/tests/zndraw/test_routes_figures.py +++ b/tests/zndraw/test_routes_figures.py @@ -80,9 +80,7 @@ async def test_get_figure_returns_data( user, token = await create_test_user_in_db(session) room = await create_test_room(session, user) - await _add_figure( - session, room.id, "my_chart", '{"data": [1, 2, 3], "layout": {}}' - ) + await _add_figure(session, room.id, "my_chart", '{"data": [1, 2, 3], "layout": {}}') response = await client.get( f"/v1/rooms/{room.id}/figures/my_chart", @@ -267,9 +265,7 @@ async def test_delete_nonexistent_figure_returns_404( @pytest.mark.asyncio -async def test_list_figures_public( - client: AsyncClient, session: AsyncSession -) -> None: +async def test_list_figures_public(client: AsyncClient, session: AsyncSession) -> None: """Test GET without auth succeeds (public endpoint).""" user, _ = await create_test_user_in_db(session) room = await create_test_room(session, user) diff --git a/tests/zndraw/test_routes_frame_selection.py b/tests/zndraw/test_routes_frame_selection.py index be2a69bd2..c79e9d996 100644 --- a/tests/zndraw/test_routes_frame_selection.py +++ b/tests/zndraw/test_routes_frame_selection.py @@ -7,8 +7,6 @@ from httpx import AsyncClient from sqlalchemy.ext.asyncio import AsyncSession - - # ============================================================================= # GET /v1/rooms/{room_id}/frame-selection # ============================================================================= @@ -68,9 +66,7 @@ async def test_get_returns_404_for_nonexistent_room( @pytest.mark.asyncio -async def test_put_stores_indices( - client: AsyncClient, session: AsyncSession -) -> None: +async def test_put_stores_indices(client: AsyncClient, session: AsyncSession) -> None: """PUT stores indices and GET returns them.""" user, token = await create_test_user_in_db(session) room = await create_test_room(session, user) diff --git a/tests/zndraw/test_routes_frames.py b/tests/zndraw/test_routes_frames.py index bc614226f..e560de9ed 100644 --- a/tests/zndraw/test_routes_frames.py +++ b/tests/zndraw/test_routes_frames.py @@ -275,9 +275,7 @@ async def test_get_frame( @pytest.mark.asyncio -async def test_get_frame_not_found( - client: AsyncClient, session: AsyncSession -) -> None: +async def test_get_frame_not_found(client: AsyncClient, session: AsyncSession) -> None: """Test getting non-existent frame returns 404.""" user, token = await _create_user(session) room = await _create_room(session, user) @@ -399,9 +397,7 @@ async def test_get_frame_metadata_room_not_found( @pytest.mark.asyncio -async def test_append_frames( - client: AsyncClient, session: AsyncSession -) -> None: +async def test_append_frames(client: AsyncClient, session: AsyncSession) -> None: """Test appending frames to storage.""" user, token = await _create_user(session) room = await _create_room(session, user) @@ -823,13 +819,9 @@ async def test_frames_require_authentication( if method == "GET": response = await client.get(url) elif method == "POST": - response = await client.post( - url, json={"frames": [_make_json_frame("H2")]} - ) + response = await client.post(url, json={"frames": [_make_json_frame("H2")]}) elif method == "PUT": - response = await client.put( - url, json={"data": _make_json_frame("H2")} - ) + response = await client.put(url, json={"data": _make_json_frame("H2")}) elif method == "PATCH": response = await client.patch( url, diff --git a/tests/zndraw/test_routes_geometries.py b/tests/zndraw/test_routes_geometries.py index 1a049c2ea..427a2d394 100644 --- a/tests/zndraw/test_routes_geometries.py +++ b/tests/zndraw/test_routes_geometries.py @@ -16,7 +16,6 @@ ) from zndraw.socket_events import GeometryInvalidate, SelectionInvalidate - # ============================================================================= # Geometry-specific helpers (kept from original) # ============================================================================= @@ -130,9 +129,7 @@ async def test_list_geometries_includes_owner( user, token = await create_test_user_in_db(session) room = await create_test_room(session, user) - await _add_geometry( - session, room.id, "owned", "Sphere", {}, owner=str(user.id) - ) + await _add_geometry(session, room.id, "owned", "Sphere", {}, owner=str(user.id)) await _add_geometry(session, room.id, "shared", "Box", {}) response = await client.get( @@ -281,9 +278,7 @@ async def test_upsert_geometry_updates_existing( user, token = await create_test_user_in_db(session) room = await create_test_room(session, user) - await _add_geometry( - session, room.id, "sphere", "Sphere", {"radius": [1.0]} - ) + await _add_geometry(session, room.id, "sphere", "Sphere", {"radius": [1.0]}) response = await client.put( f"/v1/rooms/{room.id}/geometries/sphere", @@ -499,9 +494,7 @@ async def test_get_selection_returns_indices( user, token = await create_test_user_in_db(session) room = await create_test_room(session, user) - await _add_geometry( - session, room.id, "sphere", "Sphere", {}, selection=[1, 2, 3] - ) + await _add_geometry(session, room.id, "sphere", "Sphere", {}, selection=[1, 2, 3]) response = await client.get( f"/v1/rooms/{room.id}/geometries/sphere/selection", @@ -567,9 +560,7 @@ async def test_list_geometries_includes_selection( user, token = await create_test_user_in_db(session) room = await create_test_room(session, user) - await _add_geometry( - session, room.id, "sphere", "Sphere", {}, selection=[0, 1, 2] - ) + await _add_geometry(session, room.id, "sphere", "Sphere", {}, selection=[0, 1, 2]) await _add_geometry(session, room.id, "box", "Box", {}) response = await client.get( @@ -610,9 +601,7 @@ async def test_get_geometry_public( user, _ = await create_test_user_in_db(session) room = await create_test_room(session, user) - await _add_geometry( - session, room.id, "somekey", "Sphere", {"radius": [1.0]} - ) + await _add_geometry(session, room.id, "somekey", "Sphere", {"radius": [1.0]}) response = await client.get(f"/v1/rooms/{room.id}/geometries/somekey") assert response.status_code == 200 @@ -682,9 +671,7 @@ async def test_geometry_endpoints_return_404_for_nonexistent_room( ) -> None: """All geometry endpoints return 404 for non-existent room.""" user, token = await create_test_user_in_db(session) - response = await client.request( - method, path, json=body, headers=auth_header(token) - ) + response = await client.request(method, path, json=body, headers=auth_header(token)) assert response.status_code == 404 assert "room-not-found" in response.json()["type"] @@ -701,7 +688,9 @@ async def test_upsert_rejects_non_owner( ) -> None: """Test PUT returns 403 when modifying a geometry owned by another user.""" owner, _ = await create_test_user_in_db(session, email="owner@local.test") - _other, other_token = await create_test_user_in_db(session, email="other@local.test") + _other, other_token = await create_test_user_in_db( + session, email="other@local.test" + ) room = await create_test_room(session, owner) await _add_geometry( @@ -757,9 +746,7 @@ async def test_upsert_allows_unowned( _user_b, token_b = await create_test_user_in_db(session, email="b@local.test") room = await create_test_room(session, user_a) - await _add_geometry( - session, room.id, "shared_sphere", "Sphere", {"radius": [1.0]} - ) + await _add_geometry(session, room.id, "shared_sphere", "Sphere", {"radius": [1.0]}) response = await client.put( f"/v1/rooms/{room.id}/geometries/shared_sphere", @@ -776,7 +763,9 @@ async def test_delete_rejects_non_owner( ) -> None: """Test DELETE returns 403 when deleting geometry owned by another user.""" owner, _ = await create_test_user_in_db(session, email="owner@local.test") - _other, other_token = await create_test_user_in_db(session, email="other@local.test") + _other, other_token = await create_test_user_in_db( + session, email="other@local.test" + ) room = await create_test_room(session, owner) await _add_geometry( @@ -802,7 +791,9 @@ async def test_selection_update_rejects_non_owner( ) -> None: """Test PUT selection returns 403 when geometry is owned by another user.""" owner, _ = await create_test_user_in_db(session, email="owner@local.test") - _other, other_token = await create_test_user_in_db(session, email="other@local.test") + _other, other_token = await create_test_user_in_db( + session, email="other@local.test" + ) room = await create_test_room(session, owner) await _add_geometry( diff --git a/tests/zndraw/test_screenshots.py b/tests/zndraw/test_screenshots.py index e3a0e40a5..65885480a 100644 --- a/tests/zndraw/test_screenshots.py +++ b/tests/zndraw/test_screenshots.py @@ -158,9 +158,7 @@ async def test_upload_invalid_format( @pytest.mark.asyncio -async def test_upload_too_large( - client: AsyncClient, session: AsyncSession -) -> None: +async def test_upload_too_large(client: AsyncClient, session: AsyncSession) -> None: """Upload exceeding 10 MB returns 413.""" user, token = await create_test_user_in_db(session) room = await create_test_room(session, user) @@ -182,9 +180,7 @@ async def test_upload_too_large( @pytest.mark.asyncio -async def test_list_screenshots( - client: AsyncClient, session: AsyncSession -) -> None: +async def test_list_screenshots(client: AsyncClient, session: AsyncSession) -> None: """Upload 3 screenshots, verify list with pagination.""" user, token = await create_test_user_in_db(session) room = await create_test_room(session, user) diff --git a/tests/zndraw/test_state_file_source.py b/tests/zndraw/test_state_file_source.py index fc90ae734..5f2e9e367 100644 --- a/tests/zndraw/test_state_file_source.py +++ b/tests/zndraw/test_state_file_source.py @@ -243,9 +243,7 @@ def test_token_uses_url_from_higher_source(state_file, server_factory): ), ) state_file.add_server(url, _remote_entry()) - source = _make_source( - state_file, current_state={"url": url} - ) + source = _make_source(state_file, current_state={"url": url}) result = source() assert result.get("token") == "override.jwt" assert "url" not in result From 5407fdeec3e2e218a5fa77de2c30358b3469a810 Mon Sep 17 00:00:00 2001 From: Fabian Zills Date: Thu, 26 Mar 2026 16:29:49 +0100 Subject: [PATCH 17/20] fix: address review feedback and ruff lint errors - Fix all 18 ruff errors (ASYNC109, PT006, ARG005, RUF059, E501, TC003) - Add # why: comments on all monkeypatch/patch calls in CLI auth tests - Refactor test_screenshots.py: replace per-file client with thin media_path fixture - Add MockSioServer type annotations in test_routes_edit_lock.py - Use auth_header() helper consistently in test_default_camera.py - Fix test_login_with_credentials_success to actually test login_with_credentials - Assert setup precondition in test_register_fails - Add language specifier to spec code block Co-Authored-By: Claude Opus 4.6 (1M context) --- .../2026-03-26-test-suite-overhaul-design.md | 2 +- tests/zndraw/helpers.py | 4 +- tests/zndraw/test_auth_endpoints.py | 7 +- tests/zndraw/test_cli_agent/test_auth.py | 4 + tests/zndraw/test_cli_auth.py | 31 ++++++-- tests/zndraw/test_default_camera.py | 20 ++--- tests/zndraw/test_e2e_state_file_source.py | 5 +- .../zndraw/test_local_token_jwt_regression.py | 4 +- tests/zndraw/test_resolve_token.py | 6 +- tests/zndraw/test_routes_edit_lock.py | 6 +- tests/zndraw/test_routes_geometries.py | 4 +- tests/zndraw/test_screenshots.py | 73 +++---------------- tests/zndraw/test_state_file_source.py | 9 ++- 13 files changed, 76 insertions(+), 99 deletions(-) diff --git a/docs/superpowers/specs/2026-03-26-test-suite-overhaul-design.md b/docs/superpowers/specs/2026-03-26-test-suite-overhaul-design.md index 661ba9342..fbc81cf84 100644 --- a/docs/superpowers/specs/2026-03-26-test-suite-overhaul-design.md +++ b/docs/superpowers/specs/2026-03-26-test-suite-overhaul-design.md @@ -86,7 +86,7 @@ Critical user flows MUST have E2E tests against real servers: Replace per-file fixture sets with shared fixtures in `tests/zndraw/conftest.py`: -``` +```text conftest.py fixtures: session — in-memory SQLite (already exists) redis_client — real Redis, flushed per test (already exists) diff --git a/tests/zndraw/helpers.py b/tests/zndraw/helpers.py index 0448e7c2b..b93f8c6a5 100644 --- a/tests/zndraw/helpers.py +++ b/tests/zndraw/helpers.py @@ -189,14 +189,14 @@ async def acquire_inflight(self, key: str, _ttl: int) -> bool: async def release_inflight(self, key: str) -> None: self._inflight.discard(key) - async def wait_for_key(self, key: str, timeout: float) -> bytes | None: + async def wait_for_key(self, key: str, wait_timeout: float) -> bytes | None: cached = self._store.get(key) if cached is not None: return cached event = asyncio.Event() self._waiters.setdefault(key, []).append(event) try: - await asyncio.wait_for(event.wait(), timeout=timeout) + await asyncio.wait_for(event.wait(), timeout=wait_timeout) return self._store.get(key) except TimeoutError: return None diff --git a/tests/zndraw/test_auth_endpoints.py b/tests/zndraw/test_auth_endpoints.py index 7bafd50a2..215ecb7e0 100644 --- a/tests/zndraw/test_auth_endpoints.py +++ b/tests/zndraw/test_auth_endpoints.py @@ -26,7 +26,7 @@ async def test_login(client: AsyncClient, test_user: User) -> None: @pytest.mark.asyncio @pytest.mark.parametrize( - "username,password", + ("username", "password"), [ ("testuser@local.test", "wrongpassword"), ("nonexistent@local.test", "password"), @@ -93,7 +93,7 @@ async def test_register(client: AsyncClient) -> None: @pytest.mark.asyncio @pytest.mark.parametrize( - "setup_email,email,password,expected_status", + ("setup_email", "email", "password", "expected_status"), [ ("duplicate@example.com", "duplicate@example.com", "password456", 400), (None, "newuser@example.com", None, 422), @@ -109,10 +109,11 @@ async def test_register_fails( ) -> None: """Registration returns error for invalid input.""" if setup_email: - await client.post( + setup_resp = await client.post( "/v1/auth/register", json={"email": setup_email, "password": "password123"}, ) + assert setup_resp.status_code == 201 body = {"email": email} if password is not None: body["password"] = password diff --git a/tests/zndraw/test_cli_agent/test_auth.py b/tests/zndraw/test_cli_agent/test_auth.py index 36e959a4d..41086a38f 100644 --- a/tests/zndraw/test_cli_agent/test_auth.py +++ b/tests/zndraw/test_cli_agent/test_auth.py @@ -17,6 +17,7 @@ def test_auth_login_opens_browser( ) -> None: """Login (without --code) should open the browser.""" state_file = StateFile(directory=tmp_path) + # why: isolate state.json to tmp_path so tests don't share token storage monkeypatch.setattr("zndraw.cli_agent.auth.StateFile", lambda: state_file) challenge_resp = MagicMock() @@ -50,6 +51,7 @@ def mock_get(path, **kwargs): mock_client.get.side_effect = mock_get with ( + # why: device-code login requires choreographed challenge/poll responses patch("zndraw.cli_agent.auth.httpx.Client", return_value=mock_client), # why: webbrowser.open is a real OS side-effect that cannot run in CI patch("zndraw.cli_agent.auth.webbrowser.open") as mock_browser, @@ -68,6 +70,7 @@ def test_auth_login_code_flag_does_not_open_browser( ) -> None: """Login with --code should print URL instead of opening browser.""" state_file = StateFile(directory=tmp_path) + # why: isolate state.json to tmp_path so tests don't share token storage monkeypatch.setattr("zndraw.cli_agent.auth.StateFile", lambda: state_file) challenge_resp = MagicMock() @@ -101,6 +104,7 @@ def mock_get(path, **kwargs): mock_client.get.side_effect = mock_get with ( + # why: device-code login requires choreographed challenge/poll responses patch("zndraw.cli_agent.auth.httpx.Client", return_value=mock_client), # why: webbrowser.open is a real OS side-effect that cannot run in CI patch("zndraw.cli_agent.auth.webbrowser.open") as mock_browser, diff --git a/tests/zndraw/test_cli_auth.py b/tests/zndraw/test_cli_auth.py index a3030d53d..78b4e98b7 100644 --- a/tests/zndraw/test_cli_auth.py +++ b/tests/zndraw/test_cli_auth.py @@ -49,8 +49,10 @@ def test_auth_status_with_stored_token( ) state_file.add_token(server, real_entry) + # why: isolate state.json to tmp_path so tests don't share token storage monkeypatch.setattr("zndraw.cli_agent.auth.StateFile", lambda: state_file) - monkeypatch.setattr("zndraw.cli_agent.auth._resolve_url", lambda url: server) + # why: point CLI at the test server instead of auto-discovering + monkeypatch.setattr("zndraw.cli_agent.auth._resolve_url", lambda _url: server) result = runner.invoke(app, ["auth", "status"]) @@ -61,8 +63,10 @@ def test_auth_status_with_stored_token( def test_auth_status_not_logged_in(server: str, state_file, monkeypatch): """auth status with no stored/explicit token should report not logged in.""" + # why: isolate state.json to tmp_path so tests don't share token storage monkeypatch.setattr("zndraw.cli_agent.auth.StateFile", lambda: state_file) - monkeypatch.setattr("zndraw.cli_agent.auth._resolve_url", lambda url: server) + # why: point CLI at the test server instead of auto-discovering + monkeypatch.setattr("zndraw.cli_agent.auth._resolve_url", lambda _url: server) result = runner.invoke(app, ["auth", "status"]) @@ -79,8 +83,10 @@ def test_auth_logout_removes_token(server: str, state_file, stored_entry, monkey """auth logout should remove the stored token for the server.""" state_file.add_token(server, stored_entry) + # why: isolate state.json to tmp_path so tests don't share token storage monkeypatch.setattr("zndraw.cli_agent.auth.StateFile", lambda: state_file) - monkeypatch.setattr("zndraw.cli_agent.auth._resolve_url", lambda url: server) + # why: point CLI at the test server instead of auto-discovering + monkeypatch.setattr("zndraw.cli_agent.auth._resolve_url", lambda _url: server) result = runner.invoke(app, ["auth", "logout"]) @@ -90,8 +96,10 @@ def test_auth_logout_removes_token(server: str, state_file, stored_entry, monkey def test_auth_logout_no_token(server: str, state_file, monkeypatch): """auth logout when no token is stored should not error.""" + # why: isolate state.json to tmp_path so tests don't share token storage monkeypatch.setattr("zndraw.cli_agent.auth.StateFile", lambda: state_file) - monkeypatch.setattr("zndraw.cli_agent.auth._resolve_url", lambda url: server) + # why: point CLI at the test server instead of auto-discovering + monkeypatch.setattr("zndraw.cli_agent.auth._resolve_url", lambda _url: server) result = runner.invoke(app, ["auth", "logout"]) @@ -147,10 +155,13 @@ def mock_get(path, **kwargs): mock_client.post.return_value = challenge_resp mock_client.get.side_effect = mock_get + # why: isolate state.json to tmp_path so tests don't share token storage monkeypatch.setattr("zndraw.cli_agent.auth.StateFile", lambda: state_file) - monkeypatch.setattr("zndraw.cli_agent.auth._resolve_url", lambda url: server) + # why: point CLI at the test server instead of auto-discovering + monkeypatch.setattr("zndraw.cli_agent.auth._resolve_url", lambda _url: server) with ( + # why: device-code login requires choreographed challenge/poll responses patch("zndraw.cli_agent.auth.httpx.Client", return_value=mock_client), # why: webbrowser.open is a real OS side-effect that cannot run in CI patch("zndraw.cli_agent.auth.webbrowser.open"), @@ -186,10 +197,13 @@ def test_auth_login_rejected(server: str, state_file, monkeypatch): mock_client.post.return_value = challenge_resp mock_client.get.return_value = rejected_resp + # why: isolate state.json to tmp_path so tests don't share token storage monkeypatch.setattr("zndraw.cli_agent.auth.StateFile", lambda: state_file) - monkeypatch.setattr("zndraw.cli_agent.auth._resolve_url", lambda url: server) + # why: point CLI at the test server instead of auto-discovering + monkeypatch.setattr("zndraw.cli_agent.auth._resolve_url", lambda _url: server) with ( + # why: device-code login requires choreographed challenge/poll responses patch("zndraw.cli_agent.auth.httpx.Client", return_value=mock_client), # why: webbrowser.open is a real OS side-effect that cannot run in CI patch("zndraw.cli_agent.auth.webbrowser.open"), @@ -221,10 +235,13 @@ def test_auth_login_expired(server: str, state_file, monkeypatch): mock_client.post.return_value = challenge_resp mock_client.get.return_value = expired_resp + # why: isolate state.json to tmp_path so tests don't share token storage monkeypatch.setattr("zndraw.cli_agent.auth.StateFile", lambda: state_file) - monkeypatch.setattr("zndraw.cli_agent.auth._resolve_url", lambda url: server) + # why: point CLI at the test server instead of auto-discovering + monkeypatch.setattr("zndraw.cli_agent.auth._resolve_url", lambda _url: server) with ( + # why: device-code login requires choreographed challenge/poll responses patch("zndraw.cli_agent.auth.httpx.Client", return_value=mock_client), # why: webbrowser.open is a real OS side-effect that cannot run in CI patch("zndraw.cli_agent.auth.webbrowser.open"), diff --git a/tests/zndraw/test_default_camera.py b/tests/zndraw/test_default_camera.py index 4df6638b1..87b885104 100644 --- a/tests/zndraw/test_default_camera.py +++ b/tests/zndraw/test_default_camera.py @@ -3,7 +3,7 @@ import json import pytest -from helpers import create_test_room, create_test_user_in_db +from helpers import auth_header, create_test_room, create_test_user_in_db from httpx import AsyncClient from sqlalchemy.ext.asyncio import AsyncSession @@ -42,7 +42,7 @@ async def test_get_default_camera_none( resp = await client.get( f"/v1/rooms/{room.id}/default-camera", - headers={"Authorization": f"Bearer {token}"}, + headers=auth_header(token), ) assert resp.status_code == 200 assert resp.json()["default_camera"] is None @@ -58,14 +58,14 @@ async def test_set_default_camera(session: AsyncSession, client: AsyncClient) -> resp = await client.put( f"/v1/rooms/{room.id}/default-camera", json={"default_camera": "template-cam"}, - headers={"Authorization": f"Bearer {token}"}, + headers=auth_header(token), ) assert resp.status_code == 200 assert resp.json()["default_camera"] == "template-cam" resp = await client.get( f"/v1/rooms/{room.id}/default-camera", - headers={"Authorization": f"Bearer {token}"}, + headers=auth_header(token), ) assert resp.json()["default_camera"] == "template-cam" @@ -81,7 +81,7 @@ async def test_set_default_camera_not_found( resp = await client.put( f"/v1/rooms/{room.id}/default-camera", json={"default_camera": "nonexistent"}, - headers={"Authorization": f"Bearer {token}"}, + headers=auth_header(token), ) assert resp.status_code == 404 @@ -98,7 +98,7 @@ async def test_set_default_camera_wrong_type( resp = await client.put( f"/v1/rooms/{room.id}/default-camera", json={"default_camera": "my-sphere"}, - headers={"Authorization": f"Bearer {token}"}, + headers=auth_header(token), ) assert resp.status_code == 400 @@ -113,20 +113,20 @@ async def test_unset_default_camera(session: AsyncSession, client: AsyncClient) await client.put( f"/v1/rooms/{room.id}/default-camera", json={"default_camera": "template-cam"}, - headers={"Authorization": f"Bearer {token}"}, + headers=auth_header(token), ) resp = await client.put( f"/v1/rooms/{room.id}/default-camera", json={"default_camera": None}, - headers={"Authorization": f"Bearer {token}"}, + headers=auth_header(token), ) assert resp.status_code == 200 assert resp.json()["default_camera"] is None resp = await client.get( f"/v1/rooms/{room.id}/default-camera", - headers={"Authorization": f"Bearer {token}"}, + headers=auth_header(token), ) assert resp.json()["default_camera"] is None @@ -139,7 +139,7 @@ async def test_delete_geometry_clears_default( user, token = await create_test_user_in_db(session) room = await create_test_room(session, user) await _create_geometry(session, room.id, "template-cam", "Camera") - headers = {"Authorization": f"Bearer {token}"} + headers = auth_header(token) # Set as default await client.put( diff --git a/tests/zndraw/test_e2e_state_file_source.py b/tests/zndraw/test_e2e_state_file_source.py index 9921b343d..dc3d522f7 100644 --- a/tests/zndraw/test_e2e_state_file_source.py +++ b/tests/zndraw/test_e2e_state_file_source.py @@ -8,10 +8,13 @@ import os from datetime import UTC, datetime -from pathlib import Path +from typing import TYPE_CHECKING from zndraw.state_file import ServerEntry, StateFile +if TYPE_CHECKING: + from pathlib import Path + def _make_source(state_file: StateFile): """Create a StateFileSource pointing at the given StateFile.""" diff --git a/tests/zndraw/test_local_token_jwt_regression.py b/tests/zndraw/test_local_token_jwt_regression.py index 5c6a6a78d..891cf027c 100644 --- a/tests/zndraw/test_local_token_jwt_regression.py +++ b/tests/zndraw/test_local_token_jwt_regression.py @@ -111,7 +111,7 @@ def test_e2e_dev_mode_zndraw_client_connects(server, tmp_path, monkeypatch): monkeypatch.setattr( "zndraw.settings_sources.StateFile", lambda: StateFile(directory=tmp_path), - ) # why: redirects StateFile to tmp_path for filesystem isolation; test uses real server via server_factory + ) # why: redirects StateFile to tmp_path; real server via server_factory client = ZnDraw(url=server, room="test-e2e-dev") try: @@ -150,7 +150,7 @@ def test_e2e_production_mode_zndraw_client_connects(server_auth, tmp_path, monke monkeypatch.setattr( "zndraw.settings_sources.StateFile", lambda: StateFile(directory=tmp_path), - ) # why: redirects StateFile to tmp_path for filesystem isolation; test uses real server via server_factory + ) # why: redirects StateFile to tmp_path; real server via server_factory client = ZnDraw(url=server_auth, room="test-e2e-prod") try: diff --git a/tests/zndraw/test_resolve_token.py b/tests/zndraw/test_resolve_token.py index c2848bfb6..1bdd9be32 100644 --- a/tests/zndraw/test_resolve_token.py +++ b/tests/zndraw/test_resolve_token.py @@ -29,11 +29,9 @@ def test_valid_combinations_pass(): validate_credentials(token=None, user=None, password=None) -def test_login_with_credentials_success(server: str): +def test_login_with_credentials_success(server_auth: str): """login_with_credentials returns a JWT token string from a real server.""" - # Use guest login first — the server in open mode accepts any user via JWT login - # with the built-in admin credentials (no auth mode = open guest access) - result = guest_login(server) + result = login_with_credentials(server_auth, "admin@local.test", "adminpassword") assert isinstance(result, str) assert len(result) > 0 diff --git a/tests/zndraw/test_routes_edit_lock.py b/tests/zndraw/test_routes_edit_lock.py index e69b08ece..b64e0a2e0 100644 --- a/tests/zndraw/test_routes_edit_lock.py +++ b/tests/zndraw/test_routes_edit_lock.py @@ -4,7 +4,7 @@ import json import pytest -from helpers import auth_header, create_test_room, create_test_user_in_db +from helpers import MockSioServer, auth_header, create_test_room, create_test_user_in_db from httpx import AsyncClient from redis.asyncio import Redis from sqlalchemy.ext.asyncio import AsyncSession @@ -144,7 +144,7 @@ async def test_acquire_edit_lock_stores_session_id( async def test_acquire_edit_lock_broadcasts_lock_update( client: AsyncClient, session: AsyncSession, - mock_sio, + mock_sio: MockSioServer, ) -> None: """Test PUT broadcasts LockUpdate socket event with ttl.""" user, token = await create_test_user_in_db(session) @@ -386,7 +386,7 @@ async def test_release_edit_lock_with_token( async def test_release_edit_lock_broadcasts_lock_update( client: AsyncClient, session: AsyncSession, - mock_sio, + mock_sio: MockSioServer, ) -> None: """Test DELETE broadcasts LockUpdate with action=released.""" user, token = await create_test_user_in_db(session) diff --git a/tests/zndraw/test_routes_geometries.py b/tests/zndraw/test_routes_geometries.py index 427a2d394..53b34e125 100644 --- a/tests/zndraw/test_routes_geometries.py +++ b/tests/zndraw/test_routes_geometries.py @@ -656,7 +656,7 @@ async def test_update_selection_requires_auth( @pytest.mark.asyncio @pytest.mark.parametrize( - "method,path,body", + ("method", "path", "body"), [ ("GET", "/v1/rooms/99999/geometries", None), ("GET", "/v1/rooms/99999/geometries/somekey", None), @@ -670,7 +670,7 @@ async def test_geometry_endpoints_return_404_for_nonexistent_room( client: AsyncClient, session: AsyncSession, method: str, path: str, body ) -> None: """All geometry endpoints return 404 for non-existent room.""" - user, token = await create_test_user_in_db(session) + _, token = await create_test_user_in_db(session) response = await client.request(method, path, json=body, headers=auth_header(token)) assert response.status_code == 404 assert "room-not-found" in response.json()["type"] diff --git a/tests/zndraw/test_screenshots.py b/tests/zndraw/test_screenshots.py index 65885480a..390c1a1ce 100644 --- a/tests/zndraw/test_screenshots.py +++ b/tests/zndraw/test_screenshots.py @@ -1,89 +1,40 @@ """Tests for Screenshot REST API endpoints.""" import json -from collections.abc import AsyncIterator -from contextlib import asynccontextmanager from pathlib import Path import pytest -import pytest_asyncio from helpers import ( - InMemoryResultBackend, MockSioServer, auth_header, create_test_room, create_test_user_in_db, ) -from httpx import ASGITransport, AsyncClient +from httpx import AsyncClient from sqlalchemy.ext.asyncio import AsyncSession from zndraw.redis import RedisKey from zndraw.schemas import StatusResponse -from zndraw.storage import FrameStorage # ============================================================================= -# Per-file client fixture (adds media_path override on top of shared infra) +# Fixture: adds media_path override on top of the shared client # ============================================================================= @pytest.fixture(name="media_path") def media_path_fixture(tmp_path: Path) -> Path: - """Provide a temporary media directory for screenshots.""" - return tmp_path / "media" + """Provide a temporary media directory and install the dependency override. - -@pytest_asyncio.fixture(name="client") -async def client_fixture( - session: AsyncSession, - redis_client, - mock_sio: MockSioServer, - frame_storage: FrameStorage, - result_backend: InMemoryResultBackend, - media_path: Path, -) -> AsyncIterator[AsyncClient]: - """Async test client with media_path override for screenshot tests.""" + The shared ``client`` fixture (from conftest) sets up the base overrides. + This fixture layers ``get_media_path`` on top — tests that need it request + both ``client`` and ``media_path``. + """ from zndraw.app import app - from zndraw.config import Settings - from zndraw.dependencies import ( - get_frame_storage, - get_joblib_settings, - get_media_path, - get_redis, - get_result_backend, - get_tsio, - ) - from zndraw_auth import get_session - from zndraw_auth.settings import AuthSettings - from zndraw_joblib.settings import JobLibSettings - - async def get_session_override() -> AsyncIterator[AsyncSession]: - yield session - - @asynccontextmanager - async def test_session_maker(): - yield session - - settings = Settings() - settings.media_path = media_path # type: ignore[assignment] - - app.dependency_overrides[get_session] = get_session_override - app.dependency_overrides[get_redis] = lambda: redis_client - app.dependency_overrides[get_tsio] = lambda: mock_sio - app.dependency_overrides[get_frame_storage] = lambda: frame_storage - app.dependency_overrides[get_result_backend] = lambda: result_backend - app.dependency_overrides[get_joblib_settings] = lambda: JobLibSettings() - app.dependency_overrides[get_media_path] = lambda: media_path - app.state.session_maker = test_session_maker - app.state.settings = settings - app.state.auth_settings = AuthSettings() - - async with AsyncClient( - transport=ASGITransport(app=app), - base_url="http://test", - ) as client: - yield client - - app.dependency_overrides.clear() + from zndraw.dependencies import get_media_path + + media = tmp_path / "media" + app.dependency_overrides[get_media_path] = lambda: media + return media # ============================================================================= diff --git a/tests/zndraw/test_state_file_source.py b/tests/zndraw/test_state_file_source.py index 5f2e9e367..b2b8d8694 100644 --- a/tests/zndraw/test_state_file_source.py +++ b/tests/zndraw/test_state_file_source.py @@ -121,7 +121,8 @@ def test_url_falls_back_to_remote(state_file): state_file.add_server("https://remote.example.com", _remote_entry()) source = _make_source(state_file) with ( - # why: pure-logic test of URL-healthy decision path without requiring remote server + # why: pure-logic test of URL-healthy decision path + # without requiring remote server patch("zndraw.settings_sources._is_url_healthy", return_value=True), ): result = source() @@ -211,7 +212,8 @@ def test_token_remote_uses_stored_token(state_file): ) source = _make_source(state_file) with ( - # why: pure-logic test of URL-healthy decision path without requiring remote server + # why: pure-logic test of URL-healthy decision path + # without requiring remote server patch("zndraw.settings_sources._is_url_healthy", return_value=True), ): result = source() @@ -223,7 +225,8 @@ def test_token_no_stored_returns_none(state_file): state_file.add_server(_DEAD_URL, _remote_entry()) source = _make_source(state_file) with ( - # why: pure-logic test of URL-healthy decision path without requiring remote server + # why: pure-logic test of URL-healthy decision path + # without requiring remote server patch("zndraw.settings_sources._is_url_healthy", return_value=True), ): result = source() From 12a99b34e19542c4fdc5218666d064d22df1f0c6 Mon Sep 17 00:00:00 2001 From: Fabian Zills Date: Thu, 26 Mar 2026 16:44:42 +0100 Subject: [PATCH 18/20] refactor: add media_path cleanup and extract mock_httpx_client fixture - test_screenshots.py: media_path fixture now yields and cleans up override - test_cli_auth.py: extract mock_httpx_client fixture and _challenge_response helper to deduplicate login test setup Co-Authored-By: Claude Opus 4.6 (1M context) --- tests/zndraw/test_cli_auth.py | 82 +++++++++++++------------------- tests/zndraw/test_screenshots.py | 5 +- 2 files changed, 37 insertions(+), 50 deletions(-) diff --git a/tests/zndraw/test_cli_auth.py b/tests/zndraw/test_cli_auth.py index 78b4e98b7..d23c9a114 100644 --- a/tests/zndraw/test_cli_auth.py +++ b/tests/zndraw/test_cli_auth.py @@ -30,6 +30,28 @@ def stored_entry(): ) +@pytest.fixture +def mock_httpx_client(): + """MagicMock httpx.Client with context-manager protocol configured.""" + client = MagicMock() + client.__enter__ = MagicMock(return_value=client) + client.__exit__ = MagicMock(return_value=False) + return client + + +def _challenge_response() -> MagicMock: + """Return a standard challenge response mock.""" + resp = MagicMock() + resp.status_code = 200 + resp.raise_for_status = MagicMock() + resp.json.return_value = { + "code": "ABCD1234", + "secret": "test-secret", + "approve_url": "/auth/cli-login/ABCD1234", + } + return resp + + # -- auth status -------------------------------------------------------------- @@ -109,17 +131,8 @@ def test_auth_logout_no_token(server: str, state_file, monkeypatch): # -- auth login ---------------------------------------------------------------- -def test_auth_login_approved(server: str, state_file, monkeypatch): +def test_auth_login_approved(server: str, state_file, mock_httpx_client, monkeypatch): """auth login should store token on successful approval.""" - challenge_resp = MagicMock() - challenge_resp.status_code = 200 - challenge_resp.raise_for_status = MagicMock() - challenge_resp.json.return_value = { - "code": "ABCD1234", - "secret": "test-secret", - "approve_url": "/auth/cli-login/ABCD1234", - } - # First poll: pending, second: approved pending_resp = MagicMock() pending_resp.status_code = 200 @@ -149,11 +162,8 @@ def mock_get(path, **kwargs): return pending_resp if call_count == 1 else approved_resp return me_resp - mock_client = MagicMock() - mock_client.__enter__ = MagicMock(return_value=mock_client) - mock_client.__exit__ = MagicMock(return_value=False) - mock_client.post.return_value = challenge_resp - mock_client.get.side_effect = mock_get + mock_httpx_client.post.return_value = _challenge_response() + mock_httpx_client.get.side_effect = mock_get # why: isolate state.json to tmp_path so tests don't share token storage monkeypatch.setattr("zndraw.cli_agent.auth.StateFile", lambda: state_file) @@ -162,7 +172,7 @@ def mock_get(path, **kwargs): with ( # why: device-code login requires choreographed challenge/poll responses - patch("zndraw.cli_agent.auth.httpx.Client", return_value=mock_client), + patch("zndraw.cli_agent.auth.httpx.Client", return_value=mock_httpx_client), # why: webbrowser.open is a real OS side-effect that cannot run in CI patch("zndraw.cli_agent.auth.webbrowser.open"), # why: time.sleep(1) x 300 iterations would make tests take minutes @@ -177,25 +187,13 @@ def mock_get(path, **kwargs): assert stored.email == "user@example.com" -def test_auth_login_rejected(server: str, state_file, monkeypatch): +def test_auth_login_rejected(server: str, state_file, mock_httpx_client, monkeypatch): """auth login should show error when challenge is rejected (404).""" - challenge_resp = MagicMock() - challenge_resp.status_code = 200 - challenge_resp.raise_for_status = MagicMock() - challenge_resp.json.return_value = { - "code": "ABCD1234", - "secret": "test-secret", - "approve_url": "/auth/cli-login/ABCD1234", - } - rejected_resp = MagicMock() rejected_resp.status_code = 404 - mock_client = MagicMock() - mock_client.__enter__ = MagicMock(return_value=mock_client) - mock_client.__exit__ = MagicMock(return_value=False) - mock_client.post.return_value = challenge_resp - mock_client.get.return_value = rejected_resp + mock_httpx_client.post.return_value = _challenge_response() + mock_httpx_client.get.return_value = rejected_resp # why: isolate state.json to tmp_path so tests don't share token storage monkeypatch.setattr("zndraw.cli_agent.auth.StateFile", lambda: state_file) @@ -204,7 +202,7 @@ def test_auth_login_rejected(server: str, state_file, monkeypatch): with ( # why: device-code login requires choreographed challenge/poll responses - patch("zndraw.cli_agent.auth.httpx.Client", return_value=mock_client), + patch("zndraw.cli_agent.auth.httpx.Client", return_value=mock_httpx_client), # why: webbrowser.open is a real OS side-effect that cannot run in CI patch("zndraw.cli_agent.auth.webbrowser.open"), # why: time.sleep(1) x 300 iterations would make tests take minutes @@ -215,25 +213,13 @@ def test_auth_login_rejected(server: str, state_file, monkeypatch): assert result.exit_code != 0 -def test_auth_login_expired(server: str, state_file, monkeypatch): +def test_auth_login_expired(server: str, state_file, mock_httpx_client, monkeypatch): """auth login should show error when challenge expires (410).""" - challenge_resp = MagicMock() - challenge_resp.status_code = 200 - challenge_resp.raise_for_status = MagicMock() - challenge_resp.json.return_value = { - "code": "ABCD1234", - "secret": "test-secret", - "approve_url": "/auth/cli-login/ABCD1234", - } - expired_resp = MagicMock() expired_resp.status_code = 410 - mock_client = MagicMock() - mock_client.__enter__ = MagicMock(return_value=mock_client) - mock_client.__exit__ = MagicMock(return_value=False) - mock_client.post.return_value = challenge_resp - mock_client.get.return_value = expired_resp + mock_httpx_client.post.return_value = _challenge_response() + mock_httpx_client.get.return_value = expired_resp # why: isolate state.json to tmp_path so tests don't share token storage monkeypatch.setattr("zndraw.cli_agent.auth.StateFile", lambda: state_file) @@ -242,7 +228,7 @@ def test_auth_login_expired(server: str, state_file, monkeypatch): with ( # why: device-code login requires choreographed challenge/poll responses - patch("zndraw.cli_agent.auth.httpx.Client", return_value=mock_client), + patch("zndraw.cli_agent.auth.httpx.Client", return_value=mock_httpx_client), # why: webbrowser.open is a real OS side-effect that cannot run in CI patch("zndraw.cli_agent.auth.webbrowser.open"), # why: time.sleep(1) x 300 iterations would make tests take minutes diff --git a/tests/zndraw/test_screenshots.py b/tests/zndraw/test_screenshots.py index 390c1a1ce..56e119f21 100644 --- a/tests/zndraw/test_screenshots.py +++ b/tests/zndraw/test_screenshots.py @@ -22,7 +22,7 @@ @pytest.fixture(name="media_path") -def media_path_fixture(tmp_path: Path) -> Path: +def media_path_fixture(tmp_path: Path): """Provide a temporary media directory and install the dependency override. The shared ``client`` fixture (from conftest) sets up the base overrides. @@ -34,7 +34,8 @@ def media_path_fixture(tmp_path: Path) -> Path: media = tmp_path / "media" app.dependency_overrides[get_media_path] = lambda: media - return media + yield media + app.dependency_overrides.pop(get_media_path, None) # ============================================================================= From 60ee7ffc5115c9a1621dd83883e46037efdf4e7a Mon Sep 17 00:00:00 2001 From: Fabian Zills Date: Tue, 31 Mar 2026 10:59:44 +0200 Subject: [PATCH 19/20] test: tighten stored-token assertions in test_auth_status_with_stored_token Replace weak `"server" in data` check with explicit value assertions so the test actually verifies the stored-token code path. Co-Authored-By: Claude Opus 4.6 (1M context) --- tests/zndraw/test_cli_auth.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/zndraw/test_cli_auth.py b/tests/zndraw/test_cli_auth.py index d23c9a114..f6b0e8ec6 100644 --- a/tests/zndraw/test_cli_auth.py +++ b/tests/zndraw/test_cli_auth.py @@ -80,7 +80,8 @@ def test_auth_status_with_stored_token( assert result.exit_code == 0, result.output data = json.loads(result.stdout) - assert "server" in data + assert data["server"] == server + assert data["token_source"] == "stored" def test_auth_status_not_logged_in(server: str, state_file, monkeypatch): From 984221647cf47af63c318b067500a26b3882fe66 Mon Sep 17 00:00:00 2001 From: Fabian Zills Date: Tue, 31 Mar 2026 12:05:03 +0200 Subject: [PATCH 20/20] test: add cross-guest lock E2E test and clean up frames helper aliases Add negative authorization test verifying guest B cannot write frames to a room locked by guest A (423). Remove leftover _create_user / _create_room / _auth_header aliases in test_routes_frames.py. Co-Authored-By: Claude Opus 4.6 (1M context) --- tests/zndraw/test_e2e_guest_auth.py | 45 +++++++ tests/zndraw/test_routes_frames.py | 181 ++++++++++++++-------------- 2 files changed, 133 insertions(+), 93 deletions(-) diff --git a/tests/zndraw/test_e2e_guest_auth.py b/tests/zndraw/test_e2e_guest_auth.py index ab469234a..742505841 100644 --- a/tests/zndraw/test_e2e_guest_auth.py +++ b/tests/zndraw/test_e2e_guest_auth.py @@ -161,3 +161,48 @@ async def test_guest_cannot_access_other_room_without_auth(http_client: AsyncCli # Try to access frames without authorization get_resp = await http_client.get(f"/v1/rooms/{room_id}/frames") assert get_resp.status_code == 401 + + +@pytest.mark.asyncio +async def test_second_guest_cannot_write_to_locked_room(http_client: AsyncClient): + """A second guest cannot write frames to a room locked by the first guest.""" + # Guest A: authenticate, create room, write a frame, lock the room + resp_a = await http_client.post("/v1/auth/guest") + token_a = resp_a.json()["access_token"] + headers_a = {"Authorization": f"Bearer {token_a}"} + + room_id = uuid.uuid4().hex + await http_client.post( + "/v1/rooms", + json={"room_id": room_id, "copy_from": "@none"}, + headers=headers_a, + ) + + frame_data = atoms_to_json_dict(_make_atoms(1.0)) + post_resp = await http_client.post( + f"/v1/rooms/{room_id}/frames", + json={"frames": [frame_data]}, + headers=headers_a, + ) + assert post_resp.status_code == 201 + + # Lock the room as guest A + lock_resp = await http_client.put( + f"/v1/rooms/{room_id}/edit-lock", + json={}, + headers=headers_a, + ) + assert lock_resp.status_code == 200 + + # Guest B: authenticate separately + resp_b = await http_client.post("/v1/auth/guest") + token_b = resp_b.json()["access_token"] + headers_b = {"Authorization": f"Bearer {token_b}"} + + # Guest B tries to write a frame → should fail with 423 (locked) + write_resp = await http_client.post( + f"/v1/rooms/{room_id}/frames", + json={"frames": [atoms_to_json_dict(_make_atoms(2.0))]}, + headers=headers_b, + ) + assert write_resp.status_code == 423 diff --git a/tests/zndraw/test_routes_frames.py b/tests/zndraw/test_routes_frames.py index e560de9ed..56aa5d652 100644 --- a/tests/zndraw/test_routes_frames.py +++ b/tests/zndraw/test_routes_frames.py @@ -45,11 +45,6 @@ def raw_frame_to_dict(frame: RawFrame) -> dict[str, Any]: return {k.decode(): msgpack.unpackb(v) for k, v in frame.items()} -_create_user = create_test_user_in_db -_create_room = create_test_room -_auth_header = auth_header - - # ============================================================================= # List Frames Tests # ============================================================================= @@ -60,12 +55,12 @@ async def test_list_frames_empty_room( client: AsyncClient, session: AsyncSession ) -> None: """Test listing frames from an empty room returns empty list.""" - user, token = await _create_user(session) - room = await _create_room(session, user) + user, token = await create_test_user_in_db(session) + room = await create_test_room(session, user) response = await client.get( f"/v1/rooms/{room.id}/frames", - headers=_auth_header(token), + headers=auth_header(token), ) assert response.status_code == 200 assert response.headers["content-type"] == "application/x-msgpack" @@ -81,8 +76,8 @@ async def test_list_frames_with_data( frame_storage: FrameStorage, ) -> None: """Test listing frames with data returns all frames.""" - user, token = await _create_user(session) - room = await _create_room(session, user) + user, token = await create_test_user_in_db(session) + room = await create_test_room(session, user) # Add frames to storage await frame_storage[room.id].extend( @@ -90,7 +85,7 @@ async def test_list_frames_with_data( ) response = await client.get( f"/v1/rooms/{room.id}/frames", - headers=_auth_header(token), + headers=auth_header(token), ) assert response.status_code == 200 @@ -107,8 +102,8 @@ async def test_list_frames_with_range( frame_storage: FrameStorage, ) -> None: """Test listing frames with range query params.""" - user, token = await _create_user(session) - room = await _create_room(session, user) + user, token = await create_test_user_in_db(session) + room = await create_test_room(session, user) # Add frames to storage await frame_storage[room.id].extend( @@ -121,7 +116,7 @@ async def test_list_frames_with_range( ) response = await client.get( f"/v1/rooms/{room.id}/frames?start=1&stop=3", - headers=_auth_header(token), + headers=auth_header(token), ) assert response.status_code == 200 @@ -136,11 +131,11 @@ async def test_list_frames_room_not_found( client: AsyncClient, session: AsyncSession ) -> None: """Test listing frames from non-existent room returns 404.""" - _, token = await _create_user(session) + _, token = await create_test_user_in_db(session) response = await client.get( "/v1/rooms/99999/frames", - headers=_auth_header(token), + headers=auth_header(token), ) assert response.status_code == 404 @@ -155,8 +150,8 @@ async def test_list_frames_with_indices( frame_storage: FrameStorage, ) -> None: """Test listing specific frames by indices parameter.""" - user, token = await _create_user(session) - room = await _create_room(session, user) + user, token = await create_test_user_in_db(session) + room = await create_test_room(session, user) # Add 5 frames await frame_storage[room.id].extend( @@ -172,7 +167,7 @@ async def test_list_frames_with_indices( # Request specific indices response = await client.get( f"/v1/rooms/{room.id}/frames?indices=1,3", - headers=_auth_header(token), + headers=auth_header(token), ) assert response.status_code == 200 @@ -189,8 +184,8 @@ async def test_list_frames_with_keys_filter( frame_storage: FrameStorage, ) -> None: """Test listing frames with keys parameter to filter frame data.""" - user, token = await _create_user(session) - room = await _create_room(session, user) + user, token = await create_test_user_in_db(session) + room = await create_test_room(session, user) # Add frames with multiple keys await frame_storage[room.id].extend( @@ -202,7 +197,7 @@ async def test_list_frames_with_keys_filter( # Request only x and z keys response = await client.get( f"/v1/rooms/{room.id}/frames?keys=x,z", - headers=_auth_header(token), + headers=auth_header(token), ) assert response.status_code == 200 @@ -219,8 +214,8 @@ async def test_list_frames_with_indices_and_keys( frame_storage: FrameStorage, ) -> None: """Test listing specific indices with keys filter.""" - user, token = await _create_user(session) - room = await _create_room(session, user) + user, token = await create_test_user_in_db(session) + room = await create_test_room(session, user) # Add frames await frame_storage[room.id].extend( @@ -233,7 +228,7 @@ async def test_list_frames_with_indices_and_keys( # Request index 2 with only key 'a' response = await client.get( f"/v1/rooms/{room.id}/frames?indices=2&keys=a", - headers=_auth_header(token), + headers=auth_header(token), ) assert response.status_code == 200 @@ -254,8 +249,8 @@ async def test_get_frame( frame_storage: FrameStorage, ) -> None: """Test getting a single frame by index.""" - user, token = await _create_user(session) - room = await _create_room(session, user) + user, token = await create_test_user_in_db(session) + room = await create_test_room(session, user) # Add frames to storage await frame_storage[room.id].extend( @@ -263,7 +258,7 @@ async def test_get_frame( ) response = await client.get( f"/v1/rooms/{room.id}/frames/1", - headers=_auth_header(token), + headers=auth_header(token), ) assert response.status_code == 200 assert response.headers["content-type"] == "application/x-msgpack" @@ -277,12 +272,12 @@ async def test_get_frame( @pytest.mark.asyncio async def test_get_frame_not_found(client: AsyncClient, session: AsyncSession) -> None: """Test getting non-existent frame returns 404.""" - user, token = await _create_user(session) - room = await _create_room(session, user) + user, token = await create_test_user_in_db(session) + room = await create_test_room(session, user) response = await client.get( f"/v1/rooms/{room.id}/frames/99", - headers=_auth_header(token), + headers=auth_header(token), ) assert response.status_code == 404 @@ -295,11 +290,11 @@ async def test_get_frame_room_not_found( client: AsyncClient, session: AsyncSession ) -> None: """Test getting frame from non-existent room returns 404.""" - _, token = await _create_user(session) + _, token = await create_test_user_in_db(session) response = await client.get( "/v1/rooms/99999/frames/0", - headers=_auth_header(token), + headers=auth_header(token), ) assert response.status_code == 404 @@ -322,8 +317,8 @@ async def test_get_frame_metadata( from ase import Atoms from asebytes import encode - user, token = await _create_user(session) - room = await _create_room(session, user) + user, token = await create_test_user_in_db(session) + room = await create_test_room(session, user) # Create an Atoms object with calc results atoms = Atoms("H2O", positions=[[0, 0, 0], [1, 0, 0], [0, 1, 0]]) @@ -333,7 +328,7 @@ async def test_get_frame_metadata( await frame_storage[room.id].extend([raw]) response = await client.get( f"/v1/rooms/{room.id}/frames/0/metadata", - headers=_auth_header(token), + headers=auth_header(token), ) assert response.status_code == 200 @@ -361,12 +356,12 @@ async def test_get_frame_metadata_not_found( client: AsyncClient, session: AsyncSession ) -> None: """Test getting metadata for non-existent frame returns 404.""" - user, token = await _create_user(session) - room = await _create_room(session, user) + user, token = await create_test_user_in_db(session) + room = await create_test_room(session, user) response = await client.get( f"/v1/rooms/{room.id}/frames/99/metadata", - headers=_auth_header(token), + headers=auth_header(token), ) assert response.status_code == 404 @@ -379,11 +374,11 @@ async def test_get_frame_metadata_room_not_found( client: AsyncClient, session: AsyncSession ) -> None: """Test getting metadata for a frame in a non-existent room returns 404.""" - _, token = await _create_user(session) + _, token = await create_test_user_in_db(session) response = await client.get( "/v1/rooms/nonexistent-room/frames/0/metadata", - headers=_auth_header(token), + headers=auth_header(token), ) assert response.status_code == 404 @@ -399,8 +394,8 @@ async def test_get_frame_metadata_room_not_found( @pytest.mark.asyncio async def test_append_frames(client: AsyncClient, session: AsyncSession) -> None: """Test appending frames to storage.""" - user, token = await _create_user(session) - room = await _create_room(session, user) + user, token = await create_test_user_in_db(session) + room = await create_test_room(session, user) frame_a = _make_json_frame("H2") frame_b = _make_json_frame("H2O") @@ -408,7 +403,7 @@ async def test_append_frames(client: AsyncClient, session: AsyncSession) -> None response = await client.post( f"/v1/rooms/{room.id}/frames", json={"frames": [frame_a, frame_b]}, - headers=_auth_header(token), + headers=auth_header(token), ) assert response.status_code == 201 @@ -424,14 +419,14 @@ async def test_append_frames_multiple_times( client: AsyncClient, session: AsyncSession ) -> None: """Test appending frames multiple times.""" - user, token = await _create_user(session) - room = await _create_room(session, user) + user, token = await create_test_user_in_db(session) + room = await create_test_room(session, user) # First append response = await client.post( f"/v1/rooms/{room.id}/frames", json={"frames": [_make_json_frame("H2")]}, - headers=_auth_header(token), + headers=auth_header(token), ) assert response.status_code == 201 result = FrameBulkResponse.model_validate(response.json()) @@ -443,7 +438,7 @@ async def test_append_frames_multiple_times( response = await client.post( f"/v1/rooms/{room.id}/frames", json={"frames": [_make_json_frame("H2O"), _make_json_frame("CH4")]}, - headers=_auth_header(token), + headers=auth_header(token), ) assert response.status_code == 201 result = FrameBulkResponse.model_validate(response.json()) @@ -457,12 +452,12 @@ async def test_append_frames_room_not_found( client: AsyncClient, session: AsyncSession ) -> None: """Test appending frames to non-existent room returns 404.""" - _, token = await _create_user(session) + _, token = await create_test_user_in_db(session) response = await client.post( "/v1/rooms/99999/frames", json={"frames": [_make_json_frame("H2")]}, - headers=_auth_header(token), + headers=auth_header(token), ) assert response.status_code == 404 @@ -475,13 +470,13 @@ async def test_append_frames_empty_list_rejected( client: AsyncClient, session: AsyncSession ) -> None: """Test appending empty frames list is rejected (422).""" - user, token = await _create_user(session) - room = await _create_room(session, user) + user, token = await create_test_user_in_db(session) + room = await create_test_room(session, user) response = await client.post( f"/v1/rooms/{room.id}/frames", json={"frames": []}, - headers=_auth_header(token), + headers=auth_header(token), ) # FastAPI returns 422 for validation errors assert response.status_code == 422 @@ -492,14 +487,14 @@ async def test_append_frames_exceeds_max_length( client: AsyncClient, session: AsyncSession ) -> None: """Test appending more than 1000 frames is rejected (422).""" - user, token = await _create_user(session) - room = await _create_room(session, user) + user, token = await create_test_user_in_db(session) + room = await create_test_room(session, user) frame = _make_json_frame("H2") response = await client.post( f"/v1/rooms/{room.id}/frames", json={"frames": [frame] * 1001}, - headers=_auth_header(token), + headers=auth_header(token), ) assert response.status_code == 422 @@ -516,8 +511,8 @@ async def test_update_frame( frame_storage: FrameStorage, ) -> None: """Test updating a frame at specific index.""" - user, token = await _create_user(session) - room = await _create_room(session, user) + user, token = await create_test_user_in_db(session) + room = await create_test_room(session, user) # Add frames to storage await frame_storage[room.id].extend( @@ -528,7 +523,7 @@ async def test_update_frame( response = await client.put( f"/v1/rooms/{room.id}/frames/1", json={"data": new_frame}, - headers=_auth_header(token), + headers=auth_header(token), ) assert response.status_code == 200 @@ -541,13 +536,13 @@ async def test_update_frame_not_found( client: AsyncClient, session: AsyncSession ) -> None: """Test updating non-existent frame returns 404.""" - user, token = await _create_user(session) - room = await _create_room(session, user) + user, token = await create_test_user_in_db(session) + room = await create_test_room(session, user) response = await client.put( f"/v1/rooms/{room.id}/frames/99", json={"data": _make_json_frame("H2")}, - headers=_auth_header(token), + headers=auth_header(token), ) assert response.status_code == 404 @@ -567,8 +562,8 @@ async def test_merge_frame( frame_storage: FrameStorage, ) -> None: """Test partial update merges new keys into existing frame.""" - user, token = await _create_user(session) - room = await _create_room(session, user) + user, token = await create_test_user_in_db(session) + room = await create_test_room(session, user) await frame_storage[room.id].extend([make_raw_frame({"a": 1, "b": 2})]) # Send PATCH with msgpack body updating key "a" and adding key "c" @@ -576,7 +571,7 @@ async def test_merge_frame( response = await client.patch( f"/v1/rooms/{room.id}/frames/0", content=patch_data, - headers={**_auth_header(token), "Content-Type": "application/msgpack"}, + headers={**auth_header(token), "Content-Type": "application/msgpack"}, ) assert response.status_code == 200 @@ -597,8 +592,8 @@ async def test_merge_frame_preserves_untouched_keys( frame_storage: FrameStorage, ) -> None: """Test partial update does not remove keys not in the patch.""" - user, token = await _create_user(session) - room = await _create_room(session, user) + user, token = await create_test_user_in_db(session) + room = await create_test_room(session, user) await frame_storage[room.id].extend([make_raw_frame({"x": 10, "y": 20, "z": 30})]) # Only update "y" @@ -606,7 +601,7 @@ async def test_merge_frame_preserves_untouched_keys( response = await client.patch( f"/v1/rooms/{room.id}/frames/0", content=patch_data, - headers={**_auth_header(token), "Content-Type": "application/msgpack"}, + headers={**auth_header(token), "Content-Type": "application/msgpack"}, ) assert response.status_code == 200 @@ -620,14 +615,14 @@ async def test_merge_frame_not_found( client: AsyncClient, session: AsyncSession ) -> None: """Test merging non-existent frame returns 404.""" - user, token = await _create_user(session) - room = await _create_room(session, user) + user, token = await create_test_user_in_db(session) + room = await create_test_room(session, user) patch_data = msgpack.packb({"a": 1}) response = await client.patch( f"/v1/rooms/{room.id}/frames/99", content=patch_data, - headers={**_auth_header(token), "Content-Type": "application/msgpack"}, + headers={**auth_header(token), "Content-Type": "application/msgpack"}, ) assert response.status_code == 404 @@ -640,13 +635,13 @@ async def test_merge_frame_room_not_found( client: AsyncClient, session: AsyncSession ) -> None: """Test merging frame in non-existent room returns 404.""" - _, token = await _create_user(session) + _, token = await create_test_user_in_db(session) patch_data = msgpack.packb({"a": 1}) response = await client.patch( "/v1/rooms/99999/frames/0", content=patch_data, - headers={**_auth_header(token), "Content-Type": "application/msgpack"}, + headers={**auth_header(token), "Content-Type": "application/msgpack"}, ) assert response.status_code == 404 @@ -671,8 +666,8 @@ async def test_merge_frame_preserves_msgpack_str_type( """ import struct - user, token = await _create_user(session) - room = await _create_room(session, user) + user, token = await create_test_user_in_db(session) + room = await create_test_room(session, user) # Create initial frame with a numpy-format position array (float64) # This mimics what asebytes.encode produces @@ -708,7 +703,7 @@ async def test_merge_frame_preserves_msgpack_str_type( response = await client.patch( f"/v1/rooms/{room.id}/frames/0", content=patch_body, - headers={**_auth_header(token), "Content-Type": "application/msgpack"}, + headers={**auth_header(token), "Content-Type": "application/msgpack"}, ) assert response.status_code == 200 @@ -754,8 +749,8 @@ async def test_delete_frame( frame_storage: FrameStorage, ) -> None: """Test deleting a frame at specific index.""" - user, token = await _create_user(session) - room = await _create_room(session, user) + user, token = await create_test_user_in_db(session) + room = await create_test_room(session, user) # Add frames to storage await frame_storage[room.id].extend( @@ -763,7 +758,7 @@ async def test_delete_frame( ) response = await client.delete( f"/v1/rooms/{room.id}/frames/1", - headers=_auth_header(token), + headers=auth_header(token), ) assert response.status_code == 200 StatusResponse.model_validate(response.json()) @@ -779,12 +774,12 @@ async def test_delete_frame_not_found( client: AsyncClient, session: AsyncSession ) -> None: """Test deleting non-existent frame returns 404.""" - user, token = await _create_user(session) - room = await _create_room(session, user) + user, token = await create_test_user_in_db(session) + room = await create_test_room(session, user) response = await client.delete( f"/v1/rooms/{room.id}/frames/99", - headers=_auth_header(token), + headers=auth_header(token), ) assert response.status_code == 404 @@ -802,8 +797,8 @@ async def test_frames_require_authentication( client: AsyncClient, session: AsyncSession ) -> None: """Test that all frame endpoints require authentication.""" - user, _ = await _create_user(session) - room = await _create_room(session, user) + user, _ = await create_test_user_in_db(session) + room = await create_test_room(session, user) # All endpoints should return 401 without auth endpoints = [ @@ -866,15 +861,15 @@ async def test_append_rejects_frames_without_colors_radii( client: AsyncClient, session: AsyncSession ) -> None: """POST /frames rejects frames missing arrays.colors and arrays.radii.""" - user, token = await _create_user(session) - room = await _create_room(session, user) + user, token = await create_test_user_in_db(session) + room = await create_test_room(session, user) bare_frame = _make_bare_json_frame("H2") response = await client.post( f"/v1/rooms/{room.id}/frames", json={"frames": [bare_frame]}, - headers=_auth_header(token), + headers=auth_header(token), ) assert response.status_code == 422 @@ -890,8 +885,8 @@ async def test_update_rejects_frame_without_colors_radii( frame_storage: FrameStorage, ) -> None: """PUT /frames/{index} rejects frames missing arrays.colors and arrays.radii.""" - user, token = await _create_user(session) - room = await _create_room(session, user) + user, token = await create_test_user_in_db(session) + room = await create_test_room(session, user) # Add a valid frame so index 0 exists await frame_storage[room.id].extend([make_raw_frame({"a": 1})]) @@ -900,7 +895,7 @@ async def test_update_rejects_frame_without_colors_radii( response = await client.put( f"/v1/rooms/{room.id}/frames/0", json={"data": bare_frame}, - headers=_auth_header(token), + headers=auth_header(token), ) assert response.status_code == 422 @@ -910,14 +905,14 @@ async def test_append_accepts_enriched_frames( client: AsyncClient, session: AsyncSession ) -> None: """POST /frames accepts frames that already have colors and radii.""" - user, token = await _create_user(session) - room = await _create_room(session, user) + user, token = await create_test_user_in_db(session) + room = await create_test_room(session, user) enriched_frame = _make_json_frame("H2") response = await client.post( f"/v1/rooms/{room.id}/frames", json={"frames": [enriched_frame]}, - headers=_auth_header(token), + headers=auth_header(token), ) assert response.status_code == 201