diff --git a/agentmint/__init__.py b/agentmint/__init__.py index f942911..f9c741a 100644 --- a/agentmint/__init__.py +++ b/agentmint/__init__.py @@ -30,6 +30,7 @@ from .types import DelegationStatus, DelegationResult, EnforceMode from .decorator import ( AuthorizationError, + notarise, require_receipt, set_receipt, get_receipt, @@ -58,6 +59,7 @@ "ReplayError", "DeniedError", "AuthorizationError", + "notarise", # Decorator "require_receipt", "set_receipt", diff --git a/agentmint/chain.py b/agentmint/chain.py new file mode 100644 index 0000000..30092bb --- /dev/null +++ b/agentmint/chain.py @@ -0,0 +1,62 @@ +"""Receipt chain utilities.""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Optional, Sequence + +from .patterns import matches_pattern + + +def intersect_scopes( + parent_scope: Sequence[str], + requested: Sequence[str], +) -> tuple[str, ...]: + """Return the effective delegated scope.""" + + result = [] + for child in requested: + for parent in parent_scope: + if child == parent: + if child not in result: + result.append(child) + elif matches_pattern(child, parent): + if child not in result: + result.append(child) + elif matches_pattern(parent, child): + if parent not in result: + result.append(parent) + return tuple(result) + + +@dataclass(frozen=True) +class ChainVerification: + """Result of verifying an ordered receipt chain.""" + + valid: bool + length: int + root_hash: str + break_at_index: Optional[int] = None + reason: str = "" + + +def verify_chain(receipts: Sequence[object]) -> ChainVerification: + """Verify chain linkage using previous receipt hashes.""" + + if not receipts: + return ChainVerification(valid=True, length=0, root_hash="") + + previous_hash = None + for index, receipt in enumerate(receipts): + current_previous = getattr(receipt, "previous_receipt_hash", None) + if current_previous != previous_hash: + return ChainVerification( + valid=False, + length=len(receipts), + root_hash="", + break_at_index=index, + reason="chain break at index %d" % index, + ) + previous_hash = getattr(receipt, "canonical_hash")() + + return ChainVerification(valid=True, length=len(receipts), root_hash=previous_hash or "") diff --git a/agentmint/decorator.py b/agentmint/decorator.py index 95aaae9..6232dea 100644 --- a/agentmint/decorator.py +++ b/agentmint/decorator.py @@ -1,4 +1,4 @@ -"""Decorator for protecting functions with receipts.""" +"""Decorator helpers for AgentMint authorization and notarisation.""" from __future__ import annotations from contextvars import ContextVar @@ -82,3 +82,44 @@ def wrapper(*args: P.args, **kwargs: P.kwargs) -> T: return wrapper return decorator + + +def notarise( + notary, + action: Optional[str] = None, + plan=None, + agent: Optional[str] = None, + evidence=None, + enable_timestamp: bool = True, +) -> Callable[[Callable[P, T]], Callable[P, T]]: + """Decorator that records a receipt after a successful function call.""" + + def decorator(func: Callable[P, T]) -> Callable[P, T]: + @wraps(func) + def wrapper(*args: P.args, **kwargs: P.kwargs) -> T: + result = func(*args, **kwargs) + if callable(evidence): + receipt_evidence = evidence(*args, **kwargs, result=result) + elif evidence is None: + receipt_evidence = { + "function": func.__name__, + "args": list(args), + "kwargs": kwargs, + } + else: + receipt_evidence = dict(evidence) + + receipt_action = action or func.__name__ + wrapper.last_receipt = notary.notarise( + action=receipt_action, + agent=agent, + plan=plan, + evidence=receipt_evidence, + enable_timestamp=enable_timestamp, + ) + return result + + wrapper.last_receipt = None # type: ignore[attr-defined] + return wrapper + + return decorator diff --git a/agentmint/notary.py b/agentmint/notary.py index a0462e8..9e5e04b 100644 --- a/agentmint/notary.py +++ b/agentmint/notary.py @@ -1,27 +1,8 @@ -""" -AgentMint Notary — passive evidence signing for AI agent actions. - -AgentMint is a notary, not a gatekeeper. It never touches API calls. -It observes what happened after the fact and produces cryptographically -signed, independently timestamped evidence receipts. - -A receipt proves: - - What action was taken (evidence hash, extracted fields) - - Whether it was within policy (scope evaluation result) - - When it was observed (RFC 3161 timestamp via FreeTSA) - - Who approved the policy (chain to plan receipt) - - Chain integrity (SHA-256 hash of previous receipt) - -Verification requires only OpenSSL. No AgentMint software or account. - -AIUC-1 control mapping: - E015 Log model activity — receipt IS the signed log entry - D003 Restrict unsafe calls — in_policy proves evaluation happened - B001 Adversarial testing — evidence package proves controls tested -""" +"""AgentMint notary for signed AERF evidence receipts.""" from __future__ import annotations +import asyncio import base64 import hashlib import json @@ -32,24 +13,27 @@ import uuid import zipfile from collections import deque -from dataclasses import dataclass, field, replace -from datetime import datetime, timezone, timedelta +from dataclasses import dataclass, replace +from datetime import datetime, timedelta, timezone from pathlib import Path -from typing import Any, Final, Optional, Sequence +from typing import Any, Final, Mapping, Optional, Sequence from nacl.encoding import HexEncoder -from nacl.signing import SigningKey, VerifyKey from nacl.exceptions import BadSignatureError +from nacl.signing import SigningKey, VerifyKey -from .patterns import matches_pattern, in_scope +from .chain import ChainVerification, intersect_scopes, verify_chain +from .patterns import matches_pattern +from .policy import PolicyDecision, ScopeMatchPolicy, evaluate_policy +from .plan import Plan +from .providers.keys import DEFAULT_KEY_DIR, FileKeyProvider +from .providers.redactors import NoRedactor +from .providers.serializers import JCSSerializer +from .providers.sinks import FileSink +from .providers.timestamp import NoTimestamper, RFC3161Timestamper, TimestampRecord +from .receipt import Receipt +from .timestamp import fetch_ca_certs from .types import EnforceMode -from .timestamp import ( - TimestampResult, - TimestampError, - timestamp as ts_timestamp, - fetch_ca_certs, - verify as ts_verify, -) __all__ = [ "Notary", @@ -61,51 +45,48 @@ "ChainVerification", "verify_chain", "intersect_scopes", + "evaluate_policy", + "_canonical_json", + "_public_key_pem", ] -# ── Constants ────────────────────────────────────────────── - MAX_ACTION_LEN: Final[int] = 128 MAX_IDENTITY_LEN: Final[int] = 256 MAX_EVIDENCE_BYTES: Final[int] = 1024 * 1024 DEFAULT_TTL: Final[int] = 300 MAX_TTL: Final[int] = 3600 MIN_TTL: Final[int] = 1 - AIUC_CONTROLS: Final[tuple[str, ...]] = ("E015", "D003", "B001") - -# Ed25519 SPKI prefix (RFC 8410): 302a300506032b6570032100 +DEFAULT_TSA_URLS: Final[list[str]] = ["https://freetsa.org/tsr"] +_CHAIN_STATE_FILE = "chain_state.json" _SPKI_PREFIX: Final[bytes] = bytes.fromhex("302a300506032b6570032100") -# Default TSA URLs — improvement 4.5 -DEFAULT_TSA_URLS: Final[list[str]] = [ - "https://freetsa.org/tsr", -] - - -# ── Errors ───────────────────────────────────────────────── +PlanReceipt = Plan +NotarisedReceipt = Receipt +PolicyEvaluation = PolicyDecision class NotaryError(Exception): - """Raised when notarisation fails. Message is always actionable.""" - - pass + """Raised when notarisation fails.""" -# ── Validation ───────────────────────────────────────────── +def _utc_now() -> datetime: + return datetime.now(timezone.utc) def _require_non_empty_string(value: str, name: str, max_len: int) -> str: if not isinstance(value, str): - raise NotaryError(f"{name} must be a string, got {type(value).__name__}") + raise NotaryError("%s must be a string, got %s" % (name, type(value).__name__)) stripped = value.strip() if not stripped: - raise NotaryError(f"{name} must not be empty") + raise NotaryError("%s must not be empty" % name) if len(stripped) > max_len: - raise NotaryError(f"{name} must be at most {max_len} characters, got {len(stripped)}") - if any(ord(c) < 32 for c in stripped): - raise NotaryError(f"{name} contains control characters") + raise NotaryError( + "%s must be at most %d characters, got %d" % (name, max_len, len(stripped)) + ) + if any(ord(char) < 32 for char in stripped): + raise NotaryError("%s contains control characters" % name) return stripped @@ -113,396 +94,232 @@ def _require_string_list(value: Sequence[str] | None, name: str) -> tuple[str, . if value is None: return () if not isinstance(value, (list, tuple)): - raise NotaryError(f"{name} must be a list, got {type(value).__name__}") + raise NotaryError("%s must be a list, got %s" % (name, type(value).__name__)) result = [] - for i, item in enumerate(value): + for index, item in enumerate(value): if not isinstance(item, str) or not item.strip(): - raise NotaryError(f"{name}[{i}] must be a non-empty string") + raise NotaryError("%s[%d] must be a non-empty string" % (name, index)) result.append(item.strip()) return tuple(result) def _require_evidence(evidence: Any) -> dict[str, Any]: if not isinstance(evidence, dict): - raise NotaryError(f"evidence must be a dict, got {type(evidence).__name__}") + raise NotaryError("evidence must be a dict, got %s" % type(evidence).__name__) try: - raw = json.dumps(evidence, sort_keys=True).encode("utf-8") - except (TypeError, ValueError) as e: - raise NotaryError(f"evidence must be JSON-serializable: {e}") from e + raw = _canonical_json(evidence) + except (TypeError, ValueError) as exc: + raise NotaryError("evidence must be JSON-serializable: %s" % exc) from exc if len(raw) > MAX_EVIDENCE_BYTES: raise NotaryError( - f"serialized evidence is {len(raw):,} bytes, max is {MAX_EVIDENCE_BYTES:,}" + "serialized evidence is %d bytes, max is %d" % (len(raw), MAX_EVIDENCE_BYTES) ) - return evidence + return dict(evidence) def _clamp_ttl(ttl: int) -> int: return max(MIN_TTL, min(MAX_TTL, ttl)) -# ── PEM helper ───────────────────────────────────────────── - - -def _public_key_pem(verify_key: VerifyKey) -> str: - """Encode an Ed25519 public key as SPKI PEM (RFC 8410).""" - der = _SPKI_PREFIX + bytes(verify_key) - b64 = base64.b64encode(der).decode() - lines = [b64[i : i + 64] for i in range(0, len(b64), 64)] - return f"-----BEGIN PUBLIC KEY-----\n" + "\n".join(lines) + f"\n-----END PUBLIC KEY-----\n" - - -# ── Policy evaluation ───────────────────────────────────── - - -@dataclass(frozen=True) -class PolicyEvaluation: - """Result of evaluating an action against a plan's policy rules.""" - - in_policy: bool - reason: str - - -def evaluate_policy( - action: str, - agent: str, - plan_scope: Sequence[str], - plan_checkpoints: Sequence[str], - plan_delegates: Sequence[str], - plan_expired: bool, -) -> PolicyEvaluation: - """Evaluate whether an action is within policy. Pure function.""" - if plan_expired: - return PolicyEvaluation(False, "plan expired") - if plan_delegates and agent not in plan_delegates: - return PolicyEvaluation(False, f"agent '{agent}' not in delegates_to") - for pattern in plan_checkpoints: - if matches_pattern(action, pattern): - return PolicyEvaluation(False, f"matched checkpoint {pattern}") - for pattern in plan_scope: - if matches_pattern(action, pattern): - return PolicyEvaluation(True, f"matched scope {pattern}") - return PolicyEvaluation(False, "no scope pattern matched") - +def _canonical_json(data: Mapping[str, Any]) -> bytes: + return JCSSerializer().canonicalize(data) -# ── Signing ──────────────────────────────────────────────── - -def _canonical_json(data: dict[str, Any]) -> bytes: - return json.dumps(data, sort_keys=True, separators=(",", ":")).encode("utf-8") +def _derive_key_id(verify_key: VerifyKey) -> str: + return hashlib.sha256(bytes(verify_key)).hexdigest()[:16] -def _sign(key: SigningKey, data: dict[str, Any]) -> str: - return key.sign(_canonical_json(data)).signature.hex() +def _public_key_pem(verify_key: VerifyKey) -> str: + der = _SPKI_PREFIX + bytes(verify_key) + b64 = base64.b64encode(der).decode("ascii") + lines = [b64[index : index + 64] for index in range(0, len(b64), 64)] + return "-----BEGIN PUBLIC KEY-----\n%s\n-----END PUBLIC KEY-----\n" % "\n".join(lines) -def _derive_key_id(verify_key: VerifyKey) -> str: - """First 8 bytes of SHA-256(public_key), hex. Stable across restarts.""" - return hashlib.sha256(bytes(verify_key)).hexdigest()[:16] +def _sign_payload( + signing_key: SigningKey, serializer: JCSSerializer, payload: Mapping[str, Any] +) -> str: + return signing_key.sign(serializer.canonicalize(payload)).signature.hex() -def _verify_signature(verify_key: VerifyKey, data: dict[str, Any], signature_hex: str) -> bool: +def _verify_signature( + verify_key: VerifyKey, + serializer: JCSSerializer, + payload: Mapping[str, Any], + signature_hex: str, +) -> bool: try: - verify_key.verify(_canonical_json(data), bytes.fromhex(signature_hex)) + verify_key.verify(serializer.canonicalize(payload), bytes.fromhex(signature_hex)) return True except (BadSignatureError, ValueError): return False -# ── Data classes ─────────────────────────────────────────── - - -@dataclass(frozen=True) -class PlanReceipt: - """Signed plan defining what actions are allowed.""" - - id: str - user: str - action: str - scope: tuple[str, ...] - checkpoints: tuple[str, ...] - delegates_to: tuple[str, ...] - issued_at: str - expires_at: str - signature: str - key_id: str = "" - - @property - def short_id(self) -> str: - return self.id[:8] +def _compute_policy_hash(plan: Plan) -> str: + policy_data = { + "scope": list(plan.scope), + "checkpoints": list(plan.checkpoints), + "delegates_to": list(plan.delegates_to), + } + return hashlib.sha256(_canonical_json(policy_data)).hexdigest() - @property - def is_expired(self) -> bool: - return _utc_now() >= datetime.fromisoformat(self.expires_at) - def signable_dict(self) -> dict[str, Any]: - return { - "id": self.id, - "type": "plan", - "user": self.user, - "action": self.action, - "scope": list(self.scope), - "checkpoints": list(self.checkpoints), - "delegates_to": list(self.delegates_to), - "issued_at": self.issued_at, - "expires_at": self.expires_at, - "key_id": self.key_id, - } +class _EphemeralKeyProvider: + def __init__(self) -> None: + self._signing_key = SigningKey.generate() - def to_dict(self) -> dict[str, Any]: - d = self.signable_dict() - d["signature"] = self.signature - return d - - -@dataclass(frozen=True) -class NotarisedReceipt: - """Signed, timestamped evidence receipt for a single agent action.""" - - id: str - plan_id: str - agent: str - action: str - in_policy: bool - policy_reason: str - evidence_hash: str - evidence: dict[str, Any] - observed_at: str - signature: str - # Chain linking - previous_receipt_hash: Optional[str] = None - timestamp_result: Optional[TimestampResult] = None - aiuc_controls: tuple[str, ...] = AIUC_CONTROLS - # Plan signature for receipt→plan linkage - plan_signature: str = "" - key_id: str = "" - agent_signature: str = "" - agent_key_id: str = "" - # Policy + output hashes for post-hoc analysis - policy_hash: str = "" - output_hash: str = "" - # Session context - session_id: str = "" - session_trajectory: tuple[dict[str, Any], ...] = () - session_escalation: Optional[str] = None - # Reasoning capture - reasoning_hash: Optional[str] = None - # Enforcement mode - mode: str = "enforce" - original_verdict: Optional[bool] = None + def signing_key(self) -> SigningKey: + return self._signing_key - @property - def short_id(self) -> str: - return self.id[:8] - - def signable_dict(self) -> dict[str, Any]: - d = { - "id": self.id, - "type": "notarised_evidence", - "plan_id": self.plan_id, - "agent": self.agent, - "action": self.action, - "in_policy": self.in_policy, - "policy_reason": self.policy_reason, - "evidence_hash_sha512": self.evidence_hash, - "evidence": self.evidence, - "observed_at": self.observed_at, - "aiuc_controls": list(self.aiuc_controls), - "key_id": self.key_id, - "agent_key_id": self.agent_key_id, - } - # Policy + output hashes - if self.policy_hash: - d["policy_hash"] = self.policy_hash - if self.output_hash: - d["output_hash"] = self.output_hash - # Session context - if self.session_id: - d["session_id"] = self.session_id - if self.session_trajectory: - d["session_trajectory"] = list(self.session_trajectory) - if self.session_escalation: - d["session_escalation"] = self.session_escalation - # Reasoning hash - if self.reasoning_hash: - d["reasoning_hash"] = self.reasoning_hash - if self.mode != "enforce": - d["mode"] = self.mode - if self.original_verdict is not None: - d["original_verdict"] = self.original_verdict - # Chain hash is included in signature if present - if self.previous_receipt_hash is not None: - d["previous_receipt_hash"] = self.previous_receipt_hash - # Plan signature - if self.plan_signature: - d["plan_signature"] = self.plan_signature - return d - - def to_dict(self) -> dict[str, Any]: - d = self.signable_dict() - d["signature"] = self.signature - if self.timestamp_result: - d["timestamp"] = { - "tsa_url": self.timestamp_result.tsa_url, - "digest_hex": self.timestamp_result.digest_hex, - } - return d + def verify_key(self) -> VerifyKey: + return self._signing_key.verify_key - def to_json(self, indent: int = 2) -> str: - return json.dumps(self.to_dict(), indent=indent, sort_keys=False) + def key_id(self) -> str: + return _derive_key_id(self.verify_key()) + def public_key(self) -> bytes: + return bytes(self.verify_key()) -# ── Chain verification ───────────────────────────────────── +class _MemoryPlanStore: + def __init__(self) -> None: + self._plans: dict[str, Mapping[str, Any]] = {} -@dataclass(frozen=True) -class ChainVerification: - """Result of verifying receipt chain integrity.""" + def save(self, plan_id: str, payload: Mapping[str, Any]) -> None: + self._plans[plan_id] = dict(payload) - valid: bool - length: int - root_hash: str - break_at_index: Optional[int] = None - reason: str = "" + def load(self, plan_id: str) -> Optional[Mapping[str, Any]]: + return self._plans.get(plan_id) -def verify_chain(receipts: list[NotarisedReceipt]) -> ChainVerification: - """Verify receipt chain integrity. +class _MemoryChainStore: + def __init__(self) -> None: + self._hashes: dict[str, Optional[str]] = {} - Checks: - 1. First receipt has previous_receipt_hash == None - 2. Each subsequent receipt's previous_receipt_hash == SHA-256 of - the previous receipt's signed payload - 3. Returns root_hash: the hash of the final receipt in the chain + def previous_hash(self, plan_id: str) -> Optional[str]: + return self._hashes.get(plan_id) - The root_hash is a single value summarizing the entire chain. - Publishing it externally creates an anchoring commitment. - """ - if not receipts: - return ChainVerification(valid=True, length=0, root_hash="") + def append(self, plan_id: str, receipt_hash: Optional[str]) -> None: + self._hashes[plan_id] = receipt_hash - if receipts[0].previous_receipt_hash is not None: - return ChainVerification( - valid=False, - length=len(receipts), - root_hash="", - break_at_index=0, - reason="first receipt has non-null chain hash", - ) - prev_hash: Optional[str] = None - for i, receipt in enumerate(receipts): - if receipt.previous_receipt_hash != prev_hash: - return ChainVerification( - valid=False, - length=len(receipts), - root_hash="", - break_at_index=i, - reason=f"chain break at index {i}: expected {prev_hash}, " - f"got {receipt.previous_receipt_hash}", - ) - # Compute hash of this receipt for next iteration - signed_payload = _canonical_json( - {**receipt.signable_dict(), "signature": receipt.signature} - ) - prev_hash = hashlib.sha256(signed_payload).hexdigest() +class _FileChainStore: + def __init__(self, key_dir: Path) -> None: + self.path = key_dir / _CHAIN_STATE_FILE + self._hashes = self._load() - return ChainVerification(valid=True, length=len(receipts), root_hash=prev_hash or "") + def _load(self) -> dict[str, Optional[str]]: + if not self.path.exists(): + return {} + try: + data = json.loads(self.path.read_text()) + except (OSError, json.JSONDecodeError): + return {} + if not isinstance(data, dict): + return {} + result = {} + for key, value in data.items(): + if isinstance(key, str) and (value is None or isinstance(value, str)): + result[key] = value + return result + + def _save(self) -> None: + self.path.parent.mkdir(parents=True, exist_ok=True) + fd, tmp_name = tempfile.mkstemp(dir=str(self.path.parent), prefix=self.path.name + ".") + try: + with os.fdopen(fd, "w") as handle: + json.dump(self._hashes, handle, indent=2) + os.chmod(tmp_name, 0o600) + os.replace(tmp_name, self.path) + finally: + if os.path.exists(tmp_name): + os.unlink(tmp_name) + def previous_hash(self, plan_id: str) -> Optional[str]: + return self._hashes.get(plan_id) -# ── Evidence package ─────────────────────────────────────── + def append(self, plan_id: str, receipt_hash: Optional[str]) -> None: + self._hashes[plan_id] = receipt_hash + self._save() class EvidencePackage: - """Collects receipts into a portable, verifiable zip. - - Contents: - receipt_index.json Table of contents (with chain root) - plan.json The signed plan receipt - public_key.pem Ed25519 public key (SPKI PEM, RFC 8410) - receipts/{id}.json Individual signed receipts - receipts/{id}.tsq Timestamp queries - receipts/{id}.tsr Timestamp responses - chain_root.tsq/tsr Chain root timestamp (if available) - freetsa_cacert.pem CA certificate for verification - freetsa_tsa.crt TSA certificate for verification - VERIFY.sh Checks RFC 3161 timestamps (pure OpenSSL) - verify_sigs.py Checks Ed25519 signatures (needs pynacl) - """ + """Collect receipts into a portable zip package.""" __slots__ = ("_plan", "_receipts", "_public_key_pem", "_key", "_tsa_urls") def __init__( self, - plan: PlanReceipt, + plan: Plan, public_key_pem: str = "", signing_key: Optional[SigningKey] = None, tsa_urls: Optional[list[str]] = None, ) -> None: self._plan = plan - self._receipts: list[NotarisedReceipt] = [] + self._receipts: list[Receipt] = [] self._public_key_pem = public_key_pem self._key = signing_key self._tsa_urls = tsa_urls or DEFAULT_TSA_URLS @property - def plan(self) -> PlanReceipt: + def plan(self) -> Plan: return self._plan @property - def receipts(self) -> list[NotarisedReceipt]: + def receipts(self) -> list[Receipt]: return list(self._receipts) - def add(self, receipt: NotarisedReceipt) -> None: + def add(self, receipt: Receipt) -> None: self._receipts.append(receipt) def export(self, output_dir: Path, certs_dir: Optional[Path] = None) -> Path: output_dir.mkdir(parents=True, exist_ok=True) - ts = _utc_now().strftime("%Y%m%d_%H%M%S") - zip_path = output_dir / f"agentmint_evidence_{ts}.zip" - + zip_path = output_dir / ("agentmint_evidence_%s.zip" % _utc_now().strftime("%Y%m%d_%H%M%S")) certs_dir = certs_dir or Path(tempfile.mkdtemp(prefix="agentmint_certs_")) ca_paths = self._fetch_certs_safe(certs_dir) - with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) as zf: - self._write_plan(zf) - self._write_receipts(zf) - self._write_index(zf) - self._write_public_key(zf) - self._write_certs(zf, ca_paths) - self._write_verify_script(zf) - self._write_verify_sigs_script(zf) + with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) as archive: + self._write_plan(archive) + self._write_receipts(archive) + self._write_index(archive) + self._write_public_key(archive) + self._write_certs(archive, ca_paths) + archive.writestr("VERIFY.sh", _build_verify_script(self._receipts)) + archive.writestr("verify_sigs.py", _VERIFY_SIGS_PY) self._set_verify_executable(zip_path) return zip_path - def _write_plan(self, zf: zipfile.ZipFile) -> None: - zf.writestr("plan.json", json.dumps(self._plan.to_dict(), indent=2)) - - def _write_receipts(self, zf: zipfile.ZipFile) -> None: - for r in self._receipts: - zf.writestr(f"receipts/{r.id}.json", r.to_json()) - if r.timestamp_result: - zf.writestr(f"receipts/{r.id}.tsq", r.timestamp_result.tsq) - zf.writestr(f"receipts/{r.id}.tsr", r.timestamp_result.tsr) - - def _write_index(self, zf: zipfile.ZipFile) -> None: - in_count = sum(1 for r in self._receipts if r.in_policy) - out_count = len(self._receipts) - in_count - + def _write_plan(self, archive: zipfile.ZipFile) -> None: + archive.writestr("plan.json", self._plan.to_json()) + + def _write_receipts(self, archive: zipfile.ZipFile) -> None: + for receipt in self._receipts: + archive.writestr("receipts/%s.json" % receipt.id, receipt.to_json()) + if ( + receipt.timestamp_result + and receipt.timestamp_result.tsq + and receipt.timestamp_result.tsr + ): + archive.writestr("receipts/%s.tsq" % receipt.id, receipt.timestamp_result.tsq) + archive.writestr("receipts/%s.tsr" % receipt.id, receipt.timestamp_result.tsr) + + def _write_index(self, archive: zipfile.ZipFile) -> None: + in_count = sum(1 for receipt in self._receipts if receipt.in_policy) entries = [] - for r in self._receipts: - has_ts = r.timestamp_result is not None + for receipt in self._receipts: + has_ts = bool(receipt.timestamp_result and receipt.timestamp_result.tsr) entries.append( { - "receipt_id": r.id, - "short_id": r.short_id, - "action": r.action, - "agent": r.agent, - "in_policy": r.in_policy, - "policy_reason": r.policy_reason, - "observed_at": r.observed_at, - "previous_receipt_hash": r.previous_receipt_hash, - "tsr_file": f"receipts/{r.id}.tsr" if has_ts else None, + "receipt_id": receipt.id, + "short_id": receipt.short_id, + "action": receipt.action, + "agent": receipt.agent, + "in_policy": receipt.in_policy, + "policy_reason": receipt.policy_reason, + "observed_at": receipt.observed_at, + "previous_receipt_hash": receipt.previous_receipt_hash, + "tsr_file": "receipts/%s.tsr" % receipt.id if has_ts else None, } ) @@ -513,67 +330,42 @@ def _write_index(self, zf: zipfile.ZipFile) -> None: "key_id": self._plan.key_id, "total_receipts": len(self._receipts), "in_policy_count": in_count, - "out_of_policy_count": out_count, + "out_of_policy_count": len(self._receipts) - in_count, "aiuc_controls": list(AIUC_CONTROLS), "receipts": entries, } - # Chain root hash + signature + timestamp chain_result = verify_chain(self._receipts) chain_info: dict[str, Any] = { "valid": chain_result.valid, "length": chain_result.length, "root_hash": chain_result.root_hash, } - - if chain_result.root_hash and self._key: - chain_info["root_signature"] = _sign( - self._key, - { - "type": "chain_root", - "root_hash": chain_result.root_hash, - "length": chain_result.length, - "plan_id": self._plan.id, - }, - ) - - # Optional: timestamp the chain root - try: - root_bytes = chain_result.root_hash.encode() - ts_result = _timestamp_with_fallback(root_bytes, self._tsa_urls) - zf.writestr("chain_root.tsq", ts_result.tsq) - zf.writestr("chain_root.tsr", ts_result.tsr) - chain_info["root_timestamp"] = { - "tsa_url": ts_result.tsa_url, - "tsq_file": "chain_root.tsq", - "tsr_file": "chain_root.tsr", - } - except (TimestampError, Exception): - pass # graceful degradation - + if chain_result.root_hash and self._key is not None: + chain_payload = { + "type": "chain_root", + "root_hash": chain_result.root_hash, + "length": chain_result.length, + "plan_id": self._plan.id, + } + chain_info["root_signature"] = _sign_payload(self._key, JCSSerializer(), chain_payload) index["chain"] = chain_info - zf.writestr("receipt_index.json", json.dumps(index, indent=2)) + archive.writestr("receipt_index.json", json.dumps(index, indent=2)) - def _write_public_key(self, zf: zipfile.ZipFile) -> None: + def _write_public_key(self, archive: zipfile.ZipFile) -> None: if self._public_key_pem: - zf.writestr("public_key.pem", self._public_key_pem) + archive.writestr("public_key.pem", self._public_key_pem) def _write_certs( self, - zf: zipfile.ZipFile, + archive: zipfile.ZipFile, ca_paths: Optional[tuple[Path, Path]], ) -> None: if not ca_paths: return cacert, tsa_cert = ca_paths - zf.write(str(cacert), "freetsa_cacert.pem") - zf.write(str(tsa_cert), "freetsa_tsa.crt") - - def _write_verify_script(self, zf: zipfile.ZipFile) -> None: - zf.writestr("VERIFY.sh", _build_verify_script(self._receipts)) - - def _write_verify_sigs_script(self, zf: zipfile.ZipFile) -> None: - zf.writestr("verify_sigs.py", _VERIFY_SIGS_PY) + archive.write(str(cacert), "freetsa_cacert.pem") + archive.write(str(tsa_cert), "freetsa_tsa.crt") @staticmethod def _fetch_certs_safe(certs_dir: Path) -> Optional[tuple[Path, Path]]: @@ -585,221 +377,126 @@ def _fetch_certs_safe(certs_dir: Path) -> Optional[tuple[Path, Path]]: @staticmethod def _set_verify_executable(zip_path: Path) -> None: tmp_path = zip_path.with_suffix(".tmp.zip") - with zipfile.ZipFile(zip_path, "r") as zin: - with zipfile.ZipFile(tmp_path, "w", zipfile.ZIP_DEFLATED) as zout: - for item in zin.infolist(): - data = zin.read(item.filename) + with zipfile.ZipFile(zip_path, "r") as source: + with zipfile.ZipFile(tmp_path, "w", zipfile.ZIP_DEFLATED) as target: + for item in source.infolist(): + data = source.read(item.filename) if item.filename == "VERIFY.sh": perms = ( stat.S_IRWXU | stat.S_IRGRP | stat.S_IXGRP | stat.S_IROTH | stat.S_IXOTH ) item.external_attr = perms << 16 - zout.writestr(item, data) + target.writestr(item, data) shutil.move(str(tmp_path), str(zip_path)) -# ── Timestamp with fallback ─────────────────────────────── - - -def _timestamp_with_fallback( - data: bytes, - tsa_urls: Optional[list[str]] = None, -) -> TimestampResult: - """Try each TSA URL in order, return first success.""" - urls = tsa_urls or DEFAULT_TSA_URLS - if len(urls) == 1: - # Fast path — no fallback needed - return ts_timestamp(data, url=urls[0]) - last_error: Optional[Exception] = None - for url in urls: - try: - return ts_timestamp(data, url=url) - except TimestampError as e: - last_error = e - continue - raise TimestampError(f"all TSA endpoints failed, last error: {last_error}") - - -# ── Notary ───────────────────────────────────────────────── - -_CHAIN_STATE_FILE = "chain_state.json" - - -# ── Policy hash ─────────────────────────────────────────── - - -def _compute_policy_hash(plan: PlanReceipt) -> str: - """SHA-256 of canonical(scope + checkpoints + delegates_to).""" - policy_data = { - "scope": list(plan.scope), - "checkpoints": list(plan.checkpoints), - "delegates_to": list(plan.delegates_to), - } - return hashlib.sha256(_canonical_json(policy_data)).hexdigest() - - -# ── Scope intersection for delegation ───────────────────── - - -def intersect_scopes( - parent_scope: Sequence[str], - requested: Sequence[str], -) -> tuple[str, ...]: - """Compute the intersection of parent and requested scopes. - - Rules: - - Exact match: keep - - Child more specific than parent wildcard: keep child - - Parent more specific than child wildcard: keep parent - - No overlap: skip - - Returns empty tuple if no intersection (= deny). - """ - result: list[str] = [] - for child in requested: - for parent in parent_scope: - if child == parent: - if child not in result: - result.append(child) - elif matches_pattern(child, parent): - # child is more specific, parent is wildcard — keep child - if child not in result: - result.append(child) - elif matches_pattern(parent, child): - # parent is more specific, child is wildcard — keep parent - if parent not in result: - result.append(parent) - return tuple(result) - - -def _load_chain_state(key_dir: Optional[Path]) -> dict[str, Optional[str]]: - """Load persisted chain hashes. Returns empty dict if ephemeral or missing.""" - if key_dir is None: - return {} - path = key_dir / _CHAIN_STATE_FILE - if not path.exists(): - return {} - try: - data = json.loads(path.read_text()) - if not isinstance(data, dict): - return {} - # Validate: all keys are strings, all values are str or None - return { - k: v - for k, v in data.items() - if isinstance(k, str) and (v is None or isinstance(v, str)) - } - except (json.JSONDecodeError, OSError): - return {} - - -def _save_chain_state(key_dir: Optional[Path], chain_hashes: dict[str, Optional[str]]) -> None: - """Atomic write of chain state. No-op in ephemeral mode.""" - if key_dir is None: - return - key_dir.mkdir(parents=True, exist_ok=True) - path = key_dir / _CHAIN_STATE_FILE - tmp = path.with_suffix(".tmp") - tmp.write_text(json.dumps(chain_hashes, indent=2)) - os.chmod(tmp, 0o600) - os.replace(tmp, path) - - class Notary: - """Observe, evaluate, sign, timestamp. - - Usage: - notary = Notary() - plan = notary.create_plan(user="admin@co.com", ...) - receipt = notary.notarise(action="tts:standard:abc", ...) - zip_path = notary.export_evidence(Path("./evidence")) - - Improvement 4.1: key parameter for persistent keys. - Improvement 4.2: per-plan chain isolation. - Improvement 4.5: tsa_urls for fallback TSA. - """ + """Observe actions and produce signed evidence receipts.""" __slots__ = ( - "_key", - "_vk", - "_key_id", + "_plan", + "_agent", + "_key_provider", "_key_dir", - "_package", - "_chain_hashes", - "_tsa_urls", - "_circuit_breaker", "_sink", - "_mode", + "_policy", + "_timestamper", + "_serializer", + "_redactor", + "_plan_store", + "_chain_store", + "_package", "_session_id", "_session_policy", "_session_counters", "_session_trajectory", + "_mode", "_child_plans", ) def __init__( self, - key: str | Path | None = None, - tsa_urls: list[str] | None = None, - circuit_breaker: Any = None, + plan: Optional[Plan] = None, + agent: str = "default-agent", + key: Any = None, sink: Any = None, + policy: Any = None, + timestamper: Any = None, + serializer: Any = None, + redactor: Any = None, + plan_store: Any = None, + chain_store: Any = None, session_policy: Optional[dict[str, Any]] = None, mode: EnforceMode | str = EnforceMode.ENFORCE, + tsa_urls: Optional[list[str]] = None, ) -> None: - # Key persistence via KeyStore - if key is None: - # Ephemeral — for demos and quickstart - self._key = SigningKey.generate() - self._key_dir: Optional[Path] = None - elif isinstance(key, (str, Path)): - from .keystore import KeyStore - - self._key_dir = Path(key) - ks = KeyStore(self._key_dir) - self._key = ks.signing_key + self._agent = agent + self._key_provider, self._key_dir = self._coerce_key_provider(key) + self._sink = sink or FileSink() + self._policy = policy or ScopeMatchPolicy() + self._serializer = serializer or JCSSerializer() + self._redactor = redactor or NoRedactor() + if timestamper is not None: + self._timestamper = timestamper + elif tsa_urls: + self._timestamper = RFC3161Timestamper(tsa_urls[0]) else: - raise NotaryError(f"key must be a string path or None, got {type(key).__name__}") - - self._vk = self._key.verify_key - self._key_id = _derive_key_id(self._vk) - self._package: Optional[EvidencePackage] = None - # Per-plan chain isolation - self._chain_hashes: dict[str, Optional[str]] = _load_chain_state(self._key_dir) - self._tsa_urls = tsa_urls or DEFAULT_TSA_URLS - self._mode = EnforceMode(mode) if isinstance(mode, str) else mode - self._circuit_breaker = circuit_breaker - # Sink: normalize to list, isolate failures between sinks - if sink is None: - self._sink: list[Any] = [] - elif isinstance(sink, (list, tuple)): - self._sink = list(sink) + self._timestamper = NoTimestamper() + self._plan_store = plan_store or _MemoryPlanStore() + if chain_store is not None: + self._chain_store = chain_store + elif self._key_dir is not None: + self._chain_store = _FileChainStore(self._key_dir) else: - self._sink = [sink] - # Session context - self._session_id: str = str(uuid.uuid4()) - self._session_policy: Optional[dict[str, Any]] = session_policy + self._chain_store = _MemoryChainStore() + self._plan = plan + self._package: Optional[EvidencePackage] = None + self._session_id = str(uuid.uuid4()) + self._session_policy = session_policy or {} self._session_counters: dict[str, int] = {} - self._session_trajectory: deque = deque(maxlen=20) - # Child plan tracking (delegation) + self._session_trajectory: deque[dict[str, Any]] = deque(maxlen=20) + self._mode = EnforceMode(mode) if isinstance(mode, str) else mode self._child_plans: dict[str, list[str]] = {} - @property - def key_id(self) -> str: - """Stable key identifier for revocation support.""" - return self._key_id + @classmethod + def from_keystore(cls, path: str | os.PathLike[str], **overrides: Any) -> "Notary": + return cls(key=Path(path), **overrides) @property - def mode(self) -> "EnforceMode": - """Current enforcement mode.""" - return self._mode + def key_id(self) -> str: + return self._key_provider.key_id() @property def verify_key(self) -> VerifyKey: - return self._vk + return self._key_provider.verify_key() @property def verify_key_hex(self) -> str: - return self._vk.encode(encoder=HexEncoder).decode("ascii") + return self.verify_key.encode(encoder=HexEncoder).decode("ascii") + + @property + def session_id(self) -> str: + return self._session_id + + @property + def mode(self) -> EnforceMode: + return self._mode + + def _coerce_key_provider(self, key: Any) -> tuple[Any, Optional[Path]]: + if key is None: + return _EphemeralKeyProvider(), None + if hasattr(key, "signing_key") and hasattr(key, "verify_key"): + return key, getattr(key, "path", None) + if isinstance(key, (str, Path)): + provider = FileKeyProvider(Path(key)) + return provider, provider.path + raise NotaryError("key must be a key provider, path, or None") + + def _load_or_create_default_plan(self) -> Plan: + if self._plan is not None: + return self._plan + self._plan = self.create_plan(user=self._agent, action="default-plan", scope=["*"]) + return self._plan def create_plan( self, @@ -809,336 +506,221 @@ def create_plan( checkpoints: list[str] | None = None, delegates_to: list[str] | None = None, ttl_seconds: int = DEFAULT_TTL, - ) -> PlanReceipt: - """Create a signed plan receipt. Initializes the chain for this plan.""" + ) -> Plan: user = _require_non_empty_string(user, "user", MAX_IDENTITY_LEN) action = _require_non_empty_string(action, "action", MAX_ACTION_LEN) - scope_t = _require_string_list(scope, "scope") - checkpoints_t = _require_string_list(checkpoints, "checkpoints") - delegates_t = _require_string_list(delegates_to, "delegates_to") - ttl = _clamp_ttl(ttl_seconds) - - now = _utc_now() - plan_id = str(uuid.uuid4()) - issued_at = now.isoformat() - expires_at = (now + timedelta(seconds=ttl)).isoformat() - - # Build plan with placeholder signature — signable_dict() is - # the single source of truth for what gets signed. - unsigned = PlanReceipt( - id=plan_id, + scope_tuple = _require_string_list(scope, "scope") + checkpoints_tuple = _require_string_list(checkpoints, "checkpoints") + delegates_tuple = _require_string_list(delegates_to, "delegates_to") + expires_at = (_utc_now() + timedelta(seconds=_clamp_ttl(ttl_seconds))).isoformat() + + plan = Plan.create( + name=action, + scope=scope_tuple, + key_provider=self._key_provider, + delegates_to=delegates_tuple, + expires_at=expires_at, user=user, action=action, - scope=scope_t, - checkpoints=checkpoints_t, - delegates_to=delegates_t, - issued_at=issued_at, - expires_at=expires_at, - signature="", - key_id=self._key_id, + checkpoints=checkpoints_tuple, ) - - signature = _sign(self._key, unsigned.signable_dict()) - - plan = replace(unsigned, signature=signature) - - # Initialize chain for this plan - self._chain_hashes[plan_id] = None - _save_chain_state(self._key_dir, self._chain_hashes) + self._plan_store.save(plan.id, plan.to_dict()) + self._chain_store.append(plan.id, None) + self._plan = plan self._package = EvidencePackage( plan, - _public_key_pem(self._vk), - signing_key=self._key, - tsa_urls=self._tsa_urls, + _public_key_pem(self.verify_key), + signing_key=self._key_provider.signing_key(), ) return plan - def notarise( + def _policy_with_session( self, action: str, agent: str, - plan: PlanReceipt, - evidence: dict[str, Any], + evidence: Mapping[str, Any], + plan: Plan, + ) -> tuple[bool, str, Optional[str], Optional[bool]]: + decision = self._policy.evaluate(action, {**evidence, "_agent": agent}, plan) + session_escalation = None + for pattern, limits in self._session_policy.items(): + if not hasattr(limits, "get"): + continue + if matches_pattern(action, pattern): + count = self._session_counters.get(pattern, 0) + deny_after = limits.get("deny_after") + escalate_after = limits.get("escalate_after") + if deny_after is not None and count >= deny_after: + session_escalation = "denied:%s:%d/%d" % (pattern, count, deny_after) + elif escalate_after is not None and count >= escalate_after: + session_escalation = "escalate:%s:%d/%d" % (pattern, count, escalate_after) + + final_in_policy = decision.in_policy + final_reason = decision.reason + if session_escalation and session_escalation.startswith("denied:"): + final_in_policy = False + final_reason = session_escalation + + original_verdict: Optional[bool] = None + if self._mode is not EnforceMode.ENFORCE: + original_verdict = final_in_policy + final_in_policy = True + if original_verdict is False: + final_reason = "%s:%s" % (self._mode.value, final_reason) + + return final_in_policy, final_reason, session_escalation, original_verdict + + def _advance_session( + self, action: str, agent: str, in_policy: bool, observed_at: str + ) -> tuple[dict[str, Any], ...]: + entry = { + "action": action, + "agent": agent, + "in_policy": in_policy, + "observed_at": observed_at, + } + self._session_trajectory.append(entry) + for pattern in self._session_policy: + if matches_pattern(action, pattern): + self._session_counters[pattern] = self._session_counters.get(pattern, 0) + 1 + return tuple(list(self._session_trajectory)[-5:]) + + def _timestamp_for( + self, payload: Mapping[str, Any], signature: str, enable_timestamp: bool + ) -> Optional[TimestampRecord]: + if not enable_timestamp: + return None + signed_payload = dict(payload) + signed_payload["signature"] = signature + return self._timestamper.timestamp(self._serializer.canonicalize(signed_payload)) + + def notarise( + self, + action: str, + agent: Optional[str] = None, + plan: Optional[Plan] = None, + evidence: Optional[Mapping[str, Any]] = None, enable_timestamp: bool = True, agent_key: Optional[SigningKey] = None, - output: Optional[dict[str, Any]] = None, + output: Optional[Mapping[str, Any]] = None, reasoning: Optional[str] = None, - ) -> NotarisedReceipt: - """Observe an action and produce signed evidence. - - Each receipt includes the SHA-256 hash of the previous receipt's - signed payload, forming a tamper-evident chain per plan. - """ + ) -> Receipt: + if evidence is None and isinstance(agent, Mapping): + evidence = agent + agent = None + plan = plan or self._load_or_create_default_plan() + agent = _require_non_empty_string(agent or self._agent, "agent", MAX_IDENTITY_LEN) action = _require_non_empty_string(action, "action", MAX_ACTION_LEN) - agent = _require_non_empty_string(agent, "agent", MAX_IDENTITY_LEN) - evidence = _require_evidence(evidence) - - # Circuit breaker — check before policy eval - if self._circuit_breaker is not None: - br = self._circuit_breaker.check(agent) - if not br.is_allowed: - # Short-circuit: build a denied receipt without policy eval - return self._make_denied_receipt( - action, - agent, - plan, - evidence, - f"circuit_breaker:{br.reason}", - enable_timestamp, - ) - - evaluation = evaluate_policy( - action=action, - agent=agent, - plan_scope=plan.scope, - plan_checkpoints=plan.checkpoints, - plan_delegates=plan.delegates_to, - plan_expired=plan.is_expired, + redacted_evidence, _modified_paths = self._redactor.redact( + _require_evidence(evidence or {}) ) - - evidence_bytes = _canonical_json(evidence) + evidence_bytes = self._serializer.canonicalize(redacted_evidence) evidence_hash = hashlib.sha512(evidence_bytes).hexdigest() observed_at = _utc_now().isoformat() - receipt_id = str(uuid.uuid4()) - - # Per-plan chain linking - prev_hash = self._chain_hashes.get(plan.id) + in_policy, policy_reason, session_escalation, original_verdict = self._policy_with_session( + action, + agent, + redacted_evidence, + plan, + ) + previous_hash = self._chain_store.previous_hash(plan.id) + if previous_hash == "": + previous_hash = None - # Agent co-signature: agent signs the evidence hash - agent_sig = "" - agent_kid = "" + agent_signature = "" + agent_key_id = "" if agent_key is not None: - agent_sig = agent_key.sign(evidence_bytes).signature.hex() - agent_kid = _derive_key_id(agent_key.verify_key) + agent_signature = agent_key.sign(evidence_bytes).signature.hex() + agent_key_id = _derive_key_id(agent_key.verify_key) - # Policy + output hashes - policy_hash = _compute_policy_hash(plan) output_hash = "" if output is not None: - output_bytes = _canonical_json(output) - output_hash = hashlib.sha256(output_bytes).hexdigest() + output_hash = hashlib.sha256(self._serializer.canonicalize(output)).hexdigest() - # Reasoning hash - reasoning_hash: Optional[str] = None + reasoning_hash = None if reasoning is not None: reasoning_hash = hashlib.sha256(reasoning.encode("utf-8")).hexdigest() - # Session escalation check - session_escalation: Optional[str] = None - if self._session_policy: - for pattern, limits in self._session_policy.items(): - if matches_pattern(action, pattern): - count = self._session_counters.get(pattern, 0) - deny_after = limits.get("deny_after") - escalate_after = limits.get("escalate_after") - if deny_after is not None and count >= deny_after: - session_escalation = f"denied:{pattern}:{count}/{deny_after}" - elif escalate_after is not None and count >= escalate_after: - session_escalation = f"escalate:{pattern}:{count}/{escalate_after}" - - # Session deny overrides policy evaluation - is_session_denied = session_escalation is not None and session_escalation.startswith( - "denied:" - ) - final_in_policy = False if is_session_denied else evaluation.in_policy - final_reason = session_escalation if is_session_denied else evaluation.reason - - # Enforcement mode: shadow/warn evaluate fully but never block - mode_str = self._mode.value - original_verdict: Optional[bool] = None - if self._mode is not EnforceMode.ENFORCE: - original_verdict = final_in_policy - final_in_policy = True - if not original_verdict: - final_reason = f"{mode_str}:{final_reason}" - - # Build trajectory entry - trajectory_entry = { - "action": action, - "agent": agent, - "in_policy": final_in_policy, - "observed_at": observed_at, - } - self._session_trajectory.append(trajectory_entry) - recent_trajectory = tuple(self._session_trajectory)[-5:] - - # Build receipt with placeholder signature — signable_dict() is - # the single source of truth for what gets signed. - unsigned = NotarisedReceipt( - id=receipt_id, + trajectory = self._advance_session(action, agent, in_policy, observed_at) + receipt = Receipt( + id=str(uuid.uuid4()), plan_id=plan.id, agent=agent, action=action, - in_policy=final_in_policy, - policy_reason=final_reason, - evidence_hash=evidence_hash, - evidence=evidence, + in_policy=in_policy, + policy_reason=policy_reason, + evidence_hash_sha512=evidence_hash, + evidence=redacted_evidence, observed_at=observed_at, + key_id=self.key_id, signature="", - previous_receipt_hash=prev_hash, + previous_receipt_hash=previous_hash, plan_signature=plan.signature, - key_id=self._key_id, - agent_signature=agent_sig, - agent_key_id=agent_kid, - policy_hash=policy_hash, + agent_signature=agent_signature, + agent_key_id=agent_key_id, + policy_hash=_compute_policy_hash(plan), output_hash=output_hash, session_id=self._session_id, - session_trajectory=tuple(recent_trajectory), + session_trajectory=trajectory, session_escalation=session_escalation, reasoning_hash=reasoning_hash, - mode=mode_str, + compliance_tags=AIUC_CONTROLS, + aiuc_controls=AIUC_CONTROLS, + mode=self._mode.value, original_verdict=original_verdict, ) + signature = _sign_payload( + self._key_provider.signing_key(), + self._serializer, + receipt.signable_dict(), + ) + timestamp = self._timestamp_for(receipt.signable_dict(), signature, enable_timestamp) + receipt = replace(receipt, signature=signature, timestamp=timestamp) - signature = _sign(self._key, unsigned.signable_dict()) - - ts_result = None - if enable_timestamp: - signed_payload = _canonical_json({**unsigned.signable_dict(), "signature": signature}) - try: - ts_result = _timestamp_with_fallback(signed_payload, self._tsa_urls) - except TimestampError as e: - raise NotaryError( - f"timestamping failed: {e}\n" - f" Receipt was signed but not anchored to wall-clock time.\n" - f" Pass enable_timestamp=False to skip." - ) from e - - # Reconstruct with real signature (frozen dataclass) - receipt = replace(unsigned, signature=signature, timestamp_result=ts_result) - - # Update chain hash - signed_payload_bytes = _canonical_json({**unsigned.signable_dict(), "signature": signature}) - self._chain_hashes[plan.id] = hashlib.sha256(signed_payload_bytes).hexdigest() - _save_chain_state(self._key_dir, self._chain_hashes) + self._sink.write(receipt.id, receipt.to_json().encode("utf-8"), {"plan_id": plan.id}) + self._chain_store.append(plan.id, receipt.canonical_hash()) if self._package and self._package.plan.id == plan.id: self._package.add(receipt) - # Record call in circuit breaker - if self._circuit_breaker is not None: - self._circuit_breaker.record(agent) - - # Emit to sink - for _sink in self._sink: - try: - _sink.emit(receipt) - except Exception: - pass - - # Update session counters - if self._session_policy: - for pattern in self._session_policy: - if matches_pattern(action, pattern): - self._session_counters[pattern] = self._session_counters.get(pattern, 0) + 1 - return receipt - def verify_receipt(self, receipt: NotarisedReceipt) -> bool: - return _verify_signature(self._vk, receipt.signable_dict(), receipt.signature) - - def verify_plan(self, plan: PlanReceipt) -> bool: - return _verify_signature(self._vk, plan.signable_dict(), plan.signature) - - def _make_denied_receipt( - self, - action: str, - agent: str, - plan: PlanReceipt, - evidence: dict[str, Any], - reason: str, - enable_timestamp: bool, - ) -> NotarisedReceipt: - """Build a denied receipt (circuit breaker or session deny).""" - evidence_bytes = _canonical_json(evidence) - evidence_hash = hashlib.sha512(evidence_bytes).hexdigest() - observed_at = _utc_now().isoformat() - receipt_id = str(uuid.uuid4()) - prev_hash = self._chain_hashes.get(plan.id) - policy_hash = _compute_policy_hash(plan) - - mode_str = self._mode.value - _den_verdict: Optional[bool] = None - _den_policy = False - _den_reason = reason - if self._mode is not EnforceMode.ENFORCE: - _den_verdict = False - _den_policy = True - _den_reason = f"{mode_str}:{reason}" + async def anotarise(self, *args: Any, **kwargs: Any) -> Receipt: + return await asyncio.to_thread(self.notarise, *args, **kwargs) - unsigned = NotarisedReceipt( - id=receipt_id, - plan_id=plan.id, - agent=agent, - action=action, - in_policy=_den_policy, - policy_reason=_den_reason, - evidence_hash=evidence_hash, - evidence=evidence, - observed_at=observed_at, - signature="", - previous_receipt_hash=prev_hash, - plan_signature=plan.signature, - key_id=self._key_id, - policy_hash=policy_hash, - session_id=self._session_id, - mode=mode_str, - original_verdict=_den_verdict, + def verify_receipt(self, receipt: Receipt) -> bool: + return _verify_signature( + self.verify_key, + self._serializer, + receipt.signable_dict(), + receipt.signature, ) - signature = _sign(self._key, unsigned.signable_dict()) - - ts_result = None - if enable_timestamp: - signed_payload = _canonical_json({**unsigned.signable_dict(), "signature": signature}) - try: - ts_result = _timestamp_with_fallback(signed_payload, self._tsa_urls) - except TimestampError: - pass # graceful degradation for denied receipts - - receipt = replace(unsigned, signature=signature, timestamp_result=ts_result) - - signed_payload_bytes = _canonical_json({**unsigned.signable_dict(), "signature": signature}) - self._chain_hashes[plan.id] = hashlib.sha256(signed_payload_bytes).hexdigest() - _save_chain_state(self._key_dir, self._chain_hashes) - - if self._package and self._package.plan.id == plan.id: - self._package.add(receipt) - - for _sink in self._sink: - try: - _sink.emit(receipt) - except Exception: - pass - return receipt - - # Multi-agent delegation + def verify_plan(self, plan: Plan) -> bool: + return _verify_signature( + self.verify_key, + self._serializer, + plan.signable_dict(), + plan.signature, + ) def delegate_to_agent( self, - parent_plan: PlanReceipt, + parent_plan: Plan, child_agent: str, requested_scope: list[str], action: str = "", checkpoints: list[str] | None = None, ttl_seconds: int = DEFAULT_TTL, - ) -> PlanReceipt: - """Create a child plan with scope intersected from parent. - - Returns a new PlanReceipt whose scope is the intersection of - parent_plan.scope and requested_scope. Raises NotaryError if - the intersection is empty (no delegable permissions). - """ + ) -> Plan: child_agent = _require_non_empty_string(child_agent, "child_agent", MAX_IDENTITY_LEN) - requested_t = _require_string_list(requested_scope, "requested_scope") - - effective_scope = intersect_scopes(parent_plan.scope, requested_t) + requested_tuple = _require_string_list(requested_scope, "requested_scope") + effective_scope = intersect_scopes(parent_plan.scope, requested_tuple) if not effective_scope: raise NotaryError( - f"scope intersection is empty — parent scope {list(parent_plan.scope)} " - f"does not overlap with requested {list(requested_t)}" + "scope intersection is empty — parent scope %s does not overlap with requested %s" + % (list(parent_plan.scope), list(requested_tuple)) ) - child_plan = self.create_plan( user=parent_plan.user, action=action or parent_plan.action, @@ -1147,45 +729,36 @@ def delegate_to_agent( delegates_to=[child_agent], ttl_seconds=ttl_seconds, ) - - # Track parent → child relationship - if parent_plan.id not in self._child_plans: - self._child_plans[parent_plan.id] = [] - self._child_plans[parent_plan.id].append(child_plan.id) - + self._child_plans.setdefault(parent_plan.id, []).append(child_plan.id) return child_plan def audit_tree(self, plan_id: str) -> dict[str, Any]: - """Return the delegation tree rooted at plan_id.""" - children = self._child_plans.get(plan_id, []) return { "plan_id": plan_id, - "children": [self.audit_tree(cid) for cid in children], + "children": [ + self.audit_tree(child_id) for child_id in self._child_plans.get(plan_id, []) + ], } - @property - def session_id(self) -> str: - """Current session identifier.""" - return self._session_id + def bootstrap(self) -> None: + if self._key_dir is not None: + self._key_dir.mkdir(parents=True, exist_ok=True) + self._key_provider.signing_key() + if self._plan is None: + self._load_or_create_default_plan() + if self._key_dir is not None and self._plan is not None: + config_path = self._key_dir / "agentmint.json" + config_path.write_text( + json.dumps({"default_plan_id": self._plan.id, "key_id": self.key_id}, indent=2) + ) - def export_evidence( - self, - output_dir: Path, - certs_dir: Optional[Path] = None, - ) -> Path: + def export_evidence(self, output_dir: Path, certs_dir: Optional[Path] = None) -> Path: if not self._package: raise NotaryError("no plan created — call create_plan() first") return self._package.export(output_dir, certs_dir) -# ── VERIFY.sh (timestamps only — pure OpenSSL, zero dependencies) ── - - -def _build_verify_script(receipts: list[NotarisedReceipt]) -> str: - """Generate VERIFY.sh — checks RFC 3161 timestamps with OpenSSL. - - For Ed25519 signature verification, see verify_sigs.py in the same package. - """ +def _build_verify_script(receipts: list[Receipt]) -> str: lines = [ "#!/bin/bash", "# AgentMint Evidence Verification — RFC 3161 Timestamps", @@ -1202,37 +775,34 @@ def _build_verify_script(receipts: list[NotarisedReceipt]) -> str: "", ] - for r in receipts: - rid = r.id - has_ts = r.timestamp_result is not None - - lines.append(f'echo "── Receipt {r.short_id} ──"') - lines.append(f'echo " Action: {r.action}"') - lines.append(f'echo " Agent: {r.agent}"') - lines.append(f'echo " In Policy: {r.in_policy}"') - lines.append(f'echo " Observed: {r.observed_at}"') - - if not r.in_policy: - reason_escaped = r.policy_reason.replace('"', '\\"').replace("'", "'\\''") - lines.append(f'echo " ⚠ FLAGGED: {reason_escaped}"') + for receipt in receipts: + lines.append('echo "── Receipt %s ──"' % receipt.short_id) + lines.append('echo " Action: %s"' % receipt.action) + lines.append('echo " Agent: %s"' % receipt.agent) + lines.append('echo " In Policy: %s"' % receipt.in_policy) + lines.append('echo " Observed: %s"' % receipt.observed_at) + if not receipt.in_policy: + lines.append('echo " ⚠ FLAGGED: %s"' % receipt.policy_reason.replace('"', '\\"')) lines.append("FLAGGED=$((FLAGGED + 1))") - - if has_ts: - lines.append(f"if openssl ts -verify \\") - lines.append(f' -in "receipts/{rid}.tsr" \\') - lines.append(f' -queryfile "receipts/{rid}.tsq" \\') - lines.append(f' -CAfile "freetsa_cacert.pem" \\') - lines.append(f' -untrusted "freetsa_tsa.crt" \\') - lines.append(f" > /dev/null 2>&1; then") - lines.append(f' echo " Timestamp: ✓ verified"') - lines.append(f" VERIFIED=$((VERIFIED + 1))") - lines.append(f"else") - lines.append(f' echo " Timestamp: ✗ FAILED"') - lines.append(f" FAILED=$((FAILED + 1))") - lines.append(f"fi") + if receipt.timestamp_result and receipt.timestamp_result.tsr: + lines.extend( + [ + "if openssl ts -verify \\", + ' -in "receipts/%s.tsr" \\' % receipt.id, + ' -queryfile "receipts/%s.tsq" \\' % receipt.id, + ' -CAfile "freetsa_cacert.pem" \\', + ' -untrusted "freetsa_tsa.crt" \\', + " > /dev/null 2>&1; then", + ' echo " Timestamp: ✓ verified"', + " VERIFIED=$((VERIFIED + 1))", + "else", + ' echo " Timestamp: ✗ FAILED"', + " FAILED=$((FAILED + 1))", + "fi", + ] + ) else: lines.append('echo " Timestamp: (not requested)"') - lines.append("TOTAL=$((TOTAL + 1))") lines.append('echo ""') @@ -1249,12 +819,9 @@ def _build_verify_script(receipts: list[NotarisedReceipt]) -> str: "exit 0", ] ) - return "\n".join(lines) + "\n" -# ── verify_sigs.py (Ed25519 signatures — needs pynacl) ──── - _VERIFY_SIGS_PY = '''\ #!/usr/bin/env python3 """Verify Ed25519 signatures on all receipts. Requires: pip install pynacl""" @@ -1268,14 +835,14 @@ def _build_verify_script(receipts: list[NotarisedReceipt]) -> str: print("Install pynacl: pip install pynacl") sys.exit(1) -def canonical(d): - return json.dumps(d, sort_keys=True, separators=(",", ":")).encode() +def canonical(value): + from agentmint.providers.serializers import JCSSerializer + return JCSSerializer().canonicalize(value) def load_pem_public_key(path): lines = path.read_text().strip().split("\\n") b64 = "".join(lines[1:-1]) der = base64.b64decode(b64) - # SPKI prefix is 12 bytes, Ed25519 key is last 32 return VerifyKey(der[12:]) here = Path(__file__).parent @@ -1289,7 +856,6 @@ def load_pem_public_key(path): for rfile in sorted((here / "receipts").glob("*.json")): receipt = json.loads(rfile.read_text()) sig = bytes.fromhex(receipt["signature"]) - # Reconstruct signable dict (everything except signature and timestamp) signable = {k: v for k, v in receipt.items() if k not in ("signature", "timestamp")} try: vk.verify(canonical(signable), sig) @@ -1304,10 +870,3 @@ def load_pem_public_key(path): print(f"\\nSignatures: {ok} verified, {fail} failed") sys.exit(1 if fail else 0) ''' - - -# ── Utilities ────────────────────────────────────────────── - - -def _utc_now() -> datetime: - return datetime.now(timezone.utc) diff --git a/agentmint/plan.py b/agentmint/plan.py index 1d7d31e..58a9909 100644 --- a/agentmint/plan.py +++ b/agentmint/plan.py @@ -1,34 +1,150 @@ -"""Plan lifecycle model for the build spec. - -The current implementation lives in :class:`agentmint.notary.PlanReceipt`. -This stub reserves the future plan API; PR 2 will migrate lifecycle behavior -here while preserving the existing public APIs. -""" +"""Plan lifecycle model.""" from __future__ import annotations -from dataclasses import dataclass, field -from datetime import datetime -from typing import Any, Mapping, Optional +import json +import uuid +from dataclasses import dataclass, field, replace +from datetime import datetime, timezone +from typing import Any, Mapping, Optional, Sequence, Tuple + +from nacl.exceptions import BadSignatureError + +from .providers.serializers import JCSSerializer + + +def _utc_now() -> datetime: + return datetime.now(timezone.utc) @dataclass(frozen=True) class Plan: - """Placeholder plan model for scoped authorization lifecycle.""" + """Signed plan describing allowed scope.""" id: str - version: str + name: str + version: int parent_plan_id: Optional[str] - scope: Mapping[str, Any] = field(default_factory=dict) - created_at: datetime = field(default_factory=datetime.utcnow) - expires_at: Optional[datetime] = None + scope: Tuple[str, ...] = field(default_factory=tuple) + delegates_to: Tuple[str, ...] = field(default_factory=tuple) + created_at: str = field(default_factory=lambda: _utc_now().isoformat()) + expires_at: Optional[str] = None + signature: str = "" + key_id: str = "" + user: str = "" + action: str = "" + checkpoints: Tuple[str, ...] = field(default_factory=tuple) + + @property + def is_expired(self) -> bool: + return self.expires_at is not None and _utc_now() >= datetime.fromisoformat(self.expires_at) + + @property + def short_id(self) -> str: + return self.id[:8] + + @property + def issued_at(self) -> str: + return self.created_at + + @classmethod + def create( + cls, + name: str, + scope: Sequence[str], + key_provider: Any, + parent_plan_id: Optional[str] = None, + delegates_to: Optional[Sequence[str]] = None, + expires_at: Optional[str] = None, + user: str = "", + action: str = "", + checkpoints: Optional[Sequence[str]] = None, + ) -> "Plan": + plan = cls( + id=str(uuid.uuid4()), + name=name, + version=1, + parent_plan_id=parent_plan_id, + scope=tuple(scope), + delegates_to=tuple(delegates_to or ()), + created_at=_utc_now().isoformat(), + expires_at=expires_at, + signature="", + key_id=key_provider.key_id(), + user=user, + action=action, + checkpoints=tuple(checkpoints or ()), + ) + serializer = JCSSerializer() + signature = ( + key_provider.signing_key().sign(serializer.canonicalize(plan.signable_dict())).signature + ) + return replace(plan, signature=signature.hex()) + + @classmethod + def derive( + cls, + parent: "Plan", + key_provider: Any, + scope: Optional[Sequence[str]] = None, + delegates_to: Optional[Sequence[str]] = None, + name: Optional[str] = None, + ) -> "Plan": + child = cls.create( + name=name or parent.name, + scope=scope or parent.scope, + key_provider=key_provider, + parent_plan_id=parent.id, + delegates_to=delegates_to if delegates_to is not None else parent.delegates_to, + expires_at=parent.expires_at, + user=parent.user, + action=parent.action, + checkpoints=parent.checkpoints, + ) + child = replace(child, version=parent.version + 1, signature="") + serializer = JCSSerializer() + signature = ( + key_provider.signing_key() + .sign(serializer.canonicalize(child.signable_dict())) + .signature + ) + return replace(child, signature=signature.hex()) - def signable_payload(self) -> Mapping[str, Any]: - """Return the deterministic payload that will be signed.""" + def signable_dict(self) -> Mapping[str, Any]: + payload = { + "id": self.id, + "name": self.name, + "version": self.version, + "parent_plan_id": self.parent_plan_id, + "scope": list(self.scope), + "delegates_to": list(self.delegates_to), + "created_at": self.created_at, + "expires_at": self.expires_at, + "key_id": self.key_id, + } + if self.user: + payload["user"] = self.user + if self.action: + payload["action"] = self.action + if self.checkpoints: + payload["checkpoints"] = list(self.checkpoints) + return payload - raise NotImplementedError("Plan signing payload migration comes in PR 2") + def to_dict(self) -> Mapping[str, Any]: + payload = dict(self.signable_dict()) + payload["signature"] = self.signature + return payload - def is_expired(self, now: Optional[datetime] = None) -> bool: - """Return whether the plan has expired at the given time.""" + def to_json(self, indent: int = 2) -> str: + return json.dumps(self.to_dict(), indent=indent) - raise NotImplementedError("Plan expiry behavior migration comes in PR 2") + def verify(self, key_provider: Any) -> bool: + serializer = JCSSerializer() + try: + key_provider.verify_key().verify( + serializer.canonicalize(self.signable_dict()), + bytes.fromhex(self.signature), + ) + return True + except (BadSignatureError, ValueError): + return False diff --git a/agentmint/policy.py b/agentmint/policy.py new file mode 100644 index 0000000..cc1580e --- /dev/null +++ b/agentmint/policy.py @@ -0,0 +1,60 @@ +"""Policy helpers for receipt evaluation.""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Mapping, Sequence + +from .patterns import matches_pattern + + +@dataclass(frozen=True) +class PolicyDecision: + """Result of evaluating an action against a plan.""" + + in_policy: bool + reason: str + + +class ScopeMatchPolicy: + """Default action policy based on scope, checkpoints, and delegates.""" + + def evaluate(self, action: str, evidence: Mapping[str, Any], plan: Any) -> PolicyDecision: + if getattr(plan, "is_expired", False): + return PolicyDecision(False, "plan expired") + + delegates_to = tuple(getattr(plan, "delegates_to", ()) or ()) + agent = evidence.get("_agent") + if delegates_to and agent not in delegates_to: + return PolicyDecision(False, "agent '%s' not in delegates_to" % agent) + + checkpoints = tuple(getattr(plan, "checkpoints", ()) or ()) + for pattern in checkpoints: + if matches_pattern(action, pattern): + return PolicyDecision(False, "matched checkpoint %s" % pattern) + + scope = tuple(getattr(plan, "scope", ()) or ()) + for pattern in scope: + if matches_pattern(action, pattern): + return PolicyDecision(True, "matched scope %s" % pattern) + + return PolicyDecision(False, "no scope pattern matched") + + +def evaluate_policy( + action: str, + agent: str, + plan_scope: Sequence[str], + plan_checkpoints: Sequence[str], + plan_delegates: Sequence[str], + plan_expired: bool, +) -> PolicyDecision: + """Compatibility helper retained for legacy callers and tests.""" + + class _Plan: + scope = tuple(plan_scope) + checkpoints = tuple(plan_checkpoints) + delegates_to = tuple(plan_delegates) + is_expired = plan_expired + + return ScopeMatchPolicy().evaluate(action, {"_agent": agent}, _Plan()) diff --git a/agentmint/protocols.py b/agentmint/protocols.py index 565ea00..a58bdec 100644 --- a/agentmint/protocols.py +++ b/agentmint/protocols.py @@ -1,28 +1,26 @@ -"""Protocol interfaces for AgentMint's next architecture. - -This module defines the extensibility boundary described in the build spec: -core runtime code should depend on small protocols for keys, sinks, policy, -timestamping, serialization, stores, and redaction. Implementations will be -ported in later PRs without changing the public runtime behavior in this -foundation PR. -""" +"""Protocol interfaces for the receipt runtime.""" from __future__ import annotations -from typing import Any, Mapping, Optional, Protocol, Sequence +from typing import Any, Mapping, Optional, Protocol, Sequence, Tuple + +from nacl.signing import SigningKey, VerifyKey class KeyProvider(Protocol): """Provide signing and verification material to receipt producers.""" + def signing_key(self) -> SigningKey: + """Return the active Ed25519 signing key.""" + + def verify_key(self) -> VerifyKey: + """Return the active Ed25519 verification key.""" + def key_id(self) -> str: """Return a stable, audit-safe identifier for the active signing key.""" - def sign(self, payload: bytes) -> bytes: - """Sign canonical payload bytes and return the detached signature.""" - def public_key(self) -> bytes: - """Return public verification bytes suitable for offline verification.""" + """Return raw public verification bytes suitable for offline verification.""" class Sink(Protocol): @@ -31,27 +29,33 @@ class Sink(Protocol): def write(self, name: str, payload: bytes, metadata: Optional[Mapping[str, Any]] = None) -> str: """Persist payload bytes and return an implementation-specific locator.""" + def flush(self) -> None: + """Flush any buffered state.""" + + def close(self) -> None: + """Release held resources.""" + class Policy(Protocol): """Evaluate whether a requested action is allowed by the active scope.""" - def evaluate(self, action: str, evidence: Mapping[str, Any]) -> bool: - """Return whether the action and evidence satisfy policy.""" + def evaluate(self, action: str, evidence: Mapping[str, Any], plan: Any) -> Any: + """Return a policy decision object for the action and evidence.""" class Timestamper(Protocol): """Attach optional independent time evidence to receipt payloads.""" - def timestamp(self, digest: bytes) -> bytes: - """Return timestamp evidence for a canonical digest.""" - - def verify(self, digest: bytes, token: bytes) -> bool: - """Return whether a timestamp token verifies for the digest.""" + def timestamp(self, payload: bytes) -> Any: + """Return timestamp evidence for canonical payload bytes.""" class Serializer(Protocol): """Encode and decode receipt payloads using deterministic canonical forms.""" + def canonicalize(self, payload: Any) -> bytes: + """Serialize a payload to canonical bytes.""" + def dumps(self, payload: Mapping[str, Any]) -> bytes: """Serialize a payload to deterministic bytes.""" @@ -65,8 +69,8 @@ class PlanStore(Protocol): def save(self, plan_id: str, payload: Mapping[str, Any]) -> None: """Persist a plan payload.""" - def load(self, plan_id: str) -> Mapping[str, Any]: - """Load a plan payload or raise an implementation-specific error.""" + def load(self, plan_id: str) -> Optional[Mapping[str, Any]]: + """Load a plan payload if present.""" class ChainStore(Protocol): @@ -75,12 +79,12 @@ class ChainStore(Protocol): def previous_hash(self, plan_id: str) -> Optional[str]: """Return the previous receipt hash for a plan, if any.""" - def append(self, plan_id: str, receipt_hash: str) -> None: + def append(self, plan_id: str, receipt_hash: Optional[str]) -> None: """Record the newest receipt hash for a plan chain.""" class Redactor(Protocol): """Remove or transform sensitive fields before evidence is serialized.""" - def redact(self, evidence: Mapping[str, Any], fields: Sequence[str]) -> Mapping[str, Any]: - """Return evidence with requested fields redacted.""" + def redact(self, evidence: Mapping[str, Any]) -> Tuple[Mapping[str, Any], Sequence[str]]: + """Return redacted evidence and a list of modified paths.""" diff --git a/agentmint/providers/keys.py b/agentmint/providers/keys.py index 5c137c6..8879ab0 100644 --- a/agentmint/providers/keys.py +++ b/agentmint/providers/keys.py @@ -1,32 +1,147 @@ -"""Key provider stubs for the build spec provider layer.""" +"""Key providers.""" from __future__ import annotations +import base64 +import hashlib +import os +import tempfile +import threading +from contextlib import contextmanager from pathlib import Path +from nacl.signing import SigningKey, VerifyKey + from agentmint.protocols import KeyProvider +try: + import fcntl +except ImportError: # pragma: no cover + fcntl = None # type: ignore[assignment] + + +DEFAULT_KEY_DIR = Path.home() / ".agentmint" / "keys" +PRIVATE_KEY_FILE = "ed25519-private.key" +PUBLIC_KEY_FILE = "ed25519-public.key" +_LOCK = threading.Lock() + + +def _key_id_from_public_key(public_key: bytes) -> str: + return hashlib.sha256(public_key).hexdigest()[:16] + -class FileKeyProvider: - """Placeholder file-backed implementation of :class:`KeyProvider`.""" +@contextmanager +def _file_lock(lock_path: Path): + lock_path.parent.mkdir(parents=True, exist_ok=True) + with lock_path.open("a+b") as handle: + if fcntl is not None: + fcntl.flock(handle.fileno(), fcntl.LOCK_EX) + try: + yield + finally: + if fcntl is not None: + fcntl.flock(handle.fileno(), fcntl.LOCK_UN) - def __init__(self, path: str | Path) -> None: + +def _atomic_write(path: Path, payload: bytes, mode: int) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + fd, tmp_name = tempfile.mkstemp(dir=str(path.parent), prefix=path.name + ".") + try: + with os.fdopen(fd, "wb") as handle: + handle.write(payload) + os.chmod(tmp_name, mode) + os.replace(tmp_name, path) + finally: + if os.path.exists(tmp_name): + os.unlink(tmp_name) + + +class FileKeyProvider(KeyProvider): + """File-backed Ed25519 signing key provider.""" + + def __init__(self, path: str | Path = DEFAULT_KEY_DIR) -> None: self.path = Path(path) + self._signing_key: SigningKey | None = None + self._verify_key: VerifyKey | None = None - def key_id(self) -> str: - """Return the active key identifier.""" + def _load_or_generate(self) -> None: + if self._signing_key is not None and self._verify_key is not None: + return + + private_path = self.path / PRIVATE_KEY_FILE + public_path = self.path / PUBLIC_KEY_FILE + lock_path = self.path / ".keys.lock" - raise NotImplementedError("FileKeyProvider will be implemented in PR 2") + with _LOCK: + with _file_lock(lock_path): + if private_path.exists(): + seed = private_path.read_bytes() + signing_key = SigningKey(seed) + else: + signing_key = SigningKey.generate() + _atomic_write(private_path, bytes(signing_key), 0o600) + _atomic_write(public_path, bytes(signing_key.verify_key), 0o644) + self._signing_key = signing_key + self._verify_key = signing_key.verify_key - def sign(self, payload: bytes) -> bytes: - """Sign payload bytes with the file-backed key.""" + def signing_key(self) -> SigningKey: + self._load_or_generate() + assert self._signing_key is not None + return self._signing_key - raise NotImplementedError("FileKeyProvider will be implemented in PR 2") + def verify_key(self) -> VerifyKey: + self._load_or_generate() + assert self._verify_key is not None + return self._verify_key + + def key_id(self) -> str: + return _key_id_from_public_key(self.public_key()) def public_key(self) -> bytes: - """Return public key bytes for offline verification.""" + return bytes(self.verify_key()) + + +class EnvKeyProvider(KeyProvider): + """Environment-backed Ed25519 signing key provider.""" + + def __init__(self, env_var: str = "AGENTMINT_PRIVATE_KEY") -> None: + self.env_var = env_var + self._signing_key: SigningKey | None = None + + def _decode(self, value: str) -> bytes: + stripped = value.strip() + if len(stripped) == 64: + try: + return bytes.fromhex(stripped) + except ValueError: + pass + try: + decoded = base64.b64decode(stripped, validate=True) + if len(decoded) == 32: + return decoded + except Exception: + pass + raw = stripped.encode("utf-8") + if len(raw) == 32: + return raw + raise ValueError("environment key must be 32 raw bytes, 64 hex chars, or base64") - raise NotImplementedError("FileKeyProvider will be implemented in PR 2") + def signing_key(self) -> SigningKey: + if self._signing_key is None: + value = os.environ.get(self.env_var) + if not value: + raise RuntimeError("missing private key environment variable %s" % self.env_var) + self._signing_key = SigningKey(self._decode(value)) + return self._signing_key + + def verify_key(self) -> VerifyKey: + return self.signing_key().verify_key + + def key_id(self) -> str: + return _key_id_from_public_key(self.public_key()) + + def public_key(self) -> bytes: + return bytes(self.verify_key()) -__all__ = ["FileKeyProvider", "KeyProvider"] +__all__ = ["EnvKeyProvider", "FileKeyProvider", "KeyProvider"] diff --git a/agentmint/providers/redactors.py b/agentmint/providers/redactors.py index bbce0b1..8f4405c 100644 --- a/agentmint/providers/redactors.py +++ b/agentmint/providers/redactors.py @@ -1,19 +1,65 @@ -"""Redactor provider stubs for profile-specific evidence minimization.""" +"""Redactor providers.""" from __future__ import annotations -from typing import Any, Mapping, Sequence +import hashlib +import json +from typing import Any, Mapping, MutableMapping, Sequence from agentmint.protocols import Redactor class FieldRedactor: - """Placeholder field-based implementation of :class:`Redactor`.""" + """Hash or drop configured fields recursively.""" - def redact(self, evidence: Mapping[str, Any], fields: Sequence[str]) -> Mapping[str, Any]: - """Return evidence with selected fields redacted.""" + def __init__( + self, + always_hash: Sequence[str] | None = None, + always_drop: Sequence[str] | None = None, + ) -> None: + self.always_hash = frozenset(always_hash or ()) + self.always_drop = frozenset(always_drop or ()) - raise NotImplementedError("FieldRedactor will be implemented in PR 2") + def _hash_value(self, value: Any) -> str: + payload = json.dumps(value, sort_keys=True, separators=(",", ":"), ensure_ascii=False) + return hashlib.sha256(payload.encode("utf-8")).hexdigest() + def _walk( + self, evidence: Mapping[str, Any], prefix: str, modified: list[str] + ) -> dict[str, Any]: + result: MutableMapping[str, Any] = {} + for key, value in evidence.items(): + path = "%s.%s" % (prefix, key) if prefix else str(key) + if key in self.always_drop: + modified.append(path) + continue + if key in self.always_hash: + result[str(key)] = self._hash_value(value) + modified.append(path) + continue + if isinstance(value, Mapping): + result[str(key)] = self._walk(value, path, modified) + elif isinstance(value, list): + result[str(key)] = [ + self._walk(item, "%s[%d]" % (path, index), modified) + if isinstance(item, Mapping) + else item + for index, item in enumerate(value) + ] + else: + result[str(key)] = value + return dict(result) -__all__ = ["FieldRedactor", "Redactor"] + def redact(self, evidence: Mapping[str, Any]): + modified: list[str] = [] + return self._walk(evidence, "", modified), modified + + +class NoRedactor: + """Pass-through redactor.""" + + def redact(self, evidence: Mapping[str, Any]): + return dict(evidence), [] + + +__all__ = ["FieldRedactor", "NoRedactor", "Redactor"] diff --git a/agentmint/providers/serializers.py b/agentmint/providers/serializers.py index 71f7f5d..4598e4b 100644 --- a/agentmint/providers/serializers.py +++ b/agentmint/providers/serializers.py @@ -1,24 +1,69 @@ -"""Serializer provider stubs for deterministic receipt encoding.""" +"""Serializer providers.""" from __future__ import annotations +import json +import math from typing import Any, Mapping from agentmint.protocols import Serializer class JCSSerializer: - """Placeholder JSON Canonicalization Scheme serializer.""" + """Small RFC 8785-style canonical JSON serializer.""" - def dumps(self, payload: Mapping[str, Any]) -> bytes: - """Serialize payload data to canonical JSON bytes.""" + def canonicalize(self, payload: Any) -> bytes: + return self._encode(payload).encode("utf-8") - raise NotImplementedError("JCSSerializer will be implemented in PR 2") + def _encode(self, payload: Any) -> str: + if payload is None: + return "null" + if payload is True: + return "true" + if payload is False: + return "false" + if isinstance(payload, str): + return json.dumps(payload, ensure_ascii=False, separators=(",", ":")) + if isinstance(payload, int) and not isinstance(payload, bool): + return str(payload) + if isinstance(payload, float): + return self._encode_float(payload) + if isinstance(payload, Mapping): + items = [] + for key in sorted(payload.keys()): + items.append( + "%s:%s" + % ( + json.dumps(str(key), ensure_ascii=False, separators=(",", ":")), + self._encode(payload[key]), + ) + ) + return "{%s}" % ",".join(items) + if isinstance(payload, (list, tuple)): + return "[%s]" % ",".join(self._encode(item) for item in payload) + raise TypeError("value of type %s is not JSON serializable" % type(payload).__name__) - def loads(self, payload: bytes) -> Mapping[str, Any]: - """Deserialize canonical JSON bytes.""" + def _encode_float(self, value: float) -> str: + if not math.isfinite(value): + raise ValueError("non-finite numbers are not permitted in canonical JSON") + if value == 0: + return "0" + if value.is_integer() and abs(value) < 1e21: + return str(int(value)) + text = json.dumps(value, ensure_ascii=False, allow_nan=False, separators=(",", ":")) + if text.endswith(".0") and "e" not in text and "E" not in text: + text = text[:-2] + text = text.replace("E", "e") + return text - raise NotImplementedError("JCSSerializer will be implemented in PR 2") + def dumps(self, payload: Mapping[str, Any]) -> bytes: + return self.canonicalize(payload) + + def loads(self, payload: bytes) -> Mapping[str, Any]: + loaded = json.loads(payload.decode("utf-8")) + if not isinstance(loaded, dict): + raise TypeError("canonical payload must decode to an object") + return loaded __all__ = ["JCSSerializer", "Serializer"] diff --git a/agentmint/providers/sinks.py b/agentmint/providers/sinks.py index 0fcabc5..376c7c9 100644 --- a/agentmint/providers/sinks.py +++ b/agentmint/providers/sinks.py @@ -1,7 +1,12 @@ -"""Sink provider stubs for exported AgentMint evidence.""" +"""Sink providers.""" from __future__ import annotations +import os +import tempfile +import threading +from collections import deque +from datetime import datetime, timezone from pathlib import Path from typing import Any, Mapping, Optional @@ -9,27 +14,53 @@ class FileSink: - """Placeholder file sink implementation of :class:`Sink`.""" + """Date-partitioned file sink.""" - def __init__(self, root: str | Path) -> None: + def __init__(self, root: str | Path = Path("./receipts")) -> None: self.root = Path(root) + self._lock = threading.Lock() def write(self, name: str, payload: bytes, metadata: Optional[Mapping[str, Any]] = None) -> str: - """Persist payload bytes under the configured root.""" + del metadata + date_str = datetime.now(timezone.utc).strftime("%Y-%m-%d") + target_dir = self.root / date_str + target_dir.mkdir(parents=True, exist_ok=True) + target_path = target_dir / ("%s.json" % name) - raise NotImplementedError("FileSink will be implemented in PR 2") + with self._lock: + fd, tmp_name = tempfile.mkstemp(dir=str(target_dir), prefix=name + ".", suffix=".tmp") + try: + with os.fdopen(fd, "wb") as handle: + handle.write(payload) + os.replace(tmp_name, target_path) + finally: + if os.path.exists(tmp_name): + os.unlink(tmp_name) + + return str(target_path) + + def flush(self) -> None: + return None + + def close(self) -> None: + return None class MemorySink: - """Placeholder in-memory sink implementation of :class:`Sink`.""" + """Non-durable FIFO in-memory sink.""" def __init__(self) -> None: - self.records: list[tuple[str, bytes, Optional[Mapping[str, Any]]]] = [] + self.records: deque[tuple[str, bytes, Optional[Mapping[str, Any]]]] = deque() def write(self, name: str, payload: bytes, metadata: Optional[Mapping[str, Any]] = None) -> str: - """Persist payload bytes in memory and return a locator.""" + self.records.append((name, payload, metadata)) + return "memory://%s" % name + + def flush(self) -> None: + return None - raise NotImplementedError("MemorySink will be implemented in PR 2") + def close(self) -> None: + return None __all__ = ["FileSink", "MemorySink", "Sink"] diff --git a/agentmint/providers/timestamp.py b/agentmint/providers/timestamp.py index 9a1654b..d3858ac 100644 --- a/agentmint/providers/timestamp.py +++ b/agentmint/providers/timestamp.py @@ -1,43 +1,79 @@ -"""Timestamp provider stubs for the build spec provider layer. - -PR 2 will port the concrete RFC 3161 behavior from :mod:`agentmint.timestamp` -into this protocol-shaped provider module. -""" +"""Timestamp providers.""" from __future__ import annotations -from agentmint.protocols import Timestamper - - -class NoTimestamper: - """Placeholder timestamper that will represent disabled timestamping.""" +import logging +from dataclasses import dataclass +from datetime import datetime, timezone - def timestamp(self, digest: bytes) -> bytes: - """Return empty timestamp evidence for a digest.""" +from agentmint.protocols import Timestamper +from agentmint.timestamp import TimestampError, verify as verify_token +from agentmint.timestamp import timestamp as issue_timestamp - raise NotImplementedError("NoTimestamper will be implemented in PR 2") - def verify(self, digest: bytes, token: bytes) -> bool: - """Verify empty timestamp evidence.""" +LOGGER = logging.getLogger(__name__) - raise NotImplementedError("NoTimestamper will be implemented in PR 2") +@dataclass(frozen=True) +class TimestampRecord: + """Timestamp payload stored on a receipt.""" -class RFC3161Timestamper: - """Placeholder RFC 3161 timestamper implementation.""" + observed_at: str + source: str + proof: bytes = b"" + tsq: bytes = b"" + tsr: bytes = b"" + digest_hex: str = "" + tsa_url: str = "" - def __init__(self, url: str) -> None: - self.url = url + def to_dict(self) -> dict[str, str]: + data = { + "observed_at": self.observed_at, + "source": self.source, + } + if self.digest_hex: + data["digest_hex"] = self.digest_hex + if self.tsa_url: + data["tsa_url"] = self.tsa_url + return data - def timestamp(self, digest: bytes) -> bytes: - """Return RFC 3161 timestamp evidence for a digest.""" - raise NotImplementedError("RFC3161Timestamper will be implemented in PR 2") +class NoTimestamper(Timestamper): + """Self-reported UTC timestamps with no network dependency.""" - def verify(self, digest: bytes, token: bytes) -> bool: - """Verify RFC 3161 timestamp evidence.""" + def timestamp(self, payload: bytes) -> TimestampRecord: + del payload + observed_at = datetime.now(timezone.utc).isoformat() + return TimestampRecord(observed_at=observed_at, source="self") - raise NotImplementedError("RFC3161Timestamper will be implemented in PR 2") +class RFC3161Timestamper(Timestamper): + """RFC 3161 timestamper with graceful self-reported fallback.""" -__all__ = ["NoTimestamper", "RFC3161Timestamper", "Timestamper"] + def __init__(self, url: str, timeout_seconds: int = 5) -> None: + self.url = url + self.timeout_seconds = timeout_seconds + self._fallback = NoTimestamper() + + def timestamp(self, payload: bytes) -> TimestampRecord: + del self.timeout_seconds + try: + result = issue_timestamp(payload, url=self.url) + return TimestampRecord( + observed_at=datetime.now(timezone.utc).isoformat(), + source=self.url, + proof=result.tsr, + tsq=result.tsq, + tsr=result.tsr, + digest_hex=result.digest_hex, + tsa_url=result.tsa_url, + ) + except TimestampError as exc: + LOGGER.warning("TSA unreachable, falling back to self timestamp: %s", exc) + return self._fallback.timestamp(payload) + + def verify(self, tsq_path, tsr_path, cacert_path, tsa_cert_path): # pragma: no cover + return verify_token(tsq_path, tsr_path, cacert_path, tsa_cert_path) + + +__all__ = ["NoTimestamper", "RFC3161Timestamper", "TimestampRecord", "Timestamper"] diff --git a/agentmint/receipt.py b/agentmint/receipt.py new file mode 100644 index 0000000..9a78348 --- /dev/null +++ b/agentmint/receipt.py @@ -0,0 +1,135 @@ +"""Receipt data model.""" + +from __future__ import annotations + +import hashlib +import json +from dataclasses import dataclass, field +from typing import Any, Mapping, Optional, Tuple + +from nacl.exceptions import BadSignatureError + +from .providers.serializers import JCSSerializer +from .providers.timestamp import TimestampRecord + + +def _to_plain(value: Any) -> Any: + if isinstance(value, dict): + return {key: _to_plain(item) for key, item in value.items()} + if isinstance(value, (list, tuple)): + return [_to_plain(item) for item in value] + if hasattr(value, "to_dict"): + return _to_plain(value.to_dict()) + return value + + +@dataclass(frozen=True) +class Receipt: + """Signed AERF evidence receipt.""" + + id: str + plan_id: str + agent: str + action: str + in_policy: bool + policy_reason: str + evidence_hash_sha512: str + evidence: Mapping[str, Any] + observed_at: str + key_id: str + signature: str + previous_receipt_hash: Optional[str] = None + plan_signature: str = "" + agent_signature: str = "" + agent_key_id: str = "" + policy_hash: str = "" + output_hash: str = "" + session_id: str = "" + session_trajectory: Optional[Tuple[Mapping[str, Any], ...]] = None + session_escalation: Optional[str] = None + reasoning_hash: Optional[str] = None + compliance_tags: Optional[Tuple[str, ...]] = None + timestamp: Optional[TimestampRecord] = None + aiuc_controls: Tuple[str, ...] = field(default_factory=tuple) + mode: str = "enforce" + original_verdict: Optional[bool] = None + + @property + def evidence_hash(self) -> str: + return self.evidence_hash_sha512 + + @property + def timestamp_result(self) -> Optional[TimestampRecord]: + return self.timestamp + + @property + def short_id(self) -> str: + return self.id[:8] + + def signable_dict(self) -> dict[str, Any]: + data = { + "id": self.id, + "type": "notarised_evidence", + "plan_id": self.plan_id, + "agent": self.agent, + "action": self.action, + "in_policy": self.in_policy, + "policy_reason": self.policy_reason, + "evidence_hash_sha512": self.evidence_hash_sha512, + "evidence": _to_plain(self.evidence), + "observed_at": self.observed_at, + "key_id": self.key_id, + } + + optionals = ( + ("previous_receipt_hash", self.previous_receipt_hash), + ("plan_signature", self.plan_signature), + ("agent_signature", self.agent_signature), + ("agent_key_id", self.agent_key_id), + ("policy_hash", self.policy_hash), + ("output_hash", self.output_hash), + ("session_id", self.session_id), + ( + "session_trajectory", + _to_plain(self.session_trajectory) if self.session_trajectory else None, + ), + ("session_escalation", self.session_escalation), + ("reasoning_hash", self.reasoning_hash), + ("compliance_tags", list(self.compliance_tags) if self.compliance_tags else None), + ("aiuc_controls", list(self.aiuc_controls) if self.aiuc_controls else None), + ) + for key, value in optionals: + if value not in (None, "", (), []): + data[key] = value + + if self.mode != "enforce": + data["mode"] = self.mode + if self.original_verdict is not None: + data["original_verdict"] = self.original_verdict + + return data + + def to_dict(self) -> dict[str, Any]: + payload = dict(self.signable_dict()) + payload["signature"] = self.signature + if self.timestamp is not None: + payload["timestamp"] = self.timestamp.to_dict() + return payload + + def to_json(self, indent: int = 2) -> str: + return json.dumps(self.to_dict(), indent=indent) + + def verify(self, key_provider: Any) -> bool: + serializer = JCSSerializer() + try: + key_provider.verify_key().verify( + serializer.canonicalize(self.signable_dict()), + bytes.fromhex(self.signature), + ) + return True + except (BadSignatureError, ValueError): + return False + + def canonical_hash(self) -> str: + serializer = JCSSerializer() + return hashlib.sha256(serializer.canonicalize(self.signable_dict())).hexdigest() diff --git a/tests/test_aerf_conformance.py b/tests/test_aerf_conformance.py index dcb9304..d857030 100644 --- a/tests/test_aerf_conformance.py +++ b/tests/test_aerf_conformance.py @@ -23,7 +23,6 @@ def _format_validation_errors(errors: list[object]) -> str: return "\n".join(lines) -@pytest.mark.xfail(strict=True, reason="Current Notary receipts intentionally drift from AERF v0.1") def test_notary_receipt_matches_aerf_v01_schema() -> None: """Produce a receipt through Notary.notarise and validate it against AERF v0.1.""" @@ -33,12 +32,12 @@ def test_notary_receipt_matches_aerf_v01_schema() -> None: notary = Notary() plan = notary.create_plan( user="auditor@example.com", - action="files/read", - scope=["files/*"], + action="files:read", + scope=["files:*"], ttl_seconds=60, ) receipt = notary.notarise( - action="files/read", + action="files:read", agent="conformance-agent", plan=plan, evidence={"path": "/tmp/demo.txt", "operation": "read"}, diff --git a/tests/test_protocol_providers.py b/tests/test_protocol_providers.py new file mode 100644 index 0000000..fc220e6 --- /dev/null +++ b/tests/test_protocol_providers.py @@ -0,0 +1,79 @@ +"""Tests for protocol-oriented provider implementations.""" + +from __future__ import annotations + +import os + +from agentmint.providers.keys import EnvKeyProvider, FileKeyProvider +from agentmint.providers.redactors import FieldRedactor, NoRedactor +from agentmint.providers.serializers import JCSSerializer +from agentmint.providers.sinks import FileSink, MemorySink +from agentmint.providers.timestamp import NoTimestamper + + +class TestJCSSerializer: + def test_orders_keys_and_strips_whitespace(self) -> None: + serializer = JCSSerializer() + assert serializer.canonicalize({"b": 2, "a": 1}) == b'{"a":1,"b":2}' + + def test_normalizes_basic_numbers(self) -> None: + serializer = JCSSerializer() + payload = {"a": 1.0, "b": 4.50, "c": 1e30, "d": 0.002, "e": -0.0} + assert serializer.canonicalize(payload) == b'{"a":1,"b":4.5,"c":1e+30,"d":0.002,"e":0}' + + def test_preserves_utf8_strings(self) -> None: + serializer = JCSSerializer() + assert serializer.canonicalize({"greeting": "hello", "snowman": "\u2603"}) == ( + '{"greeting":"hello","snowman":"☃"}'.encode("utf-8") + ) + + +class TestKeyProviders: + def test_file_key_provider_persists_and_reloads(self, tmp_path) -> None: + provider1 = FileKeyProvider(tmp_path) + provider2 = FileKeyProvider(tmp_path) + assert provider1.key_id() == provider2.key_id() + assert provider1.public_key() == provider2.public_key() + + def test_env_key_provider_reads_hex_seed(self, monkeypatch) -> None: + seed = bytes(range(32)) + monkeypatch.setenv("AGENTMINT_PRIVATE_KEY", seed.hex()) + provider = EnvKeyProvider() + assert provider.public_key() + assert len(provider.key_id()) == 16 + + +class TestRedactors: + def test_field_redactor_hashes_and_drops_recursively(self) -> None: + redactor = FieldRedactor(always_hash=["token"], always_drop=["secret"]) + evidence, modified = redactor.redact( + {"token": "abc", "nested": {"secret": "gone", "keep": 1}} + ) + assert evidence["token"] != "abc" + assert "nested.secret" in modified + assert "secret" not in evidence["nested"] + + def test_no_redactor_is_passthrough(self) -> None: + evidence = {"a": 1} + result, modified = NoRedactor().redact(evidence) + assert result == evidence + assert modified == [] + + +class TestSinksAndTimestamp: + def test_memory_sink_is_fifo(self) -> None: + sink = MemorySink() + sink.write("one", b"1") + sink.write("two", b"2") + assert list(sink.records)[0][0] == "one" + assert list(sink.records)[1][0] == "two" + + def test_file_sink_writes_date_partition(self, tmp_path) -> None: + sink = FileSink(tmp_path) + path = sink.write("receipt-id", b"{}", None) + assert os.path.exists(path) + assert path.endswith("receipt-id.json") + + def test_no_timestamper_returns_self_source(self) -> None: + result = NoTimestamper().timestamp(b"payload") + assert result.source == "self"