Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -331,11 +331,9 @@ jobs:
# ── CHECK 8: Rule regression tests (MockAzureClient, no Azure creds) ──
- name: Rule regression tests
id: rule_tests
env:
DATABASE_URL: "postgresql://ci:ci@localhost/ci_db"
run: |
echo "=== Running rule regression tests ==="
pytest tests/test_rules_*.py -v --tb=short
pytest tests/test_rules_*.py tests/test_clean_scan.py -v --tb=short

# ── Final summary — always runs, shows per-check pass/fail ────────
- name: CI Summary
Expand Down
24 changes: 16 additions & 8 deletions api/models/finding.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,11 +316,11 @@ def get_findings(self, filters: Optional[Dict[str, Any]] = None) -> List[Dict[st
clauses.append("scan_id = %s")
params.append(filters["scan_id"])
else:
# Default to the latest scan so historical findings do not inflate counts
# Default to the latest completed scan (includes clean scans with 0 findings)
clauses.append(
"scan_id = (SELECT scan_id FROM scans WHERE total_findings > 0 ORDER BY started_at DESC LIMIT 1)"
"scan_id = (SELECT scan_id FROM scans WHERE status = 'completed' ORDER BY started_at DESC LIMIT 1)"
)

where = "WHERE " + " AND ".join(clauses) if clauses else ""
sql = f"SELECT * FROM findings {where} ORDER BY detected_at DESC LIMIT 1000"

Expand Down Expand Up @@ -502,9 +502,10 @@ def get_score(self) -> int:
SELECT severity, COUNT(*)
FROM findings
WHERE scan_id = (
SELECT scan_id FROM scans WHERE total_findings > 0 ORDER BY started_at DESC LIMIT 1
SELECT scan_id FROM scans WHERE status = 'completed' ORDER BY started_at DESC LIMIT 1
)
GROUP BY severity
GROUP BY severity
"""
)
rows = cur.fetchall()
Expand All @@ -529,7 +530,7 @@ def get_cve_summary(self) -> Dict[str, Any]:
FROM scans s
LEFT JOIN findings f ON s.scan_id = f.scan_id
WHERE s.scan_id = (
SELECT scan_id FROM scans WHERE total_findings > 0 ORDER BY started_at DESC LIMIT 1
SELECT scan_id FROM scans WHERE status = 'completed' ORDER BY started_at DESC LIMIT 1
)
GROUP BY s.cve_enrichment_status
""")
Expand Down Expand Up @@ -577,10 +578,17 @@ def get_compliance_score(self, framework: str) -> Dict[str, Any]:

controls = framework_data.get("controls", {})

# Get rule IDs that have at least one finding
# Get rule IDs that fired in the latest completed scan only
conn = self._get_conn()
with conn.cursor() as cur:
cur.execute("SELECT DISTINCT rule_id FROM findings")
cur.execute(
"""
SELECT DISTINCT rule_id FROM findings
WHERE scan_id = (
SELECT scan_id FROM scans WHERE status = 'completed' ORDER BY started_at DESC LIMIT 1
)
"""
)
failed_rule_ids = {row[0] for row in cur.fetchall()}

results = []
Expand All @@ -606,4 +614,4 @@ def get_compliance_score(self, framework: str) -> Dict[str, Any]:
"failed": failed,
"score_percent": score_pct,
"controls": results,
}
}
185 changes: 185 additions & 0 deletions tests/test_clean_scan.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
"""Tests proving that a clean (zero-finding) completed scan is shown by posture
endpoints instead of falling back to stale data from an older scan."""

from unittest.mock import MagicMock, patch, call
import pytest

from api.models.finding import DatabaseManager, FRAMEWORK_FILE_MAP


# ── helpers ──────────────────────────────────────────────────────────────────

def _db() -> DatabaseManager:
"""Return a DatabaseManager with a mock DSN (no real connection used)."""
db = DatabaseManager.__new__(DatabaseManager)
db.dsn = "postgresql://mock/mock"
db.conn = None
return db


def _mock_cursor(rows):
"""Return a context-manager cursor mock that yields *rows* on fetchall()."""
cur = MagicMock()
cur.__enter__ = lambda s: s
cur.__exit__ = MagicMock(return_value=False)
cur.fetchall.return_value = rows
cur.fetchone.return_value = rows[0] if rows else None
return cur


# ── get_findings ──────────────────────────────────────────────────────────────

def test_get_findings_uses_completed_status_not_total_findings():
"""get_findings() must filter on status='completed', not total_findings > 0."""
db = _db()
conn = MagicMock()
cur = _mock_cursor([])
conn.cursor.return_value = cur

