Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
145 changes: 130 additions & 15 deletions api/models/finding.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,16 +86,10 @@ def __init__(self, dsn: Optional[str] = None) -> None:
# ------------------------------------------------------------------ #

def connect(self) -> None:
"""Open a persistent database connection and set the search path."""
"""Open a persistent database connection."""
self.conn = psycopg2.connect(self.dsn)
self.conn.autocommit = True # Set to True for schema management
with self.conn.cursor() as cur:
# Ensure the openshield schema exists and is preferred in the search path.
# This avoids 'permission denied for schema public' in restricted environments.
cur.execute("CREATE SCHEMA IF NOT EXISTS openshield;")
cur.execute("SET search_path TO openshield, public;")
self.conn.autocommit = False
logger.info("Database connection established (schema: openshield)")
logger.info("Database connection established")

def _get_conn(self) -> Any:
if self.conn is None or self.conn.closed:
Expand Down Expand Up @@ -134,10 +128,13 @@ def create_tables(self) -> None:
scan_id UUID PRIMARY KEY,
subscription_id TEXT NOT NULL,
started_at TIMESTAMPTZ NOT NULL,
claimed_at TIMESTAMPTZ,
completed_at TIMESTAMPTZ,
total_findings INTEGER DEFAULT 0,
score INTEGER DEFAULT NULL,
cve_enrichment_status TEXT DEFAULT 'PENDING'
cve_enrichment_status TEXT DEFAULT 'PENDING',
status TEXT DEFAULT 'pending',
error_message TEXT
);
""")
cur.execute("""
Expand Down Expand Up @@ -206,8 +203,19 @@ def run_migrations(self) -> None:
""")
cur.execute("""
ALTER TABLE scans
ADD COLUMN IF NOT EXISTS cve_enrichment_status TEXT DEFAULT 'PENDING'
ADD COLUMN IF NOT EXISTS cve_enrichment_status TEXT DEFAULT 'COMPLETED',
ADD COLUMN IF NOT EXISTS status TEXT DEFAULT 'completed',
ADD COLUMN IF NOT EXISTS error_message TEXT,
ADD COLUMN IF NOT EXISTS claimed_at TIMESTAMPTZ
""")
# Fix: If status already existed but was backfilled as 'pending' (e.g. from
# a previous buggy deploy), force it to 'completed' for all historical
# scans that have already finished.
cur.execute("UPDATE scans SET status = 'completed' WHERE status = 'pending' AND completed_at IS NOT NULL")

# Backfill claimed_at for any currently running scans so they don't get
# immediately marked as stale by the new recovery logic.
cur.execute("UPDATE scans SET claimed_at = started_at WHERE status = 'running' AND claimed_at IS NULL")
conn.commit()
logger.info("CVE migrations applied successfully")
except Exception as e:
Expand All @@ -221,21 +229,30 @@ def run_migrations(self) -> None:
def save_scan(self, scan_result: Dict[str, Any]) -> None:
"""Persist a full scan result (scan header + all findings)."""
conn = self._get_conn()
from datetime import datetime, timezone
completed_at = scan_result.get("completed_at") or datetime.now(timezone.utc).isoformat()
with conn.cursor() as cur:
cur.execute(
"""
INSERT INTO scans (scan_id, subscription_id, started_at, completed_at, total_findings, score, cve_enrichment_status)
VALUES (%s, %s, %s, %s, %s, %s, %s)
ON CONFLICT (scan_id) DO NOTHING
INSERT INTO scans (scan_id, subscription_id, started_at, completed_at, total_findings, score, cve_enrichment_status, status, error_message)
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s)
ON CONFLICT (scan_id) DO UPDATE SET
completed_at = EXCLUDED.completed_at,
total_findings = EXCLUDED.total_findings,
score = EXCLUDED.score,
status = EXCLUDED.status,
error_message = EXCLUDED.error_message
""",
(
scan_result["scan_id"],
scan_result["subscription_id"],
scan_result["started_at"],
scan_result["completed_at"],
scan_result["total_findings"],
completed_at,
scan_result.get("total_findings", 0),
scan_result.get("score"),
scan_result.get("cve_enrichment_status", "PENDING"),
scan_result.get("status", "completed"),
scan_result.get("error_message"),
),
)
for f in scan_result.get("findings", []):
Expand Down Expand Up @@ -362,6 +379,104 @@ def update_scan_enrichment_status(self, scan_id: str, status: str) -> None:
conn.commit()
logger.info("Updated scan %s enrichment status to %s", scan_id, status)

def create_pending_scan(self, scan_id: str, subscription_id: str) -> None:
"""Create a scan record in the 'pending' state."""
conn = self._get_conn()
from datetime import datetime, timezone
started_at = datetime.now(timezone.utc).isoformat()
with conn.cursor() as cur:
cur.execute(
"""
INSERT INTO scans (scan_id, subscription_id, started_at, status)
VALUES (%s, %s, %s, 'pending')
""",
(scan_id, subscription_id, started_at),
)
conn.commit()
logger.info("Created pending scan %s for %s", scan_id, subscription_id)

def update_scan_status(self, scan_id: str, status: str, error_message: Optional[str] = None) -> None:
"""Update the status of a scan (running, completed, failed)."""
conn = self._get_conn()
from datetime import datetime, timezone
with conn.cursor() as cur:
if status == "completed":
completed_at = datetime.now(timezone.utc).isoformat()
cur.execute(
"UPDATE scans SET status = %s, completed_at = %s WHERE scan_id = %s",
(status, completed_at, scan_id),
)
else:
cur.execute(
"UPDATE scans SET status = %s, error_message = %s WHERE scan_id = %s",
(status, error_message, scan_id),
)
conn.commit()
logger.info("Updated scan %s status to %s", scan_id, status)

def claim_next_pending_scan(self) -> Optional[Dict[str, Any]]:
"""Atomically claim the next pending scan using SKIP LOCKED."""
conn = self._get_conn()
from datetime import datetime, timezone
claimed_at = datetime.now(timezone.utc).isoformat()
with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cur:
cur.execute(
"""
UPDATE scans
SET status = 'running', claimed_at = %s
WHERE scan_id = (
SELECT scan_id
FROM scans
WHERE status = 'pending'
ORDER BY started_at ASC
FOR UPDATE SKIP LOCKED
LIMIT 1
)
RETURNING *
""",
(claimed_at,)
)
row = cur.fetchone()
if row:
conn.commit()
return dict(row)
return None

def recover_stale_scans(self, timeout_minutes: int = 60) -> int:
"""Mark scans that have been 'running' for too long as 'failed'."""
conn = self._get_conn()
with conn.cursor() as cur:
cur.execute(
"""
UPDATE scans
SET status = 'failed',
error_message = 'Scan timed out after remaining in running state for too long.'
WHERE status = 'running'
AND claimed_at < (CURRENT_TIMESTAMP - INTERVAL '%s minutes')
""",
(timeout_minutes,)
)
count = cur.rowcount
conn.commit()
if count > 0:
logger.info("Recovered %d stale 'running' scans", count)
return count

def get_pending_scans(self) -> List[Dict[str, Any]]:
"""Return all scans in the 'pending' state."""
conn = self._get_conn()
with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cur:
cur.execute("SELECT * FROM scans WHERE status = 'pending' ORDER BY started_at ASC")
return [dict(row) for row in cur.fetchall()]

def get_scan(self, scan_id: str) -> Optional[Dict[str, Any]]:
"""Return a single scan record by its UUID."""
conn = self._get_conn()
with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cur:
cur.execute("SELECT * FROM scans WHERE scan_id = %s", (scan_id,))
row = cur.fetchone()
return dict(row) if row else None

def get_scans(self) -> List[Dict[str, Any]]:
"""Return all scan records ordered by most recent first."""
conn = self._get_conn()
Expand Down
50 changes: 27 additions & 23 deletions api/routes/scans.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import logging
import os
import uuid
from flask import Blueprint, g, jsonify, request

from api.models.finding import DatabaseManager
Expand Down Expand Up @@ -33,21 +34,29 @@ def list_scans():
return jsonify({"error": "Failed to retrieve scans", "detail": str(exc)}), 500


@scans_bp.get("/api/scans/<scan_id>")
def get_scan_status(scan_id):
"""Return the details and status of a specific scan."""
try:
db = _get_db()
scan = db.get_scan(scan_id)
if not scan:
return jsonify({"error": "Scan not found"}), 404
return jsonify(scan)
except Exception as exc:
logger.error("Failed to get scan status: %s", exc)
return jsonify({"error": "Database error", "detail": str(exc)}), 500


@scans_bp.post("/api/scans/trigger")
def trigger_scan():
"""Trigger a synchronous scan against the configured subscription.
"""Trigger an asynchronous scan against the configured subscription.

Accepts an optional JSON body with ``subscription_id``. Falls back to the
``AZURE_SUBSCRIPTION_ID`` environment variable if not provided.

Note: For production use, replace this with an async task queue (e.g.
Celery or Azure Functions) to avoid request timeouts on large subscriptions.
Returns 202 Accepted with the scan_id immediately.
"""
try:
from scanner.engine import ScanEngine
except ImportError:
return jsonify({"error": "Scanner module is not available"}), 500

