diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 46f174f..ced4755 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -331,11 +331,9 @@ jobs: # ── CHECK 8: Rule regression tests (MockAzureClient, no Azure creds) ── - name: Rule regression tests id: rule_tests - env: - DATABASE_URL: "postgresql://ci:ci@localhost/ci_db" run: | echo "=== Running rule regression tests ===" - pytest tests/test_rules_*.py -v --tb=short + pytest tests/test_rules_*.py tests/test_clean_scan.py -v --tb=short # ── Final summary — always runs, shows per-check pass/fail ──────── - name: CI Summary diff --git a/api/models/finding.py b/api/models/finding.py index f33b093..9286aad 100644 --- a/api/models/finding.py +++ b/api/models/finding.py @@ -316,11 +316,11 @@ def get_findings(self, filters: Optional[Dict[str, Any]] = None) -> List[Dict[st clauses.append("scan_id = %s") params.append(filters["scan_id"]) else: - # Default to the latest scan so historical findings do not inflate counts + # Default to the latest completed scan (includes clean scans with 0 findings) clauses.append( - "scan_id = (SELECT scan_id FROM scans WHERE total_findings > 0 ORDER BY started_at DESC LIMIT 1)" + "scan_id = (SELECT scan_id FROM scans WHERE status = 'completed' ORDER BY started_at DESC LIMIT 1)" ) - + where = "WHERE " + " AND ".join(clauses) if clauses else "" sql = f"SELECT * FROM findings {where} ORDER BY detected_at DESC LIMIT 1000" @@ -502,9 +502,10 @@ def get_score(self) -> int: SELECT severity, COUNT(*) FROM findings WHERE scan_id = ( - SELECT scan_id FROM scans WHERE total_findings > 0 ORDER BY started_at DESC LIMIT 1 + SELECT scan_id FROM scans WHERE status = 'completed' ORDER BY started_at DESC LIMIT 1 ) GROUP BY severity + GROUP BY severity """ ) rows = cur.fetchall() @@ -529,7 +530,7 @@ def get_cve_summary(self) -> Dict[str, Any]: FROM scans s LEFT JOIN findings f ON s.scan_id = f.scan_id WHERE s.scan_id = ( - SELECT scan_id FROM scans WHERE total_findings > 0 ORDER BY started_at DESC LIMIT 1 + SELECT scan_id FROM scans WHERE status = 'completed' ORDER BY started_at DESC LIMIT 1 ) GROUP BY s.cve_enrichment_status """) @@ -577,10 +578,17 @@ def get_compliance_score(self, framework: str) -> Dict[str, Any]: controls = framework_data.get("controls", {}) - # Get rule IDs that have at least one finding + # Get rule IDs that fired in the latest completed scan only conn = self._get_conn() with conn.cursor() as cur: - cur.execute("SELECT DISTINCT rule_id FROM findings") + cur.execute( + """ + SELECT DISTINCT rule_id FROM findings + WHERE scan_id = ( + SELECT scan_id FROM scans WHERE status = 'completed' ORDER BY started_at DESC LIMIT 1 + ) + """ + ) failed_rule_ids = {row[0] for row in cur.fetchall()} results = [] @@ -606,4 +614,4 @@ def get_compliance_score(self, framework: str) -> Dict[str, Any]: "failed": failed, "score_percent": score_pct, "controls": results, - } \ No newline at end of file + } diff --git a/tests/test_clean_scan.py b/tests/test_clean_scan.py new file mode 100644 index 0000000..dbdbda0 --- /dev/null +++ b/tests/test_clean_scan.py @@ -0,0 +1,185 @@ +"""Tests proving that a clean (zero-finding) completed scan is shown by posture +endpoints instead of falling back to stale data from an older scan.""" + +from unittest.mock import MagicMock, patch, call +import pytest + +from api.models.finding import DatabaseManager, FRAMEWORK_FILE_MAP + + +# ── helpers ────────────────────────────────────────────────────────────────── + +def _db() -> DatabaseManager: + """Return a DatabaseManager with a mock DSN (no real connection used).""" + db = DatabaseManager.__new__(DatabaseManager) + db.dsn = "postgresql://mock/mock" + db.conn = None + return db + + +def _mock_cursor(rows): + """Return a context-manager cursor mock that yields *rows* on fetchall().""" + cur = MagicMock() + cur.__enter__ = lambda s: s + cur.__exit__ = MagicMock(return_value=False) + cur.fetchall.return_value = rows + cur.fetchone.return_value = rows[0] if rows else None + return cur + + +# ── get_findings ────────────────────────────────────────────────────────────── + +def test_get_findings_uses_completed_status_not_total_findings(): + """get_findings() must filter on status='completed', not total_findings > 0.""" + db = _db() + conn = MagicMock() + cur = _mock_cursor([]) + conn.cursor.return_value = cur + + with patch.object(db, "_get_conn", return_value=conn): + db.get_findings() + + executed_sql = conn.cursor.return_value.execute.call_args[0][0] + assert "status = 'completed'" in executed_sql + assert "total_findings" not in executed_sql + + +def test_get_findings_clean_scan_returns_empty_list(): + """When the latest scan has no findings, get_findings() returns [].""" + db = _db() + conn = MagicMock() + cur = _mock_cursor([]) + conn.cursor.return_value = cur + + with patch.object(db, "_get_conn", return_value=conn): + result = db.get_findings() + + assert result == [] + + +# ── get_score ───────────────────────────────────────────────────────────────── + +def test_get_score_uses_completed_status(): + """get_score() must scope to status='completed', not total_findings > 0.""" + db = _db() + conn = MagicMock() + cur = _mock_cursor([]) + conn.cursor.return_value = cur + + with patch.object(db, "_get_conn", return_value=conn): + db.get_score() + + executed_sql = conn.cursor.return_value.execute.call_args[0][0] + assert "status = 'completed'" in executed_sql + assert "total_findings" not in executed_sql + + +def test_get_score_is_100_after_clean_scan(): + """A clean scan (no findings) must yield a perfect score of 100.""" + db = _db() + conn = MagicMock() + cur = _mock_cursor([]) + conn.cursor.return_value = cur + + with patch.object(db, "_get_conn", return_value=conn): + score = db.get_score() + + assert score == 100 + + +def test_get_score_does_not_include_old_scan_findings(): + """After a clean scan, old HIGH findings must not deduct points.""" + db = _db() + conn = MagicMock() + cur = _mock_cursor([]) + conn.cursor.return_value = cur + + with patch.object(db, "_get_conn", return_value=conn): + score = db.get_score() + + assert score == 100 + + +# ── get_compliance_score ────────────────────────────────────────────────────── + +def test_get_compliance_score_scopes_to_latest_scan(): + """get_compliance_score() must only look at findings from the latest completed + scan, not the entire findings table.""" + db = _db() + conn = MagicMock() + cur = _mock_cursor([]) + conn.cursor.return_value = cur + + with patch.object(db, "_get_conn", return_value=conn): + import json + fake_framework = json.dumps({ + "framework": "Test", + "version": "1.0", + "controls": {"AZ-STOR-001": {"control_id": "3.1", "control_name": "Test control"}}, + }) + import builtins + import io + with patch("builtins.open", return_value=io.StringIO(fake_framework)): + from pathlib import Path + with patch.object(Path, "exists", return_value=True): + db.get_compliance_score("cis") + + executed_sql = conn.cursor.return_value.execute.call_args[0][0] + assert "status = 'completed'" in executed_sql + assert "total_findings" not in executed_sql + + +def test_get_compliance_score_all_pass_after_clean_scan(): + """All controls must show PASS when the latest completed scan has no findings.""" + db = _db() + conn = MagicMock() + cur = _mock_cursor([]) + conn.cursor.return_value = cur + + import json, io + from pathlib import Path + fake_framework = json.dumps({ + "framework": "CIS Azure", + "version": "2.0", + "controls": { + "AZ-STOR-001": {"control_id": "3.1", "control_name": "No public blobs"}, + "AZ-NET-001": {"control_id": "6.1", "control_name": "No unrestricted SSH"}, + }, + }) + + with patch.object(db, "_get_conn", return_value=conn): + with patch("builtins.open", return_value=io.StringIO(fake_framework)): + with patch.object(Path, "exists", return_value=True): + result = db.get_compliance_score("cis") + + assert result["passed"] == 2 + assert result["failed"] == 0 + assert result["score_percent"] == 100 + statuses = {c["rule_id"]: c["status"] for c in result["controls"]} + assert statuses["AZ-STOR-001"] == "PASS" + assert statuses["AZ-NET-001"] == "PASS" + + +def test_get_compliance_score_remediated_rule_shows_pass(): + """A rule that fired in scan-1 but not scan-2 (clean) must show PASS.""" + db = _db() + conn = MagicMock() + cur = _mock_cursor([]) + conn.cursor.return_value = cur + + import json, io + from pathlib import Path + fake_framework = json.dumps({ + "framework": "CIS Azure", + "version": "2.0", + "controls": { + "AZ-STOR-001": {"control_id": "3.1", "control_name": "No public blobs"}, + }, + }) + + with patch.object(db, "_get_conn", return_value=conn): + with patch("builtins.open", return_value=io.StringIO(fake_framework)): + with patch.object(Path, "exists", return_value=True): + result = db.get_compliance_score("cis") + + assert result["controls"][0]["status"] == "PASS"