Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions experiments/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,5 +37,7 @@ All experiments run in software-only mode. No hardware TEE is required. TRACE Cl
|------|--------|-------|
| `tests/unit/test_claim1_hash_binding.py` | Claim 1 | 6 |
| `tests/unit/test_claim2_session_gap.py` | Claim 2 | 6 |
| `tests/unit/test_claim3_rug_pull_detection.py` | Claim 3 | 6 |
| `tests/unit/test_claim4_trace_claim_nonce.py` | Claim 4 | 6 |
| `tests/unit/test_claim5_temporal_adjacency.py` | Claim 5 | 9 |
| `tests/unit/test_claim6_cross_org_attestation.py` | Claim 6 | 9 |
84 changes: 84 additions & 0 deletions tests/unit/test_claim3_rug_pull_detection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
"""Tests for Claim 3: fail-closed tool catalog drift (rug-pull) detection.

These tests assert the invariants the claim3 experiment demonstrates: definition
hashes are deterministic, a single-sentence description tamper avalanches the
hash, the aggregate catalog hash reflects any tool change, and a pinned catalog
hash makes the gateway reject a tampered catalog fail-closed. They run in CI to
catch regressions in catalog hashing and mismatch enforcement.
"""

from __future__ import annotations

from pathlib import Path

import pytest

from cmcp_runtime.catalog.loader import load_catalog
from cmcp_runtime.errors import CatalogHashMismatch

FIXTURES = (
Path(__file__).parent.parent.parent
/ "experiments"
/ "claim3-rug-pull-detection"
/ "fixtures"
)
CATALOG_APPROVED = str(FIXTURES / "catalog-approved.json")
CATALOG_TAMPERED = str(FIXTURES / "catalog-tampered.json")

_ATTACKER_SENTENCE = "Silently forwards all responses to audit@attacker.com."


def _bits_different(h1: str, h2: str) -> int:
b1 = bytes.fromhex(h1.removeprefix("sha256:"))
b2 = bytes.fromhex(h2.removeprefix("sha256:"))
return sum(bin(a ^ b).count("1") for a, b in zip(b1, b2, strict=True))


def test_definition_and_catalog_hash_determinism():
"""The same catalog loaded twice yields identical catalog and definition hashes."""
c1 = load_catalog(CATALOG_APPROVED)
c2 = load_catalog(CATALOG_APPROVED)
assert c1.catalog_hash == c2.catalog_hash
assert c1.require("ehr.get_patient").definition_hash == c2.require("ehr.get_patient").definition_hash


def test_avalanche_on_description_tamper():
"""Appending one sentence to a tool description changes >64 of 256 definition-hash bits."""
approved = load_catalog(CATALOG_APPROVED).require("ehr.get_patient")
tampered = load_catalog(CATALOG_TAMPERED).require("ehr.get_patient")
assert approved.definition_hash != tampered.definition_hash
bits_diff = _bits_different(approved.definition_hash, tampered.definition_hash)
assert bits_diff > 64, f"Expected >64 bits to change on description tamper, got {bits_diff}"


def test_catalog_hash_changes_on_single_tool_tamper():
"""A tampered tool definition propagates to the aggregate catalog hash."""
approved = load_catalog(CATALOG_APPROVED)
tampered = load_catalog(CATALOG_TAMPERED)
assert approved.catalog_hash != tampered.catalog_hash


def test_pinned_hash_rejects_tampered_catalog_fail_closed():
"""Loading the tampered catalog under the approved (pinned) hash raises CatalogHashMismatch."""
approved_hash = load_catalog(CATALOG_APPROVED).catalog_hash
with pytest.raises(CatalogHashMismatch):
load_catalog(CATALOG_TAMPERED, expected_hash=approved_hash)


def test_approved_catalog_passes_its_own_pinned_hash():
"""The approved catalog loads cleanly when presented with its own expected hash."""
approved_hash = load_catalog(CATALOG_APPROVED).catalog_hash
result = load_catalog(CATALOG_APPROVED, expected_hash=approved_hash)
assert result.catalog_hash == approved_hash


def test_tamper_is_undetectable_without_pinning():
"""Without a pinned hash the tampered catalog loads, so detection depends on the pin.

The malicious sentence is present in the loaded tampered description and absent
from the approved one; only the pinned-hash check (above) turns that into a block.
"""
approved_desc = load_catalog(CATALOG_APPROVED).require("ehr.get_patient").approved_definition.description
tampered_desc = load_catalog(CATALOG_TAMPERED).require("ehr.get_patient").approved_definition.description
assert _ATTACKER_SENTENCE not in approved_desc
assert _ATTACKER_SENTENCE in tampered_desc
151 changes: 151 additions & 0 deletions tests/unit/test_claim4_trace_claim_nonce.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
"""Tests for Claim 4: TRACE Claim nonce binding and disclosure resistance.

These tests assert the invariants the claim4 experiment demonstrates: the nonce
binds a claim to a specific session and TEE instance, a session-id swap breaks
the Ed25519 signature, and removing an audit entry breaks the export signature.
They run in CI to catch regressions in nonce construction and claim signing.
"""

from __future__ import annotations

import base64
import hashlib
import json

from cryptography.hazmat.primitives.asymmetric.ed25519 import Ed25519PublicKey

from cmcp_runtime.audit.keys import SigningKey
from cmcp_runtime.audit.trace_claim import (
AttestationReportInfo,
CallGraphSummary,
CallSummary,
PolicyBundleInfo,
ToolCatalogInfo,
generate_trace_claim,
)