with patch.object(db, "_get_conn", return_value=conn):
db.get_findings()

executed_sql = conn.cursor.return_value.execute.call_args[0][0]
assert "status = 'completed'" in executed_sql
assert "total_findings" not in executed_sql


def test_get_findings_clean_scan_returns_empty_list():
"""When the latest scan has no findings, get_findings() returns []."""
db = _db()
conn = MagicMock()
cur = _mock_cursor([])
conn.cursor.return_value = cur

with patch.object(db, "_get_conn", return_value=conn):
result = db.get_findings()

assert result == []


# ── get_score ─────────────────────────────────────────────────────────────────

def test_get_score_uses_completed_status():
"""get_score() must scope to status='completed', not total_findings > 0."""
db = _db()
conn = MagicMock()
cur = _mock_cursor([])
conn.cursor.return_value = cur

with patch.object(db, "_get_conn", return_value=conn):
db.get_score()

executed_sql = conn.cursor.return_value.execute.call_args[0][0]
assert "status = 'completed'" in executed_sql
assert "total_findings" not in executed_sql


def test_get_score_is_100_after_clean_scan():
"""A clean scan (no findings) must yield a perfect score of 100."""
db = _db()
conn = MagicMock()
cur = _mock_cursor([])
conn.cursor.return_value = cur

with patch.object(db, "_get_conn", return_value=conn):
score = db.get_score()

assert score == 100


def test_get_score_does_not_include_old_scan_findings():
"""After a clean scan, old HIGH findings must not deduct points."""
db = _db()
conn = MagicMock()
cur = _mock_cursor([])
conn.cursor.return_value = cur

with patch.object(db, "_get_conn", return_value=conn):
score = db.get_score()

assert score == 100


# ── get_compliance_score ──────────────────────────────────────────────────────

def test_get_compliance_score_scopes_to_latest_scan():
"""get_compliance_score() must only look at findings from the latest completed
scan, not the entire findings table."""
db = _db()
conn = MagicMock()
cur = _mock_cursor([])
conn.cursor.return_value = cur

with patch.object(db, "_get_conn", return_value=conn):
import json
fake_framework = json.dumps({
"framework": "Test",
"version": "1.0",
"controls": {"AZ-STOR-001": {"control_id": "3.1", "control_name": "Test control"}},
})
import builtins
import io
with patch("builtins.open", return_value=io.StringIO(fake_framework)):
from pathlib import Path
with patch.object(Path, "exists", return_value=True):
db.get_compliance_score("cis")

executed_sql = conn.cursor.return_value.execute.call_args[0][0]
assert "status = 'completed'" in executed_sql
assert "total_findings" not in executed_sql


def test_get_compliance_score_all_pass_after_clean_scan():
"""All controls must show PASS when the latest completed scan has no findings."""
db = _db()
conn = MagicMock()
cur = _mock_cursor([])
conn.cursor.return_value = cur

import json, io
from pathlib import Path
fake_framework = json.dumps({
"framework": "CIS Azure",
"version": "2.0",
"controls": {
"AZ-STOR-001": {"control_id": "3.1", "control_name": "No public blobs"},
"AZ-NET-001": {"control_id": "6.1", "control_name": "No unrestricted SSH"},
},
})

with patch.object(db, "_get_conn", return_value=conn):
with patch("builtins.open", return_value=io.StringIO(fake_framework)):
with patch.object(Path, "exists", return_value=True):
result = db.get_compliance_score("cis")

assert result["passed"] == 2
assert result["failed"] == 0
assert result["score_percent"] == 100
statuses = {c["rule_id"]: c["status"] for c in result["controls"]}
assert statuses["AZ-STOR-001"] == "PASS"
assert statuses["AZ-NET-001"] == "PASS"


def test_get_compliance_score_remediated_rule_shows_pass():
"""A rule that fired in scan-1 but not scan-2 (clean) must show PASS."""
db = _db()
conn = MagicMock()
cur = _mock_cursor([])
conn.cursor.return_value = cur

import json, io
from pathlib import Path
fake_framework = json.dumps({
"framework": "CIS Azure",
"version": "2.0",
"controls": {
"AZ-STOR-001": {"control_id": "3.1", "control_name": "No public blobs"},
},
})

with patch.object(db, "_get_conn", return_value=conn):
with patch("builtins.open", return_value=io.StringIO(fake_framework)):
with patch.object(Path, "exists", return_value=True):
result = db.get_compliance_score("cis")

assert result["controls"][0]["status"] == "PASS"
Loading