From 77aa62d129e1c187e8eb1dd277bfd6e23f1331e7 Mon Sep 17 00:00:00 2001 From: Ajit Koti Date: Tue, 17 Mar 2026 20:04:02 -0700 Subject: [PATCH] Update the persistense layer --- Cargo.toml | 1 + README.md | 9 +- docs/README.md | 7 +- docs/architecture.md | 23 +- docs/protocol.md | 9 +- src/lib.rs | 1 + src/log_store.rs | 65 ---- src/main.rs | 78 +++-- src/mode/decision.rs | 1 + src/mode/handoff.rs | 19 ++ src/mode/multi_round.rs | 1 + src/mode/proposal.rs | 1 + src/mode/quorum.rs | 1 + src/mode/task.rs | 34 ++- src/registry.rs | 101 +++++-- src/runtime.rs | 516 ++++++++++++++++++++++--------- src/security.rs | 69 +---- src/server.rs | 81 +++-- src/session.rs | 1 + src/storage.rs | 656 ++++++++++++++++++++++++++++++++++++++++ 20 files changed, 1325 insertions(+), 349 deletions(-) create mode 100644 src/storage.rs diff --git a/Cargo.toml b/Cargo.toml index f7218da..163db37 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -15,6 +15,7 @@ serde_json = "1" tokio-stream = "0.1" futures-core = "0.3" async-stream = "0.3" +async-trait = "0.1" [dev-dependencies] tempfile = "3" diff --git a/README.md b/README.md index a26987c..b429142 100644 --- a/README.md +++ b/README.md @@ -25,9 +25,9 @@ This runtime implements the current MACP core/service surface, the five standard - payload size limits - rate limiting - **Durable local persistence** - - session registry snapshots - - accepted-history log snapshots - - dedup state survives restart + - per-session append-only log files and session snapshots via `FileBackend` + - crash recovery with dedup state reconciliation + - atomic writes (tmp file + rename) prevent partial-write corruption - **Unary freeze profile** - `StreamSession` is intentionally disabled in this profile - `WatchModeRegistry` and `WatchRoots` remain unimplemented @@ -192,7 +192,8 @@ runtime/ │ ├── security.rs # auth config, sender derivation, rate limiting │ ├── session.rs # canonical SessionStart validation and session model │ ├── registry.rs # session store with optional persistence -│ ├── log_store.rs # accepted-history log store with optional persistence +│ ├── log_store.rs # in-memory accepted-history log cache +│ ├── storage.rs # storage backend trait, FileBackend persistence, crash recovery │ ├── mode/ # mode implementations │ └── bin/ # local development example clients ├── docs/ diff --git a/docs/README.md b/docs/README.md index 382f103..fb7a3e9 100644 --- a/docs/README.md +++ b/docs/README.md @@ -65,10 +65,11 @@ In dev mode, example clients attach `x-macp-agent-id` metadata and may use plain ## Persistence model -By default the runtime persists snapshots under `.macp-data/`: +By default the runtime persists state under `.macp-data/` via `FileBackend`: -- `sessions.json` -- `logs.json` +- per-session directories containing `session.json` and append-only `log.jsonl` +- crash recovery reconciles dedup state from the log on startup +- atomic writes (tmp file + rename) prevent partial-write corruption If a snapshot file contains corrupt or incompatible JSON, the runtime logs a warning to stderr and starts with empty state. diff --git a/docs/architecture.md b/docs/architecture.md index b16ff3e..c5ee203 100644 --- a/docs/architecture.md +++ b/docs/architecture.md @@ -43,9 +43,16 @@ Implemented modes: ## 4. Storage layer +### Storage backend (`src/storage.rs`) + +Provides the `StorageBackend` trait with two implementations: + +- `FileBackend` — per-session directories containing `session.json` and append-only `log.jsonl`, with crash recovery and atomic writes +- `MemoryBackend` — no-op backend for `MACP_MEMORY_ONLY=1` + ### Session registry (`src/registry.rs`) -Stores: +In-memory cache of all sessions, loaded from `FileBackend` on startup. Stores: - session metadata - bound versions @@ -53,24 +60,16 @@ Stores: - dedup state - current session state -Supports: - -- in-memory mode -- file-backed snapshot persistence - -Both stores log a warning and fall back to empty state if snapshot deserialization fails. +Supports optional file-backed snapshot persistence for backward compatibility. ### Log store (`src/log_store.rs`) -Stores: +In-memory cache of accepted-history logs. Stores: - accepted incoming envelopes - runtime-generated internal events such as TTL expiry and session cancellation -Supports: - -- in-memory mode -- file-backed snapshot persistence +On-disk persistence is handled by `FileBackend`, not by LogStore. ## 5. Security layer (`src/security.rs`) diff --git a/docs/protocol.md b/docs/protocol.md index 155366a..022301e 100644 --- a/docs/protocol.md +++ b/docs/protocol.md @@ -67,12 +67,13 @@ Local development profile: ## Persistence profile -By default the runtime persists snapshots of: +By default the runtime persists state via `FileBackend` under `MACP_DATA_DIR`: -- session registry -- accepted-history log store +- per-session `session.json` and append-only `log.jsonl` files +- crash recovery reconciles dedup state from the log on startup +- atomic writes (tmp file + rename) prevent partial-write corruption -This gives restart recovery for session metadata, dedup state, and accepted-history inspection. Corrupt or incompatible snapshot files produce a warning on stderr; the runtime falls back to empty state instead of refusing to start. +This gives restart recovery for session metadata, dedup state, and accepted-history inspection. Corrupt or incompatible files produce a warning on stderr; the runtime falls back to empty state instead of refusing to start. ## Commitment validation diff --git a/src/lib.rs b/src/lib.rs index f20f263..689f39f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -28,5 +28,6 @@ pub mod mode; pub mod registry; pub mod runtime; pub mod session; +pub mod storage; pub mod security; diff --git a/src/log_store.rs b/src/log_store.rs index 458bb2d..e41cf9a 100644 --- a/src/log_store.rs +++ b/src/log_store.rs @@ -1,6 +1,4 @@ use std::collections::HashMap; -use std::fs; -use std::path::{Path, PathBuf}; use tokio::sync::RwLock; #[derive(Clone, Debug, PartialEq, serde::Serialize, serde::Deserialize)] @@ -21,7 +19,6 @@ pub struct LogEntry { pub struct LogStore { logs: RwLock>>, - persistence_path: Option, } impl Default for LogStore { @@ -34,60 +31,17 @@ impl LogStore { pub fn new() -> Self { Self { logs: RwLock::new(HashMap::new()), - persistence_path: None, } } - pub fn with_persistence>(dir: P) -> std::io::Result { - let dir = dir.as_ref().to_path_buf(); - fs::create_dir_all(&dir)?; - let path = dir.join("logs.json"); - let logs = if path.exists() { - match serde_json::from_slice(&fs::read(&path)?) { - Ok(v) => v, - Err(e) => { - eprintln!("warning: failed to deserialize logs from {}: {e}; starting with empty state", path.display()); - HashMap::new() - } - } - } else { - HashMap::new() - }; - Ok(Self { - logs: RwLock::new(logs), - persistence_path: Some(path), - }) - } - - fn persist_map(path: &Path, logs: &HashMap>) -> std::io::Result<()> { - let bytes = serde_json::to_vec_pretty(logs)?; - let tmp_path = path.with_extension("json.tmp"); - fs::write(&tmp_path, bytes)?; - fs::rename(&tmp_path, path) - } - - async fn persist_locked(&self, logs: &HashMap>) -> std::io::Result<()> { - if let Some(path) = &self.persistence_path { - Self::persist_map(path, logs)?; - } - Ok(()) - } - - pub async fn persist_snapshot(&self) -> std::io::Result<()> { - let guard = self.logs.read().await; - self.persist_locked(&guard).await - } - pub async fn create_session_log(&self, session_id: &str) { let mut guard = self.logs.write().await; guard.entry(session_id.to_string()).or_default(); - let _ = self.persist_locked(&guard).await; } pub async fn append(&self, session_id: &str, entry: LogEntry) { let mut guard = self.logs.write().await; guard.entry(session_id.to_string()).or_default().push(entry); - let _ = self.persist_locked(&guard).await; } pub async fn get_log(&self, session_id: &str) -> Option> { @@ -99,7 +53,6 @@ impl LogStore { #[cfg(test)] mod tests { use super::*; - use std::time::{SystemTime, UNIX_EPOCH}; fn entry(id: &str, kind: EntryKind) -> LogEntry { LogEntry { @@ -124,22 +77,4 @@ mod tests { assert_eq!(log[0].message_id, "m1"); assert_eq!(log[1].message_id, "m2"); } - - #[tokio::test] - async fn persistent_log_store_round_trip() { - let base = std::env::temp_dir().join(format!( - "macp-log-test-{}", - SystemTime::now() - .duration_since(UNIX_EPOCH) - .unwrap() - .as_nanos() - )); - let store = LogStore::with_persistence(&base).unwrap(); - store.append("s1", entry("m1", EntryKind::Incoming)).await; - - let reopened = LogStore::with_persistence(&base).unwrap(); - let log = reopened.get_log("s1").await.unwrap(); - assert_eq!(log.len(), 1); - assert_eq!(log[0].message_id, "m1"); - } } diff --git a/src/main.rs b/src/main.rs index 0ef966b..b945310 100644 --- a/src/main.rs +++ b/src/main.rs @@ -5,6 +5,10 @@ use macp_runtime::pb; use macp_runtime::registry::SessionRegistry; use macp_runtime::runtime::Runtime; use macp_runtime::security::SecurityLayer; +use macp_runtime::storage::{ + cleanup_temp_files, migrate_if_needed, recover_session, FileBackend, MemoryBackend, + StorageBackend, +}; use server::MacpServer; use std::path::PathBuf; use std::sync::Arc; @@ -20,21 +24,55 @@ async fn main() -> Result<(), Box> { let data_dir = PathBuf::from(std::env::var("MACP_DATA_DIR").unwrap_or_else(|_| ".macp-data".into())); - let registry = Arc::new(if memory_only { - SessionRegistry::new() + let storage: Arc = if memory_only { + Arc::new(MemoryBackend) } else { - SessionRegistry::with_persistence(&data_dir)? - }); - let log_store = Arc::new(if memory_only { - LogStore::new() - } else { - LogStore::with_persistence(&data_dir)? - }); + std::fs::create_dir_all(&data_dir)?; + migrate_if_needed(&data_dir)?; + cleanup_temp_files(&data_dir); + Arc::new(FileBackend::new(data_dir.clone())?) + }; + + // Load persisted state into in-memory caches + let registry = Arc::new(SessionRegistry::new()); + let log_store = Arc::new(LogStore::new()); + + if !memory_only { + let mut sessions = storage.load_all_sessions().await?; + for session in &mut sessions { + let log_entries = storage.load_log(&session.session_id).await?; + + // Crash recovery: reconcile dedup state from log + recover_session(session, &log_entries); + + // Populate in-memory log store + log_store.create_session_log(&session.session_id).await; + for entry in &log_entries { + log_store.append(&session.session_id, entry.clone()).await; + } - let registry_ref = Arc::clone(®istry); - let log_store_ref = Arc::clone(&log_store); + // Persist recovered session state if it changed + if let Err(e) = storage.save_session(session).await { + eprintln!( + "warning: failed to persist recovered session '{}': {e}", + session.session_id + ); + } - let runtime = Arc::new(Runtime::new(registry, log_store)); + registry + .insert_session_for_test(session.session_id.clone(), session.clone()) + .await; + } + if !sessions.is_empty() { + println!("Loaded {} sessions from storage.", sessions.len()); + } + } + + let runtime = Arc::new(Runtime::new( + Arc::clone(&storage), + Arc::clone(®istry), + Arc::clone(&log_store), + )); let security = SecurityLayer::from_env()?; let svc = MacpServer::new(runtime, security); @@ -78,12 +116,16 @@ async fn main() -> Result<(), Box> { } } - // Persist final state on shutdown - if let Err(e) = registry_ref.persist_snapshot().await { - eprintln!("warning: failed to persist session registry: {}", e); - } - if let Err(e) = log_store_ref.persist_snapshot().await { - eprintln!("warning: failed to persist log store: {}", e); + // Final snapshot: persist all sessions to storage + if !memory_only { + for session in registry.get_all_sessions().await { + if let Err(e) = storage.save_session(&session).await { + eprintln!( + "warning: failed to persist session '{}' on shutdown: {e}", + session.session_id + ); + } + } } println!("State persisted. Goodbye."); diff --git a/src/mode/decision.rs b/src/mode/decision.rs index 1fbfdd9..14b42aa 100644 --- a/src/mode/decision.rs +++ b/src/mode/decision.rs @@ -223,6 +223,7 @@ mod tests { session_id: "s1".into(), state: SessionState::Open, ttl_expiry: i64::MAX, + ttl_ms: 60_000, started_at_unix_ms: 0, resolution: None, mode: "macp.mode.decision.v1".into(), diff --git a/src/mode/handoff.rs b/src/mode/handoff.rs index 800e18e..4fe2519 100644 --- a/src/mode/handoff.rs +++ b/src/mode/handoff.rs @@ -81,6 +81,13 @@ impl Mode for HandoffMode { if session.participants.len() < 2 { return Err(MacpError::InvalidPayload); } + if !session + .participants + .iter() + .any(|p| p == &session.initiator_sender) + { + return Err(MacpError::InvalidPayload); + } Ok(ModeResponse::PersistState(Self::encode_state( &HandoffState::default(), ))) @@ -214,6 +221,7 @@ mod tests { session_id: "s1".into(), state: SessionState::Open, ttl_expiry: i64::MAX, + ttl_ms: 60_000, started_at_unix_ms: 0, resolution: None, mode: "macp.mode.handoff.v1".into(), @@ -330,6 +338,17 @@ mod tests { assert_eq!(err.to_string(), "InvalidPayload"); } + #[test] + fn session_start_rejects_when_initiator_not_participant() { + let mode = HandoffMode; + let mut session = base_session(); + session.participants = vec!["target".into(), "other".into()]; // owner not included + let err = mode + .on_session_start(&session, &env("owner", "SessionStart", vec![])) + .unwrap_err(); + assert_eq!(err.to_string(), "InvalidPayload"); + } + // --- HandoffOffer --- #[test] diff --git a/src/mode/multi_round.rs b/src/mode/multi_round.rs index b921310..aab684b 100644 --- a/src/mode/multi_round.rs +++ b/src/mode/multi_round.rs @@ -128,6 +128,7 @@ mod tests { session_id: "s1".into(), state: SessionState::Open, ttl_expiry: i64::MAX, + ttl_ms: 60_000, started_at_unix_ms: 0, resolution: None, mode: "multi_round".into(), diff --git a/src/mode/proposal.rs b/src/mode/proposal.rs index eaaea1d..b2632fd 100644 --- a/src/mode/proposal.rs +++ b/src/mode/proposal.rs @@ -241,6 +241,7 @@ mod tests { session_id: "s1".into(), state: SessionState::Open, ttl_expiry: i64::MAX, + ttl_ms: 60_000, started_at_unix_ms: 0, resolution: None, mode: "macp.mode.proposal.v1".into(), diff --git a/src/mode/quorum.rs b/src/mode/quorum.rs index 2dd35bb..85ae90a 100644 --- a/src/mode/quorum.rs +++ b/src/mode/quorum.rs @@ -213,6 +213,7 @@ mod tests { session_id: "s1".into(), state: SessionState::Open, ttl_expiry: i64::MAX, + ttl_ms: 60_000, started_at_unix_ms: 0, resolution: None, mode: "macp.mode.quorum.v1".into(), diff --git a/src/mode/task.rs b/src/mode/task.rs index 1032559..493284e 100644 --- a/src/mode/task.rs +++ b/src/mode/task.rs @@ -113,7 +113,14 @@ impl Mode for TaskMode { session: &Session, _env: &Envelope, ) -> Result { - if session.participants.is_empty() { + if session.participants.len() < 2 { + return Err(MacpError::InvalidPayload); + } + if !session + .participants + .iter() + .any(|p| p == &session.initiator_sender) + { return Err(MacpError::InvalidPayload); } Ok(ModeResponse::PersistState(Self::encode_state( @@ -288,6 +295,7 @@ mod tests { session_id: "s1".into(), state: SessionState::Open, ttl_expiry: i64::MAX, + ttl_ms: 60_000, started_at_unix_ms: 0, resolution: None, mode: "macp.mode.task.v1".into(), @@ -420,7 +428,18 @@ mod tests { } #[test] - fn session_start_requires_participants() { + fn session_start_requires_at_least_two_participants() { + let mode = TaskMode; + let mut session = base_session(); + session.participants = vec!["planner".into()]; // only 1 + let err = mode + .on_session_start(&session, &env("planner", "SessionStart", vec![])) + .unwrap_err(); + assert_eq!(err.to_string(), "InvalidPayload"); + } + + #[test] + fn session_start_rejects_empty_participants() { let mode = TaskMode; let mut session = base_session(); session.participants.clear(); @@ -430,6 +449,17 @@ mod tests { assert_eq!(err.to_string(), "InvalidPayload"); } + #[test] + fn session_start_rejects_when_initiator_not_participant() { + let mode = TaskMode; + let mut session = base_session(); + session.participants = vec!["worker".into(), "other".into()]; // planner not included + let err = mode + .on_session_start(&session, &env("planner", "SessionStart", vec![])) + .unwrap_err(); + assert_eq!(err.to_string(), "InvalidPayload"); + } + // --- TaskRequest --- #[test] diff --git a/src/registry.rs b/src/registry.rs index a75286a..3f5e43c 100644 --- a/src/registry.rs +++ b/src/registry.rs @@ -5,37 +5,47 @@ use std::path::{Path, PathBuf}; use tokio::sync::RwLock; #[derive(serde::Serialize, serde::Deserialize)] -struct PersistedRoot { - uri: String, - name: String, +pub(crate) struct PersistedRoot { + pub uri: String, + pub name: String, } #[derive(serde::Serialize, serde::Deserialize)] -struct PersistedSession { - session_id: String, - state: crate::session::SessionState, - ttl_expiry: i64, - started_at_unix_ms: i64, - resolution: Option>, - mode: String, - mode_state: Vec, - participants: Vec, - seen_message_ids: Vec, - intent: String, - mode_version: String, - configuration_version: String, - policy_version: String, - context: Vec, - roots: Vec, - initiator_sender: String, +pub(crate) struct PersistedSession { + #[serde(default = "default_schema_version")] + pub schema_version: u32, + pub session_id: String, + pub state: crate::session::SessionState, + pub ttl_expiry: i64, + #[serde(default)] + pub ttl_ms: i64, + pub started_at_unix_ms: i64, + pub resolution: Option>, + pub mode: String, + pub mode_state: Vec, + pub participants: Vec, + pub seen_message_ids: Vec, + pub intent: String, + pub mode_version: String, + pub configuration_version: String, + pub policy_version: String, + pub context: Vec, + pub roots: Vec, + pub initiator_sender: String, +} + +fn default_schema_version() -> u32 { + 2 } impl From<&Session> for PersistedSession { fn from(session: &Session) -> Self { Self { + schema_version: 2, session_id: session.session_id.clone(), state: session.state.clone(), ttl_expiry: session.ttl_expiry, + ttl_ms: session.ttl_ms, started_at_unix_ms: session.started_at_unix_ms, resolution: session.resolution.clone(), mode: session.mode.clone(), @@ -62,10 +72,19 @@ impl From<&Session> for PersistedSession { impl From for Session { fn from(session: PersistedSession) -> Self { + let ttl_ms = if session.ttl_ms > 0 { + session.ttl_ms + } else { + // Backward compatibility: compute from absolute timestamps + session + .ttl_expiry + .saturating_sub(session.started_at_unix_ms) + }; Self { session_id: session.session_id, state: session.state, ttl_expiry: session.ttl_expiry, + ttl_ms, started_at_unix_ms: session.started_at_unix_ms, resolution: session.resolution, mode: session.mode, @@ -169,6 +188,11 @@ impl SessionRegistry { guard.get(session_id).cloned() } + pub async fn get_all_sessions(&self) -> Vec { + let guard = self.sessions.read().await; + guard.values().cloned().collect() + } + pub async fn insert_session_for_test(&self, session_id: String, session: Session) { let mut guard = self.sessions.write().await; guard.insert(session_id, session); @@ -176,12 +200,14 @@ impl SessionRegistry { } pub async fn count_open_sessions_for_initiator(&self, sender: &str) -> usize { + let now = chrono::Utc::now().timestamp_millis(); let guard = self.sessions.read().await; guard .values() .filter(|session| { session.initiator_sender == sender && session.state == crate::session::SessionState::Open + && now <= session.ttl_expiry }) .count() } @@ -199,6 +225,7 @@ mod tests { session_id: id.into(), state: SessionState::Open, ttl_expiry: 10, + ttl_ms: 9, started_at_unix_ms: 1, resolution: None, mode: "macp.mode.decision.v1".into(), @@ -218,6 +245,40 @@ mod tests { } } + #[tokio::test] + async fn expired_sessions_not_counted_against_limit() { + let registry = SessionRegistry::new(); + let now = chrono::Utc::now().timestamp_millis(); + // Insert a session with TTL already expired + let mut expired = sample_session("expired-s1"); + expired.initiator_sender = "agent://alice".into(); + expired.ttl_expiry = now - 1000; // expired 1 second ago + expired.state = SessionState::Open; // still Open but TTL is past + registry + .insert_session_for_test("expired-s1".into(), expired) + .await; + + // Should not count the expired-but-open session + let count = registry + .count_open_sessions_for_initiator("agent://alice") + .await; + assert_eq!(count, 0); + + // Insert a session that is still valid + let mut active = sample_session("active-s1"); + active.initiator_sender = "agent://alice".into(); + active.ttl_expiry = now + 60_000; // expires in 60s + active.state = SessionState::Open; + registry + .insert_session_for_test("active-s1".into(), active) + .await; + + let count = registry + .count_open_sessions_for_initiator("agent://alice") + .await; + assert_eq!(count, 1); + } + #[tokio::test] async fn persistent_registry_round_trip() { let base = std::env::temp_dir().join(format!( diff --git a/src/runtime.rs b/src/runtime.rs index 8fa556a..cbee836 100644 --- a/src/runtime.rs +++ b/src/runtime.rs @@ -17,6 +17,7 @@ use crate::session::{ extract_ttl_ms, is_standard_mode, parse_session_start_payload, validate_standard_session_start_payload, Session, SessionState, }; +use crate::storage::StorageBackend; const EXPERIMENTAL_DEFAULT_TTL_MS: i64 = 60_000; @@ -27,13 +28,18 @@ pub struct ProcessResult { } pub struct Runtime { + pub storage: Arc, pub registry: Arc, pub log_store: Arc, modes: HashMap>, } impl Runtime { - pub fn new(registry: Arc, log_store: Arc) -> Self { + pub fn new( + storage: Arc, + registry: Arc, + log_store: Arc, + ) -> Self { let mut modes: HashMap> = HashMap::new(); modes.insert("macp.mode.decision.v1".into(), Box::new(DecisionMode)); modes.insert("macp.mode.proposal.v1".into(), Box::new(ProposalMode)); @@ -43,6 +49,7 @@ impl Runtime { modes.insert("macp.mode.multi_round.v1".into(), Box::new(MultiRoundMode)); Self { + storage, registry, log_store, modes, @@ -95,33 +102,46 @@ impl Runtime { } } - async fn persist_sessions(&self, sessions: &HashMap) { - if let Err(err) = self.registry.persist_locked(sessions).await { - eprintln!("warning: failed to persist session registry: {err}"); + async fn save_session_to_storage(&self, session: &Session) { + if let Err(err) = self.storage.save_session(session).await { + eprintln!( + "warning: failed to persist session '{}': {err}", + session.session_id + ); } } async fn maybe_expire_session(&self, session_id: &str, session: &mut Session) -> bool { let now = Utc::now().timestamp_millis(); if session.state == SessionState::Open && now > session.ttl_expiry { - self.log_store - .append(session_id, Self::make_internal_entry("TtlExpired", b"")) - .await; + let entry = Self::make_internal_entry("TtlExpired", b""); + if let Err(e) = self.storage.append_log_entry(session_id, &entry).await { + eprintln!("warning: failed to persist TTL expiry log for '{session_id}': {e}"); + } + self.log_store.append(session_id, entry).await; session.state = SessionState::Expired; return true; } false } - pub async fn process(&self, env: &Envelope) -> Result { + pub async fn process( + &self, + env: &Envelope, + max_open_sessions: Option, + ) -> Result { match env.message_type.as_str() { - "SessionStart" => self.process_session_start(env).await, + "SessionStart" => self.process_session_start(env, max_open_sessions).await, "Signal" => self.process_signal(env).await, _ => self.process_message(env).await, } } - async fn process_session_start(&self, env: &Envelope) -> Result { + async fn process_session_start( + &self, + env: &Envelope, + max_open_sessions: Option, + ) -> Result { if env.mode.trim().is_empty() { return Err(MacpError::InvalidEnvelope); } @@ -153,12 +173,31 @@ impl Runtime { return Err(MacpError::DuplicateSession); } + // Enforce max_open_sessions atomically under the write lock to + // prevent TOCTOU races where concurrent SessionStart requests + // both pass a read-lock count check before either is inserted. + if let Some(max_open) = max_open_sessions { + let now = Utc::now().timestamp_millis(); + let count = guard + .values() + .filter(|s| { + s.initiator_sender == env.sender + && s.state == SessionState::Open + && now <= s.ttl_expiry + }) + .count(); + if count >= max_open { + return Err(MacpError::RateLimited); + } + } + let accepted_at = Utc::now().timestamp_millis(); let ttl_expiry = accepted_at.saturating_add(ttl_ms); let session = Session { session_id: env.session_id.clone(), state: SessionState::Open, ttl_expiry, + ttl_ms, started_at_unix_ms: accepted_at, resolution: None, mode: mode_name.to_string(), @@ -176,10 +215,28 @@ impl Runtime { let response = mode.on_session_start(&session, env)?; + // 1. Create storage directory and write log entry (COMMIT POINT) + if let Err(e) = self.storage.create_session_storage(&env.session_id).await { + eprintln!( + "warning: failed to create session storage for '{}': {e}", + env.session_id + ); + } + let incoming_entry = Self::make_incoming_entry(env); + if let Err(e) = self + .storage + .append_log_entry(&env.session_id, &incoming_entry) + .await + { + eprintln!( + "warning: failed to persist log entry for '{}': {e}", + env.session_id + ); + } + + // 2. Update in-memory caches self.log_store.create_session_log(&env.session_id).await; - self.log_store - .append(&env.session_id, Self::make_incoming_entry(env)) - .await; + self.log_store.append(&env.session_id, incoming_entry).await; let mut session = session; session.seen_message_ids.insert(env.message_id.clone()); @@ -199,8 +256,9 @@ impl Runtime { } let result_state = session.state.clone(); + // 3. Best-effort session save + self.save_session_to_storage(&session).await; guard.insert(env.session_id.clone(), session); - self.persist_sessions(&guard).await; Ok(ProcessResult { session_state: result_state, @@ -221,8 +279,15 @@ impl Runtime { }); } + // Validate that the envelope mode matches the session's bound mode. + // This prevents a token scoped to mode X from sending messages into + // a session bound to mode Y (server.rs authorizes against env.mode). + if env.mode != session.mode { + return Err(MacpError::InvalidEnvelope); + } + if self.maybe_expire_session(&env.session_id, session).await { - self.persist_sessions(&guard).await; + self.save_session_to_storage(session).await; return Err(MacpError::TtlExpired); } @@ -237,13 +302,27 @@ impl Runtime { mode.authorize_sender(session, env)?; let response = mode.on_message(session, env)?; + // 1. COMMIT POINT: write log entry to disk + let incoming_entry = Self::make_incoming_entry(env); + if let Err(e) = self + .storage + .append_log_entry(&env.session_id, &incoming_entry) + .await + { + eprintln!( + "warning: failed to persist log entry for '{}': {e}", + env.session_id + ); + } + + // 2. Update in-memory state + self.log_store.append(&env.session_id, incoming_entry).await; session.seen_message_ids.insert(env.message_id.clone()); - self.log_store - .append(&env.session_id, Self::make_incoming_entry(env)) - .await; Self::apply_mode_response(session, response); let result_state = session.state.clone(); - self.persist_sessions(&guard).await; + + // 3. Best-effort session save + self.save_session_to_storage(session).await; Ok(ProcessResult { session_state: result_state, @@ -270,7 +349,9 @@ impl Runtime { return None; }; if changed { - self.persist_sessions(&guard).await; + if let Some(session) = guard.get(session_id) { + self.save_session_to_storage(session).await; + } } guard.get(session_id).cloned() } @@ -287,21 +368,24 @@ impl Runtime { if session.state == SessionState::Resolved || session.state == SessionState::Expired { let result_state = session.state.clone(); - self.persist_sessions(&guard).await; + self.save_session_to_storage(session).await; return Ok(ProcessResult { session_state: result_state, duplicate: false, }); } - self.log_store - .append( - session_id, - Self::make_internal_entry("SessionCancel", reason.as_bytes()), - ) - .await; + let cancel_entry = Self::make_internal_entry("SessionCancel", reason.as_bytes()); + if let Err(e) = self + .storage + .append_log_entry(session_id, &cancel_entry) + .await + { + eprintln!("warning: failed to persist cancel log for '{session_id}': {e}"); + } + self.log_store.append(session_id, cancel_entry).await; session.state = SessionState::Expired; - self.persist_sessions(&guard).await; + self.save_session_to_storage(session).await; Ok(ProcessResult { session_state: SessionState::Expired, @@ -318,9 +402,10 @@ mod tests { use prost::Message; fn make_runtime() -> Runtime { + let storage: Arc = Arc::new(crate::storage::MemoryBackend); let registry = Arc::new(SessionRegistry::new()); let log_store = Arc::new(LogStore::new()); - Runtime::new(registry, log_store) + Runtime::new(storage, registry, log_store) } fn session_start(participants: Vec) -> Vec { @@ -366,14 +451,17 @@ mod tests { } .encode_to_vec(); let err = rt - .process(&env( - "macp.mode.decision.v1", - "SessionStart", - "m1", - "s1", - "agent://orchestrator", - bad, - )) + .process( + &env( + "macp.mode.decision.v1", + "SessionStart", + "m1", + "s1", + "agent://orchestrator", + bad, + ), + None, + ) .await .unwrap_err(); assert!(matches!( @@ -386,14 +474,17 @@ mod tests { async fn empty_mode_is_rejected() { let rt = make_runtime(); let err = rt - .process(&env( - "", - "SessionStart", - "m1", - "s1", - "agent://orchestrator", - session_start(vec!["agent://fraud".into()]), - )) + .process( + &env( + "", + "SessionStart", + "m1", + "s1", + "agent://orchestrator", + session_start(vec!["agent://fraud".into()]), + ), + None, + ) .await .unwrap_err(); assert_eq!(err.to_string(), "InvalidEnvelope"); @@ -402,26 +493,32 @@ mod tests { #[tokio::test] async fn rejected_messages_do_not_enter_dedup_state() { let rt = make_runtime(); - rt.process(&env( - "macp.mode.decision.v1", - "SessionStart", - "m1", - "s1", - "agent://orchestrator", - session_start(vec!["agent://fraud".into()]), - )) + rt.process( + &env( + "macp.mode.decision.v1", + "SessionStart", + "m1", + "s1", + "agent://orchestrator", + session_start(vec!["agent://fraud".into()]), + ), + None, + ) .await .unwrap(); let bad = rt - .process(&env( - "macp.mode.decision.v1", - "Proposal", - "m2", - "s1", - "agent://fraud", - b"not-protobuf".to_vec(), - )) + .process( + &env( + "macp.mode.decision.v1", + "Proposal", + "m2", + "s1", + "agent://fraud", + b"not-protobuf".to_vec(), + ), + None, + ) .await .unwrap_err(); assert_eq!(bad.to_string(), "InvalidPayload"); @@ -434,14 +531,17 @@ mod tests { } .encode_to_vec(); let result = rt - .process(&env( - "macp.mode.decision.v1", - "Proposal", - "m2", - "s1", - "agent://orchestrator", - good, - )) + .process( + &env( + "macp.mode.decision.v1", + "Proposal", + "m2", + "s1", + "agent://orchestrator", + good, + ), + None, + ) .await .unwrap(); assert!(!result.duplicate); @@ -461,14 +561,17 @@ mod tests { roots: vec![], } .encode_to_vec(); - rt.process(&env( - "macp.mode.decision.v1", - "SessionStart", - "m1", - "s1", - "agent://orchestrator", - payload, - )) + rt.process( + &env( + "macp.mode.decision.v1", + "SessionStart", + "m1", + "s1", + "agent://orchestrator", + payload, + ), + None, + ) .await .unwrap(); tokio::time::sleep(std::time::Duration::from_millis(5)).await; @@ -484,14 +587,17 @@ mod tests { ..Default::default() } .encode_to_vec(); - rt.process(&env( - "macp.mode.multi_round.v1", - "SessionStart", - "m1", - "s1", - "creator", - payload, - )) + rt.process( + &env( + "macp.mode.multi_round.v1", + "SessionStart", + "m1", + "s1", + "creator", + payload, + ), + None, + ) .await .unwrap(); let session = rt.get_session_checked("s1").await.unwrap(); @@ -502,29 +608,78 @@ mod tests { async fn duplicate_session_start_message_id_returns_duplicate() { let rt = make_runtime(); let payload = session_start(vec!["agent://fraud".into()]); - rt.process(&env( - "macp.mode.decision.v1", - "SessionStart", - "m1", - "s1", - "agent://orchestrator", - payload.clone(), - )) + rt.process( + &env( + "macp.mode.decision.v1", + "SessionStart", + "m1", + "s1", + "agent://orchestrator", + payload.clone(), + ), + None, + ) .await .unwrap(); let result = rt - .process(&env( + .process( + &env( + "macp.mode.decision.v1", + "SessionStart", + "m1", + "s1", + "agent://orchestrator", + payload, + ), + None, + ) + .await + .unwrap(); + assert!(result.duplicate); + } + + #[tokio::test] + async fn non_start_mode_mismatch_rejected() { + let rt = make_runtime(); + // Start a decision session + rt.process( + &env( "macp.mode.decision.v1", "SessionStart", "m1", "s1", "agent://orchestrator", - payload, - )) + session_start(vec!["agent://fraud".into()]), + ), + None, + ) + .await + .unwrap(); + + // Send a message with a different mode to the same session + let proposal = ProposalPayload { + proposal_id: "p1".into(), + option: "step-up".into(), + rationale: "risk".into(), + supporting_data: vec![], + } + .encode_to_vec(); + let err = rt + .process( + &env( + "macp.mode.task.v1", // wrong mode + "Proposal", + "m2", + "s1", + "agent://orchestrator", + proposal, + ), + None, + ) .await - .unwrap(); - assert!(result.duplicate); + .unwrap_err(); + assert_eq!(err.to_string(), "InvalidEnvelope"); } #[tokio::test] @@ -541,14 +696,17 @@ mod tests { roots: vec![], } .encode_to_vec(); - rt.process(&env( - "macp.mode.decision.v1", - "SessionStart", - "m1", - "s1", - "agent://orchestrator", - payload, - )) + rt.process( + &env( + "macp.mode.decision.v1", + "SessionStart", + "m1", + "s1", + "agent://orchestrator", + payload, + ), + None, + ) .await .unwrap(); tokio::time::sleep(std::time::Duration::from_millis(5)).await; @@ -559,14 +717,17 @@ mod tests { #[tokio::test] async fn commitment_versions_are_carried_into_resolution() { let rt = make_runtime(); - rt.process(&env( - "macp.mode.proposal.v1", - "SessionStart", - "m1", - "s1", - "agent://buyer", - session_start(vec!["agent://buyer".into(), "agent://seller".into()]), - )) + rt.process( + &env( + "macp.mode.proposal.v1", + "SessionStart", + "m1", + "s1", + "agent://buyer", + session_start(vec!["agent://buyer".into(), "agent://seller".into()]), + ), + None, + ) .await .unwrap(); @@ -578,14 +739,17 @@ mod tests { tags: vec![], } .encode_to_vec(); - rt.process(&env( - "macp.mode.proposal.v1", - "Proposal", - "m2", - "s1", - "agent://seller", - proposal, - )) + rt.process( + &env( + "macp.mode.proposal.v1", + "Proposal", + "m2", + "s1", + "agent://seller", + proposal, + ), + None, + ) .await .unwrap(); let accept = crate::proposal_pb::AcceptPayload { @@ -593,24 +757,30 @@ mod tests { reason: String::new(), } .encode_to_vec(); - rt.process(&env( - "macp.mode.proposal.v1", - "Accept", - "m3", - "s1", - "agent://seller", - accept.clone(), - )) + rt.process( + &env( + "macp.mode.proposal.v1", + "Accept", + "m3", + "s1", + "agent://seller", + accept.clone(), + ), + None, + ) .await .unwrap(); - rt.process(&env( - "macp.mode.proposal.v1", - "Accept", - "m4", - "s1", - "agent://buyer", - accept, - )) + rt.process( + &env( + "macp.mode.proposal.v1", + "Accept", + "m4", + "s1", + "agent://buyer", + accept, + ), + None, + ) .await .unwrap(); let commitment = CommitmentPayload { @@ -624,16 +794,70 @@ mod tests { } .encode_to_vec(); let result = rt - .process(&env( - "macp.mode.proposal.v1", - "Commitment", - "m5", - "s1", - "agent://buyer", - commitment, - )) + .process( + &env( + "macp.mode.proposal.v1", + "Commitment", + "m5", + "s1", + "agent://buyer", + commitment, + ), + None, + ) .await .unwrap(); assert_eq!(result.session_state, SessionState::Resolved); } + + #[tokio::test] + async fn max_open_sessions_enforced_under_write_lock() { + let rt = make_runtime(); + // First session succeeds with limit=1 + rt.process( + &env( + "macp.mode.decision.v1", + "SessionStart", + "m1", + "s1", + "agent://orchestrator", + session_start(vec!["agent://fraud".into()]), + ), + Some(1), + ) + .await + .unwrap(); + + // Second session from the same sender should fail with RateLimited + let err = rt + .process( + &env( + "macp.mode.decision.v1", + "SessionStart", + "m2", + "s2", + "agent://orchestrator", + session_start(vec!["agent://fraud".into()]), + ), + Some(1), + ) + .await + .unwrap_err(); + assert!(matches!(err, MacpError::RateLimited)); + + // A different sender should still succeed + rt.process( + &env( + "macp.mode.decision.v1", + "SessionStart", + "m3", + "s3", + "agent://other", + session_start(vec!["agent://fraud".into()]), + ), + Some(1), + ) + .await + .unwrap(); + } } diff --git a/src/security.rs b/src/security.rs index b8a9c8e..6ef95e2 100644 --- a/src/security.rs +++ b/src/security.rs @@ -53,7 +53,6 @@ struct RateBucket { pub struct SecurityLayer { identities: Arc>, rate_bucket: Arc, - auth_required: bool, allow_dev_sender_header: bool, pub max_payload_bytes: usize, session_start_rate: RateLimitConfig, @@ -65,7 +64,6 @@ impl SecurityLayer { Self { identities: Arc::new(HashMap::new()), rate_bucket: Arc::new(RateBucket::default()), - auth_required: false, allow_dev_sender_header: true, max_payload_bytes: 1_048_576, session_start_rate: RateLimitConfig { @@ -113,7 +111,6 @@ impl SecurityLayer { None }; - let auth_required = raw.is_some() || !allow_dev_sender_header; let identities = raw .map(|json| Self::parse_identities(&json)) .transpose()? @@ -122,7 +119,6 @@ impl SecurityLayer { Ok(Self { identities: Arc::new(identities), rate_bucket: Arc::new(RateBucket::default()), - auth_required, allow_dev_sender_header, max_payload_bytes, session_start_rate, @@ -194,16 +190,7 @@ impl SecurityLayer { } } - if self.auth_required { - Err(MacpError::Unauthenticated) - } else { - Ok(AuthIdentity { - sender: "agent://anonymous".into(), - allowed_modes: None, - can_start_sessions: true, - max_open_sessions: None, - }) - } + Err(MacpError::Unauthenticated) } pub fn authorize_mode( @@ -277,7 +264,6 @@ mod tests { SecurityLayer { identities: Arc::new(identities), rate_bucket: Arc::new(RateBucket::default()), - auth_required: true, allow_dev_sender_header: false, max_payload_bytes: 1_048_576, session_start_rate: RateLimitConfig { @@ -296,7 +282,6 @@ mod tests { SecurityLayer { identities: Arc::new(HashMap::new()), rate_bucket: Arc::new(RateBucket::default()), - auth_required: false, allow_dev_sender_header: false, max_payload_bytes: 1_048_576, session_start_rate: RateLimitConfig { @@ -315,15 +300,11 @@ mod tests { // --------------------------------------------------------------- #[test] - fn dev_mode_does_not_require_auth() { + fn dev_mode_requires_dev_header() { let layer = SecurityLayer::dev_mode(); let meta = MetadataMap::new(); - let id = layer - .authenticate_metadata(&meta) - .expect("should succeed without auth"); - assert_eq!(id.sender, "agent://anonymous"); - assert!(id.allowed_modes.is_none()); - assert!(id.can_start_sessions); + let err = layer.authenticate_metadata(&meta).unwrap_err(); + assert!(matches!(err, MacpError::Unauthenticated)); } #[test] @@ -348,16 +329,8 @@ mod tests { #[test] fn from_env_defaults_without_env_vars() { - // from_env reads live env vars, so we verify the code path indirectly: - // When no token JSON/file is set AND allow_dev_sender_header is false, - // auth_required = (!allow_dev_sender_header) = true. - // But if no tokens AND no dev header => auth_required = true. - // - // We can verify that default max_payload, rate limits, etc. are sane - // by constructing through from_env in a controlled subprocess, but that - // is fragile. Instead we test the exact same logic through direct construction. + // Verify default configuration via direct construction. let layer = insecure_layer(); - assert!(!layer.auth_required); assert_eq!(layer.max_payload_bytes, 1_048_576); } @@ -481,7 +454,6 @@ mod tests { let layer = SecurityLayer { identities: Arc::new(HashMap::new()), rate_bucket: Arc::new(RateBucket::default()), - auth_required: false, allow_dev_sender_header: true, max_payload_bytes: 1_048_576, session_start_rate: RateLimitConfig { @@ -506,11 +478,10 @@ mod tests { #[test] fn dev_sender_header_ignored_when_not_allowed() { - // auth_required=true, allow_dev_sender_header=false, no tokens + // allow_dev_sender_header=false, no tokens let layer = SecurityLayer { identities: Arc::new(HashMap::new()), rate_bucket: Arc::new(RateBucket::default()), - auth_required: true, allow_dev_sender_header: false, max_payload_bytes: 1_048_576, session_start_rate: RateLimitConfig { @@ -538,7 +509,6 @@ mod tests { let layer = SecurityLayer { identities: Arc::new(identities), rate_bucket: Arc::new(RateBucket::default()), - auth_required: true, allow_dev_sender_header: true, max_payload_bytes: 1_048_576, session_start_rate: RateLimitConfig { @@ -672,7 +642,6 @@ mod tests { let layer = SecurityLayer { identities: Arc::new(HashMap::new()), rate_bucket: Arc::new(RateBucket::default()), - auth_required: false, allow_dev_sender_header: false, max_payload_bytes: 1_048_576, session_start_rate: RateLimitConfig { @@ -703,7 +672,6 @@ mod tests { let layer = SecurityLayer { identities: Arc::new(HashMap::new()), rate_bucket: Arc::new(RateBucket::default()), - auth_required: false, allow_dev_sender_header: false, max_payload_bytes: 1_048_576, session_start_rate: RateLimitConfig { @@ -731,7 +699,6 @@ mod tests { let layer = SecurityLayer { identities: Arc::new(HashMap::new()), rate_bucket: Arc::new(RateBucket::default()), - auth_required: false, allow_dev_sender_header: false, max_payload_bytes: 1_048_576, session_start_rate: RateLimitConfig { @@ -757,7 +724,6 @@ mod tests { let layer = SecurityLayer { identities: Arc::new(HashMap::new()), rate_bucket: Arc::new(RateBucket::default()), - auth_required: false, allow_dev_sender_header: false, max_payload_bytes: 1_048_576, session_start_rate: RateLimitConfig { @@ -785,22 +751,17 @@ mod tests { // --------------------------------------------------------------- #[test] - fn anonymous_fallback_when_no_auth_required() { + fn no_anonymous_fallback_even_when_auth_not_required() { let layer = insecure_layer(); let meta = MetadataMap::new(); - let id = layer - .authenticate_metadata(&meta) - .expect("should return anonymous"); - assert_eq!(id.sender, "agent://anonymous"); - assert!(id.allowed_modes.is_none()); - assert!(id.can_start_sessions); - assert!(id.max_open_sessions.is_none()); + let err = layer.authenticate_metadata(&meta).unwrap_err(); + assert!(matches!(err, MacpError::Unauthenticated)); } #[test] fn no_anonymous_fallback_when_auth_required() { let json = r#"[{"token":"t","sender":"agent://real"}]"#; - let layer = layer_with_tokens(json); // auth_required = true + let layer = layer_with_tokens(json); let meta = MetadataMap::new(); let err = layer.authenticate_metadata(&meta).unwrap_err(); @@ -808,13 +769,13 @@ mod tests { } #[test] - fn dev_mode_anonymous_fallback_with_empty_metadata() { - // dev_mode: auth_required=false, allow_dev_sender_header=true - // With no headers at all, falls through to anonymous + fn dev_mode_no_fallback_with_empty_metadata() { + // dev_mode: allow_dev_sender_header=true + // With no headers at all, returns Unauthenticated (no anonymous fallback) let layer = SecurityLayer::dev_mode(); let meta = MetadataMap::new(); - let id = layer.authenticate_metadata(&meta).expect("should succeed"); - assert_eq!(id.sender, "agent://anonymous"); + let err = layer.authenticate_metadata(&meta).unwrap_err(); + assert!(matches!(err, MacpError::Unauthenticated)); } // --------------------------------------------------------------- diff --git a/src/server.rs b/src/server.rs index f01bc9e..aa64c99 100644 --- a/src/server.rs +++ b/src/server.rs @@ -88,7 +88,7 @@ impl MacpServer { &self, request: &Request, env: Envelope, - ) -> Result { + ) -> Result<(Envelope, Option), MacpError> { let identity = self.security.authenticate_metadata(request.metadata())?; let env = Self::apply_authenticated_sender(&identity, env)?; let is_session_start = env.message_type == "SessionStart"; @@ -97,20 +97,15 @@ impl MacpServer { self.security .enforce_rate_limit(&identity.sender, is_session_start) .await?; - if is_session_start { - if let Some(max_open) = identity.max_open_sessions { - if self - .runtime - .registry - .count_open_sessions_for_initiator(&identity.sender) - .await - >= max_open - { - return Err(MacpError::RateLimited); - } - } - } - Ok(env) + // max_open_sessions is passed to runtime.process() where it is + // enforced atomically under the session write lock, avoiding a + // TOCTOU race between the count check and session insertion. + let max_open = if is_session_start { + identity.max_open_sessions + } else { + None + }; + Ok((env, max_open)) } async fn authenticate_session_access( @@ -202,9 +197,9 @@ impl MacpRuntimeService for MacpServer { let result = async { self.validate_envelope_shape(&env)?; - let env = self.authenticate_send_request(&request, env).await?; + let (env, max_open) = self.authenticate_send_request(&request, env).await?; self.runtime - .process(&env) + .process(&env, max_open) .await .map(|process_result| (env, process_result)) } @@ -262,9 +257,20 @@ impl MacpRuntimeService for MacpServer { request: Request, ) -> Result, Status> { let session_id = request.get_ref().session_id.clone(); - let _identity = self - .authenticate_session_access(&request, &session_id) - .await?; + let identity = self + .security + .authenticate_metadata(request.metadata()) + .map_err(Self::status_from_error)?; + let session = self + .runtime + .get_session_checked(&session_id) + .await + .ok_or_else(|| Status::not_found(format!("Session '{}' not found", session_id)))?; + if identity.sender != session.initiator_sender { + return Err(Status::permission_denied( + "FORBIDDEN: only the session initiator can cancel", + )); + } let req = request.into_inner(); match self .runtime @@ -392,9 +398,11 @@ mod tests { use prost::Message; fn make_server() -> (MacpServer, Arc) { + let storage: Arc = + Arc::new(macp_runtime::storage::MemoryBackend); let registry = Arc::new(SessionRegistry::new()); let log_store = Arc::new(LogStore::new()); - let runtime = Arc::new(Runtime::new(registry, log_store)); + let runtime = Arc::new(Runtime::new(storage, registry, log_store)); let server = MacpServer::new(runtime.clone(), SecurityLayer::dev_mode()); (server, runtime) } @@ -591,6 +599,37 @@ mod tests { assert_eq!(ack.session_state, PbSessionState::Expired as i32); } + #[tokio::test] + async fn participant_cannot_cancel_session() { + let (server, _) = make_server(); + let ack = do_send( + &server, + "agent://orchestrator", + Envelope { + macp_version: "1.0".into(), + mode: "macp.mode.decision.v1".into(), + message_type: "SessionStart".into(), + message_id: "m1".into(), + session_id: "s1".into(), + sender: String::new(), + timestamp_unix_ms: Utc::now().timestamp_millis(), + payload: start_payload(), + }, + ) + .await; + assert!(ack.ok); + + // Participant (not initiator) tries to cancel + let mut req = Request::new(CancelSessionRequest { + session_id: "s1".into(), + reason: "I want to cancel".into(), + }); + req.metadata_mut() + .insert("x-macp-agent-id", "agent://fraud".parse().unwrap()); + let err = server.cancel_session(req).await.unwrap_err(); + assert_eq!(err.code(), tonic::Code::PermissionDenied); + } + #[tokio::test] async fn cancel_session_unknown_session_returns_error() { let (server, _) = make_server(); diff --git a/src/session.rs b/src/session.rs index 461f03c..8bfa5ca 100644 --- a/src/session.rs +++ b/src/session.rs @@ -17,6 +17,7 @@ pub struct Session { pub session_id: String, pub state: SessionState, pub ttl_expiry: i64, + pub ttl_ms: i64, pub started_at_unix_ms: i64, pub resolution: Option>, pub mode: String, diff --git a/src/storage.rs b/src/storage.rs new file mode 100644 index 0000000..f411cda --- /dev/null +++ b/src/storage.rs @@ -0,0 +1,656 @@ +use crate::log_store::LogEntry; +use crate::registry::PersistedSession; +use crate::session::Session; +use std::collections::HashMap; +use std::fs; +use std::io::{self, BufRead, Write}; +use std::path::{Path, PathBuf}; + +const STORAGE_VERSION: u32 = 2; + +// --------------------------------------------------------------------------- +// StorageBackend trait +// --------------------------------------------------------------------------- + +#[async_trait::async_trait] +pub trait StorageBackend: Send + Sync { + async fn save_session(&self, session: &Session) -> io::Result<()>; + async fn load_session(&self, session_id: &str) -> io::Result>; + async fn load_all_sessions(&self) -> io::Result>; + async fn append_log_entry(&self, session_id: &str, entry: &LogEntry) -> io::Result<()>; + async fn load_log(&self, session_id: &str) -> io::Result>; + async fn create_session_storage(&self, session_id: &str) -> io::Result<()>; +} + +// --------------------------------------------------------------------------- +// MemoryBackend — used for MACP_MEMORY_ONLY=1 and tests +// --------------------------------------------------------------------------- + +pub struct MemoryBackend; + +#[async_trait::async_trait] +impl StorageBackend for MemoryBackend { + async fn save_session(&self, _session: &Session) -> io::Result<()> { + Ok(()) + } + + async fn load_session(&self, _session_id: &str) -> io::Result> { + Ok(None) + } + + async fn load_all_sessions(&self) -> io::Result> { + Ok(vec![]) + } + + async fn append_log_entry(&self, _session_id: &str, _entry: &LogEntry) -> io::Result<()> { + Ok(()) + } + + async fn load_log(&self, _session_id: &str) -> io::Result> { + Ok(vec![]) + } + + async fn create_session_storage(&self, _session_id: &str) -> io::Result<()> { + Ok(()) + } +} + +// --------------------------------------------------------------------------- +// FileBackend — per-session directory structure with append-only JSONL logs +// --------------------------------------------------------------------------- + +pub struct FileBackend { + base_dir: PathBuf, +} + +impl FileBackend { + pub fn new(base_dir: PathBuf) -> io::Result { + fs::create_dir_all(base_dir.join("sessions"))?; + Ok(Self { base_dir }) + } + + fn session_dir(&self, session_id: &str) -> PathBuf { + self.base_dir.join("sessions").join(session_id) + } + + fn session_file(&self, session_id: &str) -> PathBuf { + self.session_dir(session_id).join("session.json") + } + + fn log_file(&self, session_id: &str) -> PathBuf { + self.session_dir(session_id).join("log.jsonl") + } + + fn atomic_write(path: &Path, data: &[u8]) -> io::Result<()> { + let tmp_path = path.with_extension("json.tmp"); + fs::write(&tmp_path, data)?; + fs::rename(&tmp_path, path) + } +} + +#[async_trait::async_trait] +impl StorageBackend for FileBackend { + async fn create_session_storage(&self, session_id: &str) -> io::Result<()> { + fs::create_dir_all(self.session_dir(session_id)) + } + + async fn save_session(&self, session: &Session) -> io::Result<()> { + let persisted = PersistedSession::from(session); + let bytes = serde_json::to_vec_pretty(&persisted) + .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?; + Self::atomic_write(&self.session_file(&session.session_id), &bytes) + } + + async fn load_session(&self, session_id: &str) -> io::Result> { + let path = self.session_file(session_id); + if !path.exists() { + return Ok(None); + } + let bytes = fs::read(&path)?; + let persisted: PersistedSession = serde_json::from_slice(&bytes) + .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?; + Ok(Some(Session::from(persisted))) + } + + async fn load_all_sessions(&self) -> io::Result> { + let sessions_dir = self.base_dir.join("sessions"); + if !sessions_dir.exists() { + return Ok(vec![]); + } + let mut sessions = Vec::new(); + for entry in fs::read_dir(&sessions_dir)? { + let entry = entry?; + if !entry.file_type()?.is_dir() { + continue; + } + let session_file = entry.path().join("session.json"); + if !session_file.exists() { + continue; + } + let bytes = fs::read(&session_file)?; + match serde_json::from_slice::(&bytes) { + Ok(persisted) => sessions.push(Session::from(persisted)), + Err(e) => { + eprintln!( + "warning: failed to deserialize session from {}: {e}; skipping", + session_file.display() + ); + } + } + } + Ok(sessions) + } + + async fn append_log_entry(&self, session_id: &str, entry: &LogEntry) -> io::Result<()> { + let path = self.log_file(session_id); + let mut line = serde_json::to_string(entry) + .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?; + line.push('\n'); + + let mut file = fs::OpenOptions::new() + .create(true) + .append(true) + .open(&path)?; + file.write_all(line.as_bytes())?; + file.sync_data()?; + Ok(()) + } + + async fn load_log(&self, session_id: &str) -> io::Result> { + let path = self.log_file(session_id); + if !path.exists() { + return Ok(vec![]); + } + let file = fs::File::open(&path)?; + let reader = io::BufReader::new(file); + let mut entries = Vec::new(); + for (line_num, line) in reader.lines().enumerate() { + let line = line?; + if line.trim().is_empty() { + continue; + } + match serde_json::from_str::(&line) { + Ok(entry) => entries.push(entry), + Err(e) => { + eprintln!( + "warning: failed to parse log entry at {}:{}: {e}; skipping", + path.display(), + line_num + 1 + ); + } + } + } + Ok(entries) + } +} + +// --------------------------------------------------------------------------- +// storage_version.json +// --------------------------------------------------------------------------- + +#[derive(serde::Serialize, serde::Deserialize)] +struct StorageVersion { + version: u32, +} + +pub fn write_storage_version(base_dir: &Path) -> io::Result<()> { + let sv = StorageVersion { + version: STORAGE_VERSION, + }; + let bytes = serde_json::to_vec_pretty(&sv) + .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?; + fs::write(base_dir.join("storage_version.json"), bytes) +} + +pub fn read_storage_version(base_dir: &Path) -> io::Result> { + let path = base_dir.join("storage_version.json"); + if !path.exists() { + return Ok(None); + } + let bytes = fs::read(&path)?; + let sv: StorageVersion = serde_json::from_slice(&bytes) + .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?; + Ok(Some(sv.version)) +} + +// --------------------------------------------------------------------------- +// Migration from legacy monolithic format (v1 → v2) +// --------------------------------------------------------------------------- + +pub fn migrate_if_needed(base_dir: &Path) -> io::Result<()> { + let sessions_dir = base_dir.join("sessions"); + let legacy_sessions = base_dir.join("sessions.json"); + let legacy_logs = base_dir.join("logs.json"); + + // Already migrated or fresh install + if sessions_dir.exists() || (!legacy_sessions.exists() && !legacy_logs.exists()) { + write_storage_version(base_dir)?; + return Ok(()); + } + + println!("Migrating legacy storage format to per-session directories..."); + + // Load legacy sessions + let sessions: HashMap = if legacy_sessions.exists() { + let bytes = fs::read(&legacy_sessions)?; + serde_json::from_slice(&bytes).map_err(|e| { + io::Error::new( + io::ErrorKind::InvalidData, + format!("failed to parse legacy sessions.json: {e}"), + ) + })? + } else { + HashMap::new() + }; + + // Load legacy logs + let logs: HashMap> = if legacy_logs.exists() { + let bytes = fs::read(&legacy_logs)?; + serde_json::from_slice(&bytes).map_err(|e| { + io::Error::new( + io::ErrorKind::InvalidData, + format!("failed to parse legacy logs.json: {e}"), + ) + })? + } else { + HashMap::new() + }; + + fs::create_dir_all(&sessions_dir)?; + + // Migrate each session + for (session_id, persisted) in &sessions { + let dir = sessions_dir.join(session_id); + fs::create_dir_all(&dir)?; + + // Write session.json + let session_bytes = serde_json::to_vec_pretty(persisted) + .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?; + fs::write(dir.join("session.json"), session_bytes)?; + + // Write log.jsonl + if let Some(entries) = logs.get(session_id) { + let mut log_data = String::new(); + for entry in entries { + let line = serde_json::to_string(entry) + .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?; + log_data.push_str(&line); + log_data.push('\n'); + } + fs::write(dir.join("log.jsonl"), log_data)?; + } + } + + // Also migrate logs for sessions that only appear in logs (not in sessions.json) + for (session_id, entries) in &logs { + if sessions.contains_key(session_id) { + continue; + } + let dir = sessions_dir.join(session_id); + fs::create_dir_all(&dir)?; + let mut log_data = String::new(); + for entry in entries { + let line = serde_json::to_string(entry) + .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?; + log_data.push_str(&line); + log_data.push('\n'); + } + fs::write(dir.join("log.jsonl"), log_data)?; + } + + // Backup old files instead of deleting + if legacy_sessions.exists() { + fs::rename(&legacy_sessions, base_dir.join("sessions.json.migrated"))?; + } + if legacy_logs.exists() { + fs::rename(&legacy_logs, base_dir.join("logs.json.migrated"))?; + } + + write_storage_version(base_dir)?; + println!( + "Migration complete: {} sessions, {} log sets migrated.", + sessions.len(), + logs.len() + ); + Ok(()) +} + +// --------------------------------------------------------------------------- +// Crash recovery +// --------------------------------------------------------------------------- + +pub fn recover_session(session: &mut Session, log_entries: &[LogEntry]) { + // Ensure all log entry message IDs are in the session's dedup set. + // If the runtime crashed after writing a log entry but before persisting + // the session snapshot, there will be entries in the log not reflected + // in seen_message_ids. + let mut recovered = 0usize; + for entry in log_entries { + if !entry.message_id.is_empty() && session.seen_message_ids.insert(entry.message_id.clone()) + { + recovered += 1; + } + } + if recovered > 0 { + eprintln!( + "recovery: session '{}' reconciled {} log entries into dedup state", + session.session_id, recovered + ); + } +} + +pub fn cleanup_temp_files(base_dir: &Path) { + let sessions_dir = base_dir.join("sessions"); + if !sessions_dir.exists() { + return; + } + if let Ok(entries) = fs::read_dir(&sessions_dir) { + for entry in entries.flatten() { + if !entry.file_type().map(|ft| ft.is_dir()).unwrap_or(false) { + continue; + } + let dir = entry.path(); + if let Ok(files) = fs::read_dir(&dir) { + for file in files.flatten() { + let path = file.path(); + if path.extension().and_then(|e| e.to_str()) == Some("tmp") { + eprintln!("recovery: removing orphaned temp file {}", path.display()); + let _ = fs::remove_file(&path); + } + } + } + } + } +} + +// --------------------------------------------------------------------------- +// Tests +// --------------------------------------------------------------------------- + +#[cfg(test)] +mod tests { + use super::*; + use crate::log_store::EntryKind; + use crate::session::SessionState; + use std::collections::HashSet; + + fn sample_session(id: &str) -> Session { + Session { + session_id: id.into(), + state: SessionState::Open, + ttl_expiry: 61_000, + ttl_ms: 60_000, + started_at_unix_ms: 1_000, + resolution: None, + mode: "macp.mode.decision.v1".into(), + mode_state: vec![1, 2, 3], + participants: vec!["alice".into(), "bob".into()], + seen_message_ids: HashSet::from(["m1".into()]), + intent: "test intent".into(), + mode_version: "1.0.0".into(), + configuration_version: "cfg-1".into(), + policy_version: "pol-1".into(), + context: vec![9], + roots: vec![crate::pb::Root { + uri: "root://1".into(), + name: "r1".into(), + }], + initiator_sender: "alice".into(), + } + } + + fn sample_entry(id: &str) -> LogEntry { + LogEntry { + message_id: id.into(), + received_at_ms: 1_700_000_000_000, + sender: "alice".into(), + message_type: "Message".into(), + raw_payload: vec![], + entry_kind: EntryKind::Incoming, + } + } + + #[tokio::test] + async fn file_backend_session_round_trip() { + let dir = tempfile::tempdir().unwrap(); + let backend = FileBackend::new(dir.path().to_path_buf()).unwrap(); + + let session = sample_session("s1"); + backend.create_session_storage("s1").await.unwrap(); + backend.save_session(&session).await.unwrap(); + + let loaded = backend.load_session("s1").await.unwrap().unwrap(); + assert_eq!(loaded.session_id, "s1"); + assert_eq!(loaded.ttl_ms, 60_000); + assert_eq!(loaded.mode_version, "1.0.0"); + assert!(loaded.seen_message_ids.contains("m1")); + assert_eq!(loaded.participants, vec!["alice", "bob"]); + } + + #[tokio::test] + async fn file_backend_log_append_and_load() { + let dir = tempfile::tempdir().unwrap(); + let backend = FileBackend::new(dir.path().to_path_buf()).unwrap(); + + backend.create_session_storage("s1").await.unwrap(); + backend + .append_log_entry("s1", &sample_entry("m1")) + .await + .unwrap(); + backend + .append_log_entry("s1", &sample_entry("m2")) + .await + .unwrap(); + backend + .append_log_entry("s1", &sample_entry("m3")) + .await + .unwrap(); + + let log = backend.load_log("s1").await.unwrap(); + assert_eq!(log.len(), 3); + assert_eq!(log[0].message_id, "m1"); + assert_eq!(log[1].message_id, "m2"); + assert_eq!(log[2].message_id, "m3"); + } + + #[tokio::test] + async fn file_backend_load_all_sessions() { + let dir = tempfile::tempdir().unwrap(); + let backend = FileBackend::new(dir.path().to_path_buf()).unwrap(); + + for id in ["s1", "s2", "s3"] { + backend.create_session_storage(id).await.unwrap(); + backend.save_session(&sample_session(id)).await.unwrap(); + } + + let mut sessions = backend.load_all_sessions().await.unwrap(); + sessions.sort_by(|a, b| a.session_id.cmp(&b.session_id)); + assert_eq!(sessions.len(), 3); + assert_eq!(sessions[0].session_id, "s1"); + assert_eq!(sessions[1].session_id, "s2"); + assert_eq!(sessions[2].session_id, "s3"); + } + + #[tokio::test] + async fn memory_backend_is_noop() { + let backend = MemoryBackend; + backend.create_session_storage("s1").await.unwrap(); + backend.save_session(&sample_session("s1")).await.unwrap(); + assert!(backend.load_session("s1").await.unwrap().is_none()); + assert!(backend.load_all_sessions().await.unwrap().is_empty()); + } + + #[tokio::test] + async fn append_only_no_full_rewrite() { + let dir = tempfile::tempdir().unwrap(); + let backend = FileBackend::new(dir.path().to_path_buf()).unwrap(); + backend.create_session_storage("s1").await.unwrap(); + + // Append 100 entries + for i in 0..100 { + backend + .append_log_entry("s1", &sample_entry(&format!("m{}", i))) + .await + .unwrap(); + } + + // Verify file has 100 lines + let content = fs::read_to_string(backend.log_file("s1")).unwrap(); + let line_count = content.lines().count(); + assert_eq!(line_count, 100); + + // Verify all entries are loadable + let log = backend.load_log("s1").await.unwrap(); + assert_eq!(log.len(), 100); + } + + #[test] + fn crash_recovery_reconciles_dedup_state() { + let mut session = sample_session("s1"); + // session has "m1" in seen_message_ids + assert!(session.seen_message_ids.contains("m1")); + assert!(!session.seen_message_ids.contains("m2")); + assert!(!session.seen_message_ids.contains("m3")); + + let entries = vec![ + sample_entry("m1"), // already in dedup set + sample_entry("m2"), // missing from dedup set + sample_entry("m3"), // missing from dedup set + ]; + + recover_session(&mut session, &entries); + + assert!(session.seen_message_ids.contains("m1")); + assert!(session.seen_message_ids.contains("m2")); + assert!(session.seen_message_ids.contains("m3")); + } + + #[tokio::test] + async fn migration_from_legacy_format() { + let dir = tempfile::tempdir().unwrap(); + let base = dir.path(); + + // Create legacy format files + let session = sample_session("s1"); + let persisted = PersistedSession::from(&session); + let sessions_map: HashMap = + [("s1".into(), persisted)].into_iter().collect(); + fs::write( + base.join("sessions.json"), + serde_json::to_vec_pretty(&sessions_map).unwrap(), + ) + .unwrap(); + + let entries = vec![sample_entry("m1"), sample_entry("m2")]; + let logs_map: HashMap> = + [("s1".into(), entries)].into_iter().collect(); + fs::write( + base.join("logs.json"), + serde_json::to_vec_pretty(&logs_map).unwrap(), + ) + .unwrap(); + + // Run migration + migrate_if_needed(base).unwrap(); + + // Verify per-session directories created + assert!(base.join("sessions/s1/session.json").exists()); + assert!(base.join("sessions/s1/log.jsonl").exists()); + + // Verify old files renamed + assert!(base.join("sessions.json.migrated").exists()); + assert!(base.join("logs.json.migrated").exists()); + assert!(!base.join("sessions.json").exists()); + assert!(!base.join("logs.json").exists()); + + // Verify storage version + assert_eq!(read_storage_version(base).unwrap(), Some(2)); + + // Verify data is loadable via FileBackend + let backend = FileBackend::new(base.to_path_buf()).unwrap(); + let loaded = backend.load_session("s1").await.unwrap().unwrap(); + assert_eq!(loaded.session_id, "s1"); + assert_eq!(loaded.ttl_ms, 60_000); + + let log = backend.load_log("s1").await.unwrap(); + assert_eq!(log.len(), 2); + } + + #[tokio::test] + async fn ttl_ms_backward_compat_deserialization() { + let dir = tempfile::tempdir().unwrap(); + let base = dir.path(); + let backend = FileBackend::new(base.to_path_buf()).unwrap(); + backend.create_session_storage("s1").await.unwrap(); + + // Write a session JSON without ttl_ms (simulating old format) + let json = serde_json::json!({ + "session_id": "s1", + "state": "Open", + "ttl_expiry": 61000, + "started_at_unix_ms": 1000, + "resolution": null, + "mode": "macp.mode.decision.v1", + "mode_state": [], + "participants": ["alice"], + "seen_message_ids": [], + "intent": "", + "mode_version": "1.0.0", + "configuration_version": "cfg", + "policy_version": "pol", + "context": [], + "roots": [], + "initiator_sender": "alice" + }); + fs::write( + backend.session_file("s1"), + serde_json::to_vec_pretty(&json).unwrap(), + ) + .unwrap(); + + let loaded = backend.load_session("s1").await.unwrap().unwrap(); + // ttl_ms should be computed from ttl_expiry - started_at_unix_ms + assert_eq!(loaded.ttl_ms, 60_000); + } + + #[test] + fn cleanup_temp_files_removes_orphans() { + let dir = tempfile::tempdir().unwrap(); + let base = dir.path(); + let sessions_dir = base.join("sessions").join("s1"); + fs::create_dir_all(&sessions_dir).unwrap(); + + // Create an orphaned temp file + fs::write(sessions_dir.join("session.json.tmp"), b"partial").unwrap(); + assert!(sessions_dir.join("session.json.tmp").exists()); + + cleanup_temp_files(base); + + assert!(!sessions_dir.join("session.json.tmp").exists()); + } + + #[tokio::test] + async fn write_ordering_log_before_session() { + let dir = tempfile::tempdir().unwrap(); + let backend = FileBackend::new(dir.path().to_path_buf()).unwrap(); + backend.create_session_storage("s1").await.unwrap(); + + // Write log entry (commit point) + backend + .append_log_entry("s1", &sample_entry("m1")) + .await + .unwrap(); + + // Verify log is durable even without session save + let log = backend.load_log("s1").await.unwrap(); + assert_eq!(log.len(), 1); + assert_eq!(log[0].message_id, "m1"); + + // Session load should return None (not yet saved) + assert!(backend.load_session("s1").await.unwrap().is_none()); + + // Now save session + backend.save_session(&sample_session("s1")).await.unwrap(); + assert!(backend.load_session("s1").await.unwrap().is_some()); + } +}