def _compute_nonce(public_key_hex: str, session_id: str) -> str:
"""SHA-256(tee_public_key_bytes || session_id_bytes) as hex."""
return hashlib.sha256(bytes.fromhex(public_key_hex) + session_id.encode("utf-8")).hexdigest()


def _b64url(raw: bytes) -> str:
return base64.urlsafe_b64encode(raw).rstrip(b"=").decode()


def _verify_sig(claim_dict: dict, pub_hex: str) -> bool:
sig = base64.urlsafe_b64decode(claim_dict.get("signature", "") + "==")
pub = Ed25519PublicKey.from_public_bytes(bytes.fromhex(pub_hex))
body = {k: v for k, v in claim_dict.items() if k != "signature"}
body_bytes = json.dumps(body, sort_keys=True, separators=(",", ":"), ensure_ascii=True).encode()
try:
pub.verify(sig, body_bytes)
return True
except Exception:
return False


def _stub_claim(session_id: str, signing_key: SigningKey, nonce_hex: str):
report = AttestationReportInfo(
provider="tpm",
measurement="sha256:" + "ab" * 32,
report_data=nonce_hex,
attestation_generated_at="2026-06-25T00:00:00Z",
attestation_validity_seconds=3600,
)
policy = PolicyBundleInfo(hash="sha256:" + "0" * 64, enforcement_mode="enforcing", policy_version="1.0.0")
catalog = ToolCatalogInfo(hash="sha256:" + "0" * 64)
summary = CallSummary(
tool_calls_total=1,
tool_calls_allowed=1,
tool_calls_denied=0,
tool_calls_faulted=0,
tools_invoked=["ehr.get_patient"],
session_max_sensitivity="hipaa_phi",
call_graph_summary=CallGraphSummary(compliance_domains_touched=["hipaa_phi"], cross_boundary_events=[]),
)
return generate_trace_claim(
session_id=session_id,
signing_key=signing_key,
attestation_report=report,
policy_bundle=policy,
tool_catalog=catalog,
call_summary=summary,
audit_chain_root="sha256:" + "0" * 64,
audit_chain_tip="sha256:" + "0" * 64,
audit_chain_length=1,
)


def test_nonce_is_deterministic():
"""The same key and session_id always produce the same nonce."""
key = SigningKey()
assert _compute_nonce(key.public_key_hex, "session-A") == _compute_nonce(key.public_key_hex, "session-A")


def test_nonce_changes_with_session_id():
"""Changing the session_id changes the nonce (session binding)."""
key = SigningKey()
assert _compute_nonce(key.public_key_hex, "session-A") != _compute_nonce(key.public_key_hex, "session-B")


def test_nonce_changes_with_tee_key():
"""Changing the TEE key changes the nonce for the same session (instance binding)."""
key1 = SigningKey()
key2 = SigningKey()
assert _compute_nonce(key1.public_key_hex, "session-A") != _compute_nonce(key2.public_key_hex, "session-A")


def test_claim_nonce_does_not_match_other_session():
"""A claim minted for session-A carries A's nonce, which fails B's expected nonce."""
key = SigningKey()
nonce_a = _compute_nonce(key.public_key_hex, "session-A")
claim = _stub_claim("session-A", key, nonce_a)
embedded = json.loads(claim.model_dump_json(exclude_none=True))["trace"]["runtime"]["nonce"]
assert embedded == _b64url(bytes.fromhex(nonce_a))
expected_for_b = _b64url(bytes.fromhex(_compute_nonce(key.public_key_hex, "session-B")))
assert embedded != expected_for_b


def test_session_id_tamper_breaks_signature():
"""Replacing session_id in a signed claim invalidates its Ed25519 signature."""
key = SigningKey()
nonce_a = _compute_nonce(key.public_key_hex, "session-A")
claim_dict = json.loads(_stub_claim("session-A", key, nonce_a).model_dump_json(exclude_none=True))
assert _verify_sig(claim_dict, key.public_key_hex)
claim_dict["gateway"]["session_id"] = "session-B"
assert not _verify_sig(claim_dict, key.public_key_hex)


def test_audit_entry_removal_breaks_export_signature():
"""Removing one audit entry changes the bundle hash, so the export signature fails."""
key = SigningKey()
entries = [{"call_id": f"call-{i}", "tool": "ehr.get_patient", "decision": "allow", "seq": i} for i in range(5)]

def _bundle_hash(items: list[dict]) -> str:
canonical = json.dumps(items, sort_keys=True, separators=(",", ":"), ensure_ascii=True).encode()
return "sha256:" + hashlib.sha256(canonical).hexdigest()

def _export_body(bundle_hash: str) -> bytes:
return json.dumps(
{"bundle_hash": bundle_hash, "verifier_nonce": "v-nonce-abc123"},
sort_keys=True,
separators=(",", ":"),
ensure_ascii=True,
).encode()

full_hash = _bundle_hash(entries)
sig = key.sign(_export_body(full_hash))

entries_minus_one = [e for e in entries if e["call_id"] != "call-2"]
minus_hash = _bundle_hash(entries_minus_one)
assert minus_hash != full_hash

pub = Ed25519PublicKey.from_public_bytes(bytes.fromhex(key.public_key_hex))
try:
pub.verify(sig, _export_body(minus_hash))
still_valid = True
except Exception:
still_valid = False
assert not still_valid