try:
body = request.get_json(silent=True) or {}
subscription_id = body.get("subscription_id") or os.environ.get(
Expand All @@ -57,26 +66,21 @@ def trigger_scan():
if not subscription_id:
return jsonify({"error": "subscription_id is required"}), 400

logger.info("Scan triggered for subscription %s", subscription_id)

try:
engine = ScanEngine(subscription_id)
result = engine.run_scan()
except Exception as exc:
logger.error("Scan engine execution failed: %s", exc, exc_info=True)
return jsonify({"error": "Scan failed", "detail": str(exc)}), 500

if not isinstance(result, dict) or "scan_id" not in result:
return jsonify({"error": "Invalid scan result returned"}), 500
scan_id = str(uuid.uuid4())
logger.info("Async scan triggered for subscription %s (id: %s)", subscription_id, scan_id)

try:
db = _get_db()
db.save_scan(result)
db.create_pending_scan(scan_id, subscription_id)
except Exception as exc:
logger.error("Failed to save scan result: %s", exc, exc_info=True)
return jsonify({"error": "Database save failed", "detail": str(exc)}), 500
logger.error("Failed to create pending scan: %s", exc, exc_info=True)
return jsonify({"error": "Database error", "detail": str(exc)}), 500

return jsonify(result), 201
return jsonify({
"scan_id": scan_id,
"status": "pending",
"message": "Scan has been queued and will start shortly."
}), 202

except Exception as exc:
logger.error("Critical error in trigger_scan route: %s", exc, exc_info=True)
Expand Down
54 changes: 26 additions & 28 deletions docs/api-reference.md
Original file line number Diff line number Diff line change
Expand Up @@ -133,9 +133,32 @@ Example response:

---

## GET /api/scans/&lt;scan_id&gt;

Returns the details and current status of a specific scan.

Path parameters: `scan_id` — UUID of the scan.

Example response:

```json
{
"scan_id": "6f4a08ac-7d3a-4d9a-a4b4-2a26e5f63c8a",
"subscription_id": "00000000-0000-0000-0000-000000000000",
"status": "completed",
"started_at": "2026-05-09T12:00:00Z",
"completed_at": "2026-05-09T12:02:00Z",
"total_findings": 3,
"score": 85,
"error_message": null
}
```

---

## POST /api/scans/trigger

Runs a synchronous scan and saves the result to PostgreSQL. The request body may include `subscription_id`; otherwise the API uses `AZURE_SUBSCRIPTION_ID`.
Triggers an asynchronous scan against the configured subscription. Returns `202 Accepted` with the `scan_id` immediately. The actual scan execution happens in a background worker process.

Request body:

Expand All @@ -150,32 +173,8 @@ Example response:
```json
{
"scan_id": "6f4a08ac-7d3a-4d9a-a4b4-2a26e5f63c8a",
"subscription_id": "00000000-0000-0000-0000-000000000000",
"started_at": "2026-05-09T12:00:00+00:00",
"completed_at": "2026-05-09T12:02:00+00:00",
"total_findings": 1,
"findings": [
{
"rule_id": "AZ-STOR-001",
"rule_name": "Public Blob Access Enabled on Storage Account",
"severity": "HIGH",
"category": "Storage",
"resource_id": "/subscriptions/example/resourceGroups/rg/providers/Microsoft.Storage/storageAccounts/example",
"resource_name": "example",
"resource_type": "Microsoft.Storage/storageAccounts",
"description": "Storage accounts with public blob access enabled allow unauthenticated read access to blob data over the internet.",
"remediation": "Disable public blob access on the storage account.",
"playbook": "playbooks/cli/fix_az_stor_001.sh",
"frameworks": {
"CIS": "3.5",
"NIST": "PR.AC-3",
"ISO27001": "A.9.4.1"
},
"metadata": {},
"detected_at": "2026-05-09T12:00:00+00:00",
"scan_id": "6f4a08ac-7d3a-4d9a-a4b4-2a26e5f63c8a"
}
]
"status": "pending",
"message": "Scan has been queued and will start shortly."
}
```

Expand Down Expand Up @@ -437,4 +436,3 @@ The following endpoints are called by the frontend but have no backend implement
| Endpoint | Used by | Status |
|---|---|---|
| `GET /api/monitoring` | Monitoring page — score trend chart, category distribution | Deferred. Score and findings data come from `GET /api/score` and `GET /api/findings` instead. |
| `GET /api/scans/<scan_id>` | Header scan poller | Deferred. The frontend falls back to `GET /api/scans` and matches by `scan_id` in the response list. The poller is rarely entered because `POST /api/scans/trigger` now returns `status: completed` immediately. |
Loading
Loading