diff --git a/.env.example b/.env.example index 8986f00..0415ee7 100644 --- a/.env.example +++ b/.env.example @@ -1,12 +1,24 @@ LAI_ENV=local LAI_HF_TOKEN= +LAI_OPENAI_API_KEY= +LAI_ANTHROPIC_API_KEY= +LAI_GEMINI_API_KEY= LAI_MODEL_CATALOG=configs/models/catalog.yaml LAI_ROUTING_POLICY=configs/routing/policies.yaml LAI_PROMPT_ROOT=configs/prompts LAI_HUGGINGFACE_CACHE_DIR=data/cache/huggingface LAI_AIRLLM_SHARDS_DIR=data/models/airllm-shards +LAI_RAW_MODELS_DIR=data/models/raw LAI_ARTIFACTS_DIR=data/artifacts +LAI_STATE_DIR=data/state +LAI_DATABASE_PATH=data/state/lai.db LAI_LOGS_DIR=logs LAI_ALLOW_OVERNIGHT_JOBS=true LAI_ENABLE_GPU=true LAI_MAX_ROUTER_TOKENS=2048 +LAI_DEFAULT_TIMEOUT_SECONDS=120 +LAI_DEFAULT_MAX_OUTPUT_TOKENS=1024 +LAI_DEFAULT_TEMPERATURE=0.2 +LAI_QUEUE_POLL_INTERVAL_SECONDS=5 +LAI_WORKER_IDLE_SLEEP_SECONDS=2.0 +LAI_MAX_RETRY_ATTEMPTS=1 diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 0bfdf76..b68fa21 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -21,7 +21,7 @@ jobs: - name: Install package run: | python -m pip install --upgrade pip - python -m pip install -e .[dev] + python -m pip install -e .[dev,api] - name: Ruff run: ruff check . diff --git a/.gitignore b/.gitignore index 61aa09e..2e9020a 100644 --- a/.gitignore +++ b/.gitignore @@ -23,5 +23,9 @@ build/ data/cache/ data/models/ data/artifacts/ +data/state/* +!data/state/.gitignore +evals/results/* +!evals/results/.gitignore logs/* !logs/.gitignore diff --git a/README.md b/README.md index 9cf8da1..30c57a1 100644 --- a/README.md +++ b/README.md @@ -78,18 +78,145 @@ cd LAI py -3.11 -m venv .venv .venv\Scripts\Activate.ps1 python -m pip install --upgrade pip -python -m pip install -e .[dev] +python -m pip install -e .[dev,api] Copy-Item .env.example .env python -m lai.cli doctor ``` -To add the large-model runtime later: +To add the local heavy-model and provider backends later: ```powershell -python -m pip install -e .[dev,api] -python -m pip install airllm +python -m pip install -e .[local,providers] +``` + +## Current commands + +```powershell +python -m lai.cli doctor +python -m lai.cli models list +python -m lai.cli models check +python -m lai.cli workstation validate +python -m lai.cli workstation validate --profile airllm +python -m lai.cli route explain "Summarize this note." +python -m lai.cli run "Create a detailed implementation strategy." +python -m lai.cli jobs list +python -m lai.cli jobs stale +python -m lai.cli jobs recover +python -m lai.cli jobs replay --queue-mode queued +python -m lai.cli worker run --once +python -m lai.cli worker run --max-jobs 3 +python -m lai.cli worker run --until-idle --max-idle-cycles 1 +python -m lai.cli worker status +python -m lai.cli worker serve --poll-interval 10 +python -m lai.cli worker stop +python -m lai.cli smoke providers +python -m lai.cli smoke providers --live +python -m lai.cli smoke local --live --include-airllm +python -m lai.cli smoke latest +lai-worker run --poll-interval 10 +lai-worker stop +python -m lai.cli eval route --no-save +``` + +## Current API + +After installing the `api` extra: + +```powershell +uvicorn lai.api.app:create_api --factory --reload +``` + +Available endpoints: + +- `GET /` +- `GET /dashboard` +- `GET /health` +- `GET /models` +- `GET /workstation/readiness` +- `GET /smoke/latest` +- `GET /smoke/results` +- `GET /smoke/results/{result_id}` +- `POST /smoke/run` +- `GET /worker/status` +- `POST /route/explain` +- `POST /jobs` +- `GET /jobs` +- `GET /jobs/stale` +- `GET /jobs/{job_id}` +- `GET /jobs/{job_id}/timeline` +- `GET /jobs/{job_id}/execution` +- `POST /jobs/recover` +- `POST /jobs/{job_id}/replay` +- `GET /jobs/{job_id}/artifacts` +- `GET /jobs/{job_id}/artifacts/{artifact_id}` +- `POST /jobs/{job_id}/cancel` +- `POST /worker/run` + +The API now also serves a read-mostly dashboard at `/dashboard` with live model health, +workstation readiness for heavy local execution, route explanation, recent job inspection, +stage telemetry, provider execution summaries, artifact/trace browsing, replay actions, +persisted worker monitoring, saved smoke diagnostics with history drill-down, +per-check metadata/output previews, stale-job detection/recovery for interrupted runs, +bounded queue worker controls including an until-idle drain path, and live provider +progress phases for long-running execution stages. + +## Dedicated worker service + +For always-on local queue processing, run the dedicated worker service instead of manually +triggering bounded worker batches: + +```powershell +lai-worker run --poll-interval 10 +``` + +The service: + +- acquires a lock at `data/state/worker-service.lock` +- watches for a graceful stop signal at `data/state/worker-service.stop` +- writes service logs to `logs/worker-service.log` +- recovers interrupted running jobs on startup when a stale lock is explicitly replaced +- auto-recovers stale running jobs during daemon cycles using the configured recovery timeout +- keeps the persisted worker state fresh for the CLI, API, and dashboard + +To stop it gracefully: + +```powershell +lai-worker stop +``` + +## Provider smoke diagnostics + +Use readiness-only smoke checks to verify credentials, optional dependencies, and provider health +without making live requests: + +```powershell +python -m lai.cli smoke providers +python -m lai.cli smoke local ``` +Use live mode only when you want a real tiny prompt executed: + +```powershell +python -m lai.cli smoke providers --live +python -m lai.cli smoke local --live --include-airllm +``` + +Smoke results are saved under `evals/results/smoke/` by default. + +## Workstation readiness + +Validate the local machine before you depend on the heavy local path: + +```powershell +python -m lai.cli workstation validate +python -m lai.cli workstation validate --profile airllm +``` + +This surfaces Python compatibility, local package readiness, Hugging Face credentials, +disk headroom, RAM posture, GPU availability, and AirLLM-specific readiness in one report. +For step-by-step remediation and overnight run guidance, see +`docs/setup/airllm-runbook.md`. + ## Initial GitHub rules encoded in this repo - Pull request template and issue forms for consistent planning. @@ -100,10 +227,10 @@ python -m pip install airllm ## Near-term priorities -1. Implement the model registry and routing engine under `src/lai/`. -2. Add the first AirLLM runtime adapter and smoke-test workflows. -3. Introduce an API surface in `apps/api`. -4. Add evaluation scenarios that compare small-model routing against large-model final execution. +1. Expand eval scenarios and richer reviewer/final-output refinement. +2. Add operator-facing controls for smoke result retention and cleanup over time. +3. Add worker-service lock inspection and supervised restart tooling. +4. Add partial-output capture where provider SDKs expose safe non-streaming response chunks. ## References diff --git a/apps/api/README.md b/apps/api/README.md index a58fdc9..fa2d909 100644 --- a/apps/api/README.md +++ b/apps/api/README.md @@ -1,3 +1,20 @@ # API App -This folder is reserved for the future control-plane API. It will expose request submission, job status, artifact retrieval, and model availability endpoints. +The first API surface lives in `src/lai/api/app.py` and mirrors the CLI-first orchestration core. + +Current endpoints: + +- `GET /health` +- `GET /models` +- `POST /route/explain` +- `POST /jobs` +- `GET /jobs` +- `GET /jobs/{job_id}` +- `POST /jobs/{job_id}/cancel` + +Run it after installing the `api` extra: + +```powershell +python -m pip install -e .[dev,api] +uvicorn lai.api.app:create_api --factory --reload +``` diff --git a/apps/web/README.md b/apps/web/README.md index 3a3f94f..7f6092e 100644 --- a/apps/web/README.md +++ b/apps/web/README.md @@ -1,3 +1,25 @@ # Web App -This folder is reserved for the future dashboard that will surface routing decisions, queue state, artifacts, and model health. +The first dashboard is now served directly by the FastAPI app at `/dashboard`. + +Current dashboard capabilities: + +- live health and catalog summary +- workstation readiness summary for heavy local and AirLLM execution paths +- route explanation form +- job submission and recent queue inspection +- persisted stage telemetry timeline for planner, executor, and reviewer flow +- provider execution summaries with duration, live progress phases, output preview, and artifact linkage per stage +- artifact and trace browsing for persisted jobs +- job replay controls for inline and queued reruns +- stale-job detection and one-click queue recovery for interrupted long-running jobs +- persisted worker monitoring with heartbeat, current job, queue depth, and latest daemon recovery summary +- service-aware worker monitoring with daemon lock and stop-signal visibility +- saved smoke diagnostics for provider and local readiness, with recent-run history browsing +- per-check smoke metadata and live-output preview inspection from saved results +- bounded live worker controls for processing queued jobs, including queue drain until idle +- model health cards +- job output inspector + +The implementation intentionally stays lightweight for now by using static assets served from +the API package instead of a separate frontend build pipeline. diff --git a/apps/worker/README.md b/apps/worker/README.md index 1c41b7f..8a4edc2 100644 --- a/apps/worker/README.md +++ b/apps/worker/README.md @@ -1,3 +1,22 @@ # Worker App -This folder is reserved for long-running local or remote workers that execute model jobs, including AirLLM-backed heavy inference tasks. +This area now maps to the dedicated local worker service surface used for always-on queue +processing. + +Primary entrypoints: + +- `lai-worker run --poll-interval 10` +- `lai-worker stop` +- `python -m lai.cli worker serve --poll-interval 10` +- `python -m lai.cli worker stop` + +Service runtime contract: + +- lock file: `data/state/worker-service.lock` +- stop signal: `data/state/worker-service.stop` +- service log: `logs/worker-service.log` +- persisted worker state: `data/state/lai.db` + +The worker service repeatedly drains queued jobs until idle, sleeps for the configured poll +interval, then checks again. This keeps the platform ready for overnight work without requiring +manual `worker run` commands. diff --git a/configs/models/catalog.yaml b/configs/models/catalog.yaml index 0eaa696..7c1b248 100644 --- a/configs/models/catalog.yaml +++ b/configs/models/catalog.yaml @@ -52,3 +52,42 @@ models: allow_layer_sharding: true allow_prefetching: true allow_cpu_fallback: true + + - id: openai-general + role: executor + runtime: openai + model: gpt-5.4-mini + context_window: 128000 + capabilities: + - planning + - summarization + - critique + - validation + - deep-reasoning + - long-form-generation + + - id: anthropic-general + role: executor + runtime: anthropic + model: claude-sonnet-4-20250514 + context_window: 200000 + capabilities: + - planning + - summarization + - critique + - validation + - deep-reasoning + - long-form-generation + + - id: gemini-general + role: executor + runtime: gemini + model: gemini-2.5-flash + context_window: 1000000 + capabilities: + - planning + - summarization + - critique + - validation + - deep-reasoning + - long-form-generation diff --git a/configs/routing/policies.yaml b/configs/routing/policies.yaml index 73f6749..be9d2a8 100644 --- a/configs/routing/policies.yaml +++ b/configs/routing/policies.yaml @@ -42,5 +42,5 @@ tiers: fallbacks: when_gpu_unavailable: planner_model_id: router-small - executor_model_id: verifier-medium - reviewer_model_id: verifier-medium + executor_model_id: openai-general + reviewer_model_id: openai-general diff --git a/data/state/.gitignore b/data/state/.gitignore new file mode 100644 index 0000000..d6b7ef3 --- /dev/null +++ b/data/state/.gitignore @@ -0,0 +1,2 @@ +* +!.gitignore diff --git a/docs/setup/airllm-runbook.md b/docs/setup/airllm-runbook.md new file mode 100644 index 0000000..00fbc13 --- /dev/null +++ b/docs/setup/airllm-runbook.md @@ -0,0 +1,99 @@ +# AirLLM Heavy-Local Runbook + +## Purpose + +Use this runbook before starting the `execution-large` AirLLM path or any similar heavy +local executor that may run overnight. + +## Pre-flight checklist + +Run the workstation validator first: + +```powershell +python -m lai.cli workstation validate --profile airllm +``` + +Look for failures in these areas: + +- `AirLLM dependency stack` +- `Hugging Face credentials` +- `GPU availability` +- `GPU VRAM budget` +- `AirLLM shard storage` +- `Hugging Face cache staging space` + +If the validator reports only warnings, the heavy path is still usable, but you should expect +slower runtime or more disk pressure. + +## First-time setup + +1. Install the local runtime packages: + +```powershell +python -m pip install -e .[local] +``` + +2. Add the Hugging Face token to `.env`: + +```text +LAI_HF_TOKEN=hf_... +``` + +3. Make sure these directories live on a large drive if possible: + +- `data/cache/huggingface` +- `data/models/airllm-shards` + +## Common failure modes + +### Missing `airllm`, `torch`, `accelerate`, or `transformers` + +- Install the local extra again with `python -m pip install -e .[local]` +- Re-run `python -m lai.cli workstation validate --profile airllm` + +### Missing Hugging Face token + +- Set `LAI_HF_TOKEN` in `.env` +- Restart the shell if needed +- Re-run the workstation validator + +### Not enough disk + +- Move `LAI_HUGGINGFACE_CACHE_DIR` or `LAI_AIRLLM_SHARDS_DIR` to a larger drive +- Clear older downloaded models or old shard outputs you no longer need +- Re-run the validator and confirm both cache and shard paths meet the reported target + +### No GPU detected + +- Confirm CUDA drivers and `nvidia-smi` work +- If CPU fallback is allowed, accept that the run may take many hours +- If turnaround matters, route the executor to a remote provider instead + +### VRAM below recommendation + +- AirLLM can still work with aggressive layer sharding on smaller GPUs, but expect slower + movement between disk, RAM, and VRAM +- Keep other GPU-heavy programs closed + +## Overnight run tips + +- Start the dedicated worker service instead of manually stepping jobs: + +```powershell +lai-worker run --poll-interval 10 +``` + +- Check the dashboard at `/dashboard` for: + - workstation readiness + - worker status + - queue depth + - job telemetry and artifacts + +- Keep the machine awake and disable sleep for long local runs. + +## After a failure + +1. Inspect the job timeline and artifacts in the dashboard or `lai jobs show `. +2. Re-run `python -m lai.cli workstation validate --profile airllm`. +3. If the machine is still blocked, replay the job against a remote executor until the local + workstation is ready again. diff --git a/docs/setup/workstation.md b/docs/setup/workstation.md index 5a507f5..2ef9156 100644 --- a/docs/setup/workstation.md +++ b/docs/setup/workstation.md @@ -29,3 +29,64 @@ Add AirLLM only when you are ready to test the large-model path: ```powershell python -m pip install airllm ``` + +## Readiness validation + +Before you trust the heavy local executor, validate the workstation directly: + +```powershell +python -m lai.cli workstation validate +python -m lai.cli workstation validate --profile airllm +``` + +This report checks: + +- Python compatibility for the local tooling stack +- local package availability for `torch`, `transformers`, `accelerate`, `huggingface-hub`, and `airllm` +- Hugging Face token presence for gated local models +- free disk in the Hugging Face cache and AirLLM shard directories +- RAM, GPU presence, and measured VRAM when available + +For a deeper remediation checklist and overnight-run guidance, use the AirLLM runbook in +`docs/setup/airllm-runbook.md`. + +## Always-on local worker + +Once the environment is ready, you can keep queued jobs processing in the background with the +dedicated worker service: + +```powershell +lai-worker run --poll-interval 10 +``` + +Useful local paths: + +- lock file: `data/state/worker-service.lock` +- stop signal: `data/state/worker-service.stop` +- log file: `logs/worker-service.log` + +To stop the worker cleanly from another terminal: + +```powershell +lai-worker stop +``` + +## Provider smoke checks + +Readiness-only smoke checks are safe to run on a fresh workstation because they verify +credentials, packages, and health without sending live prompts: + +```powershell +python -m lai.cli smoke providers +python -m lai.cli smoke local +``` + +Live smoke mode sends a tiny prompt and should be used only when you explicitly want a real +provider call: + +```powershell +python -m lai.cli smoke providers --live +python -m lai.cli smoke local --live --include-airllm +``` + +Saved smoke reports land under `evals/results/smoke/`. diff --git a/evals/results/.gitignore b/evals/results/.gitignore new file mode 100644 index 0000000..d6b7ef3 --- /dev/null +++ b/evals/results/.gitignore @@ -0,0 +1,2 @@ +* +!.gitignore diff --git a/evals/scenarios/routing-basics.yaml b/evals/scenarios/routing-basics.yaml new file mode 100644 index 0000000..56e86ad --- /dev/null +++ b/evals/scenarios/routing-basics.yaml @@ -0,0 +1,22 @@ +version: 1 +scenarios: + - id: instant-summary + description: Short prompt should stay on a small or fast tier. + prompt: Summarize this short note in two bullets. + expected_tier: instant + + - id: deep-work-architecture + description: Comprehensive architecture requests should route to deep-work. + prompt: Create a comprehensive advanced architecture and complete implementation strategy. + expected_tier: deep-work + + - id: provider-fallback + description: If the preferred local heavy model is unavailable, routing should fall back. + prompt: Produce a detailed long-form research answer with strong quality. + expected_tier: deep-work + unavailable_providers: + - airllm + expected_executor_in: + - openai-general + - anthropic-general + - gemini-general diff --git a/pyproject.toml b/pyproject.toml index 28044d6..8a4ff25 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,6 +25,18 @@ api = [ "fastapi>=0.116,<1.0", "uvicorn[standard]>=0.35,<1.0", ] +local = [ + "accelerate>=1.9,<2.0", + "airllm>=2.0,<3.0", + "huggingface-hub>=0.34,<1.0", + "torch>=2.7,<3.0", + "transformers>=4.54,<5.0", +] +providers = [ + "anthropic>=0.60,<1.0", + "google-genai>=1.30,<2.0", + "openai>=2.0,<3.0", +] dev = [ "httpx>=0.28,<1.0", "pytest>=8.3,<9.0", @@ -33,6 +45,12 @@ dev = [ [project.scripts] lai = "lai.cli:app" +lai-worker = "lai.worker.cli:app" [tool.hatch.build.targets.wheel] packages = ["src/lai"] + +[tool.hatch.build] +include = [ + "src/lai/api/static/**", +] diff --git a/scripts/bootstrap/setup-local.ps1 b/scripts/bootstrap/setup-local.ps1 index 4103920..cfdafcc 100644 --- a/scripts/bootstrap/setup-local.ps1 +++ b/scripts/bootstrap/setup-local.ps1 @@ -5,6 +5,7 @@ $directories = @( "data/models/airllm-shards", "data/models/raw", "data/artifacts", + "data/state", "logs" ) diff --git a/src/lai/__init__.py b/src/lai/__init__.py index 0f0a87b..c27714a 100644 --- a/src/lai/__init__.py +++ b/src/lai/__init__.py @@ -1,5 +1,16 @@ """LAI core package.""" +from .config import AppConfig, load_app_config +from .domain import ExecutionRequest, JobStatus, QueueMode, RoutingDecision, WorkerStatus from .settings import Settings -__all__ = ["Settings"] +__all__ = [ + "AppConfig", + "ExecutionRequest", + "JobStatus", + "QueueMode", + "RoutingDecision", + "Settings", + "WorkerStatus", + "load_app_config", +] diff --git a/src/lai/api/__init__.py b/src/lai/api/__init__.py new file mode 100644 index 0000000..d0c81a0 --- /dev/null +++ b/src/lai/api/__init__.py @@ -0,0 +1,3 @@ +from .app import create_api + +__all__ = ["create_api"] diff --git a/src/lai/api/app.py b/src/lai/api/app.py new file mode 100644 index 0000000..fb18879 --- /dev/null +++ b/src/lai/api/app.py @@ -0,0 +1,360 @@ +from __future__ import annotations + +from pathlib import Path +from typing import Literal, Mapping + +from fastapi import FastAPI, HTTPException +from fastapi.responses import FileResponse, RedirectResponse +from fastapi.staticfiles import StaticFiles +from pydantic import BaseModel, Field + +from ..application import create_application +from ..domain import ExecutionRequest, QueueMode +from ..observability import summarize_job_execution +from ..providers import Provider +from ..settings import Settings +from ..smoke import ( + latest_smoke_result, + list_smoke_results, + load_smoke_result, + resolve_smoke_result_path, + run_smoke_suite, + save_smoke_result, +) +from ..workstation import build_workstation_readiness + + +class JobCreatePayload(BaseModel): + user_prompt: str + system_prompt: str | None = None + queue_mode: QueueMode = QueueMode.AUTO + model_override: str | None = None + provider_override: str | None = None + temperature: float | None = Field(default=None, ge=0) + max_output_tokens: int | None = Field(default=None, ge=1) + timeout_seconds: int | None = Field(default=None, ge=1) + reviewer_enabled: bool | None = None + + +class JobReplayPayload(BaseModel): + queue_mode: QueueMode | None = None + + +class JobRecoverPayload(BaseModel): + stale_after_seconds: int | None = Field(default=None, ge=1, le=86400) + limit: int = Field(default=20, ge=1, le=200) + + +class WorkerRunPayload(BaseModel): + max_jobs: int | None = Field(default=None, ge=1, le=25) + resume_running: bool = False + until_idle: bool = False + max_idle_cycles: int = Field(default=2, ge=1, le=20) + + +class SmokeRunPayload(BaseModel): + suite_id: Literal["providers", "local"] + live: bool = False + include_airllm: bool = False + save: bool = True + + +def create_api( + settings: Settings | None = None, + providers: Mapping[str, Provider] | None = None, +) -> FastAPI: + settings = settings or Settings() + api = FastAPI(title="LAI API", version="0.1.0") + static_dir = Path(__file__).resolve().parent / "static" + api.mount( + "/dashboard/assets", + StaticFiles(directory=static_dir), + name="dashboard-assets", + ) + + def application(): + return create_application(settings=settings, providers=providers) + + @api.get("/", include_in_schema=False) + def root() -> RedirectResponse: + return RedirectResponse(url="/dashboard", status_code=307) + + @api.get("/dashboard", include_in_schema=False) + def dashboard() -> FileResponse: + return FileResponse(static_dir / "index.html") + + @api.get("/health") + def health() -> dict[str, object]: + app_state = application() + return { + "status": "ok", + "environment": app_state.settings.environment, + "database": str(app_state.settings.resolved_database_path), + "model_count": len(app_state.config.model_catalog.models), + } + + @api.get("/models") + def list_models() -> dict[str, object]: + app_state = application() + healthchecks = app_state.provider_registry.model_healthchecks(app_state.config) + return { + "models": [ + { + "id": model.id, + "role": model.role, + "runtime": model.runtime, + "model_ref": model.model_ref, + "capabilities": model.capabilities, + "health": healthchecks[model.id].model_dump(), + } + for model in app_state.config.model_catalog.models + ] + } + + @api.get("/workstation/readiness") + def workstation_readiness() -> dict[str, object]: + app_state = application() + report = build_workstation_readiness( + app_state.settings, + app_state.config, + snapshot=app_state.provider_registry.system_snapshot, + ) + return report.model_dump() + + @api.get("/worker/status") + def worker_status() -> dict[str, object]: + app_state = application() + worker = app_state.orchestration.get_worker_state() + payload = worker.model_dump() + payload["service_lock_present"] = ( + app_state.settings.resolved_worker_service_lock_path.exists() + ) + payload["stop_signal_present"] = ( + app_state.settings.resolved_worker_service_stop_path.exists() + ) + payload["log_file"] = str(app_state.settings.resolved_worker_service_log_path) + return payload + + @api.get("/smoke/latest") + def smoke_latest(suite_id: Literal["providers", "local"] | None = None) -> dict[str, object]: + app_state = application() + results_dir = app_state.settings.root_dir / "evals/results/smoke" + latest_path = latest_smoke_result(results_dir, suite_id=suite_id) + if latest_path is None: + raise HTTPException(status_code=404, detail="No smoke result found") + return _serialize_smoke_result(latest_path) + + @api.get("/smoke/results") + def smoke_results( + limit: int = 5, + suite_id: Literal["providers", "local"] | None = None, + ) -> dict[str, object]: + app_state = application() + results_dir = app_state.settings.root_dir / "evals/results/smoke" + paths = list_smoke_results(results_dir, suite_id=suite_id, limit=limit) + return {"results": [_serialize_smoke_result(path) for path in paths]} + + @api.get("/smoke/results/{result_id}") + def smoke_result_detail(result_id: str) -> dict[str, object]: + app_state = application() + results_dir = app_state.settings.root_dir / "evals/results/smoke" + result_path = resolve_smoke_result_path(results_dir, result_id) + if result_path is None: + raise HTTPException(status_code=404, detail="Smoke result not found") + return _serialize_smoke_result(result_path) + + @api.post("/smoke/run") + def smoke_run(payload: SmokeRunPayload) -> dict[str, object]: + app_state = application() + result = run_smoke_suite( + app_state.settings, + suite_id=payload.suite_id, + live=payload.live, + include_airllm=payload.include_airllm, + providers=providers, + ) + saved_path = None + if payload.save: + saved_path = save_smoke_result( + app_state.settings.root_dir / "evals/results/smoke", + result, + ) + return { + "id": saved_path.name if saved_path is not None else None, + "result": result.model_dump(), + "path": str(saved_path) if saved_path is not None else None, + } + + @api.post("/route/explain") + def explain_route(payload: JobCreatePayload) -> dict[str, object]: + app_state = application() + request = _build_execution_request(app_state, payload) + return app_state.routing_engine.route(request).model_dump() + + @api.post("/jobs") + def create_job(payload: JobCreatePayload) -> dict[str, object]: + app_state = application() + request = _build_execution_request(app_state, payload) + return app_state.orchestration.submit_request(request).model_dump() + + @api.post("/jobs/{job_id}/replay") + def replay_job(job_id: str, payload: JobReplayPayload) -> dict[str, object]: + app_state = application() + try: + return app_state.orchestration.replay_job( + job_id, + queue_mode=payload.queue_mode, + ).model_dump() + except KeyError as exc: + raise HTTPException(status_code=404, detail="Job not found") from exc + + @api.get("/jobs") + def list_jobs(limit: int = 20) -> dict[str, object]: + app_state = application() + return {"jobs": [job.model_dump() for job in app_state.job_store.list_jobs(limit=limit)]} + + @api.get("/jobs/stale") + def list_stale_jobs( + stale_after_seconds: int | None = None, + limit: int = 20, + ) -> dict[str, object]: + app_state = application() + stale_jobs = app_state.orchestration.list_stale_jobs( + stale_after_seconds=stale_after_seconds, + limit=limit, + ) + return {"jobs": [job.model_dump() for job in stale_jobs]} + + @api.post("/jobs/recover") + def recover_stale_jobs(payload: JobRecoverPayload) -> dict[str, object]: + app_state = application() + recovered = app_state.orchestration.recover_stale_jobs( + stale_after_seconds=payload.stale_after_seconds, + limit=payload.limit, + ) + return {"recovered": [job.model_dump() for job in recovered]} + + @api.get("/jobs/{job_id}") + def get_job(job_id: str) -> dict[str, object]: + app_state = application() + job = app_state.job_store.get_job(job_id) + if job is None: + raise HTTPException(status_code=404, detail="Job not found") + return job.model_dump() + + @api.get("/jobs/{job_id}/timeline") + def get_job_timeline(job_id: str) -> dict[str, object]: + app_state = application() + job = app_state.job_store.get_job(job_id) + if job is None: + raise HTTPException(status_code=404, detail="Job not found") + return {"stage_events": [event.model_dump() for event in job.stage_events]} + + @api.get("/jobs/{job_id}/execution") + def get_job_execution(job_id: str) -> dict[str, object]: + app_state = application() + job = app_state.job_store.get_job(job_id) + if job is None: + raise HTTPException(status_code=404, detail="Job not found") + return { + "stages": [summary.model_dump() for summary in summarize_job_execution(job)] + } + + @api.get("/jobs/{job_id}/artifacts") + def list_job_artifacts(job_id: str) -> dict[str, object]: + app_state = application() + job = app_state.job_store.get_job(job_id) + if job is None: + raise HTTPException(status_code=404, detail="Job not found") + return {"artifacts": [artifact.model_dump() for artifact in job.artifacts]} + + @api.get("/jobs/{job_id}/artifacts/{artifact_id}") + def get_job_artifact(job_id: str, artifact_id: str) -> dict[str, object]: + app_state = application() + job = app_state.job_store.get_job(job_id) + if job is None: + raise HTTPException(status_code=404, detail="Job not found") + + artifact = next((item for item in job.artifacts if item.id == artifact_id), None) + if artifact is None: + raise HTTPException(status_code=404, detail="Artifact not found") + + try: + content = app_state.artifacts.read_text(artifact) + except FileNotFoundError as exc: + raise HTTPException(status_code=404, detail="Artifact file missing") from exc + except ValueError as exc: + raise HTTPException(status_code=400, detail="Artifact path is invalid") from exc + + return { + "artifact": artifact.model_dump(), + "content_type": app_state.artifacts.content_type(artifact), + "content": content, + } + + @api.post("/jobs/{job_id}/cancel") + def cancel_job(job_id: str) -> dict[str, object]: + app_state = application() + job = app_state.orchestration.cancel_job(job_id) + if job is None: + raise HTTPException(status_code=404, detail="Job not found or not cancelable") + return {"job": job.model_dump()} + + @api.post("/worker/run") + def run_worker(payload: WorkerRunPayload) -> dict[str, object]: + app_state = application() + if payload.until_idle: + processed = app_state.orchestration.run_worker_until_idle( + max_idle_cycles=payload.max_idle_cycles, + max_jobs=payload.max_jobs, + requeue_running=payload.resume_running, + ) + worker_state = app_state.orchestration.get_worker_state() + return {"processed": processed, "jobs": [], "worker": worker_state.model_dump()} + + jobs = app_state.orchestration.run_worker_batch( + max_jobs=payload.max_jobs or 1, + requeue_running=payload.resume_running, + ) + worker_state = app_state.orchestration.get_worker_state() + return { + "processed": len(jobs), + "jobs": [job.model_dump() for job in jobs], + "worker": worker_state.model_dump(), + } + + return api + + +def _build_execution_request(application, payload: JobCreatePayload) -> ExecutionRequest: + return ExecutionRequest( + system_prompt=payload.system_prompt, + user_prompt=payload.user_prompt, + queue_mode=payload.queue_mode, + model_override=payload.model_override, + provider_override=payload.provider_override, + reviewer_enabled=payload.reviewer_enabled, + temperature=( + payload.temperature + if payload.temperature is not None + else application.settings.default_temperature + ), + max_output_tokens=( + payload.max_output_tokens + if payload.max_output_tokens is not None + else application.settings.default_max_output_tokens + ), + timeout_seconds=( + payload.timeout_seconds + if payload.timeout_seconds is not None + else application.settings.default_timeout_seconds + ), + ) + + +def _serialize_smoke_result(path: Path) -> dict[str, object]: + return { + "id": path.name, + "path": str(path), + "result": load_smoke_result(path).model_dump(), + } diff --git a/src/lai/api/static/dashboard.css b/src/lai/api/static/dashboard.css new file mode 100644 index 0000000..e98119a --- /dev/null +++ b/src/lai/api/static/dashboard.css @@ -0,0 +1,707 @@ +:root { + --bg: #f7f1e5; + --panel: rgba(255, 250, 240, 0.9); + --panel-strong: rgba(255, 248, 232, 0.96); + --line: rgba(37, 56, 68, 0.12); + --ink: #1e2c35; + --muted: #5a6a74; + --accent: #0d6b61; + --accent-soft: #d6efe8; + --accent-warm: #c96d39; + --accent-deep: #213f64; + --danger: #9f3b2c; + --shadow: 0 20px 60px rgba(33, 63, 100, 0.12); +} + +* { + box-sizing: border-box; +} + +body { + margin: 0; + min-height: 100vh; + background: + radial-gradient(circle at top left, rgba(13, 107, 97, 0.18), transparent 28%), + radial-gradient(circle at 80% 18%, rgba(201, 109, 57, 0.16), transparent 22%), + linear-gradient(180deg, #f8f2e7 0%, #efe4d0 100%); + color: var(--ink); + font-family: "Space Grotesk", "Aptos", sans-serif; + overflow-x: hidden; +} + +.orb { + position: fixed; + border-radius: 999px; + filter: blur(18px); + pointer-events: none; + z-index: 0; +} + +.orb-a { + width: 18rem; + height: 18rem; + background: rgba(13, 107, 97, 0.12); + top: 3rem; + right: -4rem; +} + +.orb-b { + width: 14rem; + height: 14rem; + background: rgba(201, 109, 57, 0.16); + bottom: 8rem; + left: -3rem; +} + +.grid-lines { + position: fixed; + inset: 0; + background-image: + linear-gradient(rgba(33, 63, 100, 0.04) 1px, transparent 1px), + linear-gradient(90deg, rgba(33, 63, 100, 0.04) 1px, transparent 1px); + background-size: 2rem 2rem; + mask-image: linear-gradient(180deg, rgba(0, 0, 0, 0.45), transparent 85%); + pointer-events: none; + z-index: 0; +} + +.shell { + position: relative; + z-index: 1; + width: min(1360px, calc(100vw - 2rem)); + margin: 0 auto; + padding: 2rem 0 3rem; + display: grid; + gap: 1.25rem; +} + +.panel { + background: var(--panel); + border: 1px solid var(--line); + border-radius: 1.4rem; + box-shadow: var(--shadow); + backdrop-filter: blur(12px); +} + +.hero { + display: grid; + grid-template-columns: minmax(0, 1.5fr) minmax(18rem, 0.9fr); + gap: 1.25rem; + padding: 1.6rem; +} + +.hero-copy h1, +.section-head h2 { + margin: 0; + letter-spacing: -0.04em; +} + +.hero-copy h1 { + font-size: clamp(2.4rem, 5vw, 4.8rem); + line-height: 0.96; + max-width: 11ch; +} + +.lede { + max-width: 58ch; + margin: 1rem 0 0; + color: var(--muted); + font-size: 1rem; +} + +.eyebrow { + margin: 0 0 0.55rem; + color: var(--accent-deep); + font-family: "IBM Plex Mono", monospace; + font-size: 0.74rem; + letter-spacing: 0.16em; + text-transform: uppercase; +} + +.hero-actions, +.action-row { + display: flex; + flex-wrap: wrap; + gap: 0.75rem; + margin-top: 1.2rem; +} + +.button { + border: 0; + border-radius: 999px; + padding: 0.82rem 1.1rem; + font: inherit; + font-weight: 700; + cursor: pointer; + transition: transform 180ms ease, box-shadow 180ms ease, background 180ms ease; + text-decoration: none; +} + +.button:hover { + transform: translateY(-1px); +} + +.primary { + background: linear-gradient(135deg, var(--accent), #0f7e73); + color: white; + box-shadow: 0 16px 32px rgba(13, 107, 97, 0.18); +} + +.secondary { + background: linear-gradient(135deg, var(--accent-warm), #df8755); + color: white; +} + +.ghost { + background: rgba(255, 255, 255, 0.55); + color: var(--ink); + border: 1px solid rgba(33, 63, 100, 0.12); +} + +.small { + padding: 0.58rem 0.9rem; + font-size: 0.84rem; +} + +.hero-metrics { + display: grid; + gap: 0.85rem; +} + +.metric-card { + background: var(--panel-strong); + border: 1px solid rgba(33, 63, 100, 0.12); + border-radius: 1rem; + padding: 1rem; +} + +.metric-label { + display: block; + color: var(--muted); + font-family: "IBM Plex Mono", monospace; + font-size: 0.74rem; + text-transform: uppercase; + letter-spacing: 0.08em; +} + +.metric-card strong { + display: block; + margin-top: 0.45rem; + font-size: 1.4rem; +} + +.route-grid, +.dashboard-grid { + display: grid; + gap: 1.25rem; +} + +.route-grid { + grid-template-columns: minmax(0, 1.1fr) minmax(18rem, 0.9fr); +} + +.dashboard-grid { + grid-template-columns: minmax(0, 1fr) minmax(0, 1fr); +} + +.route-panel, +.route-output, +.jobs-panel, +.detail-panel, +.models-panel, +.workstation-panel, +.smoke-panel { + padding: 1.35rem; +} + +.section-head { + display: flex; + flex-direction: column; + gap: 0.2rem; + margin-bottom: 1rem; +} + +.split-head { + flex-direction: row; + justify-content: space-between; + align-items: start; + gap: 1rem; +} + +.mini-actions { + display: flex; + flex-wrap: wrap; + gap: 0.55rem; +} + +.route-form, +.models-grid, +.jobs-list, +.job-detail, +.reason-list, +.stat-block { + min-height: 8rem; +} + +.route-form label { + display: block; + margin-bottom: 0.9rem; +} + +.route-form span { + display: block; + margin-bottom: 0.35rem; + font-size: 0.92rem; + font-weight: 700; +} + +textarea, +input, +select { + width: 100%; + border: 1px solid rgba(33, 63, 100, 0.14); + background: rgba(255, 255, 255, 0.7); + color: var(--ink); + border-radius: 1rem; + padding: 0.9rem 1rem; + font: inherit; +} + +textarea { + resize: vertical; + min-height: 6rem; +} + +.form-row { + display: grid; + grid-template-columns: minmax(0, 0.45fr) minmax(0, 0.55fr); + gap: 0.8rem; +} + +.stat-block, +.job-detail, +.empty { + display: grid; + gap: 0.75rem; + align-content: start; + color: var(--muted); +} + +.route-chip-row, +.job-meta { + display: flex; + flex-wrap: wrap; + gap: 0.55rem; +} + +.inline-status { + margin-bottom: 0.95rem; + min-height: 2.6rem; + padding: 0.8rem 0.95rem; + border-radius: 1rem; + border: 1px dashed rgba(33, 63, 100, 0.14); + background: rgba(255, 255, 255, 0.5); +} + +.worker-summary { + display: grid; + gap: 0.7rem; + margin-bottom: 0.95rem; + padding: 0.9rem 1rem; + border-radius: 1rem; + border: 1px solid rgba(33, 63, 100, 0.12); + background: rgba(255, 255, 255, 0.58); +} + +.stale-job-list { + display: grid; + gap: 0.65rem; +} + +.stale-job-button { + border: 1px solid rgba(33, 63, 100, 0.12); + background: rgba(255, 255, 255, 0.72); + color: var(--ink); + border-radius: 1rem; + padding: 0.85rem 0.95rem; + text-align: left; + font: inherit; + cursor: pointer; + transition: transform 160ms ease, border-color 160ms ease, background 160ms ease; +} + +.stale-job-button:hover { + transform: translateY(-1px); + border-color: rgba(201, 109, 57, 0.28); + background: rgba(255, 255, 255, 0.9); +} + +.chip { + display: inline-flex; + align-items: center; + gap: 0.3rem; + padding: 0.45rem 0.7rem; + border-radius: 999px; + background: var(--accent-soft); + color: var(--accent); + font-family: "IBM Plex Mono", monospace; + font-size: 0.78rem; +} + +.chip.ready { + background: rgba(13, 107, 97, 0.18); + color: var(--accent); +} + +.chip.warn { + background: rgba(201, 109, 57, 0.14); + color: var(--accent-warm); +} + +.chip.danger { + background: rgba(159, 59, 44, 0.14); + color: var(--danger); +} + +.reason-item, +.job-card, +.model-card, +.timeline-event, +.smoke-card, +.smoke-check { + border: 1px solid rgba(33, 63, 100, 0.11); + background: rgba(255, 255, 255, 0.64); + border-radius: 1rem; +} + +.reason-item { + padding: 0.9rem 1rem; +} + +.reason-stage { + color: var(--accent-deep); + font-family: "IBM Plex Mono", monospace; + font-size: 0.76rem; + letter-spacing: 0.05em; + text-transform: uppercase; +} + +.reason-text { + margin-top: 0.35rem; + color: var(--ink); +} + +.jobs-list { + display: grid; + gap: 0.75rem; +} + +.timeline-list { + display: grid; + gap: 0.75rem; +} + +.smoke-grid { + display: grid; + grid-template-columns: repeat(auto-fit, minmax(280px, 1fr)); + gap: 0.85rem; +} + +.job-card { + padding: 0.95rem 1rem; + cursor: pointer; + transition: transform 160ms ease, border-color 160ms ease, background 160ms ease; +} + +.job-card:hover, +.job-card.active { + transform: translateY(-1px); + border-color: rgba(13, 107, 97, 0.3); + background: rgba(255, 255, 255, 0.88); +} + +.job-card-header, +.model-card-header { + display: flex; + justify-content: space-between; + gap: 1rem; + align-items: start; +} + +.job-id, +.model-runtime, +.mono { + font-family: "IBM Plex Mono", monospace; +} + +.job-status { + font-size: 0.78rem; + text-transform: uppercase; + letter-spacing: 0.08em; +} + +.job-status.succeeded { + color: var(--accent); +} + +.job-status.failed, +.job-status.canceled { + color: var(--danger); +} + +.job-status.running, +.job-status.queued { + color: var(--accent-warm); +} + +.job-output { + padding: 1rem; + border-radius: 1rem; + background: rgba(30, 44, 53, 0.92); + color: #eef4f2; + font-family: "IBM Plex Mono", monospace; + font-size: 0.84rem; + line-height: 1.55; + white-space: pre-wrap; + overflow-wrap: anywhere; +} + +.artifact-strip { + display: flex; + flex-wrap: wrap; + gap: 0.55rem; +} + +.artifact-button { + border: 1px solid rgba(33, 63, 100, 0.14); + background: rgba(255, 255, 255, 0.72); + color: var(--ink); + border-radius: 999px; + padding: 0.58rem 0.8rem; + font: inherit; + font-size: 0.84rem; + cursor: pointer; +} + +.artifact-button.active { + background: var(--accent-soft); + border-color: rgba(13, 107, 97, 0.26); + color: var(--accent); +} + +.stack { + display: grid; + gap: 0.75rem; +} + +.compact-actions { + margin: 0; +} + +.timeline-event { + padding: 0.95rem 1rem; +} + +.timeline-event-head { + display: flex; + justify-content: space-between; + gap: 0.9rem; + align-items: start; +} + +.timeline-time { + color: var(--muted); + font-family: "IBM Plex Mono", monospace; + font-size: 0.76rem; + white-space: nowrap; +} + +.timeline-message { + margin-top: 0.55rem; + color: var(--ink); +} + +.execution-line { + margin-top: 0.55rem; +} + +.timeline-detail, +.worker-detail-line { + margin-top: 0.45rem; +} + +.execution-artifacts { + margin-top: 0.65rem; +} + +.execution-preview { + margin: 0.65rem 0 0; + max-height: 12rem; + overflow: auto; +} + +.smoke-card { + padding: 1rem; + display: grid; + gap: 0.8rem; +} + +.smoke-check-list { + display: grid; + gap: 0.6rem; +} + +.smoke-history { + display: grid; + gap: 0.65rem; +} + +.smoke-history-list { + display: grid; + grid-template-columns: repeat(auto-fit, minmax(180px, 1fr)); + gap: 0.6rem; +} + +.smoke-history-entry { + border: 1px solid rgba(33, 63, 100, 0.12); + background: rgba(255, 255, 255, 0.74); + color: var(--ink); + border-radius: 1rem; + padding: 0.78rem 0.85rem; + text-align: left; + font: inherit; + cursor: pointer; + transition: transform 160ms ease, border-color 160ms ease, background 160ms ease; +} + +.smoke-history-entry:hover, +.smoke-history-entry.active { + transform: translateY(-1px); + border-color: rgba(13, 107, 97, 0.28); + background: rgba(214, 239, 232, 0.62); +} + +.smoke-history-title { + font-weight: 700; +} + +.smoke-history-meta { + margin-top: 0.55rem; +} + +.smoke-check { + padding: 0.8rem 0.9rem; + display: grid; + gap: 0.55rem; +} + +.smoke-check-header { + display: flex; + justify-content: space-between; + gap: 0.75rem; + align-items: start; +} + +.smoke-check-message { + color: var(--ink); +} + +.smoke-check-remediation { + color: var(--muted); + font-size: 0.92rem; +} + +.smoke-detail-line { + overflow-wrap: anywhere; +} + +.smoke-check-reasons { + margin: 0; + padding-left: 1.2rem; + color: var(--muted); +} + +.smoke-output-preview { + margin: 0; + max-height: 14rem; + overflow: auto; +} + +.smoke-meta-details { + display: grid; + gap: 0.55rem; +} + +.smoke-meta-details summary { + cursor: pointer; + color: var(--accent-deep); + font-family: "IBM Plex Mono", monospace; + font-size: 0.76rem; + letter-spacing: 0.05em; + text-transform: uppercase; +} + +.smoke-meta-preview { + font-size: 0.78rem; +} + +.models-grid { + display: grid; + grid-template-columns: repeat(auto-fit, minmax(220px, 1fr)); + gap: 0.85rem; +} + +.model-card { + padding: 1rem; +} + +.model-card ul { + margin: 0.75rem 0 0; + padding-left: 1.1rem; + color: var(--muted); +} + +.fade-in { + animation: rise-in 260ms ease; +} + +@keyframes rise-in { + from { + opacity: 0; + transform: translateY(8px); + } + to { + opacity: 1; + transform: translateY(0); + } +} + +@media (max-width: 960px) { + .hero, + .route-grid, + .dashboard-grid, + .form-row { + grid-template-columns: 1fr; + } + + .hero-copy h1 { + max-width: none; + } + + .shell { + width: min(100vw - 1rem, 1360px); + padding: 1rem 0 2rem; + } + + .split-head { + flex-direction: column; + } + + .timeline-event-head { + flex-direction: column; + } + + .smoke-check-header { + flex-direction: column; + } +} diff --git a/src/lai/api/static/dashboard.js b/src/lai/api/static/dashboard.js new file mode 100644 index 0000000..83098cb --- /dev/null +++ b/src/lai/api/static/dashboard.js @@ -0,0 +1,1276 @@ +const state = { + jobs: [], + staleJobs: [], + models: [], + worker: null, + workstation: null, + smoke: { + providers: createSmokeSuiteState(), + local: createSmokeSuiteState(), + }, + selectedJobId: null, + selectedArtifactId: null, +}; + +const elements = { + environment: document.getElementById("metric-environment"), + modelCount: document.getElementById("metric-model-count"), + jobCount: document.getElementById("metric-job-count"), + queueState: document.getElementById("metric-queue-state"), + workerState: document.getElementById("metric-worker-state"), + jobsList: document.getElementById("jobs-list"), + modelsGrid: document.getElementById("models-grid"), + routeSummary: document.getElementById("route-summary"), + routeReasons: document.getElementById("route-reasons"), + jobDetail: document.getElementById("job-detail"), + routeForm: document.getElementById("route-form"), + refreshAll: document.getElementById("refresh-all"), + submitJob: document.getElementById("submit-job"), + runNextJob: document.getElementById("run-next-job"), + runBatchJobs: document.getElementById("run-batch-jobs"), + drainQueue: document.getElementById("drain-queue"), + recoverStaleJobs: document.getElementById("recover-stale-jobs"), + queueActionStatus: document.getElementById("queue-action-status"), + staleJobSummary: document.getElementById("stale-job-summary"), + workerSummary: document.getElementById("worker-summary"), + workstationSummary: document.getElementById("workstation-summary"), + workstationGrid: document.getElementById("workstation-grid"), + refreshWorkstation: document.getElementById("refresh-workstation"), + smokeGrid: document.getElementById("smoke-grid"), + smokeActionStatus: document.getElementById("smoke-action-status"), + runProviderSmoke: document.getElementById("run-provider-smoke"), + runLocalSmoke: document.getElementById("run-local-smoke"), + runLocalAirllmSmoke: document.getElementById("run-local-airllm-smoke"), + refreshSmoke: document.getElementById("refresh-smoke"), +}; + +async function fetchJson(url, options = undefined) { + const response = await fetch(url, options); + if (!response.ok) { + const body = await response.text(); + throw new Error(body || `Request failed with ${response.status}`); + } + return response.json(); +} + +async function fetchOptionalJson(url, options = undefined) { + const response = await fetch(url, options); + if (response.status === 404) { + return null; + } + if (!response.ok) { + const body = await response.text(); + throw new Error(body || `Request failed with ${response.status}`); + } + return response.json(); +} + +function createSmokeSuiteState() { + return { + latest: null, + history: [], + selectedResultId: null, + selected: null, + }; +} + +function payloadFromForm() { + return { + user_prompt: document.getElementById("user-prompt").value.trim(), + system_prompt: document.getElementById("system-prompt").value.trim() || null, + queue_mode: document.getElementById("queue-mode").value, + model_override: document.getElementById("model-override").value.trim() || null, + }; +} + +function ensurePrompt(payload) { + if (!payload.user_prompt) { + throw new Error("Add a user prompt first."); + } +} + +function chip(label, value, extraClass = "") { + const className = extraClass ? `chip ${extraClass}` : "chip"; + return `${label}: ${escapeHtml(String(value))}`; +} + +function setQueueActionStatus(message, tone = "") { + if (!elements.queueActionStatus) { + return; + } + const toneChip = tone ? `${chip("worker", tone, tone)} ` : ""; + elements.queueActionStatus.className = "inline-status fade-in"; + elements.queueActionStatus.innerHTML = `${toneChip}${escapeHtml(message)}`; +} + +function setSmokeActionStatus(message, tone = "") { + if (!elements.smokeActionStatus) { + return; + } + const toneChip = tone ? `${chip("smoke", tone, tone)} ` : ""; + elements.smokeActionStatus.className = "inline-status fade-in"; + elements.smokeActionStatus.innerHTML = `${toneChip}${escapeHtml(message)}`; +} + +function escapeHtml(value) { + return value + .replaceAll("&", "&") + .replaceAll("<", "<") + .replaceAll(">", ">") + .replaceAll('"', """) + .replaceAll("'", "'"); +} + +function formatProgressPercent(value) { + if (value === null || value === undefined || Number.isNaN(Number(value))) { + return null; + } + return `${Math.round(Number(value) * 100)}%`; +} + +async function loadOverview() { + const [health, modelsResponse, jobsResponse, staleJobsResponse, workerResponse, workstationResponse] = await Promise.all([ + fetchJson("/health"), + fetchJson("/models"), + fetchJson("/jobs?limit=12"), + fetchJson("/jobs/stale?limit=12"), + fetchJson("/worker/status"), + fetchJson("/workstation/readiness"), + ]); + + state.models = modelsResponse.models; + state.jobs = jobsResponse.jobs; + state.staleJobs = staleJobsResponse.jobs || []; + state.worker = workerResponse; + state.workstation = workstationResponse; + + elements.environment.textContent = health.environment; + elements.modelCount.textContent = String(health.model_count); + elements.jobCount.textContent = String(state.jobs.length); + elements.queueState.textContent = summarizeQueue(state.jobs); + renderStaleJobs(); + renderWorkerMonitor(); + renderWorkstationReadiness(); + await loadSmokeDiagnostics(); + + renderModels(); + renderJobs(); + + if (state.selectedJobId) { + await selectJob(state.selectedJobId); + } else if (state.jobs.length > 0) { + await selectJob(state.jobs[0].id); + } +} + +function renderStaleJobs() { + if (!elements.staleJobSummary) { + return; + } + + if (!state.staleJobs || state.staleJobs.length === 0) { + elements.staleJobSummary.className = "worker-summary empty"; + elements.staleJobSummary.textContent = "No stale running jobs detected."; + return; + } + + const cards = state.staleJobs + .map( + (job) => ` + + `, + ) + .join(""); + + elements.staleJobSummary.className = "worker-summary fade-in"; + elements.staleJobSummary.innerHTML = ` +
+ ${chip("stale", state.staleJobs.length, "danger")} + ${chip("threshold", `${state.staleJobs[0].stale_after_seconds}s`, "warn")} +
+
${cards}
+ `; + + elements.staleJobSummary.querySelectorAll("[data-stale-job-id]").forEach((button) => { + button.addEventListener("click", () => { + void selectJob(button.getAttribute("data-stale-job-id")).catch(handleError); + }); + }); +} + +async function loadWorkstationReadiness() { + state.workstation = await fetchJson("/workstation/readiness"); + renderWorkstationReadiness(); +} + +async function loadSmokeDiagnostics() { + const [providerLatest, localLatest, providerHistory, localHistory] = await Promise.all([ + fetchOptionalJson("/smoke/latest?suite_id=providers"), + fetchOptionalJson("/smoke/latest?suite_id=local"), + fetchJson("/smoke/results?suite_id=providers&limit=6"), + fetchJson("/smoke/results?suite_id=local&limit=6"), + ]); + + updateSmokeSuiteState("providers", providerLatest, providerHistory?.results || []); + updateSmokeSuiteState("local", localLatest, localHistory?.results || []); + renderSmokeDiagnostics(); +} + +function updateSmokeSuiteState(suiteId, latestEntry, historyEntries) { + const suiteState = state.smoke[suiteId]; + const mergedHistory = [...historyEntries]; + if (latestEntry && !mergedHistory.some((entry) => entry.id === latestEntry.id)) { + mergedHistory.unshift(latestEntry); + } + if ( + suiteState.selected && + !mergedHistory.some((entry) => entry.id === suiteState.selected.id) + ) { + mergedHistory.unshift(suiteState.selected); + } + + suiteState.latest = latestEntry || null; + suiteState.history = mergedHistory; + + const selectedEntry = + mergedHistory.find((entry) => entry.id === suiteState.selectedResultId) || + latestEntry || + mergedHistory[0] || + null; + suiteState.selectedResultId = selectedEntry?.id || null; + suiteState.selected = selectedEntry; +} + +function renderSmokeDiagnostics() { + if (!elements.smokeGrid) { + return; + } + + const suites = [ + { key: "providers", label: "Providers", suite: state.smoke.providers }, + { key: "local", label: "Local", suite: state.smoke.local }, + ]; + elements.smokeGrid.className = "smoke-grid fade-in"; + elements.smokeGrid.innerHTML = suites + .map((suite) => renderSmokeCard(suite.key, suite.label, suite.suite)) + .join(""); + + elements.smokeGrid.querySelectorAll("[data-smoke-suite][data-smoke-result-id]").forEach((button) => { + button.addEventListener("click", () => { + void selectSmokeResult( + button.getAttribute("data-smoke-suite"), + button.getAttribute("data-smoke-result-id"), + ).catch(handleSmokeError); + }); + }); +} + +function renderWorkstationReadiness() { + if (!elements.workstationSummary || !elements.workstationGrid) { + return; + } + + if (!state.workstation) { + elements.workstationSummary.className = "worker-summary empty"; + elements.workstationSummary.textContent = "Workstation readiness unavailable."; + elements.workstationGrid.className = "smoke-grid empty"; + elements.workstationGrid.textContent = "No workstation profiles loaded."; + return; + } + + const system = state.workstation.system || {}; + const profiles = state.workstation.profiles || []; + const blockingProfiles = profiles.filter((profile) => !profile.ready).length; + const warningProfiles = profiles.filter( + (profile) => profile.ready && profile.warning_count > 0, + ).length; + + elements.workstationSummary.className = "worker-summary fade-in"; + elements.workstationSummary.innerHTML = ` +
+ ${chip("platform", system.platform || "n/a")} + ${chip("python", system.python_version || "n/a")} + ${chip("gpu", system.gpu_name || "none", system.gpu_name ? "ready" : "warn")} + ${chip("vram", formatGigabytes(system.gpu_total_memory_gb), system.gpu_total_memory_gb ? "ready" : "warn")} + ${chip("ram", formatGigabytes(system.total_ram_gb), system.total_ram_gb ? "ready" : "warn")} + ${chip("blocked", blockingProfiles, blockingProfiles > 0 ? "danger" : "ready")} + ${chip("warnings", warningProfiles, warningProfiles > 0 ? "warn" : "ready")} +
+
+ hf cache: ${escapeHtml(system.huggingface_cache_dir || "n/a")} | + airllm shards: ${escapeHtml(system.airllm_shards_dir || "n/a")} +
+ `; + + if (profiles.length === 0) { + elements.workstationGrid.className = "smoke-grid empty"; + elements.workstationGrid.textContent = "No workstation profiles defined."; + return; + } + + elements.workstationGrid.className = "smoke-grid fade-in"; + elements.workstationGrid.innerHTML = profiles + .map((profile) => renderWorkstationProfile(profile)) + .join(""); +} + +function renderWorkstationProfile(profile) { + const tone = !profile.ready + ? "danger" + : profile.warning_count > 0 + ? "warn" + : "ready"; + const checks = (profile.checks || []) + .map((check) => renderWorkstationCheck(check)) + .join(""); + + return ` +
+
+

${escapeHtml(profile.profile_id)}

+

${escapeHtml(profile.title)}

+
+
+ ${chip("status", profile.ready ? "ready" : "blocked", tone)} + ${chip("warnings", profile.warning_count, profile.warning_count > 0 ? "warn" : "ready")} + ${chip("failures", profile.failure_count, profile.failure_count > 0 ? "danger" : "ready")} +
+
${escapeHtml(profile.summary)}
+
${checks}
+
+ `; +} + +function renderWorkstationCheck(check) { + const tone = + check.status === "fail" + ? "danger" + : check.status === "warn" + ? "warn" + : "ready"; + const remediation = check.remediation + ? `
${escapeHtml(check.remediation)}
` + : ""; + + return ` +
+
+ ${escapeHtml(check.title)} +
${chip("status", check.status, tone)}
+
+
${escapeHtml(check.summary)}
+ ${remediation} + ${renderSmokeSnapshot("Check metadata", check.metadata)} +
+ `; +} + +function renderSmokeCard(suiteId, label, suiteState) { + const selectedEntry = suiteState.selected || suiteState.latest || suiteState.history[0] || null; + if (!selectedEntry) { + return ` +
+
+

${escapeHtml(label)}

+

No saved smoke result

+
+
Run a readiness sweep to capture the latest diagnostics for ${escapeHtml(label.toLowerCase())}.
+
+ `; + } + + const result = selectedEntry.result; + const summaryTone = result.passed ? "ready" : result.failed_count > 0 ? "danger" : "warn"; + const executedAt = new Date(result.executed_at).toLocaleString(); + const history = suiteState.history + .map((entry) => renderSmokeHistoryEntry(suiteId, entry, suiteState.selectedResultId)) + .join(""); + const checks = (result.checks || []) + .map((check) => renderSmokeCheck(check)) + .join(""); + + return ` +
+
+

${escapeHtml(label)}

+

${escapeHtml(result.mode)} summary

+
+
+ ${chip("status", result.passed ? "passing" : "attention", summaryTone)} + ${chip("checks", result.total)} + ${chip("success", result.success_count, result.success_count > 0 ? "ready" : "")} + ${chip("blocked", result.blocked_count, result.blocked_count > 0 ? "warn" : "")} + ${chip("failed", result.failed_count, result.failed_count > 0 ? "danger" : "")} +
+
executed: ${escapeHtml(executedAt)} | saved: ${escapeHtml(selectedEntry.id)}
+
path: ${escapeHtml(selectedEntry.path)}
+
prompt: ${escapeHtml(result.prompt)}
+
+
Recent saved runs
+
${history}
+
+
${checks}
+
+ `; +} + +function renderSmokeHistoryEntry(suiteId, entry, selectedResultId) { + const result = entry.result; + const tone = result.passed ? "ready" : result.failed_count > 0 ? "danger" : "warn"; + const activeClass = entry.id === selectedResultId ? " active" : ""; + return ` + + `; +} + +function renderSmokeCheck(check) { + const tone = + check.status === "failed" + ? "danger" + : check.status === "blocked" + ? "warn" + : "ready"; + const detailChips = [ + chip("status", check.status, tone), + chip("mode", check.mode), + chip("runtime", check.runtime), + chip("available", check.available ? "yes" : "no", check.available ? "ready" : "warn"), + chip("healthy", check.healthy ? "yes" : "no", check.healthy ? "ready" : "warn"), + ]; + if (check.duration_seconds !== null && check.duration_seconds !== undefined) { + detailChips.push(chip("duration", `${check.duration_seconds.toFixed(2)}s`)); + } + if (check.executed) { + detailChips.push(chip("executed", check.success ? "yes" : "failed", check.success ? "ready" : "danger")); + } + + const reasons = (check.reasons || []).length + ? ` + + ` + : ""; + const preview = check.output_preview + ? ` +
+
Output preview
+
${escapeHtml(check.output_preview)}
+
+ ` + : ""; + + return ` +
+
+
+ ${escapeHtml(check.provider_id)} +
${escapeHtml(check.model_ref)}
+
+
${detailChips.join("")}
+
+
${escapeHtml(check.message)}
+ ${reasons} + ${preview} + ${renderSmokeSnapshot("Metadata snapshot", check.metadata)} + ${renderSmokeSnapshot("Capabilities snapshot", check.capabilities)} +
+ `; +} + +function renderSmokeSnapshot(label, payload) { + if (!payload || Object.keys(payload).length === 0) { + return ""; + } + return ` +
+ ${escapeHtml(label)} +
${escapeHtml(
+        JSON.stringify(payload, null, 2),
+      )}
+
+ `; +} + +function formatGigabytes(value) { + if (value === null || value === undefined) { + return "n/a"; + } + return `${Number(value).toFixed(2)} GB`; +} + +async function selectSmokeResult(suiteId, resultId) { + if (!suiteId || !resultId || !state.smoke[suiteId]) { + return; + } + + const suiteState = state.smoke[suiteId]; + const knownEntry = + suiteState.history.find((entry) => entry.id === resultId) || + (suiteState.latest?.id === resultId ? suiteState.latest : null) || + (suiteState.selected?.id === resultId ? suiteState.selected : null); + const entry = + knownEntry || + (await fetchJson(`/smoke/results/${encodeURIComponent(resultId)}`)); + + if (!suiteState.history.some((item) => item.id === entry.id)) { + suiteState.history.unshift(entry); + } + suiteState.selectedResultId = entry.id; + suiteState.selected = entry; + renderSmokeDiagnostics(); +} + +function renderWorkerMonitor() { + if (!state.worker) { + elements.workerState.textContent = "unknown"; + if (elements.workerSummary) { + elements.workerSummary.className = "worker-summary empty"; + elements.workerSummary.textContent = "Worker state unavailable."; + } + return; + } + + const worker = state.worker; + const tone = + worker.status === "error" + ? "danger" + : worker.status === "running" + ? "warn" + : "ready"; + + elements.workerState.textContent = `${worker.status} / ${worker.mode}`; + if (!elements.workerSummary) { + return; + } + + const chips = [ + chip("status", worker.status, tone), + chip("mode", worker.mode), + chip("queued", worker.queued_jobs), + chip("processed", worker.processed_jobs), + chip("service", worker.service_lock_present ? "locked" : "inactive", worker.service_lock_present ? "warn" : ""), + ]; + if (worker.current_job_id) { + chips.push(chip("current", worker.current_job_id.slice(0, 8), tone)); + } + if (worker.last_job_id) { + chips.push(chip("last", worker.last_job_id.slice(0, 8))); + } + + const heartbeat = worker.heartbeat_at + ? new Date(worker.heartbeat_at).toLocaleString() + : "n/a"; + const startedAt = worker.started_at ? new Date(worker.started_at).toLocaleString() : "n/a"; + const phase = worker.metadata?.phase ? `phase: ${worker.metadata.phase}` : "phase: n/a"; + const currentStage = worker.metadata?.current_stage || "n/a"; + const currentStagePhase = worker.metadata?.current_stage_phase || "n/a"; + const currentStageProgress = + formatProgressPercent(worker.metadata?.current_stage_progress) || "n/a"; + const currentStageMessage = worker.metadata?.current_stage_message || ""; + const lastRecovery = worker.metadata?.last_recovery || null; + const logFile = worker.log_file ? `log: ${worker.log_file}` : "log: n/a"; + const stopSignal = worker.stop_signal_present ? '
stop signal: present
' : ""; + const error = worker.last_error + ? `
error: ${escapeHtml(worker.last_error)}
` + : ""; + const recoveryLine = lastRecovery + ? `last recovery: ${lastRecovery.count || 0} job(s) via ${lastRecovery.reason || "n/a"} at ${lastRecovery.at || "n/a"}` + : ""; + const recoveryJobs = Array.isArray(lastRecovery?.job_ids) && lastRecovery.job_ids.length > 0 + ? `recovered jobs: ${lastRecovery.job_ids.map((jobId) => String(jobId).slice(0, 8)).join(", ")}` + : ""; + + elements.workerSummary.className = "worker-summary fade-in"; + elements.workerSummary.innerHTML = ` +
${chips.join("")}
+
heartbeat: ${escapeHtml(heartbeat)} | started: ${escapeHtml(startedAt)} | ${escapeHtml(phase)}
+
stage: ${escapeHtml(currentStage)} | provider phase: ${escapeHtml(currentStagePhase)} | progress: ${escapeHtml(currentStageProgress)}
+ ${currentStageMessage ? `
${escapeHtml(currentStageMessage)}
` : ""} + ${recoveryLine ? `
${escapeHtml(recoveryLine)}
` : ""} + ${recoveryJobs ? `
${escapeHtml(recoveryJobs)}
` : ""} +
${escapeHtml(logFile)}
+ ${stopSignal} + ${error} + `; +} + +function summarizeQueue(jobs) { + if (jobs.length === 0) { + return "idle"; + } + const counts = jobs.reduce( + (accumulator, job) => { + accumulator[job.status] = (accumulator[job.status] || 0) + 1; + return accumulator; + }, + {}, + ); + return Object.entries(counts) + .map(([key, value]) => `${key}:${value}`) + .join(" | "); +} + +function renderModels() { + if (state.models.length === 0) { + elements.modelsGrid.className = "models-grid empty"; + elements.modelsGrid.textContent = "No models loaded."; + return; + } + + elements.modelsGrid.className = "models-grid fade-in"; + elements.modelsGrid.innerHTML = state.models + .map((model) => { + const health = model.health; + const healthClass = health.available && health.healthy ? "" : "warn"; + const details = health.reasons && health.reasons.length > 0 ? health.reasons : ["ready"]; + return ` +
+
+
+ ${escapeHtml(model.id)} +
${escapeHtml(model.runtime)}
+
+ ${chip("health", health.available ? "ready" : "blocked", healthClass)} +
+

${escapeHtml(model.model_ref)}

+
    + ${details.map((detail) => `
  • ${escapeHtml(detail)}
  • `).join("")} +
+
+ `; + }) + .join(""); +} + +function renderJobs() { + if (state.jobs.length === 0) { + elements.jobsList.className = "jobs-list empty"; + elements.jobsList.textContent = "No persisted jobs yet."; + return; + } + + elements.jobsList.className = "jobs-list fade-in"; + elements.jobsList.innerHTML = state.jobs + .map((job) => { + const activeClass = job.id === state.selectedJobId ? " active" : ""; + return ` +
+
+
+
${escapeHtml(job.id.slice(0, 8))}
+ ${escapeHtml(job.route_decision?.executor_model_id || "n/a")} +
+
${escapeHtml(job.status)}
+
+

${escapeHtml(job.request.user_prompt.slice(0, 120))}

+
+ `; + }) + .join(""); + + elements.jobsList.querySelectorAll("[data-job-id]").forEach((card) => { + card.addEventListener("click", () => { + void selectJob(card.getAttribute("data-job-id")); + }); + }); +} + +async function selectJob(jobId) { + if (!jobId) { + return; + } + + state.selectedJobId = jobId; + state.selectedArtifactId = null; + renderJobs(); + + const job = await fetchJson(`/jobs/${jobId}`); + const chips = [ + chip( + "status", + job.status, + job.status === "failed" ? "danger" : job.status === "succeeded" ? "" : "warn", + ), + chip("queue", job.queue_mode), + chip("tier", job.route_decision?.matched_tier_id || "n/a"), + chip("executor", job.route_decision?.executor_model_id || "n/a"), + ].join(""); + + const output = job.result?.text + ? `
${escapeHtml(job.result.text)}
` + : `
No final output stored for this job yet.
`; + + const error = job.error?.message + ? `
error: ${escapeHtml(job.error.message)}
` + : ""; + const cancelAction = + job.status === "queued" || job.status === "running" + ? '' + : ""; + + const reasons = (job.route_decision?.reasons || []) + .map( + (reason) => ` +
+
${escapeHtml(reason.stage)}
+
${escapeHtml(reason.message)}
+
+ `, + ) + .join(""); + + elements.jobDetail.className = "job-detail fade-in"; + elements.jobDetail.innerHTML = ` +
${chips}
+
+ + + ${cancelAction} +
+ ${error} +
+

Final output

+ ${output} +
+
+

Route trace

+
${reasons || '
No route reasons recorded.
'}
+
+
+

Provider execution

+
Loading execution summaries...
+
+
+

Stage timeline

+
Loading stage telemetry...
+
+
+

Artifacts

+
Loading artifacts...
+
Select an artifact to inspect it.
+
+ `; + + document.getElementById("replay-inline-job")?.addEventListener("click", () => { + void replayJob(job.id, "inline").catch(handleQueueError); + }); + document.getElementById("replay-queued-job")?.addEventListener("click", () => { + void replayJob(job.id, "queued").catch(handleQueueError); + }); + document.getElementById("cancel-selected-job")?.addEventListener("click", () => { + void cancelJob(job.id).catch(handleQueueError); + }); + + await loadExecution(job.id); + await loadTimeline(job.id, job.stage_events || []); + await loadArtifacts(job.id, job.artifacts || []); +} + +async function loadExecution(jobId) { + const response = await fetchJson(`/jobs/${jobId}/execution`); + const stages = response.stages || []; + const executionList = document.getElementById("execution-list"); + if (!executionList) { + return; + } + + if (stages.length === 0) { + executionList.className = "timeline-list empty"; + executionList.textContent = "No provider execution summaries recorded for this job yet."; + return; + } + + executionList.className = "timeline-list fade-in"; + executionList.innerHTML = stages + .map((stage) => renderExecutionStage(jobId, stage)) + .join(""); + + executionList.querySelectorAll("[data-execution-artifact-id]").forEach((button) => { + button.addEventListener("click", () => { + void selectArtifact( + jobId, + button.getAttribute("data-execution-artifact-id"), + [], + ).catch(handleError); + }); + }); +} + +function renderExecutionStage(jobId, stage) { + const tone = + stage.status === "failed" || stage.status === "blocked" + ? "danger" + : stage.status === "running" + ? "warn" + : "ready"; + const detailChips = [ + chip("status", stage.status, tone), + chip("provider", stage.provider_id || "n/a"), + chip("model", stage.model_id || "n/a"), + ]; + if (stage.duration_seconds !== null && stage.duration_seconds !== undefined) { + detailChips.push(chip("duration", `${stage.duration_seconds.toFixed(2)}s`)); + } + if (stage.finish_reason) { + detailChips.push(chip("finish", stage.finish_reason)); + } + if (stage.prompt_characters !== null && stage.prompt_characters !== undefined) { + detailChips.push(chip("prompt", `${stage.prompt_characters} chars`)); + } + if (stage.output_characters !== null && stage.output_characters !== undefined) { + detailChips.push(chip("output", `${stage.output_characters} chars`)); + } + if (stage.progress_event_count) { + detailChips.push(chip("progress", stage.progress_event_count, "warn")); + } + if (stage.usage?.total_tokens !== null && stage.usage?.total_tokens !== undefined) { + detailChips.push(chip("tokens", stage.usage.total_tokens)); + } + if (stage.latest_progress_percent !== null && stage.latest_progress_percent !== undefined) { + detailChips.push( + chip("latest", formatProgressPercent(stage.latest_progress_percent), "warn"), + ); + } + + const artifactChips = (stage.artifact_types || []) + .map((artifactType, index) => renderExecutionArtifactLink(jobId, stage, artifactType, index)) + .join(""); + const preview = stage.output_preview + ? `
${escapeHtml(stage.output_preview)}
` + : ""; + const healthReasons = (stage.health_reasons || []).length + ? ` + + ` + : ""; + const error = stage.error_type + ? `
error: ${escapeHtml(stage.error_type)}
` + : ""; + const timestamps = [ + stage.started_at ? `started: ${new Date(stage.started_at).toLocaleString()}` : null, + stage.finished_at ? `finished: ${new Date(stage.finished_at).toLocaleString()}` : null, + ] + .filter(Boolean) + .join(" | "); + const requestKnobs = [ + stage.temperature !== null && stage.temperature !== undefined + ? `temp ${stage.temperature}` + : null, + stage.max_output_tokens ? `max ${stage.max_output_tokens}` : null, + stage.timeout_seconds ? `timeout ${stage.timeout_seconds}s` : null, + ] + .filter(Boolean) + .join(" | "); + const latestProgress = [ + stage.latest_progress_phase ? `phase ${stage.latest_progress_phase}` : null, + stage.latest_progress_percent !== null && stage.latest_progress_percent !== undefined + ? formatProgressPercent(stage.latest_progress_percent) + : null, + stage.latest_progress_at + ? `updated ${new Date(stage.latest_progress_at).toLocaleString()}` + : null, + ] + .filter(Boolean) + .join(" | "); + + return ` +
+
+
+
${escapeHtml(stage.stage)}
+
${detailChips.join("")}
+
+
${escapeHtml(timestamps || "timing unavailable")}
+
+
${escapeHtml(stage.last_message || "No execution message recorded.")}
+ ${requestKnobs ? `
${escapeHtml(requestKnobs)}
` : ""} + ${latestProgress ? `
${escapeHtml(latestProgress)}
` : ""} + ${stage.latest_progress_message ? `
${escapeHtml(stage.latest_progress_message)}
` : ""} + ${artifactChips ? `
${artifactChips}
` : ""} + ${error} + ${healthReasons} + ${preview} +
+ `; +} + +function renderExecutionArtifactLink(jobId, stage, artifactType, index) { + const artifactId = stage.artifact_ids?.[index]; + if (!artifactId) { + return chip("artifact", artifactType); + } + + return ` + + `; +} + +async function loadTimeline(jobId, stageEvents) { + const response = + stageEvents.length > 0 + ? { stage_events: stageEvents } + : await fetchJson(`/jobs/${jobId}/timeline`); + const events = response.stage_events || []; + const timeline = document.getElementById("timeline-list"); + if (!timeline) { + return; + } + + if (events.length === 0) { + timeline.className = "timeline-list empty"; + timeline.textContent = "No stage telemetry recorded for this job yet."; + return; + } + + timeline.className = "timeline-list fade-in"; + timeline.innerHTML = events + .map((event) => { + const tone = + event.event_type === "failed" || event.event_type === "blocked" + ? "danger" + : event.event_type === "running" || event.event_type === "started" || event.event_type === "progress" + ? "warn" + : "ready"; + const detailChips = [ + chip("stage", event.stage), + chip("event", event.event_type, tone), + ]; + if (event.model_id) { + detailChips.push(chip("model", event.model_id)); + } + if (event.provider_id) { + detailChips.push(chip("provider", event.provider_id)); + } + const metadataLine = [ + event.metadata?.phase ? `phase ${event.metadata.phase}` : null, + event.metadata?.progress !== null && event.metadata?.progress !== undefined + ? formatProgressPercent(event.metadata.progress) + : null, + event.metadata?.elapsed_seconds !== null && event.metadata?.elapsed_seconds !== undefined + ? `elapsed ${event.metadata.elapsed_seconds}s` + : null, + ] + .filter(Boolean) + .join(" | "); + return ` +
+
+
${detailChips.join("")}
+
${escapeHtml(new Date(event.created_at).toLocaleString())}
+
+
${escapeHtml(event.message)}
+ ${metadataLine ? `
${escapeHtml(metadataLine)}
` : ""} +
+ `; + }) + .join(""); +} + +async function loadArtifacts(jobId, jobArtifacts) { + const response = + jobArtifacts.length > 0 ? { artifacts: jobArtifacts } : await fetchJson(`/jobs/${jobId}/artifacts`); + const artifacts = response.artifacts || []; + const strip = document.getElementById("artifact-strip"); + const preview = document.getElementById("artifact-preview"); + if (!strip || !preview) { + return; + } + + if (artifacts.length === 0) { + strip.className = "artifact-strip empty"; + strip.textContent = "No artifacts recorded for this job."; + preview.className = "empty"; + preview.textContent = "No artifact preview available."; + return; + } + + strip.className = "artifact-strip fade-in"; + strip.innerHTML = artifacts + .map((artifact) => { + const activeClass = artifact.id === state.selectedArtifactId ? " active" : ""; + return ` + + `; + }) + .join(""); + + strip.querySelectorAll("[data-artifact-id]").forEach((button) => { + button.addEventListener("click", () => { + void selectArtifact(jobId, button.getAttribute("data-artifact-id"), artifacts).catch(handleError); + }); + }); + + const preferredArtifact = + artifacts.find((artifact) => artifact.artifact_type === "final-output") || artifacts[0]; + await selectArtifact(jobId, preferredArtifact.id, artifacts); +} + +async function selectArtifact(jobId, artifactId, artifacts) { + if (!artifactId) { + return; + } + + state.selectedArtifactId = artifactId; + const strip = document.getElementById("artifact-strip"); + const preview = document.getElementById("artifact-preview"); + if (!strip || !preview) { + return; + } + + strip.querySelectorAll("[data-artifact-id]").forEach((button) => { + button.classList.toggle("active", button.getAttribute("data-artifact-id") === artifactId); + }); + + const response = await fetchJson(`/jobs/${jobId}/artifacts/${artifactId}`); + const artifact = response.artifact; + const artifactMeta = artifacts.find((item) => item.id === artifactId) || artifact; + preview.className = "stack fade-in"; + preview.innerHTML = ` +
+ ${chip("type", artifactMeta.artifact_type)} + ${chip("path", artifactMeta.relative_path)} + ${chip("format", response.content_type === "application/json" ? "json" : "text")} +
+
${escapeHtml(response.content)}
+ `; +} + +function renderRouteDecision(decision) { + elements.routeSummary.className = "stat-block fade-in"; + elements.routeSummary.innerHTML = ` +
+ ${chip("tier", decision.matched_tier_id)} + ${chip("planner", decision.planner_model_id || "none")} + ${chip("executor", decision.executor_model_id)} + ${chip("reviewer", decision.reviewer_model_id || "none")} + ${chip("queue", decision.queue_recommended ? "recommended" : "not needed", decision.queue_recommended ? "warn" : "")} +
+
fallbacks: ${escapeHtml((decision.fallback_model_ids || []).join(", ") || "none")}
+ `; + + elements.routeReasons.innerHTML = (decision.reasons || []) + .map( + (reason) => ` +
+
${escapeHtml(reason.stage)}
+
${escapeHtml(reason.message)}
+
+ `, + ) + .join(""); +} + +async function explainRoute(event) { + event.preventDefault(); + const payload = payloadFromForm(); + ensurePrompt(payload); + const decision = await fetchJson("/route/explain", { + method: "POST", + headers: { "Content-Type": "application/json" }, + body: JSON.stringify(payload), + }); + renderRouteDecision(decision); +} + +async function submitJob() { + const payload = payloadFromForm(); + ensurePrompt(payload); + const job = await fetchJson("/jobs", { + method: "POST", + headers: { "Content-Type": "application/json" }, + body: JSON.stringify(payload), + }); + renderRouteDecision(job.route_decision); + await loadOverview(); + await selectJob(job.id); +} + +async function replayJob(jobId, queueMode) { + const job = await fetchJson(`/jobs/${jobId}/replay`, { + method: "POST", + headers: { "Content-Type": "application/json" }, + body: JSON.stringify({ queue_mode: queueMode }), + }); + setQueueActionStatus( + `Created replay job ${job.id.slice(0, 8)} in ${job.queue_mode} mode.`, + "ready", + ); + await loadOverview(); + await selectJob(job.id); +} + +async function cancelJob(jobId) { + await fetchJson(`/jobs/${jobId}/cancel`, { method: "POST" }); + setQueueActionStatus(`Canceled job ${jobId.slice(0, 8)}.`, "warn"); + await loadOverview(); + await selectJob(jobId); +} + +async function runWorker(maxJobs, untilIdle = false) { + const selectedJobId = state.selectedJobId; + const body = untilIdle + ? { until_idle: true, max_idle_cycles: 2, max_jobs: maxJobs ?? null } + : { max_jobs: maxJobs }; + const response = await fetchJson("/worker/run", { + method: "POST", + headers: { "Content-Type": "application/json" }, + body: JSON.stringify(body), + }); + const processed = response.processed || 0; + const latestProcessed = + response.jobs && response.jobs.length > 0 ? response.jobs[response.jobs.length - 1] : null; + state.worker = response.worker || state.worker; + renderWorkerMonitor(); + setQueueActionStatus( + processed > 0 + ? untilIdle + ? `Drained ${processed} queued job(s) until the worker went idle.` + : `Processed ${processed} queued job(s).` + : untilIdle + ? "Queue drain completed with no queued jobs left to process." + : "No queued jobs were ready to run.", + processed > 0 ? "ready" : "warn", + ); + await loadOverview(); + if (latestProcessed?.id) { + await selectJob(latestProcessed.id); + return; + } + if (selectedJobId) { + await selectJob(selectedJobId); + } +} + +async function recoverStaleJobs() { + const response = await fetchJson("/jobs/recover", { + method: "POST", + headers: { "Content-Type": "application/json" }, + body: JSON.stringify({ limit: 20 }), + }); + const recovered = response.recovered || []; + setQueueActionStatus( + recovered.length > 0 + ? `Recovered ${recovered.length} stale job(s) back into the queue.` + : "No stale jobs needed recovery.", + recovered.length > 0 ? "warn" : "ready", + ); + await loadOverview(); + if (recovered.length > 0) { + await selectJob(recovered[0].id); + } +} + +async function runSmokeSuite(suiteId, includeAirllm = false) { + const response = await fetchJson("/smoke/run", { + method: "POST", + headers: { "Content-Type": "application/json" }, + body: JSON.stringify({ + suite_id: suiteId, + live: false, + include_airllm: includeAirllm, + save: true, + }), + }); + const label = suiteId === "providers" ? "provider" : includeAirllm ? "local + AirLLM" : "local"; + await loadSmokeDiagnostics(); + if (response.id) { + await selectSmokeResult(suiteId, response.id); + } + setSmokeActionStatus( + `Saved ${label} readiness diagnostics at ${response.path || "the smoke results directory"}.`, + response.result.passed ? "ready" : "warn", + ); +} + +function bindEvents() { + elements.routeForm.addEventListener("submit", (event) => { + void explainRoute(event).catch(handleError); + }); + elements.submitJob.addEventListener("click", () => { + void submitJob().catch(handleError); + }); + elements.refreshAll.addEventListener("click", () => { + void loadOverview().catch(handleError); + }); + elements.runNextJob.addEventListener("click", () => { + void runWorker(1).catch(handleQueueError); + }); + elements.runBatchJobs.addEventListener("click", () => { + void runWorker(3).catch(handleQueueError); + }); + elements.drainQueue.addEventListener("click", () => { + void runWorker(null, true).catch(handleQueueError); + }); + elements.recoverStaleJobs.addEventListener("click", () => { + void recoverStaleJobs().catch(handleQueueError); + }); + elements.refreshWorkstation.addEventListener("click", () => { + void loadWorkstationReadiness().catch(handleError); + }); + elements.runProviderSmoke.addEventListener("click", () => { + void runSmokeSuite("providers").catch(handleSmokeError); + }); + elements.runLocalSmoke.addEventListener("click", () => { + void runSmokeSuite("local").catch(handleSmokeError); + }); + elements.runLocalAirllmSmoke.addEventListener("click", () => { + void runSmokeSuite("local", true).catch(handleSmokeError); + }); + elements.refreshSmoke.addEventListener("click", () => { + void loadSmokeDiagnostics().catch(handleSmokeError); + }); +} + +function handleError(error) { + elements.routeSummary.className = "stat-block fade-in"; + elements.routeSummary.innerHTML = `
error
${escapeHtml(error.message || String(error))}
`; +} + +function handleQueueError(error) { + setQueueActionStatus(error.message || String(error), "danger"); +} + +function handleSmokeError(error) { + setSmokeActionStatus(error.message || String(error), "danger"); +} + +async function boot() { + bindEvents(); + await loadOverview(); + setInterval(() => { + void loadOverview().catch(handleError); + }, 8000); +} + +void boot().catch(handleError); diff --git a/src/lai/api/static/index.html b/src/lai/api/static/index.html new file mode 100644 index 0000000..1fafa90 --- /dev/null +++ b/src/lai/api/static/index.html @@ -0,0 +1,202 @@ + + + + + + LAI Dashboard + + + + + + +
+
+
+ +
+
+
+

LAI CONTROL ROOM

+

Route fast. Think deep. Watch the queue move.

+

+ This dashboard sits on top of the local-first orchestration core. It exposes + model health, route reasoning, persistent jobs, and the current system stance + without needing a separate frontend stack. +

+
+ Open API Docs + +
+
+
+
+ Environment + ... +
+
+ Catalog Models + ... +
+
+ Jobs Tracked + ... +
+
+ Queue State + ... +
+
+ Worker + ... +
+
+
+ +
+
+
+

Route Lab

+

Explain or launch a request

+
+
+ + +
+ + +
+
+ + +
+
+
+ +
+
+

Decision Trace

+

Routing output

+
+
No route computed yet.
+
+
+
+ +
+
+
+
+

Queue

+

Recent jobs

+
+
+ + + + +
+
+
+ Worker idle. Use the queue controls to process queued jobs. +
+
Checking stale jobs...
+
Loading worker monitor...
+
No jobs yet.
+
+ +
+
+

Inspector

+

Job detail

+
+
Select a job to inspect its traces.
+
+
+ +
+
+

Runtime Matrix

+

Model health

+
+
Loading model health...
+
+ +
+
+
+

Heavy Local

+

Workstation readiness

+
+
+ +
+
+
Loading workstation readiness...
+
Loading workstation profiles...
+
+ +
+
+
+

Readiness

+

Smoke diagnostics

+
+
+ + + + +
+
+
+ Run a readiness sweep to save provider diagnostics without sending live prompts. +
+
Loading smoke diagnostics...
+
+
+ + + + diff --git a/src/lai/application.py b/src/lai/application.py new file mode 100644 index 0000000..237c656 --- /dev/null +++ b/src/lai/application.py @@ -0,0 +1,63 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Mapping + +from .artifacts import ArtifactManager +from .config import AppConfig, load_app_config +from .jobs.service import OrchestrationService +from .jobs.store import JobStore +from .layout import ensure_runtime_directories +from .providers import Provider, ProviderRegistry +from .routing import RoutingEngine +from .settings import Settings +from .system import collect_system_snapshot + + +@dataclass +class LAIApplication: + settings: Settings + config: AppConfig + provider_registry: ProviderRegistry + routing_engine: RoutingEngine + job_store: JobStore + artifacts: ArtifactManager + orchestration: OrchestrationService + + +def create_application( + settings: Settings | None = None, + providers: Mapping[str, Provider] | None = None, +) -> LAIApplication: + settings = settings or Settings() + ensure_runtime_directories(settings.root_dir) + config = load_app_config(settings.resolved_model_catalog, settings.resolved_routing_policy) + system_snapshot = collect_system_snapshot( + settings.resolved_huggingface_cache_dir, + settings.resolved_airllm_shards_dir, + settings.resolved_artifacts_dir, + settings.resolved_state_dir, + enable_gpu=settings.enable_gpu, + ) + registry = ProviderRegistry(settings, system_snapshot, providers=providers) + job_store = JobStore(settings.resolved_database_path) + job_store.initialize() + artifacts = ArtifactManager(settings.resolved_artifacts_dir, job_store) + routing_engine = RoutingEngine(settings, config, registry, system_snapshot) + orchestration = OrchestrationService( + settings=settings, + config=config, + provider_registry=registry, + routing_engine=routing_engine, + job_store=job_store, + artifacts=artifacts, + ) + return LAIApplication( + settings=settings, + config=config, + provider_registry=registry, + routing_engine=routing_engine, + job_store=job_store, + artifacts=artifacts, + orchestration=orchestration, + ) diff --git a/src/lai/artifacts.py b/src/lai/artifacts.py new file mode 100644 index 0000000..14d127a --- /dev/null +++ b/src/lai/artifacts.py @@ -0,0 +1,55 @@ +from __future__ import annotations + +from pathlib import Path +from uuid import uuid4 + +from .domain import ArtifactRecord +from .jobs.store import JobStore +from .serialization import dumps_pretty + + +class ArtifactManager: + def __init__(self, artifacts_root: Path, job_store: JobStore) -> None: + self.artifacts_root = artifacts_root + self.job_store = job_store + self.artifacts_root.mkdir(parents=True, exist_ok=True) + + def write_json( + self, job_id: str, artifact_type: str, filename: str, payload: object + ) -> ArtifactRecord: + return self._write(job_id, artifact_type, filename, dumps_pretty(payload)) + + def write_text( + self, job_id: str, artifact_type: str, filename: str, payload: str + ) -> ArtifactRecord: + return self._write(job_id, artifact_type, filename, payload) + + def resolve_path(self, artifact: ArtifactRecord) -> Path: + target = (self.artifacts_root / artifact.relative_path).resolve() + root = self.artifacts_root.resolve() + target.relative_to(root) + return target + + def read_text(self, artifact: ArtifactRecord) -> str: + return self.resolve_path(artifact).read_text(encoding="utf-8") + + def content_type(self, artifact: ArtifactRecord) -> str: + if artifact.relative_path.endswith(".json"): + return "application/json" + return "text/plain" + + def _write( + self, job_id: str, artifact_type: str, filename: str, payload: str + ) -> ArtifactRecord: + job_dir = self.artifacts_root / job_id + job_dir.mkdir(parents=True, exist_ok=True) + path = job_dir / filename + path.write_text(payload, encoding="utf-8") + artifact = ArtifactRecord( + id=str(uuid4()), + job_id=job_id, + artifact_type=artifact_type, + relative_path=str(path.relative_to(self.artifacts_root)), + ) + self.job_store.add_artifact(artifact) + return artifact diff --git a/src/lai/cli.py b/src/lai/cli.py index 49fb7d4..5cdb308 100644 --- a/src/lai/cli.py +++ b/src/lai/cli.py @@ -1,14 +1,40 @@ +from __future__ import annotations + from pathlib import Path +from typing import Literal import typer from rich.console import Console +from rich.panel import Panel from rich.table import Table +from .application import create_application +from .domain import ExecutionRequest, QueueMode +from .evals import run_route_eval_suite, save_route_eval_result from .layout import runtime_directories +from .observability import summarize_job_execution from .settings import Settings +from .smoke import latest_smoke_result, run_smoke_suite, save_smoke_result +from .system import collect_system_snapshot +from .worker.service import WorkerServiceConfig, WorkerServiceHost, request_worker_service_stop +from .workstation import build_workstation_readiness -app = typer.Typer(help="Utilities for the LAI repository scaffold.", no_args_is_help=True) -console = Console() +app = typer.Typer(help="Utilities for the LAI orchestration platform.", no_args_is_help=True) +models_app = typer.Typer(help="Inspect model definitions and provider readiness.") +route_app = typer.Typer(help="Explain routing decisions without executing.") +jobs_app = typer.Typer(help="Inspect and manage persistent jobs.") +worker_app = typer.Typer(help="Run the local job worker.") +eval_app = typer.Typer(help="Run evaluation suites against routing behavior.") +smoke_app = typer.Typer(help="Run provider readiness and smoke diagnostics.") +workstation_app = typer.Typer(help="Validate the local workstation for heavy local execution.") +app.add_typer(models_app, name="models") +app.add_typer(route_app, name="route") +app.add_typer(jobs_app, name="jobs") +app.add_typer(worker_app, name="worker") +app.add_typer(eval_app, name="eval") +app.add_typer(smoke_app, name="smoke") +app.add_typer(workstation_app, name="workstation") +console = Console(markup=False) @app.callback() @@ -18,8 +44,19 @@ def main() -> None: @app.command() def doctor() -> None: - """Print the expected repository configuration and runtime layout.""" + """Print repository, config, and system readiness.""" settings = Settings() + snapshot = collect_system_snapshot( + *runtime_directories(settings.root_dir), enable_gpu=settings.enable_gpu + ) + application = create_application(settings=settings) + workstation = build_workstation_readiness( + settings, + application.config, + snapshot=snapshot, + ) + local_profile = workstation.profile("local") + airllm_profile = workstation.profile("airllm") if len(workstation.profiles) > 1 else None table = Table(title="LAI Workspace Doctor") table.add_column("Item") @@ -29,6 +66,23 @@ def doctor() -> None: table.add_row("Root", str(settings.root_dir)) table.add_row("Model catalog", _status_line(settings.resolved_model_catalog)) table.add_row("Routing policy", _status_line(settings.resolved_routing_policy)) + table.add_row("Database", _status_line(settings.resolved_database_path)) + table.add_row("GPU available", "yes" if snapshot.has_gpu else "no") + table.add_row("GPU name", snapshot.gpu_name or "n/a") + table.add_row( + "GPU VRAM", + f"{snapshot.gpu_total_memory_gb:.2f} GB" + if snapshot.gpu_total_memory_gb is not None + else "n/a", + ) + table.add_row( + "System RAM", + f"{snapshot.total_ram_gb:.2f} GB" if snapshot.total_ram_gb is not None else "n/a", + ) + table.add_row("Catalog models", str(len(application.config.model_catalog.models))) + table.add_row("Local runtime", _profile_status(local_profile)) + if airllm_profile is not None: + table.add_row("AirLLM heavy path", _profile_status(airllm_profile)) for path in runtime_directories(settings.root_dir): table.add_row("Runtime path", _status_line(path)) @@ -36,10 +90,726 @@ def doctor() -> None: console.print(table) +@models_app.command("list") +def list_models() -> None: + """List catalog models and their main properties.""" + application = create_application() + table = Table(title="LAI Models") + table.add_column("Id") + table.add_column("Role") + table.add_column("Runtime") + table.add_column("Model Ref") + table.add_column("Capabilities") + + for model in application.config.model_catalog.models: + table.add_row( + model.id, + model.role, + model.runtime, + model.model_ref, + ", ".join(model.capabilities), + ) + console.print(table) + + +@models_app.command("check") +def check_models() -> None: + """Check provider availability and local constraints for each model.""" + application = create_application() + healthchecks = application.provider_registry.model_healthchecks(application.config) + + table = Table(title="LAI Model Health") + table.add_column("Id") + table.add_column("Provider") + table.add_column("Available") + table.add_column("Healthy") + table.add_column("Details") + + for model in application.config.model_catalog.models: + health = healthchecks[model.id] + table.add_row( + model.id, + model.provider_id, + "yes" if health.available else "no", + "yes" if health.healthy else "no", + "; ".join(health.reasons) if health.reasons else "ready", + ) + console.print(table) + + +@workstation_app.command("validate") +def validate_workstation( + profile: Literal["all", "local", "airllm"] = typer.Option( + "all", + help="Show all profiles, or focus on the local baseline or AirLLM heavy path.", + ), +) -> None: + """Validate workstation readiness for local and AirLLM execution paths.""" + settings = Settings() + application = create_application(settings=settings) + report = build_workstation_readiness( + settings, + application.config, + snapshot=application.provider_registry.system_snapshot, + ) + + system_table = Table(title="LAI Workstation") + system_table.add_column("Fact") + system_table.add_column("Value") + system_table.add_row("Platform", report.system["platform"]) + system_table.add_row("Python", report.system["python_version"]) + system_table.add_row("GPU", report.system["gpu_name"] or "n/a") + system_table.add_row( + "GPU VRAM", + f"{report.system['gpu_total_memory_gb']:.2f} GB" + if report.system["gpu_total_memory_gb"] is not None + else "n/a", + ) + system_table.add_row( + "System RAM", + f"{report.system['total_ram_gb']:.2f} GB" + if report.system["total_ram_gb"] is not None + else "n/a", + ) + console.print(system_table) + + selected_profiles = [ + candidate + for candidate in report.profiles + if profile == "all" or candidate.profile_id == profile + ] + for candidate in selected_profiles: + _print_workstation_profile(candidate) + + if any(not candidate.ready for candidate in selected_profiles): + raise typer.Exit(code=1) + + +@route_app.command("explain") +def explain_route( + prompt: str = typer.Argument(..., help="The user request to classify and route."), + system_prompt: str | None = typer.Option(None, help="Optional system prompt."), + model_override: str | None = typer.Option(None, help="Force a specific model id."), + provider_override: str | None = typer.Option(None, help="Force a specific provider id."), + no_review: bool = typer.Option(False, "--no-review", help="Disable the reviewer stage."), +) -> None: + """Explain how LAI would route a request.""" + application = create_application() + request = ExecutionRequest( + system_prompt=system_prompt, + user_prompt=prompt, + model_override=model_override, + provider_override=provider_override, + reviewer_enabled=False if no_review else None, + ) + decision = application.routing_engine.route(request) + + summary = Table(title="Routing Decision") + summary.add_column("Field") + summary.add_column("Value") + summary.add_row("Tier", decision.matched_tier_id) + summary.add_row("Planner", decision.planner_model_id or "none") + summary.add_row("Executor", decision.executor_model_id) + summary.add_row("Reviewer", decision.reviewer_model_id or "none") + summary.add_row("Queue recommended", "yes" if decision.queue_recommended else "no") + summary.add_row("Fallbacks", ", ".join(decision.fallback_model_ids) or "none") + console.print(summary) + + reasons = Table(title="Routing Reasons") + reasons.add_column("Stage") + reasons.add_column("Reason") + for reason in decision.reasons: + reasons.add_row(reason.stage, reason.message) + console.print(reasons) + + +@app.command() +def run( + prompt: str = typer.Argument(..., help="The user request to execute."), + system_prompt: str | None = typer.Option(None, help="Optional system prompt."), + queue_mode: QueueMode = typer.Option( + QueueMode.AUTO, help="Run inline, queued, or let LAI decide." + ), + model_override: str | None = typer.Option(None, help="Force a specific model id."), + provider_override: str | None = typer.Option(None, help="Force a specific provider id."), + temperature: float | None = typer.Option(None, help="Override generation temperature."), + max_output_tokens: int | None = typer.Option(None, help="Override output token budget."), + timeout_seconds: int | None = typer.Option(None, help="Override request timeout."), + no_review: bool = typer.Option(False, "--no-review", help="Disable the reviewer stage."), +) -> None: + """Submit a request for inline or queued execution.""" + application = create_application() + request = ExecutionRequest( + system_prompt=system_prompt, + user_prompt=prompt, + queue_mode=queue_mode, + model_override=model_override, + provider_override=provider_override, + reviewer_enabled=False if no_review else None, + temperature=temperature + if temperature is not None + else application.settings.default_temperature, + max_output_tokens=( + max_output_tokens + if max_output_tokens is not None + else application.settings.default_max_output_tokens + ), + timeout_seconds=( + timeout_seconds + if timeout_seconds is not None + else application.settings.default_timeout_seconds + ), + ) + job = application.orchestration.submit_request(request) + + if job.queue_mode == QueueMode.QUEUED and job.status == "queued": + console.print(Panel.fit(f"Queued job {job.id}", title="LAI Run")) + return + + console.print( + Panel.fit(job.result.text if job.result else "No output produced.", title=f"Job {job.id}") + ) + + +@jobs_app.command("list") +def list_jobs(limit: int = typer.Option(20, min=1, max=100, help="Maximum jobs to list.")) -> None: + """List persisted jobs.""" + application = create_application() + jobs = application.job_store.list_jobs(limit=limit) + table = Table(title="LAI Jobs") + table.add_column("Id") + table.add_column("Status") + table.add_column("Queue Mode") + table.add_column("Created") + table.add_column("Executor") + + for job in jobs: + executor = job.route_decision.executor_model_id if job.route_decision else "n/a" + table.add_row(job.id, job.status, job.queue_mode, job.created_at.isoformat(), executor) + console.print(table) + + +@jobs_app.command("show") +def show_job(job_id: str = typer.Argument(..., help="Job identifier.")) -> None: + """Show a job and its latest output.""" + application = create_application() + job = application.job_store.get_job(job_id) + if job is None: + raise typer.Exit(code=1) + + summary = Table(title=f"Job {job.id}") + summary.add_column("Field") + summary.add_column("Value") + summary.add_row("Status", job.status) + summary.add_row("Queue mode", job.queue_mode) + summary.add_row("Attempts", str(job.attempts)) + summary.add_row("Tier", job.route_decision.matched_tier_id if job.route_decision else "n/a") + summary.add_row( + "Executor", job.route_decision.executor_model_id if job.route_decision else "n/a" + ) + summary.add_row("Artifacts", str(len(job.artifacts))) + summary.add_row("Stage events", str(len(job.stage_events))) + console.print(summary) + + if job.result: + console.print(Panel.fit(job.result.text, title="Latest Output")) + if job.error: + console.print(Panel.fit(job.error.message, title="Error")) + if job.stage_events: + timeline = Table(title="Stage Timeline") + timeline.add_column("When") + timeline.add_column("Stage") + timeline.add_column("Event") + timeline.add_column("Model") + timeline.add_column("Message") + for event in job.stage_events: + timeline.add_row( + event.created_at.isoformat(timespec="seconds"), + event.stage, + event.event_type, + event.model_id or "n/a", + event.message, + ) + console.print(timeline) + + execution = summarize_job_execution(job) + if execution: + execution_table = Table(title="Execution Trace") + execution_table.add_column("Stage") + execution_table.add_column("Status") + execution_table.add_column("Provider") + execution_table.add_column("Model") + execution_table.add_column("Duration") + execution_table.add_column("Progress") + execution_table.add_column("Output") + for summary in execution: + progress = "n/a" + if summary.latest_progress_phase or summary.latest_progress_percent is not None: + parts = [] + if summary.latest_progress_phase: + parts.append(summary.latest_progress_phase) + if summary.latest_progress_percent is not None: + parts.append(f"{summary.latest_progress_percent * 100:.0f}%") + progress = " / ".join(parts) + execution_table.add_row( + summary.stage, + summary.status, + summary.provider_id or "n/a", + summary.model_id or "n/a", + ( + f"{summary.duration_seconds:.2f}s" + if summary.duration_seconds is not None + else "n/a" + ), + progress, + summary.output_preview or "n/a", + ) + console.print(execution_table) + + +@jobs_app.command("cancel") +def cancel_job(job_id: str = typer.Argument(..., help="Job identifier.")) -> None: + """Cancel a queued or running job.""" + application = create_application() + if application.orchestration.cancel_job(job_id): + console.print(f"Canceled job {job_id}") + return + raise typer.Exit(code=1) + + +@jobs_app.command("replay") +def replay_job( + job_id: str = typer.Argument(..., help="Source job identifier."), + queue_mode: QueueMode | None = typer.Option( + None, help="Override the replay queue mode for the new job." + ), +) -> None: + """Replay a persisted request as a new job.""" + application = create_application() + try: + job = application.orchestration.replay_job(job_id, queue_mode=queue_mode) + except KeyError: + raise typer.Exit(code=1) from None + + if job.queue_mode == QueueMode.QUEUED and job.status == "queued": + console.print(Panel.fit(f"Queued replay job {job.id}", title="LAI Replay")) + return + + console.print( + Panel.fit(job.result.text if job.result else "No output produced.", title=f"Job {job.id}") + ) + + +@jobs_app.command("stale") +def list_stale_jobs( + stale_after_seconds: int | None = typer.Option( + None, + min=1, + help="Override the stale-job timeout in seconds.", + ), + limit: int = typer.Option(20, min=1, max=200, help="Maximum stale jobs to show."), +) -> None: + """List stale running jobs that likely need recovery.""" + application = create_application() + stale_jobs = application.orchestration.list_stale_jobs( + stale_after_seconds=stale_after_seconds, + limit=limit, + ) + if not stale_jobs: + console.print("No stale running jobs detected.") + return + + table = Table(title="Stale Jobs") + table.add_column("Job") + table.add_column("Stage") + table.add_column("Executor") + table.add_column("Age") + table.add_column("Last Event") + table.add_column("Message") + for job in stale_jobs: + table.add_row( + job.job_id, + job.current_stage or "n/a", + job.executor_model_id or "n/a", + f"{job.age_seconds:.2f}s", + job.last_event_type or "n/a", + job.last_message or "n/a", + ) + console.print(table) + + +@jobs_app.command("recover") +def recover_stale_jobs( + stale_after_seconds: int | None = typer.Option( + None, + min=1, + help="Override the stale-job timeout in seconds.", + ), + limit: int = typer.Option( + 20, + min=1, + max=200, + help="Maximum stale jobs to recover in one pass.", + ), +) -> None: + """Recover stale running jobs by returning them to the queue.""" + application = create_application() + recovered = application.orchestration.recover_stale_jobs( + stale_after_seconds=stale_after_seconds, + limit=limit, + ) + if not recovered: + console.print("No stale jobs were recovered.") + return + + table = Table(title="Recovered Jobs") + table.add_column("Job") + table.add_column("Status") + table.add_column("Queue") + table.add_column("Executor") + for job in recovered: + table.add_row( + job.id, + job.status, + job.queue_mode, + job.route_decision.executor_model_id if job.route_decision else "n/a", + ) + console.print(table) + + +@worker_app.command("run") +def run_worker( + once: bool = typer.Option(False, help="Process at most one queued job and then exit."), + max_jobs: int | None = typer.Option( + None, + min=1, + help="Process at most this many queued jobs and then exit.", + ), + until_idle: bool = typer.Option( + False, + "--until-idle", + help="Keep polling until the queue stays idle for the configured idle cycles.", + ), + max_idle_cycles: int = typer.Option( + 2, + min=1, + help="Number of empty polls before an until-idle run stops.", + ), +) -> None: + """Run the local worker for queued jobs.""" + application = create_application() + if once: + processed = len( + application.orchestration.run_worker_batch( + max_jobs=1, + requeue_running=True, + ) + ) + elif until_idle: + processed = application.orchestration.run_worker_until_idle( + max_idle_cycles=max_idle_cycles, + max_jobs=max_jobs, + requeue_running=True, + ) + elif max_jobs is not None: + processed = len( + application.orchestration.run_worker_batch( + max_jobs=max_jobs, + requeue_running=True, + ) + ) + else: + processed = application.orchestration.run_worker() + console.print(f"Processed {processed} queued job(s).") + + +@worker_app.command("status") +def worker_status() -> None: + """Show the persisted local worker state.""" + application = create_application() + worker = application.orchestration.get_worker_state() + + table = Table(title=f"Worker {worker.worker_id}") + table.add_column("Field") + table.add_column("Value") + table.add_row("Status", worker.status) + table.add_row("Mode", worker.mode) + table.add_row("Current job", worker.current_job_id or "n/a") + table.add_row("Last job", worker.last_job_id or "n/a") + table.add_row("Processed jobs", str(worker.processed_jobs)) + table.add_row("Queued jobs", str(worker.queued_jobs)) + table.add_row( + "Heartbeat", + worker.heartbeat_at.isoformat(timespec="seconds"), + ) + table.add_row( + "Started at", + worker.started_at.isoformat(timespec="seconds") if worker.started_at else "n/a", + ) + table.add_row( + "Service lock", + "present" if application.settings.resolved_worker_service_lock_path.exists() else "missing", + ) + table.add_row( + "Stop signal", + "present" if application.settings.resolved_worker_service_stop_path.exists() else "missing", + ) + table.add_row("Service log", str(application.settings.resolved_worker_service_log_path)) + table.add_row("Last error", worker.last_error or "none") + last_recovery = worker.metadata.get("last_recovery") + if isinstance(last_recovery, dict): + recovered_at = str(last_recovery.get("at", "n/a")) + recovered_count = str(last_recovery.get("count", 0)) + recovered_reason = str(last_recovery.get("reason", "n/a")) + recovered_jobs = ", ".join(str(job_id) for job_id in last_recovery.get("job_ids", [])) + table.add_row( + "Last recovery", + f"{recovered_count} job(s) via {recovered_reason} at {recovered_at}", + ) + table.add_row("Recovered jobs", recovered_jobs or "none") + console.print(table) + + +@worker_app.command("serve") +def serve_worker( + poll_interval: float | None = typer.Option( + None, + min=0.1, + help="Seconds to wait between daemon polling cycles.", + ), + max_idle_cycles: int = typer.Option( + 1, + min=1, + help="Number of empty polls before a daemon cycle becomes idle.", + ), + max_jobs_per_cycle: int | None = typer.Option( + None, + min=1, + help="Optional cap for jobs processed in one daemon cycle.", + ), + max_cycles: int | None = typer.Option( + None, + min=1, + help="Optional cap for daemon cycles. Useful for supervised runs.", + ), + worker_id: str = typer.Option( + "local-worker", + help="Persisted worker id for the dedicated service.", + ), + replace_lock: bool = typer.Option( + False, + "--replace-lock", + help="Replace an existing worker-service lock if you are certain it is stale.", + ), +) -> None: + """Run the dedicated long-lived local worker service.""" + settings = Settings() + config = WorkerServiceConfig( + worker_id=worker_id, + poll_interval_seconds=( + poll_interval + if poll_interval is not None + else settings.worker_service_poll_interval_seconds + ), + max_idle_cycles_per_drain=max_idle_cycles, + max_jobs_per_cycle=max_jobs_per_cycle, + max_cycles=max_cycles, + replace_existing_lock=replace_lock, + ) + host = WorkerServiceHost(settings=settings, config=config) + processed = host.run() + console.print(f"Worker service exited after processing {processed} job(s).") + + +@worker_app.command("stop") +def stop_worker_service() -> None: + """Request a graceful stop for the dedicated worker service.""" + path = request_worker_service_stop(Settings()) + console.print(f"Requested worker service stop via {path}") + + +@eval_app.command("route") +def eval_route( + scenario_file: Path = typer.Option( + Path("evals/scenarios/routing-basics.yaml"), + exists=True, + file_okay=True, + dir_okay=False, + help="YAML file describing route evaluation scenarios.", + ), + save: bool = typer.Option( + True, + "--save/--no-save", + help="Save the evaluation result under evals/results.", + ), +) -> None: + """Run the route evaluation suite and report pass/fail.""" + settings = Settings() + result = run_route_eval_suite(settings, settings.root_dir / scenario_file) + + table = Table(title="LAI Route Eval") + table.add_column("Scenario") + table.add_column("Status") + table.add_column("Expected Tier") + table.add_column("Actual Tier") + table.add_column("Executor") + + for case in result.cases: + table.add_row( + case.scenario_id, + "PASS" if case.passed else "FAIL", + case.expected_tier, + case.actual_tier, + case.actual_executor, + ) + console.print(table) + + if save: + result_path = save_route_eval_result(settings.root_dir / "evals/results", result) + console.print(f"Saved evaluation result to {result_path}") + + if not result.passed: + raise typer.Exit(code=1) + + +@smoke_app.command("providers") +def smoke_providers( + live: bool = typer.Option( + False, + "--live", + help="Execute tiny live prompts instead of readiness-only checks.", + ), + save: bool = typer.Option( + True, + "--save/--no-save", + help="Save the smoke result under evals/results/smoke.", + ), +) -> None: + """Run readiness or live smoke checks for remote providers.""" + _run_smoke_command( + suite_id="providers", + live=live, + save=save, + ) + + +@smoke_app.command("local") +def smoke_local( + live: bool = typer.Option( + False, + "--live", + help="Execute tiny live prompts instead of readiness-only checks.", + ), + include_airllm: bool = typer.Option( + False, + "--include-airllm", + help="Include the heavier AirLLM smoke target alongside transformers.", + ), + save: bool = typer.Option( + True, + "--save/--no-save", + help="Save the smoke result under evals/results/smoke.", + ), +) -> None: + """Run readiness or live smoke checks for local providers.""" + _run_smoke_command( + suite_id="local", + live=live, + include_airllm=include_airllm, + save=save, + ) + + +@smoke_app.command("latest") +def smoke_latest() -> None: + """Show the latest saved smoke result path.""" + settings = Settings() + path = latest_smoke_result(settings.root_dir / "evals/results/smoke") + if path is None: + raise typer.Exit(code=1) + console.print(path) + + def _status_line(path: Path) -> str: suffix = "present" if path.exists() else "missing" return f"{path} ({suffix})" +def _profile_status(profile) -> str: + if profile.ready and profile.warning_count == 0: + return "ready" + if profile.ready: + return f"ready with {profile.warning_count} warning(s)" + return f"blocked ({profile.failure_count} fail / {profile.warning_count} warn)" + + +def _print_workstation_profile(profile) -> None: + summary = Table(title=profile.title) + summary.add_column("Field") + summary.add_column("Value") + summary.add_row("Profile", profile.profile_id) + summary.add_row("Ready", "yes" if profile.ready else "no") + summary.add_row("Warnings", str(profile.warning_count)) + summary.add_row("Failures", str(profile.failure_count)) + summary.add_row("Summary", profile.summary) + console.print(summary) + + checks = Table(title=f"{profile.title} Checks") + checks.add_column("Check") + checks.add_column("Status") + checks.add_column("Summary") + checks.add_column("Remediation") + for check in profile.checks: + checks.add_row( + check.title, + check.status, + check.summary, + check.remediation or "n/a", + ) + console.print(checks) + + +def _run_smoke_command( + *, + suite_id: Literal["providers", "local"], + live: bool, + save: bool, + include_airllm: bool = False, +) -> None: + settings = Settings() + result = run_smoke_suite( + settings, + suite_id=suite_id, + live=live, + include_airllm=include_airllm, + ) + + table = Table(title=f"LAI Smoke {suite_id}") + table.add_column("Provider") + table.add_column("Status") + table.add_column("Executed") + table.add_column("Model") + table.add_column("Duration") + table.add_column("Message") + + for check in result.checks: + table.add_row( + check.provider_id, + check.status, + "yes" if check.executed else "no", + check.model_ref, + f"{check.duration_seconds:.2f}s" if check.duration_seconds is not None else "n/a", + check.message, + ) + console.print(table) + + if save: + path = save_smoke_result(settings.root_dir / "evals/results/smoke", result) + console.print(f"Saved smoke result to {path}") + + if not result.passed: + raise typer.Exit(code=1) + + if __name__ == "__main__": app() diff --git a/src/lai/config.py b/src/lai/config.py new file mode 100644 index 0000000..5b7b552 --- /dev/null +++ b/src/lai/config.py @@ -0,0 +1,72 @@ +from __future__ import annotations + +from pathlib import Path +from typing import Any + +import yaml +from pydantic import BaseModel + +from .domain import ModelCatalog, RoutingPolicy + + +class AppConfig(BaseModel): + model_catalog: ModelCatalog + routing_policy: RoutingPolicy + + +def _load_yaml_file(path: Path) -> dict[str, Any]: + if not path.exists(): + raise FileNotFoundError(f"Configuration file not found: {path}") + + with path.open("r", encoding="utf-8") as handle: + data = yaml.safe_load(handle) or {} + if not isinstance(data, dict): + raise ValueError(f"Configuration file must contain a mapping at the top level: {path}") + return data + + +def load_model_catalog(path: Path) -> ModelCatalog: + return ModelCatalog.model_validate(_load_yaml_file(path)) + + +def load_routing_policy(path: Path) -> RoutingPolicy: + return RoutingPolicy.model_validate(_load_yaml_file(path)) + + +def load_app_config(model_catalog_path: Path, routing_policy_path: Path) -> AppConfig: + config = AppConfig( + model_catalog=load_model_catalog(model_catalog_path), + routing_policy=load_routing_policy(routing_policy_path), + ) + validate_config_references(config) + return config + + +def validate_config_references(config: AppConfig) -> None: + model_ids = {model.id for model in config.model_catalog.models} + + if config.routing_policy.router.model_id not in model_ids: + router_model_id = config.routing_policy.router.model_id + raise ValueError( + f"Router model id {router_model_id!r} is not present in the catalog." + ) + + for tier in config.routing_policy.tiers: + _assert_model_reference(model_ids, tier.executor_model_id, f"tier {tier.id} executor") + if tier.planner_model_id: + _assert_model_reference(model_ids, tier.planner_model_id, f"tier {tier.id} planner") + if tier.reviewer_model_id: + _assert_model_reference(model_ids, tier.reviewer_model_id, f"tier {tier.id} reviewer") + + fallback = config.routing_policy.fallbacks.when_gpu_unavailable + if fallback: + _assert_model_reference(model_ids, fallback.executor_model_id, "GPU fallback executor") + if fallback.planner_model_id: + _assert_model_reference(model_ids, fallback.planner_model_id, "GPU fallback planner") + if fallback.reviewer_model_id: + _assert_model_reference(model_ids, fallback.reviewer_model_id, "GPU fallback reviewer") + + +def _assert_model_reference(model_ids: set[str], model_id: str, context: str) -> None: + if model_id not in model_ids: + raise ValueError(f"{context} references unknown model id {model_id!r}.") diff --git a/src/lai/domain.py b/src/lai/domain.py new file mode 100644 index 0000000..8166aec --- /dev/null +++ b/src/lai/domain.py @@ -0,0 +1,401 @@ +from __future__ import annotations + +from datetime import datetime, timezone +from enum import Enum +from typing import Any + +from pydantic import AliasChoices, BaseModel, ConfigDict, Field, model_validator + + +def utcnow() -> datetime: + return datetime.now(tz=timezone.utc) + + +class LAIModel(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + use_enum_values=True, + extra="forbid", + ) + + +class ModelRole(str, Enum): + ROUTER = "router" + REVIEWER = "reviewer" + EXECUTOR = "executor" + PLANNER = "planner" + + +class ModelRuntime(str, Enum): + TRANSFORMERS = "transformers" + AIRLLM = "airllm" + OPENAI = "openai" + ANTHROPIC = "anthropic" + GEMINI = "gemini" + + +class QueueMode(str, Enum): + AUTO = "auto" + INLINE = "inline" + QUEUED = "queued" + + +class JobStatus(str, Enum): + QUEUED = "queued" + RUNNING = "running" + SUCCEEDED = "succeeded" + FAILED = "failed" + CANCELED = "canceled" + + +class WorkerStatus(str, Enum): + IDLE = "idle" + RUNNING = "running" + STOPPED = "stopped" + ERROR = "error" + + +class FinishReason(str, Enum): + STOP = "stop" + LENGTH = "length" + ERROR = "error" + CANCELED = "canceled" + UNKNOWN = "unknown" + + +class ModelHardwareSpec(LAIModel): + preferred_device: str = "cpu" + minimum_vram_gb: int | None = None + recommended_vram_gb: int | None = None + recommended_ram_gb: int | None = None + expected_disk_gb: int | None = None + + +class ModelRuntimeHints(LAIModel): + allow_layer_sharding: bool = False + allow_prefetching: bool = False + allow_cpu_fallback: bool = False + + +class ModelSpec(LAIModel): + id: str + role: ModelRole + runtime: ModelRuntime + model_ref: str = Field(validation_alias=AliasChoices("model_ref", "repo_id", "model")) + context_window: int = 8192 + capabilities: list[str] = Field(default_factory=list) + hardware: ModelHardwareSpec = Field(default_factory=ModelHardwareSpec) + runtime_hints: ModelRuntimeHints = Field(default_factory=ModelRuntimeHints) + enabled: bool = True + description: str | None = None + + @property + def provider_id(self) -> str: + if isinstance(self.runtime, Enum): + return self.runtime.value + return str(self.runtime) + + def supports_capabilities(self, required_capabilities: list[str]) -> bool: + return set(required_capabilities).issubset(set(self.capabilities)) + + +class CatalogDefaults(LAIModel): + allow_gated_models: bool = True + require_disk_headroom_gb: int = 50 + require_healthcheck_before_enable: bool = True + + +class ModelCatalog(LAIModel): + version: int + defaults: CatalogDefaults = Field(default_factory=CatalogDefaults) + models: list[ModelSpec] + + @model_validator(mode="after") + def validate_unique_model_ids(self) -> "ModelCatalog": + model_ids = [model.id for model in self.models] + duplicates = {model_id for model_id in model_ids if model_ids.count(model_id) > 1} + if duplicates: + raise ValueError(f"Duplicate model ids in catalog: {sorted(duplicates)}") + return self + + def get_model(self, model_id: str) -> ModelSpec: + for model in self.models: + if model.id == model_id: + return model + raise KeyError(f"Unknown model id: {model_id}") + + def enabled_models(self) -> list[ModelSpec]: + return [model for model in self.models if model.enabled] + + +class RouterConfig(LAIModel): + model_id: str + max_input_tokens: int = 2048 + classify_dimensions: list[str] = Field(default_factory=list) + use_model_router: bool = True + + +class RoutingTierMatch(LAIModel): + complexity: str | None = None + urgency: str | None = None + cost_sensitivity: str | None = None + safety_risk: str | None = None + expected_duration: str | None = None + + def score(self, context: "RouteContext") -> int: + score = 0 + if self.complexity and self.complexity == context.complexity: + score += 1 + if self.urgency and self.urgency == context.urgency: + score += 1 + if self.cost_sensitivity and self.cost_sensitivity == context.cost_sensitivity: + score += 1 + if self.safety_risk and self.safety_risk == context.safety_risk: + score += 1 + if self.expected_duration and self.expected_duration == context.expected_duration: + score += 1 + return score + + +class StageSelection(LAIModel): + planner_model_id: str | None = None + executor_model_id: str + reviewer_model_id: str | None = None + + +class RoutingTier(LAIModel): + id: str + description: str + match: RoutingTierMatch + planner_model_id: str | None = None + executor_model_id: str + reviewer_model_id: str | None = None + allow_overnight: bool = False + reviewer_enabled: bool | None = None + + @property + def resolved_reviewer_enabled(self) -> bool: + if self.reviewer_enabled is not None: + return self.reviewer_enabled + return self.id != "instant" + + @property + def should_plan(self) -> bool: + return self.id in {"standard", "deep-work"} + + +class RoutingFallbacks(LAIModel): + when_gpu_unavailable: StageSelection | None = None + + +class RoutingPolicy(LAIModel): + version: int + router: RouterConfig + tiers: list[RoutingTier] + fallbacks: RoutingFallbacks = Field(default_factory=RoutingFallbacks) + + @model_validator(mode="after") + def validate_unique_tier_ids(self) -> "RoutingPolicy": + tier_ids = [tier.id for tier in self.tiers] + duplicates = {tier_id for tier_id in tier_ids if tier_ids.count(tier_id) > 1} + if duplicates: + raise ValueError(f"Duplicate routing tier ids: {sorted(duplicates)}") + return self + + def get_tier(self, tier_id: str) -> RoutingTier: + for tier in self.tiers: + if tier.id == tier_id: + return tier + raise KeyError(f"Unknown routing tier id: {tier_id}") + + +class RouteContext(LAIModel): + prompt_length: int + complexity: str + urgency: str + cost_sensitivity: str + safety_risk: str + expected_duration: str + requires_high_quality: bool = False + overnight_requested: bool = False + required_capabilities: list[str] = Field(default_factory=list) + model_override: str | None = None + provider_override: str | None = None + + +class RouteReason(LAIModel): + stage: str + message: str + + +class RoutingDecision(LAIModel): + matched_tier_id: str + planner_model_id: str | None = None + executor_model_id: str + reviewer_model_id: str | None = None + fallback_model_ids: list[str] = Field(default_factory=list) + reasons: list[RouteReason] = Field(default_factory=list) + queue_recommended: bool = False + should_plan: bool = False + should_review: bool = False + context: RouteContext + + +class ExecutionRequest(LAIModel): + system_prompt: str | None = None + user_prompt: str + metadata: dict[str, Any] = Field(default_factory=dict) + temperature: float = 0.2 + max_output_tokens: int = 1024 + timeout_seconds: int = 120 + model_override: str | None = None + provider_override: str | None = None + queue_mode: QueueMode = QueueMode.AUTO + allow_overnight: bool = True + reviewer_enabled: bool | None = None + required_capabilities: list[str] = Field(default_factory=list) + source: str = "cli" + + +class UsageStats(LAIModel): + prompt_tokens: int | None = None + completion_tokens: int | None = None + total_tokens: int | None = None + + +class ProviderRequest(LAIModel): + system_prompt: str | None = None + user_prompt: str + metadata: dict[str, Any] = Field(default_factory=dict) + temperature: float = 0.2 + max_output_tokens: int = 1024 + timeout_seconds: int = 120 + model_override: str | None = None + + +class ProviderProgressUpdate(LAIModel): + phase: str + message: str + progress: float | None = Field(default=None, ge=0, le=1) + metadata: dict[str, Any] = Field(default_factory=dict) + + +class ProviderHealth(LAIModel): + provider_id: str + available: bool + healthy: bool + reasons: list[str] = Field(default_factory=list) + metadata: dict[str, Any] = Field(default_factory=dict) + + +class ExecutionResult(LAIModel): + text: str + finish_reason: FinishReason = FinishReason.UNKNOWN + provider_id: str + model_id: str + duration_seconds: float + usage: UsageStats | None = None + raw: dict[str, Any] = Field(default_factory=dict) + stage: str = "executor" + + +class JobError(LAIModel): + error_type: str + message: str + retryable: bool = False + + +class ArtifactRecord(LAIModel): + id: str + job_id: str + artifact_type: str + relative_path: str + created_at: datetime = Field(default_factory=utcnow) + metadata: dict[str, Any] = Field(default_factory=dict) + + +class StageEventRecord(LAIModel): + id: str + job_id: str + stage: str + event_type: str + message: str + created_at: datetime = Field(default_factory=utcnow) + model_id: str | None = None + provider_id: str | None = None + metadata: dict[str, Any] = Field(default_factory=dict) + + +class StageExecutionSummary(LAIModel): + stage: str + status: str + model_id: str | None = None + provider_id: str | None = None + started_at: datetime | None = None + finished_at: datetime | None = None + duration_seconds: float | None = None + finish_reason: str | None = None + prompt_characters: int | None = None + system_prompt_characters: int | None = None + output_characters: int | None = None + output_preview: str | None = None + temperature: float | None = None + max_output_tokens: int | None = None + timeout_seconds: int | None = None + usage: UsageStats | None = None + artifact_ids: list[str] = Field(default_factory=list) + artifact_types: list[str] = Field(default_factory=list) + health_reasons: list[str] = Field(default_factory=list) + error_type: str | None = None + last_message: str | None = None + progress_event_count: int = 0 + latest_progress_message: str | None = None + latest_progress_phase: str | None = None + latest_progress_at: datetime | None = None + latest_progress_percent: float | None = None + + +class StaleJobSummary(LAIModel): + job_id: str + queue_mode: QueueMode + attempts: int + executor_model_id: str | None = None + current_stage: str | None = None + last_event_type: str | None = None + last_message: str | None = None + last_activity_at: datetime + stale_after_seconds: int + age_seconds: float + + +class WorkerStateRecord(LAIModel): + worker_id: str + status: WorkerStatus = WorkerStatus.IDLE + mode: str = "batch" + current_job_id: str | None = None + last_job_id: str | None = None + processed_jobs: int = 0 + queued_jobs: int = 0 + started_at: datetime | None = None + heartbeat_at: datetime = Field(default_factory=utcnow) + updated_at: datetime = Field(default_factory=utcnow) + last_error: str | None = None + metadata: dict[str, Any] = Field(default_factory=dict) + + +class JobRecord(LAIModel): + id: str + status: JobStatus + created_at: datetime = Field(default_factory=utcnow) + updated_at: datetime = Field(default_factory=utcnow) + queued_at: datetime | None = None + started_at: datetime | None = None + finished_at: datetime | None = None + queue_mode: QueueMode = QueueMode.AUTO + request: ExecutionRequest + route_decision: RoutingDecision | None = None + result: ExecutionResult | None = None + error: JobError | None = None + attempts: int = 0 + artifacts: list[ArtifactRecord] = Field(default_factory=list) + stage_events: list[StageEventRecord] = Field(default_factory=list) diff --git a/src/lai/errors.py b/src/lai/errors.py new file mode 100644 index 0000000..5a74e37 --- /dev/null +++ b/src/lai/errors.py @@ -0,0 +1,10 @@ +class LAIError(Exception): + """Base exception for LAI.""" + + +class RetryableProviderError(LAIError): + """Raised when a provider failure can be retried safely.""" + + +class ConfigurationError(LAIError): + """Raised when LAI configuration is invalid.""" diff --git a/src/lai/evals.py b/src/lai/evals.py new file mode 100644 index 0000000..9c64ca1 --- /dev/null +++ b/src/lai/evals.py @@ -0,0 +1,179 @@ +from __future__ import annotations + +from datetime import datetime, timezone +from pathlib import Path +from typing import Any + +import yaml +from pydantic import BaseModel, Field + +from .application import create_application +from .domain import ( + ExecutionRequest, + ExecutionResult, + FinishReason, + ModelRuntime, + ModelSpec, + ProviderHealth, + ProviderRequest, +) +from .providers.base import ProgressCallback, Provider +from .settings import Settings +from .system import collect_system_snapshot + + +def utcnow_iso() -> str: + return datetime.now(tz=timezone.utc).isoformat() + + +class RouteEvalScenario(BaseModel): + id: str + description: str + prompt: str + expected_tier: str + expected_executor_in: list[str] = Field(default_factory=list) + unavailable_providers: list[str] = Field(default_factory=list) + + +class RouteEvalSuite(BaseModel): + version: int + scenarios: list[RouteEvalScenario] + + +class RouteEvalCaseResult(BaseModel): + scenario_id: str + passed: bool + expected_tier: str + actual_tier: str + expected_executor_in: list[str] = Field(default_factory=list) + actual_executor: str + reasons: list[str] = Field(default_factory=list) + + +class RouteEvalResult(BaseModel): + executed_at: str + scenario_file: str + passed: bool + total: int + passed_count: int + failed_count: int + cases: list[RouteEvalCaseResult] + + +def load_route_eval_suite(path: Path) -> RouteEvalSuite: + with path.open("r", encoding="utf-8") as handle: + data = yaml.safe_load(handle) or {} + return RouteEvalSuite.model_validate(data) + + +def run_route_eval_suite(settings: Settings, scenario_file: Path) -> RouteEvalResult: + suite = load_route_eval_suite(scenario_file) + cases: list[RouteEvalCaseResult] = [] + + for scenario in suite.scenarios: + application = create_application( + settings=settings, + providers=_eval_providers(settings, scenario.unavailable_providers), + ) + decision = application.routing_engine.route( + ExecutionRequest( + user_prompt=scenario.prompt, + ) + ) + passed = decision.matched_tier_id == scenario.expected_tier + if scenario.expected_executor_in: + passed = passed and decision.executor_model_id in scenario.expected_executor_in + + cases.append( + RouteEvalCaseResult( + scenario_id=scenario.id, + passed=passed, + expected_tier=scenario.expected_tier, + actual_tier=decision.matched_tier_id, + expected_executor_in=scenario.expected_executor_in, + actual_executor=decision.executor_model_id, + reasons=[f"{reason.stage}: {reason.message}" for reason in decision.reasons], + ) + ) + + passed_count = sum(1 for case in cases if case.passed) + return RouteEvalResult( + executed_at=utcnow_iso(), + scenario_file=str(scenario_file), + passed=passed_count == len(cases), + total=len(cases), + passed_count=passed_count, + failed_count=len(cases) - passed_count, + cases=cases, + ) + + +def save_route_eval_result(results_dir: Path, result: RouteEvalResult) -> Path: + results_dir.mkdir(parents=True, exist_ok=True) + timestamp = datetime.now(tz=timezone.utc).strftime("%Y%m%dT%H%M%SZ") + path = results_dir / f"route-eval-{timestamp}.json" + path.write_text(result.model_dump_json(indent=2), encoding="utf-8") + return path + + +class EvalProvider(Provider): + def __init__( + self, + settings: Settings, + provider_id: str, + *, + unavailable_providers: list[str] | None = None, + ) -> None: + snapshot = collect_system_snapshot(settings.root_dir, enable_gpu=settings.enable_gpu) + super().__init__(settings, snapshot) + self.provider_id = provider_id + self.unavailable_providers = set(unavailable_providers or []) + + def healthcheck(self, model: ModelSpec) -> ProviderHealth: + if self.provider_id in self.unavailable_providers: + return ProviderHealth( + provider_id=self.provider_id, + available=False, + healthy=False, + reasons=[f"{self.provider_id} marked unavailable for this evaluation scenario"], + ) + return ProviderHealth(provider_id=self.provider_id, available=True, healthy=True) + + def generate( + self, + model: ModelSpec, + request: ProviderRequest, + *, + progress_callback: ProgressCallback | None = None, + ) -> ExecutionResult: + self._emit_progress( + progress_callback, + phase="evaluate", + message=f"Running evaluation provider for {model.id}.", + progress=0.5, + ) + if model.role == "router": + text = '{"tier_id": "standard"}' + else: + text = f"eval:{model.id}" + return ExecutionResult( + text=text, + finish_reason=FinishReason.STOP, + provider_id=self.provider_id, + model_id=model.id, + duration_seconds=0.0, + ) + + def describe_capabilities(self) -> dict[str, Any]: + return {"evaluation_provider": True} + + +def _eval_providers(settings: Settings, unavailable_providers: list[str]) -> dict[str, Provider]: + providers: dict[str, Provider] = {} + for runtime in ModelRuntime: + providers[runtime.value] = EvalProvider( + settings, + runtime.value, + unavailable_providers=unavailable_providers, + ) + return providers diff --git a/src/lai/jobs/__init__.py b/src/lai/jobs/__init__.py index eeaf73a..96b0b6f 100644 --- a/src/lai/jobs/__init__.py +++ b/src/lai/jobs/__init__.py @@ -1 +1 @@ -"""Job orchestration primitives live here.""" +"""Job persistence and orchestration modules.""" diff --git a/src/lai/jobs/service.py b/src/lai/jobs/service.py new file mode 100644 index 0000000..19ca935 --- /dev/null +++ b/src/lai/jobs/service.py @@ -0,0 +1,1021 @@ +from __future__ import annotations + +import threading +import time +from enum import Enum +from uuid import uuid4 + +from ..artifacts import ArtifactManager +from ..config import AppConfig +from ..domain import ( + ExecutionRequest, + JobError, + JobRecord, + JobStatus, + ProviderProgressUpdate, + ProviderRequest, + QueueMode, + RoutingDecision, + StageEventRecord, + StaleJobSummary, + WorkerStateRecord, + WorkerStatus, + utcnow, +) +from ..errors import RetryableProviderError +from ..observability import summarize_stale_jobs +from ..providers import ProviderRegistry +from ..routing import RoutingEngine +from ..settings import Settings +from .store import JobStore + + +class OrchestrationService: + default_worker_id = "local-worker" + + def __init__( + self, + settings: Settings, + config: AppConfig, + provider_registry: ProviderRegistry, + routing_engine: RoutingEngine, + job_store: JobStore, + artifacts: ArtifactManager, + ) -> None: + self.settings = settings + self.config = config + self.provider_registry = provider_registry + self.routing_engine = routing_engine + self.job_store = job_store + self.artifacts = artifacts + + def submit_request(self, request: ExecutionRequest) -> JobRecord: + route_decision = self.routing_engine.route(request) + now = utcnow() + queue_mode = self._resolved_queue_mode(request.queue_mode, route_decision) + job = JobRecord( + id=str(uuid4()), + status=JobStatus.QUEUED, + created_at=now, + updated_at=now, + queued_at=now, + queue_mode=queue_mode, + request=request, + route_decision=route_decision, + ) + self.job_store.save_job(job) + self.artifacts.write_json( + job.id, "normalized-request", "request.json", request.model_dump() + ) + self.artifacts.write_json( + job.id, "route-decision", "route.json", route_decision.model_dump() + ) + self._record_event( + job_id=job.id, + stage="routing", + event_type="decision", + message=( + f"Matched tier {route_decision.matched_tier_id} with executor " + f"{route_decision.executor_model_id}." + ), + metadata={ + "planner_model_id": route_decision.planner_model_id, + "reviewer_model_id": route_decision.reviewer_model_id, + "queue_recommended": route_decision.queue_recommended, + "fallback_model_ids": route_decision.fallback_model_ids, + }, + ) + self._record_event( + job_id=job.id, + stage="job", + event_type="queued" if queue_mode == QueueMode.QUEUED else "inline-dispatch", + message=( + "Queued for worker execution." + if queue_mode == QueueMode.QUEUED + else "Scheduled for immediate inline execution." + ), + metadata={"queue_mode": queue_mode}, + ) + + if queue_mode == QueueMode.INLINE: + return self.execute_job(job.id) + return self.job_store.get_job(job.id) or job + + def replay_job(self, job_id: str, *, queue_mode: QueueMode | None = None) -> JobRecord: + job = self.job_store.get_job(job_id) + if job is None: + raise KeyError(f"Unknown job id {job_id!r}.") + + metadata = dict(job.request.metadata) + metadata["replayed_from_job_id"] = job.id + replay_request = job.request.model_copy( + update={ + "metadata": metadata, + "queue_mode": queue_mode or job.request.queue_mode, + "source": "replay", + } + ) + replayed_job = self.submit_request(replay_request) + self._record_event( + job_id=replayed_job.id, + stage="job", + event_type="replayed", + message=f"Replayed from job {job.id}.", + metadata={"source_job_id": job.id}, + ) + return self.job_store.get_job(replayed_job.id) or replayed_job + + def get_worker_state(self, worker_id: str | None = None) -> WorkerStateRecord: + resolved_worker_id = worker_id or self.default_worker_id + worker_state = self.job_store.get_worker_state(resolved_worker_id) + if worker_state is not None: + return worker_state + worker_state = WorkerStateRecord( + worker_id=resolved_worker_id, + queued_jobs=self.job_store.count_jobs_by_status(JobStatus.QUEUED), + mode="batch", + ) + self.job_store.save_worker_state(worker_state) + return worker_state + + def update_worker_state( + self, + worker_id: str | None = None, + **updates: object, + ) -> WorkerStateRecord: + return self._update_worker_state(worker_id=worker_id or self.default_worker_id, **updates) + + def cancel_job(self, job_id: str) -> JobRecord | None: + if not self.job_store.cancel_job(job_id): + return None + self._record_event( + job_id=job_id, + stage="job", + event_type="canceled", + message="Job canceled by user action.", + ) + return self.job_store.get_job(job_id) + + def execute_job(self, job_id: str, *, worker_id: str | None = None) -> JobRecord: + job = self.job_store.get_job(job_id) + if job is None: + raise KeyError(f"Unknown job id {job_id!r}.") + if job.status == JobStatus.CANCELED: + return job + + if job.status != JobStatus.RUNNING: + job.attempts += 1 + job.status = JobStatus.RUNNING + job.started_at = job.started_at or utcnow() + job.updated_at = utcnow() + self.job_store.save_job(job) + self._record_event( + job_id=job.id, + stage="job", + event_type="running", + message="Job execution started.", + metadata={"attempt": job.attempts}, + ) + + try: + final_text = self._run_pipeline(job, worker_id=worker_id) + if job.result: + job.result = job.result.model_copy(update={"text": final_text}) + job.status = JobStatus.SUCCEEDED + job.finished_at = utcnow() + job.updated_at = utcnow() + self.job_store.save_job(job) + self.artifacts.write_text(job.id, "final-output", "final_output.txt", final_text) + self._record_event( + job_id=job.id, + stage="job", + event_type="succeeded", + message="Job completed successfully.", + metadata={"artifacts": len(job.artifacts) + 1}, + ) + return self.job_store.get_job(job.id) or job + except Exception as exc: + retryable = _is_retryable_exception(exc) + job.error = JobError( + error_type=type(exc).__name__, message=str(exc), retryable=retryable + ) + job.updated_at = utcnow() + if retryable and job.attempts <= self.settings.max_retry_attempts: + job.status = JobStatus.QUEUED + self.job_store.save_job(job) + self._record_event( + job_id=job.id, + stage="job", + event_type="retry-queued", + message="Retryable failure encountered. Job returned to queue.", + metadata={"error_type": type(exc).__name__}, + ) + else: + job.status = JobStatus.FAILED + job.finished_at = utcnow() + self.job_store.save_job(job) + self._record_event( + job_id=job.id, + stage="job", + event_type="failed", + message=f"Job failed with {type(exc).__name__}.", + metadata={"error_type": type(exc).__name__}, + ) + self.artifacts.write_text(job.id, "error", "error.txt", str(exc)) + return self.job_store.get_job(job.id) or job + + def list_stale_jobs( + self, + *, + stale_after_seconds: int | None = None, + limit: int = 100, + ) -> list[StaleJobSummary]: + threshold = stale_after_seconds or self.settings.stale_running_job_timeout_seconds + running_jobs = self.job_store.list_jobs_by_status(JobStatus.RUNNING, limit=limit) + return summarize_stale_jobs( + running_jobs, + stale_after_seconds=threshold, + ) + + def recover_stale_jobs( + self, + *, + stale_after_seconds: int | None = None, + limit: int = 100, + ) -> list[JobRecord]: + stale_jobs = self.list_stale_jobs( + stale_after_seconds=stale_after_seconds, + limit=limit, + ) + recovered: list[JobRecord] = [] + for stale_job in stale_jobs: + refreshed = self._recover_job( + job_id=stale_job.job_id, + message="Recovered a stale running job and returned it to the queue.", + metadata={ + "recovery_reason": "stale-running-job", + "age_seconds": stale_job.age_seconds, + "stale_after_seconds": stale_job.stale_after_seconds, + "last_event_type": stale_job.last_event_type, + "current_stage": stale_job.current_stage, + }, + ) + if refreshed is not None: + recovered.append(refreshed) + return recovered + + def recover_running_jobs( + self, + *, + limit: int = 100, + message: str = "Recovered an interrupted running job and returned it to the queue.", + metadata: dict[str, object] | None = None, + ) -> list[JobRecord]: + running_jobs = self.job_store.list_jobs_by_status(JobStatus.RUNNING, limit=limit) + recovered: list[JobRecord] = [] + for job in running_jobs: + refreshed = self._recover_job( + job_id=job.id, + message=message, + metadata={ + "recovery_reason": "running-job-recovery", + "attempts": job.attempts, + **(metadata or {}), + }, + ) + if refreshed is not None: + recovered.append(refreshed) + return recovered + + def run_worker_batch( + self, + *, + max_jobs: int = 1, + requeue_running: bool = False, + worker_id: str | None = None, + mode: str = "batch", + ) -> list[JobRecord]: + if max_jobs < 1: + raise ValueError("max_jobs must be at least 1.") + + resolved_worker_id = worker_id or self.default_worker_id + if requeue_running: + requeued = self.job_store.requeue_running_jobs() + else: + requeued = 0 + + self._update_worker_state( + worker_id=resolved_worker_id, + status=WorkerStatus.RUNNING, + mode=mode, + current_job_id=None, + processed_jobs=0, + queued_jobs=self.job_store.count_jobs_by_status(JobStatus.QUEUED), + started_at=utcnow(), + last_error=None, + metadata={"max_jobs": max_jobs, "requeued_running_jobs": requeued}, + ) + + processed: list[JobRecord] = [] + try: + for _ in range(max_jobs): + job = self.job_store.claim_next_queued_job() + if job is None: + break + self._update_worker_state( + worker_id=resolved_worker_id, + status=WorkerStatus.RUNNING, + current_job_id=job.id, + mode=mode, + queued_jobs=self.job_store.count_jobs_by_status(JobStatus.QUEUED), + ) + processed_job = self.execute_job(job.id, worker_id=resolved_worker_id) + processed.append(processed_job) + self._update_worker_state( + worker_id=resolved_worker_id, + status=WorkerStatus.RUNNING, + current_job_id=None, + last_job_id=processed_job.id, + processed_jobs=len(processed), + queued_jobs=self.job_store.count_jobs_by_status(JobStatus.QUEUED), + mode=mode, + ) + except Exception as exc: + self._update_worker_state( + worker_id=resolved_worker_id, + status=WorkerStatus.ERROR, + current_job_id=None, + queued_jobs=self.job_store.count_jobs_by_status(JobStatus.QUEUED), + last_error=f"{type(exc).__name__}: {exc}", + mode=mode, + ) + raise + + self._update_worker_state( + worker_id=resolved_worker_id, + status=WorkerStatus.IDLE, + current_job_id=None, + processed_jobs=len(processed), + queued_jobs=self.job_store.count_jobs_by_status(JobStatus.QUEUED), + mode=mode, + metadata={"max_jobs": max_jobs, "phase": "complete"}, + ) + return processed + + def run_worker_until_idle( + self, + *, + max_idle_cycles: int = 2, + max_jobs: int | None = None, + requeue_running: bool = True, + worker_id: str | None = None, + mode: str = "until-idle", + ) -> int: + if max_idle_cycles < 1: + raise ValueError("max_idle_cycles must be at least 1.") + if max_jobs is not None and max_jobs < 1: + raise ValueError("max_jobs must be at least 1 when provided.") + + resolved_worker_id = worker_id or self.default_worker_id + if requeue_running: + requeued = self.job_store.requeue_running_jobs() + else: + requeued = 0 + + processed = 0 + idle_cycles = 0 + self._update_worker_state( + worker_id=resolved_worker_id, + status=WorkerStatus.RUNNING, + mode=mode, + current_job_id=None, + processed_jobs=0, + queued_jobs=self.job_store.count_jobs_by_status(JobStatus.QUEUED), + started_at=utcnow(), + last_error=None, + metadata={ + "max_idle_cycles": max_idle_cycles, + "max_jobs": max_jobs, + "requeued_running_jobs": requeued, + "phase": "draining", + }, + ) + + try: + while True: + if max_jobs is not None and processed >= max_jobs: + break + + job = self.job_store.claim_next_queued_job() + if job is None: + idle_cycles += 1 + self._update_worker_state( + worker_id=resolved_worker_id, + status=WorkerStatus.RUNNING, + current_job_id=None, + queued_jobs=0, + processed_jobs=processed, + mode=mode, + metadata={ + "max_idle_cycles": max_idle_cycles, + "idle_cycles": idle_cycles, + "phase": "waiting", + }, + ) + if idle_cycles >= max_idle_cycles: + break + time.sleep(self.settings.worker_idle_sleep_seconds) + continue + + idle_cycles = 0 + self._update_worker_state( + worker_id=resolved_worker_id, + status=WorkerStatus.RUNNING, + current_job_id=job.id, + queued_jobs=self.job_store.count_jobs_by_status(JobStatus.QUEUED), + mode=mode, + metadata={"phase": "processing"}, + ) + processed_job = self.execute_job(job.id, worker_id=resolved_worker_id) + processed += 1 + self._update_worker_state( + worker_id=resolved_worker_id, + status=WorkerStatus.RUNNING, + current_job_id=None, + last_job_id=processed_job.id, + processed_jobs=processed, + queued_jobs=self.job_store.count_jobs_by_status(JobStatus.QUEUED), + mode=mode, + metadata={"phase": "draining"}, + ) + except Exception as exc: + self._update_worker_state( + worker_id=resolved_worker_id, + status=WorkerStatus.ERROR, + current_job_id=None, + queued_jobs=self.job_store.count_jobs_by_status(JobStatus.QUEUED), + processed_jobs=processed, + last_error=f"{type(exc).__name__}: {exc}", + mode=mode, + ) + raise + + self._update_worker_state( + worker_id=resolved_worker_id, + status=WorkerStatus.IDLE, + current_job_id=None, + processed_jobs=processed, + queued_jobs=self.job_store.count_jobs_by_status(JobStatus.QUEUED), + mode=mode, + metadata={"phase": "idle", "idle_cycles": idle_cycles}, + ) + return processed + + def run_worker( + self, + once: bool = False, + max_jobs: int | None = None, + mode: str = "continuous", + ) -> int: + if max_jobs is not None and max_jobs < 1: + raise ValueError("max_jobs must be at least 1 when provided.") + + processed = 0 + requeued = self.job_store.requeue_running_jobs() + self._update_worker_state( + worker_id=self.default_worker_id, + status=WorkerStatus.RUNNING, + mode=mode, + current_job_id=None, + processed_jobs=0, + queued_jobs=self.job_store.count_jobs_by_status(JobStatus.QUEUED), + started_at=utcnow(), + last_error=None, + metadata={"requeued_running_jobs": requeued, "phase": "monitoring"}, + ) + while True: + if max_jobs is not None and processed >= max_jobs: + self._update_worker_state( + worker_id=self.default_worker_id, + status=WorkerStatus.IDLE, + current_job_id=None, + processed_jobs=processed, + queued_jobs=self.job_store.count_jobs_by_status(JobStatus.QUEUED), + mode=mode, + metadata={"phase": "idle"}, + ) + return processed + job = self.job_store.claim_next_queued_job() + if job is None: + if once: + self._update_worker_state( + worker_id=self.default_worker_id, + status=WorkerStatus.IDLE, + current_job_id=None, + processed_jobs=processed, + queued_jobs=0, + mode=mode, + metadata={"phase": "idle"}, + ) + return processed + self._update_worker_state( + worker_id=self.default_worker_id, + status=WorkerStatus.RUNNING, + current_job_id=None, + processed_jobs=processed, + queued_jobs=0, + mode=mode, + metadata={"phase": "waiting"}, + ) + time.sleep(self.settings.worker_idle_sleep_seconds) + continue + self._update_worker_state( + worker_id=self.default_worker_id, + status=WorkerStatus.RUNNING, + current_job_id=job.id, + processed_jobs=processed, + queued_jobs=self.job_store.count_jobs_by_status(JobStatus.QUEUED), + mode=mode, + metadata={"phase": "processing"}, + ) + self.execute_job(job.id, worker_id=self.default_worker_id) + processed += 1 + if once: + self._update_worker_state( + worker_id=self.default_worker_id, + status=WorkerStatus.IDLE, + current_job_id=None, + processed_jobs=processed, + last_job_id=job.id, + queued_jobs=self.job_store.count_jobs_by_status(JobStatus.QUEUED), + mode=mode, + metadata={"phase": "idle"}, + ) + return processed + self._update_worker_state( + worker_id=self.default_worker_id, + status=WorkerStatus.RUNNING, + current_job_id=None, + processed_jobs=processed, + last_job_id=job.id, + queued_jobs=self.job_store.count_jobs_by_status(JobStatus.QUEUED), + mode=mode, + metadata={"phase": "monitoring"}, + ) + + def _run_pipeline(self, job: JobRecord, *, worker_id: str | None = None) -> str: + assert job.route_decision is not None + + planning_output = "" + if job.route_decision.should_plan and job.route_decision.planner_model_id: + planning_output = self._run_stage( + job=job, + stage="planner", + model_id=job.route_decision.planner_model_id, + system_prompt=job.request.system_prompt + or ( + "Create a brief internal execution plan. " + "Keep it concise and implementation-focused." + ), + user_prompt=job.request.user_prompt, + worker_id=worker_id, + ) + + executor_prompt = job.request.user_prompt + if planning_output: + executor_prompt = ( + "Internal execution plan:\n" + f"{planning_output}\n\n" + "Follow that plan while answering the original request.\n\n" + f"{job.request.user_prompt}" + ) + + executor_output = self._run_stage( + job=job, + stage="executor", + model_id=job.route_decision.executor_model_id, + system_prompt=job.request.system_prompt, + user_prompt=executor_prompt, + worker_id=worker_id, + ) + + if job.route_decision.should_review and job.route_decision.reviewer_model_id: + reviewer_output = self._run_stage( + job=job, + stage="reviewer", + model_id=job.route_decision.reviewer_model_id, + system_prompt=( + "Review the answer for correctness, completeness, " + "and structure. Keep notes concise." + ), + user_prompt=( + "Original request:\n" + f"{job.request.user_prompt}\n\n" + "Candidate answer:\n" + f"{executor_output}" + ), + worker_id=worker_id, + ) + self.artifacts.write_text( + job.id, "reviewer-notes", "reviewer_notes.txt", reviewer_output + ) + + return executor_output + + def _run_stage( + self, + *, + job: JobRecord, + stage: str, + model_id: str, + system_prompt: str | None, + user_prompt: str, + worker_id: str | None = None, + ) -> str: + model = self.config.model_catalog.get_model(model_id) + provider = self.provider_registry.provider_for_model(model) + provider_id = getattr(provider, "provider_id", model.provider_id) + request = ProviderRequest( + system_prompt=system_prompt, + user_prompt=user_prompt, + metadata={"job_id": job.id, "stage": stage}, + temperature=job.request.temperature, + max_output_tokens=job.request.max_output_tokens, + timeout_seconds=job.request.timeout_seconds, + ) + self._record_event( + job_id=job.id, + stage=stage, + event_type="started", + message=f"{stage.title()} stage started with model {model.id}.", + model_id=model.id, + provider_id=provider_id, + metadata={ + "temperature": request.temperature, + "max_output_tokens": request.max_output_tokens, + "timeout_seconds": request.timeout_seconds, + "prompt_characters": len(user_prompt), + "system_prompt_characters": len(system_prompt or ""), + "request_preview": _preview_text(user_prompt), + }, + ) + self._mark_stage_worker_state( + worker_id=worker_id, + job_id=job.id, + stage=stage, + model_id=model.id, + provider_id=provider_id, + phase="processing", + extra_metadata={ + "current_stage_phase": "starting", + "current_stage_progress": 0.0, + "current_stage_message": f"{stage.title()} stage started.", + }, + ) + health = self.provider_registry.healthcheck(model) + if not health.available: + reason = "; ".join(health.reasons) or "unavailable" + self._record_event( + job_id=job.id, + stage=stage, + event_type="blocked", + message=f"{stage.title()} stage blocked: {reason}", + model_id=model.id, + provider_id=provider_id, + metadata={ + "health_reasons": health.reasons, + "health_metadata": health.metadata, + }, + ) + raise RuntimeError( + f"Model {model.id!r} is not executable: {reason}" + ) + self.artifacts.write_json( + job.id, f"{stage}-request", f"{stage}_request.json", request.model_dump() + ) + try: + with self._stage_heartbeat( + worker_id=worker_id, + job_id=job.id, + stage=stage, + model_id=model.id, + provider_id=provider_id, + ): + result = provider.generate( + model, + request, + progress_callback=self._provider_progress_callback( + worker_id=worker_id, + job_id=job.id, + stage=stage, + model_id=model.id, + provider_id=provider_id, + ), + ) + except Exception as exc: + self._record_event( + job_id=job.id, + stage=stage, + event_type="failed", + message=f"{stage.title()} stage failed with {type(exc).__name__}.", + model_id=model.id, + provider_id=provider_id, + metadata={ + "error_type": type(exc).__name__, + "error_message": str(exc), + }, + ) + self._mark_stage_worker_state( + worker_id=worker_id, + job_id=job.id, + stage=stage, + model_id=model.id, + provider_id=provider_id, + phase="error", + extra_metadata={ + "error_type": type(exc).__name__, + "current_stage_phase": "error", + "current_stage_message": str(exc), + }, + ) + raise + result.stage = stage + self.artifacts.write_json( + job.id, f"{stage}-response", f"{stage}_response.json", result.model_dump() + ) + self._record_event( + job_id=job.id, + stage=stage, + event_type="completed", + message=f"{stage.title()} stage completed.", + model_id=model.id, + provider_id=provider_id, + metadata={ + "duration_seconds": result.duration_seconds, + "finish_reason": result.finish_reason, + "usage": result.usage.model_dump() if result.usage else None, + "output_characters": len(result.text), + "output_preview": _preview_text(result.text), + "raw_keys": sorted(result.raw.keys()), + }, + ) + self._mark_stage_worker_state( + worker_id=worker_id, + job_id=job.id, + stage=stage, + model_id=model.id, + provider_id=provider_id, + phase="processing", + extra_metadata={ + "last_completed_stage": stage, + "last_stage_duration_seconds": result.duration_seconds, + "last_stage_finish_reason": result.finish_reason, + "current_stage_phase": "completed", + "current_stage_progress": 1.0, + "current_stage_message": f"{stage.title()} stage completed.", + }, + ) + if stage == "executor": + job.result = result + job.updated_at = utcnow() + self.job_store.save_job(job) + return result.text + + def _mark_stage_worker_state( + self, + *, + worker_id: str | None, + job_id: str, + stage: str, + model_id: str, + provider_id: str, + phase: str, + extra_metadata: dict[str, object] | None = None, + ) -> None: + if worker_id is None: + return + current_state = self.get_worker_state(worker_id) + metadata = dict(current_state.metadata) + metadata.update( + { + "phase": phase, + "current_stage": stage, + "current_model_id": model_id, + "current_provider_id": provider_id, + "last_stage_heartbeat_at": utcnow().isoformat(), + } + ) + if extra_metadata: + metadata.update(extra_metadata) + self._update_worker_state( + worker_id=worker_id, + status=WorkerStatus.RUNNING, + current_job_id=job_id, + queued_jobs=self.job_store.count_jobs_by_status(JobStatus.QUEUED), + mode=current_state.mode, + metadata=metadata, + ) + + def _provider_progress_callback( + self, + *, + worker_id: str | None, + job_id: str, + stage: str, + model_id: str, + provider_id: str, + ): + def _callback(update: ProviderProgressUpdate) -> None: + metadata = dict(update.metadata) + metadata["phase"] = update.phase + if update.progress is not None: + metadata["progress"] = update.progress + self._record_event( + job_id=job_id, + stage=stage, + event_type="progress", + message=update.message, + model_id=model_id, + provider_id=provider_id, + metadata=metadata, + ) + self._mark_stage_worker_state( + worker_id=worker_id, + job_id=job_id, + stage=stage, + model_id=model_id, + provider_id=provider_id, + phase="processing", + extra_metadata={ + "current_stage_phase": update.phase, + "current_stage_message": update.message, + "current_stage_updated_at": utcnow().isoformat(), + **( + {"current_stage_progress": update.progress} + if update.progress is not None + else {} + ), + }, + ) + + return _callback + + def _stage_heartbeat( + self, + *, + worker_id: str | None, + job_id: str, + stage: str, + model_id: str, + provider_id: str, + ): + interval = self.settings.job_heartbeat_interval_seconds + if interval <= 0: + return _NullContext() + + service = self + stop_event = threading.Event() + started_at = utcnow() + + class _StageHeartbeatContext: + def __enter__(self_inner): + def _heartbeat_loop() -> None: + while not stop_event.wait(interval): + elapsed_seconds = round((utcnow() - started_at).total_seconds(), 2) + service._record_event( + job_id=job_id, + stage=stage, + event_type="heartbeat", + message=f"{stage.title()} stage is still running.", + model_id=model_id, + provider_id=provider_id, + metadata={"elapsed_seconds": elapsed_seconds}, + ) + service._mark_stage_worker_state( + worker_id=worker_id, + job_id=job_id, + stage=stage, + model_id=model_id, + provider_id=provider_id, + phase="processing", + extra_metadata={ + "elapsed_seconds": elapsed_seconds, + "current_stage_updated_at": utcnow().isoformat(), + }, + ) + + self_inner.thread = threading.Thread( + target=_heartbeat_loop, + name=f"lai-stage-heartbeat-{job_id}-{stage}", + daemon=True, + ) + self_inner.thread.start() + return self_inner + + def __exit__(self_inner, *_args: object) -> None: + stop_event.set() + self_inner.thread.join(timeout=interval + 1) + + return _StageHeartbeatContext() + + def _record_event( + self, + *, + job_id: str, + stage: str, + event_type: str, + message: str, + model_id: str | None = None, + provider_id: str | None = None, + metadata: dict[str, object] | None = None, + ) -> StageEventRecord: + event = StageEventRecord( + id=str(uuid4()), + job_id=job_id, + stage=stage, + event_type=event_type, + message=message, + model_id=model_id, + provider_id=provider_id, + metadata=metadata or {}, + ) + self.job_store.add_stage_event(event) + return event + + def _recover_job( + self, + *, + job_id: str, + message: str, + metadata: dict[str, object] | None = None, + ) -> JobRecord | None: + if not self.job_store.requeue_job(job_id): + return None + self._record_event( + job_id=job_id, + stage="job", + event_type="recovered", + message=message, + metadata=metadata, + ) + return self.job_store.get_job(job_id) + + def _update_worker_state( + self, + *, + worker_id: str, + **updates: object, + ) -> WorkerStateRecord: + current = self.get_worker_state(worker_id) + payload = current.model_dump() + if ( + isinstance(payload.get("metadata"), dict) + and isinstance(updates.get("metadata"), dict) + ): + merged_metadata = dict(payload["metadata"]) + merged_metadata.update(updates["metadata"]) + updates = {**updates, "metadata": merged_metadata} + payload.update(updates) + now = utcnow() + payload["heartbeat_at"] = now + payload["updated_at"] = now + worker_state = WorkerStateRecord.model_validate(payload) + self.job_store.save_worker_state(worker_state) + return worker_state + + @staticmethod + def _resolved_queue_mode(queue_mode: QueueMode, route_decision: RoutingDecision) -> QueueMode: + queue_mode_value = _enum_value(queue_mode) + if queue_mode_value == "auto": + return QueueMode.QUEUED if route_decision.queue_recommended else QueueMode.INLINE + return QueueMode(queue_mode_value) + + +def _is_retryable_exception(exc: Exception) -> bool: + if isinstance(exc, RetryableProviderError): + return True + name = type(exc).__name__.lower() + module = type(exc).__module__.lower() + retry_markers = ("timeout", "connection", "rate", "tempor") + return any(marker in name or marker in module for marker in retry_markers) + + +def _enum_value(value: object) -> str: + if isinstance(value, Enum): + return str(value.value) + return str(value) + + +def _preview_text(value: str, *, limit: int = 240) -> str: + normalized = " ".join(value.split()) + if len(normalized) <= limit: + return normalized + return normalized[: limit - 3] + "..." + + +class _NullContext: + def __enter__(self): + return self + + def __exit__(self, *_args: object) -> None: + return None diff --git a/src/lai/jobs/store.py b/src/lai/jobs/store.py new file mode 100644 index 0000000..8e76afb --- /dev/null +++ b/src/lai/jobs/store.py @@ -0,0 +1,408 @@ +from __future__ import annotations + +import json +import sqlite3 +from pathlib import Path + +from ..domain import ( + ArtifactRecord, + JobRecord, + JobStatus, + StageEventRecord, + WorkerStateRecord, + utcnow, +) + + +class JobStore: + def __init__(self, database_path: Path) -> None: + self.database_path = database_path + self.database_path.parent.mkdir(parents=True, exist_ok=True) + + def initialize(self) -> None: + with self._connect() as connection: + connection.executescript( + """ + CREATE TABLE IF NOT EXISTS jobs ( + id TEXT PRIMARY KEY, + status TEXT NOT NULL, + created_at TEXT NOT NULL, + updated_at TEXT NOT NULL, + queued_at TEXT, + started_at TEXT, + finished_at TEXT, + queue_mode TEXT NOT NULL, + request_json TEXT NOT NULL, + route_json TEXT, + result_json TEXT, + error_json TEXT, + attempts INTEGER NOT NULL DEFAULT 0 + ); + + CREATE TABLE IF NOT EXISTS artifacts ( + id TEXT PRIMARY KEY, + job_id TEXT NOT NULL, + artifact_type TEXT NOT NULL, + relative_path TEXT NOT NULL, + created_at TEXT NOT NULL, + metadata_json TEXT NOT NULL, + FOREIGN KEY(job_id) REFERENCES jobs(id) + ); + + CREATE TABLE IF NOT EXISTS stage_events ( + id TEXT PRIMARY KEY, + job_id TEXT NOT NULL, + stage TEXT NOT NULL, + event_type TEXT NOT NULL, + message TEXT NOT NULL, + created_at TEXT NOT NULL, + model_id TEXT, + provider_id TEXT, + metadata_json TEXT NOT NULL, + FOREIGN KEY(job_id) REFERENCES jobs(id) + ); + + CREATE TABLE IF NOT EXISTS worker_state ( + worker_id TEXT PRIMARY KEY, + status TEXT NOT NULL, + mode TEXT NOT NULL, + current_job_id TEXT, + last_job_id TEXT, + processed_jobs INTEGER NOT NULL, + queued_jobs INTEGER NOT NULL, + started_at TEXT, + heartbeat_at TEXT NOT NULL, + updated_at TEXT NOT NULL, + last_error TEXT, + metadata_json TEXT NOT NULL + ); + """ + ) + + def save_job(self, job: JobRecord) -> None: + with self._connect() as connection: + connection.execute( + """ + INSERT INTO jobs ( + id, status, created_at, updated_at, queued_at, started_at, finished_at, + queue_mode, request_json, route_json, result_json, error_json, attempts + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + ON CONFLICT(id) DO UPDATE SET + status=excluded.status, + updated_at=excluded.updated_at, + queued_at=excluded.queued_at, + started_at=excluded.started_at, + finished_at=excluded.finished_at, + queue_mode=excluded.queue_mode, + request_json=excluded.request_json, + route_json=excluded.route_json, + result_json=excluded.result_json, + error_json=excluded.error_json, + attempts=excluded.attempts + """, + ( + job.id, + job.status, + job.created_at.isoformat(), + job.updated_at.isoformat(), + job.queued_at.isoformat() if job.queued_at else None, + job.started_at.isoformat() if job.started_at else None, + job.finished_at.isoformat() if job.finished_at else None, + job.queue_mode, + job.request.model_dump_json(), + job.route_decision.model_dump_json() if job.route_decision else None, + job.result.model_dump_json() if job.result else None, + job.error.model_dump_json() if job.error else None, + job.attempts, + ), + ) + + def add_artifact(self, artifact: ArtifactRecord) -> None: + with self._connect() as connection: + connection.execute( + """ + INSERT OR REPLACE INTO artifacts ( + id, job_id, artifact_type, relative_path, created_at, metadata_json + ) VALUES (?, ?, ?, ?, ?, ?) + """, + ( + artifact.id, + artifact.job_id, + artifact.artifact_type, + artifact.relative_path, + artifact.created_at.isoformat(), + json.dumps(artifact.metadata), + ), + ) + + def add_stage_event(self, event: StageEventRecord) -> None: + with self._connect() as connection: + connection.execute( + """ + INSERT OR REPLACE INTO stage_events ( + id, job_id, stage, event_type, message, created_at, + model_id, provider_id, metadata_json + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) + """, + ( + event.id, + event.job_id, + event.stage, + event.event_type, + event.message, + event.created_at.isoformat(), + event.model_id, + event.provider_id, + json.dumps(event.metadata), + ), + ) + + def save_worker_state(self, worker_state: WorkerStateRecord) -> None: + with self._connect() as connection: + connection.execute( + """ + INSERT INTO worker_state ( + worker_id, status, mode, current_job_id, last_job_id, processed_jobs, + queued_jobs, started_at, heartbeat_at, updated_at, last_error, metadata_json + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + ON CONFLICT(worker_id) DO UPDATE SET + status=excluded.status, + mode=excluded.mode, + current_job_id=excluded.current_job_id, + last_job_id=excluded.last_job_id, + processed_jobs=excluded.processed_jobs, + queued_jobs=excluded.queued_jobs, + started_at=excluded.started_at, + heartbeat_at=excluded.heartbeat_at, + updated_at=excluded.updated_at, + last_error=excluded.last_error, + metadata_json=excluded.metadata_json + """, + ( + worker_state.worker_id, + worker_state.status, + worker_state.mode, + worker_state.current_job_id, + worker_state.last_job_id, + worker_state.processed_jobs, + worker_state.queued_jobs, + worker_state.started_at.isoformat() if worker_state.started_at else None, + worker_state.heartbeat_at.isoformat(), + worker_state.updated_at.isoformat(), + worker_state.last_error, + json.dumps(worker_state.metadata), + ), + ) + + def get_worker_state(self, worker_id: str = "local-worker") -> WorkerStateRecord | None: + with self._connect() as connection: + row = connection.execute( + "SELECT * FROM worker_state WHERE worker_id = ?", + (worker_id,), + ).fetchone() + if row is None: + return None + return WorkerStateRecord.model_validate( + { + "worker_id": row["worker_id"], + "status": row["status"], + "mode": row["mode"], + "current_job_id": row["current_job_id"], + "last_job_id": row["last_job_id"], + "processed_jobs": row["processed_jobs"], + "queued_jobs": row["queued_jobs"], + "started_at": row["started_at"], + "heartbeat_at": row["heartbeat_at"], + "updated_at": row["updated_at"], + "last_error": row["last_error"], + "metadata": json.loads(row["metadata_json"]), + } + ) + + def get_job(self, job_id: str) -> JobRecord | None: + with self._connect() as connection: + row = connection.execute("SELECT * FROM jobs WHERE id = ?", (job_id,)).fetchone() + if row is None: + return None + return self._deserialize_job(connection, row) + + def list_jobs(self, limit: int = 20) -> list[JobRecord]: + with self._connect() as connection: + rows = connection.execute( + "SELECT * FROM jobs ORDER BY created_at DESC LIMIT ?", + (limit,), + ).fetchall() + return [self._deserialize_job(connection, row) for row in rows] + + def list_jobs_by_status( + self, + status: JobStatus, + *, + limit: int = 100, + ) -> list[JobRecord]: + with self._connect() as connection: + rows = connection.execute( + "SELECT * FROM jobs WHERE status = ? ORDER BY created_at DESC LIMIT ?", + (status, limit), + ).fetchall() + return [self._deserialize_job(connection, row) for row in rows] + + def count_jobs_by_status(self, status: JobStatus) -> int: + with self._connect() as connection: + row = connection.execute( + "SELECT COUNT(*) AS total FROM jobs WHERE status = ?", + (status,), + ).fetchone() + return int(row["total"]) if row else 0 + + def cancel_job(self, job_id: str) -> bool: + with self._connect() as connection: + result = connection.execute( + """ + UPDATE jobs + SET status = ?, updated_at = ?, finished_at = ? + WHERE id = ? AND status IN (?, ?) + """, + ( + JobStatus.CANCELED, + utcnow().isoformat(), + utcnow().isoformat(), + job_id, + JobStatus.QUEUED, + JobStatus.RUNNING, + ), + ) + return result.rowcount > 0 + + def claim_next_queued_job(self) -> JobRecord | None: + with self._connect() as connection: + connection.execute("BEGIN IMMEDIATE") + row = connection.execute( + "SELECT * FROM jobs WHERE status = ? ORDER BY created_at ASC LIMIT 1", + (JobStatus.QUEUED,), + ).fetchone() + if row is None: + connection.commit() + return None + now = utcnow().isoformat() + connection.execute( + """ + UPDATE jobs + SET status = ?, + updated_at = ?, + started_at = COALESCE(started_at, ?), + attempts = attempts + 1 + WHERE id = ? + """, + (JobStatus.RUNNING, now, now, row["id"]), + ) + updated = connection.execute("SELECT * FROM jobs WHERE id = ?", (row["id"],)).fetchone() + connection.commit() + return self._deserialize_job(connection, updated) + + def requeue_running_jobs(self) -> int: + with self._connect() as connection: + now = utcnow().isoformat() + result = connection.execute( + """ + UPDATE jobs + SET status = ?, + updated_at = ?, + queued_at = ?, + started_at = NULL, + finished_at = NULL + WHERE status = ? + """, + (JobStatus.QUEUED, now, now, JobStatus.RUNNING), + ) + return result.rowcount + + def requeue_job(self, job_id: str) -> bool: + with self._connect() as connection: + now = utcnow().isoformat() + result = connection.execute( + """ + UPDATE jobs + SET status = ?, + updated_at = ?, + queued_at = ?, + started_at = NULL, + finished_at = NULL + WHERE id = ? AND status = ? + """, + (JobStatus.QUEUED, now, now, job_id, JobStatus.RUNNING), + ) + return result.rowcount > 0 + + def _deserialize_job(self, connection: sqlite3.Connection, row: sqlite3.Row) -> JobRecord: + artifacts = self._artifacts_for_job(connection, row["id"]) + stage_events = self._stage_events_for_job(connection, row["id"]) + return JobRecord.model_validate( + { + "id": row["id"], + "status": row["status"], + "created_at": row["created_at"], + "updated_at": row["updated_at"], + "queued_at": row["queued_at"], + "started_at": row["started_at"], + "finished_at": row["finished_at"], + "queue_mode": row["queue_mode"], + "request": json.loads(row["request_json"]), + "route_decision": json.loads(row["route_json"]) if row["route_json"] else None, + "result": json.loads(row["result_json"]) if row["result_json"] else None, + "error": json.loads(row["error_json"]) if row["error_json"] else None, + "attempts": row["attempts"], + "artifacts": artifacts, + "stage_events": stage_events, + } + ) + + def _artifacts_for_job( + self, connection: sqlite3.Connection, job_id: str + ) -> list[ArtifactRecord]: + rows = connection.execute( + "SELECT * FROM artifacts WHERE job_id = ? ORDER BY created_at ASC", + (job_id,), + ).fetchall() + return [ + ArtifactRecord.model_validate( + { + "id": row["id"], + "job_id": row["job_id"], + "artifact_type": row["artifact_type"], + "relative_path": row["relative_path"], + "created_at": row["created_at"], + "metadata": json.loads(row["metadata_json"]), + } + ) + for row in rows + ] + + def _stage_events_for_job( + self, connection: sqlite3.Connection, job_id: str + ) -> list[StageEventRecord]: + rows = connection.execute( + "SELECT * FROM stage_events WHERE job_id = ? ORDER BY created_at ASC", + (job_id,), + ).fetchall() + return [ + StageEventRecord.model_validate( + { + "id": row["id"], + "job_id": row["job_id"], + "stage": row["stage"], + "event_type": row["event_type"], + "message": row["message"], + "created_at": row["created_at"], + "model_id": row["model_id"], + "provider_id": row["provider_id"], + "metadata": json.loads(row["metadata_json"]), + } + ) + for row in rows + ] + + def _connect(self) -> sqlite3.Connection: + connection = sqlite3.connect(self.database_path) + connection.row_factory = sqlite3.Row + return connection diff --git a/src/lai/layout.py b/src/lai/layout.py index 86a8b8f..f18f223 100644 --- a/src/lai/layout.py +++ b/src/lai/layout.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from pathlib import Path RUNTIME_DIRECTORIES = ( @@ -5,6 +7,7 @@ Path("data/models/airllm-shards"), Path("data/models/raw"), Path("data/artifacts"), + Path("data/state"), Path("logs"), ) @@ -12,3 +15,10 @@ def runtime_directories(root_dir: Path) -> list[Path]: """Resolve the runtime directories relative to the repository root.""" return [root_dir / path for path in RUNTIME_DIRECTORIES] + + +def ensure_runtime_directories(root_dir: Path) -> list[Path]: + resolved = runtime_directories(root_dir) + for path in resolved: + path.mkdir(parents=True, exist_ok=True) + return resolved diff --git a/src/lai/observability.py b/src/lai/observability.py new file mode 100644 index 0000000..344594c --- /dev/null +++ b/src/lai/observability.py @@ -0,0 +1,199 @@ +from __future__ import annotations + +from datetime import datetime, timezone + +from .domain import ( + ArtifactRecord, + JobRecord, + StageEventRecord, + StageExecutionSummary, + StaleJobSummary, + UsageStats, +) + +EXECUTION_STAGES = ("planner", "executor", "reviewer") +STALE_ACTIVITY_EVENT_TYPES = { + "running", + "started", + "progress", + "heartbeat", + "completed", + "failed", + "blocked", + "retry-queued", + "recovered", + "succeeded", + "canceled", +} + + +def summarize_job_execution(job: JobRecord) -> list[StageExecutionSummary]: + summaries: list[StageExecutionSummary] = [] + for stage in EXECUTION_STAGES: + stage_events = [event for event in job.stage_events if event.stage == stage] + if not stage_events: + continue + + progress_events = [event for event in stage_events if event.event_type == "progress"] + latest_progress = progress_events[-1] if progress_events else None + started = _first_event(stage_events, "started") + completed = _last_event(stage_events, "completed") + failed = _last_event(stage_events, "failed") + blocked = _last_event(stage_events, "blocked") + final_event = completed or failed or blocked or stage_events[-1] + request_metadata = started.metadata if started is not None else {} + final_metadata = final_event.metadata if final_event is not None else {} + status = _resolved_status(started, completed, failed, blocked) + usage_payload = final_metadata.get("usage") + + stage_artifacts = _stage_artifacts(job.artifacts, stage) + summaries.append( + StageExecutionSummary( + stage=stage, + status=status, + model_id=(final_event.model_id if final_event is not None else None) + or (started.model_id if started is not None else None), + provider_id=(final_event.provider_id if final_event is not None else None) + or (started.provider_id if started is not None else None), + started_at=started.created_at if started is not None else None, + finished_at=( + final_event.created_at + if final_event is not None and status in {"completed", "failed", "blocked"} + else None + ), + duration_seconds=final_metadata.get("duration_seconds"), + finish_reason=final_metadata.get("finish_reason"), + prompt_characters=request_metadata.get("prompt_characters"), + system_prompt_characters=request_metadata.get("system_prompt_characters"), + output_characters=final_metadata.get("output_characters"), + output_preview=final_metadata.get("output_preview"), + temperature=request_metadata.get("temperature"), + max_output_tokens=request_metadata.get("max_output_tokens"), + timeout_seconds=request_metadata.get("timeout_seconds"), + usage=( + UsageStats.model_validate(usage_payload) + if isinstance(usage_payload, dict) + else None + ), + artifact_ids=[artifact.id for artifact in stage_artifacts], + artifact_types=[artifact.artifact_type for artifact in stage_artifacts], + health_reasons=list(final_metadata.get("health_reasons") or []), + error_type=final_metadata.get("error_type"), + last_message=( + latest_progress.message + if status == "running" and latest_progress is not None + else (final_event.message if final_event is not None else None) + ), + progress_event_count=len(progress_events), + latest_progress_message=( + latest_progress.message if latest_progress is not None else None + ), + latest_progress_phase=( + str(latest_progress.metadata.get("phase")) + if latest_progress is not None + and latest_progress.metadata.get("phase") is not None + else None + ), + latest_progress_at=( + latest_progress.created_at if latest_progress is not None else None + ), + latest_progress_percent=( + float(latest_progress.metadata["progress"]) + if latest_progress is not None + and latest_progress.metadata.get("progress") is not None + else None + ), + ) + ) + + return summaries + + +def summarize_stale_jobs( + jobs: list[JobRecord], + *, + stale_after_seconds: int, + now: datetime | None = None, +) -> list[StaleJobSummary]: + now = now or datetime.now(tz=timezone.utc) + summaries: list[StaleJobSummary] = [] + for job in jobs: + activity_events = [ + event for event in job.stage_events if event.event_type in STALE_ACTIVITY_EVENT_TYPES + ] + last_event = activity_events[-1] if activity_events else ( + job.stage_events[-1] if job.stage_events else None + ) + last_activity = latest_job_activity(job, stage_events=activity_events) + age_seconds = (now - last_activity).total_seconds() + if age_seconds < stale_after_seconds: + continue + summaries.append( + StaleJobSummary( + job_id=job.id, + queue_mode=job.queue_mode, + attempts=job.attempts, + executor_model_id=( + job.route_decision.executor_model_id if job.route_decision else None + ), + current_stage=last_event.stage if last_event is not None else None, + last_event_type=last_event.event_type if last_event is not None else None, + last_message=last_event.message if last_event is not None else None, + last_activity_at=last_activity, + stale_after_seconds=stale_after_seconds, + age_seconds=round(age_seconds, 2), + ) + ) + return summaries + + +def latest_job_activity( + job: JobRecord, + *, + stage_events: list[StageEventRecord] | None = None, +) -> datetime: + candidates = [job.updated_at, job.started_at, job.finished_at] + events = stage_events if stage_events is not None else job.stage_events + candidates.extend(event.created_at for event in events) + activity = [candidate for candidate in candidates if candidate is not None] + return max(activity) if activity else job.created_at + + +def _stage_artifacts(artifacts: list[ArtifactRecord], stage: str) -> list[ArtifactRecord]: + target_types = {f"{stage}-request", f"{stage}-response"} + if stage == "reviewer": + target_types.add("reviewer-notes") + if stage == "executor": + target_types.add("final-output") + return [artifact for artifact in artifacts if artifact.artifact_type in target_types] + + +def _first_event(events: list[StageEventRecord], event_type: str) -> StageEventRecord | None: + for event in events: + if event.event_type == event_type: + return event + return None + + +def _last_event(events: list[StageEventRecord], event_type: str) -> StageEventRecord | None: + for event in reversed(events): + if event.event_type == event_type: + return event + return None + + +def _resolved_status( + started: StageEventRecord | None, + completed: StageEventRecord | None, + failed: StageEventRecord | None, + blocked: StageEventRecord | None, +) -> str: + if completed is not None: + return "completed" + if failed is not None: + return "failed" + if blocked is not None: + return "blocked" + if started is not None: + return "running" + return "pending" diff --git a/src/lai/providers/__init__.py b/src/lai/providers/__init__.py index 42d72b4..a5001c2 100644 --- a/src/lai/providers/__init__.py +++ b/src/lai/providers/__init__.py @@ -1 +1,4 @@ -"""Runtime provider adapters live here.""" +from .base import Provider +from .registry import ProviderRegistry + +__all__ = ["Provider", "ProviderRegistry"] diff --git a/src/lai/providers/base.py b/src/lai/providers/base.py new file mode 100644 index 0000000..83362be --- /dev/null +++ b/src/lai/providers/base.py @@ -0,0 +1,128 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from time import perf_counter +from typing import Any, Callable + +from ..domain import ( + ExecutionResult, + FinishReason, + ModelSpec, + ProviderHealth, + ProviderProgressUpdate, + ProviderRequest, +) +from ..settings import Settings +from ..system import SystemSnapshot + +ProgressCallback = Callable[[ProviderProgressUpdate], None] + + +def serialize_raw_object(value: Any) -> dict[str, Any]: + if hasattr(value, "model_dump"): + try: + return value.model_dump() + except Exception: + pass + if hasattr(value, "dict"): + try: + return value.dict() + except Exception: + pass + return {"repr": repr(value)} + + +class Provider(ABC): + provider_id: str + + def __init__(self, settings: Settings, system_snapshot: SystemSnapshot) -> None: + self.settings = settings + self.system_snapshot = system_snapshot + + @abstractmethod + def healthcheck(self, model: ModelSpec) -> ProviderHealth: + raise NotImplementedError + + def available(self, model: ModelSpec) -> ProviderHealth: + return self.healthcheck(model) + + @abstractmethod + def generate( + self, + model: ModelSpec, + request: ProviderRequest, + *, + progress_callback: ProgressCallback | None = None, + ) -> ExecutionResult: + raise NotImplementedError + + @abstractmethod + def describe_capabilities(self) -> dict[str, Any]: + raise NotImplementedError + + def _result( + self, + *, + text: str, + provider_id: str, + model_id: str, + duration_seconds: float, + finish_reason: FinishReason = FinishReason.STOP, + usage: dict[str, Any] | None = None, + raw: dict[str, Any] | None = None, + stage: str = "executor", + ) -> ExecutionResult: + return ExecutionResult( + text=text, + provider_id=provider_id, + model_id=model_id, + duration_seconds=duration_seconds, + finish_reason=finish_reason, + usage=usage, + raw=raw or {}, + stage=stage, + ) + + def _missing_dependency_health(self, provider_id: str, dependency_name: str) -> ProviderHealth: + return ProviderHealth( + provider_id=provider_id, + available=False, + healthy=False, + reasons=[f"Missing optional dependency {dependency_name!r}."], + ) + + def _credential_health(self, provider_id: str, credential_name: str) -> ProviderHealth: + return ProviderHealth( + provider_id=provider_id, + available=False, + healthy=False, + reasons=[f"Missing credential {credential_name!r}."], + ) + + def _ok_health(self, provider_id: str, **metadata: Any) -> ProviderHealth: + return ProviderHealth( + provider_id=provider_id, available=True, healthy=True, metadata=metadata + ) + + def _timer(self) -> float: + return perf_counter() + + def _emit_progress( + self, + progress_callback: ProgressCallback | None, + *, + phase: str, + message: str, + progress: float | None = None, + **metadata: Any, + ) -> None: + if progress_callback is None: + return + progress_callback( + ProviderProgressUpdate( + phase=phase, + message=message, + progress=progress, + metadata=metadata, + ) + ) diff --git a/src/lai/providers/implementations.py b/src/lai/providers/implementations.py new file mode 100644 index 0000000..831687b --- /dev/null +++ b/src/lai/providers/implementations.py @@ -0,0 +1,493 @@ +from __future__ import annotations + +from time import perf_counter +from typing import Any + +from ..domain import ( + FinishReason, + ModelRuntime, + ModelSpec, + ProviderHealth, + ProviderRequest, + UsageStats, +) +from ..settings import Settings +from ..system import available_disk_gb +from .base import ProgressCallback, Provider, serialize_raw_object + + +class TransformersProvider(Provider): + provider_id = ModelRuntime.TRANSFORMERS.value + + def __init__(self, settings: Settings, system_snapshot) -> None: # type: ignore[override] + super().__init__(settings, system_snapshot) + self._pipeline_cache: dict[str, Any] = {} + + def healthcheck(self, model: ModelSpec) -> ProviderHealth: + try: + import transformers # noqa: F401 + except ImportError: + return self._missing_dependency_health(self.provider_id, "transformers") + + if model.model_ref.startswith("meta-llama/") and not self.settings.huggingface_token_value: + return self._credential_health(self.provider_id, "LAI_HF_TOKEN") + + return self._ok_health(self.provider_id, model_ref=model.model_ref) + + def describe_capabilities(self) -> dict[str, Any]: + return { + "supports_local_generation": True, + "supports_streaming": False, + "requires_network_once": True, + } + + def generate( + self, + model: ModelSpec, + request: ProviderRequest, + *, + progress_callback: ProgressCallback | None = None, + ): + from transformers import pipeline + + started = perf_counter() + pipe = self._pipeline_cache.get(model.id) + if pipe is None: + self._emit_progress( + progress_callback, + phase="loading-pipeline", + message=f"Loading Transformers pipeline for {model.id}.", + progress=0.1, + cached=False, + ) + model_kwargs: dict[str, Any] = { + "torch_dtype": "auto", + } + pipe = pipeline( + "text-generation", + model=model.model_ref, + tokenizer=model.model_ref, + device_map="auto" if self.system_snapshot.has_gpu else None, + trust_remote_code=True, + token=self.settings.huggingface_token_value, + model_kwargs=model_kwargs, + ) + self._pipeline_cache[model.id] = pipe + else: + self._emit_progress( + progress_callback, + phase="pipeline-ready", + message=f"Reusing cached Transformers pipeline for {model.id}.", + progress=0.2, + cached=True, + ) + + prompt = _render_prompt(request.system_prompt, request.user_prompt) + self._emit_progress( + progress_callback, + phase="generating", + message=f"Generating text with {model.id}.", + progress=0.55, + prompt_characters=len(prompt), + ) + result = pipe( + prompt, + max_new_tokens=request.max_output_tokens, + temperature=request.temperature, + return_full_text=False, + ) + self._emit_progress( + progress_callback, + phase="parsing-output", + message=f"Parsing model output for {model.id}.", + progress=0.9, + result_count=len(result), + ) + text = result[0]["generated_text"].strip() if result else "" + duration = perf_counter() - started + return self._result( + text=text, + provider_id=self.provider_id, + model_id=model.id, + duration_seconds=duration, + finish_reason=FinishReason.STOP, + raw={"result_count": len(result)}, + ) + + +class AirLLMProvider(Provider): + provider_id = ModelRuntime.AIRLLM.value + + def __init__(self, settings: Settings, system_snapshot) -> None: # type: ignore[override] + super().__init__(settings, system_snapshot) + self._model_cache: dict[str, Any] = {} + + def healthcheck(self, model: ModelSpec) -> ProviderHealth: + try: + from airllm import AutoModel # noqa: F401 + except ImportError: + return self._missing_dependency_health(self.provider_id, "airllm") + + if not self.settings.huggingface_token_value: + return self._credential_health(self.provider_id, "LAI_HF_TOKEN") + + free_disk = available_disk_gb(self.settings.resolved_airllm_shards_dir) + expected_disk = model.hardware.expected_disk_gb or 0 + if expected_disk and free_disk < expected_disk: + return ProviderHealth( + provider_id=self.provider_id, + available=False, + healthy=False, + reasons=[ + "Insufficient disk for AirLLM shards. " + f"Need about {expected_disk} GB, found {free_disk:.2f} GB." + ], + metadata={"free_disk_gb": free_disk}, + ) + + if not self.system_snapshot.has_gpu and not model.runtime_hints.allow_cpu_fallback: + return ProviderHealth( + provider_id=self.provider_id, + available=False, + healthy=False, + reasons=["GPU unavailable and CPU fallback is disabled for this model."], + ) + + return self._ok_health( + self.provider_id, + model_ref=model.model_ref, + free_disk_gb=free_disk, + gpu_name=self.system_snapshot.gpu_name, + ) + + def describe_capabilities(self) -> dict[str, Any]: + return { + "supports_local_generation": True, + "supports_streaming": False, + "supports_layer_sharding": True, + "supports_cpu_fallback": True, + } + + def generate( + self, + model: ModelSpec, + request: ProviderRequest, + *, + progress_callback: ProgressCallback | None = None, + ): + from airllm import AutoModel + + started = perf_counter() + loaded_model = self._model_cache.get(model.id) + if loaded_model is None: + self._emit_progress( + progress_callback, + phase="loading-shards", + message=f"Loading AirLLM model shards for {model.id}.", + progress=0.1, + cached=False, + ) + loaded_model = AutoModel.from_pretrained( + model.model_ref, + hf_token=self.settings.huggingface_token_value, + layer_shards_saving_path=str(self.settings.resolved_airllm_shards_dir / model.id), + ) + self._model_cache[model.id] = loaded_model + else: + self._emit_progress( + progress_callback, + phase="model-ready", + message=f"Reusing cached AirLLM model for {model.id}.", + progress=0.2, + cached=True, + ) + + prompt = _render_prompt(request.system_prompt, request.user_prompt) + self._emit_progress( + progress_callback, + phase="tokenizing", + message=f"Tokenizing prompt for {model.id}.", + progress=0.35, + prompt_characters=len(prompt), + ) + tokenized = loaded_model.tokenizer( + [prompt], + return_tensors="pt", + return_attention_mask=False, + truncation=True, + max_length=min(model.context_window, self.settings.max_router_tokens * 2), + padding=False, + ) + + input_ids = tokenized["input_ids"] + if self.system_snapshot.has_gpu: + self._emit_progress( + progress_callback, + phase="transferring-input", + message=f"Moving prompt tensors to GPU for {model.id}.", + progress=0.45, + ) + input_ids = input_ids.cuda() + + self._emit_progress( + progress_callback, + phase="generating", + message=f"Generating response with AirLLM model {model.id}.", + progress=0.7, + ) + generation = loaded_model.generate( + input_ids, + max_new_tokens=request.max_output_tokens, + use_cache=True, + return_dict_in_generate=True, + ) + self._emit_progress( + progress_callback, + phase="decoding", + message=f"Decoding generated tokens for {model.id}.", + progress=0.9, + generated_token_count=len(generation.sequences[0]), + ) + decoded = loaded_model.tokenizer.decode(generation.sequences[0], skip_special_tokens=True) + text = _strip_prompt_prefix(decoded, prompt) + duration = perf_counter() - started + return self._result( + text=text.strip(), + provider_id=self.provider_id, + model_id=model.id, + duration_seconds=duration, + finish_reason=FinishReason.STOP, + raw={"generated_tokens": len(generation.sequences[0])}, + ) + + +class OpenAIProvider(Provider): + provider_id = ModelRuntime.OPENAI.value + + def healthcheck(self, model: ModelSpec) -> ProviderHealth: + try: + import openai # noqa: F401 + except ImportError: + return self._missing_dependency_health(self.provider_id, "openai") + if not self.settings.openai_api_key_value: + return self._credential_health(self.provider_id, "LAI_OPENAI_API_KEY") + return self._ok_health(self.provider_id, model_ref=model.model_ref) + + def describe_capabilities(self) -> dict[str, Any]: + return {"supports_remote_generation": True, "supports_streaming": False} + + def generate( + self, + model: ModelSpec, + request: ProviderRequest, + *, + progress_callback: ProgressCallback | None = None, + ): + from openai import OpenAI + + started = perf_counter() + self._emit_progress( + progress_callback, + phase="creating-client", + message=f"Creating OpenAI client for {model.id}.", + progress=0.1, + ) + client = OpenAI(api_key=self.settings.openai_api_key_value, timeout=request.timeout_seconds) + self._emit_progress( + progress_callback, + phase="request-sent", + message=f"Sending OpenAI request for {model.id}.", + progress=0.45, + prompt_characters=len(request.user_prompt), + ) + response = client.responses.create( + model=request.model_override or model.model_ref, + instructions=request.system_prompt or None, + input=request.user_prompt, + max_output_tokens=request.max_output_tokens, + temperature=request.temperature, + ) + self._emit_progress( + progress_callback, + phase="response-received", + message=f"OpenAI response received for {model.id}.", + progress=0.9, + ) + duration = perf_counter() - started + usage = getattr(response, "usage", None) + usage_stats = ( + UsageStats( + prompt_tokens=getattr(usage, "input_tokens", None), + completion_tokens=getattr(usage, "output_tokens", None), + total_tokens=getattr(usage, "total_tokens", None), + ) + if usage + else None + ) + return self._result( + text=(getattr(response, "output_text", "") or "").strip(), + provider_id=self.provider_id, + model_id=model.id, + duration_seconds=duration, + finish_reason=FinishReason.STOP, + usage=usage_stats.model_dump() if usage_stats else None, + raw=serialize_raw_object(response), + ) + + +class AnthropicProvider(Provider): + provider_id = ModelRuntime.ANTHROPIC.value + + def healthcheck(self, model: ModelSpec) -> ProviderHealth: + try: + import anthropic # noqa: F401 + except ImportError: + return self._missing_dependency_health(self.provider_id, "anthropic") + if not self.settings.anthropic_api_key_value: + return self._credential_health(self.provider_id, "LAI_ANTHROPIC_API_KEY") + return self._ok_health(self.provider_id, model_ref=model.model_ref) + + def describe_capabilities(self) -> dict[str, Any]: + return {"supports_remote_generation": True, "supports_streaming": False} + + def generate( + self, + model: ModelSpec, + request: ProviderRequest, + *, + progress_callback: ProgressCallback | None = None, + ): + from anthropic import Anthropic + + started = perf_counter() + self._emit_progress( + progress_callback, + phase="creating-client", + message=f"Creating Anthropic client for {model.id}.", + progress=0.1, + ) + client = Anthropic( + api_key=self.settings.anthropic_api_key_value, timeout=request.timeout_seconds + ) + self._emit_progress( + progress_callback, + phase="request-sent", + message=f"Sending Anthropic request for {model.id}.", + progress=0.45, + prompt_characters=len(request.user_prompt), + ) + response = client.messages.create( + model=request.model_override or model.model_ref, + system=request.system_prompt or "", + messages=[{"role": "user", "content": request.user_prompt}], + max_tokens=request.max_output_tokens, + temperature=request.temperature, + ) + self._emit_progress( + progress_callback, + phase="response-received", + message=f"Anthropic response received for {model.id}.", + progress=0.9, + ) + text_parts = [ + block.text for block in response.content if getattr(block, "type", None) == "text" + ] + usage = getattr(response, "usage", None) + usage_stats = ( + UsageStats( + prompt_tokens=getattr(usage, "input_tokens", None), + completion_tokens=getattr(usage, "output_tokens", None), + total_tokens=None, + ) + if usage + else None + ) + duration = perf_counter() - started + return self._result( + text="".join(text_parts).strip(), + provider_id=self.provider_id, + model_id=model.id, + duration_seconds=duration, + finish_reason=FinishReason.STOP, + usage=usage_stats.model_dump() if usage_stats else None, + raw=serialize_raw_object(response), + ) + + +class GeminiProvider(Provider): + provider_id = ModelRuntime.GEMINI.value + + def healthcheck(self, model: ModelSpec) -> ProviderHealth: + try: + from google import genai # noqa: F401 + except ImportError: + return self._missing_dependency_health(self.provider_id, "google-genai") + if not self.settings.gemini_api_key_value: + return self._credential_health(self.provider_id, "LAI_GEMINI_API_KEY") + return self._ok_health(self.provider_id, model_ref=model.model_ref) + + def describe_capabilities(self) -> dict[str, Any]: + return {"supports_remote_generation": True, "supports_streaming": False} + + def generate( + self, + model: ModelSpec, + request: ProviderRequest, + *, + progress_callback: ProgressCallback | None = None, + ): + from google import genai + + started = perf_counter() + self._emit_progress( + progress_callback, + phase="creating-client", + message=f"Creating Gemini client for {model.id}.", + progress=0.1, + ) + client = genai.Client(api_key=self.settings.gemini_api_key_value) + config: dict[str, Any] = { + "temperature": request.temperature, + "max_output_tokens": request.max_output_tokens, + } + if request.system_prompt: + config["system_instruction"] = request.system_prompt + self._emit_progress( + progress_callback, + phase="request-sent", + message=f"Sending Gemini request for {model.id}.", + progress=0.45, + prompt_characters=len(request.user_prompt), + ) + response = client.models.generate_content( + model=request.model_override or model.model_ref, + contents=request.user_prompt, + config=config, + ) + self._emit_progress( + progress_callback, + phase="response-received", + message=f"Gemini response received for {model.id}.", + progress=0.9, + ) + duration = perf_counter() - started + return self._result( + text=(getattr(response, "text", "") or "").strip(), + provider_id=self.provider_id, + model_id=model.id, + duration_seconds=duration, + finish_reason=FinishReason.STOP, + raw=serialize_raw_object(response), + ) + + +def _render_prompt(system_prompt: str | None, user_prompt: str) -> str: + if system_prompt: + return f"System:\n{system_prompt}\n\nUser:\n{user_prompt}\n\nAssistant:\n" + return user_prompt + + +def _strip_prompt_prefix(decoded_text: str, prompt: str) -> str: + if decoded_text.startswith(prompt): + return decoded_text[len(prompt) :] + return decoded_text diff --git a/src/lai/providers/registry.py b/src/lai/providers/registry.py new file mode 100644 index 0000000..456b43f --- /dev/null +++ b/src/lai/providers/registry.py @@ -0,0 +1,48 @@ +from __future__ import annotations + +from typing import Mapping + +from ..config import AppConfig +from ..domain import ModelRuntime, ModelSpec, ProviderHealth +from ..settings import Settings +from ..system import SystemSnapshot +from .base import Provider +from .implementations import ( + AirLLMProvider, + AnthropicProvider, + GeminiProvider, + OpenAIProvider, + TransformersProvider, +) + + +class ProviderRegistry: + def __init__( + self, + settings: Settings, + system_snapshot: SystemSnapshot, + providers: Mapping[str, Provider] | None = None, + ) -> None: + self.settings = settings + self.system_snapshot = system_snapshot + self._providers = dict(providers or {}) + if not self._providers: + self._providers = { + ModelRuntime.TRANSFORMERS.value: TransformersProvider(settings, system_snapshot), + ModelRuntime.AIRLLM.value: AirLLMProvider(settings, system_snapshot), + ModelRuntime.OPENAI.value: OpenAIProvider(settings, system_snapshot), + ModelRuntime.ANTHROPIC.value: AnthropicProvider(settings, system_snapshot), + ModelRuntime.GEMINI.value: GeminiProvider(settings, system_snapshot), + } + + def get(self, provider_id: str) -> Provider: + return self._providers[provider_id] + + def provider_for_model(self, model: ModelSpec) -> Provider: + return self.get(model.provider_id) + + def healthcheck(self, model: ModelSpec) -> ProviderHealth: + return self.provider_for_model(model).healthcheck(model) + + def model_healthchecks(self, config: AppConfig) -> dict[str, ProviderHealth]: + return {model.id: self.healthcheck(model) for model in config.model_catalog.models} diff --git a/src/lai/routing/__init__.py b/src/lai/routing/__init__.py index 967ec34..3898fb8 100644 --- a/src/lai/routing/__init__.py +++ b/src/lai/routing/__init__.py @@ -1 +1,3 @@ -"""Routing and policy evaluation lives here.""" +from .engine import RoutingEngine + +__all__ = ["RoutingEngine"] diff --git a/src/lai/routing/engine.py b/src/lai/routing/engine.py new file mode 100644 index 0000000..c2b919e --- /dev/null +++ b/src/lai/routing/engine.py @@ -0,0 +1,351 @@ +from __future__ import annotations + +import json +from enum import Enum +from typing import Iterable + +from ..config import AppConfig +from ..domain import ( + ExecutionRequest, + ModelRuntime, + ModelSpec, + ProviderRequest, + RouteReason, + RoutingDecision, + RoutingTier, +) +from ..providers import ProviderRegistry +from ..settings import Settings +from ..system import SystemSnapshot +from .heuristics import build_route_context + + +class RoutingEngine: + def __init__( + self, + settings: Settings, + config: AppConfig, + providers: ProviderRegistry, + system_snapshot: SystemSnapshot, + ) -> None: + self.settings = settings + self.config = config + self.providers = providers + self.system_snapshot = system_snapshot + + def route(self, request: ExecutionRequest) -> RoutingDecision: + reasons: list[RouteReason] = [] + context = build_route_context(request) + + candidate_models = [model for model in self.config.model_catalog.enabled_models()] + reasons.append( + RouteReason(stage="constraints", message="Started from enabled catalog models.") + ) + + if request.provider_override: + candidate_models = [ + model + for model in candidate_models + if model.provider_id == request.provider_override + ] + reasons.append( + RouteReason( + stage="constraints", + message=f"Applied provider override {request.provider_override!r}.", + ) + ) + + if request.model_override: + model = self.config.model_catalog.get_model(request.model_override) + candidate_models = [model] + reasons.append( + RouteReason( + stage="constraints", + message=f"Applied model override {request.model_override!r}.", + ) + ) + + if not self.settings.allow_overnight_jobs and context.overnight_requested: + reasons.append( + RouteReason( + stage="constraints", + message=( + "Overnight execution requested but disabled by settings; " + "deep-work routing will be avoided." + ), + ) + ) + + available_models = self._availability_filter(candidate_models, reasons) + tier = self._resolve_tier(request, context, reasons) + + planner_model = self._resolve_stage_model( + stage="planner", + preferred_model_id=tier.planner_model_id or self.config.routing_policy.router.model_id, + candidate_models=available_models, + required_capabilities=["planning"], + reasons=reasons, + ) + executor_capabilities = list(context.required_capabilities) + executor_model = self._resolve_stage_model( + stage="executor", + preferred_model_id=request.model_override or tier.executor_model_id, + candidate_models=available_models, + required_capabilities=executor_capabilities, + reasons=reasons, + ) + + reviewer_model_id: str | None = None + should_review = ( + request.reviewer_enabled + if request.reviewer_enabled is not None + else tier.resolved_reviewer_enabled + ) + if should_review: + reviewer_model = self._resolve_stage_model( + stage="reviewer", + preferred_model_id=tier.reviewer_model_id, + candidate_models=available_models, + required_capabilities=["critique", "validation"], + reasons=reasons, + ) + reviewer_model_id = reviewer_model.id if reviewer_model else None + if reviewer_model_id is None: + should_review = False + reasons.append( + RouteReason( + stage="fallback", + message=( + "Reviewer stage disabled because no available " + "reviewer-capable model was found." + ), + ) + ) + + fallback_model_ids = self._fallback_chain( + preferred_model_id=executor_model.id, + candidate_models=available_models, + required_capabilities=executor_capabilities, + ) + queue_mode = _enum_value(request.queue_mode) + queue_recommended = queue_mode == "queued" or ( + queue_mode == "auto" + and (tier.allow_overnight or context.expected_duration == "long") + ) + + return RoutingDecision( + matched_tier_id=tier.id, + planner_model_id=planner_model.id if planner_model else None, + executor_model_id=executor_model.id, + reviewer_model_id=reviewer_model_id, + fallback_model_ids=fallback_model_ids, + reasons=reasons, + queue_recommended=queue_recommended, + should_plan=tier.should_plan and planner_model is not None, + should_review=should_review, + context=context, + ) + + def _availability_filter( + self, + candidate_models: list[ModelSpec], + reasons: list[RouteReason], + ) -> list[ModelSpec]: + available: list[ModelSpec] = [] + for model in candidate_models: + if model.runtime == ModelRuntime.AIRLLM and not self.system_snapshot.has_gpu: + fallback = self.config.routing_policy.fallbacks.when_gpu_unavailable + if fallback and fallback.executor_model_id != model.id: + reasons.append( + RouteReason( + stage="availability", + message=f"Skipped {model.id!r} because GPU is unavailable.", + ) + ) + continue + + health = self.providers.healthcheck(model) + if health.available: + available.append(model) + continue + reasons.append( + RouteReason( + stage="availability", + message=f"Skipped {model.id!r}: {'; '.join(health.reasons)}", + ) + ) + if not available and candidate_models: + reasons.append( + RouteReason( + stage="availability", + message=( + "No models passed provider availability checks. " + "Continuing with configured models for route explanation only." + ), + ) + ) + return candidate_models + if not available: + raise RuntimeError("No models remain after provider and environment filtering.") + return available + + def _resolve_tier( + self, + request: ExecutionRequest, + context, + reasons: list[RouteReason], + ) -> RoutingTier: + scored_tiers = sorted( + self.config.routing_policy.tiers, + key=lambda tier: (tier.match.score(context), tier.allow_overnight, tier.id), + reverse=True, + ) + best_tier = scored_tiers[0] + reasons.append( + RouteReason( + stage="heuristics", + message=( + f"Heuristics classified complexity={context.complexity}, " + f"expected_duration={context.expected_duration}, urgency={context.urgency}." + ), + ) + ) + + if ( + self.config.routing_policy.router.use_model_router + and best_tier.match.score(context) < 2 + and not request.model_override + ): + llm_suggestion = self._router_model_suggestion(request) + if llm_suggestion: + suggested_tier = next( + ( + tier + for tier in self.config.routing_policy.tiers + if tier.id == llm_suggestion + ), + None, + ) + if suggested_tier is not None: + best_tier = suggested_tier + reasons.append( + RouteReason( + stage="model-router", + message=f"Router model suggested tier {suggested_tier.id!r}.", + ) + ) + + if best_tier.allow_overnight and not self.settings.allow_overnight_jobs: + non_overnight = [ + tier for tier in self.config.routing_policy.tiers if not tier.allow_overnight + ] + if non_overnight: + best_tier = non_overnight[0] + reasons.append( + RouteReason( + stage="constraints", + message=( + f"Selected fallback tier {best_tier.id!r} " + "because overnight jobs are disabled." + ), + ) + ) + + reasons.append(RouteReason(stage="tier", message=f"Matched routing tier {best_tier.id!r}.")) + return best_tier + + def _resolve_stage_model( + self, + *, + stage: str, + preferred_model_id: str | None, + candidate_models: list[ModelSpec], + required_capabilities: list[str], + reasons: list[RouteReason], + ) -> ModelSpec | None: + if preferred_model_id: + preferred = next( + (model for model in candidate_models if model.id == preferred_model_id), None + ) + if preferred and preferred.supports_capabilities(required_capabilities): + reasons.append( + RouteReason( + stage=stage, + message=f"Selected preferred {stage} model {preferred.id!r}.", + ) + ) + return preferred + + for model in candidate_models: + if model.supports_capabilities(required_capabilities): + reasons.append( + RouteReason( + stage="fallback", + message=f"Selected {stage} fallback model {model.id!r}.", + ) + ) + return model + + if stage == "reviewer": + return None + raise RuntimeError( + f"No available {stage} model can satisfy capabilities {required_capabilities!r}." + ) + + def _fallback_chain( + self, + *, + preferred_model_id: str, + candidate_models: Iterable[ModelSpec], + required_capabilities: list[str], + ) -> list[str]: + fallback_ids: list[str] = [] + for model in candidate_models: + if model.id == preferred_model_id: + continue + if model.supports_capabilities(required_capabilities): + fallback_ids.append(model.id) + return fallback_ids + + def _router_model_suggestion(self, request: ExecutionRequest) -> str | None: + router_model = self.config.model_catalog.get_model( + self.config.routing_policy.router.model_id + ) + health = self.providers.healthcheck(router_model) + if not health.available: + return None + + provider = self.providers.provider_for_model(router_model) + prompt = ( + "Classify the request into one routing tier and return JSON only with a single key " + '"tier_id" whose value is one of: instant, standard, deep-work.\n\n' + f"Request:\n{request.user_prompt}" + ) + try: + response = provider.generate( + router_model, + ProviderRequest( + system_prompt="You are a strict router. Return JSON only.", + user_prompt=prompt, + temperature=0, + max_output_tokens=64, + timeout_seconds=min(request.timeout_seconds, 30), + ), + ) + except Exception: + return None + + try: + payload = json.loads(response.text) + except json.JSONDecodeError: + return None + tier_id = payload.get("tier_id") + if tier_id in {"instant", "standard", "deep-work"}: + return str(tier_id) + return None + + +def _enum_value(value: object) -> str: + if isinstance(value, Enum): + return str(value.value) + return str(value) diff --git a/src/lai/routing/heuristics.py b/src/lai/routing/heuristics.py new file mode 100644 index 0000000..e0044d2 --- /dev/null +++ b/src/lai/routing/heuristics.py @@ -0,0 +1,75 @@ +from __future__ import annotations + +from ..domain import ExecutionRequest, RouteContext + +DEEP_KEYWORDS = ( + "advanced", + "comprehensive", + "research", + "full", + "complete", + "detailed", + "analyze", + "architecture", + "plan", + "strategy", +) +URGENT_KEYWORDS = ("urgent", "asap", "immediately", "quick", "fast", "today") +HIGH_RISK_KEYWORDS = ("security", "medical", "legal", "finance", "exploit", "breach") +OVERNIGHT_KEYWORDS = ("overnight", "whole night", "background", "long-running") + + +def build_route_context(request: ExecutionRequest) -> RouteContext: + lowered = request.user_prompt.lower() + prompt_length = len(request.user_prompt) + + complexity = "low" + if prompt_length > 1800 or any(keyword in lowered for keyword in DEEP_KEYWORDS): + complexity = "high" + elif prompt_length > 600: + complexity = "medium" + + urgency = "low" + if any(keyword in lowered for keyword in URGENT_KEYWORDS): + urgency = "high" + elif "soon" in lowered: + urgency = "medium" + + cost_sensitivity = "low" + if any(keyword in lowered for keyword in ("cheap", "low cost", "fast", "small model")): + cost_sensitivity = "high" + elif "balanced" in lowered: + cost_sensitivity = "medium" + + safety_risk = "low" + if any(keyword in lowered for keyword in HIGH_RISK_KEYWORDS): + safety_risk = "high" + elif any(keyword in lowered for keyword in ("privacy", "compliance", "policy")): + safety_risk = "medium" + + expected_duration = "short" + overnight_requested = any(keyword in lowered for keyword in OVERNIGHT_KEYWORDS) + if overnight_requested or complexity == "high": + expected_duration = "long" + elif complexity == "medium": + expected_duration = "medium" + + required_capabilities = list(request.required_capabilities) + if complexity == "high": + required_capabilities.extend(["deep-reasoning", "long-form-generation"]) + elif "summary" in lowered or "summarize" in lowered: + required_capabilities.append("summarization") + + return RouteContext( + prompt_length=prompt_length, + complexity=complexity, + urgency=urgency, + cost_sensitivity=cost_sensitivity, + safety_risk=safety_risk, + expected_duration=expected_duration, + requires_high_quality=complexity == "high", + overnight_requested=overnight_requested, + required_capabilities=sorted(set(required_capabilities)), + model_override=request.model_override, + provider_override=request.provider_override, + ) diff --git a/src/lai/serialization.py b/src/lai/serialization.py new file mode 100644 index 0000000..6a3288e --- /dev/null +++ b/src/lai/serialization.py @@ -0,0 +1,24 @@ +from __future__ import annotations + +import json +from datetime import datetime +from pathlib import Path +from typing import Any + + +def to_jsonable(value: Any) -> Any: + if isinstance(value, dict): + return {str(key): to_jsonable(item) for key, item in value.items()} + if isinstance(value, list): + return [to_jsonable(item) for item in value] + if isinstance(value, tuple): + return [to_jsonable(item) for item in value] + if isinstance(value, Path): + return str(value) + if isinstance(value, datetime): + return value.isoformat() + return value + + +def dumps_pretty(data: Any) -> str: + return json.dumps(to_jsonable(data), indent=2, sort_keys=True, ensure_ascii=True) diff --git a/src/lai/settings.py b/src/lai/settings.py index 24ea9f2..f7962c4 100644 --- a/src/lai/settings.py +++ b/src/lai/settings.py @@ -1,7 +1,9 @@ +from __future__ import annotations + from functools import cached_property from pathlib import Path -from pydantic import Field +from pydantic import Field, SecretStr from pydantic_settings import BaseSettings, SettingsConfigDict @@ -10,7 +12,7 @@ def _default_root_dir() -> Path: class Settings(BaseSettings): - """Application settings shared across the future platform services.""" + """Application settings shared across the platform services.""" model_config = SettingsConfigDict( env_file=".env", @@ -21,17 +23,43 @@ class Settings(BaseSettings): environment: str = "local" project_name: str = "LAI" root_dir: Path = Field(default_factory=_default_root_dir) + model_catalog: Path = Path("configs/models/catalog.yaml") routing_policy: Path = Path("configs/routing/policies.yaml") prompt_root: Path = Path("configs/prompts") + huggingface_cache_dir: Path = Path("data/cache/huggingface") airllm_shards_dir: Path = Path("data/models/airllm-shards") + raw_models_dir: Path = Path("data/models/raw") artifacts_dir: Path = Path("data/artifacts") + state_dir: Path = Path("data/state") + database_path: Path = Path("data/state/lai.db") logs_dir: Path = Path("logs") + worker_service_lock_path: Path = Path("data/state/worker-service.lock") + worker_service_stop_path: Path = Path("data/state/worker-service.stop") + worker_service_log_path: Path = Path("logs/worker-service.log") + allow_overnight_jobs: bool = True enable_gpu: bool = True max_router_tokens: int = 2048 + default_timeout_seconds: int = 120 + default_max_output_tokens: int = 1024 + default_temperature: float = 0.2 + queue_poll_interval_seconds: int = 5 + worker_idle_sleep_seconds: float = 2.0 + worker_service_poll_interval_seconds: float = 10.0 + worker_service_recovery_stale_after_seconds: int = 120 + worker_service_recovery_limit: int = 100 + job_heartbeat_interval_seconds: float = 30.0 + max_retry_attempts: int = 1 + stale_running_job_timeout_seconds: int = 900 + + hf_token: SecretStr | None = None + openai_api_key: SecretStr | None = None + anthropic_api_key: SecretStr | None = None + gemini_api_key: SecretStr | None = None + @cached_property def resolved_model_catalog(self) -> Path: return self.root_dir / self.model_catalog @@ -39,3 +67,67 @@ def resolved_model_catalog(self) -> Path: @cached_property def resolved_routing_policy(self) -> Path: return self.root_dir / self.routing_policy + + @cached_property + def resolved_prompt_root(self) -> Path: + return self.root_dir / self.prompt_root + + @cached_property + def resolved_huggingface_cache_dir(self) -> Path: + return self.root_dir / self.huggingface_cache_dir + + @cached_property + def resolved_airllm_shards_dir(self) -> Path: + return self.root_dir / self.airllm_shards_dir + + @cached_property + def resolved_raw_models_dir(self) -> Path: + return self.root_dir / self.raw_models_dir + + @cached_property + def resolved_artifacts_dir(self) -> Path: + return self.root_dir / self.artifacts_dir + + @cached_property + def resolved_state_dir(self) -> Path: + return self.root_dir / self.state_dir + + @cached_property + def resolved_database_path(self) -> Path: + return self.root_dir / self.database_path + + @cached_property + def resolved_logs_dir(self) -> Path: + return self.root_dir / self.logs_dir + + @cached_property + def resolved_worker_service_lock_path(self) -> Path: + return self.root_dir / self.worker_service_lock_path + + @cached_property + def resolved_worker_service_stop_path(self) -> Path: + return self.root_dir / self.worker_service_stop_path + + @cached_property + def resolved_worker_service_log_path(self) -> Path: + return self.root_dir / self.worker_service_log_path + + @staticmethod + def _secret_value(secret: SecretStr | None) -> str | None: + return secret.get_secret_value() if secret else None + + @property + def huggingface_token_value(self) -> str | None: + return self._secret_value(self.hf_token) + + @property + def openai_api_key_value(self) -> str | None: + return self._secret_value(self.openai_api_key) + + @property + def anthropic_api_key_value(self) -> str | None: + return self._secret_value(self.anthropic_api_key) + + @property + def gemini_api_key_value(self) -> str | None: + return self._secret_value(self.gemini_api_key) diff --git a/src/lai/smoke.py b/src/lai/smoke.py new file mode 100644 index 0000000..cb61590 --- /dev/null +++ b/src/lai/smoke.py @@ -0,0 +1,291 @@ +from __future__ import annotations + +from datetime import datetime, timezone +from pathlib import Path +from typing import Any, Literal, Mapping + +from pydantic import BaseModel, Field + +from .domain import ModelSpec, ProviderRequest +from .providers import Provider, ProviderRegistry +from .settings import Settings +from .system import collect_system_snapshot + +SMOKE_PROMPT = "Reply with READY and nothing else." + + +class SmokeCheckResult(BaseModel): + provider_id: str + runtime: str + model_ref: str + mode: Literal["readiness", "live"] + status: Literal["ready", "passed", "blocked", "failed", "skipped"] + available: bool + healthy: bool + executed: bool = False + success: bool = False + message: str + reasons: list[str] = Field(default_factory=list) + capabilities: dict[str, Any] = Field(default_factory=dict) + duration_seconds: float | None = None + output_preview: str | None = None + metadata: dict[str, Any] = Field(default_factory=dict) + + +class SmokeSuiteResult(BaseModel): + suite_id: str + mode: Literal["readiness", "live"] + executed_at: str + prompt: str + passed: bool + total: int + success_count: int + blocked_count: int + failed_count: int + checks: list[SmokeCheckResult] + + +def run_smoke_suite( + settings: Settings, + *, + suite_id: Literal["providers", "local"], + live: bool = False, + include_airllm: bool = False, + prompt: str = SMOKE_PROMPT, + providers: Mapping[str, Provider] | None = None, +) -> SmokeSuiteResult: + snapshot = collect_system_snapshot( + settings.root_dir, + settings.resolved_huggingface_cache_dir, + settings.resolved_airllm_shards_dir, + enable_gpu=settings.enable_gpu, + ) + registry = ProviderRegistry(settings, snapshot, providers=providers) + checks: list[SmokeCheckResult] = [] + + for model in _suite_models(suite_id, include_airllm=include_airllm): + provider = registry.get(model.provider_id) + health = provider.healthcheck(model) + capabilities = provider.describe_capabilities() + mode = "live" if live else "readiness" + + if not live: + ready = health.available and health.healthy + checks.append( + SmokeCheckResult( + provider_id=model.provider_id, + runtime=model.runtime, + model_ref=model.model_ref, + mode=mode, + status="ready" if ready else "blocked", + available=health.available, + healthy=health.healthy, + success=ready, + message=( + "Provider is ready for live smoke execution." + if ready + else "; ".join(health.reasons) or "Provider is blocked." + ), + reasons=health.reasons, + capabilities=capabilities, + metadata={"provider_metadata": health.metadata}, + ) + ) + continue + + if not health.available: + checks.append( + SmokeCheckResult( + provider_id=model.provider_id, + runtime=model.runtime, + model_ref=model.model_ref, + mode=mode, + status="blocked", + available=health.available, + healthy=health.healthy, + message="Live smoke blocked by provider healthcheck.", + reasons=health.reasons, + capabilities=capabilities, + metadata={"provider_metadata": health.metadata}, + ) + ) + continue + + try: + result = provider.generate( + model, + ProviderRequest( + user_prompt=prompt, + max_output_tokens=32, + temperature=0, + timeout_seconds=min(settings.default_timeout_seconds, 60), + ), + ) + except Exception as exc: + checks.append( + SmokeCheckResult( + provider_id=model.provider_id, + runtime=model.runtime, + model_ref=model.model_ref, + mode=mode, + status="failed", + available=health.available, + healthy=health.healthy, + executed=True, + success=False, + message=f"Live smoke failed with {type(exc).__name__}.", + reasons=[str(exc)], + capabilities=capabilities, + metadata={"provider_metadata": health.metadata}, + ) + ) + continue + + preview = (result.text or "").strip() + preview = preview[:160] if preview else None + success = bool((result.text or "").strip()) + checks.append( + SmokeCheckResult( + provider_id=model.provider_id, + runtime=model.runtime, + model_ref=model.model_ref, + mode=mode, + status="passed" if success else "failed", + available=health.available, + healthy=health.healthy, + executed=True, + success=success, + message=( + f"Live smoke completed in {result.duration_seconds:.2f}s." + if success + else "Live smoke returned an empty response." + ), + reasons=health.reasons, + capabilities=capabilities, + duration_seconds=result.duration_seconds, + output_preview=preview, + metadata={ + "provider_metadata": health.metadata, + "finish_reason": result.finish_reason, + "usage": result.usage.model_dump() if result.usage else None, + "sentinel_detected": "READY" in (result.text or "").upper(), + }, + ) + ) + + success_count = sum(1 for check in checks if check.success) + blocked_count = sum(1 for check in checks if check.status == "blocked") + failed_count = sum(1 for check in checks if check.status == "failed") + return SmokeSuiteResult( + suite_id=suite_id, + mode="live" if live else "readiness", + executed_at=_utcnow_iso(), + prompt=prompt, + passed=blocked_count == 0 and failed_count == 0 and success_count == len(checks), + total=len(checks), + success_count=success_count, + blocked_count=blocked_count, + failed_count=failed_count, + checks=checks, + ) + + +def save_smoke_result(results_dir: Path, result: SmokeSuiteResult) -> Path: + results_dir.mkdir(parents=True, exist_ok=True) + timestamp = datetime.now(tz=timezone.utc).strftime("%Y%m%dT%H%M%SZ") + path = results_dir / f"smoke-{result.suite_id}-{result.mode}-{timestamp}.json" + path.write_text(result.model_dump_json(indent=2), encoding="utf-8") + return path + + +def load_smoke_result(path: Path) -> SmokeSuiteResult: + return SmokeSuiteResult.model_validate_json(path.read_text(encoding="utf-8")) + + +def resolve_smoke_result_path(results_dir: Path, result_id: str) -> Path | None: + candidate_name = Path(result_id).name + if candidate_name != result_id or not candidate_name.endswith(".json"): + return None + + candidate = (results_dir / candidate_name).resolve() + try: + candidate.relative_to(results_dir.resolve()) + except ValueError: + return None + + if not candidate.exists() or not candidate.is_file(): + return None + if not candidate.name.startswith("smoke-"): + return None + return candidate + + +def list_smoke_results( + results_dir: Path, + *, + suite_id: Literal["providers", "local"] | None = None, + limit: int = 10, +) -> list[Path]: + if not results_dir.exists(): + return [] + pattern = f"smoke-{suite_id}-*.json" if suite_id else "smoke-*.json" + candidates = sorted(results_dir.glob(pattern), reverse=True) + return candidates[:limit] + + +def latest_smoke_result( + results_dir: Path, + *, + suite_id: Literal["providers", "local"] | None = None, +) -> Path | None: + candidates = list_smoke_results(results_dir, suite_id=suite_id, limit=1) + return candidates[0] if candidates else None + + +def _suite_models( + suite_id: Literal["providers", "local"], + *, + include_airllm: bool, +) -> list[ModelSpec]: + if suite_id == "providers": + return [ + _build_model("openai", "openai", "gpt-5.4-mini"), + _build_model("anthropic", "anthropic", "claude-sonnet-4-20250514"), + _build_model("gemini", "gemini", "gemini-2.5-flash"), + ] + + models = [ + _build_model("transformers", "transformers", "sshleifer/tiny-gpt2"), + ] + if include_airllm: + models.append( + ModelSpec.model_validate( + { + "id": "airllm-smoke", + "role": "executor", + "runtime": "airllm", + "repo_id": "meta-llama/Llama-3.1-8B-Instruct", + "hardware": {"expected_disk_gb": 1}, + "runtime_hints": { + "allow_cpu_fallback": True, + "allow_layer_sharding": True, + }, + } + ) + ) + return models + + +def _build_model(provider_id: str, runtime: str, model_ref: str) -> ModelSpec: + return ModelSpec.model_validate( + { + "id": f"{provider_id}-smoke", + "role": "executor", + "runtime": runtime, + "model": model_ref, + } + ) + + +def _utcnow_iso() -> str: + return datetime.now(tz=timezone.utc).isoformat() diff --git a/src/lai/system.py b/src/lai/system.py new file mode 100644 index 0000000..8a92879 --- /dev/null +++ b/src/lai/system.py @@ -0,0 +1,157 @@ +from __future__ import annotations + +import ctypes +import os +import platform +import shutil +import subprocess +import sys +from ctypes import Structure, byref, c_ulong, c_ulonglong, sizeof +from pathlib import Path + +from pydantic import BaseModel, Field + + +class SystemSnapshot(BaseModel): + has_gpu: bool + gpu_name: str | None = None + gpu_total_memory_gb: float | None = None + gpu_driver_version: str | None = None + total_ram_gb: float | None = None + available_ram_gb: float | None = None + platform_name: str = Field(default_factory=platform.platform) + python_version: str = Field(default_factory=platform.python_version) + free_disk_gb: dict[str, float] = Field(default_factory=dict) + + +def collect_system_snapshot(*paths: Path, enable_gpu: bool = True) -> SystemSnapshot: + free_disk_gb: dict[str, float] = {} + for path in paths: + free_disk_gb[str(path)] = round(available_disk_gb(path), 2) + has_gpu, gpu_name, gpu_total_memory_gb, gpu_driver_version = detect_gpu_details( + enable_gpu=enable_gpu + ) + total_ram_gb, available_ram_gb = detect_memory_gb() + return SystemSnapshot( + has_gpu=has_gpu, + gpu_name=gpu_name, + gpu_total_memory_gb=gpu_total_memory_gb, + gpu_driver_version=gpu_driver_version, + total_ram_gb=total_ram_gb, + available_ram_gb=available_ram_gb, + free_disk_gb=free_disk_gb, + ) + + +def detect_gpu(enable_gpu: bool = True) -> tuple[bool, str | None]: + has_gpu, gpu_name, _, _ = detect_gpu_details(enable_gpu=enable_gpu) + return has_gpu, gpu_name + + +def detect_gpu_details( + enable_gpu: bool = True, +) -> tuple[bool, str | None, float | None, str | None]: + if not enable_gpu: + return False, None, None, None + + try: + import torch + + if torch.cuda.is_available(): + properties = torch.cuda.get_device_properties(0) + total_memory_gb = round(properties.total_memory / (1024**3), 2) + return True, str(torch.cuda.get_device_name(0)), total_memory_gb, None + except Exception: + pass + + try: + result = subprocess.run( + [ + "nvidia-smi", + "--query-gpu=name,memory.total,driver_version", + "--format=csv,noheader,nounits", + ], + check=False, + capture_output=True, + text=True, + ) + except FileNotFoundError: + return False, None, None, None + + if result.returncode != 0: + return False, None, None, None + + lines = [line.strip() for line in result.stdout.splitlines() if line.strip()] + if not lines: + return False, None, None, None + + parts = [part.strip() for part in lines[0].split(",")] + gpu_name = parts[0] if parts else None + memory_total_gb = None + if len(parts) > 1: + try: + memory_total_gb = round(float(parts[1]) / 1024, 2) + except ValueError: + memory_total_gb = None + driver_version = parts[2] if len(parts) > 2 else None + return True, gpu_name, memory_total_gb, driver_version + + +def detect_memory_gb() -> tuple[float | None, float | None]: + if sys.platform == "win32": + return _detect_windows_memory_gb() + + if hasattr(os, "sysconf"): + try: + page_size = os.sysconf("SC_PAGE_SIZE") + total_pages = os.sysconf("SC_PHYS_PAGES") + available_pages = os.sysconf("SC_AVPHYS_PAGES") + except (OSError, ValueError): + return None, None + total_gb = round((page_size * total_pages) / (1024**3), 2) + available_gb = round((page_size * available_pages) / (1024**3), 2) + return total_gb, available_gb + + return None, None + + +def available_disk_gb(path: Path) -> float: + target = _nearest_existing_path(path) + usage = shutil.disk_usage(target) + return usage.free / (1024**3) + + +def _nearest_existing_path(path: Path) -> Path: + candidate = path + while not candidate.exists(): + if candidate.parent == candidate: + break + candidate = candidate.parent + return candidate + + +def _detect_windows_memory_gb() -> tuple[float | None, float | None]: + try: + + class MEMORYSTATUSEX(Structure): + _fields_ = [ + ("dwLength", c_ulong), + ("dwMemoryLoad", c_ulong), + ("ullTotalPhys", c_ulonglong), + ("ullAvailPhys", c_ulonglong), + ("ullTotalPageFile", c_ulonglong), + ("ullAvailPageFile", c_ulonglong), + ("ullTotalVirtual", c_ulonglong), + ("ullAvailVirtual", c_ulonglong), + ("ullAvailExtendedVirtual", c_ulonglong), + ] + + status = MEMORYSTATUSEX() + status.dwLength = sizeof(MEMORYSTATUSEX) + if not ctypes.windll.kernel32.GlobalMemoryStatusEx(byref(status)): # type: ignore[attr-defined] + return None, None + total_gb = round(status.ullTotalPhys / (1024**3), 2) + available_gb = round(status.ullAvailPhys / (1024**3), 2) + return total_gb, available_gb + except Exception: + return None, None diff --git a/src/lai/worker/__init__.py b/src/lai/worker/__init__.py new file mode 100644 index 0000000..0cb44e3 --- /dev/null +++ b/src/lai/worker/__init__.py @@ -0,0 +1,17 @@ +"""Dedicated worker service entrypoints and helpers.""" + +from .service import ( + WorkerServiceConfig, + WorkerServiceHost, + clear_worker_service_stop, + request_worker_service_stop, + worker_service_stop_requested, +) + +__all__ = [ + "WorkerServiceConfig", + "WorkerServiceHost", + "clear_worker_service_stop", + "request_worker_service_stop", + "worker_service_stop_requested", +] diff --git a/src/lai/worker/cli.py b/src/lai/worker/cli.py new file mode 100644 index 0000000..c5a5f0f --- /dev/null +++ b/src/lai/worker/cli.py @@ -0,0 +1,109 @@ +from __future__ import annotations + +import typer +from rich.console import Console +from rich.table import Table + +from ..application import create_application +from ..settings import Settings +from .service import WorkerServiceConfig, WorkerServiceHost, request_worker_service_stop + +app = typer.Typer(help="Dedicated local worker service entrypoint.", no_args_is_help=True) +console = Console() + + +@app.command("run") +def run_service( + poll_interval: float | None = typer.Option( + None, + min=0.1, + help="Seconds to wait between daemon polling cycles.", + ), + max_idle_cycles: int = typer.Option( + 1, + min=1, + help="Number of empty polls before a drain cycle becomes idle.", + ), + max_jobs_per_cycle: int | None = typer.Option( + None, + min=1, + help="Optional cap for jobs processed in one drain cycle.", + ), + max_cycles: int | None = typer.Option( + None, + min=1, + help="Optional cap for daemon cycles. Useful for tests or supervised runs.", + ), + worker_id: str = typer.Option( + "local-worker", + help="Persisted worker id used for service status and ownership.", + ), + replace_lock: bool = typer.Option( + False, + "--replace-lock", + help="Replace an existing lock file if you are certain the prior worker is gone.", + ), +) -> None: + """Run the dedicated long-lived local worker service.""" + settings = Settings() + config = WorkerServiceConfig( + worker_id=worker_id, + poll_interval_seconds=( + poll_interval + if poll_interval is not None + else settings.worker_service_poll_interval_seconds + ), + max_idle_cycles_per_drain=max_idle_cycles, + max_jobs_per_cycle=max_jobs_per_cycle, + max_cycles=max_cycles, + replace_existing_lock=replace_lock, + ) + host = WorkerServiceHost(settings=settings, config=config) + processed = host.run() + console.print(f"Worker service exited after processing {processed} job(s).") + + +@app.command("stop") +def stop_service() -> None: + """Request a graceful stop for the dedicated worker service.""" + path = request_worker_service_stop(Settings()) + console.print(f"Requested worker service stop via {path}") + + +@app.command("status") +def service_status( + worker_id: str = typer.Option( + "local-worker", + help="Persisted worker id to inspect.", + ), +) -> None: + """Show dedicated worker service status, lock, and stop-signal state.""" + settings = Settings() + application = create_application(settings=settings) + worker = application.orchestration.get_worker_state(worker_id) + + table = Table(title=f"Worker Service {worker_id}") + table.add_column("Field") + table.add_column("Value") + table.add_row("Status", worker.status) + table.add_row("Mode", worker.mode) + table.add_row("Current job", worker.current_job_id or "n/a") + table.add_row("Last job", worker.last_job_id or "n/a") + table.add_row("Processed jobs", str(worker.processed_jobs)) + table.add_row("Queued jobs", str(worker.queued_jobs)) + table.add_row("Heartbeat", worker.heartbeat_at.isoformat(timespec="seconds")) + table.add_row( + "Lock file", + "present" if settings.resolved_worker_service_lock_path.exists() else "missing", + ) + table.add_row( + "Stop signal", + "present" if settings.resolved_worker_service_stop_path.exists() else "missing", + ) + table.add_row("Log file", str(settings.resolved_worker_service_log_path)) + table.add_row("Last error", worker.last_error or "none") + console.print(table) + + +if __name__ == "__main__": + app() diff --git a/src/lai/worker/service.py b/src/lai/worker/service.py new file mode 100644 index 0000000..024c86d --- /dev/null +++ b/src/lai/worker/service.py @@ -0,0 +1,382 @@ +from __future__ import annotations + +import json +import logging +import os +import time +from dataclasses import dataclass +from pathlib import Path +from typing import Callable, Mapping + +from ..application import LAIApplication, create_application +from ..domain import JobStatus, WorkerStatus, utcnow +from ..layout import ensure_runtime_directories +from ..providers import Provider +from ..settings import Settings + + +@dataclass +class WorkerServiceConfig: + worker_id: str = "local-worker" + poll_interval_seconds: float = 10.0 + max_idle_cycles_per_drain: int = 1 + max_jobs_per_cycle: int | None = None + max_cycles: int | None = None + replace_existing_lock: bool = False + + +def read_worker_service_lock_metadata(lock_path: Path) -> dict[str, object] | None: + if not lock_path.exists(): + return None + try: + return json.loads(lock_path.read_text(encoding="utf-8")) + except (OSError, json.JSONDecodeError): + return None + + +def request_worker_service_stop(settings: Settings | None = None) -> Path: + settings = settings or Settings() + stop_path = settings.resolved_worker_service_stop_path + stop_path.parent.mkdir(parents=True, exist_ok=True) + stop_path.write_text("stop requested\n", encoding="utf-8") + return stop_path + + +def clear_worker_service_stop(settings: Settings | None = None) -> None: + settings = settings or Settings() + settings.resolved_worker_service_stop_path.unlink(missing_ok=True) + + +def worker_service_stop_requested(settings: Settings | None = None) -> bool: + settings = settings or Settings() + return settings.resolved_worker_service_stop_path.exists() + + +class WorkerServiceLock: + def __init__(self, lock_path: Path, *, replace_existing: bool = False) -> None: + self.lock_path = lock_path + self.replace_existing = replace_existing + self.acquired = False + self.replaced_metadata: dict[str, object] | None = None + self.current_metadata: dict[str, object] | None = None + + def __enter__(self) -> "WorkerServiceLock": + self.acquire() + return self + + def __exit__(self, *_args: object) -> None: + self.release() + + def acquire(self) -> None: + self.lock_path.parent.mkdir(parents=True, exist_ok=True) + if self.replace_existing and self.lock_path.exists(): + self.replaced_metadata = read_worker_service_lock_metadata(self.lock_path) + self.lock_path.unlink() + try: + descriptor = os.open( + str(self.lock_path), + os.O_CREAT | os.O_EXCL | os.O_WRONLY, + ) + except FileExistsError as exc: + existing = read_worker_service_lock_metadata(self.lock_path) + details = "" + if existing: + pid = existing.get("pid") + created_at = existing.get("created_at") + if isinstance(created_at, (int, float)): + age_seconds = max(time.time() - created_at, 0.0) + details = f" Existing lock pid={pid}, age={age_seconds:.1f}s." + else: + details = f" Existing lock pid={pid}." + raise RuntimeError( + f"Worker service lock already exists at {self.lock_path}. " + "Use replace_existing_lock only when you are certain the prior worker is gone." + f"{details}" + ) from exc + + payload = { + "pid": os.getpid(), + "created_at": time.time(), + } + with os.fdopen(descriptor, "w", encoding="utf-8") as handle: + json.dump(payload, handle) + self.acquired = True + self.current_metadata = payload + + def release(self) -> None: + if self.acquired: + self.lock_path.unlink(missing_ok=True) + self.acquired = False + + +class WorkerServiceHost: + def __init__( + self, + *, + settings: Settings | None = None, + providers: Mapping[str, Provider] | None = None, + config: WorkerServiceConfig | None = None, + application_factory: Callable[[], LAIApplication] | None = None, + sleep_fn: Callable[[float], None] = time.sleep, + logger: logging.Logger | None = None, + ) -> None: + self.settings = settings or Settings() + ensure_runtime_directories(self.settings.root_dir) + self.providers = providers + self.config = config or WorkerServiceConfig( + poll_interval_seconds=self.settings.worker_service_poll_interval_seconds + ) + self.application_factory = application_factory or ( + lambda: create_application(settings=self.settings, providers=self.providers) + ) + self.sleep_fn = sleep_fn + self.logger = logger or build_worker_service_logger( + self.settings.resolved_worker_service_log_path + ) + + def run(self) -> int: + total_processed = 0 + cycles = 0 + clear_worker_service_stop(self.settings) + + lock = WorkerServiceLock( + self.settings.resolved_worker_service_lock_path, + replace_existing=self.config.replace_existing_lock, + ) + with lock: + self.logger.info("LAI worker service started for worker_id=%s", self.config.worker_id) + app = self.application_factory() + previous_state = app.orchestration.get_worker_state(self.config.worker_id) + recovery_summary = self._recover_startup_jobs( + app=app, + previous_state=previous_state, + lock=lock, + ) + self._mark_service_state( + status=WorkerStatus.IDLE, + phase="starting", + processed_jobs=0, + metadata={ + "service_cycles": 0, + **({"last_recovery": recovery_summary} if recovery_summary else {}), + }, + ) + try: + while True: + if worker_service_stop_requested(self.settings): + self.logger.info("Worker service stop signal received.") + self._mark_service_state( + status=WorkerStatus.STOPPED, + phase="stopped", + processed_jobs=total_processed, + metadata={"service_cycles": cycles}, + ) + break + + app = self.application_factory() + cycle_recovery = self._recover_cycle_stale_jobs(app, cycle=cycles + 1) + processed = app.orchestration.run_worker_until_idle( + max_idle_cycles=self.config.max_idle_cycles_per_drain, + max_jobs=self.config.max_jobs_per_cycle, + requeue_running=False, + worker_id=self.config.worker_id, + mode="daemon", + ) + total_processed += processed + cycles += 1 + self.logger.info( + "Worker service cycle %s processed %s job(s), total=%s", + cycles, + processed, + total_processed, + ) + self._mark_service_state( + status=WorkerStatus.IDLE, + phase="sleeping", + processed_jobs=total_processed, + metadata={ + "service_cycles": cycles, + "last_cycle_processed": processed, + **( + {"last_recovery": cycle_recovery} + if cycle_recovery is not None + else {} + ), + }, + ) + + if self.config.max_cycles is not None and cycles >= self.config.max_cycles: + self.logger.info( + "Worker service reached max_cycles=%s", + self.config.max_cycles, + ) + break + + self.sleep_fn(self.config.poll_interval_seconds) + except Exception as exc: + self.logger.exception("Worker service failed.") + self._mark_service_state( + status=WorkerStatus.ERROR, + phase="error", + processed_jobs=total_processed, + last_error=f"{type(exc).__name__}: {exc}", + metadata={"service_cycles": cycles}, + ) + raise + finally: + clear_worker_service_stop(self.settings) + + self.logger.info("LAI worker service exited after processing %s job(s).", total_processed) + return total_processed + + def _recover_startup_jobs( + self, + *, + app: LAIApplication, + previous_state, + lock: WorkerServiceLock, + ) -> dict[str, object] | None: + running_jobs = app.job_store.list_jobs_by_status( + JobStatus.RUNNING, + limit=self.settings.worker_service_recovery_limit, + ) + if not running_jobs: + return None + + heartbeat_age_seconds = round( + max((utcnow() - previous_state.heartbeat_at).total_seconds(), 0.0), + 2, + ) + replaced_lock = lock.replaced_metadata is not None + previous_status = previous_state.status + stale_previous_worker = ( + previous_status in {WorkerStatus.RUNNING, WorkerStatus.ERROR} + and heartbeat_age_seconds >= self.settings.worker_service_recovery_stale_after_seconds + ) + if not replaced_lock and not stale_previous_worker: + return None + + reason = "startup-lock-replaced" if replaced_lock else "startup-stale-worker" + recovered = app.orchestration.recover_running_jobs( + limit=self.settings.worker_service_recovery_limit, + message="Recovered an interrupted running job during worker service startup.", + metadata={ + "recovery_reason": reason, + "previous_worker_status": previous_status, + "previous_worker_heartbeat_at": previous_state.heartbeat_at.isoformat(), + "previous_worker_current_job_id": previous_state.current_job_id, + "replaced_lock": replaced_lock, + "replaced_lock_pid": ( + lock.replaced_metadata.get("pid") if lock.replaced_metadata else None + ), + }, + ) + if not recovered: + return None + + summary = { + "reason": reason, + "count": len(recovered), + "job_ids": [job.id for job in recovered], + "at": utcnow().isoformat(), + "previous_worker_status": previous_status, + "previous_worker_heartbeat_at": previous_state.heartbeat_at.isoformat(), + "previous_worker_heartbeat_age_seconds": heartbeat_age_seconds, + "replaced_lock": replaced_lock, + "replaced_lock_pid": ( + lock.replaced_metadata.get("pid") if lock.replaced_metadata else None + ), + } + self.logger.warning( + "Worker service startup recovered %s interrupted running job(s) via %s.", + len(recovered), + reason, + ) + return summary + + def _recover_cycle_stale_jobs( + self, + app: LAIApplication, + *, + cycle: int, + ) -> dict[str, object] | None: + recovered = app.orchestration.recover_stale_jobs( + stale_after_seconds=self.settings.worker_service_recovery_stale_after_seconds, + limit=self.settings.worker_service_recovery_limit, + ) + if not recovered: + return None + + summary = { + "reason": "cycle-stale-recovery", + "count": len(recovered), + "job_ids": [job.id for job in recovered], + "at": utcnow().isoformat(), + "cycle": cycle, + "stale_after_seconds": self.settings.worker_service_recovery_stale_after_seconds, + } + self.logger.warning( + "Worker service cycle %s recovered %s stale running job(s).", + cycle, + len(recovered), + ) + return summary + + def _mark_service_state( + self, + *, + status: WorkerStatus, + phase: str, + processed_jobs: int, + metadata: dict[str, object] | None = None, + last_error: str | None = None, + ) -> None: + app = self.application_factory() + current_state = app.orchestration.get_worker_state(self.config.worker_id) + merged_metadata = dict(current_state.metadata) + merged_metadata.update(metadata or {}) + merged_metadata["phase"] = phase + if status != WorkerStatus.RUNNING: + for key in ( + "current_stage", + "current_stage_phase", + "current_stage_progress", + "current_stage_message", + "current_stage_updated_at", + "current_model_id", + "current_provider_id", + "elapsed_seconds", + ): + merged_metadata[key] = None + app.orchestration.update_worker_state( + worker_id=self.config.worker_id, + status=status, + mode="daemon", + current_job_id=None, + processed_jobs=processed_jobs, + queued_jobs=app.job_store.count_jobs_by_status(JobStatus.QUEUED), + started_at=current_state.started_at or current_state.heartbeat_at, + last_error=last_error, + metadata=merged_metadata, + ) + + +def build_worker_service_logger(log_path: Path) -> logging.Logger: + log_path.parent.mkdir(parents=True, exist_ok=True) + logger = logging.getLogger("lai.worker.service") + logger.setLevel(logging.INFO) + logger.propagate = False + + handler_paths = { + getattr(handler, "baseFilename", None) + for handler in logger.handlers + if hasattr(handler, "baseFilename") + } + if str(log_path) not in handler_paths: + handler = logging.FileHandler(log_path, encoding="utf-8") + handler.setFormatter( + logging.Formatter("%(asctime)s %(levelname)s %(message)s") + ) + logger.addHandler(handler) + + return logger diff --git a/src/lai/workstation.py b/src/lai/workstation.py new file mode 100644 index 0000000..91d4405 --- /dev/null +++ b/src/lai/workstation.py @@ -0,0 +1,677 @@ +from __future__ import annotations + +import sys +from datetime import datetime, timezone +from importlib.metadata import PackageNotFoundError, version +from pathlib import Path +from typing import Any, Literal, Mapping + +from pydantic import Field + +from .config import AppConfig +from .domain import LAIModel, ModelRuntime, ModelSpec +from .layout import runtime_directories +from .settings import Settings +from .system import SystemSnapshot, collect_system_snapshot + +SUPPORTED_PYTHON = ((3, 11), (3, 12)) + + +class WorkstationCheck(LAIModel): + id: str + status: Literal["pass", "warn", "fail"] + title: str + summary: str + remediation: str | None = None + metadata: dict[str, Any] = Field(default_factory=dict) + + +class WorkstationProfileReport(LAIModel): + profile_id: str + title: str + summary: str + ready: bool + warning_count: int + failure_count: int + checks: list[WorkstationCheck] + + +class WorkstationReadinessReport(LAIModel): + executed_at: str + system: dict[str, Any] + profiles: list[WorkstationProfileReport] + + def profile(self, profile_id: str) -> WorkstationProfileReport: + for profile in self.profiles: + if profile.profile_id == profile_id: + return profile + raise KeyError(f"Unknown workstation profile: {profile_id}") + + +def build_workstation_readiness( + settings: Settings, + config: AppConfig, + *, + snapshot: SystemSnapshot | None = None, + dependency_versions: Mapping[str, str | None] | None = None, +) -> WorkstationReadinessReport: + snapshot = snapshot or collect_system_snapshot( + *runtime_directories(settings.root_dir), enable_gpu=settings.enable_gpu + ) + dependencies = dict(dependency_versions or _detect_dependency_versions()) + + local_models = [ + model + for model in config.model_catalog.enabled_models() + if model.runtime in {ModelRuntime.TRANSFORMERS, ModelRuntime.AIRLLM} + ] + transformer_models = [ + model for model in local_models if model.runtime == ModelRuntime.TRANSFORMERS + ] + airllm_models = [model for model in local_models if model.runtime == ModelRuntime.AIRLLM] + + profiles = [ + _build_local_profile( + settings, + config, + snapshot, + dependencies, + local_models=local_models, + transformer_models=transformer_models, + airllm_models=airllm_models, + ), + ] + if airllm_models: + profiles.append( + _build_airllm_profile( + settings, + config, + snapshot, + dependencies, + airllm_models=airllm_models, + ) + ) + + return WorkstationReadinessReport( + executed_at=_utcnow_iso(), + system={ + "platform": snapshot.platform_name, + "python_version": snapshot.python_version, + "gpu_name": snapshot.gpu_name, + "gpu_total_memory_gb": snapshot.gpu_total_memory_gb, + "gpu_driver_version": snapshot.gpu_driver_version, + "total_ram_gb": snapshot.total_ram_gb, + "available_ram_gb": snapshot.available_ram_gb, + "catalog_disk_headroom_gb": config.model_catalog.defaults.require_disk_headroom_gb, + "huggingface_cache_dir": str(settings.resolved_huggingface_cache_dir), + "airllm_shards_dir": str(settings.resolved_airllm_shards_dir), + "paths": { + str(path): snapshot.free_disk_gb.get(str(path)) + for path in runtime_directories(settings.root_dir) + }, + "local_model_ids": [model.id for model in local_models], + "airllm_model_ids": [model.id for model in airllm_models], + }, + profiles=profiles, + ) + + +def _build_local_profile( + settings: Settings, + config: AppConfig, + snapshot: SystemSnapshot, + dependencies: Mapping[str, str | None], + *, + local_models: list[ModelSpec], + transformer_models: list[ModelSpec], + airllm_models: list[ModelSpec], +) -> WorkstationProfileReport: + checks = [ + _python_version_check(), + _path_headroom_check( + check_id="runtime-paths", + title="Runtime paths and disk headroom", + path=settings.resolved_huggingface_cache_dir, + free_disk_gb=snapshot.free_disk_gb.get( + str(settings.resolved_huggingface_cache_dir) + ), + required_free_gb=min(config.model_catalog.defaults.require_disk_headroom_gb, 80), + remediation=( + "Move the Hugging Face cache to a drive with more free space before " + "large local downloads." + ), + ), + _dependency_stack_check( + check_id="local-dependencies", + title="Local inference dependency stack", + required_dependencies={ + "torch": dependencies.get("torch"), + "transformers": dependencies.get("transformers"), + "accelerate": dependencies.get("accelerate"), + "huggingface-hub": dependencies.get("huggingface-hub"), + }, + remediation="Install the local extra with `python -m pip install -e .[local]`.", + ), + ] + + if local_models: + max_recommended_ram = max( + (model.hardware.recommended_ram_gb or 0 for model in local_models), + default=0, + ) + checks.append( + _ram_check( + check_id="local-ram", + title="System RAM for local runtimes", + total_ram_gb=snapshot.total_ram_gb, + recommended_ram_gb=max_recommended_ram or None, + remediation=( + "Close other heavy applications or move larger execution stages " + "to remote providers when RAM is constrained." + ), + ) + ) + + if transformer_models or airllm_models: + checks.append( + _huggingface_token_check( + has_token=bool(settings.huggingface_token_value), + local_models=local_models, + ) + ) + + return _build_profile( + profile_id="local", + title="Local runtime baseline", + checks=checks, + ) + + +def _build_airllm_profile( + settings: Settings, + config: AppConfig, + snapshot: SystemSnapshot, + dependencies: Mapping[str, str | None], + *, + airllm_models: list[ModelSpec], +) -> WorkstationProfileReport: + heaviest_model = max( + airllm_models, + key=lambda model: ( + model.hardware.expected_disk_gb or 0, + model.hardware.recommended_ram_gb or 0, + model.hardware.recommended_vram_gb or 0, + ), + ) + required_free_disk_gb = max( + heaviest_model.hardware.expected_disk_gb or 0, + config.model_catalog.defaults.require_disk_headroom_gb, + ) + checks = [ + _python_version_check(), + _dependency_stack_check( + check_id="airllm-dependencies", + title="AirLLM dependency stack", + required_dependencies={ + "torch": dependencies.get("torch"), + "transformers": dependencies.get("transformers"), + "accelerate": dependencies.get("accelerate"), + "huggingface-hub": dependencies.get("huggingface-hub"), + "airllm": dependencies.get("airllm"), + }, + remediation=( + "Install the heavy local extras with `python -m pip install -e " + ".[local]` and verify AirLLM imports cleanly." + ), + ), + _optional_dependency_check( + check_id="bitsandbytes", + title="Optional bitsandbytes compression", + dependency_name="bitsandbytes", + dependency_version=dependencies.get("bitsandbytes"), + remediation=( + "Install `bitsandbytes` if you want lower-memory quantized local " + "experiments." + ), + ), + _huggingface_token_check( + has_token=bool(settings.huggingface_token_value), + local_models=airllm_models, + ), + _gpu_airllm_check(snapshot=snapshot, model=heaviest_model), + _vram_check(snapshot=snapshot, model=heaviest_model), + _ram_check( + check_id="airllm-ram", + title="System RAM for AirLLM orchestration", + total_ram_gb=snapshot.total_ram_gb, + recommended_ram_gb=heaviest_model.hardware.recommended_ram_gb, + remediation=( + "Use a workstation with more system RAM, or avoid the heaviest " + "local executor until RAM is expanded." + ), + ), + _path_headroom_check( + check_id="airllm-shards-disk", + title="AirLLM shard storage", + path=settings.resolved_airllm_shards_dir, + free_disk_gb=snapshot.free_disk_gb.get(str(settings.resolved_airllm_shards_dir)), + required_free_gb=required_free_disk_gb, + remediation=( + "Move `LAI_AIRLLM_SHARDS_DIR` to a larger drive or clear space before preparing " + f"{heaviest_model.id}." + ), + metadata={ + "model_id": heaviest_model.id, + "expected_disk_gb": heaviest_model.hardware.expected_disk_gb, + }, + ), + _path_headroom_check( + check_id="huggingface-cache-disk", + title="Hugging Face cache staging space", + path=settings.resolved_huggingface_cache_dir, + free_disk_gb=snapshot.free_disk_gb.get( + str(settings.resolved_huggingface_cache_dir) + ), + required_free_gb=min(required_free_disk_gb, 160), + remediation=( + "Move `LAI_HUGGINGFACE_CACHE_DIR` to a drive with more free space " + "for large model downloads." + ), + metadata={"model_id": heaviest_model.id}, + ), + ] + + return _build_profile( + profile_id="airllm", + title="AirLLM heavy local executor", + checks=checks, + ) + + +def _python_version_check() -> WorkstationCheck: + minimum, maximum = SUPPORTED_PYTHON + current = sys.version_info[:2] + if minimum <= current <= maximum: + return WorkstationCheck( + id="python-version", + status="pass", + title="Python runtime", + summary=( + f"Python {platform_version()} is within the supported range for LAI " + "local tooling." + ), + metadata={"python_version": platform_version()}, + ) + + return WorkstationCheck( + id="python-version", + status="fail", + title="Python runtime", + summary=( + f"Python {platform_version()} is outside the supported range {minimum[0]}.{minimum[1]} " + f"to {maximum[0]}.{maximum[1]}." + ), + remediation=( + "Create a Python 3.11 virtual environment before running local model " + "workflows." + ), + metadata={"python_version": platform_version()}, + ) + + +def _dependency_stack_check( + *, + check_id: str, + title: str, + required_dependencies: Mapping[str, str | None], + remediation: str, +) -> WorkstationCheck: + missing = [ + name for name, package_version in required_dependencies.items() if not package_version + ] + if not missing: + return WorkstationCheck( + id=check_id, + status="pass", + title=title, + summary="All required packages are installed.", + metadata={"versions": required_dependencies}, + ) + + return WorkstationCheck( + id=check_id, + status="fail", + title=title, + summary=f"Missing required packages: {', '.join(missing)}.", + remediation=remediation, + metadata={"versions": required_dependencies}, + ) + + +def _optional_dependency_check( + *, + check_id: str, + title: str, + dependency_name: str, + dependency_version: str | None, + remediation: str, +) -> WorkstationCheck: + if dependency_version: + return WorkstationCheck( + id=check_id, + status="pass", + title=title, + summary=f"{dependency_name} {dependency_version} is installed.", + metadata={"version": dependency_version}, + ) + + return WorkstationCheck( + id=check_id, + status="warn", + title=title, + summary=f"{dependency_name} is not installed.", + remediation=remediation, + ) + + +def _huggingface_token_check( + *, + has_token: bool, + local_models: list[ModelSpec], +) -> WorkstationCheck: + gated_models = [model.id for model in local_models if model.model_ref.startswith("meta-llama/")] + if not gated_models: + return WorkstationCheck( + id="huggingface-token", + status="pass", + title="Hugging Face credentials", + summary="No gated local models currently require a Hugging Face token.", + ) + + if has_token: + return WorkstationCheck( + id="huggingface-token", + status="pass", + title="Hugging Face credentials", + summary="A Hugging Face token is configured for gated local models.", + metadata={"model_ids": gated_models}, + ) + + return WorkstationCheck( + id="huggingface-token", + status="fail", + title="Hugging Face credentials", + summary=f"Gated local models need `LAI_HF_TOKEN`: {', '.join(gated_models)}.", + remediation=( + "Set `LAI_HF_TOKEN` in `.env` before downloading or running gated local " + "models." + ), + metadata={"model_ids": gated_models}, + ) + + +def _gpu_airllm_check(*, snapshot: SystemSnapshot, model: ModelSpec) -> WorkstationCheck: + if snapshot.has_gpu: + return WorkstationCheck( + id="airllm-gpu", + status="pass", + title="GPU availability", + summary=f"Detected GPU {snapshot.gpu_name or 'unknown'} for AirLLM workloads.", + metadata={ + "gpu_name": snapshot.gpu_name, + "gpu_total_memory_gb": snapshot.gpu_total_memory_gb, + "gpu_driver_version": snapshot.gpu_driver_version, + }, + ) + + if model.runtime_hints.allow_cpu_fallback: + return WorkstationCheck( + id="airllm-gpu", + status="warn", + title="GPU availability", + summary=( + "No GPU detected. AirLLM can still fall back to CPU for this model, but expect " + "very slow overnight execution." + ), + remediation=( + "Use a CUDA-capable NVIDIA GPU if you want practical turnaround " + "for the large local executor." + ), + ) + + return WorkstationCheck( + id="airllm-gpu", + status="fail", + title="GPU availability", + summary="No GPU detected and CPU fallback is disabled for the selected AirLLM model.", + remediation=( + "Enable a CUDA-capable GPU or switch the model policy to a provider " + "that supports CPU fallback." + ), + ) + + +def _vram_check(*, snapshot: SystemSnapshot, model: ModelSpec) -> WorkstationCheck: + minimum_vram_gb = model.hardware.minimum_vram_gb + recommended_vram_gb = model.hardware.recommended_vram_gb + gpu_vram = snapshot.gpu_total_memory_gb + + if not snapshot.has_gpu: + return WorkstationCheck( + id="airllm-vram", + status="warn", + title="GPU VRAM budget", + summary="VRAM checks are skipped because no GPU is available.", + ) + + if gpu_vram is None: + return WorkstationCheck( + id="airllm-vram", + status="warn", + title="GPU VRAM budget", + summary="GPU detected, but VRAM could not be measured automatically.", + remediation=( + "Check `nvidia-smi` manually to confirm the GPU can host the " + "AirLLM working set." + ), + ) + + if minimum_vram_gb and gpu_vram < minimum_vram_gb: + return WorkstationCheck( + id="airllm-vram", + status="fail", + title="GPU VRAM budget", + summary=( + f"Detected {gpu_vram:.2f} GB VRAM, below the minimum " + f"{minimum_vram_gb} GB for {model.id}." + ), + remediation=( + "Use a larger GPU or route the executor to a remote provider " + "until more VRAM is available." + ), + metadata={"gpu_total_memory_gb": gpu_vram, "minimum_vram_gb": minimum_vram_gb}, + ) + + if recommended_vram_gb and gpu_vram < recommended_vram_gb: + return WorkstationCheck( + id="airllm-vram", + status="warn", + title="GPU VRAM budget", + summary=( + f"Detected {gpu_vram:.2f} GB VRAM, below the recommended " + f"{recommended_vram_gb} GB for smoother AirLLM execution." + ), + remediation=( + "Expect more disk sharding and slower layer movement until a " + "larger GPU is available." + ), + metadata={"gpu_total_memory_gb": gpu_vram, "recommended_vram_gb": recommended_vram_gb}, + ) + + return WorkstationCheck( + id="airllm-vram", + status="pass", + title="GPU VRAM budget", + summary=f"Detected {gpu_vram:.2f} GB VRAM, which satisfies the configured AirLLM target.", + metadata={ + "gpu_total_memory_gb": gpu_vram, + "recommended_vram_gb": recommended_vram_gb, + "minimum_vram_gb": minimum_vram_gb, + }, + ) + + +def _ram_check( + *, + check_id: str, + title: str, + total_ram_gb: float | None, + recommended_ram_gb: int | None, + remediation: str, +) -> WorkstationCheck: + if recommended_ram_gb is None: + return WorkstationCheck( + id=check_id, + status="pass", + title=title, + summary="No explicit RAM recommendation is configured for the current model set.", + ) + + if total_ram_gb is None: + return WorkstationCheck( + id=check_id, + status="warn", + title=title, + summary="System RAM could not be measured automatically.", + remediation=( + "Check the host RAM budget manually before starting long-running " + "local jobs." + ), + ) + + if total_ram_gb < recommended_ram_gb: + return WorkstationCheck( + id=check_id, + status="warn", + title=title, + summary=( + f"Detected {total_ram_gb:.2f} GB RAM, below the recommended " + f"{recommended_ram_gb} GB." + ), + remediation=remediation, + metadata={"total_ram_gb": total_ram_gb, "recommended_ram_gb": recommended_ram_gb}, + ) + + return WorkstationCheck( + id=check_id, + status="pass", + title=title, + summary=f"Detected {total_ram_gb:.2f} GB RAM, meeting the configured recommendation.", + metadata={"total_ram_gb": total_ram_gb, "recommended_ram_gb": recommended_ram_gb}, + ) + + +def _path_headroom_check( + *, + check_id: str, + title: str, + path: Path, + free_disk_gb: float | None, + required_free_gb: int, + remediation: str, + metadata: Mapping[str, Any] | None = None, +) -> WorkstationCheck: + if free_disk_gb is None: + return WorkstationCheck( + id=check_id, + status="warn", + title=title, + summary=f"Could not measure free disk space for {path}.", + remediation=remediation, + metadata={"path": str(path), **(metadata or {})}, + ) + + status: Literal["pass", "warn", "fail"] + if free_disk_gb >= required_free_gb: + status = "pass" + summary = ( + f"{path} has {free_disk_gb:.2f} GB free, meeting the {required_free_gb} GB target." + ) + elif free_disk_gb >= required_free_gb * 0.7: + status = "warn" + summary = ( + f"{path} has {free_disk_gb:.2f} GB free, below the target " + f"{required_free_gb} GB but still close enough for careful staging." + ) + else: + status = "fail" + summary = ( + f"{path} has {free_disk_gb:.2f} GB free, well below the target {required_free_gb} GB." + ) + + return WorkstationCheck( + id=check_id, + status=status, + title=title, + summary=summary, + remediation=None if status == "pass" else remediation, + metadata={ + "path": str(path), + "free_disk_gb": free_disk_gb, + "required_free_gb": required_free_gb, + **(metadata or {}), + }, + ) + + +def _build_profile( + *, + profile_id: str, + title: str, + checks: list[WorkstationCheck], +) -> WorkstationProfileReport: + failure_count = sum(1 for check in checks if check.status == "fail") + warning_count = sum(1 for check in checks if check.status == "warn") + ready = failure_count == 0 + if ready and warning_count == 0: + summary = "Ready for this profile with no blocking issues detected." + elif ready: + summary = f"Ready with {warning_count} warning(s)." + else: + summary = f"Blocked by {failure_count} failing check(s) and {warning_count} warning(s)." + + return WorkstationProfileReport( + profile_id=profile_id, + title=title, + summary=summary, + ready=ready, + warning_count=warning_count, + failure_count=failure_count, + checks=checks, + ) + + +def _detect_dependency_versions() -> dict[str, str | None]: + return { + "torch": _package_version("torch"), + "transformers": _package_version("transformers"), + "accelerate": _package_version("accelerate"), + "huggingface-hub": _package_version("huggingface-hub"), + "airllm": _package_version("airllm"), + "bitsandbytes": _package_version("bitsandbytes"), + } + + +def _package_version(package_name: str) -> str | None: + try: + return version(package_name) + except PackageNotFoundError: + return None + + +def platform_version() -> str: + return f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}" + + +def _utcnow_iso() -> str: + return datetime.now(tz=timezone.utc).isoformat() diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..60fee60 --- /dev/null +++ b/tests/__init__.py @@ -0,0 +1 @@ +# Test package for LAI. diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..94b7195 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,12 @@ +from __future__ import annotations + +from pathlib import Path + +import pytest + +REPO_ROOT = Path(__file__).resolve().parents[1] + + +@pytest.fixture() +def repo_root() -> Path: + return REPO_ROOT diff --git a/tests/helpers.py b/tests/helpers.py new file mode 100644 index 0000000..998cf16 --- /dev/null +++ b/tests/helpers.py @@ -0,0 +1,84 @@ +from __future__ import annotations + +from typing import Any + +from lai.domain import ExecutionResult, FinishReason, ModelSpec, ProviderHealth, ProviderRequest +from lai.providers.base import ProgressCallback, Provider +from lai.settings import Settings +from lai.system import SystemSnapshot + + +class FakeProvider(Provider): + def __init__( + self, + settings: Settings, + system_snapshot: SystemSnapshot, + provider_id: str, + *, + available: bool = True, + ) -> None: + super().__init__(settings, system_snapshot) + self.provider_id = provider_id + self._available = available + + def healthcheck(self, model: ModelSpec) -> ProviderHealth: + if self._available: + return ProviderHealth(provider_id=self.provider_id, available=True, healthy=True) + return ProviderHealth( + provider_id=self.provider_id, + available=False, + healthy=False, + reasons=[f"{self.provider_id} unavailable in test harness"], + ) + + def generate( + self, + model: ModelSpec, + request: ProviderRequest, + *, + progress_callback: ProgressCallback | None = None, + ) -> ExecutionResult: + stage = str(request.metadata.get("stage", "executor")) + self._emit_progress( + progress_callback, + phase="prepare", + message=f"{self.provider_id} preparing {stage} stage.", + progress=0.2, + stage=stage, + ) + self._emit_progress( + progress_callback, + phase="generate", + message=f"{self.provider_id} generating {stage} stage output.", + progress=0.7, + stage=stage, + ) + if stage == "planner": + text = "plan: inspect, execute, verify" + elif stage == "reviewer": + text = "review: output looks good" + elif stage == "executor": + text = f"executed by {model.id}: {request.user_prompt}" + elif model.role == "router": + text = '{"tier_id": "standard"}' + else: + text = f"generated by {model.id}: {request.user_prompt}" + self._emit_progress( + progress_callback, + phase="finalize", + message=f"{self.provider_id} finalized {stage} stage output.", + progress=0.95, + stage=stage, + ) + return ExecutionResult( + text=text, + finish_reason=FinishReason.STOP, + provider_id=self.provider_id, + model_id=model.id, + duration_seconds=0.01, + raw={"stage": stage}, + stage=stage, + ) + + def describe_capabilities(self) -> dict[str, Any]: + return {"test_provider": True} diff --git a/tests/integration/test_orchestration.py b/tests/integration/test_orchestration.py new file mode 100644 index 0000000..1da621f --- /dev/null +++ b/tests/integration/test_orchestration.py @@ -0,0 +1,149 @@ +from pathlib import Path + +from lai.application import create_application +from lai.domain import ExecutionRequest, JobStatus, QueueMode +from lai.settings import Settings +from lai.system import collect_system_snapshot +from tests.helpers import FakeProvider + + +def _make_application(tmp_path, repo_root): + settings = Settings( + root_dir=tmp_path, + model_catalog=repo_root / "configs/models/catalog.yaml", + routing_policy=repo_root / "configs/routing/policies.yaml", + enable_gpu=True, + ) + snapshot = collect_system_snapshot( + settings.root_dir, + enable_gpu=False, + ) + providers = { + "transformers": FakeProvider(settings, snapshot, "transformers"), + "airllm": FakeProvider(settings, snapshot, "airllm"), + "openai": FakeProvider(settings, snapshot, "openai"), + "anthropic": FakeProvider(settings, snapshot, "anthropic"), + "gemini": FakeProvider(settings, snapshot, "gemini"), + } + return create_application(settings=settings, providers=providers) + + +def test_inline_execution_persists_artifacts(tmp_path, repo_root) -> None: + app = _make_application(tmp_path, repo_root) + + job = app.orchestration.submit_request( + ExecutionRequest( + user_prompt="Summarize this short note.", + queue_mode=QueueMode.INLINE, + ) + ) + + assert job.status == JobStatus.SUCCEEDED + assert job.result is not None + final_output = Path(app.settings.resolved_artifacts_dir) / job.id / "final_output.txt" + assert final_output.exists() + assert "executed by" in final_output.read_text(encoding="utf-8") + assert job.stage_events + assert any( + event.stage == "executor" and event.event_type == "completed" + for event in job.stage_events + ) + assert any( + event.stage == "executor" and event.event_type == "progress" + for event in job.stage_events + ) + assert any( + event.stage == "job" and event.event_type == "succeeded" + for event in job.stage_events + ) + + +def test_queued_job_survives_restart(tmp_path, repo_root) -> None: + app = _make_application(tmp_path, repo_root) + job = app.orchestration.submit_request( + ExecutionRequest( + user_prompt="Create a comprehensive advanced architecture for this system.", + queue_mode=QueueMode.QUEUED, + ) + ) + + restarted = _make_application(tmp_path, repo_root) + processed = restarted.orchestration.run_worker(once=True) + persisted = restarted.job_store.get_job(job.id) + + assert processed == 1 + assert persisted is not None + assert persisted.status == JobStatus.SUCCEEDED + assert persisted.result is not None + + +def test_replay_job_creates_a_new_persisted_request(tmp_path, repo_root) -> None: + app = _make_application(tmp_path, repo_root) + source_job = app.orchestration.submit_request( + ExecutionRequest( + user_prompt="Prepare a structured review of the platform design.", + queue_mode=QueueMode.INLINE, + ) + ) + + replayed_job = app.orchestration.replay_job(source_job.id, queue_mode=QueueMode.QUEUED) + + assert replayed_job.id != source_job.id + assert replayed_job.status == JobStatus.QUEUED + assert replayed_job.request.user_prompt == source_job.request.user_prompt + assert replayed_job.request.metadata["replayed_from_job_id"] == source_job.id + assert any( + event.stage == "job" and event.event_type == "replayed" + for event in replayed_job.stage_events + ) + + +def test_worker_batch_processes_only_requested_number_of_jobs(tmp_path, repo_root) -> None: + app = _make_application(tmp_path, repo_root) + first_job = app.orchestration.submit_request( + ExecutionRequest( + user_prompt="First queued system design pass.", + queue_mode=QueueMode.QUEUED, + ) + ) + second_job = app.orchestration.submit_request( + ExecutionRequest( + user_prompt="Second queued system design pass.", + queue_mode=QueueMode.QUEUED, + ) + ) + + processed = app.orchestration.run_worker_batch(max_jobs=1) + refreshed_first = app.job_store.get_job(first_job.id) + refreshed_second = app.job_store.get_job(second_job.id) + + assert len(processed) == 1 + assert refreshed_first is not None + assert refreshed_second is not None + assert [refreshed_first.status, refreshed_second.status].count(JobStatus.SUCCEEDED) == 1 + assert [refreshed_first.status, refreshed_second.status].count(JobStatus.QUEUED) == 1 + + +def test_worker_until_idle_drains_queue_and_updates_worker_state(tmp_path, repo_root) -> None: + app = _make_application(tmp_path, repo_root) + app.orchestration.submit_request( + ExecutionRequest( + user_prompt="First overnight architecture request.", + queue_mode=QueueMode.QUEUED, + ) + ) + app.orchestration.submit_request( + ExecutionRequest( + user_prompt="Second overnight architecture request.", + queue_mode=QueueMode.QUEUED, + ) + ) + + processed = app.orchestration.run_worker_until_idle(max_idle_cycles=1) + worker = app.orchestration.get_worker_state() + + assert processed == 2 + assert worker.status == "idle" + assert worker.mode == "until-idle" + assert worker.processed_jobs == 2 + assert worker.queued_jobs == 0 diff --git a/tests/live/test_local_smoke.py b/tests/live/test_local_smoke.py new file mode 100644 index 0000000..b4522c8 --- /dev/null +++ b/tests/live/test_local_smoke.py @@ -0,0 +1,74 @@ +from __future__ import annotations + +import os + +import pytest +from pydantic import SecretStr + +from lai.domain import ModelSpec, ProviderRequest +from lai.providers.implementations import AirLLMProvider, TransformersProvider +from lai.settings import Settings +from lai.system import collect_system_snapshot + + +def _settings(tmp_path, **overrides) -> Settings: + return Settings( + root_dir=tmp_path, + model_catalog=tmp_path / "unused-models.yaml", + routing_policy=tmp_path / "unused-routing.yaml", + **overrides, + ) + + +def test_transformers_live_smoke(tmp_path) -> None: + if os.getenv("LAI_RUN_LOCAL_SMOKE", "").lower() != "true": + pytest.skip("Set LAI_RUN_LOCAL_SMOKE=true to run local model smoke tests.") + pytest.importorskip("transformers") + + settings = _settings(tmp_path) + provider = TransformersProvider(settings, collect_system_snapshot(tmp_path, enable_gpu=False)) + model = ModelSpec.model_validate( + { + "id": "transformers-smoke", + "role": "executor", + "runtime": "transformers", + "model": "sshleifer/tiny-gpt2", + "capabilities": ["summarization"], + } + ) + + result = provider.generate( + model, + ProviderRequest(user_prompt="Reply with READY.", max_output_tokens=12), + ) + assert result.text + + +def test_airllm_manual_smoke(tmp_path) -> None: + if os.getenv("LAI_RUN_AIRLLM_SMOKE", "").lower() != "true": + pytest.skip("Set LAI_RUN_AIRLLM_SMOKE=true to run the AirLLM smoke test.") + pytest.importorskip("airllm") + + hf_token = os.getenv("LAI_HF_TOKEN") + if not hf_token: + pytest.skip("LAI_HF_TOKEN is required for the AirLLM smoke test.") + + model_name = os.getenv("LAI_AIRLLM_SMOKE_MODEL", "meta-llama/Llama-3.1-8B-Instruct") + settings = _settings(tmp_path, hf_token=SecretStr(hf_token)) + provider = AirLLMProvider(settings, collect_system_snapshot(tmp_path, enable_gpu=True)) + model = ModelSpec.model_validate( + { + "id": "airllm-smoke", + "role": "executor", + "runtime": "airllm", + "repo_id": model_name, + "hardware": {"expected_disk_gb": 1}, + "runtime_hints": {"allow_cpu_fallback": True, "allow_layer_sharding": True}, + } + ) + + result = provider.generate( + model, + ProviderRequest(user_prompt="Reply with READY.", max_output_tokens=12), + ) + assert result.text diff --git a/tests/live/test_provider_smoke.py b/tests/live/test_provider_smoke.py new file mode 100644 index 0000000..dcd34c0 --- /dev/null +++ b/tests/live/test_provider_smoke.py @@ -0,0 +1,92 @@ +from __future__ import annotations + +import os + +import pytest +from pydantic import SecretStr + +from lai.domain import ModelSpec, ProviderRequest +from lai.providers.implementations import AnthropicProvider, GeminiProvider, OpenAIProvider +from lai.settings import Settings +from lai.system import collect_system_snapshot + + +def _settings(tmp_path, **overrides) -> Settings: + return Settings( + root_dir=tmp_path, + model_catalog=tmp_path / "unused-models.yaml", + routing_policy=tmp_path / "unused-routing.yaml", + **overrides, + ) + + +def test_openai_live_smoke(tmp_path) -> None: + api_key = os.getenv("LAI_OPENAI_API_KEY") + if not api_key: + pytest.skip("LAI_OPENAI_API_KEY is not configured.") + pytest.importorskip("openai") + + settings = _settings(tmp_path, openai_api_key=SecretStr(api_key)) + provider = OpenAIProvider(settings, collect_system_snapshot(tmp_path, enable_gpu=False)) + model = ModelSpec.model_validate( + { + "id": "openai-smoke", + "role": "executor", + "runtime": "openai", + "model": "gpt-5.4-mini", + } + ) + + result = provider.generate( + model, + ProviderRequest(user_prompt="Reply with READY and nothing else.", max_output_tokens=32), + ) + assert "READY" in result.text.upper() + + +def test_anthropic_live_smoke(tmp_path) -> None: + api_key = os.getenv("LAI_ANTHROPIC_API_KEY") + if not api_key: + pytest.skip("LAI_ANTHROPIC_API_KEY is not configured.") + pytest.importorskip("anthropic") + + settings = _settings(tmp_path, anthropic_api_key=SecretStr(api_key)) + provider = AnthropicProvider(settings, collect_system_snapshot(tmp_path, enable_gpu=False)) + model = ModelSpec.model_validate( + { + "id": "anthropic-smoke", + "role": "executor", + "runtime": "anthropic", + "model": "claude-sonnet-4-20250514", + } + ) + + result = provider.generate( + model, + ProviderRequest(user_prompt="Reply with READY and nothing else.", max_output_tokens=32), + ) + assert "READY" in result.text.upper() + + +def test_gemini_live_smoke(tmp_path) -> None: + api_key = os.getenv("LAI_GEMINI_API_KEY") + if not api_key: + pytest.skip("LAI_GEMINI_API_KEY is not configured.") + pytest.importorskip("google.genai") + + settings = _settings(tmp_path, gemini_api_key=SecretStr(api_key)) + provider = GeminiProvider(settings, collect_system_snapshot(tmp_path, enable_gpu=False)) + model = ModelSpec.model_validate( + { + "id": "gemini-smoke", + "role": "executor", + "runtime": "gemini", + "model": "gemini-2.5-flash", + } + ) + + result = provider.generate( + model, + ProviderRequest(user_prompt="Reply with READY and nothing else.", max_output_tokens=32), + ) + assert "READY" in result.text.upper() diff --git a/tests/unit/test_api.py b/tests/unit/test_api.py new file mode 100644 index 0000000..7c650ee --- /dev/null +++ b/tests/unit/test_api.py @@ -0,0 +1,360 @@ +from datetime import timedelta +from uuid import uuid4 + +from fastapi.testclient import TestClient + +from lai.api import create_api +from lai.application import create_application +from lai.domain import ExecutionRequest, QueueMode, StageEventRecord, utcnow +from lai.settings import Settings +from lai.system import collect_system_snapshot +from tests.helpers import FakeProvider + + +def _test_settings(repo_root, tmp_path) -> Settings: + return Settings( + root_dir=tmp_path, + model_catalog=repo_root / "configs/models/catalog.yaml", + routing_policy=repo_root / "configs/routing/policies.yaml", + database_path=tmp_path / "data/state/lai.db", + state_dir=tmp_path / "data/state", + artifacts_dir=tmp_path / "data/artifacts", + logs_dir=tmp_path / "logs", + huggingface_cache_dir=tmp_path / "data/cache/huggingface", + airllm_shards_dir=tmp_path / "data/models/airllm-shards", + raw_models_dir=tmp_path / "data/models/raw", + ) + + +def _test_providers(settings: Settings) -> dict[str, FakeProvider]: + snapshot = collect_system_snapshot(settings.root_dir, enable_gpu=False) + return { + "transformers": FakeProvider(settings, snapshot, "transformers"), + "airllm": FakeProvider(settings, snapshot, "airllm"), + "openai": FakeProvider(settings, snapshot, "openai"), + "anthropic": FakeProvider(settings, snapshot, "anthropic"), + "gemini": FakeProvider(settings, snapshot, "gemini"), + } + + +def test_api_exposes_expected_routes() -> None: + app = create_api() + routes = {route.path for route in app.routes} + + assert "/" in routes + assert "/dashboard" in routes + assert "/health" in routes + assert "/models" in routes + assert "/workstation/readiness" in routes + assert "/smoke/latest" in routes + assert "/smoke/results" in routes + assert "/smoke/results/{result_id}" in routes + assert "/smoke/run" in routes + assert "/worker/status" in routes + assert "/route/explain" in routes + assert "/jobs" in routes + assert "/jobs/stale" in routes + assert "/jobs/recover" in routes + assert "/jobs/{job_id}" in routes + assert "/jobs/{job_id}/timeline" in routes + assert "/jobs/{job_id}/execution" in routes + assert "/jobs/{job_id}/replay" in routes + assert "/jobs/{job_id}/artifacts" in routes + assert "/jobs/{job_id}/artifacts/{artifact_id}" in routes + assert "/jobs/{job_id}/cancel" in routes + assert "/worker/run" in routes + + +def test_dashboard_routes_serve_html_and_assets(repo_root, tmp_path) -> None: + settings = _test_settings(repo_root, tmp_path) + client = TestClient(create_api(settings=settings)) + + root_response = client.get("/", follow_redirects=False) + dashboard_response = client.get("/dashboard") + asset_response = client.get("/dashboard/assets/dashboard.js") + + assert root_response.status_code == 307 + assert root_response.headers["location"] == "/dashboard" + assert dashboard_response.status_code == 200 + assert "LAI CONTROL ROOM" in dashboard_response.text + assert "Workstation readiness" in dashboard_response.text + assert "Smoke diagnostics" in dashboard_response.text + assert asset_response.status_code == 200 + assert "async function loadWorkstationReadiness" in asset_response.text + assert "async function loadExecution" in asset_response.text + assert "async function recoverStaleJobs" in asset_response.text + assert "function formatProgressPercent" in asset_response.text + + +def test_workstation_readiness_route_returns_profiles(repo_root, tmp_path) -> None: + settings = _test_settings(repo_root, tmp_path) + providers = _test_providers(settings) + client = TestClient(create_api(settings=settings, providers=providers)) + + response = client.get("/workstation/readiness") + + assert response.status_code == 200 + payload = response.json() + assert payload["system"]["python_version"] + assert payload["profiles"] + assert any(profile["profile_id"] == "local" for profile in payload["profiles"]) + + +def test_job_artifact_routes_return_persisted_content(repo_root, tmp_path) -> None: + settings = _test_settings(repo_root, tmp_path) + providers = _test_providers(settings) + seeded_app = create_application(settings=settings, providers=providers) + job = seeded_app.orchestration.submit_request( + ExecutionRequest( + user_prompt="Summarize this short note.", + queue_mode=QueueMode.INLINE, + ) + ) + + client = TestClient(create_api(settings=settings, providers=providers)) + artifacts_response = client.get(f"/jobs/{job.id}/artifacts") + assert artifacts_response.status_code == 200 + artifacts = artifacts_response.json()["artifacts"] + assert artifacts + + final_artifact = next( + artifact for artifact in artifacts if artifact["artifact_type"] == "final-output" + ) + artifact_response = client.get(f"/jobs/{job.id}/artifacts/{final_artifact['id']}") + + assert artifact_response.status_code == 200 + payload = artifact_response.json() + assert payload["content_type"] == "text/plain" + assert "executed by" in payload["content"] + + +def test_job_timeline_route_returns_stage_events(repo_root, tmp_path) -> None: + settings = _test_settings(repo_root, tmp_path) + providers = _test_providers(settings) + seeded_app = create_application(settings=settings, providers=providers) + job = seeded_app.orchestration.submit_request( + ExecutionRequest( + user_prompt="Summarize this short note.", + queue_mode=QueueMode.INLINE, + ) + ) + + client = TestClient(create_api(settings=settings, providers=providers)) + timeline_response = client.get(f"/jobs/{job.id}/timeline") + + assert timeline_response.status_code == 200 + payload = timeline_response.json() + assert payload["stage_events"] + assert any( + event["stage"] == "executor" and event["event_type"] == "completed" + for event in payload["stage_events"] + ) + assert any( + event["stage"] == "executor" and event["event_type"] == "progress" + for event in payload["stage_events"] + ) + + +def test_job_execution_route_returns_stage_summaries(repo_root, tmp_path) -> None: + settings = _test_settings(repo_root, tmp_path) + providers = _test_providers(settings) + seeded_app = create_application(settings=settings, providers=providers) + job = seeded_app.orchestration.submit_request( + ExecutionRequest( + user_prompt="Summarize this short note.", + queue_mode=QueueMode.INLINE, + ) + ) + + client = TestClient(create_api(settings=settings, providers=providers)) + execution_response = client.get(f"/jobs/{job.id}/execution") + + assert execution_response.status_code == 200 + payload = execution_response.json() + assert payload["stages"] + executor = next(stage for stage in payload["stages"] if stage["stage"] == "executor") + assert executor["status"] == "completed" + assert executor["provider_id"] == "transformers" + assert executor["artifact_types"] + assert executor["progress_event_count"] > 0 + assert executor["latest_progress_phase"] + assert executor["latest_progress_message"] + + +def test_stale_job_routes_list_and_recover_jobs(repo_root, tmp_path) -> None: + settings = _test_settings(repo_root, tmp_path) + providers = _test_providers(settings) + seeded_app = create_application(settings=settings, providers=providers) + stale_job = seeded_app.orchestration.submit_request( + ExecutionRequest( + user_prompt="Recover this interrupted overnight execution.", + queue_mode=QueueMode.QUEUED, + ) + ) + running = seeded_app.job_store.claim_next_queued_job() + assert running is not None + stale_at = utcnow() - timedelta(minutes=30) + running.updated_at = stale_at + running.started_at = stale_at + seeded_app.job_store.save_job(running) + seeded_app.job_store.add_stage_event( + StageEventRecord( + id=str(uuid4()), + job_id=stale_job.id, + stage="executor", + event_type="started", + message="Executor stage started long ago.", + created_at=stale_at, + model_id=running.route_decision.executor_model_id if running.route_decision else None, + ) + ) + + client = TestClient(create_api(settings=settings, providers=providers)) + stale_response = client.get("/jobs/stale?stale_after_seconds=60") + assert stale_response.status_code == 200 + stale_payload = stale_response.json() + assert stale_payload["jobs"] + assert stale_payload["jobs"][0]["job_id"] == stale_job.id + + recover_response = client.post( + "/jobs/recover", + json={"stale_after_seconds": 60, "limit": 10}, + ) + assert recover_response.status_code == 200 + recovered = recover_response.json()["recovered"] + assert recovered + assert recovered[0]["id"] == stale_job.id + assert recovered[0]["status"] == "queued" + + +def test_smoke_routes_run_and_return_latest_result(repo_root, tmp_path) -> None: + settings = _test_settings(repo_root, tmp_path) + providers = _test_providers(settings) + client = TestClient(create_api(settings=settings, providers=providers)) + + run_response = client.post( + "/smoke/run", + json={"suite_id": "providers", "save": True}, + ) + + assert run_response.status_code == 200 + payload = run_response.json() + assert payload["result"]["suite_id"] == "providers" + assert payload["result"]["mode"] == "readiness" + assert payload["id"] + assert payload["path"] + + latest_response = client.get("/smoke/latest?suite_id=providers") + assert latest_response.status_code == 200 + latest_payload = latest_response.json() + assert latest_payload["result"]["suite_id"] == "providers" + assert latest_payload["id"] + + history_response = client.get("/smoke/results?limit=3") + assert history_response.status_code == 200 + history_payload = history_response.json() + assert history_payload["results"] + detail_response = client.get(f"/smoke/results/{payload['id']}") + assert detail_response.status_code == 200 + assert detail_response.json()["id"] == payload["id"] + + missing_response = client.get("/smoke/results/not-a-smoke-result.json") + assert missing_response.status_code == 404 + + +def test_worker_status_and_until_idle_routes(repo_root, tmp_path) -> None: + settings = _test_settings(repo_root, tmp_path) + providers = _test_providers(settings) + seeded_app = create_application(settings=settings, providers=providers) + seeded_app.orchestration.submit_request( + ExecutionRequest( + user_prompt="Prepare a queued architecture pass.", + queue_mode=QueueMode.QUEUED, + ) + ) + seeded_app.orchestration.submit_request( + ExecutionRequest( + user_prompt="Prepare another queued architecture pass.", + queue_mode=QueueMode.QUEUED, + ) + ) + + client = TestClient(create_api(settings=settings, providers=providers)) + status_before = client.get("/worker/status") + assert status_before.status_code == 200 + assert status_before.json()["queued_jobs"] == 2 + assert "service_lock_present" in status_before.json() + assert "stop_signal_present" in status_before.json() + assert "log_file" in status_before.json() + + run_response = client.post( + "/worker/run", + json={"until_idle": True, "max_idle_cycles": 1}, + ) + + assert run_response.status_code == 200 + payload = run_response.json() + assert payload["processed"] == 2 + assert payload["worker"]["status"] == "idle" + assert payload["worker"]["mode"] == "until-idle" + assert payload["worker"]["queued_jobs"] == 0 + + +def test_job_replay_route_creates_new_job(repo_root, tmp_path) -> None: + settings = _test_settings(repo_root, tmp_path) + providers = _test_providers(settings) + seeded_app = create_application(settings=settings, providers=providers) + source_job = seeded_app.orchestration.submit_request( + ExecutionRequest( + user_prompt="Draft an internal plan for the queue system.", + queue_mode=QueueMode.INLINE, + ) + ) + + client = TestClient(create_api(settings=settings, providers=providers)) + replay_response = client.post( + f"/jobs/{source_job.id}/replay", + json={"queue_mode": "queued"}, + ) + + assert replay_response.status_code == 200 + replayed_job = replay_response.json() + assert replayed_job["id"] != source_job.id + assert replayed_job["status"] == "queued" + assert replayed_job["request"]["metadata"]["replayed_from_job_id"] == source_job.id + + +def test_worker_run_route_processes_bounded_batch(repo_root, tmp_path) -> None: + settings = _test_settings(repo_root, tmp_path) + providers = _test_providers(settings) + seeded_app = create_application(settings=settings, providers=providers) + first_job = seeded_app.orchestration.submit_request( + ExecutionRequest( + user_prompt="Prepare a long-form architecture review.", + queue_mode=QueueMode.QUEUED, + ) + ) + second_job = seeded_app.orchestration.submit_request( + ExecutionRequest( + user_prompt="Prepare another queued execution.", + queue_mode=QueueMode.QUEUED, + ) + ) + + client = TestClient(create_api(settings=settings, providers=providers)) + run_response = client.post("/worker/run", json={"max_jobs": 1}) + + assert run_response.status_code == 200 + payload = run_response.json() + assert payload["processed"] == 1 + assert len(payload["jobs"]) == 1 + assert payload["jobs"][0]["status"] == "succeeded" + assert payload["worker"]["status"] == "idle" + assert payload["worker"]["processed_jobs"] == 1 + + refreshed_first = seeded_app.job_store.get_job(first_job.id) + refreshed_second = seeded_app.job_store.get_job(second_job.id) + assert refreshed_first is not None + assert refreshed_second is not None + assert [refreshed_first.status, refreshed_second.status].count("succeeded") == 1 + assert [refreshed_first.status, refreshed_second.status].count("queued") == 1 diff --git a/tests/unit/test_config.py b/tests/unit/test_config.py new file mode 100644 index 0000000..23b6515 --- /dev/null +++ b/tests/unit/test_config.py @@ -0,0 +1,13 @@ +from lai.config import load_app_config + + +def test_app_config_loads_and_validates(repo_root) -> None: + config = load_app_config( + repo_root / "configs/models/catalog.yaml", + repo_root / "configs/routing/policies.yaml", + ) + + assert config.model_catalog.version == 1 + assert len(config.model_catalog.models) == 6 + assert config.routing_policy.router.model_id == "router-small" + assert config.routing_policy.fallbacks.when_gpu_unavailable is not None diff --git a/tests/unit/test_evals.py b/tests/unit/test_evals.py new file mode 100644 index 0000000..4648bdb --- /dev/null +++ b/tests/unit/test_evals.py @@ -0,0 +1,29 @@ +from lai.evals import load_route_eval_suite, run_route_eval_suite +from lai.settings import Settings + + +def test_route_eval_suite_loads(repo_root) -> None: + suite = load_route_eval_suite(repo_root / "evals/scenarios/routing-basics.yaml") + + assert suite.version == 1 + assert len(suite.scenarios) == 3 + assert suite.scenarios[2].unavailable_providers == ["airllm"] + + +def test_route_eval_suite_passes(repo_root, tmp_path) -> None: + settings = Settings( + root_dir=repo_root, + database_path=tmp_path / "lai.db", + state_dir=tmp_path, + artifacts_dir=tmp_path / "artifacts", + logs_dir=tmp_path / "logs", + huggingface_cache_dir=tmp_path / "hf-cache", + airllm_shards_dir=tmp_path / "airllm", + raw_models_dir=tmp_path / "raw", + ) + + result = run_route_eval_suite(settings, repo_root / "evals/scenarios/routing-basics.yaml") + + assert result.passed is True + assert result.total == 3 + assert result.failed_count == 0 diff --git a/tests/unit/test_layout.py b/tests/unit/test_layout.py index c5dc44f..fe2a39a 100644 --- a/tests/unit/test_layout.py +++ b/tests/unit/test_layout.py @@ -7,7 +7,7 @@ def test_runtime_directories_are_repo_relative() -> None: paths = runtime_directories(settings.root_dir) - assert len(paths) == 5 + assert len(paths) == 6 assert all(str(path).startswith(str(settings.root_dir)) for path in paths) @@ -16,3 +16,12 @@ def test_resolved_config_paths_are_under_repo_root() -> None: assert settings.resolved_model_catalog == settings.root_dir / "configs/models/catalog.yaml" assert settings.resolved_routing_policy == settings.root_dir / "configs/routing/policies.yaml" + assert settings.resolved_worker_service_lock_path == ( + settings.root_dir / "data/state/worker-service.lock" + ) + assert settings.resolved_worker_service_stop_path == ( + settings.root_dir / "data/state/worker-service.stop" + ) + assert settings.resolved_worker_service_log_path == ( + settings.root_dir / "logs/worker-service.log" + ) diff --git a/tests/unit/test_observability.py b/tests/unit/test_observability.py new file mode 100644 index 0000000..47648bc --- /dev/null +++ b/tests/unit/test_observability.py @@ -0,0 +1,62 @@ +from lai.application import create_application +from lai.domain import ExecutionRequest, QueueMode +from lai.observability import summarize_job_execution +from lai.settings import Settings +from lai.system import collect_system_snapshot +from tests.helpers import FakeProvider + + +def _test_settings(repo_root, tmp_path) -> Settings: + return Settings( + root_dir=tmp_path, + model_catalog=repo_root / "configs/models/catalog.yaml", + routing_policy=repo_root / "configs/routing/policies.yaml", + database_path=tmp_path / "data/state/lai.db", + state_dir=tmp_path / "data/state", + artifacts_dir=tmp_path / "data/artifacts", + logs_dir=tmp_path / "logs", + huggingface_cache_dir=tmp_path / "data/cache/huggingface", + airllm_shards_dir=tmp_path / "data/models/airllm-shards", + raw_models_dir=tmp_path / "data/models/raw", + ) + + +def _test_providers(settings: Settings) -> dict[str, FakeProvider]: + snapshot = collect_system_snapshot(settings.root_dir, enable_gpu=False) + return { + "transformers": FakeProvider(settings, snapshot, "transformers"), + "airllm": FakeProvider(settings, snapshot, "airllm"), + "openai": FakeProvider(settings, snapshot, "openai"), + "anthropic": FakeProvider(settings, snapshot, "anthropic"), + "gemini": FakeProvider(settings, snapshot, "gemini"), + } + + +def test_summarize_job_execution_returns_stage_details(repo_root, tmp_path) -> None: + settings = _test_settings(repo_root, tmp_path) + application = create_application(settings=settings, providers=_test_providers(settings)) + job = application.orchestration.submit_request( + ExecutionRequest( + user_prompt="Prepare a detailed implementation strategy for this platform.", + queue_mode=QueueMode.INLINE, + ) + ) + + summaries = summarize_job_execution(job) + + assert summaries + assert [summary.stage for summary in summaries] == ["planner", "executor", "reviewer"] + + executor = next(summary for summary in summaries if summary.stage == "executor") + assert executor.status == "completed" + assert executor.prompt_characters is not None + assert executor.output_preview + assert "final-output" in executor.artifact_types + assert executor.progress_event_count > 0 + assert executor.latest_progress_phase + assert executor.latest_progress_message + + reviewer = next(summary for summary in summaries if summary.stage == "reviewer") + assert reviewer.status == "completed" + assert "reviewer-notes" in reviewer.artifact_types + assert reviewer.progress_event_count > 0 diff --git a/tests/unit/test_providers.py b/tests/unit/test_providers.py new file mode 100644 index 0000000..e06a84f --- /dev/null +++ b/tests/unit/test_providers.py @@ -0,0 +1,68 @@ +from __future__ import annotations + +import sys +import types + +from lai.domain import ModelRole, ModelRuntime, ModelSpec, ProviderRequest +from lai.providers.implementations import TransformersProvider +from lai.settings import Settings +from lai.system import collect_system_snapshot + + +def _test_settings(repo_root, tmp_path, **overrides) -> Settings: + base = { + "root_dir": tmp_path, + "model_catalog": repo_root / "configs/models/catalog.yaml", + "routing_policy": repo_root / "configs/routing/policies.yaml", + "database_path": tmp_path / "data/state/lai.db", + "state_dir": tmp_path / "data/state", + "artifacts_dir": tmp_path / "data/artifacts", + "logs_dir": tmp_path / "logs", + "huggingface_cache_dir": tmp_path / "data/cache/huggingface", + "airllm_shards_dir": tmp_path / "data/models/airllm-shards", + "raw_models_dir": tmp_path / "data/models/raw", + } + base.update(overrides) + return Settings(**base) + + +def test_transformers_provider_passes_token_once_to_pipeline( + repo_root, + tmp_path, + monkeypatch, +) -> None: + settings = _test_settings(repo_root, tmp_path, hf_token="fake-token") + snapshot = collect_system_snapshot(settings.root_dir, enable_gpu=False) + provider = TransformersProvider(settings, snapshot) + model = ModelSpec( + id="router-small", + role=ModelRole.ROUTER, + runtime=ModelRuntime.TRANSFORMERS, + model_ref="Qwen/Qwen2.5-3B-Instruct", + capabilities=["summarization"], + ) + + captured: dict[str, object] = {} + + def fake_pipeline(*args, **kwargs): + captured["args"] = args + captured["kwargs"] = kwargs + + def _pipe(*_pipe_args, **_pipe_kwargs): + return [{"generated_text": "hello"}] + + return _pipe + + fake_transformers = types.ModuleType("transformers") + fake_transformers.pipeline = fake_pipeline + monkeypatch.setitem(sys.modules, "transformers", fake_transformers) + + result = provider.generate( + model, + ProviderRequest(user_prompt="Say hello."), + ) + + kwargs = captured["kwargs"] + assert kwargs["token"] == "fake-token" + assert "token" not in kwargs["model_kwargs"] + assert result.text == "hello" diff --git a/tests/unit/test_recovery.py b/tests/unit/test_recovery.py new file mode 100644 index 0000000..d03400f --- /dev/null +++ b/tests/unit/test_recovery.py @@ -0,0 +1,134 @@ +from __future__ import annotations + +import time +from datetime import timedelta +from uuid import uuid4 + +from lai.application import create_application +from lai.domain import ExecutionRequest, QueueMode, StageEventRecord, utcnow +from lai.settings import Settings +from lai.system import collect_system_snapshot +from tests.helpers import FakeProvider + + +class SlowFakeProvider(FakeProvider): + def __init__(self, *args, delay_seconds: float = 0.05, **kwargs) -> None: + super().__init__(*args, **kwargs) + self.delay_seconds = delay_seconds + + def generate(self, model, request, *, progress_callback=None): + time.sleep(self.delay_seconds) + return super().generate(model, request, progress_callback=progress_callback) + + +def _test_settings(repo_root, tmp_path, **overrides) -> Settings: + base = { + "root_dir": tmp_path, + "model_catalog": repo_root / "configs/models/catalog.yaml", + "routing_policy": repo_root / "configs/routing/policies.yaml", + "database_path": tmp_path / "data/state/lai.db", + "state_dir": tmp_path / "data/state", + "artifacts_dir": tmp_path / "data/artifacts", + "logs_dir": tmp_path / "logs", + "huggingface_cache_dir": tmp_path / "data/cache/huggingface", + "airllm_shards_dir": tmp_path / "data/models/airllm-shards", + "raw_models_dir": tmp_path / "data/models/raw", + } + base.update(overrides) + return Settings(**base) + + +def _test_providers( + settings: Settings, + *, + slow_transformers: bool = False, +) -> dict[str, FakeProvider]: + snapshot = collect_system_snapshot(settings.root_dir, enable_gpu=False) + transformers_provider: FakeProvider + if slow_transformers: + transformers_provider = SlowFakeProvider( + settings, + snapshot, + "transformers", + delay_seconds=0.05, + ) + else: + transformers_provider = FakeProvider(settings, snapshot, "transformers") + return { + "transformers": transformers_provider, + "airllm": FakeProvider(settings, snapshot, "airllm"), + "openai": FakeProvider(settings, snapshot, "openai"), + "anthropic": FakeProvider(settings, snapshot, "anthropic"), + "gemini": FakeProvider(settings, snapshot, "gemini"), + } + + +def test_slow_stage_persists_heartbeat_events(repo_root, tmp_path) -> None: + settings = _test_settings( + repo_root, + tmp_path, + job_heartbeat_interval_seconds=0.01, + ) + application = create_application( + settings=settings, + providers=_test_providers(settings, slow_transformers=True), + ) + + job = application.orchestration.submit_request( + ExecutionRequest( + user_prompt="Summarize this short note.", + queue_mode=QueueMode.INLINE, + ) + ) + + assert any( + event.stage == "executor" and event.event_type == "heartbeat" + for event in job.stage_events + ) + + +def test_stale_job_detection_and_recovery(repo_root, tmp_path) -> None: + settings = _test_settings(repo_root, tmp_path) + application = create_application( + settings=settings, + providers=_test_providers(settings), + ) + job = application.orchestration.submit_request( + ExecutionRequest( + user_prompt="Recover this interrupted overnight execution.", + queue_mode=QueueMode.QUEUED, + ) + ) + running = application.job_store.claim_next_queued_job() + assert running is not None + + stale_at = utcnow() - timedelta(minutes=20) + running.updated_at = stale_at + running.started_at = stale_at + application.job_store.save_job(running) + application.job_store.add_stage_event( + StageEventRecord( + id=str(uuid4()), + job_id=job.id, + stage="executor", + event_type="started", + message="Executor stage started long ago.", + created_at=stale_at, + model_id=running.route_decision.executor_model_id if running.route_decision else None, + ) + ) + + stale_jobs = application.orchestration.list_stale_jobs(stale_after_seconds=60) + recovered = application.orchestration.recover_stale_jobs(stale_after_seconds=60) + refreshed = application.job_store.get_job(job.id) + + assert stale_jobs + assert stale_jobs[0].job_id == job.id + assert recovered + assert refreshed is not None + assert refreshed.status == "queued" + assert refreshed.started_at is None + assert any( + event.stage == "job" and event.event_type == "recovered" + for event in refreshed.stage_events + ) diff --git a/tests/unit/test_routing.py b/tests/unit/test_routing.py new file mode 100644 index 0000000..d8e5c25 --- /dev/null +++ b/tests/unit/test_routing.py @@ -0,0 +1,60 @@ +from lai.application import create_application +from lai.domain import ExecutionRequest +from lai.settings import Settings +from lai.system import collect_system_snapshot +from tests.helpers import FakeProvider + + +def _make_application(tmp_path, repo_root, *, airllm_available: bool = True): + settings = Settings( + root_dir=tmp_path, + model_catalog=repo_root / "configs/models/catalog.yaml", + routing_policy=repo_root / "configs/routing/policies.yaml", + enable_gpu=True, + ) + snapshot = collect_system_snapshot( + settings.root_dir, + enable_gpu=False, + ) + providers = { + "transformers": FakeProvider(settings, snapshot, "transformers"), + "airllm": FakeProvider(settings, snapshot, "airllm", available=airllm_available), + "openai": FakeProvider(settings, snapshot, "openai"), + "anthropic": FakeProvider(settings, snapshot, "anthropic"), + "gemini": FakeProvider(settings, snapshot, "gemini"), + } + return create_application(settings=settings, providers=providers) + + +def test_routing_prefers_deep_work_for_complex_prompt(tmp_path, repo_root) -> None: + app = _make_application(tmp_path, repo_root) + + decision = app.routing_engine.route( + ExecutionRequest( + user_prompt=( + "Create a comprehensive advanced architecture and complete implementation " + "strategy for a long-running research platform." + ) + ) + ) + + assert decision.matched_tier_id == "deep-work" + assert decision.executor_model_id == "execution-large" + assert decision.queue_recommended is True + + +def test_routing_falls_back_when_airllm_unavailable(tmp_path, repo_root) -> None: + app = _make_application(tmp_path, repo_root, airllm_available=False) + + decision = app.routing_engine.route( + ExecutionRequest( + user_prompt=( + "Create a comprehensive advanced architecture and complete implementation " + "strategy for a long-running research platform." + ) + ) + ) + + assert decision.matched_tier_id == "deep-work" + assert decision.executor_model_id != "execution-large" + assert decision.executor_model_id in {"openai-general", "anthropic-general", "gemini-general"} diff --git a/tests/unit/test_smoke.py b/tests/unit/test_smoke.py new file mode 100644 index 0000000..91bfbe4 --- /dev/null +++ b/tests/unit/test_smoke.py @@ -0,0 +1,122 @@ +from lai.settings import Settings +from lai.smoke import ( + latest_smoke_result, + list_smoke_results, + resolve_smoke_result_path, + run_smoke_suite, + save_smoke_result, +) +from lai.system import collect_system_snapshot +from tests.helpers import FakeProvider + + +def _test_settings(repo_root, tmp_path) -> Settings: + return Settings( + root_dir=tmp_path, + model_catalog=repo_root / "configs/models/catalog.yaml", + routing_policy=repo_root / "configs/routing/policies.yaml", + database_path=tmp_path / "data/state/lai.db", + state_dir=tmp_path / "data/state", + artifacts_dir=tmp_path / "data/artifacts", + logs_dir=tmp_path / "logs", + huggingface_cache_dir=tmp_path / "data/cache/huggingface", + airllm_shards_dir=tmp_path / "data/models/airllm-shards", + raw_models_dir=tmp_path / "data/models/raw", + ) + + +def _test_providers(settings: Settings) -> dict[str, FakeProvider]: + snapshot = collect_system_snapshot(settings.root_dir, enable_gpu=False) + return { + "transformers": FakeProvider(settings, snapshot, "transformers"), + "airllm": FakeProvider(settings, snapshot, "airllm"), + "openai": FakeProvider(settings, snapshot, "openai"), + "anthropic": FakeProvider(settings, snapshot, "anthropic"), + "gemini": FakeProvider(settings, snapshot, "gemini"), + } + + +def test_provider_smoke_readiness_with_fake_providers(repo_root, tmp_path) -> None: + settings = _test_settings(repo_root, tmp_path) + result = run_smoke_suite( + settings, + suite_id="providers", + providers=_test_providers(settings), + ) + + assert result.passed + assert result.mode == "readiness" + assert result.total == 3 + assert all(check.status == "ready" for check in result.checks) + assert all(not check.executed for check in result.checks) + + +def test_provider_smoke_live_executes_and_saves_result(repo_root, tmp_path) -> None: + settings = _test_settings(repo_root, tmp_path) + result = run_smoke_suite( + settings, + suite_id="providers", + live=True, + providers=_test_providers(settings), + ) + path = save_smoke_result(tmp_path / "evals/results/smoke", result) + + assert result.passed + assert result.mode == "live" + assert all(check.executed for check in result.checks) + assert path.exists() + assert latest_smoke_result(tmp_path / "evals/results/smoke") == path + + +def test_smoke_result_listing_and_resolution(repo_root, tmp_path) -> None: + settings = _test_settings(repo_root, tmp_path) + results_dir = tmp_path / "evals/results/smoke" + + provider_path = save_smoke_result( + results_dir, + run_smoke_suite( + settings, + suite_id="providers", + providers=_test_providers(settings), + ), + ) + local_path = save_smoke_result( + results_dir, + run_smoke_suite( + settings, + suite_id="local", + providers=_test_providers(settings), + ), + ) + + listed = list_smoke_results(results_dir, suite_id="providers", limit=5) + + assert provider_path in listed + assert local_path not in listed + assert resolve_smoke_result_path(results_dir, provider_path.name) == provider_path + assert resolve_smoke_result_path(results_dir, "..\\outside.json") is None + + +def test_local_smoke_excludes_airllm_by_default_and_can_include_it(repo_root, tmp_path) -> None: + settings = _test_settings(repo_root, tmp_path) + providers = _test_providers(settings) + + default_result = run_smoke_suite( + settings, + suite_id="local", + providers=providers, + ) + expanded_result = run_smoke_suite( + settings, + suite_id="local", + include_airllm=True, + providers=providers, + ) + + assert default_result.total == 1 + assert default_result.checks[0].provider_id == "transformers" + assert expanded_result.total == 2 + assert {check.provider_id for check in expanded_result.checks} == { + "transformers", + "airllm", + } diff --git a/tests/unit/test_worker_service.py b/tests/unit/test_worker_service.py new file mode 100644 index 0000000..7e8ab68 --- /dev/null +++ b/tests/unit/test_worker_service.py @@ -0,0 +1,209 @@ +import json +import time +from datetime import timedelta + +from lai.application import create_application +from lai.domain import ExecutionRequest, QueueMode, WorkerStatus, utcnow +from lai.settings import Settings +from lai.system import collect_system_snapshot +from lai.worker.service import ( + WorkerServiceConfig, + WorkerServiceHost, + request_worker_service_stop, +) +from tests.helpers import FakeProvider + + +def _test_settings(repo_root, tmp_path, **overrides) -> Settings: + base = { + "root_dir": tmp_path, + "model_catalog": repo_root / "configs/models/catalog.yaml", + "routing_policy": repo_root / "configs/routing/policies.yaml", + "database_path": tmp_path / "data/state/lai.db", + "state_dir": tmp_path / "data/state", + "artifacts_dir": tmp_path / "data/artifacts", + "logs_dir": tmp_path / "logs", + "huggingface_cache_dir": tmp_path / "data/cache/huggingface", + "airllm_shards_dir": tmp_path / "data/models/airllm-shards", + "raw_models_dir": tmp_path / "data/models/raw", + } + base.update(overrides) + return Settings(**base) + + +def _test_providers(settings: Settings) -> dict[str, FakeProvider]: + snapshot = collect_system_snapshot(settings.root_dir, enable_gpu=False) + return { + "transformers": FakeProvider(settings, snapshot, "transformers"), + "airllm": FakeProvider(settings, snapshot, "airllm"), + "openai": FakeProvider(settings, snapshot, "openai"), + "anthropic": FakeProvider(settings, snapshot, "anthropic"), + "gemini": FakeProvider(settings, snapshot, "gemini"), + } + + +def test_worker_service_processes_queue_and_releases_lock(repo_root, tmp_path) -> None: + settings = _test_settings(repo_root, tmp_path) + providers = _test_providers(settings) + app = create_application(settings=settings, providers=providers) + app.orchestration.submit_request( + ExecutionRequest( + user_prompt="Prepare a queued architecture run for the daemon.", + queue_mode=QueueMode.QUEUED, + ) + ) + + host = WorkerServiceHost( + settings=settings, + providers=providers, + config=WorkerServiceConfig(max_cycles=1, poll_interval_seconds=0.01), + sleep_fn=lambda _seconds: None, + ) + processed = host.run() + refreshed = create_application(settings=settings, providers=providers) + worker = refreshed.orchestration.get_worker_state() + + assert processed == 1 + assert not settings.resolved_worker_service_lock_path.exists() + assert worker.mode == "daemon" + assert worker.status == "idle" + assert worker.processed_jobs == 1 + + +def test_worker_service_stop_signal_exits_gracefully(repo_root, tmp_path) -> None: + settings = _test_settings(repo_root, tmp_path) + providers = _test_providers(settings) + app = create_application(settings=settings, providers=providers) + app.orchestration.submit_request( + ExecutionRequest( + user_prompt="Prepare a queued architecture run for the daemon.", + queue_mode=QueueMode.QUEUED, + ) + ) + + def _sleep_and_request_stop(_seconds: float) -> None: + request_worker_service_stop(settings) + + host = WorkerServiceHost( + settings=settings, + providers=providers, + config=WorkerServiceConfig(poll_interval_seconds=0.01, max_idle_cycles_per_drain=1), + sleep_fn=_sleep_and_request_stop, + ) + processed = host.run() + refreshed = create_application(settings=settings, providers=providers) + worker = refreshed.orchestration.get_worker_state() + + assert processed == 1 + assert worker.status == "stopped" + assert worker.mode == "daemon" + assert not settings.resolved_worker_service_stop_path.exists() + + +def test_worker_service_recovers_interrupted_running_job_after_lock_replace( + repo_root, + tmp_path, +) -> None: + settings = _test_settings(repo_root, tmp_path) + providers = _test_providers(settings) + app = create_application(settings=settings, providers=providers) + job = app.orchestration.submit_request( + ExecutionRequest( + user_prompt="Resume this interrupted overnight daemon execution.", + queue_mode=QueueMode.QUEUED, + ) + ) + running = app.job_store.claim_next_queued_job() + assert running is not None + app.orchestration.update_worker_state( + worker_id="local-worker", + status=WorkerStatus.RUNNING, + mode="daemon", + current_job_id=job.id, + processed_jobs=0, + queued_jobs=0, + metadata={"phase": "processing"}, + ) + settings.resolved_worker_service_lock_path.write_text( + json.dumps({"pid": 4242, "created_at": time.time() - 300}), + encoding="utf-8", + ) + + host = WorkerServiceHost( + settings=settings, + providers=providers, + config=WorkerServiceConfig( + max_cycles=1, + poll_interval_seconds=0.01, + replace_existing_lock=True, + ), + sleep_fn=lambda _seconds: None, + ) + processed = host.run() + refreshed = create_application(settings=settings, providers=providers) + persisted = refreshed.job_store.get_job(job.id) + worker = refreshed.orchestration.get_worker_state() + + assert processed == 1 + assert persisted is not None + assert persisted.status == "succeeded" + assert any( + event.stage == "job" and event.event_type == "recovered" + for event in persisted.stage_events + ) + last_recovery = worker.metadata.get("last_recovery") + assert isinstance(last_recovery, dict) + assert last_recovery["count"] == 1 + assert last_recovery["reason"] == "startup-lock-replaced" + assert last_recovery["replaced_lock"] is True + assert last_recovery["replaced_lock_pid"] == 4242 + + +def test_worker_service_recovers_stale_running_jobs_during_cycle(repo_root, tmp_path) -> None: + settings = _test_settings( + repo_root, + tmp_path, + worker_service_recovery_stale_after_seconds=60, + ) + providers = _test_providers(settings) + app = create_application(settings=settings, providers=providers) + job = app.orchestration.submit_request( + ExecutionRequest( + user_prompt="Recover this stale interrupted run during daemon sweep.", + queue_mode=QueueMode.QUEUED, + ) + ) + running = app.job_store.claim_next_queued_job() + assert running is not None + stale_at = utcnow() - timedelta(minutes=10) + running.updated_at = stale_at + running.started_at = stale_at + app.job_store.save_job(running) + app.orchestration.update_worker_state( + worker_id="local-worker", + status=WorkerStatus.IDLE, + mode="daemon", + current_job_id=None, + processed_jobs=0, + queued_jobs=0, + metadata={"phase": "sleeping"}, + ) + + host = WorkerServiceHost( + settings=settings, + providers=providers, + config=WorkerServiceConfig(max_cycles=1, poll_interval_seconds=0.01), + sleep_fn=lambda _seconds: None, + ) + processed = host.run() + refreshed = create_application(settings=settings, providers=providers) + persisted = refreshed.job_store.get_job(job.id) + worker = refreshed.orchestration.get_worker_state() + + assert processed == 1 + assert persisted is not None + assert persisted.status == "succeeded" + last_recovery = worker.metadata.get("last_recovery") + assert isinstance(last_recovery, dict) + assert last_recovery["count"] == 1 + assert last_recovery["reason"] == "cycle-stale-recovery" diff --git a/tests/unit/test_workstation.py b/tests/unit/test_workstation.py new file mode 100644 index 0000000..397b387 --- /dev/null +++ b/tests/unit/test_workstation.py @@ -0,0 +1,105 @@ +from lai.config import load_app_config +from lai.settings import Settings +from lai.system import SystemSnapshot +from lai.workstation import build_workstation_readiness + + +def _test_settings(repo_root, tmp_path, **overrides) -> Settings: + base = { + "root_dir": tmp_path, + "model_catalog": repo_root / "configs/models/catalog.yaml", + "routing_policy": repo_root / "configs/routing/policies.yaml", + "database_path": tmp_path / "data/state/lai.db", + "state_dir": tmp_path / "data/state", + "artifacts_dir": tmp_path / "data/artifacts", + "logs_dir": tmp_path / "logs", + "huggingface_cache_dir": tmp_path / "data/cache/huggingface", + "airllm_shards_dir": tmp_path / "data/models/airllm-shards", + "raw_models_dir": tmp_path / "data/models/raw", + } + base.update(overrides) + return Settings(**base) + + +def test_workstation_readiness_blocks_missing_airllm_prerequisites(repo_root, tmp_path) -> None: + settings = _test_settings(repo_root, tmp_path) + config = load_app_config(settings.resolved_model_catalog, settings.resolved_routing_policy) + snapshot = SystemSnapshot( + has_gpu=False, + total_ram_gb=16.0, + available_ram_gb=10.0, + free_disk_gb={ + str(settings.resolved_huggingface_cache_dir): 20.0, + str(settings.resolved_airllm_shards_dir): 40.0, + }, + ) + + report = build_workstation_readiness( + settings, + config, + snapshot=snapshot, + dependency_versions={ + "torch": None, + "transformers": "4.55.0", + "accelerate": None, + "huggingface-hub": "0.34.0", + "airllm": None, + "bitsandbytes": None, + }, + ) + + local_profile = report.profile("local") + airllm_profile = report.profile("airllm") + + assert not local_profile.ready + assert not airllm_profile.ready + assert any( + check.id == "airllm-dependencies" and check.status == "fail" + for check in airllm_profile.checks + ) + assert any( + check.id == "huggingface-token" and check.status == "fail" + for check in airllm_profile.checks + ) + assert any( + check.id == "airllm-shards-disk" and check.status == "fail" + for check in airllm_profile.checks + ) + + +def test_workstation_readiness_allows_cpu_fallback_with_warnings(repo_root, tmp_path) -> None: + settings = _test_settings(repo_root, tmp_path, hf_token="test-token") + config = load_app_config(settings.resolved_model_catalog, settings.resolved_routing_policy) + snapshot = SystemSnapshot( + has_gpu=False, + total_ram_gb=48.0, + available_ram_gb=30.0, + free_disk_gb={ + str(settings.resolved_huggingface_cache_dir): 220.0, + str(settings.resolved_airllm_shards_dir): 320.0, + }, + ) + + report = build_workstation_readiness( + settings, + config, + snapshot=snapshot, + dependency_versions={ + "torch": "2.7.0", + "transformers": "4.55.0", + "accelerate": "1.9.0", + "huggingface-hub": "0.34.0", + "airllm": "2.1.0", + "bitsandbytes": "0.45.0", + }, + ) + + local_profile = report.profile("local") + airllm_profile = report.profile("airllm") + + assert local_profile.ready + assert airllm_profile.ready + assert airllm_profile.warning_count >= 1 + assert any( + check.id == "airllm-gpu" and check.status == "warn" for check in airllm_profile.checks + )