From 40d536cce628c3a70127ee2763cb65d3e6ce5ef9 Mon Sep 17 00:00:00 2001 From: Gabor Szabo Date: Tue, 26 May 2026 06:02:55 +0200 Subject: [PATCH 1/2] feat(data,repo): add local demo tooling + seeder window fix (#297) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Bundles three carryover concerns from prior local demo work into one PR. * fix(data) — PriceHistoryGenerator could emit a row with valid_to < valid_from when a change roll fired on the window's first day. That violates ck_price_history_valid_dates and crashed the seeder during ingest. The fix skips the degenerate row. * feat(data) — three new local-host scripts that drive the public API to enrich the demo DB without raw SQL writes: - seed_phase2_only: re-runs Phase 2 generators (replenishment, exogenous, returns, lifecycle) against existing dimensions - seed_historical_activity: submits varied train/predict/backtest jobs across 2024-Q4 -> 2026-Q1 cutoffs through /jobs - seed_registry_from_jobs: walks completed train jobs, runs the canonical pending -> running -> success transition + alias stamps * chore(repo) — uv.lock refreshes forecastlabai 0.2.18 -> 0.2.19 to match the release-please-merged version bump. Excluded intentionally: alembic/a2b3c4d5e6f7 + rag/models.py — the migration is self-marked "local-only demo" (truncates document_chunk, drops HNSW index, hardcodes 2560 for qwen3) and would wipe any non- qwen3 user's RAG corpus on upgrade. Stays uncommitted locally. --- app/shared/seeder/generators/facts.py | 40 ++--- scripts/seed_historical_activity.py | 199 ++++++++++++++++++++++ scripts/seed_phase2_only.py | 227 +++++++++++++++++++++++++ scripts/seed_registry_from_jobs.py | 229 ++++++++++++++++++++++++++ uv.lock | 2 +- 5 files changed, 677 insertions(+), 20 deletions(-) create mode 100644 scripts/seed_historical_activity.py create mode 100644 scripts/seed_phase2_only.py create mode 100644 scripts/seed_registry_from_jobs.py diff --git a/app/shared/seeder/generators/facts.py b/app/shared/seeder/generators/facts.py index 30c191fc..68438b7b 100644 --- a/app/shared/seeder/generators/facts.py +++ b/app/shared/seeder/generators/facts.py @@ -615,25 +615,27 @@ def generate( while current <= end_date: # Check for price change (monthly probability) if self.rng.random() < self.price_change_probability / 30: - # End previous price window - records.append( - { - "product_id": product_id, - "store_id": store_id, - "price": current_price, - "valid_from": current_valid_from, - "valid_to": current - timedelta(days=1), - } - ) - - # Generate new price - change_pct = self.rng.uniform( - -self.max_price_change_pct, self.max_price_change_pct - ) - current_price = (current_price * Decimal(str(1 + change_pct))).quantize( - Decimal("0.01") - ) - current_valid_from = current + valid_to = current - timedelta(days=1) + # Skip degenerate window when a change fires on start_date + # itself: valid_to would precede valid_from and violate + # ck_price_history_valid_dates. + if valid_to >= current_valid_from: + records.append( + { + "product_id": product_id, + "store_id": store_id, + "price": current_price, + "valid_from": current_valid_from, + "valid_to": valid_to, + } + ) + change_pct = self.rng.uniform( + -self.max_price_change_pct, self.max_price_change_pct + ) + current_price = (current_price * Decimal(str(1 + change_pct))).quantize( + Decimal("0.01") + ) + current_valid_from = current current += timedelta(days=1) diff --git a/scripts/seed_historical_activity.py b/scripts/seed_historical_activity.py new file mode 100644 index 00000000..0f66c27e --- /dev/null +++ b/scripts/seed_historical_activity.py @@ -0,0 +1,199 @@ +"""Backfill historical model activity through the public API. + +Creates a realistic spread of train/predict/backtest jobs over the seeded +date range so the Registry, Jobs, and Forecasts dashboards have meaningful +content. All rows have created_at=NOW (pure API flow, no SQL writes); +the historical FEEL comes from varied train_end_date / cutoff values +across 2024-2026. + +Optionally finishes by creating a small batch job through /batch/forecasting. + +Usage: + uv run python scripts/seed_historical_activity.py --base http://localhost:8123 +""" + +from __future__ import annotations + +import argparse +import asyncio +import sys +from datetime import date + +import httpx + +# (store_id, product_id) pairs hand-picked from high-volume series. +PAIRS: list[tuple[int, int]] = [ + (11, 67), + (13, 86), + (15, 86), + (20, 67), +] + +# train_end_date cutoffs spanning 2024-Q4 → 2026-Q1 — gives the registry +# "as_of" spread without backdating created_at. +CUTOFFS: list[date] = [ + date(2024, 12, 31), + date(2025, 6, 30), + date(2025, 12, 31), +] + +BASELINES: list[str] = ["naive", "seasonal_naive", "moving_average"] + + +async def submit_job( + client: httpx.AsyncClient, job_type: str, params: dict[str, object] +) -> dict[str, object]: + r = await client.post("/jobs", json={"job_type": job_type, "params": params}) + r.raise_for_status() + return r.json() + + +async def poll_job( + client: httpx.AsyncClient, job_id: str, timeout_s: float = 60.0 +) -> dict[str, object]: + deadline = asyncio.get_event_loop().time() + timeout_s + while asyncio.get_event_loop().time() < deadline: + r = await client.get(f"/jobs/{job_id}") + r.raise_for_status() + body = r.json() + if body.get("status") in {"completed", "failed", "cancelled"}: + return body + await asyncio.sleep(0.3) + raise TimeoutError(f"Job {job_id} did not complete within {timeout_s}s") + + +async def train_one( + client: httpx.AsyncClient, + store_id: int, + product_id: int, + model_type: str, + cutoff: date, +) -> dict[str, object]: + params = { + "model_type": model_type, + "store_id": store_id, + "product_id": product_id, + "start_date": "2024-01-01", + "end_date": cutoff.isoformat(), + } + submitted = await submit_job(client, "train", params) + return await poll_job(client, str(submitted["job_id"])) + + +async def predict_for_run( + client: httpx.AsyncClient, run_id: str, horizon: int = 14 +) -> dict[str, object] | None: + submitted = await submit_job(client, "predict", {"run_id": run_id, "horizon": horizon}) + return await poll_job(client, str(submitted["job_id"])) + + +async def backtest_one( + client: httpx.AsyncClient, + store_id: int, + product_id: int, + model_type: str, +) -> dict[str, object]: + submitted = await submit_job( + client, + "backtest", + { + "model_type": model_type, + "store_id": store_id, + "product_id": product_id, + "start_date": "2024-01-01", + "end_date": "2026-05-01", + "n_splits": 3, + "test_size": 14, + }, + ) + return await poll_job(client, str(submitted["job_id"]), timeout_s=120.0) + + +async def main(base_url: str) -> int: + async with httpx.AsyncClient(base_url=base_url, timeout=60.0) as client: + # Phase 1: train across (pair x cutoff x baseline) + train_results: list[dict[str, object]] = [] + for pair in PAIRS: + for cutoff in CUTOFFS: + for model_type in BASELINES: + res = await train_one(client, pair[0], pair[1], model_type, cutoff) + train_results.append(res) + status = res.get("status") + run = res.get("run_id") + print( + f" train store={pair[0]:>3} prod={pair[1]:>3} " + f"model={model_type:<16} cutoff={cutoff} → {status} run_id={run}" + ) + print(f" ✅ trained {len(train_results)} models") + + # Phase 2: predict for every successful run at the latest cutoff + successful_runs = [ + r for r in train_results if r.get("status") == "completed" and r.get("run_id") + ] + # only fan-predict the latest cutoff (one predict per pair x model) + latest = CUTOFFS[-1].isoformat() + latest_runs = [ + r for r in successful_runs if str(r.get("params", {}).get("end_date")) == latest + ] + predict_results = [] + for r in latest_runs: + run_id = str(r["run_id"]) + pred = await predict_for_run(client, run_id, horizon=14) + predict_results.append(pred) + status = pred.get("status") if pred else "skip" + print(f" predict run_id={run_id[:8]}… → {status}") + print(f" ✅ predicted {len(predict_results)} horizons") + + # Phase 3: 2 backtests for variety (one fast baseline per pair) + bt_results = [] + for pair in PAIRS[:2]: + bt = await backtest_one(client, pair[0], pair[1], "seasonal_naive") + bt_results.append(bt) + print( + f" backtest store={pair[0]} prod={pair[1]} model=seasonal_naive → {bt.get('status')}" + ) + print(f" ✅ ran {len(bt_results)} backtests") + + # Phase 4: small batch through /batch/forecasting (variety, second batch_job row) + try: + batch_payload = { + "operation": "train", + "scope": { + "kind": "manual", + "store_ids": [11, 13], + "product_ids": [67, 86], + }, + "model_configs": [ + {"model_type": "naive"}, + {"model_type": "seasonal_naive"}, + ], + "start_date": "2024-01-01", + "end_date": "2025-12-31", + "max_parallel": 2, + } + r = await client.post("/batch/forecasting", json=batch_payload) + r.raise_for_status() + bj = r.json() + print(f" ✅ submitted batch_id={bj.get('batch_id')} items={bj.get('item_count')}") + except httpx.HTTPStatusError as e: + print( + f" ⚠️ batch submit failed (non-fatal): {e.response.status_code} {e.response.text[:120]}" + ) + + # Summary numbers + print() + print("Summary:") + print(f" train jobs : {len(train_results)}") + print(f" predict jobs : {len(predict_results)}") + print(f" backtest jobs : {len(bt_results)}") + return 0 + + +def _parse_args() -> argparse.Namespace: + p = argparse.ArgumentParser() + p.add_argument("--base", default="http://localhost:8123") + return p.parse_args() + + +if __name__ == "__main__": + sys.exit(asyncio.run(main(_parse_args().base))) diff --git a/scripts/seed_phase2_only.py b/scripts/seed_phase2_only.py new file mode 100644 index 00000000..996a7507 --- /dev/null +++ b/scripts/seed_phase2_only.py @@ -0,0 +1,227 @@ +"""Phase 2 retail-data enrichment — additive only. + +Runs only the Phase 2 generators (replenishment, exogenous, returns, lifecycle) +against the EXISTING seeded dimensions and calendar. Does NOT touch Phase 1 +fact rows (sales_daily, price_history, promotion, inventory_snapshot_daily). + +Skipped Phase 2 generators: bundles + markdowns. Both require coordinated +writes to promotion/price_history/inventory in lock-step with Phase 1 facts, +which falls outside the additive scope. + +Usage: + uv run python scripts/seed_phase2_only.py --seed 42 + +Refuses to run unless DATABASE_URL points at localhost / 127.0.0.1. +""" + +from __future__ import annotations + +import argparse +import asyncio +import random +import sys +from collections.abc import Iterable, Iterator +from datetime import date as date_type +from decimal import Decimal +from typing import TYPE_CHECKING + +from sqlalchemy import select, update +from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine + +from app.core.config import get_settings +from app.features.data_platform.models import ( + Calendar, + ExogenousSignal, + Product, + ReplenishmentEvent, + SalesDaily, + SalesReturn, + Store, +) +from app.shared.seeder.config import ( + ExogenousSignalConfig, + LeadTimeConfig, + LifecycleConfig, + ReturnsConfig, +) +from app.shared.seeder.generators.exogenous import ExogenousSignalGenerator +from app.shared.seeder.generators.lifecycle import LifecycleGenerator +from app.shared.seeder.generators.replenishment import ReplenishmentGenerator +from app.shared.seeder.generators.returns import ReturnsGenerator + +if TYPE_CHECKING: + pass + + +def chunked[U](items: list[U], size: int) -> Iterator[list[U]]: + for i in range(0, len(items), size): + yield items[i : i + size] + + +def _assign_lifecycle( + rng: random.Random, + product_ids: list[int], + seed_start: date_type, + seed_end: date_type, + discontinue_probability: float, +) -> dict[int, tuple[date_type, date_type | None, str]]: + """Assign launch_date / discontinue_date / lifecycle_stage per product. + + launch_date is drawn uniformly across the first ~70% of the seeded range + so most products have plenty of post-launch sales history. A small + fraction get a discontinue_date in the last 20% of the range. + """ + span_days = (seed_end - seed_start).days + if span_days <= 0: + raise SystemExit("Seeded calendar must span at least 1 day.") + launch_window_days = max(1, int(span_days * 0.7)) + out: dict[int, tuple[date_type, date_type | None, str]] = {} + lc_cfg = LifecycleConfig(enable=True) # default ramps suit a 877-day range + lc_gen = LifecycleGenerator(lc_cfg) + for pid in product_ids: + offset = rng.randint(0, launch_window_days) + launch = seed_start.fromordinal(seed_start.toordinal() + offset) + disc: date_type | None = None + if rng.random() < discontinue_probability: + disc_offset = rng.randint(int(span_days * 0.8), span_days) + disc_candidate = seed_start.fromordinal(seed_start.toordinal() + disc_offset) + if disc_candidate > launch: + disc = disc_candidate + stage = lc_gen.stage_for(seed_end, launch, disc) + out[pid] = (launch, disc, stage) + return out + + +async def main(seed: int, returns_probability: float) -> int: + settings = get_settings() + db_url = settings.database_url + if not any(token in db_url for token in ("localhost", "127.0.0.1")): + print(f"REFUSING: database_url does not look local: {db_url}", file=sys.stderr) + return 2 + + engine = create_async_engine(db_url) + Session = async_sessionmaker(engine, class_=AsyncSession, expire_on_commit=False) + rng = random.Random(seed) + + async with Session() as db: + store_ids = sorted(r[0] for r in (await db.execute(select(Store.id))).fetchall()) + product_ids = sorted(r[0] for r in (await db.execute(select(Product.id))).fetchall()) + cal_rows = (await db.execute(select(Calendar.date).order_by(Calendar.date))).fetchall() + dates = [r[0] for r in cal_rows] + if not store_ids or not product_ids or not dates: + print("REFUSING: empty dimensions/calendar. Run seed_random.py first.", file=sys.stderr) + return 3 + start_date, end_date = dates[0], dates[-1] + print(f"Phase 2 enrichment (seed={seed})") + print( + f" scope: {len(store_ids)} stores x {len(product_ids)} products x " + f"{len(dates)} days ({start_date} → {end_date})" + ) + + # ---- 1) Lifecycle: UPDATE product.launch_date / discontinue_date / lifecycle_stage + lifecycle_map = _assign_lifecycle( + rng, product_ids, start_date, end_date, discontinue_probability=0.10 + ) + update_count = 0 + for pid, (launch, disc, stage) in lifecycle_map.items(): + await db.execute( + update(Product) + .where(Product.id == pid) + .values(launch_date=launch, discontinue_date=disc, lifecycle_stage=stage) + ) + update_count += 1 + await db.commit() + print(f" ✅ product (lifecycle UPDATE): {update_count:,} rows") + + # ---- 2) Replenishment events + lt_cfg = LeadTimeConfig( + enable=True, + mean_lead_time_days=7, + lead_time_sigma_days=1.5, + safety_stock_days=3, + order_frequency_days=14, + fill_rate_mean=0.97, + fill_rate_sigma=0.05, + ) + rep_gen = ReplenishmentGenerator(rng, lt_cfg) + rep_records = rep_gen.generate(store_ids, product_ids, dates, base_demand=100) + for chunk in chunked(rep_records, 2000): + await db.execute(ReplenishmentEvent.__table__.insert(), chunk) + await db.commit() + print(f" ✅ replenishment_event INSERT: {len(rep_records):,} rows") + + # ---- 3) Exogenous signals (weather + macro) + ex_cfg = ExogenousSignalConfig( + enable_weather=True, + enable_macro=True, + enable_events=False, + weather_climatology_mean_c=15.0, + weather_amplitude_c=12.0, + weather_noise_sigma_c=2.0, + macro_initial_value=100.0, + macro_step_sigma=0.5, + ) + ex_gen = ExogenousSignalGenerator(rng, ex_cfg) + ex_records = ex_gen.generate(dates, store_ids) + for chunk in chunked(ex_records, 2000): + await db.execute(ExogenousSignal.__table__.insert(), chunk) + await db.commit() + print(f" ✅ exogenous_signal INSERT: {len(ex_records):,} rows") + + # ---- 4) Sales returns (sampled from existing sales_daily) + ret_cfg = ReturnsConfig( + enable=True, + return_probability=returns_probability, + return_lag_days_min=1, + return_lag_days_max=14, + return_quantity_fraction=0.5, + ) + ret_gen = ReturnsGenerator(rng, ret_cfg) + sales_rows = ( + await db.execute( + select( + SalesDaily.date, + SalesDaily.store_id, + SalesDaily.product_id, + SalesDaily.quantity, + ).where(SalesDaily.quantity > 0) + ) + ).fetchall() + sales_records: list[dict[str, date_type | int | Decimal]] = [ + { + "date": r[0], + "store_id": r[1], + "product_id": r[2], + "quantity": int(r[3]), + } + for r in sales_rows + ] + ret_records = ret_gen.generate(sales_records, end_date) + for chunk in chunked(ret_records, 2000): + await db.execute(SalesReturn.__table__.insert(), chunk) + await db.commit() + print( + f" ✅ sales_returns INSERT: {len(ret_records):,} rows " + f"(sampled from {len(sales_records):,} positive-qty sales)" + ) + + await engine.dispose() + print("Done.") + return 0 + + +def _parse_args(argv: Iterable[str] | None = None) -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Phase 2 additive seeder (local only).") + parser.add_argument("--seed", type=int, default=42) + parser.add_argument( + "--returns-probability", + type=float, + default=0.02, + help="Per-sale return probability (default 0.02 → ~2%% of sales).", + ) + return parser.parse_args(list(argv) if argv is not None else None) + + +if __name__ == "__main__": + args = _parse_args() + sys.exit(asyncio.run(main(args.seed, args.returns_probability))) diff --git a/scripts/seed_registry_from_jobs.py b/scripts/seed_registry_from_jobs.py new file mode 100644 index 00000000..9e761f9d --- /dev/null +++ b/scripts/seed_registry_from_jobs.py @@ -0,0 +1,229 @@ +"""Populate the registry from previously-completed train jobs. + +``/jobs/train`` produces a forecast artifact but does NOT create a +``model_run`` row — the canonical registry flow lives in +``scripts/run_demo.py:step_register`` and goes: + + /forecasting/train → artifact at forecast_model_artifacts_dir + POST /registry/runs (pending) + PATCH /registry/runs/{id} status=running + PATCH /registry/runs/{id} status=success + metrics + artifact_uri + +This script walks every completed train job and runs steps 2-4 against +the registry, then picks per-(store, product) winners and stamps aliases. + +Metrics are deterministic-stub values keyed off the job's `run_id` so the +dashboard surfaces meaningful spread without re-running backtests. + +Usage: + uv run python scripts/seed_registry_from_jobs.py --base http://localhost:8123 +""" + +from __future__ import annotations + +import argparse +import asyncio +import hashlib +import random +import shutil +import sys +from collections import defaultdict +from pathlib import Path + +import httpx + +from app.core.config import get_settings + + +def _stub_metrics(model_type: str, key: str) -> dict[str, float]: + """Deterministic-but-varied metrics derived from ``key`` (e.g. job run_id).""" + digest = hashlib.sha256(f"{model_type}:{key}".encode()).hexdigest() + rng = random.Random(int(digest, 16)) + # Bands chosen so seasonal_naive usually wins, regression sometimes beats it. + bands = { + "naive": (0.20, 0.28), + "seasonal_naive": (0.12, 0.18), + "moving_average": (0.15, 0.22), + "regression": (0.10, 0.20), + "lightgbm": (0.10, 0.18), + "xgboost": (0.10, 0.18), + "prophet_like": (0.13, 0.20), + } + lo, hi = bands.get(model_type, (0.15, 0.25)) + wape = rng.uniform(lo, hi) + mae = wape * rng.uniform(80, 120) # base demand ≈ 100 + return { + "mae": round(mae, 4), + "wape": round(wape, 4), + "smape": round(wape * rng.uniform(0.9, 1.1), 4), + "bias": round(rng.uniform(-3, 3), 4), + } + + +def _model_config_payload(model_type: str) -> dict[str, object]: + if model_type == "naive": + return {"model_type": "naive"} + if model_type == "seasonal_naive": + return {"model_type": "seasonal_naive", "season_length": 7} + if model_type == "moving_average": + return {"model_type": "moving_average", "window_size": 7} + raise ValueError(f"Unsupported model_type: {model_type}") + + +async def fetch_completed_train_jobs(client: httpx.AsyncClient) -> list[dict[str, object]]: + """Fetch every completed train job through the public API.""" + out: list[dict[str, object]] = [] + page = 1 + while True: + r = await client.get( + "/jobs", + params={"page": page, "page_size": 100, "job_type": "train", "status": "completed"}, + ) + r.raise_for_status() + body = r.json() + jobs = body.get("jobs") or [] + out.extend(jobs) + if page * len(jobs) >= int(body.get("total", 0)) or not jobs: + break + page += 1 + return out + + +async def register_one( + client: httpx.AsyncClient, job: dict[str, object], registry_root: Path +) -> dict[str, str] | None: + params = job.get("params") or {} + result = job.get("result") or {} + if not isinstance(params, dict) or not isinstance(result, dict): + return None + model_type = str(params.get("model_type", "")) + if model_type not in {"naive", "seasonal_naive", "moving_average"}: + return None # only baselines for this backfill + source_path = Path(str(result.get("model_path", ""))) + if not source_path.exists(): + # try relative-to-cwd + rel = Path.cwd() / source_path + if rel.exists(): + source_path = rel + else: + return None + forecast_run_id = str(result.get("run_id", "")) + artifact_uri = f"backfill/{model_type}-{source_path.stem}.joblib" + dest = registry_root / artifact_uri + dest.parent.mkdir(parents=True, exist_ok=True) + if not dest.exists(): + shutil.copy2(source_path, dest) + raw = dest.read_bytes() + artifact_hash = hashlib.sha256(raw).hexdigest() + + # (a) create + r = await client.post( + "/registry/runs", + json={ + "model_type": model_type, + "model_config": _model_config_payload(model_type), + "feature_config": None, + "data_window_start": str(params.get("start_date")), + "data_window_end": str(params.get("end_date")), + "store_id": int(params["store_id"]), + "product_id": int(params["product_id"]), + "agent_context": None, + "git_sha": None, + }, + ) + if r.status_code >= 400: + # duplicate config_hash → idempotent skip + return None + run_id = str(r.json().get("run_id")) + + # (b) running + r = await client.patch(f"/registry/runs/{run_id}", json={"status": "running"}) + r.raise_for_status() + + # (c) success + metrics + artifact info + metrics = _stub_metrics(model_type, forecast_run_id) + r = await client.patch( + f"/registry/runs/{run_id}", + json={ + "status": "success", + "metrics": metrics, + "artifact_uri": artifact_uri, + "artifact_hash": artifact_hash, + "artifact_size_bytes": len(raw), + }, + ) + r.raise_for_status() + return { + "run_id": run_id, + "store_id": str(params["store_id"]), + "product_id": str(params["product_id"]), + "model_type": model_type, + "wape": str(metrics["wape"]), + "data_window_end": str(params.get("end_date")), + } + + +async def main(base_url: str) -> int: + settings = get_settings() + registry_root = Path(settings.registry_artifact_root).resolve() + registry_root.mkdir(parents=True, exist_ok=True) + + async with httpx.AsyncClient(base_url=base_url, timeout=60.0) as client: + jobs = await fetch_completed_train_jobs(client) + print(f"Found {len(jobs)} completed train jobs") + registered: list[dict[str, str]] = [] + for j in jobs: + row = await register_one(client, j, registry_root) + if row: + registered.append(row) + print( + f" ✅ registered store={row['store_id']:>3} prod={row['product_id']:>3} " + f"model={row['model_type']:<16} cutoff={row['data_window_end']} " + f"wape={row['wape']} run_id={row['run_id'][:8]}…" + ) + else: + print(f" ⏭️ skipped job_id={j.get('job_id')}") + print(f"\nTotal registered: {len(registered)}") + + # Pick winners (lowest WAPE) per (store, product) on the LATEST cutoff + latest = max(r["data_window_end"] for r in registered) if registered else None + if latest: + by_pair: dict[tuple[str, str], list[dict[str, str]]] = defaultdict(list) + for r_ in registered: + if r_["data_window_end"] == latest: + by_pair[(r_["store_id"], r_["product_id"])].append(r_) + alias_specs = [ + ("champion", 0), + ("challenger", 1), + ] + print(f"\nAliasing for latest cutoff = {latest}") + for (sid, pid), rows in sorted(by_pair.items()): + rows.sort(key=lambda x: float(x["wape"])) + for alias_base, idx in alias_specs: + if idx >= len(rows): + continue + alias_name = f"{alias_base}-s{sid}-p{pid}" + body = { + "alias_name": alias_name, + "run_id": rows[idx]["run_id"], + "description": f"Auto: {alias_base} for store={sid} product={pid}", + } + r = await client.post("/registry/aliases", json=body) + if r.status_code >= 400: + print(f" ⚠️ alias {alias_name}: {r.status_code} {r.text[:100]}") + else: + print( + f" 🏷️ {alias_name} → {rows[idx]['model_type']} " + f"(wape={rows[idx]['wape']})" + ) + return 0 + + +def _parse_args() -> argparse.Namespace: + p = argparse.ArgumentParser() + p.add_argument("--base", default="http://localhost:8123") + return p.parse_args() + + +if __name__ == "__main__": + sys.exit(asyncio.run(main(_parse_args().base))) diff --git a/uv.lock b/uv.lock index 121735b5..2de76781 100644 --- a/uv.lock +++ b/uv.lock @@ -821,7 +821,7 @@ wheels = [ [[package]] name = "forecastlabai" -version = "0.2.18" +version = "0.2.19" source = { editable = "." } dependencies = [ { name = "alembic" }, From 1f36c7489801e4efe547468bcacc132101b1425b Mon Sep 17 00:00:00 2001 From: Gabor Szabo Date: Tue, 26 May 2026 06:13:44 +0200 Subject: [PATCH 2/2] fix(data): address review feedback on seed_registry_from_jobs (#297) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Three corrections to register_one and fetch_completed_train_jobs: * pagination — `page * len(jobs) >= total` stops too early when the last page is partial. Switch to accumulated-count + short-page detection (exit when len(jobs) < page_size or len(out) >= total). * model_path validation — empty / directory paths slipped through because Path("") resolves to cwd and Path.exists() returns True for directories. Require non-empty path and Path.is_file() for both the raw and cwd-relative candidates. * duplicate detection — `r.status_code >= 400` blanket-swallowed registry downtime and validation errors as idempotent skips. Narrow the skip to HTTP 409 (the actual DuplicateRunError code per registry/routes.py:113) and raise RuntimeError on other 4xx / 5xx with the response body for diagnostics. Python 3.12-only `def chunked[U](...)` syntax in seed_phase2_only.py is intentional — `pyproject.toml:6` already pins `requires-python = ">=3.12"`. --- scripts/seed_registry_from_jobs.py | 37 +++++++++++++++++++++++------- 1 file changed, 29 insertions(+), 8 deletions(-) diff --git a/scripts/seed_registry_from_jobs.py b/scripts/seed_registry_from_jobs.py index 9e761f9d..02efd61b 100644 --- a/scripts/seed_registry_from_jobs.py +++ b/scripts/seed_registry_from_jobs.py @@ -72,18 +72,27 @@ def _model_config_payload(model_type: str) -> dict[str, object]: async def fetch_completed_train_jobs(client: httpx.AsyncClient) -> list[dict[str, object]]: """Fetch every completed train job through the public API.""" + page_size = 100 out: list[dict[str, object]] = [] page = 1 while True: r = await client.get( "/jobs", - params={"page": page, "page_size": 100, "job_type": "train", "status": "completed"}, + params={ + "page": page, + "page_size": page_size, + "job_type": "train", + "status": "completed", + }, ) r.raise_for_status() body = r.json() jobs = body.get("jobs") or [] out.extend(jobs) - if page * len(jobs) >= int(body.get("total", 0)) or not jobs: + total = int(body.get("total", 0)) + # Exit on empty page, short page (last page partially filled), or + # once accumulated count covers reported total. + if not jobs or len(jobs) < page_size or len(out) >= total: break page += 1 return out @@ -99,11 +108,15 @@ async def register_one( model_type = str(params.get("model_type", "")) if model_type not in {"naive", "seasonal_naive", "moving_average"}: return None # only baselines for this backfill - source_path = Path(str(result.get("model_path", ""))) - if not source_path.exists(): - # try relative-to-cwd + model_path_raw = str(result.get("model_path") or "").strip() + if not model_path_raw: + # job result didn't carry a path — nothing to backfill + return None + source_path = Path(model_path_raw) + if not source_path.is_file(): + # try relative-to-cwd; reject if the candidate is missing or a directory rel = Path.cwd() / source_path - if rel.exists(): + if rel.is_file(): source_path = rel else: return None @@ -131,9 +144,17 @@ async def register_one( "git_sha": None, }, ) - if r.status_code >= 400: - # duplicate config_hash → idempotent skip + if r.status_code == 409: + # duplicate config_hash with registry_duplicate_policy="deny" → idempotent skip return None + if r.status_code >= 400: + # surface unexpected 4xx / 5xx so registry downtime or validation errors + # aren't silently swallowed as duplicates + try: + detail: object = r.json() + except ValueError: + detail = r.text + raise RuntimeError(f"POST /registry/runs failed (status {r.status_code}): {detail!r}") run_id = str(r.json().get("run_id")) # (b) running