diff --git a/.github/workflows/daily_pytest_slack.yml b/.github/workflows/daily_pytest_slack.yml new file mode 100644 index 000000000..d8285bf6b --- /dev/null +++ b/.github/workflows/daily_pytest_slack.yml @@ -0,0 +1,83 @@ +name: Daily Pytest + Slack (IL 01:00) + +on: + schedule: + # 01:00 Israel time — 22:00 UTC (summer), 23:00 UTC (winter) + - cron: "0 22 * * *" + - cron: "0 23 * * *" + workflow_dispatch: + +jobs: + run_pytests_and_notify: + runs-on: ubuntu-latest + + steps: + - name: Checkout repository + uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: "3.10" + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + if [ -f requirements.txt ]; then pip install -r requirements.txt; fi + + - name: Run pytest (and keep log) + run: | + pytest -q --maxfail=50 --disable-warnings -rA \ + --junitxml=pytest-report.xml > pytest.log 2>&1 || true + + - name: Parse results + id: results + run: | + python - <<'PY' + import xml.etree.ElementTree as ET + import os + counts = dict(tests=0, failures=0, errors=0, skipped=0) + try: + tree = ET.parse("pytest-report.xml") + root = tree.getroot() + for suite in root.findall(".//testsuite"): + counts["tests"] += int(suite.attrib.get("tests", 0)) + counts["failures"] += int(suite.attrib.get("failures", 0)) + counts["errors"] += int(suite.attrib.get("errors", 0)) + counts["skipped"] += int(suite.attrib.get("skipped", 0)) + except Exception as e: + print("Parse error:", e) + counts["passed"] = counts["tests"] - counts["failures"] - counts["errors"] - counts["skipped"] + with open(os.environ["GITHUB_OUTPUT"], "a") as f: + for k,v in counts.items(): + f.write(f"{k}={v}\n") + f.write(f"has_failures={'true' if (counts['failures']>0 or counts['errors']>0) else 'false'}\n") + PY + + - name: Send Slack notification (if failures) + if: steps.results.outputs.has_failures == 'true' + uses: slackapi/slack-github-action@v1.25.0 + with: + payload: | + { + "channel": "#vast", + "username": "GitHub Actions", + "icon_emoji": ":rotating_light:", + "text": "🚨 *Pytest Failures Detected!*\n\nRepository: ${{ github.repository }}\nRun: <${{ github.server_url }}/${{ github.repository }}/actions/runs/${{ github.run_id }}>\n\n*Passed:* ${{ steps.results.outputs.passed }} / ${{ steps.results.outputs.tests }}\n*Failed:* ${{ steps.results.outputs.failures }}\n*Errors:* ${{ steps.results.outputs.errors }}\n*Skipped:* ${{ steps.results.outputs.skipped }}" + } + env: + SLACK_WEBHOOK_URL: ${{ secrets.SLACK_WEBHOOK_URL }} + + - name: Send Slack success notification + if: steps.results.outputs.has_failures == 'false' + uses: slackapi/slack-github-action@v1.25.0 + with: + payload: | + { + "channel": "#vast", + "username": "GitHub Actions", + "icon_emoji": ":white_check_mark:", + "text": "✅ All tests passed successfully!\n\nRepository: ${{ github.repository }}\nRun: <${{ github.server_url }}/${{ github.repository }}/actions/runs/${{ github.run_id }}>\n\nTotal tests: ${{ steps.results.outputs.tests }}" + } + env: + SLACK_WEBHOOK_URL: ${{ secrets.SLACK_WEBHOOK_URL }} diff --git a/.github/workflows/soak.yaml b/.github/workflows/soak.yaml index ef38a55c4..b3c02dd11 100644 --- a/.github/workflows/soak.yaml +++ b/.github/workflows/soak.yaml @@ -48,9 +48,19 @@ jobs: 'REFRESH_TTL_DAYS=14' \ 'DEV_SA_NAME=ci-service' \ > services/db_api_service/.env - - - + - name: Prepare env for plant_stress + run: | + mkdir -p services/plant_stress + cat > services/plant_stress/.env <<'EOF' + ADDR=0.0.0.0 + PORT=8001 + MINIO_ENDPOINT=minio:9000 + MINIO_ACCESS_KEY=minioadmin + MINIO_SECRET_KEY=minioadmin123 + MINIO_BUCKET=audio + MINIO_PREFIX=samples/ + WINDOW_MIN=5 + EOF - name: Start core stack run: docker compose up -d kafka mosquitto connect diff --git a/.gitignore b/.gitignore index fe188cd93..5379d1f7c 100644 --- a/.gitignore +++ b/.gitignore @@ -18,6 +18,7 @@ __pycache__/ *.pytest_cache/ .venv/ venv/ +.coverage # --- VSCode / Editor --- .vscode/ @@ -35,5 +36,4 @@ venv/ # --- OS files --- .DS_Store -Thumbs.db - +Thumbs.db \ No newline at end of file diff --git a/GUI/requirements.txt b/GUI/requirements.txt index bd9efb1b7..96080bc7a 100644 --- a/GUI/requirements.txt +++ b/GUI/requirements.txt @@ -1,18 +1,37 @@ -PyQt6==6.9.1 -PyQt6-WebEngine==6.9.0 -# Web/API +PyQt6==6.7.1 +PyQt6-WebEngine==6.7.0 +python-vlc + + +# ───── Web/API ───── fastapi>=0.110 uvicorn[standard]>=0.29 flask -# Metrics & HTTP + +# ───── Metrics & HTTP ───── prometheus-client>=0.20 requests>=2.31 httpx==0.27.0 # only needed when you switch to real Flink REST -# gRPC & Protobuf + +# ───── gRPC & Protobuf ───── grpcio>=1.56,<2 grpcio-tools>=1.56,<2 protobuf>=6,<7 -# Validation / crypto + +# ───── Validation / crypto / auth ───── pydantic>=2.9,<3 argon2-cffi +PyJWT>=2.9.0 + +# ───── Geospatial / Math ───── +shapely + +# ───── Async / misc ───── +aiohttp +plotly +shapely +PyJWT>=2.9.0 +sip + + diff --git a/GUI/src/vast/alerts/alert_client.py b/GUI/src/vast/alerts/alert_client.py new file mode 100644 index 000000000..7212a10c8 --- /dev/null +++ b/GUI/src/vast/alerts/alert_client.py @@ -0,0 +1,57 @@ + +from PyQt6.QtCore import QObject, pyqtSignal, QUrl, QTimer +from PyQt6.QtWebSockets import QWebSocket +from PyQt6.QtNetwork import QAbstractSocket # ✅ add this +import json + + +class AlertClient(QObject): + """ + Connects to the alerts WebSocket gateway and emits signals + when new alerts or snapshots arrive. + """ + snapshotReceived = pyqtSignal(list) + alertReceived = pyqtSignal(dict) + connectionLost = pyqtSignal() + + def __init__(self, ws_url: str, parent=None): + super().__init__(parent) + self.url = QUrl(ws_url) + self.socket = QWebSocket() + self.socket.connected.connect(self._on_connected) + self.socket.disconnected.connect(self._on_disconnected) + self.socket.textMessageReceived.connect(self._on_message) + self.reconnect_timer = QTimer() + self.reconnect_timer.timeout.connect(self._try_reconnect) + self.reconnect_interval_ms = 5000 # retry every 5s + self._connect() + + def _connect(self): + print(f"[AlertClient] Connecting to {self.url.toString()}") + self.socket.open(self.url) + + def _try_reconnect(self): + # ✅ Use QAbstractSocket.SocketState instead of QWebSocket.SocketState + if self.socket.state() == QAbstractSocket.SocketState.ConnectedState: + self.reconnect_timer.stop() + return + print("[AlertClient] Attempting reconnect...") + self._connect() + + + def _on_connected(self): + print("[AlertClient] Connected to alerts gateway.") + self.reconnect_timer.stop() + + def _on_disconnected(self): + print("[AlertClient] Disconnected from alerts gateway.") + self.connectionLost.emit() + self.reconnect_timer.start(self.reconnect_interval_ms) + + def _on_message(self, msg: str): + try: + payload = json.loads(msg) + if payload["type"] == "alert": + self.alertReceived.emit(payload["data"]) + except Exception as e: + print("[AlertClient] Invalid message:", e, msg) diff --git a/GUI/src/vast/alerts/alert_service.py b/GUI/src/vast/alerts/alert_service.py new file mode 100644 index 000000000..c6c28e50f --- /dev/null +++ b/GUI/src/vast/alerts/alert_service.py @@ -0,0 +1,209 @@ +import yaml +from string import Template +from PyQt6.QtCore import QObject, pyqtSignal +from vast.alerts.alert_client import AlertClient +from concurrent.futures import ThreadPoolExecutor + +class AlertService(QObject): + alertsUpdated = pyqtSignal(list) + alertAdded = pyqtSignal(dict) + alertRemoved = pyqtSignal(str) + + def __init__(self, ws_url, api, templates_path="/app/templates/templates.yml"): + super().__init__() + self.api = api + self.device_locations = {} + self.templates = self._load_templates(templates_path) + self.load_devices() + + self.client = AlertClient(ws_url) + self.client.alertReceived.connect(self._on_realtime) + + self.alerts = [] + + # ──────────────────────────────── + # Load YAML templates + # ──────────────────────────────── + def _load_templates(self, path): + try: + with open(path, "r", encoding="utf-8") as f: + data = yaml.safe_load(f) + print(f"[AlertService] Loaded templates from {path}") + return data.get("templates", {}) + except Exception as e: + print("[AlertService] Failed to load templates:", e) + return {} + + # ──────────────────────────────── + # Fetch devices from DB + # ──────────────────────────────── + def load_devices(self): + try: + url = f"{self.api.base}/api/tables/devices" + r = self.api.http.get(url, timeout=10) + r.raise_for_status() + data = r.json() + devices = data.get("rows", data) + + self.device_locations = { + d["device_id"]: (d.get("location_lat"), d.get("location_lon")) + for d in devices if d.get("device_id") + } + print(f"[AlertService] Cached {len(self.device_locations)} device locations.") + except Exception as e: + print("[AlertService] Failed to fetch devices:", e) + + # ──────────────────────────────── + # Fetch alerts from DB and enrich with templates + # ──────────────────────────────── + def load_initial(self): + try: + url = f"{self.api.base}/api/tables/alerts" + r = self.api.http.get(url, timeout=10) + r.raise_for_status() + data = r.json() + alerts = data.get("rows", data) + + for a in alerts: + device_id = a.get("device_id") + alert_type = a.get("alert_type") + + # Add lat/lon if missing + if device_id in self.device_locations: + lat, lon = self.device_locations[device_id] + if not a.get("lat") and lat: + a["lat"] = lat + if not a.get("lon") and lon: + a["lon"] = lon + + # Apply template enrichment + tmpl = self.templates.get(alert_type) + if tmpl: + a["category"] = tmpl.get("category") + context = { + "device_id": device_id, + "area": a.get("area", "unknown area"), + "confidence": a.get("confidence", "?"), + "timestamp": a.get("started_at", ""), + } + # Use Template.safe_substitute to avoid KeyErrors + a["summary"] = Template(tmpl.get("summary", "")).safe_substitute(context) + a["recommendation"] = Template(tmpl.get("recommendation", "")).safe_substitute(context) + + self.alerts = alerts + self.alertsUpdated.emit(self.alerts) + print(f"[AlertService] Loaded {len(alerts)} enriched alerts.") + except Exception as e: + print("[AlertService] Failed to fetch alerts:", e) + + # ──────────────────────────────── + # Handle incoming WebSocket alerts + # ──────────────────────────────── + def _on_realtime(self, alert_msg): + alerts = alert_msg.get("alerts", []) + print("[AlertService] Realtime message:", alert_msg) + + for a in alerts: + labels = a.get("labels", {}) + ann = a.get("annotations", {}) + alert_id = labels.get("alert_id") + device_id = labels.get("device") + alert_type = labels.get("alertname") + ends_at = a.get("endsAt") + is_resolved = ends_at and not ends_at.startswith("0001-01-01") + + # Find existing alert in memory + existing = next((al for al in self.alerts if al.get("alert_id") == alert_id), None) + + if is_resolved: + # ✅ Don't delete — update existing alert with endedAt timestamp + if existing: + existing["endedAt"] = ends_at + self.alertRemoved.emit(alert_id) + else: + # If not in memory (e.g. loaded from DB earlier) + # create a minimal record so the UI can update + fake_alert = {"alert_id": alert_id, "endedAt": ends_at} + self.alerts.append(fake_alert) + self.alertRemoved.emit(alert_id) + continue + + # ──────────────────────────────── + # ACTIVE alert (new or ongoing) + # ──────────────────────────────── + lat = ann.get("lat") + lon = ann.get("lon") + + # Fill missing coordinates + if (not lat or not lon) and device_id in self.device_locations: + lat, lon = self.device_locations[device_id] + print(f"[AlertService] Filled missing coords for {device_id}: ({lat}, {lon})") + + # Enrich with template + tmpl = self.templates.get(alert_type, {}) + summary = Template(tmpl.get("summary", "")).safe_substitute( + device_id=device_id, + area=ann.get("area", ""), + confidence=ann.get("confidence", ""), + ) + recommendation = tmpl.get("recommendation", "") + category = tmpl.get("category") + + normalized = { + "alert_id": alert_id, + "alert_type": alert_type, + "device_id": device_id, + "lat": lat, + "lon": lon, + "severity": int(ann.get("severity", 1)), + "summary": summary, + "recommendation": recommendation, + "category": category, + "hls": ann.get("hls"), + "vod": ann.get("vod"), + "image_url": ann.get("image_url"), + "startsAt": a.get("startsAt"), + } + + # Update if it already exists, else append + if existing: + existing.update(normalized) + else: + self.alerts.append(normalized) + + self.alertAdded.emit(normalized) + + + def mark_all_acknowledged(self): + """Mark all alerts as acknowledged both locally and in DB (PATCH /api/tables/alerts).""" + unacked = [a for a in self.alerts if not a.get("ack", False)] + if not unacked: + return + + # Update local memory first + for a in unacked: + a["ack"] = True + + # Push updates asynchronously to DB + def _patch_ack(alert): + try: + url = f"{self.api.base}/api/tables/alerts" + payload = { + "keys": {"alert_id": alert["alert_id"]}, + "data": {"ack": True}, + } + r = self.api.http.patch(url, json=payload, timeout=5) + r.raise_for_status() + except Exception as e: + print(f"[AlertService] Failed to PATCH ack for {alert['alert_id']}: {e}") + + with ThreadPoolExecutor(max_workers=4) as pool: + for a in unacked: + pool.submit(_patch_ack, a) + + self.alertsUpdated.emit(self.alerts) + print(f"[AlertService] Marked {len(unacked)} alerts as acknowledged.") + + + + diff --git a/GUI/src/vast/dashboard_api.py b/GUI/src/vast/dashboard_api.py index bf73c37d0..00a3b2de8 100644 --- a/GUI/src/vast/dashboard_api.py +++ b/GUI/src/vast/dashboard_api.py @@ -1,55 +1,97 @@ +# -*- coding: utf-8 -*- +from __future__ import annotations +import os import json import time +import base64 import pathlib -from typing import Dict, List +from typing import Dict, List, Optional, Tuple, Union + import requests -from urllib.parse import quote from requests.adapters import HTTPAdapter from urllib3.util.retry import Retry +# ---- Optional deps (do not crash if missing) ---- +try: + from minio import Minio + from minio.error import S3Error +except Exception: # pragma: no cover + Minio = None # type: ignore + S3Error = Exception # type: ignore + +try: + from vast.rel_db import RelDB +except Exception: # pragma: no cover + RelDB = None # type: ignore + + +# ========================= +# CONFIG +# ========================= +# --- HTTP API --- +DB_API_BASE = os.getenv("DB_API_BASE", "http://db_api_service:8001") +DB_API_AUTH_MODE = os.getenv("DB_API_AUTH_MODE", "service") # "service" | "bearer" +DB_API_TOKEN_FILE = os.getenv("DB_API_TOKEN_FILE", "/app/secrets/db_api_token") +DB_API_TOKEN = os.getenv("DB_API_TOKEN", "auto") +DB_API_SERVICE_NAME = os.getenv("DB_API_SERVICE_NAME", "GUI_H") + +# --- RelDB (used inside RelDB class; here only for reference/env) --- +DB_HOST = os.getenv("DB_HOST", "127.0.0.1") +DB_PORT = int(os.getenv("DB_PORT", "5432")) +DB_USER = os.getenv("DB_USER", "missions_user") +DB_PASS = os.getenv("DB_PASS", "pg123") +DB_NAME = os.getenv("DB_NAME", "missions_db") + +# --- MinIO --- +MINIO_ENDPOINT = os.getenv("MINIO_ENDPOINT", "127.0.0.1:9001") # host:exposed_port +MINIO_ACCESS_KEY = os.getenv("MINIO_ACCESS_KEY", "minioadmin") +MINIO_SECRET_KEY = os.getenv("MINIO_SECRET_KEY", "minioadmin") +MINIO_SECURE = os.getenv("MINIO_SECURE", "false").lower() == "true" -# ---------- CONFIG ---------- -DB_API_BASE = "http://host.docker.internal:8001" -DB_API_AUTH_MODE = "service" -DB_API_TOKEN_FILE = "/app/secrets/db_api_token" -DB_API_TOKEN = "auto" -DB_API_SERVICE_NAME = "GUI_H" +DEFAULT_GROUND_BUCKET = os.getenv("GROUND_BUCKET", "ground") +DEFAULT_GROUND_PREFIX = os.getenv("GROUND_PREFIX", "") -# ---------- TOKEN BOOTSTRAP ---------- +# ========================= +# TOKEN BOOTSTRAP HELPERS +# ========================= def _safe_join_url(base: str, path: str) -> str: return f"{base.rstrip('/')}/{path.lstrip('/')}" -def _read_token_from_file(path: str) -> str | None: +def _read_token_from_file(path: str) -> Optional[str]: p = pathlib.Path(path) if p.exists(): token = p.read_text(encoding="utf-8").strip() return token or None return None -def _fetch_token_via_dev_bootstrap(base: str, retries: int = 3, backoff: float = 0.8) -> str | None: +def _fetch_token_via_dev_bootstrap(base: str, retries: int = 3, backoff: float = 0.8) -> Optional[str]: + """ + Calls /auth/_dev_bootstrap to mint/rotate a service token for this client. + """ url = _safe_join_url(base, "/auth/_dev_bootstrap") payload = {"service_name": DB_API_SERVICE_NAME, "rotate_if_exists": True} + last_exc: Optional[Exception] = None for attempt in range(1, retries + 1): try: r = requests.post(url, json=payload, timeout=10) if r.status_code in (200, 201): - data = r.json() + data = r.json() if r.content else {} raw = (data.get("service_account", {}) or {}).get("raw_token") \ - or (data.get("service_account", {}) or {}).get("token") + or (data.get("service_account", {}) or {}).get("token") if raw and isinstance(raw, str) and "***" not in raw: return raw.strip() - except Exception: - time.sleep(backoff * attempt) + except Exception as e: + last_exc = e + time.sleep(backoff * attempt) + if last_exc: + print(f"[BOOTSTRAP][WARN] last error: {last_exc}") return None - -def get_or_bootstrap_token() -> str | None: - print(f"[DEBUG] Checking for existing token file at: {DB_API_TOKEN_FILE}", flush=True) - +def get_or_bootstrap_token() -> Optional[str]: if DB_API_TOKEN and DB_API_TOKEN.lower() != "auto": - print(f"[DEBUG] Using static token from config", flush=True) + print("[DEBUG] Using static token from DB_API_TOKEN", flush=True) return DB_API_TOKEN token = _read_token_from_file(DB_API_TOKEN_FILE) @@ -57,11 +99,12 @@ def get_or_bootstrap_token() -> str | None: print(f"[DEBUG] Loaded token from {DB_API_TOKEN_FILE}", flush=True) return token - print(f"[DEBUG] No existing token found, bootstrapping via {DB_API_BASE}/auth/_dev_bootstrap", flush=True) + print(f"[DEBUG] No token found, bootstrapping via {DB_API_BASE}/auth/_dev_bootstrap", flush=True) token = _fetch_token_via_dev_bootstrap(DB_API_BASE) if token: - pathlib.Path(DB_API_TOKEN_FILE).parent.mkdir(parents=True, exist_ok=True) - pathlib.Path(DB_API_TOKEN_FILE).write_text(token, encoding="utf-8") + p = pathlib.Path(DB_API_TOKEN_FILE) + p.parent.mkdir(parents=True, exist_ok=True) + p.write_text(token, encoding="utf-8") print(f"[BOOTSTRAP] wrote token to {DB_API_TOKEN_FILE}", flush=True) return token @@ -69,68 +112,333 @@ def get_or_bootstrap_token() -> str | None: return None +# ========================= +# UTILITIES +# ========================= +def _image_id_from_object_key(object_key: str) -> str: + """ + 'some/prefix/image (3).jpg' -> 'image (3)' + """ + base = os.path.basename(object_key or "") + return base.rsplit(".", 1)[0] if "." in base else base -# ---------- API CLIENT ---------- +# ========================= +# DASHBOARD API +# ========================= class DashboardApi: - def __init__(self): + """ + Unified client: + - REST to DB-API (with token bootstrap/refresh) + - Optional MinIO helper + - Optional RelDB helper + """ + + def __init__(self) -> None: + # ---- HTTP session ---- self.base = DB_API_BASE.rstrip("/") self.http = requests.Session() + + # Attach robust retries + retry = Retry( + total=5, + backoff_factor=0.5, + status_forcelist=[500, 502, 503, 504], + allowed_methods=frozenset(["HEAD", "GET", "POST", "PUT", "DELETE", "OPTIONS", "TRACE"]) + ) + self.http.mount("http://", HTTPAdapter(max_retries=retry)) + self.http.mount("https://", HTTPAdapter(max_retries=retry)) + self.http.headers.update({"Content-Type": "application/json"}) + + # ---- Auth ---- token = get_or_bootstrap_token() + self.token: Optional[str] = token + self.token_type = "service" if DB_API_AUTH_MODE == "service" else "bearer" + self._apply_auth_header(token) + + # ---- MinIO (optional) ---- + self.minio: Optional[Minio] = None + if Minio is not None: + try: + self.minio = Minio( + MINIO_ENDPOINT, + access_key=MINIO_ACCESS_KEY, + secret_key=MINIO_SECRET_KEY, + secure=MINIO_SECURE, + ) + except Exception as e: # pragma: no cover + print(f"[MINIO][INIT][WARN] {e}") + + # ---- RelDB (optional) ---- + self.rdb: Optional[RelDB] = None + if RelDB is not None: + try: + self.rdb = RelDB() + except Exception as e: # pragma: no cover + print(f"[RelDB][INIT][WARN] {e}") + + # --------------------------- + # Auth helpers + # --------------------------- + def _apply_auth_header(self, token: Optional[str]) -> None: + # Clean previous header variants + for h in ["X-Service-Token", "Authorization"]: + if h in self.http.headers: + del self.http.headers[h] if token: if DB_API_AUTH_MODE == "service": self.http.headers.update({"X-Service-Token": token}) else: self.http.headers.update({"Authorization": f"Bearer {token}"}) - self.http.headers.update({"Content-Type": "application/json"}) - self.http.mount("http://", HTTPAdapter(max_retries=Retry(total=5, backoff_factor=0.5, status_forcelist=[500, 502, 503, 504]))) - self.http.mount("https://", HTTPAdapter(max_retries=Retry(total=5, backoff_factor=0.5, status_forcelist=[500, 502, 503, 504]))) - # ---------- METHODS ---------- + def get_token_info(self) -> dict: + """ + Tries to decode JWT payload. If not a JWT, returns basic info. + """ + t = self.token + if not t: + return {"type": self.token_type, "status": "missing"} - def list_devices(self, model: str | None = None) -> list[dict]: - - url = f"{self.base}/api/devices" - if model: - url += f"?model={model}" - try: - r = self.http.get(url, timeout=10) - if r.status_code == 200: - return r.json() - print(f"[API ERROR] {r.status_code}: {r.text[:100]}") - except Exception as e: - print(f"[API FAIL] {e}") + if "." in t: + try: + payload_b64 = t.split(".")[1] + padded = payload_b64 + "=" * (-len(payload_b64) % 4) + data = json.loads(base64.urlsafe_b64decode(padded)) + exp = data.get("exp") + secs_left = exp - int(time.time()) if exp else None + return {"type": "jwt", "exp": exp, "secs_left": secs_left, "payload": data} + except Exception: + pass + return {"type": self.token_type, "token_length": len(t)} + + def refresh_token(self) -> bool: + """ + Fetches a new service token via dev bootstrap and updates headers + file. + """ + new_token = _fetch_token_via_dev_bootstrap(self.base) + if new_token: + try: + pathlib.Path(DB_API_TOKEN_FILE).parent.mkdir(parents=True, exist_ok=True) + pathlib.Path(DB_API_TOKEN_FILE).write_text(new_token, encoding="utf-8") + except Exception as e: + print(f"[TOKEN][WARN] Could not persist new token: {e}") + self.token = new_token + self._apply_auth_header(new_token) + print("[TOKEN] refreshed", flush=True) + return True + print("[TOKEN][ERROR] refresh failed", flush=True) + return False + + # --------------------------- + # REST: examples / utilities + # --------------------------- + def list_devices(self, model: Optional[str] = None) -> List[dict]: + """ + Tries modern path /api/devices; falls back to /api/tables/devices for older servers. + """ + paths = ["/api/devices", "/api/tables/devices"] + last_err: Optional[str] = None + for path in paths: + url = f"{self.base}{path}" + if model: + sep = "&" if "?" in url else "?" + url = f"{url}{sep}model={model}" + try: + r = self.http.get(url, timeout=10) + if r.status_code == 200: + try: + return r.json() + except Exception: + print("[API WARN] devices response is not JSON", flush=True) + return [] + if r.status_code in (404, 405): + last_err = f"http-{r.status_code}" + continue + print(f"[API ERROR] {r.status_code}: {r.text[:200]}") + return [] + except Exception as e: + last_err = str(e) + continue + if last_err: + print(f"[API FAIL] list_devices: {last_err}") return [] - # ---------- THRESHOLDS ---------- def bulk_set_task_thresholds_labeled( self, - mapping: dict[tuple[str, str], float] | list[dict], + mapping: Dict[Tuple[str, str], float] | List[dict], updated_by: str = "gui", ) -> dict: - - if isinstance(mapping, dict): - items = [ + """ + Unified + fallback: + 1) POST /api/task_thresholds/batch + 2) if 404/405 -> POST /api/thresholds/batch + Body shape is normalized to: {"task": str, "label": str, "threshold": float, "updated_by": str} + """ + items = ( + [ {"task": t, "label": l or "", "threshold": thr, "updated_by": updated_by} for (t, l), thr in mapping.items() ] - else: - items = mapping + if isinstance(mapping, dict) else mapping + ) - url = f"{self.base}/api/thresholds/batch" - try: - - r = self.http.post(url, json=items, timeout=20) - if r.status_code in (200, 201): - data = r.json() - # ודאי שמבנה ok/fail תואם + paths = ["/api/task_thresholds/batch", "/api/thresholds/batch"] + last_err: Optional[str] = None + for path in paths: + url = f"{self.base}{path}" + try: + r = self.http.post(url, json=items, timeout=20) + if r.status_code in (200, 201): + data = r.json() if r.content else {} + return {"ok": list(data.get("ok", [])), "fail": list(data.get("fail", []))} + if r.status_code in (404, 405): + last_err = f"http-{r.status_code}" + continue return { - "ok": list(data.get("ok", [])), - "fail": list(data.get("fail", [])), + "ok": [], + "fail": [[[i.get("task"), i.get("label","")], f"http-{r.status_code} {r.text[:200]}"] for i in items], } - return { - "ok": [], - "fail": [[ [i.get("task"), i.get("label","")], f"http-{r.status_code} {r.text[:200]}"] for i in items], - } + except Exception as e: + last_err = str(e) + continue + return {"ok": [], "fail": [[[i.get("task"), i.get("label","")], last_err or "unknown"] for i in items]} + + # --------------------------- + # MinIO helpers (optional) + # --------------------------- + def list_minio_objects(self, bucket: str, prefix: str = "", limit: int = 100) -> List[dict]: + """ + Returns: [{'key': 'path/file.jpg', 'size': int, 'last_modified': iso}, ...] + """ + if not self.minio: + print("[MINIO][WARN] MinIO client not available") + return [] + out: List[dict] = [] + try: + for i, obj in enumerate(self.minio.list_objects(bucket, prefix=prefix, recursive=True)): + if i >= limit: + break + lm = getattr(obj, "last_modified", None) + out.append({ + "key": getattr(obj, "object_name", None) or getattr(obj, "name", None), + "size": getattr(obj, "size", None), + "last_modified": lm.isoformat() if lm else None, + }) except Exception as e: - return {"ok": [], "fail": [[ [i.get("task"), i.get("label","")], str(e)] for i in items]} \ No newline at end of file + print(f"[MINIO LIST FAIL] {e}") + return out + + def get_latest_minio_key(self, bucket: str, prefix: str = "") -> Optional[str]: + objs = self.list_minio_objects(bucket, prefix=prefix, limit=200) + if not objs: + return None + objs_sorted = sorted(objs, key=lambda o: o.get("last_modified") or "", reverse=True) + key = objs_sorted[0].get("key") + return key if isinstance(key, str) and key.strip() else None + + def get_image_bytes_from_minio(self, key: str, bucket: Optional[str] = None) -> Optional[bytes]: + if not self.minio: + print("[MINIO][WARN] MinIO client not available") + return None + bucket_name = bucket or DEFAULT_GROUND_BUCKET + try: + response = self.minio.get_object(bucket_name, key) + data = response.read() + response.close() + response.release_conn() + print(f"[DEBUG] Got {len(data)} bytes from {bucket_name}/{key}") + return data + except Exception as e: + print(f"[MINIO GET FAIL] {e}") + return None + + # --------------------------- + # RelDB delegates (optional) + # --------------------------- + def _rdb_guard(self) -> bool: + if not self.rdb: + print("[RelDB][WARN] RelDB client not available") + return False + return True + + def get_weekly_phi(self) -> dict: + if not self._rdb_guard(): return {} + return self.rdb.get_weekly_phi() + + def get_latest_rows(self, limit: int = 20) -> List[dict]: + if not self._rdb_guard(): return [] + return self.rdb.get_latest_anomalies(limit=limit) + + def get_latest_detections(self, limit: int = 20) -> List[dict]: + if not self._rdb_guard(): return [] + return self.rdb.get_latest_anomalies(limit=limit) + + def get_rows_by_image(self, image_name: str, limit: int = 50) -> List[dict]: + """ + image_name is image_id without extension. + """ + if not self._rdb_guard(): return [] + return self.rdb.get_anomalies_by_image(image_name, limit=limit) + + def get_last_row_by_image(self, image_name: str) -> Optional[dict]: + if not self._rdb_guard(): return None + return self.rdb.get_last_anomaly_by_image(image_name) + + def get_rows_by_day(self, date_iso: str, limit: int = 1000) -> List[dict]: + if not self._rdb_guard(): return [] + return self.rdb.get_anomalies_by_day(date_iso, limit=limit) + + # --------------------------- + # Image-centric (MinIO→image_id→RelDB) + # --------------------------- + def get_latest_image_key(self) -> Optional[str]: + """ + Prefer the newest in MinIO; if none—fallback to DB (if available). + """ + key = None + if self.minio: + key = self.get_latest_minio_key(DEFAULT_GROUND_BUCKET, DEFAULT_GROUND_PREFIX) + if key: + return key + if self.rdb: + try: + return self.rdb.get_latest_image_key() + except Exception as e: + print(f"[RelDB][WARN] get_latest_image_key fallback failed: {e}") + return None + + def get_anomalies_for_image_key(self, object_key: str, limit: int = 50) -> List[dict]: + if not self._rdb_guard(): return [] + image_id = _image_id_from_object_key(object_key) + return self.rdb.get_anomalies_by_image(image_id, limit=limit) + + def get_anomalies_for_current_image(self, limit: int = 100) -> List[dict]: + if not self._rdb_guard(): return [] + key = self.get_latest_image_key() + if not key: + return [] + image_id = _image_id_from_object_key(key) + return self.rdb.get_anomalies_by_image(image_id, limit=limit) + + def get_last_anomaly_for_current_image(self) -> Optional[dict]: + if not self._rdb_guard(): return None + key = self.get_latest_image_key() + if not key: + return None + image_id = _image_id_from_object_key(key) + return self.rdb.get_last_anomaly_by_image(image_id) + + def get_phi_for_image(self, image_name_or_key: str) -> dict: + if not self._rdb_guard(): + return {"phi": None, "severity_avg": None, "density": None, "coverage": None, "trend": None} + image_id = _image_id_from_object_key(image_name_or_key) + return self.rdb.get_phi_for_image(image_id) + + def get_phi_for_current_image(self) -> dict: + if not self._rdb_guard(): + return {"phi": None, "severity_avg": None, "density": None, "coverage": None, "trend": None} + key = self.get_latest_image_key() + if not key: + return {"phi": None, "severity_avg": None, "density": None, "coverage": None, "trend": None} + image_id = _image_id_from_object_key(key) + return self.rdb.get_phi_for_image(image_id) diff --git a/GUI/src/vast/desktop/Dockerfile b/GUI/src/vast/desktop/Dockerfile index 10d413e67..ec0cc9972 100644 --- a/GUI/src/vast/desktop/Dockerfile +++ b/GUI/src/vast/desktop/Dockerfile @@ -1,6 +1,8 @@ FROM python:3.11-slim ENV PYTHONDONTWRITEBYTECODE=1 PYTHONUNBUFFERED=1 WORKDIR /app + +# ───────── system dependencies ───────── RUN apt-get update && apt-get install -y --no-install-recommends \ libgl1 libegl1 libx11-6 libxcomposite1 libxext6 libxi6 libxtst6 libsm6 \ libxkbcommon0 libxkbcommon-x11-0 libxkbfile1 libxrender1 libxrandr2 \ @@ -10,24 +12,28 @@ RUN apt-get update && apt-get install -y --no-install-recommends \ libpango-1.0-0 libharfbuzz0b libatk1.0-0 libatk-bridge2.0-0 libnss3 \ libnspr4 libdbus-1-3 libkrb5-3 libgssapi-krb5-2 libasound2 libpulse0 \ fluxbox x11vnc xvfb wget net-tools python3-tk ca-certificates \ - procps iproute2 xauth git \ + procps iproute2 xauth git vlc libvlc5 libvlccore9 \ + fonts-dejavu-core fonts-noto-core fonts-noto-color-emoji\ && rm -rf /var/lib/apt/lists/* + +# (optional) minimal extra XCB deps for PyQt RUN apt-get update && apt-get install -y --no-install-recommends \ libxcb-xinerama0 libxcb-cursor0 libxcb-keysyms1 libxcb-render-util0 \ - libxcb-randr0 \ - && rm -rf /var/lib/apt/lists/* + libxcb-randr0 && rm -rf /var/lib/apt/lists/* +# ───────── optional CA certs ───────── COPY certs /app/certs RUN if [ -d ./certs ] && [ "$(ls ./certs/*.crt 2>/dev/null)" ]; then \ - echo "Configuring NetFree certificates..."; \ - cp ./certs/*.crt /usr/local/share/ca-certificates/; \ - update-ca-certificates; \ + echo "Configuring NetFree certificates..."; \ + cp ./certs/*.crt /usr/local/share/ca-certificates/; \ + update-ca-certificates; \ fi ENV SSL_CERT_FILE=/etc/ssl/certs/ca-certificates.crt ENV REQUESTS_CA_BUNDLE=/etc/ssl/certs/ca-certificates.crt ENV PIP_CERT=/etc/ssl/certs/ca-certificates.crt +# ───────── noVNC for remote GUI ───────── RUN mkdir -p /opt && \ wget --tries=3 --timeout=30 -O /tmp/novnc.tar.gz https://github.com/novnc/noVNC/archive/refs/tags/v1.4.0.tar.gz && \ tar xzf /tmp/novnc.tar.gz -C /opt && \ @@ -35,22 +41,44 @@ RUN mkdir -p /opt && \ rm /tmp/novnc.tar.gz && \ git clone --depth 1 https://github.com/novnc/websockify /opt/noVNC/utils/websockify +# ───────── Python deps ───────── COPY requirements.txt /app/requirements.txt RUN pip install --no-cache-dir -r requirements.txt +RUN pip install --no-cache-dir --upgrade pip \ + && pip install --no-cache-dir \ + "PyQt6==6.8.0" \ + "PyQt6-WebEngine==6.8.0" \ + "argon2-cffi" \ + "requests" \ + "numpy" \ + --extra-index-url https://pypi.org/simple \ + --prefer-binary \ + --break-system-packages \ + && pip show PyQt6 PyQt6-WebEngine argon2-cffi +RUN pip install plotly +RUN pip install PyJWT +# ───────── app setup ───────── RUN useradd -m -s /bin/bash appuser \ && mkdir -p /app /tmp/.X11-unix \ && chown -R appuser:appuser /app /tmp /opt/noVNC /var/tmp + RUN apt-get update && apt-get install -y --no-install-recommends gosu && rm -rf /var/lib/apt/lists/* + COPY src/vast /app/src/vast COPY src/vast/desktop/start.sh /app/start.sh -RUN sed -i 's/\r$//' /app/start.sh && \ - chmod +x /app/start.sh && \ - chown -R appuser:appuser /app -# RUN chmod +x /app/start.sh && chown -R appuser:appuser /app +RUN sed -i 's/\r$//' /app/start.sh && chmod +x /app/start.sh && chown -R appuser:appuser /app + RUN mkdir -p /app/secrets && chmod -R 777 /app/secrets + USER appuser -EXPOSE 5900 6080 +EXPOSE 5900 6080 ENV PYTHONPATH=/app/src:/app ENV DISPLAY=:0 ENV NO_VNC_PORT=6080 +ENV PORT=19100 +ENV MEDIA_BASE=http://media-proxy:8080 + CMD ["/app/start.sh"] + + + diff --git a/GUI/src/vast/desktop/start.sh b/GUI/src/vast/desktop/start.sh index cd4ca25f9..460dfcafe 100644 --- a/GUI/src/vast/desktop/start.sh +++ b/GUI/src/vast/desktop/start.sh @@ -22,3 +22,10 @@ echo "[INFO] Starting noVNC..." echo "[INFO] Starting PyQt application..." exec python /app/src/vast/main.py + + + +# # ------------------------------ +# # 🚀 Launch the main PyQt application +# # ------------------------------ +# exec /opt/venv/bin/python /app/src/vast/main.py diff --git a/GUI/src/vast/gateway/Dockerfile b/GUI/src/vast/gateway/Dockerfile index 27486d095..745b58831 100644 --- a/GUI/src/vast/gateway/Dockerfile +++ b/GUI/src/vast/gateway/Dockerfile @@ -3,7 +3,7 @@ ENV PYTHONDONTWRITEBYTECODE=1 PYTHONUNBUFFERED=1 WORKDIR /app # build arg -ARG USE_NETFREE=true +# ARG USE_NETFREE=true RUN apt-get update && apt-get install -y --no-install-recommends ca-certificates curl && rm -rf /var/lib/apt/lists/* COPY certs /app/certs @@ -22,12 +22,12 @@ ENV SSL_CERT_FILE=/etc/ssl/certs/ca-certificates.crt \ # # System CA + add NetFree certs -# RUN apt-get update && apt-get install -y --no-install-recommends ca-certificates curl && rm -rf /var/lib/apt/lists/* -# COPY certs/*.crt /usr/local/share/ca-certificates/ -# RUN update-ca-certificates || true -# ENV SSL_CERT_FILE=/etc/ssl/certs/ca-certificates.crt \ -# REQUESTS_CA_BUNDLE=/etc/ssl/certs/ca-certificates.crt \ -# PIP_CERT=/etc/ssl/certs/ca-certificates.crt +RUN apt-get update && apt-get install -y --no-install-recommends ca-certificates curl && rm -rf /var/lib/apt/lists/* +COPY certs/*.crt /usr/local/share/ca-certificates/ +RUN update-ca-certificates || true +ENV SSL_CERT_FILE=/etc/ssl/certs/ca-certificates.crt \ + REQUESTS_CA_BUNDLE=/etc/ssl/certs/ca-certificates.crt \ + PIP_CERT=/etc/ssl/certs/ca-certificates.crt # Python deps COPY requirements.txt /app/requirements.txt @@ -49,4 +49,4 @@ RUN python -m grpc_tools.protoc -I./vast/proto \ ENV PYTHONPATH=/app/vast/proto/generated:/app ENV RUNNER_ADDR=runner:50051 EXPOSE 8000 -CMD ["uvicorn", "vast.gateway.gateway_main:app", "--host", "0.0.0.0", "--port", "8000"] +CMD ["uvicorn", "vast.gateway.gateway_main:app", "--host", "0.0.0.0", "--port", "8000"] \ No newline at end of file diff --git a/GUI/src/vast/home_view.py b/GUI/src/vast/home_view.py index c7a5a6cc6..ec2b41df2 100644 --- a/GUI/src/vast/home_view.py +++ b/GUI/src/vast/home_view.py @@ -1,28 +1,56 @@ from __future__ import annotations from PyQt6.QtWebEngineWidgets import QWebEngineView from PyQt6.QtCore import QUrl, pyqtSignal -from PyQt6.QtWidgets import QWidget, QGridLayout, QVBoxLayout, QLabel, QSizePolicy, QPushButton +from PyQt6.QtWidgets import ( + QWidget, QVBoxLayout, QHBoxLayout, QLabel, + QSizePolicy, QPushButton +) from orthophoto_canvas.ui.viewer_factory import create_orthophoto_viewer from vast.orthophoto_canvas.ui.sensors_layer import SensorLayer, add_sensors_by_gps_bulk from orthophoto_canvas.ag_io import sensors_api import os +from vast.orthophoto_canvas.ui.alert_layer import AlertLayer + class HomeView(QWidget): openSensorsRequested = pyqtSignal() - def __init__(self, api, parent: QWidget | None = None): + def __init__(self, api, alert_service, parent: QWidget | None = None): super().__init__(parent) + self.api = api + self.alert_service = alert_service + # ───────────────────────────── + # Root vertical layout + # ───────────────────────────── root = QVBoxLayout(self) + root.setContentsMargins(12, 12, 12, 12) + root.setSpacing(10) + + # Header header = QLabel("Sensors Dashboard (Grafana)") - header.setStyleSheet("font-size: 20px; font-weight: 600;") + header.setStyleSheet("font-size: 20px; font-weight: 600; margin-bottom: 8px;") root.addWidget(header) - grid = QGridLayout() - grid.setHorizontalSpacing(12) - grid.setVerticalSpacing(12) - root.addLayout(grid) + # ───────────────────────────── + # Main content: Map (left) + Panels (right) + # ───────────────────────────── + main_layout = QHBoxLayout() + main_layout.setSpacing(12) + root.addLayout(main_layout, stretch=1) + + # ───── Map on the left ───── + tiles_root = "./src/vast/orthophoto_canvas/data/tiles" + self.viewer = create_orthophoto_viewer(tiles_root, forced_scheme=None, parent=self) + self.viewer.setSizePolicy(QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Expanding) + self.viewer.setMinimumSize(700, 700) # Ensures it's visibly large and square + main_layout.addWidget(self.viewer, stretch=3) + + # ───── Grafana panels on the right ───── + right_box = QVBoxLayout() + right_box.setSpacing(10) + main_layout.addLayout(right_box, stretch=2) grafana_host = os.getenv("GRAFANA_HOST", "grafana") base = f"http://{grafana_host}:3000" @@ -31,20 +59,20 @@ def __init__(self, api, parent: QWidget | None = None): QUrl(f"{base}/d-solo/agcloud-sensors/sensors?orgId=1&panelId=2&from=now-6h&to=now&refresh=10s&theme=light"), ] - for i, url in enumerate(panel_urls): + for url in panel_urls: view = QWebEngineView(self) - view.setSizePolicy(QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Expanding) + view.setSizePolicy(QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Fixed) + view.setFixedHeight(300) view.setUrl(url) - r, c = divmod(i, 2) - grid.addWidget(view, r, c) - - tiles_root = "./src/vast/orthophoto_canvas/data/tiles" - self.viewer = create_orthophoto_viewer(tiles_root, forced_scheme=None, parent=self) - grid.addWidget(self.viewer, 1, 0, 1, 2) + right_box.addWidget(view) + # ───────────────────────────── + # Load sensors layer + # ───────────────────────────── gateway_url = os.getenv("GATEWAY_URL", "http://gateway:8000") sensors_api.GATEWAY_URL = gateway_url rows = sensors_api.get_sensors() + self.sensor_layer = SensorLayer(self.viewer) add_sensors_by_gps_bulk( self.sensor_layer, @@ -53,6 +81,100 @@ def __init__(self, api, parent: QWidget | None = None): default_radius_px=0.2 ) + # ───────────────────────────── + # Alerts layer setup + # ───────────────────────────── + self.alert_layer = AlertLayer(self.viewer) + self.alert_service.alertsUpdated.connect(self._on_alerts_updated) + self.alert_service.alertAdded.connect(self._on_alert_added) + self.alert_service.alertRemoved.connect(self._on_alert_removed) + self.alert_service.load_initial() + + # ───────────────────────────── + # Footer button + # ───────────────────────────── self.sensor_types_btn = QPushButton("Sensor Types") + self.sensor_types_btn.setStyleSheet("padding: 8px 12px; font-weight: 500;") self.sensor_types_btn.clicked.connect(self.openSensorsRequested.emit) root.addWidget(self.sensor_types_btn) + + # ───────────────────────────── + # Keep the map square on resize + # ───────────────────────────── + def resizeEvent(self, event): + super().resizeEvent(event) + if self.viewer: + # Square size = min(available height, available width fraction) + left_width = int(self.width() * 0.6) + height = self.height() - 100 + size = min(left_width, height) + if size > 400: + self.viewer.setFixedSize(size-50, size-50) + + # ───────────────────────────── + # Alerts Handlers + # ───────────────────────────── + def _on_alerts_updated(self, alerts: list): + print(f"[HomeView] Full alert update: {len(alerts)} alerts") + + active_alerts = [a for a in alerts if not a.get("ended_at") and not a.get("endedAt")] + print(f"[HomeView] Displaying {len(active_alerts)} active alerts on map") + + self.alert_layer.clear_alerts() + for alert in active_alerts: + self.alert_layer.add_or_update_alert(alert) + + def _on_alert_added(self, alert: dict): + print(f"[HomeView] New alert added: {alert.get('alert_id')}") + self.alert_layer.add_or_update_alert(alert) + + def _on_alert_removed(self, alert_id: str): + print(f"[HomeView] Removing alert: {alert_id}") + self.alert_layer.remove_alert(alert_id) + + # ───────────────────────────── + # Real-time alert normalization + # ───────────────────────────── + def _on_alert_realtime(self, alert: dict): + alerts = alert.get("alerts", []) + if not alerts: + print("[HomeView] No alerts in payload.") + return + + for a in alerts: + labels = a.get("labels", {}) + ann = a.get("annotations", {}) + + normalized = { + "alert_id": labels.get("alert_id"), + "alert_type": labels.get("alertname"), + "device_id": labels.get("device"), + "lat": float(ann.get("lat")) if ann.get("lat") else None, + "lon": float(ann.get("lon")) if ann.get("lon") else None, + "severity": int(ann.get("severity", 1)), + "confidence": float(ann.get("confidence", 0)), + "area": ann.get("area"), + "summary": ann.get("summary"), + "category": ann.get("category"), + "recommendation": ann.get("recommendation"), + "meta": ann.get("meta"), + "startsAt": a.get("startsAt"), + "endsAt": a.get("endsAt"), + } + + alert_id = normalized.get("alert_id") + ended_at = normalized.get("endsAt") + is_resolved = ended_at and not ended_at.startswith("0001-01-01") + + if is_resolved: + print(f"[HomeView] Removing resolved alert: {alert_id}") + self.alert_layer.remove_alert(alert_id) + continue + + print(f"[HomeView] Active alert: {normalized['alert_type']} " + f"from {normalized['device_id']} ({normalized['lat']}, {normalized['lon']})") + self.alert_layer.add_or_update_alert(normalized) + + + + diff --git a/GUI/src/vast/main.py b/GUI/src/vast/main.py index 394e23859..f302d3b81 100644 --- a/GUI/src/vast/main.py +++ b/GUI/src/vast/main.py @@ -100,14 +100,14 @@ def main() -> int: print("[main] starting QApplication") app = QApplication(sys.argv) - # 1) show auth shell first + # 1) create the auth shell but do NOT show it shell = AuthShell() shell.setWindowTitle("Sign in") - shell.show() + # shell.show() # disabled to skip the login window # 2) when login succeeds -> open MainWindow def open_main(user): - api = DashboardApi() # pass user if needed + api = DashboardApi() # create API instance (user not required) win = MainWindow(api) # connect logout back to login @@ -116,7 +116,6 @@ def open_main(user): win.show() shell.hide() - def on_logout(win): win.close() shell.reset() @@ -125,6 +124,9 @@ def on_logout(win): # wire callback shell.on_login_success = open_main + # open the main window directly (skip login) + open_main(None) + print("[main] window shown, entering event loop") rc = app.exec() print(f"[main] event loop exited with code {rc}") diff --git a/GUI/src/vast/main_window.py b/GUI/src/vast/main_window.py index df8594822..f5e3fc80b 100644 --- a/GUI/src/vast/main_window.py +++ b/GUI/src/vast/main_window.py @@ -1,88 +1,343 @@ -from __future__ import annotations from PyQt6.QtCore import Qt, pyqtSignal, QSize from PyQt6.QtWidgets import ( - QMainWindow, QDockWidget, QListWidget, QListWidgetItem, QStatusBar, QStackedWidget, - QVBoxLayout, QWidget + QMainWindow, QDockWidget, QListWidget, QListWidgetItem, QStatusBar, + QStackedWidget, QToolButton, QLabel, QWidget, QHBoxLayout, QVBoxLayout, + QGraphicsDropShadowEffect, QPushButton ) -from PyQt6.QtGui import QAction, QIcon -from PyQt6.QtWebEngineWidgets import QWebEngineView -from PyQt6.QtCore import QUrl +from PyQt6.QtGui import QAction, QIcon, QFont, QColor +import os + from home_view import HomeView from views.sensors_view import SensorsView +from views.alerts_panel import AlertsPanel from views.notification_view import NotificationView -from dashboard_api import DashboardApi from views.fruits_view import FruitsView +from views.ground_view import GroundView +from views.auth_status_view import AuthStatusView +from dashboard_api import DashboardApi +from vast.alerts.alert_service import AlertService + +# === New Sensors GUI imports === +from views.sensorsMainView import SensorsMainView +from views.sensorsMapView import SensorsMapView +from views.sensorDetailsTab import SensorDetailsTab +from views.sensors_status_summary import SensorsStatusSummary + class MainWindow(QMainWindow): logoutRequested = pyqtSignal() def __init__(self, api: DashboardApi, parent=None): super().__init__(parent) - self.setWindowTitle("VAST – Dashboard") - self.resize(1100, 700) + self.setWindowTitle("AgCloud – Dashboard") + self.resize(1280, 760) self.api = api - # ---------- Menu ---------- + # ─────────────────────────────── + # GLOBAL STYLE + # ─────────────────────────────── + self.setStyleSheet(""" + QMainWindow { background-color: #f9fafb; } + QMenuBar { background-color: #e5e7eb; font-size: 11.5pt; padding: 4px 10px; } + QToolBar { + background: qlineargradient(x1:0, y1:0, x2:0, y2:1, stop:0 #ffffff, stop:1 #f3f4f6); + border-bottom: 1px solid #d1d5db; padding: 2px 10px; min-height: 42px; + } + QToolButton { background-color: transparent; border: none; padding: 4px; border-radius: 8px; font-size: 20px; } + QToolButton:hover { background-color: #e5e7eb; } + QListWidget { background-color: #ffffff; border: none; font-size: 12pt; color: #111827; } + QListWidget::item { padding: 10px; border-radius: 6px; } + QListWidget::item:selected { background-color: #10b981; color: white; } + QStatusBar { background-color: #f3f4f6; font-size: 10pt; } + """) + + # ─────────────────────────────── + # MENU + # ─────────────────────────────── file_menu = self.menuBar().addMenu("&File") - - # ---------- Dock navigation ---------- + self.back_action = QAction(QIcon.fromTheme("go-previous"), "Back", self) + self.back_action.setShortcut("Alt+Left") + self.back_action.triggered.connect(self.go_back) + file_menu.addAction(self.back_action) + self.logout_action = QAction("Log out", self) + self.logout_action.triggered.connect(self._logout) + file_menu.addAction(self.logout_action) + + # ─────────────────────────────── + # TOP BAR (toolbar) + # ─────────────────────────────── + toolbar = self.addToolBar("Main Toolbar") + toolbar.setMovable(False) + toolbar.setFloatable(False) + toolbar.setIconSize(QSize(32, 32)) + + top_bar = QWidget() + top_bar_layout = QHBoxLayout(top_bar) + top_bar_layout.setContentsMargins(8, 0, 8, 0) + top_bar_layout.setSpacing(10) + + # Logout button + logout_btn = QPushButton("Logout") + logout_btn.setToolTip("Log out") + logout_btn.setCursor(Qt.CursorShape.PointingHandCursor) + logout_btn.setStyleSheet(""" + QPushButton { + background-color: #10b981; + color: white; + border: none; + border-radius: 8px; + padding: 6px 16px; + font-size: 11pt; + font-weight: 600; + } + QPushButton:hover { background-color: #059669; } + QPushButton:pressed { background-color: #047857; } + """) + logout_btn.clicked.connect(self._logout) + + # Alert bell + self.alert_button = QToolButton() + self.alert_button.setToolTip("Show alerts") + self.alert_button.setText("🔔") + self.alert_button.setIconSize(QSize(40, 40)) + self.alert_button.setStyleSheet(""" + QToolButton { + font-size: 30px; + border: none; + background: transparent; + padding: 4px; + border-radius: 8px; + } + QToolButton:hover { background-color: #e5e7eb; } + """) + + # Alert badge + self.alert_badge = QLabel("0", self.alert_button) + self.alert_badge.setAlignment(Qt.AlignmentFlag.AlignCenter) + self.alert_badge.setFixedSize(24, 24) + self.alert_badge.setStyleSheet(""" + QLabel { + background-color: #3b82f6; + color: white; + font-size: 10pt; + font-weight: bold; + border-radius: 12px; + border: 2px solid white; + } + """) + self.alert_badge.hide() + + def reposition_badge(): + btn_w = self.alert_button.width() + self.alert_badge.move(btn_w - 22, 2) + self.alert_badge.raise_() + + self.alert_button.resizeEvent = lambda e: ( + QToolButton.resizeEvent(self.alert_button, e), + reposition_badge() + ) + reposition_badge() + + # ─────────────────────────────── + # TITLE AREA (Updated) + # ─────────────────────────────── + title_container = QWidget() + title_layout = QVBoxLayout(title_container) + title_layout.setContentsMargins(0, 0, 0, 0) + title_layout.setSpacing(0) + + main_title = QLabel("AgCloud") + main_title.setAlignment(Qt.AlignmentFlag.AlignCenter) + main_title.setStyleSheet(""" + QLabel { + font-size: 22pt; + font-weight: 700; + color: #047857; + letter-spacing: 1px; + } + """) + + subtitle = QLabel("The Smart Platform that Protects and Optimizes Your Field") + subtitle.setAlignment(Qt.AlignmentFlag.AlignCenter) + subtitle.setStyleSheet(""" + QLabel { + font-size: 11pt; + font-weight: 500; + color: #374151; + margin-top: 2px; + } + """) + + title_layout.addWidget(main_title) + title_layout.addWidget(subtitle) + + shadow = QGraphicsDropShadowEffect() + shadow.setBlurRadius(8) + shadow.setColor(QColor(0, 0, 0, 35)) + shadow.setOffset(0, 2) + top_bar.setGraphicsEffect(shadow) + + top_bar_layout.addWidget(logout_btn) + top_bar_layout.addWidget(self.alert_button) + top_bar_layout.addStretch() + top_bar_layout.addWidget(title_container) + top_bar_layout.addStretch() + toolbar.addWidget(top_bar) + + # ─────────────────────────────── + # NAVIGATION + # ─────────────────────────────── self.nav_dock = QDockWidget("Navigation", self) self.nav_dock.setFeatures(QDockWidget.DockWidgetFeature.NoDockWidgetFeatures) self.addDockWidget(Qt.DockWidgetArea.LeftDockWidgetArea, self.nav_dock) - self.nav_list = QListWidget(self.nav_dock) self.nav_dock.setWidget(self.nav_list) + self.nav_dock.setMinimumWidth(220) + + font = QFont(); font.setPointSize(12) + self.nav_list.setFont(font) + + for main_item in ["Home", "Sensors", "Sound", "Ground Image", "Aerial Image", "Fruits", "Security", "Settings", "Notifications", "Auth"]: + item = QListWidgetItem(main_item) + item.setData(Qt.ItemDataRole.UserRole, {"type": "main"}) + self.nav_list.addItem(item) + if main_item == "Sensors": + for sub in ["Live Data", "Sensor Health", "Location Map"]: + sub_item = QListWidgetItem(f" ↳ {sub}") + sub_item.setData(Qt.ItemDataRole.UserRole, {"type": "sub", "parent": main_item, "name": sub}) + sub_item.setHidden(True) + self.nav_list.addItem(sub_item) - for name in [ - "Home", "Sensors", "Sound", "Ground Image", - "Aerial Image", "Fruits", "Security", "Settings", "Notifications" - ]: - QListWidgetItem(name, self.nav_list) - - self.nav_list.setCurrentRow(0) self.nav_list.currentRowChanged.connect(self._on_nav_change) + self.nav_list.itemClicked.connect(self._on_nav_click) + + # ─────────────────────────────── + # ALERT SERVICE + PANEL + # ─────────────────────────────── + ws_url = os.getenv("ALERTS_WS", "ws://alerts-gateway:8000/ws/alerts") + self.alert_service = AlertService(ws_url, api) + self.alert_service.alertsUpdated.connect(self.update_alert_badge) + self.alert_service.alertAdded.connect(lambda _: self.update_alert_badge()) - # ---------- Views ---------- - self.home = HomeView(api, self) + self.alerts_panel = AlertsPanel(self.alert_service) + self.alerts_panel.setWindowFlags(Qt.WindowType.FramelessWindowHint | Qt.WindowType.Tool) + self.alerts_panel.setAttribute(Qt.WidgetAttribute.WA_TranslucentBackground) + self.alerts_panel.setStyleSheet(""" + QWidget { + background-color: #ffffff; + border: 1px solid #d1d5db; + border-radius: 10px; + } + """) + self.alerts_panel.hide() + self.alert_button.clicked.connect(self.toggle_alert_panel) + + # ─────────────────────────────── + # CENTRAL STACKED VIEWS + # ─────────────────────────────── + self.home = HomeView(api, self.alert_service, self) self.sensors_view = SensorsView(api, self) self.notification_view = NotificationView(self) - self.fruits_view = FruitsView(api,self) + self.fruits_view = FruitsView(api, self) + self.ground_view = GroundView(api, self) + self.auth_status = AuthStatusView(api, self) + + self.sensors_status_summary = SensorsStatusSummary(api, self) + self.sensors_health = SensorsView(api, self) + self.sensors_main = SensorsMainView(api, self) - # Stack for switching between views self.stack = QStackedWidget() self.setCentralWidget(self.stack) - self.views = { "Home": self.home, "Sensors": self.sensors_view, + "Sensors - Live Data": self.sensors_status_summary, + "Sensors - Sensor Health": self.sensors_health, + "Sensors - Location Map": self.sensors_main, "Notifications": self.notification_view, - "Fruits": self.fruits_view + "Fruits": self.fruits_view, + "Ground Image": self.ground_view, + "Auth": self.auth_status } - + for view in self.views.values(): self.stack.addWidget(view) - self.stack.setCurrentWidget(self.home) - - # ---------- History for Back ---------- self.history = [] - # ---------- Status bar ---------- + # ─────────────────────────────── + # STATUS BAR + # ─────────────────────────────── sb = QStatusBar(self) + sb.setStyleSheet("QStatusBar { background-color: #f3f4f6; color: #374151; font-size: 10.5pt; }") self.setStatusBar(sb) sb.showMessage("Ready") + # ─────────────────────────────── + # ALERT BADGE + # ─────────────────────────────── + def update_alert_badge(self): + unacked = sum(1 for a in self.alert_service.alerts if not a.get("ack", False)) + if unacked > 0: + self.alert_badge.setText(str(unacked)) + self.alert_badge.show() + else: + self.alert_badge.hide() + + def toggle_alert_panel(self): + if self.alerts_panel.isVisible(): + self.alerts_panel.hide() + return + + panel_width, panel_height = 420, 540 + self.alerts_panel.resize(panel_width, panel_height) + rect = self.alert_button.geometry() + bottom_left = self.alert_button.mapToGlobal(rect.bottomLeft()) + bottom_right = self.alert_button.mapToGlobal(rect.bottomRight()) + center_x = (bottom_left.x() + bottom_right.x()) // 2 - (panel_width // 2) + pos_y = bottom_left.y() + 8 + self.alerts_panel.move(center_x, pos_y) + self.alerts_panel.show() + self.alerts_panel.raise_() + + if hasattr(self.alert_service, "mark_all_acknowledged"): + self.alert_service.mark_all_acknowledged() + self.update_alert_badge() + + # ─────────────────────────────── + # NAVIGATION + # ─────────────────────────────── def _on_nav_change(self, row: int) -> None: - name = self.nav_list.item(row).text() - print(f"[MainWindow] Navigation changed to: {name}") - + name = self.nav_list.item(row).text().strip() if name in self.views: self.navigate_to(self.views[name]) else: self.statusBar().showMessage(f"Section '{name}' not implemented yet.") + def _on_nav_click(self, item): + data = item.data(Qt.ItemDataRole.UserRole) + if data and data.get("type") == "main": + parent = item.text() + expanded = False + for i in range(self.nav_list.count()): + sub_item = self.nav_list.item(i) + sub_data = sub_item.data(Qt.ItemDataRole.UserRole) + if sub_data and sub_data.get("type") == "sub" and sub_data.get("parent") == parent: + expanded = sub_item.isHidden() + break + for i in range(self.nav_list.count()): + sub_item = self.nav_list.item(i) + sub_data = sub_item.data(Qt.ItemDataRole.UserRole) + if sub_data and sub_data.get("type") == "sub" and sub_data.get("parent") == parent: + sub_item.setHidden(not expanded) + elif data and data.get("type") == "sub": + parent = data.get("parent") + sub_name = data.get("name") + key = f"{parent} - {sub_name}" + if key in self.views: + self.stack.setCurrentWidget(self.views[key]) + def navigate_to(self, widget): - print(f"[MainWindow] Navigating to widget: {widget.__class__.__name__}") current = self.stack.currentWidget() if current not in self.history: self.history.append(current) @@ -97,4 +352,4 @@ def go_back(self): def _logout(self) -> None: self.statusBar().showMessage("Logged out (demo)") - self.logoutRequested.emit() + self.logoutRequested.emit() \ No newline at end of file diff --git a/GUI/src/vast/orthophoto_canvas/ui/alert_layer.py b/GUI/src/vast/orthophoto_canvas/ui/alert_layer.py new file mode 100644 index 000000000..defc7612e --- /dev/null +++ b/GUI/src/vast/orthophoto_canvas/ui/alert_layer.py @@ -0,0 +1,203 @@ +from PyQt6.QtWidgets import ( + QGraphicsTextItem, QLabel, QVBoxLayout, QWidget, QGraphicsDropShadowEffect +) +from PyQt6.QtCore import Qt, QPoint +from PyQt6.QtGui import QColor, QFont +from src.vast.orthophoto_canvas.ui.sensors_layer import _latlon_to_xy_at_max_zoom, TILE_SIZE + + +# ───────────────────────────────────────────── +# Frameless Popup Widget +# ───────────────────────────────────────────── +class AlertPopupWidget(QWidget): + """Frameless popup with rounded corners, colored border, and drop shadow.""" + + def __init__(self, html: str, border_color: str = "#444", parent=None): + super().__init__(parent) + self.setWindowFlags(Qt.WindowType.ToolTip | Qt.WindowType.FramelessWindowHint) + self.setAttribute(Qt.WidgetAttribute.WA_TranslucentBackground) + + layout = QVBoxLayout(self) + layout.setContentsMargins(0, 0, 0, 0) + + label = QLabel() + label.setTextFormat(Qt.TextFormat.RichText) + label.setText(html) + label.setWordWrap(True) + label.setAlignment(Qt.AlignmentFlag.AlignLeft | Qt.AlignmentFlag.AlignTop) + label.setStyleSheet(f""" + QLabel {{ + background-color: #ffffff; + border: 2px solid {border_color}; + border-radius: 12px; + padding: 10px 12px; + font-family: 'Segoe UI', 'Roboto', 'Helvetica Neue', sans-serif; + font-size: 12px; + color: #111; + }} + """) + layout.addWidget(label) + + shadow = QGraphicsDropShadowEffect(self) + shadow.setBlurRadius(18) + shadow.setOffset(0, 4) + shadow.setColor(QColor(0, 0, 0, 70)) + self.setGraphicsEffect(shadow) + + self.adjustSize() + + def show_near(self, global_pos: QPoint): + """Show popup slightly above and to the right of the marker.""" + self.adjustSize() + self.move(global_pos + QPoint(12, -self.height() - 12)) + self.show() + + +# ───────────────────────────────────────────── +# Marker Item +# ───────────────────────────────────────────── +class _AlertMarker(QGraphicsTextItem): + """A single alert marker (emoji icon) that shows a modern popup on hover.""" + + def __init__(self, alert_id, alert_data, *args, **kwargs): + severity = int(alert_data.get("severity", 1)) + icon = {1: "⚠️", 2: "🚨"}.get(severity, "🚨") + super().__init__(icon, *args, **kwargs) + + self.alert_id = alert_id + self.alert_data = alert_data + self._popup = None + + self.setZValue(1_000_000) + self.setFont(QFont("Noto Color Emoji", 12)) + self.setDefaultTextColor(QColor("#222")) + self.setFlag(QGraphicsTextItem.GraphicsItemFlag.ItemIgnoresTransformations, True) + self.setAcceptHoverEvents(True) + + def hoverEnterEvent(self, event): + alert = self.alert_data + severity = int(alert.get("severity", 1)) + alert_type = alert.get("alert_type", "Alert").replace("_", " ") + device_id = alert.get("device_id", "unknown") + summary = alert.get("summary") or "No additional details." + started_at = alert.get("startsAt", "") + + border_color = {1: "#f1c232", 2: "#f39c12", 3: "#e67e22", + 4: "#cc0000", 5: "#8b0000"}.get(severity, "#999") + + tooltip_html = f""" +
+
+ {self.toPlainText()} + {alert_type.capitalize()} detected +
+
+
+ 💬 + {summary} +
+ {f'
🕒 {started_at}
' if started_at else ''} +
+ """ + + view = self.scene().views()[0] if self.scene().views() else None + if view: + global_pos = view.mapToGlobal(view.mapFromScene(self.scenePos())) + self._popup = AlertPopupWidget(tooltip_html, border_color=border_color) + self._popup.show_near(global_pos) + + super().hoverEnterEvent(event) + + def hoverLeaveEvent(self, event): + if self._popup: + self._popup.close() + self._popup = None + super().hoverLeaveEvent(event) + + +# ───────────────────────────────────────────── +# Alert Layer +# ───────────────────────────────────────────── +class AlertLayer: + """Draws alert markers on the orthophoto scene, consistent with RegionLayer projection.""" + + def __init__(self, viewer): + self.viewer = viewer + self.scene = viewer.scene + self.alerts = {} + + # Use same base tile coordinates as RegionLayer (max zoom) + z = viewer.max_zoom_fs + self._x_min_base = viewer.ts.z_ranges[z][0] + self._y_min_base = viewer.ts.z_ranges[z][2] + + def add_or_update_alert(self, alert: dict): + """Add or update a marker for the given alert.""" + if not alert: + return + + alert_id = alert.get("alert_id") or alert.get("id") or alert.get("alertId") + if not alert_id: + print("[AlertLayer] ⚠️ Skipping alert without ID:", alert) + return + + # Parse coordinates + lat = alert.get("lat") or alert.get("latitude") or alert.get("location_lat") + lon = alert.get("lon") or alert.get("longitude") or alert.get("location_lon") + try: + lat = float(lat) + lon = float(lon) + except Exception: + print(f"[AlertLayer] ⚠️ Invalid lat/lon for {alert_id}: {lat}, {lon}") + return + + pos = _latlon_to_xy_at_max_zoom(self.viewer, lat, lon) + if not pos: + print(f"[AlertLayer] ⚠️ Alert {alert_id} outside dataset bounds") + return + + xb, yb = pos + scene_x = (xb - self._x_min_base) * TILE_SIZE + scene_y = (yb - self._y_min_base) * TILE_SIZE + print(f"[AlertLayer] Alert {alert_id}: scene=({scene_x:.1f}, {scene_y:.1f})") + + # Remove old marker if exists + if alert_id in self.alerts: + old_marker, _ = self.alerts.pop(alert_id) + self.scene.removeItem(old_marker) + + severity = int(alert.get("severity", 1)) + normalized = { + "alert_id": alert_id, + "alert_type": alert.get("alert_type") or "alert", + "device_id": alert.get("device_id") or "unknown", + "area": alert.get("area") or "", + "severity": severity, + "confidence": alert.get("confidence") or 0, + "summary": alert.get("summary") or alert.get("meta") or "", + "startsAt": alert.get("started_at") or alert.get("startsAt") or "", + } + + marker = _AlertMarker(alert_id, normalized) + marker.setPos(scene_x, scene_y) + self.scene.addItem(marker) + self.alerts[alert_id] = (marker, None) + + def clear_alerts(self): + print("[AlertLayer] Clearing all alert markers") + for marker, _ in self.alerts.values(): + self.scene.removeItem(marker) + self.alerts.clear() + + def remove_alert(self, alert_id: str): + """Remove a specific alert marker from the scene.""" + if alert_id not in self.alerts: + print(f"[AlertLayer] ⚠️ Tried to remove unknown alert_id: {alert_id}") + return + marker, _ = self.alerts.pop(alert_id) + if marker: + self.scene.removeItem(marker) + print(f"[AlertLayer] ❌ Removed alert marker: {alert_id}") diff --git a/GUI/src/vast/orthophoto_canvas/ui/sensors_layer.py b/GUI/src/vast/orthophoto_canvas/ui/sensors_layer.py index 854610688..4ef7808d6 100644 --- a/GUI/src/vast/orthophoto_canvas/ui/sensors_layer.py +++ b/GUI/src/vast/orthophoto_canvas/ui/sensors_layer.py @@ -311,6 +311,43 @@ def _latlon_to_base_xy_if_inside(viewer, lat: float, lon: float, z: int = None) y_min_base <= yb < y_min_base + y_max_base + 1: return xb, yb return None +import math +MERCATOR_MAX_LAT = 85.05112878 + +def _latlon_to_xy_at_max_zoom(viewer, lat: float, lon: float) -> Optional[tuple[float, float]]: + """ + Convert WGS84 (lat, lon) → fractional tile coordinates (x, y) + aligned with viewer.max_zoom_fs scene orientation. + """ + z = viewer.max_zoom_fs + if z not in viewer.z_ranges: + return None + + try: + lat = float(lat) + lon = float(lon) + except Exception: + return None + + lat = max(min(lat, MERCATOR_MAX_LAT), -MERCATOR_MAX_LAT) + n = 1 << z + lat_rad = math.radians(lat) + xtile = (lon + 180.0) / 360.0 * n + ytile = (1.0 - math.log(math.tan(lat_rad) + 1 / math.cos(lat_rad)) / math.pi) / 2.0 * n + + # 🔁 Flip if your tile store is TMS (bottom origin) + if getattr(viewer, "is_tms", False): + ytile = n - ytile - 1 + + # 🧭 XYZ tiles: (0,0) top-left, y increases downward + # But our scene uses top-left origin (same), so no additional flip! + + x_min, x_max, y_min, y_max = viewer.ts.z_ranges[z] + if not (x_min <= xtile <= x_max and y_min <= ytile <= y_max): + return None + + return xtile, ytile + def add_sensor_by_gps_strict(layer: SensorLayer, sensor_id: str, lat: float, lon: float, z: int = None, center: bool = False, **kwargs) -> Optional[SensorSpec]: @@ -368,4 +405,4 @@ def tile2lat(y): lon_min = tile2lon(x_min); lon_max = tile2lon(x_max + 1) lat_max = tile2lat(y_xyz_min); lat_min = tile2lat(y_xyz_max + 1) print(f"[COVERAGE z={z}] lon:[{lon_min:.6f}..{lon_max:.6f}] lat:[{lat_min:.6f}..{lat_max:.6f}]") - return (lat_min, lat_max, lon_min, lon_max) \ No newline at end of file + return (lat_min, lat_max, lon_min, lon_max) diff --git a/GUI/src/vast/orthophoto_canvas/ui/viewer.py b/GUI/src/vast/orthophoto_canvas/ui/viewer.py index c83ffb827..34f0fc4a2 100644 --- a/GUI/src/vast/orthophoto_canvas/ui/viewer.py +++ b/GUI/src/vast/orthophoto_canvas/ui/viewer.py @@ -1,122 +1,171 @@ -# agcloud/ui/viewer.py from __future__ import annotations from pathlib import Path - from ..utils.tiles import TileStore from .sensors_layer import SensorLayer, add_sensors_by_gps_bulk, dataset_bbox_latlon from ..ag_io.sensors_api import get_sensors import math -from typing import Iterable, List, Optional, Tuple, Union +from typing import Optional, Tuple, Union -from PyQt6.QtCore import Qt, QRectF, QPointF, QTimer -from PyQt6.QtGui import QPixmap, QPainter, QPen, QColor, QBrush +from PyQt6.QtCore import Qt, QTimer +from PyQt6.QtGui import QPixmap, QPainter, QPen, QColor from PyQt6.QtWidgets import ( - QGraphicsView, - QGraphicsScene, - QGraphicsPixmapItem, - QGraphicsRectItem, - QGraphicsEllipseItem, - QToolTip + QGraphicsView, QGraphicsScene, QGraphicsPixmapItem, QGraphicsRectItem ) # ==== Tunables ==== TILE_SIZE = 512 -INITIAL_TILE_PX = 640.0 # initial "tile-size on screen" (px) to compute first zoom -TARGET_TILE_PX_FOR_LOD = 512.0 # target tile size (px) used to pick which z to load -SNAP_CHOICES = (512.0, 384.0, 320.0, 256.0, 192.0, 128.0) # preferred tile sizes on screen to avoid blur +TARGET_TILE_PX_FOR_LOD = 512.0 +SNAP_CHOICES = (512.0, 384.0, 320.0, 256.0, 192.0, 128.0) class OrthophotoViewer(QGraphicsView): - """ - QGraphicsView that renders a pyramidal tile set with lazy loading + LOD, - using a provided TileSet (I/O abstraction). Optionally draws sensor markers. - """ + """Stable orthophoto tile viewer that perfectly fits its container.""" - # ---------- Construction / scene bootstrapping ---------- def __init__(self, tiles: Union[TileStore, str, Path]) -> None: - """ - tileset: object that exposes min_zoom, max_zoom, z_ranges[z], tile_path(z,x,y) - (see agcloud/utils/tiles) - """ super().__init__() + + # ───────────────────────────── + # Load tiles + # ───────────────────────────── + # if isinstance(tiles, TileStore): + # self.ts = tiles + # else: + # self.ts = TileStore(Path(tiles)) + + # self.min_zoom_fs = self.ts.min_zoom + # self.max_zoom_fs = self.ts.max_zoom + # self.z_ranges = self.ts.z_ranges + # self.is_tms = self.ts.is_tms + # ───────────────────────────── +# Load tiles +# ───────────────────────────── if isinstance(tiles, TileStore): self.ts = tiles else: - self.ts = TileStore(Path(tiles)) - - self.ts.existing_zooms = self.ts.existing_zooms - self.min_zoom_fs = self.ts.min_zoom - self.max_zoom_fs = self.ts.max_zoom - self.z_ranges = self.ts.z_ranges - self.is_tms = self.ts.is_tms + tiles_path = Path(tiles) + if not tiles_path.exists(): + raise FileNotFoundError(f"[OrthophotoViewer] Tile root not found: {tiles_path}") + self.ts = TileStore(tiles_path) + + # Safety: ensure scheme attribute exists + if not hasattr(self.ts, "scheme"): + self.ts.scheme = "XYZ" + + self.min_zoom_fs = self.ts.min_zoom + self.max_zoom_fs = self.ts.max_zoom + self.z_ranges = self.ts.z_ranges + self.is_tms = self.ts.is_tms + + print(f"[DEBUG] Tile root: {self.ts.root}") + print(f"[DEBUG] Tile scheme: {self.ts.scheme}") + print(f"[DEBUG] is_tms: {self.ts.is_tms}") + print(f"[DEBUG] Zoom levels: {self.ts.existing_zooms or 'none found'}") + print(f"[DEBUG] z_ranges: {self.ts.z_ranges or 'empty'}") - self.scene = QGraphicsScene(self) - self.setScene(self.scene) - # Crisp rendering (no smoothing) - - self.setRenderHint(QPainter.RenderHint.SmoothPixmapTransform, False) - self.setRenderHint(QPainter.RenderHint.Antialiasing, False) - self.setRenderHint(QPainter.RenderHint.TextAntialiasing, False) - # Interaction / performance + # ───────────────────────────── + # Scene setup + # ───────────────────────────── + self.scene = QGraphicsScene(self) + self.setScene(self.scene) + self.setRenderHint(QPainter.RenderHint.SmoothPixmapTransform, True) + self.setRenderHint(QPainter.RenderHint.Antialiasing, True) self.setCacheMode(QGraphicsView.CacheModeFlag.CacheBackground) self.setOptimizationFlag(QGraphicsView.OptimizationFlag.DontSavePainterState, True) self.setDragMode(QGraphicsView.DragMode.ScrollHandDrag) self.setTransformationAnchor(QGraphicsView.ViewportAnchor.AnchorUnderMouse) self.setViewportUpdateMode(QGraphicsView.ViewportUpdateMode.SmartViewportUpdate) - self.setBackgroundBrush(QColor(220, 220, 220)) - self.setHorizontalScrollBarPolicy(Qt.ScrollBarPolicy.ScrollBarAlwaysOff) + + # 🔹 Light gray background (no border) + self.setBackgroundBrush(QColor("#d1d5db")) # soft gray background + self.setStyleSheet("background-color: #d1d5db; border: none;") + + # No scrollbars + self.setHorizontalScrollBarPolicy(Qt.ScrollBarPolicy.ScrollBarAlwaysOff) self.setVerticalScrollBarPolicy(Qt.ScrollBarPolicy.ScrollBarAlwaysOff) - - # State - self.current_zoom = self.ts.min_zoom - self.placeholder_color = Qt.GlobalColor.lightGray + + # ───────────────────────────── + # Internal state + # ───────────────────────────── + self.current_zoom = self.ts.max_zoom + self.placeholder_color = QColor("#d1d5db") self.tile_items: dict[Tuple[int, int, int], QGraphicsPixmapItem | QGraphicsRectItem] = {} - # Debounce loads after interaction + # ───────────────────────────── + # Timed updates + # ───────────────────────────── self.update_timer = QTimer(self) self.update_timer.setSingleShot(True) self.update_timer.timeout.connect(self.update_tiles) - self.sensor_layer = SensorLayer(self) + # ───────────────────────────── + # Optional sensor overlay + # ───────────────────────────── + self.sensor_layer = SensorLayer(self) dataset_bbox_latlon(self, z=self.max_zoom_fs) - add_sensors_by_gps_bulk(self.sensor_layer, get_sensors(), z=self.max_zoom_fs, default_radius_px=0.2) - - # Scene rect anchored to base zoom (min z) - self._init_scene_rect_from_min_zoom() - - self._did_initial_zoom = False - - # First load + + try: + add_sensors_by_gps_bulk(self.sensor_layer, get_sensors(), z=self.max_zoom_fs, default_radius_px=0.2) + except Exception as e: + print(f"[Sensors] skipped: {e}") + + # ───────────────────────────── + # Scene geometry from MAX zoom + # ───────────────────────────── + self._init_scene_rect_from_max_zoom() + + # ───────────────────────────── + # Initial zoom and centering + # ───────────────────────────── + self._fit_scene_exactly() + + # ───────────────────────────── + # Initial tile rendering + # ───────────────────────────── self.update_tiles() - - def _init_scene_rect_from_min_zoom(self) -> None: - """ - Build scene rect from the min zoom range (anchor for all z levels). - """ - z0 = self.ts.min_zoom - x_min, x_max, y_min, y_max = self.ts.z_ranges[z0] + + + # ───────────────────────────── + # Scene geometry + # ───────────────────────────── + def _init_scene_rect_from_max_zoom(self) -> None: + """Build scene rect from the max zoom level (actual dataset size).""" + z_max = self.ts.max_zoom + x_min, x_max, y_min, y_max = self.ts.z_ranges[z_max] width = (x_max - x_min + 1) * TILE_SIZE height = (y_max - y_min + 1) * TILE_SIZE self._x_min_base = x_min self._y_min_base = y_min - self.scene.setSceneRect(0, 0, width, height) - print(f"[BASE] z={z0} X:[{x_min}-{x_max}] Y:[{y_min}-{y_max}] scene={width}x{height}px") - # ---------- Sensors overlay (optional) ---------- - def set_sensors(self, sensors: list[dict]): - self.sensor_layer.clear() - add_sensors_by_gps_bulk(self.sensor_layer, sensors, z=self.max_zoom_fs, center_on_first=True) - - def clear_sensors(self) -> None: - """Remove all sensor markers from scene.""" - self.sensor_layer.clear() + # Add tiny margin to prevent borders + margin = 2 + self.scene.setSceneRect(-margin, -margin, width + margin * 2, height + margin * 2) + print(f"[BASE] z={z_max} scene={width}x{height}px") + + # ───────────────────────────── + # Fit helper + # ───────────────────────────── + def _fit_scene_exactly(self): + """Reset zoom and fit map exactly to the view size.""" + self.resetTransform() + self.fitInView(self.scene.sceneRect(), Qt.AspectRatioMode.KeepAspectRatio) + self.centerOn(self.scene.sceneRect().center()) + + # ───────────────────────────── + # Events + # ───────────────────────────── + def resizeEvent(self, e): + """Fit perfectly when resized, no cumulative zoom.""" + super().resizeEvent(e) + self._fit_scene_exactly() + self._debounced_update() + - # ---------- Interaction ---------- def wheelEvent(self, event) -> None: + """Zoom with mouse wheel.""" factor = 1.25 if event.angleDelta().y() > 0: self.scale(factor, factor) @@ -124,97 +173,49 @@ def wheelEvent(self, event) -> None: self.scale(1.0 / factor, 1.0 / factor) self._debounced_update() - def resizeEvent(self, event) -> None: - super().resizeEvent(event) - self._debounced_update() - def mouseReleaseEvent(self, event) -> None: super().mouseReleaseEvent(event) if event.button() in (Qt.MouseButton.LeftButton, Qt.MouseButton.MiddleButton): self._debounced_update() - def keyPressEvent(self, event) -> None: - k = event.key() - if k == Qt.Key.Key_C: - self._snap_to_native_scale() - return - if k == Qt.Key.Key_F: - self._fit_data_width() - return - if k in (Qt.Key.Key_1, Qt.Key.Key_2, Qt.Key.Key_3, Qt.Key.Key_4, Qt.Key.Key_5): - choices = { - Qt.Key.Key_1: 192.0, - Qt.Key.Key_2: 256.0, - Qt.Key.Key_3: 320.0, - Qt.Key.Key_4: 384.0, - Qt.Key.Key_5: 512.0, - } - self._smart_initial_focus(target_tile_px=choices[k], found=self._best_focus_tile()) - return - super().keyPressEvent(event) - def _debounced_update(self) -> None: self.update_timer.start(50) - def showEvent(self, e): - super().showEvent(e) - if getattr(self, "_did_initial_fit", False): - return - self._did_initial_fit = True - QTimer.singleShot(0, lambda: self.fit_to_data("width", 0.98)) - - def resizeEvent(self, e): - super().resizeEvent(e) - if not getattr(self, "_did_initial_zoom", False): - self._did_initial_zoom = True - self.fit_to_data("width", 40.0) - # self.zoom_to_tile_px(768.0, at_z="max") - self._debounced_update() - - # ---------- LOD: choose z, load visible tiles ---------- + # ───────────────────────────── + # Level-of-detail (LOD) + # ───────────────────────────── def _calc_zoom_level(self) -> int: - """ - Decide which z to load: pick z so that (tile on screen) ~= TARGET_TILE_PX_FOR_LOD. - """ - scale = max(self.transform().m11(), 1e-6) - z_base = self.ts.min_zoom - zf = z_base + math.log2((scale * float(TILE_SIZE)) / TARGET_TILE_PX_FOR_LOD) - z = int(round(zf)) - return max(self.ts.min_zoom, min(self.ts.max_zoom, z)) + """Force max zoom for small coverage sets.""" + return self.ts.max_zoom def update_tiles(self) -> None: - """ - Core: compute visible keys (z,x,y) and ensure each has an item. - Placeholders are created first; then upgraded to true pixmaps (with parent fallback). - """ + """Compute visible tiles and render them.""" z = self._calc_zoom_level() self.current_zoom = z - eff_tile_scene = TILE_SIZE / float(1 << (z - self.ts.min_zoom)) + eff_tile_scene = TILE_SIZE / float(1 << (z - self.ts.max_zoom)) eff_tile_screen = eff_tile_scene * max(self.transform().m11(), 1e-6) print(f"[LODDBG] z={z} tile_on_screen≈{eff_tile_screen:.1f}px") view_rect = self.mapToScene(self.viewport().rect()).boundingRect() x_min_z, x_max_z, y_min_z, y_max_z = self.ts.z_ranges[z] - # scene → base indices (anchored to min z) - start_tx = int(math.floor(view_rect.left() / eff_tile_scene)) - end_tx = int(math.floor(view_rect.right() / eff_tile_scene)) - start_ty = int(math.floor(view_rect.top() / eff_tile_scene)) - end_ty = int(math.floor(view_rect.bottom() / eff_tile_scene)) + start_tx = int(math.floor(view_rect.left() / eff_tile_scene)) + end_tx = int(math.floor(view_rect.right() / eff_tile_scene)) + start_ty = int(math.floor(view_rect.top() / eff_tile_scene)) + end_ty = int(math.floor(view_rect.bottom() / eff_tile_scene)) - scale_factor = 1 << (z - self.ts.min_zoom) + scale_factor = 1 << (z - self.ts.max_zoom) want: set[Tuple[int, int, int]] = set() for tx in range(start_tx, end_tx + 1): for ty in range(start_ty, end_ty + 1): x_idx = self._x_min_base * scale_factor + tx y_idx = self._y_min_base * scale_factor + ty - # clamp to existing range on disk if x_idx < x_min_z or x_idx > x_max_z or y_idx < y_min_z or y_idx > y_max_z: continue want.add((z, x_idx, y_idx)) - # Create / upgrade + # Create or upgrade for key in want: if key not in self.tile_items: ph = self._create_placeholder_item_at(key, eff_tile_scene) @@ -222,42 +223,34 @@ def update_tiles(self) -> None: self.scene.addItem(ph) self._try_upgrade_tile_to_pixmap(key, eff_tile_scene) - # Unload others + # Unload tiles that are no longer visible for key in list(self.tile_items.keys()): if key not in want: self.scene.removeItem(self.tile_items.pop(key)) - # ---------- Tile placement / upgrade ---------- + # ───────────────────────────── + # Tile placement / upgrade + # ───────────────────────────── def _create_placeholder_item_at(self, key: Tuple[int, int, int], eff_tile_scene: float): - """ - Create a light-gray rect exactly where the tile will go (no rounding) to avoid seams. - """ z, x, y = key - scale_factor = 1 << (z - self.ts.min_zoom) + scale_factor = 1 << (z - self.ts.max_zoom) x0 = self._x_min_base * scale_factor y0 = self._y_min_base * scale_factor - tx = x - x0 ty = y - y0 sx = tx * eff_tile_scene sy = ty * eff_tile_scene rect = QGraphicsRectItem(sx, sy, eff_tile_scene, eff_tile_scene) - rect.setBrush(self.placeholder_color) + rect.setBrush(QColor("#d1d5db")) # same gray as background rect.setPen(QPen(Qt.PenStyle.NoPen)) - return rect def _place_pixmap_item(self, pm: QPixmap, key: Tuple[int, int, int], eff_tile_scene: float): - """ - Position the pixmap at exact scene coords (anchored to min z) and scale the item, - not the pixmap (keeps the source un-resampled). - """ z, x, y = key - scale_factor = 1 << (z - self.ts.min_zoom) + scale_factor = 1 << (z - self.ts.max_zoom) x0 = self._x_min_base * scale_factor y0 = self._y_min_base * scale_factor - tx = x - x0 ty = y - y0 sx = tx * eff_tile_scene @@ -267,18 +260,12 @@ def _place_pixmap_item(self, pm: QPixmap, key: Tuple[int, int, int], eff_tile_sc item.setPos(sx, sy) s = eff_tile_scene / float(pm.width()) item.setScale(s) - item.setTransformationMode(Qt.TransformationMode.FastTransformation) + item.setTransformationMode(Qt.TransformationMode.SmoothTransformation) return item def _try_upgrade_tile_to_pixmap(self, key: Tuple[int, int, int], eff_tile_scene: float) -> None: - """ - Replace placeholder by the right pixmap: - - Try native z/x/y - - If missing, climb to parent(s) until min z, crop the relevant quadrant. - """ z, x, y = key - - # climb parents if needed + # print(f"[TRY] tile z={z} x={x} y={y}") zz, xx, yy = z, x, y pm: Optional[QPixmap] = None while zz >= self.ts.min_zoom: @@ -296,7 +283,9 @@ def _try_upgrade_tile_to_pixmap(self, key: Tuple[int, int, int], eff_tile_scene: v = (y % seg) * h pm = pm.copy(u, v, w, h) break - xx //= 2; yy //= 2; zz -= 1 + xx //= 2 + yy //= 2 + zz -= 1 if not pm: return @@ -307,177 +296,3 @@ def _try_upgrade_tile_to_pixmap(self, key: Tuple[int, int, int], eff_tile_scene: item = self._place_pixmap_item(pm, key, eff_tile_scene) self.scene.addItem(item) self.tile_items[key] = item - - # ---------- Smart focus / fit / snap ---------- - def _smart_initial_focus(self, target_tile_px: float, found: Optional[Tuple[int, int, int]], snap=True) -> None: - """ - Pick a zoom/transform so that one tile would appear ~target_tile_px wide on screen, - and center the view around an existing tile (found). - """ - if not found: - # fallback: fit entire base scene width - self.fitInView(self.scene.sceneRect(), Qt.AspectRatioMode.KeepAspectRatio) - return - - z, x, y = found - eff_scene = TILE_SIZE / float(1 << (z - self.ts.min_zoom)) - s = max(target_tile_px / eff_scene, 1e-6) - - self.resetTransform() - self.scale(s, s) - - if snap: - self._snap_to_native_scale() - - # center on that tile - scale_factor = 1 << (z - self.ts.min_zoom) - x0 = self._x_min_base * scale_factor - y0 = self._y_min_base * scale_factor - tx = x - x0 - ty = y - y0 - sx = tx * eff_scene - sy = ty * eff_scene - self.centerOn(sx + eff_scene * 0.5, sy + eff_scene * 0.5) - - self._debounced_update() - print(f"[FOCUS] center @ {z}/{x}/{y} tile≈{target_tile_px:.0f}px") - - def _snap_to_native_scale(self) -> None: - """ - Snap current transform so that tile size on screen equals one of SNAP_CHOICES, - minimizing resampling blur. - """ - center_scene = self.mapToScene(self.viewport().rect().center()) - z = self._calc_zoom_level() - eff_scene = TILE_SIZE / float(1 << (z - self.ts.min_zoom)) - cur_tile = eff_scene * max(self.transform().m11(), 1e-6) - target = min(SNAP_CHOICES, key=lambda t: abs(t - cur_tile)) - s = max(target / eff_scene, 1e-6) - self.resetTransform() - self.scale(s, s) - self.centerOn(center_scene) - self._debounced_update() - - def _fit_data_width(self, margin: float = 0.95) -> None: - """ - Fit the visible data extent at current z to the viewport width (keep slight margin). - """ - z = self._calc_zoom_level() - x_min, x_max, y_min, y_max = self.ts.z_ranges[z] - eff = TILE_SIZE / float(1 << (z - self.ts.min_zoom)) - - scale_factor = 1 << (z - self.ts.min_zoom) - x0 = self._x_min_base * scale_factor - y0 = self._y_min_base * scale_factor - - left = (x_min - x0) * eff - right = (x_max - x0 + 1) * eff - top = (y_min - y0) * eff - bottom = (y_max - y0 + 1) * eff - rect = QRectF(left, top, right - left, bottom - top) - - self.resetTransform() - if rect.width() > 0: - s_w = (self.viewport().width() / rect.width()) * float(margin) - self.scale(s_w, s_w) - self.centerOn(rect.center()) - self._debounced_update() - print(f"[FIT] width={rect.width():.1f}px scene, scale={s_w:.3f}") - - def fit_to_data(self, how="width", margin=0.98): - """Zoom so the dataset fills the viewport (width/height/all).""" - z = self._calc_zoom_level() - self.current_zoom_level = z - if z not in self.ts.z_ranges: - return - - x_min_z, x_max_z, y_min_z, y_max_z = self.ts.z_ranges[z] - - eff = TILE_SIZE / float(1 << (z - self.ts.min_zoom)) - - base_x_min, _, base_y_min, _ = self.ts.z_ranges[self.ts.min_zoom] - scale_factor = 1 << (z - self.ts.min_zoom) - x0 = base_x_min * scale_factor - y0 = base_y_min * scale_factor - - - left = (x_min_z - x0) * eff - top = (y_min_z - y0) * eff - width = (x_max_z - x_min_z + 1) * eff - height = (y_max_z - y_min_z + 1) * eff - rect = QRectF(left, top, width, height) - if rect.width() <= 0 or rect.height() <= 0: - return - - self.resetTransform() - if how == "width": - s = (self.viewport().width() / rect.width()) * margin - elif how == "height": - s = (self.viewport().height() / rect.height()) * margin - else: # "all" - s = min(self.viewport().width()/rect.width(), - self.viewport().height()/rect.height()) * margin - - self.scale(s, s) - self.centerOn(rect.center()) - self._debounced_update() - - # ---------- Find a good starting tile ---------- - def _best_focus_tile(self, prefer_z: Optional[int] = None, max_x_check: int = 64) -> Optional[Tuple[int,int,int]]: - """ - Heuristic: at the highest available z, choose x that is closest to mid-range, - then choose the y closest to mid-range that actually exists. If that x has no y, - try other x (still ordered by proximity to center). If nothing found, fall back to - the first available (z,x,y). - """ - zs = list(range(self.ts.min_zoom, self.ts.max_zoom + 1)) - if not zs: - return None - z = prefer_z if prefer_z is not None else zs[-1] - - try: - x_min, x_max, y_min, y_max = self.ts.ranges(z) - except Exception: - return None - - x0 = (x_min + x_max) // 2 - y0 = (y_min + y_max) // 2 - - # Prefer X near the center - xs = list(range(x_min, x_max + 1)) - xs.sort(key=lambda xv: abs(xv - x0)) - - # helper to list ys for an x (if tileset doesn’t implement list_y, we try a few probes) - def list_y_for_x(x: int) -> List[int]: - if hasattr(self.ts, "list_y"): - return list(getattr(self.ts, "list_y")(z, x)) # type: ignore - # minimal probe: try a small window around y0 - win = 256 - candidates = [] - for yy in (y0, y0-1, y0+1, y0-2, y0+2, y0-4, y0+4, y0-win, y0+win): - p = self.ts.tile_path(z, x, yy) - if p: - candidates.append(yy) - return sorted(set(candidates)) - - # Try up to max_x_check x-folders near center - for x in xs[:max_x_check]: - ys = list_y_for_x(x) - if not ys: - continue - y = min(ys, key=lambda yv: abs(yv - y0)) - return (z, x, y) - - # Fallback: brute probe a small grid near center - for x in xs: - for y in (y0, y0-1, y0+1, y0-2, y0+2): - if self.ts.tile_path(z, x, y): - return (z, x, y) - - # Last resort: scan all ranges (can be slower on huge sets) - for x in range(x_min, x_max + 1): - for y in range(y_min, y_max + 1): - if self.ts.tile_path(z, x, y): - return (z, x, y) - - return None diff --git a/GUI/src/vast/rel_db.py b/GUI/src/vast/rel_db.py new file mode 100644 index 000000000..8e92c49c1 --- /dev/null +++ b/GUI/src/vast/rel_db.py @@ -0,0 +1,324 @@ +# rel_db.py +from __future__ import annotations +import os +import datetime as dt +from contextlib import contextmanager +from typing import Optional, List, Dict, Tuple +from functools import lru_cache + +import psycopg2 +from psycopg2.extras import RealDictCursor + + +# ---- ENV (Docker Compose defaults) ---- +DB_HOST = os.getenv("DB_HOST", "127.0.0.1") +DB_PORT = int(os.getenv("DB_PORT", "5432")) +DB_USER = os.getenv("DB_USER", "missions_user") +DB_PASS = os.getenv("DB_PASS", "pg123") +DB_NAME = os.getenv("DB_NAME", "missions_db") + + +@contextmanager +def _pg_conn(): + conn = psycopg2.connect( + host=DB_HOST, port=DB_PORT, user=DB_USER, password=DB_PASS, dbname=DB_NAME + ) + try: + yield conn + finally: + conn.close() + + +def _query(sql: str, params: tuple = ()) -> List[Dict]: + try: + with _pg_conn() as conn: + with conn.cursor(cursor_factory=RealDictCursor) as cur: + cur.execute(sql, params) + return [dict(r) for r in cur.fetchall()] + except Exception as e: + print(f"[RelDB][QUERY FAIL] {e}\n | SQL={sql!r} | params={params!r}") + return [] + + +# ---------- Dynamic schema ---------- +@lru_cache(maxsize=1) +def _anomalies_cols() -> set[str]: + rows = _query( + "SELECT column_name FROM information_schema.columns " + "WHERE table_schema='public' AND table_name='anomalies'" + ) + return {r["column_name"] for r in rows} if rows else set() + +def _has_col(name: str) -> bool: + return name in _anomalies_cols() + +def _img_expr() -> str: + """ + Adaptive image column: + image_id (if exists) -> else tile_id -> else details->>'image_id' + """ + if _has_col("image_id"): + return "image_id" + if _has_col("tile_id"): + return "tile_id" + return "(details->>'image_id')" + +def _bbox_expr() -> str: + """Adaptive bbox column: bbox or JSON.""" + if _has_col("bbox"): + return "bbox" + return "(details->'bbox')" + +def _select_projection() -> str: + """ + Returns SELECT column list with aliases to always include: + anomaly_id, mission_id, device_id, ts, anomaly_type_id, severity, + bbox, area, label, image_id, confidence, geom, details + Even if some do not physically exist (extracted from JSON). + """ + cols = [ + "anomaly_id", + "mission_id", + "device_id", + "ts", + "anomaly_type_id", + "severity", + f"{_bbox_expr()} AS bbox", + # Derived from JSON (even if they exist physically, it’s fine; but we avoid name conflicts) + "(details->>'area')::float AS area", + "(details->>'label') AS label", + f"{_img_expr()} AS image_id", + "(details->>'confidence')::float AS confidence", + "geom", + "details", + ] + # If a physical column with the same name exists, prefer it (remove alias to avoid collision) + if _has_col("area"): + cols[cols.index("(details->>'area')::float AS area")] = "area" + if _has_col("label"): + cols[cols.index("(details->>'label') AS label")] = "label" + if _has_col("confidence"): + cols[cols.index("(details->>'confidence')::float AS confidence")] = "confidence" + return ", ".join(cols) + + +class RelDB: + """ + Thin data-access layer for the anomalies table. + Works even when image_id/bbox are missing by using details(JSONB). + """ + + # ---------- Utilities ---------- + @staticmethod + def _split_object_key(object_key: str) -> Tuple[str, str]: + if not isinstance(object_key, str): + return "", "" + name = object_key.replace("\\", "/").split("/")[-1] + if "." in name: + base = ".".join(name.split(".")[:-1]) + ext = name.split(".")[-1] + return base, ext + return name, "" + + @staticmethod + def _image_name_from_object_key(object_key: str) -> str: + base, _ = RelDB._split_object_key(object_key) + return base.strip() + + # ---------- Latest N ---------- + def get_latest_anomalies(self, limit: int = 20) -> List[Dict]: + limit = max(1, min(int(limit or 20), 1000)) + cols = _select_projection() + q_ts = f"SELECT {cols} FROM public.anomalies ORDER BY ts DESC LIMIT %s" + rows = _query(q_ts, (limit,)) + if rows: + return rows + q_id = f"SELECT {cols} FROM public.anomalies ORDER BY anomaly_id DESC LIMIT %s" + return _query(q_id, (limit,)) + + # ---------- By image ---------- + def get_anomalies_by_image(self, image_name: str, limit: int = 50) -> List[Dict]: + if not image_name: + return [] + limit = max(1, min(int(limit or 50), 1000)) + cols = _select_projection() + img_col = _img_expr() + q_ts = f""" + SELECT {cols} + FROM public.anomalies + WHERE {img_col} = %s + ORDER BY ts DESC + LIMIT %s + """ + rows = _query(q_ts, (image_name, limit)) + if rows: + return rows + q_id = f""" + SELECT {cols} + FROM public.anomalies + WHERE {img_col} = %s + ORDER BY anomaly_id DESC + LIMIT %s + """ + return _query(q_id, (image_name, limit)) + + def get_last_anomaly_by_image(self, image_name: str) -> Optional[Dict]: + rows = self.get_anomalies_by_image(image_name, limit=1) + return rows[0] if rows else None + + # ---------- From object key ---------- + def get_anomalies_for_image_key(self, object_key: str, limit: int = 50) -> List[Dict]: + image_name = self._image_name_from_object_key(object_key) + if not image_name: + return [] + return self.get_anomalies_by_image(image_name, limit=limit) + + # ---------- Latest image present in DB ---------- + def get_latest_image_key(self) -> Optional[str]: + img_col = _img_expr() + if img_col.startswith("(") and "details" in img_col: + # Can still filter based on the expression + pass + q_ts = f""" + SELECT {img_col} AS img + FROM public.anomalies + WHERE {img_col} IS NOT NULL AND {img_col} <> '' + ORDER BY ts DESC + LIMIT 50 + """ + rows = _query(q_ts) + if not rows: + q_id = f""" + SELECT {img_col} AS img + FROM public.anomalies + WHERE {img_col} IS NOT NULL AND {img_col} <> '' + ORDER BY anomaly_id DESC + LIMIT 50 + """ + rows = _query(q_id) + for r in rows or []: + v = r.get("img") + if isinstance(v, str) and v.strip(): + return v.strip() + return None + + # ---------- By day ---------- + def get_anomalies_by_day(self, date_iso: str, limit: int = 1000) -> List[Dict]: + try: + day = dt.date.fromisoformat(date_iso) + except Exception: + print(f"[RelDB][DAY WARN] invalid date {date_iso!r}") + return [] + start = dt.datetime.combine(day, dt.time.min) + end = start + dt.timedelta(days=1) + cols = _select_projection() + q = f""" + SELECT {cols} + FROM public.anomalies + WHERE ts >= %s AND ts < %s + ORDER BY ts DESC + LIMIT %s + """ + rows = _query(q, (start, end, limit)) + if rows: + return rows + return self.get_latest_anomalies(limit=limit) + + # ---------- PHI helpers ---------- + @staticmethod + def _sev_norm(x) -> Optional[float]: + try: + s = float(x) + except Exception: + return None + if s < 0: + return None + return s if s <= 1.0 else min(s, 10.0) / 10.0 + + @staticmethod + def _phi_from(sev_avg_norm: Optional[float]) -> Optional[float]: + if sev_avg_norm is None: + return None + return max(0.0, min(100.0, 100.0 * (1.0 - max(0.0, min(1.0, sev_avg_norm))))) + + # --- PHI per image --- + def get_phi_for_image(self, image_name: str) -> Dict[str, Optional[float | str]]: + if not image_name: + return {"phi": None, "severity_avg": None, "image_id": None} + img_col = _img_expr() + q = f""" + SELECT + AVG( + CASE + WHEN severity <= 1.0 THEN severity + WHEN severity > 1.0 THEN LEAST(severity, 10.0)/10.0 + ELSE NULL + END + ) AS sev_avg_norm, + COUNT(*) AS n_rows + FROM public.anomalies + WHERE {img_col} = %s + """ + rows = _query(q, (image_name,)) + sev_avg = rows[0].get("sev_avg_norm") if rows else None + phi = self._phi_from(sev_avg) + return { + "phi": phi, + "severity_avg": float(sev_avg) if sev_avg is not None else None, + "image_id": image_name, + } + + def get_phi_for_current_image(self) -> Dict[str, Optional[float | str]]: + image_name = self.get_latest_image_key() + if not image_name: + return {"phi": None, "severity_avg": None, "image_id": None} + return self.get_phi_for_image(image_name) + + # --- Weekly PHI (backward compatibility) --- + def get_weekly_phi(self) -> Dict[str, Optional[float | str]]: + today = dt.date.today() + week_start = today - dt.timedelta(days=today.weekday()) # Monday + prev_week_start = week_start - dt.timedelta(days=7) + week_end = week_start + dt.timedelta(days=7) + prev_week_end = week_start + + def _week_stats(a: dt.date, b: dt.date): + q = """ + SELECT + AVG( + CASE + WHEN severity <= 1.0 THEN severity + WHEN severity > 1.0 THEN LEAST(severity, 10.0)/10.0 + ELSE NULL + END + ) AS sev_avg_norm, + COUNT(*) AS n_rows + FROM public.anomalies + WHERE ts >= %s AND ts < %s + """ + rows = _query(q, ( + dt.datetime.combine(a, dt.time.min), + dt.datetime.combine(b, dt.time.min), + )) + return rows[0] if rows else {"sev_avg_norm": None, "n_rows": 0} + + cur = _week_stats(week_start, week_end) + prev = _week_stats(prev_week_start, prev_week_end) + + sev_avg = cur.get("sev_avg_norm") + phi = self._phi_from(sev_avg) + + n_rows = (cur.get("n_rows") or 0) + density = (n_rows / 7.0) if n_rows else None + + prev_phi = self._phi_from(prev.get("sev_avg_norm")) + trend = (phi - prev_phi) if (phi is not None and prev_phi is not None) else None + + return { + "phi": phi, + "severity_avg": float(sev_avg) if sev_avg is not None else None, + "density": float(density) if density is not None else None, + "coverage": None, + "trend": float(trend) if trend is not None else None, + "week_start": str(week_start), + } diff --git a/GUI/src/vast/runner/Dockerfile b/GUI/src/vast/runner/Dockerfile index 7b6c3330f..1307a1d59 100644 --- a/GUI/src/vast/runner/Dockerfile +++ b/GUI/src/vast/runner/Dockerfile @@ -9,9 +9,9 @@ RUN apt-get update && apt-get install -y --no-install-recommends ca-certificates COPY certs /app/certs # System CA + add NetFree certs RUN if [ "$USE_NETFREE" = "true" ] && [ -d ./certs ] && [ "$(ls ./certs/*.crt 2>/dev/null)" ]; then \ - echo "Configuring NetFree certificates..."; \ - cp ./certs/*.crt /usr/local/share/ca-certificates/; \ - update-ca-certificates; \ + echo "Configuring NetFree certificates..."; \ + cp ./certs/*.crt /usr/local/share/ca-certificates/; \ + update-ca-certificates; \ fi # SSL certs env @@ -38,4 +38,4 @@ ENV RUNNER_MODE=real SQLITE_DB=/data/app.db LOG_LEVEL=INFO PYTHONPATH=/app ENV PYTHONPATH=/app/vast/proto/generated:/app EXPOSE 50051 -CMD ["python", "vast/runner/runner_server.py"] +CMD ["python", "vast/runner/runner_server.py"] \ No newline at end of file diff --git a/GUI/src/vast/services/Dockerfile b/GUI/src/vast/services/Dockerfile index 7c7c7d4f0..3b3c03a35 100644 --- a/GUI/src/vast/services/Dockerfile +++ b/GUI/src/vast/services/Dockerfile @@ -19,5 +19,4 @@ ENV PYTHONPATH=/app/vast/proto/generated:/app # Expose port if metrics serve HTTP (optional) EXPOSE 8001 # Run the metrics app -CMD ["python", "-m", "vast.services.sensors_metrics_app"] - +CMD ["python", "-m", "vast.services.sensors_metrics_app"] \ No newline at end of file diff --git a/GUI/src/vast/views/alerts_panel.py b/GUI/src/vast/views/alerts_panel.py new file mode 100644 index 000000000..e9cfd549d --- /dev/null +++ b/GUI/src/vast/views/alerts_panel.py @@ -0,0 +1,250 @@ +# views/alerts_panel.py +from PyQt6.QtWidgets import ( + QWidget, QVBoxLayout, QLabel, QScrollArea, QFrame, QHBoxLayout +) +from PyQt6.QtCore import Qt, QTimer +from PyQt6.QtGui import QFont +from datetime import datetime, timezone +import re + + +# ──────────────────────────────────────────────── +# Helper: parse timestamps from DB or realtime +# ──────────────────────────────────────────────── +def _parse_time(value: str): + """Safely parse a timestamp from DB or Alertmanager format.""" + if not value: + return None + + v = value.strip().replace("Z", "+00:00") + + # Try ISO format first + try: + return datetime.fromisoformat(v) + except Exception: + pass + + # Try common fallback formats (Postgres or plain) + for fmt in ("%Y-%m-%d %H:%M:%S", "%Y-%m-%dT%H:%M:%S"): + try: + return datetime.strptime(v.split("+")[0], fmt) + except Exception: + continue + return None + + +# ──────────────────────────────────────────────── +# AlertItem Widget +# ──────────────────────────────────────────────── +class AlertItem(QFrame): + """Compact alert box with one-line layout that expands for longer text.""" + + def __init__(self, alert): + super().__init__() + self.alert = alert + self._build_ui() + + def _build_ui(self): + color = "#FFC107" # default amber tone + + layout = QHBoxLayout(self) + layout.setContentsMargins(10, 6, 10, 6) + layout.setSpacing(10) + + # Left colored bar + bar = QFrame() + bar.setFixedWidth(5) + bar.setStyleSheet(f"background-color: {color}; border-radius: 2px;") + layout.addWidget(bar) + + # Alert details + alert_type = self.alert.get("alert_type", "Unknown") + device = self.alert.get("device_id", "") + summary = self.alert.get("summary", "No summary") + + # Remove ISO timestamps from summary text + summary = re.sub( + r"\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}(?:\.\d+)?(?:\+\d{2}:\d{2}|Z)?", + "", + summary + ).strip() + + # --- Parse and format time --- + start_raw = ( + self.alert.get("startsAt") + or self.alert.get("started_at") + or self.alert.get("startedAt") + ) + dt = _parse_time(start_raw) + time_str = dt.strftime("%Y-%m-%d %H:%M") if dt else "–" + + # --- Alert text --- + is_unack = not self.alert.get("ack", False) + font_weight = "font-weight:600;" if is_unack else "font-weight:normal;" + text = QLabel( + f"{alert_type} " + f"on {device} — {summary} " + f"🕒 {time_str}" + ) + text.setWordWrap(True) + text.setTextInteractionFlags(Qt.TextInteractionFlag.TextSelectableByMouse) + text.setFont(QFont("Segoe UI", 9)) + layout.addWidget(text, 1) + + # Right status label + self.status_label = QLabel("ACTIVE") + self.status_label.setFont(QFont("Segoe UI", 9, QFont.Weight.Bold)) + self.status_label.setStyleSheet(f"color:{color};") + layout.addWidget(self.status_label, alignment=Qt.AlignmentFlag.AlignRight) + + # Allow box to expand vertically if needed + self.setMinimumHeight(65) + self.setMaximumHeight(130) + + # Style + self.setStyleSheet(""" + QFrame { + background-color: #ffffff; + border: 1px solid #ddd; + border-radius: 8px; + } + """) + + # ──────────────────────────────────────────────── + # Mark alert as resolved + # ──────────────────────────────────────────────── + def mark_resolved(self, ended_at): + """Change color and show duration when resolved.""" + try: + start_str = ( + self.alert.get("startsAt") + or self.alert.get("started_at") + or self.alert.get("startedAt") + ) + end_str = ended_at or self.alert.get("endedAt") or self.alert.get("ended_at") + start = _parse_time(start_str) + end = _parse_time(end_str) + + if start and end: + dur = end - start + mins = int(dur.total_seconds() // 60) + secs = int(dur.total_seconds() % 60) + duration = f"{mins}m {secs}s" + else: + duration = "" + except Exception: + duration = "" + + self.status_label.setText(f"✓ {duration}") + self.status_label.setStyleSheet("color:#2E7D32; font-weight:bold;") + self.setStyleSheet(""" + QFrame { + background-color: #f6fff6; + border: 1px solid #b8e5b8; + border-radius: 8px; + } + """) + + +# ──────────────────────────────────────────────── +# AlertsPanel Widget +# ──────────────────────────────────────────────── +class AlertsPanel(QWidget): + """Floating list of alert boxes (like a modern notification dropdown).""" + + def __init__(self, alert_service): + super().__init__() + self.alert_service = alert_service + self.items = {} + + layout = QVBoxLayout(self) + layout.setContentsMargins(10, 10, 10, 10) + layout.setSpacing(8) + + # Scrollable area + scroll = QScrollArea() + scroll.setWidgetResizable(True) + scroll.setHorizontalScrollBarPolicy(Qt.ScrollBarPolicy.ScrollBarAlwaysOff) + scroll.setVerticalScrollBarPolicy(Qt.ScrollBarPolicy.ScrollBarAsNeeded) + scroll.setStyleSheet(""" + QScrollArea { + border: none; + background: transparent; + } + QScrollBar:vertical { + width: 8px; + background: #f0f0f0; + margin: 2px; + border-radius: 4px; + } + QScrollBar::handle:vertical { + background: #bbb; + border-radius: 4px; + } + QScrollBar::handle:vertical:hover { + background: #999; + } + """) + layout.addWidget(scroll) + + # Inner container + container = QWidget() + self.vbox = QVBoxLayout(container) + self.vbox.setContentsMargins(6, 6, 6, 6) + self.vbox.setSpacing(8) + self.vbox.setAlignment(Qt.AlignmentFlag.AlignTop) + scroll.setWidget(container) + + # Connect signals + self.alert_service.alertsUpdated.connect(self._populate) + self.alert_service.alertAdded.connect(self._add_alert) + self.alert_service.alertRemoved.connect(self._mark_resolved) + + # Load initial alerts + QTimer.singleShot(500, self.alert_service.load_initial) + + # ──────────────────────────────────────────────── + # Populate panel + # ──────────────────────────────────────────────── + def _populate(self, alerts): + # Remove all existing widgets + for i in reversed(range(self.vbox.count())): + widget = self.vbox.itemAt(i).widget() + if widget: + widget.deleteLater() + self.items.clear() + + # Add in reverse chronological order + for a in reversed(alerts): + self._add_alert(a) + + # ──────────────────────────────────────────────── + # Add single alert + # ──────────────────────────────────────────────── + def _add_alert(self, alert): + alert_id = alert.get("alert_id") + if not alert_id or alert_id in self.items: + return + + item = AlertItem(alert) + self.vbox.insertWidget(0, item) + self.items[alert_id] = item + + # ✅ If alert is resolved already, mark as resolved + ended_at = alert.get("ended_at") or alert.get("endedAt") + if ended_at: + item.mark_resolved(ended_at) + + # ──────────────────────────────────────────────── + # Mark resolved by ID + # ──────────────────────────────────────────────── + def _mark_resolved(self, alert_id): + item = self.items.get(alert_id) + if item: + for a in self.alert_service.alerts: + if a.get("alert_id") == alert_id: + ended_at = a.get("endedAt") or a.get("ended_at") + break + else: + ended_at = datetime.now(timezone.utc).isoformat() + item.mark_resolved(ended_at) diff --git a/GUI/src/vast/views/assets/fields.png b/GUI/src/vast/views/assets/fields.png new file mode 100644 index 000000000..b26f5ecda Binary files /dev/null and b/GUI/src/vast/views/assets/fields.png differ diff --git a/GUI/src/vast/views/assets/leaflet/leaflet-heat.js b/GUI/src/vast/views/assets/leaflet/leaflet-heat.js new file mode 100644 index 000000000..aa8031ab5 --- /dev/null +++ b/GUI/src/vast/views/assets/leaflet/leaflet-heat.js @@ -0,0 +1,11 @@ +/* + (c) 2014, Vladimir Agafonkin + simpleheat, a tiny JavaScript library for drawing heatmaps with Canvas + https://github.com/mourner/simpleheat +*/ +!function(){"use strict";function t(i){return this instanceof t?(this._canvas=i="string"==typeof i?document.getElementById(i):i,this._ctx=i.getContext("2d"),this._width=i.width,this._height=i.height,this._max=1,void this.clear()):new t(i)}t.prototype={defaultRadius:25,defaultGradient:{.4:"blue",.6:"cyan",.7:"lime",.8:"yellow",1:"red"},data:function(t,i){return this._data=t,this},max:function(t){return this._max=t,this},add:function(t){return this._data.push(t),this},clear:function(){return this._data=[],this},radius:function(t,i){i=i||15;var a=this._circle=document.createElement("canvas"),s=a.getContext("2d"),e=this._r=t+i;return a.width=a.height=2*e,s.shadowOffsetX=s.shadowOffsetY=200,s.shadowBlur=i,s.shadowColor="black",s.beginPath(),s.arc(e-200,e-200,t,0,2*Math.PI,!0),s.closePath(),s.fill(),this},gradient:function(t){var i=document.createElement("canvas"),a=i.getContext("2d"),s=a.createLinearGradient(0,0,0,256);i.width=1,i.height=256;for(var e in t)s.addColorStop(e,t[e]);return a.fillStyle=s,a.fillRect(0,0,1,256),this._grad=a.getImageData(0,0,1,256).data,this},draw:function(t){this._circle||this.radius(this.defaultRadius),this._grad||this.gradient(this.defaultGradient);var i=this._ctx;i.clearRect(0,0,this._width,this._height);for(var a,s=0,e=this._data.length;e>s;s++)a=this._data[s],i.globalAlpha=Math.max(a[2]/this._max,t||.05),i.drawImage(this._circle,a[0]-this._r,a[1]-this._r);var n=i.getImageData(0,0,this._width,this._height);return this._colorize(n.data,this._grad),i.putImageData(n,0,0),this},_colorize:function(t,i){for(var a,s=3,e=t.length;e>s;s+=4)a=4*t[s],a&&(t[s-3]=i[a],t[s-2]=i[a+1],t[s-1]=i[a+2])}},window.simpleheat=t}(),/* + (c) 2014, Vladimir Agafonkin + Leaflet.heat, a tiny and fast heatmap plugin for Leaflet. + https://github.com/Leaflet/Leaflet.heat +*/ +L.HeatLayer=(L.Layer?L.Layer:L.Class).extend({initialize:function(t,i){this._latlngs=t,L.setOptions(this,i)},setLatLngs:function(t){return this._latlngs=t,this.redraw()},addLatLng:function(t){return this._latlngs.push(t),this.redraw()},setOptions:function(t){return L.setOptions(this,t),this._heat&&this._updateOptions(),this.redraw()},redraw:function(){return!this._heat||this._frame||this._map._animating||(this._frame=L.Util.requestAnimFrame(this._redraw,this)),this},onAdd:function(t){this._map=t,this._canvas||this._initCanvas(),t._panes.overlayPane.appendChild(this._canvas),t.on("moveend",this._reset,this),t.options.zoomAnimation&&L.Browser.any3d&&t.on("zoomanim",this._animateZoom,this),this._reset()},onRemove:function(t){t.getPanes().overlayPane.removeChild(this._canvas),t.off("moveend",this._reset,this),t.options.zoomAnimation&&t.off("zoomanim",this._animateZoom,this)},addTo:function(t){return t.addLayer(this),this},_initCanvas:function(){var t=this._canvas=L.DomUtil.create("canvas","leaflet-heatmap-layer leaflet-layer"),i=L.DomUtil.testProp(["transformOrigin","WebkitTransformOrigin","msTransformOrigin"]);t.style[i]="50% 50%";var a=this._map.getSize();t.width=a.x,t.height=a.y;var s=this._map.options.zoomAnimation&&L.Browser.any3d;L.DomUtil.addClass(t,"leaflet-zoom-"+(s?"animated":"hide")),this._heat=simpleheat(t),this._updateOptions()},_updateOptions:function(){this._heat.radius(this.options.radius||this._heat.defaultRadius,this.options.blur),this.options.gradient&&this._heat.gradient(this.options.gradient),this.options.max&&this._heat.max(this.options.max)},_reset:function(){var t=this._map.containerPointToLayerPoint([0,0]);L.DomUtil.setPosition(this._canvas,t);var i=this._map.getSize();this._heat._width!==i.x&&(this._canvas.width=this._heat._width=i.x),this._heat._height!==i.y&&(this._canvas.height=this._heat._height=i.y),this._redraw()},_redraw:function(){var t,i,a,s,e,n,h,o,r,d=[],_=this._heat._r,l=this._map.getSize(),m=new L.Bounds(L.point([-_,-_]),l.add([_,_])),c=void 0===this.options.max?1:this.options.max,u=void 0===this.options.maxZoom?this._map.getMaxZoom():this.options.maxZoom,f=1/Math.pow(2,Math.max(0,Math.min(u-this._map.getZoom(),12))),g=_/2,p=[],v=this._map._getMapPanePos(),w=v.x%g,y=v.y%g;for(t=0,i=this._latlngs.length;i>t;t++)if(a=this._map.latLngToContainerPoint(this._latlngs[t]),m.contains(a)){e=Math.floor((a.x-w)/g)+2,n=Math.floor((a.y-y)/g)+2;var x=void 0!==this._latlngs[t].alt?this._latlngs[t].alt:void 0!==this._latlngs[t][2]?+this._latlngs[t][2]:1;r=x*f,p[n]=p[n]||[],s=p[n][e],s?(s[0]=(s[0]*s[2]+a.x*r)/(s[2]+r),s[1]=(s[1]*s[2]+a.y*r)/(s[2]+r),s[2]+=r):p[n][e]=[a.x,a.y,r]}for(t=0,i=p.length;i>t;t++)if(p[t])for(h=0,o=p[t].length;o>h;h++)s=p[t][h],s&&d.push([Math.round(s[0]),Math.round(s[1]),Math.min(s[2],c)]);this._heat.data(d).draw(this.options.minOpacity),this._frame=null},_animateZoom:function(t){var i=this._map.getZoomScale(t.zoom),a=this._map._getCenterOffset(t.center)._multiplyBy(-i).subtract(this._map._getMapPanePos());L.DomUtil.setTransform?L.DomUtil.setTransform(this._canvas,a,i):this._canvas.style[L.DomUtil.TRANSFORM]=L.DomUtil.getTranslateString(a)+" scale("+i+")"}}),L.heatLayer=function(t,i){return new L.HeatLayer(t,i)}; \ No newline at end of file diff --git a/GUI/src/vast/views/assets/leaflet/leaflet.css b/GUI/src/vast/views/assets/leaflet/leaflet.css new file mode 100644 index 000000000..9ade8dc49 --- /dev/null +++ b/GUI/src/vast/views/assets/leaflet/leaflet.css @@ -0,0 +1,661 @@ +/* required styles */ + +.leaflet-pane, +.leaflet-tile, +.leaflet-marker-icon, +.leaflet-marker-shadow, +.leaflet-tile-container, +.leaflet-pane > svg, +.leaflet-pane > canvas, +.leaflet-zoom-box, +.leaflet-image-layer, +.leaflet-layer { + position: absolute; + left: 0; + top: 0; + } +.leaflet-container { + overflow: hidden; + } +.leaflet-tile, +.leaflet-marker-icon, +.leaflet-marker-shadow { + -webkit-user-select: none; + -moz-user-select: none; + user-select: none; + -webkit-user-drag: none; + } +/* Prevents IE11 from highlighting tiles in blue */ +.leaflet-tile::selection { + background: transparent; +} +/* Safari renders non-retina tile on retina better with this, but Chrome is worse */ +.leaflet-safari .leaflet-tile { + image-rendering: -webkit-optimize-contrast; + } +/* hack that prevents hw layers "stretching" when loading new tiles */ +.leaflet-safari .leaflet-tile-container { + width: 1600px; + height: 1600px; + -webkit-transform-origin: 0 0; + } +.leaflet-marker-icon, +.leaflet-marker-shadow { + display: block; + } +/* .leaflet-container svg: reset svg max-width decleration shipped in Joomla! (joomla.org) 3.x */ +/* .leaflet-container img: map is broken in FF if you have max-width: 100% on tiles */ +.leaflet-container .leaflet-overlay-pane svg { + max-width: none !important; + max-height: none !important; + } +.leaflet-container .leaflet-marker-pane img, +.leaflet-container .leaflet-shadow-pane img, +.leaflet-container .leaflet-tile-pane img, +.leaflet-container img.leaflet-image-layer, +.leaflet-container .leaflet-tile { + max-width: none !important; + max-height: none !important; + width: auto; + padding: 0; + } + +.leaflet-container img.leaflet-tile { + /* See: https://bugs.chromium.org/p/chromium/issues/detail?id=600120 */ + mix-blend-mode: plus-lighter; +} + +.leaflet-container.leaflet-touch-zoom { + -ms-touch-action: pan-x pan-y; + touch-action: pan-x pan-y; + } +.leaflet-container.leaflet-touch-drag { + -ms-touch-action: pinch-zoom; + /* Fallback for FF which doesn't support pinch-zoom */ + touch-action: none; + touch-action: pinch-zoom; +} +.leaflet-container.leaflet-touch-drag.leaflet-touch-zoom { + -ms-touch-action: none; + touch-action: none; +} +.leaflet-container { + -webkit-tap-highlight-color: transparent; +} +.leaflet-container a { + -webkit-tap-highlight-color: rgba(51, 181, 229, 0.4); +} +.leaflet-tile { + filter: inherit; + visibility: hidden; + } +.leaflet-tile-loaded { + visibility: inherit; + } +.leaflet-zoom-box { + width: 0; + height: 0; + -moz-box-sizing: border-box; + box-sizing: border-box; + z-index: 800; + } +/* workaround for https://bugzilla.mozilla.org/show_bug.cgi?id=888319 */ +.leaflet-overlay-pane svg { + -moz-user-select: none; + } + +.leaflet-pane { z-index: 400; } + +.leaflet-tile-pane { z-index: 200; } +.leaflet-overlay-pane { z-index: 400; } +.leaflet-shadow-pane { z-index: 500; } +.leaflet-marker-pane { z-index: 600; } +.leaflet-tooltip-pane { z-index: 650; } +.leaflet-popup-pane { z-index: 700; } + +.leaflet-map-pane canvas { z-index: 100; } +.leaflet-map-pane svg { z-index: 200; } + +.leaflet-vml-shape { + width: 1px; + height: 1px; + } +.lvml { + behavior: url(#default#VML); + display: inline-block; + position: absolute; + } + + +/* control positioning */ + +.leaflet-control { + position: relative; + z-index: 800; + pointer-events: visiblePainted; /* IE 9-10 doesn't have auto */ + pointer-events: auto; + } +.leaflet-top, +.leaflet-bottom { + position: absolute; + z-index: 1000; + pointer-events: none; + } +.leaflet-top { + top: 0; + } +.leaflet-right { + right: 0; + } +.leaflet-bottom { + bottom: 0; + } +.leaflet-left { + left: 0; + } +.leaflet-control { + float: left; + clear: both; + } +.leaflet-right .leaflet-control { + float: right; + } +.leaflet-top .leaflet-control { + margin-top: 10px; + } +.leaflet-bottom .leaflet-control { + margin-bottom: 10px; + } +.leaflet-left .leaflet-control { + margin-left: 10px; + } +.leaflet-right .leaflet-control { + margin-right: 10px; + } + + +/* zoom and fade animations */ + +.leaflet-fade-anim .leaflet-popup { + opacity: 0; + -webkit-transition: opacity 0.2s linear; + -moz-transition: opacity 0.2s linear; + transition: opacity 0.2s linear; + } +.leaflet-fade-anim .leaflet-map-pane .leaflet-popup { + opacity: 1; + } +.leaflet-zoom-animated { + -webkit-transform-origin: 0 0; + -ms-transform-origin: 0 0; + transform-origin: 0 0; + } +svg.leaflet-zoom-animated { + will-change: transform; +} + +.leaflet-zoom-anim .leaflet-zoom-animated { + -webkit-transition: -webkit-transform 0.25s cubic-bezier(0,0,0.25,1); + -moz-transition: -moz-transform 0.25s cubic-bezier(0,0,0.25,1); + transition: transform 0.25s cubic-bezier(0,0,0.25,1); + } +.leaflet-zoom-anim .leaflet-tile, +.leaflet-pan-anim .leaflet-tile { + -webkit-transition: none; + -moz-transition: none; + transition: none; + } + +.leaflet-zoom-anim .leaflet-zoom-hide { + visibility: hidden; + } + + +/* cursors */ + +.leaflet-interactive { + cursor: pointer; + } +.leaflet-grab { + cursor: -webkit-grab; + cursor: -moz-grab; + cursor: grab; + } +.leaflet-crosshair, +.leaflet-crosshair .leaflet-interactive { + cursor: crosshair; + } +.leaflet-popup-pane, +.leaflet-control { + cursor: auto; + } +.leaflet-dragging .leaflet-grab, +.leaflet-dragging .leaflet-grab .leaflet-interactive, +.leaflet-dragging .leaflet-marker-draggable { + cursor: move; + cursor: -webkit-grabbing; + cursor: -moz-grabbing; + cursor: grabbing; + } + +/* marker & overlays interactivity */ +.leaflet-marker-icon, +.leaflet-marker-shadow, +.leaflet-image-layer, +.leaflet-pane > svg path, +.leaflet-tile-container { + pointer-events: none; + } + +.leaflet-marker-icon.leaflet-interactive, +.leaflet-image-layer.leaflet-interactive, +.leaflet-pane > svg path.leaflet-interactive, +svg.leaflet-image-layer.leaflet-interactive path { + pointer-events: visiblePainted; /* IE 9-10 doesn't have auto */ + pointer-events: auto; + } + +/* visual tweaks */ + +.leaflet-container { + background: #ddd; + outline-offset: 1px; + } +.leaflet-container a { + color: #0078A8; + } +.leaflet-zoom-box { + border: 2px dotted #38f; + background: rgba(255,255,255,0.5); + } + + +/* general typography */ +.leaflet-container { + font-family: "Helvetica Neue", Arial, Helvetica, sans-serif; + font-size: 12px; + font-size: 0.75rem; + line-height: 1.5; + } + + +/* general toolbar styles */ + +.leaflet-bar { + box-shadow: 0 1px 5px rgba(0,0,0,0.65); + border-radius: 4px; + } +.leaflet-bar a { + background-color: #fff; + border-bottom: 1px solid #ccc; + width: 26px; + height: 26px; + line-height: 26px; + display: block; + text-align: center; + text-decoration: none; + color: black; + } +.leaflet-bar a, +.leaflet-control-layers-toggle { + background-position: 50% 50%; + background-repeat: no-repeat; + display: block; + } +.leaflet-bar a:hover, +.leaflet-bar a:focus { + background-color: #f4f4f4; + } +.leaflet-bar a:first-child { + border-top-left-radius: 4px; + border-top-right-radius: 4px; + } +.leaflet-bar a:last-child { + border-bottom-left-radius: 4px; + border-bottom-right-radius: 4px; + border-bottom: none; + } +.leaflet-bar a.leaflet-disabled { + cursor: default; + background-color: #f4f4f4; + color: #bbb; + } + +.leaflet-touch .leaflet-bar a { + width: 30px; + height: 30px; + line-height: 30px; + } +.leaflet-touch .leaflet-bar a:first-child { + border-top-left-radius: 2px; + border-top-right-radius: 2px; + } +.leaflet-touch .leaflet-bar a:last-child { + border-bottom-left-radius: 2px; + border-bottom-right-radius: 2px; + } + +/* zoom control */ + +.leaflet-control-zoom-in, +.leaflet-control-zoom-out { + font: bold 18px 'Lucida Console', Monaco, monospace; + text-indent: 1px; + } + +.leaflet-touch .leaflet-control-zoom-in, .leaflet-touch .leaflet-control-zoom-out { + font-size: 22px; + } + + +/* layers control */ + +.leaflet-control-layers { + box-shadow: 0 1px 5px rgba(0,0,0,0.4); + background: #fff; + border-radius: 5px; + } +.leaflet-control-layers-toggle { + background-image: url(images/layers.png); + width: 36px; + height: 36px; + } +.leaflet-retina .leaflet-control-layers-toggle { + background-image: url(images/layers-2x.png); + background-size: 26px 26px; + } +.leaflet-touch .leaflet-control-layers-toggle { + width: 44px; + height: 44px; + } +.leaflet-control-layers .leaflet-control-layers-list, +.leaflet-control-layers-expanded .leaflet-control-layers-toggle { + display: none; + } +.leaflet-control-layers-expanded .leaflet-control-layers-list { + display: block; + position: relative; + } +.leaflet-control-layers-expanded { + padding: 6px 10px 6px 6px; + color: #333; + background: #fff; + } +.leaflet-control-layers-scrollbar { + overflow-y: scroll; + overflow-x: hidden; + padding-right: 5px; + } +.leaflet-control-layers-selector { + margin-top: 2px; + position: relative; + top: 1px; + } +.leaflet-control-layers label { + display: block; + font-size: 13px; + font-size: 1.08333em; + } +.leaflet-control-layers-separator { + height: 0; + border-top: 1px solid #ddd; + margin: 5px -10px 5px -6px; + } + +/* Default icon URLs */ +.leaflet-default-icon-path { /* used only in path-guessing heuristic, see L.Icon.Default */ + background-image: url(images/marker-icon.png); + } + + +/* attribution and scale controls */ + +.leaflet-container .leaflet-control-attribution { + background: #fff; + background: rgba(255, 255, 255, 0.8); + margin: 0; + } +.leaflet-control-attribution, +.leaflet-control-scale-line { + padding: 0 5px; + color: #333; + line-height: 1.4; + } +.leaflet-control-attribution a { + text-decoration: none; + } +.leaflet-control-attribution a:hover, +.leaflet-control-attribution a:focus { + text-decoration: underline; + } +.leaflet-attribution-flag { + display: inline !important; + vertical-align: baseline !important; + width: 1em; + height: 0.6669em; + } +.leaflet-left .leaflet-control-scale { + margin-left: 5px; + } +.leaflet-bottom .leaflet-control-scale { + margin-bottom: 5px; + } +.leaflet-control-scale-line { + border: 2px solid #777; + border-top: none; + line-height: 1.1; + padding: 2px 5px 1px; + white-space: nowrap; + -moz-box-sizing: border-box; + box-sizing: border-box; + background: rgba(255, 255, 255, 0.8); + text-shadow: 1px 1px #fff; + } +.leaflet-control-scale-line:not(:first-child) { + border-top: 2px solid #777; + border-bottom: none; + margin-top: -2px; + } +.leaflet-control-scale-line:not(:first-child):not(:last-child) { + border-bottom: 2px solid #777; + } + +.leaflet-touch .leaflet-control-attribution, +.leaflet-touch .leaflet-control-layers, +.leaflet-touch .leaflet-bar { + box-shadow: none; + } +.leaflet-touch .leaflet-control-layers, +.leaflet-touch .leaflet-bar { + border: 2px solid rgba(0,0,0,0.2); + background-clip: padding-box; + } + + +/* popup */ + +.leaflet-popup { + position: absolute; + text-align: center; + margin-bottom: 20px; + } +.leaflet-popup-content-wrapper { + padding: 1px; + text-align: left; + border-radius: 12px; + } +.leaflet-popup-content { + margin: 13px 24px 13px 20px; + line-height: 1.3; + font-size: 13px; + font-size: 1.08333em; + min-height: 1px; + } +.leaflet-popup-content p { + margin: 17px 0; + margin: 1.3em 0; + } +.leaflet-popup-tip-container { + width: 40px; + height: 20px; + position: absolute; + left: 50%; + margin-top: -1px; + margin-left: -20px; + overflow: hidden; + pointer-events: none; + } +.leaflet-popup-tip { + width: 17px; + height: 17px; + padding: 1px; + + margin: -10px auto 0; + pointer-events: auto; + + -webkit-transform: rotate(45deg); + -moz-transform: rotate(45deg); + -ms-transform: rotate(45deg); + transform: rotate(45deg); + } +.leaflet-popup-content-wrapper, +.leaflet-popup-tip { + background: white; + color: #333; + box-shadow: 0 3px 14px rgba(0,0,0,0.4); + } +.leaflet-container a.leaflet-popup-close-button { + position: absolute; + top: 0; + right: 0; + border: none; + text-align: center; + width: 24px; + height: 24px; + font: 16px/24px Tahoma, Verdana, sans-serif; + color: #757575; + text-decoration: none; + background: transparent; + } +.leaflet-container a.leaflet-popup-close-button:hover, +.leaflet-container a.leaflet-popup-close-button:focus { + color: #585858; + } +.leaflet-popup-scrolled { + overflow: auto; + } + +.leaflet-oldie .leaflet-popup-content-wrapper { + -ms-zoom: 1; + } +.leaflet-oldie .leaflet-popup-tip { + width: 24px; + margin: 0 auto; + + -ms-filter: "progid:DXImageTransform.Microsoft.Matrix(M11=0.70710678, M12=0.70710678, M21=-0.70710678, M22=0.70710678)"; + filter: progid:DXImageTransform.Microsoft.Matrix(M11=0.70710678, M12=0.70710678, M21=-0.70710678, M22=0.70710678); + } + +.leaflet-oldie .leaflet-control-zoom, +.leaflet-oldie .leaflet-control-layers, +.leaflet-oldie .leaflet-popup-content-wrapper, +.leaflet-oldie .leaflet-popup-tip { + border: 1px solid #999; + } + + +/* div icon */ + +.leaflet-div-icon { + background: #fff; + border: 1px solid #666; + } + + +/* Tooltip */ +/* Base styles for the element that has a tooltip */ +.leaflet-tooltip { + position: absolute; + padding: 6px; + background-color: #fff; + border: 1px solid #fff; + border-radius: 3px; + color: #222; + white-space: nowrap; + -webkit-user-select: none; + -moz-user-select: none; + -ms-user-select: none; + user-select: none; + pointer-events: none; + box-shadow: 0 1px 3px rgba(0,0,0,0.4); + } +.leaflet-tooltip.leaflet-interactive { + cursor: pointer; + pointer-events: auto; + } +.leaflet-tooltip-top:before, +.leaflet-tooltip-bottom:before, +.leaflet-tooltip-left:before, +.leaflet-tooltip-right:before { + position: absolute; + pointer-events: none; + border: 6px solid transparent; + background: transparent; + content: ""; + } + +/* Directions */ + +.leaflet-tooltip-bottom { + margin-top: 6px; +} +.leaflet-tooltip-top { + margin-top: -6px; +} +.leaflet-tooltip-bottom:before, +.leaflet-tooltip-top:before { + left: 50%; + margin-left: -6px; + } +.leaflet-tooltip-top:before { + bottom: 0; + margin-bottom: -12px; + border-top-color: #fff; + } +.leaflet-tooltip-bottom:before { + top: 0; + margin-top: -12px; + margin-left: -6px; + border-bottom-color: #fff; + } +.leaflet-tooltip-left { + margin-left: -6px; +} +.leaflet-tooltip-right { + margin-left: 6px; +} +.leaflet-tooltip-left:before, +.leaflet-tooltip-right:before { + top: 50%; + margin-top: -6px; + } +.leaflet-tooltip-left:before { + right: 0; + margin-right: -12px; + border-left-color: #fff; + } +.leaflet-tooltip-right:before { + left: 0; + margin-left: -12px; + border-right-color: #fff; + } + +/* Printing */ + +@media print { + /* Prevent printers from removing background-images of controls. */ + .leaflet-control { + -webkit-print-color-adjust: exact; + print-color-adjust: exact; + } + } diff --git a/GUI/src/vast/views/assets/leaflet/leaflet.js b/GUI/src/vast/views/assets/leaflet/leaflet.js new file mode 100644 index 000000000..a3bf693d0 --- /dev/null +++ b/GUI/src/vast/views/assets/leaflet/leaflet.js @@ -0,0 +1,6 @@ +/* @preserve + * Leaflet 1.9.4, a JS library for interactive maps. https://leafletjs.com + * (c) 2010-2023 Vladimir Agafonkin, (c) 2010-2011 CloudMade + */ +!function(t,e){"object"==typeof exports&&"undefined"!=typeof module?e(exports):"function"==typeof define&&define.amd?define(["exports"],e):e((t="undefined"!=typeof globalThis?globalThis:t||self).leaflet={})}(this,function(t){"use strict";function l(t){for(var e,i,n=1,o=arguments.length;n=this.min.x&&i.x<=this.max.x&&e.y>=this.min.y&&i.y<=this.max.y},intersects:function(t){t=_(t);var e=this.min,i=this.max,n=t.min,t=t.max,o=t.x>=e.x&&n.x<=i.x,t=t.y>=e.y&&n.y<=i.y;return o&&t},overlaps:function(t){t=_(t);var e=this.min,i=this.max,n=t.min,t=t.max,o=t.x>e.x&&n.xe.y&&n.y=n.lat&&i.lat<=o.lat&&e.lng>=n.lng&&i.lng<=o.lng},intersects:function(t){t=g(t);var e=this._southWest,i=this._northEast,n=t.getSouthWest(),t=t.getNorthEast(),o=t.lat>=e.lat&&n.lat<=i.lat,t=t.lng>=e.lng&&n.lng<=i.lng;return o&&t},overlaps:function(t){t=g(t);var e=this._southWest,i=this._northEast,n=t.getSouthWest(),t=t.getNorthEast(),o=t.lat>e.lat&&n.late.lng&&n.lng","http://www.w3.org/2000/svg"===(Wt.firstChild&&Wt.firstChild.namespaceURI));function y(t){return 0<=navigator.userAgent.toLowerCase().indexOf(t)}var b={ie:pt,ielt9:mt,edge:n,webkit:ft,android:gt,android23:vt,androidStock:yt,opera:xt,chrome:wt,gecko:bt,safari:Pt,phantom:Lt,opera12:o,win:Tt,ie3d:Mt,webkit3d:zt,gecko3d:_t,any3d:Ct,mobile:Zt,mobileWebkit:St,mobileWebkit3d:Et,msPointer:kt,pointer:Ot,touch:Bt,touchNative:At,mobileOpera:It,mobileGecko:Rt,retina:Nt,passiveEvents:Dt,canvas:jt,svg:Ht,vml:!Ht&&function(){try{var t=document.createElement("div"),e=(t.innerHTML='',t.firstChild);return e.style.behavior="url(#default#VML)",e&&"object"==typeof e.adj}catch(t){return!1}}(),inlineSvg:Wt,mac:0===navigator.platform.indexOf("Mac"),linux:0===navigator.platform.indexOf("Linux")},Ft=b.msPointer?"MSPointerDown":"pointerdown",Ut=b.msPointer?"MSPointerMove":"pointermove",Vt=b.msPointer?"MSPointerUp":"pointerup",qt=b.msPointer?"MSPointerCancel":"pointercancel",Gt={touchstart:Ft,touchmove:Ut,touchend:Vt,touchcancel:qt},Kt={touchstart:function(t,e){e.MSPOINTER_TYPE_TOUCH&&e.pointerType===e.MSPOINTER_TYPE_TOUCH&&O(e);ee(t,e)},touchmove:ee,touchend:ee,touchcancel:ee},Yt={},Xt=!1;function Jt(t,e,i){return"touchstart"!==e||Xt||(document.addEventListener(Ft,$t,!0),document.addEventListener(Ut,Qt,!0),document.addEventListener(Vt,te,!0),document.addEventListener(qt,te,!0),Xt=!0),Kt[e]?(i=Kt[e].bind(this,i),t.addEventListener(Gt[e],i,!1),i):(console.warn("wrong event specified:",e),u)}function $t(t){Yt[t.pointerId]=t}function Qt(t){Yt[t.pointerId]&&(Yt[t.pointerId]=t)}function te(t){delete Yt[t.pointerId]}function ee(t,e){if(e.pointerType!==(e.MSPOINTER_TYPE_MOUSE||"mouse")){for(var i in e.touches=[],Yt)e.touches.push(Yt[i]);e.changedTouches=[e],t(e)}}var ie=200;function ne(t,i){t.addEventListener("dblclick",i);var n,o=0;function e(t){var e;1!==t.detail?n=t.detail:"mouse"===t.pointerType||t.sourceCapabilities&&!t.sourceCapabilities.firesTouchEvents||((e=Ne(t)).some(function(t){return t instanceof HTMLLabelElement&&t.attributes.for})&&!e.some(function(t){return t instanceof HTMLInputElement||t instanceof HTMLSelectElement})||((e=Date.now())-o<=ie?2===++n&&i(function(t){var e,i,n={};for(i in t)e=t[i],n[i]=e&&e.bind?e.bind(t):e;return(t=n).type="dblclick",n.detail=2,n.isTrusted=!1,n._simulated=!0,n}(t)):n=1,o=e))}return t.addEventListener("click",e),{dblclick:i,simDblclick:e}}var oe,se,re,ae,he,le,ue=we(["transform","webkitTransform","OTransform","MozTransform","msTransform"]),ce=we(["webkitTransition","transition","OTransition","MozTransition","msTransition"]),de="webkitTransition"===ce||"OTransition"===ce?ce+"End":"transitionend";function _e(t){return"string"==typeof t?document.getElementById(t):t}function pe(t,e){var i=t.style[e]||t.currentStyle&&t.currentStyle[e];return"auto"===(i=i&&"auto"!==i||!document.defaultView?i:(t=document.defaultView.getComputedStyle(t,null))?t[e]:null)?null:i}function P(t,e,i){t=document.createElement(t);return t.className=e||"",i&&i.appendChild(t),t}function T(t){var e=t.parentNode;e&&e.removeChild(t)}function me(t){for(;t.firstChild;)t.removeChild(t.firstChild)}function fe(t){var e=t.parentNode;e&&e.lastChild!==t&&e.appendChild(t)}function ge(t){var e=t.parentNode;e&&e.firstChild!==t&&e.insertBefore(t,e.firstChild)}function ve(t,e){return void 0!==t.classList?t.classList.contains(e):0<(t=xe(t)).length&&new RegExp("(^|\\s)"+e+"(\\s|$)").test(t)}function M(t,e){var i;if(void 0!==t.classList)for(var n=F(e),o=0,s=n.length;othis.options.maxZoom)?this.setZoom(t):this},panInsideBounds:function(t,e){this._enforcingBounds=!0;var i=this.getCenter(),t=this._limitCenter(i,this._zoom,g(t));return i.equals(t)||this.panTo(t,e),this._enforcingBounds=!1,this},panInside:function(t,e){var i=m((e=e||{}).paddingTopLeft||e.padding||[0,0]),n=m(e.paddingBottomRight||e.padding||[0,0]),o=this.project(this.getCenter()),t=this.project(t),s=this.getPixelBounds(),i=_([s.min.add(i),s.max.subtract(n)]),s=i.getSize();return i.contains(t)||(this._enforcingBounds=!0,n=t.subtract(i.getCenter()),i=i.extend(t).getSize().subtract(s),o.x+=n.x<0?-i.x:i.x,o.y+=n.y<0?-i.y:i.y,this.panTo(this.unproject(o),e),this._enforcingBounds=!1),this},invalidateSize:function(t){if(!this._loaded)return this;t=l({animate:!1,pan:!0},!0===t?{animate:!0}:t);var e=this.getSize(),i=(this._sizeChanged=!0,this._lastCenter=null,this.getSize()),n=e.divideBy(2).round(),o=i.divideBy(2).round(),n=n.subtract(o);return n.x||n.y?(t.animate&&t.pan?this.panBy(n):(t.pan&&this._rawPanBy(n),this.fire("move"),t.debounceMoveend?(clearTimeout(this._sizeTimer),this._sizeTimer=setTimeout(a(this.fire,this,"moveend"),200)):this.fire("moveend")),this.fire("resize",{oldSize:e,newSize:i})):this},stop:function(){return this.setZoom(this._limitZoom(this._zoom)),this.options.zoomSnap||this.fire("viewreset"),this._stop()},locate:function(t){var e,i;return t=this._locateOptions=l({timeout:1e4,watch:!1},t),"geolocation"in navigator?(e=a(this._handleGeolocationResponse,this),i=a(this._handleGeolocationError,this),t.watch?this._locationWatchId=navigator.geolocation.watchPosition(e,i,t):navigator.geolocation.getCurrentPosition(e,i,t)):this._handleGeolocationError({code:0,message:"Geolocation not supported."}),this},stopLocate:function(){return navigator.geolocation&&navigator.geolocation.clearWatch&&navigator.geolocation.clearWatch(this._locationWatchId),this._locateOptions&&(this._locateOptions.setView=!1),this},_handleGeolocationError:function(t){var e;this._container._leaflet_id&&(e=t.code,t=t.message||(1===e?"permission denied":2===e?"position unavailable":"timeout"),this._locateOptions.setView&&!this._loaded&&this.fitWorld(),this.fire("locationerror",{code:e,message:"Geolocation error: "+t+"."}))},_handleGeolocationResponse:function(t){if(this._container._leaflet_id){var e,i,n=new v(t.coords.latitude,t.coords.longitude),o=n.toBounds(2*t.coords.accuracy),s=this._locateOptions,r=(s.setView&&(e=this.getBoundsZoom(o),this.setView(n,s.maxZoom?Math.min(e,s.maxZoom):e)),{latlng:n,bounds:o,timestamp:t.timestamp});for(i in t.coords)"number"==typeof t.coords[i]&&(r[i]=t.coords[i]);this.fire("locationfound",r)}},addHandler:function(t,e){return e&&(e=this[t]=new e(this),this._handlers.push(e),this.options[t]&&e.enable()),this},remove:function(){if(this._initEvents(!0),this.options.maxBounds&&this.off("moveend",this._panInsideMaxBounds),this._containerId!==this._container._leaflet_id)throw new Error("Map container is being reused by another instance");try{delete this._container._leaflet_id,delete this._containerId}catch(t){this._container._leaflet_id=void 0,this._containerId=void 0}for(var t in void 0!==this._locationWatchId&&this.stopLocate(),this._stop(),T(this._mapPane),this._clearControlPos&&this._clearControlPos(),this._resizeRequest&&(r(this._resizeRequest),this._resizeRequest=null),this._clearHandlers(),this._loaded&&this.fire("unload"),this._layers)this._layers[t].remove();for(t in this._panes)T(this._panes[t]);return this._layers=[],this._panes=[],delete this._mapPane,delete this._renderer,this},createPane:function(t,e){e=P("div","leaflet-pane"+(t?" leaflet-"+t.replace("Pane","")+"-pane":""),e||this._mapPane);return t&&(this._panes[t]=e),e},getCenter:function(){return this._checkIfLoaded(),this._lastCenter&&!this._moved()?this._lastCenter.clone():this.layerPointToLatLng(this._getCenterLayerPoint())},getZoom:function(){return this._zoom},getBounds:function(){var t=this.getPixelBounds();return new s(this.unproject(t.getBottomLeft()),this.unproject(t.getTopRight()))},getMinZoom:function(){return void 0===this.options.minZoom?this._layersMinZoom||0:this.options.minZoom},getMaxZoom:function(){return void 0===this.options.maxZoom?void 0===this._layersMaxZoom?1/0:this._layersMaxZoom:this.options.maxZoom},getBoundsZoom:function(t,e,i){t=g(t),i=m(i||[0,0]);var n=this.getZoom()||0,o=this.getMinZoom(),s=this.getMaxZoom(),r=t.getNorthWest(),t=t.getSouthEast(),i=this.getSize().subtract(i),t=_(this.project(t,n),this.project(r,n)).getSize(),r=b.any3d?this.options.zoomSnap:1,a=i.x/t.x,i=i.y/t.y,t=e?Math.max(a,i):Math.min(a,i),n=this.getScaleZoom(t,n);return r&&(n=Math.round(n/(r/100))*(r/100),n=e?Math.ceil(n/r)*r:Math.floor(n/r)*r),Math.max(o,Math.min(s,n))},getSize:function(){return this._size&&!this._sizeChanged||(this._size=new p(this._container.clientWidth||0,this._container.clientHeight||0),this._sizeChanged=!1),this._size.clone()},getPixelBounds:function(t,e){t=this._getTopLeftPoint(t,e);return new f(t,t.add(this.getSize()))},getPixelOrigin:function(){return this._checkIfLoaded(),this._pixelOrigin},getPixelWorldBounds:function(t){return this.options.crs.getProjectedBounds(void 0===t?this.getZoom():t)},getPane:function(t){return"string"==typeof t?this._panes[t]:t},getPanes:function(){return this._panes},getContainer:function(){return this._container},getZoomScale:function(t,e){var i=this.options.crs;return e=void 0===e?this._zoom:e,i.scale(t)/i.scale(e)},getScaleZoom:function(t,e){var i=this.options.crs,t=(e=void 0===e?this._zoom:e,i.zoom(t*i.scale(e)));return isNaN(t)?1/0:t},project:function(t,e){return e=void 0===e?this._zoom:e,this.options.crs.latLngToPoint(w(t),e)},unproject:function(t,e){return e=void 0===e?this._zoom:e,this.options.crs.pointToLatLng(m(t),e)},layerPointToLatLng:function(t){t=m(t).add(this.getPixelOrigin());return this.unproject(t)},latLngToLayerPoint:function(t){return this.project(w(t))._round()._subtract(this.getPixelOrigin())},wrapLatLng:function(t){return this.options.crs.wrapLatLng(w(t))},wrapLatLngBounds:function(t){return this.options.crs.wrapLatLngBounds(g(t))},distance:function(t,e){return this.options.crs.distance(w(t),w(e))},containerPointToLayerPoint:function(t){return m(t).subtract(this._getMapPanePos())},layerPointToContainerPoint:function(t){return m(t).add(this._getMapPanePos())},containerPointToLatLng:function(t){t=this.containerPointToLayerPoint(m(t));return this.layerPointToLatLng(t)},latLngToContainerPoint:function(t){return this.layerPointToContainerPoint(this.latLngToLayerPoint(w(t)))},mouseEventToContainerPoint:function(t){return De(t,this._container)},mouseEventToLayerPoint:function(t){return this.containerPointToLayerPoint(this.mouseEventToContainerPoint(t))},mouseEventToLatLng:function(t){return this.layerPointToLatLng(this.mouseEventToLayerPoint(t))},_initContainer:function(t){t=this._container=_e(t);if(!t)throw new Error("Map container not found.");if(t._leaflet_id)throw new Error("Map container is already initialized.");S(t,"scroll",this._onScroll,this),this._containerId=h(t)},_initLayout:function(){var t=this._container,e=(this._fadeAnimated=this.options.fadeAnimation&&b.any3d,M(t,"leaflet-container"+(b.touch?" leaflet-touch":"")+(b.retina?" leaflet-retina":"")+(b.ielt9?" leaflet-oldie":"")+(b.safari?" leaflet-safari":"")+(this._fadeAnimated?" leaflet-fade-anim":"")),pe(t,"position"));"absolute"!==e&&"relative"!==e&&"fixed"!==e&&"sticky"!==e&&(t.style.position="relative"),this._initPanes(),this._initControlPos&&this._initControlPos()},_initPanes:function(){var t=this._panes={};this._paneRenderers={},this._mapPane=this.createPane("mapPane",this._container),Z(this._mapPane,new p(0,0)),this.createPane("tilePane"),this.createPane("overlayPane"),this.createPane("shadowPane"),this.createPane("markerPane"),this.createPane("tooltipPane"),this.createPane("popupPane"),this.options.markerZoomAnimation||(M(t.markerPane,"leaflet-zoom-hide"),M(t.shadowPane,"leaflet-zoom-hide"))},_resetView:function(t,e,i){Z(this._mapPane,new p(0,0));var n=!this._loaded,o=(this._loaded=!0,e=this._limitZoom(e),this.fire("viewprereset"),this._zoom!==e);this._moveStart(o,i)._move(t,e)._moveEnd(o),this.fire("viewreset"),n&&this.fire("load")},_moveStart:function(t,e){return t&&this.fire("zoomstart"),e||this.fire("movestart"),this},_move:function(t,e,i,n){void 0===e&&(e=this._zoom);var o=this._zoom!==e;return this._zoom=e,this._lastCenter=t,this._pixelOrigin=this._getNewPixelOrigin(t),n?i&&i.pinch&&this.fire("zoom",i):((o||i&&i.pinch)&&this.fire("zoom",i),this.fire("move",i)),this},_moveEnd:function(t){return t&&this.fire("zoomend"),this.fire("moveend")},_stop:function(){return r(this._flyToFrame),this._panAnim&&this._panAnim.stop(),this},_rawPanBy:function(t){Z(this._mapPane,this._getMapPanePos().subtract(t))},_getZoomSpan:function(){return this.getMaxZoom()-this.getMinZoom()},_panInsideMaxBounds:function(){this._enforcingBounds||this.panInsideBounds(this.options.maxBounds)},_checkIfLoaded:function(){if(!this._loaded)throw new Error("Set map center and zoom first.")},_initEvents:function(t){this._targets={};var e=t?k:S;e((this._targets[h(this._container)]=this)._container,"click dblclick mousedown mouseup mouseover mouseout mousemove contextmenu keypress keydown keyup",this._handleDOMEvent,this),this.options.trackResize&&e(window,"resize",this._onResize,this),b.any3d&&this.options.transform3DLimit&&(t?this.off:this.on).call(this,"moveend",this._onMoveEnd)},_onResize:function(){r(this._resizeRequest),this._resizeRequest=x(function(){this.invalidateSize({debounceMoveend:!0})},this)},_onScroll:function(){this._container.scrollTop=0,this._container.scrollLeft=0},_onMoveEnd:function(){var t=this._getMapPanePos();Math.max(Math.abs(t.x),Math.abs(t.y))>=this.options.transform3DLimit&&this._resetView(this.getCenter(),this.getZoom())},_findEventTargets:function(t,e){for(var i,n=[],o="mouseout"===e||"mouseover"===e,s=t.target||t.srcElement,r=!1;s;){if((i=this._targets[h(s)])&&("click"===e||"preclick"===e)&&this._draggableMoved(i)){r=!0;break}if(i&&i.listens(e,!0)){if(o&&!We(s,t))break;if(n.push(i),o)break}if(s===this._container)break;s=s.parentNode}return n=n.length||r||o||!this.listens(e,!0)?n:[this]},_isClickDisabled:function(t){for(;t&&t!==this._container;){if(t._leaflet_disable_click)return!0;t=t.parentNode}},_handleDOMEvent:function(t){var e,i=t.target||t.srcElement;!this._loaded||i._leaflet_disable_events||"click"===t.type&&this._isClickDisabled(i)||("mousedown"===(e=t.type)&&Me(i),this._fireDOMEvent(t,e))},_mouseEvents:["click","dblclick","mouseover","mouseout","contextmenu"],_fireDOMEvent:function(t,e,i){"click"===t.type&&((a=l({},t)).type="preclick",this._fireDOMEvent(a,a.type,i));var n=this._findEventTargets(t,e);if(i){for(var o=[],s=0;sthis.options.zoomAnimationThreshold)return!1;var n=this.getZoomScale(e),n=this._getCenterOffset(t)._divideBy(1-1/n);if(!0!==i.animate&&!this.getSize().contains(n))return!1;x(function(){this._moveStart(!0,i.noMoveStart||!1)._animateZoom(t,e,!0)},this)}return!0},_animateZoom:function(t,e,i,n){this._mapPane&&(i&&(this._animatingZoom=!0,this._animateToCenter=t,this._animateToZoom=e,M(this._mapPane,"leaflet-zoom-anim")),this.fire("zoomanim",{center:t,zoom:e,noUpdate:n}),this._tempFireZoomEvent||(this._tempFireZoomEvent=this._zoom!==this._animateToZoom),this._move(this._animateToCenter,this._animateToZoom,void 0,!0),setTimeout(a(this._onZoomTransitionEnd,this),250))},_onZoomTransitionEnd:function(){this._animatingZoom&&(this._mapPane&&z(this._mapPane,"leaflet-zoom-anim"),this._animatingZoom=!1,this._move(this._animateToCenter,this._animateToZoom,void 0,!0),this._tempFireZoomEvent&&this.fire("zoom"),delete this._tempFireZoomEvent,this.fire("move"),this._moveEnd(!0))}});function Ue(t){return new B(t)}var B=et.extend({options:{position:"topright"},initialize:function(t){c(this,t)},getPosition:function(){return this.options.position},setPosition:function(t){var e=this._map;return e&&e.removeControl(this),this.options.position=t,e&&e.addControl(this),this},getContainer:function(){return this._container},addTo:function(t){this.remove(),this._map=t;var e=this._container=this.onAdd(t),i=this.getPosition(),t=t._controlCorners[i];return M(e,"leaflet-control"),-1!==i.indexOf("bottom")?t.insertBefore(e,t.firstChild):t.appendChild(e),this._map.on("unload",this.remove,this),this},remove:function(){return this._map&&(T(this._container),this.onRemove&&this.onRemove(this._map),this._map.off("unload",this.remove,this),this._map=null),this},_refocusOnMap:function(t){this._map&&t&&0",e=document.createElement("div");return e.innerHTML=t,e.firstChild},_addItem:function(t){var e,i=document.createElement("label"),n=this._map.hasLayer(t.layer),n=(t.overlay?((e=document.createElement("input")).type="checkbox",e.className="leaflet-control-layers-selector",e.defaultChecked=n):e=this._createRadioElement("leaflet-base-layers_"+h(this),n),this._layerControlInputs.push(e),e.layerId=h(t.layer),S(e,"click",this._onInputClick,this),document.createElement("span")),o=(n.innerHTML=" "+t.name,document.createElement("span"));return i.appendChild(o),o.appendChild(e),o.appendChild(n),(t.overlay?this._overlaysList:this._baseLayersList).appendChild(i),this._checkDisabledLayers(),i},_onInputClick:function(){if(!this._preventClick){var t,e,i=this._layerControlInputs,n=[],o=[];this._handlingClick=!0;for(var s=i.length-1;0<=s;s--)t=i[s],e=this._getLayer(t.layerId).layer,t.checked?n.push(e):t.checked||o.push(e);for(s=0;se.options.maxZoom},_expandIfNotCollapsed:function(){return this._map&&!this.options.collapsed&&this.expand(),this},_expandSafely:function(){var t=this._section,e=(this._preventClick=!0,S(t,"click",O),this.expand(),this);setTimeout(function(){k(t,"click",O),e._preventClick=!1})}})),qe=B.extend({options:{position:"topleft",zoomInText:'',zoomInTitle:"Zoom in",zoomOutText:'',zoomOutTitle:"Zoom out"},onAdd:function(t){var e="leaflet-control-zoom",i=P("div",e+" leaflet-bar"),n=this.options;return this._zoomInButton=this._createButton(n.zoomInText,n.zoomInTitle,e+"-in",i,this._zoomIn),this._zoomOutButton=this._createButton(n.zoomOutText,n.zoomOutTitle,e+"-out",i,this._zoomOut),this._updateDisabled(),t.on("zoomend zoomlevelschange",this._updateDisabled,this),i},onRemove:function(t){t.off("zoomend zoomlevelschange",this._updateDisabled,this)},disable:function(){return this._disabled=!0,this._updateDisabled(),this},enable:function(){return this._disabled=!1,this._updateDisabled(),this},_zoomIn:function(t){!this._disabled&&this._map._zoomthis._map.getMinZoom()&&this._map.zoomOut(this._map.options.zoomDelta*(t.shiftKey?3:1))},_createButton:function(t,e,i,n,o){i=P("a",i,n);return i.innerHTML=t,i.href="#",i.title=e,i.setAttribute("role","button"),i.setAttribute("aria-label",e),Ie(i),S(i,"click",Re),S(i,"click",o,this),S(i,"click",this._refocusOnMap,this),i},_updateDisabled:function(){var t=this._map,e="leaflet-disabled";z(this._zoomInButton,e),z(this._zoomOutButton,e),this._zoomInButton.setAttribute("aria-disabled","false"),this._zoomOutButton.setAttribute("aria-disabled","false"),!this._disabled&&t._zoom!==t.getMinZoom()||(M(this._zoomOutButton,e),this._zoomOutButton.setAttribute("aria-disabled","true")),!this._disabled&&t._zoom!==t.getMaxZoom()||(M(this._zoomInButton,e),this._zoomInButton.setAttribute("aria-disabled","true"))}}),Ge=(A.mergeOptions({zoomControl:!0}),A.addInitHook(function(){this.options.zoomControl&&(this.zoomControl=new qe,this.addControl(this.zoomControl))}),B.extend({options:{position:"bottomleft",maxWidth:100,metric:!0,imperial:!0},onAdd:function(t){var e="leaflet-control-scale",i=P("div",e),n=this.options;return this._addScales(n,e+"-line",i),t.on(n.updateWhenIdle?"moveend":"move",this._update,this),t.whenReady(this._update,this),i},onRemove:function(t){t.off(this.options.updateWhenIdle?"moveend":"move",this._update,this)},_addScales:function(t,e,i){t.metric&&(this._mScale=P("div",e,i)),t.imperial&&(this._iScale=P("div",e,i))},_update:function(){var t=this._map,e=t.getSize().y/2,t=t.distance(t.containerPointToLatLng([0,e]),t.containerPointToLatLng([this.options.maxWidth,e]));this._updateScales(t)},_updateScales:function(t){this.options.metric&&t&&this._updateMetric(t),this.options.imperial&&t&&this._updateImperial(t)},_updateMetric:function(t){var e=this._getRoundNum(t);this._updateScale(this._mScale,e<1e3?e+" m":e/1e3+" km",e/t)},_updateImperial:function(t){var e,i,t=3.2808399*t;5280'+(b.inlineSvg?' ':"")+"Leaflet"},initialize:function(t){c(this,t),this._attributions={}},onAdd:function(t){for(var e in(t.attributionControl=this)._container=P("div","leaflet-control-attribution"),Ie(this._container),t._layers)t._layers[e].getAttribution&&this.addAttribution(t._layers[e].getAttribution());return this._update(),t.on("layeradd",this._addAttribution,this),this._container},onRemove:function(t){t.off("layeradd",this._addAttribution,this)},_addAttribution:function(t){t.layer.getAttribution&&(this.addAttribution(t.layer.getAttribution()),t.layer.once("remove",function(){this.removeAttribution(t.layer.getAttribution())},this))},setPrefix:function(t){return this.options.prefix=t,this._update(),this},addAttribution:function(t){return t&&(this._attributions[t]||(this._attributions[t]=0),this._attributions[t]++,this._update()),this},removeAttribution:function(t){return t&&this._attributions[t]&&(this._attributions[t]--,this._update()),this},_update:function(){if(this._map){var t,e=[];for(t in this._attributions)this._attributions[t]&&e.push(t);var i=[];this.options.prefix&&i.push(this.options.prefix),e.length&&i.push(e.join(", ")),this._container.innerHTML=i.join(' ')}}}),n=(A.mergeOptions({attributionControl:!0}),A.addInitHook(function(){this.options.attributionControl&&(new Ke).addTo(this)}),B.Layers=Ve,B.Zoom=qe,B.Scale=Ge,B.Attribution=Ke,Ue.layers=function(t,e,i){return new Ve(t,e,i)},Ue.zoom=function(t){return new qe(t)},Ue.scale=function(t){return new Ge(t)},Ue.attribution=function(t){return new Ke(t)},et.extend({initialize:function(t){this._map=t},enable:function(){return this._enabled||(this._enabled=!0,this.addHooks()),this},disable:function(){return this._enabled&&(this._enabled=!1,this.removeHooks()),this},enabled:function(){return!!this._enabled}})),ft=(n.addTo=function(t,e){return t.addHandler(e,this),this},{Events:e}),Ye=b.touch?"touchstart mousedown":"mousedown",Xe=it.extend({options:{clickTolerance:3},initialize:function(t,e,i,n){c(this,n),this._element=t,this._dragStartTarget=e||t,this._preventOutline=i},enable:function(){this._enabled||(S(this._dragStartTarget,Ye,this._onDown,this),this._enabled=!0)},disable:function(){this._enabled&&(Xe._dragging===this&&this.finishDrag(!0),k(this._dragStartTarget,Ye,this._onDown,this),this._enabled=!1,this._moved=!1)},_onDown:function(t){var e,i;this._enabled&&(this._moved=!1,ve(this._element,"leaflet-zoom-anim")||(t.touches&&1!==t.touches.length?Xe._dragging===this&&this.finishDrag():Xe._dragging||t.shiftKey||1!==t.which&&1!==t.button&&!t.touches||((Xe._dragging=this)._preventOutline&&Me(this._element),Le(),re(),this._moving||(this.fire("down"),i=t.touches?t.touches[0]:t,e=Ce(this._element),this._startPoint=new p(i.clientX,i.clientY),this._startPos=Pe(this._element),this._parentScale=Ze(e),i="mousedown"===t.type,S(document,i?"mousemove":"touchmove",this._onMove,this),S(document,i?"mouseup":"touchend touchcancel",this._onUp,this)))))},_onMove:function(t){var e;this._enabled&&(t.touches&&1e&&(i.push(t[n]),o=n);oe.max.x&&(i|=2),t.ye.max.y&&(i|=8),i}function ri(t,e,i,n){var o=e.x,e=e.y,s=i.x-o,r=i.y-e,a=s*s+r*r;return 0this._layersMaxZoom&&this.setZoom(this._layersMaxZoom),void 0===this.options.minZoom&&this._layersMinZoom&&this.getZoom()t.y!=n.y>t.y&&t.x<(n.x-i.x)*(t.y-i.y)/(n.y-i.y)+i.x&&(l=!l);return l||yi.prototype._containsPoint.call(this,t,!0)}});var wi=ci.extend({initialize:function(t,e){c(this,e),this._layers={},t&&this.addData(t)},addData:function(t){var e,i,n,o=d(t)?t:t.features;if(o){for(e=0,i=o.length;es.x&&(r=i.x+a-s.x+o.x),i.x-r-n.x<(a=0)&&(r=i.x-n.x),i.y+e+o.y>s.y&&(a=i.y+e-s.y+o.y),i.y-a-n.y<0&&(a=i.y-n.y),(r||a)&&(this.options.keepInView&&(this._autopanning=!0),t.fire("autopanstart").panBy([r,a]))))},_getAnchor:function(){return m(this._source&&this._source._getPopupAnchor?this._source._getPopupAnchor():[0,0])}})),Ii=(A.mergeOptions({closePopupOnClick:!0}),A.include({openPopup:function(t,e,i){return this._initOverlay(Bi,t,e,i).openOn(this),this},closePopup:function(t){return(t=arguments.length?t:this._popup)&&t.close(),this}}),o.include({bindPopup:function(t,e){return this._popup=this._initOverlay(Bi,this._popup,t,e),this._popupHandlersAdded||(this.on({click:this._openPopup,keypress:this._onKeyPress,remove:this.closePopup,move:this._movePopup}),this._popupHandlersAdded=!0),this},unbindPopup:function(){return this._popup&&(this.off({click:this._openPopup,keypress:this._onKeyPress,remove:this.closePopup,move:this._movePopup}),this._popupHandlersAdded=!1,this._popup=null),this},openPopup:function(t){return this._popup&&(this instanceof ci||(this._popup._source=this),this._popup._prepareOpen(t||this._latlng)&&this._popup.openOn(this._map)),this},closePopup:function(){return this._popup&&this._popup.close(),this},togglePopup:function(){return this._popup&&this._popup.toggle(this),this},isPopupOpen:function(){return!!this._popup&&this._popup.isOpen()},setPopupContent:function(t){return this._popup&&this._popup.setContent(t),this},getPopup:function(){return this._popup},_openPopup:function(t){var e;this._popup&&this._map&&(Re(t),e=t.layer||t.target,this._popup._source!==e||e instanceof fi?(this._popup._source=e,this.openPopup(t.latlng)):this._map.hasLayer(this._popup)?this.closePopup():this.openPopup(t.latlng))},_movePopup:function(t){this._popup.setLatLng(t.latlng)},_onKeyPress:function(t){13===t.originalEvent.keyCode&&this._openPopup(t)}}),Ai.extend({options:{pane:"tooltipPane",offset:[0,0],direction:"auto",permanent:!1,sticky:!1,opacity:.9},onAdd:function(t){Ai.prototype.onAdd.call(this,t),this.setOpacity(this.options.opacity),t.fire("tooltipopen",{tooltip:this}),this._source&&(this.addEventParent(this._source),this._source.fire("tooltipopen",{tooltip:this},!0))},onRemove:function(t){Ai.prototype.onRemove.call(this,t),t.fire("tooltipclose",{tooltip:this}),this._source&&(this.removeEventParent(this._source),this._source.fire("tooltipclose",{tooltip:this},!0))},getEvents:function(){var t=Ai.prototype.getEvents.call(this);return this.options.permanent||(t.preclick=this.close),t},_initLayout:function(){var t="leaflet-tooltip "+(this.options.className||"")+" leaflet-zoom-"+(this._zoomAnimated?"animated":"hide");this._contentNode=this._container=P("div",t),this._container.setAttribute("role","tooltip"),this._container.setAttribute("id","leaflet-tooltip-"+h(this))},_updateLayout:function(){},_adjustPan:function(){},_setPosition:function(t){var e,i=this._map,n=this._container,o=i.latLngToContainerPoint(i.getCenter()),i=i.layerPointToContainerPoint(t),s=this.options.direction,r=n.offsetWidth,a=n.offsetHeight,h=m(this.options.offset),l=this._getAnchor(),i="top"===s?(e=r/2,a):"bottom"===s?(e=r/2,0):(e="center"===s?r/2:"right"===s?0:"left"===s?r:i.xthis.options.maxZoom||nthis.options.maxZoom||void 0!==this.options.minZoom&&oi.max.x)||!e.wrapLat&&(t.yi.max.y))return!1}return!this.options.bounds||(e=this._tileCoordsToBounds(t),g(this.options.bounds).overlaps(e))},_keyToBounds:function(t){return this._tileCoordsToBounds(this._keyToTileCoords(t))},_tileCoordsToNwSe:function(t){var e=this._map,i=this.getTileSize(),n=t.scaleBy(i),i=n.add(i);return[e.unproject(n,t.z),e.unproject(i,t.z)]},_tileCoordsToBounds:function(t){t=this._tileCoordsToNwSe(t),t=new s(t[0],t[1]);return t=this.options.noWrap?t:this._map.wrapLatLngBounds(t)},_tileCoordsToKey:function(t){return t.x+":"+t.y+":"+t.z},_keyToTileCoords:function(t){var t=t.split(":"),e=new p(+t[0],+t[1]);return e.z=+t[2],e},_removeTile:function(t){var e=this._tiles[t];e&&(T(e.el),delete this._tiles[t],this.fire("tileunload",{tile:e.el,coords:this._keyToTileCoords(t)}))},_initTile:function(t){M(t,"leaflet-tile");var e=this.getTileSize();t.style.width=e.x+"px",t.style.height=e.y+"px",t.onselectstart=u,t.onmousemove=u,b.ielt9&&this.options.opacity<1&&C(t,this.options.opacity)},_addTile:function(t,e){var i=this._getTilePos(t),n=this._tileCoordsToKey(t),o=this.createTile(this._wrapCoords(t),a(this._tileReady,this,t));this._initTile(o),this.createTile.length<2&&x(a(this._tileReady,this,t,null,o)),Z(o,i),this._tiles[n]={el:o,coords:t,current:!0},e.appendChild(o),this.fire("tileloadstart",{tile:o,coords:t})},_tileReady:function(t,e,i){e&&this.fire("tileerror",{error:e,tile:i,coords:t});var n=this._tileCoordsToKey(t);(i=this._tiles[n])&&(i.loaded=+new Date,this._map._fadeAnimated?(C(i.el,0),r(this._fadeFrame),this._fadeFrame=x(this._updateOpacity,this)):(i.active=!0,this._pruneTiles()),e||(M(i.el,"leaflet-tile-loaded"),this.fire("tileload",{tile:i.el,coords:t})),this._noTilesToLoad()&&(this._loading=!1,this.fire("load"),b.ielt9||!this._map._fadeAnimated?x(this._pruneTiles,this):setTimeout(a(this._pruneTiles,this),250)))},_getTilePos:function(t){return t.scaleBy(this.getTileSize()).subtract(this._level.origin)},_wrapCoords:function(t){var e=new p(this._wrapX?H(t.x,this._wrapX):t.x,this._wrapY?H(t.y,this._wrapY):t.y);return e.z=t.z,e},_pxBoundsToTileRange:function(t){var e=this.getTileSize();return new f(t.min.unscaleBy(e).floor(),t.max.unscaleBy(e).ceil().subtract([1,1]))},_noTilesToLoad:function(){for(var t in this._tiles)if(!this._tiles[t].loaded)return!1;return!0}});var Di=Ni.extend({options:{minZoom:0,maxZoom:18,subdomains:"abc",errorTileUrl:"",zoomOffset:0,tms:!1,zoomReverse:!1,detectRetina:!1,crossOrigin:!1,referrerPolicy:!1},initialize:function(t,e){this._url=t,(e=c(this,e)).detectRetina&&b.retina&&0')}}catch(t){}return function(t){return document.createElement("<"+t+' xmlns="urn:schemas-microsoft.com:vml" class="lvml">')}}(),zt={_initContainer:function(){this._container=P("div","leaflet-vml-container")},_update:function(){this._map._animatingZoom||(Wi.prototype._update.call(this),this.fire("update"))},_initPath:function(t){var e=t._container=Vi("shape");M(e,"leaflet-vml-shape "+(this.options.className||"")),e.coordsize="1 1",t._path=Vi("path"),e.appendChild(t._path),this._updateStyle(t),this._layers[h(t)]=t},_addPath:function(t){var e=t._container;this._container.appendChild(e),t.options.interactive&&t.addInteractiveTarget(e)},_removePath:function(t){var e=t._container;T(e),t.removeInteractiveTarget(e),delete this._layers[h(t)]},_updateStyle:function(t){var e=t._stroke,i=t._fill,n=t.options,o=t._container;o.stroked=!!n.stroke,o.filled=!!n.fill,n.stroke?(e=e||(t._stroke=Vi("stroke")),o.appendChild(e),e.weight=n.weight+"px",e.color=n.color,e.opacity=n.opacity,n.dashArray?e.dashStyle=d(n.dashArray)?n.dashArray.join(" "):n.dashArray.replace(/( *, *)/g," "):e.dashStyle="",e.endcap=n.lineCap.replace("butt","flat"),e.joinstyle=n.lineJoin):e&&(o.removeChild(e),t._stroke=null),n.fill?(i=i||(t._fill=Vi("fill")),o.appendChild(i),i.color=n.fillColor||n.color,i.opacity=n.fillOpacity):i&&(o.removeChild(i),t._fill=null)},_updateCircle:function(t){var e=t._point.round(),i=Math.round(t._radius),n=Math.round(t._radiusY||i);this._setPath(t,t._empty()?"M0 0":"AL "+e.x+","+e.y+" "+i+","+n+" 0,23592600")},_setPath:function(t,e){t._path.v=e},_bringToFront:function(t){fe(t._container)},_bringToBack:function(t){ge(t._container)}},qi=b.vml?Vi:ct,Gi=Wi.extend({_initContainer:function(){this._container=qi("svg"),this._container.setAttribute("pointer-events","none"),this._rootGroup=qi("g"),this._container.appendChild(this._rootGroup)},_destroyContainer:function(){T(this._container),k(this._container),delete this._container,delete this._rootGroup,delete this._svgSize},_update:function(){var t,e,i;this._map._animatingZoom&&this._bounds||(Wi.prototype._update.call(this),e=(t=this._bounds).getSize(),i=this._container,this._svgSize&&this._svgSize.equals(e)||(this._svgSize=e,i.setAttribute("width",e.x),i.setAttribute("height",e.y)),Z(i,t.min),i.setAttribute("viewBox",[t.min.x,t.min.y,e.x,e.y].join(" ")),this.fire("update"))},_initPath:function(t){var e=t._path=qi("path");t.options.className&&M(e,t.options.className),t.options.interactive&&M(e,"leaflet-interactive"),this._updateStyle(t),this._layers[h(t)]=t},_addPath:function(t){this._rootGroup||this._initContainer(),this._rootGroup.appendChild(t._path),t.addInteractiveTarget(t._path)},_removePath:function(t){T(t._path),t.removeInteractiveTarget(t._path),delete this._layers[h(t)]},_updatePath:function(t){t._project(),t._update()},_updateStyle:function(t){var e=t._path,t=t.options;e&&(t.stroke?(e.setAttribute("stroke",t.color),e.setAttribute("stroke-opacity",t.opacity),e.setAttribute("stroke-width",t.weight),e.setAttribute("stroke-linecap",t.lineCap),e.setAttribute("stroke-linejoin",t.lineJoin),t.dashArray?e.setAttribute("stroke-dasharray",t.dashArray):e.removeAttribute("stroke-dasharray"),t.dashOffset?e.setAttribute("stroke-dashoffset",t.dashOffset):e.removeAttribute("stroke-dashoffset")):e.setAttribute("stroke","none"),t.fill?(e.setAttribute("fill",t.fillColor||t.color),e.setAttribute("fill-opacity",t.fillOpacity),e.setAttribute("fill-rule",t.fillRule||"evenodd")):e.setAttribute("fill","none"))},_updatePoly:function(t,e){this._setPath(t,dt(t._parts,e))},_updateCircle:function(t){var e=t._point,i=Math.max(Math.round(t._radius),1),n="a"+i+","+(Math.max(Math.round(t._radiusY),1)||i)+" 0 1,0 ",e=t._empty()?"M0 0":"M"+(e.x-i)+","+e.y+n+2*i+",0 "+n+2*-i+",0 ";this._setPath(t,e)},_setPath:function(t,e){t._path.setAttribute("d",e)},_bringToFront:function(t){fe(t._path)},_bringToBack:function(t){ge(t._path)}});function Ki(t){return b.svg||b.vml?new Gi(t):null}b.vml&&Gi.include(zt),A.include({getRenderer:function(t){t=(t=t.options.renderer||this._getPaneRenderer(t.options.pane)||this.options.renderer||this._renderer)||(this._renderer=this._createRenderer());return this.hasLayer(t)||this.addLayer(t),t},_getPaneRenderer:function(t){var e;return"overlayPane"!==t&&void 0!==t&&(void 0===(e=this._paneRenderers[t])&&(e=this._createRenderer({pane:t}),this._paneRenderers[t]=e),e)},_createRenderer:function(t){return this.options.preferCanvas&&Ui(t)||Ki(t)}});var Yi=xi.extend({initialize:function(t,e){xi.prototype.initialize.call(this,this._boundsToLatLngs(t),e)},setBounds:function(t){return this.setLatLngs(this._boundsToLatLngs(t))},_boundsToLatLngs:function(t){return[(t=g(t)).getSouthWest(),t.getNorthWest(),t.getNorthEast(),t.getSouthEast()]}});Gi.create=qi,Gi.pointsToPath=dt,wi.geometryToLayer=bi,wi.coordsToLatLng=Li,wi.coordsToLatLngs=Ti,wi.latLngToCoords=Mi,wi.latLngsToCoords=zi,wi.getFeature=Ci,wi.asFeature=Zi,A.mergeOptions({boxZoom:!0});var _t=n.extend({initialize:function(t){this._map=t,this._container=t._container,this._pane=t._panes.overlayPane,this._resetStateTimeout=0,t.on("unload",this._destroy,this)},addHooks:function(){S(this._container,"mousedown",this._onMouseDown,this)},removeHooks:function(){k(this._container,"mousedown",this._onMouseDown,this)},moved:function(){return this._moved},_destroy:function(){T(this._pane),delete this._pane},_resetState:function(){this._resetStateTimeout=0,this._moved=!1},_clearDeferredResetState:function(){0!==this._resetStateTimeout&&(clearTimeout(this._resetStateTimeout),this._resetStateTimeout=0)},_onMouseDown:function(t){if(!t.shiftKey||1!==t.which&&1!==t.button)return!1;this._clearDeferredResetState(),this._resetState(),re(),Le(),this._startPoint=this._map.mouseEventToContainerPoint(t),S(document,{contextmenu:Re,mousemove:this._onMouseMove,mouseup:this._onMouseUp,keydown:this._onKeyDown},this)},_onMouseMove:function(t){this._moved||(this._moved=!0,this._box=P("div","leaflet-zoom-box",this._container),M(this._container,"leaflet-crosshair"),this._map.fire("boxzoomstart")),this._point=this._map.mouseEventToContainerPoint(t);var t=new f(this._point,this._startPoint),e=t.getSize();Z(this._box,t.min),this._box.style.width=e.x+"px",this._box.style.height=e.y+"px"},_finish:function(){this._moved&&(T(this._box),z(this._container,"leaflet-crosshair")),ae(),Te(),k(document,{contextmenu:Re,mousemove:this._onMouseMove,mouseup:this._onMouseUp,keydown:this._onKeyDown},this)},_onMouseUp:function(t){1!==t.which&&1!==t.button||(this._finish(),this._moved&&(this._clearDeferredResetState(),this._resetStateTimeout=setTimeout(a(this._resetState,this),0),t=new s(this._map.containerPointToLatLng(this._startPoint),this._map.containerPointToLatLng(this._point)),this._map.fitBounds(t).fire("boxzoomend",{boxZoomBounds:t})))},_onKeyDown:function(t){27===t.keyCode&&(this._finish(),this._clearDeferredResetState(),this._resetState())}}),Ct=(A.addInitHook("addHandler","boxZoom",_t),A.mergeOptions({doubleClickZoom:!0}),n.extend({addHooks:function(){this._map.on("dblclick",this._onDoubleClick,this)},removeHooks:function(){this._map.off("dblclick",this._onDoubleClick,this)},_onDoubleClick:function(t){var e=this._map,i=e.getZoom(),n=e.options.zoomDelta,i=t.originalEvent.shiftKey?i-n:i+n;"center"===e.options.doubleClickZoom?e.setZoom(i):e.setZoomAround(t.containerPoint,i)}})),Zt=(A.addInitHook("addHandler","doubleClickZoom",Ct),A.mergeOptions({dragging:!0,inertia:!0,inertiaDeceleration:3400,inertiaMaxSpeed:1/0,easeLinearity:.2,worldCopyJump:!1,maxBoundsViscosity:0}),n.extend({addHooks:function(){var t;this._draggable||(t=this._map,this._draggable=new Xe(t._mapPane,t._container),this._draggable.on({dragstart:this._onDragStart,drag:this._onDrag,dragend:this._onDragEnd},this),this._draggable.on("predrag",this._onPreDragLimit,this),t.options.worldCopyJump&&(this._draggable.on("predrag",this._onPreDragWrap,this),t.on("zoomend",this._onZoomEnd,this),t.whenReady(this._onZoomEnd,this))),M(this._map._container,"leaflet-grab leaflet-touch-drag"),this._draggable.enable(),this._positions=[],this._times=[]},removeHooks:function(){z(this._map._container,"leaflet-grab"),z(this._map._container,"leaflet-touch-drag"),this._draggable.disable()},moved:function(){return this._draggable&&this._draggable._moved},moving:function(){return this._draggable&&this._draggable._moving},_onDragStart:function(){var t,e=this._map;e._stop(),this._map.options.maxBounds&&this._map.options.maxBoundsViscosity?(t=g(this._map.options.maxBounds),this._offsetLimit=_(this._map.latLngToContainerPoint(t.getNorthWest()).multiplyBy(-1),this._map.latLngToContainerPoint(t.getSouthEast()).multiplyBy(-1).add(this._map.getSize())),this._viscosity=Math.min(1,Math.max(0,this._map.options.maxBoundsViscosity))):this._offsetLimit=null,e.fire("movestart").fire("dragstart"),e.options.inertia&&(this._positions=[],this._times=[])},_onDrag:function(t){var e,i;this._map.options.inertia&&(e=this._lastTime=+new Date,i=this._lastPos=this._draggable._absPos||this._draggable._newPos,this._positions.push(i),this._times.push(e),this._prunePositions(e)),this._map.fire("move",t).fire("drag",t)},_prunePositions:function(t){for(;1e.max.x&&(t.x=this._viscousLimit(t.x,e.max.x)),t.y>e.max.y&&(t.y=this._viscousLimit(t.y,e.max.y)),this._draggable._newPos=this._draggable._startPos.add(t))},_onPreDragWrap:function(){var t=this._worldWidth,e=Math.round(t/2),i=this._initialWorldOffset,n=this._draggable._newPos.x,o=(n-e+i)%t+e-i,n=(n+e+i)%t-e-i,t=Math.abs(o+i)e.getMaxZoom()&&1= 0) + handlers.splice(idx, 1); + if (handlers.length === 0) { + delete this.__objectSignals__[signalIdx]; + this.__transport__.send(JSON.stringify({ + type: QWebChannelMessageTypes.disconnectFromSignal, + object: this.__id__, + signal: signalIdx + })); + } +}; + +QObject.prototype.__signalEmitted__ = function(signalName, args) { + var handlers = this.__objectSignals__[signalName]; + if (handlers) { + handlers.forEach(function(cb) { cb.apply(null, args); }); + } +}; + +function QWebChannel(transport, initCallback) { + this.transport = transport; + this.objects = {}; + + var channel = this; + this.transport.onmessage = function(message) { + var data = JSON.parse(message.data); + switch (data.type) { + case QWebChannelMessageTypes.init: + Object.keys(data.data).forEach(function(name) { + channel.objects[name] = new QObject(name, data.data[name], transport); + }); + if (initCallback) + initCallback(channel); + break; + case QWebChannelMessageTypes.signal: + var object = channel.objects[data.object]; + if (object) + object.__signalEmitted__(data.signal, data.args); + break; + case QWebChannelMessageTypes.propertyUpdate: + Object.keys(data.data).forEach(function(objName) { + var obj = channel.objects[objName]; + var props = data.data[objName]; + Object.keys(props).forEach(function(propName) { + obj[propName] = props[propName]; + }); + }); + break; + case QWebChannelMessageTypes.response: + break; + } + }; + + this.exec = function(data) { + this.transport.send(JSON.stringify(data)); + }; +} + +// Export +if (typeof module !== "undefined" && module.exports) { + module.exports = QWebChannel; +} else { + window.QWebChannel = QWebChannel; +} +})(); diff --git a/GUI/src/vast/views/assets/sensors_map.html b/GUI/src/vast/views/assets/sensors_map.html new file mode 100644 index 000000000..3a711258b --- /dev/null +++ b/GUI/src/vast/views/assets/sensors_map.html @@ -0,0 +1,238 @@ + + + + + AgCloud – Sensor Map + + + + + +
+
+
Sensor Dashboard
+
+
Normal
+
Soil Moisture
+
Temperature
+
Humidity
+
+
+ +
+
+
+
+ + + + + + + + + + + + + diff --git a/GUI/src/vast/views/assets/zones.geojson b/GUI/src/vast/views/assets/zones.geojson new file mode 100644 index 000000000..bddcf348b --- /dev/null +++ b/GUI/src/vast/views/assets/zones.geojson @@ -0,0 +1,24 @@ + + + +{ + "type": "FeatureCollection", + "features": [ + { + "type": "Feature", + "properties": {"name": "Zone A"}, + "geometry": { + "type": "Polygon", + "coordinates": [[[34.75, 32.00], [34.90, 32.00], [34.90, 32.10], [34.75, 32.10], [34.75, 32.00]]] + } + }, + { + "type": "Feature", + "properties": {"name": "Zone B"}, + "geometry": { + "type": "Polygon", + "coordinates": [[[34.90, 31.95], [35.05, 31.95], [35.05, 32.05], [34.90, 32.05], [34.90, 31.95]]] + } + } + ] +} diff --git a/GUI/src/vast/views/auth_status_view.py b/GUI/src/vast/views/auth_status_view.py new file mode 100644 index 000000000..1a4ea8563 --- /dev/null +++ b/GUI/src/vast/views/auth_status_view.py @@ -0,0 +1,316 @@ + +from __future__ import annotations +import os, time, jwt, requests, json +from PyQt6.QtWidgets import ( + QWidget, QVBoxLayout, QHBoxLayout, QLabel, QLineEdit, QPushButton, + QComboBox, QTableWidget, QTableWidgetItem, QTextEdit, QFrame, + QMessageBox, QProgressDialog +) +from PyQt6.QtCore import Qt, QTimer +from PyQt6 import sip + + +def _line(): + line = QFrame() + line.setFrameShape(QFrame.Shape.HLine) + line.setFrameShadow(QFrame.Shadow.Sunken) + return line + + +class AuthStatusView(QWidget): + def __init__(self, api, parent=None): + super().__init__(parent) + self.api = api + self.access_token = None + self.refresh_token = None + self.expiry_ts = None + self.all_data = [] + + self.setStyleSheet(""" + QWidget { + background-color: #101010; + color: #e6e6e6; + font-family: 'Segoe UI', sans-serif; + font-size: 14px; + } + QLineEdit, QComboBox { + background-color: #1a1a1a; + color: #e6e6e6; + border: 1px solid #333; + border-radius: 4px; + padding: 6px; + } + QPushButton { + background-color: #2d89ef; + color: white; + border: none; + padding: 8px 14px; + border-radius: 6px; + font-weight: 600; + } + QPushButton:hover { background-color: #1e5fb4; } + QTableWidget { + background-color: #1a1a1a; + gridline-color: #333; + color: #e6e6e6; + border: 1px solid #333; + border-radius: 6px; + } + QTextEdit { + background-color: #181818; + border: 1px solid #333; + color: #cccccc; + font-family: Consolas, monospace; + font-size: 12px; + } + QLabel#Title { + font-size: 22px; + font-weight: 700; + color: #00bcd4; + } + QFrame#Card { + background-color: #141414; + border: 1px solid #333; + border-radius: 10px; + padding: 14px; + } + """) + + layout = QVBoxLayout(self) + layout.setContentsMargins(20, 20, 20, 20) + layout.setSpacing(15) + + title = QLabel("User Data Dashboard") + title.setAlignment(Qt.AlignmentFlag.AlignCenter) + title.setObjectName("Title") + layout.addWidget(title) + layout.addWidget(_line()) + + login_card = QFrame() + login_card.setObjectName("Card") + login_layout = QHBoxLayout(login_card) + self.user_edit = QLineEdit() + self.user_edit.setPlaceholderText("Username") + self.pass_edit = QLineEdit() + self.pass_edit.setEchoMode(QLineEdit.EchoMode.Password) + self.pass_edit.setPlaceholderText("Password") + self.btn_login = QPushButton("Login") + login_layout.addWidget(self.user_edit) + login_layout.addWidget(self.pass_edit) + login_layout.addWidget(self.btn_login) + layout.addWidget(login_card) + + token_card = QFrame() + token_card.setObjectName("Card") + token_layout = QVBoxLayout(token_card) + self.tokens_display = QTextEdit() + self.tokens_display.setReadOnly(True) + token_layout.addWidget(self.tokens_display) + layout.addWidget(token_card) + layout.addWidget(_line()) + + tables_env = os.getenv("TABLES_LIST", "devices") + self.tables = [t.strip() for t in tables_env.split(",") if t.strip()] + select_card = QFrame() + select_card.setObjectName("Card") + select_layout = QHBoxLayout(select_card) + self.table_combo = QComboBox() + self.table_combo.addItems(self.tables) + self.btn_load = QPushButton("Load Table Data") + select_layout.addWidget(QLabel("Select Table:")) + select_layout.addWidget(self.table_combo, 1) + select_layout.addWidget(self.btn_load) + layout.addWidget(select_card) + + search_card = QFrame() + search_card.setObjectName("Card") + search_layout = QHBoxLayout(search_card) + self.search_edit = QLineEdit() + self.search_edit.setPlaceholderText("Search in table...") + search_layout.addWidget(self.search_edit) + layout.addWidget(search_card) + + self.table_widget = QTableWidget() + layout.addWidget(self.table_widget, 1) + + self.progress = QProgressDialog("Loading data...", None, 0, 0, self) + self.progress.setWindowTitle("Please Wait") + self.progress.setCancelButton(None) + self.progress.setWindowModality(Qt.WindowModality.ApplicationModal) + self.progress.setStyleSheet(""" + QProgressDialog { + background-color: #222; + color: white; + border: 2px solid #00bcd4; + border-radius: 10px; + font-size: 16px; + padding: 15px; + } + """) + self.progress.close() + + self.btn_login.clicked.connect(self._login) + self.btn_load.clicked.connect(self._load_table) + self.search_edit.textChanged.connect(self._filter_table) + + self.timer = QTimer(self) + self.timer.timeout.connect(self._update_expiry_timer) + self.timer.start(1000) + + def _login(self): + user = self.user_edit.text().strip() + password = self.pass_edit.text().strip() + if not user or not password: + QMessageBox.warning(self, "Missing Data", "Please enter both username and password.") + return + try: + url = f"{self.api.base}/auth/login" + data = {"username": user, "password": password} + r = requests.post(url, data=data, timeout=10) + if r.status_code == 200: + js = r.json() + old_token = self.access_token + self.access_token = js.get("access_token") + self.refresh_token = js.get("refresh_token") + self.api.http.headers.update({"Authorization": f"Bearer {self.access_token}"}) + try: + payload = jwt.decode(self.access_token, options={"verify_signature": False}) + self.expiry_ts = payload.get("exp") + except Exception: + self.expiry_ts = None + msg_prefix = "✅ Access Token updated!\n\n" if old_token and self.access_token != old_token else "" + self.tokens_display.setPlainText( + f"{msg_prefix}" + f"Access Token:\n{self.access_token}\n\n" + f"Refresh Token:\n{self.refresh_token}" + ) + QMessageBox.information(self, "Login Successful", "User authenticated successfully.") + else: + QMessageBox.warning(self, "Login Failed", f"Error {r.status_code}: {r.text[:200]}") + except Exception as e: + QMessageBox.critical(self, "Error", f"Failed to login:\n{e}") + + def _update_expiry_timer(self): + if not self.expiry_ts or sip.isdeleted(self.tokens_display): + return + now = int(time.time()) + secs_left = self.expiry_ts - now + if secs_left < 0: + msg = "⚠ Token expired." + else: + mins, secs = divmod(secs_left, 60) + msg = f"Token expires in {mins:02d}:{secs:02d}" + self.tokens_display.setToolTip(msg) + + def _load_table(self): + if not self.access_token: + if not sip.isdeleted(self): + QMessageBox.warning(self, "Not Authenticated", "Please login first.") + return + + table_name = self.table_combo.currentText() + url = f"{self.api.base}/api/tables/{table_name}" + + try: + if sip.isdeleted(self): + return + self.progress.show() + self.repaint() + + r = self.api.http.get(url, timeout=20) + + if r.status_code == 200: + data = r.json() + if not sip.isdeleted(self) and self.isVisible(): + self._populate_table(data) + elif not sip.isdeleted(self): + QMessageBox.warning(self, "Request Failed", f"{r.status_code}: {r.text[:200]}") + + except Exception as e: + if not sip.isdeleted(self): + QMessageBox.critical(self, "Error", f"Request failed:\n{e}") + finally: + if hasattr(self, "progress") and not sip.isdeleted(self.progress): + self.progress.close() + + + def _populate_table(self, data): + if sip.isdeleted(self) or sip.isdeleted(self.table_widget) or not self.isVisible(): + return + + # --- normalize input --- + if isinstance(data, str): + try: + data = json.loads(data) + except Exception: + data = [{"value": data}] + if isinstance(data, dict) and "rows" in data: + data = data["rows"] + if not isinstance(data, list): + data = [data] if data else [] + + # --- handle empty --- + if not data: + self.table_widget.clear() + self.table_widget.setRowCount(0) + self.table_widget.setColumnCount(0) + if not sip.isdeleted(self): + QMessageBox.information(self, "Empty", "No data found for this table.") + return + + # --- normalize rows to dicts --- + normalized = [] + for row in data: + if not isinstance(row, dict): + try: + row = dict(row) + except Exception: + row = {"value": str(row)} + normalized.append(row) + data = normalized + + # --- build header keys --- + keys = sorted({k for row in data for k in row.keys()}) + if sip.isdeleted(self.table_widget): + return + + self.table_widget.setColumnCount(len(keys)) + self.table_widget.setRowCount(len(data)) + self.table_widget.setHorizontalHeaderLabels(keys) + + # --- fill cells safely --- + for i, row in enumerate(data): + if sip.isdeleted(self.table_widget): + return + for j, key in enumerate(keys): + val = row.get(key, "") + try: + if isinstance(val, (dict, list)): + val = json.dumps(val, ensure_ascii=False, indent=2) + elif val is None: + val = "" + else: + val = str(val) + + if len(val) > 1000: + val = val[:997] + "..." + item = QTableWidgetItem(val) + item.setToolTip(val[:3000]) + self.table_widget.setItem(i, j, item) + except Exception as e: + if not sip.isdeleted(self.table_widget): + self.table_widget.setItem(i, j, QTableWidgetItem(f"[error: {e}]")) + + # --- finish up --- + self.table_widget.resizeColumnsToContents() + self.all_data = data + + def _filter_table(self): + if sip.isdeleted(self.table_widget): + return + text = self.search_edit.text().lower().strip() + if not text: + self._populate_table(self.all_data) + return + filtered = [r for r in self.all_data if any(text in str(v).lower() for v in r.values())] + self._populate_table(filtered) diff --git a/GUI/src/vast/views/fruits_view.py b/GUI/src/vast/views/fruits_view.py index 8e9e14e1c..c77ded97d 100644 --- a/GUI/src/vast/views/fruits_view.py +++ b/GUI/src/vast/views/fruits_view.py @@ -1,60 +1,118 @@ # views/fruits_view.py from __future__ import annotations from typing import Optional, Tuple, Dict, List -from PyQt6.QtCore import Qt, pyqtSignal # type: ignore -from PyQt6.QtWidgets import ( # type: ignore +from PyQt6.QtCore import Qt, pyqtSignal +from PyQt6.QtWidgets import ( QWidget, QVBoxLayout, QHBoxLayout, QLabel, QPushButton, QComboBox, QTableWidget, QTableWidgetItem, QAbstractItemView, QDoubleSpinBox, - QMessageBox, QHeaderView, + QMessageBox, QHeaderView, QDialog, QLineEdit, QFrame ) from dashboard_api import DashboardApi -class FruitsView(QWidget): - """ - Thresholds editor per ( task, label). - Replaces the previous 'client thresholds' view. - """ - thresholdsSaved = pyqtSignal(dict) # {( task, label): threshold} +# ---------- thresholds ---------- +class ThresholdsEditorDialog(QDialog): + thresholdsSaved = pyqtSignal(dict) # {(task,label): threshold} - def __init__(self, api: DashboardApi, parent=None): + TASK_OPTIONS = ["ripeness", "disease", "size", "color"] + + def __init__(self, api: DashboardApi, parent: QWidget | None = None): super().__init__(parent) self.api = api + self.setWindowTitle("Fruits — Task Thresholds") + self.setModal(True) + self.resize(820, 560) + + + self.setStyleSheet(""" + +QLineEdit#search { + padding: 10px 12px; border: 1px solid #e8dccc; border-radius: 10px; background: #ffffff; +} + + +/* ====== status====== */ +QLabel#status { color: #6b7280; } +QLabel.status-ok { color: #17803a; } +QLabel.status-warn { color: #b25a00; } +QLabel.status-err { color: #cc0022; } + + +QPushButton, QToolButton { + padding: 10px 16px; border-radius: 12px; color: white; border: none; font-weight: 700; +} +QPushButton:disabled { background: #c8c8c8; color: #f5f5f5; } + +/* Add (🍌) */ +QPushButton#btn_add { + background: qlineargradient(x1:0,y1:0,x2:1,y2:0, stop:0 #f8e27a, stop:1 #d8c94a); + color: #3a3a00; +} +QPushButton#btn_add:hover { background: qlineargradient(x1:0,y1:0,x2:1,y2:0, stop:0 #ffef87, stop:1 #e3d65a); } + +/* Delete (🍒) */ +QPushButton#btn_delete { + background: qlineargradient(x1:0,y1:0,x2:1,y2:0, stop:0 #ff6a7a, stop:1 #e03d4f); +} +QPushButton#btn_delete:hover { background: qlineargradient(x1:0,y1:0,x2:1,y2:0, stop:0 #ff7f8d, stop:1 #ea5666); } + +/* Save (🥝) */ +QPushButton#btn_save { + background: qlineargradient(x1:0,y1:0,x2:1,y2:0, stop:0 #4bd27c, stop:1 #2fb765); +} +QPushButton#btn_save:hover { background: qlineargradient(x1:0,y1:0,x2:1,y2:0, stop:0 #5fe08b, stop:1 #3fcb75); } + +/* Close (🫐) */ +QPushButton#btn_close { + background: qlineargradient(x1:0,y1:0,x2:1,y2:0, stop:0 #6a7bff, stop:1 #4757e6); +} +QPushButton#btn_close:hover { background: qlineargradient(x1:0,y1:0,x2:1,y2:0, stop:0 #7d8bff, stop:1 #5b6cf0); } +""") - # --- Layout scaffolding --- root = QVBoxLayout(self) - root.setSpacing(10) + root.setSpacing(12) + # Title title = QLabel("Fruits — Task Thresholds (per task/label)") - title.setStyleSheet("font-size: 18px; font-weight: 600;") + title.setObjectName("title") root.addWidget(title) - # Action buttons - btns = QHBoxLayout() - self.btn_add = QPushButton("Add row") - self.btn_delete = QPushButton("Delete selected") - self.btn_save = QPushButton("Save all") - btns.addWidget(self.btn_add) - btns.addWidget(self.btn_delete) - btns.addStretch(1) - btns.addWidget(self.btn_save) - root.addLayout(btns) - - # Table: 4 columns - - # 0: Task (editable text) - # 1: Label (editable text, optional; empty string means default bucket) - # 2: Threshold (0..1) as spinbox - # 3: Updated By (editable text, optional) + # Toolbar: search + actions + toolbar = QFrame() + toolbar.setObjectName("toolbar") + tl = QHBoxLayout(toolbar) + tl.setContentsMargins(12, 12, 12, 12) + tl.setSpacing(8) + + self.txt_search = QLineEdit(placeholderText="Search by task or label…") + self.txt_search.setObjectName("search") + + self.btn_add = QPushButton("🍌 Add row") + self.btn_delete = QPushButton("🍒 Delete selected") + self.btn_save = QPushButton("🥝 Save all") + + self.btn_add.setObjectName("btn_add") + self.btn_delete.setObjectName("btn_delete") + self.btn_save.setObjectName("btn_save") + + tl.addWidget(self.txt_search, 1) + tl.addStretch(0) + tl.addWidget(self.btn_add) + tl.addWidget(self.btn_delete) + tl.addWidget(self.btn_save) + + root.addWidget(toolbar) + + # Table self.tbl = QTableWidget(0, 4, self) + self.tbl.setAlternatingRowColors(True) self.tbl.setHorizontalHeaderLabels([ - "Task", "Label (optional)", "Threshold (0..1)", "Updated By" + "Task", "Label (optional)", "Threshold (0..1)", "Updated By" ]) - self.tbl.horizontalHeader().setStretchLastSection(True) - self.tbl.horizontalHeader().setSectionResizeMode( - QHeaderView.ResizeMode.Interactive - ) + hdr = self.tbl.horizontalHeader() + hdr.setStretchLastSection(True) + hdr.setSectionResizeMode(QHeaderView.ResizeMode.Interactive) self.tbl.setSelectionBehavior(QAbstractItemView.SelectionBehavior.SelectRows) self.tbl.setSelectionMode(QAbstractItemView.SelectionMode.SingleSelection) self.tbl.setEditTriggers( @@ -62,21 +120,39 @@ def __init__(self, api: DashboardApi, parent=None): | QAbstractItemView.EditTrigger.SelectedClicked | QAbstractItemView.EditTrigger.EditKeyPressed ) - root.addWidget(self.tbl) + + self.tbl.verticalHeader().setDefaultSectionSize(36) + + root.addWidget(self.tbl, 1) - # Status line + # Status + Close + bottom = QHBoxLayout() self.lbl_status = QLabel("Add rows and click Save.") - self.lbl_status.setStyleSheet("color:#555;") - root.addWidget(self.lbl_status) + self.lbl_status.setObjectName("status") + bottom.addWidget(self.lbl_status) + bottom.addStretch(1) + self.btn_close = QPushButton("🫐 Close") + self.btn_close.setObjectName("btn_close") + bottom.addWidget(self.btn_close) + root.addLayout(bottom) # Signals self.btn_add.clicked.connect(self.add_row) self.btn_delete.clicked.connect(self.delete_selected) self.btn_save.clicked.connect(self.save_all) + self.btn_close.clicked.connect(self.accept) + self.txt_search.textChanged.connect(self._apply_filter) - # Start with an empty demo row + # Start with one empty row self.add_row() + def load_rows(self, rows: List[Tuple[str, str, float, str]]): + + self.tbl.setRowCount(0) + for t, l, thr, upd in rows: + self.add_row(t, l, thr, upd) + self.lbl_status.setText(f"Loaded {len(rows)} rows.") + # -------- Row ops -------- def add_row( self, @@ -88,59 +164,51 @@ def add_row( r = self.tbl.rowCount() self.tbl.insertRow(r) - - # Task + # Task (combobox) cmb = QComboBox(self.tbl) - TASK_OPTIONS = ["ripeness", "disease", "size", "color"] - cmb.addItems(TASK_OPTIONS) - if task in TASK_OPTIONS: + cmb.addItems(self.TASK_OPTIONS) + if task in self.TASK_OPTIONS: cmb.setCurrentText(task) - else: - cmb.setCurrentIndex(0) self.tbl.setCellWidget(r, 0, cmb) - # Label (optional) + + # Label (editable) self._set_text_cell(r, 1, label) - # Threshold spinbox + # Threshold (spinbox) spn = QDoubleSpinBox(self.tbl) spn.setRange(0.0, 1.0) spn.setSingleStep(0.01) spn.setDecimals(2) spn.setValue(float(threshold)) + spn.setAlignment(Qt.AlignmentFlag.AlignRight) self.tbl.setCellWidget(r, 2, spn) # Updated By - self._set_text_cell(r, 3, updated_by) + self._set_text_cell(r, 3, updated_by or "gui") self.lbl_status.setText("Row added.") def delete_selected(self): sel = self.tbl.selectionModel().selectedRows() if not sel: - self.lbl_status.setText("No row selected.") + self._set_status("No row selected.", "warn") return for m in sel: self.tbl.removeRow(m.row()) - self.lbl_status.setText("Row deleted.") + self._set_status("Row deleted.", "ok") # -------- Helpers -------- def _set_text_cell(self, row: int, col: int, text: str): item = QTableWidgetItem(text or "") - # Editable text item item.setFlags(item.flags() | Qt.ItemFlag.ItemIsEditable) self.tbl.setItem(row, col, item) - def _read_row(self, r: int) -> Tuple[ str, str, float, str]: - + def _read_row(self, r: int) -> Tuple[str, str, float, str]: # Task cmb = self.tbl.cellWidget(r, 0) - if isinstance(cmb, QComboBox): - task = cmb.currentText() - else: - task = "" - + task = cmb.currentText() if isinstance(cmb, QComboBox) else "" - # Label (optional; empty string is allowed) + # Label (optional) label_item = self.tbl.item(r, 1) label = (label_item.text().strip() if label_item else "") @@ -152,56 +220,79 @@ def _read_row(self, r: int) -> Tuple[ str, str, float, str]: updated_item = self.tbl.item(r, 3) updated_by = (updated_item.text().strip() if updated_item else "") - return task, label, threshold, updated_by + return task, label, threshold, updated_by - def _validate(self) -> Tuple[bool, str]: + def _apply_filter(self): + q = self.txt_search.text().strip().lower() + for r in range(self.tbl.rowCount()): + t, l, _, _ = self._read_row(r) + show = (q in t.lower()) or (q in (l or "").lower()) or (q == "") + self.tbl.setRowHidden(r, not show) + + def _set_status(self, text: str, level: str = "info"): + # level: ok|warn|err|info + self.lbl_status.setText(text) + for cls in ["status-ok", "status-warn", "status-err"]: + self.lbl_status.setProperty("class", "") + if level == "ok": + self.lbl_status.setProperty("class", "status-ok") + elif level == "warn": + self.lbl_status.setProperty("class", "status-warn") + elif level == "err": + self.lbl_status.setProperty("class", "status-err") + self.lbl_status.style().unpolish(self.lbl_status) + self.lbl_status.style().polish(self.lbl_status) + + def _validate(self) -> Tuple[bool, str, List[int]]: """ Rules: - Task: required - - Label: optional (dedup still enforced using the triple) + - Label: optional (dedup by (task,label)) - Threshold: 0..1 - - No duplicate (mission_id, task, label) + - No duplicate (task, label) - At least one row + Returns: (ok, msg, bad_rows_indices) """ seen = set() + bad_rows = [] + for r in range(self.tbl.rowCount()): t, l, thr, _ = self._read_row(r) if not t: - return False, f"Row {r+1}: Task is empty." + bad_rows.append(r) + return False, f"Row {r+1}: Task is empty.", bad_rows if not (0.0 <= thr <= 1.0): - return False, f"Row {r+1}: Threshold must be between 0 and 1." + bad_rows.append(r) + return False, f"Row {r+1}: Threshold must be between 0 and 1.", bad_rows key = (t, l or "") if key in seen: - return False, ( - f"Row {r+1}: Duplicate ( task, label) = " - f"( {t}, {l or '∅'})." - ) + bad_rows.append(r) + return False, f"Row {r+1}: Duplicate (task,label)=({t},{l or '∅'}).", bad_rows seen.add(key) if self.tbl.rowCount() == 0: - return False, "No rows to save." + return False, "No rows to save.", [] - return True, "" + return True, "", [] # -------- Save -------- def save_all(self): - ok, msg = self._validate() + ok, msg, bad_rows = self._validate() + if not ok: + if bad_rows: + self.tbl.selectRow(bad_rows[0]) QMessageBox.warning(self, "Validation error", msg) - self.lbl_status.setStyleSheet("color:#b00020;") - self.lbl_status.setText(msg) + self._set_status(msg, "err") return - # Build mapping: {( task, label): threshold} - mapping: Dict[Tuple[ str, str], float] = {} - updated_by_for_row: Dict[Tuple[ str, str], str] = {} - + # Build mapping: {(task,label): threshold} + mapping: Dict[Tuple[str, str], float] = {} for r in range(self.tbl.rowCount()): - t, l, thr, updated_by = self._read_row(r) - key = ( t, l or "") + t, l, thr, _ = self._read_row(r) + key = (t, l or "") mapping[key] = thr - updated_by_for_row[key] = updated_by or "gui" # Disable buttons during save self.btn_save.setEnabled(False) @@ -231,13 +322,6 @@ def _normalize_fail_list(fail_raw): return pairs try: - # Flatten mapping into the API’s expected structure - # API helper that we assume exists (as agreed earlier): - # bulk_set_task_thresholds_labeled(mapping, updated_by="gui") - # If your API expects a list of dicts, it should convert internally. - # To respect per-row updated_by, we send the majority value, - # and rely on server-side to accept it. Alternatively, expose a - # dedicated bulk that accepts per-row updated_by list. report = self.api.bulk_set_task_thresholds_labeled(mapping, updated_by="gui") ok_keys = _normalize_ok_set(report.get("ok", [])) fail_pairs = _normalize_fail_list(report.get("fail", [])) @@ -247,25 +331,69 @@ def _normalize_fail_list(fail_raw): succeeded = len(ok_keys) if ok_keys else (total - failed) if failed == 0: - self.lbl_status.setStyleSheet("color:#0a7d00;") - self.lbl_status.setText(f"Saved {succeeded}/{total} thresholds ✓") + self._set_status(f"Saved {succeeded}/{total} thresholds ✓", "ok") QMessageBox.information(self, "Saved", f"All {total} thresholds saved successfully.") - # Emit a normalized dict for listeners self.thresholdsSaved.emit({k: v for k, v in mapping.items()}) else: - self.lbl_status.setStyleSheet("color:#cc7a00;") - # Show first 10 failures neatly lines = "\n".join(f"- {k}: {reason}" for k, reason in fail_pairs[:10]) more = "" if failed <= 10 else f"\n(+{failed-10} more...)" - self.lbl_status.setText(f"Partial save: {succeeded}/{total} saved, {failed} failed.") + self._set_status(f"Partial save: {succeeded}/{total} saved, {failed} failed.", "warn") QMessageBox.warning(self, "Partial save", f"Saved {succeeded}/{total}.\nFailed:\n{lines}{more}") except Exception as e: import traceback traceback.print_exc() QMessageBox.critical(self, "Error", f"Failed to save thresholds:\n{type(e).__name__}: {e}") - + self._set_status("Failed to save thresholds.", "err") finally: self.btn_save.setEnabled(True) self.btn_add.setEnabled(True) self.btn_delete.setEnabled(True) + + + +class FruitsView(QWidget): + thresholdsSaved = pyqtSignal(dict) # {(task,label): threshold} + + def __init__(self, api: DashboardApi, parent=None): + super().__init__(parent) + self.api = api + + root = QVBoxLayout(self) + root.setSpacing(10) + + title = QLabel("Fruits") + title.setStyleSheet("font-size: 22px; font-weight: 700;") + root.addWidget(title) + + + row = QHBoxLayout() + lbl = QLabel("Manage task thresholds per task/label.") + self.btn_open_editor = QPushButton("Change thresholds…") + row.addWidget(lbl, 1) + row.addStretch(0) + row.addWidget(self.btn_open_editor) + root.addLayout(row) + + line = QFrame() + line.setFrameShape(QFrame.Shape.HLine) + line.setStyleSheet("color:#e5e7eb;") + root.addWidget(line) + + + self.lbl_status = QLabel("Click “Change thresholds…” to edit.") + self.lbl_status.setStyleSheet("color:#555;") + root.addWidget(self.lbl_status) + + + self.btn_open_editor.clicked.connect(self.open_thresholds_dialog) + + def open_thresholds_dialog(self): + dlg = ThresholdsEditorDialog(self.api, self) + + # rows = self.api.get_current_thresholds() -> List[Tuple[str,str,float,str]] + # dlg.load_rows(rows) + + dlg.thresholdsSaved.connect(self.thresholdsSaved.emit) + dlg.exec() # modal + self.lbl_status.setText("Threshold editor closed.") diff --git a/GUI/src/vast/views/ground_view.py b/GUI/src/vast/views/ground_view.py new file mode 100644 index 000000000..09db68187 --- /dev/null +++ b/GUI/src/vast/views/ground_view.py @@ -0,0 +1,589 @@ +from __future__ import annotations +import os +from dataclasses import dataclass +from typing import Optional, Any, Dict, List + +from PyQt6.QtCore import Qt, QTimer, QSize, QRectF +from PyQt6.QtGui import QPixmap, QKeyEvent, QPainter, QColor, QPen, QFont +from PyQt6.QtWidgets import ( + QWidget, QVBoxLayout, QHBoxLayout, QLabel, QPushButton, + QProgressBar, QMessageBox, QSizePolicy, QFrame +) + +# GUI never touches MinIO directly – it uses DashboardApi only. +from vast.dashboard_api import DashboardApi + +GROUND_BUCKET = os.getenv("GROUND_BUCKET", "ground") +GROUND_PREFIX = os.getenv("GROUND_PREFIX", "") + +# ---------------------------- +# PHI data model +# ---------------------------- +@dataclass +class PhiSnapshot: + phi: Optional[float] # 0..100 or None + density: Optional[float] + coverage: Optional[float] + severity_avg: Optional[float] # usually 0..1 (clamped) + trend: Optional[float] + week_start: Optional[str] + source: str = "" # textual hint of data source (NOT shown in UI) + + +def _phi_band_color(v: float) -> str: + # 80..100 = green, 50..79 = amber, else red + if v >= 80: + return "#16a34a" # green-600 + if v >= 50: + return "#f59e0b" # amber-500 + return "#dc2626" # red-600 + + +def _safe_float(x) -> Optional[float]: + try: + if x is None: + return None + return float(x) + except Exception: + return None + + +# ---------------------------- +# Visual PHI circle (pie) +# ---------------------------- +class PhiCircleWidget(QWidget): + """ + Draws a pie: + - red slice = severity in [0..1] + - green slice = 1 - severity (healthy remainder) + Always draws red on top so it's never hidden. + Also draws the severity percentage text centered on the pie. + """ + def __init__(self, parent=None): + super().__init__(parent) + self._severity = 0.0 # 0..1 + self.setAttribute(Qt.WidgetAttribute.WA_OpaquePaintEvent, True) + self.setSizePolicy(QSizePolicy.Policy.Fixed, QSizePolicy.Policy.Fixed) + self.setToolTip("Red slice = severity (0..1). Green = healthy remainder.") + + def sizeHint(self): + return QSize(120, 120) + + def minimumSizeHint(self): + return QSize(100, 100) + + def setSeverity(self, value: float) -> None: + try: + v = float(value) + except Exception: + v = 0.0 + self._severity = max(0.0, min(1.0, v)) + self.update() # ensure repaint + + def paintEvent(self, event) -> None: + try: + painter = QPainter(self) + painter.setRenderHint(QPainter.RenderHint.Antialiasing, True) + painter.setPen(Qt.PenStyle.NoPen) + + pad = 10 + size = min(self.width(), self.height()) - 2 * pad + if size <= 0: + return + cx = (self.width() - size) / 2 + cy = (self.height() - size) / 2 + rect = QRectF(cx, cy, size, size) + + start = 90 * 16 # 12 o'clock (Qt: 0° is 3 o'clock) + full = -360 * 16 + + s = self._severity + # Degenerate cases + if s <= 1e-6: + painter.setBrush(QColor("#16a34a")) + painter.drawEllipse(rect) + elif s >= 1 - 1e-6: + painter.setBrush(QColor("#dc2626")) + painter.drawEllipse(rect) + else: + span_red = int(round(full * s)) # negative (clockwise) + span_green = full - span_red # the remainder + + # Draw green remainder first + painter.setBrush(QColor("#16a34a")) + painter.drawPie(rect, start + span_red, span_green) + + # Draw red slice on top (so it's always visible) + painter.setBrush(QColor("#dc2626")) + painter.drawPie(rect, start, span_red) + + # Outline + pen = QPen(QColor("#334155")) + pen.setWidth(2) + painter.setPen(pen) + painter.setBrush(Qt.BrushStyle.NoBrush) + painter.drawEllipse(rect) + + # Percentage text (centered) + percent_text = f"{int(round(s * 100))}%" + font = QFont() + font.setBold(True) + font.setPointSize(int(size * 0.22)) # responsive sizing + painter.setFont(font) + + # Soft shadow for readability + painter.setPen(QColor(0, 0, 0, 160)) + painter.drawText(rect, Qt.AlignmentFlag.AlignCenter, percent_text) + # Foreground text + painter.setPen(QColor("#ffffff")) + painter.drawText(rect, Qt.AlignmentFlag.AlignCenter, percent_text) + + except Exception as e: + print(f"[PhiCircleWidget] paintEvent error: {e}") + + +# ---------------------------- +# GroundView +# ---------------------------- +class GroundView(QWidget): + """ + Gallery mode: + - Loads all object keys from MinIO via DashboardApi. + - Keeps current index; supports Prev/Next buttons and keyboard arrows. + - On image change, fetches bytes and refreshes PHI for that key. + """ + + def __init__(self, api: DashboardApi, parent=None): + super().__init__(parent) + self.api = api + + # State for gallery + self._keys: List[str] = [] + self._idx: int = -1 + + # ---------- UI ---------- + root = QVBoxLayout(self) + root.setContentsMargins(12, 12, 12, 12) + root.setSpacing(10) + + title = QLabel("🌿 Ground — Gallery & PHI") + title.setStyleSheet("font-size:20px;font-weight:800;color:#0f172a;") + root.addWidget(title) + + # Toolbar + toolbar = QHBoxLayout() + self.btn_refresh_list = QPushButton("Reload list") + self.btn_refresh_list.clicked.connect(self.reload_keys) + + self.btn_prev = QPushButton("◀ Prev") + self.btn_prev.clicked.connect(self.prev_image) + self.btn_next = QPushButton("Next ▶") + self.btn_next.clicked.connect(self.next_image) + + self.btn_show_phi = QPushButton("Show PHI") + self.btn_show_phi.clicked.connect(self.refresh_phi_current) + + self.counter_label = QLabel("(0 / 0)") + self.counter_label.setStyleSheet("color:#475569;font-size:12px;") + + toolbar.addWidget(self.btn_refresh_list) + toolbar.addSpacing(8) + toolbar.addWidget(self.btn_prev) + toolbar.addWidget(self.btn_next) + toolbar.addSpacing(16) + toolbar.addWidget(self.btn_show_phi) + toolbar.addStretch(1) + toolbar.addWidget(self.counter_label) + root.addLayout(toolbar) + + # Image frame + img_frame = QFrame() + img_frame.setStyleSheet("background:#f8fafc;border:1px solid #cbd5e1;border-radius:10px;") + img_layout = QVBoxLayout(img_frame) + img_layout.setContentsMargins(8, 8, 8, 8) + img_layout.setSpacing(6) + + self.image_label = QLabel("(No image loaded yet)") + self.image_label.setAlignment(Qt.AlignmentFlag.AlignCenter) + self.image_label.setSizePolicy(QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Expanding) + self.image_label.setMinimumHeight(260) + img_layout.addWidget(self.image_label) + + self.img_meta = QLabel("") + self.img_meta.setStyleSheet("color:#475569;font-size:12px;") + img_layout.addWidget(self.img_meta) + + root.addWidget(img_frame, stretch=2) + + # PHI area + phi_frame = QFrame() + phi_frame.setStyleSheet("background:#ffffff;border:1px solid #cbd5e1;border-radius:10px;") + phi_layout = QVBoxLayout(phi_frame) + phi_layout.setContentsMargins(12, 12, 12, 12) + phi_layout.setSpacing(8) + + # Row with headline + (trimmed) details + row = QHBoxLayout() + self.phi_label = QLabel("PHI: –") + self.phi_label.setStyleSheet("font-size:16px;font-weight:700;color:#0f172a;") + row.addWidget(self.phi_label) + row.addStretch(1) + self.phi_details = QLabel("") # will show severity/coverage/trend (without src) + self.phi_details.setStyleSheet("color:#475569;font-size:12px;") + row.addWidget(self.phi_details) + phi_layout.addLayout(row) + + # PHI progress (axis-like) + pie at its side + self.phi_bar = QProgressBar() + self.phi_bar.setRange(0, 100) + self.phi_bar.setValue(0) + self.phi_bar.setFormat("%v") + self._style_phi_bar(None) + + phi_row2 = QHBoxLayout() + phi_row2.setContentsMargins(0, 0, 0, 0) + phi_row2.setSpacing(10) + phi_row2.addWidget(self.phi_bar, stretch=1) + + self.phi_circle = PhiCircleWidget() + self.phi_circle.setFixedSize(120, 120) + phi_row2.addWidget(self.phi_circle) + + phi_layout.addLayout(phi_row2) + + legend = QLabel("אדום = Severity | ירוק = Healthy") + legend.setStyleSheet("color:#64748b;font-size:11px;") + phi_layout.addWidget(legend) + + root.addWidget(phi_frame, stretch=1) + + # Auto-refresh PHI every 2 min + self.timer = QTimer(self) + self.timer.setInterval(120_000) + self.timer.timeout.connect(self.refresh_phi_current) + self.timer.start() + + # Initial load + QTimer.singleShot(200, self.reload_keys) + + # So arrow keys work even without inner focus + self.setFocusPolicy(Qt.FocusPolicy.StrongFocus) + + # ---------------------------- + # Styling helpers + # ---------------------------- + def _style_phi_bar(self, value: Optional[float]) -> None: + color = "#64748b" if value is None else _phi_band_color(float(value)) + self.phi_bar.setStyleSheet( + f"QProgressBar {{ border:1px solid #cbd5e1;border-radius:6px;height:18px; }} " + f"QProgressBar::chunk {{ background:{color}; border-radius:6px; }}" + ) + + def _warn(self, msg: str) -> None: + # Non-blocking warning box + try: + def _show(): + try: + box = QMessageBox(self) + box.setIcon(QMessageBox.Icon.Warning) + box.setWindowTitle("Ground") + box.setText(str(msg)) + box.setStandardButtons(QMessageBox.StandardButton.Ok) + box.setWindowModality(Qt.WindowModality.NonModal) + box.show() + except BaseException: + print(f"[GroundView] WARN(fallback): {msg}") + QTimer.singleShot(0, _show) + except BaseException: + print(f"[GroundView] WARN: {msg}") + + def _try_api(self, names: List[str], *args, **kwargs) -> Any: + # Try a list of API method names until one succeeds. + for name in names: + fn = getattr(self.api, name, None) + if callable(fn): + try: + return fn(*args, **kwargs) + except Exception as e: + print(f"[GroundView] API call {name} failed: {e}") + return None + + # ---------------------------- + # Gallery: load keys & navigation + # ---------------------------- + def reload_keys(self) -> None: + """Load all object keys from MinIO (sorted newest→oldest).""" + try: + objs = self._try_api( + ["list_minio_objects", "list_objects"], + bucket=GROUND_BUCKET, prefix=GROUND_PREFIX, limit=1000 + ) + keys: List[str] = [] + if isinstance(objs, list): + # Sort by last_modified/LastModified desc when available + def _lm(o): + if not isinstance(o, dict): + return "" + return o.get("last_modified") or o.get("LastModified") or "" + try: + objs = sorted(objs, key=_lm, reverse=True) + except Exception: + pass + for o in objs: + if isinstance(o, dict): + for f in ("key", "name", "object_name", "path"): + v = o.get(f) + if isinstance(v, str) and v.strip(): + keys.append(v.strip()) + break + + self._keys = keys + self._idx = 0 if self._keys else -1 + self._update_counter() + self._update_nav_buttons() + if self._idx >= 0: + self.load_current_image() + else: + self._set_image(None) + self.img_meta.setText("No objects found in MinIO.") + self._render_phi_none() + + except Exception as e: + self._warn(f"reload_keys error: {e}") + + def _update_counter(self) -> None: + total = len(self._keys) + pos = (self._idx + 1) if self._idx >= 0 else 0 + self.counter_label.setText(f"({pos} / {total})") + + def _update_nav_buttons(self) -> None: + has = bool(self._keys) + for b in (self.btn_prev, self.btn_next, self.btn_show_phi): + b.setEnabled(has) + + def prev_image(self) -> None: + if not self._keys: + return + self._idx = (self._idx - 1) % len(self._keys) + self._update_counter() + self.load_current_image() + + def next_image(self) -> None: + if not self._keys: + return + self._idx = (self._idx + 1) % len(self._keys) + self._update_counter() + self.load_current_image() + + def keyPressEvent(self, event: QKeyEvent) -> None: + if event.key() in (Qt.Key.Key_Left, Qt.Key.Key_A): + self.prev_image() + event.accept() + return + if event.key() in (Qt.Key.Key_Right, Qt.Key.Key_D): + self.next_image() + event.accept() + return + super().keyPressEvent(event) + + # ---------------------------- + # Image load + PHI for current key + # ---------------------------- + def _set_image(self, pix: Optional[QPixmap]) -> None: + if pix is None or pix.isNull(): + self.image_label.setText("(No image)") + self.image_label.setPixmap(QPixmap()) + return + target_size: QSize = self.image_label.size() + if target_size.width() <= 4 or target_size.height() <= 4: + self.image_label.setPixmap(pix) + self.image_label.setText("") + return + scaled = pix.scaled( + target_size.width(), + target_size.height(), + Qt.AspectRatioMode.KeepAspectRatio, + Qt.TransformationMode.SmoothTransformation, + ) + self.image_label.setPixmap(scaled) + self.image_label.setText("") + + def resizeEvent(self, e): + super().resizeEvent(e) + pix = self.image_label.pixmap() + if pix is not None and not pix.isNull(): + self._set_image(pix) + + def load_current_image(self) -> None: + """Load image bytes for current key and refresh PHI.""" + try: + if self._idx < 0 or self._idx >= len(self._keys): + self._set_image(None) + self.img_meta.setText("No selection.") + self._render_phi_none() + return + + key = self._keys[self._idx] + getter = getattr(self.api, "get_image_bytes_from_minio", None) + if not callable(getter): + self._warn("DashboardApi.get_image_bytes_from_minio is missing.") + self._set_image(None) + self._render_phi_none() + return + + data = None + try: + data = getter(key, bucket=GROUND_BUCKET) + except TypeError: + data = getter(key) + except Exception as e: + self._warn(f"Failed fetching image bytes: {e}") + data = None + + if not data: + self._set_image(None) + self.img_meta.setText(f"Failed to read: {GROUND_BUCKET}/{key}") + self._render_phi_none() + return + + pix = QPixmap() + if not pix.loadFromData(data): + self._set_image(None) + self.img_meta.setText(f"Unsupported bytes: {GROUND_BUCKET}/{key}") + self._render_phi_none() + return + + self._set_image(pix) + self.img_meta.setText(f"{GROUND_BUCKET}/{key}") + + # After image displayed, refresh PHI + self._refresh_phi_for_key(key) + + except Exception as e: + self._warn(f"load_current_image error: {e}") + self._render_phi_none() + + # ---------------------------- + # PHI flow + # ---------------------------- + def _map_phi_dict(self, d: Dict[str, Any], source: str) -> PhiSnapshot: + return PhiSnapshot( + phi=_safe_float(d.get("phi")), + density=_safe_float(d.get("density")), + coverage=_safe_float(d.get("coverage")), + severity_avg=_safe_float(d.get("severity_avg")), + trend=_safe_float(d.get("trend")), + week_start=str(d.get("week_start")) if d.get("week_start") is not None else None, + source=source, + ) + + def _render_phi_none(self) -> None: + self.phi_label.setText("PHI: –") + self.phi_details.setText("No PHI available.") + self.phi_bar.setValue(0) + self._style_phi_bar(None) + self.phi_circle.setSeverity(0.0) + + def _refresh_phi_for_key(self, key: str) -> None: + """Best-effort PHI for the specific image key. Fully guarded.""" + try: + # 1) Exact by key + d = self._try_api(["get_phi_for_image"], key) + if isinstance(d, dict) and (d.get("phi") is not None or d.get("severity_avg") is not None): + return self._render_phi(self._map_phi_dict(d, "phi_by_key")) + + # 2) Current image + d = self._try_api(["get_phi_for_current_image"]) + if isinstance(d, dict) and (d.get("phi") is not None or d.get("severity_avg") is not None): + return self._render_phi(self._map_phi_dict(d, "phi_current")) + + # 3) Weekly/global + d = self._try_api(["get_weekly_phi"]) + if isinstance(d, dict) and (d.get("phi") is not None or d.get("severity_avg") is not None): + return self._render_phi(self._map_phi_dict(d, "weekly")) + + # 4) Derive roughly from latest rows as last resort + rows = self._try_api(["get_latest_rows", "get_latest_detections", "get_latest_ground_rows"], limit=1) or [] + if rows and isinstance(rows, list) and isinstance(rows[0], dict): + sev = None + cov = None + for k in ("severity_avg", "severity", "mean_severity"): + sev = _safe_float(rows[0].get(k)) + if sev is not None: + break + for k in ("coverage", "plant_coverage"): + cov = _safe_float(rows[0].get(k)) + if cov is not None: + break + + phi_val = None + if sev is not None: + s = sev if sev <= 1.0 else min(sev, 10.0) / 10.0 + phi_val = max(0.0, min(100.0, 100.0 * (1.0 - s))) + elif cov is not None: + c = max(0.0, min(1.0, cov)) + phi_val = 100.0 * c + + if phi_val is not None: + snap = PhiSnapshot( + phi=phi_val, density=None, coverage=cov, severity_avg=sev, + trend=None, week_start=None, source="derived_from_rows" + ) + return self._render_phi(snap) + + # Nothing available + self._render_phi_none() + self._warn("No PHI available for this image.") + except Exception as e: + self._warn(f"_refresh_phi_for_key error: {e}") + self._render_phi_none() + + def _render_phi(self, snap: PhiSnapshot) -> None: + if snap is None or snap.phi is None: + self._render_phi_none() + return + + # Progress bar + label + val = max(0, min(100, int(round(snap.phi)))) + self.phi_label.setText(f"PHI: {val}") + parts = [] + # keep useful metrics, but DO NOT show 'src=' anymore + if snap.density is not None: + parts.append(f"density={snap.density:.2f}") + if snap.coverage is not None: + parts.append(f"coverage={snap.coverage:.2f}") + if snap.severity_avg is not None: + parts.append(f"severity={snap.severity_avg:.2f}") + if snap.trend is not None: + parts.append(f"trend={snap.trend:+.2f}") + if snap.week_start: + parts.append(f"week={snap.week_start}") + self.phi_details.setText(" | ".join(parts)) # no src here + self.phi_bar.setValue(val) + self._style_phi_bar(val) + + # Circle severity (normalized to 0..1) + sev = snap.severity_avg + try: + sev = float(sev) if sev is not None else 0.0 + except Exception: + sev = 0.0 + sev_norm = sev if sev <= 1.0 else min(sev, 10.0) / 10.0 + self.phi_circle.setSeverity(max(0.0, min(1.0, sev_norm))) + + def refresh_phi_current(self) -> None: + """Called by the 'Show PHI' button; uses current image key safely.""" + try: + if 0 <= self._idx < len(self._keys): + key = self._keys[self._idx] + if not isinstance(key, str) or not key.strip(): + self._render_phi_none() + self._warn("No valid image key selected.") + return + self._refresh_phi_for_key(key) + else: + self._render_phi_none() + self._warn("No image selected yet. Click 'Reload list' or 'Next'.") + except Exception as e: + self._render_phi_none() + self._warn(f"Show PHI failed: {e}") diff --git a/GUI/src/vast/views/security/events_history_page.py b/GUI/src/vast/views/security/events_history_page.py new file mode 100644 index 000000000..f6ec017fb --- /dev/null +++ b/GUI/src/vast/views/security/events_history_page.py @@ -0,0 +1,516 @@ +from PyQt6 import QtWidgets, QtGui, QtCore +import os, sys, vlc +from datetime import datetime + + +class EventsHistoryPage(QtWidgets.QWidget): + """AgGuard Security Events History — visual-only severity bar with sorting and fixed filters (with debug prints).""" + + def __init__(self, api, parent=None): + super().__init__(parent) + self.api = api + self.setContentsMargins(24, 24, 24, 24) + + print("[INIT] EventsHistoryPage initialized") + + # ───────────── GLOBAL STYLE ───────────── + self.setStyleSheet(""" + QWidget { + background-color: #f9fafb; + font-family: 'Segoe UI', 'DejaVu Sans', Arial, sans-serif; + color: #111827; + font-size: 14px; + } + QHeaderView::section { + background-color: #f3f4f6; + color: #111827; + font-weight: 600; + border: none; + padding: 8px; + border-bottom: 1px solid #e5e7eb; + } + QTableWidget { + gridline-color: #e5e7eb; + background-color: #ffffff; + border: 1px solid #d1d5db; + border-radius: 10px; + selection-background-color: #bbf7d0; + selection-color: #065f46; + font-size: 13px; + } + QTableWidget::item { padding: 10px; } + QScrollBar:vertical { + background: transparent; + width: 10px; + margin: 2px; + } + QScrollBar::handle:vertical { + background: #9ca3af; + border-radius: 5px; + min-height: 20px; + } + QScrollBar::handle:vertical:hover { background: #6b7280; } + QComboBox, QDateEdit { + background-color: #ffffff; + border: 1px solid #d1d5db; + border-radius: 8px; + padding: 4px 10px; + font-size: 13px; + height: 32px; + min-width: 120px; + color: #111827; + } + QComboBox:hover, QDateEdit:hover { + border-color: #9ca3af; + background-color: #f9fafb; + } + QComboBox:focus, QDateEdit:focus { + border: 1px solid #10b981; + background-color: #ffffff; + } + QComboBox QAbstractItemView { + border: none; + background-color: #ffffff; + padding: 6px 4px; + outline: none; + font-size: 14px; + selection-background-color: #10b981; + selection-color: white; + } + QPushButton { + border: none; + border-radius: 6px; + font-weight: 500; + padding: 6px 12px; + } + QPushButton#reload_btn { + background-color: #10b981; + color: white; + font-weight: 600; + } + QPushButton#reload_btn:hover { background-color: #059669; } + QPushButton#clear_btn { + background-color: #f3f4f6; + color: #374151; + border: 1px solid #d1d5db; + } + QPushButton#clear_btn:hover { background-color: #e5e7eb; } + QPushButton.view_btn { + background-color: #10b981; + color: white; + padding: 6px 16px; + font-weight: 700; + font-size: 13px; + } + QPushButton.view_btn:hover { background-color: #059669; } + """) + + # ───────────── CONSTANTS ───────────── + self.media_proxy_base = os.getenv("MEDIA_PROXY_BASE", "http://media-proxy:8080").rstrip("/") + self.proxy_local_base = "http://127.0.0.1:19100" + self.all_rows = [] + + main_layout = QtWidgets.QVBoxLayout(self) + main_layout.setSpacing(18) + + # ───────────── HEADER ───────────── + header = QtWidgets.QHBoxLayout() + title = QtWidgets.QLabel("🧾 Security Events History") + title.setStyleSheet("font-size:22px;font-weight:700;color:#0f172a;") + header.addWidget(title) + header.addStretch(1) + + reload_btn = QtWidgets.QPushButton("Reload") + reload_btn.setObjectName("reload_btn") + reload_btn.setCursor(QtCore.Qt.CursorShape.PointingHandCursor) + reload_btn.clicked.connect(self.load_from_api) + header.addWidget(reload_btn) + main_layout.addLayout(header) + + # ───────────── TOOLBAR ───────────── + toolbar = QtWidgets.QFrame() + toolbar.setStyleSheet(""" + QFrame { + background-color: #ffffff; + border: 1px solid #d1d5db; + border-radius: 14px; + padding: 10px 14px; + } + """) + tl = QtWidgets.QHBoxLayout(toolbar) + tl.setContentsMargins(8, 6, 8, 6) + tl.setSpacing(8) + + self.device_filter = QtWidgets.QComboBox() + self.device_filter.addItem("All Devices") + self.device_filter.currentIndexChanged.connect(self.apply_filters) + + self.anomaly_filter = QtWidgets.QComboBox() + self.anomaly_filter.addItem("All Anomalies") + self.anomaly_filter.currentIndexChanged.connect(self.apply_filters) + + self.severity_slider = QtWidgets.QSlider(QtCore.Qt.Orientation.Horizontal) + self.severity_slider.setRange(0, 6) + self.severity_slider.setFixedWidth(110) + self._update_slider_style(0) + self.severity_slider.valueChanged.connect(self._update_slider_style) + self.severity_slider.valueChanged.connect(self.apply_filters) + + self.from_date = QtWidgets.QDateEdit(QtCore.QDate.currentDate().addMonths(-1)) + self.from_date.setDisplayFormat("yyyy-MM-dd") + self.from_date.setCalendarPopup(True) + self.from_date.dateChanged.connect(self.apply_filters) + + self.to_date = QtWidgets.QDateEdit(QtCore.QDate.currentDate()) + self.to_date.setDisplayFormat("yyyy-MM-dd") + self.to_date.setCalendarPopup(True) + self.to_date.dateChanged.connect(self.apply_filters) + + self.sort_combo = QtWidgets.QComboBox() + self.sort_combo.addItems([ + "No Sorting", + "Severity (High → Low)", + "Severity (Low → High)", + "Start Time (Newest)", + "Start Time (Oldest)", + "End Time (Newest)", + "End Time (Oldest)", + "Anomaly (A → Z)", + "Anomaly (Z → A)" + ]) + self.sort_combo.currentIndexChanged.connect(self.apply_filters) + + clear_btn = QtWidgets.QPushButton("Clear") + clear_btn.setObjectName("clear_btn") + clear_btn.setCursor(QtCore.Qt.CursorShape.PointingHandCursor) + clear_btn.clicked.connect(self.clear_filters) + + for w in [ + self.device_filter, self.anomaly_filter, + self.severity_slider, self.from_date, self.to_date, + self.sort_combo + ]: + tl.addWidget(w) + tl.addStretch(1) + tl.addWidget(clear_btn) + main_layout.addWidget(toolbar) + + # ───────────── TABLE ───────────── + self.table = QtWidgets.QTableWidget() + self.table.setColumnCount(8) + self.table.setHorizontalHeaderLabels([ + "ID", "Device", "Anomaly", "Start Time", + "End Time", "Duration (s)", "Severity", "View" + ]) + self.table.verticalHeader().setVisible(False) + self.table.setEditTriggers(QtWidgets.QAbstractItemView.EditTrigger.NoEditTriggers) + self.table.setSelectionBehavior(QtWidgets.QAbstractItemView.SelectionBehavior.SelectRows) + self.table.horizontalHeader().setStretchLastSection(True) + self.table.horizontalHeader().setSectionResizeMode(QtWidgets.QHeaderView.ResizeMode.Stretch) + self.table.verticalHeader().setDefaultSectionSize(48) + main_layout.addWidget(self.table, 1) + + QtCore.QTimer.singleShot(300, self.load_from_api) + + # ───────────── SLIDER STYLE ───────────── + def _update_slider_style(self, value): + percent = value / 6 if value else 0 + self.severity_slider.setStyleSheet(f""" + QSlider::groove:horizontal {{ + border: 1px solid #d1d5db; + height: 6px; + border-radius: 3px; + background: qlineargradient( + x1:0, y1:0, x2:1, y2:0, + stop:0 #10b981, + stop:{percent} #10b981, + stop:{percent} white, + stop:1 white + ); + }} + QSlider::handle:horizontal {{ + width: 16px; + height: 16px; + background: #10b981; + border-radius: 8px; + margin: -5px 0; + border: 1px solid #10b981; + }} + """) + + # ───────────── LOGIC ───────────── + def _safe_int(self, val): + try: + return int(val) + except Exception: + return 0 + + def _parse_time(self, t): + try: + if not t: + return None + dt = datetime.fromisoformat(t.replace("Z", "+00:00")) + return dt.replace(tzinfo=None) + except Exception: + return None + + def _fmt_time(self, t): + return t.replace("T", " ").split(".")[0] if t else "-" + + def load_from_api(self): + print("[API] Fetching alerts from:", f"{self.api.base}/api/tables/alerts") + try: + url = f"{self.api.base}/api/tables/alerts" + resp = self.api.http.get(url, timeout=8) + resp.raise_for_status() + data = resp.json() + + # Expect structure: {"rows": [...], "count": N} + if isinstance(data, dict) and "rows" in data: + rows = data["rows"] + count = data.get("count", len(rows)) + print(f"[API] Loaded {count} total alerts.") + else: + rows = data if isinstance(data, list) else [] + print(f"[API][WARN] Unexpected format, using raw list of {len(rows)} items.") + + # ───── Filter only relevant alert types ───── + allowed_types = {"climbing_fence", "masked_person", "intruding_animal"} + filtered = [r for r in rows if (r.get("alert_type") or "").strip() in allowed_types] + + print(f"[API] Filtered {len(filtered)} / {len(rows)} alerts matching allowed types {allowed_types}.") + + self.all_rows = filtered + + except Exception as e: + print("[API][ERROR]", e) + QtWidgets.QMessageBox.warning(self, "Error", f"Failed to fetch alerts:\n{e}") + return + + # Update the table and filters + self.populate_table(self.all_rows) + self.populate_filters() + + + + def populate_filters(self): + devices = sorted({it.get("device_id") or "-" for it in self.all_rows}) + anomalies = sorted({it.get("anomaly") or "-" for it in self.all_rows}) + print(f"[FILTERS] Available devices={devices}") + print(f"[FILTERS] Available anomalies={anomalies}") + + # devices + self.device_filter.blockSignals(True) + self.device_filter.clear() + self.device_filter.addItem("All Devices", None) + for d in devices: + self.device_filter.addItem(d, d) + self.device_filter.blockSignals(False) + + # anomalies (friendly display) + self.anomaly_filter.blockSignals(True) + self.anomaly_filter.clear() + self.anomaly_filter.addItem("All Anomalies", None) + for a in anomalies: + label = a.replace("_", " ").title() if a and a != "-" else a + self.anomaly_filter.addItem(label, a) + self.anomaly_filter.blockSignals(False) + + print("[FILTERS] Filters populated.") + + def apply_filters(self): + if not self.all_rows: + print("[FILTER] No rows loaded yet.") + return + + device = self.device_filter.currentData() + anomaly = self.anomaly_filter.currentData() + min_sev = self._safe_int(self.severity_slider.value()) + from_dt = datetime.combine(self.from_date.date().toPyDate(), datetime.min.time()) + to_dt = datetime.combine(self.to_date.date().toPyDate(), datetime.max.time()) + + print(f"\n[FILTER] Applying filters:") + print(f" device={device}, anomaly={anomaly}, min_sev={min_sev},") + print(f" from={from_dt}, to={to_dt}") + print(f" total rows={len(self.all_rows)}") + + filtered = [] + for idx, it in enumerate(self.all_rows): + dev = it.get("device_id") or "-" + anom = it.get("anomaly") or "-" + sev = self._safe_int(it.get("severity")) + started = self._parse_time(it.get("started_at")) + + include = True + reasons = [] + + # Device filter + if device and dev != device: + include = False + reasons.append(f"device mismatch ({dev} ≠ {device})") + + # Anomaly filter + if anomaly and anom != anomaly: + include = False + reasons.append(f"anomaly mismatch ({anom} ≠ {anomaly})") + + # Severity filter + if sev < min_sev: + include = False + reasons.append(f"severity too low ({sev} < {min_sev})") + + # Date filter + if started: + if not (from_dt <= started <= to_dt): + include = False + reasons.append(f"date {started} out of range [{from_dt}, {to_dt}]") + else: + reasons.append("no start date parsed") + + if include: + filtered.append(it) + else: + print(f"[FILTER][X] Row {idx} excluded — {', '.join(reasons)}") + + print(f"[FILTER] {len(filtered)} / {len(self.all_rows)} rows matched filters.\n") + + # Sorting + i = self.sort_combo.currentIndex() + keymap = { + 1: lambda x: self._safe_int(x.get("severity")), + 2: lambda x: self._safe_int(x.get("severity")), + 3: lambda x: self._parse_time(x.get("started_at")) or datetime.min, + 4: lambda x: self._parse_time(x.get("started_at")) or datetime.min, + 5: lambda x: self._parse_time(x.get("ended_at")) or datetime.min, + 6: lambda x: self._parse_time(x.get("ended_at")) or datetime.min, + 7: lambda x: (x.get("anomaly") or "").lower(), + 8: lambda x: (x.get("anomaly") or "").lower(), + } + + if i in keymap: + reverse = i in (1, 3, 5, 8) + print(f"[SORT] Sorting index={i}, reverse={reverse}") + filtered.sort(key=keymap[i], reverse=reverse) + else: + print("[SORT] No sorting applied.") + + self.populate_table(filtered) + + + def clear_filters(self): + print("[FILTER] Clearing filters to defaults.") + self.device_filter.setCurrentIndex(0) + self.anomaly_filter.setCurrentIndex(0) + self.sort_combo.setCurrentIndex(0) + self.severity_slider.setValue(0) + self.from_date.setDate(QtCore.QDate.currentDate().addMonths(-1)) + self.to_date.setDate(QtCore.QDate.currentDate()) + self.apply_filters() + + def _severity_color(self, sev: int) -> str: + """Return green intensity from white (low) to dark green (high).""" + sev = max(1, min(sev, 9)) + # interpolate white (#ffffff) → dark green (#059669) + def lerp_color(c1, c2, t): + c1, c2 = [int(c1[i:i+2], 16) for i in (1, 3, 5)], [int(c2[i:i+2], 16) for i in (1, 3, 5)] + mix = [round(c1[j] + (c2[j]-c1[j])*t) for j in range(3)] + return f"#{mix[0]:02x}{mix[1]:02x}{mix[2]:02x}" + return lerp_color("#ffffff", "#059669", sev / 9) + + def _severity_label(self, sev: int) -> str: + if sev <= 3: + return f"Low ({sev})" + elif sev <= 6: + return f"Medium ({sev})" + else: + return f"Critical ({sev})" + + + + def populate_table(self, rows): + print(f"[TABLE] Populating table with {len(rows)} alerts.") + self.table.setRowCount(len(rows)) + + for r, it in enumerate(rows): + sid = (str(it.get("alert_id") or "")[:8] + "...") if it.get("alert_id") else "-" + self.table.setItem(r, 0, QtWidgets.QTableWidgetItem(sid)) + self.table.setItem(r, 1, QtWidgets.QTableWidgetItem(it.get("device_id") or "-")) + self.table.setItem(r, 2, QtWidgets.QTableWidgetItem(it.get("alert_type") or "-")) + self.table.setItem(r, 3, QtWidgets.QTableWidgetItem(self._fmt_time(it.get("started_at")))) + self.table.setItem(r, 4, QtWidgets.QTableWidgetItem(self._fmt_time(it.get("ended_at")))) + self.table.setItem(r, 5, QtWidgets.QTableWidgetItem(f"{it.get('confidence') or 0:.2f}")) + + # ACK Checkbox indicator + ack_value = it.get("ack", False) + ack_label = QtWidgets.QLabel("✅" if ack_value else "❌") + ack_label.setAlignment(QtCore.Qt.AlignmentFlag.AlignCenter) + self.table.setCellWidget(r, 6, ack_label) + + # Centered “View” button for vod + btn = QtWidgets.QPushButton("View") + btn.setCursor(QtCore.Qt.CursorShape.PointingHandCursor) + btn.setFixedHeight(26) + btn.setFixedWidth(65) + btn.setStyleSheet(""" + QPushButton { + background-color: #10b981; + color: white; + border-radius: 6px; + font-size: 12px; + font-weight: 600; + padding: 3px 6px; + } + QPushButton:hover { + background-color: #059669; + } + """) + btn.clicked.connect(lambda _, info=it: self._open_video_player(info)) + + btn_container = QtWidgets.QWidget() + btn_layout = QtWidgets.QHBoxLayout(btn_container) + btn_layout.setContentsMargins(0, 0, 0, 0) + btn_layout.setAlignment(QtCore.Qt.AlignmentFlag.AlignCenter) + btn_layout.addWidget(btn) + self.table.setCellWidget(r, 7, btn_container) + + print("[TABLE] Done populating alerts table.") + + + + + + + + + + + + + def _open_video_player(self, info): + print(f"[VIDEO] Opening video player for alert={info.get('alert_id')}") + url = info.get("vod") + if not url: + QtWidgets.QMessageBox.warning(self, "No Video", "This alert has no VOD URL.") + return + self._show_vlc_popup(url) + + + def _show_vlc_popup(self, url): + print(f"[VIDEO] Playing URL: {url}") + popup = QtWidgets.QDialog(self) + popup.setWindowTitle("Incident Video Playback") + popup.setMinimumSize(640, 400) + vbox = QtWidgets.QVBoxLayout(popup) + player = QtWidgets.QFrame() + player.setStyleSheet("background:black;border-radius:8px;") + vbox.addWidget(player, 1) + inst = vlc.Instance(["--quiet", "--no-video-title-show"]) + mp = inst.media_player_new() + mp.set_media(inst.media_new(url)) + popup.show() + if sys.platform.startswith("win"): + mp.set_hwnd(int(player.winId())) + else: + mp.set_xwindow(int(player.winId())) + mp.play() + print("[VIDEO] Playback started.") diff --git a/GUI/src/vast/views/security/incident_player_vlc.py b/GUI/src/vast/views/security/incident_player_vlc.py new file mode 100644 index 000000000..52bda9b11 --- /dev/null +++ b/GUI/src/vast/views/security/incident_player_vlc.py @@ -0,0 +1,1947 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +AgGuard Incident Player — PyQt6 + python-vlc with a tiny DVR proxy. + +What’s new in this build: +- Dynamic live lag: on any segment 404/410, the proxy temporarily hides more + tail segments in /live.m3u8 so VLC never requests unavailable parts. + (Decay back to normal once stable.) +- No DVR freeze on resolve (removed items disappear; playback stops/advances). +- No-cache headers on HLS endpoints. +""" + +from __future__ import annotations +import sys, os, asyncio, threading, time, re, json +from dataclasses import dataclass +from typing import Optional, List, Tuple +from urllib.parse import urljoin, urlparse, urlunparse + +from PyQt6 import QtCore, QtWidgets, QtGui +from PyQt6.QtCore import Qt, QUrl, QTimer +from PyQt6.QtWebSockets import QWebSocket +from PyQt6.QtNetwork import QNetworkAccessManager, QNetworkRequest +import vlc # python-vlc +from aiohttp import web, ClientSession +from vast.views.security.events_history_page import EventsHistoryPage + + + + +# ────────────────────────────────────────────────────────────────────────────── +# Config +# ────────────────────────────────────────────────────────────────────────────── +class Config: + MEDIA_BASE = os.getenv("MEDIA_BASE", "http://media-proxy:8080") + INCIDENT = os.getenv("INCIDENT", "placeholder") + TOKEN = os.getenv("MEDIA_TOKEN", "CHANGE_ME") + BIND = os.getenv("BIND", "127.0.0.1") + PORT = int(os.getenv("PORT", "19100")) + + # Poll upstream playlist ~2–4x per segment (1.0s segments -> 300ms is good) + REFRESH_MS = int(os.getenv("REFRESH_MS", "20000")) + + # Show this many segments in the live window… + LIVE_EDGE_SEGMENTS = int(os.getenv("LIVE_EDGE_SEGMENTS", "3")) + # …but hide the freshest N (stay behind live edge to avoid stalls) + LIVE_LAG_SEGMENTS = int(os.getenv("LIVE_LAG_SEGMENTS", "1")) + + # VLC network caching (ms) + NETWORK_CACHING = int(os.getenv("NETWORK_CACHING", "320")) + + ALERTS_WS = os.getenv("ALERTS_WS", "ws://host.docker.internal:8010/ws/alerts") + ALERTS_SNAPSHOT_HTTP = os.getenv("ALERTS_SNAPSHOT_HTTP", "") + +# ────────────────────────────────────────────────────────────────────────────── +# Upstream fetcher + DVR state +# ────────────────────────────────────────────────────────────────────────────── +@dataclass +class Segment: + uri: str + duration: float + abs_url: str # absolute URL to fetch + +class DvrState: + def __init__(self, upstream_index_url: str, auth_token: str = "", refresh_ms: int = 800): + self.upstream_index_url = upstream_index_url + self.auth_token = auth_token + self.refresh_ms = refresh_ms + self.init_url: Optional[str] = None + self.target_duration: float = 1.0 + self.version: int = 6 + self.segments: List[Segment] = [] + self._last_playlist_text: Optional[str] = None + self._stop = False + self._ready_evt = threading.Event() + self._lock = threading.Lock() + + @staticmethod + def _absolutize(base: str, maybe_rel: str) -> str: + return urljoin(base, maybe_rel) + + async def _fetch_text(self, session: ClientSession, url: str) -> Tuple[int, str]: + headers = {} + if self.auth_token: + headers["Authorization"] = f"Bearer {self.auth_token}" + async with session.get(url, headers=headers, timeout=10) as resp: + txt = await resp.text() + status = resp.status + if status == 200 and txt.lstrip().startswith("#EXTM3U"): + print(f"[DVR] fetched playlist {status}, {len(txt)} bytes") + else: + print(f"[DVR] upstream status={status}, body[:120]={txt[:120]!r}") + return status, txt + + def stop(self): + self._stop = True + self._ready_evt.set() + + async def run(self): + async with ClientSession() as session: + base = self.upstream_index_url + base_dir = base.rsplit("/", 1)[0] + "/" + while not self._stop: + try: + status, text = await self._fetch_text(session, base) + + # Hard-stop conditions: upstream removed/closed + if status in (404, 410): + print(f"[DVR] upstream gone (HTTP {status}); stop polling") + self.stop() + break + + # Always parse; de-dupe by URL prevents dupes + if text.lstrip().startswith("#EXTM3U"): + self._parse_and_update(text, base_dir) + self._last_playlist_text = text + self._ready_evt.set() + else: + if text != self._last_playlist_text: + self._last_playlist_text = text + print("[DVR] NOTE: got non-HLS body; will retry.") + except Exception as e: + print(f"[DVR] fetch error: {e!r}") + await asyncio.sleep(self.refresh_ms / 1000.0) + + def _parse_and_update(self, playlist_text: str, base_dir: str): + lines = [l.strip() for l in playlist_text.splitlines() if l.strip()] + + target_from_tag: Optional[float] = None + max_seen_extinf = 0.0 + for l in lines: + if l.startswith('#EXT-X-TARGETDURATION:'): + try: + target_from_tag = float(l.split(':', 1)[1]) + except Exception: + pass + elif l.startswith('#EXT-X-VERSION:'): + try: + self.version = int(l.split(':', 1)[1]) + except Exception: + pass + elif l.startswith('#EXT-X-MAP:'): + m = re.search(r'URI="([^"]+)"', l) + if m: + self.init_url = self._absolutize(base_dir, m.group(1)) + elif l.startswith('#EXTINF:'): + try: + d = float(l.split(':', 1)[1].split(',')[0]) + max_seen_extinf = max(max_seen_extinf, d) + except Exception: + pass + + new_segments: List[Segment] = [] + i = 0 + while i < len(lines): + l = lines[i] + if l.startswith('#EXTINF:'): + try: + dur = float(l.split(':', 1)[1].split(',')[0]) + except Exception: + dur = self.target_duration or 1.0 + j = i + 1 + while j < len(lines) and lines[j].startswith('#'): + j += 1 + if j < len(lines): + uri = lines[j] + absu = self._absolutize(base_dir, uri) + new_segments.append(Segment(uri=uri, duration=dur, abs_url=absu)) + i = j + 1 + continue + i += 1 + + if target_from_tag is None or target_from_tag <= 0: + self.target_duration = max(1.0, max_seen_extinf or self.target_duration or 1.0) + else: + self.target_duration = target_from_tag + + added = 0 + with self._lock: + seen_urls = {s.abs_url for s in self.segments} + for s in new_segments: + if s.abs_url not in seen_urls: + self.segments.append(s) + seen_urls.add(s.abs_url) + added += 1 + if added: + print(f"[DVR] +{added} segments (total={len(self.segments)})") + + def render_dvr_vod_playlist(self, *, endlist: bool = False) -> Tuple[str, float]: + with self._lock: + segs = list(self.segments) + init_url = self.init_url + target = int(max(1.0, self.target_duration)) + version = self.version + + total = sum(s.duration for s in segs) + + out: List[str] = [] + out.append('#EXTM3U') + out.append(f'#EXT-X-VERSION:{version}') + out.append('#EXT-X-PLAYLIST-TYPE:EVENT') + out.append('#EXT-X-INDEPENDENT-SEGMENTS') + out.append(f'#EXT-X-TARGETDURATION:{target}') + out.append(f'#EXT-X-MEDIA-SEQUENCE:0') + + if init_url: + out.append(f'#EXT-X-MAP:URI="/seg?u={init_url}"') + + for s in segs: + out.append(f'#EXTINF:{s.duration:.3f},') + out.append(f'/seg?u={s.abs_url}') + + if endlist: + out.append('#EXT-X-ENDLIST') + + return "\n".join(out) + "\n", float(total) + +# ────────────────────────────────────────────────────────────────────────────── +# Aiohttp proxy app +# ────────────────────────────────────────────────────────────────────────────── +import socket + +def is_port_in_use(port=19090, host="127.0.0.1"): + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + return s.connect_ex((host, port)) == 0 + +class ProxyServer: + def __init__(self, media_base: str, camera: Optional[str], incident: Optional[str], + token: str, refresh_ms: int, bind: str, port: int): + self.media_base = media_base.rstrip('/') + self.camera = camera + self.incident = incident + self.token = token + self.refresh_ms = refresh_ms + self.bind = bind + self.port = port + + self.upstream_index: Optional[str] = None + self.dvr: Optional[DvrState] = None + self.resolved: bool = False + + # Dynamic lag control + self._last_seg_404_ts: float = 0.0 # monotonic timestamp of last 404/410 + self._extra_lag_floor: int = 0 # can be bumped to 1–2 and decays over time + + self._app = web.Application() + self._app.router.add_get('/dvr.m3u8', self.handle_dvr) + self._app.router.add_get('/live.m3u8', self.handle_live) + self._app.router.add_get('/seg', self.handle_seg) + self._app.router.add_get('/', self.handle_root) + self._app.router.add_get('/dvr_seek.m3u8', self.handle_dvr_seek) + self._app.router.add_get("/vod", self.handle_vod) + + + # DEBUG routes + self._app.router.add_get('/debug/upstream', self.handle_debug_upstream) + self._app.router.add_get('/debug/dvr', self.handle_debug_dvr) + self._app.router.add_get('/debug/state', self.handle_debug_state) + + self._runner: Optional[web.AppRunner] = None + self._thread: Optional[threading.Thread] = None + self._loop: Optional[asyncio.AbstractEventLoop] = None + + # no-cache headers helper + def _nocache_headers(self) -> dict: + return { + "Cache-Control": "no-store, no-cache, must-revalidate, max-age=0", + "Pragma": "no-cache", + "Expires": "0", + } + + # quick helper so UI knows totals + def get_durations_ms(self) -> Tuple[int, int]: + if not self.dvr: + return (0, 0) + with self.dvr._lock: + segs = list(self.dvr.segments) + total_ms = int(sum(s.duration for s in segs) * 1000) + edge = max(1, int(getattr(Config, "LIVE_EDGE_SEGMENTS", 3))) + lag = max(0, int(getattr(Config, "LIVE_LAG_SEGMENTS", 0))) + # Apply dynamic lag here too so UI stays coherent with playlist + lag += self._current_extra_lag() + keep = min(len(segs), max(1, edge + lag)) + last = segs[-keep:] if keep <= len(segs) else segs + live_win_ms = int(sum(s.duration for s in last) * 1000) + return (total_ms, live_win_ms) + async def handle_vod(self, request): + vod_url = request.query.get("u") + if not vod_url: + raise web.HTTPBadRequest(text="missing u") + + if not vod_url.startswith(("http://", "https://")): + vod_url = f"http://{vod_url.lstrip('/')}" + + headers = {} + if self.token: + headers["Authorization"] = f"Bearer {self.token}" + + range_hdr = request.headers.get("Range") + if range_hdr: + headers["Range"] = range_hdr + + try: + async with ClientSession() as session: + async with session.get(vod_url, headers=headers, timeout=None) as resp: + response_headers = { + "Content-Type": resp.headers.get("Content-Type", "video/mp4"), + "Accept-Ranges": resp.headers.get("Accept-Ranges", "bytes"), + **self._nocache_headers(), + } + if "Content-Length" in resp.headers: + response_headers["Content-Length"] = resp.headers["Content-Length"] + if "Content-Range" in resp.headers: + response_headers["Content-Range"] = resp.headers["Content-Range"] + + print(f"[HTTP] vod {resp.status} -> {vod_url} " + f"({resp.headers.get('Content-Length', '?')} bytes, range={range_hdr})") + + proxy_resp = web.StreamResponse(status=resp.status, headers=response_headers) + await proxy_resp.prepare(request) + + try: + async for chunk in resp.content.iter_chunked(8192): + await proxy_resp.write(chunk) + await proxy_resp.write_eof() + except (asyncio.CancelledError, ConnectionResetError, ClientConnectionError, ClientPayloadError) as e: + # ✅ harmless — VLC moved to another range + print(f"[HTTP] client disconnected early ({type(e).__name__}) — OK") + except Exception as e: + print(f"[HTTP] stream write error: {type(e).__name__}: {e}") + finally: + await proxy_resp.write_eof() + + return proxy_resp + + except Exception as e: + print(f"[HTTP] vod fetch error: {e!r} <- {vod_url}") + return web.Response( + text=f"vod fetch error: {type(e).__name__}: {e}", + content_type="text/plain", + status=502, + headers=self._nocache_headers(), + ) + + + # Dynamic lag amount based on recent 404s + def _current_extra_lag(self) -> int: + now = time.monotonic() + extra = 0 + if self._last_seg_404_ts > 0: + dt = now - self._last_seg_404_ts + # immediately after a 404, be conservative with +2; + # after 10s, ease to +1; after 30s, back to +0 + if dt < 10: + extra = 2 + elif dt < 30: + extra = 1 + else: + extra = 0 + # floor in case we had repeated issues and want to hold higher lag briefly + extra = max(extra, self._extra_lag_floor) + # decay the floor gently + if self._extra_lag_floor and (now - self._last_seg_404_ts) > 20: + self._extra_lag_floor = max(0, self._extra_lag_floor - 1) + return extra + + def _bump_extra_lag(self, floor_to: int): + self._last_seg_404_ts = time.monotonic() + self._extra_lag_floor = max(self._extra_lag_floor, floor_to) + print(f"[LIVE] segment 404/410 observed → increasing effective lag (floor={self._extra_lag_floor})") + + # DEBUG HANDLERS + async def handle_debug_upstream(self, _request: web.Request): + if not self.upstream_index: + return web.Response(text="(no upstream_index yet)\n", content_type="text/plain") + headers = {} + if self.token: + headers['Authorization'] = f'Bearer {self.token}' + try: + async with ClientSession() as session: + async with session.get(self.upstream_index, headers=headers, timeout=10) as resp: + body = await resp.text() + out = [ + f"URL: {self.upstream_index}", + f"HTTP {resp.status}", + "", + body + ] + print(f"[HTTP] debug_upstream {resp.status}") + return web.Response(text="\n".join(out), content_type='text/plain', status=resp.status, headers=self._nocache_headers()) + except Exception as e: + return web.Response(text=f"fetch error: {type(e).__name__}: {e}\n", content_type="text/plain", status=500, headers=self._nocache_headers()) + + async def handle_debug_dvr(self, _request: web.Request): + if not self.dvr: + return web.Response(text="(no DVR yet)\n", content_type="text/plain", headers=self._nocache_headers()) + m3u8, total = self.dvr.render_dvr_vod_playlist(endlist=self.resolved) + hdr = f"# segment_count={len(self.dvr.segments)} total_duration_seconds={total:.3f} resolved={self.resolved}\n" + print(f"[HTTP] debug_dvr segments={len(self.dvr.segments)} total_s={total:.3f} endlist={self.resolved}") + return web.Response(text=hdr + m3u8, content_type="text/plain", headers=self._nocache_headers()) + + async def handle_debug_state(self, _request: web.Request): + info = { + "camera": self.camera, + "incident": self.incident, + "upstream_index": self.upstream_index, + "have_dvr": bool(self.dvr), + "segment_count": len(self.dvr.segments) if self.dvr else 0, + "target_duration": getattr(self.dvr, "target_duration", None) if self.dvr else None, + "have_init": bool(getattr(self.dvr, "init_url", None)) if self.dvr else False, + "resolved": self.resolved, + "extra_lag": self._current_extra_lag(), + } + print(f"[HTTP] state: {info}") + return web.json_response(info, headers=self._nocache_headers()) + + # URL helpers + def _rewrite_to_media_base(self, any_hls_url: str) -> str: + if not any_hls_url: + return any_hls_url + mb = urlparse(self.media_base) + if any_hls_url.startswith('/') and not any_hls_url.startswith('//'): + return f"{mb.scheme}://{mb.netloc}{any_hls_url}" + u = urlparse(any_hls_url) + if not u.scheme or not u.netloc: + return f"{mb.scheme}://{mb.netloc}/{any_hls_url.lstrip('/')}" + return urlunparse(u._replace(scheme=mb.scheme, netloc=mb.netloc)) + + def _normalize_live_playlist(self, upstream_text: str, upstream_index_url: str) -> str: + base_dir = upstream_index_url.rsplit("/", 1)[0] + "/" + lines = [l.strip() for l in upstream_text.splitlines() if l.strip()] + + version = 6 + media_seq = 0 + + segments = [] + init_map_abs = None + max_extinf = 1.0 + + i = 0 + while i < len(lines): + l = lines[i] + if l.startswith("#EXT-X-VERSION:"): + try: version = int(l.split(":", 1)[1]) + except: pass + elif l.startswith("#EXT-X-MEDIA-SEQUENCE:"): + try: media_seq = int(l.split(":", 1)[1]) + except: media_seq = 0 + elif l.startswith("#EXT-X-MAP:"): + m = re.search(r'URI="([^"]+)"', l) + if m: + init_map_abs = urljoin(base_dir, m.group(1)) + elif l.startswith("#EXTINF:"): + try: + dur = float(l.split(':', 1)[1].split(',')[0]) + except Exception: + dur = 1.0 + max_extinf = max(max_extinf, dur) + attached = [] + j = i + 1 + while j < len(lines) and lines[j].startswith("#"): + attached.append(lines[j]); j += 1 + if j < len(lines): + uri = lines[j] + segments.append((dur, attached, uri)) + i = j + else: + i = j + i += 1 + continue + i += 1 + + base_edge = max(1, int(getattr(Config, "LIVE_EDGE_SEGMENTS", 3))) + base_lag = max(0, int(getattr(Config, "LIVE_LAG_SEGMENTS", 0))) + # Add dynamic lag derived from recent 404s + effective_lag = base_lag + self._current_extra_lag() + + total = len(segments) + keep = min(total, max(1, base_edge + effective_lag)) + start_index = max(0, total - keep) + end_index = max(0, total - effective_lag) + trimmed = segments[start_index:end_index] + new_media_seq = media_seq + start_index + + out = [ + "#EXTM3U", + f"#EXT-X-VERSION:{version}", + "#EXT-X-PLAYLIST-TYPE:LIVE", + f"#EXT-X-TARGETDURATION:{int(max(1, round(max_extinf + 0.0001)))}", + "#EXT-X-INDEPENDENT-SEGMENTS", + f"#EXT-X-MEDIA-SEQUENCE:{new_media_seq}", + ] + + if init_map_abs: + out.append(f'#EXT-X-MAP:URI="/seg?u={init_map_abs}"') + + for dur, attached_tags, uri in trimmed: + out.append(f"#EXTINF:{dur:.3f},") + for t in attached_tags: + out.append(t) + seg_abs = urljoin(base_dir, uri) + out.append(f'/seg?u={seg_abs}') + + print(f"[LIVE] served {len(trimmed)} segs (edge={base_edge}, lag={effective_lag}, seq={new_media_seq})") + return "\n".join(out) + "\n" + + # Source switching + def switch_source(self, *, camera: Optional[str] = None, + incident: Optional[str] = None, + upstream_hls: Optional[str] = None): + if camera: + self.camera = camera + if incident: + self.incident = incident + + self.resolved = False + self._last_seg_404_ts = 0.0 + self._extra_lag_floor = 0 + + if upstream_hls: + self.upstream_index = self._rewrite_to_media_base(upstream_hls) + else: + if not (self.camera and self.incident): + return + self.upstream_index = f"{self.media_base}/hls/{self.camera}/{self.incident}/index.m3u8" + + print(f"[SRC] switch to upstream={self.upstream_index}") + + if self.dvr: + try: + self.dvr.stop() + except Exception: + pass + self.dvr = DvrState(self.upstream_index, auth_token=self.token, refresh_ms=self.refresh_ms) + + if self._loop and self._loop.is_running(): + def _start(): + print("[SRC] starting DVR loop") + self._loop.create_task(self.dvr.run()) + self._loop.call_soon_threadsafe(_start) + + def mark_resolved(self): + if self.resolved: + return + self.resolved = True + if self.dvr: + try: + self.dvr.stop() + except Exception: + pass + self.upstream_index = None + print("[SRC] incident resolved; upstream disabled; no DVR freeze") + + # HTTP handlers + async def handle_root(self, _request: web.Request): + return web.Response(text='OK', content_type='text/plain', headers=self._nocache_headers()) + + async def handle_dvr(self, _request: web.Request): + # No DVR freeze behavior + return web.Response(text="#EXTM3U\n#EXT-X-ENDLIST\n", content_type='application/vnd.apple.mpegurl', status=410, headers=self._nocache_headers()) + + async def handle_live(self, _request: web.Request): + if self.resolved or not self.upstream_index: + return web.Response(text="#EXTM3U\n#EXT-X-ENDLIST\n", content_type='application/vnd.apple.mpegurl', status=410, headers=self._nocache_headers()) + + headers = {} + if self.token: + headers['Authorization'] = f'Bearer {self.token}' + try: + async with ClientSession() as session: + async with session.get(self.upstream_index, headers=headers, timeout=10) as resp: + text = await resp.text() + if resp.status >= 400: + print(f"[HTTP] live.m3u8 upstream {resp.status}") + if resp.status in (404, 410): + self.mark_resolved() + return web.Response(text="#EXTM3U\n#EXT-X-ENDLIST\n", content_type='application/vnd.apple.mpegurl', status=410, headers=self._nocache_headers()) + return web.Response(text=f"# upstream {resp.status}\n{text}", content_type='text/plain', status=resp.status, headers=self._nocache_headers()) + except Exception as e: + print(f"[HTTP] live.m3u8 fetch error: {e!r}") + return web.Response(text=f"# fetch error: {type(e).__name__}: {e}\n", content_type='text/plain', status=502, headers=self._nocache_headers()) + + text = self._normalize_live_playlist(text, self.upstream_index) + return web.Response(text=text, content_type='application/vnd.apple.mpegurl', headers=self._nocache_headers()) + + async def handle_seg(self, request: web.Request): + url = request.query.get('u') + if not url: + raise web.HTTPBadRequest(text='missing u') + headers = {} + if self.token: + headers['Authorization'] = f'Bearer {self.token}' + try: + async with ClientSession() as session: + async with session.get(url, headers=headers, timeout=20) as resp: + body = await resp.read() + ctype = resp.headers.get('Content-Type', 'application/octet-stream') + status = resp.status + print(f"[HTTP] seg {status} {ctype} {len(body)} bytes <- {url}") + # On 404/410, bump lag so subsequent /live.m3u8 hides fresher segs + if status in (404, 410): + self._bump_extra_lag(floor_to=2) + return web.Response(body=body, content_type=ctype, status=status, headers=self._nocache_headers()) + except Exception as e: + print(f"[HTTP] seg fetch error: {e!r} <- {url}") + return web.Response(text=f"segment fetch error: {type(e).__name__}: {e}", content_type="text/plain", status=502, headers=self._nocache_headers()) + + async def handle_dvr_seek(self, request: web.Request): + if self.resolved or not self.dvr: + return web.Response(text="#EXTM3U\n#EXT-X-ENDLIST\n", + content_type='application/vnd.apple.mpegurl', + status=410, + headers=self._nocache_headers()) + + t_ms_str = request.query.get('t', '0') + try: + t_ms = max(0, int(float(t_ms_str))) + except Exception: + t_ms = 0 + + with self.dvr._lock: + segs = list(self.dvr.segments) + init_url = self.dvr.init_url + version = self.dvr.version + target = int(max(1.0, self.dvr.target_duration)) + + # Compute which segment contains t_ms and how far into it we need to start. + acc_ms = 0.0 + start_idx = 0 + intra_ms = 0.0 + for i, s in enumerate(segs): + next_acc = acc_ms + s.duration * 1000.0 + if next_acc > t_ms: + start_idx = i + intra_ms = max(0.0, t_ms - acc_ms) + break + acc_ms = next_acc + else: + # Past the end → start at the last segment, no intra offset + start_idx = max(0, len(segs) - 1) + intra_ms = 0.0 + + trimmed = segs[start_idx:] + media_seq = start_idx + + out = [] + out.append('#EXTM3U') + out.append(f'#EXT-X-VERSION:{version}') + out.append('#EXT-X-PLAYLIST-TYPE:EVENT') + out.append('#EXT-X-INDEPENDENT-SEGMENTS') + out.append(f'#EXT-X-TARGETDURATION:{max(1, target)}') + out.append(f'#EXT-X-MEDIA-SEQUENCE:{media_seq}') + + # PRECISE intra-segment start (many players honor this; helps VLC too) + # Start "intra_ms" seconds *into* the first segment of this playlist. + out.append(f'#EXT-X-START:TIME-OFFSET={intra_ms/1000.0:.3f},PRECISE=YES') + + if init_url: + out.append(f'#EXT-X-MAP:URI="/seg?u={init_url}"') + + for s in trimmed: + out.append(f'#EXTINF:{s.duration:.3f},') + out.append(f'/seg?u={s.abs_url}') + + body = "\n".join(out) + "\n" + print(f"[HTTP] dvr_seek.m3u8 t={t_ms}ms -> start_idx={start_idx} intra={int(intra_ms)}ms segs={len(trimmed)} resolved={self.resolved}") + # Optional debug header — handy to confirm behavior in logs/curl: + headers = self._nocache_headers() | {"X-Start-Offset-Ms": str(int(intra_ms))} + return web.Response(text=body, + content_type='application/vnd.apple.mpegurl', + headers=headers) + + # Lifecycle + def start(self): + def _run_loop(): + loop = asyncio.new_event_loop() + self._loop = loop + asyncio.set_event_loop(loop) + self._runner = web.AppRunner(self._app) + loop.run_until_complete(self._runner.setup()) + site = web.TCPSite(self._runner, self.bind, self.port) + loop.run_until_complete(site.start()) + print(f"[HTTP] proxy listening on http://{self.bind}:{self.port}") + try: + loop.run_forever() + finally: + loop.run_until_complete(self._runner.cleanup()) + loop.stop() + if is_port_in_use(19090): + print("[INFO] DVR proxy already running on port 19090, reusing it.") + else: + self._thread = threading.Thread(target=_run_loop, daemon=True) + self._thread.start() + + def stop(self): + if self.dvr: + self.dvr.stop() + +# ────────────────────────────────────────────────────────────────────────────── +# LEFT PANE + UI — unchanged except: no DVR freeze on resolve +# ────────────────────────────────────────────────────────────────────────────── +class AlertsModel(QtCore.QAbstractListModel): + def __init__(self): + super().__init__() + self._items: list[dict] = [] + + def rowCount(self, parent=None): + return len(self._items) + + def data(self, idx, role): + if not idx.isValid(): + return None + if role == QtCore.Qt.ItemDataRole.DisplayRole: + a = self._items[idx.row()] + status = (a.get("status") or "firing").lower() + return f'[{status}] {a.get("camera")} {a.get("anomaly")} ({a.get("incident_id")})' + return None + + def is_empty(self) -> bool: + return len(self._items) == 0 + + def set_alerts(self, items: list[dict]): + self.beginResetModel() + self._items = list(items or []) + self.endResetModel() + + def add_alerts(self, items): + if not items: + return + start = len(self._items) + self.beginInsertRows(QtCore.QModelIndex(), start, start + len(items) - 1) + self._items.extend(items) + self.endInsertRows() + + def get(self, row: int): + return self._items[row] + + def _key(self, it: dict) -> tuple[str, str]: + return (str(it.get("camera") or ""), str(it.get("incident_id") or "")) + + def as_dict(self) -> dict[tuple[str, str], dict]: + return { self._key(it): it for it in self._items } + + def replace_with(self, merged: dict[tuple[str,str], dict]): + self.set_alerts(list(merged.values())) + + def remove_by_key(self, camera: str, incident_id: str): + k = (str(camera or ""), str(incident_id or "")) + for i, it in enumerate(self._items): + if (str(it.get("camera") or ""), str(it.get("incident_id") or "")) == k: + self.beginRemoveRows(QtCore.QModelIndex(), i, i) + self._items.pop(i) + self.endRemoveRows() + return True + return False + +class AlertItemDelegate(QtWidgets.QStyledItemDelegate): + def paint(self, painter: QtGui.QPainter, option: QtWidgets.QStyleOptionViewItem, index: QtCore.QModelIndex): + model: AlertsModel = index.model() # type: ignore + a = model.get(index.row()) + r = option.rect + painter.save() + + if option.state & QtWidgets.QStyle.StateFlag.State_Selected: + painter.fillRect(r, QtGui.QColor("#eef8ff")) + elif option.state & QtWidgets.QStyle.StateFlag.State_MouseOver: + painter.fillRect(r, QtGui.QColor("#f6fafc")) + + status = (a.get("status") or "firing").lower() + color = {"firing": "#16a34a", "resolved": "#94a3b8", "warning": "#f59e0b"}.get(status, "#16a34a") + chip = QtCore.QRect(r.left() + 10, r.center().y() - 5, 10, 10) + painter.setBrush(QtGui.QColor(color)) + painter.setPen(QtCore.Qt.PenStyle.NoPen) + painter.drawEllipse(chip) + + x = chip.right() + 10 + cam = str(a.get("camera") or "") + anom = str(a.get("anomaly") or "") + inc = str(a.get("incident_id") or "")[:8] + + title_font = QtGui.QFont(option.font); title_font.setPointSizeF(option.font.pointSizeF() + 1); title_font.setBold(True) + sub_font = QtGui.QFont(option.font); sub_font.setPointSizeF(option.font.pointSizeF() - 1) + + painter.setPen(QtGui.QColor("#111827")) + painter.setFont(title_font) + painter.drawText(QtCore.QRect(x, r.top() + 4, r.width() - 20, 18), + QtCore.Qt.AlignmentFlag.AlignLeft | QtCore.Qt.AlignmentFlag.AlignVCenter, + f"{cam} • {anom}") + + painter.setPen(QtGui.QColor("#6b7280")) + painter.setFont(sub_font) + painter.drawText(QtCore.QRect(x, r.top() + 22, r.width() - 20, 16), + QtCore.Qt.AlignmentFlag.AlignLeft | QtCore.Qt.AlignmentFlag.AlignVCenter, + f"Incident: {inc}… • Status: {status}") + + painter.restore() + + def sizeHint(self, option: QtWidgets.QStyleOptionViewItem, _index: QtCore.QModelIndex) -> QtCore.QSize: + return QtCore.QSize(220, 42) + +LEFT_LIST_QSS = """ +QListView { + padding: 6px; + background: #ffffff; + border: 1px solid #e5e7eb; + border-radius: 12px; +} +QListView::item { padding: 4px 8px; } +QListView::item:selected { background: #eef8ff; border-radius: 8px; } +QScrollBar:vertical { background: transparent; width: 10px; margin: 8px 2px 8px 2px; border-radius: 5px; } +QScrollBar::handle:vertical { background: #cbd5e1; min-height: 32px; border-radius: 5px; } +QScrollBar::add-line:vertical, QScrollBar::sub-line:vertical { height: 0; } +#LeftHeader { color: #6b7280; font-weight: 600; letter-spacing: 0.4px; margin: 0 6px 6px 6px; } +""" + +class SeekSlider(QtWidgets.QSlider): + hovered = QtCore.pyqtSignal(int) + clickedTo = QtCore.pyqtSignal(int) + draggedTo = QtCore.pyqtSignal(int) + + def __init__(self, orientation, parent=None): + super().__init__(orientation, parent) + self._press_x: Optional[float] = None + self._moved: bool = False + self._CLICK_EPS = 4.0 + self._EDGE_SNAP_PX = 8 # ← new: snap zone near the ends + + def mousePressEvent(self, ev: QtGui.QMouseEvent): + if ev.button() == Qt.MouseButton.LeftButton: + self._press_x = float(ev.position().x()) + self._moved = False + self.setSliderDown(True) + ev.accept() + return + super().mousePressEvent(ev) + + def mouseMoveEvent(self, ev: QtGui.QMouseEvent): + x = float(ev.position().x()) + if self._press_x is not None and abs(x - self._press_x) > self._CLICK_EPS: + self._moved = True + val = self._value_for_x(x) + self.hovered.emit(val) + if self._moved: + self.setValue(val) + super().mouseMoveEvent(ev) + + def mouseReleaseEvent(self, ev: QtGui.QMouseEvent): + if ev.button() == Qt.MouseButton.LeftButton and self._press_x is not None: + x = float(ev.position().x()) + val = self._value_for_x(x) + self.setSliderDown(False) + self.setValue(val) + if self._moved: + self.draggedTo.emit(val) + else: + self.clickedTo.emit(val) + self._press_x = None + ev.accept() + return + super().mouseReleaseEvent(ev) + + def _value_for_x(self, x: float) -> int: + opt = QtWidgets.QStyleOptionSlider() + self.initStyleOption(opt) + groove = self.style().subControlRect( + QtWidgets.QStyle.ComplexControl.CC_Slider, + opt, + QtWidgets.QStyle.SubControl.SC_SliderGroove, + self + ) + if groove.width() <= 0: + return self.value() + + # NEW: snap to exact min/max if you're near the ends + if x <= groove.left() + self._EDGE_SNAP_PX: + return self.minimum() + if x >= groove.right() - self._EDGE_SNAP_PX: + return self.maximum() + + ratio = max(0.0, min(1.0, (x - groove.left()) / groove.width())) + return int(self.minimum() + ratio * (self.maximum() - self.minimum())) + + +class VideoSurface(QtWidgets.QStackedWidget): + def __init__(self, vlc_widget: QtWidgets.QWidget, parent=None): + super().__init__(parent) + self.vlcw = vlc_widget + self.loading = QtWidgets.QLabel("Loading…") + self.loading.setAlignment(Qt.AlignmentFlag.AlignCenter) + self.loading.setStyleSheet("color:#b9c0c7; font-size:18px;") + self.addWidget(self.vlcw) + self.addWidget(self.loading) + self.setCurrentIndex(1) + + def show_loading(self, on: bool): + self.setCurrentIndex(1 if on else 0) + +class VlcWidget(QtWidgets.QFrame): + positionChanged = QtCore.pyqtSignal(float) + timeChanged = QtCore.pyqtSignal(int) + + def __init__(self, instance: vlc.Instance, parent=None): + super().__init__(parent) + self.instance = instance + self.mediaplayer = self.instance.media_player_new() + self.setMinimumSize(640, 360) + self._timer = QtCore.QTimer(self) + self._timer.setInterval(200) + self._timer.timeout.connect(self._on_tick) + self._timer.start() + + def _on_tick(self): + if self.mediaplayer: + try: + pos = self.mediaplayer.get_position() + t = self.mediaplayer.get_time() + if pos >= 0: + self.positionChanged.emit(pos) + if t >= 0: + self.timeChanged.emit(t) + except Exception: + pass + + def set_media(self, mrl: str, options: Optional[List[str]] = None): + print(f"[VLC] set_media {mrl} opts={options or []}") + media = self.instance.media_new(mrl) + for opt in (options or []): + media.add_option(opt) + self.mediaplayer.set_media(media) + + def play(self): + if sys.platform.startswith('linux'): + self.mediaplayer.set_xwindow(int(self.winId())) + elif sys.platform.startswith('win'): + self.mediaplayer.set_hwnd(int(self.winId())) + else: + self.mediaplayer.set_nsobject(int(self.winId())) + print("[VLC] play()") + self.mediaplayer.play() + + def pause(self): + print("[VLC] pause()") + self.mediaplayer.pause() + + def set_position(self, pos01: float): + p = max(0.0, min(1.0, float(pos01))) + print(f"[VLC] set_position {p:.3f}") + self.mediaplayer.set_position(p) + + def set_time_ms(self, t_ms: int): + t = int(max(0, t_ms)) + print(f"[VLC] set_time {t}ms") + self.mediaplayer.set_time(t) + +class IncidentPlayerVLC(QtWidgets.QWidget): + def __init__(self, api,alert_service, parent=None): + super().__init__(parent) + self.api = api + self.alert_service = alert_service + self.cfg = Config() + self.proxy = ProxyServer( + media_base=self.cfg.MEDIA_BASE, + camera=None, + incident=self.cfg.INCIDENT, + token=self.cfg.TOKEN, + refresh_ms=self.cfg.REFRESH_MS, + bind=self.cfg.BIND, + port=self.cfg.PORT, + ) + self.proxy.start() + self.setWindowTitle("AgGuard — Live Incidents") + + self.setMinimumSize(1100, 620) + self.resize(1180, 680) + self.setContentsMargins(6, 6, 6, 6) + + THEME_QSS = """ + QWidget { background:#fafbfc; color:#1f2937; font-size:13px; } + QGroupBox { background:#ffffff; border:1px solid #e5e7eb; border-radius:10px; margin-top:14px; } + QGroupBox::title { subcontrol-origin: margin; left: 12px; top:-6px; padding:0 4px; color:#0f172a; font-weight:600; } + QPushButton { border-radius:10px; padding:7px 12px; background:#10b981; color:white; font-weight:600; border:0; } + QPushButton:hover { background:#0ea371; } + QLabel#timeLabel { color:#6b7280; font-weight:600; } + QLabel#liveBadge { background:#10b981; color:white; padding:3px 8px; border-radius:12px; font-weight:700; } + QLabel#liveBadge.off { background:#9ca3af; } + QSlider::groove:horizontal { height:8px; background:#e7f6ef; border-radius:4px; } + QSlider::handle:horizontal { background:#10b981; width:14px; height:14px; margin:-3px 0; border-radius:7px; } + """ + LEFT_LIST_QSS + self.setStyleSheet(THEME_QSS) + + os.environ.setdefault("VDPAU_DRIVER", "") + os.environ.setdefault("LIBVA_DRIVER_NAME", "") + vlc_opts = [ + f'--network-caching={max(200, int(self.cfg.NETWORK_CACHING))}', + '--live-caching=300', + '--file-caching=300', + '--no-video-title-show', + '--quiet', + '--aout=dummy', + '--avcodec-hw=none', + '--drop-late-frames', + '--skip-frames', + '--clock-jitter=0', + ] + self.vlc_instance = vlc.Instance(*vlc_opts) + self.vlcw = VlcWidget(self.vlc_instance) + self.videoSurface = VideoSurface(self.vlcw) + self.videoSurface.setSizePolicy(QtWidgets.QSizePolicy.Policy.Expanding, + QtWidgets.QSizePolicy.Policy.Expanding) + + # Controls + self.btnLive = QtWidgets.QPushButton('Go Live') + self.btnLive.setObjectName("btnLive") + + self.timeLeft = QtWidgets.QLabel('00:00') + self.timeLeft.setObjectName("timeLabel") + self.slider = SeekSlider(QtCore.Qt.Orientation.Horizontal) + self.slider.setRange(0, 0) + self.liveBadge = QtWidgets.QLabel('LIVE') + self.liveBadge.setObjectName("liveBadge") + + # LEFT PANE + leftContainer = QtWidgets.QGroupBox("Alerts") + leftContainer.setSizePolicy(QtWidgets.QSizePolicy.Policy.Fixed, + QtWidgets.QSizePolicy.Policy.Expanding) + leftContainer.setMinimumWidth(300) + leftContainer.setMaximumWidth(340) + + leftLayout = QtWidgets.QVBoxLayout(leftContainer) + leftLayout.setContentsMargins(10, 10, 10, 10) + leftLayout.setSpacing(8) + + self.alertList = QtWidgets.QListView() + self.alertList.setMouseTracking(True) + self.alertList.setSelectionMode(QtWidgets.QAbstractItemView.SelectionMode.SingleSelection) + self.alertList.setSelectionBehavior(QtWidgets.QAbstractItemView.SelectionBehavior.SelectItems) + self.alertList.setVerticalScrollMode(QtWidgets.QAbstractItemView.ScrollMode.ScrollPerPixel) + self.alertList.setEditTriggers(QtWidgets.QAbstractItemView.EditTrigger.NoEditTriggers) + self.alertList.setUniformItemSizes(True) + self.alertList.setSpacing(4) + self.alertList.setHorizontalScrollBarPolicy(QtCore.Qt.ScrollBarPolicy.ScrollBarAlwaysOff) + + self.alertModel = AlertsModel() + self.alertList.setModel(self.alertModel) + self.alertList.setItemDelegate(AlertItemDelegate(self.alertList)) + leftLayout.addWidget(self.alertList) + + # Details inside player pane + self.detailGroup = QtWidgets.QGroupBox("Details") + self.detailGroup.setSizePolicy(QtWidgets.QSizePolicy.Policy.Preferred, + QtWidgets.QSizePolicy.Policy.Fixed) + grid = QtWidgets.QGridLayout(self.detailGroup) + grid.setContentsMargins(12, 8, 12, 12) + grid.setHorizontalSpacing(24) + grid.setVerticalSpacing(6) + + labels = ["Camera:", "Anomaly:", "Incident ID:", "Status:", "Start Time:"] + self.lblVals = [] + for i, title in enumerate(labels): + k = QtWidgets.QLabel(title) + v = QtWidgets.QLabel("–") + v.setStyleSheet("color:#6b7280;") + grid.addWidget(k, i, 0, 1, 1) + grid.addWidget(v, i, 1, 1, 1) + self.lblVals.append(v) + self.detailGroup.setMaximumHeight(160) + + # Right stack + self.rightStack = QtWidgets.QStackedWidget() + self.rightStack.setContentsMargins(0, 0, 0, 0) + + self.emptyPane = QtWidgets.QWidget() + ep_layout = QtWidgets.QVBoxLayout(self.emptyPane) + ep_layout.setContentsMargins(10, 10, 10, 10) + ep_layout.setSpacing(0) + + noTitle = QtWidgets.QLabel("No alerts") + noTitle.setAlignment(Qt.AlignmentFlag.AlignCenter) + noTitle.setStyleSheet("font-size:22px; font-weight:800; color:#111827;") + + noSub = QtWidgets.QLabel("Alerts will appear here.") + noSub.setAlignment(Qt.AlignmentFlag.AlignCenter) + noSub.setWordWrap(True) + noSub.setStyleSheet("color:#6b7280;") + + ep_layout.addStretch(1) + ep_layout.addWidget(noTitle) + ep_layout.addWidget(noSub) + ep_layout.addStretch(3) + + self.playerPane = QtWidgets.QGroupBox("") + rightLayout = QtWidgets.QVBoxLayout(self.playerPane) + rightLayout.setContentsMargins(10, 10, 10, 10) + rightLayout.setSpacing(10) + + titleRow = QtWidgets.QHBoxLayout() + titleRow.setContentsMargins(0, 0, 0, 0) + titleRow.setSpacing(10) + title = QtWidgets.QLabel("AgGuard — Security Alerts") + title.setStyleSheet("font-size:20px; font-weight:800; color:#111827;") + dotLive = QtWidgets.QLabel("• LIVE") + dotLive.setStyleSheet("color:#10b981; font-weight:700;") + titleRow.addWidget(title) + titleRow.addStretch(1) + titleRow.addWidget(dotLive) + + ctrls = QtWidgets.QHBoxLayout() + ctrls.setContentsMargins(0, 0, 0, 0) + ctrls.setSpacing(10) + ctrls.addWidget(self.btnLive) + ctrls.addSpacing(10) + ctrls.addWidget(self.timeLeft) + ctrls.addSpacing(8) + ctrls.addWidget(self.slider, 1) + ctrls.addSpacing(8) + ctrls.addWidget(self.liveBadge) + + rightLayout.addLayout(titleRow, 0) + rightLayout.addWidget(self.videoSurface, 1) + rightLayout.addLayout(ctrls, 0) + rightLayout.addWidget(self.detailGroup, 0) + + + self.rightStack.addWidget(self.emptyPane) + self.rightStack.addWidget(self.playerPane) + self.rightStack.setCurrentIndex(0) + + splitter = QtWidgets.QSplitter(Qt.Orientation.Horizontal) + splitter.setChildrenCollapsible(False) + splitter.setHandleWidth(6) + splitter.addWidget(leftContainer) + splitter.addWidget(self.rightStack) + splitter.setStretchFactor(0, 0) + splitter.setStretchFactor(1, 1) + splitter.setSizes([320, 900]) + + outer = QtWidgets.QVBoxLayout(self) + outer.setContentsMargins(6, 6, 6, 6) + outer.setSpacing(6) + # outer.addWidget(splitter) + # --- Navigation bar --- + navBar = QtWidgets.QHBoxLayout() + navBar.setContentsMargins(6, 6, 6, 6) + navBar.setSpacing(8) + + btnLiveView = QtWidgets.QPushButton("Live Incidents") + btnLiveView.setCheckable(True) + btnLiveView.setChecked(True) + btnHistory = QtWidgets.QPushButton("Events History") + btnHistory.setCheckable(True) + + btnStyle = """ + QPushButton { + background:#e5e7eb; border:none; border-radius:8px; + padding:6px 12px; font-weight:600; + } + QPushButton:checked { background:#10b981; color:white; } + """ + btnLiveView.setStyleSheet(btnStyle) + btnHistory.setStyleSheet(btnStyle) + + navBar.addWidget(btnLiveView) + navBar.addWidget(btnHistory) + navBar.addStretch(1) + + # --- Main content stack --- + self.stack = QtWidgets.QStackedWidget() + self.livePage = QtWidgets.QWidget() + self.liveLayout = QtWidgets.QVBoxLayout(self.livePage) + self.liveLayout.setContentsMargins(0, 0, 0, 0) + self.liveLayout.addWidget(splitter) + + self.historyPage = EventsHistoryPage(api=self.api) + + self.stack.addWidget(self.livePage) + self.stack.addWidget(self.historyPage) + + # --- Combine all together --- + outer.addLayout(navBar) + outer.addWidget(self.stack) + + # --- Navigation logic --- + btnLiveView.clicked.connect(lambda: self._switch_page(0, btnLiveView, btnHistory)) + btnHistory.clicked.connect(lambda: self._switch_page(1, btnHistory, btnLiveView)) + + + # State + self.mode_live = False + self.dvr_duration_ms = 0 + self._dragging = False + self.current_camera: Optional[str] = None + self.current_incident: Optional[str] = None + self.current_status: str = "firing" + + self._last_abs_t_ms: int = 0 + self._playlist_offset_ms: int = 0 + + self._ui_freeze_deadline: float = 0.0 + self._seek_guard_deadline: float = 0.0 + + self._live_sync = QTimer(self) + self._live_sync.setInterval(800) + self._live_sync.timeout.connect(self._maybe_sync_live_timeline) + + #new + self._dvr_growth = QTimer(self) + self._dvr_growth.setInterval(1200) + self._dvr_growth.timeout.connect(self._maybe_grow_dvr_range_only) + + # --- Subscribe to alert service --- + self.alert_service.alertsUpdated.connect(self._on_alerts_updated) + self.alert_service.alertAdded.connect(self._on_alert_added) + self.alert_service.alertRemoved.connect(self._on_alert_removed) + + # Trigger initial load + if not self.alert_service.alerts: + print("[IncidentPlayer] No cached alerts yet — calling load_initial()") + self.alert_service.load_initial() + else: + print("[IncidentPlayer] Using cached alerts:", len(self.alert_service.alerts)) + self._on_alerts_updated(self.alert_service.alerts) + # WebSocket + snapshot + # self.ws: Optional[QWebSocket] = None + # self.ws_url: Optional[QUrl] = QUrl(self.cfg.ALERTS_WS) if self.cfg.ALERTS_WS else None + # self._ws_backoff_sec = 1 + # self._ws_ping = QTimer(self) + # self._ws_ping.setInterval(15000) + # self._ws_ping.timeout.connect(self._ws_send_ping) + # self._got_initial_snapshot = False + # self._snapshot_resends = 0 + # self._snapshot_retry_timer = QTimer(self) + # self._snapshot_retry_timer.setInterval(1200) + # self._snapshot_retry_timer.timeout.connect(self._on_snapshot_retry_tick) + # self.net = QNetworkAccessManager(self) + # self.net.finished.connect(self._on_http_finished) + # self._awaiting_http_snapshot = False + + # if self.ws_url: + # self._ws_connect() + + # Connections + self.btnLive.clicked.connect(self._go_live) + self.slider.hovered.connect(self._on_slider_hover) + self.slider.clickedTo.connect(self._on_slider_clicked) + self.slider.draggedTo.connect(self._on_slider_drag_released) + self.vlcw.positionChanged.connect(self._on_vlc_pos) + self.vlcw.timeChanged.connect(self._on_vlc_time) + self.alertList.clicked.connect(self._on_pick_alert_from_list) + + self._show_player(False) + self._set_idle() + + def _on_alerts_updated(self, alerts: list): + """Called when AlertService emits full list (on initial load).""" + print(f"[AlertService] Full update: {len(alerts)} alerts") + print("[DEBUG] alerts from AlertService:", alerts) + self._apply_firing_list(alerts) + + def _on_alert_added(self, alert: dict): + """Called when a new alert arrives in real-time.""" + print(f"[AlertService] New alert added: {alert.get('alert_id')}") + self._merge_firing_deltas([alert]) + + def _on_alert_removed(self, alert_id: str): + """Called when an alert is resolved/removed.""" + print(f"[AlertService] Alert removed: {alert_id}") + self.alertModel.set_alerts([ + a for a in self.alertModel._items if a.get("alert_id") != alert_id + ]) + self._update_right_pane_visibility() + + + def _fetch_active_alerts_from_db(self): + """Fetch current active alerts directly from the DB API.""" + try: + print("[DB] Fetching active alerts from dashboard API...") + url = f"{self.api.base}/api/tables/alerts" + resp = self.api.http.get(url, timeout=10) + if resp.status_code != 200: + print(f"[DB] Failed to fetch alerts: {resp.status_code}") + return [] + + data = resp.json() + alerts = data.get("rows", data) if isinstance(data, dict) else data + print(f"[DB] Loaded {len(alerts)} active alerts from DB.") + return alerts + except Exception as e: + print(f"[DB] Error fetching alerts: {e}") + return [] + + + # ───── NO-ALERTS helpers ───── + def _show_player(self, on: bool): + self.rightStack.setCurrentIndex(1 if on else 0) + print(f"[UI] right pane -> {'PLAYER' if on else 'NO-ALERTS'}") + + def _update_right_pane_visibility(self): + have_any = not self.alertModel.is_empty() + print("_update_right_pane_visibility called have any",have_any) + if not have_any: + try: + self.vlcw.mediaplayer.stop() + except Exception: + pass + self._set_idle() + self._show_player(False) + else: + self._show_player(True) + + def _switch_page(self, index: int, active_btn: QtWidgets.QPushButton, inactive_btn: QtWidgets.QPushButton): + self.stack.setCurrentIndex(index) + active_btn.setChecked(True) + inactive_btn.setChecked(False) + print(f"[UI] switched to page index={index}") + + + # ───── alerts helpers ───── + def _key(self, it: dict) -> tuple[str, str]: + return (str(it.get('camera') or ''), str(it.get('incident_id') or '')) + + # def _normalize_alert(self, it: dict) -> dict: + # if not isinstance(it, dict): + # return {} + + # labels = it.get("labels", {}) or {} + # ann = it.get("annotations", {}) or {} + + # # Normalize field names + # flat = { + # "camera": labels.get("device") or ann.get("device") or "unknown", + # "incident_id": labels.get("alert_id") or ann.get("alert_id"), + # "anomaly": labels.get("alertname") or ann.get("category") or "unknown", + # "hls": ann.get("hls"), + # "vod": ann.get("vod"), + # "image_url": ann.get("image_url"), + # "lat": ann.get("lat"), + # "lon": ann.get("lon"), + # "severity": ann.get("severity"), + # "summary": ann.get("summary"), + # "recommendation": ann.get("recommendation"), + # "category": ann.get("category"), + # "startsAt": it.get("startsAt"), + # "endsAt": it.get("endsAt"), + # } + + # # Status inference (Alertmanager has endsAt → resolved) + # ends_at = it.get("endsAt") + # flat["status"] = "resolved" if ends_at else "firing" + + # return flat + def _normalize_alert(self, it: dict) -> dict: + return { + "camera": it.get("device_id") or it.get("camera"), + "incident_id": it.get("alert_id") or it.get("incident_id"), + "anomaly": it.get("alert_type") or it.get("anomaly"), + "hls": it.get("hls"), + "vod": it.get("vod"), + "image_url": it.get("image_url"), + "summary": it.get("summary"), + "severity": it.get("severity"), + "started_at": it.get("started_at") or it.get("startsAt"), + "ended_at": it.get("ended_at") or it.get("endsAt"), + "status": "firing" if not (it.get("ended_at") or it.get("endsAt")) else "resolved", + } + + + + + ##new + def _maybe_grow_dvr_range_only(self): + # Only expand the slider max while paused/seeked (DVR mode). Never move the thumb. + if self.mode_live: + self._dvr_growth.stop() + return + if self.proxy.dvr and not self.proxy.resolved: + _, total = self.proxy.dvr.render_dvr_vod_playlist() + new_max = int(total * 1000) + if new_max > self.dvr_duration_ms: + self.dvr_duration_ms = new_max + self.slider.setRange(0, self.dvr_duration_ms) + + + def _apply_firing_list(self, firing: list[dict]): + firing = [self._normalize_alert(it) for it in (firing or []) if it] + print("[DEBUG] normalized firing list:", firing) + firing = [it for it in firing if (it.get("status") or "firing").lower() == "firing"] + + sel = self.alertList.selectionModel().currentIndex() if self.alertList.selectionModel() else QtCore.QModelIndex() + selected_inc = selected_cam = None + if sel.isValid(): + try: + cur = self.alertModel.get(sel.row()) + selected_inc = cur.get('incident_id') + selected_cam = cur.get('camera') + except Exception: + pass + + self.alertModel.set_alerts(firing) + self._update_right_pane_visibility() + + if selected_inc is not None: + for row, it in enumerate(firing): + if it.get('incident_id') == selected_inc and it.get('camera') == selected_cam: + idx = self.alertModel.index(row, 0) + self.alertList.selectionModel().select(idx, QtCore.QItemSelectionModel.SelectionFlag.ClearAndSelect) + self.alertList.setCurrentIndex(idx) + break + + cur_cam = self.current_camera + cur_inc = self.current_incident or self.cfg.INCIDENT + still_there = any( + it.get('camera') == cur_cam and it.get('incident_id') == cur_inc + for it in firing + ) if (cur_cam and cur_inc) else False + + if (cur_cam and cur_inc) and not still_there: + if self.current_camera and self.current_incident: + self.alertModel.remove_by_key(self.current_camera, self.current_incident) + self.current_status = "resolved" + self.proxy.mark_resolved() + try: + self.vlcw.mediaplayer.stop() + except Exception: + pass + + if firing: + self._show_player(True) + self._play_alert(firing[0]) + else: + self._set_idle() + self._show_player(False) + return + + if firing and not still_there: + self._show_player(True) + self._play_alert(firing[0]) + + def _merge_firing_deltas(self, deltas: list[dict]): + current = self.alertModel.as_dict() + changed = False + + for raw in (deltas or []): + it = self._normalize_alert(raw) + print("[DEBUG] normalized:", it) + k = self._key(it) + + if it.get('status') == 'firing': + if current.get(k) != it: + current[k] = it + changed = True + else: + if k in current: + current.pop(k, None) + changed = True + + if (self.current_camera, self.current_incident) == k and it.get('status') != 'firing': + self.current_status = "resolved" + self.proxy.mark_resolved() + if self.current_camera and self.current_incident: + self.alertModel.remove_by_key(self.current_camera, self.current_incident) + try: + self.vlcw.mediaplayer.stop() + except Exception: + pass + + if not changed: + return + + sel = self.alertList.selectionModel().currentIndex() if self.alertList.selectionModel() else QtCore.QModelIndex() + selected_key = None + if sel.isValid(): + try: + cur = self.alertModel.get(sel.row()) + selected_key = self._key(cur) + except Exception: + pass + + self.alertModel.replace_with(current) + self._update_right_pane_visibility() + + if selected_key: + items = list(current.values()) + for row, it in enumerate(items): + if self._key(it) == selected_key: + idx = self.alertModel.index(row, 0) + self.alertList.selectionModel().select(idx, QtCore.QItemSelectionModel.SelectionFlag.ClearAndSelect) + self.alertList.setCurrentIndex(idx) + break + + cur_cam = self.current_camera + cur_inc = self.current_incident or self.cfg.INCIDENT + has_current = (cur_cam and cur_inc and (cur_cam, cur_inc) in current) + if not has_current: + items = list(current.values()) + if items: + self._show_player(True) + self._play_alert(items[0]) + else: + try: self.vlcw.mediaplayer.stop() + except Exception: pass + self._set_idle() + self._show_player(False) + + # ───── helpers ───── + def _freeze_ui(self, seconds: float = 0.8): + self._ui_freeze_deadline = time.monotonic() + max(0.1, seconds) + + def _maybe_sync_live_timeline(self): + if not self.mode_live: + self._live_sync.stop() + return + + total_ms, live_win_ms = self.proxy.get_durations_ms() + if total_ms <= 0 or live_win_ms <= 0: + return + + target_offset = max(0, total_ms - live_win_ms) + t_rel = self.vlcw.mediaplayer.get_time() + if t_rel < 0: + t_rel = 0 + abs_t = min(target_offset + t_rel, total_ms) + + changed = (self.dvr_duration_ms != total_ms) or (abs(self._playlist_offset_ms - target_offset) > 250) + if changed: + self.dvr_duration_ms = total_ms + self._playlist_offset_ms = target_offset + + self.slider.setRange(0, total_ms) + self.slider.blockSignals(True) + self.slider.setValue(abs_t) + self.slider.blockSignals(False) + + self._last_abs_t_ms = abs_t + self._update_time_label(abs_t) + + # WebSocket & snapshot (same as before) … + def _ws_connect(self): + if not self.ws_url: + print("[WS] ALERTS_WS not set; skipping alerts websocket.") + return + if self.ws: + try: + self.ws.abort() + except Exception: + pass + self.ws = QWebSocket() + self.ws.connected.connect(self._on_ws_connected) + self.ws.textMessageReceived.connect(self._on_ws_msg) + self.ws.disconnected.connect(self._on_ws_disconnected) + self.ws.errorOccurred.connect(self._on_ws_error) + print(f"[WS] connecting to {self.ws_url.toString()}") + self.ws.open(self.ws_url) + + def _send_ws_snapshot_request(self): + try: + if self.ws and self.ws.isValid(): + self.ws.sendTextMessage('{"type":"get_snapshot"}') + self.ws.sendTextMessage('{"type":"snapshot_request"}') + except Exception: + pass + + def _on_ws_connected(self): + print("[WS] connected") + self._ws_backoff_sec = 1 + self._ws_ping.start() + + # Instead of waiting for a snapshot message, immediately fetch from DB + alerts = self._fetch_active_alerts_from_db() + if alerts: + print(f"[WS] Initial load: {len(alerts)} alerts fetched from DB") + self._apply_firing_list(alerts) + else: + print("[WS] No active alerts found in DB") + + # Continue listening for WebSocket deltas + self._got_initial_snapshot = True + self._snapshot_retry_timer.stop() + + + def _on_snapshot_retry_tick(self): + if self._got_initial_snapshot: + self._snapshot_retry_timer.stop() + return + if self._snapshot_resends < 3: + print(f"[WS] requesting snapshot again (attempt {self._snapshot_resends+2}/4)") + self._send_ws_snapshot_request() + self._snapshot_resends += 1 + if self.cfg.ALERTS_SNAPSHOT_HTTP and not self._awaiting_http_snapshot: + self._request_http_snapshot() + if self._snapshot_resends >= 3 and not self.cfg.ALERTS_SNAPSHOT_HTTP: + self._snapshot_retry_timer.stop() + + def _request_http_snapshot(self): + try: + url = QUrl(self.cfg.ALERTS_SNAPSHOT_HTTP) + if not url.isValid() or url.isEmpty(): + return + req = QNetworkRequest(url) + if self.cfg.TOKEN and self.cfg.TOKEN != "CHANGE_ME": + req.setRawHeader(b"Authorization", f"Bearer {self.cfg.TOKEN}".encode("utf-8")) + self._awaiting_http_snapshot = True + print(f"[HTTP] requesting snapshot from {url.toString()}") + self.net.get(req) + except Exception as e: + print(f"[HTTP] snapshot request error: {e!r}") + + def _on_http_finished(self, reply): + try: + if not self._awaiting_http_snapshot: + reply.deleteLater() + return + self._awaiting_http_snapshot = False + data = bytes(reply.readAll()) + try: + payload = json.loads(data.decode("utf-8")) + except Exception: + payload = [] + if isinstance(payload, dict): + items = payload.get("items") or payload.get("alerts") or payload.get("data") or [] + elif isinstance(payload, list): + items = payload + else: + items = [] + firing = [it for it in items if (it or {}).get("status", "").lower() == "firing"] + print(f"[HTTP] snapshot received: {len(firing)} firing") + self._apply_firing_list(firing) + self._got_initial_snapshot = True + self._snapshot_retry_timer.stop() + finally: + reply.deleteLater() + + def _on_ws_disconnected(self): + print("[WS] disconnected") + self._ws_ping.stop() + self._schedule_ws_reconnect() + + def _on_ws_error(self, err): + print(f"[WS] error: {self.ws.errorString()}") + if not self._ws_ping.isActive(): + self._schedule_ws_reconnect() + + def _ws_send_ping(self): + try: + if self.ws and self.ws.isValid(): + self.ws.sendTextMessage('{"type":"ping"}') + except Exception: + pass + + def _schedule_ws_reconnect(self): + delay = min(self._ws_backoff_sec, 30) + print(f"[WS] reconnecting in {delay}s...") + QtCore.QTimer.singleShot(int(delay * 1000), self._ws_connect) + self._ws_backoff_sec = min(self._ws_backoff_sec * 2, 30) + + def _on_ws_msg(self, text: str): + print("=" * 80) + print("[WS] RAW MESSAGE:", text[:600].replace("\n", " ")) + + # Try to parse JSON + try: + msg = json.loads(text) + except Exception as e: + print(f"[WS] non-JSON message ({type(e).__name__}):", text[:120]) + return + + t = (msg.get("type") or "").lower() + items = msg.get("items") or msg.get("alerts") or msg.get("data") or [] + print(f"[WS] Type='{t}' | items={len(items)}") + + # Log short summary of each alert item + for i, it in enumerate(items): + print(f" [{i}] camera={it.get('camera')} " + f"incident={it.get('incident_id')} " + f"status={it.get('status')} " + f"endsAt={it.get('endsAt')} " + f"summary={it.get('summary')}") + + # Normalize alerts + norm = [self._normalize_alert(it) for it in (items or []) if it] + print(f"[WS] After normalize → {[it.get('status') for it in norm]}") + + # Infer missing statuses + for it in norm: + st = (it.get("status") or "").lower() + if not st: + it["status"] = "resolved" if it.get("endsAt") else "firing" + + # Compute "currently firing" view + firing_now = [it for it in norm if (it.get("status") or "firing").lower() == "firing"] + print(f"[WS] Firing after filter: {len(firing_now)} / {len(norm)} → " + f"{[it.get('status') for it in firing_now]}") + + # Decide what to do based on type + if t == 'am_alerts': + print("[WS] Handling am_alerts as delta update (can include resolved)") + self._merge_firing_deltas(norm) + return + + if t in ('snapshot', 'update', 'delta', 'patch'): + print(f"[WS] Handling message type '{t}' as full state replace") + self._apply_firing_list(firing_now) + if t == 'snapshot': + self._got_initial_snapshot = True + self._snapshot_retry_timer.stop() + return + + + # Fallback for unknown message types + if isinstance(items, list): + print(f"[WS] Unknown type '{t}' → applying default rule ({len(firing_now)} firing)") + self._apply_firing_list(firing_now) + + + + # List click → play + def _on_pick_alert_from_list(self, idx: QtCore.QModelIndex): + if not idx.isValid(): + return + it = self.alertModel.get(idx.row()) + print(f"[UI] picked alert: {it}") + self._show_player(True) + self._play_alert(it) + + def _play_alert(self, it: dict): + cam = it.get('camera') + inc = it.get('incident_id') or self.cfg.INCIDENT + hls_url = it.get('hls') or None + + self.current_camera = cam + self.current_incident = inc + self.current_status = (it.get('status') or 'firing').lower() + + self.proxy.switch_source(camera=cam, incident=inc, upstream_hls=hls_url) + self.setWindowTitle("AgGuard — Live Incidents") + self._update_details(it) + self.videoSurface.show_loading(True) + + QtCore.QTimer.singleShot(150, self._go_live) + + def _update_details(self, it: dict): + vals = [ + it.get('camera') or '–', + it.get('anomaly') or '–', + it.get('incident_id') or '–', + (it.get('status') or self.current_status or '–'), + it.get('startsAt') or '–', + ] + for lbl, v in zip(self.lblVals, vals): + lbl.setText(v) + + # ───── slider / playback helpers ───── + def _fmt(self, ms: int) -> str: + s = max(0, ms // 1000) + h, s = divmod(s, 3600) + m, s = divmod(s, 60) + if h: + return f"{h:d}:{m:02d}:{s:02d}" + return f"{m:d}:{s:02d}" + + def _set_live_badge(self, live: bool): + if live: + self.liveBadge.setText("LIVE") + self.liveBadge.setStyleSheet("background:#10b981; color:white; padding:3px 8px; border-radius:12px; font-weight:700;") + else: + self.liveBadge.setText("DVR") + self.liveBadge.setStyleSheet("background:#9ca3af; color:white; padding:3px 8px; border-radius:12px; font-weight:700;") + + def _set_idle(self): + self.mode_live = False + self._set_live_badge(False) + self.dvr_duration_ms = 0 + self._last_abs_t_ms = 0 + self._playlist_offset_ms = 0 + self.timeLeft.setText("00:00") + self.videoSurface.show_loading(True) + for v in self.lblVals: + v.setText("–") + print("[MODE] IDLE") + + def _go_live(self): + # resume live sync; stop DVR growth + if not self._live_sync.isActive(): + self._live_sync.start() + self._dvr_growth.stop() + + if self.current_status == "resolved": + print("[MODE] resolved; not going live") + self._set_idle() + return + + if not self.proxy.upstream_index: + return + + total_ms, live_win_ms = self.proxy.get_durations_ms() + self.dvr_duration_ms = max(self.dvr_duration_ms, total_ms) + self._playlist_offset_ms = max(0, total_ms - live_win_ms) + + self.mode_live = True + self._set_live_badge(True) + self.videoSurface.show_loading(True) + + self._freeze_ui(1.0) + self._update_time_label(self.dvr_duration_ms) + + live_url = f"http://{self.cfg.BIND}:{self.cfg.PORT}/live.m3u8" + + try: + self.vlcw.mediaplayer.stop() + except Exception: + pass + + live_edge_total = max(2, int(self.cfg.LIVE_EDGE_SEGMENTS) + int(self.cfg.LIVE_LAG_SEGMENTS)) + self.vlcw.set_media( + live_url, + options=[ + "--demux=hls", + ":no-audio", + ":http-reconnect=true", + ":hls-keep-live-session", + f":hls-live-edge={min(3, max(2, live_edge_total))}", + ":hls-segment-threads=2", + f":network-caching={max(200, int(self.cfg.NETWORK_CACHING))}", + ], + ) + self.vlcw.play() + self._live_sync.start() + + self.slider.setEnabled(True) + self.slider.setRange(0, self.dvr_duration_ms) + self.slider.blockSignals(True) + self.slider.setValue(self.dvr_duration_ms) + self.slider.blockSignals(False) + + QtCore.QTimer.singleShot(300, lambda: self.videoSurface.show_loading(False)) + print(f"[MODE] LIVE offset={self._playlist_offset_ms}ms total={self.dvr_duration_ms}ms") + + def _load_dvr(self): + print("[DVR] _load_dvr called but DVR freeze is disabled.") + self._set_idle() + + def _on_slider_clicked(self, value: int): + print(f"[SEEK] click -> {value}ms (mode_live={self.mode_live})") + self._freeze_ui(0.8) + if self.mode_live: + self.mode_live = False + self._set_live_badge(False) + self._seek_via_playlist(value) + + def _on_slider_drag_released(self, value: int): + print(f"[SEEK] drag-release -> {value}ms (mode_live={self.mode_live})") + self._freeze_ui(0.8) + if self.mode_live: + self.mode_live = False + self._set_live_badge(False) + self._seek_via_playlist(value) + + def _seek_via_playlist(self, t_ms: int): + # kill live sync right away so it cannot pull the thumb toward live + self._live_sync.stop() + if not self._dvr_growth.isActive(): + self._dvr_growth.start() + + if self.current_status == "resolved" or not self.proxy.dvr: + print("[SEEK] ignored (no DVR while resolved)") + return + + t_ms = max(0, min(int(t_ms), max(0, self.dvr_duration_ms))) + seek_url = f"http://{self.cfg.BIND}:{self.cfg.PORT}/dvr_seek.m3u8?t={t_ms}" + print(f"[SEEK] switching media to seek playlist: {seek_url}") + try: + self.vlcw.mediaplayer.stop() + except Exception: + pass + + self._playlist_offset_ms = t_ms + + if self.slider.maximum() < max(self.dvr_duration_ms, t_ms): + self.slider.setRange(0, max(self.dvr_duration_ms, t_ms)) + + self.vlcw.set_media(seek_url, options=["--demux=hls", ":no-audio"]) + self.vlcw.play() + self.mode_live = False + self._set_live_badge(False) + + self._update_time_label(t_ms) + self.slider.blockSignals(True) + self.slider.setValue(t_ms) + self.slider.blockSignals(False) + + self._last_abs_t_ms = t_ms + self._seek_guard_deadline = time.monotonic() + 2.0 + + QTimer.singleShot(700, lambda: setattr(self, "_ui_freeze_deadline", 0.0)) + + def _on_slider_hover(self, value: int): + if self.dvr_duration_ms > 0: + self._update_time_label(int(value)) + + def _on_vlc_pos(self, _pos01: float): + pass + + def _on_vlc_time(self, t_ms: int): + if t_ms < 0: + return + if time.monotonic() < self._ui_freeze_deadline: + return + + absolute_ms = self._playlist_offset_ms + t_ms + + now = time.monotonic() + if absolute_ms < self._last_abs_t_ms: + if now < self._seek_guard_deadline: + self._last_abs_t_ms = absolute_ms + else: + absolute_ms = self._last_abs_t_ms + else: + self._last_abs_t_ms = absolute_ms + + self._update_time_label(absolute_ms) + + if self.dvr_duration_ms > 0: + self.slider.blockSignals(True) + self.slider.setValue(min(absolute_ms, self.dvr_duration_ms)) + self.slider.blockSignals(False) + + if self.proxy.dvr and not self.proxy.resolved and (int(time.time()) % 2 == 0): + _, total = self.proxy.dvr.render_dvr_vod_playlist() + new_dur = int(total * 1000) + if new_dur > self.dvr_duration_ms: + self.dvr_duration_ms = new_dur + self.slider.setRange(0, self.dvr_duration_ms) + + def _update_time_label(self, t_ms: int): + s = max(0, t_ms // 1000) + h, s = divmod(s, 3600) + m, s = divmod(s, 60) + txt = f"{h:d}:{m:02d}:{s:02d}" if h else f"{m:d}:{s:02d}" + self.timeLeft.setText(txt) + + def closeEvent(self, event: QtGui.QCloseEvent): + try: + self.proxy.stop() + except Exception: + pass + super().closeEvent(event) + + diff --git a/GUI/src/vast/views/sensorDetailsTab.py b/GUI/src/vast/views/sensorDetailsTab.py new file mode 100644 index 000000000..9d78abe0a --- /dev/null +++ b/GUI/src/vast/views/sensorDetailsTab.py @@ -0,0 +1,228 @@ +import traceback +import plotly.graph_objects as go +from PyQt6.QtWidgets import ( + QWidget, QVBoxLayout, QHBoxLayout, QPushButton, QLabel, QComboBox +) +from PyQt6.QtWebEngineWidgets import QWebEngineView +from PyQt6.QtCore import QTimer + + +class SensorDetailsTab(QWidget): + """Sensor Details Tab – compact, clean, and fully in English.""" + + def __init__(self, api, parent=None): + super().__init__(parent) + self.api = api + self.sensor_id = None + self.sensor_names = [] + + main_layout = QVBoxLayout(self) + main_layout.setContentsMargins(10, 10, 10, 10) + main_layout.setSpacing(6) + + # --- Sensor selection area --- + self.input_layout = QHBoxLayout() + self.label = QLabel("Select sensor:") + self.label.setStyleSheet("font-weight:600;font-size:12px;") + + self.sensor_dropdown = QComboBox() + self.sensor_dropdown.setStyleSheet(""" + QComboBox { + padding:4px 8px; + border:1px solid #cbd5e1; + border-radius:4px; + font-size:12px; + background:white; + min-width:150px; + } + QComboBox:hover { border:1px solid #2563eb; } + """) + + self.load_button = QPushButton("Show Data") + self.load_button.setStyleSheet(""" + QPushButton { + background:#2563eb; + color:white; + border:none; + border-radius:4px; + padding:4px 10px; + font-size:12px; + font-weight:600; + } + QPushButton:hover { background:#1d4ed8; } + """) + self.load_button.clicked.connect(self._on_load_clicked) + + self.input_layout.addWidget(self.label) + self.input_layout.addWidget(self.sensor_dropdown) + self.input_layout.addWidget(self.load_button) + main_layout.addLayout(self.input_layout) + + # --- Web view area --- + self.web = QWebEngineView() + main_layout.addWidget(self.web) + + # --- Auto-refresh timer --- + self.timer = QTimer(self) + self.timer.timeout.connect(self.refresh_data) + self.timer.start(15000) + + # Load available sensors list + self._load_sensor_list() + self.web.setHtml("

Please select a sensor to view details

") + + # -------------------------------------------------------- + def _load_sensor_list(self): + """Load sensor names from the API.""" + try: + r = self.api.http.get(f"{self.api.base}/api/tables/sensors") + data = r.json().get("rows", []) + self.sensor_names = [s["sensor_name"] for s in data if "sensor_name" in s] + self.sensor_dropdown.clear() + self.sensor_dropdown.addItem("-- Select Sensor --") + for name in self.sensor_names: + self.sensor_dropdown.addItem(name) + except Exception as e: + print(f"[SensorDetailsTab] Failed to load sensors list: {e}") + + # -------------------------------------------------------- + def _on_load_clicked(self): + """Triggered when user clicks 'Show Data'.""" + selected = self.sensor_dropdown.currentText().strip() + if not selected or selected == "-- Select Sensor --": + self.web.setHtml("

Please select a sensor from the list

") + return + self.load_sensor(selected) + + # -------------------------------------------------------- + def load_sensor(self, sensor_id: str): + """Called when a sensor is selected manually or from the map.""" + self.sensor_id = sensor_id + self.refresh_data() + + # -------------------------------------------------------- + def refresh_data(self): + """Fetch data from API and refresh the dashboard.""" + if not self.sensor_id: + return + try: + # Sensors + r_sensor = self.api.http.get(f"{self.api.base}/api/tables/sensors?sensor_name={self.sensor_id}") + sensors = r_sensor.json().get("rows", []) + sensor_data = sensors[0] if sensors else {} + + # Logs + r_logs = self.api.http.get(f"{self.api.base}/api/tables/event_logs_sensors?device_id={self.sensor_id}&order_by=start_ts&order_dir=desc") + logs = r_logs.json().get("rows", []) + + # Modal anomalies + r_modal = self.api.http.get(f"{self.api.base}/api/tables/sensors_anomalies_modal?sensor_id={self.sensor_id}&order_by=ts&order_dir=desc") + modal = r_modal.json().get("rows", []) + + # Sensor anomalies + r_anoms = self.api.http.get(f"{self.api.base}/api/tables/sensor_anomalies?sensor={self.sensor_id}&limit=50&order_by=ts&order_dir=desc") + anoms = r_anoms.json().get("rows", []) + + # Active alert + active_alert = next((a for a in logs if a.get("end_ts") is None), None) + + chart_html = self._build_plot(anoms) + page_html = self._build_html(sensor_data, logs, modal, active_alert, chart_html) + self.web.setHtml(page_html) + except Exception as e: + traceback.print_exc() + self.web.setHtml(f"

Error: {e}

") + + # -------------------------------------------------------- + def _build_plot(self, anoms): + """Build the Plotly chart.""" + if not anoms: + return "
No data available for this sensor
" + + timestamps = [a.get("ts") for a in anoms] + values = [a.get("value") for a in anoms] + fig = go.Figure() + fig.add_trace(go.Scatter( + x=timestamps, y=values, mode="lines+markers", + line=dict(color="#2563eb", width=2), + marker=dict(size=4) + )) + fig.update_layout( + template="plotly_white", + height=240, + margin=dict(l=20, r=20, t=20, b=20), + xaxis_title="Timestamp", + yaxis_title="Value", + font=dict(family="Inter,Segoe UI,sans-serif", size=10) + ) + return fig.to_html(include_plotlyjs="cdn", full_html=False) + + # -------------------------------------------------------- + def _build_html(self, sensor_data, logs, modal, active_alert, chart_html): + """Generate the full HTML layout.""" + sensor_name = sensor_data.get("sensor_name", self.sensor_id) + active_html = "" + if active_alert: + sev = active_alert.get("severity", "warn").capitalize() + issue = active_alert.get("issue_type", "Unknown") + started = active_alert.get("start_ts", "")[:19] + active_html = f""" +
+ Active Alert: {issue} | Severity: {sev} | Started: {started} +
+ """ + + combined = [] + for l in logs: + combined.append({ + "time": l.get("start_ts"), + "issue": l.get("issue_type"), + "severity": l.get("severity"), + "source": "event_logs_sensors" + }) + for m in modal: + is_anomaly = m.get("anomaly") not in (0, "0", False, "false", None) + combined.append({ + "time": m.get("ts"), + "issue": "Model anomaly detected" if is_anomaly else "Model normal", + "severity": "critical" if is_anomaly else "info", + "source": "sensors_anomalies_modal" + }) + combined.sort(key=lambda x: x.get("time") or "", reverse=True) + + rows = "".join([ + f"{r['time'][:19]}{r['issue']}" + f"{r['severity'].capitalize()}{r['source']}" + for r in combined + ]) or "No alerts found" + + return f""" + + + + + +

Sensor: {sensor_name}

+{active_html} +

Sensor Readings

{chart_html}
+

Alerts History

+ +{rows}
TimeIssueSeveritySource
+ +""" diff --git a/GUI/src/vast/views/sensorsMainView.py b/GUI/src/vast/views/sensorsMainView.py new file mode 100644 index 000000000..1c8176d7e --- /dev/null +++ b/GUI/src/vast/views/sensorsMainView.py @@ -0,0 +1,84 @@ +from PyQt6.QtWidgets import QWidget, QVBoxLayout, QLabel, QTabWidget +from PyQt6.QtCore import Qt +from views.sensorsMapView import SensorsMapView +from views.sensorDetailsTab import SensorDetailsTab + + +class SensorsMainView(QWidget): + """ + Main container for the sensors module. + Contains two tabs: + 1. Map view (SensorsMapView) + 2. Sensor details (SensorDetailsTab) + """ + def __init__(self, api, parent=None): + super().__init__(parent) + self.api = api + self.setWindowTitle("🌾 Sensors Dashboard") + self.setMinimumSize(1100, 750) + + # --- Layout --- # + layout = QVBoxLayout(self) + layout.setContentsMargins(12, 12, 12, 12) + layout.setSpacing(10) + + # --- Header --- # + title = QLabel("📡 Sensors Dashboard") + title.setStyleSheet(""" + font-size:22px; + font-weight:800; + color:#0f172a; + margin-bottom:4px; + """) + layout.addWidget(title, alignment=Qt.AlignmentFlag.AlignLeft) + + # --- Tabs --- # + self.tabs = QTabWidget() + self.tabs.setTabPosition(QTabWidget.TabPosition.North) + self.tabs.setStyleSheet(""" + QTabWidget::pane { + border: 1px solid #cbd5e1; + border-radius: 10px; + background: #f8fafc; + } + QTabBar::tab { + padding: 8px 16px; + margin-right: 2px; + background: #e2e8f0; + border-radius: 6px 6px 0 0; + font-weight: 600; + color: #0f172a; + } + QTabBar::tab:selected { + background: #2563eb; + color: white; + } + """) + + # --- Map tab --- # + self.map_tab = SensorsMapView(api, self) + self.tabs.addTab(self.map_tab, "🗺️ Map") + + # --- Details tab --- # + self.details_tab = SensorDetailsTab(api, self) + self.tabs.addTab(self.details_tab, "📊 Sensor Details") + + # Add tabs to layout + layout.addWidget(self.tabs) + + # ========================================================== + # === Navigation between tabs + # ========================================================== + def show_sensor_details(self, sensor_id: str): + """ + Called by the map (via JS bridge) when user clicks 'view details' on a sensor. + Loads the details tab and switches to it. + """ + print(f"[SensorsMainView] Showing details for sensor: {sensor_id}") + self.details_tab.load_sensor(sensor_id) + self.tabs.setCurrentIndex(1) + + def back_to_map(self): + """Switch back to the map tab.""" + print("[SensorsMainView] Returning to map tab") + self.tabs.setCurrentIndex(0) diff --git a/GUI/src/vast/views/sensorsMapView.py b/GUI/src/vast/views/sensorsMapView.py new file mode 100644 index 000000000..9c5c515e3 --- /dev/null +++ b/GUI/src/vast/views/sensorsMapView.py @@ -0,0 +1,142 @@ +import os, json +from PyQt6.QtWidgets import QWidget, QVBoxLayout, QLabel, QTableWidget, QPushButton, QTableWidgetItem, QSizePolicy, QFrame +from PyQt6.QtWebEngineWidgets import QWebEngineView +from PyQt6.QtCore import QUrl, QTimer, Qt, pyqtSlot, QObject +from PyQt6.QtWebChannel import QWebChannel +from pathlib import Path +from dashboard_api import DashboardApi + +# Disable GPU (useful in Docker) +os.environ["QTWEBENGINE_DISABLE_GPU"] = "1" +os.environ["QTWEBENGINE_CHROMIUM_FLAGS"] = "--disable-gpu --disable-software-rasterizer --disable-webgl" + +class JsBridge(QObject): + def __init__(self, parent): + super().__init__() + self.parent = parent + + @pyqtSlot(str) + def openSensorDetail(self, sensor_id): + """Called from JS when clicking 'view details'.""" + print(f"[JsBridge] openSensorDetail({sensor_id})") + # נסרוק עד לחלון העליון (MainWindow) בצורה בטוחה + try: + main_window = self.parent.window() + except Exception: + main_window = None + + if main_window and hasattr(main_window, "show_sensor_details"): + main_window.show_sensor_details(sensor_id) + +class SensorsMapView(QWidget): + def __init__(self, api: DashboardApi, parent=None): + super().__init__(parent) + self.api = api + self._map_ready = False + self._visible = False + self._closing = False + + layout = QVBoxLayout(self) + layout.setContentsMargins(10, 10, 10, 10) + layout.setSpacing(10) + + title = QLabel("🗺️ Sensor Map") + title.setStyleSheet("font-size:20px;font-weight:700;color:#0f172a;") + layout.addWidget(title) + + # Map frame + map_frame = QFrame() + map_layout = QVBoxLayout(map_frame) + map_layout.setContentsMargins(0, 0, 0, 0) + self.web = QWebEngineView() + map_layout.addWidget(self.web) + layout.addWidget(map_frame, stretch=2) + + # Load map HTML + html_path = Path(__file__).resolve().parent / "assets" / "sensors_map.html" + self.web.setUrl(QUrl.fromLocalFile(str(html_path))) + + # Stats table + self.table = QTableWidget() + self.table.setAlternatingRowColors(True) + layout.addWidget(self.table, stretch=1) + + btn = QPushButton("⟳ Load Zone Stats") + btn.setStyleSheet("background:#2563eb;color:white;font-weight:700;padding:8px 14px;border-radius:8px;") + btn.clicked.connect(self.load_zone_stats) + layout.addWidget(btn, alignment=Qt.AlignmentFlag.AlignRight) + + # JS bridge + self.channel = QWebChannel() + self.bridge = JsBridge(self) + self.channel.registerObject("pyObj", self.bridge) + self.web.page().setWebChannel(self.channel) + self.web.loadFinished.connect(self._on_map_ready) + + self.timer = QTimer(self) + self.timer.timeout.connect(self.refresh_all) + + def _on_map_ready(self): + self._map_ready = True + print("[SensorsMapView] Map ready") + QTimer.singleShot(1000, self._inject_data) + + def _inject_data(self): + if not self._map_ready: + return + try: + r = self.api.http.get(f"{self.api.base}/api/tables/sensor_anomalies") + data = r.json() + js_data = json.dumps(data.get("rows", data)) + js = f"window.SENSOR_DATA={js_data};if(typeof renderSensors==='function')renderSensors(window.SENSOR_DATA);" + self.web.page().runJavaScript(js) + except Exception as e: + print("[SensorsMapView] Error:", e) + + def load_zone_stats(self): + try: + r = self.api.http.get(f"{self.api.base}/api/tables/sensor_zone_stats?limit=10&order_by=inserted_at&order_dir=desc") + rows = r.json().get("rows", []) + except Exception as e: + print("[SensorsMapView] API error:", e) + return + + self.table.clear() + if not rows: + self.table.setRowCount(0) + self.table.setColumnCount(1) + self.table.setHorizontalHeaderLabels(["No data"]) + return + + + exclude_keys = {"max", "min", "std", "median","mean"} + + keys = [k for k in rows[0].keys() if k not in exclude_keys] + + self.table.setColumnCount(len(keys)) + self.table.setHorizontalHeaderLabels(keys) + self.table.setRowCount(len(rows)) + + for i, row in enumerate(rows): + for j, key in enumerate(keys): + item = QTableWidgetItem(str(row.get(key, ""))) + item.setTextAlignment(Qt.AlignmentFlag.AlignCenter) + self.table.setItem(i, j, item) + + self.table.resizeColumnsToContents() + self.table.horizontalHeader().setStretchLastSection(True) + + def refresh_all(self): + if self._closing or not self._visible: + return + self.load_zone_stats() + self._inject_data() + + def showEvent(self, event): + super().showEvent(event) + self._visible = True + QTimer.singleShot(1000, self._inject_data) + + def hideEvent(self, event): + self._visible = False + super().hideEvent(event) diff --git a/GUI/src/vast/views/sensors_status_summary.py b/GUI/src/vast/views/sensors_status_summary.py new file mode 100644 index 000000000..79d20a46f --- /dev/null +++ b/GUI/src/vast/views/sensors_status_summary.py @@ -0,0 +1,521 @@ +from PyQt6.QtWidgets import ( + QWidget, QVBoxLayout, QHBoxLayout, QLabel, QFrame, QTableWidget, + QTableWidgetItem, QHeaderView, QPushButton +) +from PyQt6.QtCore import Qt, QTimer +from PyQt6.QtGui import QColor +from datetime import datetime, timedelta + + +class SensorsStatusSummary(QWidget): + def __init__(self, api, parent=None): + super().__init__(parent) + self.api = api + + # Cache for performance optimization + self._last_sensors_fetch = None + self._sensors_cache = [] + self._last_events_check = None + self._events_cache = [] + self._last_event_id = 0 # Track last processed event ID + + # Cache duration (5 minutes for sensors, 1 minute for events) + self._sensors_cache_duration = timedelta(minutes=5) + self._events_cache_duration = timedelta(minutes=1) + + self._build_ui() + self.load_data() + + # Auto-refresh timer for events only (every 30 seconds) + self._refresh_timer = QTimer() + self._refresh_timer.timeout.connect(self._refresh_events_only) + self._refresh_timer.start(30000) # 30 seconds + + def _build_ui(self): + main_layout = QVBoxLayout(self) + main_layout.setContentsMargins(30, 30, 30, 30) + main_layout.setSpacing(25) + + # -------- MODERN HEADER -------- + header_layout = QVBoxLayout() + + title = QLabel("🌾 Sensors Status Dashboard") + title.setStyleSheet(""" + QLabel { + font-family: 'Segoe UI', 'Roboto', 'Inter', sans-serif; + font-size: 32px; + font-weight: 800; + color: #1a1a1a; + margin-bottom: 8px; + letter-spacing: -0.5px; + } + """) + + subtitle = QLabel("Real-time monitoring of agricultural sensors") + subtitle.setStyleSheet(""" + QLabel { + font-family: 'Segoe UI', 'Roboto', 'Inter', sans-serif; + font-size: 16px; + font-weight: 400; + color: #6B7280; + margin-bottom: 15px; + } + """) + + header_layout.addWidget(title, alignment=Qt.AlignmentFlag.AlignLeft) + header_layout.addWidget(subtitle, alignment=Qt.AlignmentFlag.AlignLeft) + main_layout.addLayout(header_layout) + + # -------- MODERN STATUS CARDS -------- + cards_row = QHBoxLayout() + cards_row.setSpacing(20) + self.active_card = self._create_status_card("Active Sensors", "●", "#10B981", "#F0FDF4") + self.inactive_card = self._create_status_card("Inactive Sensors", "●", "#EF4444", "#FEF2F2") + cards_row.addWidget(self.active_card) + cards_row.addWidget(self.inactive_card) + main_layout.addLayout(cards_row) + + # -------- MODERN TABLE -------- + self.table = QTableWidget(0, 5) + self.table.setHorizontalHeaderLabels(["ID", "Sensor Type", "Plant", "Plant ID", "Status"]) + self.table.horizontalHeader().setSectionResizeMode(QHeaderView.ResizeMode.Stretch) + self.table.verticalHeader().setVisible(False) + self.table.setAlternatingRowColors(True) + self.table.setStyleSheet(""" + QTableWidget { + background-color: #ffffff; + alternate-background-color: #F9FAFB; + font-family: 'Segoe UI', 'Roboto', 'Inter', sans-serif; + font-size: 14px; + border: 2px solid #E5E7EB; + border-radius: 12px; + gridline-color: #F3F4F6; + selection-background-color: #EEF2FF; + } + QHeaderView::section { + background: qlineargradient(x1: 0, y1: 0, x2: 0, y2: 1, + stop: 0 #F8FAFC, stop: 1 #F1F5F9); + font-family: 'Segoe UI', 'Roboto', 'Inter', sans-serif; + font-weight: 700; + font-size: 15px; + color: #1F2937; + border: none; + border-bottom: 2px solid #E5E7EB; + padding: 12px 8px; + text-align: left; + } + QTableWidget::item { + padding: 12px 8px; + border-bottom: 1px solid #F3F4F6; + } + QTableWidget::item:selected { + background-color: #EEF2FF; + color: #1E40AF; + } + QTableWidget::item:hover { + background-color: #F8FAFC; + } + """) + main_layout.addWidget(self.table) + + # -------- MODERN REFRESH BUTTON -------- + button_layout = QHBoxLayout() + + refresh_btn = QPushButton("↻ Refresh Data") + refresh_btn.setFixedWidth(150) + refresh_btn.setFixedHeight(45) + refresh_btn.clicked.connect(self.refresh_all) + refresh_btn.setStyleSheet(""" + QPushButton { + background: qlineargradient(x1: 0, y1: 0, x2: 0, y2: 1, + stop: 0 #3B82F6, stop: 1 #1D4ED8); + color: white; + border-radius: 12px; + font-family: 'Segoe UI', 'Roboto', 'Inter', sans-serif; + font-size: 15px; + font-weight: 600; + padding: 0px 16px; + border: none; + letter-spacing: 0.3px; + } + QPushButton:hover { + background: qlineargradient(x1: 0, y1: 0, x2: 0, y2: 1, + stop: 0 #2563EB, stop: 1 #1E40AF); + } + QPushButton:pressed { + background: qlineargradient(x1: 0, y1: 0, x2: 0, y2: 1, + stop: 0 #1D4ED8, stop: 1 #1E3A8A); + } + """) + + button_layout.addStretch() + button_layout.addWidget(refresh_btn) + main_layout.addLayout(button_layout) + + self.setLayout(main_layout) + self.setStyleSheet(""" + QWidget { + background: qlineargradient(x1: 0, y1: 0, x2: 0, y2: 1, + stop: 0 #F8FAFC, stop: 1 #F1F5F9); + font-family: 'Segoe UI', 'Roboto', 'Inter', sans-serif; + } + """) + + # -------- MODERN CARD CREATOR -------- + def _create_status_card(self, title_text, icon, accent_color, bg_color): + frame = QFrame() + frame.setFixedHeight(120) + frame.setStyleSheet(f""" + QFrame {{ + background: qlineargradient(x1: 0, y1: 0, x2: 0, y2: 1, + stop: 0 {bg_color}, stop: 1 #ffffff); + border-radius: 16px; + border: none; + padding: 0px; + }} + """) + + layout = QHBoxLayout(frame) + layout.setContentsMargins(24, 20, 24, 20) + layout.setSpacing(18) + + # Icon section + icon_label = QLabel(icon) + icon_label.setAlignment(Qt.AlignmentFlag.AlignCenter) + icon_label.setFixedSize(50, 50) + icon_label.setStyleSheet(f""" + QLabel {{ + color: {accent_color}; + font-size: 36px; + font-weight: 900; + background: transparent; + border-radius: 25px; + font-family: 'Segoe UI Symbol', 'Arial'; + }} + """) + layout.addWidget(icon_label) + + # Text section + text_layout = QVBoxLayout() + text_layout.setSpacing(5) + + title = QLabel(title_text) + title.setStyleSheet(f""" + QLabel {{ + font-family: 'Segoe UI', 'Roboto', 'Inter', sans-serif; + font-size: 16px; + font-weight: 600; + color: #374151; + letter-spacing: 0.2px; + }} + """) + + count = QLabel("0") + count.setObjectName(title_text.lower().replace(" ", "_")) + count.setStyleSheet(f""" + QLabel {{ + font-family: 'Segoe UI', 'Roboto', 'Inter', sans-serif; + font-size: 32px; + font-weight: 800; + color: {accent_color}; + letter-spacing: -1px; + }} + """) + + text_layout.addWidget(title) + text_layout.addWidget(count) + layout.addLayout(text_layout) + + return frame + + # -------- LOAD DATA (OPTIMIZED) -------- + def load_data(self, force_sensors_refresh=False): + """Load sensors and events data with caching optimization.""" + try: + # Load sensors (with caching) + sensors = self._get_sensors_cached(force_sensors_refresh) + + # Load recent keepalive events only (last 2 hours for performance) + events = self._get_recent_keepalive_events() + + except Exception as e: + print("[SensorsStatusSummary] Error loading data:", e) + return + + # identify inactive sensors by looking at the LATEST record for each device + inactive_ids = set() + + # Group events by device_id and issue_type + device_latest = {} + + for e in events: + if (e.get("issue_type") in ["missing_keepalive", "prolonged_silence"] + and str(e.get("device_id", "")).isdigit()): + device_id = str(e["device_id"]) + issue_type = e.get("issue_type") + key = f"{device_id}_{issue_type}" + + # Keep only the latest record for each device+issue_type combination + if key not in device_latest: + device_latest[key] = e + else: + # Compare by ID (higher ID = more recent) since start_ts can be identical + current_id = device_latest[key].get("id", 0) + new_id = e.get("id", 0) + if new_id > current_id: + device_latest[key] = e + + # Now check which devices have open issues based on latest records + for key, latest_event in device_latest.items(): + if latest_event.get("end_ts") is None: # Latest record is still open + device_id = str(latest_event["device_id"]) + inactive_ids.add(device_id) + + print(f"[SensorsStatusSummary] Found {len(inactive_ids)} sensors with active keepalive issues (based on latest records)") + + active = [s for s in sensors if s["id"] not in inactive_ids] + inactive = [s for s in sensors if s["id"] in inactive_ids] + + print(f"[SensorsStatusSummary] Status: {len(active)} active, {len(inactive)} inactive sensors") + print(f"[SensorsStatusSummary] DEBUG: inactive_ids = {sorted(inactive_ids)}") + print(f"[SensorsStatusSummary] DEBUG: sensor IDs = {[s['id'] for s in sensors]}") + print(f"[SensorsStatusSummary] DEBUG: active sensor IDs = {[s['id'] for s in active]}") + print(f"[SensorsStatusSummary] DEBUG: inactive sensor IDs = {[s['id'] for s in inactive]}") + + # Debug: show which sensors are inactive + if inactive_ids: + print(f"[SensorsStatusSummary] Inactive sensor IDs: {sorted(inactive_ids)}") + + self._update_cards(active, inactive) + self._update_table(active, inactive) + + def _get_sensors_cached(self, force_refresh=False): + """Get sensors list with caching (sensors don't change often).""" + now = datetime.now() + + if (not force_refresh and + self._last_sensors_fetch and + self._sensors_cache and + (now - self._last_sensors_fetch) < self._sensors_cache_duration): + return self._sensors_cache + + # Fetch fresh sensors data + res_sensors = self.api.http.get(f"{self.api.base}/api/tables/devices_sensor") + self._sensors_cache = res_sensors.json().get("rows", []) + self._last_sensors_fetch = now + print(f"[SensorsStatusSummary] Refreshed sensors cache: {len(self._sensors_cache)} sensors") + return self._sensors_cache + + def _get_recent_keepalive_events(self): + """Get events with smart caching - only fetch new events since last check.""" + now = datetime.now() + + # Check if cache is still valid + if (self._last_events_check and + self._events_cache and + (now - self._last_events_check) < self._events_cache_duration): + return self._events_cache + + try: + # Strategy: Get only NEW events since last check (incremental loading) + if self._last_event_id > 0: + # Get only events with ID > last processed ID + url = f"{self.api.base}/api/tables/event_logs_sensors?limit=200&order_by=id&order_dir=desc" + res_events = self.api.http.get(url) + new_events = res_events.json().get("rows", []) + + # Filter only truly new events + really_new = [e for e in new_events if e.get("id", 0) > self._last_event_id] + + if really_new: + print(f"[SensorsStatusSummary] Found {len(really_new)} new events") + # Update cache with new events + self._update_events_cache(really_new) + else: + print("[SensorsStatusSummary] No new events since last check") + else: + # First load - get recent events + print("[SensorsStatusSummary] First load - fetching initial events") + url = f"{self.api.base}/api/tables/event_logs_sensors?limit=500&order_by=id&order_dir=desc" + res_events = self.api.http.get(url) + all_events = res_events.json().get("rows", []) + self._initialize_events_cache(all_events) + + self._last_events_check = now + return self._events_cache + + except Exception as e: + print(f"[SensorsStatusSummary] Error loading events: {e}") + return self._events_cache or [] + + def _initialize_events_cache(self, all_events): + """Initialize cache on first load.""" + # Filter relevant events and store in cache + two_hours_ago = datetime.now() - timedelta(hours=2) + + filtered_events = [] + max_id = 0 + + for event in all_events: + event_id = event.get("id", 0) + if event_id > max_id: + max_id = event_id + + # Check issue type + if event.get("issue_type") not in ["missing_keepalive", "prolonged_silence"]: + continue + + # Keep both open events AND recently closed events (for cache invalidation) + is_open = event.get("end_ts") is None + is_recently_closed = False + + if not is_open: + # Check if closed recently (last 5 minutes) + end_ts_str = event.get("end_ts") + if end_ts_str: + try: + end_ts = datetime.fromisoformat(end_ts_str.replace('Z', '+00:00')) + five_min_ago = datetime.now() - timedelta(minutes=5) + is_recently_closed = end_ts.replace(tzinfo=None) >= five_min_ago + except (ValueError, AttributeError): + pass + + # Keep only open events or recently closed ones + if not (is_open or is_recently_closed): + continue + + # Check if recent (within last 2 hours) + start_ts_str = event.get("start_ts") + if start_ts_str: + try: + start_ts = datetime.fromisoformat(start_ts_str.replace('Z', '+00:00')) + if start_ts.replace(tzinfo=None) < two_hours_ago: + continue # Too old + except (ValueError, AttributeError): + continue # Invalid timestamp + + filtered_events.append(event) + + self._events_cache = filtered_events + self._last_event_id = max_id + print(f"[SensorsStatusSummary] Initialized cache with {len(filtered_events)} relevant events") + + def _update_events_cache(self, new_events): + """Update cache with new events (incremental).""" + two_hours_ago = datetime.now() - timedelta(hours=2) + + # Process new events + new_relevant = [] + max_id = self._last_event_id + + for event in new_events: + event_id = event.get("id", 0) + if event_id > max_id: + max_id = event_id + + # Apply same filtering (include recent closures) + if event.get("issue_type") in ["missing_keepalive", "prolonged_silence"]: + is_open = event.get("end_ts") is None + is_recently_closed = False + + if not is_open: + end_ts_str = event.get("end_ts") + if end_ts_str: + try: + end_ts = datetime.fromisoformat(end_ts_str.replace('Z', '+00:00')) + five_min_ago = datetime.now() - timedelta(minutes=5) + is_recently_closed = end_ts.replace(tzinfo=None) >= five_min_ago + except (ValueError, AttributeError): + pass + + if is_open or is_recently_closed: + # Check if recent + start_ts_str = event.get("start_ts") + if start_ts_str: + try: + start_ts = datetime.fromisoformat(start_ts_str.replace('Z', '+00:00')) + if start_ts.replace(tzinfo=None) >= two_hours_ago: + new_relevant.append(event) + except (ValueError, AttributeError): + pass + + # Update cache: add new events and remove old ones + self._events_cache.extend(new_relevant) + + # Clean old events from cache (older than 2 hours) + self._events_cache = [ + e for e in self._events_cache + if self._is_event_recent(e, two_hours_ago) + ] + + self._last_event_id = max_id + print(f"[SensorsStatusSummary] Added {len(new_relevant)} new events, cache now has {len(self._events_cache)} events") + + def _is_event_recent(self, event, threshold): + """Check if event is recent enough to keep in cache.""" + start_ts_str = event.get("start_ts") + if not start_ts_str: + return True # Keep if no timestamp + + try: + start_ts = datetime.fromisoformat(start_ts_str.replace('Z', '+00:00')) + return start_ts.replace(tzinfo=None) >= threshold + except (ValueError, AttributeError): + return True # Keep if invalid timestamp + + def _refresh_events_only(self): + """Auto-refresh only events data (called by timer).""" + try: + self.load_data(force_sensors_refresh=False) + except Exception as e: + print(f"[SensorsStatusSummary] Auto-refresh error: {e}") + + def refresh_all(self): + """Force refresh all data (sensors + events) - clear all caches.""" + # Clear all caches + self._events_cache = [] + self._last_event_id = 0 + self._last_events_check = None + self._sensors_cache = [] + self._last_sensors_fetch = None + + print("[SensorsStatusSummary] Cleared all caches - doing full refresh") + self.load_data(force_sensors_refresh=True) + + def _update_cards(self, active, inactive): + self.active_card.findChild(QLabel, "active_sensors").setText(str(len(active))) + self.inactive_card.findChild(QLabel, "inactive_sensors").setText(str(len(inactive))) + + def _update_table(self, active, inactive): + all_data = [(s, "Active") for s in active] + [(s, "Inactive") for s in inactive] + self.table.setRowCount(len(all_data)) + + for r, (sensor, status) in enumerate(all_data): + sid = QTableWidgetItem(str(sensor.get("id", ""))) + typ = QTableWidgetItem(sensor.get("sensor_type", "")) + # devices_sensor table doesn't have owner_name, use plant_id instead + plant_id = QTableWidgetItem(f"Plant {sensor.get('plant_id', '—')}") + # devices_sensor table doesn't have location, show plant_id info instead + location = QTableWidgetItem(f"Plant ID: {sensor.get('plant_id', 'N/A')}") + # Modern status with colored badges + if status == "Active": + stat = QTableWidgetItem("● ONLINE") + stat.setForeground(Qt.GlobalColor.darkGreen) + else: + stat = QTableWidgetItem("● OFFLINE") + stat.setForeground(Qt.GlobalColor.darkRed) + + # Style inactive rows with subtle background + if status == "Inactive": + gray_bg = QColor(248, 250, 252) + gray_text = QColor(107, 114, 128) + + for item in (sid, typ, plant_id, location): + item.setBackground(gray_bg) + item.setForeground(gray_text) + + self.table.setItem(r, 0, sid) + self.table.setItem(r, 1, typ) + self.table.setItem(r, 2, plant_id) + self.table.setItem(r, 3, location) + self.table.setItem(r, 4, stat) diff --git a/GUI/src/vast/views/sensors_view.py b/GUI/src/vast/views/sensors_view.py index fe980a31d..5a15f40e6 100644 --- a/GUI/src/vast/views/sensors_view.py +++ b/GUI/src/vast/views/sensors_view.py @@ -1,44 +1,366 @@ -from PyQt6.QtWidgets import QWidget, QVBoxLayout, QPushButton, QLabel, QTableWidget, QTableWidgetItem -from dashboard_api import DashboardApi +from PyQt6.QtWidgets import ( + QWidget, QVBoxLayout, QHBoxLayout, QLabel, QPushButton, QLineEdit, + QScrollArea, QGridLayout, QFrame, QDialog, QDialogButtonBox, QFormLayout, QComboBox +) +from PyQt6.QtCore import Qt, QTimer, QDateTime +import traceback -class SensorsView(QWidget): - def __init__(self, api: DashboardApi, parent=None): - super().__init__(parent) - self.api = api +# ============================================================ +# CONSTANTS +# ============================================================ +SEVERITY_RANK = { + "info": 0, + "ok": 0, + "normal": 0, + "warn": 1, + "warning": 1, + "error": 2, + "critical": 3, +} + +# ============================================================ +# SENSOR CARD +# ============================================================ +class SensorCard(QFrame): + """Modern compact sensor card""" + def __init__(self, sensor_data: dict, on_click): + super().__init__() + self.data = sensor_data + self.on_click = on_click + self.setObjectName("card") + self._build_ui() + self.mousePressEvent = self._on_click + + def _on_click(self, event): + self.on_click(self.data) + + def _build_ui(self): layout = QVBoxLayout(self) - title = QLabel("Sensor Types") - title.setStyleSheet("font-size: 18px; font-weight: 600;") + layout.setSpacing(5) + layout.setContentsMargins(12, 12, 12, 10) + + title = QLabel(self.data.get("sensor_name", "Unknown Sensor")) + title.setStyleSheet("font-weight:600; font-size:15px; color:#111;") layout.addWidget(title) - self.table = QTableWidget() - layout.addWidget(self.table) + stype = QLabel(f"Type: {self.data.get('sensor_type', 'N/A')}") + stype.setStyleSheet("color:#555; font-size:12px;") + layout.addWidget(stype) - refresh_btn = QPushButton("Load Sensor Types") - refresh_btn.clicked.connect(self.load_sensors) - layout.addWidget(refresh_btn) + issue = QLabel(f"Issue: {self.data.get('Issue', 'No alerts')}") + issue.setStyleSheet("font-size:12px; color:#444;") + layout.addWidget(issue) layout.addStretch() + sev = self.data.get("Severity", "info").lower() + sev_label = QLabel(sev.capitalize()) + sev_label.setAlignment(Qt.AlignmentFlag.AlignCenter) + sev_label.setFixedHeight(20) + + if sev == "info": + sev_label.setStyleSheet("background-color:#D9FAD3; color:#1B5E20; border-radius:6px; font-weight:600;") + elif sev in ("warn", "warning"): + sev_label.setStyleSheet("background-color:#FFF5BA; color:#8B8000; border-radius:6px; font-weight:600;") + elif sev in ("error", "critical"): + sev_label.setStyleSheet("background-color:#FFD5D5; color:#B71C1C; border-radius:6px; font-weight:600;") + else: + sev_label.setStyleSheet("background-color:#EEE; color:#333; border-radius:6px; font-weight:600;") + + layout.addWidget(sev_label) + + self.setFixedSize(230, 110) + self.setStyleSheet(""" + QFrame#card { + border-radius: 12px; + border: 1px solid #DDD; + background-color: #FFFFFF; + transition: 200ms; + } + QFrame#card:hover { + border: 1px solid #0078D7; + background-color: #F6F9FF; + } + """) + + +# ============================================================ +# ALERT DETAILS DIALOG +# ============================================================ +class AlertDialog(QDialog): + def __init__(self, sensor): + super().__init__() + self.setWindowTitle(f"Alert Details – {sensor.get('sensor_name')}") + self.setMinimumSize(480, 360) + self.setStyleSheet(""" + QDialog { background-color: #FAFAFA; border-radius: 10px; } + QLabel { font-size: 13px; color: #222; } + QPushButton { + background-color: #0078D7; color: white; border-radius: 6px; + padding: 6px 12px; font-weight: 600; + } + QPushButton:hover { background-color: #005FA3; } + """) + + layout = QVBoxLayout(self) + title = QLabel(f"Sensor: {sensor.get('sensor_name')}
" + f"Type: {sensor.get('sensor_type')}
" + f"Current Issue: {sensor.get('Issue')}
" + f"Severity: {sensor.get('Severity')}
") + title.setWordWrap(True) + layout.addWidget(title) + + alerts = sensor.get("All Alerts", []) + layout.addSpacing(10) + + body = QWidget() + body_layout = QVBoxLayout(body) + body_layout.setSpacing(8) + + if alerts: + for a in sorted(alerts, key=lambda x: x.get("start_ts", ""), reverse=True): + card = QFrame() + severity = a.get("severity", "info") + border_color = { + "critical": "#FF4444", + "error": "#FF8800", + "warn": "#FFCC00", + }.get(severity, "#44AA44") + + card.setStyleSheet(f""" + QFrame {{ + border-radius: 8px; + border: 2px solid {border_color}; + background-color: #FFF; + padding: 6px; + margin: 2px; + }} + """) + card_layout = QFormLayout(card) + start_time = a.get("start_ts", "")[:19] if a.get("start_ts") else "" + end_time = a.get("end_ts") + end_display = end_time[:19] if end_time else "[ACTIVE]" + card_layout.addRow("Start Time:", QLabel(start_time)) + card_layout.addRow("End Time:", QLabel(end_display)) + card_layout.addRow("Issue:", QLabel(a.get("issue_type", ""))) + card_layout.addRow("Severity:", QLabel(a.get("severity", ""))) + details = a.get("details", {}) + if details: + for key, val in details.items(): + card_layout.addRow(f"{key.title()}:", QLabel(str(val))) + body_layout.addWidget(card) + else: + body_layout.addWidget(QLabel("No previous alerts or anomalies.")) + + scroll = QScrollArea() + scroll.setWidgetResizable(True) + scroll.setWidget(body) + layout.addWidget(scroll) + + btns = QDialogButtonBox(QDialogButtonBox.StandardButton.Ok) + btns.accepted.connect(self.accept) + layout.addWidget(btns) + + +# ============================================================ +# MAIN VIEW +# ============================================================ +class SensorsView(QWidget): + """Unified sensors view merging anomalies + alerts""" + def __init__(self, api, parent=None): + super().__init__(parent) + self.api = api + self.all_sensors = [] + self._build_ui() + self.load_sensors() + + self.timer = QTimer(self) + self.timer.timeout.connect(self.load_sensors) + self.timer.start(30000) + + def _build_ui(self): + main_layout = QVBoxLayout(self) + main_layout.setContentsMargins(20, 20, 20, 20) + main_layout.setSpacing(10) + + # ---------- Header ---------- + header = QHBoxLayout() + title = QLabel("🌡️ Unified Sensor Alerts Dashboard") + title.setStyleSheet("font-size:22px; font-weight:700; color:#111;") + header.addWidget(title) + header.addStretch() + + self.filter_box = QComboBox() + self.filter_box.addItems(["All", "Info", "Warning", "Error", "Critical"]) + self.filter_box.currentTextChanged.connect(self._apply_filters) + header.addWidget(self.filter_box) + + self.search_box = QLineEdit() + self.search_box.setPlaceholderText("Search sensors...") + self.search_box.textChanged.connect(self._apply_filters) + self.search_box.setFixedWidth(220) + header.addWidget(self.search_box) + + self.refresh_btn = QPushButton("⟳ Refresh") + self.refresh_btn.clicked.connect(self.load_sensors) + self.refresh_btn.setStyleSheet(""" + QPushButton { + background-color: #0078D7; + color: white; + border-radius: 6px; + padding: 6px 12px; + font-weight: 600; + } + QPushButton:hover { background-color: #005FA3; } + """) + header.addWidget(self.refresh_btn) + main_layout.addLayout(header) + + # ---------- Scroll with cards ---------- + self.scroll = QScrollArea() + self.scroll.setWidgetResizable(True) + self.container = QWidget() + self.grid = QGridLayout(self.container) + self.grid.setSpacing(12) + self.scroll.setWidget(self.container) + main_layout.addWidget(self.scroll) + self.setLayout(main_layout) + + # ---------------------------- def load_sensors(self): + self.refresh_btn.setEnabled(False) + self.refresh_btn.setText("⟳ Loading...") + try: - data = self.api.http.get(f"{self.api.base}/api/files").json() + res_sensors = self.api.http.get(f"{self.api.base}/api/tables/sensors", timeout=10).json() + res_anoms = self.api.http.get(f"{self.api.base}/api/tables/sensors_anomalies_modal", timeout=10).json() + res_logs = self.api.http.get(f"{self.api.base}/api/tables/event_logs_sensors", timeout=10).json() + sensors = res_sensors.get("rows", []) + anomalies = res_anoms.get("rows", []) + alerts = res_logs.get("rows", []) except Exception as e: - print(f"[SensorsView] API error: {e}") - data = [] + traceback.print_exc() + self.refresh_btn.setEnabled(True) + self.refresh_btn.setText("↻ Refresh") + return + + # map anomalies by sensor + anomaly_latest = {} + for a in anomalies: + sid = a.get("sensor_id") + if not sid: + continue + prev = anomaly_latest.get(sid) + if prev is None or a.get("ts", "") > prev.get("ts", ""): + anomaly_latest[sid] = a + + # map alerts (event_logs_sensors) + alerts_by_sensor = {} + for alert in alerts: + dev_id = alert.get("device_id") + if not dev_id: + continue + if alert.get("end_ts"): + continue # closed alert + alerts_by_sensor.setdefault(dev_id, []).append(alert) + + merged = [] + for s in sensors: + sid = s.get("sensor_name") + s_type = s.get("sensor_type", "Unknown") - if not data: - self.table.setRowCount(0) - self.table.setColumnCount(1) - self.table.setHorizontalHeaderLabels(["No data"]) + alerts_for_s = alerts_by_sensor.get(sid, []) + active_alerts = [a for a in alerts_for_s if not a.get("end_ts")] + + # determine severity from alerts + if active_alerts: + latest_alert = sorted(active_alerts, key=lambda x: x.get("start_ts", ""), reverse=True)[0] + sev_alert = latest_alert.get("severity", "info").lower() + issue_alert = latest_alert.get("issue_type", "alert") + else: + sev_alert = "info" + issue_alert = None + + # determine severity from anomalies + anom = anomaly_latest.get(sid) + if anom and anom.get("anomaly", 0) > 0: + sev_anom = "error" # you can adjust mapping of numeric to severity + issue_anom = "Anomaly detected" + else: + sev_anom = "info" + issue_anom = None + + # pick the most severe + sev_final = sev_alert + issue_final = issue_alert or "No active alerts" + if SEVERITY_RANK[sev_anom] > SEVERITY_RANK[sev_alert]: + sev_final = sev_anom + issue_final = issue_anom + + all_alerts = alerts_for_s.copy() + if anom: + all_alerts.append({ + "issue_type": "anomaly_modal", + "severity": sev_anom, + "start_ts": anom.get("ts"), + "details": {"anomaly": anom.get("anomaly")} + }) + + merged.append({ + "sensor_name": sid, + "sensor_type": s_type, + "Issue": issue_final, + "Severity": sev_final, + "All Alerts": all_alerts + }) + + self.all_sensors = merged + self._apply_filters() + self.refresh_btn.setEnabled(True) + self.refresh_btn.setText("↻ Refresh") + + # ---------------------------- + def _render_cards(self, sensors): + for i in reversed(range(self.grid.count())): + w = self.grid.itemAt(i).widget() + if w: + w.setParent(None) + + if not sensors: + no_data = QLabel("No sensors found matching your criteria") + no_data.setAlignment(Qt.AlignmentFlag.AlignCenter) + no_data.setStyleSheet("color:#666; font-size:16px; padding:40px;") + self.grid.addWidget(no_data, 0, 0, 1, 3) return - keys = list(data[0].keys()) - self.table.setColumnCount(len(keys)) - self.table.setHorizontalHeaderLabels(keys) - self.table.setRowCount(len(data)) + cols = 3 + for idx, s in enumerate(sensors): + card = SensorCard(s, self._show_alert_history) + r, c = divmod(idx, cols) + self.grid.addWidget(card, r, c, Qt.AlignmentFlag.AlignTop) + + # ---------------------------- + def _apply_filters(self): + text = self.search_box.text().strip().lower() + sev_filter = self.filter_box.currentText().lower() + filtered = [] + + for s in self.all_sensors: + sid = str(s.get("sensor_name", "")).lower() + stype = str(s.get("sensor_type", "")).lower() + sev = s.get("Severity", "").lower() + + if text and text not in sid and text not in stype: + continue + if sev_filter != "all" and sev_filter not in sev: + continue + filtered.append(s) + + self._render_cards(filtered) - for r, row in enumerate(data): - for c, key in enumerate(keys): - self.table.setItem(r, c, QTableWidgetItem(str(row[key]))) + # ---------------------------- + def _show_alert_history(self, sensor): + dlg = AlertDialog(sensor) + dlg.exec() diff --git a/RelDB/build_tables/loader.sql b/RelDB/build_tables/loader.sql index a15bbe3d3..15b5f1435 100644 --- a/RelDB/build_tables/loader.sql +++ b/RelDB/build_tables/loader.sql @@ -9,26 +9,7 @@ INSERT INTO devices (device_id, model, owner, active) VALUES ('dev-e','sensor-z','TeamC',true), ('dev-f','sensor-z','TeamC',true) ON CONFLICT DO NOTHING; --- Insert synthetic sensors -INSERT INTO sensors ( - sensor_name, - sensor_type, - owner_name, - location_lat, - location_lon, - install_date, - status, - description, - last_maintenance -) -VALUES - ('SoilMoistureSensor_A1', 'moisture', 'TeamA', 32.051, 34.871, NOW() - INTERVAL '120 days', 'active', 'Soil probe at north field section A1', NOW() - INTERVAL '20 days'), - ('TempSensor_B2', 'temperature', 'TeamA', 32.057, 34.885, NOW() - INTERVAL '95 days', 'active', 'Temperature monitor - greenhouse B2', NOW() - INTERVAL '15 days'), - ('HumiditySensor_C1', 'humidity', 'TeamB', 31.982, 34.945, NOW() - INTERVAL '200 days', 'maintenance', 'Humidity node C1 (low battery)', NOW() - INTERVAL '3 days'), - ('NDVI_Camera_01', 'NDVI', 'TeamB', 32.015, 34.980, NOW() - INTERVAL '60 days', 'active', 'Multispectral NDVI drone-mounted camera', NOW() - INTERVAL '10 days'), - ('WeatherStation_Main', 'weather', 'TeamC', 32.000, 34.760, NOW() - INTERVAL '365 days', 'active', 'Main weather station at south field', NOW() - INTERVAL '30 days'), - ('SoilProbe_Edge', 'moisture', 'TeamC', 32.010, 34.910, NOW() - INTERVAL '40 days', 'inactive', 'Edge field soil probe - disconnected', NOW() - INTERVAL '60 days') -ON CONFLICT (sensor_name) DO NOTHING; + -- Insert some regions INSERT INTO regions (name, geom) @@ -46,6 +27,11 @@ VALUES ('COMM_LOSS','Communication lost') ON CONFLICT DO NOTHING; +-- Seed leaf disease types +INSERT INTO leaf_disease_types (name) +VALUES ('Blight'), ('Mildew'), ('Rust') +ON CONFLICT DO NOTHING; + -- Insert 5 missions WITH params AS ( SELECT 34.75::double precision AS min_lon, 35.05 AS max_lon, @@ -142,7 +128,6 @@ SELECT CASE WHEN random()<0.3 THEN (100+g) ELSE -1 END FROM generate_series(1,100) g; - -- Insert 1000 random embeddings INSERT INTO embeddings (vec) SELECT ARRAY( @@ -165,3 +150,13 @@ DO UPDATE SET threshold = EXCLUDED.threshold, updated_by = EXCLUDED.updated_by, updated_at = NOW(); + +-- Seed sample leaf reports +INSERT INTO leaf_reports (device_id, leaf_disease_type_id, ts, confidence, sick) +SELECT d.device_id, t.id, now() - ((g % 2000) || ' seconds')::interval, + (random()*0.5 + 0.5)::double precision, -- 0.5..1.0 + (random() < 0.5) +FROM devices d, leaf_disease_types t, generate_series(1,20) g +LIMIT 30; + + diff --git a/RelDB/build_tables/schema.sql b/RelDB/build_tables/schema.sql index ade97f4a0..adfb954a4 100644 --- a/RelDB/build_tables/schema.sql +++ b/RelDB/build_tables/schema.sql @@ -11,7 +11,9 @@ CREATE TABLE IF NOT EXISTS devices ( device_id text PRIMARY KEY, model text, owner text, - active boolean DEFAULT true + active boolean DEFAULT true, + location_lat DOUBLE PRECISION, + location_lon DOUBLE PRECISION ); -- Predefined regions (optional: for missions crossing multiple regions) @@ -28,8 +30,23 @@ CREATE TABLE IF NOT EXISTS anomaly_types ( description text NOT NULL ); +--Types of leaf diseases +CREATE TABLE IF NOT EXISTS leaf_disease_types ( + id SERIAL PRIMARY KEY, + name TEXT UNIQUE NOT NULL +); -- === Core entities === +CREATE TABLE IF NOT EXISTS leaf_reports ( + id BIGSERIAL PRIMARY KEY, + device_id TEXT NOT NULL REFERENCES devices(device_id), + leaf_disease_type_id INT NOT NULL REFERENCES leaf_disease_types(id), + ts TIMESTAMPTZ NOT NULL, + confidence DOUBLE PRECISION CHECK (confidence >= 0 AND confidence <= 1), + sick BOOLEAN NOT NULL +); + + -- Missions table CREATE TABLE IF NOT EXISTS missions ( mission_id BIGSERIAL PRIMARY KEY, @@ -128,13 +145,13 @@ CREATE TABLE IF NOT EXISTS users ( ); CREATE TABLE IF NOT EXISTS clients ( - schedule_id BIGSERIAL PRIMARY KEY, - client_id BIGINT NOT NULL, - team VARCHAR(150), - cron_expr TEXT, - active_days TEXT, - time_window TEXT, - last_updated TIMESTAMPTZ NOT NULL DEFAULT now() + schedule_id BIGSERIAL PRIMARY KEY, + client_id BIGINT NOT NULL, + team VARCHAR(150), + cron_expr TEXT, + active_days TEXT, + time_window TEXT, + last_updated TIMESTAMPTZ NOT NULL DEFAULT now() ); -- CREATE TABLE IF NOT EXISTS ultrasonic_plant_predictions ( @@ -210,6 +227,50 @@ CREATE TABLE IF NOT EXISTS inference_logs ( image_url TEXT ); +-- Ripeness predictions table +CREATE TABLE IF NOT EXISTS ripeness_predictions ( + id BIGSERIAL PRIMARY KEY, + inference_log_id BIGINT NOT NULL REFERENCES inference_logs(id) ON DELETE CASCADE, + ts TIMESTAMPTZ NOT NULL DEFAULT NOW(), + ripeness_label TEXT NOT NULL CHECK (ripeness_label IN ('ripe', 'unripe', 'overripe')), + ripeness_score DOUBLE PRECISION NOT NULL, + model_name TEXT NOT NULL, + run_id UUID NOT NULL, + device_id TEXT REFERENCES devices(device_id), + UNIQUE (inference_log_id) +); + +-- Create indexes for ripeness_predictions +CREATE INDEX IF NOT EXISTS ix_ripeness_inflog ON ripeness_predictions(inference_log_id); +CREATE INDEX IF NOT EXISTS ix_ripeness_ts ON ripeness_predictions(ts); +CREATE INDEX IF NOT EXISTS ix_ripeness_device ON ripeness_predictions(device_id); +CREATE INDEX IF NOT EXISTS ix_ripeness_run ON ripeness_predictions(run_id); +CREATE INDEX IF NOT EXISTS ix_leaf_reports_ts_brin ON leaf_reports USING BRIN (ts); +CREATE INDEX IF NOT EXISTS ix_leaf_reports_device_ts ON leaf_reports (device_id, ts); +CREATE INDEX IF NOT EXISTS ix_leaf_reports_type_ts ON leaf_reports (leaf_disease_type_id, ts); + +-- Weekly ripeness rollups table +CREATE TABLE IF NOT EXISTS ripeness_weekly_rollups_ts ( + id BIGSERIAL PRIMARY KEY, + ts TIMESTAMPTZ NOT NULL DEFAULT NOW(), + window_start TIMESTAMPTZ NOT NULL, + window_end TIMESTAMPTZ NOT NULL, + fruit_type TEXT NOT NULL, + device_id TEXT REFERENCES devices(device_id), + run_id UUID NOT NULL, + cnt_total INTEGER NOT NULL, + cnt_ripe INTEGER NOT NULL, + cnt_unripe INTEGER NOT NULL, + cnt_overripe INTEGER NOT NULL, + pct_ripe DOUBLE PRECISION NOT NULL +); + +-- Create indexes for ripeness_weekly_rollups_ts +CREATE INDEX IF NOT EXISTS ix_rwrt_ts ON ripeness_weekly_rollups_ts(ts); +CREATE INDEX IF NOT EXISTS ix_rwrt_fruit_ts ON ripeness_weekly_rollups_ts(fruit_type, ts); +CREATE INDEX IF NOT EXISTS ix_rwrt_device ON ripeness_weekly_rollups_ts(device_id); +CREATE INDEX IF NOT EXISTS ix_rwrt_run ON ripeness_weekly_rollups_ts(run_id); + -- Sensor event logs table. CREATE TABLE IF NOT EXISTS event_logs_sensors( id bigserial PRIMARY KEY, @@ -266,28 +327,11 @@ CREATE TABLE IF NOT EXISTS public.sensor_zone_stats ( anomalies INT, inserted_at TIMESTAMPTZ NOT NULL DEFAULT now() ); - ---- Alerts table - -CREATE TABLE IF NOT EXISTS public.alerts ( - id bigserial PRIMARY KEY, - entity_id text NOT NULL, - rule text NOT NULL, - window_start timestamptz NOT NULL, - window_end timestamptz NOT NULL, - score double precision NOT NULL, - first_seen timestamptz NOT NULL, - last_seen timestamptz NOT NULL, - status text NOT NULL CHECK (status IN ('OPEN','ACK','RESOLVED')), - meta_json jsonb -); - - --- === Soil moisture irrigation tables === CREATE TABLE IF NOT EXISTS soil_moisture_events ( id SERIAL PRIMARY KEY, - zone_id TEXT NOT NULL, + device_id TEXT NOT NULL REFERENCES devices(device_id), ts TIMESTAMPTZ NOT NULL DEFAULT NOW(), dry_ratio REAL NOT NULL, decision TEXT NOT NULL, @@ -300,7 +344,8 @@ CREATE TABLE IF NOT EXISTS soil_moisture_events ( CREATE UNIQUE INDEX IF NOT EXISTS idx_events_idem ON soil_moisture_events (idempotency_key); CREATE TABLE IF NOT EXISTS irrigation_schedule ( - zone_id TEXT PRIMARY KEY, + device_id TEXT PRIMARY KEY REFERENCES devices(device_id), + next_run_at TIMESTAMPTZ NOT NULL, duration_min INT NOT NULL, updated_by TEXT NOT NULL, @@ -310,7 +355,7 @@ CREATE TABLE IF NOT EXISTS irrigation_schedule ( CREATE TABLE IF NOT EXISTS irrigation_schedule_audit ( id SERIAL PRIMARY KEY, - zone_id TEXT NOT NULL, + device_id TEXT NOT NULL, prev_next_run_at TIMESTAMPTZ, prev_duration_min INT, next_run_at TIMESTAMPTZ NOT NULL, @@ -320,6 +365,52 @@ CREATE TABLE IF NOT EXISTS irrigation_schedule_audit ( updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW() ); +CREATE TABLE irrigation_policies ( + device_id TEXT NOT NULL, + prev_state TEXT, + dry_ratio_high REAL, + dry_ratio_low REAL, + min_patches INT, + duration_min INT, + updated_at TIMESTAMP DEFAULT NOW(), + PRIMARY KEY (device_id), + CONSTRAINT fk_device + FOREIGN KEY (device_id) REFERENCES devices(device_id) + ON DELETE CASCADE +); + + + +CREATE TABLE IF NOT EXISTS alerts ( + + -- Required fields + alert_id TEXT PRIMARY KEY, + alert_type TEXT, + device_id TEXT, + started_at TIMESTAMPTZ, + + -- Optional / dynamic fields + ended_at TIMESTAMPTZ, + confidence DOUBLE PRECISION, + area TEXT, + lat DOUBLE PRECISION, + lon DOUBLE PRECISION, + severity INT DEFAULT 1, + image_url TEXT, + vod TEXT, + hls TEXT, + + -- Acknowledgment field + ack BOOLEAN DEFAULT FALSE, -- TRUE when the alert was acknowledged + + -- Flexible metadata for anything else + meta JSONB, + + -- System fields + created_at TIMESTAMPTZ DEFAULT now(), + updated_at TIMESTAMPTZ DEFAULT now() +); + -- === Task thresholds (enum + table) === DO $$ BEGIN @@ -344,6 +435,155 @@ CREATE TABLE IF NOT EXISTS task_thresholds ( CONSTRAINT ux_task_thresholds_task_label UNIQUE (task, label) ); +CREATE TABLE public.image_new_aerial_connections ( + id BIGSERIAL PRIMARY KEY, + file_name VARCHAR(255), + key TEXT, + linked_time TIMESTAMPTZ +); + +CREATE TABLE IF NOT EXISTS public.aerial_images_metadata ( + id SERIAL PRIMARY KEY, + + -- File and drone metadata + file_name TEXT NOT NULL, + drone_id TEXT NOT NULL, + capture_time TIMESTAMP WITH TIME ZONE NOT NULL, + + -- Raw JSON as received (latitude/longitude) + gis_origin JSONB NOT NULL, + + -- Geometry point auto-generated from JSON + geom_point geometry(Point, 4326) + GENERATED ALWAYS AS ( + ST_SetSRID( + ST_MakePoint( + (gis_origin->>'longitude')::double precision, + (gis_origin->>'latitude')::double precision + ), + 4326 + ) + ) STORED, + + -- Flight attributes + altitude_m DOUBLE PRECISION, + done BOOLEAN DEFAULT FALSE, + created_at TIMESTAMP DEFAULT NOW() +); + +CREATE INDEX IF NOT EXISTS ix_aerial_geom_point_gist +ON public.aerial_images_metadata USING GIST (geom_point); + + +CREATE TABLE IF NOT EXISTS public.aerial_image_object_detections ( + id SERIAL PRIMARY KEY, + img_key TEXT NOT NULL, + label TEXT NOT NULL, + confidence DOUBLE PRECISION NOT NULL, + bbox_x1 DOUBLE PRECISION NOT NULL, + bbox_y1 DOUBLE PRECISION NOT NULL, + bbox_x2 DOUBLE PRECISION NOT NULL, + bbox_y2 DOUBLE PRECISION NOT NULL, + detected_at TIMESTAMP DEFAULT NOW() +); + +CREATE INDEX IF NOT EXISTS idx_image_object_detections_key + ON public.aerial_image_object_detections (img_key); + + +CREATE TABLE IF NOT EXISTS public.aerial_image_anomaly_detections ( + id SERIAL PRIMARY KEY, + img_key TEXT NOT NULL, + label TEXT NOT NULL, + confidence DOUBLE PRECISION NOT NULL, + bbox_x1 DOUBLE PRECISION NOT NULL, + bbox_y1 DOUBLE PRECISION NOT NULL, + bbox_x2 DOUBLE PRECISION NOT NULL, + bbox_y2 DOUBLE PRECISION NOT NULL, + detected_at TIMESTAMP DEFAULT NOW() +); + +CREATE INDEX IF NOT EXISTS idx_image_anomaly_detections_key + ON public.aerial_image_anomaly_detections (img_key); + + +CREATE TABLE IF NOT EXISTS public.aerial_images_complete_metadata ( + id SERIAL PRIMARY KEY, + file_name TEXT NOT NULL, + device_id TEXT NOT NULL, + gis_origin JSONB, + gis geometry(Point, 4326) + GENERATED ALWAYS AS ( + ST_SetSRID( + ST_MakePoint( + (gis_origin->>'longitude')::double precision, + (gis_origin->>'latitude')::double precision + ), + 4326 + ) + ) STORED, + img_key TEXT NOT NULL UNIQUE, + timestamp_utc TIMESTAMP WITH TIME ZONE, + created_at TIMESTAMP DEFAULT NOW() +); + +CREATE INDEX IF NOT EXISTS idx_aerial_metadata_device_id + ON public.aerial_images_complete_metadata (device_id); + +CREATE INDEX IF NOT EXISTS idx_aerial_metadata_timestamp + ON public.aerial_images_complete_metadata (timestamp_utc); + +CREATE INDEX IF NOT EXISTS idx_aerial_metadata_gis + ON public.aerial_images_complete_metadata USING GIST (gis); + + +CREATE TABLE IF NOT EXISTS public.field_polygons ( + id SERIAL PRIMARY KEY, + gis geometry(Point, 4326) NOT NULL, + boundary geometry(Polygon, 4326) NOT NULL, + area_sq_m DOUBLE PRECISION GENERATED ALWAYS AS ( + ST_Area(geography(boundary)) + ) STORED, + created_at TIMESTAMP DEFAULT NOW() +); + +CREATE INDEX IF NOT EXISTS idx_field_polygons_gis + ON public.field_polygons USING GIST (gis); + + +CREATE TABLE IF NOT EXISTS public.aerial_image_segmentation ( + id SERIAL PRIMARY KEY, + img_key TEXT NOT NULL, + mask_path TEXT, + other FLOAT DEFAULT 0, + bareland FLOAT DEFAULT 0, + rangeland FLOAT DEFAULT 0, + developed_space FLOAT DEFAULT 0, + road FLOAT DEFAULT 0, + tree FLOAT DEFAULT 0, + water FLOAT DEFAULT 0, + agriculture FLOAT DEFAULT 0, + building FLOAT DEFAULT 0, + created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW() +); + +CREATE INDEX IF NOT EXISTS idx_segmentation_img_key + ON public.aerial_image_segmentation (img_key); + + +CREATE TABLE public.sound_new_sounds_connections ( + id BIGSERIAL PRIMARY KEY, + file_name VARCHAR(255), + key TEXT, + linked_time TIMESTAMPTZ +); + +CREATE TABLE public.sound_new_plants_connections ( + id BIGSERIAL PRIMARY KEY, + file_name VARCHAR(255), + key TEXT, + linked_time TIMESTAMPTZ +); CREATE INDEX IF NOT EXISTS ix_task_thresholds_task ON task_thresholds (task); CREATE INDEX IF NOT EXISTS ix_task_thresholds_updated_at ON task_thresholds (updated_at); @@ -412,5 +652,250 @@ CREATE INDEX IF NOT EXISTS ix_event_logs_sensors_start_brin ON event_logs_sens CREATE INDEX IF NOT EXISTS ix_event_logs_sensors_details_gin ON event_logs_sensors USING GIN (details jsonb_path_ops); -CREATE INDEX IF NOT EXISTS ix_alerts_entity_rule ON public.alerts(entity_id, rule); -CREATE INDEX IF NOT EXISTS ix_alerts_status ON public.alerts(status); + + +/* =========================== + ADDED: Incidents schema v1 + =========================== */ + +-- ========================= +-- Incidents: one event row +-- ========================= +-- anomaly_type_id int REFERENCES anomaly_types(anomaly_type_id) ON DELETE SET NULL, +CREATE TABLE IF NOT EXISTS incidents ( -- [ADDED] + incident_id uuid PRIMARY KEY, -- [ADDED] + mission_id bigint REFERENCES missions(mission_id) ON DELETE SET NULL, -- [ADDED] + device_id text REFERENCES devices(device_id) ON DELETE SET NULL, -- [ADDED] + anomaly text, -- [ADDED] + started_at timestamptz NOT NULL, -- [ADDED] + ended_at timestamptz, -- [ADDED] + duration_sec double precision, -- [ADDED] + frame_start int, -- [ADDED] + frame_end int, -- [ADDED] + -- [ADDED] + -- NEW: aggregate severity (mean tracks/frame over the incident) -- [ADDED] + severity real, -- [ADDED] + -- [ADDED] + -- image-space ROI; keep flexible -- [ADDED] + roi_pixels jsonb, -- [ADDED] + -- optional map footprint (camera frustum, area of interest, etc.) -- [ADDED] + footprint geometry(Polygon,4326), -- [ADDED] + -- [ADDED] + -- canonical media for playback (referencing your existing files table) -- [ADDED] + clip_file_id bigint REFERENCES files(file_id) ON DELETE SET NULL, -- [ADDED] + poster_file_id bigint REFERENCES files(file_id) ON DELETE SET NULL, -- [ADDED] + -- [ADDED] + -- optional pre-baked UI timeline (array of {frame,ts,box,conf,url,...}) -- [ADDED] + frames_manifest jsonb, -- [ADDED] + is_real boolean, + ack boolean DEFAULT false, -- [ADDED] + meta jsonb DEFAULT '{}'::jsonb -- [ADDED] +); -- [ADDED] + +-- Helpful indexes -- [ADDED] +CREATE INDEX IF NOT EXISTS ix_incidents_device_time ON incidents (device_id, started_at DESC); -- [ADDED] +CREATE INDEX IF NOT EXISTS ix_incidents_mission_time ON incidents (mission_id, started_at DESC); -- [ADDED] + +-- ========================================== +-- Per-frame timeline: one row per frame +-- Store ALL detections (bbox + conf + track_id) in JSONB +-- ========================================== +DROP TABLE IF EXISTS incident_frames CASCADE; -- [ADDED] + +CREATE TABLE incident_frames ( -- [ADDED] + incident_id uuid NOT NULL REFERENCES incidents(incident_id) ON DELETE CASCADE, -- [ADDED] + frame_idx int NOT NULL, -- [ADDED] + ts timestamptz NOT NULL, -- [ADDED] + -- [ADDED] + -- List of detection objects: -- [ADDED] + -- [{"x1":int,"y1":int,"x2":int,"y2":int,"conf":float|null,"track_id":int|null}, ...] -- [ADDED] + detections jsonb NOT NULL DEFAULT '[]'::jsonb, -- [ADDED] + -- [ADDED] + cls_name text, -- [ADDED] + cls_id text, -- [ADDED] + -- [ADDED] + -- Annotated (or raw) frame stored in files -- [ADDED] + file_id bigint REFERENCES files(file_id) ON DELETE SET NULL, -- [ADDED] + -- [ADDED] + meta jsonb DEFAULT '{}'::jsonb, -- [ADDED] + PRIMARY KEY (incident_id, frame_idx) -- [ADDED] +); -- [ADDED] + +-- Useful indexes -- [ADDED] +CREATE INDEX IF NOT EXISTS ix_incident_frames_ts -- [ADDED] + ON incident_frames (incident_id, ts); -- [ADDED] + +-- JSONB GIN index for detection queries (by bbox/conf/track_id) -- [ADDED] +CREATE INDEX IF NOT EXISTS ix_incident_frames_detections_gin -- [ADDED] + ON incident_frames USING GIN (detections jsonb_path_ops); -- [ADDED] + +-- (Optional) denormalized count for quick metrics -- [ADDED] +ALTER TABLE incident_frames -- [ADDED] + ADD COLUMN IF NOT EXISTS num_tracks int -- [ADDED] + GENERATED ALWAYS AS (jsonb_array_length(detections)) STORED; -- [ADDED] + +-- CREATE INDEX IF NOT EXISTS ix_alerts_entity_rule ON public.alerts(entity_id, rule); +-- CREATE INDEX IF NOT EXISTS ix_alerts_status ON public.alerts(status); + +-- ============================================ +-- 🔹 MISSING TABLES AND INDEXES FROM FIRST SCHEMA +-- ============================================ + +-- Devices sensor mapping +CREATE TABLE IF NOT EXISTS devices_sensor ( + id TEXT UNIQUE NOT NULL, + plant_id INT NOT NULL, + sensor_type TEXT NOT NULL, + PRIMARY KEY (plant_id, id) +); + +-- Zones table (for linking sensors to geographic areas) +CREATE TABLE IF NOT EXISTS public.zones ( + id SERIAL PRIMARY KEY, + name VARCHAR(128) NOT NULL, + geom geometry(POLYGON, 4326) NOT NULL +); + +-- Extended sensors table with all environmental metrics +DROP TABLE IF EXISTS public.sensors CASCADE; +CREATE TABLE IF NOT EXISTS public.sensors ( + id SERIAL PRIMARY KEY, + sensor_name TEXT UNIQUE NOT NULL, + sensor_type TEXT NOT NULL, + owner_name TEXT, + location_lat DOUBLE PRECISION, + location_lon DOUBLE PRECISION, + install_date TIMESTAMP DEFAULT NOW(), + status TEXT DEFAULT 'active', + description TEXT, + last_maintenance TIMESTAMP, + value DOUBLE PRECISION, + humidity DOUBLE PRECISION, + temperature DOUBLE PRECISION, + ph DOUBLE PRECISION, + rainfall DOUBLE PRECISION, + soil_moisture DOUBLE PRECISION, + co2_concentration DOUBLE PRECISION, + n DOUBLE PRECISION, + p DOUBLE PRECISION, + k DOUBLE PRECISION, + label TEXT, + timestamp TIMESTAMPTZ NOT NULL, + msg_type TEXT, + plant_id INT, + soil_type INT, + sunlight_exposure DOUBLE PRECISION, + wind_speed DOUBLE PRECISION, + organic_matter DOUBLE PRECISION, + irrigation_frequency DOUBLE PRECISION, + crop_density DOUBLE PRECISION, + pest_pressure DOUBLE PRECISION, + fertilizer_usage DOUBLE PRECISION, + growth_stage INT, + urban_area_proximity DOUBLE PRECISION, + water_source_type INT, + frost_risk DOUBLE PRECISION, + water_usage_efficiency DOUBLE PRECISION +); + +-- Sensor anomalies table with full structure and JSONB result +DROP TABLE IF EXISTS public.sensor_anomalies CASCADE; +CREATE TABLE IF NOT EXISTS public.sensor_anomalies ( + id BIGSERIAL PRIMARY KEY, + idSensor INT NOT NULL, + plant_id INT NOT NULL, + sensor VARCHAR(64) NOT NULL, + ts TIMESTAMPTZ NOT NULL, + value DOUBLE PRECISION, + lat DOUBLE PRECISION, + lon DOUBLE PRECISION, + zone VARCHAR(128), + result JSONB NOT NULL, + inserted_at TIMESTAMPTZ NOT NULL DEFAULT now() +); + +-- Sensors anomalies modal (aggregated anomaly detection model) +CREATE TABLE IF NOT EXISTS public.sensors_anomalies_modal ( + id BIGSERIAL PRIMARY KEY, + sensor_id TEXT NOT NULL REFERENCES sensors(sensor_name) ON DELETE CASCADE, + ts TIMESTAMPTZ NOT NULL, + anomaly REAL NOT NULL CHECK (anomaly >= 0), + inserted_at TIMESTAMPTZ NOT NULL DEFAULT now() +); + +-- Updated event_logs_sensors referencing devices_sensor +DROP TABLE IF EXISTS event_logs_sensors CASCADE; +CREATE TABLE IF NOT EXISTS event_logs_sensors( + id bigserial PRIMARY KEY, + device_id TEXT NOT NULL REFERENCES devices_sensor(id), + issue_type text NOT NULL, + severity text NOT NULL CHECK (severity IN ('info','warn','error','critical')), + start_ts timestamptz NOT NULL DEFAULT now(), + end_ts timestamptz NULL, + details jsonb NOT NULL DEFAULT '{}'::jsonb, + CONSTRAINT event_logs_sensors_end_after_start + CHECK (end_ts IS NULL OR end_ts >= start_ts) +); + +-- Sensor zone statistics (for per-region summaries) +CREATE TABLE IF NOT EXISTS public.sensor_zone_stats ( + id BIGSERIAL PRIMARY KEY, + zone VARCHAR(128) NOT NULL, + window_start TIMESTAMPTZ NOT NULL, + window_end TIMESTAMPTZ NOT NULL, + count INT NOT NULL, + mean DOUBLE PRECISION, + median DOUBLE PRECISION, + min DOUBLE PRECISION, + max DOUBLE PRECISION, + std DOUBLE PRECISION, + anomalies INT, + inserted_at TIMESTAMPTZ NOT NULL DEFAULT now() +); + +-- ============================================ +-- 🔹 INDEXES FOR SENSOR TABLES +-- ============================================ + +CREATE INDEX IF NOT EXISTS ix_sensors_anomalies_modal_sensor_ts + ON sensors_anomalies_modal (sensor_id, ts); + +CREATE INDEX IF NOT EXISTS ix_sensor_anomalies_ts_brin + ON public.sensor_anomalies USING BRIN (ts); + +CREATE INDEX IF NOT EXISTS ix_sensor_anomalies_zone + ON public.sensor_anomalies (zone); + +CREATE INDEX IF NOT EXISTS ix_sensor_anomalies_sensor + ON public.sensor_anomalies (sensor); + +CREATE INDEX IF NOT EXISTS ix_sensor_zone_stats_zone_window + ON public.sensor_zone_stats (zone, window_start, window_end); + +CREATE INDEX IF NOT EXISTS ix_sensor_zone_stats_anomalies + ON public.sensor_zone_stats (anomalies); + +CREATE INDEX IF NOT EXISTS ix_sensors_name ON sensors (sensor_name); +CREATE INDEX IF NOT EXISTS ix_sensors_type ON sensors (sensor_type); +CREATE INDEX IF NOT EXISTS ix_sensors_status ON sensors (status); +CREATE INDEX IF NOT EXISTS ix_sensors_location ON sensors (location_lat, location_lon); + + +--- Alerts_leaves table + +CREATE TABLE IF NOT EXISTS public.alerts_leaves ( + id bigserial PRIMARY KEY, + entity_id text NOT NULL, + rule text NOT NULL, + window_start timestamptz NOT NULL, + window_end timestamptz NOT NULL, + score double precision NOT NULL, + first_seen timestamptz NOT NULL, + last_seen timestamptz NOT NULL, + status text NOT NULL CHECK (status IN ('OPEN','ACK','RESOLVED')), + meta_json jsonb +); + +CREATE INDEX IF NOT EXISTS ix_alerts_leaves_entity_rule ON public.alerts_leaves(entity_id, rule); +CREATE INDEX IF NOT EXISTS ix_alerts_leaves_status ON public.alerts_leaves(status); + diff --git a/airflow_bundle/leaf-pipeline/.gitattributes b/airflow_bundle/leaf-pipeline/.gitattributes new file mode 100644 index 000000000..ada57005e --- /dev/null +++ b/airflow_bundle/leaf-pipeline/.gitattributes @@ -0,0 +1,5 @@ +projects/leaf-counting/weights/*.pt filter=lfs diff=lfs merge=lfs -text +projects/leaf-counting/weights/*.pth filter=lfs diff=lfs merge=lfs -text +projects/leaf-counting/weights/*.safetensors filter=lfs diff=lfs merge=lfs -text + +projects/Detection_Jobs/**/models/* filter=lfs diff=lfs merge=lfs -text diff --git a/airflow_bundle/leaf-pipeline/.gitignore b/airflow_bundle/leaf-pipeline/.gitignore new file mode 100644 index 000000000..fa3f26096 --- /dev/null +++ b/airflow_bundle/leaf-pipeline/.gitignore @@ -0,0 +1,89 @@ +# === OS / Editors === +.DS_Store +*.swp +*.swo +.idea/ +.vscode/ +*.code-workspace + +# === Python === +__pycache__/ +*.py[cod] +*.pyo +*.pyd +*.egg-info/ +.eggs/ +.build/ +build/ +dist/ +.mypy_cache/ +.pytest_cache/ +.coverage +.coverage.* +.cache/ +.ipynb_checkpoints/ + +# Virtual envs +.venv/ +venv/ +env/ + +# === Docker / Compose === +docker.env +.env +.env.* +.env.local +.envrc + +# === Airflow === +airflow/logs/ +airflow/airflow.db +airflow/airflow.db-journal +airflow/*.pid +airflow/*webserver*.log +airflow/*webserver*.out +airflow/*webserver*.err +airflow/*scheduler*.log +airflow/*scheduler*.out +airflow/*scheduler*.err +airflow/dags/*.bak.* +airflow/staging/ + +# === Projects: leaf-counting === +projects/leaf-counting/out_detect/ +projects/leaf-counting/out_crops/ +projects/leaf-counting/out_pwb/ +projects/leaf-counting/runs_local/ +projects/leaf-counting/.venv/ +projects/leaf-counting/staging/ + +# === Projects: Detection_Jobs / disease-monitor === +projects/Detection_Jobs/**/__pycache__/ +projects/Detection_Jobs/**/.mypy_cache/ +projects/Detection_Jobs/**/.pytest_cache/ +projects/disease-monitor/**/__pycache__/ +projects/disease-monitor/**/.mypy_cache/ +projects/disease-monitor/**/.pytest_cache/ + +# === Secrets / Certs === +*.key +*.pem +*.crt +*.p12 +*credentials*.json +*service_account*.json +*.secrets.* +.secrets/ +secrets/ +projects/Detection_Jobs/.git.backup-*.tar.gz +airflow/dags/leaf-counting/runs_local/ +airflow/dags/leaf-counting/demo_images/ +airflow/dags/leaf-counting/out_*/ +!projects/Detection_Jobs/**/models/ +projects/disease-monitor/disease-monitor/alerts.db +projects/**/.git.backup-*.tar.gz +.gitignore.bak* +airflow/dags/leaf-counting/ + + +dags_leaf_counting_backup.tgz \ No newline at end of file diff --git a/airflow_bundle/leaf-pipeline/Dockerfile b/airflow_bundle/leaf-pipeline/Dockerfile new file mode 100644 index 000000000..53af135cf --- /dev/null +++ b/airflow_bundle/leaf-pipeline/Dockerfile @@ -0,0 +1,42 @@ + +FROM mcr.microsoft.com/devcontainers/python:1-3.10-bullseye +# RUN apt-get update && apt-get install -y --no-install-recommends \ +# libgl1 libglib2.0-0 ffmpeg curl ca-certificates && \ +# rm -rf /var/lib/apt/lists/* +# RUN apt-get update && apt-get install -y --no-install-recommends \ +# libgl1 libglib2.0-0 ffmpeg curl ca-certificates \ +# util-linux procps \ +# && rm -rf /var/lib/apt/lists/* + +RUN apt-get update && apt-get install -y --no-install-recommends \ + libgl1 libglib2.0-0 ffmpeg curl ca-certificates \ + util-linux procps \ + && rm -rf /var/lib/apt/lists/* + +RUN python -m pip install --no-cache-dir --upgrade pip wheel setuptools + +ENV AIRFLOW_VERSION=2.9.3 +ENV PYTHON_VERSION=3.10 +ENV CONSTRAINT_URL=https://raw.githubusercontent.com/apache/airflow/constraints-${AIRFLOW_VERSION}/constraints-${PYTHON_VERSION}.txt +RUN pip install --no-cache-dir "apache-airflow==${AIRFLOW_VERSION}" --constraint "${CONSTRAINT_URL}" + +RUN pip install --no-cache-dir \ + "apache-airflow-providers-docker" \ + --constraint "${CONSTRAINT_URL}" + + +# === PyTorch CPU wheels === +RUN pip install --no-cache-dir \ + --extra-index-url https://download.pytorch.org/whl/cpu \ + torch==2.3.1+cpu torchvision==0.18.1+cpu torchaudio==2.3.1+cpu + +# === YOLO=== +RUN pip install --no-cache-dir \ + numpy==1.26.4 opencv-python-headless==4.9.0.80 ultralytics==8.2.10 \ + onnx==1.16.1 onnxruntime==1.18.1 \ + boto3 minio awscli requests tqdm + +#root +RUN useradd -ms /bin/bash airflow +USER airflow +WORKDIR /opt/airflow diff --git a/airflow_bundle/leaf-pipeline/README b/airflow_bundle/leaf-pipeline/README new file mode 100644 index 000000000..7d2cd8bd5 --- /dev/null +++ b/airflow_bundle/leaf-pipeline/README @@ -0,0 +1 @@ +docker compose --profile images up -d --build \ No newline at end of file diff --git a/airflow_bundle/leaf-pipeline/airflow/airflow.cfg b/airflow_bundle/leaf-pipeline/airflow/airflow.cfg new file mode 100644 index 000000000..8f1408c24 --- /dev/null +++ b/airflow_bundle/leaf-pipeline/airflow/airflow.cfg @@ -0,0 +1,2420 @@ +[core] +# The folder where your airflow pipelines live, most likely a +# subfolder in a code repository. This path must be absolute. +# +# Variable: AIRFLOW__CORE__DAGS_FOLDER +# +dags_folder = /opt/airflow/dags + +# Hostname by providing a path to a callable, which will resolve the hostname. +# The format is "package.function". +# +# For example, default value ``airflow.utils.net.getfqdn`` means that result from patched +# version of `socket.getfqdn() `__, +# see related `CPython Issue `__. +# +# No argument should be required in the function specified. +# If using IP address as hostname is preferred, use value ``airflow.utils.net.get_host_ip_address`` +# +# Variable: AIRFLOW__CORE__HOSTNAME_CALLABLE +# +hostname_callable = airflow.utils.net.getfqdn + +# A callable to check if a python file has airflow dags defined or not and should +# return ``True`` if it has dags otherwise ``False``. +# If this is not provided, Airflow uses its own heuristic rules. +# +# The function should have the following signature +# +# .. code-block:: python +# +# def func_name(file_path: str, zip_file: zipfile.ZipFile | None = None) -> bool: ... +# +# Variable: AIRFLOW__CORE__MIGHT_CONTAIN_DAG_CALLABLE +# +might_contain_dag_callable = airflow.utils.file.might_contain_dag_via_default_heuristic + +# Default timezone in case supplied date times are naive +# can be `UTC` (default), `system`, or any `IANA ` +# timezone string (e.g. Europe/Amsterdam) +# +# Variable: AIRFLOW__CORE__DEFAULT_TIMEZONE +# +default_timezone = utc + +# The executor class that airflow should use. Choices include +# ``SequentialExecutor``, ``LocalExecutor``, ``CeleryExecutor``, +# ``KubernetesExecutor``, ``CeleryKubernetesExecutor``, ``LocalKubernetesExecutor`` or the +# full import path to the class when using a custom executor. +# +# Variable: AIRFLOW__CORE__EXECUTOR +# +executor = SequentialExecutor + +# The auth manager class that airflow should use. Full import path to the auth manager class. +# +# Variable: AIRFLOW__CORE__AUTH_MANAGER +# +auth_manager = airflow.providers.fab.auth_manager.fab_auth_manager.FabAuthManager + +# This defines the maximum number of task instances that can run concurrently per scheduler in +# Airflow, regardless of the worker count. Generally this value, multiplied by the number of +# schedulers in your cluster, is the maximum number of task instances with the running +# state in the metadata database. +# +# Variable: AIRFLOW__CORE__PARALLELISM +# +parallelism = 32 + +# The maximum number of task instances allowed to run concurrently in each DAG. To calculate +# the number of tasks that is running concurrently for a DAG, add up the number of running +# tasks for all DAG runs of the DAG. This is configurable at the DAG level with ``max_active_tasks``, +# which is defaulted as ``[core] max_active_tasks_per_dag``. +# +# An example scenario when this would be useful is when you want to stop a new dag with an early +# start date from stealing all the executor slots in a cluster. +# +# Variable: AIRFLOW__CORE__MAX_ACTIVE_TASKS_PER_DAG +# +max_active_tasks_per_dag = 16 + +# Are DAGs paused by default at creation +# +# Variable: AIRFLOW__CORE__DAGS_ARE_PAUSED_AT_CREATION +# +dags_are_paused_at_creation = True + +# The maximum number of active DAG runs per DAG. The scheduler will not create more DAG runs +# if it reaches the limit. This is configurable at the DAG level with ``max_active_runs``, +# which is defaulted as ``[core] max_active_runs_per_dag``. +# +# Variable: AIRFLOW__CORE__MAX_ACTIVE_RUNS_PER_DAG +# +max_active_runs_per_dag = 16 + +# (experimental) The maximum number of consecutive DAG failures before DAG is automatically paused. +# This is also configurable per DAG level with ``max_consecutive_failed_dag_runs``, +# which is defaulted as ``[core] max_consecutive_failed_dag_runs_per_dag``. +# If not specified, then the value is considered as 0, +# meaning that the dags are never paused out by default. +# +# Variable: AIRFLOW__CORE__MAX_CONSECUTIVE_FAILED_DAG_RUNS_PER_DAG +# +max_consecutive_failed_dag_runs_per_dag = 0 + +# The name of the method used in order to start Python processes via the multiprocessing module. +# This corresponds directly with the options available in the Python docs: +# `multiprocessing.set_start_method +# `__ +# must be one of the values returned by `multiprocessing.get_all_start_methods() +# `__. +# +# Example: mp_start_method = fork +# +# Variable: AIRFLOW__CORE__MP_START_METHOD +# +# mp_start_method = + +# Whether to load the DAG examples that ship with Airflow. It's good to +# get started, but you probably want to set this to ``False`` in a production +# environment +# +# Variable: AIRFLOW__CORE__LOAD_EXAMPLES +# +load_examples = True + +# Path to the folder containing Airflow plugins +# +# Variable: AIRFLOW__CORE__PLUGINS_FOLDER +# +plugins_folder = /opt/airflow/plugins + +# Should tasks be executed via forking of the parent process +# +# * ``False``: Execute via forking of the parent process +# * ``True``: Spawning a new python process, slower than fork, but means plugin changes picked +# up by tasks straight away +# +# Variable: AIRFLOW__CORE__EXECUTE_TASKS_NEW_PYTHON_INTERPRETER +# +execute_tasks_new_python_interpreter = False + +# Secret key to save connection passwords in the db +# +# Variable: AIRFLOW__CORE__FERNET_KEY +# +fernet_key = + +# Whether to disable pickling dags +# +# Variable: AIRFLOW__CORE__DONOT_PICKLE +# +donot_pickle = True + +# How long before timing out a python file import +# +# Variable: AIRFLOW__CORE__DAGBAG_IMPORT_TIMEOUT +# +dagbag_import_timeout = 30.0 + +# Should a traceback be shown in the UI for dagbag import errors, +# instead of just the exception message +# +# Variable: AIRFLOW__CORE__DAGBAG_IMPORT_ERROR_TRACEBACKS +# +dagbag_import_error_tracebacks = True + +# If tracebacks are shown, how many entries from the traceback should be shown +# +# Variable: AIRFLOW__CORE__DAGBAG_IMPORT_ERROR_TRACEBACK_DEPTH +# +dagbag_import_error_traceback_depth = 2 + +# How long before timing out a DagFileProcessor, which processes a dag file +# +# Variable: AIRFLOW__CORE__DAG_FILE_PROCESSOR_TIMEOUT +# +dag_file_processor_timeout = 50 + +# The class to use for running task instances in a subprocess. +# Choices include StandardTaskRunner, CgroupTaskRunner or the full import path to the class +# when using a custom task runner. +# +# Variable: AIRFLOW__CORE__TASK_RUNNER +# +task_runner = StandardTaskRunner + +# If set, tasks without a ``run_as_user`` argument will be run with this user +# Can be used to de-elevate a sudo user running Airflow when executing tasks +# +# Variable: AIRFLOW__CORE__DEFAULT_IMPERSONATION +# +default_impersonation = + +# What security module to use (for example kerberos) +# +# Variable: AIRFLOW__CORE__SECURITY +# +security = + +# Turn unit test mode on (overwrites many configuration options with test +# values at runtime) +# +# Variable: AIRFLOW__CORE__UNIT_TEST_MODE +# +unit_test_mode = False + +# Whether to enable pickling for xcom (note that this is insecure and allows for +# RCE exploits). +# +# Variable: AIRFLOW__CORE__ENABLE_XCOM_PICKLING +# +enable_xcom_pickling = False + +# What classes can be imported during deserialization. This is a multi line value. +# The individual items will be parsed as a pattern to a glob function. +# Python built-in classes (like dict) are always allowed. +# +# Variable: AIRFLOW__CORE__ALLOWED_DESERIALIZATION_CLASSES +# +allowed_deserialization_classes = airflow.* + +# What classes can be imported during deserialization. This is a multi line value. +# The individual items will be parsed as regexp patterns. +# This is a secondary option to ``[core] allowed_deserialization_classes``. +# +# Variable: AIRFLOW__CORE__ALLOWED_DESERIALIZATION_CLASSES_REGEXP +# +allowed_deserialization_classes_regexp = + +# When a task is killed forcefully, this is the amount of time in seconds that +# it has to cleanup after it is sent a SIGTERM, before it is SIGKILLED +# +# Variable: AIRFLOW__CORE__KILLED_TASK_CLEANUP_TIME +# +killed_task_cleanup_time = 60 + +# Whether to override params with dag_run.conf. If you pass some key-value pairs +# through ``airflow dags backfill -c`` or +# ``airflow dags trigger -c``, the key-value pairs will override the existing ones in params. +# +# Variable: AIRFLOW__CORE__DAG_RUN_CONF_OVERRIDES_PARAMS +# +dag_run_conf_overrides_params = True + +# If enabled, Airflow will only scan files containing both ``DAG`` and ``airflow`` (case-insensitive). +# +# Variable: AIRFLOW__CORE__DAG_DISCOVERY_SAFE_MODE +# +dag_discovery_safe_mode = True + +# The pattern syntax used in the +# `.airflowignore +# `__ +# files in the DAG directories. Valid values are ``regexp`` or ``glob``. +# +# Variable: AIRFLOW__CORE__DAG_IGNORE_FILE_SYNTAX +# +dag_ignore_file_syntax = regexp + +# The number of retries each task is going to have by default. Can be overridden at dag or task level. +# +# Variable: AIRFLOW__CORE__DEFAULT_TASK_RETRIES +# +default_task_retries = 0 + +# The number of seconds each task is going to wait by default between retries. Can be overridden at +# dag or task level. +# +# Variable: AIRFLOW__CORE__DEFAULT_TASK_RETRY_DELAY +# +default_task_retry_delay = 300 + +# The maximum delay (in seconds) each task is going to wait by default between retries. +# This is a global setting and cannot be overridden at task or DAG level. +# +# Variable: AIRFLOW__CORE__MAX_TASK_RETRY_DELAY +# +max_task_retry_delay = 86400 + +# The weighting method used for the effective total priority weight of the task +# +# Variable: AIRFLOW__CORE__DEFAULT_TASK_WEIGHT_RULE +# +default_task_weight_rule = downstream + +# The default task execution_timeout value for the operators. Expected an integer value to +# be passed into timedelta as seconds. If not specified, then the value is considered as None, +# meaning that the operators are never timed out by default. +# +# Variable: AIRFLOW__CORE__DEFAULT_TASK_EXECUTION_TIMEOUT +# +default_task_execution_timeout = + +# Updating serialized DAG can not be faster than a minimum interval to reduce database write rate. +# +# Variable: AIRFLOW__CORE__MIN_SERIALIZED_DAG_UPDATE_INTERVAL +# +min_serialized_dag_update_interval = 30 + +# If ``True``, serialized DAGs are compressed before writing to DB. +# +# .. note:: +# +# This will disable the DAG dependencies view +# +# Variable: AIRFLOW__CORE__COMPRESS_SERIALIZED_DAGS +# +compress_serialized_dags = False + +# Fetching serialized DAG can not be faster than a minimum interval to reduce database +# read rate. This config controls when your DAGs are updated in the Webserver +# +# Variable: AIRFLOW__CORE__MIN_SERIALIZED_DAG_FETCH_INTERVAL +# +min_serialized_dag_fetch_interval = 10 + +# Maximum number of Rendered Task Instance Fields (Template Fields) per task to store +# in the Database. +# All the template_fields for each of Task Instance are stored in the Database. +# Keeping this number small may cause an error when you try to view ``Rendered`` tab in +# TaskInstance view for older tasks. +# +# Variable: AIRFLOW__CORE__MAX_NUM_RENDERED_TI_FIELDS_PER_TASK +# +max_num_rendered_ti_fields_per_task = 30 + +# On each dagrun check against defined SLAs +# +# Variable: AIRFLOW__CORE__CHECK_SLAS +# +check_slas = True + +# Path to custom XCom class that will be used to store and resolve operators results +# +# Example: xcom_backend = path.to.CustomXCom +# +# Variable: AIRFLOW__CORE__XCOM_BACKEND +# +xcom_backend = airflow.models.xcom.BaseXCom + +# By default Airflow plugins are lazily-loaded (only loaded when required). Set it to ``False``, +# if you want to load plugins whenever 'airflow' is invoked via cli or loaded from module. +# +# Variable: AIRFLOW__CORE__LAZY_LOAD_PLUGINS +# +lazy_load_plugins = True + +# By default Airflow providers are lazily-discovered (discovery and imports happen only when required). +# Set it to ``False``, if you want to discover providers whenever 'airflow' is invoked via cli or +# loaded from module. +# +# Variable: AIRFLOW__CORE__LAZY_DISCOVER_PROVIDERS +# +lazy_discover_providers = True + +# Hide sensitive **Variables** or **Connection extra json keys** from UI +# and task logs when set to ``True`` +# +# .. note:: +# +# Connection passwords are always hidden in logs +# +# Variable: AIRFLOW__CORE__HIDE_SENSITIVE_VAR_CONN_FIELDS +# +hide_sensitive_var_conn_fields = True + +# A comma-separated list of extra sensitive keywords to look for in variables names or connection's +# extra JSON. +# +# Variable: AIRFLOW__CORE__SENSITIVE_VAR_CONN_NAMES +# +sensitive_var_conn_names = + +# Task Slot counts for ``default_pool``. This setting would not have any effect in an existing +# deployment where the ``default_pool`` is already created. For existing deployments, users can +# change the number of slots using Webserver, API or the CLI +# +# Variable: AIRFLOW__CORE__DEFAULT_POOL_TASK_SLOT_COUNT +# +default_pool_task_slot_count = 128 + +# The maximum list/dict length an XCom can push to trigger task mapping. If the pushed list/dict has a +# length exceeding this value, the task pushing the XCom will be failed automatically to prevent the +# mapped tasks from clogging the scheduler. +# +# Variable: AIRFLOW__CORE__MAX_MAP_LENGTH +# +max_map_length = 1024 + +# The default umask to use for process when run in daemon mode (scheduler, worker, etc.) +# +# This controls the file-creation mode mask which determines the initial value of file permission bits +# for newly created files. +# +# This value is treated as an octal-integer. +# +# Variable: AIRFLOW__CORE__DAEMON_UMASK +# +daemon_umask = 0o077 + +# Class to use as dataset manager. +# +# Example: dataset_manager_class = airflow.datasets.manager.DatasetManager +# +# Variable: AIRFLOW__CORE__DATASET_MANAGER_CLASS +# +# dataset_manager_class = + +# Kwargs to supply to dataset manager. +# +# Example: dataset_manager_kwargs = {"some_param": "some_value"} +# +# Variable: AIRFLOW__CORE__DATASET_MANAGER_KWARGS +# +# dataset_manager_kwargs = + +# Dataset URI validation should raise an exception if it is not compliant with AIP-60. +# By default this configuration is false, meaning that Airflow 2.x only warns the user. +# In Airflow 3, this configuration will be enabled by default. +# +# Variable: AIRFLOW__CORE__STRICT_DATASET_URI_VALIDATION +# +strict_dataset_uri_validation = False + +# (experimental) Whether components should use Airflow Internal API for DB connectivity. +# +# Variable: AIRFLOW__CORE__DATABASE_ACCESS_ISOLATION +# +database_access_isolation = False + +# (experimental) Airflow Internal API url. +# Only used if ``[core] database_access_isolation`` is ``True``. +# +# Example: internal_api_url = http://localhost:8080 +# +# Variable: AIRFLOW__CORE__INTERNAL_API_URL +# +# internal_api_url = + +# The ability to allow testing connections across Airflow UI, API and CLI. +# Supported options: ``Disabled``, ``Enabled``, ``Hidden``. Default: Disabled +# Disabled - Disables the test connection functionality and disables the Test Connection button in UI. +# Enabled - Enables the test connection functionality and shows the Test Connection button in UI. +# Hidden - Disables the test connection functionality and hides the Test Connection button in UI. +# Before setting this to Enabled, make sure that you review the users who are able to add/edit +# connections and ensure they are trusted. Connection testing can be done maliciously leading to +# undesired and insecure outcomes. +# See `Airflow Security Model: Capabilities of authenticated UI users +# `__ +# for more details. +# +# Variable: AIRFLOW__CORE__TEST_CONNECTION +# +test_connection = Disabled + +# The maximum length of the rendered template field. If the value to be stored in the +# rendered template field exceeds this size, it's redacted. +# +# Variable: AIRFLOW__CORE__MAX_TEMPLATED_FIELD_LENGTH +# +max_templated_field_length = 4096 + +[database] +# Path to the ``alembic.ini`` file. You can either provide the file path relative +# to the Airflow home directory or the absolute path if it is located elsewhere. +# +# Variable: AIRFLOW__DATABASE__ALEMBIC_INI_FILE_PATH +# +alembic_ini_file_path = alembic.ini + +# The SQLAlchemy connection string to the metadata database. +# SQLAlchemy supports many different database engines. +# See: `Set up a Database Backend: Database URI +# `__ +# for more details. +# +# Variable: AIRFLOW__DATABASE__SQL_ALCHEMY_CONN +# +sql_alchemy_conn = sqlite:////opt/airflow/airflow.db + +# Extra engine specific keyword args passed to SQLAlchemy's create_engine, as a JSON-encoded value +# +# Example: sql_alchemy_engine_args = {"arg1": true} +# +# Variable: AIRFLOW__DATABASE__SQL_ALCHEMY_ENGINE_ARGS +# +# sql_alchemy_engine_args = + +# The encoding for the databases +# +# Variable: AIRFLOW__DATABASE__SQL_ENGINE_ENCODING +# +sql_engine_encoding = utf-8 + +# Collation for ``dag_id``, ``task_id``, ``key``, ``external_executor_id`` columns +# in case they have different encoding. +# By default this collation is the same as the database collation, however for ``mysql`` and ``mariadb`` +# the default is ``utf8mb3_bin`` so that the index sizes of our index keys will not exceed +# the maximum size of allowed index when collation is set to ``utf8mb4`` variant, see +# `GitHub Issue Comment `__ +# for more details. +# +# Variable: AIRFLOW__DATABASE__SQL_ENGINE_COLLATION_FOR_IDS +# +# sql_engine_collation_for_ids = + +# If SQLAlchemy should pool database connections. +# +# Variable: AIRFLOW__DATABASE__SQL_ALCHEMY_POOL_ENABLED +# +sql_alchemy_pool_enabled = True + +# The SQLAlchemy pool size is the maximum number of database connections +# in the pool. 0 indicates no limit. +# +# Variable: AIRFLOW__DATABASE__SQL_ALCHEMY_POOL_SIZE +# +sql_alchemy_pool_size = 5 + +# The maximum overflow size of the pool. +# When the number of checked-out connections reaches the size set in pool_size, +# additional connections will be returned up to this limit. +# When those additional connections are returned to the pool, they are disconnected and discarded. +# It follows then that the total number of simultaneous connections the pool will allow +# is **pool_size** + **max_overflow**, +# and the total number of "sleeping" connections the pool will allow is pool_size. +# max_overflow can be set to ``-1`` to indicate no overflow limit; +# no limit will be placed on the total number of concurrent connections. Defaults to ``10``. +# +# Variable: AIRFLOW__DATABASE__SQL_ALCHEMY_MAX_OVERFLOW +# +sql_alchemy_max_overflow = 10 + +# The SQLAlchemy pool recycle is the number of seconds a connection +# can be idle in the pool before it is invalidated. This config does +# not apply to sqlite. If the number of DB connections is ever exceeded, +# a lower config value will allow the system to recover faster. +# +# Variable: AIRFLOW__DATABASE__SQL_ALCHEMY_POOL_RECYCLE +# +sql_alchemy_pool_recycle = 1800 + +# Check connection at the start of each connection pool checkout. +# Typically, this is a simple statement like "SELECT 1". +# See `SQLAlchemy Pooling: Disconnect Handling - Pessimistic +# `__ +# for more details. +# +# Variable: AIRFLOW__DATABASE__SQL_ALCHEMY_POOL_PRE_PING +# +sql_alchemy_pool_pre_ping = True + +# The schema to use for the metadata database. +# SQLAlchemy supports databases with the concept of multiple schemas. +# +# Variable: AIRFLOW__DATABASE__SQL_ALCHEMY_SCHEMA +# +sql_alchemy_schema = + +# Import path for connect args in SQLAlchemy. Defaults to an empty dict. +# This is useful when you want to configure db engine args that SQLAlchemy won't parse +# in connection string. This can be set by passing a dictionary containing the create engine parameters. +# For more details about passing create engine parameters (keepalives variables, timeout etc) +# in Postgres DB Backend see `Setting up a PostgreSQL Database +# `__ +# e.g ``connect_args={"timeout":30}`` can be defined in ``airflow_local_settings.py`` and +# can be imported as shown below +# +# Example: sql_alchemy_connect_args = airflow_local_settings.connect_args +# +# Variable: AIRFLOW__DATABASE__SQL_ALCHEMY_CONNECT_ARGS +# +# sql_alchemy_connect_args = + +# Whether to load the default connections that ship with Airflow when ``airflow db init`` is called. +# It's good to get started, but you probably want to set this to ``False`` in a production environment. +# +# Variable: AIRFLOW__DATABASE__LOAD_DEFAULT_CONNECTIONS +# +load_default_connections = True + +# Number of times the code should be retried in case of DB Operational Errors. +# Not all transactions will be retried as it can cause undesired state. +# Currently it is only used in ``DagFileProcessor.process_file`` to retry ``dagbag.sync_to_db``. +# +# Variable: AIRFLOW__DATABASE__MAX_DB_RETRIES +# +max_db_retries = 3 + +# Whether to run alembic migrations during Airflow start up. Sometimes this operation can be expensive, +# and the users can assert the correct version through other means (e.g. through a Helm chart). +# Accepts ``True`` or ``False``. +# +# Variable: AIRFLOW__DATABASE__CHECK_MIGRATIONS +# +check_migrations = True + +[logging] +# The folder where airflow should store its log files. +# This path must be absolute. +# There are a few existing configurations that assume this is set to the default. +# If you choose to override this you may need to update the +# ``[logging] dag_processor_manager_log_location`` and +# ``[logging] child_process_log_directory settings`` as well. +# +# Variable: AIRFLOW__LOGGING__BASE_LOG_FOLDER +# +base_log_folder = /opt/airflow/logs +processor_log_folder = /opt/airflow/logs/scheduler + +# Airflow can store logs remotely in AWS S3, Google Cloud Storage or Elastic Search. +# Set this to ``True`` if you want to enable remote logging. +# +# Variable: AIRFLOW__LOGGING__REMOTE_LOGGING +# +remote_logging = False + +# Users must supply an Airflow connection id that provides access to the storage +# location. Depending on your remote logging service, this may only be used for +# reading logs, not writing them. +# +# Variable: AIRFLOW__LOGGING__REMOTE_LOG_CONN_ID +# +remote_log_conn_id = + +# Whether the local log files for GCS, S3, WASB and OSS remote logging should be deleted after +# they are uploaded to the remote location. +# +# Variable: AIRFLOW__LOGGING__DELETE_LOCAL_LOGS +# +delete_local_logs = False + +# Path to Google Credential JSON file. If omitted, authorization based on `the Application Default +# Credentials +# `__ will +# be used. +# +# Variable: AIRFLOW__LOGGING__GOOGLE_KEY_PATH +# +google_key_path = + +# Storage bucket URL for remote logging +# S3 buckets should start with **s3://** +# Cloudwatch log groups should start with **cloudwatch://** +# GCS buckets should start with **gs://** +# WASB buckets should start with **wasb** just to help Airflow select correct handler +# Stackdriver logs should start with **stackdriver://** +# +# Variable: AIRFLOW__LOGGING__REMOTE_BASE_LOG_FOLDER +# +remote_base_log_folder = + +# The remote_task_handler_kwargs param is loaded into a dictionary and passed to the ``__init__`` +# of remote task handler and it overrides the values provided by Airflow config. For example if you set +# ``delete_local_logs=False`` and you provide ``{"delete_local_copy": true}``, then the local +# log files will be deleted after they are uploaded to remote location. +# +# Example: remote_task_handler_kwargs = {"delete_local_copy": true} +# +# Variable: AIRFLOW__LOGGING__REMOTE_TASK_HANDLER_KWARGS +# +remote_task_handler_kwargs = + +# Use server-side encryption for logs stored in S3 +# +# Variable: AIRFLOW__LOGGING__ENCRYPT_S3_LOGS +# +encrypt_s3_logs = False + +# Logging level. +# +# Supported values: ``CRITICAL``, ``ERROR``, ``WARNING``, ``INFO``, ``DEBUG``. +# +# Variable: AIRFLOW__LOGGING__LOGGING_LEVEL +# +logging_level = INFO + +# Logging level for celery. If not set, it uses the value of logging_level +# +# Supported values: ``CRITICAL``, ``ERROR``, ``WARNING``, ``INFO``, ``DEBUG``. +# +# Variable: AIRFLOW__LOGGING__CELERY_LOGGING_LEVEL +# +celery_logging_level = + +# Logging level for Flask-appbuilder UI. +# +# Supported values: ``CRITICAL``, ``ERROR``, ``WARNING``, ``INFO``, ``DEBUG``. +# +# Variable: AIRFLOW__LOGGING__FAB_LOGGING_LEVEL +# +fab_logging_level = WARNING + +# Logging class +# Specify the class that will specify the logging configuration +# This class has to be on the python classpath +# +# Example: logging_config_class = my.path.default_local_settings.LOGGING_CONFIG +# +# Variable: AIRFLOW__LOGGING__LOGGING_CONFIG_CLASS +# +logging_config_class = + +# Flag to enable/disable Colored logs in Console +# Colour the logs when the controlling terminal is a TTY. +# +# Variable: AIRFLOW__LOGGING__COLORED_CONSOLE_LOG +# +colored_console_log = True + +# Log format for when Colored logs is enabled +# +# Variable: AIRFLOW__LOGGING__COLORED_LOG_FORMAT +# +colored_log_format = [%%(blue)s%%(asctime)s%%(reset)s] {%%(blue)s%%(filename)s:%%(reset)s%%(lineno)d} %%(log_color)s%%(levelname)s%%(reset)s - %%(log_color)s%%(message)s%%(reset)s + +# Specifies the class utilized by Airflow to implement colored logging +# +# Variable: AIRFLOW__LOGGING__COLORED_FORMATTER_CLASS +# +colored_formatter_class = airflow.utils.log.colored_log.CustomTTYColoredFormatter + +# Format of Log line +# +# Variable: AIRFLOW__LOGGING__LOG_FORMAT +# +log_format = [%%(asctime)s] {%%(filename)s:%%(lineno)d} %%(levelname)s - %%(message)s + +# Defines the format of log messages for simple logging configuration +# +# Variable: AIRFLOW__LOGGING__SIMPLE_LOG_FORMAT +# +simple_log_format = %%(asctime)s %%(levelname)s - %%(message)s + +# Where to send dag parser logs. If "file", logs are sent to log files defined by child_process_log_directory. +# +# Variable: AIRFLOW__LOGGING__DAG_PROCESSOR_LOG_TARGET +# +dag_processor_log_target = file + +# Format of Dag Processor Log line +# +# Variable: AIRFLOW__LOGGING__DAG_PROCESSOR_LOG_FORMAT +# +dag_processor_log_format = [%%(asctime)s] [SOURCE:DAG_PROCESSOR] {%%(filename)s:%%(lineno)d} %%(levelname)s - %%(message)s + +# Determines the formatter class used by Airflow for structuring its log messages +# The default formatter class is timezone-aware, which means that timestamps attached to log entries +# will be adjusted to reflect the local timezone of the Airflow instance +# +# Variable: AIRFLOW__LOGGING__LOG_FORMATTER_CLASS +# +log_formatter_class = airflow.utils.log.timezone_aware.TimezoneAware + +# An import path to a function to add adaptations of each secret added with +# ``airflow.utils.log.secrets_masker.mask_secret`` to be masked in log messages. The given function +# is expected to require a single parameter: the secret to be adapted. It may return a +# single adaptation of the secret or an iterable of adaptations to each be masked as secrets. +# The original secret will be masked as well as any adaptations returned. +# +# Example: secret_mask_adapter = urllib.parse.quote +# +# Variable: AIRFLOW__LOGGING__SECRET_MASK_ADAPTER +# +secret_mask_adapter = + +# Specify prefix pattern like mentioned below with stream handler ``TaskHandlerWithCustomFormatter`` +# +# Example: task_log_prefix_template = {{ti.dag_id}}-{{ti.task_id}}-{{execution_date}}-{{ti.try_number}} +# +# Variable: AIRFLOW__LOGGING__TASK_LOG_PREFIX_TEMPLATE +# +task_log_prefix_template = + +# Formatting for how airflow generates file names/paths for each task run. +# +# Variable: AIRFLOW__LOGGING__LOG_FILENAME_TEMPLATE +# +log_filename_template = dag_id={{ ti.dag_id }}/run_id={{ ti.run_id }}/task_id={{ ti.task_id }}/{%% if ti.map_index >= 0 %%}map_index={{ ti.map_index }}/{%% endif %%}attempt={{ try_number }}.log + +# Formatting for how airflow generates file names for log +# +# Variable: AIRFLOW__LOGGING__LOG_PROCESSOR_FILENAME_TEMPLATE +# +log_processor_filename_template = {{ filename }}.log + +# Full path of dag_processor_manager logfile. +# +# Variable: AIRFLOW__LOGGING__DAG_PROCESSOR_MANAGER_LOG_LOCATION +# +dag_processor_manager_log_location = /opt/airflow/logs/dag_processor_manager/dag_processor_manager.log + +# Whether DAG processor manager will write logs to stdout +# +# Variable: AIRFLOW__LOGGING__DAG_PROCESSOR_MANAGER_LOG_STDOUT +# +dag_processor_manager_log_stdout = False + +# Name of handler to read task instance logs. +# Defaults to use ``task`` handler. +# +# Variable: AIRFLOW__LOGGING__TASK_LOG_READER +# +task_log_reader = task + +# A comma\-separated list of third-party logger names that will be configured to print messages to +# consoles\. +# +# Example: extra_logger_names = connexion,sqlalchemy +# +# Variable: AIRFLOW__LOGGING__EXTRA_LOGGER_NAMES +# +extra_logger_names = + +# When you start an Airflow worker, Airflow starts a tiny web server +# subprocess to serve the workers local log files to the airflow main +# web server, who then builds pages and sends them to users. This defines +# the port on which the logs are served. It needs to be unused, and open +# visible from the main web server to connect into the workers. +# +# Variable: AIRFLOW__LOGGING__WORKER_LOG_SERVER_PORT +# +worker_log_server_port = 8793 + +# Port to serve logs from for triggerer. +# See ``[logging] worker_log_server_port`` description for more info. +# +# Variable: AIRFLOW__LOGGING__TRIGGER_LOG_SERVER_PORT +# +trigger_log_server_port = 8794 + +# We must parse timestamps to interleave logs between trigger and task. To do so, +# we need to parse timestamps in log files. In case your log format is non-standard, +# you may provide import path to callable which takes a string log line and returns +# the timestamp (datetime.datetime compatible). +# +# Example: interleave_timestamp_parser = path.to.my_func +# +# Variable: AIRFLOW__LOGGING__INTERLEAVE_TIMESTAMP_PARSER +# +# interleave_timestamp_parser = + +# Permissions in the form or of octal string as understood by chmod. The permissions are important +# when you use impersonation, when logs are written by a different user than airflow. The most secure +# way of configuring it in this case is to add both users to the same group and make it the default +# group of both users. Group-writeable logs are default in airflow, but you might decide that you are +# OK with having the logs other-writeable, in which case you should set it to ``0o777``. You might +# decide to add more security if you do not use impersonation and change it to ``0o755`` to make it +# only owner-writeable. You can also make it just readable only for owner by changing it to ``0o700`` +# if all the access (read/write) for your logs happens from the same user. +# +# Example: file_task_handler_new_folder_permissions = 0o775 +# +# Variable: AIRFLOW__LOGGING__FILE_TASK_HANDLER_NEW_FOLDER_PERMISSIONS +# +file_task_handler_new_folder_permissions = 0o775 + +# Permissions in the form or of octal string as understood by chmod. The permissions are important +# when you use impersonation, when logs are written by a different user than airflow. The most secure +# way of configuring it in this case is to add both users to the same group and make it the default +# group of both users. Group-writeable logs are default in airflow, but you might decide that you are +# OK with having the logs other-writeable, in which case you should set it to ``0o666``. You might +# decide to add more security if you do not use impersonation and change it to ``0o644`` to make it +# only owner-writeable. You can also make it just readable only for owner by changing it to ``0o600`` +# if all the access (read/write) for your logs happens from the same user. +# +# Example: file_task_handler_new_file_permissions = 0o664 +# +# Variable: AIRFLOW__LOGGING__FILE_TASK_HANDLER_NEW_FILE_PERMISSIONS +# +file_task_handler_new_file_permissions = 0o664 + +# By default Celery sends all logs into stderr. +# If enabled any previous logging handlers will get *removed*. +# With this option AirFlow will create new handlers +# and send low level logs like INFO and WARNING to stdout, +# while sending higher severity logs to stderr. +# +# Variable: AIRFLOW__LOGGING__CELERY_STDOUT_STDERR_SEPARATION +# +celery_stdout_stderr_separation = False + +# If enabled, Airflow may ship messages to task logs from outside the task run context, e.g. from +# the scheduler, executor, or callback execution context. This can help in circumstances such as +# when there's something blocking the execution of the task and ordinarily there may be no task +# logs at all. +# This is set to ``True`` by default. If you encounter issues with this feature +# (e.g. scheduler performance issues) it can be disabled. +# +# Variable: AIRFLOW__LOGGING__ENABLE_TASK_CONTEXT_LOGGER +# +enable_task_context_logger = True + +[metrics] +# `StatsD `__ integration settings. + +# If true, ``[metrics] metrics_allow_list`` and ``[metrics] metrics_block_list`` will use +# regex pattern matching anywhere within the metric name instead of only prefix matching +# at the start of the name. +# +# Variable: AIRFLOW__METRICS__METRICS_USE_PATTERN_MATCH +# +metrics_use_pattern_match = False + +# Configure an allow list (comma separated string) to send only certain metrics. +# If ``[metrics] metrics_use_pattern_match`` is ``false``, match only the exact metric name prefix. +# If ``[metrics] metrics_use_pattern_match`` is ``true``, provide regex patterns to match. +# +# Example: metrics_allow_list = "scheduler,executor,dagrun,pool,triggerer,celery" or "^scheduler,^executor,heartbeat|timeout" +# +# Variable: AIRFLOW__METRICS__METRICS_ALLOW_LIST +# +metrics_allow_list = + +# Configure a block list (comma separated string) to block certain metrics from being emitted. +# If ``[metrics] metrics_allow_list`` and ``[metrics] metrics_block_list`` are both configured, +# ``[metrics] metrics_block_list`` is ignored. +# +# If ``[metrics] metrics_use_pattern_match`` is ``false``, match only the exact metric name prefix. +# +# If ``[metrics] metrics_use_pattern_match`` is ``true``, provide regex patterns to match. +# +# Example: metrics_block_list = "scheduler,executor,dagrun,pool,triggerer,celery" or "^scheduler,^executor,heartbeat|timeout" +# +# Variable: AIRFLOW__METRICS__METRICS_BLOCK_LIST +# +metrics_block_list = + +# Enables sending metrics to StatsD. +# +# Variable: AIRFLOW__METRICS__STATSD_ON +# +statsd_on = False + +# Specifies the host address where the StatsD daemon (or server) is running +# +# Variable: AIRFLOW__METRICS__STATSD_HOST +# +statsd_host = localhost + +# Specifies the port on which the StatsD daemon (or server) is listening to +# +# Variable: AIRFLOW__METRICS__STATSD_PORT +# +statsd_port = 8125 + +# Defines the namespace for all metrics sent from Airflow to StatsD +# +# Variable: AIRFLOW__METRICS__STATSD_PREFIX +# +statsd_prefix = airflow + +# A function that validate the StatsD stat name, apply changes to the stat name if necessary and return +# the transformed stat name. +# +# The function should have the following signature +# +# .. code-block:: python +# +# def func_name(stat_name: str) -> str: ... +# +# Variable: AIRFLOW__METRICS__STAT_NAME_HANDLER +# +stat_name_handler = + +# To enable datadog integration to send airflow metrics. +# +# Variable: AIRFLOW__METRICS__STATSD_DATADOG_ENABLED +# +statsd_datadog_enabled = False + +# List of datadog tags attached to all metrics(e.g: ``key1:value1,key2:value2``) +# +# Variable: AIRFLOW__METRICS__STATSD_DATADOG_TAGS +# +statsd_datadog_tags = + +# Set to ``False`` to disable metadata tags for some of the emitted metrics +# +# Variable: AIRFLOW__METRICS__STATSD_DATADOG_METRICS_TAGS +# +statsd_datadog_metrics_tags = True + +# If you want to utilise your own custom StatsD client set the relevant +# module path below. +# Note: The module path must exist on your +# `PYTHONPATH ` +# for Airflow to pick it up +# +# Variable: AIRFLOW__METRICS__STATSD_CUSTOM_CLIENT_PATH +# +# statsd_custom_client_path = + +# If you want to avoid sending all the available metrics tags to StatsD, +# you can configure a block list of prefixes (comma separated) to filter out metric tags +# that start with the elements of the list (e.g: ``job_id,run_id``) +# +# Example: statsd_disabled_tags = job_id,run_id,dag_id,task_id +# +# Variable: AIRFLOW__METRICS__STATSD_DISABLED_TAGS +# +statsd_disabled_tags = job_id,run_id + +# To enable sending Airflow metrics with StatsD-Influxdb tagging convention. +# +# Variable: AIRFLOW__METRICS__STATSD_INFLUXDB_ENABLED +# +statsd_influxdb_enabled = False + +# Enables sending metrics to OpenTelemetry. +# +# Variable: AIRFLOW__METRICS__OTEL_ON +# +otel_on = False + +# Specifies the hostname or IP address of the OpenTelemetry Collector to which Airflow sends +# metrics and traces. +# +# Variable: AIRFLOW__METRICS__OTEL_HOST +# +otel_host = localhost + +# Specifies the port of the OpenTelemetry Collector that is listening to. +# +# Variable: AIRFLOW__METRICS__OTEL_PORT +# +otel_port = 8889 + +# The prefix for the Airflow metrics. +# +# Variable: AIRFLOW__METRICS__OTEL_PREFIX +# +otel_prefix = airflow + +# Defines the interval, in milliseconds, at which Airflow sends batches of metrics and traces +# to the configured OpenTelemetry Collector. +# +# Variable: AIRFLOW__METRICS__OTEL_INTERVAL_MILLISECONDS +# +otel_interval_milliseconds = 60000 + +# If ``True``, all metrics are also emitted to the console. Defaults to ``False``. +# +# Variable: AIRFLOW__METRICS__OTEL_DEBUGGING_ON +# +otel_debugging_on = False + +# If ``True``, SSL will be enabled. Defaults to ``False``. +# To establish an HTTPS connection to the OpenTelemetry collector, +# you need to configure the SSL certificate and key within the OpenTelemetry collector's +# ``config.yml`` file. +# +# Variable: AIRFLOW__METRICS__OTEL_SSL_ACTIVE +# +otel_ssl_active = False + +[secrets] +# Full class name of secrets backend to enable (will precede env vars and metastore in search path) +# +# Example: backend = airflow.providers.amazon.aws.secrets.systems_manager.SystemsManagerParameterStoreBackend +# +# Variable: AIRFLOW__SECRETS__BACKEND +# +backend = + +# The backend_kwargs param is loaded into a dictionary and passed to ``__init__`` +# of secrets backend class. See documentation for the secrets backend you are using. +# JSON is expected. +# +# Example for AWS Systems Manager ParameterStore: +# ``{"connections_prefix": "/airflow/connections", "profile_name": "default"}`` +# +# Variable: AIRFLOW__SECRETS__BACKEND_KWARGS +# +backend_kwargs = + +# .. note:: |experimental| +# +# Enables local caching of Variables, when parsing DAGs only. +# Using this option can make dag parsing faster if Variables are used in top level code, at the expense +# of longer propagation time for changes. +# Please note that this cache concerns only the DAG parsing step. There is no caching in place when DAG +# tasks are run. +# +# Variable: AIRFLOW__SECRETS__USE_CACHE +# +use_cache = False + +# .. note:: |experimental| +# +# When the cache is enabled, this is the duration for which we consider an entry in the cache to be +# valid. Entries are refreshed if they are older than this many seconds. +# It means that when the cache is enabled, this is the maximum amount of time you need to wait to see a +# Variable change take effect. +# +# Variable: AIRFLOW__SECRETS__CACHE_TTL_SECONDS +# +cache_ttl_seconds = 900 + +[cli] +# In what way should the cli access the API. The LocalClient will use the +# database directly, while the json_client will use the api running on the +# webserver +# +# Variable: AIRFLOW__CLI__API_CLIENT +# +api_client = airflow.api.client.local_client + +# If you set web_server_url_prefix, do NOT forget to append it here, ex: +# ``endpoint_url = http://localhost:8080/myroot`` +# So api will look like: ``http://localhost:8080/myroot/api/experimental/...`` +# +# Variable: AIRFLOW__CLI__ENDPOINT_URL +# +endpoint_url = http://localhost:8080 + +[debug] +# Used only with ``DebugExecutor``. If set to ``True`` DAG will fail with first +# failed task. Helpful for debugging purposes. +# +# Variable: AIRFLOW__DEBUG__FAIL_FAST +# +fail_fast = False + +[api] +# Enables the deprecated experimental API. Please note that these API endpoints do not have +# access control. An authenticated user has full access. +# +# .. warning:: +# +# This `Experimental REST API +# `__ is +# deprecated since version 2.0. Please consider using +# `the Stable REST API +# `__. +# For more information on migration, see +# `RELEASE_NOTES.rst `_ +# +# Variable: AIRFLOW__API__ENABLE_EXPERIMENTAL_API +# +enable_experimental_api = False + +# Comma separated list of auth backends to authenticate users of the API. See +# `Security: API +# `__ for possible values. +# ("airflow.api.auth.backend.default" allows all requests for historic reasons) +# +# Variable: AIRFLOW__API__AUTH_BACKENDS +# +auth_backends = airflow.api.auth.backend.session + +# Used to set the maximum page limit for API requests. If limit passed as param +# is greater than maximum page limit, it will be ignored and maximum page limit value +# will be set as the limit +# +# Variable: AIRFLOW__API__MAXIMUM_PAGE_LIMIT +# +maximum_page_limit = 100 + +# Used to set the default page limit when limit param is zero or not provided in API +# requests. Otherwise if positive integer is passed in the API requests as limit, the +# smallest number of user given limit or maximum page limit is taken as limit. +# +# Variable: AIRFLOW__API__FALLBACK_PAGE_LIMIT +# +fallback_page_limit = 100 + +# The intended audience for JWT token credentials used for authorization. This value must match on the client and server sides. If empty, audience will not be tested. +# +# Example: google_oauth2_audience = project-id-random-value.apps.googleusercontent.com +# +# Variable: AIRFLOW__API__GOOGLE_OAUTH2_AUDIENCE +# +google_oauth2_audience = + +# Path to Google Cloud Service Account key file (JSON). If omitted, authorization based on +# `the Application Default Credentials +# `__ will +# be used. +# +# Example: google_key_path = /files/service-account-json +# +# Variable: AIRFLOW__API__GOOGLE_KEY_PATH +# +google_key_path = + +# Used in response to a preflight request to indicate which HTTP +# headers can be used when making the actual request. This header is +# the server side response to the browser's +# Access-Control-Request-Headers header. +# +# Variable: AIRFLOW__API__ACCESS_CONTROL_ALLOW_HEADERS +# +access_control_allow_headers = + +# Specifies the method or methods allowed when accessing the resource. +# +# Variable: AIRFLOW__API__ACCESS_CONTROL_ALLOW_METHODS +# +access_control_allow_methods = + +# Indicates whether the response can be shared with requesting code from the given origins. +# Separate URLs with space. +# +# Variable: AIRFLOW__API__ACCESS_CONTROL_ALLOW_ORIGINS +# +access_control_allow_origins = + +# Indicates whether the **xcomEntries** endpoint supports the **deserialize** +# flag. If set to ``False``, setting this flag in a request would result in a +# 400 Bad Request error. +# +# Variable: AIRFLOW__API__ENABLE_XCOM_DESERIALIZE_SUPPORT +# +enable_xcom_deserialize_support = False + +[lineage] +# what lineage backend to use +# +# Variable: AIRFLOW__LINEAGE__BACKEND +# +backend = + +[operators] +# The default owner assigned to each new operator, unless +# provided explicitly or passed via ``default_args`` +# +# Variable: AIRFLOW__OPERATORS__DEFAULT_OWNER +# +default_owner = airflow + +# The default value of attribute "deferrable" in operators and sensors. +# +# Variable: AIRFLOW__OPERATORS__DEFAULT_DEFERRABLE +# +default_deferrable = false + +# Indicates the default number of CPU units allocated to each operator when no specific CPU request +# is specified in the operator's configuration +# +# Variable: AIRFLOW__OPERATORS__DEFAULT_CPUS +# +default_cpus = 1 + +# Indicates the default number of RAM allocated to each operator when no specific RAM request +# is specified in the operator's configuration +# +# Variable: AIRFLOW__OPERATORS__DEFAULT_RAM +# +default_ram = 512 + +# Indicates the default number of disk storage allocated to each operator when no specific disk request +# is specified in the operator's configuration +# +# Variable: AIRFLOW__OPERATORS__DEFAULT_DISK +# +default_disk = 512 + +# Indicates the default number of GPUs allocated to each operator when no specific GPUs request +# is specified in the operator's configuration +# +# Variable: AIRFLOW__OPERATORS__DEFAULT_GPUS +# +default_gpus = 0 + +# Default queue that tasks get assigned to and that worker listen on. +# +# Variable: AIRFLOW__OPERATORS__DEFAULT_QUEUE +# +default_queue = default + +# Is allowed to pass additional/unused arguments (args, kwargs) to the BaseOperator operator. +# If set to ``False``, an exception will be thrown, +# otherwise only the console message will be displayed. +# +# Variable: AIRFLOW__OPERATORS__ALLOW_ILLEGAL_ARGUMENTS +# +allow_illegal_arguments = False + +[webserver] +# The message displayed when a user attempts to execute actions beyond their authorised privileges. +# +# Variable: AIRFLOW__WEBSERVER__ACCESS_DENIED_MESSAGE +# +access_denied_message = Access is Denied + +# Path of webserver config file used for configuring the webserver parameters +# +# Variable: AIRFLOW__WEBSERVER__CONFIG_FILE +# +config_file = /opt/airflow/webserver_config.py + +# The base url of your website: Airflow cannot guess what domain or CNAME you are using. +# This is used to create links in the Log Url column in the Browse - Task Instances menu, +# as well as in any automated emails sent by Airflow that contain links to your webserver. +# +# Variable: AIRFLOW__WEBSERVER__BASE_URL +# +base_url = http://localhost:8080 + +# Default timezone to display all dates in the UI, can be UTC, system, or +# any IANA timezone string (e.g. **Europe/Amsterdam**). If left empty the +# default value of core/default_timezone will be used +# +# Example: default_ui_timezone = America/New_York +# +# Variable: AIRFLOW__WEBSERVER__DEFAULT_UI_TIMEZONE +# +default_ui_timezone = UTC + +# The ip specified when starting the web server +# +# Variable: AIRFLOW__WEBSERVER__WEB_SERVER_HOST +# +web_server_host = 0.0.0.0 + +# The port on which to run the web server +# +# Variable: AIRFLOW__WEBSERVER__WEB_SERVER_PORT +# +web_server_port = 8080 + +# Paths to the SSL certificate and key for the web server. When both are +# provided SSL will be enabled. This does not change the web server port. +# +# Variable: AIRFLOW__WEBSERVER__WEB_SERVER_SSL_CERT +# +web_server_ssl_cert = + +# Paths to the SSL certificate and key for the web server. When both are +# provided SSL will be enabled. This does not change the web server port. +# +# Variable: AIRFLOW__WEBSERVER__WEB_SERVER_SSL_KEY +# +web_server_ssl_key = + +# The type of backend used to store web session data, can be ``database`` or ``securecookie``. For the +# ``database`` backend, sessions are store in the database and they can be +# managed there (for example when you reset password of the user, all sessions for that user are +# deleted). For the ``securecookie`` backend, sessions are stored in encrypted cookies on the client +# side. The ``securecookie`` mechanism is 'lighter' than database backend, but sessions are not deleted +# when you reset password of the user, which means that other than waiting for expiry time, the only +# way to invalidate all sessions for a user is to change secret_key and restart webserver (which +# also invalidates and logs out all other user's sessions). +# +# When you are using ``database`` backend, make sure to keep your database session table small +# by periodically running ``airflow db clean --table session`` command, especially if you have +# automated API calls that will create a new session for each call rather than reuse the sessions +# stored in browser cookies. +# +# Example: session_backend = securecookie +# +# Variable: AIRFLOW__WEBSERVER__SESSION_BACKEND +# +session_backend = database + +# Number of seconds the webserver waits before killing gunicorn master that doesn't respond +# +# Variable: AIRFLOW__WEBSERVER__WEB_SERVER_MASTER_TIMEOUT +# +web_server_master_timeout = 120 + +# Number of seconds the gunicorn webserver waits before timing out on a worker +# +# Variable: AIRFLOW__WEBSERVER__WEB_SERVER_WORKER_TIMEOUT +# +web_server_worker_timeout = 120 + +# Number of workers to refresh at a time. When set to 0, worker refresh is +# disabled. When nonzero, airflow periodically refreshes webserver workers by +# bringing up new ones and killing old ones. +# +# Variable: AIRFLOW__WEBSERVER__WORKER_REFRESH_BATCH_SIZE +# +worker_refresh_batch_size = 1 + +# Number of seconds to wait before refreshing a batch of workers. +# +# Variable: AIRFLOW__WEBSERVER__WORKER_REFRESH_INTERVAL +# +worker_refresh_interval = 6000 + +# If set to ``True``, Airflow will track files in plugins_folder directory. When it detects changes, +# then reload the gunicorn. If set to ``True``, gunicorn starts without preloading, which is slower, +# uses more memory, and may cause race conditions. Avoid setting this to ``True`` in production. +# +# Variable: AIRFLOW__WEBSERVER__RELOAD_ON_PLUGIN_CHANGE +# +reload_on_plugin_change = False + +# Secret key used to run your flask app. It should be as random as possible. However, when running +# more than 1 instances of webserver, make sure all of them use the same ``secret_key`` otherwise +# one of them will error with "CSRF session token is missing". +# The webserver key is also used to authorize requests to Celery workers when logs are retrieved. +# The token generated using the secret key has a short expiry time though - make sure that time on +# ALL the machines that you run airflow components on is synchronized (for example using ntpd) +# otherwise you might get "forbidden" errors when the logs are accessed. +# +# Variable: AIRFLOW__WEBSERVER__SECRET_KEY +# +secret_key = xIcs2bnO5KyoXMwIyf88+g== + +# Number of workers to run the Gunicorn web server +# +# Variable: AIRFLOW__WEBSERVER__WORKERS +# +workers = 4 + +# The worker class gunicorn should use. Choices include +# ``sync`` (default), ``eventlet``, ``gevent``. +# +# .. warning:: +# +# When using ``gevent`` you might also want to set the ``_AIRFLOW_PATCH_GEVENT`` +# environment variable to ``"1"`` to make sure gevent patching is done as early as possible. +# +# See related Issues / PRs for more details: +# +# * https://github.com/benoitc/gunicorn/issues/2796 +# * https://github.com/apache/airflow/issues/8212 +# * https://github.com/apache/airflow/pull/28283 +# +# Variable: AIRFLOW__WEBSERVER__WORKER_CLASS +# +worker_class = sync + +# Log files for the gunicorn webserver. '-' means log to stderr. +# +# Variable: AIRFLOW__WEBSERVER__ACCESS_LOGFILE +# +access_logfile = - + +# Log files for the gunicorn webserver. '-' means log to stderr. +# +# Variable: AIRFLOW__WEBSERVER__ERROR_LOGFILE +# +error_logfile = - + +# Access log format for gunicorn webserver. +# default format is ``%%(h)s %%(l)s %%(u)s %%(t)s "%%(r)s" %%(s)s %%(b)s "%%(f)s" "%%(a)s"`` +# See `Gunicorn Settings: 'access_log_format' Reference +# `__ for more details +# +# Variable: AIRFLOW__WEBSERVER__ACCESS_LOGFORMAT +# +access_logformat = + +# Expose the configuration file in the web server. Set to ``non-sensitive-only`` to show all values +# except those that have security implications. ``True`` shows all values. ``False`` hides the +# configuration completely. +# +# Variable: AIRFLOW__WEBSERVER__EXPOSE_CONFIG +# +expose_config = False + +# Expose hostname in the web server +# +# Variable: AIRFLOW__WEBSERVER__EXPOSE_HOSTNAME +# +expose_hostname = False + +# Expose stacktrace in the web server +# +# Variable: AIRFLOW__WEBSERVER__EXPOSE_STACKTRACE +# +expose_stacktrace = False + +# Default DAG view. Valid values are: ``grid``, ``graph``, ``duration``, ``gantt``, ``landing_times`` +# +# Variable: AIRFLOW__WEBSERVER__DAG_DEFAULT_VIEW +# +dag_default_view = grid + +# Default DAG orientation. Valid values are: +# ``LR`` (Left->Right), ``TB`` (Top->Bottom), ``RL`` (Right->Left), ``BT`` (Bottom->Top) +# +# Variable: AIRFLOW__WEBSERVER__DAG_ORIENTATION +# +dag_orientation = LR + +# Sorting order in grid view. Valid values are: ``topological``, ``hierarchical_alphabetical`` +# +# Variable: AIRFLOW__WEBSERVER__GRID_VIEW_SORTING_ORDER +# +grid_view_sorting_order = topological + +# The amount of time (in secs) webserver will wait for initial handshake +# while fetching logs from other worker machine +# +# Variable: AIRFLOW__WEBSERVER__LOG_FETCH_TIMEOUT_SEC +# +log_fetch_timeout_sec = 5 + +# Time interval (in secs) to wait before next log fetching. +# +# Variable: AIRFLOW__WEBSERVER__LOG_FETCH_DELAY_SEC +# +log_fetch_delay_sec = 2 + +# Distance away from page bottom to enable auto tailing. +# +# Variable: AIRFLOW__WEBSERVER__LOG_AUTO_TAILING_OFFSET +# +log_auto_tailing_offset = 30 + +# Animation speed for auto tailing log display. +# +# Variable: AIRFLOW__WEBSERVER__LOG_ANIMATION_SPEED +# +log_animation_speed = 1000 + +# By default, the webserver shows paused DAGs. Flip this to hide paused +# DAGs by default +# +# Variable: AIRFLOW__WEBSERVER__HIDE_PAUSED_DAGS_BY_DEFAULT +# +hide_paused_dags_by_default = False + +# Consistent page size across all listing views in the UI +# +# Variable: AIRFLOW__WEBSERVER__PAGE_SIZE +# +page_size = 100 + +# Define the color of navigation bar +# +# Variable: AIRFLOW__WEBSERVER__NAVBAR_COLOR +# +navbar_color = #fff + +# Define the color of text in the navigation bar +# +# Variable: AIRFLOW__WEBSERVER__NAVBAR_TEXT_COLOR +# +navbar_text_color = #51504f + +# Define the color of navigation bar links when hovered +# +# Variable: AIRFLOW__WEBSERVER__NAVBAR_HOVER_COLOR +# +navbar_hover_color = #eee + +# Define the color of text in the navigation bar when hovered +# +# Variable: AIRFLOW__WEBSERVER__NAVBAR_TEXT_HOVER_COLOR +# +navbar_text_hover_color = #51504f + +# Define the color of the logo text +# +# Variable: AIRFLOW__WEBSERVER__NAVBAR_LOGO_TEXT_COLOR +# +navbar_logo_text_color = #51504f + +# Default dagrun to show in UI +# +# Variable: AIRFLOW__WEBSERVER__DEFAULT_DAG_RUN_DISPLAY_NUMBER +# +default_dag_run_display_number = 25 + +# Enable werkzeug ``ProxyFix`` middleware for reverse proxy +# +# Variable: AIRFLOW__WEBSERVER__ENABLE_PROXY_FIX +# +enable_proxy_fix = False + +# Number of values to trust for ``X-Forwarded-For``. +# See `Werkzeug: X-Forwarded-For Proxy Fix +# `__ for more details. +# +# Variable: AIRFLOW__WEBSERVER__PROXY_FIX_X_FOR +# +proxy_fix_x_for = 1 + +# Number of values to trust for ``X-Forwarded-Proto``. +# See `Werkzeug: X-Forwarded-For Proxy Fix +# `__ for more details. +# +# Variable: AIRFLOW__WEBSERVER__PROXY_FIX_X_PROTO +# +proxy_fix_x_proto = 1 + +# Number of values to trust for ``X-Forwarded-Host``. +# See `Werkzeug: X-Forwarded-For Proxy Fix +# `__ for more details. +# +# Variable: AIRFLOW__WEBSERVER__PROXY_FIX_X_HOST +# +proxy_fix_x_host = 1 + +# Number of values to trust for ``X-Forwarded-Port``. +# See `Werkzeug: X-Forwarded-For Proxy Fix +# `__ for more details. +# +# Variable: AIRFLOW__WEBSERVER__PROXY_FIX_X_PORT +# +proxy_fix_x_port = 1 + +# Number of values to trust for ``X-Forwarded-Prefix``. +# See `Werkzeug: X-Forwarded-For Proxy Fix +# `__ for more details. +# +# Variable: AIRFLOW__WEBSERVER__PROXY_FIX_X_PREFIX +# +proxy_fix_x_prefix = 1 + +# Set secure flag on session cookie +# +# Variable: AIRFLOW__WEBSERVER__COOKIE_SECURE +# +cookie_secure = False + +# Set samesite policy on session cookie +# +# Variable: AIRFLOW__WEBSERVER__COOKIE_SAMESITE +# +cookie_samesite = Lax + +# Default setting for wrap toggle on DAG code and TI log views. +# +# Variable: AIRFLOW__WEBSERVER__DEFAULT_WRAP +# +default_wrap = False + +# Allow the UI to be rendered in a frame +# +# Variable: AIRFLOW__WEBSERVER__X_FRAME_ENABLED +# +x_frame_enabled = True + +# Send anonymous user activity to your analytics tool +# choose from ``google_analytics``, ``segment``, ``metarouter``, or ``matomo`` +# +# Variable: AIRFLOW__WEBSERVER__ANALYTICS_TOOL +# +# analytics_tool = + +# Unique ID of your account in the analytics tool +# +# Variable: AIRFLOW__WEBSERVER__ANALYTICS_ID +# +# analytics_id = + +# Your instances url, only applicable to Matomo. +# +# Example: analytics_url = https://your.matomo.instance.com/ +# +# Variable: AIRFLOW__WEBSERVER__ANALYTICS_URL +# +# analytics_url = + +# 'Recent Tasks' stats will show for old DagRuns if set +# +# Variable: AIRFLOW__WEBSERVER__SHOW_RECENT_STATS_FOR_COMPLETED_RUNS +# +show_recent_stats_for_completed_runs = True + +# The UI cookie lifetime in minutes. User will be logged out from UI after +# ``[webserver] session_lifetime_minutes`` of non-activity +# +# Variable: AIRFLOW__WEBSERVER__SESSION_LIFETIME_MINUTES +# +session_lifetime_minutes = 43200 + +# Sets a custom page title for the DAGs overview page and site title for all pages +# +# Variable: AIRFLOW__WEBSERVER__INSTANCE_NAME +# +# instance_name = + +# Whether the custom page title for the DAGs overview page contains any Markup language +# +# Variable: AIRFLOW__WEBSERVER__INSTANCE_NAME_HAS_MARKUP +# +instance_name_has_markup = False + +# How frequently, in seconds, the DAG data will auto-refresh in graph or grid view +# when auto-refresh is turned on +# +# Variable: AIRFLOW__WEBSERVER__AUTO_REFRESH_INTERVAL +# +auto_refresh_interval = 3 + +# Boolean for displaying warning for publicly viewable deployment +# +# Variable: AIRFLOW__WEBSERVER__WARN_DEPLOYMENT_EXPOSURE +# +warn_deployment_exposure = True + +# Comma separated string of view events to exclude from dag audit view. +# All other events will be added minus the ones passed here. +# The audit logs in the db will not be affected by this parameter. +# +# Example: audit_view_excluded_events = cli_task_run,running,success +# +# Variable: AIRFLOW__WEBSERVER__AUDIT_VIEW_EXCLUDED_EVENTS +# +# audit_view_excluded_events = + +# Comma separated string of view events to include in dag audit view. +# If passed, only these events will populate the dag audit view. +# The audit logs in the db will not be affected by this parameter. +# +# Example: audit_view_included_events = dagrun_cleared,failed +# +# Variable: AIRFLOW__WEBSERVER__AUDIT_VIEW_INCLUDED_EVENTS +# +# audit_view_included_events = + +# Boolean for running SwaggerUI in the webserver. +# +# Variable: AIRFLOW__WEBSERVER__ENABLE_SWAGGER_UI +# +enable_swagger_ui = True + +# Boolean for running Internal API in the webserver. +# +# Variable: AIRFLOW__WEBSERVER__RUN_INTERNAL_API +# +run_internal_api = False + +# The caching algorithm used by the webserver. Must be a valid hashlib function name. +# +# Example: caching_hash_method = sha256 +# +# Variable: AIRFLOW__WEBSERVER__CACHING_HASH_METHOD +# +caching_hash_method = md5 + +# Behavior of the trigger DAG run button for DAGs without params. ``False`` to skip and trigger +# without displaying a form to add a **dag_run.conf**, ``True`` to always display the form. +# The form is displayed always if parameters are defined. +# +# Variable: AIRFLOW__WEBSERVER__SHOW_TRIGGER_FORM_IF_NO_PARAMS +# +show_trigger_form_if_no_params = False + +# Number of recent DAG run configurations in the selector on the trigger web form. +# +# Example: num_recent_configurations_for_trigger = 10 +# +# Variable: AIRFLOW__WEBSERVER__NUM_RECENT_CONFIGURATIONS_FOR_TRIGGER +# +num_recent_configurations_for_trigger = 5 + +# A DAG author is able to provide any raw HTML into ``doc_md`` or params description in +# ``description_md`` for text formatting. This is including potentially unsafe javascript. +# Displaying the DAG or trigger form in web UI provides the DAG author the potential to +# inject malicious code into clients browsers. To ensure the web UI is safe by default, +# raw HTML is disabled by default. If you trust your DAG authors, you can enable HTML +# support in markdown by setting this option to ``True``. +# +# This parameter also enables the deprecated fields ``description_html`` and +# ``custom_html_form`` in DAG params until the feature is removed in a future version. +# +# Example: allow_raw_html_descriptions = False +# +# Variable: AIRFLOW__WEBSERVER__ALLOW_RAW_HTML_DESCRIPTIONS +# +allow_raw_html_descriptions = False + +# The maximum size of the request payload (in MB) that can be sent. +# +# Variable: AIRFLOW__WEBSERVER__ALLOWED_PAYLOAD_SIZE +# +allowed_payload_size = 1.0 + +# Require confirmation when changing a DAG in the web UI. This is to prevent accidental changes +# to a DAG that may be running on sensitive environments like production. +# When set to ``True``, confirmation dialog will be shown when a user tries to Pause/Unpause, +# Trigger a DAG +# +# Variable: AIRFLOW__WEBSERVER__REQUIRE_CONFIRMATION_DAG_CHANGE +# +require_confirmation_dag_change = False + +[email] +# Configuration email backend and whether to +# send email alerts on retry or failure + +# Email backend to use +# +# Variable: AIRFLOW__EMAIL__EMAIL_BACKEND +# +email_backend = airflow.utils.email.send_email_smtp + +# Email connection to use +# +# Variable: AIRFLOW__EMAIL__EMAIL_CONN_ID +# +email_conn_id = smtp_default + +# Whether email alerts should be sent when a task is retried +# +# Variable: AIRFLOW__EMAIL__DEFAULT_EMAIL_ON_RETRY +# +default_email_on_retry = True + +# Whether email alerts should be sent when a task failed +# +# Variable: AIRFLOW__EMAIL__DEFAULT_EMAIL_ON_FAILURE +# +default_email_on_failure = True + +# File that will be used as the template for Email subject (which will be rendered using Jinja2). +# If not set, Airflow uses a base template. +# +# Example: subject_template = /path/to/my_subject_template_file +# +# Variable: AIRFLOW__EMAIL__SUBJECT_TEMPLATE +# +# subject_template = + +# File that will be used as the template for Email content (which will be rendered using Jinja2). +# If not set, Airflow uses a base template. +# +# Example: html_content_template = /path/to/my_html_content_template_file +# +# Variable: AIRFLOW__EMAIL__HTML_CONTENT_TEMPLATE +# +# html_content_template = + +# Email address that will be used as sender address. +# It can either be raw email or the complete address in a format ``Sender Name `` +# +# Example: from_email = Airflow +# +# Variable: AIRFLOW__EMAIL__FROM_EMAIL +# +# from_email = + +# ssl context to use when using SMTP and IMAP SSL connections. By default, the context is "default" +# which sets it to ``ssl.create_default_context()`` which provides the right balance between +# compatibility and security, it however requires that certificates in your operating system are +# updated and that SMTP/IMAP servers of yours have valid certificates that have corresponding public +# keys installed on your machines. You can switch it to "none" if you want to disable checking +# of the certificates, but it is not recommended as it allows MITM (man-in-the-middle) attacks +# if your infrastructure is not sufficiently secured. It should only be set temporarily while you +# are fixing your certificate configuration. This can be typically done by upgrading to newer +# version of the operating system you run Airflow components on,by upgrading/refreshing proper +# certificates in the OS or by updating certificates for your mail servers. +# +# Example: ssl_context = default +# +# Variable: AIRFLOW__EMAIL__SSL_CONTEXT +# +ssl_context = default + +[smtp] +# If you want airflow to send emails on retries, failure, and you want to use +# the airflow.utils.email.send_email_smtp function, you have to configure an +# smtp server here + +# Specifies the host server address used by Airflow when sending out email notifications via SMTP. +# +# Variable: AIRFLOW__SMTP__SMTP_HOST +# +smtp_host = localhost + +# Determines whether to use the STARTTLS command when connecting to the SMTP server. +# +# Variable: AIRFLOW__SMTP__SMTP_STARTTLS +# +smtp_starttls = True + +# Determines whether to use an SSL connection when talking to the SMTP server. +# +# Variable: AIRFLOW__SMTP__SMTP_SSL +# +smtp_ssl = False + +# Username to authenticate when connecting to smtp server. +# +# Example: smtp_user = airflow +# +# Variable: AIRFLOW__SMTP__SMTP_USER +# +# smtp_user = + +# Password to authenticate when connecting to smtp server. +# +# Example: smtp_password = airflow +# +# Variable: AIRFLOW__SMTP__SMTP_PASSWORD +# +# smtp_password = + +# Defines the port number on which Airflow connects to the SMTP server to send email notifications. +# +# Variable: AIRFLOW__SMTP__SMTP_PORT +# +smtp_port = 25 + +# Specifies the default **from** email address used when Airflow sends email notifications. +# +# Variable: AIRFLOW__SMTP__SMTP_MAIL_FROM +# +smtp_mail_from = airflow@example.com + +# Determines the maximum time (in seconds) the Apache Airflow system will wait for a +# connection to the SMTP server to be established. +# +# Variable: AIRFLOW__SMTP__SMTP_TIMEOUT +# +smtp_timeout = 30 + +# Defines the maximum number of times Airflow will attempt to connect to the SMTP server. +# +# Variable: AIRFLOW__SMTP__SMTP_RETRY_LIMIT +# +smtp_retry_limit = 5 + +[sentry] +# `Sentry `__ integration. Here you can supply +# additional configuration options based on the Python platform. +# See `Python / Configuration / Basic Options +# `__ for more details. +# Unsupported options: ``integrations``, ``in_app_include``, ``in_app_exclude``, +# ``ignore_errors``, ``before_breadcrumb``, ``transport``. + +# Enable error reporting to Sentry +# +# Variable: AIRFLOW__SENTRY__SENTRY_ON +# +sentry_on = false + +# +# Variable: AIRFLOW__SENTRY__SENTRY_DSN +# +sentry_dsn = + +# Dotted path to a before_send function that the sentry SDK should be configured to use. +# +# Variable: AIRFLOW__SENTRY__BEFORE_SEND +# +# before_send = + +[scheduler] +# Task instances listen for external kill signal (when you clear tasks +# from the CLI or the UI), this defines the frequency at which they should +# listen (in seconds). +# +# Variable: AIRFLOW__SCHEDULER__JOB_HEARTBEAT_SEC +# +job_heartbeat_sec = 5 + +# The scheduler constantly tries to trigger new tasks (look at the +# scheduler section in the docs for more information). This defines +# how often the scheduler should run (in seconds). +# +# Variable: AIRFLOW__SCHEDULER__SCHEDULER_HEARTBEAT_SEC +# +scheduler_heartbeat_sec = 5 + +# The frequency (in seconds) at which the LocalTaskJob should send heartbeat signals to the +# scheduler to notify it's still alive. If this value is set to 0, the heartbeat interval will default +# to the value of ``[scheduler] scheduler_zombie_task_threshold``. +# +# Variable: AIRFLOW__SCHEDULER__LOCAL_TASK_JOB_HEARTBEAT_SEC +# +local_task_job_heartbeat_sec = 0 + +# The number of times to try to schedule each DAG file +# -1 indicates unlimited number +# +# Variable: AIRFLOW__SCHEDULER__NUM_RUNS +# +num_runs = -1 + +# Controls how long the scheduler will sleep between loops, but if there was nothing to do +# in the loop. i.e. if it scheduled something then it will start the next loop +# iteration straight away. +# +# Variable: AIRFLOW__SCHEDULER__SCHEDULER_IDLE_SLEEP_TIME +# +scheduler_idle_sleep_time = 1 + +# Number of seconds after which a DAG file is parsed. The DAG file is parsed every +# ``[scheduler] min_file_process_interval`` number of seconds. Updates to DAGs are reflected after +# this interval. Keeping this number low will increase CPU usage. +# +# Variable: AIRFLOW__SCHEDULER__MIN_FILE_PROCESS_INTERVAL +# +min_file_process_interval = 30 + +# How often (in seconds) to check for stale DAGs (DAGs which are no longer present in +# the expected files) which should be deactivated, as well as datasets that are no longer +# referenced and should be marked as orphaned. +# +# Variable: AIRFLOW__SCHEDULER__PARSING_CLEANUP_INTERVAL +# +parsing_cleanup_interval = 60 + +# How long (in seconds) to wait after we have re-parsed a DAG file before deactivating stale +# DAGs (DAGs which are no longer present in the expected files). The reason why we need +# this threshold is to account for the time between when the file is parsed and when the +# DAG is loaded. The absolute maximum that this could take is ``[core] dag_file_processor_timeout``, +# but when you have a long timeout configured, it results in a significant delay in the +# deactivation of stale dags. +# +# Variable: AIRFLOW__SCHEDULER__STALE_DAG_THRESHOLD +# +stale_dag_threshold = 50 + +# How often (in seconds) to scan the DAGs directory for new files. Default to 5 minutes. +# +# Variable: AIRFLOW__SCHEDULER__DAG_DIR_LIST_INTERVAL +# +dag_dir_list_interval = 300 + +# How often should stats be printed to the logs. Setting to 0 will disable printing stats +# +# Variable: AIRFLOW__SCHEDULER__PRINT_STATS_INTERVAL +# +print_stats_interval = 30 + +# How often (in seconds) should pool usage stats be sent to StatsD (if statsd_on is enabled) +# +# Variable: AIRFLOW__SCHEDULER__POOL_METRICS_INTERVAL +# +pool_metrics_interval = 5.0 + +# If the last scheduler heartbeat happened more than ``[scheduler] scheduler_health_check_threshold`` +# ago (in seconds), scheduler is considered unhealthy. +# This is used by the health check in the **/health** endpoint and in ``airflow jobs check`` CLI +# for SchedulerJob. +# +# Variable: AIRFLOW__SCHEDULER__SCHEDULER_HEALTH_CHECK_THRESHOLD +# +scheduler_health_check_threshold = 30 + +# When you start a scheduler, airflow starts a tiny web server +# subprocess to serve a health check if this is set to ``True`` +# +# Variable: AIRFLOW__SCHEDULER__ENABLE_HEALTH_CHECK +# +enable_health_check = False + +# When you start a scheduler, airflow starts a tiny web server +# subprocess to serve a health check on this host +# +# Variable: AIRFLOW__SCHEDULER__SCHEDULER_HEALTH_CHECK_SERVER_HOST +# +scheduler_health_check_server_host = 0.0.0.0 + +# When you start a scheduler, airflow starts a tiny web server +# subprocess to serve a health check on this port +# +# Variable: AIRFLOW__SCHEDULER__SCHEDULER_HEALTH_CHECK_SERVER_PORT +# +scheduler_health_check_server_port = 8974 + +# How often (in seconds) should the scheduler check for orphaned tasks and SchedulerJobs +# +# Variable: AIRFLOW__SCHEDULER__ORPHANED_TASKS_CHECK_INTERVAL +# +orphaned_tasks_check_interval = 300.0 + +# Determines the directory where logs for the child processes of the scheduler will be stored +# +# Variable: AIRFLOW__SCHEDULER__CHILD_PROCESS_LOG_DIRECTORY +# +child_process_log_directory = /opt/airflow/logs/scheduler + +# Local task jobs periodically heartbeat to the DB. If the job has +# not heartbeat in this many seconds, the scheduler will mark the +# associated task instance as failed and will re-schedule the task. +# +# Variable: AIRFLOW__SCHEDULER__SCHEDULER_ZOMBIE_TASK_THRESHOLD +# +scheduler_zombie_task_threshold = 300 + +# How often (in seconds) should the scheduler check for zombie tasks. +# +# Variable: AIRFLOW__SCHEDULER__ZOMBIE_DETECTION_INTERVAL +# +zombie_detection_interval = 10.0 + +# Turn off scheduler catchup by setting this to ``False``. +# Default behavior is unchanged and +# Command Line Backfills still work, but the scheduler +# will not do scheduler catchup if this is ``False``, +# however it can be set on a per DAG basis in the +# DAG definition (catchup) +# +# Variable: AIRFLOW__SCHEDULER__CATCHUP_BY_DEFAULT +# +catchup_by_default = True + +# Setting this to ``True`` will make first task instance of a task +# ignore depends_on_past setting. A task instance will be considered +# as the first task instance of a task when there is no task instance +# in the DB with an execution_date earlier than it., i.e. no manual marking +# success will be needed for a newly added task to be scheduled. +# +# Variable: AIRFLOW__SCHEDULER__IGNORE_FIRST_DEPENDS_ON_PAST_BY_DEFAULT +# +ignore_first_depends_on_past_by_default = True + +# This changes the batch size of queries in the scheduling main loop. +# This should not be greater than ``[core] parallelism``. +# If this is too high, SQL query performance may be impacted by +# complexity of query predicate, and/or excessive locking. +# Additionally, you may hit the maximum allowable query length for your db. +# Set this to 0 to use the value of ``[core] parallelism`` +# +# Variable: AIRFLOW__SCHEDULER__MAX_TIS_PER_QUERY +# +max_tis_per_query = 16 + +# Should the scheduler issue ``SELECT ... FOR UPDATE`` in relevant queries. +# If this is set to ``False`` then you should not run more than a single +# scheduler at once +# +# Variable: AIRFLOW__SCHEDULER__USE_ROW_LEVEL_LOCKING +# +use_row_level_locking = True + +# Max number of DAGs to create DagRuns for per scheduler loop. +# +# Variable: AIRFLOW__SCHEDULER__MAX_DAGRUNS_TO_CREATE_PER_LOOP +# +max_dagruns_to_create_per_loop = 10 + +# How many DagRuns should a scheduler examine (and lock) when scheduling +# and queuing tasks. +# +# Variable: AIRFLOW__SCHEDULER__MAX_DAGRUNS_PER_LOOP_TO_SCHEDULE +# +max_dagruns_per_loop_to_schedule = 20 + +# Should the Task supervisor process perform a "mini scheduler" to attempt to schedule more tasks of the +# same DAG. Leaving this on will mean tasks in the same DAG execute quicker, but might starve out other +# dags in some circumstances +# +# Variable: AIRFLOW__SCHEDULER__SCHEDULE_AFTER_TASK_EXECUTION +# +schedule_after_task_execution = True + +# The scheduler reads dag files to extract the airflow modules that are going to be used, +# and imports them ahead of time to avoid having to re-do it for each parsing process. +# This flag can be set to ``False`` to disable this behavior in case an airflow module needs +# to be freshly imported each time (at the cost of increased DAG parsing time). +# +# Variable: AIRFLOW__SCHEDULER__PARSING_PRE_IMPORT_MODULES +# +parsing_pre_import_modules = True + +# The scheduler can run multiple processes in parallel to parse dags. +# This defines how many processes will run. +# +# Variable: AIRFLOW__SCHEDULER__PARSING_PROCESSES +# +parsing_processes = 2 + +# One of ``modified_time``, ``random_seeded_by_host`` and ``alphabetical``. +# The scheduler will list and sort the dag files to decide the parsing order. +# +# * ``modified_time``: Sort by modified time of the files. This is useful on large scale to parse the +# recently modified DAGs first. +# * ``random_seeded_by_host``: Sort randomly across multiple Schedulers but with same order on the +# same host. This is useful when running with Scheduler in HA mode where each scheduler can +# parse different DAG files. +# * ``alphabetical``: Sort by filename +# +# Variable: AIRFLOW__SCHEDULER__FILE_PARSING_SORT_MODE +# +file_parsing_sort_mode = modified_time + +# Whether the dag processor is running as a standalone process or it is a subprocess of a scheduler +# job. +# +# Variable: AIRFLOW__SCHEDULER__STANDALONE_DAG_PROCESSOR +# +standalone_dag_processor = False + +# Only applicable if ``[scheduler] standalone_dag_processor`` is true and callbacks are stored +# in database. Contains maximum number of callbacks that are fetched during a single loop. +# +# Variable: AIRFLOW__SCHEDULER__MAX_CALLBACKS_PER_LOOP +# +max_callbacks_per_loop = 20 + +# Only applicable if ``[scheduler] standalone_dag_processor`` is true. +# Time in seconds after which dags, which were not updated by Dag Processor are deactivated. +# +# Variable: AIRFLOW__SCHEDULER__DAG_STALE_NOT_SEEN_DURATION +# +dag_stale_not_seen_duration = 600 + +# Turn off scheduler use of cron intervals by setting this to ``False``. +# DAGs submitted manually in the web UI or with trigger_dag will still run. +# +# Variable: AIRFLOW__SCHEDULER__USE_JOB_SCHEDULE +# +use_job_schedule = True + +# Allow externally triggered DagRuns for Execution Dates in the future +# Only has effect if schedule_interval is set to None in DAG +# +# Variable: AIRFLOW__SCHEDULER__ALLOW_TRIGGER_IN_FUTURE +# +allow_trigger_in_future = False + +# How often to check for expired trigger requests that have not run yet. +# +# Variable: AIRFLOW__SCHEDULER__TRIGGER_TIMEOUT_CHECK_INTERVAL +# +trigger_timeout_check_interval = 15 + +# Amount of time a task can be in the queued state before being retried or set to failed. +# +# Variable: AIRFLOW__SCHEDULER__TASK_QUEUED_TIMEOUT +# +task_queued_timeout = 600.0 + +# How often to check for tasks that have been in the queued state for +# longer than ``[scheduler] task_queued_timeout``. +# +# Variable: AIRFLOW__SCHEDULER__TASK_QUEUED_TIMEOUT_CHECK_INTERVAL +# +task_queued_timeout_check_interval = 120.0 + +# The run_id pattern used to verify the validity of user input to the run_id parameter when +# triggering a DAG. This pattern cannot change the pattern used by scheduler to generate run_id +# for scheduled DAG runs or DAG runs triggered without changing the run_id parameter. +# +# Variable: AIRFLOW__SCHEDULER__ALLOWED_RUN_ID_PATTERN +# +allowed_run_id_pattern = ^[A-Za-z0-9_.~:+-]+$ + +# Whether to create DAG runs that span an interval or one single point in time for cron schedules, when +# a cron string is provided to ``schedule`` argument of a DAG. +# +# * ``True``: **CronDataIntervalTimetable** is used, which is suitable +# for DAGs with well-defined data interval. You get contiguous intervals from the end of the previous +# interval up to the scheduled datetime. +# * ``False``: **CronTriggerTimetable** is used, which is closer to the behavior of cron itself. +# +# Notably, for **CronTriggerTimetable**, the logical date is the same as the time the DAG Run will +# try to schedule, while for **CronDataIntervalTimetable**, the logical date is the beginning of +# the data interval, but the DAG Run will try to schedule at the end of the data interval. +# +# Variable: AIRFLOW__SCHEDULER__CREATE_CRON_DATA_INTERVALS +# +create_cron_data_intervals = True + +[triggerer] +# How many triggers a single Triggerer will run at once, by default. +# +# Variable: AIRFLOW__TRIGGERER__DEFAULT_CAPACITY +# +default_capacity = 1000 + +# How often to heartbeat the Triggerer job to ensure it hasn't been killed. +# +# Variable: AIRFLOW__TRIGGERER__JOB_HEARTBEAT_SEC +# +job_heartbeat_sec = 5 + +# If the last triggerer heartbeat happened more than ``[triggerer] triggerer_health_check_threshold`` +# ago (in seconds), triggerer is considered unhealthy. +# This is used by the health check in the **/health** endpoint and in ``airflow jobs check`` CLI +# for TriggererJob. +# +# Variable: AIRFLOW__TRIGGERER__TRIGGERER_HEALTH_CHECK_THRESHOLD +# +triggerer_health_check_threshold = 30 + +[kerberos] +# Location of your ccache file once kinit has been performed. +# +# Variable: AIRFLOW__KERBEROS__CCACHE +# +ccache = /tmp/airflow_krb5_ccache + +# gets augmented with fqdn +# +# Variable: AIRFLOW__KERBEROS__PRINCIPAL +# +principal = airflow + +# Determines the frequency at which initialization or re-initialization processes occur. +# +# Variable: AIRFLOW__KERBEROS__REINIT_FREQUENCY +# +reinit_frequency = 3600 + +# Path to the kinit executable +# +# Variable: AIRFLOW__KERBEROS__KINIT_PATH +# +kinit_path = kinit + +# Designates the path to the Kerberos keytab file for the Airflow user +# +# Variable: AIRFLOW__KERBEROS__KEYTAB +# +keytab = airflow.keytab + +# Allow to disable ticket forwardability. +# +# Variable: AIRFLOW__KERBEROS__FORWARDABLE +# +forwardable = True + +# Allow to remove source IP from token, useful when using token behind NATted Docker host. +# +# Variable: AIRFLOW__KERBEROS__INCLUDE_IP +# +include_ip = True + +[sensors] +# Sensor default timeout, 7 days by default (7 * 24 * 60 * 60). +# +# Variable: AIRFLOW__SENSORS__DEFAULT_TIMEOUT +# +default_timeout = 604800 + +[common.io] +# Common IO configuration section + +# Path to a location on object storage where XComs can be stored in url format. +# +# Example: xcom_objectstorage_path = s3://conn_id@bucket/path +# +# Variable: AIRFLOW__COMMON.IO__XCOM_OBJECTSTORAGE_PATH +# +xcom_objectstorage_path = + +# Threshold in bytes for storing XComs in object storage. -1 means always store in the +# database. 0 means always store in object storage. Any positive number means +# it will be stored in object storage if the size of the value is greater than the threshold. +# +# Example: xcom_objectstorage_threshold = 1000000 +# +# Variable: AIRFLOW__COMMON.IO__XCOM_OBJECTSTORAGE_THRESHOLD +# +xcom_objectstorage_threshold = -1 + +# Compression algorithm to use when storing XComs in object storage. Supported algorithms +# are a.o.: snappy, zip, gzip, bz2, and lzma. If not specified, no compression will be used. +# Note that the compression algorithm must be available in the Python installation (e.g. +# python-snappy for snappy). Zip, gz, bz2 are available by default. +# +# Example: xcom_objectstorage_compression = gz +# +# Variable: AIRFLOW__COMMON.IO__XCOM_OBJECTSTORAGE_COMPRESSION +# +xcom_objectstorage_compression = + +[fab] +# This section contains configs specific to FAB provider. + +# Boolean for enabling rate limiting on authentication endpoints. +# +# Variable: AIRFLOW__FAB__AUTH_RATE_LIMITED +# +auth_rate_limited = True + +# Rate limit for authentication endpoints. +# +# Variable: AIRFLOW__FAB__AUTH_RATE_LIMIT +# +auth_rate_limit = 5 per 40 second + +# Update FAB permissions and sync security manager roles +# on webserver startup +# +# Variable: AIRFLOW__FAB__UPDATE_FAB_PERMS +# +update_fab_perms = True + +[imap] +# Options for IMAP provider. + +# ssl_context = + +[smtp_provider] +# Options for SMTP provider. + +# ssl context to use when using SMTP and IMAP SSL connections. By default, the context is "default" +# which sets it to ``ssl.create_default_context()`` which provides the right balance between +# compatibility and security, it however requires that certificates in your operating system are +# updated and that SMTP/IMAP servers of yours have valid certificates that have corresponding public +# keys installed on your machines. You can switch it to "none" if you want to disable checking +# of the certificates, but it is not recommended as it allows MITM (man-in-the-middle) attacks +# if your infrastructure is not sufficiently secured. It should only be set temporarily while you +# are fixing your certificate configuration. This can be typically done by upgrading to newer +# version of the operating system you run Airflow components on,by upgrading/refreshing proper +# certificates in the OS or by updating certificates for your mail servers. +# +# If you do not set this option explicitly, it will use Airflow "email.ssl_context" configuration, +# but if this configuration is not present, it will use "default" value. +# +# Example: ssl_context = default +# +# Variable: AIRFLOW__SMTP_PROVIDER__SSL_CONTEXT +# +# ssl_context = + +# Allows overriding of the standard templated email subject line when the SmtpNotifier is used. +# Must provide a path to the template. +# +# Example: templated_email_subject_path = path/to/override/email_subject.html +# +# Variable: AIRFLOW__SMTP_PROVIDER__TEMPLATED_EMAIL_SUBJECT_PATH +# +# templated_email_subject_path = + +# Allows overriding of the standard templated email path when the SmtpNotifier is used. Must provide +# a path to the template. +# +# Example: templated_html_content_path = path/to/override/email.html +# +# Variable: AIRFLOW__SMTP_PROVIDER__TEMPLATED_HTML_CONTENT_PATH +# +# templated_html_content_path = + +processor_log_folder = /opt/airflow/logs/scheduler diff --git a/airflow_bundle/leaf-pipeline/airflow/dags/configs/config.docker.yaml b/airflow_bundle/leaf-pipeline/airflow/dags/configs/config.docker.yaml new file mode 100644 index 000000000..465a10ca7 --- /dev/null +++ b/airflow_bundle/leaf-pipeline/airflow/dags/configs/config.docker.yaml @@ -0,0 +1,67 @@ +# io: +# # IMPORTANT: use the Docker service name of Postgres (from your compose): +# postgres_url: "postgresql+psycopg2://missions_user:pg123@postgres:5432/missions_db" +io: + postgres_url: "postgresql+psycopg2://postgres:postgres@agcloud-postgres:5432/postgres" + + +windows: + frequency: "D" + timezone: "UTC" + +source_mapping: + entity_dim: "mission" # or "region"/"device" + area_strategy: "none" # or "region_area" (requires regions table/geom) + filters: + start_time: null + end_time: null + anomaly_codes: null + +baseline: + method: "median" + lookback_periods: 28 + min_history: 7 + seasonality: null + +rules: + count_anomaly: + enabled: true + method: "zscore" + z_threshold: 3.0 + iqr_k: 1.5 + min_count: 3 + worsening: + enabled: true + method: "slope" + slope_lookback: 7 + slope_min: 0.02 + min_periods: 5 + ewma_span: 7 + ewma_threshold: 0.6 + +alerting: + dedup_cooldown_windows: 3 + resolve_after_no_anomaly: 3 + rate_limit_per_run: 100 + group_by_window: true + +delivery: + slack: + enabled: false + webhook_url: "" + webhook: + enabled: false + url: "" + headers: {} + email: + enabled: false + smtp_host: "" + smtp_port: 587 + username: "" + password_env: "SMTP_PASSWORD" + from_addr: "" + to_addrs: [] + +run: + dry_run: false + diff --git a/airflow_bundle/leaf-pipeline/airflow/dags/leaf_pipeline_dag.py b/airflow_bundle/leaf-pipeline/airflow/dags/leaf_pipeline_dag.py new file mode 100644 index 000000000..b66987628 --- /dev/null +++ b/airflow_bundle/leaf-pipeline/airflow/dags/leaf_pipeline_dag.py @@ -0,0 +1,535 @@ +from __future__ import annotations +from datetime import datetime +import pendulum +from airflow import DAG +from airflow.operators.bash import BashOperator +from airflow.providers.docker.operators.docker import DockerOperator + + + +PROJECT_ROOT = "/opt/leaf-pipeline/projects/leaf-counting" + +PYTHON_BIN = "python" +WEIGHTS = f"{PROJECT_ROOT}/weights/best.pt" + + +OUT_RUN = f"{PROJECT_ROOT}/runs_local/airflow_run" +STAGING_DIR = "/opt/airflow/staging/input" + +tz = pendulum.timezone("Asia/Jerusalem") + +with DAG( + dag_id="leaf_pipeline_v2", + start_date=datetime(2025, 10, 1, tzinfo=tz), + schedule=None, + catchup=False, + default_args={"owner": "leafcounting", "retries": 0}, + tags=["leaf-counting", "detect", "pwb", "crop", "minio"], +) as dag: + + + RUN_ID_DATE = "{{ dag_run.conf.get('run_id') or logical_date.in_timezone('Asia/Jerusalem').strftime('%Y/%m/%d/%H%M') }}" + + # ----------------------------- + # STAGE INPUT + # ----------------------------- + stage_input = BashOperator( + task_id="stage_input", + bash_command=""" +set -euo pipefail +python -m pip install --no-cache-dir -q \ + --trusted-host pypi.org --trusted-host files.pythonhosted.org --trusted-host pypi.python.org \ + awscli \ +|| apt-get update && apt-get install -y -qq ca-certificates awscli \ +|| python -m pip install --no-cache-dir -q \ + --index-url http://pypi.org/simple \ + --trusted-host pypi.org --trusted-host files.pythonhosted.org --trusted-host pypi.python.org \ + awscli +STAGING_DIR='{{ params.staging_dir }}' +INPUT_MODE='minio' +mkdir -p "$STAGING_DIR"; rm -rf "$STAGING_DIR"/* +if [ "$INPUT_MODE" = 'minio' ]; then + SRC_BUCKET='{{ dag_run.conf.get("src_bucket", var.value.leaf_minio_bucket | default("imagery")) }}' + SRC_PREFIX='leaves/examples' + ENDPOINT_URL='{{ conn.minio_s3.extra_dejson.endpoint_url | default("http://host.docker.internal:9001") }}' + export AWS_ACCESS_KEY_ID='{{ conn.minio_s3.login }}' + export AWS_SECRET_ACCESS_KEY='{{ conn.minio_s3.password }}' + export AWS_DEFAULT_REGION='{{ conn.minio_s3.extra_dejson.region_name or "us-east-1" }}' + export AWS_S3_FORCE_PATH_STYLE=true + export AWS_EC2_METADATA_DISABLED=true + echo "[stage] source=minio s3://$SRC_BUCKET/$SRC_PREFIX -> $STAGING_DIR (endpoint=$ENDPOINT_URL)" + python -m awscli s3 sync "s3://$SRC_BUCKET/$SRC_PREFIX" "$STAGING_DIR" --endpoint-url "$ENDPOINT_URL" +else + INPUT_DIR='{{ params.project_root }}/demo_images' + echo "[stage] source=local $INPUT_DIR -> $STAGING_DIR" + rsync -a --delete "$INPUT_DIR"/ "$STAGING_DIR"/ +fi +""", + params={"staging_dir": STAGING_DIR, "project_root": PROJECT_ROOT}, + env={"PYTHONUNBUFFERED": "1"}, + ) + + # ----------------------------- + # DETECT -> imagery/leaves///
//detect/ + # ----------------------------- + detect = BashOperator( + task_id="detect", + bash_command=""" +set -euo pipefail + +PROJECT_ROOT='{{ params.project_root }}' +PY='{{ params.python_bin }}'; if ! command -v "$PY" >/dev/null 2>&1; then PY='python'; fi + +export PYTHONEXECUTABLE="$PY" + +INPUT_DIR='{{ params.staging_dir }}' +OUT_LOCAL_DET='{{ params.out_run }}/detect' +WEIGHTS='{{ params.weights }}' + +DATE_ONLY='{{ dag_run.conf.get("run_id") or logical_date.in_timezone("Asia/Jerusalem").strftime("%Y/%m/%d/%H%M") }}' + +DEST_PREFIX="leaves/${DATE_ONLY}/detect" + +# MinIO: +# ל-SDK (minio-py) +ENDPOINT_HOSTPORT='{{ (conn.minio_s3.host or "host.docker.internal") }}:{{ (conn.minio_s3.port or 9001) }}' +# ל-awscli צריך URL מלא: +ENDPOINT_URL='{{ conn.minio_s3.extra_dejson.endpoint_url | default("http://host.docker.internal:9001") }}' +BUCKET='{{ var.value.leaf_minio_bucket | default("imagery") }}' +export AWS_ACCESS_KEY_ID='{{ conn.minio_s3.login }}' +export AWS_SECRET_ACCESS_KEY='{{ conn.minio_s3.password }}' +export AWS_DEFAULT_REGION='us-east-1' +export AWS_S3_FORCE_PATH_STYLE=true + +mkdir -p "$OUT_LOCAL_DET" + +cd "$PROJECT_ROOT" +$PY src/detect_only.py \ + --input "$INPUT_DIR" \ + --out "$OUT_LOCAL_DET" \ + --weights "$WEIGHTS" \ + --conf 0.25 --imgsz 896 --device cpu \ + --minio-endpoint "$ENDPOINT_HOSTPORT" \ + --minio-access "$AWS_ACCESS_KEY_ID" \ + --minio-secret "$AWS_SECRET_ACCESS_KEY" \ + --minio-bucket "$BUCKET" \ + --minio-prefix "leaves/${DATE_ONLY}" \ + --run-id "detect" + +# יישור קו לנתיב המדויק: +pip install -q awscli || true +python -m awscli s3 sync "$OUT_LOCAL_DET"/ "s3://$BUCKET/$DEST_PREFIX/" --endpoint-url "$ENDPOINT_URL" +python -m awscli s3 ls "s3://$BUCKET/$DEST_PREFIX/" --recursive --endpoint-url "$ENDPOINT_URL" || true +""", + params={ + "project_root": PROJECT_ROOT, + # "python_bin": PYTHON_BIN, + "python_bin": "/usr/local/bin/python", + "staging_dir": STAGING_DIR, + "out_run": OUT_RUN, + "weights": WEIGHTS, + "run_id_date": RUN_ID_DATE, + }, + env={"PYTHONUNBUFFERED": "1"}, + ) + + # ----------------------------- + # PREDICT_PWB -> imagery/leaves///
//pwb/ + # ----------------------------- + pwb = BashOperator( + task_id="predict_pwb", + bash_command=""" +set -euo pipefail + +PROJECT_ROOT='{{ params.project_root }}' +PY='{{ params.python_bin }}'; if ! command -v "$PY" >/dev/null 2>&1; then PY='python'; fi +INPUT_DIR='{{ params.staging_dir }}' +OUT_LOCAL_PWB='{{ params.out_run }}/pwb' +WEIGHTS='{{ params.weights }}' + +DATE_ONLY='{{ dag_run.conf.get("run_id") or logical_date.in_timezone("Asia/Jerusalem").strftime("%Y/%m/%d/%H%M") }}' + +DEST_PREFIX="leaves/${DATE_ONLY}/pwb" + +ENDPOINT_HOSTPORT='{{ (conn.minio_s3.host or "host.docker.internal") }}:{{ (conn.minio_s3.port or 9001) }}' +ENDPOINT_URL='{{ conn.minio_s3.extra_dejson.endpoint_url | default("http://host.docker.internal:9001") }}' +BUCKET='{{ var.value.leaf_minio_bucket | default("imagery") }}' +export AWS_ACCESS_KEY_ID='{{ conn.minio_s3.login }}' +export AWS_SECRET_ACCESS_KEY='{{ conn.minio_s3.password }}' +export AWS_DEFAULT_REGION='us-east-1' +export AWS_S3_FORCE_PATH_STYLE=true + +mkdir -p "$OUT_LOCAL_PWB" + +cd "$PROJECT_ROOT" +$PY src/predict_pyramid_wbf.py \ + --input "$INPUT_DIR" \ + --out "$OUT_LOCAL_PWB" \ + --weights "$WEIGHTS" \ + --scales 0.75,1.0,1.25 --conf 0.25 --iou 0.55 --imgsz 896 --device cpu \ + --minio-endpoint "$ENDPOINT_HOSTPORT" \ + --minio-access "$AWS_ACCESS_KEY_ID" \ + --minio-secret "$AWS_SECRET_ACCESS_KEY" \ + --minio-bucket "$BUCKET" \ + --minio-prefix "leaves/${DATE_ONLY}" \ + --run-id "pwb" + +pip install -q awscli || true +python -m awscli s3 sync "$OUT_LOCAL_PWB"/ "s3://$BUCKET/$DEST_PREFIX/" --endpoint-url "$ENDPOINT_URL" +python -m awscli s3 ls "s3://$BUCKET/$DEST_PREFIX/" --recursive --endpoint-url "$ENDPOINT_URL" || true +""", + params={ + "project_root": PROJECT_ROOT, + # "python_bin": PYTHON_BIN, + "python_bin": "/usr/local/bin/python", + + "staging_dir": STAGING_DIR, + "out_run": OUT_RUN, + "weights": WEIGHTS, + "run_id_date": RUN_ID_DATE, + }, + env={"PYTHONUNBUFFERED": "1"}, + ) + + + crop = BashOperator( + task_id="crop", + bash_command=""" + set -euo pipefail + + PROJECT_ROOT='{{ params.project_root }}' + PY='{{ params.python_bin }}'; if ! command -v "$PY" >/dev/null 2>&1; then PY='python'; fi + OUT_LOCAL_CROP='{{ params.out_run }}/crop' + PWB_LOCAL='{{ params.out_run }}/pwb' + RUN_ID_DATE='{{ dag_run.conf.get("run_id") or logical_date.in_timezone("Asia/Jerusalem").strftime("%Y/%m/%d/%H%M") }}' + + RUN_ID_CROP="${RUN_ID_DATE}/crop" + + ENDPOINT_URL='{{ conn.minio_s3.extra_dejson.endpoint_url | default("http://host.docker.internal:9001") }}' + BUCKET='{{ var.value.leaf_minio_bucket | default("imagery") }}' + export AWS_ACCESS_KEY_ID='{{ conn.minio_s3.login }}' + export AWS_SECRET_ACCESS_KEY='{{ conn.minio_s3.password }}' + export AWS_DEFAULT_REGION='us-east-1' + export AWS_S3_FORCE_PATH_STYLE=true + + + export DEVICE_ID="${DEVICE_ID:-dev1}" + + mkdir -p "$OUT_LOCAL_CROP" + + # 1)crop + if [ -f "$PROJECT_ROOT/src/crop_only.py" ]; then + cd "$PROJECT_ROOT" + $PY src/crop_only.py --input "$PWB_LOCAL" --out "$OUT_LOCAL_CROP" + elif [ -f "$PROJECT_ROOT/src/crop_from_meta.py" ]; then + cd "$PROJECT_ROOT" + $PY src/crop_from_meta.py --input "$PWB_LOCAL" --out "$OUT_LOCAL_CROP" + else + echo "[crop] No crop script found; will only sync if $OUT_LOCAL_CROP has files." + fi + + # 2) (: _TZ[ _suffix].ext) + + python -m pip install -q pillow piexif || true + export OUT_LOCAL_CROP + python - <<'PY' +import os, re, sys, time +from datetime import datetime, timezone +OUT = os.environ.get("OUT_LOCAL_CROP", "") +DEVICE = os.environ.get("DEVICE_ID", "dev1") + +if not OUT or not os.path.isdir(OUT): + sys.exit(0) + +IMG_EXT = {".jpg",".jpeg",".png",".webp",".tif",".tiff",".bmp"} +iso_re = re.compile(r"^[A-Za-z0-9\-]+_\d{8}T\d{6}Z(?:[ _][^/\\\\]+)?\\.[A-Za-z0-9]+$") + +def get_ts_from_exif(path): + try: + import piexif + from PIL import Image + with Image.open(path) as im: + exif = im.info.get("exif") + if not exif: + return None + exif_dict = piexif.load(exif) + dt = exif_dict["Exif"].get(piexif.ExifIFD.DateTimeOriginal) or \ + exif_dict["Exif"].get(piexif.ExifIFD.DateTimeDigitized) or \ + exif_dict["0th"].get(piexif.ImageIFD.DateTime) + if not dt: + return None + # EXIF: "YYYY:MM:DD HH:MM:SS" + s = dt.decode() if isinstance(dt, bytes) else dt + dt_obj = datetime.strptime(s, "%Y:%m:%d %H:%M:%S").replace(tzinfo=timezone.utc) + return dt_obj + except Exception: + return None + +def ts_for_file(path): + dt = get_ts_from_exif(path) + if dt is None: + # fallback: mtime כ-UTC + mt = os.path.getmtime(path) + dt = datetime.fromtimestamp(mt, tz=timezone.utc) + return dt + +renamed = 0 +skipped = 0 +for root, _, files in os.walk(OUT): + for f in files: + ext = os.path.splitext(f)[1].lower() + if ext not in IMG_EXT: + continue + if iso_re.match(f): + skipped += 1 + continue + old = os.path.join(root, f) + dt = ts_for_file(old) + ts = dt.strftime("%Y%m%dT%H%M%SZ") + # suffix + base = os.path.splitext(f)[0] + suffix = "" + if base and base.lower() not in {"img","image","photo","dsc","dscn"}: + + cleaned = re.sub(r"[^A-Za-z0-9._-]+", "-", base).strip("-_.") + if cleaned and cleaned != ts: + suffix = f"_{cleaned}" + new_name = f"{DEVICE}_{ts}{suffix}{ext}" + new = os.path.join(root, new_name) + if new == old: + skipped += 1 + continue + + i = 1 + new_final = new + while os.path.exists(new_final): + new_final = os.path.join(root, f"{os.path.splitext(new_name)[0]}_{i}{ext}") + i += 1 + os.rename(old, new_final) + print(f"[crop][rename] {f} -> {os.path.basename(new_final)}") + renamed += 1 + +print(f"[crop][rename] done: renamed={renamed}, already_ok={skipped}") +PY + + # 3)MinIO + pip install -q awscli || true + if [ -d "$OUT_LOCAL_CROP" ] && [ "$(ls -A "$OUT_LOCAL_CROP" || true)" ]; then + python -m awscli s3 sync "$OUT_LOCAL_CROP"/ "s3://$BUCKET/leaves/$RUN_ID_CROP/" --endpoint-url "$ENDPOINT_URL" + else + echo "[crop] WARNING: no local crops found to upload." + fi + + python -m awscli s3 ls "s3://$BUCKET/leaves/$RUN_ID_CROP/" --recursive --endpoint-url "$ENDPOINT_URL" || true + """, + params={ + "project_root": PROJECT_ROOT, + "python_bin": PYTHON_BIN, + "out_run": OUT_RUN, + "run_id_date": RUN_ID_DATE, + }, + env={"PYTHONUNBUFFERED": "1"}, +) + + + detection_jobs = DockerOperator( + task_id="detection_jobs", + image="detection-jobs:cpu-lts", + docker_url="unix://var/run/docker.sock", + api_version="auto", + auto_remove=True, + mount_tmp_dir=False, + working_dir="/app", + network_mode="ag_cloud", + environment={ + "MINIO_ENDPOINT": "{{ conn.minio_s3.extra_dejson.endpoint_url | default('http://minio-hot:9001') }}", + "AWS_ACCESS_KEY_ID": "{{ conn.minio_s3.login }}", + "AWS_SECRET_ACCESS_KEY": "{{ conn.minio_s3.password }}", + "AWS_S3_FORCE_PATH_STYLE": "true", + "AWS_DEFAULT_REGION": "us-east-1", + "DATABASE_URL": "postgresql+psycopg2://missions_user:pg123@postgres:5432/missions_db", + "USER": "root", + "HOME": "/root", + }, + command=[ + "/bin/bash","-lc", r''' +set -euo pipefail +echo "[DJ] START"; whoami; pwd; python3 -V +python3 -m pip install --no-cache-dir -q awscli || true + +RID='{{ dag_run.conf.get("run_id") or logical_date.in_timezone("Asia/Jerusalem").strftime("%Y/%m/%d/%H%M") }}' +BUCKET='{{ var.value.leaf_minio_bucket | default("imagery") }}' +SRC="s3://${BUCKET}/leaves/${RID}/crop/" +ENDPOINT="${MINIO_ENDPOINT:-http://minio-hot:9001}" + +mkdir -p /work/in /work/out +echo "[DJ] sync from ${SRC} via ${ENDPOINT}" +python3 -m awscli s3 cp --recursive "$SRC" /work/in --endpoint-url "$ENDPOINT" || true + +IN_DIR="/work/in" +READY_DIR="/work/in_ready" +DEVICE_ID="${DEVICE_ID:-dev1}" +rm -rf "$READY_DIR" && mkdir -p "$READY_DIR" + +while IFS= read -r -d '' f; do + base="$(basename "$f")" + stem="$(printf '%s\n' "$base" | sed -n 's/^\([A-Za-z0-9-]\+_[0-9]\{8\}T[0-9]\{6\}Z\).*/\1/p')" + [ -n "$stem" ] || { echo "[DJ][skip] no stem in $base"; continue; } + outdir="$READY_DIR/$stem" + mkdir -p "$outdir" + cp -p "$f" "$outdir/$base" +done < <(find "$IN_DIR" -type f -print0) + +echo "[DJ][ready] tree under: $READY_DIR" + +FLAT_DIR="/work/in_flat" +rm -rf "$FLAT_DIR" && mkdir -p "$FLAT_DIR" +find "$READY_DIR" -type f \( -iname '*.jpg' -o -iname '*.jpeg' -o -iname '*.png' -o -iname '*.webp' -o -iname '*.tif' -o -iname '*.tiff' -o -iname '*.bmp' \) -print0 \ +| while IFS= read -r -d '' f; do + base="$(basename "$f")" + + out="$FLAT_DIR/$base"; i=1 + while [ -e "$out" ]; do + ext="${base##*.}"; stem="${base%.*}" + out="$FLAT_DIR/${stem}_$i.$ext"; i=$((i+1)) + done + cp -p "$f" "$out" +done + +echo "[DJ][flat] files in $FLAT_DIR:" +ls -1 "$FLAT_DIR" | sed -n '1,50p' +export INPUT_DIR_FOR_RUNNER="$FLAT_DIR" +# === DB bootstrap: ensure required table exists === +python3 - <<'PY' +import os +from sqlalchemy import create_engine, text + + +ddl = """ +CREATE TABLE IF NOT EXISTS public.leaf_disease_types ( + id SERIAL PRIMARY KEY, + name TEXT UNIQUE NOT NULL +); + +CREATE TABLE IF NOT EXISTS public.leaf_reports ( + id BIGSERIAL PRIMARY KEY, + device_id TEXT NOT NULL, + leaf_disease_type_id INTEGER NOT NULL REFERENCES public.leaf_disease_types(id) ON DELETE RESTRICT, + ts TIMESTAMPTZ NOT NULL, + confidence DOUBLE PRECISION NOT NULL, + sick BOOLEAN NOT NULL +); + + +CREATE INDEX IF NOT EXISTS ix_leaf_reports_ts ON public.leaf_reports (ts); +CREATE INDEX IF NOT EXISTS ix_leaf_reports_type ON public.leaf_reports (leaf_disease_type_id); +CREATE INDEX IF NOT EXISTS ix_leaf_reports_device_ts ON public.leaf_reports (device_id, ts); +""" + +url = os.environ["DATABASE_URL"] +eng = create_engine(url, future=True) +with eng.begin() as conn: + conn.execute(text(ddl)) +print("[DJ][db] ensured table public.leaf_disease_types") +PY + + +export PYTHONPATH=/app +python3 - <<'PY' +import os, sys, importlib +os.environ.setdefault("DATABASE_URL", os.environ.get("DATABASE_URL","")) +inp = os.environ.get('INPUT_DIR_FOR_RUNNER','/work/in_flat') +print("[DJ] runner input dir:", inp) + +try: + files = [f for f in os.listdir(inp) if os.path.isfile(os.path.join(inp,f))] + print(f"[DJ] flat file count: {len(files)}") + for f in files[:10]: + print("[DJ] sample:", f) +except Exception as e: + print("[DJ] listdir failed:", e) + +mod = importlib.import_module('agri_baseline.src.batch_runner') +sys.argv = ['batch_runner.py', '--input', inp, '--mission','1'] +exit_code = 0 +try: + mod.main() +except SystemExit as e: + exit_code = int(e.code) if isinstance(e.code, int) else 1 +sys.exit(exit_code) +PY + +echo "[DJ] DONE" + ''' + ], +) + + + disease_monitor = DockerOperator( + task_id="disease_monitor", + image="disease-monitor:cpu-lts", + entrypoint=["/bin/sh","-c"], + command=[r''' +set -eu +echo "[DM] START"; whoami; pwd; python3 -V || true + +echo "[DM][env] DATABASE_URL=$DATABASE_URL" +echo "[DM][env] Dropping PG* env if present (to avoid overrides)" +unset PGHOST PGPORT PGDATABASE PGPASSWORD PGUSER 2>/dev/null || true + +# ---- DDL via DATABASE_URL only ---- +python3 - <<'PY' +import os, sys, time +import psycopg2 + +DDL = """ +CREATE TABLE IF NOT EXISTS alerts_leaves ( + id bigserial PRIMARY KEY, + entity_id text NOT NULL, + rule text NOT NULL, + window_start timestamptz NOT NULL, + window_end timestamptz NOT NULL, + score double precision NOT NULL, + first_seen timestamptz NOT NULL, + last_seen timestamptz NOT NULL, + status text NOT NULL CHECK (status IN ('OPEN','ACK','RESOLVED')), + meta_json jsonb +); +CREATE INDEX IF NOT EXISTS ix_alerts_leaves_entity_rule ON alerts_leaves(entity_id, rule); +CREATE INDEX IF NOT EXISTS ix_alerts_leaves_status ON alerts_leaves(status); +""" + +dsn = os.environ["DATABASE_URL"].replace("postgresql+psycopg2://", "postgresql://", 1) +print("[DM][db] Using DSN:", dsn.replace(os.environ.get("DATABASE_URL",""), "***redacted***")) + +for i in range(12): + try: + with psycopg2.connect(dsn, connect_timeout=4) as conn: + with conn.cursor() as cur: + cur.execute(DDL) + print("[DM][db] DDL applied OK.") + break + except Exception as e: + print(f"[DM][db] retry {i+1}/12: {e}") + time.sleep(5) +else: + sys.exit("[DM][db] DDL failed after retries") +PY + +exec python -m disease_monitor.cli --config /app/configs/config.docker.yaml --log-level INFO + +'''], + environment={ + "DATABASE_URL": "postgresql://missions_user:pg123@postgres:5432/missions_db", + }, + working_dir="/app", + docker_url="unix://var/run/docker.sock", + api_version="auto", + auto_remove=True, + network_mode="ag_cloud", + dag=dag, +) + + + + stage_input >> detect >> pwb >> crop >> detection_jobs>>disease_monitor diff --git a/airflow_bundle/leaf-pipeline/airflow/webserver_config.py b/airflow_bundle/leaf-pipeline/airflow/webserver_config.py new file mode 100644 index 000000000..3048bb21f --- /dev/null +++ b/airflow_bundle/leaf-pipeline/airflow/webserver_config.py @@ -0,0 +1,132 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Default configuration for the Airflow webserver.""" + +from __future__ import annotations + +import os + +from flask_appbuilder.const import AUTH_DB + +# from airflow.www.fab_security.manager import AUTH_LDAP +# from airflow.www.fab_security.manager import AUTH_OAUTH +# from airflow.www.fab_security.manager import AUTH_OID +# from airflow.www.fab_security.manager import AUTH_REMOTE_USER + + +basedir = os.path.abspath(os.path.dirname(__file__)) + +# Flask-WTF flag for CSRF +WTF_CSRF_ENABLED = True +WTF_CSRF_TIME_LIMIT = None + +# ---------------------------------------------------- +# AUTHENTICATION CONFIG +# ---------------------------------------------------- +# For details on how to set up each of the following authentication, see +# http://flask-appbuilder.readthedocs.io/en/latest/security.html# authentication-methods +# for details. + +# The authentication type +# AUTH_OID : Is for OpenID +# AUTH_DB : Is for database +# AUTH_LDAP : Is for LDAP +# AUTH_REMOTE_USER : Is for using REMOTE_USER from web server +# AUTH_OAUTH : Is for OAuth +AUTH_TYPE = AUTH_DB + +# Uncomment to setup Full admin role name +# AUTH_ROLE_ADMIN = 'Admin' + +# Uncomment and set to desired role to enable access without authentication +# AUTH_ROLE_PUBLIC = 'Viewer' + +# Will allow user self registration +# AUTH_USER_REGISTRATION = True + +# The recaptcha it's automatically enabled for user self registration is active and the keys are necessary +# RECAPTCHA_PRIVATE_KEY = PRIVATE_KEY +# RECAPTCHA_PUBLIC_KEY = PUBLIC_KEY + +# Config for Flask-Mail necessary for user self registration +# MAIL_SERVER = 'smtp.gmail.com' +# MAIL_USE_TLS = True +# MAIL_USERNAME = 'yourappemail@gmail.com' +# MAIL_PASSWORD = 'passwordformail' +# MAIL_DEFAULT_SENDER = 'sender@gmail.com' + +# The default user self registration role +# AUTH_USER_REGISTRATION_ROLE = "Public" + +# When using OAuth Auth, uncomment to setup provider(s) info +# Google OAuth example: +# OAUTH_PROVIDERS = [{ +# 'name':'google', +# 'token_key':'access_token', +# 'icon':'fa-google', +# 'remote_app': { +# 'api_base_url':'https://www.googleapis.com/oauth2/v2/', +# 'client_kwargs':{ +# 'scope': 'email profile' +# }, +# 'access_token_url':'https://accounts.google.com/o/oauth2/token', +# 'authorize_url':'https://accounts.google.com/o/oauth2/auth', +# 'request_token_url': None, +# 'client_id': GOOGLE_KEY, +# 'client_secret': GOOGLE_SECRET_KEY, +# } +# }] + +# When using LDAP Auth, setup the ldap server +# AUTH_LDAP_SERVER = "ldap://ldapserver.new" + +# When using OpenID Auth, uncomment to setup OpenID providers. +# example for OpenID authentication +# OPENID_PROVIDERS = [ +# { 'name': 'Yahoo', 'url': 'https://me.yahoo.com' }, +# { 'name': 'AOL', 'url': 'http://openid.aol.com/' }, +# { 'name': 'Flickr', 'url': 'http://www.flickr.com/' }, +# { 'name': 'MyOpenID', 'url': 'https://www.myopenid.com' }] + +# ---------------------------------------------------- +# Theme CONFIG +# ---------------------------------------------------- +# Flask App Builder comes up with a number of predefined themes +# that you can use for Apache Airflow. +# http://flask-appbuilder.readthedocs.io/en/latest/customizing.html#changing-themes +# Please make sure to remove "navbar_color" configuration from airflow.cfg +# in order to fully utilize the theme. (or use that property in conjunction with theme) +# APP_THEME = "bootstrap-theme.css" # default bootstrap +# APP_THEME = "amelia.css" +# APP_THEME = "cerulean.css" +# APP_THEME = "cosmo.css" +# APP_THEME = "cyborg.css" +# APP_THEME = "darkly.css" +# APP_THEME = "flatly.css" +# APP_THEME = "journal.css" +# APP_THEME = "lumen.css" +# APP_THEME = "paper.css" +# APP_THEME = "readable.css" +# APP_THEME = "sandstone.css" +# APP_THEME = "simplex.css" +# APP_THEME = "slate.css" +# APP_THEME = "solar.css" +# APP_THEME = "spacelab.css" +# APP_THEME = "superhero.css" +# APP_THEME = "united.css" +# APP_THEME = "yeti.css" diff --git a/airflow_bundle/leaf-pipeline/dags/configs/disease_monitor.yaml b/airflow_bundle/leaf-pipeline/dags/configs/disease_monitor.yaml new file mode 100644 index 000000000..832d6daf7 --- /dev/null +++ b/airflow_bundle/leaf-pipeline/dags/configs/disease_monitor.yaml @@ -0,0 +1,68 @@ +io: + # IMPORTANT: use the Docker service name of Postgres (from your compose): + postgres_url: "postgresql+psycopg2://missions_user:pg123@postgres:5432/missions_db" + +windows: + frequency: "D" + timezone: "UTC" + +source_mapping: + entity_dim: "device" + area_strategy: "none" # or "region_area" (requires regions table/geom) + filters: + start_time: null + end_time: null + anomaly_codes: null + +baseline: + method: "median" + lookback_periods: 28 + min_history: 7 + seasonality: null + +rules: + count_anomaly: + enabled: true + method: "zscore" + z_threshold: 3.0 + iqr_k: 1.5 + min_count: 3 + worsening: + enabled: true + method: "slope" + slope_lookback: 7 + slope_min: 0.02 + min_periods: 5 + ewma_span: 7 + ewma_threshold: 0.6 + +alerting: + dedup_cooldown_windows: 3 + resolve_after_no_anomaly: 3 + rate_limit_per_run: 100 + group_by_window: true + +delivery: + kafka: + enabled: false + brokers: "kafka:9092" + topic: "alerts" + slack: + enabled: false + webhook_url: "" + webhook: + enabled: false + url: "" + headers: {} + email: + enabled: false + smtp_host: "" + smtp_port: 587 + username: "" + password_env: "SMTP_PASSWORD" + from_addr: "" + to_addrs: [] + +run: + dry_run: false + diff --git a/airflow_bundle/leaf-pipeline/docker-compose.yml b/airflow_bundle/leaf-pipeline/docker-compose.yml new file mode 100644 index 000000000..5ededec6f --- /dev/null +++ b/airflow_bundle/leaf-pipeline/docker-compose.yml @@ -0,0 +1,73 @@ + +version: "3.8" + +x-airflow-common: &airflow-common + image: leaf-airflow:2.9.3-fixed + build: + context: . + dockerfile: Dockerfile + environment: + AIRFLOW__CORE__LOAD_EXAMPLES: "False" + AIRFLOW__CORE__EXECUTOR: "SequentialExecutor" + AIRFLOW_HOME: /opt/airflow + LEAF_MINIO_ENDPOINT: "http://minio-hot:9000" + volumes: + - ./airflow:/opt/airflow + - ./projects/leaf-counting:/opt/leaf-pipeline/projects/leaf-counting + - /var/run/docker.sock:/var/run/docker.sock + user: "${AIRFLOW_UID:-50000}:${AIRFLOW_GID:-0}" + restart: unless-stopped + +services: + # --- Build-only images for DockerOperator tasks (exit 0 right away) --- + build_detection_jobs: + profiles: ["images"] + image: detection-jobs:cpu-lts + build: + context: ./projects/Detection_Jobs/Detection_Jobs + dockerfile: dockerfile + command: ["sh", "-c", "echo built detection-jobs && true"] + restart: "no" + networks: [agcloud_ag_cloud] + + build_disease_monitor: + profiles: ["images"] + image: disease-monitor:cpu-lts + build: + context: ./projects/disease-monitor/disease-monitor + dockerfile: Dockerfile + entrypoint: ["/bin/sh", "-c"] + command: ["sh", "-c", "echo built disease-monitor && true"] + restart: "no" + networks: [agcloud_ag_cloud] + + # --- Airflow runtime --- + scheduler: + <<: *airflow-common + command: ["airflow", "scheduler"] + user: "0:0" + depends_on: + build_detection_jobs: + condition: service_completed_successfully + build_disease_monitor: + condition: service_completed_successfully + networks: [agcloud_ag_cloud] + + webserver: + <<: *airflow-common + command: ["airflow", "webserver"] + user: "0:0" + ports: + - "8081:8080" + depends_on: + scheduler: + condition: service_started + build_detection_jobs: + condition: service_completed_successfully + build_disease_monitor: + condition: service_completed_successfully + networks: [agcloud_ag_cloud] + +networks: + agcloud_ag_cloud: + external: true diff --git a/airflow_bundle/leaf-pipeline/projects/Detection_Jobs/.gitignore b/airflow_bundle/leaf-pipeline/projects/Detection_Jobs/.gitignore new file mode 100644 index 000000000..06593e32e --- /dev/null +++ b/airflow_bundle/leaf-pipeline/projects/Detection_Jobs/.gitignore @@ -0,0 +1,55 @@ +# ==== OS / IDE ==== +.DS_Store +Thumbs.db +.vscode/ +.idea/ + +# ==== Node ==== +node_modules/ +dist/ + +# ==== Python ==== +__pycache__/ +*.py[cod] +*.pyc +*.pyo +*.so +*.dylib + +# ==== Virtual envs ==== +.venv/ +venv/ +ENV/ +env/ + +# ==== Packaging / build ==== +build/ +*.egg-info/ + +# ==== Environment / Secrets ==== +.env +.env.* + +# ==== Data / Notebooks / Logs ==== +*.log +*.ipynb +.ipynb_checkpoints/ + +# ==== Artifacts / Wheels / Models ==== +artifacts/ +.wheels/ +wheels/ +*.whl +*.pt +*.pth +*.bin + +# ==== Coverage reports ==== +.pytest_cache/ +.coverage +coverage.xml +htmlcov/ + +# ==== gRPC generated (נוצרים בבילד דוקר) ==== +server/embed_pb2.py +server/embed_pb2_grpc.py diff --git a/airflow_bundle/leaf-pipeline/projects/Detection_Jobs/Detection_Jobs/.dockerignore b/airflow_bundle/leaf-pipeline/projects/Detection_Jobs/Detection_Jobs/.dockerignore new file mode 100644 index 000000000..641f56876 --- /dev/null +++ b/airflow_bundle/leaf-pipeline/projects/Detection_Jobs/Detection_Jobs/.dockerignore @@ -0,0 +1,33 @@ +# Python cache +__pycache__/ +*.pyc +*.pyo + +# Virtual environments +.env +.venv/ +venv/ + +# IDE +.idea/ +.vscode/ + +# Node / Frontend +node_modules/ +dist/ + +# Test / Coverage +.pytest_cache/ +.coverage +coverage.xml +htmlcov/ + +# Local databases +*.db +agri.db + +# Data outputs +data/ +data_balanced/ +data_baseline/ +*.csv diff --git a/airflow_bundle/leaf-pipeline/projects/Detection_Jobs/Detection_Jobs/.gitignore b/airflow_bundle/leaf-pipeline/projects/Detection_Jobs/Detection_Jobs/.gitignore new file mode 100644 index 000000000..b0e9b0028 --- /dev/null +++ b/airflow_bundle/leaf-pipeline/projects/Detection_Jobs/Detection_Jobs/.gitignore @@ -0,0 +1,25 @@ +# === Python cache === +__pycache__/ +*.pyc +*.pyo + +# === Virtual environments === +.env +.venv/ +venv/ + +# === IDE / Editors === +.idea/ +.vscode/ + + +# === Test / Coverage === +.pytest_cache/ + +# === Local databases === +*.db +agri.db + +# === Data outputs === +data_balanced/ +data/ \ No newline at end of file diff --git a/airflow_bundle/leaf-pipeline/projects/Detection_Jobs/Detection_Jobs/README.md b/airflow_bundle/leaf-pipeline/projects/Detection_Jobs/Detection_Jobs/README.md new file mode 100644 index 000000000..b2e46aec5 --- /dev/null +++ b/airflow_bundle/leaf-pipeline/projects/Detection_Jobs/Detection_Jobs/README.md @@ -0,0 +1,115 @@ +🌿 Agri Baseline – Disease Detection Pipeline + +This project runs an end-to-end disease detection pipeline for agricultural images. +It supports both local and MinIO-based storage backends, and processes entire folders of plant images using trained CNN models. + +🚀 Quick Start +1️⃣ Setup Environment +cp agri_baseline/.env.example agri_baseline/.env +pip install -r agri_baseline/requirements.txt + +2️⃣ Run the Pipeline + +Now the pipeline fetches images directly from MinIO, not from a local folder. + +docker compose up -d +docker compose logs -f app + + +The service automatically connects to your configured MinIO bucket, downloads the images to a cache directory, and processes them. + +3️⃣ Run Tests + +To verify the system: + +docker compose run --rm app pytest -q + +📂 Project Structure +Detection_Jobs/ +│ +├── agri_baseline/ +│ ├── scripts/ +│ │ └── run_batch.py # Run the pipeline on MinIO or local images +│ │ +│ ├── src/ +│ │ ├── detectors/ # CNN models and detectors +│ │ │ ├── base.py # Base Detector/Detection classes +│ │ │ ├── cnn_multi_classifier.py +│ │ │ ├── disease_model.py # Wraps CNN model as a Detector +│ │ │ ├── train/ +│ │ │ │ └── dictionary.py +│ │ │ +│ │ ├── pipeline/ +│ │ │ ├── config.py +│ │ │ ├── db.py # DB connection via SQLAlchemy +│ │ │ ├── logging_setup.py +│ │ │ └── utils.py # Helper functions (image loading, bbox, etc.) +│ │ │ +│ │ ├── storage/ +│ │ │ ├── minio_client.py +│ │ │ └── minio_sync.py # MinIO download helpers +│ │ │ +│ │ └── validator/ +│ │ ├── rules.py # Validation rules +│ │ └── validator.py # QA manager, writes to event logs +│ │ +│ ├── batch_runner.py # Orchestrates the full pipeline +│ ├── .env # Local config (not committed) +│ ├── .env.example # Example configuration file +│ ├── requirements.txt # Python dependencies +│ └── README.md +│ +├── models/ # Trained model weights (not in git) +│ ├── resnet18-f37072fd.pth +│ ├── cnn_multi_stage3.pth +│ └── multi_classes.pth +│ +├── docker-compose.yml # Runs pipeline + MinIO connection +├── dockerfile +├── tests/ # Unit and integration tests +│ ├── test_batch_runner.py +│ ├── test_disease_model.py +│ ├── test_run_detectors.py +│ ├── test_utils.py +│ └── test_validator.py +│ +└── ressearch/ # Experimental models and training + ├── detectors/ + │ ├── models/ + │ │ ├── cnn_binary.pth + │ │ ├── cnn_multi_finetuned.pth + │ │ └── cnn_multi.pth + │ ├── train/ + │ │ ├── disease.py + │ │ ├── eval_multi_levels.py + │ │ ├── finetune_multi_stage3.py + │ │ ├── finetune_multi.py + │ │ └── train_binary_multi.py + │ ├── cnn_binary_classifier.py + │ └── dataset_binary.py + +🧩 Models + +All trained models are stored under models/ and are not committed to Git: + +cnn_multi.pth – Base multi-class CNN + +cnn_multi_finetuned.pth – Fine-tuned on additional data + +cnn_multi_stage3.pth – Advanced fine-tuning with crop-specific data + +multi_classes.pth – Unified class mapping + +🧪 Testing + +Run all integration and unit tests using Docker: + +docker compose run --rm app pytest -q + +📌 Notes + +The pipeline now supports MinIO integration via environment variables in .env. + +Make sure your .env file includes all required MINIO_* variables (endpoint, bucket, credentials). + +Avoid committing .env or model files to the repository. \ No newline at end of file diff --git a/services/sounds/API-development/src/__init__.py b/airflow_bundle/leaf-pipeline/projects/Detection_Jobs/Detection_Jobs/__init__.py similarity index 100% rename from services/sounds/API-development/src/__init__.py rename to airflow_bundle/leaf-pipeline/projects/Detection_Jobs/Detection_Jobs/__init__.py diff --git a/services/sounds/API-development/src/backend/__init__.py b/airflow_bundle/leaf-pipeline/projects/Detection_Jobs/Detection_Jobs/agri_baseline/__init__.py similarity index 100% rename from services/sounds/API-development/src/backend/__init__.py rename to airflow_bundle/leaf-pipeline/projects/Detection_Jobs/Detection_Jobs/agri_baseline/__init__.py diff --git a/airflow_bundle/leaf-pipeline/projects/Detection_Jobs/Detection_Jobs/agri_baseline/requirements.txt b/airflow_bundle/leaf-pipeline/projects/Detection_Jobs/Detection_Jobs/agri_baseline/requirements.txt new file mode 100644 index 000000000..8df4df0ba --- /dev/null +++ b/airflow_bundle/leaf-pipeline/projects/Detection_Jobs/Detection_Jobs/agri_baseline/requirements.txt @@ -0,0 +1,47 @@ +# ---------------------------- +# Core scientific stack +# ---------------------------- +numpy==1.26.4 +pandas==2.0.3 +scipy==1.12.0 + +# ---------------------------- +# Image processing +# ---------------------------- +opencv-python-headless==4.9.0.80 +Pillow==10.4.0 +albumentations==1.4.3 + +# ---------------------------- +# Database & configuration +# ---------------------------- +SQLAlchemy==1.4.52 +psycopg2-binary==2.9.9 +python-dotenv==1.0.1 +minio==7.2.9 # MinIO SDK for connecting to object storage + +# ---------------------------- +# Testing +# ---------------------------- +pytest + +# ---------------------------- +# Typing helpers +# ---------------------------- +typing-extensions>=4.9.0 +# Deep learning frameworks +torch==2.2.0 +torchvision==0.17.0 +torchaudio==2.2.0 + + +# ---------------------------- +# Training & monitoring tools +# ---------------------------- +tensorboard>=2.16 + +# ---------------------------- +# Visualization & ML utilities +# ---------------------------- +matplotlib>=3.7 +scikit-learn>=1.3 diff --git a/services/sounds/API-development/tests/__init__.py b/airflow_bundle/leaf-pipeline/projects/Detection_Jobs/Detection_Jobs/agri_baseline/scripts/__init__.py similarity index 100% rename from services/sounds/API-development/tests/__init__.py rename to airflow_bundle/leaf-pipeline/projects/Detection_Jobs/Detection_Jobs/agri_baseline/scripts/__init__.py diff --git a/airflow_bundle/leaf-pipeline/projects/Detection_Jobs/Detection_Jobs/agri_baseline/scripts/run_batch.py b/airflow_bundle/leaf-pipeline/projects/Detection_Jobs/Detection_Jobs/agri_baseline/scripts/run_batch.py new file mode 100644 index 000000000..666ed466e --- /dev/null +++ b/airflow_bundle/leaf-pipeline/projects/Detection_Jobs/Detection_Jobs/agri_baseline/scripts/run_batch.py @@ -0,0 +1,137 @@ +""" +run_batch.py + +Purpose: +- Run the disease-detection batch pipeline either from a LOCAL folder of images + or from a MinIO bucket (objects are first downloaded to a local cache dir, + then processed exactly like local files). + +Usage examples: +1) Local folder (backward-compatible): + python -m agri_baseline.scripts.run_batch --storage local --images ./data/images + +2) MinIO (reads config from ENV and optional CLI flags): + python -m agri_baseline.scripts.run_batch --storage minio --minio-prefix "" + +Environment variables (typical .env): +- STORAGE_BACKEND=minio|local +- MINIO_ENDPOINT=127.0.0.1:9000 +- MINIO_ACCESS_KEY=minioadmin +- MINIO_SECRET_KEY=minioadmin +- MINIO_BUCKET=leaves +- MINIO_SECURE=false +- MINIO_PREFIX=mission-123/ (optional) +- MINIO_CACHE_DIR=./data/_minio_cache +""" + +import argparse +import os +from pathlib import Path + +from agri_baseline.src.pipeline.logging_setup import setup_logging +from agri_baseline.src.pipeline import config +from agri_baseline.src.batch_runner import BatchRunner + +# MinIO helpers provided in your project +from agri_baseline.src.storage.minio_client import load_minio_config # loads config from ENV +from agri_baseline.src.storage.minio_sync import download_prefix_to_dir, ensure_bucket + + +def run_local(images_dir: Path) -> None: + """ + LOCAL mode: + - Run the batch pipeline over a local folder of images. + - This preserves the original behavior for backward compatibility. + """ + runner = BatchRunner() + runner.run_folder(images_dir) + + +def run_minio(prefix: str, cache_dir: Path) -> None: + """ + MINIO mode: + - Pull objects from a MinIO bucket (based on ENV config). + - Download them to a local cache directory. + - Run the batch pipeline over the downloaded files. + """ + cfg = load_minio_config() + ensure_bucket(cfg) # Safety: create the bucket if it doesn't exist + + cache_dir.mkdir(parents=True, exist_ok=True) + + # Download objects under 'prefix' into the local cache folder + downloaded = download_prefix_to_dir(cfg, prefix=prefix, local_dir=cache_dir) + if not downloaded: + raise SystemExit( + f"No objects found in bucket '{cfg.bucket}' with prefix '{prefix}'." + ) + + runner = BatchRunner() + runner.run_folder(cache_dir) + + +def parse_args() -> argparse.Namespace: + """ + Parse CLI arguments and provide sensible defaults from ENV where applicable. + """ + ap = argparse.ArgumentParser(description="Run batch pipeline (local/minio).") + + # Backward-compatible local images folder + ap.add_argument( + "--images", + default=config.IMAGES_DIR, + help="Folder of input images (LOCAL mode)", + ) + + # Storage backend selector + ap.add_argument( + "--storage", + choices=["local", "minio"], + default=os.getenv("STORAGE_BACKEND", "local").lower(), + help="Where to read images from (local|minio).", + ) + + # MinIO options (with ENV fallbacks) + ap.add_argument( + "--minio-prefix", + default=os.getenv("MINIO_PREFIX", ""), + help="Object prefix inside the bucket (e.g. 'mission-123/').", + ) + ap.add_argument( + "--minio-cache", + default=os.getenv("MINIO_CACHE_DIR", "./data/_minio_cache"), + help="Local temp folder used to download MinIO objects before processing.", + ) + + return ap.parse_args() + + +def main() -> None: + """ + Entry point: + - Logs chosen backend. + - Dispatches to local/minio flows. + - Keeps logs concise and informative for CI/ops. + """ + log = setup_logging() + args = parse_args() + + log.info(f"Storage backend: {args.storage}") + + if args.storage == "local": + images_dir = Path(args.images) + log.info(f"Starting batch over LOCAL folder: {images_dir}") + run_local(images_dir) + log.info("Batch done (local).") + else: + cache_dir = Path(args.minio_cache) + log.info( + "Starting batch over MINIO: " + f"bucket from ENV, prefix='{args.minio_prefix}', cache='{cache_dir}'" + ) + run_minio(prefix=args.minio_prefix, cache_dir=cache_dir) + log.info("Batch done (minio).") + + +if __name__ == "__main__": + main() diff --git a/services/sounds/compression/scripts/__init__.py b/airflow_bundle/leaf-pipeline/projects/Detection_Jobs/Detection_Jobs/agri_baseline/src/__init__.py similarity index 100% rename from services/sounds/compression/scripts/__init__.py rename to airflow_bundle/leaf-pipeline/projects/Detection_Jobs/Detection_Jobs/agri_baseline/src/__init__.py diff --git a/airflow_bundle/leaf-pipeline/projects/Detection_Jobs/Detection_Jobs/agri_baseline/src/batch_runner.py b/airflow_bundle/leaf-pipeline/projects/Detection_Jobs/Detection_Jobs/agri_baseline/src/batch_runner.py new file mode 100644 index 000000000..6eee8f27d --- /dev/null +++ b/airflow_bundle/leaf-pipeline/projects/Detection_Jobs/Detection_Jobs/agri_baseline/src/batch_runner.py @@ -0,0 +1,683 @@ +# # agri_baseline/src/batch_runner.py +# # Max line length: 100 + +# from __future__ import annotations + +# import json +# from dataclasses import asdict, is_dataclass +# from datetime import datetime, timezone +# from pathlib import Path +# from typing import Tuple + +# from agri_baseline.src.pipeline.utils import ( +# load_image, +# image_id_from_path, +# clamp_bbox, +# ) +# from agri_baseline.src.pipeline.db import ( +# get_engine, +# INSERT_DET, +# INSERT_COUNT, +# INSERT_QA, +# ) +# from agri_baseline.src.detectors.disease_model import DiseaseDetector + + +# class BatchRunner: +# """ +# End-to-end runner: +# - Load image +# - Run disease detector +# - Normalize detections +# - Write anomalies / counts / QA to RelDB +# """ + +# def __init__(self, mission_id: int = 1, device_id: str = "device-1") -> None: +# self.mission_id = mission_id +# self.device_id = device_id # TEXT FK per schema v2 +# self.engine = get_engine() +# self.detector = DiseaseDetector() + +# # ---------------------------- +# # Public API +# # ---------------------------- + +# def run_folder(self, folder: Path | str) -> None: +# """ +# Run pipeline on all images within a folder (non-recursive). +# Skips non-image files; prints minimal info. +# """ +# folder = Path(folder) +# assert folder.exists(), f"Folder not found: {folder.resolve()}" + +# image_paths = sorted( +# p for p in folder.iterdir() if p.suffix.lower() in {".jpg", ".jpeg", ".png"} +# ) + +# total = 0 +# total_dets = 0 +# for img_path in image_paths: +# try: +# n = self.process_image(img_path) +# total += 1 +# total_dets += n +# except Exception as ex: +# # Keep output tidy; prefer structured logging in production +# print(f"[WARN] Failed on {img_path.name}: {ex}") + +# # Record a small QA summary +# qa = { +# "images_processed": total, +# "detections_total": total_dets, +# "ts": datetime.now(timezone.utc).isoformat(timespec="seconds"), +# } +# with self.engine.begin() as conn: +# conn.execute(INSERT_QA, {"details": json.dumps(qa)}) + +# def process_image(self, img_path: Path | str) -> int: +# """ +# Run pipeline on a single image, write detections and a simple per-image score. +# Returns number of detections written. +# """ +# img_path = Path(img_path) +# img, W, H = load_image(img_path) + +# image_id = image_id_from_path(img_path) +# dets = self.detector.run(img) + +# print(f"{image_id}: found {len(dets)} disease spots") + +# # Write detections as anomalies +# written = 0 +# for d in dets: +# x, y, w, h = self._extract_bbox(d) +# x, y, w, h = clamp_bbox(int(x), int(y), int(w), int(h), W, H) +# cx = x + w / 2.0 +# cy = y + h / 2.0 + +# area = float(getattr(d, "area", w * h)) +# label = str(getattr(d, "label", "disease")) +# conf = float(getattr(d, "confidence", 1.0)) + +# details = { +# "image_id": image_id, +# "label": label, +# "bbox": [x, y, w, h], +# "area": area, +# "confidence": conf, +# } +# if is_dataclass(d): +# details["raw_detection"] = asdict(d) + +# with self.engine.begin() as conn: +# conn.execute( +# INSERT_DET, +# dict( +# mission_id=self.mission_id, +# device_id=self.device_id, # TEXT FK +# ts=datetime.now(timezone.utc), +# anomaly_type_id=1, # seeded below +# severity=conf, +# details=json.dumps(details), +# wkt_geom=f"POINT({cx} {cy})", +# ), +# ) +# written += 1 + +# # Per-image score → tile_stats (tile_id TEXT, geom POLYGON) +# if dets: +# anomaly_score = float(len(dets)) +# poly_wkt = self._make_square_polygon_wkt(W / 2.0, H / 2.0, size=1.0) +# with self.engine.begin() as conn: +# conn.execute( +# INSERT_COUNT, +# dict( +# mission_id=self.mission_id, +# tile_id=image_id, # TEXT per schema v2 +# anomaly_score=anomaly_score, +# wkt_geom=poly_wkt, # POLYGON +# ), +# ) + +# return written + +# # ---------------------------- +# # Internals +# # ---------------------------- + +# @staticmethod +# def _extract_bbox(d) -> Tuple[float, float, float, float]: +# """ +# Normalize bbox to (x, y, w, h). Supports: +# - d.x, d.y, d.w, d.h +# - d.bbox == (x, y, w, h) +# - d.xmin, d.ymin, d.xmax, d.ymax +# - d.left, d.top, d.width, d.height +# """ +# if all(hasattr(d, a) for a in ("x", "y", "w", "h")): +# return float(d.x), float(d.y), float(d.w), float(d.h) + +# if hasattr(d, "bbox"): +# bx = list(d.bbox) +# if len(bx) != 4: +# raise ValueError(f"Unexpected bbox length: {len(bx)} in {bx}") +# x, y, w, h = map(float, bx) +# return x, y, w, h + +# if all(hasattr(d, a) for a in ("xmin", "ymin", "xmax", "ymax")): +# x1, y1, x2, y2 = float(d.xmin), float(d.ymin), float(d.xmax), float(d.ymax) +# return x1, y1, max(0.0, x2 - x1), max(0.0, y2 - y1) + +# if all(hasattr(d, a) for a in ("left", "top", "width", "height")): +# return float(d.left), float(d.top), float(d.width), float(d.height) + +# raise AttributeError( +# "Detection bbox fields missing. Supported: " +# "(x,y,w,h) or bbox or (xmin,ymin,xmax,ymax) or (left,top,width,height)." +# ) + +# @staticmethod +# def _make_square_polygon_wkt(cx: float, cy: float, size: float = 1.0) -> str: +# """ +# Build a tiny square Polygon around (cx, cy) in WKT, closed ring. +# PostGIS expects Polygon for tile_stats.geom (SRID 4326). +# """ +# x1, y1 = cx - size, cy - size +# x2, y2 = cx + size, cy + size +# return f"POLYGON(({x1} {y1}, {x2} {y1}, {x2} {y2}, {x1} {y2}, {x1} {y1}))" + + +# # ------------- CLI helper ------------- + +# # def main() -> None: +# # """ +# # Local runner: +# # python -m agri_baseline.src.batch_runner --input +# # """ +# # import argparse + +# # parser = argparse.ArgumentParser(description="Run disease detection pipeline.") +# # parser.add_argument("--log-level", default="INFO", help="logging level (ignored by runner)") + +# # parser.add_argument("--input", type=str, required=True, help="Image file or folder") +# # parser.add_argument("--mission", type=int, default=1, help="Numeric mission ID") +# # parser.add_argument("--device", type=str, default="device-1", help="Text device ID") +# # args = parser.parse_args() + +# # runner = BatchRunner(mission_id=args.mission, device_id=args.device) +# # in_path = Path(args.input) +# # if in_path.is_dir(): +# # runner.run_folder(in_path) +# # else: +# # runner.process_image(in_path) + + +# # if __name__ == "__main__": +# # main() +# def main() -> None: +# """ +# Local runner: +# python -m agri_baseline.src.batch_runner --input +# """ +# import argparse + +# parser = argparse.ArgumentParser(description="Run disease detection pipeline.") +# parser.add_argument("--log-level", type=str, default="INFO", +# help="logging level (ignored by runner)") + +# parser.add_argument("--input", type=str, required=True, +# help="Image file or folder") +# # קולט גם מחרוזת וגם מספר, וממיר ל-int תקני +# parser.add_argument("--mission", type=str, default="baseline", +# help=f"Mission name/id ({', '.join(MISSION_ALIASES)} or numeric id)") +# parser.add_argument("--device", type=str, default="cpu", +# choices=["cpu", "cuda"], +# help="device to use") + +# args = parser.parse_args() + +# mission_id = parse_mission(args.mission) + +# in_path = Path(args.input) +# if not in_path.exists(): +# raise FileNotFoundError(f"input does not exist: {in_path}") +# if in_path.is_dir(): +# # אופציונלי: הגנה על תיקייה ריקה +# has_files = any(in_path.rglob("*")) +# if not has_files: +# raise RuntimeError(f"input folder is empty: {in_path}") + +# runner = BatchRunner(mission_id=mission_id, device_id=args.device) +# if in_path.is_dir(): +# runner.run_folder(in_path) +# else: +# runner.process_image(in_path) + +# if __name__ == "__main__": +# main() +# agri_baseline/src/batch_runner.py +# Max line length: 100 + +from __future__ import annotations + +from sqlalchemy import text +import os +import re +import json +from dataclasses import asdict, is_dataclass +from datetime import datetime, timezone, timedelta +from pathlib import Path +from typing import Tuple + +from agri_baseline.src.pipeline.utils import ( + load_image, + image_id_from_path, + clamp_bbox, +) +from agri_baseline.src.pipeline.db import ( + get_engine, +) +from agri_baseline.src.detectors.disease_model import DiseaseDetector + +# ----------------------------------- +# SQL +# ----------------------------------- + +# anomalies insert (unchanged) +INSERT_ANOMALY = text( + """ + INSERT INTO public.anomalies + (mission_id, device_id, ts, anomaly_type_id, severity, details, geom) + VALUES + ( + :mission_id, + :device_id, + :ts, + :anomaly_type_id, + :severity, + CAST(:details AS JSONB), + ST_SetSRID(ST_GeomFromText(:wkt_geom), 4326) + ) + """ +) + +# NEW: leaf_reports insert (always written) +INSERT_LEAF_REPORT = text( + """ + INSERT INTO public.leaf_reports + (device_id, leaf_disease_type_id, ts, confidence, sick) + VALUES + (:device_id, :leaf_disease_type_id, :ts, :confidence, :sick) + """ +) + +# NEW: upsert/get id for leaf_disease_types by name (case-insensitive) +UPSERT_LEAF_DISEASE_TYPE = text( + """ + WITH ins AS ( + INSERT INTO public.leaf_disease_types (name) + VALUES (:name) + ON CONFLICT (name) DO UPDATE SET name = EXCLUDED.name + RETURNING id + ) + SELECT id FROM ins + UNION ALL + SELECT id FROM public.leaf_disease_types WHERE name = :name + LIMIT 1 + """ +) + +INSERT_MISSION_FULL = text( + """ + INSERT INTO public.missions (mission_id, start_time, end_time, area_geom) + VALUES ( + :mission_id, + :start_time, + :end_time, + ST_SetSRID(ST_GeomFromText(:wkt_poly), 4326) + ) + ON CONFLICT (mission_id) DO NOTHING + """ +) + + +class BatchRunner: + """ + End-to-end runner: + - Parse device & timestamp from file name: _TZ[ _suffix].ext + - Run disease detector + - ALWAYS write a row into public.leaf_reports for each detection + - Write into public.anomalies ONLY if label is 'sick' (i.e., does NOT contain 'healthy') + - Ensure supporting FKs exist (devices:, missions: fixed 60, leaf_disease_types:) + + Notes: + * mission_id is fixed to 60 per requirement. + * geom is the pixel-center point of the detection bbox (WKT, SRID 4326). + """ + + # Fixed mission per request + FIXED_MISSION_ID = 60 + + def __init__(self, mission_id: int | None = None, device_id: str = "device-1") -> None: + # mission_id ignored; always use 60, but keep signature for CLI compatibility + self.mission_id = BatchRunner.FIXED_MISSION_ID + self.fallback_device_id = device_id # used only if filename parsing fails + self.engine = get_engine() + self.detector = DiseaseDetector() + + # anomaly_types entry for LEAF_DISEASE (used only for anomalies table) + self.leaf_anomaly_type_id = self._ensure_anomaly_type( + code="LEAF_DISEASE", description="Leaf disease detected" + ) + + # ---------------------------- + # Public API + # ---------------------------- + + @staticmethod + def _parse_device_and_ts_from_name(img_path: Path) -> tuple[str, datetime]: + """ + Accepts: + _TZ. + _TZ_. + Returns (device_id, ts_utc). Raises ValueError if the pattern doesn't match. + """ + stem = img_path.stem + parts = stem.split("_") + if len(parts) < 2: + raise ValueError( + f"Filename '{img_path.name}' must be '_TZ[ _suffix].ext'" + ) + device = parts[0] + ts_str = parts[1] + if not re.fullmatch(r"\d{8}T\d{6}Z", ts_str): + raise ValueError( + f"Filename '{img_path.name}' must include timestamp as TZ" + ) + ts = datetime.strptime(ts_str, "%Y%m%dT%H%M%SZ").replace(tzinfo=timezone.utc) + return device, ts + + def run_folder(self, folder: Path | str) -> None: + """ + Run pipeline on all images within a folder (non-recursive). + """ + folder = Path(folder) + assert folder.exists(), f"Folder not found: {folder.resolve()}" + + image_paths = sorted( + p for p in folder.iterdir() if p.suffix.lower() in {".jpg", ".jpeg", ".png"} + ) + + total, total_dets = 0, 0 + for img_path in image_paths: + try: + n = self.process_image(img_path) + total += 1 + total_dets += n + except Exception as ex: + print(f"[WARN] Failed on {img_path.name}: {ex}") + + print(f"Processed {total} images, wrote {total_dets} detections") + + def process_image(self, img_path: Path | str) -> int: + """ + Run pipeline on a single image and insert rows into leaf_reports (always) + and anomalies (only if sick). Returns number of detections processed. + """ + img_path = Path(img_path) + # img_path = Path(img_path) + +# Parse from filename (with fallback for your current crop file names) + try: + device_id, det_ts = self._parse_device_and_ts_from_name(img_path) + except Exception: + device_id = self.fallback_device_id + # timestamp: file mtime if available, otherwise now (UTC) + try: + det_ts = datetime.fromtimestamp(img_path.stat().st_mtime, tz=timezone.utc) + except Exception: + det_ts = datetime.now(timezone.utc) + + # Parse from filename + device_id, det_ts = self._parse_device_and_ts_from_name(img_path) + + # Ensure FKs exist + self._ensure_device(device_id) + self._ensure_mission_full(self.mission_id, det_ts) + + # Load image & run detector + img, W, H = load_image(img_path) + image_id = image_id_from_path(img_path) + dets = self.detector.run(img) + + print(f"{image_id}: found {len(dets)} detections") + + written = 0 + for d in dets: + x, y, w, h = self._extract_bbox(d) + x, y, w, h = clamp_bbox(int(x), int(y), int(w), int(h), W, H) + cx = x + w / 2.0 + cy = y + h / 2.0 + + area = float(getattr(d, "area", w * h)) + label = str(getattr(d, "label", "disease")) + conf = float(getattr(d, "confidence", 1.0)) + + # Build details JSON (used only in anomalies) + details = { + "image_id": image_id, + "label": label, + "bbox": [x, y, w, h], + "area": area, + "confidence": conf, + "device_id": device_id, + "ts": det_ts.isoformat(), + } + minio_url = self._minio_url(img_path) + if minio_url: + details["minio_url"] = minio_url + details.setdefault("crop_type", None) + details.setdefault("disease_type", label) + if is_dataclass(d): + details["raw_detection"] = asdict(d) + + # Decide sick/healthy by label + sick = not self._is_healthy_label(label) + + # Map label → disease_type_name (part after "__" if present) + disease_type_name = self._disease_type_from_label(label) + + with self.engine.begin() as conn: + # ensure disease type exists and get id + leaf_type_id = self._ensure_leaf_disease_type(conn, disease_type_name) + + # 1) ALWAYS insert a leaf report + conn.execute( + INSERT_LEAF_REPORT, + dict( + device_id=device_id, + leaf_disease_type_id=leaf_type_id, + ts=det_ts, + confidence=conf, + sick=sick, + ), + ) + + # 2) Insert anomaly ONLY if sick + if sick: + conn.execute( + INSERT_ANOMALY, + dict( + mission_id=self.mission_id, + device_id=device_id, + ts=det_ts, + anomaly_type_id=self.leaf_anomaly_type_id, + severity=conf, + details=json.dumps(details), + wkt_geom=f"POINT({cx} {cy})", + ), + ) + + written += 1 + + return written + + # ---------------------------- + # Internals + # ---------------------------- + + @staticmethod + def _is_healthy_label(label: str) -> bool: + """Return True if label contains 'healthy' (case-insensitive).""" + return "healthy" in label.lower() + + @staticmethod + def _disease_type_from_label(label: str) -> str: + """ + Extract disease type token from label. If label contains 'a__b', return 'b'; else return label. + Keeps underscores as-is for consistency with the model outputs. + """ + if "__" in label: + return label.split("__", 1)[1] + return label + + def _ensure_anomaly_type(self, code: str, description: str) -> int: + """Return anomaly_type_id for `code`, inserting if needed (idempotent).""" + with self.engine.begin() as conn: + row = conn.execute( + text("SELECT anomaly_type_id FROM public.anomaly_types WHERE code = :c"), + {"c": code}, + ).first() + if row: + return int(row[0]) + + row = conn.execute( + text( + """ + INSERT INTO public.anomaly_types (code, description) + VALUES (:c, :d) + ON CONFLICT (code) + DO UPDATE SET description = EXCLUDED.description + RETURNING anomaly_type_id + """ + ), + {"c": code, "d": description}, + ).first() + return int(row[0]) + + def _ensure_leaf_disease_type(self, conn, name: str) -> int: + """ + Ensure a row exists in public.leaf_disease_types for the given name and return its id. + Uses an upsert with RETURNING to be idempotent. + """ + row = conn.execute(UPSERT_LEAF_DISEASE_TYPE, {"name": name}).first() + return int(row[0]) + + def _ensure_device(self, device_id: str) -> None: + """Ensure a row exists in public.devices (TEXT PK/UNIQUE).""" + with self.engine.begin() as conn: + conn.execute( + text( + """ + INSERT INTO public.devices (device_id) + VALUES (:d) + ON CONFLICT (device_id) DO NOTHING + """ + ), + {"d": device_id}, + ) + + def _ensure_mission_full(self, mission_id: int, ts: datetime) -> None: + """ + Ensure mission row exists and matches your table shape. + If not exists: start_time=ts, end_time=ts+1h, area=default 1x1° square near (0,0). + """ + with self.engine.begin() as conn: + exists = conn.execute( + text("SELECT 1 FROM public.missions WHERE mission_id = :id"), + {"id": mission_id}, + ).first() + if exists: + return + start = ts + end = ts + timedelta(hours=1) + wkt_poly = "POLYGON((0 0, 1 0, 1 1, 0 1, 0 0))" + conn.execute( + INSERT_MISSION_FULL, + { + "mission_id": mission_id, + "start_time": start, + "end_time": end, + "wkt_poly": wkt_poly, + }, + ) + + @staticmethod + def _extract_bbox(d) -> Tuple[float, float, float, float]: + """ + Normalize bbox to (x, y, w, h). Supports multiple field layouts. + """ + if all(hasattr(d, a) for a in ("x", "y", "w", "h")): + return float(d.x), float(d.y), float(d.w), float(d.h) + + if hasattr(d, "bbox"): + bx = list(d.bbox) + if len(bx) != 4: + raise ValueError(f"Unexpected bbox length: {len(bx)} in {bx}") + x, y, w, h = map(float, bx) + return x, y, w, h + + if all(hasattr(d, a) for a in ("xmin", "ymin", "xmax", "ymax")): + x1, y1, x2, y2 = float(d.xmin), float(d.ymin), float(d.xmax), float(d.ymax) + return x1, y1, max(0.0, x2 - x1), max(0.0, y2 - y1) + + if all(hasattr(d, a) for a in ("left", "top", "width", "height")): + return float(d.left), float(d.top), float(d.width), float(d.height) + + raise AttributeError( + "Detection bbox fields missing. Supported: " + "(x,y,w,h) or bbox or (xmin,ymin,xmax,ymax) or (left,top,width,height)." + ) + + @staticmethod + def _minio_url(img_path: Path) -> str | None: + """ + Build a MinIO object URL if MINIO_* env vars are provided. + """ + endpoint = os.getenv("MINIO_ENDPOINT") + bucket = os.getenv("MINIO_BUCKET") + prefix = os.getenv("MINIO_PREFIX", "").strip("/") + if not endpoint or not bucket: + return None + endpoint = endpoint.rstrip("/") + key = f"{prefix}/{img_path.name}" if prefix else img_path.name + return f"{endpoint}/{bucket}/{key}" + + +# ------------- CLI helper ------------- + +def main() -> None: + """ + Local runner: + python -m agri_baseline.src.batch_runner --input + """ + import argparse + + parser = argparse.ArgumentParser( + description="Run disease detection pipeline: leaf_reports (always), anomalies (sick only)." + ) + parser.add_argument("--input", type=str, required=True, help="Image file or folder") + parser.add_argument("--mission", type=int, default=60, help="Ignored; always fixed to 60") + parser.add_argument("--device", type=str, default="device-1", help="Fallback device (unused)") + args = parser.parse_args() + + runner = BatchRunner(mission_id=args.mission, device_id=args.device) + in_path = Path(args.input) + if in_path.is_dir(): + runner.run_folder(in_path) + else: + runner.process_image(in_path) + + +if __name__ == "__main__": + main() diff --git a/airflow_bundle/leaf-pipeline/projects/Detection_Jobs/Detection_Jobs/agri_baseline/src/detectors/base.py b/airflow_bundle/leaf-pipeline/projects/Detection_Jobs/Detection_Jobs/agri_baseline/src/detectors/base.py new file mode 100644 index 000000000..3eede7361 --- /dev/null +++ b/airflow_bundle/leaf-pipeline/projects/Detection_Jobs/Detection_Jobs/agri_baseline/src/detectors/base.py @@ -0,0 +1,94 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Dict, List, Optional, Tuple, Protocol + + +@dataclass(frozen=True) +class Detection: + """ + Model-agnostic detection container. + + Canonical storage: + - bbox: (x, y, w, h) in pixel coordinates. + - confidence: float in [0, 1]. + - label: class/code string. + + Notes: + - Properties expose a stable attribute API (.x/.y/.w/.h/.area etc.) + so downstream code can use either bbox or attributes. + - The class is frozen (immutable) to avoid accidental mutations + during processing and logging. + """ + label: str + confidence: float + bbox: Tuple[float, float, float, float] + meta: Optional[Dict] = None # optional extra data (e.g., model logits) + + # ---- Convenience constructors ------------------------------------------------- + + @staticmethod + def from_xywh( + label: str, + confidence: float, + x: float, + y: float, + w: float, + h: float, + meta: Optional[Dict] = None, + ) -> "Detection": + """Create a Detection from explicit x/y/w/h values.""" + return Detection(label=label, confidence=float(confidence), bbox=(x, y, w, h), meta=meta) + + # ---- Attribute-style view over bbox ------------------------------------------ + + @property + def x(self) -> float: + return float(self.bbox[0]) + + @property + def y(self) -> float: + return float(self.bbox[1]) + + @property + def w(self) -> float: + return float(self.bbox[2]) + + @property + def h(self) -> float: + return float(self.bbox[3]) + + @property + def xmin(self) -> float: + return self.x + + @property + def ymin(self) -> float: + return self.y + + @property + def xmax(self) -> float: + return self.x + self.w + + @property + def ymax(self) -> float: + return self.y + self.h + + @property + def area(self) -> float: + # Clamp at zero to avoid negative area if w/h are negative by mistake. + return max(0.0, self.w) * max(0.0, self.h) + + +class Detector(Protocol): + """ + Base detector interface. + + Implementors must return a list of Detection objects given a BGR image + (numpy array with shape (H, W, 3), dtype=uint8). + """ + name: str + + def run(self, bgr_image) -> List[Detection]: + """Run inference on a BGR image and return model detections.""" + ... diff --git a/airflow_bundle/leaf-pipeline/projects/Detection_Jobs/Detection_Jobs/agri_baseline/src/detectors/cnn_multi_classifier.py b/airflow_bundle/leaf-pipeline/projects/Detection_Jobs/Detection_Jobs/agri_baseline/src/detectors/cnn_multi_classifier.py new file mode 100644 index 000000000..6a2d5f3a3 --- /dev/null +++ b/airflow_bundle/leaf-pipeline/projects/Detection_Jobs/Detection_Jobs/agri_baseline/src/detectors/cnn_multi_classifier.py @@ -0,0 +1,12 @@ +# agri-baseline/src/detectors/cnn_multi_classifier.py +import torch.nn as nn +from torchvision import models + +def build_multi_model(num_classes: int, pretrained: bool = True) -> nn.Module: + """ + Builds a ResNet18 model for multi-class disease classification. + """ + model = models.resnet18(weights="IMAGENET1K_V1" if pretrained else None) + in_features = model.fc.in_features + model.fc = nn.Linear(in_features, num_classes) + return model diff --git a/airflow_bundle/leaf-pipeline/projects/Detection_Jobs/Detection_Jobs/agri_baseline/src/detectors/disease_model.py b/airflow_bundle/leaf-pipeline/projects/Detection_Jobs/Detection_Jobs/agri_baseline/src/detectors/disease_model.py new file mode 100644 index 000000000..a9f94ebae --- /dev/null +++ b/airflow_bundle/leaf-pipeline/projects/Detection_Jobs/Detection_Jobs/agri_baseline/src/detectors/disease_model.py @@ -0,0 +1,127 @@ +# agri_baseline/src/detectors/disease_model.py +from __future__ import annotations + +from dataclasses import dataclass +from typing import List, Tuple + +import cv2 +import numpy as np +import torch +import albumentations as A +from albumentations.pytorch import ToTensorV2 + +from agri_baseline.src.detectors.cnn_multi_classifier import build_multi_model +from agri_baseline.src.detectors.train.dictionary import CLASS_MAPPING + + +@dataclass +class Detection: + """Simple container for a single detection box.""" + bbox: Tuple[int, int, int, int] # x, y, w, h + confidence: float + label: str = "disease" + + @property + def area(self) -> int: + x, y, w, h = self.bbox + return int(w * h) + + +def _ensure_bgr_uint8(img: np.ndarray) -> np.ndarray: + """ + Normalize any input image to BGR uint8 with 3 channels. + Prevents cvtColor from crashing with color.simd_helpers.hpp:94. + + Rules: + - None / empty -> ValueError + - GRAY (H,W) -> BGR + - BGRA (H,W,4) -> BGR + - dtype != uint8 -> convert to uint8 (clip to [0..255]) + """ + if img is None or getattr(img, "size", 0) == 0: + raise ValueError("DiseaseDetector: empty/None image given") + + # If grayscale -> convert to BGR + if img.ndim == 2: + img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) + + # If BGRA -> drop alpha + elif img.ndim == 3 and img.shape[2] == 4: + img = cv2.cvtColor(img, cv2.COLOR_BGRA2BGR) + + # Validate shape now + if img.ndim != 3 or img.shape[2] != 3: + raise ValueError(f"DiseaseDetector: unexpected image shape {img.shape}") + + # Ensure uint8 + if img.dtype != np.uint8: + img = np.clip(img, 0, 255).astype(np.uint8) + + # Ensure non-zero size + h, w = img.shape[:2] + if h == 0 or w == 0: + raise ValueError("DiseaseDetector: zero-sized image") + + return img + + +class DiseaseDetector: + """ + CNN-based disease classifier. + - Normalizes input to BGR uint8 (3-ch) to avoid OpenCV color conversion crashes. + - Converts BGR->RGB before Albumentations (Normalize + ToTensorV2). + """ + + name = "disease" + + def __init__(self, model_path: str = "models/cnn_multi_stage3.pth", device: str | None = None) -> None: + # choose device + self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") + + # build model according to class mapping + self.classes = sorted(set(CLASS_MAPPING.values())) + self.model = build_multi_model(num_classes=len(self.classes)).to(self.device) + + # load trained weights + state = torch.load(model_path, map_location=self.device) + self.model.load_state_dict(state) + self.model.eval() + + # same validation transforms used in training + self.transform = A.Compose( + [ + A.Resize(224, 224), + A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), + ToTensorV2(), + ] + ) + + def run(self, img: np.ndarray) -> List[Detection]: + """ + Run the classifier on a single image. + :param img: np.ndarray from OpenCV (BGR or GRAY/BGRA/float) — any shape/dtype. + :return: list with a single full-frame Detection carrying predicted label/confidence. + """ + # 1) Normalize input so cvtColor is safe + img = _ensure_bgr_uint8(img) + + # 2) Convert to RGB for the model pipeline + img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + + # 3) Albumentations -> tensor + aug = self.transform(image=img_rgb) + tensor = aug["image"].unsqueeze(0).to(self.device) + + # 4) Model inference + with torch.no_grad(): + logits = self.model(tensor) + probs = torch.softmax(logits, dim=1)[0] + conf_t, cls_t = torch.max(probs, dim=0) + + label = self.classes[cls_t.item()] + confidence = float(conf_t.item()) + + # 5) Return a single detection that spans the whole image (classifier) + h, w = img.shape[:2] + det = Detection(bbox=(0, 0, w, h), confidence=confidence, label=label) + return [det] diff --git a/airflow_bundle/leaf-pipeline/projects/Detection_Jobs/Detection_Jobs/agri_baseline/src/detectors/train/dictionary.py b/airflow_bundle/leaf-pipeline/projects/Detection_Jobs/Detection_Jobs/agri_baseline/src/detectors/train/dictionary.py new file mode 100644 index 000000000..1d0671026 --- /dev/null +++ b/airflow_bundle/leaf-pipeline/projects/Detection_Jobs/Detection_Jobs/agri_baseline/src/detectors/train/dictionary.py @@ -0,0 +1,36 @@ +CLASS_MAPPING = { + # 🍅 Tomato + "tomato_healthy": "tomato__healthy", + "tomato_leaf": "tomato__healthy", + "tomato_bacterial_spot": "tomato__bacterial_spot", + "tomato_leaf_bacterial_spot": "tomato__bacterial_spot", + "tomato_early_blight": "tomato__early_blight", + "tomato_early_blight_leaf": "tomato__early_blight", + "tomato_late_blight": "tomato__late_blight", + "tomato_leaf_late_blight": "tomato__late_blight", + "tomato_leaf_mold": "tomato__leaf_mold", + "tomato_mold_leaf": "tomato__leaf_mold", + "tomato_septoria_leaf_spot": "tomato__septoria_leaf_spot", + "tomato_spider_mites_two_spotted_spider_mite": "tomato__spider_mites", + "tomato_spider_mites": "tomato__spider_mites", + "tomato_target_spot": "tomato__target_spot", + "tomato_tomato_mosaic_virus": "tomato__mosaic_virus", + "tomato_tomato_yellowleaf_curl_virus": "tomato__yellowleaf_curl_virus", + "tomato_leaf_mosaic_virus": "tomato__mosaic_virus", + "tomato_leaf_yellow_virus": "tomato__yellowleaf_curl_virus", + + + # 🥔 Potato + "potato_healthy": "potato__healthy", + "potato_leaf": "potato__healthy", + "potato_early_blight": "potato__early_blight", + "potato_leaf_early_blight": "potato__early_blight", + "potato_late_blight": "potato__late_blight", + "potato_leaf_late_blight": "potato__late_blight", + + # 🌶️ Pepper + "pepper_bell_healthy": "pepper__healthy", + "bell_pepper_leaf": "pepper__healthy", + "pepper_bell_bacterial_spot": "pepper__bacterial_spot", + "bell_pepper_leaf_spot": "pepper__bacterial_spot", +} diff --git a/airflow_bundle/leaf-pipeline/projects/Detection_Jobs/Detection_Jobs/agri_baseline/src/pipeline/config.py b/airflow_bundle/leaf-pipeline/projects/Detection_Jobs/Detection_Jobs/agri_baseline/src/pipeline/config.py new file mode 100644 index 000000000..18d696e0a --- /dev/null +++ b/airflow_bundle/leaf-pipeline/projects/Detection_Jobs/Detection_Jobs/agri_baseline/src/pipeline/config.py @@ -0,0 +1,25 @@ +from __future__ import annotations + +import os +from pathlib import Path + +# Try to load env files both from project root and from agri_baseline/.env +try: + from dotenv import load_dotenv # type: ignore + load_dotenv(dotenv_path=Path("agri_baseline/.env"), override=False) + load_dotenv(override=False) +except Exception: + pass + +# Prefer standard name DATABASE_URL; fallback to DB_URL; finally default to localhost:5432 +DB_URL: str = ( + os.getenv("DATABASE_URL") + or os.getenv("DB_URL") + or "postgresql+psycopg2://missions_user:pg123@localhost:5432/missions_db" +) + +IMAGES_DIR = os.getenv("IMAGES_DIR", "./data/images") +BATCH_SIZE = int(os.getenv("BATCH_SIZE", 64)) +MAX_WORKERS = int(os.getenv("MAX_WORKERS", 4)) +MIN_BBOX_AREA = int(os.getenv("MIN_BBOX_AREA", 60)) +MIN_COMPONENT_AREA = int(os.getenv("MIN_COMPONENT_AREA", 200)) diff --git a/airflow_bundle/leaf-pipeline/projects/Detection_Jobs/Detection_Jobs/agri_baseline/src/pipeline/db.py b/airflow_bundle/leaf-pipeline/projects/Detection_Jobs/Detection_Jobs/agri_baseline/src/pipeline/db.py new file mode 100644 index 000000000..8c69e24f3 --- /dev/null +++ b/airflow_bundle/leaf-pipeline/projects/Detection_Jobs/Detection_Jobs/agri_baseline/src/pipeline/db.py @@ -0,0 +1,67 @@ +from __future__ import annotations + +from sqlalchemy import create_engine, text, bindparam +from sqlalchemy.engine import Engine + +from . import config + +_engine: Engine | None = None + +def get_engine() -> Engine: + """Return a singleton SQLAlchemy engine for the configured DB.""" + global _engine + if _engine is None: + _engine = create_engine( + config.DB_URL, + pool_pre_ping=True, # keep-alive for flaky networks/tests + future=True, + connect_args={"connect_timeout": 5} # fail fast on bad host/port + ) + return _engine + +# === Inserts mapped to RelDB schema === + +# detections → anomalies +INSERT_DET = text( + """ + INSERT INTO anomalies(mission_id, device_id, ts, anomaly_type_id, severity, details, geom) + VALUES (:mission_id, :device_id, :ts, :anomaly_type_id, :severity, CAST(:details AS jsonb), + ST_GeomFromText(:wkt_geom, 4326)); + """ +) + +# counts → tile_stats +INSERT_COUNT = text( + """ + INSERT INTO tile_stats(mission_id, tile_id, anomaly_score, geom) + VALUES (:mission_id, :tile_id, :anomaly_score, ST_GeomFromText(:wkt_geom, 4326)) + ON CONFLICT (mission_id, tile_id) DO UPDATE + SET anomaly_score = excluded.anomaly_score; + """ +) + +# validator findings → event_logs +INSERT_FINDING = ( + text( + """ + INSERT INTO event_logs(ts, level, source, message, details) + VALUES (CURRENT_TIMESTAMP, :level, 'validator', :message, CAST(:details AS jsonb)); + """ + ) + # Defaults if the caller does not send the parameters + .bindparams( + bindparam("level", value="INFO"), + bindparam("message", value=""), + bindparam("details", value="{}"), + ) +) + + + +# QA metrics → event_logs +INSERT_QA = text( + """ + INSERT INTO event_logs(ts, level, source, message, details) + VALUES (CURRENT_TIMESTAMP, 'INFO', 'qa', 'QA metrics recorded', CAST(:details AS jsonb)); + """ +) \ No newline at end of file diff --git a/airflow_bundle/leaf-pipeline/projects/Detection_Jobs/Detection_Jobs/agri_baseline/src/pipeline/logging_setup.py b/airflow_bundle/leaf-pipeline/projects/Detection_Jobs/Detection_Jobs/agri_baseline/src/pipeline/logging_setup.py new file mode 100644 index 000000000..06193027f --- /dev/null +++ b/airflow_bundle/leaf-pipeline/projects/Detection_Jobs/Detection_Jobs/agri_baseline/src/pipeline/logging_setup.py @@ -0,0 +1,9 @@ +import logging + + +def setup_logging(): + logging.basicConfig( + level=logging.INFO, + format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", + ) + return logging.getLogger("agri") \ No newline at end of file diff --git a/airflow_bundle/leaf-pipeline/projects/Detection_Jobs/Detection_Jobs/agri_baseline/src/pipeline/utils.py b/airflow_bundle/leaf-pipeline/projects/Detection_Jobs/Detection_Jobs/agri_baseline/src/pipeline/utils.py new file mode 100644 index 000000000..0b99245e9 --- /dev/null +++ b/airflow_bundle/leaf-pipeline/projects/Detection_Jobs/Detection_Jobs/agri_baseline/src/pipeline/utils.py @@ -0,0 +1,62 @@ +# agri_baseline/src/pipeline/utils.py +# Max line length: 100 + +from __future__ import annotations + +import hashlib +from pathlib import Path +from typing import Tuple + +import cv2 +import numpy as np + + +class ImageLoadError(Exception): + """Raised when an image cannot be decoded or is empty.""" + + +def load_image(path: str | Path) -> Tuple[np.ndarray, int, int]: + """ + Load an image from disk as BGR uint8 and return (img, width, height). + + Rules: + - Always read as color to ensure 3 channels (BGR). + - Raise FileNotFoundError if the path doesn't exist. + - Raise ImageLoadError if decode fails or the image is empty. + - Convert dtype to uint8 if needed. + - Normalize channel count: grayscale -> BGR, BGRA -> BGR. + """ + p = Path(path) + if not p.exists(): + raise FileNotFoundError(f"Image not found: {p.resolve()}") + + # Always load as color to ensure 3 channels (BGR) + img = cv2.imread(str(p), cv2.IMREAD_COLOR) + if img is None or img.size == 0: + raise ImageLoadError(f"Failed to decode image (or empty): {p.resolve()}") + + if img.dtype != np.uint8: + img = cv2.convertScaleAbs(img) + + # Guard channel count (should be 3 after IMREAD_COLOR, but just in case) + if img.ndim == 2: + img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) + elif img.ndim == 3 and img.shape[2] == 4: + img = cv2.cvtColor(img, cv2.COLOR_BGRA2BGR) + + h, w = img.shape[:2] + return img, w, h + + +def image_id_from_path(path: str | Path) -> str: + p = Path(path) + digest = hashlib.sha1(str(p.resolve()).encode()).hexdigest()[:16] + return f"{p.stem}_{digest}" + + +def clamp_bbox(x: int, y: int, w: int, h: int, W: int, H: int) -> Tuple[int, int, int, int]: + x = max(0, min(x, W - 1)) + y = max(0, min(y, H - 1)) + w = max(1, min(w, W - x)) + h = max(1, min(h, H - y)) + return x, y, w, h diff --git a/services/sounds/compression/src/__init__.py b/airflow_bundle/leaf-pipeline/projects/Detection_Jobs/Detection_Jobs/agri_baseline/src/storage/__init__.py similarity index 100% rename from services/sounds/compression/src/__init__.py rename to airflow_bundle/leaf-pipeline/projects/Detection_Jobs/Detection_Jobs/agri_baseline/src/storage/__init__.py diff --git a/airflow_bundle/leaf-pipeline/projects/Detection_Jobs/Detection_Jobs/agri_baseline/src/storage/minio_client.py b/airflow_bundle/leaf-pipeline/projects/Detection_Jobs/Detection_Jobs/agri_baseline/src/storage/minio_client.py new file mode 100644 index 000000000..dd5effd69 --- /dev/null +++ b/airflow_bundle/leaf-pipeline/projects/Detection_Jobs/Detection_Jobs/agri_baseline/src/storage/minio_client.py @@ -0,0 +1,35 @@ +from __future__ import annotations + +import os +from dataclasses import dataclass +from minio import Minio + + +@dataclass(frozen=True) +class MinioConfig: + endpoint: str + access_key: str + secret_key: str + bucket: str + secure: bool + + +def load_minio_config() -> MinioConfig: + endpoint = os.getenv("MINIO_ENDPOINT", "localhost:9000") + access_key = os.getenv("MINIO_ACCESS_KEY", "") + secret_key = os.getenv("MINIO_SECRET_KEY", "") + bucket = os.getenv("MINIO_BUCKET", "my-bucket") + secure = os.getenv("MINIO_SECURE", "false").lower() == "true" + + if not access_key or not secret_key: + raise ValueError("Missing MINIO_ACCESS_KEY / MINIO_SECRET_KEY.") + return MinioConfig(endpoint, access_key, secret_key, bucket, secure) + + +def build_client(cfg: MinioConfig) -> Minio: + return Minio( + endpoint=cfg.endpoint, + access_key=cfg.access_key, + secret_key=cfg.secret_key, + secure=cfg.secure, + ) diff --git a/airflow_bundle/leaf-pipeline/projects/Detection_Jobs/Detection_Jobs/agri_baseline/src/storage/minio_sync.py b/airflow_bundle/leaf-pipeline/projects/Detection_Jobs/Detection_Jobs/agri_baseline/src/storage/minio_sync.py new file mode 100644 index 000000000..8c6c2b6a1 --- /dev/null +++ b/airflow_bundle/leaf-pipeline/projects/Detection_Jobs/Detection_Jobs/agri_baseline/src/storage/minio_sync.py @@ -0,0 +1,74 @@ +from __future__ import annotations + +import os +from io import BytesIO +from pathlib import Path +from typing import Iterable + +from .minio_client import MinioConfig, build_client + + +def ensure_bucket(cfg: MinioConfig) -> None: + """ + Ensure the target bucket exists; create it if it does not. + """ + client = build_client(cfg) + if not client.bucket_exists(cfg.bucket): + client.make_bucket(cfg.bucket) + + +def download_prefix_to_dir(cfg: MinioConfig, prefix: str, local_dir: Path) -> list[Path]: + """ + Download all objects under the given `prefix` to the local directory. + Returns a list of local file paths that were downloaded. + """ + client = build_client(cfg) + local_dir.mkdir(parents=True, exist_ok=True) + + downloaded: list[Path] = [] + for obj in client.list_objects(cfg.bucket, prefix=prefix, recursive=True): + # Skip entries that represent "virtual folders" + name = obj.object_name + if name.endswith("/") or not name: + continue + + # Simplify: save using the file's basename only. + # If you need to preserve the full hierarchy, use: local_dir.joinpath(name) + target = local_dir.joinpath(Path(name).name) + + response = client.get_object(cfg.bucket, name) + try: + data = response.read() + finally: + response.close() + response.release_conn() + + target.parent.mkdir(parents=True, exist_ok=True) + target.write_bytes(data) + downloaded.append(target) + + return downloaded + + +def upload_dir_to_prefix(cfg: MinioConfig, local_dir: Path, prefix: str) -> list[str]: + """ + Upload all files from the local directory under the given `prefix`. + Returns a list of object names that were uploaded. + """ + client = build_client(cfg) + ensure_bucket(cfg) + + uploaded: list[str] = [] + for path in local_dir.rglob("*"): + if not path.is_file(): + continue + + rel = path.relative_to(local_dir).as_posix() + object_name = f"{prefix.rstrip('/')}/{rel}" + data = path.read_bytes() + bio = BytesIO(data) + + client.put_object(cfg.bucket, object_name, bio, length=len(data)) + uploaded.append(object_name) + + return uploaded diff --git a/airflow_bundle/leaf-pipeline/projects/Detection_Jobs/Detection_Jobs/agri_baseline/src/validator/rules.py b/airflow_bundle/leaf-pipeline/projects/Detection_Jobs/Detection_Jobs/agri_baseline/src/validator/rules.py new file mode 100644 index 000000000..afb6318a7 --- /dev/null +++ b/airflow_bundle/leaf-pipeline/projects/Detection_Jobs/Detection_Jobs/agri_baseline/src/validator/rules.py @@ -0,0 +1,65 @@ +from __future__ import annotations +import json +from dataclasses import dataclass +from typing import Iterable, Optional + +from sqlalchemy import text +from agri_baseline.src.pipeline.db import get_engine, INSERT_FINDING, INSERT_QA + + +@dataclass +class Finding: + scope: str + image_id: str + rule: str + severity: str + message: str + details: Optional[dict] = None + + +# ---- Image-level checks ---- + +def check_bbox_bounds(image_id: str, width: int, height: int, dets: list[dict]) -> list[Finding]: + out: list[Finding] = [] + for d in dets: + x, y, w, h = d["bbox_x"], d["bbox_y"], d["bbox_w"], d["bbox_h"] + if x < 0 or y < 0 or x + w > width or y + h > height: + out.append(Finding("image", image_id, "bbox_oob", "warn", + f"BBox out-of-bounds: {(x, y, w, h)}")) + if w * h <= 0 or d["area_px"] <= 0: + out.append(Finding("image", image_id, "bbox_area_zero", "error", + "Non-positive area")) + if d["confidence"] < 0 or d["confidence"] > 1: + out.append(Finding("image", image_id, "conf_oob", "error", + f"Confidence out of range: {d['confidence']:.3f}")) + return out + + +def check_counts_reasonable(image_id: str, disease: int) -> list[Finding]: + out: list[Finding] = [] + if disease < 0: + out.append(Finding("image", image_id, "negative_counts", "error", + f"Negative count: disease={disease}")) + if disease == 0: + out.append(Finding("image", image_id, "all_zero_counts", "warn", + "Disease count is zero")) + if disease > 10000: + out.append(Finding("image", image_id, "count_too_high", "warn", + f"Suspiciously high disease count: {disease}")) + return out + + +# ---- Batch-level checks ---- + +def check_batch_error_rate(total: int, errored: int, threshold: float = 0.05) -> list[Finding]: + rate = 0.0 if total == 0 else errored / total + sev = "warn" if rate <= threshold else "error" + return [Finding("batch", None, "error_rate", sev, + f"Batch error rate={rate:.3%}, threshold={threshold:.0%}")] + + +def check_batch_no_detections(total: int, sum_dets: int) -> list[Finding]: + if total > 0 and sum_dets == 0: + return [Finding("batch", None, "no_detections", "warn", + "Pipeline produced zero detections for the entire batch")] + return [] diff --git a/airflow_bundle/leaf-pipeline/projects/Detection_Jobs/Detection_Jobs/agri_baseline/src/validator/validator.py b/airflow_bundle/leaf-pipeline/projects/Detection_Jobs/Detection_Jobs/agri_baseline/src/validator/validator.py new file mode 100644 index 000000000..3c970190c --- /dev/null +++ b/airflow_bundle/leaf-pipeline/projects/Detection_Jobs/Detection_Jobs/agri_baseline/src/validator/validator.py @@ -0,0 +1,94 @@ +from __future__ import annotations +import json +from dataclasses import dataclass +from typing import Iterable, List, Optional +from sqlalchemy import text + +from agri_baseline.src.pipeline.db import get_engine, INSERT_FINDING, INSERT_QA + + +@dataclass +class Finding: + """Single validation finding.""" + scope: str # e.g., "image" + image_id: str # logical id per image + rule: str # rule code/name + severity: str # DEBUG/INFO/WARN/ERROR + message: str # human-readable message + details: Optional[dict] = None + + +class Validator: + """ + Collects validation findings and writes batch summaries. + """ + def image_findings(self, findings: Iterable[Finding]) -> None: + """Write image-level findings into event_logs table.""" + with get_engine().begin() as conn: + for f in findings: + details_dict = { + "scope": f.scope, + "rule": f.rule, + "image_id": f.image_id, + **(f.details or {}), + } + conn.execute( + INSERT_FINDING, + { + "level": f.severity.upper(), + "message": f.message, + # Passes as a JSON string because SQL does CAST(... AS jsonb) "details": json.dumps(details_dict), + }, + ) + + + def batch_summary(self) -> None: + """ + Aggregate anomalies → tile_stats by image_id (from anomalies.details->>'image_id'). + For each (mission_id, image_id): + - anomaly_score = count of anomalies + - geom = envelope of a small expanded collect of points (Polygon, 4326) + Idempotent via ON CONFLICT (mission_id, tile_id). + """ + sql = text( + """ + WITH per_image AS ( + SELECT + a.mission_id, + a.details->>'image_id' AS tile_id, + COUNT(*)::real AS anomaly_score, + -- produce Polygon in 4326 directly (no WKT roundtrip) + ST_Envelope( + ST_Expand( + ST_Collect(a.geom), + 0.0005 -- ~50m at equator; tweak if needed + ) + )::geometry(Polygon, 4326) AS poly + FROM anomalies a + WHERE a.geom IS NOT NULL + AND a.details ? 'image_id' + GROUP BY a.mission_id, tile_id + ) + INSERT INTO tile_stats (mission_id, tile_id, anomaly_score, geom) + SELECT mission_id, tile_id, anomaly_score, poly + FROM per_image + ON CONFLICT (mission_id, tile_id) DO UPDATE + SET anomaly_score = EXCLUDED.anomaly_score, + geom = EXCLUDED.geom; + """ + ) + + with get_engine().begin() as conn: + conn.execute(sql) + + # optional: record a QA info log (pass JSON as string) + with get_engine().begin() as conn: + conn.execute( + INSERT_QA, + { + "details": json.dumps({ + "source": "batch_summary", + "note": "tile_stats updated from anomalies by image_id", + }) + }, + ) diff --git a/airflow_bundle/leaf-pipeline/projects/Detection_Jobs/Detection_Jobs/docker-compose.yml b/airflow_bundle/leaf-pipeline/projects/Detection_Jobs/Detection_Jobs/docker-compose.yml new file mode 100644 index 000000000..18e1cc31c --- /dev/null +++ b/airflow_bundle/leaf-pipeline/projects/Detection_Jobs/Detection_Jobs/docker-compose.yml @@ -0,0 +1,25 @@ +services: + app: + build: + context: . + dockerfile: Dockerfile + container_name: agri_app + # exec-form to avoid spacing/quoting issues + command: ["python", "-m", "agri_baseline.scripts.run_batch", "--storage", "minio"] + env_file: + - agri_baseline/.env + volumes: + - ./agri_baseline:/app/agri_baseline + - ./tests:/app/tests + - ./data:/app/data + - ./models:/root/.cache/torch/hub/checkpoints + networks: + - agri_net + - minio_net # ← MinIO network + +networks: + agri_net: + external: true + minio_net: + external: true + name: storage_with_mqtt_minionet # ← MinIO network name diff --git a/airflow_bundle/leaf-pipeline/projects/Detection_Jobs/Detection_Jobs/dockerfile b/airflow_bundle/leaf-pipeline/projects/Detection_Jobs/Detection_Jobs/dockerfile new file mode 100644 index 000000000..06550cea8 --- /dev/null +++ b/airflow_bundle/leaf-pipeline/projects/Detection_Jobs/Detection_Jobs/dockerfile @@ -0,0 +1,183 @@ +# # # ============================== +# # # Based on PyTorch with CUDA +# # # ============================== +# # # FROM pytorch/pytorch:2.2.0-cuda12.1-cudnn8-runtime +# # ARG BASE_IMAGE=pytorch/pytorch:2.2.0-cpu +# # FROM ${BASE_IMAGE} + +# # # # --- NETFREE CERT INSTALL --- +# # # ADD https://netfree.link/dl/unix-ca.sh /home/netfree-unix-ca.sh +# # # RUN bash /home/netfree-unix-ca.sh \ +# # # && update-ca-certificates +# # # ENV REQUESTS_CA_BUNDLE=/etc/ssl/certs/ca-certificates.crt +# # # ENV SSL_CERT_FILE=/etc/ssl/certs/ca-certificates.crt +# # # ENV PIP_CERT=/etc/ssl/certs/ca-certificates.crt +# # # --- NETFREE CERT INSTALL (optional) --- +# # ARG INSTALL_NETFREE_CA=0 +# # # אל תעשי ADD מהאינטרנט (זה מה שנפל) +# # # במקום זה, רק אם תרצי – נוריד בזמן הבנייה עם curl (כש-INSTALL_NETFREE_CA=1) +# # RUN if [ "$INSTALL_NETFREE_CA" = "1" ]; then \ +# # apt-get update && apt-get install -y --no-install-recommends curl ca-certificates && \ +# # curl -fsSL --retry 5 https://netfree.link/dl/unix-ca.sh -o /home/netfree-unix-ca.sh && \ +# # bash /home/netfree-unix-ca.sh && update-ca-certificates && \ +# # rm -rf /var/lib/apt/lists/* ; \ +# # else echo "Skipping NetFree CA install"; fi + +# # # Force pip to trust PyPI +# # RUN pip config set global.trusted-host "pypi.org files.pythonhosted.org pypi.python.org" +# # RUN pip config set global.cert /etc/ssl/certs/ca-certificates.crt +# # # --- END NETFREE CERT INSTALL --- + +# # # ============================== +# # # Install system packages +# # # ============================== +# # RUN apt-get update && apt-get install -y --no-install-recommends \ +# # libgl1-mesa-glx \ +# # libglib2.0-0 \ +# # libsm6 \ +# # libxext6 \ +# # libxrender1 \ +# # libgtk2.0-0 \ +# # libcanberra-gtk-module \ +# # libcanberra-gtk3-module \ +# # && rm -rf /var/lib/apt/lists/* + +# # # ============================== +# # # Working directory +# # # ============================== +# # ==== Portable CPU base (works everywhere) ==== +# FROM python:3.10-slim + +# ENV PIP_NO_CACHE_DIR=1 \ +# PYTHONDONTWRITEBYTECODE=1 \ +# PYTHONUNBUFFERED=1 + +# # System deps מינימליים ל-CV/IO +# RUN apt-get update && apt-get install -y --no-install-recommends \ +# git ffmpeg libsm6 libxext6 libgl1 ca-certificates \ +# && rm -rf /var/lib/apt/lists/* +# # הוספת תעודת NetFree ל־trust store של המערכת +# COPY netfree-ca.crt /usr/local/share/ca-certificates/netfree-ca.crt +# RUN update-ca-certificates + +# # לוודא שכלי רשת/פייתון משתמשים ב־CA המעודכן +# ENV REQUESTS_CA_BUNDLE=/etc/ssl/certs/ca-certificates.crt +# ENV SSL_CERT_FILE=/etc/ssl/certs/ca-certificates.crt +# ENV PIP_CERT=/etc/ssl/certs/ca-certificates.crt + +# # Torch/torchvision/torchaudio גרסאות CPU יציבות מה-PyTorch index +# ARG TORCH_VERSION=2.2.1 +# ARG TORCHVISION_VERSION=0.17.1 +# ARG TORCHAUDIO_VERSION=2.2.1 +# RUN python -m pip install --upgrade pip && \ +# python -m pip install --index-url https://download.pytorch.org/whl/cpu \ +# torch==${TORCH_VERSION} \ +# torchvision==${TORCHVISION_VERSION} \ +# torchaudio==${TORCHAUDIO_VERSION} + +# # (מבטל תלות ב-NetFree בזמן build; אין ADD/curl מהאינטרנט בשלב הזה) +# # ==== END portable header ==== + +# # ============================== +# # Working directory +# # ============================== +# # WORKDIR /app + +# WORKDIR /app + +# # Update pip +# RUN pip install --upgrade pip + +# # ============================== +# # Install dependencies +# # ============================== +# COPY agri_baseline/requirements.txt /app/requirements.txt +# RUN pip install --no-cache-dir --upgrade "numpy==1.26.4" +# RUN pip install --no-cache-dir --force-reinstall "opencv-python-headless==4.9.0.80" + +# RUN pip install --no-cache-dir -r /app/requirements.txt + +# # ============================== +# # Copy source code +# # ============================== +# COPY agri_baseline /app/agri_baseline +# COPY models /app/models +# # Copy tests folder +# COPY tests /app/tests + +# # Set PYTHONPATH +# ENV PYTHONPATH=/app:$PYTHONPATH + +# # ============================== +# # Entry point +# # ============================== +# CMD ["python", "agri_baseline/src/batch_runner.py"] +# syntax=docker/dockerfile:1.6 + +FROM mcr.microsoft.com/devcontainers/python:1-3.11-bullseye + +ENV PIP_NO_CACHE_DIR=0 \ + PYTHONDONTWRITEBYTECODE=1 \ + PYTHONUNBUFFERED=1 + +# 1) חבילות מערכת בסיסיות +RUN apt-get update && apt-get install -y --no-install-recommends \ + git ffmpeg libsm6 libxext6 libgl1 ca-certificates \ + && rm -rf /var/lib/apt/lists/* + +# 2) הוספת תעודת NetFree שהכנת (הקובץ יושב לצד ה-dockerfile) +# COPY netfree-ca.crt /usr/local/share/ca-certificates/netfree-ca.crt +# RUN update-ca-certificates + +# 3) לוודא שכלי רשת/פייתון משתמשים ב-CA של המערכת +ENV SSL_CERT_FILE=/etc/ssl/certs/ca-certificates.crt \ + REQUESTS_CA_BUNDLE=/etc/ssl/certs/ca-certificates.crt \ + PIP_CERT=/etc/ssl/certs/ca-certificates.crt + +# 4) התקנת Torch CPU מהאינדקס של PyTorch עם cache +ARG TORCH_VERSION=2.2.1 +ARG TORCHVISION_VERSION=0.17.1 +ARG TORCHAUDIO_VERSION=2.2.1 +RUN --mount=type=cache,target=/root/.cache/pip \ + python -m pip install --upgrade pip && \ + python -m pip install --index-url https://download.pytorch.org/whl/cpu \ + torch==${TORCH_VERSION} \ + torchvision==${TORCHVISION_VERSION} \ + torchaudio==${TORCHAUDIO_VERSION} + +# 5) ספריות פייתון נוספות (עם cache + בלי --no-cache-dir) +# WORKDIR /app +# COPY Detection_Jobs/agri_baseline/requirements.txt /app/requirements.txt + +# RUN --mount=type=cache,target=/root/.cache/pip \ +# pip install --upgrade pip && \ +# pip install "numpy==1.26.4" && \ +# pip install --force-reinstall "opencv-python-headless==4.9.0.80" && \ +# pip install --retries 10 --timeout 120 -r /app/requirements.txt +# 5) ספריות פייתון נוספות (עם cache + סינון GPU) +WORKDIR /app +COPY agri_baseline/requirements.txt /app/requirements.txt + +# מסננים תלויות GPU כדי שלא ימשכו CUDA +RUN awk '!/^(torch|torchvision|torchaudio)[[:space:]=<>!~]*$/ \ + && !/^pytorch-cuda/ \ + && !/^xformers/ \ + && !/^cupy-cuda/ \ + && !/^nvidia[-_]/' /app/requirements.txt > /app/requirements.cpu.txt + +RUN --mount=type=cache,target=/root/.cache/pip \ + pip install --upgrade pip && \ + PIP_INDEX_URL=https://download.pytorch.org/whl/cpu \ + PIP_EXTRA_INDEX_URL=https://pypi.org/simple \ + pip install --retries 10 --timeout 120 -r /app/requirements.cpu.txt + +# 6) קוד המקור +COPY agri_baseline /app/agri_baseline +COPY models /app/models +COPY tests /app/tests + +# 7) PYTHONPATH – בלי ההפניה למשתנה שאינו קיים בבילד +ENV PYTHONPATH=/app + +# 8) Entry +CMD ["python", "agri_baseline/src/batch_runner.py"] diff --git a/airflow_bundle/leaf-pipeline/projects/Detection_Jobs/Detection_Jobs/pytest.ini b/airflow_bundle/leaf-pipeline/projects/Detection_Jobs/Detection_Jobs/pytest.ini new file mode 100644 index 000000000..89313dd9b --- /dev/null +++ b/airflow_bundle/leaf-pipeline/projects/Detection_Jobs/Detection_Jobs/pytest.ini @@ -0,0 +1,4 @@ +[pytest] +pythonpath = . +testpaths = tests +addopts = -v diff --git a/airflow_bundle/leaf-pipeline/projects/Detection_Jobs/Detection_Jobs/research/detectors/cnn_binary_classifier.py b/airflow_bundle/leaf-pipeline/projects/Detection_Jobs/Detection_Jobs/research/detectors/cnn_binary_classifier.py new file mode 100644 index 000000000..898c2c918 --- /dev/null +++ b/airflow_bundle/leaf-pipeline/projects/Detection_Jobs/Detection_Jobs/research/detectors/cnn_binary_classifier.py @@ -0,0 +1,12 @@ +# agri-baseline/src/detectors/cnn_binary_classifier.py +import torch.nn as nn +from torchvision import models + +def build_binary_model(pretrained: bool = True) -> nn.Module: + """ + Builds a ResNet18 model for binary classification (healthy vs diseased). + """ + model = models.resnet18(weights="IMAGENET1K_V1" if pretrained else None) + in_features = model.fc.in_features + model.fc = nn.Linear(in_features, 2) # healthy / diseased + return model diff --git a/airflow_bundle/leaf-pipeline/projects/Detection_Jobs/Detection_Jobs/research/detectors/dataset_binary.py b/airflow_bundle/leaf-pipeline/projects/Detection_Jobs/Detection_Jobs/research/detectors/dataset_binary.py new file mode 100644 index 000000000..d63bf5208 --- /dev/null +++ b/airflow_bundle/leaf-pipeline/projects/Detection_Jobs/Detection_Jobs/research/detectors/dataset_binary.py @@ -0,0 +1,36 @@ +# agri-baseline/src/detectors/dataset_binary.py +import os +from torch.utils.data import Dataset +from PIL import Image + +class BinaryDiseaseDataset(Dataset): + """ + Dataset wrapper that maps: + - healthy folders -> label 0 + - all disease folders -> label 1 + Keeps also the original folder name for optional subtype info. + """ + def __init__(self, root: str, transform=None): + self.samples = [] + + self.targets = [] + self.transform = transform + for cls in os.listdir(root): + path = os.path.join(root, cls) + if not os.path.isdir(path): + continue + label = 0 if "healthy" in cls.lower() else 1 + for f in os.listdir(path): + if f.lower().endswith((".jpg", ".png", ".jpeg")): + self.samples.append((os.path.join(path, f), label, cls)) + self.targets.append(label) + + def __len__(self): + return len(self.samples) + + def __getitem__(self, idx): + path, label, cls_name = self.samples[idx] + img = Image.open(path).convert("RGB") + if self.transform: + img = self.transform(img) + return img, label, cls_name diff --git a/airflow_bundle/leaf-pipeline/projects/Detection_Jobs/Detection_Jobs/research/detectors/train/disease.py b/airflow_bundle/leaf-pipeline/projects/Detection_Jobs/Detection_Jobs/research/detectors/train/disease.py new file mode 100644 index 000000000..653f673ae --- /dev/null +++ b/airflow_bundle/leaf-pipeline/projects/Detection_Jobs/Detection_Jobs/research/detectors/train/disease.py @@ -0,0 +1,202 @@ +import cv2 +import numpy as np + +from ...agri_baseline.src.detectors.base import Detection +from ..pipeline import config + + +class DiseaseDetector: + """ + Improved disease detector: + - Leaf mask (HSV/LAB) to isolate plant tissue. + - Candidate lesion detection: + 1) Yellow/Brown in HSV (stress/necrosis). + 2) Dark + Brown in LAB (low L, high b). + - Noise cleaning and merging. + - Shape filtering by circularity (detect "spots"). + - Confidence weighted by darkness, saturation, and circularity. + """ + + name = "disease" + + # HSV thresholds for yellow/brown (tunable) + HSV_YELLOW = ((10, 50, 40), (45, 255, 255)) + HSV_BROWN1 = ((0, 80, 30), (10, 255, 200)) + HSV_BROWN2 = ((160, 80, 30), (179, 255, 200)) + + # LAB thresholds for dark/brown lesions (tunable) + LAB_L_MAX_DARK = 145 # Lower L means darker + LAB_B_MIN_BROWN = 135 # Higher b means more yellow/brown + + # Shape filtering + MIN_CIRCULARITY = 0.22 # 4πA/P^2; range 0..1 + MAX_ASPECT_RATIO = 2.2 # Avoid elongated regions + DILATE_MERGE_RADIUS = 4 + + def __init__(self): + # Minimum area from config (fallback to default if missing) + self.min_area = int(getattr(config, "MIN_BBOX_AREA", 60)) + + def run(self, bgr_image: np.ndarray) -> list[Detection]: + h, w = bgr_image.shape[:2] + + # ---------- 1) Leaf isolation ---------- + hsv = cv2.cvtColor(bgr_image, cv2.COLOR_BGR2HSV) + lab = cv2.cvtColor(bgr_image, cv2.COLOR_BGR2LAB) + H, S, V = cv2.split(hsv) + L, A, B = cv2.split(lab) + + # Green mask in HSV (broad range for leaf tissue) + green1 = cv2.inRange(hsv, (35, 30, 30), (85, 255, 255)) + green2 = cv2.inRange(hsv, (25, 25, 40), (95, 255, 255)) + leaf_mask = cv2.bitwise_or(green1, green2) + + # Contrast enhancement with CLAHE on L channel + clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8)) + L_eq = clahe.apply(L) + + # Basic cleaning of leaf mask + leaf_mask = cv2.medianBlur(leaf_mask, 5) + leaf_mask = cv2.morphologyEx( + leaf_mask, cv2.MORPH_CLOSE, np.ones((5, 5), np.uint8), iterations=1 + ) + + # ---------- 2) Lesion candidates ---------- + # (a) Yellow/Brown in HSV + yellow = cv2.inRange(hsv, self.HSV_YELLOW[0], self.HSV_YELLOW[1]) + brown1 = cv2.inRange(hsv, self.HSV_BROWN1[0], self.HSV_BROWN1[1]) + brown2 = cv2.inRange(hsv, self.HSV_BROWN2[0], self.HSV_BROWN2[1]) + hsv_spots = cv2.bitwise_or(yellow, cv2.bitwise_or(brown1, brown2)) + + # (b) Dark + Brownish in LAB + dark = cv2.threshold(L_eq, self.LAB_L_MAX_DARK, 255, cv2.THRESH_BINARY_INV)[1] + brownish = cv2.threshold(B, self.LAB_B_MIN_BROWN, 255, cv2.THRESH_BINARY)[1] + lab_spots = cv2.bitwise_and(dark, brownish) + + # Combine HSV and LAB candidates, restricted to leaf mask + candidates = cv2.bitwise_or(hsv_spots, lab_spots) + candidates = cv2.bitwise_and(candidates, leaf_mask) + + # ---------- 3) Cleaning & merging ---------- + candidates = cv2.medianBlur(candidates, 3) + candidates = cv2.morphologyEx( + candidates, cv2.MORPH_OPEN, np.ones((3, 3), np.uint8), iterations=1 + ) + + # Dilate slightly to merge nearby spots + if self.DILATE_MERGE_RADIUS > 0: + k = cv2.getStructuringElement( + cv2.MORPH_ELLIPSE, + (2 * self.DILATE_MERGE_RADIUS + 1, 2 * self.DILATE_MERGE_RADIUS + 1), + ) + candidates = cv2.dilate(candidates, k, iterations=1) + + # ---------- 4) Contours & filtering ---------- + cnts, _ = cv2.findContours(candidates, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) + dets = [] + for c in cnts: + area = cv2.contourArea(c) + if area < self.min_area: + continue + + x, y, bw, bh = cv2.boundingRect(c) + + # Circularity: 4πA / P^2 + perim = cv2.arcLength(c, True) + circularity = (4.0 * np.pi * area / (perim ** 2 + 1e-6)) if perim > 0 else 0.0 + if circularity < self.MIN_CIRCULARITY: + continue + + # Aspect ratio filtering + ar = max(bw, bh) / (min(bw, bh) + 1e-6) + if ar > self.MAX_ASPECT_RATIO: + continue + + # Extract subregion for scoring + hsv_box = hsv[y : y + bh, x : x + bw] + lab_box = lab[y : y + bh, x : x + bw] + + Lb = lab_box[:, :, 0].astype(np.float32) + Sb = hsv_box[:, :, 1].astype(np.float32) + + # Darkness score (lower L → higher score) + dark_score = np.clip((180.0 - float(np.mean(Lb))) / 180.0, 0.0, 1.0) + # Saturation score (higher S → higher score) + sat_score = np.clip(float(np.mean(Sb)) / 255.0, 0.0, 1.0) + + # Final weighted confidence + conf = 0.45 * dark_score + 0.35 * sat_score + 0.20 * np.clip(circularity, 0.0, 1.0) + conf = float(np.clip(conf, 0.0, 1.0)) + + dets.append( + Detection( + label="disease_spot", + confidence=conf, + x=int(x), + y=int(y), + w=int(bw), + h=int(bh), + area=int(area), + ) + ) + + # ---------- 5) Merge overlapping boxes ---------- + dets = self._merge_overlaps(dets, iou_thresh=0.5) + return dets + + # ---------- IoU helper ---------- + @staticmethod + def _iou(a, b): + ax1, ay1, ax2, ay2 = a.x, a.y, a.x + a.w, a.y + a.h + bx1, by1, bx2, by2 = b.x, b.y, b.x + b.w, b.y + b.h + inter_x1, inter_y1 = max(ax1, bx1), max(ay1, by1) + inter_x2, inter_y2 = min(ax2, bx2), min(ay2, by2) + iw, ih = max(0, inter_x2 - inter_x1), max(0, inter_y2 - inter_y1) + inter = iw * ih + if inter == 0: + return 0.0 + area_a = a.w * a.h + area_b = b.w * b.h + return inter / float(area_a + area_b - inter + 1e-6) + + def _merge_overlaps(self, dets, iou_thresh=0.5): + if not dets: + return dets + dets = sorted(dets, key=lambda d: d.confidence, reverse=True) + kept = [] + while dets: + base = dets.pop(0) + to_merge = [base] + remain = [] + for d in dets: + if self._iou(base, d) >= iou_thresh: + to_merge.append(d) + else: + remain.append(d) + dets = remain + + # Merge into one bounding box + xs = [d.x for d in to_merge] + ys = [d.y for d in to_merge] + x2s = [d.x + d.w for d in to_merge] + y2s = [d.y + d.h for d in to_merge] + x = int(min(xs)) + y = int(min(ys)) + w = int(max(x2s) - x) + h = int(max(y2s) - y) + + # Average confidence + conf = float(np.mean([d.confidence for d in to_merge])) + area = int(w * h) + kept.append( + Detection( + label="disease_spot", + confidence=conf, + x=x, + y=y, + w=w, + h=h, + area=area, + ) + ) + return kept diff --git a/airflow_bundle/leaf-pipeline/projects/Detection_Jobs/Detection_Jobs/research/detectors/train/eval_multi_levels.py b/airflow_bundle/leaf-pipeline/projects/Detection_Jobs/Detection_Jobs/research/detectors/train/eval_multi_levels.py new file mode 100644 index 000000000..c39ea4e5a --- /dev/null +++ b/airflow_bundle/leaf-pipeline/projects/Detection_Jobs/Detection_Jobs/research/detectors/train/eval_multi_levels.py @@ -0,0 +1,167 @@ +# eval_multi_levels.py +import torch +import numpy as np +from sklearn.metrics import accuracy_score, confusion_matrix, f1_score, classification_report +from torch.utils.data import DataLoader +import cv2 +import albumentations as A +from albumentations.pytorch import ToTensorV2 + +from agri_baseline.src.detectors.train.dictionary import CLASS_MAPPING +from agri_baseline.src.detectors.cnn_multi_classifier import build_multi_model +from torchvision import datasets +import seaborn as sns +import matplotlib.pyplot as plt + +# ------------------------ +# Paths +# ------------------------ +DATA_DIR = "data_balanced/PlantDoc/test" +MODEL_PATH = "models/cnn_multi_stage3.pth" + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +# ------------------------ +# Transforms +# ------------------------ +val_transforms = A.Compose([ + A.Resize(224, 224), + A.Normalize(mean=(0.485, 0.456, 0.406), + std=(0.229, 0.224, 0.225)), + ToTensorV2() +]) + +# ------------------------ +# Dataset wrapper +# ------------------------ +class AlbumentationsDataset(torch.utils.data.Dataset): + def __init__(self, dataset, transform=None): + self.dataset = dataset + self.transform = transform + + def __getitem__(self, idx): + path, label = self.dataset.samples[idx] + image = cv2.imread(path) + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + if self.transform: + image = self.transform(image=image)["image"] + return image, label + + def __len__(self): + return len(self.dataset) + + +# ------------------------ +# Prepare dataset +# ------------------------ +dataset = datasets.ImageFolder(DATA_DIR) +canonical_classes = sorted(set(CLASS_MAPPING.values())) +class_to_idx = {cls: i for i, cls in enumerate(canonical_classes)} + +new_samples, new_targets = [], [] +for path, label_idx in dataset.samples: + raw_name = dataset.classes[label_idx].lower().replace(" ", "_") + canonical_label = CLASS_MAPPING.get(raw_name) + if canonical_label is None: + raise ValueError(f"Class {raw_name} not found in CLASS_MAPPING") + new_samples.append((path, class_to_idx[canonical_label])) + new_targets.append(class_to_idx[canonical_label]) + +dataset.samples = new_samples +dataset.targets = new_targets +dataset.classes = canonical_classes +dataset.class_to_idx = class_to_idx + +val_dataset = AlbumentationsDataset(dataset, transform=val_transforms) +val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False) + +# ------------------------ +# Load model +# ------------------------ +model = build_multi_model(num_classes=len(canonical_classes)).to(device) +state_dict = torch.load(MODEL_PATH, map_location=device) +model.load_state_dict(state_dict) +model.eval() + +# ------------------------ +# Evaluation +# ------------------------ +all_preds, all_labels = [], [] +with torch.no_grad(): + for images, labels in val_loader: + images, labels = images.to(device), labels.to(device) + outputs = model(images) + _, preds = outputs.max(1) + all_preds.extend(preds.cpu().numpy()) + all_labels.extend(labels.cpu().numpy()) + +all_preds = np.array(all_preds) +all_labels = np.array(all_labels) + +# ------------------------ +# Grouping +# ------------------------ +def to_healthy_sick(cls: str): + return "healthy" if "healthy" in cls else "sick" + +def to_crop(cls: str): + if cls.startswith("tomato"): return "tomato" + if cls.startswith("potato"): return "potato" + if cls.startswith("pepper"): return "pepper" + return "other" + +def to_disease(cls: str): + if "bacterial_spot" in cls: return "bacterial_spot" + if "early_blight" in cls: return "early_blight" + if "late_blight" in cls: return "late_blight" + if "leaf_mold" in cls: return "leaf_mold" + if "septoria_leaf_spot" in cls: return "septoria_leaf_spot" + if "spider_mites" in cls: return "spider_mites" + if "target_spot" in cls: return "target_spot" + if "mosaic_virus" in cls: return "mosaic_virus" + if "yellowleaf_curl_virus" in cls: return "yellowleaf_curl_virus" + return "none" + +idx_to_class = {v: k for k, v in class_to_idx.items()} + +y_true_cls = [idx_to_class[i] for i in all_labels] +y_pred_cls = [idx_to_class[i] for i in all_preds] + +# ------------------------ +# Evaluation per level +# ------------------------ +def evaluate_level(name, y_true, y_pred, labels=None): + acc = accuracy_score(y_true, y_pred) + f1 = f1_score(y_true, y_pred, average="weighted") + print(f"\n===== {name} =====") + print(f"Accuracy: {acc:.4f}") + print(f"F1-score (weighted): {f1:.4f}") + print(classification_report(y_true, y_pred, digits=4)) + cm = confusion_matrix(y_true, y_pred, labels=labels) + if labels: + plt.figure(figsize=(8, 6)) + sns.heatmap(cm, annot=True, fmt="d", xticklabels=labels, yticklabels=labels, cmap="Blues") + plt.title(f"Confusion Matrix - {name}") + plt.xlabel("Predicted") + plt.ylabel("True") + plt.show() + +# Healthy vs Sick +evaluate_level("Healthy vs Sick", + [to_healthy_sick(c) for c in y_true_cls], + [to_healthy_sick(c) for c in y_pred_cls], + labels=["healthy", "sick"]) + +# Crop type +evaluate_level("Crop type", + [to_crop(c) for c in y_true_cls], + [to_crop(c) for c in y_pred_cls], + labels=["tomato", "potato", "pepper", "other"]) + +# Disease type +evaluate_level("Disease type", + [to_disease(c) for c in y_true_cls], + [to_disease(c) for c in y_pred_cls], + labels=["bacterial_spot","early_blight","late_blight","leaf_mold", + "septoria_leaf_spot","spider_mites","target_spot", + "mosaic_virus","yellowleaf_curl_virus","none"]) diff --git a/airflow_bundle/leaf-pipeline/projects/Detection_Jobs/Detection_Jobs/research/detectors/train/finetune_multi.py b/airflow_bundle/leaf-pipeline/projects/Detection_Jobs/Detection_Jobs/research/detectors/train/finetune_multi.py new file mode 100644 index 000000000..e3e5457cd --- /dev/null +++ b/airflow_bundle/leaf-pipeline/projects/Detection_Jobs/Detection_Jobs/research/detectors/train/finetune_multi.py @@ -0,0 +1,242 @@ +# finetune_multi.py +import torch +import torch.nn as nn +import torch.optim as optim +from torchvision import datasets +import os +from sklearn.metrics import f1_score +from torch.utils.data import DataLoader, random_split, WeightedRandomSampler +from torch.optim.lr_scheduler import ReduceLROnPlateau +import albumentations as A +from albumentations.pytorch import ToTensorV2 +import cv2 +import numpy as np + +from agri_baseline.src.detectors.train.dictionary import CLASS_MAPPING +from agri_baseline.src.detectors.cnn_multi_classifier import build_multi_model + + +# ------------------------ +# MixUp +# ------------------------ +def mixup_data(x, y, alpha=1.0): + if alpha > 0: + lam = np.random.beta(alpha, alpha) + else: + lam = 1 + batch_size = x.size()[0] + index = torch.randperm(batch_size).to(x.device) + + mixed_x = lam * x + (1 - lam) * x[index, :] + y_a, y_b = y, y[index] + return mixed_x, y_a, y_b, lam + +def mixup_criterion(criterion, pred, y_a, y_b, lam): + return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b) + + +# ------------------------ +# Paths +# ------------------------ +DATA_DIR = "data_balanced/PlantDoc" +MODEL_PATH = "models/cnn_multi.pth" +SAVE_PATH = "models/cnn_multi_finetuned.pth" + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + +# ------------------------ +# Augmentations +# ------------------------ +train_transforms = A.Compose([ + A.RandomResizedCrop(size=(224, 224), scale=(0.7, 1.0), p=1.0), + A.HorizontalFlip(p=0.5), + A.VerticalFlip(p=0.3), + A.ShiftScaleRotate(shift_limit=0.05, scale_limit=0.2, rotate_limit=30, p=0.7), + A.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, p=0.5), + A.RandomBrightnessContrast(p=0.5), + A.GaussianBlur(p=0.3), + A.CoarseDropout(max_height=32, max_width=32, max_holes=1, p=0.3), + A.Normalize(mean=(0.485, 0.456, 0.406), + std=(0.229, 0.224, 0.225)), + ToTensorV2() +]) + +val_transforms = A.Compose([ + A.Resize(224, 224), + A.Normalize(mean=(0.485, 0.456, 0.406), + std=(0.229, 0.224, 0.225)), + ToTensorV2() +]) + + +# ------------------------ +# Albumentations Dataset +# ------------------------ +class AlbumentationsDataset(torch.utils.data.Dataset): + def __init__(self, dataset, transform=None): + self.dataset = dataset + self.transform = transform + + def __getitem__(self, idx): + path, label = self.dataset.samples[idx] + image = cv2.imread(path) + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + if self.transform: + image = self.transform(image=image)["image"] + return image, label + + def __len__(self): + return len(self.dataset) + + +# ------------------------ +# Prepare Dataset +# ------------------------ +def prepare_multi_dataset(path): + dataset = datasets.ImageFolder(path) + new_samples, new_targets = [], [] + canonical_classes = sorted(set(CLASS_MAPPING.values())) + class_to_idx = {cls: i for i, cls in enumerate(canonical_classes)} + + for sample_path, label_idx in dataset.samples: + raw_name = dataset.classes[label_idx].lower().replace(" ", "_") + canonical_label = CLASS_MAPPING.get(raw_name) + if canonical_label is None: + raise ValueError(f"Class {raw_name} not found in CLASS_MAPPING") + new_samples.append((sample_path, class_to_idx[canonical_label])) + new_targets.append(class_to_idx[canonical_label]) + + dataset.samples = new_samples + dataset.targets = new_targets + dataset.classes = canonical_classes + dataset.class_to_idx = class_to_idx + return dataset + + +# ------------------------ +# Load dataset +# ------------------------ +full_dataset = prepare_multi_dataset(os.path.join(DATA_DIR, "train")) +print("Classes:", full_dataset.classes) +print("Total samples:", len(full_dataset)) + +train_size = int(0.8 * len(full_dataset)) +val_size = len(full_dataset) - train_size +train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size]) + +train_dataset = AlbumentationsDataset(train_dataset.dataset, transform=train_transforms) +val_dataset = AlbumentationsDataset(val_dataset.dataset, transform=val_transforms) + +class_counts = np.bincount(full_dataset.targets) +class_weights = 1. / class_counts +sample_weights = [class_weights[t] for t in full_dataset.targets] + +sampler = WeightedRandomSampler(weights=sample_weights, + num_samples=len(sample_weights), + replacement=True) + +train_loader = DataLoader(train_dataset, batch_size=32, sampler=sampler) +val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False) + + +# ------------------------ +# Model +# ------------------------ +model = build_multi_model(num_classes=len(full_dataset.classes)).to(device) +state_dict = torch.load(MODEL_PATH, map_location=device) +filtered_state_dict = {k: v for k, v in state_dict.items() if not k.startswith("fc.")} +model.load_state_dict(filtered_state_dict, strict=False) +print("✅ Loaded pretrained backbone") + + +# ------------------------ +# Training setup +# ------------------------ +criterion = nn.CrossEntropyLoss() +optimizer = optim.Adam([ + {"params": model.fc.parameters(), "lr": 1e-3}, +], lr=1e-3) + +scheduler = ReduceLROnPlateau(optimizer, mode="min", factor=0.5, patience=3, verbose=True) +best_val_f1 = 0.0 +patience, counter = 5, 0 + + +# ------------------------ +# Gradual Unfreeze +# ------------------------ +def unfreeze(epoch): + if epoch == 5: + for name, param in model.named_parameters(): + if "layer4" in name: + param.requires_grad = True + if epoch == 10: + for param in model.parameters(): + param.requires_grad = True + + +# ------------------------ +# Training Loop +# ------------------------ +EPOCHS = 20 +for epoch in range(EPOCHS): + unfreeze(epoch) + model.train() + total_loss, correct, total = 0.0, 0, 0 + for images, labels in train_loader: + images, labels = images.to(device), labels.to(device) + optimizer.zero_grad() + images, targets_a, targets_b, lam = mixup_data(images, labels, alpha=0.4) + outputs = model(images) + loss = mixup_criterion(criterion, outputs, targets_a, targets_b, lam) + loss.backward() + optimizer.step() + total_loss += loss.item() * images.size(0) + _, preds = outputs.max(1) + correct += preds.eq(labels).sum().item() + total += labels.size(0) + + train_acc = correct / total + train_loss = total_loss / total + + # Validation + model.eval() + all_preds, all_labels = [], [] + val_loss, val_correct, val_total = 0.0, 0, 0 + with torch.no_grad(): + for images, labels in val_loader: + images, labels = images.to(device), labels.to(device) + outputs = model(images) + loss = criterion(outputs, labels) + val_loss += loss.item() * images.size(0) + _, preds = outputs.max(1) + val_correct += preds.eq(labels).sum().item() + val_total += labels.size(0) + all_preds.extend(preds.cpu().numpy()) + all_labels.extend(labels.cpu().numpy()) + + val_acc = val_correct / val_total + val_loss /= val_total + val_f1 = f1_score(all_labels, all_preds, average="weighted") + + print(f"Epoch {epoch+1}/{EPOCHS} " + f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f} " + f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}, Val F1: {val_f1:.4f}") + + scheduler.step(val_loss) + + # Save by F1 + if val_f1 > best_val_f1: + best_val_f1 = val_f1 + counter = 0 + torch.save(model.state_dict(), SAVE_PATH) + print("💾 Model improved (F1) and saved!") + else: + counter += 1 + print(f"⏳ No improvement. EarlyStopping counter: {counter}/{patience}") + if counter >= patience: + print("🛑 Early stopping triggered!") + break + +print(f"✅ Training finished. Best model saved to {SAVE_PATH}") diff --git a/airflow_bundle/leaf-pipeline/projects/Detection_Jobs/Detection_Jobs/research/detectors/train/finetune_multi_stage3.py b/airflow_bundle/leaf-pipeline/projects/Detection_Jobs/Detection_Jobs/research/detectors/train/finetune_multi_stage3.py new file mode 100644 index 000000000..eee68bd67 --- /dev/null +++ b/airflow_bundle/leaf-pipeline/projects/Detection_Jobs/Detection_Jobs/research/detectors/train/finetune_multi_stage3.py @@ -0,0 +1,191 @@ +# finetune_multi_stage3.py +import torch +import torch.nn as nn +import torch.optim as optim +from torch.utils.data import DataLoader, random_split +import albumentations as A +from albumentations.pytorch import ToTensorV2 +import cv2, os +import numpy as np +from sklearn.metrics import f1_score + +from agri_baseline.src.detectors.train.dictionary import CLASS_MAPPING +from agri_baseline.src.detectors.cnn_multi_classifier import build_multi_model +from torchvision import datasets + +# ========================= +# Config +# ========================= +DATA_DIR = "data_balanced/PlantDoc" +PREV_MODEL = "models/cnn_multi_finetuned.pth" +SAVE_PATH = "models/cnn_multi_stage3.pth" + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + +# ========================= +# Augmentations +# ========================= +train_tfms = A.Compose([ + A.RandomResizedCrop(size=(224, 224), scale=(0.6, 1.0), p=1.0), + A.HorizontalFlip(p=0.5), + A.VerticalFlip(p=0.3), + A.RandomBrightnessContrast(p=0.4), + A.ShiftScaleRotate(shift_limit=0.05, scale_limit=0.2, rotate_limit=30, p=0.5), + A.GaussianBlur(p=0.2), + A.RandomGamma(p=0.3), + A.Normalize(mean=(0.485, 0.456, 0.406), + std=(0.229, 0.224, 0.225)), + ToTensorV2() +]) + +val_tfms = A.Compose([ + A.Resize(224, 224), + A.Normalize(mean=(0.485, 0.456, 0.406), + std=(0.229, 0.224, 0.225)), + ToTensorV2() +]) + + +# ========================= +# Dataset wrapper +# ========================= +class AlbumentationsDataset(torch.utils.data.Dataset): + def __init__(self, dataset, transform=None): + self.dataset = dataset + self.transform = transform + + def __getitem__(self, idx): + path, label = self.dataset.samples[idx] + img = cv2.imread(path) + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + if self.transform: + img = self.transform(image=img)["image"] + return img, label + + def __len__(self): + return len(self.dataset) + + +def prepare_dataset(path): + ds = datasets.ImageFolder(path) + new_samples, new_targets = [], [] + canonical = sorted(set(CLASS_MAPPING.values())) + class_to_idx = {cls: i for i, cls in enumerate(canonical)} + + for pth, idx in ds.samples: + raw = ds.classes[idx].lower().replace(" ", "_") + canon = CLASS_MAPPING.get(raw) + if canon is None: + raise ValueError(f"Class {raw} missing in CLASS_MAPPING") + new_samples.append((pth, class_to_idx[canon])) + new_targets.append(class_to_idx[canon]) + + ds.samples = new_samples + ds.targets = new_targets + ds.classes = canonical + ds.class_to_idx = class_to_idx + return ds + + +# ========================= +# Progressive unfreezing +# ========================= +def unfreeze_layers(model, stages): + """ + stages: List of layer names to release (e.g.: ["layer3", "layer2"]) + """ + for name, param in model.named_parameters(): + for stage in stages: + if stage in name: + param.requires_grad = True + + +# ========================= +# Training loop +# ========================= +def train_stage3(model, train_loader, val_loader, epochs=20): + criterion = nn.CrossEntropyLoss() + optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-4) + + best_f1, patience, counter = 0, 5, 0 + for epoch in range(epochs): + model.train() + total_loss, total_correct, total = 0, 0, 0 + for xb, yb in train_loader: + xb, yb = xb.to(device), yb.to(device) + optimizer.zero_grad() + out = model(xb) + loss = criterion(out, yb) + loss.backward() + optimizer.step() + total_loss += loss.item() * xb.size(0) + _, preds = out.max(1) + total_correct += preds.eq(yb).sum().item() + total += yb.size(0) + + train_acc = total_correct / total + train_loss = total_loss / total + + # Validation + model.eval() + val_loss, val_correct, val_total = 0, 0, 0 + all_preds, all_labels = [], [] + with torch.no_grad(): + for xb, yb in val_loader: + xb, yb = xb.to(device), yb.to(device) + out = model(xb) + loss = criterion(out, yb) + val_loss += loss.item() * xb.size(0) + _, preds = out.max(1) + val_correct += preds.eq(yb).sum().item() + val_total += yb.size(0) + all_preds.extend(preds.cpu().numpy()) + all_labels.extend(yb.cpu().numpy()) + + val_acc = val_correct / val_total + val_loss /= val_total + val_f1 = f1_score(all_labels, all_preds, average="weighted") + + print(f"Epoch {epoch+1}/{epochs} | Train Loss: {train_loss:.4f} Acc: {train_acc:.3f} " + f"| Val Loss: {val_loss:.4f} Acc: {val_acc:.3f} F1: {val_f1:.3f}") + + if val_f1 > best_f1: + best_f1 = val_f1 + counter = 0 + torch.save(model.state_dict(), SAVE_PATH) + print(f"💾 Model improved (F1={val_f1:.3f}) and saved!") + else: + counter += 1 + if counter >= patience: + print("🛑 EarlyStopping triggered.") + break + + +# ========================= +# Main +# ========================= +if __name__ == "__main__": + full_ds = prepare_dataset(os.path.join(DATA_DIR, "train")) + train_size = int(0.8 * len(full_ds)) + val_size = len(full_ds) - train_size + train_ds, val_ds = random_split(full_ds, [train_size, val_size]) + + train_ds = AlbumentationsDataset(train_ds.dataset, transform=train_tfms) + val_ds = AlbumentationsDataset(val_ds.dataset, transform=val_tfms) + + train_loader = DataLoader(train_ds, batch_size=32, shuffle=True) + val_loader = DataLoader(val_ds, batch_size=32) + + model = build_multi_model(num_classes=len(full_ds.classes)).to(device) + model.load_state_dict(torch.load(PREV_MODEL, map_location=device)) + + # In step 3 we will release additional layers beyond layer4 + for p in model.parameters(): + p.requires_grad = False + for stage in ["layer3", "layer4", "fc"]: + unfreeze_layers(model, [stage]) + print(f"🔓 Unfroze {stage}") + + train_stage3(model, train_loader, val_loader, epochs=15) + print(f"✅ Training done. Best model saved to {SAVE_PATH}") diff --git a/airflow_bundle/leaf-pipeline/projects/Detection_Jobs/Detection_Jobs/research/detectors/train/train_binary_multi.py b/airflow_bundle/leaf-pipeline/projects/Detection_Jobs/Detection_Jobs/research/detectors/train/train_binary_multi.py new file mode 100644 index 000000000..0d48afb36 --- /dev/null +++ b/airflow_bundle/leaf-pipeline/projects/Detection_Jobs/Detection_Jobs/research/detectors/train/train_binary_multi.py @@ -0,0 +1,152 @@ +# agri-baseline/src/detectors/train_binary_multi.py +import argparse +import os +import torch +import torch.nn as nn +from torch.utils.data import DataLoader, WeightedRandomSampler +from torchvision import datasets, transforms +from torch.optim.lr_scheduler import ReduceLROnPlateau +import numpy as np + +from ...agri_baseline.src.detectors.cnn_binary_classifier import build_binary_model +from ...agri_baseline.src.detectors.cnn_multi_classifier import build_multi_model +from ...agri_baseline.src.detectors.dataset_binary import BinaryDiseaseDataset + + +def train_model(model, dataloader, val_dl, device, epochs, lr, out_path): + opt = torch.optim.Adam(model.parameters(), lr=lr) + loss_fn = nn.CrossEntropyLoss() + scheduler = ReduceLROnPlateau(opt, mode="min", factor=0.5, patience=3, verbose=True) + + best_val_loss = float("inf") + patience, counter = 5, 0 + + for epoch in range(epochs): + model.train() + running_loss, correct, total = 0.0, 0, 0 + for batch in dataloader: + if len(batch) == 3: + xb, yb, _ = batch + else: + xb, yb = batch + xb, yb = xb.to(device), yb.to(device) + + opt.zero_grad() + preds = model(xb) + loss = loss_fn(preds, yb) + loss.backward() + opt.step() + + running_loss += loss.item() * xb.size(0) + _, predicted = preds.max(1) + correct += predicted.eq(yb).sum().item() + total += yb.size(0) + + acc = correct / total + + # Validation + val_loss, val_acc = evaluate(model, val_dl, device, loss_fn) + print(f"Epoch {epoch+1}/{epochs} " + f"Train Loss={running_loss/total:.4f} Train Acc={acc:.3f} " + f"Val Loss={val_loss:.4f} Val Acc={val_acc:.3f}") + + scheduler.step(val_loss) + + # EarlyStopping + if val_loss < best_val_loss: + best_val_loss = val_loss + counter = 0 + torch.save(model.state_dict(), out_path) + print(f"💾 Saved best model {out_path}") + else: + counter += 1 + print(f"⏳ EarlyStopping counter {counter}/{patience}") + if counter >= patience: + print("🛑 Early stopping triggered") + break + + +def evaluate(model, dataloader, device, loss_fn): + model.eval() + correct, total, total_loss = 0, 0, 0.0 + with torch.no_grad(): + for batch in dataloader: + if len(batch) == 3: + xb, yb, _ = batch + else: + xb, yb = batch + xb, yb = xb.to(device), yb.to(device) + preds = model(xb) + loss = loss_fn(preds, yb) + total_loss += loss.item() * xb.size(0) + _, predicted = preds.max(1) + correct += predicted.eq(yb).sum().item() + total += yb.size(0) + return total_loss/total, correct/total + + +def make_sampler(targets): + class_counts = np.bincount(targets) + class_weights = 1. / class_counts + sample_weights = [class_weights[t] for t in targets] + return WeightedRandomSampler(weights=sample_weights, + num_samples=len(sample_weights), + replacement=True) + + +def main(): + p = argparse.ArgumentParser() + p.add_argument("--data", required=True, help="Dataset root (with train/val/test)") + p.add_argument("--out", default="./models") + p.add_argument("--epochs", type=int, default=10) + p.add_argument("--batch", type=int, default=32) + p.add_argument("--lr", type=float, default=1e-3) + p.add_argument("--device", default="cpu") + args = p.parse_args() + + device = torch.device(args.device) + + # Augmentations + train_tfms = transforms.Compose([ + transforms.RandomResizedCrop(224, scale=(0.8, 1.0)), + transforms.RandomHorizontalFlip(), + transforms.RandomRotation(15), + transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2), + transforms.ToTensor(), + transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225]) + ]) + test_tfms = transforms.Compose([ + transforms.Resize((224,224)), + transforms.ToTensor(), + transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225]) + ]) + + # Binary dataset + train_bin = BinaryDiseaseDataset(os.path.join(args.data,"train"), transform=train_tfms) + val_bin = BinaryDiseaseDataset(os.path.join(args.data,"val"), transform=test_tfms) + + sampler_bin = make_sampler(train_bin.targets) + train_dl_bin = DataLoader(train_bin, batch_size=args.batch, sampler=sampler_bin) + val_dl_bin = DataLoader(val_bin, batch_size=args.batch) + + model_bin = build_binary_model().to(device) + train_model(model_bin, train_dl_bin, val_dl_bin, device, args.epochs, args.lr, + os.path.join(args.out, "cnn_binary.pth")) + + # Multi-class dataset + train_multi = datasets.ImageFolder(os.path.join(args.data,"train"), transform=train_tfms) + val_multi = datasets.ImageFolder(os.path.join(args.data,"val"), transform=test_tfms) + + sampler_multi = make_sampler([y for _, y in train_multi.samples]) + train_dl_multi = DataLoader(train_multi, batch_size=args.batch, sampler=sampler_multi) + val_dl_multi = DataLoader(val_multi, batch_size=args.batch) + + model_multi = build_multi_model(num_classes=len(train_multi.classes)).to(device) + train_model(model_multi, train_dl_multi, val_dl_multi, device, args.epochs, args.lr, + os.path.join(args.out, "cnn_multi.pth")) + + torch.save({"classes": train_multi.classes}, + os.path.join(args.out,"multi_classes.pth")) + +if __name__=="__main__": + main() diff --git a/airflow_bundle/leaf-pipeline/projects/Detection_Jobs/Detection_Jobs/tests/conftest.py b/airflow_bundle/leaf-pipeline/projects/Detection_Jobs/Detection_Jobs/tests/conftest.py new file mode 100644 index 000000000..486e4fe52 --- /dev/null +++ b/airflow_bundle/leaf-pipeline/projects/Detection_Jobs/Detection_Jobs/tests/conftest.py @@ -0,0 +1,29 @@ +# tests/conftest.py +import os +import pytest +from sqlalchemy import text +from agri_baseline.src.pipeline.db import get_engine + +@pytest.fixture(autouse=True, scope="function") +def _ensure_local_db_url(monkeypatch): + """ + Guarantee DATABASE_URL exists for tests. + """ + monkeypatch.setenv( + "DATABASE_URL", + os.getenv( + "DATABASE_URL", + "postgresql+psycopg2://missions_user:pg123@localhost:5432/missions_db", + ), + ) + +@pytest.fixture(autouse=True) +def _clean_tables_before_test(): + """ + Clean key tables before each test so counts can increase deterministically. + Adjust the list to your schema. + """ + tables = ["anomalies", "tile_stats", "event_logs"] + with get_engine().begin() as conn: + for t in tables: + conn.execute(text(f"DELETE FROM {t}")) diff --git a/airflow_bundle/leaf-pipeline/projects/Detection_Jobs/Detection_Jobs/tests/test_batch_runner.py b/airflow_bundle/leaf-pipeline/projects/Detection_Jobs/Detection_Jobs/tests/test_batch_runner.py new file mode 100644 index 000000000..db88f63bc --- /dev/null +++ b/airflow_bundle/leaf-pipeline/projects/Detection_Jobs/Detection_Jobs/tests/test_batch_runner.py @@ -0,0 +1,68 @@ +# Purpose: End-to-end tests for the BatchRunner pipeline. +# Verifies that running on image folders or single images correctly writes results to the database. + +import pytest +from pathlib import Path +from sqlalchemy import text + +from agri_baseline.src.batch_runner import BatchRunner +from agri_baseline.src.pipeline.db import get_engine + + +@pytest.fixture +def folder_with_images() -> Path: + """ + Return a folder that contains a few test images. + Adjust the path if your dataset sits elsewhere. + """ + folder = Path("./data_balanced/PlantDoc/train/Bell_pepper leaf") + assert folder.exists(), f"Images folder not found: {folder.resolve()}" + return folder + + +def _count(conn, sql: str, params: dict | None = None) -> int: + """ + Small helper: run a COUNT(*) query safely with SQLAlchemy 2.0. + """ + return conn.execute(text(sql), params or {}).scalar() or 0 + + +def test_run_batch_on_images_folder(folder_with_images: Path): + """ + End-to-end: run the batch pipeline on a folder and verify DB writes happened. + We compare counts before/after instead of relying on specific image_id values. + """ + runner = BatchRunner() + + with get_engine().begin() as conn: + before = _count(conn, "SELECT COUNT(1) FROM anomalies") + + runner.run_folder(folder_with_images) + + with get_engine().begin() as conn: + after = _count(conn, "SELECT COUNT(1) FROM anomalies") + + assert after > before, "No detections were written to the database." + + +def test_process_single_image(): + """ + Process a single image and assert the DB anomalies count has increased. + This avoids fragile assumptions on the exact image_id in the DB. + """ + image_path = Path( + "./data_balanced/PlantDoc/train/Bell_pepper leaf/0f3s5A.jpg" + ) + assert image_path.exists(), f"Test image not found: {image_path.resolve()}" + + runner = BatchRunner() + + with get_engine().begin() as conn: + before = _count(conn, "SELECT COUNT(1) FROM anomalies") + + runner.process_image(image_path) + + with get_engine().begin() as conn: + after = _count(conn, "SELECT COUNT(1) FROM anomalies") + + assert after > before, "Single image was not processed correctly." diff --git a/airflow_bundle/leaf-pipeline/projects/Detection_Jobs/Detection_Jobs/tests/test_disease_model.py b/airflow_bundle/leaf-pipeline/projects/Detection_Jobs/Detection_Jobs/tests/test_disease_model.py new file mode 100644 index 000000000..c2cb78625 --- /dev/null +++ b/airflow_bundle/leaf-pipeline/projects/Detection_Jobs/Detection_Jobs/tests/test_disease_model.py @@ -0,0 +1,17 @@ +# Purpose: Unit tests for the DiseaseDetector class. +# Ensures the model loads successfully and returns valid detections on dummy input. + +import pytest +from agri_baseline.src.detectors.disease_model import DiseaseDetector +import numpy as np + +def test_disease_detector_model_loads(): + detector = DiseaseDetector(model_path="models/cnn_multi_stage3.pth") + assert detector.model is not None, "Model failed to load correctly." + +def test_disease_detector_predicts(): + detector = DiseaseDetector() + img = np.zeros((224, 224, 3)) # Dummy image for testing + detections = detector.run(img) + assert len(detections) > 0, "Model did not return any detections." + assert detections[0].confidence > 0, "Detection confidence should be greater than 0." diff --git a/airflow_bundle/leaf-pipeline/projects/Detection_Jobs/Detection_Jobs/tests/test_minio_integration_mock.py b/airflow_bundle/leaf-pipeline/projects/Detection_Jobs/Detection_Jobs/tests/test_minio_integration_mock.py new file mode 100644 index 000000000..369b7070c --- /dev/null +++ b/airflow_bundle/leaf-pipeline/projects/Detection_Jobs/Detection_Jobs/tests/test_minio_integration_mock.py @@ -0,0 +1,120 @@ +# Purpose: Mock-based integration tests for MinIO storage. +# Simulates MinIO object downloads, saves them locally, and verifies images can be loaded successfully. + +from __future__ import annotations + +from io import BytesIO +from pathlib import Path +from typing import Dict, Iterable + +import pytest +from PIL import Image + +from agri_baseline.src.storage import minio_sync +from agri_baseline.src.storage.minio_client import MinioConfig +from agri_baseline.src.pipeline.utils import load_image + + +class _FakeObj: + """Mimics the object returned by client.list_objects().""" + + def __init__(self, object_name: str) -> None: + self.object_name = object_name + + +class _FakeResponse: + """ + Minimal MinIO get_object-like response object. + + Provides: + - read(amt: int | None = None) -> bytes + - close() -> None + - release_conn() -> None + + This mirrors what MinIO/urllib3 responses typically expose, so production code + that calls release_conn() won't fail under the mock. + """ + + def __init__(self, data: bytes) -> None: + self._buf = BytesIO(data) + + def read(self, amt: int | None = None) -> bytes: + return self._buf.read() if amt is None else self._buf.read(amt) + + def close(self) -> None: + self._buf.close() + + def release_conn(self) -> None: + # In real clients this releases underlying HTTP resources. + # No-op here is fine for tests. + pass + + +class _FakeMinio: + """ + Fake MinIO client that supports the subset used by minio_sync: + - list_objects(bucket, prefix, recursive) -> Iterable[_FakeObj] + - get_object(bucket, key) -> _FakeResponse + """ + + def __init__(self, payload_by_key: Dict[str, bytes]) -> None: + self._payload_by_key = payload_by_key + + def list_objects(self, bucket: str, prefix: str, recursive: bool) -> Iterable[_FakeObj]: + for key in self._payload_by_key: + if key.startswith(prefix) and not key.endswith("/"): + yield _FakeObj(key) + + def get_object(self, bucket: str, key: str) -> _FakeResponse: + data = self._payload_by_key[key] + return _FakeResponse(data) + + +@pytest.fixture +def fake_jpeg() -> bytes: + """Create a tiny deterministic JPEG in-memory.""" + img = Image.new("RGB", (32, 24), (10, 20, 30)) + buf = BytesIO() + img.save(buf, format="JPEG") + return buf.getvalue() + + +def test_minio_download_and_load(monkeypatch: pytest.MonkeyPatch, + tmp_path: Path, + fake_jpeg: bytes) -> None: + """ + Flow under test: + 1) list prefix from MinIO (fake). + 2) download files to local cache dir. + 3) ensure those files exist and can be loaded with load_image. + """ + + # 1) Arrange fake MinIO payload (two images under mission-123/) + payload = { + "mission-123/imgA.jpg": fake_jpeg, + "mission-123/imgB.jpg": fake_jpeg, + } + fake_client = _FakeMinio(payload) + + # 2) Monkeypatch build_client to return our fake client + monkeypatch.setattr(minio_sync, "build_client", lambda cfg: fake_client, raising=True) + + # 3) Prepare config and download target folder + cfg = MinioConfig( + endpoint="127.0.0.1:9000", + access_key="minioadmin", + secret_key="minioadmin", + bucket="leaves", + secure=False, + ) + out_dir = tmp_path / "cache" + + # 4) Act: download objects to local dir + paths = minio_sync.download_prefix_to_dir(cfg, prefix="mission-123", local_dir=out_dir) + + # 5) Assert: files were written and are loadable + assert len(paths) == 2, f"Expected 2 files, got {len(paths)}" + for p in paths: + assert p.exists() and p.is_file(), f"Missing file: {p}" + img, w, h = load_image(str(p)) + assert img is not None and w > 0 and h > 0, f"Failed to load image {p}" diff --git a/airflow_bundle/leaf-pipeline/projects/Detection_Jobs/Detection_Jobs/tests/test_run_detectors.py b/airflow_bundle/leaf-pipeline/projects/Detection_Jobs/Detection_Jobs/tests/test_run_detectors.py new file mode 100644 index 000000000..6e0f13e26 --- /dev/null +++ b/airflow_bundle/leaf-pipeline/projects/Detection_Jobs/Detection_Jobs/tests/test_run_detectors.py @@ -0,0 +1,21 @@ +# Purpose: Tests for running the DiseaseDetector model. +# Checks that detections are produced with valid confidence values on dummy images. + +import pytest +from agri_baseline.src.detectors.disease_model import DiseaseDetector +import numpy as np + +@pytest.fixture +def dummy_image(): + """Provide a dummy image for testing.""" + return np.zeros((224, 224, 3)) # Black dummy image + +def test_disease_detector_runs(dummy_image): + detector = DiseaseDetector() + detections = detector.run(dummy_image) + assert len(detections) > 0, "Disease detection did not return any detections." + assert detections[0].confidence > 0, "Detection confidence should be greater than 0." + +def test_disease_detector_model_loads(): + detector = DiseaseDetector(model_path="models/cnn_multi_stage3.pth") + assert detector.model is not None, "Model failed to load correctly." diff --git a/airflow_bundle/leaf-pipeline/projects/Detection_Jobs/Detection_Jobs/tests/test_utils_local.py b/airflow_bundle/leaf-pipeline/projects/Detection_Jobs/Detection_Jobs/tests/test_utils_local.py new file mode 100644 index 000000000..ab76fe8cf --- /dev/null +++ b/airflow_bundle/leaf-pipeline/projects/Detection_Jobs/Detection_Jobs/tests/test_utils_local.py @@ -0,0 +1,27 @@ +# Purpose: Local unit tests for utility functions. +# Covers image loading, image ID extraction, and bounding box clamping logic. + +from pathlib import Path +from PIL import Image +from agri_baseline.src.pipeline.utils import load_image, image_id_from_path, clamp_bbox + +def _write_test_image(tmp_dir: Path, name: str = "test.jpg") -> Path: + img = Image.new("RGB", (64, 48), (127, 200, 50)) + path = tmp_dir / name + img.save(path, format="JPEG") + return path + +def test_load_image_local(tmp_path: Path): + img_path = _write_test_image(tmp_path) + img, w, h = load_image(str(img_path)) + assert img is not None + assert (w, h) == (64, 48) + +def test_image_id_from_path_no_fs(tmp_path: Path): + fake_path = tmp_path / "nested" / "test.jpg" # no file needed + image_id = image_id_from_path(str(fake_path)) + assert isinstance(image_id, str) and image_id + +def test_clamp_bbox_pure(): + x, y, w, h = clamp_bbox(10, 10, 250, 250, 224, 224) + assert x >= 0 and y >= 0 and w <= 224 and h <= 224 diff --git a/airflow_bundle/leaf-pipeline/projects/Detection_Jobs/Detection_Jobs/tests/test_validator.py b/airflow_bundle/leaf-pipeline/projects/Detection_Jobs/Detection_Jobs/tests/test_validator.py new file mode 100644 index 000000000..17957c530 --- /dev/null +++ b/airflow_bundle/leaf-pipeline/projects/Detection_Jobs/Detection_Jobs/tests/test_validator.py @@ -0,0 +1,119 @@ +# Purpose: Integration tests for the Validator module. +# Verifies event logging from findings and correctness of batch summary generation in the database. + +import pytest +from sqlalchemy import text + +from agri_baseline.src.validator.validator import Validator +from agri_baseline.src.validator.rules import Finding +from agri_baseline.src.pipeline.db import get_engine +from agri_baseline.src.pipeline import config +from agri_baseline.src.pipeline.db import get_engine + +@pytest.fixture(autouse=True) +def _seed_anomalies_for_summary(): + """ + Ensure the DB has minimal data for batch_summary: + - device 'device-1' + - anomaly type id=1 + - mission id=1 with small polygon + - two anomalies with same image_id and non-null geom + Idempotent: safe to run before every test. + """ + with get_engine().begin() as conn: + conn.exec_driver_sql(""" + INSERT INTO devices(device_id, model, owner, active) + VALUES ('device-1','sim','lab',true) + ON CONFLICT (device_id) DO NOTHING; + """) + conn.exec_driver_sql(""" + INSERT INTO anomaly_types(anomaly_type_id, code, description) + VALUES (1,'disease_spot','Leaf disease spot') + ON CONFLICT (anomaly_type_id) DO NOTHING; + """) + conn.exec_driver_sql(""" + INSERT INTO missions(mission_id, start_time, area_geom) + VALUES (1, now(), ST_GeomFromText('POLYGON((0 0,1 0,1 1,0 1,0 0))',4326)) + ON CONFLICT (mission_id) DO NOTHING; + """) + conn.exec_driver_sql(""" + INSERT INTO anomalies(mission_id, device_id, ts, anomaly_type_id, severity, details, geom) + VALUES + (1, 'device-1', now(), 1, 0.6, + '{"image_id":"seed_img_for_summary"}'::jsonb, + ST_GeomFromText('POINT(0.50 0.50)',4326)), + (1, 'device-1', now(), 1, 0.7, + '{"image_id":"seed_img_for_summary"}'::jsonb, + ST_GeomFromText('POINT(0.55 0.52)',4326)) + ON CONFLICT DO NOTHING; + """) + yield + +@pytest.fixture +def dummy_finding() -> Finding: + """ + Create a minimal Finding to simulate a validator output. + Scope/value names should match your Validator implementation. + """ + return Finding( + scope="image", + image_id="test_image", + rule="bbox_oob", + severity="warn", + message="BBox out of bounds", + ) + + +def _count(conn, sql: str, params: dict | None = None) -> int: + """ + Small helper: run a COUNT(*) query safely with SQLAlchemy 2.0. + """ + return conn.execute(text(sql), params or {}).scalar() or 0 + + +def test_validator_image_findings(dummy_finding: Finding): + """ + Ensure validator writes a record into event_logs for the given finding. + We assert a strictly increasing count for the message we inserted. + """ + validator = Validator() + + with get_engine().begin() as conn: + before = _count( + conn, + "SELECT COUNT(1) FROM event_logs WHERE message = :msg", + {"msg": dummy_finding.message}, + ) + + validator.image_findings([dummy_finding]) + + with get_engine().begin() as conn: + after = _count( + conn, + "SELECT COUNT(1) FROM event_logs WHERE message = :msg", + {"msg": dummy_finding.message}, + ) + + assert after > before, "Finding was not written to event_logs." + + +def test_batch_summary(): + """ + Run batch_summary and verify tile_stats is populated or remains populated. + We allow idempotency (>=) but also require that there is some data (> 0). + """ + validator = Validator() + + with get_engine().begin() as conn: + print("DEBUG DB_URL:", config.DB_URL) + print("DEBUG anomalies:", conn.exec_driver_sql("SELECT COUNT(*) FROM anomalies").scalar()) + print("DEBUG tile_stats:", conn.exec_driver_sql("SELECT COUNT(*) FROM tile_stats").scalar()) + before = _count(conn, "SELECT COUNT(1) FROM tile_stats") + + validator.batch_summary() + + with get_engine().begin() as conn: + after = _count(conn, "SELECT COUNT(1) FROM tile_stats") + + assert after >= before, "tile_stats count unexpectedly decreased." + assert after > 0, "No images found in tile_stats for batch summary." diff --git a/airflow_bundle/leaf-pipeline/projects/Detection_Jobs/Makefile b/airflow_bundle/leaf-pipeline/projects/Detection_Jobs/Makefile new file mode 100644 index 000000000..94e2c9d4c --- /dev/null +++ b/airflow_bundle/leaf-pipeline/projects/Detection_Jobs/Makefile @@ -0,0 +1,7 @@ +.PHONY: ci-e2e ci-detection + +ci-e2e: + cd e2e_kafka_flink && pytest tests --cov=e2e_pipeline --cov-report=xml --maxfail=1 -v + +ci-detection: + cd agri-baseline && pytest tests --cov=agri_baseline --cov-report=xml --maxfail=1 -v diff --git a/airflow_bundle/leaf-pipeline/projects/Detection_Jobs/VENDORED_FROM.txt b/airflow_bundle/leaf-pipeline/projects/Detection_Jobs/VENDORED_FROM.txt new file mode 100644 index 000000000..b1786ac32 --- /dev/null +++ b/airflow_bundle/leaf-pipeline/projects/Detection_Jobs/VENDORED_FROM.txt @@ -0,0 +1,4 @@ +VENDORED FROM (saved 2025-11-09T04:18:09+02:00) +------------------------------------- +origin https://github.com/KamaTechOrg/AgCloud.git (fetch) [blob:none] +origin https://github.com/KamaTechOrg/AgCloud.git (push) diff --git a/airflow_bundle/leaf-pipeline/projects/disease-monitor/.gitignore b/airflow_bundle/leaf-pipeline/projects/disease-monitor/.gitignore new file mode 100644 index 000000000..06593e32e --- /dev/null +++ b/airflow_bundle/leaf-pipeline/projects/disease-monitor/.gitignore @@ -0,0 +1,55 @@ +# ==== OS / IDE ==== +.DS_Store +Thumbs.db +.vscode/ +.idea/ + +# ==== Node ==== +node_modules/ +dist/ + +# ==== Python ==== +__pycache__/ +*.py[cod] +*.pyc +*.pyo +*.so +*.dylib + +# ==== Virtual envs ==== +.venv/ +venv/ +ENV/ +env/ + +# ==== Packaging / build ==== +build/ +*.egg-info/ + +# ==== Environment / Secrets ==== +.env +.env.* + +# ==== Data / Notebooks / Logs ==== +*.log +*.ipynb +.ipynb_checkpoints/ + +# ==== Artifacts / Wheels / Models ==== +artifacts/ +.wheels/ +wheels/ +*.whl +*.pt +*.pth +*.bin + +# ==== Coverage reports ==== +.pytest_cache/ +.coverage +coverage.xml +htmlcov/ + +# ==== gRPC generated (נוצרים בבילד דוקר) ==== +server/embed_pb2.py +server/embed_pb2_grpc.py diff --git a/airflow_bundle/leaf-pipeline/projects/disease-monitor/Makefile b/airflow_bundle/leaf-pipeline/projects/disease-monitor/Makefile new file mode 100644 index 000000000..94e2c9d4c --- /dev/null +++ b/airflow_bundle/leaf-pipeline/projects/disease-monitor/Makefile @@ -0,0 +1,7 @@ +.PHONY: ci-e2e ci-detection + +ci-e2e: + cd e2e_kafka_flink && pytest tests --cov=e2e_pipeline --cov-report=xml --maxfail=1 -v + +ci-detection: + cd agri-baseline && pytest tests --cov=agri_baseline --cov-report=xml --maxfail=1 -v diff --git a/airflow_bundle/leaf-pipeline/projects/disease-monitor/VENDORED_FROM.txt b/airflow_bundle/leaf-pipeline/projects/disease-monitor/VENDORED_FROM.txt new file mode 100644 index 000000000..b1786ac32 --- /dev/null +++ b/airflow_bundle/leaf-pipeline/projects/disease-monitor/VENDORED_FROM.txt @@ -0,0 +1,4 @@ +VENDORED FROM (saved 2025-11-09T04:18:09+02:00) +------------------------------------- +origin https://github.com/KamaTechOrg/AgCloud.git (fetch) [blob:none] +origin https://github.com/KamaTechOrg/AgCloud.git (push) diff --git a/airflow_bundle/leaf-pipeline/projects/disease-monitor/disease-monitor/.dockerignore b/airflow_bundle/leaf-pipeline/projects/disease-monitor/disease-monitor/.dockerignore new file mode 100644 index 000000000..9bd273d4f --- /dev/null +++ b/airflow_bundle/leaf-pipeline/projects/disease-monitor/disease-monitor/.dockerignore @@ -0,0 +1,25 @@ +# Python +__pycache__/ +*.pyc +*.pyo +*.pyd +.Python +.venv/ +venv/ + +# Tests and caches +.pytest_cache/ +tests/ + +# Local data / artifacts +data/ +alerts.db + +# Git +.git/ +.gitignore + +# IDE +.vscode/ +.idea/ + diff --git a/airflow_bundle/leaf-pipeline/projects/disease-monitor/disease-monitor/.gitignore b/airflow_bundle/leaf-pipeline/projects/disease-monitor/disease-monitor/.gitignore new file mode 100644 index 000000000..da73fe5e4 --- /dev/null +++ b/airflow_bundle/leaf-pipeline/projects/disease-monitor/disease-monitor/.gitignore @@ -0,0 +1,2 @@ +# Ignore local data +data/ diff --git a/airflow_bundle/leaf-pipeline/projects/disease-monitor/disease-monitor/Dockerfile b/airflow_bundle/leaf-pipeline/projects/disease-monitor/disease-monitor/Dockerfile new file mode 100644 index 000000000..49d99c76c --- /dev/null +++ b/airflow_bundle/leaf-pipeline/projects/disease-monitor/disease-monitor/Dockerfile @@ -0,0 +1,34 @@ +FROM mcr.microsoft.com/devcontainers/python:1-3.11-bullseye + +WORKDIR /app + +# 1) Install CA tools and curl +RUN apt-get update && apt-get install -y --no-install-recommends \ + ca-certificates curl \ + && rm -rf /var/lib/apt/lists/* + +# 2) Add NetFree certificate and register in system trust store + +# Ensure Python, requests, and pip use the updated CA bundle +ENV REQUESTS_CA_BUNDLE=/etc/ssl/certs/ca-certificates.crt +ENV SSL_CERT_FILE=/etc/ssl/certs/ca-certificates.crt +ENV PIP_CERT=/etc/ssl/certs/ca-certificates.crt + +# 3) Install Python dependencies (use trusted hosts to simplify NetFree path) +COPY requirements.txt /app/requirements.txt +RUN pip install --trusted-host pypi.org --trusted-host pypi.python.org \ + --trusted-host files.pythonhosted.org --no-cache-dir -r requirements.txt + +# 4) Install the package (PEP517) with the same trusted hosts +COPY pyproject.toml README.md /app/ +COPY src /app/src +RUN pip install --trusted-host pypi.org --trusted-host pypi.python.org \ + --trusted-host files.pythonhosted.org --no-cache-dir . + +# 5) Copy configs (can be overridden by a bind mount) +COPY configs /app/configs + +ENV PYTHONUNBUFFERED=1 + +ENTRYPOINT ["python", "-m", "disease_monitor.cli"] + diff --git a/airflow_bundle/leaf-pipeline/projects/disease-monitor/disease-monitor/Dockerfile.local b/airflow_bundle/leaf-pipeline/projects/disease-monitor/disease-monitor/Dockerfile.local new file mode 100644 index 000000000..e9dfc92b9 --- /dev/null +++ b/airflow_bundle/leaf-pipeline/projects/disease-monitor/disease-monitor/Dockerfile.local @@ -0,0 +1,32 @@ +FROM docker.io/library/python@sha256:e0c4fae70d550834a40f6c3e0326e02cfe239c2351d922e1fb1577a3c6ebde02 + +WORKDIR /app + +# 1) כלים בסיסיים ותעודות + requests (דרך APT) בלי לפנות ל-PyPI עבור החבילה הזו +RUN apt-get update && apt-get install -y --no-install-recommends \ + ca-certificates curl python3-requests \ + && rm -rf /var/lib/apt/lists/* + +# הגדרות SSL סטנדרטיות (משתמשים בתעודות מערכת) +ENV REQUESTS_CA_BUNDLE=/etc/ssl/certs/ca-certificates.crt +ENV SSL_CERT_FILE=/etc/ssl/certs/ca-certificates.crt +ENV PIP_CERT=/etc/ssl/certs/ca-certificates.crt + +# 2) התקנת תלויות פייתון של הפרויקט +COPY requirements.txt /app/requirements.txt +RUN pip install --trusted-host pypi.org --trusted-host pypi.python.org \ + --trusted-host files.pythonhosted.org --no-cache-dir -r requirements.txt +RUN pip install --trusted-host pypi.org --trusted-host pypi.python.org \ + --trusted-host files.pythonhosted.org --no-cache-dir requests + +# 3) התקנת החבילה עצמה +COPY pyproject.toml README.md /app/ +COPY src /app/src +RUN pip install --trusted-host pypi.org --trusted-host pypi.python.org \ + --trusted-host files.pythonhosted.org --no-cache-dir . + +# 4) קונפיגים +COPY configs /app/configs + +ENV PYTHONUNBUFFERED=1 +ENTRYPOINT ["python", "-m", "disease_monitor.cli"] diff --git a/airflow_bundle/leaf-pipeline/projects/disease-monitor/disease-monitor/README.md b/airflow_bundle/leaf-pipeline/projects/disease-monitor/disease-monitor/README.md new file mode 100644 index 000000000..a0eb9bda6 --- /dev/null +++ b/airflow_bundle/leaf-pipeline/projects/disease-monitor/disease-monitor/README.md @@ -0,0 +1,127 @@ +# Disease Monitor (Offline) + +Offline batch job that reads disease detections from **Postgres**, aggregates data, +builds baselines, detects anomalies/worsening, deduplicates & rate-limits alerts, +delivers notifications (Slack/Webhook/Email), and writes alerts back to **Postgres**. + +> **Note:** The pipeline uses **Postgres only** (both sources and sink). No CSV/SQLite. + +--- + +## Data Sources & Sink (Postgres) + +**Sources** +- `anomalies` — per-image detections (0..N rows per image) +- `tile_stats` — exactly 1 row per image (summary) +- `event_logs` — QA & validator logs (rules/metrics/errors) + +**Sink** +- `alerts` — unified alerts table (rules: `COUNT_SPIKE`, `WORSENING_TREND`) + +--- + +## Config (`configs/config.example.yaml`) + +Main sections: +- **io**: Postgres URL (e.g. `postgresql+psycopg2://user:pass@host:5432/db`) +- **windows**: frequency (`"D"`/`"W"`), timezone (e.g. `"UTC"`) +- **baseline**: method (`mean`/`median`), lookback, min_history, optional seasonality +- **rules**: thresholds & toggles for `count_anomaly` (zscore/iqr) and `worsening` (slope/ewma) +- **alerting**: dedup cooldown (windows), resolve-after-no-anomaly, per-run rate limit, group_by_window +- **delivery**: slack/webhook/email targets (can be disabled) +- **run**: `dry_run` and optional filters + +Example: +```yaml +io: + postgres_url: "postgresql+psycopg2://missions_user:pg123@localhost:5432/missions_db" + +windows: + frequency: "D" + timezone: "UTC" + +baseline: + method: "median" + lookback_periods: 28 + min_history: 7 + seasonality: null + +rules: + count_anomaly: + enabled: true + method: "zscore" + z_threshold: 3.0 + iqr_k: 1.5 + min_count: 3 + worsening: + enabled: true + method: "slope" + slope_lookback: 7 + slope_min: 0.02 + min_periods: 5 + ewma_span: 7 + ewma_threshold: 0.6 + +alerting: + dedup_cooldown_windows: 3 + resolve_after_no_anomaly: 3 + rate_limit_per_run: 100 + group_by_window: true + +delivery: + slack: + enabled: false + webhook_url: "" + webhook: + enabled: false + url: "" + headers: {} + email: + enabled: false + smtp_host: "" + smtp_port: 587 + username: "" + password_env: "SMTP_PASSWORD" + from_addr: "" + to_addrs: [] + +run: + dry_run: false +``` + +--- + +## Install & Run + +```bash +# Create & activate venv (Linux/Mac) +python -m venv .venv +source .venv/bin/activate + +# On Windows (PowerShell): +# python -m venv .venv +# .venv\Scripts\Activate.ps1 + +# Install +pip install -r requirements.txt + +# Run +python -m disease_monitor.cli --config configs/config.example.yaml --log-level INFO +``` + +--- + +## Tests + +```bash +pytest +``` + +--- + + +## Notes + +- Thresholds, lookbacks, and active rules are fully configurable from YAML. +- Logs and runtime counters are emitted to stdout. +- Extend notifiers in `src/disease_monitor/notifiers`. diff --git a/airflow_bundle/leaf-pipeline/projects/disease-monitor/disease-monitor/configs/config.docker.yaml b/airflow_bundle/leaf-pipeline/projects/disease-monitor/disease-monitor/configs/config.docker.yaml new file mode 100644 index 000000000..fb46acdfe --- /dev/null +++ b/airflow_bundle/leaf-pipeline/projects/disease-monitor/disease-monitor/configs/config.docker.yaml @@ -0,0 +1,64 @@ +io: + # IMPORTANT: use the Docker service name of Postgres (from your compose): + postgres_url: "postgresql+psycopg2://missions_user:pg123@postgres:5432/missions_db" + +windows: + frequency: "D" + timezone: "UTC" + +source_mapping: + entity_dim: "mission" # or "region"/"device" + area_strategy: "none" # or "region_area" (requires regions table/geom) + filters: + start_time: null + end_time: null + anomaly_codes: null + +baseline: + method: "median" + lookback_periods: 28 + min_history: 7 + seasonality: null + +rules: + count_anomaly: + enabled: true + method: "zscore" + z_threshold: 3.0 + iqr_k: 1.5 + min_count: 3 + worsening: + enabled: true + method: "slope" + slope_lookback: 7 + slope_min: 0.02 + min_periods: 5 + ewma_span: 7 + ewma_threshold: 0.6 + +alerting: + dedup_cooldown_windows: 3 + resolve_after_no_anomaly: 3 + rate_limit_per_run: 100 + group_by_window: true + +delivery: + slack: + enabled: false + webhook_url: "" + webhook: + enabled: false + url: "" + headers: {} + email: + enabled: false + smtp_host: "" + smtp_port: 587 + username: "" + password_env: "SMTP_PASSWORD" + from_addr: "" + to_addrs: [] + +run: + dry_run: false + diff --git a/airflow_bundle/leaf-pipeline/projects/disease-monitor/disease-monitor/configs/config.example.yaml b/airflow_bundle/leaf-pipeline/projects/disease-monitor/disease-monitor/configs/config.example.yaml new file mode 100644 index 000000000..6d4d8d699 --- /dev/null +++ b/airflow_bundle/leaf-pipeline/projects/disease-monitor/disease-monitor/configs/config.example.yaml @@ -0,0 +1,70 @@ +io: + postgres_url: "postgresql+psycopg2://missions_user:pg123@localhost:5432/missions_db" + +windows: + frequency: "D" + timezone: "UTC" + +source_mapping: + entity_dim: "mission" # or "region"/"device" + area_strategy: "none" # or "region_area" + filters: + start_time: null + end_time: null + anomaly_codes: null + +baseline: + method: "median" + lookback_periods: 28 + min_history: 7 + seasonality: null + +rules: + count_anomaly: + enabled: true + method: "zscore" + z_threshold: 3.0 + iqr_k: 1.5 + min_count: 3 + worsening: + enabled: true + method: "slope" + slope_lookback: 7 + slope_min: 0.02 + min_periods: 5 + ewma_span: 7 + ewma_threshold: 0.6 + +alerting: + dedup_cooldown_windows: 3 + resolve_after_no_anomaly: 3 + rate_limit_per_run: 100 + group_by_window: true + +delivery: + slack: + enabled: false + webhook_url: "" # paste Slack Webhook URL here if you want to enable + webhook: + enabled: false + url: "" # paste your Webhook URL here if you want to enable + headers: {} # optional headers map + email: + enabled: false + smtp_host: "" # paste your SMTP server address here if you want to enable + smtp_port: 587 + username: "" + password_env: "SMTP_PASSWORD" + from_addr: "" + to_addrs: [] + + alertmanager: + enabled: false + url: "http://localhost:9093" + default_severity: "warning" + extra_labels: + system: "disease-monitor" + team: "ag" + +run: + dry_run: false diff --git a/airflow_bundle/leaf-pipeline/projects/disease-monitor/disease-monitor/configs/disease_monitor.yaml b/airflow_bundle/leaf-pipeline/projects/disease-monitor/disease-monitor/configs/disease_monitor.yaml new file mode 100644 index 000000000..832d6daf7 --- /dev/null +++ b/airflow_bundle/leaf-pipeline/projects/disease-monitor/disease-monitor/configs/disease_monitor.yaml @@ -0,0 +1,68 @@ +io: + # IMPORTANT: use the Docker service name of Postgres (from your compose): + postgres_url: "postgresql+psycopg2://missions_user:pg123@postgres:5432/missions_db" + +windows: + frequency: "D" + timezone: "UTC" + +source_mapping: + entity_dim: "device" + area_strategy: "none" # or "region_area" (requires regions table/geom) + filters: + start_time: null + end_time: null + anomaly_codes: null + +baseline: + method: "median" + lookback_periods: 28 + min_history: 7 + seasonality: null + +rules: + count_anomaly: + enabled: true + method: "zscore" + z_threshold: 3.0 + iqr_k: 1.5 + min_count: 3 + worsening: + enabled: true + method: "slope" + slope_lookback: 7 + slope_min: 0.02 + min_periods: 5 + ewma_span: 7 + ewma_threshold: 0.6 + +alerting: + dedup_cooldown_windows: 3 + resolve_after_no_anomaly: 3 + rate_limit_per_run: 100 + group_by_window: true + +delivery: + kafka: + enabled: false + brokers: "kafka:9092" + topic: "alerts" + slack: + enabled: false + webhook_url: "" + webhook: + enabled: false + url: "" + headers: {} + email: + enabled: false + smtp_host: "" + smtp_port: 587 + username: "" + password_env: "SMTP_PASSWORD" + from_addr: "" + to_addrs: [] + +run: + dry_run: false + diff --git a/airflow_bundle/leaf-pipeline/projects/disease-monitor/disease-monitor/docker-compose.yml b/airflow_bundle/leaf-pipeline/projects/disease-monitor/disease-monitor/docker-compose.yml new file mode 100644 index 000000000..15a0a0baa --- /dev/null +++ b/airflow_bundle/leaf-pipeline/projects/disease-monitor/disease-monitor/docker-compose.yml @@ -0,0 +1,21 @@ +services: + disease-monitor: + build: + context: . + dockerfile: Dockerfile + image: disease-monitor:latest + command: ["--config", "/app/configs/config.docker.yaml", "--log-level", "INFO"] + environment: + TZ: "UTC" + # If you enable email delivery and use password_env=SMTP_PASSWORD: + # SMTP_PASSWORD: "your-smtp-password" + volumes: + - ./configs:/app/configs:ro + networks: + - worktree-main_ag_cloud + restart: on-failure + +networks: + # Use the external network created by your worktree-main compose + worktree-main_ag_cloud: + external: true diff --git a/airflow_bundle/leaf-pipeline/projects/disease-monitor/disease-monitor/pyproject.toml b/airflow_bundle/leaf-pipeline/projects/disease-monitor/disease-monitor/pyproject.toml new file mode 100644 index 000000000..063c0697f --- /dev/null +++ b/airflow_bundle/leaf-pipeline/projects/disease-monitor/disease-monitor/pyproject.toml @@ -0,0 +1,14 @@ +[build-system] +requires = ["setuptools>=68", "wheel"] +build-backend = "setuptools.build_meta" + +[project] +name = "disease-monitor" +version = "0.1.0" +description = "Offline anomaly & worsening detection for disease cases in trees/plots/regions." +readme = "README.md" +requires-python = ">=3.10" + +[tool.pytest.ini_options] +pythonpath = ["src"] +addopts = "-q" diff --git a/airflow_bundle/leaf-pipeline/projects/disease-monitor/disease-monitor/requirements.txt b/airflow_bundle/leaf-pipeline/projects/disease-monitor/disease-monitor/requirements.txt new file mode 100644 index 000000000..46a776a5b --- /dev/null +++ b/airflow_bundle/leaf-pipeline/projects/disease-monitor/disease-monitor/requirements.txt @@ -0,0 +1,11 @@ +pandas>=2.3.0,<2.4 +numpy>=2.2,<2.4 +pyyaml==6.0.2 +sqlalchemy==2.0.32 +pydantic==2.9.2 +scipy>=1.14.1,<1.15 +pytest==8.3.2 +python-dateutil==2.9.0.post0 +psycopg2-binary==2.9.7 +requests>=2.31 +kafka-python==2.0.2 diff --git a/airflow_bundle/leaf-pipeline/projects/disease-monitor/disease-monitor/src/disease_monitor/__init__.py b/airflow_bundle/leaf-pipeline/projects/disease-monitor/disease-monitor/src/disease_monitor/__init__.py new file mode 100644 index 000000000..a9a2c5b3b --- /dev/null +++ b/airflow_bundle/leaf-pipeline/projects/disease-monitor/disease-monitor/src/disease_monitor/__init__.py @@ -0,0 +1 @@ +__all__ = [] diff --git a/airflow_bundle/leaf-pipeline/projects/disease-monitor/disease-monitor/src/disease_monitor/alerting.py b/airflow_bundle/leaf-pipeline/projects/disease-monitor/disease-monitor/src/disease_monitor/alerting.py new file mode 100644 index 000000000..fe3802cba --- /dev/null +++ b/airflow_bundle/leaf-pipeline/projects/disease-monitor/disease-monitor/src/disease_monitor/alerting.py @@ -0,0 +1,99 @@ +from __future__ import annotations +import logging +from datetime import datetime +from typing import Dict, Any, List, Tuple +import pandas as pd + +LOGGER = logging.getLogger(__name__) + +def _merge_reasons(s: pd.Series) -> list[str]: + items = [] + for x in s: + if isinstance(x, (list, tuple, set)): + items.extend(list(x)) + else: + items.append(str(x)) + return sorted(set(items)) + +def enforce_policies(candidates: pd.DataFrame, open_alerts_df: pd.DataFrame, + cfg: Dict[str, Any]) -> List[Dict[str, Any]]: + """ + Deduplicate per (entity, rule) with cooldown; update OPEN alerts if still anomalous; + create RESOLVED entries after consecutive non-anomalous windows (handled by absence). + Rate limiting applied. + """ + if candidates.empty: + return [] + + candidates = candidates.copy() + candidates["window_start"] = pd.to_datetime(candidates["window"]) + candidates["window_end"] = pd.to_datetime(candidates["window_end"]) + candidates["first_seen"] = candidates["window_start"] + candidates["last_seen"] = candidates["window_end"] + candidates["status"] = "OPEN" + + # Dedup cooldown: skip if there is OPEN/ACK within last N windows for same (entity, rule) + cooldown = cfg["alerting"]["dedup_cooldown_windows"] + frequency = cfg["windows"]["frequency"] + + alerts_out: List[Dict[str, Any]] = [] + rate_limit = cfg["alerting"]["rate_limit_per_run"] + emitted = 0 + + # Grouping by window if requested + if cfg["alerting"]["group_by_window"]: + group_keys = ["entity_id", "rule", "window_start", "window_end"] + else: + group_keys = ["entity_id", "rule"] + + g = candidates.groupby(group_keys, as_index=False).agg({ + "score": "max", + "disease_count": "max", + "avg_severity": "max", + "affected_area": "max", + "reason": _merge_reasons +}) + + for _, row in g.iterrows(): + if emitted >= rate_limit: + LOGGER.warning("Rate limit reached (%d).", rate_limit) + break + entity, rule = row["entity_id"], row["rule"] + ws, we = row["window_start"], row["window_end"] + # Check cooldown against open alerts + if not open_alerts_df.empty: + same = open_alerts_df[(open_alerts_df["entity_id"] == entity) & + (open_alerts_df["rule"] == rule)] + # In cooldown if last_seen within last cooldown windows + recent = same[same["last_seen"] >= (ws - _windows_to_offset(frequency, cooldown))] + if not recent.empty: + LOGGER.info("Cooldown skip for %s/%s at %s.", entity, rule, ws) + continue + + meta = { + "reasons": row["reason"], + "disease_count": int(row["disease_count"]), + "avg_severity": float(row["avg_severity"]), + "affected_area": float(row["affected_area"]), + } + alerts_out.append({ + "entity_id": entity, + "rule": rule, + "window_start": ws.to_pydatetime(), + "window_end": we.to_pydatetime(), + "score": float(row["score"]), + "first_seen": ws.to_pydatetime(), + "last_seen": we.to_pydatetime(), + "status": "OPEN", + "meta": meta + }) + emitted += 1 + + return alerts_out + +def _windows_to_offset(freq: str, n: int) -> pd.Timedelta: + if n <= 0: + return pd.Timedelta(0) + if freq.upper().startswith("W"): + return pd.to_timedelta(7 * n, unit="D") + return pd.to_timedelta(n, unit="D") diff --git a/airflow_bundle/leaf-pipeline/projects/disease-monitor/disease-monitor/src/disease_monitor/baseline.py b/airflow_bundle/leaf-pipeline/projects/disease-monitor/disease-monitor/src/disease_monitor/baseline.py new file mode 100644 index 000000000..71e976c08 --- /dev/null +++ b/airflow_bundle/leaf-pipeline/projects/disease-monitor/disease-monitor/src/disease_monitor/baseline.py @@ -0,0 +1,38 @@ +from __future__ import annotations +import pandas as pd +import numpy as np + +def compute_baseline(agg: pd.DataFrame, method: str, lookback: int, + min_history: int, seasonality: int | None) -> pd.DataFrame: + """ + Returns agg with baseline columns for disease_count, avg_severity, affected_area: + *_bl, *_std (or IQR helpers). + """ + df = agg.sort_values(["entity_id", "window"]).copy() + keys = ["entity_id"] + metrics = ["disease_count", "avg_severity", "affected_area"] + + # Optionally seasonal lag indexing + if seasonality and seasonality > 1: + df["season_index"] = df.groupby(keys)["window"].rank(method="first").astype(int) % seasonality + groupers = keys + ["season_index"] + else: + groupers = keys + + for m in metrics: + if method == "mean": + bl = df.groupby(groupers)[m].transform(lambda s: s.shift(1).rolling(lookback, min_periods=min_history).mean()) + sd = df.groupby(groupers)[m].transform(lambda s: s.shift(1).rolling(lookback, min_periods=min_history).std(ddof=0)) + else: + bl = df.groupby(groupers)[m].transform(lambda s: s.shift(1).rolling(lookback, min_periods=min_history).median()) + sd = df.groupby(groupers)[m].transform(lambda s: s.shift(1).rolling(lookback, min_periods=min_history).std(ddof=0)) + df[f"{m}_bl"] = bl.fillna(0.0) + df[f"{m}_std"] = sd.fillna(0.0) + + # IQR helpers + q1 = df.groupby(groupers)[m].transform(lambda s: s.shift(1).rolling(lookback, min_periods=min_history).quantile(0.25)) + q3 = df.groupby(groupers)[m].transform(lambda s: s.shift(1).rolling(lookback, min_periods=min_history).quantile(0.75)) + df[f"{m}_q1"] = q1.fillna(0.0) + df[f"{m}_q3"] = q3.fillna(0.0) + + return df diff --git a/airflow_bundle/leaf-pipeline/projects/disease-monitor/disease-monitor/src/disease_monitor/cli.py b/airflow_bundle/leaf-pipeline/projects/disease-monitor/disease-monitor/src/disease_monitor/cli.py new file mode 100644 index 000000000..74817d4cb --- /dev/null +++ b/airflow_bundle/leaf-pipeline/projects/disease-monitor/disease-monitor/src/disease_monitor/cli.py @@ -0,0 +1,121 @@ +import argparse +import json +import logging +from typing import Dict, Any, List + +import yaml +import pandas as pd + +from .logging_utils import setup_logging +from .config import AppConfig +from . import io as io_mod +from .baseline import compute_baseline +from .rules import apply_rules +from .alerting import enforce_policies +from .notifiers.base import Notifier +from .notifiers.slack import SlackNotifier +from .notifiers.webhook import WebhookNotifier +from .notifiers.emailer import EmailNotifier +from .notifiers.kafka_notifier import KafkaNotifier +from .io import load_inputs_from_postgres , upsert_alerts_pg , fetch_open_alerts_pg + + +LOGGER = logging.getLogger("disease_monitor") + + +def parse_args(): + parser = argparse.ArgumentParser(description="Offline disease anomaly detector") + parser.add_argument("--config", required=True, help="Path to config file (YAML)") + parser.add_argument("--log-level", default="INFO", help="Logging level") + return parser.parse_args() + + +def load_config(path: str) -> Dict[str, Any]: + with open(path, "r", encoding="utf-8") as f: + cfg = yaml.safe_load(f) + AppConfig(**cfg) # validation + return cfg + +def build_notifiers(cfg: Dict[str, Any]) -> List[Notifier]: + ns: List[Notifier] = [] + d = cfg.get("delivery", {}) + + kafka_cfg = d.get("kafka", {}) + if kafka_cfg.get("enabled"): + brokers = kafka_cfg["brokers"] + topic = kafka_cfg.get("topic", "alerts") + ns.append(KafkaNotifier(brokers, topic)) + LOGGER.info("Using KafkaNotifier to send alerts.") + return ns + + slack = d.get("slack", {}) + if slack.get("enabled") and slack.get("webhook_url"): + ns.append(SlackNotifier(slack["webhook_url"])) + + webhook = d.get("webhook", {}) + if webhook.get("enabled") and webhook.get("url"): + ns.append(WebhookNotifier(webhook["url"], webhook.get("headers") or {})) + + email = d.get("email", {}) + if email.get("enabled") and email.get("to_addrs"): + ns.append(EmailNotifier(email["smtp_host"], email["smtp_port"], email["username"], + email["password_env"], email["from_addr"], email["to_addrs"])) + return ns + +def main() -> None: + args = parse_args() + setup_logging(args.log_level) + cfg = load_config(args.config) + + tz = cfg["windows"]["timezone"] + freq = cfg["windows"]["frequency"] + + # Load inputs + det, reg = load_inputs_from_postgres(cfg["io"]["postgres_url"], tz, cfg) + + # Optional filters + run_cfg = cfg["run"] + if run_cfg.get("disease_filter"): + det = det[det["disease_type"].isin(run_cfg["disease_filter"])] + if run_cfg.get("limit_entities"): + keep = det["entity_id"].drop_duplicates().head(run_cfg["limit_entities"]).tolist() + det = det[det["entity_id"].isin(keep)] + + # Aggregation + baseline + agg = io_mod.aggregate(det, freq=freq) + agg_bl = compute_baseline( + agg, + method=cfg["baseline"]["method"], + lookback=cfg["baseline"]["lookback_periods"], + min_history=cfg["baseline"]["min_history"], + seasonality=cfg["baseline"]["seasonality"], + ) + + # Rules + candidates = apply_rules(agg_bl, cfg) + LOGGER.info("Candidate alerts: %d", 0 if candidates is None else len(candidates)) + + # Policies need knowledge of currently OPEN alerts from the chosen backend + open_alerts = fetch_open_alerts_pg(cfg["io"]["postgres_url"]) + alerts = enforce_policies(candidates, open_alerts, cfg) + LOGGER.info("Alerts after policies: %d", len(alerts)) + + # Delivery + notifiers = build_notifiers(cfg) + dry_run = cfg["run"]["dry_run"] + + if not dry_run and alerts: + io_mod.upsert_alerts_pg(cfg["io"]["postgres_url"], alerts) + for a in alerts: + for n in notifiers: + try: + n.send(a) + except Exception as ex: + LOGGER.error("Notifier failed: %s", ex) + else: + LOGGER.info("Dry-run or no alerts. Skipping DB write & delivery.") + LOGGER.info("Preview alerts: %s", json.dumps(alerts, default=str, ensure_ascii=False)) + + +if __name__ == "__main__": + main() diff --git a/airflow_bundle/leaf-pipeline/projects/disease-monitor/disease-monitor/src/disease_monitor/config.py b/airflow_bundle/leaf-pipeline/projects/disease-monitor/disease-monitor/src/disease_monitor/config.py new file mode 100644 index 000000000..a2e0aa83f --- /dev/null +++ b/airflow_bundle/leaf-pipeline/projects/disease-monitor/disease-monitor/src/disease_monitor/config.py @@ -0,0 +1,114 @@ +from pydantic import BaseModel, Field, model_validator +from typing import Optional, List, Dict, Any + +# ------------------------------ +# IO: Postgres-only +# ------------------------------ +class IOConfig(BaseModel): + postgres_url: str # required: Postgres-only + + @model_validator(mode="after") + def _ensure_pg_only(self): + url = self.postgres_url + if not isinstance(url, str) or not url.lower().startswith( + ("postgresql://", "postgresql+psycopg2://") + ): + raise ValueError("io.postgres_url is required and must be a PostgreSQL URL.") + return self + + +# ------------------------------ +# Windows/Baseline/Rules/Alerting +# ------------------------------ +class WindowsConfig(BaseModel): + frequency: str = "D" + timezone: str = "UTC" + +class BaselineConfig(BaseModel): + method: str = "median" + lookback_periods: int = 28 + min_history: int = 7 + seasonality: Optional[int] = None + +class CountAnomalyRule(BaseModel): + enabled: bool = True + method: str = "zscore" + z_threshold: float = 3.0 + iqr_k: float = 1.5 + min_count: int = 3 + +class WorseningRule(BaseModel): + enabled: bool = True + method: str = "slope" + slope_lookback: int = 7 + slope_min: float = 0.02 + min_periods: int = 5 + ewma_span: int = 7 + ewma_threshold: float = 0.6 + +class RulesConfig(BaseModel): + count_anomaly: CountAnomalyRule = Field(default_factory=CountAnomalyRule) + worsening: WorseningRule = Field(default_factory=WorseningRule) + +class AlertingConfig(BaseModel): + dedup_cooldown_windows: int = 3 + resolve_after_no_anomaly: int = 3 + rate_limit_per_run: int = 100 + group_by_window: bool = True + + +# ------------------------------ +# Delivery: add Alertmanager section +# ------------------------------ +class SlackConfig(BaseModel): + enabled: bool = False + webhook_url: Optional[str] = None + +class WebhookConfig(BaseModel): + enabled: bool = False + url: Optional[str] = None + headers: Dict[str, Any] = Field(default_factory=dict) + +class EmailConfig(BaseModel): + enabled: bool = False + smtp_host: str = "" + smtp_port: int = 587 + username: str = "" + password_env: str = "SMTP_PASSWORD" + from_addr: str = "" + to_addrs: List[str] = Field(default_factory=list) + +class AlertmanagerConfig(BaseModel): + enabled: bool = False + url: Optional[str] = None + default_severity: str = "warning" + extra_labels: Dict[str, str] = Field(default_factory=dict) + auth: Dict[str, Any] = Field(default_factory=lambda: {"type": "none"}) # {"type":"none"} or {"type":"basic",...} + +class DeliveryConfig(BaseModel): + slack: SlackConfig = Field(default_factory=SlackConfig) + webhook: WebhookConfig = Field(default_factory=WebhookConfig) + email: EmailConfig = Field(default_factory=EmailConfig) + alertmanager: AlertmanagerConfig = Field(default_factory=AlertmanagerConfig) + + +# ------------------------------ +# Run +# ------------------------------ +class RunConfig(BaseModel): + dry_run: bool = False + limit_entities: Optional[int] = None + disease_filter: Optional[List[str]] = None + + +# ------------------------------ +# AppConfig +# ------------------------------ +class AppConfig(BaseModel): + io: IOConfig + windows: WindowsConfig = Field(default_factory=WindowsConfig) + baseline: BaselineConfig = Field(default_factory=BaselineConfig) + rules: RulesConfig = Field(default_factory=RulesConfig) + alerting: AlertingConfig = Field(default_factory=AlertingConfig) + delivery: DeliveryConfig = Field(default_factory=DeliveryConfig) + run: RunConfig = Field(default_factory=RunConfig) diff --git a/airflow_bundle/leaf-pipeline/projects/disease-monitor/disease-monitor/src/disease_monitor/io.py b/airflow_bundle/leaf-pipeline/projects/disease-monitor/disease-monitor/src/disease_monitor/io.py new file mode 100644 index 000000000..98b100f77 --- /dev/null +++ b/airflow_bundle/leaf-pipeline/projects/disease-monitor/disease-monitor/src/disease_monitor/io.py @@ -0,0 +1,208 @@ +from __future__ import annotations + +import json +import logging +from typing import Tuple, Iterable, Dict, Any, List + +import pandas as pd +from sqlalchemy import create_engine, text + +LOGGER = logging.getLogger(__name__) + +# --------------------------------------------------------------------- +# Postgres sources: anomalies / anomaly_types / regions +# --------------------------------------------------------------------- + +_BASE_SQLS: Dict[str, str] = { + "device": """ + SELECT a.ts AS "timestamp", + a.device_id AS entity_id, + at.code AS disease_type, + COALESCE(a.severity::double precision, 0.0) AS severity, + 0.0 AS affected_area + FROM public.anomalies a + JOIN public.anomaly_types at ON at.anomaly_type_id = a.anomaly_type_id + WHERE a.ts IS NOT NULL + {AND_CODE_FILTER} + {AND_TIME_RANGE} + """, + "mission": """ + SELECT a.ts AS "timestamp", + a.mission_id::text AS entity_id, + at.code AS disease_type, + COALESCE(a.severity::double precision, 0.0) AS severity, + 0.0 AS affected_area + FROM public.anomalies a + JOIN public.anomaly_types at ON at.anomaly_type_id = a.anomaly_type_id + WHERE a.ts IS NOT NULL + {AND_CODE_FILTER} + {AND_TIME_RANGE} + """, + "region": """ + SELECT a.ts AS "timestamp", + r.id::text AS entity_id, + at.code AS disease_type, + COALESCE(a.severity::double precision, 0.0) AS severity, + {AREA_EXPR} AS affected_area + FROM public.anomalies a + JOIN public.anomaly_types at ON at.anomaly_type_id = a.anomaly_type_id + JOIN public.regions r ON ST_Contains(r.geom, a.geom) + WHERE a.ts IS NOT NULL AND a.geom IS NOT NULL + {AND_CODE_FILTER} + {AND_TIME_RANGE} + """, + +} + + +def _build_sql( + entity_dim: str, + area_strategy: str, + codes: List[str] | None, + start: str | None, + end: str | None, +) -> tuple[str, dict]: + """ + Build parametrized SQL for reading anomalies with chosen entity dimension and area strategy. + """ + sql = _BASE_SQLS[entity_dim] + area_expr = "0.0" + if entity_dim == "region" and area_strategy == "region_area": + area_expr = "ST_Area(r.geom::geography)::double precision" + + and_code = "" + params: Dict[str, Any] = {} + if codes: + and_code = "AND at.code = ANY(:codes)" + params["codes"] = codes + + and_time = "" + if start: + and_time += " AND a.ts >= :start_time" + params["start_time"] = start + if end: + and_time += " AND a.ts < :end_time" + params["end_time"] = end + + sql = ( + sql.replace("{AREA_EXPR}", area_expr) + .replace("{AND_CODE_FILTER}", and_code) + .replace("{AND_TIME_RANGE}", and_time) + ) + return sql, params + + +# --------------------------------------------------------------------- +# Postgres input (canonical) +# --------------------------------------------------------------------- + +def load_inputs_from_postgres(pg_url: str, tz: str, cfg: dict) -> Tuple[pd.DataFrame, pd.DataFrame]: + """ + Load inputs from Postgres (public.anomalies/anomaly_types/regions). + Controlled by cfg['source_mapping'] (entity_dim, area_strategy, filters, codes). + Returns: + det: columns [timestamp, entity_id, disease_type, severity, affected_area] + reg: columns [entity_id, entity_type] + """ + edim = cfg["source_mapping"]["entity_dim"] + area = cfg["source_mapping"].get("area_strategy", "none") + codes = cfg["source_mapping"].get("anomaly_codes") + filters = cfg["source_mapping"].get("filters") or {} + start = filters.get("start_time") + end = filters.get("end_time") + + sql, params = _build_sql(edim, area, codes, start, end) + + eng = create_engine(pg_url) + with eng.begin() as conn: + det = pd.read_sql(text(sql), conn, params=params) + reg = det[["entity_id"]].drop_duplicates().assign(entity_type=edim) + + det["timestamp"] = pd.to_datetime(det["timestamp"], utc=True).dt.tz_convert(tz) + + required = {"timestamp", "entity_id", "disease_type", "severity", "affected_area"} + if not required.issubset(det.columns): + missing = required - set(det.columns) + raise ValueError(f"det: missing {missing}") + if not {"entity_id", "entity_type"}.issubset(reg.columns): + raise ValueError("reg: missing cols") + + return det, reg + + +# --------------------------------------------------------------------- +# Aggregation +# --------------------------------------------------------------------- + +def aggregate(det: pd.DataFrame, freq: str) -> pd.DataFrame: + """ + Aggregate by entity_id + window and compute disease_count, avg_severity, affected_area. + """ + df = det.copy() + + # Normalize tz: drop tz-info to use pandas period-based bucketing safely + if pd.api.types.is_datetime64tz_dtype(df["timestamp"]): + df["timestamp"] = df["timestamp"].dt.tz_convert("UTC").dt.tz_localize(None) + + df["window"] = df["timestamp"].dt.to_period(freq).dt.start_time + grp = df.groupby(["entity_id", "window"], as_index=False).agg( + disease_count=("disease_type", "count"), + avg_severity=("severity", "mean"), + affected_area=("affected_area", "sum"), + ) + grp["window_end"] = grp["window"] + pd.tseries.frequencies.to_offset(freq) + return grp + + +# --------------------------------------------------------------------- +# Alerts: Postgres backend +# --------------------------------------------------------------------- + + +def fetch_open_alerts_pg(pg_url: str) -> pd.DataFrame: + eng = create_engine(pg_url) + sql = """ + SELECT id, entity_id, rule, window_start, window_end, score, + first_seen, last_seen, status, meta_json + FROM alerts_leaves + WHERE status IN ('OPEN','ACK') + """ + with eng.begin() as conn: + df = pd.read_sql(text(sql), conn) + if not df.empty: + for c in ("first_seen", "last_seen", "window_start", "window_end"): + # make tz-aware UTC then drop tz -> naive UTC + s = pd.to_datetime(df[c], utc=True) + df[c] = s.dt.tz_convert("UTC").dt.tz_localize(None) + + return df + + +def upsert_alerts_pg(pg_url: str, alerts: Iterable[Dict[str, Any]]) -> None: + rows = list(alerts) + if not rows: + return + eng = create_engine(pg_url) + sql = """ + INSERT INTO alerts_leaves + (entity_id, rule, window_start, window_end, score, + first_seen, last_seen, status, meta_json) + VALUES + (:entity_id, :rule, :window_start, :window_end, :score, + :first_seen, :last_seen, :status, CAST(:meta_json AS jsonb)) + """ + payload = [{ + "entity_id": a["entity_id"], + "rule": a["rule"], + "window_start": a["window_start"], + "window_end": a["window_end"], + "score": float(a["score"]), + "first_seen": a["first_seen"], + "last_seen": a["last_seen"], + "status": a["status"], + "meta_json": json.dumps(a["meta"], ensure_ascii=False), + } for a in rows] + + with eng.begin() as conn: + conn.execute(text(sql), payload) + LOGGER.info("Inserted %d alerts into Postgres.", len(rows)) \ No newline at end of file diff --git a/airflow_bundle/leaf-pipeline/projects/disease-monitor/disease-monitor/src/disease_monitor/logging_utils.py b/airflow_bundle/leaf-pipeline/projects/disease-monitor/disease-monitor/src/disease_monitor/logging_utils.py new file mode 100644 index 000000000..f9618ff02 --- /dev/null +++ b/airflow_bundle/leaf-pipeline/projects/disease-monitor/disease-monitor/src/disease_monitor/logging_utils.py @@ -0,0 +1,10 @@ +import logging +import sys + +def setup_logging(level: str = "INFO") -> None: + handler = logging.StreamHandler(sys.stdout) + fmt = "%(asctime)s %(levelname)s %(name)s - %(message)s" + handler.setFormatter(logging.Formatter(fmt)) + root = logging.getLogger() + root.setLevel(level.upper()) + root.handlers = [handler] diff --git a/airflow_bundle/leaf-pipeline/projects/disease-monitor/disease-monitor/src/disease_monitor/models.py b/airflow_bundle/leaf-pipeline/projects/disease-monitor/disease-monitor/src/disease_monitor/models.py new file mode 100644 index 000000000..4796fecd8 --- /dev/null +++ b/airflow_bundle/leaf-pipeline/projects/disease-monitor/disease-monitor/src/disease_monitor/models.py @@ -0,0 +1,15 @@ +from dataclasses import dataclass +from typing import Optional, Dict +from datetime import datetime + +@dataclass +class Alert: + entity_id: str + rule: str + window_start: datetime + window_end: datetime + score: float + first_seen: datetime + last_seen: datetime + status: str # OPEN | ACK | RESOLVED + meta: Dict diff --git a/airflow_bundle/leaf-pipeline/projects/disease-monitor/disease-monitor/src/disease_monitor/notifiers/base.py b/airflow_bundle/leaf-pipeline/projects/disease-monitor/disease-monitor/src/disease_monitor/notifiers/base.py new file mode 100644 index 000000000..1f8bc409a --- /dev/null +++ b/airflow_bundle/leaf-pipeline/projects/disease-monitor/disease-monitor/src/disease_monitor/notifiers/base.py @@ -0,0 +1,13 @@ +from __future__ import annotations +from typing import Dict, Any, List + +class Notifier: + def send(self, alert: Dict[str, Any]) -> None: + raise NotImplementedError + +def render_text(alert: Dict[str, Any]) -> str: + return ( + f"[{alert['status']}] {alert['rule']} for {alert['entity_id']} " + f"{alert['window_start']}..{alert['window_end']} " + f"score={alert['score']:.2f} reasons={alert['meta'].get('reasons')}" + ) diff --git a/airflow_bundle/leaf-pipeline/projects/disease-monitor/disease-monitor/src/disease_monitor/notifiers/emailer.py b/airflow_bundle/leaf-pipeline/projects/disease-monitor/disease-monitor/src/disease_monitor/notifiers/emailer.py new file mode 100644 index 000000000..695e0ddfe --- /dev/null +++ b/airflow_bundle/leaf-pipeline/projects/disease-monitor/disease-monitor/src/disease_monitor/notifiers/emailer.py @@ -0,0 +1,28 @@ +from __future__ import annotations +import os +import smtplib +from email.mime.text import MIMEText +from typing import Dict, Any, List +from .base import Notifier, render_text + +class EmailNotifier(Notifier): + def __init__(self, host: str, port: int, username: str, password_env: str, + from_addr: str, to_addrs: List[str]) -> None: + self.host = host + self.port = port + self.username = username + self.password_env = password_env + self.from_addr = from_addr + self.to_addrs = to_addrs + + def send(self, alert: Dict[str, Any]) -> None: + password = os.getenv(self.password_env, "") + msg = MIMEText(render_text(alert)) + msg["Subject"] = f"Alert: {alert['rule']} {alert['entity_id']}" + msg["From"] = self.from_addr + msg["To"] = ", ".join(self.to_addrs) + with smtplib.SMTP(self.host, self.port, timeout=10) as s: + s.starttls() + if self.username and password: + s.login(self.username, password) + s.sendmail(self.from_addr, self.to_addrs, msg.as_string()) diff --git a/airflow_bundle/leaf-pipeline/projects/disease-monitor/disease-monitor/src/disease_monitor/notifiers/kafka_notifier.py b/airflow_bundle/leaf-pipeline/projects/disease-monitor/disease-monitor/src/disease_monitor/notifiers/kafka_notifier.py new file mode 100644 index 000000000..7552bfa1f --- /dev/null +++ b/airflow_bundle/leaf-pipeline/projects/disease-monitor/disease-monitor/src/disease_monitor/notifiers/kafka_notifier.py @@ -0,0 +1,49 @@ +from __future__ import annotations +import json, uuid, datetime, logging +from kafka import KafkaProducer +from typing import Dict, Any +from .base import Notifier + +LOGGER = logging.getLogger(__name__) + +def _json_default(obj): + if isinstance(obj, (datetime.datetime, datetime.date)): + return obj.isoformat() + raise TypeError(f"Type {type(obj)} not serializable") + + +class KafkaNotifier(Notifier): + def __init__(self, brokers: str, topic: str): + self.producer = KafkaProducer( + bootstrap_servers=brokers.split(","), + value_serializer=lambda v: json.dumps(v, default=_json_default).encode("utf-8"), + ) + self.topic = topic + + def send(self, alert: Dict[str, Any]) -> None: + msg = { + "alert_id": alert.get("alert_id") or str(uuid.uuid4()), + "alert_type": alert.get("rule", "disease_detected"), + "device_id": alert.get("entity_id"), + "started_at": alert.get("window_start"), + "ended_at": alert.get("window_end"), + "confidence": alert.get("score"), + "severity": int(alert.get("meta", {}).get("severity", 1)), + "area": alert.get("meta", {}).get("area"), + "lat": alert.get("meta", {}).get("lat"), + "lon": alert.get("meta", {}).get("lon"), + "image_url": alert.get("meta", {}).get("image_url"), + "vod": alert.get("meta", {}).get("vod"), + "hls": alert.get("meta", {}).get("hls"), + "meta": alert.get("meta", {}), + } + + try: + self.producer.send(self.topic, msg) + self.producer.flush() + LOGGER.info( + "KafkaNotifier: sent alert %s to topic '%s' with rule '%s' (confidence=%.2f)", + msg["alert_id"], self.topic, msg["alert_type"], msg["confidence"] or 0, + ) + except Exception as e: + LOGGER.error("KafkaNotifier failed to send alert: %s", e) diff --git a/airflow_bundle/leaf-pipeline/projects/disease-monitor/disease-monitor/src/disease_monitor/notifiers/slack.py b/airflow_bundle/leaf-pipeline/projects/disease-monitor/disease-monitor/src/disease_monitor/notifiers/slack.py new file mode 100644 index 000000000..68925060a --- /dev/null +++ b/airflow_bundle/leaf-pipeline/projects/disease-monitor/disease-monitor/src/disease_monitor/notifiers/slack.py @@ -0,0 +1,15 @@ +from __future__ import annotations +import os +import json +import requests +from typing import Dict, Any +from .base import Notifier, render_text + +class SlackNotifier(Notifier): + def __init__(self, webhook_url: str) -> None: + self.webhook_url = webhook_url + + def send(self, alert: Dict[str, Any]) -> None: + text = render_text(alert) + payload = {"text": text} + requests.post(self.webhook_url, data=json.dumps(payload), timeout=10) diff --git a/airflow_bundle/leaf-pipeline/projects/disease-monitor/disease-monitor/src/disease_monitor/notifiers/webhook.py b/airflow_bundle/leaf-pipeline/projects/disease-monitor/disease-monitor/src/disease_monitor/notifiers/webhook.py new file mode 100644 index 000000000..1e84232c7 --- /dev/null +++ b/airflow_bundle/leaf-pipeline/projects/disease-monitor/disease-monitor/src/disease_monitor/notifiers/webhook.py @@ -0,0 +1,13 @@ +from __future__ import annotations +import json +import requests +from typing import Dict, Any +from .base import Notifier + +class WebhookNotifier(Notifier): + def __init__(self, url: str, headers: Dict[str, str] | None = None) -> None: + self.url = url + self.headers = headers or {} + + def send(self, alert: Dict[str, Any]) -> None: + requests.post(self.url, json=alert, headers=self.headers, timeout=10) diff --git a/airflow_bundle/leaf-pipeline/projects/disease-monitor/disease-monitor/src/disease_monitor/rules.py b/airflow_bundle/leaf-pipeline/projects/disease-monitor/disease-monitor/src/disease_monitor/rules.py new file mode 100644 index 000000000..eeba8ec4c --- /dev/null +++ b/airflow_bundle/leaf-pipeline/projects/disease-monitor/disease-monitor/src/disease_monitor/rules.py @@ -0,0 +1,108 @@ +from __future__ import annotations +import logging +from typing import List, Dict, Any, Tuple +import pandas as pd +import numpy as np +from scipy import stats + +LOGGER = logging.getLogger(__name__) + +def zscore_anomalies(df: pd.DataFrame, threshold: float, min_count: int) -> pd.DataFrame: + s = df["disease_count"] + mu = df["disease_count_bl"] + # Use small epsilon for zero/NaN std to avoid z=0 + sd = df["disease_count_std"] + eps = 1e-6 + sd = sd.where(sd > 0, other=eps).fillna(eps) + + z = (s - mu) / sd + cond = (z >= threshold) & (s >= min_count) + + out = df.loc[cond].copy() + out["score"] = z.loc[cond] + out["rule"] = "COUNT_SPIKE" + out["reason"] = "zscore" + return out + +def iqr_anomalies(df: pd.DataFrame, k: float, min_count: int) -> pd.DataFrame: + q1 = df["disease_count_q1"] + q3 = df["disease_count_q3"] + iqr = (q3 - q1).replace(0, np.nan) + upper = q3 + k * iqr + cond = (df["disease_count"] > upper.fillna(float("inf"))) & (df["disease_count"] >= min_count) + out = df.loc[cond].copy() + out["score"] = (df["disease_count"] - upper).loc[cond].fillna(0.0) + out["rule"] = "COUNT_SPIKE" + out["reason"] = "iqr" + return out + +def slope_worsening(df: pd.DataFrame, metric: str, lookback: int, + slope_min: float, min_periods: int) -> pd.DataFrame: + # Per entity rolling slope (OLS) + rows = [] + for entity, g in df.groupby("entity_id"): + g = g.sort_values("window") + y = g[metric].rolling(lookback, min_periods=min_periods).apply(_rolling_slope, raw=False) + cond = y >= slope_min + sel = g.loc[cond].copy() + if sel.empty: + continue + sel["score"] = y.loc[cond] + sel["rule"] = "WORSENING_TREND" + sel["reason"] = f"slope_{metric}" + rows.append(sel) + return pd.concat(rows, ignore_index=True) if rows else pd.DataFrame(columns=df.columns.tolist() + ["score","rule","reason"]) + +def _rolling_slope(s: pd.Series) -> float: + x = np.arange(len(s)) + res = stats.linregress(x, s.values) + return float(res.slope) + +def ewma_worsening(df: pd.DataFrame, metric: str, span: int, threshold: float, min_periods: int) -> pd.DataFrame: + rows = [] + for entity, g in df.groupby("entity_id"): + g = g.sort_values("window").copy() + ew = g[metric].ewm(span=span, adjust=False).mean() + cond = (ew >= threshold) & (g[metric].rolling(span, min_periods=min_periods).count() >= min_periods) + sel = g.loc[cond].copy() + if sel.empty: + continue + sel["score"] = ew.loc[cond] + sel["rule"] = "WORSENING_TREND" + sel["reason"] = f"ewma_{metric}" + rows.append(sel) + return pd.concat(rows, ignore_index=True) if rows else pd.DataFrame(columns=df.columns.tolist() + ["score","rule","reason"]) + +def apply_rules(df: pd.DataFrame, cfg: Dict[str, Any]) -> pd.DataFrame: + results = [] + + # Count anomaly + rc = cfg["rules"]["count_anomaly"] + if rc["enabled"]: + if rc["method"] == "zscore": + results.append(zscore_anomalies(df, rc["z_threshold"], rc["min_count"])) + elif rc["method"] == "iqr": + results.append(iqr_anomalies(df, rc["iqr_k"], rc["min_count"])) + else: + # Placeholder: CUSUM can be added similarly + results.append(zscore_anomalies(df, rc["z_threshold"], rc["min_count"])) + + # Worsening trend on severity and area + rw = cfg["rules"]["worsening"] + if rw["enabled"]: + if rw["method"] == "slope": + for m in ["avg_severity", "affected_area"]: + results.append(slope_worsening(df, m, rw["slope_lookback"], rw["slope_min"], rw["min_periods"])) + else: + for m in ["avg_severity", "affected_area"]: + results.append(ewma_worsening(df, m, rw["ewma_span"], rw["ewma_threshold"], rw["min_periods"])) + + if not results: + return pd.DataFrame() + out = pd.concat([r for r in results if r is not None and not r.empty], ignore_index=True) \ + if any((r is not None and not r.empty) for r in results) else pd.DataFrame() + # Prepare common fields + if not out.empty: + out = out[["entity_id", "window", "window_end", "rule", "score", "reason", + "disease_count", "avg_severity", "affected_area"]].copy() + return out diff --git a/airflow_bundle/leaf-pipeline/projects/disease-monitor/disease-monitor/tests/conftest.py b/airflow_bundle/leaf-pipeline/projects/disease-monitor/disease-monitor/tests/conftest.py new file mode 100644 index 000000000..043dc1712 --- /dev/null +++ b/airflow_bundle/leaf-pipeline/projects/disease-monitor/disease-monitor/tests/conftest.py @@ -0,0 +1,17 @@ +import pandas as pd +import numpy as np +from datetime import datetime, timedelta, timezone + +TZ = "UTC" + +def make_series(start: str, days: int, entity: str, base_count=1, bump_at=None, bump=5): + rows = [] + start_dt = pd.to_datetime(start).tz_localize("UTC") + for i in range(days): + ts = start_dt + pd.Timedelta(days=i) + count = base_count + if bump_at is not None and i in bump_at: + count = bump + rows.append({"timestamp": ts, "entity_id": entity, "disease_type": "x", + "severity": 0.1 * count, "affected_area": 2.0 * count}) + return pd.DataFrame(rows) diff --git a/airflow_bundle/leaf-pipeline/projects/disease-monitor/disease-monitor/tests/test_aggregation.py b/airflow_bundle/leaf-pipeline/projects/disease-monitor/disease-monitor/tests/test_aggregation.py new file mode 100644 index 000000000..bb0a42558 --- /dev/null +++ b/airflow_bundle/leaf-pipeline/projects/disease-monitor/disease-monitor/tests/test_aggregation.py @@ -0,0 +1,17 @@ +import pandas as pd +from disease_monitor.io import aggregate + +def test_aggregate_basic(): + det = pd.DataFrame({ + "timestamp": pd.to_datetime(["2025-08-01", "2025-08-01", "2025-08-02"]).tz_localize("UTC"), + "entity_id": ["A","A","A"], + "disease_type": ["x","x","x"], + "severity": [0.2, 0.4, 0.3], + "affected_area": [1,2,3], + }) + out = aggregate(det, "D") + assert len(out) == 2 + d1 = out[out["window"] == pd.to_datetime("2025-08-01")] + assert int(d1["disease_count"].iloc[0]) == 2 + assert abs(float(d1["avg_severity"].iloc[0]) - 0.3) < 1e-9 + assert int(d1["affected_area"].iloc[0]) == 3 diff --git a/airflow_bundle/leaf-pipeline/projects/disease-monitor/disease-monitor/tests/test_alerting.py b/airflow_bundle/leaf-pipeline/projects/disease-monitor/disease-monitor/tests/test_alerting.py new file mode 100644 index 000000000..48126beaf --- /dev/null +++ b/airflow_bundle/leaf-pipeline/projects/disease-monitor/disease-monitor/tests/test_alerting.py @@ -0,0 +1,34 @@ +import pandas as pd +from disease_monitor.alerting import enforce_policies + +def test_dedup_cooldown(): + candidates = pd.DataFrame({ + "entity_id": ["A","A"], + "window": pd.to_datetime(["2025-08-10","2025-08-11"]), + "window_end": pd.to_datetime(["2025-08-11","2025-08-12"]), + "rule": ["COUNT_SPIKE","COUNT_SPIKE"], + "score": [3.1, 3.2], + "reason": [["zscore"],["zscore"]], + "disease_count": [10, 9], + "avg_severity": [0.5, 0.4], + "affected_area": [10.0, 9.0], + }) + open_alerts = pd.DataFrame({ + "entity_id": ["A"], + "rule": ["COUNT_SPIKE"], + "last_seen": pd.to_datetime(["2025-08-10"]), + "window_start": pd.to_datetime(["2025-08-10"]), + "window_end": pd.to_datetime(["2025-08-11"]), + "first_seen": pd.to_datetime(["2025-08-10"]), + "status": ["OPEN"], + "id": [1], + "score": [3.1] + }) + cfg = { + "alerting": {"dedup_cooldown_windows": 3, "resolve_after_no_anomaly": 3, + "rate_limit_per_run": 10, "group_by_window": True}, + "windows": {"frequency": "D"} + } + res = enforce_policies(candidates, open_alerts, cfg) + # Second day should be skipped due to cooldown + assert len(res) == 0 or len(res) == 1 diff --git a/airflow_bundle/leaf-pipeline/projects/disease-monitor/disease-monitor/tests/test_anomaly_rules.py b/airflow_bundle/leaf-pipeline/projects/disease-monitor/disease-monitor/tests/test_anomaly_rules.py new file mode 100644 index 000000000..23f595a7a --- /dev/null +++ b/airflow_bundle/leaf-pipeline/projects/disease-monitor/disease-monitor/tests/test_anomaly_rules.py @@ -0,0 +1,22 @@ +import pandas as pd +from disease_monitor.baseline import compute_baseline +from disease_monitor.rules import apply_rules + +def test_zscore_spike_detected(): + # mostly low counts, then spike + det = [] + for d in range(10): + det.append({"window": pd.to_datetime(f"2025-08-{d+1:02d}"), + "entity_id": "E1", + "disease_count": 1 if d < 8 else (10 if d==8 else 1), + "avg_severity": 0.2, "affected_area": 2.0}) + df = pd.DataFrame(det) + df["window_end"] = df["window"] + pd.Timedelta(days=1) + bl = compute_baseline(df.rename(columns={"window":"window"}), "median", 7, 3, None) + cfg = { + "rules": {"count_anomaly": {"enabled": True, "method": "zscore", "z_threshold": 2.5, "min_count": 3}, + "worsening": {"enabled": False}}, + } + out = apply_rules(bl, cfg) + assert not out.empty + assert "COUNT_SPIKE" in out["rule"].unique() diff --git a/airflow_bundle/leaf-pipeline/projects/disease-monitor/disease-monitor/tests/test_worsening_rules.py b/airflow_bundle/leaf-pipeline/projects/disease-monitor/disease-monitor/tests/test_worsening_rules.py new file mode 100644 index 000000000..e17b34e57 --- /dev/null +++ b/airflow_bundle/leaf-pipeline/projects/disease-monitor/disease-monitor/tests/test_worsening_rules.py @@ -0,0 +1,24 @@ +import pandas as pd +from disease_monitor.baseline import compute_baseline +from disease_monitor.rules import apply_rules + +def test_worsening_slope_on_severity(): + rows = [] + for i in range(10): + rows.append({ + "window": pd.to_datetime(f"2025-08-{i+1:02d}"), + "entity_id": "E1", + "disease_count": 1, + "avg_severity": 0.1 + 0.03*i, + "affected_area": 2 + i + }) + df = pd.DataFrame(rows) + df["window_end"] = df["window"] + pd.Timedelta(days=1) + bl = compute_baseline(df, "median", 7, 3, None) + cfg = {"rules": + {"count_anomaly": {"enabled": False}, + "worsening": {"enabled": True, "method": "slope", + "slope_lookback": 7, "slope_min": 0.02, "min_periods": 5}}} + out = apply_rules(bl, cfg) + assert not out.empty + assert "WORSENING_TREND" in out["rule"].unique() diff --git a/airflow_bundle/leaf-pipeline/projects/leaf-counting/demo_images/10/Bell_pepper leaf spot/05.jpg b/airflow_bundle/leaf-pipeline/projects/leaf-counting/demo_images/10/Bell_pepper leaf spot/05.jpg new file mode 100644 index 000000000..09be3be5f Binary files /dev/null and b/airflow_bundle/leaf-pipeline/projects/leaf-counting/demo_images/10/Bell_pepper leaf spot/05.jpg differ diff --git a/airflow_bundle/leaf-pipeline/projects/leaf-counting/demo_images/10/Bell_pepper leaf spot/bacterialspot3_600px.jpg b/airflow_bundle/leaf-pipeline/projects/leaf-counting/demo_images/10/Bell_pepper leaf spot/bacterialspot3_600px.jpg new file mode 100644 index 000000000..c56ceecd3 Binary files /dev/null and b/airflow_bundle/leaf-pipeline/projects/leaf-counting/demo_images/10/Bell_pepper leaf spot/bacterialspot3_600px.jpg differ diff --git a/airflow_bundle/leaf-pipeline/projects/leaf-counting/demo_images/10/Bell_pepper leaf/DSCN3768.JPG.jpg b/airflow_bundle/leaf-pipeline/projects/leaf-counting/demo_images/10/Bell_pepper leaf/DSCN3768.JPG.jpg new file mode 100644 index 000000000..1dd73ab87 Binary files /dev/null and b/airflow_bundle/leaf-pipeline/projects/leaf-counting/demo_images/10/Bell_pepper leaf/DSCN3768.JPG.jpg differ diff --git a/airflow_bundle/leaf-pipeline/projects/leaf-counting/demo_images/10/Bell_pepper leaf/IMG_3891.JPG_1492073147.jpg b/airflow_bundle/leaf-pipeline/projects/leaf-counting/demo_images/10/Bell_pepper leaf/IMG_3891.JPG_1492073147.jpg new file mode 100644 index 000000000..4acbdc891 Binary files /dev/null and b/airflow_bundle/leaf-pipeline/projects/leaf-counting/demo_images/10/Bell_pepper leaf/IMG_3891.JPG_1492073147.jpg differ diff --git a/airflow_bundle/leaf-pipeline/projects/leaf-counting/demo_images/10/Potato leaf late blight/Late-blight-infected-potato-plants_2.jpg b/airflow_bundle/leaf-pipeline/projects/leaf-counting/demo_images/10/Potato leaf late blight/Late-blight-infected-potato-plants_2.jpg new file mode 100644 index 000000000..238c58e4e Binary files /dev/null and b/airflow_bundle/leaf-pipeline/projects/leaf-counting/demo_images/10/Potato leaf late blight/Late-blight-infected-potato-plants_2.jpg differ diff --git a/airflow_bundle/leaf-pipeline/projects/leaf-counting/demo_images/10/Potato leaf late blight/blight-on-potato-leaves.jpg b/airflow_bundle/leaf-pipeline/projects/leaf-counting/demo_images/10/Potato leaf late blight/blight-on-potato-leaves.jpg new file mode 100644 index 000000000..311a114ce Binary files /dev/null and b/airflow_bundle/leaf-pipeline/projects/leaf-counting/demo_images/10/Potato leaf late blight/blight-on-potato-leaves.jpg differ diff --git a/airflow_bundle/leaf-pipeline/projects/leaf-counting/demo_images/10/Tomato Early blight leaf/dscn3175.jpg b/airflow_bundle/leaf-pipeline/projects/leaf-counting/demo_images/10/Tomato Early blight leaf/dscn3175.jpg new file mode 100644 index 000000000..a1014b817 Binary files /dev/null and b/airflow_bundle/leaf-pipeline/projects/leaf-counting/demo_images/10/Tomato Early blight leaf/dscn3175.jpg differ diff --git a/airflow_bundle/leaf-pipeline/projects/leaf-counting/demo_images/10/Tomato Septoria leaf spot/tomato-badleaves.jpg b/airflow_bundle/leaf-pipeline/projects/leaf-counting/demo_images/10/Tomato Septoria leaf spot/tomato-badleaves.jpg new file mode 100644 index 000000000..545773c4c Binary files /dev/null and b/airflow_bundle/leaf-pipeline/projects/leaf-counting/demo_images/10/Tomato Septoria leaf spot/tomato-badleaves.jpg differ diff --git a/airflow_bundle/leaf-pipeline/projects/leaf-counting/demo_images/10/Tomato Septoria leaf spot/tomato_septoria_05_zoom.jpg b/airflow_bundle/leaf-pipeline/projects/leaf-counting/demo_images/10/Tomato Septoria leaf spot/tomato_septoria_05_zoom.jpg new file mode 100644 index 000000000..9ca14479a Binary files /dev/null and b/airflow_bundle/leaf-pipeline/projects/leaf-counting/demo_images/10/Tomato Septoria leaf spot/tomato_septoria_05_zoom.jpg differ diff --git a/airflow_bundle/leaf-pipeline/projects/leaf-counting/demo_images/10/Tomato leaf/late_blight_tomato_leaf4x1200.jpg b/airflow_bundle/leaf-pipeline/projects/leaf-counting/demo_images/10/Tomato leaf/late_blight_tomato_leaf4x1200.jpg new file mode 100644 index 000000000..83ea0c4f6 Binary files /dev/null and b/airflow_bundle/leaf-pipeline/projects/leaf-counting/demo_images/10/Tomato leaf/late_blight_tomato_leaf4x1200.jpg differ diff --git a/airflow_bundle/leaf-pipeline/projects/leaf-counting/demo_images/10/Tomato leaf/russian-2-319-dt-2010-leaves-high-tunnel-9-29-2014-c.jpg b/airflow_bundle/leaf-pipeline/projects/leaf-counting/demo_images/10/Tomato leaf/russian-2-319-dt-2010-leaves-high-tunnel-9-29-2014-c.jpg new file mode 100644 index 000000000..ca68fc44c Binary files /dev/null and b/airflow_bundle/leaf-pipeline/projects/leaf-counting/demo_images/10/Tomato leaf/russian-2-319-dt-2010-leaves-high-tunnel-9-29-2014-c.jpg differ diff --git a/airflow_bundle/leaf-pipeline/projects/leaf-counting/demo_images/10/Tomato mold leaf/Leaf-mold3.jpg b/airflow_bundle/leaf-pipeline/projects/leaf-counting/demo_images/10/Tomato mold leaf/Leaf-mold3.jpg new file mode 100644 index 000000000..5c7238de3 Binary files /dev/null and b/airflow_bundle/leaf-pipeline/projects/leaf-counting/demo_images/10/Tomato mold leaf/Leaf-mold3.jpg differ diff --git a/airflow_bundle/leaf-pipeline/projects/leaf-counting/demo_images/10/Tomato mold leaf/tomato_plants_1_original.JPG_1407178095.jpg b/airflow_bundle/leaf-pipeline/projects/leaf-counting/demo_images/10/Tomato mold leaf/tomato_plants_1_original.JPG_1407178095.jpg new file mode 100644 index 000000000..d8215f30d Binary files /dev/null and b/airflow_bundle/leaf-pipeline/projects/leaf-counting/demo_images/10/Tomato mold leaf/tomato_plants_1_original.JPG_1407178095.jpg differ diff --git a/airflow_bundle/leaf-pipeline/projects/leaf-counting/requirements.orig.txt b/airflow_bundle/leaf-pipeline/projects/leaf-counting/requirements.orig.txt new file mode 100644 index 000000000..81958653a --- /dev/null +++ b/airflow_bundle/leaf-pipeline/projects/leaf-counting/requirements.orig.txt @@ -0,0 +1,5 @@ + +ultralytics>=8.1.0 +opencv-python-headless>=4.9.0.80 +numpy>=1.23.0 +minio>=7.1.15 diff --git a/airflow_bundle/leaf-pipeline/projects/leaf-counting/requirements.txt b/airflow_bundle/leaf-pipeline/projects/leaf-counting/requirements.txt new file mode 100644 index 000000000..0c27d4902 --- /dev/null +++ b/airflow_bundle/leaf-pipeline/projects/leaf-counting/requirements.txt @@ -0,0 +1,9 @@ +--extra-index-url https://download.pytorch.org/whl/cpu +torch==2.6.0+cpu +torchvision==0.21.0+cpu +torchaudio==2.6.0+cpu + +ultralytics>=8.1.0 +opencv-python-headless>=4.9.0.80 +numpy>=1.23.0 +minio>=7.1.15 diff --git a/airflow_bundle/leaf-pipeline/projects/leaf-counting/src/__init__.py b/airflow_bundle/leaf-pipeline/projects/leaf-counting/src/__init__.py new file mode 100644 index 000000000..ee49d4339 --- /dev/null +++ b/airflow_bundle/leaf-pipeline/projects/leaf-counting/src/__init__.py @@ -0,0 +1,5 @@ +# decompyle3 version 3.9.3 +# Python bytecode version base 3.12.0 (3531) +# Decompiled from: Python 3.12.3 (main, Aug 14 2025, 17:47:21) [GCC 13.3.0] +# Embedded file name: /home/user/ml-workspace/projects/leaf-counting/src/__init__.py +# Compiled at: 2025-10-20 13:47:51 diff --git a/airflow_bundle/leaf-pipeline/projects/leaf-counting/src/common.py b/airflow_bundle/leaf-pipeline/projects/leaf-counting/src/common.py new file mode 100644 index 000000000..ea30f64ea --- /dev/null +++ b/airflow_bundle/leaf-pipeline/projects/leaf-counting/src/common.py @@ -0,0 +1,35 @@ +from __future__ import annotations +from pathlib import Path +import cv2 +import numpy as np + +IMG_EXTS = {".jpg", ".jpeg", ".png", ".bmp", ".tif", ".tiff", ".webp"} + +def is_image(path: Path) -> bool: + return path.suffix.lower() in IMG_EXTS + +def iter_images(inp: Path): + p = Path(inp) + if p.is_file() and is_image(p): + yield p + elif p.is_dir(): + for q in sorted(p.rglob("*")): + if q.is_file() and is_image(q): + yield q + +def ensure_dir(p: Path) -> Path: + Path(p).mkdir(parents=True, exist_ok=True) + return Path(p) + +def draw_boxes(img_bgr: np.ndarray, boxes, color=(0,255,0), thickness=2): + h, w = img_bgr.shape[:2] + out = img_bgr.copy() + for (x1,y1,x2,y2,conf,cls_id) in boxes: + x1 = max(0, min(w-1, int(x1))) + y1 = max(0, min(h-1, int(y1))) + x2 = max(0, min(w-1, int(x2))) + y2 = max(0, min(h-1, int(y2))) + cv2.rectangle(out, (x1,y1), (x2,y2), color, thickness) + label = f"{int(cls_id)}:{conf:.2f}" + cv2.putText(out, label, (x1, max(0, y1-5)), cv2.FONT_HERSHEY_SIMPLEX, 0.45, color, 1, cv2.LINE_AA) + return out \ No newline at end of file diff --git a/airflow_bundle/leaf-pipeline/projects/leaf-counting/src/crop_only.py b/airflow_bundle/leaf-pipeline/projects/leaf-counting/src/crop_only.py new file mode 100644 index 000000000..c54066f4c --- /dev/null +++ b/airflow_bundle/leaf-pipeline/projects/leaf-counting/src/crop_only.py @@ -0,0 +1,143 @@ +from __future__ import annotations +import json, argparse +from pathlib import Path +from typing import Optional +import cv2 +from common import ensure_dir +from datetime import datetime + +try: + from minio_io import get_client, ensure_bucket, put_png +except Exception: + get_client = ensure_bucket = put_png = None + + +def _load_jsons(inp: Path): + jdir = inp / "json" + if not jdir.exists(): + raise SystemExit(f"[ERR] Expected JSON dir not found: {jdir} (run detect_only.py first)") + + for jp in sorted(jdir.rglob("*.json")): + with jp.open("r", encoding="utf-8") as f: + j = json.load(f) + yield jp, j + +def _safe_crop(img, x1, y1, x2, y2): + h, w = img.shape[:2] + x1 = max(0, min(w-1, int(x1))); y1 = max(0, min(h-1, int(y1))) + x2 = max(0, min(w-1, int(x2))); y2 = max(0, min(h-1, int(y2))) + if x2 <= x1: x2 = min(w-1, x1+1) + if y2 <= y1: y2 = min(h-1, y1+1) + return img[y1:y2, x1:x2] + + +def run_crop(inp: Path, out_dir: Path, size: int=224, margin: float=0.1, min_wh: int=8, + orig_dir: Optional[Path]=None, flat: bool=False, + minio_endpoint: Optional[str]=None, minio_access: Optional[str]=None, + minio_secret: Optional[str]=None, minio_bucket: Optional[str]=None, + minio_prefix: str="CROP", minio_secure: bool=False, + run_id: Optional[str]=None): + run_id = run_id or datetime.now().strftime("%Y/%m/%d/%H%M") + out_dir = ensure_dir(out_dir) + + cli = None + if minio_endpoint and minio_access and minio_secret and minio_bucket: + if get_client is None: + raise SystemExit("[ERR] חסר minio או minio_io.") + cli = get_client(minio_endpoint, minio_access, minio_secret, secure=minio_secure) + ensure_bucket(cli, minio_bucket) + + count = 0 + for jp, j in _load_jsons(inp): + + if "source_path" in j: + img_path = Path(j["source_path"]) + rel_path = j.get("rel_path", j["image"]) + elif "rel_path" in j: + if orig_dir is None: + raise SystemExit("[ERR] JSON מכיל רק rel_path; ספקי --orig כדי למצוא את קובץ המקור") + img_path = Path(orig_dir) / j["rel_path"] + rel_path = j["rel_path"] + else: + if orig_dir is None: + raise SystemExit("[ERR] JSON חסר source_path/rel_path; ספקי --orig ותתאימי לשמות image") + img_path = Path(orig_dir) / j["image"] + rel_path = j["image"] + + if not img_path.exists(): + print(f"[WARN] Original image not found: {img_path}, skipping") + continue + + img = cv2.imread(str(img_path)) + if img is None: + print(f"[WARN] Can't read image: {img_path}") + continue + + rel_parent = str(Path(rel_path).parent) + rel_stem = Path(rel_path).stem + + + if flat: + dest_dir = ensure_dir(out_dir) + minio_subprefix = minio_prefix + else: + dest_dir = ensure_dir(out_dir / rel_parent / rel_stem) + minio_subprefix = f"{minio_prefix}/{rel_parent}/{rel_stem}" if rel_parent != "." else f"{minio_prefix}/{rel_stem}" + + for i, (x1,y1,x2,y2,conf,cls_id) in enumerate(j.get("boxes", [])): + w = x2 - x1; h = y2 - y1 + if w < min_wh or h < min_wh: + continue + cx = (x1 + x2) * 0.5; cy = (y1 + y2) * 0.5 + half = max(w, h) * 0.5 * (1.0 + margin) + crop = _safe_crop(img, cx-half, cy-half, cx+half, cy+half) + if crop.size == 0: + continue + crop_resized = cv2.resize(crop, (size, size), interpolation=cv2.INTER_AREA) + out_name = f"det{i:03d}_cls{int(cls_id)}_{conf:.2f}.png" + cv2.imwrite(str(dest_dir / out_name), crop_resized) + count += 1 + + if cli: + base = f"{run_id}/{minio_prefix}" # תאריך/שעה קודם, אח"כ CROP + key = f"{base}/{rel_parent}/{rel_stem}/{out_name}" if rel_parent != "." else f"{base}/{rel_stem}/{out_name}" + put_png(cli, minio_bucket, key, crop_resized) + + put_png(cli, minio_bucket, f"{minio_subprefix}/{out_name}", crop_resized) + + print(f"[DONE] Saved {count} crops under: {out_dir} (flat={flat})") + +def main(): + ap = argparse.ArgumentParser(description="Create square crops from detection JSON results (+optional MinIO).") + ap.add_argument("--input", required=True) + ap.add_argument("--out", required=True) + ap.add_argument("--orig", default=None, help="דרוש רק אם JSON חסר source_path") + ap.add_argument("--size", type=int, default=224) + ap.add_argument("--margin", type=float, default=0.1) + ap.add_argument("--min-wh", type=int, default=8) + ap.add_argument("--flat", action="store_true") + + ap.add_argument("--minio-endpoint", default=None) + ap.add_argument("--minio-access", default=None) + ap.add_argument("--minio-secret", default=None) + ap.add_argument("--minio-bucket", default=None) + ap.add_argument("--minio-prefix", default="crops") + ap.add_argument("--minio-secure", action="store_true") + ap.add_argument("--run-id", default=None, help="תיקיית הריצה ב-MinIO (ברירת מחדל: YYYY/MM/DD/HHmm)") + + args = ap.parse_args() + run_id = args.run_id or datetime.now().strftime("%Y/%m/%d/%H%M") + run_crop( + inp=Path(args.input), out_dir=Path(args.out), + size=args.size, margin=args.margin, min_wh=args.min_wh, + orig_dir=Path(args.orig) if args.orig else None, flat=args.flat, + minio_endpoint=args.minio_endpoint, minio_access=args.minio_access, + minio_secret=args.minio_secret, minio_bucket=args.minio_bucket, + minio_prefix=args.minio_prefix, minio_secure=args.minio_secure, + run_id=run_id, + ) + + + +if __name__ == "__main__": + main() diff --git a/airflow_bundle/leaf-pipeline/projects/leaf-counting/src/detect_only.py b/airflow_bundle/leaf-pipeline/projects/leaf-counting/src/detect_only.py new file mode 100644 index 000000000..707c88a66 --- /dev/null +++ b/airflow_bundle/leaf-pipeline/projects/leaf-counting/src/detect_only.py @@ -0,0 +1,140 @@ + +from __future__ import annotations +import json, argparse +from pathlib import Path +from typing import Optional +from datetime import datetime + +# --- HARD PATCH: עקיפת cpuinfo/ultralytics באופן גורף (מונע Popen('')) --- +try: + import cpuinfo as _ci + _ci.get_cpu_info = (lambda: {"brand_raw": "unknown"}) +except Exception: + pass +try: + import ultralytics.utils.torch_utils as _tu + _tu.get_cpu_info = (lambda: "unknown") +except Exception: + pass +# --- end hard patch --- + +import cv2 +from ultralytics import YOLO +from common import iter_images, ensure_dir, draw_boxes + +import cpuinfo +try: + print("cpu brand:", cpuinfo.get_cpu_info().get("brand_raw")) +except Exception as e: + print("cpuinfo error:", repr(e)) + +try: + from minio_io import get_client, ensure_bucket, put_png, put_json +except Exception: + get_client = ensure_bucket = put_png = put_json = None + +def run_detect(inp: Path, out_dir: Path, weights: Path, + conf: float=0.25, imgsz: int=896, device: str="cpu", + minio_endpoint: Optional[str]=None, minio_access: Optional[str]=None, + minio_secret: Optional[str]=None, minio_bucket: Optional[str]=None, + minio_prefix: str="DETECT", minio_secure: bool=False, + run_id: Optional[str]=None): + + run_id = run_id or datetime.now().strftime("%Y/%m/%d/%H%M") + out_dir = ensure_dir(out_dir) + overlay_root = ensure_dir(out_dir / "overlay") + json_root = ensure_dir(out_dir / "json") + + cli = None + if minio_endpoint and minio_access and minio_secret and minio_bucket: + if get_client is None: + raise SystemExit("[ERR] חסר minio או minio_io.") + cli = get_client(minio_endpoint, minio_access, minio_secret, secure=minio_secure) + ensure_bucket(cli, minio_bucket) + + model = YOLO(str(weights)) + + + img_paths = list(iter_images(inp)) + if not img_paths: + raise SystemExit(f"[ERR] No images found under: {inp}") + + is_dir_input = Path(inp).is_dir() + + for img_path in img_paths: + + rel_path = img_path.name if not is_dir_input else str(img_path.relative_to(inp)) + rel_parent = "." if not is_dir_input else str(img_path.relative_to(inp).parent) + rel_stem = Path(rel_path).stem + + overlay_dir = ensure_dir(overlay_root / rel_parent) + json_dir = ensure_dir(json_root / rel_parent) + + img_bgr = cv2.imread(str(img_path)) + if img_bgr is None: + print(f"[WARN] can't read image: {img_path}") + continue + h, w = img_bgr.shape[:2] + + res = model.predict(source=img_bgr, conf=conf, imgsz=imgsz, device=device, verbose=False)[0] + + boxes_pix = [] + if res.boxes is not None and len(res.boxes) > 0: + for b in res.boxes: + xyxy = b.xyxy.cpu().numpy().reshape(-1) + conf_i = float(b.conf.cpu().numpy().reshape(-1)[0]) + cls_i = float(b.cls.cpu().numpy().reshape(-1)[0]) if b.cls is not None else 0.0 + x1,y1,x2,y2 = map(float, xyxy.tolist()) + boxes_pix.append([x1,y1,x2,y2,conf_i,cls_i]) + + j = { + "image": img_path.name, + "rel_path": rel_path, + "source_path": str(img_path.resolve()), + "width": w, "height": h, + "boxes": boxes_pix + } + json_path = json_dir / f"{rel_stem}.json" + json_path.write_text(json.dumps(j, ensure_ascii=False, indent=2), encoding="utf-8") + + overlay = draw_boxes(img_bgr, boxes_pix) + ov_path = overlay_dir / img_path.name + cv2.imwrite(str(ov_path), overlay) + + if cli: + base = f"{run_id}/{minio_prefix}" + minio_json_key = f"{base}/json/{rel_parent}/{rel_stem}.json" if rel_parent != "." else f"{base}/json/{rel_stem}.json" + minio_ov_key = f"{base}/overlay/{rel_parent}/{img_path.name}" if rel_parent != "." else f"{base}/overlay/{img_path.name}" + put_json(cli, minio_bucket, minio_json_key, j) + put_png(cli, minio_bucket, minio_ov_key, overlay) + + print(f"[OK] {rel_path} -> {json_path.relative_to(out_dir)}, boxes={len(boxes_pix)}") + +def main(): + ap = argparse.ArgumentParser(description="YOLO detect -> pixel JSON + overlay (+optional MinIO)") + ap.add_argument("--input", required=True) + ap.add_argument("--out", required=True) + ap.add_argument("--weights", required=True) + ap.add_argument("--conf", type=float, default=0.25) + ap.add_argument("--imgsz", type=int, default=896) + ap.add_argument("--device", default="cpu") + + ap.add_argument("--minio-endpoint", default=None) + ap.add_argument("--minio-access", default=None) + ap.add_argument("--minio-secret", default=None) + ap.add_argument("--minio-bucket", default=None) + ap.add_argument("--minio-prefix", default="detect") + ap.add_argument("--minio-secure", action="store_true") + ap.add_argument("--run-id", default=None, help="תיקיית הריצה ב-MinIO (ברירת מחדל: YYYY/MM/DD/HHmm)") + + args = ap.parse_args() + run_id = args.run_id or datetime.now().strftime("%Y/%m/%d/%H%M") + run_detect(Path(args.input), Path(args.out), Path(args.weights), + conf=args.conf, imgsz=args.imgsz, device=args.device, + minio_endpoint=args.minio_endpoint, minio_access=args.minio_access, + minio_secret=args.minio_secret, minio_bucket=args.minio_bucket, + minio_prefix=args.minio_prefix, minio_secure=args.minio_secure, + run_id=run_id) + +if __name__ == "__main__": + main() diff --git a/airflow_bundle/leaf-pipeline/projects/leaf-counting/src/minio_io.py b/airflow_bundle/leaf-pipeline/projects/leaf-counting/src/minio_io.py new file mode 100644 index 000000000..cda3c246d --- /dev/null +++ b/airflow_bundle/leaf-pipeline/projects/leaf-counting/src/minio_io.py @@ -0,0 +1,39 @@ +from __future__ import annotations +import io, json, os +from pathlib import Path +import cv2 +from minio import Minio +from minio.error import S3Error + +def get_client(endpoint: str, access_key: str, secret_key: str, secure: bool=False) -> Minio: + """ + דוגמה: + cli = get_client("localhost:9000", "minioadmin", "minioadmin", secure=False) + """ + return Minio(endpoint, access_key=access_key, secret_key=secret_key, secure=secure) + +def ensure_bucket(cli: Minio, bucket: str): + found = cli.bucket_exists(bucket) + if not found: + cli.make_bucket(bucket) + +def put_png(cli: Minio, bucket: str, key: str, img_bgr): + """ + מעלה תמונת PNG מתוך np.ndarray (BGR של OpenCV). + """ + Path(key).parent and os.makedirs(Path(key).parent, exist_ok=True) # לא חובה, לשקט נפשי מקומי + ok, buf = cv2.imencode(".png", img_bgr) + if not ok: + raise RuntimeError("cv2.imencode PNG failed") + bio = io.BytesIO(buf.tobytes()) + bio.seek(0) + cli.put_object(bucket, key, bio, length=len(bio.getvalue()), content_type="image/png") + +def put_json(cli: Minio, bucket: str, key: str, obj): + """ + מעלה JSON (dict/list). + """ + js = json.dumps(obj, ensure_ascii=False, indent=2).encode("utf-8") + bio = io.BytesIO(js) + bio.seek(0) + cli.put_object(bucket, key, bio, length=len(js), content_type="application/json") diff --git a/airflow_bundle/leaf-pipeline/projects/leaf-counting/src/predict_pyramid_wbf.py b/airflow_bundle/leaf-pipeline/projects/leaf-counting/src/predict_pyramid_wbf.py new file mode 100644 index 000000000..292829ba1 --- /dev/null +++ b/airflow_bundle/leaf-pipeline/projects/leaf-counting/src/predict_pyramid_wbf.py @@ -0,0 +1,233 @@ + +from __future__ import annotations +import argparse, json +from pathlib import Path +from typing import List, Tuple, Optional +from datetime import datetime + +# --- HARD PATCH: עקיפת cpuinfo/ultralytics (מונע Popen('')) --- +try: + import cpuinfo as _ci + _ci.get_cpu_info = (lambda: {"brand_raw": "unknown"}) +except Exception: + pass +try: + import ultralytics.utils.torch_utils as _tu + _tu.get_cpu_info = (lambda: "unknown") +except Exception: + pass +# --- end hard patch --- + +import cv2 +import numpy as np +from ultralytics import YOLO + +from common import iter_images, ensure_dir, draw_boxes + +try: + from minio_io import get_client, ensure_bucket, put_png, put_json +except Exception: + get_client = ensure_bucket = put_png = put_json = None + + +# ----------------- WBF utils ----------------- +def iou_xyxy(a: np.ndarray, b: np.ndarray) -> float: + ax1, ay1, ax2, ay2 = a + bx1, by1, bx2, by2 = b + ix1, iy1 = max(ax1, bx1), max(ay1, by1) + ix2, iy2 = min(ax2, bx2), min(ay2, by2) + iw, ih = max(0.0, ix2 - ix1), max(0.0, iy2 - iy1) + inter = iw * ih + area_a = max(0.0, ax2 - ax1) * max(0.0, ay2 - ay1) + area_b = max(0.0, bx2 - bx1) * max(0.0, by2 - by1) + union = area_a + area_b - inter + 1e-9 + return inter / union + + +def wbf(boxes: List[np.ndarray], scores: List[float], iou_thr: float = 0.55) -> tuple[list[np.ndarray], list[float]]: + """Very small WBF: קיבוץ לפי IoU>=thr, ממוצע משוקלל לפי conf.""" + used = [False] * len(boxes) + out_boxes, out_scores = [], [] + for i in range(len(boxes)): + if used[i]: + continue + group_idxs = [i] + used[i] = True + for j in range(i + 1, len(boxes)): + if used[j]: + continue + if iou_xyxy(boxes[i], boxes[j]) >= iou_thr: + group_idxs.append(j) + used[j] = True + bs = np.array([boxes[k] for k in group_idxs], dtype=float) + ws = np.array([scores[k] for k in group_idxs], dtype=float) + wsum = ws.sum() + 1e-9 + avg = (bs * ws[:, None]).sum(axis=0) / wsum + out_boxes.append(avg) + out_scores.append(float(ws.max())) + return out_boxes, out_scores + + +# ----------------- multi-scale predict ----------------- +def predict_at_scales(model: YOLO, img_bgr: np.ndarray, scales: List[float], conf: float, imgsz: int, device: str): + H, W = img_bgr.shape[:2] + all_boxes, all_scores, all_classes = [], [], [] + for s in scales: + if s == 1.0: + resized = img_bgr + rx, ry = 1.0, 1.0 + else: + newW, newH = int(W * s), int(H * s) + resized = cv2.resize(img_bgr, (newW, newH), interpolation=cv2.INTER_LINEAR) + rx, ry = 1.0 / s, 1.0 / s + + res = model.predict(source=resized, conf=conf, imgsz=imgsz, device=device, verbose=False)[0] + if res.boxes is None or len(res.boxes) == 0: + continue + for b in res.boxes: + xyxy = b.xyxy.cpu().numpy().reshape(-1) + conf_i = float(b.conf.cpu().numpy().reshape(-1)[0]) + cls_i = float(b.cls.cpu().numpy().reshape(-1)[0]) if b.cls is not None else 0.0 + x1, y1, x2, y2 = xyxy + # החזרה לקואורדינטות המקור + x1, y1, x2, y2 = x1 * rx, y1 * ry, x2 * rx, y2 * ry + all_boxes.append(np.array([x1, y1, x2, y2], dtype=float)) + all_scores.append(conf_i) + all_classes.append(int(cls_i)) + return all_boxes, all_scores, all_classes + + +# ----------------- main runner ----------------- +def run(inp: Path, out_dir: Path, weights: Path, + scales: List[float], conf: float = 0.25, iou_thr: float = 0.55, + imgsz: int = 896, device: str = "cpu", + minio_endpoint: Optional[str] = None, minio_access: Optional[str] = None, + minio_secret: Optional[str] = None, minio_bucket: Optional[str] = None, + minio_prefix: str = "PREDICT_PWB", minio_secure: bool = False, + run_id: Optional[str] = None): + + run_id = run_id or datetime.now().strftime("%Y/%m/%d/%H%M") + + out_dir = ensure_dir(out_dir) + overlay_root = ensure_dir(out_dir / "overlay") + json_root = ensure_dir(out_dir / "json") + + cli = None + if minio_endpoint and minio_access and minio_secret and minio_bucket: + if get_client is None: + raise SystemExit("[ERR] חסר minio או minio_io.") + cli = get_client(minio_endpoint, minio_access, minio_secret, secure=minio_secure) + ensure_bucket(cli, minio_bucket) + + model = YOLO(str(weights)) + images = list(iter_images(inp)) + if not images: + raise SystemExit(f"[ERR] No images under: {inp}") + + for p in images: + img = cv2.imread(str(p)) + if img is None: + print(f"[WARN] can't read: {p}") + continue + H, W = img.shape[:2] + + + rel_path = str(p.relative_to(inp)) if inp.is_dir() else p.name + rel_parent = str(Path(rel_path).parent) + rel_stem = Path(rel_path).stem + + boxes, scores, classes = predict_at_scales(model, img, scales, conf, imgsz, device) + + + merged = [] + for cls in sorted(set(classes)): + idxs = [i for i, c in enumerate(classes) if c == cls] + if not idxs: + continue + bcls = [boxes[i] for i in idxs] + scls = [scores[i] for i in idxs] + mbox, mscore = wbf(bcls, scls, iou_thr=iou_thr) + for bb, ss in zip(mbox, mscore): + x1, y1, x2, y2 = [float(max(0, v)) for v in bb] + x1, y1 = min(x1, W - 1), min(y1, H - 1) + x2, y2 = min(x2, W - 1), min(y2, H - 1) + merged.append([x1, y1, x2, y2, float(ss), float(cls)]) + + + overlay_dir = ensure_dir(overlay_root / rel_parent) + json_dir = ensure_dir(json_root / rel_parent) + + j = { + "image": p.name, + "rel_path": rel_path, + "source_path": str(p.resolve()), + "width": W, "height": H, + "boxes": merged + } + jpath = json_dir / f"{rel_stem}.json" + jpath.write_text(json.dumps(j, ensure_ascii=False, indent=2), encoding="utf-8") + + overlay = draw_boxes(img, merged) + cv2.imwrite(str(overlay_dir / p.name), overlay) + + + if cli: + base = f"{run_id}/{minio_prefix}" + json_key = f"{base}/json/{rel_parent}/{rel_stem}.json" if rel_parent != "." else f"{base}/json/{rel_stem}.json" + ov_key = f"{base}/overlay/{rel_parent}/{p.name}" if rel_parent != "." else f"{base}/overlay/{p.name}" + put_json(cli, minio_bucket, json_key, j) + put_png(cli, minio_bucket, ov_key, overlay) + + print(f"[OK] {rel_path} WBF boxes={len(merged)} -> {jpath.relative_to(out_dir)}") + + +def parse_scales(s: str) -> List[float]: + return [float(x) for x in s.split(",") if x.strip()] + + +def main(): + ap = argparse.ArgumentParser(description="YOLO multi-scale + WBF (+optional MinIO)") + ap.add_argument("--input", required=True) + ap.add_argument("--out", required=True) + ap.add_argument("--weights", required=True) + ap.add_argument("--scales", default="0.75,1.0,1.25", help="comma-separated, e.g. 0.5,1.0,1.5") + ap.add_argument("--conf", type=float, default=0.25) + ap.add_argument("--iou", type=float, default=0.55, help="WBF IoU threshold") + ap.add_argument("--imgsz", type=int, default=896) + ap.add_argument("--device", default="cpu") + + # MinIO + ap.add_argument("--minio-endpoint", default=None) + ap.add_argument("--minio-access", default=None) + ap.add_argument("--minio-secret", default=None) + ap.add_argument("--minio-bucket", default=None) + ap.add_argument("--minio-prefix", default="PREDICT_PWB") + ap.add_argument("--minio-secure", action="store_true") + + # Run grouping + ap.add_argument("--run-id", default=None, help="תיקיית הריצה ב-MinIO (ברירת מחדל: YYYY/MM/DD/HHmm)") + + args = ap.parse_args() + + run_id = args.run_id or datetime.now().strftime("%Y/%m/%d/%H%M") + run( + inp=Path(args.input), + out_dir=Path(args.out), + weights=Path(args.weights), + scales=parse_scales(args.scales), + conf=args.conf, + iou_thr=args.iou, + imgsz=args.imgsz, + device=args.device, + minio_endpoint=args.minio_endpoint, + minio_access=args.minio_access, + minio_secret=args.minio_secret, + minio_bucket=args.minio_bucket, + minio_prefix=args.minio_prefix, + minio_secure=args.minio_secure, + run_id=run_id, + ) + + +if __name__ == "__main__": + main() diff --git a/docker-compose.yml b/docker-compose.yml index ea690204e..ec5fb7df0 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -1,17 +1,16 @@ - # ========================== # Docker Compose - AG Cloud # ========================== -version: "3.9" +# version: "3.9" # -------------------------- # Networks # -------------------------- networks: ag_cloud: + name: ag_cloud driver: bridge - # -------------------------- # Volumes # -------------------------- @@ -53,7 +52,7 @@ services: - wal_archive:/var/lib/postgresql/wal_archive - backups:/var/lib/postgresql/backups healthcheck: - test: ["CMD", "pg_isready", "-U", "missions_user", "-d", "missions_db"] + test: [ "CMD", "pg_isready", "-U", "missions_user", "-d", "missions_db" ] interval: 10s timeout: 5s retries: 5 @@ -76,7 +75,6 @@ services: - "9187:9187" networks: - ag_cloud - # ------------------------- # Sound Metrics Service # ------------------------- @@ -96,72 +94,72 @@ services: - MINIO_ENDPOINT=minio-hot:9000 - MINIO_ACCESS_KEY=minioadmin - MINIO_SECRET_KEY=minioadmin123 - - MINIO_BUCKET=telemetry - - MINIO_PREFIX=sounds/ + - MINIO_BUCKET=sound + - MINIO_PREFIX=sounds/ - command: ["python","-u","src/metrics.py"] + command: [ "python", "-u", "src/metrics.py" ] ports: - "8005:8005" depends_on: - minio-hot networks: - - ag_cloud + - ag_cloud restart: unless-stopped + # ------------------------- - # Plant Stress Detector + # Plant Stress Daily Batch # ------------------------- - plant_stress: + plant_stress_daily: build: ./services/plant_stress + env_file: + - ./services/plant_stress/.env + restart: "no" environment: - - INPUT_DIR=/data/inbox - - MODEL_DIR=/models - - POSTGRES_DSN=postgresql://missions_user:pg123@postgres:5432/missions_db - - PERIOD_DAYS=0 - - CONFIDENCE_THRESHOLD=0.6 - # - TF_ENABLE_ONEDNN_OPTS=0 # Disable oneDNN optimizations (for CPU compatibility) + MODEL_DIR: /models + CONFIDENCE_THRESHOLD: "0.60" + TF_CPP_MIN_LOG_LEVEL: "2" + TIMEZONE: Asia/Jerusalem + POSTGRES_DSN: postgresql://missions_user:pg123@postgres:5432/missions_db + MINIO_ENDPOINT: minio-hot:9000 + MINIO_ACCESS_KEY: minioadmin + MINIO_SECRET_KEY: minioadmin123 + MINIO_BUCKET: sound + MINIO_PREFIX: plants/ + MINIO_SECURE: "false" + + DEFAULT_AREA: unknown + DEFAULT_LAT: 0.0 + DEFAULT_LON: 0.0 + DEFAULT_IMAGE_URL: https://example.com/placeholder.jpg + DEFAULT_VOD: https://example.com/placeholder.mp4 + DEFAULT_HLS: https://example.com/placeholder.m3u8 + # ==== Alerts → Kafka ==== + ENABLE_ALERTS: "true" + KAFKA_BOOTSTRAP: "kafka:9092" + ALERT_TOPIC: "alerts" + ALERT_TYPE: "plant_drought_detected" + KAFKA_CLIENT_ID: "plant-stress-producer" + command: ["python","-u","/app/predict_minio_daily.py"] volumes: - "./services/plant_stress/models:/models:ro" - - "./services/plant_stress/samples:/data/inbox:ro" - depends_on: - - postgres - command: ["python", "-u", "/app/app.py"] - networks: - - ag_cloud - - - flink_writer_db: - build: - context: ./services/flink_writer_db - dockerfile: Dockerfile.flink - container_name: flink_writer_db - environment: - - KAFKA_BROKERS=kafka:9092 - - TOPICS=sensor_zone_stats,sensor_anomalies - - DB_API_BASE=http://db_api_service:8001 - - DB_API_AUTH_MODE=service - - DB_API_SERVICE_NAME=flink-writer-db - - DB_API_TOKEN_FILE=/opt/app/secrets/db_api_token - - FLINK_PARALLELISM=1 - depends_on: - kafka: - condition: service_healthy - db_api_service: - condition: service_started - networks: - - ag_cloud - restart: unless-stopped - - - + postgres: + condition: service_healthy + minio-hot: + condition: service_healthy + mc-bootstrap: + condition: service_started + kafka: + condition: service_healthy + networks: [ag_cloud] # ------------------------- - # MQTT + Kafka + Connect + Init + # MQTT + Kafka + MQTT-router # ------------------------- kafka: - build: + build: context: ./mqtt_and_kafka/kafka dockerfile: dockerfile container_name: kafka @@ -186,7 +184,7 @@ services: networks: - ag_cloud healthcheck: - test: ["CMD-SHELL", "/opt/bitnami/kafka/bin/kafka-topics.sh --bootstrap-server localhost:9092 --list >/dev/null 2>&1 || exit 1"] + test: [ "CMD-SHELL", "/opt/bitnami/kafka/bin/kafka-topics.sh --bootstrap-server localhost:9092 --list >/dev/null 2>&1 || exit 1" ] interval: 10s timeout: 5s retries: 20 @@ -194,7 +192,7 @@ services: mosquitto: image: eclipse-mosquitto:2.0 container_name: mosquitto - command: ["mosquitto", "-c", "/mqtt_and_kafka/mosquitto/config/mosquitto.conf"] + command: [ "mosquitto", "-c", "/mqtt_and_kafka/mosquitto/config/mosquitto.conf" ] ports: - "1883:1883" volumes: @@ -205,60 +203,37 @@ services: networks: - ag_cloud healthcheck: - test: ["CMD", "mosquitto_sub", "-h", "localhost", "-p", "1883", "-t", "$$SYS/#", "-C", "1", "-W", "15"] + test: [ "CMD", "mosquitto_sub", "-h", "localhost", "-p", "1883", "-t", "$$SYS/#", "-C", "1", "-W", "15" ] interval: 10s timeout: 5s retries: 12 - connect: - build: - context: ./mqtt_and_kafka - dockerfile: connect.Dockerfile - image: local/connect-with-mqtt:1.0.0 - container_name: connect - depends_on: - kafka: - condition: service_healthy - mosquitto: - condition: service_healthy - ports: - - "8083:8083" - environment: - - CONNECT_BOOTSTRAP_SERVERS=kafka:9092 - - CONNECT_GROUP_ID=agcloud-connect - - CONNECT_CONFIG_STORAGE_TOPIC=_connect_configs - - CONNECT_OFFSET_STORAGE_TOPIC=_connect_offsets - - CONNECT_STATUS_STORAGE_TOPIC=_connect_status - - CONNECT_CONFIG_STORAGE_REPLICATION_FACTOR=1 - - CONNECT_OFFSET_STORAGE_REPLICATION_FACTOR=1 - - CONNECT_STATUS_STORAGE_REPLICATION_FACTOR=1 - - CONNECT_KEY_CONVERTER=org.apache.kafka.connect.storage.StringConverter - - CONNECT_VALUE_CONVERTER=org.apache.kafka.connect.storage.StringConverter - - CONNECT_REST_ADVERTISED_HOST_NAME=localhost - - CONNECT_PLUGIN_PATH=/usr/share/java,/usr/share/confluent-hub-components - networks: - - ag_cloud - healthcheck: - test: ["CMD", "curl", "-sf", "http://localhost:8083/connectors"] - interval: 10s - timeout: 5s - retries: 12 + mqtt-router: + build: + context: ./mqtt_and_kafka/mqtt-router + image: local/mqtt-router:1.0.0 + depends_on: + kafka: + condition: service_healthy + mosquitto: + condition: service_healthy + environment: + - MQTT_HOST=mosquitto + - MQTT_PORT=1883 + - MQTT_TOPIC_FILTER=mqtt/# - init-connector: - image: curlimages/curl:8.7.1 - depends_on: - connect: - condition: service_healthy - volumes: - - ./mqtt_and_kafka/connectors:/connectors - networks: - - ag_cloud - entrypoint: > - sh -c " - echo '==> Creating MQTT connector...'; - curl -X POST -H 'Content-Type: application/json' --data @/connectors/mqtt-source.json http://connect:8083/connectors; - echo '==> Done.'; - " + - KAFKA_BOOTSTRAP=kafka:9092 + - CREATE_TOPICS=false + - DEFAULT_PARTITIONS=1 + - DEFAULT_REPLICATION=1 + networks: + - ag_cloud + restart: unless-stopped + healthcheck: + test: ["CMD", "python", "-c", "import socket; socket.create_connection(('mosquitto',1883),3); socket.create_connection(('kafka',9092),3)"] + interval: 15s + timeout: 5s + retries: 5 # -------------------------- # GUI / Runner / Gateway @@ -348,6 +323,15 @@ services: networks: - ag_cloud + pushgateway: + image: prom/pushgateway:v1.8.0 + container_name: pushgateway + ports: + - "9091:9091" + networks: + - ag_cloud + restart: unless-stopped + # -------------------------- # Desktop App # -------------------------- @@ -361,19 +345,24 @@ services: - DISPLAY=host.docker.internal:0.0 - GATEWAY_URL=http://sensors_metrics:8000 - NOTIFICATION_API_URL=http://notification_api:5000 + + - API_BASE_URL=http://db_api_service:8001 + - AUTH_BOOTSTRAP_URL=http://db_api_service:8001/auth/_dev_bootstrap + - ALERTS_WS_URL=ws://alerts-gateway:8000/ws/alerts ports: - "5900:5900" - "8080:8080" depends_on: - db_api_service - notification_api + - alerts-gateway volumes: - - ./GUI/src/vast:/app/src/vast + - ./GUI/src/vast:/app/src/vast + - ./templates:/app/templates:ro networks: - ag_cloud restart: unless-stopped - # -------------------------- # Large Mosquitto # -------------------------- @@ -399,12 +388,51 @@ services: MINIO_PROMETHEUS_AUTH_TYPE: public MINIO_ROOT_USER: minioadmin MINIO_ROOT_PASSWORD: minioadmin123 + + # ===== IMAGE NOTIFIERS ===== + MINIO_NOTIFY_KAFKA_ENABLE_aerial: "on" + MINIO_NOTIFY_KAFKA_BROKERS_aerial: "kafka:9092" + MINIO_NOTIFY_KAFKA_TOPIC_aerial: "image.new.aerial" + + MINIO_NOTIFY_KAFKA_ENABLE_air: "on" + MINIO_NOTIFY_KAFKA_BROKERS_air: "kafka:9092" + MINIO_NOTIFY_KAFKA_TOPIC_air: "image.new.air" + + MINIO_NOTIFY_KAFKA_ENABLE_fruits: "on" + MINIO_NOTIFY_KAFKA_BROKERS_fruits: "kafka:9092" + MINIO_NOTIFY_KAFKA_TOPIC_fruits: "image.new.fruits" + + MINIO_NOTIFY_KAFKA_ENABLE_leaves: "on" + MINIO_NOTIFY_KAFKA_BROKERS_leaves: "kafka:9092" + MINIO_NOTIFY_KAFKA_TOPIC_leaves: "image.new.leaves" + + MINIO_NOTIFY_KAFKA_ENABLE_ground: "on" + MINIO_NOTIFY_KAFKA_BROKERS_ground: "kafka:9092" + MINIO_NOTIFY_KAFKA_TOPIC_ground: "image.new.ground" + + MINIO_NOTIFY_KAFKA_ENABLE_field: "on" + MINIO_NOTIFY_KAFKA_BROKERS_field: "kafka:9092" + MINIO_NOTIFY_KAFKA_TOPIC_field: "image.new.field" + + # ===== SOUND NOTIFIERS ===== + MINIO_NOTIFY_KAFKA_ENABLE_plants: "on" + MINIO_NOTIFY_KAFKA_BROKERS_plants: "kafka:9092" + MINIO_NOTIFY_KAFKA_TOPIC_plants: "sound.new.plants" + + MINIO_NOTIFY_KAFKA_ENABLE_sounds: "on" + MINIO_NOTIFY_KAFKA_BROKERS_sounds: "kafka:9092" + MINIO_NOTIFY_KAFKA_TOPIC_sounds: "sound.new.sounds" + + # ===== SECURITY NOTIFIER ===== + MINIO_NOTIFY_KAFKA_ENABLE_security: "on" + MINIO_NOTIFY_KAFKA_BROKERS_security: "kafka:9092" + MINIO_NOTIFY_KAFKA_TOPIC_security: "image.new.security" ports: - - "9001:9000" # HOT S3 - - "9002:9001" # HOT Console - networks: [ag_cloud] + - "9001:9000" # HOT S3 + - "9002:9001" # HOT Console + networks: [ ag_cloud ] healthcheck: - test: ["CMD", "curl", "-fsS", "http://localhost:9000/minio/health/ready"] + test: [ "CMD", "curl", "-fsS", "http://localhost:9000/minio/health/ready" ] interval: 3s timeout: 2s retries: 40 @@ -420,11 +448,11 @@ services: MINIO_ROOT_USER: minioadmin MINIO_ROOT_PASSWORD: minioadmin123 ports: - - "9101:9000" # COLD S3 - - "9102:9001" # COLD Console - networks: [ag_cloud] + - "9101:9000" # COLD S3 + - "9102:9001" # COLD Console + networks: [ ag_cloud ] healthcheck: - test: ["CMD", "curl", "-fsS", "http://localhost:9000/minio/health/ready"] + test: [ "CMD", "curl", "-fsS", "http://localhost:9000/minio/health/ready" ] interval: 3s timeout: 2s retries: 40 @@ -434,6 +462,7 @@ services: mc-bootstrap: build: context: ./storage_with_mqtt/storage/Lifecycle_rules/minio-bootstrap + container_name: mc-bootstrap volumes: - ./storage_with_mqtt/storage/combined_minio_setup/config:/config:ro - ./storage_with_mqtt/data/config:/config @@ -442,7 +471,8 @@ services: condition: service_healthy minio-cold: condition: service_healthy - command: ["/bin/bash","-lc","/entrypoint/init.sh; tail -f /dev/null"] + kafka: + condition: service_healthy environment: MINIO_ROOT_USER: minioadmin MINIO_ROOT_PASSWORD: minioadmin123 @@ -451,8 +481,8 @@ services: MC_ALIAS_HOT: hot MC_ALIAS_COLD: cold BUCKET_IMAGERY: imagery - BUCKET_TELEMETRY: telemetry - networks: [ag_cloud] + BUCKET_SOUND: sound + networks: [ ag_cloud ] restart: unless-stopped # -------------------------- @@ -467,7 +497,7 @@ services: MINIO_ROOT_USER: minioadmin MINIO_ROOT_PASSWORD: minioadmin123 BUCKET_IMAGERY: imagery - BUCKET_TELEMETRY: telemetry + BUCKET_SOUND: sound MQTT_BROKER: large-mosquitto MQTT_PORT: 1885 MQTT_TOPIC: MQTT/imagery/# @@ -494,6 +524,83 @@ services: - ag_cloud restart: unless-stopped + mqtt_ingest_sound: + build: + context: ./storage_with_mqtt/mqtt_images/mqtt_ingest + container_name: mqtt_ingest_sound + environment: + MINIO_ENDPOINT: http://minio-hot:9000 + MINIO_ACCESS_KEY: minioadmin + MINIO_SECRET_KEY: minioadmin123 + S3_BUCKET: sound + MQTT_BROKER: large-mosquitto + MQTT_PORT: 1885 + MQTT_TOPIC: MQTT/sounds/# + MQTT_PUB_TOPIC: sound/sounds/ingested + DEFAULT_PREFIX: MIC-01 + CAMERA_PREFIX: camera + MICROPHONE_PREFIX: microphone + DUMMY_DB: 0 + DB_API_BASE: http://db_api_service:8001 + DB_API_TOKEN: auto + OUTBOX_DIR: /app/outbox + DB_API_AUTH_MODE: service + DB_API_SERVICE_NAME: mqtt_ingest_sound + INGEST_WORKERS: 8 + volumes: + - ./storage_with_mqtt/mqtt_images/outbox:/app/outbox + depends_on: + large-mosquitto: + condition: service_started + minio-hot: + condition: service_healthy + mc-bootstrap: + condition: service_started + db_api_service: + condition: service_started + networks: + - ag_cloud + restart: unless-stopped + + mqtt_ingest_sounds_ultra: + build: + context: ./storage_with_mqtt/mqtt_images/mqtt_ingest + container_name: mqtt_ingest_sounds_ultra + environment: + MINIO_ENDPOINT: http://minio-hot:9000 + MINIO_ACCESS_KEY: minioadmin + MINIO_SECRET_KEY: minioadmin123 + S3_BUCKET: sound + MQTT_BROKER: large-mosquitto + MQTT_PORT: 1885 + MQTT_TOPIC: MQTT/sounds_ultra/# + MQTT_PUB_TOPIC: sound/sounds_ultra/ingested + DEFAULT_PREFIX: MIC-02 + CAMERA_PREFIX: camera + MICROPHONE_PREFIX: microphone + DUMMY_DB: 0 + DB_API_BASE: http://db_api_service:8001 + DB_API_TOKEN: auto + OUTBOX_DIR: /app/outbox + DB_API_AUTH_MODE: service + DB_API_SERVICE_NAME: mqtt_ingest_sounds_ultra + INGEST_WORKERS: 8 + ULTRA_DIR_PREFIX: plants + volumes: + - ./storage_with_mqtt/mqtt_images/outbox:/app/outbox + depends_on: + large-mosquitto: + condition: service_started + minio-hot: + condition: service_healthy + mc-bootstrap: + condition: service_started + db_api_service: + condition: service_started + networks: + - ag_cloud + restart: unless-stopped + mqtt_publisher: build: context: ./storage_with_mqtt/mqtt_images/mqtt_publisher @@ -521,14 +628,10 @@ services: # ------------------------ sounds_classifier: build: - context: ./services/sounds/sounds_classifier + context: ./services/sounds_classifier dockerfile: Dockerfile.classifier-svc - # args: - # CHECKPOINT_URL: "CHECKPOINT=/app/classification/models/panns_data/Cnn14_mAP=0.431.pth" container_name: sounds_classifier restart: unless-stopped - # env_file: - # - ./services/sounds/sounds_classifier/src/classification/.env environment: # Runtime mode - DEVICE=cpu @@ -547,7 +650,8 @@ services: # Kafka - KAFKA_BROKERS=kafka:9092 - - ALERTS_TOPIC=dev-robot-alerts + - ALERTS_TOPIC=alerts + - ENABLE_ALERTS=true # MinIO - MINIO_ENDPOINT=minio-hot:9000 @@ -556,7 +660,7 @@ services: - MINIO_SECURE=false # Request validation - - ALLOWED_BUCKETS=imagery + - ALLOWED_BUCKETS=sound - ALLOWED_CONTENT_TYPES=audio/wav,audio/x-wav,audio/mpeg,audio/flac,audio/ogg,audio/mp4 - MAX_BYTES=104857600 @@ -578,7 +682,7 @@ services: networks: - ag_cloud healthcheck: - test: ["CMD", "curl", "-fsS", "http://localhost:8088/health"] + test: [ "CMD", "python", "-c", "import urllib.request; urllib.request.urlopen('http://localhost:8088/health').read()" ] interval: 45s timeout: 5s retries: 10 @@ -588,8 +692,6 @@ services: # DB API Service # -------------------------- - - contracts-gen: build: context: ./services/db_api_service @@ -607,7 +709,6 @@ services: - ag_cloud restart: "no" - db_api_service: build: context: ./services/db_api_service @@ -616,16 +717,18 @@ services: env_file: - ./services/db_api_service/.env environment: - DB_DSN: postgresql+psycopg://missions_user:pg123@host.docker.internal:5432/missions_db + DB_DSN: postgresql+psycopg://missions_user:pg123@postgres:5432/missions_db ENV: dev JWT_SECRET: change-me-please-very-secret JWT_ALGO: HS256 ACCESS_TTL_MIN: 15 REFRESH_TTL_DAYS: 14 DEV_SA_NAME: my-ingest-service + ADDR: 0.0.0.0 ports: - "8001:8001" volumes: + - ./services/db_api_service/app:/app/app - contracts:/app/app/contracts depends_on: contracts-gen: @@ -636,10 +739,9 @@ services: - ag_cloud restart: unless-stopped - notification_api: build: - context: ./services/sounds/API-development/src + context: ./services/API-notifications/src dockerfile: Dockerfile container_name: notification_api environment: @@ -649,7 +751,37 @@ services: depends_on: - postgres - + ripeness-api: + build: + context: ./services/ripeness-ml + dockerfile: deploy/Dockerfile + image: ripeness-api:latest + environment: + - PGHOST=postgres + - PGPORT=5432 + - PGDATABASE=missions_db + - PGUSER=missions_user + - PGPASSWORD=pg123 + - MINIO_ENDPOINT=minio-hot:9000 + - MINIO_SECURE=false + - MINIO_ACCESS_KEY=minioadmin + - MINIO_SECRET_KEY=minioadmin123 + - MODEL_NAME=best_conditional + - BATCH_LIMIT=500 + - FRUITS=Apple,Banana,Orange + depends_on: + - postgres + - minio-hot + volumes: + - ./services/ripeness-ml/checkpoints:/app/checkpoints + - ./services/ripeness-ml/configs:/app/configs + - ./services/ripeness-ml/model:/app/model + container_name: ripeness-api + networks: [ ag_cloud ] + ports: + - "8091:8088" + restart: unless-stopped + # -------------------------- # Flink JobManager & TaskManager # -------------------------- @@ -657,12 +789,12 @@ services: build: context: ./streaming/flink dockerfile: Dockerfile.flink-py - image: agcloud-flink-py:1.18 + image: agcloud-flink-py:1.18 container_name: flink-jobmanager command: jobmanager ports: - "8081:8081" - networks: [ag_cloud] + networks: [ ag_cloud ] environment: - | FLINK_PROPERTIES= @@ -678,7 +810,7 @@ services: fs.s3a.connection.ssl.enabled: false python.client.executable: /usr/bin/python3 python.executable: /usr/bin/python3 - - HTTP_INFER_URL=http://fruit-inference-http:8000/infer_json + - HTTP_INFER_URL=http://fruit-inference-http:8004/infer_json volumes: - ./streaming/flink/jobs:/opt/flink/jobs:ro - ./streaming/flink/connectors/flink-json-1.18.1.jar:/opt/flink/lib/flink-json-1.18.1.jar:ro @@ -689,6 +821,50 @@ services: - ./streaming/flink/connectors/snappy-java-1.1.10.5.jar:/opt/flink/lib/snappy-java-1.1.10.5.jar:ro restart: unless-stopped + audio_compression: + build: + context: ./services/compression + dockerfile: Dockerfile + container_name: audio_compression + environment: + - RAW_MAX_AGE_DAYS=30 + - COMPRESSION_CODEC=opus + - COMPRESSED_MAX_AGE_DAYS=90 + - CHECK_INTERVAL_SECONDS=3600 + - MINIO_ENDPOINT=minio-hot:9000 + - ACCESS_KEY=minioadmin + - SECRET_KEY=minioadmin123 + - BUCKET_NAME=imagery + depends_on: + minio-hot: + condition: service_healthy + mc-bootstrap: + condition: service_started + networks: + - ag_cloud + restart: unless-stopped + + flink_writer_db: + build: + context: ./services/flink_writer_db + dockerfile: Dockerfile.flink + container_name: flink_writer_db + environment: + - KAFKA_BROKERS=kafka:9092 + - TOPICS=sensor_zone_stats,sensor_anomalies,image_new_security_connections,alerts,image_new_aerial_connections,aerial_images_metadata,aerial_image_object_detections,aerial_image_anomaly_detections,aerial_images_complete_metadata,field_polygons,aerial_image_segmentation,sound_new_sounds_connections,sound_new_plants_connections + - DB_API_BASE=http://db_api_service:8001 + - DB_API_AUTH_MODE=service + - DB_API_SERVICE_NAME=flink-writer-db + - DB_API_TOKEN_FILE=/opt/app/secrets/db_api_token + - FLINK_PARALLELISM=1 + depends_on: + kafka: + condition: service_healthy + db_api_service: + condition: service_started + networks: + - ag_cloud + restart: unless-stopped flink-taskmanager: image: agcloud-flink-py:1.18 @@ -697,7 +873,7 @@ services: depends_on: flink-jobmanager: condition: service_started - networks: [ag_cloud] + networks: [ ag_cloud ] environment: - | FLINK_PROPERTIES= @@ -713,7 +889,7 @@ services: fs.s3a.connection.ssl.enabled: false python.client.executable: /usr/bin/python3 python.executable: /usr/bin/python3 - - HTTP_INFER_URL=http://fruit-inference-http:8000/infer_json + - HTTP_INFER_URL=http://fruit-inference-http:8004/infer_json volumes: - ./streaming/flink/connectors/flink-json-1.18.1.jar:/opt/flink/lib/flink-json-1.18.1.jar:ro - ./streaming/flink/connectors/flink-sql-connector-kafka-3.2.0-1.18.jar:/opt/flink/lib/flink-sql-connector-kafka-3.2.0-1.18.jar:ro @@ -723,31 +899,78 @@ services: - ./streaming/flink/connectors/snappy-java-1.1.10.5.jar:/opt/flink/lib/snappy-java-1.1.10.5.jar:ro restart: unless-stopped - - # -------------------------- # Inference HTTP Service # -------------------------- fruit-inference-http: build: - context: ./services/inference_http + context: ./services/inference_http dockerfile: Dockerfile environment: - - TEAM=fruit - - WEIGHTS_PATH=/app/weights/fruit_cls_best.ts - - MINIO_ENDPOINT=minio-hot:9000 - - MINIO_ACCESS_KEY=minioadmin - - MINIO_SECRET_KEY=minioadmin123 - - MINIO_SECURE=0 + - TEAM=fruit + - WEIGHTS_PATH=/app/weights/fruit_cls_best.ts + - MINIO_ENDPOINT=minio-hot:9000 + - MINIO_ACCESS_KEY=minioadmin + - MINIO_SECRET_KEY=minioadmin123 + - MINIO_SECURE=0 volumes: - ./services/inference_http/weights:/app/weights:ro container_name: fruit-inference-http + networks: [ ag_cloud ] + ports: + - "8011:8004" + restart: unless-stopped + + camera-inference-http: + build: + context: ./services/inference_http + dockerfile: Dockerfile + environment: + - TEAM=camera + - WEIGHTS_PATH=/app/weights/yolov8-fruits.pt + - MINIO_ENDPOINT=minio-hot:9000 + - MINIO_ACCESS_KEY=minioadmin + - MINIO_SECRET_KEY=minioadmin123 + - MINIO_SECURE=0 + volumes: + - ./services/inference_http/weights:/app/weights:ro + container_name: camera-inference-http networks: [ag_cloud] ports: - - "8011:8000" + - "8012:8004" + restart: unless-stopped + + + soil-inference-http: + build: + context: ./services/inference_http + dockerfile: Dockerfile + environment: + + - TEAM=soil_moisture + - WEIGHTS_PATH=/app/weights/soil_moisture_best.onnx + - MINIO_ENDPOINT=minio-hot:9000 + - MINIO_ACCESS_KEY=minioadmin + - MINIO_SECRET_KEY=minioadmin123 + - MINIO_SECURE=0 + - PG_DSN=postgresql://missions_user:pg123@postgres:5432/missions_db + - KAFKA_BROKERS=kafka:9092 + - KAFKA_TOPIC=irrigation.control + - KAFKA_DLT=irrigation.control.dlq + + + volumes: + - ./services/inference_http/weights:/app/weights:ro + - ./services/inference_http/adapters:/app/adapters + - ./services/inference_http/soil_moisture:/app/soil_moisture + depends_on: + - minio-hot + - postgres + ports: + - "8013:8004" + networks: [ag_cloud] restart: unless-stopped - # -------------------------- # Flink Jobs # -------------------------- @@ -758,12 +981,12 @@ services: flink-jobmanager: { condition: service_started } flink-taskmanager: { condition: service_started } fruit-inference-http: { condition: service_started } - networks: [ag_cloud] + networks: [ ag_cloud ] environment: - KAFKA_BOOTSTRAP=kafka:9092 - INPUT_TOPIC=imagery.new.fruit - TEAM=fruit - - HTTP_URL=http://fruit-inference-http:8000/infer_json + - HTTP_URL=http://fruit-inference-http:8004/infer_json - DLQ_TOPIC=dlq.inference.http - GROUP_ID=http-dispatcher-fruit - PARALLELISM=2 @@ -776,27 +999,238 @@ services: - ./streaming/flink/connectors/kafka-clients-3.2.3.jar:/opt/flink/lib/kafka-clients-3.2.3.jar:ro - ./streaming/flink/connectors/lz4-java-1.8.0.jar:/opt/flink/lib/lz4-java-1.8.0.jar:ro - ./streaming/flink/connectors/snappy-java-1.1.10.5.jar:/opt/flink/lib/snappy-java-1.1.10.5.jar:ro - command: [ - "bash","-lc", - "set -e; - echo 'Waiting for JobManager to accept commands...'; - until /opt/flink/bin/flink list --jobmanager flink-jobmanager:8081 >/dev/null 2>&1; do - echo 'still waiting...'; sleep 3; - done; - echo 'JobManager is ready!'; - /opt/flink/bin/flink run \ - -Dpython.client.executable=/usr/bin/python3 \ - -Dpython.executable=/usr/bin/python3 \ - -Dpipeline.jars=file:///opt/flink/lib/flink-connector-kafka-3.2.0-1.18.jar,file:///opt/flink/lib/flink-sql-connector-kafka-3.2.0-1.18.jar,file:///opt/flink/lib/flink-json-1.18.1.jar \ - --jobmanager flink-jobmanager:8081 \ - --detached \ - --python /opt/flink/jobs/http_dispatcher.py \ - -- \ - --bootstrap kafka:9092 \ - --input-topic imagery.new.fruit \ - --team fruit \ - --http-url http://fruit-inference-http:8000/infer_json \ - --group-id http-dispatcher-fruit \ - --dlq-topic dlq.inference.http; - tail -f /dev/null" - ] + command: [ "bash", "-lc", "set -e; echo 'Waiting for JobManager to accept commands...'; until /opt/flink/bin/flink list --jobmanager flink-jobmanager:8081 >/dev/null 2>&1; do echo 'still waiting...'; sleep 3; done; echo 'JobManager is ready!'; /opt/flink/bin/flink run -Dpython.client.executable=/usr/bin/python3 -Dpython.executable=/usr/bin/python3 -Dpipeline.jars=file:///opt/flink/lib/flink-connector-kafka-3.2.0-1.18.jar,file:///opt/flink/lib/flink-sql-connector-kafka-3.2.0-1.18.jar,file:///opt/flink/lib/flink-json-1.18.1.jar --jobmanager flink-jobmanager:8081 --detached --python /opt/flink/jobs/http_dispatcher.py -- --bootstrap kafka:9092 --input-topic imagery.new.fruit --team fruit --http-url http://fruit-inference-http:8004/infer_json --group-id http-dispatcher-fruit --dlq-topic dlq.inference.http; tail -f /dev/null" ] + restart: always + + flink-dispatcher-camera: + image: agcloud-flink-py:1.18 + container_name: flink-dispatcher-camera + depends_on: + flink-jobmanager: { condition: service_started } + flink-taskmanager: { condition: service_started } + camera-inference-http: { condition: service_started } + networks: [ag_cloud] + environment: + - KAFKA_BOOTSTRAP=kafka:9092 + - INPUT_TOPIC=imagery.new.camera + - TEAM=camera + - HTTP_URL=http://camera-inference-http:8004/infer_json + - DLQ_TOPIC=dlq.inference.http + - GROUP_ID=http-dispatcher-camera + - PARALLELISM=2 + - PYFLINK_CLIENT_EXECUTABLE=/usr/bin/python3 + volumes: + - ./streaming/flink/jobs:/opt/flink/jobs:ro + - ./streaming/flink/connectors:/opt/flink/lib/connectors:ro + command: [ "bash", "-lc", "set -e; echo 'Waiting for JobManager to accept commands...'; until /opt/flink/bin/flink list --jobmanager flink-jobmanager:8081 >/dev/null 2>&1; do echo 'still waiting...'; sleep 3; done; echo 'JobManager is ready!'; /opt/flink/bin/flink run -Dpython.client.executable=/usr/bin/python3 -Dpython.executable=/usr/bin/python3 -Dpipeline.jars=file:///opt/flink/lib/connectors/flink-connector-kafka-3.2.0-1.18.jar,file:///opt/flink/lib/connectors/flink-sql-connector-kafka-3.2.0-1.18.jar,file:///opt/flink/lib/connectors/flink-json-1.18.1.jar --jobmanager flink-jobmanager:8081 --detached --python /opt/flink/jobs/http_dispatcher.py -- --bootstrap kafka:9092 --input-topic imagery.new.camera --team camera --http-url http://camera-inference-http:8004/infer_json --group-id http-dispatcher-camera --dlq-topic dlq.inference.http; tail -f /dev/null" ] + restart: always + + + flink-dispatcher-soil: + image: agcloud-flink-py:1.18 + depends_on: + flink-jobmanager: { condition: service_started } + flink-taskmanager: { condition: service_started } + soil-inference-http: { condition: service_started } + networks: [ag_cloud] + environment: + - KAFKA_BOOTSTRAP=kafka:9092 + - INPUT_TOPIC=image.new.ground + - TEAM=soil_moisture + - HTTP_URL=http://soil-inference-http:8004/infer_json + - DLQ_TOPIC=dlq.inference.http + - GROUP_ID=http-dispatcher-soil + - PARALLELISM=1 + - PYFLINK_CLIENT_EXECUTABLE=/usr/bin/python3 + volumes: + - ./streaming/flink/jobs:/opt/flink/jobs:ro + - ./streaming/flink/connectors:/opt/flink/lib/connectors:ro + command: [ "bash", "-lc", "set -e; echo 'Waiting...'; until /opt/flink/bin/flink list --jobmanager flink-jobmanager:8081 >/dev/null 2>&1; do echo 'still waiting...'; sleep 3; done; echo 'JobManager is ready!'; /opt/flink/bin/flink run -Dpython.client.executable=/usr/bin/python3 -Dpython.executable=/usr/bin/python3 -Dpipeline.jars=file:///opt/flink/lib/connectors/... --jobmanager flink-jobmanager:8081 --detached --python /opt/flink/jobs/http_dispatcher.py -- --bootstrap kafka:9092 --input-topic image.new.ground --team soil_moisture --http-url http://soil-inference-http:8004/infer_json --group-id http-dispatcher-soil --dlq-topic dlq.inference.http; tail -f /dev/null" ] + + + flink-alerts-job: + build: + context: ./services/alerts_forwarder + dockerfile: Dockerfile.flink + container_name: alerts-forwarder + depends_on: + kafka: + condition: service_healthy + alertmanager_service: + condition: service_started + environment: + - PYTHONPATH=/opt/app + - KAFKA_BROKERS=kafka:9092 + - ALERTMANAGER_SERVICE_URL=http://alertmanager_service:8090/alerts + command: [ "python", "/opt/app/alerts_forwarder.py" ] + networks: + - ag_cloud + restart: unless-stopped + + alertmanager: + image: prom/alertmanager:v0.27.0 + container_name: alertmanager + command: + - "--config.file=/etc/alertmanager/alertmanager.yml" + - "--storage.path=/alertmanager" + - "--log.level=debug" + volumes: + - ./services/alertmanager_service/compose/alertmanager.yml:/etc/alertmanager/alertmanager.yml:ro + ports: + - "9093:9093" + networks: + - ag_cloud + restart: always + + alertmanager_service: + build: + context: ./services/alertmanager_service/src + dockerfile: Dockerfile + container_name: alertmanager_service + ports: + - "8090:8090" + command: [ "uvicorn", "app:app", "--host", "0.0.0.0", "--port", "8090" ] + volumes: + - ./templates:/app/templates:ro + environment: + - CFG_PATH=/app/templates/templates.yml + - ALERTMANAGER_URL=http://alertmanager:9093 + - GATEWAY_URL=http://alerts-gateway:8000/internal/alert + depends_on: + - alertmanager + - alerts-gateway + networks: + - ag_cloud + + alerts-gateway: + build: + context: ./services/alertmanager_service/src + dockerfile: Dockerfile + container_name: alerts_gateway + command: [ "uvicorn", "gateway:app", "--host", "0.0.0.0", "--port", "8000" ] + ports: + - "8010:8000" + networks: + - ag_cloud + + image-linker-jobmanager: + build: + context: ./services/image-linker + dockerfile: Dockerfile.flink + container_name: image-linker-jobmanager + command: jobmanager + ports: + - "8084:8081" + environment: + - JOB_MANAGER_RPC_ADDRESS=image-linker-jobmanager + - KAFKA_BROKERS=kafka:9092 + - CONFIG_PATH=/opt/app/config/topics.yaml + networks: + - ag_cloud + + image-linker-taskmanager: + build: + context: ./services/image-linker + dockerfile: Dockerfile.flink + container_name: image-linker-taskmanager + command: taskmanager + environment: + - JOB_MANAGER_RPC_ADDRESS=image-linker-jobmanager + - KAFKA_BROKERS=kafka:9092 + - CONFIG_PATH=/opt/app/config/topics.yaml + depends_on: + image-linker-jobmanager: + condition: service_started + networks: + - ag_cloud + + image-linker-submitter: + build: + context: ./services/image-linker + dockerfile: Dockerfile.flink + container_name: image-linker-submit + depends_on: + image-linker-jobmanager: + condition: service_started + command: > + bash -lc "sleep 10 && + flink run -m image-linker-jobmanager:8081 -py /opt/app/job_linker.py && + echo 'Image-Linker job submitted successfully' && + sleep 1" + networks: + - ag_cloud + + flink-sounds-http-jobmanager: + build: + context: ./services/sounds_flink + dockerfile: Dockerfile + container_name: flink-sounds-http-jobmanager + command: jobmanager + ports: + - "8083:8081" + environment: + JOB_MANAGER_RPC_ADDRESS: flink-sounds-http-jobmanager + KAFKA_BROKERS: kafka:9092 + SOURCE_TOPIC: sound_new_sounds_connections + SINK_TOPIC: "" + GROUP_ID: flink-classifier-sounds + CLASSIFIER_HTTP_URL: http://sounds_classifier:8088/classify + DEFAULT_PARALLELISM: 2 + KAFKA_START: earliest + PYTHON: /opt/venv/bin/python + FLINK_PYTHON: /opt/venv/bin/python + networks: + - ag_cloud + + flink-sounds-http-taskmanager: + build: + context: ./services/sounds_flink + dockerfile: Dockerfile + container_name: flink-sounds-http-taskmanager + command: taskmanager + depends_on: + flink-sounds-http-jobmanager: + condition: service_started + environment: + JOB_MANAGER_RPC_ADDRESS: flink-sounds-http-jobmanager + PYTHON: /opt/venv/bin/python + FLINK_PYTHON: /opt/venv/bin/python + FLINK_PROPERTIES: |- + jobmanager.rpc.address: flink-sounds-http-jobmanager + taskmanager.numberOfTaskSlots: 2 + networks: + - ag_cloud + + flink-sounds-http-submit: + build: + context: ./services/sounds_flink + dockerfile: Dockerfile + container_name: flink-sounds-http-submit + depends_on: + flink-sounds-http-jobmanager: + condition: service_started + flink-sounds-http-taskmanager: + condition: service_started + command: + - /opt/flink/bin/flink + - run + - -d + - -m + - flink-sounds-http-jobmanager:8081 + - -Dpython.client.executable=/opt/venv/bin/python + - -Dpython.executable=/opt/venv/bin/python + - -py + - /opt/app/flink_job.py + environment: + JOB_MANAGER_RPC_ADDRESS: flink-sounds-http-jobmanager + KAFKA_BROKERS: kafka:9092 + SOURCE_TOPIC: sound_new_sounds_connections + SINK_TOPIC: "" + GROUP_ID: flink-classifier-sounds + CLASSIFIER_HTTP_URL: http://sounds_classifier:8088/classify + DEFAULT_PARALLELISM: 2 + KAFKA_START: earliest + PYTHON: /opt/venv/bin/python + FLINK_PYTHON: /opt/venv/bin/python + networks: + - ag_cloud diff --git a/grafana/dashboards/sound_dashboard.json b/grafana/dashboards/sound_dashboard.json new file mode 100644 index 000000000..c3aaa958d --- /dev/null +++ b/grafana/dashboards/sound_dashboard.json @@ -0,0 +1,107 @@ +{ + "id": null, + "uid": "sound-dashboard", + "title": "Sound 1 – Grafana Dashboards", + "tags": [ + "sound", + "metrics", + "demo" + ], + "timezone": "browser", + "schemaVersion": 38, + "version": 1, + "refresh": "5s", + "panels": [ + { + "type": "timeseries", + "title": "🎚 Volume Trends (dB)", + "targets": [ + { + "expr": "sound_volume_db", + "legendFormat": "Mic Volume" + } + ], + "gridPos": { + "x": 0, + "y": 0, + "w": 12, + "h": 7 + } + }, + { + "type": "gauge", + "title": "🎯 Classifier Accuracy Rate", + "targets": [ + { + "expr": "classifier_rate" + } + ], + "fieldConfig": { + "defaults": { + "min": 0, + "max": 1, + "unit": "percentunit", + "thresholds": { + "mode": "absolute", + "steps": [ + { + "color": "red", + "value": null + }, + { + "color": "yellow", + "value": 0.7 + }, + { + "color": "green", + "value": 0.9 + } + ] + } + } + }, + "gridPos": { + "x": 12, + "y": 0, + "w": 12, + "h": 7 + } + }, + { + "type": "stat", + "title": "⏱ Microphone Uptime (seconds)", + "targets": [ + { + "expr": "mic_uptime_seconds" + } + ], + "gridPos": { + "x": 0, + "y": 7, + "w": 12, + "h": 6 + }, + "fieldConfig": { + "defaults": { + "unit": "s" + } + } + }, + { + "type": "timeseries", + "title": "⚠️ Anomaly Counters", + "targets": [ + { + "expr": "anomaly_count", + "legendFormat": "Anomalies" + } + ], + "gridPos": { + "x": 12, + "y": 7, + "w": 12, + "h": 6 + } + } + ] +} \ No newline at end of file diff --git a/grafana/sensors-to-pushgateway.ps1 b/grafana/sensors-to-pushgateway.ps1 new file mode 100644 index 000000000..2a87b1623 --- /dev/null +++ b/grafana/sensors-to-pushgateway.ps1 @@ -0,0 +1,36 @@ +# =============================================== +# Push local sensor JSON metrics to Pushgateway +# =============================================== + +# URL of Prometheus Pushgateway +$PushUrl = "http://pushgateway:9091/metrics/job/local_sensors" + +# Use relative path inside the repo or container (cross-platform) +# Example: if script is under grafana/, look for ./local_sensors/ +$BaseDir = Split-Path -Parent $MyInvocation.MyCommand.Definition +$SensorDir = Join-Path $BaseDir "local_sensors" + +Write-Host "Monitoring folder: $SensorDir" +Write-Host "Pushing metrics to: $PushUrl" + +while ($true) { + Get-ChildItem -Path $SensorDir -Filter "*.json" | ForEach-Object { + try { + $data = Get-Content $_.FullName | ConvertFrom-Json + $mic = $data.mic_id + $body = @" +sound_volume_db{mic_id="$mic"} $($data.volume_db) +classifier_rate{mic_id="$mic"} $($data.classifier_rate) +mic_uptime_seconds{mic_id="$mic"} $($data.uptime_sec) +anomaly_count{mic_id="$mic"} $($data.anomaly_count) +"@ + Invoke-RestMethod -Uri "$PushUrl/instance/$mic" -Method Put -Body ($body + "`n") -ContentType "text/plain" + Write-Host "✅ Pushed metrics for $mic" + } + catch { + Write-Warning "⚠️ Failed for file $($_.Name): $_" + } + } + + Start-Sleep -Seconds 5 +} diff --git a/grafana/simulate-sound-metrics.ps1 b/grafana/simulate-sound-metrics.ps1 new file mode 100644 index 000000000..769a954bf --- /dev/null +++ b/grafana/simulate-sound-metrics.ps1 @@ -0,0 +1,33 @@ +$job = "sound_dashboard" +$instance = "mic-001" + +$volume = 40 +$rate = 0.8 +$uptime = 0 +$anomalies = 0 + +while ($true) { + $volume += Get-Random -Minimum -3 -Maximum 3 + $rate += (Get-Random -Minimum -0.02 -Maximum 0.02) + $uptime += 5 + if ((Get-Random -Minimum 0 -Maximum 10) -gt 8) { $anomalies++ } + + if ($volume -lt 20) { $volume = 20 } + elseif ($volume -gt 90) { $volume = 90 } + + if ($rate -lt 0.5) { $rate = 0.5 } + elseif ($rate -gt 1.0) { $rate = 1.0 } + + $body = @" +sound_volume_db $volume +classifier_rate $rate +mic_uptime_seconds $uptime +app_anomaly_total $anomalies +"@ + + Invoke-RestMethod -Uri "http://pushgateway:9091/metrics/job/$job/instance/$instance" ` + -Method Put -Body ($body + "`n") -ContentType "text/plain" + + Write-Host "Pushed: V=$volume, R=$rate, U=$uptime, A=$anomalies" + Start-Sleep -Seconds 5 +} diff --git a/mqtt_and_kafka/README.md b/mqtt_and_kafka/README.md index 2bd0e4cdb..ee83d5da5 100644 --- a/mqtt_and_kafka/README.md +++ b/mqtt_and_kafka/README.md @@ -1,4 +1,4 @@ -# AgCloud-Sounds +# AgCloud-telemetrys ## AgCloud – End-to-End MQTT → Kafka (Quickstart) diff --git a/mqtt_and_kafka/connect/plugins/confluentinc-kafka-connect-mqtt-1.7.6/assets/mqtt.png b/mqtt_and_kafka/connect/plugins/confluentinc-kafka-connect-mqtt-1.7.6/assets/mqtt.png deleted file mode 100644 index d06210321..000000000 Binary files a/mqtt_and_kafka/connect/plugins/confluentinc-kafka-connect-mqtt-1.7.6/assets/mqtt.png and /dev/null differ diff --git a/mqtt_and_kafka/connect/plugins/confluentinc-kafka-connect-mqtt-1.7.6/doc/README.md b/mqtt_and_kafka/connect/plugins/confluentinc-kafka-connect-mqtt-1.7.6/doc/README.md deleted file mode 100644 index 64c1d8bee..000000000 --- a/mqtt_and_kafka/connect/plugins/confluentinc-kafka-connect-mqtt-1.7.6/doc/README.md +++ /dev/null @@ -1,95 +0,0 @@ -# Introduction - -This project provides connectors for Kafka Connect to read and write data to a MQTT broker. - -# Using with AWS IOT - -```bash -sudo mkdir /opt/aws-iot/ -sudo aws iot create-keys-and-certificate --set-as-active --certificate-pem-outfile /opt/aws-iot/cert.crt --private-key-outfile /opt/aws-iot/private.key --public-key-outfile /opt/aws-iot/public.key --region us-east-1 -sudo openssl pkcs12 -export -in /opt/aws-iot/cert.crt -inkey /opt/aws-iot/private.key -out /opt/aws-iot/p12.keystore -name alias -(type in the export password) - -$ keytool -importkeystore -srckeystore p12.keystore -srcstoretype PKCS12 -srcstorepass -alias alias -deststorepass -destkeypass -destkeystore my.keystore - -$ openssl x509 -outform der -in certificate.pem -out certificate.der -$ keytool -import -alias your-alias -keystore cacerts -file certificate.der -``` - -# Building - -To build the plugin archive, ensure you have [Artifactory credentials](https://github.com/confluentinc/connect-plugins-common#artifactory-credentials-for-building-plugins) set up on the build machine and use the following mvn command: - -``` -$ mvn clean package -... -output truncated -... -[INFO] Building zip: /Users/arjun/Sandbox/clones/kafka-connect-mqtt-wicknicks/target/components/packages/confluentinc-kafka-connect-mqtt-1.0.0-SNAPSHOT.zip -[INFO] -[INFO] --- maven-jar-plugin:3.0.2:test-jar (default) @ kafka-connect-mqtt --- -[INFO] Building jar: /Users/arjun/Sandbox/clones/kafka-connect-mqtt-wicknicks/target/kafka-connect-mqtt-1.0.0-SNAPSHOT-tests.jar -``` - -The location of the plugin archive is shown above in the `target/components/packages` directory. - -## Integration Tests - -To run integration tests from the terminal, start a Docker daemon locally, and run the following command: - -``` -$ mvn clean integration-test - -... -output truncated -... -[INFO] ------------------------------------------------------- -[INFO] T E S T S -[INFO] ------------------------------------------------------- -[INFO] Running io.confluent.connect.mqtt.integration.MqttSourceIntegrationTest -[INFO] Tests run: 1, Failures: 0, Errors: 0, Skipped: 0, Time elapsed: 12.817 s - in io.confluent.connect.mqtt.integration.MqttSourceIntegrationTest -[INFO] Running io.confluent.connect.mqtt.integration.MqttSinkIntegrationTest -[INFO] Tests run: 1, Failures: 0, Errors: 0, Skipped: 0, Time elapsed: 9.083 s - in io.confluent.connect.mqtt.integration.MqttSinkIntegrationTest -... -``` - -To **skip** running integration tests when running maven commands, enable the `skipIntegrationTests` flag. For example: - -``` -mvn clean install -DskipIntegrationTests -``` - -# Documentation - -## Location -Documentation on the connector is hosted on Confluent's -[docs site](https://docs.confluent.io/current/connect/kafka-connect-mqtt/). - -Source code is located in Confluent's -[docs repo](https://github.com/confluentinc/docs/tree/master/connect/kafka-connect-mqtt). If changes -are made to configuration options for the connector, be sure to generate the RST docs (as described -below) and open a PR against the docs repo to publish those changes! - -## Configs -Documentation on the configurations for each connector can be autotomatically generated via Maven. - -To generate documentation for the sink connector: -```bash -mvn -Pdocs exec:java@sink-config-docs -``` - -To generate documentation for the source connector: -```bash -mvn -Pdocs exec:java@source-config-docs -``` - -# Compatibility Matrix: - -This mqtt connector has been tested against the following versions of AK and Eclipse Mosquitto -Broker: - -| | AK 1.0 | AK 1.1 | AK 2.0 | -| ----------------------------- | ------------------ | ------------- | ------------- | -| **Eclipse Mosquitto v1.4.12** | NOT COMPATIBLE (1) | OK | OK | - -1. The connector needs header support in Connect. diff --git a/mqtt_and_kafka/connect/plugins/confluentinc-kafka-connect-mqtt-1.7.6/doc/licenses/LICENSE-EDL-org.eclipse.paho.client.mqttv3-1.2.0.txt b/mqtt_and_kafka/connect/plugins/confluentinc-kafka-connect-mqtt-1.7.6/doc/licenses/LICENSE-EDL-org.eclipse.paho.client.mqttv3-1.2.0.txt deleted file mode 100644 index cf989f145..000000000 --- a/mqtt_and_kafka/connect/plugins/confluentinc-kafka-connect-mqtt-1.7.6/doc/licenses/LICENSE-EDL-org.eclipse.paho.client.mqttv3-1.2.0.txt +++ /dev/null @@ -1,15 +0,0 @@ - -Eclipse Distribution License - v 1.0 - -Copyright (c) 2007, Eclipse Foundation, Inc. and its licensors. - -All rights reserved. - -Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. - Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. - Neither the name of the Eclipse Foundation, Inc. nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. - -THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - diff --git a/mqtt_and_kafka/connect/plugins/confluentinc-kafka-connect-mqtt-1.7.6/doc/licenses/LICENSE-connect-utils-0.3.140.txt b/mqtt_and_kafka/connect/plugins/confluentinc-kafka-connect-mqtt-1.7.6/doc/licenses/LICENSE-connect-utils-0.3.140.txt deleted file mode 100644 index 412a9e9d8..000000000 --- a/mqtt_and_kafka/connect/plugins/confluentinc-kafka-connect-mqtt-1.7.6/doc/licenses/LICENSE-connect-utils-0.3.140.txt +++ /dev/null @@ -1,422 +0,0 @@ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - Apache License, Version 2.0 - - - - - - - - - - - -
- -
- -
-
-
- Apache Logo -
-
- - - -
-
-
- - -
- The Apache Way - Contribute - ASF Sponsors -
-
-
-
-

Apache License

Version 2.0, January 2004

-http://www.apache.org/licenses/

-

TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION

-

1. Definitions.

-

"License" shall mean the terms and conditions for use, reproduction, and -distribution as defined by Sections 1 through 9 of this document.

-

"Licensor" shall mean the copyright owner or entity authorized by the -copyright owner that is granting the License.

-

"Legal Entity" shall mean the union of the acting entity and all other -entities that control, are controlled by, or are under common control with -that entity. For the purposes of this definition, "control" means (i) the -power, direct or indirect, to cause the direction or management of such -entity, whether by contract or otherwise, or (ii) ownership of fifty -percent (50%) or more of the outstanding shares, or (iii) beneficial -ownership of such entity.

-

"You" (or "Your") shall mean an individual or Legal Entity exercising -permissions granted by this License.

-

"Source" form shall mean the preferred form for making modifications, -including but not limited to software source code, documentation source, -and configuration files.

-

"Object" form shall mean any form resulting from mechanical transformation -or translation of a Source form, including but not limited to compiled -object code, generated documentation, and conversions to other media types.

-

"Work" shall mean the work of authorship, whether in Source or Object form, -made available under the License, as indicated by a copyright notice that -is included in or attached to the work (an example is provided in the -Appendix below).

-

"Derivative Works" shall mean any work, whether in Source or Object form, -that is based on (or derived from) the Work and for which the editorial -revisions, annotations, elaborations, or other modifications represent, as -a whole, an original work of authorship. For the purposes of this License, -Derivative Works shall not include works that remain separable from, or -merely link (or bind by name) to the interfaces of, the Work and Derivative -Works thereof.

-

"Contribution" shall mean any work of authorship, including the original -version of the Work and any modifications or additions to that Work or -Derivative Works thereof, that is intentionally submitted to Licensor for -inclusion in the Work by the copyright owner or by an individual or Legal -Entity authorized to submit on behalf of the copyright owner. For the -purposes of this definition, "submitted" means any form of electronic, -verbal, or written communication sent to the Licensor or its -representatives, including but not limited to communication on electronic -mailing lists, source code control systems, and issue tracking systems that -are managed by, or on behalf of, the Licensor for the purpose of discussing -and improving the Work, but excluding communication that is conspicuously -marked or otherwise designated in writing by the copyright owner as "Not a -Contribution."

-

"Contributor" shall mean Licensor and any individual or Legal Entity on -behalf of whom a Contribution has been received by Licensor and -subsequently incorporated within the Work.

-

2. Grant of Copyright License. Subject to the -terms and conditions of this License, each Contributor hereby grants to You -a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable -copyright license to reproduce, prepare Derivative Works of, publicly -display, publicly perform, sublicense, and distribute the Work and such -Derivative Works in Source or Object form.

-

3. Grant of Patent License. Subject to the terms -and conditions of this License, each Contributor hereby grants to You a -perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable -(except as stated in this section) patent license to make, have made, use, -offer to sell, sell, import, and otherwise transfer the Work, where such -license applies only to those patent claims licensable by such Contributor -that are necessarily infringed by their Contribution(s) alone or by -combination of their Contribution(s) with the Work to which such -Contribution(s) was submitted. If You institute patent litigation against -any entity (including a cross-claim or counterclaim in a lawsuit) alleging -that the Work or a Contribution incorporated within the Work constitutes -direct or contributory patent infringement, then any patent licenses -granted to You under this License for that Work shall terminate as of the -date such litigation is filed.

-

4. Redistribution. You may reproduce and -distribute copies of the Work or Derivative Works thereof in any medium, -with or without modifications, and in Source or Object form, provided that -You meet the following conditions:

-
    -
  1. You must give any other recipients of the Work or Derivative Works a -copy of this License; and
  2. - -
  3. You must cause any modified files to carry prominent notices stating -that You changed the files; and
  4. - -
  5. You must retain, in the Source form of any Derivative Works that You -distribute, all copyright, patent, trademark, and attribution notices from -the Source form of the Work, excluding those notices that do not pertain to -any part of the Derivative Works; and
  6. - -
  7. If the Work includes a "NOTICE" text file as part of its distribution, -then any Derivative Works that You distribute must include a readable copy -of the attribution notices contained within such NOTICE file, excluding -those notices that do not pertain to any part of the Derivative Works, in -at least one of the following places: within a NOTICE text file distributed -as part of the Derivative Works; within the Source form or documentation, -if provided along with the Derivative Works; or, within a display generated -by the Derivative Works, if and wherever such third-party notices normally -appear. The contents of the NOTICE file are for informational purposes only -and do not modify the License. You may add Your own attribution notices -within Derivative Works that You distribute, alongside or as an addendum to -the NOTICE text from the Work, provided that such additional attribution -notices cannot be construed as modifying the License. -
    -
    -You may add Your own copyright statement to Your modifications and may -provide additional or different license terms and conditions for use, -reproduction, or distribution of Your modifications, or for any such -Derivative Works as a whole, provided Your use, reproduction, and -distribution of the Work otherwise complies with the conditions stated in -this License. -
  8. - -
- -

5. Submission of Contributions. Unless You -explicitly state otherwise, any Contribution intentionally submitted for -inclusion in the Work by You to the Licensor shall be under the terms and -conditions of this License, without any additional terms or conditions. -Notwithstanding the above, nothing herein shall supersede or modify the -terms of any separate license agreement you may have executed with Licensor -regarding such Contributions.

-

6. Trademarks. This License does not grant -permission to use the trade names, trademarks, service marks, or product -names of the Licensor, except as required for reasonable and customary use -in describing the origin of the Work and reproducing the content of the -NOTICE file.

-

7. Disclaimer of Warranty. Unless required by -applicable law or agreed to in writing, Licensor provides the Work (and -each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT -WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, -without limitation, any warranties or conditions of TITLE, -NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You -are solely responsible for determining the appropriateness of using or -redistributing the Work and assume any risks associated with Your exercise -of permissions under this License.

-

8. Limitation of Liability. In no event and -under no legal theory, whether in tort (including negligence), contract, or -otherwise, unless required by applicable law (such as deliberate and -grossly negligent acts) or agreed to in writing, shall any Contributor be -liable to You for damages, including any direct, indirect, special, -incidental, or consequential damages of any character arising as a result -of this License or out of the use or inability to use the Work (including -but not limited to damages for loss of goodwill, work stoppage, computer -failure or malfunction, or any and all other commercial damages or losses), -even if such Contributor has been advised of the possibility of such -damages.

-

9. Accepting Warranty or Additional Liability. -While redistributing the Work or Derivative Works thereof, You may choose -to offer, and charge a fee for, acceptance of support, warranty, indemnity, -or other liability obligations and/or rights consistent with this License. -However, in accepting such obligations, You may act only on Your own behalf -and on Your sole responsibility, not on behalf of any other Contributor, -and only if You agree to indemnify, defend, and hold each Contributor -harmless for any liability incurred by, or claims asserted against, such -Contributor by reason of your accepting any such warranty or additional -liability.

-

END OF TERMS AND CONDITIONS

-

APPENDIX: How to apply the Apache License to your work

-

To apply the Apache License to your work, attach the following boilerplate -notice, with the fields enclosed by brackets "[]" replaced with your own -identifying information. (Don't include the brackets!) The text should be -enclosed in the appropriate comment syntax for the file format. We also -recommend that a file or class name and description of purpose be included -on the same "printed page" as the copyright notice for easier -identification within third-party archives.

-
Copyright [yyyy] [name of copyright owner]
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-
- - - - - - - - - - - diff --git a/mqtt_and_kafka/connect/plugins/confluentinc-kafka-connect-mqtt-1.7.6/doc/licenses/LICENSE-freemarker-2.3.25-incubating.txt b/mqtt_and_kafka/connect/plugins/confluentinc-kafka-connect-mqtt-1.7.6/doc/licenses/LICENSE-freemarker-2.3.25-incubating.txt deleted file mode 100644 index d64569567..000000000 --- a/mqtt_and_kafka/connect/plugins/confluentinc-kafka-connect-mqtt-1.7.6/doc/licenses/LICENSE-freemarker-2.3.25-incubating.txt +++ /dev/null @@ -1,202 +0,0 @@ - - Apache License - Version 2.0, January 2004 - http://www.apache.org/licenses/ - - TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION - - 1. Definitions. - - "License" shall mean the terms and conditions for use, reproduction, - and distribution as defined by Sections 1 through 9 of this document. - - "Licensor" shall mean the copyright owner or entity authorized by - the copyright owner that is granting the License. - - "Legal Entity" shall mean the union of the acting entity and all - other entities that control, are controlled by, or are under common - control with that entity. For the purposes of this definition, - "control" means (i) the power, direct or indirect, to cause the - direction or management of such entity, whether by contract or - otherwise, or (ii) ownership of fifty percent (50%) or more of the - outstanding shares, or (iii) beneficial ownership of such entity. - - "You" (or "Your") shall mean an individual or Legal Entity - exercising permissions granted by this License. - - "Source" form shall mean the preferred form for making modifications, - including but not limited to software source code, documentation - source, and configuration files. - - "Object" form shall mean any form resulting from mechanical - transformation or translation of a Source form, including but - not limited to compiled object code, generated documentation, - and conversions to other media types. - - "Work" shall mean the work of authorship, whether in Source or - Object form, made available under the License, as indicated by a - copyright notice that is included in or attached to the work - (an example is provided in the Appendix below). - - "Derivative Works" shall mean any work, whether in Source or Object - form, that is based on (or derived from) the Work and for which the - editorial revisions, annotations, elaborations, or other modifications - represent, as a whole, an original work of authorship. For the purposes - of this License, Derivative Works shall not include works that remain - separable from, or merely link (or bind by name) to the interfaces of, - the Work and Derivative Works thereof. - - "Contribution" shall mean any work of authorship, including - the original version of the Work and any modifications or additions - to that Work or Derivative Works thereof, that is intentionally - submitted to Licensor for inclusion in the Work by the copyright owner - or by an individual or Legal Entity authorized to submit on behalf of - the copyright owner. For the purposes of this definition, "submitted" - means any form of electronic, verbal, or written communication sent - to the Licensor or its representatives, including but not limited to - communication on electronic mailing lists, source code control systems, - and issue tracking systems that are managed by, or on behalf of, the - Licensor for the purpose of discussing and improving the Work, but - excluding communication that is conspicuously marked or otherwise - designated in writing by the copyright owner as "Not a Contribution." - - "Contributor" shall mean Licensor and any individual or Legal Entity - on behalf of whom a Contribution has been received by Licensor and - subsequently incorporated within the Work. - - 2. Grant of Copyright License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - copyright license to reproduce, prepare Derivative Works of, - publicly display, publicly perform, sublicense, and distribute the - Work and such Derivative Works in Source or Object form. - - 3. Grant of Patent License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - (except as stated in this section) patent license to make, have made, - use, offer to sell, sell, import, and otherwise transfer the Work, - where such license applies only to those patent claims licensable - by such Contributor that are necessarily infringed by their - Contribution(s) alone or by combination of their Contribution(s) - with the Work to which such Contribution(s) was submitted. If You - institute patent litigation against any entity (including a - cross-claim or counterclaim in a lawsuit) alleging that the Work - or a Contribution incorporated within the Work constitutes direct - or contributory patent infringement, then any patent licenses - granted to You under this License for that Work shall terminate - as of the date such litigation is filed. - - 4. Redistribution. You may reproduce and distribute copies of the - Work or Derivative Works thereof in any medium, with or without - modifications, and in Source or Object form, provided that You - meet the following conditions: - - (a) You must give any other recipients of the Work or - Derivative Works a copy of this License; and - - (b) You must cause any modified files to carry prominent notices - stating that You changed the files; and - - (c) You must retain, in the Source form of any Derivative Works - that You distribute, all copyright, patent, trademark, and - attribution notices from the Source form of the Work, - excluding those notices that do not pertain to any part of - the Derivative Works; and - - (d) If the Work includes a "NOTICE" text file as part of its - distribution, then any Derivative Works that You distribute must - include a readable copy of the attribution notices contained - within such NOTICE file, excluding those notices that do not - pertain to any part of the Derivative Works, in at least one - of the following places: within a NOTICE text file distributed - as part of the Derivative Works; within the Source form or - documentation, if provided along with the Derivative Works; or, - within a display generated by the Derivative Works, if and - wherever such third-party notices normally appear. The contents - of the NOTICE file are for informational purposes only and - do not modify the License. You may add Your own attribution - notices within Derivative Works that You distribute, alongside - or as an addendum to the NOTICE text from the Work, provided - that such additional attribution notices cannot be construed - as modifying the License. - - You may add Your own copyright statement to Your modifications and - may provide additional or different license terms and conditions - for use, reproduction, or distribution of Your modifications, or - for any such Derivative Works as a whole, provided Your use, - reproduction, and distribution of the Work otherwise complies with - the conditions stated in this License. - - 5. Submission of Contributions. Unless You explicitly state otherwise, - any Contribution intentionally submitted for inclusion in the Work - by You to the Licensor shall be under the terms and conditions of - this License, without any additional terms or conditions. - Notwithstanding the above, nothing herein shall supersede or modify - the terms of any separate license agreement you may have executed - with Licensor regarding such Contributions. - - 6. Trademarks. This License does not grant permission to use the trade - names, trademarks, service marks, or product names of the Licensor, - except as required for reasonable and customary use in describing the - origin of the Work and reproducing the content of the NOTICE file. - - 7. Disclaimer of Warranty. Unless required by applicable law or - agreed to in writing, Licensor provides the Work (and each - Contributor provides its Contributions) on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or - implied, including, without limitation, any warranties or conditions - of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A - PARTICULAR PURPOSE. You are solely responsible for determining the - appropriateness of using or redistributing the Work and assume any - risks associated with Your exercise of permissions under this License. - - 8. Limitation of Liability. In no event and under no legal theory, - whether in tort (including negligence), contract, or otherwise, - unless required by applicable law (such as deliberate and grossly - negligent acts) or agreed to in writing, shall any Contributor be - liable to You for damages, including any direct, indirect, special, - incidental, or consequential damages of any character arising as a - result of this License or out of the use or inability to use the - Work (including but not limited to damages for loss of goodwill, - work stoppage, computer failure or malfunction, or any and all - other commercial damages or losses), even if such Contributor - has been advised of the possibility of such damages. - - 9. Accepting Warranty or Additional Liability. While redistributing - the Work or Derivative Works thereof, You may choose to offer, - and charge a fee for, acceptance of support, warranty, indemnity, - or other liability obligations and/or rights consistent with this - License. However, in accepting such obligations, You may act only - on Your own behalf and on Your sole responsibility, not on behalf - of any other Contributor, and only if You agree to indemnify, - defend, and hold each Contributor harmless for any liability - incurred by, or claims asserted against, such Contributor by reason - of your accepting any such warranty or additional liability. - - END OF TERMS AND CONDITIONS - - APPENDIX: How to apply the Apache License to your work. - - To apply the Apache License to your work, attach the following - boilerplate notice, with the fields enclosed by brackets "[]" - replaced with your own identifying information. (Don't include - the brackets!) The text should be enclosed in the appropriate - comment syntax for the file format. We also recommend that a - file or class name and description of purpose be included on the - same "printed page" as the copyright notice for easier - identification within third-party archives. - - Copyright [yyyy] [name of copyright owner] - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. diff --git a/mqtt_and_kafka/connect/plugins/confluentinc-kafka-connect-mqtt-1.7.6/doc/licenses/LICENSE-guava-20.0.txt b/mqtt_and_kafka/connect/plugins/confluentinc-kafka-connect-mqtt-1.7.6/doc/licenses/LICENSE-guava-20.0.txt deleted file mode 100644 index d64569567..000000000 --- a/mqtt_and_kafka/connect/plugins/confluentinc-kafka-connect-mqtt-1.7.6/doc/licenses/LICENSE-guava-20.0.txt +++ /dev/null @@ -1,202 +0,0 @@ - - Apache License - Version 2.0, January 2004 - http://www.apache.org/licenses/ - - TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION - - 1. Definitions. - - "License" shall mean the terms and conditions for use, reproduction, - and distribution as defined by Sections 1 through 9 of this document. - - "Licensor" shall mean the copyright owner or entity authorized by - the copyright owner that is granting the License. - - "Legal Entity" shall mean the union of the acting entity and all - other entities that control, are controlled by, or are under common - control with that entity. For the purposes of this definition, - "control" means (i) the power, direct or indirect, to cause the - direction or management of such entity, whether by contract or - otherwise, or (ii) ownership of fifty percent (50%) or more of the - outstanding shares, or (iii) beneficial ownership of such entity. - - "You" (or "Your") shall mean an individual or Legal Entity - exercising permissions granted by this License. - - "Source" form shall mean the preferred form for making modifications, - including but not limited to software source code, documentation - source, and configuration files. - - "Object" form shall mean any form resulting from mechanical - transformation or translation of a Source form, including but - not limited to compiled object code, generated documentation, - and conversions to other media types. - - "Work" shall mean the work of authorship, whether in Source or - Object form, made available under the License, as indicated by a - copyright notice that is included in or attached to the work - (an example is provided in the Appendix below). - - "Derivative Works" shall mean any work, whether in Source or Object - form, that is based on (or derived from) the Work and for which the - editorial revisions, annotations, elaborations, or other modifications - represent, as a whole, an original work of authorship. For the purposes - of this License, Derivative Works shall not include works that remain - separable from, or merely link (or bind by name) to the interfaces of, - the Work and Derivative Works thereof. - - "Contribution" shall mean any work of authorship, including - the original version of the Work and any modifications or additions - to that Work or Derivative Works thereof, that is intentionally - submitted to Licensor for inclusion in the Work by the copyright owner - or by an individual or Legal Entity authorized to submit on behalf of - the copyright owner. For the purposes of this definition, "submitted" - means any form of electronic, verbal, or written communication sent - to the Licensor or its representatives, including but not limited to - communication on electronic mailing lists, source code control systems, - and issue tracking systems that are managed by, or on behalf of, the - Licensor for the purpose of discussing and improving the Work, but - excluding communication that is conspicuously marked or otherwise - designated in writing by the copyright owner as "Not a Contribution." - - "Contributor" shall mean Licensor and any individual or Legal Entity - on behalf of whom a Contribution has been received by Licensor and - subsequently incorporated within the Work. - - 2. Grant of Copyright License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - copyright license to reproduce, prepare Derivative Works of, - publicly display, publicly perform, sublicense, and distribute the - Work and such Derivative Works in Source or Object form. - - 3. Grant of Patent License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - (except as stated in this section) patent license to make, have made, - use, offer to sell, sell, import, and otherwise transfer the Work, - where such license applies only to those patent claims licensable - by such Contributor that are necessarily infringed by their - Contribution(s) alone or by combination of their Contribution(s) - with the Work to which such Contribution(s) was submitted. If You - institute patent litigation against any entity (including a - cross-claim or counterclaim in a lawsuit) alleging that the Work - or a Contribution incorporated within the Work constitutes direct - or contributory patent infringement, then any patent licenses - granted to You under this License for that Work shall terminate - as of the date such litigation is filed. - - 4. Redistribution. You may reproduce and distribute copies of the - Work or Derivative Works thereof in any medium, with or without - modifications, and in Source or Object form, provided that You - meet the following conditions: - - (a) You must give any other recipients of the Work or - Derivative Works a copy of this License; and - - (b) You must cause any modified files to carry prominent notices - stating that You changed the files; and - - (c) You must retain, in the Source form of any Derivative Works - that You distribute, all copyright, patent, trademark, and - attribution notices from the Source form of the Work, - excluding those notices that do not pertain to any part of - the Derivative Works; and - - (d) If the Work includes a "NOTICE" text file as part of its - distribution, then any Derivative Works that You distribute must - include a readable copy of the attribution notices contained - within such NOTICE file, excluding those notices that do not - pertain to any part of the Derivative Works, in at least one - of the following places: within a NOTICE text file distributed - as part of the Derivative Works; within the Source form or - documentation, if provided along with the Derivative Works; or, - within a display generated by the Derivative Works, if and - wherever such third-party notices normally appear. The contents - of the NOTICE file are for informational purposes only and - do not modify the License. You may add Your own attribution - notices within Derivative Works that You distribute, alongside - or as an addendum to the NOTICE text from the Work, provided - that such additional attribution notices cannot be construed - as modifying the License. - - You may add Your own copyright statement to Your modifications and - may provide additional or different license terms and conditions - for use, reproduction, or distribution of Your modifications, or - for any such Derivative Works as a whole, provided Your use, - reproduction, and distribution of the Work otherwise complies with - the conditions stated in this License. - - 5. Submission of Contributions. Unless You explicitly state otherwise, - any Contribution intentionally submitted for inclusion in the Work - by You to the Licensor shall be under the terms and conditions of - this License, without any additional terms or conditions. - Notwithstanding the above, nothing herein shall supersede or modify - the terms of any separate license agreement you may have executed - with Licensor regarding such Contributions. - - 6. Trademarks. This License does not grant permission to use the trade - names, trademarks, service marks, or product names of the Licensor, - except as required for reasonable and customary use in describing the - origin of the Work and reproducing the content of the NOTICE file. - - 7. Disclaimer of Warranty. Unless required by applicable law or - agreed to in writing, Licensor provides the Work (and each - Contributor provides its Contributions) on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or - implied, including, without limitation, any warranties or conditions - of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A - PARTICULAR PURPOSE. You are solely responsible for determining the - appropriateness of using or redistributing the Work and assume any - risks associated with Your exercise of permissions under this License. - - 8. Limitation of Liability. In no event and under no legal theory, - whether in tort (including negligence), contract, or otherwise, - unless required by applicable law (such as deliberate and grossly - negligent acts) or agreed to in writing, shall any Contributor be - liable to You for damages, including any direct, indirect, special, - incidental, or consequential damages of any character arising as a - result of this License or out of the use or inability to use the - Work (including but not limited to damages for loss of goodwill, - work stoppage, computer failure or malfunction, or any and all - other commercial damages or losses), even if such Contributor - has been advised of the possibility of such damages. - - 9. Accepting Warranty or Additional Liability. While redistributing - the Work or Derivative Works thereof, You may choose to offer, - and charge a fee for, acceptance of support, warranty, indemnity, - or other liability obligations and/or rights consistent with this - License. However, in accepting such obligations, You may act only - on Your own behalf and on Your sole responsibility, not on behalf - of any other Contributor, and only if You agree to indemnify, - defend, and hold each Contributor harmless for any liability - incurred by, or claims asserted against, such Contributor by reason - of your accepting any such warranty or additional liability. - - END OF TERMS AND CONDITIONS - - APPENDIX: How to apply the Apache License to your work. - - To apply the Apache License to your work, attach the following - boilerplate notice, with the fields enclosed by brackets "[]" - replaced with your own identifying information. (Don't include - the brackets!) The text should be enclosed in the appropriate - comment syntax for the file format. We also recommend that a - file or class name and description of purpose be included on the - same "printed page" as the copyright notice for easier - identification within third-party archives. - - Copyright [yyyy] [name of copyright owner] - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. diff --git a/mqtt_and_kafka/connect/plugins/confluentinc-kafka-connect-mqtt-1.7.6/doc/licenses/LICENSE-org.eclipse.paho.client.mqttv3-1.2.0.txt b/mqtt_and_kafka/connect/plugins/confluentinc-kafka-connect-mqtt-1.7.6/doc/licenses/LICENSE-org.eclipse.paho.client.mqttv3-1.2.0.txt deleted file mode 100644 index 79e486c3d..000000000 --- a/mqtt_and_kafka/connect/plugins/confluentinc-kafka-connect-mqtt-1.7.6/doc/licenses/LICENSE-org.eclipse.paho.client.mqttv3-1.2.0.txt +++ /dev/null @@ -1,70 +0,0 @@ -Eclipse Public License - v 1.0 - -THE ACCOMPANYING PROGRAM IS PROVIDED UNDER THE TERMS OF THIS ECLIPSE PUBLIC LICENSE ("AGREEMENT"). ANY USE, REPRODUCTION OR DISTRIBUTION OF THE PROGRAM CONSTITUTES RECIPIENT'S ACCEPTANCE OF THIS AGREEMENT. - -1. DEFINITIONS - -"Contribution" means: - -a) in the case of the initial Contributor, the initial code and documentation distributed under this Agreement, and -b) in the case of each subsequent Contributor: -i) changes to the Program, and -ii) additions to the Program; -where such changes and/or additions to the Program originate from and are distributed by that particular Contributor. A Contribution 'originates' from a Contributor if it was added to the Program by such Contributor itself or anyone acting on such Contributor's behalf. Contributions do not include additions to the Program which: (i) are separate modules of software distributed in conjunction with the Program under their own license agreement, and (ii) are not derivative works of the Program. -"Contributor" means any person or entity that distributes the Program. - -"Licensed Patents" mean patent claims licensable by a Contributor which are necessarily infringed by the use or sale of its Contribution alone or when combined with the Program. - -"Program" means the Contributions distributed in accordance with this Agreement. - -"Recipient" means anyone who receives the Program under this Agreement, including all Contributors. - -2. GRANT OF RIGHTS - -a) Subject to the terms of this Agreement, each Contributor hereby grants Recipient a non-exclusive, worldwide, royalty-free copyright license to reproduce, prepare derivative works of, publicly display, publicly perform, distribute and sublicense the Contribution of such Contributor, if any, and such derivative works, in source code and object code form. -b) Subject to the terms of this Agreement, each Contributor hereby grants Recipient a non-exclusive, worldwide, royalty-free patent license under Licensed Patents to make, use, sell, offer to sell, import and otherwise transfer the Contribution of such Contributor, if any, in source code and object code form. This patent license shall apply to the combination of the Contribution and the Program if, at the time the Contribution is added by the Contributor, such addition of the Contribution causes such combination to be covered by the Licensed Patents. The patent license shall not apply to any other combinations which include the Contribution. No hardware per se is licensed hereunder. -c) Recipient understands that although each Contributor grants the licenses to its Contributions set forth herein, no assurances are provided by any Contributor that the Program does not infringe the patent or other intellectual property rights of any other entity. Each Contributor disclaims any liability to Recipient for claims brought by any other entity based on infringement of intellectual property rights or otherwise. As a condition to exercising the rights and licenses granted hereunder, each Recipient hereby assumes sole responsibility to secure any other intellectual property rights needed, if any. For example, if a third party patent license is required to allow Recipient to distribute the Program, it is Recipient's responsibility to acquire that license before distributing the Program. -d) Each Contributor represents that to its knowledge it has sufficient copyright rights in its Contribution, if any, to grant the copyright license set forth in this Agreement. -3. REQUIREMENTS - -A Contributor may choose to distribute the Program in object code form under its own license agreement, provided that: - -a) it complies with the terms and conditions of this Agreement; and -b) its license agreement: -i) effectively disclaims on behalf of all Contributors all warranties and conditions, express and implied, including warranties or conditions of title and non-infringement, and implied warranties or conditions of merchantability and fitness for a particular purpose; -ii) effectively excludes on behalf of all Contributors all liability for damages, including direct, indirect, special, incidental and consequential damages, such as lost profits; -iii) states that any provisions which differ from this Agreement are offered by that Contributor alone and not by any other party; and -iv) states that source code for the Program is available from such Contributor, and informs licensees how to obtain it in a reasonable manner on or through a medium customarily used for software exchange. -When the Program is made available in source code form: - -a) it must be made available under this Agreement; and -b) a copy of this Agreement must be included with each copy of the Program. -Contributors may not remove or alter any copyright notices contained within the Program. - -Each Contributor must identify itself as the originator of its Contribution, if any, in a manner that reasonably allows subsequent Recipients to identify the originator of the Contribution. - -4. COMMERCIAL DISTRIBUTION - -Commercial distributors of software may accept certain responsibilities with respect to end users, business partners and the like. While this license is intended to facilitate the commercial use of the Program, the Contributor who includes the Program in a commercial product offering should do so in a manner which does not create potential liability for other Contributors. Therefore, if a Contributor includes the Program in a commercial product offering, such Contributor ("Commercial Contributor") hereby agrees to defend and indemnify every other Contributor ("Indemnified Contributor") against any losses, damages and costs (collectively "Losses") arising from claims, lawsuits and other legal actions brought by a third party against the Indemnified Contributor to the extent caused by the acts or omissions of such Commercial Contributor in connection with its distribution of the Program in a commercial product offering. The obligations in this section do not apply to any claims or Losses relating to any actual or alleged intellectual property infringement. In order to qualify, an Indemnified Contributor must: a) promptly notify the Commercial Contributor in writing of such claim, and b) allow the Commercial Contributor to control, and cooperate with the Commercial Contributor in, the defense and any related settlement negotiations. The Indemnified Contributor may participate in any such claim at its own expense. - -For example, a Contributor might include the Program in a commercial product offering, Product X. That Contributor is then a Commercial Contributor. If that Commercial Contributor then makes performance claims, or offers warranties related to Product X, those performance claims and warranties are such Commercial Contributor's responsibility alone. Under this section, the Commercial Contributor would have to defend claims against the other Contributors related to those performance claims and warranties, and if a court requires any other Contributor to pay any damages as a result, the Commercial Contributor must pay those damages. - -5. NO WARRANTY - -EXCEPT AS EXPRESSLY SET FORTH IN THIS AGREEMENT, THE PROGRAM IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER EXPRESS OR IMPLIED INCLUDING, WITHOUT LIMITATION, ANY WARRANTIES OR CONDITIONS OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY OR FITNESS FOR A PARTICULAR PURPOSE. Each Recipient is solely responsible for determining the appropriateness of using and distributing the Program and assumes all risks associated with its exercise of rights under this Agreement , including but not limited to the risks and costs of program errors, compliance with applicable laws, damage to or loss of data, programs or equipment, and unavailability or interruption of operations. - -6. DISCLAIMER OF LIABILITY - -EXCEPT AS EXPRESSLY SET FORTH IN THIS AGREEMENT, NEITHER RECIPIENT NOR ANY CONTRIBUTORS SHALL HAVE ANY LIABILITY FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING WITHOUT LIMITATION LOST PROFITS), HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OR DISTRIBUTION OF THE PROGRAM OR THE EXERCISE OF ANY RIGHTS GRANTED HEREUNDER, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGES. - -7. GENERAL - -If any provision of this Agreement is invalid or unenforceable under applicable law, it shall not affect the validity or enforceability of the remainder of the terms of this Agreement, and without further action by the parties hereto, such provision shall be reformed to the minimum extent necessary to make such provision valid and enforceable. - -If Recipient institutes patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Program itself (excluding combinations of the Program with other software or hardware) infringes such Recipient's patent(s), then such Recipient's rights granted under Section 2(b) shall terminate as of the date such litigation is filed. - -All Recipient's rights under this Agreement shall terminate if it fails to comply with any of the material terms or conditions of this Agreement and does not cure such failure in a reasonable period of time after becoming aware of such noncompliance. If all Recipient's rights under this Agreement terminate, Recipient agrees to cease use and distribution of the Program as soon as reasonably practicable. However, Recipient's obligations under this Agreement and any licenses granted by Recipient relating to the Program shall continue and survive. - -Everyone is permitted to copy and distribute copies of this Agreement, but in order to avoid inconsistency the Agreement is copyrighted and may only be modified in the following manner. The Agreement Steward reserves the right to publish new versions (including revisions) of this Agreement from time to time. No one other than the Agreement Steward has the right to modify this Agreement. The Eclipse Foundation is the initial Agreement Steward. The Eclipse Foundation may assign the responsibility to serve as the Agreement Steward to a suitable separate entity. Each new version of the Agreement will be given a distinguishing version number. The Program (including Contributions) may always be distributed subject to the version of the Agreement under which it was received. In addition, after a new version of the Agreement is published, Contributor may elect to distribute the Program (including its Contributions) under the new version. Except as expressly stated in Sections 2(a) and 2(b) above, Recipient receives no rights or licenses to the intellectual property of any Contributor under this Agreement, whether expressly, by implication, estoppel or otherwise. All rights in the Program not expressly granted under this Agreement are reserved. - -This Agreement is governed by the laws of the State of New York and the intellectual property laws of the United States of America. No party to this Agreement will bring a legal action under this Agreement more than one year after the cause of action arose. Each party waives its rights to a jury trial in any resulting litigation. diff --git a/mqtt_and_kafka/connect/plugins/confluentinc-kafka-connect-mqtt-1.7.6/doc/licenses/LICENSE-slf4j-1.7.25.txt b/mqtt_and_kafka/connect/plugins/confluentinc-kafka-connect-mqtt-1.7.6/doc/licenses/LICENSE-slf4j-1.7.25.txt deleted file mode 100644 index 315bd4979..000000000 --- a/mqtt_and_kafka/connect/plugins/confluentinc-kafka-connect-mqtt-1.7.6/doc/licenses/LICENSE-slf4j-1.7.25.txt +++ /dev/null @@ -1,24 +0,0 @@ -Copyright (c) 2004-2017 QOS.ch -All rights reserved. - -Permission is hereby granted, free of charge, to any person obtaining -a copy of this software and associated documentation files (the -"Software"), to deal in the Software without restriction, including -without limitation the rights to use, copy, modify, merge, publish, -distribute, sublicense, and/or sell copies of the Software, and to -permit persons to whom the Software is furnished to do so, subject to -the following conditions: - -The above copyright notice and this permission notice shall be -included in all copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, -EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF -MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND -NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE -LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION -OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION -WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. - - - diff --git a/mqtt_and_kafka/connect/plugins/confluentinc-kafka-connect-mqtt-1.7.6/etc/connect-avro-docker.properties b/mqtt_and_kafka/connect/plugins/confluentinc-kafka-connect-mqtt-1.7.6/etc/connect-avro-docker.properties deleted file mode 100644 index 482880401..000000000 --- a/mqtt_and_kafka/connect/plugins/confluentinc-kafka-connect-mqtt-1.7.6/etc/connect-avro-docker.properties +++ /dev/null @@ -1,31 +0,0 @@ -# -# Copyright [2018 - 2020] Confluent Inc. -# - -# Sample configuration for a standalone Kafka Connect worker that uses Avro serialization and -# integrates the the SchemaConfig Registry. This sample configuration assumes a local installation of -# Confluent Platform with all services running on their default ports. -# Bootstrap Kafka servers. If multiple servers are specified, they should be comma-separated. -bootstrap.servers=kafka:9092 -# The converters specify the format of data in Kafka and how to translate it into Connect data. -# Every Connect user will need to configure these based on the format they want their data in -# when loaded from or stored into Kafka -# key.converter=io.confluent.connect.avro.AvroConverter -key.converter.schema.registry.url=http://schema-registry:8081 -# value.converter=io.confluent.connect.avro.AvroConverter -value.converter.schema.registry.url=http://schema-registry:8081 -# The internal converter used for offsets and config data is configurable and must be specified, -# but most users will always want to use the built-in default. Offset and config data is never -# visible outside of Connect in this format. -internal.key.converter=org.apache.kafka.connect.json.JsonConverter -internal.value.converter=org.apache.kafka.connect.json.JsonConverter -internal.key.converter.schemas.enable=false -internal.value.converter.schemas.enable=false -# Local storage file for offset data -offset.storage.file.filename=/tmp/connect.offsets -# Confuent Control Center Integration -- uncomment these lines to enable Kafka client interceptors -# that will report audit data that can be displayed and analyzed in Confluent Control Center -# producer.interceptor.classes=io.confluent.monitoring.clients.interceptor.MonitoringProducerInterceptor -# consumer.interceptor.classes=io.confluent.monitoring.clients.interceptor.MonitoringConsumerInterceptor -consumer.max.poll.records=100 -plugin.path=target/components/packages/ \ No newline at end of file diff --git a/mqtt_and_kafka/connect/plugins/confluentinc-kafka-connect-mqtt-1.7.6/etc/source-anonymous.properties b/mqtt_and_kafka/connect/plugins/confluentinc-kafka-connect-mqtt-1.7.6/etc/source-anonymous.properties deleted file mode 100644 index caee1a9f2..000000000 --- a/mqtt_and_kafka/connect/plugins/confluentinc-kafka-connect-mqtt-1.7.6/etc/source-anonymous.properties +++ /dev/null @@ -1,9 +0,0 @@ -# -# Copyright [2018 - 2020] Confluent Inc. -# - -name=anonymous -tasks.max=1 -connector.class=io.confluent.connect.mqtt.MqttSourceConnector -mqtt.server.uri=tcp://127.0.0.1:32790 -mqtt.topics=foo diff --git a/mqtt_and_kafka/connect/plugins/confluentinc-kafka-connect-mqtt-1.7.6/etc/source-password.properties b/mqtt_and_kafka/connect/plugins/confluentinc-kafka-connect-mqtt-1.7.6/etc/source-password.properties deleted file mode 100644 index 6e70fe86f..000000000 --- a/mqtt_and_kafka/connect/plugins/confluentinc-kafka-connect-mqtt-1.7.6/etc/source-password.properties +++ /dev/null @@ -1,11 +0,0 @@ -# -# Copyright [2018 - 2020] Confluent Inc. -# - -name=anonymous -tasks.max=1 -connector.class=io.confluent.connect.mqtt.MqttSourceConnector -mqtt.server.uri=tcp://127.0.0.1:32792 -mqtt.topics=foo -mqtt.username=test -mqtt.password=test \ No newline at end of file diff --git a/mqtt_and_kafka/connect/plugins/confluentinc-kafka-connect-mqtt-1.7.6/lib/checker-qual-3.33.0.jar b/mqtt_and_kafka/connect/plugins/confluentinc-kafka-connect-mqtt-1.7.6/lib/checker-qual-3.33.0.jar deleted file mode 100644 index 61761fdcb..000000000 Binary files a/mqtt_and_kafka/connect/plugins/confluentinc-kafka-connect-mqtt-1.7.6/lib/checker-qual-3.33.0.jar and /dev/null differ diff --git a/mqtt_and_kafka/connect/plugins/confluentinc-kafka-connect-mqtt-1.7.6/lib/confluent-licensing-new-7.5.8-18-ce.jar b/mqtt_and_kafka/connect/plugins/confluentinc-kafka-connect-mqtt-1.7.6/lib/confluent-licensing-new-7.5.8-18-ce.jar deleted file mode 100644 index 029b90421..000000000 Binary files a/mqtt_and_kafka/connect/plugins/confluentinc-kafka-connect-mqtt-1.7.6/lib/confluent-licensing-new-7.5.8-18-ce.jar and /dev/null differ diff --git a/mqtt_and_kafka/connect/plugins/confluentinc-kafka-connect-mqtt-1.7.6/lib/confluent-serializers-new-7.5.8-18-ce.jar b/mqtt_and_kafka/connect/plugins/confluentinc-kafka-connect-mqtt-1.7.6/lib/confluent-serializers-new-7.5.8-18-ce.jar deleted file mode 100644 index ace3404ee..000000000 Binary files a/mqtt_and_kafka/connect/plugins/confluentinc-kafka-connect-mqtt-1.7.6/lib/confluent-serializers-new-7.5.8-18-ce.jar and /dev/null differ diff --git a/mqtt_and_kafka/connect/plugins/confluentinc-kafka-connect-mqtt-1.7.6/lib/connect-ak-non-public-0.24.0.jar b/mqtt_and_kafka/connect/plugins/confluentinc-kafka-connect-mqtt-1.7.6/lib/connect-ak-non-public-0.24.0.jar deleted file mode 100644 index 77c66c22b..000000000 Binary files a/mqtt_and_kafka/connect/plugins/confluentinc-kafka-connect-mqtt-1.7.6/lib/connect-ak-non-public-0.24.0.jar and /dev/null differ diff --git a/mqtt_and_kafka/connect/plugins/confluentinc-kafka-connect-mqtt-1.7.6/lib/connect-licensing-extensions-0.9.35.jar b/mqtt_and_kafka/connect/plugins/confluentinc-kafka-connect-mqtt-1.7.6/lib/connect-licensing-extensions-0.9.35.jar deleted file mode 100644 index 61f03838b..000000000 Binary files a/mqtt_and_kafka/connect/plugins/confluentinc-kafka-connect-mqtt-1.7.6/lib/connect-licensing-extensions-0.9.35.jar and /dev/null differ diff --git a/mqtt_and_kafka/connect/plugins/confluentinc-kafka-connect-mqtt-1.7.6/lib/connect-utils-0.3.3.jar b/mqtt_and_kafka/connect/plugins/confluentinc-kafka-connect-mqtt-1.7.6/lib/connect-utils-0.3.3.jar deleted file mode 100644 index 4cc6f9c5b..000000000 Binary files a/mqtt_and_kafka/connect/plugins/confluentinc-kafka-connect-mqtt-1.7.6/lib/connect-utils-0.3.3.jar and /dev/null differ diff --git a/mqtt_and_kafka/connect/plugins/confluentinc-kafka-connect-mqtt-1.7.6/lib/connect-utils-1.1.0.jar b/mqtt_and_kafka/connect/plugins/confluentinc-kafka-connect-mqtt-1.7.6/lib/connect-utils-1.1.0.jar deleted file mode 100644 index 8ad2fa21c..000000000 Binary files a/mqtt_and_kafka/connect/plugins/confluentinc-kafka-connect-mqtt-1.7.6/lib/connect-utils-1.1.0.jar and /dev/null differ diff --git a/mqtt_and_kafka/connect/plugins/confluentinc-kafka-connect-mqtt-1.7.6/lib/error_prone_annotations-2.18.0.jar b/mqtt_and_kafka/connect/plugins/confluentinc-kafka-connect-mqtt-1.7.6/lib/error_prone_annotations-2.18.0.jar deleted file mode 100644 index e072fe029..000000000 Binary files a/mqtt_and_kafka/connect/plugins/confluentinc-kafka-connect-mqtt-1.7.6/lib/error_prone_annotations-2.18.0.jar and /dev/null differ diff --git a/mqtt_and_kafka/connect/plugins/confluentinc-kafka-connect-mqtt-1.7.6/lib/failureaccess-1.0.1.jar b/mqtt_and_kafka/connect/plugins/confluentinc-kafka-connect-mqtt-1.7.6/lib/failureaccess-1.0.1.jar deleted file mode 100644 index 9b56dc751..000000000 Binary files a/mqtt_and_kafka/connect/plugins/confluentinc-kafka-connect-mqtt-1.7.6/lib/failureaccess-1.0.1.jar and /dev/null differ diff --git a/mqtt_and_kafka/connect/plugins/confluentinc-kafka-connect-mqtt-1.7.6/lib/freemarker-2.3.31.jar b/mqtt_and_kafka/connect/plugins/confluentinc-kafka-connect-mqtt-1.7.6/lib/freemarker-2.3.31.jar deleted file mode 100644 index 8fb169b21..000000000 Binary files a/mqtt_and_kafka/connect/plugins/confluentinc-kafka-connect-mqtt-1.7.6/lib/freemarker-2.3.31.jar and /dev/null differ diff --git a/mqtt_and_kafka/connect/plugins/confluentinc-kafka-connect-mqtt-1.7.6/lib/gson-2.9.0.jar b/mqtt_and_kafka/connect/plugins/confluentinc-kafka-connect-mqtt-1.7.6/lib/gson-2.9.0.jar deleted file mode 100644 index fb62e0565..000000000 Binary files a/mqtt_and_kafka/connect/plugins/confluentinc-kafka-connect-mqtt-1.7.6/lib/gson-2.9.0.jar and /dev/null differ diff --git a/mqtt_and_kafka/connect/plugins/confluentinc-kafka-connect-mqtt-1.7.6/lib/guava-32.1.1-jre.jar b/mqtt_and_kafka/connect/plugins/confluentinc-kafka-connect-mqtt-1.7.6/lib/guava-32.1.1-jre.jar deleted file mode 100644 index 8f2b3f5d4..000000000 Binary files a/mqtt_and_kafka/connect/plugins/confluentinc-kafka-connect-mqtt-1.7.6/lib/guava-32.1.1-jre.jar and /dev/null differ diff --git a/mqtt_and_kafka/connect/plugins/confluentinc-kafka-connect-mqtt-1.7.6/lib/j2objc-annotations-2.8.jar b/mqtt_and_kafka/connect/plugins/confluentinc-kafka-connect-mqtt-1.7.6/lib/j2objc-annotations-2.8.jar deleted file mode 100644 index 3595c4f9b..000000000 Binary files a/mqtt_and_kafka/connect/plugins/confluentinc-kafka-connect-mqtt-1.7.6/lib/j2objc-annotations-2.8.jar and /dev/null differ diff --git a/mqtt_and_kafka/connect/plugins/confluentinc-kafka-connect-mqtt-1.7.6/lib/jose4j-0.9.5.jar b/mqtt_and_kafka/connect/plugins/confluentinc-kafka-connect-mqtt-1.7.6/lib/jose4j-0.9.5.jar deleted file mode 100644 index ee3b65c10..000000000 Binary files a/mqtt_and_kafka/connect/plugins/confluentinc-kafka-connect-mqtt-1.7.6/lib/jose4j-0.9.5.jar and /dev/null differ diff --git a/mqtt_and_kafka/connect/plugins/confluentinc-kafka-connect-mqtt-1.7.6/lib/jsr305-3.0.2.jar b/mqtt_and_kafka/connect/plugins/confluentinc-kafka-connect-mqtt-1.7.6/lib/jsr305-3.0.2.jar deleted file mode 100644 index 59222d9ca..000000000 Binary files a/mqtt_and_kafka/connect/plugins/confluentinc-kafka-connect-mqtt-1.7.6/lib/jsr305-3.0.2.jar and /dev/null differ diff --git a/mqtt_and_kafka/connect/plugins/confluentinc-kafka-connect-mqtt-1.7.6/lib/kafka-connect-mqtt-1.7.6.jar b/mqtt_and_kafka/connect/plugins/confluentinc-kafka-connect-mqtt-1.7.6/lib/kafka-connect-mqtt-1.7.6.jar deleted file mode 100644 index 427eca9c2..000000000 Binary files a/mqtt_and_kafka/connect/plugins/confluentinc-kafka-connect-mqtt-1.7.6/lib/kafka-connect-mqtt-1.7.6.jar and /dev/null differ diff --git a/mqtt_and_kafka/connect/plugins/confluentinc-kafka-connect-mqtt-1.7.6/lib/listenablefuture-9999.0-empty-to-avoid-conflict-with-guava.jar b/mqtt_and_kafka/connect/plugins/confluentinc-kafka-connect-mqtt-1.7.6/lib/listenablefuture-9999.0-empty-to-avoid-conflict-with-guava.jar deleted file mode 100644 index 45832c052..000000000 Binary files a/mqtt_and_kafka/connect/plugins/confluentinc-kafka-connect-mqtt-1.7.6/lib/listenablefuture-9999.0-empty-to-avoid-conflict-with-guava.jar and /dev/null differ diff --git a/mqtt_and_kafka/connect/plugins/confluentinc-kafka-connect-mqtt-1.7.6/lib/log4j-api-2.25.0.jar b/mqtt_and_kafka/connect/plugins/confluentinc-kafka-connect-mqtt-1.7.6/lib/log4j-api-2.25.0.jar deleted file mode 100644 index 848f393a4..000000000 Binary files a/mqtt_and_kafka/connect/plugins/confluentinc-kafka-connect-mqtt-1.7.6/lib/log4j-api-2.25.0.jar and /dev/null differ diff --git a/mqtt_and_kafka/connect/plugins/confluentinc-kafka-connect-mqtt-1.7.6/lib/log4j-core-2.25.0.jar b/mqtt_and_kafka/connect/plugins/confluentinc-kafka-connect-mqtt-1.7.6/lib/log4j-core-2.25.0.jar deleted file mode 100644 index 8184432ba..000000000 Binary files a/mqtt_and_kafka/connect/plugins/confluentinc-kafka-connect-mqtt-1.7.6/lib/log4j-core-2.25.0.jar and /dev/null differ diff --git a/mqtt_and_kafka/connect/plugins/confluentinc-kafka-connect-mqtt-1.7.6/lib/log4j-slf4j2-impl-2.25.0.jar b/mqtt_and_kafka/connect/plugins/confluentinc-kafka-connect-mqtt-1.7.6/lib/log4j-slf4j2-impl-2.25.0.jar deleted file mode 100644 index 1a398dce0..000000000 Binary files a/mqtt_and_kafka/connect/plugins/confluentinc-kafka-connect-mqtt-1.7.6/lib/log4j-slf4j2-impl-2.25.0.jar and /dev/null differ diff --git a/mqtt_and_kafka/connect/plugins/confluentinc-kafka-connect-mqtt-1.7.6/lib/metrics-core-2.2.0.jar b/mqtt_and_kafka/connect/plugins/confluentinc-kafka-connect-mqtt-1.7.6/lib/metrics-core-2.2.0.jar deleted file mode 100644 index 0f6d1cb0e..000000000 Binary files a/mqtt_and_kafka/connect/plugins/confluentinc-kafka-connect-mqtt-1.7.6/lib/metrics-core-2.2.0.jar and /dev/null differ diff --git a/mqtt_and_kafka/connect/plugins/confluentinc-kafka-connect-mqtt-1.7.6/lib/org.eclipse.paho.client.mqttv3-1.2.5.jar b/mqtt_and_kafka/connect/plugins/confluentinc-kafka-connect-mqtt-1.7.6/lib/org.eclipse.paho.client.mqttv3-1.2.5.jar deleted file mode 100644 index 66f1278e4..000000000 Binary files a/mqtt_and_kafka/connect/plugins/confluentinc-kafka-connect-mqtt-1.7.6/lib/org.eclipse.paho.client.mqttv3-1.2.5.jar and /dev/null differ diff --git a/mqtt_and_kafka/connect/plugins/confluentinc-kafka-connect-mqtt-1.7.6/lib/protobuf-java-3.25.5.jar b/mqtt_and_kafka/connect/plugins/confluentinc-kafka-connect-mqtt-1.7.6/lib/protobuf-java-3.25.5.jar deleted file mode 100644 index d76648859..000000000 Binary files a/mqtt_and_kafka/connect/plugins/confluentinc-kafka-connect-mqtt-1.7.6/lib/protobuf-java-3.25.5.jar and /dev/null differ diff --git a/mqtt_and_kafka/connect/plugins/confluentinc-kafka-connect-mqtt-1.7.6/lib/protobuf-java-util-3.25.5.jar b/mqtt_and_kafka/connect/plugins/confluentinc-kafka-connect-mqtt-1.7.6/lib/protobuf-java-util-3.25.5.jar deleted file mode 100644 index 5f97266a4..000000000 Binary files a/mqtt_and_kafka/connect/plugins/confluentinc-kafka-connect-mqtt-1.7.6/lib/protobuf-java-util-3.25.5.jar and /dev/null differ diff --git a/mqtt_and_kafka/connect/plugins/confluentinc-kafka-connect-mqtt-1.7.6/lib/slf4j-api-1.7.36.jar b/mqtt_and_kafka/connect/plugins/confluentinc-kafka-connect-mqtt-1.7.6/lib/slf4j-api-1.7.36.jar deleted file mode 100644 index 7d3ce68d2..000000000 Binary files a/mqtt_and_kafka/connect/plugins/confluentinc-kafka-connect-mqtt-1.7.6/lib/slf4j-api-1.7.36.jar and /dev/null differ diff --git a/mqtt_and_kafka/connect/plugins/confluentinc-kafka-connect-mqtt-1.7.6/lib/value-2.8.2.jar b/mqtt_and_kafka/connect/plugins/confluentinc-kafka-connect-mqtt-1.7.6/lib/value-2.8.2.jar deleted file mode 100644 index 6f4cec3c5..000000000 Binary files a/mqtt_and_kafka/connect/plugins/confluentinc-kafka-connect-mqtt-1.7.6/lib/value-2.8.2.jar and /dev/null differ diff --git a/mqtt_and_kafka/connect/plugins/confluentinc-kafka-connect-mqtt-1.7.6/manifest.json b/mqtt_and_kafka/connect/plugins/confluentinc-kafka-connect-mqtt-1.7.6/manifest.json deleted file mode 100644 index 31b258b82..000000000 --- a/mqtt_and_kafka/connect/plugins/confluentinc-kafka-connect-mqtt-1.7.6/manifest.json +++ /dev/null @@ -1,30 +0,0 @@ -{ - "name" : "kafka-connect-mqtt", - "version" : "1.7.6", - "title" : "Kafka Connect MQTT", - "description" : "A Kafka Connect plugin for sending and receiving data from a Mqtt broker.", - "owner" : { - "username" : "confluentinc", - "name" : "Confluent, Inc." - }, - "support" : { - "summary" : "This connector is a Confluent Commercial Connector and supported by Confluent. The requires purchase of a Confluent Platform subscription, including a license to this Commercial Connector. You can also use this connector for a 30-day trial without an enterprise license key - after 30 days, you need to purchase a subscription. Please contact your Confluent account manager for details.", - "url" : "https://docs.confluent.io/kafka-connect-mqtt/current/index.html" - }, - "tags" : [ "MQTT", "Internet of Things", "IOT" ], - "features" : { - "supported_encodings" : [ "any" ], - "single_message_transforms" : true, - "confluent_control_center_integration" : true, - "kafka_connect_api" : true - }, - "logo" : "assets/mqtt.png", - "documentation_url" : "https://docs.confluent.io/kafka-connect-mqtt/current/index.html", - "docker_image" : { }, - "license" : [ { - "name" : "Confluent Software Evaluation License", - "url" : "https://www.confluent.io/software-evaluation-license" - } ], - "component_types" : [ "source", "sink" ], - "release_date" : "2025-08-09" -} \ No newline at end of file diff --git a/mqtt_and_kafka/connectors/mqtt-source.json b/mqtt_and_kafka/connectors/mqtt-source.json deleted file mode 100644 index 0940826ef..000000000 --- a/mqtt_and_kafka/connectors/mqtt-source.json +++ /dev/null @@ -1,31 +0,0 @@ -{ - "name": "mqtt-source", - "config": { - "connector.class": "io.confluent.connect.mqtt.MqttSourceConnector", - "tasks.max": "1", - - "mqtt.server.uri": "tcp://mosquitto:1883", - "mqtt.topics": "mqtt/#", - "mqtt.qos": "0", - "clean.session": "true", - - "kafka.topic": "dev-robot-alerts", - - "key.converter": "org.apache.kafka.connect.storage.StringConverter", - "key.converter.schemas.enable": "false", - - "value.converter": "org.apache.kafka.connect.converters.ByteArrayConverter", - "value.converter.schemas.enable": "false", - - "errors.tolerance": "all", - "errors.log.enable": "true", - - "topic.creation.enable": "false", - "topic.creation.default.replication.factor": "1", - "topic.creation.default.partitions": "1", - - "confluent.topic.bootstrap.servers": "kafka:9092", - "confluent.topic.replication.factor": "1", - "producer.override.bootstrap.servers": "kafka:9092" - } -} diff --git a/mqtt_and_kafka/kafka/kafka-files/app-start.sh b/mqtt_and_kafka/kafka/kafka-files/app-start.sh index 598b736d2..ad6f253a0 100644 --- a/mqtt_and_kafka/kafka/kafka-files/app-start.sh +++ b/mqtt_and_kafka/kafka/kafka-files/app-start.sh @@ -46,4 +46,4 @@ BOOTSTRAP="${BOOTSTRAP}" /opt/bitnami/smoke-test.sh || { } # Stay attached to Kafka process -wait ${KAFKA_PID} +wait ${KAFKA_PID} \ No newline at end of file diff --git a/mqtt_and_kafka/kafka/kafka-files/create-topics.sh b/mqtt_and_kafka/kafka/kafka-files/create-topics.sh index 0ec1b6b6b..5a5107bde 100644 --- a/mqtt_and_kafka/kafka/kafka-files/create-topics.sh +++ b/mqtt_and_kafka/kafka/kafka-files/create-topics.sh @@ -24,24 +24,55 @@ for i in {1..60}; do fi done -# Required topics with 7-day retention TOPICS=( dev-robot-alerts dev-robot-commands dev-robot-status dev-robot-telemetry-raw dev-robot-state + + dev-camera-security + sensor-telemetry + sensor-anomalies + dev-robot-telemetry-anomalies + sensor_anomalies sensor_zone_stats dev-robot-telemetry-anomalies + summaries.5m irrigation.control irrigation.control.dlq sound.new image.new - summaries.5m - dev-aerial-images-keys + aerial_images_metadata + dev-security-images-keys + alerts + + aerial_image_object_detections + aerial_image_anomaly_detections + aerial_image_segmentation + aerial_images_complete_metadata + + # --- imagery (MinIO -> Kafka) --- image.new.aerial - image.new.aerial.connections + image_new_aerial_connections + image.new.fruits + image.new.leaves + image.new.ground + image.new.field + image.new.security + image_new_security_connections + + # --- sound(sound) (MinIO -> Kafka) --- + sound.new.plants + sound.new.sounds + sounds_ultra_metadata + sounds_metadata + sound_new_plants_connections + sound_new_sounds_connections + + inference.dispatched.sounds + dlq.inference.http ) # Idempotent creation with retention.ms diff --git a/mqtt_and_kafka/mqtt-router/Dockerfile b/mqtt_and_kafka/mqtt-router/Dockerfile new file mode 100644 index 000000000..3ec5173ee --- /dev/null +++ b/mqtt_and_kafka/mqtt-router/Dockerfile @@ -0,0 +1,39 @@ +FROM python:3.12-slim + +# ---- Build-time toggle for NetFree CA injection ---- +ARG USE_NETFREE=false + +# ---- System deps (CA, curl). librdkafka1 helps if confluent-kafka wheel is not fully static on your base ---- +RUN apt-get update && apt-get install -y --no-install-recommends \ + ca-certificates curl librdkafka1 \ + && rm -rf /var/lib/apt/lists/* + +WORKDIR /app + +# ---- Optional NetFree certificates (mount or COPY your certs/*.crt alongside the Dockerfile) ---- +# If you keep certs in repo, uncomment the next line: +# COPY certs/*.crt /app/certs/ +RUN if [ "$USE_NETFREE" = "true" ] && [ -d /app/certs ] && ls /app/certs/*.crt >/dev/null 2>&1; then \ + echo "Configuring NetFree certificates..."; \ + cp /app/certs/*.crt /usr/local/share/ca-certificates/ && update-ca-certificates; \ + else \ + echo "No NetFree certs applied (USE_NETFREE=$USE_NETFREE)."; \ + fi + +# ---- Make requests/libs use system CA (works both with and without NetFree) ---- +ENV SSL_CERT_FILE=/etc/ssl/certs/ca-certificates.crt \ + REQUESTS_CA_BUNDLE=/etc/ssl/certs/ca-certificates.crt \ + PIP_CERT=/etc/ssl/certs/ca-certificates.crt \ + PYTHONUNBUFFERED=1 + +# ---- Install Python deps ---- +COPY requirements.txt . +# When behind NetFree, trusted-host can help even אם אין צורך זה לא מזיק: +RUN python -m pip install --no-cache-dir \ + --trusted-host pypi.org --trusted-host files.pythonhosted.org \ + -r requirements.txt + +# ---- App code ---- +COPY app.py . + +ENTRYPOINT ["python", "app.py"] diff --git a/mqtt_and_kafka/mqtt-router/app.py b/mqtt_and_kafka/mqtt-router/app.py new file mode 100644 index 000000000..3383671aa --- /dev/null +++ b/mqtt_and_kafka/mqtt-router/app.py @@ -0,0 +1,154 @@ +import os +import re +import signal +import sys +from typing import Optional + +import paho.mqtt.client as mqtt +from confluent_kafka import Producer, KafkaException, KafkaError +from confluent_kafka.admin import AdminClient, NewTopic + +# ---------- Env ---------- +MQTT_HOST = os.getenv("MQTT_HOST", "mosquitto") +MQTT_PORT = int(os.getenv("MQTT_PORT", "1883")) +MQTT_USERNAME = os.getenv("MQTT_USERNAME", "") +MQTT_PASSWORD = os.getenv("MQTT_PASSWORD", "") +MQTT_TOPIC_FILTER = os.getenv("MQTT_TOPIC_FILTER", "mqtt/#") + +KAFKA_BOOTSTRAP = os.getenv("KAFKA_BOOTSTRAP", "kafka:9092") +KAFKA_CLIENT_ID = os.getenv("KAFKA_CLIENT_ID", "mqtt-router") +CREATE_TOPICS = os.getenv("CREATE_TOPICS", "false").lower() == "true" +DEFAULT_PARTITIONS = int(os.getenv("DEFAULT_PARTITIONS", "1")) +DEFAULT_REPLICATION = int(os.getenv("DEFAULT_REPLICATION", "1")) + +# Optional security (set via env if needed) +KAFKA_SECURITY_PROTOCOL = os.getenv("KAFKA_SECURITY_PROTOCOL", "") # e.g. "SASL_PLAINTEXT", "SASL_SSL", "SSL" +KAFKA_SASL_MECHANISM = os.getenv("KAFKA_SASL_MECHANISM", "") # e.g. "PLAIN" +KAFKA_SASL_USERNAME = os.getenv("KAFKA_SASL_USERNAME", "") +KAFKA_SASL_PASSWORD = os.getenv("KAFKA_SASL_PASSWORD", "") + +# ---------- Topic mapping ---------- +# Allow arbitrary depth after "mqtt/" → replace "/" with "." +VALID_CHARS = re.compile(r'[^A-Za-z0-9._-]') + +def map_mqtt_to_kafka_topic(mqtt_topic: str) -> Optional[str]: + prefix = "mqtt/" + if not mqtt_topic.startswith(prefix): + return None + tail = mqtt_topic[len(prefix):].strip("/") + if not tail: + return None + parts = [seg for seg in tail.split("/") if seg] + dotted = "_".join(parts) + dotted = VALID_CHARS.sub("_", dotted) + return dotted[:249] if dotted else None + +# ---------- Kafka clients ---------- +producer_conf = { + "bootstrap.servers": KAFKA_BOOTSTRAP, + "client.id": KAFKA_CLIENT_ID, + + # Strong delivery semantics + "acks": "all", + "enable.idempotence": True, + + # Throughput tuning + "compression.type": os.getenv("KAFKA_COMPRESSION", "lz4"), + "linger.ms": int(os.getenv("KAFKA_LINGER_MS", "5")), + "batch.size": int(os.getenv("KAFKA_BATCH_SIZE", str(64 * 1024))), # bytes + + # Resilience + "socket.keepalive.enable": True, + "delivery.timeout.ms": int(os.getenv("KAFKA_DELIVERY_TIMEOUT_MS", "120000")), + "request.timeout.ms": int(os.getenv("KAFKA_REQUEST_TIMEOUT_MS", "30000")), +} + +# Optional security +if KAFKA_SECURITY_PROTOCOL: + producer_conf["security.protocol"] = KAFKA_SECURITY_PROTOCOL +if KAFKA_SASL_MECHANISM: + producer_conf["sasl.mechanism"] = KAFKA_SASL_MECHANISM +if KAFKA_SASL_USERNAME: + producer_conf["sasl.username"] = KAFKA_SASL_USERNAME +if KAFKA_SASL_PASSWORD: + producer_conf["sasl.password"] = KAFKA_SASL_PASSWORD + +p = Producer(producer_conf) +admin = AdminClient({"bootstrap.servers": KAFKA_BOOTSTRAP}) # kept for CREATE_TOPICS toggle + +def ensure_topic(topic: str): + if not CREATE_TOPICS: + return + try: + fs = admin.create_topics([NewTopic(topic, num_partitions=DEFAULT_PARTITIONS, + replication_factor=DEFAULT_REPLICATION)]) + fs[topic].result() + print(f"[router] Created topic: {topic}", flush=True) + except Exception as e: + msg = str(e) + if "exists" in msg.lower() or "TopicExistsError" in msg or "TOPIC_ALREADY_EXISTS" in msg: + return + print(f"[router] create_topics warning for {topic}: {e}", flush=True) + +def delivery_report(err, msg): + if err is not None: + print(f"[router] Delivery failed for {msg.topic()}: {err}", flush=True) + else: + print(f"[router] Delivered to {msg.topic()} [partition {msg.partition()} offset {msg.offset()}]", flush=True) + +# ---------- MQTT callbacks ---------- +def on_connect(client, userdata, flags, rc, properties=None): + if rc == 0: + print(f"[router] Connected MQTT {MQTT_HOST}:{MQTT_PORT}, subscribe: {MQTT_TOPIC_FILTER}", flush=True) + client.subscribe(MQTT_TOPIC_FILTER, qos=0) + else: + print(f"[router] MQTT connect failed: rc={rc}", flush=True) + +def on_message(client, userdata, msg): + src = msg.topic + dst = map_mqtt_to_kafka_topic(src) + if not dst: + print(f"[router] Skipping topic (no match): {src}", flush=True) + return + try: + ensure_topic(dst) + p.produce(dst, value=msg.payload, on_delivery=delivery_report) + # Poll to serve delivery callbacks; small 0 keeps loop snappy + p.poll(0) + except KafkaException as e: + # Helpful message when topics are not pre-created + kafka_err = e.args[0] if e.args else None + if isinstance(kafka_err, KafkaError) and kafka_err.code() == KafkaError.UNKNOWN_TOPIC_OR_PART: + print(f"[router] ERROR UnknownTopicOrPartition for '{dst}'. " + f"CREATE_TOPICS=false → please pre-create this topic.", flush=True) + else: + print(f"[router] Kafka produce error: {e}", flush=True) + +# ---------- Main ---------- +def main(): + client = mqtt.Client(client_id="mqtt-router", protocol=mqtt.MQTTv5) + if MQTT_USERNAME or MQTT_PASSWORD: + client.username_pw_set(MQTT_USERNAME, MQTT_PASSWORD) + + # Gentle reconnect backoff + client.reconnect_delay_set(min_delay=1, max_delay=30) + + client.on_connect = on_connect + client.on_message = on_message + + def handle_sigterm(signum, frame): + print("[router] SIGTERM received, flushing producer...", flush=True) + p.flush(10) + sys.exit(0) + + signal.signal(signal.SIGTERM, handle_sigterm) + signal.signal(signal.SIGINT, handle_sigterm) + + client.connect(MQTT_HOST, MQTT_PORT, keepalive=30) + print(f"[router] Boot: MQTT={MQTT_HOST}:{MQTT_PORT} Kafka={KAFKA_BOOTSTRAP} " + f"CREATE_TOPICS={CREATE_TOPICS}", flush=True) + client.loop_forever() + +if __name__ == "__main__": + main() + diff --git a/mqtt_and_kafka/mqtt-router/requirements.txt b/mqtt_and_kafka/mqtt-router/requirements.txt new file mode 100644 index 000000000..3d853480c --- /dev/null +++ b/mqtt_and_kafka/mqtt-router/requirements.txt @@ -0,0 +1,2 @@ +paho-mqtt==2.1.0 +confluent-kafka>=2.4 diff --git a/mqtt_and_kafka/simulator/data_simulator.py b/mqtt_and_kafka/simulator/air_simulator.py similarity index 100% rename from mqtt_and_kafka/simulator/data_simulator.py rename to mqtt_and_kafka/simulator/air_simulator.py diff --git a/prometheus/prometheus.yml b/prometheus/prometheus.yml index 87bef244d..804405bbd 100644 --- a/prometheus/prometheus.yml +++ b/prometheus/prometheus.yml @@ -34,5 +34,12 @@ scrape_configs: - job_name: sound_metrics static_configs: - - targets: ['host.docker.internal:8001'] + - targets: ['host.docker.internal:8005'] + + - job_name: 'pushgateway' + honor_labels: true + static_configs: + - targets: ['pushgateway:9091'] + + diff --git a/results/benchmarks.csv b/results/benchmarks.csv new file mode 100644 index 000000000..4f54ff285 --- /dev/null +++ b/results/benchmarks.csv @@ -0,0 +1,3 @@ +file,codec,orig_bytes,encoded_bytes,compression_ratio_orig_over_encoded,encode_time_sec,encode_cpu_avg_percent,timestamp,age_days +robot-03_20251028t120000z.mp3,flac,6747480,33505376,0.201,0.71,152.2,2025-10-28T12:00:00,5.0 +robot-03_20251028t120000z.mp3,opus,6747480,2452450,2.751,3.593,143.9,2025-10-28T12:00:00,5.0 diff --git a/services/sounds/API-development/README.md b/services/API-notifications/README.md similarity index 100% rename from services/sounds/API-development/README.md rename to services/API-notifications/README.md diff --git a/services/sounds/API-development/jest.config.js b/services/API-notifications/jest.config.js similarity index 100% rename from services/sounds/API-development/jest.config.js rename to services/API-notifications/jest.config.js diff --git a/services/sounds/API-development/package-lock.json b/services/API-notifications/package-lock.json similarity index 100% rename from services/sounds/API-development/package-lock.json rename to services/API-notifications/package-lock.json diff --git a/services/sounds/API-development/package.json b/services/API-notifications/package.json similarity index 100% rename from services/sounds/API-development/package.json rename to services/API-notifications/package.json diff --git a/services/sounds/API-development/pytest.ini b/services/API-notifications/pytest.ini similarity index 100% rename from services/sounds/API-development/pytest.ini rename to services/API-notifications/pytest.ini diff --git a/services/sounds/API-development/src/Dockerfile b/services/API-notifications/src/Dockerfile similarity index 100% rename from services/sounds/API-development/src/Dockerfile rename to services/API-notifications/src/Dockerfile diff --git a/services/sounds/compression/tests/__init__.py b/services/API-notifications/src/__init__.py similarity index 100% rename from services/sounds/compression/tests/__init__.py rename to services/API-notifications/src/__init__.py diff --git a/services/sounds/sounds_classifier/src/classification/backbones/__init__.py b/services/API-notifications/src/backend/__init__.py similarity index 100% rename from services/sounds/sounds_classifier/src/classification/backbones/__init__.py rename to services/API-notifications/src/backend/__init__.py diff --git a/services/sounds/API-development/src/backend/app.py b/services/API-notifications/src/backend/app.py similarity index 100% rename from services/sounds/API-development/src/backend/app.py rename to services/API-notifications/src/backend/app.py diff --git a/services/sounds/API-development/src/backend/requirements.txt b/services/API-notifications/src/backend/requirements.txt similarity index 100% rename from services/sounds/API-development/src/backend/requirements.txt rename to services/API-notifications/src/backend/requirements.txt diff --git a/services/sounds/API-development/src/requirements.txt b/services/API-notifications/src/requirements.txt similarity index 100% rename from services/sounds/API-development/src/requirements.txt rename to services/API-notifications/src/requirements.txt diff --git a/services/sounds/API-development/src/window-client/README.md b/services/API-notifications/src/window-client/README.md similarity index 100% rename from services/sounds/API-development/src/window-client/README.md rename to services/API-notifications/src/window-client/README.md diff --git a/services/sounds/API-development/src/window-client/notification-popup.html b/services/API-notifications/src/window-client/notification-popup.html similarity index 100% rename from services/sounds/API-development/src/window-client/notification-popup.html rename to services/API-notifications/src/window-client/notification-popup.html diff --git a/services/sounds/API-development/src/window-client/package.json b/services/API-notifications/src/window-client/package.json similarity index 100% rename from services/sounds/API-development/src/window-client/package.json rename to services/API-notifications/src/window-client/package.json diff --git a/services/sounds/API-development/src/window-client/script.js b/services/API-notifications/src/window-client/script.js similarity index 100% rename from services/sounds/API-development/src/window-client/script.js rename to services/API-notifications/src/window-client/script.js diff --git a/services/sounds/API-development/src/window-client/styles.css b/services/API-notifications/src/window-client/styles.css similarity index 100% rename from services/sounds/API-development/src/window-client/styles.css rename to services/API-notifications/src/window-client/styles.css diff --git a/services/sounds/sounds_classifier/src/classification/core/__init__.py b/services/API-notifications/tests/__init__.py similarity index 100% rename from services/sounds/sounds_classifier/src/classification/core/__init__.py rename to services/API-notifications/tests/__init__.py diff --git a/services/sounds/API-development/tests/notification-manager.test.js b/services/API-notifications/tests/notification-manager.test.js similarity index 100% rename from services/sounds/API-development/tests/notification-manager.test.js rename to services/API-notifications/tests/notification-manager.test.js diff --git a/services/sounds/API-development/tests/test_app.py b/services/API-notifications/tests/test_app.py similarity index 100% rename from services/sounds/API-development/tests/test_app.py rename to services/API-notifications/tests/test_app.py diff --git a/services/alertmanager_service/README.md b/services/alertmanager_service/README.md new file mode 100644 index 000000000..743bef952 --- /dev/null +++ b/services/alertmanager_service/README.md @@ -0,0 +1,201 @@ +# 🚨 AgGuard AlertManager Service + +The **AgGuard AlertManager Service** acts as a bridge between AgCloud’s detection pipelines and **Prometheus Alertmanager**. +It receives structured alert JSON payloads, renders descriptive messages using YAML templates, and forwards the alerts to Alertmanager’s `/api/v2/alerts` endpoint. + +--- + +## 🧩 Overview + +- **Framework:** FastAPI +- **Purpose:** Converts raw alerts from detection systems into human-readable, templated messages and sends them to Alertmanager +- **Output:** Properly structured Alertmanager v2 JSON alerts +- **Version:** `1.3` + +--- + +## ⚙️ Environment Variables + +| Variable | Description | Default | +|-----------|--------------|----------| +| `CFG_PATH` | Path to the YAML file containing alert templates | `/app/templates/templates/templates.yml` | +| `ALERTMANAGER_URL` | Base URL of the Alertmanager API | `http://alertmanager:9093` | +| `LOG_LEVEL` | Optional logging verbosity (e.g., `INFO`, `DEBUG`) | `INFO` | + +--- + +## 🚀 Endpoints + +### `POST /alerts` + +Accepts an alert JSON payload and forwards it to Alertmanager after rendering its template. + +**Example request:** + +```bash +curl -X POST http://localhost:8000/alerts \ + -H "Content-Type: application/json" \ + -d '{ + "alert_id": "alert-67", + "alert_type": "smoke_detected", + "device_id": "camera-12", + "started_at": "2025-10-30T14:45:00Z", + "ended_at": "2025-10-30T15:10:00Z", + "confidence": 0.91, + "severity": 2, + "area": "south_field", + "lat": 31.900215, + "lon": 34.850921, + "image_url": "https://s3.farm/agguard/smoke_20251030_1445.jpg", + "vod": "https://s3.farm/agguard/smoke_clip_1445.mp4" + }' +``` + +**Example response:** + +```json +{ + "status": "sent", + "alert": { + "labels": { + "alertname": "smoke_detected", + "alert_id": "alert-67", + "device": "camera-12", + "source": "agcloud-alerts" + }, + "annotations": { + "summary": "🚨 Smoke detected by camera-12 near south_field (confidence 0.91)", + "recommendation": "Inspect the south_field immediately. If fire is confirmed, contact emergency services.", + "category": "environmental", + "severity": "2", + "lat": "31.900215", + "lon": "34.850921", + "image_url": "https://s3.farm/agguard/smoke_20251030_1445.jpg", + "vod": "https://s3.farm/agguard/smoke_clip_1445.mp4" + }, + "startsAt": "2025-10-30T14:45:00Z", + "endsAt": "2025-10-30T15:10:00Z" + } +} +``` + +--- + +### `GET /health` + +Simple health check endpoint. + +**Response:** + +```json +{ "status": "ok" } +``` + +--- + +## 📄 Template Configuration + +Templates are defined in a YAML file (default: `/app/templates/templates/templates.yml`). +Each key corresponds to an `alert_type` and defines the message text and metadata. + +**Example:** + +```yaml +templates: + smoke_detected: + category: environmental + summary: "🚨 Smoke detected by ${device_id} near ${area} (confidence ${confidence})" + recommendation: "Inspect the ${area} immediately. If fire is confirmed, contact emergency services." + + masked_person: + category: security + summary: "Person wearing a mask detected by ${device_id} at ${timestamp}" + recommendation: "Verify the person’s authorization using the live feed." +``` + +### 🧠 Template Variables + +Template values use Python’s `string.Template` syntax (`${variable}`). +Any key present in the incoming alert JSON can be substituted dynamically. + +| Common variable | Description | +|------------------|-------------| +| `${device_id}` | Unique device identifier | +| `${area}` | Detected area/zone | +| `${confidence}` | Detection confidence | +| `${timestamp}` | ISO time string (optional) | +| `${alert_type}` | Type of alert | +| `${severity}` | Numeric severity or category | + +If a template variable is missing in the payload, it is safely ignored (not replaced). + +--- + +## 💬 How Templates Are Used in UI and Slack + +- The `summary` field defined in the template is **displayed directly in the AgGuard UI alert panels**, providing human-readable context (e.g., _“🚨 Smoke detected by camera-12 near south_field”_). +- The same `summary` text is also included in **Slack notifications** sent by Alertmanager, ensuring consistent and recognizable messages across interfaces. +- `recommendation` text is used as an actionable suggestion in both the UI and Slack alerts (e.g., _“Inspect the south_field immediately.”_) + +--- + +## 🧱 Expected JSON Fields + +| Field | Required | Description | +|--------|-----------|-------------| +| `alert_id` | ✅ | Unique alert identifier | +| `alert_type` | ✅ | Type of alert (matches template name) | +| `device_id` | ✅ | Source device ID | +| `started_at` | ✅ | ISO timestamp (`Z` or timezone-aware) | +| `ended_at` | ❌ | ISO timestamp for resolution (optional) | +| `severity` | ❌ | Numeric or string-based severity | +| `confidence`, `area`, `lat`, `lon`, `image_url`, `vod`, `hls`, `meta` | ❌ | Optional metadata | + +--- + +## 🧾 Example Alert Flow + +1. **AgCloud Detector** sends a JSON alert to `/alerts`. +2. The service loads the corresponding YAML template (based on `alert_type`). +3. It renders the `summary`, `recommendation`, and `category` using `${variables}`. +4. A properly formatted payload is sent to **Alertmanager v2 API**. +5. Alertmanager handles grouping, silencing, and routing to receivers (e.g., Slack, email). +6. The same summary is displayed in both **Slack messages** and the **AgGuard UI alerts**. + +--- + +## 🧰 Local Run + +```bash +# Install dependencies +pip install fastapi uvicorn pyyaml + +# Run the service +uvicorn main:app --reload --host 0.0.0.0 --port 8000 +``` + +Environment variables can be provided via `.env` or Docker Compose: + +```yaml +environment: + - CFG_PATH=/app/templates/templates/templates.yml + - ALERTMANAGER_URL=http://alertmanager:9093 +``` + +--- + +## 🪶 Logging + +Logs include alert processing details and delivery status: + +``` +2025-11-02 15:34:12 | INFO | [ALERT PAYLOAD] { + "labels": { "alertname": "smoke_detected", ... }, + "annotations": { "summary": "...", ... }, + "startsAt": "..." +} +2025-11-02 15:34:12 | INFO | [Alertmanager] Sent alerts (HTTP 200) +``` + +--- + diff --git a/services/alertmanager_service/compose/alertmanager.yml b/services/alertmanager_service/compose/alertmanager.yml new file mode 100644 index 000000000..ae8fd8d32 --- /dev/null +++ b/services/alertmanager_service/compose/alertmanager.yml @@ -0,0 +1,30 @@ + +global: + resolve_timeout: 24h + +route: + receiver: "null" + group_by: ["alertname", "device", "alert_id"] + group_wait: 0s + group_interval: 1s + repeat_interval: 2h + routes: + # 1️⃣ Gateway route + - receiver: "gateway" + continue: true + matchers: + - source="agcloud-alerts" + + # # 2️⃣ Slack route + # - receiver: "slack" + # continue: false + # matchers: + # - source="agcloud-alerts" + +receivers: + - name: "null" + - name: "gateway" + webhook_configs: + - url: "http://alerts-gateway:8000/internal/alert" + send_resolved: true + diff --git a/services/alertmanager_service/compose/docker-compose.yml b/services/alertmanager_service/compose/docker-compose.yml new file mode 100644 index 000000000..91100a5b1 --- /dev/null +++ b/services/alertmanager_service/compose/docker-compose.yml @@ -0,0 +1,46 @@ +version: "3.9" + +services: + alertmanager: + image: prom/alertmanager:v0.27.0 + container_name: alertmanager + command: + - "--config.file=/etc/alertmanager/alertmanager.yml" + - "--storage.path=/alertmanager" + - "--log.level=debug" + volumes: + - ./alertmanager.yml:/etc/alertmanager/alertmanager.yml:ro + ports: + - "9093:9093" + restart: always + + alertmanager_service: + build: + context: ../src + dockerfile: Dockerfile + container_name: alertmanager_service + ports: + - "8090:8090" + command: ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "8090"] + environment: + - CFG_PATH=/app/templates.yml + - ALERTMANAGER_URL=http://alertmanager:9093 + - GATEWAY_URL=http://alerts-gateway:8000/internal/alert + depends_on: + - alertmanager + - alerts-gateway + + alerts-gateway: + build: + context: ../src + dockerfile: Dockerfile + container_name: alerts_gateway + command: ["uvicorn", "gateway:app", "--host", "0.0.0.0", "--port", "8000"] + ports: + - "8010:8000" # host:container + +networks: + default: + external: true + name: agcloud_ag_cloud + diff --git a/services/alertmanager_service/src/Dockerfile b/services/alertmanager_service/src/Dockerfile new file mode 100644 index 000000000..7fc4608b2 --- /dev/null +++ b/services/alertmanager_service/src/Dockerfile @@ -0,0 +1,25 @@ +# ───────────────────────────────────────────── +# Dockerfile (used for both alert_service and gateway) +# ───────────────────────────────────────────── +FROM python:3.11-slim + +WORKDIR /app + +# Copy code +COPY . . + +# Install deps +RUN pip install --no-cache-dir fastapi "uvicorn[standard]" pyyaml aiohttp asyncpg + + +# Default port — can be overridden by compose +EXPOSE 8090 + +# Default environment +ENV CFG_PATH=/app/templates.yml +ENV ALERTMANAGER_URL=http://alertmanager:9093 +ENV GATEWAY_URL=http://alerts-gateway:8000/internal/alert + + +# Default command (can be overridden in docker-compose) +CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "8090"] diff --git a/services/alertmanager_service/src/alert_service.py b/services/alertmanager_service/src/alert_service.py new file mode 100644 index 000000000..27b6e1978 --- /dev/null +++ b/services/alertmanager_service/src/alert_service.py @@ -0,0 +1,121 @@ +from __future__ import annotations +import yaml, json, logging, string +from typing import Dict, Any, Sequence +from datetime import datetime, timezone +import urllib.request + +log = logging.getLogger(__name__) +logging.basicConfig(level=logging.INFO, format="%(asctime)s | %(levelname)s | %(message)s") + + +# ───────────────────────────────────────────── +# Template Renderer +# ───────────────────────────────────────────── +class AlertTemplateRenderer: + def __init__(self, cfg_path: str): + with open(cfg_path, "r", encoding="utf-8") as f: + cfg = yaml.safe_load(f) or {} + self.templates = cfg.get("templates", {}) + + def render(self, alert_type: str, context: dict) -> dict: + tpl = self.templates.get(alert_type) + if not tpl: + raise ValueError(f"No template found for alert type '{alert_type}'") + return {k: string.Template(str(v)).safe_substitute(context) for k, v in tpl.items()} + + +# ───────────────────────────────────────────── +# Alertmanager HTTP Client +# ───────────────────────────────────────────── +class AlertmanagerClient: + def __init__(self, base_url: str, timeout: float = 3.0): + self.base_url = base_url.rstrip("/") + self.timeout = timeout + + def send(self, alerts: Sequence[Dict[str, Any]]) -> None: + """Send alerts to Alertmanager v2 API endpoint.""" + url = f"{self.base_url}/api/v2/alerts" + data = json.dumps(list(alerts)).encode("utf-8") + req = urllib.request.Request(url, data=data, method="POST", + headers={"Content-Type": "application/json"}) + try: + with urllib.request.urlopen(req, timeout=self.timeout) as resp: + log.info(f"[Alertmanager] Sent alerts (HTTP {resp.status})") + except Exception as e: + log.error(f"[Alertmanager] Failed to send: {e}") + raise + + +# ───────────────────────────────────────────── +# Main Service Logic +# ───────────────────────────────────────────── +class AlertManagerService: + def __init__(self, cfg_path: str, alertmanager_url: str): + self.renderer = AlertTemplateRenderer(cfg_path) + self.client = AlertmanagerClient(alertmanager_url) + + def process_alert(self, data: dict): + """Validate, render, and send an alert to Alertmanager.""" + required_fields = ["alert_id", "alert_type", "device_id", "started_at"] + for field in required_fields: + if field not in data: + raise ValueError(f"Missing required field: {field}") + + tpl = self.renderer.render(data["alert_type"], data) + + # ───── Stable labels ───── + labels = { + "alertname": data["alert_type"], + "alert_id": data["alert_id"], + "device": data["device_id"], + "source": "agcloud-alerts", + } + + # ───── Descriptive annotations ───── + annotations = { + "summary": tpl.get("summary"), + "recommendation": tpl.get("recommendation"), + "category": tpl.get("category"), + "severity": str(data.get("severity", tpl.get("severity", "unknown"))), + } + + # Optional dynamic fields + optional_fields = [ + "confidence", "area", "lat", "lon", + "image_url", "vod", "hls", "meta" + ] + for f in optional_fields: + if f in data and data[f] is not None: + annotations[f] = str(data[f]) + + # ───── Timestamp normalization ───── + def to_utc_iso(s: str | None) -> str | None: + if not s: + return None + dt = datetime.fromisoformat(s.replace("Z", "+00:00")) + if dt.tzinfo is None: + dt = dt.replace(tzinfo=timezone.utc) + return dt.astimezone(timezone.utc).isoformat().replace("+00:00", "Z") + + starts_at = to_utc_iso(data.get("started_at")) + ends_at = to_utc_iso(data.get("ended_at")) + + now_utc = datetime.now(timezone.utc).isoformat().replace("+00:00", "Z") + if starts_at and starts_at > now_utc: + log.warning(f"Start time {starts_at} is in the future, adjusting to now: {now_utc}") + starts_at = now_utc + + + payload = { + "labels": labels, + "annotations": annotations, + "startsAt": starts_at, + } + if ends_at: + payload["endsAt"] = ends_at + + # ───── Send to Alertmanager ───── + self.client.send([payload]) + log.info(f"[ALERT PAYLOAD] {json.dumps(payload, indent=2)}") + + return payload diff --git a/services/alertmanager_service/src/app.py b/services/alertmanager_service/src/app.py new file mode 100644 index 000000000..b9fc3447a --- /dev/null +++ b/services/alertmanager_service/src/app.py @@ -0,0 +1,34 @@ +from fastapi import FastAPI, Request, HTTPException +from fastapi.responses import JSONResponse +from alert_service import AlertManagerService +import os, logging + +app = FastAPI(title="AgGuard AlertManager Service", version="1.3") +log = logging.getLogger(__name__) + +# CFG_PATH = os.getenv("CFG_PATH", "templates.yml") +CFG_PATH = os.getenv("CFG_PATH", "/app/templates/templates.yml") + +ALERTMANAGER_URL = os.getenv("ALERTMANAGER_URL", "http://alertmanager:9093") + +service = AlertManagerService(CFG_PATH, ALERTMANAGER_URL) + + +@app.post("/alerts") +async def post_alert(request: Request): + """ + Receive an alert JSON payload and forward it to Alertmanager. + """ + try: + data = await request.json() + result = service.process_alert(data) + return JSONResponse({"status": "sent", "alert": result}) + except Exception as e: + log.exception("Failed to process alert") + raise HTTPException(status_code=400, detail=str(e)) + + +@app.get("/health") +async def health(): + """Simple health check endpoint.""" + return {"status": "ok"} diff --git a/services/alertmanager_service/src/gateway.py b/services/alertmanager_service/src/gateway.py new file mode 100644 index 000000000..872b3704d --- /dev/null +++ b/services/alertmanager_service/src/gateway.py @@ -0,0 +1,57 @@ +from fastapi import FastAPI, WebSocket, WebSocketDisconnect, Request +from fastapi.responses import JSONResponse +import json, asyncio, logging + + +log = logging.getLogger(__name__) +app = FastAPI(title="AgGuard Alerts Gateway", version="1.0") + +CLIENTS = set() + +@app.websocket("/ws/alerts") +async def ws_alerts(ws: WebSocket): + await ws.accept() + CLIENTS.add(ws) + log.info("Client connected.") + # active = await fetch_active_alerts() + try: + # Send initial snapshot + + while True: + # Wait for pings or keepalive from client + try: + msg = await asyncio.wait_for(ws.receive_text(), timeout=30) + log.debug(f"Received message: {msg}") + except asyncio.TimeoutError: + # No message from client — send ping + await ws.send_json({"type": "ping"}) + continue + except WebSocketDisconnect: + log.info("Client disconnected.") + except Exception as e: + log.exception(f"Error in WebSocket: {e}") + finally: + CLIENTS.discard(ws) + + +@app.post("/internal/alert") +async def internal_alert(request: Request): + """Called by alert_service when a new alert is received.""" + alert = await request.json() + msg = json.dumps({"type": "alert", "data": alert}) + dead = [] + for ws in CLIENTS: + try: + await ws.send_text(msg) + except Exception: + dead.append(ws) + for ws in dead: + CLIENTS.discard(ws) + return {"status": "broadcasted"} + + + +@app.get("/health") +async def health(): + return {"status": "ok"} + diff --git a/services/alerts_forwarder/Dockerfile.flink b/services/alerts_forwarder/Dockerfile.flink new file mode 100644 index 000000000..fbb7026c4 --- /dev/null +++ b/services/alerts_forwarder/Dockerfile.flink @@ -0,0 +1,41 @@ + +FROM flink:1.20.0-scala_2.12-java11 + +USER root + +# Add local CA (place netfree-ca.crt next to this Dockerfile before building) +# COPY netfree-ca.crt /usr/local/share/ca-certificates/netfree-ca.crt +# RUN chmod 644 /usr/local/share/ca-certificates/netfree-ca.crt && update-ca-certificates + +ENV SSL_CERT_FILE=/etc/ssl/certs/ca-certificates.crt +ENV REQUESTS_CA_BUNDLE=/etc/ssl/certs/ca-certificates.crt +ENV PIP_DISABLE_PIP_VERSION_CHECK=1 + +# Python & tools +RUN apt-get update && apt-get install -y --no-install-recommends python3 python3-venv python3-pip curl ca-certificates && rm -rf /var/lib/apt/lists/* + +# Create venv and install pyflink +RUN python3 -m venv /opt/venv +ENV PATH="/opt/venv/bin:$PATH" + +# Install PyFlink (DataStream API) and requests +RUN pip install --upgrade pip certifi && pip install --prefer-binary apache-flink==2.1.0 requests urllib3 + +# Kafka connector jar matching Flink 1.20 (connector v3 series) +# Reference: flink-connector-kafka-3.x for Flink 1.20 +RUN curl -fSL https://repo1.maven.org/maven2/org/apache/flink/flink-connector-kafka/3.2.0-1.19/flink-connector-kafka-3.2.0-1.19.jar \ + -o /opt/flink/lib/flink-connector-kafka-3.2.0-1.19.jar && \ + curl -fSL https://repo1.maven.org/maven2/org/apache/kafka/kafka-clients/3.7.0/kafka-clients-3.7.0.jar \ + -o /opt/flink/lib/kafka-clients-3.7.0.jar + +RUN mkdir -p /opt/app/secrets && chmod -R 777 /opt/app + +WORKDIR /opt/app +COPY alerts_forwarder.py /opt/app/alerts_forwarder.py + +# Flink Python env vars +ENV PYFLINK_CLIENT_EXECUTABLE=/opt/venv/bin/python PYFLINK_PYTHON=/opt/venv/bin/python PYTHONPATH=/opt/app + +# Default command is provided by docker-compose (jobmanager/taskmanager), but keep a convenient default +CMD ["bash", "-lc", "python alerts_forwarder.py"] + diff --git a/services/alerts_forwarder/alerts_forwarder.py b/services/alerts_forwarder/alerts_forwarder.py new file mode 100644 index 000000000..537aa5815 --- /dev/null +++ b/services/alerts_forwarder/alerts_forwarder.py @@ -0,0 +1,53 @@ +import os, json, requests +from pyflink.common import Types +from pyflink.datastream import StreamExecutionEnvironment +from pyflink.datastream.connectors.kafka import KafkaSource, KafkaOffsetsInitializer +from pyflink.common.serialization import SimpleStringSchema +from pyflink.common.watermark_strategy import WatermarkStrategy + +ALERTMANAGER_SERVICE_URL = "http://alertmanager_service:8090/alerts" +KAFKA_BROKERS = "kafka:9092" +TOPIC = "alerts" + +def send_to_alertmanager(alert_json: str): + try: + data = json.loads(alert_json) + resp = requests.post(ALERTMANAGER_SERVICE_URL, json=data, timeout=5) + if resp.status_code == 200: + print(f"✅ Sent alert {data.get('alert_id')}", flush=True) + else: + print(f"❌ {resp.status_code}: {resp.text}", flush=True) + except Exception as e: + print(f"Failed to send alert: {e}", flush=True) + +def main(): + env = StreamExecutionEnvironment.get_execution_environment() + env.set_parallelism(int(os.getenv("FLINK_PARALLELISM", "1"))) + + print(f"[FLINK] Listening on topic: {TOPIC}", flush=True) + + source = ( + KafkaSource.builder() + .set_bootstrap_servers(KAFKA_BROKERS) + .set_topics(TOPIC) + .set_group_id("flink-alerts-to-alertmanager") + .set_starting_offsets(KafkaOffsetsInitializer.latest()) + .set_value_only_deserializer(SimpleStringSchema()) + .build() + ) + + stream = env.from_source(source, WatermarkStrategy.no_watermarks(), "Kafka Alerts Source") + + # ✅ Must have a terminal operator + stream.map( + lambda raw: (send_to_alertmanager(raw) or True), + output_type=Types.BOOLEAN() + ).print() + + env.execute("Flink Alerts → AlertManager Forwarder") + +if __name__ == "__main__": + main() + + + diff --git a/services/alerts_forwarder/docker-compose.yml b/services/alerts_forwarder/docker-compose.yml new file mode 100644 index 000000000..4f45542b7 --- /dev/null +++ b/services/alerts_forwarder/docker-compose.yml @@ -0,0 +1,19 @@ +services: + flink-alerts-job: + build: + context: . + dockerfile: Dockerfile.flink + container_name: alerts-forwarder + depends_on: + - kafka + - alertmanager_service + environment: + - PYTHONPATH=/opt/app + - KAFKA_BROKERS=kafka:9092 + - ALERTMANAGER_SERVICE_URL=http://alertmanager_service:8090/alerts + command: ["python", "/opt/app/alerts_forwarder.py"] + +networks: + default: + external: true + name: agcloud_ag_cloud diff --git a/services/compression/Dockerfile b/services/compression/Dockerfile new file mode 100644 index 000000000..44deed463 --- /dev/null +++ b/services/compression/Dockerfile @@ -0,0 +1,36 @@ +FROM python:3.12-slim + +# Install ffmpeg, cron, and ca-certificates +RUN apt-get update && \ + apt-get install -y ffmpeg cron ca-certificates dos2unix && \ + rm -rf /var/lib/apt/lists/* + +WORKDIR /app + +# Copy certificates +COPY certs /app/certs + +# Install certificates +RUN cp /app/certs/*.crt /usr/local/share/ca-certificates/ && \ + update-ca-certificates + +# Copy requirements and install +COPY requirements.txt . +RUN pip install --no-cache-dir -r requirements.txt + +# Copy source code +COPY src/ ./src/ + +# Create logs directory +RUN mkdir -p /app/src/logs + +# Copy and fix both scripts +COPY docker-entrypoint.sh /docker-entrypoint.sh +RUN dos2unix /docker-entrypoint.sh && \ + chmod +x /docker-entrypoint.sh && \ + dos2unix /app/src/run_tiering.sh && \ + chmod +x /app/src/run_tiering.sh + +WORKDIR /app/src + +ENTRYPOINT ["/docker-entrypoint.sh"] diff --git a/services/sounds/compression/README.md b/services/compression/README.md similarity index 100% rename from services/sounds/compression/README.md rename to services/compression/README.md diff --git a/services/sounds/compression/compressed/cat.flac b/services/compression/compressed/cat.flac similarity index 100% rename from services/sounds/compression/compressed/cat.flac rename to services/compression/compressed/cat.flac diff --git a/services/compression/docker-entrypoint.sh b/services/compression/docker-entrypoint.sh new file mode 100644 index 000000000..997109f08 --- /dev/null +++ b/services/compression/docker-entrypoint.sh @@ -0,0 +1,64 @@ +#!/bin/bash +set -e + +echo "Setting up cron job..." + +# Write environment variables with export prefix +printenv | grep -E '^(RAW_MAX_AGE_DAYS|COMPRESSION_CODEC|COMPRESSED_MAX_AGE_DAYS|MINIO_ENDPOINT|ACCESS_KEY|SECRET_KEY|BUCKET_NAME)=' | sed 's/^/export /' > /app/cron.env + +# Make scripts executable +chmod +x /app/src/run_tiering.sh + +# Create crontab +cat > /tmp/crontab.txt << 'EOF' +SHELL=/bin/bash +PATH=/usr/local/bin:/usr/bin:/bin:/usr/local/sbin:/usr/sbin:/sbin + +# Run every 2 minutes +*/2 * * * * . /app/cron.env && /app/src/run_tiering.sh >> /app/src/logs/cron.log 2>&1 + +# Debug: Log cron is alive every 10 minutes +*/10 * * * * echo "[$(date)] Cron is alive" >> /app/src/logs/cron.log +EOF + +# Install crontab +crontab /tmp/crontab.txt + +echo "===================================" +echo "Audio Compression Service Started" +echo "===================================" +echo "Cron schedule:" +crontab -l +echo "===================================" +echo "Environment variables saved to /app/cron.env:" +cat /app/cron.env +echo "===================================" + +# Create logs directory +mkdir -p /app/src/logs + +# Initial log entry +echo "[$(date)] Cron daemon starting..." >> /app/src/logs/cron.log + +echo "Waiting for MinIO to be ready..." +sleep 10 + +echo "===================================" +echo "Running initial test..." +echo "===================================" + +# Run initial test (this will show if there are any immediate errors) +if /app/src/run_tiering.sh >> /app/src/logs/cron.log 2>&1; then + echo "✓ Initial test completed successfully" +else + echo "✗ Initial test failed - check /app/src/logs/cron.log" +fi + +echo "===================================" +echo "Service is now running." +echo "Compression will run every 2 minutes." +echo "Check logs: docker exec audio_compression tail -f /app/src/logs/cron.log" +echo "===================================" + +# Start cron in foreground +exec cron -f \ No newline at end of file diff --git a/services/sounds/compression/pytest.ini b/services/compression/pytest.ini similarity index 100% rename from services/sounds/compression/pytest.ini rename to services/compression/pytest.ini diff --git a/services/compression/requirements.txt b/services/compression/requirements.txt new file mode 100644 index 000000000..68cfe7aa3 --- /dev/null +++ b/services/compression/requirements.txt @@ -0,0 +1,4 @@ +minio==7.2.18 +# ffmpeg-python==4.4 +# argparse==1.4.0 +# statistics==1.0.3.5 # This module is built-in, but can be added for compatibility diff --git a/services/sounds/sounds_classifier/src/classification/scripts/__init__.py b/services/compression/src/__init__.py similarity index 100% rename from services/sounds/sounds_classifier/src/classification/scripts/__init__.py rename to services/compression/src/__init__.py diff --git a/services/compression/src/minio_client.py b/services/compression/src/minio_client.py new file mode 100644 index 000000000..6e970881b --- /dev/null +++ b/services/compression/src/minio_client.py @@ -0,0 +1,17 @@ +from minio import Minio +import os + +MINIO_ENDPOINT = os.getenv("MINIO_ENDPOINT", "localhost:9001") +ACCESS_KEY = os.getenv("ACCESS_KEY", "minioadmin") +SECRET_KEY = os.getenv("SECRET_KEY", "minioadmin123") +BUCKET_NAME = os.getenv("BUCKET_NAME", "telemetry") + +client = Minio( + MINIO_ENDPOINT, + access_key=ACCESS_KEY, + secret_key=SECRET_KEY, + secure=False, +) + +if not client.bucket_exists(BUCKET_NAME): + client.make_bucket(BUCKET_NAME) diff --git a/services/compression/src/prototype_lib.py b/services/compression/src/prototype_lib.py new file mode 100644 index 000000000..d2ef512af --- /dev/null +++ b/services/compression/src/prototype_lib.py @@ -0,0 +1,157 @@ +from pathlib import Path, PurePosixPath +import subprocess +import tempfile +import time +from datetime import datetime +import re +from minio_client import client, BUCKET_NAME + +# Supported audio formats for compression +AUDIO_EXTS = {".wav", ".mp3", ".flac", ".ogg", ".m4a", ".aac", ".wma", ".opus"} + +RAW_PREFIX = "sound/" + +def is_audio_file(filename: str) -> bool: + """Check if file is an audio file that should be compressed.""" + if filename.lower().endswith((".flac", ".opus")): + return False # If it's already compressed, don't compress again + return any(filename.lower().endswith(ext) for ext in AUDIO_EXTS) + +def iter_audio_files(): + """Yield MinIO object names in RAW_PREFIX that are audio files.""" + for obj in client.list_objects(BUCKET_NAME, prefix=RAW_PREFIX, recursive=True): + if is_audio_file(obj.object_name): + yield obj.object_name + +def parse_timestamp_from_filename(filename: str) -> datetime: + """ + Extract timestamp from filename pattern: sensor-id_timestamp.ext + """ + # Pattern: anything_YYYYMMDDtHHMMSSz.ext + # Case-insensitive for 't' and 'z' + pattern = r'_(\d{8})[tT](\d{6})[zZ]\.' + + match = re.search(pattern, filename) + if not match: + print(f"[WARN] Cannot parse timestamp from filename: {filename}") + return None + + date_part = match.group(1) # YYYYMMDD + time_part = match.group(2) # HHMMSS + + try: + # Parse: 20240901 120000 + dt = datetime.strptime(f"{date_part}{time_part}", "%Y%m%d%H%M%S") + # Assume UTC (because of 'z' suffix) + dt = dt.replace(tzinfo=None) + return dt + except ValueError as e: + print(f"[WARN] Invalid timestamp in filename {filename}: {e}") + return None + +def get_file_age_seconds(obj_name: str) -> float: + """ + Get age of file in seconds based on timestamp in filename. + + Returns: + Age in seconds, or 0 if timestamp cannot be parsed + """ + dt = parse_timestamp_from_filename(obj_name) + if dt is None: + return 0 + + now = datetime.utcnow() + age = now - dt + return age.total_seconds() + +def is_older_than(obj_name: str, age_seconds: int) -> bool: + """Check if file is older than specified age based on filename timestamp.""" + return get_file_age_seconds(obj_name) >= age_seconds + +def build_ffmpeg_cmds(in_local_path: Path, codec="all", flac_level="5", opus_bitrate="96k"): + """ + Return ffmpeg commands to encode a local audio file. + Output will be a temporary file (to upload after encode). + """ + cmds = [] + temp_dir = Path(tempfile.gettempdir()) + + if codec in ("flac", "all"): + flac_out = temp_dir / f"{in_local_path.stem}.flac" + flac_cmd = [ + "ffmpeg", "-y", "-hide_banner", "-loglevel", "error", + "-i", str(in_local_path), + "-c:a", "flac", "-compression_level", flac_level, + str(flac_out) + ] + cmds.append(("flac", flac_cmd, flac_out)) + + if codec in ("opus", "all"): + opus_out = temp_dir / f"{in_local_path.stem}.opus" + opus_cmd = [ + "ffmpeg", "-y", "-hide_banner", "-loglevel", "error", + "-i", str(in_local_path), + "-c:a", "libopus", "-b:a", opus_bitrate, + str(opus_out) + ] + cmds.append(("opus", opus_cmd, opus_out)) + + return cmds + +def download_raw_to_temp(obj_name: str) -> Path: + """Download MinIO raw object to temporary file.""" + local_path = Path(tempfile.gettempdir()) / Path(obj_name).name + client.fget_object(BUCKET_NAME, obj_name, str(local_path)) + return local_path + +def replace_with_compressed(original_obj_name: str, compressed_local_path: Path): + """ + Replace the original file in MinIO with the compressed version. + Keeps the same path, only changes the extension. + + Example: + sound/drone-01_20251102t010618z.wav -> sound/drone-01_20251102t010618z.opus + """ + # Use PurePosixPath to always get forward slashes for MinIO paths + from pathlib import PurePosixPath + + # Parse original path (use PurePosixPath for consistency) + obj_path = PurePosixPath(original_obj_name) + stem = obj_path.stem # e.g., "drone-01_20251102t010618z" + parent = obj_path.parent # e.g., "sound" + + # New compressed extension + compressed_ext = compressed_local_path.suffix # e.g., ".opus" + new_obj_name = str(parent / f"{stem}{compressed_ext}") + + # Upload compressed file to the new path + client.fput_object(BUCKET_NAME, new_obj_name, str(compressed_local_path)) + + # Delete original file + client.remove_object(BUCKET_NAME, original_obj_name) + + return new_obj_name + +def delete_object(obj_name: str): + """Delete an object from MinIO.""" + client.remove_object(BUCKET_NAME, obj_name) + +def get_compressed_variants(obj_name: str) -> list: + """ + Given an object name, return possible compressed variants. + + Example: + sound/drone-01_20251102t010618z.wav -> + [ + "sound/drone-01_20251102t010618z.opus", + "sound/drone-01_20251102t010618z.flac" + ] + """ + obj_path = PurePosixPath(obj_name) + stem = obj_path.stem + parent = obj_path.parent + + return [ + str(parent / f"{stem}.opus"), + str(parent / f"{stem}.flac") + ] \ No newline at end of file diff --git a/services/compression/src/run_bench.py b/services/compression/src/run_bench.py new file mode 100644 index 000000000..19d39568d --- /dev/null +++ b/services/compression/src/run_bench.py @@ -0,0 +1,134 @@ +from pathlib import Path +import time +import csv +from statistics import mean +import subprocess +from prototype_lib import ( + iter_audio_files, + build_ffmpeg_cmds, + download_raw_to_temp, + replace_with_compressed, + parse_timestamp_from_filename, + get_file_age_seconds +) +from minio_client import BUCKET_NAME, client + +RES_DIR = Path("results") +RES_DIR.mkdir(exist_ok=True) + +def file_size_minio(obj_name: str) -> int: + """Return object size in bytes.""" + try: + stat = client.stat_object(BUCKET_NAME, obj_name) + return stat.size + except: + return 0 + +def run_and_profile(cmd): + """Run command and profile CPU usage.""" + import psutil + start = time.time() + proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + parent = psutil.Process(proc.pid) + samples = [] + while proc.poll() is None: + cpu_total = 0.0 + for pr in [parent] + parent.children(recursive=True): + try: + cpu_total += pr.cpu_percent(interval=0.1) + except psutil.NoSuchProcess: + continue + samples.append(cpu_total) + out, err = proc.communicate() + wall = time.time() - start + avg_cpu = mean(samples) if samples else 0.0 + return proc.returncode, wall, avg_cpu, (out or b"") + (err or b"") + +def main(): + rows = [] + files = list(iter_audio_files()) + if not files: + print("No audio files found in MinIO") + return + + print("=" * 70) + print("AUDIO COMPRESSION BENCHMARK") + print("=" * 70) + print(f"Bucket: {BUCKET_NAME}") + print(f"Files to process: {len(files)}") + print("=" * 70) + + for obj_name in files: + # Parse timestamp from filename + dt = parse_timestamp_from_filename(obj_name) + age_seconds = get_file_age_seconds(obj_name) + + print(f"\n[PROCESSING] {obj_name}") + if dt: + print(f" Timestamp: {dt.strftime('%Y-%m-%d %H:%M:%S')} UTC") + print(f" Age: {age_seconds/86400:.1f} days") + + # Download original file + local_file = download_raw_to_temp(obj_name) + orig_size = local_file.stat().st_size + + # Test each codec + for codec, cmd, outp in build_ffmpeg_cmds(local_file): + rc, wall, cpu, _ = run_and_profile(cmd) + if rc != 0: + print(f" [FAIL] {codec.upper()} encoding failed") + continue + + # Replace original with compressed version (same path, different extension) + try: + new_obj_name = replace_with_compressed(obj_name, outp) + enc_size = file_size_minio(new_obj_name) + ratio = (orig_size / enc_size) if enc_size else 0.0 + + print(f" [OK] {codec.upper()}: {new_obj_name}") + print(f" Size: {enc_size:,} bytes ({enc_size/(1024**2):.2f} MB)") + print(f" Ratio: {ratio:.2f}x") + print(f" Time: {wall:.2f}s, CPU: {cpu:.1f}%") + + rows.append({ + "file": Path(obj_name).name, + "codec": codec, + "orig_bytes": orig_size, + "encoded_bytes": enc_size, + "compression_ratio_orig_over_encoded": round(ratio, 3), + "encode_time_sec": round(wall, 3), + "encode_cpu_avg_percent": round(cpu, 1), + "timestamp": dt.isoformat() if dt else "unknown", + "age_days": round(age_seconds / 86400, 1) if age_seconds > 0 else 0, + }) + + # Clean up local encoded file + outp.unlink() + + except Exception as e: + print(f" [FAIL] {codec.upper()}: {e}") + outp.unlink(missing_ok=True) + + # Clean up local original file + local_file.unlink() + + # Save results + if rows: + out_csv = RES_DIR / "benchmarks.csv" + with open(out_csv, "w", newline="", encoding="utf-8") as fh: + writer = csv.DictWriter(fh, fieldnames=list(rows[0].keys())) + writer.writeheader() + writer.writerows(rows) + + print("\n" + "=" * 70) + print("SUMMARY") + print("=" * 70) + print(f"Total files benchmarked: {len(files)}") + print(f"Total tests: {len(rows)}") + print(f"Results saved: {out_csv}") + print("=" * 70) + else: + print("\n[WARN] No successful encodings to save") + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/services/compression/src/run_tiering.sh b/services/compression/src/run_tiering.sh new file mode 100644 index 000000000..2ea1fc7ab --- /dev/null +++ b/services/compression/src/run_tiering.sh @@ -0,0 +1,25 @@ +#!/bin/bash + +SCRIPT_DIR="/app/src" +LOG_DIR="$SCRIPT_DIR/logs" +LOG_FILE="$LOG_DIR/tiering_$(date +%Y%m%d_%H%M%S).log" + +mkdir -p "$LOG_DIR" +cd "$SCRIPT_DIR" + +# In Docker, python3 is in PATH, no venv needed +PYTHON="python3" + +# ffmpeg is already in PATH from Dockerfile +export PATH="/usr/bin:/usr/local/bin:$PATH" + +echo "=== Starting at $(date) ===" >> "$LOG_FILE" +"$PYTHON" tiering_job.py \ + --raw-max-age-days ${RAW_MAX_AGE_DAYS:-30} \ + --codec ${COMPRESSION_CODEC:-opus} \ + --compressed-max-age-days ${COMPRESSED_MAX_AGE_DAYS:-90} \ + >> "$LOG_FILE" 2>&1 +echo "=== Finished at $(date) ===" >> "$LOG_FILE" + +# Clean old logs (older than 7 days) +find "$LOG_DIR" -name "tiering_*.log" -mtime +7 -delete diff --git a/services/compression/src/tiering_job.py b/services/compression/src/tiering_job.py new file mode 100644 index 000000000..ef42665bf --- /dev/null +++ b/services/compression/src/tiering_job.py @@ -0,0 +1,239 @@ +from pathlib import Path +import time +import argparse +import subprocess +from prototype_lib import ( + iter_audio_files, + build_ffmpeg_cmds, + download_raw_to_temp, + replace_with_compressed, + delete_object, + get_file_age_seconds, + is_older_than, + RAW_PREFIX # Import the prefix to ensure consistency +) +from minio_client import client, BUCKET_NAME + +DEFAULT_RAW_MAX_AGE_DAYS = 30 +DEFAULT_COMP_MAX_AGE_DAYS = 90 +DEFAULT_LONG_TERM_CODEC = "opus" + +def encode_and_replace(obj_name: str, codec: str) -> str: + """ + Download audio file, encode it, and replace the original in MinIO. + + Returns: + The new object name (with compressed extension) + """ + # Skip already compressed files + if obj_name.lower().endswith((".flac", ".opus")): + print(f"[SKIP] {obj_name} - Already compressed, skipping.") + return obj_name + + # Download original + local_file = download_raw_to_temp(obj_name) + + # Build encode commands + encode_cmds = build_ffmpeg_cmds(local_file, codec=codec) + + if not encode_cmds: + local_file.unlink() + raise RuntimeError(f"No encode commands generated for {obj_name}") + + # Take the first codec result + codec_name, cmd, output_path = encode_cmds[0] + + # Run encoding + print(f"[ENC] Encoding {obj_name} -> {codec_name.upper()}...") + rc = subprocess.call(cmd) + + if rc != 0: + local_file.unlink() + output_path.unlink(missing_ok=True) + raise RuntimeError(f"Encode failed: {obj_name} -> {codec}") + + # Replace in MinIO (same path, different extension) + new_obj_name = replace_with_compressed(obj_name, output_path) + + # Cleanup local files + output_path.unlink() + local_file.unlink() + + return new_obj_name + +def cleanup_compressed(max_age_days: int, dry_run: bool) -> int: + """ + Delete very old compressed files that exceeded retention period. + Uses timestamp from filename (same logic as compression). + """ + if max_age_days <= 0: + print("[INFO] Compressed cleanup disabled (max_age_days <= 0)") + return 0 + + cutoff_sec = max_age_days * 86400 + deleted = 0 + + print(f"\n[CLEANUP] Checking for compressed files older than {max_age_days} days...") + print(f"[CLEANUP] Looking in prefix: {RAW_PREFIX}") + print(f"[CLEANUP] Using timestamp from filename") + + # Only check files in the RAW_PREFIX (sound/) directory + for obj in client.list_objects(BUCKET_NAME, prefix=RAW_PREFIX, recursive=True): + # Only consider compressed files + if not obj.object_name.lower().endswith(('.opus', '.flac')): + continue + + # Use timestamp from filename (consistent with compression logic) + file_age_sec = get_file_age_seconds(obj.object_name) + + if file_age_sec == 0: + print(f"[SKIP] {obj.object_name} - Cannot parse timestamp from filename") + continue + + file_age_days = file_age_sec / 86400 + + print(f"[CHECK] {obj.object_name}: {file_age_days:.1f} days old (from filename)") + + if file_age_sec >= cutoff_sec: + if dry_run: + print(f"[DRY] Would delete: {obj.object_name} (age={file_age_days:.1f} days)") + deleted += 1 + else: + try: + delete_object(obj.object_name) + deleted += 1 + print(f"[DEL] ✓ Deleted: {obj.object_name} (age={file_age_days:.1f} days)") + except Exception as e: + print(f"[ERROR] Failed to delete {obj.object_name}: {e}") + else: + remaining_days = max_age_days - file_age_days + print(f"[KEEP] {obj.object_name} - will be deleted in {remaining_days:.1f} days") + + return deleted + +def main(): + ap = argparse.ArgumentParser( + description="Two-tier audio compression job - compresses files based on filename timestamp" + ) + ap.add_argument( + "--raw-max-age-days", + type=int, + default=DEFAULT_RAW_MAX_AGE_DAYS, + help=f"Age threshold in days for audio files to be compressed (default: {DEFAULT_RAW_MAX_AGE_DAYS})" + ) + ap.add_argument( + "--codec", + choices=["opus", "flac"], + default=DEFAULT_LONG_TERM_CODEC, + help="Compression codec to use" + ) + ap.add_argument( + "--compressed-max-age-days", + type=int, + default=DEFAULT_COMP_MAX_AGE_DAYS, + help="Delete compressed files older than this many days (0 to disable)" + ) + ap.add_argument( + "--dry-run", + action="store_true", + help="Simulate operations without making changes" + ) + args = ap.parse_args() + + # Calculate age threshold + raw_age_seconds = args.raw_max_age_days * 86400 + age_desc = f"{args.raw_max_age_days} days" + + print("=" * 70) + print("AUDIO COMPRESSION & TIERING JOB") + print("=" * 70) + print(f"Bucket: {BUCKET_NAME}") + print(f"Age threshold: {age_desc} (based on filename timestamp)") + print(f"Codec: {args.codec.upper()}") + print(f"Compressed retention: {args.compressed_max_age_days} days") + print(f"Mode: {'DRY RUN' if args.dry_run else 'LIVE'}") + print("=" * 70) + + processed = 0 + skipped = 0 + errors = 0 + total_orig_size = 0 + total_comp_size = 0 + + # Process audio files only + for obj_name in iter_audio_files(): + + # Check age based on filename timestamp + age = get_file_age_seconds(obj_name) + + if age == 0: + print(f"[SKIP] {obj_name} - Cannot parse timestamp from filename") + skipped += 1 + continue + + if not is_older_than(obj_name, raw_age_seconds): + skipped += 1 + continue + + age_days = age / 86400 + + if args.dry_run: + print(f"[DRY] Would compress: {obj_name} (age={age_days:.1f} days) -> {args.codec.upper()}") + processed += 1 + continue + + # Get original size + try: + orig_stat = client.stat_object(BUCKET_NAME, obj_name) + orig_size = orig_stat.size + total_orig_size += orig_size + except: + orig_size = 0 + + try: + new_obj_name = encode_and_replace(obj_name, args.codec) + + # Get compressed size + try: + comp_stat = client.stat_object(BUCKET_NAME, new_obj_name) + comp_size = comp_stat.size + total_comp_size += comp_size + saved = orig_size - comp_size + ratio = orig_size / comp_size if comp_size > 0 else 0 + + print(f"[OK] Compressed: {obj_name} -> {new_obj_name}") + print(f" Age: {age_days:.1f} days") + print(f" Original: {orig_size:,} bytes ({orig_size/(1024**2):.2f} MB)") + print(f" Compressed: {comp_size:,} bytes ({comp_size/(1024**2):.2f} MB)") + print(f" Ratio: {ratio:.2f}x, Saved: {saved:,} bytes ({saved/(1024**2):.2f} MB)") + except: + print(f"[OK] Compressed: {obj_name} -> {new_obj_name}") + + processed += 1 + + except Exception as e: + errors += 1 + print(f"[FAIL] {obj_name}: {e}") + + # Cleanup very old compressed files + comp_deleted = cleanup_compressed(args.compressed_max_age_days, args.dry_run) + + print("\n" + "=" * 70) + print("SUMMARY") + print("=" * 70) + print(f"Audio files compressed: {processed}") + print(f"Files skipped (too new): {skipped}") + print(f"Errors: {errors}") + print(f"Old compressed deleted: {comp_deleted}") + if total_orig_size > 0 and total_comp_size > 0: + total_saved = total_orig_size - total_comp_size + total_ratio = total_orig_size / total_comp_size + print(f"Total original size: {total_orig_size:,} bytes ({total_orig_size/(1024**2):.2f} MB)") + print(f"Total compressed size: {total_comp_size:,} bytes ({total_comp_size/(1024**2):.2f} MB)") + print(f"Total saved: {total_saved:,} bytes ({total_saved/(1024**2):.2f} MB)") + print(f"Overall compression ratio: {total_ratio:.2f}x") + print("=" * 70) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/services/db_api_service/certs/.keep b/services/compression/tests/__init__.py similarity index 100% rename from services/db_api_service/certs/.keep rename to services/compression/tests/__init__.py diff --git a/services/compression/tests/test_prototype_lib.py b/services/compression/tests/test_prototype_lib.py new file mode 100644 index 000000000..5ecaa127d --- /dev/null +++ b/services/compression/tests/test_prototype_lib.py @@ -0,0 +1,435 @@ +""" +Tests for prototype_lib.py - Audio compression library +Updated to match the new implementation with same-path compression +""" + +from pathlib import Path +import pytest +import tempfile +from datetime import datetime, timedelta +from unittest.mock import patch, MagicMock, Mock +from prototype_lib import ( + is_audio_file, + iter_audio_files, + parse_timestamp_from_filename, + get_file_age_seconds, + is_older_than, + build_ffmpeg_cmds, + get_compressed_variants, + find_file_with_fallback, + AUDIO_EXTS, + RAW_PREFIX +) + + +class TestIsAudioFile: + """Tests for is_audio_file function""" + + def test_valid_audio_extensions(self): + """Test that valid audio files are recognized""" + valid_files = [ + "audio.wav", "music.mp3", "sound.flac", "voice.opus", + "song.ogg", "track.m4a", "audio.aac", "sound.wma" + ] + + for filename in valid_files: + assert is_audio_file(filename), f"{filename} should be recognized as audio" + + def test_case_insensitive(self): + """Test case insensitivity""" + assert is_audio_file("AUDIO.WAV") + assert is_audio_file("Music.MP3") + assert is_audio_file("SoUnD.fLaC") + + def test_invalid_extensions(self): + """Test that non-audio files are rejected""" + invalid_files = [ + "doc.txt", "image.jpg", "video.mp4", "data.csv", "script.py" + ] + + for filename in invalid_files: + assert not is_audio_file(filename), f"{filename} should not be recognized as audio" + + def test_audio_exts_constant(self): + """Test that AUDIO_EXTS contains expected formats""" + expected = {".wav", ".mp3", ".flac", ".ogg", ".m4a", ".aac", ".wma", ".opus"} + assert AUDIO_EXTS == expected + + +class TestIterAudioFiles: + """Tests for iter_audio_files function""" + + @patch('prototype_lib.client') + def test_iter_audio_files_filters_correctly(self, mock_client): + """Test that only audio files are returned""" + # Mock MinIO objects + mock_objects = [ + Mock(object_name=f"{RAW_PREFIX}drone-01_20251102t010618z.wav"), + Mock(object_name=f"{RAW_PREFIX}sensor-02_20251102t020000z.mp3"), + Mock(object_name=f"{RAW_PREFIX}data.txt"), # Should be filtered out + Mock(object_name=f"{RAW_PREFIX}image.jpg"), # Should be filtered out + ] + mock_client.list_objects.return_value = mock_objects + + files = list(iter_audio_files()) + + assert len(files) == 2 + assert all(is_audio_file(f) for f in files) + mock_client.list_objects.assert_called_once() + + @patch('prototype_lib.client') + def test_iter_audio_files_empty_bucket(self, mock_client): + """Test with empty bucket""" + mock_client.list_objects.return_value = [] + + files = list(iter_audio_files()) + + assert len(files) == 0 + + +class TestParseTimestampFromFilename: + """Tests for parse_timestamp_from_filename function""" + + def test_valid_timestamp_lowercase(self): + """Test parsing valid timestamp (lowercase t and z)""" + filename = "drone-01_20251102t010618z.wav" + dt = parse_timestamp_from_filename(filename) + + assert dt is not None + assert dt.year == 2025 + assert dt.month == 11 + assert dt.day == 2 + assert dt.hour == 1 + assert dt.minute == 6 + assert dt.second == 18 + + def test_valid_timestamp_uppercase(self): + """Test parsing valid timestamp (uppercase T and Z)""" + filename = "sensor-06_20251028T080000Z.opus" + dt = parse_timestamp_from_filename(filename) + + assert dt is not None + assert dt.year == 2025 + assert dt.month == 10 + assert dt.day == 28 + assert dt.hour == 8 + assert dt.minute == 0 + assert dt.second == 0 + + def test_valid_timestamp_mixed_case(self): + """Test parsing with mixed case""" + filename = "device_20240101T120000z.mp3" + dt = parse_timestamp_from_filename(filename) + + assert dt is not None + assert dt.year == 2024 + + def test_invalid_timestamp_format(self): + """Test with invalid timestamp format""" + invalid_filenames = [ + "audio_without_timestamp.wav", + "audio_20251102.wav", # Missing time + "audio_2025-11-02t01:06:18z.wav", # Wrong format (hyphens and colons) + ] + + for filename in invalid_filenames: + dt = parse_timestamp_from_filename(filename) + assert dt is None, f"Should return None for {filename}" + + def test_invalid_date_values(self): + """Test with invalid date values""" + # Invalid month (13) + filename = "audio_20251302t010618z.wav" + dt = parse_timestamp_from_filename(filename) + assert dt is None + + def test_full_path(self): + """Test parsing from full MinIO path""" + full_path = "sound/drone-01_20251102t010618z.wav" + dt = parse_timestamp_from_filename(full_path) + + assert dt is not None + assert dt.year == 2025 + + +class TestGetFileAgeSeconds: + """Tests for get_file_age_seconds function""" + + @patch('prototype_lib.datetime') + def test_get_file_age_recent_file(self, mock_datetime_module): + """Test age calculation for recent file""" + # Mock current time + now = datetime(2025, 11, 2, 12, 0, 0) + mock_datetime_module.utcnow.return_value = now + mock_datetime_module.strptime = datetime.strptime + + # File timestamp: 1 hour ago + filename = "audio_20251102t110000z.wav" + + age = get_file_age_seconds(filename) + + assert age == 3600 # 1 hour in seconds + + @patch('prototype_lib.datetime') + def test_get_file_age_old_file(self, mock_datetime_module): + """Test age calculation for old file""" + now = datetime(2025, 11, 2, 12, 0, 0) + mock_datetime_module.utcnow.return_value = now + mock_datetime_module.strptime = datetime.strptime + + # File timestamp: 30 days ago + filename = "audio_20251003t120000z.wav" + + age = get_file_age_seconds(filename) + + expected_age = 30 * 86400 # 30 days in seconds + assert abs(age - expected_age) < 60 # Allow 1 minute tolerance + + def test_get_file_age_no_timestamp(self): + """Test with file that has no parseable timestamp""" + filename = "audio_no_timestamp.wav" + + age = get_file_age_seconds(filename) + + assert age == 0 + + +class TestIsOlderThan: + """Tests for is_older_than function""" + + @patch('prototype_lib.get_file_age_seconds') + def test_is_older_than_true(self, mock_get_age): + """Test when file is older than threshold""" + mock_get_age.return_value = 7200 # 2 hours + + result = is_older_than("audio.wav", 3600) # 1 hour threshold + + assert result is True + + @patch('prototype_lib.get_file_age_seconds') + def test_is_older_than_false(self, mock_get_age): + """Test when file is younger than threshold""" + mock_get_age.return_value = 1800 # 30 minutes + + result = is_older_than("audio.wav", 3600) # 1 hour threshold + + assert result is False + + @patch('prototype_lib.get_file_age_seconds') + def test_is_older_than_equal(self, mock_get_age): + """Test when file age equals threshold""" + mock_get_age.return_value = 3600 + + result = is_older_than("audio.wav", 3600) + + assert result is True # >= should return True + + +class TestBuildFfmpegCmds: + """Tests for build_ffmpeg_cmds function""" + + def test_build_commands_all_codecs(self): + """Test building commands for all codecs""" + input_path = Path("test_audio.wav") + + cmds = build_ffmpeg_cmds(input_path, codec="all") + + assert len(cmds) == 2 + codec_names = [cmd[0] for cmd in cmds] + assert "flac" in codec_names + assert "opus" in codec_names + + def test_build_commands_flac_only(self): + """Test building commands for FLAC only""" + input_path = Path("test_audio.wav") + + cmds = build_ffmpeg_cmds(input_path, codec="flac") + + assert len(cmds) == 1 + assert cmds[0][0] == "flac" + assert "-c:a" in cmds[0][1] + assert "flac" in cmds[0][1] + + def test_build_commands_opus_only(self): + """Test building commands for Opus only""" + input_path = Path("test_audio.wav") + + cmds = build_ffmpeg_cmds(input_path, codec="opus") + + assert len(cmds) == 1 + assert cmds[0][0] == "opus" + assert "-c:a" in cmds[0][1] + assert "libopus" in cmds[0][1] + + def test_build_commands_custom_parameters(self): + """Test custom compression parameters""" + input_path = Path("test.wav") + + cmds = build_ffmpeg_cmds(input_path, codec="all", + flac_level="8", opus_bitrate="128k") + + flac_cmd = [c for c in cmds if c[0] == "flac"][0] + opus_cmd = [c for c in cmds if c[0] == "opus"][0] + + assert "8" in flac_cmd[1] + assert "128k" in opus_cmd[1] + + def test_build_commands_output_extensions(self): + """Test that output files have correct extensions""" + input_path = Path("audio.wav") + + cmds = build_ffmpeg_cmds(input_path, codec="all") + + for codec, _, output_path in cmds: + if codec == "flac": + assert output_path.suffix == ".flac" + elif codec == "opus": + assert output_path.suffix == ".opus" + + def test_build_commands_preserves_stem(self): + """Test that file stem is preserved""" + input_path = Path("my_audio_file.wav") + + cmds = build_ffmpeg_cmds(input_path) + + for _, _, output_path in cmds: + assert output_path.stem == "my_audio_file" + + +class TestGetCompressedVariants: + """Tests for get_compressed_variants function""" + + def test_get_variants_basic(self): + """Test getting compressed variants""" + obj_name = "sound/drone-01_20251102t010618z.wav" + + variants = get_compressed_variants(obj_name) + + assert len(variants) == 2 + assert "sound/drone-01_20251102t010618z.opus" in variants + assert "sound/drone-01_20251102t010618z.flac" in variants + + def test_get_variants_already_compressed(self): + """Test variants for already compressed file""" + obj_name = "sound/audio.opus" + + variants = get_compressed_variants(obj_name) + + assert "sound/audio.opus" in variants + assert "sound/audio.flac" in variants + + def test_get_variants_nested_path(self): + """Test with nested directory structure""" + obj_name = "sound/folder1/folder2/audio.mp3" + + variants = get_compressed_variants(obj_name) + + assert all("folder1/folder2" in v for v in variants) + + +class TestFindFileWithFallback: + """Tests for find_file_with_fallback function""" + + @patch('prototype_lib.client') + def test_find_original_exists(self, mock_client): + """Test when original file exists""" + obj_name = "sound/audio.wav" + mock_client.stat_object.return_value = Mock() + + found, exists = find_file_with_fallback(obj_name) + + assert exists is True + assert found == obj_name + mock_client.stat_object.assert_called_once() + + @patch('prototype_lib.client') + def test_find_compressed_variant(self, mock_client): + """Test fallback to compressed variant""" + obj_name = "sound/audio.wav" + + # First call (original) fails, second call (opus variant) succeeds + mock_client.stat_object.side_effect = [ + Exception("Not found"), # Original doesn't exist + Mock(), # Opus variant exists + ] + + found, exists = find_file_with_fallback(obj_name) + + assert exists is True + assert found == "sound/audio.opus" + assert mock_client.stat_object.call_count == 2 + + @patch('prototype_lib.client') + def test_find_no_variants_exist(self, mock_client): + """Test when no variants exist""" + obj_name = "sound/audio.wav" + mock_client.stat_object.side_effect = Exception("Not found") + + found, exists = find_file_with_fallback(obj_name) + + assert exists is False + assert found == obj_name + assert mock_client.stat_object.call_count == 3 # Original + 2 variants + + +class TestEdgeCases: + """Test edge cases and error handling""" + + def test_parse_timestamp_edge_dates(self): + """Test with edge case dates""" + # New Year + dt = parse_timestamp_from_filename("audio_20250101t000000z.wav") + assert dt.month == 1 and dt.day == 1 + + # End of year + dt = parse_timestamp_from_filename("audio_20251231t235959z.wav") + assert dt.month == 12 and dt.day == 31 + + def test_build_ffmpeg_special_characters(self): + """Test with special characters in filename""" + special_names = [ + Path("file with spaces.wav"), + Path("file-with-dashes.mp3"), + Path("file_with_underscores.flac"), + ] + + for input_path in special_names: + cmds = build_ffmpeg_cmds(input_path) + assert len(cmds) > 0 + + def test_audio_extensions_lowercase(self): + """Ensure all extensions in AUDIO_EXTS are lowercase""" + for ext in AUDIO_EXTS: + assert ext.islower(), f"Extension {ext} should be lowercase" + assert ext.startswith("."), f"Extension {ext} should start with dot" + + +class TestIntegration: + """Integration tests""" + + @patch('prototype_lib.datetime') + @patch('prototype_lib.client') + def test_full_workflow_simulation(self, mock_client, mock_datetime_module): + """Simulate full workflow: list -> filter -> check age -> find""" + # Setup + now = datetime(2025, 11, 2, 12, 0, 0) + mock_datetime_module.utcnow.return_value = now + mock_datetime_module.strptime = datetime.strptime + + # Mock MinIO files (one old, one new) + mock_objects = [ + Mock(object_name="sound/old_20251001t120000z.wav"), # 32 days old + Mock(object_name="sound/new_20251102t100000z.wav"), # 2 hours old + ] + mock_client.list_objects.return_value = mock_objects + + # Get all audio files + files = list(iter_audio_files()) + assert len(files) == 2 + + # Check which are old (30+ days) + threshold = 30 * 86400 + old_files = [f for f in files if is_older_than(f, threshold)] + + assert len(old_files) == 1 + assert "old_" in old_files[0] diff --git a/services/compression/tests/test_run_bench.py b/services/compression/tests/test_run_bench.py new file mode 100644 index 000000000..3697749ac --- /dev/null +++ b/services/compression/tests/test_run_bench.py @@ -0,0 +1,199 @@ +from pathlib import Path +from run_bench import run_and_profile, file_size_minio, main +from unittest.mock import patch, Mock, MagicMock, call +import subprocess +import csv +import pytest + +# Test the `run_and_profile` function +@patch('psutil.Process') +@patch('psutil.NoSuchProcess', new=Exception) +@patch('run_bench.subprocess.Popen') +def test_run_and_profile(mock_popen, mock_process_class): + """Test run_and_profile function with mocked subprocess and psutil""" + # Setup mocks + mock_proc = Mock() + mock_proc.pid = 12345 + mock_proc.poll.side_effect = [None, None, 0] # Simulate process running then finishing + mock_proc.communicate.return_value = (b"ffmpeg version", b"") + mock_proc.returncode = 0 + mock_popen.return_value = mock_proc + + # Mock psutil Process + mock_parent = Mock() + mock_parent.cpu_percent.return_value = 50.0 + mock_parent.children.return_value = [] + mock_process_class.return_value = mock_parent + + # Set up a test command for profiling + cmd = ["ffmpeg", "-version"] + + # Run the command and get results + rc, wall_time, cpu, output = run_and_profile(cmd) + + # Test if the command ran successfully + assert rc == 0, f"Command failed with return code {rc}" + assert wall_time > 0, "Wall time should be greater than zero" + assert cpu >= 0, "CPU usage should be non-negative" + assert b"ffmpeg" in output, "Expected output not found" + + +# Test the `file_size_minio` function +@patch('run_bench.client') +def test_file_size(mock_client): + """Test file_size_minio function with mocked MinIO client""" + test_obj_name = "sound/drone-01_20251102t010618z.wav" + + # Mock stat_object to return a size + mock_stat = Mock() + mock_stat.size = 1024000 # 1MB + mock_client.stat_object.return_value = mock_stat + + # Get the file size + size = file_size_minio(test_obj_name) + + # Test if the file size is correct + assert size == 1024000, f"Expected file size 1024000, but got {size}" + mock_client.stat_object.assert_called_once() + + +# Test the `main` function to check if the CSV is created correctly +@patch('run_bench.client') +@patch('run_bench.iter_audio_files') +@patch('run_bench.download_raw_to_temp') +@patch('run_bench.replace_with_compressed') +@patch('run_bench.run_and_profile') +@patch('run_bench.parse_timestamp_from_filename') +@patch('run_bench.get_file_age_seconds') +@patch('run_bench.build_ffmpeg_cmds') +def test_main(mock_build, mock_get_age, mock_parse_ts, mock_profile, + mock_replace, mock_download, mock_iter, mock_client): + """Test main function with all dependencies mocked""" + from datetime import datetime + + # Setup mocks + test_obj = "sound/test_20251102t120000z.wav" + mock_iter.return_value = [test_obj] + + mock_dt = datetime(2025, 11, 2, 12, 0, 0) + mock_parse_ts.return_value = mock_dt + mock_get_age.return_value = 86400 # 1 day + + # Mock local file + mock_local = Mock() + mock_local.stat.return_value.st_size = 2000000 # 2MB original + mock_local.unlink = Mock() + mock_download.return_value = mock_local + + # Mock encoding + mock_profile.return_value = (0, 5.0, 75.0, b"") # success, 5s, 75% CPU + + # Mock MinIO client for file_size_minio + mock_stat = Mock() + mock_stat.size = 500000 # 500KB compressed + mock_client.stat_object.return_value = mock_stat + + # Mock replace operation + mock_replace.return_value = "sound/test_20251102t120000z.opus" + + # Mock output file + mock_output = Mock() + mock_output.unlink = Mock() + + # Mock build_ffmpeg_cmds + mock_build.return_value = [ + ("opus", ["ffmpeg", "-i", "test.wav", "test.opus"], mock_output) + ] + + # Run main + result_path = Path("results/benchmarks.csv") + if result_path.exists(): + result_path.unlink() + + main() + + # Check if the results file was created + assert result_path.exists(), "The result CSV file was not created" + + # Check if it contains rows + with open(result_path, "r") as file: + reader = csv.DictReader(file) + rows = list(reader) + assert len(rows) > 0, "No rows in the results CSV" + + # Verify row contents + row = rows[0] + assert row["codec"] == "opus" + assert float(row["compression_ratio_orig_over_encoded"]) == 4.0 # 2MB / 500KB + + +# Additional tests for edge cases +@patch('run_bench.client') +def test_file_size_invalid_file(mock_client): + """Test the file size function with a non-existent file""" + invalid_obj = "sound/nonexistent_file.wav" + + # Mock stat_object to raise exception + mock_client.stat_object.side_effect = Exception("Object not found") + + size = file_size_minio(invalid_obj) + assert size == 0, f"Expected file size to be 0, but got {size}" + + +@patch('run_bench.client') +@patch('run_bench.iter_audio_files') +def test_main_no_files(mock_iter, mock_client, capsys): + """Test main function when no files are found""" + mock_iter.return_value = [] + + main() + + captured = capsys.readouterr() + assert "No audio files found" in captured.out + + +@patch('run_bench.client') +@patch('run_bench.iter_audio_files') +@patch('run_bench.download_raw_to_temp') +@patch('run_bench.run_and_profile') +@patch('run_bench.parse_timestamp_from_filename') +@patch('run_bench.get_file_age_seconds') +@patch('run_bench.build_ffmpeg_cmds') +def test_main_encoding_failure(mock_build, mock_get_age, mock_parse_ts, + mock_profile, mock_download, mock_iter, mock_client): + """Test main function when encoding fails""" + from datetime import datetime + + test_obj = "sound/test.wav" + mock_iter.return_value = [test_obj] + + mock_parse_ts.return_value = datetime(2025, 11, 2, 12, 0, 0) + mock_get_age.return_value = 0 + + mock_local = Mock() + mock_local.stat.return_value.st_size = 1000000 + mock_local.unlink = Mock() + mock_download.return_value = mock_local + + # Mock failed encoding (return code != 0) + mock_profile.return_value = (1, 0.0, 0.0, b"error") + + mock_output = Mock() + mock_output.unlink = Mock() + + mock_build.return_value = [ + ("opus", ["ffmpeg", "-i", "test.wav", "test.opus"], mock_output) + ] + + result_path = Path("results/benchmarks.csv") + if result_path.exists(): + result_path.unlink() + + main() + + # CSV should not be created or should be empty + if result_path.exists(): + with open(result_path, "r") as file: + reader = csv.DictReader(file) + rows = list(reader) + assert len(rows) == 0, "Should have no successful encodings" diff --git a/services/compression/tests/test_tiering_job.py b/services/compression/tests/test_tiering_job.py new file mode 100644 index 000000000..6bdb133be --- /dev/null +++ b/services/compression/tests/test_tiering_job.py @@ -0,0 +1,495 @@ +""" +Tests for tiering_job.py - Audio compression and tiering +Updated to match the new implementation with same-path compression +""" + +import pytest +from pathlib import Path +from unittest.mock import patch, Mock, MagicMock, call +from tiering_job import ( + encode_and_replace, + cleanup_compressed, + main, + DEFAULT_RAW_MAX_AGE_DAYS, + DEFAULT_COMP_MAX_AGE_DAYS, + DEFAULT_LONG_TERM_CODEC +) + + +class TestEncodeAndReplace: + """Tests for encode_and_replace function""" + + @patch('tiering_job.download_raw_to_temp') + @patch('tiering_job.build_ffmpeg_cmds') + @patch('tiering_job.subprocess.call') + @patch('tiering_job.replace_with_compressed') + def test_encode_and_replace_success( + self, mock_replace, mock_subprocess, mock_build, mock_download + ): + """Test successful encoding and replacement""" + # Setup mocks + mock_local = Mock() + mock_local.unlink = Mock() + mock_download.return_value = mock_local + + mock_output = Mock() + mock_output.unlink = Mock() + mock_build.return_value = [ + ("opus", ["ffmpeg", "-i", "input", "output"], mock_output) + ] + + mock_subprocess.return_value = 0 # Success + mock_replace.return_value = "sound/audio.opus" + + # Execute + result = encode_and_replace("sound/audio.wav", "opus") + + # Verify + assert result == "sound/audio.opus" + mock_download.assert_called_once_with("sound/audio.wav") + mock_build.assert_called_once() + mock_subprocess.assert_called_once() + mock_replace.assert_called_once() + mock_local.unlink.assert_called_once() + mock_output.unlink.assert_called_once() + + @patch('tiering_job.download_raw_to_temp') + @patch('tiering_job.build_ffmpeg_cmds') + @patch('tiering_job.subprocess.call') + def test_encode_and_replace_encode_failure( + self, mock_subprocess, mock_build, mock_download + ): + """Test handling of encoding failure""" + mock_local = Mock() + mock_local.unlink = Mock() + mock_download.return_value = mock_local + + mock_output = Mock() + mock_output.unlink = Mock() + mock_build.return_value = [ + ("opus", ["ffmpeg"], mock_output) + ] + + mock_subprocess.return_value = 1 # Failure + + with pytest.raises(RuntimeError, match="Encode failed"): + encode_and_replace("sound/audio.wav", "opus") + + # Verify cleanup happened + mock_local.unlink.assert_called_once() + mock_output.unlink.assert_called_once() + + @patch('tiering_job.download_raw_to_temp') + @patch('tiering_job.build_ffmpeg_cmds') + def test_encode_and_replace_no_commands(self, mock_build, mock_download): + """Test when no encode commands are generated""" + mock_local = Mock() + mock_local.unlink = Mock() + mock_download.return_value = mock_local + + mock_build.return_value = [] # No commands + + with pytest.raises(RuntimeError, match="No encode commands"): + encode_and_replace("sound/audio.wav", "opus") + + mock_local.unlink.assert_called_once() + + @patch('tiering_job.download_raw_to_temp') + @patch('tiering_job.build_ffmpeg_cmds') + @patch('tiering_job.subprocess.call') + @patch('tiering_job.replace_with_compressed') + def test_encode_and_replace_flac_codec( + self, mock_replace, mock_subprocess, mock_build, mock_download + ): + """Test encoding with FLAC codec""" + mock_local = Mock() + mock_local.unlink = Mock() + mock_download.return_value = mock_local + + mock_output = Mock() + mock_output.unlink = Mock() + mock_build.return_value = [ + ("flac", ["ffmpeg"], mock_output) + ] + + mock_subprocess.return_value = 0 + mock_replace.return_value = "sound/audio.flac" + + result = encode_and_replace("sound/audio.wav", "flac") + + assert result == "sound/audio.flac" + mock_build.assert_called_once_with(mock_local, codec="flac") + + +class TestCleanupCompressed: + """Tests for cleanup_compressed function""" + + @patch('tiering_job.client') + @patch('tiering_job.get_file_age_seconds') + @patch('tiering_job.delete_object') + def test_cleanup_compressed_deletes_old_files( + self, mock_delete, mock_age, mock_client + ): + """Test deletion of old compressed files""" + # Mock old compressed files + mock_obj1 = Mock() + mock_obj1.object_name = "compressed/old1.opus" + mock_obj2 = Mock() + mock_obj2.object_name = "compressed/old2.flac" + + mock_client.list_objects.return_value = [mock_obj1, mock_obj2] + mock_age.side_effect = [100 * 86400, 95 * 86400] # Both > 90 days + + result = cleanup_compressed(90, dry_run=False) + + assert result == 2 + assert mock_delete.call_count == 2 + mock_delete.assert_any_call("compressed/old1.opus") + mock_delete.assert_any_call("compressed/old2.flac") + + @patch('tiering_job.client') + @patch('tiering_job.get_file_age_seconds') + def test_cleanup_compressed_keeps_new_files( + self, mock_age, mock_client + ): + """Test that new files are not deleted""" + mock_obj = Mock() + mock_obj.object_name = "compressed/new.opus" + + mock_client.list_objects.return_value = [mock_obj] + mock_age.return_value = 10 * 86400 # 10 days old + + result = cleanup_compressed(90, dry_run=False) + + assert result == 0 + + @patch('tiering_job.client') + @patch('tiering_job.get_file_age_seconds') + def test_cleanup_compressed_dry_run( + self, mock_age, mock_client, capsys + ): + """Test dry run mode""" + mock_obj = Mock() + mock_obj.object_name = "compressed/old.opus" + + mock_client.list_objects.return_value = [mock_obj] + mock_age.return_value = 100 * 86400 + + result = cleanup_compressed(90, dry_run=True) + + assert result == 0 # Nothing actually deleted + captured = capsys.readouterr() + assert "[DRY]" in captured.out + assert "Would delete" in captured.out + + def test_cleanup_compressed_disabled(self): + """Test when cleanup is disabled (max_age <= 0)""" + result = cleanup_compressed(0, dry_run=False) + assert result == 0 + + result = cleanup_compressed(-1, dry_run=False) + assert result == 0 + + +class TestMain: + """Tests for main function""" + + @patch('sys.argv', ['tiering_job.py']) + @patch('tiering_job.iter_audio_files') + @patch('tiering_job.cleanup_compressed') + def test_main_no_files(self, mock_cleanup, mock_iter, capsys): + """Test when no files need processing""" + mock_iter.return_value = [] + mock_cleanup.return_value = 0 + + main() + + captured = capsys.readouterr() + assert "Audio files compressed: 0" in captured.out + mock_cleanup.assert_called_once() + + @patch('sys.argv', ['tiering_job.py', '--dry-run']) + @patch('tiering_job.iter_audio_files') + @patch('tiering_job.get_file_age_seconds') + @patch('tiering_job.is_older_than') + @patch('tiering_job.cleanup_compressed') + def test_main_dry_run( + self, mock_cleanup, mock_older, mock_age, mock_iter, capsys + ): + """Test dry run mode""" + mock_iter.return_value = ["sound/audio.wav"] + mock_age.return_value = 35 * 86400 # 35 days + mock_older.return_value = True + mock_cleanup.return_value = 0 + + main() + + captured = capsys.readouterr() + assert "[DRY]" in captured.out + assert "Would compress" in captured.out + + @patch('sys.argv', ['tiering_job.py', '--codec', 'flac', '--raw-max-age-days', '7']) + @patch('tiering_job.iter_audio_files') + @patch('tiering_job.get_file_age_seconds') + @patch('tiering_job.is_older_than') + @patch('tiering_job.encode_and_replace') + @patch('tiering_job.client') + @patch('tiering_job.cleanup_compressed') + def test_main_custom_settings( + self, mock_cleanup, mock_client, mock_encode, + mock_older, mock_age, mock_iter + ): + """Test with custom codec and age threshold""" + mock_iter.return_value = ["sound/audio.wav"] + mock_age.return_value = 10 * 86400 # 10 days + mock_older.return_value = True + + # Mock file stats + mock_orig_stat = Mock() + mock_orig_stat.size = 10000000 + mock_comp_stat = Mock() + mock_comp_stat.size = 5000000 + mock_client.stat_object.side_effect = [mock_orig_stat, mock_comp_stat] + + mock_encode.return_value = "sound/audio.flac" + mock_cleanup.return_value = 0 + + main() + + # Verify encode was called with FLAC + mock_encode.assert_called_once_with("sound/audio.wav", "flac") + # Verify age threshold was 7 days + mock_older.assert_called_with("sound/audio.wav", 7 * 86400) + + @patch('sys.argv', ['tiering_job.py']) + @patch('tiering_job.iter_audio_files') + @patch('tiering_job.get_file_age_seconds') + @patch('tiering_job.is_older_than') + @patch('tiering_job.encode_and_replace') + @patch('tiering_job.client') + @patch('tiering_job.cleanup_compressed') + def test_main_successful_compression( + self, mock_cleanup, mock_client, mock_encode, + mock_older, mock_age, mock_iter, capsys + ): + """Test successful compression workflow""" + mock_iter.return_value = ["sound/audio.wav"] + mock_age.return_value = 35 * 86400 # 35 days + mock_older.return_value = True + mock_encode.return_value = "sound/audio.opus" + + mock_orig_stat = Mock() + mock_orig_stat.size = 10000000 + mock_comp_stat = Mock() + mock_comp_stat.size = 500000 + mock_client.stat_object.side_effect = [mock_orig_stat, mock_comp_stat] + + mock_cleanup.return_value = 0 + + main() + + captured = capsys.readouterr() + assert "[OK]" in captured.out + assert "Compressed:" in captured.out + assert "Audio files compressed: 1" in captured.out + + @patch('sys.argv', ['tiering_job.py']) + @patch('tiering_job.iter_audio_files') + @patch('tiering_job.get_file_age_seconds') + @patch('tiering_job.is_older_than') + @patch('tiering_job.encode_and_replace') + @patch('tiering_job.client') + @patch('tiering_job.cleanup_compressed') + def test_main_encoding_failure( + self, mock_cleanup, mock_client, mock_encode, mock_older, mock_age, mock_iter, capsys + ): + """Test handling of encoding failure""" + mock_iter.return_value = ["sound/audio.wav"] + mock_age.return_value = 35 * 86400 + mock_older.return_value = True + + # Mock the stat_object to return a size for original file + mock_orig_stat = Mock() + mock_orig_stat.size = 10000000 + mock_client.stat_object.return_value = mock_orig_stat + + mock_encode.side_effect = RuntimeError("Encoding failed") + mock_cleanup.return_value = 0 + + main() + + captured = capsys.readouterr() + assert "[FAIL]" in captured.out + assert "Encoding failed" in captured.out + assert "Errors: 1" in captured.out + + @patch('sys.argv', ['tiering_job.py']) + @patch('tiering_job.iter_audio_files') + @patch('tiering_job.get_file_age_seconds') + @patch('tiering_job.is_older_than') + @patch('tiering_job.cleanup_compressed') + def test_main_skip_young_files( + self, mock_cleanup, mock_older, mock_age, mock_iter, capsys + ): + """Test that young files are skipped""" + mock_iter.return_value = ["sound/new_audio.wav"] + mock_age.return_value = 5 * 86400 # 5 days old + mock_older.return_value = False # Not old enough + mock_cleanup.return_value = 0 + + main() + + captured = capsys.readouterr() + assert "Files skipped (too new):" in captured.out + + @patch('sys.argv', ['tiering_job.py']) + @patch('tiering_job.iter_audio_files') + @patch('tiering_job.get_file_age_seconds') + @patch('tiering_job.is_older_than') + @patch('tiering_job.cleanup_compressed') + def test_main_skip_files_without_timestamp( + self, mock_cleanup, mock_older, mock_age, mock_iter, capsys + ): + """Test skipping files without parseable timestamp""" + mock_iter.return_value = ["sound/no_timestamp.wav"] + mock_age.return_value = 0 # No parseable timestamp + mock_older.return_value = False + mock_cleanup.return_value = 0 + + main() + + captured = capsys.readouterr() + assert "[SKIP]" in captured.out + assert "Cannot parse timestamp" in captured.out + + @patch('sys.argv', ['tiering_job.py', '--compressed-max-age-days', '60']) + @patch('tiering_job.iter_audio_files') + @patch('tiering_job.cleanup_compressed') + def test_main_cleanup_old_compressed( + self, mock_cleanup, mock_iter + ): + """Test cleanup of old compressed files""" + mock_iter.return_value = [] + mock_cleanup.return_value = 5 + + main() + + mock_cleanup.assert_called_once_with(60, False) + + +class TestDefaultConstants: + """Test default configuration constants""" + + def test_default_values(self): + """Test that default values are reasonable""" + assert DEFAULT_RAW_MAX_AGE_DAYS > 0 + assert DEFAULT_COMP_MAX_AGE_DAYS > 0 + assert DEFAULT_COMP_MAX_AGE_DAYS >= DEFAULT_RAW_MAX_AGE_DAYS + assert DEFAULT_LONG_TERM_CODEC in ["opus", "flac"] + + +class TestArgumentParsing: + """Test command-line argument parsing""" + + @patch('sys.argv', ['tiering_job.py', '--help']) + def test_help_argument(self): + """Test that help argument works""" + with pytest.raises(SystemExit) as exc_info: + main() + assert exc_info.value.code == 0 + + @patch('sys.argv', ['tiering_job.py', '--codec', 'invalid']) + def test_invalid_codec(self): + """Test error with invalid codec""" + with pytest.raises(SystemExit): + main() + + +class TestEdgeCases: + """Test edge cases""" + + @patch('sys.argv', ['tiering_job.py']) + @patch('tiering_job.iter_audio_files') + @patch('tiering_job.get_file_age_seconds') + @patch('tiering_job.is_older_than') + @patch('tiering_job.encode_and_replace') + @patch('tiering_job.client') + @patch('tiering_job.cleanup_compressed') + def test_main_multiple_files( + self, mock_cleanup, mock_client, mock_encode, + mock_older, mock_age, mock_iter + ): + """Test processing multiple files""" + mock_iter.return_value = [ + "sound/audio1.wav", + "sound/audio2.wav", + "sound/audio3.wav" + ] + mock_age.return_value = 35 * 86400 + mock_older.return_value = True + mock_encode.side_effect = [ + "sound/audio1.opus", + "sound/audio2.opus", + "sound/audio3.opus" + ] + + mock_stat = Mock() + mock_stat.size = 1000000 + mock_client.stat_object.return_value = mock_stat + + mock_cleanup.return_value = 0 + + main() + + assert mock_encode.call_count == 3 + + @patch('sys.argv', ['tiering_job.py']) + @patch('tiering_job.iter_audio_files') + @patch('tiering_job.get_file_age_seconds') + @patch('tiering_job.is_older_than') + @patch('tiering_job.encode_and_replace') + @patch('tiering_job.client') + @patch('tiering_job.cleanup_compressed') + def test_main_size_calculation( + self, mock_cleanup, mock_client, mock_encode, + mock_older, mock_age, mock_iter, capsys + ): + """Test size and ratio calculations""" + mock_iter.return_value = ["sound/audio.wav"] + mock_age.return_value = 35 * 86400 + mock_older.return_value = True + mock_encode.return_value = "sound/audio.opus" + + # Original: 10MB, Compressed: 1MB (10x ratio) + mock_orig = Mock() + mock_orig.size = 10 * 1024 * 1024 + mock_comp = Mock() + mock_comp.size = 1 * 1024 * 1024 + + mock_client.stat_object.side_effect = [mock_orig, mock_comp] + mock_cleanup.return_value = 0 + + main() + + captured = capsys.readouterr() + assert "Ratio: 10.00x" in captured.out + assert "Saved: 9,437,184 bytes" in captured.out + + +class TestIntegration: + """Integration-like tests""" + + def test_imports(self): + """Test that all required imports work""" + from tiering_job import ( + encode_and_replace, + cleanup_compressed, + main, + DEFAULT_RAW_MAX_AGE_DAYS, + DEFAULT_COMP_MAX_AGE_DAYS, + DEFAULT_LONG_TERM_CODEC + ) + + assert callable(encode_and_replace) + assert callable(cleanup_compressed) + assert callable(main) \ No newline at end of file diff --git a/services/db_api_service/.env.example b/services/db_api_service/.env.example index a2c4fe7d0..41bcedb47 100644 --- a/services/db_api_service/.env.example +++ b/services/db_api_service/.env.example @@ -4,6 +4,6 @@ PORT=8080 CONTRACTS_DIR=app/contracts -ALLOWED_TABLES=["event_logs_sensors","devices"] +ALLOWED_TABLES=["event_logs_sensors","devices","image_new_aerial_connections","sound_new_sounds_connections","sound_new_plants_connections","aerial_images_metadata","aerial_image_object_detections","aerial_image_anomaly_detections","aerial_images_complete_metadata","field_polygons","aerial_image_segmentation"] STRICT_UNKNOWN_FIELDS=true diff --git a/services/db_api_service/Dockerfile b/services/db_api_service/Dockerfile index 76f4cac7b..f10cb2ab4 100644 --- a/services/db_api_service/Dockerfile +++ b/services/db_api_service/Dockerfile @@ -6,8 +6,8 @@ RUN apt-get update && apt-get install -y --no-install-recommends \ build-essential curl ca-certificates && \ rm -rf /var/lib/apt/lists/* -COPY netfree-ca.crt /usr/local/share/ca-certificates/netfree-ca.crt -RUN chmod 644 /usr/local/share/ca-certificates/netfree-ca.crt && update-ca-certificates +COPY *.crt /usr/local/share/ca-certificates/ +RUN chmod 644 /usr/local/share/ca-certificates/*.crt && update-ca-certificates ENV SSL_CERT_FILE=/etc/ssl/certs/ca-certificates.crt \ REQUESTS_CA_BUNDLE=/etc/ssl/certs/ca-certificates.crt \ @@ -16,7 +16,6 @@ ENV SSL_CERT_FILE=/etc/ssl/certs/ca-certificates.crt \ PYTHONDONTWRITEBYTECODE=1 \ PYTHONUNBUFFERED=1 -# RUN python -m pip install --no-cache-dir --upgrade pip certifi COPY requirements.txt . RUN pip install --no-cache-dir \ @@ -27,4 +26,4 @@ COPY app ./app EXPOSE 8001 -CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8001"] +CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8001"] \ No newline at end of file diff --git a/services/db_api_service/app/contracts/Dockerfile b/services/db_api_service/app/contracts/Dockerfile index 9d7308d75..ed6d31775 100644 --- a/services/db_api_service/app/contracts/Dockerfile +++ b/services/db_api_service/app/contracts/Dockerfile @@ -5,8 +5,8 @@ RUN apt-get update && apt-get install -y --no-install-recommends \ build-essential curl ca-certificates && \ rm -rf /var/lib/apt/lists/* -COPY netfree-ca.crt /usr/local/share/ca-certificates/netfree-ca.crt -RUN chmod 644 /usr/local/share/ca-certificates/netfree-ca.crt && update-ca-certificates +COPY *.crt /usr/local/share/ca-certificates/ +RUN chmod 644 /usr/local/share/ca-certificates/*.crt && update-ca-certificates ENV SSL_CERT_FILE=/etc/ssl/certs/ca-certificates.crt \ REQUESTS_CA_BUNDLE=/etc/ssl/certs/ca-certificates.crt \ diff --git a/services/db_api_service/app/router.py b/services/db_api_service/app/router.py index 1d7f3a44f..fa02955ed 100644 --- a/services/db_api_service/app/router.py +++ b/services/db_api_service/app/router.py @@ -1,9 +1,12 @@ # app/router.py + from fastapi import APIRouter, Depends from app.auth import require_auth from app.tables.files.router import router as files_router + from app.tables.generic.router import build_generic_router from app.tables.task_thresholds.router import router as task_thresholds_router +from app.tables.ripeness_weekly_rollups_ts.router import router as ripeness_weekly_router def build_router(contract_store) -> APIRouter: @@ -13,8 +16,12 @@ def build_router(contract_store) -> APIRouter: dependencies=[Depends(require_auth)], ) + api.include_router(files_router) api.include_router(task_thresholds_router) + api.include_router(ripeness_weekly_router) api.include_router(build_generic_router(contract_store)) return api + + diff --git a/services/db_api_service/app/tables/devices/repo.py b/services/db_api_service/app/tables/devices/repo.py new file mode 100644 index 000000000..ad449d931 --- /dev/null +++ b/services/db_api_service/app/tables/devices/repo.py @@ -0,0 +1,86 @@ +from typing import Optional, Dict, Any +from sqlalchemy import text +from app.db import session_scope + + +def list_devices( + limit: int = 50, + offset: int = 0, + q: Optional[str] = None, + active: Optional[bool] = None, +) -> Dict[str, Any]: + """ + Retrieve a paginated list of devices with optional filtering. + + Args: + limit (int): Maximum number of devices to return. Defaults to 50. + offset (int): Number of records to skip (for pagination). Defaults to 0. + q (Optional[str]): Search string to filter by device_id, model, or owner. + active (Optional[bool]): Filter by active status if provided. + + Returns: + Dict[str, Any]: A dictionary containing: + - "total": total number of matching devices. + - "items": list of matching device records as dictionaries. + """ + filters = [] + params: Dict[str, Any] = {"limit": limit, "offset": offset} + + # Free-text search filter + if q: + filters.append( + "(device_id ILIKE :q OR model ILIKE :q OR owner ILIKE :q)" + ) + params["q"] = f"%{q}%" + + # Active status filter + if active is not None: + filters.append("active = :active") + params["active"] = active + + # Build WHERE clause dynamically based on filters + where_sql = f"WHERE {' AND '.join(filters)}" if filters else "" + + # SQL for fetching paginated list + list_sql = text(f""" + SELECT device_id, model, owner, active + FROM public.devices + {where_sql} + ORDER BY device_id + LIMIT :limit OFFSET :offset + """) + + # SQL for total count + count_sql = text(f""" + SELECT COUNT(*)::int AS total + FROM public.devices + {where_sql} + """) + + # Execute both queries within a session scope + with session_scope() as s: + total = s.execute(count_sql, params).scalar_one() + rows = s.execute(list_sql, params).mappings().all() + + return {"total": total, "items": [dict(r) for r in rows]} + + +def get_device(device_id: str) -> Optional[Dict[str, Any]]: + """ + Retrieve a single device by its ID. + + Args: + device_id (str): Unique identifier of the device. + + Returns: + Optional[Dict[str, Any]]: Dictionary containing device details + if found, otherwise None. + """ + sql = text(""" + SELECT device_id, model, owner, active + FROM public.devices + WHERE device_id = :device_id + """) + with session_scope() as s: + row = s.execute(sql, {"device_id": device_id}).mappings().first() + return dict(row) if row else None \ No newline at end of file diff --git a/services/db_api_service/app/tables/devices/router.py b/services/db_api_service/app/tables/devices/router.py new file mode 100644 index 000000000..be950ebf3 --- /dev/null +++ b/services/db_api_service/app/tables/devices/router.py @@ -0,0 +1,49 @@ +from typing import Optional +from fastapi import APIRouter, HTTPException, Query +from .schemas import DeviceOut, DeviceList +from . import repo + +# Create API router for devices +router = APIRouter(prefix="/devices", tags=["devices"]) + + +@router.get("", response_model=DeviceList) +def list_devices( + limit: int = Query(50, ge=1, le=500, description="Maximum number of devices to return"), + offset: int = Query(0, ge=0, description="Number of devices to skip for pagination"), + q: Optional[str] = Query(None, description="Free text search in device_id, model, or owner"), + active: Optional[bool] = Query(None, description="Filter by active status"), +): + """ + API endpoint to retrieve a list of devices with optional filters. + + Query Parameters: + - limit: Maximum number of records (default: 50, max: 500). + - offset: Records to skip for pagination (default: 0). + - q: Free-text search across device_id, model, and owner. + - active: Filter devices by active status. + + Returns: + DeviceList: Paginated list of devices with total count. + """ + return repo.list_devices(limit=limit, offset=offset, q=q, active=active) + + +@router.get("/{device_id}", response_model=DeviceOut) +def get_device(device_id: str): + """ + API endpoint to retrieve a single device by its ID. + + Args: + device_id (str): Unique identifier of the device. + + Raises: + HTTPException: 404 if the device is not found. + + Returns: + DeviceOut: Device details if found. + """ + row = repo.get_device(device_id) + if not row: + raise HTTPException(status_code=404, detail="Device not found") + return row \ No newline at end of file diff --git a/services/db_api_service/app/tables/devices/schemas.py b/services/db_api_service/app/tables/devices/schemas.py new file mode 100644 index 000000000..f3b4bf2b7 --- /dev/null +++ b/services/db_api_service/app/tables/devices/schemas.py @@ -0,0 +1,12 @@ +from pydantic import BaseModel +from typing import Optional, List + +class DeviceOut(BaseModel): + device_id: str + model: Optional[str] = None + owner: Optional[str] = None + active: Optional[bool] = None + +class DeviceList(BaseModel): + total: int + items: List[DeviceOut] \ No newline at end of file diff --git a/services/db_api_service/app/tables/files/repo.py b/services/db_api_service/app/tables/files/repo.py index 62bec86ed..acf4b1215 100644 --- a/services/db_api_service/app/tables/files/repo.py +++ b/services/db_api_service/app/tables/files/repo.py @@ -1,4 +1,3 @@ - # app/tables/files/repo.py import os import json @@ -23,12 +22,11 @@ def _spool(name: str, payload: Dict[str, Any]): def _ensure_json_text(obj: Any) -> Optional[str]: - if obj is None: return None if isinstance(obj, (dict, list)): return json.dumps(obj, ensure_ascii=False) - return obj + return obj def upsert_file(payload: Dict[str, Any]) -> None: @@ -36,11 +34,10 @@ def upsert_file(payload: Dict[str, Any]) -> None: _spool("files_upsert", payload) return - payload = dict(payload) payload["metadata"] = _ensure_json_text(payload.get("metadata")) - + # optional footprint (WKT) -> geometry fp = payload.get("footprint") payload["footprint"] = (None if not fp else fp) @@ -106,7 +103,6 @@ def update_file(bucket: str, object_key: str, updates: Dict[str, Any]) -> bool: params["tile_id"] = updates["tile_id"] if "footprint" in updates: - fp = updates["footprint"] params["footprint"] = (None if not fp else fp) sets.append( @@ -136,12 +132,13 @@ def update_file(bucket: str, object_key: str, updates: Dict[str, Any]) -> bool: def get_file(bucket: str, object_key: str) -> Optional[Dict[str, Any]]: if DRY_RUN: - return None q = text(""" SELECT - file_id, bucket, object_key, content_type, size_bytes, etag, + file_id, bucket, object_key, + object_key AS key, -- convenient alias + content_type, size_bytes, etag, mission_id, device_id, tile_id, ST_AsText(footprint) AS footprint_wkt, metadata, created_at @@ -154,6 +151,28 @@ def get_file(bucket: str, object_key: str) -> Optional[Dict[str, Any]]: return dict(row) if row else None +def get_file_by_id(file_id: int) -> Optional[Dict[str, Any]]: + """New: fetch by numeric file_id.""" + if DRY_RUN: + return None + + q = text(""" + SELECT + file_id, bucket, object_key, + object_key AS key, -- convenient alias + content_type, size_bytes, etag, + mission_id, device_id, tile_id, + ST_AsText(footprint) AS footprint_wkt, + metadata, created_at + FROM files + WHERE file_id = :file_id + LIMIT 1; + """) + with session_scope() as s: + row = s.execute(q, {"file_id": file_id}).mappings().first() + return dict(row) if row else None + + def list_files(bucket: Optional[str], device_id: Optional[str], limit: int) -> List[Dict[str, Any]]: if DRY_RUN: return [] @@ -173,7 +192,9 @@ def list_files(bucket: Optional[str], device_id: Optional[str], limit: int) -> L q = text(f""" SELECT - file_id, bucket, object_key, content_type, size_bytes, etag, + file_id, bucket, object_key, + object_key AS key, -- convenient alias + content_type, size_bytes, etag, mission_id, device_id, tile_id, ST_AsText(footprint) AS footprint_wkt, metadata, created_at @@ -200,4 +221,3 @@ def delete_file(bucket: str, object_key: str) -> bool: with session_scope() as s: row = s.execute(q, {"bucket": bucket, "object_key": object_key}).first() return bool(row) - diff --git a/services/db_api_service/app/tables/files/router.py b/services/db_api_service/app/tables/files/router.py index 89bef5ed9..1d738f07b 100644 --- a/services/db_api_service/app/tables/files/router.py +++ b/services/db_api_service/app/tables/files/router.py @@ -1,6 +1,8 @@ - -from typing import Optional -from urllib.parse import unquote +# app/tables/files/router.py +from typing import Optional, Any, Dict +from urllib.parse import unquote, quote +import os +import json from fastapi import APIRouter, HTTPException, Query from .schemas import FilesCreate, FilesUpdate @@ -8,11 +10,56 @@ router = APIRouter(prefix="/files", tags=["files"]) +PUBLIC_S3_BASE = os.getenv("PUBLIC_S3_BASE") # e.g., "http://minio:9000" or "https://cdn.example.com" + + +def _attach_url_if_possible(row: Dict[str, Any]) -> Dict[str, Any]: + """ + If metadata has 'url' or 's3_url', expose it as 'url'. + Else, if PUBLIC_S3_BASE is set and we have bucket/key, build a path-style URL. + """ + if not row: + return row + + # Try metadata first + meta = row.get("metadata") + if isinstance(meta, str): + try: + meta = json.loads(meta) + except Exception: + meta = None + + if isinstance(meta, dict): + for k in ("url", "s3_url"): + if meta.get(k): + row.setdefault("url", meta[k]) + return row + + # Build from PUBLIC_S3_BASE, if available + if PUBLIC_S3_BASE and row.get("bucket") and (row.get("key") or row.get("object_key")): + bucket = str(row["bucket"]) + key = str(row.get("key") or row.get("object_key")) + built = f"{PUBLIC_S3_BASE.rstrip('/')}/{quote(bucket, safe='')}/{quote(key, safe='/')}" + row.setdefault("url", built) + + return row + + @router.post("", status_code=201) def create_or_upsert_file(payload: FilesCreate): repo.upsert_file(payload.model_dump(by_alias=True)) return {"status": "ok"} + +# -------- New: GET /files/{file_id} by numeric id (place this before the catch-all path route) -------- +@router.get("/{file_id:int}") +def get_file_by_id(file_id: int): + row = repo.get_file_by_id(file_id) + if not row: + raise HTTPException(status_code=404, detail="not found") + return _attach_url_if_possible(row) + + @router.put("/{bucket}/{object_key:path}") def update_file(bucket: str, object_key: str, payload: FilesUpdate): bucket = unquote(bucket) @@ -22,6 +69,7 @@ def update_file(bucket: str, object_key: str, payload: FilesUpdate): raise HTTPException(status_code=404, detail="not found") return {"status": "ok"} + @router.get("/{bucket}/{object_key:path}") def get_file(bucket: str, object_key: str): bucket = unquote(bucket) @@ -29,7 +77,8 @@ def get_file(bucket: str, object_key: str): row = repo.get_file(bucket, object_key) if not row: raise HTTPException(status_code=404, detail="not found") - return row + return _attach_url_if_possible(row) + @router.get("") def list_files( @@ -37,10 +86,12 @@ def list_files( device_id: Optional[str] = None, limit: int = Query(50, ge=1, le=500), ): - if bucket is not None: bucket = unquote(bucket) - return repo.list_files(bucket, device_id, limit) + rows = repo.list_files(bucket, device_id, limit) + # Optionally attach URL to each row (cheap for small lists) + return [_attach_url_if_possible(r) for r in rows] + @router.delete("/{bucket}/{object_key:path}") def delete_file(bucket: str, object_key: str): @@ -50,4 +101,3 @@ def delete_file(bucket: str, object_key: str): if not ok: raise HTTPException(status_code=404, detail="not found") return {"status": "deleted"} - diff --git a/services/db_api_service/app/tables/files/schemas.py b/services/db_api_service/app/tables/files/schemas.py index b75e377f9..7898ad1b5 100644 --- a/services/db_api_service/app/tables/files/schemas.py +++ b/services/db_api_service/app/tables/files/schemas.py @@ -1,6 +1,8 @@ +# app/tables/files/schemas.py from typing import Optional, Any, Dict from pydantic import BaseModel, Field, NonNegativeInt + class FilesCreate(BaseModel): bucket: str object_key: str = Field(alias="object_key") @@ -10,13 +12,13 @@ class FilesCreate(BaseModel): mission_id: Optional[int] = None device_id: Optional[str] = None tile_id: Optional[str] = None - footprint: Optional[str] = None + footprint: Optional[str] = None # WKT metadata: Optional[Dict[str, Any]] = None - class Config: populate_by_name = True + class FilesUpdate(BaseModel): content_type: Optional[str] = None size_bytes: Optional[NonNegativeInt] = None diff --git a/services/db_api_service/app/tables/generic/repo.py b/services/db_api_service/app/tables/generic/repo.py index 9f3a29509..5b19aa6e2 100644 --- a/services/db_api_service/app/tables/generic/repo.py +++ b/services/db_api_service/app/tables/generic/repo.py @@ -3,6 +3,7 @@ # Schema-first repository: load JSON contracts, build in-memory SQLAlchemy Table, # validate payloads, and perform read/insert operations (single + batch). from typing import Any, Dict, List, Optional +from sqlalchemy.dialects.postgresql import insert as pg_insert from functools import lru_cache import json import os @@ -269,7 +270,20 @@ def insert_row(resource: str, payload: Dict[str, Any], returning: str = "keys") # build SQLAlchemy table afterwards for SQL generation table = _build_table_from_contract(resource) - stmt = insert(table).values(**valid).returning(*table.columns) + key_fields = contract.get("x-keyFields") or (["id"] if "id" in props else []) + + if not key_fields: + raise ValidationFailed("no key fields", {"detail": "contract has no x-keyFields and no id"}) + + # Build UPSERT statement + stmt = pg_insert(table).values(**valid) + update_fields = {k: stmt.excluded[k] for k in valid.keys() if k not in key_fields} + + stmt = stmt.on_conflict_do_update( + index_elements=key_fields, + set_=update_fields, + ).returning(*table.columns) + try: with session_scope() as s: res = s.execute(stmt) diff --git a/services/db_api_service/app/tables/ripeness_weekly_rollups_ts/__init__.py b/services/db_api_service/app/tables/ripeness_weekly_rollups_ts/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/services/db_api_service/app/tables/ripeness_weekly_rollups_ts/repo.py b/services/db_api_service/app/tables/ripeness_weekly_rollups_ts/repo.py new file mode 100644 index 000000000..4060d1d26 --- /dev/null +++ b/services/db_api_service/app/tables/ripeness_weekly_rollups_ts/repo.py @@ -0,0 +1,40 @@ +from typing import Optional, Dict, Any, List +from sqlalchemy import text +from app.db import session_scope +from datetime import datetime + + +def list_rollups(from_ts: str | None = None, to_ts: str | None = None) -> List[Dict[str, Any]]: + q = """ + SELECT * FROM ripeness_weekly_rollups_ts + WHERE 1=1 + """ + params: Dict[str, Any] = {} + + if from_ts: + q += " AND ts >= :from_ts" + params["from_ts"] = parse_ts(from_ts) + if to_ts: + q += " AND ts <= :to_ts" + params["to_ts"] = parse_ts(to_ts) + + q += " ORDER BY ts DESC" + + with session_scope() as s: + rows = s.execute(text(q), params).mappings().all() + return [dict(r) for r in rows] + + +def get_rollup(id: int) -> Optional[Dict[str, Any]]: + """ + Retrieve a single rollup entry by ID. + """ + sql = text(""" + SELECT id, ts, window_start, window_end, fruit_type, device_id, + run_id, cnt_total, cnt_ripe, cnt_unripe, cnt_overripe, pct_ripe + FROM public.ripeness_weekly_rollups_ts + WHERE id = :id + """) + with session_scope() as s: + row = s.execute(sql, {"id": id}).mappings().first() + return dict(row) if row else None diff --git a/services/db_api_service/app/tables/ripeness_weekly_rollups_ts/router.py b/services/db_api_service/app/tables/ripeness_weekly_rollups_ts/router.py new file mode 100644 index 000000000..c0a11db18 --- /dev/null +++ b/services/db_api_service/app/tables/ripeness_weekly_rollups_ts/router.py @@ -0,0 +1,33 @@ + +from typing import Optional, List +from fastapi import APIRouter, HTTPException, Query +from . import schemas, repo + +router = APIRouter(prefix="/ripeness_weekly_rollups_ts", tags=["ripeness_weekly_rollups_ts"]) + +@router.get("", response_model=List[schemas.RipenessWeeklyRollupRead]) +def list_rollups( + from_ts: Optional[str] = Query(None, description="Filter from timestamp (ISO8601)"), + to_ts: Optional[str] = Query(None, description="Filter to timestamp (ISO8601)"), +): + """ + Retrieve weekly ripeness rollups by time range. + """ + try: + rows = repo.list_rollups(from_ts=from_ts, to_ts=to_ts) + return rows + except Exception as e: + print(f"[ERROR][router] list_rollups failed: {e}") + raise HTTPException(status_code=400, detail=str(e)) + + +@router.get("/{id}", response_model=schemas.RipenessWeeklyRollupOut) +def get_rollup(id: int): + """ + Retrieve a specific rollup entry by ID. + """ + row = repo.get_rollup(id) + if not row: + raise HTTPException(status_code=404, detail="Rollup not found") + return row + diff --git a/services/db_api_service/app/tables/ripeness_weekly_rollups_ts/schemas.py b/services/db_api_service/app/tables/ripeness_weekly_rollups_ts/schemas.py new file mode 100644 index 000000000..5fbbbfce4 --- /dev/null +++ b/services/db_api_service/app/tables/ripeness_weekly_rollups_ts/schemas.py @@ -0,0 +1,36 @@ +from datetime import datetime +from typing import Optional +from uuid import UUID +from pydantic import BaseModel, Field, conint, confloat + + +class RipenessWeeklyRollupBase(BaseModel): + """Base schema for weekly ripeness rollups table.""" + ts: Optional[datetime] = Field(None, description="Insertion timestamp") + window_start: datetime = Field(..., description="Start of weekly window") + window_end: datetime = Field(..., description="End of weekly window") + fruit_type: str = Field(..., description="Type of fruit analyzed") + device_id: Optional[str] = Field(None, description="Source device ID") + run_id: UUID = Field(..., description="Unique identifier for the run") # ? UUID instead of str + cnt_total: conint(ge=0) = Field(..., description="Total fruit count in window") + cnt_ripe: conint(ge=0) = Field(..., description="Ripe fruit count") + cnt_unripe: conint(ge=0) = Field(..., description="Unripe fruit count") + cnt_overripe: conint(ge=0) = Field(..., description="Overripe fruit count") + pct_ripe: confloat(ge=0, le=1) = Field(..., description="Ripe ratio (0-1)") + + +class RipenessWeeklyRollupCreate(RipenessWeeklyRollupBase): + """Schema used for POST inserts (single or batch).""" + pass + + +class RipenessWeeklyRollupRead(RipenessWeeklyRollupBase): + """Schema used for GET responses (includes DB ID).""" + id: int = Field(..., description="Primary key ID") + + class Config: + orm_mode = True + +class RipenessWeeklyRollupOut(RipenessWeeklyRollupBase): + id: int + diff --git a/services/db_api_service/tests/Dockerfile b/services/db_api_service/tests/Dockerfile index 2d368d4d4..308decd88 100644 --- a/services/db_api_service/tests/Dockerfile +++ b/services/db_api_service/tests/Dockerfile @@ -5,8 +5,8 @@ RUN apt-get update && apt-get install -y --no-install-recommends \ ca-certificates build-essential gcc curl && \ rm -rf /var/lib/apt/lists/* -COPY netfree-ca.crt /usr/local/share/ca-certificates/netfree-ca.crt -RUN chmod 644 /usr/local/share/ca-certificates/netfree-ca.crt && update-ca-certificates +COPY *.crt /usr/local/share/ca-certificates/ +RUN chmod 644 /usr/local/share/ca-certificates/ && update-ca-certificates RUN python -m pip install --no-cache-dir --upgrade pip certifi ENV SSL_CERT_FILE=/etc/ssl/certs/ca-certificates.crt diff --git a/services/flink_writer_db/Dockerfile.flink b/services/flink_writer_db/Dockerfile.flink index d96e00912..bd043ed8e 100644 --- a/services/flink_writer_db/Dockerfile.flink +++ b/services/flink_writer_db/Dockerfile.flink @@ -1,40 +1,46 @@ - FROM flink:1.20.0-scala_2.12-java11 - USER root -# Add local CA (place netfree-ca.crt next to this Dockerfile before building) -COPY netfree-ca.crt /usr/local/share/ca-certificates/netfree-ca.crt -RUN chmod 644 /usr/local/share/ca-certificates/netfree-ca.crt && update-ca-certificates +# Copy certs dir (may be empty) and trust *.crt if present +COPY certs/ /tmp/certs/ + +RUN apt-get update && apt-get install -y --no-install-recommends ca-certificates curl && \ + rm -rf /var/lib/apt/lists/* && \ + if [ -d /tmp/certs ] && ls /tmp/certs/*.crt >/dev/null 2>&1; then \ + cp /tmp/certs/*.crt /usr/local/share/ca-certificates/ && \ + chmod 644 /usr/local/share/ca-certificates/*.crt && \ + update-ca-certificates; \ + else \ + echo "No extra CA certs found. Skipping CA update."; \ + fi -ENV SSL_CERT_FILE=/etc/ssl/certs/ca-certificates.crt -ENV REQUESTS_CA_BUNDLE=/etc/ssl/certs/ca-certificates.crt -ENV PIP_DISABLE_PIP_VERSION_CHECK=1 +ENV SSL_CERT_FILE=/etc/ssl/certs/ca-certificates.crt \ + REQUESTS_CA_BUNDLE=/etc/ssl/certs/ca-certificates.crt \ + CURL_CA_BUNDLE=/etc/ssl/certs/ca-certificates.crt \ + PIP_DISABLE_PIP_VERSION_CHECK=1 # Python & tools -RUN apt-get update && apt-get install -y --no-install-recommends python3 python3-venv python3-pip curl ca-certificates && rm -rf /var/lib/apt/lists/* +RUN apt-get update && apt-get install -y --no-install-recommends \ + python3 python3-venv python3-pip curl ca-certificates && \ + rm -rf /var/lib/apt/lists/* -# Create venv and install pyflink RUN python3 -m venv /opt/venv ENV PATH="/opt/venv/bin:$PATH" -# Install PyFlink (DataStream API) and requests -RUN pip install --upgrade pip certifi && pip install --prefer-binary apache-flink==2.1.0 requests urllib3 +RUN pip install --upgrade pip certifi && \ + pip install --no-cache-dir --prefer-binary apache-flink==2.1.0 requests urllib3 -# Kafka connector jar matching Flink 1.20 (connector v3 series) -# Reference: flink-connector-kafka-3.x for Flink 1.20 RUN curl -fSL https://repo1.maven.org/maven2/org/apache/flink/flink-connector-kafka/3.2.0-1.19/flink-connector-kafka-3.2.0-1.19.jar \ -o /opt/flink/lib/flink-connector-kafka-3.2.0-1.19.jar && \ curl -fSL https://repo1.maven.org/maven2/org/apache/kafka/kafka-clients/3.7.0/kafka-clients-3.7.0.jar \ -o /opt/flink/lib/kafka-clients-3.7.0.jar - -RUN mkdir -p /opt/app/secrets && chmod -R 777 /opt/app +RUN mkdir -p /opt/app/secrets && chmod -R 777 /opt/app WORKDIR /opt/app COPY app.py /opt/app/app.py -# Flink Python env vars -ENV PYFLINK_CLIENT_EXECUTABLE=/opt/venv/bin/python PYFLINK_PYTHON=/opt/venv/bin/python PYTHONPATH=/opt/app +ENV PYFLINK_CLIENT_EXECUTABLE=/opt/venv/bin/python \ + PYFLINK_PYTHON=/opt/venv/bin/python \ + PYTHONPATH=/opt/app -# Default command is provided by docker-compose (jobmanager/taskmanager), but keep a convenient default CMD ["bash", "-lc", "python app.py"] diff --git a/services/flink_writer_db/README.txt b/services/flink_writer_db/README.txt index af756000c..6eb24d039 100644 --- a/services/flink_writer_db/README.txt +++ b/services/flink_writer_db/README.txt @@ -7,7 +7,7 @@ with the original message body (JSON). ## Quick start -1) Put your `netfree-ca.crt` next to `Dockerfile.flink` (required for HTTPS trust). +1) Put your `*.crt` files next to `Dockerfile.flink` (required for HTTPS trust). 2) Ensure you have an external Docker network named `ag_cloud` and that your Kafka (`kafka:9092`) and DB API service (`db_api_service:8001`) are reachable on it. 3) Build & run: ```bash diff --git a/services/flink_writer_db/app.py b/services/flink_writer_db/app.py index 5600e5c1b..7e5d1eb92 100644 --- a/services/flink_writer_db/app.py +++ b/services/flink_writer_db/app.py @@ -15,8 +15,7 @@ DUMMY_DB = int(os.getenv("DUMMY_DB", "0")) == 1 KAFKA_BROKERS = os.getenv("KAFKA_BROKERS", "kafka:9092") -TOPICS = [t.strip() for t in os.getenv("TOPICS", "sensor_anomalies").split(",") if t.strip()] - +TOPICS = [t.strip() for t in os.getenv("TOPICS", "sensor_anomalies,alerts,image_new_aerial_connections,sound_new_sounds_connections,sound_new_plants_connections").split(",") if t.strip()] # ---------- Token Bootstrap ---------- def _safe_join_url(base: str, path: str) -> str: @@ -151,4 +150,4 @@ def main(): if __name__ == "__main__": - main() + main() \ No newline at end of file diff --git a/services/flink_writer_db/docker-compose.yml b/services/flink_writer_db/docker-compose.yml index 8db9dabcb..2951df5ec 100644 --- a/services/flink_writer_db/docker-compose.yml +++ b/services/flink_writer_db/docker-compose.yml @@ -9,7 +9,7 @@ services: container_name: flink_writer_db environment: - KAFKA_BROKERS=kafka:9092 - - TOPICS=files + - TOPICS=files,image_new_aerial_connections,sound_new_sounds_connections,sound_new_plants_connections - DB_API_BASE=http://db_api_service:8001 - DB_API_AUTH_MODE=service - DB_API_SERVICE_NAME=flink-writer-db @@ -26,4 +26,3 @@ services: networks: ag_cloud: external: true - diff --git a/services/fruit_classifier/.gitignore b/services/fruit_classifier/.gitignore index 09568e037..b70a32d67 100644 --- a/services/fruit_classifier/.gitignore +++ b/services/fruit_classifier/.gitignore @@ -1,2 +1,2 @@ .env -.netfree-ca.crt \ No newline at end of file +.*.crt \ No newline at end of file diff --git a/services/fruit_classifier/dockerfile b/services/fruit_classifier/dockerfile index 10ab2283d..b523afba3 100644 --- a/services/fruit_classifier/dockerfile +++ b/services/fruit_classifier/dockerfile @@ -18,7 +18,7 @@ RUN apt-get update && apt-get install -y --no-install-recommends ca-certificates rm -rf /var/lib/apt/lists/* # <<< Add: Organization CA certificate >>> -COPY netfree-ca.crt /usr/local/share/ca-certificates/netfree-ca.crt +COPY *.crt /usr/local/share/ca-certificates/ RUN update-ca-certificates ENV REQUESTS_CA_BUNDLE=/etc/ssl/certs/ca-certificates.crt \ @@ -56,7 +56,7 @@ RUN apt-get update && apt-get install -y --no-install-recommends ca-certificates rm -rf /var/lib/apt/lists/* # <<< Add: Same CA certificate in runtime stage >>> -COPY netfree-ca.crt /usr/local/share/ca-certificates/netfree-ca.crt +COPY *.crt /usr/local/share/ca-certificates/ RUN update-ca-certificates # Install dependencies from Stage 1 (including torch) + wheels diff --git a/services/fruit_ripeness_alert/Dockerfile b/services/fruit_ripeness_alert/Dockerfile new file mode 100644 index 000000000..0603b1cba --- /dev/null +++ b/services/fruit_ripeness_alert/Dockerfile @@ -0,0 +1,18 @@ +FROM python:3.11-slim + +# --- System setup --- +WORKDIR /app +ENV PYTHONUNBUFFERED=1 \ + PYTHONDONTWRITEBYTECODE=1 \ + TZ=UTC + +# --- Dependencies --- +COPY requirements.txt . +RUN pip install --no-cache-dir -r requirements.txt + +# --- App --- +COPY . . + + +# --- Default command --- +CMD ["python", "-u", "app.py"] diff --git a/services/fruit_ripeness_alert/app.py b/services/fruit_ripeness_alert/app.py new file mode 100644 index 000000000..9d0a724fe --- /dev/null +++ b/services/fruit_ripeness_alert/app.py @@ -0,0 +1,126 @@ +#!/usr/bin/env python3 +import os, json, uuid, requests +from datetime import datetime, timedelta, timezone +from kafka import KafkaProducer +from token_bootstrap import get_service_token + +# === Environment === +DB_API_BASE = os.getenv("DB_API_BASE", "http://db_api_service:8001") +DB_API_TOKEN_FILE = os.getenv("DB_API_TOKEN_FILE", "/app/secret/db_api_token") +KAFKA_BROKER = os.getenv("KAFKA_BROKER", "kafka:9092") +ALERT_TOPIC = os.getenv("ALERT_TOPIC", "alerts") +WINDOW_HOURS = int(os.getenv("WINDOW_HOURS", "168")) + + +def now_utc() -> datetime: + return datetime.now(timezone.utc) + +def iso(ts: datetime) -> str: + return ts.replace(tzinfo=timezone.utc).isoformat() + +def get_threshold(task_name="ripeness", headers=None): + """שולף את אחוז הסף מהטבלה task_thresholds לפי שם המשימה.""" + url = f"{DB_API_BASE}/api/tables/task_thresholds" + r = requests.get(url, headers=headers, timeout=15) + r.raise_for_status() + rows = r.json().get("rows", []) + if not rows: + print(f"[WARN] No thresholds found at all, using default 0.8") + return 0.8 + + match = next((row for row in rows if row.get("task") == task_name), None) + if not match: + print(f"[WARN] No threshold found for task={task_name}, using default 0.8") + return 0.8 + + threshold = float(match.get("threshold", 0.8)) + print(f"[INFO] Task '{task_name}' threshold: {threshold*100:.1f}%") + return threshold +from datetime import datetime, timezone + +def get_rollups(window_start, window_end, headers=None): + """ + שולפת את כל הרשומות מהטבלה ripeness_weekly_rollups_ts + ואז מסננת לפי טווח התאריכים (window_start → window_end) בפייתון. + """ + url = f"{DB_API_BASE}/api/tables/ripeness_weekly_rollups_ts" + print(f"[DEBUG] Fetching full table from {url}", flush=True) + + try: + # שולף את כל הנתונים (בלי פילטרים) + r = requests.get(url, headers=headers, timeout=60) + r.raise_for_status() + except requests.exceptions.HTTPError as e: + print(f"[ERROR] HTTP {r.status_code}: {r.text}", flush=True) + return [] + except Exception as e: + print(f"[ERROR] failed to fetch rollups: {e}", flush=True) + return [] + + data = r.json() + rows = data.get("rows", data) + + + def parse_ts(ts_str: str) -> datetime: + try: + return datetime.fromisoformat(ts_str.replace("Z", "+00:00")) + except Exception: + return datetime.min.replace(tzinfo=timezone.utc) + + filtered = [] + for row in rows: + ts = parse_ts(row.get("ts", "")) + if window_start <= ts <= window_end: + filtered.append(row) + + print(f"[INFO] Retrieved {len(filtered)} rollups after filtering (out of {len(rows)} total)") + return filtered + +def send_kafka_alert(producer, device_id, ratio, threshold): + alert = { + "alert_id": str(uuid.uuid4()), + "alert_type": "fruit_ripeness_high", + "device_id": device_id, + "started_at": iso(now_utc()), + "confidence": float(ratio), + "severity": 3, + "threshold": threshold, + "description": f"{ratio*100:.1f}% ripe/overripe fruits", # <── וגם את זה + } + + producer.send(ALERT_TOPIC, json.dumps(alert).encode("utf-8")) + producer.flush() + print(f"[ALERT] sent for {device_id}: {ratio*100:.1f}%") + +def main(): + token = get_service_token() + headers = {"Content-Type": "application/json"} + if token: + headers["X-Service-Token"] = token + + window_end = now_utc() + window_start = window_end - timedelta(hours=WINDOW_HOURS) + print(f"[INFO] Checking rollups {window_start} → {window_end}") + + threshold = get_threshold("ripeness", headers) + rows = get_rollups(window_start, window_end, headers) + if not rows: + print("[INFO] No data found.") + return + + producer = KafkaProducer(bootstrap_servers=[KAFKA_BROKER]) + + # iterate each device + for row in rows: + device_id = row.get("device_id") + pct = row.get("pct_ripe", 0.0) + if pct >= threshold: + send_kafka_alert(producer, device_id, pct, threshold) + else: + print(f"[INFO] {device_id}: below threshold {pct:.2f} < {threshold:.2f}") + + producer.close() + print("[DONE] process complete.") + +if __name__ == "__main__": + main() diff --git a/services/fruit_ripeness_alert/docker-compose.yml b/services/fruit_ripeness_alert/docker-compose.yml new file mode 100644 index 000000000..7a6966ee9 --- /dev/null +++ b/services/fruit_ripeness_alert/docker-compose.yml @@ -0,0 +1,23 @@ +services: + fruit_ripeness_alert: + build: . + container_name: fruit_ripeness_alert + environment: + - DB_API_BASE=http://db_api_service:8001 + - DB_API_SERVICE_NAME=fruit_ripeness_alert + - DB_ADMIN_USER=admin + - DB_ADMIN_PASS=admin123 + - DB_API_TOKEN_FILE=/app/secret/db_api_token + - KAFKA_BROKER=kafka:9092 + - ALERT_TOPIC=alerts + - WINDOW_HOURS=168 + volumes: + - .:/app + - ./secret:/app/secret + command: ["sleep", "infinity"] + networks: + - ag_cloud + +networks: + ag_cloud: + external: true diff --git a/services/fruit_ripeness_alert/requirements.txt b/services/fruit_ripeness_alert/requirements.txt new file mode 100644 index 000000000..d9e23d1df --- /dev/null +++ b/services/fruit_ripeness_alert/requirements.txt @@ -0,0 +1,2 @@ +requests +kafka-python diff --git a/services/fruit_ripeness_alert/token_bootstrap.py b/services/fruit_ripeness_alert/token_bootstrap.py new file mode 100644 index 000000000..a1dee593b --- /dev/null +++ b/services/fruit_ripeness_alert/token_bootstrap.py @@ -0,0 +1,62 @@ +import os, pathlib, time, requests + +DB_API_BASE = os.getenv("DB_API_BASE", "").strip() +DB_API_TOKEN_FILE = os.getenv("DB_API_TOKEN_FILE", "/app/secret/db_api_token") +DB_API_SERVICE_NAME = os.getenv("DB_API_SERVICE_NAME", "fruit_ripeness_alert").strip() or "fruit_ripeness_alert" + +def _safe_join_url(base: str, path: str) -> str: + return f"{base.rstrip('/')}/{path.lstrip('/')}" + +def _read_token(path: str) -> str | None: + p = pathlib.Path(path) + if p.exists(): + t = p.read_text(encoding="utf-8").strip() + if t and "***" not in t: + return t + return None + +def _write_token(path: str, token: str) -> None: + p = pathlib.Path(path) + p.parent.mkdir(parents=True, exist_ok=True) + p.write_text(token, encoding="utf-8") + +def _try_dev_bootstrap(): + """Try to get token using /auth/_dev_bootstrap (new API).""" + url = _safe_join_url(DB_API_BASE, "/auth/_dev_bootstrap") + payload = {"service_name": DB_API_SERVICE_NAME, "rotate_if_exists": True} + try: + r = requests.post(url, json=payload, timeout=10) + if r.status_code in (200, 201): + data = r.json() + sa = data.get("service_account") or {} + token = sa.get("raw_token") or sa.get("token") + if token and "***" not in token: + print("[BOOTSTRAP] obtained token via /auth/_dev_bootstrap") + return token.strip() + print(f"[BOOTSTRAP][WARN] _dev_bootstrap returned {r.status_code}: {r.text[:100]}") + except Exception as e: + print(f"[BOOTSTRAP][ERROR] {e}") + return None + +def get_service_token() -> str | None: + """Get or create a service token automatically.""" + if not DB_API_BASE: + print("[BOOTSTRAP][WARN] DB_API_BASE not set") + return None + + # Try existing file + token = _read_token(DB_API_TOKEN_FILE) + if token: + print(f"[BOOTSTRAP] using existing token from {DB_API_TOKEN_FILE}") + return token + + # Try bootstrap (new unified API) + print(f"[BOOTSTRAP] fetching new service token from {DB_API_BASE}") + token = _try_dev_bootstrap() + if token: + _write_token(DB_API_TOKEN_FILE, token) + print(f"[BOOTSTRAP] wrote token to {DB_API_TOKEN_FILE}") + return token + + print("[BOOTSTRAP][ERROR] Could not obtain service token.") + return None diff --git a/services/image-linker/Dockerfile.flink b/services/image-linker/Dockerfile.flink index 1ddec1b57..dd05fc18d 100644 --- a/services/image-linker/Dockerfile.flink +++ b/services/image-linker/Dockerfile.flink @@ -4,6 +4,14 @@ FROM flink:1.19.3-scala_2.12-java11 USER root +# Add local CA (place netfree-ca.crt next to this Dockerfile before building) +COPY netfree-ca.crt /usr/local/share/ca-certificates/netfree-ca.crt +RUN chmod 644 /usr/local/share/ca-certificates/netfree-ca.crt && update-ca-certificates + +ENV SSL_CERT_FILE=/etc/ssl/certs/ca-certificates.crt +ENV REQUESTS_CA_BUNDLE=/etc/ssl/certs/ca-certificates.crt +ENV PIP_DISABLE_PIP_VERSION_CHECK=1 + # --- Install Python and tools --- RUN apt-get update && \ apt-get install -y python3 python3-pip python3-venv wget && \ diff --git a/services/image-linker/config/topics.yaml b/services/image-linker/config/topics.yaml index c9bb51cef..bbdd3588e 100644 --- a/services/image-linker/config/topics.yaml +++ b/services/image-linker/config/topics.yaml @@ -1,13 +1,26 @@ teams: air: - metadata_topic: dev-aerial-images-keys - minio_topic: image.new.aerial - output_topic: image.new.aerial.connections + metadata_topic: aerial_images_metadata + minio_topic: image.new.aerial + output_topic: image_new_aerial_connections + security: + metadata_topic: dev-security-images-keys + minio_topic: image.new.security + output_topic: image_new_security_connections + + sounds: + metadata_topic: sounds_metadata + minio_topic: sound.new.sounds + output_topic: sound_new_sounds_connections + plants: + metadata_topic: sounds_ultra_metadata + minio_topic: sound.new.plants + output_topic: sound_new_plants_connections # fruits: # metadata: dev-fruits-images-keys # minio: image.new.fruits # output: image.new.fruits.connections - \ No newline at end of file + diff --git a/services/image-linker/job_linker.py b/services/image-linker/job_linker.py index 66f83a15e..e1f66dd2c 100644 --- a/services/image-linker/job_linker.py +++ b/services/image-linker/job_linker.py @@ -133,7 +133,7 @@ def on_timer(self, timestamp, ctx): self.minio_state.clear() self.cleanup_ts_state.clear() print("[CLEANUP] Cleared stale state after 5 minutes") - yield from [] # 🔧 fix: must return iterable (even empty) + yield from [] # fix: must return iterable (even empty) # ---------- Main Function ---------- diff --git a/services/inference_http/Dockerfile b/services/inference_http/Dockerfile index 764a286ed..d5167d1ea 100644 --- a/services/inference_http/Dockerfile +++ b/services/inference_http/Dockerfile @@ -1,22 +1,84 @@ +# ============================================================ +# Unified Inference HTTP Dockerfile (Fruit + Camera + YOLO) +# ============================================================ FROM python:3.11-slim -ENV PIP_NO_CACHE_DIR=1 PIP_DEFAULT_TIMEOUT=1200 PIP_DISABLE_PIP_VERSION_CHECK=1 + + + +ENV PIP_NO_CACHE_DIR=1 \ + PIP_DISABLE_PIP_VERSION_CHECK=1 \ + PIP_DEFAULT_TIMEOUT=1200 + + WORKDIR /app -RUN python -m pip install --upgrade pip setuptools wheel && \ - pip install --no-cache-dir numpy==1.26.4 --only-binary=:all: +RUN apt-get update && apt-get install -y --no-install-recommends \ + libglib2.0-0 \ + libglib2.0-dev \ + libsm6 \ + libxrender1 \ + libxext6 \ + libgl1 \ + libopenblas-dev \ + liblapack-dev \ + && rm -rf /var/lib/apt/lists/* + +RUN python -m pip install --upgrade pip setuptools wheel --only-binary=:all: && \ + pip install --no-cache-dir numpy==1.26.4 --only-binary=:all: && \ + pip install --no-cache-dir --extra-index-url https://download.pytorch.org/whl/cpu \ + torch==2.1.2 torchvision==0.16.2 --only-binary=:all: --upgrade-strategy only-if-needed + + +ENV PIP_NO_CACHE_DIR=1 \ + PIP_DEFAULT_TIMEOUT=1200 \ + PIP_DISABLE_PIP_VERSION_CHECK=1 +WORKDIR /app + +# Copy certs dir (may be empty) and trust *.crt if present +RUN apt-get update && apt-get install -y --no-install-recommends ca-certificates curl && \ + rm -rf /var/lib/apt/lists/* && \ + if [ -d /tmp/certs ] && ls /tmp/certs/*.crt >/dev/null 2>&1; then \ + cp /tmp/certs/*.crt /usr/local/share/ca-certificates/ && \ + chmod 644 /usr/local/share/ca-certificates/*.crt && \ + update-ca-certificates; \ + else \ + echo "No extra CA certs found. Skipping CA update."; \ + fi + +ENV SSL_CERT_FILE=/etc/ssl/certs/ca-certificates.crt \ + REQUESTS_CA_BUNDLE=/etc/ssl/certs/ca-certificates.crt \ + CURL_CA_BUNDLE=/etc/ssl/certs/ca-certificates.crt \ + PIP_CERT=/etc/ssl/certs/ca-certificates.crt +RUN printf "[global]\ncert = /etc/ssl/certs/ca-certificates.crt\n" > /etc/pip.conf + +# Python deps +RUN python -m pip install --upgrade pip setuptools wheel certifi && \ + pip install --no-cache-dir numpy==1.26.4 --only-binary=:all: RUN pip install --no-cache-dir --extra-index-url https://download.pytorch.org/whl/cpu \ torch==2.1.2 torchvision==0.16.2 --only-binary=:all: --upgrade-strategy only-if-needed - COPY requirements.txt /app/requirements.txt -RUN pip install --no-cache-dir -r /app/requirements.txt --only-binary=:all: --upgrade-strategy only-if-needed + + +RUN pip install --no-cache-dir -r /app/requirements.txt && \ + pip install --no-cache-dir \ + opencv-python-headless \ + ultralytics==8.2.34 \ + boto3 \ + pillow \ + requests \ + "numpy<2" \ + && rm -rf /root/.cache/pip COPY app.py model_registry.py /app/ COPY adapters /app/adapters -COPY models /app/models +COPY models /app/models +COPY weights /app/weights +COPY models/soil_moisture/artifacts /app/artifacts + +EXPOSE 8004 -EXPOSE 8000 -CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "8000"] +CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "8004"] diff --git a/services/inference_http/adapters/fruit_segmentation_runner.py b/services/inference_http/adapters/fruit_segmentation_runner.py new file mode 100644 index 000000000..7df623a8d --- /dev/null +++ b/services/inference_http/adapters/fruit_segmentation_runner.py @@ -0,0 +1,92 @@ +import os, io, tempfile, hashlib, cv2, numpy as np, boto3, torch + +def allow_unrestricted_torch_load(): + _original_load = torch.load + def patched_load(*args, **kwargs): + kwargs["weights_only"] = False + return _original_load(*args, **kwargs) + torch.load = patched_load + +allow_unrestricted_torch_load() +# === End Patch === + +import time +from typing import Any, Dict, Optional +from datetime import datetime, timezone +from ultralytics import YOLO + +def sha256_hex(path: str) -> str: + h = hashlib.sha256() + with open(path, "rb") as f: + for chunk in iter(lambda: f.read(8192), b""): + h.update(chunk) + return h.hexdigest() + +class FruitSegmentationRunner: + def __init__(self, weights_path: Optional[str] = None, model_tag: Optional[str] = None): + self.weights_path = weights_path or os.getenv("WEIGHTS_PATH", "/app/weights/yolov8-fruits.pt") + self.model = YOLO(self.weights_path) + raw_endpoint = os.getenv("MINIO_ENDPOINT", "minio-hot:9000").strip() + if not raw_endpoint.startswith(("http://", "https://")): + endpoint = f"http://{raw_endpoint}" + else: + endpoint = raw_endpoint + self.s3 = boto3.client( + "s3", + endpoint_url=endpoint, + aws_access_key_id=os.getenv("MINIO_ACCESS_KEY", "minioadmin"), + aws_secret_access_key=os.getenv("MINIO_SECRET_KEY", "minioadmin123") + ) + def run(self, image_bytes: bytes | None = None, model_tag=None, extra=None) -> Dict[str, Any]: + """Main inference entrypoint for HTTP""" + bucket_in = extra.get("bucket") if extra else "imagery" + key = extra.get("key") if extra else None + if not key: + return {"error": "missing key"} + + + if image_bytes: + img_array = np.frombuffer(image_bytes, np.uint8) + img = cv2.imdecode(img_array, cv2.IMREAD_COLOR) + if img is None: + return {"error": "failed to decode image from bytes"} + else: + with tempfile.TemporaryDirectory() as tmpdir: + local_path = os.path.join(tmpdir, os.path.basename(key)) + self.s3.download_file(bucket_in, key, local_path) + img = cv2.imread(local_path) + if img is None: + return {"error": "failed to read image"} + + t0 = time.time() + results = self.model.predict(img, conf=0.3, iou=0.45, verbose=False) + latency_ms = int((time.time() - t0) * 1000) + boxes = results[0].boxes + count = 0 + + if boxes: + with tempfile.TemporaryDirectory() as tmpdir: + for i, box in enumerate(boxes): + label = results[0].names[int(box.cls[0])] + if label.lower() not in [ + "apple", "banana", "orange", "pear", "peach", "plum", + "mango", "grape", "cherry", "pomegranate" + ]: + continue + x1, y1, x2, y2 = map(int, box.xyxy[0]) + crop = img[y1:y2, x1:x2] + if crop.size == 0: + continue + out_name = f"{os.path.splitext(os.path.basename(key))[0]}_fruit_{i+1}.jpg" + out_key = f"segments/{out_name}" + out_path = os.path.join(tmpdir, out_name) + cv2.imwrite(out_path, crop) + self.s3.upload_file(out_path, bucket_in, out_key) + count += 1 + + return { + "label": "fruit", + "count": count, + "latency_ms_model": latency_ms, + "bucket_out": bucket_in + } diff --git a/services/inference_http/adapters/soil_moisture_runner.py b/services/inference_http/adapters/soil_moisture_runner.py new file mode 100644 index 000000000..7c1624c70 --- /dev/null +++ b/services/inference_http/adapters/soil_moisture_runner.py @@ -0,0 +1,158 @@ + +""" +Adapter for soil moisture inference in the generic HTTP inference flow. +Uses the shared inference logic from the soil-moisture service. +""" + +import os +import base64 +import logging +import sys +from typing import Any, Dict, Optional +from PIL import Image +from io import BytesIO +import numpy as np +import cv2 +import time +import re + +logger = logging.getLogger(__name__) + + +class SoilMoistureRunner: + """ + Adapter that wraps the soil moisture inference logic. + """ + + def __init__(self, weights_path: Optional[str] = None, model_tag: Optional[str] = None): + self.model_tag = model_tag + self.weights_path = weights_path + + try: + # Add models directory to path + models_dir = os.path.join(os.path.dirname(__file__), '..', 'models') + if models_dir not in sys.path: + sys.path.insert(0, models_dir) + + # Import soil moisture components + from soil_moisture.src.app.config import Settings, load_zones + from soil_moisture.src.app.inference import Inferencer + from soil_moisture.src.app.db import DB + from soil_moisture.src.app.inference_logic import SoilMoistureInferenceLogic + + logger.info("Initializing SoilMoistureRunner...") + + # Initialize components + self.settings = Settings() + + # Load zones config if available + if hasattr(self.settings, 'zones_file') and self.settings.zones_file: + if os.path.exists(self.settings.zones_file): + self.zones_cfg = load_zones(self.settings.zones_file) + else: + logger.warning(f"zones_file not found: {self.settings.zones_file}") + self.zones_cfg = {} + else: + self.zones_cfg = {} + + self.db = DB(self.settings.pg_dsn) + self.inferencer = Inferencer(self.settings, self.db) + + # Initialize Kafka producer (optional) + producer = None + try: + from soil_moisture.src.app.kafka_producer import ControlProducer + producer = ControlProducer( + self.settings.kafka_brokers, + self.settings.kafka_topic, + self.settings.kafka_dlt + ) + except Exception as e: + logger.warning(f"Kafka producer init failed: {e}") + + # Initialize shared inference logic + self.inference_logic = SoilMoistureInferenceLogic( + settings=self.settings, + db=self.db, + inferencer=self.inferencer, + producer=producer + ) + + logger.info("SoilMoistureRunner initialized successfully!") + + except Exception as e: + logger.error(f"Failed to initialize SoilMoistureRunner: {e}", exc_info=True) + raise + + def run(self, image_bytes: Any, model_tag: Optional[str] = None, + extra: Optional[Dict] = None) -> Dict: + """ + Run soil moisture inference using the shared inference logic. + """ + start_time = time.time() + + try: + bucket_in = extra.get("bucket") if extra else "imagery" + key = extra.get("key") if extra else None + if not key: + return {"error": "missing key"} + + # --- Extract device_id from the key (pattern: path/to/image/dev-id_ts.jpg) --- + def extract_device_id_from_key(key: str) -> str: + filename = key.split("/")[-1] # get "dev-id_ts.jpg" + match = re.match(r"([^_]+)_", filename) # capture part before "_" + if match: + return match.group(1) + return "unknown" + + # --- Decode image --- + img_array = np.frombuffer(image_bytes, np.uint8) + img = cv2.imdecode(img_array, cv2.IMREAD_COLOR) + if img is None: + return {"error": "failed to decode image from bytes"} + + # --- Determine device_id --- + device_id = "unknown" + if extra: + if "device_id" in extra: + device_id = extra["device_id"] + elif "filename" in extra: + device_id = self.inference_logic.extract_device_id(extra["filename"]) + elif "key" in extra: + device_id = extract_device_id_from_key(extra["key"]) + + # --- Convert input to PIL Image --- + if isinstance(image_bytes, bytes): + img = Image.open(BytesIO(image_bytes)) + elif isinstance(image_bytes, str): + img_bytes = base64.b64decode(image_bytes) + img = Image.open(BytesIO(img_bytes)) + elif isinstance(image_bytes, Image.Image): + img = image_bytes + else: + raise ValueError(f"Unsupported input type: {type(image_bytes)}") + + # --- Run inference --- + result = self.inference_logic.infer_from_image(img, device_id) + + return { + "device_id": result["device_id"], + "dry_ratio": result["dry_ratio"], + "decision": result["decision"], + "confidence": result["confidence"], + "patch_count": result["patch_count"], + "duration_min": result.get("duration_min", 0), + "latency_ms_model": result.get("latency_ms", 0), + "ts": result.get("ts"), + "idempotency_key": result.get("idempotency_key"), + "debug": result.get("debug") + } + + except Exception as e: + logger.error(f"Inference failed: {e}", exc_info=True) + latency_ms = int((time.time() - start_time) * 1000) + return { + "error": str(e), + "device_id": locals().get("device_id", "unknown"), + "latency_ms_model": latency_ms + } \ No newline at end of file diff --git a/services/inference_http/app.py b/services/inference_http/app.py index 3a490493d..ec6b984e9 100644 --- a/services/inference_http/app.py +++ b/services/inference_http/app.py @@ -60,12 +60,14 @@ def infer_json( obj.close() obj.release_conn() + # Attempt to run the model with bytes input first # Attempt to run the model with bytes input first try: - result = runner.run(image_bytes) + result = runner.run(image_bytes, extra={"bucket": req.bucket, "key": req.key}) except TypeError: # If the function does not accept bytes, try with URI instead - result = runner.run(s3_uri) + result = runner.run(s3_uri, extra={"bucket": req.bucket, "key": req.key}) + latency_ms = int((time.perf_counter() - started) * 1000) return { diff --git a/services/inference_http/model_registry.py b/services/inference_http/model_registry.py index 3c5c0d6df..f9ab4758e 100644 --- a/services/inference_http/model_registry.py +++ b/services/inference_http/model_registry.py @@ -1,15 +1,15 @@ -from typing import Any, Dict from adapters.fruit_defect_runner import FruitDefectRunner +from adapters.fruit_segmentation_runner import FruitSegmentationRunner +from adapters.soil_moisture_runner import SoilMoistureRunner -class FruitRunner: - def __init__(self): - self.impl = FruitDefectRunner() - - def run(self, image_bytes: bytes, model_tag=None, extra=None) -> Dict[str, Any]: - return self.impl.run(image_bytes, model_tag=model_tag, extra=extra) def get_model_runner(team: str): t = (team or "").lower() - if t == "fruit": - return FruitRunner() + if t == "fruit_defect": + return FruitDefectRunner() + if t == "camera": + return FruitSegmentationRunner() + if t == "soil_moisture": + return SoilMoistureRunner() raise ValueError(f"unknown TEAM {t}") + \ No newline at end of file diff --git a/services/inference_http/models/soil_moisture/.gitignore b/services/inference_http/models/soil_moisture/.gitignore new file mode 100644 index 000000000..62d87daf0 --- /dev/null +++ b/services/inference_http/models/soil_moisture/.gitignore @@ -0,0 +1 @@ +samples/ \ No newline at end of file diff --git a/services/inference_http/models/soil_moisture/Dockerfile b/services/inference_http/models/soil_moisture/Dockerfile new file mode 100644 index 000000000..9b1bb644d --- /dev/null +++ b/services/inference_http/models/soil_moisture/Dockerfile @@ -0,0 +1,44 @@ + +FROM python:3.10-slim + +WORKDIR /app + +# --- 1) installing ca-certificates --- +RUN apt-get update && apt-get install -y --no-install-recommends \ + ca-certificates curl \ + && rm -rf /var/lib/apt/lists/* + +# --- 2) copying NetFree certificate and adding it to the system --- +COPY netfree-ca.crt /usr/local/share/ca-certificates/netfree-ca.crt +RUN chmod 644 /usr/local/share/ca-certificates/netfree-ca.crt && \ + update-ca-certificates + +# Setting to ensure the updated certificate is used +ENV REQUESTS_CA_BUNDLE=/etc/ssl/certs/ca-certificates.crt +ENV SSL_CERT_FILE=/etc/ssl/certs/ca-certificates.crt +ENV PIP_CERT=/etc/ssl/certs/ca-certificates.crt + +# --- 3) System dependencies required for FastAPI etc. --- +RUN apt-get update && apt-get install -y --no-install-recommends \ + libglib2.0-0 libsm6 libxrender1 libxext6 \ + && rm -rf /var/lib/apt/lists/* + +# --- 4) Installing dependencies --- +COPY requirements-api.txt . +# RUN pip install --trusted-host pypi.org --trusted-host pypi.python.org \ +# --trusted-host files.pythonhosted.org --no-cache-dir -r requirements-api.txt + +RUN pip config set global.require-hashes false && \ + pip install --trusted-host pypi.org --trusted-host pypi.python.org \ + --trusted-host files.pythonhosted.org --no-cache-dir -r requirements-api.txt +# --- 5) Copying code --- +COPY src ./src +COPY configs ./configs +COPY artifacts ./artifacts +COPY src/sql/init_db.sql /initdb/init_db.sql + +ENV PYTHONPATH=/app +ENV SCHEDULE_UPDATE=1 +RUN pip install python-multipart + +CMD ["uvicorn", "src.app.service:app", "--host", "0.0.0.0", "--port", "8000"] \ No newline at end of file diff --git a/services/inference_http/models/soil_moisture/README.md b/services/inference_http/models/soil_moisture/README.md new file mode 100644 index 000000000..0aacb3f61 --- /dev/null +++ b/services/inference_http/models/soil_moisture/README.md @@ -0,0 +1,119 @@ +# Soil Moisture DL Pipeline – Real-Time Irrigation Control (ONNX Inference) + +This repository delivers an end-to-end **deep learning** pipeline to detect soil moisture state +(**wet / dry**) from ground-level RGB images and trigger **real-time irrigation** actions. + +## Highlights +- **Training (PyTorch)**: MobileNetV3-small (transfer learning) + augmentations. +- **Export** to **ONNX** for light-weight **CPU/Jetson** inference. +- **Inference Service (FastAPI)**: + - Tiling into patches + - Per-patch ONNX inference + - Zone policy with hysteresis (dry_ratio_high / dry_ratio_low / min_patches) + - **Kafka** publish to `irrigation.control` (idempotent) + DLQ + - **Postgres** persistence in `soil_moisture_events` (+ optional schedule UPSERT + audit) + - **Prometheus** metrics + health/ready endpoints + +--- + +## Run + +```bash +docker compose up -d api +``` + +The API will be available at: [http://localhost:8000](http://localhost:8000) + +--- + +## Endpoints + +| Method | Path | Description | +|--------|------|-------------| +| `GET` | `/health` | Basic health check | +| `GET` | `/ready` | Checks DB connectivity | +| `GET` | `/metrics` | Prometheus metrics | +| `POST` | `/infer` | Run inference on uploaded image | + +### Example request + +```bash +curl -X POST "http://localhost:8000/infer" -F "zone_id=zone1" -F "image=@sample.jpg" +``` + +Response: +```json +{ + "device_id": "zone1", + "dry_ratio": 0.42, + "decision": "stop", + "confidence": 0.87, + "patch_count": 48, + "ts": "2025-10-29T09:41:00Z", + "idempotency_key": "zone1:345621" +} +``` + +--- + +## Environment Variables + +| Name | Description | Example | +|------|--------------|----------| +| `PG_DSN` | Postgres connection string | `postgresql://user:pass@host.docker.internal:5432/missions_db` | +| `KAFKA_BROKERS` | Kafka brokers | `kafka:9092` | +| `KAFKA_TOPIC` | Kafka topic for irrigation control | `irrigation.control` | +| `KAFKA_DLT` | Kafka DLQ topic | `irrigation.control.dlq` | +| `ZONES_FILE` | Path to zone configuration | `/app/configs/zones.yaml` | +| `SCHEDULE_UPDATE` | Enables schedule table update | `1` | +| `DECISION_WINDOW_SEC` | Time window for decision hysteresis | `3` | + +--- + +## Notes +- The service depends on Postgres and Kafka within the `ag_cloud` Docker network. +- If Kafka is unreachable, messages are logged but not published. +- Duplicate inferences are prevented using an idempotency key per decision window. +- Metrics exposed for Prometheus under `/metrics`. + +--- + +## Example Compose Context + +```yaml +services: + api: + build: + context: . + dockerfile: Dockerfile + environment: + PG_DSN: postgresql://missions_user:pg123@host.docker.internal:5432/missions_db + KAFKA_BROKERS: kafka:9092 + KAFKA_TOPIC: irrigation.control + KAFKA_DLT: irrigation.control.dlq + ZONES_FILE: /app/configs/zones.yaml + DECISION_WINDOW_SEC: 3 + PATCH_SIZE: 256 + PATCH_STRIDE: 256 + SCHEDULE_UPDATE: 1 + volumes: + - ./configs:/app/configs + - ./artifacts:/app/artifacts + ports: + - "8000:8000" + networks: + - ag_cloud +``` + +--- + +## Testing + +```bash +pytest -v +``` + +--- + +## License +Internal AgCloud component – for research and development use only. diff --git a/services/inference_http/models/soil_moisture/artifacts/best.pt b/services/inference_http/models/soil_moisture/artifacts/best.pt new file mode 100644 index 000000000..4fbed017d Binary files /dev/null and b/services/inference_http/models/soil_moisture/artifacts/best.pt differ diff --git a/services/inference_http/models/soil_moisture/artifacts/label_mapping.json b/services/inference_http/models/soil_moisture/artifacts/label_mapping.json new file mode 100644 index 000000000..7b688603a --- /dev/null +++ b/services/inference_http/models/soil_moisture/artifacts/label_mapping.json @@ -0,0 +1,4 @@ +{ + "0": "dry", + "1": "wet" +} \ No newline at end of file diff --git a/services/inference_http/models/soil_moisture/artifacts/model.onnx b/services/inference_http/models/soil_moisture/artifacts/model.onnx new file mode 100644 index 000000000..6052e846e Binary files /dev/null and b/services/inference_http/models/soil_moisture/artifacts/model.onnx differ diff --git a/services/inference_http/models/soil_moisture/configs/zones.yaml b/services/inference_http/models/soil_moisture/configs/zones.yaml new file mode 100644 index 000000000..fa76afbb7 --- /dev/null +++ b/services/inference_http/models/soil_moisture/configs/zones.yaml @@ -0,0 +1,12 @@ +zones: + ZONE_A: + dry_ratio_high: 0.35 + dry_ratio_low: 0.25 + min_patches: 2 + duration_min: 10 + + ZONE_B: + dry_ratio_high: 0.40 + dry_ratio_low: 0.30 + min_patches: 2 + duration_min: 12 \ No newline at end of file diff --git a/services/inference_http/models/soil_moisture/docker-compose.yml b/services/inference_http/models/soil_moisture/docker-compose.yml new file mode 100644 index 000000000..c23e6b3e3 --- /dev/null +++ b/services/inference_http/models/soil_moisture/docker-compose.yml @@ -0,0 +1,30 @@ +networks: + worktree-main_ag_cloud: + external: true + +services: + api: + build: + context: . + dockerfile: Dockerfile + container_name: soil_api + environment: + PG_DSN: postgresql://missions_user:pg123@host.docker.internal:5432/missions_db + KAFKA_BROKERS: kafka:9092 # host.docker.internal:29092 + KAFKA_TOPIC: irrigation.control + KAFKA_DLT: irrigation.control.dlq + ZONES_FILE: /app/configs/zones.yaml + DECISION_WINDOW_SEC: 3 + PATCH_SIZE: 256 + PATCH_STRIDE: 256 + SCHEDULE_UPDATE: 1 + volumes: + - ./configs:/app/configs + - ./artifacts:/app/artifacts + ports: + - "8000:8000" + + networks: + - worktree-main_ag_cloud + + diff --git a/services/inference_http/models/soil_moisture/requirements-api.txt b/services/inference_http/models/soil_moisture/requirements-api.txt new file mode 100644 index 000000000..6e3ae713a --- /dev/null +++ b/services/inference_http/models/soil_moisture/requirements-api.txt @@ -0,0 +1,14 @@ +fastapi==0.114.2 +uvicorn==0.30.6 +onnxruntime==1.20.0 +numpy==2.1.1 +Pillow==10.4.0 +opencv-python==4.10.0.84 +kafka-python==2.0.2 +psycopg2-binary==2.9.10 +prometheus_client==0.21.0 +PyYAML==6.0.2 +python-dotenv==1.0.1 +requests==2.32.3 +python-multipart==0.0.6 +confluent_kafka==2.12.0 \ No newline at end of file diff --git a/services/inference_http/models/soil_moisture/requirements-train.txt b/services/inference_http/models/soil_moisture/requirements-train.txt new file mode 100644 index 000000000..21a78fde4 --- /dev/null +++ b/services/inference_http/models/soil_moisture/requirements-train.txt @@ -0,0 +1,9 @@ +torch>=2.2.0 +torchvision==0.23.0 +numpy==2.1.1 +Pillow==10.4.0 +opencv-python==4.10.0.84 +scikit-learn==1.5.2 +tqdm==4.66.5 +PyYAML==6.0.2 +onnx==1.19.0 diff --git a/services/inference_http/models/soil_moisture/src/.dockerignore b/services/inference_http/models/soil_moisture/src/.dockerignore new file mode 100644 index 000000000..f5bddaa38 --- /dev/null +++ b/services/inference_http/models/soil_moisture/src/.dockerignore @@ -0,0 +1,2 @@ +models/soil_moisture/samples/ +models/soil_moisture/tests/ \ No newline at end of file diff --git a/services/inference_http/models/soil_moisture/src/app/__init__.py b/services/inference_http/models/soil_moisture/src/app/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/services/inference_http/models/soil_moisture/src/app/config.py b/services/inference_http/models/soil_moisture/src/app/config.py new file mode 100644 index 000000000..719e03fa3 --- /dev/null +++ b/services/inference_http/models/soil_moisture/src/app/config.py @@ -0,0 +1,21 @@ +import os +import yaml +from dataclasses import dataclass +from typing import Dict, Any +from dotenv import load_dotenv +load_dotenv() + +@dataclass +class Settings: + kafka_brokers: str = os.getenv("KAFKA_BROKERS", "localhost:9092") + kafka_topic: str = os.getenv("KAFKA_TOPIC", "irrigation.control") + kafka_dlt: str = os.getenv("KAFKA_DLT", "irrigation.control.dlq") + pg_dsn: str = os.getenv("PG_DSN", "postgresql://postgres:postgres@localhost:5432/soil") + zones_file: str = os.getenv("ZONES_FILE", "configs/zones.yaml") + decision_window_sec: int = int(os.getenv("DECISION_WINDOW_SEC", "1")) + patch_size: int = int(os.getenv("PATCH_SIZE", "256")) + patch_stride: int = int(os.getenv("PATCH_STRIDE", "256")) + +def load_zones(path: str) -> Dict[str, Any]: + with open(path, "r", encoding="utf-8") as f: + return yaml.safe_load(f) diff --git a/services/inference_http/models/soil_moisture/src/app/db.py b/services/inference_http/models/soil_moisture/src/app/db.py new file mode 100644 index 000000000..356f523b2 --- /dev/null +++ b/services/inference_http/models/soil_moisture/src/app/db.py @@ -0,0 +1,134 @@ + +import json +from typing import Optional, Dict, Any +import psycopg2 +import psycopg2.extras +from contextlib import contextmanager + +class DB: + def __init__(self, dsn: str): + self.dsn = dsn + + @contextmanager + def conn(self): + conn = psycopg2.connect(self.dsn) + try: + yield conn + finally: + conn.close() + + def init_ok(self) -> bool: + try: + with self.conn() as c: + with c.cursor() as cur: + cur.execute("SELECT 1") + return True + except Exception: + return False + + def log_event(self, device_id: str, ts_iso: str, dry_ratio: float, + decision: str, confidence: float, patch_count: int, + idem_key: str, extra: Optional[Dict[str, Any]]=None) -> bool: + q = ''' + INSERT INTO soil_moisture_events + (device_id, ts, dry_ratio, decision, confidence, patch_count, idempotency_key, extra) + VALUES (%s, %s, %s, %s, %s, %s, %s, %s) + ON CONFLICT (idempotency_key) DO NOTHING + ''' + with self.conn() as c: + with c.cursor() as cur: + cur.execute(q, ( + device_id, + ts_iso, + dry_ratio, + decision, + confidence, + patch_count, + idem_key, + json.dumps(extra or {}) + )) + c.commit() + return cur.rowcount > 0 + + def load_device_policy(self, device_id: str) -> dict: + try: + with self.conn() as c: + with c.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur: + cur.execute(""" + SELECT prev_state, dry_ratio_high, dry_ratio_low, + min_patches, duration_min + FROM irrigation_policies + WHERE device_id = %s + """, (device_id,)) + row = cur.fetchone() + if not row: + print(f"No row found for device_id={device_id}") + raise ValueError("not found") + print(f"Loaded from DB: {dict(row)}") + return dict(row) + except Exception as e: + print(f"Falling back to defaults because: {e}") + # fallback defaults + return { + "prev_state": "stop", + "dry_ratio_high": 0.35, + "dry_ratio_low": 0.25, + "min_patches": 2, + "duration_min": 10 + } + + + def upsert_schedule(self, device_id: str, next_run_at: str, duration_min: int, + updated_by: str, update_reason: str) -> None: + with self.conn() as c: + with c.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur: + cur.execute("SELECT next_run_at, duration_min FROM irrigation_schedule WHERE device_id=%s", (device_id,)) + prev = cur.fetchone() + cur.execute(''' + INSERT INTO irrigation_schedule(device_id, next_run_at, duration_min, updated_by, update_reason) + VALUES (%s, %s, %s, %s, %s) + ON CONFLICT (device_id) DO UPDATE SET + next_run_at=EXCLUDED.next_run_at, + duration_min=EXCLUDED.duration_min, + updated_by=EXCLUDED.updated_by, + update_reason=EXCLUDED.update_reason, + updated_at=NOW() + ''', (device_id, next_run_at, duration_min, updated_by, update_reason)) + cur.execute(''' + INSERT INTO irrigation_schedule_audit(device_id, prev_next_run_at, prev_duration_min, + next_run_at, duration_min, updated_by, update_reason) + VALUES (%s, %s, %s, %s, %s, %s, %s) + ''', (device_id, + prev["next_run_at"] if prev else None, + prev["duration_min"] if prev else None, + next_run_at, duration_min, updated_by, update_reason)) + c.commit() + + def update_prev_state(self, device_id: str, new_state: str) -> None: + # default values for other new fields + default_policy = { + "dry_ratio_high": 0.35, + "dry_ratio_low": 0.25, + "min_patches": 2, + "duration_min": 10 + } + + with self.conn() as c: + with c.cursor() as cur: + cur.execute(""" + INSERT INTO irrigation_policies + (device_id, prev_state, dry_ratio_high, dry_ratio_low, min_patches, duration_min) + VALUES (%s, %s, %s, %s, %s, %s) + ON CONFLICT (device_id) DO UPDATE + SET prev_state = EXCLUDED.prev_state, + updated_at = NOW() + """, ( + device_id, + new_state, + default_policy["dry_ratio_high"], + default_policy["dry_ratio_low"], + default_policy["min_patches"], + default_policy["duration_min"] + )) + c.commit() + diff --git a/services/inference_http/models/soil_moisture/src/app/inference.py b/services/inference_http/models/soil_moisture/src/app/inference.py new file mode 100644 index 000000000..95c4d948c --- /dev/null +++ b/services/inference_http/models/soil_moisture/src/app/inference.py @@ -0,0 +1,106 @@ +from typing import Dict, Any, Tuple +from PIL import Image +import numpy as np +import os, time, logging +from .config import Settings +from .utils import normalize_lighting, tile_image, preprocess_onnx +from .metrics import METRICS +from .onnx_model import ONNXMoistureModel +from .db import DB + +logger = logging.getLogger("soil_api") + +DRY_LABEL = "dry" + +class Inferencer: + def __init__(self, settings: Settings, db: DB, + model_path: str = "artifacts/model.onnx", + label_map_path: str = "artifacts/label_mapping.json"): + self.settings = settings + self.db = db + self.model = ONNXMoistureModel(model_path, label_map_path) + self.classes = [self.model.label_map[str(i)] for i in range(len(self.model.label_map))] + + def decision_window_bucket(self, ts: float) -> int: + return int(ts // self.settings.decision_window_sec) * self.settings.decision_window_sec + + def infer_image(self, img: Image.Image, device_id: str) -> Tuple[Dict[str, Any], Dict[str, Any]]: + + t0 = time.time() + img_n = normalize_lighting(img) + patches = tile_image(img_n, self.settings.patch_size, self.settings.patch_stride) + + dry_votes = 0 + probs = [] + + logger.info("infer_image start device_id=%s patch_size=%d stride=%d total_patches=%d",device_id, self.settings.patch_size, self.settings.patch_stride, len(patches)) + + try: + dry_idx = self.classes.index(DRY_LABEL) + except ValueError: + logger.warning("DRY_LABEL '%s' not found in classes %s", DRY_LABEL, self.classes) + dry_idx = None + + for idx, p in enumerate(patches): + try: + proba = self.model.predict_proba_patch(p) + except Exception as e: + logger.exception("model.predict_proba_patch failed for patch idx=%d: %s", idx, e) + proba = np.zeros(len(self.classes), dtype=float) + + probs.append(proba) + arg = int(proba.argmax()) if proba.size else -1 + arg_label = self.classes[arg] if (arg >= 0 and arg < len(self.classes)) else "unknown" + maxp = float(proba.max()) if proba.size else 0.0 + logger.debug("patch idx=%d arg=%d label=%s maxp=%.4f", idx, arg, arg_label, maxp) + + if dry_idx is not None and arg == dry_idx: + dry_votes += 1 + + mean_confidence = float(np.mean([max(x) for x in probs])) if probs else 0.0 + + patch_count = len(patches) + dry_ratio = dry_votes / max(1, patch_count) + + # Policy / hysteresis + policy = self.db.load_device_policy(device_id) + prev_state = policy.get("prev_state") or "stop" + high = policy.get("dry_ratio_high") or 0.35 + low = policy.get("dry_ratio_low") or 0.25 + min_patches = policy.get("min_patches") or 2 + duration_min = policy.get("duration_min") or 10 + + logger.info("decision inputs prev_state=%s dry_votes=%d patch_count=%d dry_ratio=%.4f high=%.4f low=%.4f min_patches=%d", + prev_state, dry_votes, patch_count, dry_ratio, high, low, min_patches) + + decision = "noop" + if patch_count >= min_patches: + if prev_state != "run" and dry_ratio >= high: + decision = "run" + elif prev_state != "stop" and dry_ratio <= low: + decision = "stop" + else: + logger.debug("hysteresis conditions not met (prev_state=%s)", prev_state) + else: + logger.debug("not enough patches for decision: patch_count=%d min_patches=%d", patch_count, min_patches) + + new_state = decision if decision in ("run", "stop") else prev_state + logger.info("decision result=%s updated_state=%s duration_min=%d confidence=%.4f", + decision, new_state, duration_min, mean_confidence) + + METRICS["inference_latency_ms"].observe((time.time() - t0) * 1000.0) + + result = { + "dry_ratio": float(dry_ratio), + "decision": decision, + "confidence": float(mean_confidence), + "patch_count": int(patch_count), + "duration_min": duration_min + } + debug = { + "probs_shape": (len(probs), len(probs[0]) if probs else 0) + } + if new_state != prev_state: + self.db.update_prev_state(device_id, new_state) + + return result, debug diff --git a/services/inference_http/models/soil_moisture/src/app/inference_logic.py b/services/inference_http/models/soil_moisture/src/app/inference_logic.py new file mode 100644 index 000000000..8da8bae44 --- /dev/null +++ b/services/inference_http/models/soil_moisture/src/app/inference_logic.py @@ -0,0 +1,171 @@ +""" +Shared inference logic that can be used by both the API and adapters. +""" +import logging +import time +import datetime as dt +from typing import Dict, Any, Tuple, Optional +from PIL import Image + +logger = logging.getLogger(__name__) + + +class SoilMoistureInferenceLogic: + """ + Encapsulates the inference logic for soil moisture detection. + This can be used by both the FastAPI service and the Flink adapter. + """ + + def __init__(self, settings, db, inferencer, producer=None): + self.settings = settings + self.db = db + self.inferencer = inferencer + self.producer = producer + + def extract_device_id(self, filename: str) -> str: + """ + Extract device ID from filename in the format device_ts + (For example dev-f_20251106T1030.jpg) + """ + import os + base = os.path.basename(filename) + device_id = base.split("_")[0] + if not device_id: + raise ValueError(f"Invalid filename format: {filename}") + return device_id + + def build_idem_key(self, device_id: str, ts_unix: float) -> str: + """Build idempotency key""" + bucket = self.inferencer.decision_window_bucket(ts_unix) + return f"{device_id}:{int(bucket)}" + + def publish_and_persist( + self, + device_id: str, + decision: str, + duration_min: int, + confidence: float, + dry_ratio: float, + patch_count: int, + idem: str, + ts_iso: str + ) -> bool: + """ + Publish to Kafka and persist to database. + Returns True if saved successfully, False if duplicate. + """ + import json + import os + + payload = { + "device_id": device_id, + "command": decision if decision in ("run", "stop") else "noop", + "reason": "soil_dry", + "duration_min": duration_min if decision == "run" else None, + "confidence": confidence, + "ts": ts_iso, + "idempotency_key": idem + } + + saved = self.db.log_event( + device_id, ts_iso, dry_ratio, payload["command"], + confidence, patch_count, idem, + extra={"dry_ratio": dry_ratio} + ) + + if not saved: + logger.info(json.dumps({ + "msg": "duplicate_idempotency", + "device_id": device_id, + "idem": idem + })) + return False + + # Schedule update + schedule_update = os.getenv('SCHEDULE_UPDATE', '1') == '1' + if schedule_update and decision == 'run': + try: + self.db.upsert_schedule( + device_id, ts_iso, duration_min, + updated_by='soil_api', + update_reason='soil_dry' + ) + except Exception as e: + logger.warning('schedule update failed: %s', e) + + # Publish to Kafka + if self.producer: + self.producer.publish(payload) + else: + logger.warning( + "Kafka producer unavailable; skipping publish. payload=%s", + payload + ) + + return True + + def infer_from_image( + self, + img: Image.Image, + device_id: str, + ts_unix: Optional[float] = None + ) -> Dict[str, Any]: + """ + Run inference on an image and return results. + + Args: + img: PIL Image object + device_id: Device identifier + ts_unix: Unix timestamp (optional, defaults to current time) + + Returns: + Dictionary with inference results including: + - device_id + - dry_ratio + - decision + - confidence + - patch_count + - duration_min + - ts + - idempotency_key + - latency_ms + """ + if ts_unix is None: + ts_unix = time.time() + + start_time = time.time() + ts_iso = dt.datetime.utcfromtimestamp(ts_unix).isoformat() + "Z" + + # Run inference + result, debug = self.inferencer.infer_image(img, device_id) + + # Build idempotency key + idem = self.build_idem_key(device_id, ts_unix) + + # Publish and persist + saved = self.publish_and_persist( + device_id=device_id, + decision=result["decision"], + duration_min=result["duration_min"], + confidence=result["confidence"], + dry_ratio=result["dry_ratio"], + patch_count=result["patch_count"], + idem=idem, + ts_iso=ts_iso + ) + + latency_ms = int((time.time() - start_time) * 1000) + + return { + "device_id": device_id, + "dry_ratio": result["dry_ratio"], + "decision": result["decision"], + "confidence": result["confidence"], + "patch_count": result["patch_count"], + "duration_min": result.get("duration_min", 0), + "ts": ts_iso, + "idempotency_key": idem, + "latency_ms": latency_ms, + "saved": saved, + "debug": debug + } \ No newline at end of file diff --git a/services/inference_http/models/soil_moisture/src/app/kafka_producer.py b/services/inference_http/models/soil_moisture/src/app/kafka_producer.py new file mode 100644 index 000000000..b6052accd --- /dev/null +++ b/services/inference_http/models/soil_moisture/src/app/kafka_producer.py @@ -0,0 +1,39 @@ +import json +import logging +from confluent_kafka import Producer, KafkaError +from .metrics import METRICS + +logger = logging.getLogger("soil_api") + +class ControlProducer: + def __init__(self, brokers: str, topic: str, dlt: str): + self.topic = topic + self.dlt = dlt + self.producer = Producer({"bootstrap.servers": brokers}) + + def publish(self, payload: dict) -> None: + try: + self.producer.produce( + self.topic, + value=json.dumps(payload).encode("utf-8"), + on_delivery=self._delivery_report + ) + self.producer.flush(2) + METRICS["alerts_sent_total"].labels(decision=payload.get("command", "unknown")).inc() + except Exception as e: + logger.warning("Kafka publish failed: %s", e) + METRICS["kafka_publish_errors_total"].labels(reason=type(e).__name__).inc() + # try send to DLT + try: + dlt_payload = dict(payload) + dlt_payload["error"] = str(e) + self.producer.produce(self.dlt, value=json.dumps(dlt_payload).encode("utf-8")) + self.producer.flush(2) + except Exception as e2: + logger.error("DLT publish failed: %s", e2) + + def _delivery_report(self, err, msg): + if err is not None: + logger.warning("Delivery failed for record %s: %s", msg.key(), err) + else: + logger.debug("Record delivered to %s [%d] @ %d", msg.topic(), msg.partition(), msg.offset()) diff --git a/services/inference_http/models/soil_moisture/src/app/metrics.py b/services/inference_http/models/soil_moisture/src/app/metrics.py new file mode 100644 index 000000000..ce38f34de --- /dev/null +++ b/services/inference_http/models/soil_moisture/src/app/metrics.py @@ -0,0 +1,11 @@ +from prometheus_client import Counter, Histogram + +METRICS = { + "alerts_sent_total": Counter("alerts_sent_total", "Total alerts sent", ["decision"]), + "kafka_publish_errors_total": Counter("kafka_publish_errors_total", "Kafka publish errors", ["reason"]), + "inference_latency_ms": Histogram( + "inference_latency_ms", + "Inference latency (ms)", + buckets=(5,10,20,50,100,200,500,1000,2000) + ), +} diff --git a/services/inference_http/models/soil_moisture/src/app/onnx_model.py b/services/inference_http/models/soil_moisture/src/app/onnx_model.py new file mode 100644 index 000000000..c78db3f26 --- /dev/null +++ b/services/inference_http/models/soil_moisture/src/app/onnx_model.py @@ -0,0 +1,22 @@ +import json +import numpy as np +import onnxruntime as ort +from typing import List +from PIL import Image +from .utils import preprocess_onnx + +class ONNXMoistureModel: + def __init__(self, model_path: str, label_map_path: str): + self.sess = ort.InferenceSession(model_path, providers=['CPUExecutionProvider']) + with open(label_map_path, "r", encoding="utf-8") as f: + self.label_map = json.load(f) # index -> label + self.input_name = self.sess.get_inputs()[0].name + self.output_name = self.sess.get_outputs()[0].name + + def predict_proba_patch(self, patch: Image.Image): + x = preprocess_onnx(patch, size=224) + logits = self.sess.run([self.output_name], {self.input_name: x})[0] + # softmax on logits + e = np.exp(logits - logits.max(axis=1, keepdims=True)) + proba = e / e.sum(axis=1, keepdims=True) + return proba[0] diff --git a/services/inference_http/models/soil_moisture/src/app/schemas.py b/services/inference_http/models/soil_moisture/src/app/schemas.py new file mode 100644 index 000000000..336e7c323 --- /dev/null +++ b/services/inference_http/models/soil_moisture/src/app/schemas.py @@ -0,0 +1,14 @@ +from pydantic import BaseModel + +class InferRequest(BaseModel): + device_id: str + image_b64: str # base64-encoded RGB image + +class InferResponse(BaseModel): + device_id: str + dry_ratio: float + decision: str + confidence: float + patch_count: int + ts: str + idempotency_key: str diff --git a/services/inference_http/models/soil_moisture/src/app/service.py b/services/inference_http/models/soil_moisture/src/app/service.py new file mode 100644 index 000000000..36e09cc20 --- /dev/null +++ b/services/inference_http/models/soil_moisture/src/app/service.py @@ -0,0 +1,103 @@ +""" +FastAPI service for soil moisture inference. +Delegates business logic to inference_logic.py +""" +import base64 +import logging +from fastapi import FastAPI, UploadFile, File, HTTPException +from fastapi.responses import PlainTextResponse +from prometheus_client import generate_latest, CONTENT_TYPE_LATEST + +from .config import Settings, load_zones +from .schemas import InferRequest, InferResponse +from .inference import Inferencer +from .kafka_producer import ControlProducer +from .db import DB +from .metrics import METRICS +from .utils import load_image_from_b64 +from .inference_logic import SoilMoistureInferenceLogic + +logging.basicConfig(level=logging.DEBUG, format='%(message)s') +logger = logging.getLogger("soil_api") + +# Initialize components +settings = Settings() +zones_cfg = load_zones(settings.zones_file) +db = DB(settings.pg_dsn) +inferencer = Inferencer(settings, db) + +# Initialize Kafka producer +producer = None +try: + from kafka import KafkaProducer + producer = ControlProducer( + settings.kafka_brokers, + settings.kafka_topic, + settings.kafka_dlt + ) +except Exception as e: + import traceback + logger.warning("Kafka init failed: %s\n%s", e, traceback.format_exc()) + +# Initialize shared inference logic +inference_logic = SoilMoistureInferenceLogic( + settings=settings, + db=db, + inferencer=inferencer, + producer=producer +) + +app = FastAPI(title="Soil Moisture DL API", version="1.0.0") + + +@app.get("/health", response_class=PlainTextResponse) +def health(): + return "ok" + + +@app.get("/ready", response_class=PlainTextResponse) +def ready(): + if not db.init_ok(): + raise HTTPException(status_code=503, detail="DB not ready") + return "ready" + + +@app.get("/metrics") +def metrics(): + return PlainTextResponse(generate_latest(), media_type=CONTENT_TYPE_LATEST) + + +@app.post("/infer", response_model=InferResponse) +async def infer(image: UploadFile = File(None), body: InferRequest | None = None): + """ + Run inference on a soil moisture image. + Accepts either multipart form data (file upload) or JSON with base64 image. + """ + # Parse input + if body is not None: + img = load_image_from_b64(body.image_b64) + device_id = inference_logic.extract_device_id(body.filename) + else: + if image is None: + raise HTTPException( + status_code=400, + detail="Provide multipart (file) or JSON (image_b64)" + ) + filename = image.filename + device_id = inference_logic.extract_device_id(filename) + content = await image.read() + img = load_image_from_b64(base64.b64encode(content).decode("utf-8")) + + # Run inference using shared logic + result = inference_logic.infer_from_image(img, device_id) + + # Return response + return InferResponse( + device_id=result["device_id"], + dry_ratio=result["dry_ratio"], + decision=result["decision"], + confidence=result["confidence"], + patch_count=result["patch_count"], + ts=result["ts"], + idempotency_key=result["idempotency_key"] + ) \ No newline at end of file diff --git a/services/inference_http/models/soil_moisture/src/app/utils.py b/services/inference_http/models/soil_moisture/src/app/utils.py new file mode 100644 index 000000000..2853c5250 --- /dev/null +++ b/services/inference_http/models/soil_moisture/src/app/utils.py @@ -0,0 +1,29 @@ +import base64, io +from PIL import Image, ImageOps +import numpy as np +from typing import List + +def load_image_from_b64(b64: str) -> Image.Image: + data = base64.b64decode(b64) + return Image.open(io.BytesIO(data)).convert("RGB") + +def normalize_lighting(img: Image.Image) -> Image.Image: + r, g, b = img.split() + r, g, b = ImageOps.equalize(r), ImageOps.equalize(g), ImageOps.equalize(b) + return Image.merge("RGB", (r, g, b)) + +def tile_image(img: Image.Image, patch_size: int, stride: int) -> List[Image.Image]: + w, h = img.size + patches = [] + for y in range(0, h - patch_size + 1, stride): + for x in range(0, w - patch_size + 1, stride): + patches.append(img.crop((x, y, x + patch_size, y + patch_size))) + if not patches: + patches.append(img.resize((patch_size, patch_size))) + return patches + +def preprocess_onnx(pil_img: Image.Image, size: int = 224) -> np.ndarray: + img = pil_img.resize((size, size)) + arr = np.asarray(img).astype("float32") / 255.0 + arr = arr.transpose(2,0,1) # HWC -> CHW + return arr[None, :, :, :] # NCHW diff --git a/services/inference_http/models/soil_moisture/src/scripts/consume_once.py b/services/inference_http/models/soil_moisture/src/scripts/consume_once.py new file mode 100644 index 000000000..b035f4017 --- /dev/null +++ b/services/inference_http/models/soil_moisture/src/scripts/consume_once.py @@ -0,0 +1,40 @@ +import argparse +import json +from kafka import KafkaConsumer, errors + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--brokers", default="localhost:29092", help="Broker list, e.g. localhost:29092 or kafka:9092") + parser.add_argument("--topic", default="irrigation.control") + parser.add_argument("--group", default="debug-consumer") + parser.add_argument("--from-beginning", action="store_true", help="Read from earliest offset") + args = parser.parse_args() + + print(f"Connecting to Kafka brokers: {args.brokers}") + try: + consumer = KafkaConsumer( + args.topic, + bootstrap_servers=args.brokers.split(","), + group_id=args.group, + enable_auto_commit=False, + auto_offset_reset="earliest" if args.from_beginning else "latest", + value_deserializer=lambda v: json.loads(v.decode("utf-8")), + consumer_timeout_ms=0 + ) + print(f"Listening on topic '{args.topic}' (group={args.group})...") + except errors.NoBrokersAvailable: + print("❌ Cannot connect to Kafka. Check host/port and Docker networking.") + return + + try: + for message in consumer: + print("\n--- New message ---") + print(json.dumps(message.value, indent=2)) + except KeyboardInterrupt: + print("\nStopped by user.") + finally: + consumer.close() + print("Consumer closed.") + +if __name__ == "__main__": + main() diff --git a/services/inference_http/models/soil_moisture/src/scripts/demo_feed.py b/services/inference_http/models/soil_moisture/src/scripts/demo_feed.py new file mode 100644 index 000000000..3be42924a --- /dev/null +++ b/services/inference_http/models/soil_moisture/src/scripts/demo_feed.py @@ -0,0 +1,48 @@ +import argparse, os, base64, time, json, glob +import requests + +def main(): + ap = argparse.ArgumentParser() + ap.add_argument("--images-dir", required=True) + ap.add_argument("--api", default="http://localhost:8000") + args = ap.parse_args() + + # Collect all images + imgs = [] + for ext in ("*.jpg", "*.jpeg", "*.png", "*.bmp"): + imgs += glob.glob(os.path.join(args.images_dir, '**', ext), recursive=True) + imgs = sorted(imgs) + + if not imgs: + print("No images found in", args.images_dir) + return + + for path in imgs: + filename = os.path.basename(path) + + # IMPORTANT: The device_id must be encoded inside the filename, + # e.g. device123_20250101T1030.jpg + print(f"Sending {filename} ...") + + try: + with open(path, "rb") as f: + files = {"image": (filename, f)} + r = requests.post( + args.api + "/infer", + files=files, + timeout=60 + ) + + if r.status_code != 200: + print("Error:", r.status_code, r.text) + else: + print(json.dumps(r.json(), indent=2)) + + except Exception as e: + print("Request failed:", e) + + time.sleep(0.4) + + +if __name__ == "__main__": + main() diff --git a/services/inference_http/models/soil_moisture/src/scripts/eval_test_set.py b/services/inference_http/models/soil_moisture/src/scripts/eval_test_set.py new file mode 100644 index 000000000..875834dea --- /dev/null +++ b/services/inference_http/models/soil_moisture/src/scripts/eval_test_set.py @@ -0,0 +1,42 @@ +# src/scripts/eval_test_onnx.py +import json, numpy as np, onnxruntime as ort +from pathlib import Path +from torchvision import transforms +from torchvision.datasets import ImageFolder +from torch.utils.data import DataLoader + +def load_label_mapping(path="artifacts/label_mapping.json"): + with open(path,"r") as f: + return json.load(f) + +def preprocess_pil(img): + tf = transforms.Compose([transforms.Resize((224,224)), transforms.ToTensor()]) + t = tf(img).numpy() + return np.expand_dims(t, axis=0).astype(np.float32) + +def run_onnx_eval(onnx_path="artifacts/model.onnx", test_dir="samples/test", batch_size=16): + label_map = load_label_mapping() + classes = [label_map[str(i)] for i in range(len(label_map))] + ds = ImageFolder(test_dir, transform=None) # we'll read PIL ourselves + loader = DataLoader(ds, batch_size=batch_size, shuffle=False) + + sess = ort.InferenceSession(onnx_path, providers=['CPUExecutionProvider']) + input_name = sess.get_inputs()[0].name + output_name = sess.get_outputs()[0].name + + y_true, y_pred = [], [] + from PIL import Image + for path, label in ds.samples: + img = Image.open(path).convert("RGB") + x = preprocess_pil(img) + logits = sess.run([output_name], {input_name: x})[0] + pred = int(np.argmax(logits, axis=1)[0]) + y_true.append(label) + y_pred.append(pred) + + from sklearn.metrics import classification_report, confusion_matrix + print(classification_report(y_true, y_pred, target_names=classes, digits=4)) + print(confusion_matrix(y_true, y_pred)) + +if __name__ == "__main__": + run_onnx_eval() diff --git a/services/inference_http/models/soil_moisture/src/scripts/print_db_events.py b/services/inference_http/models/soil_moisture/src/scripts/print_db_events.py new file mode 100644 index 000000000..07b41b875 --- /dev/null +++ b/services/inference_http/models/soil_moisture/src/scripts/print_db_events.py @@ -0,0 +1,24 @@ +import argparse, psycopg2, psycopg2.extras, json + +def main(): + ap = argparse.ArgumentParser() + ap.add_argument("--dsn", default="postgresql://missions_user:pg123@127.0.0.1:5432/missions_db") + ap.add_argument("--limit", type=int, default=10) + args = ap.parse_args() + + conn = psycopg2.connect(args.dsn) + try: + cur = conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) + cur.execute(""" + SELECT id, device_id, ts, dry_ratio, decision, confidence, patch_count, idempotency_key + FROM soil_moisture_events + ORDER BY id DESC + LIMIT %s + """, (args.limit,)) + rows = cur.fetchall() + print(json.dumps(rows, indent=2, default=str)) + finally: + conn.close() + +if __name__ == "__main__": + main() diff --git a/services/inference_http/models/soil_moisture/src/sql/init_db.sql b/services/inference_http/models/soil_moisture/src/sql/init_db.sql new file mode 100644 index 000000000..1476755fc --- /dev/null +++ b/services/inference_http/models/soil_moisture/src/sql/init_db.sql @@ -0,0 +1,47 @@ +CREATE TABLE IF NOT EXISTS soil_moisture_events ( + id SERIAL PRIMARY KEY, + device_id TEXT NOT NULL, + ts TIMESTAMPTZ NOT NULL DEFAULT NOW(), + dry_ratio REAL NOT NULL, + decision TEXT NOT NULL, + confidence REAL NOT NULL, + patch_count INT NOT NULL, + idempotency_key TEXT NOT NULL, + extra JSONB DEFAULT '{}'::jsonb +); + +CREATE UNIQUE INDEX IF NOT EXISTS idx_events_idem ON soil_moisture_events (idempotency_key); + +CREATE TABLE IF NOT EXISTS irrigation_schedule ( + device_id TEXT PRIMARY KEY, + next_run_at TIMESTAMPTZ NOT NULL, + duration_min INT NOT NULL, + updated_by TEXT NOT NULL, + update_reason TEXT NOT NULL, + updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW() +); + +CREATE TABLE IF NOT EXISTS irrigation_schedule_audit ( + id SERIAL PRIMARY KEY, + device_id TEXT NOT NULL, + prev_next_run_at TIMESTAMPTZ, + prev_duration_min INT, + next_run_at TIMESTAMPTZ NOT NULL, + duration_min INT NOT NULL, + updated_by TEXT NOT NULL, + update_reason TEXT NOT NULL, + updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW() +); +CREATE TABLE irrigation_policies ( + device_id TEXT NOT NULL, + prev_state TEXT, + dry_ratio_high REAL, + dry_ratio_low REAL, + min_patches INT, + duration_min INT, + updated_at TIMESTAMP DEFAULT NOW(), + PRIMARY KEY (device_id), + CONSTRAINT fk_device + FOREIGN KEY (device_id) REFERENCES devices(device_id) + ON DELETE CASCADE +); diff --git a/services/inference_http/models/soil_moisture/src/train/train_torch.py b/services/inference_http/models/soil_moisture/src/train/train_torch.py new file mode 100644 index 000000000..db7c89d9a --- /dev/null +++ b/services/inference_http/models/soil_moisture/src/train/train_torch.py @@ -0,0 +1,110 @@ +import argparse, os, json +from pathlib import Path +import numpy as np +from PIL import Image +import torch +import torch.nn as nn +import torch.optim as optim +from torch.utils.data import DataLoader +from torchvision import datasets, transforms, models +from tqdm import tqdm + +def build_dataloaders(train_dir, val_dir, batch_size): + aug = transforms.Compose([ + transforms.RandomResizedCrop(224, scale=(0.7, 1.0)), + transforms.RandomHorizontalFlip(), + transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2), + transforms.ToTensor() + ]) + val_tf = transforms.Compose([ + transforms.Resize((224,224)), + transforms.ToTensor() + ]) + train_ds = datasets.ImageFolder(train_dir, transform=aug) + val_ds = datasets.ImageFolder(val_dir, transform=val_tf) + train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True) + val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True) + return train_loader, val_loader, train_ds.classes + +@torch.no_grad() +def evaluate(model, loader, device): + model.eval() + correct, total = 0, 0 + for x, y in loader: + x, y = x.to(device), y.to(device) + logits = model(x) + pred = logits.argmax(1) + correct += (pred == y).sum().item() + total += y.numel() + return correct / max(1,total) + +def export_onnx(model, out_path, device): + model.eval() + dummy = torch.randn(1,3,224,224, device=device) + out_dir = os.path.dirname(out_path) + os.makedirs(out_dir, exist_ok=True) + torch.onnx.export( + model, dummy, out_path, + input_names=["input"], output_names=["logits"], + opset_version=17, dynamic_axes=None + ) + print("Exported ONNX to", out_path) + +def main(): + ap = argparse.ArgumentParser() + ap.add_argument("--train-dir", required=True) + ap.add_argument("--val-dir", required=True) + ap.add_argument("--epochs", type=int, default=15) + ap.add_argument("--batch-size", type=int, default=64) + ap.add_argument("--lr", type=float, default=3e-4) + ap.add_argument("--out", required=True) # ONNX output path + args = ap.parse_args() + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + train_loader, val_loader, classes = build_dataloaders(args.train_dir, args.val_dir, args.batch_size) + + # MobileNetV3-small transfer learning + model = models.mobilenet_v3_small(weights=models.MobileNet_V3_Small_Weights.DEFAULT) + in_features = model.classifier[3].in_features + model.classifier[3] = nn.Linear(in_features, len(classes)) + model.to(device) + + criterion = nn.CrossEntropyLoss() + optimizer = optim.AdamW(model.parameters(), lr=args.lr) + + best_acc = 0.0 + best_pt = "artifacts/best.pt" + os.makedirs("artifacts", exist_ok=True) + + for epoch in range(1, args.epochs+1): + model.train() + pbar = tqdm(train_loader, desc=f"Epoch {epoch}/{args.epochs}") + for x, y in pbar: + x, y = x.to(device), y.to(device) + logits = model(x) + loss = criterion(logits, y) + optimizer.zero_grad() + loss.backward() + optimizer.step() + pbar.set_postfix(loss=float(loss.item())) + + acc = evaluate(model, val_loader, device) + print(f"Val acc: {acc:.4f}") + if acc > best_acc: + best_acc = acc + torch.save(model.state_dict(), best_pt) + + # Export ONNX from best weights + model.load_state_dict(torch.load(best_pt, map_location=device)) + export_onnx(model, args.out, device=device) + + # Save label mapping + lbl_path = os.path.join(os.path.dirname(args.out), "label_mapping.json") + label_mapping = {str(i): cls for i, cls in enumerate(classes)} + with open(lbl_path, "w", encoding="utf-8") as f: + json.dump(label_mapping, f, indent=2) + print("Saved label mapping:", lbl_path) + print("Best val acc:", best_acc) + +if __name__ == "__main__": + main() diff --git a/services/inference_http/models/soil_moisture/tests/conftest.py b/services/inference_http/models/soil_moisture/tests/conftest.py new file mode 100644 index 000000000..5cfe605bc --- /dev/null +++ b/services/inference_http/models/soil_moisture/tests/conftest.py @@ -0,0 +1,10 @@ +import os +import sys + + +# Ensure `app` package (under src/) is importable when running tests +TEST_DIR = os.path.dirname(__file__) +SRC_DIR = os.path.abspath(os.path.join(TEST_DIR, "..", "src")) +if SRC_DIR not in sys.path: + sys.path.insert(0, SRC_DIR) + diff --git a/services/inference_http/models/soil_moisture/tests/test_config_and_schemas.py b/services/inference_http/models/soil_moisture/tests/test_config_and_schemas.py new file mode 100644 index 000000000..0fb9abae9 --- /dev/null +++ b/services/inference_http/models/soil_moisture/tests/test_config_and_schemas.py @@ -0,0 +1,43 @@ +import os +from app.config import load_zones, Settings +from app.schemas import InferRequest, InferResponse + + +def test_load_zones_file_exists_and_parses(): + # Use the repo's zones.yaml + base_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) + zones_path = os.path.join(base_dir, "configs", "zones.yaml") + data = load_zones(zones_path) + assert isinstance(data, dict) + assert "zones" in data + assert isinstance(data["zones"], dict) + + +def test_settings_defaults_are_present(): + s = Settings() + # Ensure critical fields exist and are strings/ints + assert isinstance(s.kafka_brokers, str) + assert isinstance(s.kafka_topic, str) + assert isinstance(s.pg_dsn, str) + assert isinstance(s.decision_window_sec, int) + assert isinstance(s.patch_size, int) + assert isinstance(s.patch_stride, int) + + +def test_schemas_models_construction(): + req = InferRequest(device_id="zone-a", image_b64="abcd==") + assert req.device_id == "zone-a" + assert isinstance(req.image_b64, str) + + resp = InferResponse( + device_id="zone-a", + dry_ratio=0.5, + decision="run", + confidence=0.9, + patch_count=4, + ts="2024-01-01T00:00:00Z", + idempotency_key="zone-a:12345", + ) + assert resp.decision in {"run", "stop", "noop"} + assert resp.patch_count == 4 + diff --git a/services/inference_http/models/soil_moisture/tests/test_inference.py b/services/inference_http/models/soil_moisture/tests/test_inference.py new file mode 100644 index 000000000..5906bacfe --- /dev/null +++ b/services/inference_http/models/soil_moisture/tests/test_inference.py @@ -0,0 +1,96 @@ +import sys +import types +import numpy as np +from PIL import Image + +# Pre-inject a lightweight stub for app.onnx_model to avoid importing onnxruntime +stub_mod = types.ModuleType("app.onnx_model") +class _StubONNX: + def __init__(self, *a, **k): + # minimal surface to satisfy Inferencer init + self.label_map = {"0": "dry", "1": "wet"} + def predict_proba_patch(self, patch): + return np.array([0.5, 0.5], dtype=float) +stub_mod.ONNXMoistureModel = _StubONNX +sys.modules.setdefault("app.onnx_model", stub_mod) + +from app.config import Settings +from app import inference as infmod + + +class _FakeDryModel: + def __init__(self, *args, **kwargs): + # emulates label_map: index->label + self.label_map = {"0": "dry", "1": "wet"} + + def predict_proba_patch(self, patch): + # Always predict class 0 (dry) with high confidence + return np.array([0.9, 0.1], dtype=float) + + +class _FakeWetModel: + def __init__(self, *args, **kwargs): + self.label_map = {"0": "dry", "1": "wet"} + + def predict_proba_patch(self, patch): + # Always predict class 1 (wet) + return np.array([0.1, 0.9], dtype=float) + + +def _make_inferencer(monkeypatch, model_cls): + # Replace ONNX model with a lightweight fake + monkeypatch.setattr(infmod, "ONNXMoistureModel", lambda *a, **k: model_cls()) + + s = Settings() + s.patch_size = 10 + s.patch_stride = 10 + s.decision_window_sec = 300 + return infmod.Inferencer(s) + + +def _make_image(w=20, h=10): + return Image.new("RGB", (w, h), color=(128, 128, 128)) + + +def test_decision_run_when_dry_ratio_high(monkeypatch): + inf = _make_inferencer(monkeypatch, _FakeDryModel) + # 20x10 with 10x10 patches & stride 10 => 2 patches + img = _make_image(20, 10) + zone_cfg = {"_state": "stop", "dry_ratio_high": 0.5, "dry_ratio_low": 0.3, "min_patches": 2, "duration_min": 7} + + result, debug = inf.infer_image(img, zone_cfg) + assert result["patch_count"] == 2 + assert result["dry_ratio"] == 1.0 + assert result["decision"] == "run" + assert zone_cfg["_state"] == "run" + + +def test_decision_stop_when_dry_ratio_low_and_prev_run(monkeypatch): + inf = _make_inferencer(monkeypatch, _FakeWetModel) + img = _make_image(20, 10) # 2 patches + zone_cfg = {"_state": "run", "dry_ratio_high": 0.6, "dry_ratio_low": 0.25, "min_patches": 2, "duration_min": 5} + + result, _ = inf.infer_image(img, zone_cfg) + assert result["patch_count"] == 2 + assert result["dry_ratio"] == 0.0 + assert result["decision"] == "stop" + assert zone_cfg["_state"] == "stop" + + +def test_noop_when_not_enough_patches(monkeypatch): + inf = _make_inferencer(monkeypatch, _FakeDryModel) + img = _make_image(20, 10) # 2 patches + zone_cfg = {"_state": "stop", "dry_ratio_high": 0.5, "dry_ratio_low": 0.3, "min_patches": 3, "duration_min": 7} + + result, _ = inf.infer_image(img, zone_cfg) + assert result["patch_count"] == 2 + assert result["decision"] == "noop" + # State remains unchanged + assert zone_cfg["_state"] == "stop" + + +def test_decision_window_bucket_rounds_down(monkeypatch): + inf = _make_inferencer(monkeypatch, _FakeDryModel) + # With window 300s, 1234 -> bucket start 1200 + bucket = inf.decision_window_bucket(1234.0) + assert bucket == 1200 diff --git a/services/inference_http/models/soil_moisture/tests/test_utils.py b/services/inference_http/models/soil_moisture/tests/test_utils.py new file mode 100644 index 000000000..a45ea10e4 --- /dev/null +++ b/services/inference_http/models/soil_moisture/tests/test_utils.py @@ -0,0 +1,58 @@ +import base64 +import io +from PIL import Image +import numpy as np + +from app.utils import ( + load_image_from_b64, + normalize_lighting, + tile_image, + preprocess_onnx, +) + + +def make_rgb_image(w=8, h=6, color=(120, 100, 80)): + return Image.new("RGB", (w, h), color=color) + + +def test_load_image_from_b64_roundtrip(): + img = make_rgb_image(5, 7, (10, 20, 30)) + buf = io.BytesIO() + img.save(buf, format="PNG") + b64 = base64.b64encode(buf.getvalue()).decode("utf-8") + + out = load_image_from_b64(b64) + assert out.mode == "RGB" + assert out.size == (5, 7) + + +def test_normalize_lighting_basic_properties(): + img = make_rgb_image(10, 10, (50, 100, 150)) + out = normalize_lighting(img) + assert out.mode == "RGB" + assert out.size == img.size + + +def test_tile_image_regular_grid(): + img = make_rgb_image(5, 5, (0, 0, 0)) + patches = tile_image(img, patch_size=3, stride=2) + # Positions: x in {0,2}, y in {0,2} => 4 patches + assert len(patches) == 4 + assert all(p.size == (3, 3) for p in patches) + + +def test_tile_image_small_image_resizes_to_single_patch(): + img = make_rgb_image(2, 2, (0, 0, 0)) + patches = tile_image(img, patch_size=4, stride=4) + assert len(patches) == 1 + assert patches[0].size == (4, 4) + + +def test_preprocess_onnx_output_shape_and_range(): + img = make_rgb_image(6, 6, (255, 128, 0)) + arr = preprocess_onnx(img, size=8) + assert arr.shape == (1, 3, 8, 8) + assert arr.dtype == np.float32 + assert np.isfinite(arr).all() + assert arr.min() >= 0.0 and arr.max() <= 1.0 + diff --git a/services/inference_http/requirements.txt b/services/inference_http/requirements.txt index 2e65f4b67..bdc33b3eb 100644 --- a/services/inference_http/requirements.txt +++ b/services/inference_http/requirements.txt @@ -1,6 +1,23 @@ + fastapi -uvicorn[standard] +uvicorn +pydantic minio +requests +torch +numpy<2 +opencv-python-headless +ultralytics==8.2.34 +boto3 pillow numpy==1.26.4 pydantic +onnxruntime==1.20.0 +kafka-python==2.0.2 +psycopg2-binary==2.9.10 +prometheus_client==0.21.0 +PyYAML==6.0.2 +python-dotenv==1.0.1 +requests==2.32.3 +python-multipart==0.0.6 +confluent_kafka==2.12.0 \ No newline at end of file diff --git a/services/inference_http/weights/yolov8-fruits.pt b/services/inference_http/weights/yolov8-fruits.pt new file mode 100644 index 000000000..d61ef50d3 Binary files /dev/null and b/services/inference_http/weights/yolov8-fruits.pt differ diff --git a/services/plant_stress/Dockerfile b/services/plant_stress/Dockerfile index 59490295b..097ca9a18 100644 --- a/services/plant_stress/Dockerfile +++ b/services/plant_stress/Dockerfile @@ -1,23 +1,64 @@ +# FROM python:3.11-slim + +# ENV PYTHONDONTWRITEBYTECODE=1 \ +# PYTHONUNBUFFERED=1 + +# RUN apt-get update && apt-get install -y --no-install-recommends \ +# libsndfile1 ffmpeg gcc \ +# && rm -rf /var/lib/apt/lists/* + +# WORKDIR /app + +# # COPY certs /app/certs + +# # RUN if [ -f /app/certs/netfree-ca.crt ]; then \ +# # echo "Installing NetFree certificate..."; \ +# # cp /app/certs/netfree-ca.crt /usr/local/share/ca-certificates/netfree-ca.crt && \ +# # update-ca-certificates; \ +# # else \ +# # echo "⚠️ WARNING: netfree-ca.crt not found, continuing without it."; \ +# # fi + +# ENV SSL_CERT_FILE=/etc/ssl/certs/ca-certificates.crt \ +# REQUESTS_CA_BUNDLE=/etc/ssl/certs/ca-certificates.crt \ +# PIP_CERT=/etc/ssl/certs/ca-certificates.crt + +# COPY requirements.txt ./ +# RUN pip install --no-cache-dir -r requirements.txt \ +# --trusted-host pypi.org \ +# --trusted-host pypi.python.org \ +# --trusted-host files.pythonhosted.org + +# COPY src/ /app/ + +# ENV INPUT_DIR=/data/inbox \ +# MODEL_DIR=/models \ +# PERIOD_DAYS=7 \ +# CONF_THRESHOLD=0.0 \ +# POSTGRES_DSN="postgresql://postgres:pg123@host.docker.internal:5432/missions_db" + +# CMD ["python", "/app/app.py"] + FROM python:3.11-slim ENV PYTHONDONTWRITEBYTECODE=1 \ PYTHONUNBUFFERED=1 RUN apt-get update && apt-get install -y --no-install-recommends \ - libsndfile1 ffmpeg gcc \ + libsndfile1 ffmpeg tzdata \ && rm -rf /var/lib/apt/lists/* WORKDIR /app -COPY certs /app/certs +# COPY certs /app/certs -RUN if [ -f /app/certs/netfree-ca.crt ]; then \ - echo "Installing NetFree certificate..."; \ - cp /app/certs/netfree-ca.crt /usr/local/share/ca-certificates/netfree-ca.crt && \ - update-ca-certificates; \ - else \ - echo "⚠️ WARNING: netfree-ca.crt not found, continuing without it."; \ - fi +# RUN if [ -f /app/certs/netfree-ca.crt ]; then \ +# echo "Installing NetFree certificate..."; \ +# cp /app/certs/netfree-ca.crt /usr/local/share/ca-certificates/netfree-ca.crt && \ +# update-ca-certificates; \ +# else \ +# echo "⚠️ WARNING: netfree-ca.crt not found, continuing without it."; \ +# fi ENV SSL_CERT_FILE=/etc/ssl/certs/ca-certificates.crt \ REQUESTS_CA_BUNDLE=/etc/ssl/certs/ca-certificates.crt \ @@ -29,12 +70,21 @@ RUN pip install --no-cache-dir -r requirements.txt \ --trusted-host pypi.python.org \ --trusted-host files.pythonhosted.org +# Copy your source code COPY src/ /app/ -ENV INPUT_DIR=/data/inbox \ - MODEL_DIR=/models \ - PERIOD_DAYS=7 \ - CONF_THRESHOLD=0.0 \ - POSTGRES_DSN="postgresql://postgres:pg123@host.docker.internal:5432/missions_db" +# Copy model files if needed +# COPY models/ /models/ + +ENV MODEL_DIR=/models \ + TIMEZONE=Asia/Jerusalem \ + MINIO_ENDPOINT=minio:9000 \ + MINIO_ACCESS_KEY=minioadmin \ + MINIO_SECRET_KEY=minioadmin123 \ + MINIO_BUCKET=sound \ + MINIO_PREFIX=plants/ \ + MINIO_SECURE=false \ + POSTGRES_DSN=postgresql://missions_user:pg123@postgres:5432/missions_db \ + CONFIDENCE_THRESHOLD=0.60 -CMD ["python", "/app/app.py"] +CMD ["python", "/app/predict_minio_daily.py"] diff --git a/services/plant_stress/requirements.txt b/services/plant_stress/requirements.txt index e29f9b682..869c5ea97 100644 --- a/services/plant_stress/requirements.txt +++ b/services/plant_stress/requirements.txt @@ -4,4 +4,8 @@ soundfile tensorflow psycopg2-binary requests -urllib3 \ No newline at end of file +urllib3 +pytz +minio +keras +kafka-python \ No newline at end of file diff --git a/services/plant_stress/run_plant_stress_daily.sh b/services/plant_stress/run_plant_stress_daily.sh new file mode 100644 index 000000000..3d449ad43 --- /dev/null +++ b/services/plant_stress/run_plant_stress_daily.sh @@ -0,0 +1,15 @@ +#!/usr/bin/env bash +set -euo pipefail + +PROJECT_DIR="/mnt/c/Users/This User/Desktop/project_05112025/AgCloud" +LOG_DIR="$PROJECT_DIR/services/plant_stress/logs" +STAMP="$(date +%F)" +LOG_FILE="$LOG_DIR/cron_${STAMP}.log" + +mkdir -p "$LOG_DIR" + +echo "[cron start $(date '+%Y-%m-%d %H:%M:%S')]" >> "$LOG_FILE" + +cd "$PROJECT_DIR" + +exec /usr/bin/docker compose run --rm plant_stress_daily >> "$LOG_FILE" 2>&1 diff --git a/services/plant_stress/samples/id_89_sound_14.wav b/services/plant_stress/samples/id_89_sound_14.wav deleted file mode 100644 index bd55e276c..000000000 Binary files a/services/plant_stress/samples/id_89_sound_14.wav and /dev/null differ diff --git a/services/plant_stress/samples/id_89_sound_15.wav b/services/plant_stress/samples/id_89_sound_15.wav deleted file mode 100644 index e0cef995e..000000000 Binary files a/services/plant_stress/samples/id_89_sound_15.wav and /dev/null differ diff --git a/services/plant_stress/samples/id_89_sound_16.wav b/services/plant_stress/samples/id_89_sound_16.wav deleted file mode 100644 index 28ae9b01a..000000000 Binary files a/services/plant_stress/samples/id_89_sound_16.wav and /dev/null differ diff --git a/services/plant_stress/samples/id_89_sound_17.wav b/services/plant_stress/samples/id_89_sound_17.wav deleted file mode 100644 index ea418e663..000000000 Binary files a/services/plant_stress/samples/id_89_sound_17.wav and /dev/null differ diff --git a/services/plant_stress/samples/id_89_sound_18.wav b/services/plant_stress/samples/id_89_sound_18.wav deleted file mode 100644 index 05fd1f354..000000000 Binary files a/services/plant_stress/samples/id_89_sound_18.wav and /dev/null differ diff --git a/services/plant_stress/src/app.py b/services/plant_stress/src/app.py deleted file mode 100644 index 4afbbb809..000000000 --- a/services/plant_stress/src/app.py +++ /dev/null @@ -1,207 +0,0 @@ -import os, sys, time, glob, pickle, datetime as dt -from pathlib import Path -import numpy as np -import librosa -import tensorflow as tf -import psycopg2 -import psycopg2.extras - -# ======== Environment Config ======== -INPUT_DIR = os.environ.get("INPUT_DIR", "/data/inbox") -MODEL_DIR = os.environ.get("MODEL_DIR", "/models") -POSTGRES_DSN = os.environ.get("POSTGRES_DSN", "postgresql://postgres:postgres@localhost:5432/postgres") -PERIOD_DAYS = int(os.environ.get("PERIOD_DAYS", "0")) # 0 = process all files - -# ======== Audio Parameters (2ms @ 500kHz) ======== -SAMPLE_RATE = 500_000 -DURATION_MS = 2 -N_SAMPLES = int(SAMPLE_RATE * DURATION_MS / 1000) # 1000 samples -N_FFT = 256 -HOP_LENGTH = 64 -N_MELS = 64 - -# ======== Watering Status Mapping ======== -CLASS_TO_STATUS = { - "Drought_Tomato": "Watering required", - "Drought_Tobacco": "Watering required", - "Control_Empty": "Normal / Empty", - "Control_Greenhouse": "Greenhouse noise / Normal", -} -CONFIDENCE_THRESHOLD = float(os.environ.get("CONFIDENCE_THRESHOLD", "0.60")) - -# ======== Load Model / Scaler / LabelEncoder ======== -model_path = os.path.join(MODEL_DIR, "ultrasonic_plant_cnn.keras") -scaler_path = os.path.join(MODEL_DIR, "scaler_params.npz") -le_path = os.path.join(MODEL_DIR, "label_encoder.pkl") - -print(f"INPUT_DIR={INPUT_DIR}") -print(f"MODEL_DIR={MODEL_DIR}") -print(f"POSTGRES_DSN={POSTGRES_DSN}") - -MODEL = tf.keras.models.load_model(model_path) - -sc = np.load(scaler_path) -SCALER_MEAN = sc["mean"] -SCALER_SCALE = sc["scale"] - -with open(le_path, "rb") as f: - LABEL_ENCODER = pickle.load(f) - -# ======== Helper Functions ======== -def load_and_preprocess_audio(file_path: str): - """Load audio, resample to 500kHz, crop/pad to 2ms (1000 samples).""" - audio, sr = librosa.load(file_path, sr=None, mono=True) - if sr != SAMPLE_RATE: - audio = librosa.resample(audio, orig_sr=sr, target_sr=SAMPLE_RATE) - sr = SAMPLE_RATE - - if len(audio) > N_SAMPLES: - start = (len(audio) - N_SAMPLES) // 2 - audio = audio[start:start + N_SAMPLES] - elif len(audio) < N_SAMPLES: - pad = N_SAMPLES - len(audio) - audio = np.pad(audio, (0, pad), mode='constant') - - return audio.astype(np.float32), sr - -def extract_ultrasonic_features(audio: np.ndarray, sr: int): - """Short time/frequency features for 2ms window.""" - feats = [] - feats.extend([np.mean(audio), np.std(audio), np.max(audio), np.min(audio), - np.var(audio), np.median(audio)]) - zcr = librosa.feature.zero_crossing_rate(audio, hop_length=HOP_LENGTH)[0] - feats.extend([np.mean(zcr), np.std(zcr), np.max(zcr)]) - fft = np.abs(np.fft.fft(audio))[:len(audio)//2] - feats.extend([np.mean(fft), np.std(fft), np.max(fft), np.argmax(fft)]) - try: - spectral_centroids = librosa.feature.spectral_centroid(y=audio, sr=sr, hop_length=HOP_LENGTH)[0] - spectral_rolloff = librosa.feature.spectral_rolloff(y=audio, sr=sr, hop_length=HOP_LENGTH)[0] - feats.extend([np.mean(spectral_centroids), np.mean(spectral_rolloff)]) - except Exception: - feats.extend([0.0, 0.0]) - rms = librosa.feature.rms(y=audio, hop_length=HOP_LENGTH)[0] - feats.extend([np.mean(rms), np.std(rms)]) - return np.array(feats, dtype=np.float32) - -def create_spectrogram_features(audio: np.ndarray, sr: int): - """Small Mel-spectrogram adapted for 2ms.""" - mel = librosa.feature.melspectrogram( - y=audio, sr=sr, n_fft=N_FFT, hop_length=HOP_LENGTH, - n_mels=N_MELS, fmax=sr//2 - ) - mel_db = librosa.power_to_db(mel, ref=np.max) - mel_norm = (mel_db - mel_db.min()) / (mel_db.max() - mel_db.min() + 1e-8) - return mel_norm.astype(np.float32) - -def normalize_features(x: np.ndarray): - return (x - SCALER_MEAN) / SCALER_SCALE - -def files_to_process(root: str, period_days: int): - p = Path(root) - if not p.exists(): - print(f"[!] INPUT_DIR not found: {root}") - return [] - exts = {".wav", ".WAV"} - allf = [str(pp) for pp in p.rglob("*") if pp.suffix in exts] - if period_days <= 0: - print("[i] Time filter disabled (PERIOD_DAYS<=0): processing ALL .wav files") - return sorted(allf) - cutoff = time.time() - period_days * 86400 - return sorted([f for f in allf if Path(f).stat().st_mtime >= cutoff]) - -def ensure_table(conn): - sql = """ - CREATE TABLE IF NOT EXISTS ultrasonic_plant_predictions ( - id BIGSERIAL PRIMARY KEY, - file TEXT, - predicted_class TEXT, - confidence DOUBLE PRECISION, - watering_status TEXT, - status TEXT, - prediction_time TIMESTAMPTZ DEFAULT now() - ); - """ - with conn.cursor() as cur: - cur.execute(sql) - conn.commit() - -def insert_rows(conn, rows): - """rows: list of tuples(file, predicted_class, confidence, watering_status, status, prediction_time)""" - sql = """ - INSERT INTO ultrasonic_plant_predictions - (file, predicted_class, confidence, watering_status, status, prediction_time) - VALUES %s - """ - with conn.cursor() as cur: - psycopg2.extras.execute_values(cur, sql, rows, page_size=500) - conn.commit() - -# ======== Main Run ======== -def main(): - files = files_to_process(INPUT_DIR, PERIOD_DAYS) - if not files: - print("No files to process. Exiting.") - return 0 - - try: - conn = psycopg2.connect(POSTGRES_DSN) - ensure_table(conn) - except Exception as e: - print(f"[!] Postgres connection error: {e}") - return 2 - - batch = [] - ok, fail = 0, 0 - start = time.time() - - for fpath in files: - try: - audio, sr = load_and_preprocess_audio(fpath) - feats = extract_ultrasonic_features(audio, sr) - spec = create_spectrogram_features(audio, sr) - - feats_norm = normalize_features(feats) - feats_batch = feats_norm[np.newaxis, :] - spec_batch = spec[np.newaxis, ..., np.newaxis] - - probs = MODEL.predict([feats_batch, spec_batch], verbose=0)[0] - idx = int(np.argmax(probs)) - pred_class = LABEL_ENCODER.classes_[idx] - conf = float(probs[idx]) - - watering_status = CLASS_TO_STATUS.get(pred_class, "Undefined") - if conf < CONFIDENCE_THRESHOLD: - watering_status = f"{watering_status} (Uncertain)" - - batch.append(( - fpath, str(pred_class), conf, watering_status, "Success", - dt.datetime.utcnow() - )) - ok += 1 - print(f"OK {fpath} -> {pred_class} ({conf:.3f})") - except Exception as e: - fail += 1 - batch.append((fpath, "", None, "", f"Error: {e}", dt.datetime.utcnow())) - print(f"[ERR] {fpath} -> {e}") - - - # Write to Postgres - try: - if batch: - insert_rows(conn, batch) - print(f"Inserted {len(batch)} rows to Postgres.") - except Exception as e: - print(f"[!] Insert error: {e}") - return 3 - finally: - try: - conn.close() - except Exception: - pass - - elapsed = time.time() - start - print(f"Done. processed={len(files)} ok={ok} fail={fail} elapsed_sec={elapsed:.1f}") - return 0 - -if __name__ == "__main__": - sys.exit(main()) diff --git a/services/plant_stress/src/db_api_client.py b/services/plant_stress/src/db_api_client.py deleted file mode 100644 index f9d9e370b..000000000 --- a/services/plant_stress/src/db_api_client.py +++ /dev/null @@ -1,223 +0,0 @@ -# # import json -# # import time -# # import pathlib -# # import requests -# # from urllib.parse import quote -# # from requests.adapters import HTTPAdapter -# # from urllib3.util.retry import Retry -# # # ---------- CONFIG ---------- -# # DB_API_BASE = "http://host.docker.internal:8001" -# # DB_API_AUTH_MODE = "service" -# # DB_API_TOKEN_FILE = "/app/secret/db_api_token" -# # DB_API_TOKEN = "auto" -# # DB_API_SERVICE_NAME = "GUI_H" -# # # ---------- TOKEN BOOTSTRAP ---------- -# # def _safe_join_url(base: str, path: str) -> str: -# # return f"{base.rstrip('/')}/{path.lstrip('/')}" -# # def _read_token_from_file(path: str) -> str | None: -# # p = pathlib.Path(path) -# # if p.exists(): -# # token = p.read_text(encoding="utf-8").strip() -# # return token or None -# # return None -# # def _fetch_token_via_dev_bootstrap(base: str, retries: int = 3, backoff: float = 0.8) -> str | None: -# # url = _safe_join_url(base, "/auth/_dev_bootstrap") -# # payload = {"service_name": DB_API_SERVICE_NAME, "rotate_if_exists": True} -# # for attempt in range(1, retries + 1): -# # try: -# # r = requests.post(url, json=payload, timeout=10) -# # if r.status_code in (200, 201): -# # data = r.json() -# # raw = (data.get("service_account", {}) or {}).get("raw_token") \ -# # or (data.get("service_account", {}) or {}).get("token") -# # if raw and isinstance(raw, str) and "***" not in raw: -# # return raw.strip() -# # except Exception: -# # time.sleep(backoff * attempt) -# # return None -# # def get_or_bootstrap_token() -> str | None: -# # print(f"[DEBUG] Checking for existing token file at: {DB_API_TOKEN_FILE}", flush=True) -# # if DB_API_TOKEN and DB_API_TOKEN.lower() != "auto": -# # print(f"[DEBUG] Using static token from config", flush=True) -# # return DB_API_TOKEN -# # token = _read_token_from_file(DB_API_TOKEN_FILE) -# # if token: -# # print(f"[DEBUG] Loaded token from {DB_API_TOKEN_FILE}", flush=True) -# # return token -# # print(f"[DEBUG] No existing token found, bootstrapping via {DB_API_BASE}/auth/_dev_bootstrap", flush=True) -# # token = _fetch_token_via_dev_bootstrap(DB_API_BASE) -# # if token: -# # pathlib.Path(DB_API_TOKEN_FILE).parent.mkdir(parents=True, exist_ok=True) -# # pathlib.Path(DB_API_TOKEN_FILE).write_text(token, encoding="utf-8") -# # print(f"[BOOTSTRAP] wrote token to {DB_API_TOKEN_FILE}", flush=True) -# # return token -# # print("[BOOTSTRAP][ERROR] Failed to obtain token.", flush=True) -# # return None -# # # ---------- API CLIENT ---------- -# # class DashboardApi: -# # def __init__(self): -# # self.base = DB_API_BASE.rstrip("/") -# # self.http = requests.Session() -# # token = get_or_bootstrap_token() -# # if token: -# # if DB_API_AUTH_MODE == "service": -# # self.http.headers.update({"X-Service-Token": token}) -# # else: -# # self.http.headers.update({"Authorization": f"Bearer {token}"}) -# # self.http.headers.update({"Content-Type": "application/json"}) -# # self.http.mount("http://", HTTPAdapter(max_retries=Retry(total=5, backoff_factor=0.5, status_forcelist=[500, 502, 503, 504]))) -# # self.http.mount("https://", HTTPAdapter(max_retries=Retry(total=5, backoff_factor=0.5, status_forcelist=[500, 502, 503, 504]))) -# # # ---------- METHODS ---------- -# # def list_devices(self, model: str | None = None) -> list[dict]: -# # url = f"{self.base}/api/devices" -# # if model: -# # url += f"?model={model}" -# # try: -# # r = self.http.get(url, timeout=10) -# # if r.status_code == 200: -# # return r.json() -# # print(f"[API ERROR] {r.status_code}: {r.text[:100]}") -# # except Exception as e: -# # print(f"[API FAIL] {e}") -# # return [] - - -# # services/plant_stress/src/db_api_client.py -# # services/plant_stress/src/db_api_client.py -# import os, pathlib, time, requests - -# DB_API_BASE = os.getenv("DB_API_BASE", "http://db_api_service:8001").rstrip("/") -# DB_API_AUTH_MODE = os.getenv("DB_API_AUTH_MODE", "service") -# DB_API_TOKEN_FILE = os.getenv("DB_API_TOKEN_FILE", "/tmp/db_api_token") -# DB_API_TOKEN = os.getenv("DB_API_TOKEN", "auto") -# DB_API_SERVICE_NAME = os.getenv("DB_API_SERVICE_NAME", "plant_stress") - -# def _join(b, p): -# return f"{b.rstrip('/')}/{p.lstrip('/')}" - -# def _read_token(path): -# p = pathlib.Path(path) -# return p.read_text(encoding="utf-8").strip() if p.exists() else None - -# def _bootstrap_token(base, retries=3, backoff=0.8): -# url = _join(base, "/auth/_dev_bootstrap") -# payload = {"service_name": DB_API_SERVICE_NAME, "rotate_if_exists": True} -# for i in range(1, retries+1): -# try: -# r = requests.post(url, json=payload, timeout=10) -# if r.status_code in (200, 201): -# sa = (r.json().get("service_account", {}) or {}) -# raw = sa.get("raw_token") or sa.get("token") -# if raw and isinstance(raw, str) and "***" not in raw: -# return raw.strip() -# except Exception: -# time.sleep(backoff * i) -# return None - -# def get_or_bootstrap_token(): -# if DB_API_TOKEN and DB_API_TOKEN.lower() != "auto": -# return DB_API_TOKEN -# tok = _read_token(DB_API_TOKEN_FILE) -# if tok: -# return tok -# tok = _bootstrap_token(DB_API_BASE) -# if tok: -# pathlib.Path(DB_API_TOKEN_FILE).parent.mkdir(parents=True, exist_ok=True) -# pathlib.Path(DB_API_TOKEN_FILE).write_text(tok, encoding="utf-8") -# print(f"[BOOTSTRAP] wrote token to {DB_API_TOKEN_FILE}", flush=True) -# return tok -# print("[BOOTSTRAP][ERROR] Failed to obtain token.", flush=True) -# return None - -# # ---- requests.Session() גלובלי ---- -# _SESSION = None -# def get_session(): -# global _SESSION -# if _SESSION is None: -# s = requests.Session() -# tok = get_or_bootstrap_token() -# if tok: -# if DB_API_AUTH_MODE == "service": -# s.headers.update({"X-Service-Token": tok}) -# else: -# s.headers.update({"Authorization": f"Bearer {tok}"}) -# s.headers.update({"Content-Type": "application/json"}) -# _SESSION = s -# return _SESSION - -# FILES_POST = "/api/files" -# FILES_PUT_TPL = "/api/files/{bucket}/{object_key}" - -# def _variants_for_create(meta: dict): -# """נפיק כמה וריאציות פייפ/load נפוצות ל-POST /api/files""" -# b = meta.get("bucket"); k = meta.get("object_key") -# mime = meta.get("mime", "audio/wav") -# tags = meta.get("tags", []) -# mdata = { -# "service": meta.get("service"), -# "timestamp": meta.get("timestamp"), -# "predicted_class": meta.get("predicted_class"), -# "confidence": meta.get("confidence"), -# } -# return [ -# {"bucket": b, "object_key": k, "mime": mime, "tags": tags, "metadata": mdata}, -# {"bucket": b, "object_key": k, "content_type": mime, "tags": tags, "metadata": mdata}, -# {"bucket": b, "object_key": k, "mime": mime, "labels": tags, "metadata": mdata}, -# {"bucket": b, "object_key": k, "mime": mime, "meta": mdata}, -# {"bucket": b, "object_key": k}, # מינימלי מאוד (ייתכן שהסכמה דורשת עוד – נבדוק 422) -# ] - -# def _variants_for_update(meta: dict): -# """וריאציות PUT /api/files/{bucket}/{object_key}""" -# mime = meta.get("mime", "audio/wav") -# tags = meta.get("tags", []) -# mdata = { -# "service": meta.get("service"), -# "timestamp": meta.get("timestamp"), -# "predicted_class": meta.get("predicted_class"), -# "confidence": meta.get("confidence"), -# } -# return [ -# {"mime": mime, "tags": tags, "metadata": mdata}, -# {"content_type": mime, "tags": tags, "metadata": mdata}, -# {"metadata": mdata}, -# {"tags": tags}, -# ] - -# def write_db_entry(meta: dict) -> bool: -# """ -# רושם/מעדכן רשומת קובץ + מטא-דטה ב-DB API דרך Files API. -# """ -# s = get_session() -# b = meta.get("bucket"); k = meta.get("object_key") -# if not b or not k: -# print("[API ERROR] meta must include 'bucket' and 'object_key'") -# return False - -# # 1) נסה POST /api/files עם כמה וריאציות -# for payload in _variants_for_create(meta): -# try: -# r = s.post(_join(DB_API_BASE, FILES_POST), json=payload, timeout=15) -# if r.status_code in (200, 201): -# return True -# # אם האובייקט כבר קיים או הסכמה לא תואמת – ננסה וריאציה אחרת / נגלוש ל-PUT -# except Exception as e: -# print(f"[API ERROR] POST /api/files: {e}") - -# # 2) PUT /api/files/{bucket}/{object_key} (upsert/update) -# path = FILES_PUT_TPL.format(bucket=b, object_key=k) -# for payload in _variants_for_update(meta): -# try: -# r = s.put(_join(DB_API_BASE, path), json=payload, timeout=15) -# if r.status_code in (200, 201): -# return True -# # ננסה וריאציה הבאה -# except Exception as e: -# print(f"[API ERROR] PUT {path}: {e}") - -# try: -# print(f"[API ERROR] all variants failed for {b}/{k}. Last status={r.status_code} body={r.text[:300]}") -# except Exception: -# pass -# return False - diff --git a/services/plant_stress/src/predict_minio_daily.py b/services/plant_stress/src/predict_minio_daily.py new file mode 100644 index 000000000..24c6e7778 --- /dev/null +++ b/services/plant_stress/src/predict_minio_daily.py @@ -0,0 +1,603 @@ +import os, sys, time, pickle, datetime as dt +from pathlib import Path +import re +import uuid +import json +import numpy as np +import librosa +import tensorflow as tf +import psycopg2, psycopg2.extras +import pytz +from io import BytesIO +import soundfile as sf +from minio import Minio + +# ======== Environment ======== +MODEL_DIR = os.getenv("MODEL_DIR", "/models") +POSTGRES_DSN = os.getenv("POSTGRES_DSN", "postgresql://postgres:postgres@localhost:5432/postgres") + +# MinIO +MINIO_ENDPOINT = os.getenv("MINIO_ENDPOINT", "minio:9000") +MINIO_ACCESS = os.getenv("MINIO_ACCESS_KEY", "minioadmin") +MINIO_SECRET = os.getenv("MINIO_SECRET_KEY", "minioadmin123") +MINIO_BUCKET = os.getenv("MINIO_BUCKET", "sound") +MINIO_PREFIX = os.getenv("MINIO_PREFIX", "plants/") +MINIO_SECURE = os.getenv("MINIO_SECURE", "false").lower() == "true" + +# Defaults for required GUI fields +DEFAULT_AREA = os.getenv("DEFAULT_AREA", "unknown").strip() +DEFAULT_LAT = os.getenv("DEFAULT_LAT", "0.0").strip() +DEFAULT_LON = os.getenv("DEFAULT_LON", "0.0").strip() +DEFAULT_IMAGE_URL = os.getenv("DEFAULT_IMAGE_URL", "https://example.com/placeholder.jpg").strip() +DEFAULT_VOD = os.getenv("DEFAULT_VOD", "https://example.com/placeholder.mp4").strip() +DEFAULT_HLS = os.getenv("DEFAULT_HLS", "https://example.com/placeholder.m3u8").strip() + +# Date / TZ +TIMEZONE = os.getenv("TIMEZONE", "Asia/Jerusalem") +PROCESS_DATE = os.getenv("PROCESS_DATE", "").strip() # YYYY-MM-DD (optional backfill) + +# Confidence +CONFIDENCE_THRESHOLD = float(os.getenv("CONFIDENCE_THRESHOLD", "0.60")) + +# ======== Audio Params (2ms @ 500kHz) ======== +SAMPLE_RATE = 500_000 +DURATION_MS = 2 +N_SAMPLES = int(SAMPLE_RATE * DURATION_MS / 1000) # 1000 samples +N_FFT = 256 +HOP_LENGTH = 64 +N_MELS = 64 + +# ======== Watering Status Mapping ======== +CLASS_TO_STATUS = { + "Drought_Tomato": "Watering required", + "Drought_Tobacco": "Watering required", + "Control_Empty": "Normal / Empty", + "Control_Greenhouse": "Greenhouse noise / Normal", +} + +# ======== Alerts / Kafka ======== +ENABLE_ALERTS = os.getenv("ENABLE_ALERTS", "true").lower() == "true" +ALERT_TOPIC = os.getenv("ALERT_TOPIC", "alerts") +ALERT_TYPE = os.getenv("ALERT_TYPE", "plant_drought_detected") +ALERT_AREA = os.getenv("ALERT_AREA", "").strip() +ALERT_LAT = os.getenv("ALERT_LAT", "").strip() +ALERT_LON = os.getenv("ALERT_LON", "").strip() +ALERT_IMAGE_URL = os.getenv("ALERT_IMAGE_URL", "").strip() +ALERT_VOD = os.getenv("ALERT_VOD", "").strip() +ALERT_HLS = os.getenv("ALERT_HLS", "").strip() + +KAFKA_BOOTSTRAP = os.getenv("KAFKA_BOOTSTRAP", "kafka:9092") +KAFKA_CLIENT_ID = os.getenv("KAFKA_CLIENT_ID", "plant-stress-producer") +KAFKA_SECURITY_PROTOCOL = os.getenv("KAFKA_SECURITY_PROTOCOL", "").strip() +KAFKA_SASL_MECHANISM = os.getenv("KAFKA_SASL_MECHANISM", "").strip() +KAFKA_SASL_USERNAME = os.getenv("KAFKA_SASL_USERNAME", "").strip() +KAFKA_SASL_PASSWORD = os.getenv("KAFKA_SASL_PASSWORD", "").strip() +KAFKA_SSL_CA = os.getenv("KAFKA_SSL_CA", "").strip() +KAFKA_SSL_CERT = os.getenv("KAFKA_SSL_CERT", "").strip() +KAFKA_SSL_KEY = os.getenv("KAFKA_SSL_KEY", "").strip() + +# ======== Load model/scaler/encoder ======== +model_path = os.path.join(MODEL_DIR, "ultrasonic_plant_cnn.keras") +scaler_path = os.path.join(MODEL_DIR, "scaler_params.npz") +le_path = os.path.join(MODEL_DIR, "label_encoder.pkl") + +print(f"MODEL_DIR={MODEL_DIR}") +print(f"POSTGRES_DSN={POSTGRES_DSN}") +print(f"MINIO {MINIO_ENDPOINT=} {MINIO_BUCKET=} {MINIO_PREFIX=} {MINIO_SECURE=}") +print(f"ALERTS enable={ENABLE_ALERTS} topic={ALERT_TOPIC} bootstrap={KAFKA_BOOTSTRAP}") + +try: + import keras + MODEL = keras.saving.load_model(model_path, compile=False) +except Exception as e_keras3: + print(f"[!] Keras 3 load failed: {e_keras3} -- falling back to tf.keras") + MODEL = tf.keras.models.load_model(model_path, compile=False) + +sc = np.load(scaler_path) +SCALER_MEAN = sc["mean"] +SCALER_SCALE = sc["scale"] + +with open(le_path, "rb") as f: + LABEL_ENCODER = pickle.load(f) + +# ======== Kafka Producer (lazy, dual-impl) ======== +class _KafkaProducer: + def __init__(self): + self.impl = None + self.mode = None + self._init_producer() + + def _init_producer(self): + if not ENABLE_ALERTS: + return + try: + from confluent_kafka import Producer + conf = {"bootstrap.servers": KAFKA_BOOTSTRAP, "client.id": KAFKA_CLIENT_ID} + if KAFKA_SECURITY_PROTOCOL: conf["security.protocol"] = KAFKA_SECURITY_PROTOCOL + if KAFKA_SASL_MECHANISM: conf["sasl.mechanisms"] = KAFKA_SASL_MECHANISM + if KAFKA_SASL_USERNAME: conf["sasl.username"] = KAFKA_SASL_USERNAME + if KAFKA_SASL_PASSWORD: conf["sasl.password"] = KAFKA_SASL_PASSWORD + if KAFKA_SSL_CA: conf["ssl.ca.location"] = KAFKA_SSL_CA + if KAFKA_SSL_CERT: conf["ssl.certificate.location"]= KAFKA_SSL_CERT + if KAFKA_SSL_KEY: conf["ssl.key.location"] = KAFKA_SSL_KEY + self.impl = Producer(conf) + self.mode = "confluent" + print("[Kafka] Using confluent-kafka Producer") + return + except Exception as e: + print(f"[Kafka] confluent-kafka unavailable: {e}") + + try: + from kafka import KafkaProducer + kwargs = { + "bootstrap_servers": KAFKA_BOOTSTRAP, + "client_id": KAFKA_CLIENT_ID, + "value_serializer": lambda v: json.dumps(v).encode("utf-8"), + "linger_ms": 10, + "acks": "all", + } + if KAFKA_SECURITY_PROTOCOL: kwargs["security_protocol"] = KAFKA_SECURITY_PROTOCOL + if KAFKA_SASL_MECHANISM: kwargs["sasl_mechanism"] = KAFKA_SASL_MECHANISM + if KAFKA_SASL_USERNAME and KAFKA_SASL_PASSWORD: + kwargs["sasl_plain_username"] = KAFKA_SASL_USERNAME + kwargs["sasl_plain_password"] = KAFKA_SASL_PASSWORD + if KAFKA_SSL_CA: kwargs["ssl_cafile"] = KAFKA_SSL_CA + if KAFKA_SSL_CERT: kwargs["ssl_certfile"] = KAFKA_SSL_CERT + if KAFKA_SSL_KEY: kwargs["ssl_keyfile"] = KAFKA_SSL_KEY + self.impl = KafkaProducer(**kwargs) + self.mode = "kafka-python" + print("[Kafka] Using kafka-python Producer") + except Exception as e2: + print(f"[Kafka] kafka-python unavailable: {e2}") + self.impl = None + self.mode = None + + def send(self, topic: str, value: dict): + if not ENABLE_ALERTS or self.impl is None: + return False + if self.mode == "confluent": + try: + self.impl.produce(topic, value=json.dumps(value).encode("utf-8")) + self.impl.poll(0) + return True + except Exception as e: + print(f"[Kafka] produce error (confluent): {e}") + return False + elif self.mode == "kafka-python": + try: + fut = self.impl.send(topic, value=value) + fut.get(timeout=5) + return True + except Exception as e: + print(f"[Kafka] produce error (kafka-python): {e}") + return False + return False + + def flush(self): + try: + if self.mode == "confluent" and self.impl is not None: + self.impl.flush(5) + elif self.mode == "kafka-python" and self.impl is not None: + self.impl.flush() + except Exception: + pass + +KAFKA_PRODUCER = _KafkaProducer() + +# ======== Helpers ======== +FILENAME_RE = re.compile( + r'(?P[^/_]+)_(?P\d{4}-\d{2}-\d{2})_(?P\d{2})-(?P\d{2})\.wav$', + re.IGNORECASE +) + +def _tz(): + return pytz.timezone(TIMEZONE) + +def _today_date(): + if PROCESS_DATE: + return dt.datetime.strptime(PROCESS_DATE, "%Y-%m-%d").date() + return dt.datetime.now(_tz()).date() + +def parse_from_name(key: str): + """ + mic1_2025-09-03_12-05.wav -> (sensor_id, aware-local-datetime) or (None, None) + """ + m = FILENAME_RE.search(key) + if not m: + return None, None + sensor = m.group("sensor") + y, mon, dd = map(int, m.group("date").split("-")) + hh = int(m.group("hour")); mm = int(m.group("minute")) + local_dt = _tz().localize(dt.datetime(y, mon, dd, hh, mm, 0)) + return sensor, local_dt + +def list_minio_wavs_for_date(client: Minio, bucket: str, prefix: str, the_date: dt.date): + selected = [] + for obj in client.list_objects(bucket, prefix=prefix, recursive=True): + key = obj.object_name + if not key.lower().endswith(".wav"): + continue + sensor, rec_local = parse_from_name(key) + if rec_local is not None: + if rec_local.date() == the_date: + selected.append((obj, sensor, rec_local)) + continue + lm_local = obj.last_modified.astimezone(_tz()) + if lm_local.date() == the_date: + selected.append((obj, sensor, lm_local)) + return selected + +def load_audio_from_minio(client: Minio, bucket: str, key: str): + resp = client.get_object(bucket, key) + try: + data = resp.read() + finally: + resp.close(); resp.release_conn() + bio = BytesIO(data) + audio, sr = sf.read(bio, dtype="float32", always_2d=False) + if isinstance(audio, np.ndarray) and audio.ndim == 2: + audio = audio.mean(axis=1) + if sr != SAMPLE_RATE: + audio = librosa.resample(audio, orig_sr=sr, target_sr=SAMPLE_RATE) + sr = SAMPLE_RATE + if len(audio) > N_SAMPLES: + start = (len(audio) - N_SAMPLES) // 2 + audio = audio[start:start + N_SAMPLES] + elif len(audio) < N_SAMPLES: + pad = N_SAMPLES - len(audio) + audio = np.pad(audio, (0, pad), mode='constant') + return audio.astype(np.float32), sr + +def extract_ultrasonic_features(audio: np.ndarray, sr: int): + feats = [] + feats.extend([np.mean(audio), np.std(audio), np.max(audio), np.min(audio), + np.var(audio), np.median(audio)]) + zcr = librosa.feature.zero_crossing_rate(audio, hop_length=HOP_LENGTH)[0] + feats.extend([np.mean(zcr), np.std(zcr), np.max(zcr)]) + fft = np.abs(np.fft.fft(audio))[:len(audio)//2] + feats.extend([np.mean(fft), np.std(fft), np.max(fft), np.argmax(fft)]) + try: + sc = librosa.feature.spectral_centroid(y=audio, sr=sr, hop_length=HOP_LENGTH)[0] + ro = librosa.feature.spectral_rolloff(y=audio, sr=sr, hop_length=HOP_LENGTH)[0] + feats.extend([np.mean(sc), np.mean(ro)]) + except Exception: + feats.extend([0.0, 0.0]) + rms = librosa.feature.rms(y=audio, hop_length=HOP_LENGTH)[0] + feats.extend([np.mean(rms), np.std(rms)]) + return np.array(feats, dtype=np.float32) + +def create_spectrogram_features(audio: np.ndarray, sr: int): + mel = librosa.feature.melspectrogram( + y=audio, sr=sr, n_fft=N_FFT, hop_length=HOP_LENGTH, + n_mels=N_MELS, fmax=sr//2 + ) + mel_db = librosa.power_to_db(mel, ref=np.max) + mel_norm = (mel_db - mel_db.min()) / (mel_db.max() - mel_db.min() + 1e-8) + return mel_norm.astype(np.float32) + +def normalize_features(x: np.ndarray): + return (x - SCALER_MEAN) / SCALER_SCALE + +# ======== DB ======== +def ensure_predictions_table(conn): + """Match your schema exactly: no sensor_id, no recording_time.""" + with conn.cursor() as cur: + cur.execute(""" + CREATE TABLE IF NOT EXISTS ultrasonic_plant_predictions ( + id BIGSERIAL PRIMARY KEY, + file TEXT, + predicted_class TEXT, + confidence DOUBLE PRECISION, + watering_status TEXT, + status TEXT, + prediction_time TIMESTAMPTZ DEFAULT now() + ); + """) + cur.execute("CREATE INDEX IF NOT EXISTS idx_upp_pred_time ON ultrasonic_plant_predictions(prediction_time DESC);") + cur.execute("CREATE INDEX IF NOT EXISTS idx_upp_class ON ultrasonic_plant_predictions(predicted_class);") + conn.commit() + +def ensure_alerts_table(conn): + with conn.cursor() as cur: + cur.execute(""" + CREATE TABLE IF NOT EXISTS alerts ( + alert_id TEXT PRIMARY KEY, + alert_type TEXT, + device_id TEXT, + started_at TIMESTAMPTZ, + ended_at TIMESTAMPTZ, + confidence DOUBLE PRECISION, + area TEXT, + lat DOUBLE PRECISION, + lon DOUBLE PRECISION, + severity INT DEFAULT 1, + image_url TEXT, + vod TEXT, + hls TEXT, + ack BOOLEAN DEFAULT FALSE, + meta JSONB, + created_at TIMESTAMPTZ DEFAULT now(), + updated_at TIMESTAMPTZ DEFAULT now() + ); + """) + cur.execute(""" + DO $$ + BEGIN + IF NOT EXISTS (SELECT 1 FROM pg_proc WHERE proname = 'set_updated_at') THEN + CREATE OR REPLACE FUNCTION set_updated_at() RETURNS trigger AS $f$ + BEGIN + NEW.updated_at = now(); + RETURN NEW; + END; + $f$ LANGUAGE plpgsql; + END IF; + END$$; + """) + cur.execute(""" + DO $$ + BEGIN + IF NOT EXISTS (SELECT 1 FROM pg_trigger WHERE tgname = 'trg_alerts_updated_at') THEN + CREATE TRIGGER trg_alerts_updated_at + BEFORE UPDATE ON alerts + FOR EACH ROW + EXECUTE PROCEDURE set_updated_at(); + END IF; + END$$; + """) + conn.commit() + +def insert_prediction_rows(conn, rows): + """ + rows: iterable of tuples shaped exactly as the table: + (file, predicted_class, confidence, watering_status, status, prediction_time) + """ + sql = """ + INSERT INTO ultrasonic_plant_predictions + (file, predicted_class, confidence, watering_status, status, prediction_time) + VALUES %s + """ + with conn.cursor() as cur: + psycopg2.extras.execute_values(cur, sql, rows, page_size=500) + conn.commit() + +def insert_alert_row(conn, alert: dict, started_at_dt: dt.datetime, ended_at_dt: dt.datetime | None = None, ack: bool = False): + from psycopg2.extras import Json + sql = """ + INSERT INTO alerts ( + alert_id, alert_type, device_id, started_at, ended_at, + confidence, area, lat, lon, severity, + image_url, vod, hls, ack, meta + ) + VALUES ( + %(alert_id)s, %(alert_type)s, %(device_id)s, %(started_at)s, %(ended_at)s, + %(confidence)s, %(area)s, %(lat)s, %(lon)s, %(severity)s, + %(image_url)s, %(vod)s, %(hls)s, %(ack)s, %(meta)s + ) + ON CONFLICT (alert_id) DO UPDATE + SET updated_at = now() + """ + params = { + "alert_id": alert["alert_id"], + "alert_type": alert.get("alert_type"), + "device_id": alert.get("device_id"), + "started_at": started_at_dt, + "ended_at": ended_at_dt, + "confidence": alert.get("confidence"), + "area": alert.get("area"), + "lat": alert.get("lat"), + "lon": alert.get("lon"), + "severity": alert.get("severity"), + "image_url": alert.get("image_url"), + "vod": alert.get("vod"), + "hls": alert.get("hls"), + "ack": ack, + "meta": Json(alert.get("meta", {})), + } + with conn.cursor() as cur: + cur.execute(sql, params) + conn.commit() + +# ======== Alert helpers ======== +def _severity_from_confidence(conf: float) -> int: + if conf >= 0.95: return 5 + if conf >= 0.90: return 4 + if conf >= 0.80: return 3 + if conf >= 0.70: return 2 + return 1 + +def _iso_utc(dt_aware) -> str: + return dt_aware.astimezone(pytz.UTC).strftime("%Y-%m-%dT%H:%M:%SZ") + +def build_alert_payload( + alert_type: str, + device_id: str, + started_at_utc: dt.datetime, + confidence: float, + s3url: str, + area: str = "", + lat: str = "", + lon: str = "", + image_url: str = "", + vod: str = "", + hls: str = "", + extra_meta: dict | None = None, +) -> dict: + def _to_float_or_default(v, default_s: str): + try: + if v is None or (isinstance(v, str) and v.strip() == ""): + return float(default_s) + return float(v) + except Exception: + return float(default_s) + + def _non_empty(value: str, default_value: str) -> str: + v = (value or "").strip() + return v if v else default_value + + area_f = _non_empty(area, DEFAULT_AREA) + lat_f = _to_float_or_default(lat, DEFAULT_LAT) + lon_f = _to_float_or_default(lon, DEFAULT_LON) + image_f = _non_empty(image_url, DEFAULT_IMAGE_URL) + vod_f = _non_empty(vod, DEFAULT_VOD) + hls_f = _non_empty(hls, DEFAULT_HLS) + + payload = { + "alert_id": str(uuid.uuid4()), + "alert_type": alert_type, + "device_id": device_id, + "started_at": _iso_utc(started_at_utc), + "confidence": round(confidence, 6), + "severity": _severity_from_confidence(confidence), + "area": area_f, + "lat": lat_f, + "lon": lon_f, + "image_url": image_f, + "vod": vod_f, + "hls": hls_f, + "meta": { + "source": "ultrasonic_plant_classifier", + "file": s3url, + }, + } + if extra_meta: + payload["meta"].update(extra_meta) + return payload + +# ======== Main ======== +def main(): + the_date = _today_date() + print(f"[i] Processing MinIO objects for date={the_date} (TZ={TIMEZONE})") + + client = Minio(MINIO_ENDPOINT, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE) + objs = list_minio_wavs_for_date(client, MINIO_BUCKET, MINIO_PREFIX, the_date) + if not objs: + print("No WAV objects for that date. Exiting.") + return 0 + + try: + conn = psycopg2.connect(POSTGRES_DSN) + ensure_alerts_table(conn) + ensure_predictions_table(conn) + except Exception as e: + print(f"[!] Postgres connection error: {e}") + return 2 + + batch = [] + ok, fail = 0, 0 + t0 = time.time() + + for (obj, sensor, rec_local_dt) in objs: + key = obj.object_name + s3url = f"s3://{MINIO_BUCKET}/{key}" + try: + if sensor is None: + sensor, _ = parse_from_name(key) + if sensor is None: + sensor = key.split("/")[-1].split("_")[0] + + audio, sr = load_audio_from_minio(client, MINIO_BUCKET, key) + feats = extract_ultrasonic_features(audio, sr) + spec = create_spectrogram_features(audio, sr) + + feats_norm = normalize_features(feats) + feats_batch = feats_norm[np.newaxis, :] + spec_batch = spec[np.newaxis, ..., np.newaxis] + + probs = MODEL.predict([feats_batch, spec_batch], verbose=0)[0] + idx = int(np.argmax(probs)) + pred_class = LABEL_ENCODER.classes_[idx] + conf = float(probs[idx]) + + watering_status = CLASS_TO_STATUS.get(pred_class, "Undefined") + if conf < CONFIDENCE_THRESHOLD: + watering_status = f"{watering_status} (Uncertain)" + + # Save prediction row (schema: no sensor_id/recording_time) + batch.append(( + s3url, # file + str(pred_class), # predicted_class + conf, # confidence + watering_status, # watering_status + "Success", # status + dt.datetime.utcnow() + )) + ok += 1 + print(f"OK {s3url} [{sensor} @ {rec_local_dt.isoformat()}] -> {pred_class} ({conf:.3f})") + + # Alerts for drought classes + if ENABLE_ALERTS and pred_class in ("Drought_Tomato", "Drought_Tobacco"): + rec_utc = rec_local_dt.astimezone(pytz.UTC) if rec_local_dt.tzinfo else pytz.UTC.localize(rec_local_dt) + alert = build_alert_payload( + alert_type=ALERT_TYPE, + device_id=str(sensor), + started_at_utc=rec_utc, + confidence=conf, + s3url=s3url, + area=ALERT_AREA, + lat=ALERT_LAT, + lon=ALERT_LON, + image_url=ALERT_IMAGE_URL, + vod=ALERT_VOD, + hls=ALERT_HLS, + extra_meta={ + "predicted_class": pred_class, + "watering_status": watering_status, + "model_dir": MODEL_DIR, + "sample_rate": SAMPLE_RATE, + "n_fft": N_FFT, + "n_mels": N_MELS + } + ) + try: + insert_alert_row(conn, alert, started_at_dt=rec_utc, ended_at_dt=None, ack=False) + print(f"[Alert][DB] upsert alert_id={alert['alert_id']} device={alert['device_id']} severity={alert['severity']}") + except Exception as e: + print(f"[Alert][DB] insert failed: {e}") + + # Send to Kafka (best effort) + try: + ok_send = KAFKA_PRODUCER.send(ALERT_TOPIC, alert) + if ok_send: + print(f"[Alert] sent to topic={ALERT_TOPIC}: {alert['alert_id']}") + else: + print(f"[Alert] FAILED to send alert to topic={ALERT_TOPIC}") + except Exception as e: + print(f"[Alert] send exception: {e}") + + except Exception as e: + fail += 1 + print(f"[ERR] {s3url} -> {e}") + batch.append(( + s3url, # file + "", # predicted_class + None, # confidence + "", # watering_status + f"Error: {e}", # status + dt.datetime.utcnow() + )) + + try: + if batch: + insert_prediction_rows(conn, batch) + print(f"Inserted {len(batch)} rows.") + except Exception as e: + print(f"[!] Insert error: {e}") + return 3 + finally: + try: + conn.close() + except: + pass + + # Flush Kafka + try: + KAFKA_PRODUCER.flush() + except Exception: + pass + + dt_sec = time.time() - t0 + print(f"Done. processed={len(objs)} ok={ok} fail={fail} elapsed_sec={dt_sec:.1f}") + return 0 + +if __name__ == "__main__": + sys.exit(main()) diff --git a/services/ripeness-baseline/src/main.py b/services/ripeness-baseline/src/main.py index 727d16b74..574b42af8 100644 --- a/services/ripeness-baseline/src/main.py +++ b/services/ripeness-baseline/src/main.py @@ -35,7 +35,6 @@ def list_minio_objects(): def imread_from_any(url: str): u = urlparse(url) - # נזהה שזה MinIO: אותו host כמו MINIO_URL, או הידוע שלך (minio-hot-1:9000) minio_base = os.getenv("MINIO_URL", "http://minio-hot-1:9000") mu = urlparse(minio_base) @@ -43,7 +42,6 @@ def imread_from_any(url: str): (u.hostname == "minio-hot-1" and (u.port or 80) == 9000) if is_minio: - # קריאה עם AK/SK cli = Minio( f"{mu.hostname}:{mu.port or (443 if mu.scheme=='https' else 80)}", access_key=os.getenv("MINIO_ACCESS_KEY"), @@ -63,7 +61,6 @@ def imread_from_any(url: str): arr = np.frombuffer(data, dtype=np.uint8) return cv.imdecode(arr, cv.IMREAD_COLOR) - # אחרת — HTTP רגיל safe = quote(url, safe="/:?=&%()[]") r = requests.get(safe, timeout=30) r.raise_for_status() @@ -74,8 +71,6 @@ def imread_from_any(url: str): def process_all(): DSN = dsn(PG) - # apply_sql_autocommit(DSN, SCHEMA_SQL) - # apply_sql_autocommit(DSN, VIEW_SQL) lookback_days = int(os.getenv("LOOKBACK_DAYS", "7")) rows = fetch_inference_logs(PG, lookback_days=lookback_days, diff --git a/services/ripeness-ml/.gitignore b/services/ripeness-ml/.gitignore new file mode 100644 index 000000000..17751b717 --- /dev/null +++ b/services/ripeness-ml/.gitignore @@ -0,0 +1,7 @@ +__pycache__/ +*.onnx +datasets/ +*.png +.venv/ +venv/ +ENV/ \ No newline at end of file diff --git a/services/ripeness-ml/README.md b/services/ripeness-ml/README.md new file mode 100644 index 000000000..19de71518 --- /dev/null +++ b/services/ripeness-ml/README.md @@ -0,0 +1,226 @@ +# Ripeness ML – API & Weekly Job + +A small **FastAPI** service that: +- Predicts fruit ripeness (**ripe / unripe / overripe**) for new images from **MinIO** based on the trained conditional model, and writes results to **Postgres**. +- Creates a **weekly rollup snapshot** (with TS window) per fruit. + +--- + +## 🧩 Repo layout (service) + +``` +services/ripeness-ml/ +├─ api/ +│ └─ ripeness_api.py # FastAPI endpoints (predict + rollup) +├─ jobs/ +│ └─ weekly_ripeness_job.py # model/minio/db helpers reused by the API +├─ model/ +│ ├─ architecture/ +│ │ └─ mobilenet_v3_large_head.py # Model architecture definition +│ └─ data/ +│ └─ data_multitask.py # Data loading and preprocessing +├─ checkpoints/ +│ └─ mobilenet_v3_large/ +│ └─ best_conditional.pt # trained model weights +├─ deploy/ +│ ├─ Dockerfile +│ └─ docker-compose.ripeness.yml +├─ configs/ +│ └─ config.yaml # Model and training configuration +├─ requirements.txt +└─ .env (optional) +``` + +--- + +## ⚙️ Requirements + +- Docker Desktop +- External Docker network: **agcloud_ag_cloud** (same as your existing stack) + +**Running services on that network:** +- Postgres (`postgres:5432`, DB: `missions_db`, user: `missions_user`) +- MinIO (`minio-hot:9000`) + +--- + +## 🌍 Environment variables + +Set via `docker-compose.ripeness.yml` or `.env`: + +| Name | Default | Notes | +|------|----------|-------| +| `PGHOST` | postgres | DB host (inside Docker network) | +| `PGPORT` | 5432 | | +| `PGDATABASE` | missions_db | | +| `PGUSER` | missions_user | | +| `PGPASSWORD` | pg123 | | +| `MINIO_ENDPOINT` | minio-hot:9000 | S3 API port is 9000 inside Docker | +| `MINIO_SECURE` | false | set true if TLS to MinIO | +| `MINIO_ACCESS_KEY` | minioadmin | | +| `MINIO_SECRET_KEY` | minioadmin | | +| `MODEL_PATH` | /models/best_conditional.pt | mounted from host | +| `MODEL_NAME` | best_conditional | stored in DB | +| `BATCH_LIMIT` | 500 | safety cap per run | +| `FRUITS` (optional) | Apple,Orange,Grape,Strawberry | if enabled in code | + +If you’re behind **NetFree/proxy**, copy your CA file to `deploy/certs/` and use the Dockerfile section that installs CA + `update-ca-certificates`. + +--- + +## 🐳 Build & Run (Docker) + +From `services/ripeness-ml/`: + +```bash +docker compose -f docker-compose.ripeness.yml build ripeness-api +docker compose -f docker-compose.ripeness.yml up -d ripeness-api +``` + +**Health check:** + +```bash +curl http://localhost:8088/healthz +``` + +# **logs** +```bash +docker logs -n 200 ripeness-api +``` + +--- + +## 🔌 API + +**Base URL:** `http://localhost:8088` + +### POST `/predict-last-week` +Runs prediction for images from the last 7 days that don’t have a record yet in `ripeness_predictions`. + +```bash +curl -X POST http://localhost:8088/predict-last-week +# -> {"processed": 17} +``` + +### POST `/predict-batch` +Run for a custom time window and limit. + +**Request body (JSON):** +```json +{ + "since_ts": "2025-10-01T00:00:00", + "limit": 1000 +} +``` + +**Example:** +```bash +curl -X POST http://localhost:8088/predict-batch -H "Content-Type: application/json" -d '{"since_ts":"2025-10-01T00:00:00","limit":1000}' +``` + +### POST `/rollup/weekly` +Creates a weekly snapshot into `ripeness_weekly_rollups_ts` for the last 7 days (creates the table if missing). + +```bash +curl -X POST http://localhost:8088/rollup/weekly +# -> {"ok": true} +``` + +--- + +## 🧮 Database schema + +### Predictions table +```sql +CREATE TABLE IF NOT EXISTS ripeness_predictions ( + id BIGSERIAL PRIMARY KEY, + inference_log_id BIGINT NOT NULL REFERENCES inference_logs(id) ON DELETE CASCADE, + ts TIMESTAMPTZ NOT NULL DEFAULT now(), + ripeness_label TEXT NOT NULL CHECK (ripeness_label IN ('ripe','unripe','overripe')), + ripeness_score DOUBLE PRECISION NOT NULL, + model_name TEXT NOT NULL, + UNIQUE (inference_log_id) +); +``` + +### Weekly rollups +```sql +CREATE TABLE IF NOT EXISTS ripeness_weekly_rollups_ts ( + id BIGSERIAL PRIMARY KEY, + ts TIMESTAMPTZ NOT NULL DEFAULT now(), -- snapshot time + window_start TIMESTAMPTZ NOT NULL, + window_end TIMESTAMPTZ NOT NULL, + fruit_type TEXT NOT NULL, + cnt_total INTEGER NOT NULL, + cnt_ripe INTEGER NOT NULL, + cnt_unripe INTEGER NOT NULL, + cnt_overripe INTEGER NOT NULL, + pct_ripe DOUBLE PRECISION NOT NULL +); +``` + +--- + +## 🔍 Useful queries + +**Show latest predictions joined with inference logs:** +```sql +SELECT il.id, il.fruit_type, il.image_url, rp.ripeness_label, rp.ripeness_score, rp.model_name, rp.ts +FROM inference_logs il +JOIN ripeness_predictions rp ON rp.inference_log_id = il.id +ORDER BY rp.ts DESC +LIMIT 20; +``` + +**Show rollup snapshots:** +```sql +SELECT ts::date AS snapshot_day, fruit_type, cnt_total, +cnt_ripe, cnt_unripe, cnt_overripe, +ROUND(pct_ripe*100,2) AS pct_ripe_pct +FROM ripeness_weekly_rollups_ts +ORDER BY ts DESC, fruit_type; +``` + +**From Docker (network agcloud_ag_cloud):** +```bash +docker run --rm --network agcloud_ag_cloud -e PGPASSWORD=pg123 postgres:16-alpine psql -h postgres -U missions_user -d missions_db -c "SELECT ts::date AS snapshot_day, fruit_type, cnt_total, cnt_ripe, cnt_unripe, cnt_overripe, ROUND(pct_ripe*100,2) AS pct_ripe_pct + FROM ripeness_weekly_rollups_ts + ORDER BY ts DESC, fruit_type;" +``` + +--- + +## 🕒 Scheduling (Windows Task Scheduler) + +Create a weekly job that first predicts, then rolls up. + +**run_weekly.ps1:** +```powershell +Invoke-RestMethod -Method Post -Uri "http://localhost:8088/predict-last-week" +# note: /predict-last-week now triggers the weekly rollup automatically, +# so a single call is sufficient (no duplicate predictions are inserted). +``` + +**Register task:** +```bash +schtasks /Create /TN "RipenessWeekly" /TR "powershell.exe -ExecutionPolicy Bypass -File C:\path\run_weekly.ps1" /SC WEEKLY /D MON /ST 03:00 +``` + +--- + +## 🧰 Troubleshooting + +- **MinIO errors / 9000 vs 9001:** inside Docker network always use `minio-hot:9000` (S3 API). + Ports 9001/9002 are host-exposed console/proxy. +- **SignatureDoesNotMatch:** wrong `MINIO_ACCESS_KEY`/`SECRET_KEY` or endpoint (should be the S3 API). +- **Model FRUITS mismatch:** ensure the FRUITS list in code matches the model checkpoint (e.g. include Grape if trained). +- **SSL to PyPI (NetFree/proxy):** add your CA to the image and run `update-ca-certificates`. +- **No rows processed:** endpoint processes only inference logs without an existing prediction; expand window with `/predict-batch`. + +--- + +## 👩‍💻 Maintainer + +**Name:** Ayala +**Service name:** ripeness-api +**Ports:** 8088/tcp diff --git a/services/ripeness-ml/api/ripeness_api.py b/services/ripeness-ml/api/ripeness_api.py new file mode 100644 index 000000000..0ea1f6991 --- /dev/null +++ b/services/ripeness-ml/api/ripeness_api.py @@ -0,0 +1,175 @@ +# scripts/ripeness_api.py +from fastapi import FastAPI +from pydantic import BaseModel +from datetime import datetime, timedelta +import sys, os +sys.path.append(os.path.join(os.path.dirname(__file__), "")) + +from jobs.weekly_ripeness_job import ( + get_conn, + fetch_from_minio, + load_image_for_model, + predict_ripeness, +) + +app = FastAPI(title="Ripeness Service") + + +class BatchRequest(BaseModel): + since_ts: datetime | None = None + limit: int = 500 + + +def run_batch(since_ts: datetime | None, limit: int) -> int: + if since_ts is None: + since_ts = datetime.utcnow() - timedelta(days=7) + with get_conn() as conn, conn.cursor() as cur: + cur.execute(""" + SELECT il.id, il.ts, il.fruit_type, il.image_url + FROM inference_logs il + LEFT JOIN ripeness_predictions rp ON rp.inference_log_id = il.id + WHERE il.ts >= %s + AND rp.id IS NULL + ORDER BY il.id ASC + LIMIT %s; + """, (since_ts, limit)) + rows = cur.fetchall() + + processed = 0 + # Generate a new run_id for this batch (once per batch) + with get_conn() as conn, conn.cursor() as cur: + cur.execute("SELECT gen_random_uuid()") + run_id = cur.fetchone()[0] + + for inflog_id, ts, fruit_type, image_url in rows: + try: + img_bytes = fetch_from_minio(image_url) + tensor = load_image_for_model(img_bytes) + label, score = predict_ripeness(tensor, fruit_type) + + # Parse bucket and object_key from image_url (expects format minio://bucket/object_key) + device_id = None + if image_url.startswith("minio://"): + path = image_url[len("minio://"):] + if "/" in path: + bucket, object_key = path.split("/", 1) + with get_conn() as conn, conn.cursor() as cur: + cur.execute(""" + SELECT device_id FROM files + WHERE bucket = %s AND object_key = %s + """, (bucket, object_key)) + res = cur.fetchone() + device_id = res[0] if res else None + + with get_conn() as conn, conn.cursor() as cur: + cur.execute(""" + INSERT INTO ripeness_predictions + (inference_log_id, ts, ripeness_label, ripeness_score, model_name, run_id, device_id) + VALUES (%s, now(), %s, %s, %s, %s, %s) + ON CONFLICT (inference_log_id) DO NOTHING; + """, (inflog_id, label, score, os.getenv("MODEL_NAME", "best_conditional"), run_id, device_id)) + processed += 1 + except Exception as e: + print(f"[ERR] inflog_id={inflog_id} :: {e}") + return processed + + +@app.get("/healthz") +def healthz(): + return {"ok": True} + + +@app.post("/predict-batch") +def predict_batch(req: BatchRequest): + n = run_batch(req.since_ts, req.limit) + return {"processed": n} + + +@app.post("/predict-last-week") +def predict_last_week(): + n = run_batch(None, int(os.getenv("BATCH_LIMIT", "500"))) + # After predicting new images, immediately create the weekly rollup + # This keeps the workflow to a single endpoint call (no duplicates because + # predictions use ON CONFLICT DO NOTHING) + try: + insert_weekly_rollup() + return {"processed": n, "rollup": True} + except Exception as e: + # Log the error but still return the number of processed items + print(f"[ERR] rollup: {e}") + return {"processed": n, "rollup": False, "error": str(e)} + + +def insert_weekly_rollup(): + ddl = """ + CREATE TABLE IF NOT EXISTS ripeness_weekly_rollups_ts ( + id BIGSERIAL PRIMARY KEY, + ts TIMESTAMPTZ NOT NULL DEFAULT now(), + window_start TIMESTAMPTZ NOT NULL, + window_end TIMESTAMPTZ NOT NULL, + fruit_type TEXT NOT NULL, + device_id TEXT, + run_id UUID, + cnt_total INTEGER NOT NULL, + cnt_ripe INTEGER NOT NULL, + cnt_unripe INTEGER NOT NULL, + cnt_overripe INTEGER NOT NULL, + pct_ripe DOUBLE PRECISION NOT NULL + ); + CREATE INDEX IF NOT EXISTS ix_rwrt_ts ON ripeness_weekly_rollups_ts(ts); + CREATE INDEX IF NOT EXISTS ix_rwrt_fruit_ts ON ripeness_weekly_rollups_ts(fruit_type, ts); + CREATE INDEX IF NOT EXISTS ix_rwrt_device ON ripeness_weekly_rollups_ts(device_id); + CREATE INDEX IF NOT EXISTS ix_rwrt_run ON ripeness_weekly_rollups_ts(run_id); + """ + + # optional filter by fruits from environment (comma-separated) + fruits_env = os.getenv("FRUITS") + fruits = None + fruit_where = "" + if fruits_env: + fruits = [f.strip() for f in fruits_env.split(",") if f.strip()] + # use = ANY(%s) with a TEXT[] parameter + fruit_where = "WHERE il.fruit_type = ANY(%s)" + + sql = """ + WITH w AS ( + SELECT now() - interval '7 days' AS ws, now() AS we + ), + agg AS ( + SELECT + il.fruit_type, + rp.device_id, + rp.run_id, + COUNT(*) AS cnt_total, + SUM(CASE WHEN rp.ripeness_label='ripe' THEN 1 ELSE 0 END) AS cnt_ripe, + SUM(CASE WHEN rp.ripeness_label='unripe' THEN 1 ELSE 0 END) AS cnt_unripe, + SUM(CASE WHEN rp.ripeness_label='overripe' THEN 1 ELSE 0 END) AS cnt_overripe + FROM ripeness_predictions rp + JOIN inference_logs il ON il.id = rp.inference_log_id + JOIN w ON rp.ts >= w.ws AND rp.ts < w.we + """ + ("\n " + fruit_where if fruit_where else "") + """ + GROUP BY il.fruit_type, rp.device_id, rp.run_id + ) + INSERT INTO ripeness_weekly_rollups_ts + (ts, window_start, window_end, fruit_type, device_id, run_id, cnt_total, cnt_ripe, cnt_unripe, cnt_overripe, pct_ripe) + SELECT + now(), (SELECT ws FROM w), (SELECT we FROM w), + fruit_type, device_id, run_id, cnt_total, cnt_ripe, cnt_unripe, cnt_overripe, + CASE WHEN cnt_total>0 THEN cnt_ripe::double precision/cnt_total ELSE 0 END + FROM agg; + """ + + with get_conn() as conn, conn.cursor() as cur: + cur.execute(ddl) + if fruits: + # psycopg2 adapts Python list to SQL array + cur.execute(sql, (fruits,)) + else: + cur.execute(sql) + return True + + +@app.post("/rollup/weekly") +def rollup_weekly(): + insert_weekly_rollup() + return {"ok": True} diff --git a/services/ripeness-ml/checkpoints/eval/classification_report.txt b/services/ripeness-ml/checkpoints/eval/classification_report.txt new file mode 100644 index 000000000..f2e2f5a54 --- /dev/null +++ b/services/ripeness-ml/checkpoints/eval/classification_report.txt @@ -0,0 +1,13 @@ + precision recall f1-score support + + unripe 1.0000 0.9981 0.9990 1041 + ripe 0.9983 1.0000 0.9991 1164 + overripe 1.0000 1.0000 1.0000 1534 + + accuracy 0.9995 3739 + macro avg 0.9994 0.9994 0.9994 3739 +weighted avg 0.9995 0.9995 0.9995 3739 + + +Accuracy: 0.9995 +Macro-F1: 0.9994 diff --git a/services/ripeness-ml/checkpoints/eval/metrics.json b/services/ripeness-ml/checkpoints/eval/metrics.json new file mode 100644 index 000000000..00bdd6890 --- /dev/null +++ b/services/ripeness-ml/checkpoints/eval/metrics.json @@ -0,0 +1,9 @@ +{ + "accuracy": 0.9994650976196844, + "macro_f1": 0.9993933641465831, + "per_class_f1": { + "unripe": 0.9990384615384615, + "ripe": 0.9991416309012876, + "overripe": 1.0 + } +} \ No newline at end of file diff --git a/services/ripeness-ml/checkpoints/mobilenet_v3_large/best_conditional.pt b/services/ripeness-ml/checkpoints/mobilenet_v3_large/best_conditional.pt new file mode 100644 index 000000000..0637d2fd0 Binary files /dev/null and b/services/ripeness-ml/checkpoints/mobilenet_v3_large/best_conditional.pt differ diff --git a/services/ripeness-ml/checkpoints/mobilenet_v3_large/best_conditional_frozen.pt b/services/ripeness-ml/checkpoints/mobilenet_v3_large/best_conditional_frozen.pt new file mode 100644 index 000000000..29b9cfb1c Binary files /dev/null and b/services/ripeness-ml/checkpoints/mobilenet_v3_large/best_conditional_frozen.pt differ diff --git a/services/ripeness-ml/checkpoints/mobilenet_v3_large/best_conditional_unfrozen.pt b/services/ripeness-ml/checkpoints/mobilenet_v3_large/best_conditional_unfrozen.pt new file mode 100644 index 000000000..c9ef651e1 Binary files /dev/null and b/services/ripeness-ml/checkpoints/mobilenet_v3_large/best_conditional_unfrozen.pt differ diff --git a/services/ripeness-ml/configs/config.yaml b/services/ripeness-ml/configs/config.yaml new file mode 100644 index 000000000..bfd6c7863 --- /dev/null +++ b/services/ripeness-ml/configs/config.yaml @@ -0,0 +1,25 @@ +seed: 42 +classes: ["unripe", "ripe", "overripe"] +img_size: 224 +batch_size: 32 +num_workers: 0 +epochs_frozen: 5 +epochs_unfrozen: 10 +lr: 0.0003 +weight_decay: 0.0001 +label_smoothing: 0.05 +use_class_weights: true +train_dir: "data/train" +val_dir: "data/val" +test_dir: "data/test" +checkpoint_dir: "checkpoints/mobilenet_v3_large" +best_metric: "f1_macro" + +fruits: ["apple","banana","orange"] +ripeness: ["unripe","ripe","overripe"] + +csv: + train: "data_mt_train/train.csv" + val: "data_mt_train/val.csv" + test: "data_mt_test/test.csv" + diff --git a/services/ripeness-ml/deploy/Dockerfile b/services/ripeness-ml/deploy/Dockerfile new file mode 100644 index 000000000..a232874d7 --- /dev/null +++ b/services/ripeness-ml/deploy/Dockerfile @@ -0,0 +1,59 @@ +FROM python:3.11-slim + +ENV PYTHONDONTWRITEBYTECODE=1 \ + PYTHONUNBUFFERED=1 \ + PIP_NO_CACHE_DIR=1 + +WORKDIR /app + +RUN apt-get update && apt-get install -y --no-install-recommends \ + ca-certificates openssl libpq-dev build-essential gcc \ + && rm -rf /var/lib/apt/lists/* + +COPY deploy/certs/ /usr/local/share/ca-certificates/ +RUN set -eux; \ + for f in /usr/local/share/ca-certificates/*.cer; do \ + [ -f "$f" ] && openssl x509 -inform der -in "$f" -out "${f%.cer}.crt" && rm -f "$f" || true; \ + done; \ + update-ca-certificates + +RUN printf "[global]\n\ +cert = /etc/ssl/certs/ca-certificates.crt\n\ +index-url = https://pypi.org/simple\n\ +trusted-host =\n\ + pypi.org\n\ + files.pythonhosted.org\n\ + download.pytorch.org\n" > /etc/pip.conf + +ENV SSL_CERT_FILE=/etc/ssl/certs/ca-certificates.crt \ + REQUESTS_CA_BUNDLE=/etc/ssl/certs/ca-certificates.crt + +COPY requirements.txt /app/ +RUN pip install --no-cache-dir --timeout 120 --index-url https://download.pytorch.org/whl/cpu \ + --trusted-host pypi.org --trusted-host pypi.python.org --trusted-host files.pythonhosted.org \ + torch==2.3.1 torchvision==0.18.1 \ + && pip install --no-cache-dir -r /app/requirements.txt \ + && pip install --no-cache-dir fastapi "uvicorn[standard]" +COPY api/ /app/api/ +COPY model/ /app/model +COPY jobs/ /app/jobs/ +COPY configs/ /app/configs/ +# Create models directory and copy model file +RUN mkdir -p /app/models +COPY checkpoints/mobilenet_v3_large/best_conditional.pt /app/models/best_conditional.pt + +# Create __init__.py files for Python modules +RUN touch /app/model/__init__.py \ + && touch /app/model/architecture/__init__.py \ + && touch /app/model/data/__init__.py \ + && touch /app/jobs/__init__.py \ + && touch /app/api/__init__.py + +ENV PYTHONPATH=/app + +EXPOSE 8088 +ENV MODEL_PATH=/app/models/best_conditional.pt \ + MODEL_NAME=best_conditional \ + BATCH_LIMIT=500 + +CMD ["uvicorn", "api.ripeness_api:app", "--host", "0.0.0.0", "--port", "8088", "--reload"] diff --git a/services/ripeness-ml/deploy/docker-compose.ripeness.yml b/services/ripeness-ml/deploy/docker-compose.ripeness.yml new file mode 100644 index 000000000..0ac0a30f2 --- /dev/null +++ b/services/ripeness-ml/deploy/docker-compose.ripeness.yml @@ -0,0 +1,36 @@ +services: + ripeness-api: + image: ripeness-api:latest + build: + context: .. + dockerfile: deploy/Dockerfile + container_name: ripeness-api + environment: + PGHOST: postgres + PGPORT: "5432" + PGDATABASE: missions_db + PGUSER: missions_user + PGPASSWORD: pg123 + + MINIO_ENDPOINT: minio-hot:9000 + MINIO_SECURE: "false" + MINIO_ACCESS_KEY: minioadmin + MINIO_SECRET_KEY: minioadmin123 + + MODEL_NAME: best_conditional + BATCH_LIMIT: "500" + FRUITS: "Apple,Banana,Orange" + volumes: + - ../checkpoints:/app/checkpoints + - ../configs:/app/configs + - ../model:/app/model + networks: + - agcloud_net + ports: + - "8091:8088" + restart: unless-stopped + +networks: + agcloud_net: + external: true + name: agcloud_ag_cloud diff --git a/services/ripeness-ml/jobs/weekly_ripeness_job.py b/services/ripeness-ml/jobs/weekly_ripeness_job.py new file mode 100644 index 000000000..387b7c034 --- /dev/null +++ b/services/ripeness-ml/jobs/weekly_ripeness_job.py @@ -0,0 +1,167 @@ +# file: services/weekly_ripeness_job.py +import io +import time +import torch +import psycopg2 +import datetime as dt +from urllib.parse import urlparse +from minio import Minio +from PIL import Image +import sys, os +sys.path.append(os.path.join(os.path.dirname(__file__), "..")) # so "models" is importable +from model.architecture.mobilenet_v3_large_head import build_conditional +from tqdm.auto import tqdm + + +from pathlib import Path +try: + from dotenv import load_dotenv + env_path = Path(__file__).resolve().parents[1] / ".env" + if env_path.exists(): + load_dotenv(env_path.as_posix()) +except Exception: + pass + +# ---- ENV ---- +PGHOST = os.getenv("PGHOST", "db") +PGPORT = int(os.getenv("PGPORT", "5432")) +PGDATABASE = os.getenv("PGDATABASE", "missions_db") +PGUSER = os.getenv("PGUSER", "missions_user") +PGPASSWORD = os.getenv("PGPASSWORD", "pg123") + +MINIO_ENDPOINT = os.getenv("MINIO_ENDPOINT", "127.0.0.1:9000") +MINIO_SECURE = os.getenv("MINIO_SECURE", "false").lower() == "true" +MINIO_ACCESS_KEY = os.getenv("MINIO_ACCESS_KEY", "minioadmin") +MINIO_SECRET_KEY = os.getenv("MINIO_SECRET_KEY", "minioadmin") + +MODEL_PATH = os.getenv("MODEL_PATH", "/models/best_conditional.pt") +MODEL_NAME = os.getenv("MODEL_NAME", "best_conditional") +BATCH_LIMIT = int(os.getenv("BATCH_LIMIT", "200")) + +# ----- labels & fruits mapping ----- +LABELS = ["unripe", "ripe", "overripe"] +FRUITS = ["Apple", "Banana", "Orange", "."] +FRUIT2IDX = {name.lower(): i for i, name in enumerate(FRUITS)} + +# ----- build model & load weights ----- +device = "cuda" if torch.cuda.is_available() else "cpu" +num_ripeness = len(LABELS) +num_fruits = len(FRUITS) + +model = build_conditional(num_ripeness=num_ripeness, num_fruits=num_fruits, embed_dim=16).to(device) + +ckpt = torch.load(MODEL_PATH, map_location=device) +state = ckpt["state_dict"] if (isinstance(ckpt, dict) and "state_dict" in ckpt) else ckpt + +assert state["fruit_embed.weight"].shape[0] == num_fruits, \ + f"Checkpoint expects {state['fruit_embed.weight'].shape[0]} fruits, but FRUITS has {num_fruits}" + +model.load_state_dict(state, strict=True) +model.eval() + +def load_image_for_model(img_bytes): + im = Image.open(io.BytesIO(img_bytes)).convert("RGB") + from torchvision import transforms + preprocess = transforms.Compose([ + transforms.Resize((224, 224)), + transforms.ToTensor(), + transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225]) + ]) + return preprocess(im).unsqueeze(0).to(device) + +@torch.no_grad() +def predict_ripeness(img_tensor, fruit_type: str): + idx = FRUIT2IDX.get(fruit_type.lower()) + if idx is None: + raise KeyError(f"skip: fruit '{fruit_type}' not in trained set {FRUITS}") + fruit_idx_tensor = torch.tensor([idx], dtype=torch.long, device=device) + logits = model(img_tensor, fruit_idx_tensor) + probs = torch.softmax(logits, dim=1).squeeze(0).cpu().numpy() + j = int(probs.argmax()) + return LABELS[j], float(probs[j]) + +# ---- MINIO ---- +minio_client = Minio(MINIO_ENDPOINT, MINIO_ACCESS_KEY, MINIO_SECRET_KEY, secure=MINIO_SECURE) + +def fetch_from_minio(image_url: str) -> bytes: + p = urlparse(image_url) + path = p.path.lstrip("/") + bucket, *rest = path.split("/", 1) + if not rest: + raise ValueError(f"Invalid URL path for MinIO: {image_url}") + obj = rest[0] + resp = minio_client.get_object(bucket, obj) + data = resp.read() + resp.close() + resp.release_conn() + return data + +# ---- DB ---- +def get_conn(): + return psycopg2.connect( + host=PGHOST, port=PGPORT, dbname=PGDATABASE, user=PGUSER, password=PGPASSWORD + ) + +def main(): + with get_conn() as conn, conn.cursor() as cur: + cur.execute(""" + SELECT il.id, il.ts, il.fruit_type, il.image_url + FROM inference_logs il + LEFT JOIN ripeness_predictions rp ON rp.inference_log_id = il.id + WHERE il.ts >= now() - interval '7 days' + AND rp.id IS NULL + ORDER BY il.id ASC + LIMIT %s; + """, (BATCH_LIMIT,)) + rows = cur.fetchall() + + processed = 0 + + # generate a single run_id for this batch + with get_conn() as conn, conn.cursor() as cur: + cur.execute("SELECT gen_random_uuid()") + run_id = cur.fetchone()[0] + + for inflog_id, ts, fruit_type, image_url in tqdm(rows, desc="Predicting ripeness"): + try: + if processed % 20 == 0: + print(f"...processed {processed} so far") + img_bytes = fetch_from_minio(image_url) + tensor = load_image_for_model(img_bytes) + try: + label, score = predict_ripeness(tensor, fruit_type) + except KeyError as skip: + print(f"[SKIP] inflog_id={inflog_id} :: {skip}") + continue + + # derive bucket/object_key and lookup device_id + device_id = None + try: + p = urlparse(image_url) + path = p.path.lstrip('/') + if '/' in path: + bucket, object_key = path.split('/', 1) + with get_conn() as conn, conn.cursor() as cur: + cur.execute("SELECT device_id FROM files WHERE bucket = %s AND object_key = %s", (bucket, object_key)) + res = cur.fetchone() + device_id = res[0] if res else None + except Exception: + # keep device_id as None if parsing/lookup fails + device_id = None + + with get_conn() as conn, conn.cursor() as cur: + cur.execute(""" + INSERT INTO ripeness_predictions + (inference_log_id, ts, ripeness_label, ripeness_score, model_name, run_id, device_id) + VALUES (%s, now(), %s, %s, %s, %s, %s) + ON CONFLICT (inference_log_id) DO NOTHING; + """, (inflog_id, label, score, MODEL_NAME, run_id, device_id)) + processed += 1 + print(f"[OK] inflog_id={inflog_id} -> {label} ({score:.4f})") + except Exception as e: + print(f"[ERR] inflog_id={inflog_id} url={image_url} :: {e}") + + print(f"Done. processed={processed}") + +if __name__ == "__main__": + main() diff --git a/services/ripeness-ml/model/architecture/mobilenet_v3_large_head.py b/services/ripeness-ml/model/architecture/mobilenet_v3_large_head.py new file mode 100644 index 000000000..3457d6953 --- /dev/null +++ b/services/ripeness-ml/model/architecture/mobilenet_v3_large_head.py @@ -0,0 +1,34 @@ +import torch.nn as nn +import torch +from torchvision.models import mobilenet_v3_large, MobileNet_V3_Large_Weights + +class RipenessModelConditional(nn.Module): + """ + Image -> MobileNetV3 backbone + Fruit type (idx) -> Embedding + Concatenate -> Linear -> ripeness logits + """ + def __init__(self, num_ripeness: int, num_fruits: int, embed_dim: int = 16): + super().__init__() + weights = MobileNet_V3_Large_Weights.IMAGENET1K_V2 + self.backbone = mobilenet_v3_large(weights=weights) + in_feats = self.backbone.classifier[-1].in_features + self.backbone.classifier[-1] = nn.Identity() + self.fruit_embed = nn.Embedding(num_fruits, embed_dim) + self.head = nn.Linear(in_feats + embed_dim, num_ripeness) + + def forward(self, x, fruit_idx): + feats = self.backbone(x) # [B, in_feats] + fvec = self.fruit_embed(fruit_idx) # [B, embed_dim] + out = torch.cat([feats, fvec], dim=1) # [B, in_feats+embed_dim] + return self.head(out) # [B, num_ripeness] + +def build_conditional(num_ripeness: int, num_fruits: int, embed_dim: int = 16) -> nn.Module: + return RipenessModelConditional(num_ripeness, num_fruits, embed_dim) + +def build_model(num_classes: int) -> nn.Module: + weights = MobileNet_V3_Large_Weights.IMAGENET1K_V2 + model = mobilenet_v3_large(weights=weights) + in_feats = model.classifier[-1].in_features + model.classifier[-1] = nn.Linear(in_feats, num_classes) + return model diff --git a/services/ripeness-ml/model/data/data_multitask.py b/services/ripeness-ml/model/data/data_multitask.py new file mode 100644 index 000000000..9b01959a3 --- /dev/null +++ b/services/ripeness-ml/model/data/data_multitask.py @@ -0,0 +1,48 @@ +from torch.utils.data import Dataset, DataLoader +from PIL import Image +from torchvision import transforms +import pandas as pd + +IMAGENET_MEAN=(0.485,0.456,0.406); IMAGENET_STD=(0.229,0.224,0.225) + +def build_transforms(img_size=224): + from torchvision import transforms as T + t_train = T.Compose([ + T.RandomResizedCrop(img_size, scale=(0.7,1.0)), + T.RandomHorizontalFlip(), + T.ColorJitter(0.2,0.2,0.2,0.05), + T.ToTensor(), T.Normalize(IMAGENET_MEAN, IMAGENET_STD), + ]) + t_val = T.Compose([ + T.Resize(int(img_size*1.15)), T.CenterCrop(img_size), + T.ToTensor(), T.Normalize(IMAGENET_MEAN, IMAGENET_STD), + ]) + return t_train, t_val + +class CSVConditional(Dataset): + def __init__(self, csv_path, fruit_to_idx, ripeness_to_idx, transform=None): + self.df = pd.read_csv(csv_path) + self.fruit_to_idx = fruit_to_idx + self.ripeness_to_idx = ripeness_to_idx + self.transform = transform + + def __len__(self): return len(self.df) + + def __getitem__(self, i): + row = self.df.iloc[i] + img = Image.open(row["path"]).convert("RGB") + if self.transform: img = self.transform(img) + fruit_idx = self.fruit_to_idx[row["fruit"]] + ripeness_idx = self.ripeness_to_idx[row["ripeness"]] + return img, fruit_idx, ripeness_idx + +def make_loaders(csv_train, csv_val, img_size, batch_size, num_workers, fruits, ripeness): + t_train, t_val = build_transforms(img_size) + f2i = {f:i for i,f in enumerate(fruits)} + r2i = {r:i for i,r in enumerate(ripeness)} + dtr = CSVConditional(csv_train, f2i, r2i, t_train) + dva = CSVConditional(csv_val, f2i, r2i, t_val) + from torch.utils.data import DataLoader + ltr = DataLoader(dtr, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True) + lva = DataLoader(dva, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True) + return ltr, lva, f2i, r2i diff --git a/services/ripeness-ml/model/data/transforms.py b/services/ripeness-ml/model/data/transforms.py new file mode 100644 index 000000000..a8f2c4b5b --- /dev/null +++ b/services/ripeness-ml/model/data/transforms.py @@ -0,0 +1,18 @@ +from torchvision import transforms +IMAGENET_MEAN=(0.485,0.456,0.406); IMAGENET_STD=(0.229,0.224,0.225) + +def build_transforms(img_size=224): + t_train = transforms.Compose([ + transforms.RandomResizedCrop(img_size, scale=(0.7,1.0)), + transforms.RandomHorizontalFlip(), + transforms.ColorJitter(0.2,0.2,0.2,0.05), + transforms.ToTensor(), + transforms.Normalize(IMAGENET_MEAN, IMAGENET_STD), + ]) + t_val = transforms.Compose([ + transforms.Resize(int(img_size*1.15)), + transforms.CenterCrop(img_size), + transforms.ToTensor(), + transforms.Normalize(IMAGENET_MEAN, IMAGENET_STD), + ]) + return t_train, t_val diff --git a/services/ripeness-ml/model/training/evaluate_conditional.py b/services/ripeness-ml/model/training/evaluate_conditional.py new file mode 100644 index 000000000..e10977e01 --- /dev/null +++ b/services/ripeness-ml/model/training/evaluate_conditional.py @@ -0,0 +1,138 @@ +# -*- coding: utf-8 -*- +# Evaluate the conditional ripeness model on test/val CSVs. +# Outputs: +# - metrics.json (accuracy, macro_f1, per-class F1) +# - classification_report.txt +# - confusion_matrix.png + +import os, sys, json, yaml +from pathlib import Path +import numpy as np +import torch +from sklearn.metrics import accuracy_score, f1_score, classification_report, confusion_matrix +import matplotlib.pyplot as plt + +# --- make 'models' & 'training' importable when running as a script --- +PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +if PROJECT_ROOT not in sys.path: + sys.path.insert(0, PROJECT_ROOT) + +from model.architecture.mobilenet_v3_large_head import build_conditional +from data.data_multitask import CSVConditional, build_transforms + +IMAGENET_MEAN=(0.485,0.456,0.406) +IMAGENET_STD=(0.229,0.224,0.225) + +def softmax(x): + x = x - x.max(axis=1, keepdims=True) + e = np.exp(x) + return e / e.sum(axis=1, keepdims=True) + +def load_cfg(): + return yaml.safe_load(open(os.path.join(PROJECT_ROOT, "configs/config.yaml"), "r", encoding="utf-8")) + +def make_loader(csv_path, fruits, ripeness, img_size=224, batch_size=64, num_workers=0): + _, t_val = build_transforms(img_size) + f2i = {f:i for i,f in enumerate(fruits)} + r2i = {r:i for i,r in enumerate(ripeness)} + ds = CSVConditional(csv_path, f2i, r2i, transform=t_val) + from torch.utils.data import DataLoader + return DataLoader(ds, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True) + +def plot_confusion_matrix(cm, classes, out_png): + fig = plt.figure(figsize=(5.5, 4.5)) + ax = fig.add_subplot(111) + im = ax.imshow(cm, interpolation='nearest') + ax.set_title('Confusion Matrix') + fig.colorbar(im) + tick_marks = np.arange(len(classes)) + ax.set_xticks(tick_marks); ax.set_xticklabels(classes, rotation=45, ha="right") + ax.set_yticks(tick_marks); ax.set_yticklabels(classes) + ax.set_ylabel('True'); ax.set_xlabel('Predicted') + # write counts + thresh = cm.max() / 2.0 if cm.size else 0.5 + for i in range(cm.shape[0]): + for j in range(cm.shape[1]): + ax.text(j, i, format(cm[i, j], 'd'), + ha="center", va="center", + color="white" if cm[i, j] > thresh else "black") + fig.tight_layout() + fig.savefig(out_png, dpi=160) + plt.close(fig) + +if __name__ == "__main__": + cfg = load_cfg() + device = "cuda" if torch.cuda.is_available() else "cpu" + + fruits = cfg["fruits"] + ripeness = cfg["ripeness"] + + # choose CSV: prefer test.csv; if missing/empty -> use val.csv + csv_test = Path(cfg["csv"].get("test", "data_mt/test.csv")) + csv_val = Path(cfg["csv"].get("val", "data_mt/val.csv")) + csv_path = csv_test if csv_test.exists() and csv_test.stat().st_size > 50 else csv_val + if not csv_path.exists(): + raise SystemExit(f"CSV not found: {csv_path}. Run ingest to create it.") + + # dataloader + loader = make_loader( + str(csv_path), fruits, ripeness, + img_size=cfg.get("img_size", 224), + batch_size=cfg.get("batch_size", 32), + num_workers=cfg.get("num_workers", 0) + ) + + # model + ckpt_dir = cfg["checkpoint_dir"] + ckpt = os.path.join(ckpt_dir, "best_conditional.pt") + if not os.path.exists(ckpt): + raise SystemExit(f"Checkpoint not found: {ckpt}") + + model = build_conditional(num_ripeness=len(ripeness), num_fruits=len(fruits)) + model.load_state_dict(torch.load(ckpt, map_location="cpu")) + model.eval().to(device) + + # predict + y_true, y_pred = [], [] + probs_all = [] + with torch.no_grad(): + for x, fidx, ridx in loader: + x = x.to(device) + fidx = torch.as_tensor(fidx, device=device) + logits = model(x, fidx).cpu().numpy() + prob = softmax(logits) + preds = prob.argmax(1) + y_pred.extend(preds.tolist()) + y_true.extend(ridx.numpy().tolist()) + probs_all.append(prob) + + y_true = np.array(y_true) + y_pred = np.array(y_pred) + probs = np.concatenate(probs_all, axis=0) if probs_all else np.empty((0,len(ripeness))) + + # metrics + acc = float(accuracy_score(y_true, y_pred)) + macro_f1 = float(f1_score(y_true, y_pred, average="macro")) + per_class_f1 = f1_score(y_true, y_pred, average=None) + per_class = {ripeness[i]: float(per_class_f1[i]) for i in range(len(ripeness))} + report = classification_report(y_true, y_pred, target_names=ripeness, digits=4) + cm = confusion_matrix(y_true, y_pred) + + # outputs + out_dir = os.path.join(PROJECT_ROOT, "checkpoints", "eval") + os.makedirs(out_dir, exist_ok=True) + # confusion matrix PNG + cm_png = os.path.join(out_dir, "confusion_matrix.png") + plot_confusion_matrix(cm, ripeness, cm_png) + # classification report + with open(os.path.join(out_dir, "classification_report.txt"), "w", encoding="utf-8") as f: + f.write(report + "\n") + f.write(f"\nAccuracy: {acc:.4f}\nMacro-F1: {macro_f1:.4f}\n") + # json metrics + with open(os.path.join(out_dir, "metrics.json"), "w", encoding="utf-8") as f: + json.dump({"accuracy": acc, "macro_f1": macro_f1, "per_class_f1": per_class}, f, indent=2) + + print(f"Evaluated on: {csv_path}") + print(f"Accuracy: {acc:.4f} | Macro-F1: {macro_f1:.4f}") + print("Per-class F1:", per_class) + print(f"Saved: {cm_png} and classification_report.txt, metrics.json") diff --git a/services/ripeness-ml/model/training/train_conditional.py b/services/ripeness-ml/model/training/train_conditional.py new file mode 100644 index 000000000..c2bf56357 --- /dev/null +++ b/services/ripeness-ml/model/training/train_conditional.py @@ -0,0 +1,114 @@ +import os, yaml, torch +from torch import nn +from sklearn.metrics import accuracy_score, f1_score + +from model.architecture.mobilenet_v3_large_head import build_conditional +from data.data_multitask import make_loaders + + +def evaluate(model, loader, device): + model.eval() + y_true, y_pred = [], [] + with torch.no_grad(): + for x, fidx, ridx in loader: + x = x.to(device) + fidx = torch.as_tensor(fidx, device=device) + logits = model(x, fidx) + y_pred.extend(logits.argmax(1).cpu().numpy()) + y_true.extend(ridx.numpy()) + acc = accuracy_score(y_true, y_pred) + f1 = f1_score(y_true, y_pred, average="macro") + return acc, f1 + + +def train_phase(model, ltr, lva, device, epochs, lr, wd, ckpt_dir, tag, ce, patience=2): + from torch.optim import AdamW + from torch.optim.lr_scheduler import CosineAnnealingLR + + os.makedirs(ckpt_dir, exist_ok=True) + opt = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=lr, weight_decay=wd) + sch = CosineAnnealingLR(opt, T_max=epochs) + + best_f1 = -1.0 + best_state = None + no_improve = 0 + + try: + for ep in range(1, epochs + 1): + model.train() + for x, fidx, ridx in ltr: + x = x.to(device) + fidx = torch.as_tensor(fidx, device=device) + ridx = torch.as_tensor(ridx, device=device) + + logits = model(x, fidx) + loss = ce(logits, ridx) + + opt.zero_grad() + loss.backward() + opt.step() + + acc, f1 = evaluate(model, lva, device) + sch.step() + print(f"[{tag} Epoch {ep}] val_acc={acc:.3f} val_f1={f1:.3f}") + + if f1 > best_f1 + 1e-4: + best_f1 = f1 + best_state = {k: v.cpu() for k, v in model.state_dict().items()} + torch.save(best_state, os.path.join(ckpt_dir, f"best_conditional_{tag}.pt")) + no_improve = 0 + else: + no_improve += 1 + if no_improve >= patience: + print(f"Early stopping ({tag}) — no improvement for {patience} epochs") + break + + except KeyboardInterrupt: + print("KeyboardInterrupt — saving best checkpoint so far...") + + finally: + if best_state is not None: + model.load_state_dict(best_state) + return model, best_f1 + + +if __name__ == "__main__": + cfg = yaml.safe_load(open("configs/config.yaml", "r", encoding="utf-8")) + device = "cuda" if torch.cuda.is_available() else "cpu" + + train_csv = cfg["csv"]["train"] + val_csv = cfg["csv"]["val"] + fruits = cfg["fruits"] + ripeness = cfg["ripeness"] + + ltr, lva, f2i, r2i = make_loaders( + train_csv, val_csv, + cfg["img_size"], cfg["batch_size"], cfg["num_workers"], + fruits, ripeness + ) + + model = build_conditional(num_ripeness=len(ripeness), num_fruits=len(fruits)).to(device) + ce = nn.CrossEntropyLoss() + + for p in model.backbone.features.parameters(): + p.requires_grad = False + + model, _ = train_phase( + model, ltr, lva, device, + cfg["epochs_frozen"], cfg["lr"], cfg["weight_decay"], + cfg["checkpoint_dir"], tag="frozen", ce=ce, patience=2 + + ) + + for p in model.parameters(): + p.requires_grad = True + + model, best_f1 = train_phase( + model, ltr, lva, device, + cfg["epochs_unfrozen"], cfg["lr"]/3, cfg["weight_decay"], + cfg["checkpoint_dir"], tag="unfrozen", ce=ce, patience=2 + ) + + os.makedirs(cfg["checkpoint_dir"], exist_ok=True) + torch.save(model.state_dict(), os.path.join(cfg["checkpoint_dir"], "best_conditional.pt")) + print("Saved:", os.path.join(cfg["checkpoint_dir"], "best_conditional.pt"), "| best F1:", best_f1) diff --git a/services/ripeness-ml/model/training/utils.py b/services/ripeness-ml/model/training/utils.py new file mode 100644 index 000000000..23e47edc7 --- /dev/null +++ b/services/ripeness-ml/model/training/utils.py @@ -0,0 +1,14 @@ +# import torch, random, numpy as np +# from collections import Counter + +# def set_seed(s=42): +# random.seed(s); np.random.seed(s); torch.manual_seed(s); torch.cuda.manual_seed_all(s) + +# def load_class_weights(trainloader, use=True): +# if not use: return None +# counts = Counter() +# for _,y in trainloader: +# for i in y.numpy(): counts[int(i)]+=1 +# total = sum(counts.values()) +# weights = [total/counts[i] for i in range(len(counts))] +# return torch.tensor(weights, dtype=torch.float32) diff --git a/services/ripeness-ml/requirements.txt b/services/ripeness-ml/requirements.txt new file mode 100644 index 000000000..0aa2be591 --- /dev/null +++ b/services/ripeness-ml/requirements.txt @@ -0,0 +1,16 @@ +torch==2.3.1 +torchvision==0.18.1 +# timm==1.0.9 +# scikit-learn==1.5.1 +matplotlib==3.9.0 +pillow==10.4.0 +pyyaml==6.0.2 +tqdm==4.66.4 +pandas==2.2.2 +onnx==1.16.0 +onnxruntime==1.18.1 +fastapi==0.115.0 +uvicorn[standard]==0.30.6 +minio==7.2.10 +python-dotenv==1.0.1 +psycopg2-binary diff --git a/services/ripeness-ml/tools/data_prep/ingest_kaggle_multitask.py b/services/ripeness-ml/tools/data_prep/ingest_kaggle_multitask.py new file mode 100644 index 000000000..46daeed76 --- /dev/null +++ b/services/ripeness-ml/tools/data_prep/ingest_kaggle_multitask.py @@ -0,0 +1,79 @@ + +import argparse, csv, random +from pathlib import Path + +IMG_EXT = {".jpg",".jpeg",".png",".bmp",".tif",".tiff",".webp"} + +RIPENESS_MAP = { + "unripe": "unripe", + "fresh": "ripe", + "ripe": "ripe", + "rotten": "overripe", +} + +FRUIT_KEYS = ["apple", "banana", "orange", "pineapple"] + +def detect_from_path(p: Path): + names = [pp.name.lower().replace(" ", "").replace("_","") for pp in [p] + list(p.parents)] + fruit = None + ripeness = None + + for n in names: + for fk in FRUIT_KEYS: + if fk in n: + fruit = fk + break + for key, mapped in RIPENESS_MAP.items(): + if key in n: + ripeness = mapped + break + if fruit and ripeness: + return fruit, ripeness + return fruit, ripeness + +def gather(root: Path): + rows = [] # (path, fruit, ripeness) + for fp in root.rglob("*"): + if fp.is_file() and fp.suffix.lower() in IMG_EXT: + fruit, ripeness = detect_from_path(fp) + if fruit and ripeness: + rows.append((fp.resolve().as_posix(), fruit, ripeness)) + return rows + +def write_csv(path: Path, rows): + path.parent.mkdir(parents=True, exist_ok=True) + with open(path, "w", newline="", encoding="utf-8") as f: + w = csv.writer(f) + w.writerow(["path","fruit","ripeness"]) + w.writerows(rows) + +if __name__ == "__main__": + ap = argparse.ArgumentParser(description="Create CSVs (train/val/test) with path,fruit,ripeness from Kaggle folders") + ap.add_argument("--src", required=True, help="path to .../dataset (the folder that contains train/ and test/)") + ap.add_argument("--outdir", default="data_mt", help="output folder for CSVs") + ap.add_argument("--split", default="0.8,0.2,0.0", help="train,val,test ratios") + ap.add_argument("--seed", type=int, default=42) + args = ap.parse_args() + + root = Path(args.src).resolve() + all_rows = gather(root) + if not all_rows: + raise SystemExit(f"No images found under: {root}. Check --src path.") + + random.seed(args.seed) + random.shuffle(all_rows) + + tr, va, te = [float(x) for x in args.split.split(",")] + assert abs(tr+va+te - 1.0) < 1e-6, "--split must sum to 1.0" + n = len(all_rows); ntr = int(tr*n); nv = int(va*n) + rows_tr = all_rows[:ntr]; rows_va = all_rows[ntr:ntr+nv]; rows_te = all_rows[ntr+nv:] + + out = Path(args.outdir) + write_csv(out/"train.csv", rows_tr) + write_csv(out/"val.csv", rows_va) + write_csv(out/"test.csv", rows_te) + + print(f"Saved CSVs in {out.resolve()}") + print(f" train.csv: {len(rows_tr)}") + print(f" val.csv: {len(rows_va)}") + print(f" test.csv: {len(rows_te)}") diff --git a/services/ripeness-ml/tools/data_prep/prepare_from_minio.py b/services/ripeness-ml/tools/data_prep/prepare_from_minio.py new file mode 100644 index 000000000..0afb84c54 --- /dev/null +++ b/services/ripeness-ml/tools/data_prep/prepare_from_minio.py @@ -0,0 +1,161 @@ +# AGCLOUD/services/ripeness-ml/scripts/prepare_from_minio.py +import os, io, csv, argparse, sys, re, datetime as dt +from pathlib import Path +from typing import Dict, List, Tuple, Optional +from minio import Minio +from minio.error import S3Error +from tqdm import tqdm +import random + +def parse_args(): + p = argparse.ArgumentParser(description="Sync labeled images from MinIO into local data/train|val|test/") + p.add_argument("--minio-url", required=True, help="e.g. http://127.0.0.1:9000") + p.add_argument("--access-key", required=False, default=os.getenv("MINIO_ACCESS_KEY","minioadmin")) + p.add_argument("--secret-key", required=False, default=os.getenv("MINIO_SECRET_KEY","minioadmin")) + p.add_argument("--secure", action="store_true", help="use HTTPS") + p.add_argument("--bucket", required=True, help="e.g. classification") + p.add_argument("--prefix", required=True, help="e.g. samples/2025/ or samples/") + p.add_argument("--outdir", default="data", help="local output root") + p.add_argument("--split", default="0.7,0.15,0.15", help="train,val,test ratios") + p.add_argument("--labels-csv", help="path to labels.csv (local file) OR object path in bucket (starts without leading /)") + p.add_argument("--infer-label-from-folder", action="store_true", help="take class name from folder under prefix") + p.add_argument("--from-date", help="YYYY-MM-DD (inclusive)") + p.add_argument("--to-date", help="YYYY-MM-DD (inclusive)") + p.add_argument("--last-days", type=int, help="Use only last N days under prefix (overrides from/to)") + p.add_argument("--dry-run", action="store_true") + return p.parse_args() + +def list_objects(client: Minio, bucket: str, prefix: str): + return client.list_objects(bucket, prefix=prefix, recursive=True) + +DATE_RE = re.compile(r"/(\d{4})/(\d{2})/(\d{2})(?:/|$)") + +def object_date(obj_name: str) -> Optional[dt.date]: + m = DATE_RE.search("/"+obj_name.strip("/")) + if not m: return None + y, mth, d = map(int, m.groups()) + return dt.date(y, mth, d) + +def load_labels_from_csv_local(csv_path: str) -> Dict[str, str]: + mapping = {} + with open(csv_path, "r", newline="", encoding="utf-8") as f: + r = csv.DictReader(f) + for row in r: + mapping[row["object"].strip()] = row["label"].strip() + return mapping + +def load_labels_from_csv_minio(client: Minio, bucket: str, obj_path: str) -> Dict[str, str]: + resp = client.get_object(bucket, obj_path) + data = resp.read().decode("utf-8") + mapping = {} + for row in csv.DictReader(io.StringIO(data)): + mapping[row["object"].strip()] = row["label"].strip() + return mapping + +def ensure_dirs(root: Path, classes: List[str]): + for split in ["train","val","test"]: + for c in classes: + (root/split/c).mkdir(parents=True, exist_ok=True) + +def main(): + args = parse_args() + tr, va, te = [float(x) for x in args.split.split(",")] + assert abs(tr+va+te - 1.0) < 1e-6, "--split must sum to 1.0" + + secure = args.secure or args.minio_url.startswith("https://") + endpoint = args.minio_url.replace("http://","").replace("https://","") + client = Minio(endpoint, access_key=args.access_key, secret_key=args.secret_key, secure=secure) + + # python arg names can't contain hyphen; fallback + access = getattr(args, "access_key", getattr(args, "access-key", None)) + secret = getattr(args, "secret_key", getattr(args, "secret-key", None)) + client = Minio(endpoint, access_key=access, secret_key=secret, secure=secure) + + # gather all candidate objects under prefix + objs = list(list_objects(client, args.bucket, args.prefix)) + if len(objs)==0: + print("No objects under prefix:", args.prefix); sys.exit(1) + + # filter by date + if args.last_days: + cutoff = dt.date.today() - dt.timedelta(days=args.last_days) + objs = [o for o in objs if (object_date(o.object_name) or dt.date.min) >= cutoff] + else: + dfrom = dt.date.fromisoformat(args.from_date) if args.from_date else None + dto = dt.date.fromisoformat(args.to_date) if args.to_date else None + if dfrom or dto: + def inrange(o): + od = object_date(o.object_name) + if not od: return False + if dfrom and od < dfrom: return False + if dto and od > dto: return False + return True + objs = [o for o in objs if inrange(o)] + + # Build label mapping + label_map: Dict[str,str] = {} + classes: set = set() + + if args.labels_csv: + if os.path.exists(args.labels_csv): + label_map = load_labels_from_csv_local(args.labels_csv) + else: + label_map = load_labels_from_csv_minio(client, args.bucket, args.labels_csv) + classes = set(label_map.values()) + candidates = [(o.object_name, label_map.get(o.object_name)) for o in objs if o.object_name in label_map] + elif args.infer_label_from_folder: + # Expect ...//... somewhere AFTER prefix + pref = args.prefix.strip("/") + candidates = [] + for o in objs: + rel = o.object_name[len(pref):].strip("/") + parts = rel.split("/") + if len(parts)>=2: + cls = parts[0] + candidates.append((o.object_name, cls)) + classes.add(cls) + if not classes: + print("Could not infer classes from folders; provide --labels-csv", file=sys.stderr) + sys.exit(2) + else: + print("Provide either --labels-csv or --infer-label-from-folder", file=sys.stderr) + sys.exit(2) + + classes = sorted(list(classes)) + print("Classes:", classes, "| samples:", len(candidates)) + root = Path(args.outdir) + ensure_dirs(root, classes) + + # stratified split by class + by_cls: Dict[str, List[str]] = {c: [] for c in classes} + for obj, lab in candidates: + if lab in by_cls: + by_cls[lab].append(obj) + for c in classes: random.shuffle(by_cls[c]) + + plan: List[Tuple[str, str]] = [] # (object_name, target_path) + for c in classes: + items = by_cls[c] + n = len(items); ntr = int(tr*n); nv = int(va*n) + tr_items = items[:ntr]; va_items = items[ntr:ntr+nv]; te_items = items[ntr+nv:] + for src in tr_items: + plan.append((src, str(root/ "train"/c/ Path(src).name))) + for src in va_items: + plan.append((src, str(root/ "val"/c/ Path(src).name))) + for src in te_items: + plan.append((src, str(root/ "test"/c/ Path(src).name))) + + if args.dry_run: + print(f"DRY-RUN: would download {len(plan)} files.") + return + + # download + for src, dst in tqdm(plan, desc="Downloading"): + dpath = Path(dst) + if dpath.exists(): continue + dpath.parent.mkdir(parents=True, exist_ok=True) + client.fget_object(args.bucket, src, dst) + print("Done. Data prepared under:", root.resolve()) + +if __name__ == "__main__": + main() diff --git a/services/ripeness-ml/tools/export/export_onnx_conditional.py b/services/ripeness-ml/tools/export/export_onnx_conditional.py new file mode 100644 index 000000000..1431ea0e6 --- /dev/null +++ b/services/ripeness-ml/tools/export/export_onnx_conditional.py @@ -0,0 +1,25 @@ +import torch, yaml, os +from model.architecture.mobilenet_v3_large_head import build_conditional + +if __name__ == "__main__": + cfg = yaml.safe_load(open("configs/config.yaml")) + fruits = cfg["fruits"] + ripeness = cfg["ripeness"] + + model = build_conditional(num_ripeness=len(ripeness), num_fruits=len(fruits)) + ckpt_path = os.path.join(cfg["checkpoint_dir"], "best_conditional.pt") + model.load_state_dict(torch.load(ckpt_path, map_location="cpu")) + model.eval() + + dummy_x = torch.randn(1, 3, cfg["img_size"], cfg["img_size"]) + dummy_f = torch.zeros(1, dtype=torch.long) # example fruit index + torch.onnx.export( + model, (dummy_x, dummy_f), + "ripeness_conditional.onnx", + input_names=["image", "fruit_idx"], + output_names=["ripeness_logits"], + dynamic_axes={"image": {0: "batch"}, "ripeness_logits": {0: "batch"}}, + opset_version=13 + ) + + print("✅ Exported: ripeness_conditional.onnx") diff --git a/services/ripeness-ml/tools/inference/infer_minio_batch.py b/services/ripeness-ml/tools/inference/infer_minio_batch.py new file mode 100644 index 000000000..0206a3791 --- /dev/null +++ b/services/ripeness-ml/tools/inference/infer_minio_batch.py @@ -0,0 +1,193 @@ +# AGCLOUD/services/ripeness-ml/scripts/infer_minio_batch.py +import argparse, os, sys, csv, json +from io import BytesIO +from pathlib import Path + +import numpy as np +from PIL import Image +from minio import Minio +from tqdm import tqdm +import onnxruntime as ort +from torchvision import transforms + +# ---- Configurable defaults ---- +IMAGENET_MEAN = (0.485, 0.456, 0.406) +IMAGENET_STD = (0.229, 0.224, 0.225) +IMG_TFM = transforms.Compose([ + transforms.Resize(256), + transforms.CenterCrop(224), + transforms.ToTensor(), + transforms.Normalize(IMAGENET_MEAN, IMAGENET_STD), +]) + +DEFAULT_FRUITS = ["apple", "banana", "orange", "pineapple"] # order matters! +RIPENESS = ["unripe", "ripe", "overripe"] + +IMG_EXTS = (".jpg", ".jpeg", ".png", ".bmp", ".tif", ".tiff", ".webp") + + +def parse_args(): + p = argparse.ArgumentParser( + description="Batch inference from MinIO prefix with conditional ONNX model (image + fruit_idx)." + ) + p.add_argument("--minio-url", required=True, help="http://127.0.0.1:9001") + p.add_argument("--access-key", default=os.getenv("MINIO_ACCESS_KEY", "minioadmin")) + p.add_argument("--secret-key", default=os.getenv("MINIO_SECRET_KEY", "minioadmin")) + p.add_argument("--secure", action="store_true", help="Use HTTPS") + + p.add_argument("--bucket", required=True, help="MinIO bucket name") + p.add_argument("--prefix", help="Prefix to scan, e.g. samples/2025/10/15 (ignored if --pairs-csv is used)") + + p.add_argument("--onnx", default="ripeness_conditional.onnx", help="Path to conditional ONNX model") + p.add_argument("--providers", nargs="*", default=None, help="ONNX Runtime providers list (default: CPU)") + + # Fruit specification + p.add_argument("--fruit", help="Fruit for ALL objects (apple|banana|orange|pineapple)") + p.add_argument("--pairs-csv", help="CSV file with columns: object,fruit (mapping per object)") + + # Fruits list order (so fruit_idx matches training) + p.add_argument("--fruits", default=None, + help='Fruits list in order, e.g. \'["apple","banana","orange","pineapple"]\' or "apple,banana,orange,pineapple"') + + # Output + p.add_argument("--out-csv", help="Optional: write results to CSV (object,fruit,label,prob_unripe,prob_ripe,prob_overripe)") + p.add_argument("--quiet", action="store_true", help="Do not print JSON lines to stdout") + + args = p.parse_args() + + if not args.pairs_csv and not (args.prefix and args.fruit): + p.error("Provide either --pairs-csv OR both --prefix and --fruit.") + + return args + + +def parse_fruits_list(fruits_arg): + if not fruits_arg: + return DEFAULT_FRUITS + s = fruits_arg.strip() + if s.startswith("["): + # JSON-ish + try: + import json as _json + lst = _json.loads(s) + return [x.strip().lower() for x in lst] + except Exception: + pass + # comma separated + return [x.strip().lower() for x in s.split(",") if x.strip()] + + +def softmax(x): + x = x - x.max(axis=1, keepdims=True) + e = np.exp(x) + return e / e.sum(axis=1, keepdims=True) + + +def is_image(name: str) -> bool: + return name.lower().endswith(IMG_EXTS) + + +def load_pairs_csv(path: str): + mapping = {} + with open(path, "r", newline="", encoding="utf-8") as f: + r = csv.DictReader(f) + if "object" not in r.fieldnames or "fruit" not in r.fieldnames: + raise SystemExit("pairs CSV must have columns: object,fruit") + for row in r: + obj = row["object"].strip() + fruit = row["fruit"].strip().lower() + mapping[obj] = fruit + return mapping + + +def open_minio(args): + secure = args.secure or args.minio_url.startswith("https://") + endpoint = args.minio_url.replace("http://", "").replace("https://", "") + return Minio(endpoint, access_key=args.access_key, secret_key=args.secret_key, secure=secure) + + +def main(): + args = parse_args() + fruits = parse_fruits_list(args.fruits) + + # Validate fruit names + fruit_set = set(fruits) + + # Prepare ONNX Runtime session + providers = args.providers or ["CPUExecutionProvider"] + sess = ort.InferenceSession(args.onnx, providers=providers) + + client = open_minio(args) + + # Prepare iterator over (object_name, fruit) + if args.pairs_csv: + mapping = load_pairs_csv(args.pairs_csv) + # Only iterate the keys present in the CSV (no MinIO list needed) + iterator = [(obj, mapping[obj]) for obj in mapping] + else: + fixed_fruit = args.fruit.lower() + if fixed_fruit not in fruit_set: + raise SystemExit(f"--fruit must be one of {fruits}; got {fixed_fruit}") + iterator = [] + for obj in client.list_objects(args.bucket, prefix=args.prefix, recursive=True): + if is_image(obj.object_name): + iterator.append((obj.object_name, fixed_fruit)) + + # Output CSV writer (optional) + csv_writer = None + if args.out_csv: + Path(args.out_csv).parent.mkdir(parents=True, exist_ok=True) + fcsv = open(args.out_csv, "w", newline="", encoding="utf-8") + csv_writer = csv.writer(fcsv) + csv_writer.writerow(["object", "fruit", "label", "prob_unripe", "prob_ripe", "prob_overripe"]) + + # Run predictions + for obj_name, fruit in tqdm(iterator, desc="Predicting"): + if fruit not in fruit_set: + # Unknown fruit -> skip + if not args.quiet: + print(json.dumps({"object": obj_name, "error": f"unknown fruit '{fruit}' (allowed {fruits})"}, ensure_ascii=False)) + continue + + # Fetch image bytes + if args.pairs_csv: + # object names in CSV must be full paths in bucket + resp = client.get_object(args.bucket, obj_name) + else: + resp = client.get_object(args.bucket, obj_name) + + try: + img = Image.open(BytesIO(resp.read())).convert("RGB") + finally: + resp.close(); resp.release_conn() + + x = IMG_TFM(img).unsqueeze(0).numpy() + fidx = np.array([fruits.index(fruit)], dtype=np.int64) + + logits = sess.run(["ripeness_logits"], {"images": x, "fruit_idx": fidx})[0] + prob = softmax(logits)[0] + idx = int(prob.argmax()) + label = RIPENESS[idx] + + record = { + "object": obj_name, + "fruit": fruit, + "label": label, + "probs": {RIPENESS[i]: float(prob[i]) for i in range(len(RIPENESS))} + } + + if not args.quiet: + print(json.dumps(record, ensure_ascii=False)) + + if csv_writer: + csv_writer.writerow([ + obj_name, fruit, label, + f"{prob[0]:.6f}", f"{prob[1]:.6f}", f"{prob[2]:.6f}" + ]) + + if csv_writer: + fcsv.close() + + +if __name__ == "__main__": + main() diff --git a/services/sound_metrics/Dockerfile b/services/sound_metrics/Dockerfile index a74fd0275..fa1af4ac7 100644 --- a/services/sound_metrics/Dockerfile +++ b/services/sound_metrics/Dockerfile @@ -7,7 +7,7 @@ RUN apt-get update && apt-get install -y --no-install-recommends \ WORKDIR /app -COPY certs/ /app/certs/ +# COPY certs/ /app/certs/ RUN if [ -d /app/certs ] && ls /app/certs/*.crt >/dev/null 2>&1; then \ cp /app/certs/*.crt /usr/local/share/ca-certificates/ && update-ca-certificates; \ else \ @@ -23,7 +23,7 @@ COPY src/ ./src/ ENV PYTHONUNBUFFERED=1 \ ADDR=0.0.0.0 \ - PORT=8001 \ + PORT=8005 \ WINDOW_MIN=5 \ FRAME_SEC=0.1 \ THRESHOLD=0.01 \ @@ -31,8 +31,8 @@ ENV PYTHONUNBUFFERED=1 \ MINIO_ENDPOINT=minio:9000 \ MINIO_ACCESS_KEY=minioadmin \ MINIO_SECRET_KEY=minioadmin123 \ - MINIO_BUCKET=audio \ - MINIO_PREFIX=samples/ + MINIO_BUCKET=sound \ + MINIO_PREFIX=sounds/ EXPOSE 8001 diff --git a/services/sound_metrics/src/metrics.py b/services/sound_metrics/src/metrics.py index 8db44afcd..0f5e7d7d3 100644 --- a/services/sound_metrics/src/metrics.py +++ b/services/sound_metrics/src/metrics.py @@ -15,7 +15,7 @@ # === Environment === ADDR = os.getenv("ADDR", "0.0.0.0") -PORT = int(os.getenv("PORT", "8001")) +PORT = int(os.getenv("PORT", "8005")) WINDOW_MIN = int(os.getenv("WINDOW_MIN", 5)) FRAME_SEC = float(os.getenv("FRAME_SEC", 0.1)) THRESHOLD = float(os.getenv("THRESHOLD", 0.01)) @@ -23,7 +23,7 @@ MINIO_ENDPOINT = os.getenv("MINIO_ENDPOINT", "localhost:9000") MINIO_ACCESS = os.getenv("MINIO_ACCESS_KEY", "minioadmin") MINIO_SECRET = os.getenv("MINIO_SECRET_KEY", "minioadmin123") -MINIO_BUCKET = os.getenv("MINIO_BUCKET", "telemetry") +MINIO_BUCKET = os.getenv("MINIO_BUCKET", "sound") MINIO_PREFIX = os.getenv("MINIO_PREFIX", "sounds/") ALLOWED_EXTS = {".wav", ".flac", ".ogg", ".aiff", ".aif", ".au", ".mp3", ".m4a", ".aac", ".opus"} diff --git a/services/sounds/API-development/.coverage b/services/sounds/API-development/.coverage deleted file mode 100644 index 0a26d044f..000000000 Binary files a/services/sounds/API-development/.coverage and /dev/null differ diff --git a/services/sounds/API-development/.vscode/settings.json b/services/sounds/API-development/.vscode/settings.json deleted file mode 100644 index 642ff51b6..000000000 --- a/services/sounds/API-development/.vscode/settings.json +++ /dev/null @@ -1,3 +0,0 @@ -{ - "python.REPL.enableREPLSmartSend": false -} \ No newline at end of file diff --git a/services/sounds/API-development/tests/__pycache__/__init__.cpython-311.pyc b/services/sounds/API-development/tests/__pycache__/__init__.cpython-311.pyc deleted file mode 100644 index dc7c5d115..000000000 Binary files a/services/sounds/API-development/tests/__pycache__/__init__.cpython-311.pyc and /dev/null differ diff --git a/services/sounds/API-development/tests/__pycache__/test_app.cpython-311-pytest-8.4.2.pyc b/services/sounds/API-development/tests/__pycache__/test_app.cpython-311-pytest-8.4.2.pyc deleted file mode 100644 index ae1136aaa..000000000 Binary files a/services/sounds/API-development/tests/__pycache__/test_app.cpython-311-pytest-8.4.2.pyc and /dev/null differ diff --git a/services/sounds/compression/.coverage b/services/sounds/compression/.coverage deleted file mode 100644 index 61d94bc14..000000000 Binary files a/services/sounds/compression/.coverage and /dev/null differ diff --git a/services/sounds/compression/requirments.txt b/services/sounds/compression/requirments.txt deleted file mode 100644 index 160ff5c43..000000000 --- a/services/sounds/compression/requirments.txt +++ /dev/null @@ -1,11 +0,0 @@ -# Core dependencies -minio>=7.1.0 - -# Testing dependencies -pytest>=7.2.0 -pytest-mock>=3.8.0 - -# Note: ffmpeg must be installed separately as a system dependency -# Ubuntu/Debian: apt-get install ffmpeg -# macOS: brew install ffmpeg -# Windows: Download from https://ffmpeg.org/download.html \ No newline at end of file diff --git a/services/sounds/compression/results/benchmarks.csv b/services/sounds/compression/results/benchmarks.csv deleted file mode 100644 index b940286a6..000000000 --- a/services/sounds/compression/results/benchmarks.csv +++ /dev/null @@ -1,3 +0,0 @@ -file,codec,orig_bytes,encoded_bytes,compression_ratio_orig_over_encoded,encode_time_sec,encode_cpu_avg_percent -crying-animal-84745.mp3,flac,74400,314441,0.237,0.121,0.0 -crying-animal-84745.mp3,opus,74400,41630,1.787,0.121,0.0 diff --git a/services/sounds/compression/scripts/minio_client.py b/services/sounds/compression/scripts/minio_client.py deleted file mode 100644 index 585d4c8d2..000000000 --- a/services/sounds/compression/scripts/minio_client.py +++ /dev/null @@ -1,16 +0,0 @@ -from minio import Minio - -MINIO_ENDPOINT = "localhost:9001" -ACCESS_KEY = "minioadmin" -SECRET_KEY = "minioadmin123" -BUCKET_NAME = "compression" - -client = Minio( - MINIO_ENDPOINT, - access_key=ACCESS_KEY, - secret_key=SECRET_KEY, - secure=False, -) - -if not client.bucket_exists(BUCKET_NAME): - client.make_bucket(BUCKET_NAME) diff --git a/services/sounds/compression/scripts/prototype_lib.py b/services/sounds/compression/scripts/prototype_lib.py deleted file mode 100644 index bde806334..000000000 --- a/services/sounds/compression/scripts/prototype_lib.py +++ /dev/null @@ -1,58 +0,0 @@ -from pathlib import Path -import subprocess -import tempfile -import time -from minio_client import client, BUCKET_NAME - -INPUT_EXTS = {".wav", ".mp3", ".flac", ".ogg", ".m4a", ".aac", ".wma", ".opus"} - -RAW_PREFIX = "raw/" -COMP_PREFIX = "compressed/" - -def iter_input_files(): - """Yield MinIO object names in RAW_PREFIX with accepted extensions.""" - for obj in client.list_objects(BUCKET_NAME, prefix=RAW_PREFIX, recursive=True): - if any(obj.object_name.lower().endswith(ext) for ext in INPUT_EXTS): - yield obj.object_name - -def build_ffmpeg_cmds(in_local_path: Path, codec="all", flac_level="5", opus_bitrate="96k"): - """ - Return ffmpeg commands to encode a local file. - Output will be a temporary file (to upload after encode). - """ - cmds = [] - temp_dir = Path(tempfile.gettempdir()) - if codec in ("flac", "all"): - flac_out = temp_dir / f"{in_local_path.stem}.flac" - flac_cmd = [ - "ffmpeg", "-y", "-hide_banner", "-loglevel", "error", - "-i", str(in_local_path), - "-c:a", "flac", "-compression_level", flac_level, - str(flac_out) - ] - cmds.append(("flac", flac_cmd, flac_out)) - if codec in ("opus", "all"): - opus_out = temp_dir / f"{in_local_path.stem}.opus" - opus_cmd = [ - "ffmpeg", "-y", "-hide_banner", "-loglevel", "error", - "-i", str(in_local_path), - "-c:a", "libopus", "-b:a", opus_bitrate, - str(opus_out) - ] - cmds.append(("opus", opus_cmd, opus_out)) - return cmds - -def download_raw_to_temp(obj_name: str) -> Path: - """Download MinIO raw object to temporary file.""" - local_path = Path(tempfile.gettempdir()) / Path(obj_name).name - client.fget_object(BUCKET_NAME, obj_name, str(local_path)) - return local_path - -def upload_compressed(local_path: Path): - """Upload encoded file to MinIO compressed folder.""" - target_name = f"{COMP_PREFIX}{int(time.time())}_{local_path.name}" - client.fput_object(BUCKET_NAME, target_name, str(local_path)) - return target_name - -def delete_object(obj_name: str): - client.remove_object(BUCKET_NAME, obj_name) \ No newline at end of file diff --git a/services/sounds/compression/scripts/run_bench.py b/services/sounds/compression/scripts/run_bench.py deleted file mode 100644 index d21f0f29c..000000000 --- a/services/sounds/compression/scripts/run_bench.py +++ /dev/null @@ -1,86 +0,0 @@ -from pathlib import Path -import time -import csv -from statistics import mean -import subprocess -from prototype_lib import iter_input_files, build_ffmpeg_cmds, download_raw_to_temp, upload_compressed, delete_object -from minio_client import BUCKET_NAME, client - -RES_DIR = Path("results") -RES_DIR.mkdir(exist_ok=True) - -def file_size_minio(obj_name: str) -> int: - """Return object size in bytes.""" - try: - stat = client.stat_object(BUCKET_NAME, obj_name) - return stat.size - except: - return 0 - -def run_and_profile(cmd): - import psutil - start = time.time() - proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) - parent = psutil.Process(proc.pid) - samples = [] - while proc.poll() is None: - cpu_total = 0.0 - for pr in [parent] + parent.children(recursive=True): - try: - cpu_total += pr.cpu_percent(interval=0.1) - except psutil.NoSuchProcess: - continue - samples.append(cpu_total) - out, err = proc.communicate() - wall = time.time() - start - avg_cpu = mean(samples) if samples else 0.0 - return proc.returncode, wall, avg_cpu, (out or b"") + (err or b"") - -def main(): - rows = [] - files = list(iter_input_files()) - if not files: - print("No raw files found in MinIO") - return - - for obj_name in files: - local_file = download_raw_to_temp(obj_name) - orig_size = local_file.stat().st_size - - for codec, cmd, outp in build_ffmpeg_cmds(local_file): - rc, wall, cpu, _ = run_and_profile(cmd) - if rc != 0: - print(f"[FAIL] {obj_name} ({codec})") - continue - - target_obj = upload_compressed(outp) - enc_size = file_size_minio(target_obj) - ratio = (orig_size / enc_size) if enc_size else 0.0 - - print(f"[OK] {obj_name} | {codec.upper()}: {enc_size} bytes, {wall:.2f}s, CPU~{cpu:.1f}% (ratio={ratio:.3f})") - - rows.append({ - "file": Path(obj_name).name, - "codec": codec, - "orig_bytes": orig_size, - "encoded_bytes": enc_size, - "compression_ratio_orig_over_encoded": round(ratio, 3), - "encode_time_sec": round(wall, 3), - "encode_cpu_avg_percent": round(cpu, 1), - }) - - # Clean up local encoded file - outp.unlink() - - local_file.unlink() - - if rows: - out_csv = RES_DIR / "benchmarks.csv" - with open(out_csv, "w", newline="", encoding="utf-8") as fh: - writer = csv.DictWriter(fh, fieldnames=list(rows[0].keys())) - writer.writeheader() - writer.writerows(rows) - print(f"Saved CSV: {out_csv}") - -if __name__ == "__main__": - main() \ No newline at end of file diff --git a/services/sounds/compression/scripts/tiering_job.py b/services/sounds/compression/scripts/tiering_job.py deleted file mode 100644 index b1c4b0e70..000000000 --- a/services/sounds/compression/scripts/tiering_job.py +++ /dev/null @@ -1,101 +0,0 @@ -from pathlib import Path -import time -import argparse -import subprocess -import tempfile -from prototype_lib import ( - iter_input_files, build_ffmpeg_cmds, download_raw_to_temp, - upload_compressed, delete_object, INPUT_EXTS -) -from minio_client import client, BUCKET_NAME - -DEFAULT_RAW_MAX_AGE_HOURS = 24 -DEFAULT_COMP_MAX_AGE_DAYS = 90 -DEFAULT_LONG_TERM_CODEC = "opus" - -def get_age_seconds(obj_name: str, mode: str = "mtime") -> float: - stat = client.stat_object(BUCKET_NAME, obj_name) - ts = getattr(stat, "last_modified", None) - if ts is None: - return 0 - # last_modified is datetime, convert to timestamp - return time.time() - ts.timestamp() - -def is_older_than(obj_name: str, age_seconds: int, mode: str) -> bool: - return get_age_seconds(obj_name, mode) >= age_seconds - -def encode_and_upload(obj_name: str, codec: str) -> str: - local_file = download_raw_to_temp(obj_name) - for c, cmd, outp in build_ffmpeg_cmds(local_file, codec=codec): - rc = subprocess.call(cmd) - if rc != 0: - raise RuntimeError(f"Encode failed: {obj_name} -> {codec}") - target_obj = upload_compressed(outp) - outp.unlink() - local_file.unlink() - return target_obj - -def cleanup_compressed(max_age_days: int, dry_run: bool) -> int: - if max_age_days <= 0: - return 0 - cutoff_sec = max_age_days * 86400 - deleted = 0 - for obj in client.list_objects(BUCKET_NAME, prefix="compressed/", recursive=True): - age = get_age_seconds(obj.object_name) - if age >= cutoff_sec: - if dry_run: - print(f"[DRY] would delete compressed: {obj.object_name}") - else: - delete_object(obj.object_name) - deleted += 1 - print(f"[DEL] compressed old: {obj.object_name}") - return deleted - -def main(): - ap = argparse.ArgumentParser(description="Two-tier storage job with MinIO") - ap.add_argument("--raw-max-age-hours", type=int, default=None) - ap.add_argument("--raw-max-age-minutes", type=int, default=None) - ap.add_argument("--age-mode", choices=["mtime", "ctime"], default="mtime") - ap.add_argument("--codec", choices=["opus", "flac"], default=DEFAULT_LONG_TERM_CODEC) - ap.add_argument("--delete-raw-after", action="store_true") - ap.add_argument("--compressed-max-age-days", type=int, default=DEFAULT_COMP_MAX_AGE_DAYS) - ap.add_argument("--dry-run", action="store_true") - args = ap.parse_args() - - if args.raw_max_age_minutes is not None: - raw_age_seconds = args.raw_max_age_minutes * 60 - elif args.raw_max_age_hours is not None: - raw_age_seconds = args.raw_max_age_hours * 3600 - else: - raw_age_seconds = DEFAULT_RAW_MAX_AGE_HOURS * 3600 - - processed, raw_deleted = 0, 0 - - for obj_name in iter_input_files(): - if is_older_than(obj_name, raw_age_seconds, args.age_mode): - if args.dry_run: - print(f"[DRY] would encode {obj_name} -> {args.codec.upper()}, delete RAW={args.delete_raw_after}") - processed += 1 - continue - - try: - target_obj = encode_and_upload(obj_name, args.codec) - print(f"[OK] {obj_name} -> {target_obj} ({args.codec})") - processed += 1 - - if args.delete_raw_after: - delete_object(obj_name) - raw_deleted += 1 - print(f"[DEL] raw: {obj_name}") - - except Exception as e: - print(f"[FAIL] {obj_name}: {e}") - - comp_deleted = cleanup_compressed(args.compressed_max_age_days, args.dry_run) - - print(f"Done. Processed={processed}, Raw deletions={raw_deleted}, Compressed deletions={comp_deleted}, " - f"Mode={args.age_mode}, Threshold={raw_age_seconds}s") - - -if __name__ == "__main__": - main() \ No newline at end of file diff --git a/services/sounds/compression/src/minio_client.py b/services/sounds/compression/src/minio_client.py deleted file mode 100644 index 585d4c8d2..000000000 --- a/services/sounds/compression/src/minio_client.py +++ /dev/null @@ -1,16 +0,0 @@ -from minio import Minio - -MINIO_ENDPOINT = "localhost:9001" -ACCESS_KEY = "minioadmin" -SECRET_KEY = "minioadmin123" -BUCKET_NAME = "compression" - -client = Minio( - MINIO_ENDPOINT, - access_key=ACCESS_KEY, - secret_key=SECRET_KEY, - secure=False, -) - -if not client.bucket_exists(BUCKET_NAME): - client.make_bucket(BUCKET_NAME) diff --git a/services/sounds/compression/src/prototype_lib.py b/services/sounds/compression/src/prototype_lib.py deleted file mode 100644 index bde806334..000000000 --- a/services/sounds/compression/src/prototype_lib.py +++ /dev/null @@ -1,58 +0,0 @@ -from pathlib import Path -import subprocess -import tempfile -import time -from minio_client import client, BUCKET_NAME - -INPUT_EXTS = {".wav", ".mp3", ".flac", ".ogg", ".m4a", ".aac", ".wma", ".opus"} - -RAW_PREFIX = "raw/" -COMP_PREFIX = "compressed/" - -def iter_input_files(): - """Yield MinIO object names in RAW_PREFIX with accepted extensions.""" - for obj in client.list_objects(BUCKET_NAME, prefix=RAW_PREFIX, recursive=True): - if any(obj.object_name.lower().endswith(ext) for ext in INPUT_EXTS): - yield obj.object_name - -def build_ffmpeg_cmds(in_local_path: Path, codec="all", flac_level="5", opus_bitrate="96k"): - """ - Return ffmpeg commands to encode a local file. - Output will be a temporary file (to upload after encode). - """ - cmds = [] - temp_dir = Path(tempfile.gettempdir()) - if codec in ("flac", "all"): - flac_out = temp_dir / f"{in_local_path.stem}.flac" - flac_cmd = [ - "ffmpeg", "-y", "-hide_banner", "-loglevel", "error", - "-i", str(in_local_path), - "-c:a", "flac", "-compression_level", flac_level, - str(flac_out) - ] - cmds.append(("flac", flac_cmd, flac_out)) - if codec in ("opus", "all"): - opus_out = temp_dir / f"{in_local_path.stem}.opus" - opus_cmd = [ - "ffmpeg", "-y", "-hide_banner", "-loglevel", "error", - "-i", str(in_local_path), - "-c:a", "libopus", "-b:a", opus_bitrate, - str(opus_out) - ] - cmds.append(("opus", opus_cmd, opus_out)) - return cmds - -def download_raw_to_temp(obj_name: str) -> Path: - """Download MinIO raw object to temporary file.""" - local_path = Path(tempfile.gettempdir()) / Path(obj_name).name - client.fget_object(BUCKET_NAME, obj_name, str(local_path)) - return local_path - -def upload_compressed(local_path: Path): - """Upload encoded file to MinIO compressed folder.""" - target_name = f"{COMP_PREFIX}{int(time.time())}_{local_path.name}" - client.fput_object(BUCKET_NAME, target_name, str(local_path)) - return target_name - -def delete_object(obj_name: str): - client.remove_object(BUCKET_NAME, obj_name) \ No newline at end of file diff --git a/services/sounds/compression/src/run_bench.py b/services/sounds/compression/src/run_bench.py deleted file mode 100644 index d21f0f29c..000000000 --- a/services/sounds/compression/src/run_bench.py +++ /dev/null @@ -1,86 +0,0 @@ -from pathlib import Path -import time -import csv -from statistics import mean -import subprocess -from prototype_lib import iter_input_files, build_ffmpeg_cmds, download_raw_to_temp, upload_compressed, delete_object -from minio_client import BUCKET_NAME, client - -RES_DIR = Path("results") -RES_DIR.mkdir(exist_ok=True) - -def file_size_minio(obj_name: str) -> int: - """Return object size in bytes.""" - try: - stat = client.stat_object(BUCKET_NAME, obj_name) - return stat.size - except: - return 0 - -def run_and_profile(cmd): - import psutil - start = time.time() - proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) - parent = psutil.Process(proc.pid) - samples = [] - while proc.poll() is None: - cpu_total = 0.0 - for pr in [parent] + parent.children(recursive=True): - try: - cpu_total += pr.cpu_percent(interval=0.1) - except psutil.NoSuchProcess: - continue - samples.append(cpu_total) - out, err = proc.communicate() - wall = time.time() - start - avg_cpu = mean(samples) if samples else 0.0 - return proc.returncode, wall, avg_cpu, (out or b"") + (err or b"") - -def main(): - rows = [] - files = list(iter_input_files()) - if not files: - print("No raw files found in MinIO") - return - - for obj_name in files: - local_file = download_raw_to_temp(obj_name) - orig_size = local_file.stat().st_size - - for codec, cmd, outp in build_ffmpeg_cmds(local_file): - rc, wall, cpu, _ = run_and_profile(cmd) - if rc != 0: - print(f"[FAIL] {obj_name} ({codec})") - continue - - target_obj = upload_compressed(outp) - enc_size = file_size_minio(target_obj) - ratio = (orig_size / enc_size) if enc_size else 0.0 - - print(f"[OK] {obj_name} | {codec.upper()}: {enc_size} bytes, {wall:.2f}s, CPU~{cpu:.1f}% (ratio={ratio:.3f})") - - rows.append({ - "file": Path(obj_name).name, - "codec": codec, - "orig_bytes": orig_size, - "encoded_bytes": enc_size, - "compression_ratio_orig_over_encoded": round(ratio, 3), - "encode_time_sec": round(wall, 3), - "encode_cpu_avg_percent": round(cpu, 1), - }) - - # Clean up local encoded file - outp.unlink() - - local_file.unlink() - - if rows: - out_csv = RES_DIR / "benchmarks.csv" - with open(out_csv, "w", newline="", encoding="utf-8") as fh: - writer = csv.DictWriter(fh, fieldnames=list(rows[0].keys())) - writer.writeheader() - writer.writerows(rows) - print(f"Saved CSV: {out_csv}") - -if __name__ == "__main__": - main() \ No newline at end of file diff --git a/services/sounds/compression/src/tiering_job.py b/services/sounds/compression/src/tiering_job.py deleted file mode 100644 index b1c4b0e70..000000000 --- a/services/sounds/compression/src/tiering_job.py +++ /dev/null @@ -1,101 +0,0 @@ -from pathlib import Path -import time -import argparse -import subprocess -import tempfile -from prototype_lib import ( - iter_input_files, build_ffmpeg_cmds, download_raw_to_temp, - upload_compressed, delete_object, INPUT_EXTS -) -from minio_client import client, BUCKET_NAME - -DEFAULT_RAW_MAX_AGE_HOURS = 24 -DEFAULT_COMP_MAX_AGE_DAYS = 90 -DEFAULT_LONG_TERM_CODEC = "opus" - -def get_age_seconds(obj_name: str, mode: str = "mtime") -> float: - stat = client.stat_object(BUCKET_NAME, obj_name) - ts = getattr(stat, "last_modified", None) - if ts is None: - return 0 - # last_modified is datetime, convert to timestamp - return time.time() - ts.timestamp() - -def is_older_than(obj_name: str, age_seconds: int, mode: str) -> bool: - return get_age_seconds(obj_name, mode) >= age_seconds - -def encode_and_upload(obj_name: str, codec: str) -> str: - local_file = download_raw_to_temp(obj_name) - for c, cmd, outp in build_ffmpeg_cmds(local_file, codec=codec): - rc = subprocess.call(cmd) - if rc != 0: - raise RuntimeError(f"Encode failed: {obj_name} -> {codec}") - target_obj = upload_compressed(outp) - outp.unlink() - local_file.unlink() - return target_obj - -def cleanup_compressed(max_age_days: int, dry_run: bool) -> int: - if max_age_days <= 0: - return 0 - cutoff_sec = max_age_days * 86400 - deleted = 0 - for obj in client.list_objects(BUCKET_NAME, prefix="compressed/", recursive=True): - age = get_age_seconds(obj.object_name) - if age >= cutoff_sec: - if dry_run: - print(f"[DRY] would delete compressed: {obj.object_name}") - else: - delete_object(obj.object_name) - deleted += 1 - print(f"[DEL] compressed old: {obj.object_name}") - return deleted - -def main(): - ap = argparse.ArgumentParser(description="Two-tier storage job with MinIO") - ap.add_argument("--raw-max-age-hours", type=int, default=None) - ap.add_argument("--raw-max-age-minutes", type=int, default=None) - ap.add_argument("--age-mode", choices=["mtime", "ctime"], default="mtime") - ap.add_argument("--codec", choices=["opus", "flac"], default=DEFAULT_LONG_TERM_CODEC) - ap.add_argument("--delete-raw-after", action="store_true") - ap.add_argument("--compressed-max-age-days", type=int, default=DEFAULT_COMP_MAX_AGE_DAYS) - ap.add_argument("--dry-run", action="store_true") - args = ap.parse_args() - - if args.raw_max_age_minutes is not None: - raw_age_seconds = args.raw_max_age_minutes * 60 - elif args.raw_max_age_hours is not None: - raw_age_seconds = args.raw_max_age_hours * 3600 - else: - raw_age_seconds = DEFAULT_RAW_MAX_AGE_HOURS * 3600 - - processed, raw_deleted = 0, 0 - - for obj_name in iter_input_files(): - if is_older_than(obj_name, raw_age_seconds, args.age_mode): - if args.dry_run: - print(f"[DRY] would encode {obj_name} -> {args.codec.upper()}, delete RAW={args.delete_raw_after}") - processed += 1 - continue - - try: - target_obj = encode_and_upload(obj_name, args.codec) - print(f"[OK] {obj_name} -> {target_obj} ({args.codec})") - processed += 1 - - if args.delete_raw_after: - delete_object(obj_name) - raw_deleted += 1 - print(f"[DEL] raw: {obj_name}") - - except Exception as e: - print(f"[FAIL] {obj_name}: {e}") - - comp_deleted = cleanup_compressed(args.compressed_max_age_days, args.dry_run) - - print(f"Done. Processed={processed}, Raw deletions={raw_deleted}, Compressed deletions={comp_deleted}, " - f"Mode={args.age_mode}, Threshold={raw_age_seconds}s") - - -if __name__ == "__main__": - main() \ No newline at end of file diff --git a/services/sounds/compression/tests/test_prototype_lib.py b/services/sounds/compression/tests/test_prototype_lib.py deleted file mode 100644 index 6b4b21f50..000000000 --- a/services/sounds/compression/tests/test_prototype_lib.py +++ /dev/null @@ -1,515 +0,0 @@ -from pathlib import Path -import pytest -import tempfile -import shutil -from unittest.mock import patch, MagicMock -from src.prototype_lib import iter_input_files, build_ffmpeg_cmds, INPUT_EXTS, RAW_DIR, COMP_DIR - -# Test for iter_input_files function to ensure it retrieves all valid audio files. -def test_iter_input_files(): - files = list(iter_input_files()) # Get list of files returned by the function - assert len(files) > 0, "No input files found in the raw directory." # Ensure at least one file is found - for file in files: - assert file.suffix.lower() in {".wav", ".mp3", ".flac", ".ogg", ".m4a", ".aac", ".wma", ".opus"}, \ - f"Unsupported file format: {file.suffix}" # Ensure file has a valid audio extension - -# Test for build_ffmpeg_cmds function to validate the generated ffmpeg commands. -def test_build_ffmpeg_cmds(): - input_path = Path("data/raw/cat.wav") # Sample audio file - flac_cmd, opus_cmd = build_ffmpeg_cmds(input_path) # Generate ffmpeg commands for FLAC and Opus - - # Test FLAC command - assert flac_cmd[1][0] == "ffmpeg", "FLAC command does not start with 'ffmpeg'" # Check if the command starts correctly - assert "-c:a" in flac_cmd[1], "FLAC command does not specify codec" # Check if codec is specified - assert "flac" in flac_cmd[1], "FLAC command does not contain the 'flac' codec" # Ensure FLAC codec is used - - # Test Opus command - assert opus_cmd[1][0] == "ffmpeg", "Opus command does not start with 'ffmpeg'" # Check if the command starts correctly - assert "-c:a" in opus_cmd[1], "Opus command does not specify codec" # Check if codec is specified - assert "libopus" in opus_cmd[1], "Opus command does not contain the 'libopus' codec" # Ensure Opus codec is used - -# Test when the directory is empty, ensuring no files are returned. -def test_iter_input_files_empty_directory(): - """Test case where no input files are available.""" - empty_dir = Path("data/empty") # Path to an empty directory - # Ensure the directory is empty - if not empty_dir.exists(): - empty_dir.mkdir(parents=True) - - # Simulate the empty directory scenario - assert len(list(iter_input_files("data/empty"))) == 0, "The directory is empty, but files were found." - -# Test to ensure that invalid file extensions are not returned. -def test_iter_input_files_invalid_extension(): - """Test case where files with invalid extensions are present.""" - invalid_file = Path("data/raw/invalid_file.txt") # Invalid file extension - invalid_file.touch() # Create the invalid file - - files = list(iter_input_files()) # Get files from the directory - assert invalid_file not in files, f"Unexpected file with invalid extension: {invalid_file}" # Assert invalid file is not included - -# Additional tests for improving code coverage - -# Test with a custom directory to ensure files are handled properly. -def test_iter_input_files_with_custom_directory(): - """Test with a custom directory""" - with tempfile.TemporaryDirectory() as temp_dir: - temp_path = Path(temp_dir) - - # Create a valid audio file - audio_file = temp_path / "test.wav" - audio_file.touch() - - files = list(iter_input_files(temp_dir)) # Get files from the custom directory - assert len(files) == 1 # Ensure only one file is found - assert files[0].name == "test.wav" # Verify the file name - -# Test case for non-existent directory, should raise an error. -def test_iter_input_files_nonexistent_directory(): - """Test with a non-existent directory""" - nonexistent_dir = "/path/that/does/not/exist" # Path that doesn't exist - - with pytest.raises(ValueError, match="does not exist"): # Expecting a ValueError - list(iter_input_files(nonexistent_dir)) # Try accessing the non-existent directory - -# Test case to ensure case-insensitive handling of file extensions. -def test_iter_input_files_case_insensitive(): - """Test that the function handles extensions with uppercase letters""" - with tempfile.TemporaryDirectory() as temp_dir: - temp_path = Path(temp_dir) - - # Create files with both lowercase and uppercase extensions - files_to_create = ["test.WAV", "audio.MP3", "music.flac", "sound.OGG"] - - for filename in files_to_create: - (temp_path / filename).touch() - - found_files = list(iter_input_files(temp_dir)) # Get files from the directory - assert len(found_files) == 4 # Ensure 4 files are found - - # Verify that all expected files were found - found_names = [f.name for f in found_files] - for expected_file in files_to_create: - assert expected_file in found_names - -# Test to ensure subdirectories are ignored when iterating files. -def test_iter_input_files_with_subdirectories(): - """Test that the function ignores subdirectories""" - with tempfile.TemporaryDirectory() as temp_dir: - temp_path = Path(temp_dir) - - # Create a file in the main directory - audio_file = temp_path / "main.wav" - audio_file.touch() - - # Create a subdirectory with a file inside - sub_dir = temp_path / "subdirectory" - sub_dir.mkdir() - sub_audio = sub_dir / "sub.wav" - sub_audio.touch() - - files = list(iter_input_files(temp_dir)) # Get files from the main directory - # Should only find the file in the main directory - assert len(files) == 1 - assert files[0].name == "main.wav" - -# Test build_ffmpeg_cmds with custom parameters to ensure command customization works. -def test_build_ffmpeg_cmds_custom_parameters(): - """Test build_ffmpeg_cmds with custom parameters""" - input_path = Path("test_audio.mp3") # Sample audio file - - flac_cmd, opus_cmd = build_ffmpeg_cmds(input_path, flac_level="8", opus_bitrate="128k") # Custom parameters - - # Check if custom FLAC compression level is included - assert "8" in flac_cmd[1], "Custom FLAC level not found in command" - - # Check if custom Opus bitrate is included - assert "128k" in opus_cmd[1], "Custom Opus bitrate not found in command" - -# Test if the output paths for FLAC and Opus are built correctly. -def test_build_ffmpeg_cmds_output_paths(): - """Test that output paths are built correctly""" - input_path = Path("my_audio_file.wav") # Sample audio file - - flac_cmd, opus_cmd = build_ffmpeg_cmds(input_path) # Generate ffmpeg commands - - # Check if FLAC output path is correct - expected_flac_path = str(COMP_DIR / "my_audio_file.flac") - assert expected_flac_path in flac_cmd[1], "FLAC output path is incorrect" - - # Check if Opus output path is correct - expected_opus_path = str(COMP_DIR / "my_audio_file.opus") - assert expected_opus_path in opus_cmd[1], "Opus output path is incorrect" - -# Test to validate the structure of the return values from build_ffmpeg_cmds. -def test_build_ffmpeg_cmds_return_structure(): - """Test that the return structure is valid""" - input_path = Path("test.wav") # Sample audio file - - flac_result, opus_result = build_ffmpeg_cmds(input_path) # Generate ffmpeg commands - - # Ensure each result is a tuple of 3 elements - assert len(flac_result) == 3, "FLAC result should have 3 elements" - assert len(opus_result) == 3, "Opus result should have 3 elements" - - # Check if the first element is the codec name - assert flac_result[0] == "flac", "First element should be codec name" - assert opus_result[0] == "opus", "First element should be codec name" - - # Ensure the third element is a Path object - assert isinstance(flac_result[2], Path), "Third element should be Path object" - assert isinstance(opus_result[2], Path), "Third element should be Path object" - -# Test to ensure that the INPUT_EXTS constant matches the expected set of extensions. -def test_input_exts_constant(): - """Test that the INPUT_EXTS constant matches the expected formats""" - expected_formats = {".wav", ".mp3", ".flac", ".ogg", ".m4a", ".aac", ".wma", ".opus"} - assert INPUT_EXTS == expected_formats, "INPUT_EXTS constant doesn't match expected formats" - -# Test to validate that only valid audio files are returned from a directory containing mixed files. -def test_iter_input_files_mixed_files(): - """Test with a mix of valid and invalid files""" - with tempfile.TemporaryDirectory() as temp_dir: - temp_path = Path(temp_dir) - - # Create valid audio files - valid_files = ["audio1.wav", "music.mp3", "sound.flac"] - for filename in valid_files: - (temp_path / filename).touch() - - # Create invalid files - invalid_files = ["document.txt", "image.jpg", "video.mp4"] - for filename in invalid_files: - (temp_path / filename).touch() - - found_files = list(iter_input_files(temp_dir)) # Get files from the directory - - # Should find only the valid files - assert len(found_files) == 3 - found_names = [f.name for f in found_files] - - for valid_file in valid_files: - assert valid_file in found_names - - for invalid_file in invalid_files: - assert invalid_file not in found_names - -def test_build_ffmpeg_cmds_all_parameters(): - """Test all parameters in the ffmpeg command.""" - input_path = Path("test_file.wav") - - # Generate FFmpeg commands for FLAC and Opus with specific parameters - flac_result, opus_result = build_ffmpeg_cmds(input_path, flac_level="3", opus_bitrate="64k") - - flac_cmd = flac_result[1] - opus_cmd = opus_result[1] - - # Detailed tests for FLAC command - assert "-y" in flac_cmd, "Missing -y parameter in FLAC command" - assert "-hide_banner" in flac_cmd, "Missing -hide_banner parameter in FLAC command" - assert "-loglevel" in flac_cmd, "Missing -loglevel parameter in FLAC command" - assert "error" in flac_cmd, "Missing error loglevel in FLAC command" - assert "-i" in flac_cmd, "Missing -i parameter in FLAC command" - assert "-compression_level" in flac_cmd, "Missing compression_level in FLAC command" - assert "3" in flac_cmd, "Missing custom compression level in FLAC command" - - # Detailed tests for Opus command - assert "-y" in opus_cmd, "Missing -y parameter in Opus command" - assert "-hide_banner" in opus_cmd, "Missing -hide_banner parameter in Opus command" - assert "-loglevel" in opus_cmd, "Missing -loglevel parameter in Opus command" - assert "error" in opus_cmd, "Missing error loglevel in Opus command" - assert "-i" in opus_cmd, "Missing -i parameter in Opus command" - assert "-b:a" in opus_cmd, "Missing -b:a parameter in Opus command" - assert "64k" in opus_cmd, "Missing custom bitrate in Opus command" - -def test_iter_input_files_all_supported_formats(): - """Test that all supported formats are correctly identified.""" - with tempfile.TemporaryDirectory() as temp_dir: - temp_path = Path(temp_dir) - - # Create a file for each supported format - supported_formats = [".wav", ".mp3", ".flac", ".ogg", ".m4a", ".aac", ".wma", ".opus"] - - for ext in supported_formats: - test_file = temp_path / f"test{ext}" - test_file.touch() - - found_files = list(iter_input_files(temp_dir)) - assert len(found_files) == len(supported_formats), f"Expected {len(supported_formats)} files, found {len(found_files)}" - - # Ensure all formats were found - found_extensions = {f.suffix.lower() for f in found_files} - expected_extensions = set(supported_formats) - assert found_extensions == expected_extensions, "Not all supported formats were found" - -def test_build_ffmpeg_cmds_default_parameters(): - """Test that default parameters work correctly.""" - input_path = Path("default_test.mp3") - - flac_result, opus_result = build_ffmpeg_cmds(input_path) - - # Check default parameters - assert "5" in flac_result[1], "Default FLAC compression level (5) not found" - assert "96k" in opus_result[1], "Default Opus bitrate (96k) not found" - -def test_path_handling_with_special_characters(): - """Test handling of file names with special characters.""" - special_names = [ - "file with spaces.wav", - "file-with-dashes.mp3", - "file_with_underscores.flac", - "file.with.dots.ogg" - ] - - with tempfile.TemporaryDirectory() as temp_dir: - temp_path = Path(temp_dir) - - for name in special_names: - (temp_path / name).touch() - - found_files = list(iter_input_files(temp_dir)) - assert len(found_files) == len(special_names) - - found_names = [f.name for f in found_files] - for expected_name in special_names: - assert expected_name in found_names - -def test_comp_dir_creation_in_build_ffmpeg_cmds(): - """Test that the output paths point to the correct compression directory.""" - input_path = Path("some_audio.wav") - - flac_result, opus_result = build_ffmpeg_cmds(input_path) - - flac_output_path = flac_result[2] - opus_output_path = opus_result[2] - - # Check that the paths point to COMP_DIR - assert flac_output_path.parent == COMP_DIR, "FLAC output path parent is not COMP_DIR" - assert opus_output_path.parent == COMP_DIR, "Opus output path parent is not COMP_DIR" - - # Check that the extensions are correct - assert flac_output_path.suffix == ".flac", "FLAC output doesn't have .flac extension" - assert opus_output_path.suffix == ".opus", "Opus output doesn't have .opus extension" - -def test_iter_input_files_path_object_handling(): - """Test correct handling of Path objects.""" - with tempfile.TemporaryDirectory() as temp_dir: - temp_path = Path(temp_dir) - - audio_file = temp_path / "path_test.wav" - audio_file.touch() - - # Test with string - files_str = list(iter_input_files(temp_dir)) - - # Test with Path object - files_path = list(iter_input_files(temp_path)) - - assert len(files_str) == len(files_path) == 1 - assert files_str[0].name == files_path[0].name == "path_test.wav" - -def test_build_ffmpeg_cmds_extreme_parameters(): - """Test with extreme parameters.""" - input_path = Path("extreme_test.wav") - - # Extreme parameters - flac_result, opus_result = build_ffmpeg_cmds( - input_path, - flac_level="12", # Maximum - opus_bitrate="512k" # Very high - ) - - assert "12" in flac_result[1], "Extreme FLAC level not found" - assert "512k" in opus_result[1], "Extreme Opus bitrate not found" - -def test_constants_and_globals(): - """Test constants and global variables.""" - from scripts.prototype_lib import ROOT, RAW_DIR, COMP_DIR, INPUT_EXTS - - # Check that all constants are Path objects or sets - assert isinstance(ROOT, Path), "ROOT should be a Path object" - assert isinstance(RAW_DIR, Path), "RAW_DIR should be a Path object" - assert isinstance(COMP_DIR, Path), "COMP_DIR should be a Path object" - assert isinstance(INPUT_EXTS, set), "INPUT_EXTS should be a set" - - # Check that the paths are logical - assert RAW_DIR.name == "raw", "RAW_DIR should end with 'raw'" - assert COMP_DIR.name == "compressed", "COMP_DIR should end with 'compressed'" - - # Check that all extensions in INPUT_EXTS start with a dot - for ext in INPUT_EXTS: - assert ext.startswith("."), f"Extension {ext} should start with dot" - assert ext.islower(), f"Extension {ext} should be lowercase" - -def test_iter_input_files_file_vs_directory(): - """Test that the function ignores directories that look like audio files.""" - with tempfile.TemporaryDirectory() as temp_dir: - temp_path = Path(temp_dir) - - # Create a valid file - valid_file = temp_path / "audio.wav" - valid_file.touch() - - # Create a directory with a name that looks like an audio file - fake_audio_dir = temp_path / "fake_audio.mp3" - fake_audio_dir.mkdir() - - files = list(iter_input_files(temp_dir)) - - # Should find only the valid file - assert len(files) == 1 - assert files[0].name == "audio.wav" - -def test_build_ffmpeg_cmds_file_stem_handling(): - """Test that handling of file stems works correctly.""" - test_cases = [ - ("simple.wav", "simple"), - ("file.with.dots.mp3", "file.with.dots"), - ("no_extension", "no_extension"), - ("multiple.dots.in.name.flac", "multiple.dots.in.name") - ] - - for input_name, expected_stem in test_cases: - input_path = Path(input_name) - flac_result, opus_result = build_ffmpeg_cmds(input_path) - - flac_output = flac_result[2] - opus_output = opus_result[2] - - assert flac_output.stem == expected_stem, f"FLAC stem mismatch for {input_name}" - assert opus_output.stem == expected_stem, f"Opus stem mismatch for {input_name}" - -def test_iter_input_files_empty_files(): - """Test with empty files - should still be considered valid.""" - with tempfile.TemporaryDirectory() as temp_dir: - temp_path = Path(temp_dir) - - # Create empty files with valid extensions - empty_files = ["empty1.wav", "empty2.mp3", "empty3.flac"] - - for filename in empty_files: - (temp_path / filename).touch() - - files = list(iter_input_files(temp_dir)) - - assert len(files) == len(empty_files) - found_names = [f.name for f in files] - - for expected_file in empty_files: - assert expected_file in found_names - -def test_iter_input_files_hidden_files(): - """Test with hidden files (starting with a dot).""" - with tempfile.TemporaryDirectory() as temp_dir: - temp_path = Path(temp_dir) - - # Create regular files - normal_file = temp_path / "normal.wav" - normal_file.touch() - - # Create a hidden file - hidden_file = temp_path / ".hidden.mp3" - hidden_file.touch() - - files = list(iter_input_files(temp_dir)) - found_names = [f.name for f in files] - - # Hidden files should be found if they have a valid extension - assert "normal.wav" in found_names - assert ".hidden.mp3" in found_names - assert len(files) == 2 - -def test_build_ffmpeg_cmds_command_order(): - """Test that the order of parameters in the ffmpeg command is correct.""" - input_path = Path("test.wav") - flac_result, opus_result = build_ffmpeg_cmds(input_path) - - flac_cmd = flac_result[1] - opus_cmd = opus_result[1] - - # Check parameter order for FLAC - assert flac_cmd[0] == "ffmpeg" - assert flac_cmd[1] == "-y" - assert flac_cmd[2] == "-hide_banner" - assert "-i" in flac_cmd - assert "-c:a" in flac_cmd - - # Check parameter order for Opus - assert opus_cmd[0] == "ffmpeg" - assert opus_cmd[1] == "-y" - assert opus_cmd[2] == "-hide_banner" - assert "-i" in opus_cmd - assert "-c:a" in opus_cmd - -def test_input_path_as_string_in_build_ffmpeg_cmds(): - """Test that the function handles string paths correctly.""" - input_path = Path("string_path.wav") - - flac_result, opus_result = build_ffmpeg_cmds(input_path) - - # The path should appear as a string in the command - flac_cmd = flac_result[1] - opus_cmd = opus_result[1] - - input_str = str(input_path) - assert input_str in flac_cmd, "Input path not found as string in FLAC command" - assert input_str in opus_cmd, "Input path not found as string in Opus command" - -def test_comp_dir_mkdir_functionality(): - """Test that COMP_DIR is created correctly.""" - from scripts.prototype_lib import COMP_DIR - - # COMP_DIR should be defined and accessible - assert COMP_DIR is not None - assert isinstance(COMP_DIR, Path) - - # Check that the path is logical - assert "compressed" in str(COMP_DIR) - -def test_root_path_calculation(): - """Test that ROOT path calculation is correct.""" - from scripts.prototype_lib import ROOT - - # ROOT should be a Path object - assert isinstance(ROOT, Path) - - # ROOT should be an absolute path - assert ROOT.is_absolute() - -def test_various_audio_extensions_case_combinations(): - """Test different combinations of uppercase and lowercase audio extensions.""" - test_extensions = [ - ".wav", ".WAV", ".Wav", ".wAv", - ".mp3", ".MP3", ".Mp3", ".mP3", - ".flac", ".FLAC", ".Flac", ".fLaC" - ] - - with tempfile.TemporaryDirectory() as temp_dir: - temp_path = Path(temp_dir) - - for i, ext in enumerate(test_extensions): - test_file = temp_path / f"test{i}{ext}" - test_file.touch() - - files = list(iter_input_files(temp_dir)) - - # All files should be found (case insensitive) - assert len(files) == len(test_extensions) - -def test_build_ffmpeg_cmds_output_file_extensions(): - """Detailed test of output file extensions.""" - test_inputs = [ - ("audio.wav", ".flac", ".opus"), - ("music.mp3", ".flac", ".opus"), - ("sound.m4a", ".flac", ".opus") - ] - - for input_name, expected_flac_ext, expected_opus_ext in test_inputs: - input_path = Path(input_name) - flac_result, opus_result = build_ffmpeg_cmds(input_path) - - flac_output = flac_result[2] - opus_output = opus_result[2] - - assert flac_output.suffix == expected_flac_ext, f"Wrong FLAC extension for {input_name}" - assert opus_output.suffix == expected_opus_ext, f"Wrong Opus extension for {input_name}" diff --git a/services/sounds/compression/tests/test_run_bench.py b/services/sounds/compression/tests/test_run_bench.py deleted file mode 100644 index 0fd0a88d6..000000000 --- a/services/sounds/compression/tests/test_run_bench.py +++ /dev/null @@ -1,54 +0,0 @@ -from pathlib import Path -from src.run_bench import run_and_profile, file_size, main -import subprocess -import csv - -# Test the `run_and_profile` function -def test_run_and_profile(): - # Set up a test command for profiling - cmd = ["ffmpeg", "-version"] # Test with a simple command (you can replace it with your actual compression command) - - # Run the command and get results - rc, wall_time, cpu, output = run_and_profile(cmd, shell=True) - - # Test if the command ran successfully - assert rc == 0, f"Command failed with return code {rc}" - assert wall_time > 0, "Wall time should be greater than zero" - assert cpu >= 0, "CPU usage should be non-negative" - assert "ffmpeg" in output, "Expected output not found" # Check if the output contains 'ffmpeg' - -# Test the `file_size` function -def test_file_size(): - test_file = Path("data/raw/cat.wav") # Ensure this file exists for the test - - # Make sure the test file exists - assert test_file.exists(), "Test file does not exist" - - # Get the file size - size = file_size(test_file) - - # Test if the file size is greater than zero - assert size > 0, f"File size should be greater than zero, but got {size}" - -# Test the `main` function to check if the CSV is created correctly -def test_main(): - result_path = Path("results/benchmarks.csv") - - # Run the main function - main() - - # Check if the results file was created - assert result_path.exists(), "The result CSV file was not created" - - # Optionally check if it contains rows - with open(result_path, "r") as file: - reader = csv.DictReader(file) - rows = list(reader) - assert len(rows) > 0, "No rows in the results CSV" - -# Additional tests for edge cases -def test_file_size_invalid_file(): - """Test the file size function with a non-existent file""" - invalid_file = Path("data/raw/nonexistent_file.wav") - size = file_size(invalid_file) - assert size == 0, f"Expected file size to be 0, but got {size}" \ No newline at end of file diff --git a/services/sounds/compression/tests/test_tiering_job.py b/services/sounds/compression/tests/test_tiering_job.py deleted file mode 100644 index b81a13d29..000000000 --- a/services/sounds/compression/tests/test_tiering_job.py +++ /dev/null @@ -1,418 +0,0 @@ -import pytest -from pathlib import Path -import time -import subprocess -import tempfile -import shutil -import os -from unittest.mock import patch, MagicMock, mock_open, Mock -from src.tiering_job import ( - is_older_than, encode, cleanup_compressed, get_age_seconds, main -) - - -class TestGetAgeSeconds: - """Tests for the function get_age_seconds""" - - def test_get_age_seconds_mtime(self, tmp_path): - """Test the file age based on mtime""" - test_file = tmp_path / "test_file.txt" - test_file.write_text("test") - - age = get_age_seconds(test_file, "mtime") - assert age >= 0 - assert age < 1 # A newly created file should be less than a second old - - def test_get_age_seconds_ctime(self, tmp_path): - """Test the file age based on ctime""" - test_file = tmp_path / "test_file.txt" - test_file.write_text("test") - - age = get_age_seconds(test_file, "ctime") - assert age >= 0 - assert age < 1 # A newly created file should be less than a second old - - @patch('time.time') - @patch('os.stat') - def test_get_age_seconds_old_file(self, mock_stat, mock_time): - """Test the file age for an old file""" - mock_time.return_value = 1000000 # Simulated current time - - # Creating a mock stat object - mock_stat_result = Mock() - mock_stat_result.st_mtime = 996400 # 1000000 - 3600 (1 hour) - mock_stat_result.st_ctime = 996400 - mock_stat.return_value = mock_stat_result - - test_file = Path("old_file.txt") - - age_mtime = get_age_seconds(test_file, "mtime") - age_ctime = get_age_seconds(test_file, "ctime") - - assert age_mtime == 3600 # File is 3600 seconds (1 hour) old - assert age_ctime == 3600 # File is 3600 seconds (1 hour) old - - -class TestIsOlderThan: - """Tests for the function is_older_than""" - - @patch('scripts.tiering_job.get_age_seconds') - def test_is_older_than_true(self, mock_get_age): - """Test when the file is older than the threshold""" - mock_get_age.return_value = 7200 # 2 hours - test_path = Path("test.txt") - - result = is_older_than(test_path, 3600, "mtime") # Threshold of 1 hour - assert result is True - mock_get_age.assert_called_once_with(test_path, "mtime") - - @patch('scripts.tiering_job.get_age_seconds') - def test_is_older_than_false(self, mock_get_age): - """Test when the file is younger than the threshold""" - mock_get_age.return_value = 1800 # 30 minutes - test_path = Path("test.txt") - - result = is_older_than(test_path, 3600, "mtime") # Threshold of 1 hour - assert result is False - - @patch('scripts.tiering_job.get_age_seconds') - def test_is_older_than_equal(self, mock_get_age): - """Test when the file is exactly at the threshold age""" - mock_get_age.return_value = 3600 # Exactly 1 hour - test_path = Path("test.txt") - - result = is_older_than(test_path, 3600, "mtime") # Threshold of 1 hour - assert result is True # >= should return True - - def test_is_older_than_invalid_mode(self): - """Test invalid mode scenario""" - test_path = Path("test.txt") - - with pytest.raises(ValueError, match="Invalid mode: invalid"): - is_older_than(test_path, 3600, "invalid") - - def test_is_older_than_ctime_mode(self, tmp_path): - """Test the ctime mode scenario""" - test_file = tmp_path / "test_file.txt" - test_file.write_text("test") - - # Since the file is new, it should not be older than 10 seconds - result = is_older_than(test_file, 10, "ctime") - assert result is False - -class TestEncode: - """Tests for the 'encode' function""" - - @patch('scripts.tiering_job.build_ffmpeg_cmds') - @patch('subprocess.call') - def test_encode_success(self, mock_subprocess, mock_build_cmds): - """Test for successful encoding""" - test_path = Path("input.wav") # Input file - output_path = Path("output.flac") # Expected output file - - # Mock the ffmpeg command building - mock_build_cmds.return_value = [ - ("flac", ["ffmpeg", "-i", "input.wav", "output.flac"], output_path) - ] - mock_subprocess.return_value = 0 # Simulate success - - result = encode(test_path, "flac") # Call the encode function - - assert result == output_path # Check that the result matches the expected output - mock_build_cmds.assert_called_once_with(test_path) # Ensure the ffmpeg command was built with the correct input - mock_subprocess.assert_called_once() # Ensure subprocess.call was called once - - @patch('scripts.tiering_job.build_ffmpeg_cmds') - @patch('subprocess.call') - def test_encode_subprocess_failure(self, mock_subprocess, mock_build_cmds): - """Test for subprocess failure during encoding""" - test_path = Path("input.wav") - output_path = Path("output.flac") - - # Mock the ffmpeg command building - mock_build_cmds.return_value = [ - ("flac", ["ffmpeg", "-i", "input.wav", "output.flac"], output_path) - ] - mock_subprocess.return_value = 1 # Simulate failure (non-zero return code) - - # Assert that an exception is raised when encoding fails - with pytest.raises(RuntimeError, match="Encode failed: input.wav -> flac"): - encode(test_path, "flac") - - @patch('scripts.tiering_job.build_ffmpeg_cmds') - def test_encode_unsupported_codec(self, mock_build_cmds): - """Test for unsupported codec during encoding""" - test_path = Path("input.wav") - - # Mock the ffmpeg command building with supported codecs (flac, opus) - mock_build_cmds.return_value = [ - ("flac", ["ffmpeg", "-i", "input.wav", "output.flac"], Path("output.flac")), - ("opus", ["ffmpeg", "-i", "input.wav", "output.opus"], Path("output.opus")) - ] - - # Test for an unsupported codec (mp3) - with pytest.raises(ValueError, match="Unsupported codec: mp3"): - encode(test_path, "mp3") - - @patch('scripts.tiering_job.build_ffmpeg_cmds') - @patch('subprocess.call') - def test_encode_opus(self, mock_subprocess, mock_build_cmds): - """Test for encoding to opus format""" - test_path = Path("input.wav") - output_path = Path("output.opus") - - # Mock the ffmpeg command building for opus encoding - mock_build_cmds.return_value = [ - ("opus", ["ffmpeg", "-i", "input.wav", "output.opus"], output_path) - ] - mock_subprocess.return_value = 0 # Simulate success - - result = encode(test_path, "opus") # Call the encode function - assert result == output_path # Check that the result matches the expected output - -class TestCleanupCompressed: - """Tests for the cleanup_compressed function""" - - def test_cleanup_compressed_negative_age(self): - """Test for negative age""" - with pytest.raises(ValueError, match="max_age_days cannot be negative"): - cleanup_compressed(-1, dry_run=False) # This should raise an error because the age can't be negative - - def test_cleanup_compressed_zero_age(self): - """Test for zero age - should return 0 without performing any action""" - result = cleanup_compressed(0, dry_run=False) - assert result == 0 # No deletion should happen, so result should be 0 - - @patch('scripts.tiering_job.COMP_DIR') - @patch('time.time') - def test_cleanup_compressed_dry_run(self, mock_time, mock_comp_dir, capsys): - """Test for dry run mode""" - mock_time.return_value = 1000000 # Mocking the current time - - # Creating mock files - old_file = Mock() - old_file.name = "old_file.flac" - old_file.is_file.return_value = True - old_file_stat = Mock() - old_file_stat.st_mtime = 1000000 - (2 * 86400) # 2 days ago - old_file.stat.return_value = old_file_stat - - new_file = Mock() - new_file.name = "new_file.flac" - new_file.is_file.return_value = True - new_file_stat = Mock() - new_file_stat.st_mtime = 1000000 - 3600 # 1 hour ago - new_file.stat.return_value = new_file_stat - - mock_comp_dir.iterdir.return_value = [old_file, new_file] # Mocking the directory contents - - result = cleanup_compressed(1, dry_run=True) # In dry run mode, no files should be deleted - - assert result == 0 # Nothing should be deleted in dry run mode - captured = capsys.readouterr() - assert "[DRY] would delete compressed:" in captured.out # We expect a dry run message - - @patch('scripts.tiering_job.COMP_DIR') - @patch('time.time') - def test_cleanup_compressed_actual_deletion(self, mock_time, mock_comp_dir, capsys): - """Test for actual deletion""" - mock_time.return_value = 1000000 # Mocking the current time - - # Creating mock file to delete - old_file = Mock() - old_file.name = "old_file.flac" - old_file.is_file.return_value = True - old_file_stat = Mock() - old_file_stat.st_mtime = 1000000 - (2 * 86400) # 2 days ago - old_file.stat.return_value = old_file_stat - - mock_comp_dir.iterdir.return_value = [old_file] # Mocking the directory contents - - result = cleanup_compressed(1, dry_run=False) # In real mode, files should be deleted if older than 1 day - - assert result == 1 # One file should be deleted - old_file.unlink.assert_called_once() # Checking that unlink was called on the old file - captured = capsys.readouterr() - assert "[DEL] compressed old:" in captured.out # We expect a deletion message - - @patch('scripts.tiering_job.COMP_DIR') - @patch('time.time') - def test_cleanup_compressed_deletion_error(self, mock_time, mock_comp_dir, capsys): - """Test for deletion error""" - mock_time.return_value = 1000000 # Mocking the current time - - old_file = Mock() - old_file.name = "old_file.flac" - old_file.is_file.return_value = True - old_file_stat = Mock() - old_file_stat.st_mtime = 1000000 - (2 * 86400) # 2 days ago - old_file.stat.return_value = old_file_stat - old_file.unlink.side_effect = OSError("Permission denied") # Mocking an error when trying to delete - - mock_comp_dir.iterdir.return_value = [old_file] # Mocking the directory contents - - result = cleanup_compressed(1, dry_run=False) # Trying to delete with an error - - assert result == 0 # No file should be deleted due to the error - captured = capsys.readouterr() - assert "[WARN] failed to delete" in captured.out # We expect a warning message - - @patch('scripts.tiering_job.COMP_DIR') - @patch('time.time') - def test_cleanup_compressed_no_old_files(self, mock_time, mock_comp_dir): - """Test when there are no old files""" - mock_time.return_value = 1000000 # Mocking the current time - - new_file = Mock() - new_file.name = "new_file.flac" - new_file.is_file.return_value = True - new_file_stat = Mock() - new_file_stat.st_mtime = 1000000 - 3600 # 1 hour ago - new_file.stat.return_value = new_file_stat - - mock_comp_dir.iterdir.return_value = [new_file] # Mocking the directory contents - - result = cleanup_compressed(1, dry_run=False) # Trying to delete files older than 1 day, but no old files - - assert result == 0 # No files should be deleted since they are all too new - -class TestMain: - """Tests for the main function in the 'tiering_job.py' script.""" - - @patch('sys.argv', ['tiering_job.py', '--dry-run']) - @patch('scripts.tiering_job.iter_input_files') - @patch('scripts.tiering_job.cleanup_compressed') - def test_main_dry_run_default_settings(self, mock_cleanup, mock_iter, capsys): - """Test dry run with default settings""" - mock_iter.return_value = [] # Mock that no files are returned - mock_cleanup.return_value = 0 # Mock that cleanup doesn't delete anything - - main() # Run the main function - - # Assert that cleanup was called once with default parameters (90 days, dry_run=True) - mock_cleanup.assert_called_once_with(90, True) - captured = capsys.readouterr() # Capture the output printed to the console - assert "Done. Processed=0" in captured.out # Assert the output indicates no files were processed - - @patch('sys.argv', ['tiering_job.py', '--raw-max-age-minutes', '30', '--codec', 'opus']) - @patch('scripts.tiering_job.iter_input_files') - @patch('scripts.tiering_job.is_older_than') - @patch('scripts.tiering_job.encode') - @patch('scripts.tiering_job.cleanup_compressed') - def test_main_with_minutes_and_opus(self, mock_cleanup, mock_encode, mock_older, mock_iter): - """Test with custom settings: max age in minutes and codec set to opus""" - test_file = Mock() # Mock a file object - test_file.name = "test.wav" # Set the file name - output_file = Mock() # Mock an output file after encoding - output_file.exists.return_value = True # Mock that the encoded file exists - - mock_iter.return_value = [test_file] # Mock iter_input_files to return our test file - mock_older.return_value = True # Mock that the file is older than 30 minutes - mock_encode.return_value = output_file # Mock the encoding function - mock_cleanup.return_value = 1 # Mock cleanup indicating one file was processed - - main() # Run the main function - - # Assert that is_older_than was called with the correct arguments - mock_older.assert_called_with(test_file, 30 * 60, 'mtime') # 30 minutes in seconds - # Assert that encode was called with the correct codec ('opus') - mock_encode.assert_called_with(test_file, 'opus') - # Assert that cleanup was called with custom settings (90 days, not dry_run) - mock_cleanup.assert_called_with(90, False) - - @patch('sys.argv', ['tiering_job.py', '--delete-raw-after', '--age-mode', 'ctime']) - @patch('scripts.tiering_job.iter_input_files') - @patch('scripts.tiering_job.is_older_than') - @patch('scripts.tiering_job.encode') - @patch('scripts.tiering_job.cleanup_compressed') - def test_main_delete_raw_after(self, mock_cleanup, mock_encode, mock_older, mock_iter, capsys): - """Test deleting raw files after encoding""" - test_file = Mock() # Mock a file object - test_file.name = "test.wav" - output_file = Mock() # Mock the encoded output file - output_file.exists.return_value = True # Mock that the encoded file exists - - mock_iter.return_value = [test_file] # Mock iter_input_files to return our test file - mock_older.return_value = True # Mock that the file is older than 30 minutes - mock_encode.return_value = output_file # Mock the encoding function - mock_cleanup.return_value = 0 # Mock cleanup showing no files processed - - main() # Run the main function - - # Assert that the raw file was deleted after encoding - test_file.unlink.assert_called_once() - captured = capsys.readouterr() # Capture the console output - # Assert that the raw file deletion was logged in the output - assert "[DEL] raw:" in captured.out - - @patch('sys.argv', ['tiering_job.py', '--compressed-max-age-days', '30']) - @patch('scripts.tiering_job.iter_input_files') - @patch('scripts.tiering_job.cleanup_compressed') - def test_main_custom_compressed_age(self, mock_cleanup, mock_iter): - """Test setting a custom age for compressed files""" - mock_iter.return_value = [] # Mock that no files are returned - mock_cleanup.return_value = 5 # Mock that 5 files were processed - - main() # Run the main function - - # Assert that cleanup was called with the custom compressed file age (30 days) - mock_cleanup.assert_called_once_with(30, False) - - @patch('sys.argv', ['tiering_job.py']) - @patch('scripts.tiering_job.iter_input_files') - @patch('scripts.tiering_job.is_older_than') - @patch('scripts.tiering_job.encode') - @patch('scripts.tiering_job.cleanup_compressed') - def test_main_encode_failure(self, mock_cleanup, mock_encode, mock_older, mock_iter, capsys): - """Test handling of encoding failure""" - test_file = Mock() # Mock a file object - test_file.name = "test.wav" - - mock_iter.return_value = [test_file] # Mock iter_input_files to return our test file - mock_older.return_value = True # Mock that the file is older than 30 minutes - mock_encode.side_effect = RuntimeError("Encoding failed") # Mock encoding failure - mock_cleanup.return_value = 0 # Mock cleanup showing no files processed - - main() # Run the main function - - # Capture the console output - captured = capsys.readouterr() - # Assert that the failure was logged in the output - assert "[FAIL]" in captured.out - assert "Encoding failed" in captured.out - - @patch('sys.argv', ['tiering_job.py', '--raw-max-age-hours', '12']) - @patch('scripts.tiering_job.iter_input_files') - @patch('scripts.tiering_job.is_older_than') - @patch('scripts.tiering_job.cleanup_compressed') - def test_main_with_hours_setting(self, mock_cleanup, mock_older, mock_iter): - """Test setting file age in hours""" - mock_iter.return_value = [] # Mock that no files are returned - mock_older.return_value = False # Mock that the file is not older than the specified hours - mock_cleanup.return_value = 0 # Mock cleanup showing no files processed - - main() # Run the main function - - # Assert that cleanup was called with default settings (90 days, not dry_run) - mock_cleanup.assert_called_once_with(90, False) - - @patch('sys.argv', ['tiering_job.py', '--dry-run']) - @patch('scripts.tiering_job.iter_input_files') - @patch('scripts.tiering_job.is_older_than') - @patch('scripts.tiering_job.cleanup_compressed') - def test_main_dry_run_with_files(self, mock_cleanup, mock_older, mock_iter, capsys): - """Test dry run with files""" - test_file = Mock() # Mock a file object - test_file.name = "test.wav" - - mock_iter.return_value = [test_file] # Mock iter_input_files to return our test file - mock_older.return_value = True # Mock that the file is older than 30 minutes - mock_cleanup.return_value = 0 # Mock cleanup showing no files processed - - main() # Run the main function - - # Capture the console output - captured = capsys.readouterr() - # Assert that the dry run simulation was logged in the output - assert "[DRY] would encode" in captured.out - assert "Done. Processed=1" in captured.out \ No newline at end of file diff --git a/services/sounds/sounds_classifier/Dockerfile.classifier-svc b/services/sounds_classifier/Dockerfile.classifier-svc similarity index 63% rename from services/sounds/sounds_classifier/Dockerfile.classifier-svc rename to services/sounds_classifier/Dockerfile.classifier-svc index 9038c3808..14fc3ff43 100644 --- a/services/sounds/sounds_classifier/Dockerfile.classifier-svc +++ b/services/sounds_classifier/Dockerfile.classifier-svc @@ -1,71 +1,58 @@ -FROM python:3.12-slim - -# System deps + codecs + CA + Kafka/DB native libs (librdkafka, libpq) -RUN apt-get update && apt-get install -y --no-install-recommends \ - build-essential \ - libsndfile1 \ - ffmpeg \ - ca-certificates \ - wget curl \ - librdkafka1 \ - libpq5 \ - && rm -rf /var/lib/apt/lists/* - -WORKDIR /app - -# ---- Corporate CAs ---- -# Place your CA files under classify/certs/*.crt before build -COPY certs/*.crt /usr/local/share/ca-certificates/ -RUN update-ca-certificates - -ENV SSL_CERT_FILE=/etc/ssl/certs/ca-certificates.crt \ - REQUESTS_CA_BUNDLE=/etc/ssl/certs/ca-certificates.crt \ - PYTHONUNBUFFERED=1 - - -COPY requirements.txt /app/requirements.txt -RUN python -m pip install --upgrade pip \ - && pip install --no-cache-dir -r /app/requirements.txt - -# Install PyTorch CPU wheels from official index (kept separate for clearer errors/caching) -RUN pip install --no-cache-dir --index-url https://download.pytorch.org/whl/cpu torch==2.5.1+cpu - -COPY src/classification /app/classification - -RUN touch /app/classification/__init__.py - -# ---- Checkpoint bootstrap (download once at build if missing) ---- -# Configure via build-args or ENV: -ARG CHECKPOINT_URL="https://example.com/path/to/Cnn14_mAP=0.431.pth" -ARG CHECKPOINT_PATH="/app/classification/models/panns_data/Cnn14_mAP=0.431.pth" -ENV CHECKPOINT_URL=${CHECKPOINT_URL} \ - CHECKPOINT=${CHECKPOINT_PATH} - -# Create target folder and download checkpoint using your ensure_checkpoint() utility -RUN python - <<'PY' -import os, pathlib, sys -p = pathlib.Path(os.environ.get("CHECKPOINT", "/app/classification/models/panns_data/Cnn14_mAP=0.431.pth")) -url = os.environ.get("CHECKPOINT_URL", "") -p.parent.mkdir(parents=True, exist_ok=True) -# Only fetch if not exists -if not p.exists(): - try: - # Use your project helper if available - from classification.core.model_io import ensure_checkpoint - ensure_checkpoint(str(p), url) - except Exception as e: - # Fallback to curl if helper not available or fails - import subprocess - if not url: - raise RuntimeError("CHECKPOINT_URL is empty; cannot fetch checkpoint.") from e - subprocess.run(["curl", "-L", "-o", str(p), url], check=True) -print(f"Checkpoint ready at: {p} (exists={p.exists()})") -PY - -RUN mkdir -p /root/panns_data && \ - curl -L -o /root/panns_data/class_labels_indices.csv \ - https://storage.googleapis.com/us_audioset/youtube_corpus/v1/csv/class_labels_indices.csv - - -EXPOSE 8088 +FROM python:3.12-slim + +# System deps + codecs + CA + Kafka/DB native libs (librdkafka, libpq) +RUN apt-get update && apt-get install -y --no-install-recommends \ + build-essential \ + libsndfile1 \ + ffmpeg \ + ca-certificates \ + wget curl \ + librdkafka1 \ + libpq5 \ + && rm -rf /var/lib/apt/lists/* + +WORKDIR /app + +# ---- Corporate CAs ---- +# Place your CA files under classify/certs/*.crt before build +COPY certs/*.crt /usr/local/share/ca-certificates/ +RUN update-ca-certificates + +ENV SSL_CERT_FILE=/etc/ssl/certs/ca-certificates.crt \ + REQUESTS_CA_BUNDLE=/etc/ssl/certs/ca-certificates.crt \ + PYTHONUNBUFFERED=1 + + +COPY requirements.txt /app/requirements.txt +RUN python -m pip install --upgrade pip \ + && pip install --no-cache-dir -r /app/requirements.txt + +# Install PyTorch CPU wheels from official index (kept separate for clearer errors/caching) +RUN pip install --no-cache-dir --index-url https://download.pytorch.org/whl/cpu torch==2.5.1+cpu + +# ---- Checkpoint bootstrap (download once at build if missing) ---- +# Configure via build-args or ENV: +ARG CHECKPOINT_URL="https://example.com/path/to/Cnn14_mAP=0.431.pth" +ARG CHECKPOINT_PATH="/app/classification/models/panns_data/Cnn14_mAP=0.431.pth" +ENV CHECKPOINT_URL=${CHECKPOINT_URL} \ + CHECKPOINT=${CHECKPOINT_PATH} + +# Create target folder and download checkpoint WITHOUT importing project code +RUN set -eux; \ + p="${CHECKPOINT}"; \ + url="${CHECKPOINT_URL}"; \ + mkdir -p "$(dirname "$p")"; \ + if [ ! -f "$p" ]; then \ + [ -n "$url" ]; curl -L -o "$p" "$url"; \ + fi; \ + echo "Checkpoint ready at: $p" + +RUN mkdir -p /root/panns_data && \ + curl -L -o /root/panns_data/class_labels_indices.csv \ + https://storage.googleapis.com/us_audioset/youtube_corpus/v1/csv/class_labels_indices.csv + +COPY src/classification /app/classification +RUN touch /app/classification/__init__.py + +EXPOSE 8088 CMD ["uvicorn", "classification.app:app", "--host", "0.0.0.0", "--port", "8088", "--log-level", "info"] \ No newline at end of file diff --git a/services/sounds/sounds_classifier/README.md b/services/sounds_classifier/README.md similarity index 96% rename from services/sounds/sounds_classifier/README.md rename to services/sounds_classifier/README.md index d95b238dd..aa7dda159 100644 --- a/services/sounds/sounds_classifier/README.md +++ b/services/sounds_classifier/README.md @@ -1,59 +1,59 @@ -# 🎧 Sound Classifier Service (CNN14-based) - -Service that classifies audio files using CNN14 model. It: -1. Receives S3 object location (bucket+key) -2. Classifies the sound -3. Stores result in PostgreSQL (optional) -4. Sends alert to Kafka topic if specific sounds detected (optional) -Built with **FastAPI**, **PANNs (CNN14)**, **PostgreSQL**, and optional **Kafka alerts** for real-time monitoring. - -## Quick Start -```bash -docker compose up -d sounds_classifier -``` -Service runs on **http://localhost:8088** (see `docker-compose.yml`, port 8088). - -## API Usage -```json -POST /classify -{ - "s3_bucket": "your-bucket", - "s3_key": "path/to/audio.wav" -} -``` - -### Example Response -```json -{ - "label": "vehicle", - "probs": { - "vehicle": 0.93, - "animal": 0.05, - "shotgun": 0.02 - } -} -``` - -## Supported Audio Formats -- WAV, MP3, FLAC, OGG -- M4A, AAC, WMA, OPUS - - -## Health & Docs -- `GET /health` → basic readiness and model load status -- Swagger UI: [http://localhost:8088/docs](http://localhost:8088/docs) - -## 🧪 Testing -Run all tests (unit + integration): -```bash -pytest -v --cov=src --cov-report=term-missing -``` - -## System Requirements -- Docker and Docker Compose -- MinIO instance with access credentials - -## Notes - • First startup may take ~30s to load the CNN14 model into memory. - • Kafka alerts are optional; see `KAFKA_BROKERS` and `ALERTS_TOPIC`. +# 🎧 Sound Classifier Service (CNN14-based) + +Service that classifies audio files using CNN14 model. It: +1. Receives S3 object location (bucket+key) +2. Classifies the sound +3. Stores result in PostgreSQL (optional) +4. Sends alert to Kafka topic if specific sounds detected (optional) +Built with **FastAPI**, **PANNs (CNN14)**, **PostgreSQL**, and optional **Kafka alerts** for real-time monitoring. + +## Quick Start +```bash +docker compose up -d sounds_classifier +``` +Service runs on **http://localhost:8088** (see `docker-compose.yml`, port 8088). + +## API Usage +```json +POST /classify +{ + "s3_bucket": "your-bucket", + "s3_key": "path/to/audio.wav" +} +``` + +### Example Response +```json +{ + "label": "vehicle", + "probs": { + "vehicle": 0.93, + "animal": 0.05, + "shotgun": 0.02 + } +} +``` + +## Supported Audio Formats +- WAV, MP3, FLAC, OGG +- M4A, AAC, WMA, OPUS + + +## Health & Docs +- `GET /health` → basic readiness and model load status +- Swagger UI: [http://localhost:8088/docs](http://localhost:8088/docs) + +## 🧪 Testing +Run all tests (unit + integration): +```bash +pytest -v --cov=src --cov-report=term-missing +``` + +## System Requirements +- Docker and Docker Compose +- MinIO instance with access credentials + +## Notes + • First startup may take ~30s to load the CNN14 model into memory. + • Kafka alerts are optional; see `KAFKA_BROKERS` and `ALERTS_TOPIC`. • Database writes are handled through `classification.core.db_io_pg`. \ No newline at end of file diff --git a/services/sounds/sounds_classifier/requirements.txt b/services/sounds_classifier/requirements.txt similarity index 94% rename from services/sounds/sounds_classifier/requirements.txt rename to services/sounds_classifier/requirements.txt index f13de3deb..cb7054ecc 100644 --- a/services/sounds/sounds_classifier/requirements.txt +++ b/services/sounds_classifier/requirements.txt @@ -1,20 +1,20 @@ -# Web API -fastapi==0.115.5 -uvicorn[standard]==0.30.6 -pydantic==2.9.2 - -# Core scientific / audio -numpy==1.26.4 -scipy==1.11.4 -librosa==0.10.2.post1 -soundfile==0.12.1 -joblib==1.4.2 -scikit-learn==1.5.2 - -# Storage / messaging / DB -minio==7.2.7 -confluent-kafka==2.6.0 -psycopg2-binary==2.9.9 - -# PANNs helper (optional but referenced by imports) -panns-inference==0.1.1 +# Web API +fastapi==0.115.5 +uvicorn[standard]==0.30.6 +pydantic==2.9.2 + +# Core scientific / audio +numpy==1.26.4 +scipy==1.11.4 +librosa==0.10.2.post1 +soundfile==0.12.1 +joblib==1.4.2 +scikit-learn==1.5.2 + +# Storage / messaging / DB +minio==7.2.7 +confluent-kafka==2.6.0 +psycopg2-binary==2.9.9 + +# PANNs helper (optional but referenced by imports) +panns-inference==0.1.1 diff --git a/services/sounds/sounds_classifier/src/classification/__init__.py b/services/sounds_classifier/src/classification/__init__.py similarity index 96% rename from services/sounds/sounds_classifier/src/classification/__init__.py rename to services/sounds_classifier/src/classification/__init__.py index 5ac24be70..d163ffe0c 100644 --- a/services/sounds/sounds_classifier/src/classification/__init__.py +++ b/services/sounds_classifier/src/classification/__init__.py @@ -1,7 +1,7 @@ -""" -Top-level package for the project. -Avoid importing heavy submodules at package import time to keep imports safe for tests. -""" - -__all__ = ["__version__"] +""" +Top-level package for the project. +Avoid importing heavy submodules at package import time to keep imports safe for tests. +""" + +__all__ = ["__version__"] __version__ = "0.0.0" \ No newline at end of file diff --git a/services/sounds/sounds_classifier/src/classification/app.py b/services/sounds_classifier/src/classification/app.py similarity index 82% rename from services/sounds/sounds_classifier/src/classification/app.py rename to services/sounds_classifier/src/classification/app.py index c3bbcbb8a..a2b526747 100644 --- a/services/sounds/sounds_classifier/src/classification/app.py +++ b/services/sounds_classifier/src/classification/app.py @@ -1,186 +1,209 @@ -import logging -import time -from fastapi import FastAPI, HTTPException -from pydantic import BaseModel -from typing import Dict, Optional -import os -import numpy as np -import joblib - -from panns_inference import AudioTagging -from classification.core.model_io import SAMPLE_RATE -from classification.scripts import classify as cls_script -from classification.core.db_utils import ensure_run, open_db, resolve_file_id -from classification.core.db_io_pg import finish_run, upsert_file_aggregate - -app = FastAPI(title="Audio Classifier API", version="2.0.0") - -# --- Globals (singletons) --- -PANN_MODEL: Optional[AudioTagging] = None -SK_PIPELINE = None -DB_CONN = None -DB_RUN_ID = os.getenv("DB_RUN_ID", "api-default") -DB_SCHEMA = os.getenv("DB_SCHEMA", "agcloud_audio") - -CHECKPOINT_PATH = os.getenv( - "CHECKPOINT", - "/app/classification/models/panns_data/Cnn14_mAP=0.431.pth" -) -HEAD_PATH = os.getenv( - "HEAD", - "/app/classification/models/head/head_cnn14_rf.joblib" # adapt if different -) - -class ClassifyIn(BaseModel): - s3_bucket: str - s3_key: str - return_probs: bool = True - -class ClassifyOut(BaseModel): - label: str - probs: Dict[str, float] - - -@app.on_event("startup") -def load_models_on_startup() -> None: - """ - Load heavy models once and perform a short warm-up to avoid cold-start. - Also open DB connection and ensure run row exists. """ - global PANN_MODEL, SK_PIPELINE, DB_CONN - logger = logging.getLogger("uvicorn.error") - - logger.info("Loading models into memory...") - PANN_MODEL = AudioTagging(checkpoint_path=CHECKPOINT_PATH) - SK_PIPELINE = None - try: - if os.path.exists(HEAD_PATH): - SK_PIPELINE = joblib.load(HEAD_PATH) - logger.info("✅ SK pipeline loaded from HEAD.") - else: - logger.warning(f"HEAD pipeline not found at {HEAD_PATH}; using built-in head.") - except Exception as e: - logger.warning(f"HEAD pipeline load failed ({e}); using built-in head.") - - # 3) Warm-up forward pass with 1 second of silence - dummy = np.zeros((1, SAMPLE_RATE * 10), dtype=np.float32) # add batch dim - try: - _ = PANN_MODEL.inference(dummy) - logger.info("✅ PANN model warm-up complete.") - except Exception as e: - logger.warning(f"PANN warm-up skipped ({e})") - - # DB connect + ensure run - try: - DB_CONN = open_db() - ensure_run(DB_CONN, DB_RUN_ID) - logger.info(f"✅ DB connected; run '{DB_RUN_ID}' ensured in schema '{DB_SCHEMA}'.") - except Exception as e: - logger.error(f"DB init failed: {e}") - raise - - logger.info("✅ All models loaded and ready.") - -@app.on_event("shutdown") -def close_db_on_shutdown() -> None: - """ - Cleanly close the global DB connection on shutdown. - """ - global DB_CONN - try: - if DB_CONN is not None: - try: - finish_run(DB_CONN, DB_RUN_ID) - except Exception: - pass - DB_CONN.close() - except Exception: - pass - finally: - DB_CONN = None - -# dedicated API perf logger -api_logger = logging.getLogger("audio_cls.api") -api_logger.setLevel(logging.INFO) -if not api_logger.handlers: - h = logging.StreamHandler() - h.setFormatter(logging.Formatter("[%(asctime)s] [API] %(message)s", "%Y-%m-%d %H:%M:%S")) - api_logger.addHandler(h) - -@app.post("/classify", response_model=ClassifyOut) -def classify(body: ClassifyIn): - """ - Run the full classification pipeline: - - Download from MinIO (s3_bucket + s3_key) - - Model inference with open-set threshold - - DB upsert into agcloud_audio.file_aggregates - """ - start = time.perf_counter() - status_code = 200 - try: - # 1) Require the file to already exist in public.files → else 404 - try: - file_id = resolve_file_id(DB_CONN, bucket=body.s3_bucket, object_key=body.s3_key) - except ValueError as e: - # file not found in public.files → return 404 (do NOT create) - raise HTTPException(status_code=404, detail=str(e)) - - # 2) Run classification - result = cls_script.run_classification_job( - s3_bucket=body.s3_bucket, - s3_key=body.s3_key, - pann_model=PANN_MODEL, - sk_pipeline=SK_PIPELINE - ) - - # 3) Upsert aggregate to DB (JSONB) - upsert_file_aggregate(DB_CONN, { - "run_id": DB_RUN_ID, - "file_id": file_id, - "head_probs_json": result.get("probs", {}), - "head_pred_label": result.get("label"), - "head_pred_prob": result.get("pred_prob"), - "head_unknown_threshold": result.get("unknown_threshold"), - "head_is_another": result.get("is_another"), - "num_windows": result.get("num_windows"), - "agg_mode": result.get("agg_mode"), - "processing_ms": result.get("processing_ms"), - }) - - # 4) Build API response - out = {"label": result.get("label", ""), "probs": result.get("probs", {})} - if not body.return_probs: - out["probs"] = {} - return ClassifyOut(**out) - - except HTTPException as e: - status_code = e.status_code - raise - except Exception as e: - status_code = 500 - raise HTTPException(status_code=500, detail=str(e)) - finally: - elapsed_ms = (time.perf_counter() - start) * 1000.0 - api_logger.info( - f"path=/classify bucket={body.s3_bucket} key={body.s3_key} " - f"latency_ms={elapsed_ms:.2f} status={status_code}" - ) - -@app.get("/health") -def health(): - return { - "ok": True, - "pann_loaded": PANN_MODEL is not None, - "sk_pipeline_loaded": SK_PIPELINE is not None - } - -@app.middleware("http") -async def timing_middleware(request, call_next): - t0 = time.perf_counter() - response = await call_next(request) - elapsed_ms = (time.perf_counter() - t0) * 1000.0 - - # log only interesting routes (keep or adjust as you like) - if request.url.path in ("/classify", "/health"): - api_logger.info(f"path={request.url.path} status={response.status_code} latency_ms={elapsed_ms:.2f}") - - return response +import logging +import time +from fastapi import FastAPI, HTTPException +from pydantic import BaseModel +from typing import Dict, Optional +import os +import numpy as np +import joblib +from psycopg2 import extensions + +from panns_inference import AudioTagging +from classification.core.model_io import SAMPLE_RATE +from classification.scripts import classify as cls_script +from classification.core.db_utils import ensure_run, open_db, resolve_file_id +from classification.core.db_io_pg import finish_run, upsert_file_aggregate + +app = FastAPI(title="Audio Classifier API", version="2.0.0") + +# --- Globals (singletons) --- +PANN_MODEL: Optional[AudioTagging] = None +SK_PIPELINE = None +DB_CONN = None +DB_RUN_ID = os.getenv("DB_RUN_ID", "api-default") +DB_SCHEMA = os.getenv("DB_SCHEMA", "agcloud_audio") + +CHECKPOINT_PATH = os.getenv( + "CHECKPOINT", + "/app/classification/models/panns_data/Cnn14_mAP=0.431.pth" +) +HEAD_PATH = os.getenv( + "HEAD", + "/app/classification/models/head/head_cnn14_rf.joblib" # adapt if different +) + +class ClassifyIn(BaseModel): + s3_bucket: str + s3_key: str + return_probs: bool = True + +class ClassifyOut(BaseModel): + label: str + probs: Dict[str, float] + sent_alert: bool = True + alert_topic: Optional[str] = None + alert_skip_reason: Optional[str] = None + +def _ensure_conn_clean(conn): + """Rollback any non-idle transaction so the connection is usable.""" + try: + if conn is not None and conn.get_transaction_status() != extensions.TRANSACTION_STATUS_IDLE: + conn.rollback() + except Exception: + # swallow – if rollback itself fails, next ops will raise and be handled + pass + +@app.on_event("startup") +def load_models_on_startup() -> None: + """ + Load heavy models once and perform a short warm-up to avoid cold-start. + Also open DB connection and ensure run row exists. """ + global PANN_MODEL, SK_PIPELINE, DB_CONN + logger = logging.getLogger("uvicorn.error") + + logger.info("Loading models into memory...") + PANN_MODEL = AudioTagging(checkpoint_path=CHECKPOINT_PATH) + SK_PIPELINE = None + try: + if os.path.exists(HEAD_PATH): + SK_PIPELINE = joblib.load(HEAD_PATH) + logger.info("✅ SK pipeline loaded from HEAD.") + else: + logger.warning(f"HEAD pipeline not found at {HEAD_PATH}; using built-in head.") + except Exception as e: + logger.warning(f"HEAD pipeline load failed ({e}); using built-in head.") + + # 3) Warm-up forward pass with 1 second of silence + dummy = np.zeros((1, SAMPLE_RATE * 10), dtype=np.float32) # add batch dim + try: + _ = PANN_MODEL.inference(dummy) + logger.info("✅ PANN model warm-up complete.") + except Exception as e: + logger.warning(f"PANN warm-up skipped ({e})") + + # DB connect + ensure run + try: + DB_CONN = open_db() + ensure_run(DB_CONN, DB_RUN_ID) + logger.info(f"✅ DB connected; run '{DB_RUN_ID}' ensured in schema '{DB_SCHEMA}'.") + except Exception as e: + logger.error(f"DB init failed: {e}") + raise + + logger.info("✅ All models loaded and ready.") + +@app.on_event("shutdown") +def close_db_on_shutdown() -> None: + """ + Cleanly close the global DB connection on shutdown. + """ + global DB_CONN + try: + if DB_CONN is not None: + try: + finish_run(DB_CONN, DB_RUN_ID) + except Exception: + pass + DB_CONN.close() + except Exception: + pass + finally: + DB_CONN = None + +# dedicated API perf logger +api_logger = logging.getLogger("audio_cls.api") +api_logger.setLevel(logging.INFO) +if not api_logger.handlers: + h = logging.StreamHandler() + h.setFormatter(logging.Formatter("[%(asctime)s] [API] %(message)s", "%Y-%m-%d %H:%M:%S")) + api_logger.addHandler(h) + +@app.post("/classify", response_model=ClassifyOut) +def classify(body: ClassifyIn): + """ + Run the full classification pipeline: + - Download from MinIO (s3_bucket + s3_key) + - Model inference with open-set threshold + - DB upsert into agcloud_audio.file_aggregates + """ + start = time.perf_counter() + status_code = 200 + _ensure_conn_clean(DB_CONN) + try: + # 1) Require the file to already exist in public.sound_new_sounds_connections → else 404 + try: + file_id = resolve_file_id(DB_CONN, bucket=body.s3_bucket, object_key=body.s3_key) + except ValueError as e: + DB_CONN.rollback() + # file not found in public.sound_new_sounds_connections → return 404 (do NOT create) + raise HTTPException(status_code=404, detail=str(e)) + + # 2) Run classification + result = cls_script.run_classification_job( + s3_bucket=body.s3_bucket, + s3_key=body.s3_key, + pann_model=PANN_MODEL, + sk_pipeline=SK_PIPELINE + ) + + # 3) Upsert aggregate to DB (JSONB) + upsert_file_aggregate(DB_CONN, { + "run_id": DB_RUN_ID, + "file_id": file_id, + "head_probs_json": result.get("probs", {}), + "head_pred_label": result.get("label"), + "head_pred_prob": result.get("pred_prob"), + "head_unknown_threshold": result.get("unknown_threshold"), + "head_is_another": result.get("is_another"), + "num_windows": result.get("num_windows"), + "agg_mode": result.get("agg_mode"), + "processing_ms": result.get("processing_ms"), + }) + + # 4) Build API response (include alert status if exists) + out = {"label": result.get("label", ""), + "probs": result.get("probs", {}), + "sent_alert": bool(result.get("sent_alert", False)), + "alert_topic": result.get("alert_topic"), + "alert_skip_reason": result.get("alert_skip_reason"), + } + if not body.return_probs: + out["probs"] = {} + return ClassifyOut(**out) + + except HTTPException as e: + status_code = e.status_code + raise + except Exception as e: + try: + DB_CONN.rollback() + except Exception: + pass + status_code = 500 + raise HTTPException(status_code=500, detail=str(e)) + finally: + elapsed_ms = (time.perf_counter() - start) * 1000.0 + api_logger.info( + f"path=/classify bucket={body.s3_bucket} key={body.s3_key} " + f"latency_ms={elapsed_ms:.2f} status={status_code}" + ) + +@app.get("/health") +def health(): + return { + "ok": True, + "pann_loaded": PANN_MODEL is not None, + "sk_pipeline_loaded": SK_PIPELINE is not None + } + +@app.middleware("http") +async def timing_middleware(request, call_next): + t0 = time.perf_counter() + response = await call_next(request) + elapsed_ms = (time.perf_counter() - t0) * 1000.0 + + # log only interesting routes (keep or adjust as you like) + if request.url.path in ("/classify", "/health"): + api_logger.info(f"path={request.url.path} status={response.status_code} latency_ms={elapsed_ms:.2f}") + + return response diff --git a/services/sounds_classifier/src/classification/backbones/__init__.py b/services/sounds_classifier/src/classification/backbones/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/services/sounds/sounds_classifier/src/classification/backbones/cnn14.py b/services/sounds_classifier/src/classification/backbones/cnn14.py similarity index 97% rename from services/sounds/sounds_classifier/src/classification/backbones/cnn14.py rename to services/sounds_classifier/src/classification/backbones/cnn14.py index af4420017..afc7e327e 100644 --- a/services/sounds/sounds_classifier/src/classification/backbones/cnn14.py +++ b/services/sounds_classifier/src/classification/backbones/cnn14.py @@ -1,80 +1,80 @@ -from __future__ import annotations -from typing import Tuple, List, Optional -import numpy as np -from panns_inference import AudioTagging -from classification.core.model_io import _to_numpy, ensure_checkpoint - -def load_cnn14_model( - checkpoint_path: Optional[str] = None, - checkpoint_url: Optional[str] = None, - device: str = "cpu" -) -> AudioTagging: - """ - Load a CNN14 AudioTagging model. - Either checkpoint_path or checkpoint_url must be provided. - Always resolves to a local path via ensure_checkpoint. - """ - if not (checkpoint_path or checkpoint_url): - raise FileNotFoundError("Either checkpoint_path or checkpoint_url must be provided.") - - ckpt = ensure_checkpoint(checkpoint_path, checkpoint_url) # returns a local path - return AudioTagging(checkpoint_path=ckpt, device=device) - - -def run_embedding(at: AudioTagging, wav: np.ndarray) -> np.ndarray: - try: - res = at.inference(wav) - except Exception: - res = at.inference(wav[None, :]) - - emb = None - if isinstance(res, dict): - emb = res.get("embedding", None) - elif isinstance(res, tuple) and len(res) >= 2: - emb = res[1] - if emb is None: - raise RuntimeError("No embedding returned by panns_inference.") - return _to_numpy(emb).reshape(-1) - - -def run_cnn14_embedding(model: AudioTagging, wav: np.ndarray) -> np.ndarray: - """ - Run embedding extraction; validate input waveform. - Raises ValueError if wav is empty. - """ - wav = np.asarray(wav) - if wav.size == 0: - raise ValueError("waveform must not be empty") - if wav.dtype != np.float32: - wav = wav.astype(np.float32, copy=False) - return run_embedding(model, wav) - - -def run_cnn14_embeddings_batch(model: AudioTagging, windows: np.ndarray, batch_size: int = 32) -> np.ndarray: - """ - Compute embeddings for a batch of windows in shape (N, samples). - Returns array (N, emb_dim) float32. - """ - if windows.ndim != 2: - raise ValueError("windows must be 2D (N, samples)") - n = windows.shape[0] - embs = [] - i = 0 - while i < n: - j = min(i + batch_size, n) - chunk = np.array(windows[i:j], dtype=np.float32, copy=True, order="C") - # panns_inference supports batched input (N, samples) - res = model.inference(chunk) - if isinstance(res, dict): - emb = res.get("embedding") - elif isinstance(res, tuple) and len(res) >= 2: - emb = res[1] - else: - raise RuntimeError("Unexpected inference output") - e = _to_numpy(emb).astype(np.float32, copy=False) - if e.ndim == 1: - e = e[None, :] - embs.append(e) - i = j - E = np.concatenate(embs, axis=0).astype(np.float32, copy=False) - return E +from __future__ import annotations +from typing import Tuple, List, Optional +import numpy as np +from panns_inference import AudioTagging +from classification.core.model_io import _to_numpy, ensure_checkpoint + +def load_cnn14_model( + checkpoint_path: Optional[str] = None, + checkpoint_url: Optional[str] = None, + device: str = "cpu" +) -> AudioTagging: + """ + Load a CNN14 AudioTagging model. + Either checkpoint_path or checkpoint_url must be provided. + Always resolves to a local path via ensure_checkpoint. + """ + if not (checkpoint_path or checkpoint_url): + raise FileNotFoundError("Either checkpoint_path or checkpoint_url must be provided.") + + ckpt = ensure_checkpoint(checkpoint_path, checkpoint_url) # returns a local path + return AudioTagging(checkpoint_path=ckpt, device=device) + + +def run_embedding(at: AudioTagging, wav: np.ndarray) -> np.ndarray: + try: + res = at.inference(wav) + except Exception: + res = at.inference(wav[None, :]) + + emb = None + if isinstance(res, dict): + emb = res.get("embedding", None) + elif isinstance(res, tuple) and len(res) >= 2: + emb = res[1] + if emb is None: + raise RuntimeError("No embedding returned by panns_inference.") + return _to_numpy(emb).reshape(-1) + + +def run_cnn14_embedding(model: AudioTagging, wav: np.ndarray) -> np.ndarray: + """ + Run embedding extraction; validate input waveform. + Raises ValueError if wav is empty. + """ + wav = np.asarray(wav) + if wav.size == 0: + raise ValueError("waveform must not be empty") + if wav.dtype != np.float32: + wav = wav.astype(np.float32, copy=False) + return run_embedding(model, wav) + + +def run_cnn14_embeddings_batch(model: AudioTagging, windows: np.ndarray, batch_size: int = 32) -> np.ndarray: + """ + Compute embeddings for a batch of windows in shape (N, samples). + Returns array (N, emb_dim) float32. + """ + if windows.ndim != 2: + raise ValueError("windows must be 2D (N, samples)") + n = windows.shape[0] + embs = [] + i = 0 + while i < n: + j = min(i + batch_size, n) + chunk = np.array(windows[i:j], dtype=np.float32, copy=True, order="C") + # panns_inference supports batched input (N, samples) + res = model.inference(chunk) + if isinstance(res, dict): + emb = res.get("embedding") + elif isinstance(res, tuple) and len(res) >= 2: + emb = res[1] + else: + raise RuntimeError("Unexpected inference output") + e = _to_numpy(emb).astype(np.float32, copy=False) + if e.ndim == 1: + e = e[None, :] + embs.append(e) + i = j + E = np.concatenate(embs, axis=0).astype(np.float32, copy=False) + return E diff --git a/services/sounds_classifier/src/classification/core/__init__.py b/services/sounds_classifier/src/classification/core/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/services/sounds/sounds_classifier/src/classification/core/db_io_pg.py b/services/sounds_classifier/src/classification/core/db_io_pg.py similarity index 97% rename from services/sounds/sounds_classifier/src/classification/core/db_io_pg.py rename to services/sounds_classifier/src/classification/core/db_io_pg.py index 877ffebb1..5a1789e4c 100644 --- a/services/sounds/sounds_classifier/src/classification/core/db_io_pg.py +++ b/services/sounds_classifier/src/classification/core/db_io_pg.py @@ -1,112 +1,112 @@ -from __future__ import annotations - -import json -import re -from typing import Any, Dict, Optional - -import psycopg2 -import psycopg2.extras -from psycopg2.extensions import connection as PGConnection -from psycopg2 import sql -import logging - -LOGGER = logging.getLogger(__name__) -_SCHEMA_RE = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*$") - -def open_db(db_url: str, schema: str = "audio_cls") -> PGConnection: - if not db_url: - raise ValueError("db_url is required (e.g., postgresql://user:pass@host:port/db)") - if not _SCHEMA_RE.match(schema): - raise ValueError(f"invalid schema name: {schema}") - - conn = psycopg2.connect(db_url) - conn.autocommit = False - try: - with conn.cursor() as cur: - cur.execute(sql.SQL("CREATE SCHEMA IF NOT EXISTS {}").format(sql.Identifier(schema))) - cur.execute(sql.SQL("SET search_path TO {}, public").format(sql.Identifier(schema))) - conn.commit() - LOGGER.info("DB connected; schema=%s", schema) - except Exception: - conn.rollback() - LOGGER.exception("failed to init schema/search_path") - raise - return conn - -def upsert_run(conn: PGConnection, meta: Dict[str, Any]) -> None: - try: - with conn.cursor() as cur: - cur.execute( - """ - INSERT INTO runs - (run_id, model_name, checkpoint, head_path, labels_csv, - window_sec, hop_sec, pad_last, agg, topk, device, code_version, notes) - VALUES - (%(run_id)s, %(model_name)s, %(checkpoint)s, %(head_path)s, %(labels_csv)s, - %(window_sec)s, %(hop_sec)s, %(pad_last)s, %(agg)s, %(topk)s, %(device)s, %(code_version)s, %(notes)s) - ON CONFLICT (run_id) DO NOTHING - """, - meta, - ) - conn.commit() - LOGGER.debug("upsert_run: %s", meta.get("run_id")) - except Exception: - conn.rollback() - LOGGER.exception("upsert_run failed") - raise - -def finish_run(conn: PGConnection, run_id: str) -> None: - try: - with conn.cursor() as cur: - cur.execute("UPDATE runs SET finished_at = now() WHERE run_id = %s", (run_id,)) - conn.commit() - LOGGER.info("finish_run: %s", run_id) - except Exception: - conn.rollback() - LOGGER.exception("finish_run failed: %s", run_id) - raise - -def _jsonify(v: Any) -> psycopg2.extras.Json: - if isinstance(v, str): - try: - v = json.loads(v) - except Exception: - v = {"raw": v} - return psycopg2.extras.Json(v) - -def upsert_file_aggregate(conn: PGConnection, row: Dict[str, Any]) -> None: - data = dict(row) - if "head_probs_json" in data and data["head_probs_json"] is not None: - data["head_probs_json"] = _jsonify(data.get("head_probs_json")) - - try: - with conn.cursor() as cur: - cur.execute( - """ - INSERT INTO file_aggregates - (run_id, file_id, - head_probs_json, head_pred_label, head_pred_prob, head_unknown_threshold, head_is_another, - num_windows, agg_mode, processing_ms) - VALUES - (%(run_id)s, %(file_id)s, - %(head_probs_json)s, %(head_pred_label)s, %(head_pred_prob)s, %(head_unknown_threshold)s, %(head_is_another)s, - %(num_windows)s, %(agg_mode)s, %(processing_ms)s) - ON CONFLICT (run_id, file_id) DO UPDATE SET - head_probs_json = EXCLUDED.head_probs_json, - head_pred_label = EXCLUDED.head_pred_label, - head_pred_prob = EXCLUDED.head_pred_prob, - head_unknown_threshold = EXCLUDED.head_unknown_threshold, - head_is_another = EXCLUDED.head_is_another, - num_windows = EXCLUDED.num_windows, - agg_mode = EXCLUDED.agg_mode, - processing_ms = EXCLUDED.processing_ms - """, - data, - ) - conn.commit() - LOGGER.debug("upsert_file_aggregate: run=%s file=%s", data.get("run_id"), data.get("file_id")) - except Exception: - conn.rollback() - LOGGER.exception("upsert_file_aggregate failed") - raise - +from __future__ import annotations + +import json +import re +from typing import Any, Dict, Optional + +import psycopg2 +import psycopg2.extras +from psycopg2.extensions import connection as PGConnection +from psycopg2 import sql +import logging + +LOGGER = logging.getLogger(__name__) +_SCHEMA_RE = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*$") + +def open_db(db_url: str, schema: str = "audio_cls") -> PGConnection: + if not db_url: + raise ValueError("db_url is required (e.g., postgresql://user:pass@host:port/db)") + if not _SCHEMA_RE.match(schema): + raise ValueError(f"invalid schema name: {schema}") + + conn = psycopg2.connect(db_url) + conn.autocommit = False + try: + with conn.cursor() as cur: + cur.execute(sql.SQL("CREATE SCHEMA IF NOT EXISTS {}").format(sql.Identifier(schema))) + cur.execute(sql.SQL("SET search_path TO {}, public").format(sql.Identifier(schema))) + conn.commit() + LOGGER.info("DB connected; schema=%s", schema) + except Exception: + conn.rollback() + LOGGER.exception("failed to init schema/search_path") + raise + return conn + +def upsert_run(conn: PGConnection, meta: Dict[str, Any]) -> None: + try: + with conn.cursor() as cur: + cur.execute( + """ + INSERT INTO runs + (run_id, model_name, checkpoint, head_path, labels_csv, + window_sec, hop_sec, pad_last, agg, topk, device, code_version, notes) + VALUES + (%(run_id)s, %(model_name)s, %(checkpoint)s, %(head_path)s, %(labels_csv)s, + %(window_sec)s, %(hop_sec)s, %(pad_last)s, %(agg)s, %(topk)s, %(device)s, %(code_version)s, %(notes)s) + ON CONFLICT (run_id) DO NOTHING + """, + meta, + ) + conn.commit() + LOGGER.debug("upsert_run: %s", meta.get("run_id")) + except Exception: + conn.rollback() + LOGGER.exception("upsert_run failed") + raise + +def finish_run(conn: PGConnection, run_id: str) -> None: + try: + with conn.cursor() as cur: + cur.execute("UPDATE runs SET finished_at = now() WHERE run_id = %s", (run_id,)) + conn.commit() + LOGGER.info("finish_run: %s", run_id) + except Exception: + conn.rollback() + LOGGER.exception("finish_run failed: %s", run_id) + raise + +def _jsonify(v: Any) -> psycopg2.extras.Json: + if isinstance(v, str): + try: + v = json.loads(v) + except Exception: + v = {"raw": v} + return psycopg2.extras.Json(v) + +def upsert_file_aggregate(conn: PGConnection, row: Dict[str, Any]) -> None: + data = dict(row) + if "head_probs_json" in data and data["head_probs_json"] is not None: + data["head_probs_json"] = _jsonify(data.get("head_probs_json")) + + try: + with conn.cursor() as cur: + cur.execute( + """ + INSERT INTO file_aggregates + (run_id, file_id, + head_probs_json, head_pred_label, head_pred_prob, head_unknown_threshold, head_is_another, + num_windows, agg_mode, processing_ms) + VALUES + (%(run_id)s, %(file_id)s, + %(head_probs_json)s, %(head_pred_label)s, %(head_pred_prob)s, %(head_unknown_threshold)s, %(head_is_another)s, + %(num_windows)s, %(agg_mode)s, %(processing_ms)s) + ON CONFLICT (run_id, file_id) DO UPDATE SET + head_probs_json = EXCLUDED.head_probs_json, + head_pred_label = EXCLUDED.head_pred_label, + head_pred_prob = EXCLUDED.head_pred_prob, + head_unknown_threshold = EXCLUDED.head_unknown_threshold, + head_is_another = EXCLUDED.head_is_another, + num_windows = EXCLUDED.num_windows, + agg_mode = EXCLUDED.agg_mode, + processing_ms = EXCLUDED.processing_ms + """, + data, + ) + conn.commit() + LOGGER.debug("upsert_file_aggregate: run=%s file=%s", data.get("run_id"), data.get("file_id")) + except Exception: + conn.rollback() + LOGGER.exception("upsert_file_aggregate failed") + raise + diff --git a/services/sounds/sounds_classifier/src/classification/core/db_utils.py b/services/sounds_classifier/src/classification/core/db_utils.py similarity index 78% rename from services/sounds/sounds_classifier/src/classification/core/db_utils.py rename to services/sounds_classifier/src/classification/core/db_utils.py index 0ea1f4fae..f2b7f139e 100644 --- a/services/sounds/sounds_classifier/src/classification/core/db_utils.py +++ b/services/sounds_classifier/src/classification/core/db_utils.py @@ -1,122 +1,126 @@ -import os -import psycopg2 -from psycopg2 import sql -from typing import Optional - -FILES_SCHEMA = os.getenv("FILES_SCHEMA", "public") -FILES_TABLE = os.getenv("FILES_TABLE", "files") - -def _files_table_ql() -> sql.SQL: - return sql.SQL("{}.{}").format(sql.Identifier(FILES_SCHEMA), sql.Identifier(FILES_TABLE)) - -def open_db(): - host = os.getenv("DB_HOST", "postgres") - port = int(os.getenv("DB_PORT", "5432")) - db = os.getenv("DB_NAME", "missions_db") - user = os.getenv("DB_USER", "missions_user") - pwd = os.getenv("DB_PASSWORD", "pg123") - schema = os.getenv("DB_SCHEMA", "agcloud_audio") - - conn = psycopg2.connect(host=host, port=port, dbname=db, user=user, password=pwd) - conn.autocommit = False - with conn.cursor() as cur: - cur.execute(sql.SQL("SET search_path TO {}, public;").format(sql.Identifier(schema))) - return conn - -def ensure_file(conn, *, bucket: str, object_key: str, - size_bytes: Optional[int] = None, - sample_rate: Optional[int] = None, - duration_s: Optional[float] = None) -> int: - """Idempotent ensure in public.files by (bucket, object_key).""" - try: - with conn.cursor() as cur: - cur.execute( - sql.SQL("SELECT file_id FROM {} WHERE bucket = %s AND object_key = %s") - .format(_files_table_ql()), - (bucket, object_key), - ) - row = cur.fetchone() - if row: - return int(row[0]) - - cur.execute( - sql.SQL(""" - INSERT INTO {} (bucket, object_key, size_bytes, sample_rate, duration_s) - VALUES (%s, %s, %s, %s, %s) - RETURNING file_id - """).format(_files_table_ql()), - (bucket, object_key, size_bytes, sample_rate, duration_s), - ) - new_id = cur.fetchone()[0] - - conn.commit() - return int(new_id) - except Exception: - conn.rollback() - raise - -def ensure_run(conn, run_id: str): - """ - Ensure there is a row in agcloud_audio.runs for FK constraints. - This will INSERT ... ON CONFLICT DO NOTHING with reasonable defaults. - """ - import os - window_sec = float(os.getenv("WINDOW_SEC", "2.0")) - hop_sec = float(os.getenv("HOP_SEC", "0.5")) - pad_last = os.getenv("PAD_LAST", "true").lower() in ("1", "true", "yes", "on") - - with conn.cursor() as cur: - cur.execute(""" - INSERT INTO runs ( - run_id, model_name, checkpoint, head_path, labels_csv, - window_sec, hop_sec, pad_last, agg, topk, device, code_version, notes - ) - VALUES ( - %(run_id)s, %(model_name)s, %(checkpoint)s, %(head_path)s, %(labels_csv)s, - %(window_sec)s, %(hop_sec)s, %(pad_last)s, %(agg)s, %(topk)s, %(device)s, - %(code_version)s, %(notes)s - ) - ON CONFLICT (run_id) DO NOTHING - """, { - "run_id": run_id, - "model_name": os.getenv("MODEL_NAME", "panns_cnn14"), - "checkpoint": os.getenv("CHECKPOINT", "panns_cnn14.pth"), - "head_path": os.getenv("HEAD", ""), - "labels_csv": os.getenv("LABELS_CSV", ""), - "window_sec": float(os.getenv("WINDOW_SEC", "10")), - "hop_sec": float(os.getenv("HOP_SEC", "10")), - "pad_last": os.getenv("PAD_LAST", "false").lower() == "true", - "agg": os.getenv("AGG", "mean"), - "topk": int(os.getenv("TOPK", "3")), - "device": os.getenv("DEVICE", "cpu"), - "code_version": os.getenv("CODE_VERSION", ""), - "notes": os.getenv("RUN_NOTES", "created by API ensure_run") - }) - conn.commit() - -def resolve_file_id(conn, *, file_id: Optional[int] = None, - bucket: Optional[str] = None, object_key: Optional[str] = None) -> int: - """Select-only (NO insert). Raises ValueError if not found.""" - with conn.cursor() as cur: - if file_id is not None: - cur.execute( - sql.SQL("SELECT file_id FROM {} WHERE file_id = %s").format(_files_table_ql()), - (file_id,), - ) - row = cur.fetchone() - if row: - return int(row[0]) - raise ValueError(f"file_id {file_id} not found in {FILES_SCHEMA}.{FILES_TABLE}") - - if bucket is not None and object_key is not None: - cur.execute( - sql.SQL("SELECT file_id FROM {} WHERE bucket = %s AND object_key = %s") - .format(_files_table_ql()), - (bucket, object_key), - ) - row = cur.fetchone() - if row: - return int(row[0]) - raise ValueError(f"File s3://{bucket}/{object_key} not found in {FILES_SCHEMA}.{FILES_TABLE}") - - raise ValueError("Must provide file_id or (bucket, object_key)") +import os +import psycopg2 +from psycopg2 import sql +from typing import Optional + +FILES_SCHEMA = os.getenv("FILES_SCHEMA", "public") +FILES_TABLE = os.getenv("FILES_TABLE", "sound_new_sounds_connections") + +def _files_table_ql() -> sql.SQL: + return sql.SQL("{}.{}").format(sql.Identifier(FILES_SCHEMA), sql.Identifier(FILES_TABLE)) + +_KEY_COL = sql.Identifier("key") + +def open_db(): + host = os.getenv("DB_HOST", "postgres") + port = int(os.getenv("DB_PORT", "5432")) + db = os.getenv("DB_NAME", "missions_db") + user = os.getenv("DB_USER", "missions_user") + pwd = os.getenv("DB_PASSWORD", "pg123") + schema = os.getenv("DB_SCHEMA", "agcloud_audio") + + conn = psycopg2.connect(host=host, port=port, dbname=db, user=user, password=pwd) + conn.autocommit = False + with conn.cursor() as cur: + cur.execute(sql.SQL("SET search_path TO {}, public;").format(sql.Identifier(schema))) + return conn + +def ensure_file(conn, *, bucket: str, object_key: str, + size_bytes: Optional[int] = None, + sample_rate: Optional[int] = None, + duration_s: Optional[float] = None) -> int: + """Idempotent ensure in public.sound_new_sounds_connections by (bucket, object_key).""" + combined_key = f"{bucket}/{object_key}".lstrip("/") + try: + with conn.cursor() as cur: + cur.execute( + sql.SQL("SELECT id FROM {} WHERE {} = %s") + .format(_files_table_ql(),_KEY_COL), + (combined_key,), + ) + row = cur.fetchone() + if row: + return int(row[0]) + + cur.execute( + sql.SQL(""" + INSERT INTO {} ({}, size_bytes, sample_rate, duration_s) + VALUES (%s, %s, %s, %s) + RETURNING id + """).format(_files_table_ql(), _KEY_COL), + (combined_key, size_bytes, sample_rate, duration_s), + ) + new_id = cur.fetchone()[0] + + conn.commit() + return int(new_id) + except Exception: + conn.rollback() + raise + +def ensure_run(conn, run_id: str): + """ + Ensure there is a row in agcloud_audio.runs for FK constraints. + This will INSERT ... ON CONFLICT DO NOTHING with reasonable defaults. + """ + import os + window_sec = float(os.getenv("WINDOW_SEC", "2.0")) + hop_sec = float(os.getenv("HOP_SEC", "0.5")) + pad_last = os.getenv("PAD_LAST", "true").lower() in ("1", "true", "yes", "on") + + with conn.cursor() as cur: + cur.execute(""" + INSERT INTO runs ( + run_id, model_name, checkpoint, head_path, labels_csv, + window_sec, hop_sec, pad_last, agg, topk, device, code_version, notes + ) + VALUES ( + %(run_id)s, %(model_name)s, %(checkpoint)s, %(head_path)s, %(labels_csv)s, + %(window_sec)s, %(hop_sec)s, %(pad_last)s, %(agg)s, %(topk)s, %(device)s, + %(code_version)s, %(notes)s + ) + ON CONFLICT (run_id) DO NOTHING + """, { + "run_id": run_id, + "model_name": os.getenv("MODEL_NAME", "panns_cnn14"), + "checkpoint": os.getenv("CHECKPOINT", "panns_cnn14.pth"), + "head_path": os.getenv("HEAD", ""), + "labels_csv": os.getenv("LABELS_CSV", ""), + "window_sec": float(os.getenv("WINDOW_SEC", "10")), + "hop_sec": float(os.getenv("HOP_SEC", "10")), + "pad_last": os.getenv("PAD_LAST", "false").lower() == "true", + "agg": os.getenv("AGG", "mean"), + "topk": int(os.getenv("TOPK", "3")), + "device": os.getenv("DEVICE", "cpu"), + "code_version": os.getenv("CODE_VERSION", ""), + "notes": os.getenv("RUN_NOTES", "created by API ensure_run") + }) + conn.commit() + +def resolve_file_id(conn, *, file_id: Optional[int] = None, + bucket: Optional[str] = None, object_key: Optional[str] = None) -> int: + """Select-only (NO insert). Raises ValueError if not found.""" + with conn.cursor() as cur: + if file_id is not None: + cur.execute( + sql.SQL("SELECT id FROM {} WHERE id = %s").format(_files_table_ql()), + (file_id,), + ) + row = cur.fetchone() + if row: + return int(row[0]) + raise ValueError(f"id {file_id} not found in {FILES_SCHEMA}.{FILES_TABLE}") + + if bucket is not None and object_key is not None: + combined_key = f"{bucket}/{object_key}".lstrip("/") + cur.execute( + sql.SQL("SELECT id FROM {} WHERE {} = %s") + .format(_files_table_ql(), _KEY_COL), + (combined_key,), + ) + row = cur.fetchone() + if row: + return int(row[0]) + raise ValueError(f"File s3://{bucket}/{object_key} not found in {FILES_SCHEMA}.{FILES_TABLE}") + + raise ValueError("Must provide file_id or (bucket, object_key)") diff --git a/services/sounds/sounds_classifier/src/classification/core/model_io.py b/services/sounds_classifier/src/classification/core/model_io.py similarity index 97% rename from services/sounds/sounds_classifier/src/classification/core/model_io.py rename to services/sounds_classifier/src/classification/core/model_io.py index 2465b6bca..fcfa3a0f6 100644 --- a/services/sounds/sounds_classifier/src/classification/core/model_io.py +++ b/services/sounds_classifier/src/classification/core/model_io.py @@ -1,248 +1,248 @@ -from __future__ import annotations - -import pathlib -import shutil -import subprocess -from typing import Any, List, Optional, Tuple, Literal - -import numpy as np -import soundfile as sf -import librosa -import logging -import os -from numpy.lib.stride_tricks import sliding_window_view - -try: - import torch -except Exception: - torch = None - -LOGGER = logging.getLogger(__name__) - -SAMPLE_RATE = 32000 -MIN_SAMPLES = 16000 -HARD_EXTS = {".mp3", ".opus", ".m4a", ".aac", ".wma"} -SUPPORTED_EXTS = {".wav", ".mp3", ".flac", ".ogg", ".m4a", ".aac", ".wma", ".opus"} - -def ensure_numpy_1d(x): - """ - Force input to be numpy float32 vector (1-D). - Accepts numpy, torch.Tensor, TF tensors, and torch-like wrappers (duck-typed). - """ - # Torch-like (duck typing): has detach/cpu/numpy - if not isinstance(x, np.ndarray): - has_detach = hasattr(x, "detach") - has_cpu = hasattr(x, "cpu") - has_numpy = hasattr(x, "numpy") - if has_detach and has_cpu and has_numpy: - try: - x = x.detach().cpu().numpy() - except Exception: - pass - # Generic tensors (e.g., TF), expose .numpy() without detach/cpu - elif has_numpy and callable(getattr(x, "numpy", None)): - try: - x = x.numpy() - except Exception: - pass - - # Final conversion to numpy float32 - x = np.asarray(x, dtype=np.float32) - - # Flatten to 1-D - if x.ndim > 1: - x = x.reshape(-1) - return x - - -def has_ffmpeg() -> bool: - return shutil.which("ffmpeg") is not None - - -def decode_with_ffmpeg_to_float32_mono(path: str, target_sr: int = SAMPLE_RATE) -> np.ndarray: - cmd = ["ffmpeg", "-v", "error", "-i", path, "-vn", "-ac", "1", "-ar", str(target_sr), "-f", "f32le", "pipe:1"] - proc = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, check=True) - y = np.frombuffer(proc.stdout, dtype=np.float32) - if y.size < MIN_SAMPLES: - y = np.concatenate([y, np.zeros(MIN_SAMPLES - y.size, dtype=np.float32)], axis=0) - return y - - -def ensure_checkpoint(checkpoint_path: str, checkpoint_url: Optional[str]) -> str: - import urllib.request - p = pathlib.Path(checkpoint_path) - p.parent.mkdir(parents=True, exist_ok=True) - if p.exists(): - return str(p) - if not checkpoint_url: - raise FileNotFoundError(f"No checkpoint at {p}. Provide --checkpoint or --checkpoint-url.") - urllib.request.urlretrieve(checkpoint_url, p) - LOGGER.info("downloaded checkpoint to %s", p) - return str(p) - - -def load_audio(path: str, target_sr: int = SAMPLE_RATE) -> np.ndarray: - ext = pathlib.Path(path).suffix.lower() - - def _pad(y: np.ndarray) -> np.ndarray: - if y.size < MIN_SAMPLES: - y = np.concatenate([y, np.zeros(MIN_SAMPLES - y.size, dtype=np.float32)]) - return y - - # Compressed/streaming formats first (e.g., mp3, m4a, etc.) - if ext in HARD_EXTS: - try: - y, _ = librosa.load(path, sr=target_sr, mono=True) - y = ensure_numpy_1d(y) - return _pad(y) - except Exception: - if has_ffmpeg(): - LOGGER.warning("librosa failed; using ffmpeg fallback for %s", path) - y = decode_with_ffmpeg_to_float32_mono(path, target_sr) - y = ensure_numpy_1d(y) - return _pad(y) - LOGGER.exception("failed to load compressed audio: %s", path) - raise - - # Uncompressed / common wavs - try: - y, sr = sf.read(path, always_2d=False) - if hasattr(y, "ndim") and y.ndim > 1: - y = np.mean(y, axis=1) - - y = ensure_numpy_1d(y) - if int(sr) != int(target_sr): - y = librosa.resample(y, orig_sr=int(sr), target_sr=int(target_sr)) - y = ensure_numpy_1d(y) - - return _pad(y) - - except Exception: - try: - y, _ = librosa.load(path, sr=target_sr, mono=True) - y = ensure_numpy_1d(y) - return _pad(y) - except Exception: - if has_ffmpeg(): - LOGGER.warning("soundfile/librosa failed; using ffmpeg fallback for %s", path) - y = decode_with_ffmpeg_to_float32_mono(path, target_sr) - y = ensure_numpy_1d(y) - return _pad(y) - LOGGER.exception("failed to load audio: %s", path) - raise - - -def _to_numpy(x: Any) -> np.ndarray: - if (torch is not None) and hasattr(torch, "Tensor") and isinstance(x, torch.Tensor): # type: ignore - x = x.detach().cpu().numpy() - arr = np.asarray(x, dtype=np.float32) - if arr.ndim == 2: - if arr.shape[0] == 1: - arr = arr[0] - elif arr.shape[1] == 1: - arr = arr[:, 0] - else: - arr = arr.reshape(-1) - elif arr.ndim != 1: - arr = arr.reshape(-1) - return arr - - -def segment_waveform( - wav: np.ndarray, - sr: int = SAMPLE_RATE, - window_sec: float = 2.0, - hop_sec: float = 0.5, - pad_last: bool = True, -) -> List[np.ndarray]: - """ - Splits waveform into overlapping fixed-size windows. - Returns list of 1D numpy arrays (segments), each of length window_sec * sr. - """ - wav = np.asarray(wav, dtype=np.float32).reshape(-1) - win = max(1, int(round(window_sec * sr))) - hop = max(1, int(round(hop_sec * sr))) - n = wav.size - - segments: List[np.ndarray] = [] - if n == 0: - return segments - - i = 0 - while i + win <= n: - seg = wav[i: i + win].astype(np.float32) - segments.append(seg) - i += hop - - if pad_last and (i < n): - tail = wav[i:] - pad = np.zeros(win - tail.size, dtype=np.float32) - seg = np.concatenate([tail, pad], axis=0) - segments.append(seg) - elif not segments and pad_last: - pad = np.zeros(win - n, dtype=np.float32) - seg = np.concatenate([wav, pad], axis=0) - segments.append(seg) - - # ensure all are 1D np.float32 arrays - return [np.asarray(seg, dtype=np.float32).flatten() for seg in segments] - - -def segment_waveform_2d_view( - wav: np.ndarray, - sr: int = SAMPLE_RATE, - window_sec: float = 2.0, - hop_sec: float = 0.5, - pad_last: bool = True, -) -> np.ndarray: - """ - Return a 2D view of windows with shape (N, win) float32, minimizing copies. - The last window is padded if pad_last=True and needed (that one will copy). - """ - wav = np.asarray(wav, dtype=np.float32).reshape(-1) - win = max(1, int(round(window_sec * sr))) - hop = max(1, int(round(hop_sec * sr))) - n = wav.size - if n == 0: - return np.zeros((0, win), dtype=np.float32) - - if n >= win: - # sliding view for all full windows (no copy) - sw = sliding_window_view(wav, win)[::hop] # shape (N_full, win), view - if pad_last and ((n - win) % hop != 0): - tail_start = (sw.shape[0] * hop) - tail = wav[tail_start:] - pad = np.zeros(win - tail.size, dtype=np.float32) - last = np.concatenate([tail, pad], axis=0)[None, :] - return np.vstack([sw.astype(np.float32, copy=False), last.astype(np.float32, copy=False)]) - return sw.astype(np.float32, copy=False) - - # n < win - if pad_last: - pad = np.zeros(win - n, dtype=np.float32) - return np.concatenate([wav, pad], axis=0)[None, :] - return np.zeros((0, win), dtype=np.float32) - - -def aggregate_matrix(mat: np.ndarray, mode: Literal["mean", "max"] = "mean") -> np.ndarray: - if not isinstance(mat, np.ndarray): - raise TypeError("mat must be a numpy.ndarray") - if mat.ndim != 2: - raise ValueError("expected shape (num_windows, num_classes)") - if mat.shape[0] == 0: - raise ValueError("cannot aggregate an empty window matrix (num_windows == 0)") - if mat.shape[1] == 0: - raise ValueError("expected num_classes > 0") - if mode == "mean": - # Ignore NaNs when computing per-class means - v = np.nanmean(mat.astype(np.float32, copy=False), axis=0) - elif mode == "max": - # Ignore NaNs when computing per-class max - v = np.nanmax(mat.astype(np.float32, copy=False), axis=0) - else: - raise ValueError(f"Unsupported aggregation mode: {mode}") - - # Ensure finite float32 output; all-NaN columns become 0.0 - v = np.nan_to_num(v, nan=0.0, posinf=np.finfo(np.float32).max, neginf=np.finfo(np.float32).min) - return v.astype(np.float32, copy=False) - +from __future__ import annotations + +import pathlib +import shutil +import subprocess +from typing import Any, List, Optional, Tuple, Literal + +import numpy as np +import soundfile as sf +import librosa +import logging +import os +from numpy.lib.stride_tricks import sliding_window_view + +try: + import torch +except Exception: + torch = None + +LOGGER = logging.getLogger(__name__) + +SAMPLE_RATE = 32000 +MIN_SAMPLES = 16000 +HARD_EXTS = {".mp3", ".opus", ".m4a", ".aac", ".wma"} +SUPPORTED_EXTS = {".wav", ".mp3", ".flac", ".ogg", ".m4a", ".aac", ".wma", ".opus"} + +def ensure_numpy_1d(x): + """ + Force input to be numpy float32 vector (1-D). + Accepts numpy, torch.Tensor, TF tensors, and torch-like wrappers (duck-typed). + """ + # Torch-like (duck typing): has detach/cpu/numpy + if not isinstance(x, np.ndarray): + has_detach = hasattr(x, "detach") + has_cpu = hasattr(x, "cpu") + has_numpy = hasattr(x, "numpy") + if has_detach and has_cpu and has_numpy: + try: + x = x.detach().cpu().numpy() + except Exception: + pass + # Generic tensors (e.g., TF), expose .numpy() without detach/cpu + elif has_numpy and callable(getattr(x, "numpy", None)): + try: + x = x.numpy() + except Exception: + pass + + # Final conversion to numpy float32 + x = np.asarray(x, dtype=np.float32) + + # Flatten to 1-D + if x.ndim > 1: + x = x.reshape(-1) + return x + + +def has_ffmpeg() -> bool: + return shutil.which("ffmpeg") is not None + + +def decode_with_ffmpeg_to_float32_mono(path: str, target_sr: int = SAMPLE_RATE) -> np.ndarray: + cmd = ["ffmpeg", "-v", "error", "-i", path, "-vn", "-ac", "1", "-ar", str(target_sr), "-f", "f32le", "pipe:1"] + proc = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, check=True) + y = np.frombuffer(proc.stdout, dtype=np.float32) + if y.size < MIN_SAMPLES: + y = np.concatenate([y, np.zeros(MIN_SAMPLES - y.size, dtype=np.float32)], axis=0) + return y + + +def ensure_checkpoint(checkpoint_path: str, checkpoint_url: Optional[str]) -> str: + import urllib.request + p = pathlib.Path(checkpoint_path) + p.parent.mkdir(parents=True, exist_ok=True) + if p.exists(): + return str(p) + if not checkpoint_url: + raise FileNotFoundError(f"No checkpoint at {p}. Provide --checkpoint or --checkpoint-url.") + urllib.request.urlretrieve(checkpoint_url, p) + LOGGER.info("downloaded checkpoint to %s", p) + return str(p) + + +def load_audio(path: str, target_sr: int = SAMPLE_RATE) -> np.ndarray: + ext = pathlib.Path(path).suffix.lower() + + def _pad(y: np.ndarray) -> np.ndarray: + if y.size < MIN_SAMPLES: + y = np.concatenate([y, np.zeros(MIN_SAMPLES - y.size, dtype=np.float32)]) + return y + + # Compressed/streaming formats first (e.g., mp3, m4a, etc.) + if ext in HARD_EXTS: + try: + y, _ = librosa.load(path, sr=target_sr, mono=True) + y = ensure_numpy_1d(y) + return _pad(y) + except Exception: + if has_ffmpeg(): + LOGGER.warning("librosa failed; using ffmpeg fallback for %s", path) + y = decode_with_ffmpeg_to_float32_mono(path, target_sr) + y = ensure_numpy_1d(y) + return _pad(y) + LOGGER.exception("failed to load compressed audio: %s", path) + raise + + # Uncompressed / common wavs + try: + y, sr = sf.read(path, always_2d=False) + if hasattr(y, "ndim") and y.ndim > 1: + y = np.mean(y, axis=1) + + y = ensure_numpy_1d(y) + if int(sr) != int(target_sr): + y = librosa.resample(y, orig_sr=int(sr), target_sr=int(target_sr)) + y = ensure_numpy_1d(y) + + return _pad(y) + + except Exception: + try: + y, _ = librosa.load(path, sr=target_sr, mono=True) + y = ensure_numpy_1d(y) + return _pad(y) + except Exception: + if has_ffmpeg(): + LOGGER.warning("soundfile/librosa failed; using ffmpeg fallback for %s", path) + y = decode_with_ffmpeg_to_float32_mono(path, target_sr) + y = ensure_numpy_1d(y) + return _pad(y) + LOGGER.exception("failed to load audio: %s", path) + raise + + +def _to_numpy(x: Any) -> np.ndarray: + if (torch is not None) and hasattr(torch, "Tensor") and isinstance(x, torch.Tensor): # type: ignore + x = x.detach().cpu().numpy() + arr = np.asarray(x, dtype=np.float32) + if arr.ndim == 2: + if arr.shape[0] == 1: + arr = arr[0] + elif arr.shape[1] == 1: + arr = arr[:, 0] + else: + arr = arr.reshape(-1) + elif arr.ndim != 1: + arr = arr.reshape(-1) + return arr + + +def segment_waveform( + wav: np.ndarray, + sr: int = SAMPLE_RATE, + window_sec: float = 2.0, + hop_sec: float = 0.5, + pad_last: bool = True, +) -> List[np.ndarray]: + """ + Splits waveform into overlapping fixed-size windows. + Returns list of 1D numpy arrays (segments), each of length window_sec * sr. + """ + wav = np.asarray(wav, dtype=np.float32).reshape(-1) + win = max(1, int(round(window_sec * sr))) + hop = max(1, int(round(hop_sec * sr))) + n = wav.size + + segments: List[np.ndarray] = [] + if n == 0: + return segments + + i = 0 + while i + win <= n: + seg = wav[i: i + win].astype(np.float32) + segments.append(seg) + i += hop + + if pad_last and (i < n): + tail = wav[i:] + pad = np.zeros(win - tail.size, dtype=np.float32) + seg = np.concatenate([tail, pad], axis=0) + segments.append(seg) + elif not segments and pad_last: + pad = np.zeros(win - n, dtype=np.float32) + seg = np.concatenate([wav, pad], axis=0) + segments.append(seg) + + # ensure all are 1D np.float32 arrays + return [np.asarray(seg, dtype=np.float32).flatten() for seg in segments] + + +def segment_waveform_2d_view( + wav: np.ndarray, + sr: int = SAMPLE_RATE, + window_sec: float = 2.0, + hop_sec: float = 0.5, + pad_last: bool = True, +) -> np.ndarray: + """ + Return a 2D view of windows with shape (N, win) float32, minimizing copies. + The last window is padded if pad_last=True and needed (that one will copy). + """ + wav = np.asarray(wav, dtype=np.float32).reshape(-1) + win = max(1, int(round(window_sec * sr))) + hop = max(1, int(round(hop_sec * sr))) + n = wav.size + if n == 0: + return np.zeros((0, win), dtype=np.float32) + + if n >= win: + # sliding view for all full windows (no copy) + sw = sliding_window_view(wav, win)[::hop] # shape (N_full, win), view + if pad_last and ((n - win) % hop != 0): + tail_start = (sw.shape[0] * hop) + tail = wav[tail_start:] + pad = np.zeros(win - tail.size, dtype=np.float32) + last = np.concatenate([tail, pad], axis=0)[None, :] + return np.vstack([sw.astype(np.float32, copy=False), last.astype(np.float32, copy=False)]) + return sw.astype(np.float32, copy=False) + + # n < win + if pad_last: + pad = np.zeros(win - n, dtype=np.float32) + return np.concatenate([wav, pad], axis=0)[None, :] + return np.zeros((0, win), dtype=np.float32) + + +def aggregate_matrix(mat: np.ndarray, mode: Literal["mean", "max"] = "mean") -> np.ndarray: + if not isinstance(mat, np.ndarray): + raise TypeError("mat must be a numpy.ndarray") + if mat.ndim != 2: + raise ValueError("expected shape (num_windows, num_classes)") + if mat.shape[0] == 0: + raise ValueError("cannot aggregate an empty window matrix (num_windows == 0)") + if mat.shape[1] == 0: + raise ValueError("expected num_classes > 0") + if mode == "mean": + # Ignore NaNs when computing per-class means + v = np.nanmean(mat.astype(np.float32, copy=False), axis=0) + elif mode == "max": + # Ignore NaNs when computing per-class max + v = np.nanmax(mat.astype(np.float32, copy=False), axis=0) + else: + raise ValueError(f"Unsupported aggregation mode: {mode}") + + # Ensure finite float32 output; all-NaN columns become 0.0 + v = np.nan_to_num(v, nan=0.0, posinf=np.finfo(np.float32).max, neginf=np.finfo(np.float32).min) + return v.astype(np.float32, copy=False) + diff --git a/services/sounds/sounds_classifier/src/classification/models/custom_labels.csv b/services/sounds_classifier/src/classification/models/custom_labels.csv similarity index 93% rename from services/sounds/sounds_classifier/src/classification/models/custom_labels.csv rename to services/sounds_classifier/src/classification/models/custom_labels.csv index 7cb2d9a9f..028674305 100644 --- a/services/sounds/sounds_classifier/src/classification/models/custom_labels.csv +++ b/services/sounds_classifier/src/classification/models/custom_labels.csv @@ -1,12 +1,12 @@ -index,display_name -0,predatory_animals -1,non_predatory_animals -2,birds -3,fire -4,footsteps -5,insects -6,screaming -7,shotgun -8,stormy_weather -9,streaming_water -10,vehicle +index,display_name +0,predatory_animals +1,non_predatory_animals +2,birds +3,fire +4,footsteps +5,insects +6,screaming +7,shotgun +8,stormy_weather +9,streaming_water +10,vehicle diff --git a/services/sounds/sounds_classifier/src/classification/models/head/head_cnn14_rf.joblib b/services/sounds_classifier/src/classification/models/head/head_cnn14_rf.joblib similarity index 100% rename from services/sounds/sounds_classifier/src/classification/models/head/head_cnn14_rf.joblib rename to services/sounds_classifier/src/classification/models/head/head_cnn14_rf.joblib diff --git a/services/sounds/sounds_classifier/src/classification/models/head/head_cnn14_rf.joblib.meta.json b/services/sounds_classifier/src/classification/models/head/head_cnn14_rf.joblib.meta.json similarity index 96% rename from services/sounds/sounds_classifier/src/classification/models/head/head_cnn14_rf.joblib.meta.json rename to services/sounds_classifier/src/classification/models/head/head_cnn14_rf.joblib.meta.json index b4ba52c84..fc380654f 100644 --- a/services/sounds/sounds_classifier/src/classification/models/head/head_cnn14_rf.joblib.meta.json +++ b/services/sounds_classifier/src/classification/models/head/head_cnn14_rf.joblib.meta.json @@ -1,23 +1,23 @@ -{ - "class_order": [ - "predatory_animals", - "non_predatory_animals", - "birds", - "fire", - "footsteps", - "insects", - "screaming", - "shotgun", - "stormy_weather", - "streaming_water", - "vehicle" - ], - "seed": 42, - "test_size": 0.2, - "train_dir": "C:\\Users\\user1\\Desktop\\programming\\kamatech\\AgCloud\\AgCloud\\classification\\data\\train", - "checkpoint": "C:\\Users\\user1\\panns_data\\Cnn14_mAP=0.431.pth", - "device": "cpu", - "backbone": "cnn14", - "embedding_dim": 2048, - "head_type": "rf" +{ + "class_order": [ + "predatory_animals", + "non_predatory_animals", + "birds", + "fire", + "footsteps", + "insects", + "screaming", + "shotgun", + "stormy_weather", + "streaming_water", + "vehicle" + ], + "seed": 42, + "test_size": 0.2, + "train_dir": "C:\\Users\\user1\\Desktop\\programming\\kamatech\\AgCloud\\AgCloud\\classification\\data\\train", + "checkpoint": "C:\\Users\\user1\\panns_data\\Cnn14_mAP=0.431.pth", + "device": "cpu", + "backbone": "cnn14", + "embedding_dim": 2048, + "head_type": "rf" } \ No newline at end of file diff --git a/services/sounds_classifier/src/classification/scripts/__init__.py b/services/sounds_classifier/src/classification/scripts/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/services/sounds/sounds_classifier/src/classification/scripts/alerts.py b/services/sounds_classifier/src/classification/scripts/alerts.py similarity index 58% rename from services/sounds/sounds_classifier/src/classification/scripts/alerts.py rename to services/sounds_classifier/src/classification/scripts/alerts.py index 77b552842..bccf69048 100644 --- a/services/sounds/sounds_classifier/src/classification/scripts/alerts.py +++ b/services/sounds_classifier/src/classification/scripts/alerts.py @@ -1,110 +1,183 @@ -import json -import time -import logging -from typing import Optional, Dict, Any - -from confluent_kafka import Producer, KafkaException - -LOGGER = logging.getLogger("audio_cls.alerts") -if not LOGGER.handlers: - # Minimal console handler if none configured by the app - h = logging.StreamHandler() - fmt = logging.Formatter("[%(levelname)s] %(name)s: %(message)s") - h.setFormatter(fmt) - LOGGER.addHandler(h) -LOGGER.setLevel(logging.INFO) - -# Cache one Producer per brokers string -_producer_cache: dict[str, Producer] = {} - -def _get_producer(brokers: str) -> Producer: - """ - Lazily create and cache a Kafka Producer for the given brokers string. - Do NOT load env here; configuration is passed by the caller (service). - """ - p = _producer_cache.get(brokers) - if p is not None: - return p - - conf = { - "bootstrap.servers": brokers, - "queue.buffering.max.ms": 5, # small batching for low latency - "message.timeout.ms": 5000, # fail fast - "socket.keepalive.enable": True, - "api.version.request": True, - } - try: - p = Producer(conf) - except KafkaException as e: - LOGGER.error("Failed to initialize Kafka producer (brokers=%s): %s", brokers, e) - raise - _producer_cache[brokers] = p - return p - -def _delivery_report(err, msg): - if err is not None: - LOGGER.warning("Kafka delivery failed: %s", err) - else: - LOGGER.info( - "Kafka delivered: topic=%s partition=%s offset=%s", - msg.topic(), msg.partition(), msg.offset() - ) - -def send_alert( - *, - brokers: str, - topic: str, - label: str, - probs: Dict[str, float], - meta: Optional[Dict[str, Any]] = None, -) -> bool: - """ - Send a JSON alert to Kafka. Returns True if enqueued+delivered (within flush timeout), - False on immediate failure. Delivery problems are logged via _delivery_report. - """ - payload = { - "label": label, - "probs": probs, - "meta": meta or {}, - "ts": int(time.time() * 1000), - } - try: - p = _get_producer(brokers) - p.produce(topic=topic, value=json.dumps(payload).encode("utf-8"), callback=_delivery_report) - # Serve delivery callbacks; flush returns number of undelivered messages (0 == success) - p.poll(0) - # undelivered = p.flush(5) - # if undelivered != 0: - # LOGGER.warning("Kafka flush returned %s undelivered message(s)", undelivered) - # return False - return True - except KafkaException as e: - LOGGER.error("Kafka exception while producing: %s", e) - return False - except BufferError as e: - LOGGER.error("Kafka local queue full: %s", e) - return False - except Exception as e: - LOGGER.error("Kafka produce error: %s", e) - return False - -# ---- Backwards compatibility shim ---- -def send_kafka_alert(file_path: str, label: str, prob: float) -> bool: - """ - Legacy helper kept for backward compatibility. Reads brokers/topic from env - ONLY if caller insists on using this function. Prefer send_alert(...). - """ - import os # local import to avoid env dependency on module load - brokers = os.getenv("KAFKA_BROKERS") or os.getenv("KAFKA_BROKER", "localhost:9092") - topic = os.getenv("ALERTS_TOPIC") or os.getenv("KAFKA_ALERTS_TOPIC", "alerts") - - payload_probs = {label: float(prob)} - meta = {"file_path": file_path, "source": "legacy_send_kafka_alert"} - - return send_alert( - brokers=brokers, - topic=topic, - label=label, - probs=payload_probs, - meta=meta, - ) +import json +import time +import logging +from typing import Optional, Dict, Any +from confluent_kafka import Producer, KafkaException +import uuid +from datetime import datetime, timezone + +LOGGER = logging.getLogger("audio_cls.alerts") +if not LOGGER.handlers: + # Minimal console handler if none configured by the app + h = logging.StreamHandler() + fmt = logging.Formatter("[%(levelname)s] %(name)s: %(message)s") + h.setFormatter(fmt) + LOGGER.addHandler(h) +LOGGER.setLevel(logging.INFO) + +# Cache one Producer per brokers string +_producer_cache: dict[str, Producer] = {} + +def _get_producer(brokers: str) -> Producer: + """ + Lazily create and cache a Kafka Producer for the given brokers string. + Do NOT load env here; configuration is passed by the caller (service). + """ + p = _producer_cache.get(brokers) + if p is not None: + return p + + conf = { + "bootstrap.servers": brokers, + "queue.buffering.max.ms": 5, # small batching for low latency + "message.timeout.ms": 5000, # fail fast + "socket.keepalive.enable": True, + "api.version.request": True, + } + try: + p = Producer(conf) + except KafkaException as e: + LOGGER.error("Failed to initialize Kafka producer (brokers=%s): %s", brokers, e) + raise + _producer_cache[brokers] = p + return p + +def _delivery_report(err, msg): + if err is not None: + LOGGER.warning("Kafka delivery failed: %s", err) + else: + LOGGER.info( + "Kafka delivered: topic=%s partition=%s offset=%s", + msg.topic(), msg.partition(), msg.offset() + ) + +def _iso_utc(dt: datetime) -> str: + return dt.astimezone(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ") + +def send_alert( + *, + brokers: str, + topic: str, + label: str, + probs: Dict[str, float], + meta: Optional[Dict[str, Any]] = None, +) -> bool: + """ + Send a JSON alert to Kafka. Returns True if enqueued+delivered (within flush timeout), + False on immediate failure. Delivery problems are logged via _delivery_report. + """ + payload = { + "label": label, + "probs": probs, + "meta": meta or {}, + "ts": int(time.time() * 1000), + } + try: + p = _get_producer(brokers) + p.produce(topic=topic, value=json.dumps(payload).encode("utf-8"), callback=_delivery_report) + # Serve delivery callbacks; flush returns number of undelivered messages (0 == success) + p.poll(0) + # undelivered = p.flush(5) + # if undelivered != 0: + # LOGGER.warning("Kafka flush returned %s undelivered message(s)", undelivered) + # return False + return True + except KafkaException as e: + LOGGER.error("Kafka exception while producing: %s", e) + return False + except BufferError as e: + LOGGER.error("Kafka local queue full: %s", e) + return False + except Exception as e: + LOGGER.error("Kafka produce error: %s", e) + return False + +# ---- Backwards compatibility shim ---- +def send_kafka_alert(file_path: str, label: str, prob: float) -> bool: + """ + Legacy helper kept for backward compatibility. Reads brokers/topic from env + ONLY if caller insists on using this function. Prefer send_alert(...). + """ + import os # local import to avoid env dependency on module load + brokers = os.getenv("KAFKA_BROKERS") or os.getenv("KAFKA_BROKER", "localhost:9092") + topic = os.getenv("ALERTS_TOPIC") or os.getenv("KAFKA_ALERTS_TOPIC", "alerts") + + payload_probs = {label: float(prob)} + meta = {"file_path": file_path, "source": "legacy_send_kafka_alert"} + + return send_alert( + brokers=brokers, + topic=topic, + label=label, + probs=payload_probs, + meta=meta, + ) + +# ---- Structured alert with strict required fields ---- +REQUIRED_FIELDS = ("alert_id", "alert_type", "device_id", "started_at") + +def send_structured_alert( + *, + brokers: str, + topic: str = "alerts", + alert_type: str, + device_id: str, + started_at: str, + ended_at: Optional[str] = None, + confidence: Optional[float] = None, + severity: Optional[int] = None, + area: Optional[str] = None, + lat: Optional[float] = None, + lon: Optional[float] = None, + image_url: Optional[str] = None, + vod: Optional[str] = None, + hls: Optional[str] = None, + meta: Optional[Dict[str, Any]] = None, + alert_id: Optional[str] = None, + message_key: Optional[str] = None, +) -> bool: + """ + Send alert JSON to Kafka in the required schema. + Required: alert_id, alert_type, device_id, started_at (ISO-8601 Z). + Optional fields are included ONLY if explicitly provided (no defaults/guesses). + """ + payload: Dict[str, Any] = { + "alert_id": alert_id or str(uuid.uuid4()), + "alert_type": alert_type, + "device_id": device_id, + "started_at": started_at, + } + + # Append optional fields IFF provided (no guessing) + if ended_at: payload["ended_at"] = ended_at + if confidence is not None: payload["confidence"] = float(confidence) + if severity is not None: payload["severity"] = int(severity) + if area: payload["area"] = area + if lat is not None: payload["lat"] = float(lat) + if lon is not None: payload["lon"] = float(lon) + if image_url: payload["image_url"] = image_url + if vod: payload["vod"] = vod + if hls: payload["hls"] = hls + if meta is not None: payload["meta"] = meta + + missing = [f for f in REQUIRED_FIELDS if f not in payload or payload[f] in (None, "")] + if missing: + LOGGER.error("Structured alert missing required fields: %s", missing) + return False + + try: + p = _get_producer(brokers) + p.produce( + topic=topic, + value=json.dumps(payload).encode("utf-8"), + key=(message_key.encode("utf-8") if isinstance(message_key, str) else None), + callback=_delivery_report + ) + p.poll(0) + return True + except KafkaException as e: + LOGGER.error("Kafka exception while producing structured alert: %s", e) + return False + except Exception as e: + LOGGER.error("Kafka produce error (structured alert): %s", e) + return False \ No newline at end of file diff --git a/services/sounds/sounds_classifier/src/classification/scripts/classify.py b/services/sounds_classifier/src/classification/scripts/classify.py similarity index 62% rename from services/sounds/sounds_classifier/src/classification/scripts/classify.py rename to services/sounds_classifier/src/classification/scripts/classify.py index cc8d06dc9..45a13076e 100644 --- a/services/sounds/sounds_classifier/src/classification/scripts/classify.py +++ b/services/sounds_classifier/src/classification/scripts/classify.py @@ -1,301 +1,419 @@ -from __future__ import annotations - -import logging -import os -import tempfile -from pathlib import Path -import time -from typing import Dict, List, Optional, Tuple, Any -from panns_inference import AudioTagging -import numpy as np -import joblib - -from minio import Minio -from minio.error import S3Error - -from classification.core.model_io import ( - SAMPLE_RATE, - _to_numpy, - load_audio, # returns 1-D float32 mono @ SAMPLE_RATE - # segment_waveform, # returns List[np.ndarray] after our fix - segment_waveform_2d_view, - aggregate_matrix, -) -from classification.backbones.cnn14 import load_cnn14_model, run_cnn14_embedding, run_cnn14_embeddings_batch -from classification.scripts import alerts - -# ----------------------------- -# Environment configuration -# ----------------------------- -DEVICE = os.getenv("DEVICE", "cpu").strip().lower() -BACKBONE = os.getenv("BACKBONE", "cnn14").strip().lower() - -CHECKPOINT = os.getenv("CHECKPOINT") or "" -CHECKPOINT_URL = os.getenv("CHECKPOINT_URL") or "" - -HEAD_PATH = os.getenv("HEAD") or "" # joblib path -LABELS_CSV = os.getenv("LABELS_CSV") or "" # optional (if head has classes_, not needed) - -WINDOW_SEC = float(os.getenv("WINDOW_SEC", "2.0")) -HOP_SEC = float(os.getenv("HOP_SEC", "0.5")) -PAD_LAST = os.getenv("PAD_LAST", "true").strip().lower() in ("1", "true", "yes", "on") -AGG = os.getenv("AGG", "mean").strip().lower() # "mean" | "max" - -UNKNOWN_THRESHOLD = float(os.getenv("UNKNOWN_THRESHOLD", "0.55")) - -MINIO_ENDPOINT = os.getenv("MINIO_ENDPOINT", "minio:9000") -MINIO_ACCESS_KEY = os.getenv("MINIO_ACCESS_KEY", "minio") -MINIO_SECRET_KEY = os.getenv("MINIO_SECRET_KEY", "minio123") -MINIO_SECURE = os.getenv("MINIO_SECURE", "false").strip().lower() in ("1", "true", "yes", "on") - -ALLOWED_BUCKETS: List[str] = [b.strip() for b in os.getenv("ALLOWED_BUCKETS", "").split(",") if b.strip()] -ALLOWED_CONTENT_TYPES: List[str] = [t.strip() for t in os.getenv( - "ALLOWED_CONTENT_TYPES", - "audio/wav,audio/x-wav,audio/mpeg,audio/flac,audio/ogg,audio/mp4" -).split(",") if t.strip()] -MAX_BYTES = int(os.getenv("MAX_BYTES", str(50 * 1024 * 1024))) - -KAFKA_BROKERS = os.getenv("KAFKA_BROKERS", "kafka:9092") -ALERTS_TOPIC = os.getenv("ALERTS_TOPIC", "dev-robot-alerts") - -# ----------------------------- -# Lazy runtime (model/head/labels) -# ----------------------------- -class _Runtime: - model = None # CNN14 backbone - head = None # sklearn pipeline with predict_proba - classes: List[str] = [] # class names aligned to head output - -R = _Runtime() - -_MINIO_CLIENT = None - -def _get_minio(): - global _MINIO_CLIENT - if _MINIO_CLIENT is None: - _MINIO_CLIENT = Minio( - MINIO_ENDPOINT, access_key=MINIO_ACCESS_KEY, secret_key=MINIO_SECRET_KEY, secure=MINIO_SECURE - ) - return _MINIO_CLIENT - -def _load_backbone_once() -> None: - if R.model is not None: - return - if BACKBONE != "cnn14": - raise RuntimeError(f"Only BACKBONE=cnn14 is supported in this service, got {BACKBONE}") - # load_cnn14_model internally handles checkpoint/url (your impl) - R.model = load_cnn14_model(CHECKPOINT or None, device=DEVICE) - -def _load_head_once() -> None: - if R.head is not None: - return - if not HEAD_PATH: - raise RuntimeError("HEAD env var is required (path to joblib head)") - R.head = joblib.load(HEAD_PATH) - if not hasattr(R.head, "predict_proba"): - raise RuntimeError("HEAD must expose predict_proba(X) and classes_)") - - # 1) try labels from CSV if provided (most robust for production) - labels_csv = os.getenv("LABELS_CSV") or "" - if labels_csv: - from classification.core.model_io import load_labels_from_csv - labels = load_labels_from_csv(labels_csv) - if not labels: - raise RuntimeError(f"Labels CSV is empty or unreadable: {labels_csv}") - R.classes = labels - return - - # 2) else, try meta.json next to HEAD (HEAD_META env or HEAD+'.meta.json') - head_meta = os.getenv("HEAD_META") or (HEAD_PATH + ".meta.json") - labels_from_meta = [] - try: - if os.path.exists(head_meta): - import json - with open(head_meta, "r", encoding="utf-8") as f: - meta = json.load(f) - if isinstance(meta.get("class_order"), list) and len(meta["class_order"]) > 0: - labels_from_meta = [str(x) for x in meta["class_order"]] - except Exception as e: - print(f"⚠️ Warning: failed to parse HEAD meta: {e}") - - # 3) reconcile with head.classes_ - head_classes = list(getattr(R.head, "classes_", [])) - if labels_from_meta: - # if head.classes_ are [0..N-1], we map by index - if all(isinstance(c, (int, np.integer)) for c in head_classes): - if len(head_classes) != len(labels_from_meta): - raise RuntimeError( - f"Meta class_order length ({len(labels_from_meta)}) != head.classes_ length ({len(head_classes)})" - ) - R.classes = labels_from_meta - return - # else: if head.classes_ already hold real names, prefer them - R.classes = [str(c) for c in head_classes] if head_classes else labels_from_meta - return - - # 4) fallback to head.classes_ as strings - if head_classes: - R.classes = [str(c) for c in head_classes] - return - - # 5) no labels source found - raise RuntimeError( - "No labels source found. Provide LABELS_CSV, or HEAD_META with class_order, " - "or ensure the head exposes string class names via classes_." - ) - -# ----------------------------- -# Embedding/inference helpers -# ----------------------------- - -def _aggregate_probs(per_window_probs: np.ndarray) -> np.ndarray: - """ - Aggregate per-window class probabilities to a single clip-level vector. - Supports mean|max; returns 1-D float32. - """ - if per_window_probs.ndim != 2: - raise ValueError("expected shape (num_windows, num_classes)") - if per_window_probs.shape[0] == 0: - return np.zeros((per_window_probs.shape[1],), dtype=np.float32) - v = aggregate_matrix(per_window_probs, mode=AGG) - # When AGG=max, v might be logits-like — but we trained on probs, so it is already probabilities. - # If needed: apply softmax here. For a calibrated head (sklearn) it's already in [0,1]. - return v.astype(np.float32, copy=False) - -# ----------------------------- -# Public API for service -# ----------------------------- - -# Create a dedicated logger for performance metrics -perf_logger = logging.getLogger("audio_cls.perf") -perf_logger.setLevel(logging.INFO) -if not perf_logger.handlers: - h = logging.StreamHandler() - fmt = logging.Formatter("[%(asctime)s] [PERF] %(message)s", "%Y-%m-%d %H:%M:%S") - h.setFormatter(fmt) - perf_logger.addHandler(h) - -def classify_file( - path: str, - pann_model: Optional[AudioTagging] = None, - sk_pipeline: Optional[Any] = None -) -> Dict[str, object]: - t0 = time.perf_counter() - if sk_pipeline is None: - _load_head_once() - if pann_model is None: - _load_backbone_once() - - wav = np.array(load_audio(path, SAMPLE_RATE), dtype=np.float32, copy=True, order="C") - windows_2d = segment_waveform_2d_view( - wav, SAMPLE_RATE, window_sec=WINDOW_SEC, hop_sec=HOP_SEC, pad_last=PAD_LAST - ) - - num_windows = int(windows_2d.shape[0]) - if num_windows == 0: - result = { - "label": "another", - "probs": {c: 0.0 for c in R.classes}, - "pred_prob": 0.0, - "unknown_threshold": UNKNOWN_THRESHOLD, - "is_another": True, - "num_windows": 0, - "agg_mode": AGG, - "processing_ms": int((time.perf_counter() - t0) * 1000.0), - } - return result - - # Batch embeddings - if pann_model is not None: - win = np.array(windows_2d, dtype=np.float32, copy=True, order="C") - seg = pann_model.inference(win) - if isinstance(seg, dict): - seg = seg.get("embedding") - elif isinstance(seg, tuple) and len(seg) >= 2: - seg = seg[1] - seg = np.asarray(seg, dtype=np.float32) - if seg.ndim == 1: - seg = seg[None, :] - else: - win = np.array(windows_2d, dtype=np.float32, copy=True, order="C") - seg = run_cnn14_embeddings_batch(R.model, win, batch_size=32) - - # Head predict_proba - clf = sk_pipeline if sk_pipeline is not None else R.head - per_window_probs = np.asarray(clf.predict_proba(seg), dtype=np.float32) - - # Aggregate and finalize - agg_vec = _aggregate_probs(per_window_probs) - k = int(np.argmax(agg_vec)) - top_prob = float(agg_vec[k]) - top_label = R.classes[k] - final_label = top_label if top_prob >= UNKNOWN_THRESHOLD else "another" - probs = {cls: float(p) for cls, p in zip(R.classes, agg_vec)} - - processing_ms = int((time.perf_counter() - t0) * 1000.0) - - return { - "label": final_label, - "probs": probs, - "pred_prob": top_prob, - "unknown_threshold": UNKNOWN_THRESHOLD, - "is_another": (final_label == "another"), - "num_windows": num_windows, - "agg_mode": AGG, - "processing_ms": processing_ms, - } - -def run_classification_job( - *, - s3_bucket: str, - s3_key: str, - pann_model: Optional[AudioTagging] = None, - sk_pipeline: Optional[Any] = None -) -> Dict[str, object]: - """ - Download from MinIO → classify_file → (optional) write DB → (optional) Kafka alert. - Returns a dict with 'label' and 'probs'. - """ - _load_head_once() - if ALLOWED_BUCKETS and s3_bucket not in ALLOWED_BUCKETS: - raise RuntimeError(f"Bucket '{s3_bucket}' is not allowed") - - client = _get_minio() - - # stat & validate - try: - stat = client.stat_object(s3_bucket, s3_key) - except S3Error as e: - raise RuntimeError(f"S3 stat failed: {e}") from e - size = getattr(stat, "size", None) - ctype = getattr(stat, "content_type", "") or "" - if size and size > MAX_BYTES: - raise RuntimeError(f"Object too large: {size} > {MAX_BYTES}") - if ctype and ALLOWED_CONTENT_TYPES and ctype not in ALLOWED_CONTENT_TYPES: - raise RuntimeError(f"Unsupported content-type: {ctype}") - - # download to temp - suffix = Path(s3_key).suffix or ".wav" - fd, tmp_path = tempfile.mkstemp(prefix="audio_", suffix=suffix) - os.close(fd) - try: - client.fget_object(s3_bucket, s3_key, tmp_path) - - result = classify_file(tmp_path, pann_model=pann_model, sk_pipeline=sk_pipeline) - - if result["label"] != "another" and KAFKA_BROKERS and ALERTS_TOPIC: - alerts.send_alert( - brokers=KAFKA_BROKERS, - topic=ALERTS_TOPIC, - label=str(result["label"]), - probs=dict(result["probs"]), - meta={"bucket": s3_bucket, "key": s3_key}, - ) - - return result - - finally: - try: - os.remove(tmp_path) - except Exception: - pass +from __future__ import annotations + +import logging +import os +import tempfile +from pathlib import Path +import time +from typing import Dict, List, Optional, Tuple, Any +from panns_inference import AudioTagging +import numpy as np +import joblib + +from minio import Minio +from minio.error import S3Error +import re +from datetime import datetime, timezone +import uuid + +from classification.core.model_io import ( + SAMPLE_RATE, + _to_numpy, + load_audio, # returns 1-D float32 mono @ SAMPLE_RATE + segment_waveform_2d_view, + aggregate_matrix, +) +from classification.backbones.cnn14 import load_cnn14_model, run_cnn14_embeddings_batch +from classification.scripts import alerts + +# ----------------------------- +# Environment configuration +# ----------------------------- +DEVICE = os.getenv("DEVICE", "cpu").strip().lower() +BACKBONE = os.getenv("BACKBONE", "cnn14").strip().lower() + +CHECKPOINT = os.getenv("CHECKPOINT") or "" +CHECKPOINT_URL = os.getenv("CHECKPOINT_URL") or "" + +HEAD_PATH = os.getenv("HEAD") or "" # joblib path +LABELS_CSV = os.getenv("LABELS_CSV") or "" # optional (if head has classes_, not needed) + +WINDOW_SEC = float(os.getenv("WINDOW_SEC", "2.0")) +HOP_SEC = float(os.getenv("HOP_SEC", "0.5")) +PAD_LAST = os.getenv("PAD_LAST", "true").strip().lower() in ("1", "true", "yes", "on") +AGG = os.getenv("AGG", "mean").strip().lower() # "mean" | "max" + +UNKNOWN_THRESHOLD = float(os.getenv("UNKNOWN_THRESHOLD", "0.55")) + +MINIO_ENDPOINT = os.getenv("MINIO_ENDPOINT", "minio:9000") +MINIO_ACCESS_KEY = os.getenv("MINIO_ACCESS_KEY", "minio") +MINIO_SECRET_KEY = os.getenv("MINIO_SECRET_KEY", "minio123") +MINIO_SECURE = os.getenv("MINIO_SECURE", "false").strip().lower() in ("1", "true", "yes", "on") + +ALLOWED_BUCKETS: List[str] = [b.strip() for b in os.getenv("ALLOWED_BUCKETS", "sound").split(",") if b.strip()] +ALLOWED_CONTENT_TYPES: List[str] = [t.strip() for t in os.getenv( + "ALLOWED_CONTENT_TYPES", + "audio/wav,audio/x-wav,audio/mpeg,audio/flac,audio/ogg,audio/mp4" +).split(",") if t.strip()] +MAX_BYTES = int(os.getenv("MAX_BYTES", str(50 * 1024 * 1024))) + +KAFKA_BROKERS = os.getenv("KAFKA_BROKERS", "kafka:9092") +ALERTS_TOPIC = os.getenv("ALERTS_TOPIC", "alerts") + +# ----------------------------- +# Lazy runtime (model/head/labels) +# ----------------------------- +class _Runtime: + model = None # CNN14 backbone + head = None # sklearn pipeline with predict_proba + classes: List[str] = [] # class names aligned to head output + +R = _Runtime() + +_MINIO_CLIENT = None + +def _get_minio(): + global _MINIO_CLIENT + if _MINIO_CLIENT is None: + _MINIO_CLIENT = Minio( + MINIO_ENDPOINT, access_key=MINIO_ACCESS_KEY, secret_key=MINIO_SECRET_KEY, secure=MINIO_SECURE + ) + return _MINIO_CLIENT + +_TS_PATTERNS = ( + # ISO-like with Z or without Z, with or without 'T' + r"(?P\d{4}-?\d{2}-?\d{2}[T ]?\d{2}:?\d{2}:?\d{2}Z?)", + # Compact: YYYYMMDDTHHMMSSZ or YYYYMMDDHHMMSS + r"(?P\d{8}T?\d{6}Z?)", + # Epoch seconds or millis + r"(?P\d{10}|\d{13})", +) + +def _parse_started_at_from_token(token: str) -> Optional[str]: + """Return ISO8601 UTC Z string if token looks like a timestamp; else None.""" + t = token.strip() + # epoch + if re.fullmatch(r"\d{13}", t): + dt = datetime.fromtimestamp(int(t)/1000.0, tz=timezone.utc) + return dt.strftime("%Y-%m-%dT%H:%M:%SZ") + if re.fullmatch(r"\d{10}", t): + dt = datetime.fromtimestamp(int(t), tz=timezone.utc) + return dt.strftime("%Y-%m-%dT%H:%M:%SZ") + # compact YYYYMMDD[ T]?HHMMSS[Z]? + m = re.fullmatch(r"(\d{8})T?(\d{6})Z?", t) + if m: + d, h = m.groups() + dt = datetime.strptime(d + h, "%Y%m%d%H%M%S").replace(tzinfo=timezone.utc) + return dt.strftime("%Y-%m-%dT%H:%M:%SZ") + # ISO-ish (allow missing separators) + # normalize: keep only digits and Z, then rebuild + if re.fullmatch(r"\d{4}-?\d{2}-?\d{2}[T ]?\d{2}:?\d{2}:?\d{2}Z?", t): + # try a few formats + for fmt in ("%Y-%m-%dT%H:%M:%SZ", "%Y-%m-%d %H:%M:%SZ", + "%Y-%m-%dT%H:%M:%S", "%Y%m%dT%H%M%SZ", "%Y%m%d%H%M%S"): + try: + if t.endswith("Z") and fmt.endswith("Z"): + dt = datetime.strptime(t.replace(" ", "T"), fmt).replace(tzinfo=timezone.utc) + return dt.strftime("%Y-%m-%dT%H:%M:%SZ") + if not t.endswith("Z") and not fmt.endswith("Z"): + dt = datetime.strptime(t.replace(" ", "T").replace("-", "").replace(":", ""), "%Y%m%dT%H%M%S").replace(tzinfo=timezone.utc) + return dt.strftime("%Y-%m-%dT%H:%M:%SZ") + except Exception: + pass + return None + +def _extract_device_and_started_at_from_key(s3_key: str) -> Tuple[Optional[str], Optional[str]]: + """ + Expect filenames like 'sensorId_timestamp.*' (timestamp token can be in a few formats). + Return (device_id, started_at_isoZ) or (None, None) if not confident. + """ + name = Path(s3_key).name + m = re.match(r"(?P[^_/]+)_(?P[^_.]+)", name) + if not m: + return None, None + device_id = m.group("dev").strip() + ts_token = m.group("ts").strip() + started_at = _parse_started_at_from_token(ts_token) + if not device_id or not started_at: + return None, None + return device_id, started_at + +def _load_backbone_once() -> None: + if R.model is not None: + return + if BACKBONE != "cnn14": + raise RuntimeError(f"Only BACKBONE=cnn14 is supported in this service, got {BACKBONE}") + # load_cnn14_model internally handles checkpoint/url (your impl) + R.model = load_cnn14_model(CHECKPOINT or None, device=DEVICE) + +def _load_head_once() -> None: + if R.head is not None: + return + if not HEAD_PATH: + raise RuntimeError("HEAD env var is required (path to joblib head)") + R.head = joblib.load(HEAD_PATH) + if not hasattr(R.head, "predict_proba"): + raise RuntimeError("HEAD must expose predict_proba(X) and classes_)") + + # 1) try labels from CSV if provided (most robust for production) + labels_csv = os.getenv("LABELS_CSV") or "" + if labels_csv: + from classification.core.model_io import load_labels_from_csv + labels = load_labels_from_csv(labels_csv) + if not labels: + raise RuntimeError(f"Labels CSV is empty or unreadable: {labels_csv}") + R.classes = labels + return + + # 2) else, try meta.json next to HEAD (HEAD_META env or HEAD+'.meta.json') + head_meta = os.getenv("HEAD_META") or (HEAD_PATH + ".meta.json") + labels_from_meta = [] + try: + if os.path.exists(head_meta): + import json + with open(head_meta, "r", encoding="utf-8") as f: + meta = json.load(f) + if isinstance(meta.get("class_order"), list) and len(meta["class_order"]) > 0: + labels_from_meta = [str(x) for x in meta["class_order"]] + except Exception as e: + print(f"⚠️ Warning: failed to parse HEAD meta: {e}") + + # 3) reconcile with head.classes_ + head_classes = list(getattr(R.head, "classes_", [])) + if labels_from_meta: + # if head.classes_ are [0..N-1], we map by index + if all(isinstance(c, (int, np.integer)) for c in head_classes): + if len(head_classes) != len(labels_from_meta): + raise RuntimeError( + f"Meta class_order length ({len(labels_from_meta)}) != head.classes_ length ({len(head_classes)})" + ) + R.classes = labels_from_meta + return + # else: if head.classes_ already hold real names, prefer them + R.classes = [str(c) for c in head_classes] if head_classes else labels_from_meta + return + + # 4) fallback to head.classes_ as strings + if head_classes: + R.classes = [str(c) for c in head_classes] + return + + # 5) no labels source found + raise RuntimeError( + "No labels source found. Provide LABELS_CSV, or HEAD_META with class_order, " + "or ensure the head exposes string class names via classes_." + ) + +# ----------------------------- +# Embedding/inference helpers +# ----------------------------- + +def _aggregate_probs(per_window_probs: np.ndarray) -> np.ndarray: + """ + Aggregate per-window class probabilities to a single clip-level vector. + Supports mean|max; returns 1-D float32. + """ + if per_window_probs.ndim != 2: + raise ValueError("expected shape (num_windows, num_classes)") + if per_window_probs.shape[0] == 0: + return np.zeros((per_window_probs.shape[1],), dtype=np.float32) + v = aggregate_matrix(per_window_probs, mode=AGG) + # When AGG=max, v might be logits-like — but we trained on probs, so it is already probabilities. + # If needed: apply softmax here. For a calibrated head (sklearn) it's already in [0,1]. + return v.astype(np.float32, copy=False) + +# ----------------------------- +# Public API for service +# ----------------------------- + +# Create a dedicated logger for performance metrics +perf_logger = logging.getLogger("audio_cls.perf") +perf_logger.setLevel(logging.INFO) +if not perf_logger.handlers: + h = logging.StreamHandler() + fmt = logging.Formatter("[%(asctime)s] [PERF] %(message)s", "%Y-%m-%d %H:%M:%S") + h.setFormatter(fmt) + perf_logger.addHandler(h) + +def classify_file( + path: str, + pann_model: Optional[AudioTagging] = None, + sk_pipeline: Optional[Any] = None +) -> Dict[str, object]: + t0 = time.perf_counter() + if sk_pipeline is None: + _load_head_once() + if pann_model is None: + _load_backbone_once() + + wav = np.array(load_audio(path, SAMPLE_RATE), dtype=np.float32, copy=True, order="C") + windows_2d = segment_waveform_2d_view( + wav, SAMPLE_RATE, window_sec=WINDOW_SEC, hop_sec=HOP_SEC, pad_last=PAD_LAST + ) + + num_windows = int(windows_2d.shape[0]) + if num_windows == 0: + result = { + "label": "another", + "probs": {c: 0.0 for c in R.classes}, + "pred_prob": 0.0, + "unknown_threshold": UNKNOWN_THRESHOLD, + "is_another": True, + "num_windows": 0, + "agg_mode": AGG, + "processing_ms": int((time.perf_counter() - t0) * 1000.0), + } + return result + + # Batch embeddings + if pann_model is not None: + win = np.array(windows_2d, dtype=np.float32, copy=True, order="C") + seg = pann_model.inference(win) + if isinstance(seg, dict): + seg = seg.get("embedding") + elif isinstance(seg, tuple) and len(seg) >= 2: + seg = seg[1] + seg = np.asarray(seg, dtype=np.float32) + if seg.ndim == 1: + seg = seg[None, :] + else: + win = np.array(windows_2d, dtype=np.float32, copy=True, order="C") + seg = run_cnn14_embeddings_batch(R.model, win, batch_size=32) + + # Head predict_proba + clf = sk_pipeline if sk_pipeline is not None else R.head + per_window_probs = np.asarray(clf.predict_proba(seg), dtype=np.float32) + + # Aggregate and finalize + agg_vec = _aggregate_probs(per_window_probs) + k = int(np.argmax(agg_vec)) + top_prob = float(agg_vec[k]) + top_label = R.classes[k] + final_label = top_label if top_prob >= UNKNOWN_THRESHOLD else "another" + probs = {cls: float(p) for cls, p in zip(R.classes, agg_vec)} + + processing_ms = int((time.perf_counter() - t0) * 1000.0) + + return { + "label": final_label, + "probs": probs, + "pred_prob": top_prob, + "unknown_threshold": UNKNOWN_THRESHOLD, + "is_another": (final_label == "another"), + "num_windows": num_windows, + "agg_mode": AGG, + "processing_ms": processing_ms, + } + +def run_classification_job( + *, + s3_bucket: str, + s3_key: str, + pann_model: Optional[AudioTagging] = None, + sk_pipeline: Optional[Any] = None +) -> Dict[str, object]: + """ + Download from MinIO → classify_file → (optional) Kafka alert. + Returns a dict with 'label', 'probs, and alert send status. + """ + _load_head_once() + if ALLOWED_BUCKETS and s3_bucket not in ALLOWED_BUCKETS: + raise RuntimeError(f"Bucket '{s3_bucket}' is not allowed") + + client = _get_minio() + + # stat & validate + try: + stat = client.stat_object(s3_bucket, s3_key) + except S3Error as e: + raise RuntimeError(f"S3 stat failed: {e}") from e + size = getattr(stat, "size", None) + ctype = getattr(stat, "content_type", "") or "" + if size and size > MAX_BYTES: + raise RuntimeError(f"Object too large: {size} > {MAX_BYTES}") + if ctype and ALLOWED_CONTENT_TYPES and ctype not in ALLOWED_CONTENT_TYPES: + raise RuntimeError(f"Unsupported content-type: {ctype}") + + # download to temp + suffix = Path(s3_key).suffix or ".wav" + fd, tmp_path = tempfile.mkstemp(prefix="audio_", suffix=suffix) + os.close(fd) + try: + client.fget_object(s3_bucket, s3_key, tmp_path) + result = classify_file(tmp_path, pann_model=pann_model, sk_pipeline=sk_pipeline) + # default alert flags + result.setdefault("sent_alert", False) + result.setdefault("alert_topic", None) + result.setdefault("alert_skip_reason", None) + if result.get("processing_ms") is not None: + try: + result["processing_ms"] = int(result["processing_ms"]) + except Exception: + pass + if result["label"] != "another" and KAFKA_BROKERS and ALERTS_TOPIC: + device_id, started_at = _extract_device_and_started_at_from_key(s3_key) + if device_id and started_at: + try: + label = str(result["label"]) + alert_type = f"suspicious_sound-{label}" + + severity = None + sev_map_env = os.getenv("ALERT_SEVERITY_MAP", "").strip() + if sev_map_env: + try: + _sev_map = __import__("json").loads(sev_map_env) + if isinstance(_sev_map, dict) and label in _sev_map: + _s = _sev_map[label] + if isinstance(_s, (int, np.integer)): + severity = int(_s) + except Exception: + pass + + confidence = float(result.get("pred_prob") or 0.0) + + meta = { + "bucket": s3_bucket, + "key": s3_key, + "processing_ms": result.get("processing_ms"), + } + if meta["processing_ms"] is None: + meta.pop("processing_ms") + message_key = f"{device_id}|{started_at}" + + ok = alerts.send_structured_alert( + brokers=KAFKA_BROKERS, + topic=ALERTS_TOPIC, + alert_type=alert_type, + device_id=device_id, + started_at=started_at, + confidence=confidence, + severity=severity, + meta=meta, + message_key=message_key, + ) + perf_logger.info("About to send alert: topic=%s key=%s type=%s", + ALERTS_TOPIC, message_key, alert_type) + if ok: + result["sent_alert"] = True + result["alert_topic"] = ALERTS_TOPIC + else: + perf_logger.warning("Alert send returned False (topic=%s key=%s)", ALERTS_TOPIC, s3_key) + result["alert_skip_reason"] = "kafka_produce_returned_false" + except Exception as e: + perf_logger.warning("Alert send failed: %s (key=%s)", e, s3_key) + result["alert_skip_reason"] = "kafka_exception" + else: + perf_logger.warning( + "Skip alert (missing device_id/started_at) for key=%s", s3_key + ) + result["alert_skip_reason"] = "missing_device_or_started_at" + elif result["label"] == "another": + result["alert_skip_reason"] = "label_is_another" + elif not KAFKA_BROKERS or not ALERTS_TOPIC: + result["alert_skip_reason"] = "missing_env_brokers_or_topic" + return result + finally: + try: + os.remove(tmp_path) + except Exception: + pass diff --git a/services/sounds/sounds_classifier/tests/conftest.py b/services/sounds_classifier/tests/conftest.py similarity index 97% rename from services/sounds/sounds_classifier/tests/conftest.py rename to services/sounds_classifier/tests/conftest.py index 10920f2c3..afb6a527e 100644 --- a/services/sounds/sounds_classifier/tests/conftest.py +++ b/services/sounds_classifier/tests/conftest.py @@ -1,56 +1,56 @@ -import sys -import pathlib -import os -import pytest - -# 1) Ensure "src" is on sys.path so `import classification...` works -# This walks up from tests/ to repo root and prepends /src -HERE = pathlib.Path(__file__).resolve() -ROOT = HERE -for _ in range(6): # walk up a few levels just in case - if (ROOT / "src").exists(): - sys.path.insert(0, str(ROOT / "src")) - break - ROOT = ROOT.parent - -# 2) Provide minimal, isolated env defaults for tests -@pytest.fixture(autouse=True) -def _isolate_env(monkeypatch: pytest.MonkeyPatch, tmp_path: pathlib.Path): - # Core runtime defaults - monkeypatch.setenv("DEVICE", "cpu") - monkeypatch.setenv("BACKBONE", "cnn14") - - # Windowing / aggregation - monkeypatch.setenv("WINDOW_SEC", "2.0") - monkeypatch.setenv("HOP_SEC", "0.5") - monkeypatch.setenv("PAD_LAST", "true") - monkeypatch.setenv("AGG", "mean") - monkeypatch.setenv("UNKNOWN_THRESHOLD", "0.55") - - # Disable optional integrations by default - monkeypatch.setenv("KAFKA_BROKERS", "") - monkeypatch.delenv("WRITE_DB", raising=False) - monkeypatch.delenv("DB_URL", raising=False) - - # HEAD path (tests can override with real/mocked head if needed) - head_dir = tmp_path / "head" - head_dir.mkdir(parents=True, exist_ok=True) - monkeypatch.setenv("HEAD", str(head_dir / "dummy.joblib")) - # Let tests decide labels source; default to none - monkeypatch.setenv("LABELS_CSV", "") - monkeypatch.delenv("HEAD_META", raising=False) - - # MinIO defaults (won't be used unless explicitly mocked) - monkeypatch.setenv("MINIO_ENDPOINT", "minio:9000") - monkeypatch.setenv("MINIO_ACCESS_KEY", "minio") - monkeypatch.setenv("MINIO_SECRET_KEY", "minio123") - monkeypatch.setenv("MINIO_SECURE", "false") - - # Kafka alerts defaults - monkeypatch.setenv("ALERTS_TOPIC", "dev-robot-alerts") - - # Checkpoint defaults (tests typically mock loading; real file not needed) - monkeypatch.setenv("CHECKPOINT", str(tmp_path / "models" / "panns_data" / "Cnn14_mAP=0.431.pth")) - monkeypatch.delenv("CHECKPOINT_URL", raising=False) - - yield +import sys +import pathlib +import os +import pytest + +# 1) Ensure "src" is on sys.path so `import classification...` works +# This walks up from tests/ to repo root and prepends /src +HERE = pathlib.Path(__file__).resolve() +ROOT = HERE +for _ in range(6): # walk up a few levels just in case + if (ROOT / "src").exists(): + sys.path.insert(0, str(ROOT / "src")) + break + ROOT = ROOT.parent + +# 2) Provide minimal, isolated env defaults for tests +@pytest.fixture(autouse=True) +def _isolate_env(monkeypatch: pytest.MonkeyPatch, tmp_path: pathlib.Path): + # Core runtime defaults + monkeypatch.setenv("DEVICE", "cpu") + monkeypatch.setenv("BACKBONE", "cnn14") + + # Windowing / aggregation + monkeypatch.setenv("WINDOW_SEC", "2.0") + monkeypatch.setenv("HOP_SEC", "0.5") + monkeypatch.setenv("PAD_LAST", "true") + monkeypatch.setenv("AGG", "mean") + monkeypatch.setenv("UNKNOWN_THRESHOLD", "0.55") + + # Disable optional integrations by default + monkeypatch.setenv("KAFKA_BROKERS", "") + monkeypatch.delenv("WRITE_DB", raising=False) + monkeypatch.delenv("DB_URL", raising=False) + + # HEAD path (tests can override with real/mocked head if needed) + head_dir = tmp_path / "head" + head_dir.mkdir(parents=True, exist_ok=True) + monkeypatch.setenv("HEAD", str(head_dir / "dummy.joblib")) + # Let tests decide labels source; default to none + monkeypatch.setenv("LABELS_CSV", "") + monkeypatch.delenv("HEAD_META", raising=False) + + # MinIO defaults (won't be used unless explicitly mocked) + monkeypatch.setenv("MINIO_ENDPOINT", "minio:9000") + monkeypatch.setenv("MINIO_ACCESS_KEY", "minio") + monkeypatch.setenv("MINIO_SECRET_KEY", "minio123") + monkeypatch.setenv("MINIO_SECURE", "false") + + # Kafka alerts defaults + monkeypatch.setenv("ALERTS_TOPIC", "dev-robot-alerts") + + # Checkpoint defaults (tests typically mock loading; real file not needed) + monkeypatch.setenv("CHECKPOINT", str(tmp_path / "models" / "panns_data" / "Cnn14_mAP=0.431.pth")) + monkeypatch.delenv("CHECKPOINT_URL", raising=False) + + yield diff --git a/services/sounds/sounds_classifier/tests/test_alerts.py b/services/sounds_classifier/tests/test_alerts.py similarity index 97% rename from services/sounds/sounds_classifier/tests/test_alerts.py rename to services/sounds_classifier/tests/test_alerts.py index 97989ff34..f2dd89850 100644 --- a/services/sounds/sounds_classifier/tests/test_alerts.py +++ b/services/sounds_classifier/tests/test_alerts.py @@ -1,161 +1,161 @@ -import json -import types -import os -import pytest - -import classification.scripts.alerts as alerts - - -# ---------- test helpers ---------- -class DummyMsg: - def __init__(self, topic, partition, offset): - self._t = topic - self._p = partition - self._o = offset - def topic(self): return self._t - def partition(self): return self._p - def offset(self): return self._o - - -class DummyProducer: - def __init__(self, *a, **kw): - self._queue = [] - self._flushed = 0 - def produce(self, *, topic, value, callback=None): - self._queue.append((topic, value)) - if callback: - callback(None, DummyMsg(topic, 0, len(self._queue) - 1)) - def poll(self, _): # api compatibility - pass - def flush(self, timeout): - self._flushed += 1 - return 0 # 0 undelivered → success - - -@pytest.fixture(autouse=True) -def _clear_cache(): - alerts._producer_cache.clear() - yield - alerts._producer_cache.clear() - - -# ---------- tests ---------- -def test_send_alert_success_and_payload(monkeypatch): - dp = DummyProducer() - monkeypatch.setattr(alerts, "Producer", lambda *_a, **_k: dp, raising=True) - - ok = alerts.send_alert( - brokers="kafka:9092", - topic="dev-robot-alerts", - label="car", - probs={"car": 0.9, "dog": 0.1}, - meta={"bucket": "b", "key": "k.wav"}, - ) - assert ok is True - # Verify one message with proper JSON payload - assert len(dp._queue) == 1 - topic, raw = dp._queue[0] - assert topic == "dev-robot-alerts" - payload = json.loads(raw.decode("utf-8")) - assert payload["label"] == "car" - assert payload["probs"]["car"] == pytest.approx(0.9) - assert isinstance(payload["ts"], int) - - -def test_producer_cache_reuse(monkeypatch): - calls = {"made": 0} - def mk(*_a, **_k): - calls["made"] += 1 - return DummyProducer() - monkeypatch.setattr(alerts, "Producer", mk, raising=True) - - # First call creates a producer - assert alerts.send_alert(brokers="kafka:9092", topic="t", label="x", probs={"x": 1.0}) - # Second call should reuse cache → no extra Producer() - assert alerts.send_alert(brokers="kafka:9092", topic="t", label="y", probs={"y": 1.0}) - assert calls["made"] == 1 - assert len(alerts._producer_cache) == 1 - - -def test_send_alert_is_non_blocking_and_returns_true(monkeypatch, caplog): - class ProducerNoFlush(DummyProducer): - def flush(self, timeout): - return 1 # would indicate undelivered if we used it, but we don't block now - - monkeypatch.setattr(alerts, "Producer", lambda *_a, **_k: ProducerNoFlush(), raising=True) - - ok = alerts.send_alert(brokers="b:9092", topic="t", label="x", probs={"x": 1.0}) - assert ok is True - assert any("Kafka delivered" in rec.message or "Kafka" in rec.message for rec in caplog.records) - - -def test_send_alert_kafka_exception_on_init(monkeypatch, caplog): - class DummyKafkaEx(alerts.KafkaException): # use the module's class - pass - def boom(*_a, **_k): - raise DummyKafkaEx("init failed") - monkeypatch.setattr(alerts, "Producer", boom, raising=True) - - ok = alerts.send_alert(brokers="bad:9092", topic="t", label="x", probs={"x": 1.0}) - assert ok is False - assert any("exception while producing" in rec.message or "Failed to initialize Kafka" in rec.message - for rec in caplog.records) - - -def test_send_alert_buffer_error(monkeypatch, caplog): - class BufErrProducer(DummyProducer): - def produce(self, *a, **k): - raise BufferError("local queue full") - monkeypatch.setattr(alerts, "Producer", lambda *_a, **_k: BufErrProducer(), raising=True) - - ok = alerts.send_alert(brokers="b:9092", topic="t", label="x", probs={"x": 1.0}) - assert ok is False - assert any("local queue full" in rec.message for rec in caplog.records) - - -def test_send_alert_runtime_error(monkeypatch, caplog): - class BoomProducer(DummyProducer): - def produce(self, *a, **k): - raise RuntimeError("produce exploded") - monkeypatch.setattr(alerts, "Producer", lambda *_a, **_k: BoomProducer(), raising=True) - - ok = alerts.send_alert(brokers="b:9092", topic="t", label="x", probs={"x": 1.0}) - assert ok is False - assert any("produce error" in rec.message for rec in caplog.records) - - -def test_legacy_send_kafka_alert_reads_env(monkeypatch): - # Check that env is read and forwarded into send_alert - captured = {} - def fake_send_alert(*, brokers, topic, label, probs, meta): - captured["brokers"] = brokers - captured["topic"] = topic - captured["label"] = label - captured["probs"] = probs - captured["meta"] = meta - return True - - monkeypatch.setenv("KAFKA_BROKERS", "env-b:9092") - monkeypatch.setenv("ALERTS_TOPIC", "env-topic") - monkeypatch.setattr(alerts, "send_alert", fake_send_alert, raising=True) - - ok = alerts.send_kafka_alert(file_path="/tmp/a.wav", label="car", prob=0.7) - assert ok is True - assert captured["brokers"] == "env-b:9092" - assert captured["topic"] == "env-topic" - assert captured["label"] == "car" - assert captured["probs"] == {"car": 0.7} - assert captured["meta"]["file_path"] == "/tmp/a.wav" - - -def test__delivery_report_logs_ok_and_err(caplog): - # OK case - alerts._delivery_report(None, DummyMsg("t", 0, 1)) - # Error case - class Err: - def __str__(self): return "boom" - alerts._delivery_report(Err(), DummyMsg("t", 0, 2)) - - # We don't assert exact wording, just that both paths logged - assert any("Kafka delivered" in r.message for r in caplog.records) - assert any("Kafka delivery failed" in r.message for r in caplog.records) +import json +import types +import os +import pytest + +import classification.scripts.alerts as alerts + + +# ---------- test helpers ---------- +class DummyMsg: + def __init__(self, topic, partition, offset): + self._t = topic + self._p = partition + self._o = offset + def topic(self): return self._t + def partition(self): return self._p + def offset(self): return self._o + + +class DummyProducer: + def __init__(self, *a, **kw): + self._queue = [] + self._flushed = 0 + def produce(self, *, topic, value, callback=None): + self._queue.append((topic, value)) + if callback: + callback(None, DummyMsg(topic, 0, len(self._queue) - 1)) + def poll(self, _): # api compatibility + pass + def flush(self, timeout): + self._flushed += 1 + return 0 # 0 undelivered → success + + +@pytest.fixture(autouse=True) +def _clear_cache(): + alerts._producer_cache.clear() + yield + alerts._producer_cache.clear() + + +# ---------- tests ---------- +def test_send_alert_success_and_payload(monkeypatch): + dp = DummyProducer() + monkeypatch.setattr(alerts, "Producer", lambda *_a, **_k: dp, raising=True) + + ok = alerts.send_alert( + brokers="kafka:9092", + topic="dev-robot-alerts", + label="car", + probs={"car": 0.9, "dog": 0.1}, + meta={"bucket": "b", "key": "k.wav"}, + ) + assert ok is True + # Verify one message with proper JSON payload + assert len(dp._queue) == 1 + topic, raw = dp._queue[0] + assert topic == "dev-robot-alerts" + payload = json.loads(raw.decode("utf-8")) + assert payload["label"] == "car" + assert payload["probs"]["car"] == pytest.approx(0.9) + assert isinstance(payload["ts"], int) + + +def test_producer_cache_reuse(monkeypatch): + calls = {"made": 0} + def mk(*_a, **_k): + calls["made"] += 1 + return DummyProducer() + monkeypatch.setattr(alerts, "Producer", mk, raising=True) + + # First call creates a producer + assert alerts.send_alert(brokers="kafka:9092", topic="t", label="x", probs={"x": 1.0}) + # Second call should reuse cache → no extra Producer() + assert alerts.send_alert(brokers="kafka:9092", topic="t", label="y", probs={"y": 1.0}) + assert calls["made"] == 1 + assert len(alerts._producer_cache) == 1 + + +def test_send_alert_is_non_blocking_and_returns_true(monkeypatch, caplog): + class ProducerNoFlush(DummyProducer): + def flush(self, timeout): + return 1 # would indicate undelivered if we used it, but we don't block now + + monkeypatch.setattr(alerts, "Producer", lambda *_a, **_k: ProducerNoFlush(), raising=True) + + ok = alerts.send_alert(brokers="b:9092", topic="t", label="x", probs={"x": 1.0}) + assert ok is True + assert any("Kafka delivered" in rec.message or "Kafka" in rec.message for rec in caplog.records) + + +def test_send_alert_kafka_exception_on_init(monkeypatch, caplog): + class DummyKafkaEx(alerts.KafkaException): # use the module's class + pass + def boom(*_a, **_k): + raise DummyKafkaEx("init failed") + monkeypatch.setattr(alerts, "Producer", boom, raising=True) + + ok = alerts.send_alert(brokers="bad:9092", topic="t", label="x", probs={"x": 1.0}) + assert ok is False + assert any("exception while producing" in rec.message or "Failed to initialize Kafka" in rec.message + for rec in caplog.records) + + +def test_send_alert_buffer_error(monkeypatch, caplog): + class BufErrProducer(DummyProducer): + def produce(self, *a, **k): + raise BufferError("local queue full") + monkeypatch.setattr(alerts, "Producer", lambda *_a, **_k: BufErrProducer(), raising=True) + + ok = alerts.send_alert(brokers="b:9092", topic="t", label="x", probs={"x": 1.0}) + assert ok is False + assert any("local queue full" in rec.message for rec in caplog.records) + + +def test_send_alert_runtime_error(monkeypatch, caplog): + class BoomProducer(DummyProducer): + def produce(self, *a, **k): + raise RuntimeError("produce exploded") + monkeypatch.setattr(alerts, "Producer", lambda *_a, **_k: BoomProducer(), raising=True) + + ok = alerts.send_alert(brokers="b:9092", topic="t", label="x", probs={"x": 1.0}) + assert ok is False + assert any("produce error" in rec.message for rec in caplog.records) + + +def test_legacy_send_kafka_alert_reads_env(monkeypatch): + # Check that env is read and forwarded into send_alert + captured = {} + def fake_send_alert(*, brokers, topic, label, probs, meta): + captured["brokers"] = brokers + captured["topic"] = topic + captured["label"] = label + captured["probs"] = probs + captured["meta"] = meta + return True + + monkeypatch.setenv("KAFKA_BROKERS", "env-b:9092") + monkeypatch.setenv("ALERTS_TOPIC", "env-topic") + monkeypatch.setattr(alerts, "send_alert", fake_send_alert, raising=True) + + ok = alerts.send_kafka_alert(file_path="/tmp/a.wav", label="car", prob=0.7) + assert ok is True + assert captured["brokers"] == "env-b:9092" + assert captured["topic"] == "env-topic" + assert captured["label"] == "car" + assert captured["probs"] == {"car": 0.7} + assert captured["meta"]["file_path"] == "/tmp/a.wav" + + +def test__delivery_report_logs_ok_and_err(caplog): + # OK case + alerts._delivery_report(None, DummyMsg("t", 0, 1)) + # Error case + class Err: + def __str__(self): return "boom" + alerts._delivery_report(Err(), DummyMsg("t", 0, 2)) + + # We don't assert exact wording, just that both paths logged + assert any("Kafka delivered" in r.message for r in caplog.records) + assert any("Kafka delivery failed" in r.message for r in caplog.records) diff --git a/services/sounds/sounds_classifier/tests/test_app.py b/services/sounds_classifier/tests/test_app.py similarity index 97% rename from services/sounds/sounds_classifier/tests/test_app.py rename to services/sounds_classifier/tests/test_app.py index 154264c1e..6056cf4ef 100644 --- a/services/sounds/sounds_classifier/tests/test_app.py +++ b/services/sounds_classifier/tests/test_app.py @@ -1,215 +1,215 @@ -from fastapi.testclient import TestClient -import classification.app as app_mod - -# ----------------------------- -# Helpers / Fakes -# ----------------------------- -class DummyAT: - """Lightweight fake for AudioTagging to let startup warm-up pass.""" - def __init__(self, *a, **k): - pass - def inference(self, x): - # Return something with "embedding" to match warm-up expectations - return {"embedding": [0.0, 0.0]} - -class DummyConn: - """Fake psycopg2 connection used by open_db().""" - def __init__(self): - self.closed = False - self._executed = [] - def cursor(self): - class C: - def __init__(self, outer): - self.outer = outer - def __enter__(self): - return self - def __exit__(self, *a): - return False - def execute(self, *a, **k): - self.outer._executed.append(("EXEC", a, k)) - return C(self) - def commit(self): - self._executed.append(("COMMIT", None, None)) - def rollback(self): - self._executed.append(("ROLLBACK", None, None)) - def close(self): - self.closed = True - -def _patch_startup_and_db(monkeypatch, *, capture): - """ - Patch all heavy/IO dependencies used by app startup and endpoints. - 'capture' is a dict used to collect calls for assertions. - """ - # Avoid heavy model loading - monkeypatch.setattr(app_mod, "AudioTagging", DummyAT, raising=True) - # Pretend pipeline file does not exist -> skip joblib.load - monkeypatch.setattr(app_mod.os.path, "exists", lambda p: False, raising=True) - - # open_db returns DummyConn and we keep the instance in capture - def _open_db(): - conn = DummyConn() - capture["conn"] = conn - return conn - monkeypatch.setattr(app_mod, "open_db", _open_db, raising=True) - - # No-op ensure_run - monkeypatch.setattr(app_mod, "ensure_run", lambda conn, run_id: None, raising=True) - - # Safe defaults for resolve/upsert; specific tests will override if needed - monkeypatch.setattr(app_mod, "resolve_file_id", lambda *a, **k: 123, raising=True) - - def _upsert(conn, payload): - capture.setdefault("upserts", []).append({"conn": conn, "payload": payload}) - monkeypatch.setattr(app_mod, "upsert_file_aggregate", _upsert, raising=True) - - -# ----------------------------- -# Tests -# ----------------------------- -def test_health_ok(monkeypatch): - cap = {} - _patch_startup_and_db(monkeypatch, capture=cap) - # Use context manager so startup/shutdown definitely run - with TestClient(app_mod.app) as client: - r = client.get("/health") - assert r.status_code == 200 - data = r.json() - assert isinstance(data, dict) - assert data.get("ok") is True - # Assert values that reflect startup side effects - assert data.get("pann_loaded") is True - assert data.get("sk_pipeline_loaded") in (True, False) - # DB connection was created - assert cap.get("conn") is not None - -def test_startup_and_shutdown_close_db(monkeypatch): - cap = {} - _patch_startup_and_db(monkeypatch, capture=cap) - # Run inside context to trigger startup+shutdown - with TestClient(app_mod.app) as client: - r = client.get("/health") - assert r.status_code == 200 - # During runtime, connection object is set - assert cap.get("conn") is not None - assert cap["conn"].closed is False - # After shutdown hook ran, global DB_CONN should be None and fake conn closed - assert app_mod.DB_CONN is None - assert cap["conn"].closed is True - -def test_classify_200_success(monkeypatch): - cap = {} - _patch_startup_and_db(monkeypatch, capture=cap) - - # resolve_file_id -> a fixed file_id - monkeypatch.setattr(app_mod, "resolve_file_id", lambda *a, **k: 42, raising=True) - - # classification core returns a rich dict (with extra fields) - import classification.scripts.classify as cls - monkeypatch.setattr( - cls, - "run_classification_job", - lambda **k: { - "label": "car", - "probs": {"car": 0.9, "dog": 0.1}, - "pred_prob": 0.9, - "unknown_threshold": 0.55, - "is_another": False, - "num_windows": 5, - "agg_mode": "mean", - "processing_ms": 123.0, - }, - raising=True - ) - - with TestClient(app_mod.app) as client: - r = client.post("/classify", json={"s3_bucket": "ok", "s3_key": "file.wav", "return_porbs": True}) - assert r.status_code == 200 - body = r.json() - assert body["label"] == "car" - # default return_probs is False -> probs stripped - assert body["probs"] == {"car": 0.9, "dog": 0.1} - - # Verify upsert called with our collected payload - assert len(cap.get("upserts", [])) == 1 - payload = cap["upserts"][0]["payload"] - assert payload["file_id"] == 42 - assert payload["head_pred_label"] == "car" - assert payload["num_windows"] == 5 - assert payload["agg_mode"] == "mean" - assert payload["processing_ms"] == 123.0 - -def test_classify_200_with_return_probs_true(monkeypatch): - cap = {} - _patch_startup_and_db(monkeypatch, capture=cap) - monkeypatch.setattr(app_mod, "resolve_file_id", lambda *a, **k: 88, raising=True) - - import classification.scripts.classify as cls - monkeypatch.setattr( - cls, - "run_classification_job", - lambda **k: {"label": "dog", "probs": {"car": 0.2, "dog": 0.8}}, - raising=True - ) - - with TestClient(app_mod.app) as client: - r = client.post("/classify", json={"s3_bucket": "b", "s3_key": "key.wav", "return_probs": True}) - assert r.status_code == 200 - body = r.json() - assert body["label"] == "dog" - assert body["probs"] == {"car": 0.2, "dog": 0.8} - # ensure upsert invoked - assert len(cap.get("upserts", [])) == 1 - assert cap["upserts"][0]["payload"]["file_id"] == 88 - -def test_classify_404_when_file_missing(monkeypatch): - cap = {} - _patch_startup_and_db(monkeypatch, capture=cap) - # Simulate resolve_file_id failure (file not in public.files) - def _resolve(*a, **k): - raise ValueError("not found") - monkeypatch.setattr(app_mod, "resolve_file_id", _resolve, raising=True) - - with TestClient(app_mod.app) as client: - r = client.post("/classify", json={"s3_bucket": "b", "s3_key": "missing.wav"}) - assert r.status_code == 404 - # No upsert on failure - assert cap.get("upserts", []) == [] - -def test_classify_500_when_core_raises(monkeypatch): - cap = {} - _patch_startup_and_db(monkeypatch, capture=cap) - monkeypatch.setattr(app_mod, "resolve_file_id", lambda *a, **k: 55, raising=True) - - import classification.scripts.classify as cls - def _raiser(**k): - raise RuntimeError("boom") - monkeypatch.setattr(cls, "run_classification_job", _raiser, raising=True) - - with TestClient(app_mod.app) as client: - r = client.post("/classify", json={"s3_bucket": "b", "s3_key": "crash.wav"}) - assert r.status_code == 500 - # No upsert on failure - assert cap.get("upserts", []) == [] - -def test_middleware_executes(monkeypatch): - """ - Hitting endpoints implicitly passes through the timing middleware. - This test mainly ensures middleware path executes without errors - (coverage); assertions are on status only. - """ - cap = {} - _patch_startup_and_db(monkeypatch, capture=cap) - # First pass: just exercise /health - with TestClient(app_mod.app) as client: - r1 = client.get("/health") - assert r1.status_code == 200 - - # Second pass: exercise /classify with a minimal classifier mock - cap = {} - _patch_startup_and_db(monkeypatch, capture=cap) - import classification.scripts.classify as cls - monkeypatch.setattr(cls, "run_classification_job", - lambda **k: {"label": "ok", "probs": {}}, raising=True) - with TestClient(app_mod.app) as client: - r = client.post("/classify", json={"s3_bucket": "b", "s3_key": "k.wav"}) - assert r.status_code == 200 +from fastapi.testclient import TestClient +import classification.app as app_mod + +# ----------------------------- +# Helpers / Fakes +# ----------------------------- +class DummyAT: + """Lightweight fake for AudioTagging to let startup warm-up pass.""" + def __init__(self, *a, **k): + pass + def inference(self, x): + # Return something with "embedding" to match warm-up expectations + return {"embedding": [0.0, 0.0]} + +class DummyConn: + """Fake psycopg2 connection used by open_db().""" + def __init__(self): + self.closed = False + self._executed = [] + def cursor(self): + class C: + def __init__(self, outer): + self.outer = outer + def __enter__(self): + return self + def __exit__(self, *a): + return False + def execute(self, *a, **k): + self.outer._executed.append(("EXEC", a, k)) + return C(self) + def commit(self): + self._executed.append(("COMMIT", None, None)) + def rollback(self): + self._executed.append(("ROLLBACK", None, None)) + def close(self): + self.closed = True + +def _patch_startup_and_db(monkeypatch, *, capture): + """ + Patch all heavy/IO dependencies used by app startup and endpoints. + 'capture' is a dict used to collect calls for assertions. + """ + # Avoid heavy model loading + monkeypatch.setattr(app_mod, "AudioTagging", DummyAT, raising=True) + # Pretend pipeline file does not exist -> skip joblib.load + monkeypatch.setattr(app_mod.os.path, "exists", lambda p: False, raising=True) + + # open_db returns DummyConn and we keep the instance in capture + def _open_db(): + conn = DummyConn() + capture["conn"] = conn + return conn + monkeypatch.setattr(app_mod, "open_db", _open_db, raising=True) + + # No-op ensure_run + monkeypatch.setattr(app_mod, "ensure_run", lambda conn, run_id: None, raising=True) + + # Safe defaults for resolve/upsert; specific tests will override if needed + monkeypatch.setattr(app_mod, "resolve_file_id", lambda *a, **k: 123, raising=True) + + def _upsert(conn, payload): + capture.setdefault("upserts", []).append({"conn": conn, "payload": payload}) + monkeypatch.setattr(app_mod, "upsert_file_aggregate", _upsert, raising=True) + + +# ----------------------------- +# Tests +# ----------------------------- +def test_health_ok(monkeypatch): + cap = {} + _patch_startup_and_db(monkeypatch, capture=cap) + # Use context manager so startup/shutdown definitely run + with TestClient(app_mod.app) as client: + r = client.get("/health") + assert r.status_code == 200 + data = r.json() + assert isinstance(data, dict) + assert data.get("ok") is True + # Assert values that reflect startup side effects + assert data.get("pann_loaded") is True + assert data.get("sk_pipeline_loaded") in (True, False) + # DB connection was created + assert cap.get("conn") is not None + +def test_startup_and_shutdown_close_db(monkeypatch): + cap = {} + _patch_startup_and_db(monkeypatch, capture=cap) + # Run inside context to trigger startup+shutdown + with TestClient(app_mod.app) as client: + r = client.get("/health") + assert r.status_code == 200 + # During runtime, connection object is set + assert cap.get("conn") is not None + assert cap["conn"].closed is False + # After shutdown hook ran, global DB_CONN should be None and fake conn closed + assert app_mod.DB_CONN is None + assert cap["conn"].closed is True + +def test_classify_200_success(monkeypatch): + cap = {} + _patch_startup_and_db(monkeypatch, capture=cap) + + # resolve_file_id -> a fixed file_id + monkeypatch.setattr(app_mod, "resolve_file_id", lambda *a, **k: 42, raising=True) + + # classification core returns a rich dict (with extra fields) + import classification.scripts.classify as cls + monkeypatch.setattr( + cls, + "run_classification_job", + lambda **k: { + "label": "car", + "probs": {"car": 0.9, "dog": 0.1}, + "pred_prob": 0.9, + "unknown_threshold": 0.55, + "is_another": False, + "num_windows": 5, + "agg_mode": "mean", + "processing_ms": 123.0, + }, + raising=True + ) + + with TestClient(app_mod.app) as client: + r = client.post("/classify", json={"s3_bucket": "ok", "s3_key": "file.wav", "return_porbs": True}) + assert r.status_code == 200 + body = r.json() + assert body["label"] == "car" + # default return_probs is False -> probs stripped + assert body["probs"] == {"car": 0.9, "dog": 0.1} + + # Verify upsert called with our collected payload + assert len(cap.get("upserts", [])) == 1 + payload = cap["upserts"][0]["payload"] + assert payload["file_id"] == 42 + assert payload["head_pred_label"] == "car" + assert payload["num_windows"] == 5 + assert payload["agg_mode"] == "mean" + assert payload["processing_ms"] == 123.0 + +def test_classify_200_with_return_probs_true(monkeypatch): + cap = {} + _patch_startup_and_db(monkeypatch, capture=cap) + monkeypatch.setattr(app_mod, "resolve_file_id", lambda *a, **k: 88, raising=True) + + import classification.scripts.classify as cls + monkeypatch.setattr( + cls, + "run_classification_job", + lambda **k: {"label": "dog", "probs": {"car": 0.2, "dog": 0.8}}, + raising=True + ) + + with TestClient(app_mod.app) as client: + r = client.post("/classify", json={"s3_bucket": "b", "s3_key": "key.wav", "return_probs": True}) + assert r.status_code == 200 + body = r.json() + assert body["label"] == "dog" + assert body["probs"] == {"car": 0.2, "dog": 0.8} + # ensure upsert invoked + assert len(cap.get("upserts", [])) == 1 + assert cap["upserts"][0]["payload"]["file_id"] == 88 + +def test_classify_404_when_file_missing(monkeypatch): + cap = {} + _patch_startup_and_db(monkeypatch, capture=cap) + # Simulate resolve_file_id failure (file not in public.files) + def _resolve(*a, **k): + raise ValueError("not found") + monkeypatch.setattr(app_mod, "resolve_file_id", _resolve, raising=True) + + with TestClient(app_mod.app) as client: + r = client.post("/classify", json={"s3_bucket": "b", "s3_key": "missing.wav"}) + assert r.status_code == 404 + # No upsert on failure + assert cap.get("upserts", []) == [] + +def test_classify_500_when_core_raises(monkeypatch): + cap = {} + _patch_startup_and_db(monkeypatch, capture=cap) + monkeypatch.setattr(app_mod, "resolve_file_id", lambda *a, **k: 55, raising=True) + + import classification.scripts.classify as cls + def _raiser(**k): + raise RuntimeError("boom") + monkeypatch.setattr(cls, "run_classification_job", _raiser, raising=True) + + with TestClient(app_mod.app) as client: + r = client.post("/classify", json={"s3_bucket": "b", "s3_key": "crash.wav"}) + assert r.status_code == 500 + # No upsert on failure + assert cap.get("upserts", []) == [] + +def test_middleware_executes(monkeypatch): + """ + Hitting endpoints implicitly passes through the timing middleware. + This test mainly ensures middleware path executes without errors + (coverage); assertions are on status only. + """ + cap = {} + _patch_startup_and_db(monkeypatch, capture=cap) + # First pass: just exercise /health + with TestClient(app_mod.app) as client: + r1 = client.get("/health") + assert r1.status_code == 200 + + # Second pass: exercise /classify with a minimal classifier mock + cap = {} + _patch_startup_and_db(monkeypatch, capture=cap) + import classification.scripts.classify as cls + monkeypatch.setattr(cls, "run_classification_job", + lambda **k: {"label": "ok", "probs": {}}, raising=True) + with TestClient(app_mod.app) as client: + r = client.post("/classify", json={"s3_bucket": "b", "s3_key": "k.wav"}) + assert r.status_code == 200 diff --git a/services/sounds/sounds_classifier/tests/test_classify_core.py b/services/sounds_classifier/tests/test_classify_core.py similarity index 97% rename from services/sounds/sounds_classifier/tests/test_classify_core.py rename to services/sounds_classifier/tests/test_classify_core.py index 3ccace665..8c9423af3 100644 --- a/services/sounds/sounds_classifier/tests/test_classify_core.py +++ b/services/sounds_classifier/tests/test_classify_core.py @@ -1,262 +1,262 @@ -import numpy as np -from pathlib import Path -import classification.scripts.classify as c - - -# ---- Helpers ---- -class DummyHead: - def __init__(self, classes): - self.classes_ = classes - self._probas = None - def set_out(self, arr): - self._probas = arr - def predict_proba(self, X): - n = X.shape[0] - return np.tile(self._probas, (n, 1)) - -def _reset_runtime(): - c.R = c._Runtime() - -def _to_dict_result(res): - """Normalize classifier outputs to a dict {label, probs} for tests.""" - if isinstance(res, tuple): - label, probs = res - return {"label": label, "probs": probs} - return res - - -# ---- Core classification flow & validations ---- -def test_classify_file_unknown_threshold(monkeypatch): - _reset_runtime() - c.R.model = object() - head = DummyHead(classes=["car", "dog"]) - head.set_out(np.array([0.51, 0.49], dtype=np.float32)) - c.R.head = head - c.R.classes = head.classes_ - - monkeypatch.setattr(c, "load_audio", lambda p, sr: np.ones(16000, dtype=np.float32), raising=True) - # one window: shape (1, 16000) - monkeypatch.setattr(c, "segment_waveform_2d_view", - lambda *a, **k: np.ones((1, 16000), dtype=np.float32), - raising=True) - monkeypatch.setattr( - c, - "run_cnn14_embeddings_batch", - lambda _model, windows, batch_size=32: np.tile(np.array([[1, 2, 3, 4]], dtype=np.float32), (windows.shape[0], 1)), - raising=True, - ) - result = _to_dict_result(c.classify_file("dummy.wav")) - assert result["label"] in ("car", "another") - assert set(result["probs"].keys()) == {"car", "dog"} - - -def test__aggregate_probs_rejects_bad_ndim(): - x = np.array([0.1, 0.9], dtype=np.float32) # 1-D - try: - c._aggregate_probs(x) - assert False, "Expected ValueError for 1-D array" - except ValueError as e: - assert "expected shape" in str(e) - - -def test__aggregate_probs_empty_windows_returns_zeros(): - x = np.zeros((0, 3), dtype=np.float32) - out = c._aggregate_probs(x) - assert out.shape == (3,) - assert np.allclose(out, 0.0) - - -def test_classify_file_returns_another_when_no_segments(monkeypatch): - _reset_runtime() - c.R.model = object() - - class Head: - def __init__(self): self.classes_ = ["car", "dog"] - def predict_proba(self, X): return np.zeros((X.shape[0], 2), dtype=np.float32) - - c.R.head = Head() - c.R.classes = ["car", "dog"] - - monkeypatch.setattr(c, "load_audio", lambda p, sr: np.ones(16000, dtype=np.float32), raising=True) - # zero windows: shape (0, 16000) - monkeypatch.setattr(c, "segment_waveform_2d_view", - lambda *a, **k: np.zeros((0, 16000), dtype=np.float32), - raising=True) - - result = _to_dict_result(c.classify_file("dummy.wav")) - assert result["label"] == "another" - assert set(result["probs"].keys()) == {"car", "dog"} - - -def test_classify_file_with_agg_max(monkeypatch): - _reset_runtime() - old_agg = c.AGG - c.AGG = "max" - try: - c.R.model = object() - - class Head: - def __init__(self): self.classes_ = ["a", "b"] - def predict_proba(self, X): - # pretend we got two windows → choose element-wise max - return np.array([[0.2, 0.8], [0.6, 0.4]], dtype=np.float32) - - c.R.head = Head() - c.R.classes = ["a", "b"] - - monkeypatch.setattr(c, "load_audio", lambda p, sr: np.ones(2 * 16000, dtype=np.float32), raising=True) - # two windows: shape (2, 16000) - monkeypatch.setattr(c, "segment_waveform_2d_view", - lambda *a, **k: np.ones((2, 16000), dtype=np.float32), - raising=True) - monkeypatch.setattr( - c, - "run_cnn14_embeddings_batch", - lambda _model, windows, batch_size=32: np.tile(np.array([[1, 2, 3, 4]], dtype=np.float32), (windows.shape[0], 1)), - raising=True, - ) - result = _to_dict_result(c.classify_file("x.wav")) - assert set(result["probs"].keys()) == {"a", "b"} - assert np.isclose(result["probs"]["a"], 0.6) or np.isclose(result["probs"]["b"], 0.8) - finally: - c.AGG = old_agg - - -def test_run_classification_job_happy_path(monkeypatch, tmp_path): - _reset_runtime() - - class Stat: size = 10; content_type = "audio/wav" - class Client: - def stat_object(self, b, k): return Stat() - def fget_object(self, b, k, dst): Path(dst).write_bytes(b"RIFF") - - monkeypatch.setattr(c, "_get_minio", lambda: Client(), raising=True) - - c.KAFKA_BROKERS, c.ALERTS_TOPIC = "", "dev-robot-alerts" - monkeypatch.setattr(c.alerts, "send_alert", lambda **kw: None, raising=True) - - c.R.model = object() - head = DummyHead(classes=["car", "dog"]); head.set_out(np.array([0.6, 0.4], dtype=np.float32)) - c.R.head = head; c.R.classes = head.classes_ - - monkeypatch.setattr(c, "load_audio", lambda p, sr: np.ones(16000, dtype=np.float32), raising=True) - monkeypatch.setattr(c, "segment_waveform_2d_view", - lambda *a, **k: np.ones((1, 16000), dtype=np.float32), - raising=True) - monkeypatch.setattr( - c, - "run_cnn14_embeddings_batch", - lambda _model, windows, batch_size=32: np.tile(np.array([[1, 2, 3, 4]], dtype=np.float32), (windows.shape[0], 1)), - raising=True, - ) - - out = _to_dict_result(c.run_classification_job(s3_bucket="b", s3_key="k.wav")) - assert out["label"] in ("car", "another") - assert "probs" in out and isinstance(out["probs"], dict) - - -def test_run_classification_job_bucket_not_allowed(): - _reset_runtime() - c.R.head = object() - old = c.ALLOWED_BUCKETS - c.ALLOWED_BUCKETS = ["only-this-bucket"] - try: - try: - _ = c.run_classification_job(s3_bucket="not-allowed", s3_key="a.wav") - assert False, "Expected RuntimeError for disallowed bucket" - except RuntimeError as e: - assert "not allowed" in str(e) - finally: - c.ALLOWED_BUCKETS = old - - -def test_run_classification_job_rejects_content_type(monkeypatch): - _reset_runtime() - c.R.head = object() - c.ALLOWED_BUCKETS = [] - old_types = c.ALLOWED_CONTENT_TYPES - c.ALLOWED_CONTENT_TYPES = ["audio/wav"] - - class Stat: size = 1024; content_type = "text/plain" - class Client: - def stat_object(self, b, k): return Stat() - def fget_object(self, b, k, dst): Path(dst).write_bytes(b"RIFF") - - monkeypatch.setattr(c, "_get_minio", lambda: Client(), raising=True) - - try: - _ = c.run_classification_job(s3_bucket="ok", s3_key="a.wav") - assert False, "Expected RuntimeError for unsupported content-type" - except RuntimeError as e: - assert "Unsupported content-type" in str(e) - finally: - c.ALLOWED_CONTENT_TYPES = old_types - - -def test_run_classification_job_rejects_size(monkeypatch): - _reset_runtime() - c.R.head = object() - c.ALLOWED_BUCKETS = [] - old_max = c.MAX_BYTES; c.MAX_BYTES = 10 - - class Stat: size = 11; content_type = "audio/wav" - class Client: - def stat_object(self, b, k): return Stat() - def fget_object(self, b, k, dst): Path(dst).write_bytes(b"RIFF") - - monkeypatch.setattr(c, "_get_minio", lambda: Client(), raising=True) - - try: - _ = c.run_classification_job(s3_bucket="ok", s3_key="a.wav") - assert False, "Expected RuntimeError for object too large" - except RuntimeError as e: - assert "Object too large" in str(e) - finally: - c.MAX_BYTES = old_max - - -def test_run_classification_job_s3error_fails_fast(monkeypatch): - _reset_runtime() - c.R.head = object() - c.ALLOWED_BUCKETS = [] - - class S3Err(Exception): pass - monkeypatch.setattr(c, "S3Error", S3Err, raising=True) - - class Client: - def stat_object(self, b, k): raise S3Err("boom") - def fget_object(self, b, k, dst): raise AssertionError("should not be called") - - monkeypatch.setattr(c, "_get_minio", lambda: Client(), raising=True) - - try: - _ = c.run_classification_job(s3_bucket="ok", s3_key="a.wav") - assert False, "Expected RuntimeError wrapping S3 failure" - except RuntimeError as e: - assert "S3 stat failed" in str(e) - - -def test_run_classification_job_adds_wav_suffix_when_missing(monkeypatch): - _reset_runtime() - c.R.head = object() - c.R.classes = ["car", "dog"] - c.ALLOWED_BUCKETS = [] - - class Stat: size = 100; content_type = "audio/wav" - observed = {"ext": None} - - class Client: - def stat_object(self, b, k): return Stat() - def fget_object(self, b, k, dst): - observed["ext"] = Path(dst).suffix - Path(dst).write_bytes(b"RIFF") - - monkeypatch.setattr(c, "_get_minio", lambda: Client(), raising=True) - monkeypatch.setattr(c, "classify_file", - lambda p, **_: {"label": "another", "probs": {"car": 0.0, "dog": 0.0}}, - raising=True) - monkeypatch.setattr(c.os, "remove", lambda p: None, raising=True) - - out = _to_dict_result(c.run_classification_job(s3_bucket="ok", s3_key="noext")) - assert out["label"] in ("another", "car", "dog") - assert observed["ext"] == ".wav" +import numpy as np +from pathlib import Path +import classification.scripts.classify as c + + +# ---- Helpers ---- +class DummyHead: + def __init__(self, classes): + self.classes_ = classes + self._probas = None + def set_out(self, arr): + self._probas = arr + def predict_proba(self, X): + n = X.shape[0] + return np.tile(self._probas, (n, 1)) + +def _reset_runtime(): + c.R = c._Runtime() + +def _to_dict_result(res): + """Normalize classifier outputs to a dict {label, probs} for tests.""" + if isinstance(res, tuple): + label, probs = res + return {"label": label, "probs": probs} + return res + + +# ---- Core classification flow & validations ---- +def test_classify_file_unknown_threshold(monkeypatch): + _reset_runtime() + c.R.model = object() + head = DummyHead(classes=["car", "dog"]) + head.set_out(np.array([0.51, 0.49], dtype=np.float32)) + c.R.head = head + c.R.classes = head.classes_ + + monkeypatch.setattr(c, "load_audio", lambda p, sr: np.ones(16000, dtype=np.float32), raising=True) + # one window: shape (1, 16000) + monkeypatch.setattr(c, "segment_waveform_2d_view", + lambda *a, **k: np.ones((1, 16000), dtype=np.float32), + raising=True) + monkeypatch.setattr( + c, + "run_cnn14_embeddings_batch", + lambda _model, windows, batch_size=32: np.tile(np.array([[1, 2, 3, 4]], dtype=np.float32), (windows.shape[0], 1)), + raising=True, + ) + result = _to_dict_result(c.classify_file("dummy.wav")) + assert result["label"] in ("car", "another") + assert set(result["probs"].keys()) == {"car", "dog"} + + +def test__aggregate_probs_rejects_bad_ndim(): + x = np.array([0.1, 0.9], dtype=np.float32) # 1-D + try: + c._aggregate_probs(x) + assert False, "Expected ValueError for 1-D array" + except ValueError as e: + assert "expected shape" in str(e) + + +def test__aggregate_probs_empty_windows_returns_zeros(): + x = np.zeros((0, 3), dtype=np.float32) + out = c._aggregate_probs(x) + assert out.shape == (3,) + assert np.allclose(out, 0.0) + + +def test_classify_file_returns_another_when_no_segments(monkeypatch): + _reset_runtime() + c.R.model = object() + + class Head: + def __init__(self): self.classes_ = ["car", "dog"] + def predict_proba(self, X): return np.zeros((X.shape[0], 2), dtype=np.float32) + + c.R.head = Head() + c.R.classes = ["car", "dog"] + + monkeypatch.setattr(c, "load_audio", lambda p, sr: np.ones(16000, dtype=np.float32), raising=True) + # zero windows: shape (0, 16000) + monkeypatch.setattr(c, "segment_waveform_2d_view", + lambda *a, **k: np.zeros((0, 16000), dtype=np.float32), + raising=True) + + result = _to_dict_result(c.classify_file("dummy.wav")) + assert result["label"] == "another" + assert set(result["probs"].keys()) == {"car", "dog"} + + +def test_classify_file_with_agg_max(monkeypatch): + _reset_runtime() + old_agg = c.AGG + c.AGG = "max" + try: + c.R.model = object() + + class Head: + def __init__(self): self.classes_ = ["a", "b"] + def predict_proba(self, X): + # pretend we got two windows → choose element-wise max + return np.array([[0.2, 0.8], [0.6, 0.4]], dtype=np.float32) + + c.R.head = Head() + c.R.classes = ["a", "b"] + + monkeypatch.setattr(c, "load_audio", lambda p, sr: np.ones(2 * 16000, dtype=np.float32), raising=True) + # two windows: shape (2, 16000) + monkeypatch.setattr(c, "segment_waveform_2d_view", + lambda *a, **k: np.ones((2, 16000), dtype=np.float32), + raising=True) + monkeypatch.setattr( + c, + "run_cnn14_embeddings_batch", + lambda _model, windows, batch_size=32: np.tile(np.array([[1, 2, 3, 4]], dtype=np.float32), (windows.shape[0], 1)), + raising=True, + ) + result = _to_dict_result(c.classify_file("x.wav")) + assert set(result["probs"].keys()) == {"a", "b"} + assert np.isclose(result["probs"]["a"], 0.6) or np.isclose(result["probs"]["b"], 0.8) + finally: + c.AGG = old_agg + + +def test_run_classification_job_happy_path(monkeypatch, tmp_path): + _reset_runtime() + + class Stat: size = 10; content_type = "audio/wav" + class Client: + def stat_object(self, b, k): return Stat() + def fget_object(self, b, k, dst): Path(dst).write_bytes(b"RIFF") + + monkeypatch.setattr(c, "_get_minio", lambda: Client(), raising=True) + + c.KAFKA_BROKERS, c.ALERTS_TOPIC = "", "dev-robot-alerts" + monkeypatch.setattr(c.alerts, "send_alert", lambda **kw: None, raising=True) + + c.R.model = object() + head = DummyHead(classes=["car", "dog"]); head.set_out(np.array([0.6, 0.4], dtype=np.float32)) + c.R.head = head; c.R.classes = head.classes_ + + monkeypatch.setattr(c, "load_audio", lambda p, sr: np.ones(16000, dtype=np.float32), raising=True) + monkeypatch.setattr(c, "segment_waveform_2d_view", + lambda *a, **k: np.ones((1, 16000), dtype=np.float32), + raising=True) + monkeypatch.setattr( + c, + "run_cnn14_embeddings_batch", + lambda _model, windows, batch_size=32: np.tile(np.array([[1, 2, 3, 4]], dtype=np.float32), (windows.shape[0], 1)), + raising=True, + ) + + out = _to_dict_result(c.run_classification_job(s3_bucket="b", s3_key="k.wav")) + assert out["label"] in ("car", "another") + assert "probs" in out and isinstance(out["probs"], dict) + + +def test_run_classification_job_bucket_not_allowed(): + _reset_runtime() + c.R.head = object() + old = c.ALLOWED_BUCKETS + c.ALLOWED_BUCKETS = ["only-this-bucket"] + try: + try: + _ = c.run_classification_job(s3_bucket="not-allowed", s3_key="a.wav") + assert False, "Expected RuntimeError for disallowed bucket" + except RuntimeError as e: + assert "not allowed" in str(e) + finally: + c.ALLOWED_BUCKETS = old + + +def test_run_classification_job_rejects_content_type(monkeypatch): + _reset_runtime() + c.R.head = object() + c.ALLOWED_BUCKETS = [] + old_types = c.ALLOWED_CONTENT_TYPES + c.ALLOWED_CONTENT_TYPES = ["audio/wav"] + + class Stat: size = 1024; content_type = "text/plain" + class Client: + def stat_object(self, b, k): return Stat() + def fget_object(self, b, k, dst): Path(dst).write_bytes(b"RIFF") + + monkeypatch.setattr(c, "_get_minio", lambda: Client(), raising=True) + + try: + _ = c.run_classification_job(s3_bucket="ok", s3_key="a.wav") + assert False, "Expected RuntimeError for unsupported content-type" + except RuntimeError as e: + assert "Unsupported content-type" in str(e) + finally: + c.ALLOWED_CONTENT_TYPES = old_types + + +def test_run_classification_job_rejects_size(monkeypatch): + _reset_runtime() + c.R.head = object() + c.ALLOWED_BUCKETS = [] + old_max = c.MAX_BYTES; c.MAX_BYTES = 10 + + class Stat: size = 11; content_type = "audio/wav" + class Client: + def stat_object(self, b, k): return Stat() + def fget_object(self, b, k, dst): Path(dst).write_bytes(b"RIFF") + + monkeypatch.setattr(c, "_get_minio", lambda: Client(), raising=True) + + try: + _ = c.run_classification_job(s3_bucket="ok", s3_key="a.wav") + assert False, "Expected RuntimeError for object too large" + except RuntimeError as e: + assert "Object too large" in str(e) + finally: + c.MAX_BYTES = old_max + + +def test_run_classification_job_s3error_fails_fast(monkeypatch): + _reset_runtime() + c.R.head = object() + c.ALLOWED_BUCKETS = [] + + class S3Err(Exception): pass + monkeypatch.setattr(c, "S3Error", S3Err, raising=True) + + class Client: + def stat_object(self, b, k): raise S3Err("boom") + def fget_object(self, b, k, dst): raise AssertionError("should not be called") + + monkeypatch.setattr(c, "_get_minio", lambda: Client(), raising=True) + + try: + _ = c.run_classification_job(s3_bucket="ok", s3_key="a.wav") + assert False, "Expected RuntimeError wrapping S3 failure" + except RuntimeError as e: + assert "S3 stat failed" in str(e) + + +def test_run_classification_job_adds_wav_suffix_when_missing(monkeypatch): + _reset_runtime() + c.R.head = object() + c.R.classes = ["car", "dog"] + c.ALLOWED_BUCKETS = [] + + class Stat: size = 100; content_type = "audio/wav" + observed = {"ext": None} + + class Client: + def stat_object(self, b, k): return Stat() + def fget_object(self, b, k, dst): + observed["ext"] = Path(dst).suffix + Path(dst).write_bytes(b"RIFF") + + monkeypatch.setattr(c, "_get_minio", lambda: Client(), raising=True) + monkeypatch.setattr(c, "classify_file", + lambda p, **_: {"label": "another", "probs": {"car": 0.0, "dog": 0.0}}, + raising=True) + monkeypatch.setattr(c.os, "remove", lambda p: None, raising=True) + + out = _to_dict_result(c.run_classification_job(s3_bucket="ok", s3_key="noext")) + assert out["label"] in ("another", "car", "dog") + assert observed["ext"] == ".wav" diff --git a/services/sounds/sounds_classifier/tests/test_classify_head_loading.py b/services/sounds_classifier/tests/test_classify_head_loading.py similarity index 96% rename from services/sounds/sounds_classifier/tests/test_classify_head_loading.py rename to services/sounds_classifier/tests/test_classify_head_loading.py index 2769dd620..60064879c 100644 --- a/services/sounds/sounds_classifier/tests/test_classify_head_loading.py +++ b/services/sounds_classifier/tests/test_classify_head_loading.py @@ -1,142 +1,142 @@ -import json -import numpy as np -import classification.scripts.classify as c - - -def _reset_runtime(): - c.R = c._Runtime() - - -def test__load_backbone_once_calls_loader(monkeypatch): - _reset_runtime() - old = c.BACKBONE - try: - c.BACKBONE = "cnn14" - called = {"ok": False} - def fake_loader(checkpoint_path=None, device="cpu", checkpoint_url=None): - called["ok"] = True - return object() - monkeypatch.setattr(c, "load_cnn14_model", fake_loader, raising=True) - c._load_backbone_once() - assert called["ok"] is True and c.R.model is not None - # Second call should be a no-op - called["ok"] = False - c._load_backbone_once() - assert called["ok"] is False - finally: - c.BACKBONE = old - - -def test__load_backbone_once_rejects_non_cnn14(): - _reset_runtime() - old = c.BACKBONE - try: - c.BACKBONE = "something-else" - try: - c._load_backbone_once() - assert False, "Expected RuntimeError for unsupported backbone" - except RuntimeError as e: - assert "Only BACKBONE=cnn14" in str(e) - finally: - c.BACKBONE = old - - -def test__load_head_once_requires_head(): - _reset_runtime() - old = c.HEAD_PATH - try: - c.HEAD_PATH = "" - try: - c._load_head_once() - assert False, "Expected RuntimeError when HEAD env/path is missing" - except RuntimeError as e: - assert "HEAD env var is required" in str(e) - finally: - c.HEAD_PATH = old - - -def test__load_head_once_uses_labels_csv(monkeypatch): - _reset_runtime() - - class DummyHead: - def __init__(self): - self.classes_ = ["x", "y"] - def predict_proba(self, X): - return np.zeros((X.shape[0], 2), dtype=np.float32) - - c.HEAD_PATH = "dummy.joblib" - monkeypatch.setattr(c.joblib, "load", lambda _: DummyHead(), raising=True) - - monkeypatch.setenv("LABELS_CSV", "labels.csv") - import classification.core.model_io as mio - # Important: raising=False because the attribute doesn't exist in the real module - monkeypatch.setattr(mio, "load_labels_from_csv", lambda _: ["car", "dog"], raising=False) - - c._load_head_once() - assert c.R.head is not None - assert c.R.classes == ["car", "dog"] - monkeypatch.delenv("LABELS_CSV", raising=False) - - -def test__load_head_once_fallback_to_head_classes(monkeypatch): - _reset_runtime() - - class DummyHead: - def __init__(self): - self.classes_ = ["cat", "plane"] - def predict_proba(self, X): - return np.zeros((X.shape[0], 2), dtype=np.float32) - - c.HEAD_PATH = "dummy.joblib" - monkeypatch.setattr(c.joblib, "load", lambda _: DummyHead(), raising=True) - monkeypatch.delenv("LABELS_CSV", raising=False) - monkeypatch.setenv("HEAD_META", "does_not_exist.json") - - c._load_head_once() - assert c.R.head is not None - assert c.R.classes == ["cat", "plane"] - - -def test__load_head_once_uses_meta_with_indexed_classes(monkeypatch, tmp_path): - _reset_runtime() - - class DummyHeadInt: - def __init__(self): - self.classes_ = [0, 1] - def predict_proba(self, X): - return np.zeros((X.shape[0], 2), dtype=np.float32) - - c.HEAD_PATH = "dummy.joblib" - monkeypatch.setattr(c.joblib, "load", lambda _: DummyHeadInt(), raising=True) - - meta_path = tmp_path / "head_meta.json" - meta_path.write_text(json.dumps({"class_order": ["engine", "bird"]}), encoding="utf-8") - monkeypatch.setenv("HEAD_META", str(meta_path)) - monkeypatch.delenv("LABELS_CSV", raising=False) - - c._load_head_once() - assert c.R.classes == ["engine", "bird"] - - -def test__load_head_once_meta_length_mismatch_raises(monkeypatch, tmp_path): - _reset_runtime() - - class DummyHeadInt: - def __init__(self): - self.classes_ = [0, 1] - def predict_proba(self, X): - return np.zeros((X.shape[0], 2), dtype=np.float32) - - c.HEAD_PATH = "dummy.joblib" - monkeypatch.setattr(c.joblib, "load", lambda _: DummyHeadInt(), raising=True) - - meta_path = tmp_path / "bad_meta.json" - meta_path.write_text(json.dumps({"class_order": ["only-one"]}), encoding="utf-8") - monkeypatch.setenv("HEAD_META", str(meta_path)) - monkeypatch.delenv("LABELS_CSV", raising=False) - - try: - c._load_head_once() - assert False, "Expected RuntimeError for meta length mismatch" - except RuntimeError as e: - assert "class_order length" in str(e) +import json +import numpy as np +import classification.scripts.classify as c + + +def _reset_runtime(): + c.R = c._Runtime() + + +def test__load_backbone_once_calls_loader(monkeypatch): + _reset_runtime() + old = c.BACKBONE + try: + c.BACKBONE = "cnn14" + called = {"ok": False} + def fake_loader(checkpoint_path=None, device="cpu", checkpoint_url=None): + called["ok"] = True + return object() + monkeypatch.setattr(c, "load_cnn14_model", fake_loader, raising=True) + c._load_backbone_once() + assert called["ok"] is True and c.R.model is not None + # Second call should be a no-op + called["ok"] = False + c._load_backbone_once() + assert called["ok"] is False + finally: + c.BACKBONE = old + + +def test__load_backbone_once_rejects_non_cnn14(): + _reset_runtime() + old = c.BACKBONE + try: + c.BACKBONE = "something-else" + try: + c._load_backbone_once() + assert False, "Expected RuntimeError for unsupported backbone" + except RuntimeError as e: + assert "Only BACKBONE=cnn14" in str(e) + finally: + c.BACKBONE = old + + +def test__load_head_once_requires_head(): + _reset_runtime() + old = c.HEAD_PATH + try: + c.HEAD_PATH = "" + try: + c._load_head_once() + assert False, "Expected RuntimeError when HEAD env/path is missing" + except RuntimeError as e: + assert "HEAD env var is required" in str(e) + finally: + c.HEAD_PATH = old + + +def test__load_head_once_uses_labels_csv(monkeypatch): + _reset_runtime() + + class DummyHead: + def __init__(self): + self.classes_ = ["x", "y"] + def predict_proba(self, X): + return np.zeros((X.shape[0], 2), dtype=np.float32) + + c.HEAD_PATH = "dummy.joblib" + monkeypatch.setattr(c.joblib, "load", lambda _: DummyHead(), raising=True) + + monkeypatch.setenv("LABELS_CSV", "labels.csv") + import classification.core.model_io as mio + # Important: raising=False because the attribute doesn't exist in the real module + monkeypatch.setattr(mio, "load_labels_from_csv", lambda _: ["car", "dog"], raising=False) + + c._load_head_once() + assert c.R.head is not None + assert c.R.classes == ["car", "dog"] + monkeypatch.delenv("LABELS_CSV", raising=False) + + +def test__load_head_once_fallback_to_head_classes(monkeypatch): + _reset_runtime() + + class DummyHead: + def __init__(self): + self.classes_ = ["cat", "plane"] + def predict_proba(self, X): + return np.zeros((X.shape[0], 2), dtype=np.float32) + + c.HEAD_PATH = "dummy.joblib" + monkeypatch.setattr(c.joblib, "load", lambda _: DummyHead(), raising=True) + monkeypatch.delenv("LABELS_CSV", raising=False) + monkeypatch.setenv("HEAD_META", "does_not_exist.json") + + c._load_head_once() + assert c.R.head is not None + assert c.R.classes == ["cat", "plane"] + + +def test__load_head_once_uses_meta_with_indexed_classes(monkeypatch, tmp_path): + _reset_runtime() + + class DummyHeadInt: + def __init__(self): + self.classes_ = [0, 1] + def predict_proba(self, X): + return np.zeros((X.shape[0], 2), dtype=np.float32) + + c.HEAD_PATH = "dummy.joblib" + monkeypatch.setattr(c.joblib, "load", lambda _: DummyHeadInt(), raising=True) + + meta_path = tmp_path / "head_meta.json" + meta_path.write_text(json.dumps({"class_order": ["engine", "bird"]}), encoding="utf-8") + monkeypatch.setenv("HEAD_META", str(meta_path)) + monkeypatch.delenv("LABELS_CSV", raising=False) + + c._load_head_once() + assert c.R.classes == ["engine", "bird"] + + +def test__load_head_once_meta_length_mismatch_raises(monkeypatch, tmp_path): + _reset_runtime() + + class DummyHeadInt: + def __init__(self): + self.classes_ = [0, 1] + def predict_proba(self, X): + return np.zeros((X.shape[0], 2), dtype=np.float32) + + c.HEAD_PATH = "dummy.joblib" + monkeypatch.setattr(c.joblib, "load", lambda _: DummyHeadInt(), raising=True) + + meta_path = tmp_path / "bad_meta.json" + meta_path.write_text(json.dumps({"class_order": ["only-one"]}), encoding="utf-8") + monkeypatch.setenv("HEAD_META", str(meta_path)) + monkeypatch.delenv("LABELS_CSV", raising=False) + + try: + c._load_head_once() + assert False, "Expected RuntimeError for meta length mismatch" + except RuntimeError as e: + assert "class_order length" in str(e) diff --git a/services/sounds/sounds_classifier/tests/test_cnn14.py b/services/sounds_classifier/tests/test_cnn14.py similarity index 97% rename from services/sounds/sounds_classifier/tests/test_cnn14.py rename to services/sounds_classifier/tests/test_cnn14.py index 742810377..44330e533 100644 --- a/services/sounds/sounds_classifier/tests/test_cnn14.py +++ b/services/sounds_classifier/tests/test_cnn14.py @@ -1,111 +1,111 @@ -import numpy as np -import types -import builtins -import pytest -import classification.backbones.cnn14 as cnn14 -from classification.backbones.cnn14 import run_cnn14_embeddings_batch -from classification.core.model_io import segment_waveform_2d_view, SAMPLE_RATE - -class DummyAT: - def __init__(self, checkpoint_path: str, device: str = "cpu"): - self.checkpoint_path = checkpoint_path - self.device = device - def inference(self, wav): - # Return (dummy_probs, dummy_embedding) - emb = np.ones((1, 2048), dtype=np.float32) - return (None, emb) - -class DummyATTuple: - def inference(self, x): - # x: (N, samples) - N = x.shape[0] - emb_dim = 8 - scores = np.zeros((N, 10), dtype=np.float32) # unused - embs = np.arange(N * emb_dim, dtype=np.float32).reshape(N, emb_dim) - return (scores, embs) - -class DummyATDict: - def inference(self, x): - N = x.shape[0] - emb_dim = 16 - return {"embedding": np.ones((N, emb_dim), dtype=np.float32)} - -def test_load_cnn14_model_uses_ensure_checkpoint(monkeypatch, tmp_path): - ckpt = tmp_path / "m.pth" - def fake_ensure(path, url): - ckpt.write_bytes(b"dummy") - return str(ckpt) - monkeypatch.setattr(cnn14, "ensure_checkpoint", fake_ensure) - # Mock panns_inference.AudioTagging - monkeypatch.setattr(cnn14, "AudioTagging", DummyAT, raising=True) - - m = cnn14.load_cnn14_model(checkpoint_path=str(ckpt), device="cpu") - assert isinstance(m, DummyAT) - assert m.checkpoint_path == str(ckpt) - -def test_run_cnn14_embedding_happy_path(monkeypatch): - dummy = DummyAT("x") - x = np.random.randn(32000).astype(np.float32) - e = cnn14.run_cnn14_embedding(dummy, x) - assert e.shape == (2048,) - -def test_run_cnn14_embedding_empty_raises(monkeypatch): - dummy = DummyAT("x") - with pytest.raises(ValueError): - cnn14.run_cnn14_embedding(dummy, np.array([], dtype=np.float32)) - - -def test_run_cnn14_embeddings_batch_tuple_output(monkeypatch): - model = DummyATTuple() - # Prevent flattening of (N, D) to (N*D,) - monkeypatch.setattr(cnn14, "_to_numpy", - lambda x: np.asarray(x, dtype=np.float32), - raising=True) - - windows = np.ones((5, 32000), dtype=np.float32) - E = run_cnn14_embeddings_batch(model, windows, batch_size=2) - assert E.shape == (5, 8) - assert np.allclose(E[1], np.arange(8, 16, dtype=np.float32)) - -def test_run_cnn14_embeddings_batch_dict_output(monkeypatch): - model = DummyATDict() - # Prevent flattening of (N, D) to (N*D,) - monkeypatch.setattr(cnn14, "_to_numpy", - lambda x: np.asarray(x, dtype=np.float32), - raising=True) - windows = np.ones((3, 16000), dtype=np.float32) - E = run_cnn14_embeddings_batch(model, windows, batch_size=32) - assert E.shape == (3, 16) - assert np.allclose(E, 1.0) - -def test_segment_waveform_2d_basic_exact_fit(): - sr = SAMPLE_RATE - win_s = 1.0 - hop_s = 0.75 - wav = np.ones(int(sr * 1.6), dtype=np.float32) - segs = segment_waveform_2d_view(wav, sr=sr, window_sec=win_s, hop_sec=hop_s, pad_last=False) - # Expect exactly one full window: [0..1.0] - assert segs.shape[0] == 1 - assert segs.shape[1] == int(sr * win_s) - # No padding is expected when pad_last=False - -def test_segment_waveform_2d_pad_last(): - sr = SAMPLE_RATE - win_s = 1.0 - hop_s = 0.75 - wav = np.ones(int(sr * 1.6), dtype=np.float32) - segs = segment_waveform_2d_view(wav, sr=sr, window_sec=win_s, hop_sec=hop_s, pad_last=True) - # Expect one full window + one padded tail window → total 2 - assert segs.shape[0] == 2 - assert segs.shape[1] == int(sr * win_s) - assert np.any(segs[-1] == 0.0) - -def test_segment_waveform_2d_empty_input_returns_empty_or_padded(): - sr = SAMPLE_RATE - segs_no_pad = segment_waveform_2d_view(np.zeros(0, dtype=np.float32), sr=sr, - window_sec=1.0, hop_sec=0.5, pad_last=False) - assert segs_no_pad.shape == (0, int(sr * 1.0)) - segs_pad = segment_waveform_2d_view(np.zeros(0, dtype=np.float32), sr=sr, - window_sec=1.0, hop_sec=0.5, pad_last=True) - # Current impl returns 0 windows even with pad_last for empty input - assert segs_pad.shape == (0, int(sr * 1.0)) +import numpy as np +import types +import builtins +import pytest +import classification.backbones.cnn14 as cnn14 +from classification.backbones.cnn14 import run_cnn14_embeddings_batch +from classification.core.model_io import segment_waveform_2d_view, SAMPLE_RATE + +class DummyAT: + def __init__(self, checkpoint_path: str, device: str = "cpu"): + self.checkpoint_path = checkpoint_path + self.device = device + def inference(self, wav): + # Return (dummy_probs, dummy_embedding) + emb = np.ones((1, 2048), dtype=np.float32) + return (None, emb) + +class DummyATTuple: + def inference(self, x): + # x: (N, samples) + N = x.shape[0] + emb_dim = 8 + scores = np.zeros((N, 10), dtype=np.float32) # unused + embs = np.arange(N * emb_dim, dtype=np.float32).reshape(N, emb_dim) + return (scores, embs) + +class DummyATDict: + def inference(self, x): + N = x.shape[0] + emb_dim = 16 + return {"embedding": np.ones((N, emb_dim), dtype=np.float32)} + +def test_load_cnn14_model_uses_ensure_checkpoint(monkeypatch, tmp_path): + ckpt = tmp_path / "m.pth" + def fake_ensure(path, url): + ckpt.write_bytes(b"dummy") + return str(ckpt) + monkeypatch.setattr(cnn14, "ensure_checkpoint", fake_ensure) + # Mock panns_inference.AudioTagging + monkeypatch.setattr(cnn14, "AudioTagging", DummyAT, raising=True) + + m = cnn14.load_cnn14_model(checkpoint_path=str(ckpt), device="cpu") + assert isinstance(m, DummyAT) + assert m.checkpoint_path == str(ckpt) + +def test_run_cnn14_embedding_happy_path(monkeypatch): + dummy = DummyAT("x") + x = np.random.randn(32000).astype(np.float32) + e = cnn14.run_cnn14_embedding(dummy, x) + assert e.shape == (2048,) + +def test_run_cnn14_embedding_empty_raises(monkeypatch): + dummy = DummyAT("x") + with pytest.raises(ValueError): + cnn14.run_cnn14_embedding(dummy, np.array([], dtype=np.float32)) + + +def test_run_cnn14_embeddings_batch_tuple_output(monkeypatch): + model = DummyATTuple() + # Prevent flattening of (N, D) to (N*D,) + monkeypatch.setattr(cnn14, "_to_numpy", + lambda x: np.asarray(x, dtype=np.float32), + raising=True) + + windows = np.ones((5, 32000), dtype=np.float32) + E = run_cnn14_embeddings_batch(model, windows, batch_size=2) + assert E.shape == (5, 8) + assert np.allclose(E[1], np.arange(8, 16, dtype=np.float32)) + +def test_run_cnn14_embeddings_batch_dict_output(monkeypatch): + model = DummyATDict() + # Prevent flattening of (N, D) to (N*D,) + monkeypatch.setattr(cnn14, "_to_numpy", + lambda x: np.asarray(x, dtype=np.float32), + raising=True) + windows = np.ones((3, 16000), dtype=np.float32) + E = run_cnn14_embeddings_batch(model, windows, batch_size=32) + assert E.shape == (3, 16) + assert np.allclose(E, 1.0) + +def test_segment_waveform_2d_basic_exact_fit(): + sr = SAMPLE_RATE + win_s = 1.0 + hop_s = 0.75 + wav = np.ones(int(sr * 1.6), dtype=np.float32) + segs = segment_waveform_2d_view(wav, sr=sr, window_sec=win_s, hop_sec=hop_s, pad_last=False) + # Expect exactly one full window: [0..1.0] + assert segs.shape[0] == 1 + assert segs.shape[1] == int(sr * win_s) + # No padding is expected when pad_last=False + +def test_segment_waveform_2d_pad_last(): + sr = SAMPLE_RATE + win_s = 1.0 + hop_s = 0.75 + wav = np.ones(int(sr * 1.6), dtype=np.float32) + segs = segment_waveform_2d_view(wav, sr=sr, window_sec=win_s, hop_sec=hop_s, pad_last=True) + # Expect one full window + one padded tail window → total 2 + assert segs.shape[0] == 2 + assert segs.shape[1] == int(sr * win_s) + assert np.any(segs[-1] == 0.0) + +def test_segment_waveform_2d_empty_input_returns_empty_or_padded(): + sr = SAMPLE_RATE + segs_no_pad = segment_waveform_2d_view(np.zeros(0, dtype=np.float32), sr=sr, + window_sec=1.0, hop_sec=0.5, pad_last=False) + assert segs_no_pad.shape == (0, int(sr * 1.0)) + segs_pad = segment_waveform_2d_view(np.zeros(0, dtype=np.float32), sr=sr, + window_sec=1.0, hop_sec=0.5, pad_last=True) + # Current impl returns 0 windows even with pad_last for empty input + assert segs_pad.shape == (0, int(sr * 1.0)) diff --git a/services/sounds/sounds_classifier/tests/test_db_io_pg.py b/services/sounds_classifier/tests/test_db_io_pg.py similarity index 97% rename from services/sounds/sounds_classifier/tests/test_db_io_pg.py rename to services/sounds_classifier/tests/test_db_io_pg.py index 528d86c32..235f5a5e5 100644 --- a/services/sounds/sounds_classifier/tests/test_db_io_pg.py +++ b/services/sounds_classifier/tests/test_db_io_pg.py @@ -1,167 +1,167 @@ -import json -import types -import pytest -import classification.core.db_io_pg as dbpg - - -# ------------------------- -# Dummy connection helpers -# ------------------------- - -class DummyCursor: - def __init__(self, rec=None, raise_on_execute=False): - self.queries = [] - self._rec = rec - self._raise = raise_on_execute - def execute(self, q, p=None): - if self._raise: - raise RuntimeError("boom-exec") - self.queries.append((q, p)) - def fetchone(self): - return (123,) - def __enter__(self): - return self - def __exit__(self, *a): - return False - -class DummyConn: - def __init__(self, raise_on_execute=False): - self.cursors = [] - self.autocommit = False - self._commits = 0 - self._rollbacks = 0 - self._raise = raise_on_execute - def cursor(self): - c = DummyCursor(raise_on_execute=self._raise) - self.cursors.append(c) - return c - def commit(self): - self._commits += 1 - def rollback(self): - self._rollbacks += 1 - - -# ------------------------- -# open_db -# ------------------------- - -def test_open_db_validates_and_initializes_schema(monkeypatch): - # make psycopg2.connect return our dummy connection - monkeypatch.setattr(dbpg.psycopg2, "connect", lambda url: DummyConn()) - conn = dbpg.open_db("postgresql://u:p@h:5432/db", schema="audio_cls") - assert isinstance(conn, DummyConn) - # schema init should have committed once - assert conn._commits >= 1 - -def test_open_db_rejects_bad_schema(): - with pytest.raises(ValueError): - dbpg.open_db("postgresql://u:p@h:5432/db", schema="bad-dash") - -def test_open_db_rollback_on_failure(monkeypatch): - # first cursor.execute will raise - monkeypatch.setattr(dbpg.psycopg2, "connect", lambda url: DummyConn(raise_on_execute=True)) - with pytest.raises(RuntimeError): - dbpg.open_db("postgresql://u:p@h:5432/db", schema="audio_cls") - - -# ------------------------- -# upsert_run -# ------------------------- - -def test_upsert_run_success(monkeypatch): - conn = DummyConn() - dbpg.upsert_run(conn, { - "run_id": "r1", "model_name": "CNN14", "checkpoint": "ckpt", - "head_path": "h", "labels_csv": "l", "window_sec": 2.0, "hop_sec": 0.5, - "pad_last": True, "agg": "mean", "topk": 10, "device": "cpu", - "code_version": "v", "notes": "n" - }) - assert conn._commits == 1 - assert conn._rollbacks == 0 - # Ensure at least one execute has been issued - assert conn.cursors and conn.cursors[0].queries - -def test_upsert_run_rollback_on_exception(monkeypatch): - conn = DummyConn(raise_on_execute=True) - with pytest.raises(RuntimeError): - dbpg.upsert_run(conn, { - "run_id":"r1","model_name":"CNN14","checkpoint":"ckpt","head_path":"h","labels_csv":"l", - "window_sec":2.0,"hop_sec":0.5,"pad_last":True,"agg":"mean","topk":10,"device":"cpu", - "code_version":"v","notes":"n" - }) - assert conn._rollbacks == 1 - assert conn._commits == 0 - - -# ------------------------- -# finish_run -# ------------------------- - -def test_finish_run_success(): - conn = DummyConn() - dbpg.finish_run(conn, "r1") - assert conn._commits == 1 - assert conn._rollbacks == 0 - assert conn.cursors[0].queries # UPDATE executed - -def test_finish_run_rollback_on_exception(): - conn = DummyConn(raise_on_execute=True) - with pytest.raises(RuntimeError): - dbpg.finish_run(conn, "r1") - assert conn._rollbacks == 1 - assert conn._commits == 0 - -def test__jsonify_variants(monkeypatch): - # Wrap psycopg2.extras.Json to observe value passed in - captured = {"value": None} - def fake_json(v): - captured["value"] = v - return ("JsonWrapped", v) - - monkeypatch.setattr(dbpg.psycopg2.extras, "Json", fake_json, raising=True) - - # string with valid JSON → parsed dict - j = dbpg._jsonify('{"a":1}') - assert j == ("JsonWrapped", {"a": 1}) - assert captured["value"] == {"a": 1} - - # plain string → {"raw": "..."} - j2 = dbpg._jsonify("hello") - assert j2 == ("JsonWrapped", {"raw": "hello"}) - assert captured["value"] == {"raw": "hello"} - - # dict passes through - j3 = dbpg._jsonify({"k": 3}) - assert j3 == ("JsonWrapped", {"k": 3}) - assert captured["value"] == {"k": 3} - -def test_upsert_file_aggregate_success(monkeypatch): - # Make Json a pass-through so psycopg2.extras.Json(v) -> v - monkeypatch.setattr(dbpg.psycopg2.extras, "Json", lambda x: x, raising=True) - - conn = DummyConn() - dbpg.upsert_file_aggregate(conn, { - "run_id":"r1","file_id":123, - "head_probs_json":{"car":0.9},"head_pred_label":"car","head_pred_prob":0.9, - "head_unknown_threshold":0.55,"head_is_another":False, - "num_windows":3,"agg_mode":"mean","processing_ms":123 - }) - assert conn._commits == 1 - assert conn._rollbacks == 0 - -def test_upsert_file_aggregate_accepts_string_json_and_rollback_on_exception(monkeypatch): - # Json wrapper - monkeypatch.setattr(dbpg.psycopg2.extras, "Json", lambda x: x, raising=True) - - # connection that will fail during execute - conn = DummyConn(raise_on_execute=True) - with pytest.raises(RuntimeError): - dbpg.upsert_file_aggregate(conn, { - "run_id":"r1","file_id":123, - "head_probs_json":'{"car":0.9}', # string json - "head_pred_label":"car","head_pred_prob":0.9, - "head_unknown_threshold":0.55,"head_is_another":False, - "num_windows":3,"agg_mode":"mean","processing_ms":123 - }) - assert conn._rollbacks == 1 - assert conn._commits == 0 +import json +import types +import pytest +import classification.core.db_io_pg as dbpg + + +# ------------------------- +# Dummy connection helpers +# ------------------------- + +class DummyCursor: + def __init__(self, rec=None, raise_on_execute=False): + self.queries = [] + self._rec = rec + self._raise = raise_on_execute + def execute(self, q, p=None): + if self._raise: + raise RuntimeError("boom-exec") + self.queries.append((q, p)) + def fetchone(self): + return (123,) + def __enter__(self): + return self + def __exit__(self, *a): + return False + +class DummyConn: + def __init__(self, raise_on_execute=False): + self.cursors = [] + self.autocommit = False + self._commits = 0 + self._rollbacks = 0 + self._raise = raise_on_execute + def cursor(self): + c = DummyCursor(raise_on_execute=self._raise) + self.cursors.append(c) + return c + def commit(self): + self._commits += 1 + def rollback(self): + self._rollbacks += 1 + + +# ------------------------- +# open_db +# ------------------------- + +def test_open_db_validates_and_initializes_schema(monkeypatch): + # make psycopg2.connect return our dummy connection + monkeypatch.setattr(dbpg.psycopg2, "connect", lambda url: DummyConn()) + conn = dbpg.open_db("postgresql://u:p@h:5432/db", schema="audio_cls") + assert isinstance(conn, DummyConn) + # schema init should have committed once + assert conn._commits >= 1 + +def test_open_db_rejects_bad_schema(): + with pytest.raises(ValueError): + dbpg.open_db("postgresql://u:p@h:5432/db", schema="bad-dash") + +def test_open_db_rollback_on_failure(monkeypatch): + # first cursor.execute will raise + monkeypatch.setattr(dbpg.psycopg2, "connect", lambda url: DummyConn(raise_on_execute=True)) + with pytest.raises(RuntimeError): + dbpg.open_db("postgresql://u:p@h:5432/db", schema="audio_cls") + + +# ------------------------- +# upsert_run +# ------------------------- + +def test_upsert_run_success(monkeypatch): + conn = DummyConn() + dbpg.upsert_run(conn, { + "run_id": "r1", "model_name": "CNN14", "checkpoint": "ckpt", + "head_path": "h", "labels_csv": "l", "window_sec": 2.0, "hop_sec": 0.5, + "pad_last": True, "agg": "mean", "topk": 10, "device": "cpu", + "code_version": "v", "notes": "n" + }) + assert conn._commits == 1 + assert conn._rollbacks == 0 + # Ensure at least one execute has been issued + assert conn.cursors and conn.cursors[0].queries + +def test_upsert_run_rollback_on_exception(monkeypatch): + conn = DummyConn(raise_on_execute=True) + with pytest.raises(RuntimeError): + dbpg.upsert_run(conn, { + "run_id":"r1","model_name":"CNN14","checkpoint":"ckpt","head_path":"h","labels_csv":"l", + "window_sec":2.0,"hop_sec":0.5,"pad_last":True,"agg":"mean","topk":10,"device":"cpu", + "code_version":"v","notes":"n" + }) + assert conn._rollbacks == 1 + assert conn._commits == 0 + + +# ------------------------- +# finish_run +# ------------------------- + +def test_finish_run_success(): + conn = DummyConn() + dbpg.finish_run(conn, "r1") + assert conn._commits == 1 + assert conn._rollbacks == 0 + assert conn.cursors[0].queries # UPDATE executed + +def test_finish_run_rollback_on_exception(): + conn = DummyConn(raise_on_execute=True) + with pytest.raises(RuntimeError): + dbpg.finish_run(conn, "r1") + assert conn._rollbacks == 1 + assert conn._commits == 0 + +def test__jsonify_variants(monkeypatch): + # Wrap psycopg2.extras.Json to observe value passed in + captured = {"value": None} + def fake_json(v): + captured["value"] = v + return ("JsonWrapped", v) + + monkeypatch.setattr(dbpg.psycopg2.extras, "Json", fake_json, raising=True) + + # string with valid JSON → parsed dict + j = dbpg._jsonify('{"a":1}') + assert j == ("JsonWrapped", {"a": 1}) + assert captured["value"] == {"a": 1} + + # plain string → {"raw": "..."} + j2 = dbpg._jsonify("hello") + assert j2 == ("JsonWrapped", {"raw": "hello"}) + assert captured["value"] == {"raw": "hello"} + + # dict passes through + j3 = dbpg._jsonify({"k": 3}) + assert j3 == ("JsonWrapped", {"k": 3}) + assert captured["value"] == {"k": 3} + +def test_upsert_file_aggregate_success(monkeypatch): + # Make Json a pass-through so psycopg2.extras.Json(v) -> v + monkeypatch.setattr(dbpg.psycopg2.extras, "Json", lambda x: x, raising=True) + + conn = DummyConn() + dbpg.upsert_file_aggregate(conn, { + "run_id":"r1","file_id":123, + "head_probs_json":{"car":0.9},"head_pred_label":"car","head_pred_prob":0.9, + "head_unknown_threshold":0.55,"head_is_another":False, + "num_windows":3,"agg_mode":"mean","processing_ms":123 + }) + assert conn._commits == 1 + assert conn._rollbacks == 0 + +def test_upsert_file_aggregate_accepts_string_json_and_rollback_on_exception(monkeypatch): + # Json wrapper + monkeypatch.setattr(dbpg.psycopg2.extras, "Json", lambda x: x, raising=True) + + # connection that will fail during execute + conn = DummyConn(raise_on_execute=True) + with pytest.raises(RuntimeError): + dbpg.upsert_file_aggregate(conn, { + "run_id":"r1","file_id":123, + "head_probs_json":'{"car":0.9}', # string json + "head_pred_label":"car","head_pred_prob":0.9, + "head_unknown_threshold":0.55,"head_is_another":False, + "num_windows":3,"agg_mode":"mean","processing_ms":123 + }) + assert conn._rollbacks == 1 + assert conn._commits == 0 diff --git a/services/sounds/sounds_classifier/tests/test_db_utils.py b/services/sounds_classifier/tests/test_db_utils.py similarity index 97% rename from services/sounds/sounds_classifier/tests/test_db_utils.py rename to services/sounds_classifier/tests/test_db_utils.py index 02b7f0e06..bbe6e8f43 100644 --- a/services/sounds/sounds_classifier/tests/test_db_utils.py +++ b/services/sounds_classifier/tests/test_db_utils.py @@ -1,165 +1,165 @@ -import os -import re -import pytest - -import classification.core.db_utils as dbu - -# ----------------------------- -# Fake psycopg2 connection/cursor -# ----------------------------- -class FakeCursor: - def __init__(self, script_recorder): - self.script_recorder = script_recorder - self._fetchone = None # single value returned by fetchone() - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc, tb): - return False - - def execute(self, query, params=None): - # Record statements and params for later assertions - self.script_recorder.append(("EXEC", str(query), params)) - - def fetchone(self): - return self._fetchone - - # helper for tests to set what fetchone should return - def set_fetchone(self, value): - self._fetchone = value - - -class FakeConn: - def __init__(self): - self.autocommit = False - self.closed = False - self._script = [] - self._cursor = FakeCursor(self._script) - - def cursor(self): - return self._cursor - - def commit(self): - self._script.append(("COMMIT", None, None)) - - def rollback(self): - self._script.append(("ROLLBACK", None, None)) - - def close(self): - self.closed = True - - # test helper - def script(self): - return list(self._script) - - -# ----------------------------- -# Fixture -# ----------------------------- -@pytest.fixture -def fake_conn(monkeypatch): - """ - Patch psycopg2.connect to return our FakeConn. - Also ensure env vars exist with harmless defaults. - """ - fc = FakeConn() - - def fake_connect(**kwargs): - return fc - - # Patch psycopg2.connect inside db_utils module - monkeypatch.setattr(dbu.psycopg2, "connect", fake_connect) - - # Minimal env for db_utils.open_db() - monkeypatch.setenv("DB_HOST", "postgres") - monkeypatch.setenv("DB_PORT", "5432") - monkeypatch.setenv("DB_NAME", "missions_db") - monkeypatch.setenv("DB_USER", "missions_user") - monkeypatch.setenv("DB_PASSWORD", "pg123") - monkeypatch.setenv("DB_SCHEMA", "agcloud_audio") - - return fc - - -# ----------------------------- -# Tests for open_db() -# ----------------------------- -def test_open_db_sets_search_path(fake_conn): - conn = dbu.open_db() - assert conn is fake_conn - # We expect one SET search_path statement with our schema - script = conn.script() - execs = [s for s in script if s[0] == "EXEC"] - assert any( - ("SET search_path TO" in q) and ("agcloud_audio" in q) - for _, q, _ in execs - ), f"Expected SET search_path to agcloud_audio, got: {execs}" - # autocommit should remain False - assert conn.autocommit is False - - -# ----------------------------- -# Tests for ensure_run() -# ----------------------------- -def test_ensure_run_inserts_and_commits(fake_conn, monkeypatch): - # Provide env to fill NOT NULL columns - monkeypatch.setenv("MODEL_NAME", "panns_cnn14") - monkeypatch.setenv("CHECKPOINT", "ckpt.pth") - monkeypatch.setenv("HEAD", "/tmp/head.joblib") - monkeypatch.setenv("LABELS_CSV", "") - monkeypatch.setenv("WINDOW_SEC", "10") - monkeypatch.setenv("HOP_SEC", "10") - monkeypatch.setenv("PAD_LAST", "true") - monkeypatch.setenv("AGG", "mean") - monkeypatch.setenv("TOPK", "3") - monkeypatch.setenv("DEVICE", "cpu") - monkeypatch.setenv("CODE_VERSION", "test") - monkeypatch.setenv("RUN_NOTES", "unit-test") - - dbu.ensure_run(fake_conn, run_id="run-123") - # We expect one INSERT and one COMMIT - script = fake_conn.script() - insert_calls = [s for s in script if s[0] == "EXEC" and "INSERT INTO runs" in s[1]] - assert len(insert_calls) == 1, f"expected single INSERT, got: {insert_calls}" - assert ("COMMIT", None, None) in script - - -# ----------------------------- -# Tests for resolve_file_id() -# ----------------------------- -def test_resolve_file_id_by_file_id_ok(fake_conn): - # Simulate existing file_id - fake_conn._cursor.set_fetchone((42,)) - file_id = dbu.resolve_file_id(fake_conn, file_id=42) - assert file_id == 42 - # Verify the WHERE clause by id appeared - execs = [s for s in fake_conn.script() if s[0] == "EXEC"] - assert any("WHERE file_id = %s" in q for _, q, _ in execs) - -def test_resolve_file_id_by_file_id_not_found(fake_conn): - fake_conn._cursor.set_fetchone(None) - with pytest.raises(ValueError) as ex: - dbu.resolve_file_id(fake_conn, file_id=999) - assert "file_id 999 not found" in str(ex.value) - -def test_resolve_file_id_by_bucket_key_ok(fake_conn): - # Simulate a row found by (bucket, object_key) - fake_conn._cursor.set_fetchone((321,)) - out = dbu.resolve_file_id(fake_conn, bucket="b", object_key="k") - assert out == 321 - execs = [s for s in fake_conn.script() if s[0] == "EXEC"] - assert any("WHERE bucket = %s AND object_key = %s" in q for _, q, _ in execs) - -def test_resolve_file_id_by_bucket_key_not_found(fake_conn): - fake_conn._cursor.set_fetchone(None) - with pytest.raises(ValueError) as ex: - dbu.resolve_file_id(fake_conn, bucket="b", object_key="k") - msg = str(ex.value) - # Be flexible about formatting: just assert key parts of the message exist - assert "not found in public.files" in msg - assert "s3://b/k" in msg or ("bucket" in msg and "object_key" in msg and "b" in msg and "k" in msg) - -def test_resolve_file_id_requires_params(fake_conn): - with pytest.raises(ValueError): - dbu.resolve_file_id(fake_conn) # neither file_id nor (bucket,key) +import os +import re +import pytest + +import classification.core.db_utils as dbu + +# ----------------------------- +# Fake psycopg2 connection/cursor +# ----------------------------- +class FakeCursor: + def __init__(self, script_recorder): + self.script_recorder = script_recorder + self._fetchone = None # single value returned by fetchone() + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + def execute(self, query, params=None): + # Record statements and params for later assertions + self.script_recorder.append(("EXEC", str(query), params)) + + def fetchone(self): + return self._fetchone + + # helper for tests to set what fetchone should return + def set_fetchone(self, value): + self._fetchone = value + + +class FakeConn: + def __init__(self): + self.autocommit = False + self.closed = False + self._script = [] + self._cursor = FakeCursor(self._script) + + def cursor(self): + return self._cursor + + def commit(self): + self._script.append(("COMMIT", None, None)) + + def rollback(self): + self._script.append(("ROLLBACK", None, None)) + + def close(self): + self.closed = True + + # test helper + def script(self): + return list(self._script) + + +# ----------------------------- +# Fixture +# ----------------------------- +@pytest.fixture +def fake_conn(monkeypatch): + """ + Patch psycopg2.connect to return our FakeConn. + Also ensure env vars exist with harmless defaults. + """ + fc = FakeConn() + + def fake_connect(**kwargs): + return fc + + # Patch psycopg2.connect inside db_utils module + monkeypatch.setattr(dbu.psycopg2, "connect", fake_connect) + + # Minimal env for db_utils.open_db() + monkeypatch.setenv("DB_HOST", "postgres") + monkeypatch.setenv("DB_PORT", "5432") + monkeypatch.setenv("DB_NAME", "missions_db") + monkeypatch.setenv("DB_USER", "missions_user") + monkeypatch.setenv("DB_PASSWORD", "pg123") + monkeypatch.setenv("DB_SCHEMA", "agcloud_audio") + + return fc + + +# ----------------------------- +# Tests for open_db() +# ----------------------------- +def test_open_db_sets_search_path(fake_conn): + conn = dbu.open_db() + assert conn is fake_conn + # We expect one SET search_path statement with our schema + script = conn.script() + execs = [s for s in script if s[0] == "EXEC"] + assert any( + ("SET search_path TO" in q) and ("agcloud_audio" in q) + for _, q, _ in execs + ), f"Expected SET search_path to agcloud_audio, got: {execs}" + # autocommit should remain False + assert conn.autocommit is False + + +# ----------------------------- +# Tests for ensure_run() +# ----------------------------- +def test_ensure_run_inserts_and_commits(fake_conn, monkeypatch): + # Provide env to fill NOT NULL columns + monkeypatch.setenv("MODEL_NAME", "panns_cnn14") + monkeypatch.setenv("CHECKPOINT", "ckpt.pth") + monkeypatch.setenv("HEAD", "/tmp/head.joblib") + monkeypatch.setenv("LABELS_CSV", "") + monkeypatch.setenv("WINDOW_SEC", "10") + monkeypatch.setenv("HOP_SEC", "10") + monkeypatch.setenv("PAD_LAST", "true") + monkeypatch.setenv("AGG", "mean") + monkeypatch.setenv("TOPK", "3") + monkeypatch.setenv("DEVICE", "cpu") + monkeypatch.setenv("CODE_VERSION", "test") + monkeypatch.setenv("RUN_NOTES", "unit-test") + + dbu.ensure_run(fake_conn, run_id="run-123") + # We expect one INSERT and one COMMIT + script = fake_conn.script() + insert_calls = [s for s in script if s[0] == "EXEC" and "INSERT INTO runs" in s[1]] + assert len(insert_calls) == 1, f"expected single INSERT, got: {insert_calls}" + assert ("COMMIT", None, None) in script + + +# ----------------------------- +# Tests for resolve_file_id() +# ----------------------------- +def test_resolve_file_id_by_file_id_ok(fake_conn): + # Simulate existing file_id + fake_conn._cursor.set_fetchone((42,)) + file_id = dbu.resolve_file_id(fake_conn, file_id=42) + assert file_id == 42 + # Verify the WHERE clause by id appeared + execs = [s for s in fake_conn.script() if s[0] == "EXEC"] + assert any("WHERE file_id = %s" in q for _, q, _ in execs) + +def test_resolve_file_id_by_file_id_not_found(fake_conn): + fake_conn._cursor.set_fetchone(None) + with pytest.raises(ValueError) as ex: + dbu.resolve_file_id(fake_conn, file_id=999) + assert "file_id 999 not found" in str(ex.value) + +def test_resolve_file_id_by_bucket_key_ok(fake_conn): + # Simulate a row found by (bucket, object_key) + fake_conn._cursor.set_fetchone((321,)) + out = dbu.resolve_file_id(fake_conn, bucket="b", object_key="k") + assert out == 321 + execs = [s for s in fake_conn.script() if s[0] == "EXEC"] + assert any("WHERE bucket = %s AND object_key = %s" in q for _, q, _ in execs) + +def test_resolve_file_id_by_bucket_key_not_found(fake_conn): + fake_conn._cursor.set_fetchone(None) + with pytest.raises(ValueError) as ex: + dbu.resolve_file_id(fake_conn, bucket="b", object_key="k") + msg = str(ex.value) + # Be flexible about formatting: just assert key parts of the message exist + assert "not found in public.files" in msg + assert "s3://b/k" in msg or ("bucket" in msg and "object_key" in msg and "b" in msg and "k" in msg) + +def test_resolve_file_id_requires_params(fake_conn): + with pytest.raises(ValueError): + dbu.resolve_file_id(fake_conn) # neither file_id nor (bucket,key) diff --git a/services/sounds/sounds_classifier/tests/test_model_io.py b/services/sounds_classifier/tests/test_model_io.py similarity index 97% rename from services/sounds/sounds_classifier/tests/test_model_io.py rename to services/sounds_classifier/tests/test_model_io.py index c635ffd19..4634a9804 100644 --- a/services/sounds/sounds_classifier/tests/test_model_io.py +++ b/services/sounds_classifier/tests/test_model_io.py @@ -1,258 +1,258 @@ -import numpy as np -from pathlib import Path -import types -import pytest -import soundfile as sf - -import classification.core.model_io as mio -from classification.core.model_io import ( - aggregate_matrix, - segment_waveform, - load_audio, - SAMPLE_RATE, - ensure_checkpoint, - ensure_numpy_1d, - _to_numpy, -) - -# ---------- aggregate_matrix: happy paths and error branches ---------- - -def test_aggregate_matrix_mean_and_max_non_empty(): - X = np.array([[0.1, 0.9], [0.7, 0.3], [0.2, 0.8]], dtype=np.float32) - m = aggregate_matrix(X, mode="mean") - M = aggregate_matrix(X, mode="max") - assert np.allclose(m, X.mean(0)) - assert np.allclose(M, X.max(0)) - -def test_aggregate_matrix_empty_raises(): - X = np.zeros((0, 2), dtype=np.float32) - with pytest.raises(ValueError) as e: - aggregate_matrix(X, mode="mean") - assert "empty window matrix" in str(e.value) - -def test_aggregate_matrix_zero_classes_raises(): - X = np.zeros((3, 0), dtype=np.float32) - with pytest.raises(ValueError) as e: - aggregate_matrix(X, mode="mean") - assert "num_classes > 0" in str(e.value) - -def test_aggregate_matrix_wrong_type_and_ndim(): - with pytest.raises(TypeError): - aggregate_matrix([[1, 2], [3, 4]], mode="mean") # not np.ndarray - with pytest.raises(ValueError): - aggregate_matrix(np.array([1, 2, 3], dtype=np.float32), mode="mean") # 1D - with pytest.raises(ValueError): - aggregate_matrix(np.zeros((2, 2, 2), dtype=np.float32), mode="mean") # 3D - -def test_aggregate_matrix_unsupported_mode(): - X = np.ones((2, 2), dtype=np.float32) - with pytest.raises(ValueError): - aggregate_matrix(X, mode="median") - -def test_aggregate_matrix_nan_and_infs_are_handled(): - X = np.array([[np.nan, 1.0, np.inf, -np.inf], - [2.0, np.nan, 3.0, -5.0 ]], dtype=np.float32) - out_mean = aggregate_matrix(X, mode="mean") - out_max = aggregate_matrix(X, mode="max") - assert out_mean.dtype == np.float32 - assert out_max.dtype == np.float32 - # NaNs should be treated as missing, and inf/-inf should be clamped via nan_to_num. - assert np.isfinite(out_mean).all() - assert np.isfinite(out_max).all() - -# ---------- ensure_numpy_1d & _to_numpy ---------- - -def test_ensure_numpy_1d_duck_typed_torch_like(): - class DuckTensor: - def __init__(self, arr): self._arr = np.asarray(arr, dtype=np.float32) - def detach(self): return self - def cpu(self): return self - def numpy(self): return self._arr - x = DuckTensor([[1.0], [2.0], [3.0]]) - y = ensure_numpy_1d(x) - assert isinstance(y, np.ndarray) - assert y.shape == (3,) - assert y.dtype == np.float32 - -def test_ensure_numpy_1d_object_with_numpy_method_only(): - class OnlyNumpy: - def __init__(self, arr): self._arr = np.asarray(arr, dtype=np.float32) - def numpy(self): return self._arr - x = OnlyNumpy([[1.0, 2.0, 3.0]]) - y = ensure_numpy_1d(x) - assert y.shape == (3,) - assert y.dtype == np.float32 - -def test__to_numpy_handles_2d_shapes_and_casts_to_float32(): - x = np.array([[1.0], [2.0], [3.0]], dtype=np.float64) - y = _to_numpy(x) - assert y.dtype == np.float32 - assert y.shape == (3,) - -@pytest.mark.skipif(mio.torch is None, reason="torch is not available") -def test__to_numpy_with_real_torch_tensor_when_available(): - import torch - x = torch.tensor([[1.0, 2.0, 3.0]], dtype=torch.float32) - try: - y = _to_numpy(x) - except RuntimeError as e: - # Some builds of torch raise when calling .numpy() if NumPy is mismatched/not present. - assert "Numpy is not available" in str(e) - else: - assert isinstance(y, np.ndarray) - assert y.dtype == np.float32 - assert y.shape == (3,) - -# ---------- segment_waveform edge cases ---------- - -def test_segment_waveform_empty_and_padding_logic(): - # empty wav returns [] - assert segment_waveform(np.array([], dtype=np.float32), sr=SAMPLE_RATE) == [] - - # shorter than one window, pad_last=True adds one padded segment - short = np.ones(100, dtype=np.float32) - segs = segment_waveform(short, sr=SAMPLE_RATE, window_sec=0.01, hop_sec=0.005, pad_last=True) - assert len(segs) >= 1 - for s in segs: - assert s.ndim == 1 and s.dtype == np.float32 - - # shorter than one window, pad_last=False should not add segment - segs2 = segment_waveform(short, sr=SAMPLE_RATE, window_sec=0.01, hop_sec=0.005, pad_last=False) - # Depending on rounding, there may be 0 or >0 windows; enforce that no tail padding is added: - # if the first while loop didn't fit any full window, with pad_last=False we expect 0 segments. - assert len(segs2) in (0,) - -def test_segment_waveform_overlap_and_types(): - wav = np.arange(1000, dtype=np.int16) # non-float input should be casted - segs = segment_waveform(wav, sr=1000, window_sec=0.1, hop_sec=0.05, pad_last=True) - assert len(segs) >= 1 - assert all(s.dtype == np.float32 for s in segs) - -# ---------- load_audio: WAV roundtrip & padding ---------- - -def _sine(sr: int, seconds: float) -> np.ndarray: - t = np.linspace(0, seconds, int(sr * seconds), endpoint=False, dtype=np.float32) - return np.sin(2 * np.pi * 440.0 * t).astype(np.float32) - -def test_load_audio_roundtrip_and_padding(tmp_path: Path): - wav = _sine(SAMPLE_RATE, 0.25) # shorter than MIN_SAMPLES (16000) - p = tmp_path / "a.wav" - sf.write(p, wav, samplerate=SAMPLE_RATE, subtype="PCM_16") - out = load_audio(str(p), SAMPLE_RATE) - assert out.dtype == np.float32 - assert out.ndim == 1 - assert out.size >= mio.MIN_SAMPLES # padded - -def test_load_audio_resample_path(monkeypatch, tmp_path: Path): - # Write at sr=16000; target is 32000. Force librosa.resample path. - wav16 = _sine(16000, 0.2).astype(np.float32) - p = tmp_path / "b.wav" - sf.write(p, wav16, samplerate=16000, subtype="PCM_16") - - called = {"resample": False} - def fake_resample(y, orig_sr, target_sr): - called["resample"] = True - # return float64 to test cast back to float32 - return np.asarray(np.repeat(y, 2), dtype=np.float64) - - monkeypatch.setattr(mio.librosa, "resample", fake_resample, raising=True) - out = load_audio(str(p), SAMPLE_RATE) - assert called["resample"] is True - assert out.dtype == np.float32 - assert out.ndim == 1 - -# ---------- load_audio: HARD_EXTS (mp3/m4a/...) branch + ffmpeg fallback ---------- - -def test_load_audio_hard_ext_uses_librosa_success(monkeypatch, tmp_path: Path): - # Fake an mp3 file; we'll mock librosa.load to return data, so the bytes don't matter. - p = tmp_path / "x.mp3" - p.write_bytes(b"ID3") - - def fake_librosa_load(path, sr, mono): - assert str(path) == str(p) - return np.ones(100, dtype=np.float32), sr - - monkeypatch.setattr(mio.librosa, "load", fake_librosa_load, raising=True) - out = load_audio(str(p), SAMPLE_RATE) - assert out.dtype == np.float32 - assert out.ndim == 1 - assert out.size >= mio.MIN_SAMPLES - -def test_load_audio_hard_ext_librosa_fail_ffmpeg_fallback(monkeypatch, tmp_path: Path): - p = tmp_path / "y.m4a" - p.write_bytes(b"\x00\x00\x00\x20ftypM4A ") # header-ish; not used - - def fake_librosa_fail(*a, **k): raise RuntimeError("boom") - monkeypatch.setattr(mio.librosa, "load", fake_librosa_fail, raising=True) - monkeypatch.setattr(mio, "has_ffmpeg", lambda: True, raising=True) - # make ffmpeg path return short buffer to test padding - monkeypatch.setattr(mio, "decode_with_ffmpeg_to_float32_mono", - lambda path, target_sr: np.ones(10, dtype=np.float32), raising=True) - out = load_audio(str(p), SAMPLE_RATE) - assert out.size >= mio.MIN_SAMPLES - assert out.dtype == np.float32 - -def test_load_audio_all_fail_then_ffmpeg(monkeypatch, tmp_path: Path): - # For non-HARD ext: raise in soundfile.read, then raise in librosa.load, then use ffmpeg - p = tmp_path / "z.ogg" - p.write_bytes(b"OggS...") - - def fake_sf_read(*a, **k): raise RuntimeError("sf fail") - def fake_librosa_load(*a, **k): raise RuntimeError("librosa fail") - - monkeypatch.setattr(sf, "read", fake_sf_read, raising=True) - monkeypatch.setattr(mio.librosa, "load", fake_librosa_load, raising=True) - monkeypatch.setattr(mio, "has_ffmpeg", lambda: True, raising=True) - monkeypatch.setattr(mio, "decode_with_ffmpeg_to_float32_mono", - lambda path, target_sr: np.ones(33, dtype=np.float32), raising=True) - out = load_audio(str(p), SAMPLE_RATE) - assert out.ndim == 1 and out.dtype == np.float32 - -# ---------- ffmpeg helpers ---------- - -def test_has_ffmpeg_uses_shutil_which(monkeypatch): - # Ensure function depends on shutil.which - monkeypatch.setattr(mio.shutil, "which", lambda _: "/usr/bin/ffmpeg") - assert mio.has_ffmpeg() is True - monkeypatch.setattr(mio.shutil, "which", lambda _: None) - assert mio.has_ffmpeg() is False - -def test_decode_with_ffmpeg_to_float32_mono_monkeypatched(monkeypatch, tmp_path: Path): - # We won't call real ffmpeg; we mock subprocess.run to return a small f32 buffer. - p = tmp_path / "q.ogg" - p.write_bytes(b"OggS...") - - class DummyProc: - def __init__(self, data): self.stdout = data; self.stderr = b"" - # 4 float32 numbers -> 16 bytes; MIN_SAMPLES will trigger padding. - raw = (np.array([0.1, 0.2, 0.3, 0.4], dtype=np.float32)).tobytes() - - monkeypatch.setattr( - mio.subprocess, "run", - lambda cmd, stdout, stderr, check: DummyProc(raw), - raising=True - ) - - out = mio.decode_with_ffmpeg_to_float32_mono(str(p), mio.SAMPLE_RATE) - assert out.dtype == np.float32 - assert out.ndim == 1 - assert out.size >= mio.MIN_SAMPLES - -# ---------- ensure_checkpoint (patch urllib at global module level) ---------- - -def test_ensure_checkpoint_downloads_when_missing(monkeypatch, tmp_path): - target = tmp_path / "models" / "panns_data" / "m.pth" - called = {"ok": False} - def fake_urlretrieve(url, dst): - called["ok"] = True - Path(dst).write_bytes(b"ok") - monkeypatch.setattr("urllib.request.urlretrieve", fake_urlretrieve, raising=True) - path = ensure_checkpoint(str(target), "http://example.com/m.pth") - assert called["ok"] is True - assert Path(path).exists() - -def test_ensure_numpy_and_to_numpy_helpers(): - x = np.array([[1.0], [2.0], [3.0]], dtype=np.float32) - assert ensure_numpy_1d(x).shape == (3,) - y = _to_numpy(x.reshape(1, -1)) - assert y.ndim == 1 +import numpy as np +from pathlib import Path +import types +import pytest +import soundfile as sf + +import classification.core.model_io as mio +from classification.core.model_io import ( + aggregate_matrix, + segment_waveform, + load_audio, + SAMPLE_RATE, + ensure_checkpoint, + ensure_numpy_1d, + _to_numpy, +) + +# ---------- aggregate_matrix: happy paths and error branches ---------- + +def test_aggregate_matrix_mean_and_max_non_empty(): + X = np.array([[0.1, 0.9], [0.7, 0.3], [0.2, 0.8]], dtype=np.float32) + m = aggregate_matrix(X, mode="mean") + M = aggregate_matrix(X, mode="max") + assert np.allclose(m, X.mean(0)) + assert np.allclose(M, X.max(0)) + +def test_aggregate_matrix_empty_raises(): + X = np.zeros((0, 2), dtype=np.float32) + with pytest.raises(ValueError) as e: + aggregate_matrix(X, mode="mean") + assert "empty window matrix" in str(e.value) + +def test_aggregate_matrix_zero_classes_raises(): + X = np.zeros((3, 0), dtype=np.float32) + with pytest.raises(ValueError) as e: + aggregate_matrix(X, mode="mean") + assert "num_classes > 0" in str(e.value) + +def test_aggregate_matrix_wrong_type_and_ndim(): + with pytest.raises(TypeError): + aggregate_matrix([[1, 2], [3, 4]], mode="mean") # not np.ndarray + with pytest.raises(ValueError): + aggregate_matrix(np.array([1, 2, 3], dtype=np.float32), mode="mean") # 1D + with pytest.raises(ValueError): + aggregate_matrix(np.zeros((2, 2, 2), dtype=np.float32), mode="mean") # 3D + +def test_aggregate_matrix_unsupported_mode(): + X = np.ones((2, 2), dtype=np.float32) + with pytest.raises(ValueError): + aggregate_matrix(X, mode="median") + +def test_aggregate_matrix_nan_and_infs_are_handled(): + X = np.array([[np.nan, 1.0, np.inf, -np.inf], + [2.0, np.nan, 3.0, -5.0 ]], dtype=np.float32) + out_mean = aggregate_matrix(X, mode="mean") + out_max = aggregate_matrix(X, mode="max") + assert out_mean.dtype == np.float32 + assert out_max.dtype == np.float32 + # NaNs should be treated as missing, and inf/-inf should be clamped via nan_to_num. + assert np.isfinite(out_mean).all() + assert np.isfinite(out_max).all() + +# ---------- ensure_numpy_1d & _to_numpy ---------- + +def test_ensure_numpy_1d_duck_typed_torch_like(): + class DuckTensor: + def __init__(self, arr): self._arr = np.asarray(arr, dtype=np.float32) + def detach(self): return self + def cpu(self): return self + def numpy(self): return self._arr + x = DuckTensor([[1.0], [2.0], [3.0]]) + y = ensure_numpy_1d(x) + assert isinstance(y, np.ndarray) + assert y.shape == (3,) + assert y.dtype == np.float32 + +def test_ensure_numpy_1d_object_with_numpy_method_only(): + class OnlyNumpy: + def __init__(self, arr): self._arr = np.asarray(arr, dtype=np.float32) + def numpy(self): return self._arr + x = OnlyNumpy([[1.0, 2.0, 3.0]]) + y = ensure_numpy_1d(x) + assert y.shape == (3,) + assert y.dtype == np.float32 + +def test__to_numpy_handles_2d_shapes_and_casts_to_float32(): + x = np.array([[1.0], [2.0], [3.0]], dtype=np.float64) + y = _to_numpy(x) + assert y.dtype == np.float32 + assert y.shape == (3,) + +@pytest.mark.skipif(mio.torch is None, reason="torch is not available") +def test__to_numpy_with_real_torch_tensor_when_available(): + import torch + x = torch.tensor([[1.0, 2.0, 3.0]], dtype=torch.float32) + try: + y = _to_numpy(x) + except RuntimeError as e: + # Some builds of torch raise when calling .numpy() if NumPy is mismatched/not present. + assert "Numpy is not available" in str(e) + else: + assert isinstance(y, np.ndarray) + assert y.dtype == np.float32 + assert y.shape == (3,) + +# ---------- segment_waveform edge cases ---------- + +def test_segment_waveform_empty_and_padding_logic(): + # empty wav returns [] + assert segment_waveform(np.array([], dtype=np.float32), sr=SAMPLE_RATE) == [] + + # shorter than one window, pad_last=True adds one padded segment + short = np.ones(100, dtype=np.float32) + segs = segment_waveform(short, sr=SAMPLE_RATE, window_sec=0.01, hop_sec=0.005, pad_last=True) + assert len(segs) >= 1 + for s in segs: + assert s.ndim == 1 and s.dtype == np.float32 + + # shorter than one window, pad_last=False should not add segment + segs2 = segment_waveform(short, sr=SAMPLE_RATE, window_sec=0.01, hop_sec=0.005, pad_last=False) + # Depending on rounding, there may be 0 or >0 windows; enforce that no tail padding is added: + # if the first while loop didn't fit any full window, with pad_last=False we expect 0 segments. + assert len(segs2) in (0,) + +def test_segment_waveform_overlap_and_types(): + wav = np.arange(1000, dtype=np.int16) # non-float input should be casted + segs = segment_waveform(wav, sr=1000, window_sec=0.1, hop_sec=0.05, pad_last=True) + assert len(segs) >= 1 + assert all(s.dtype == np.float32 for s in segs) + +# ---------- load_audio: WAV roundtrip & padding ---------- + +def _sine(sr: int, seconds: float) -> np.ndarray: + t = np.linspace(0, seconds, int(sr * seconds), endpoint=False, dtype=np.float32) + return np.sin(2 * np.pi * 440.0 * t).astype(np.float32) + +def test_load_audio_roundtrip_and_padding(tmp_path: Path): + wav = _sine(SAMPLE_RATE, 0.25) # shorter than MIN_SAMPLES (16000) + p = tmp_path / "a.wav" + sf.write(p, wav, samplerate=SAMPLE_RATE, subtype="PCM_16") + out = load_audio(str(p), SAMPLE_RATE) + assert out.dtype == np.float32 + assert out.ndim == 1 + assert out.size >= mio.MIN_SAMPLES # padded + +def test_load_audio_resample_path(monkeypatch, tmp_path: Path): + # Write at sr=16000; target is 32000. Force librosa.resample path. + wav16 = _sine(16000, 0.2).astype(np.float32) + p = tmp_path / "b.wav" + sf.write(p, wav16, samplerate=16000, subtype="PCM_16") + + called = {"resample": False} + def fake_resample(y, orig_sr, target_sr): + called["resample"] = True + # return float64 to test cast back to float32 + return np.asarray(np.repeat(y, 2), dtype=np.float64) + + monkeypatch.setattr(mio.librosa, "resample", fake_resample, raising=True) + out = load_audio(str(p), SAMPLE_RATE) + assert called["resample"] is True + assert out.dtype == np.float32 + assert out.ndim == 1 + +# ---------- load_audio: HARD_EXTS (mp3/m4a/...) branch + ffmpeg fallback ---------- + +def test_load_audio_hard_ext_uses_librosa_success(monkeypatch, tmp_path: Path): + # Fake an mp3 file; we'll mock librosa.load to return data, so the bytes don't matter. + p = tmp_path / "x.mp3" + p.write_bytes(b"ID3") + + def fake_librosa_load(path, sr, mono): + assert str(path) == str(p) + return np.ones(100, dtype=np.float32), sr + + monkeypatch.setattr(mio.librosa, "load", fake_librosa_load, raising=True) + out = load_audio(str(p), SAMPLE_RATE) + assert out.dtype == np.float32 + assert out.ndim == 1 + assert out.size >= mio.MIN_SAMPLES + +def test_load_audio_hard_ext_librosa_fail_ffmpeg_fallback(monkeypatch, tmp_path: Path): + p = tmp_path / "y.m4a" + p.write_bytes(b"\x00\x00\x00\x20ftypM4A ") # header-ish; not used + + def fake_librosa_fail(*a, **k): raise RuntimeError("boom") + monkeypatch.setattr(mio.librosa, "load", fake_librosa_fail, raising=True) + monkeypatch.setattr(mio, "has_ffmpeg", lambda: True, raising=True) + # make ffmpeg path return short buffer to test padding + monkeypatch.setattr(mio, "decode_with_ffmpeg_to_float32_mono", + lambda path, target_sr: np.ones(10, dtype=np.float32), raising=True) + out = load_audio(str(p), SAMPLE_RATE) + assert out.size >= mio.MIN_SAMPLES + assert out.dtype == np.float32 + +def test_load_audio_all_fail_then_ffmpeg(monkeypatch, tmp_path: Path): + # For non-HARD ext: raise in soundfile.read, then raise in librosa.load, then use ffmpeg + p = tmp_path / "z.ogg" + p.write_bytes(b"OggS...") + + def fake_sf_read(*a, **k): raise RuntimeError("sf fail") + def fake_librosa_load(*a, **k): raise RuntimeError("librosa fail") + + monkeypatch.setattr(sf, "read", fake_sf_read, raising=True) + monkeypatch.setattr(mio.librosa, "load", fake_librosa_load, raising=True) + monkeypatch.setattr(mio, "has_ffmpeg", lambda: True, raising=True) + monkeypatch.setattr(mio, "decode_with_ffmpeg_to_float32_mono", + lambda path, target_sr: np.ones(33, dtype=np.float32), raising=True) + out = load_audio(str(p), SAMPLE_RATE) + assert out.ndim == 1 and out.dtype == np.float32 + +# ---------- ffmpeg helpers ---------- + +def test_has_ffmpeg_uses_shutil_which(monkeypatch): + # Ensure function depends on shutil.which + monkeypatch.setattr(mio.shutil, "which", lambda _: "/usr/bin/ffmpeg") + assert mio.has_ffmpeg() is True + monkeypatch.setattr(mio.shutil, "which", lambda _: None) + assert mio.has_ffmpeg() is False + +def test_decode_with_ffmpeg_to_float32_mono_monkeypatched(monkeypatch, tmp_path: Path): + # We won't call real ffmpeg; we mock subprocess.run to return a small f32 buffer. + p = tmp_path / "q.ogg" + p.write_bytes(b"OggS...") + + class DummyProc: + def __init__(self, data): self.stdout = data; self.stderr = b"" + # 4 float32 numbers -> 16 bytes; MIN_SAMPLES will trigger padding. + raw = (np.array([0.1, 0.2, 0.3, 0.4], dtype=np.float32)).tobytes() + + monkeypatch.setattr( + mio.subprocess, "run", + lambda cmd, stdout, stderr, check: DummyProc(raw), + raising=True + ) + + out = mio.decode_with_ffmpeg_to_float32_mono(str(p), mio.SAMPLE_RATE) + assert out.dtype == np.float32 + assert out.ndim == 1 + assert out.size >= mio.MIN_SAMPLES + +# ---------- ensure_checkpoint (patch urllib at global module level) ---------- + +def test_ensure_checkpoint_downloads_when_missing(monkeypatch, tmp_path): + target = tmp_path / "models" / "panns_data" / "m.pth" + called = {"ok": False} + def fake_urlretrieve(url, dst): + called["ok"] = True + Path(dst).write_bytes(b"ok") + monkeypatch.setattr("urllib.request.urlretrieve", fake_urlretrieve, raising=True) + path = ensure_checkpoint(str(target), "http://example.com/m.pth") + assert called["ok"] is True + assert Path(path).exists() + +def test_ensure_numpy_and_to_numpy_helpers(): + x = np.array([[1.0], [2.0], [3.0]], dtype=np.float32) + assert ensure_numpy_1d(x).shape == (3,) + y = _to_numpy(x.reshape(1, -1)) + assert y.ndim == 1 diff --git a/services/sounds_flink/Dockerfile b/services/sounds_flink/Dockerfile new file mode 100644 index 000000000..9fc4a5a5c --- /dev/null +++ b/services/sounds_flink/Dockerfile @@ -0,0 +1,155 @@ +# syntax=docker/dockerfile:1 +FROM flink:1.19.3-scala_2.12-java11 + +USER root +WORKDIR /opt/app + +# ----------------------------- +# System CA & Python toolchain +# ----------------------------- +RUN apt-get update && apt-get install -y --no-install-recommends \ + ca-certificates wget curl python3 python3-venv python3-pip jq gnupg \ + && update-ca-certificates \ + && rm -rf /var/lib/apt/lists/* + +# ----------------------------- +# Optional NetFree CAs (empty dir is OK) +# If you have custom CAs, put *.crt in ./certs and they will be added. +# ----------------------------- +COPY certs/ /usr/local/share/ca-certificates/ +RUN update-ca-certificates || true + +# ----------------------------- +# SSL env for pip/requests (works with/without NetFree) +# ----------------------------- +ENV SSL_CERT_FILE=/etc/ssl/certs/ca-certificates.crt \ + REQUESTS_CA_BUNDLE=/etc/ssl/certs/ca-certificates.crt \ + PIP_CERT=/etc/ssl/certs/ca-certificates.crt \ + PIP_DISABLE_PIP_VERSION_CHECK=1 + +# ----------------------------- +# Optional PyPI mirror (for NetFree or internal mirror) +# Pass --build-arg PIP_INDEX_URL=https:///simple +# ----------------------------- +ARG PIP_INDEX_URL="" +RUN set -eux; \ + mkdir -p /etc/pip; \ + { \ + echo "[global]"; \ + echo "cert = /etc/ssl/certs/ca-certificates.crt"; \ + echo "trusted-host = pypi.org"; \ + echo "trusted-host = files.pythonhosted.org"; \ + } > /etc/pip/pip.conf; \ + if [ -n "$PIP_INDEX_URL" ]; then \ + echo "index-url = $PIP_INDEX_URL" >> /etc/pip/pip.conf; \ + fi + +# ----------------------------- +# Python venv (expose system packages as fallback) +# ----------------------------- +RUN python3 -m venv /opt/venv --system-site-packages +ENV PATH="/opt/venv/bin:${PATH}" + +# ----------------------------- +# Copy requirements and install (pip first, apt fallback for some libs) +# ----------------------------- +COPY requirements.txt /opt/app/requirements.txt + +# Optional: allow passing a pre-downloaded PyFlink wheel (for NetFree) +ARG PYFLINK_WHEEL_URL="" +# Make sure shell sees a defined variable even with `set -u` +ENV PYFLINK_WHEEL_URL="${PYFLINK_WHEEL_URL}" + +RUN set -eux; \ + python -m pip install --no-cache-dir --upgrade \ + --trusted-host pypi.org --trusted-host files.pythonhosted.org \ + --cert /etc/ssl/certs/ca-certificates.crt \ + pip setuptools wheel; \ + echo ">>> Installing Python deps via pip (requirements.txt)"; \ + if ! python -m pip install --no-cache-dir \ + --trusted-host pypi.org --trusted-host files.pythonhosted.org \ + --cert /etc/ssl/certs/ca-certificates.crt \ + -r /opt/app/requirements.txt; then \ + echo "WARN: pip install failed or blocked; trying apt fallback for core libs"; \ + apt-get update && apt-get install -y --no-install-recommends \ + python3-yaml python3-protobuf python3-grpcio \ + && rm -rf /var/lib/apt/lists/*; \ + fi; \ + echo '>>> Ensuring requests is installed (pip first, apt fallback)'; \ + if ! python -m pip install --no-cache-dir \ + --trusted-host pypi.org --trusted-host files.pythonhosted.org \ + --cert /etc/ssl/certs/ca-certificates.crt \ + requests==2.32.3; then \ + echo 'WARN: pip blocked for requests; falling back to apt'; \ + apt-get update && apt-get install -y --no-install-recommends python3-requests || true; \ + rm -rf /var/lib/apt/lists/*; \ + fi; \ + echo ">>> Enforcing PyFlink presence (wheel or pip)"; \ + if [ -n "${PYFLINK_WHEEL_URL:-}" ]; then \ + python -m pip install --no-cache-dir \ + --cert /etc/ssl/certs/ca-certificates.crt "${PYFLINK_WHEEL_URL}" \ + || (echo 'FATAL: PyFlink wheel install failed' && exit 1); \ + else \ + python -m pip install --no-cache-dir \ + --trusted-host pypi.org --trusted-host files.pythonhosted.org \ + --cert /etc/ssl/certs/ca-certificates.crt \ + apache-flink==1.19.3 \ + || (echo 'FATAL: apache-flink install failed' && exit 1); \ + fi; \ + echo ">>> Forcing critical runtime libs into venv (pip first, apt fallback)"; \ + if ! python -m pip install --no-cache-dir \ + --trusted-host pypi.org --trusted-host files.pythonhosted.org \ + --cert /etc/ssl/certs/ca-certificates.crt \ + protobuf==4.25.3 googleapis-common-protos==1.63.0 grpcio==1.60.0; then \ + echo "WARN: pip blocked for key libs; using apt fallback"; \ + apt-get update && apt-get install -y --no-install-recommends \ + python3-protobuf python3-grpcio || true; \ + rm -rf /var/lib/apt/lists/*; \ + fi; \ + python - <<'PY' +import sys +print("Python:", sys.version) +# hard fail if imports are not available inside the venv +for mod in ("requests", "urllib3", "google", "google.protobuf", "grpc", "pyflink"): + try: + __import__(mod) + print(mod, "OK") + except Exception as e: + print("FATAL import check:", mod, "->", e) + raise SystemExit(1) +PY + + +# ----------------------------- +# Flink Kafka connector jars (REQUIRED for Kafka in cluster mode) +# Version aligned to Flink 1.19 +# ----------------------------- +RUN mkdir -p /opt/flink/lib && \ + wget -qO /opt/flink/lib/flink-connector-kafka-3.2.0-1.19.jar \ + https://repo1.maven.org/maven2/org/apache/flink/flink-connector-kafka/3.2.0-1.19/flink-connector-kafka-3.2.0-1.19.jar && \ + wget -qO /opt/flink/lib/kafka-clients-3.7.0.jar \ + https://repo1.maven.org/maven2/org/apache/kafka/kafka-clients/3.7.0/kafka-clients-3.7.0.jar + +# ----------------------------- +# Copy app code (keep small, configurable by ENV) +# ----------------------------- +COPY config.py /opt/app/config.py +COPY processor.py /opt/app/processor.py +COPY flink_job.py /opt/app/flink_job.py + +# ----------------------------- +# Runtime ENV defaults (can be overridden in docker-compose) +# ----------------------------- +ENV PYTHONPATH=/opt/app \ + PYFLINK_CLIENT_EXECUTABLE=python \ + PYFLINK_PYTHON=python \ + KAFKA_BROKERS=kafka:9092 \ + SOURCE_TOPIC=sound.new.sounds \ + SINK_TOPIC=classified.sounds \ + GROUP_ID=flink-classifier-sounds \ + KAFKA_START=earliest + +# ----------------------------- +# Default CMD is noop; compose sets jobmanager/taskmanager/submitter commands +# ----------------------------- +CMD ["bash", "-lc", "echo 'Flink image ready (NetFree/Non-NetFree compatible)'; tail -f /dev/null"] diff --git a/services/sounds_flink/config.py b/services/sounds_flink/config.py new file mode 100644 index 000000000..989f8677e --- /dev/null +++ b/services/sounds_flink/config.py @@ -0,0 +1,23 @@ +import os + +# Kafka / topics +KAFKA_BROKERS = os.getenv("KAFKA_BROKERS", "kafka:9092") +SOURCE_TOPIC = os.getenv("SOURCE_TOPIC", "sound_new_sounds_connections") +SINK_TOPIC = os.getenv("SINK_TOPIC", "") # empty = print to stdout only +GROUP_ID = os.getenv("GROUP_ID", "flink-classifier-sounds") +KAFKA_START = os.getenv("KAFKA_START", "earliest") # earliest|latest + +# HTTP classifier +CLASSIFIER_HTTP_URL = os.getenv("CLASSIFIER_HTTP_URL", "http://sounds_classifier:8088/classify") +REQUEST_TIMEOUT = float(os.getenv("REQUEST_TIMEOUT", "5.0")) +RETRIES_TOTAL = int(os.getenv("RETRIES_TOTAL", "3")) +BACKOFF_FACTOR = float(os.getenv("BACKOFF_FACTOR", "0.5")) + +# Flink +DEFAULT_PARALLELISM = int(os.getenv("DEFAULT_PARALLELISM", "1")) +CHECKPOINT_MS = int(os.getenv("CHECKPOINT_MS", "10000")) # 10s +DELIVERY_GUARANTEE = os.getenv("DELIVERY_GUARANTEE", "AT_LEAST_ONCE") # AT_LEAST_ONCE|NONE +TRANSACTION_TIMEOUT_MS = os.getenv("TRANSACTION_TIMEOUT_MS", "600000") # 10 min + +# Optional default bucket to use when input only carries an object key +DEFAULT_BUCKET = os.getenv("DEFAULT_BUCKET", "sound") \ No newline at end of file diff --git a/services/sounds_flink/flink_job.py b/services/sounds_flink/flink_job.py new file mode 100644 index 000000000..de2429cc9 --- /dev/null +++ b/services/sounds_flink/flink_job.py @@ -0,0 +1,83 @@ +""" +Flink Python DataStream job: +- Kafka source (JSON notifications) +- Per-record HTTP classification via pooled Session (processor.process_json_line) +- Optional Kafka sink; if SINK_TOPIC is empty -> print to stdout +""" + +from pyflink.datastream import StreamExecutionEnvironment +from pyflink.datastream.connectors.kafka import ( + KafkaSource, KafkaSink, KafkaRecordSerializationSchema, DeliveryGuarantee +) +from pyflink.common.serialization import SimpleStringSchema +from pyflink.common.watermark_strategy import WatermarkStrategy +from pyflink.datastream.checkpointing_mode import CheckpointingMode +from pyflink.common import Types +from processor import process_json_line + +from config import ( + KAFKA_BROKERS, + SOURCE_TOPIC, + SINK_TOPIC, + GROUP_ID, + KAFKA_START, + DEFAULT_PARALLELISM, + CHECKPOINT_MS, + DELIVERY_GUARANTEE, + TRANSACTION_TIMEOUT_MS, +) + +def main(): + env = StreamExecutionEnvironment.get_execution_environment() + env.set_parallelism(DEFAULT_PARALLELISM) + env.enable_checkpointing(CHECKPOINT_MS, CheckpointingMode.EXACTLY_ONCE) + + source = ( + KafkaSource.builder() + .set_bootstrap_servers(KAFKA_BROKERS) + .set_topics(SOURCE_TOPIC) + .set_group_id(GROUP_ID) + .set_property("auto.offset.reset", KAFKA_START) + .set_value_only_deserializer(SimpleStringSchema()) + .build() + ) + + stream = env.from_source( + source, + WatermarkStrategy.no_watermarks(), + f"source-{SOURCE_TOPIC}", + ) + + mapped = stream.map(process_json_line, output_type=Types.STRING()) + filtered = mapped.filter(lambda s: bool(s and s.strip())) + + # Always print for quick debugging + filtered.name("stdout-preview").print() + + # Optional Kafka sink + if SINK_TOPIC: + guarantee = ( + DeliveryGuarantee.AT_LEAST_ONCE + if DELIVERY_GUARANTEE.upper() == "AT_LEAST_ONCE" + else DeliveryGuarantee.NONE + ) + sink = ( + KafkaSink.builder() + .set_bootstrap_servers(KAFKA_BROKERS) + .set_record_serializer( + KafkaRecordSerializationSchema.builder() + .set_topic(SINK_TOPIC) + .set_value_serialization_schema(SimpleStringSchema()) + .build() + ) + .set_delivery_guarantee(guarantee) + .set_property("transaction.timeout.ms", TRANSACTION_TIMEOUT_MS) + .build() + ) + filtered.sink_to(sink).name(f"sink-{SINK_TOPIC}") + + env.execute("flink-http-classifier") + + +if __name__ == "__main__": + main() diff --git a/services/sounds_flink/processor.py b/services/sounds_flink/processor.py new file mode 100644 index 000000000..1b4c77814 --- /dev/null +++ b/services/sounds_flink/processor.py @@ -0,0 +1,136 @@ +import json +import logging +from datetime import datetime +from typing import Tuple, Optional, Dict +from urllib.parse import unquote, unquote_plus + +import requests +from requests.adapters import HTTPAdapter +from urllib3.util.retry import Retry + +from config import ( + CLASSIFIER_HTTP_URL, + REQUEST_TIMEOUT, + RETRIES_TOTAL, + BACKOFF_FACTOR, + DEFAULT_BUCKET, +) + +# Reusable HTTP session with retries/backoff +_session = requests.Session() +_retries = Retry( + total=RETRIES_TOTAL, + backoff_factor=BACKOFF_FACTOR, + status_forcelist=(429, 500, 502, 503, 504), + allowed_methods=["GET", "POST"], + respect_retry_after_header=True, +) +_session.mount("http://", HTTPAdapter(max_retries=_retries)) +_session.mount("https://", HTTPAdapter(max_retries=_retries)) + + +def _try_json(raw: str) -> Optional[Dict]: + try: + return json.loads(raw) + except Exception: + return None + + +def _extract_bucket_key(event: Dict) -> Tuple[Optional[str], Optional[str]]: + """ + Extract (bucket, key) from multiple possible MinIO/S3 event shapes. + Supports: + - short link format: {"file_name": "...", "key": "/", "linked_time": "..."} + - flat: {"Bucket": "...", "Key": "..."} + - Records[0].s3.bucket.name / Records[0].s3.object.key + """ + bucket: Optional[str] = None + key: Optional[str] = None + + # 1) Short link format (from *_connections topics): key="/" + if isinstance(event.get("key"), str): + k = event["key"].strip() + k = unquote_plus(unquote(k)) + if "/" in k: + bucket, key = k.split("/", 1) + else: + key = k # no bucket provided here + + # 2) Flat shape + if (bucket is None or key is None) and event.get("Bucket") and event.get("Key"): + bucket = bucket or event.get("Bucket") + key = key or event.get("Key") + + # 3) Records[...] S3-style + if bucket is None or key is None: + records = event.get("Records") or [] + if records: + r0 = records[0] + s3 = r0.get("s3", {}) + b = s3.get("bucket", {}) + o = s3.get("object", {}) + bucket = bucket or b.get("name") + key = key or o.get("key") + + # Normalize/URL-decode + if isinstance(key, str) and key: + key = unquote_plus(unquote(key)) + + return bucket, key + + +def _classify(bucket: Optional[str], key: Optional[str]) -> Optional[Dict]: + """ + Call the classifier service with the resolved (bucket, key). + The classifier expects: + { "s3_bucket": "...", "s3_key": "..." } + """ + if not key: + return None + + # Prefer provided bucket, otherwise fallback to DEFAULT_BUCKET if configured + eff_bucket = bucket or (DEFAULT_BUCKET if DEFAULT_BUCKET else None) + if not eff_bucket: + # Without a bucket we cannot call the classifier + return None + + payload = { + "s3_bucket": eff_bucket, + "s3_key": key, + } + + try: + resp = _session.post(CLASSIFIER_HTTP_URL, json=payload, timeout=REQUEST_TIMEOUT) + if resp.status_code >= 400: + logging.warning("Classifier returned %s for key=%s", resp.status_code, key) + return None + return resp.json() + except Exception as e: + logging.warning("Classifier request failed for key=%s: %s", key, e) + return None + + +def process_json_line(raw: str) -> str: + """ + Map function: input raw JSON string -> output JSON string or "" to skip. + 1) Parse JSON + 2) Extract (bucket, key) + 3) Call classifier (payload: s3_bucket/s3_key) + 4) Return compact JSON result or "" to drop + """ + event = _try_json(raw) + if not event: + return "" + + bucket, key = _extract_bucket_key(event) + result = _classify(bucket, key) + if not result: + return "" + + out = { + "s3_bucket": bucket or DEFAULT_BUCKET or "", + "s3_key": key, + "result": result, + "received_at": datetime.utcnow().isoformat(timespec="seconds") + "Z", + } + return json.dumps(out, separators=(",", ":")) diff --git a/services/sounds_flink/requirements.txt b/services/sounds_flink/requirements.txt new file mode 100644 index 000000000..f5de43b68 --- /dev/null +++ b/services/sounds_flink/requirements.txt @@ -0,0 +1,6 @@ +apache-flink==1.19.3 +requests==2.32.3 +urllib3==2.2.3 +protobuf==4.25.3 +googleapis-common-protos==1.63.0 +grpcio==1.60.0 diff --git a/services/weed_detection/Dockerfile b/services/weed_detection/Dockerfile new file mode 100644 index 000000000..5821909cc --- /dev/null +++ b/services/weed_detection/Dockerfile @@ -0,0 +1,42 @@ +# ---- Base: lightweight Python + Torch CPU ---- +FROM python:3.10-slim + +ARG DEBIAN_FRONTEND=noninteractive +ENV PIP_NO_CACHE_DIR=1 PYTHONUNBUFFERED=1 + +# System and build tools +RUN apt-get update && apt-get install -y --no-install-recommends \ + ca-certificates curl build-essential gcc libpq5 git && \ + rm -rf /var/lib/apt/lists/* + +# NetFree CA +COPY certs/netfree-ca.crt /usr/local/share/ca-certificates/netfree-ca.crt +RUN update-ca-certificates + +# Make pip/requests use the system CA bundle +ENV SSL_CERT_FILE=/etc/ssl/certs/ca-certificates.crt \ + REQUESTS_CA_BUNDLE=/etc/ssl/certs/ca-certificates.crt \ + PIP_CERT=/etc/ssl/certs/ca-certificates.crt + +# NEW: install certifi and replace it with the system bundle +RUN python -m pip install --upgrade pip certifi && python - <<'PY' +import certifi, shutil, os +src = "/etc/ssl/certs/ca-certificates.crt" +dst = certifi.where() +os.makedirs(os.path.dirname(dst), exist_ok=True) +shutil.copyfile(src, dst) +print("certifi bundle replaced:", dst) +PY + +# Install PyTorch CPU (2.9.0) + torchvision (0.24.0) from the PyTorch index +RUN pip install "torch==2.9.0+cpu" "torchvision==0.24.0+cpu" --index-url https://download.pytorch.org/whl/cpu + +# Other dependencies — excluding torch/torchvision/torchaudio +WORKDIR /app +COPY requirements.txt /tmp/requirements.txt +RUN sed -i '/^torch/d;/^torchvision/d;/^torchaudio/d' /tmp/requirements.txt && \ + pip install --no-cache-dir -r /tmp/requirements.txt + +# Code and execution +COPY . . +CMD ["python", "-m", "scripts.run_detection", "--storage", "minio"] diff --git a/services/weed_detection/README.md b/services/weed_detection/README.md new file mode 100644 index 000000000..2d1de8659 --- /dev/null +++ b/services/weed_detection/README.md @@ -0,0 +1,140 @@ +# 🌱 Weed Detection Pipeline — MinIO → PostgreSQL + +## Overview +This project implements a **weed detection and analysis pipeline** that automatically: +1. Retrieves images from **MinIO** (S3-compatible object storage). +2. Runs **weed detection** using a combination of heuristic image analysis and machine learning (MobileNetV3). +3. Writes detection results and statistics into a **Relational Database (PostgreSQL)**. + +It’s designed for **automated weekly or on-demand runs** using Docker or local Python execution. + +--- + +## 🧠 Architecture + +**Flow:** +`MinIO (images) → Local cache → Weed Detection (Heuristic + ML) → PostgreSQL` + +### Main Steps +1. **Data Input** + - Images are loaded from MinIO using the credentials defined in `.env`. + - Supports both local and remote (S3-compatible) backends. + +2. **Processing** + - **Heuristic Detection** using Excess Green (ExG) and Otsu thresholding. + - **ML Refinement** with a small MobileNetV3 model (`ml_model.py`) to improve detection accuracy. + - Output: weed masks, bounding boxes, and anomaly scores. + +3. **Database Output** + - Results are inserted into PostgreSQL tables (`tile_stats`, `anomalies`, `qa_runs`) via SQLAlchemy. + - Geometry data is stored as WKT (PostGIS-compatible). + +--- + +## 🧩 Project Structure + +``` +project_root/ +├── scripts/ +│ └── run_detection.py # Main entry point for batch processing +├── src/ +│ ├── detectors/ # Weed and disease detection logic +│ ├── pipeline/ # Database and utility modules +│ └── models/ # ML models (e.g., MobileNetV3) +├── data/ # Local image cache +├── Dockerfile +├── docker-compose.yml +├── .env +└── run_weekly.ps1 # Windows PowerShell automation script +``` + +--- + +## ⚙️ Technologies Used + +| Component | Description | +|------------------|-------------| +| **Python 3.10+** | Core language | +| **PyTorch** | ML inference (MobileNetV3) | +| **OpenCV** | Image preprocessing and segmentation | +| **SQLAlchemy** | ORM and database connection | +| **MinIO SDK** | S3-compatible data access | +| **Docker Compose** | Service orchestration | +| **PostgreSQL + PostGIS** | Result storage and spatial data handling | + +--- + +## 🧾 Environment Configuration + +The `.env` file defines all key environment variables: +```ini +DB_URL=postgresql+psycopg2://user:password@db:5432/missions_db +STORAGE_BACKEND=minio +MINIO_ENDPOINT=minio-hot:9000 +MINIO_ACCESS_KEY=minioadmin +MINIO_SECRET_KEY=minioadmin123 +MINIO_BUCKET=ground +MINIO_SECURE=false +BATCH_SIZE=64 +MAX_WORKERS=4 +MIN_BBOX_AREA=150 +MIN_COMPONENT_AREA=200 +``` + +--- + +## 🚀 Running the Project + +### Option 1: Run via Docker +```bash +docker compose up -d --build +docker compose logs -f weed-detector +``` + +### Option 2: Run Locally (Python) +```bash +python -m venv .venv +source .venv/bin/activate # or .venv\Scripts\activate on Windows +pip install -r requirements.txt +python -m scripts.run_detection --storage minio +``` + +--- + +## 🕒 Scheduled Execution (Windows) + +The `run_weekly.ps1` script automates weekly runs using **Task Scheduler**. +It: +- Ensures Docker is running +- Executes `docker compose run` for the detector +- Logs output to `C:\logs\weed-weekly.log` +- Prevents concurrent runs using a lock mechanism + +--- + +## 🗄️ Database Schema (Simplified) + +| Table | Purpose | +|----------------|----------| +| **anomalies** | Stores detected weed events and metadata | +| **tile_stats** | Aggregated scores per image/tile | +| **qa_runs** | Logs of detection runs for debugging and QA | + +> Requires PostgreSQL with PostGIS enabled for geometry operations. + +--- + +## 🧰 Troubleshooting + +| Issue | Cause | Fix | +|-------|--------|-----| +| **Torch model not found** | blocked download or cache missing | Manually place model in `~/.cache/torch/hub/checkpoints` | +| **UniqueViolation on tile_stats** | duplicate tile_id/mission_id | Add `ON CONFLICT DO NOTHING` or adjust mission IDs | +| **Slow performance** | batch size too high | Lower `BATCH_SIZE` and `MAX_WORKERS` | +| **SSL errors** | missing CA certificate | Verify `CA_CERT_PATH` or disable `MINIO_SECURE` if local | + +--- + +## 🏁 Summary + +This project provides an **end-to-end pipeline** for automated weed detection — from image retrieval to database integration — built for scalable, repeatable, and containerized deployment. diff --git a/services/weed_detection/config/config.yaml b/services/weed_detection/config/config.yaml new file mode 100644 index 000000000..e69de29bb diff --git a/services/weed_detection/docker-compose.yml b/services/weed_detection/docker-compose.yml new file mode 100644 index 000000000..4e8ced77f --- /dev/null +++ b/services/weed_detection/docker-compose.yml @@ -0,0 +1,13 @@ +services: + weed-detector: + build: . + container_name: weed-detector + restart: unless-stopped + env_file: + - .env + volumes: + - ./data_minio_cache:/app/data_minio_cache + command: ["python", "-m", "scripts.run_detection", "--storage", "minio"] + # (Optional) if you have a proxy/certificates: + # environment: + # - REQUESTS_CA_BUNDLE=/etc/ssl/certs/ca-certificates.crt diff --git a/services/weed_detection/migrations/001_init.sql b/services/weed_detection/migrations/001_init.sql new file mode 100644 index 000000000..e69de29bb diff --git a/services/weed_detection/models/__init__.py b/services/weed_detection/models/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/services/weed_detection/models/dataset.py b/services/weed_detection/models/dataset.py new file mode 100644 index 000000000..93b75ef3f --- /dev/null +++ b/services/weed_detection/models/dataset.py @@ -0,0 +1,159 @@ +import os +from typing import List, Optional, Tuple, Dict +import pandas as pd +import numpy as np +from PIL import Image +import torch +from torch.utils.data import Dataset + +class WeedsFromTables(Dataset): + """ + Supports two modes: + 1) Bounding boxes in tables (image_path + xmin, ymin, xmax, ymax [+ label]) -> builds a mask from boxes. + 2) Without boxes (only Filename/Label) -> loads a corresponding PNG mask file from the masks folder. + + Assumptions: + - If box_cols=None: a masks directory is required with one mask file per image (same basename, PNG). + - If class_col exists and is not binary, the mask values should contain matching class indices. + """ + + def __init__( + self, + root_dir: str, + table_files: List[str], + labels_file: Optional[str] = None, + image_col: str = "image_path", + box_cols: Optional[Tuple[str, str, str, str]] = ("xmin", "ymin", "xmax", "ymax"), + class_col: Optional[str] = "label", + image_transform=None, + mask_transform=None, + masks_dir: str = "masks", # relative to root_dir + ): + self.root_dir = root_dir + self.image_col = image_col + self.box_cols = box_cols + self.class_col = class_col + self.image_transform = image_transform + self.mask_transform = mask_transform + self.masks_dir = os.path.join(root_dir, masks_dir) + + # Load and merge tables + dfs = [] + for path in table_files: + ext = os.path.splitext(path)[1].lower() + if ext in (".xlsx", ".xls"): + df = pd.read_excel(path) + elif ext == ".csv": + df = pd.read_csv(path) + else: + raise ValueError(f"Unsupported table extension: {ext}") + dfs.append(df) + if not dfs: + raise ValueError("No tables loaded.") + df = pd.concat(dfs, ignore_index=True) + df.columns = [str(c).strip() for c in df.columns] + + # Flexibility for common column names: + # If the expected image_col does not exist, try 'Filename'/'filename' + if self.image_col not in df.columns: + for cand in ("Filename", "filename", "file_name", "image", "img"): + if cand in df.columns: + self.image_col = cand + break + if self.image_col not in df.columns: + raise KeyError(f"Missing image column. Expected '{self.image_col}' or a common alternative (e.g., 'Filename').") + + # Drop rows without a valid path/name + df = df.dropna(subset=[self.image_col]).copy() + + # If there are no boxes → use mask files mode + self.use_file_masks = self.box_cols is None + + # Optional label mapping + self.label2id: Optional[Dict[str, int]] = None + if self.class_col and self.class_col in df.columns: + if labels_file: + lex = os.path.splitext(labels_file)[1].lower() + ldf = pd.read_excel(labels_file) if lex in (".xlsx", ".xls") else pd.read_csv(labels_file) + cols = {c.lower(): c for c in ldf.columns} + id_col = cols.get("id") or cols.get("class_id") or list(ldf.columns)[0] + name_col = cols.get("name") or cols.get("label") or cols.get("class") or list(ldf.columns)[1] + self.label2id = {str(r[name_col]): int(r[id_col]) for _, r in ldf.iterrows()} + else: + uniq = sorted(set(map(str, df[self.class_col].unique()))) + self.label2id = {name: i + 1 for i, name in enumerate(uniq)} # 0 = background + + # Normalize image path: if the value is only a filename, search for it under root_dir/images/ + def resolve_path(p: str) -> str: + p = str(p) + if os.path.isabs(p): + return p + cand = os.path.join(self.root_dir, p) + if os.path.exists(cand): + return cand + return os.path.join(self.root_dir, "images", os.path.basename(p)) + + df[self.image_col] = df[self.image_col].map(resolve_path) + + # Save image list and group by image (if boxes exist) + self.images = list(df[self.image_col].unique()) + self.by_image = None + if not self.use_file_masks: + # Ensure bounding box columns exist + for c in (self.box_cols or ()): + if c not in df.columns: + raise KeyError(f"Missing bbox column '{c}' in tables.") + self.by_image = {img: subdf for img, subdf in df.groupby(self.image_col, sort=False)} + + def __len__(self): + return len(self.images) + + @staticmethod + def _clip_box(x1, y1, x2, y2, W, H): + x1 = int(np.clip(x1, 0, W)) + y1 = int(np.clip(y1, 0, H)) + x2 = int(np.clip(x2, 0, W)) + y2 = int(np.clip(y2, 0, H)) + if x2 < x1: x1, x2 = x2, x1 + if y2 < y1: y1, y2 = y2, y1 + return x1, y1, x2, y2 + + def __getitem__(self, idx): + img_path = self.images[idx] + image = Image.open(img_path).convert("RGB") + W, H = image.size + + if self.use_file_masks: + # Load mask from a PNG file with the same name as the image + mask_name = os.path.splitext(os.path.basename(img_path))[0] + ".png" + mask_path = os.path.join(self.masks_dir, mask_name) + if not os.path.exists(mask_path): + raise FileNotFoundError(f"Mask file not found for image: {mask_path}") + mask = Image.open(mask_path).convert("L") + else: + # Build mask from bounding boxes + mask_np = np.zeros((H, W), dtype=np.uint8) + rows = self.by_image[self.images[idx]] + for _, row in rows.iterrows(): + x1, y1, x2, y2 = (row[self.box_cols[0]], row[self.box_cols[1]], + row[self.box_cols[2]], row[self.box_cols[3]]) + x1, y1, x2, y2 = self._clip_box(x1, y1, x2, y2, W, H) + if self.label2id is not None and self.class_col and self.class_col in row: + cls_id = self.label2id.get(str(row[self.class_col]), 1) + else: + cls_id = 1 + mask_np[y1:y2, x1:x2] = cls_id + mask = Image.fromarray(mask_np, mode="L") + + # Transformations + if self.image_transform: + image = self.image_transform(image) + else: + image = torch.from_numpy(np.array(image)).permute(2, 0, 1).float() / 255.0 + + if self.mask_transform: + mask = self.mask_transform(mask) + else: + mask = torch.from_numpy(np.array(mask, dtype=np.uint8)).long() + + return image, mask diff --git a/services/weed_detection/models/evaluate.py b/services/weed_detection/models/evaluate.py new file mode 100644 index 000000000..21016e67d --- /dev/null +++ b/services/weed_detection/models/evaluate.py @@ -0,0 +1,323 @@ +# models/evaluate.py +# --------------------------------------------------------------------- +# Lightweight & robust evaluation for a binary UNet segmentation model. +# - Safe on Windows (no anonymous lambdas; workers=0 by default) +# - Normalizes mask to {0,1} so BCEWithLogitsLoss behaves correctly +# - Matches logits size to mask to avoid shape errors +# - Clamps logits to [-20, 20] to prevent numerical blow-ups in BCE +# - Tiny default IMG_SIZE (16x16) for very fast CPU sanity checks +# - Prints ETA, IoU/Dice, Dice-Loss, one-batch profile, and optional +# best-threshold sweep on the cached validation subset +# --------------------------------------------------------------------- + +import os +import time +import random +import argparse +from typing import Tuple, List, Dict, Any + +import torch +import torch.nn.functional as F +from torch.utils.data import DataLoader, Subset +from torchvision import transforms +from torchvision.transforms import InterpolationMode + +from models.unet_model import UNet +from models.dataset import WeedsFromTables as DeepWeedsDataset + +# Keep machine responsive (mirror your train.py) +torch.set_num_threads(1) + +# ----------------------- Defaults ------------------------------------ + +ROOT = "data" +LABELS_DIR = os.path.join(ROOT, "labels") +MASKS_DIR = "masks" +IMAGE_COL = "Filename" +IMG_SIZE_DEFAULT: Tuple[int, int] = (16, 16) # tiny & fast (matches your train.py) +WEIGHTS_DEFAULT = "models/unet_weedseg_best.pth" +VAL_PREFIX_DEFAULT = "val_subset" + +FAST_DEBUG_DEFAULT = True +SUBSET_SIZE_VAL_DEFAULT = 200 +MAX_STEPS_EVAL_DEFAULT = 100 +PRINT_EVERY_DEFAULT = 10 + +BATCH_SIZE_DEFAULT = 1 # safe/light on CPU +WORKERS_DEFAULT = 0 # Windows-safe (no pickling issues) + +CLAMP_LOGITS: Tuple[float, float] = (-20.0, 20.0) # stabilize BCE + +# ----------------------- Picklable transforms ------------------------- + +class MaskTo01(object): + """ + Convert [1,H,W] uint8 {0,255} mask from PILToTensor to [H,W] long {0,1}. + Kept as a top-level class so it's picklable on Windows. + """ + def __call__(self, t: torch.Tensor) -> torch.Tensor: + t = t.squeeze(0) # [H,W] uint8 + return (t > 0).to(torch.long) # [H,W] long {0,1} + +def build_transforms(img_size: Tuple[int, int]): + image_tf = transforms.Compose([ + transforms.Resize(img_size, interpolation=InterpolationMode.BILINEAR), + transforms.ToTensor(), # [C,H,W] float in [0,1] + ]) + mask_tf = transforms.Compose([ + transforms.Resize(img_size, interpolation=InterpolationMode.NEAREST), + transforms.PILToTensor(), # [1,H,W] uint8 (0 or 255) + MaskTo01(), # [H,W] long {0,1} + ]) + return image_tf, mask_tf + +# ----------------------- Dataset utils -------------------------------- + +def collect_tables(labels_dir: str, prefix: str) -> List[str]: + files = [] + for f in os.listdir(labels_dir): + name = f.lower() + if name.startswith(prefix.lower()) and (name.endswith(".csv") or name.endswith(".xlsx")): + files.append(os.path.join(labels_dir, f)) + if not files: + raise RuntimeError(f"No files with prefix '{prefix}' found in {labels_dir}") + return sorted(files) + +def make_subset(ds, k: int, use_subset: bool, seed: int = 1337): + if not use_subset: + return ds + k = min(k, len(ds)) + rng = random.Random(seed) + idx = rng.sample(range(len(ds)), k) + print(f"FAST_DEBUG: evaluating on subset of {k}/{len(ds)} examples") + return Subset(ds, idx) + +# ----------------------- Core helpers --------------------------------- + +def _match_logits_to_mask(logits: torch.Tensor, mask_hw: Tuple[int, int]) -> torch.Tensor: + """Ensure the logits spatial size matches the target mask.""" + if tuple(logits.shape[-2:]) != tuple(mask_hw): + logits = F.interpolate(logits, size=mask_hw, mode="bilinear", align_corners=False) + return logits + +def iou_dice(pred01: torch.Tensor, tgt01: torch.Tensor): + """pred01,tgt01: [H,W] {0,1} uint8/long.""" + inter = (pred01 & tgt01).sum().item() + union = (pred01 | tgt01).sum().item() + iou = inter / (union + 1e-8) + dice = (2 * inter) / (pred01.sum().item() + tgt01.sum().item() + 1e-8) + return float(iou), float(dice) + +def dice_loss_from_probs(prob: torch.Tensor, target: torch.Tensor, eps: float = 1e-6) -> float: + """ + prob,target: [B,1,H,W] float in [0,1] + Returns mean Dice loss over the batch. + """ + inter = (prob * target).sum(dim=(1, 2, 3)) + denom = prob.sum(dim=(1, 2, 3)) + target.sum(dim=(1, 2, 3)) + dl = 1.0 - (2.0 * inter + eps) / (denom + eps) + return float(dl.mean().item()) + +def find_best_threshold(probs_list: List, masks_list: List) -> Tuple[float, float]: + """Grid-search a good probability threshold on cached samples.""" + import numpy as np + ths = np.linspace(0.1, 0.9, 17) + best_t, best_dice = 0.5, -1.0 + for t in ths: + dices = [] + for p, m in zip(probs_list, masks_list): + pred01 = (p > t).astype("uint8") + inter = (pred01 & m).sum() + dices.append((2 * inter) / (pred01.sum() + m.sum() + 1e-8)) + md = float(np.mean(dices)) if dices else -1.0 + if md > best_dice: + best_dice, best_t = md, float(t) + return best_t, best_dice + +# ----------------------- Profiling ------------------------------------ + +@torch.no_grad() +def profile_one_batch(model, loader, device) -> Dict[str, Any]: + """Quick timing of one batch to sense where the time goes.""" + t0 = time.time() + images, masks = next(iter(loader)) + t1 = time.time() + images = images.to(device) + masks = masks.unsqueeze(1).float().to(device) + t2 = time.time() + logits = model(images) + logits = _match_logits_to_mask(logits, masks.shape[-2:]) + t3 = time.time() + return { + "load_s": t1 - t0, + "to_device_s": t2 - t1, + "forward_s": t3 - t2, + "total_s": t3 - t0, + "img_shape": tuple(images.shape), + "logits_shape": tuple(logits.shape), + "mask_unique": torch.unique(masks[0,0].cpu()) + } + +# ----------------------- Evaluation ----------------------------------- + +@torch.inference_mode() +def evaluate(model: torch.nn.Module, + loader: DataLoader, + device: torch.device, + print_every: int = 10, + max_steps: int | None = None, + do_threshold_sweep: bool = True) -> None: + criterion = torch.nn.BCEWithLogitsLoss() + + total_bce = 0.0 + total_dice_loss = 0.0 + iou_sum, dice_sum = 0.0, 0.0 + n_valid = 0 + + t0 = time.time() + total_steps = len(loader) if max_steps is None else min(len(loader), max_steps) + + # cache a small set for threshold sweep + probs_cache, masks_cache = [], [] + + for step, (imgs, masks) in enumerate(loader, 1): + imgs = imgs.to(device) # [B,3,H,W] + masks = masks.unsqueeze(1).float().to(device) # [B,1,H,W] float {0,1} + + logits = model(imgs) # [B,1,h,w] + logits = _match_logits_to_mask(logits, masks.shape[-2:]) + + if step == 1: + print("shapes:", tuple(imgs.shape), tuple(logits.shape), tuple(masks.shape)) + print("mask unique (should be {0.,1.}):", torch.unique(masks[0,0].cpu())) + # Logits diagnostics + print("logits stats -> min/max/mean/std:", + logits.min().item(), logits.max().item(), + logits.mean().item(), logits.std().item()) + + # Stabilize BCE against extreme logits + logits = torch.clamp(logits, CLAMP_LOGITS[0], CLAMP_LOGITS[1]) + + bce = criterion(logits, masks) + if not torch.isfinite(bce): + print(f" step {step}: non-finite loss -> skipped") + continue + + total_bce += float(bce.item()) + n_valid += 1 + + prob = torch.sigmoid(logits) + total_dice_loss += dice_loss_from_probs(prob, masks) + + # threshold 0.5 metrics + pred01 = (prob > 0.5).to(torch.uint8)[0, 0].cpu() + tgt01 = masks[0, 0].to(torch.uint8).cpu() + iou, dice = iou_dice(pred01, tgt01) + iou_sum += iou; dice_sum += dice + + # cache for sweep + if do_threshold_sweep and len(probs_cache) < (total_steps if max_steps else 200): + probs_cache.append(prob[0,0].cpu().numpy()) + masks_cache.append(tgt01.cpu().numpy()) + + if step == 1 or step % print_every == 0: + elapsed = time.time() - t0 + eta = elapsed / step * (total_steps - step) + print(f" step {step}/{total_steps} | BCE={total_bce/max(1,n_valid):.4f} " + f"| DiceLoss={total_dice_loss/max(1,n_valid):.4f} " + f"| IoU@0.5={iou_sum/max(1,n_valid):.4f} | Dice@0.5={dice_sum/max(1,n_valid):.4f} " + f"| ETA={int(eta//60)}m {int(eta%60)}s") + + if (max_steps is not None) and (step >= max_steps): + break + + if n_valid == 0: + print("No valid samples evaluated.") + return + + print("\n===== VALIDATION SUMMARY =====") + print(f"BCE={total_bce/n_valid:.4f} | DiceLoss={total_dice_loss/n_valid:.4f} " + f"| IoU@0.5={iou_sum/n_valid:.4f} | Dice@0.5={dice_sum/n_valid:.4f} | N={n_valid}") + + # Optional: threshold sweep report + if do_threshold_sweep and probs_cache: + try: + best_t, best_dice = find_best_threshold(probs_cache, masks_cache) + print(f"Best threshold on VAL (grid 0.1..0.9): t={best_t:.2f} | Dice={best_dice:.4f}") + except Exception as e: + print(f"(threshold sweep skipped: {e})") + +# ----------------------- CLI/Main ------------------------------------- + +def parse_args(): + p = argparse.ArgumentParser(description="Lightweight evaluation for UNet weed segmentation.") + p.add_argument("--root", type=str, default=ROOT) + p.add_argument("--labels_dir", type=str, default=LABELS_DIR) + p.add_argument("--val_prefix", type=str, default=VAL_PREFIX_DEFAULT) + p.add_argument("--masks_dir", type=str, default=MASKS_DIR) + p.add_argument("--image_col", type=str, default=IMAGE_COL) + p.add_argument("--img_size", type=int, nargs=2, default=list(IMG_SIZE_DEFAULT)) # H W + p.add_argument("--weights", type=str, default=WEIGHTS_DEFAULT) + + p.add_argument("--fast_debug", action="store_true", default=FAST_DEBUG_DEFAULT) + p.add_argument("--subset", type=int, default=SUBSET_SIZE_VAL_DEFAULT) + p.add_argument("--max_steps", type=int, default=MAX_STEPS_EVAL_DEFAULT) + + p.add_argument("--batch_size", type=int, default=BATCH_SIZE_DEFAULT) + p.add_argument("--workers", type=int, default=WORKERS_DEFAULT) # keep 0 on Windows + p.add_argument("--print_every", type=int, default=PRINT_EVERY_DEFAULT) + p.add_argument("--seed", type=int, default=1337) + p.add_argument("--no_thresh_sweep", action="store_true", help="Disable best-threshold sweep") + return p.parse_args() + +def main(): + args = parse_args() + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + print("Device:", device) + + image_tf, mask_tf = build_transforms(tuple(args.img_size)) + val_tables = collect_tables(args.labels_dir, args.val_prefix) + + val_full = DeepWeedsDataset( + root_dir=args.root, + table_files=val_tables, + labels_file=None, + image_col=args.image_col, + box_cols=None, # using mask files + class_col=None, # binary from mask + image_transform=image_tf, + mask_transform=mask_tf, + masks_dir=args.masks_dir, + ) + + val_dataset = make_subset(val_full, args.subset, args.fast_debug, seed=args.seed) + + val_loader = DataLoader( + val_dataset, + batch_size=max(1, args.batch_size), + shuffle=False, + num_workers=max(0, args.workers), # 0 by default → Windows-safe + pin_memory=False, # CPU path + persistent_workers=False, + ) + + model = UNet(in_channels=3, out_channels=1).to(device) + state = torch.load(args.weights, map_location=device) + model.load_state_dict(state, strict=True) + model.eval() + + # one-batch profile (helps detect bottlenecks quickly) + prof = profile_one_batch(model, val_loader, device) + print(f"PROFILE one batch -> load={prof['load_s']:.3f}s | to_device={prof['to_device_s']:.3f}s | " + f"forward={prof['forward_s']:.3f}s | total={prof['total_s']:.3f}s | " + f"img={prof['img_shape']} | logits={prof['logits_shape']} | mask_unique={prof['mask_unique']}") + + evaluate( + model, val_loader, device, + print_every=args.print_every, + max_steps=(args.max_steps if args.fast_debug else None), + do_threshold_sweep=(not args.no_thresh_sweep) + ) + +if __name__ == "__main__": + main() diff --git a/services/weed_detection/models/ml_model.py b/services/weed_detection/models/ml_model.py new file mode 100644 index 000000000..5472e192e --- /dev/null +++ b/services/weed_detection/models/ml_model.py @@ -0,0 +1,75 @@ +""" +Optional ML model for weed detection (patch classification / region scoring). +If no weights are available, fallback to heuristic detections is used. +""" + +import torch +import torch.nn as nn +import torchvision.transforms as T +from torchvision.models import mobilenet_v3_small +import numpy as np +import cv2 +from typing import Dict, List, Tuple + +class WeedNet(nn.Module): + def __init__(self, num_classes=2): + super().__init__() + self.backbone = mobilenet_v3_small(weights="DEFAULT") + in_feats = self.backbone.classifier[3].in_features + self.backbone.classifier[3] = nn.Linear(in_feats, num_classes) + + def forward(self, x): + return self.backbone(x) + +class MLWeedDetector: + def __init__(self, weights_path: str | None = None, device: str = "cpu"): + self.device = device + self.model = WeedNet().to(self.device) + self.model.eval() + self.transform = T.Compose([ + T.ToTensor(), + T.Resize((224,224)), + T.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225]) + ]) + self.active = False + if weights_path: + try: + state = torch.load(weights_path, map_location=self.device) + self.model.load_state_dict(state, strict=False) + self.active = True + except Exception: + self.active = False # fallback + + @torch.inference_mode() + def score_mask(self, bgr: np.ndarray, coarse_mask: np.ndarray) -> np.ndarray: + """ + Optionally refine heuristic mask by classifying sampled patches. + Returns refined binary mask (uint8). + """ + bgr = np.ascontiguousarray(bgr) + coarse_mask= np.ascontiguousarray(coarse_mask) + if not self.active: + return coarse_mask + mask = coarse_mask.copy() + ys, xs = np.where(coarse_mask > 0) + if len(ys) == 0: + return coarse_mask + + # sample up to N points for refinement + N = min(200, len(ys)) + idx = np.random.choice(len(ys), N, replace=False) + H, W = bgr.shape[:2] + for i in idx: + y, x = ys[i], xs[i] + y0, x0 = max(0, y-16), max(0, x-16) + y1, x1 = min(H, y+16), min(W, x+16) + patch = cv2.cvtColor(bgr[y0:y1, x0:x1], cv2.COLOR_BGR2RGB) + patch = np.ascontiguousarray(patch) + if patch.size == 0: + continue + inp = self.transform(patch).unsqueeze(0).to(self.device) + logits = self.model(inp) + prob = torch.softmax(logits, dim=1)[0,1].item() # class 1 = weed + if prob < 0.5: + mask[y, x] = 0 + return mask diff --git a/services/weed_detection/models/train.py b/services/weed_detection/models/train.py new file mode 100644 index 000000000..e9cd358a5 --- /dev/null +++ b/services/weed_detection/models/train.py @@ -0,0 +1,333 @@ +# models/train.py +# --------------------------------------------------------------------- +# Lightweight UNet training loop (CPU/Windows friendly). +# Key fixes: +# - Masks are normalized to {0,1} (not {0,255}) via a picklable transform. +# - Size-mismatch safety: logits are resized to target HxW before loss. +# - Optional BCE+Dice combined loss (helps class imbalance). +# - Gradient clipping to prevent exploding updates. +# Defaults mirror your "light" config: tiny IMG_SIZE, batch_size=1, workers=0. +# --------------------------------------------------------------------- + +import os +import time +import random +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.utils.data import DataLoader, Subset +from torchvision import transforms +from torchvision.transforms import InterpolationMode + +from models.unet_model import UNet +from models.dataset import WeedsFromTables as DeepWeedsDataset + +print(torch.cuda.is_available()) + +# Keep machine responsive on CPU +torch.set_num_threads(1) + +# ===================== Run Config ===================== +ROOT = "data" +LABELS_DIR = os.path.join(ROOT, "labels") + +# Fast debug to verify pipeline end-to-end +FAST_DEBUG = False # set to False for full training +SUBSET_SIZE_TRAIN = 500 +SUBSET_SIZE_VAL = 200 +MAX_STEPS_TRAIN = 20 +MAX_STEPS_VAL = 100 +PRINT_EVERY = 10 + +# Table column names +IMAGE_COL = "Filename" +BOX_COLS = None +CLASS_COL = None + +# Small input for fast CPU sanity checks (can raise to 128x128 later) +IMG_SIZE = (64, 64) + +# Training knobs +LR = 1e-3 +WEIGHT_DECAY = 1e-4 +GRAD_CLIP_NORM = 5.0 # set None to disable +USE_DICE_MIX = True # True -> (BCE + DiceLoss)/2 +SAVE_DIR = "models" +BEST_WEIGHTS_PATH = os.path.join(SAVE_DIR, "unet_weedseg_best.pth") +LAST_WEIGHTS_PATH = os.path.join(SAVE_DIR, "unet_weedseg_last.pth") +SEED = 1337 +# ====================================================== + +# ------------------- Reproducibility ------------------- +def set_seed(seed: int = 1337): + random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) +set_seed(SEED) + +# ------------------- Table collection ------------------ +def collect(prefix: str): + files = [] + for f in os.listdir(LABELS_DIR): + name = f.lower() + if name.startswith(prefix.lower()) and (name.endswith(".xlsx") or name.endswith(".csv")): + files.append(os.path.join(LABELS_DIR, f)) + if not files: + raise RuntimeError(f"No {prefix} files found in {LABELS_DIR}") + return sorted(files) + +# ------------------- Subset helper --------------------- +def make_subset(ds, k: int): + n = len(ds) + if k is None or n <= k: + return ds + idxs = random.sample(range(n), k) + return Subset(ds, idxs) + +# ------------------- Transforms ------------------------ +class MaskTo01(object): + """ + Convert [1,H,W] uint8 {0,255} from PILToTensor to [H,W] long {0,1}. + Kept as a top-level class so it's picklable on Windows. + """ + def __call__(self, t: torch.Tensor) -> torch.Tensor: + t = t.squeeze(0) # [H,W] uint8 + return (t > 0).to(torch.long) # [H,W] 0/1 + +image_tf = transforms.Compose([ + transforms.Resize(IMG_SIZE, interpolation=InterpolationMode.BILINEAR), + transforms.ToTensor(), # -> [C,H,W], float in [0,1] +]) + +mask_tf = transforms.Compose([ + transforms.Resize(IMG_SIZE, interpolation=InterpolationMode.NEAREST), + transforms.PILToTensor(), # -> [1,H,W] uint8 (0 or 255) + MaskTo01(), # -> [H,W] long {0,1} +]) + +# ------------------- Datasets -------------------------- +train_tables = collect("train_subset") +val_tables = collect("val_subset") +labels_file = None + +train_dataset_full = DeepWeedsDataset( + root_dir=ROOT, + table_files=train_tables, + labels_file=labels_file, + image_col=IMAGE_COL, + box_cols=BOX_COLS, + class_col=CLASS_COL, + image_transform=image_tf, + mask_transform=mask_tf, + masks_dir="masks", +) +val_dataset_full = DeepWeedsDataset( + root_dir=ROOT, + table_files=val_tables, + labels_file=labels_file, + image_col=IMAGE_COL, + box_cols=BOX_COLS, + class_col=CLASS_COL, + image_transform=image_tf, + mask_transform=mask_tf, + masks_dir="masks", +) + +if FAST_DEBUG: + print("FAST_DEBUG: using small subsets") + train_dataset = make_subset(train_dataset_full, SUBSET_SIZE_TRAIN) + val_dataset = make_subset(val_dataset_full, SUBSET_SIZE_VAL) +else: + train_dataset = train_dataset_full + val_dataset = val_dataset_full + +# ------------------- DataLoaders ----------------------- +# Windows-safe: workers=0; batch_size=1 for low CPU pressure +train_loader = DataLoader( + train_dataset, batch_size=1, shuffle=True, + num_workers=0, pin_memory=False +) +val_loader = DataLoader( + val_dataset, batch_size=1, shuffle=False, + num_workers=0, pin_memory=False +) + +# ------------------- Model & Optimizer ----------------- +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +model = UNet(in_channels=3, out_channels=1).to(device) + +bce = nn.BCEWithLogitsLoss() +optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY) + +# ------------------- Helpers --------------------------- +def _match_logits_to_mask(logits: torch.Tensor, mask_hw): + """Resize logits to [*,1,H,W] to match target before loss.""" + if logits.shape[-2:] != mask_hw: + logits = F.interpolate(logits, size=mask_hw, mode="bilinear", align_corners=False) + return logits + +def dice_loss_from_logits(logits: torch.Tensor, target01: torch.Tensor, eps: float = 1e-6) -> torch.Tensor: + """ + logits: [B,1,H,W] (raw) + target01: [B,1,H,W] float {0,1} + returns scalar tensor + """ + prob = torch.sigmoid(logits) + inter = (prob * target01).sum(dim=(1,2,3)) + denom = prob.sum(dim=(1,2,3)) + target01.sum(dim=(1,2,3)) + dl = 1.0 - (2.0 * inter + eps) / (denom + eps) + return dl.mean() + +@torch.no_grad() +def pixel_accuracy_from_logits(logits, target01, threshold=0.5): + """ + logits: [B,1,H,W] - פלט המודל לפני sigmoid + target01: [B,1,H,W] - מסכה עם 0/1 + threshold: סף בינארי (לרוב 0.5) + """ + probs = torch.sigmoid(logits) + preds = (probs > threshold).float() + correct = (preds == target01).float().sum() + total = target01.numel() + return (correct / total).item() + +def combined_loss(logits: torch.Tensor, target01: torch.Tensor) -> torch.Tensor: + if USE_DICE_MIX: + return 0.5 * bce(logits, target01) + 0.5 * dice_loss_from_logits(logits, target01) + else: + return bce(logits, target01) + +# ------------------- One-batch profile ----------------- +@torch.no_grad() +def profile_one_batch(model, loader, device): + model.eval() + t0 = time.time() + images, masks = next(iter(loader)) + t1 = time.time() + images = images.to(device) + masks = masks.unsqueeze(1).float().to(device) + t2 = time.time() + logits = model(images) + logits = _match_logits_to_mask(logits, masks.shape[-2:]) + t3 = time.time() + # quick sanity check + print("PROFILE shapes:", tuple(images.shape), tuple(masks.shape), tuple(logits.shape)) + print("mask unique (should be {0.,1.}):", torch.unique(masks[0,0].detach().cpu())) + return { + "load_s": t1 - t0, + "to_device_s": t2 - t1, + "forward_s": t3 - t2, + "total_s": t3 - t0, + "img_shape": tuple(images.shape), + } + +# ------------------- Train / Val loops ----------------- +_printed_debug_shapes = False + +def train_one_epoch(model, loader, optimizer, device, max_steps=None): + global _printed_debug_shapes + model.train() + running = 0.0 + t_start = time.time() + + for step, (images, masks) in enumerate(loader, 1): + if not _printed_debug_shapes: + print("DEBUG shapes:", tuple(images.shape), tuple(masks.shape)) # (B,3,H,W), (B,H,W) + _printed_debug_shapes = True + + images = images.to(device) # [B,3,H,W] + masks = masks.unsqueeze(1).float().to(device) # [B,1,H,W] float {0,1} + + optimizer.zero_grad(set_to_none=True) + logits = model(images) # [B,1,h,w] + logits = _match_logits_to_mask(logits, masks.shape[-2:]) + + loss = combined_loss(logits, masks) + loss.backward() + + if GRAD_CLIP_NORM is not None: + torch.nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP_NORM) + + optimizer.step() + + running += loss.item() + + if step % PRINT_EVERY == 0 or step == 1: + print(f" step {step}/{len(loader)} | loss={running/step:.4f}") + + if (max_steps is not None) and (step >= max_steps): + break + + denom = min(len(loader), max_steps) if max_steps else len(loader) + epoch_time = time.time() - t_start + print(f" epoch avg step: {epoch_time/denom:.3f}s | epoch total: {epoch_time:.1f}s") + return running / max(1, denom) +@torch.no_grad() +def validate(model, loader, device, max_steps=None): + model.eval() + running_loss = 0.0 + running_acc = 0.0 + count = 0 + t_start = time.time() + + for step, (images, masks) in enumerate(loader, 1): + images = images.to(device) + masks = masks.unsqueeze(1).float().to(device) + logits = model(images) + logits = _match_logits_to_mask(logits, masks.shape[-2:]) + loss = combined_loss(logits, masks) + acc = pixel_accuracy_from_logits(logits, masks) + + running_loss += loss.item() + running_acc += acc + count += 1 + + if step % PRINT_EVERY == 0 or step == 1: + print(f" [val] step {step}/{len(loader)} | loss={running_loss/count:.4f} | acc={running_acc/count:.4f}") + + if (max_steps is not None) and (step >= max_steps): + break + + denom = max(1, count) + epoch_time = time.time() - t_start + print(f" [val] epoch avg step: {epoch_time/denom:.3f}s | epoch total: {epoch_time:.1f}s") + avg_loss = running_loss / denom + avg_acc = running_acc / denom + return avg_loss, avg_acc + +# ------------------- Main ------------------------------ +def main(): + epochs = 3 if FAST_DEBUG else 20 + best_val = float("inf") + os.makedirs(SAVE_DIR, exist_ok=True) + + # One-batch profile to understand timing + prof = profile_one_batch(model, train_loader, device) + print(f"PROFILE one batch -> load={prof['load_s']:.3f}s | to_device={prof['to_device_s']:.3f}s | " + f"forward={prof['forward_s']:.3f}s | total={prof['total_s']:.3f}s | shape={prof['img_shape']}") + + for epoch in range(1, epochs + 1): + print(f"\nEpoch {epoch}/{epochs}") + + train_loss = train_one_epoch( + model, train_loader, optimizer, device, + max_steps=(MAX_STEPS_TRAIN if FAST_DEBUG else None) + ) + + val_loss, val_acc = validate( + model, val_loader, device, + max_steps=(MAX_STEPS_VAL if FAST_DEBUG else None) + ) + + print(f"Epoch {epoch:02d} | train_loss={train_loss:.4f} | val_loss={val_loss:.4f} | val_acc={val_acc:.4f}") + + if val_loss < best_val: + best_val = val_loss + torch.save(model.state_dict(), BEST_WEIGHTS_PATH) + print(f"✓ Saved best -> {BEST_WEIGHTS_PATH} (val={best_val:.4f})") + + torch.save(model.state_dict(), LAST_WEIGHTS_PATH) + print(f"✓ Saved last -> {LAST_WEIGHTS_PATH}") + +if __name__ == "__main__": + main() diff --git a/services/weed_detection/models/train_ml_refiner.py b/services/weed_detection/models/train_ml_refiner.py new file mode 100644 index 000000000..4ee5d6749 --- /dev/null +++ b/services/weed_detection/models/train_ml_refiner.py @@ -0,0 +1,107 @@ +""" +Train a small classifier to refine weed mask using pseudo-labels (heuristic or GT). +Saves weights to WEIGHTS_OUT (default: ./weights_refiner.pth). + +Usage: + python -m src.train_ml_refiner +Env (.env): + INPUT_DIR=... # same as batch + GT_DIR=... # optional, if you have GT + USE_GT=0/1 # 1 to use GT masks if available + EPOCHS=3 + BATCH_SIZE=64 + LR=1e-3 + WEIGHTS_OUT=./weights_refiner.pth + SAMPLES_PER_IMAGE=64 + LIMIT_IMAGES=0 # 0 = no limit + MAX_STEPS_PER_EPOCH=0 # 0 = no limit +""" + +import os +from dotenv import load_dotenv +import torch +from torch.utils.data import DataLoader, random_split +import torch.nn as nn +import torch.optim as optim +from tqdm import tqdm + +from .data_ml import PatchRefineDataset +from .ml_model import WeedNet # mobilenet_v3_small head -> 2 classes + + +def main(): + load_dotenv() + images_dir = os.getenv("INPUT_DIR", "./data/images") + gt_dir = os.getenv("GT_DIR", "./data/labels") + use_gt = os.getenv("USE_GT", "0") == "1" + epochs = int(os.getenv("EPOCHS", "3")) + bs = int(os.getenv("BATCH_SIZE", "64")) + lr = float(os.getenv("LR", "1e-3")) + weights_out= os.getenv("WEIGHTS_OUT", "./weights_refiner.pth") + + samples_per_image = int(os.getenv("SAMPLES_PER_IMAGE", "64")) + limit_images = int(os.getenv("LIMIT_IMAGES", "0")) + max_steps_per_epoch = int(os.getenv("MAX_STEPS_PER_EPOCH", "0")) + + # Load dataset + ds = PatchRefineDataset(images_dir, gt_dir, use_gt=use_gt, + samples_per_image=samples_per_image, patch_radius=16) + + if limit_images > 0: + ds.images = ds.images[:limit_images] + + n = len(ds) + n_train = int(0.9 * n) + n_val = n - n_train + train_ds, val_ds = random_split(ds, [n_train, n_val]) + + train_loader = DataLoader(train_ds, batch_size=bs, shuffle=True, num_workers=0, pin_memory=True) + val_loader = DataLoader(val_ds, batch_size=bs, shuffle=False, num_workers=0, pin_memory=True) + + device = "cuda" if torch.cuda.is_available() else "cpu" + model = WeedNet(num_classes=2).to(device) + criterion = nn.CrossEntropyLoss() + optimizer = optim.Adam(model.parameters(), lr=lr) + + best_acc = 0.0 + for epoch in range(1, epochs+1): + model.train() + total = correct = 0 + pbar = tqdm(train_loader, desc=f"Train {epoch}/{epochs}") + for step, (x, y) in enumerate(pbar, start=1): + x, y = x.to(device), y.to(device) + optimizer.zero_grad() + logits = model(x) + loss = criterion(logits, y) + loss.backward() + optimizer.step() + pred = logits.argmax(1) + correct += (pred == y).sum().item() + total += y.numel() + pbar.set_postfix(loss=f"{loss.item():.4f}", acc=f"{(correct/total)*100:.1f}%") + + if max_steps_per_epoch and step >= max_steps_per_epoch: + break + + # validation + model.eval() + v_total = v_correct = 0 + with torch.no_grad(): + for x, y in val_loader: + x, y = x.to(device), y.to(device) + logits = model(x) + pred = logits.argmax(1) + v_correct += (pred == y).sum().item() + v_total += y.numel() + v_acc = v_correct / max(1, v_total) + if v_acc > best_acc: + best_acc = v_acc + torch.save(model.state_dict(), weights_out) + + print(f"[VAL] acc={v_acc:.4f} (best={best_acc:.4f})") + + print(f"[DONE] Saved best weights to: {weights_out}") + + +if __name__ == "__main__": + main() diff --git a/services/weed_detection/models/unet_model.py b/services/weed_detection/models/unet_model.py new file mode 100644 index 000000000..0952c47b8 --- /dev/null +++ b/services/weed_detection/models/unet_model.py @@ -0,0 +1,54 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +class UNet(nn.Module): + def __init__(self, in_channels=3, out_channels=1): + super(UNet, self).__init__() + + # Encoder + self.enc1 = self.conv_block(in_channels, 64) + self.enc2 = self.conv_block(64, 128) + self.enc3 = self.conv_block(128, 256) + self.enc4 = self.conv_block(256, 512) + + # Bottleneck + self.bottleneck = self.conv_block(512, 1024) + + # Decoder + self.dec1 = self.deconv_block(1024, 512) + self.dec2 = self.deconv_block(512, 256) + self.dec3 = self.deconv_block(256, 128) + self.dec4 = self.deconv_block(128, 64) + + # Output layer + self.output = nn.Conv2d(64, out_channels, kernel_size=1) + + def conv_block(self, in_channels, out_channels): + return nn.Sequential( + nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1), + nn.ReLU(inplace=True) + ) + + def deconv_block(self, in_channels, out_channels): + return nn.Sequential( + nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2), + nn.ReLU(inplace=True) + ) + + def forward(self, x): + enc1 = self.enc1(x) + enc2 = self.enc2(enc1) + enc3 = self.enc3(enc2) + enc4 = self.enc4(enc3) + bottleneck = self.bottleneck(enc4) + + dec1 = self.dec1(bottleneck) + dec2 = self.dec2(dec1) + dec3 = self.dec3(dec2) + dec4 = self.dec4(dec3) + + output = self.output(dec4) + return output diff --git a/services/weed_detection/models/weights_refiner.pth b/services/weed_detection/models/weights_refiner.pth new file mode 100644 index 000000000..05c31cbcd Binary files /dev/null and b/services/weed_detection/models/weights_refiner.pth differ diff --git a/services/weed_detection/requirements.txt b/services/weed_detection/requirements.txt new file mode 100644 index 000000000..0d011903c --- /dev/null +++ b/services/weed_detection/requirements.txt @@ -0,0 +1,10 @@ +--extra-index-url https://pypi.org/simple +# torch==2.9.0 +# torchvision==0.24.0 +kafka-python==2.2.2 +minio==7.2.18 +psycopg2-binary==2.9.11 +opencv-python-headless==4.10.0.84 +Pillow +SQLAlchemy==2.0.36 + diff --git a/services/weed_detection/run_weekly.ps1 b/services/weed_detection/run_weekly.ps1 new file mode 100644 index 000000000..9f10eee66 --- /dev/null +++ b/services/weed_detection/run_weekly.ps1 @@ -0,0 +1,84 @@ +# run_weekly.ps1 +# ------------------------------------------- +# Runs the docker-compose service once a week in a clean and stable manner +# ------------------------------------------- + +# ===== Settings ===== +$ProjectDir = "C:\Users\user\Documents\weed-baseline\AgCloud\weed" # update if needed +$Service = "weed-detector" +$LogFile = "C:\logs\weed-weekly.log" +$LockFile = "C:\temp\weed-weekly.lock" +$WaitForDockerMinutes = 5 + +# Create folders for log/lock +New-Item -ItemType Directory -Force -Path (Split-Path $LogFile) | Out-Null +New-Item -ItemType Directory -Force -Path (Split-Path $LockFile) | Out-Null + +# Prevent overlap +if (Test-Path $LockFile) { exit 0 } +New-Item -ItemType File -Path $LockFile -Force | Out-Null + +# Short logging function with timestamp +function Write-Log($msg) { + $stamp = (Get-Date -Format 'yyyy-MM-dd HH:mm:ss') + "$stamp | $msg" | Tee-Object -FilePath $LogFile -Append +} + +try { + Write-Log "JOB START" + + # Start Docker Desktop if it’s not running + if (-not (Get-Process -Name "Docker Desktop" -ErrorAction SilentlyContinue)) { + $dockerExe = "C:\Program Files\Docker\Docker\Docker Desktop.exe" + if (Test-Path $dockerExe) { + Write-Log "Starting Docker Desktop..." + Start-Process $dockerExe | Out-Null + } else { + Write-Log "Docker Desktop not found at $dockerExe" + } + } + + # Wait for Docker engine to start + $deadline = (Get-Date).AddMinutes($WaitForDockerMinutes) + do { + try { docker info | Out-Null; $up=$true } catch { Start-Sleep -Seconds 5 } + } until ($up -or (Get-Date) -gt $deadline) + if (-not $up) { + Write-Log "Docker engine did not become ready within $WaitForDockerMinutes minutes." + exit 98 + } + + # Correct context (harmless if already correct) + docker context use desktop-linux 2>$null | Out-Null + + Push-Location $ProjectDir + + # Header for execution + Write-Log "BEGIN build+run for service '$Service'" + + # Important: do not treat stderr output as a fatal error + $prev = $ErrorActionPreference + $ErrorActionPreference = "Continue" + + # Build and run: returns the container’s own exit code + # --no-deps: only this service; --abort-on-container-exit: stops when the service finishes + "===== $(Get-Date -Format 'yyyy-MM-dd HH:mm:ss') :: DOCKER START =====" | Tee-Object -FilePath $LogFile -Append + docker compose up --no-deps --build --abort-on-container-exit --exit-code-from $Service $Service ` + 2>&1 | Tee-Object -FilePath $LogFile -Append + "===== $(Get-Date -Format 'yyyy-MM-dd HH:mm:ss') :: DOCKER END =====" | Tee-Object -FilePath $LogFile -Append + + $code = $LASTEXITCODE + $ErrorActionPreference = $prev + + if ($code -ne 0) { + Write-Log "JOB FAILED with exit code $code" + exit $code + } else { + Write-Log "JOB SUCCEEDED (exit code 0)" + } + +} finally { + Pop-Location 2>$null + Remove-Item $LockFile -ErrorAction SilentlyContinue + Write-Log "JOB END" +} diff --git a/services/weed_detection/scripts/cron_job_config.yaml b/services/weed_detection/scripts/cron_job_config.yaml new file mode 100644 index 000000000..55006c170 --- /dev/null +++ b/services/weed_detection/scripts/cron_job_config.yaml @@ -0,0 +1,15 @@ +apiVersion: batch/v1 +kind: CronJob +metadata: + name: weed-detection-job +spec: + schedule: "0 0 * * 0" # every Sunday night + jobTemplate: + spec: + template: + spec: + containers: + - name: weed-detection + image: weed-detection-image + command: ["/bin/bash", "-c", "python /scripts/run_detection.py"] + restartPolicy: OnFailure diff --git a/services/weed_detection/scripts/make_masks_auto.py b/services/weed_detection/scripts/make_masks_auto.py new file mode 100644 index 000000000..0c1b03810 --- /dev/null +++ b/services/weed_detection/scripts/make_masks_auto.py @@ -0,0 +1,40 @@ +# scripts/make_masks_auto.py +import os +import cv2 +import numpy as np + +IM_DIR = "data/images" +MASK_DIR = "data/masks" +os.makedirs(MASK_DIR, exist_ok=True) + +def exg_mask(bgr): + b, g, r = cv2.split(bgr.astype(np.float32)) + # Simple Excess Green: ExG = 2G - R - B + exg = 2*g - r - b + exg = cv2.normalize(exg, None, 0, 255, cv2.NORM_MINMAX).astype(np.uint8) + # Automatic threshold (Otsu) + thr_val, thr = cv2.threshold(exg, 0, 255, cv2.THRESH_BINARY+cv2.THRESH_OTSU) + # Cleanup: opening/closing + kernel = np.ones((3,3), np.uint8) + thr = cv2.morphologyEx(thr, cv2.MORPH_OPEN, kernel, iterations=1) + thr = cv2.morphologyEx(thr, cv2.MORPH_CLOSE, kernel, iterations=1) + return (thr > 0).astype(np.uint8) # 0/1 + +def process_one(path): + bgr = cv2.imread(path) + if bgr is None: + print(f"[warn] cannot read: {path}") + return + mask01 = exg_mask(bgr) # [H,W] uint8 0/1 + out = (mask01 * 255).astype(np.uint8) + name = os.path.splitext(os.path.basename(path))[0] + ".png" + cv2.imwrite(os.path.join(MASK_DIR, name), out) + +def main(): + for fn in os.listdir(IM_DIR): + if fn.lower().endswith((".jpg",".jpeg",".png",".bmp","tif","tiff")): + process_one(os.path.join(IM_DIR, fn)) + print("done. masks saved to:", MASK_DIR) + +if __name__ == "__main__": + main() diff --git a/services/weed_detection/scripts/run_detection.py b/services/weed_detection/scripts/run_detection.py new file mode 100644 index 000000000..e5bc9f234 --- /dev/null +++ b/services/weed_detection/scripts/run_detection.py @@ -0,0 +1,137 @@ +""" +run_batch.py + +Purpose: +- Run the disease-detection batch pipeline either from a LOCAL folder of images + or from a MinIO bucket (objects are first downloaded to a local cache dir, + then processed exactly like local files). + +Usage examples: +1) Local folder (backward-compatible): + python -m agri_baseline.scripts.run_batch --storage local --images ./data/images + +2) MinIO (reads config from ENV and optional CLI flags): + python -m agri_baseline.scripts.run_batch --storage minio --minio-prefix "" + +Environment variables (typical .env): +- STORAGE_BACKEND=minio|local +- MINIO_ENDPOINT=127.0.0.1:9000 +- MINIO_ACCESS_KEY=minioadmin +- MINIO_SECRET_KEY=minioadmin +- MINIO_BUCKET=leaves +- MINIO_SECURE=false +- MINIO_PREFIX=mission-123/ (optional) +- MINIO_CACHE_DIR=./data/_minio_cache +""" + +import argparse +import os +from pathlib import Path + +from src.pipeline.logging_setup import setup_logging +from src.pipeline import config +from src.batch_runner import BatchRunner + +# MinIO helpers provided in your project +from services.minio_client import load_minio_config # loads config from ENV +from services.minio_sync import download_prefix_to_dir, ensure_bucket + + +def run_local(images_dir: Path) -> None: + """ + LOCAL mode: + - Run the batch pipeline over a local folder of images. + - This preserves the original behavior for backward compatibility. + """ + runner = BatchRunner() + runner.run_folder(images_dir) + + +def run_minio(prefix: str, cache_dir: Path) -> None: + """ + MINIO mode: + - Pull objects from a MinIO bucket (based on ENV config). + - Download them to a local cache directory. + - Run the batch pipeline over the downloaded files. + """ + cfg = load_minio_config() + ensure_bucket(cfg) # Safety: create the bucket if it doesn't exist + + cache_dir.mkdir(parents=True, exist_ok=True) + + # Download objects under 'prefix' into the local cache folder + downloaded = download_prefix_to_dir(cfg, prefix=prefix, local_dir=cache_dir) + if not downloaded: + raise SystemExit( + f"No objects found in bucket '{cfg.bucket}' with prefix '{prefix}'." + ) + + runner = BatchRunner() + runner.run_folder(cache_dir) + + +def parse_args() -> argparse.Namespace: + """ + Parse CLI arguments and provide sensible defaults from ENV where applicable. + """ + ap = argparse.ArgumentParser(description="Run batch pipeline (local/minio).") + + # Backward-compatible local images folder + ap.add_argument( + "--images", + default=config.IMAGES_DIR, + help="Folder of input images (LOCAL mode)", + ) + + # Storage backend selector + ap.add_argument( + "--storage", + choices=["local", "minio"], + default=os.getenv("STORAGE_BACKEND", "local").lower(), + help="Where to read images from (local|minio).", + ) + + # MinIO options (with ENV fallbacks) + ap.add_argument( + "--minio-prefix", + default=os.getenv("MINIO_PREFIX", ""), + help="Object prefix inside the bucket (e.g. 'mission-123/').", + ) + ap.add_argument( + "--minio-cache", + default=os.getenv("MINIO_CACHE_DIR", "./data/_minio_cache"), + help="Local temp folder used to download MinIO objects before processing.", + ) + + return ap.parse_args() + + +def main() -> None: + """ + Entry point: + - Logs chosen backend. + - Dispatches to local/minio flows. + - Keeps logs concise and informative for CI/ops. + """ + log = setup_logging() + args = parse_args() + + log.info(f"Storage backend: {args.storage}") + + if args.storage == "local": + images_dir = Path(args.images) + log.info(f"Starting batch over LOCAL folder: {images_dir}") + run_local(images_dir) + log.info("Batch done (local).") + else: + cache_dir = Path(args.minio_cache) + log.info( + "Starting batch over MINIO: " + f"bucket from ENV, prefix='{args.minio_prefix}', cache='{cache_dir}'" + ) + run_minio(prefix=args.minio_prefix, cache_dir=cache_dir) + log.info("Batch done (minio).") + + +if __name__ == "__main__": + main() diff --git a/services/weed_detection/services/io.py b/services/weed_detection/services/io.py new file mode 100644 index 000000000..57014b783 --- /dev/null +++ b/services/weed_detection/services/io.py @@ -0,0 +1,207 @@ +from __future__ import annotations + +import json +import logging +from typing import Tuple, Iterable, Dict, Any, List + +import pandas as pd +from sqlalchemy import create_engine, text + +LOGGER = logging.getLogger(__name__) + +# --------------------------------------------------------------------- +# Postgres sources: anomalies / anomaly_types / regions +# --------------------------------------------------------------------- + +_BASE_SQLS: Dict[str, str] = { + "device": """ + SELECT a.ts AS "timestamp", + a.device_id AS entity_id, + at.code AS disease_type, + COALESCE(a.severity::double precision, 0.0) AS severity, + 0.0 AS affected_area + FROM public.anomalies a + JOIN public.anomaly_types at ON at.anomaly_type_id = a.anomaly_type_id + WHERE a.ts IS NOT NULL + {AND_CODE_FILTER} + {AND_TIME_RANGE} + """, + "mission": """ + SELECT a.ts AS "timestamp", + a.mission_id::text AS entity_id, + at.code AS disease_type, + COALESCE(a.severity::double precision, 0.0) AS severity, + 0.0 AS affected_area + FROM public.anomalies a + JOIN public.anomaly_types at ON at.anomaly_type_id = a.anomaly_type_id + WHERE a.ts IS NOT NULL + {AND_CODE_FILTER} + {AND_TIME_RANGE} + """, + "region": """ + SELECT a.ts AS "timestamp", + r.id::text AS entity_id, + at.code AS disease_type, + COALESCE(a.severity::double precision, 0.0) AS severity, + {AREA_EXPR} AS affected_area + FROM public.anomalies a + JOIN public.anomaly_types at ON at.anomaly_type_id = a.anomaly_type_id + JOIN public.regions r ON ST_Contains(r.geom, a.geom) + WHERE a.ts IS NOT NULL AND a.geom IS NOT NULL + {AND_CODE_FILTER} + {AND_TIME_RANGE} + """, +} + + +def _build_sql( + entity_dim: str, + area_strategy: str, + codes: List[str] | None, + start: str | None, + end: str | None, +) -> tuple[str, dict]: + """ + Build parametrized SQL for reading anomalies with chosen entity dimension and area strategy. + """ + sql = _BASE_SQLS[entity_dim] + area_expr = "0.0" + if entity_dim == "region" and area_strategy == "region_area": + area_expr = "ST_Area(r.geom::geography)::double precision" + + and_code = "" + params: Dict[str, Any] = {} + if codes: + and_code = "AND at.code = ANY(:codes)" + params["codes"] = codes + + and_time = "" + if start: + and_time += " AND a.ts >= :start_time" + params["start_time"] = start + if end: + and_time += " AND a.ts < :end_time" + params["end_time"] = end + + sql = ( + sql.replace("{AREA_EXPR}", area_expr) + .replace("{AND_CODE_FILTER}", and_code) + .replace("{AND_TIME_RANGE}", and_time) + ) + return sql, params + + +# --------------------------------------------------------------------- +# Postgres input (canonical) +# --------------------------------------------------------------------- + +def load_inputs_from_postgres(pg_url: str, tz: str, cfg: dict) -> Tuple[pd.DataFrame, pd.DataFrame]: + """ + Load inputs from Postgres (public.anomalies/anomaly_types/regions). + Controlled by cfg['source_mapping'] (entity_dim, area_strategy, filters, codes). + Returns: + det: columns [timestamp, entity_id, disease_type, severity, affected_area] + reg: columns [entity_id, entity_type] + """ + edim = cfg["source_mapping"]["entity_dim"] + area = cfg["source_mapping"].get("area_strategy", "none") + codes = cfg["source_mapping"].get("anomaly_codes") + filters = cfg["source_mapping"].get("filters") or {} + start = filters.get("start_time") + end = filters.get("end_time") + + sql, params = _build_sql(edim, area, codes, start, end) + + eng = create_engine(pg_url) + with eng.begin() as conn: + det = pd.read_sql(text(sql), conn, params=params) + reg = det[["entity_id"]].drop_duplicates().assign(entity_type=edim) + + det["timestamp"] = pd.to_datetime(det["timestamp"], utc=True).dt.tz_convert(tz) + + required = {"timestamp", "entity_id", "disease_type", "severity", "affected_area"} + if not required.issubset(det.columns): + missing = required - set(det.columns) + raise ValueError(f"det: missing {missing}") + if not {"entity_id", "entity_type"}.issubset(reg.columns): + raise ValueError("reg: missing cols") + + return det, reg + + +# --------------------------------------------------------------------- +# Aggregation +# --------------------------------------------------------------------- + +def aggregate(det: pd.DataFrame, freq: str) -> pd.DataFrame: + """ + Aggregate by entity_id + window and compute disease_count, avg_severity, affected_area. + """ + df = det.copy() + + # Normalize tz: drop tz-info to use pandas period-based bucketing safely + if pd.api.types.is_datetime64tz_dtype(df["timestamp"]): + df["timestamp"] = df["timestamp"].dt.tz_convert("UTC").dt.tz_localize(None) + + df["window"] = df["timestamp"].dt.to_period(freq).dt.start_time + grp = df.groupby(["entity_id", "window"], as_index=False).agg( + disease_count=("disease_type", "count"), + avg_severity=("severity", "mean"), + affected_area=("affected_area", "sum"), + ) + grp["window_end"] = grp["window"] + pd.tseries.frequencies.to_offset(freq) + return grp + + +# --------------------------------------------------------------------- +# Alerts: Postgres backend +# --------------------------------------------------------------------- + + +def fetch_open_alerts_pg(pg_url: str) -> pd.DataFrame: + eng = create_engine(pg_url) + sql = """ + SELECT id, entity_id, rule, window_start, window_end, score, + first_seen, last_seen, status, meta_json + FROM public.alerts + WHERE status IN ('OPEN','ACK') + """ + with eng.begin() as conn: + df = pd.read_sql(text(sql), conn) + if not df.empty: + for c in ("first_seen", "last_seen", "window_start", "window_end"): + # make tz-aware UTC then drop tz -> naive UTC + s = pd.to_datetime(df[c], utc=True) + df[c] = s.dt.tz_convert("UTC").dt.tz_localize(None) + + return df + + +def upsert_alerts_pg(pg_url: str, alerts: Iterable[Dict[str, Any]]) -> None: + rows = list(alerts) + if not rows: + return + eng = create_engine(pg_url) + sql = """ + INSERT INTO public.alerts + (entity_id, rule, window_start, window_end, score, + first_seen, last_seen, status, meta_json) + VALUES + (:entity_id, :rule, :window_start, :window_end, :score, + :first_seen, :last_seen, :status, CAST(:meta_json AS jsonb)) + """ + payload = [{ + "entity_id": a["entity_id"], + "rule": a["rule"], + "window_start": a["window_start"], + "window_end": a["window_end"], + "score": float(a["score"]), + "first_seen": a["first_seen"], + "last_seen": a["last_seen"], + "status": a["status"], + "meta_json": json.dumps(a["meta"], ensure_ascii=False), + } for a in rows] + + with eng.begin() as conn: + conn.execute(text(sql), payload) + LOGGER.info("Inserted %d alerts into Postgres.", len(rows)) diff --git a/services/weed_detection/services/minio_client.py b/services/weed_detection/services/minio_client.py new file mode 100644 index 000000000..dd5effd69 --- /dev/null +++ b/services/weed_detection/services/minio_client.py @@ -0,0 +1,35 @@ +from __future__ import annotations + +import os +from dataclasses import dataclass +from minio import Minio + + +@dataclass(frozen=True) +class MinioConfig: + endpoint: str + access_key: str + secret_key: str + bucket: str + secure: bool + + +def load_minio_config() -> MinioConfig: + endpoint = os.getenv("MINIO_ENDPOINT", "localhost:9000") + access_key = os.getenv("MINIO_ACCESS_KEY", "") + secret_key = os.getenv("MINIO_SECRET_KEY", "") + bucket = os.getenv("MINIO_BUCKET", "my-bucket") + secure = os.getenv("MINIO_SECURE", "false").lower() == "true" + + if not access_key or not secret_key: + raise ValueError("Missing MINIO_ACCESS_KEY / MINIO_SECRET_KEY.") + return MinioConfig(endpoint, access_key, secret_key, bucket, secure) + + +def build_client(cfg: MinioConfig) -> Minio: + return Minio( + endpoint=cfg.endpoint, + access_key=cfg.access_key, + secret_key=cfg.secret_key, + secure=cfg.secure, + ) diff --git a/services/weed_detection/services/minio_sync.py b/services/weed_detection/services/minio_sync.py new file mode 100644 index 000000000..2f64195a5 --- /dev/null +++ b/services/weed_detection/services/minio_sync.py @@ -0,0 +1,86 @@ + +# services/minio_sync.py +from __future__ import annotations + +from io import BytesIO +from pathlib import Path +from typing import Iterable + +from .minio_client import MinioConfig, build_client + +ALLOWED_EXTS = {".jpg", ".jpeg", ".png", ".bmp", ".tif", ".tiff"} + + +def ensure_bucket(cfg: MinioConfig) -> None: + """ + Ensure the target bucket exists; create it if it does not. + """ + client = build_client(cfg) + if not client.bucket_exists(cfg.bucket): + client.make_bucket(cfg.bucket) + + +def download_prefix_to_dir(cfg: MinioConfig, prefix: str, local_dir: Path) -> list[Path]: + """ + Download all objects under the given `prefix` to the local directory. + Filters by ALLOWED_EXTS. Returns a list of local file paths. + """ + client = build_client(cfg) + local_dir.mkdir(parents=True, exist_ok=True) + + print(f"[minio] bucket={cfg.bucket} prefix='{prefix}' -> {local_dir}") + downloaded: list[Path] = [] + listed = 0 + objs = list(client.list_objects(cfg.bucket, prefix=prefix, recursive=True)) + print("ALL OBJECT KEYS (raw):", [o.object_name for o in objs]) + for obj in client.list_objects(cfg.bucket, prefix=prefix, recursive=True): + name = obj.object_name + if not name or name.endswith("/"): + continue # “תיקיות” מדומות + listed += 1 + + suf = Path(name).suffix.lower() + if suf not in ALLOWED_EXTS: + # אפשר להשתיק אם לא רוצים לוגים על סיומות + # print(f"[skip-ext] {name}") + continue + + target = local_dir / Path(name).name + response = client.get_object(cfg.bucket, name) + try: + data = response.read() + finally: + try: + response.close() + response.release_conn() + except Exception: + pass + + target.parent.mkdir(parents=True, exist_ok=True) + target.write_bytes(data) + downloaded.append(target) + print(f"[dl] {name} -> {target}") + + print(f"[minio] listed={listed}, downloaded={len(downloaded)}") + return downloaded + + +def upload_dir_to_prefix(cfg: MinioConfig, local_dir: Path, prefix: str) -> list[str]: + """ + Upload all files from `local_dir` under `prefix`. + Returns a list of object names uploaded. + """ + client = build_client(cfg) + ensure_bucket(cfg) + + uploaded: list[str] = [] + for path in local_dir.rglob("*"): + if not path.is_file(): + continue + rel = path.relative_to(local_dir).as_posix() + object_name = f"{prefix.rstrip('/')}/{rel}" + data = path.read_bytes() + bio = BytesIO(data) + client.put_object(cfg.bucket, object_name, bio, length=len(data)) + uploaded.append(object_name) + return uploaded diff --git a/services/weed_detection/src/__init__.py b/services/weed_detection/src/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/services/weed_detection/src/batch_runner.py b/services/weed_detection/src/batch_runner.py new file mode 100644 index 000000000..361e85ac2 --- /dev/null +++ b/services/weed_detection/src/batch_runner.py @@ -0,0 +1,214 @@ +# agri_baseline/src/batch_runner.py +# Max line length: 100 + +from __future__ import annotations + +import json +from dataclasses import asdict, is_dataclass +from datetime import datetime, timezone +from pathlib import Path +from typing import Tuple + +from src.pipeline.utils import ( + load_image, + image_id_from_path, + clamp_bbox, +) +from src.pipeline.db import ( + get_engine, + INSERT_DET, + INSERT_COUNT, + INSERT_QA, +) +from src.detectors.disease_model import DiseaseDetector + + +class BatchRunner: + """ + End-to-end runner: + - Load image + - Run disease detector + - Normalize detections + - Write anomalies / counts / QA to RelDB + """ + + def __init__(self, mission_id: int = 1, device_id: str = "device-1") -> None: + self.mission_id = mission_id + self.device_id = device_id # TEXT FK per schema v2 + self.engine = get_engine() + self.detector = DiseaseDetector() + + # ---------------------------- + # Public API + # ---------------------------- + + def run_folder(self, folder: Path | str) -> None: + """ + Run pipeline on all images within a folder (non-recursive). + Skips non-image files; prints minimal info. + """ + folder = Path(folder) + assert folder.exists(), f"Folder not found: {folder.resolve()}" + + image_paths = sorted( + p for p in folder.iterdir() if p.suffix.lower() in {".jpg", ".jpeg", ".png"} + ) + + total = 0 + total_dets = 0 + for img_path in image_paths: + try: + n = self.process_image(img_path) + total += 1 + total_dets += n + except Exception as ex: + # Keep output tidy; prefer structured logging in production + print(f"[WARN] Failed on {img_path.name}: {ex}") + + # Record a small QA summary + qa = { + "images_processed": total, + "detections_total": total_dets, + "ts": datetime.now(timezone.utc).isoformat(timespec="seconds"), + } + with self.engine.begin() as conn: + conn.execute(INSERT_QA, {"details": json.dumps(qa)}) + + def process_image(self, img_path: Path | str) -> int: + """ + Run pipeline on a single image, write detections and a simple per-image score. + Returns number of detections written. + """ + img_path = Path(img_path) + img, W, H = load_image(img_path) + + image_id = image_id_from_path(img_path) + dets = self.detector.run(img) + + print(f"{image_id}: found {len(dets)} disease spots") + + # Write detections as anomalies + written = 0 + for d in dets: + x, y, w, h = self._extract_bbox(d) + x, y, w, h = clamp_bbox(int(x), int(y), int(w), int(h), W, H) + cx = x + w / 2.0 + cy = y + h / 2.0 + + area = float(getattr(d, "area", w * h)) + label = str(getattr(d, "label", "disease")) + conf = float(getattr(d, "confidence", 1.0)) + + details = { + "image_id": image_id, + "label": label, + "bbox": [x, y, w, h], + "area": area, + "confidence": conf, + } + if is_dataclass(d): + details["raw_detection"] = asdict(d) + + with self.engine.begin() as conn: + conn.execute( + INSERT_DET, + dict( + mission_id=self.mission_id, + device_id=self.device_id, # TEXT FK + ts=datetime.now(timezone.utc), + anomaly_type_id=1, # seeded below + severity=conf, + details=json.dumps(details), + wkt_geom=f"POINT({cx} {cy})", + ), + ) + written += 1 + + # Per-image score → tile_stats (tile_id TEXT, geom POLYGON) + if dets: + anomaly_score = float(len(dets)) + poly_wkt = self._make_square_polygon_wkt(W / 2.0, H / 2.0, size=1.0) + with self.engine.begin() as conn: + conn.execute( + INSERT_COUNT, + dict( + mission_id=self.mission_id, + tile_id=image_id, # TEXT per schema v2 + anomaly_score=anomaly_score, + wkt_geom=poly_wkt, # POLYGON + ), + ) + + return written + + # ---------------------------- + # Internals + # ---------------------------- + + @staticmethod + def _extract_bbox(d) -> Tuple[float, float, float, float]: + """ + Normalize bbox to (x, y, w, h). Supports: + - d.x, d.y, d.w, d.h + - d.bbox == (x, y, w, h) + - d.xmin, d.ymin, d.xmax, d.ymax + - d.left, d.top, d.width, d.height + """ + if all(hasattr(d, a) for a in ("x", "y", "w", "h")): + return float(d.x), float(d.y), float(d.w), float(d.h) + + if hasattr(d, "bbox"): + bx = list(d.bbox) + if len(bx) != 4: + raise ValueError(f"Unexpected bbox length: {len(bx)} in {bx}") + x, y, w, h = map(float, bx) + return x, y, w, h + + if all(hasattr(d, a) for a in ("xmin", "ymin", "xmax", "ymax")): + x1, y1, x2, y2 = float(d.xmin), float(d.ymin), float(d.xmax), float(d.ymax) + return x1, y1, max(0.0, x2 - x1), max(0.0, y2 - y1) + + if all(hasattr(d, a) for a in ("left", "top", "width", "height")): + return float(d.left), float(d.top), float(d.width), float(d.height) + + raise AttributeError( + "Detection bbox fields missing. Supported: " + "(x,y,w,h) or bbox or (xmin,ymin,xmax,ymax) or (left,top,width,height)." + ) + + @staticmethod + def _make_square_polygon_wkt(cx: float, cy: float, size: float = 1.0) -> str: + """ + Build a tiny square Polygon around (cx, cy) in WKT, closed ring. + PostGIS expects Polygon for tile_stats.geom (SRID 4326). + """ + x1, y1 = cx - size, cy - size + x2, y2 = cx + size, cy + size + return f"POLYGON(({x1} {y1}, {x2} {y1}, {x2} {y2}, {x1} {y2}, {x1} {y1}))" + + +# ------------- CLI helper ------------- + +def main() -> None: + """ + Local runner: + python -m agri_baseline.src.batch_runner --input + """ + import argparse + + parser = argparse.ArgumentParser(description="Run disease detection pipeline.") + parser.add_argument("--input", type=str, required=True, help="Image file or folder") + parser.add_argument("--mission", type=int, default=1, help="Numeric mission ID") + parser.add_argument("--device", type=str, default="device-1", help="Text device ID") + args = parser.parse_args() + + runner = BatchRunner(mission_id=args.mission, device_id=args.device) + in_path = Path(args.input) + if in_path.is_dir(): + runner.run_folder(in_path) + else: + runner.process_image(in_path) + + +if __name__ == "__main__": + main() diff --git a/services/weed_detection/src/detectors/__init__.py b/services/weed_detection/src/detectors/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/services/weed_detection/src/detectors/disease_model.py b/services/weed_detection/src/detectors/disease_model.py new file mode 100644 index 000000000..2428dc4b0 --- /dev/null +++ b/services/weed_detection/src/detectors/disease_model.py @@ -0,0 +1,471 @@ +# /src/detectors/disease_model.py +from __future__ import annotations +from dataclasses import dataclass +from pathlib import Path +from typing import List, Tuple, Union +import os +import numpy as np +import torch +import cv2 + +# If your import path is different (e.g., src.models.ml_model), update here: +from models.ml_model import MLWeedDetector + + +@dataclass +class Detection: + x: float + y: float + w: float + h: float + area: float + confidence: float + label: str = "disease" # You can change to "weed" if that's the desired name + + +class DiseaseDetector: + """ + Flow: + 1) Build a coarse mask + 2) Refine the mask using MLWeedDetector (your MobileNetV3 model) + 3) Convert to detections (bounding boxes) for DB writing + """ + + def __init__(self) -> None: + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + # ---- Load weights ---- + weights_env = os.getenv("WEIGHTS_REFINER", "").strip() + if weights_env: + weights_path = Path(weights_env) + else: + # Smart search inside the models directory + project_root = Path(__file__).resolve().parents[2] + models_dir = project_root / "models" + candidates = [ + models_dir / "weights_refiner.pth", + models_dir / "mobilenetv3_best.pth", + models_dir / "best.pth", + models_dir / "last.pth", + ] + found = [p for p in candidates if p.exists()] + if not found: + raise FileNotFoundError( + f"Could not find weights file. Set WEIGHTS_REFINER or put a weights .pth under {models_dir}" + ) + weights_path = found[0] + + # MLWeedDetector already handles Resize(224) + ImageNet normalization + self.refiner = MLWeedDetector(weights_path=str(weights_path), device=str(self.device)) + + # ---- Parameters configurable via .env ---- + self.min_component_area = int(os.getenv("MIN_COMPONENT_AREA", "200")) # Filter out small connected components + self.min_bbox_area = int(os.getenv("MIN_BBOX_AREA", "150")) # Filter out small bounding boxes + self.bin_thresh = int(os.getenv("REFINED_BIN_THRESH", "128")) # Binarization threshold after refinement + self.conf_area_norm = float(os.getenv("CONF_AREA_NORM", "10000")) # Normalize confidence by area + self.coarse_method = os.getenv("COARSE_METHOD", "OTSU").upper() # OTSU / HSV_GREEN + self.max_infer_side = int(os.getenv("MAX_INFER_SIDE", "0")) # 0 = no global downscale + + + # ---------- Main API ---------- + + def run( + self, + bgr_img: np.ndarray, + return_mask: bool = False + ) -> Union[List[Detection], Tuple[np.ndarray, List[Detection]]]: + """ + :param bgr_img: OpenCV image in BGR format + :param return_mask: If True, also return the refined mask (uint8 0/255) + """ + # Ensure contiguous to prevent negative strides + bgr = np.ascontiguousarray(bgr_img) + + bgr = self._maybe_downscale(bgr) + coarse = self._make_coarse(bgr) + # Ensure contiguous before refinement + coarse = np.ascontiguousarray(coarse) + + refined = self._refine_with_model(bgr, coarse) # 0..255 + refined_bin = self._binarize(refined, self.bin_thresh) # 0/255 + refined_bin = self._remove_small(refined_bin, self.min_component_area) + + detections = self._mask_to_detections(refined_bin) + + if return_mask: + return refined_bin, detections + return detections + + # ---------- Processing helpers ---------- + + def _maybe_downscale(self, bgr: np.ndarray) -> np.ndarray: + if not self.max_infer_side or self.max_infer_side <= 0: + return bgr + h, w = bgr.shape[:2] + m = max(h, w) + if m <= self.max_infer_side: + return bgr + scale = self.max_infer_side / float(m) + new_w, new_h = int(w * scale), int(h * scale) + out = cv2.resize(bgr, (new_w, new_h), interpolation=cv2.INTER_AREA) + return np.ascontiguousarray(out) + + def _make_coarse(self, bgr: np.ndarray) -> np.ndarray: + """ + Coarse mask: + - OTSU (default) + - or HSV_GREEN if COARSE_METHOD=HSV_GREEN + """ + if self.coarse_method == "HSV_GREEN": + hsv = cv2.cvtColor(bgr, cv2.COLOR_BGR2HSV) + # Generic green range; you can calibrate according to your data: + lower = np.array([35, 40, 20], dtype=np.uint8) + upper = np.array([85, 255, 255], dtype=np.uint8) + mask = cv2.inRange(hsv, lower, upper) # 0/255 + return np.ascontiguousarray(mask) + + # OTSU (default) + gray = cv2.cvtColor(bgr, cv2.COLOR_BGR2GRAY) + _, mask = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU) + return np.ascontiguousarray(mask) + + def _refine_with_model(self, bgr: np.ndarray, coarse: np.ndarray) -> np.ndarray: + """ + Refinement using your trained model. + MLWeedDetector.score_mask(bgr, coarse) -> mask 0..255 + """ + # Ensure contiguous to avoid negative strides inside the refiner + bgr = np.ascontiguousarray(bgr) + coarse = np.ascontiguousarray(coarse) + + refined = self.refiner.score_mask(bgr, coarse) + if refined.dtype != np.uint8: + refined = np.clip(refined, 0, 255).astype(np.uint8) + return refined + + def _binarize(self, mask: np.ndarray, thresh: int) -> np.ndarray: + if mask.dtype != np.uint8: + mask = np.clip(mask, 0, 255).astype(np.uint8) + _, mask_bin = cv2.threshold(mask, thresh, 255, cv2.THRESH_BINARY) + return np.ascontiguousarray(mask_bin) + + def _remove_small(self, mask_bin: np.ndarray, min_area: int) -> np.ndarray: + if min_area <= 1: + return mask_bin + m01 = (mask_bin > 0).astype(np.uint8) + num_labels, labels = cv2.connectedComponents(m01) + out = np.zeros_like(m01) + for i in range(1, num_labels): # 0 = background + comp = (labels == i) + if int(comp.sum()) >= min_area: + out[comp] = 1 + return (out * 255).astype(np.uint8) + + def _mask_to_detections(self, mask_bin: np.ndarray) -> List[Detection]: + num, labels, stats, _ = cv2.connectedComponentsWithStats( + (mask_bin > 0).astype(np.uint8), connectivity=8 + ) + dets: List[Detection] = [] + for i in range(1, num): # 0 = background + x, y, w, h, area = stats[i] + if area < self.min_component_area: + continue + if (w * h) < self.min_bbox_area: + continue + conf = float(min(1.0, area / max(1.0, self.conf_area_norm))) + dets.append( + Detection( + x=float(x), y=float(y), w=float(w), h=float(h), + area=float(area), confidence=conf, label="disease" + ) + ) + return dets + + +# # /src/detectors/disease_model.py +# from __future__ import annotations +# from dataclasses import dataclass +# from pathlib import Path +# from typing import List, Tuple, Union, Optional, Dict, Any +# import os +# import uuid +# from datetime import datetime, timezone + +# import numpy as np +# import torch +# import cv2 +# import requests + +# # If your import path is different (e.g., src.models.ml_model), update here: +# from models.ml_model import MLWeedDetector + + +# @dataclass +# class Detection: +# x: float +# y: float +# w: float +# h: float +# area: float +# confidence: float +# label: str = "disease" # You can change to "weed" if that's the desired name + + +# class DiseaseDetector: +# """ +# Flow: +# 1) Build a coarse mask +# 2) Refine the mask using MLWeedDetector (your MobileNetV3 model) +# 3) Convert to detections (bounding boxes) for DB writing +# 4) Compute severity from detections and POST JSON alert if severity > threshold +# """ + +# def __init__(self) -> None: +# self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +# # ---- Load weights ---- +# weights_env = os.getenv("WEIGHTS_REFINER", "").strip() +# if weights_env: +# weights_path = Path(weights_env) +# else: +# # Smart search inside the models directory +# project_root = Path(__file__).resolve().parents[2] +# models_dir = project_root / "models" +# candidates = [ +# models_dir / "weights_refiner.pth", +# models_dir / "mobilenetv3_best.pth", +# models_dir / "best.pth", +# models_dir / "last.pth", +# ] +# found = [p for p in candidates if p.exists()] +# if not found: +# raise FileNotFoundError( +# f"Could not find weights file. Set WEIGHTS_REFINER or put a weights .pth under {models_dir}" +# ) +# weights_path = found[0] + +# # MLWeedDetector already handles Resize(224) + ImageNet normalization +# self.refiner = MLWeedDetector(weights_path=str(weights_path), device=str(self.device)) + +# # ---- Parameters configurable via .env ---- +# self.min_component_area = int(os.getenv("MIN_COMPONENT_AREA", "200")) # Filter out small connected components +# self.min_bbox_area = int(os.getenv("MIN_BBOX_AREA", "150")) # Filter out small bounding boxes +# self.bin_thresh = int(os.getenv("REFINED_BIN_THRESH", "128")) # Binarization threshold after refinement +# self.conf_area_norm = float(os.getenv("CONF_AREA_NORM", "10000")) # Normalize confidence by area +# self.coarse_method = os.getenv("COARSE_METHOD", "OTSU").upper() # OTSU / HSV_GREEN +# self.max_infer_side = int(os.getenv("MAX_INFER_SIDE", "0")) # 0 = no global downscale + +# # ---- Alerting (JSON POST) ---- +# self.alert_thresh = float(os.getenv("ALERT_SEVERITY_THRESH", "0.3")) +# self.alert_url = os.getenv("ALERT_URL", "http://localhost:8090/alerts") +# self.device_id = os.getenv("DEVICE_ID", "camera-12") +# self.area_name = os.getenv("AREA_NAME", "") or None # optional + +# # ---------- Main API ---------- + +# def run( +# self, +# bgr_img: np.ndarray, +# return_mask: bool = False, +# *, +# # Optional context to include in the alert JSON if available: +# image_url: Optional[str] = None, +# lat: Optional[float] = None, +# lon: Optional[float] = None, +# alert_type: str = "disease_detected", +# meta: Optional[Dict[str, Any]] = None, +# ) -> Union[List[Detection], Tuple[np.ndarray, List[Detection]]]: +# """ +# :param bgr_img: OpenCV image in BGR format +# :param return_mask: If True, also return the refined mask (uint8 0/255) +# :param image_url: Optional image URL for the alert JSON +# :param lat: Optional latitude for the alert JSON +# :param lon: Optional longitude for the alert JSON +# :param alert_type: Alert type string for the alert JSON +# :param meta: Optional metadata dict to include in the alert JSON +# """ +# # Ensure contiguous to prevent negative strides +# bgr = np.ascontiguousarray(bgr_img) + +# bgr = self._maybe_downscale(bgr) +# coarse = self._make_coarse(bgr) +# # Ensure contiguous before refinement +# coarse = np.ascontiguousarray(coarse) + +# refined = self._refine_with_model(bgr, coarse) # 0..255 +# refined_bin = self._binarize(refined, self.bin_thresh) # 0/255 +# refined_bin = self._remove_small(refined_bin, self.min_component_area) + +# detections = self._mask_to_detections(refined_bin) + +# # ---- Compute severity from detections and POST JSON if above threshold ---- +# severity = self._severity_from_detections(detections) +# if severity > self.alert_thresh: +# # For "confidence" field, we send the same scalar as severity by default. +# # Replace if you prefer a different definition. +# self._post_alert_json( +# severity=severity, +# alert_type=alert_type, +# image_url=image_url, +# lat=lat, +# lon=lon, +# confidence=severity, +# meta=meta, +# ) + +# if return_mask: +# return refined_bin, detections +# return detections + +# # ---------- Processing helpers ---------- + +# def _maybe_downscale(self, bgr: np.ndarray) -> np.ndarray: +# if not self.max_infer_side or self.max_infer_side <= 0: +# return bgr +# h, w = bgr.shape[:2] +# m = max(h, w) +# if m <= self.max_infer_side: +# return bgr +# scale = self.max_infer_side / float(m) +# new_w, new_h = int(w * scale), int(h * scale) +# out = cv2.resize(bgr, (new_w, new_h), interpolation=cv2.INTER_AREA) +# return np.ascontiguousarray(out) + +# def _make_coarse(self, bgr: np.ndarray) -> np.ndarray: +# """ +# Coarse mask: +# - OTSU (default) +# - or HSV_GREEN if COARSE_METHOD=HSV_GREEN +# """ +# if self.coarse_method == "HSV_GREEN": +# hsv = cv2.cvtColor(bgr, cv2.COLOR_BGR2HSV) +# # Generic green range; you can calibrate according to your data: +# lower = np.array([35, 40, 20], dtype=np.uint8) +# upper = np.array([85, 255, 255], dtype=np.uint8) +# mask = cv2.inRange(hsv, lower, upper) # 0/255 +# return np.ascontiguousarray(mask) + +# # OTSU (default) +# gray = cv2.cvtColor(bgr, cv2.COLOR_BGR2GRAY) +# _, mask = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU) +# return np.ascontiguousarray(mask) + +# def _refine_with_model(self, bgr: np.ndarray, coarse: np.ndarray) -> np.ndarray: +# """ +# Refinement using your trained model. +# MLWeedDetector.score_mask(bgr, coarse) -> mask 0..255 +# """ +# # Ensure contiguous to avoid negative strides inside the refiner +# bgr = np.ascontiguousarray(bgr) +# coarse = np.ascontiguousarray(coarse) + +# refined = self.refiner.score_mask(bgr, coarse) +# if refined.dtype != np.uint8: +# refined = np.clip(refined, 0, 255).astype(np.uint8) +# return refined + +# def _binarize(self, mask: np.ndarray, thresh: int) -> np.ndarray: +# if mask.dtype != np.uint8: +# mask = np.clip(mask, 0, 255).astype(np.uint8) +# _, mask_bin = cv2.threshold(mask, thresh, 255, cv2.THRESH_BINARY) +# return np.ascontiguousarray(mask_bin) + +# def _remove_small(self, mask_bin: np.ndarray, min_area: int) -> np.ndarray: +# if min_area <= 1: +# return mask_bin +# m01 = (mask_bin > 0).astype(np.uint8) +# num_labels, labels = cv2.connectedComponents(m01) +# out = np.zeros_like(m01) +# for i in range(1, num_labels): # 0 = background +# comp = (labels == i) +# if int(comp.sum()) >= min_area: +# out[comp] = 1 +# return (out * 255).astype(np.uint8) + +# def _mask_to_detections(self, mask_bin: np.ndarray) -> List[Detection]: +# num, labels, stats, _ = cv2.connectedComponentsWithStats( +# (mask_bin > 0).astype(np.uint8), connectivity=8 +# ) +# dets: List[Detection] = [] +# for i in range(1, num): # 0 = background +# x, y, w, h, area = stats[i] +# if area < self.min_component_area: +# continue +# if (w * h) < self.min_bbox_area: +# continue +# conf = float(min(1.0, area / max(1.0, self.conf_area_norm))) +# dets.append( +# Detection( +# x=float(x), y=float(y), w=float(w), h=float(h), +# area=float(area), confidence=conf, label="disease" +# ) +# ) +# return dets + +# # ---------- Severity & Alert JSON ---------- + +# def _severity_from_detections(self, dets: List[Detection]) -> float: +# """ +# Define severity from detections. +# Default: max confidence over all detections. +# You can change this to sum/mean/etc if needed. +# """ +# return max((d.confidence for d in dets), default=0.0) + +# def _post_alert_json( +# self, +# *, +# severity: float, +# alert_type: str, +# image_url: Optional[str], +# lat: Optional[float], +# lon: Optional[float], +# confidence: Optional[float], +# meta: Optional[Dict[str, Any]], +# ) -> None: +# """ +# POST a JSON alert to the configured Alertmanager URL. +# If your Alertmanager expects a list of alerts, send [payload] instead of payload. +# """ +# payload: Dict[str, Any] = { +# # --- Required fields --- +# "alert_id": str(uuid.uuid4()), +# "alert_type": alert_type, +# "device_id": self.device_id, +# "started_at": datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ"), +# # --- Optional / dynamics --- +# "severity": float(severity), +# } +# if confidence is not None: +# payload["confidence"] = float(confidence) +# if self.area_name: +# payload["area"] = self.area_name +# if lat is not None: +# payload["lat"] = float(lat) +# if lon is not None: +# payload["lon"] = float(lon) +# if image_url: +# payload["image_url"] = image_url +# if meta: +# payload["meta"] = meta + +# # Some setups require a list of alerts: requests.post(self.alert_url, json=[payload], timeout=5) +# requests.post(self.alert_url, json=payload, timeout=5) + + +# # Optional: quick local test (won't run in production compose) +# if __name__ == "__main__": +# # Minimal smoke test for severity/POST path (won't run inference here). +# dets = [ +# Detection(x=0, y=0, w=10, h=10, area=4000, confidence=0.2), +# Detection(x=15, y=15, w=20, h=20, area=9000, confidence=0.6), +# ] +# dd = DiseaseDetector() +# sev = dd._severity_from_detections(dets) +# print("Severity example:", sev) +# if sev > dd.alert_thresh: +# dd._post_alert_json( +# severity=sev, alert_type="disease_detected", +# image_url=None, lat=None, lon=None, +# confidence=sev, meta={"note": "dry run"}, +# ) diff --git a/services/weed_detection/src/pipeline/__init__.py b/services/weed_detection/src/pipeline/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/services/weed_detection/src/pipeline/config.py b/services/weed_detection/src/pipeline/config.py new file mode 100644 index 000000000..18d696e0a --- /dev/null +++ b/services/weed_detection/src/pipeline/config.py @@ -0,0 +1,25 @@ +from __future__ import annotations + +import os +from pathlib import Path + +# Try to load env files both from project root and from agri_baseline/.env +try: + from dotenv import load_dotenv # type: ignore + load_dotenv(dotenv_path=Path("agri_baseline/.env"), override=False) + load_dotenv(override=False) +except Exception: + pass + +# Prefer standard name DATABASE_URL; fallback to DB_URL; finally default to localhost:5432 +DB_URL: str = ( + os.getenv("DATABASE_URL") + or os.getenv("DB_URL") + or "postgresql+psycopg2://missions_user:pg123@localhost:5432/missions_db" +) + +IMAGES_DIR = os.getenv("IMAGES_DIR", "./data/images") +BATCH_SIZE = int(os.getenv("BATCH_SIZE", 64)) +MAX_WORKERS = int(os.getenv("MAX_WORKERS", 4)) +MIN_BBOX_AREA = int(os.getenv("MIN_BBOX_AREA", 60)) +MIN_COMPONENT_AREA = int(os.getenv("MIN_COMPONENT_AREA", 200)) diff --git a/services/weed_detection/src/pipeline/db.py b/services/weed_detection/src/pipeline/db.py new file mode 100644 index 000000000..5fb271655 --- /dev/null +++ b/services/weed_detection/src/pipeline/db.py @@ -0,0 +1,28 @@ +# /src/pipeline/db.py +from __future__ import annotations +import os +from sqlalchemy import create_engine, text + +def get_engine(): + db_url = os.getenv("DB_URL") + if not db_url: + raise RuntimeError("DB_URL is not set in environment") + # echo=False לשקט; אפשר True לדיבוג + return create_engine(db_url, future=True) + +# משפטי INSERT בהתאם לשדות שה-Runner מזין +# הערה: מיועד ל-PostgreSQL + PostGIS. אם SQLite – צריך להתאים (לשמור WKT כ-TEXT). +INSERT_DET = text(""" +INSERT INTO anomalies (mission_id, device_id, ts, anomaly_type_id, severity, details, geom) +VALUES (:mission_id, :device_id, :ts, :anomaly_type_id, :severity, CAST(:details AS JSONB), + ST_GeomFromText(:wkt_geom, 4326)) +""") + +INSERT_COUNT = text(""" +INSERT INTO tile_stats (mission_id, tile_id, anomaly_score, geom) +VALUES (:mission_id, :tile_id, :anomaly_score, ST_GeomFromText(:wkt_geom, 4326)) +""") + +INSERT_QA = text(""" +INSERT INTO qa_runs (details) VALUES (CAST(:details AS JSONB)) +""") diff --git a/services/weed_detection/src/pipeline/logging_setup.py b/services/weed_detection/src/pipeline/logging_setup.py new file mode 100644 index 000000000..9904581c8 --- /dev/null +++ b/services/weed_detection/src/pipeline/logging_setup.py @@ -0,0 +1,11 @@ +# /src/pipeline/logging_setup.py +import logging +import os + +def setup_logging(): + level = os.getenv("LOG_LEVEL", "INFO").upper() + logging.basicConfig( + level=getattr(logging, level, logging.INFO), + format="%(asctime)s | %(levelname)s | %(message)s", + ) + return logging.getLogger("agri-baseline") diff --git a/services/weed_detection/src/pipeline/utils.py b/services/weed_detection/src/pipeline/utils.py new file mode 100644 index 000000000..9dfdc8bb6 --- /dev/null +++ b/services/weed_detection/src/pipeline/utils.py @@ -0,0 +1,25 @@ +# /src/pipeline/utils.py +from __future__ import annotations +from pathlib import Path +import cv2 +import numpy as np + + +def load_image(path: str | Path): + p = Path(path) + img = cv2.imread(str(p), cv2.IMREAD_COLOR) + if img is None: + raise FileNotFoundError(f"Failed to load image: {p}") + img = np.ascontiguousarray(img) + h, w = img.shape[:2] + return img, w, h # זה הפורמט שה-Runner שלך מצפה לו + +def image_id_from_path(p: str | Path) -> str: + return Path(p).stem + +def clamp_bbox(x: int, y: int, w: int, h: int, W: int, H: int): + x = max(0, min(x, W - 1)) + y = max(0, min(y, H - 1)) + w = max(0, min(w, W - x)) + h = max(0, min(h, H - y)) + return x, y, w, h diff --git a/simulators/.env.example b/simulators/.env.example new file mode 100644 index 000000000..429fc70fc --- /dev/null +++ b/simulators/.env.example @@ -0,0 +1,19 @@ +# --- General --- +IMAGES_DIR=/data/images +META_DIR=/data/metadata +CAMERA_ID=drone-01 + +# --- Data broker (Minio bridge) --- +MQTT_HOST_DATA=large-mosquitto +MQTT_PORT_DATA=1885 +MQTT_TOPIC_DATA=MQTT/imagery/air + +# --- Meta broker (kafka bridge) --- +MQTT_HOST_META=mosquitto +MQTT_PORT_META=1883 +MQTT_TOPIC_META=mqtt/aerial/images/metadata + +# --- Publishing behavior --- +INTERVAL_CHECK=10 +INTERVAL_PUBLISH=10 +MQTT_QOS=1 diff --git a/simulators/Dockerfile b/simulators/Dockerfile index aecc9517e..7eeed2e60 100644 --- a/simulators/Dockerfile +++ b/simulators/Dockerfile @@ -2,7 +2,7 @@ FROM python:3.12-slim # Copy the NetFree certificate into the container -COPY netfree-ca.crt /usr/local/share/ca-certificates/netfree-ca.crt +COPY certs/*.crt /usr/local/share/ca-certificates/ # Install system dependencies, add the certificate, and clean cache RUN apt-get update && \ diff --git a/simulators/data/sound/metadata/mic-1_20250916T191623Z.json b/simulators/data/sound/metadata/mic-1_20250916T191623Z.json new file mode 100644 index 000000000..f101dc592 --- /dev/null +++ b/simulators/data/sound/metadata/mic-1_20250916T191623Z.json @@ -0,0 +1,14 @@ +{ + "file_name": "mic-1_20250916T191623Z.wav", + "device_id": "mic-1", + "capture_time": "2025-09-16T19:16:23Z", + "duration_sec": 5.0, + "done": false, + "sample_rate_hz": 16000, + "channels": 1, + "content_type": "audio/wav", + "gis_origin": { + "latitude": 31.89561, + "longitude": 34.9681 + } +} \ No newline at end of file diff --git a/simulators/data/sound/metadata/mic-1_20250917T162528Z.json b/simulators/data/sound/metadata/mic-1_20250917T162528Z.json new file mode 100644 index 000000000..1025f570f --- /dev/null +++ b/simulators/data/sound/metadata/mic-1_20250917T162528Z.json @@ -0,0 +1,14 @@ +{ + "file_name": "mic-1_20250917T162528Z.wav", + "device_id": "mic-1", + "capture_time": "2025-09-17T16:25:28Z", + "duration_sec": 5.0, + "done": false, + "sample_rate_hz": 44100, + "channels": 2, + "content_type": "audio/wav", + "gis_origin": { + "latitude": 31.89561, + "longitude": 34.9681 + } +} \ No newline at end of file diff --git a/simulators/data/sound/metadata/mic-2_20250917T182119Z.json b/simulators/data/sound/metadata/mic-2_20250917T182119Z.json new file mode 100644 index 000000000..254663350 --- /dev/null +++ b/simulators/data/sound/metadata/mic-2_20250917T182119Z.json @@ -0,0 +1,14 @@ +{ + "file_name": "mic-2_20250917T182119Z.wav", + "device_id": "mic-2", + "capture_time": "2025-09-17T18:21:19Z", + "duration_sec": 5.0, + "done": false, + "sample_rate_hz": 16000, + "channels": 1, + "content_type": "audio/wav", + "gis_origin": { + "latitude": 31.89561, + "longitude": 34.9681 + } +} \ No newline at end of file diff --git a/simulators/data/sound/metadata/mic-33_20250917T162407Z.json b/simulators/data/sound/metadata/mic-33_20250917T162407Z.json new file mode 100644 index 000000000..084bfbe9b --- /dev/null +++ b/simulators/data/sound/metadata/mic-33_20250917T162407Z.json @@ -0,0 +1,14 @@ +{ + "file_name": "mic-33_20250917T162407Z.wav", + "device_id": "mic-33", + "capture_time": "2025-09-17T16:24:07Z", + "duration_sec": 50.0, + "done": false, + "sample_rate_hz": 44100, + "channels": 2, + "content_type": "audio/wav", + "gis_origin": { + "latitude": 31.89561, + "longitude": 34.9681 + } +} \ No newline at end of file diff --git a/simulators/data/sound/sounds/mic-1_20250916T191623Z.wav b/simulators/data/sound/sounds/mic-1_20250916T191623Z.wav new file mode 100644 index 000000000..ec92fed17 Binary files /dev/null and b/simulators/data/sound/sounds/mic-1_20250916T191623Z.wav differ diff --git a/simulators/data/sound/sounds/mic-1_20250917T162528Z.wav b/simulators/data/sound/sounds/mic-1_20250917T162528Z.wav new file mode 100644 index 000000000..16161e7ec Binary files /dev/null and b/simulators/data/sound/sounds/mic-1_20250917T162528Z.wav differ diff --git a/simulators/data/sound/sounds/mic-2_20250917T182119Z.wav b/simulators/data/sound/sounds/mic-2_20250917T182119Z.wav new file mode 100644 index 000000000..6752a67fe Binary files /dev/null and b/simulators/data/sound/sounds/mic-2_20250917T182119Z.wav differ diff --git a/simulators/data/sound/sounds/mic-33_20250917T162407Z.wav b/simulators/data/sound/sounds/mic-33_20250917T162407Z.wav new file mode 100644 index 000000000..2c40326b8 Binary files /dev/null and b/simulators/data/sound/sounds/mic-33_20250917T162407Z.wav differ diff --git a/simulators/data/ultra-sound/metadata/mic-u-2_20251003T120500Z.json b/simulators/data/ultra-sound/metadata/mic-u-2_20251003T120500Z.json new file mode 100644 index 000000000..43ce7d2a8 --- /dev/null +++ b/simulators/data/ultra-sound/metadata/mic-u-2_20251003T120500Z.json @@ -0,0 +1,14 @@ +{ + "file_name": "mic-u-2_20251003T120500Z.wav", + "device_id": "mic-u-2", + "capture_time": "2025-10-03T12:05:00Z", + "duration_sec": 0.002, + "done": false, + "sample_rate_hz": 500000, + "channels": 1, + "content_type": "audio/wav", + "gis_origin": { + "latitude": 32.89561, + "longitude": 30.9681 + } +} \ No newline at end of file diff --git a/simulators/data/ultra-sound/metadata/mic-u-2_20251101T120500Z.json b/simulators/data/ultra-sound/metadata/mic-u-2_20251101T120500Z.json new file mode 100644 index 000000000..d6ca0d3fc --- /dev/null +++ b/simulators/data/ultra-sound/metadata/mic-u-2_20251101T120500Z.json @@ -0,0 +1,14 @@ +{ + "file_name": "mic-u-2_20251101T120500Z.wav", + "device_id": "mic-u-2", + "capture_time": "2025-11-01T12:05:00Z", + "duration_sec": 0.002, + "done": false, + "sample_rate_hz": 500000, + "channels": 1, + "content_type": "audio/wav", + "gis_origin": { + "latitude": 32.89561, + "longitude": 30.9681 + } +} \ No newline at end of file diff --git a/simulators/data/ultra-sound/metadata/mic-u-2_20251102T120500Z.json b/simulators/data/ultra-sound/metadata/mic-u-2_20251102T120500Z.json new file mode 100644 index 000000000..d5a7b2360 --- /dev/null +++ b/simulators/data/ultra-sound/metadata/mic-u-2_20251102T120500Z.json @@ -0,0 +1,14 @@ +{ + "file_name": "mic-u-2_20251102T120500Z.wav", + "device_id": "mic-u-2", + "capture_time": "2025-11-02T12:05:00Z", + "duration_sec": 0.002, + "done": false, + "sample_rate_hz": 500000, + "channels": 1, + "content_type": "audio/wav", + "gis_origin": { + "latitude": 32.89561, + "longitude": 30.9681 + } +} \ No newline at end of file diff --git a/simulators/data/ultra-sound/metadata/mic-u-2_20251102T140500Z.json b/simulators/data/ultra-sound/metadata/mic-u-2_20251102T140500Z.json new file mode 100644 index 000000000..fdae67d8b --- /dev/null +++ b/simulators/data/ultra-sound/metadata/mic-u-2_20251102T140500Z.json @@ -0,0 +1,14 @@ +{ + "file_name": "mic-u-2_20251102T140500Z.wav", + "device_id": "mic-u-2", + "capture_time": "2025-11-02T14:05:00Z", + "duration_sec": 0.002, + "done": false, + "sample_rate_hz": 500000, + "channels": 1, + "content_type": "audio/wav", + "gis_origin": { + "latitude": 32.89561, + "longitude": 30.9681 + } +} \ No newline at end of file diff --git a/simulators/data/ultra-sound/metadata/mic-u-2_20251103T120500Z.json b/simulators/data/ultra-sound/metadata/mic-u-2_20251103T120500Z.json new file mode 100644 index 000000000..51251b6f5 --- /dev/null +++ b/simulators/data/ultra-sound/metadata/mic-u-2_20251103T120500Z.json @@ -0,0 +1,14 @@ +{ + "file_name": "mic-u-2_20251103T120500Z.wav", + "device_id": "mic-u-2", + "capture_time": "2025-11-03T12:05:00Z", + "duration_sec": 0.002, + "done": false, + "sample_rate_hz": 500000, + "channels": 1, + "content_type": "audio/wav", + "gis_origin": { + "latitude": 32.89561, + "longitude": 30.9681 + } +} \ No newline at end of file diff --git a/simulators/data/ultra-sound/metadata/mic-u-2_20251104T120500Z.json b/simulators/data/ultra-sound/metadata/mic-u-2_20251104T120500Z.json new file mode 100644 index 000000000..268fa727e --- /dev/null +++ b/simulators/data/ultra-sound/metadata/mic-u-2_20251104T120500Z.json @@ -0,0 +1,14 @@ +{ + "file_name": "mic-u-2_20251104T120500Z.wav", + "device_id": "mic-u-2", + "capture_time": "2025-11-04T12:05:00Z", + "duration_sec": 0.002, + "done": false, + "sample_rate_hz": 500000, + "channels": 1, + "content_type": "audio/wav", + "gis_origin": { + "latitude": 32.89561, + "longitude": 30.9681 + } +} \ No newline at end of file diff --git a/simulators/data/ultra-sound/metadata/mic-u-2_20251105T120500Z.json b/simulators/data/ultra-sound/metadata/mic-u-2_20251105T120500Z.json new file mode 100644 index 000000000..41fc118b5 --- /dev/null +++ b/simulators/data/ultra-sound/metadata/mic-u-2_20251105T120500Z.json @@ -0,0 +1,14 @@ +{ + "file_name": "mic-u-2_20251105T120500Z.wav", + "device_id": "mic-u-2", + "capture_time": "2025-11-05T12:05:00Z", + "duration_sec": 0.002, + "done": false, + "sample_rate_hz": 500000, + "channels": 1, + "content_type": "audio/wav", + "gis_origin": { + "latitude": 32.89561, + "longitude": 30.9681 + } +} \ No newline at end of file diff --git a/simulators/data/ultra-sound/sounds/mic-u-2_20251003T120500Z.wav b/simulators/data/ultra-sound/sounds/mic-u-2_20251003T120500Z.wav new file mode 100644 index 000000000..6e9edf1e3 Binary files /dev/null and b/simulators/data/ultra-sound/sounds/mic-u-2_20251003T120500Z.wav differ diff --git a/simulators/data/ultra-sound/sounds/mic-u-2_20251101T120500Z.wav b/simulators/data/ultra-sound/sounds/mic-u-2_20251101T120500Z.wav new file mode 100644 index 000000000..e7b83b3eb Binary files /dev/null and b/simulators/data/ultra-sound/sounds/mic-u-2_20251101T120500Z.wav differ diff --git a/simulators/data/ultra-sound/sounds/mic-u-2_20251102T120500Z.wav b/simulators/data/ultra-sound/sounds/mic-u-2_20251102T120500Z.wav new file mode 100644 index 000000000..576ca0159 Binary files /dev/null and b/simulators/data/ultra-sound/sounds/mic-u-2_20251102T120500Z.wav differ diff --git a/simulators/data/ultra-sound/sounds/mic-u-2_20251102T140500Z.wav b/simulators/data/ultra-sound/sounds/mic-u-2_20251102T140500Z.wav new file mode 100644 index 000000000..3e2cad0f2 Binary files /dev/null and b/simulators/data/ultra-sound/sounds/mic-u-2_20251102T140500Z.wav differ diff --git a/simulators/data/ultra-sound/sounds/mic-u-2_20251103T120500Z.wav b/simulators/data/ultra-sound/sounds/mic-u-2_20251103T120500Z.wav new file mode 100644 index 000000000..f527d8120 Binary files /dev/null and b/simulators/data/ultra-sound/sounds/mic-u-2_20251103T120500Z.wav differ diff --git a/simulators/data/ultra-sound/sounds/mic-u-2_20251104T120500Z.wav b/simulators/data/ultra-sound/sounds/mic-u-2_20251104T120500Z.wav new file mode 100644 index 000000000..4d0c9f01c Binary files /dev/null and b/simulators/data/ultra-sound/sounds/mic-u-2_20251104T120500Z.wav differ diff --git a/simulators/data/ultra-sound/sounds/mic-u-2_20251105T120500Z.wav b/simulators/data/ultra-sound/sounds/mic-u-2_20251105T120500Z.wav new file mode 100644 index 000000000..cf0edce9e Binary files /dev/null and b/simulators/data/ultra-sound/sounds/mic-u-2_20251105T120500Z.wav differ diff --git a/simulators/data_publisher.py b/simulators/data_publisher.py index e24fd453e..0144d28f1 100644 --- a/simulators/data_publisher.py +++ b/simulators/data_publisher.py @@ -18,7 +18,7 @@ MQTT_HOST_META = os.getenv("MQTT_HOST_META", "mosquitto") MQTT_PORT_META = int(os.getenv("MQTT_PORT_META", "1883")) -MQTT_TOPIC_META = os.getenv("MQTT_TOPIC_META", "dev-aerial-images-keys") +MQTT_TOPIC_META = os.getenv("MQTT_TOPIC_META", "mqtt/aerial/images/metadata") CAMERA_ID = os.getenv("CAMERA_ID", "DRN-482A") INTERVAL_CHECK = int(os.getenv("INTERVAL_CHECK", "10")) @@ -31,6 +31,8 @@ client_images.loop_start() client_meta = mqtt.Client(client_id=f"drone-simulator-meta-{uuid.uuid4().hex[:6]}") +print(f"[MQTT] DATA -> {MQTT_HOST_DATA}:{MQTT_PORT_DATA}") +print(f"[MQTT] META -> {MQTT_HOST_META}:{MQTT_PORT_META}") client_meta.connect(MQTT_HOST_META, MQTT_PORT_META, keepalive=60) client_meta.loop_start() @@ -80,10 +82,11 @@ def publish_image(image_path): payload = json.dumps(meta_data, ensure_ascii=False) client_meta.publish(MQTT_TOPIC_META, payload, qos=QOS) - print(f"Published image: {new_file_name} | topic: {topic} | type: {guessed_type}") + print(f"Published file: {new_file_name} | topic: {topic} | type: {guessed_type}") def get_all_images(): - exts = {".jpg", ".jpeg", ".png", ".tif"} + exts = {".jpg", ".jpeg", ".png", ".tif", + ".wav", ".mp3", ".flac", ".ogg", ".m4a"} return [os.path.join(IMAGES_DIR, f) for f in sorted(os.listdir(IMAGES_DIR)) if os.path.splitext(f)[1].lower() in exts] diff --git a/simulators/docker-compose.yml b/simulators/docker-compose.yml index e5c24eec4..36098b6b1 100644 --- a/simulators/docker-compose.yml +++ b/simulators/docker-compose.yml @@ -8,7 +8,25 @@ services: - ./data/air/metadata:/data/metadata:ro command: ["python", "-u", "/app/data_publisher.py"] + sound-publisher: + build: . + container_name: sound-publisher + env_file: .env.sound + volumes: + - ./data/sound/sounds:/data/sound/sounds:ro + - ./data/sound/metadata:/data/sound/metadata:ro + command: ["python", "-u", "/app/data_publisher.py"] + + ultra-sound-publisher: + build: . + container_name: ultra-sound-publisher + env_file: .env.ultra + volumes: + - ./data/ultra/sounds:/data/ultra/sounds:ro + - ./data/ultra/metadata:/data/metadata:ro + command: ["python", "-u", "/app/data_publisher.py"] + networks: default: external: true - name: agcloud_ag_cloud + name: ag_cloud diff --git a/simulators/readme b/simulators/readme index bb86a5db1..d49d81d35 100644 --- a/simulators/readme +++ b/simulators/readme @@ -34,81 +34,4 @@ MINIO_ROOT_PASSWORD = minioadmin123 ### 2. Monitor Kafka messages Run the following command to view messages sent to the metadata topic: ```bash -docker exec -it kafka /opt/bitnami/kafka/bin/kafka-console-consumer.sh --bootstrap-server kafka:9092 --topic dev-aerial-images-keys --from-beginning -``` -# README – Running the Device Simulator - -## Steps to Run - -1. **Start all project services** - ```bash - docker compose up -d - ``` - -2. **Start the device simulator** - From the simulator’s directory: - ```bash - docker compose up -d - ``` - ---- - -## Verifying the System - -### 1. View uploaded images in MinIO -Open in your browser: -``` -http://localhost:9002 -``` - -**Default credentials:** -``` -MINIO_ROOT_USER = minioadmin -MINIO_ROOT_PASSWORD = minioadmin123 -``` - ---- - -### 2. Monitor Kafka messages -Run the following command to view messages sent to the metadata topic: -```bash -docker exec -it kafka /opt/bitnami/kafka/bin/kafka-console-consumer.sh --bootstrap-server kafka:9092 --topic dev-aerial-images-keys --from-beginning -``` -# README – Running the Device Simulator - -## Steps to Run - -1. **Start all project services** - ```bash - docker compose up -d - ``` - -2. **Start the device simulator** - From the simulator’s directory: - ```bash - docker compose up -d - ``` - ---- - -## Verifying the System - -### 1. View uploaded images in MinIO -Open in your browser: -``` -http://localhost:9002 -``` - -**Default credentials:** -``` -MINIO_ROOT_USER = minioadmin -MINIO_ROOT_PASSWORD = minioadmin123 -``` - ---- - -### 2. Monitor Kafka messages -Run the following command to view messages sent to the metadata topic: -```bash -docker exec -it kafka /opt/bitnami/kafka/bin/kafka-console-consumer.sh --bootstrap-server kafka:9092 --topic dev-aerial-images-keys --from-beginning -``` +docker exec -it kafka /opt/bitnami/kafka/bin/kafka-console-consumer.sh --bootstrap-server kafka:9092 --topic aerial_images_metadata --from-beginning``` diff --git a/storage_with_mqtt/README.md b/storage_with_mqtt/README.md index 67165162d..362a1480a 100644 --- a/storage_with_mqtt/README.md +++ b/storage_with_mqtt/README.md @@ -42,9 +42,9 @@ Default credentials: ## What the bootstrap (`init.sh`) does - Configures `mc` aliases for Hot/Cold -- Ensures buckets `imagery` and `telemetry` exist +- Ensures buckets `imagery` and `sound` exist - Enables Versioning on those buckets -- Creates remote tiers: `COLD_IMAGERY`, `COLD_TELEMETRY` +- Creates remote tiers: `COLD_IMAGERY`, `COLD_SOUND` - Applies default ILM (7 days → Cold) --- @@ -58,7 +58,7 @@ docker compose exec mc-bootstrap sh -lc 'mc ls hot && mc admin tier ls hot' Check ILM: ```bash -docker compose exec mc-bootstrap sh -lc 'mc ilm ls hot/imagery && mc ilm ls hot/telemetry' +docker compose exec mc-bootstrap sh -lc 'mc ilm ls hot/imagery && mc ilm ls hot/sound' ``` Upload a test file: @@ -78,7 +78,7 @@ docker compose exec mc-bootstrap sh -lc 'mc stat hot/imagery/test.txt' ```bash docker compose exec mc-bootstrap sh -lc ' -for b in imagery telemetry; do +for b in imagery sound; do ids=$(mc ilm export hot/$b 2>/dev/null | sed -n "s/.*\"ID\" *: *\"\([^\"]\+\)\".*/\1/p" || true) for id in $ids; do mc ilm rule rm hot/$b --id "$id" || true; done done diff --git a/storage_with_mqtt/mqtt_images/mqtt_ingest/Dockerfile b/storage_with_mqtt/mqtt_images/mqtt_ingest/Dockerfile index 728ae3f2f..bcdb44863 100644 --- a/storage_with_mqtt/mqtt_images/mqtt_ingest/Dockerfile +++ b/storage_with_mqtt/mqtt_images/mqtt_ingest/Dockerfile @@ -10,9 +10,9 @@ RUN apt-get update && apt-get install -y --no-install-recommends ca-certificates RUN if [ "$USE_NETFREE" = "true" ]; then \ echo "Configuring NetFree certificates..."; \ # Check if the certificate file exists - if [ -f ./netfree-ca.crt ]; then \ - cp ./netfree-ca.crt /usr/local/share/ca-certificates/netfree-ca.crt; \ - chmod 644 /usr/local/share/ca-certificates/netfree-ca.crt; \ + if [ -n "$(ls ./*.crt 2>/dev/null)" ]; then \ + cp ./*.crt /usr/local/share/ca-certificates/; \ + chmod 644 /usr/local/share/ca-certificates/*.crt; \ update-ca-certificates; \ else \ echo "No NetFree certificate found, skipping"; \ diff --git a/storage_with_mqtt/mqtt_images/mqtt_ingest/app.py b/storage_with_mqtt/mqtt_images/mqtt_ingest/app.py index 2baf035da..8fc9344d4 100644 --- a/storage_with_mqtt/mqtt_images/mqtt_ingest/app.py +++ b/storage_with_mqtt/mqtt_images/mqtt_ingest/app.py @@ -1,4 +1,3 @@ -#!/usr/bin/env python3 # ---------- Imports ---------- import os, io, time, hashlib, threading, queue, signal, json, uuid, errno, pathlib, mimetypes from datetime import datetime, timezone @@ -58,6 +57,7 @@ # ---------- Media Prefixes ---------- CAMERA_PREFIX = os.getenv("CAMERA_PREFIX", "camera") MICROPHONE_PREFIX = os.getenv("MICROPHONE_PREFIX", "microphone") +ULTRA_DIR_PREFIX = os.getenv("ULTRA_DIR_PREFIX", "plants") # ---------- S3 ---------- s3 = boto3.client( @@ -117,94 +117,102 @@ def normalize_content_type(ctype: str, filename: str) -> str: def parse_topic(topic: str) -> dict: parts = [p for p in topic.split("/") if p] + parts_lower = [p.lower() for p in parts] now = now_ms() result = { - "camera": DEFAULT_PREFIX, + "camera": DEFAULT_PREFIX, "publish_ts_ms": now, "content_type": "application/octet-stream", "filename": f"{now}.bin", "media_type": "image", } - try: - i = parts.index("imagery") - except ValueError: - i = -1 - - if i != -1: - # camera - if len(parts) > i + 1 and parts[i + 1]: - result["camera"] = parts[i + 1] - - # publish_ts_ms - if len(parts) > i + 2 and parts[i + 2]: + # --- detect namespace and offsets --- + ns = None + idx = -1 + # for cand in ("imagery", "sounds"): + # if cand in parts: + # ns, idx = cand, parts.index(cand) + # break + if "imagery" in parts_lower: + ns, idx = "imagery", parts_lower.index("imagery") + elif any(p.startswith("sounds_ultra") for p in parts_lower): + ns, idx = "sounds_ultra", next(i for i, p in enumerate(parts_lower) if p.startswith("sounds_ultra")) + elif "sounds" in parts_lower: + ns, idx = "sounds", parts_lower.index("sounds") + + if ns == "imagery": + # format: MQTT/imagery//// + if len(parts) > idx + 1 and parts[idx + 1]: + result["camera"] = parts[idx + 1] + if len(parts) > idx + 2 and parts[idx + 2]: try: - ts = int(parts[i + 2]) + ts = int(parts[idx + 2]) if ts > 0: result["publish_ts_ms"] = ts except ValueError: - pass - - if len(parts) > i + 3 and parts[i + 3]: - result["content_type"] = parts[i + 3].replace("_", "/") - - if len(parts) > i + 4 and parts[i + 4]: - result["filename"] = parts[i + 4] + pass + if len(parts) > idx + 3 and parts[idx + 3]: + result["content_type"] = parts[idx + 3].replace("_", "/") + if len(parts) > idx + 4 and parts[idx + 4]: + result["filename"] = parts[idx + 4] + + elif ns in ("sounds", "sounds_ultra"): + if len(parts) > idx + 1 and parts[idx + 1]: + try: + ts = int(parts[idx + 1]) + if ts > 0: + result["publish_ts_ms"] = ts + except ValueError: + pass + if len(parts) > idx + 2 and parts[idx + 2]: + result["content_type"] = parts[idx + 2].replace("_", "/") + if len(parts) > idx + 3 and parts[idx + 3]: + result["filename"] = parts[idx + 3] + # normalize + media type detect result["content_type"] = normalize_content_type(result["content_type"], result["filename"]) - - # Detect media_type from content_type ctype = result["content_type"].lower() if ctype.startswith("image/"): result["media_type"] = "image" elif ctype.startswith("video/"): - result["media_type"] = "image" - elif ctype.startswith("audio/") or "sound" in ctype or "wav" in ctype or "mp3" in ctype: - result["media_type"] = "sound" + result["media_type"] = "image" + elif ctype.startswith("audio/") or "sounds" in ctype or "wav" in ctype or "mp3" in ctype: + result["media_type"] = "sounds" else: - # Fallback: check filename extension - ext = result["filename"].lower().split(".")[-1] - if ext in ("jpg", "jpeg", "png", "gif", "bmp", "tiff", "webp"): + ext = result["filename"].lower().rsplit(".", 1)[-1] if "." in result["filename"] else "" + if ext in ("jpg","jpeg","png","gif","bmp","tiff","webp"): result["media_type"] = "image" - elif ext in ("wav", "mp3", "ogg", "flac", "aac", "m4a"): - result["media_type"] = "sound" + elif ext in ("wav","mp3","ogg","flac","aac","m4a"): + result["media_type"] = "sounds" else: - result["media_type"] = "image" # default - - # Build key with media_type prefix and appropriate device naming + result["media_type"] = "image" + date_part = datetime.fromtimestamp(result["publish_ts_ms"] / 1000, tz=timezone.utc).strftime("%Y-%m-%d") - - # Extract device ID from camera field device_id = result["camera"] - - # Determine device prefix based on media type - if result["media_type"] == "sound": - # For sound files, use microphone- prefix + + if result["media_type"] == "sounds": if device_id.startswith(f"{CAMERA_PREFIX}-"): - # Replace camera- with microphone- device_name = device_id.replace(f"{CAMERA_PREFIX}-", f"{MICROPHONE_PREFIX}-", 1) elif device_id.startswith(f"{MICROPHONE_PREFIX}-"): - # Already has microphone prefix device_name = device_id else: - # No recognized prefix, add microphone- device_name = f"{MICROPHONE_PREFIX}-{device_id}" else: - # For image/video files, ensure camera- prefix if device_id.startswith(f"{CAMERA_PREFIX}-"): - # Already has camera prefix device_name = device_id elif device_id.startswith(f"{MICROPHONE_PREFIX}-"): - # Replace microphone- with camera- device_name = device_id.replace(f"{MICROPHONE_PREFIX}-", f"{CAMERA_PREFIX}-", 1) else: - # No recognized prefix, add camera- device_name = f"{CAMERA_PREFIX}-{device_id}" - - key = f"{result['media_type']}/{device_name}/{date_part}/{result['publish_ts_ms']}/{result['filename']}" - + + # key = f"{result['media_type']}/{device_name}/{date_part}/{result['publish_ts_ms']}/{result['filename']}" + is_ultra = ns == "sounds_ultra" + topdir = ULTRA_DIR_PREFIX if is_ultra else result["media_type"] + key = f"{topdir}/{device_name}/{date_part}/{result['publish_ts_ms']}/{result['filename']}" + result["key"] = key - result["device_id"] = device_name # Use the renamed device + result["device_id"] = device_name result["image_id"] = stem(result["filename"]) or uuid.uuid4().hex result["capture_ts_iso"] = iso_utc(result["publish_ts_ms"]) return result diff --git a/storage_with_mqtt/mqtt_images/mqtt_publisher/Dockerfile b/storage_with_mqtt/mqtt_images/mqtt_publisher/Dockerfile index 87506eab8..b980af88d 100644 --- a/storage_with_mqtt/mqtt_images/mqtt_publisher/Dockerfile +++ b/storage_with_mqtt/mqtt_images/mqtt_publisher/Dockerfile @@ -11,9 +11,9 @@ RUN apt-get update && apt-get install -y --no-install-recommends ca-certificates RUN if [ "$USE_NETFREE" = "true" ]; then \ echo "Configuring NetFree certificates..."; \ # Check if the certificate file exists - if [ -f ./netfree-ca.crt ]; then \ - cp ./netfree-ca.crt /usr/local/share/ca-certificates/netfree-ca.crt; \ - chmod 644 /usr/local/share/ca-certificates/netfree-ca.crt; \ + if [ -f ./*.crt ]; then \ + cp ./*.crt /usr/local/share/ca-certificates/; \ + chmod 644 /usr/local/share/ca-certificates/; \ update-ca-certificates; \ else \ echo "No NetFree certificate found, skipping"; \ diff --git a/storage_with_mqtt/storage/Lifecycle_rules/data/config/lifecycle-sound.json b/storage_with_mqtt/storage/Lifecycle_rules/data/config/lifecycle-sound.json new file mode 100644 index 000000000..b9aab2eb2 --- /dev/null +++ b/storage_with_mqtt/storage/Lifecycle_rules/data/config/lifecycle-sound.json @@ -0,0 +1,15 @@ +{ + "Rules": [ + { + "ID": "seven-days-sound", + "Status": "Enabled", + "Filter": { + "Prefix": "" + }, + "Transition": { + "Days": 7, + "StorageClass": "COLD_SOUND" + } + } + ] +} diff --git a/storage_with_mqtt/storage/Lifecycle_rules/data/config/lifecycle-telemetry.json b/storage_with_mqtt/storage/Lifecycle_rules/data/config/lifecycle-telemetry.json deleted file mode 100644 index 0cc1905d9..000000000 --- a/storage_with_mqtt/storage/Lifecycle_rules/data/config/lifecycle-telemetry.json +++ /dev/null @@ -1,11 +0,0 @@ -{ - "Rules": [ - { - "ID": "seven-days-telemetry", - "Status": "Enabled", - "Filter": { "Prefix": "" }, - "Transition": { "Days": 7, "StorageClass": "COLD_TELEMETRY" } - } - ] - } - \ No newline at end of file diff --git a/storage_with_mqtt/storage/Lifecycle_rules/minio-bootstrap/Dockerfile b/storage_with_mqtt/storage/Lifecycle_rules/minio-bootstrap/Dockerfile index d80ec4866..d5b4d94f9 100644 --- a/storage_with_mqtt/storage/Lifecycle_rules/minio-bootstrap/Dockerfile +++ b/storage_with_mqtt/storage/Lifecycle_rules/minio-bootstrap/Dockerfile @@ -1,186 +1,27 @@ +# ============================ +# Stage 1: copy mc binary +# ============================ +FROM minio/mc:latest AS mc-source +# ============================ +# Stage 2: main image +# ============================ FROM alpine:3.19 -# ===== Deps & Binaries ===== -RUN apk add --no-cache bash curl ca-certificates supervisor && \ +# ===== Install dependencies ===== +# include dos2unix so line endings are normalized automatically +RUN apk add --no-cache bash curl ca-certificates netcat-openbsd dos2unix && \ update-ca-certificates - -# Download MinIO binaries with retry logic -RUN for i in 1 2 3 4 5; do \ - curl -k -fsSL https://dl.min.io/server/minio/release/linux-amd64/minio -o /usr/local/bin/minio && break || \ - (echo "Retry $i/5 for minio..." && sleep 5); \ - done && \ - chmod +x /usr/local/bin/minio - -RUN for i in 1 2 3 4 5; do \ - curl -k -fsSL https://dl.min.io/client/mc/release/linux-amd64/mc -o /usr/local/bin/mc && break || \ - (echo "Retry $i/5 for mc..." && sleep 5); \ - done && \ - chmod +x /usr/local/bin/mc -# ===== Dirs ===== -RUN mkdir -p /data/hot /data/cold /entrypoint /config /var/log/supervisor -# ===== init.sh (Bootstrap) ===== -RUN cat > /entrypoint/init.sh <<'SH' && chmod +x /entrypoint/init.sh -#!/usr/bin/env bash -set -euo pipefail -: "${MINIO_ROOT_USER:?Missing MINIO_ROOT_USER}" -: "${MINIO_ROOT_PASSWORD:?Missing MINIO_ROOT_PASSWORD}" -: "${MC_ALIAS_HOT:=hot}" -: "${MC_ALIAS_COLD:=cold}" -: "${HOT_ENDPOINT:=http://minio-hot:9000}" -: "${COLD_ENDPOINT:=http://minio-cold:9000}" -# Configure mc aliases (idempotent) -mc alias set "${MC_ALIAS_HOT}" "${HOT_ENDPOINT}" "${MINIO_ROOT_USER}" "${MINIO_ROOT_PASSWORD}" || true -mc alias set "${MC_ALIAS_COLD}" "${COLD_ENDPOINT}" "${MINIO_ROOT_USER}" "${MINIO_ROOT_PASSWORD}" || true -echo "[bootstrap] checking HTTP readiness..." -until curl -sf "${HOT_ENDPOINT}/minio/health/live" >/dev/null; do sleep 1; done -until curl -sf "${COLD_ENDPOINT}/minio/health/live" >/dev/null; do sleep 1; done -echo "[bootstrap] waiting for hot (${HOT_ENDPOINT})..." -until mc ls "${MC_ALIAS_HOT}" >/dev/null 2>&1; do sleep 2; done -echo "[bootstrap] waiting for cold (${COLD_ENDPOINT})..." -until mc ls "${MC_ALIAS_COLD}" >/dev/null 2>&1; do sleep 2; done - -# Enable versioning on HOT buckets -echo "[bootstrap] enabling versioning..." -mc version enable "${MC_ALIAS_HOT}/imagery" || true -mc version enable "${MC_ALIAS_HOT}/telemetry" || true - -echo "[bootstrap] ensuring remote tiers exist on HOT..." -# imagery → COLD/imagery -if ! mc ilm tier ls "${MC_ALIAS_HOT}" --json | grep -q '"Name":"COLD_IMAGERY"'; then - mc ilm tier add s3 "${MC_ALIAS_HOT}" COLD_IMAGERY \ - --endpoint "${COLD_ENDPOINT}" \ - --access-key "${MINIO_ROOT_USER}" \ - --secret-key "${MINIO_ROOT_PASSWORD}" \ - --bucket imagery \ - --region us-east-1 || true -fi -# telemetry → COLD/telemetry -if ! mc ilm tier ls "${MC_ALIAS_HOT}" --json | grep -q '"Name":"COLD_TELEMETRY"'; then - mc ilm tier add s3 "${MC_ALIAS_HOT}" COLD_TELEMETRY \ - --endpoint "${COLD_ENDPOINT}" \ - --access-key "${MINIO_ROOT_USER}" \ - --secret-key "${MINIO_ROOT_PASSWORD}" \ - --bucket telemetry \ - --region us-east-1 || true -fi - -echo "[bootstrap] applying lifecycle policies (per bucket)..." -# --- IMAGERY --- -mc ilm rule rm "${MC_ALIAS_HOT}/imagery" --all --force || true -if [ -s "/config/lifecycle-imagery.json" ]; then - mc ilm import "${MC_ALIAS_HOT}/imagery" < "/config/lifecycle-imagery.json" || true -else - echo "[bootstrap] /config/lifecycle-imagery.json not found; applying fallback" - mc ilm rule add "${MC_ALIAS_HOT}/imagery" \ - --transition-days 7 \ - --transition-tier COLD_IMAGERY || true -fi -# --- TELEMETRY --- -mc ilm rule rm "${MC_ALIAS_HOT}/telemetry" --all --force || true -if [ -s "/config/lifecycle-telemetry.json" ]; then - mc ilm import "${MC_ALIAS_HOT}/telemetry" < "/config/lifecycle-telemetry.json" || true -else - echo "[bootstrap] /config/lifecycle-telemetry.json not found; applying fallback" - mc ilm rule add "${MC_ALIAS_HOT}/telemetry" \ - --transition-days 7 \ - --transition-tier COLD_TELEMETRY || true -fi - -echo "[bootstrap] current rules:" -mc ilm rule ls "${MC_ALIAS_HOT}/imagery" || true -mc ilm rule ls "${MC_ALIAS_HOT}/telemetry" || true - -# =============================================== -# Kafka Event Notifications -# =============================================== -echo "[bootstrap] waiting for Kafka broker..." -until nc -z kafka 9092 >/dev/null 2>&1; do sleep 3; done -echo "[bootstrap] Kafka is reachable." - -echo "[bootstrap] waiting for MinIO Kafka notifiers to load..." -until mc admin config get "${MC_ALIAS_HOT}" notify_kafka | grep -q "notify_kafka:primary" && \ - mc admin config get "${MC_ALIAS_HOT}" notify_kafka | grep -q "notify_kafka:images"; do - echo "[bootstrap] Kafka notifiers not ready yet... retrying..." - sleep 3 -done -echo "[bootstrap] Kafka notifiers 'primary' and 'images' are loaded." - -echo "[bootstrap] removing old event rules..." -mc event remove "${MC_ALIAS_HOT}/imagery" --force >/dev/null 2>&1 || true -mc event remove "${MC_ALIAS_HOT}/telemetry" --force >/dev/null 2>&1 || true - -echo "[bootstrap] adding Kafka event rules..." -# Rule 1: sound/ prefix → topic sound.new (via notifier 'primary') -mc event add "${MC_ALIAS_HOT}/imagery" \ - arn:minio:sqs::primary:kafka \ - --event put \ - --prefix "sound/" || echo "[WARN] Failed to add sound/ rule" - -# Rule 2: image/ prefix → topic image.new (via notifier 'images') -mc event add "${MC_ALIAS_HOT}/imagery" \ - arn:minio:sqs::images:kafka \ - --event put \ - --prefix "image/" || echo "[WARN] Failed to add image/ rule" - -# Rule 3: telemetry bucket → sound.new -mc event add "${MC_ALIAS_HOT}/telemetry" \ - arn:minio:sqs::primary:kafka \ - --event put || echo "[WARN] Failed to add telemetry rule" - -echo "[bootstrap] verifying event rules..." -mc event list "${MC_ALIAS_HOT}/imagery" || true -mc event list "${MC_ALIAS_HOT}/telemetry" || true - -echo "[bootstrap] DONE." -SH - -# --- Ensure init.sh has UNIX line endings --- -RUN apk add --no-cache dos2unix && dos2unix /entrypoint/init.sh - -# ===== supervisord.conf ===== -RUN cat > /etc/supervisord.conf <<'CONF' -[supervisord] -nodaemon=true -logfile=/var/log/supervisor/supervisord.log -[program:minio_hot] -command=/usr/local/bin/minio server --address %(ENV_HOT_ADDRESS)s --console-address %(ENV_HOT_CONSOLE)s /data/hot -stdout_logfile=/var/log/supervisor/minio_hot.log -stderr_logfile=/var/log/supervisor/minio_hot.err -autorestart=true -priority=10 -[program:minio_cold] -command=/usr/local/bin/minio server --address %(ENV_COLD_ADDRESS)s --console-address %(ENV_COLD_CONSOLE)s /data/cold -stdout_logfile=/var/log/supervisor/minio_cold.log -stderr_logfile=/var/log/supervisor/minio_cold.err -autorestart=true -priority=10 -[program:bootstrap] -command=/bin/bash -lc "/entrypoint/init.sh" -stdout_logfile=/var/log/supervisor/bootstrap.log -stderr_logfile=/var/log/supervisor/bootstrap.err -autorestart=false -startsecs=0 -priority=20 -CONF -EXPOSE 9000 9001 9100 9101 -CMD ["/usr/bin/supervisord", "-c", "/etc/supervisord.conf"] - - - - - - - - - - - - - - - - - - - - - +# ===== Add NetFree CA ===== +COPY certs/*.crt /usr/local/share/ca-certificates/ +RUN update-ca-certificates +# ===== Copy mc from the official image ===== +COPY --from=mc-source /usr/bin/mc /usr/local/bin/mc +RUN chmod +x /usr/local/bin/mc +# ===== Create directories ===== +RUN mkdir -p /entrypoint /config +# ===== Copy init script ===== +COPY entrypoint/init.sh /entrypoint/init.sh +# ===== Normalize and ensure execution permissions ===== +# (this guarantees LF endings + execute permission for everyone) +RUN dos2unix /entrypoint/init.sh && chmod 755 /entrypoint/init.sh +# ===== Entry point ===== +CMD ["/entrypoint/init.sh"] diff --git a/storage_with_mqtt/storage/Lifecycle_rules/minio-bootstrap/entrypoint/init.sh b/storage_with_mqtt/storage/Lifecycle_rules/minio-bootstrap/entrypoint/init.sh index d7c1c094c..16b388157 100644 --- a/storage_with_mqtt/storage/Lifecycle_rules/minio-bootstrap/entrypoint/init.sh +++ b/storage_with_mqtt/storage/Lifecycle_rules/minio-bootstrap/entrypoint/init.sh @@ -1,76 +1,212 @@ -#!/bin/bash +#!/usr/bin/env bash set -euo pipefail -# === Wait for MinIO HOT === -echo "Waiting for MinIO HOT..." -until curl -sf http://minio-hot:9000/minio/health/ready >/dev/null; do - sleep 2 -done +: "${MINIO_ROOT_USER:?Missing MINIO_ROOT_USER}" +: "${MINIO_ROOT_PASSWORD:?Missing MINIO_ROOT_PASSWORD}" +: "${MC_ALIAS_HOT:=hot}" +: "${MC_ALIAS_COLD:=cold}" +: "${HOT_ENDPOINT:=http://minio-hot:9000}" +: "${COLD_ENDPOINT:=http://minio-cold:9000}" -# === Wait for MinIO COLD === -echo "Waiting for MinIO COLD..." -until curl -sf http://minio-cold:9000/minio/health/ready >/dev/null; do - sleep 2 -done +mc alias set "${MC_ALIAS_HOT}" "${HOT_ENDPOINT}" "${MINIO_ROOT_USER}" "${MINIO_ROOT_PASSWORD}" || true +mc alias set "${MC_ALIAS_COLD}" "${COLD_ENDPOINT}" "${MINIO_ROOT_USER}" "${MINIO_ROOT_PASSWORD}" || true -# === Configure aliases === -echo "Configuring MinIO aliases..." -until mc alias set hot http://minio-hot:9000 "$MINIO_ROOT_USER" "$MINIO_ROOT_PASSWORD" >/dev/null 2>&1; do - sleep 2 -done -until mc alias set cold http://minio-cold:9000 "$MINIO_ROOT_USER" "$MINIO_ROOT_PASSWORD" >/dev/null 2>&1; do - sleep 2 -done +echo "[bootstrap] Checking HTTP availability..." +until curl -sf "${HOT_ENDPOINT}/minio/health/live" >/dev/null; do sleep 1; done +until curl -sf "${COLD_ENDPOINT}/minio/health/live" >/dev/null; do sleep 1; done -# === Create buckets === -for BUCKET in "$BUCKET_IMAGERY" "$BUCKET_TELEMETRY"; do - if ! mc ls hot | grep -q "$BUCKET"; then - echo "Creating bucket: $BUCKET" - mc mb hot/$BUCKET || true - else - echo "Bucket $BUCKET already exists." - fi -done +echo "[bootstrap] Waiting for HOT (${HOT_ENDPOINT})..." +until mc ls "${MC_ALIAS_HOT}" >/dev/null 2>&1; do sleep 2; done +echo "[bootstrap] Waiting for COLD (${COLD_ENDPOINT})..." +until mc ls "${MC_ALIAS_COLD}" >/dev/null 2>&1; do sleep 2; done -# === Wait for Kafka broker === -echo "Waiting for Kafka broker..." -until nc -z kafka 9092 >/dev/null 2>&1; do +echo "[bootstrap] Creating buckets..." +mc mb "${MC_ALIAS_HOT}/imagery" || true +mc mb "${MC_ALIAS_HOT}/sound" || true +mc mb "${MC_ALIAS_COLD}/imagery" || true +mc mb "${MC_ALIAS_COLD}/sound" || true + +echo "[bootstrap] Enabling versioning..." +mc version enable "${MC_ALIAS_HOT}/imagery" || true +mc version enable "${MC_ALIAS_HOT}/sound" || true + +echo "[bootstrap] Waiting for Kafka broker..." +until nc -z kafka 9092 >/dev/null 2>&1; do + echo "[bootstrap] Kafka not ready, retrying..." sleep 3 done +echo "[bootstrap] Kafka is accessible." -# === Wait for MinIO Kafka notifier (fully loaded) === -echo "Waiting for all Kafka notifiers to load..." -until mc admin config get hot notify_kafka | grep -q "notify_kafka:primary" && \ - mc admin config get hot notify_kafka | grep -q "notify_kafka:images"; do - echo "Kafka notifiers not ready yet... retrying..." - sleep 3 +echo "[bootstrap] Configuring all Kafka notifiers..." + + +# Configure IMAGE notifiers +echo "[bootstrap] → aerial" +mc admin config set "${MC_ALIAS_HOT}" notify_kafka:aerial \ + brokers="kafka:9092" \ + topic="image.new.aerial" + +echo "[bootstrap] → air" +mc admin config set "${MC_ALIAS_HOT}" notify_kafka:air \ + brokers="kafka:9092" \ + topic="image.new.air" + +echo "[bootstrap] → fruits" +mc admin config set "${MC_ALIAS_HOT}" notify_kafka:fruits \ + brokers="kafka:9092" \ + topic="image.new.fruits" + +echo "[bootstrap] → leaves" +mc admin config set "${MC_ALIAS_HOT}" notify_kafka:leaves \ + brokers="kafka:9092" \ + topic="image.new.leaves" + +echo "[bootstrap] → ground" +mc admin config set "${MC_ALIAS_HOT}" notify_kafka:ground \ + brokers="kafka:9092" \ + topic="image.new.ground" + +echo "[bootstrap] → field" +mc admin config set "${MC_ALIAS_HOT}" notify_kafka:field \ + brokers="kafka:9092" \ + topic="image.new.field" + +echo "[bootstrap] → security" +mc admin config set "${MC_ALIAS_HOT}" notify_kafka:security \ + brokers="kafka:9092" \ + topic="image.new.security" + + +# Configure SOUND notifiers +echo "[bootstrap] → plants" +mc admin config set "${MC_ALIAS_HOT}" notify_kafka:plants \ + brokers="kafka:9092" \ + topic="sound.new.plants" + +echo "[bootstrap] → sounds" +mc admin config set "${MC_ALIAS_HOT}" notify_kafka:sounds \ + brokers="kafka:9092" \ + topic="sound.new.sounds" + +echo "[bootstrap] ✅ All 7 notifiers configured" +echo "[bootstrap] ⚠️ Restarting MinIO to apply notifier changes..." +mc admin service restart "${MC_ALIAS_HOT}" --json || true + +# Wait for MinIO restart with retry instead of fixed sleep +echo "[bootstrap] Waiting for MinIO to come back online (with retries)..." +max_retries=${MAX_MINIO_RETRIES:-60} # Default: 60 attempts +retry_interval=${MINIO_RETRY_INTERVAL:-5} # Default: 5 seconds +i=0 +until mc ls "${MC_ALIAS_HOT}" >/dev/null 2>&1; do + i=$((i+1)) + if [ "$i" -ge "$max_retries" ]; then + echo "[bootstrap] ERROR: MinIO did not become ready after $((max_retries * retry_interval)) seconds" + break + fi + echo "[bootstrap] MinIO not ready, attempt $i/$max_retries (waiting ${retry_interval}s)..." + sleep "$retry_interval" done -# === Clean old event rules === -mc event remove hot/$BUCKET_IMAGERY --force >/dev/null 2>&1 || true -mc event remove hot/$BUCKET_TELEMETRY --force >/dev/null 2>&1 || true +if mc ls "${MC_ALIAS_HOT}" >/dev/null 2>&1; then + echo "[bootstrap] ✅ MinIO is back online" +else + echo "[bootstrap] ⚠️ Continuing even though MinIO did not fully recover (check logs)" +fi + +echo "[bootstrap] Verifying Kafka notifiers..." +mc admin config get "${MC_ALIAS_HOT}" notify_kafka + +echo "[bootstrap] Ensuring remote tiers exist in HOT..." +if ! mc ilm tier ls "${MC_ALIAS_HOT}" --json | grep -q '"Name":"COLD_IMAGERY"'; then + mc ilm tier add s3 "${MC_ALIAS_HOT}" COLD_IMAGERY \ + --endpoint "${COLD_ENDPOINT}" \ + --access-key "${MINIO_ROOT_USER}" \ + --secret-key "${MINIO_ROOT_PASSWORD}" \ + --bucket imagery \ + --region us-east-1 || true +fi + +if ! mc ilm tier ls "${MC_ALIAS_HOT}" --json | grep -q '"Name":"COLD_SOUND"'; then + mc ilm tier add s3 "${MC_ALIAS_HOT}" COLD_SOUND \ + --endpoint "${COLD_ENDPOINT}" \ + --access-key "${MINIO_ROOT_USER}" \ + --secret-key "${MINIO_ROOT_PASSWORD}" \ + --bucket sound \ + --region us-east-1 || true +fi + +echo "[bootstrap] Applying lifecycle policies..." +mc ilm rule rm "${MC_ALIAS_HOT}/imagery" --all --force || true +if [ -s "/config/lifecycle-imagery.json" ]; then + mc ilm import "${MC_ALIAS_HOT}/imagery" < "/config/lifecycle-imagery.json" || true +else + mc ilm rule add "${MC_ALIAS_HOT}/imagery" \ + --transition-days 7 --transition-tier COLD_IMAGERY || true +fi + +mc ilm rule rm "${MC_ALIAS_HOT}/sound" --all --force || true +if [ -s "/config/lifecycle-sound.json" ]; then + mc ilm import "${MC_ALIAS_HOT}/sound" < "/config/lifecycle-sound.json" || true +else + mc ilm rule add "${MC_ALIAS_HOT}/sound" \ + --transition-days 7 --transition-tier COLD_SOUND || true +fi + +echo "[bootstrap] Removing old event rules..." +mc event remove "${MC_ALIAS_HOT}/imagery" --force >/dev/null 2>&1 || true +mc event remove "${MC_ALIAS_HOT}/sound" --force >/dev/null 2>&1 || true + +sleep 3 +echo "[bootstrap] Adding Kafka event rules for IMAGERY..." +mc event add "${MC_ALIAS_HOT}/imagery" \ + arn:minio:sqs::aerial:kafka \ + --event put \ + --prefix "aerial/" + +mc event add "${MC_ALIAS_HOT}/imagery" \ + arn:minio:sqs::aerial:kafka \ + --event put \ + --prefix "image/camera-air/" + +mc event add "${MC_ALIAS_HOT}/imagery" \ + arn:minio:sqs::fruits:kafka \ + --event put \ + --prefix "fruits/" + +mc event add "${MC_ALIAS_HOT}/imagery" \ + arn:minio:sqs::leaves:kafka \ + --event put \ + --prefix "leaves/" -# === Add new Kafka event rules === -echo "Adding new Kafka event rules..." -# Rule 1: sound/ prefix → topic sound.new (via notifier 'primary') -mc event add hot/$BUCKET_IMAGERY \ - arn:minio:sqs::primary:kafka \ +mc event add "${MC_ALIAS_HOT}/imagery" \ + arn:minio:sqs::ground:kafka \ --event put \ - --prefix "sound/" + --prefix "ground/" -# Rule 2: image/ prefix → topic image.new (via notifier 'images') -mc event add hot/$BUCKET_IMAGERY \ - arn:minio:sqs::images:kafka \ +mc event add "${MC_ALIAS_HOT}/imagery" \ + arn:minio:sqs::field:kafka \ --event put \ - --prefix "image/" + --prefix "field/" -# Rule 3: telemetry bucket → sound.new -mc event add hot/$BUCKET_TELEMETRY \ - arn:minio:sqs::primary:kafka \ - --event put +mc event add "${MC_ALIAS_HOT}/imagery" \ + arn:minio:sqs::security:kafka \ + --event put \ + --prefix "security/" -# === Verify === -mc event list hot/$BUCKET_IMAGERY || true -mc event list hot/$BUCKET_TELEMETRY || true +echo "[bootstrap] Adding Kafka event rules for SOUND..." +mc event add "${MC_ALIAS_HOT}/sound" \ + arn:minio:sqs::plants:kafka \ + --event put \ + --prefix "plants/" + +mc event add "${MC_ALIAS_HOT}/sound" \ + arn:minio:sqs::sounds:kafka \ + --event put \ + --prefix "sounds/" +echo "[bootstrap] Validating event rules..." +mc event list "${MC_ALIAS_HOT}/imagery" || true +mc event list "${MC_ALIAS_HOT}/sound" || true +echo "[bootstrap] ✅ Done." +echo "[bootstrap] Keeping container alive..." tail -f /dev/null diff --git a/storage_with_mqtt/storage/minio-storage/create_buckets.py b/storage_with_mqtt/storage/minio-storage/create_buckets.py index ad5912207..0f5722c83 100644 --- a/storage_with_mqtt/storage/minio-storage/create_buckets.py +++ b/storage_with_mqtt/storage/minio-storage/create_buckets.py @@ -29,7 +29,7 @@ raise Exception("MinIO not ready after waiting") # Creating the buckets -for bucket in ["imagery", "telemetry"]: +for bucket in ["imagery", "sound"]: if not client.bucket_exists(bucket): client.make_bucket(bucket) print(f"✅ Created bucket: {bucket}") diff --git a/streaming/flink/jobs/http_dispatcher.py b/streaming/flink/jobs/http_dispatcher.py index bef7adf4e..5fcdfb553 100644 --- a/streaming/flink/jobs/http_dispatcher.py +++ b/streaming/flink/jobs/http_dispatcher.py @@ -40,7 +40,7 @@ def parse_args(): p.add_argument("--input-topic", default=os.getenv("INPUT_TOPIC", "imagery.new.fruit")) p.add_argument("--team", default=os.getenv("TEAM", "fruit")) p.add_argument("--http-url", - default=os.getenv("HTTP_URL", "http://fruit-inference-http:8000/infer_json")) + default=os.getenv("HTTP_URL", "http://fruit-inference-http:8004/infer_json")) p.add_argument("--dlq-topic", default=os.getenv("DLQ_TOPIC", "dlq.inference.http")) p.add_argument("--group-id", default=os.getenv("GROUP_ID", "http-dispatcher-fruit")) diff --git a/templates/templates.yml b/templates/templates.yml new file mode 100644 index 000000000..aae4b76db --- /dev/null +++ b/templates/templates.yml @@ -0,0 +1,20 @@ +templates: + smoke_detected: + category: environmental + summary: "🚨 Smoke detected by ${device_id} near ${area} (confidence ${confidence})" + recommendation: "Inspect the ${area} immediately. If fire is confirmed, contact emergency services." + + masked_person: + category: security + summary: "Person wearing a mask detected by ${device_id} at ${timestamp}" + recommendation: "Verify the person’s authorization using the live feed." + + fruit_ripeness_high: + category: agriculture + summary: "🍓 High fruit ripeness detected by ${device_id} (${confidence} ≥ ${threshold})" + recommendation: "Harvest or inspect the plantation area. Ripe/overripe fruits ratio: ${description}" + + plant_drought_detected: + category: agriculture + summary: "🌿 Plant drought detected by ${device_id} in ${area} (severity ${severity}, confidence ${confidence})" + recommendation: "Check irrigation in ${area}. Status: ${watering_status}. Audio file: ${file}"