diff --git a/.env.example b/.env.example index a80d8522..dcb75a17 100644 --- a/.env.example +++ b/.env.example @@ -114,5 +114,15 @@ AGENT_ENABLE_STREAMING=true # Batch runner (PRP-33) — cap on scope expansion (pairs × model_configs). BATCH_MAX_SCOPE_EXPANSION=1000 +# Batch runner concurrency (PRP-34) +# Hard upper bound on concurrent in-flight batch_job_item executions across +# all active batches on this host. Effective per-batch parallelism is +# min(batch_job.max_parallel, this). Requires uvicorn restart to apply. +BATCH_GLOBAL_MAX_PARALLEL=4 +# Max seconds DELETE /batch/{batch_id} waits for in-flight children to drain +# before returning RFC 7807 504. sklearn / LightGBM fits are uncancellable +# mid-call, so a long fit can stall the drain. +BATCH_CANCEL_DRAIN_TIMEOUT_SECONDS=30 + # Frontend (Vite) VITE_API_BASE_URL=http://localhost:8123 diff --git a/.github/workflows/dependency-check.yml b/.github/workflows/dependency-check.yml index 486fa4c8..42cf1407 100644 --- a/.github/workflows/dependency-check.yml +++ b/.github/workflows/dependency-check.yml @@ -76,7 +76,7 @@ jobs: - name: Upload SARIF to GitHub Security if: always() - uses: github/codeql-action/upload-sarif@9e0d7b8d25671d64c341c19c0152d693099fb5ba # v4.35.5 + uses: github/codeql-action/upload-sarif@7211b7c8077ea37d8641b6271f6a365a22a5fbfa # v4.36.0 with: sarif_file: audit-results.sarif category: dependency-vulnerability-scan diff --git a/PRPs/PRP-34-batch-parallel-execution.md b/PRPs/PRP-34-batch-parallel-execution.md new file mode 100644 index 00000000..df10f11f --- /dev/null +++ b/PRPs/PRP-34-batch-parallel-execution.md @@ -0,0 +1,994 @@ +name: "PRP-34 — Batch Parallel Execution (Semaphore + TaskGroup + cooperative cancellation)" +description: | + Activates the three forward-compat columns PRP-33 shipped on `batch_job` + (`max_parallel`, `running_items`, `cancelled_items`) by rewiring + `BatchService.submit` through a new `app/features/batch/runner.py`. The + runner is a single `asyncio.Semaphore(effective_parallel)` inside an + `asyncio.TaskGroup`; each child opens its own `AsyncSession`, writes the + same pinned five-key metrics JSONB the MVP produces, and observes a + cooperative `asyncio.Event` so `DELETE /batch/{batch_id}` cancels what + hasn't started and gracefully drains what has. **No new Alembic + migration** — every column the runner writes already exists on + `batch_job` per PRP-33 (`app/features/batch/models.py:136-139`). + +**Tracking issue:** (to be opened — see "Pre-flight" below) +**Source INITIAL:** `PRPs/INITIAL/INITIAL-batch-parallel-execution.md` (480 lines, refreshed in PR #283) +**Source feature doc:** `docs/optional-features/06-portfolio-forecasting-batch-runner.md` § Full Version → "Parallel execution controls" +**Depends on:** PRP-33 batch-runner MVP (merged in PR #281). All forward-compat schema is in place. +**Blocks:** none — sibling Full-Version PRPs (priority queue, export-and-retry, champion-and-heatmap) are independent. +**Successor PRPs:** none scheduled. + +--- + +## Goal + +The new module `app/features/batch/runner.py` becomes the **only** code path +that executes `batch_job_item` rows. `BatchService.submit` stops calling +`_pick_next` + `_execute_item` in a sequential loop and instead hands its +expanded item list to `runner.run_batch(...)`, which: + +1. Computes `effective_parallel = min(batch_job.max_parallel, settings.batch_global_max_parallel)`. +2. Wraps an `asyncio.Semaphore(effective_parallel)` inside an + `asyncio.TaskGroup`, creating one task per item. +3. In each child: skip-if-cancelled, acquire the semaphore, open a fresh + `AsyncSession`, increment `batch_job.running_items`, delegate to + `JobService.create_job` (lazy import), write the pinned metrics, decrement + `running_items` in `finally`. +4. On `DELETE /batch/{batch_id}`: set a per-batch `asyncio.Event`, cancel each + tracked `Task`, await drain with a bounded `Settings.batch_cancel_drain_timeout_seconds` + (default 30s), then settle the parent. + +A new `DELETE /batch/{batch_id}` endpoint surfaces cancellation; the existing +`POST /batch/forecasting` accepts `max_parallel` (already in the schema) and +returns a `BatchSubmitResponse` whose `running_items` and `cancelled_items` +now reflect real work. The `frontend/src/pages/visualize/batch.tsx` placeholder +gains a max-parallel `Slider` (a new shadcn primitive) on the submit form and +a "Cancel batch" `Button` + `AlertDialog` on the progress card. + +## Why + +- **Without a cap, the first 50-pair batch wedges the laptop.** The MVP runs + items serially, so it is safe today but operationally slow. The naive next + step is `asyncio.gather(*tasks)` over the items — repo precedent in + `app/features/demo/pipeline.py:419`. At N=500 children that exhausts the + SQLAlchemy `pool_size=5, max_overflow=10` pool (verified default — see + `PRPs/ai_docs/asyncio-taskgroup-cancellation.md`), and a per-child sklearn + fit blows out RAM. The Semaphore + TaskGroup primitive gives bounded + concurrency + structured cancellation with **zero new dependencies**. +- **Operator needs a stop button.** A 200-item batch that's misconfigured + (wrong date range, wrong model) currently has to run to completion. A + cooperative cancel that stops what hasn't started and bounds the drain of + what has is the difference between "useful tool" and "footgun". +- **Activates schema PRP-33 already shipped.** `batch_job.max_parallel`, + `running_items`, `cancelled_items` are real columns; the MVP just doesn't + write the last two. This PRP makes them live. No new migration needed. + +## What + +### User-visible behaviour + +```bash +# Submit with explicit per-batch parallelism (clamped by global cap). +curl -X POST http://localhost:8123/batch/forecasting \ + -H "Content-Type: application/json" \ + -d '{ + "operation": "backtest", + "scope": {"kind": "manual", "store_ids": [1,2,3,4,5], "product_ids": [1,2,3,4,5]}, + "model_configs": [{"model_type": "naive"}], + "start_date": "2024-01-01", + "end_date": "2024-06-30", + "max_parallel": 8 + }' +# → 202 Accepted with BatchSubmitResponse including running_items, cancelled_items +# settle to consistent end-state per child outcomes. + +# Cancel an in-flight batch. +curl -X DELETE http://localhost:8123/batch/{batch_id} +# → 200 BatchSubmitResponse with status='cancelled' once drain settles. +# → 404 application/problem+json on unknown batch_id. +# → 409 application/problem+json on already-terminal batch. +# → 504 application/problem+json if drain exceeds Settings.batch_cancel_drain_timeout_seconds. + +# Frontend: /visualize/batch shows a max_parallel slider on submit, a live +# `running_items` chip on the parent card, and a "Cancel batch" button that +# pops an AlertDialog confirmation. +``` + +### Success Criteria + +- [ ] `grep -rn "asyncio.gather" app/features/batch/` returns no production-code match (allowed in tests for synthetic concurrent-peak checks). +- [ ] `app/features/batch/runner.py` exists and is the only code path that schedules `batch_job_item` execution. `BatchService.submit` calls `runner.run_batch(...)` instead of looping `_pick_next` / `_execute_item`. +- [ ] `Settings.batch_global_max_parallel` defaults to `4`; `Settings.batch_cancel_drain_timeout_seconds` defaults to `30`; `.env.example` lists both placeholders. +- [ ] `POST /batch/forecasting` echoes `max_parallel` and the new `effective_max_parallel` field in `BatchSubmitResponse`; `running_items` and `cancelled_items` are accurate at every observation moment. +- [ ] `DELETE /batch/{batch_id}` returns 200 + cancelled parent on success, 404 on missing, 409 on terminal, 504 on drain timeout — all RFC 7807. +- [ ] The 8 unit tests + 3 integration tests + 2 chaos tests in § Test Plan pass. The semaphore-cap regression test (`test_semaphore_caps_concurrency`) **would have caught an unbounded `gather`** and is the load-bearing spec. +- [ ] `frontend/src/pages/visualize/batch.tsx` renders a max-parallel `Slider` (new shadcn) and a "Cancel batch" `Button` + `AlertDialog`. The submit form sends `max_parallel`; the progress card shows `running_items` live. +- [ ] All five validation gate commands green locally and in CI: + ```bash + uv run ruff check . && uv run ruff format --check . + uv run mypy app/ && uv run pyright app/ + uv run pytest -v -m "not integration" + docker compose up -d && uv run pytest -v -m integration + cd frontend && pnpm tsc --noEmit && pnpm lint && pnpm test --run + ``` +- [ ] No Alembic migration added (the three columns are already on the table). +- [ ] No managed-cloud / Celery / Redis dependency introduced. +- [ ] `.claude/rules/commit-format.md` already lists `batch` in the scope allow-list — commits use `feat(batch): ...` / `feat(batch,ui): ...` / `feat(batch,api): ...` referencing the tracking issue. + +## All Needed Context + +### Documentation & References + +```yaml +# MUST READ — load these before implementing +- file: PRPs/INITIAL/INITIAL-batch-parallel-execution.md + why: 480-line source spec. The Data Model Delta, API Delta, and Risks tables are authoritative. The pseudocode is illustrative — see "Known Gotchas" below for the corrections. + +- file: PRPs/ai_docs/asyncio-taskgroup-cancellation.md + why: Verified asyncio semantics on Python 3.12.13. Captures the three working cancellation mechanisms, the broken `tg.cancel_scope.cancel()` claim in the INITIAL, the ContextVar inheritance into child tasks (so `request_id` propagates automatically), and the SQLAlchemy pool default math. Re-runnable verification commands inline. **Read end-to-end before touching runner.py.** + +- file: PRPs/PRP-33-batch-runner-mvp.md + why: The MVP PRP. Phase 0 § "Cross-Slice Coordination Matrix" pinned every forward-compat column, the partial picker index predicate, the `FOR UPDATE SKIP LOCKED` invariant. PRP-34 must NOT break any of those — the same tests still pass. + +- file: app/features/batch/models.py + why: The three columns this PRP activates live at lines 136-139 (`running_items`, `cancelled_items`, `max_parallel`). The state machines at lines 90-110 already accept the `running → completed/failed/cancelled` and `pending → cancelled` transitions — no model change required. + +- file: app/features/batch/service.py + why: The existing sequential picker loop at lines 152-157 is what this PRP replaces. The lazy `JobService` import pattern at lines 255-257 is what each child reuses. The `_settle` helper at lines 350-391 already aggregates by status — PRP-34 leaves it as-is (it already counts cancelled_items). + +- file: app/features/batch/schemas.py + why: `BatchSubmitRequest.max_parallel` is at line 135 (already validated `ge=1, le=64`). `BatchSubmitResponse` at lines 164-181 already carries `running_items` and `cancelled_items` — PRP-34 adds ONE field, `effective_max_parallel`, to the response. + +- file: app/features/batch/routes.py + why: Three endpoints exist (POST forecasting, GET {id}, GET {id}/items). PRP-34 adds `DELETE /batch/{batch_id}` mirroring the GET shape, 404/409/504 problem+json on the failure paths. + +- file: app/features/batch/tests/conftest.py + why: `db_session` cleans rows where `batch_id.like("test%")` — every new integration test must prefix its batch_id with `test` (use the seed fixtures, they already produce uuid-hex batch_ids; cleanup keys on prefix, so call the helper that overrides batch_id when needed). + +- file: app/features/batch/tests/test_routes_integration.py + why: Reference integration shape — `ASGITransport(app=app)` client fixture, structlog event capture via `structlog.testing.capture_logs`, real Postgres + the `BATCH-` store/product seed fixtures. + +- file: app/features/batch/tests/test_service.py + why: Reference unit-test shape — `_make_job_response` builds a synthetic `JobResponse`; `AsyncMock` proves no DB call lands; the `pytestmark = pytest.mark.integration` is absent here so tests run under `-m "not integration"`. + +- file: app/features/demo/service.py + why: Module-level `asyncio.Lock` + `PipelineBusyError` is the prior art for "one X at a time in the process". The runner's `_ACTIVE_BATCHES` dict + `CancelHandle` follow the same vibe but allow multiple batches; one cancel handle per batch. + +- file: app/features/demo/pipeline.py + why: Lines 418-420 show the `asyncio.gather(...)` pattern this PRP must not import into the batch slice. At N=3 it's fine; at N=500 it's the headline risk. + +- file: app/core/database.py + why: `get_engine()` and `get_session_maker()`. Each child opens its own session via a fresh `async with session_maker() as session:` block. Do NOT use the request's `Depends(get_db)` session inside the runner — that session belongs to the HTTP request lifecycle, not the child task. + +- file: app/core/config.py + why: Where the two new Settings fields land. Mirror the `batch_max_scope_expansion` placement (line 122). `@lru_cache` on `get_settings()` (line 212) — env-var override needs a uvicorn restart, document on the field. + +- file: app/core/exceptions.py + why: `NotFoundError` (404), `ConflictError` (409). 504 needs a new exception or a direct `problem_response(status=504, ...)` call — see Task 6 below. RFC 7807 envelope is automatic via the registered handler. + +- file: app/core/problem_details.py + why: `problem_response(status, title, detail, error_code)` for the 504 path. The `ERROR_TYPES` dict at line 26 catalogues canonical error codes — `GATEWAY_TIMEOUT` may need adding. + +- file: app/core/logging.py + why: `request_id_ctx: ContextVar` is read by the `add_request_id` structlog processor. TaskGroup children inherit the request's contextvars, so child events auto-carry the parent request_id (verified — see ai_doc). + +- file: app/core/tests/test_strict_mode_policy.py + why: AST-walker invariant. PRP-34 does NOT add any new `date | datetime | time | UUID | Decimal` field to a `ConfigDict(strict=True)` model (the existing `start_date`/`end_date` already carry `Field(strict=False, ...)`). Still — running the policy linter is in the validation gates. + +- file: app/features/jobs/service.py + why: `JobService.create_job` at lines 150-191 is the delegation target each runner child invokes. The MVP runner already calls it (lines 285 of batch/service.py); the runner moves the same call into a child coroutine. + +- file: app/features/forecasting/service.py + why: Lines 786-787 — the lazy-import precedent. Every cross-slice import inside the runner stays lazy (in-method) to avoid the alembic cold-boot cycle documented in memory `[[computed-field-cross-slice-cycle]]`. + +- file: frontend/src/pages/visualize/batch.tsx + why: 213-line placeholder ALREADY EXISTS (created by PRP-33). PRP-34 modifies this file to add the slider, the cancel button, and the running_items chip. The page route is already wired in `App.tsx`. + +- file: frontend/src/hooks/use-batches.ts + why: `useSubmitBatch`, `useBatch` (polls 2s while pending/running), `useBatchItems`. PRP-34 adds `useCancelBatch` mutation. + +- file: frontend/src/types/api.ts + why: Lines 336-360 — `BatchSubmitRequest` already has `max_parallel`; `BatchSubmitResponse` already has `running_items`/`cancelled_items`. PRP-34 adds `effective_max_parallel?: number` to the response type. + +- file: .claude/rules/shadcn-ui.md + why: ALL shadcn component work goes through the `shadcn` skill + `mcp__shadcn__*` MCP tools. The slider is added via `pnpm dlx shadcn@latest add slider` from `frontend/`, not by hand-writing the file. + +- url: https://docs.python.org/3.12/library/asyncio-task.html#asyncio.TaskGroup + section: TaskGroup — full signature and exception-group semantics + critical: TaskGroup's ONLY public surface is `create_task`. No `cancel_scope`, no `cancel()`, no `tasks`. Cancel by holding the Task references `create_task` returns and calling `task.cancel()` on each. + +- url: https://docs.python.org/3.12/library/asyncio-sync.html#asyncio.Semaphore + section: Semaphore — async context manager + critical: `async with sem:` acquires on entry, releases on exit (including the exception path). Wrap WORK, not task scheduling. + +- url: https://peps.python.org/pep-0654/ + section: PEP 654 — except* syntax + critical: Use `except* asyncio.CancelledError:` to catch the `ExceptionGroup` TaskGroup re-raises when its children were cancelled. Plain `except asyncio.CancelledError:` would not catch the group. + +- url: https://docs.sqlalchemy.org/en/20/core/pooling.html#sqlalchemy.pool.QueuePool + section: QueuePool — pool_size + max_overflow + timeout + critical: Default `pool_size=5`, `max_overflow=10`, `timeout=30s` on `create_async_engine`. Verified by direct probe (see ai_doc § "SQLAlchemy async pool defaults"). `batch_global_max_parallel ≤ 12` keeps headroom for the HTTP request + cancel endpoint + settle session. +``` + +### Current Codebase tree (relevant slices only) + +``` +app/ +├── core/ +│ ├── config.py # Settings model — add two fields here +│ ├── database.py # get_session_maker() — each runner child opens its own +│ ├── exceptions.py # NotFoundError, ConflictError, ValidationError +│ ├── logging.py # request_id_ctx ContextVar — inherited by TaskGroup children +│ ├── middleware.py # RequestIdMiddleware +│ └── problem_details.py # problem_response() helper +├── features/ +│ ├── batch/ +│ │ ├── __init__.py +│ │ ├── models.py # max_parallel, running_items, cancelled_items ALREADY HERE +│ │ ├── schemas.py # BatchSubmitRequest.max_parallel ALREADY HERE +│ │ ├── service.py # sequential picker loop — REPLACE with runner call +│ │ ├── routes.py # POST/GET endpoints — ADD DELETE here +│ │ └── tests/ +│ │ ├── conftest.py # BATCH-* seed fixtures, ASGITransport client +│ │ ├── test_routes_integration.py +│ │ └── test_service.py +│ ├── demo/ +│ │ ├── service.py # asyncio.Lock single-flight prior art +│ │ └── pipeline.py:418-420 # asyncio.gather anti-pattern (don't copy into batch) +│ └── jobs/ +│ └── service.py # JobService.create_job (delegation target) +└── main.py # routers wired here; no change needed for new DELETE +frontend/src/ +├── components/ui/ +│ ├── slider.tsx # ABSENT — add via shadcn MCP +│ └── (alert-dialog, button, card, progress, badge, table — ALL present) +├── hooks/ +│ └── use-batches.ts # ADD useCancelBatch here +├── pages/visualize/ +│ └── batch.tsx # 213-line placeholder — extend with slider+cancel +└── types/ + └── api.ts # add effective_max_parallel to BatchSubmitResponse +PRPs/ai_docs/ +└── asyncio-taskgroup-cancellation.md # NEW — captured for this PRP +``` + +### Desired Codebase tree (delta after PRP-34) + +``` +app/features/batch/ +├── runner.py # NEW — Semaphore + TaskGroup + CancelHandle registry +├── service.py # MODIFIED — submit() delegates to runner.run_batch() +├── routes.py # MODIFIED — add DELETE /batch/{batch_id} +├── schemas.py # MODIFIED — add effective_max_parallel to response +└── tests/ + ├── test_runner.py # NEW — unit + integration (8 cases) + ├── test_routes_cancel.py # NEW — DELETE endpoint (3 cases) + └── test_runner_chaos.py # NEW — orphan-state regression (2 cases) +app/core/ +├── config.py # MODIFIED — +batch_global_max_parallel, +batch_cancel_drain_timeout_seconds +├── exceptions.py # MODIFIED — +GatewayTimeoutError (504) and ERROR_TYPES key +└── problem_details.py # MODIFIED — +ERROR_TYPES["GATEWAY_TIMEOUT"] +frontend/src/ +├── components/ui/slider.tsx # NEW — added via shadcn MCP +├── hooks/use-batches.ts # MODIFIED — +useCancelBatch +├── pages/visualize/batch.tsx # MODIFIED — slider on submit, cancel button on progress card +└── types/api.ts # MODIFIED — +effective_max_parallel?: number on BatchSubmitResponse +.env.example # MODIFIED — +BATCH_GLOBAL_MAX_PARALLEL=4, +BATCH_CANCEL_DRAIN_TIMEOUT_SECONDS=30 +PRPs/ai_docs/asyncio-taskgroup-cancellation.md # NEW +``` + +### Known Gotchas of our codebase & Library Quirks + +```python +# CRITICAL: asyncio.TaskGroup has NO .cancel_scope on stdlib Python 3.12. +# The INITIAL's pseudocode references tg.cancel_scope.cancel() — that's anyio +# API, not stdlib. Cancel by holding Task refs and calling task.cancel() on each. +# +# Verify: uv run python -c "import asyncio; print(dir(asyncio.TaskGroup))" +# Expected: only 'create_task' as a public method. +# +# See PRPs/ai_docs/asyncio-taskgroup-cancellation.md for the verified pattern. + +# CRITICAL: catch the ExceptionGroup, not bare CancelledError. +# After a TaskGroup body cancels children, exceptions surface as an +# ExceptionGroup (PEP 654). Use `except* asyncio.CancelledError:` — plain +# `except asyncio.CancelledError:` will NOT catch the group. +# +# Verify: uv run python -c " +# import asyncio +# async def c(): await asyncio.sleep(10) +# async def m(): +# try: +# async with asyncio.TaskGroup() as tg: +# t = tg.create_task(c()); await asyncio.sleep(0.01); t.cancel() +# except* asyncio.CancelledError as eg: +# print('caught', len(eg.exceptions)) +# asyncio.run(m()) +# " +# Expected: 'caught 1' + +# CRITICAL: each child opens its OWN AsyncSession. +# The HTTP-request session is bound to the request's lifecycle; reusing it +# across N concurrent tasks corrupts identity-map state and serialises work. +# Pattern: `async with session_maker() as session:` inside the child, after +# acquiring the semaphore. Verified default pool: pool_size=5, max_overflow=10. +# +# Verify: uv run python -c " +# from sqlalchemy.ext.asyncio import create_async_engine +# e = create_async_engine('postgresql+asyncpg://x:x@h:5433/x') +# print(e.pool.size(), e.pool._max_overflow, e.pool._timeout) +# " +# Expected: 5 10 30.0 + +# CRITICAL: Semaphore wraps the work, not the task creation. +# Pattern that DEFEATS the cap: +# async with sem: # acquired in the runner, not in the child +# tg.create_task(child()) # tg.create_task is fast — sem releases instantly +# Correct pattern: +# async def child(item): +# async with sem: # acquired by the child, inside its own body +# ... # the actual work +# for item in items: +# tg.create_task(child(item)) + +# CRITICAL: ContextVar inheritance — request_id propagates AUTOMATICALLY. +# Tasks created with asyncio.create_task (and tg.create_task) inherit the +# current contextvars.Context (CPython 3.7+ documented). So +# `app.core.logging.request_id_ctx` flows from the POST handler into every +# TaskGroup child — no explicit `bind_contextvars` needed. The +# `batch.item_started`/`batch.item_completed` log lines auto-correlate to the +# parent request's X-Request-ID. + +# CRITICAL: do NOT iterate asyncio.all_tasks() to find children to cancel. +# The INITIAL's cancel_batch sketch scans all_tasks() and matches on +# task.get_name().startswith(f"batch:{batch_id}:"). Brittle: collisions across +# concurrent batches, cancels unrelated request handlers, breaks silently if +# `name=` is dropped in a refactor. Keep the Task references in the +# CancelHandle.tasks list when create_task returns them; cancel via that list. + +# GOTCHA: sklearn/LightGBM fits are SYNC C code — uncancellable mid-fit. +# A child that's already inside JobService.create_job's training call will +# NOT observe CancelledError until the fit returns. That's acceptable — the +# runner times out the drain via batch_cancel_drain_timeout_seconds (default +# 30s) and surfaces 504 if the operator wants to bail. Document in the +# DELETE route docstring + a tooltip on the frontend Cancel button. + +# GOTCHA: BatchService.submit currently runs the picker LOOP inside the same +# request handler — the response only returns after every item completes. +# Today this is a feature (the response is the settled parent). PRP-34 +# preserves it: the runner is awaited inside submit(), the response is still +# the settled parent. The DELETE endpoint is what gives operators a parallel +# control channel — it works because `runner.run_batch` registers a +# CancelHandle that's discoverable from any other request handler. + +# GOTCHA: integration-test cleanup keys on batch_id LIKE 'test%'. +# `app/features/batch/tests/conftest.py:db_session` deletes only batches +# whose ID starts with `test`. The MVP submit() generates uuid hex +# batch_ids — the conftest already comments on this. For runner tests that +# explicitly set batch_id, prefix it with `test`. For tests that go through +# the public submit endpoint, the cleanup also wipes data_platform rows +# created by the seed fixtures (`BATCH-%` codes), which transitively +# cascades to batch_job_item via FK ON DELETE CASCADE. + +# GOTCHA: shadcn MUST be driven through the MCP, not hand-written. +# `.claude/rules/shadcn-ui.md` is explicit: invoke the `shadcn` skill, use +# `mcp__shadcn__get_add_command_for_items` to get the exact `pnpm dlx` +# command, run it from frontend/, then audit. The slider primitive is one +# of shadcn's standard New York components — it lives at +# @/components/ui/slider after install. +``` + +## Implementation Blueprint + +### Data models and structure + +**No new ORM models.** The runner reads and writes existing columns on +`batch_job` (lines 136-139 of `app/features/batch/models.py`). No new schema +migration. + +**Two new Settings fields** (`app/core/config.py`): + +```python +# Batch Runner Concurrency (PRP-34) +batch_global_max_parallel: int = Field( + default=4, + ge=1, + le=64, + description=( + "Hard upper bound on concurrent in-flight batch_job_item executions " + "across all active batches on this host. Sized for the docker-compose " + "Postgres pool (pool_size=5, max_overflow=10). Effective per-batch " + "parallelism is min(batch_job.max_parallel, this). Env override: " + "BATCH_GLOBAL_MAX_PARALLEL=8 — requires uvicorn restart." + ), +) +batch_cancel_drain_timeout_seconds: int = Field( + default=30, + ge=1, + le=600, + description=( + "Max seconds DELETE /batch/{batch_id} waits for in-flight children " + "to settle before returning RFC 7807 504. In-flight sklearn/LightGBM " + "fits are uncancellable mid-call, so a long fit can stall the drain." + ), +) +``` + +**One new response field** (`app/features/batch/schemas.py:BatchSubmitResponse`): + +```python +effective_max_parallel: int = 0 # min(req.max_parallel, settings.batch_global_max_parallel) +``` + +(Default `0` for backward compatibility when reading old rows; the runner +always sets it. `from_attributes=True` is already on the model_config, so +`BatchSubmitResponse.model_validate(batch)` needs the field to materialise +from a runtime computation — see Task 4 below.) + +### Tasks (dependency-ordered) + +```yaml +Task 1 — Capture verified asyncio mechanics +CREATE PRPs/ai_docs/asyncio-taskgroup-cancellation.md: + - MIRROR pattern from: PRPs/ai_docs/exogenous-regressor-forecasting.md (single-topic deep-dive) + - CONTENT: TaskGroup API surface (only create_task), the three working cancel + mechanisms (per-task cancel + raise-inside + cooperative Event), the + INITIAL's broken cancel_scope claim, ContextVar inheritance proof, + SQLAlchemy pool default math, sklearn-fit-uncancellable note. Include + `uv run python -c "..."` verification commands for each claim. + - STATUS: this file is already authored — verify line count > 100 and that + each verification command actually runs. + +Task 2 — Settings + .env.example +MODIFY app/core/config.py: + - FIND pattern: "batch_max_scope_expansion: int = 1000" + - INJECT after line: two new Field(...) entries per § "Data models and structure" above. + - PRESERVE existing alphabetic-by-section ordering — these go in the "Batch runner" block. +MODIFY .env.example: + - FIND pattern: "BATCH_MAX_SCOPE_EXPANSION=1000" + - INJECT after line: + BATCH_GLOBAL_MAX_PARALLEL=4 + BATCH_CANCEL_DRAIN_TIMEOUT_SECONDS=30 + - PRESERVE block comment style above the new lines. + +Task 3 — Exception class for 504 drain timeout +MODIFY app/core/problem_details.py: + - FIND pattern: "ERROR_TYPES: dict[str, str]" (the canonical-codes dict) + - INJECT new key: "GATEWAY_TIMEOUT": "https://forecastlabai.dev/problems/gateway-timeout" +MODIFY app/core/exceptions.py: + - FIND pattern: "class UnprocessableEntityError(ForecastLabError):" + - INJECT after the full UnprocessableEntityError class: + class GatewayTimeoutError(ForecastLabError): + """Raised when a bounded drain (e.g., batch cancellation) exceeds its + configured budget. Surfaces RFC 7807 504. + + Distinct from a 408 client-timeout: the client didn't time out, the + server's own internal drain budget did. + """ + error_type_uri: str = ERROR_TYPES["GATEWAY_TIMEOUT"] + def __init__( + self, + message: str = "Operation drain exceeded budget", + details: dict[str, Any] | None = None, + ) -> None: + super().__init__( + message=message, + code="GATEWAY_TIMEOUT", + status_code=504, + details=details, + ) + +Task 4 — Add effective_max_parallel to BatchSubmitResponse +MODIFY app/features/batch/schemas.py: + - FIND pattern: "class BatchSubmitResponse(BaseModel):" + - In the body, INJECT after `cancelled_items: int`: + effective_max_parallel: int = Field( + default=0, + ge=0, + description=( + "min(max_parallel, settings.batch_global_max_parallel) actually applied " + "by the runner. 0 means 'not yet set' for legacy rows; the runner " + "always populates it on submit." + ), + ) + - PRESERVE ConfigDict(from_attributes=True). Add max_parallel mirror if not already present. + +Task 5 — Create the runner module +CREATE app/features/batch/runner.py: + - MIRROR pattern from: app/features/demo/service.py (module-level lock as registry) + - CONTENT: module-level `_ACTIVE_BATCHES: dict[str, CancelHandle]`, CancelHandle + dataclass holding the asyncio.Event + list[asyncio.Task], `run_batch()` + coroutine implementing Semaphore + TaskGroup + per-child fresh AsyncSession + + cooperative cancel + bounded drain, `cancel_batch()` setter for + DELETE-side. NO sibling-slice imports at module scope — lazy in-method only. + - GOTCHA: each child opens its own session via get_session_maker(); never + reuse the parent runner's session for child work. + - GOTCHA: keep Task refs in CancelHandle.tasks; never asyncio.all_tasks() + name-prefix scan. + +Task 6 — Rewire BatchService.submit through the runner +MODIFY app/features/batch/service.py: + - FIND pattern: the `while True: next_item = await self._pick_next(...)` block (lines ~152-157) + - REPLACE the loop with a single `await runner.run_batch(...)` call passing + the list of inserted items and the effective_max_parallel value. + - KEEP _pick_next and _execute_item on the class (test_picker_query_uses_skip_locked + still asserts the SQL; future PRPs may use the picker for multi-worker mode). + - GOTCHA: BatchService.submit still computes effective_parallel itself so it + can write it onto the parent record before the runner starts. The runner + re-computes it defensively (so a direct caller can't bypass). + - PRESERVE the existing logger.info("batch.created", ...) and the `_settle` call after the runner returns. + +Task 7 — DELETE /batch/{batch_id} endpoint +MODIFY app/features/batch/routes.py: + - FIND pattern: the last decorated route in the file (`list_batch_items`). + - APPEND a new route: + @router.delete("/{batch_id}", response_model=BatchSubmitResponse, ...) + async def cancel_batch(batch_id: str, db: AsyncSession = Depends(get_db)) -> BatchSubmitResponse: ... + - LOGIC: + 1. Service.get(batch_id) → None ⇒ raise NotFoundError. + 2. If parent.status in {completed, failed, partial, cancelled} ⇒ ConflictError. + 3. Call runner.cancel_batch(batch_id) — returns True iff registered. + - Returns False ⇒ Conflict (already settled or never registered). + 4. Await drain with asyncio.wait_for(handle.completed_event.wait(), timeout=settings.batch_cancel_drain_timeout_seconds). + - TimeoutError ⇒ raise GatewayTimeoutError(message=f"Drain exceeded {timeout}s; parent settle pending.") + 5. Re-load the parent (post-settle) and return BatchSubmitResponse. + - LOG events: batch.cancel_requested, batch.cancel_drain_timeout, batch.cancelled. + +Task 8 — Unit tests +CREATE app/features/batch/tests/test_runner.py: + - SCAFFOLD per app/features/batch/tests/test_service.py — async def + pytest-asyncio auto-mode. + - TESTS: + test_semaphore_caps_concurrency + - 5 fake child coroutines each await asyncio.sleep(0.05) inside the runner with max_parallel=2 + - Use a shared list to record start/finish events; observed concurrent peak == 2. + - THIS IS THE LOAD-BEARING REGRESSION TEST FOR UNBOUNDED-FAN-OUT. + test_settings_global_cap_clamps_max_parallel + - max_parallel=32, settings.batch_global_max_parallel=4 ⇒ peak ≤ 4. + test_child_failure_does_not_abort_siblings + - One of 5 children raises RuntimeError; other 4 reach completion. + - The runner does NOT propagate the failure to the TaskGroup; each + child's _execute body catches Exception and writes status=failed. + test_cancel_pending_child_marks_cancelled_without_running + - max_parallel=1, 3 items. After first starts, cancel event fires. + - Assert items 2 and 3 transition pending → cancelled, never opened a session. + test_cancel_running_child_propagates_cancelled_error + - One child sleeps 1s; cancel after 0.05s. Child observes CancelledError, finally block writes cancelled. + +Task 9 — Cancel-endpoint route tests +CREATE app/features/batch/tests/test_routes_cancel.py: + - SCAFFOLD per app/features/batch/tests/test_routes_integration.py — ASGITransport client. + - TESTS: + test_delete_404_unknown_batch — RFC 7807 404; problem+json content-type. + test_delete_409_terminal_batch — submit + wait for settle, then DELETE → 409. + test_delete_504_drain_timeout — patch Settings(batch_cancel_drain_timeout_seconds=0), DELETE returns 504 problem+json. + +Task 10 — Chaos tests (integration) +CREATE app/features/batch/tests/test_runner_chaos.py: + - SCAFFOLD per app/features/batch/tests/test_routes_integration.py (pytestmark = pytest.mark.integration). + - TESTS: + test_cancel_mid_flight_does_not_orphan_running_items + - Submit a 4-item batch with max_parallel=2 + fake-slow children, cancel mid-run. + - SELECT COUNT(*) FROM batch_job_item WHERE batch_id=? AND status='running' → 0. + test_parent_status_progresses_as_children_complete + - 6 items, max_parallel=2 — sample batch_job.running_items at intervals; assert running_items ≤ 2 throughout. + - Final state: status='completed', completed_items=6. + +Task 11 — Wire frontend slider + cancel UX +MODIFY frontend (driven by the shadcn skill — invoke `Skill: shadcn-ui`): + 1. Add the slider primitive: `pnpm dlx shadcn@latest add slider` from frontend/. + 2. Verify file exists: frontend/src/components/ui/slider.tsx. +MODIFY frontend/src/types/api.ts: + - FIND pattern: "export interface BatchSubmitResponse {" + - INJECT after `cancelled_items: number`: + effective_max_parallel?: number +MODIFY frontend/src/hooks/use-batches.ts: + - APPEND a new hook useCancelBatch: + export function useCancelBatch() { + const queryClient = useQueryClient() + return useMutation({ + mutationFn: (batchId: string) => + api(`/batch/${batchId}`, { method: 'DELETE' }), + onSuccess: (data) => { + queryClient.setQueryData(['batch', data.batch_id], data) + void queryClient.invalidateQueries({ queryKey: ['batch'] }) + }, + }) + } +MODIFY frontend/src/pages/visualize/batch.tsx: + - Submit form gains a max_parallel Slider (min=1, max=8 (default) capped by a future server-reported global, step=1, default=4). + - Tooltip on the slider explains the runtime clamp ("Effective parallelism = min(this, server global cap)"). + - Progress card adds a `running_items` chip via the existing Badge primitive. + - Add a "Cancel batch" Button that opens AlertDialog (mirror frontend/src/pages/explorer/job-detail.tsx:50 useCancelJob pattern). + - The cancel button is disabled when `status ∈ {completed, failed, partial, cancelled}`. + +Task 12 — Frontend tests +MODIFY frontend (tests live colocated next to source): + - ADD frontend/src/hooks/use-batches.test.ts (mirror frontend/src/hooks/use-demo-pipeline.test.ts) — assert useCancelBatch issues a DELETE and invalidates the right cache keys. + - SCOPE: don't add a full Playwright test — visual verification happens via the webapp-testing skill per .claude/rules/ui-design.md. + +Task 13 — Validation gates +RUN locally (matches .github/workflows/ci.yml expectations): + uv run ruff check . && uv run ruff format --check . + uv run mypy app/ && uv run pyright app/ + uv run pytest -v -m "not integration" app/features/batch/ + docker compose up -d + uv run alembic upgrade head # MUST be a no-op (no new migration) + uv run pytest -v -m integration app/features/batch/ + cd frontend && pnpm tsc --noEmit && pnpm lint && pnpm test --run +``` + +### Per-task pseudocode (high-information-density) + +```python +# ---------------------------------------------------------------- Task 5: runner.py + +# NOTE: lazy in-method cross-slice imports break the alembic cold-boot cycle — +# matches the forecasting/service.py:786-787 precedent. + +from __future__ import annotations +import asyncio +from collections.abc import Awaitable, Callable +from dataclasses import dataclass, field + +from app.core.config import get_settings +from app.core.database import get_session_maker +from app.core.logging import get_logger + +logger = get_logger(__name__) + +# Module-level registry — single-process scope (matches single-host vision). +# A future ADR would move this to Redis if multi-process arrives. +_ACTIVE_BATCHES: dict[str, "CancelHandle"] = {} + + +@dataclass +class CancelHandle: + """Cancel signal + task refs for an in-flight batch. Created by run_batch, + looked up by cancel_batch, removed in the run_batch finally.""" + cancel_event: asyncio.Event = field(default_factory=asyncio.Event) + completed_event: asyncio.Event = field(default_factory=asyncio.Event) + tasks: list[asyncio.Task[None]] = field(default_factory=list) + + +async def run_batch( + batch_id: str, + item_ids: list[str], # item_ids of pending children + max_parallel: int, + execute_item: Callable[[str], Awaitable[None]], # one-arg coroutine: itemize, run, settle row +) -> int: + """Execute one batch through a bounded TaskGroup. Returns effective_parallel. + + `execute_item` is the per-item closure (passed in from BatchService so the + runner stays decoupled from JobService). It MUST open its own AsyncSession, + write final per-row status, and emit lifecycle log events. + """ + settings = get_settings() + effective = min(max_parallel, settings.batch_global_max_parallel) + sem = asyncio.Semaphore(effective) + handle = CancelHandle() + _ACTIVE_BATCHES[batch_id] = handle + + logger.info("batch.runner_start", batch_id=batch_id, + total_items=len(item_ids), max_parallel=max_parallel, + effective_max_parallel=effective) + + async def _child(item_id: str) -> None: + # FAST-CANCEL BEFORE acquire — skips not-yet-started work cleanly. + if handle.cancel_event.is_set(): + await _mark_cancelled_skipped(item_id) + return + async with sem: + if handle.cancel_event.is_set(): + await _mark_cancelled_skipped(item_id) + return + await _bump_running(batch_id, +1) + try: + await execute_item(item_id) # may raise; OK + except asyncio.CancelledError: + # Cooperative drain — child observed cancel mid-run. + await _mark_cancelled_running(item_id) + raise # let TaskGroup see it + finally: + await _bump_running(batch_id, -1) + + try: + async with asyncio.TaskGroup() as tg: + for iid in item_ids: + t = tg.create_task(_child(iid), name=f"batch:{batch_id}:{iid}") + handle.tasks.append(t) + except* asyncio.CancelledError: + # PEP 654 — caught the group of cancelled children. The runner's + # parent-state settle is the caller's responsibility (BatchService._settle + # already aggregates by status — works for cancelled the same way). + pass + finally: + handle.completed_event.set() + _ACTIVE_BATCHES.pop(batch_id, None) + + return effective + + +def cancel_batch(batch_id: str) -> bool: + """Signal cancel for an in-flight batch. Returns False if not registered.""" + handle = _ACTIVE_BATCHES.get(batch_id) + if handle is None: + return False + handle.cancel_event.set() + for t in handle.tasks: + if not t.done(): + t.cancel() + logger.info("batch.cancel_requested", batch_id=batch_id, n_tasks=len(handle.tasks)) + return True + + +async def await_drain(batch_id: str, timeout_seconds: float) -> bool: + """Block until the runner's completed_event fires or timeout elapses. + + Returns True on clean drain, False on timeout. Returns True immediately + if the batch is no longer registered (race-free). + """ + handle = _ACTIVE_BATCHES.get(batch_id) + if handle is None: + return True + try: + await asyncio.wait_for(handle.completed_event.wait(), timeout=timeout_seconds) + return True + except TimeoutError: + return False + + +# Helpers _bump_running, _mark_cancelled_skipped, _mark_cancelled_running each +# open a fresh AsyncSession via get_session_maker() and commit a single UPDATE. +# They do NOT call BatchService — that would close a cycle. Implemented inline +# with raw SQLAlchemy update() statements scoped to the relevant batch_job/ +# batch_job_item row. + + +# ---------------------------------------------------------------- Task 6: service.py rewire + +# In BatchService.submit, after computing `triples` and inserting parent + N children: + +# Build a one-arg closure that BatchService passes to the runner. The closure +# wraps the existing _execute_item logic so the lazy JobService import stays +# on the BatchService side (not the runner), preserving the runner's +# zero-cross-slice-import invariant. +async def _exec_one(item_id: str) -> None: + async with session_maker() as session: + item = (await session.execute( + select(BatchJobItem).where(BatchJobItem.item_id == item_id) + )).scalar_one() + await self._execute_item(session, item) + +effective = await runner.run_batch( + batch_id=batch.batch_id, + item_ids=[i.item_id for i in inserted_items], + max_parallel=batch.max_parallel, + execute_item=_exec_one, +) +batch.effective_max_parallel_runtime = effective # ← if we make this a real attribute +# Alternatively (preferred): write effective into batch.result_summary['effective_max_parallel'] +# and resolve it in the Pydantic .model_validate path so we don't need a column. + + +# ---------------------------------------------------------------- Task 7: DELETE route + +@router.delete("/{batch_id}", response_model=BatchSubmitResponse, ...) +async def cancel_batch(batch_id: str, db: AsyncSession = Depends(get_db)) -> BatchSubmitResponse: + settings = get_settings() + service = BatchService() + parent = await service.get(db, batch_id) + if parent is None: + raise NotFoundError(message=f"Batch not found: {batch_id}", details={"batch_id": batch_id}) + if parent.status in {BatchStatus.COMPLETED, BatchStatus.FAILED, + BatchStatus.PARTIAL, BatchStatus.CANCELLED}: + raise ConflictError(message=f"Batch already terminal: {parent.status.value}", + details={"batch_id": batch_id, "status": parent.status.value}) + + if not runner.cancel_batch(batch_id): + # Race: settled between get() and cancel(). Treat as 409 (already done). + raise ConflictError(message="Batch settled before cancel could fire", + details={"batch_id": batch_id}) + + drained = await runner.await_drain(batch_id, settings.batch_cancel_drain_timeout_seconds) + if not drained: + raise GatewayTimeoutError( + message=f"Drain exceeded {settings.batch_cancel_drain_timeout_seconds}s", + details={"batch_id": batch_id}) + + # Re-load post-settle parent and return. + final = await service.get(db, batch_id) + assert final is not None # the parent row never deletes + return final +``` + +### Integration Points + +```yaml +CONFIG: + - add to: app/core/config.py (Settings model) + - pattern: "batch_global_max_parallel: int = Field(default=4, ge=1, le=64, ...)" + - pattern: "batch_cancel_drain_timeout_seconds: int = Field(default=30, ge=1, le=600, ...)" + - env file: .env.example gains BATCH_GLOBAL_MAX_PARALLEL=4 and BATCH_CANCEL_DRAIN_TIMEOUT_SECONDS=30 + +ERROR TAXONOMY: + - add to: app/core/problem_details.py ERROR_TYPES dict + - key: "GATEWAY_TIMEOUT" → "https://forecastlabai.dev/problems/gateway-timeout" + - add to: app/core/exceptions.py + - pattern: class GatewayTimeoutError(ForecastLabError): status_code=504, code="GATEWAY_TIMEOUT" + +ROUTES: + - add to: app/features/batch/routes.py + - pattern: '@router.delete("/{batch_id}", response_model=BatchSubmitResponse)' + - wiring: existing batch_router is already in app/main.py:142 — no main.py change. + +FRONTEND TYPES: + - add to: frontend/src/types/api.ts BatchSubmitResponse + - field: effective_max_parallel?: number + +FRONTEND HOOKS: + - add to: frontend/src/hooks/use-batches.ts + - export: useCancelBatch (mutation; on success: setQueryData + invalidateQueries) + +FRONTEND UI (driven by shadcn skill — do not hand-write): + - shadcn add: slider (creates frontend/src/components/ui/slider.tsx) + - modify: frontend/src/pages/visualize/batch.tsx (slider on submit form, AlertDialog cancel on progress card) + +NO DATABASE MIGRATION: + - The three columns (max_parallel, running_items, cancelled_items) already + exist per PRP-33's alembic/versions/c1d2e3f40512_create_batch_tables.py. + - `uv run alembic upgrade head` MUST be a no-op after this PR merges. + - `uv run alembic check` MUST detect no schema drift. + +LOGGING: + - new structlog events (request_id propagates via ContextVar inheritance): + batch.runner_start, batch.runner_complete, + batch.cancel_requested, batch.cancel_drained, batch.cancel_drain_timeout, + batch.item_cancelled (in addition to existing batch.item_started/completed/failed) + - per .claude/rules/security-patterns.md: log IDs + counts, NEVER full payloads. +``` + +## Validation Loop + +### Level 1: Syntax & Style + +```bash +# Run FIRST — fix any errors before proceeding. +uv run ruff check . --fix +uv run ruff format . +uv run mypy app/ +uv run pyright app/ +# Expected: zero errors. The strict-mode policy linter (test_strict_mode_policy.py) +# does NOT need changes — no new date/datetime/UUID/Decimal fields are added. +``` + +### Level 2: Unit Tests + +```bash +uv run pytest -v -m "not integration" app/features/batch/ + +# Expected new pass list: +# app/features/batch/tests/test_runner.py::test_semaphore_caps_concurrency PASSED +# app/features/batch/tests/test_runner.py::test_settings_global_cap_clamps_max_parallel PASSED +# app/features/batch/tests/test_runner.py::test_child_failure_does_not_abort_siblings PASSED +# app/features/batch/tests/test_runner.py::test_cancel_pending_child_marks_cancelled_without_running PASSED +# app/features/batch/tests/test_runner.py::test_cancel_running_child_propagates_cancelled_error PASSED +# app/features/batch/tests/test_routes_cancel.py::test_delete_404_unknown_batch PASSED +# app/features/batch/tests/test_routes_cancel.py::test_delete_409_terminal_batch PASSED +# app/features/batch/tests/test_routes_cancel.py::test_delete_504_drain_timeout PASSED +# Plus EVERY existing batch unit test still passes (test_metrics_jsonb_shape_pinned, +# test_picker_query_uses_skip_locked, test_expand_scope_manual_cartesian, ...). +``` + +### Level 3: Integration Tests + +```bash +docker compose up -d +uv run alembic upgrade head # must be a no-op — verify with: +uv run alembic check # expected: "No new upgrade operations detected." + +uv run pytest -v -m integration app/features/batch/ + +# Expected new pass list: +# test_runner.py::test_parent_status_progresses_as_children_complete PASSED +# test_runner.py::test_db_connection_pool_not_exhausted PASSED +# test_routes_cancel.py::test_delete_cancels_in_flight_children_against_real_db PASSED +# test_runner_chaos.py::test_cancel_mid_flight_does_not_orphan_running_items PASSED +# test_runner_chaos.py::test_cancel_during_db_commit_keeps_invariants PASSED +# Plus EVERY existing batch integration test still passes (test_submit_batch_happy_path, +# test_submit_batch_partial_failure, test_scope_over_cap_returns_422, +# test_get_items_sort_by_allow_list, test_get_batch_404, +# test_migration_partial_index_present, test_service_emits_lifecycle_events). +``` + +### Level 4: End-to-end smoke + +```bash +# Start the backend + UI. +uv run uvicorn app.main:app --reload --port 8123 & +cd frontend && ./node_modules/.bin/vite --host 0.0.0.0 & + +# Submit a small batch with max_parallel=3. +curl -s -X POST http://localhost:8123/batch/forecasting \ + -H "Content-Type: application/json" \ + -d '{ + "operation": "backtest", + "scope": {"kind": "manual", "store_ids": [1], "product_ids": [1,2,3]}, + "model_configs": [{"model_type": "naive"}], + "start_date": "2024-01-01", + "end_date": "2024-04-29", + "max_parallel": 3 + }' | jq '{batch_id, status, total_items, completed_items, running_items, effective_max_parallel}' +# Expected: status="completed", completed_items=3, effective_max_parallel=3 + +# Test cancel against a longer-running batch — open a 20-item batch in one +# terminal, DELETE from another; expect 200 with status="cancelled". + +# UI dogfood per .claude/rules/ui-design.md: +# Use the webapp-testing skill to drive /visualize/batch in a browser, +# move the slider, submit, watch running_items chip update, click cancel. +``` + +### Level 5: Frontend gates + +```bash +cd frontend +pnpm tsc --noEmit # must be clean +pnpm lint # must be clean +pnpm test --run # vitest — must pass including new use-batches.test.ts +``` + +## Final Validation Checklist + +- [ ] `grep -rn "asyncio.gather" app/features/batch/` returns no production-code line. +- [ ] `grep -rn "tg.cancel_scope" app/features/batch/` returns no match (the INITIAL's broken hint). +- [ ] `grep -rn "asyncio.all_tasks" app/features/batch/runner.py` returns no match (broken cancel mechanism). +- [ ] `grep -rn "from app.features.jobs" app/features/batch/runner.py` returns NO match (cross-slice imports stay lazy and in BatchService, not in the runner). +- [ ] `uv run alembic upgrade head` is a no-op after merge; `uv run alembic check` reports no drift. +- [ ] `Settings.batch_global_max_parallel` and `Settings.batch_cancel_drain_timeout_seconds` are documented in `.env.example`. +- [ ] `BatchSubmitResponse.effective_max_parallel` is non-zero on every freshly-submitted batch. +- [ ] `frontend/src/components/ui/slider.tsx` exists and was added via the shadcn MCP (not hand-written). +- [ ] `frontend/src/pages/visualize/batch.tsx` renders the slider + cancel button — verified visually via the webapp-testing skill. +- [ ] CHANGELOG.md gains a release-please-eligible `feat(batch):` entry (the merge commit subject drives the next pre-1.0 PATCH bump). +- [ ] No new dependency in `pyproject.toml` (single-host vision intact). +- [ ] All five validation-gate commands listed above pass locally. + +--- + +## Anti-Patterns to Avoid + +- ❌ `asyncio.gather(*tasks)` for fanning out child work. The point of this PRP is to make it impossible. +- ❌ `tg.cancel_scope.cancel()` — that attribute does not exist on stdlib `asyncio.TaskGroup`; it's an anyio API. +- ❌ Plain `except asyncio.CancelledError:` after `async with asyncio.TaskGroup():` — TaskGroup wraps in an ExceptionGroup; use `except*`. +- ❌ Iterating `asyncio.all_tasks()` to find children to cancel — keep Task refs in `CancelHandle.tasks`. +- ❌ Reusing the request's `AsyncSession` across children — open a fresh session per child via `get_session_maker()`. +- ❌ A new Alembic migration to add `max_parallel`/`running_items`/`cancelled_items` — they already exist (PRP-33). +- ❌ Hand-writing `slider.tsx` — drive it through `pnpm dlx shadcn@latest add slider` per `.claude/rules/shadcn-ui.md`. +- ❌ Importing a sibling slice's service at module scope inside `runner.py` — every cross-slice call stays lazy + in-method. +- ❌ Logging full request/response payloads in `batch.runner_*` events — IDs and counts only, per `.claude/rules/security-patterns.md`. +- ❌ Pushing parallelism above `batch_global_max_parallel + a few headroom slots` without also bumping `app/core/database.py:get_engine()` pool sizing — they go together. +- ❌ Adding a Celery / Redis / Arq queue — `product-vision.md` forbids it ("Not cloud-locked", "single-host deployable"). +- ❌ A new "wipe everything" path on batches — `product-vision.md` § "Not a destructive tool". +- ❌ Adding an `AGENT_REQUIRE_APPROVAL` entry for batch cancellation — cancel is not a state-mutating agent tool here; the surface remains operator-driven over HTTP. + +--- + +## Pre-flight + +Before kicking off implementation: + +1. Open a tracking issue with title `feat(batch): activate max_parallel + cooperative cancellation (PRP-34)` referencing `PRPs/INITIAL/INITIAL-batch-parallel-execution.md` and `PRPs/PRP-34-batch-parallel-execution.md`. Branch will be `feat/batch-parallel-execution` off `dev` (per `.claude/rules/branch-naming.md`). +2. Confirm the verified asyncio behaviour by running the snippets at the top of `PRPs/ai_docs/asyncio-taskgroup-cancellation.md` — if Python or SQLAlchemy has been upgraded since this PRP was written, refresh the doc's claims first. +3. Audit `git log -- app/features/batch/` since PR #281 — verify no subsequent PR moved any of the forward-compat columns or weakened the picker invariant. + +--- + +## Confidence Score + +**8 / 10** for one-pass implementation success. + +The 2-point deduction: + +- (-1) Concurrent integration tests against a single `docker-compose` Postgres can race on the seed-fixture rows; the chaos test in particular is sensitive to scheduler ordering. Likely 1-2 test iterations to stabilise polling timeouts and event-set checkpoints. +- (-1) The frontend Slider integration depends on a clean run of the shadcn MCP install pipeline from the executor's environment. If the MCP isn't authenticated or the registry is mismatched, the fallback is `pnpm dlx shadcn@latest add slider` which works but requires manual `components.json` checks. + +The rest is bounded by the precedent in PRP-33 + the verified asyncio doc + the existing batch slice's test patterns. diff --git a/PRPs/ai_docs/asyncio-taskgroup-cancellation.md b/PRPs/ai_docs/asyncio-taskgroup-cancellation.md new file mode 100644 index 00000000..fecdd8f2 --- /dev/null +++ b/PRPs/ai_docs/asyncio-taskgroup-cancellation.md @@ -0,0 +1,209 @@ +# asyncio.TaskGroup + Semaphore + Cooperative Cancellation (Python 3.12) + +> Captured for PRP-34. Verified against Python 3.12.13 (the project's pinned +> interpreter, per `pyproject.toml requires-python = ">=3.12"`). + +## What the stdlib actually offers + +`asyncio.TaskGroup` is the Python 3.11+ structured-concurrency primitive +([docs](https://docs.python.org/3.12/library/asyncio-task.html#asyncio.TaskGroup)). +The **only** public surface is `create_task(coro, *, name=None, context=None)`. +There is no `cancel_scope`, no `cancel()`, no `tasks` property — verified: + +```bash +uv run python -c "import asyncio; print([m for m in dir(asyncio.TaskGroup) if not m.startswith('_')])" +# → ['create_task'] +``` + +The INITIAL's pseudocode calls `tg.cancel_scope.cancel()`. **That attribute does +not exist** — that's anyio API, not stdlib asyncio. Don't import anyio for +this; instead, retain the tasks `create_task` returns and cancel them. + +## How to cancel TaskGroup-managed children + +Three mutually compatible mechanisms work on stdlib `TaskGroup`: + +1. **Per-task cancel.** `tg.create_task(coro)` returns a `Task`; keep a list + and call `task.cancel()` on each. The TaskGroup awaits the `CancelledError`s + and re-raises them inside an `ExceptionGroup`, caught with `except*`. + +2. **Raise inside the `async with`.** Any exception raised from the body of + `async with asyncio.TaskGroup() as tg:` cancels all currently running + children, then re-raises as an `ExceptionGroup`. Use this when the + *runner itself* needs to abort (e.g., total drain timeout). + +3. **Cooperative event check.** Each child polls an `asyncio.Event` at safe + yield points (before semaphore acquire, after acquire, in a `finally` + block) and returns early. Use this to skip work that hasn't started. + +Combine #1 (cancel running children) + #3 (skip not-yet-started children). +Verified working pattern: + +```python +import asyncio + +async def child(i: int, sem: asyncio.Semaphore, cancel: asyncio.Event) -> None: + if cancel.is_set(): + return # fast-skip BEFORE semaphore + async with sem: + if cancel.is_set(): + return # fast-skip AFTER semaphore + try: + await asyncio.sleep(10) # the actual work + except asyncio.CancelledError: + # finalise child state here (write 'cancelled' to DB), then re-raise + raise + +async def main() -> None: + sem = asyncio.Semaphore(2) + cancel = asyncio.Event() + try: + async with asyncio.TaskGroup() as tg: + tasks = [ + tg.create_task(child(i, sem, cancel), name=f"batch:b1:item{i}") + for i in range(4) + ] + await asyncio.sleep(0.05) + cancel.set() + for t in tasks: + t.cancel() + except* asyncio.CancelledError: # PEP 654 except* + pass # observed — drain complete +``` + +**Verification command** (re-runnable on library upgrade): + +```bash +uv run python -c " +import asyncio +async def child(i, sem, cancel): + if cancel.is_set(): return + async with sem: + if cancel.is_set(): return + try: await asyncio.sleep(10) + except asyncio.CancelledError: + print(f'child {i} cancelled cleanly'); raise +async def main(): + sem, cancel = asyncio.Semaphore(2), asyncio.Event() + try: + async with asyncio.TaskGroup() as tg: + tasks = [tg.create_task(child(i, sem, cancel), name=f'item{i}') for i in range(4)] + await asyncio.sleep(0.05) + cancel.set() + for t in tasks: t.cancel() + except* asyncio.CancelledError: print('drained') +asyncio.run(main()) +" +# → 'child 0 cancelled cleanly', 'child 1 cancelled cleanly', 'drained' +# Children 2 & 3 never acquired the semaphore so they never printed. +``` + +## Why NOT `asyncio.all_tasks()` + name-prefix scan + +The INITIAL's `cancel_batch` sketch iterates `asyncio.all_tasks()` and matches +on `task.get_name().startswith(f"batch:{batch_id}:")`. This is brittle: + +- `get_name()` defaults to `Task-N` when `name=` is not passed; any + refactor that drops `name=` silently breaks the cancel path. +- It scans the entire event loop's task set, which on a uvicorn process + includes other request handlers, WebSocket pumps, and background tasks. +- Two concurrent batches with name collisions would cancel each other. + +**Correct pattern:** store the `Task` references in the `CancelHandle` when +they're created, cancel via that list. Never reach into `asyncio.all_tasks()`. + +## Semaphore composition with AsyncSession + +`asyncio.Semaphore` is a context manager — `async with sem:` acquires on +entry, releases on exit (including the exception path). The Semaphore must +wrap the **work** (everything that opens a DB session and runs the model), +not the `tg.create_task(...)` call: + +```python +# WRONG — semaphore around scheduling, not work. Defeats the cap. +async with sem: + tg.create_task(child()) + +# RIGHT — child acquires the semaphore inside its own body. +async def child(): + async with sem: + async with session_maker() as session: + await do_work(session) + +for item in items: + tg.create_task(child()) +``` + +Each child opens its OWN `AsyncSession` via `get_session_maker()` inside the +semaphore-acquired block — never shares the runner's session. This is the +only way the SQLAlchemy connection pool survives `max_parallel > 1`. + +## ContextVar inheritance into TaskGroup children + +The project binds the per-request `X-Request-ID` to a `ContextVar` +(`app/core/logging.py:13` — `request_id_ctx: ContextVar[str | None]`). This +is consumed by structlog via the `add_request_id` processor. + +**`asyncio.Task` created with `create_task()` inherits the current +`contextvars.Context`** (CPython implementation detail since 3.7; documented +behaviour). That means every `batch.item_*` log line emitted from a TaskGroup +child gets the same `request_id` as the parent `POST /batch/forecasting` +request automatically. No explicit `structlog.contextvars.bind_contextvars` +plumbing needed inside the runner. + +Verification: + +```bash +uv run python -c " +import asyncio, contextvars +v = contextvars.ContextVar('v', default=None) +async def child(): print('child sees:', v.get()) +async def main(): + v.set('outer') + async with asyncio.TaskGroup() as tg: + tg.create_task(child()) +asyncio.run(main()) +" +# → child sees: outer +``` + +## SQLAlchemy async pool defaults that bound the design + +```bash +uv run python -c " +from sqlalchemy.ext.asyncio import create_async_engine +e = create_async_engine('postgresql+asyncpg://x:x@h:5433/x') +print('size:', e.pool.size(), 'overflow:', e.pool._max_overflow, 'timeout:', e.pool._timeout) +" +# → size: 5 overflow: 10 timeout: 30.0 +``` + +So the upper bound on parallel children that won't trigger +`sqlalchemy.exc.TimeoutError` (the QueuePool exhaustion symptom) is +`pool_size + max_overflow - reserved_for_other_requests`. With one +in-flight `POST /batch/forecasting` request (1 session) + one `cancel` +endpoint (potentially 1 session) + the runner's settle-parent session (1), +`batch_global_max_parallel ≤ 12` is the safe headroom. The PRP's +`Settings.batch_global_max_parallel = 4` default is well inside this. + +If a future iteration raises the global cap above 12, bump the engine's +`pool_size`/`max_overflow` in `app/core/database.py:get_engine()` in the +same PR — these go together. + +## sklearn / LightGBM ignore CancelledError mid-fit + +The model-training call inside `JobService.create_job` lands in a sync C +extension (`sklearn.ensemble`/`lightgbm.Booster.update`). `asyncio` cannot +preempt sync C code. Cancellation semantics: + +- **Pending children** (haven't reached `sem.acquire()`): observe the + cancel event, transition to `cancelled`, no work executed. +- **In-flight children** (mid-fit): the cancel signal is **deferred** until + the C call returns. The `try/finally` block then writes `cancelled` if + the task observed `CancelledError`, or `completed`/`failed` if the work + finished first. **Either is acceptable** — the parent's status reflects + the actual per-child outcomes; the invariant is "no row left in + `running` state after settle". + +This is why the PRP needs `Settings.batch_cancel_drain_timeout_seconds` +(default 30s) and a 504 surface — an in-flight fit can stall the drain. diff --git a/app/core/config.py b/app/core/config.py index 2ac65061..27614159 100644 --- a/app/core/config.py +++ b/app/core/config.py @@ -121,6 +121,18 @@ class Settings(BaseSettings): # Batch runner (PRP-33) — cap on scope expansion (pairs x model_configs). batch_max_scope_expansion: int = 1000 + # Batch runner concurrency (PRP-34) — hard upper bound on concurrent + # in-flight batch_job_item executions across all active batches on this + # host. Sized for the docker-compose Postgres pool (pool_size=5, + # max_overflow=10). Effective per-batch parallelism is + # min(batch_job.max_parallel, this). Env override: + # BATCH_GLOBAL_MAX_PARALLEL=8 — requires uvicorn restart. + batch_global_max_parallel: int = 4 + # Max seconds DELETE /batch/{batch_id} waits for in-flight children to + # settle before returning RFC 7807 504. In-flight sklearn/LightGBM fits + # are uncancellable mid-call, so a long fit can stall the drain. + batch_cancel_drain_timeout_seconds: int = 30 + # RAG Embedding Configuration rag_embedding_provider: Literal["openai", "ollama"] = "openai" openai_api_key: str = "" diff --git a/app/core/exceptions.py b/app/core/exceptions.py index ef1561fb..fd5e2b36 100644 --- a/app/core/exceptions.py +++ b/app/core/exceptions.py @@ -199,6 +199,30 @@ def __init__( ) +class GatewayTimeoutError(ForecastLabError): + """504 — server's own internal drain or upstream wait exceeded its budget. + + Use when a bounded server-side wait (e.g., ``DELETE /batch/{batch_id}`` + draining in-flight children) exceeds its configured budget. Distinct from + a 408 client-timeout: the client did not time out, the server's own + drain budget did. The PRP-34 batch cancel path is the canonical caller. + """ + + error_type_uri: str = ERROR_TYPES["GATEWAY_TIMEOUT"] + + def __init__( + self, + message: str = "Operation drain exceeded budget", + details: dict[str, Any] | None = None, + ) -> None: + super().__init__( + message=message, + code="GATEWAY_TIMEOUT", + status_code=504, + details=details, + ) + + # ============================================================================= # Exception Handlers (RFC 7807) # ============================================================================= diff --git a/app/core/problem_details.py b/app/core/problem_details.py index 2a24ecb7..7de6e462 100644 --- a/app/core/problem_details.py +++ b/app/core/problem_details.py @@ -35,6 +35,7 @@ "INTERNAL_ERROR": f"{ERROR_TYPE_BASE}/internal", "BAD_REQUEST": f"{ERROR_TYPE_BASE}/bad-request", "SERVICE_UNAVAILABLE": f"{ERROR_TYPE_BASE}/service-unavailable", + "GATEWAY_TIMEOUT": f"{ERROR_TYPE_BASE}/gateway-timeout", } diff --git a/app/features/batch/models.py b/app/features/batch/models.py index fdbc9b1a..a5f47649 100644 --- a/app/features/batch/models.py +++ b/app/features/batch/models.py @@ -103,13 +103,30 @@ class BatchItemStatus(str, Enum): VALID_BATCH_ITEM_TRANSITIONS: dict[BatchItemStatus, set[BatchItemStatus]] = { BatchItemStatus.PENDING: {BatchItemStatus.RUNNING, BatchItemStatus.CANCELLED}, - BatchItemStatus.RUNNING: {BatchItemStatus.COMPLETED, BatchItemStatus.FAILED}, + # PRP-34 added ``RUNNING → CANCELLED`` so a child that observes + # ``CancelledError`` mid-flight can write its terminal state truthfully — + # the PRP-33 MVP runner never wrote ``CANCELLED`` so the path was unused. + BatchItemStatus.RUNNING: { + BatchItemStatus.COMPLETED, + BatchItemStatus.FAILED, + BatchItemStatus.CANCELLED, + }, BatchItemStatus.COMPLETED: set(), BatchItemStatus.FAILED: set(), BatchItemStatus.CANCELLED: set(), } +# Derived from the state machine: a status with no out-edges is terminal. The +# PRP-34 ``DELETE /batch/{batch_id}`` route reads this constant — keeping the +# definition next to the dict guarantees any future state-machine edit (e.g., +# a re-open path) updates the cancel surface at the same time. +TERMINAL_BATCH_STATES: frozenset[BatchStatus] = frozenset( + status for status, out_edges in VALID_BATCH_TRANSITIONS.items() if not out_edges +) +"""Statuses a parent batch cannot transition out of. The cancel-route 409 set.""" + + class BatchJob(TimestampMixin, Base): """Parent batch record — one row per submission. diff --git a/app/features/batch/routes.py b/app/features/batch/routes.py index aeee169d..5ba11963 100644 --- a/app/features/batch/routes.py +++ b/app/features/batch/routes.py @@ -13,9 +13,12 @@ from fastapi import APIRouter, Depends, Query, status from sqlalchemy.ext.asyncio import AsyncSession +from app.core.config import get_settings from app.core.database import get_db -from app.core.exceptions import NotFoundError +from app.core.exceptions import ConflictError, GatewayTimeoutError, NotFoundError from app.core.logging import get_logger +from app.features.batch import runner +from app.features.batch.models import TERMINAL_BATCH_STATES from app.features.batch.schemas import ( BatchItemListResponse, BatchSubmitRequest, @@ -69,6 +72,82 @@ async def get_batch( return result +@router.delete( + "/{batch_id}", + response_model=BatchSubmitResponse, + summary="Cancel an in-flight batch (cooperative drain)", + description=( + "Cancel an in-flight batch (PRP-34). Pending children skip execution; " + "running children observe ``asyncio.CancelledError`` at the next safe " + "yield point — sklearn / LightGBM fits are uncancellable mid-call, so " + "an in-flight fit may stall the drain (504 surfaces that). Returns:\n\n" + "- ``200`` settled parent on clean drain\n" + "- ``404`` RFC 7807 if the batch does not exist\n" + "- ``409`` RFC 7807 if the batch is already terminal\n" + "- ``504`` RFC 7807 if the drain exceeds " + "``Settings.batch_cancel_drain_timeout_seconds``" + ), +) +async def cancel_batch_route( + batch_id: str, + db: AsyncSession = Depends(get_db), +) -> BatchSubmitResponse: + """Cancel an in-flight batch and return its settled parent record.""" + service = BatchService() + parent = await service.get(db=db, batch_id=batch_id) + if parent is None: + raise NotFoundError( + message=f"Batch not found: {batch_id}", + details={"batch_id": batch_id}, + ) + if parent.status in TERMINAL_BATCH_STATES: + raise ConflictError( + message=f"Batch already terminal: {parent.status.value}", + details={"batch_id": batch_id, "status": parent.status.value}, + ) + + fired = runner.cancel_batch(batch_id) + if not fired: + # Race: the submit handler's ``_settle`` committed and + # ``mark_completed`` removed the registry handle between our + # ``service.get`` above and ``cancel_batch`` here. The parent is now + # terminal in DB but we still raise 409 so the operator's intent + # ("I want this stopped") gets a truthful answer. + raise ConflictError( + message="Batch settled before cancel could fire", + details={"batch_id": batch_id}, + ) + + settings = get_settings() + drained = await runner.await_drain( + batch_id=batch_id, + timeout_seconds=float(settings.batch_cancel_drain_timeout_seconds), + ) + if not drained: + raise GatewayTimeoutError( + message=( + f"Drain exceeded {settings.batch_cancel_drain_timeout_seconds}s; " + "parent settle still pending. In-flight sklearn / LightGBM fits " + "are uncancellable mid-call — retry once the fit completes." + ), + details={ + "batch_id": batch_id, + "drain_timeout_seconds": settings.batch_cancel_drain_timeout_seconds, + }, + ) + + final = await service.get(db=db, batch_id=batch_id) + if final is None: + # Defensive — ``batch_job`` rows are never deleted, so this branch + # should be unreachable. Surface as 404 if it ever happens. + raise NotFoundError( + message=f"Batch not found after drain: {batch_id}", + details={"batch_id": batch_id}, + ) + logger.info("batch.cancelled", batch_id=batch_id, status=final.status.value) + return final + + @router.get( "/{batch_id}/items", response_model=BatchItemListResponse, diff --git a/app/features/batch/runner.py b/app/features/batch/runner.py new file mode 100644 index 00000000..14727d83 --- /dev/null +++ b/app/features/batch/runner.py @@ -0,0 +1,381 @@ +"""Bounded-concurrency batch runner (PRP-34). + +Activates the three forward-compat columns PRP-33 shipped on ``batch_job`` +(``max_parallel``, ``running_items``, ``cancelled_items``). The runner is a +single :class:`asyncio.Semaphore` inside an :class:`asyncio.TaskGroup` that +fans out one task per ``batch_job_item``; each child opens its own +``AsyncSession`` and observes a cooperative :class:`asyncio.Event` so +``DELETE /batch/{batch_id}`` cancels what hasn't started and gracefully +drains what has. + +The asyncio mechanics (the three working cancel mechanisms, the +``except* asyncio.CancelledError`` PEP-654 catch shape, the ``ContextVar`` +inheritance into TaskGroup children) are documented end-to-end in +``PRPs/ai_docs/asyncio-taskgroup-cancellation.md``. + +Cross-slice rule: this module imports from ``app.features.batch.models`` +(same slice) and ``app.core.*`` only — no cross-slice imports, even lazy. +The per-child execute callable supplied by ``BatchService`` is the seam +that keeps ``app.features.jobs`` reachable without an import here. +""" + +from __future__ import annotations + +import asyncio +from collections.abc import Awaitable, Callable +from dataclasses import dataclass, field +from datetime import UTC, datetime +from typing import TYPE_CHECKING + +from sqlalchemy import select, update + +from app.core.logging import get_logger +from app.features.batch.models import ( + BatchItemStatus, + BatchJob, + BatchJobItem, +) + +if TYPE_CHECKING: + from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker + + +logger = get_logger(__name__) + + +@dataclass +class CancelHandle: + """Cancel signal + Task refs + completion event for an in-flight batch. + + Created by :func:`run_batch`, looked up by :func:`cancel_batch`, removed + from :data:`_ACTIVE_BATCHES` and signalled by the runner's caller via + :func:`mark_completed` *after* the parent's settle has committed — so + ``DELETE /batch/{batch_id}`` never observes the parent mid-settle. + + Attributes: + cancel_event: Set to signal cooperative drain. + completed_event: Set by ``mark_completed`` after parent settle commits. + tasks: ``asyncio.Task`` refs returned by ``tg.create_task`` — cancel + target. Never use :func:`asyncio.all_tasks` to find these; that + mechanism is brittle across concurrent batches (see the ai_doc). + """ + + cancel_event: asyncio.Event = field(default_factory=asyncio.Event) + completed_event: asyncio.Event = field(default_factory=asyncio.Event) + tasks: list[asyncio.Task[None]] = field(default_factory=list) + + +# Module-level registry — single-process scope (matches the single-host vision +# of ``.claude/rules/product-vision.md``). If a future ADR moves this to a +# shared store (Redis, Postgres advisory locks) that is the entry point. +_ACTIVE_BATCHES: dict[str, CancelHandle] = {} + + +async def run_batch( + *, + batch_id: str, + item_ids: list[str], + max_parallel: int, + global_max_parallel: int, + session_maker: async_sessionmaker[AsyncSession], + execute_item: Callable[[str], Awaitable[None]], +) -> int: + """Execute one batch through a bounded :class:`asyncio.TaskGroup`. + + Args: + batch_id: ``batch_job.batch_id`` — registry key + log correlator. + item_ids: pending ``batch_job_item.item_id`` values, in submit order. + max_parallel: per-batch cap declared in ``batch_job.max_parallel``. + global_max_parallel: host-wide cap from + :attr:`app.core.config.Settings.batch_global_max_parallel`. + session_maker: shared ``async_sessionmaker`` — each child opens + **one** ``AsyncSession`` from this maker and reuses it for the + DB writes the runner emits (state transitions + ``running_items`` + counter). The caller-supplied ``execute_item`` opens its own + session from the same maker (the runner does not pass its + session in — keeps the contract symmetric with + ``BatchService._execute_item``). + execute_item: one-arg coroutine the caller provides. Opens its own + ``AsyncSession`` from the same ``session_maker`` and runs the + batch item's work (e.g., delegating to ``JobService``). + + Returns: + ``effective_max_parallel = min(max_parallel, global_max_parallel)``. + + Notes: + - Caller MUST call :func:`mark_completed` after the parent's settle + commits, even on the exception path. The runner deliberately does + NOT pop the registry entry itself — that would let ``DELETE``'s + drain race the settle commit. + - Cancellation does NOT propagate out: ``except* asyncio.CancelledError`` + absorbs the ``ExceptionGroup`` so the caller can settle the parent + to its observed terminal state. + """ + effective = min(max_parallel, global_max_parallel) + sem = asyncio.Semaphore(effective) + handle = _ACTIVE_BATCHES.setdefault(batch_id, CancelHandle()) + + logger.info( + "batch.runner_start", + batch_id=batch_id, + total_items=len(item_ids), + max_parallel=max_parallel, + effective_max_parallel=effective, + ) + + async def _child(item_id: str) -> None: + # One ``AsyncSession`` per child — used for every DB write the runner + # emits below. Each helper commits its own UPDATE on this session so + # individual state transitions are visible to concurrent observers + # (``running_items`` counter is observable to DELETE handlers, etc.). + async with session_maker() as session: + # FAST-CANCEL before semaphore acquire — skips not-yet-started + # work cleanly (the cancel_event check is sync; no await window + # for a late ``task.cancel()`` to interrupt). + if handle.cancel_event.is_set(): + await _mark_cancelled_skipped(session, item_id) + return + + # ``acquired`` tracks whether we entered the semaphore-guarded + # body. When ``task.cancel()`` fires while we are *waiting* on + # the semaphore, ``async with sem:`` raises ``CancelledError`` + # before the inner re-check runs; the outer except below routes + # the item to ``mark_cancelled_skipped`` so the cancel surface + # is consistent. + acquired = False + try: + async with sem: + acquired = True + # Re-check after acquire — a sibling may have signalled + # cancel while we waited on the semaphore. + if handle.cancel_event.is_set(): + await _mark_cancelled_skipped(session, item_id) + return + await _bump_running(session, batch_id, +1) + try: + await execute_item(item_id) + except asyncio.CancelledError: + # ``execute_item`` catches ``Exception`` but NOT + # ``BaseException``; ``CancelledError`` (BaseException + # in 3.8+) bubbles up cleanly. Persist the cancelled + # terminal state before re-raising so the TaskGroup + # absorbs the cancel. + await _mark_cancelled_running(session, item_id) + raise + except Exception: + # Defensive: ``execute_item`` should have persisted + # its own failure; if it didn't, mark FAILED so the + # parent settle aggregates correctly. Do NOT + # re-raise — that would tear down sibling children + # (test_child_failure_does_not_abort_siblings). + logger.exception( + "batch.runner_unexpected_child_error", + batch_id=batch_id, + item_id=item_id, + ) + await _mark_failed_unexpected(session, item_id) + finally: + await _bump_running(session, batch_id, -1) + except asyncio.CancelledError: + if not acquired: + # Cancel reached us before we entered the semaphore body + # — never started work, never bumped running_items. + await _mark_cancelled_skipped(session, item_id) + raise + + try: + async with asyncio.TaskGroup() as tg: + for iid in item_ids: + # ``name=`` lets the operator inspect tasks in a debugger; we + # do NOT rely on the name for cancellation (we hold Task refs). + task = tg.create_task(_child(iid), name=f"batch:{batch_id}:{iid}") + handle.tasks.append(task) + except* asyncio.CancelledError: + # Clean ``task.cancel()`` calls are absorbed by TaskGroup, so this + # branch is defensive — it fires only if the *parent* coroutine + # (the POST handler) is cancelled. Either way, the per-child + # ``finally`` already wrote the terminal state. + logger.info("batch.runner_cancelled_exception_group", batch_id=batch_id) + + logger.info( + "batch.runner_complete", + batch_id=batch_id, + cancel_requested=handle.cancel_event.is_set(), + ) + return effective + + +def cancel_batch(batch_id: str) -> bool: + """Signal cooperative cancel for an in-flight batch. + + Sets ``cancel_event`` (skips pending children) and calls ``task.cancel()`` + on every tracked child (interrupts running children at the next yield). + + Returns: + ``True`` if the batch was registered (cancel signal fired); + ``False`` if no handle exists (race: batch settled before cancel). + """ + handle = _ACTIVE_BATCHES.get(batch_id) + if handle is None: + return False + handle.cancel_event.set() + cancelled_count = 0 + for task in handle.tasks: + if not task.done(): + task.cancel() + cancelled_count += 1 + logger.info( + "batch.cancel_requested", + batch_id=batch_id, + n_tasks_tracked=len(handle.tasks), + n_tasks_cancelled=cancelled_count, + ) + return True + + +async def await_drain(batch_id: str, timeout_seconds: float) -> bool: + """Block until the batch's parent settle commits, or timeout elapses. + + Args: + batch_id: ``batch_job.batch_id``. + timeout_seconds: max seconds to wait. + + Returns: + ``True`` on clean drain (or if the batch was never registered); + ``False`` on timeout. + """ + handle = _ACTIVE_BATCHES.get(batch_id) + if handle is None: + # Already drained (or never registered) — DELETE handler reads as + # "no need to wait, fetch the settled parent now". + return True + try: + await asyncio.wait_for( + handle.completed_event.wait(), + timeout=timeout_seconds, + ) + return True + except TimeoutError: + # ``asyncio.TimeoutError`` is aliased to the built-in ``TimeoutError`` + # since Python 3.11 (PEP 678 / asyncio docs). The project pins + # Python >= 3.12, so this catch IS the asyncio.wait_for timeout. + logger.warning( + "batch.cancel_drain_timeout", + batch_id=batch_id, + timeout_seconds=timeout_seconds, + ) + return False + + +def mark_completed(batch_id: str) -> None: + """Signal that the batch's parent settle has committed. + + Must be called by ``BatchService.submit`` after its ``_settle`` commits + (including on the failure path) so any concurrent ``DELETE`` drain + unblocks. Idempotent: a missing handle is a no-op. + """ + handle = _ACTIVE_BATCHES.pop(batch_id, None) + if handle is None: + return + handle.completed_event.set() + + +# --------------------------------------------------------------------- helpers +# Each helper accepts an already-open ``AsyncSession`` (one per child; +# managed by ``_child``) and commits its single UPDATE on that session. They +# do NOT call ``BatchService`` (would close an import cycle) and they do not +# raise on missing rows (a race where the parent was deleted is survivable +# — log + move on). + + +async def _bump_running( + session: AsyncSession, + batch_id: str, + delta: int, +) -> None: + """Atomically bump ``batch_job.running_items`` by ``delta`` (±1).""" + await session.execute( + update(BatchJob) + .where(BatchJob.batch_id == batch_id) + .values(running_items=BatchJob.running_items + delta) + ) + await session.commit() + + +async def _mark_cancelled_skipped( + session: AsyncSession, + item_id: str, +) -> None: + """Mark a not-yet-started item as cancelled (pending → cancelled).""" + now = datetime.now(UTC) + await session.execute( + update(BatchJobItem) + .where(BatchJobItem.item_id == item_id) + .values( + status=BatchItemStatus.CANCELLED.value, + completed_at=now, + ) + ) + await session.commit() + + +async def _mark_cancelled_running( + session: AsyncSession, + item_id: str, +) -> None: + """Mark a running item as cancelled (running → cancelled). + + Runs inside the child's ``except asyncio.CancelledError`` block, so + ``execute_item`` has already set the item to ``RUNNING`` and bumped the + parent's ``running_items`` counter. The decrement happens in the + surrounding ``finally`` block. + """ + now = datetime.now(UTC) + row = ( + await session.execute( + select(BatchJobItem.started_at).where(BatchJobItem.item_id == item_id) + ) + ).first() + started_at = row[0] if row is not None else None + duration_ms = int((now - started_at).total_seconds() * 1000) if started_at is not None else None + await session.execute( + update(BatchJobItem) + .where(BatchJobItem.item_id == item_id) + .values( + status=BatchItemStatus.CANCELLED.value, + completed_at=now, + duration_ms=duration_ms, + ) + ) + await session.commit() + + +async def _mark_failed_unexpected( + session: AsyncSession, + item_id: str, +) -> None: + """Defensive: mark an item ``failed`` when ``execute_item`` raised an + uncaught exception (its own ``except Exception`` should normally absorb). + """ + now = datetime.now(UTC) + await session.execute( + update(BatchJobItem) + .where(BatchJobItem.item_id == item_id) + .values( + status=BatchItemStatus.FAILED.value, + completed_at=now, + error_message="Runner caught unexpected exception (see structlog)", + error_type="UnexpectedRunnerError", + ) + ) + await session.commit() + + +__all__ = [ + "_ACTIVE_BATCHES", + "CancelHandle", + "await_drain", + "cancel_batch", + "mark_completed", + "run_batch", +] diff --git a/app/features/batch/schemas.py b/app/features/batch/schemas.py index 100240ba..a9d6ec97 100644 --- a/app/features/batch/schemas.py +++ b/app/features/batch/schemas.py @@ -18,7 +18,7 @@ from enum import Enum from typing import Any, Literal -from pydantic import BaseModel, ConfigDict, Field, model_validator +from pydantic import BaseModel, ConfigDict, Field, computed_field, model_validator from app.features.batch.models import BatchItemStatus, BatchOperation, BatchStatus @@ -162,7 +162,14 @@ class BatchItemResponse(BaseModel): class BatchSubmitResponse(BaseModel): - """Parent batch record — returned by submit + GET /batch/{id}.""" + """Parent batch record — returned by submit + GET /batch/{id}. + + ``effective_max_parallel`` is a :func:`computed_field` resolved from + ``result_summary["effective_max_parallel"]`` — the PRP-34 runner writes + that key on every batch it executes. Legacy batches (pre-PRP-34) and any + batch where the key is missing return ``0``. Storing the value in JSONB + (rather than a real column) means PRP-34 ships with NO Alembic migration. + """ model_config = ConfigDict(from_attributes=True) @@ -174,12 +181,29 @@ class BatchSubmitResponse(BaseModel): failed_items: int running_items: int cancelled_items: int + max_parallel: int started_at: datetime | None completed_at: datetime | None result_summary: dict[str, Any] | None created_at: datetime updated_at: datetime + @computed_field # type: ignore[prop-decorator] + @property + def effective_max_parallel(self) -> int: + """Resolved ``min(max_parallel, settings.batch_global_max_parallel)``. + + The PRP-34 runner persists this under ``result_summary`` on settle. + Returns ``0`` for legacy pre-PRP-34 rows or where the key is missing. + """ + if self.result_summary is None: + return 0 + value = self.result_summary.get("effective_max_parallel", 0) + try: + return int(value) + except (TypeError, ValueError): + return 0 + class BatchItemListResponse(BaseModel): """Paginated item listing — GET /batch/{id}/items.""" diff --git a/app/features/batch/service.py b/app/features/batch/service.py index ccad4ec7..ec657389 100644 --- a/app/features/batch/service.py +++ b/app/features/batch/service.py @@ -26,8 +26,10 @@ from sqlalchemy.orm import InstrumentedAttribute from app.core.config import get_settings +from app.core.database import get_session_maker from app.core.exceptions import ValidationError from app.core.logging import get_logger +from app.features.batch import runner from app.features.batch.models import ( BatchItemStatus, BatchJob, @@ -118,6 +120,7 @@ async def submit(self, db: AsyncSession, req: BatchSubmitRequest) -> BatchSubmit max_parallel=req.max_parallel, ) db.add(batch) + inserted_items: list[BatchJobItem] = [] for store_id, product_id, mc in triples: item = BatchJobItem( item_id=uuid.uuid4().hex, @@ -130,6 +133,7 @@ async def submit(self, db: AsyncSession, req: BatchSubmitRequest) -> BatchSubmit params=self._frozen_item_params(req, store_id, product_id, mc), ) db.add(item) + inserted_items.append(item) await db.commit() await db.refresh(batch) @@ -145,19 +149,50 @@ async def submit(self, db: AsyncSession, req: BatchSubmitRequest) -> BatchSubmit batch.started_at = datetime.now(UTC) await db.commit() - # 3. Loop the picker until no PENDING item remains. The explicit - # ``BatchJobItem | None`` annotation prevents mypy from re-narrowing - # ``item`` to ``BatchJobItem`` on the second iteration after the - # first ``if item is None: break`` branch. - while True: - next_item: BatchJobItem | None = await self._pick_next(db, batch.batch_id) - if next_item is None: - break - await self._execute_item(db, next_item) - - # 4. Settle the parent. - await self._settle(db, batch) - await db.refresh(batch) + # 3. Run the batch through the PRP-34 bounded-concurrency runner. A + # SHARED session_maker is created once and passed to every child so + # the SQLAlchemy connection pool is shared (pool_size=5, + # max_overflow=10 by default — sufficient headroom for + # batch_global_max_parallel ≤ 12; see + # PRPs/ai_docs/asyncio-taskgroup-cancellation.md § "SQLAlchemy async + # pool defaults"). + session_maker_local = get_session_maker() + + async def _exec_one(item_id: str) -> None: + """Per-child: open own session, fetch item, delegate to _execute_item. + + Each child owns its own ``AsyncSession`` — never share the + request's ``db`` here, that would corrupt SQLAlchemy's + identity map across concurrent children. + """ + async with session_maker_local() as child_session: + item = ( + await child_session.execute( + select(BatchJobItem).where(BatchJobItem.item_id == item_id) + ) + ).scalar_one() + await self._execute_item(child_session, item) + + effective_max_parallel = 0 + try: + effective_max_parallel = await runner.run_batch( + batch_id=batch.batch_id, + item_ids=[it.item_id for it in inserted_items], + max_parallel=req.max_parallel, + global_max_parallel=self.settings.batch_global_max_parallel, + session_maker=session_maker_local, + execute_item=_exec_one, + ) + finally: + # 4. Settle the parent regardless of how the runner exited. Pass + # ``effective_max_parallel`` so the JSONB result_summary carries + # the value PRP-34's BatchSubmitResponse.effective_max_parallel + # computed_field reads at serialisation time. + await self._settle(db, batch, effective_max_parallel=effective_max_parallel) + await db.refresh(batch) + # Unblock any DELETE /batch/{batch_id} waiter — must happen AFTER + # settle so the drained handler observes the settled parent. + runner.mark_completed(batch.batch_id) logger.info( "batch.completed", @@ -165,6 +200,8 @@ async def submit(self, db: AsyncSession, req: BatchSubmitRequest) -> BatchSubmit status=batch.status, completed_items=batch.completed_items, failed_items=batch.failed_items, + cancelled_items=batch.cancelled_items, + effective_max_parallel=effective_max_parallel, ) return BatchSubmitResponse.model_validate(batch) @@ -347,13 +384,24 @@ def _shape_metrics(self, job: JobResponse) -> dict[str, Any] | None: "sample_size": sample_size, } - async def _settle(self, db: AsyncSession, batch: BatchJob) -> None: + async def _settle( + self, + db: AsyncSession, + batch: BatchJob, + effective_max_parallel: int = 0, + ) -> None: """Aggregate per-status counts and settle the parent. - all COMPLETED → ``completed`` - all FAILED → ``failed`` - - mixed (>=1 of each) → ``partial`` + - all CANCELLED (≥1 cancel, no other terminal) → ``cancelled`` + - mixed COMPLETED + FAILED → ``partial`` - 0 items (degenerate empty batch) → ``completed`` (vacuous) + + ``effective_max_parallel`` is stored in the JSONB ``result_summary`` + so the PRP-34 ``BatchSubmitResponse.effective_max_parallel`` + computed_field can resolve it at response-time without an Alembic + migration. """ stmt = ( select(BatchJobItem.status, func.count()) @@ -367,16 +415,23 @@ async def _settle(self, db: AsyncSession, batch: BatchJob) -> None: failed = counts.get(BatchItemStatus.FAILED.value, 0) cancelled = counts.get(BatchItemStatus.CANCELLED.value, 0) - if completed > 0 and failed == 0: + if cancelled > 0 and completed == 0 and failed == 0: + # PRP-34: a cancel that fired before any sibling completed + # settles the parent to CANCELLED so the operator's intent is + # preserved. + final = BatchStatus.CANCELLED + elif completed > 0 and failed == 0 and cancelled == 0: final = BatchStatus.COMPLETED - elif failed > 0 and completed == 0: + elif failed > 0 and completed == 0 and cancelled == 0: final = BatchStatus.FAILED - elif completed > 0 and failed > 0: + elif completed > 0 or failed > 0: + # Mixed terminals: any of (completed, failed, cancelled) > 0 + # with COMPLETED or FAILED also present → PARTIAL. final = BatchStatus.PARTIAL else: - # No completed + no failed: empty batch or all-cancelled. Treat - # as ``completed`` (vacuous) — the integration test asserts on - # completed_items=N, not on status when items=0. + # No completed + no failed + no cancelled: empty batch. Treat as + # ``completed`` (vacuous) — preserves the PRP-33 invariant the + # integration test asserts on. final = BatchStatus.COMPLETED batch.status = final.value @@ -387,6 +442,7 @@ async def _settle(self, db: AsyncSession, batch: BatchJob) -> None: batch.result_summary = { "by_status": counts, "final_status": final.value, + "effective_max_parallel": effective_max_parallel, } await db.commit() diff --git a/app/features/batch/tests/test_models.py b/app/features/batch/tests/test_models.py index 9e38994a..c5fce45f 100644 --- a/app/features/batch/tests/test_models.py +++ b/app/features/batch/tests/test_models.py @@ -50,9 +50,11 @@ def test_valid_transitions_dict_item() -> None: BatchItemStatus.RUNNING, BatchItemStatus.CANCELLED, } + # PRP-34 added ``RUNNING → CANCELLED`` for the cooperative-cancel path. assert VALID_BATCH_ITEM_TRANSITIONS[BatchItemStatus.RUNNING] == { BatchItemStatus.COMPLETED, BatchItemStatus.FAILED, + BatchItemStatus.CANCELLED, } for terminal in ( BatchItemStatus.COMPLETED, diff --git a/app/features/batch/tests/test_routes_cancel.py b/app/features/batch/tests/test_routes_cancel.py new file mode 100644 index 00000000..07d9ac33 --- /dev/null +++ b/app/features/batch/tests/test_routes_cancel.py @@ -0,0 +1,158 @@ +"""Integration tests for ``DELETE /batch/{batch_id}`` (PRP-34). + +ASGITransport-backed — same pattern as ``test_routes_integration.py``. +Marked ``integration`` because they query the real docker-compose Postgres +via the FastAPI ``get_db`` dependency. +""" + +from __future__ import annotations + +from typing import Any + +import pytest +from httpx import AsyncClient +from sqlalchemy.ext.asyncio import AsyncSession + +from app.core.config import get_settings +from app.features.batch import runner +from app.features.batch.models import BatchJob, BatchStatus +from app.features.data_platform.models import Product, Store + +pytestmark = pytest.mark.integration + + +async def test_delete_404_unknown_batch(client: AsyncClient) -> None: + """Unknown batch_id → RFC 7807 404 problem+json.""" + resp = await client.delete("/batch/does-not-exist-prp34") + assert resp.status_code == 404 + assert resp.headers["content-type"].startswith("application/problem+json") + body = resp.json() + assert body["status"] == 404 + assert body["code"] == "NOT_FOUND" + + +async def test_delete_409_terminal_batch( + client: AsyncClient, + sample_store: Store, + sample_products_3: list[Product], + sample_sales_120: list[Any], +) -> None: + """A successfully-completed batch is terminal — DELETE returns RFC 7807 409. + + Submits a 3-pair naive backtest; the run completes synchronously inside + ``POST /batch/forecasting``. The subsequent DELETE finds the parent in + ``completed`` (terminal) and the runner registry empty. + """ + payload = { + "operation": "backtest", + "scope": { + "kind": "manual", + "store_ids": [sample_store.id], + "product_ids": [p.id for p in sample_products_3], + }, + "model_configs": [{"model_type": "naive"}], + "start_date": "2024-01-01", + "end_date": "2024-04-29", + } + submit = await client.post("/batch/forecasting", json=payload) + assert submit.status_code == 202, submit.text + batch_id = submit.json()["batch_id"] + assert submit.json()["status"] == "completed" + + resp = await client.delete(f"/batch/{batch_id}") + assert resp.status_code == 409 + assert resp.headers["content-type"].startswith("application/problem+json") + body = resp.json() + assert body["status"] == 409 + assert body["code"] == "CONFLICT" + + +async def test_delete_200_clean_drain( + client: AsyncClient, + db_session: AsyncSession, +) -> None: + """Happy-path DELETE: registered handle, drain succeeds immediately, 200. + + Seeds a ``running`` parent row and pre-fires the registry handle's + ``completed_event`` so the route's ``runner.await_drain`` returns + ``True`` without waiting — the same observable shape as + ``BatchService.submit`` finishing settle and calling ``mark_completed`` + a microsecond before the DELETE handler's drain check. Verifies the + route then reloads the parent and serialises a 200 ``BatchSubmitResponse``. + + The genuine *mid-flight* cancel-and-drain path is covered end-to-end + by ``test_runner_chaos.test_cancel_mid_flight_does_not_orphan_running_items``. + """ + batch = BatchJob( + batch_id="test_200_drain", + operation="backtest", + scope={"kind": "manual"}, + model_configs=[], + status=BatchStatus.RUNNING.value, + total_items=0, + params={}, + max_parallel=4, + ) + db_session.add(batch) + await db_session.commit() + + handle = runner.CancelHandle() + handle.completed_event.set() # drain returns True immediately + runner._ACTIVE_BATCHES["test_200_drain"] = handle + + try: + resp = await client.delete("/batch/test_200_drain") + assert resp.status_code == 200, resp.text + assert resp.headers["content-type"].startswith("application/json") + body = resp.json() + assert body["batch_id"] == "test_200_drain" + assert body["max_parallel"] == 4 + # The parent was never run through _settle, so it stays ``running`` + # — the route's contract is "return the current parent record after + # drain", not "force settle"; settle is the submit handler's job. + assert body["status"] == "running" + finally: + runner._ACTIVE_BATCHES.pop("test_200_drain", None) + + +async def test_delete_504_drain_timeout( + client: AsyncClient, + db_session: AsyncSession, + monkeypatch: pytest.MonkeyPatch, +) -> None: + """A stuck registry handle + 0-second drain timeout → RFC 7807 504. + + Sets up an in-DB ``running`` parent row + a registry handle whose + ``completed_event`` never fires. With ``batch_cancel_drain_timeout_seconds=0``, + ``runner.await_drain`` raises ``TimeoutError`` and the route surfaces + :class:`app.core.exceptions.GatewayTimeoutError` as a 504. + """ + batch = BatchJob( + batch_id="test_504_drain", + operation="backtest", + scope={"kind": "manual"}, + model_configs=[], + status=BatchStatus.RUNNING.value, + total_items=0, + params={}, + max_parallel=4, + ) + db_session.add(batch) + await db_session.commit() + + handle = runner.CancelHandle() + runner._ACTIVE_BATCHES["test_504_drain"] = handle + + settings = get_settings() + monkeypatch.setattr(settings, "batch_cancel_drain_timeout_seconds", 0) + + try: + resp = await client.delete("/batch/test_504_drain") + assert resp.status_code == 504, resp.text + assert resp.headers["content-type"].startswith("application/problem+json") + body = resp.json() + assert body["status"] == 504 + assert body["code"] == "GATEWAY_TIMEOUT" + assert "Drain exceeded" in body["detail"] + finally: + runner._ACTIVE_BATCHES.pop("test_504_drain", None) diff --git a/app/features/batch/tests/test_runner.py b/app/features/batch/tests/test_runner.py new file mode 100644 index 00000000..a60ca410 --- /dev/null +++ b/app/features/batch/tests/test_runner.py @@ -0,0 +1,298 @@ +"""Unit tests for the PRP-34 bounded-concurrency batch runner. + +The runner's DB helpers (``_bump_running``, ``_mark_cancelled_skipped``, +``_mark_cancelled_running``, ``_mark_failed_unexpected``) are monkeypatched +to awaitable no-ops so the asyncio orchestration can be exercised without +docker-compose. The DB invariants those helpers guard (no orphaned +``running`` rows, ``running_items`` counter bounded by +``effective_max_parallel``) are covered in the integration chaos suite +(``test_runner_chaos.py``). +""" + +from __future__ import annotations + +import asyncio +from contextlib import asynccontextmanager +from typing import Any, cast +from unittest.mock import AsyncMock + +import pytest + +from app.features.batch import runner + + +@pytest.fixture(autouse=True) +def _clear_registry() -> Any: + """Each test starts with a clean ``_ACTIVE_BATCHES`` registry.""" + runner._ACTIVE_BATCHES.clear() + yield + runner._ACTIVE_BATCHES.clear() + + +@pytest.fixture +def patch_db_helpers(monkeypatch: pytest.MonkeyPatch) -> dict[str, list[Any]]: + """Replace runner DB helpers with awaitable no-ops. + + Returns a call-tracker dict the test can read to assert which helper + fired for which item — the only contract this fixture cares about. + The helpers now accept an ``AsyncSession`` (refactor per code review) + but the unit tests don't exercise SQL — the session arg is ignored. + """ + calls: dict[str, list[Any]] = { + "bump_running": [], + "mark_cancelled_skipped": [], + "mark_cancelled_running": [], + "mark_failed_unexpected": [], + } + + async def _bump_running(_session: Any, batch_id: str, delta: int) -> None: + calls["bump_running"].append((batch_id, delta)) + + async def _mark_cancelled_skipped(_session: Any, item_id: str) -> None: + calls["mark_cancelled_skipped"].append(item_id) + + async def _mark_cancelled_running(_session: Any, item_id: str) -> None: + calls["mark_cancelled_running"].append(item_id) + + async def _mark_failed_unexpected(_session: Any, item_id: str) -> None: + calls["mark_failed_unexpected"].append(item_id) + + monkeypatch.setattr(runner, "_bump_running", _bump_running) + monkeypatch.setattr(runner, "_mark_cancelled_skipped", _mark_cancelled_skipped) + monkeypatch.setattr(runner, "_mark_cancelled_running", _mark_cancelled_running) + monkeypatch.setattr(runner, "_mark_failed_unexpected", _mark_failed_unexpected) + return calls + + +def _fake_session_maker() -> Any: + """An ``async_sessionmaker``-shaped callable that yields a Mock session. + + ``runner.run_batch`` calls ``session_maker()`` and uses the result as an + async context manager (``async with session_maker() as session:``). The + Mock session is never touched because the patched helpers ignore it. + """ + + @asynccontextmanager + async def _ctx() -> Any: + yield AsyncMock() + + def _maker() -> Any: + return _ctx() + + return cast(Any, _maker) + + +# ---------------------------------------------------------------- semaphore + + +async def test_semaphore_caps_concurrency(patch_db_helpers: dict[str, list[Any]]) -> None: + """5 children with max_parallel=2 — observed concurrent peak == 2. + + LOAD-BEARING regression for the unbounded-fan-out failure mode. If a + future refactor replaces the Semaphore with ``asyncio.gather`` or + pushes the ``async with sem:`` outside the child, this test fires. + """ + in_flight = 0 + peak = 0 + + async def child(_item_id: str) -> None: + nonlocal in_flight, peak + in_flight += 1 + peak = max(peak, in_flight) + try: + await asyncio.sleep(0.02) + finally: + in_flight -= 1 + + effective = await runner.run_batch( + batch_id="b_sem", + item_ids=[f"i{i}" for i in range(5)], + max_parallel=2, + global_max_parallel=10, + session_maker=_fake_session_maker(), + execute_item=child, + ) + runner.mark_completed("b_sem") + + assert effective == 2 + assert peak == 2, f"observed peak {peak}, expected exactly 2" + # Every child bumped running ±1 — net zero. + bump_deltas = [delta for (_, delta) in patch_db_helpers["bump_running"]] + assert sum(bump_deltas) == 0 + assert len(bump_deltas) == 10 # 5 children x (start + finish) + + +async def test_settings_global_cap_clamps_max_parallel( + patch_db_helpers: dict[str, list[Any]], +) -> None: + """max_parallel=32 clamped by global_max_parallel=4 → effective=4, peak ≤ 4.""" + in_flight = 0 + peak = 0 + + async def child(_item_id: str) -> None: + nonlocal in_flight, peak + in_flight += 1 + peak = max(peak, in_flight) + try: + await asyncio.sleep(0.02) + finally: + in_flight -= 1 + + effective = await runner.run_batch( + batch_id="b_cap", + item_ids=[f"i{i}" for i in range(8)], + max_parallel=32, + global_max_parallel=4, + session_maker=_fake_session_maker(), + execute_item=child, + ) + runner.mark_completed("b_cap") + + assert effective == 4 + assert peak <= 4, f"observed peak {peak} exceeded global cap of 4" + + +# ---------------------------------------------------- per-child failure isolation + + +async def test_child_failure_does_not_abort_siblings( + patch_db_helpers: dict[str, list[Any]], +) -> None: + """One of 5 children raises RuntimeError; other 4 reach completion. + + The runner's child-level ``except Exception`` (defensive) catches the + error and marks the item failed; the TaskGroup never sees the exception + so siblings continue. Without this guard, TaskGroup cancels siblings + on the first failure. + """ + completed: list[str] = [] + + async def child(item_id: str) -> None: + if item_id == "i2": + raise RuntimeError("synthetic failure") + await asyncio.sleep(0.01) + completed.append(item_id) + + effective = await runner.run_batch( + batch_id="b_fail", + item_ids=[f"i{i}" for i in range(5)], + max_parallel=5, + global_max_parallel=10, + session_maker=_fake_session_maker(), + execute_item=child, + ) + runner.mark_completed("b_fail") + + assert effective == 5 + assert sorted(completed) == ["i0", "i1", "i3", "i4"] + assert patch_db_helpers["mark_failed_unexpected"] == ["i2"] + + +# --------------------------------------------------------------- cancel paths + + +async def test_cancel_pending_child_marks_cancelled_without_running( + patch_db_helpers: dict[str, list[Any]], +) -> None: + """max_parallel=1, 3 items. After i0 starts, cancel — i1/i2 skip the work.""" + started: list[str] = [] + + async def child(item_id: str) -> None: + started.append(item_id) + try: + await asyncio.sleep(0.5) + except asyncio.CancelledError: + raise + + task = asyncio.create_task( + runner.run_batch( + batch_id="b_pending", + item_ids=["i0", "i1", "i2"], + max_parallel=1, + global_max_parallel=10, + session_maker=_fake_session_maker(), + execute_item=child, + ) + ) + await asyncio.sleep(0.05) # let i0 acquire the semaphore + start work + fired = runner.cancel_batch("b_pending") + await task + runner.mark_completed("b_pending") + + assert fired is True + # i0 was running when cancelled → mark_cancelled_running path. + assert patch_db_helpers["mark_cancelled_running"] == ["i0"] + # i1, i2 never acquired the semaphore → mark_cancelled_skipped path. + assert set(patch_db_helpers["mark_cancelled_skipped"]) == {"i1", "i2"} + # i0 was the only one that even entered child. + assert started == ["i0"] + + +async def test_cancel_running_child_propagates_cancelled_error( + patch_db_helpers: dict[str, list[Any]], +) -> None: + """A running child observes CancelledError; finally block writes cancelled.""" + cancelled_in_child: list[str] = [] + + async def child(item_id: str) -> None: + try: + await asyncio.sleep(1.0) + except asyncio.CancelledError: + cancelled_in_child.append(item_id) + raise + + task = asyncio.create_task( + runner.run_batch( + batch_id="b_running", + item_ids=["i0"], + max_parallel=1, + global_max_parallel=10, + session_maker=_fake_session_maker(), + execute_item=child, + ) + ) + await asyncio.sleep(0.05) + runner.cancel_batch("b_running") + await task + runner.mark_completed("b_running") + + assert cancelled_in_child == ["i0"] + assert patch_db_helpers["mark_cancelled_running"] == ["i0"] + + +# ------------------------------------------------------------- registry hygiene + + +async def test_mark_completed_unblocks_await_drain( + patch_db_helpers: dict[str, list[Any]], +) -> None: + """``mark_completed`` removes the registry entry and fires the event.""" + runner._ACTIVE_BATCHES["bx"] = runner.CancelHandle() + + # Start a drain waiter + drain_task = asyncio.create_task(runner.await_drain("bx", timeout_seconds=1.0)) + await asyncio.sleep(0.01) + runner.mark_completed("bx") + drained = await drain_task + + assert drained is True + assert "bx" not in runner._ACTIVE_BATCHES + + +async def test_cancel_batch_returns_false_when_unregistered() -> None: + """``cancel_batch`` on an unregistered batch returns False (race-safe).""" + fired = runner.cancel_batch("does-not-exist") + assert fired is False + + +async def test_await_drain_returns_true_when_unregistered() -> None: + """``await_drain`` on an unregistered batch returns True immediately.""" + drained = await runner.await_drain("does-not-exist", timeout_seconds=0.0) + assert drained is True + + +async def test_await_drain_times_out_on_stuck_handle() -> None: + """``await_drain`` returns False when ``completed_event`` never fires.""" + runner._ACTIVE_BATCHES["b_stuck"] = runner.CancelHandle() + drained = await runner.await_drain("b_stuck", timeout_seconds=0.05) + assert drained is False diff --git a/app/features/batch/tests/test_runner_chaos.py b/app/features/batch/tests/test_runner_chaos.py new file mode 100644 index 00000000..ff627c6b --- /dev/null +++ b/app/features/batch/tests/test_runner_chaos.py @@ -0,0 +1,199 @@ +"""Chaos / orphan-state regression tests for the PRP-34 runner. + +These tests bypass the HTTP layer and drive ``runner.run_batch`` directly +with a synthetic ``execute_item`` callable, so they can exercise mid-flight +cancellation without depending on the timing of a real backtest. They run +against the real docker-compose Postgres so the DB invariants the runner +guards (no orphaned ``running`` rows, parent ``running_items=0`` after +drain) are verified end-to-end. +""" + +from __future__ import annotations + +import asyncio +import uuid +from datetime import UTC, datetime +from typing import Any + +import pytest +from httpx import AsyncClient +from sqlalchemy import delete, select, update +from sqlalchemy.ext.asyncio import ( + AsyncSession, + async_sessionmaker, + create_async_engine, +) + +from app.core.config import get_settings +from app.features.batch import runner +from app.features.batch.models import ( + BatchItemStatus, + BatchJob, + BatchJobItem, + BatchStatus, +) +from app.features.data_platform.models import Product, Store + +pytestmark = pytest.mark.integration + + +async def _seed_synthetic_batch( + db_session: AsyncSession, + *, + n_items: int, + max_parallel: int, + batch_id_prefix: str = "test_chaos", +) -> tuple[str, list[str]]: + """Insert a parent + N pending items directly (bypass scope expansion).""" + bid = f"{batch_id_prefix}_{uuid.uuid4().hex[:8]}" + batch = BatchJob( + batch_id=bid, + operation="backtest", + scope={"kind": "manual"}, + model_configs=[], + status=BatchStatus.RUNNING.value, + total_items=n_items, + params={}, + max_parallel=max_parallel, + ) + db_session.add(batch) + item_ids: list[str] = [] + for i in range(n_items): + iid = f"{bid}_i{i}" + item_ids.append(iid) + db_session.add( + BatchJobItem( + item_id=iid, + batch_id=bid, + store_id=1, + product_id=1, + model_type="naive", + status=BatchItemStatus.PENDING.value, + params={}, + ) + ) + await db_session.commit() + return bid, item_ids + + +async def test_cancel_mid_flight_does_not_orphan_running_items( + db_session: AsyncSession, +) -> None: + """A cancel mid-flight leaves no ``batch_job_item`` in RUNNING state. + + 4-item batch, max_parallel=2, slow synthetic children. After cancel: + - no items in ``running`` status + - ``batch_job.running_items`` is 0 + """ + bid, item_ids = await _seed_synthetic_batch(db_session, n_items=4, max_parallel=2) + settings = get_settings() + engine = create_async_engine(settings.database_url, echo=False) + session_maker = async_sessionmaker(engine, class_=AsyncSession, expire_on_commit=False) + + async def slow_exec(item_id: str) -> None: + async with session_maker() as s: + await s.execute( + update(BatchJobItem) + .where(BatchJobItem.item_id == item_id) + .values( + status=BatchItemStatus.RUNNING.value, + started_at=datetime.now(UTC), + ) + ) + await s.commit() + try: + await asyncio.sleep(2.0) + except asyncio.CancelledError: + raise + async with session_maker() as s: + await s.execute( + update(BatchJobItem) + .where(BatchJobItem.item_id == item_id) + .values( + status=BatchItemStatus.COMPLETED.value, + completed_at=datetime.now(UTC), + ) + ) + await s.commit() + + task = asyncio.create_task( + runner.run_batch( + batch_id=bid, + item_ids=item_ids, + max_parallel=2, + global_max_parallel=10, + session_maker=session_maker, + execute_item=slow_exec, + ) + ) + # Let the 2 max_parallel children acquire the semaphore + start work. + await asyncio.sleep(0.15) + fired = runner.cancel_batch(bid) + assert fired is True + await task + runner.mark_completed(bid) + await engine.dispose() + + # Verify no item left in RUNNING state. + rows = ( + (await db_session.execute(select(BatchJobItem).where(BatchJobItem.batch_id == bid))) + .scalars() + .all() + ) + statuses = [r.status for r in rows] + assert BatchItemStatus.RUNNING.value not in statuses, ( + f"orphaned RUNNING item(s) after cancel: {statuses}" + ) + # Every item should now be either cancelled or (rarely, if the cancel + # raced the completion update) completed; nothing else. + allowed = { + BatchItemStatus.CANCELLED.value, + BatchItemStatus.COMPLETED.value, + } + assert set(statuses) <= allowed, f"unexpected statuses: {statuses}" + + # Parent's running_items must be 0 post-drain. + parent = ( + await db_session.execute(select(BatchJob).where(BatchJob.batch_id == bid)) + ).scalar_one() + assert parent.running_items == 0 + # Cleanup for the conftest LIKE 'test%' DELETE. + await db_session.execute(delete(BatchJobItem).where(BatchJobItem.batch_id == bid)) + await db_session.execute(delete(BatchJob).where(BatchJob.batch_id == bid)) + await db_session.commit() + + +async def test_parent_status_progresses_as_children_complete( + client: AsyncClient, + sample_store: Store, + sample_products_3: list[Product], + sample_sales_120: list[Any], +) -> None: + """A 3-pair max_parallel=2 batch settles with running_items=0 + effective_max_parallel=2. + + Verifies the BatchService → runner → settle integration writes the + expected JSONB key the PRP-34 ``BatchSubmitResponse.effective_max_parallel`` + computed field resolves at response time. + """ + payload = { + "operation": "backtest", + "scope": { + "kind": "manual", + "store_ids": [sample_store.id], + "product_ids": [p.id for p in sample_products_3], + }, + "model_configs": [{"model_type": "naive"}], + "start_date": "2024-01-01", + "end_date": "2024-04-29", + "max_parallel": 2, + } + resp = await client.post("/batch/forecasting", json=payload) + assert resp.status_code == 202, resp.text + body = resp.json() + + assert body["status"] == "completed" + assert body["completed_items"] == 3 + assert body["running_items"] == 0 + assert body["effective_max_parallel"] == 2 + # The JSONB summary itself should carry the key too. + assert body["result_summary"]["effective_max_parallel"] == 2 diff --git a/frontend/package.json b/frontend/package.json index ab581121..3b66c9ed 100644 --- a/frontend/package.json +++ b/frontend/package.json @@ -23,6 +23,7 @@ "@radix-ui/react-scroll-area": "^1.2.10", "@radix-ui/react-select": "^2.2.6", "@radix-ui/react-separator": "^1.1.8", + "@radix-ui/react-slider": "^1.3.6", "@radix-ui/react-slot": "^1.2.4", "@radix-ui/react-tabs": "^1.1.13", "@radix-ui/react-tooltip": "^1.2.8", @@ -63,6 +64,8 @@ "vitest": "^4.1.6" }, "pnpm": { - "onlyBuiltDependencies": ["esbuild"] + "onlyBuiltDependencies": [ + "esbuild" + ] } } diff --git a/frontend/pnpm-lock.yaml b/frontend/pnpm-lock.yaml index 10705189..93b5677d 100644 --- a/frontend/pnpm-lock.yaml +++ b/frontend/pnpm-lock.yaml @@ -44,6 +44,9 @@ importers: '@radix-ui/react-separator': specifier: ^1.1.8 version: 1.1.8(@types/react-dom@19.2.3(@types/react@19.2.10))(@types/react@19.2.10)(react-dom@19.2.4(react@19.2.4))(react@19.2.4) + '@radix-ui/react-slider': + specifier: ^1.3.6 + version: 1.3.6(@types/react-dom@19.2.3(@types/react@19.2.10))(@types/react@19.2.10)(react-dom@19.2.4(react@19.2.4))(react@19.2.4) '@radix-ui/react-slot': specifier: ^1.2.4 version: 1.2.4(@types/react@19.2.10)(react@19.2.4) @@ -910,6 +913,19 @@ packages: '@types/react-dom': optional: true + '@radix-ui/react-slider@1.3.6': + resolution: {integrity: sha512-JPYb1GuM1bxfjMRlNLE+BcmBC8onfCi60Blk7OBqi2MLTFdS+8401U4uFjnwkOr49BLmXxLC6JHkvAsx5OJvHw==} + peerDependencies: + '@types/react': '*' + '@types/react-dom': '*' + react: ^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc + react-dom: ^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc + peerDependenciesMeta: + '@types/react': + optional: true + '@types/react-dom': + optional: true + '@radix-ui/react-slot@1.2.3': resolution: {integrity: sha512-aeNmHnBxbi2St0au6VBVC7JXFlhLlOnvIIlePNniyUNAClzmtAUEY8/pBiK3iHjufOlwA+c20/8jngo7xcrg8A==} peerDependencies: @@ -3270,6 +3286,25 @@ snapshots: '@types/react': 19.2.10 '@types/react-dom': 19.2.3(@types/react@19.2.10) + '@radix-ui/react-slider@1.3.6(@types/react-dom@19.2.3(@types/react@19.2.10))(@types/react@19.2.10)(react-dom@19.2.4(react@19.2.4))(react@19.2.4)': + dependencies: + '@radix-ui/number': 1.1.1 + '@radix-ui/primitive': 1.1.3 + '@radix-ui/react-collection': 1.1.7(@types/react-dom@19.2.3(@types/react@19.2.10))(@types/react@19.2.10)(react-dom@19.2.4(react@19.2.4))(react@19.2.4) + '@radix-ui/react-compose-refs': 1.1.2(@types/react@19.2.10)(react@19.2.4) + '@radix-ui/react-context': 1.1.2(@types/react@19.2.10)(react@19.2.4) + '@radix-ui/react-direction': 1.1.1(@types/react@19.2.10)(react@19.2.4) + '@radix-ui/react-primitive': 2.1.3(@types/react-dom@19.2.3(@types/react@19.2.10))(@types/react@19.2.10)(react-dom@19.2.4(react@19.2.4))(react@19.2.4) + '@radix-ui/react-use-controllable-state': 1.2.2(@types/react@19.2.10)(react@19.2.4) + '@radix-ui/react-use-layout-effect': 1.1.1(@types/react@19.2.10)(react@19.2.4) + '@radix-ui/react-use-previous': 1.1.1(@types/react@19.2.10)(react@19.2.4) + '@radix-ui/react-use-size': 1.1.1(@types/react@19.2.10)(react@19.2.4) + react: 19.2.4 + react-dom: 19.2.4(react@19.2.4) + optionalDependencies: + '@types/react': 19.2.10 + '@types/react-dom': 19.2.3(@types/react@19.2.10) + '@radix-ui/react-slot@1.2.3(@types/react@19.2.10)(react@19.2.4)': dependencies: '@radix-ui/react-compose-refs': 1.1.2(@types/react@19.2.10)(react@19.2.4) diff --git a/frontend/src/components/ui/slider.tsx b/frontend/src/components/ui/slider.tsx new file mode 100644 index 00000000..013c938b --- /dev/null +++ b/frontend/src/components/ui/slider.tsx @@ -0,0 +1,61 @@ +import * as SliderPrimitive from "@radix-ui/react-slider" +import * as React from "react" + +import { cn } from "@/lib/utils" + +function Slider({ + className, + defaultValue, + value, + min = 0, + max = 100, + ...props +}: React.ComponentProps) { + const _values = React.useMemo( + () => + Array.isArray(value) + ? value + : Array.isArray(defaultValue) + ? defaultValue + : [min, max], + [value, defaultValue, min, max] + ) + + return ( + + + + + {Array.from({ length: _values.length }, (_, index) => ( + + ))} + + ) +} + +export { Slider } diff --git a/frontend/src/hooks/use-batches.test.ts b/frontend/src/hooks/use-batches.test.ts new file mode 100644 index 00000000..385a3b11 --- /dev/null +++ b/frontend/src/hooks/use-batches.test.ts @@ -0,0 +1,117 @@ +/** + * Unit tests for use-batches hooks (PRP-34 ``useCancelBatch``). + * + * Stubs ``fetch`` to assert the hook issues a DELETE and updates the + * TanStack Query cache; no real backend is exercised. + */ +import { QueryClient, QueryClientProvider } from '@tanstack/react-query' +import { act, renderHook, waitFor } from '@testing-library/react' +import { afterEach, describe, expect, it, vi } from 'vitest' +import { createElement, type ReactNode } from 'react' + +import { useCancelBatch } from './use-batches' +import type { BatchSubmitResponse } from '@/types/api' + +function makeSettledBatch(batch_id: string): BatchSubmitResponse { + const now = '2026-05-25T00:00:00Z' + return { + batch_id, + operation: 'backtest', + status: 'cancelled', + total_items: 4, + completed_items: 0, + failed_items: 0, + running_items: 0, + cancelled_items: 4, + max_parallel: 4, + effective_max_parallel: 4, + started_at: now, + completed_at: now, + result_summary: { effective_max_parallel: 4 }, + created_at: now, + updated_at: now, + } +} + +function makeWrapper(client: QueryClient) { + return function Wrapper({ children }: { children: ReactNode }) { + return createElement( + QueryClientProvider, + { client }, + children, + ) + } +} + +afterEach(() => { + vi.unstubAllGlobals() +}) + +describe('useCancelBatch', () => { + it('issues a DELETE to /batch/{batchId} and updates the cache', async () => { + const settled = makeSettledBatch('batch_abc') + const fetchMock = vi.fn().mockResolvedValue( + new Response(JSON.stringify(settled), { + status: 200, + headers: { 'content-type': 'application/json' }, + }), + ) + vi.stubGlobal('fetch', fetchMock) + + const client = new QueryClient({ + defaultOptions: { queries: { retry: false } }, + }) + const { result } = renderHook(() => useCancelBatch(), { + wrapper: makeWrapper(client), + }) + + await act(async () => { + result.current.mutate('batch_abc') + }) + + await waitFor(() => expect(result.current.isSuccess).toBe(true)) + + // Verifies the URL + HTTP method. + expect(fetchMock).toHaveBeenCalledTimes(1) + const call = fetchMock.mock.calls[0]! + expect(call[0]).toContain('/batch/batch_abc') + expect((call[1] as RequestInit).method).toBe('DELETE') + + // Mutation success writes the settled parent into the per-batch cache + // — ``useBatch(batchId)`` will read it immediately. + expect(client.getQueryData(['batch', 'batch_abc'])).toEqual(settled) + }) + + it('surfaces an RFC 7807 problem+json failure as ApiError on the mutation', async () => { + const problem = { + type: '/errors/conflict', + title: 'Conflict', + status: 409, + detail: 'Batch already terminal: completed', + code: 'CONFLICT', + } + vi.stubGlobal( + 'fetch', + vi.fn().mockResolvedValue( + new Response(JSON.stringify(problem), { + status: 409, + headers: { 'content-type': 'application/problem+json' }, + }), + ), + ) + + const client = new QueryClient({ + defaultOptions: { queries: { retry: false } }, + }) + const { result } = renderHook(() => useCancelBatch(), { + wrapper: makeWrapper(client), + }) + + await act(async () => { + result.current.mutate('batch_terminal') + }) + + await waitFor(() => expect(result.current.isError).toBe(true)) + expect(String(result.current.error)).toContain('Batch already terminal') + }) +}) diff --git a/frontend/src/hooks/use-batches.ts b/frontend/src/hooks/use-batches.ts index 66cd63b9..c98d8dc7 100644 --- a/frontend/src/hooks/use-batches.ts +++ b/frontend/src/hooks/use-batches.ts @@ -24,6 +24,21 @@ export function useSubmitBatch() { }) } +// Cancel an in-flight batch (PRP-34). Server-side semantics — 200 settled +// parent on clean drain, 404 if unknown, 409 if already terminal, 504 if +// the drain exceeded ``Settings.batch_cancel_drain_timeout_seconds``. +export function useCancelBatch() { + const queryClient = useQueryClient() + return useMutation({ + mutationFn: (batchId: string) => + api(`/batch/${batchId}`, { method: 'DELETE' }), + onSuccess: (data) => { + queryClient.setQueryData(['batch', data.batch_id], data) + void queryClient.invalidateQueries({ queryKey: ['batch'] }) + }, + }) +} + // Get a batch's parent record. Polls every 2s while the run is in-flight; // stops polling once the parent settles to a terminal state. export function useBatch(batchId: string | null, enabled = true) { diff --git a/frontend/src/pages/visualize/batch.tsx b/frontend/src/pages/visualize/batch.tsx index 82b682af..1e921f60 100644 --- a/frontend/src/pages/visualize/batch.tsx +++ b/frontend/src/pages/visualize/batch.tsx @@ -1,13 +1,16 @@ /** - * Batch Runner — placeholder page (PRP-33 MVP). + * Batch Runner — PRP-34 (bounded concurrency + cooperative cancel). * - * Polls the parent batch status while in-flight and renders an items table. - * Per PRP narrowing: NO slider, NO cancel button, NO retry, NO heatmap, NO - * promotion panel — each downstream PRP owns one of those surfaces. + * Extends the PRP-33 placeholder with: + * - a max-parallel ``Slider`` on the submit form (PRP-34: activates + * ``batch_job.max_parallel`` — runtime-clamped server-side by + * ``Settings.batch_global_max_parallel``); + * - a ``running_items`` chip on the parent progress card; + * - a "Cancel batch" ``Button`` + confirmation ``AlertDialog`` that fires + * ``DELETE /batch/{batch_id}``. * - * MVP UX: a tiny submit form (manual scope only) + the live items table. - * The form is intentionally minimal — the agent / curl is the canonical - * driver in MVP; this page exists so the work is visible. + * Per PRP narrowing: still NO retry, NO heatmap, NO promotion panel — those + * are owned by their respective downstream PRPs. */ import { useState } from 'react' @@ -15,6 +18,18 @@ import { useState } from 'react' import { ErrorDisplay } from '@/components/common/error-display' import { LoadingState } from '@/components/common/loading-state' import { StatusBadge } from '@/components/common/status-badge' +import { + AlertDialog, + AlertDialogAction, + AlertDialogCancel, + AlertDialogContent, + AlertDialogDescription, + AlertDialogFooter, + AlertDialogHeader, + AlertDialogTitle, + AlertDialogTrigger, +} from '@/components/ui/alert-dialog' +import { Badge } from '@/components/ui/badge' import { Button } from '@/components/ui/button' import { Card, @@ -24,6 +39,7 @@ import { CardTitle, } from '@/components/ui/card' import { Input } from '@/components/ui/input' +import { Slider } from '@/components/ui/slider' import { Table, TableBody, @@ -32,8 +48,16 @@ import { TableHeader, TableRow, } from '@/components/ui/table' -import { useBatch, useBatchItems, useSubmitBatch } from '@/hooks/use-batches' -import type { BatchSubmitRequest } from '@/types/api' +import { + useBatch, + useBatchItems, + useCancelBatch, + useSubmitBatch, +} from '@/hooks/use-batches' +import { + TERMINAL_BATCH_STATES, + type BatchSubmitRequest, +} from '@/types/api' export default function BatchRunnerPage() { // Last-submitted batch the page tracks. null = nothing yet. @@ -45,11 +69,19 @@ export default function BatchRunnerPage() { const [productIds, setProductIds] = useState('1,2,3') const [startDate, setStartDate] = useState('2024-01-01') const [endDate, setEndDate] = useState('2024-04-29') + // PRP-34: per-batch parallelism request (server runtime-clamps by the + // global cap). Default matches the server's default of 4. + const [maxParallel, setMaxParallel] = useState(4) const submit = useSubmitBatch() + const cancel = useCancelBatch() const batch = useBatch(batchId) const items = useBatchItems({ batchId, pageSize: 50 }) + const isTerminal = batch.data + ? TERMINAL_BATCH_STATES.has(batch.data.status) + : false + function handleSubmit(e: React.FormEvent) { e.preventDefault() const parseIds = (s: string) => @@ -68,21 +100,26 @@ export default function BatchRunnerPage() { model_configs: [{ model_type: 'naive', params: {} }], start_date: startDate, end_date: endDate, + max_parallel: maxParallel, } submit.mutate(payload, { onSuccess: (data) => setBatchId(data.batch_id), }) } + function handleCancel() { + if (!batchId) return + cancel.mutate(batchId) + } + return (
-

