From a3b54165b24da1aa2616108b2cc5e5c23d12953e Mon Sep 17 00:00:00 2001 From: N3uralCreativity Date: Tue, 7 Apr 2026 19:35:06 +0200 Subject: [PATCH 01/16] feat: build orchestration core and eval tooling --- .env.example | 12 + .github/workflows/ci.yml | 2 +- .gitignore | 4 + README.md | 46 +++- apps/api/README.md | 19 +- configs/models/catalog.yaml | 39 +++ configs/routing/policies.yaml | 4 +- data/state/.gitignore | 2 + evals/results/.gitignore | 2 + evals/scenarios/routing-basics.yaml | 22 ++ pyproject.toml | 12 + scripts/bootstrap/setup-local.ps1 | 1 + src/lai/__init__.py | 12 +- src/lai/api/__init__.py | 3 + src/lai/api/app.py | 112 ++++++++ src/lai/application.py | 63 +++++ src/lai/artifacts.py | 41 +++ src/lai/cli.py | 271 +++++++++++++++++- src/lai/config.py | 72 +++++ src/lai/domain.py | 317 +++++++++++++++++++++ src/lai/errors.py | 10 + src/lai/evals.py | 167 +++++++++++ src/lai/jobs/__init__.py | 2 +- src/lai/jobs/service.py | 239 ++++++++++++++++ src/lai/jobs/store.py | 216 +++++++++++++++ src/lai/layout.py | 10 + src/lai/providers/__init__.py | 5 +- src/lai/providers/base.py | 93 +++++++ src/lai/providers/implementations.py | 335 ++++++++++++++++++++++ src/lai/providers/registry.py | 48 ++++ src/lai/routing/__init__.py | 4 +- src/lai/routing/engine.py | 351 ++++++++++++++++++++++++ src/lai/routing/heuristics.py | 75 +++++ src/lai/serialization.py | 24 ++ src/lai/settings.py | 77 +++++- src/lai/system.py | 67 +++++ tests/__init__.py | 1 + tests/conftest.py | 12 + tests/helpers.py | 57 ++++ tests/integration/test_orchestration.py | 64 +++++ tests/live/test_local_smoke.py | 74 +++++ tests/live/test_provider_smoke.py | 92 +++++++ tests/unit/test_api.py | 13 + tests/unit/test_config.py | 13 + tests/unit/test_evals.py | 29 ++ tests/unit/test_layout.py | 2 +- tests/unit/test_routing.py | 60 ++++ 47 files changed, 3175 insertions(+), 21 deletions(-) create mode 100644 data/state/.gitignore create mode 100644 evals/results/.gitignore create mode 100644 evals/scenarios/routing-basics.yaml create mode 100644 src/lai/api/__init__.py create mode 100644 src/lai/api/app.py create mode 100644 src/lai/application.py create mode 100644 src/lai/artifacts.py create mode 100644 src/lai/config.py create mode 100644 src/lai/domain.py create mode 100644 src/lai/errors.py create mode 100644 src/lai/evals.py create mode 100644 src/lai/jobs/service.py create mode 100644 src/lai/jobs/store.py create mode 100644 src/lai/providers/base.py create mode 100644 src/lai/providers/implementations.py create mode 100644 src/lai/providers/registry.py create mode 100644 src/lai/routing/engine.py create mode 100644 src/lai/routing/heuristics.py create mode 100644 src/lai/serialization.py create mode 100644 src/lai/system.py create mode 100644 tests/__init__.py create mode 100644 tests/conftest.py create mode 100644 tests/helpers.py create mode 100644 tests/integration/test_orchestration.py create mode 100644 tests/live/test_local_smoke.py create mode 100644 tests/live/test_provider_smoke.py create mode 100644 tests/unit/test_api.py create mode 100644 tests/unit/test_config.py create mode 100644 tests/unit/test_evals.py create mode 100644 tests/unit/test_routing.py diff --git a/.env.example b/.env.example index 8986f00..0415ee7 100644 --- a/.env.example +++ b/.env.example @@ -1,12 +1,24 @@ LAI_ENV=local LAI_HF_TOKEN= +LAI_OPENAI_API_KEY= +LAI_ANTHROPIC_API_KEY= +LAI_GEMINI_API_KEY= LAI_MODEL_CATALOG=configs/models/catalog.yaml LAI_ROUTING_POLICY=configs/routing/policies.yaml LAI_PROMPT_ROOT=configs/prompts LAI_HUGGINGFACE_CACHE_DIR=data/cache/huggingface LAI_AIRLLM_SHARDS_DIR=data/models/airllm-shards +LAI_RAW_MODELS_DIR=data/models/raw LAI_ARTIFACTS_DIR=data/artifacts +LAI_STATE_DIR=data/state +LAI_DATABASE_PATH=data/state/lai.db LAI_LOGS_DIR=logs LAI_ALLOW_OVERNIGHT_JOBS=true LAI_ENABLE_GPU=true LAI_MAX_ROUTER_TOKENS=2048 +LAI_DEFAULT_TIMEOUT_SECONDS=120 +LAI_DEFAULT_MAX_OUTPUT_TOKENS=1024 +LAI_DEFAULT_TEMPERATURE=0.2 +LAI_QUEUE_POLL_INTERVAL_SECONDS=5 +LAI_WORKER_IDLE_SLEEP_SECONDS=2.0 +LAI_MAX_RETRY_ATTEMPTS=1 diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 0bfdf76..b68fa21 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -21,7 +21,7 @@ jobs: - name: Install package run: | python -m pip install --upgrade pip - python -m pip install -e .[dev] + python -m pip install -e .[dev,api] - name: Ruff run: ruff check . diff --git a/.gitignore b/.gitignore index 61aa09e..2e9020a 100644 --- a/.gitignore +++ b/.gitignore @@ -23,5 +23,9 @@ build/ data/cache/ data/models/ data/artifacts/ +data/state/* +!data/state/.gitignore +evals/results/* +!evals/results/.gitignore logs/* !logs/.gitignore diff --git a/README.md b/README.md index 9cf8da1..7eeb2da 100644 --- a/README.md +++ b/README.md @@ -78,18 +78,48 @@ cd LAI py -3.11 -m venv .venv .venv\Scripts\Activate.ps1 python -m pip install --upgrade pip -python -m pip install -e .[dev] +python -m pip install -e .[dev,api] Copy-Item .env.example .env python -m lai.cli doctor ``` -To add the large-model runtime later: +To add the local heavy-model and provider backends later: ```powershell -python -m pip install -e .[dev,api] -python -m pip install airllm +python -m pip install -e .[local,providers] +``` + +## Current commands + +```powershell +python -m lai.cli doctor +python -m lai.cli models list +python -m lai.cli models check +python -m lai.cli route explain "Summarize this note." +python -m lai.cli run "Create a detailed implementation strategy." +python -m lai.cli jobs list +python -m lai.cli worker run --once +python -m lai.cli eval route --no-save ``` +## Current API + +After installing the `api` extra: + +```powershell +uvicorn lai.api.app:create_api --factory --reload +``` + +Available endpoints: + +- `GET /health` +- `GET /models` +- `POST /route/explain` +- `POST /jobs` +- `GET /jobs` +- `GET /jobs/{job_id}` +- `POST /jobs/{job_id}/cancel` + ## Initial GitHub rules encoded in this repo - Pull request template and issue forms for consistent planning. @@ -100,10 +130,10 @@ python -m pip install airllm ## Near-term priorities -1. Implement the model registry and routing engine under `src/lai/`. -2. Add the first AirLLM runtime adapter and smoke-test workflows. -3. Introduce an API surface in `apps/api`. -4. Add evaluation scenarios that compare small-model routing against large-model final execution. +1. Add live provider smoke tests behind credentials and optional extras. +2. Harden the AirLLM local runtime path with real workstation validation. +3. Expand eval scenarios and richer reviewer/final-output refinement. +4. Add the web dashboard on top of the persisted job and artifact store. ## References diff --git a/apps/api/README.md b/apps/api/README.md index a58fdc9..fa2d909 100644 --- a/apps/api/README.md +++ b/apps/api/README.md @@ -1,3 +1,20 @@ # API App -This folder is reserved for the future control-plane API. It will expose request submission, job status, artifact retrieval, and model availability endpoints. +The first API surface lives in `src/lai/api/app.py` and mirrors the CLI-first orchestration core. + +Current endpoints: + +- `GET /health` +- `GET /models` +- `POST /route/explain` +- `POST /jobs` +- `GET /jobs` +- `GET /jobs/{job_id}` +- `POST /jobs/{job_id}/cancel` + +Run it after installing the `api` extra: + +```powershell +python -m pip install -e .[dev,api] +uvicorn lai.api.app:create_api --factory --reload +``` diff --git a/configs/models/catalog.yaml b/configs/models/catalog.yaml index 0eaa696..7c1b248 100644 --- a/configs/models/catalog.yaml +++ b/configs/models/catalog.yaml @@ -52,3 +52,42 @@ models: allow_layer_sharding: true allow_prefetching: true allow_cpu_fallback: true + + - id: openai-general + role: executor + runtime: openai + model: gpt-5.4-mini + context_window: 128000 + capabilities: + - planning + - summarization + - critique + - validation + - deep-reasoning + - long-form-generation + + - id: anthropic-general + role: executor + runtime: anthropic + model: claude-sonnet-4-20250514 + context_window: 200000 + capabilities: + - planning + - summarization + - critique + - validation + - deep-reasoning + - long-form-generation + + - id: gemini-general + role: executor + runtime: gemini + model: gemini-2.5-flash + context_window: 1000000 + capabilities: + - planning + - summarization + - critique + - validation + - deep-reasoning + - long-form-generation diff --git a/configs/routing/policies.yaml b/configs/routing/policies.yaml index 73f6749..be9d2a8 100644 --- a/configs/routing/policies.yaml +++ b/configs/routing/policies.yaml @@ -42,5 +42,5 @@ tiers: fallbacks: when_gpu_unavailable: planner_model_id: router-small - executor_model_id: verifier-medium - reviewer_model_id: verifier-medium + executor_model_id: openai-general + reviewer_model_id: openai-general diff --git a/data/state/.gitignore b/data/state/.gitignore new file mode 100644 index 0000000..d6b7ef3 --- /dev/null +++ b/data/state/.gitignore @@ -0,0 +1,2 @@ +* +!.gitignore diff --git a/evals/results/.gitignore b/evals/results/.gitignore new file mode 100644 index 0000000..d6b7ef3 --- /dev/null +++ b/evals/results/.gitignore @@ -0,0 +1,2 @@ +* +!.gitignore diff --git a/evals/scenarios/routing-basics.yaml b/evals/scenarios/routing-basics.yaml new file mode 100644 index 0000000..56e86ad --- /dev/null +++ b/evals/scenarios/routing-basics.yaml @@ -0,0 +1,22 @@ +version: 1 +scenarios: + - id: instant-summary + description: Short prompt should stay on a small or fast tier. + prompt: Summarize this short note in two bullets. + expected_tier: instant + + - id: deep-work-architecture + description: Comprehensive architecture requests should route to deep-work. + prompt: Create a comprehensive advanced architecture and complete implementation strategy. + expected_tier: deep-work + + - id: provider-fallback + description: If the preferred local heavy model is unavailable, routing should fall back. + prompt: Produce a detailed long-form research answer with strong quality. + expected_tier: deep-work + unavailable_providers: + - airllm + expected_executor_in: + - openai-general + - anthropic-general + - gemini-general diff --git a/pyproject.toml b/pyproject.toml index 28044d6..6156707 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,6 +25,18 @@ api = [ "fastapi>=0.116,<1.0", "uvicorn[standard]>=0.35,<1.0", ] +local = [ + "accelerate>=1.9,<2.0", + "airllm>=2.0,<3.0", + "huggingface-hub>=0.34,<1.0", + "torch>=2.7,<3.0", + "transformers>=4.54,<5.0", +] +providers = [ + "anthropic>=0.60,<1.0", + "google-genai>=1.30,<2.0", + "openai>=2.0,<3.0", +] dev = [ "httpx>=0.28,<1.0", "pytest>=8.3,<9.0", diff --git a/scripts/bootstrap/setup-local.ps1 b/scripts/bootstrap/setup-local.ps1 index 4103920..cfdafcc 100644 --- a/scripts/bootstrap/setup-local.ps1 +++ b/scripts/bootstrap/setup-local.ps1 @@ -5,6 +5,7 @@ $directories = @( "data/models/airllm-shards", "data/models/raw", "data/artifacts", + "data/state", "logs" ) diff --git a/src/lai/__init__.py b/src/lai/__init__.py index 0f0a87b..601a203 100644 --- a/src/lai/__init__.py +++ b/src/lai/__init__.py @@ -1,5 +1,15 @@ """LAI core package.""" +from .config import AppConfig, load_app_config +from .domain import ExecutionRequest, JobStatus, QueueMode, RoutingDecision from .settings import Settings -__all__ = ["Settings"] +__all__ = [ + "AppConfig", + "ExecutionRequest", + "JobStatus", + "QueueMode", + "RoutingDecision", + "Settings", + "load_app_config", +] diff --git a/src/lai/api/__init__.py b/src/lai/api/__init__.py new file mode 100644 index 0000000..d0c81a0 --- /dev/null +++ b/src/lai/api/__init__.py @@ -0,0 +1,3 @@ +from .app import create_api + +__all__ = ["create_api"] diff --git a/src/lai/api/app.py b/src/lai/api/app.py new file mode 100644 index 0000000..5780bd6 --- /dev/null +++ b/src/lai/api/app.py @@ -0,0 +1,112 @@ +from __future__ import annotations + +from fastapi import FastAPI, HTTPException +from pydantic import BaseModel, Field + +from ..application import create_application +from ..domain import ExecutionRequest, QueueMode + + +class JobCreatePayload(BaseModel): + user_prompt: str + system_prompt: str | None = None + queue_mode: QueueMode = QueueMode.AUTO + model_override: str | None = None + provider_override: str | None = None + temperature: float | None = Field(default=None, ge=0) + max_output_tokens: int | None = Field(default=None, ge=1) + timeout_seconds: int | None = Field(default=None, ge=1) + reviewer_enabled: bool | None = None + + +def create_api() -> FastAPI: + api = FastAPI(title="LAI API", version="0.1.0") + + @api.get("/health") + def health() -> dict[str, object]: + application = create_application() + return { + "status": "ok", + "environment": application.settings.environment, + "database": str(application.settings.resolved_database_path), + "model_count": len(application.config.model_catalog.models), + } + + @api.get("/models") + def list_models() -> dict[str, object]: + application = create_application() + healthchecks = application.provider_registry.model_healthchecks(application.config) + return { + "models": [ + { + "id": model.id, + "role": model.role, + "runtime": model.runtime, + "model_ref": model.model_ref, + "capabilities": model.capabilities, + "health": healthchecks[model.id].model_dump(), + } + for model in application.config.model_catalog.models + ] + } + + @api.post("/route/explain") + def explain_route(payload: JobCreatePayload) -> dict[str, object]: + application = create_application() + request = _build_execution_request(application, payload) + return application.routing_engine.route(request).model_dump() + + @api.post("/jobs") + def create_job(payload: JobCreatePayload) -> dict[str, object]: + application = create_application() + request = _build_execution_request(application, payload) + return application.orchestration.submit_request(request).model_dump() + + @api.get("/jobs") + def list_jobs(limit: int = 20) -> dict[str, object]: + application = create_application() + return {"jobs": [job.model_dump() for job in application.job_store.list_jobs(limit=limit)]} + + @api.get("/jobs/{job_id}") + def get_job(job_id: str) -> dict[str, object]: + application = create_application() + job = application.job_store.get_job(job_id) + if job is None: + raise HTTPException(status_code=404, detail="Job not found") + return job.model_dump() + + @api.post("/jobs/{job_id}/cancel") + def cancel_job(job_id: str) -> dict[str, object]: + application = create_application() + if not application.job_store.cancel_job(job_id): + raise HTTPException(status_code=404, detail="Job not found or not cancelable") + job = application.job_store.get_job(job_id) + return {"job": job.model_dump() if job else None} + + return api + + +def _build_execution_request(application, payload: JobCreatePayload) -> ExecutionRequest: + return ExecutionRequest( + system_prompt=payload.system_prompt, + user_prompt=payload.user_prompt, + queue_mode=payload.queue_mode, + model_override=payload.model_override, + provider_override=payload.provider_override, + reviewer_enabled=payload.reviewer_enabled, + temperature=( + payload.temperature + if payload.temperature is not None + else application.settings.default_temperature + ), + max_output_tokens=( + payload.max_output_tokens + if payload.max_output_tokens is not None + else application.settings.default_max_output_tokens + ), + timeout_seconds=( + payload.timeout_seconds + if payload.timeout_seconds is not None + else application.settings.default_timeout_seconds + ), + ) diff --git a/src/lai/application.py b/src/lai/application.py new file mode 100644 index 0000000..237c656 --- /dev/null +++ b/src/lai/application.py @@ -0,0 +1,63 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Mapping + +from .artifacts import ArtifactManager +from .config import AppConfig, load_app_config +from .jobs.service import OrchestrationService +from .jobs.store import JobStore +from .layout import ensure_runtime_directories +from .providers import Provider, ProviderRegistry +from .routing import RoutingEngine +from .settings import Settings +from .system import collect_system_snapshot + + +@dataclass +class LAIApplication: + settings: Settings + config: AppConfig + provider_registry: ProviderRegistry + routing_engine: RoutingEngine + job_store: JobStore + artifacts: ArtifactManager + orchestration: OrchestrationService + + +def create_application( + settings: Settings | None = None, + providers: Mapping[str, Provider] | None = None, +) -> LAIApplication: + settings = settings or Settings() + ensure_runtime_directories(settings.root_dir) + config = load_app_config(settings.resolved_model_catalog, settings.resolved_routing_policy) + system_snapshot = collect_system_snapshot( + settings.resolved_huggingface_cache_dir, + settings.resolved_airllm_shards_dir, + settings.resolved_artifacts_dir, + settings.resolved_state_dir, + enable_gpu=settings.enable_gpu, + ) + registry = ProviderRegistry(settings, system_snapshot, providers=providers) + job_store = JobStore(settings.resolved_database_path) + job_store.initialize() + artifacts = ArtifactManager(settings.resolved_artifacts_dir, job_store) + routing_engine = RoutingEngine(settings, config, registry, system_snapshot) + orchestration = OrchestrationService( + settings=settings, + config=config, + provider_registry=registry, + routing_engine=routing_engine, + job_store=job_store, + artifacts=artifacts, + ) + return LAIApplication( + settings=settings, + config=config, + provider_registry=registry, + routing_engine=routing_engine, + job_store=job_store, + artifacts=artifacts, + orchestration=orchestration, + ) diff --git a/src/lai/artifacts.py b/src/lai/artifacts.py new file mode 100644 index 0000000..275e008 --- /dev/null +++ b/src/lai/artifacts.py @@ -0,0 +1,41 @@ +from __future__ import annotations + +from pathlib import Path +from uuid import uuid4 + +from .domain import ArtifactRecord +from .jobs.store import JobStore +from .serialization import dumps_pretty + + +class ArtifactManager: + def __init__(self, artifacts_root: Path, job_store: JobStore) -> None: + self.artifacts_root = artifacts_root + self.job_store = job_store + self.artifacts_root.mkdir(parents=True, exist_ok=True) + + def write_json( + self, job_id: str, artifact_type: str, filename: str, payload: object + ) -> ArtifactRecord: + return self._write(job_id, artifact_type, filename, dumps_pretty(payload)) + + def write_text( + self, job_id: str, artifact_type: str, filename: str, payload: str + ) -> ArtifactRecord: + return self._write(job_id, artifact_type, filename, payload) + + def _write( + self, job_id: str, artifact_type: str, filename: str, payload: str + ) -> ArtifactRecord: + job_dir = self.artifacts_root / job_id + job_dir.mkdir(parents=True, exist_ok=True) + path = job_dir / filename + path.write_text(payload, encoding="utf-8") + artifact = ArtifactRecord( + id=str(uuid4()), + job_id=job_id, + artifact_type=artifact_type, + relative_path=str(path.relative_to(self.artifacts_root)), + ) + self.job_store.add_artifact(artifact) + return artifact diff --git a/src/lai/cli.py b/src/lai/cli.py index 49fb7d4..3694bb2 100644 --- a/src/lai/cli.py +++ b/src/lai/cli.py @@ -1,13 +1,30 @@ +from __future__ import annotations + from pathlib import Path import typer from rich.console import Console +from rich.panel import Panel from rich.table import Table +from .application import create_application +from .domain import ExecutionRequest, QueueMode +from .evals import run_route_eval_suite, save_route_eval_result from .layout import runtime_directories from .settings import Settings +from .system import collect_system_snapshot -app = typer.Typer(help="Utilities for the LAI repository scaffold.", no_args_is_help=True) +app = typer.Typer(help="Utilities for the LAI orchestration platform.", no_args_is_help=True) +models_app = typer.Typer(help="Inspect model definitions and provider readiness.") +route_app = typer.Typer(help="Explain routing decisions without executing.") +jobs_app = typer.Typer(help="Inspect and manage persistent jobs.") +worker_app = typer.Typer(help="Run the local job worker.") +eval_app = typer.Typer(help="Run evaluation suites against routing behavior.") +app.add_typer(models_app, name="models") +app.add_typer(route_app, name="route") +app.add_typer(jobs_app, name="jobs") +app.add_typer(worker_app, name="worker") +app.add_typer(eval_app, name="eval") console = Console() @@ -18,8 +35,12 @@ def main() -> None: @app.command() def doctor() -> None: - """Print the expected repository configuration and runtime layout.""" + """Print repository, config, and system readiness.""" settings = Settings() + snapshot = collect_system_snapshot( + *runtime_directories(settings.root_dir), enable_gpu=settings.enable_gpu + ) + application = create_application(settings=settings) table = Table(title="LAI Workspace Doctor") table.add_column("Item") @@ -29,6 +50,10 @@ def doctor() -> None: table.add_row("Root", str(settings.root_dir)) table.add_row("Model catalog", _status_line(settings.resolved_model_catalog)) table.add_row("Routing policy", _status_line(settings.resolved_routing_policy)) + table.add_row("Database", _status_line(settings.resolved_database_path)) + table.add_row("GPU available", "yes" if snapshot.has_gpu else "no") + table.add_row("GPU name", snapshot.gpu_name or "n/a") + table.add_row("Catalog models", str(len(application.config.model_catalog.models))) for path in runtime_directories(settings.root_dir): table.add_row("Runtime path", _status_line(path)) @@ -36,6 +61,248 @@ def doctor() -> None: console.print(table) +@models_app.command("list") +def list_models() -> None: + """List catalog models and their main properties.""" + application = create_application() + table = Table(title="LAI Models") + table.add_column("Id") + table.add_column("Role") + table.add_column("Runtime") + table.add_column("Model Ref") + table.add_column("Capabilities") + + for model in application.config.model_catalog.models: + table.add_row( + model.id, + model.role, + model.runtime, + model.model_ref, + ", ".join(model.capabilities), + ) + console.print(table) + + +@models_app.command("check") +def check_models() -> None: + """Check provider availability and local constraints for each model.""" + application = create_application() + healthchecks = application.provider_registry.model_healthchecks(application.config) + + table = Table(title="LAI Model Health") + table.add_column("Id") + table.add_column("Provider") + table.add_column("Available") + table.add_column("Healthy") + table.add_column("Details") + + for model in application.config.model_catalog.models: + health = healthchecks[model.id] + table.add_row( + model.id, + model.provider_id, + "yes" if health.available else "no", + "yes" if health.healthy else "no", + "; ".join(health.reasons) if health.reasons else "ready", + ) + console.print(table) + + +@route_app.command("explain") +def explain_route( + prompt: str = typer.Argument(..., help="The user request to classify and route."), + system_prompt: str | None = typer.Option(None, help="Optional system prompt."), + model_override: str | None = typer.Option(None, help="Force a specific model id."), + provider_override: str | None = typer.Option(None, help="Force a specific provider id."), + no_review: bool = typer.Option(False, "--no-review", help="Disable the reviewer stage."), +) -> None: + """Explain how LAI would route a request.""" + application = create_application() + request = ExecutionRequest( + system_prompt=system_prompt, + user_prompt=prompt, + model_override=model_override, + provider_override=provider_override, + reviewer_enabled=False if no_review else None, + ) + decision = application.routing_engine.route(request) + + summary = Table(title="Routing Decision") + summary.add_column("Field") + summary.add_column("Value") + summary.add_row("Tier", decision.matched_tier_id) + summary.add_row("Planner", decision.planner_model_id or "none") + summary.add_row("Executor", decision.executor_model_id) + summary.add_row("Reviewer", decision.reviewer_model_id or "none") + summary.add_row("Queue recommended", "yes" if decision.queue_recommended else "no") + summary.add_row("Fallbacks", ", ".join(decision.fallback_model_ids) or "none") + console.print(summary) + + reasons = Table(title="Routing Reasons") + reasons.add_column("Stage") + reasons.add_column("Reason") + for reason in decision.reasons: + reasons.add_row(reason.stage, reason.message) + console.print(reasons) + + +@app.command() +def run( + prompt: str = typer.Argument(..., help="The user request to execute."), + system_prompt: str | None = typer.Option(None, help="Optional system prompt."), + queue_mode: QueueMode = typer.Option( + QueueMode.AUTO, help="Run inline, queued, or let LAI decide." + ), + model_override: str | None = typer.Option(None, help="Force a specific model id."), + provider_override: str | None = typer.Option(None, help="Force a specific provider id."), + temperature: float | None = typer.Option(None, help="Override generation temperature."), + max_output_tokens: int | None = typer.Option(None, help="Override output token budget."), + timeout_seconds: int | None = typer.Option(None, help="Override request timeout."), + no_review: bool = typer.Option(False, "--no-review", help="Disable the reviewer stage."), +) -> None: + """Submit a request for inline or queued execution.""" + application = create_application() + request = ExecutionRequest( + system_prompt=system_prompt, + user_prompt=prompt, + queue_mode=queue_mode, + model_override=model_override, + provider_override=provider_override, + reviewer_enabled=False if no_review else None, + temperature=temperature + if temperature is not None + else application.settings.default_temperature, + max_output_tokens=( + max_output_tokens + if max_output_tokens is not None + else application.settings.default_max_output_tokens + ), + timeout_seconds=( + timeout_seconds + if timeout_seconds is not None + else application.settings.default_timeout_seconds + ), + ) + job = application.orchestration.submit_request(request) + + if job.queue_mode == QueueMode.QUEUED and job.status == "queued": + console.print(Panel.fit(f"Queued job {job.id}", title="LAI Run")) + return + + console.print( + Panel.fit(job.result.text if job.result else "No output produced.", title=f"Job {job.id}") + ) + + +@jobs_app.command("list") +def list_jobs(limit: int = typer.Option(20, min=1, max=100, help="Maximum jobs to list.")) -> None: + """List persisted jobs.""" + application = create_application() + jobs = application.job_store.list_jobs(limit=limit) + table = Table(title="LAI Jobs") + table.add_column("Id") + table.add_column("Status") + table.add_column("Queue Mode") + table.add_column("Created") + table.add_column("Executor") + + for job in jobs: + executor = job.route_decision.executor_model_id if job.route_decision else "n/a" + table.add_row(job.id, job.status, job.queue_mode, job.created_at.isoformat(), executor) + console.print(table) + + +@jobs_app.command("show") +def show_job(job_id: str = typer.Argument(..., help="Job identifier.")) -> None: + """Show a job and its latest output.""" + application = create_application() + job = application.job_store.get_job(job_id) + if job is None: + raise typer.Exit(code=1) + + summary = Table(title=f"Job {job.id}") + summary.add_column("Field") + summary.add_column("Value") + summary.add_row("Status", job.status) + summary.add_row("Queue mode", job.queue_mode) + summary.add_row("Attempts", str(job.attempts)) + summary.add_row("Tier", job.route_decision.matched_tier_id if job.route_decision else "n/a") + summary.add_row( + "Executor", job.route_decision.executor_model_id if job.route_decision else "n/a" + ) + summary.add_row("Artifacts", str(len(job.artifacts))) + console.print(summary) + + if job.result: + console.print(Panel.fit(job.result.text, title="Latest Output")) + if job.error: + console.print(Panel.fit(job.error.message, title="Error")) + + +@jobs_app.command("cancel") +def cancel_job(job_id: str = typer.Argument(..., help="Job identifier.")) -> None: + """Cancel a queued or running job.""" + application = create_application() + if application.job_store.cancel_job(job_id): + console.print(f"Canceled job {job_id}") + return + raise typer.Exit(code=1) + + +@worker_app.command("run") +def run_worker( + once: bool = typer.Option(False, help="Process at most one queued job and then exit."), +) -> None: + """Run the local worker for queued jobs.""" + application = create_application() + processed = application.orchestration.run_worker(once=once) + console.print(f"Processed {processed} queued job(s).") + + +@eval_app.command("route") +def eval_route( + scenario_file: Path = typer.Option( + Path("evals/scenarios/routing-basics.yaml"), + exists=True, + file_okay=True, + dir_okay=False, + help="YAML file describing route evaluation scenarios.", + ), + save: bool = typer.Option( + True, + "--save/--no-save", + help="Save the evaluation result under evals/results.", + ), +) -> None: + """Run the route evaluation suite and report pass/fail.""" + settings = Settings() + result = run_route_eval_suite(settings, settings.root_dir / scenario_file) + + table = Table(title="LAI Route Eval") + table.add_column("Scenario") + table.add_column("Status") + table.add_column("Expected Tier") + table.add_column("Actual Tier") + table.add_column("Executor") + + for case in result.cases: + table.add_row( + case.scenario_id, + "PASS" if case.passed else "FAIL", + case.expected_tier, + case.actual_tier, + case.actual_executor, + ) + console.print(table) + + if save: + result_path = save_route_eval_result(settings.root_dir / "evals/results", result) + console.print(f"Saved evaluation result to {result_path}") + + if not result.passed: + raise typer.Exit(code=1) + + def _status_line(path: Path) -> str: suffix = "present" if path.exists() else "missing" return f"{path} ({suffix})" diff --git a/src/lai/config.py b/src/lai/config.py new file mode 100644 index 0000000..5b7b552 --- /dev/null +++ b/src/lai/config.py @@ -0,0 +1,72 @@ +from __future__ import annotations + +from pathlib import Path +from typing import Any + +import yaml +from pydantic import BaseModel + +from .domain import ModelCatalog, RoutingPolicy + + +class AppConfig(BaseModel): + model_catalog: ModelCatalog + routing_policy: RoutingPolicy + + +def _load_yaml_file(path: Path) -> dict[str, Any]: + if not path.exists(): + raise FileNotFoundError(f"Configuration file not found: {path}") + + with path.open("r", encoding="utf-8") as handle: + data = yaml.safe_load(handle) or {} + if not isinstance(data, dict): + raise ValueError(f"Configuration file must contain a mapping at the top level: {path}") + return data + + +def load_model_catalog(path: Path) -> ModelCatalog: + return ModelCatalog.model_validate(_load_yaml_file(path)) + + +def load_routing_policy(path: Path) -> RoutingPolicy: + return RoutingPolicy.model_validate(_load_yaml_file(path)) + + +def load_app_config(model_catalog_path: Path, routing_policy_path: Path) -> AppConfig: + config = AppConfig( + model_catalog=load_model_catalog(model_catalog_path), + routing_policy=load_routing_policy(routing_policy_path), + ) + validate_config_references(config) + return config + + +def validate_config_references(config: AppConfig) -> None: + model_ids = {model.id for model in config.model_catalog.models} + + if config.routing_policy.router.model_id not in model_ids: + router_model_id = config.routing_policy.router.model_id + raise ValueError( + f"Router model id {router_model_id!r} is not present in the catalog." + ) + + for tier in config.routing_policy.tiers: + _assert_model_reference(model_ids, tier.executor_model_id, f"tier {tier.id} executor") + if tier.planner_model_id: + _assert_model_reference(model_ids, tier.planner_model_id, f"tier {tier.id} planner") + if tier.reviewer_model_id: + _assert_model_reference(model_ids, tier.reviewer_model_id, f"tier {tier.id} reviewer") + + fallback = config.routing_policy.fallbacks.when_gpu_unavailable + if fallback: + _assert_model_reference(model_ids, fallback.executor_model_id, "GPU fallback executor") + if fallback.planner_model_id: + _assert_model_reference(model_ids, fallback.planner_model_id, "GPU fallback planner") + if fallback.reviewer_model_id: + _assert_model_reference(model_ids, fallback.reviewer_model_id, "GPU fallback reviewer") + + +def _assert_model_reference(model_ids: set[str], model_id: str, context: str) -> None: + if model_id not in model_ids: + raise ValueError(f"{context} references unknown model id {model_id!r}.") diff --git a/src/lai/domain.py b/src/lai/domain.py new file mode 100644 index 0000000..2f9ad7a --- /dev/null +++ b/src/lai/domain.py @@ -0,0 +1,317 @@ +from __future__ import annotations + +from datetime import datetime, timezone +from enum import Enum +from typing import Any + +from pydantic import AliasChoices, BaseModel, ConfigDict, Field, model_validator + + +def utcnow() -> datetime: + return datetime.now(tz=timezone.utc) + + +class LAIModel(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + use_enum_values=True, + extra="forbid", + ) + + +class ModelRole(str, Enum): + ROUTER = "router" + REVIEWER = "reviewer" + EXECUTOR = "executor" + PLANNER = "planner" + + +class ModelRuntime(str, Enum): + TRANSFORMERS = "transformers" + AIRLLM = "airllm" + OPENAI = "openai" + ANTHROPIC = "anthropic" + GEMINI = "gemini" + + +class QueueMode(str, Enum): + AUTO = "auto" + INLINE = "inline" + QUEUED = "queued" + + +class JobStatus(str, Enum): + QUEUED = "queued" + RUNNING = "running" + SUCCEEDED = "succeeded" + FAILED = "failed" + CANCELED = "canceled" + + +class FinishReason(str, Enum): + STOP = "stop" + LENGTH = "length" + ERROR = "error" + CANCELED = "canceled" + UNKNOWN = "unknown" + + +class ModelHardwareSpec(LAIModel): + preferred_device: str = "cpu" + minimum_vram_gb: int | None = None + recommended_vram_gb: int | None = None + recommended_ram_gb: int | None = None + expected_disk_gb: int | None = None + + +class ModelRuntimeHints(LAIModel): + allow_layer_sharding: bool = False + allow_prefetching: bool = False + allow_cpu_fallback: bool = False + + +class ModelSpec(LAIModel): + id: str + role: ModelRole + runtime: ModelRuntime + model_ref: str = Field(validation_alias=AliasChoices("model_ref", "repo_id", "model")) + context_window: int = 8192 + capabilities: list[str] = Field(default_factory=list) + hardware: ModelHardwareSpec = Field(default_factory=ModelHardwareSpec) + runtime_hints: ModelRuntimeHints = Field(default_factory=ModelRuntimeHints) + enabled: bool = True + description: str | None = None + + @property + def provider_id(self) -> str: + if isinstance(self.runtime, Enum): + return self.runtime.value + return str(self.runtime) + + def supports_capabilities(self, required_capabilities: list[str]) -> bool: + return set(required_capabilities).issubset(set(self.capabilities)) + + +class CatalogDefaults(LAIModel): + allow_gated_models: bool = True + require_disk_headroom_gb: int = 50 + require_healthcheck_before_enable: bool = True + + +class ModelCatalog(LAIModel): + version: int + defaults: CatalogDefaults = Field(default_factory=CatalogDefaults) + models: list[ModelSpec] + + @model_validator(mode="after") + def validate_unique_model_ids(self) -> "ModelCatalog": + model_ids = [model.id for model in self.models] + duplicates = {model_id for model_id in model_ids if model_ids.count(model_id) > 1} + if duplicates: + raise ValueError(f"Duplicate model ids in catalog: {sorted(duplicates)}") + return self + + def get_model(self, model_id: str) -> ModelSpec: + for model in self.models: + if model.id == model_id: + return model + raise KeyError(f"Unknown model id: {model_id}") + + def enabled_models(self) -> list[ModelSpec]: + return [model for model in self.models if model.enabled] + + +class RouterConfig(LAIModel): + model_id: str + max_input_tokens: int = 2048 + classify_dimensions: list[str] = Field(default_factory=list) + use_model_router: bool = True + + +class RoutingTierMatch(LAIModel): + complexity: str | None = None + urgency: str | None = None + cost_sensitivity: str | None = None + safety_risk: str | None = None + expected_duration: str | None = None + + def score(self, context: "RouteContext") -> int: + score = 0 + if self.complexity and self.complexity == context.complexity: + score += 1 + if self.urgency and self.urgency == context.urgency: + score += 1 + if self.cost_sensitivity and self.cost_sensitivity == context.cost_sensitivity: + score += 1 + if self.safety_risk and self.safety_risk == context.safety_risk: + score += 1 + if self.expected_duration and self.expected_duration == context.expected_duration: + score += 1 + return score + + +class StageSelection(LAIModel): + planner_model_id: str | None = None + executor_model_id: str + reviewer_model_id: str | None = None + + +class RoutingTier(LAIModel): + id: str + description: str + match: RoutingTierMatch + planner_model_id: str | None = None + executor_model_id: str + reviewer_model_id: str | None = None + allow_overnight: bool = False + reviewer_enabled: bool | None = None + + @property + def resolved_reviewer_enabled(self) -> bool: + if self.reviewer_enabled is not None: + return self.reviewer_enabled + return self.id != "instant" + + @property + def should_plan(self) -> bool: + return self.id in {"standard", "deep-work"} + + +class RoutingFallbacks(LAIModel): + when_gpu_unavailable: StageSelection | None = None + + +class RoutingPolicy(LAIModel): + version: int + router: RouterConfig + tiers: list[RoutingTier] + fallbacks: RoutingFallbacks = Field(default_factory=RoutingFallbacks) + + @model_validator(mode="after") + def validate_unique_tier_ids(self) -> "RoutingPolicy": + tier_ids = [tier.id for tier in self.tiers] + duplicates = {tier_id for tier_id in tier_ids if tier_ids.count(tier_id) > 1} + if duplicates: + raise ValueError(f"Duplicate routing tier ids: {sorted(duplicates)}") + return self + + def get_tier(self, tier_id: str) -> RoutingTier: + for tier in self.tiers: + if tier.id == tier_id: + return tier + raise KeyError(f"Unknown routing tier id: {tier_id}") + + +class RouteContext(LAIModel): + prompt_length: int + complexity: str + urgency: str + cost_sensitivity: str + safety_risk: str + expected_duration: str + requires_high_quality: bool = False + overnight_requested: bool = False + required_capabilities: list[str] = Field(default_factory=list) + model_override: str | None = None + provider_override: str | None = None + + +class RouteReason(LAIModel): + stage: str + message: str + + +class RoutingDecision(LAIModel): + matched_tier_id: str + planner_model_id: str | None = None + executor_model_id: str + reviewer_model_id: str | None = None + fallback_model_ids: list[str] = Field(default_factory=list) + reasons: list[RouteReason] = Field(default_factory=list) + queue_recommended: bool = False + should_plan: bool = False + should_review: bool = False + context: RouteContext + + +class ExecutionRequest(LAIModel): + system_prompt: str | None = None + user_prompt: str + metadata: dict[str, Any] = Field(default_factory=dict) + temperature: float = 0.2 + max_output_tokens: int = 1024 + timeout_seconds: int = 120 + model_override: str | None = None + provider_override: str | None = None + queue_mode: QueueMode = QueueMode.AUTO + allow_overnight: bool = True + reviewer_enabled: bool | None = None + required_capabilities: list[str] = Field(default_factory=list) + source: str = "cli" + + +class UsageStats(LAIModel): + prompt_tokens: int | None = None + completion_tokens: int | None = None + total_tokens: int | None = None + + +class ProviderRequest(LAIModel): + system_prompt: str | None = None + user_prompt: str + metadata: dict[str, Any] = Field(default_factory=dict) + temperature: float = 0.2 + max_output_tokens: int = 1024 + timeout_seconds: int = 120 + model_override: str | None = None + + +class ProviderHealth(LAIModel): + provider_id: str + available: bool + healthy: bool + reasons: list[str] = Field(default_factory=list) + metadata: dict[str, Any] = Field(default_factory=dict) + + +class ExecutionResult(LAIModel): + text: str + finish_reason: FinishReason = FinishReason.UNKNOWN + provider_id: str + model_id: str + duration_seconds: float + usage: UsageStats | None = None + raw: dict[str, Any] = Field(default_factory=dict) + stage: str = "executor" + + +class JobError(LAIModel): + error_type: str + message: str + retryable: bool = False + + +class ArtifactRecord(LAIModel): + id: str + job_id: str + artifact_type: str + relative_path: str + created_at: datetime = Field(default_factory=utcnow) + metadata: dict[str, Any] = Field(default_factory=dict) + + +class JobRecord(LAIModel): + id: str + status: JobStatus + created_at: datetime = Field(default_factory=utcnow) + updated_at: datetime = Field(default_factory=utcnow) + queued_at: datetime | None = None + started_at: datetime | None = None + finished_at: datetime | None = None + queue_mode: QueueMode = QueueMode.AUTO + request: ExecutionRequest + route_decision: RoutingDecision | None = None + result: ExecutionResult | None = None + error: JobError | None = None + attempts: int = 0 + artifacts: list[ArtifactRecord] = Field(default_factory=list) diff --git a/src/lai/errors.py b/src/lai/errors.py new file mode 100644 index 0000000..5a74e37 --- /dev/null +++ b/src/lai/errors.py @@ -0,0 +1,10 @@ +class LAIError(Exception): + """Base exception for LAI.""" + + +class RetryableProviderError(LAIError): + """Raised when a provider failure can be retried safely.""" + + +class ConfigurationError(LAIError): + """Raised when LAI configuration is invalid.""" diff --git a/src/lai/evals.py b/src/lai/evals.py new file mode 100644 index 0000000..361b2f4 --- /dev/null +++ b/src/lai/evals.py @@ -0,0 +1,167 @@ +from __future__ import annotations + +from datetime import datetime, timezone +from pathlib import Path +from typing import Any + +import yaml +from pydantic import BaseModel, Field + +from .application import create_application +from .domain import ( + ExecutionRequest, + ExecutionResult, + FinishReason, + ModelRuntime, + ModelSpec, + ProviderHealth, + ProviderRequest, +) +from .providers.base import Provider +from .settings import Settings +from .system import collect_system_snapshot + + +def utcnow_iso() -> str: + return datetime.now(tz=timezone.utc).isoformat() + + +class RouteEvalScenario(BaseModel): + id: str + description: str + prompt: str + expected_tier: str + expected_executor_in: list[str] = Field(default_factory=list) + unavailable_providers: list[str] = Field(default_factory=list) + + +class RouteEvalSuite(BaseModel): + version: int + scenarios: list[RouteEvalScenario] + + +class RouteEvalCaseResult(BaseModel): + scenario_id: str + passed: bool + expected_tier: str + actual_tier: str + expected_executor_in: list[str] = Field(default_factory=list) + actual_executor: str + reasons: list[str] = Field(default_factory=list) + + +class RouteEvalResult(BaseModel): + executed_at: str + scenario_file: str + passed: bool + total: int + passed_count: int + failed_count: int + cases: list[RouteEvalCaseResult] + + +def load_route_eval_suite(path: Path) -> RouteEvalSuite: + with path.open("r", encoding="utf-8") as handle: + data = yaml.safe_load(handle) or {} + return RouteEvalSuite.model_validate(data) + + +def run_route_eval_suite(settings: Settings, scenario_file: Path) -> RouteEvalResult: + suite = load_route_eval_suite(scenario_file) + cases: list[RouteEvalCaseResult] = [] + + for scenario in suite.scenarios: + application = create_application( + settings=settings, + providers=_eval_providers(settings, scenario.unavailable_providers), + ) + decision = application.routing_engine.route( + ExecutionRequest( + user_prompt=scenario.prompt, + ) + ) + passed = decision.matched_tier_id == scenario.expected_tier + if scenario.expected_executor_in: + passed = passed and decision.executor_model_id in scenario.expected_executor_in + + cases.append( + RouteEvalCaseResult( + scenario_id=scenario.id, + passed=passed, + expected_tier=scenario.expected_tier, + actual_tier=decision.matched_tier_id, + expected_executor_in=scenario.expected_executor_in, + actual_executor=decision.executor_model_id, + reasons=[f"{reason.stage}: {reason.message}" for reason in decision.reasons], + ) + ) + + passed_count = sum(1 for case in cases if case.passed) + return RouteEvalResult( + executed_at=utcnow_iso(), + scenario_file=str(scenario_file), + passed=passed_count == len(cases), + total=len(cases), + passed_count=passed_count, + failed_count=len(cases) - passed_count, + cases=cases, + ) + + +def save_route_eval_result(results_dir: Path, result: RouteEvalResult) -> Path: + results_dir.mkdir(parents=True, exist_ok=True) + timestamp = datetime.now(tz=timezone.utc).strftime("%Y%m%dT%H%M%SZ") + path = results_dir / f"route-eval-{timestamp}.json" + path.write_text(result.model_dump_json(indent=2), encoding="utf-8") + return path + + +class EvalProvider(Provider): + def __init__( + self, + settings: Settings, + provider_id: str, + *, + unavailable_providers: list[str] | None = None, + ) -> None: + snapshot = collect_system_snapshot(settings.root_dir, enable_gpu=settings.enable_gpu) + super().__init__(settings, snapshot) + self.provider_id = provider_id + self.unavailable_providers = set(unavailable_providers or []) + + def healthcheck(self, model: ModelSpec) -> ProviderHealth: + if self.provider_id in self.unavailable_providers: + return ProviderHealth( + provider_id=self.provider_id, + available=False, + healthy=False, + reasons=[f"{self.provider_id} marked unavailable for this evaluation scenario"], + ) + return ProviderHealth(provider_id=self.provider_id, available=True, healthy=True) + + def generate(self, model: ModelSpec, request: ProviderRequest) -> ExecutionResult: + if model.role == "router": + text = '{"tier_id": "standard"}' + else: + text = f"eval:{model.id}" + return ExecutionResult( + text=text, + finish_reason=FinishReason.STOP, + provider_id=self.provider_id, + model_id=model.id, + duration_seconds=0.0, + ) + + def describe_capabilities(self) -> dict[str, Any]: + return {"evaluation_provider": True} + + +def _eval_providers(settings: Settings, unavailable_providers: list[str]) -> dict[str, Provider]: + providers: dict[str, Provider] = {} + for runtime in ModelRuntime: + providers[runtime.value] = EvalProvider( + settings, + runtime.value, + unavailable_providers=unavailable_providers, + ) + return providers diff --git a/src/lai/jobs/__init__.py b/src/lai/jobs/__init__.py index eeaf73a..96b0b6f 100644 --- a/src/lai/jobs/__init__.py +++ b/src/lai/jobs/__init__.py @@ -1 +1 @@ -"""Job orchestration primitives live here.""" +"""Job persistence and orchestration modules.""" diff --git a/src/lai/jobs/service.py b/src/lai/jobs/service.py new file mode 100644 index 0000000..4d51015 --- /dev/null +++ b/src/lai/jobs/service.py @@ -0,0 +1,239 @@ +from __future__ import annotations + +import time +from enum import Enum +from uuid import uuid4 + +from ..artifacts import ArtifactManager +from ..config import AppConfig +from ..domain import ( + ExecutionRequest, + JobError, + JobRecord, + JobStatus, + ProviderRequest, + QueueMode, + RoutingDecision, + utcnow, +) +from ..errors import RetryableProviderError +from ..providers import ProviderRegistry +from ..routing import RoutingEngine +from ..settings import Settings +from .store import JobStore + + +class OrchestrationService: + def __init__( + self, + settings: Settings, + config: AppConfig, + provider_registry: ProviderRegistry, + routing_engine: RoutingEngine, + job_store: JobStore, + artifacts: ArtifactManager, + ) -> None: + self.settings = settings + self.config = config + self.provider_registry = provider_registry + self.routing_engine = routing_engine + self.job_store = job_store + self.artifacts = artifacts + + def submit_request(self, request: ExecutionRequest) -> JobRecord: + route_decision = self.routing_engine.route(request) + now = utcnow() + queue_mode = self._resolved_queue_mode(request.queue_mode, route_decision) + job = JobRecord( + id=str(uuid4()), + status=JobStatus.QUEUED, + created_at=now, + updated_at=now, + queued_at=now, + queue_mode=queue_mode, + request=request, + route_decision=route_decision, + ) + self.job_store.save_job(job) + self.artifacts.write_json( + job.id, "normalized-request", "request.json", request.model_dump() + ) + self.artifacts.write_json( + job.id, "route-decision", "route.json", route_decision.model_dump() + ) + + if queue_mode == QueueMode.INLINE: + return self.execute_job(job.id) + return self.job_store.get_job(job.id) or job + + def execute_job(self, job_id: str) -> JobRecord: + job = self.job_store.get_job(job_id) + if job is None: + raise KeyError(f"Unknown job id {job_id!r}.") + if job.status == JobStatus.CANCELED: + return job + + if job.status != JobStatus.RUNNING: + job.attempts += 1 + job.status = JobStatus.RUNNING + job.started_at = job.started_at or utcnow() + job.updated_at = utcnow() + self.job_store.save_job(job) + + try: + final_text = self._run_pipeline(job) + if job.result: + job.result = job.result.model_copy(update={"text": final_text}) + job.status = JobStatus.SUCCEEDED + job.finished_at = utcnow() + job.updated_at = utcnow() + self.job_store.save_job(job) + self.artifacts.write_text(job.id, "final-output", "final_output.txt", final_text) + return self.job_store.get_job(job.id) or job + except Exception as exc: + retryable = _is_retryable_exception(exc) + job.error = JobError( + error_type=type(exc).__name__, message=str(exc), retryable=retryable + ) + job.updated_at = utcnow() + if retryable and job.attempts <= self.settings.max_retry_attempts: + job.status = JobStatus.QUEUED + self.job_store.save_job(job) + else: + job.status = JobStatus.FAILED + job.finished_at = utcnow() + self.job_store.save_job(job) + self.artifacts.write_text(job.id, "error", "error.txt", str(exc)) + return self.job_store.get_job(job.id) or job + + def run_worker(self, once: bool = False) -> int: + processed = 0 + self.job_store.requeue_running_jobs() + while True: + job = self.job_store.claim_next_queued_job() + if job is None: + if once: + return processed + time.sleep(self.settings.worker_idle_sleep_seconds) + continue + self.execute_job(job.id) + processed += 1 + if once: + return processed + + def _run_pipeline(self, job: JobRecord) -> str: + assert job.route_decision is not None + + planning_output = "" + if job.route_decision.should_plan and job.route_decision.planner_model_id: + planning_output = self._run_stage( + job=job, + stage="planner", + model_id=job.route_decision.planner_model_id, + system_prompt=job.request.system_prompt + or ( + "Create a brief internal execution plan. " + "Keep it concise and implementation-focused." + ), + user_prompt=job.request.user_prompt, + ) + + executor_prompt = job.request.user_prompt + if planning_output: + executor_prompt = ( + "Internal execution plan:\n" + f"{planning_output}\n\n" + "Follow that plan while answering the original request.\n\n" + f"{job.request.user_prompt}" + ) + + executor_output = self._run_stage( + job=job, + stage="executor", + model_id=job.route_decision.executor_model_id, + system_prompt=job.request.system_prompt, + user_prompt=executor_prompt, + ) + + if job.route_decision.should_review and job.route_decision.reviewer_model_id: + reviewer_output = self._run_stage( + job=job, + stage="reviewer", + model_id=job.route_decision.reviewer_model_id, + system_prompt=( + "Review the answer for correctness, completeness, " + "and structure. Keep notes concise." + ), + user_prompt=( + "Original request:\n" + f"{job.request.user_prompt}\n\n" + "Candidate answer:\n" + f"{executor_output}" + ), + ) + self.artifacts.write_text( + job.id, "reviewer-notes", "reviewer_notes.txt", reviewer_output + ) + + return executor_output + + def _run_stage( + self, + *, + job: JobRecord, + stage: str, + model_id: str, + system_prompt: str | None, + user_prompt: str, + ) -> str: + model = self.config.model_catalog.get_model(model_id) + provider = self.provider_registry.provider_for_model(model) + health = self.provider_registry.healthcheck(model) + if not health.available: + reason = "; ".join(health.reasons) or "unavailable" + raise RuntimeError( + f"Model {model.id!r} is not executable: {reason}" + ) + request = ProviderRequest( + system_prompt=system_prompt, + user_prompt=user_prompt, + metadata={"job_id": job.id, "stage": stage}, + temperature=job.request.temperature, + max_output_tokens=job.request.max_output_tokens, + timeout_seconds=job.request.timeout_seconds, + ) + self.artifacts.write_json( + job.id, f"{stage}-request", f"{stage}_request.json", request.model_dump() + ) + result = provider.generate(model, request) + result.stage = stage + self.artifacts.write_json( + job.id, f"{stage}-response", f"{stage}_response.json", result.model_dump() + ) + if stage == "executor": + job.result = result + job.updated_at = utcnow() + self.job_store.save_job(job) + return result.text + + @staticmethod + def _resolved_queue_mode(queue_mode: QueueMode, route_decision: RoutingDecision) -> QueueMode: + queue_mode_value = _enum_value(queue_mode) + if queue_mode_value == "auto": + return QueueMode.QUEUED if route_decision.queue_recommended else QueueMode.INLINE + return QueueMode(queue_mode_value) + + +def _is_retryable_exception(exc: Exception) -> bool: + if isinstance(exc, RetryableProviderError): + return True + name = type(exc).__name__.lower() + module = type(exc).__module__.lower() + retry_markers = ("timeout", "connection", "rate", "tempor") + return any(marker in name or marker in module for marker in retry_markers) + + +def _enum_value(value: object) -> str: + if isinstance(value, Enum): + return str(value.value) + return str(value) diff --git a/src/lai/jobs/store.py b/src/lai/jobs/store.py new file mode 100644 index 0000000..67feade --- /dev/null +++ b/src/lai/jobs/store.py @@ -0,0 +1,216 @@ +from __future__ import annotations + +import json +import sqlite3 +from pathlib import Path + +from ..domain import ArtifactRecord, JobRecord, JobStatus, utcnow + + +class JobStore: + def __init__(self, database_path: Path) -> None: + self.database_path = database_path + self.database_path.parent.mkdir(parents=True, exist_ok=True) + + def initialize(self) -> None: + with self._connect() as connection: + connection.executescript( + """ + CREATE TABLE IF NOT EXISTS jobs ( + id TEXT PRIMARY KEY, + status TEXT NOT NULL, + created_at TEXT NOT NULL, + updated_at TEXT NOT NULL, + queued_at TEXT, + started_at TEXT, + finished_at TEXT, + queue_mode TEXT NOT NULL, + request_json TEXT NOT NULL, + route_json TEXT, + result_json TEXT, + error_json TEXT, + attempts INTEGER NOT NULL DEFAULT 0 + ); + + CREATE TABLE IF NOT EXISTS artifacts ( + id TEXT PRIMARY KEY, + job_id TEXT NOT NULL, + artifact_type TEXT NOT NULL, + relative_path TEXT NOT NULL, + created_at TEXT NOT NULL, + metadata_json TEXT NOT NULL, + FOREIGN KEY(job_id) REFERENCES jobs(id) + ); + """ + ) + + def save_job(self, job: JobRecord) -> None: + with self._connect() as connection: + connection.execute( + """ + INSERT INTO jobs ( + id, status, created_at, updated_at, queued_at, started_at, finished_at, + queue_mode, request_json, route_json, result_json, error_json, attempts + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + ON CONFLICT(id) DO UPDATE SET + status=excluded.status, + updated_at=excluded.updated_at, + queued_at=excluded.queued_at, + started_at=excluded.started_at, + finished_at=excluded.finished_at, + queue_mode=excluded.queue_mode, + request_json=excluded.request_json, + route_json=excluded.route_json, + result_json=excluded.result_json, + error_json=excluded.error_json, + attempts=excluded.attempts + """, + ( + job.id, + job.status, + job.created_at.isoformat(), + job.updated_at.isoformat(), + job.queued_at.isoformat() if job.queued_at else None, + job.started_at.isoformat() if job.started_at else None, + job.finished_at.isoformat() if job.finished_at else None, + job.queue_mode, + job.request.model_dump_json(), + job.route_decision.model_dump_json() if job.route_decision else None, + job.result.model_dump_json() if job.result else None, + job.error.model_dump_json() if job.error else None, + job.attempts, + ), + ) + + def add_artifact(self, artifact: ArtifactRecord) -> None: + with self._connect() as connection: + connection.execute( + """ + INSERT OR REPLACE INTO artifacts ( + id, job_id, artifact_type, relative_path, created_at, metadata_json + ) VALUES (?, ?, ?, ?, ?, ?) + """, + ( + artifact.id, + artifact.job_id, + artifact.artifact_type, + artifact.relative_path, + artifact.created_at.isoformat(), + json.dumps(artifact.metadata), + ), + ) + + def get_job(self, job_id: str) -> JobRecord | None: + with self._connect() as connection: + row = connection.execute("SELECT * FROM jobs WHERE id = ?", (job_id,)).fetchone() + if row is None: + return None + return self._deserialize_job(connection, row) + + def list_jobs(self, limit: int = 20) -> list[JobRecord]: + with self._connect() as connection: + rows = connection.execute( + "SELECT * FROM jobs ORDER BY created_at DESC LIMIT ?", + (limit,), + ).fetchall() + return [self._deserialize_job(connection, row) for row in rows] + + def cancel_job(self, job_id: str) -> bool: + with self._connect() as connection: + result = connection.execute( + """ + UPDATE jobs + SET status = ?, updated_at = ?, finished_at = ? + WHERE id = ? AND status IN (?, ?) + """, + ( + JobStatus.CANCELED, + utcnow().isoformat(), + utcnow().isoformat(), + job_id, + JobStatus.QUEUED, + JobStatus.RUNNING, + ), + ) + return result.rowcount > 0 + + def claim_next_queued_job(self) -> JobRecord | None: + with self._connect() as connection: + connection.execute("BEGIN IMMEDIATE") + row = connection.execute( + "SELECT * FROM jobs WHERE status = ? ORDER BY created_at ASC LIMIT 1", + (JobStatus.QUEUED,), + ).fetchone() + if row is None: + connection.commit() + return None + now = utcnow().isoformat() + connection.execute( + """ + UPDATE jobs + SET status = ?, + updated_at = ?, + started_at = COALESCE(started_at, ?), + attempts = attempts + 1 + WHERE id = ? + """, + (JobStatus.RUNNING, now, now, row["id"]), + ) + updated = connection.execute("SELECT * FROM jobs WHERE id = ?", (row["id"],)).fetchone() + connection.commit() + return self._deserialize_job(connection, updated) + + def requeue_running_jobs(self) -> int: + with self._connect() as connection: + result = connection.execute( + "UPDATE jobs SET status = ?, updated_at = ? WHERE status = ?", + (JobStatus.QUEUED, utcnow().isoformat(), JobStatus.RUNNING), + ) + return result.rowcount + + def _deserialize_job(self, connection: sqlite3.Connection, row: sqlite3.Row) -> JobRecord: + artifacts = self._artifacts_for_job(connection, row["id"]) + return JobRecord.model_validate( + { + "id": row["id"], + "status": row["status"], + "created_at": row["created_at"], + "updated_at": row["updated_at"], + "queued_at": row["queued_at"], + "started_at": row["started_at"], + "finished_at": row["finished_at"], + "queue_mode": row["queue_mode"], + "request": json.loads(row["request_json"]), + "route_decision": json.loads(row["route_json"]) if row["route_json"] else None, + "result": json.loads(row["result_json"]) if row["result_json"] else None, + "error": json.loads(row["error_json"]) if row["error_json"] else None, + "attempts": row["attempts"], + "artifacts": artifacts, + } + ) + + def _artifacts_for_job( + self, connection: sqlite3.Connection, job_id: str + ) -> list[ArtifactRecord]: + rows = connection.execute( + "SELECT * FROM artifacts WHERE job_id = ? ORDER BY created_at ASC", + (job_id,), + ).fetchall() + return [ + ArtifactRecord.model_validate( + { + "id": row["id"], + "job_id": row["job_id"], + "artifact_type": row["artifact_type"], + "relative_path": row["relative_path"], + "created_at": row["created_at"], + "metadata": json.loads(row["metadata_json"]), + } + ) + for row in rows + ] + + def _connect(self) -> sqlite3.Connection: + connection = sqlite3.connect(self.database_path) + connection.row_factory = sqlite3.Row + return connection diff --git a/src/lai/layout.py b/src/lai/layout.py index 86a8b8f..f18f223 100644 --- a/src/lai/layout.py +++ b/src/lai/layout.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from pathlib import Path RUNTIME_DIRECTORIES = ( @@ -5,6 +7,7 @@ Path("data/models/airllm-shards"), Path("data/models/raw"), Path("data/artifacts"), + Path("data/state"), Path("logs"), ) @@ -12,3 +15,10 @@ def runtime_directories(root_dir: Path) -> list[Path]: """Resolve the runtime directories relative to the repository root.""" return [root_dir / path for path in RUNTIME_DIRECTORIES] + + +def ensure_runtime_directories(root_dir: Path) -> list[Path]: + resolved = runtime_directories(root_dir) + for path in resolved: + path.mkdir(parents=True, exist_ok=True) + return resolved diff --git a/src/lai/providers/__init__.py b/src/lai/providers/__init__.py index 42d72b4..a5001c2 100644 --- a/src/lai/providers/__init__.py +++ b/src/lai/providers/__init__.py @@ -1 +1,4 @@ -"""Runtime provider adapters live here.""" +from .base import Provider +from .registry import ProviderRegistry + +__all__ = ["Provider", "ProviderRegistry"] diff --git a/src/lai/providers/base.py b/src/lai/providers/base.py new file mode 100644 index 0000000..a154656 --- /dev/null +++ b/src/lai/providers/base.py @@ -0,0 +1,93 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from time import perf_counter +from typing import Any + +from ..domain import ExecutionResult, FinishReason, ModelSpec, ProviderHealth, ProviderRequest +from ..settings import Settings +from ..system import SystemSnapshot + + +def serialize_raw_object(value: Any) -> dict[str, Any]: + if hasattr(value, "model_dump"): + try: + return value.model_dump() + except Exception: + pass + if hasattr(value, "dict"): + try: + return value.dict() + except Exception: + pass + return {"repr": repr(value)} + + +class Provider(ABC): + provider_id: str + + def __init__(self, settings: Settings, system_snapshot: SystemSnapshot) -> None: + self.settings = settings + self.system_snapshot = system_snapshot + + @abstractmethod + def healthcheck(self, model: ModelSpec) -> ProviderHealth: + raise NotImplementedError + + def available(self, model: ModelSpec) -> ProviderHealth: + return self.healthcheck(model) + + @abstractmethod + def generate(self, model: ModelSpec, request: ProviderRequest) -> ExecutionResult: + raise NotImplementedError + + @abstractmethod + def describe_capabilities(self) -> dict[str, Any]: + raise NotImplementedError + + def _result( + self, + *, + text: str, + provider_id: str, + model_id: str, + duration_seconds: float, + finish_reason: FinishReason = FinishReason.STOP, + usage: dict[str, Any] | None = None, + raw: dict[str, Any] | None = None, + stage: str = "executor", + ) -> ExecutionResult: + return ExecutionResult( + text=text, + provider_id=provider_id, + model_id=model_id, + duration_seconds=duration_seconds, + finish_reason=finish_reason, + usage=usage, + raw=raw or {}, + stage=stage, + ) + + def _missing_dependency_health(self, provider_id: str, dependency_name: str) -> ProviderHealth: + return ProviderHealth( + provider_id=provider_id, + available=False, + healthy=False, + reasons=[f"Missing optional dependency {dependency_name!r}."], + ) + + def _credential_health(self, provider_id: str, credential_name: str) -> ProviderHealth: + return ProviderHealth( + provider_id=provider_id, + available=False, + healthy=False, + reasons=[f"Missing credential {credential_name!r}."], + ) + + def _ok_health(self, provider_id: str, **metadata: Any) -> ProviderHealth: + return ProviderHealth( + provider_id=provider_id, available=True, healthy=True, metadata=metadata + ) + + def _timer(self) -> float: + return perf_counter() diff --git a/src/lai/providers/implementations.py b/src/lai/providers/implementations.py new file mode 100644 index 0000000..3d984e7 --- /dev/null +++ b/src/lai/providers/implementations.py @@ -0,0 +1,335 @@ +from __future__ import annotations + +from time import perf_counter +from typing import Any + +from ..domain import ( + FinishReason, + ModelRuntime, + ModelSpec, + ProviderHealth, + ProviderRequest, + UsageStats, +) +from ..settings import Settings +from ..system import available_disk_gb +from .base import Provider, serialize_raw_object + + +class TransformersProvider(Provider): + provider_id = ModelRuntime.TRANSFORMERS.value + + def __init__(self, settings: Settings, system_snapshot) -> None: # type: ignore[override] + super().__init__(settings, system_snapshot) + self._pipeline_cache: dict[str, Any] = {} + + def healthcheck(self, model: ModelSpec) -> ProviderHealth: + try: + import transformers # noqa: F401 + except ImportError: + return self._missing_dependency_health(self.provider_id, "transformers") + + if model.model_ref.startswith("meta-llama/") and not self.settings.huggingface_token_value: + return self._credential_health(self.provider_id, "LAI_HF_TOKEN") + + return self._ok_health(self.provider_id, model_ref=model.model_ref) + + def describe_capabilities(self) -> dict[str, Any]: + return { + "supports_local_generation": True, + "supports_streaming": False, + "requires_network_once": True, + } + + def generate(self, model: ModelSpec, request: ProviderRequest): + from transformers import pipeline + + started = perf_counter() + pipe = self._pipeline_cache.get(model.id) + if pipe is None: + pipe = pipeline( + "text-generation", + model=model.model_ref, + tokenizer=model.model_ref, + device_map="auto" if self.system_snapshot.has_gpu else None, + trust_remote_code=True, + model_kwargs={ + "token": self.settings.huggingface_token_value, + "torch_dtype": "auto", + }, + ) + self._pipeline_cache[model.id] = pipe + + prompt = _render_prompt(request.system_prompt, request.user_prompt) + result = pipe( + prompt, + max_new_tokens=request.max_output_tokens, + temperature=request.temperature, + return_full_text=False, + ) + text = result[0]["generated_text"].strip() if result else "" + duration = perf_counter() - started + return self._result( + text=text, + provider_id=self.provider_id, + model_id=model.id, + duration_seconds=duration, + finish_reason=FinishReason.STOP, + raw={"result_count": len(result)}, + ) + + +class AirLLMProvider(Provider): + provider_id = ModelRuntime.AIRLLM.value + + def __init__(self, settings: Settings, system_snapshot) -> None: # type: ignore[override] + super().__init__(settings, system_snapshot) + self._model_cache: dict[str, Any] = {} + + def healthcheck(self, model: ModelSpec) -> ProviderHealth: + try: + from airllm import AutoModel # noqa: F401 + except ImportError: + return self._missing_dependency_health(self.provider_id, "airllm") + + if not self.settings.huggingface_token_value: + return self._credential_health(self.provider_id, "LAI_HF_TOKEN") + + free_disk = available_disk_gb(self.settings.resolved_airllm_shards_dir) + expected_disk = model.hardware.expected_disk_gb or 0 + if expected_disk and free_disk < expected_disk: + return ProviderHealth( + provider_id=self.provider_id, + available=False, + healthy=False, + reasons=[ + "Insufficient disk for AirLLM shards. " + f"Need about {expected_disk} GB, found {free_disk:.2f} GB." + ], + metadata={"free_disk_gb": free_disk}, + ) + + if not self.system_snapshot.has_gpu and not model.runtime_hints.allow_cpu_fallback: + return ProviderHealth( + provider_id=self.provider_id, + available=False, + healthy=False, + reasons=["GPU unavailable and CPU fallback is disabled for this model."], + ) + + return self._ok_health( + self.provider_id, + model_ref=model.model_ref, + free_disk_gb=free_disk, + gpu_name=self.system_snapshot.gpu_name, + ) + + def describe_capabilities(self) -> dict[str, Any]: + return { + "supports_local_generation": True, + "supports_streaming": False, + "supports_layer_sharding": True, + "supports_cpu_fallback": True, + } + + def generate(self, model: ModelSpec, request: ProviderRequest): + from airllm import AutoModel + + started = perf_counter() + loaded_model = self._model_cache.get(model.id) + if loaded_model is None: + loaded_model = AutoModel.from_pretrained( + model.model_ref, + hf_token=self.settings.huggingface_token_value, + layer_shards_saving_path=str(self.settings.resolved_airllm_shards_dir / model.id), + ) + self._model_cache[model.id] = loaded_model + + prompt = _render_prompt(request.system_prompt, request.user_prompt) + tokenized = loaded_model.tokenizer( + [prompt], + return_tensors="pt", + return_attention_mask=False, + truncation=True, + max_length=min(model.context_window, self.settings.max_router_tokens * 2), + padding=False, + ) + + input_ids = tokenized["input_ids"] + if self.system_snapshot.has_gpu: + input_ids = input_ids.cuda() + + generation = loaded_model.generate( + input_ids, + max_new_tokens=request.max_output_tokens, + use_cache=True, + return_dict_in_generate=True, + ) + decoded = loaded_model.tokenizer.decode(generation.sequences[0], skip_special_tokens=True) + text = _strip_prompt_prefix(decoded, prompt) + duration = perf_counter() - started + return self._result( + text=text.strip(), + provider_id=self.provider_id, + model_id=model.id, + duration_seconds=duration, + finish_reason=FinishReason.STOP, + raw={"generated_tokens": len(generation.sequences[0])}, + ) + + +class OpenAIProvider(Provider): + provider_id = ModelRuntime.OPENAI.value + + def healthcheck(self, model: ModelSpec) -> ProviderHealth: + try: + import openai # noqa: F401 + except ImportError: + return self._missing_dependency_health(self.provider_id, "openai") + if not self.settings.openai_api_key_value: + return self._credential_health(self.provider_id, "LAI_OPENAI_API_KEY") + return self._ok_health(self.provider_id, model_ref=model.model_ref) + + def describe_capabilities(self) -> dict[str, Any]: + return {"supports_remote_generation": True, "supports_streaming": False} + + def generate(self, model: ModelSpec, request: ProviderRequest): + from openai import OpenAI + + started = perf_counter() + client = OpenAI(api_key=self.settings.openai_api_key_value, timeout=request.timeout_seconds) + response = client.responses.create( + model=request.model_override or model.model_ref, + instructions=request.system_prompt or None, + input=request.user_prompt, + max_output_tokens=request.max_output_tokens, + temperature=request.temperature, + ) + duration = perf_counter() - started + usage = getattr(response, "usage", None) + usage_stats = ( + UsageStats( + prompt_tokens=getattr(usage, "input_tokens", None), + completion_tokens=getattr(usage, "output_tokens", None), + total_tokens=getattr(usage, "total_tokens", None), + ) + if usage + else None + ) + return self._result( + text=(getattr(response, "output_text", "") or "").strip(), + provider_id=self.provider_id, + model_id=model.id, + duration_seconds=duration, + finish_reason=FinishReason.STOP, + usage=usage_stats.model_dump() if usage_stats else None, + raw=serialize_raw_object(response), + ) + + +class AnthropicProvider(Provider): + provider_id = ModelRuntime.ANTHROPIC.value + + def healthcheck(self, model: ModelSpec) -> ProviderHealth: + try: + import anthropic # noqa: F401 + except ImportError: + return self._missing_dependency_health(self.provider_id, "anthropic") + if not self.settings.anthropic_api_key_value: + return self._credential_health(self.provider_id, "LAI_ANTHROPIC_API_KEY") + return self._ok_health(self.provider_id, model_ref=model.model_ref) + + def describe_capabilities(self) -> dict[str, Any]: + return {"supports_remote_generation": True, "supports_streaming": False} + + def generate(self, model: ModelSpec, request: ProviderRequest): + from anthropic import Anthropic + + started = perf_counter() + client = Anthropic( + api_key=self.settings.anthropic_api_key_value, timeout=request.timeout_seconds + ) + response = client.messages.create( + model=request.model_override or model.model_ref, + system=request.system_prompt or "", + messages=[{"role": "user", "content": request.user_prompt}], + max_tokens=request.max_output_tokens, + temperature=request.temperature, + ) + text_parts = [ + block.text for block in response.content if getattr(block, "type", None) == "text" + ] + usage = getattr(response, "usage", None) + usage_stats = ( + UsageStats( + prompt_tokens=getattr(usage, "input_tokens", None), + completion_tokens=getattr(usage, "output_tokens", None), + total_tokens=None, + ) + if usage + else None + ) + duration = perf_counter() - started + return self._result( + text="".join(text_parts).strip(), + provider_id=self.provider_id, + model_id=model.id, + duration_seconds=duration, + finish_reason=FinishReason.STOP, + usage=usage_stats.model_dump() if usage_stats else None, + raw=serialize_raw_object(response), + ) + + +class GeminiProvider(Provider): + provider_id = ModelRuntime.GEMINI.value + + def healthcheck(self, model: ModelSpec) -> ProviderHealth: + try: + from google import genai # noqa: F401 + except ImportError: + return self._missing_dependency_health(self.provider_id, "google-genai") + if not self.settings.gemini_api_key_value: + return self._credential_health(self.provider_id, "LAI_GEMINI_API_KEY") + return self._ok_health(self.provider_id, model_ref=model.model_ref) + + def describe_capabilities(self) -> dict[str, Any]: + return {"supports_remote_generation": True, "supports_streaming": False} + + def generate(self, model: ModelSpec, request: ProviderRequest): + from google import genai + + started = perf_counter() + client = genai.Client(api_key=self.settings.gemini_api_key_value) + config: dict[str, Any] = { + "temperature": request.temperature, + "max_output_tokens": request.max_output_tokens, + } + if request.system_prompt: + config["system_instruction"] = request.system_prompt + response = client.models.generate_content( + model=request.model_override or model.model_ref, + contents=request.user_prompt, + config=config, + ) + duration = perf_counter() - started + return self._result( + text=(getattr(response, "text", "") or "").strip(), + provider_id=self.provider_id, + model_id=model.id, + duration_seconds=duration, + finish_reason=FinishReason.STOP, + raw=serialize_raw_object(response), + ) + + +def _render_prompt(system_prompt: str | None, user_prompt: str) -> str: + if system_prompt: + return f"System:\n{system_prompt}\n\nUser:\n{user_prompt}\n\nAssistant:\n" + return user_prompt + + +def _strip_prompt_prefix(decoded_text: str, prompt: str) -> str: + if decoded_text.startswith(prompt): + return decoded_text[len(prompt) :] + return decoded_text diff --git a/src/lai/providers/registry.py b/src/lai/providers/registry.py new file mode 100644 index 0000000..456b43f --- /dev/null +++ b/src/lai/providers/registry.py @@ -0,0 +1,48 @@ +from __future__ import annotations + +from typing import Mapping + +from ..config import AppConfig +from ..domain import ModelRuntime, ModelSpec, ProviderHealth +from ..settings import Settings +from ..system import SystemSnapshot +from .base import Provider +from .implementations import ( + AirLLMProvider, + AnthropicProvider, + GeminiProvider, + OpenAIProvider, + TransformersProvider, +) + + +class ProviderRegistry: + def __init__( + self, + settings: Settings, + system_snapshot: SystemSnapshot, + providers: Mapping[str, Provider] | None = None, + ) -> None: + self.settings = settings + self.system_snapshot = system_snapshot + self._providers = dict(providers or {}) + if not self._providers: + self._providers = { + ModelRuntime.TRANSFORMERS.value: TransformersProvider(settings, system_snapshot), + ModelRuntime.AIRLLM.value: AirLLMProvider(settings, system_snapshot), + ModelRuntime.OPENAI.value: OpenAIProvider(settings, system_snapshot), + ModelRuntime.ANTHROPIC.value: AnthropicProvider(settings, system_snapshot), + ModelRuntime.GEMINI.value: GeminiProvider(settings, system_snapshot), + } + + def get(self, provider_id: str) -> Provider: + return self._providers[provider_id] + + def provider_for_model(self, model: ModelSpec) -> Provider: + return self.get(model.provider_id) + + def healthcheck(self, model: ModelSpec) -> ProviderHealth: + return self.provider_for_model(model).healthcheck(model) + + def model_healthchecks(self, config: AppConfig) -> dict[str, ProviderHealth]: + return {model.id: self.healthcheck(model) for model in config.model_catalog.models} diff --git a/src/lai/routing/__init__.py b/src/lai/routing/__init__.py index 967ec34..3898fb8 100644 --- a/src/lai/routing/__init__.py +++ b/src/lai/routing/__init__.py @@ -1 +1,3 @@ -"""Routing and policy evaluation lives here.""" +from .engine import RoutingEngine + +__all__ = ["RoutingEngine"] diff --git a/src/lai/routing/engine.py b/src/lai/routing/engine.py new file mode 100644 index 0000000..c2b919e --- /dev/null +++ b/src/lai/routing/engine.py @@ -0,0 +1,351 @@ +from __future__ import annotations + +import json +from enum import Enum +from typing import Iterable + +from ..config import AppConfig +from ..domain import ( + ExecutionRequest, + ModelRuntime, + ModelSpec, + ProviderRequest, + RouteReason, + RoutingDecision, + RoutingTier, +) +from ..providers import ProviderRegistry +from ..settings import Settings +from ..system import SystemSnapshot +from .heuristics import build_route_context + + +class RoutingEngine: + def __init__( + self, + settings: Settings, + config: AppConfig, + providers: ProviderRegistry, + system_snapshot: SystemSnapshot, + ) -> None: + self.settings = settings + self.config = config + self.providers = providers + self.system_snapshot = system_snapshot + + def route(self, request: ExecutionRequest) -> RoutingDecision: + reasons: list[RouteReason] = [] + context = build_route_context(request) + + candidate_models = [model for model in self.config.model_catalog.enabled_models()] + reasons.append( + RouteReason(stage="constraints", message="Started from enabled catalog models.") + ) + + if request.provider_override: + candidate_models = [ + model + for model in candidate_models + if model.provider_id == request.provider_override + ] + reasons.append( + RouteReason( + stage="constraints", + message=f"Applied provider override {request.provider_override!r}.", + ) + ) + + if request.model_override: + model = self.config.model_catalog.get_model(request.model_override) + candidate_models = [model] + reasons.append( + RouteReason( + stage="constraints", + message=f"Applied model override {request.model_override!r}.", + ) + ) + + if not self.settings.allow_overnight_jobs and context.overnight_requested: + reasons.append( + RouteReason( + stage="constraints", + message=( + "Overnight execution requested but disabled by settings; " + "deep-work routing will be avoided." + ), + ) + ) + + available_models = self._availability_filter(candidate_models, reasons) + tier = self._resolve_tier(request, context, reasons) + + planner_model = self._resolve_stage_model( + stage="planner", + preferred_model_id=tier.planner_model_id or self.config.routing_policy.router.model_id, + candidate_models=available_models, + required_capabilities=["planning"], + reasons=reasons, + ) + executor_capabilities = list(context.required_capabilities) + executor_model = self._resolve_stage_model( + stage="executor", + preferred_model_id=request.model_override or tier.executor_model_id, + candidate_models=available_models, + required_capabilities=executor_capabilities, + reasons=reasons, + ) + + reviewer_model_id: str | None = None + should_review = ( + request.reviewer_enabled + if request.reviewer_enabled is not None + else tier.resolved_reviewer_enabled + ) + if should_review: + reviewer_model = self._resolve_stage_model( + stage="reviewer", + preferred_model_id=tier.reviewer_model_id, + candidate_models=available_models, + required_capabilities=["critique", "validation"], + reasons=reasons, + ) + reviewer_model_id = reviewer_model.id if reviewer_model else None + if reviewer_model_id is None: + should_review = False + reasons.append( + RouteReason( + stage="fallback", + message=( + "Reviewer stage disabled because no available " + "reviewer-capable model was found." + ), + ) + ) + + fallback_model_ids = self._fallback_chain( + preferred_model_id=executor_model.id, + candidate_models=available_models, + required_capabilities=executor_capabilities, + ) + queue_mode = _enum_value(request.queue_mode) + queue_recommended = queue_mode == "queued" or ( + queue_mode == "auto" + and (tier.allow_overnight or context.expected_duration == "long") + ) + + return RoutingDecision( + matched_tier_id=tier.id, + planner_model_id=planner_model.id if planner_model else None, + executor_model_id=executor_model.id, + reviewer_model_id=reviewer_model_id, + fallback_model_ids=fallback_model_ids, + reasons=reasons, + queue_recommended=queue_recommended, + should_plan=tier.should_plan and planner_model is not None, + should_review=should_review, + context=context, + ) + + def _availability_filter( + self, + candidate_models: list[ModelSpec], + reasons: list[RouteReason], + ) -> list[ModelSpec]: + available: list[ModelSpec] = [] + for model in candidate_models: + if model.runtime == ModelRuntime.AIRLLM and not self.system_snapshot.has_gpu: + fallback = self.config.routing_policy.fallbacks.when_gpu_unavailable + if fallback and fallback.executor_model_id != model.id: + reasons.append( + RouteReason( + stage="availability", + message=f"Skipped {model.id!r} because GPU is unavailable.", + ) + ) + continue + + health = self.providers.healthcheck(model) + if health.available: + available.append(model) + continue + reasons.append( + RouteReason( + stage="availability", + message=f"Skipped {model.id!r}: {'; '.join(health.reasons)}", + ) + ) + if not available and candidate_models: + reasons.append( + RouteReason( + stage="availability", + message=( + "No models passed provider availability checks. " + "Continuing with configured models for route explanation only." + ), + ) + ) + return candidate_models + if not available: + raise RuntimeError("No models remain after provider and environment filtering.") + return available + + def _resolve_tier( + self, + request: ExecutionRequest, + context, + reasons: list[RouteReason], + ) -> RoutingTier: + scored_tiers = sorted( + self.config.routing_policy.tiers, + key=lambda tier: (tier.match.score(context), tier.allow_overnight, tier.id), + reverse=True, + ) + best_tier = scored_tiers[0] + reasons.append( + RouteReason( + stage="heuristics", + message=( + f"Heuristics classified complexity={context.complexity}, " + f"expected_duration={context.expected_duration}, urgency={context.urgency}." + ), + ) + ) + + if ( + self.config.routing_policy.router.use_model_router + and best_tier.match.score(context) < 2 + and not request.model_override + ): + llm_suggestion = self._router_model_suggestion(request) + if llm_suggestion: + suggested_tier = next( + ( + tier + for tier in self.config.routing_policy.tiers + if tier.id == llm_suggestion + ), + None, + ) + if suggested_tier is not None: + best_tier = suggested_tier + reasons.append( + RouteReason( + stage="model-router", + message=f"Router model suggested tier {suggested_tier.id!r}.", + ) + ) + + if best_tier.allow_overnight and not self.settings.allow_overnight_jobs: + non_overnight = [ + tier for tier in self.config.routing_policy.tiers if not tier.allow_overnight + ] + if non_overnight: + best_tier = non_overnight[0] + reasons.append( + RouteReason( + stage="constraints", + message=( + f"Selected fallback tier {best_tier.id!r} " + "because overnight jobs are disabled." + ), + ) + ) + + reasons.append(RouteReason(stage="tier", message=f"Matched routing tier {best_tier.id!r}.")) + return best_tier + + def _resolve_stage_model( + self, + *, + stage: str, + preferred_model_id: str | None, + candidate_models: list[ModelSpec], + required_capabilities: list[str], + reasons: list[RouteReason], + ) -> ModelSpec | None: + if preferred_model_id: + preferred = next( + (model for model in candidate_models if model.id == preferred_model_id), None + ) + if preferred and preferred.supports_capabilities(required_capabilities): + reasons.append( + RouteReason( + stage=stage, + message=f"Selected preferred {stage} model {preferred.id!r}.", + ) + ) + return preferred + + for model in candidate_models: + if model.supports_capabilities(required_capabilities): + reasons.append( + RouteReason( + stage="fallback", + message=f"Selected {stage} fallback model {model.id!r}.", + ) + ) + return model + + if stage == "reviewer": + return None + raise RuntimeError( + f"No available {stage} model can satisfy capabilities {required_capabilities!r}." + ) + + def _fallback_chain( + self, + *, + preferred_model_id: str, + candidate_models: Iterable[ModelSpec], + required_capabilities: list[str], + ) -> list[str]: + fallback_ids: list[str] = [] + for model in candidate_models: + if model.id == preferred_model_id: + continue + if model.supports_capabilities(required_capabilities): + fallback_ids.append(model.id) + return fallback_ids + + def _router_model_suggestion(self, request: ExecutionRequest) -> str | None: + router_model = self.config.model_catalog.get_model( + self.config.routing_policy.router.model_id + ) + health = self.providers.healthcheck(router_model) + if not health.available: + return None + + provider = self.providers.provider_for_model(router_model) + prompt = ( + "Classify the request into one routing tier and return JSON only with a single key " + '"tier_id" whose value is one of: instant, standard, deep-work.\n\n' + f"Request:\n{request.user_prompt}" + ) + try: + response = provider.generate( + router_model, + ProviderRequest( + system_prompt="You are a strict router. Return JSON only.", + user_prompt=prompt, + temperature=0, + max_output_tokens=64, + timeout_seconds=min(request.timeout_seconds, 30), + ), + ) + except Exception: + return None + + try: + payload = json.loads(response.text) + except json.JSONDecodeError: + return None + tier_id = payload.get("tier_id") + if tier_id in {"instant", "standard", "deep-work"}: + return str(tier_id) + return None + + +def _enum_value(value: object) -> str: + if isinstance(value, Enum): + return str(value.value) + return str(value) diff --git a/src/lai/routing/heuristics.py b/src/lai/routing/heuristics.py new file mode 100644 index 0000000..e0044d2 --- /dev/null +++ b/src/lai/routing/heuristics.py @@ -0,0 +1,75 @@ +from __future__ import annotations + +from ..domain import ExecutionRequest, RouteContext + +DEEP_KEYWORDS = ( + "advanced", + "comprehensive", + "research", + "full", + "complete", + "detailed", + "analyze", + "architecture", + "plan", + "strategy", +) +URGENT_KEYWORDS = ("urgent", "asap", "immediately", "quick", "fast", "today") +HIGH_RISK_KEYWORDS = ("security", "medical", "legal", "finance", "exploit", "breach") +OVERNIGHT_KEYWORDS = ("overnight", "whole night", "background", "long-running") + + +def build_route_context(request: ExecutionRequest) -> RouteContext: + lowered = request.user_prompt.lower() + prompt_length = len(request.user_prompt) + + complexity = "low" + if prompt_length > 1800 or any(keyword in lowered for keyword in DEEP_KEYWORDS): + complexity = "high" + elif prompt_length > 600: + complexity = "medium" + + urgency = "low" + if any(keyword in lowered for keyword in URGENT_KEYWORDS): + urgency = "high" + elif "soon" in lowered: + urgency = "medium" + + cost_sensitivity = "low" + if any(keyword in lowered for keyword in ("cheap", "low cost", "fast", "small model")): + cost_sensitivity = "high" + elif "balanced" in lowered: + cost_sensitivity = "medium" + + safety_risk = "low" + if any(keyword in lowered for keyword in HIGH_RISK_KEYWORDS): + safety_risk = "high" + elif any(keyword in lowered for keyword in ("privacy", "compliance", "policy")): + safety_risk = "medium" + + expected_duration = "short" + overnight_requested = any(keyword in lowered for keyword in OVERNIGHT_KEYWORDS) + if overnight_requested or complexity == "high": + expected_duration = "long" + elif complexity == "medium": + expected_duration = "medium" + + required_capabilities = list(request.required_capabilities) + if complexity == "high": + required_capabilities.extend(["deep-reasoning", "long-form-generation"]) + elif "summary" in lowered or "summarize" in lowered: + required_capabilities.append("summarization") + + return RouteContext( + prompt_length=prompt_length, + complexity=complexity, + urgency=urgency, + cost_sensitivity=cost_sensitivity, + safety_risk=safety_risk, + expected_duration=expected_duration, + requires_high_quality=complexity == "high", + overnight_requested=overnight_requested, + required_capabilities=sorted(set(required_capabilities)), + model_override=request.model_override, + provider_override=request.provider_override, + ) diff --git a/src/lai/serialization.py b/src/lai/serialization.py new file mode 100644 index 0000000..6a3288e --- /dev/null +++ b/src/lai/serialization.py @@ -0,0 +1,24 @@ +from __future__ import annotations + +import json +from datetime import datetime +from pathlib import Path +from typing import Any + + +def to_jsonable(value: Any) -> Any: + if isinstance(value, dict): + return {str(key): to_jsonable(item) for key, item in value.items()} + if isinstance(value, list): + return [to_jsonable(item) for item in value] + if isinstance(value, tuple): + return [to_jsonable(item) for item in value] + if isinstance(value, Path): + return str(value) + if isinstance(value, datetime): + return value.isoformat() + return value + + +def dumps_pretty(data: Any) -> str: + return json.dumps(to_jsonable(data), indent=2, sort_keys=True, ensure_ascii=True) diff --git a/src/lai/settings.py b/src/lai/settings.py index 24ea9f2..a21274e 100644 --- a/src/lai/settings.py +++ b/src/lai/settings.py @@ -1,7 +1,9 @@ +from __future__ import annotations + from functools import cached_property from pathlib import Path -from pydantic import Field +from pydantic import Field, SecretStr from pydantic_settings import BaseSettings, SettingsConfigDict @@ -10,7 +12,7 @@ def _default_root_dir() -> Path: class Settings(BaseSettings): - """Application settings shared across the future platform services.""" + """Application settings shared across the platform services.""" model_config = SettingsConfigDict( env_file=".env", @@ -21,17 +23,36 @@ class Settings(BaseSettings): environment: str = "local" project_name: str = "LAI" root_dir: Path = Field(default_factory=_default_root_dir) + model_catalog: Path = Path("configs/models/catalog.yaml") routing_policy: Path = Path("configs/routing/policies.yaml") prompt_root: Path = Path("configs/prompts") + huggingface_cache_dir: Path = Path("data/cache/huggingface") airllm_shards_dir: Path = Path("data/models/airllm-shards") + raw_models_dir: Path = Path("data/models/raw") artifacts_dir: Path = Path("data/artifacts") + state_dir: Path = Path("data/state") + database_path: Path = Path("data/state/lai.db") logs_dir: Path = Path("logs") + allow_overnight_jobs: bool = True enable_gpu: bool = True max_router_tokens: int = 2048 + default_timeout_seconds: int = 120 + default_max_output_tokens: int = 1024 + default_temperature: float = 0.2 + queue_poll_interval_seconds: int = 5 + worker_idle_sleep_seconds: float = 2.0 + max_retry_attempts: int = 1 + stale_running_job_timeout_seconds: int = 900 + + hf_token: SecretStr | None = None + openai_api_key: SecretStr | None = None + anthropic_api_key: SecretStr | None = None + gemini_api_key: SecretStr | None = None + @cached_property def resolved_model_catalog(self) -> Path: return self.root_dir / self.model_catalog @@ -39,3 +60,55 @@ def resolved_model_catalog(self) -> Path: @cached_property def resolved_routing_policy(self) -> Path: return self.root_dir / self.routing_policy + + @cached_property + def resolved_prompt_root(self) -> Path: + return self.root_dir / self.prompt_root + + @cached_property + def resolved_huggingface_cache_dir(self) -> Path: + return self.root_dir / self.huggingface_cache_dir + + @cached_property + def resolved_airllm_shards_dir(self) -> Path: + return self.root_dir / self.airllm_shards_dir + + @cached_property + def resolved_raw_models_dir(self) -> Path: + return self.root_dir / self.raw_models_dir + + @cached_property + def resolved_artifacts_dir(self) -> Path: + return self.root_dir / self.artifacts_dir + + @cached_property + def resolved_state_dir(self) -> Path: + return self.root_dir / self.state_dir + + @cached_property + def resolved_database_path(self) -> Path: + return self.root_dir / self.database_path + + @cached_property + def resolved_logs_dir(self) -> Path: + return self.root_dir / self.logs_dir + + @staticmethod + def _secret_value(secret: SecretStr | None) -> str | None: + return secret.get_secret_value() if secret else None + + @property + def huggingface_token_value(self) -> str | None: + return self._secret_value(self.hf_token) + + @property + def openai_api_key_value(self) -> str | None: + return self._secret_value(self.openai_api_key) + + @property + def anthropic_api_key_value(self) -> str | None: + return self._secret_value(self.anthropic_api_key) + + @property + def gemini_api_key_value(self) -> str | None: + return self._secret_value(self.gemini_api_key) diff --git a/src/lai/system.py b/src/lai/system.py new file mode 100644 index 0000000..cf1b1d4 --- /dev/null +++ b/src/lai/system.py @@ -0,0 +1,67 @@ +from __future__ import annotations + +import shutil +import subprocess +from pathlib import Path + +from pydantic import BaseModel, Field + + +class SystemSnapshot(BaseModel): + has_gpu: bool + gpu_name: str | None = None + free_disk_gb: dict[str, float] = Field(default_factory=dict) + + +def collect_system_snapshot(*paths: Path, enable_gpu: bool = True) -> SystemSnapshot: + free_disk_gb: dict[str, float] = {} + for path in paths: + free_disk_gb[str(path)] = round(available_disk_gb(path), 2) + has_gpu, gpu_name = detect_gpu(enable_gpu=enable_gpu) + return SystemSnapshot(has_gpu=has_gpu, gpu_name=gpu_name, free_disk_gb=free_disk_gb) + + +def detect_gpu(enable_gpu: bool = True) -> tuple[bool, str | None]: + if not enable_gpu: + return False, None + + try: + import torch + + if torch.cuda.is_available(): + return True, str(torch.cuda.get_device_name(0)) + except Exception: + pass + + try: + result = subprocess.run( + ["nvidia-smi", "--query-gpu=name", "--format=csv,noheader"], + check=False, + capture_output=True, + text=True, + ) + except FileNotFoundError: + return False, None + + if result.returncode != 0: + return False, None + + lines = [line.strip() for line in result.stdout.splitlines() if line.strip()] + if not lines: + return False, None + return True, lines[0] + + +def available_disk_gb(path: Path) -> float: + target = _nearest_existing_path(path) + usage = shutil.disk_usage(target) + return usage.free / (1024**3) + + +def _nearest_existing_path(path: Path) -> Path: + candidate = path + while not candidate.exists(): + if candidate.parent == candidate: + break + candidate = candidate.parent + return candidate diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..60fee60 --- /dev/null +++ b/tests/__init__.py @@ -0,0 +1 @@ +# Test package for LAI. diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..94b7195 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,12 @@ +from __future__ import annotations + +from pathlib import Path + +import pytest + +REPO_ROOT = Path(__file__).resolve().parents[1] + + +@pytest.fixture() +def repo_root() -> Path: + return REPO_ROOT diff --git a/tests/helpers.py b/tests/helpers.py new file mode 100644 index 0000000..72fefb7 --- /dev/null +++ b/tests/helpers.py @@ -0,0 +1,57 @@ +from __future__ import annotations + +from typing import Any + +from lai.domain import ExecutionResult, FinishReason, ModelSpec, ProviderHealth, ProviderRequest +from lai.providers.base import Provider +from lai.settings import Settings +from lai.system import SystemSnapshot + + +class FakeProvider(Provider): + def __init__( + self, + settings: Settings, + system_snapshot: SystemSnapshot, + provider_id: str, + *, + available: bool = True, + ) -> None: + super().__init__(settings, system_snapshot) + self.provider_id = provider_id + self._available = available + + def healthcheck(self, model: ModelSpec) -> ProviderHealth: + if self._available: + return ProviderHealth(provider_id=self.provider_id, available=True, healthy=True) + return ProviderHealth( + provider_id=self.provider_id, + available=False, + healthy=False, + reasons=[f"{self.provider_id} unavailable in test harness"], + ) + + def generate(self, model: ModelSpec, request: ProviderRequest) -> ExecutionResult: + stage = str(request.metadata.get("stage", "executor")) + if stage == "planner": + text = "plan: inspect, execute, verify" + elif stage == "reviewer": + text = "review: output looks good" + elif stage == "executor": + text = f"executed by {model.id}: {request.user_prompt}" + elif model.role == "router": + text = '{"tier_id": "standard"}' + else: + text = f"generated by {model.id}: {request.user_prompt}" + return ExecutionResult( + text=text, + finish_reason=FinishReason.STOP, + provider_id=self.provider_id, + model_id=model.id, + duration_seconds=0.01, + raw={"stage": stage}, + stage=stage, + ) + + def describe_capabilities(self) -> dict[str, Any]: + return {"test_provider": True} diff --git a/tests/integration/test_orchestration.py b/tests/integration/test_orchestration.py new file mode 100644 index 0000000..ca1127b --- /dev/null +++ b/tests/integration/test_orchestration.py @@ -0,0 +1,64 @@ +from pathlib import Path + +from lai.application import create_application +from lai.domain import ExecutionRequest, JobStatus, QueueMode +from lai.settings import Settings +from lai.system import collect_system_snapshot +from tests.helpers import FakeProvider + + +def _make_application(tmp_path, repo_root): + settings = Settings( + root_dir=tmp_path, + model_catalog=repo_root / "configs/models/catalog.yaml", + routing_policy=repo_root / "configs/routing/policies.yaml", + enable_gpu=True, + ) + snapshot = collect_system_snapshot( + settings.root_dir, + enable_gpu=False, + ) + providers = { + "transformers": FakeProvider(settings, snapshot, "transformers"), + "airllm": FakeProvider(settings, snapshot, "airllm"), + "openai": FakeProvider(settings, snapshot, "openai"), + "anthropic": FakeProvider(settings, snapshot, "anthropic"), + "gemini": FakeProvider(settings, snapshot, "gemini"), + } + return create_application(settings=settings, providers=providers) + + +def test_inline_execution_persists_artifacts(tmp_path, repo_root) -> None: + app = _make_application(tmp_path, repo_root) + + job = app.orchestration.submit_request( + ExecutionRequest( + user_prompt="Summarize this short note.", + queue_mode=QueueMode.INLINE, + ) + ) + + assert job.status == JobStatus.SUCCEEDED + assert job.result is not None + final_output = Path(app.settings.resolved_artifacts_dir) / job.id / "final_output.txt" + assert final_output.exists() + assert "executed by" in final_output.read_text(encoding="utf-8") + + +def test_queued_job_survives_restart(tmp_path, repo_root) -> None: + app = _make_application(tmp_path, repo_root) + job = app.orchestration.submit_request( + ExecutionRequest( + user_prompt="Create a comprehensive advanced architecture for this system.", + queue_mode=QueueMode.QUEUED, + ) + ) + + restarted = _make_application(tmp_path, repo_root) + processed = restarted.orchestration.run_worker(once=True) + persisted = restarted.job_store.get_job(job.id) + + assert processed == 1 + assert persisted is not None + assert persisted.status == JobStatus.SUCCEEDED + assert persisted.result is not None diff --git a/tests/live/test_local_smoke.py b/tests/live/test_local_smoke.py new file mode 100644 index 0000000..b4522c8 --- /dev/null +++ b/tests/live/test_local_smoke.py @@ -0,0 +1,74 @@ +from __future__ import annotations + +import os + +import pytest +from pydantic import SecretStr + +from lai.domain import ModelSpec, ProviderRequest +from lai.providers.implementations import AirLLMProvider, TransformersProvider +from lai.settings import Settings +from lai.system import collect_system_snapshot + + +def _settings(tmp_path, **overrides) -> Settings: + return Settings( + root_dir=tmp_path, + model_catalog=tmp_path / "unused-models.yaml", + routing_policy=tmp_path / "unused-routing.yaml", + **overrides, + ) + + +def test_transformers_live_smoke(tmp_path) -> None: + if os.getenv("LAI_RUN_LOCAL_SMOKE", "").lower() != "true": + pytest.skip("Set LAI_RUN_LOCAL_SMOKE=true to run local model smoke tests.") + pytest.importorskip("transformers") + + settings = _settings(tmp_path) + provider = TransformersProvider(settings, collect_system_snapshot(tmp_path, enable_gpu=False)) + model = ModelSpec.model_validate( + { + "id": "transformers-smoke", + "role": "executor", + "runtime": "transformers", + "model": "sshleifer/tiny-gpt2", + "capabilities": ["summarization"], + } + ) + + result = provider.generate( + model, + ProviderRequest(user_prompt="Reply with READY.", max_output_tokens=12), + ) + assert result.text + + +def test_airllm_manual_smoke(tmp_path) -> None: + if os.getenv("LAI_RUN_AIRLLM_SMOKE", "").lower() != "true": + pytest.skip("Set LAI_RUN_AIRLLM_SMOKE=true to run the AirLLM smoke test.") + pytest.importorskip("airllm") + + hf_token = os.getenv("LAI_HF_TOKEN") + if not hf_token: + pytest.skip("LAI_HF_TOKEN is required for the AirLLM smoke test.") + + model_name = os.getenv("LAI_AIRLLM_SMOKE_MODEL", "meta-llama/Llama-3.1-8B-Instruct") + settings = _settings(tmp_path, hf_token=SecretStr(hf_token)) + provider = AirLLMProvider(settings, collect_system_snapshot(tmp_path, enable_gpu=True)) + model = ModelSpec.model_validate( + { + "id": "airllm-smoke", + "role": "executor", + "runtime": "airllm", + "repo_id": model_name, + "hardware": {"expected_disk_gb": 1}, + "runtime_hints": {"allow_cpu_fallback": True, "allow_layer_sharding": True}, + } + ) + + result = provider.generate( + model, + ProviderRequest(user_prompt="Reply with READY.", max_output_tokens=12), + ) + assert result.text diff --git a/tests/live/test_provider_smoke.py b/tests/live/test_provider_smoke.py new file mode 100644 index 0000000..dcd34c0 --- /dev/null +++ b/tests/live/test_provider_smoke.py @@ -0,0 +1,92 @@ +from __future__ import annotations + +import os + +import pytest +from pydantic import SecretStr + +from lai.domain import ModelSpec, ProviderRequest +from lai.providers.implementations import AnthropicProvider, GeminiProvider, OpenAIProvider +from lai.settings import Settings +from lai.system import collect_system_snapshot + + +def _settings(tmp_path, **overrides) -> Settings: + return Settings( + root_dir=tmp_path, + model_catalog=tmp_path / "unused-models.yaml", + routing_policy=tmp_path / "unused-routing.yaml", + **overrides, + ) + + +def test_openai_live_smoke(tmp_path) -> None: + api_key = os.getenv("LAI_OPENAI_API_KEY") + if not api_key: + pytest.skip("LAI_OPENAI_API_KEY is not configured.") + pytest.importorskip("openai") + + settings = _settings(tmp_path, openai_api_key=SecretStr(api_key)) + provider = OpenAIProvider(settings, collect_system_snapshot(tmp_path, enable_gpu=False)) + model = ModelSpec.model_validate( + { + "id": "openai-smoke", + "role": "executor", + "runtime": "openai", + "model": "gpt-5.4-mini", + } + ) + + result = provider.generate( + model, + ProviderRequest(user_prompt="Reply with READY and nothing else.", max_output_tokens=32), + ) + assert "READY" in result.text.upper() + + +def test_anthropic_live_smoke(tmp_path) -> None: + api_key = os.getenv("LAI_ANTHROPIC_API_KEY") + if not api_key: + pytest.skip("LAI_ANTHROPIC_API_KEY is not configured.") + pytest.importorskip("anthropic") + + settings = _settings(tmp_path, anthropic_api_key=SecretStr(api_key)) + provider = AnthropicProvider(settings, collect_system_snapshot(tmp_path, enable_gpu=False)) + model = ModelSpec.model_validate( + { + "id": "anthropic-smoke", + "role": "executor", + "runtime": "anthropic", + "model": "claude-sonnet-4-20250514", + } + ) + + result = provider.generate( + model, + ProviderRequest(user_prompt="Reply with READY and nothing else.", max_output_tokens=32), + ) + assert "READY" in result.text.upper() + + +def test_gemini_live_smoke(tmp_path) -> None: + api_key = os.getenv("LAI_GEMINI_API_KEY") + if not api_key: + pytest.skip("LAI_GEMINI_API_KEY is not configured.") + pytest.importorskip("google.genai") + + settings = _settings(tmp_path, gemini_api_key=SecretStr(api_key)) + provider = GeminiProvider(settings, collect_system_snapshot(tmp_path, enable_gpu=False)) + model = ModelSpec.model_validate( + { + "id": "gemini-smoke", + "role": "executor", + "runtime": "gemini", + "model": "gemini-2.5-flash", + } + ) + + result = provider.generate( + model, + ProviderRequest(user_prompt="Reply with READY and nothing else.", max_output_tokens=32), + ) + assert "READY" in result.text.upper() diff --git a/tests/unit/test_api.py b/tests/unit/test_api.py new file mode 100644 index 0000000..99e5f0e --- /dev/null +++ b/tests/unit/test_api.py @@ -0,0 +1,13 @@ +from lai.api import create_api + + +def test_api_exposes_expected_routes() -> None: + app = create_api() + routes = {route.path for route in app.routes} + + assert "/health" in routes + assert "/models" in routes + assert "/route/explain" in routes + assert "/jobs" in routes + assert "/jobs/{job_id}" in routes + assert "/jobs/{job_id}/cancel" in routes diff --git a/tests/unit/test_config.py b/tests/unit/test_config.py new file mode 100644 index 0000000..23b6515 --- /dev/null +++ b/tests/unit/test_config.py @@ -0,0 +1,13 @@ +from lai.config import load_app_config + + +def test_app_config_loads_and_validates(repo_root) -> None: + config = load_app_config( + repo_root / "configs/models/catalog.yaml", + repo_root / "configs/routing/policies.yaml", + ) + + assert config.model_catalog.version == 1 + assert len(config.model_catalog.models) == 6 + assert config.routing_policy.router.model_id == "router-small" + assert config.routing_policy.fallbacks.when_gpu_unavailable is not None diff --git a/tests/unit/test_evals.py b/tests/unit/test_evals.py new file mode 100644 index 0000000..4648bdb --- /dev/null +++ b/tests/unit/test_evals.py @@ -0,0 +1,29 @@ +from lai.evals import load_route_eval_suite, run_route_eval_suite +from lai.settings import Settings + + +def test_route_eval_suite_loads(repo_root) -> None: + suite = load_route_eval_suite(repo_root / "evals/scenarios/routing-basics.yaml") + + assert suite.version == 1 + assert len(suite.scenarios) == 3 + assert suite.scenarios[2].unavailable_providers == ["airllm"] + + +def test_route_eval_suite_passes(repo_root, tmp_path) -> None: + settings = Settings( + root_dir=repo_root, + database_path=tmp_path / "lai.db", + state_dir=tmp_path, + artifacts_dir=tmp_path / "artifacts", + logs_dir=tmp_path / "logs", + huggingface_cache_dir=tmp_path / "hf-cache", + airllm_shards_dir=tmp_path / "airllm", + raw_models_dir=tmp_path / "raw", + ) + + result = run_route_eval_suite(settings, repo_root / "evals/scenarios/routing-basics.yaml") + + assert result.passed is True + assert result.total == 3 + assert result.failed_count == 0 diff --git a/tests/unit/test_layout.py b/tests/unit/test_layout.py index c5dc44f..e83afa2 100644 --- a/tests/unit/test_layout.py +++ b/tests/unit/test_layout.py @@ -7,7 +7,7 @@ def test_runtime_directories_are_repo_relative() -> None: paths = runtime_directories(settings.root_dir) - assert len(paths) == 5 + assert len(paths) == 6 assert all(str(path).startswith(str(settings.root_dir)) for path in paths) diff --git a/tests/unit/test_routing.py b/tests/unit/test_routing.py new file mode 100644 index 0000000..d8e5c25 --- /dev/null +++ b/tests/unit/test_routing.py @@ -0,0 +1,60 @@ +from lai.application import create_application +from lai.domain import ExecutionRequest +from lai.settings import Settings +from lai.system import collect_system_snapshot +from tests.helpers import FakeProvider + + +def _make_application(tmp_path, repo_root, *, airllm_available: bool = True): + settings = Settings( + root_dir=tmp_path, + model_catalog=repo_root / "configs/models/catalog.yaml", + routing_policy=repo_root / "configs/routing/policies.yaml", + enable_gpu=True, + ) + snapshot = collect_system_snapshot( + settings.root_dir, + enable_gpu=False, + ) + providers = { + "transformers": FakeProvider(settings, snapshot, "transformers"), + "airllm": FakeProvider(settings, snapshot, "airllm", available=airllm_available), + "openai": FakeProvider(settings, snapshot, "openai"), + "anthropic": FakeProvider(settings, snapshot, "anthropic"), + "gemini": FakeProvider(settings, snapshot, "gemini"), + } + return create_application(settings=settings, providers=providers) + + +def test_routing_prefers_deep_work_for_complex_prompt(tmp_path, repo_root) -> None: + app = _make_application(tmp_path, repo_root) + + decision = app.routing_engine.route( + ExecutionRequest( + user_prompt=( + "Create a comprehensive advanced architecture and complete implementation " + "strategy for a long-running research platform." + ) + ) + ) + + assert decision.matched_tier_id == "deep-work" + assert decision.executor_model_id == "execution-large" + assert decision.queue_recommended is True + + +def test_routing_falls_back_when_airllm_unavailable(tmp_path, repo_root) -> None: + app = _make_application(tmp_path, repo_root, airllm_available=False) + + decision = app.routing_engine.route( + ExecutionRequest( + user_prompt=( + "Create a comprehensive advanced architecture and complete implementation " + "strategy for a long-running research platform." + ) + ) + ) + + assert decision.matched_tier_id == "deep-work" + assert decision.executor_model_id != "execution-large" + assert decision.executor_model_id in {"openai-general", "anthropic-general", "gemini-general"} From 9bc5dc86cb84817c45c6cfa8dbb5305c7a11b2eb Mon Sep 17 00:00:00 2001 From: N3uralCreativity Date: Tue, 7 Apr 2026 19:40:55 +0200 Subject: [PATCH 02/16] feat: add dashboard for jobs and routing --- README.md | 7 +- apps/web/README.md | 13 +- pyproject.toml | 5 + src/lai/api/app.py | 65 +++-- src/lai/api/static/dashboard.css | 435 +++++++++++++++++++++++++++++++ src/lai/api/static/dashboard.js | 273 +++++++++++++++++++ src/lai/api/static/index.html | 140 ++++++++++ tests/unit/test_api.py | 32 +++ 8 files changed, 947 insertions(+), 23 deletions(-) create mode 100644 src/lai/api/static/dashboard.css create mode 100644 src/lai/api/static/dashboard.js create mode 100644 src/lai/api/static/index.html diff --git a/README.md b/README.md index 7eeb2da..b1d152c 100644 --- a/README.md +++ b/README.md @@ -112,6 +112,8 @@ uvicorn lai.api.app:create_api --factory --reload Available endpoints: +- `GET /` +- `GET /dashboard` - `GET /health` - `GET /models` - `POST /route/explain` @@ -120,6 +122,9 @@ Available endpoints: - `GET /jobs/{job_id}` - `POST /jobs/{job_id}/cancel` +The API now also serves a read-mostly dashboard at `/dashboard` with live model health, +route explanation, and recent job inspection. + ## Initial GitHub rules encoded in this repo - Pull request template and issue forms for consistent planning. @@ -133,7 +138,7 @@ Available endpoints: 1. Add live provider smoke tests behind credentials and optional extras. 2. Harden the AirLLM local runtime path with real workstation validation. 3. Expand eval scenarios and richer reviewer/final-output refinement. -4. Add the web dashboard on top of the persisted job and artifact store. +4. Deepen the dashboard with artifact browsing, job replay, and richer traces. ## References diff --git a/apps/web/README.md b/apps/web/README.md index 3a3f94f..9f55cb2 100644 --- a/apps/web/README.md +++ b/apps/web/README.md @@ -1,3 +1,14 @@ # Web App -This folder is reserved for the future dashboard that will surface routing decisions, queue state, artifacts, and model health. +The first dashboard is now served directly by the FastAPI app at `/dashboard`. + +Current dashboard capabilities: + +- live health and catalog summary +- route explanation form +- job submission and recent queue inspection +- model health cards +- job output inspector + +The implementation intentionally stays lightweight for now by using static assets served from +the API package instead of a separate frontend build pipeline. diff --git a/pyproject.toml b/pyproject.toml index 6156707..5a4c5cc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -48,3 +48,8 @@ lai = "lai.cli:app" [tool.hatch.build.targets.wheel] packages = ["src/lai"] + +[tool.hatch.build] +include = [ + "src/lai/api/static/**", +] diff --git a/src/lai/api/app.py b/src/lai/api/app.py index 5780bd6..2f7bf69 100644 --- a/src/lai/api/app.py +++ b/src/lai/api/app.py @@ -1,10 +1,15 @@ from __future__ import annotations +from pathlib import Path + from fastapi import FastAPI, HTTPException +from fastapi.responses import FileResponse, RedirectResponse +from fastapi.staticfiles import StaticFiles from pydantic import BaseModel, Field from ..application import create_application from ..domain import ExecutionRequest, QueueMode +from ..settings import Settings class JobCreatePayload(BaseModel): @@ -19,23 +24,41 @@ class JobCreatePayload(BaseModel): reviewer_enabled: bool | None = None -def create_api() -> FastAPI: +def create_api(settings: Settings | None = None) -> FastAPI: + settings = settings or Settings() api = FastAPI(title="LAI API", version="0.1.0") + static_dir = Path(__file__).resolve().parent / "static" + api.mount( + "/dashboard/assets", + StaticFiles(directory=static_dir), + name="dashboard-assets", + ) + + def application(): + return create_application(settings=settings) + + @api.get("/", include_in_schema=False) + def root() -> RedirectResponse: + return RedirectResponse(url="/dashboard", status_code=307) + + @api.get("/dashboard", include_in_schema=False) + def dashboard() -> FileResponse: + return FileResponse(static_dir / "index.html") @api.get("/health") def health() -> dict[str, object]: - application = create_application() + app_state = application() return { "status": "ok", - "environment": application.settings.environment, - "database": str(application.settings.resolved_database_path), - "model_count": len(application.config.model_catalog.models), + "environment": app_state.settings.environment, + "database": str(app_state.settings.resolved_database_path), + "model_count": len(app_state.config.model_catalog.models), } @api.get("/models") def list_models() -> dict[str, object]: - application = create_application() - healthchecks = application.provider_registry.model_healthchecks(application.config) + app_state = application() + healthchecks = app_state.provider_registry.model_healthchecks(app_state.config) return { "models": [ { @@ -46,41 +69,41 @@ def list_models() -> dict[str, object]: "capabilities": model.capabilities, "health": healthchecks[model.id].model_dump(), } - for model in application.config.model_catalog.models + for model in app_state.config.model_catalog.models ] } @api.post("/route/explain") def explain_route(payload: JobCreatePayload) -> dict[str, object]: - application = create_application() - request = _build_execution_request(application, payload) - return application.routing_engine.route(request).model_dump() + app_state = application() + request = _build_execution_request(app_state, payload) + return app_state.routing_engine.route(request).model_dump() @api.post("/jobs") def create_job(payload: JobCreatePayload) -> dict[str, object]: - application = create_application() - request = _build_execution_request(application, payload) - return application.orchestration.submit_request(request).model_dump() + app_state = application() + request = _build_execution_request(app_state, payload) + return app_state.orchestration.submit_request(request).model_dump() @api.get("/jobs") def list_jobs(limit: int = 20) -> dict[str, object]: - application = create_application() - return {"jobs": [job.model_dump() for job in application.job_store.list_jobs(limit=limit)]} + app_state = application() + return {"jobs": [job.model_dump() for job in app_state.job_store.list_jobs(limit=limit)]} @api.get("/jobs/{job_id}") def get_job(job_id: str) -> dict[str, object]: - application = create_application() - job = application.job_store.get_job(job_id) + app_state = application() + job = app_state.job_store.get_job(job_id) if job is None: raise HTTPException(status_code=404, detail="Job not found") return job.model_dump() @api.post("/jobs/{job_id}/cancel") def cancel_job(job_id: str) -> dict[str, object]: - application = create_application() - if not application.job_store.cancel_job(job_id): + app_state = application() + if not app_state.job_store.cancel_job(job_id): raise HTTPException(status_code=404, detail="Job not found or not cancelable") - job = application.job_store.get_job(job_id) + job = app_state.job_store.get_job(job_id) return {"job": job.model_dump() if job else None} return api diff --git a/src/lai/api/static/dashboard.css b/src/lai/api/static/dashboard.css new file mode 100644 index 0000000..f78e82b --- /dev/null +++ b/src/lai/api/static/dashboard.css @@ -0,0 +1,435 @@ +:root { + --bg: #f7f1e5; + --panel: rgba(255, 250, 240, 0.9); + --panel-strong: rgba(255, 248, 232, 0.96); + --line: rgba(37, 56, 68, 0.12); + --ink: #1e2c35; + --muted: #5a6a74; + --accent: #0d6b61; + --accent-soft: #d6efe8; + --accent-warm: #c96d39; + --accent-deep: #213f64; + --danger: #9f3b2c; + --shadow: 0 20px 60px rgba(33, 63, 100, 0.12); +} + +* { + box-sizing: border-box; +} + +body { + margin: 0; + min-height: 100vh; + background: + radial-gradient(circle at top left, rgba(13, 107, 97, 0.18), transparent 28%), + radial-gradient(circle at 80% 18%, rgba(201, 109, 57, 0.16), transparent 22%), + linear-gradient(180deg, #f8f2e7 0%, #efe4d0 100%); + color: var(--ink); + font-family: "Space Grotesk", "Aptos", sans-serif; + overflow-x: hidden; +} + +.orb { + position: fixed; + border-radius: 999px; + filter: blur(18px); + pointer-events: none; + z-index: 0; +} + +.orb-a { + width: 18rem; + height: 18rem; + background: rgba(13, 107, 97, 0.12); + top: 3rem; + right: -4rem; +} + +.orb-b { + width: 14rem; + height: 14rem; + background: rgba(201, 109, 57, 0.16); + bottom: 8rem; + left: -3rem; +} + +.grid-lines { + position: fixed; + inset: 0; + background-image: + linear-gradient(rgba(33, 63, 100, 0.04) 1px, transparent 1px), + linear-gradient(90deg, rgba(33, 63, 100, 0.04) 1px, transparent 1px); + background-size: 2rem 2rem; + mask-image: linear-gradient(180deg, rgba(0, 0, 0, 0.45), transparent 85%); + pointer-events: none; + z-index: 0; +} + +.shell { + position: relative; + z-index: 1; + width: min(1360px, calc(100vw - 2rem)); + margin: 0 auto; + padding: 2rem 0 3rem; + display: grid; + gap: 1.25rem; +} + +.panel { + background: var(--panel); + border: 1px solid var(--line); + border-radius: 1.4rem; + box-shadow: var(--shadow); + backdrop-filter: blur(12px); +} + +.hero { + display: grid; + grid-template-columns: minmax(0, 1.5fr) minmax(18rem, 0.9fr); + gap: 1.25rem; + padding: 1.6rem; +} + +.hero-copy h1, +.section-head h2 { + margin: 0; + letter-spacing: -0.04em; +} + +.hero-copy h1 { + font-size: clamp(2.4rem, 5vw, 4.8rem); + line-height: 0.96; + max-width: 11ch; +} + +.lede { + max-width: 58ch; + margin: 1rem 0 0; + color: var(--muted); + font-size: 1rem; +} + +.eyebrow { + margin: 0 0 0.55rem; + color: var(--accent-deep); + font-family: "IBM Plex Mono", monospace; + font-size: 0.74rem; + letter-spacing: 0.16em; + text-transform: uppercase; +} + +.hero-actions, +.action-row { + display: flex; + flex-wrap: wrap; + gap: 0.75rem; + margin-top: 1.2rem; +} + +.button { + border: 0; + border-radius: 999px; + padding: 0.82rem 1.1rem; + font: inherit; + font-weight: 700; + cursor: pointer; + transition: transform 180ms ease, box-shadow 180ms ease, background 180ms ease; + text-decoration: none; +} + +.button:hover { + transform: translateY(-1px); +} + +.primary { + background: linear-gradient(135deg, var(--accent), #0f7e73); + color: white; + box-shadow: 0 16px 32px rgba(13, 107, 97, 0.18); +} + +.secondary { + background: linear-gradient(135deg, var(--accent-warm), #df8755); + color: white; +} + +.ghost { + background: rgba(255, 255, 255, 0.55); + color: var(--ink); + border: 1px solid rgba(33, 63, 100, 0.12); +} + +.hero-metrics { + display: grid; + gap: 0.85rem; +} + +.metric-card { + background: var(--panel-strong); + border: 1px solid rgba(33, 63, 100, 0.12); + border-radius: 1rem; + padding: 1rem; +} + +.metric-label { + display: block; + color: var(--muted); + font-family: "IBM Plex Mono", monospace; + font-size: 0.74rem; + text-transform: uppercase; + letter-spacing: 0.08em; +} + +.metric-card strong { + display: block; + margin-top: 0.45rem; + font-size: 1.4rem; +} + +.route-grid, +.dashboard-grid { + display: grid; + gap: 1.25rem; +} + +.route-grid { + grid-template-columns: minmax(0, 1.1fr) minmax(18rem, 0.9fr); +} + +.dashboard-grid { + grid-template-columns: minmax(0, 1fr) minmax(0, 1fr); +} + +.route-panel, +.route-output, +.jobs-panel, +.detail-panel, +.models-panel { + padding: 1.35rem; +} + +.section-head { + display: flex; + flex-direction: column; + gap: 0.2rem; + margin-bottom: 1rem; +} + +.route-form, +.models-grid, +.jobs-list, +.job-detail, +.reason-list, +.stat-block { + min-height: 8rem; +} + +.route-form label { + display: block; + margin-bottom: 0.9rem; +} + +.route-form span { + display: block; + margin-bottom: 0.35rem; + font-size: 0.92rem; + font-weight: 700; +} + +textarea, +input, +select { + width: 100%; + border: 1px solid rgba(33, 63, 100, 0.14); + background: rgba(255, 255, 255, 0.7); + color: var(--ink); + border-radius: 1rem; + padding: 0.9rem 1rem; + font: inherit; +} + +textarea { + resize: vertical; + min-height: 6rem; +} + +.form-row { + display: grid; + grid-template-columns: minmax(0, 0.45fr) minmax(0, 0.55fr); + gap: 0.8rem; +} + +.stat-block, +.job-detail, +.empty { + display: grid; + gap: 0.75rem; + align-content: start; + color: var(--muted); +} + +.route-chip-row, +.job-meta { + display: flex; + flex-wrap: wrap; + gap: 0.55rem; +} + +.chip { + display: inline-flex; + align-items: center; + gap: 0.3rem; + padding: 0.45rem 0.7rem; + border-radius: 999px; + background: var(--accent-soft); + color: var(--accent); + font-family: "IBM Plex Mono", monospace; + font-size: 0.78rem; +} + +.chip.warn { + background: rgba(201, 109, 57, 0.14); + color: var(--accent-warm); +} + +.chip.danger { + background: rgba(159, 59, 44, 0.14); + color: var(--danger); +} + +.reason-item, +.job-card, +.model-card { + border: 1px solid rgba(33, 63, 100, 0.11); + background: rgba(255, 255, 255, 0.64); + border-radius: 1rem; +} + +.reason-item { + padding: 0.9rem 1rem; +} + +.reason-stage { + color: var(--accent-deep); + font-family: "IBM Plex Mono", monospace; + font-size: 0.76rem; + letter-spacing: 0.05em; + text-transform: uppercase; +} + +.reason-text { + margin-top: 0.35rem; + color: var(--ink); +} + +.jobs-list { + display: grid; + gap: 0.75rem; +} + +.job-card { + padding: 0.95rem 1rem; + cursor: pointer; + transition: transform 160ms ease, border-color 160ms ease, background 160ms ease; +} + +.job-card:hover, +.job-card.active { + transform: translateY(-1px); + border-color: rgba(13, 107, 97, 0.3); + background: rgba(255, 255, 255, 0.88); +} + +.job-card-header, +.model-card-header { + display: flex; + justify-content: space-between; + gap: 1rem; + align-items: start; +} + +.job-id, +.model-runtime, +.mono { + font-family: "IBM Plex Mono", monospace; +} + +.job-status { + font-size: 0.78rem; + text-transform: uppercase; + letter-spacing: 0.08em; +} + +.job-status.succeeded { + color: var(--accent); +} + +.job-status.failed, +.job-status.canceled { + color: var(--danger); +} + +.job-status.running, +.job-status.queued { + color: var(--accent-warm); +} + +.job-output { + padding: 1rem; + border-radius: 1rem; + background: rgba(30, 44, 53, 0.92); + color: #eef4f2; + font-family: "IBM Plex Mono", monospace; + font-size: 0.84rem; + line-height: 1.55; + white-space: pre-wrap; + overflow-wrap: anywhere; +} + +.models-grid { + display: grid; + grid-template-columns: repeat(auto-fit, minmax(220px, 1fr)); + gap: 0.85rem; +} + +.model-card { + padding: 1rem; +} + +.model-card ul { + margin: 0.75rem 0 0; + padding-left: 1.1rem; + color: var(--muted); +} + +.fade-in { + animation: rise-in 260ms ease; +} + +@keyframes rise-in { + from { + opacity: 0; + transform: translateY(8px); + } + to { + opacity: 1; + transform: translateY(0); + } +} + +@media (max-width: 960px) { + .hero, + .route-grid, + .dashboard-grid, + .form-row { + grid-template-columns: 1fr; + } + + .hero-copy h1 { + max-width: none; + } + + .shell { + width: min(100vw - 1rem, 1360px); + padding: 1rem 0 2rem; + } +} diff --git a/src/lai/api/static/dashboard.js b/src/lai/api/static/dashboard.js new file mode 100644 index 0000000..0d7e789 --- /dev/null +++ b/src/lai/api/static/dashboard.js @@ -0,0 +1,273 @@ +const state = { + jobs: [], + models: [], + selectedJobId: null, +}; + +const elements = { + environment: document.getElementById("metric-environment"), + modelCount: document.getElementById("metric-model-count"), + jobCount: document.getElementById("metric-job-count"), + queueState: document.getElementById("metric-queue-state"), + jobsList: document.getElementById("jobs-list"), + modelsGrid: document.getElementById("models-grid"), + routeSummary: document.getElementById("route-summary"), + routeReasons: document.getElementById("route-reasons"), + jobDetail: document.getElementById("job-detail"), + routeForm: document.getElementById("route-form"), + refreshAll: document.getElementById("refresh-all"), + submitJob: document.getElementById("submit-job"), +}; + +async function fetchJson(url, options = undefined) { + const response = await fetch(url, options); + if (!response.ok) { + const body = await response.text(); + throw new Error(body || `Request failed with ${response.status}`); + } + return response.json(); +} + +function payloadFromForm() { + return { + user_prompt: document.getElementById("user-prompt").value.trim(), + system_prompt: document.getElementById("system-prompt").value.trim() || null, + queue_mode: document.getElementById("queue-mode").value, + model_override: document.getElementById("model-override").value.trim() || null, + }; +} + +function ensurePrompt(payload) { + if (!payload.user_prompt) { + throw new Error("Add a user prompt first."); + } +} + +function chip(label, value, extraClass = "") { + const className = extraClass ? `chip ${extraClass}` : "chip"; + return `${label}: ${escapeHtml(String(value))}`; +} + +function escapeHtml(value) { + return value + .replaceAll("&", "&") + .replaceAll("<", "<") + .replaceAll(">", ">") + .replaceAll('"', """) + .replaceAll("'", "'"); +} + +async function loadOverview() { + const [health, modelsResponse, jobsResponse] = await Promise.all([ + fetchJson("/health"), + fetchJson("/models"), + fetchJson("/jobs?limit=12"), + ]); + + state.models = modelsResponse.models; + state.jobs = jobsResponse.jobs; + + elements.environment.textContent = health.environment; + elements.modelCount.textContent = String(health.model_count); + elements.jobCount.textContent = String(state.jobs.length); + elements.queueState.textContent = summarizeQueue(state.jobs); + + renderModels(); + renderJobs(); + + if (state.selectedJobId) { + await selectJob(state.selectedJobId); + } else if (state.jobs.length > 0) { + await selectJob(state.jobs[0].id); + } +} + +function summarizeQueue(jobs) { + if (jobs.length === 0) { + return "idle"; + } + const counts = jobs.reduce( + (accumulator, job) => { + accumulator[job.status] = (accumulator[job.status] || 0) + 1; + return accumulator; + }, + {}, + ); + return Object.entries(counts) + .map(([key, value]) => `${key}:${value}`) + .join(" | "); +} + +function renderModels() { + if (state.models.length === 0) { + elements.modelsGrid.className = "models-grid empty"; + elements.modelsGrid.textContent = "No models loaded."; + return; + } + + elements.modelsGrid.className = "models-grid fade-in"; + elements.modelsGrid.innerHTML = state.models + .map((model) => { + const health = model.health; + const healthClass = health.available && health.healthy ? "" : "warn"; + const details = health.reasons && health.reasons.length > 0 ? health.reasons : ["ready"]; + return ` +
+
+
+ ${escapeHtml(model.id)} +
${escapeHtml(model.runtime)}
+
+ ${chip("health", health.available ? "ready" : "blocked", healthClass)} +
+

${escapeHtml(model.model_ref)}

+
    + ${details.map((detail) => `
  • ${escapeHtml(detail)}
  • `).join("")} +
+
+ `; + }) + .join(""); +} + +function renderJobs() { + if (state.jobs.length === 0) { + elements.jobsList.className = "jobs-list empty"; + elements.jobsList.textContent = "No persisted jobs yet."; + return; + } + + elements.jobsList.className = "jobs-list fade-in"; + elements.jobsList.innerHTML = state.jobs + .map((job) => { + const activeClass = job.id === state.selectedJobId ? " active" : ""; + return ` +
+
+
+
${escapeHtml(job.id.slice(0, 8))}
+ ${escapeHtml(job.route_decision?.executor_model_id || "n/a")} +
+
${escapeHtml(job.status)}
+
+

${escapeHtml(job.request.user_prompt.slice(0, 120))}

+
+ `; + }) + .join(""); + + elements.jobsList.querySelectorAll("[data-job-id]").forEach((card) => { + card.addEventListener("click", () => { + void selectJob(card.getAttribute("data-job-id")); + }); + }); +} + +async function selectJob(jobId) { + if (!jobId) { + return; + } + + state.selectedJobId = jobId; + renderJobs(); + + const job = await fetchJson(`/jobs/${jobId}`); + const chips = [ + chip("status", job.status, job.status === "failed" ? "danger" : job.status === "succeeded" ? "" : "warn"), + chip("queue", job.queue_mode), + chip("tier", job.route_decision?.matched_tier_id || "n/a"), + chip("executor", job.route_decision?.executor_model_id || "n/a"), + ].join(""); + + const output = job.result?.text + ? `
${escapeHtml(job.result.text)}
` + : `
No final output stored for this job yet.
`; + + const error = job.error?.message + ? `
error: ${escapeHtml(job.error.message)}
` + : ""; + + elements.jobDetail.className = "job-detail fade-in"; + elements.jobDetail.innerHTML = ` +
${chips}
+ ${error} + ${output} + `; +} + +function renderRouteDecision(decision) { + elements.routeSummary.className = "stat-block fade-in"; + elements.routeSummary.innerHTML = ` +
+ ${chip("tier", decision.matched_tier_id)} + ${chip("planner", decision.planner_model_id || "none")} + ${chip("executor", decision.executor_model_id)} + ${chip("reviewer", decision.reviewer_model_id || "none")} + ${chip("queue", decision.queue_recommended ? "recommended" : "not needed", decision.queue_recommended ? "warn" : "")} +
+
fallbacks: ${escapeHtml((decision.fallback_model_ids || []).join(", ") || "none")}
+ `; + + elements.routeReasons.innerHTML = (decision.reasons || []) + .map( + (reason) => ` +
+
${escapeHtml(reason.stage)}
+
${escapeHtml(reason.message)}
+
+ `, + ) + .join(""); +} + +async function explainRoute(event) { + event.preventDefault(); + const payload = payloadFromForm(); + ensurePrompt(payload); + const decision = await fetchJson("/route/explain", { + method: "POST", + headers: { "Content-Type": "application/json" }, + body: JSON.stringify(payload), + }); + renderRouteDecision(decision); +} + +async function submitJob() { + const payload = payloadFromForm(); + ensurePrompt(payload); + const job = await fetchJson("/jobs", { + method: "POST", + headers: { "Content-Type": "application/json" }, + body: JSON.stringify(payload), + }); + renderRouteDecision(job.route_decision); + await loadOverview(); + await selectJob(job.id); +} + +function bindEvents() { + elements.routeForm.addEventListener("submit", (event) => { + void explainRoute(event).catch(handleError); + }); + elements.submitJob.addEventListener("click", () => { + void submitJob().catch(handleError); + }); + elements.refreshAll.addEventListener("click", () => { + void loadOverview().catch(handleError); + }); +} + +function handleError(error) { + elements.routeSummary.className = "stat-block fade-in"; + elements.routeSummary.innerHTML = `
error
${escapeHtml(error.message || String(error))}
`; +} + +async function boot() { + bindEvents(); + await loadOverview(); + setInterval(() => { + void loadOverview().catch(handleError); + }, 8000); +} + +void boot().catch(handleError); diff --git a/src/lai/api/static/index.html b/src/lai/api/static/index.html new file mode 100644 index 0000000..911aa3f --- /dev/null +++ b/src/lai/api/static/index.html @@ -0,0 +1,140 @@ + + + + + + LAI Dashboard + + + + + + +
+
+
+ +
+
+
+

LAI CONTROL ROOM

+

Route fast. Think deep. Watch the queue move.

+

+ This dashboard sits on top of the local-first orchestration core. It exposes + model health, route reasoning, persistent jobs, and the current system stance + without needing a separate frontend stack. +

+
+ Open API Docs + +
+
+
+
+ Environment + ... +
+
+ Catalog Models + ... +
+
+ Jobs Tracked + ... +
+
+ Queue State + ... +
+
+
+ +
+
+
+

Route Lab

+

Explain or launch a request

+
+
+ + +
+ + +
+
+ + +
+
+
+ +
+
+

Decision Trace

+

Routing output

+
+
No route computed yet.
+
+
+
+ +
+
+
+

Queue

+

Recent jobs

+
+
No jobs yet.
+
+ +
+
+

Inspector

+

Job detail

+
+
Select a job to inspect its output.
+
+
+ +
+
+

Runtime Matrix

+

Model health

+
+
Loading model health...
+
+
+ + + + diff --git a/tests/unit/test_api.py b/tests/unit/test_api.py index 99e5f0e..8a49b9d 100644 --- a/tests/unit/test_api.py +++ b/tests/unit/test_api.py @@ -1,13 +1,45 @@ +from fastapi.testclient import TestClient + from lai.api import create_api +from lai.settings import Settings def test_api_exposes_expected_routes() -> None: app = create_api() routes = {route.path for route in app.routes} + assert "/" in routes + assert "/dashboard" in routes assert "/health" in routes assert "/models" in routes assert "/route/explain" in routes assert "/jobs" in routes assert "/jobs/{job_id}" in routes assert "/jobs/{job_id}/cancel" in routes + + +def test_dashboard_routes_serve_html_and_assets(repo_root, tmp_path) -> None: + settings = Settings( + root_dir=tmp_path, + model_catalog=repo_root / "configs/models/catalog.yaml", + routing_policy=repo_root / "configs/routing/policies.yaml", + database_path=tmp_path / "data/state/lai.db", + state_dir=tmp_path / "data/state", + artifacts_dir=tmp_path / "data/artifacts", + logs_dir=tmp_path / "logs", + huggingface_cache_dir=tmp_path / "data/cache/huggingface", + airllm_shards_dir=tmp_path / "data/models/airllm-shards", + raw_models_dir=tmp_path / "data/models/raw", + ) + client = TestClient(create_api(settings=settings)) + + root_response = client.get("/", follow_redirects=False) + dashboard_response = client.get("/dashboard") + asset_response = client.get("/dashboard/assets/dashboard.js") + + assert root_response.status_code == 307 + assert root_response.headers["location"] == "/dashboard" + assert dashboard_response.status_code == 200 + assert "LAI CONTROL ROOM" in dashboard_response.text + assert asset_response.status_code == 200 + assert "async function loadOverview" in asset_response.text From 114e95eab094036c0bc50c95683664e3a0ca7e0a Mon Sep 17 00:00:00 2001 From: N3uralCreativity Date: Tue, 7 Apr 2026 19:44:07 +0200 Subject: [PATCH 03/16] feat: add artifact browsing to dashboard --- README.md | 4 +- apps/web/README.md | 1 + src/lai/api/app.py | 32 ++++++++++ src/lai/api/static/dashboard.css | 28 +++++++++ src/lai/api/static/dashboard.js | 100 ++++++++++++++++++++++++++++++- src/lai/api/static/index.html | 2 +- src/lai/artifacts.py | 14 +++++ tests/unit/test_api.py | 69 +++++++++++++++++---- 8 files changed, 234 insertions(+), 16 deletions(-) diff --git a/README.md b/README.md index b1d152c..0718650 100644 --- a/README.md +++ b/README.md @@ -123,7 +123,7 @@ Available endpoints: - `POST /jobs/{job_id}/cancel` The API now also serves a read-mostly dashboard at `/dashboard` with live model health, -route explanation, and recent job inspection. +route explanation, recent job inspection, and artifact/trace browsing. ## Initial GitHub rules encoded in this repo @@ -138,7 +138,7 @@ route explanation, and recent job inspection. 1. Add live provider smoke tests behind credentials and optional extras. 2. Harden the AirLLM local runtime path with real workstation validation. 3. Expand eval scenarios and richer reviewer/final-output refinement. -4. Deepen the dashboard with artifact browsing, job replay, and richer traces. +4. Deepen the dashboard with job replay, richer traces, and live worker controls. ## References diff --git a/apps/web/README.md b/apps/web/README.md index 9f55cb2..5b1acec 100644 --- a/apps/web/README.md +++ b/apps/web/README.md @@ -7,6 +7,7 @@ Current dashboard capabilities: - live health and catalog summary - route explanation form - job submission and recent queue inspection +- artifact and trace browsing for persisted jobs - model health cards - job output inspector diff --git a/src/lai/api/app.py b/src/lai/api/app.py index 2f7bf69..178f80d 100644 --- a/src/lai/api/app.py +++ b/src/lai/api/app.py @@ -98,6 +98,38 @@ def get_job(job_id: str) -> dict[str, object]: raise HTTPException(status_code=404, detail="Job not found") return job.model_dump() + @api.get("/jobs/{job_id}/artifacts") + def list_job_artifacts(job_id: str) -> dict[str, object]: + app_state = application() + job = app_state.job_store.get_job(job_id) + if job is None: + raise HTTPException(status_code=404, detail="Job not found") + return {"artifacts": [artifact.model_dump() for artifact in job.artifacts]} + + @api.get("/jobs/{job_id}/artifacts/{artifact_id}") + def get_job_artifact(job_id: str, artifact_id: str) -> dict[str, object]: + app_state = application() + job = app_state.job_store.get_job(job_id) + if job is None: + raise HTTPException(status_code=404, detail="Job not found") + + artifact = next((item for item in job.artifacts if item.id == artifact_id), None) + if artifact is None: + raise HTTPException(status_code=404, detail="Artifact not found") + + try: + content = app_state.artifacts.read_text(artifact) + except FileNotFoundError as exc: + raise HTTPException(status_code=404, detail="Artifact file missing") from exc + except ValueError as exc: + raise HTTPException(status_code=400, detail="Artifact path is invalid") from exc + + return { + "artifact": artifact.model_dump(), + "content_type": app_state.artifacts.content_type(artifact), + "content": content, + } + @api.post("/jobs/{job_id}/cancel") def cancel_job(job_id: str) -> dict[str, object]: app_state = application() diff --git a/src/lai/api/static/dashboard.css b/src/lai/api/static/dashboard.css index f78e82b..255a789 100644 --- a/src/lai/api/static/dashboard.css +++ b/src/lai/api/static/dashboard.css @@ -385,6 +385,34 @@ textarea { overflow-wrap: anywhere; } +.artifact-strip { + display: flex; + flex-wrap: wrap; + gap: 0.55rem; +} + +.artifact-button { + border: 1px solid rgba(33, 63, 100, 0.14); + background: rgba(255, 255, 255, 0.72); + color: var(--ink); + border-radius: 999px; + padding: 0.58rem 0.8rem; + font: inherit; + font-size: 0.84rem; + cursor: pointer; +} + +.artifact-button.active { + background: var(--accent-soft); + border-color: rgba(13, 107, 97, 0.26); + color: var(--accent); +} + +.stack { + display: grid; + gap: 0.75rem; +} + .models-grid { display: grid; grid-template-columns: repeat(auto-fit, minmax(220px, 1fr)); diff --git a/src/lai/api/static/dashboard.js b/src/lai/api/static/dashboard.js index 0d7e789..c3e88bc 100644 --- a/src/lai/api/static/dashboard.js +++ b/src/lai/api/static/dashboard.js @@ -2,6 +2,7 @@ const state = { jobs: [], models: [], selectedJobId: null, + selectedArtifactId: null, }; const elements = { @@ -169,6 +170,7 @@ async function selectJob(jobId) { } state.selectedJobId = jobId; + state.selectedArtifactId = null; renderJobs(); const job = await fetchJson(`/jobs/${jobId}`); @@ -187,11 +189,107 @@ async function selectJob(jobId) { ? `
error: ${escapeHtml(job.error.message)}
` : ""; + const reasons = (job.route_decision?.reasons || []) + .map( + (reason) => ` +
+
${escapeHtml(reason.stage)}
+
${escapeHtml(reason.message)}
+
+ `, + ) + .join(""); + elements.jobDetail.className = "job-detail fade-in"; elements.jobDetail.innerHTML = `
${chips}
${error} - ${output} +
+

Final output

+ ${output} +
+
+

Route trace

+
${reasons || '
No route reasons recorded.
'}
+
+
+

Artifacts

+
Loading artifacts...
+
Select an artifact to inspect it.
+
+ `; + + await loadArtifacts(job.id, job.artifacts || []); +} + +async function loadArtifacts(jobId, jobArtifacts) { + const response = + jobArtifacts.length > 0 ? { artifacts: jobArtifacts } : await fetchJson(`/jobs/${jobId}/artifacts`); + const artifacts = response.artifacts || []; + const strip = document.getElementById("artifact-strip"); + const preview = document.getElementById("artifact-preview"); + if (!strip || !preview) { + return; + } + + if (artifacts.length === 0) { + strip.className = "artifact-strip empty"; + strip.textContent = "No artifacts recorded for this job."; + preview.className = "empty"; + preview.textContent = "No artifact preview available."; + return; + } + + strip.className = "artifact-strip fade-in"; + strip.innerHTML = artifacts + .map((artifact) => { + const activeClass = artifact.id === state.selectedArtifactId ? " active" : ""; + return ` + + `; + }) + .join(""); + + strip.querySelectorAll("[data-artifact-id]").forEach((button) => { + button.addEventListener("click", () => { + void selectArtifact(jobId, button.getAttribute("data-artifact-id"), artifacts).catch(handleError); + }); + }); + + const preferredArtifact = + artifacts.find((artifact) => artifact.artifact_type === "final-output") || artifacts[0]; + await selectArtifact(jobId, preferredArtifact.id, artifacts); +} + +async function selectArtifact(jobId, artifactId, artifacts) { + if (!artifactId) { + return; + } + + state.selectedArtifactId = artifactId; + const strip = document.getElementById("artifact-strip"); + const preview = document.getElementById("artifact-preview"); + if (!strip || !preview) { + return; + } + + strip.querySelectorAll("[data-artifact-id]").forEach((button) => { + button.classList.toggle("active", button.getAttribute("data-artifact-id") === artifactId); + }); + + const response = await fetchJson(`/jobs/${jobId}/artifacts/${artifactId}`); + const artifact = response.artifact; + const artifactMeta = artifacts.find((item) => item.id === artifactId) || artifact; + preview.className = "stack fade-in"; + preview.innerHTML = ` +
+ ${chip("type", artifactMeta.artifact_type)} + ${chip("path", artifactMeta.relative_path)} + ${chip("format", response.content_type === "application/json" ? "json" : "text")} +
+
${escapeHtml(response.content)}
`; } diff --git a/src/lai/api/static/index.html b/src/lai/api/static/index.html index 911aa3f..65661df 100644 --- a/src/lai/api/static/index.html +++ b/src/lai/api/static/index.html @@ -122,7 +122,7 @@

Recent jobs

Inspector

Job detail

-
Select a job to inspect its output.
+
Select a job to inspect its traces.
diff --git a/src/lai/artifacts.py b/src/lai/artifacts.py index 275e008..14d127a 100644 --- a/src/lai/artifacts.py +++ b/src/lai/artifacts.py @@ -24,6 +24,20 @@ def write_text( ) -> ArtifactRecord: return self._write(job_id, artifact_type, filename, payload) + def resolve_path(self, artifact: ArtifactRecord) -> Path: + target = (self.artifacts_root / artifact.relative_path).resolve() + root = self.artifacts_root.resolve() + target.relative_to(root) + return target + + def read_text(self, artifact: ArtifactRecord) -> str: + return self.resolve_path(artifact).read_text(encoding="utf-8") + + def content_type(self, artifact: ArtifactRecord) -> str: + if artifact.relative_path.endswith(".json"): + return "application/json" + return "text/plain" + def _write( self, job_id: str, artifact_type: str, filename: str, payload: str ) -> ArtifactRecord: diff --git a/tests/unit/test_api.py b/tests/unit/test_api.py index 8a49b9d..d63f400 100644 --- a/tests/unit/test_api.py +++ b/tests/unit/test_api.py @@ -1,7 +1,26 @@ from fastapi.testclient import TestClient from lai.api import create_api +from lai.application import create_application +from lai.domain import ExecutionRequest, QueueMode from lai.settings import Settings +from lai.system import collect_system_snapshot +from tests.helpers import FakeProvider + + +def _test_settings(repo_root, tmp_path) -> Settings: + return Settings( + root_dir=tmp_path, + model_catalog=repo_root / "configs/models/catalog.yaml", + routing_policy=repo_root / "configs/routing/policies.yaml", + database_path=tmp_path / "data/state/lai.db", + state_dir=tmp_path / "data/state", + artifacts_dir=tmp_path / "data/artifacts", + logs_dir=tmp_path / "logs", + huggingface_cache_dir=tmp_path / "data/cache/huggingface", + airllm_shards_dir=tmp_path / "data/models/airllm-shards", + raw_models_dir=tmp_path / "data/models/raw", + ) def test_api_exposes_expected_routes() -> None: @@ -15,22 +34,13 @@ def test_api_exposes_expected_routes() -> None: assert "/route/explain" in routes assert "/jobs" in routes assert "/jobs/{job_id}" in routes + assert "/jobs/{job_id}/artifacts" in routes + assert "/jobs/{job_id}/artifacts/{artifact_id}" in routes assert "/jobs/{job_id}/cancel" in routes def test_dashboard_routes_serve_html_and_assets(repo_root, tmp_path) -> None: - settings = Settings( - root_dir=tmp_path, - model_catalog=repo_root / "configs/models/catalog.yaml", - routing_policy=repo_root / "configs/routing/policies.yaml", - database_path=tmp_path / "data/state/lai.db", - state_dir=tmp_path / "data/state", - artifacts_dir=tmp_path / "data/artifacts", - logs_dir=tmp_path / "logs", - huggingface_cache_dir=tmp_path / "data/cache/huggingface", - airllm_shards_dir=tmp_path / "data/models/airllm-shards", - raw_models_dir=tmp_path / "data/models/raw", - ) + settings = _test_settings(repo_root, tmp_path) client = TestClient(create_api(settings=settings)) root_response = client.get("/", follow_redirects=False) @@ -43,3 +53,38 @@ def test_dashboard_routes_serve_html_and_assets(repo_root, tmp_path) -> None: assert "LAI CONTROL ROOM" in dashboard_response.text assert asset_response.status_code == 200 assert "async function loadOverview" in asset_response.text + + +def test_job_artifact_routes_return_persisted_content(repo_root, tmp_path) -> None: + settings = _test_settings(repo_root, tmp_path) + snapshot = collect_system_snapshot(settings.root_dir, enable_gpu=False) + providers = { + "transformers": FakeProvider(settings, snapshot, "transformers"), + "airllm": FakeProvider(settings, snapshot, "airllm"), + "openai": FakeProvider(settings, snapshot, "openai"), + "anthropic": FakeProvider(settings, snapshot, "anthropic"), + "gemini": FakeProvider(settings, snapshot, "gemini"), + } + seeded_app = create_application(settings=settings, providers=providers) + job = seeded_app.orchestration.submit_request( + ExecutionRequest( + user_prompt="Summarize this short note.", + queue_mode=QueueMode.INLINE, + ) + ) + + client = TestClient(create_api(settings=settings)) + artifacts_response = client.get(f"/jobs/{job.id}/artifacts") + assert artifacts_response.status_code == 200 + artifacts = artifacts_response.json()["artifacts"] + assert artifacts + + final_artifact = next( + artifact for artifact in artifacts if artifact["artifact_type"] == "final-output" + ) + artifact_response = client.get(f"/jobs/{job.id}/artifacts/{final_artifact['id']}") + + assert artifact_response.status_code == 200 + payload = artifact_response.json() + assert payload["content_type"] == "text/plain" + assert "executed by" in payload["content"] From c02a788c66734358679285ac7c9c257a78d6996e Mon Sep 17 00:00:00 2001 From: N3uralCreativity Date: Tue, 7 Apr 2026 19:53:09 +0200 Subject: [PATCH 04/16] feat: add queue controls and job replay --- README.md | 11 ++- apps/web/README.md | 2 + src/lai/api/app.py | 38 +++++++++- src/lai/api/static/dashboard.css | 40 +++++++++++ src/lai/api/static/dashboard.js | 94 ++++++++++++++++++++++++- src/lai/api/static/index.html | 15 +++- src/lai/cli.py | 45 +++++++++++- src/lai/jobs/service.py | 43 ++++++++++- tests/integration/test_orchestration.py | 43 +++++++++++ tests/unit/test_api.py | 82 ++++++++++++++++++--- 10 files changed, 394 insertions(+), 19 deletions(-) diff --git a/README.md b/README.md index 0718650..ea648a8 100644 --- a/README.md +++ b/README.md @@ -98,7 +98,9 @@ python -m lai.cli models check python -m lai.cli route explain "Summarize this note." python -m lai.cli run "Create a detailed implementation strategy." python -m lai.cli jobs list +python -m lai.cli jobs replay --queue-mode queued python -m lai.cli worker run --once +python -m lai.cli worker run --max-jobs 3 python -m lai.cli eval route --no-save ``` @@ -120,10 +122,15 @@ Available endpoints: - `POST /jobs` - `GET /jobs` - `GET /jobs/{job_id}` +- `POST /jobs/{job_id}/replay` +- `GET /jobs/{job_id}/artifacts` +- `GET /jobs/{job_id}/artifacts/{artifact_id}` - `POST /jobs/{job_id}/cancel` +- `POST /worker/run` The API now also serves a read-mostly dashboard at `/dashboard` with live model health, -route explanation, recent job inspection, and artifact/trace browsing. +route explanation, recent job inspection, artifact/trace browsing, replay actions, +and bounded queue worker controls. ## Initial GitHub rules encoded in this repo @@ -138,7 +145,7 @@ route explanation, recent job inspection, and artifact/trace browsing. 1. Add live provider smoke tests behind credentials and optional extras. 2. Harden the AirLLM local runtime path with real workstation validation. 3. Expand eval scenarios and richer reviewer/final-output refinement. -4. Deepen the dashboard with job replay, richer traces, and live worker controls. +4. Deepen the dashboard with richer stage telemetry and stronger background worker automation. ## References diff --git a/apps/web/README.md b/apps/web/README.md index 5b1acec..aad1557 100644 --- a/apps/web/README.md +++ b/apps/web/README.md @@ -8,6 +8,8 @@ Current dashboard capabilities: - route explanation form - job submission and recent queue inspection - artifact and trace browsing for persisted jobs +- job replay controls for inline and queued reruns +- bounded live worker controls for processing queued jobs - model health cards - job output inspector diff --git a/src/lai/api/app.py b/src/lai/api/app.py index 178f80d..fe1e24d 100644 --- a/src/lai/api/app.py +++ b/src/lai/api/app.py @@ -1,6 +1,7 @@ from __future__ import annotations from pathlib import Path +from typing import Mapping from fastapi import FastAPI, HTTPException from fastapi.responses import FileResponse, RedirectResponse @@ -9,6 +10,7 @@ from ..application import create_application from ..domain import ExecutionRequest, QueueMode +from ..providers import Provider from ..settings import Settings @@ -24,7 +26,19 @@ class JobCreatePayload(BaseModel): reviewer_enabled: bool | None = None -def create_api(settings: Settings | None = None) -> FastAPI: +class JobReplayPayload(BaseModel): + queue_mode: QueueMode | None = None + + +class WorkerRunPayload(BaseModel): + max_jobs: int = Field(default=1, ge=1, le=25) + resume_running: bool = False + + +def create_api( + settings: Settings | None = None, + providers: Mapping[str, Provider] | None = None, +) -> FastAPI: settings = settings or Settings() api = FastAPI(title="LAI API", version="0.1.0") static_dir = Path(__file__).resolve().parent / "static" @@ -35,7 +49,7 @@ def create_api(settings: Settings | None = None) -> FastAPI: ) def application(): - return create_application(settings=settings) + return create_application(settings=settings, providers=providers) @api.get("/", include_in_schema=False) def root() -> RedirectResponse: @@ -85,6 +99,17 @@ def create_job(payload: JobCreatePayload) -> dict[str, object]: request = _build_execution_request(app_state, payload) return app_state.orchestration.submit_request(request).model_dump() + @api.post("/jobs/{job_id}/replay") + def replay_job(job_id: str, payload: JobReplayPayload) -> dict[str, object]: + app_state = application() + try: + return app_state.orchestration.replay_job( + job_id, + queue_mode=payload.queue_mode, + ).model_dump() + except KeyError as exc: + raise HTTPException(status_code=404, detail="Job not found") from exc + @api.get("/jobs") def list_jobs(limit: int = 20) -> dict[str, object]: app_state = application() @@ -138,6 +163,15 @@ def cancel_job(job_id: str) -> dict[str, object]: job = app_state.job_store.get_job(job_id) return {"job": job.model_dump() if job else None} + @api.post("/worker/run") + def run_worker(payload: WorkerRunPayload) -> dict[str, object]: + app_state = application() + jobs = app_state.orchestration.run_worker_batch( + max_jobs=payload.max_jobs, + requeue_running=payload.resume_running, + ) + return {"processed": len(jobs), "jobs": [job.model_dump() for job in jobs]} + return api diff --git a/src/lai/api/static/dashboard.css b/src/lai/api/static/dashboard.css index 255a789..a76f502 100644 --- a/src/lai/api/static/dashboard.css +++ b/src/lai/api/static/dashboard.css @@ -158,6 +158,11 @@ body { border: 1px solid rgba(33, 63, 100, 0.12); } +.small { + padding: 0.58rem 0.9rem; + font-size: 0.84rem; +} + .hero-metrics { display: grid; gap: 0.85rem; @@ -214,6 +219,19 @@ body { margin-bottom: 1rem; } +.split-head { + flex-direction: row; + justify-content: space-between; + align-items: start; + gap: 1rem; +} + +.mini-actions { + display: flex; + flex-wrap: wrap; + gap: 0.55rem; +} + .route-form, .models-grid, .jobs-list, @@ -274,6 +292,15 @@ textarea { gap: 0.55rem; } +.inline-status { + margin-bottom: 0.95rem; + min-height: 2.6rem; + padding: 0.8rem 0.95rem; + border-radius: 1rem; + border: 1px dashed rgba(33, 63, 100, 0.14); + background: rgba(255, 255, 255, 0.5); +} + .chip { display: inline-flex; align-items: center; @@ -286,6 +313,11 @@ textarea { font-size: 0.78rem; } +.chip.ready { + background: rgba(13, 107, 97, 0.18); + color: var(--accent); +} + .chip.warn { background: rgba(201, 109, 57, 0.14); color: var(--accent-warm); @@ -413,6 +445,10 @@ textarea { gap: 0.75rem; } +.compact-actions { + margin: 0; +} + .models-grid { display: grid; grid-template-columns: repeat(auto-fit, minmax(220px, 1fr)); @@ -460,4 +496,8 @@ textarea { width: min(100vw - 1rem, 1360px); padding: 1rem 0 2rem; } + + .split-head { + flex-direction: column; + } } diff --git a/src/lai/api/static/dashboard.js b/src/lai/api/static/dashboard.js index c3e88bc..7da5d3b 100644 --- a/src/lai/api/static/dashboard.js +++ b/src/lai/api/static/dashboard.js @@ -18,6 +18,9 @@ const elements = { routeForm: document.getElementById("route-form"), refreshAll: document.getElementById("refresh-all"), submitJob: document.getElementById("submit-job"), + runNextJob: document.getElementById("run-next-job"), + runBatchJobs: document.getElementById("run-batch-jobs"), + queueActionStatus: document.getElementById("queue-action-status"), }; async function fetchJson(url, options = undefined) { @@ -49,6 +52,15 @@ function chip(label, value, extraClass = "") { return `${label}: ${escapeHtml(String(value))}`; } +function setQueueActionStatus(message, tone = "") { + if (!elements.queueActionStatus) { + return; + } + const toneChip = tone ? `${chip("worker", tone, tone)} ` : ""; + elements.queueActionStatus.className = "inline-status fade-in"; + elements.queueActionStatus.innerHTML = `${toneChip}${escapeHtml(message)}`; +} + function escapeHtml(value) { return value .replaceAll("&", "&") @@ -175,7 +187,11 @@ async function selectJob(jobId) { const job = await fetchJson(`/jobs/${jobId}`); const chips = [ - chip("status", job.status, job.status === "failed" ? "danger" : job.status === "succeeded" ? "" : "warn"), + chip( + "status", + job.status, + job.status === "failed" ? "danger" : job.status === "succeeded" ? "" : "warn", + ), chip("queue", job.queue_mode), chip("tier", job.route_decision?.matched_tier_id || "n/a"), chip("executor", job.route_decision?.executor_model_id || "n/a"), @@ -188,6 +204,10 @@ async function selectJob(jobId) { const error = job.error?.message ? `
error: ${escapeHtml(job.error.message)}
` : ""; + const cancelAction = + job.status === "queued" || job.status === "running" + ? '' + : ""; const reasons = (job.route_decision?.reasons || []) .map( @@ -203,6 +223,11 @@ async function selectJob(jobId) { elements.jobDetail.className = "job-detail fade-in"; elements.jobDetail.innerHTML = `
${chips}
+
+ + + ${cancelAction} +
${error}

Final output

@@ -219,6 +244,16 @@ async function selectJob(jobId) {
`; + document.getElementById("replay-inline-job")?.addEventListener("click", () => { + void replayJob(job.id, "inline").catch(handleQueueError); + }); + document.getElementById("replay-queued-job")?.addEventListener("click", () => { + void replayJob(job.id, "queued").catch(handleQueueError); + }); + document.getElementById("cancel-selected-job")?.addEventListener("click", () => { + void cancelJob(job.id).catch(handleQueueError); + }); + await loadArtifacts(job.id, job.artifacts || []); } @@ -343,6 +378,53 @@ async function submitJob() { await selectJob(job.id); } +async function replayJob(jobId, queueMode) { + const job = await fetchJson(`/jobs/${jobId}/replay`, { + method: "POST", + headers: { "Content-Type": "application/json" }, + body: JSON.stringify({ queue_mode: queueMode }), + }); + setQueueActionStatus( + `Created replay job ${job.id.slice(0, 8)} in ${job.queue_mode} mode.`, + "ready", + ); + await loadOverview(); + await selectJob(job.id); +} + +async function cancelJob(jobId) { + await fetchJson(`/jobs/${jobId}/cancel`, { method: "POST" }); + setQueueActionStatus(`Canceled job ${jobId.slice(0, 8)}.`, "warn"); + await loadOverview(); + await selectJob(jobId); +} + +async function runWorker(maxJobs) { + const selectedJobId = state.selectedJobId; + const response = await fetchJson("/worker/run", { + method: "POST", + headers: { "Content-Type": "application/json" }, + body: JSON.stringify({ max_jobs: maxJobs }), + }); + const processed = response.processed || 0; + const latestProcessed = + response.jobs && response.jobs.length > 0 ? response.jobs[response.jobs.length - 1] : null; + setQueueActionStatus( + processed > 0 + ? `Processed ${processed} queued job(s).` + : "No queued jobs were ready to run.", + processed > 0 ? "ready" : "warn", + ); + await loadOverview(); + if (latestProcessed?.id) { + await selectJob(latestProcessed.id); + return; + } + if (selectedJobId) { + await selectJob(selectedJobId); + } +} + function bindEvents() { elements.routeForm.addEventListener("submit", (event) => { void explainRoute(event).catch(handleError); @@ -353,6 +435,12 @@ function bindEvents() { elements.refreshAll.addEventListener("click", () => { void loadOverview().catch(handleError); }); + elements.runNextJob.addEventListener("click", () => { + void runWorker(1).catch(handleQueueError); + }); + elements.runBatchJobs.addEventListener("click", () => { + void runWorker(3).catch(handleQueueError); + }); } function handleError(error) { @@ -360,6 +448,10 @@ function handleError(error) { elements.routeSummary.innerHTML = `
error
${escapeHtml(error.message || String(error))}
`; } +function handleQueueError(error) { + setQueueActionStatus(error.message || String(error), "danger"); +} + async function boot() { bindEvents(); await loadOverview(); diff --git a/src/lai/api/static/index.html b/src/lai/api/static/index.html index 65661df..672d080 100644 --- a/src/lai/api/static/index.html +++ b/src/lai/api/static/index.html @@ -110,9 +110,18 @@

Routing output

-
-

Queue

-

Recent jobs

+
+
+

Queue

+

Recent jobs

+
+
+ + +
+
+
+ Worker idle. Use the queue controls to process queued jobs.
No jobs yet.
diff --git a/src/lai/cli.py b/src/lai/cli.py index 3694bb2..9458a28 100644 --- a/src/lai/cli.py +++ b/src/lai/cli.py @@ -249,13 +249,56 @@ def cancel_job(job_id: str = typer.Argument(..., help="Job identifier.")) -> Non raise typer.Exit(code=1) +@jobs_app.command("replay") +def replay_job( + job_id: str = typer.Argument(..., help="Source job identifier."), + queue_mode: QueueMode | None = typer.Option( + None, help="Override the replay queue mode for the new job." + ), +) -> None: + """Replay a persisted request as a new job.""" + application = create_application() + try: + job = application.orchestration.replay_job(job_id, queue_mode=queue_mode) + except KeyError: + raise typer.Exit(code=1) from None + + if job.queue_mode == QueueMode.QUEUED and job.status == "queued": + console.print(Panel.fit(f"Queued replay job {job.id}", title="LAI Replay")) + return + + console.print( + Panel.fit(job.result.text if job.result else "No output produced.", title=f"Job {job.id}") + ) + + @worker_app.command("run") def run_worker( once: bool = typer.Option(False, help="Process at most one queued job and then exit."), + max_jobs: int | None = typer.Option( + None, + min=1, + help="Process at most this many queued jobs and then exit.", + ), ) -> None: """Run the local worker for queued jobs.""" application = create_application() - processed = application.orchestration.run_worker(once=once) + if once: + processed = len( + application.orchestration.run_worker_batch( + max_jobs=1, + requeue_running=True, + ) + ) + elif max_jobs is not None: + processed = len( + application.orchestration.run_worker_batch( + max_jobs=max_jobs, + requeue_running=True, + ) + ) + else: + processed = application.orchestration.run_worker() console.print(f"Processed {processed} queued job(s).") diff --git a/src/lai/jobs/service.py b/src/lai/jobs/service.py index 4d51015..11448e9 100644 --- a/src/lai/jobs/service.py +++ b/src/lai/jobs/service.py @@ -66,6 +66,22 @@ def submit_request(self, request: ExecutionRequest) -> JobRecord: return self.execute_job(job.id) return self.job_store.get_job(job.id) or job + def replay_job(self, job_id: str, *, queue_mode: QueueMode | None = None) -> JobRecord: + job = self.job_store.get_job(job_id) + if job is None: + raise KeyError(f"Unknown job id {job_id!r}.") + + metadata = dict(job.request.metadata) + metadata["replayed_from_job_id"] = job.id + replay_request = job.request.model_copy( + update={ + "metadata": metadata, + "queue_mode": queue_mode or job.request.queue_mode, + "source": "replay", + } + ) + return self.submit_request(replay_request) + def execute_job(self, job_id: str) -> JobRecord: job = self.job_store.get_job(job_id) if job is None: @@ -106,10 +122,35 @@ def execute_job(self, job_id: str) -> JobRecord: self.artifacts.write_text(job.id, "error", "error.txt", str(exc)) return self.job_store.get_job(job.id) or job - def run_worker(self, once: bool = False) -> int: + def run_worker_batch( + self, + *, + max_jobs: int = 1, + requeue_running: bool = False, + ) -> list[JobRecord]: + if max_jobs < 1: + raise ValueError("max_jobs must be at least 1.") + + if requeue_running: + self.job_store.requeue_running_jobs() + + processed: list[JobRecord] = [] + for _ in range(max_jobs): + job = self.job_store.claim_next_queued_job() + if job is None: + break + processed.append(self.execute_job(job.id)) + return processed + + def run_worker(self, once: bool = False, max_jobs: int | None = None) -> int: + if max_jobs is not None and max_jobs < 1: + raise ValueError("max_jobs must be at least 1 when provided.") + processed = 0 self.job_store.requeue_running_jobs() while True: + if max_jobs is not None and processed >= max_jobs: + return processed job = self.job_store.claim_next_queued_job() if job is None: if once: diff --git a/tests/integration/test_orchestration.py b/tests/integration/test_orchestration.py index ca1127b..a8eacfd 100644 --- a/tests/integration/test_orchestration.py +++ b/tests/integration/test_orchestration.py @@ -62,3 +62,46 @@ def test_queued_job_survives_restart(tmp_path, repo_root) -> None: assert persisted is not None assert persisted.status == JobStatus.SUCCEEDED assert persisted.result is not None + + +def test_replay_job_creates_a_new_persisted_request(tmp_path, repo_root) -> None: + app = _make_application(tmp_path, repo_root) + source_job = app.orchestration.submit_request( + ExecutionRequest( + user_prompt="Prepare a structured review of the platform design.", + queue_mode=QueueMode.INLINE, + ) + ) + + replayed_job = app.orchestration.replay_job(source_job.id, queue_mode=QueueMode.QUEUED) + + assert replayed_job.id != source_job.id + assert replayed_job.status == JobStatus.QUEUED + assert replayed_job.request.user_prompt == source_job.request.user_prompt + assert replayed_job.request.metadata["replayed_from_job_id"] == source_job.id + + +def test_worker_batch_processes_only_requested_number_of_jobs(tmp_path, repo_root) -> None: + app = _make_application(tmp_path, repo_root) + first_job = app.orchestration.submit_request( + ExecutionRequest( + user_prompt="First queued system design pass.", + queue_mode=QueueMode.QUEUED, + ) + ) + second_job = app.orchestration.submit_request( + ExecutionRequest( + user_prompt="Second queued system design pass.", + queue_mode=QueueMode.QUEUED, + ) + ) + + processed = app.orchestration.run_worker_batch(max_jobs=1) + refreshed_first = app.job_store.get_job(first_job.id) + refreshed_second = app.job_store.get_job(second_job.id) + + assert len(processed) == 1 + assert refreshed_first is not None + assert refreshed_second is not None + assert [refreshed_first.status, refreshed_second.status].count(JobStatus.SUCCEEDED) == 1 + assert [refreshed_first.status, refreshed_second.status].count(JobStatus.QUEUED) == 1 diff --git a/tests/unit/test_api.py b/tests/unit/test_api.py index d63f400..cc7c049 100644 --- a/tests/unit/test_api.py +++ b/tests/unit/test_api.py @@ -23,6 +23,17 @@ def _test_settings(repo_root, tmp_path) -> Settings: ) +def _test_providers(settings: Settings) -> dict[str, FakeProvider]: + snapshot = collect_system_snapshot(settings.root_dir, enable_gpu=False) + return { + "transformers": FakeProvider(settings, snapshot, "transformers"), + "airllm": FakeProvider(settings, snapshot, "airllm"), + "openai": FakeProvider(settings, snapshot, "openai"), + "anthropic": FakeProvider(settings, snapshot, "anthropic"), + "gemini": FakeProvider(settings, snapshot, "gemini"), + } + + def test_api_exposes_expected_routes() -> None: app = create_api() routes = {route.path for route in app.routes} @@ -34,9 +45,11 @@ def test_api_exposes_expected_routes() -> None: assert "/route/explain" in routes assert "/jobs" in routes assert "/jobs/{job_id}" in routes + assert "/jobs/{job_id}/replay" in routes assert "/jobs/{job_id}/artifacts" in routes assert "/jobs/{job_id}/artifacts/{artifact_id}" in routes assert "/jobs/{job_id}/cancel" in routes + assert "/worker/run" in routes def test_dashboard_routes_serve_html_and_assets(repo_root, tmp_path) -> None: @@ -57,14 +70,7 @@ def test_dashboard_routes_serve_html_and_assets(repo_root, tmp_path) -> None: def test_job_artifact_routes_return_persisted_content(repo_root, tmp_path) -> None: settings = _test_settings(repo_root, tmp_path) - snapshot = collect_system_snapshot(settings.root_dir, enable_gpu=False) - providers = { - "transformers": FakeProvider(settings, snapshot, "transformers"), - "airllm": FakeProvider(settings, snapshot, "airllm"), - "openai": FakeProvider(settings, snapshot, "openai"), - "anthropic": FakeProvider(settings, snapshot, "anthropic"), - "gemini": FakeProvider(settings, snapshot, "gemini"), - } + providers = _test_providers(settings) seeded_app = create_application(settings=settings, providers=providers) job = seeded_app.orchestration.submit_request( ExecutionRequest( @@ -73,7 +79,7 @@ def test_job_artifact_routes_return_persisted_content(repo_root, tmp_path) -> No ) ) - client = TestClient(create_api(settings=settings)) + client = TestClient(create_api(settings=settings, providers=providers)) artifacts_response = client.get(f"/jobs/{job.id}/artifacts") assert artifacts_response.status_code == 200 artifacts = artifacts_response.json()["artifacts"] @@ -88,3 +94,61 @@ def test_job_artifact_routes_return_persisted_content(repo_root, tmp_path) -> No payload = artifact_response.json() assert payload["content_type"] == "text/plain" assert "executed by" in payload["content"] + + +def test_job_replay_route_creates_new_job(repo_root, tmp_path) -> None: + settings = _test_settings(repo_root, tmp_path) + providers = _test_providers(settings) + seeded_app = create_application(settings=settings, providers=providers) + source_job = seeded_app.orchestration.submit_request( + ExecutionRequest( + user_prompt="Draft an internal plan for the queue system.", + queue_mode=QueueMode.INLINE, + ) + ) + + client = TestClient(create_api(settings=settings, providers=providers)) + replay_response = client.post( + f"/jobs/{source_job.id}/replay", + json={"queue_mode": "queued"}, + ) + + assert replay_response.status_code == 200 + replayed_job = replay_response.json() + assert replayed_job["id"] != source_job.id + assert replayed_job["status"] == "queued" + assert replayed_job["request"]["metadata"]["replayed_from_job_id"] == source_job.id + + +def test_worker_run_route_processes_bounded_batch(repo_root, tmp_path) -> None: + settings = _test_settings(repo_root, tmp_path) + providers = _test_providers(settings) + seeded_app = create_application(settings=settings, providers=providers) + first_job = seeded_app.orchestration.submit_request( + ExecutionRequest( + user_prompt="Prepare a long-form architecture review.", + queue_mode=QueueMode.QUEUED, + ) + ) + second_job = seeded_app.orchestration.submit_request( + ExecutionRequest( + user_prompt="Prepare another queued execution.", + queue_mode=QueueMode.QUEUED, + ) + ) + + client = TestClient(create_api(settings=settings, providers=providers)) + run_response = client.post("/worker/run", json={"max_jobs": 1}) + + assert run_response.status_code == 200 + payload = run_response.json() + assert payload["processed"] == 1 + assert len(payload["jobs"]) == 1 + assert payload["jobs"][0]["status"] == "succeeded" + + refreshed_first = seeded_app.job_store.get_job(first_job.id) + refreshed_second = seeded_app.job_store.get_job(second_job.id) + assert refreshed_first is not None + assert refreshed_second is not None + assert [refreshed_first.status, refreshed_second.status].count("succeeded") == 1 + assert [refreshed_first.status, refreshed_second.status].count("queued") == 1 From 75190f8e22a0258c0dcd9e3611adf9a0cf846e27 Mon Sep 17 00:00:00 2001 From: N3uralCreativity Date: Tue, 7 Apr 2026 19:58:58 +0200 Subject: [PATCH 05/16] feat: add persisted stage telemetry --- README.md | 7 +- apps/web/README.md | 1 + src/lai/api/app.py | 14 ++- src/lai/api/static/dashboard.css | 35 +++++- src/lai/api/static/dashboard.js | 54 +++++++++ src/lai/cli.py | 19 +++- src/lai/domain.py | 13 +++ src/lai/jobs/service.py | 145 +++++++++++++++++++++++- src/lai/jobs/store.py | 63 +++++++++- tests/integration/test_orchestration.py | 13 +++ tests/unit/test_api.py | 24 ++++ 11 files changed, 377 insertions(+), 11 deletions(-) diff --git a/README.md b/README.md index ea648a8..6961c5d 100644 --- a/README.md +++ b/README.md @@ -122,6 +122,7 @@ Available endpoints: - `POST /jobs` - `GET /jobs` - `GET /jobs/{job_id}` +- `GET /jobs/{job_id}/timeline` - `POST /jobs/{job_id}/replay` - `GET /jobs/{job_id}/artifacts` - `GET /jobs/{job_id}/artifacts/{artifact_id}` @@ -129,8 +130,8 @@ Available endpoints: - `POST /worker/run` The API now also serves a read-mostly dashboard at `/dashboard` with live model health, -route explanation, recent job inspection, artifact/trace browsing, replay actions, -and bounded queue worker controls. +route explanation, recent job inspection, stage telemetry, artifact/trace browsing, +replay actions, and bounded queue worker controls. ## Initial GitHub rules encoded in this repo @@ -145,7 +146,7 @@ and bounded queue worker controls. 1. Add live provider smoke tests behind credentials and optional extras. 2. Harden the AirLLM local runtime path with real workstation validation. 3. Expand eval scenarios and richer reviewer/final-output refinement. -4. Deepen the dashboard with richer stage telemetry and stronger background worker automation. +4. Strengthen background worker automation and live job monitoring for longer-running executions. ## References diff --git a/apps/web/README.md b/apps/web/README.md index aad1557..dc051ad 100644 --- a/apps/web/README.md +++ b/apps/web/README.md @@ -7,6 +7,7 @@ Current dashboard capabilities: - live health and catalog summary - route explanation form - job submission and recent queue inspection +- persisted stage telemetry timeline for planner, executor, and reviewer flow - artifact and trace browsing for persisted jobs - job replay controls for inline and queued reruns - bounded live worker controls for processing queued jobs diff --git a/src/lai/api/app.py b/src/lai/api/app.py index fe1e24d..9034f11 100644 --- a/src/lai/api/app.py +++ b/src/lai/api/app.py @@ -123,6 +123,14 @@ def get_job(job_id: str) -> dict[str, object]: raise HTTPException(status_code=404, detail="Job not found") return job.model_dump() + @api.get("/jobs/{job_id}/timeline") + def get_job_timeline(job_id: str) -> dict[str, object]: + app_state = application() + job = app_state.job_store.get_job(job_id) + if job is None: + raise HTTPException(status_code=404, detail="Job not found") + return {"stage_events": [event.model_dump() for event in job.stage_events]} + @api.get("/jobs/{job_id}/artifacts") def list_job_artifacts(job_id: str) -> dict[str, object]: app_state = application() @@ -158,10 +166,10 @@ def get_job_artifact(job_id: str, artifact_id: str) -> dict[str, object]: @api.post("/jobs/{job_id}/cancel") def cancel_job(job_id: str) -> dict[str, object]: app_state = application() - if not app_state.job_store.cancel_job(job_id): + job = app_state.orchestration.cancel_job(job_id) + if job is None: raise HTTPException(status_code=404, detail="Job not found or not cancelable") - job = app_state.job_store.get_job(job_id) - return {"job": job.model_dump() if job else None} + return {"job": job.model_dump()} @api.post("/worker/run") def run_worker(payload: WorkerRunPayload) -> dict[str, object]: diff --git a/src/lai/api/static/dashboard.css b/src/lai/api/static/dashboard.css index a76f502..b4adaac 100644 --- a/src/lai/api/static/dashboard.css +++ b/src/lai/api/static/dashboard.css @@ -330,7 +330,8 @@ textarea { .reason-item, .job-card, -.model-card { +.model-card, +.timeline-event { border: 1px solid rgba(33, 63, 100, 0.11); background: rgba(255, 255, 255, 0.64); border-radius: 1rem; @@ -358,6 +359,11 @@ textarea { gap: 0.75rem; } +.timeline-list { + display: grid; + gap: 0.75rem; +} + .job-card { padding: 0.95rem 1rem; cursor: pointer; @@ -449,6 +455,29 @@ textarea { margin: 0; } +.timeline-event { + padding: 0.95rem 1rem; +} + +.timeline-event-head { + display: flex; + justify-content: space-between; + gap: 0.9rem; + align-items: start; +} + +.timeline-time { + color: var(--muted); + font-family: "IBM Plex Mono", monospace; + font-size: 0.76rem; + white-space: nowrap; +} + +.timeline-message { + margin-top: 0.55rem; + color: var(--ink); +} + .models-grid { display: grid; grid-template-columns: repeat(auto-fit, minmax(220px, 1fr)); @@ -500,4 +529,8 @@ textarea { .split-head { flex-direction: column; } + + .timeline-event-head { + flex-direction: column; + } } diff --git a/src/lai/api/static/dashboard.js b/src/lai/api/static/dashboard.js index 7da5d3b..f74a307 100644 --- a/src/lai/api/static/dashboard.js +++ b/src/lai/api/static/dashboard.js @@ -237,6 +237,10 @@ async function selectJob(jobId) {

Route trace

${reasons || '
No route reasons recorded.
'}
+
+

Stage timeline

+
Loading stage telemetry...
+

Artifacts

Loading artifacts...
@@ -254,9 +258,59 @@ async function selectJob(jobId) { void cancelJob(job.id).catch(handleQueueError); }); + await loadTimeline(job.id, job.stage_events || []); await loadArtifacts(job.id, job.artifacts || []); } +async function loadTimeline(jobId, stageEvents) { + const response = + stageEvents.length > 0 + ? { stage_events: stageEvents } + : await fetchJson(`/jobs/${jobId}/timeline`); + const events = response.stage_events || []; + const timeline = document.getElementById("timeline-list"); + if (!timeline) { + return; + } + + if (events.length === 0) { + timeline.className = "timeline-list empty"; + timeline.textContent = "No stage telemetry recorded for this job yet."; + return; + } + + timeline.className = "timeline-list fade-in"; + timeline.innerHTML = events + .map((event) => { + const tone = + event.event_type === "failed" || event.event_type === "blocked" + ? "danger" + : event.event_type === "running" || event.event_type === "started" + ? "warn" + : "ready"; + const detailChips = [ + chip("stage", event.stage), + chip("event", event.event_type, tone), + ]; + if (event.model_id) { + detailChips.push(chip("model", event.model_id)); + } + if (event.provider_id) { + detailChips.push(chip("provider", event.provider_id)); + } + return ` +
+
+
${detailChips.join("")}
+
${escapeHtml(new Date(event.created_at).toLocaleString())}
+
+
${escapeHtml(event.message)}
+
+ `; + }) + .join(""); +} + async function loadArtifacts(jobId, jobArtifacts) { const response = jobArtifacts.length > 0 ? { artifacts: jobArtifacts } : await fetchJson(`/jobs/${jobId}/artifacts`); diff --git a/src/lai/cli.py b/src/lai/cli.py index 9458a28..5f119b1 100644 --- a/src/lai/cli.py +++ b/src/lai/cli.py @@ -231,19 +231,36 @@ def show_job(job_id: str = typer.Argument(..., help="Job identifier.")) -> None: "Executor", job.route_decision.executor_model_id if job.route_decision else "n/a" ) summary.add_row("Artifacts", str(len(job.artifacts))) + summary.add_row("Stage events", str(len(job.stage_events))) console.print(summary) if job.result: console.print(Panel.fit(job.result.text, title="Latest Output")) if job.error: console.print(Panel.fit(job.error.message, title="Error")) + if job.stage_events: + timeline = Table(title="Stage Timeline") + timeline.add_column("When") + timeline.add_column("Stage") + timeline.add_column("Event") + timeline.add_column("Model") + timeline.add_column("Message") + for event in job.stage_events: + timeline.add_row( + event.created_at.isoformat(timespec="seconds"), + event.stage, + event.event_type, + event.model_id or "n/a", + event.message, + ) + console.print(timeline) @jobs_app.command("cancel") def cancel_job(job_id: str = typer.Argument(..., help="Job identifier.")) -> None: """Cancel a queued or running job.""" application = create_application() - if application.job_store.cancel_job(job_id): + if application.orchestration.cancel_job(job_id): console.print(f"Canceled job {job_id}") return raise typer.Exit(code=1) diff --git a/src/lai/domain.py b/src/lai/domain.py index 2f9ad7a..ffb986c 100644 --- a/src/lai/domain.py +++ b/src/lai/domain.py @@ -300,6 +300,18 @@ class ArtifactRecord(LAIModel): metadata: dict[str, Any] = Field(default_factory=dict) +class StageEventRecord(LAIModel): + id: str + job_id: str + stage: str + event_type: str + message: str + created_at: datetime = Field(default_factory=utcnow) + model_id: str | None = None + provider_id: str | None = None + metadata: dict[str, Any] = Field(default_factory=dict) + + class JobRecord(LAIModel): id: str status: JobStatus @@ -315,3 +327,4 @@ class JobRecord(LAIModel): error: JobError | None = None attempts: int = 0 artifacts: list[ArtifactRecord] = Field(default_factory=list) + stage_events: list[StageEventRecord] = Field(default_factory=list) diff --git a/src/lai/jobs/service.py b/src/lai/jobs/service.py index 11448e9..26e7375 100644 --- a/src/lai/jobs/service.py +++ b/src/lai/jobs/service.py @@ -14,6 +14,7 @@ ProviderRequest, QueueMode, RoutingDecision, + StageEventRecord, utcnow, ) from ..errors import RetryableProviderError @@ -61,6 +62,32 @@ def submit_request(self, request: ExecutionRequest) -> JobRecord: self.artifacts.write_json( job.id, "route-decision", "route.json", route_decision.model_dump() ) + self._record_event( + job_id=job.id, + stage="routing", + event_type="decision", + message=( + f"Matched tier {route_decision.matched_tier_id} with executor " + f"{route_decision.executor_model_id}." + ), + metadata={ + "planner_model_id": route_decision.planner_model_id, + "reviewer_model_id": route_decision.reviewer_model_id, + "queue_recommended": route_decision.queue_recommended, + "fallback_model_ids": route_decision.fallback_model_ids, + }, + ) + self._record_event( + job_id=job.id, + stage="job", + event_type="queued" if queue_mode == QueueMode.QUEUED else "inline-dispatch", + message=( + "Queued for worker execution." + if queue_mode == QueueMode.QUEUED + else "Scheduled for immediate inline execution." + ), + metadata={"queue_mode": queue_mode}, + ) if queue_mode == QueueMode.INLINE: return self.execute_job(job.id) @@ -80,7 +107,26 @@ def replay_job(self, job_id: str, *, queue_mode: QueueMode | None = None) -> Job "source": "replay", } ) - return self.submit_request(replay_request) + replayed_job = self.submit_request(replay_request) + self._record_event( + job_id=replayed_job.id, + stage="job", + event_type="replayed", + message=f"Replayed from job {job.id}.", + metadata={"source_job_id": job.id}, + ) + return self.job_store.get_job(replayed_job.id) or replayed_job + + def cancel_job(self, job_id: str) -> JobRecord | None: + if not self.job_store.cancel_job(job_id): + return None + self._record_event( + job_id=job_id, + stage="job", + event_type="canceled", + message="Job canceled by user action.", + ) + return self.job_store.get_job(job_id) def execute_job(self, job_id: str) -> JobRecord: job = self.job_store.get_job(job_id) @@ -95,6 +141,13 @@ def execute_job(self, job_id: str) -> JobRecord: job.started_at = job.started_at or utcnow() job.updated_at = utcnow() self.job_store.save_job(job) + self._record_event( + job_id=job.id, + stage="job", + event_type="running", + message="Job execution started.", + metadata={"attempt": job.attempts}, + ) try: final_text = self._run_pipeline(job) @@ -105,6 +158,13 @@ def execute_job(self, job_id: str) -> JobRecord: job.updated_at = utcnow() self.job_store.save_job(job) self.artifacts.write_text(job.id, "final-output", "final_output.txt", final_text) + self._record_event( + job_id=job.id, + stage="job", + event_type="succeeded", + message="Job completed successfully.", + metadata={"artifacts": len(job.artifacts) + 1}, + ) return self.job_store.get_job(job.id) or job except Exception as exc: retryable = _is_retryable_exception(exc) @@ -115,10 +175,24 @@ def execute_job(self, job_id: str) -> JobRecord: if retryable and job.attempts <= self.settings.max_retry_attempts: job.status = JobStatus.QUEUED self.job_store.save_job(job) + self._record_event( + job_id=job.id, + stage="job", + event_type="retry-queued", + message="Retryable failure encountered. Job returned to queue.", + metadata={"error_type": type(exc).__name__}, + ) else: job.status = JobStatus.FAILED job.finished_at = utcnow() self.job_store.save_job(job) + self._record_event( + job_id=job.id, + stage="job", + event_type="failed", + message=f"Job failed with {type(exc).__name__}.", + metadata={"error_type": type(exc).__name__}, + ) self.artifacts.write_text(job.id, "error", "error.txt", str(exc)) return self.job_store.get_job(job.id) or job @@ -229,9 +303,27 @@ def _run_stage( ) -> str: model = self.config.model_catalog.get_model(model_id) provider = self.provider_registry.provider_for_model(model) + provider_id = getattr(provider, "provider_id", model.provider_id) + self._record_event( + job_id=job.id, + stage=stage, + event_type="started", + message=f"{stage.title()} stage started with model {model.id}.", + model_id=model.id, + provider_id=provider_id, + ) health = self.provider_registry.healthcheck(model) if not health.available: reason = "; ".join(health.reasons) or "unavailable" + self._record_event( + job_id=job.id, + stage=stage, + event_type="blocked", + message=f"{stage.title()} stage blocked: {reason}", + model_id=model.id, + provider_id=provider_id, + metadata={"health_reasons": health.reasons}, + ) raise RuntimeError( f"Model {model.id!r} is not executable: {reason}" ) @@ -246,17 +338,66 @@ def _run_stage( self.artifacts.write_json( job.id, f"{stage}-request", f"{stage}_request.json", request.model_dump() ) - result = provider.generate(model, request) + try: + result = provider.generate(model, request) + except Exception as exc: + self._record_event( + job_id=job.id, + stage=stage, + event_type="failed", + message=f"{stage.title()} stage failed with {type(exc).__name__}.", + model_id=model.id, + provider_id=provider_id, + metadata={"error_type": type(exc).__name__}, + ) + raise result.stage = stage self.artifacts.write_json( job.id, f"{stage}-response", f"{stage}_response.json", result.model_dump() ) + self._record_event( + job_id=job.id, + stage=stage, + event_type="completed", + message=f"{stage.title()} stage completed.", + model_id=model.id, + provider_id=provider_id, + metadata={ + "duration_seconds": result.duration_seconds, + "finish_reason": result.finish_reason, + "usage": result.usage.model_dump() if result.usage else None, + }, + ) if stage == "executor": job.result = result job.updated_at = utcnow() self.job_store.save_job(job) return result.text + def _record_event( + self, + *, + job_id: str, + stage: str, + event_type: str, + message: str, + model_id: str | None = None, + provider_id: str | None = None, + metadata: dict[str, object] | None = None, + ) -> StageEventRecord: + event = StageEventRecord( + id=str(uuid4()), + job_id=job_id, + stage=stage, + event_type=event_type, + message=message, + model_id=model_id, + provider_id=provider_id, + metadata=metadata or {}, + ) + self.job_store.add_stage_event(event) + return event + @staticmethod def _resolved_queue_mode(queue_mode: QueueMode, route_decision: RoutingDecision) -> QueueMode: queue_mode_value = _enum_value(queue_mode) diff --git a/src/lai/jobs/store.py b/src/lai/jobs/store.py index 67feade..6c0faba 100644 --- a/src/lai/jobs/store.py +++ b/src/lai/jobs/store.py @@ -4,7 +4,7 @@ import sqlite3 from pathlib import Path -from ..domain import ArtifactRecord, JobRecord, JobStatus, utcnow +from ..domain import ArtifactRecord, JobRecord, JobStatus, StageEventRecord, utcnow class JobStore: @@ -41,6 +41,19 @@ def initialize(self) -> None: metadata_json TEXT NOT NULL, FOREIGN KEY(job_id) REFERENCES jobs(id) ); + + CREATE TABLE IF NOT EXISTS stage_events ( + id TEXT PRIMARY KEY, + job_id TEXT NOT NULL, + stage TEXT NOT NULL, + event_type TEXT NOT NULL, + message TEXT NOT NULL, + created_at TEXT NOT NULL, + model_id TEXT, + provider_id TEXT, + metadata_json TEXT NOT NULL, + FOREIGN KEY(job_id) REFERENCES jobs(id) + ); """ ) @@ -100,6 +113,28 @@ def add_artifact(self, artifact: ArtifactRecord) -> None: ), ) + def add_stage_event(self, event: StageEventRecord) -> None: + with self._connect() as connection: + connection.execute( + """ + INSERT OR REPLACE INTO stage_events ( + id, job_id, stage, event_type, message, created_at, + model_id, provider_id, metadata_json + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) + """, + ( + event.id, + event.job_id, + event.stage, + event.event_type, + event.message, + event.created_at.isoformat(), + event.model_id, + event.provider_id, + json.dumps(event.metadata), + ), + ) + def get_job(self, job_id: str) -> JobRecord | None: with self._connect() as connection: row = connection.execute("SELECT * FROM jobs WHERE id = ?", (job_id,)).fetchone() @@ -170,6 +205,7 @@ def requeue_running_jobs(self) -> int: def _deserialize_job(self, connection: sqlite3.Connection, row: sqlite3.Row) -> JobRecord: artifacts = self._artifacts_for_job(connection, row["id"]) + stage_events = self._stage_events_for_job(connection, row["id"]) return JobRecord.model_validate( { "id": row["id"], @@ -186,6 +222,7 @@ def _deserialize_job(self, connection: sqlite3.Connection, row: sqlite3.Row) -> "error": json.loads(row["error_json"]) if row["error_json"] else None, "attempts": row["attempts"], "artifacts": artifacts, + "stage_events": stage_events, } ) @@ -210,6 +247,30 @@ def _artifacts_for_job( for row in rows ] + def _stage_events_for_job( + self, connection: sqlite3.Connection, job_id: str + ) -> list[StageEventRecord]: + rows = connection.execute( + "SELECT * FROM stage_events WHERE job_id = ? ORDER BY created_at ASC", + (job_id,), + ).fetchall() + return [ + StageEventRecord.model_validate( + { + "id": row["id"], + "job_id": row["job_id"], + "stage": row["stage"], + "event_type": row["event_type"], + "message": row["message"], + "created_at": row["created_at"], + "model_id": row["model_id"], + "provider_id": row["provider_id"], + "metadata": json.loads(row["metadata_json"]), + } + ) + for row in rows + ] + def _connect(self) -> sqlite3.Connection: connection = sqlite3.connect(self.database_path) connection.row_factory = sqlite3.Row diff --git a/tests/integration/test_orchestration.py b/tests/integration/test_orchestration.py index a8eacfd..8e76901 100644 --- a/tests/integration/test_orchestration.py +++ b/tests/integration/test_orchestration.py @@ -43,6 +43,15 @@ def test_inline_execution_persists_artifacts(tmp_path, repo_root) -> None: final_output = Path(app.settings.resolved_artifacts_dir) / job.id / "final_output.txt" assert final_output.exists() assert "executed by" in final_output.read_text(encoding="utf-8") + assert job.stage_events + assert any( + event.stage == "executor" and event.event_type == "completed" + for event in job.stage_events + ) + assert any( + event.stage == "job" and event.event_type == "succeeded" + for event in job.stage_events + ) def test_queued_job_survives_restart(tmp_path, repo_root) -> None: @@ -79,6 +88,10 @@ def test_replay_job_creates_a_new_persisted_request(tmp_path, repo_root) -> None assert replayed_job.status == JobStatus.QUEUED assert replayed_job.request.user_prompt == source_job.request.user_prompt assert replayed_job.request.metadata["replayed_from_job_id"] == source_job.id + assert any( + event.stage == "job" and event.event_type == "replayed" + for event in replayed_job.stage_events + ) def test_worker_batch_processes_only_requested_number_of_jobs(tmp_path, repo_root) -> None: diff --git a/tests/unit/test_api.py b/tests/unit/test_api.py index cc7c049..9609fd8 100644 --- a/tests/unit/test_api.py +++ b/tests/unit/test_api.py @@ -45,6 +45,7 @@ def test_api_exposes_expected_routes() -> None: assert "/route/explain" in routes assert "/jobs" in routes assert "/jobs/{job_id}" in routes + assert "/jobs/{job_id}/timeline" in routes assert "/jobs/{job_id}/replay" in routes assert "/jobs/{job_id}/artifacts" in routes assert "/jobs/{job_id}/artifacts/{artifact_id}" in routes @@ -96,6 +97,29 @@ def test_job_artifact_routes_return_persisted_content(repo_root, tmp_path) -> No assert "executed by" in payload["content"] +def test_job_timeline_route_returns_stage_events(repo_root, tmp_path) -> None: + settings = _test_settings(repo_root, tmp_path) + providers = _test_providers(settings) + seeded_app = create_application(settings=settings, providers=providers) + job = seeded_app.orchestration.submit_request( + ExecutionRequest( + user_prompt="Summarize this short note.", + queue_mode=QueueMode.INLINE, + ) + ) + + client = TestClient(create_api(settings=settings, providers=providers)) + timeline_response = client.get(f"/jobs/{job.id}/timeline") + + assert timeline_response.status_code == 200 + payload = timeline_response.json() + assert payload["stage_events"] + assert any( + event["stage"] == "executor" and event["event_type"] == "completed" + for event in payload["stage_events"] + ) + + def test_job_replay_route_creates_new_job(repo_root, tmp_path) -> None: settings = _test_settings(repo_root, tmp_path) providers = _test_providers(settings) From 499661225215d088e83a8945b08d03cedb611704 Mon Sep 17 00:00:00 2001 From: N3uralCreativity Date: Tue, 7 Apr 2026 20:06:08 +0200 Subject: [PATCH 06/16] feat: add worker monitoring and queue drain --- README.md | 8 +- apps/web/README.md | 3 +- src/lai/api/app.py | 27 ++- src/lai/api/static/dashboard.css | 10 + src/lai/api/static/dashboard.js | 82 +++++++- src/lai/api/static/index.html | 6 + src/lai/cli.py | 43 ++++ src/lai/domain.py | 22 ++ src/lai/jobs/service.py | 261 +++++++++++++++++++++++- src/lai/jobs/store.py | 94 ++++++++- tests/integration/test_orchestration.py | 25 +++ tests/unit/test_api.py | 38 ++++ 12 files changed, 600 insertions(+), 19 deletions(-) diff --git a/README.md b/README.md index 6961c5d..fd5f9be 100644 --- a/README.md +++ b/README.md @@ -101,6 +101,8 @@ python -m lai.cli jobs list python -m lai.cli jobs replay --queue-mode queued python -m lai.cli worker run --once python -m lai.cli worker run --max-jobs 3 +python -m lai.cli worker run --until-idle --max-idle-cycles 1 +python -m lai.cli worker status python -m lai.cli eval route --no-save ``` @@ -118,6 +120,7 @@ Available endpoints: - `GET /dashboard` - `GET /health` - `GET /models` +- `GET /worker/status` - `POST /route/explain` - `POST /jobs` - `GET /jobs` @@ -131,7 +134,8 @@ Available endpoints: The API now also serves a read-mostly dashboard at `/dashboard` with live model health, route explanation, recent job inspection, stage telemetry, artifact/trace browsing, -replay actions, and bounded queue worker controls. +replay actions, persisted worker monitoring, and bounded queue worker controls including +an until-idle drain path. ## Initial GitHub rules encoded in this repo @@ -146,7 +150,7 @@ replay actions, and bounded queue worker controls. 1. Add live provider smoke tests behind credentials and optional extras. 2. Harden the AirLLM local runtime path with real workstation validation. 3. Expand eval scenarios and richer reviewer/final-output refinement. -4. Strengthen background worker automation and live job monitoring for longer-running executions. +4. Add a dedicated long-running local worker service or daemon entrypoint for always-on queue processing. ## References diff --git a/apps/web/README.md b/apps/web/README.md index dc051ad..f55c577 100644 --- a/apps/web/README.md +++ b/apps/web/README.md @@ -10,7 +10,8 @@ Current dashboard capabilities: - persisted stage telemetry timeline for planner, executor, and reviewer flow - artifact and trace browsing for persisted jobs - job replay controls for inline and queued reruns -- bounded live worker controls for processing queued jobs +- persisted worker monitoring with heartbeat, current job, and queue depth +- bounded live worker controls for processing queued jobs, including queue drain until idle - model health cards - job output inspector diff --git a/src/lai/api/app.py b/src/lai/api/app.py index 9034f11..269ae84 100644 --- a/src/lai/api/app.py +++ b/src/lai/api/app.py @@ -31,8 +31,10 @@ class JobReplayPayload(BaseModel): class WorkerRunPayload(BaseModel): - max_jobs: int = Field(default=1, ge=1, le=25) + max_jobs: int | None = Field(default=None, ge=1, le=25) resume_running: bool = False + until_idle: bool = False + max_idle_cycles: int = Field(default=2, ge=1, le=20) def create_api( @@ -87,6 +89,11 @@ def list_models() -> dict[str, object]: ] } + @api.get("/worker/status") + def worker_status() -> dict[str, object]: + app_state = application() + return app_state.orchestration.get_worker_state().model_dump() + @api.post("/route/explain") def explain_route(payload: JobCreatePayload) -> dict[str, object]: app_state = application() @@ -174,11 +181,25 @@ def cancel_job(job_id: str) -> dict[str, object]: @api.post("/worker/run") def run_worker(payload: WorkerRunPayload) -> dict[str, object]: app_state = application() + if payload.until_idle: + processed = app_state.orchestration.run_worker_until_idle( + max_idle_cycles=payload.max_idle_cycles, + max_jobs=payload.max_jobs, + requeue_running=payload.resume_running, + ) + worker_state = app_state.orchestration.get_worker_state() + return {"processed": processed, "jobs": [], "worker": worker_state.model_dump()} + jobs = app_state.orchestration.run_worker_batch( - max_jobs=payload.max_jobs, + max_jobs=payload.max_jobs or 1, requeue_running=payload.resume_running, ) - return {"processed": len(jobs), "jobs": [job.model_dump() for job in jobs]} + worker_state = app_state.orchestration.get_worker_state() + return { + "processed": len(jobs), + "jobs": [job.model_dump() for job in jobs], + "worker": worker_state.model_dump(), + } return api diff --git a/src/lai/api/static/dashboard.css b/src/lai/api/static/dashboard.css index b4adaac..d80c4f4 100644 --- a/src/lai/api/static/dashboard.css +++ b/src/lai/api/static/dashboard.css @@ -301,6 +301,16 @@ textarea { background: rgba(255, 255, 255, 0.5); } +.worker-summary { + display: grid; + gap: 0.7rem; + margin-bottom: 0.95rem; + padding: 0.9rem 1rem; + border-radius: 1rem; + border: 1px solid rgba(33, 63, 100, 0.12); + background: rgba(255, 255, 255, 0.58); +} + .chip { display: inline-flex; align-items: center; diff --git a/src/lai/api/static/dashboard.js b/src/lai/api/static/dashboard.js index f74a307..8c1880d 100644 --- a/src/lai/api/static/dashboard.js +++ b/src/lai/api/static/dashboard.js @@ -1,6 +1,7 @@ const state = { jobs: [], models: [], + worker: null, selectedJobId: null, selectedArtifactId: null, }; @@ -10,6 +11,7 @@ const elements = { modelCount: document.getElementById("metric-model-count"), jobCount: document.getElementById("metric-job-count"), queueState: document.getElementById("metric-queue-state"), + workerState: document.getElementById("metric-worker-state"), jobsList: document.getElementById("jobs-list"), modelsGrid: document.getElementById("models-grid"), routeSummary: document.getElementById("route-summary"), @@ -20,7 +22,9 @@ const elements = { submitJob: document.getElementById("submit-job"), runNextJob: document.getElementById("run-next-job"), runBatchJobs: document.getElementById("run-batch-jobs"), + drainQueue: document.getElementById("drain-queue"), queueActionStatus: document.getElementById("queue-action-status"), + workerSummary: document.getElementById("worker-summary"), }; async function fetchJson(url, options = undefined) { @@ -71,19 +75,22 @@ function escapeHtml(value) { } async function loadOverview() { - const [health, modelsResponse, jobsResponse] = await Promise.all([ + const [health, modelsResponse, jobsResponse, workerResponse] = await Promise.all([ fetchJson("/health"), fetchJson("/models"), fetchJson("/jobs?limit=12"), + fetchJson("/worker/status"), ]); state.models = modelsResponse.models; state.jobs = jobsResponse.jobs; + state.worker = workerResponse; elements.environment.textContent = health.environment; elements.modelCount.textContent = String(health.model_count); elements.jobCount.textContent = String(state.jobs.length); elements.queueState.textContent = summarizeQueue(state.jobs); + renderWorkerMonitor(); renderModels(); renderJobs(); @@ -95,6 +102,59 @@ async function loadOverview() { } } +function renderWorkerMonitor() { + if (!state.worker) { + elements.workerState.textContent = "unknown"; + if (elements.workerSummary) { + elements.workerSummary.className = "worker-summary empty"; + elements.workerSummary.textContent = "Worker state unavailable."; + } + return; + } + + const worker = state.worker; + const tone = + worker.status === "error" + ? "danger" + : worker.status === "running" + ? "warn" + : "ready"; + + elements.workerState.textContent = `${worker.status} / ${worker.mode}`; + if (!elements.workerSummary) { + return; + } + + const chips = [ + chip("status", worker.status, tone), + chip("mode", worker.mode), + chip("queued", worker.queued_jobs), + chip("processed", worker.processed_jobs), + ]; + if (worker.current_job_id) { + chips.push(chip("current", worker.current_job_id.slice(0, 8), tone)); + } + if (worker.last_job_id) { + chips.push(chip("last", worker.last_job_id.slice(0, 8))); + } + + const heartbeat = worker.heartbeat_at + ? new Date(worker.heartbeat_at).toLocaleString() + : "n/a"; + const startedAt = worker.started_at ? new Date(worker.started_at).toLocaleString() : "n/a"; + const phase = worker.metadata?.phase ? `phase: ${worker.metadata.phase}` : "phase: n/a"; + const error = worker.last_error + ? `
error: ${escapeHtml(worker.last_error)}
` + : ""; + + elements.workerSummary.className = "worker-summary fade-in"; + elements.workerSummary.innerHTML = ` +
${chips.join("")}
+
heartbeat: ${escapeHtml(heartbeat)} | started: ${escapeHtml(startedAt)} | ${escapeHtml(phase)}
+ ${error} + `; +} + function summarizeQueue(jobs) { if (jobs.length === 0) { return "idle"; @@ -453,20 +513,29 @@ async function cancelJob(jobId) { await selectJob(jobId); } -async function runWorker(maxJobs) { +async function runWorker(maxJobs, untilIdle = false) { const selectedJobId = state.selectedJobId; + const body = untilIdle + ? { until_idle: true, max_idle_cycles: 2, max_jobs: maxJobs ?? null } + : { max_jobs: maxJobs }; const response = await fetchJson("/worker/run", { method: "POST", headers: { "Content-Type": "application/json" }, - body: JSON.stringify({ max_jobs: maxJobs }), + body: JSON.stringify(body), }); const processed = response.processed || 0; const latestProcessed = response.jobs && response.jobs.length > 0 ? response.jobs[response.jobs.length - 1] : null; + state.worker = response.worker || state.worker; + renderWorkerMonitor(); setQueueActionStatus( processed > 0 - ? `Processed ${processed} queued job(s).` - : "No queued jobs were ready to run.", + ? untilIdle + ? `Drained ${processed} queued job(s) until the worker went idle.` + : `Processed ${processed} queued job(s).` + : untilIdle + ? "Queue drain completed with no queued jobs left to process." + : "No queued jobs were ready to run.", processed > 0 ? "ready" : "warn", ); await loadOverview(); @@ -495,6 +564,9 @@ function bindEvents() { elements.runBatchJobs.addEventListener("click", () => { void runWorker(3).catch(handleQueueError); }); + elements.drainQueue.addEventListener("click", () => { + void runWorker(null, true).catch(handleQueueError); + }); } function handleError(error) { diff --git a/src/lai/api/static/index.html b/src/lai/api/static/index.html index 672d080..0199d98 100644 --- a/src/lai/api/static/index.html +++ b/src/lai/api/static/index.html @@ -49,6 +49,10 @@

Route fast. Think deep. Watch the queue move.

Queue State ... +
+ Worker + ... +
@@ -118,11 +122,13 @@

Recent jobs

+
Worker idle. Use the queue controls to process queued jobs.
+
Loading worker monitor...
No jobs yet.
diff --git a/src/lai/cli.py b/src/lai/cli.py index 5f119b1..3174f8c 100644 --- a/src/lai/cli.py +++ b/src/lai/cli.py @@ -297,6 +297,16 @@ def run_worker( min=1, help="Process at most this many queued jobs and then exit.", ), + until_idle: bool = typer.Option( + False, + "--until-idle", + help="Keep polling until the queue stays idle for the configured idle cycles.", + ), + max_idle_cycles: int = typer.Option( + 2, + min=1, + help="Number of empty polls before an until-idle run stops.", + ), ) -> None: """Run the local worker for queued jobs.""" application = create_application() @@ -307,6 +317,12 @@ def run_worker( requeue_running=True, ) ) + elif until_idle: + processed = application.orchestration.run_worker_until_idle( + max_idle_cycles=max_idle_cycles, + max_jobs=max_jobs, + requeue_running=True, + ) elif max_jobs is not None: processed = len( application.orchestration.run_worker_batch( @@ -319,6 +335,33 @@ def run_worker( console.print(f"Processed {processed} queued job(s).") +@worker_app.command("status") +def worker_status() -> None: + """Show the persisted local worker state.""" + application = create_application() + worker = application.orchestration.get_worker_state() + + table = Table(title=f"Worker {worker.worker_id}") + table.add_column("Field") + table.add_column("Value") + table.add_row("Status", worker.status) + table.add_row("Mode", worker.mode) + table.add_row("Current job", worker.current_job_id or "n/a") + table.add_row("Last job", worker.last_job_id or "n/a") + table.add_row("Processed jobs", str(worker.processed_jobs)) + table.add_row("Queued jobs", str(worker.queued_jobs)) + table.add_row( + "Heartbeat", + worker.heartbeat_at.isoformat(timespec="seconds"), + ) + table.add_row( + "Started at", + worker.started_at.isoformat(timespec="seconds") if worker.started_at else "n/a", + ) + table.add_row("Last error", worker.last_error or "none") + console.print(table) + + @eval_app.command("route") def eval_route( scenario_file: Path = typer.Option( diff --git a/src/lai/domain.py b/src/lai/domain.py index ffb986c..83a20e8 100644 --- a/src/lai/domain.py +++ b/src/lai/domain.py @@ -48,6 +48,13 @@ class JobStatus(str, Enum): CANCELED = "canceled" +class WorkerStatus(str, Enum): + IDLE = "idle" + RUNNING = "running" + STOPPED = "stopped" + ERROR = "error" + + class FinishReason(str, Enum): STOP = "stop" LENGTH = "length" @@ -312,6 +319,21 @@ class StageEventRecord(LAIModel): metadata: dict[str, Any] = Field(default_factory=dict) +class WorkerStateRecord(LAIModel): + worker_id: str + status: WorkerStatus = WorkerStatus.IDLE + mode: str = "batch" + current_job_id: str | None = None + last_job_id: str | None = None + processed_jobs: int = 0 + queued_jobs: int = 0 + started_at: datetime | None = None + heartbeat_at: datetime = Field(default_factory=utcnow) + updated_at: datetime = Field(default_factory=utcnow) + last_error: str | None = None + metadata: dict[str, Any] = Field(default_factory=dict) + + class JobRecord(LAIModel): id: str status: JobStatus diff --git a/src/lai/jobs/service.py b/src/lai/jobs/service.py index 26e7375..ad4e140 100644 --- a/src/lai/jobs/service.py +++ b/src/lai/jobs/service.py @@ -15,6 +15,8 @@ QueueMode, RoutingDecision, StageEventRecord, + WorkerStateRecord, + WorkerStatus, utcnow, ) from ..errors import RetryableProviderError @@ -25,6 +27,8 @@ class OrchestrationService: + default_worker_id = "local-worker" + def __init__( self, settings: Settings, @@ -117,6 +121,19 @@ def replay_job(self, job_id: str, *, queue_mode: QueueMode | None = None) -> Job ) return self.job_store.get_job(replayed_job.id) or replayed_job + def get_worker_state(self, worker_id: str | None = None) -> WorkerStateRecord: + resolved_worker_id = worker_id or self.default_worker_id + worker_state = self.job_store.get_worker_state(resolved_worker_id) + if worker_state is not None: + return worker_state + worker_state = WorkerStateRecord( + worker_id=resolved_worker_id, + queued_jobs=self.job_store.count_jobs_by_status(JobStatus.QUEUED), + mode="batch", + ) + self.job_store.save_worker_state(worker_state) + return worker_state + def cancel_job(self, job_id: str) -> JobRecord | None: if not self.job_store.cancel_job(job_id): return None @@ -201,19 +218,172 @@ def run_worker_batch( *, max_jobs: int = 1, requeue_running: bool = False, + worker_id: str | None = None, ) -> list[JobRecord]: if max_jobs < 1: raise ValueError("max_jobs must be at least 1.") + resolved_worker_id = worker_id or self.default_worker_id if requeue_running: - self.job_store.requeue_running_jobs() + requeued = self.job_store.requeue_running_jobs() + else: + requeued = 0 + + self._update_worker_state( + worker_id=resolved_worker_id, + status=WorkerStatus.RUNNING, + mode="batch", + current_job_id=None, + processed_jobs=0, + queued_jobs=self.job_store.count_jobs_by_status(JobStatus.QUEUED), + started_at=utcnow(), + last_error=None, + metadata={"max_jobs": max_jobs, "requeued_running_jobs": requeued}, + ) processed: list[JobRecord] = [] - for _ in range(max_jobs): - job = self.job_store.claim_next_queued_job() - if job is None: - break - processed.append(self.execute_job(job.id)) + try: + for _ in range(max_jobs): + job = self.job_store.claim_next_queued_job() + if job is None: + break + self._update_worker_state( + worker_id=resolved_worker_id, + status=WorkerStatus.RUNNING, + current_job_id=job.id, + queued_jobs=self.job_store.count_jobs_by_status(JobStatus.QUEUED), + ) + processed_job = self.execute_job(job.id) + processed.append(processed_job) + self._update_worker_state( + worker_id=resolved_worker_id, + status=WorkerStatus.RUNNING, + current_job_id=None, + last_job_id=processed_job.id, + processed_jobs=len(processed), + queued_jobs=self.job_store.count_jobs_by_status(JobStatus.QUEUED), + ) + except Exception as exc: + self._update_worker_state( + worker_id=resolved_worker_id, + status=WorkerStatus.ERROR, + current_job_id=None, + queued_jobs=self.job_store.count_jobs_by_status(JobStatus.QUEUED), + last_error=f"{type(exc).__name__}: {exc}", + ) + raise + + self._update_worker_state( + worker_id=resolved_worker_id, + status=WorkerStatus.IDLE, + current_job_id=None, + processed_jobs=len(processed), + queued_jobs=self.job_store.count_jobs_by_status(JobStatus.QUEUED), + metadata={"max_jobs": max_jobs, "phase": "complete"}, + ) + return processed + + def run_worker_until_idle( + self, + *, + max_idle_cycles: int = 2, + max_jobs: int | None = None, + requeue_running: bool = True, + worker_id: str | None = None, + ) -> int: + if max_idle_cycles < 1: + raise ValueError("max_idle_cycles must be at least 1.") + if max_jobs is not None and max_jobs < 1: + raise ValueError("max_jobs must be at least 1 when provided.") + + resolved_worker_id = worker_id or self.default_worker_id + if requeue_running: + requeued = self.job_store.requeue_running_jobs() + else: + requeued = 0 + + processed = 0 + idle_cycles = 0 + self._update_worker_state( + worker_id=resolved_worker_id, + status=WorkerStatus.RUNNING, + mode="until-idle", + current_job_id=None, + processed_jobs=0, + queued_jobs=self.job_store.count_jobs_by_status(JobStatus.QUEUED), + started_at=utcnow(), + last_error=None, + metadata={ + "max_idle_cycles": max_idle_cycles, + "max_jobs": max_jobs, + "requeued_running_jobs": requeued, + "phase": "draining", + }, + ) + + try: + while True: + if max_jobs is not None and processed >= max_jobs: + break + + job = self.job_store.claim_next_queued_job() + if job is None: + idle_cycles += 1 + self._update_worker_state( + worker_id=resolved_worker_id, + status=WorkerStatus.RUNNING, + current_job_id=None, + queued_jobs=0, + processed_jobs=processed, + metadata={ + "max_idle_cycles": max_idle_cycles, + "idle_cycles": idle_cycles, + "phase": "waiting", + }, + ) + if idle_cycles >= max_idle_cycles: + break + time.sleep(self.settings.worker_idle_sleep_seconds) + continue + + idle_cycles = 0 + self._update_worker_state( + worker_id=resolved_worker_id, + status=WorkerStatus.RUNNING, + current_job_id=job.id, + queued_jobs=self.job_store.count_jobs_by_status(JobStatus.QUEUED), + metadata={"phase": "processing"}, + ) + processed_job = self.execute_job(job.id) + processed += 1 + self._update_worker_state( + worker_id=resolved_worker_id, + status=WorkerStatus.RUNNING, + current_job_id=None, + last_job_id=processed_job.id, + processed_jobs=processed, + queued_jobs=self.job_store.count_jobs_by_status(JobStatus.QUEUED), + metadata={"phase": "draining"}, + ) + except Exception as exc: + self._update_worker_state( + worker_id=resolved_worker_id, + status=WorkerStatus.ERROR, + current_job_id=None, + queued_jobs=self.job_store.count_jobs_by_status(JobStatus.QUEUED), + processed_jobs=processed, + last_error=f"{type(exc).__name__}: {exc}", + ) + raise + + self._update_worker_state( + worker_id=resolved_worker_id, + status=WorkerStatus.IDLE, + current_job_id=None, + processed_jobs=processed, + queued_jobs=self.job_store.count_jobs_by_status(JobStatus.QUEUED), + metadata={"phase": "idle", "idle_cycles": idle_cycles}, + ) return processed def run_worker(self, once: bool = False, max_jobs: int | None = None) -> int: @@ -221,20 +391,81 @@ def run_worker(self, once: bool = False, max_jobs: int | None = None) -> int: raise ValueError("max_jobs must be at least 1 when provided.") processed = 0 - self.job_store.requeue_running_jobs() + requeued = self.job_store.requeue_running_jobs() + self._update_worker_state( + worker_id=self.default_worker_id, + status=WorkerStatus.RUNNING, + mode="continuous", + current_job_id=None, + processed_jobs=0, + queued_jobs=self.job_store.count_jobs_by_status(JobStatus.QUEUED), + started_at=utcnow(), + last_error=None, + metadata={"requeued_running_jobs": requeued, "phase": "monitoring"}, + ) while True: if max_jobs is not None and processed >= max_jobs: + self._update_worker_state( + worker_id=self.default_worker_id, + status=WorkerStatus.IDLE, + current_job_id=None, + processed_jobs=processed, + queued_jobs=self.job_store.count_jobs_by_status(JobStatus.QUEUED), + metadata={"phase": "idle"}, + ) return processed job = self.job_store.claim_next_queued_job() if job is None: if once: + self._update_worker_state( + worker_id=self.default_worker_id, + status=WorkerStatus.IDLE, + current_job_id=None, + processed_jobs=processed, + queued_jobs=0, + metadata={"phase": "idle"}, + ) return processed + self._update_worker_state( + worker_id=self.default_worker_id, + status=WorkerStatus.RUNNING, + current_job_id=None, + processed_jobs=processed, + queued_jobs=0, + metadata={"phase": "waiting"}, + ) time.sleep(self.settings.worker_idle_sleep_seconds) continue + self._update_worker_state( + worker_id=self.default_worker_id, + status=WorkerStatus.RUNNING, + current_job_id=job.id, + processed_jobs=processed, + queued_jobs=self.job_store.count_jobs_by_status(JobStatus.QUEUED), + metadata={"phase": "processing"}, + ) self.execute_job(job.id) processed += 1 if once: + self._update_worker_state( + worker_id=self.default_worker_id, + status=WorkerStatus.IDLE, + current_job_id=None, + processed_jobs=processed, + last_job_id=job.id, + queued_jobs=self.job_store.count_jobs_by_status(JobStatus.QUEUED), + metadata={"phase": "idle"}, + ) return processed + self._update_worker_state( + worker_id=self.default_worker_id, + status=WorkerStatus.RUNNING, + current_job_id=None, + processed_jobs=processed, + last_job_id=job.id, + queued_jobs=self.job_store.count_jobs_by_status(JobStatus.QUEUED), + metadata={"phase": "monitoring"}, + ) def _run_pipeline(self, job: JobRecord) -> str: assert job.route_decision is not None @@ -398,6 +629,22 @@ def _record_event( self.job_store.add_stage_event(event) return event + def _update_worker_state( + self, + *, + worker_id: str, + **updates: object, + ) -> WorkerStateRecord: + current = self.get_worker_state(worker_id) + payload = current.model_dump() + payload.update(updates) + now = utcnow() + payload["heartbeat_at"] = now + payload["updated_at"] = now + worker_state = WorkerStateRecord.model_validate(payload) + self.job_store.save_worker_state(worker_state) + return worker_state + @staticmethod def _resolved_queue_mode(queue_mode: QueueMode, route_decision: RoutingDecision) -> QueueMode: queue_mode_value = _enum_value(queue_mode) diff --git a/src/lai/jobs/store.py b/src/lai/jobs/store.py index 6c0faba..0b330d5 100644 --- a/src/lai/jobs/store.py +++ b/src/lai/jobs/store.py @@ -4,7 +4,14 @@ import sqlite3 from pathlib import Path -from ..domain import ArtifactRecord, JobRecord, JobStatus, StageEventRecord, utcnow +from ..domain import ( + ArtifactRecord, + JobRecord, + JobStatus, + StageEventRecord, + WorkerStateRecord, + utcnow, +) class JobStore: @@ -54,6 +61,21 @@ def initialize(self) -> None: metadata_json TEXT NOT NULL, FOREIGN KEY(job_id) REFERENCES jobs(id) ); + + CREATE TABLE IF NOT EXISTS worker_state ( + worker_id TEXT PRIMARY KEY, + status TEXT NOT NULL, + mode TEXT NOT NULL, + current_job_id TEXT, + last_job_id TEXT, + processed_jobs INTEGER NOT NULL, + queued_jobs INTEGER NOT NULL, + started_at TEXT, + heartbeat_at TEXT NOT NULL, + updated_at TEXT NOT NULL, + last_error TEXT, + metadata_json TEXT NOT NULL + ); """ ) @@ -135,6 +157,68 @@ def add_stage_event(self, event: StageEventRecord) -> None: ), ) + def save_worker_state(self, worker_state: WorkerStateRecord) -> None: + with self._connect() as connection: + connection.execute( + """ + INSERT INTO worker_state ( + worker_id, status, mode, current_job_id, last_job_id, processed_jobs, + queued_jobs, started_at, heartbeat_at, updated_at, last_error, metadata_json + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + ON CONFLICT(worker_id) DO UPDATE SET + status=excluded.status, + mode=excluded.mode, + current_job_id=excluded.current_job_id, + last_job_id=excluded.last_job_id, + processed_jobs=excluded.processed_jobs, + queued_jobs=excluded.queued_jobs, + started_at=excluded.started_at, + heartbeat_at=excluded.heartbeat_at, + updated_at=excluded.updated_at, + last_error=excluded.last_error, + metadata_json=excluded.metadata_json + """, + ( + worker_state.worker_id, + worker_state.status, + worker_state.mode, + worker_state.current_job_id, + worker_state.last_job_id, + worker_state.processed_jobs, + worker_state.queued_jobs, + worker_state.started_at.isoformat() if worker_state.started_at else None, + worker_state.heartbeat_at.isoformat(), + worker_state.updated_at.isoformat(), + worker_state.last_error, + json.dumps(worker_state.metadata), + ), + ) + + def get_worker_state(self, worker_id: str = "local-worker") -> WorkerStateRecord | None: + with self._connect() as connection: + row = connection.execute( + "SELECT * FROM worker_state WHERE worker_id = ?", + (worker_id,), + ).fetchone() + if row is None: + return None + return WorkerStateRecord.model_validate( + { + "worker_id": row["worker_id"], + "status": row["status"], + "mode": row["mode"], + "current_job_id": row["current_job_id"], + "last_job_id": row["last_job_id"], + "processed_jobs": row["processed_jobs"], + "queued_jobs": row["queued_jobs"], + "started_at": row["started_at"], + "heartbeat_at": row["heartbeat_at"], + "updated_at": row["updated_at"], + "last_error": row["last_error"], + "metadata": json.loads(row["metadata_json"]), + } + ) + def get_job(self, job_id: str) -> JobRecord | None: with self._connect() as connection: row = connection.execute("SELECT * FROM jobs WHERE id = ?", (job_id,)).fetchone() @@ -150,6 +234,14 @@ def list_jobs(self, limit: int = 20) -> list[JobRecord]: ).fetchall() return [self._deserialize_job(connection, row) for row in rows] + def count_jobs_by_status(self, status: JobStatus) -> int: + with self._connect() as connection: + row = connection.execute( + "SELECT COUNT(*) AS total FROM jobs WHERE status = ?", + (status,), + ).fetchone() + return int(row["total"]) if row else 0 + def cancel_job(self, job_id: str) -> bool: with self._connect() as connection: result = connection.execute( diff --git a/tests/integration/test_orchestration.py b/tests/integration/test_orchestration.py index 8e76901..9bb4cd0 100644 --- a/tests/integration/test_orchestration.py +++ b/tests/integration/test_orchestration.py @@ -118,3 +118,28 @@ def test_worker_batch_processes_only_requested_number_of_jobs(tmp_path, repo_roo assert refreshed_second is not None assert [refreshed_first.status, refreshed_second.status].count(JobStatus.SUCCEEDED) == 1 assert [refreshed_first.status, refreshed_second.status].count(JobStatus.QUEUED) == 1 + + +def test_worker_until_idle_drains_queue_and_updates_worker_state(tmp_path, repo_root) -> None: + app = _make_application(tmp_path, repo_root) + app.orchestration.submit_request( + ExecutionRequest( + user_prompt="First overnight architecture request.", + queue_mode=QueueMode.QUEUED, + ) + ) + app.orchestration.submit_request( + ExecutionRequest( + user_prompt="Second overnight architecture request.", + queue_mode=QueueMode.QUEUED, + ) + ) + + processed = app.orchestration.run_worker_until_idle(max_idle_cycles=1) + worker = app.orchestration.get_worker_state() + + assert processed == 2 + assert worker.status == "idle" + assert worker.mode == "until-idle" + assert worker.processed_jobs == 2 + assert worker.queued_jobs == 0 diff --git a/tests/unit/test_api.py b/tests/unit/test_api.py index 9609fd8..f2e37c4 100644 --- a/tests/unit/test_api.py +++ b/tests/unit/test_api.py @@ -42,6 +42,7 @@ def test_api_exposes_expected_routes() -> None: assert "/dashboard" in routes assert "/health" in routes assert "/models" in routes + assert "/worker/status" in routes assert "/route/explain" in routes assert "/jobs" in routes assert "/jobs/{job_id}" in routes @@ -120,6 +121,41 @@ def test_job_timeline_route_returns_stage_events(repo_root, tmp_path) -> None: ) +def test_worker_status_and_until_idle_routes(repo_root, tmp_path) -> None: + settings = _test_settings(repo_root, tmp_path) + providers = _test_providers(settings) + seeded_app = create_application(settings=settings, providers=providers) + seeded_app.orchestration.submit_request( + ExecutionRequest( + user_prompt="Prepare a queued architecture pass.", + queue_mode=QueueMode.QUEUED, + ) + ) + seeded_app.orchestration.submit_request( + ExecutionRequest( + user_prompt="Prepare another queued architecture pass.", + queue_mode=QueueMode.QUEUED, + ) + ) + + client = TestClient(create_api(settings=settings, providers=providers)) + status_before = client.get("/worker/status") + assert status_before.status_code == 200 + assert status_before.json()["queued_jobs"] == 2 + + run_response = client.post( + "/worker/run", + json={"until_idle": True, "max_idle_cycles": 1}, + ) + + assert run_response.status_code == 200 + payload = run_response.json() + assert payload["processed"] == 2 + assert payload["worker"]["status"] == "idle" + assert payload["worker"]["mode"] == "until-idle" + assert payload["worker"]["queued_jobs"] == 0 + + def test_job_replay_route_creates_new_job(repo_root, tmp_path) -> None: settings = _test_settings(repo_root, tmp_path) providers = _test_providers(settings) @@ -169,6 +205,8 @@ def test_worker_run_route_processes_bounded_batch(repo_root, tmp_path) -> None: assert payload["processed"] == 1 assert len(payload["jobs"]) == 1 assert payload["jobs"][0]["status"] == "succeeded" + assert payload["worker"]["status"] == "idle" + assert payload["worker"]["processed_jobs"] == 1 refreshed_first = seeded_app.job_store.get_job(first_job.id) refreshed_second = seeded_app.job_store.get_job(second_job.id) From 0644430444129fa0bea28c130b5cda86224134bd Mon Sep 17 00:00:00 2001 From: N3uralCreativity Date: Tue, 7 Apr 2026 20:14:44 +0200 Subject: [PATCH 07/16] feat: add dedicated worker service entrypoint --- README.md | 28 +++- apps/web/README.md | 1 + apps/worker/README.md | 21 ++- docs/setup/workstation.md | 21 +++ pyproject.toml | 1 + src/lai/__init__.py | 3 +- src/lai/api/app.py | 11 +- src/lai/api/static/dashboard.js | 5 + src/lai/cli.py | 68 +++++++++ src/lai/jobs/service.py | 51 +++++-- src/lai/settings.py | 16 ++ src/lai/worker/__init__.py | 17 +++ src/lai/worker/cli.py | 109 ++++++++++++++ src/lai/worker/service.py | 233 ++++++++++++++++++++++++++++++ tests/unit/test_api.py | 3 + tests/unit/test_layout.py | 9 ++ tests/unit/test_worker_service.py | 94 ++++++++++++ 17 files changed, 676 insertions(+), 15 deletions(-) create mode 100644 src/lai/worker/__init__.py create mode 100644 src/lai/worker/cli.py create mode 100644 src/lai/worker/service.py create mode 100644 tests/unit/test_worker_service.py diff --git a/README.md b/README.md index fd5f9be..df342a6 100644 --- a/README.md +++ b/README.md @@ -103,6 +103,10 @@ python -m lai.cli worker run --once python -m lai.cli worker run --max-jobs 3 python -m lai.cli worker run --until-idle --max-idle-cycles 1 python -m lai.cli worker status +python -m lai.cli worker serve --poll-interval 10 +python -m lai.cli worker stop +lai-worker run --poll-interval 10 +lai-worker stop python -m lai.cli eval route --no-save ``` @@ -137,6 +141,28 @@ route explanation, recent job inspection, stage telemetry, artifact/trace browsi replay actions, persisted worker monitoring, and bounded queue worker controls including an until-idle drain path. +## Dedicated worker service + +For always-on local queue processing, run the dedicated worker service instead of manually +triggering bounded worker batches: + +```powershell +lai-worker run --poll-interval 10 +``` + +The service: + +- acquires a lock at `data/state/worker-service.lock` +- watches for a graceful stop signal at `data/state/worker-service.stop` +- writes service logs to `logs/worker-service.log` +- keeps the persisted worker state fresh for the CLI, API, and dashboard + +To stop it gracefully: + +```powershell +lai-worker stop +``` + ## Initial GitHub rules encoded in this repo - Pull request template and issue forms for consistent planning. @@ -150,7 +176,7 @@ an until-idle drain path. 1. Add live provider smoke tests behind credentials and optional extras. 2. Harden the AirLLM local runtime path with real workstation validation. 3. Expand eval scenarios and richer reviewer/final-output refinement. -4. Add a dedicated long-running local worker service or daemon entrypoint for always-on queue processing. +4. Add richer live provider execution visibility and workstation runbooks for real large-model jobs. ## References diff --git a/apps/web/README.md b/apps/web/README.md index f55c577..9d506f7 100644 --- a/apps/web/README.md +++ b/apps/web/README.md @@ -11,6 +11,7 @@ Current dashboard capabilities: - artifact and trace browsing for persisted jobs - job replay controls for inline and queued reruns - persisted worker monitoring with heartbeat, current job, and queue depth +- service-aware worker monitoring with daemon lock and stop-signal visibility - bounded live worker controls for processing queued jobs, including queue drain until idle - model health cards - job output inspector diff --git a/apps/worker/README.md b/apps/worker/README.md index 1c41b7f..8a4edc2 100644 --- a/apps/worker/README.md +++ b/apps/worker/README.md @@ -1,3 +1,22 @@ # Worker App -This folder is reserved for long-running local or remote workers that execute model jobs, including AirLLM-backed heavy inference tasks. +This area now maps to the dedicated local worker service surface used for always-on queue +processing. + +Primary entrypoints: + +- `lai-worker run --poll-interval 10` +- `lai-worker stop` +- `python -m lai.cli worker serve --poll-interval 10` +- `python -m lai.cli worker stop` + +Service runtime contract: + +- lock file: `data/state/worker-service.lock` +- stop signal: `data/state/worker-service.stop` +- service log: `logs/worker-service.log` +- persisted worker state: `data/state/lai.db` + +The worker service repeatedly drains queued jobs until idle, sleeps for the configured poll +interval, then checks again. This keeps the platform ready for overnight work without requiring +manual `worker run` commands. diff --git a/docs/setup/workstation.md b/docs/setup/workstation.md index 5a507f5..f939382 100644 --- a/docs/setup/workstation.md +++ b/docs/setup/workstation.md @@ -29,3 +29,24 @@ Add AirLLM only when you are ready to test the large-model path: ```powershell python -m pip install airllm ``` + +## Always-on local worker + +Once the environment is ready, you can keep queued jobs processing in the background with the +dedicated worker service: + +```powershell +lai-worker run --poll-interval 10 +``` + +Useful local paths: + +- lock file: `data/state/worker-service.lock` +- stop signal: `data/state/worker-service.stop` +- log file: `logs/worker-service.log` + +To stop the worker cleanly from another terminal: + +```powershell +lai-worker stop +``` diff --git a/pyproject.toml b/pyproject.toml index 5a4c5cc..8a4ff25 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -45,6 +45,7 @@ dev = [ [project.scripts] lai = "lai.cli:app" +lai-worker = "lai.worker.cli:app" [tool.hatch.build.targets.wheel] packages = ["src/lai"] diff --git a/src/lai/__init__.py b/src/lai/__init__.py index 601a203..c27714a 100644 --- a/src/lai/__init__.py +++ b/src/lai/__init__.py @@ -1,7 +1,7 @@ """LAI core package.""" from .config import AppConfig, load_app_config -from .domain import ExecutionRequest, JobStatus, QueueMode, RoutingDecision +from .domain import ExecutionRequest, JobStatus, QueueMode, RoutingDecision, WorkerStatus from .settings import Settings __all__ = [ @@ -11,5 +11,6 @@ "QueueMode", "RoutingDecision", "Settings", + "WorkerStatus", "load_app_config", ] diff --git a/src/lai/api/app.py b/src/lai/api/app.py index 269ae84..2e9617f 100644 --- a/src/lai/api/app.py +++ b/src/lai/api/app.py @@ -92,7 +92,16 @@ def list_models() -> dict[str, object]: @api.get("/worker/status") def worker_status() -> dict[str, object]: app_state = application() - return app_state.orchestration.get_worker_state().model_dump() + worker = app_state.orchestration.get_worker_state() + payload = worker.model_dump() + payload["service_lock_present"] = ( + app_state.settings.resolved_worker_service_lock_path.exists() + ) + payload["stop_signal_present"] = ( + app_state.settings.resolved_worker_service_stop_path.exists() + ) + payload["log_file"] = str(app_state.settings.resolved_worker_service_log_path) + return payload @api.post("/route/explain") def explain_route(payload: JobCreatePayload) -> dict[str, object]: diff --git a/src/lai/api/static/dashboard.js b/src/lai/api/static/dashboard.js index 8c1880d..aaf9800 100644 --- a/src/lai/api/static/dashboard.js +++ b/src/lai/api/static/dashboard.js @@ -130,6 +130,7 @@ function renderWorkerMonitor() { chip("mode", worker.mode), chip("queued", worker.queued_jobs), chip("processed", worker.processed_jobs), + chip("service", worker.service_lock_present ? "locked" : "inactive", worker.service_lock_present ? "warn" : ""), ]; if (worker.current_job_id) { chips.push(chip("current", worker.current_job_id.slice(0, 8), tone)); @@ -143,6 +144,8 @@ function renderWorkerMonitor() { : "n/a"; const startedAt = worker.started_at ? new Date(worker.started_at).toLocaleString() : "n/a"; const phase = worker.metadata?.phase ? `phase: ${worker.metadata.phase}` : "phase: n/a"; + const logFile = worker.log_file ? `log: ${worker.log_file}` : "log: n/a"; + const stopSignal = worker.stop_signal_present ? '
stop signal: present
' : ""; const error = worker.last_error ? `
error: ${escapeHtml(worker.last_error)}
` : ""; @@ -151,6 +154,8 @@ function renderWorkerMonitor() { elements.workerSummary.innerHTML = `
${chips.join("")}
heartbeat: ${escapeHtml(heartbeat)} | started: ${escapeHtml(startedAt)} | ${escapeHtml(phase)}
+
${escapeHtml(logFile)}
+ ${stopSignal} ${error} `; } diff --git a/src/lai/cli.py b/src/lai/cli.py index 3174f8c..e08735d 100644 --- a/src/lai/cli.py +++ b/src/lai/cli.py @@ -13,6 +13,7 @@ from .layout import runtime_directories from .settings import Settings from .system import collect_system_snapshot +from .worker.service import WorkerServiceConfig, WorkerServiceHost, request_worker_service_stop app = typer.Typer(help="Utilities for the LAI orchestration platform.", no_args_is_help=True) models_app = typer.Typer(help="Inspect model definitions and provider readiness.") @@ -358,10 +359,77 @@ def worker_status() -> None: "Started at", worker.started_at.isoformat(timespec="seconds") if worker.started_at else "n/a", ) + table.add_row( + "Service lock", + "present" if application.settings.resolved_worker_service_lock_path.exists() else "missing", + ) + table.add_row( + "Stop signal", + "present" if application.settings.resolved_worker_service_stop_path.exists() else "missing", + ) + table.add_row("Service log", str(application.settings.resolved_worker_service_log_path)) table.add_row("Last error", worker.last_error or "none") console.print(table) +@worker_app.command("serve") +def serve_worker( + poll_interval: float | None = typer.Option( + None, + min=0.1, + help="Seconds to wait between daemon polling cycles.", + ), + max_idle_cycles: int = typer.Option( + 1, + min=1, + help="Number of empty polls before a daemon cycle becomes idle.", + ), + max_jobs_per_cycle: int | None = typer.Option( + None, + min=1, + help="Optional cap for jobs processed in one daemon cycle.", + ), + max_cycles: int | None = typer.Option( + None, + min=1, + help="Optional cap for daemon cycles. Useful for supervised runs.", + ), + worker_id: str = typer.Option( + "local-worker", + help="Persisted worker id for the dedicated service.", + ), + replace_lock: bool = typer.Option( + False, + "--replace-lock", + help="Replace an existing worker-service lock if you are certain it is stale.", + ), +) -> None: + """Run the dedicated long-lived local worker service.""" + settings = Settings() + config = WorkerServiceConfig( + worker_id=worker_id, + poll_interval_seconds=( + poll_interval + if poll_interval is not None + else settings.worker_service_poll_interval_seconds + ), + max_idle_cycles_per_drain=max_idle_cycles, + max_jobs_per_cycle=max_jobs_per_cycle, + max_cycles=max_cycles, + replace_existing_lock=replace_lock, + ) + host = WorkerServiceHost(settings=settings, config=config) + processed = host.run() + console.print(f"Worker service exited after processing {processed} job(s).") + + +@worker_app.command("stop") +def stop_worker_service() -> None: + """Request a graceful stop for the dedicated worker service.""" + path = request_worker_service_stop(Settings()) + console.print(f"Requested worker service stop via {path}") + + @eval_app.command("route") def eval_route( scenario_file: Path = typer.Option( diff --git a/src/lai/jobs/service.py b/src/lai/jobs/service.py index ad4e140..18c1e54 100644 --- a/src/lai/jobs/service.py +++ b/src/lai/jobs/service.py @@ -134,6 +134,13 @@ def get_worker_state(self, worker_id: str | None = None) -> WorkerStateRecord: self.job_store.save_worker_state(worker_state) return worker_state + def update_worker_state( + self, + worker_id: str | None = None, + **updates: object, + ) -> WorkerStateRecord: + return self._update_worker_state(worker_id=worker_id or self.default_worker_id, **updates) + def cancel_job(self, job_id: str) -> JobRecord | None: if not self.job_store.cancel_job(job_id): return None @@ -219,6 +226,7 @@ def run_worker_batch( max_jobs: int = 1, requeue_running: bool = False, worker_id: str | None = None, + mode: str = "batch", ) -> list[JobRecord]: if max_jobs < 1: raise ValueError("max_jobs must be at least 1.") @@ -232,7 +240,7 @@ def run_worker_batch( self._update_worker_state( worker_id=resolved_worker_id, status=WorkerStatus.RUNNING, - mode="batch", + mode=mode, current_job_id=None, processed_jobs=0, queued_jobs=self.job_store.count_jobs_by_status(JobStatus.QUEUED), @@ -251,6 +259,7 @@ def run_worker_batch( worker_id=resolved_worker_id, status=WorkerStatus.RUNNING, current_job_id=job.id, + mode=mode, queued_jobs=self.job_store.count_jobs_by_status(JobStatus.QUEUED), ) processed_job = self.execute_job(job.id) @@ -262,6 +271,7 @@ def run_worker_batch( last_job_id=processed_job.id, processed_jobs=len(processed), queued_jobs=self.job_store.count_jobs_by_status(JobStatus.QUEUED), + mode=mode, ) except Exception as exc: self._update_worker_state( @@ -270,6 +280,7 @@ def run_worker_batch( current_job_id=None, queued_jobs=self.job_store.count_jobs_by_status(JobStatus.QUEUED), last_error=f"{type(exc).__name__}: {exc}", + mode=mode, ) raise @@ -279,6 +290,7 @@ def run_worker_batch( current_job_id=None, processed_jobs=len(processed), queued_jobs=self.job_store.count_jobs_by_status(JobStatus.QUEUED), + mode=mode, metadata={"max_jobs": max_jobs, "phase": "complete"}, ) return processed @@ -290,6 +302,7 @@ def run_worker_until_idle( max_jobs: int | None = None, requeue_running: bool = True, worker_id: str | None = None, + mode: str = "until-idle", ) -> int: if max_idle_cycles < 1: raise ValueError("max_idle_cycles must be at least 1.") @@ -307,7 +320,7 @@ def run_worker_until_idle( self._update_worker_state( worker_id=resolved_worker_id, status=WorkerStatus.RUNNING, - mode="until-idle", + mode=mode, current_job_id=None, processed_jobs=0, queued_jobs=self.job_store.count_jobs_by_status(JobStatus.QUEUED), @@ -335,6 +348,7 @@ def run_worker_until_idle( current_job_id=None, queued_jobs=0, processed_jobs=processed, + mode=mode, metadata={ "max_idle_cycles": max_idle_cycles, "idle_cycles": idle_cycles, @@ -352,6 +366,7 @@ def run_worker_until_idle( status=WorkerStatus.RUNNING, current_job_id=job.id, queued_jobs=self.job_store.count_jobs_by_status(JobStatus.QUEUED), + mode=mode, metadata={"phase": "processing"}, ) processed_job = self.execute_job(job.id) @@ -363,6 +378,7 @@ def run_worker_until_idle( last_job_id=processed_job.id, processed_jobs=processed, queued_jobs=self.job_store.count_jobs_by_status(JobStatus.QUEUED), + mode=mode, metadata={"phase": "draining"}, ) except Exception as exc: @@ -373,6 +389,7 @@ def run_worker_until_idle( queued_jobs=self.job_store.count_jobs_by_status(JobStatus.QUEUED), processed_jobs=processed, last_error=f"{type(exc).__name__}: {exc}", + mode=mode, ) raise @@ -382,11 +399,17 @@ def run_worker_until_idle( current_job_id=None, processed_jobs=processed, queued_jobs=self.job_store.count_jobs_by_status(JobStatus.QUEUED), + mode=mode, metadata={"phase": "idle", "idle_cycles": idle_cycles}, ) return processed - def run_worker(self, once: bool = False, max_jobs: int | None = None) -> int: + def run_worker( + self, + once: bool = False, + max_jobs: int | None = None, + mode: str = "continuous", + ) -> int: if max_jobs is not None and max_jobs < 1: raise ValueError("max_jobs must be at least 1 when provided.") @@ -395,7 +418,7 @@ def run_worker(self, once: bool = False, max_jobs: int | None = None) -> int: self._update_worker_state( worker_id=self.default_worker_id, status=WorkerStatus.RUNNING, - mode="continuous", + mode=mode, current_job_id=None, processed_jobs=0, queued_jobs=self.job_store.count_jobs_by_status(JobStatus.QUEUED), @@ -406,13 +429,14 @@ def run_worker(self, once: bool = False, max_jobs: int | None = None) -> int: while True: if max_jobs is not None and processed >= max_jobs: self._update_worker_state( - worker_id=self.default_worker_id, - status=WorkerStatus.IDLE, - current_job_id=None, - processed_jobs=processed, - queued_jobs=self.job_store.count_jobs_by_status(JobStatus.QUEUED), - metadata={"phase": "idle"}, - ) + worker_id=self.default_worker_id, + status=WorkerStatus.IDLE, + current_job_id=None, + processed_jobs=processed, + queued_jobs=self.job_store.count_jobs_by_status(JobStatus.QUEUED), + mode=mode, + metadata={"phase": "idle"}, + ) return processed job = self.job_store.claim_next_queued_job() if job is None: @@ -423,6 +447,7 @@ def run_worker(self, once: bool = False, max_jobs: int | None = None) -> int: current_job_id=None, processed_jobs=processed, queued_jobs=0, + mode=mode, metadata={"phase": "idle"}, ) return processed @@ -432,6 +457,7 @@ def run_worker(self, once: bool = False, max_jobs: int | None = None) -> int: current_job_id=None, processed_jobs=processed, queued_jobs=0, + mode=mode, metadata={"phase": "waiting"}, ) time.sleep(self.settings.worker_idle_sleep_seconds) @@ -442,6 +468,7 @@ def run_worker(self, once: bool = False, max_jobs: int | None = None) -> int: current_job_id=job.id, processed_jobs=processed, queued_jobs=self.job_store.count_jobs_by_status(JobStatus.QUEUED), + mode=mode, metadata={"phase": "processing"}, ) self.execute_job(job.id) @@ -454,6 +481,7 @@ def run_worker(self, once: bool = False, max_jobs: int | None = None) -> int: processed_jobs=processed, last_job_id=job.id, queued_jobs=self.job_store.count_jobs_by_status(JobStatus.QUEUED), + mode=mode, metadata={"phase": "idle"}, ) return processed @@ -464,6 +492,7 @@ def run_worker(self, once: bool = False, max_jobs: int | None = None) -> int: processed_jobs=processed, last_job_id=job.id, queued_jobs=self.job_store.count_jobs_by_status(JobStatus.QUEUED), + mode=mode, metadata={"phase": "monitoring"}, ) diff --git a/src/lai/settings.py b/src/lai/settings.py index a21274e..059ed41 100644 --- a/src/lai/settings.py +++ b/src/lai/settings.py @@ -35,6 +35,9 @@ class Settings(BaseSettings): state_dir: Path = Path("data/state") database_path: Path = Path("data/state/lai.db") logs_dir: Path = Path("logs") + worker_service_lock_path: Path = Path("data/state/worker-service.lock") + worker_service_stop_path: Path = Path("data/state/worker-service.stop") + worker_service_log_path: Path = Path("logs/worker-service.log") allow_overnight_jobs: bool = True enable_gpu: bool = True @@ -45,6 +48,7 @@ class Settings(BaseSettings): default_temperature: float = 0.2 queue_poll_interval_seconds: int = 5 worker_idle_sleep_seconds: float = 2.0 + worker_service_poll_interval_seconds: float = 10.0 max_retry_attempts: int = 1 stale_running_job_timeout_seconds: int = 900 @@ -93,6 +97,18 @@ def resolved_database_path(self) -> Path: def resolved_logs_dir(self) -> Path: return self.root_dir / self.logs_dir + @cached_property + def resolved_worker_service_lock_path(self) -> Path: + return self.root_dir / self.worker_service_lock_path + + @cached_property + def resolved_worker_service_stop_path(self) -> Path: + return self.root_dir / self.worker_service_stop_path + + @cached_property + def resolved_worker_service_log_path(self) -> Path: + return self.root_dir / self.worker_service_log_path + @staticmethod def _secret_value(secret: SecretStr | None) -> str | None: return secret.get_secret_value() if secret else None diff --git a/src/lai/worker/__init__.py b/src/lai/worker/__init__.py new file mode 100644 index 0000000..0cb44e3 --- /dev/null +++ b/src/lai/worker/__init__.py @@ -0,0 +1,17 @@ +"""Dedicated worker service entrypoints and helpers.""" + +from .service import ( + WorkerServiceConfig, + WorkerServiceHost, + clear_worker_service_stop, + request_worker_service_stop, + worker_service_stop_requested, +) + +__all__ = [ + "WorkerServiceConfig", + "WorkerServiceHost", + "clear_worker_service_stop", + "request_worker_service_stop", + "worker_service_stop_requested", +] diff --git a/src/lai/worker/cli.py b/src/lai/worker/cli.py new file mode 100644 index 0000000..c5a5f0f --- /dev/null +++ b/src/lai/worker/cli.py @@ -0,0 +1,109 @@ +from __future__ import annotations + +import typer +from rich.console import Console +from rich.table import Table + +from ..application import create_application +from ..settings import Settings +from .service import WorkerServiceConfig, WorkerServiceHost, request_worker_service_stop + +app = typer.Typer(help="Dedicated local worker service entrypoint.", no_args_is_help=True) +console = Console() + + +@app.command("run") +def run_service( + poll_interval: float | None = typer.Option( + None, + min=0.1, + help="Seconds to wait between daemon polling cycles.", + ), + max_idle_cycles: int = typer.Option( + 1, + min=1, + help="Number of empty polls before a drain cycle becomes idle.", + ), + max_jobs_per_cycle: int | None = typer.Option( + None, + min=1, + help="Optional cap for jobs processed in one drain cycle.", + ), + max_cycles: int | None = typer.Option( + None, + min=1, + help="Optional cap for daemon cycles. Useful for tests or supervised runs.", + ), + worker_id: str = typer.Option( + "local-worker", + help="Persisted worker id used for service status and ownership.", + ), + replace_lock: bool = typer.Option( + False, + "--replace-lock", + help="Replace an existing lock file if you are certain the prior worker is gone.", + ), +) -> None: + """Run the dedicated long-lived local worker service.""" + settings = Settings() + config = WorkerServiceConfig( + worker_id=worker_id, + poll_interval_seconds=( + poll_interval + if poll_interval is not None + else settings.worker_service_poll_interval_seconds + ), + max_idle_cycles_per_drain=max_idle_cycles, + max_jobs_per_cycle=max_jobs_per_cycle, + max_cycles=max_cycles, + replace_existing_lock=replace_lock, + ) + host = WorkerServiceHost(settings=settings, config=config) + processed = host.run() + console.print(f"Worker service exited after processing {processed} job(s).") + + +@app.command("stop") +def stop_service() -> None: + """Request a graceful stop for the dedicated worker service.""" + path = request_worker_service_stop(Settings()) + console.print(f"Requested worker service stop via {path}") + + +@app.command("status") +def service_status( + worker_id: str = typer.Option( + "local-worker", + help="Persisted worker id to inspect.", + ), +) -> None: + """Show dedicated worker service status, lock, and stop-signal state.""" + settings = Settings() + application = create_application(settings=settings) + worker = application.orchestration.get_worker_state(worker_id) + + table = Table(title=f"Worker Service {worker_id}") + table.add_column("Field") + table.add_column("Value") + table.add_row("Status", worker.status) + table.add_row("Mode", worker.mode) + table.add_row("Current job", worker.current_job_id or "n/a") + table.add_row("Last job", worker.last_job_id or "n/a") + table.add_row("Processed jobs", str(worker.processed_jobs)) + table.add_row("Queued jobs", str(worker.queued_jobs)) + table.add_row("Heartbeat", worker.heartbeat_at.isoformat(timespec="seconds")) + table.add_row( + "Lock file", + "present" if settings.resolved_worker_service_lock_path.exists() else "missing", + ) + table.add_row( + "Stop signal", + "present" if settings.resolved_worker_service_stop_path.exists() else "missing", + ) + table.add_row("Log file", str(settings.resolved_worker_service_log_path)) + table.add_row("Last error", worker.last_error or "none") + console.print(table) + + +if __name__ == "__main__": + app() diff --git a/src/lai/worker/service.py b/src/lai/worker/service.py new file mode 100644 index 0000000..8e1deeb --- /dev/null +++ b/src/lai/worker/service.py @@ -0,0 +1,233 @@ +from __future__ import annotations + +import json +import logging +import os +import time +from dataclasses import dataclass +from pathlib import Path +from typing import Callable, Mapping + +from ..application import LAIApplication, create_application +from ..domain import JobStatus, WorkerStatus +from ..layout import ensure_runtime_directories +from ..providers import Provider +from ..settings import Settings + + +@dataclass +class WorkerServiceConfig: + worker_id: str = "local-worker" + poll_interval_seconds: float = 10.0 + max_idle_cycles_per_drain: int = 1 + max_jobs_per_cycle: int | None = None + max_cycles: int | None = None + replace_existing_lock: bool = False + + +def request_worker_service_stop(settings: Settings | None = None) -> Path: + settings = settings or Settings() + stop_path = settings.resolved_worker_service_stop_path + stop_path.parent.mkdir(parents=True, exist_ok=True) + stop_path.write_text("stop requested\n", encoding="utf-8") + return stop_path + + +def clear_worker_service_stop(settings: Settings | None = None) -> None: + settings = settings or Settings() + settings.resolved_worker_service_stop_path.unlink(missing_ok=True) + + +def worker_service_stop_requested(settings: Settings | None = None) -> bool: + settings = settings or Settings() + return settings.resolved_worker_service_stop_path.exists() + + +class WorkerServiceLock: + def __init__(self, lock_path: Path, *, replace_existing: bool = False) -> None: + self.lock_path = lock_path + self.replace_existing = replace_existing + self.acquired = False + + def __enter__(self) -> "WorkerServiceLock": + self.acquire() + return self + + def __exit__(self, *_args: object) -> None: + self.release() + + def acquire(self) -> None: + self.lock_path.parent.mkdir(parents=True, exist_ok=True) + if self.replace_existing and self.lock_path.exists(): + self.lock_path.unlink() + try: + descriptor = os.open( + str(self.lock_path), + os.O_CREAT | os.O_EXCL | os.O_WRONLY, + ) + except FileExistsError as exc: + raise RuntimeError( + f"Worker service lock already exists at {self.lock_path}. " + "Use replace_existing_lock only when you are certain the prior worker is gone." + ) from exc + + payload = { + "pid": os.getpid(), + "created_at": time.time(), + } + with os.fdopen(descriptor, "w", encoding="utf-8") as handle: + json.dump(payload, handle) + self.acquired = True + + def release(self) -> None: + if self.acquired: + self.lock_path.unlink(missing_ok=True) + self.acquired = False + + +class WorkerServiceHost: + def __init__( + self, + *, + settings: Settings | None = None, + providers: Mapping[str, Provider] | None = None, + config: WorkerServiceConfig | None = None, + application_factory: Callable[[], LAIApplication] | None = None, + sleep_fn: Callable[[float], None] = time.sleep, + logger: logging.Logger | None = None, + ) -> None: + self.settings = settings or Settings() + ensure_runtime_directories(self.settings.root_dir) + self.providers = providers + self.config = config or WorkerServiceConfig( + poll_interval_seconds=self.settings.worker_service_poll_interval_seconds + ) + self.application_factory = application_factory or ( + lambda: create_application(settings=self.settings, providers=self.providers) + ) + self.sleep_fn = sleep_fn + self.logger = logger or build_worker_service_logger( + self.settings.resolved_worker_service_log_path + ) + + def run(self) -> int: + total_processed = 0 + cycles = 0 + clear_worker_service_stop(self.settings) + + with WorkerServiceLock( + self.settings.resolved_worker_service_lock_path, + replace_existing=self.config.replace_existing_lock, + ): + self.logger.info("LAI worker service started for worker_id=%s", self.config.worker_id) + self._mark_service_state( + status=WorkerStatus.IDLE, + phase="starting", + processed_jobs=0, + metadata={"service_cycles": 0}, + ) + try: + while True: + if worker_service_stop_requested(self.settings): + self.logger.info("Worker service stop signal received.") + self._mark_service_state( + status=WorkerStatus.STOPPED, + phase="stopped", + processed_jobs=total_processed, + metadata={"service_cycles": cycles}, + ) + break + + app = self.application_factory() + processed = app.orchestration.run_worker_until_idle( + max_idle_cycles=self.config.max_idle_cycles_per_drain, + max_jobs=self.config.max_jobs_per_cycle, + requeue_running=(cycles == 0), + worker_id=self.config.worker_id, + mode="daemon", + ) + total_processed += processed + cycles += 1 + self.logger.info( + "Worker service cycle %s processed %s job(s), total=%s", + cycles, + processed, + total_processed, + ) + self._mark_service_state( + status=WorkerStatus.IDLE, + phase="sleeping", + processed_jobs=total_processed, + metadata={ + "service_cycles": cycles, + "last_cycle_processed": processed, + }, + ) + + if self.config.max_cycles is not None and cycles >= self.config.max_cycles: + self.logger.info( + "Worker service reached max_cycles=%s", + self.config.max_cycles, + ) + break + + self.sleep_fn(self.config.poll_interval_seconds) + except Exception as exc: + self.logger.exception("Worker service failed.") + self._mark_service_state( + status=WorkerStatus.ERROR, + phase="error", + processed_jobs=total_processed, + last_error=f"{type(exc).__name__}: {exc}", + metadata={"service_cycles": cycles}, + ) + raise + finally: + clear_worker_service_stop(self.settings) + + self.logger.info("LAI worker service exited after processing %s job(s).", total_processed) + return total_processed + + def _mark_service_state( + self, + *, + status: WorkerStatus, + phase: str, + processed_jobs: int, + metadata: dict[str, object] | None = None, + last_error: str | None = None, + ) -> None: + app = self.application_factory() + current_state = app.orchestration.get_worker_state(self.config.worker_id) + app.orchestration.update_worker_state( + worker_id=self.config.worker_id, + status=status, + mode="daemon", + current_job_id=None, + processed_jobs=processed_jobs, + queued_jobs=app.job_store.count_jobs_by_status(JobStatus.QUEUED), + started_at=current_state.started_at or current_state.heartbeat_at, + last_error=last_error, + metadata={"phase": phase, **(metadata or {})}, + ) + + +def build_worker_service_logger(log_path: Path) -> logging.Logger: + log_path.parent.mkdir(parents=True, exist_ok=True) + logger = logging.getLogger("lai.worker.service") + logger.setLevel(logging.INFO) + logger.propagate = False + + handler_paths = { + getattr(handler, "baseFilename", None) + for handler in logger.handlers + if hasattr(handler, "baseFilename") + } + if str(log_path) not in handler_paths: + handler = logging.FileHandler(log_path, encoding="utf-8") + handler.setFormatter( + logging.Formatter("%(asctime)s %(levelname)s %(message)s") + ) + logger.addHandler(handler) + + return logger diff --git a/tests/unit/test_api.py b/tests/unit/test_api.py index f2e37c4..00daa45 100644 --- a/tests/unit/test_api.py +++ b/tests/unit/test_api.py @@ -142,6 +142,9 @@ def test_worker_status_and_until_idle_routes(repo_root, tmp_path) -> None: status_before = client.get("/worker/status") assert status_before.status_code == 200 assert status_before.json()["queued_jobs"] == 2 + assert "service_lock_present" in status_before.json() + assert "stop_signal_present" in status_before.json() + assert "log_file" in status_before.json() run_response = client.post( "/worker/run", diff --git a/tests/unit/test_layout.py b/tests/unit/test_layout.py index e83afa2..fe2a39a 100644 --- a/tests/unit/test_layout.py +++ b/tests/unit/test_layout.py @@ -16,3 +16,12 @@ def test_resolved_config_paths_are_under_repo_root() -> None: assert settings.resolved_model_catalog == settings.root_dir / "configs/models/catalog.yaml" assert settings.resolved_routing_policy == settings.root_dir / "configs/routing/policies.yaml" + assert settings.resolved_worker_service_lock_path == ( + settings.root_dir / "data/state/worker-service.lock" + ) + assert settings.resolved_worker_service_stop_path == ( + settings.root_dir / "data/state/worker-service.stop" + ) + assert settings.resolved_worker_service_log_path == ( + settings.root_dir / "logs/worker-service.log" + ) diff --git a/tests/unit/test_worker_service.py b/tests/unit/test_worker_service.py new file mode 100644 index 0000000..1852e44 --- /dev/null +++ b/tests/unit/test_worker_service.py @@ -0,0 +1,94 @@ +from lai.application import create_application +from lai.domain import ExecutionRequest, QueueMode +from lai.settings import Settings +from lai.system import collect_system_snapshot +from lai.worker.service import ( + WorkerServiceConfig, + WorkerServiceHost, + request_worker_service_stop, +) +from tests.helpers import FakeProvider + + +def _test_settings(repo_root, tmp_path) -> Settings: + return Settings( + root_dir=tmp_path, + model_catalog=repo_root / "configs/models/catalog.yaml", + routing_policy=repo_root / "configs/routing/policies.yaml", + database_path=tmp_path / "data/state/lai.db", + state_dir=tmp_path / "data/state", + artifacts_dir=tmp_path / "data/artifacts", + logs_dir=tmp_path / "logs", + huggingface_cache_dir=tmp_path / "data/cache/huggingface", + airllm_shards_dir=tmp_path / "data/models/airllm-shards", + raw_models_dir=tmp_path / "data/models/raw", + ) + + +def _test_providers(settings: Settings) -> dict[str, FakeProvider]: + snapshot = collect_system_snapshot(settings.root_dir, enable_gpu=False) + return { + "transformers": FakeProvider(settings, snapshot, "transformers"), + "airllm": FakeProvider(settings, snapshot, "airllm"), + "openai": FakeProvider(settings, snapshot, "openai"), + "anthropic": FakeProvider(settings, snapshot, "anthropic"), + "gemini": FakeProvider(settings, snapshot, "gemini"), + } + + +def test_worker_service_processes_queue_and_releases_lock(repo_root, tmp_path) -> None: + settings = _test_settings(repo_root, tmp_path) + providers = _test_providers(settings) + app = create_application(settings=settings, providers=providers) + app.orchestration.submit_request( + ExecutionRequest( + user_prompt="Prepare a queued architecture run for the daemon.", + queue_mode=QueueMode.QUEUED, + ) + ) + + host = WorkerServiceHost( + settings=settings, + providers=providers, + config=WorkerServiceConfig(max_cycles=1, poll_interval_seconds=0.01), + sleep_fn=lambda _seconds: None, + ) + processed = host.run() + refreshed = create_application(settings=settings, providers=providers) + worker = refreshed.orchestration.get_worker_state() + + assert processed == 1 + assert not settings.resolved_worker_service_lock_path.exists() + assert worker.mode == "daemon" + assert worker.status == "idle" + assert worker.processed_jobs == 1 + + +def test_worker_service_stop_signal_exits_gracefully(repo_root, tmp_path) -> None: + settings = _test_settings(repo_root, tmp_path) + providers = _test_providers(settings) + app = create_application(settings=settings, providers=providers) + app.orchestration.submit_request( + ExecutionRequest( + user_prompt="Prepare a queued architecture run for the daemon.", + queue_mode=QueueMode.QUEUED, + ) + ) + + def _sleep_and_request_stop(_seconds: float) -> None: + request_worker_service_stop(settings) + + host = WorkerServiceHost( + settings=settings, + providers=providers, + config=WorkerServiceConfig(poll_interval_seconds=0.01, max_idle_cycles_per_drain=1), + sleep_fn=_sleep_and_request_stop, + ) + processed = host.run() + refreshed = create_application(settings=settings, providers=providers) + worker = refreshed.orchestration.get_worker_state() + + assert processed == 1 + assert worker.status == "stopped" + assert worker.mode == "daemon" + assert not settings.resolved_worker_service_stop_path.exists() From 13faa4245efa4294a8f3d666a5aa63330cc7b221 Mon Sep 17 00:00:00 2001 From: N3uralCreativity Date: Tue, 7 Apr 2026 20:20:41 +0200 Subject: [PATCH 08/16] feat: add provider smoke diagnostics --- README.md | 25 +++- docs/setup/workstation.md | 20 +++ src/lai/cli.py | 104 ++++++++++++++++ src/lai/smoke.py | 254 ++++++++++++++++++++++++++++++++++++++ tests/unit/test_smoke.py | 87 +++++++++++++ 5 files changed, 489 insertions(+), 1 deletion(-) create mode 100644 src/lai/smoke.py create mode 100644 tests/unit/test_smoke.py diff --git a/README.md b/README.md index df342a6..be480f3 100644 --- a/README.md +++ b/README.md @@ -105,6 +105,10 @@ python -m lai.cli worker run --until-idle --max-idle-cycles 1 python -m lai.cli worker status python -m lai.cli worker serve --poll-interval 10 python -m lai.cli worker stop +python -m lai.cli smoke providers +python -m lai.cli smoke providers --live +python -m lai.cli smoke local --live --include-airllm +python -m lai.cli smoke latest lai-worker run --poll-interval 10 lai-worker stop python -m lai.cli eval route --no-save @@ -163,6 +167,25 @@ To stop it gracefully: lai-worker stop ``` +## Provider smoke diagnostics + +Use readiness-only smoke checks to verify credentials, optional dependencies, and provider health +without making live requests: + +```powershell +python -m lai.cli smoke providers +python -m lai.cli smoke local +``` + +Use live mode only when you want a real tiny prompt executed: + +```powershell +python -m lai.cli smoke providers --live +python -m lai.cli smoke local --live --include-airllm +``` + +Smoke results are saved under `evals/results/smoke/` by default. + ## Initial GitHub rules encoded in this repo - Pull request template and issue forms for consistent planning. @@ -173,7 +196,7 @@ lai-worker stop ## Near-term priorities -1. Add live provider smoke tests behind credentials and optional extras. +1. Surface latest smoke diagnostics and run summaries through the API/dashboard. 2. Harden the AirLLM local runtime path with real workstation validation. 3. Expand eval scenarios and richer reviewer/final-output refinement. 4. Add richer live provider execution visibility and workstation runbooks for real large-model jobs. diff --git a/docs/setup/workstation.md b/docs/setup/workstation.md index f939382..64f920f 100644 --- a/docs/setup/workstation.md +++ b/docs/setup/workstation.md @@ -50,3 +50,23 @@ To stop the worker cleanly from another terminal: ```powershell lai-worker stop ``` + +## Provider smoke checks + +Readiness-only smoke checks are safe to run on a fresh workstation because they verify +credentials, packages, and health without sending live prompts: + +```powershell +python -m lai.cli smoke providers +python -m lai.cli smoke local +``` + +Live smoke mode sends a tiny prompt and should be used only when you explicitly want a real +provider call: + +```powershell +python -m lai.cli smoke providers --live +python -m lai.cli smoke local --live --include-airllm +``` + +Saved smoke reports land under `evals/results/smoke/`. diff --git a/src/lai/cli.py b/src/lai/cli.py index e08735d..d495d2c 100644 --- a/src/lai/cli.py +++ b/src/lai/cli.py @@ -1,6 +1,7 @@ from __future__ import annotations from pathlib import Path +from typing import Literal import typer from rich.console import Console @@ -12,6 +13,7 @@ from .evals import run_route_eval_suite, save_route_eval_result from .layout import runtime_directories from .settings import Settings +from .smoke import latest_smoke_result, run_smoke_suite, save_smoke_result from .system import collect_system_snapshot from .worker.service import WorkerServiceConfig, WorkerServiceHost, request_worker_service_stop @@ -21,11 +23,13 @@ jobs_app = typer.Typer(help="Inspect and manage persistent jobs.") worker_app = typer.Typer(help="Run the local job worker.") eval_app = typer.Typer(help="Run evaluation suites against routing behavior.") +smoke_app = typer.Typer(help="Run provider readiness and smoke diagnostics.") app.add_typer(models_app, name="models") app.add_typer(route_app, name="route") app.add_typer(jobs_app, name="jobs") app.add_typer(worker_app, name="worker") app.add_typer(eval_app, name="eval") +app.add_typer(smoke_app, name="smoke") console = Console() @@ -474,10 +478,110 @@ def eval_route( raise typer.Exit(code=1) +@smoke_app.command("providers") +def smoke_providers( + live: bool = typer.Option( + False, + "--live", + help="Execute tiny live prompts instead of readiness-only checks.", + ), + save: bool = typer.Option( + True, + "--save/--no-save", + help="Save the smoke result under evals/results/smoke.", + ), +) -> None: + """Run readiness or live smoke checks for remote providers.""" + _run_smoke_command( + suite_id="providers", + live=live, + save=save, + ) + + +@smoke_app.command("local") +def smoke_local( + live: bool = typer.Option( + False, + "--live", + help="Execute tiny live prompts instead of readiness-only checks.", + ), + include_airllm: bool = typer.Option( + False, + "--include-airllm", + help="Include the heavier AirLLM smoke target alongside transformers.", + ), + save: bool = typer.Option( + True, + "--save/--no-save", + help="Save the smoke result under evals/results/smoke.", + ), +) -> None: + """Run readiness or live smoke checks for local providers.""" + _run_smoke_command( + suite_id="local", + live=live, + include_airllm=include_airllm, + save=save, + ) + + +@smoke_app.command("latest") +def smoke_latest() -> None: + """Show the latest saved smoke result path.""" + settings = Settings() + path = latest_smoke_result(settings.root_dir / "evals/results/smoke") + if path is None: + raise typer.Exit(code=1) + console.print(path) + + def _status_line(path: Path) -> str: suffix = "present" if path.exists() else "missing" return f"{path} ({suffix})" +def _run_smoke_command( + *, + suite_id: Literal["providers", "local"], + live: bool, + save: bool, + include_airllm: bool = False, +) -> None: + settings = Settings() + result = run_smoke_suite( + settings, + suite_id=suite_id, + live=live, + include_airllm=include_airllm, + ) + + table = Table(title=f"LAI Smoke {suite_id}") + table.add_column("Provider") + table.add_column("Status") + table.add_column("Executed") + table.add_column("Model") + table.add_column("Duration") + table.add_column("Message") + + for check in result.checks: + table.add_row( + check.provider_id, + check.status, + "yes" if check.executed else "no", + check.model_ref, + f"{check.duration_seconds:.2f}s" if check.duration_seconds is not None else "n/a", + check.message, + ) + console.print(table) + + if save: + path = save_smoke_result(settings.root_dir / "evals/results/smoke", result) + console.print(f"Saved smoke result to {path}") + + if not result.passed: + raise typer.Exit(code=1) + + if __name__ == "__main__": app() diff --git a/src/lai/smoke.py b/src/lai/smoke.py new file mode 100644 index 0000000..58bb5e4 --- /dev/null +++ b/src/lai/smoke.py @@ -0,0 +1,254 @@ +from __future__ import annotations + +from datetime import datetime, timezone +from pathlib import Path +from typing import Any, Literal, Mapping + +from pydantic import BaseModel, Field + +from .domain import ModelSpec, ProviderRequest +from .providers import Provider, ProviderRegistry +from .settings import Settings +from .system import collect_system_snapshot + +SMOKE_PROMPT = "Reply with READY and nothing else." + + +class SmokeCheckResult(BaseModel): + provider_id: str + runtime: str + model_ref: str + mode: Literal["readiness", "live"] + status: Literal["ready", "passed", "blocked", "failed", "skipped"] + available: bool + healthy: bool + executed: bool = False + success: bool = False + message: str + reasons: list[str] = Field(default_factory=list) + capabilities: dict[str, Any] = Field(default_factory=dict) + duration_seconds: float | None = None + output_preview: str | None = None + metadata: dict[str, Any] = Field(default_factory=dict) + + +class SmokeSuiteResult(BaseModel): + suite_id: str + mode: Literal["readiness", "live"] + executed_at: str + prompt: str + passed: bool + total: int + success_count: int + blocked_count: int + failed_count: int + checks: list[SmokeCheckResult] + + +def run_smoke_suite( + settings: Settings, + *, + suite_id: Literal["providers", "local"], + live: bool = False, + include_airllm: bool = False, + prompt: str = SMOKE_PROMPT, + providers: Mapping[str, Provider] | None = None, +) -> SmokeSuiteResult: + snapshot = collect_system_snapshot( + settings.root_dir, + settings.resolved_huggingface_cache_dir, + settings.resolved_airllm_shards_dir, + enable_gpu=settings.enable_gpu, + ) + registry = ProviderRegistry(settings, snapshot, providers=providers) + checks: list[SmokeCheckResult] = [] + + for model in _suite_models(suite_id, include_airllm=include_airllm): + provider = registry.get(model.provider_id) + health = provider.healthcheck(model) + capabilities = provider.describe_capabilities() + mode = "live" if live else "readiness" + + if not live: + ready = health.available and health.healthy + checks.append( + SmokeCheckResult( + provider_id=model.provider_id, + runtime=model.runtime, + model_ref=model.model_ref, + mode=mode, + status="ready" if ready else "blocked", + available=health.available, + healthy=health.healthy, + success=ready, + message=( + "Provider is ready for live smoke execution." + if ready + else "; ".join(health.reasons) or "Provider is blocked." + ), + reasons=health.reasons, + capabilities=capabilities, + metadata={"provider_metadata": health.metadata}, + ) + ) + continue + + if not health.available: + checks.append( + SmokeCheckResult( + provider_id=model.provider_id, + runtime=model.runtime, + model_ref=model.model_ref, + mode=mode, + status="blocked", + available=health.available, + healthy=health.healthy, + message="Live smoke blocked by provider healthcheck.", + reasons=health.reasons, + capabilities=capabilities, + metadata={"provider_metadata": health.metadata}, + ) + ) + continue + + try: + result = provider.generate( + model, + ProviderRequest( + user_prompt=prompt, + max_output_tokens=32, + temperature=0, + timeout_seconds=min(settings.default_timeout_seconds, 60), + ), + ) + except Exception as exc: + checks.append( + SmokeCheckResult( + provider_id=model.provider_id, + runtime=model.runtime, + model_ref=model.model_ref, + mode=mode, + status="failed", + available=health.available, + healthy=health.healthy, + executed=True, + success=False, + message=f"Live smoke failed with {type(exc).__name__}.", + reasons=[str(exc)], + capabilities=capabilities, + metadata={"provider_metadata": health.metadata}, + ) + ) + continue + + preview = (result.text or "").strip() + preview = preview[:160] if preview else None + success = bool((result.text or "").strip()) + checks.append( + SmokeCheckResult( + provider_id=model.provider_id, + runtime=model.runtime, + model_ref=model.model_ref, + mode=mode, + status="passed" if success else "failed", + available=health.available, + healthy=health.healthy, + executed=True, + success=success, + message=( + f"Live smoke completed in {result.duration_seconds:.2f}s." + if success + else "Live smoke returned an empty response." + ), + reasons=health.reasons, + capabilities=capabilities, + duration_seconds=result.duration_seconds, + output_preview=preview, + metadata={ + "provider_metadata": health.metadata, + "finish_reason": result.finish_reason, + "usage": result.usage.model_dump() if result.usage else None, + "sentinel_detected": "READY" in (result.text or "").upper(), + }, + ) + ) + + success_count = sum(1 for check in checks if check.success) + blocked_count = sum(1 for check in checks if check.status == "blocked") + failed_count = sum(1 for check in checks if check.status == "failed") + return SmokeSuiteResult( + suite_id=suite_id, + mode="live" if live else "readiness", + executed_at=_utcnow_iso(), + prompt=prompt, + passed=blocked_count == 0 and failed_count == 0 and success_count == len(checks), + total=len(checks), + success_count=success_count, + blocked_count=blocked_count, + failed_count=failed_count, + checks=checks, + ) + + +def save_smoke_result(results_dir: Path, result: SmokeSuiteResult) -> Path: + results_dir.mkdir(parents=True, exist_ok=True) + timestamp = datetime.now(tz=timezone.utc).strftime("%Y%m%dT%H%M%SZ") + path = results_dir / f"smoke-{result.suite_id}-{result.mode}-{timestamp}.json" + path.write_text(result.model_dump_json(indent=2), encoding="utf-8") + return path + + +def latest_smoke_result(results_dir: Path) -> Path | None: + if not results_dir.exists(): + return None + candidates = sorted(results_dir.glob("smoke-*.json"), reverse=True) + return candidates[0] if candidates else None + + +def _suite_models( + suite_id: Literal["providers", "local"], + *, + include_airllm: bool, +) -> list[ModelSpec]: + if suite_id == "providers": + return [ + _build_model("openai", "openai", "gpt-5.4-mini"), + _build_model("anthropic", "anthropic", "claude-sonnet-4-20250514"), + _build_model("gemini", "gemini", "gemini-2.5-flash"), + ] + + models = [ + _build_model("transformers", "transformers", "sshleifer/tiny-gpt2"), + ] + if include_airllm: + models.append( + ModelSpec.model_validate( + { + "id": "airllm-smoke", + "role": "executor", + "runtime": "airllm", + "repo_id": "meta-llama/Llama-3.1-8B-Instruct", + "hardware": {"expected_disk_gb": 1}, + "runtime_hints": { + "allow_cpu_fallback": True, + "allow_layer_sharding": True, + }, + } + ) + ) + return models + + +def _build_model(provider_id: str, runtime: str, model_ref: str) -> ModelSpec: + return ModelSpec.model_validate( + { + "id": f"{provider_id}-smoke", + "role": "executor", + "runtime": runtime, + "model": model_ref, + } + ) + + +def _utcnow_iso() -> str: + return datetime.now(tz=timezone.utc).isoformat() diff --git a/tests/unit/test_smoke.py b/tests/unit/test_smoke.py new file mode 100644 index 0000000..d72452d --- /dev/null +++ b/tests/unit/test_smoke.py @@ -0,0 +1,87 @@ +from lai.settings import Settings +from lai.smoke import latest_smoke_result, run_smoke_suite, save_smoke_result +from lai.system import collect_system_snapshot +from tests.helpers import FakeProvider + + +def _test_settings(repo_root, tmp_path) -> Settings: + return Settings( + root_dir=tmp_path, + model_catalog=repo_root / "configs/models/catalog.yaml", + routing_policy=repo_root / "configs/routing/policies.yaml", + database_path=tmp_path / "data/state/lai.db", + state_dir=tmp_path / "data/state", + artifacts_dir=tmp_path / "data/artifacts", + logs_dir=tmp_path / "logs", + huggingface_cache_dir=tmp_path / "data/cache/huggingface", + airllm_shards_dir=tmp_path / "data/models/airllm-shards", + raw_models_dir=tmp_path / "data/models/raw", + ) + + +def _test_providers(settings: Settings) -> dict[str, FakeProvider]: + snapshot = collect_system_snapshot(settings.root_dir, enable_gpu=False) + return { + "transformers": FakeProvider(settings, snapshot, "transformers"), + "airllm": FakeProvider(settings, snapshot, "airllm"), + "openai": FakeProvider(settings, snapshot, "openai"), + "anthropic": FakeProvider(settings, snapshot, "anthropic"), + "gemini": FakeProvider(settings, snapshot, "gemini"), + } + + +def test_provider_smoke_readiness_with_fake_providers(repo_root, tmp_path) -> None: + settings = _test_settings(repo_root, tmp_path) + result = run_smoke_suite( + settings, + suite_id="providers", + providers=_test_providers(settings), + ) + + assert result.passed + assert result.mode == "readiness" + assert result.total == 3 + assert all(check.status == "ready" for check in result.checks) + assert all(not check.executed for check in result.checks) + + +def test_provider_smoke_live_executes_and_saves_result(repo_root, tmp_path) -> None: + settings = _test_settings(repo_root, tmp_path) + result = run_smoke_suite( + settings, + suite_id="providers", + live=True, + providers=_test_providers(settings), + ) + path = save_smoke_result(tmp_path / "evals/results/smoke", result) + + assert result.passed + assert result.mode == "live" + assert all(check.executed for check in result.checks) + assert path.exists() + assert latest_smoke_result(tmp_path / "evals/results/smoke") == path + + +def test_local_smoke_excludes_airllm_by_default_and_can_include_it(repo_root, tmp_path) -> None: + settings = _test_settings(repo_root, tmp_path) + providers = _test_providers(settings) + + default_result = run_smoke_suite( + settings, + suite_id="local", + providers=providers, + ) + expanded_result = run_smoke_suite( + settings, + suite_id="local", + include_airllm=True, + providers=providers, + ) + + assert default_result.total == 1 + assert default_result.checks[0].provider_id == "transformers" + assert expanded_result.total == 2 + assert {check.provider_id for check in expanded_result.checks} == { + "transformers", + "airllm", + } From 8fefcafdabd55f6526af79132ca8a3f99f7fcbcd Mon Sep 17 00:00:00 2001 From: N3uralCreativity Date: Tue, 7 Apr 2026 20:25:48 +0200 Subject: [PATCH 09/16] feat: add smoke diagnostics to api and dashboard --- README.md | 9 +- apps/web/README.md | 1 + src/lai/api/app.py | 62 +++++++++++- src/lai/api/static/dashboard.css | 45 ++++++++- src/lai/api/static/dashboard.js | 156 ++++++++++++++++++++++++++++++- src/lai/api/static/index.html | 27 ++++++ src/lai/smoke.py | 25 ++++- tests/unit/test_api.py | 30 ++++++ 8 files changed, 345 insertions(+), 10 deletions(-) diff --git a/README.md b/README.md index be480f3..8d5520c 100644 --- a/README.md +++ b/README.md @@ -128,6 +128,9 @@ Available endpoints: - `GET /dashboard` - `GET /health` - `GET /models` +- `GET /smoke/latest` +- `GET /smoke/results` +- `POST /smoke/run` - `GET /worker/status` - `POST /route/explain` - `POST /jobs` @@ -142,8 +145,8 @@ Available endpoints: The API now also serves a read-mostly dashboard at `/dashboard` with live model health, route explanation, recent job inspection, stage telemetry, artifact/trace browsing, -replay actions, persisted worker monitoring, and bounded queue worker controls including -an until-idle drain path. +replay actions, persisted worker monitoring, saved smoke diagnostics, and bounded queue +worker controls including an until-idle drain path. ## Dedicated worker service @@ -196,7 +199,7 @@ Smoke results are saved under `evals/results/smoke/` by default. ## Near-term priorities -1. Surface latest smoke diagnostics and run summaries through the API/dashboard. +1. Add deeper dashboard drill-down for saved smoke history and provider run artifacts. 2. Harden the AirLLM local runtime path with real workstation validation. 3. Expand eval scenarios and richer reviewer/final-output refinement. 4. Add richer live provider execution visibility and workstation runbooks for real large-model jobs. diff --git a/apps/web/README.md b/apps/web/README.md index 9d506f7..1792b15 100644 --- a/apps/web/README.md +++ b/apps/web/README.md @@ -12,6 +12,7 @@ Current dashboard capabilities: - job replay controls for inline and queued reruns - persisted worker monitoring with heartbeat, current job, and queue depth - service-aware worker monitoring with daemon lock and stop-signal visibility +- saved smoke diagnostics for provider and local readiness - bounded live worker controls for processing queued jobs, including queue drain until idle - model health cards - job output inspector diff --git a/src/lai/api/app.py b/src/lai/api/app.py index 2e9617f..64a68a6 100644 --- a/src/lai/api/app.py +++ b/src/lai/api/app.py @@ -1,7 +1,7 @@ from __future__ import annotations from pathlib import Path -from typing import Mapping +from typing import Literal, Mapping from fastapi import FastAPI, HTTPException from fastapi.responses import FileResponse, RedirectResponse @@ -12,6 +12,13 @@ from ..domain import ExecutionRequest, QueueMode from ..providers import Provider from ..settings import Settings +from ..smoke import ( + latest_smoke_result, + list_smoke_results, + load_smoke_result, + run_smoke_suite, + save_smoke_result, +) class JobCreatePayload(BaseModel): @@ -37,6 +44,13 @@ class WorkerRunPayload(BaseModel): max_idle_cycles: int = Field(default=2, ge=1, le=20) +class SmokeRunPayload(BaseModel): + suite_id: Literal["providers", "local"] + live: bool = False + include_airllm: bool = False + save: bool = True + + def create_api( settings: Settings | None = None, providers: Mapping[str, Provider] | None = None, @@ -103,6 +117,52 @@ def worker_status() -> dict[str, object]: payload["log_file"] = str(app_state.settings.resolved_worker_service_log_path) return payload + @api.get("/smoke/latest") + def smoke_latest(suite_id: Literal["providers", "local"] | None = None) -> dict[str, object]: + app_state = application() + results_dir = app_state.settings.root_dir / "evals/results/smoke" + latest_path = latest_smoke_result(results_dir, suite_id=suite_id) + if latest_path is None: + raise HTTPException(status_code=404, detail="No smoke result found") + result = load_smoke_result(latest_path) + return {"result": result.model_dump(), "path": str(latest_path)} + + @api.get("/smoke/results") + def smoke_results( + limit: int = 5, + suite_id: Literal["providers", "local"] | None = None, + ) -> dict[str, object]: + app_state = application() + results_dir = app_state.settings.root_dir / "evals/results/smoke" + paths = list_smoke_results(results_dir, suite_id=suite_id, limit=limit) + return { + "results": [ + {"path": str(path), "result": load_smoke_result(path).model_dump()} + for path in paths + ] + } + + @api.post("/smoke/run") + def smoke_run(payload: SmokeRunPayload) -> dict[str, object]: + app_state = application() + result = run_smoke_suite( + app_state.settings, + suite_id=payload.suite_id, + live=payload.live, + include_airllm=payload.include_airllm, + providers=providers, + ) + saved_path = None + if payload.save: + saved_path = save_smoke_result( + app_state.settings.root_dir / "evals/results/smoke", + result, + ) + return { + "result": result.model_dump(), + "path": str(saved_path) if saved_path is not None else None, + } + @api.post("/route/explain") def explain_route(payload: JobCreatePayload) -> dict[str, object]: app_state = application() diff --git a/src/lai/api/static/dashboard.css b/src/lai/api/static/dashboard.css index d80c4f4..f371252 100644 --- a/src/lai/api/static/dashboard.css +++ b/src/lai/api/static/dashboard.css @@ -208,7 +208,8 @@ body { .route-output, .jobs-panel, .detail-panel, -.models-panel { +.models-panel, +.smoke-panel { padding: 1.35rem; } @@ -341,7 +342,9 @@ textarea { .reason-item, .job-card, .model-card, -.timeline-event { +.timeline-event, +.smoke-card, +.smoke-check { border: 1px solid rgba(33, 63, 100, 0.11); background: rgba(255, 255, 255, 0.64); border-radius: 1rem; @@ -374,6 +377,12 @@ textarea { gap: 0.75rem; } +.smoke-grid { + display: grid; + grid-template-columns: repeat(auto-fit, minmax(280px, 1fr)); + gap: 0.85rem; +} + .job-card { padding: 0.95rem 1rem; cursor: pointer; @@ -488,6 +497,34 @@ textarea { color: var(--ink); } +.smoke-card { + padding: 1rem; + display: grid; + gap: 0.8rem; +} + +.smoke-check-list { + display: grid; + gap: 0.6rem; +} + +.smoke-check { + padding: 0.8rem 0.9rem; + display: grid; + gap: 0.55rem; +} + +.smoke-check-header { + display: flex; + justify-content: space-between; + gap: 0.75rem; + align-items: start; +} + +.smoke-check-message { + color: var(--ink); +} + .models-grid { display: grid; grid-template-columns: repeat(auto-fit, minmax(220px, 1fr)); @@ -543,4 +580,8 @@ textarea { .timeline-event-head { flex-direction: column; } + + .smoke-check-header { + flex-direction: column; + } } diff --git a/src/lai/api/static/dashboard.js b/src/lai/api/static/dashboard.js index aaf9800..f129955 100644 --- a/src/lai/api/static/dashboard.js +++ b/src/lai/api/static/dashboard.js @@ -2,6 +2,10 @@ const state = { jobs: [], models: [], worker: null, + smoke: { + providers: null, + local: null, + }, selectedJobId: null, selectedArtifactId: null, }; @@ -25,6 +29,12 @@ const elements = { drainQueue: document.getElementById("drain-queue"), queueActionStatus: document.getElementById("queue-action-status"), workerSummary: document.getElementById("worker-summary"), + smokeGrid: document.getElementById("smoke-grid"), + smokeActionStatus: document.getElementById("smoke-action-status"), + runProviderSmoke: document.getElementById("run-provider-smoke"), + runLocalSmoke: document.getElementById("run-local-smoke"), + runLocalAirllmSmoke: document.getElementById("run-local-airllm-smoke"), + refreshSmoke: document.getElementById("refresh-smoke"), }; async function fetchJson(url, options = undefined) { @@ -36,6 +46,18 @@ async function fetchJson(url, options = undefined) { return response.json(); } +async function fetchOptionalJson(url, options = undefined) { + const response = await fetch(url, options); + if (response.status === 404) { + return null; + } + if (!response.ok) { + const body = await response.text(); + throw new Error(body || `Request failed with ${response.status}`); + } + return response.json(); +} + function payloadFromForm() { return { user_prompt: document.getElementById("user-prompt").value.trim(), @@ -65,6 +87,15 @@ function setQueueActionStatus(message, tone = "") { elements.queueActionStatus.innerHTML = `${toneChip}${escapeHtml(message)}`; } +function setSmokeActionStatus(message, tone = "") { + if (!elements.smokeActionStatus) { + return; + } + const toneChip = tone ? `${chip("smoke", tone, tone)} ` : ""; + elements.smokeActionStatus.className = "inline-status fade-in"; + elements.smokeActionStatus.innerHTML = `${toneChip}${escapeHtml(message)}`; +} + function escapeHtml(value) { return value .replaceAll("&", "&") @@ -75,22 +106,27 @@ function escapeHtml(value) { } async function loadOverview() { - const [health, modelsResponse, jobsResponse, workerResponse] = await Promise.all([ + const [health, modelsResponse, jobsResponse, workerResponse, providerSmoke, localSmoke] = await Promise.all([ fetchJson("/health"), fetchJson("/models"), fetchJson("/jobs?limit=12"), fetchJson("/worker/status"), + fetchOptionalJson("/smoke/latest?suite_id=providers"), + fetchOptionalJson("/smoke/latest?suite_id=local"), ]); state.models = modelsResponse.models; state.jobs = jobsResponse.jobs; state.worker = workerResponse; + state.smoke.providers = providerSmoke?.result || null; + state.smoke.local = localSmoke?.result || null; elements.environment.textContent = health.environment; elements.modelCount.textContent = String(health.model_count); elements.jobCount.textContent = String(state.jobs.length); elements.queueState.textContent = summarizeQueue(state.jobs); renderWorkerMonitor(); + renderSmokeDiagnostics(); renderModels(); renderJobs(); @@ -102,6 +138,88 @@ async function loadOverview() { } } +function renderSmokeDiagnostics() { + if (!elements.smokeGrid) { + return; + } + + const suites = [ + { key: "providers", label: "Providers", result: state.smoke.providers }, + { key: "local", label: "Local", result: state.smoke.local }, + ]; + elements.smokeGrid.className = "smoke-grid fade-in"; + elements.smokeGrid.innerHTML = suites + .map((suite) => renderSmokeCard(suite.label, suite.result)) + .join(""); +} + +function renderSmokeCard(label, result) { + if (!result) { + return ` +
+
+

${escapeHtml(label)}

+

No saved smoke result

+
+
Run a readiness sweep to capture the latest diagnostics for ${escapeHtml(label.toLowerCase())}.
+
+ `; + } + + const summaryTone = result.passed ? "ready" : result.failed_count > 0 ? "danger" : "warn"; + const executedAt = new Date(result.executed_at).toLocaleString(); + const checks = (result.checks || []) + .map((check) => renderSmokeCheck(check)) + .join(""); + + return ` +
+
+

${escapeHtml(label)}

+

${escapeHtml(result.mode)} summary

+
+
+ ${chip("status", result.passed ? "passing" : "attention", summaryTone)} + ${chip("checks", result.total)} + ${chip("success", result.success_count, result.success_count > 0 ? "ready" : "")} + ${chip("blocked", result.blocked_count, result.blocked_count > 0 ? "warn" : "")} + ${chip("failed", result.failed_count, result.failed_count > 0 ? "danger" : "")} +
+
executed: ${escapeHtml(executedAt)} | prompt: ${escapeHtml(result.prompt)}
+
${checks}
+
+ `; +} + +function renderSmokeCheck(check) { + const tone = + check.status === "failed" + ? "danger" + : check.status === "blocked" + ? "warn" + : "ready"; + const detailChips = [ + chip("status", check.status, tone), + chip("mode", check.mode), + ]; + if (check.duration_seconds !== null && check.duration_seconds !== undefined) { + detailChips.push(chip("duration", `${check.duration_seconds.toFixed(2)}s`)); + } + + return ` +
+
+
+ ${escapeHtml(check.provider_id)} +
${escapeHtml(check.model_ref)}
+
+
${detailChips.join("")}
+
+
${escapeHtml(check.message)}
+
+ `; +} + function renderWorkerMonitor() { if (!state.worker) { elements.workerState.textContent = "unknown"; @@ -553,6 +671,26 @@ async function runWorker(maxJobs, untilIdle = false) { } } +async function runSmokeSuite(suiteId, includeAirllm = false) { + const response = await fetchJson("/smoke/run", { + method: "POST", + headers: { "Content-Type": "application/json" }, + body: JSON.stringify({ + suite_id: suiteId, + live: false, + include_airllm: includeAirllm, + save: true, + }), + }); + state.smoke[suiteId] = response.result; + renderSmokeDiagnostics(); + const label = suiteId === "providers" ? "provider" : includeAirllm ? "local + AirLLM" : "local"; + setSmokeActionStatus( + `Saved ${label} readiness diagnostics at ${response.path || "the smoke results directory"}.`, + response.result.passed ? "ready" : "warn", + ); +} + function bindEvents() { elements.routeForm.addEventListener("submit", (event) => { void explainRoute(event).catch(handleError); @@ -572,6 +710,18 @@ function bindEvents() { elements.drainQueue.addEventListener("click", () => { void runWorker(null, true).catch(handleQueueError); }); + elements.runProviderSmoke.addEventListener("click", () => { + void runSmokeSuite("providers").catch(handleSmokeError); + }); + elements.runLocalSmoke.addEventListener("click", () => { + void runSmokeSuite("local").catch(handleSmokeError); + }); + elements.runLocalAirllmSmoke.addEventListener("click", () => { + void runSmokeSuite("local", true).catch(handleSmokeError); + }); + elements.refreshSmoke.addEventListener("click", () => { + void loadOverview().catch(handleError); + }); } function handleError(error) { @@ -583,6 +733,10 @@ function handleQueueError(error) { setQueueActionStatus(error.message || String(error), "danger"); } +function handleSmokeError(error) { + setSmokeActionStatus(error.message || String(error), "danger"); +} + async function boot() { bindEvents(); await loadOverview(); diff --git a/src/lai/api/static/index.html b/src/lai/api/static/index.html index 0199d98..8beb127 100644 --- a/src/lai/api/static/index.html +++ b/src/lai/api/static/index.html @@ -148,6 +148,33 @@

Model health

Loading model health...
+ +
+
+
+

Readiness

+

Smoke diagnostics

+
+
+ + + + +
+
+
+ Run a readiness sweep to save provider diagnostics without sending live prompts. +
+
Loading smoke diagnostics...
+
diff --git a/src/lai/smoke.py b/src/lai/smoke.py index 58bb5e4..d0a6525 100644 --- a/src/lai/smoke.py +++ b/src/lai/smoke.py @@ -198,10 +198,29 @@ def save_smoke_result(results_dir: Path, result: SmokeSuiteResult) -> Path: return path -def latest_smoke_result(results_dir: Path) -> Path | None: +def load_smoke_result(path: Path) -> SmokeSuiteResult: + return SmokeSuiteResult.model_validate_json(path.read_text(encoding="utf-8")) + + +def list_smoke_results( + results_dir: Path, + *, + suite_id: Literal["providers", "local"] | None = None, + limit: int = 10, +) -> list[Path]: if not results_dir.exists(): - return None - candidates = sorted(results_dir.glob("smoke-*.json"), reverse=True) + return [] + pattern = f"smoke-{suite_id}-*.json" if suite_id else "smoke-*.json" + candidates = sorted(results_dir.glob(pattern), reverse=True) + return candidates[:limit] + + +def latest_smoke_result( + results_dir: Path, + *, + suite_id: Literal["providers", "local"] | None = None, +) -> Path | None: + candidates = list_smoke_results(results_dir, suite_id=suite_id, limit=1) return candidates[0] if candidates else None diff --git a/tests/unit/test_api.py b/tests/unit/test_api.py index 00daa45..533f6b9 100644 --- a/tests/unit/test_api.py +++ b/tests/unit/test_api.py @@ -42,6 +42,9 @@ def test_api_exposes_expected_routes() -> None: assert "/dashboard" in routes assert "/health" in routes assert "/models" in routes + assert "/smoke/latest" in routes + assert "/smoke/results" in routes + assert "/smoke/run" in routes assert "/worker/status" in routes assert "/route/explain" in routes assert "/jobs" in routes @@ -66,6 +69,7 @@ def test_dashboard_routes_serve_html_and_assets(repo_root, tmp_path) -> None: assert root_response.headers["location"] == "/dashboard" assert dashboard_response.status_code == 200 assert "LAI CONTROL ROOM" in dashboard_response.text + assert "Smoke diagnostics" in dashboard_response.text assert asset_response.status_code == 200 assert "async function loadOverview" in asset_response.text @@ -121,6 +125,32 @@ def test_job_timeline_route_returns_stage_events(repo_root, tmp_path) -> None: ) +def test_smoke_routes_run_and_return_latest_result(repo_root, tmp_path) -> None: + settings = _test_settings(repo_root, tmp_path) + providers = _test_providers(settings) + client = TestClient(create_api(settings=settings, providers=providers)) + + run_response = client.post( + "/smoke/run", + json={"suite_id": "providers", "save": True}, + ) + + assert run_response.status_code == 200 + payload = run_response.json() + assert payload["result"]["suite_id"] == "providers" + assert payload["result"]["mode"] == "readiness" + assert payload["path"] + + latest_response = client.get("/smoke/latest?suite_id=providers") + assert latest_response.status_code == 200 + latest_payload = latest_response.json() + assert latest_payload["result"]["suite_id"] == "providers" + + history_response = client.get("/smoke/results?limit=3") + assert history_response.status_code == 200 + assert history_response.json()["results"] + + def test_worker_status_and_until_idle_routes(repo_root, tmp_path) -> None: settings = _test_settings(repo_root, tmp_path) providers = _test_providers(settings) From 4a1c3679f42e89d82b926304e528b9ec4ccfe2f3 Mon Sep 17 00:00:00 2001 From: N3uralCreativity Date: Tue, 7 Apr 2026 20:34:15 +0200 Subject: [PATCH 10/16] feat: add smoke history drilldown --- README.md | 14 ++- apps/web/README.md | 3 +- src/lai/api/app.py | 29 +++-- src/lai/api/static/dashboard.css | 72 ++++++++++++ src/lai/api/static/dashboard.js | 182 ++++++++++++++++++++++++++++--- src/lai/smoke.py | 18 +++ tests/unit/test_api.py | 14 ++- tests/unit/test_smoke.py | 37 ++++++- 8 files changed, 334 insertions(+), 35 deletions(-) diff --git a/README.md b/README.md index 8d5520c..d0b80d2 100644 --- a/README.md +++ b/README.md @@ -130,6 +130,7 @@ Available endpoints: - `GET /models` - `GET /smoke/latest` - `GET /smoke/results` +- `GET /smoke/results/{result_id}` - `POST /smoke/run` - `GET /worker/status` - `POST /route/explain` @@ -145,8 +146,9 @@ Available endpoints: The API now also serves a read-mostly dashboard at `/dashboard` with live model health, route explanation, recent job inspection, stage telemetry, artifact/trace browsing, -replay actions, persisted worker monitoring, saved smoke diagnostics, and bounded queue -worker controls including an until-idle drain path. +replay actions, persisted worker monitoring, saved smoke diagnostics with history drill-down, +per-check metadata/output previews, and bounded queue worker controls including an +until-idle drain path. ## Dedicated worker service @@ -199,10 +201,10 @@ Smoke results are saved under `evals/results/smoke/` by default. ## Near-term priorities -1. Add deeper dashboard drill-down for saved smoke history and provider run artifacts. -2. Harden the AirLLM local runtime path with real workstation validation. -3. Expand eval scenarios and richer reviewer/final-output refinement. -4. Add richer live provider execution visibility and workstation runbooks for real large-model jobs. +1. Harden the AirLLM local runtime path with real workstation validation. +2. Expand eval scenarios and richer reviewer/final-output refinement. +3. Add richer live provider execution visibility and workstation runbooks for real large-model jobs. +4. Add operator-facing controls for smoke result retention and cleanup over time. ## References diff --git a/apps/web/README.md b/apps/web/README.md index 1792b15..8d580ba 100644 --- a/apps/web/README.md +++ b/apps/web/README.md @@ -12,7 +12,8 @@ Current dashboard capabilities: - job replay controls for inline and queued reruns - persisted worker monitoring with heartbeat, current job, and queue depth - service-aware worker monitoring with daemon lock and stop-signal visibility -- saved smoke diagnostics for provider and local readiness +- saved smoke diagnostics for provider and local readiness, with recent-run history browsing +- per-check smoke metadata and live-output preview inspection from saved results - bounded live worker controls for processing queued jobs, including queue drain until idle - model health cards - job output inspector diff --git a/src/lai/api/app.py b/src/lai/api/app.py index 64a68a6..25f4251 100644 --- a/src/lai/api/app.py +++ b/src/lai/api/app.py @@ -16,6 +16,7 @@ latest_smoke_result, list_smoke_results, load_smoke_result, + resolve_smoke_result_path, run_smoke_suite, save_smoke_result, ) @@ -124,8 +125,7 @@ def smoke_latest(suite_id: Literal["providers", "local"] | None = None) -> dict[ latest_path = latest_smoke_result(results_dir, suite_id=suite_id) if latest_path is None: raise HTTPException(status_code=404, detail="No smoke result found") - result = load_smoke_result(latest_path) - return {"result": result.model_dump(), "path": str(latest_path)} + return _serialize_smoke_result(latest_path) @api.get("/smoke/results") def smoke_results( @@ -135,12 +135,16 @@ def smoke_results( app_state = application() results_dir = app_state.settings.root_dir / "evals/results/smoke" paths = list_smoke_results(results_dir, suite_id=suite_id, limit=limit) - return { - "results": [ - {"path": str(path), "result": load_smoke_result(path).model_dump()} - for path in paths - ] - } + return {"results": [_serialize_smoke_result(path) for path in paths]} + + @api.get("/smoke/results/{result_id}") + def smoke_result_detail(result_id: str) -> dict[str, object]: + app_state = application() + results_dir = app_state.settings.root_dir / "evals/results/smoke" + result_path = resolve_smoke_result_path(results_dir, result_id) + if result_path is None: + raise HTTPException(status_code=404, detail="Smoke result not found") + return _serialize_smoke_result(result_path) @api.post("/smoke/run") def smoke_run(payload: SmokeRunPayload) -> dict[str, object]: @@ -159,6 +163,7 @@ def smoke_run(payload: SmokeRunPayload) -> dict[str, object]: result, ) return { + "id": saved_path.name if saved_path is not None else None, "result": result.model_dump(), "path": str(saved_path) if saved_path is not None else None, } @@ -297,3 +302,11 @@ def _build_execution_request(application, payload: JobCreatePayload) -> Executio else application.settings.default_timeout_seconds ), ) + + +def _serialize_smoke_result(path: Path) -> dict[str, object]: + return { + "id": path.name, + "path": str(path), + "result": load_smoke_result(path).model_dump(), + } diff --git a/src/lai/api/static/dashboard.css b/src/lai/api/static/dashboard.css index f371252..f4bbfaa 100644 --- a/src/lai/api/static/dashboard.css +++ b/src/lai/api/static/dashboard.css @@ -508,6 +508,44 @@ textarea { gap: 0.6rem; } +.smoke-history { + display: grid; + gap: 0.65rem; +} + +.smoke-history-list { + display: grid; + grid-template-columns: repeat(auto-fit, minmax(180px, 1fr)); + gap: 0.6rem; +} + +.smoke-history-entry { + border: 1px solid rgba(33, 63, 100, 0.12); + background: rgba(255, 255, 255, 0.74); + color: var(--ink); + border-radius: 1rem; + padding: 0.78rem 0.85rem; + text-align: left; + font: inherit; + cursor: pointer; + transition: transform 160ms ease, border-color 160ms ease, background 160ms ease; +} + +.smoke-history-entry:hover, +.smoke-history-entry.active { + transform: translateY(-1px); + border-color: rgba(13, 107, 97, 0.28); + background: rgba(214, 239, 232, 0.62); +} + +.smoke-history-title { + font-weight: 700; +} + +.smoke-history-meta { + margin-top: 0.55rem; +} + .smoke-check { padding: 0.8rem 0.9rem; display: grid; @@ -525,6 +563,40 @@ textarea { color: var(--ink); } +.smoke-detail-line { + overflow-wrap: anywhere; +} + +.smoke-check-reasons { + margin: 0; + padding-left: 1.2rem; + color: var(--muted); +} + +.smoke-output-preview { + margin: 0; + max-height: 14rem; + overflow: auto; +} + +.smoke-meta-details { + display: grid; + gap: 0.55rem; +} + +.smoke-meta-details summary { + cursor: pointer; + color: var(--accent-deep); + font-family: "IBM Plex Mono", monospace; + font-size: 0.76rem; + letter-spacing: 0.05em; + text-transform: uppercase; +} + +.smoke-meta-preview { + font-size: 0.78rem; +} + .models-grid { display: grid; grid-template-columns: repeat(auto-fit, minmax(220px, 1fr)); diff --git a/src/lai/api/static/dashboard.js b/src/lai/api/static/dashboard.js index f129955..169a3df 100644 --- a/src/lai/api/static/dashboard.js +++ b/src/lai/api/static/dashboard.js @@ -3,8 +3,8 @@ const state = { models: [], worker: null, smoke: { - providers: null, - local: null, + providers: createSmokeSuiteState(), + local: createSmokeSuiteState(), }, selectedJobId: null, selectedArtifactId: null, @@ -58,6 +58,15 @@ async function fetchOptionalJson(url, options = undefined) { return response.json(); } +function createSmokeSuiteState() { + return { + latest: null, + history: [], + selectedResultId: null, + selected: null, + }; +} + function payloadFromForm() { return { user_prompt: document.getElementById("user-prompt").value.trim(), @@ -106,27 +115,23 @@ function escapeHtml(value) { } async function loadOverview() { - const [health, modelsResponse, jobsResponse, workerResponse, providerSmoke, localSmoke] = await Promise.all([ + const [health, modelsResponse, jobsResponse, workerResponse] = await Promise.all([ fetchJson("/health"), fetchJson("/models"), fetchJson("/jobs?limit=12"), fetchJson("/worker/status"), - fetchOptionalJson("/smoke/latest?suite_id=providers"), - fetchOptionalJson("/smoke/latest?suite_id=local"), ]); state.models = modelsResponse.models; state.jobs = jobsResponse.jobs; state.worker = workerResponse; - state.smoke.providers = providerSmoke?.result || null; - state.smoke.local = localSmoke?.result || null; elements.environment.textContent = health.environment; elements.modelCount.textContent = String(health.model_count); elements.jobCount.textContent = String(state.jobs.length); elements.queueState.textContent = summarizeQueue(state.jobs); renderWorkerMonitor(); - renderSmokeDiagnostics(); + await loadSmokeDiagnostics(); renderModels(); renderJobs(); @@ -138,23 +143,71 @@ async function loadOverview() { } } +async function loadSmokeDiagnostics() { + const [providerLatest, localLatest, providerHistory, localHistory] = await Promise.all([ + fetchOptionalJson("/smoke/latest?suite_id=providers"), + fetchOptionalJson("/smoke/latest?suite_id=local"), + fetchJson("/smoke/results?suite_id=providers&limit=6"), + fetchJson("/smoke/results?suite_id=local&limit=6"), + ]); + + updateSmokeSuiteState("providers", providerLatest, providerHistory?.results || []); + updateSmokeSuiteState("local", localLatest, localHistory?.results || []); + renderSmokeDiagnostics(); +} + +function updateSmokeSuiteState(suiteId, latestEntry, historyEntries) { + const suiteState = state.smoke[suiteId]; + const mergedHistory = [...historyEntries]; + if (latestEntry && !mergedHistory.some((entry) => entry.id === latestEntry.id)) { + mergedHistory.unshift(latestEntry); + } + if ( + suiteState.selected && + !mergedHistory.some((entry) => entry.id === suiteState.selected.id) + ) { + mergedHistory.unshift(suiteState.selected); + } + + suiteState.latest = latestEntry || null; + suiteState.history = mergedHistory; + + const selectedEntry = + mergedHistory.find((entry) => entry.id === suiteState.selectedResultId) || + latestEntry || + mergedHistory[0] || + null; + suiteState.selectedResultId = selectedEntry?.id || null; + suiteState.selected = selectedEntry; +} + function renderSmokeDiagnostics() { if (!elements.smokeGrid) { return; } const suites = [ - { key: "providers", label: "Providers", result: state.smoke.providers }, - { key: "local", label: "Local", result: state.smoke.local }, + { key: "providers", label: "Providers", suite: state.smoke.providers }, + { key: "local", label: "Local", suite: state.smoke.local }, ]; elements.smokeGrid.className = "smoke-grid fade-in"; elements.smokeGrid.innerHTML = suites - .map((suite) => renderSmokeCard(suite.label, suite.result)) + .map((suite) => renderSmokeCard(suite.key, suite.label, suite.suite)) .join(""); + + elements.smokeGrid.querySelectorAll("[data-smoke-suite][data-smoke-result-id]").forEach((button) => { + button.addEventListener("click", () => { + void selectSmokeResult( + button.getAttribute("data-smoke-suite"), + button.getAttribute("data-smoke-result-id"), + ).catch(handleSmokeError); + }); + }); } -function renderSmokeCard(label, result) { - if (!result) { +function renderSmokeCard(suiteId, label, suiteState) { + const selectedEntry = suiteState.selected || suiteState.latest || suiteState.history[0] || null; + if (!selectedEntry) { return `
@@ -166,8 +219,12 @@ function renderSmokeCard(label, result) { `; } + const result = selectedEntry.result; const summaryTone = result.passed ? "ready" : result.failed_count > 0 ? "danger" : "warn"; const executedAt = new Date(result.executed_at).toLocaleString(); + const history = suiteState.history + .map((entry) => renderSmokeHistoryEntry(suiteId, entry, suiteState.selectedResultId)) + .join(""); const checks = (result.checks || []) .map((check) => renderSmokeCheck(check)) .join(""); @@ -185,12 +242,39 @@ function renderSmokeCard(label, result) { ${chip("blocked", result.blocked_count, result.blocked_count > 0 ? "warn" : "")} ${chip("failed", result.failed_count, result.failed_count > 0 ? "danger" : "")}
-
executed: ${escapeHtml(executedAt)} | prompt: ${escapeHtml(result.prompt)}
+
executed: ${escapeHtml(executedAt)} | saved: ${escapeHtml(selectedEntry.id)}
+
path: ${escapeHtml(selectedEntry.path)}
+
prompt: ${escapeHtml(result.prompt)}
+
+
Recent saved runs
+
${history}
+
${checks}
`; } +function renderSmokeHistoryEntry(suiteId, entry, selectedResultId) { + const result = entry.result; + const tone = result.passed ? "ready" : result.failed_count > 0 ? "danger" : "warn"; + const activeClass = entry.id === selectedResultId ? " active" : ""; + return ` + + `; +} + function renderSmokeCheck(check) { const tone = check.status === "failed" @@ -201,10 +285,32 @@ function renderSmokeCheck(check) { const detailChips = [ chip("status", check.status, tone), chip("mode", check.mode), + chip("runtime", check.runtime), + chip("available", check.available ? "yes" : "no", check.available ? "ready" : "warn"), + chip("healthy", check.healthy ? "yes" : "no", check.healthy ? "ready" : "warn"), ]; if (check.duration_seconds !== null && check.duration_seconds !== undefined) { detailChips.push(chip("duration", `${check.duration_seconds.toFixed(2)}s`)); } + if (check.executed) { + detailChips.push(chip("executed", check.success ? "yes" : "failed", check.success ? "ready" : "danger")); + } + + const reasons = (check.reasons || []).length + ? ` +
    + ${check.reasons.map((reason) => `
  • ${escapeHtml(reason)}
  • `).join("")} +
+ ` + : ""; + const preview = check.output_preview + ? ` +
+
Output preview
+
${escapeHtml(check.output_preview)}
+
+ ` + : ""; return `
@@ -216,10 +322,50 @@ function renderSmokeCheck(check) {
${detailChips.join("")}
${escapeHtml(check.message)}
+ ${reasons} + ${preview} + ${renderSmokeSnapshot("Metadata snapshot", check.metadata)} + ${renderSmokeSnapshot("Capabilities snapshot", check.capabilities)}
`; } +function renderSmokeSnapshot(label, payload) { + if (!payload || Object.keys(payload).length === 0) { + return ""; + } + return ` +
+ ${escapeHtml(label)} +
${escapeHtml(
+        JSON.stringify(payload, null, 2),
+      )}
+
+ `; +} + +async function selectSmokeResult(suiteId, resultId) { + if (!suiteId || !resultId || !state.smoke[suiteId]) { + return; + } + + const suiteState = state.smoke[suiteId]; + const knownEntry = + suiteState.history.find((entry) => entry.id === resultId) || + (suiteState.latest?.id === resultId ? suiteState.latest : null) || + (suiteState.selected?.id === resultId ? suiteState.selected : null); + const entry = + knownEntry || + (await fetchJson(`/smoke/results/${encodeURIComponent(resultId)}`)); + + if (!suiteState.history.some((item) => item.id === entry.id)) { + suiteState.history.unshift(entry); + } + suiteState.selectedResultId = entry.id; + suiteState.selected = entry; + renderSmokeDiagnostics(); +} + function renderWorkerMonitor() { if (!state.worker) { elements.workerState.textContent = "unknown"; @@ -682,9 +828,11 @@ async function runSmokeSuite(suiteId, includeAirllm = false) { save: true, }), }); - state.smoke[suiteId] = response.result; - renderSmokeDiagnostics(); const label = suiteId === "providers" ? "provider" : includeAirllm ? "local + AirLLM" : "local"; + await loadSmokeDiagnostics(); + if (response.id) { + await selectSmokeResult(suiteId, response.id); + } setSmokeActionStatus( `Saved ${label} readiness diagnostics at ${response.path || "the smoke results directory"}.`, response.result.passed ? "ready" : "warn", @@ -720,7 +868,7 @@ function bindEvents() { void runSmokeSuite("local", true).catch(handleSmokeError); }); elements.refreshSmoke.addEventListener("click", () => { - void loadOverview().catch(handleError); + void loadSmokeDiagnostics().catch(handleSmokeError); }); } diff --git a/src/lai/smoke.py b/src/lai/smoke.py index d0a6525..cb61590 100644 --- a/src/lai/smoke.py +++ b/src/lai/smoke.py @@ -202,6 +202,24 @@ def load_smoke_result(path: Path) -> SmokeSuiteResult: return SmokeSuiteResult.model_validate_json(path.read_text(encoding="utf-8")) +def resolve_smoke_result_path(results_dir: Path, result_id: str) -> Path | None: + candidate_name = Path(result_id).name + if candidate_name != result_id or not candidate_name.endswith(".json"): + return None + + candidate = (results_dir / candidate_name).resolve() + try: + candidate.relative_to(results_dir.resolve()) + except ValueError: + return None + + if not candidate.exists() or not candidate.is_file(): + return None + if not candidate.name.startswith("smoke-"): + return None + return candidate + + def list_smoke_results( results_dir: Path, *, diff --git a/tests/unit/test_api.py b/tests/unit/test_api.py index 533f6b9..f80a4df 100644 --- a/tests/unit/test_api.py +++ b/tests/unit/test_api.py @@ -44,6 +44,7 @@ def test_api_exposes_expected_routes() -> None: assert "/models" in routes assert "/smoke/latest" in routes assert "/smoke/results" in routes + assert "/smoke/results/{result_id}" in routes assert "/smoke/run" in routes assert "/worker/status" in routes assert "/route/explain" in routes @@ -71,7 +72,7 @@ def test_dashboard_routes_serve_html_and_assets(repo_root, tmp_path) -> None: assert "LAI CONTROL ROOM" in dashboard_response.text assert "Smoke diagnostics" in dashboard_response.text assert asset_response.status_code == 200 - assert "async function loadOverview" in asset_response.text + assert "async function loadSmokeDiagnostics" in asset_response.text def test_job_artifact_routes_return_persisted_content(repo_root, tmp_path) -> None: @@ -139,16 +140,25 @@ def test_smoke_routes_run_and_return_latest_result(repo_root, tmp_path) -> None: payload = run_response.json() assert payload["result"]["suite_id"] == "providers" assert payload["result"]["mode"] == "readiness" + assert payload["id"] assert payload["path"] latest_response = client.get("/smoke/latest?suite_id=providers") assert latest_response.status_code == 200 latest_payload = latest_response.json() assert latest_payload["result"]["suite_id"] == "providers" + assert latest_payload["id"] history_response = client.get("/smoke/results?limit=3") assert history_response.status_code == 200 - assert history_response.json()["results"] + history_payload = history_response.json() + assert history_payload["results"] + detail_response = client.get(f"/smoke/results/{payload['id']}") + assert detail_response.status_code == 200 + assert detail_response.json()["id"] == payload["id"] + + missing_response = client.get("/smoke/results/not-a-smoke-result.json") + assert missing_response.status_code == 404 def test_worker_status_and_until_idle_routes(repo_root, tmp_path) -> None: diff --git a/tests/unit/test_smoke.py b/tests/unit/test_smoke.py index d72452d..91bfbe4 100644 --- a/tests/unit/test_smoke.py +++ b/tests/unit/test_smoke.py @@ -1,5 +1,11 @@ from lai.settings import Settings -from lai.smoke import latest_smoke_result, run_smoke_suite, save_smoke_result +from lai.smoke import ( + latest_smoke_result, + list_smoke_results, + resolve_smoke_result_path, + run_smoke_suite, + save_smoke_result, +) from lai.system import collect_system_snapshot from tests.helpers import FakeProvider @@ -62,6 +68,35 @@ def test_provider_smoke_live_executes_and_saves_result(repo_root, tmp_path) -> N assert latest_smoke_result(tmp_path / "evals/results/smoke") == path +def test_smoke_result_listing_and_resolution(repo_root, tmp_path) -> None: + settings = _test_settings(repo_root, tmp_path) + results_dir = tmp_path / "evals/results/smoke" + + provider_path = save_smoke_result( + results_dir, + run_smoke_suite( + settings, + suite_id="providers", + providers=_test_providers(settings), + ), + ) + local_path = save_smoke_result( + results_dir, + run_smoke_suite( + settings, + suite_id="local", + providers=_test_providers(settings), + ), + ) + + listed = list_smoke_results(results_dir, suite_id="providers", limit=5) + + assert provider_path in listed + assert local_path not in listed + assert resolve_smoke_result_path(results_dir, provider_path.name) == provider_path + assert resolve_smoke_result_path(results_dir, "..\\outside.json") is None + + def test_local_smoke_excludes_airllm_by_default_and_can_include_it(repo_root, tmp_path) -> None: settings = _test_settings(repo_root, tmp_path) providers = _test_providers(settings) From 12697a5d10845af2f192065e411a4597825108a1 Mon Sep 17 00:00:00 2001 From: N3uralCreativity Date: Tue, 7 Apr 2026 20:47:10 +0200 Subject: [PATCH 11/16] feat: add workstation readiness validation --- README.md | 33 +- apps/web/README.md | 1 + docs/setup/airllm-runbook.md | 99 +++++ docs/setup/workstation.md | 20 + src/lai/api/app.py | 11 + src/lai/api/static/dashboard.css | 6 + src/lai/api/static/dashboard.js | 124 +++++- src/lai/api/static/index.html | 16 + src/lai/cli.py | 107 ++++- src/lai/system.py | 108 ++++- src/lai/workstation.py | 677 +++++++++++++++++++++++++++++++ tests/unit/test_api.py | 18 +- tests/unit/test_workstation.py | 105 +++++ 13 files changed, 1305 insertions(+), 20 deletions(-) create mode 100644 docs/setup/airllm-runbook.md create mode 100644 src/lai/workstation.py create mode 100644 tests/unit/test_workstation.py diff --git a/README.md b/README.md index d0b80d2..1de5ffc 100644 --- a/README.md +++ b/README.md @@ -95,6 +95,8 @@ python -m pip install -e .[local,providers] python -m lai.cli doctor python -m lai.cli models list python -m lai.cli models check +python -m lai.cli workstation validate +python -m lai.cli workstation validate --profile airllm python -m lai.cli route explain "Summarize this note." python -m lai.cli run "Create a detailed implementation strategy." python -m lai.cli jobs list @@ -128,6 +130,7 @@ Available endpoints: - `GET /dashboard` - `GET /health` - `GET /models` +- `GET /workstation/readiness` - `GET /smoke/latest` - `GET /smoke/results` - `GET /smoke/results/{result_id}` @@ -145,10 +148,10 @@ Available endpoints: - `POST /worker/run` The API now also serves a read-mostly dashboard at `/dashboard` with live model health, -route explanation, recent job inspection, stage telemetry, artifact/trace browsing, -replay actions, persisted worker monitoring, saved smoke diagnostics with history drill-down, -per-check metadata/output previews, and bounded queue worker controls including an -until-idle drain path. +workstation readiness for heavy local execution, route explanation, recent job inspection, +stage telemetry, artifact/trace browsing, replay actions, persisted worker monitoring, +saved smoke diagnostics with history drill-down, per-check metadata/output previews, +and bounded queue worker controls including an until-idle drain path. ## Dedicated worker service @@ -191,6 +194,20 @@ python -m lai.cli smoke local --live --include-airllm Smoke results are saved under `evals/results/smoke/` by default. +## Workstation readiness + +Validate the local machine before you depend on the heavy local path: + +```powershell +python -m lai.cli workstation validate +python -m lai.cli workstation validate --profile airllm +``` + +This surfaces Python compatibility, local package readiness, Hugging Face credentials, +disk headroom, RAM posture, GPU availability, and AirLLM-specific readiness in one report. +For step-by-step remediation and overnight run guidance, see +`docs/setup/airllm-runbook.md`. + ## Initial GitHub rules encoded in this repo - Pull request template and issue forms for consistent planning. @@ -201,10 +218,10 @@ Smoke results are saved under `evals/results/smoke/` by default. ## Near-term priorities -1. Harden the AirLLM local runtime path with real workstation validation. -2. Expand eval scenarios and richer reviewer/final-output refinement. -3. Add richer live provider execution visibility and workstation runbooks for real large-model jobs. -4. Add operator-facing controls for smoke result retention and cleanup over time. +1. Expand eval scenarios and richer reviewer/final-output refinement. +2. Add richer live provider execution visibility for real provider calls. +3. Add operator-facing controls for smoke result retention and cleanup over time. +4. Add richer heavy-job run telemetry and recovery tooling for interrupted overnight executions. ## References diff --git a/apps/web/README.md b/apps/web/README.md index 8d580ba..d8e3deb 100644 --- a/apps/web/README.md +++ b/apps/web/README.md @@ -5,6 +5,7 @@ The first dashboard is now served directly by the FastAPI app at `/dashboard`. Current dashboard capabilities: - live health and catalog summary +- workstation readiness summary for heavy local and AirLLM execution paths - route explanation form - job submission and recent queue inspection - persisted stage telemetry timeline for planner, executor, and reviewer flow diff --git a/docs/setup/airllm-runbook.md b/docs/setup/airllm-runbook.md new file mode 100644 index 0000000..00fbc13 --- /dev/null +++ b/docs/setup/airllm-runbook.md @@ -0,0 +1,99 @@ +# AirLLM Heavy-Local Runbook + +## Purpose + +Use this runbook before starting the `execution-large` AirLLM path or any similar heavy +local executor that may run overnight. + +## Pre-flight checklist + +Run the workstation validator first: + +```powershell +python -m lai.cli workstation validate --profile airllm +``` + +Look for failures in these areas: + +- `AirLLM dependency stack` +- `Hugging Face credentials` +- `GPU availability` +- `GPU VRAM budget` +- `AirLLM shard storage` +- `Hugging Face cache staging space` + +If the validator reports only warnings, the heavy path is still usable, but you should expect +slower runtime or more disk pressure. + +## First-time setup + +1. Install the local runtime packages: + +```powershell +python -m pip install -e .[local] +``` + +2. Add the Hugging Face token to `.env`: + +```text +LAI_HF_TOKEN=hf_... +``` + +3. Make sure these directories live on a large drive if possible: + +- `data/cache/huggingface` +- `data/models/airllm-shards` + +## Common failure modes + +### Missing `airllm`, `torch`, `accelerate`, or `transformers` + +- Install the local extra again with `python -m pip install -e .[local]` +- Re-run `python -m lai.cli workstation validate --profile airllm` + +### Missing Hugging Face token + +- Set `LAI_HF_TOKEN` in `.env` +- Restart the shell if needed +- Re-run the workstation validator + +### Not enough disk + +- Move `LAI_HUGGINGFACE_CACHE_DIR` or `LAI_AIRLLM_SHARDS_DIR` to a larger drive +- Clear older downloaded models or old shard outputs you no longer need +- Re-run the validator and confirm both cache and shard paths meet the reported target + +### No GPU detected + +- Confirm CUDA drivers and `nvidia-smi` work +- If CPU fallback is allowed, accept that the run may take many hours +- If turnaround matters, route the executor to a remote provider instead + +### VRAM below recommendation + +- AirLLM can still work with aggressive layer sharding on smaller GPUs, but expect slower + movement between disk, RAM, and VRAM +- Keep other GPU-heavy programs closed + +## Overnight run tips + +- Start the dedicated worker service instead of manually stepping jobs: + +```powershell +lai-worker run --poll-interval 10 +``` + +- Check the dashboard at `/dashboard` for: + - workstation readiness + - worker status + - queue depth + - job telemetry and artifacts + +- Keep the machine awake and disable sleep for long local runs. + +## After a failure + +1. Inspect the job timeline and artifacts in the dashboard or `lai jobs show `. +2. Re-run `python -m lai.cli workstation validate --profile airllm`. +3. If the machine is still blocked, replay the job against a remote executor until the local + workstation is ready again. diff --git a/docs/setup/workstation.md b/docs/setup/workstation.md index 64f920f..2ef9156 100644 --- a/docs/setup/workstation.md +++ b/docs/setup/workstation.md @@ -30,6 +30,26 @@ Add AirLLM only when you are ready to test the large-model path: python -m pip install airllm ``` +## Readiness validation + +Before you trust the heavy local executor, validate the workstation directly: + +```powershell +python -m lai.cli workstation validate +python -m lai.cli workstation validate --profile airllm +``` + +This report checks: + +- Python compatibility for the local tooling stack +- local package availability for `torch`, `transformers`, `accelerate`, `huggingface-hub`, and `airllm` +- Hugging Face token presence for gated local models +- free disk in the Hugging Face cache and AirLLM shard directories +- RAM, GPU presence, and measured VRAM when available + +For a deeper remediation checklist and overnight-run guidance, use the AirLLM runbook in +`docs/setup/airllm-runbook.md`. + ## Always-on local worker Once the environment is ready, you can keep queued jobs processing in the background with the diff --git a/src/lai/api/app.py b/src/lai/api/app.py index 25f4251..e3fe6fc 100644 --- a/src/lai/api/app.py +++ b/src/lai/api/app.py @@ -20,6 +20,7 @@ run_smoke_suite, save_smoke_result, ) +from ..workstation import build_workstation_readiness class JobCreatePayload(BaseModel): @@ -104,6 +105,16 @@ def list_models() -> dict[str, object]: ] } + @api.get("/workstation/readiness") + def workstation_readiness() -> dict[str, object]: + app_state = application() + report = build_workstation_readiness( + app_state.settings, + app_state.config, + snapshot=app_state.provider_registry.system_snapshot, + ) + return report.model_dump() + @api.get("/worker/status") def worker_status() -> dict[str, object]: app_state = application() diff --git a/src/lai/api/static/dashboard.css b/src/lai/api/static/dashboard.css index f4bbfaa..7abc4ad 100644 --- a/src/lai/api/static/dashboard.css +++ b/src/lai/api/static/dashboard.css @@ -209,6 +209,7 @@ body { .jobs-panel, .detail-panel, .models-panel, +.workstation-panel, .smoke-panel { padding: 1.35rem; } @@ -563,6 +564,11 @@ textarea { color: var(--ink); } +.smoke-check-remediation { + color: var(--muted); + font-size: 0.92rem; +} + .smoke-detail-line { overflow-wrap: anywhere; } diff --git a/src/lai/api/static/dashboard.js b/src/lai/api/static/dashboard.js index 169a3df..b6de5dc 100644 --- a/src/lai/api/static/dashboard.js +++ b/src/lai/api/static/dashboard.js @@ -2,6 +2,7 @@ const state = { jobs: [], models: [], worker: null, + workstation: null, smoke: { providers: createSmokeSuiteState(), local: createSmokeSuiteState(), @@ -29,6 +30,9 @@ const elements = { drainQueue: document.getElementById("drain-queue"), queueActionStatus: document.getElementById("queue-action-status"), workerSummary: document.getElementById("worker-summary"), + workstationSummary: document.getElementById("workstation-summary"), + workstationGrid: document.getElementById("workstation-grid"), + refreshWorkstation: document.getElementById("refresh-workstation"), smokeGrid: document.getElementById("smoke-grid"), smokeActionStatus: document.getElementById("smoke-action-status"), runProviderSmoke: document.getElementById("run-provider-smoke"), @@ -115,22 +119,25 @@ function escapeHtml(value) { } async function loadOverview() { - const [health, modelsResponse, jobsResponse, workerResponse] = await Promise.all([ + const [health, modelsResponse, jobsResponse, workerResponse, workstationResponse] = await Promise.all([ fetchJson("/health"), fetchJson("/models"), fetchJson("/jobs?limit=12"), fetchJson("/worker/status"), + fetchJson("/workstation/readiness"), ]); state.models = modelsResponse.models; state.jobs = jobsResponse.jobs; state.worker = workerResponse; + state.workstation = workstationResponse; elements.environment.textContent = health.environment; elements.modelCount.textContent = String(health.model_count); elements.jobCount.textContent = String(state.jobs.length); elements.queueState.textContent = summarizeQueue(state.jobs); renderWorkerMonitor(); + renderWorkstationReadiness(); await loadSmokeDiagnostics(); renderModels(); @@ -143,6 +150,11 @@ async function loadOverview() { } } +async function loadWorkstationReadiness() { + state.workstation = await fetchJson("/workstation/readiness"); + renderWorkstationReadiness(); +} + async function loadSmokeDiagnostics() { const [providerLatest, localLatest, providerHistory, localHistory] = await Promise.all([ fetchOptionalJson("/smoke/latest?suite_id=providers"), @@ -205,6 +217,106 @@ function renderSmokeDiagnostics() { }); } +function renderWorkstationReadiness() { + if (!elements.workstationSummary || !elements.workstationGrid) { + return; + } + + if (!state.workstation) { + elements.workstationSummary.className = "worker-summary empty"; + elements.workstationSummary.textContent = "Workstation readiness unavailable."; + elements.workstationGrid.className = "smoke-grid empty"; + elements.workstationGrid.textContent = "No workstation profiles loaded."; + return; + } + + const system = state.workstation.system || {}; + const profiles = state.workstation.profiles || []; + const blockingProfiles = profiles.filter((profile) => !profile.ready).length; + const warningProfiles = profiles.filter( + (profile) => profile.ready && profile.warning_count > 0, + ).length; + + elements.workstationSummary.className = "worker-summary fade-in"; + elements.workstationSummary.innerHTML = ` +
+ ${chip("platform", system.platform || "n/a")} + ${chip("python", system.python_version || "n/a")} + ${chip("gpu", system.gpu_name || "none", system.gpu_name ? "ready" : "warn")} + ${chip("vram", formatGigabytes(system.gpu_total_memory_gb), system.gpu_total_memory_gb ? "ready" : "warn")} + ${chip("ram", formatGigabytes(system.total_ram_gb), system.total_ram_gb ? "ready" : "warn")} + ${chip("blocked", blockingProfiles, blockingProfiles > 0 ? "danger" : "ready")} + ${chip("warnings", warningProfiles, warningProfiles > 0 ? "warn" : "ready")} +
+
+ hf cache: ${escapeHtml(system.huggingface_cache_dir || "n/a")} | + airllm shards: ${escapeHtml(system.airllm_shards_dir || "n/a")} +
+ `; + + if (profiles.length === 0) { + elements.workstationGrid.className = "smoke-grid empty"; + elements.workstationGrid.textContent = "No workstation profiles defined."; + return; + } + + elements.workstationGrid.className = "smoke-grid fade-in"; + elements.workstationGrid.innerHTML = profiles + .map((profile) => renderWorkstationProfile(profile)) + .join(""); +} + +function renderWorkstationProfile(profile) { + const tone = !profile.ready + ? "danger" + : profile.warning_count > 0 + ? "warn" + : "ready"; + const checks = (profile.checks || []) + .map((check) => renderWorkstationCheck(check)) + .join(""); + + return ` +
+
+

${escapeHtml(profile.profile_id)}

+

${escapeHtml(profile.title)}

+
+
+ ${chip("status", profile.ready ? "ready" : "blocked", tone)} + ${chip("warnings", profile.warning_count, profile.warning_count > 0 ? "warn" : "ready")} + ${chip("failures", profile.failure_count, profile.failure_count > 0 ? "danger" : "ready")} +
+
${escapeHtml(profile.summary)}
+
${checks}
+
+ `; +} + +function renderWorkstationCheck(check) { + const tone = + check.status === "fail" + ? "danger" + : check.status === "warn" + ? "warn" + : "ready"; + const remediation = check.remediation + ? `
${escapeHtml(check.remediation)}
` + : ""; + + return ` +
+
+ ${escapeHtml(check.title)} +
${chip("status", check.status, tone)}
+
+
${escapeHtml(check.summary)}
+ ${remediation} + ${renderSmokeSnapshot("Check metadata", check.metadata)} +
+ `; +} + function renderSmokeCard(suiteId, label, suiteState) { const selectedEntry = suiteState.selected || suiteState.latest || suiteState.history[0] || null; if (!selectedEntry) { @@ -344,6 +456,13 @@ function renderSmokeSnapshot(label, payload) { `; } +function formatGigabytes(value) { + if (value === null || value === undefined) { + return "n/a"; + } + return `${Number(value).toFixed(2)} GB`; +} + async function selectSmokeResult(suiteId, resultId) { if (!suiteId || !resultId || !state.smoke[suiteId]) { return; @@ -858,6 +977,9 @@ function bindEvents() { elements.drainQueue.addEventListener("click", () => { void runWorker(null, true).catch(handleQueueError); }); + elements.refreshWorkstation.addEventListener("click", () => { + void loadWorkstationReadiness().catch(handleError); + }); elements.runProviderSmoke.addEventListener("click", () => { void runSmokeSuite("providers").catch(handleSmokeError); }); diff --git a/src/lai/api/static/index.html b/src/lai/api/static/index.html index 8beb127..cac937d 100644 --- a/src/lai/api/static/index.html +++ b/src/lai/api/static/index.html @@ -149,6 +149,22 @@

Model health

Loading model health...
+
+
+
+

Heavy Local

+

Workstation readiness

+
+
+ +
+
+
Loading workstation readiness...
+
Loading workstation profiles...
+
+
diff --git a/src/lai/cli.py b/src/lai/cli.py index d495d2c..17b5a41 100644 --- a/src/lai/cli.py +++ b/src/lai/cli.py @@ -16,6 +16,7 @@ from .smoke import latest_smoke_result, run_smoke_suite, save_smoke_result from .system import collect_system_snapshot from .worker.service import WorkerServiceConfig, WorkerServiceHost, request_worker_service_stop +from .workstation import build_workstation_readiness app = typer.Typer(help="Utilities for the LAI orchestration platform.", no_args_is_help=True) models_app = typer.Typer(help="Inspect model definitions and provider readiness.") @@ -24,13 +25,15 @@ worker_app = typer.Typer(help="Run the local job worker.") eval_app = typer.Typer(help="Run evaluation suites against routing behavior.") smoke_app = typer.Typer(help="Run provider readiness and smoke diagnostics.") +workstation_app = typer.Typer(help="Validate the local workstation for heavy local execution.") app.add_typer(models_app, name="models") app.add_typer(route_app, name="route") app.add_typer(jobs_app, name="jobs") app.add_typer(worker_app, name="worker") app.add_typer(eval_app, name="eval") app.add_typer(smoke_app, name="smoke") -console = Console() +app.add_typer(workstation_app, name="workstation") +console = Console(markup=False) @app.callback() @@ -46,6 +49,13 @@ def doctor() -> None: *runtime_directories(settings.root_dir), enable_gpu=settings.enable_gpu ) application = create_application(settings=settings) + workstation = build_workstation_readiness( + settings, + application.config, + snapshot=snapshot, + ) + local_profile = workstation.profile("local") + airllm_profile = workstation.profile("airllm") if len(workstation.profiles) > 1 else None table = Table(title="LAI Workspace Doctor") table.add_column("Item") @@ -58,7 +68,20 @@ def doctor() -> None: table.add_row("Database", _status_line(settings.resolved_database_path)) table.add_row("GPU available", "yes" if snapshot.has_gpu else "no") table.add_row("GPU name", snapshot.gpu_name or "n/a") + table.add_row( + "GPU VRAM", + f"{snapshot.gpu_total_memory_gb:.2f} GB" + if snapshot.gpu_total_memory_gb is not None + else "n/a", + ) + table.add_row( + "System RAM", + f"{snapshot.total_ram_gb:.2f} GB" if snapshot.total_ram_gb is not None else "n/a", + ) table.add_row("Catalog models", str(len(application.config.model_catalog.models))) + table.add_row("Local runtime", _profile_status(local_profile)) + if airllm_profile is not None: + table.add_row("AirLLM heavy path", _profile_status(airllm_profile)) for path in runtime_directories(settings.root_dir): table.add_row("Runtime path", _status_line(path)) @@ -113,6 +136,54 @@ def check_models() -> None: console.print(table) +@workstation_app.command("validate") +def validate_workstation( + profile: Literal["all", "local", "airllm"] = typer.Option( + "all", + help="Show all profiles, or focus on the local baseline or AirLLM heavy path.", + ), +) -> None: + """Validate workstation readiness for local and AirLLM execution paths.""" + settings = Settings() + application = create_application(settings=settings) + report = build_workstation_readiness( + settings, + application.config, + snapshot=application.provider_registry.system_snapshot, + ) + + system_table = Table(title="LAI Workstation") + system_table.add_column("Fact") + system_table.add_column("Value") + system_table.add_row("Platform", report.system["platform"]) + system_table.add_row("Python", report.system["python_version"]) + system_table.add_row("GPU", report.system["gpu_name"] or "n/a") + system_table.add_row( + "GPU VRAM", + f"{report.system['gpu_total_memory_gb']:.2f} GB" + if report.system["gpu_total_memory_gb"] is not None + else "n/a", + ) + system_table.add_row( + "System RAM", + f"{report.system['total_ram_gb']:.2f} GB" + if report.system["total_ram_gb"] is not None + else "n/a", + ) + console.print(system_table) + + selected_profiles = [ + candidate + for candidate in report.profiles + if profile == "all" or candidate.profile_id == profile + ] + for candidate in selected_profiles: + _print_workstation_profile(candidate) + + if any(not candidate.ready for candidate in selected_profiles): + raise typer.Exit(code=1) + + @route_app.command("explain") def explain_route( prompt: str = typer.Argument(..., help="The user request to classify and route."), @@ -541,6 +612,40 @@ def _status_line(path: Path) -> str: return f"{path} ({suffix})" +def _profile_status(profile) -> str: + if profile.ready and profile.warning_count == 0: + return "ready" + if profile.ready: + return f"ready with {profile.warning_count} warning(s)" + return f"blocked ({profile.failure_count} fail / {profile.warning_count} warn)" + + +def _print_workstation_profile(profile) -> None: + summary = Table(title=profile.title) + summary.add_column("Field") + summary.add_column("Value") + summary.add_row("Profile", profile.profile_id) + summary.add_row("Ready", "yes" if profile.ready else "no") + summary.add_row("Warnings", str(profile.warning_count)) + summary.add_row("Failures", str(profile.failure_count)) + summary.add_row("Summary", profile.summary) + console.print(summary) + + checks = Table(title=f"{profile.title} Checks") + checks.add_column("Check") + checks.add_column("Status") + checks.add_column("Summary") + checks.add_column("Remediation") + for check in profile.checks: + checks.add_row( + check.title, + check.status, + check.summary, + check.remediation or "n/a", + ) + console.print(checks) + + def _run_smoke_command( *, suite_id: Literal["providers", "local"], diff --git a/src/lai/system.py b/src/lai/system.py index cf1b1d4..8a92879 100644 --- a/src/lai/system.py +++ b/src/lai/system.py @@ -1,7 +1,12 @@ from __future__ import annotations +import ctypes +import os +import platform import shutil import subprocess +import sys +from ctypes import Structure, byref, c_ulong, c_ulonglong, sizeof from pathlib import Path from pydantic import BaseModel, Field @@ -10,6 +15,12 @@ class SystemSnapshot(BaseModel): has_gpu: bool gpu_name: str | None = None + gpu_total_memory_gb: float | None = None + gpu_driver_version: str | None = None + total_ram_gb: float | None = None + available_ram_gb: float | None = None + platform_name: str = Field(default_factory=platform.platform) + python_version: str = Field(default_factory=platform.python_version) free_disk_gb: dict[str, float] = Field(default_factory=dict) @@ -17,39 +28,91 @@ def collect_system_snapshot(*paths: Path, enable_gpu: bool = True) -> SystemSnap free_disk_gb: dict[str, float] = {} for path in paths: free_disk_gb[str(path)] = round(available_disk_gb(path), 2) - has_gpu, gpu_name = detect_gpu(enable_gpu=enable_gpu) - return SystemSnapshot(has_gpu=has_gpu, gpu_name=gpu_name, free_disk_gb=free_disk_gb) + has_gpu, gpu_name, gpu_total_memory_gb, gpu_driver_version = detect_gpu_details( + enable_gpu=enable_gpu + ) + total_ram_gb, available_ram_gb = detect_memory_gb() + return SystemSnapshot( + has_gpu=has_gpu, + gpu_name=gpu_name, + gpu_total_memory_gb=gpu_total_memory_gb, + gpu_driver_version=gpu_driver_version, + total_ram_gb=total_ram_gb, + available_ram_gb=available_ram_gb, + free_disk_gb=free_disk_gb, + ) def detect_gpu(enable_gpu: bool = True) -> tuple[bool, str | None]: + has_gpu, gpu_name, _, _ = detect_gpu_details(enable_gpu=enable_gpu) + return has_gpu, gpu_name + + +def detect_gpu_details( + enable_gpu: bool = True, +) -> tuple[bool, str | None, float | None, str | None]: if not enable_gpu: - return False, None + return False, None, None, None try: import torch if torch.cuda.is_available(): - return True, str(torch.cuda.get_device_name(0)) + properties = torch.cuda.get_device_properties(0) + total_memory_gb = round(properties.total_memory / (1024**3), 2) + return True, str(torch.cuda.get_device_name(0)), total_memory_gb, None except Exception: pass try: result = subprocess.run( - ["nvidia-smi", "--query-gpu=name", "--format=csv,noheader"], + [ + "nvidia-smi", + "--query-gpu=name,memory.total,driver_version", + "--format=csv,noheader,nounits", + ], check=False, capture_output=True, text=True, ) except FileNotFoundError: - return False, None + return False, None, None, None if result.returncode != 0: - return False, None + return False, None, None, None lines = [line.strip() for line in result.stdout.splitlines() if line.strip()] if not lines: - return False, None - return True, lines[0] + return False, None, None, None + + parts = [part.strip() for part in lines[0].split(",")] + gpu_name = parts[0] if parts else None + memory_total_gb = None + if len(parts) > 1: + try: + memory_total_gb = round(float(parts[1]) / 1024, 2) + except ValueError: + memory_total_gb = None + driver_version = parts[2] if len(parts) > 2 else None + return True, gpu_name, memory_total_gb, driver_version + + +def detect_memory_gb() -> tuple[float | None, float | None]: + if sys.platform == "win32": + return _detect_windows_memory_gb() + + if hasattr(os, "sysconf"): + try: + page_size = os.sysconf("SC_PAGE_SIZE") + total_pages = os.sysconf("SC_PHYS_PAGES") + available_pages = os.sysconf("SC_AVPHYS_PAGES") + except (OSError, ValueError): + return None, None + total_gb = round((page_size * total_pages) / (1024**3), 2) + available_gb = round((page_size * available_pages) / (1024**3), 2) + return total_gb, available_gb + + return None, None def available_disk_gb(path: Path) -> float: @@ -65,3 +128,30 @@ def _nearest_existing_path(path: Path) -> Path: break candidate = candidate.parent return candidate + + +def _detect_windows_memory_gb() -> tuple[float | None, float | None]: + try: + + class MEMORYSTATUSEX(Structure): + _fields_ = [ + ("dwLength", c_ulong), + ("dwMemoryLoad", c_ulong), + ("ullTotalPhys", c_ulonglong), + ("ullAvailPhys", c_ulonglong), + ("ullTotalPageFile", c_ulonglong), + ("ullAvailPageFile", c_ulonglong), + ("ullTotalVirtual", c_ulonglong), + ("ullAvailVirtual", c_ulonglong), + ("ullAvailExtendedVirtual", c_ulonglong), + ] + + status = MEMORYSTATUSEX() + status.dwLength = sizeof(MEMORYSTATUSEX) + if not ctypes.windll.kernel32.GlobalMemoryStatusEx(byref(status)): # type: ignore[attr-defined] + return None, None + total_gb = round(status.ullTotalPhys / (1024**3), 2) + available_gb = round(status.ullAvailPhys / (1024**3), 2) + return total_gb, available_gb + except Exception: + return None, None diff --git a/src/lai/workstation.py b/src/lai/workstation.py new file mode 100644 index 0000000..91d4405 --- /dev/null +++ b/src/lai/workstation.py @@ -0,0 +1,677 @@ +from __future__ import annotations + +import sys +from datetime import datetime, timezone +from importlib.metadata import PackageNotFoundError, version +from pathlib import Path +from typing import Any, Literal, Mapping + +from pydantic import Field + +from .config import AppConfig +from .domain import LAIModel, ModelRuntime, ModelSpec +from .layout import runtime_directories +from .settings import Settings +from .system import SystemSnapshot, collect_system_snapshot + +SUPPORTED_PYTHON = ((3, 11), (3, 12)) + + +class WorkstationCheck(LAIModel): + id: str + status: Literal["pass", "warn", "fail"] + title: str + summary: str + remediation: str | None = None + metadata: dict[str, Any] = Field(default_factory=dict) + + +class WorkstationProfileReport(LAIModel): + profile_id: str + title: str + summary: str + ready: bool + warning_count: int + failure_count: int + checks: list[WorkstationCheck] + + +class WorkstationReadinessReport(LAIModel): + executed_at: str + system: dict[str, Any] + profiles: list[WorkstationProfileReport] + + def profile(self, profile_id: str) -> WorkstationProfileReport: + for profile in self.profiles: + if profile.profile_id == profile_id: + return profile + raise KeyError(f"Unknown workstation profile: {profile_id}") + + +def build_workstation_readiness( + settings: Settings, + config: AppConfig, + *, + snapshot: SystemSnapshot | None = None, + dependency_versions: Mapping[str, str | None] | None = None, +) -> WorkstationReadinessReport: + snapshot = snapshot or collect_system_snapshot( + *runtime_directories(settings.root_dir), enable_gpu=settings.enable_gpu + ) + dependencies = dict(dependency_versions or _detect_dependency_versions()) + + local_models = [ + model + for model in config.model_catalog.enabled_models() + if model.runtime in {ModelRuntime.TRANSFORMERS, ModelRuntime.AIRLLM} + ] + transformer_models = [ + model for model in local_models if model.runtime == ModelRuntime.TRANSFORMERS + ] + airllm_models = [model for model in local_models if model.runtime == ModelRuntime.AIRLLM] + + profiles = [ + _build_local_profile( + settings, + config, + snapshot, + dependencies, + local_models=local_models, + transformer_models=transformer_models, + airllm_models=airllm_models, + ), + ] + if airllm_models: + profiles.append( + _build_airllm_profile( + settings, + config, + snapshot, + dependencies, + airllm_models=airllm_models, + ) + ) + + return WorkstationReadinessReport( + executed_at=_utcnow_iso(), + system={ + "platform": snapshot.platform_name, + "python_version": snapshot.python_version, + "gpu_name": snapshot.gpu_name, + "gpu_total_memory_gb": snapshot.gpu_total_memory_gb, + "gpu_driver_version": snapshot.gpu_driver_version, + "total_ram_gb": snapshot.total_ram_gb, + "available_ram_gb": snapshot.available_ram_gb, + "catalog_disk_headroom_gb": config.model_catalog.defaults.require_disk_headroom_gb, + "huggingface_cache_dir": str(settings.resolved_huggingface_cache_dir), + "airllm_shards_dir": str(settings.resolved_airllm_shards_dir), + "paths": { + str(path): snapshot.free_disk_gb.get(str(path)) + for path in runtime_directories(settings.root_dir) + }, + "local_model_ids": [model.id for model in local_models], + "airllm_model_ids": [model.id for model in airllm_models], + }, + profiles=profiles, + ) + + +def _build_local_profile( + settings: Settings, + config: AppConfig, + snapshot: SystemSnapshot, + dependencies: Mapping[str, str | None], + *, + local_models: list[ModelSpec], + transformer_models: list[ModelSpec], + airllm_models: list[ModelSpec], +) -> WorkstationProfileReport: + checks = [ + _python_version_check(), + _path_headroom_check( + check_id="runtime-paths", + title="Runtime paths and disk headroom", + path=settings.resolved_huggingface_cache_dir, + free_disk_gb=snapshot.free_disk_gb.get( + str(settings.resolved_huggingface_cache_dir) + ), + required_free_gb=min(config.model_catalog.defaults.require_disk_headroom_gb, 80), + remediation=( + "Move the Hugging Face cache to a drive with more free space before " + "large local downloads." + ), + ), + _dependency_stack_check( + check_id="local-dependencies", + title="Local inference dependency stack", + required_dependencies={ + "torch": dependencies.get("torch"), + "transformers": dependencies.get("transformers"), + "accelerate": dependencies.get("accelerate"), + "huggingface-hub": dependencies.get("huggingface-hub"), + }, + remediation="Install the local extra with `python -m pip install -e .[local]`.", + ), + ] + + if local_models: + max_recommended_ram = max( + (model.hardware.recommended_ram_gb or 0 for model in local_models), + default=0, + ) + checks.append( + _ram_check( + check_id="local-ram", + title="System RAM for local runtimes", + total_ram_gb=snapshot.total_ram_gb, + recommended_ram_gb=max_recommended_ram or None, + remediation=( + "Close other heavy applications or move larger execution stages " + "to remote providers when RAM is constrained." + ), + ) + ) + + if transformer_models or airllm_models: + checks.append( + _huggingface_token_check( + has_token=bool(settings.huggingface_token_value), + local_models=local_models, + ) + ) + + return _build_profile( + profile_id="local", + title="Local runtime baseline", + checks=checks, + ) + + +def _build_airllm_profile( + settings: Settings, + config: AppConfig, + snapshot: SystemSnapshot, + dependencies: Mapping[str, str | None], + *, + airllm_models: list[ModelSpec], +) -> WorkstationProfileReport: + heaviest_model = max( + airllm_models, + key=lambda model: ( + model.hardware.expected_disk_gb or 0, + model.hardware.recommended_ram_gb or 0, + model.hardware.recommended_vram_gb or 0, + ), + ) + required_free_disk_gb = max( + heaviest_model.hardware.expected_disk_gb or 0, + config.model_catalog.defaults.require_disk_headroom_gb, + ) + checks = [ + _python_version_check(), + _dependency_stack_check( + check_id="airllm-dependencies", + title="AirLLM dependency stack", + required_dependencies={ + "torch": dependencies.get("torch"), + "transformers": dependencies.get("transformers"), + "accelerate": dependencies.get("accelerate"), + "huggingface-hub": dependencies.get("huggingface-hub"), + "airllm": dependencies.get("airllm"), + }, + remediation=( + "Install the heavy local extras with `python -m pip install -e " + ".[local]` and verify AirLLM imports cleanly." + ), + ), + _optional_dependency_check( + check_id="bitsandbytes", + title="Optional bitsandbytes compression", + dependency_name="bitsandbytes", + dependency_version=dependencies.get("bitsandbytes"), + remediation=( + "Install `bitsandbytes` if you want lower-memory quantized local " + "experiments." + ), + ), + _huggingface_token_check( + has_token=bool(settings.huggingface_token_value), + local_models=airllm_models, + ), + _gpu_airllm_check(snapshot=snapshot, model=heaviest_model), + _vram_check(snapshot=snapshot, model=heaviest_model), + _ram_check( + check_id="airllm-ram", + title="System RAM for AirLLM orchestration", + total_ram_gb=snapshot.total_ram_gb, + recommended_ram_gb=heaviest_model.hardware.recommended_ram_gb, + remediation=( + "Use a workstation with more system RAM, or avoid the heaviest " + "local executor until RAM is expanded." + ), + ), + _path_headroom_check( + check_id="airllm-shards-disk", + title="AirLLM shard storage", + path=settings.resolved_airllm_shards_dir, + free_disk_gb=snapshot.free_disk_gb.get(str(settings.resolved_airllm_shards_dir)), + required_free_gb=required_free_disk_gb, + remediation=( + "Move `LAI_AIRLLM_SHARDS_DIR` to a larger drive or clear space before preparing " + f"{heaviest_model.id}." + ), + metadata={ + "model_id": heaviest_model.id, + "expected_disk_gb": heaviest_model.hardware.expected_disk_gb, + }, + ), + _path_headroom_check( + check_id="huggingface-cache-disk", + title="Hugging Face cache staging space", + path=settings.resolved_huggingface_cache_dir, + free_disk_gb=snapshot.free_disk_gb.get( + str(settings.resolved_huggingface_cache_dir) + ), + required_free_gb=min(required_free_disk_gb, 160), + remediation=( + "Move `LAI_HUGGINGFACE_CACHE_DIR` to a drive with more free space " + "for large model downloads." + ), + metadata={"model_id": heaviest_model.id}, + ), + ] + + return _build_profile( + profile_id="airllm", + title="AirLLM heavy local executor", + checks=checks, + ) + + +def _python_version_check() -> WorkstationCheck: + minimum, maximum = SUPPORTED_PYTHON + current = sys.version_info[:2] + if minimum <= current <= maximum: + return WorkstationCheck( + id="python-version", + status="pass", + title="Python runtime", + summary=( + f"Python {platform_version()} is within the supported range for LAI " + "local tooling." + ), + metadata={"python_version": platform_version()}, + ) + + return WorkstationCheck( + id="python-version", + status="fail", + title="Python runtime", + summary=( + f"Python {platform_version()} is outside the supported range {minimum[0]}.{minimum[1]} " + f"to {maximum[0]}.{maximum[1]}." + ), + remediation=( + "Create a Python 3.11 virtual environment before running local model " + "workflows." + ), + metadata={"python_version": platform_version()}, + ) + + +def _dependency_stack_check( + *, + check_id: str, + title: str, + required_dependencies: Mapping[str, str | None], + remediation: str, +) -> WorkstationCheck: + missing = [ + name for name, package_version in required_dependencies.items() if not package_version + ] + if not missing: + return WorkstationCheck( + id=check_id, + status="pass", + title=title, + summary="All required packages are installed.", + metadata={"versions": required_dependencies}, + ) + + return WorkstationCheck( + id=check_id, + status="fail", + title=title, + summary=f"Missing required packages: {', '.join(missing)}.", + remediation=remediation, + metadata={"versions": required_dependencies}, + ) + + +def _optional_dependency_check( + *, + check_id: str, + title: str, + dependency_name: str, + dependency_version: str | None, + remediation: str, +) -> WorkstationCheck: + if dependency_version: + return WorkstationCheck( + id=check_id, + status="pass", + title=title, + summary=f"{dependency_name} {dependency_version} is installed.", + metadata={"version": dependency_version}, + ) + + return WorkstationCheck( + id=check_id, + status="warn", + title=title, + summary=f"{dependency_name} is not installed.", + remediation=remediation, + ) + + +def _huggingface_token_check( + *, + has_token: bool, + local_models: list[ModelSpec], +) -> WorkstationCheck: + gated_models = [model.id for model in local_models if model.model_ref.startswith("meta-llama/")] + if not gated_models: + return WorkstationCheck( + id="huggingface-token", + status="pass", + title="Hugging Face credentials", + summary="No gated local models currently require a Hugging Face token.", + ) + + if has_token: + return WorkstationCheck( + id="huggingface-token", + status="pass", + title="Hugging Face credentials", + summary="A Hugging Face token is configured for gated local models.", + metadata={"model_ids": gated_models}, + ) + + return WorkstationCheck( + id="huggingface-token", + status="fail", + title="Hugging Face credentials", + summary=f"Gated local models need `LAI_HF_TOKEN`: {', '.join(gated_models)}.", + remediation=( + "Set `LAI_HF_TOKEN` in `.env` before downloading or running gated local " + "models." + ), + metadata={"model_ids": gated_models}, + ) + + +def _gpu_airllm_check(*, snapshot: SystemSnapshot, model: ModelSpec) -> WorkstationCheck: + if snapshot.has_gpu: + return WorkstationCheck( + id="airllm-gpu", + status="pass", + title="GPU availability", + summary=f"Detected GPU {snapshot.gpu_name or 'unknown'} for AirLLM workloads.", + metadata={ + "gpu_name": snapshot.gpu_name, + "gpu_total_memory_gb": snapshot.gpu_total_memory_gb, + "gpu_driver_version": snapshot.gpu_driver_version, + }, + ) + + if model.runtime_hints.allow_cpu_fallback: + return WorkstationCheck( + id="airllm-gpu", + status="warn", + title="GPU availability", + summary=( + "No GPU detected. AirLLM can still fall back to CPU for this model, but expect " + "very slow overnight execution." + ), + remediation=( + "Use a CUDA-capable NVIDIA GPU if you want practical turnaround " + "for the large local executor." + ), + ) + + return WorkstationCheck( + id="airllm-gpu", + status="fail", + title="GPU availability", + summary="No GPU detected and CPU fallback is disabled for the selected AirLLM model.", + remediation=( + "Enable a CUDA-capable GPU or switch the model policy to a provider " + "that supports CPU fallback." + ), + ) + + +def _vram_check(*, snapshot: SystemSnapshot, model: ModelSpec) -> WorkstationCheck: + minimum_vram_gb = model.hardware.minimum_vram_gb + recommended_vram_gb = model.hardware.recommended_vram_gb + gpu_vram = snapshot.gpu_total_memory_gb + + if not snapshot.has_gpu: + return WorkstationCheck( + id="airllm-vram", + status="warn", + title="GPU VRAM budget", + summary="VRAM checks are skipped because no GPU is available.", + ) + + if gpu_vram is None: + return WorkstationCheck( + id="airllm-vram", + status="warn", + title="GPU VRAM budget", + summary="GPU detected, but VRAM could not be measured automatically.", + remediation=( + "Check `nvidia-smi` manually to confirm the GPU can host the " + "AirLLM working set." + ), + ) + + if minimum_vram_gb and gpu_vram < minimum_vram_gb: + return WorkstationCheck( + id="airllm-vram", + status="fail", + title="GPU VRAM budget", + summary=( + f"Detected {gpu_vram:.2f} GB VRAM, below the minimum " + f"{minimum_vram_gb} GB for {model.id}." + ), + remediation=( + "Use a larger GPU or route the executor to a remote provider " + "until more VRAM is available." + ), + metadata={"gpu_total_memory_gb": gpu_vram, "minimum_vram_gb": minimum_vram_gb}, + ) + + if recommended_vram_gb and gpu_vram < recommended_vram_gb: + return WorkstationCheck( + id="airllm-vram", + status="warn", + title="GPU VRAM budget", + summary=( + f"Detected {gpu_vram:.2f} GB VRAM, below the recommended " + f"{recommended_vram_gb} GB for smoother AirLLM execution." + ), + remediation=( + "Expect more disk sharding and slower layer movement until a " + "larger GPU is available." + ), + metadata={"gpu_total_memory_gb": gpu_vram, "recommended_vram_gb": recommended_vram_gb}, + ) + + return WorkstationCheck( + id="airllm-vram", + status="pass", + title="GPU VRAM budget", + summary=f"Detected {gpu_vram:.2f} GB VRAM, which satisfies the configured AirLLM target.", + metadata={ + "gpu_total_memory_gb": gpu_vram, + "recommended_vram_gb": recommended_vram_gb, + "minimum_vram_gb": minimum_vram_gb, + }, + ) + + +def _ram_check( + *, + check_id: str, + title: str, + total_ram_gb: float | None, + recommended_ram_gb: int | None, + remediation: str, +) -> WorkstationCheck: + if recommended_ram_gb is None: + return WorkstationCheck( + id=check_id, + status="pass", + title=title, + summary="No explicit RAM recommendation is configured for the current model set.", + ) + + if total_ram_gb is None: + return WorkstationCheck( + id=check_id, + status="warn", + title=title, + summary="System RAM could not be measured automatically.", + remediation=( + "Check the host RAM budget manually before starting long-running " + "local jobs." + ), + ) + + if total_ram_gb < recommended_ram_gb: + return WorkstationCheck( + id=check_id, + status="warn", + title=title, + summary=( + f"Detected {total_ram_gb:.2f} GB RAM, below the recommended " + f"{recommended_ram_gb} GB." + ), + remediation=remediation, + metadata={"total_ram_gb": total_ram_gb, "recommended_ram_gb": recommended_ram_gb}, + ) + + return WorkstationCheck( + id=check_id, + status="pass", + title=title, + summary=f"Detected {total_ram_gb:.2f} GB RAM, meeting the configured recommendation.", + metadata={"total_ram_gb": total_ram_gb, "recommended_ram_gb": recommended_ram_gb}, + ) + + +def _path_headroom_check( + *, + check_id: str, + title: str, + path: Path, + free_disk_gb: float | None, + required_free_gb: int, + remediation: str, + metadata: Mapping[str, Any] | None = None, +) -> WorkstationCheck: + if free_disk_gb is None: + return WorkstationCheck( + id=check_id, + status="warn", + title=title, + summary=f"Could not measure free disk space for {path}.", + remediation=remediation, + metadata={"path": str(path), **(metadata or {})}, + ) + + status: Literal["pass", "warn", "fail"] + if free_disk_gb >= required_free_gb: + status = "pass" + summary = ( + f"{path} has {free_disk_gb:.2f} GB free, meeting the {required_free_gb} GB target." + ) + elif free_disk_gb >= required_free_gb * 0.7: + status = "warn" + summary = ( + f"{path} has {free_disk_gb:.2f} GB free, below the target " + f"{required_free_gb} GB but still close enough for careful staging." + ) + else: + status = "fail" + summary = ( + f"{path} has {free_disk_gb:.2f} GB free, well below the target {required_free_gb} GB." + ) + + return WorkstationCheck( + id=check_id, + status=status, + title=title, + summary=summary, + remediation=None if status == "pass" else remediation, + metadata={ + "path": str(path), + "free_disk_gb": free_disk_gb, + "required_free_gb": required_free_gb, + **(metadata or {}), + }, + ) + + +def _build_profile( + *, + profile_id: str, + title: str, + checks: list[WorkstationCheck], +) -> WorkstationProfileReport: + failure_count = sum(1 for check in checks if check.status == "fail") + warning_count = sum(1 for check in checks if check.status == "warn") + ready = failure_count == 0 + if ready and warning_count == 0: + summary = "Ready for this profile with no blocking issues detected." + elif ready: + summary = f"Ready with {warning_count} warning(s)." + else: + summary = f"Blocked by {failure_count} failing check(s) and {warning_count} warning(s)." + + return WorkstationProfileReport( + profile_id=profile_id, + title=title, + summary=summary, + ready=ready, + warning_count=warning_count, + failure_count=failure_count, + checks=checks, + ) + + +def _detect_dependency_versions() -> dict[str, str | None]: + return { + "torch": _package_version("torch"), + "transformers": _package_version("transformers"), + "accelerate": _package_version("accelerate"), + "huggingface-hub": _package_version("huggingface-hub"), + "airllm": _package_version("airllm"), + "bitsandbytes": _package_version("bitsandbytes"), + } + + +def _package_version(package_name: str) -> str | None: + try: + return version(package_name) + except PackageNotFoundError: + return None + + +def platform_version() -> str: + return f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}" + + +def _utcnow_iso() -> str: + return datetime.now(tz=timezone.utc).isoformat() diff --git a/tests/unit/test_api.py b/tests/unit/test_api.py index f80a4df..0b38b1d 100644 --- a/tests/unit/test_api.py +++ b/tests/unit/test_api.py @@ -42,6 +42,7 @@ def test_api_exposes_expected_routes() -> None: assert "/dashboard" in routes assert "/health" in routes assert "/models" in routes + assert "/workstation/readiness" in routes assert "/smoke/latest" in routes assert "/smoke/results" in routes assert "/smoke/results/{result_id}" in routes @@ -70,9 +71,24 @@ def test_dashboard_routes_serve_html_and_assets(repo_root, tmp_path) -> None: assert root_response.headers["location"] == "/dashboard" assert dashboard_response.status_code == 200 assert "LAI CONTROL ROOM" in dashboard_response.text + assert "Workstation readiness" in dashboard_response.text assert "Smoke diagnostics" in dashboard_response.text assert asset_response.status_code == 200 - assert "async function loadSmokeDiagnostics" in asset_response.text + assert "async function loadWorkstationReadiness" in asset_response.text + + +def test_workstation_readiness_route_returns_profiles(repo_root, tmp_path) -> None: + settings = _test_settings(repo_root, tmp_path) + providers = _test_providers(settings) + client = TestClient(create_api(settings=settings, providers=providers)) + + response = client.get("/workstation/readiness") + + assert response.status_code == 200 + payload = response.json() + assert payload["system"]["python_version"] + assert payload["profiles"] + assert any(profile["profile_id"] == "local" for profile in payload["profiles"]) def test_job_artifact_routes_return_persisted_content(repo_root, tmp_path) -> None: diff --git a/tests/unit/test_workstation.py b/tests/unit/test_workstation.py new file mode 100644 index 0000000..397b387 --- /dev/null +++ b/tests/unit/test_workstation.py @@ -0,0 +1,105 @@ +from lai.config import load_app_config +from lai.settings import Settings +from lai.system import SystemSnapshot +from lai.workstation import build_workstation_readiness + + +def _test_settings(repo_root, tmp_path, **overrides) -> Settings: + base = { + "root_dir": tmp_path, + "model_catalog": repo_root / "configs/models/catalog.yaml", + "routing_policy": repo_root / "configs/routing/policies.yaml", + "database_path": tmp_path / "data/state/lai.db", + "state_dir": tmp_path / "data/state", + "artifacts_dir": tmp_path / "data/artifacts", + "logs_dir": tmp_path / "logs", + "huggingface_cache_dir": tmp_path / "data/cache/huggingface", + "airllm_shards_dir": tmp_path / "data/models/airllm-shards", + "raw_models_dir": tmp_path / "data/models/raw", + } + base.update(overrides) + return Settings(**base) + + +def test_workstation_readiness_blocks_missing_airllm_prerequisites(repo_root, tmp_path) -> None: + settings = _test_settings(repo_root, tmp_path) + config = load_app_config(settings.resolved_model_catalog, settings.resolved_routing_policy) + snapshot = SystemSnapshot( + has_gpu=False, + total_ram_gb=16.0, + available_ram_gb=10.0, + free_disk_gb={ + str(settings.resolved_huggingface_cache_dir): 20.0, + str(settings.resolved_airllm_shards_dir): 40.0, + }, + ) + + report = build_workstation_readiness( + settings, + config, + snapshot=snapshot, + dependency_versions={ + "torch": None, + "transformers": "4.55.0", + "accelerate": None, + "huggingface-hub": "0.34.0", + "airllm": None, + "bitsandbytes": None, + }, + ) + + local_profile = report.profile("local") + airllm_profile = report.profile("airllm") + + assert not local_profile.ready + assert not airllm_profile.ready + assert any( + check.id == "airllm-dependencies" and check.status == "fail" + for check in airllm_profile.checks + ) + assert any( + check.id == "huggingface-token" and check.status == "fail" + for check in airllm_profile.checks + ) + assert any( + check.id == "airllm-shards-disk" and check.status == "fail" + for check in airllm_profile.checks + ) + + +def test_workstation_readiness_allows_cpu_fallback_with_warnings(repo_root, tmp_path) -> None: + settings = _test_settings(repo_root, tmp_path, hf_token="test-token") + config = load_app_config(settings.resolved_model_catalog, settings.resolved_routing_policy) + snapshot = SystemSnapshot( + has_gpu=False, + total_ram_gb=48.0, + available_ram_gb=30.0, + free_disk_gb={ + str(settings.resolved_huggingface_cache_dir): 220.0, + str(settings.resolved_airllm_shards_dir): 320.0, + }, + ) + + report = build_workstation_readiness( + settings, + config, + snapshot=snapshot, + dependency_versions={ + "torch": "2.7.0", + "transformers": "4.55.0", + "accelerate": "1.9.0", + "huggingface-hub": "0.34.0", + "airllm": "2.1.0", + "bitsandbytes": "0.45.0", + }, + ) + + local_profile = report.profile("local") + airllm_profile = report.profile("airllm") + + assert local_profile.ready + assert airllm_profile.ready + assert airllm_profile.warning_count >= 1 + assert any( + check.id == "airllm-gpu" and check.status == "warn" for check in airllm_profile.checks + ) From ebb956a27bdddcc5a71c9a73c91a1eaf61981ad1 Mon Sep 17 00:00:00 2001 From: N3uralCreativity Date: Tue, 7 Apr 2026 20:54:12 +0200 Subject: [PATCH 12/16] feat: add execution trace summaries --- README.md | 8 +- apps/web/README.md | 1 + src/lai/api/app.py | 11 +++ src/lai/api/static/dashboard.css | 14 ++++ src/lai/api/static/dashboard.js | 132 +++++++++++++++++++++++++++++++ src/lai/cli.py | 25 ++++++ src/lai/domain.py | 24 ++++++ src/lai/jobs/service.py | 44 ++++++++--- src/lai/observability.py | 102 ++++++++++++++++++++++++ tests/unit/test_api.py | 25 ++++++ tests/unit/test_observability.py | 58 ++++++++++++++ 11 files changed, 431 insertions(+), 13 deletions(-) create mode 100644 src/lai/observability.py create mode 100644 tests/unit/test_observability.py diff --git a/README.md b/README.md index 1de5ffc..2127d51 100644 --- a/README.md +++ b/README.md @@ -141,6 +141,7 @@ Available endpoints: - `GET /jobs` - `GET /jobs/{job_id}` - `GET /jobs/{job_id}/timeline` +- `GET /jobs/{job_id}/execution` - `POST /jobs/{job_id}/replay` - `GET /jobs/{job_id}/artifacts` - `GET /jobs/{job_id}/artifacts/{artifact_id}` @@ -149,9 +150,10 @@ Available endpoints: The API now also serves a read-mostly dashboard at `/dashboard` with live model health, workstation readiness for heavy local execution, route explanation, recent job inspection, -stage telemetry, artifact/trace browsing, replay actions, persisted worker monitoring, -saved smoke diagnostics with history drill-down, per-check metadata/output previews, -and bounded queue worker controls including an until-idle drain path. +stage telemetry, provider execution summaries, artifact/trace browsing, replay actions, +persisted worker monitoring, saved smoke diagnostics with history drill-down, +per-check metadata/output previews, and bounded queue worker controls including an +until-idle drain path. ## Dedicated worker service diff --git a/apps/web/README.md b/apps/web/README.md index d8e3deb..9ac942d 100644 --- a/apps/web/README.md +++ b/apps/web/README.md @@ -9,6 +9,7 @@ Current dashboard capabilities: - route explanation form - job submission and recent queue inspection - persisted stage telemetry timeline for planner, executor, and reviewer flow +- provider execution summaries with duration, output preview, and artifact linkage per stage - artifact and trace browsing for persisted jobs - job replay controls for inline and queued reruns - persisted worker monitoring with heartbeat, current job, and queue depth diff --git a/src/lai/api/app.py b/src/lai/api/app.py index e3fe6fc..46c5f4d 100644 --- a/src/lai/api/app.py +++ b/src/lai/api/app.py @@ -10,6 +10,7 @@ from ..application import create_application from ..domain import ExecutionRequest, QueueMode +from ..observability import summarize_job_execution from ..providers import Provider from ..settings import Settings from ..smoke import ( @@ -223,6 +224,16 @@ def get_job_timeline(job_id: str) -> dict[str, object]: raise HTTPException(status_code=404, detail="Job not found") return {"stage_events": [event.model_dump() for event in job.stage_events]} + @api.get("/jobs/{job_id}/execution") + def get_job_execution(job_id: str) -> dict[str, object]: + app_state = application() + job = app_state.job_store.get_job(job_id) + if job is None: + raise HTTPException(status_code=404, detail="Job not found") + return { + "stages": [summary.model_dump() for summary in summarize_job_execution(job)] + } + @api.get("/jobs/{job_id}/artifacts") def list_job_artifacts(job_id: str) -> dict[str, object]: app_state = application() diff --git a/src/lai/api/static/dashboard.css b/src/lai/api/static/dashboard.css index 7abc4ad..d6e43a2 100644 --- a/src/lai/api/static/dashboard.css +++ b/src/lai/api/static/dashboard.css @@ -498,6 +498,20 @@ textarea { color: var(--ink); } +.execution-line { + margin-top: 0.55rem; +} + +.execution-artifacts { + margin-top: 0.65rem; +} + +.execution-preview { + margin: 0.65rem 0 0; + max-height: 12rem; + overflow: auto; +} + .smoke-card { padding: 1rem; display: grid; diff --git a/src/lai/api/static/dashboard.js b/src/lai/api/static/dashboard.js index b6de5dc..ace7b26 100644 --- a/src/lai/api/static/dashboard.js +++ b/src/lai/api/static/dashboard.js @@ -685,6 +685,10 @@ async function selectJob(jobId) {

Route trace

${reasons || '
No route reasons recorded.
'}
+
+

Provider execution

+
Loading execution summaries...
+

Stage timeline

Loading stage telemetry...
@@ -706,10 +710,138 @@ async function selectJob(jobId) { void cancelJob(job.id).catch(handleQueueError); }); + await loadExecution(job.id); await loadTimeline(job.id, job.stage_events || []); await loadArtifacts(job.id, job.artifacts || []); } +async function loadExecution(jobId) { + const response = await fetchJson(`/jobs/${jobId}/execution`); + const stages = response.stages || []; + const executionList = document.getElementById("execution-list"); + if (!executionList) { + return; + } + + if (stages.length === 0) { + executionList.className = "timeline-list empty"; + executionList.textContent = "No provider execution summaries recorded for this job yet."; + return; + } + + executionList.className = "timeline-list fade-in"; + executionList.innerHTML = stages + .map((stage) => renderExecutionStage(jobId, stage)) + .join(""); + + executionList.querySelectorAll("[data-execution-artifact-id]").forEach((button) => { + button.addEventListener("click", () => { + void selectArtifact( + jobId, + button.getAttribute("data-execution-artifact-id"), + [], + ).catch(handleError); + }); + }); +} + +function renderExecutionStage(jobId, stage) { + const tone = + stage.status === "failed" || stage.status === "blocked" + ? "danger" + : stage.status === "running" + ? "warn" + : "ready"; + const detailChips = [ + chip("status", stage.status, tone), + chip("provider", stage.provider_id || "n/a"), + chip("model", stage.model_id || "n/a"), + ]; + if (stage.duration_seconds !== null && stage.duration_seconds !== undefined) { + detailChips.push(chip("duration", `${stage.duration_seconds.toFixed(2)}s`)); + } + if (stage.finish_reason) { + detailChips.push(chip("finish", stage.finish_reason)); + } + if (stage.prompt_characters !== null && stage.prompt_characters !== undefined) { + detailChips.push(chip("prompt", `${stage.prompt_characters} chars`)); + } + if (stage.output_characters !== null && stage.output_characters !== undefined) { + detailChips.push(chip("output", `${stage.output_characters} chars`)); + } + if (stage.usage?.total_tokens !== null && stage.usage?.total_tokens !== undefined) { + detailChips.push(chip("tokens", stage.usage.total_tokens)); + } + + const artifactChips = (stage.artifact_types || []) + .map((artifactType, index) => renderExecutionArtifactLink(jobId, stage, artifactType, index)) + .join(""); + const preview = stage.output_preview + ? `
${escapeHtml(stage.output_preview)}
` + : ""; + const healthReasons = (stage.health_reasons || []).length + ? ` +
    + ${stage.health_reasons.map((reason) => `
  • ${escapeHtml(reason)}
  • `).join("")} +
+ ` + : ""; + const error = stage.error_type + ? `
error: ${escapeHtml(stage.error_type)}
` + : ""; + const timestamps = [ + stage.started_at ? `started: ${new Date(stage.started_at).toLocaleString()}` : null, + stage.finished_at ? `finished: ${new Date(stage.finished_at).toLocaleString()}` : null, + ] + .filter(Boolean) + .join(" | "); + const requestKnobs = [ + stage.temperature !== null && stage.temperature !== undefined + ? `temp ${stage.temperature}` + : null, + stage.max_output_tokens ? `max ${stage.max_output_tokens}` : null, + stage.timeout_seconds ? `timeout ${stage.timeout_seconds}s` : null, + ] + .filter(Boolean) + .join(" | "); + + return ` +
+
+
+
${escapeHtml(stage.stage)}
+
${detailChips.join("")}
+
+
${escapeHtml(timestamps || "timing unavailable")}
+
+
${escapeHtml(stage.last_message || "No execution message recorded.")}
+ ${requestKnobs ? `
${escapeHtml(requestKnobs)}
` : ""} + ${artifactChips ? `
${artifactChips}
` : ""} + ${error} + ${healthReasons} + ${preview} +
+ `; +} + +function renderExecutionArtifactLink(jobId, stage, artifactType, index) { + const artifactId = stage.artifact_ids?.[index]; + if (!artifactId) { + return chip("artifact", artifactType); + } + + return ` + + `; +} + async function loadTimeline(jobId, stageEvents) { const response = stageEvents.length > 0 diff --git a/src/lai/cli.py b/src/lai/cli.py index 17b5a41..5f65ae0 100644 --- a/src/lai/cli.py +++ b/src/lai/cli.py @@ -12,6 +12,7 @@ from .domain import ExecutionRequest, QueueMode from .evals import run_route_eval_suite, save_route_eval_result from .layout import runtime_directories +from .observability import summarize_job_execution from .settings import Settings from .smoke import latest_smoke_result, run_smoke_suite, save_smoke_result from .system import collect_system_snapshot @@ -331,6 +332,30 @@ def show_job(job_id: str = typer.Argument(..., help="Job identifier.")) -> None: ) console.print(timeline) + execution = summarize_job_execution(job) + if execution: + execution_table = Table(title="Execution Trace") + execution_table.add_column("Stage") + execution_table.add_column("Status") + execution_table.add_column("Provider") + execution_table.add_column("Model") + execution_table.add_column("Duration") + execution_table.add_column("Output") + for summary in execution: + execution_table.add_row( + summary.stage, + summary.status, + summary.provider_id or "n/a", + summary.model_id or "n/a", + ( + f"{summary.duration_seconds:.2f}s" + if summary.duration_seconds is not None + else "n/a" + ), + summary.output_preview or "n/a", + ) + console.print(execution_table) + @jobs_app.command("cancel") def cancel_job(job_id: str = typer.Argument(..., help="Job identifier.")) -> None: diff --git a/src/lai/domain.py b/src/lai/domain.py index 83a20e8..89605cb 100644 --- a/src/lai/domain.py +++ b/src/lai/domain.py @@ -319,6 +319,30 @@ class StageEventRecord(LAIModel): metadata: dict[str, Any] = Field(default_factory=dict) +class StageExecutionSummary(LAIModel): + stage: str + status: str + model_id: str | None = None + provider_id: str | None = None + started_at: datetime | None = None + finished_at: datetime | None = None + duration_seconds: float | None = None + finish_reason: str | None = None + prompt_characters: int | None = None + system_prompt_characters: int | None = None + output_characters: int | None = None + output_preview: str | None = None + temperature: float | None = None + max_output_tokens: int | None = None + timeout_seconds: int | None = None + usage: UsageStats | None = None + artifact_ids: list[str] = Field(default_factory=list) + artifact_types: list[str] = Field(default_factory=list) + health_reasons: list[str] = Field(default_factory=list) + error_type: str | None = None + last_message: str | None = None + + class WorkerStateRecord(LAIModel): worker_id: str status: WorkerStatus = WorkerStatus.IDLE diff --git a/src/lai/jobs/service.py b/src/lai/jobs/service.py index 18c1e54..04b0eb5 100644 --- a/src/lai/jobs/service.py +++ b/src/lai/jobs/service.py @@ -564,6 +564,14 @@ def _run_stage( model = self.config.model_catalog.get_model(model_id) provider = self.provider_registry.provider_for_model(model) provider_id = getattr(provider, "provider_id", model.provider_id) + request = ProviderRequest( + system_prompt=system_prompt, + user_prompt=user_prompt, + metadata={"job_id": job.id, "stage": stage}, + temperature=job.request.temperature, + max_output_tokens=job.request.max_output_tokens, + timeout_seconds=job.request.timeout_seconds, + ) self._record_event( job_id=job.id, stage=stage, @@ -571,6 +579,14 @@ def _run_stage( message=f"{stage.title()} stage started with model {model.id}.", model_id=model.id, provider_id=provider_id, + metadata={ + "temperature": request.temperature, + "max_output_tokens": request.max_output_tokens, + "timeout_seconds": request.timeout_seconds, + "prompt_characters": len(user_prompt), + "system_prompt_characters": len(system_prompt or ""), + "request_preview": _preview_text(user_prompt), + }, ) health = self.provider_registry.healthcheck(model) if not health.available: @@ -582,19 +598,14 @@ def _run_stage( message=f"{stage.title()} stage blocked: {reason}", model_id=model.id, provider_id=provider_id, - metadata={"health_reasons": health.reasons}, + metadata={ + "health_reasons": health.reasons, + "health_metadata": health.metadata, + }, ) raise RuntimeError( f"Model {model.id!r} is not executable: {reason}" ) - request = ProviderRequest( - system_prompt=system_prompt, - user_prompt=user_prompt, - metadata={"job_id": job.id, "stage": stage}, - temperature=job.request.temperature, - max_output_tokens=job.request.max_output_tokens, - timeout_seconds=job.request.timeout_seconds, - ) self.artifacts.write_json( job.id, f"{stage}-request", f"{stage}_request.json", request.model_dump() ) @@ -608,7 +619,10 @@ def _run_stage( message=f"{stage.title()} stage failed with {type(exc).__name__}.", model_id=model.id, provider_id=provider_id, - metadata={"error_type": type(exc).__name__}, + metadata={ + "error_type": type(exc).__name__, + "error_message": str(exc), + }, ) raise result.stage = stage @@ -626,6 +640,9 @@ def _run_stage( "duration_seconds": result.duration_seconds, "finish_reason": result.finish_reason, "usage": result.usage.model_dump() if result.usage else None, + "output_characters": len(result.text), + "output_preview": _preview_text(result.text), + "raw_keys": sorted(result.raw.keys()), }, ) if stage == "executor": @@ -695,3 +712,10 @@ def _enum_value(value: object) -> str: if isinstance(value, Enum): return str(value.value) return str(value) + + +def _preview_text(value: str, *, limit: int = 240) -> str: + normalized = " ".join(value.split()) + if len(normalized) <= limit: + return normalized + return normalized[: limit - 3] + "..." diff --git a/src/lai/observability.py b/src/lai/observability.py new file mode 100644 index 0000000..f1867c6 --- /dev/null +++ b/src/lai/observability.py @@ -0,0 +1,102 @@ +from __future__ import annotations + +from .domain import ArtifactRecord, JobRecord, StageEventRecord, StageExecutionSummary, UsageStats + +EXECUTION_STAGES = ("planner", "executor", "reviewer") + + +def summarize_job_execution(job: JobRecord) -> list[StageExecutionSummary]: + summaries: list[StageExecutionSummary] = [] + for stage in EXECUTION_STAGES: + stage_events = [event for event in job.stage_events if event.stage == stage] + if not stage_events: + continue + + started = _first_event(stage_events, "started") + completed = _last_event(stage_events, "completed") + failed = _last_event(stage_events, "failed") + blocked = _last_event(stage_events, "blocked") + final_event = completed or failed or blocked or stage_events[-1] + request_metadata = started.metadata if started is not None else {} + final_metadata = final_event.metadata if final_event is not None else {} + status = _resolved_status(started, completed, failed, blocked) + usage_payload = final_metadata.get("usage") + + stage_artifacts = _stage_artifacts(job.artifacts, stage) + summaries.append( + StageExecutionSummary( + stage=stage, + status=status, + model_id=(final_event.model_id if final_event is not None else None) + or (started.model_id if started is not None else None), + provider_id=(final_event.provider_id if final_event is not None else None) + or (started.provider_id if started is not None else None), + started_at=started.created_at if started is not None else None, + finished_at=( + final_event.created_at + if final_event is not None and status in {"completed", "failed", "blocked"} + else None + ), + duration_seconds=final_metadata.get("duration_seconds"), + finish_reason=final_metadata.get("finish_reason"), + prompt_characters=request_metadata.get("prompt_characters"), + system_prompt_characters=request_metadata.get("system_prompt_characters"), + output_characters=final_metadata.get("output_characters"), + output_preview=final_metadata.get("output_preview"), + temperature=request_metadata.get("temperature"), + max_output_tokens=request_metadata.get("max_output_tokens"), + timeout_seconds=request_metadata.get("timeout_seconds"), + usage=( + UsageStats.model_validate(usage_payload) + if isinstance(usage_payload, dict) + else None + ), + artifact_ids=[artifact.id for artifact in stage_artifacts], + artifact_types=[artifact.artifact_type for artifact in stage_artifacts], + health_reasons=list(final_metadata.get("health_reasons") or []), + error_type=final_metadata.get("error_type"), + last_message=final_event.message if final_event is not None else None, + ) + ) + + return summaries + + +def _stage_artifacts(artifacts: list[ArtifactRecord], stage: str) -> list[ArtifactRecord]: + target_types = {f"{stage}-request", f"{stage}-response"} + if stage == "reviewer": + target_types.add("reviewer-notes") + if stage == "executor": + target_types.add("final-output") + return [artifact for artifact in artifacts if artifact.artifact_type in target_types] + + +def _first_event(events: list[StageEventRecord], event_type: str) -> StageEventRecord | None: + for event in events: + if event.event_type == event_type: + return event + return None + + +def _last_event(events: list[StageEventRecord], event_type: str) -> StageEventRecord | None: + for event in reversed(events): + if event.event_type == event_type: + return event + return None + + +def _resolved_status( + started: StageEventRecord | None, + completed: StageEventRecord | None, + failed: StageEventRecord | None, + blocked: StageEventRecord | None, +) -> str: + if completed is not None: + return "completed" + if failed is not None: + return "failed" + if blocked is not None: + return "blocked" + if started is not None: + return "running" + return "pending" diff --git a/tests/unit/test_api.py b/tests/unit/test_api.py index 0b38b1d..46f1194 100644 --- a/tests/unit/test_api.py +++ b/tests/unit/test_api.py @@ -52,6 +52,7 @@ def test_api_exposes_expected_routes() -> None: assert "/jobs" in routes assert "/jobs/{job_id}" in routes assert "/jobs/{job_id}/timeline" in routes + assert "/jobs/{job_id}/execution" in routes assert "/jobs/{job_id}/replay" in routes assert "/jobs/{job_id}/artifacts" in routes assert "/jobs/{job_id}/artifacts/{artifact_id}" in routes @@ -75,6 +76,7 @@ def test_dashboard_routes_serve_html_and_assets(repo_root, tmp_path) -> None: assert "Smoke diagnostics" in dashboard_response.text assert asset_response.status_code == 200 assert "async function loadWorkstationReadiness" in asset_response.text + assert "async function loadExecution" in asset_response.text def test_workstation_readiness_route_returns_profiles(repo_root, tmp_path) -> None: @@ -142,6 +144,29 @@ def test_job_timeline_route_returns_stage_events(repo_root, tmp_path) -> None: ) +def test_job_execution_route_returns_stage_summaries(repo_root, tmp_path) -> None: + settings = _test_settings(repo_root, tmp_path) + providers = _test_providers(settings) + seeded_app = create_application(settings=settings, providers=providers) + job = seeded_app.orchestration.submit_request( + ExecutionRequest( + user_prompt="Summarize this short note.", + queue_mode=QueueMode.INLINE, + ) + ) + + client = TestClient(create_api(settings=settings, providers=providers)) + execution_response = client.get(f"/jobs/{job.id}/execution") + + assert execution_response.status_code == 200 + payload = execution_response.json() + assert payload["stages"] + executor = next(stage for stage in payload["stages"] if stage["stage"] == "executor") + assert executor["status"] == "completed" + assert executor["provider_id"] == "transformers" + assert executor["artifact_types"] + + def test_smoke_routes_run_and_return_latest_result(repo_root, tmp_path) -> None: settings = _test_settings(repo_root, tmp_path) providers = _test_providers(settings) diff --git a/tests/unit/test_observability.py b/tests/unit/test_observability.py new file mode 100644 index 0000000..15b2281 --- /dev/null +++ b/tests/unit/test_observability.py @@ -0,0 +1,58 @@ +from lai.application import create_application +from lai.domain import ExecutionRequest, QueueMode +from lai.observability import summarize_job_execution +from lai.settings import Settings +from lai.system import collect_system_snapshot +from tests.helpers import FakeProvider + + +def _test_settings(repo_root, tmp_path) -> Settings: + return Settings( + root_dir=tmp_path, + model_catalog=repo_root / "configs/models/catalog.yaml", + routing_policy=repo_root / "configs/routing/policies.yaml", + database_path=tmp_path / "data/state/lai.db", + state_dir=tmp_path / "data/state", + artifacts_dir=tmp_path / "data/artifacts", + logs_dir=tmp_path / "logs", + huggingface_cache_dir=tmp_path / "data/cache/huggingface", + airllm_shards_dir=tmp_path / "data/models/airllm-shards", + raw_models_dir=tmp_path / "data/models/raw", + ) + + +def _test_providers(settings: Settings) -> dict[str, FakeProvider]: + snapshot = collect_system_snapshot(settings.root_dir, enable_gpu=False) + return { + "transformers": FakeProvider(settings, snapshot, "transformers"), + "airllm": FakeProvider(settings, snapshot, "airllm"), + "openai": FakeProvider(settings, snapshot, "openai"), + "anthropic": FakeProvider(settings, snapshot, "anthropic"), + "gemini": FakeProvider(settings, snapshot, "gemini"), + } + + +def test_summarize_job_execution_returns_stage_details(repo_root, tmp_path) -> None: + settings = _test_settings(repo_root, tmp_path) + application = create_application(settings=settings, providers=_test_providers(settings)) + job = application.orchestration.submit_request( + ExecutionRequest( + user_prompt="Prepare a detailed implementation strategy for this platform.", + queue_mode=QueueMode.INLINE, + ) + ) + + summaries = summarize_job_execution(job) + + assert summaries + assert [summary.stage for summary in summaries] == ["planner", "executor", "reviewer"] + + executor = next(summary for summary in summaries if summary.stage == "executor") + assert executor.status == "completed" + assert executor.prompt_characters is not None + assert executor.output_preview + assert "final-output" in executor.artifact_types + + reviewer = next(summary for summary in summaries if summary.stage == "reviewer") + assert reviewer.status == "completed" + assert "reviewer-notes" in reviewer.artifact_types From bddcc28a3072896df51b0058a73b27c103507288 Mon Sep 17 00:00:00 2001 From: N3uralCreativity Date: Tue, 7 Apr 2026 21:06:49 +0200 Subject: [PATCH 13/16] feat: add stale job recovery tooling --- README.md | 10 +- apps/web/README.md | 1 + src/lai/api/app.py | 26 ++++ src/lai/api/static/dashboard.css | 23 ++++ src/lai/api/static/dashboard.js | 72 ++++++++++- src/lai/api/static/index.html | 4 + src/lai/cli.py | 77 ++++++++++++ src/lai/domain.py | 13 ++ src/lai/jobs/service.py | 200 +++++++++++++++++++++++++++++-- src/lai/jobs/store.py | 43 ++++++- src/lai/observability.py | 73 ++++++++++- src/lai/settings.py | 1 + tests/unit/test_api.py | 54 ++++++++- tests/unit/test_recovery.py | 134 +++++++++++++++++++++ 14 files changed, 716 insertions(+), 15 deletions(-) create mode 100644 tests/unit/test_recovery.py diff --git a/README.md b/README.md index 2127d51..597e9ef 100644 --- a/README.md +++ b/README.md @@ -100,6 +100,8 @@ python -m lai.cli workstation validate --profile airllm python -m lai.cli route explain "Summarize this note." python -m lai.cli run "Create a detailed implementation strategy." python -m lai.cli jobs list +python -m lai.cli jobs stale +python -m lai.cli jobs recover python -m lai.cli jobs replay --queue-mode queued python -m lai.cli worker run --once python -m lai.cli worker run --max-jobs 3 @@ -139,9 +141,11 @@ Available endpoints: - `POST /route/explain` - `POST /jobs` - `GET /jobs` +- `GET /jobs/stale` - `GET /jobs/{job_id}` - `GET /jobs/{job_id}/timeline` - `GET /jobs/{job_id}/execution` +- `POST /jobs/recover` - `POST /jobs/{job_id}/replay` - `GET /jobs/{job_id}/artifacts` - `GET /jobs/{job_id}/artifacts/{artifact_id}` @@ -152,8 +156,8 @@ The API now also serves a read-mostly dashboard at `/dashboard` with live model workstation readiness for heavy local execution, route explanation, recent job inspection, stage telemetry, provider execution summaries, artifact/trace browsing, replay actions, persisted worker monitoring, saved smoke diagnostics with history drill-down, -per-check metadata/output previews, and bounded queue worker controls including an -until-idle drain path. +per-check metadata/output previews, stale-job detection/recovery for interrupted runs, +and bounded queue worker controls including an until-idle drain path. ## Dedicated worker service @@ -223,7 +227,7 @@ For step-by-step remediation and overnight run guidance, see 1. Expand eval scenarios and richer reviewer/final-output refinement. 2. Add richer live provider execution visibility for real provider calls. 3. Add operator-facing controls for smoke result retention and cleanup over time. -4. Add richer heavy-job run telemetry and recovery tooling for interrupted overnight executions. +4. Add provider-specific live progress hooks where SDKs expose deeper execution signals. ## References diff --git a/apps/web/README.md b/apps/web/README.md index 9ac942d..c8e75d8 100644 --- a/apps/web/README.md +++ b/apps/web/README.md @@ -12,6 +12,7 @@ Current dashboard capabilities: - provider execution summaries with duration, output preview, and artifact linkage per stage - artifact and trace browsing for persisted jobs - job replay controls for inline and queued reruns +- stale-job detection and one-click queue recovery for interrupted long-running jobs - persisted worker monitoring with heartbeat, current job, and queue depth - service-aware worker monitoring with daemon lock and stop-signal visibility - saved smoke diagnostics for provider and local readiness, with recent-run history browsing diff --git a/src/lai/api/app.py b/src/lai/api/app.py index 46c5f4d..fb18879 100644 --- a/src/lai/api/app.py +++ b/src/lai/api/app.py @@ -40,6 +40,11 @@ class JobReplayPayload(BaseModel): queue_mode: QueueMode | None = None +class JobRecoverPayload(BaseModel): + stale_after_seconds: int | None = Field(default=None, ge=1, le=86400) + limit: int = Field(default=20, ge=1, le=200) + + class WorkerRunPayload(BaseModel): max_jobs: int | None = Field(default=None, ge=1, le=25) resume_running: bool = False @@ -208,6 +213,27 @@ def list_jobs(limit: int = 20) -> dict[str, object]: app_state = application() return {"jobs": [job.model_dump() for job in app_state.job_store.list_jobs(limit=limit)]} + @api.get("/jobs/stale") + def list_stale_jobs( + stale_after_seconds: int | None = None, + limit: int = 20, + ) -> dict[str, object]: + app_state = application() + stale_jobs = app_state.orchestration.list_stale_jobs( + stale_after_seconds=stale_after_seconds, + limit=limit, + ) + return {"jobs": [job.model_dump() for job in stale_jobs]} + + @api.post("/jobs/recover") + def recover_stale_jobs(payload: JobRecoverPayload) -> dict[str, object]: + app_state = application() + recovered = app_state.orchestration.recover_stale_jobs( + stale_after_seconds=payload.stale_after_seconds, + limit=payload.limit, + ) + return {"recovered": [job.model_dump() for job in recovered]} + @api.get("/jobs/{job_id}") def get_job(job_id: str) -> dict[str, object]: app_state = application() diff --git a/src/lai/api/static/dashboard.css b/src/lai/api/static/dashboard.css index d6e43a2..866dd1c 100644 --- a/src/lai/api/static/dashboard.css +++ b/src/lai/api/static/dashboard.css @@ -313,6 +313,29 @@ textarea { background: rgba(255, 255, 255, 0.58); } +.stale-job-list { + display: grid; + gap: 0.65rem; +} + +.stale-job-button { + border: 1px solid rgba(33, 63, 100, 0.12); + background: rgba(255, 255, 255, 0.72); + color: var(--ink); + border-radius: 1rem; + padding: 0.85rem 0.95rem; + text-align: left; + font: inherit; + cursor: pointer; + transition: transform 160ms ease, border-color 160ms ease, background 160ms ease; +} + +.stale-job-button:hover { + transform: translateY(-1px); + border-color: rgba(201, 109, 57, 0.28); + background: rgba(255, 255, 255, 0.9); +} + .chip { display: inline-flex; align-items: center; diff --git a/src/lai/api/static/dashboard.js b/src/lai/api/static/dashboard.js index ace7b26..3db7925 100644 --- a/src/lai/api/static/dashboard.js +++ b/src/lai/api/static/dashboard.js @@ -1,5 +1,6 @@ const state = { jobs: [], + staleJobs: [], models: [], worker: null, workstation: null, @@ -28,7 +29,9 @@ const elements = { runNextJob: document.getElementById("run-next-job"), runBatchJobs: document.getElementById("run-batch-jobs"), drainQueue: document.getElementById("drain-queue"), + recoverStaleJobs: document.getElementById("recover-stale-jobs"), queueActionStatus: document.getElementById("queue-action-status"), + staleJobSummary: document.getElementById("stale-job-summary"), workerSummary: document.getElementById("worker-summary"), workstationSummary: document.getElementById("workstation-summary"), workstationGrid: document.getElementById("workstation-grid"), @@ -119,16 +122,18 @@ function escapeHtml(value) { } async function loadOverview() { - const [health, modelsResponse, jobsResponse, workerResponse, workstationResponse] = await Promise.all([ + const [health, modelsResponse, jobsResponse, staleJobsResponse, workerResponse, workstationResponse] = await Promise.all([ fetchJson("/health"), fetchJson("/models"), fetchJson("/jobs?limit=12"), + fetchJson("/jobs/stale?limit=12"), fetchJson("/worker/status"), fetchJson("/workstation/readiness"), ]); state.models = modelsResponse.models; state.jobs = jobsResponse.jobs; + state.staleJobs = staleJobsResponse.jobs || []; state.worker = workerResponse; state.workstation = workstationResponse; @@ -136,6 +141,7 @@ async function loadOverview() { elements.modelCount.textContent = String(health.model_count); elements.jobCount.textContent = String(state.jobs.length); elements.queueState.textContent = summarizeQueue(state.jobs); + renderStaleJobs(); renderWorkerMonitor(); renderWorkstationReadiness(); await loadSmokeDiagnostics(); @@ -150,6 +156,48 @@ async function loadOverview() { } } +function renderStaleJobs() { + if (!elements.staleJobSummary) { + return; + } + + if (!state.staleJobs || state.staleJobs.length === 0) { + elements.staleJobSummary.className = "worker-summary empty"; + elements.staleJobSummary.textContent = "No stale running jobs detected."; + return; + } + + const cards = state.staleJobs + .map( + (job) => ` + + `, + ) + .join(""); + + elements.staleJobSummary.className = "worker-summary fade-in"; + elements.staleJobSummary.innerHTML = ` +
+ ${chip("stale", state.staleJobs.length, "danger")} + ${chip("threshold", `${state.staleJobs[0].stale_after_seconds}s`, "warn")} +
+
${cards}
+ `; + + elements.staleJobSummary.querySelectorAll("[data-stale-job-id]").forEach((button) => { + button.addEventListener("click", () => { + void selectJob(button.getAttribute("data-stale-job-id")).catch(handleError); + }); + }); +} + async function loadWorkstationReadiness() { state.workstation = await fetchJson("/workstation/readiness"); renderWorkstationReadiness(); @@ -1068,6 +1116,25 @@ async function runWorker(maxJobs, untilIdle = false) { } } +async function recoverStaleJobs() { + const response = await fetchJson("/jobs/recover", { + method: "POST", + headers: { "Content-Type": "application/json" }, + body: JSON.stringify({ limit: 20 }), + }); + const recovered = response.recovered || []; + setQueueActionStatus( + recovered.length > 0 + ? `Recovered ${recovered.length} stale job(s) back into the queue.` + : "No stale jobs needed recovery.", + recovered.length > 0 ? "warn" : "ready", + ); + await loadOverview(); + if (recovered.length > 0) { + await selectJob(recovered[0].id); + } +} + async function runSmokeSuite(suiteId, includeAirllm = false) { const response = await fetchJson("/smoke/run", { method: "POST", @@ -1109,6 +1176,9 @@ function bindEvents() { elements.drainQueue.addEventListener("click", () => { void runWorker(null, true).catch(handleQueueError); }); + elements.recoverStaleJobs.addEventListener("click", () => { + void recoverStaleJobs().catch(handleQueueError); + }); elements.refreshWorkstation.addEventListener("click", () => { void loadWorkstationReadiness().catch(handleError); }); diff --git a/src/lai/api/static/index.html b/src/lai/api/static/index.html index cac937d..1fafa90 100644 --- a/src/lai/api/static/index.html +++ b/src/lai/api/static/index.html @@ -123,11 +123,15 @@

Recent jobs

+
Worker idle. Use the queue controls to process queued jobs.
+
Checking stale jobs...
Loading worker monitor...
No jobs yet.
diff --git a/src/lai/cli.py b/src/lai/cli.py index 5f65ae0..2e0fa9d 100644 --- a/src/lai/cli.py +++ b/src/lai/cli.py @@ -390,6 +390,83 @@ def replay_job( ) +@jobs_app.command("stale") +def list_stale_jobs( + stale_after_seconds: int | None = typer.Option( + None, + min=1, + help="Override the stale-job timeout in seconds.", + ), + limit: int = typer.Option(20, min=1, max=200, help="Maximum stale jobs to show."), +) -> None: + """List stale running jobs that likely need recovery.""" + application = create_application() + stale_jobs = application.orchestration.list_stale_jobs( + stale_after_seconds=stale_after_seconds, + limit=limit, + ) + if not stale_jobs: + console.print("No stale running jobs detected.") + return + + table = Table(title="Stale Jobs") + table.add_column("Job") + table.add_column("Stage") + table.add_column("Executor") + table.add_column("Age") + table.add_column("Last Event") + table.add_column("Message") + for job in stale_jobs: + table.add_row( + job.job_id, + job.current_stage or "n/a", + job.executor_model_id or "n/a", + f"{job.age_seconds:.2f}s", + job.last_event_type or "n/a", + job.last_message or "n/a", + ) + console.print(table) + + +@jobs_app.command("recover") +def recover_stale_jobs( + stale_after_seconds: int | None = typer.Option( + None, + min=1, + help="Override the stale-job timeout in seconds.", + ), + limit: int = typer.Option( + 20, + min=1, + max=200, + help="Maximum stale jobs to recover in one pass.", + ), +) -> None: + """Recover stale running jobs by returning them to the queue.""" + application = create_application() + recovered = application.orchestration.recover_stale_jobs( + stale_after_seconds=stale_after_seconds, + limit=limit, + ) + if not recovered: + console.print("No stale jobs were recovered.") + return + + table = Table(title="Recovered Jobs") + table.add_column("Job") + table.add_column("Status") + table.add_column("Queue") + table.add_column("Executor") + for job in recovered: + table.add_row( + job.id, + job.status, + job.queue_mode, + job.route_decision.executor_model_id if job.route_decision else "n/a", + ) + console.print(table) + + @worker_app.command("run") def run_worker( once: bool = typer.Option(False, help="Process at most one queued job and then exit."), diff --git a/src/lai/domain.py b/src/lai/domain.py index 89605cb..1e2e301 100644 --- a/src/lai/domain.py +++ b/src/lai/domain.py @@ -343,6 +343,19 @@ class StageExecutionSummary(LAIModel): last_message: str | None = None +class StaleJobSummary(LAIModel): + job_id: str + queue_mode: QueueMode + attempts: int + executor_model_id: str | None = None + current_stage: str | None = None + last_event_type: str | None = None + last_message: str | None = None + last_activity_at: datetime + stale_after_seconds: int + age_seconds: float + + class WorkerStateRecord(LAIModel): worker_id: str status: WorkerStatus = WorkerStatus.IDLE diff --git a/src/lai/jobs/service.py b/src/lai/jobs/service.py index 04b0eb5..0ed169f 100644 --- a/src/lai/jobs/service.py +++ b/src/lai/jobs/service.py @@ -1,5 +1,6 @@ from __future__ import annotations +import threading import time from enum import Enum from uuid import uuid4 @@ -15,11 +16,13 @@ QueueMode, RoutingDecision, StageEventRecord, + StaleJobSummary, WorkerStateRecord, WorkerStatus, utcnow, ) from ..errors import RetryableProviderError +from ..observability import summarize_stale_jobs from ..providers import ProviderRegistry from ..routing import RoutingEngine from ..settings import Settings @@ -152,7 +155,7 @@ def cancel_job(self, job_id: str) -> JobRecord | None: ) return self.job_store.get_job(job_id) - def execute_job(self, job_id: str) -> JobRecord: + def execute_job(self, job_id: str, *, worker_id: str | None = None) -> JobRecord: job = self.job_store.get_job(job_id) if job is None: raise KeyError(f"Unknown job id {job_id!r}.") @@ -174,7 +177,7 @@ def execute_job(self, job_id: str) -> JobRecord: ) try: - final_text = self._run_pipeline(job) + final_text = self._run_pipeline(job, worker_id=worker_id) if job.result: job.result = job.result.model_copy(update={"text": final_text}) job.status = JobStatus.SUCCEEDED @@ -220,6 +223,50 @@ def execute_job(self, job_id: str) -> JobRecord: self.artifacts.write_text(job.id, "error", "error.txt", str(exc)) return self.job_store.get_job(job.id) or job + def list_stale_jobs( + self, + *, + stale_after_seconds: int | None = None, + limit: int = 100, + ) -> list[StaleJobSummary]: + threshold = stale_after_seconds or self.settings.stale_running_job_timeout_seconds + running_jobs = self.job_store.list_jobs_by_status(JobStatus.RUNNING, limit=limit) + return summarize_stale_jobs( + running_jobs, + stale_after_seconds=threshold, + ) + + def recover_stale_jobs( + self, + *, + stale_after_seconds: int | None = None, + limit: int = 100, + ) -> list[JobRecord]: + stale_jobs = self.list_stale_jobs( + stale_after_seconds=stale_after_seconds, + limit=limit, + ) + recovered: list[JobRecord] = [] + for stale_job in stale_jobs: + if not self.job_store.requeue_job(stale_job.job_id): + continue + self._record_event( + job_id=stale_job.job_id, + stage="job", + event_type="recovered", + message="Recovered a stale running job and returned it to the queue.", + metadata={ + "age_seconds": stale_job.age_seconds, + "stale_after_seconds": stale_job.stale_after_seconds, + "last_event_type": stale_job.last_event_type, + "current_stage": stale_job.current_stage, + }, + ) + refreshed = self.job_store.get_job(stale_job.job_id) + if refreshed is not None: + recovered.append(refreshed) + return recovered + def run_worker_batch( self, *, @@ -262,7 +309,7 @@ def run_worker_batch( mode=mode, queued_jobs=self.job_store.count_jobs_by_status(JobStatus.QUEUED), ) - processed_job = self.execute_job(job.id) + processed_job = self.execute_job(job.id, worker_id=resolved_worker_id) processed.append(processed_job) self._update_worker_state( worker_id=resolved_worker_id, @@ -369,7 +416,7 @@ def run_worker_until_idle( mode=mode, metadata={"phase": "processing"}, ) - processed_job = self.execute_job(job.id) + processed_job = self.execute_job(job.id, worker_id=resolved_worker_id) processed += 1 self._update_worker_state( worker_id=resolved_worker_id, @@ -471,7 +518,7 @@ def run_worker( mode=mode, metadata={"phase": "processing"}, ) - self.execute_job(job.id) + self.execute_job(job.id, worker_id=self.default_worker_id) processed += 1 if once: self._update_worker_state( @@ -496,7 +543,7 @@ def run_worker( metadata={"phase": "monitoring"}, ) - def _run_pipeline(self, job: JobRecord) -> str: + def _run_pipeline(self, job: JobRecord, *, worker_id: str | None = None) -> str: assert job.route_decision is not None planning_output = "" @@ -511,6 +558,7 @@ def _run_pipeline(self, job: JobRecord) -> str: "Keep it concise and implementation-focused." ), user_prompt=job.request.user_prompt, + worker_id=worker_id, ) executor_prompt = job.request.user_prompt @@ -528,6 +576,7 @@ def _run_pipeline(self, job: JobRecord) -> str: model_id=job.route_decision.executor_model_id, system_prompt=job.request.system_prompt, user_prompt=executor_prompt, + worker_id=worker_id, ) if job.route_decision.should_review and job.route_decision.reviewer_model_id: @@ -545,6 +594,7 @@ def _run_pipeline(self, job: JobRecord) -> str: "Candidate answer:\n" f"{executor_output}" ), + worker_id=worker_id, ) self.artifacts.write_text( job.id, "reviewer-notes", "reviewer_notes.txt", reviewer_output @@ -560,6 +610,7 @@ def _run_stage( model_id: str, system_prompt: str | None, user_prompt: str, + worker_id: str | None = None, ) -> str: model = self.config.model_catalog.get_model(model_id) provider = self.provider_registry.provider_for_model(model) @@ -588,6 +639,14 @@ def _run_stage( "request_preview": _preview_text(user_prompt), }, ) + self._mark_stage_worker_state( + worker_id=worker_id, + job_id=job.id, + stage=stage, + model_id=model.id, + provider_id=provider_id, + phase="processing", + ) health = self.provider_registry.healthcheck(model) if not health.available: reason = "; ".join(health.reasons) or "unavailable" @@ -610,7 +669,14 @@ def _run_stage( job.id, f"{stage}-request", f"{stage}_request.json", request.model_dump() ) try: - result = provider.generate(model, request) + with self._stage_heartbeat( + worker_id=worker_id, + job_id=job.id, + stage=stage, + model_id=model.id, + provider_id=provider_id, + ): + result = provider.generate(model, request) except Exception as exc: self._record_event( job_id=job.id, @@ -624,6 +690,15 @@ def _run_stage( "error_message": str(exc), }, ) + self._mark_stage_worker_state( + worker_id=worker_id, + job_id=job.id, + stage=stage, + model_id=model.id, + provider_id=provider_id, + phase="error", + extra_metadata={"error_type": type(exc).__name__}, + ) raise result.stage = stage self.artifacts.write_json( @@ -645,12 +720,115 @@ def _run_stage( "raw_keys": sorted(result.raw.keys()), }, ) + self._mark_stage_worker_state( + worker_id=worker_id, + job_id=job.id, + stage=stage, + model_id=model.id, + provider_id=provider_id, + phase="processing", + extra_metadata={ + "last_completed_stage": stage, + "last_stage_duration_seconds": result.duration_seconds, + "last_stage_finish_reason": result.finish_reason, + }, + ) if stage == "executor": job.result = result job.updated_at = utcnow() self.job_store.save_job(job) return result.text + def _mark_stage_worker_state( + self, + *, + worker_id: str | None, + job_id: str, + stage: str, + model_id: str, + provider_id: str, + phase: str, + extra_metadata: dict[str, object] | None = None, + ) -> None: + if worker_id is None: + return + current_state = self.get_worker_state(worker_id) + metadata = dict(current_state.metadata) + metadata.update( + { + "phase": phase, + "current_stage": stage, + "current_model_id": model_id, + "current_provider_id": provider_id, + "last_stage_heartbeat_at": utcnow().isoformat(), + } + ) + if extra_metadata: + metadata.update(extra_metadata) + self._update_worker_state( + worker_id=worker_id, + status=WorkerStatus.RUNNING, + current_job_id=job_id, + queued_jobs=self.job_store.count_jobs_by_status(JobStatus.QUEUED), + mode=current_state.mode, + metadata=metadata, + ) + + def _stage_heartbeat( + self, + *, + worker_id: str | None, + job_id: str, + stage: str, + model_id: str, + provider_id: str, + ): + interval = self.settings.job_heartbeat_interval_seconds + if interval <= 0: + return _NullContext() + + service = self + stop_event = threading.Event() + started_at = utcnow() + + class _StageHeartbeatContext: + def __enter__(self_inner): + def _heartbeat_loop() -> None: + while not stop_event.wait(interval): + elapsed_seconds = round((utcnow() - started_at).total_seconds(), 2) + service._record_event( + job_id=job_id, + stage=stage, + event_type="heartbeat", + message=f"{stage.title()} stage is still running.", + model_id=model_id, + provider_id=provider_id, + metadata={"elapsed_seconds": elapsed_seconds}, + ) + service._mark_stage_worker_state( + worker_id=worker_id, + job_id=job_id, + stage=stage, + model_id=model_id, + provider_id=provider_id, + phase="processing", + extra_metadata={"elapsed_seconds": elapsed_seconds}, + ) + + self_inner.thread = threading.Thread( + target=_heartbeat_loop, + name=f"lai-stage-heartbeat-{job_id}-{stage}", + daemon=True, + ) + self_inner.thread.start() + return self_inner + + def __exit__(self_inner, *_args: object) -> None: + stop_event.set() + self_inner.thread.join(timeout=interval + 1) + + return _StageHeartbeatContext() + def _record_event( self, *, @@ -719,3 +897,11 @@ def _preview_text(value: str, *, limit: int = 240) -> str: if len(normalized) <= limit: return normalized return normalized[: limit - 3] + "..." + + +class _NullContext: + def __enter__(self): + return self + + def __exit__(self, *_args: object) -> None: + return None diff --git a/src/lai/jobs/store.py b/src/lai/jobs/store.py index 0b330d5..8e76afb 100644 --- a/src/lai/jobs/store.py +++ b/src/lai/jobs/store.py @@ -234,6 +234,19 @@ def list_jobs(self, limit: int = 20) -> list[JobRecord]: ).fetchall() return [self._deserialize_job(connection, row) for row in rows] + def list_jobs_by_status( + self, + status: JobStatus, + *, + limit: int = 100, + ) -> list[JobRecord]: + with self._connect() as connection: + rows = connection.execute( + "SELECT * FROM jobs WHERE status = ? ORDER BY created_at DESC LIMIT ?", + (status, limit), + ).fetchall() + return [self._deserialize_job(connection, row) for row in rows] + def count_jobs_by_status(self, status: JobStatus) -> int: with self._connect() as connection: row = connection.execute( @@ -289,12 +302,38 @@ def claim_next_queued_job(self) -> JobRecord | None: def requeue_running_jobs(self) -> int: with self._connect() as connection: + now = utcnow().isoformat() result = connection.execute( - "UPDATE jobs SET status = ?, updated_at = ? WHERE status = ?", - (JobStatus.QUEUED, utcnow().isoformat(), JobStatus.RUNNING), + """ + UPDATE jobs + SET status = ?, + updated_at = ?, + queued_at = ?, + started_at = NULL, + finished_at = NULL + WHERE status = ? + """, + (JobStatus.QUEUED, now, now, JobStatus.RUNNING), ) return result.rowcount + def requeue_job(self, job_id: str) -> bool: + with self._connect() as connection: + now = utcnow().isoformat() + result = connection.execute( + """ + UPDATE jobs + SET status = ?, + updated_at = ?, + queued_at = ?, + started_at = NULL, + finished_at = NULL + WHERE id = ? AND status = ? + """, + (JobStatus.QUEUED, now, now, job_id, JobStatus.RUNNING), + ) + return result.rowcount > 0 + def _deserialize_job(self, connection: sqlite3.Connection, row: sqlite3.Row) -> JobRecord: artifacts = self._artifacts_for_job(connection, row["id"]) stage_events = self._stage_events_for_job(connection, row["id"]) diff --git a/src/lai/observability.py b/src/lai/observability.py index f1867c6..0b2f99b 100644 --- a/src/lai/observability.py +++ b/src/lai/observability.py @@ -1,8 +1,29 @@ from __future__ import annotations -from .domain import ArtifactRecord, JobRecord, StageEventRecord, StageExecutionSummary, UsageStats +from datetime import datetime, timezone + +from .domain import ( + ArtifactRecord, + JobRecord, + StageEventRecord, + StageExecutionSummary, + StaleJobSummary, + UsageStats, +) EXECUTION_STAGES = ("planner", "executor", "reviewer") +STALE_ACTIVITY_EVENT_TYPES = { + "running", + "started", + "heartbeat", + "completed", + "failed", + "blocked", + "retry-queued", + "recovered", + "succeeded", + "canceled", +} def summarize_job_execution(job: JobRecord) -> list[StageExecutionSummary]: @@ -62,6 +83,56 @@ def summarize_job_execution(job: JobRecord) -> list[StageExecutionSummary]: return summaries +def summarize_stale_jobs( + jobs: list[JobRecord], + *, + stale_after_seconds: int, + now: datetime | None = None, +) -> list[StaleJobSummary]: + now = now or datetime.now(tz=timezone.utc) + summaries: list[StaleJobSummary] = [] + for job in jobs: + activity_events = [ + event for event in job.stage_events if event.event_type in STALE_ACTIVITY_EVENT_TYPES + ] + last_event = activity_events[-1] if activity_events else ( + job.stage_events[-1] if job.stage_events else None + ) + last_activity = latest_job_activity(job, stage_events=activity_events) + age_seconds = (now - last_activity).total_seconds() + if age_seconds < stale_after_seconds: + continue + summaries.append( + StaleJobSummary( + job_id=job.id, + queue_mode=job.queue_mode, + attempts=job.attempts, + executor_model_id=( + job.route_decision.executor_model_id if job.route_decision else None + ), + current_stage=last_event.stage if last_event is not None else None, + last_event_type=last_event.event_type if last_event is not None else None, + last_message=last_event.message if last_event is not None else None, + last_activity_at=last_activity, + stale_after_seconds=stale_after_seconds, + age_seconds=round(age_seconds, 2), + ) + ) + return summaries + + +def latest_job_activity( + job: JobRecord, + *, + stage_events: list[StageEventRecord] | None = None, +) -> datetime: + candidates = [job.updated_at, job.started_at, job.finished_at] + events = stage_events if stage_events is not None else job.stage_events + candidates.extend(event.created_at for event in events) + activity = [candidate for candidate in candidates if candidate is not None] + return max(activity) if activity else job.created_at + + def _stage_artifacts(artifacts: list[ArtifactRecord], stage: str) -> list[ArtifactRecord]: target_types = {f"{stage}-request", f"{stage}-response"} if stage == "reviewer": diff --git a/src/lai/settings.py b/src/lai/settings.py index 059ed41..c33390f 100644 --- a/src/lai/settings.py +++ b/src/lai/settings.py @@ -49,6 +49,7 @@ class Settings(BaseSettings): queue_poll_interval_seconds: int = 5 worker_idle_sleep_seconds: float = 2.0 worker_service_poll_interval_seconds: float = 10.0 + job_heartbeat_interval_seconds: float = 30.0 max_retry_attempts: int = 1 stale_running_job_timeout_seconds: int = 900 diff --git a/tests/unit/test_api.py b/tests/unit/test_api.py index 46f1194..4213bfd 100644 --- a/tests/unit/test_api.py +++ b/tests/unit/test_api.py @@ -1,8 +1,11 @@ +from datetime import timedelta +from uuid import uuid4 + from fastapi.testclient import TestClient from lai.api import create_api from lai.application import create_application -from lai.domain import ExecutionRequest, QueueMode +from lai.domain import ExecutionRequest, QueueMode, StageEventRecord, utcnow from lai.settings import Settings from lai.system import collect_system_snapshot from tests.helpers import FakeProvider @@ -50,6 +53,8 @@ def test_api_exposes_expected_routes() -> None: assert "/worker/status" in routes assert "/route/explain" in routes assert "/jobs" in routes + assert "/jobs/stale" in routes + assert "/jobs/recover" in routes assert "/jobs/{job_id}" in routes assert "/jobs/{job_id}/timeline" in routes assert "/jobs/{job_id}/execution" in routes @@ -77,6 +82,7 @@ def test_dashboard_routes_serve_html_and_assets(repo_root, tmp_path) -> None: assert asset_response.status_code == 200 assert "async function loadWorkstationReadiness" in asset_response.text assert "async function loadExecution" in asset_response.text + assert "async function recoverStaleJobs" in asset_response.text def test_workstation_readiness_route_returns_profiles(repo_root, tmp_path) -> None: @@ -167,6 +173,52 @@ def test_job_execution_route_returns_stage_summaries(repo_root, tmp_path) -> Non assert executor["artifact_types"] +def test_stale_job_routes_list_and_recover_jobs(repo_root, tmp_path) -> None: + settings = _test_settings(repo_root, tmp_path) + providers = _test_providers(settings) + seeded_app = create_application(settings=settings, providers=providers) + stale_job = seeded_app.orchestration.submit_request( + ExecutionRequest( + user_prompt="Recover this interrupted overnight execution.", + queue_mode=QueueMode.QUEUED, + ) + ) + running = seeded_app.job_store.claim_next_queued_job() + assert running is not None + stale_at = utcnow() - timedelta(minutes=30) + running.updated_at = stale_at + running.started_at = stale_at + seeded_app.job_store.save_job(running) + seeded_app.job_store.add_stage_event( + StageEventRecord( + id=str(uuid4()), + job_id=stale_job.id, + stage="executor", + event_type="started", + message="Executor stage started long ago.", + created_at=stale_at, + model_id=running.route_decision.executor_model_id if running.route_decision else None, + ) + ) + + client = TestClient(create_api(settings=settings, providers=providers)) + stale_response = client.get("/jobs/stale?stale_after_seconds=60") + assert stale_response.status_code == 200 + stale_payload = stale_response.json() + assert stale_payload["jobs"] + assert stale_payload["jobs"][0]["job_id"] == stale_job.id + + recover_response = client.post( + "/jobs/recover", + json={"stale_after_seconds": 60, "limit": 10}, + ) + assert recover_response.status_code == 200 + recovered = recover_response.json()["recovered"] + assert recovered + assert recovered[0]["id"] == stale_job.id + assert recovered[0]["status"] == "queued" + + def test_smoke_routes_run_and_return_latest_result(repo_root, tmp_path) -> None: settings = _test_settings(repo_root, tmp_path) providers = _test_providers(settings) diff --git a/tests/unit/test_recovery.py b/tests/unit/test_recovery.py new file mode 100644 index 0000000..48b28bb --- /dev/null +++ b/tests/unit/test_recovery.py @@ -0,0 +1,134 @@ +from __future__ import annotations + +import time +from datetime import timedelta +from uuid import uuid4 + +from lai.application import create_application +from lai.domain import ExecutionRequest, QueueMode, StageEventRecord, utcnow +from lai.settings import Settings +from lai.system import collect_system_snapshot +from tests.helpers import FakeProvider + + +class SlowFakeProvider(FakeProvider): + def __init__(self, *args, delay_seconds: float = 0.05, **kwargs) -> None: + super().__init__(*args, **kwargs) + self.delay_seconds = delay_seconds + + def generate(self, model, request): + time.sleep(self.delay_seconds) + return super().generate(model, request) + + +def _test_settings(repo_root, tmp_path, **overrides) -> Settings: + base = { + "root_dir": tmp_path, + "model_catalog": repo_root / "configs/models/catalog.yaml", + "routing_policy": repo_root / "configs/routing/policies.yaml", + "database_path": tmp_path / "data/state/lai.db", + "state_dir": tmp_path / "data/state", + "artifacts_dir": tmp_path / "data/artifacts", + "logs_dir": tmp_path / "logs", + "huggingface_cache_dir": tmp_path / "data/cache/huggingface", + "airllm_shards_dir": tmp_path / "data/models/airllm-shards", + "raw_models_dir": tmp_path / "data/models/raw", + } + base.update(overrides) + return Settings(**base) + + +def _test_providers( + settings: Settings, + *, + slow_transformers: bool = False, +) -> dict[str, FakeProvider]: + snapshot = collect_system_snapshot(settings.root_dir, enable_gpu=False) + transformers_provider: FakeProvider + if slow_transformers: + transformers_provider = SlowFakeProvider( + settings, + snapshot, + "transformers", + delay_seconds=0.05, + ) + else: + transformers_provider = FakeProvider(settings, snapshot, "transformers") + return { + "transformers": transformers_provider, + "airllm": FakeProvider(settings, snapshot, "airllm"), + "openai": FakeProvider(settings, snapshot, "openai"), + "anthropic": FakeProvider(settings, snapshot, "anthropic"), + "gemini": FakeProvider(settings, snapshot, "gemini"), + } + + +def test_slow_stage_persists_heartbeat_events(repo_root, tmp_path) -> None: + settings = _test_settings( + repo_root, + tmp_path, + job_heartbeat_interval_seconds=0.01, + ) + application = create_application( + settings=settings, + providers=_test_providers(settings, slow_transformers=True), + ) + + job = application.orchestration.submit_request( + ExecutionRequest( + user_prompt="Summarize this short note.", + queue_mode=QueueMode.INLINE, + ) + ) + + assert any( + event.stage == "executor" and event.event_type == "heartbeat" + for event in job.stage_events + ) + + +def test_stale_job_detection_and_recovery(repo_root, tmp_path) -> None: + settings = _test_settings(repo_root, tmp_path) + application = create_application( + settings=settings, + providers=_test_providers(settings), + ) + job = application.orchestration.submit_request( + ExecutionRequest( + user_prompt="Recover this interrupted overnight execution.", + queue_mode=QueueMode.QUEUED, + ) + ) + running = application.job_store.claim_next_queued_job() + assert running is not None + + stale_at = utcnow() - timedelta(minutes=20) + running.updated_at = stale_at + running.started_at = stale_at + application.job_store.save_job(running) + application.job_store.add_stage_event( + StageEventRecord( + id=str(uuid4()), + job_id=job.id, + stage="executor", + event_type="started", + message="Executor stage started long ago.", + created_at=stale_at, + model_id=running.route_decision.executor_model_id if running.route_decision else None, + ) + ) + + stale_jobs = application.orchestration.list_stale_jobs(stale_after_seconds=60) + recovered = application.orchestration.recover_stale_jobs(stale_after_seconds=60) + refreshed = application.job_store.get_job(job.id) + + assert stale_jobs + assert stale_jobs[0].job_id == job.id + assert recovered + assert refreshed is not None + assert refreshed.status == "queued" + assert refreshed.started_at is None + assert any( + event.stage == "job" and event.event_type == "recovered" + for event in refreshed.stage_events + ) From d2fb5cdebbef323f56b5c43b50cbf3b31d218508 Mon Sep 17 00:00:00 2001 From: N3uralCreativity Date: Tue, 7 Apr 2026 21:14:10 +0200 Subject: [PATCH 14/16] feat: add provider progress telemetry --- README.md | 7 +- apps/web/README.md | 2 +- src/lai/api/static/dashboard.css | 5 + src/lai/api/static/dashboard.js | 49 ++++++- src/lai/cli.py | 10 ++ src/lai/domain.py | 12 ++ src/lai/evals.py | 16 ++- src/lai/jobs/service.py | 76 ++++++++++- src/lai/observability.py | 28 +++- src/lai/providers/base.py | 43 +++++- src/lai/providers/implementations.py | 169 +++++++++++++++++++++++- tests/helpers.py | 31 ++++- tests/integration/test_orchestration.py | 4 + tests/unit/test_api.py | 8 ++ tests/unit/test_observability.py | 4 + tests/unit/test_recovery.py | 4 +- 16 files changed, 443 insertions(+), 25 deletions(-) diff --git a/README.md b/README.md index 597e9ef..411e926 100644 --- a/README.md +++ b/README.md @@ -157,7 +157,8 @@ workstation readiness for heavy local execution, route explanation, recent job i stage telemetry, provider execution summaries, artifact/trace browsing, replay actions, persisted worker monitoring, saved smoke diagnostics with history drill-down, per-check metadata/output previews, stale-job detection/recovery for interrupted runs, -and bounded queue worker controls including an until-idle drain path. +bounded queue worker controls including an until-idle drain path, and live provider +progress phases for long-running execution stages. ## Dedicated worker service @@ -225,9 +226,9 @@ For step-by-step remediation and overnight run guidance, see ## Near-term priorities 1. Expand eval scenarios and richer reviewer/final-output refinement. -2. Add richer live provider execution visibility for real provider calls. +2. Add restart-aware recovery policies for interrupted daemon and overnight jobs. 3. Add operator-facing controls for smoke result retention and cleanup over time. -4. Add provider-specific live progress hooks where SDKs expose deeper execution signals. +4. Add partial-output capture where provider SDKs expose safe non-streaming response chunks. ## References diff --git a/apps/web/README.md b/apps/web/README.md index c8e75d8..fc9a129 100644 --- a/apps/web/README.md +++ b/apps/web/README.md @@ -9,7 +9,7 @@ Current dashboard capabilities: - route explanation form - job submission and recent queue inspection - persisted stage telemetry timeline for planner, executor, and reviewer flow -- provider execution summaries with duration, output preview, and artifact linkage per stage +- provider execution summaries with duration, live progress phases, output preview, and artifact linkage per stage - artifact and trace browsing for persisted jobs - job replay controls for inline and queued reruns - stale-job detection and one-click queue recovery for interrupted long-running jobs diff --git a/src/lai/api/static/dashboard.css b/src/lai/api/static/dashboard.css index 866dd1c..e98119a 100644 --- a/src/lai/api/static/dashboard.css +++ b/src/lai/api/static/dashboard.css @@ -525,6 +525,11 @@ textarea { margin-top: 0.55rem; } +.timeline-detail, +.worker-detail-line { + margin-top: 0.45rem; +} + .execution-artifacts { margin-top: 0.65rem; } diff --git a/src/lai/api/static/dashboard.js b/src/lai/api/static/dashboard.js index 3db7925..6ab0bd5 100644 --- a/src/lai/api/static/dashboard.js +++ b/src/lai/api/static/dashboard.js @@ -121,6 +121,13 @@ function escapeHtml(value) { .replaceAll("'", "'"); } +function formatProgressPercent(value) { + if (value === null || value === undefined || Number.isNaN(Number(value))) { + return null; + } + return `${Math.round(Number(value) * 100)}%`; +} + async function loadOverview() { const [health, modelsResponse, jobsResponse, staleJobsResponse, workerResponse, workstationResponse] = await Promise.all([ fetchJson("/health"), @@ -575,6 +582,11 @@ function renderWorkerMonitor() { : "n/a"; const startedAt = worker.started_at ? new Date(worker.started_at).toLocaleString() : "n/a"; const phase = worker.metadata?.phase ? `phase: ${worker.metadata.phase}` : "phase: n/a"; + const currentStage = worker.metadata?.current_stage || "n/a"; + const currentStagePhase = worker.metadata?.current_stage_phase || "n/a"; + const currentStageProgress = + formatProgressPercent(worker.metadata?.current_stage_progress) || "n/a"; + const currentStageMessage = worker.metadata?.current_stage_message || ""; const logFile = worker.log_file ? `log: ${worker.log_file}` : "log: n/a"; const stopSignal = worker.stop_signal_present ? '
stop signal: present
' : ""; const error = worker.last_error @@ -585,6 +597,8 @@ function renderWorkerMonitor() { elements.workerSummary.innerHTML = `
${chips.join("")}
heartbeat: ${escapeHtml(heartbeat)} | started: ${escapeHtml(startedAt)} | ${escapeHtml(phase)}
+
stage: ${escapeHtml(currentStage)} | provider phase: ${escapeHtml(currentStagePhase)} | progress: ${escapeHtml(currentStageProgress)}
+ ${currentStageMessage ? `
${escapeHtml(currentStageMessage)}
` : ""}
${escapeHtml(logFile)}
${stopSignal} ${error} @@ -817,9 +831,17 @@ function renderExecutionStage(jobId, stage) { if (stage.output_characters !== null && stage.output_characters !== undefined) { detailChips.push(chip("output", `${stage.output_characters} chars`)); } + if (stage.progress_event_count) { + detailChips.push(chip("progress", stage.progress_event_count, "warn")); + } if (stage.usage?.total_tokens !== null && stage.usage?.total_tokens !== undefined) { detailChips.push(chip("tokens", stage.usage.total_tokens)); } + if (stage.latest_progress_percent !== null && stage.latest_progress_percent !== undefined) { + detailChips.push( + chip("latest", formatProgressPercent(stage.latest_progress_percent), "warn"), + ); + } const artifactChips = (stage.artifact_types || []) .map((artifactType, index) => renderExecutionArtifactLink(jobId, stage, artifactType, index)) @@ -852,6 +874,17 @@ function renderExecutionStage(jobId, stage) { ] .filter(Boolean) .join(" | "); + const latestProgress = [ + stage.latest_progress_phase ? `phase ${stage.latest_progress_phase}` : null, + stage.latest_progress_percent !== null && stage.latest_progress_percent !== undefined + ? formatProgressPercent(stage.latest_progress_percent) + : null, + stage.latest_progress_at + ? `updated ${new Date(stage.latest_progress_at).toLocaleString()}` + : null, + ] + .filter(Boolean) + .join(" | "); return `
@@ -864,6 +897,8 @@ function renderExecutionStage(jobId, stage) {
${escapeHtml(stage.last_message || "No execution message recorded.")}
${requestKnobs ? `
${escapeHtml(requestKnobs)}
` : ""} + ${latestProgress ? `
${escapeHtml(latestProgress)}
` : ""} + ${stage.latest_progress_message ? `
${escapeHtml(stage.latest_progress_message)}
` : ""} ${artifactChips ? `
${artifactChips}
` : ""} ${error} ${healthReasons} @@ -913,7 +948,7 @@ async function loadTimeline(jobId, stageEvents) { const tone = event.event_type === "failed" || event.event_type === "blocked" ? "danger" - : event.event_type === "running" || event.event_type === "started" + : event.event_type === "running" || event.event_type === "started" || event.event_type === "progress" ? "warn" : "ready"; const detailChips = [ @@ -926,6 +961,17 @@ async function loadTimeline(jobId, stageEvents) { if (event.provider_id) { detailChips.push(chip("provider", event.provider_id)); } + const metadataLine = [ + event.metadata?.phase ? `phase ${event.metadata.phase}` : null, + event.metadata?.progress !== null && event.metadata?.progress !== undefined + ? formatProgressPercent(event.metadata.progress) + : null, + event.metadata?.elapsed_seconds !== null && event.metadata?.elapsed_seconds !== undefined + ? `elapsed ${event.metadata.elapsed_seconds}s` + : null, + ] + .filter(Boolean) + .join(" | "); return `
@@ -933,6 +979,7 @@ async function loadTimeline(jobId, stageEvents) {
${escapeHtml(new Date(event.created_at).toLocaleString())}
${escapeHtml(event.message)}
+ ${metadataLine ? `
${escapeHtml(metadataLine)}
` : ""}
`; }) diff --git a/src/lai/cli.py b/src/lai/cli.py index 2e0fa9d..75bea62 100644 --- a/src/lai/cli.py +++ b/src/lai/cli.py @@ -340,8 +340,17 @@ def show_job(job_id: str = typer.Argument(..., help="Job identifier.")) -> None: execution_table.add_column("Provider") execution_table.add_column("Model") execution_table.add_column("Duration") + execution_table.add_column("Progress") execution_table.add_column("Output") for summary in execution: + progress = "n/a" + if summary.latest_progress_phase or summary.latest_progress_percent is not None: + parts = [] + if summary.latest_progress_phase: + parts.append(summary.latest_progress_phase) + if summary.latest_progress_percent is not None: + parts.append(f"{summary.latest_progress_percent * 100:.0f}%") + progress = " / ".join(parts) execution_table.add_row( summary.stage, summary.status, @@ -352,6 +361,7 @@ def show_job(job_id: str = typer.Argument(..., help="Job identifier.")) -> None: if summary.duration_seconds is not None else "n/a" ), + progress, summary.output_preview or "n/a", ) console.print(execution_table) diff --git a/src/lai/domain.py b/src/lai/domain.py index 1e2e301..8166aec 100644 --- a/src/lai/domain.py +++ b/src/lai/domain.py @@ -273,6 +273,13 @@ class ProviderRequest(LAIModel): model_override: str | None = None +class ProviderProgressUpdate(LAIModel): + phase: str + message: str + progress: float | None = Field(default=None, ge=0, le=1) + metadata: dict[str, Any] = Field(default_factory=dict) + + class ProviderHealth(LAIModel): provider_id: str available: bool @@ -341,6 +348,11 @@ class StageExecutionSummary(LAIModel): health_reasons: list[str] = Field(default_factory=list) error_type: str | None = None last_message: str | None = None + progress_event_count: int = 0 + latest_progress_message: str | None = None + latest_progress_phase: str | None = None + latest_progress_at: datetime | None = None + latest_progress_percent: float | None = None class StaleJobSummary(LAIModel): diff --git a/src/lai/evals.py b/src/lai/evals.py index 361b2f4..9c64ca1 100644 --- a/src/lai/evals.py +++ b/src/lai/evals.py @@ -17,7 +17,7 @@ ProviderHealth, ProviderRequest, ) -from .providers.base import Provider +from .providers.base import ProgressCallback, Provider from .settings import Settings from .system import collect_system_snapshot @@ -139,7 +139,19 @@ def healthcheck(self, model: ModelSpec) -> ProviderHealth: ) return ProviderHealth(provider_id=self.provider_id, available=True, healthy=True) - def generate(self, model: ModelSpec, request: ProviderRequest) -> ExecutionResult: + def generate( + self, + model: ModelSpec, + request: ProviderRequest, + *, + progress_callback: ProgressCallback | None = None, + ) -> ExecutionResult: + self._emit_progress( + progress_callback, + phase="evaluate", + message=f"Running evaluation provider for {model.id}.", + progress=0.5, + ) if model.role == "router": text = '{"tier_id": "standard"}' else: diff --git a/src/lai/jobs/service.py b/src/lai/jobs/service.py index 0ed169f..9ae8234 100644 --- a/src/lai/jobs/service.py +++ b/src/lai/jobs/service.py @@ -12,6 +12,7 @@ JobError, JobRecord, JobStatus, + ProviderProgressUpdate, ProviderRequest, QueueMode, RoutingDecision, @@ -646,6 +647,11 @@ def _run_stage( model_id=model.id, provider_id=provider_id, phase="processing", + extra_metadata={ + "current_stage_phase": "starting", + "current_stage_progress": 0.0, + "current_stage_message": f"{stage.title()} stage started.", + }, ) health = self.provider_registry.healthcheck(model) if not health.available: @@ -676,7 +682,17 @@ def _run_stage( model_id=model.id, provider_id=provider_id, ): - result = provider.generate(model, request) + result = provider.generate( + model, + request, + progress_callback=self._provider_progress_callback( + worker_id=worker_id, + job_id=job.id, + stage=stage, + model_id=model.id, + provider_id=provider_id, + ), + ) except Exception as exc: self._record_event( job_id=job.id, @@ -697,7 +713,11 @@ def _run_stage( model_id=model.id, provider_id=provider_id, phase="error", - extra_metadata={"error_type": type(exc).__name__}, + extra_metadata={ + "error_type": type(exc).__name__, + "current_stage_phase": "error", + "current_stage_message": str(exc), + }, ) raise result.stage = stage @@ -731,6 +751,9 @@ def _run_stage( "last_completed_stage": stage, "last_stage_duration_seconds": result.duration_seconds, "last_stage_finish_reason": result.finish_reason, + "current_stage_phase": "completed", + "current_stage_progress": 1.0, + "current_stage_message": f"{stage.title()} stage completed.", }, ) if stage == "executor": @@ -774,6 +797,50 @@ def _mark_stage_worker_state( metadata=metadata, ) + def _provider_progress_callback( + self, + *, + worker_id: str | None, + job_id: str, + stage: str, + model_id: str, + provider_id: str, + ): + def _callback(update: ProviderProgressUpdate) -> None: + metadata = dict(update.metadata) + metadata["phase"] = update.phase + if update.progress is not None: + metadata["progress"] = update.progress + self._record_event( + job_id=job_id, + stage=stage, + event_type="progress", + message=update.message, + model_id=model_id, + provider_id=provider_id, + metadata=metadata, + ) + self._mark_stage_worker_state( + worker_id=worker_id, + job_id=job_id, + stage=stage, + model_id=model_id, + provider_id=provider_id, + phase="processing", + extra_metadata={ + "current_stage_phase": update.phase, + "current_stage_message": update.message, + "current_stage_updated_at": utcnow().isoformat(), + **( + {"current_stage_progress": update.progress} + if update.progress is not None + else {} + ), + }, + ) + + return _callback + def _stage_heartbeat( self, *, @@ -812,7 +879,10 @@ def _heartbeat_loop() -> None: model_id=model_id, provider_id=provider_id, phase="processing", - extra_metadata={"elapsed_seconds": elapsed_seconds}, + extra_metadata={ + "elapsed_seconds": elapsed_seconds, + "current_stage_updated_at": utcnow().isoformat(), + }, ) self_inner.thread = threading.Thread( diff --git a/src/lai/observability.py b/src/lai/observability.py index 0b2f99b..344594c 100644 --- a/src/lai/observability.py +++ b/src/lai/observability.py @@ -15,6 +15,7 @@ STALE_ACTIVITY_EVENT_TYPES = { "running", "started", + "progress", "heartbeat", "completed", "failed", @@ -33,6 +34,8 @@ def summarize_job_execution(job: JobRecord) -> list[StageExecutionSummary]: if not stage_events: continue + progress_events = [event for event in stage_events if event.event_type == "progress"] + latest_progress = progress_events[-1] if progress_events else None started = _first_event(stage_events, "started") completed = _last_event(stage_events, "completed") failed = _last_event(stage_events, "failed") @@ -76,7 +79,30 @@ def summarize_job_execution(job: JobRecord) -> list[StageExecutionSummary]: artifact_types=[artifact.artifact_type for artifact in stage_artifacts], health_reasons=list(final_metadata.get("health_reasons") or []), error_type=final_metadata.get("error_type"), - last_message=final_event.message if final_event is not None else None, + last_message=( + latest_progress.message + if status == "running" and latest_progress is not None + else (final_event.message if final_event is not None else None) + ), + progress_event_count=len(progress_events), + latest_progress_message=( + latest_progress.message if latest_progress is not None else None + ), + latest_progress_phase=( + str(latest_progress.metadata.get("phase")) + if latest_progress is not None + and latest_progress.metadata.get("phase") is not None + else None + ), + latest_progress_at=( + latest_progress.created_at if latest_progress is not None else None + ), + latest_progress_percent=( + float(latest_progress.metadata["progress"]) + if latest_progress is not None + and latest_progress.metadata.get("progress") is not None + else None + ), ) ) diff --git a/src/lai/providers/base.py b/src/lai/providers/base.py index a154656..83362be 100644 --- a/src/lai/providers/base.py +++ b/src/lai/providers/base.py @@ -2,12 +2,21 @@ from abc import ABC, abstractmethod from time import perf_counter -from typing import Any - -from ..domain import ExecutionResult, FinishReason, ModelSpec, ProviderHealth, ProviderRequest +from typing import Any, Callable + +from ..domain import ( + ExecutionResult, + FinishReason, + ModelSpec, + ProviderHealth, + ProviderProgressUpdate, + ProviderRequest, +) from ..settings import Settings from ..system import SystemSnapshot +ProgressCallback = Callable[[ProviderProgressUpdate], None] + def serialize_raw_object(value: Any) -> dict[str, Any]: if hasattr(value, "model_dump"): @@ -38,7 +47,13 @@ def available(self, model: ModelSpec) -> ProviderHealth: return self.healthcheck(model) @abstractmethod - def generate(self, model: ModelSpec, request: ProviderRequest) -> ExecutionResult: + def generate( + self, + model: ModelSpec, + request: ProviderRequest, + *, + progress_callback: ProgressCallback | None = None, + ) -> ExecutionResult: raise NotImplementedError @abstractmethod @@ -91,3 +106,23 @@ def _ok_health(self, provider_id: str, **metadata: Any) -> ProviderHealth: def _timer(self) -> float: return perf_counter() + + def _emit_progress( + self, + progress_callback: ProgressCallback | None, + *, + phase: str, + message: str, + progress: float | None = None, + **metadata: Any, + ) -> None: + if progress_callback is None: + return + progress_callback( + ProviderProgressUpdate( + phase=phase, + message=message, + progress=progress, + metadata=metadata, + ) + ) diff --git a/src/lai/providers/implementations.py b/src/lai/providers/implementations.py index 3d984e7..354d0c3 100644 --- a/src/lai/providers/implementations.py +++ b/src/lai/providers/implementations.py @@ -13,7 +13,7 @@ ) from ..settings import Settings from ..system import available_disk_gb -from .base import Provider, serialize_raw_object +from .base import ProgressCallback, Provider, serialize_raw_object class TransformersProvider(Provider): @@ -41,12 +41,25 @@ def describe_capabilities(self) -> dict[str, Any]: "requires_network_once": True, } - def generate(self, model: ModelSpec, request: ProviderRequest): + def generate( + self, + model: ModelSpec, + request: ProviderRequest, + *, + progress_callback: ProgressCallback | None = None, + ): from transformers import pipeline started = perf_counter() pipe = self._pipeline_cache.get(model.id) if pipe is None: + self._emit_progress( + progress_callback, + phase="loading-pipeline", + message=f"Loading Transformers pipeline for {model.id}.", + progress=0.1, + cached=False, + ) pipe = pipeline( "text-generation", model=model.model_ref, @@ -59,14 +72,36 @@ def generate(self, model: ModelSpec, request: ProviderRequest): }, ) self._pipeline_cache[model.id] = pipe + else: + self._emit_progress( + progress_callback, + phase="pipeline-ready", + message=f"Reusing cached Transformers pipeline for {model.id}.", + progress=0.2, + cached=True, + ) prompt = _render_prompt(request.system_prompt, request.user_prompt) + self._emit_progress( + progress_callback, + phase="generating", + message=f"Generating text with {model.id}.", + progress=0.55, + prompt_characters=len(prompt), + ) result = pipe( prompt, max_new_tokens=request.max_output_tokens, temperature=request.temperature, return_full_text=False, ) + self._emit_progress( + progress_callback, + phase="parsing-output", + message=f"Parsing model output for {model.id}.", + progress=0.9, + result_count=len(result), + ) text = result[0]["generated_text"].strip() if result else "" duration = perf_counter() - started return self._result( @@ -132,20 +167,48 @@ def describe_capabilities(self) -> dict[str, Any]: "supports_cpu_fallback": True, } - def generate(self, model: ModelSpec, request: ProviderRequest): + def generate( + self, + model: ModelSpec, + request: ProviderRequest, + *, + progress_callback: ProgressCallback | None = None, + ): from airllm import AutoModel started = perf_counter() loaded_model = self._model_cache.get(model.id) if loaded_model is None: + self._emit_progress( + progress_callback, + phase="loading-shards", + message=f"Loading AirLLM model shards for {model.id}.", + progress=0.1, + cached=False, + ) loaded_model = AutoModel.from_pretrained( model.model_ref, hf_token=self.settings.huggingface_token_value, layer_shards_saving_path=str(self.settings.resolved_airllm_shards_dir / model.id), ) self._model_cache[model.id] = loaded_model + else: + self._emit_progress( + progress_callback, + phase="model-ready", + message=f"Reusing cached AirLLM model for {model.id}.", + progress=0.2, + cached=True, + ) prompt = _render_prompt(request.system_prompt, request.user_prompt) + self._emit_progress( + progress_callback, + phase="tokenizing", + message=f"Tokenizing prompt for {model.id}.", + progress=0.35, + prompt_characters=len(prompt), + ) tokenized = loaded_model.tokenizer( [prompt], return_tensors="pt", @@ -157,14 +220,33 @@ def generate(self, model: ModelSpec, request: ProviderRequest): input_ids = tokenized["input_ids"] if self.system_snapshot.has_gpu: + self._emit_progress( + progress_callback, + phase="transferring-input", + message=f"Moving prompt tensors to GPU for {model.id}.", + progress=0.45, + ) input_ids = input_ids.cuda() + self._emit_progress( + progress_callback, + phase="generating", + message=f"Generating response with AirLLM model {model.id}.", + progress=0.7, + ) generation = loaded_model.generate( input_ids, max_new_tokens=request.max_output_tokens, use_cache=True, return_dict_in_generate=True, ) + self._emit_progress( + progress_callback, + phase="decoding", + message=f"Decoding generated tokens for {model.id}.", + progress=0.9, + generated_token_count=len(generation.sequences[0]), + ) decoded = loaded_model.tokenizer.decode(generation.sequences[0], skip_special_tokens=True) text = _strip_prompt_prefix(decoded, prompt) duration = perf_counter() - started @@ -193,11 +275,30 @@ def healthcheck(self, model: ModelSpec) -> ProviderHealth: def describe_capabilities(self) -> dict[str, Any]: return {"supports_remote_generation": True, "supports_streaming": False} - def generate(self, model: ModelSpec, request: ProviderRequest): + def generate( + self, + model: ModelSpec, + request: ProviderRequest, + *, + progress_callback: ProgressCallback | None = None, + ): from openai import OpenAI started = perf_counter() + self._emit_progress( + progress_callback, + phase="creating-client", + message=f"Creating OpenAI client for {model.id}.", + progress=0.1, + ) client = OpenAI(api_key=self.settings.openai_api_key_value, timeout=request.timeout_seconds) + self._emit_progress( + progress_callback, + phase="request-sent", + message=f"Sending OpenAI request for {model.id}.", + progress=0.45, + prompt_characters=len(request.user_prompt), + ) response = client.responses.create( model=request.model_override or model.model_ref, instructions=request.system_prompt or None, @@ -205,6 +306,12 @@ def generate(self, model: ModelSpec, request: ProviderRequest): max_output_tokens=request.max_output_tokens, temperature=request.temperature, ) + self._emit_progress( + progress_callback, + phase="response-received", + message=f"OpenAI response received for {model.id}.", + progress=0.9, + ) duration = perf_counter() - started usage = getattr(response, "usage", None) usage_stats = ( @@ -242,13 +349,32 @@ def healthcheck(self, model: ModelSpec) -> ProviderHealth: def describe_capabilities(self) -> dict[str, Any]: return {"supports_remote_generation": True, "supports_streaming": False} - def generate(self, model: ModelSpec, request: ProviderRequest): + def generate( + self, + model: ModelSpec, + request: ProviderRequest, + *, + progress_callback: ProgressCallback | None = None, + ): from anthropic import Anthropic started = perf_counter() + self._emit_progress( + progress_callback, + phase="creating-client", + message=f"Creating Anthropic client for {model.id}.", + progress=0.1, + ) client = Anthropic( api_key=self.settings.anthropic_api_key_value, timeout=request.timeout_seconds ) + self._emit_progress( + progress_callback, + phase="request-sent", + message=f"Sending Anthropic request for {model.id}.", + progress=0.45, + prompt_characters=len(request.user_prompt), + ) response = client.messages.create( model=request.model_override or model.model_ref, system=request.system_prompt or "", @@ -256,6 +382,12 @@ def generate(self, model: ModelSpec, request: ProviderRequest): max_tokens=request.max_output_tokens, temperature=request.temperature, ) + self._emit_progress( + progress_callback, + phase="response-received", + message=f"Anthropic response received for {model.id}.", + progress=0.9, + ) text_parts = [ block.text for block in response.content if getattr(block, "type", None) == "text" ] @@ -296,10 +428,22 @@ def healthcheck(self, model: ModelSpec) -> ProviderHealth: def describe_capabilities(self) -> dict[str, Any]: return {"supports_remote_generation": True, "supports_streaming": False} - def generate(self, model: ModelSpec, request: ProviderRequest): + def generate( + self, + model: ModelSpec, + request: ProviderRequest, + *, + progress_callback: ProgressCallback | None = None, + ): from google import genai started = perf_counter() + self._emit_progress( + progress_callback, + phase="creating-client", + message=f"Creating Gemini client for {model.id}.", + progress=0.1, + ) client = genai.Client(api_key=self.settings.gemini_api_key_value) config: dict[str, Any] = { "temperature": request.temperature, @@ -307,11 +451,24 @@ def generate(self, model: ModelSpec, request: ProviderRequest): } if request.system_prompt: config["system_instruction"] = request.system_prompt + self._emit_progress( + progress_callback, + phase="request-sent", + message=f"Sending Gemini request for {model.id}.", + progress=0.45, + prompt_characters=len(request.user_prompt), + ) response = client.models.generate_content( model=request.model_override or model.model_ref, contents=request.user_prompt, config=config, ) + self._emit_progress( + progress_callback, + phase="response-received", + message=f"Gemini response received for {model.id}.", + progress=0.9, + ) duration = perf_counter() - started return self._result( text=(getattr(response, "text", "") or "").strip(), diff --git a/tests/helpers.py b/tests/helpers.py index 72fefb7..998cf16 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -3,7 +3,7 @@ from typing import Any from lai.domain import ExecutionResult, FinishReason, ModelSpec, ProviderHealth, ProviderRequest -from lai.providers.base import Provider +from lai.providers.base import ProgressCallback, Provider from lai.settings import Settings from lai.system import SystemSnapshot @@ -31,8 +31,28 @@ def healthcheck(self, model: ModelSpec) -> ProviderHealth: reasons=[f"{self.provider_id} unavailable in test harness"], ) - def generate(self, model: ModelSpec, request: ProviderRequest) -> ExecutionResult: + def generate( + self, + model: ModelSpec, + request: ProviderRequest, + *, + progress_callback: ProgressCallback | None = None, + ) -> ExecutionResult: stage = str(request.metadata.get("stage", "executor")) + self._emit_progress( + progress_callback, + phase="prepare", + message=f"{self.provider_id} preparing {stage} stage.", + progress=0.2, + stage=stage, + ) + self._emit_progress( + progress_callback, + phase="generate", + message=f"{self.provider_id} generating {stage} stage output.", + progress=0.7, + stage=stage, + ) if stage == "planner": text = "plan: inspect, execute, verify" elif stage == "reviewer": @@ -43,6 +63,13 @@ def generate(self, model: ModelSpec, request: ProviderRequest) -> ExecutionResul text = '{"tier_id": "standard"}' else: text = f"generated by {model.id}: {request.user_prompt}" + self._emit_progress( + progress_callback, + phase="finalize", + message=f"{self.provider_id} finalized {stage} stage output.", + progress=0.95, + stage=stage, + ) return ExecutionResult( text=text, finish_reason=FinishReason.STOP, diff --git a/tests/integration/test_orchestration.py b/tests/integration/test_orchestration.py index 9bb4cd0..1da621f 100644 --- a/tests/integration/test_orchestration.py +++ b/tests/integration/test_orchestration.py @@ -48,6 +48,10 @@ def test_inline_execution_persists_artifacts(tmp_path, repo_root) -> None: event.stage == "executor" and event.event_type == "completed" for event in job.stage_events ) + assert any( + event.stage == "executor" and event.event_type == "progress" + for event in job.stage_events + ) assert any( event.stage == "job" and event.event_type == "succeeded" for event in job.stage_events diff --git a/tests/unit/test_api.py b/tests/unit/test_api.py index 4213bfd..7c650ee 100644 --- a/tests/unit/test_api.py +++ b/tests/unit/test_api.py @@ -83,6 +83,7 @@ def test_dashboard_routes_serve_html_and_assets(repo_root, tmp_path) -> None: assert "async function loadWorkstationReadiness" in asset_response.text assert "async function loadExecution" in asset_response.text assert "async function recoverStaleJobs" in asset_response.text + assert "function formatProgressPercent" in asset_response.text def test_workstation_readiness_route_returns_profiles(repo_root, tmp_path) -> None: @@ -148,6 +149,10 @@ def test_job_timeline_route_returns_stage_events(repo_root, tmp_path) -> None: event["stage"] == "executor" and event["event_type"] == "completed" for event in payload["stage_events"] ) + assert any( + event["stage"] == "executor" and event["event_type"] == "progress" + for event in payload["stage_events"] + ) def test_job_execution_route_returns_stage_summaries(repo_root, tmp_path) -> None: @@ -171,6 +176,9 @@ def test_job_execution_route_returns_stage_summaries(repo_root, tmp_path) -> Non assert executor["status"] == "completed" assert executor["provider_id"] == "transformers" assert executor["artifact_types"] + assert executor["progress_event_count"] > 0 + assert executor["latest_progress_phase"] + assert executor["latest_progress_message"] def test_stale_job_routes_list_and_recover_jobs(repo_root, tmp_path) -> None: diff --git a/tests/unit/test_observability.py b/tests/unit/test_observability.py index 15b2281..47648bc 100644 --- a/tests/unit/test_observability.py +++ b/tests/unit/test_observability.py @@ -52,7 +52,11 @@ def test_summarize_job_execution_returns_stage_details(repo_root, tmp_path) -> N assert executor.prompt_characters is not None assert executor.output_preview assert "final-output" in executor.artifact_types + assert executor.progress_event_count > 0 + assert executor.latest_progress_phase + assert executor.latest_progress_message reviewer = next(summary for summary in summaries if summary.stage == "reviewer") assert reviewer.status == "completed" assert "reviewer-notes" in reviewer.artifact_types + assert reviewer.progress_event_count > 0 diff --git a/tests/unit/test_recovery.py b/tests/unit/test_recovery.py index 48b28bb..d03400f 100644 --- a/tests/unit/test_recovery.py +++ b/tests/unit/test_recovery.py @@ -16,9 +16,9 @@ def __init__(self, *args, delay_seconds: float = 0.05, **kwargs) -> None: super().__init__(*args, **kwargs) self.delay_seconds = delay_seconds - def generate(self, model, request): + def generate(self, model, request, *, progress_callback=None): time.sleep(self.delay_seconds) - return super().generate(model, request) + return super().generate(model, request, progress_callback=progress_callback) def _test_settings(repo_root, tmp_path, **overrides) -> Settings: From 11f258dd276ecae29a22afc418212f01b678dc32 Mon Sep 17 00:00:00 2001 From: N3uralCreativity Date: Tue, 7 Apr 2026 21:22:02 +0200 Subject: [PATCH 15/16] feat: add daemon restart recovery --- README.md | 6 +- apps/web/README.md | 2 +- src/lai/api/static/dashboard.js | 9 ++ src/lai/cli.py | 11 ++ src/lai/jobs/service.py | 56 +++++++++-- src/lai/settings.py | 2 + src/lai/worker/service.py | 161 ++++++++++++++++++++++++++++-- tests/unit/test_worker_service.py | 143 +++++++++++++++++++++++--- 8 files changed, 361 insertions(+), 29 deletions(-) diff --git a/README.md b/README.md index 411e926..30c57a1 100644 --- a/README.md +++ b/README.md @@ -174,6 +174,8 @@ The service: - acquires a lock at `data/state/worker-service.lock` - watches for a graceful stop signal at `data/state/worker-service.stop` - writes service logs to `logs/worker-service.log` +- recovers interrupted running jobs on startup when a stale lock is explicitly replaced +- auto-recovers stale running jobs during daemon cycles using the configured recovery timeout - keeps the persisted worker state fresh for the CLI, API, and dashboard To stop it gracefully: @@ -226,8 +228,8 @@ For step-by-step remediation and overnight run guidance, see ## Near-term priorities 1. Expand eval scenarios and richer reviewer/final-output refinement. -2. Add restart-aware recovery policies for interrupted daemon and overnight jobs. -3. Add operator-facing controls for smoke result retention and cleanup over time. +2. Add operator-facing controls for smoke result retention and cleanup over time. +3. Add worker-service lock inspection and supervised restart tooling. 4. Add partial-output capture where provider SDKs expose safe non-streaming response chunks. ## References diff --git a/apps/web/README.md b/apps/web/README.md index fc9a129..7f6092e 100644 --- a/apps/web/README.md +++ b/apps/web/README.md @@ -13,7 +13,7 @@ Current dashboard capabilities: - artifact and trace browsing for persisted jobs - job replay controls for inline and queued reruns - stale-job detection and one-click queue recovery for interrupted long-running jobs -- persisted worker monitoring with heartbeat, current job, and queue depth +- persisted worker monitoring with heartbeat, current job, queue depth, and latest daemon recovery summary - service-aware worker monitoring with daemon lock and stop-signal visibility - saved smoke diagnostics for provider and local readiness, with recent-run history browsing - per-check smoke metadata and live-output preview inspection from saved results diff --git a/src/lai/api/static/dashboard.js b/src/lai/api/static/dashboard.js index 6ab0bd5..83098cb 100644 --- a/src/lai/api/static/dashboard.js +++ b/src/lai/api/static/dashboard.js @@ -587,11 +587,18 @@ function renderWorkerMonitor() { const currentStageProgress = formatProgressPercent(worker.metadata?.current_stage_progress) || "n/a"; const currentStageMessage = worker.metadata?.current_stage_message || ""; + const lastRecovery = worker.metadata?.last_recovery || null; const logFile = worker.log_file ? `log: ${worker.log_file}` : "log: n/a"; const stopSignal = worker.stop_signal_present ? '
stop signal: present
' : ""; const error = worker.last_error ? `
error: ${escapeHtml(worker.last_error)}
` : ""; + const recoveryLine = lastRecovery + ? `last recovery: ${lastRecovery.count || 0} job(s) via ${lastRecovery.reason || "n/a"} at ${lastRecovery.at || "n/a"}` + : ""; + const recoveryJobs = Array.isArray(lastRecovery?.job_ids) && lastRecovery.job_ids.length > 0 + ? `recovered jobs: ${lastRecovery.job_ids.map((jobId) => String(jobId).slice(0, 8)).join(", ")}` + : ""; elements.workerSummary.className = "worker-summary fade-in"; elements.workerSummary.innerHTML = ` @@ -599,6 +606,8 @@ function renderWorkerMonitor() {
heartbeat: ${escapeHtml(heartbeat)} | started: ${escapeHtml(startedAt)} | ${escapeHtml(phase)}
stage: ${escapeHtml(currentStage)} | provider phase: ${escapeHtml(currentStagePhase)} | progress: ${escapeHtml(currentStageProgress)}
${currentStageMessage ? `
${escapeHtml(currentStageMessage)}
` : ""} + ${recoveryLine ? `
${escapeHtml(recoveryLine)}
` : ""} + ${recoveryJobs ? `
${escapeHtml(recoveryJobs)}
` : ""}
${escapeHtml(logFile)}
${stopSignal} ${error} diff --git a/src/lai/cli.py b/src/lai/cli.py index 75bea62..5cdb308 100644 --- a/src/lai/cli.py +++ b/src/lai/cli.py @@ -556,6 +556,17 @@ def worker_status() -> None: ) table.add_row("Service log", str(application.settings.resolved_worker_service_log_path)) table.add_row("Last error", worker.last_error or "none") + last_recovery = worker.metadata.get("last_recovery") + if isinstance(last_recovery, dict): + recovered_at = str(last_recovery.get("at", "n/a")) + recovered_count = str(last_recovery.get("count", 0)) + recovered_reason = str(last_recovery.get("reason", "n/a")) + recovered_jobs = ", ".join(str(job_id) for job_id in last_recovery.get("job_ids", [])) + table.add_row( + "Last recovery", + f"{recovered_count} job(s) via {recovered_reason} at {recovered_at}", + ) + table.add_row("Recovered jobs", recovered_jobs or "none") console.print(table) diff --git a/src/lai/jobs/service.py b/src/lai/jobs/service.py index 9ae8234..19ca935 100644 --- a/src/lai/jobs/service.py +++ b/src/lai/jobs/service.py @@ -249,21 +249,40 @@ def recover_stale_jobs( ) recovered: list[JobRecord] = [] for stale_job in stale_jobs: - if not self.job_store.requeue_job(stale_job.job_id): - continue - self._record_event( + refreshed = self._recover_job( job_id=stale_job.job_id, - stage="job", - event_type="recovered", message="Recovered a stale running job and returned it to the queue.", metadata={ + "recovery_reason": "stale-running-job", "age_seconds": stale_job.age_seconds, "stale_after_seconds": stale_job.stale_after_seconds, "last_event_type": stale_job.last_event_type, "current_stage": stale_job.current_stage, }, ) - refreshed = self.job_store.get_job(stale_job.job_id) + if refreshed is not None: + recovered.append(refreshed) + return recovered + + def recover_running_jobs( + self, + *, + limit: int = 100, + message: str = "Recovered an interrupted running job and returned it to the queue.", + metadata: dict[str, object] | None = None, + ) -> list[JobRecord]: + running_jobs = self.job_store.list_jobs_by_status(JobStatus.RUNNING, limit=limit) + recovered: list[JobRecord] = [] + for job in running_jobs: + refreshed = self._recover_job( + job_id=job.id, + message=message, + metadata={ + "recovery_reason": "running-job-recovery", + "attempts": job.attempts, + **(metadata or {}), + }, + ) if refreshed is not None: recovered.append(refreshed) return recovered @@ -923,6 +942,24 @@ def _record_event( self.job_store.add_stage_event(event) return event + def _recover_job( + self, + *, + job_id: str, + message: str, + metadata: dict[str, object] | None = None, + ) -> JobRecord | None: + if not self.job_store.requeue_job(job_id): + return None + self._record_event( + job_id=job_id, + stage="job", + event_type="recovered", + message=message, + metadata=metadata, + ) + return self.job_store.get_job(job_id) + def _update_worker_state( self, *, @@ -931,6 +968,13 @@ def _update_worker_state( ) -> WorkerStateRecord: current = self.get_worker_state(worker_id) payload = current.model_dump() + if ( + isinstance(payload.get("metadata"), dict) + and isinstance(updates.get("metadata"), dict) + ): + merged_metadata = dict(payload["metadata"]) + merged_metadata.update(updates["metadata"]) + updates = {**updates, "metadata": merged_metadata} payload.update(updates) now = utcnow() payload["heartbeat_at"] = now diff --git a/src/lai/settings.py b/src/lai/settings.py index c33390f..f7962c4 100644 --- a/src/lai/settings.py +++ b/src/lai/settings.py @@ -49,6 +49,8 @@ class Settings(BaseSettings): queue_poll_interval_seconds: int = 5 worker_idle_sleep_seconds: float = 2.0 worker_service_poll_interval_seconds: float = 10.0 + worker_service_recovery_stale_after_seconds: int = 120 + worker_service_recovery_limit: int = 100 job_heartbeat_interval_seconds: float = 30.0 max_retry_attempts: int = 1 stale_running_job_timeout_seconds: int = 900 diff --git a/src/lai/worker/service.py b/src/lai/worker/service.py index 8e1deeb..024c86d 100644 --- a/src/lai/worker/service.py +++ b/src/lai/worker/service.py @@ -9,7 +9,7 @@ from typing import Callable, Mapping from ..application import LAIApplication, create_application -from ..domain import JobStatus, WorkerStatus +from ..domain import JobStatus, WorkerStatus, utcnow from ..layout import ensure_runtime_directories from ..providers import Provider from ..settings import Settings @@ -25,6 +25,15 @@ class WorkerServiceConfig: replace_existing_lock: bool = False +def read_worker_service_lock_metadata(lock_path: Path) -> dict[str, object] | None: + if not lock_path.exists(): + return None + try: + return json.loads(lock_path.read_text(encoding="utf-8")) + except (OSError, json.JSONDecodeError): + return None + + def request_worker_service_stop(settings: Settings | None = None) -> Path: settings = settings or Settings() stop_path = settings.resolved_worker_service_stop_path @@ -48,6 +57,8 @@ def __init__(self, lock_path: Path, *, replace_existing: bool = False) -> None: self.lock_path = lock_path self.replace_existing = replace_existing self.acquired = False + self.replaced_metadata: dict[str, object] | None = None + self.current_metadata: dict[str, object] | None = None def __enter__(self) -> "WorkerServiceLock": self.acquire() @@ -59,6 +70,7 @@ def __exit__(self, *_args: object) -> None: def acquire(self) -> None: self.lock_path.parent.mkdir(parents=True, exist_ok=True) if self.replace_existing and self.lock_path.exists(): + self.replaced_metadata = read_worker_service_lock_metadata(self.lock_path) self.lock_path.unlink() try: descriptor = os.open( @@ -66,9 +78,20 @@ def acquire(self) -> None: os.O_CREAT | os.O_EXCL | os.O_WRONLY, ) except FileExistsError as exc: + existing = read_worker_service_lock_metadata(self.lock_path) + details = "" + if existing: + pid = existing.get("pid") + created_at = existing.get("created_at") + if isinstance(created_at, (int, float)): + age_seconds = max(time.time() - created_at, 0.0) + details = f" Existing lock pid={pid}, age={age_seconds:.1f}s." + else: + details = f" Existing lock pid={pid}." raise RuntimeError( f"Worker service lock already exists at {self.lock_path}. " "Use replace_existing_lock only when you are certain the prior worker is gone." + f"{details}" ) from exc payload = { @@ -78,6 +101,7 @@ def acquire(self) -> None: with os.fdopen(descriptor, "w", encoding="utf-8") as handle: json.dump(payload, handle) self.acquired = True + self.current_metadata = payload def release(self) -> None: if self.acquired: @@ -115,16 +139,27 @@ def run(self) -> int: cycles = 0 clear_worker_service_stop(self.settings) - with WorkerServiceLock( + lock = WorkerServiceLock( self.settings.resolved_worker_service_lock_path, replace_existing=self.config.replace_existing_lock, - ): + ) + with lock: self.logger.info("LAI worker service started for worker_id=%s", self.config.worker_id) + app = self.application_factory() + previous_state = app.orchestration.get_worker_state(self.config.worker_id) + recovery_summary = self._recover_startup_jobs( + app=app, + previous_state=previous_state, + lock=lock, + ) self._mark_service_state( status=WorkerStatus.IDLE, phase="starting", processed_jobs=0, - metadata={"service_cycles": 0}, + metadata={ + "service_cycles": 0, + **({"last_recovery": recovery_summary} if recovery_summary else {}), + }, ) try: while True: @@ -139,10 +174,11 @@ def run(self) -> int: break app = self.application_factory() + cycle_recovery = self._recover_cycle_stale_jobs(app, cycle=cycles + 1) processed = app.orchestration.run_worker_until_idle( max_idle_cycles=self.config.max_idle_cycles_per_drain, max_jobs=self.config.max_jobs_per_cycle, - requeue_running=(cycles == 0), + requeue_running=False, worker_id=self.config.worker_id, mode="daemon", ) @@ -161,6 +197,11 @@ def run(self) -> int: metadata={ "service_cycles": cycles, "last_cycle_processed": processed, + **( + {"last_recovery": cycle_recovery} + if cycle_recovery is not None + else {} + ), }, ) @@ -188,6 +229,99 @@ def run(self) -> int: self.logger.info("LAI worker service exited after processing %s job(s).", total_processed) return total_processed + def _recover_startup_jobs( + self, + *, + app: LAIApplication, + previous_state, + lock: WorkerServiceLock, + ) -> dict[str, object] | None: + running_jobs = app.job_store.list_jobs_by_status( + JobStatus.RUNNING, + limit=self.settings.worker_service_recovery_limit, + ) + if not running_jobs: + return None + + heartbeat_age_seconds = round( + max((utcnow() - previous_state.heartbeat_at).total_seconds(), 0.0), + 2, + ) + replaced_lock = lock.replaced_metadata is not None + previous_status = previous_state.status + stale_previous_worker = ( + previous_status in {WorkerStatus.RUNNING, WorkerStatus.ERROR} + and heartbeat_age_seconds >= self.settings.worker_service_recovery_stale_after_seconds + ) + if not replaced_lock and not stale_previous_worker: + return None + + reason = "startup-lock-replaced" if replaced_lock else "startup-stale-worker" + recovered = app.orchestration.recover_running_jobs( + limit=self.settings.worker_service_recovery_limit, + message="Recovered an interrupted running job during worker service startup.", + metadata={ + "recovery_reason": reason, + "previous_worker_status": previous_status, + "previous_worker_heartbeat_at": previous_state.heartbeat_at.isoformat(), + "previous_worker_current_job_id": previous_state.current_job_id, + "replaced_lock": replaced_lock, + "replaced_lock_pid": ( + lock.replaced_metadata.get("pid") if lock.replaced_metadata else None + ), + }, + ) + if not recovered: + return None + + summary = { + "reason": reason, + "count": len(recovered), + "job_ids": [job.id for job in recovered], + "at": utcnow().isoformat(), + "previous_worker_status": previous_status, + "previous_worker_heartbeat_at": previous_state.heartbeat_at.isoformat(), + "previous_worker_heartbeat_age_seconds": heartbeat_age_seconds, + "replaced_lock": replaced_lock, + "replaced_lock_pid": ( + lock.replaced_metadata.get("pid") if lock.replaced_metadata else None + ), + } + self.logger.warning( + "Worker service startup recovered %s interrupted running job(s) via %s.", + len(recovered), + reason, + ) + return summary + + def _recover_cycle_stale_jobs( + self, + app: LAIApplication, + *, + cycle: int, + ) -> dict[str, object] | None: + recovered = app.orchestration.recover_stale_jobs( + stale_after_seconds=self.settings.worker_service_recovery_stale_after_seconds, + limit=self.settings.worker_service_recovery_limit, + ) + if not recovered: + return None + + summary = { + "reason": "cycle-stale-recovery", + "count": len(recovered), + "job_ids": [job.id for job in recovered], + "at": utcnow().isoformat(), + "cycle": cycle, + "stale_after_seconds": self.settings.worker_service_recovery_stale_after_seconds, + } + self.logger.warning( + "Worker service cycle %s recovered %s stale running job(s).", + cycle, + len(recovered), + ) + return summary + def _mark_service_state( self, *, @@ -199,6 +333,21 @@ def _mark_service_state( ) -> None: app = self.application_factory() current_state = app.orchestration.get_worker_state(self.config.worker_id) + merged_metadata = dict(current_state.metadata) + merged_metadata.update(metadata or {}) + merged_metadata["phase"] = phase + if status != WorkerStatus.RUNNING: + for key in ( + "current_stage", + "current_stage_phase", + "current_stage_progress", + "current_stage_message", + "current_stage_updated_at", + "current_model_id", + "current_provider_id", + "elapsed_seconds", + ): + merged_metadata[key] = None app.orchestration.update_worker_state( worker_id=self.config.worker_id, status=status, @@ -208,7 +357,7 @@ def _mark_service_state( queued_jobs=app.job_store.count_jobs_by_status(JobStatus.QUEUED), started_at=current_state.started_at or current_state.heartbeat_at, last_error=last_error, - metadata={"phase": phase, **(metadata or {})}, + metadata=merged_metadata, ) diff --git a/tests/unit/test_worker_service.py b/tests/unit/test_worker_service.py index 1852e44..7e8ab68 100644 --- a/tests/unit/test_worker_service.py +++ b/tests/unit/test_worker_service.py @@ -1,5 +1,9 @@ +import json +import time +from datetime import timedelta + from lai.application import create_application -from lai.domain import ExecutionRequest, QueueMode +from lai.domain import ExecutionRequest, QueueMode, WorkerStatus, utcnow from lai.settings import Settings from lai.system import collect_system_snapshot from lai.worker.service import ( @@ -10,19 +14,21 @@ from tests.helpers import FakeProvider -def _test_settings(repo_root, tmp_path) -> Settings: - return Settings( - root_dir=tmp_path, - model_catalog=repo_root / "configs/models/catalog.yaml", - routing_policy=repo_root / "configs/routing/policies.yaml", - database_path=tmp_path / "data/state/lai.db", - state_dir=tmp_path / "data/state", - artifacts_dir=tmp_path / "data/artifacts", - logs_dir=tmp_path / "logs", - huggingface_cache_dir=tmp_path / "data/cache/huggingface", - airllm_shards_dir=tmp_path / "data/models/airllm-shards", - raw_models_dir=tmp_path / "data/models/raw", - ) +def _test_settings(repo_root, tmp_path, **overrides) -> Settings: + base = { + "root_dir": tmp_path, + "model_catalog": repo_root / "configs/models/catalog.yaml", + "routing_policy": repo_root / "configs/routing/policies.yaml", + "database_path": tmp_path / "data/state/lai.db", + "state_dir": tmp_path / "data/state", + "artifacts_dir": tmp_path / "data/artifacts", + "logs_dir": tmp_path / "logs", + "huggingface_cache_dir": tmp_path / "data/cache/huggingface", + "airllm_shards_dir": tmp_path / "data/models/airllm-shards", + "raw_models_dir": tmp_path / "data/models/raw", + } + base.update(overrides) + return Settings(**base) def _test_providers(settings: Settings) -> dict[str, FakeProvider]: @@ -92,3 +98,112 @@ def _sleep_and_request_stop(_seconds: float) -> None: assert worker.status == "stopped" assert worker.mode == "daemon" assert not settings.resolved_worker_service_stop_path.exists() + + +def test_worker_service_recovers_interrupted_running_job_after_lock_replace( + repo_root, + tmp_path, +) -> None: + settings = _test_settings(repo_root, tmp_path) + providers = _test_providers(settings) + app = create_application(settings=settings, providers=providers) + job = app.orchestration.submit_request( + ExecutionRequest( + user_prompt="Resume this interrupted overnight daemon execution.", + queue_mode=QueueMode.QUEUED, + ) + ) + running = app.job_store.claim_next_queued_job() + assert running is not None + app.orchestration.update_worker_state( + worker_id="local-worker", + status=WorkerStatus.RUNNING, + mode="daemon", + current_job_id=job.id, + processed_jobs=0, + queued_jobs=0, + metadata={"phase": "processing"}, + ) + settings.resolved_worker_service_lock_path.write_text( + json.dumps({"pid": 4242, "created_at": time.time() - 300}), + encoding="utf-8", + ) + + host = WorkerServiceHost( + settings=settings, + providers=providers, + config=WorkerServiceConfig( + max_cycles=1, + poll_interval_seconds=0.01, + replace_existing_lock=True, + ), + sleep_fn=lambda _seconds: None, + ) + processed = host.run() + refreshed = create_application(settings=settings, providers=providers) + persisted = refreshed.job_store.get_job(job.id) + worker = refreshed.orchestration.get_worker_state() + + assert processed == 1 + assert persisted is not None + assert persisted.status == "succeeded" + assert any( + event.stage == "job" and event.event_type == "recovered" + for event in persisted.stage_events + ) + last_recovery = worker.metadata.get("last_recovery") + assert isinstance(last_recovery, dict) + assert last_recovery["count"] == 1 + assert last_recovery["reason"] == "startup-lock-replaced" + assert last_recovery["replaced_lock"] is True + assert last_recovery["replaced_lock_pid"] == 4242 + + +def test_worker_service_recovers_stale_running_jobs_during_cycle(repo_root, tmp_path) -> None: + settings = _test_settings( + repo_root, + tmp_path, + worker_service_recovery_stale_after_seconds=60, + ) + providers = _test_providers(settings) + app = create_application(settings=settings, providers=providers) + job = app.orchestration.submit_request( + ExecutionRequest( + user_prompt="Recover this stale interrupted run during daemon sweep.", + queue_mode=QueueMode.QUEUED, + ) + ) + running = app.job_store.claim_next_queued_job() + assert running is not None + stale_at = utcnow() - timedelta(minutes=10) + running.updated_at = stale_at + running.started_at = stale_at + app.job_store.save_job(running) + app.orchestration.update_worker_state( + worker_id="local-worker", + status=WorkerStatus.IDLE, + mode="daemon", + current_job_id=None, + processed_jobs=0, + queued_jobs=0, + metadata={"phase": "sleeping"}, + ) + + host = WorkerServiceHost( + settings=settings, + providers=providers, + config=WorkerServiceConfig(max_cycles=1, poll_interval_seconds=0.01), + sleep_fn=lambda _seconds: None, + ) + processed = host.run() + refreshed = create_application(settings=settings, providers=providers) + persisted = refreshed.job_store.get_job(job.id) + worker = refreshed.orchestration.get_worker_state() + + assert processed == 1 + assert persisted is not None + assert persisted.status == "succeeded" + last_recovery = worker.metadata.get("last_recovery") + assert isinstance(last_recovery, dict) + assert last_recovery["count"] == 1 + assert last_recovery["reason"] == "cycle-stale-recovery" From 2b0ddbd91f36c3c89731ecd26dd504dc1d6cb4f4 Mon Sep 17 00:00:00 2001 From: N3uralCreativity Date: Tue, 7 Apr 2026 21:45:01 +0200 Subject: [PATCH 16/16] feat: refactor TransformersProvider to streamline token handling in pipeline --- src/lai/providers/implementations.py | 9 ++-- tests/unit/test_providers.py | 68 ++++++++++++++++++++++++++++ 2 files changed, 73 insertions(+), 4 deletions(-) create mode 100644 tests/unit/test_providers.py diff --git a/src/lai/providers/implementations.py b/src/lai/providers/implementations.py index 354d0c3..831687b 100644 --- a/src/lai/providers/implementations.py +++ b/src/lai/providers/implementations.py @@ -60,16 +60,17 @@ def generate( progress=0.1, cached=False, ) + model_kwargs: dict[str, Any] = { + "torch_dtype": "auto", + } pipe = pipeline( "text-generation", model=model.model_ref, tokenizer=model.model_ref, device_map="auto" if self.system_snapshot.has_gpu else None, trust_remote_code=True, - model_kwargs={ - "token": self.settings.huggingface_token_value, - "torch_dtype": "auto", - }, + token=self.settings.huggingface_token_value, + model_kwargs=model_kwargs, ) self._pipeline_cache[model.id] = pipe else: diff --git a/tests/unit/test_providers.py b/tests/unit/test_providers.py new file mode 100644 index 0000000..e06a84f --- /dev/null +++ b/tests/unit/test_providers.py @@ -0,0 +1,68 @@ +from __future__ import annotations + +import sys +import types + +from lai.domain import ModelRole, ModelRuntime, ModelSpec, ProviderRequest +from lai.providers.implementations import TransformersProvider +from lai.settings import Settings +from lai.system import collect_system_snapshot + + +def _test_settings(repo_root, tmp_path, **overrides) -> Settings: + base = { + "root_dir": tmp_path, + "model_catalog": repo_root / "configs/models/catalog.yaml", + "routing_policy": repo_root / "configs/routing/policies.yaml", + "database_path": tmp_path / "data/state/lai.db", + "state_dir": tmp_path / "data/state", + "artifacts_dir": tmp_path / "data/artifacts", + "logs_dir": tmp_path / "logs", + "huggingface_cache_dir": tmp_path / "data/cache/huggingface", + "airllm_shards_dir": tmp_path / "data/models/airllm-shards", + "raw_models_dir": tmp_path / "data/models/raw", + } + base.update(overrides) + return Settings(**base) + + +def test_transformers_provider_passes_token_once_to_pipeline( + repo_root, + tmp_path, + monkeypatch, +) -> None: + settings = _test_settings(repo_root, tmp_path, hf_token="fake-token") + snapshot = collect_system_snapshot(settings.root_dir, enable_gpu=False) + provider = TransformersProvider(settings, snapshot) + model = ModelSpec( + id="router-small", + role=ModelRole.ROUTER, + runtime=ModelRuntime.TRANSFORMERS, + model_ref="Qwen/Qwen2.5-3B-Instruct", + capabilities=["summarization"], + ) + + captured: dict[str, object] = {} + + def fake_pipeline(*args, **kwargs): + captured["args"] = args + captured["kwargs"] = kwargs + + def _pipe(*_pipe_args, **_pipe_kwargs): + return [{"generated_text": "hello"}] + + return _pipe + + fake_transformers = types.ModuleType("transformers") + fake_transformers.pipeline = fake_pipeline + monkeypatch.setitem(sys.modules, "transformers", fake_transformers) + + result = provider.generate( + model, + ProviderRequest(user_prompt="Say hello."), + ) + + kwargs = captured["kwargs"] + assert kwargs["token"] == "fake-token" + assert "token" not in kwargs["model_kwargs"] + assert result.text == "hello"