diff --git a/Cargo.toml b/Cargo.toml index cafd299..b2102ff 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -3,6 +3,11 @@ name = "macp-runtime" version = "0.4.0" edition = "2021" +[features] +default = [] +rocksdb-backend = ["dep:rocksdb"] +redis-backend = ["dep:redis"] + [dependencies] tokio = { version = "1", features = ["full"] } tonic = { version = "0.14", features = ["transport", "tls-ring"] } @@ -20,6 +25,8 @@ async-trait = "0.1" uuid = { version = "1", features = ["v4"] } tracing = "0.1" tracing-subscriber = { version = "0.3", features = ["env-filter"] } +rocksdb = { version = "0.22", optional = true } +redis = { version = "0.27", features = ["tokio-comp", "aio"], optional = true } [dev-dependencies] tempfile = "3" diff --git a/src/log_store.rs b/src/log_store.rs index 7eb2579..b1cba1c 100644 --- a/src/log_store.rs +++ b/src/log_store.rs @@ -5,6 +5,7 @@ use tokio::sync::RwLock; pub enum EntryKind { Incoming, Internal, + Checkpoint, } #[derive(Clone, Debug, serde::Serialize, serde::Deserialize)] diff --git a/src/main.rs b/src/main.rs index 31a9f90..0366c83 100644 --- a/src/main.rs +++ b/src/main.rs @@ -34,13 +34,36 @@ async fn main() -> Result<(), Box> { PathBuf::from(std::env::var("MACP_DATA_DIR").unwrap_or_else(|_| ".macp-data".into())); let strict_recovery = std::env::var("MACP_STRICT_RECOVERY").ok().as_deref() == Some("1"); + let backend_name = std::env::var("MACP_STORAGE_BACKEND").unwrap_or_else(|_| "file".into()); let storage: Arc = if memory_only { Arc::new(MemoryBackend) } else { - std::fs::create_dir_all(&data_dir)?; - migrate_if_needed(&data_dir)?; - cleanup_temp_files(&data_dir); - Arc::new(FileBackend::new(data_dir.clone())?) + match backend_name.as_str() { + "file" => { + std::fs::create_dir_all(&data_dir)?; + migrate_if_needed(&data_dir)?; + cleanup_temp_files(&data_dir); + Arc::new(FileBackend::new(data_dir.clone())?) + } + #[cfg(feature = "rocksdb-backend")] + "rocksdb" => { + let path = std::env::var("MACP_ROCKSDB_PATH") + .unwrap_or_else(|_| data_dir.join("rocksdb").to_string_lossy().to_string()); + Arc::new(macp_runtime::storage::RocksDbBackend::open(&path)?) + } + #[cfg(feature = "redis-backend")] + "redis" => { + let url = std::env::var("MACP_REDIS_URL") + .unwrap_or_else(|_| "redis://127.0.0.1:6379".into()); + Arc::new(macp_runtime::storage::RedisBackend::connect(&url, "macp").await?) + } + other => { + return Err(format!( + "unknown storage backend: {other}. Valid: file, rocksdb, redis" + ) + .into()); + } + } }; // Load persisted state into in-memory caches @@ -49,76 +72,69 @@ async fn main() -> Result<(), Box> { let mode_registry = Arc::new(ModeRegistry::build_default()); if !memory_only { - // Enumerate session directories and replay from logs - let sessions_dir = data_dir.join("sessions"); + // Replay sessions from logs + let session_ids = storage.list_session_ids().await?; let mut recovered = 0usize; - if tokio::fs::metadata(&sessions_dir).await.is_ok() { - let mut entries = tokio::fs::read_dir(&sessions_dir).await?; - while let Some(entry) = entries.next_entry().await? { - if !entry.file_type().await?.is_dir() { - continue; + for session_id in session_ids { + let log_entries = match storage.load_log(&session_id).await { + Ok(entries) => entries, + Err(e) if strict_recovery => { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + format!("strict recovery: failed to load log for {session_id}: {e}"), + ) + .into()); } - let session_id = entry.file_name().to_string_lossy().to_string(); - let log_entries = match storage.load_log(&session_id).await { - Ok(entries) => entries, - Err(e) if strict_recovery => { - return Err(io::Error::new( - io::ErrorKind::InvalidData, - format!("strict recovery: failed to load log for {session_id}: {e}"), - ) - .into()); - } - Err(e) => { - tracing::warn!( - session_id = %session_id, - error = %e, - "failed to load session log; skipping" - ); - continue; - } - }; - if log_entries.is_empty() { + Err(e) => { + tracing::warn!( + session_id = %session_id, + error = %e, + "failed to load session log; skipping" + ); continue; } + }; + if log_entries.is_empty() { + continue; + } - match replay_session(&session_id, &log_entries, &mode_registry) { - Ok(session) => { - if let Err(e) = storage.save_session(&session).await { - if strict_recovery { - return Err(io::Error::other(format!( - "strict recovery: failed to persist recovered session {session_id}: {e}" - )) - .into()); - } - tracing::warn!( - session_id = %session_id, - error = %e, - "failed to persist recovered session" - ); - } - - log_store.create_session_log(&session_id).await; - for log_entry in &log_entries { - log_store.append(&session_id, log_entry.clone()).await; + match replay_session(&session_id, &log_entries, &mode_registry) { + Ok(session) => { + if let Err(e) = storage.save_session(&session).await { + if strict_recovery { + return Err(io::Error::other(format!( + "strict recovery: failed to persist recovered session {session_id}: {e}" + )) + .into()); } - - registry.insert_recovered_session(session_id, session).await; - recovered += 1; - } - Err(e) if strict_recovery => { - return Err(io::Error::new( - io::ErrorKind::InvalidData, - format!("strict recovery: failed to replay session {session_id}: {e}"), - ) - .into()); - } - Err(e) => { tracing::warn!( session_id = %session_id, error = %e, - "failed to replay session; skipping" + "failed to persist recovered session" ); } + + log_store.create_session_log(&session_id).await; + for log_entry in &log_entries { + log_store.append(&session_id, log_entry.clone()).await; + } + + registry.insert_recovered_session(session_id, session).await; + recovered += 1; + } + Err(e) if strict_recovery => { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + format!("strict recovery: failed to replay session {session_id}: {e}"), + ) + .into()); + } + Err(e) => { + tracing::warn!( + session_id = %session_id, + error = %e, + "failed to replay session; skipping" + ); } } } diff --git a/src/replay.rs b/src/replay.rs index aa854a3..3b38bc2 100644 --- a/src/replay.rs +++ b/src/replay.rs @@ -2,6 +2,7 @@ use crate::error::MacpError; use crate::log_store::{EntryKind, LogEntry}; use crate::mode_registry::ModeRegistry; use crate::pb::Envelope; +use crate::registry::PersistedSession; use crate::session::{ extract_ttl_ms, parse_session_start_payload, validate_canonical_session_start_payload, Session, SessionState, @@ -9,13 +10,117 @@ use crate::session::{ /// Rebuild a `Session` from its append-only log. /// -/// This replays the same mode callbacks (`on_session_start`, `on_message`) in -/// order so mode state, dedup state, and session lifecycle are reconstructed -/// identically to how they were built during live processing. +/// If the log contains `Checkpoint` entries, replay starts from the last +/// checkpoint (restoring the serialized session state) and only replays +/// subsequent entries. Otherwise, a full replay from `SessionStart` is +/// performed. pub fn replay_session( session_id: &str, log_entries: &[LogEntry], registry: &ModeRegistry, +) -> Result { + // Try checkpoint-based fast path first + if let Some(session) = try_replay_from_checkpoint(session_id, log_entries, registry)? { + return Ok(session); + } + + replay_from_start(session_id, log_entries, registry) +} + +/// Attempt to restore from the last checkpoint entry and replay remaining entries. +/// Returns `Ok(None)` if no checkpoint exists. +fn try_replay_from_checkpoint( + session_id: &str, + log_entries: &[LogEntry], + registry: &ModeRegistry, +) -> Result, MacpError> { + let checkpoint_idx = log_entries + .iter() + .rposition(|e| e.entry_kind == EntryKind::Checkpoint); + + let idx = match checkpoint_idx { + Some(idx) => idx, + None => return Ok(None), + }; + + let checkpoint = &log_entries[idx]; + let persisted: PersistedSession = + serde_json::from_slice(&checkpoint.raw_payload).map_err(|_| MacpError::InvalidPayload)?; + let mut session = Session::from(persisted); + session.session_id = session_id.into(); + + let mode = registry + .get_mode(&session.mode) + .ok_or(MacpError::UnknownMode)?; + + // Replay entries after the checkpoint + for entry in &log_entries[idx + 1..] { + replay_entry(&mut session, session_id, entry, &mode)?; + } + + Ok(Some(session)) +} + +/// Replay a single log entry onto a session. +fn replay_entry( + session: &mut Session, + session_id: &str, + entry: &LogEntry, + mode: &crate::mode_registry::ModeRef<'_>, +) -> Result<(), MacpError> { + match entry.entry_kind { + EntryKind::Incoming => { + let replay_env = Envelope { + macp_version: if entry.macp_version.is_empty() { + "1.0".into() + } else { + entry.macp_version.clone() + }, + mode: if entry.mode.is_empty() { + session.mode.clone() + } else { + entry.mode.clone() + }, + message_type: entry.message_type.clone(), + message_id: entry.message_id.clone(), + session_id: session_id.into(), + sender: entry.sender.clone(), + timestamp_unix_ms: entry.received_at_ms, + payload: entry.raw_payload.clone(), + }; + + if session.state != SessionState::Open { + if !replay_env.message_id.is_empty() { + session.seen_message_ids.insert(replay_env.message_id); + } + return Ok(()); + } + + mode.authorize_sender(session, &replay_env)?; + let response = mode.on_message(session, &replay_env)?; + session.apply_mode_response(response); + if !replay_env.message_id.is_empty() { + session.seen_message_ids.insert(replay_env.message_id); + } + } + EntryKind::Internal => match entry.message_type.as_str() { + "TtlExpired" | "SessionCancel" => { + session.state = SessionState::Expired; + } + _ => {} + }, + EntryKind::Checkpoint => { + // Skip intermediate checkpoints when replaying from an earlier one + } + } + Ok(()) +} + +/// Full replay from the SessionStart entry. +fn replay_from_start( + session_id: &str, + log_entries: &[LogEntry], + registry: &ModeRegistry, ) -> Result { // 1. Find the SessionStart entry let start_entry = log_entries @@ -98,59 +203,7 @@ pub fn replay_session( // 5. Replay subsequent entries for entry in log_entries.iter().skip(1) { - match entry.entry_kind { - EntryKind::Incoming => { - let replay_env = Envelope { - macp_version: if entry.macp_version.is_empty() { - "1.0".into() - } else { - entry.macp_version.clone() - }, - mode: if entry.mode.is_empty() { - session.mode.clone() - } else { - entry.mode.clone() - }, - message_type: entry.message_type.clone(), - message_id: entry.message_id.clone(), - session_id: session_id.into(), - sender: entry.sender.clone(), - timestamp_unix_ms: entry.received_at_ms, - payload: entry.raw_payload.clone(), - }; - - if session.state != SessionState::Open { - // Session already resolved/expired, just rebuild dedup - if !replay_env.message_id.is_empty() { - session.seen_message_ids.insert(replay_env.message_id); - } - continue; - } - - // Replay through the same authorization and mode callbacks used during - // live processing. Accepted history that no longer replays cleanly must - // fail recovery instead of silently drifting session state. - mode.authorize_sender(&session, &replay_env)?; - let response = mode.on_message(&session, &replay_env)?; - session.apply_mode_response(response); - if !replay_env.message_id.is_empty() { - session.seen_message_ids.insert(replay_env.message_id); - } - } - EntryKind::Internal => { - match entry.message_type.as_str() { - "TtlExpired" => { - session.state = SessionState::Expired; - } - "SessionCancel" => { - session.state = SessionState::Expired; - } - _ => { - // Unknown internal event — skip - } - } - } - } + replay_entry(&mut session, session_id, entry, &mode)?; } Ok(session) @@ -367,4 +420,92 @@ mod tests { assert_eq!(entry.mode, ""); assert_eq!(entry.macp_version, ""); } + + #[test] + fn replay_from_checkpoint_restores_state() { + use crate::registry::PersistedSession; + + let registry = make_registry(); + + // Build a session via normal replay first + let proposal = ProposalPayload { + proposal_id: "p1".into(), + option: "deploy".into(), + rationale: "ready".into(), + supporting_data: vec![], + } + .encode_to_vec(); + + let full_entries = vec![ + incoming_entry( + "m1", + "SessionStart", + "agent://orchestrator", + start_payload_bytes(), + 1000, + ), + incoming_entry( + "m2", + "Proposal", + "agent://orchestrator", + proposal.clone(), + 2000, + ), + ]; + let full_session = replay_session("s1", &full_entries, ®istry).unwrap(); + + // Create a checkpoint from the replayed session state + let persisted = PersistedSession::from(&full_session); + let checkpoint_payload = serde_json::to_vec(&persisted).unwrap(); + let checkpoint = LogEntry { + message_id: String::new(), + received_at_ms: 3000, + sender: "_runtime".into(), + message_type: "Checkpoint".into(), + raw_payload: checkpoint_payload, + entry_kind: EntryKind::Checkpoint, + session_id: "s1".into(), + mode: "macp.mode.decision.v1".into(), + macp_version: "1.0".into(), + }; + + // A vote after the checkpoint + let vote = VotePayload { + proposal_id: "p1".into(), + vote: "approve".into(), + reason: "lgtm".into(), + } + .encode_to_vec(); + + // Log: SessionStart, Proposal, Checkpoint, Vote + let entries_with_checkpoint = vec![ + full_entries[0].clone(), + full_entries[1].clone(), + checkpoint, + incoming_entry("m3", "Vote", "agent://fraud", vote, 4000), + ]; + + let session = replay_session("s1", &entries_with_checkpoint, ®istry).unwrap(); + assert_eq!(session.state, SessionState::Open); + // Should have dedup from checkpoint (m1, m2) plus newly replayed m3 + assert!(session.seen_message_ids.contains("m1")); + assert!(session.seen_message_ids.contains("m2")); + assert!(session.seen_message_ids.contains("m3")); + } + + #[test] + fn replay_without_checkpoint_still_works() { + // Ensure logs without checkpoints replay correctly (backward compat) + let registry = make_registry(); + let entries = vec![incoming_entry( + "m1", + "SessionStart", + "agent://orchestrator", + start_payload_bytes(), + 1000, + )]; + let session = replay_session("s1", &entries, ®istry).unwrap(); + assert_eq!(session.state, SessionState::Open); + assert!(session.seen_message_ids.contains("m1")); + } } diff --git a/src/runtime.rs b/src/runtime.rs index b82f876..7bb1cc5 100644 --- a/src/runtime.rs +++ b/src/runtime.rs @@ -27,6 +27,7 @@ pub struct Runtime { stream_bus: Arc, mode_registry: Arc, metrics: Arc, + checkpoint_interval: usize, } impl Runtime { @@ -49,6 +50,10 @@ impl Runtime { log_store: Arc, mode_registry: Arc, ) -> Self { + let checkpoint_interval = std::env::var("MACP_CHECKPOINT_INTERVAL") + .ok() + .and_then(|v| v.parse().ok()) + .unwrap_or(0); // 0 = disabled by default Self { storage, registry, @@ -56,6 +61,7 @@ impl Runtime { stream_bus: Arc::new(SessionStreamBus::default()), mode_registry, metrics: Arc::new(RuntimeMetrics::new()), + checkpoint_interval, } } @@ -368,8 +374,13 @@ impl Runtime { tracing::info!(session_id = %env.session_id, mode = %session.mode, "session resolved"); } - // 3. Best-effort session save + // 3. Best-effort session save + checkpoint self.save_session_to_storage(session).await; + if result_state == SessionState::Resolved { + self.maybe_compact_log(&env.session_id, session).await; + } else { + self.maybe_insert_checkpoint(&env.session_id, session).await; + } self.publish_accepted_envelope(env); Ok(ProcessResult { @@ -441,6 +452,7 @@ impl Runtime { self.log_store.append(session_id, cancel_entry).await; session.state = SessionState::Expired; self.save_session_to_storage(session).await; + self.maybe_compact_log(session_id, session).await; self.metrics.record_session_cancelled(&session.mode); tracing::info!(session_id, reason, "session cancelled"); @@ -449,6 +461,62 @@ impl Runtime { duplicate: false, }) } + + /// Best-effort log compaction for terminal sessions. + async fn maybe_compact_log(&self, session_id: &str, session: &Session) { + if let Err(e) = + crate::storage::compaction::compact_session_log(&*self.storage, session_id, session) + .await + { + tracing::debug!( + session_id, + error = %e, + "log compaction skipped (backend may not support it)" + ); + } + } + + /// Insert a checkpoint entry if the log has reached the configured interval. + async fn maybe_insert_checkpoint(&self, session_id: &str, session: &Session) { + if self.checkpoint_interval == 0 { + return; + } + let log_len = self + .log_store + .get_log(session_id) + .await + .map(|l| l.len()) + .unwrap_or(0); + // Only checkpoint at interval boundaries, and not on the first entry + if log_len < self.checkpoint_interval || log_len % self.checkpoint_interval != 0 { + return; + } + let persisted = crate::registry::PersistedSession::from(session); + let raw_payload = match serde_json::to_vec(&persisted) { + Ok(bytes) => bytes, + Err(e) => { + tracing::warn!(session_id, error = %e, "failed to serialize checkpoint"); + return; + } + }; + let checkpoint = LogEntry { + message_id: String::new(), + received_at_ms: Utc::now().timestamp_millis(), + sender: "_runtime".into(), + message_type: "Checkpoint".into(), + raw_payload, + entry_kind: EntryKind::Checkpoint, + session_id: session_id.into(), + mode: session.mode.clone(), + macp_version: String::new(), + }; + if let Err(e) = self.storage.append_log_entry(session_id, &checkpoint).await { + tracing::warn!(session_id, error = %e, "failed to write checkpoint"); + return; + } + self.log_store.append(session_id, checkpoint).await; + tracing::debug!(session_id, log_len, "checkpoint inserted"); + } } #[cfg(test)] @@ -1031,6 +1099,12 @@ mod tests { async fn load_all_sessions(&self) -> io::Result> { Ok(vec![]) } + async fn delete_session(&self, _: &str) -> io::Result<()> { + Ok(()) + } + async fn list_session_ids(&self) -> io::Result> { + Ok(vec![]) + } async fn append_log_entry(&self, _: &str, _: &LogEntry) -> io::Result<()> { Err(io::Error::other("disk full")) } @@ -1084,6 +1158,12 @@ mod tests { async fn load_all_sessions(&self) -> io::Result> { Ok(vec![]) } + async fn delete_session(&self, _: &str) -> io::Result<()> { + Ok(()) + } + async fn list_session_ids(&self) -> io::Result> { + Ok(vec![]) + } async fn append_log_entry(&self, _: &str, _: &LogEntry) -> io::Result<()> { let n = self.count.fetch_add(1, Ordering::SeqCst); if n >= 1 { @@ -1171,6 +1251,12 @@ mod tests { async fn load_all_sessions(&self) -> io::Result> { Ok(vec![]) } + async fn delete_session(&self, _: &str) -> io::Result<()> { + Ok(()) + } + async fn list_session_ids(&self) -> io::Result> { + Ok(vec![]) + } async fn append_log_entry(&self, _: &str, _: &LogEntry) -> io::Result<()> { let n = self.count.fetch_add(1, Ordering::SeqCst); if n >= 1 { diff --git a/src/storage.rs b/src/storage.rs deleted file mode 100644 index 684ccf6..0000000 --- a/src/storage.rs +++ /dev/null @@ -1,676 +0,0 @@ -use crate::log_store::LogEntry; -use crate::registry::PersistedSession; -use crate::session::Session; -use std::collections::HashMap; -use std::fs; -use std::io; -use std::path::{Path, PathBuf}; -use tokio::fs as tfs; -use tokio::io::AsyncWriteExt; - -const STORAGE_VERSION: u32 = 3; - -// --------------------------------------------------------------------------- -// StorageBackend trait -// --------------------------------------------------------------------------- - -#[async_trait::async_trait] -pub trait StorageBackend: Send + Sync { - async fn save_session(&self, session: &Session) -> io::Result<()>; - async fn load_session(&self, session_id: &str) -> io::Result>; - async fn load_all_sessions(&self) -> io::Result>; - async fn append_log_entry(&self, session_id: &str, entry: &LogEntry) -> io::Result<()>; - async fn load_log(&self, session_id: &str) -> io::Result>; - async fn create_session_storage(&self, session_id: &str) -> io::Result<()>; -} - -// --------------------------------------------------------------------------- -// MemoryBackend — used for MACP_MEMORY_ONLY=1 and tests -// --------------------------------------------------------------------------- - -pub struct MemoryBackend; - -#[async_trait::async_trait] -impl StorageBackend for MemoryBackend { - async fn save_session(&self, _session: &Session) -> io::Result<()> { - Ok(()) - } - - async fn load_session(&self, _session_id: &str) -> io::Result> { - Ok(None) - } - - async fn load_all_sessions(&self) -> io::Result> { - Ok(vec![]) - } - - async fn append_log_entry(&self, _session_id: &str, _entry: &LogEntry) -> io::Result<()> { - Ok(()) - } - - async fn load_log(&self, _session_id: &str) -> io::Result> { - Ok(vec![]) - } - - async fn create_session_storage(&self, _session_id: &str) -> io::Result<()> { - Ok(()) - } -} - -// --------------------------------------------------------------------------- -// FileBackend — per-session directory structure with append-only JSONL logs -// --------------------------------------------------------------------------- - -pub struct FileBackend { - base_dir: PathBuf, -} - -impl FileBackend { - pub fn new(base_dir: PathBuf) -> io::Result { - fs::create_dir_all(base_dir.join("sessions"))?; - Ok(Self { base_dir }) - } - - fn session_dir(&self, session_id: &str) -> PathBuf { - self.base_dir.join("sessions").join(session_id) - } - - fn session_file(&self, session_id: &str) -> PathBuf { - self.session_dir(session_id).join("session.json") - } - - fn log_file(&self, session_id: &str) -> PathBuf { - self.session_dir(session_id).join("log.jsonl") - } - - async fn atomic_write(path: &Path, data: &[u8]) -> io::Result<()> { - let tmp_path = path.with_extension("json.tmp"); - tfs::write(&tmp_path, data).await?; - tfs::rename(&tmp_path, path).await - } -} - -#[async_trait::async_trait] -impl StorageBackend for FileBackend { - async fn create_session_storage(&self, session_id: &str) -> io::Result<()> { - tfs::create_dir_all(self.session_dir(session_id)).await - } - - async fn save_session(&self, session: &Session) -> io::Result<()> { - let persisted = PersistedSession::from(session); - let bytes = serde_json::to_vec_pretty(&persisted) - .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?; - Self::atomic_write(&self.session_file(&session.session_id), &bytes).await - } - - async fn load_session(&self, session_id: &str) -> io::Result> { - let path = self.session_file(session_id); - if tfs::metadata(&path).await.is_err() { - return Ok(None); - } - let bytes = tfs::read(&path).await?; - let persisted: PersistedSession = serde_json::from_slice(&bytes) - .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?; - Ok(Some(Session::from(persisted))) - } - - async fn load_all_sessions(&self) -> io::Result> { - let sessions_dir = self.base_dir.join("sessions"); - if tfs::metadata(&sessions_dir).await.is_err() { - return Ok(vec![]); - } - let mut sessions = Vec::new(); - let mut entries = tfs::read_dir(&sessions_dir).await?; - while let Some(entry) = entries.next_entry().await? { - if !entry.file_type().await?.is_dir() { - continue; - } - let session_file = entry.path().join("session.json"); - if tfs::metadata(&session_file).await.is_err() { - continue; - } - let bytes = tfs::read(&session_file).await?; - match serde_json::from_slice::(&bytes) { - Ok(persisted) => sessions.push(Session::from(persisted)), - Err(e) => { - eprintln!( - "warning: failed to deserialize session from {}: {e}; skipping", - session_file.display() - ); - } - } - } - Ok(sessions) - } - - async fn append_log_entry(&self, session_id: &str, entry: &LogEntry) -> io::Result<()> { - let path = self.log_file(session_id); - let mut line = serde_json::to_string(entry) - .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?; - line.push('\n'); - - let mut file = tfs::OpenOptions::new() - .create(true) - .append(true) - .open(&path) - .await?; - file.write_all(line.as_bytes()).await?; - file.sync_data().await?; - Ok(()) - } - - async fn load_log(&self, session_id: &str) -> io::Result> { - let path = self.log_file(session_id); - if tfs::metadata(&path).await.is_err() { - return Ok(vec![]); - } - let content = tfs::read_to_string(&path).await?; - let mut entries = Vec::new(); - for (line_num, line) in content.lines().enumerate() { - if line.trim().is_empty() { - continue; - } - match serde_json::from_str::(line) { - Ok(entry) => entries.push(entry), - Err(e) => { - eprintln!( - "warning: failed to parse log entry at {}:{}: {e}; skipping", - path.display(), - line_num + 1 - ); - } - } - } - Ok(entries) - } -} - -// --------------------------------------------------------------------------- -// storage_version.json -// --------------------------------------------------------------------------- - -#[derive(serde::Serialize, serde::Deserialize)] -struct StorageVersion { - version: u32, -} - -pub fn write_storage_version(base_dir: &Path) -> io::Result<()> { - let sv = StorageVersion { - version: STORAGE_VERSION, - }; - let bytes = serde_json::to_vec_pretty(&sv) - .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?; - fs::write(base_dir.join("storage_version.json"), bytes) -} - -pub fn read_storage_version(base_dir: &Path) -> io::Result> { - let path = base_dir.join("storage_version.json"); - if !path.exists() { - return Ok(None); - } - let bytes = fs::read(&path)?; - let sv: StorageVersion = serde_json::from_slice(&bytes) - .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?; - Ok(Some(sv.version)) -} - -// --------------------------------------------------------------------------- -// Migration from legacy monolithic format (v1 → v2) -// --------------------------------------------------------------------------- - -pub fn migrate_if_needed(base_dir: &Path) -> io::Result<()> { - let sessions_dir = base_dir.join("sessions"); - let legacy_sessions = base_dir.join("sessions.json"); - let legacy_logs = base_dir.join("logs.json"); - - // Already at current version or fresh install (no legacy files, sessions dir exists) - let current_version = read_storage_version(base_dir)?; - if sessions_dir.exists() && !legacy_sessions.exists() && !legacy_logs.exists() { - // v2 → v3: no-op data migration, just bump version. New LogEntry fields - // use #[serde(default)] so existing v2 JSONL lines deserialize fine. - if current_version.unwrap_or(0) < STORAGE_VERSION { - write_storage_version(base_dir)?; - } - return Ok(()); - } - - if !legacy_sessions.exists() && !legacy_logs.exists() && !sessions_dir.exists() { - write_storage_version(base_dir)?; - return Ok(()); - } - - // Already migrated from v1 - if sessions_dir.exists() && !legacy_sessions.exists() && !legacy_logs.exists() { - write_storage_version(base_dir)?; - return Ok(()); - } - - println!("Migrating legacy storage format to per-session directories..."); - - // Load legacy sessions - let sessions: HashMap = if legacy_sessions.exists() { - let bytes = fs::read(&legacy_sessions)?; - serde_json::from_slice(&bytes).map_err(|e| { - io::Error::new( - io::ErrorKind::InvalidData, - format!("failed to parse legacy sessions.json: {e}"), - ) - })? - } else { - HashMap::new() - }; - - // Load legacy logs - let logs: HashMap> = if legacy_logs.exists() { - let bytes = fs::read(&legacy_logs)?; - serde_json::from_slice(&bytes).map_err(|e| { - io::Error::new( - io::ErrorKind::InvalidData, - format!("failed to parse legacy logs.json: {e}"), - ) - })? - } else { - HashMap::new() - }; - - fs::create_dir_all(&sessions_dir)?; - - // Migrate each session - for (session_id, persisted) in &sessions { - let dir = sessions_dir.join(session_id); - fs::create_dir_all(&dir)?; - - // Write session.json - let session_bytes = serde_json::to_vec_pretty(persisted) - .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?; - fs::write(dir.join("session.json"), session_bytes)?; - - // Write log.jsonl - if let Some(entries) = logs.get(session_id) { - let mut log_data = String::new(); - for entry in entries { - let line = serde_json::to_string(entry) - .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?; - log_data.push_str(&line); - log_data.push('\n'); - } - fs::write(dir.join("log.jsonl"), log_data)?; - } - } - - // Also migrate logs for sessions that only appear in logs (not in sessions.json) - for (session_id, entries) in &logs { - if sessions.contains_key(session_id) { - continue; - } - let dir = sessions_dir.join(session_id); - fs::create_dir_all(&dir)?; - let mut log_data = String::new(); - for entry in entries { - let line = serde_json::to_string(entry) - .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?; - log_data.push_str(&line); - log_data.push('\n'); - } - fs::write(dir.join("log.jsonl"), log_data)?; - } - - // Backup old files instead of deleting - if legacy_sessions.exists() { - fs::rename(&legacy_sessions, base_dir.join("sessions.json.migrated"))?; - } - if legacy_logs.exists() { - fs::rename(&legacy_logs, base_dir.join("logs.json.migrated"))?; - } - - write_storage_version(base_dir)?; - println!( - "Migration complete: {} sessions, {} log sets migrated.", - sessions.len(), - logs.len() - ); - Ok(()) -} - -// --------------------------------------------------------------------------- -// Crash recovery -// --------------------------------------------------------------------------- - -pub fn recover_session(session: &mut Session, log_entries: &[LogEntry]) { - // Ensure all log entry message IDs are in the session's dedup set. - // If the runtime crashed after writing a log entry but before persisting - // the session snapshot, there will be entries in the log not reflected - // in seen_message_ids. - let mut recovered = 0usize; - for entry in log_entries { - if !entry.message_id.is_empty() && session.seen_message_ids.insert(entry.message_id.clone()) - { - recovered += 1; - } - } - if recovered > 0 { - eprintln!( - "recovery: session '{}' reconciled {} log entries into dedup state", - session.session_id, recovered - ); - } -} - -pub fn cleanup_temp_files(base_dir: &Path) { - let sessions_dir = base_dir.join("sessions"); - if !sessions_dir.exists() { - return; - } - if let Ok(entries) = fs::read_dir(&sessions_dir) { - for entry in entries.flatten() { - if !entry.file_type().map(|ft| ft.is_dir()).unwrap_or(false) { - continue; - } - let dir = entry.path(); - if let Ok(files) = fs::read_dir(&dir) { - for file in files.flatten() { - let path = file.path(); - if path.extension().and_then(|e| e.to_str()) == Some("tmp") { - eprintln!("recovery: removing orphaned temp file {}", path.display()); - let _ = fs::remove_file(&path); - } - } - } - } - } -} - -// --------------------------------------------------------------------------- -// Tests -// --------------------------------------------------------------------------- - -#[cfg(test)] -mod tests { - use super::*; - use crate::log_store::EntryKind; - use crate::session::SessionState; - use std::collections::HashSet; - - fn sample_session(id: &str) -> Session { - Session { - session_id: id.into(), - state: SessionState::Open, - ttl_expiry: 61_000, - ttl_ms: 60_000, - started_at_unix_ms: 1_000, - resolution: None, - mode: "macp.mode.decision.v1".into(), - mode_state: vec![1, 2, 3], - participants: vec!["alice".into(), "bob".into()], - seen_message_ids: HashSet::from(["m1".into()]), - intent: "test intent".into(), - mode_version: "1.0.0".into(), - configuration_version: "cfg-1".into(), - policy_version: "pol-1".into(), - context: vec![9], - roots: vec![crate::pb::Root { - uri: "root://1".into(), - name: "r1".into(), - }], - initiator_sender: "alice".into(), - } - } - - fn sample_entry(id: &str) -> LogEntry { - LogEntry { - message_id: id.into(), - received_at_ms: 1_700_000_000_000, - sender: "alice".into(), - message_type: "Message".into(), - raw_payload: vec![], - entry_kind: EntryKind::Incoming, - session_id: String::new(), - mode: String::new(), - macp_version: String::new(), - } - } - - #[tokio::test] - async fn file_backend_session_round_trip() { - let dir = tempfile::tempdir().unwrap(); - let backend = FileBackend::new(dir.path().to_path_buf()).unwrap(); - - let session = sample_session("s1"); - backend.create_session_storage("s1").await.unwrap(); - backend.save_session(&session).await.unwrap(); - - let loaded = backend.load_session("s1").await.unwrap().unwrap(); - assert_eq!(loaded.session_id, "s1"); - assert_eq!(loaded.ttl_ms, 60_000); - assert_eq!(loaded.mode_version, "1.0.0"); - assert!(loaded.seen_message_ids.contains("m1")); - assert_eq!(loaded.participants, vec!["alice", "bob"]); - } - - #[tokio::test] - async fn file_backend_log_append_and_load() { - let dir = tempfile::tempdir().unwrap(); - let backend = FileBackend::new(dir.path().to_path_buf()).unwrap(); - - backend.create_session_storage("s1").await.unwrap(); - backend - .append_log_entry("s1", &sample_entry("m1")) - .await - .unwrap(); - backend - .append_log_entry("s1", &sample_entry("m2")) - .await - .unwrap(); - backend - .append_log_entry("s1", &sample_entry("m3")) - .await - .unwrap(); - - let log = backend.load_log("s1").await.unwrap(); - assert_eq!(log.len(), 3); - assert_eq!(log[0].message_id, "m1"); - assert_eq!(log[1].message_id, "m2"); - assert_eq!(log[2].message_id, "m3"); - } - - #[tokio::test] - async fn file_backend_load_all_sessions() { - let dir = tempfile::tempdir().unwrap(); - let backend = FileBackend::new(dir.path().to_path_buf()).unwrap(); - - for id in ["s1", "s2", "s3"] { - backend.create_session_storage(id).await.unwrap(); - backend.save_session(&sample_session(id)).await.unwrap(); - } - - let mut sessions = backend.load_all_sessions().await.unwrap(); - sessions.sort_by(|a, b| a.session_id.cmp(&b.session_id)); - assert_eq!(sessions.len(), 3); - assert_eq!(sessions[0].session_id, "s1"); - assert_eq!(sessions[1].session_id, "s2"); - assert_eq!(sessions[2].session_id, "s3"); - } - - #[tokio::test] - async fn memory_backend_is_noop() { - let backend = MemoryBackend; - backend.create_session_storage("s1").await.unwrap(); - backend.save_session(&sample_session("s1")).await.unwrap(); - assert!(backend.load_session("s1").await.unwrap().is_none()); - assert!(backend.load_all_sessions().await.unwrap().is_empty()); - } - - #[tokio::test] - async fn append_only_no_full_rewrite() { - let dir = tempfile::tempdir().unwrap(); - let backend = FileBackend::new(dir.path().to_path_buf()).unwrap(); - backend.create_session_storage("s1").await.unwrap(); - - // Append 100 entries - for i in 0..100 { - backend - .append_log_entry("s1", &sample_entry(&format!("m{}", i))) - .await - .unwrap(); - } - - // Verify file has 100 lines - let content = fs::read_to_string(backend.log_file("s1")).unwrap(); - let line_count = content.lines().count(); - assert_eq!(line_count, 100); - - // Verify all entries are loadable - let log = backend.load_log("s1").await.unwrap(); - assert_eq!(log.len(), 100); - } - - #[test] - fn crash_recovery_reconciles_dedup_state() { - let mut session = sample_session("s1"); - // session has "m1" in seen_message_ids - assert!(session.seen_message_ids.contains("m1")); - assert!(!session.seen_message_ids.contains("m2")); - assert!(!session.seen_message_ids.contains("m3")); - - let entries = vec![ - sample_entry("m1"), // already in dedup set - sample_entry("m2"), // missing from dedup set - sample_entry("m3"), // missing from dedup set - ]; - - recover_session(&mut session, &entries); - - assert!(session.seen_message_ids.contains("m1")); - assert!(session.seen_message_ids.contains("m2")); - assert!(session.seen_message_ids.contains("m3")); - } - - #[tokio::test] - async fn migration_from_legacy_format() { - let dir = tempfile::tempdir().unwrap(); - let base = dir.path(); - - // Create legacy format files - let session = sample_session("s1"); - let persisted = PersistedSession::from(&session); - let sessions_map: HashMap = - [("s1".into(), persisted)].into_iter().collect(); - fs::write( - base.join("sessions.json"), - serde_json::to_vec_pretty(&sessions_map).unwrap(), - ) - .unwrap(); - - let entries = vec![sample_entry("m1"), sample_entry("m2")]; - let logs_map: HashMap> = - [("s1".into(), entries)].into_iter().collect(); - fs::write( - base.join("logs.json"), - serde_json::to_vec_pretty(&logs_map).unwrap(), - ) - .unwrap(); - - // Run migration - migrate_if_needed(base).unwrap(); - - // Verify per-session directories created - assert!(base.join("sessions/s1/session.json").exists()); - assert!(base.join("sessions/s1/log.jsonl").exists()); - - // Verify old files renamed - assert!(base.join("sessions.json.migrated").exists()); - assert!(base.join("logs.json.migrated").exists()); - assert!(!base.join("sessions.json").exists()); - assert!(!base.join("logs.json").exists()); - - // Verify storage version - assert_eq!(read_storage_version(base).unwrap(), Some(STORAGE_VERSION)); - - // Verify data is loadable via FileBackend - let backend = FileBackend::new(base.to_path_buf()).unwrap(); - let loaded = backend.load_session("s1").await.unwrap().unwrap(); - assert_eq!(loaded.session_id, "s1"); - assert_eq!(loaded.ttl_ms, 60_000); - - let log = backend.load_log("s1").await.unwrap(); - assert_eq!(log.len(), 2); - } - - #[tokio::test] - async fn ttl_ms_backward_compat_deserialization() { - let dir = tempfile::tempdir().unwrap(); - let base = dir.path(); - let backend = FileBackend::new(base.to_path_buf()).unwrap(); - backend.create_session_storage("s1").await.unwrap(); - - // Write a session JSON without ttl_ms (simulating old format) - let json = serde_json::json!({ - "session_id": "s1", - "state": "Open", - "ttl_expiry": 61000, - "started_at_unix_ms": 1000, - "resolution": null, - "mode": "macp.mode.decision.v1", - "mode_state": [], - "participants": ["alice"], - "seen_message_ids": [], - "intent": "", - "mode_version": "1.0.0", - "configuration_version": "cfg", - "policy_version": "pol", - "context": [], - "roots": [], - "initiator_sender": "alice" - }); - fs::write( - backend.session_file("s1"), - serde_json::to_vec_pretty(&json).unwrap(), - ) - .unwrap(); - - let loaded = backend.load_session("s1").await.unwrap().unwrap(); - // ttl_ms should be computed from ttl_expiry - started_at_unix_ms - assert_eq!(loaded.ttl_ms, 60_000); - } - - #[test] - fn cleanup_temp_files_removes_orphans() { - let dir = tempfile::tempdir().unwrap(); - let base = dir.path(); - let sessions_dir = base.join("sessions").join("s1"); - fs::create_dir_all(&sessions_dir).unwrap(); - - // Create an orphaned temp file - fs::write(sessions_dir.join("session.json.tmp"), b"partial").unwrap(); - assert!(sessions_dir.join("session.json.tmp").exists()); - - cleanup_temp_files(base); - - assert!(!sessions_dir.join("session.json.tmp").exists()); - } - - #[tokio::test] - async fn write_ordering_log_before_session() { - let dir = tempfile::tempdir().unwrap(); - let backend = FileBackend::new(dir.path().to_path_buf()).unwrap(); - backend.create_session_storage("s1").await.unwrap(); - - // Write log entry (commit point) - backend - .append_log_entry("s1", &sample_entry("m1")) - .await - .unwrap(); - - // Verify log is durable even without session save - let log = backend.load_log("s1").await.unwrap(); - assert_eq!(log.len(), 1); - assert_eq!(log[0].message_id, "m1"); - - // Session load should return None (not yet saved) - assert!(backend.load_session("s1").await.unwrap().is_none()); - - // Now save session - backend.save_session(&sample_session("s1")).await.unwrap(); - assert!(backend.load_session("s1").await.unwrap().is_some()); - } -} diff --git a/src/storage/compaction.rs b/src/storage/compaction.rs new file mode 100644 index 0000000..0baab70 --- /dev/null +++ b/src/storage/compaction.rs @@ -0,0 +1,35 @@ +use crate::log_store::{EntryKind, LogEntry}; +use crate::registry::PersistedSession; +use crate::session::Session; +use std::io; + +use super::StorageBackend; + +/// Compact a session's log into a single checkpoint entry. +/// +/// This replaces all existing log entries with a single `Checkpoint` entry +/// containing the serialized session state. Should only be called on sessions +/// in terminal state (Resolved/Expired/Cancelled). +pub async fn compact_session_log( + storage: &dyn StorageBackend, + session_id: &str, + session: &Session, +) -> io::Result<()> { + let persisted = PersistedSession::from(session); + let raw_payload = serde_json::to_vec(&persisted) + .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?; + + let checkpoint = LogEntry { + message_id: String::new(), + received_at_ms: chrono::Utc::now().timestamp_millis(), + sender: "_runtime".into(), + message_type: "Checkpoint".into(), + raw_payload, + entry_kind: EntryKind::Checkpoint, + session_id: session_id.into(), + mode: session.mode.clone(), + macp_version: String::new(), + }; + + storage.replace_log(session_id, &[checkpoint]).await +} diff --git a/src/storage/file.rs b/src/storage/file.rs new file mode 100644 index 0000000..ee13dcb --- /dev/null +++ b/src/storage/file.rs @@ -0,0 +1,401 @@ +use crate::log_store::LogEntry; +use crate::registry::PersistedSession; +use crate::session::Session; +use std::fs; +use std::io; +use std::path::{Path, PathBuf}; +use tokio::fs as tfs; +use tokio::io::AsyncWriteExt; + +use super::StorageBackend; + +pub struct FileBackend { + base_dir: PathBuf, +} + +impl FileBackend { + pub fn new(base_dir: PathBuf) -> io::Result { + fs::create_dir_all(base_dir.join("sessions"))?; + Ok(Self { base_dir }) + } + + fn session_dir(&self, session_id: &str) -> PathBuf { + self.base_dir.join("sessions").join(session_id) + } + + pub(crate) fn session_file(&self, session_id: &str) -> PathBuf { + self.session_dir(session_id).join("session.json") + } + + pub(crate) fn log_file(&self, session_id: &str) -> PathBuf { + self.session_dir(session_id).join("log.jsonl") + } + + async fn atomic_write(path: &Path, data: &[u8]) -> io::Result<()> { + let tmp_path = path.with_extension("json.tmp"); + tfs::write(&tmp_path, data).await?; + tfs::rename(&tmp_path, path).await + } +} + +#[async_trait::async_trait] +impl StorageBackend for FileBackend { + async fn create_session_storage(&self, session_id: &str) -> io::Result<()> { + tfs::create_dir_all(self.session_dir(session_id)).await + } + + async fn save_session(&self, session: &Session) -> io::Result<()> { + let persisted = PersistedSession::from(session); + let bytes = serde_json::to_vec_pretty(&persisted) + .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?; + Self::atomic_write(&self.session_file(&session.session_id), &bytes).await + } + + async fn load_session(&self, session_id: &str) -> io::Result> { + let path = self.session_file(session_id); + if tfs::metadata(&path).await.is_err() { + return Ok(None); + } + let bytes = tfs::read(&path).await?; + let persisted: PersistedSession = serde_json::from_slice(&bytes) + .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?; + Ok(Some(Session::from(persisted))) + } + + async fn load_all_sessions(&self) -> io::Result> { + let ids = self.list_session_ids().await?; + let mut sessions = Vec::new(); + for id in ids { + match self.load_session(&id).await { + Ok(Some(s)) => sessions.push(s), + Ok(None) => {} + Err(e) => { + eprintln!("warning: failed to load session {id}: {e}; skipping"); + } + } + } + Ok(sessions) + } + + async fn delete_session(&self, session_id: &str) -> io::Result<()> { + let dir = self.session_dir(session_id); + if tfs::metadata(&dir).await.is_ok() { + tfs::remove_dir_all(&dir).await?; + } + Ok(()) + } + + async fn list_session_ids(&self) -> io::Result> { + let sessions_dir = self.base_dir.join("sessions"); + if tfs::metadata(&sessions_dir).await.is_err() { + return Ok(vec![]); + } + let mut ids = Vec::new(); + let mut entries = tfs::read_dir(&sessions_dir).await?; + while let Some(entry) = entries.next_entry().await? { + if !entry.file_type().await?.is_dir() { + continue; + } + ids.push(entry.file_name().to_string_lossy().to_string()); + } + Ok(ids) + } + + async fn append_log_entry(&self, session_id: &str, entry: &LogEntry) -> io::Result<()> { + let path = self.log_file(session_id); + let mut line = serde_json::to_string(entry) + .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?; + line.push('\n'); + + let mut file = tfs::OpenOptions::new() + .create(true) + .append(true) + .open(&path) + .await?; + file.write_all(line.as_bytes()).await?; + file.sync_data().await?; + Ok(()) + } + + async fn load_log(&self, session_id: &str) -> io::Result> { + let path = self.log_file(session_id); + if tfs::metadata(&path).await.is_err() { + return Ok(vec![]); + } + let content = tfs::read_to_string(&path).await?; + let mut entries = Vec::new(); + for (line_num, line) in content.lines().enumerate() { + if line.trim().is_empty() { + continue; + } + match serde_json::from_str::(line) { + Ok(entry) => entries.push(entry), + Err(e) => { + eprintln!( + "warning: failed to parse log entry at {}:{}: {e}; skipping", + path.display(), + line_num + 1 + ); + } + } + } + Ok(entries) + } + + async fn replace_log(&self, session_id: &str, entries: &[LogEntry]) -> io::Result<()> { + let path = self.log_file(session_id); + let mut data = String::new(); + for entry in entries { + let line = serde_json::to_string(entry) + .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?; + data.push_str(&line); + data.push('\n'); + } + let tmp_path = path.with_extension("jsonl.tmp"); + tfs::write(&tmp_path, data.as_bytes()).await?; + tfs::rename(&tmp_path, &path).await + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::log_store::EntryKind; + use crate::session::SessionState; + use std::collections::HashSet; + + fn sample_session(id: &str) -> Session { + Session { + session_id: id.into(), + state: SessionState::Open, + ttl_expiry: 61_000, + ttl_ms: 60_000, + started_at_unix_ms: 1_000, + resolution: None, + mode: "macp.mode.decision.v1".into(), + mode_state: vec![1, 2, 3], + participants: vec!["alice".into(), "bob".into()], + seen_message_ids: HashSet::from(["m1".into()]), + intent: "test intent".into(), + mode_version: "1.0.0".into(), + configuration_version: "cfg-1".into(), + policy_version: "pol-1".into(), + context: vec![9], + roots: vec![crate::pb::Root { + uri: "root://1".into(), + name: "r1".into(), + }], + initiator_sender: "alice".into(), + } + } + + fn sample_entry(id: &str) -> LogEntry { + LogEntry { + message_id: id.into(), + received_at_ms: 1_700_000_000_000, + sender: "alice".into(), + message_type: "Message".into(), + raw_payload: vec![], + entry_kind: EntryKind::Incoming, + session_id: String::new(), + mode: String::new(), + macp_version: String::new(), + } + } + + #[tokio::test] + async fn file_backend_session_round_trip() { + let dir = tempfile::tempdir().unwrap(); + let backend = FileBackend::new(dir.path().to_path_buf()).unwrap(); + + let session = sample_session("s1"); + backend.create_session_storage("s1").await.unwrap(); + backend.save_session(&session).await.unwrap(); + + let loaded = backend.load_session("s1").await.unwrap().unwrap(); + assert_eq!(loaded.session_id, "s1"); + assert_eq!(loaded.ttl_ms, 60_000); + assert_eq!(loaded.mode_version, "1.0.0"); + assert!(loaded.seen_message_ids.contains("m1")); + assert_eq!(loaded.participants, vec!["alice", "bob"]); + } + + #[tokio::test] + async fn file_backend_log_append_and_load() { + let dir = tempfile::tempdir().unwrap(); + let backend = FileBackend::new(dir.path().to_path_buf()).unwrap(); + + backend.create_session_storage("s1").await.unwrap(); + backend + .append_log_entry("s1", &sample_entry("m1")) + .await + .unwrap(); + backend + .append_log_entry("s1", &sample_entry("m2")) + .await + .unwrap(); + backend + .append_log_entry("s1", &sample_entry("m3")) + .await + .unwrap(); + + let log = backend.load_log("s1").await.unwrap(); + assert_eq!(log.len(), 3); + assert_eq!(log[0].message_id, "m1"); + assert_eq!(log[1].message_id, "m2"); + assert_eq!(log[2].message_id, "m3"); + } + + #[tokio::test] + async fn file_backend_load_all_sessions() { + let dir = tempfile::tempdir().unwrap(); + let backend = FileBackend::new(dir.path().to_path_buf()).unwrap(); + + for id in ["s1", "s2", "s3"] { + backend.create_session_storage(id).await.unwrap(); + backend.save_session(&sample_session(id)).await.unwrap(); + } + + let mut sessions = backend.load_all_sessions().await.unwrap(); + sessions.sort_by(|a, b| a.session_id.cmp(&b.session_id)); + assert_eq!(sessions.len(), 3); + assert_eq!(sessions[0].session_id, "s1"); + assert_eq!(sessions[1].session_id, "s2"); + assert_eq!(sessions[2].session_id, "s3"); + } + + #[tokio::test] + async fn append_only_no_full_rewrite() { + let dir = tempfile::tempdir().unwrap(); + let backend = FileBackend::new(dir.path().to_path_buf()).unwrap(); + backend.create_session_storage("s1").await.unwrap(); + + for i in 0..100 { + backend + .append_log_entry("s1", &sample_entry(&format!("m{}", i))) + .await + .unwrap(); + } + + let content = fs::read_to_string(backend.log_file("s1")).unwrap(); + let line_count = content.lines().count(); + assert_eq!(line_count, 100); + + let log = backend.load_log("s1").await.unwrap(); + assert_eq!(log.len(), 100); + } + + #[tokio::test] + async fn write_ordering_log_before_session() { + let dir = tempfile::tempdir().unwrap(); + let backend = FileBackend::new(dir.path().to_path_buf()).unwrap(); + backend.create_session_storage("s1").await.unwrap(); + + backend + .append_log_entry("s1", &sample_entry("m1")) + .await + .unwrap(); + + let log = backend.load_log("s1").await.unwrap(); + assert_eq!(log.len(), 1); + assert_eq!(log[0].message_id, "m1"); + + assert!(backend.load_session("s1").await.unwrap().is_none()); + + backend.save_session(&sample_session("s1")).await.unwrap(); + assert!(backend.load_session("s1").await.unwrap().is_some()); + } + + #[tokio::test] + async fn delete_session_removes_directory() { + let dir = tempfile::tempdir().unwrap(); + let backend = FileBackend::new(dir.path().to_path_buf()).unwrap(); + + backend.create_session_storage("s1").await.unwrap(); + backend.save_session(&sample_session("s1")).await.unwrap(); + backend + .append_log_entry("s1", &sample_entry("m1")) + .await + .unwrap(); + + assert!(backend.load_session("s1").await.unwrap().is_some()); + + backend.delete_session("s1").await.unwrap(); + assert!(backend.load_session("s1").await.unwrap().is_none()); + assert!(backend.load_log("s1").await.unwrap().is_empty()); + + // Idempotent + backend.delete_session("s1").await.unwrap(); + } + + #[tokio::test] + async fn list_session_ids_returns_directories() { + let dir = tempfile::tempdir().unwrap(); + let backend = FileBackend::new(dir.path().to_path_buf()).unwrap(); + + for id in ["s1", "s2", "s3"] { + backend.create_session_storage(id).await.unwrap(); + } + + let mut ids = backend.list_session_ids().await.unwrap(); + ids.sort(); + assert_eq!(ids, vec!["s1", "s2", "s3"]); + } + + #[tokio::test] + async fn replace_log_atomically_overwrites() { + let dir = tempfile::tempdir().unwrap(); + let backend = FileBackend::new(dir.path().to_path_buf()).unwrap(); + backend.create_session_storage("s1").await.unwrap(); + + for i in 0..10 { + backend + .append_log_entry("s1", &sample_entry(&format!("m{i}"))) + .await + .unwrap(); + } + assert_eq!(backend.load_log("s1").await.unwrap().len(), 10); + + let replacement = vec![sample_entry("checkpoint")]; + backend.replace_log("s1", &replacement).await.unwrap(); + + let log = backend.load_log("s1").await.unwrap(); + assert_eq!(log.len(), 1); + assert_eq!(log[0].message_id, "checkpoint"); + } + + #[tokio::test] + async fn ttl_ms_backward_compat_deserialization() { + let dir = tempfile::tempdir().unwrap(); + let base = dir.path(); + let backend = FileBackend::new(base.to_path_buf()).unwrap(); + backend.create_session_storage("s1").await.unwrap(); + + let json = serde_json::json!({ + "session_id": "s1", + "state": "Open", + "ttl_expiry": 61000, + "started_at_unix_ms": 1000, + "resolution": null, + "mode": "macp.mode.decision.v1", + "mode_state": [], + "participants": ["alice"], + "seen_message_ids": [], + "intent": "", + "mode_version": "1.0.0", + "configuration_version": "cfg", + "policy_version": "pol", + "context": [], + "roots": [], + "initiator_sender": "alice" + }); + fs::write( + backend.session_file("s1"), + serde_json::to_vec_pretty(&json).unwrap(), + ) + .unwrap(); + + let loaded = backend.load_session("s1").await.unwrap().unwrap(); + assert_eq!(loaded.ttl_ms, 60_000); + } +} diff --git a/src/storage/memory.rs b/src/storage/memory.rs new file mode 100644 index 0000000..b755b8d --- /dev/null +++ b/src/storage/memory.rs @@ -0,0 +1,83 @@ +use crate::log_store::LogEntry; +use crate::session::Session; +use std::io; + +use super::StorageBackend; + +pub struct MemoryBackend; + +#[async_trait::async_trait] +impl StorageBackend for MemoryBackend { + async fn save_session(&self, _session: &Session) -> io::Result<()> { + Ok(()) + } + + async fn load_session(&self, _session_id: &str) -> io::Result> { + Ok(None) + } + + async fn load_all_sessions(&self) -> io::Result> { + Ok(vec![]) + } + + async fn delete_session(&self, _session_id: &str) -> io::Result<()> { + Ok(()) + } + + async fn list_session_ids(&self) -> io::Result> { + Ok(vec![]) + } + + async fn append_log_entry(&self, _session_id: &str, _entry: &LogEntry) -> io::Result<()> { + Ok(()) + } + + async fn load_log(&self, _session_id: &str) -> io::Result> { + Ok(vec![]) + } + + async fn create_session_storage(&self, _session_id: &str) -> io::Result<()> { + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + fn sample_session() -> Session { + use crate::session::SessionState; + use std::collections::HashSet; + + Session { + session_id: "s1".into(), + state: SessionState::Open, + ttl_expiry: 61_000, + ttl_ms: 60_000, + started_at_unix_ms: 1_000, + resolution: None, + mode: "macp.mode.decision.v1".into(), + mode_state: vec![], + participants: vec!["alice".into()], + seen_message_ids: HashSet::new(), + intent: "".into(), + mode_version: "1.0.0".into(), + configuration_version: "cfg-1".into(), + policy_version: "pol-1".into(), + context: vec![], + roots: vec![], + initiator_sender: "alice".into(), + } + } + + #[tokio::test] + async fn memory_backend_is_noop() { + let backend = MemoryBackend; + backend.create_session_storage("s1").await.unwrap(); + backend.save_session(&sample_session()).await.unwrap(); + assert!(backend.load_session("s1").await.unwrap().is_none()); + assert!(backend.load_all_sessions().await.unwrap().is_empty()); + assert!(backend.list_session_ids().await.unwrap().is_empty()); + backend.delete_session("s1").await.unwrap(); + } +} diff --git a/src/storage/migration.rs b/src/storage/migration.rs new file mode 100644 index 0000000..52d0563 --- /dev/null +++ b/src/storage/migration.rs @@ -0,0 +1,240 @@ +use crate::log_store::LogEntry; +use crate::registry::PersistedSession; +use std::collections::HashMap; +use std::fs; +use std::io; +use std::path::Path; + +const STORAGE_VERSION: u32 = 3; + +#[derive(serde::Serialize, serde::Deserialize)] +struct StorageVersion { + version: u32, +} + +pub fn write_storage_version(base_dir: &Path) -> io::Result<()> { + let sv = StorageVersion { + version: STORAGE_VERSION, + }; + let bytes = serde_json::to_vec_pretty(&sv) + .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?; + fs::write(base_dir.join("storage_version.json"), bytes) +} + +pub fn read_storage_version(base_dir: &Path) -> io::Result> { + let path = base_dir.join("storage_version.json"); + if !path.exists() { + return Ok(None); + } + let bytes = fs::read(&path)?; + let sv: StorageVersion = serde_json::from_slice(&bytes) + .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?; + Ok(Some(sv.version)) +} + +pub fn migrate_if_needed(base_dir: &Path) -> io::Result<()> { + let sessions_dir = base_dir.join("sessions"); + let legacy_sessions = base_dir.join("sessions.json"); + let legacy_logs = base_dir.join("logs.json"); + + // Already at current version or fresh install (no legacy files, sessions dir exists) + let current_version = read_storage_version(base_dir)?; + if sessions_dir.exists() && !legacy_sessions.exists() && !legacy_logs.exists() { + // v2 → v3: no-op data migration, just bump version. New LogEntry fields + // use #[serde(default)] so existing v2 JSONL lines deserialize fine. + if current_version.unwrap_or(0) < STORAGE_VERSION { + write_storage_version(base_dir)?; + } + return Ok(()); + } + + if !legacy_sessions.exists() && !legacy_logs.exists() && !sessions_dir.exists() { + write_storage_version(base_dir)?; + return Ok(()); + } + + // Already migrated from v1 + if sessions_dir.exists() && !legacy_sessions.exists() && !legacy_logs.exists() { + write_storage_version(base_dir)?; + return Ok(()); + } + + println!("Migrating legacy storage format to per-session directories..."); + + // Load legacy sessions + let sessions: HashMap = if legacy_sessions.exists() { + let bytes = fs::read(&legacy_sessions)?; + serde_json::from_slice(&bytes).map_err(|e| { + io::Error::new( + io::ErrorKind::InvalidData, + format!("failed to parse legacy sessions.json: {e}"), + ) + })? + } else { + HashMap::new() + }; + + // Load legacy logs + let logs: HashMap> = if legacy_logs.exists() { + let bytes = fs::read(&legacy_logs)?; + serde_json::from_slice(&bytes).map_err(|e| { + io::Error::new( + io::ErrorKind::InvalidData, + format!("failed to parse legacy logs.json: {e}"), + ) + })? + } else { + HashMap::new() + }; + + fs::create_dir_all(&sessions_dir)?; + + // Migrate each session + for (session_id, persisted) in &sessions { + let dir = sessions_dir.join(session_id); + fs::create_dir_all(&dir)?; + + // Write session.json + let session_bytes = serde_json::to_vec_pretty(persisted) + .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?; + fs::write(dir.join("session.json"), session_bytes)?; + + // Write log.jsonl + if let Some(entries) = logs.get(session_id) { + let mut log_data = String::new(); + for entry in entries { + let line = serde_json::to_string(entry) + .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?; + log_data.push_str(&line); + log_data.push('\n'); + } + fs::write(dir.join("log.jsonl"), log_data)?; + } + } + + // Also migrate logs for sessions that only appear in logs (not in sessions.json) + for (session_id, entries) in &logs { + if sessions.contains_key(session_id) { + continue; + } + let dir = sessions_dir.join(session_id); + fs::create_dir_all(&dir)?; + let mut log_data = String::new(); + for entry in entries { + let line = serde_json::to_string(entry) + .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?; + log_data.push_str(&line); + log_data.push('\n'); + } + fs::write(dir.join("log.jsonl"), log_data)?; + } + + // Backup old files instead of deleting + if legacy_sessions.exists() { + fs::rename(&legacy_sessions, base_dir.join("sessions.json.migrated"))?; + } + if legacy_logs.exists() { + fs::rename(&legacy_logs, base_dir.join("logs.json.migrated"))?; + } + + write_storage_version(base_dir)?; + println!( + "Migration complete: {} sessions, {} log sets migrated.", + sessions.len(), + logs.len() + ); + Ok(()) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::log_store::EntryKind; + use crate::session::{Session, SessionState}; + use crate::storage::StorageBackend; + use std::collections::HashSet; + + fn sample_session(id: &str) -> Session { + Session { + session_id: id.into(), + state: SessionState::Open, + ttl_expiry: 61_000, + ttl_ms: 60_000, + started_at_unix_ms: 1_000, + resolution: None, + mode: "macp.mode.decision.v1".into(), + mode_state: vec![1, 2, 3], + participants: vec!["alice".into(), "bob".into()], + seen_message_ids: HashSet::from(["m1".into()]), + intent: "test intent".into(), + mode_version: "1.0.0".into(), + configuration_version: "cfg-1".into(), + policy_version: "pol-1".into(), + context: vec![9], + roots: vec![crate::pb::Root { + uri: "root://1".into(), + name: "r1".into(), + }], + initiator_sender: "alice".into(), + } + } + + fn sample_entry(id: &str) -> LogEntry { + LogEntry { + message_id: id.into(), + received_at_ms: 1_700_000_000_000, + sender: "alice".into(), + message_type: "Message".into(), + raw_payload: vec![], + entry_kind: EntryKind::Incoming, + session_id: String::new(), + mode: String::new(), + macp_version: String::new(), + } + } + + #[tokio::test] + async fn migration_from_legacy_format() { + let dir = tempfile::tempdir().unwrap(); + let base = dir.path(); + + let session = sample_session("s1"); + let persisted = PersistedSession::from(&session); + let sessions_map: HashMap = + [("s1".into(), persisted)].into_iter().collect(); + fs::write( + base.join("sessions.json"), + serde_json::to_vec_pretty(&sessions_map).unwrap(), + ) + .unwrap(); + + let entries = vec![sample_entry("m1"), sample_entry("m2")]; + let logs_map: HashMap> = + [("s1".into(), entries)].into_iter().collect(); + fs::write( + base.join("logs.json"), + serde_json::to_vec_pretty(&logs_map).unwrap(), + ) + .unwrap(); + + migrate_if_needed(base).unwrap(); + + assert!(base.join("sessions/s1/session.json").exists()); + assert!(base.join("sessions/s1/log.jsonl").exists()); + + assert!(base.join("sessions.json.migrated").exists()); + assert!(base.join("logs.json.migrated").exists()); + assert!(!base.join("sessions.json").exists()); + assert!(!base.join("logs.json").exists()); + + assert_eq!(read_storage_version(base).unwrap(), Some(STORAGE_VERSION)); + + let backend = crate::storage::FileBackend::new(base.to_path_buf()).unwrap(); + let loaded = backend.load_session("s1").await.unwrap().unwrap(); + assert_eq!(loaded.session_id, "s1"); + assert_eq!(loaded.ttl_ms, 60_000); + + let log = backend.load_log("s1").await.unwrap(); + assert_eq!(log.len(), 2); + } +} diff --git a/src/storage/mod.rs b/src/storage/mod.rs new file mode 100644 index 0000000..081e541 --- /dev/null +++ b/src/storage/mod.rs @@ -0,0 +1,48 @@ +mod file; +mod memory; +mod migration; +mod recovery; + +#[cfg(feature = "rocksdb-backend")] +pub mod rocksdb; +#[cfg(feature = "rocksdb-backend")] +pub use self::rocksdb::RocksDbBackend; + +#[cfg(feature = "redis-backend")] +pub mod redis_backend; +#[cfg(feature = "redis-backend")] +pub use redis_backend::RedisBackend; + +pub mod compaction; + +pub use file::FileBackend; +pub use memory::MemoryBackend; +pub use migration::migrate_if_needed; +pub use recovery::{cleanup_temp_files, recover_session}; + +use crate::log_store::LogEntry; +use crate::session::Session; +use std::io; + +// --------------------------------------------------------------------------- +// StorageBackend trait +// --------------------------------------------------------------------------- + +#[async_trait::async_trait] +pub trait StorageBackend: Send + Sync { + async fn save_session(&self, session: &Session) -> io::Result<()>; + async fn load_session(&self, session_id: &str) -> io::Result>; + async fn load_all_sessions(&self) -> io::Result>; + async fn delete_session(&self, session_id: &str) -> io::Result<()>; + async fn list_session_ids(&self) -> io::Result>; + async fn append_log_entry(&self, session_id: &str, entry: &LogEntry) -> io::Result<()>; + async fn load_log(&self, session_id: &str) -> io::Result>; + async fn create_session_storage(&self, session_id: &str) -> io::Result<()>; + + async fn replace_log(&self, _session_id: &str, _entries: &[LogEntry]) -> io::Result<()> { + Err(io::Error::new( + io::ErrorKind::Unsupported, + "compaction not supported by this backend", + )) + } +} diff --git a/src/storage/recovery.rs b/src/storage/recovery.rs new file mode 100644 index 0000000..38953a9 --- /dev/null +++ b/src/storage/recovery.rs @@ -0,0 +1,123 @@ +use crate::log_store::LogEntry; +use crate::session::Session; +use std::fs; +use std::path::Path; + +pub fn recover_session(session: &mut Session, log_entries: &[LogEntry]) { + // Ensure all log entry message IDs are in the session's dedup set. + // If the runtime crashed after writing a log entry but before persisting + // the session snapshot, there will be entries in the log not reflected + // in seen_message_ids. + let mut recovered = 0usize; + for entry in log_entries { + if !entry.message_id.is_empty() && session.seen_message_ids.insert(entry.message_id.clone()) + { + recovered += 1; + } + } + if recovered > 0 { + eprintln!( + "recovery: session '{}' reconciled {} log entries into dedup state", + session.session_id, recovered + ); + } +} + +pub fn cleanup_temp_files(base_dir: &Path) { + let sessions_dir = base_dir.join("sessions"); + if !sessions_dir.exists() { + return; + } + if let Ok(entries) = fs::read_dir(&sessions_dir) { + for entry in entries.flatten() { + if !entry.file_type().map(|ft| ft.is_dir()).unwrap_or(false) { + continue; + } + let dir = entry.path(); + if let Ok(files) = fs::read_dir(&dir) { + for file in files.flatten() { + let path = file.path(); + if path.extension().and_then(|e| e.to_str()) == Some("tmp") { + eprintln!("recovery: removing orphaned temp file {}", path.display()); + let _ = fs::remove_file(&path); + } + } + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::log_store::EntryKind; + use crate::session::SessionState; + use std::collections::HashSet; + + fn sample_session() -> Session { + Session { + session_id: "s1".into(), + state: SessionState::Open, + ttl_expiry: 61_000, + ttl_ms: 60_000, + started_at_unix_ms: 1_000, + resolution: None, + mode: "macp.mode.decision.v1".into(), + mode_state: vec![], + participants: vec!["alice".into()], + seen_message_ids: HashSet::from(["m1".into()]), + intent: "".into(), + mode_version: "1.0.0".into(), + configuration_version: "cfg-1".into(), + policy_version: "pol-1".into(), + context: vec![], + roots: vec![], + initiator_sender: "alice".into(), + } + } + + fn sample_entry(id: &str) -> LogEntry { + LogEntry { + message_id: id.into(), + received_at_ms: 1_700_000_000_000, + sender: "alice".into(), + message_type: "Message".into(), + raw_payload: vec![], + entry_kind: EntryKind::Incoming, + session_id: String::new(), + mode: String::new(), + macp_version: String::new(), + } + } + + #[test] + fn crash_recovery_reconciles_dedup_state() { + let mut session = sample_session(); + assert!(session.seen_message_ids.contains("m1")); + assert!(!session.seen_message_ids.contains("m2")); + assert!(!session.seen_message_ids.contains("m3")); + + let entries = vec![sample_entry("m1"), sample_entry("m2"), sample_entry("m3")]; + + recover_session(&mut session, &entries); + + assert!(session.seen_message_ids.contains("m1")); + assert!(session.seen_message_ids.contains("m2")); + assert!(session.seen_message_ids.contains("m3")); + } + + #[test] + fn cleanup_temp_files_removes_orphans() { + let dir = tempfile::tempdir().unwrap(); + let base = dir.path(); + let sessions_dir = base.join("sessions").join("s1"); + fs::create_dir_all(&sessions_dir).unwrap(); + + fs::write(sessions_dir.join("session.json.tmp"), b"partial").unwrap(); + assert!(sessions_dir.join("session.json.tmp").exists()); + + cleanup_temp_files(base); + + assert!(!sessions_dir.join("session.json.tmp").exists()); + } +} diff --git a/src/storage/redis_backend.rs b/src/storage/redis_backend.rs new file mode 100644 index 0000000..97d32c5 --- /dev/null +++ b/src/storage/redis_backend.rs @@ -0,0 +1,283 @@ +use crate::log_store::LogEntry; +use crate::registry::PersistedSession; +use crate::session::Session; +use redis::AsyncCommands; +use std::io; + +use super::StorageBackend; + +pub struct RedisBackend { + conn: redis::aio::MultiplexedConnection, + prefix: String, +} + +impl RedisBackend { + pub async fn connect(url: &str, prefix: &str) -> io::Result { + let client = redis::Client::open(url).map_err(io::Error::other)?; + let conn = client + .get_multiplexed_async_connection() + .await + .map_err(io::Error::other)?; + Ok(Self { + conn, + prefix: prefix.into(), + }) + } + + fn session_key(&self, session_id: &str) -> String { + format!("{}:session:{}", self.prefix, session_id) + } + + fn log_key(&self, session_id: &str) -> String { + format!("{}:log:{}", self.prefix, session_id) + } + + fn index_key(&self) -> String { + format!("{}:sessions", self.prefix) + } +} + +#[async_trait::async_trait] +impl StorageBackend for RedisBackend { + async fn create_session_storage(&self, session_id: &str) -> io::Result<()> { + let mut conn = self.conn.clone(); + conn.sadd::<_, _, ()>(self.index_key(), session_id) + .await + .map_err(io::Error::other) + } + + async fn save_session(&self, session: &Session) -> io::Result<()> { + let mut conn = self.conn.clone(); + let persisted = PersistedSession::from(session); + let bytes = serde_json::to_vec(&persisted) + .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?; + conn.set::<_, _, ()>(&self.session_key(&session.session_id), bytes) + .await + .map_err(io::Error::other) + } + + async fn load_session(&self, session_id: &str) -> io::Result> { + let mut conn = self.conn.clone(); + let bytes: Option> = conn + .get(self.session_key(session_id)) + .await + .map_err(io::Error::other)?; + match bytes { + Some(b) => { + let persisted: PersistedSession = serde_json::from_slice(&b) + .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?; + Ok(Some(Session::from(persisted))) + } + None => Ok(None), + } + } + + async fn load_all_sessions(&self) -> io::Result> { + let ids = self.list_session_ids().await?; + let mut sessions = Vec::new(); + for id in ids { + if let Some(s) = self.load_session(&id).await? { + sessions.push(s); + } + } + Ok(sessions) + } + + async fn delete_session(&self, session_id: &str) -> io::Result<()> { + let mut conn = self.conn.clone(); + redis::pipe() + .del(self.session_key(session_id)) + .del(self.log_key(session_id)) + .srem(self.index_key(), session_id) + .exec_async(&mut conn) + .await + .map_err(io::Error::other) + } + + async fn list_session_ids(&self) -> io::Result> { + let mut conn = self.conn.clone(); + conn.smembers(self.index_key()) + .await + .map_err(io::Error::other) + } + + async fn append_log_entry(&self, session_id: &str, entry: &LogEntry) -> io::Result<()> { + let mut conn = self.conn.clone(); + let bytes = + serde_json::to_vec(entry).map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?; + conn.rpush::<_, _, ()>(self.log_key(session_id), bytes) + .await + .map_err(io::Error::other) + } + + async fn load_log(&self, session_id: &str) -> io::Result> { + let mut conn = self.conn.clone(); + let items: Vec> = conn + .lrange(self.log_key(session_id), 0, -1) + .await + .map_err(io::Error::other)?; + let mut entries = Vec::with_capacity(items.len()); + for item in items { + let entry: LogEntry = serde_json::from_slice(&item) + .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?; + entries.push(entry); + } + Ok(entries) + } + + async fn replace_log(&self, session_id: &str, entries: &[LogEntry]) -> io::Result<()> { + let mut conn = self.conn.clone(); + let key = self.log_key(session_id); + + // Delete existing list + conn.del::<_, ()>(&key).await.map_err(io::Error::other)?; + + // Push new entries + for entry in entries { + let bytes = serde_json::to_vec(entry) + .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?; + conn.rpush::<_, _, ()>(&key, bytes) + .await + .map_err(io::Error::other)?; + } + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::log_store::EntryKind; + use crate::session::SessionState; + use std::collections::HashSet; + + fn sample_session(id: &str) -> Session { + Session { + session_id: id.into(), + state: SessionState::Open, + ttl_expiry: 61_000, + ttl_ms: 60_000, + started_at_unix_ms: 1_000, + resolution: None, + mode: "macp.mode.decision.v1".into(), + mode_state: vec![1, 2, 3], + participants: vec!["alice".into(), "bob".into()], + seen_message_ids: HashSet::from(["m1".into()]), + intent: "test".into(), + mode_version: "1.0.0".into(), + configuration_version: "cfg-1".into(), + policy_version: "pol-1".into(), + context: vec![9], + roots: vec![], + initiator_sender: "alice".into(), + } + } + + fn sample_entry(id: &str) -> LogEntry { + LogEntry { + message_id: id.into(), + received_at_ms: 1_700_000_000_000, + sender: "alice".into(), + message_type: "Message".into(), + raw_payload: vec![], + entry_kind: EntryKind::Incoming, + session_id: String::new(), + mode: String::new(), + macp_version: String::new(), + } + } + + async fn make_backend() -> Option { + let url = std::env::var("MACP_TEST_REDIS_URL").ok()?; + let prefix = format!("macp_test_{}", uuid::Uuid::new_v4()); + RedisBackend::connect(&url, &prefix).await.ok() + } + + async fn cleanup(backend: &RedisBackend) { + // Clean up all test keys + let mut conn = backend.conn.clone(); + let ids: Vec = conn.smembers(backend.index_key()).await.unwrap_or_default(); + for id in &ids { + let _ = redis::pipe() + .del(backend.session_key(id)) + .del(backend.log_key(id)) + .exec_async(&mut conn) + .await; + } + let _: Result<(), _> = conn.del(backend.index_key()).await; + } + + #[tokio::test] + async fn redis_session_round_trip() { + let Some(backend) = make_backend().await else { + eprintln!("skipping redis test: MACP_TEST_REDIS_URL not set"); + return; + }; + backend.create_session_storage("s1").await.unwrap(); + backend.save_session(&sample_session("s1")).await.unwrap(); + let loaded = backend.load_session("s1").await.unwrap().unwrap(); + assert_eq!(loaded.session_id, "s1"); + assert_eq!(loaded.ttl_ms, 60_000); + cleanup(&backend).await; + } + + #[tokio::test] + async fn redis_log_append_and_load() { + let Some(backend) = make_backend().await else { + return; + }; + backend.create_session_storage("s1").await.unwrap(); + for id in ["m1", "m2", "m3"] { + backend + .append_log_entry("s1", &sample_entry(id)) + .await + .unwrap(); + } + let log = backend.load_log("s1").await.unwrap(); + assert_eq!(log.len(), 3); + assert_eq!(log[0].message_id, "m1"); + assert_eq!(log[2].message_id, "m3"); + cleanup(&backend).await; + } + + #[tokio::test] + async fn redis_list_and_delete() { + let Some(backend) = make_backend().await else { + return; + }; + for id in ["s1", "s2"] { + backend.create_session_storage(id).await.unwrap(); + backend.save_session(&sample_session(id)).await.unwrap(); + } + let mut ids = backend.list_session_ids().await.unwrap(); + ids.sort(); + assert_eq!(ids, vec!["s1", "s2"]); + + backend.delete_session("s1").await.unwrap(); + assert!(backend.load_session("s1").await.unwrap().is_none()); + + // Idempotent + backend.delete_session("s1").await.unwrap(); + cleanup(&backend).await; + } + + #[tokio::test] + async fn redis_replace_log() { + let Some(backend) = make_backend().await else { + return; + }; + backend.create_session_storage("s1").await.unwrap(); + for i in 0..5 { + backend + .append_log_entry("s1", &sample_entry(&format!("m{i}"))) + .await + .unwrap(); + } + let replacement = vec![sample_entry("checkpoint")]; + backend.replace_log("s1", &replacement).await.unwrap(); + let log = backend.load_log("s1").await.unwrap(); + assert_eq!(log.len(), 1); + assert_eq!(log[0].message_id, "checkpoint"); + cleanup(&backend).await; + } +} diff --git a/src/storage/rocksdb.rs b/src/storage/rocksdb.rs new file mode 100644 index 0000000..00c6326 --- /dev/null +++ b/src/storage/rocksdb.rs @@ -0,0 +1,398 @@ +use crate::log_store::LogEntry; +use crate::registry::PersistedSession; +use crate::session::Session; +use rocksdb::{ColumnFamilyDescriptor, IteratorMode, Options, WriteBatchWithTransaction, DB}; +use std::io; +use std::path::Path; + +use super::StorageBackend; + +const CF_SESSIONS: &str = "sessions"; +const CF_LOGS: &str = "logs"; + +pub struct RocksDbBackend { + db: DB, +} + +impl RocksDbBackend { + pub fn open>(path: P) -> io::Result { + let mut opts = Options::default(); + opts.create_if_missing(true); + opts.create_missing_column_families(true); + + let cf_sessions = ColumnFamilyDescriptor::new(CF_SESSIONS, Options::default()); + let cf_logs = ColumnFamilyDescriptor::new(CF_LOGS, Options::default()); + + let db = DB::open_cf_descriptors(&opts, path, vec![cf_sessions, cf_logs]) + .map_err(io::Error::other)?; + + Ok(Self { db }) + } + + /// Composite key for log entries: `{session_id}\x00{seq:08}` + fn log_key(session_id: &str, seq: u64) -> Vec { + format!("{session_id}\x00{seq:08}").into_bytes() + } + + /// Prefix for iterating all log entries of a session: `{session_id}\x00` + fn log_prefix(session_id: &str) -> Vec { + format!("{session_id}\x00").into_bytes() + } + + /// Find the next sequence number for a session's log entries. + fn next_seq(&self, session_id: &str) -> io::Result { + let cf = self + .db + .cf_handle(CF_LOGS) + .ok_or_else(|| io::Error::other("missing logs CF"))?; + + let prefix = Self::log_prefix(session_id); + + // Build the upper bound: increment the last byte before \x00 + // so the iterator only covers keys with this prefix. + let mut upper_bound = prefix.clone(); + // Replace trailing \x00 with \x01 to create an exclusive upper bound + if let Some(last) = upper_bound.last_mut() { + *last = 0x01; + } + + let mut read_opts = rocksdb::ReadOptions::default(); + read_opts.set_iterate_upper_bound(upper_bound); + + let mut iter = self.db.iterator_cf_opt(&cf, read_opts, IteratorMode::End); + + if let Some(item) = iter.next() { + let (key, _) = item.map_err(io::Error::other)?; + if key.starts_with(&prefix) { + let key_str = String::from_utf8_lossy(&key); + if let Some(seq_str) = key_str.rsplit('\x00').next() { + if let Ok(seq) = seq_str.parse::() { + return Ok(seq + 1); + } + } + } + } + Ok(0) + } +} + +#[async_trait::async_trait] +impl StorageBackend for RocksDbBackend { + async fn create_session_storage(&self, _session_id: &str) -> io::Result<()> { + // Column families are created at DB open time; nothing to do per-session. + Ok(()) + } + + async fn save_session(&self, session: &Session) -> io::Result<()> { + let cf = self + .db + .cf_handle(CF_SESSIONS) + .ok_or_else(|| io::Error::other("missing sessions CF"))?; + let persisted = PersistedSession::from(session); + let bytes = serde_json::to_vec(&persisted) + .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?; + self.db + .put_cf(&cf, session.session_id.as_bytes(), &bytes) + .map_err(io::Error::other) + } + + async fn load_session(&self, session_id: &str) -> io::Result> { + let cf = self + .db + .cf_handle(CF_SESSIONS) + .ok_or_else(|| io::Error::other("missing sessions CF"))?; + match self + .db + .get_cf(&cf, session_id.as_bytes()) + .map_err(io::Error::other)? + { + Some(bytes) => { + let persisted: PersistedSession = serde_json::from_slice(&bytes) + .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?; + Ok(Some(Session::from(persisted))) + } + None => Ok(None), + } + } + + async fn load_all_sessions(&self) -> io::Result> { + let ids = self.list_session_ids().await?; + let mut sessions = Vec::new(); + for id in ids { + if let Some(s) = self.load_session(&id).await? { + sessions.push(s); + } + } + Ok(sessions) + } + + async fn delete_session(&self, session_id: &str) -> io::Result<()> { + let sessions_cf = self + .db + .cf_handle(CF_SESSIONS) + .ok_or_else(|| io::Error::other("missing sessions CF"))?; + let logs_cf = self + .db + .cf_handle(CF_LOGS) + .ok_or_else(|| io::Error::other("missing logs CF"))?; + + let mut batch = WriteBatchWithTransaction::::default(); + + // Delete session record + batch.delete_cf(&sessions_cf, session_id.as_bytes()); + + // Delete all log entries with this session's prefix + let prefix = Self::log_prefix(session_id); + let iter = self.db.prefix_iterator_cf(&logs_cf, &prefix); + for item in iter { + let (key, _) = item.map_err(io::Error::other)?; + if !key.starts_with(&prefix) { + break; + } + batch.delete_cf(&logs_cf, &key); + } + + self.db.write(batch).map_err(io::Error::other) + } + + async fn list_session_ids(&self) -> io::Result> { + let cf = self + .db + .cf_handle(CF_SESSIONS) + .ok_or_else(|| io::Error::other("missing sessions CF"))?; + let mut ids = Vec::new(); + let iter = self.db.iterator_cf(&cf, IteratorMode::Start); + for item in iter { + let (key, _) = item.map_err(io::Error::other)?; + ids.push(String::from_utf8_lossy(&key).to_string()); + } + Ok(ids) + } + + async fn append_log_entry(&self, session_id: &str, entry: &LogEntry) -> io::Result<()> { + let cf = self + .db + .cf_handle(CF_LOGS) + .ok_or_else(|| io::Error::other("missing logs CF"))?; + let seq = self.next_seq(session_id)?; + let key = Self::log_key(session_id, seq); + let bytes = + serde_json::to_vec(entry).map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?; + self.db.put_cf(&cf, &key, &bytes).map_err(io::Error::other) + } + + async fn load_log(&self, session_id: &str) -> io::Result> { + let cf = self + .db + .cf_handle(CF_LOGS) + .ok_or_else(|| io::Error::other("missing logs CF"))?; + let prefix = Self::log_prefix(session_id); + let mut entries = Vec::new(); + let iter = self.db.prefix_iterator_cf(&cf, &prefix); + for item in iter { + let (key, value) = item.map_err(io::Error::other)?; + if !key.starts_with(&prefix) { + break; + } + let entry: LogEntry = serde_json::from_slice(&value) + .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?; + entries.push(entry); + } + Ok(entries) + } + + async fn replace_log(&self, session_id: &str, entries: &[LogEntry]) -> io::Result<()> { + let cf = self + .db + .cf_handle(CF_LOGS) + .ok_or_else(|| io::Error::other("missing logs CF"))?; + + let mut batch = WriteBatchWithTransaction::::default(); + + // Delete all existing log entries + let prefix = Self::log_prefix(session_id); + let iter = self.db.prefix_iterator_cf(&cf, &prefix); + for item in iter { + let (key, _) = item.map_err(io::Error::other)?; + if !key.starts_with(&prefix) { + break; + } + batch.delete_cf(&cf, &key); + } + + // Insert new entries + for (i, entry) in entries.iter().enumerate() { + let key = Self::log_key(session_id, i as u64); + let bytes = serde_json::to_vec(entry) + .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?; + batch.put_cf(&cf, &key, &bytes); + } + + self.db.write(batch).map_err(io::Error::other) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::log_store::EntryKind; + use crate::session::SessionState; + use std::collections::HashSet; + + fn sample_session(id: &str) -> Session { + Session { + session_id: id.into(), + state: SessionState::Open, + ttl_expiry: 61_000, + ttl_ms: 60_000, + started_at_unix_ms: 1_000, + resolution: None, + mode: "macp.mode.decision.v1".into(), + mode_state: vec![1, 2, 3], + participants: vec!["alice".into(), "bob".into()], + seen_message_ids: HashSet::from(["m1".into()]), + intent: "test".into(), + mode_version: "1.0.0".into(), + configuration_version: "cfg-1".into(), + policy_version: "pol-1".into(), + context: vec![9], + roots: vec![], + initiator_sender: "alice".into(), + } + } + + fn sample_entry(id: &str) -> LogEntry { + LogEntry { + message_id: id.into(), + received_at_ms: 1_700_000_000_000, + sender: "alice".into(), + message_type: "Message".into(), + raw_payload: vec![], + entry_kind: EntryKind::Incoming, + session_id: String::new(), + mode: String::new(), + macp_version: String::new(), + } + } + + #[tokio::test] + async fn rocksdb_session_round_trip() { + let dir = tempfile::tempdir().unwrap(); + let backend = RocksDbBackend::open(dir.path().join("db")).unwrap(); + + backend.create_session_storage("s1").await.unwrap(); + backend.save_session(&sample_session("s1")).await.unwrap(); + + let loaded = backend.load_session("s1").await.unwrap().unwrap(); + assert_eq!(loaded.session_id, "s1"); + assert_eq!(loaded.ttl_ms, 60_000); + assert!(loaded.seen_message_ids.contains("m1")); + } + + #[tokio::test] + async fn rocksdb_log_append_and_load() { + let dir = tempfile::tempdir().unwrap(); + let backend = RocksDbBackend::open(dir.path().join("db")).unwrap(); + + backend.create_session_storage("s1").await.unwrap(); + for id in ["m1", "m2", "m3"] { + backend + .append_log_entry("s1", &sample_entry(id)) + .await + .unwrap(); + } + + let log = backend.load_log("s1").await.unwrap(); + assert_eq!(log.len(), 3); + assert_eq!(log[0].message_id, "m1"); + assert_eq!(log[1].message_id, "m2"); + assert_eq!(log[2].message_id, "m3"); + } + + #[tokio::test] + async fn rocksdb_list_and_delete() { + let dir = tempfile::tempdir().unwrap(); + let backend = RocksDbBackend::open(dir.path().join("db")).unwrap(); + + for id in ["s1", "s2"] { + backend.save_session(&sample_session(id)).await.unwrap(); + backend + .append_log_entry(id, &sample_entry("m1")) + .await + .unwrap(); + } + + let mut ids = backend.list_session_ids().await.unwrap(); + ids.sort(); + assert_eq!(ids, vec!["s1", "s2"]); + + backend.delete_session("s1").await.unwrap(); + assert!(backend.load_session("s1").await.unwrap().is_none()); + assert!(backend.load_log("s1").await.unwrap().is_empty()); + + // Idempotent delete + backend.delete_session("s1").await.unwrap(); + + let ids = backend.list_session_ids().await.unwrap(); + assert_eq!(ids, vec!["s2"]); + } + + #[tokio::test] + async fn rocksdb_replace_log() { + let dir = tempfile::tempdir().unwrap(); + let backend = RocksDbBackend::open(dir.path().join("db")).unwrap(); + + for i in 0..5 { + backend + .append_log_entry("s1", &sample_entry(&format!("m{i}"))) + .await + .unwrap(); + } + assert_eq!(backend.load_log("s1").await.unwrap().len(), 5); + + let replacement = vec![sample_entry("checkpoint")]; + backend.replace_log("s1", &replacement).await.unwrap(); + + let log = backend.load_log("s1").await.unwrap(); + assert_eq!(log.len(), 1); + assert_eq!(log[0].message_id, "checkpoint"); + } + + #[tokio::test] + async fn rocksdb_load_all_sessions() { + let dir = tempfile::tempdir().unwrap(); + let backend = RocksDbBackend::open(dir.path().join("db")).unwrap(); + + for id in ["s1", "s2", "s3"] { + backend.save_session(&sample_session(id)).await.unwrap(); + } + + let mut sessions = backend.load_all_sessions().await.unwrap(); + sessions.sort_by(|a, b| a.session_id.cmp(&b.session_id)); + assert_eq!(sessions.len(), 3); + assert_eq!(sessions[0].session_id, "s1"); + } + + #[tokio::test] + async fn rocksdb_logs_isolated_between_sessions() { + let dir = tempfile::tempdir().unwrap(); + let backend = RocksDbBackend::open(dir.path().join("db")).unwrap(); + + backend + .append_log_entry("s1", &sample_entry("a")) + .await + .unwrap(); + backend + .append_log_entry("s2", &sample_entry("b")) + .await + .unwrap(); + + let log1 = backend.load_log("s1").await.unwrap(); + assert_eq!(log1.len(), 1); + assert_eq!(log1[0].message_id, "a"); + + let log2 = backend.load_log("s2").await.unwrap(); + assert_eq!(log2.len(), 1); + assert_eq!(log2[0].message_id, "b"); + } +} diff --git a/tests/file_backend_integration.rs b/tests/file_backend_integration.rs index b8d08e1..64b2192 100644 --- a/tests/file_backend_integration.rs +++ b/tests/file_backend_integration.rs @@ -133,9 +133,20 @@ async fn file_backend_full_lifecycle() { .unwrap(); assert_eq!(result.session_state, SessionState::Resolved); - // Verify log was persisted + // After resolution, the log is compacted to a single checkpoint entry let log = storage.load_log(&sid).await.unwrap(); - assert_eq!(log.len(), 4); + assert!( + !log.is_empty(), + "log should have at least one entry after compaction" + ); + // Verify the session can be replayed from the compacted log + let replayed = macp_runtime::replay::replay_session( + &sid, + &log, + &macp_runtime::mode_registry::ModeRegistry::build_default(), + ) + .unwrap(); + assert_eq!(replayed.state, SessionState::Resolved); } #[tokio::test]