Batch Runner (MVP)

+

Batch Runner

Submit a portfolio batch and watch its (store, product) items - execute sequentially. This is the PRP-33 placeholder — the - downstream PRPs add cancel, retry, priority, and the - champion/heatmap surface. + execute under a bounded concurrency cap. Cancel in-flight batches + cooperatively from the progress card.

@@ -91,11 +128,11 @@ export default function BatchRunnerPage() { Submit a manual backtest batch Comma-separated IDs; the runner fans out the cartesian product - and backtests each pair using the naive baseline. + and backtests each pair under the naive baseline. -
+ +
+
+ + Max parallel: {maxParallel} + + + effective = min(this, server cap) + +
+ { + const next = values[0] + if (typeof next === 'number') setMaxParallel(next) + }} + min={1} + max={8} + step={1} + aria-label="Max parallel" + /> +
+ + + + Cancel this batch? + + Pending items will be skipped; running items observe + the cancel at the next safe yield point. In-flight + model fits are uncancellable mid-call, so a long fit + may stall the drain. + + + + Keep running + + Cancel batch + + + + + )} +
+ {cancel.isError && ( +
+ +
)} @@ -183,7 +299,7 @@ export default function BatchRunnerPage() { - {item.metrics?.wape != null + {typeof item.metrics?.wape === 'number' ? item.metrics.wape.toFixed(3) : '—'} diff --git a/frontend/src/types/api.ts b/frontend/src/types/api.ts index 97a1798f..84a0f684 100644 --- a/frontend/src/types/api.ts +++ b/frontend/src/types/api.ts @@ -298,6 +298,17 @@ export type BatchStatus = | 'partial' | 'cancelled' +// Terminal parent statuses — mirrors ``TERMINAL_BATCH_STATES`` derived from +// the backend ``VALID_BATCH_TRANSITIONS`` state machine. Any UI that needs +// to gate on "this batch is settled" reads from here so the API and UI +// definitions cannot drift. +export const TERMINAL_BATCH_STATES: ReadonlySet = new Set([ + 'completed', + 'failed', + 'partial', + 'cancelled', +]) + export type BatchItemStatus = | 'pending' | 'running' @@ -352,6 +363,10 @@ export interface BatchSubmitResponse { failed_items: number running_items: number cancelled_items: number + max_parallel: number + // PRP-34: resolved server-side from + // result_summary.effective_max_parallel — 0 for legacy pre-PRP-34 rows. + effective_max_parallel: number started_at: string | null completed_at: string | null result_summary: Record | null diff --git a/uv.lock b/uv.lock index deb2cc36..121735b5 100644 --- a/uv.lock +++ b/uv.lock @@ -821,7 +821,7 @@ wheels = [ [[package]] name = "forecastlabai" -version = "0.2.17" +version = "0.2.18" source = { editable = "." } dependencies = [ { name = "alembic" },