|
| 1 | +import importlib |
| 2 | +import json |
| 3 | +import sys |
| 4 | +from types import ModuleType |
| 5 | + |
| 6 | + |
| 7 | +def test_metrics_store_inc_and_rule_counters(tmp_path, monkeypatch): |
| 8 | + monkeypatch.chdir(tmp_path) |
| 9 | + import src.metrics.store as store |
| 10 | + importlib.reload(store) |
| 11 | + |
| 12 | + store.inc("allowed", 2) |
| 13 | + store.inc_rule_action("r1", "block", 3) |
| 14 | + data = store.get_all() |
| 15 | + assert data.get("allowed") == 2 and data.get("rule:r1:block") == 3 |
| 16 | + rc = store.get_rule_counters("r1") |
| 17 | + assert rc == {"block": 3} |
| 18 | + |
| 19 | + |
| 20 | +def test_decisions_append_and_tail_limit(tmp_path, monkeypatch): |
| 21 | + monkeypatch.chdir(tmp_path) |
| 22 | + import src.metrics.decisions as dec |
| 23 | + importlib.reload(dec) |
| 24 | + dec.append({"action": "allow", "rule_id": "rA"}) |
| 25 | + dec.append({"action": "block", "rule_id": "rB"}) |
| 26 | + # Inject a bad line |
| 27 | + p = tmp_path / "data/metrics/decisions.jsonl" |
| 28 | + p.write_text(p.read_text(encoding="utf-8") + "bad\n", encoding="utf-8") |
| 29 | + items = dec.tail(limit=3) |
| 30 | + assert any(it.get("rule_id") == "rB" for it in items) |
| 31 | + assert all("ts" in it for it in items if isinstance(it, dict)) |
| 32 | + |
| 33 | + |
| 34 | +def test_policy_loader_valid_and_invalid(tmp_path, monkeypatch): |
| 35 | + monkeypatch.chdir(tmp_path) |
| 36 | + rules = [ |
| 37 | + {"id": "v1", "target": "table", "selector": "dbo.T", "action": "allow"}, |
| 38 | + {"id": "bad-tgt", "target": "xxx", "selector": "dbo.T", "action": "allow"}, |
| 39 | + {"id": "bad-act", "target": "table", "selector": "dbo.T", "action": "noop"}, |
| 40 | + {"id": "bad-sel", "target": "table", "selector": None, "action": "block"}, |
| 41 | + ] |
| 42 | + (tmp_path / "config").mkdir(parents=True, exist_ok=True) |
| 43 | + (tmp_path / "config/rules.json").write_text(json.dumps(rules), encoding="utf-8") |
| 44 | + from src.policy.loader import load_rules |
| 45 | + out = load_rules(str(tmp_path / "config/rules.json")) |
| 46 | + assert [r.id for r in out] == ["v1"] |
| 47 | + |
| 48 | + |
| 49 | +def test_policy_engine_matching_and_env(): |
| 50 | + from src.policy.engine import PolicyEngine, Event, Rule |
| 51 | + |
| 52 | + rules = [ |
| 53 | + Rule(id="t", target="table", selector="dbo.T", action="block", reason="tbl"), |
| 54 | + Rule(id="c1", target="column", selector="Email", action="autocorrect", reason="col"), |
| 55 | + Rule(id="p", target="pattern", selector="INSERT INTO", action="block", reason="pat"), |
| 56 | + Rule(id="env", target="table", selector="dbo.Env", action="block", reason="env", enabled=True), |
| 57 | + ] |
| 58 | + pe = PolicyEngine(rules, environment="prod") |
| 59 | + # table match |
| 60 | + d1 = pe.decide(Event(database=None, user=None, sql_text=None, table="dbo.T", column=None, value=None)) |
| 61 | + assert d1.action == "block" and d1.rule_id == "t" |
| 62 | + # column-only match matches both dotted and parameter forms |
| 63 | + d2 = pe.decide(Event(None, None, None, None, "dbo.Users.Email", None)) |
| 64 | + d3 = pe.decide(Event(None, None, None, None, "@Email", None)) |
| 65 | + assert d2.action == d3.action == "autocorrect" |
| 66 | + # pattern match |
| 67 | + d4 = pe.decide(Event(None, None, "insert into dbo.x values(1)", None, None, None)) |
| 68 | + assert d4.rule_id == "p" |
| 69 | + # get_rule |
| 70 | + assert pe.get_rule("t").id == "t" |
| 71 | + |
| 72 | + |
| 73 | +def test_metrics_prom_fallback_lines(tmp_path, monkeypatch): |
| 74 | + # Force fallback path by inserting a dummy prometheus_client without required symbols |
| 75 | + dummy = ModuleType("prometheus_client") |
| 76 | + sys.modules["prometheus_client"] = dummy |
| 77 | + monkeypatch.chdir(tmp_path) |
| 78 | + # Prepare metrics |
| 79 | + mdir = tmp_path / "data/metrics" |
| 80 | + mdir.mkdir(parents=True) |
| 81 | + (mdir / "metrics.json").write_text(json.dumps({"allowed": 5, "rule:rX:block": 2}), encoding="utf-8") |
| 82 | + |
| 83 | + api = importlib.import_module("src.api") |
| 84 | + importlib.reload(api) |
| 85 | + resp = api.metrics_prom() |
| 86 | + content = resp.content if hasattr(resp, "content") else resp # shim in minimal env |
| 87 | + text = content.decode("utf-8") if hasattr(content, "decode") else str(content) |
| 88 | + assert "sqlumai_metric{key=\"allowed\"} 5" in text |
| 89 | + assert "sqlumai_metric{key=\"rule\",rule=\"rX\",action=\"block\"} 2" in text |
0 commit comments