diff --git a/README.md b/README.md index 512476d..6c53dd0 100644 --- a/README.md +++ b/README.md @@ -140,7 +140,8 @@ Token JSON may be either a raw list or an object with a `tokens` array. Example: "allowed_modes": [ "macp.mode.task.v1" ], - "can_start_sessions": false + "can_start_sessions": false, + "can_manage_mode_registry": false } ] } diff --git a/src/metrics.rs b/src/metrics.rs index 0b3273d..6a2ea4f 100644 --- a/src/metrics.rs +++ b/src/metrics.rs @@ -1,6 +1,6 @@ use std::collections::HashMap; use std::sync::atomic::{AtomicU64, Ordering}; -use std::sync::RwLock; +use std::sync::{Arc, RwLock}; pub struct ModeMetrics { pub messages_accepted: AtomicU64, @@ -35,7 +35,7 @@ impl Default for ModeMetrics { } pub struct RuntimeMetrics { - per_mode: RwLock>, + per_mode: RwLock>>, } impl RuntimeMetrics { @@ -93,23 +93,20 @@ impl RuntimeMetrics { .fetch_add(1, Ordering::Relaxed); } - fn get_or_create(&self, mode: &str) -> &ModeMetrics { - // Fast path: read lock + fn get_or_create(&self, mode: &str) -> Arc { { let guard = self.per_mode.read().unwrap(); - if guard.contains_key(mode) { - // SAFETY: We never remove entries and HashMap doesn't move values - // on insert of other keys. The reference is valid for the lifetime - // of RuntimeMetrics. - let ptr = guard.get(mode).unwrap() as *const ModeMetrics; - return unsafe { &*ptr }; + if let Some(metrics) = guard.get(mode) { + return Arc::clone(metrics); } } - // Slow path: write lock to insert + let mut guard = self.per_mode.write().unwrap(); - guard.entry(mode.to_string()).or_default(); - let ptr = guard.get(mode).unwrap() as *const ModeMetrics; - unsafe { &*ptr } + Arc::clone( + guard + .entry(mode.to_string()) + .or_insert_with(|| Arc::new(ModeMetrics::default())), + ) } pub fn snapshot(&self) -> Vec<(String, MetricsSnapshot)> { diff --git a/src/mode_registry.rs b/src/mode_registry.rs index de44174..fa950e1 100644 --- a/src/mode_registry.rs +++ b/src/mode_registry.rs @@ -86,6 +86,30 @@ impl ModeRegistry { } } + fn ordered_standard_names(entries: &HashMap) -> Vec { + let mut names: Vec = STANDARD_MODE_NAMES + .iter() + .filter(|name| { + entries + .get(**name) + .map(|entry| entry.standards_track) + .unwrap_or(false) + }) + .map(|name| (*name).to_string()) + .collect(); + + let mut promoted: Vec = entries + .iter() + .filter(|(name, entry)| { + entry.standards_track && !STANDARD_MODE_NAMES.contains(&name.as_str()) + }) + .map(|(name, _)| name.clone()) + .collect(); + promoted.sort(); + names.extend(promoted); + names + } + pub fn get_mode(&self, name: &str) -> Option> { let guard = self.entries.read().expect("mode registry lock poisoned"); if guard.contains_key(name) { @@ -109,50 +133,54 @@ impl ModeRegistry { pub fn standard_mode_names(&self) -> Vec { let guard = self.entries.read().expect("mode registry lock poisoned"); - STANDARD_MODE_NAMES - .iter() - .filter(|name| guard.contains_key(**name)) - .map(|name| (*name).to_string()) - .collect() + Self::ordered_standard_names(&guard) } pub fn standard_mode_descriptors(&self) -> Vec { let guard = self.entries.read().expect("mode registry lock poisoned"); - STANDARD_MODE_NAMES - .iter() - .filter_map(|name| guard.get(*name).and_then(|e| e.descriptor.clone())) + Self::ordered_standard_names(&guard) + .into_iter() + .filter_map(|name| guard.get(&name).and_then(|entry| entry.descriptor.clone())) .collect() } pub fn extension_mode_names(&self) -> Vec { let guard = self.entries.read().expect("mode registry lock poisoned"); - guard + let mut names: Vec = guard .iter() .filter(|(_, e)| !e.standards_track) .map(|(name, _)| name.clone()) - .collect() + .collect(); + names.sort(); + names } pub fn extension_mode_descriptors(&self) -> Vec { let guard = self.entries.read().expect("mode registry lock poisoned"); - guard + let mut descriptors: Vec = guard .iter() .filter(|(_, e)| !e.standards_track) .filter_map(|(_, e)| e.descriptor.clone()) - .collect() + .collect(); + descriptors.sort_by(|a, b| a.mode.cmp(&b.mode)); + descriptors } pub fn all_mode_names(&self) -> Vec { let guard = self.entries.read().expect("mode registry lock poisoned"); - guard.keys().cloned().collect() + let mut names: Vec = guard.keys().cloned().collect(); + names.sort(); + names } pub fn all_mode_descriptors(&self) -> Vec { let guard = self.entries.read().expect("mode registry lock poisoned"); - guard + let mut descriptors: Vec = guard .values() .filter_map(|e| e.descriptor.clone()) - .collect() + .collect(); + descriptors.sort_by(|a, b| a.mode.cmp(&b.mode)); + descriptors } pub fn is_standard_mode(&self, name: &str) -> bool { @@ -481,6 +509,32 @@ mod tests { assert!(registry.is_standard_mode("ext.keep.v1")); } + #[test] + fn promoted_mode_appears_in_standard_mode_names_and_descriptors() { + let registry = ModeRegistry::build_default(); + let descriptor = ModeDescriptor { + mode: "ext.promoted.v1".into(), + mode_version: "1.0.0".into(), + title: "Promoted".into(), + message_types: vec!["SessionStart".into(), "Commitment".into()], + ..Default::default() + }; + registry.register_extension(descriptor).unwrap(); + registry + .promote_mode("ext.promoted.v1", Some("macp.mode.promoted.v1")) + .unwrap(); + + let standard_names = registry.standard_mode_names(); + assert!(standard_names.contains(&"macp.mode.promoted.v1".to_string())); + + let standard_modes: Vec = registry + .standard_mode_descriptors() + .into_iter() + .map(|d| d.mode) + .collect(); + assert!(standard_modes.contains(&"macp.mode.promoted.v1".to_string())); + } + #[test] fn promote_already_standard_fails() { let registry = ModeRegistry::build_default(); diff --git a/src/replay.rs b/src/replay.rs index ddc7e04..5d9e923 100644 --- a/src/replay.rs +++ b/src/replay.rs @@ -3,8 +3,8 @@ use crate::log_store::{EntryKind, LogEntry}; use crate::mode_registry::ModeRegistry; use crate::pb::Envelope; use crate::session::{ - extract_ttl_ms, parse_session_start_payload, requires_strict_session_start, - validate_strict_session_start_payload, Session, SessionState, + extract_ttl_ms, parse_session_start_payload, validate_canonical_session_start_payload, Session, + SessionState, }; /// Rebuild a `Session` from its append-only log. @@ -35,15 +35,18 @@ pub fn replay_session( let mode = registry.get_mode(mode_name).ok_or(MacpError::UnknownMode)?; // 2. Parse SessionStartPayload - let start_payload = - if start_entry.raw_payload.is_empty() && !requires_strict_session_start(mode_name) { - crate::pb::SessionStartPayload::default() - } else { - parse_session_start_payload(&start_entry.raw_payload)? - }; - validate_strict_session_start_payload(mode_name, &start_payload)?; + let require_complete_start = + registry.is_standard_mode(mode_name) || mode_name == "ext.multi_round.v1"; + let start_payload = if start_entry.raw_payload.is_empty() && !require_complete_start { + crate::pb::SessionStartPayload::default() + } else { + parse_session_start_payload(&start_entry.raw_payload)? + }; + if require_complete_start { + validate_canonical_session_start_payload(&start_payload)?; + } - let ttl_ms = if !requires_strict_session_start(mode_name) && start_payload.ttl_ms == 0 { + let ttl_ms = if !require_complete_start && start_payload.ttl_ms == 0 { // Legacy experimental modes may have 0 ttl_ms 60_000i64 } else { @@ -125,11 +128,12 @@ pub fn replay_session( continue; } - // Replay through mode — errors during replay are not fatal, - // the message was already accepted in the original run - if let Ok(response) = mode.on_message(&session, &replay_env) { - session.apply_mode_response(response); - } + // Replay through the same authorization and mode callbacks used during + // live processing. Accepted history that no longer replays cleanly must + // fail recovery instead of silently drifting session state. + mode.authorize_sender(&session, &replay_env)?; + let response = mode.on_message(&session, &replay_env)?; + session.apply_mode_response(response); if !replay_env.message_id.is_empty() { session.seen_message_ids.insert(replay_env.message_id); } @@ -318,6 +322,36 @@ mod tests { assert_eq!(session.state, SessionState::Expired); } + #[test] + fn replay_fails_when_accepted_history_no_longer_applies() { + let registry = make_registry(); + let vote = VotePayload { + proposal_id: "p1".into(), + vote: "approve".into(), + reason: String::new(), + } + .encode_to_vec(); + let entries = vec![ + incoming_entry( + "m1", + "SessionStart", + "agent://orchestrator", + start_payload_bytes(), + 1000, + ), + incoming_entry("m2", "Vote", "agent://fraud", vote, 2000), + ]; + + let err = replay_session("s1", &entries, ®istry).unwrap_err(); + // The exact error variant depends on which check fails first (authorize_sender + // or on_message); what matters is that replay does NOT silently succeed. + let msg = err.to_string(); + assert!( + msg == "InvalidTransition" || msg == "InvalidPayload" || msg == "Forbidden", + "unexpected error: {msg}" + ); + } + #[test] fn replay_empty_log_returns_error() { let registry = make_registry(); diff --git a/src/runtime.rs b/src/runtime.rs index 5b630dc..f3e6f82 100644 --- a/src/runtime.rs +++ b/src/runtime.rs @@ -8,8 +8,8 @@ use crate::mode_registry::ModeRegistry; use crate::pb::{Envelope, ModeDescriptor}; use crate::registry::SessionRegistry; use crate::session::{ - extract_ttl_ms, parse_session_start_payload, validate_session_id_for_acceptance, - validate_strict_session_start_payload, Session, SessionState, + extract_ttl_ms, parse_session_start_payload, validate_canonical_session_start_payload, + validate_session_id_for_acceptance, Session, SessionState, }; use crate::storage::StorageBackend; use crate::stream_bus::SessionStreamBus; @@ -204,7 +204,11 @@ impl Runtime { .ok_or(MacpError::UnknownMode)?; let start_payload = parse_session_start_payload(&env.payload)?; - validate_strict_session_start_payload(mode_name, &start_payload)?; + let require_complete_start = + self.mode_registry.is_standard_mode(mode_name) || mode_name == "ext.multi_round.v1"; + if require_complete_start { + validate_canonical_session_start_payload(&start_payload)?; + } let ttl_ms = extract_ttl_ms(&start_payload)?; let mut guard = self.registry.sessions.write().await; diff --git a/src/security.rs b/src/security.rs index 6ef95e2..6b2c3a0 100644 --- a/src/security.rs +++ b/src/security.rs @@ -13,6 +13,7 @@ pub struct AuthIdentity { pub allowed_modes: Option>, pub can_start_sessions: bool, pub max_open_sessions: Option, + pub can_manage_mode_registry: bool, } #[derive(Clone, Debug, serde::Deserialize)] @@ -24,6 +25,8 @@ struct RawIdentity { #[serde(default = "default_true")] can_start_sessions: bool, max_open_sessions: Option, + #[serde(default)] + can_manage_mode_registry: bool, } #[derive(Clone, Debug, serde::Deserialize)] @@ -147,6 +150,7 @@ impl SecurityLayer { }, can_start_sessions: item.can_start_sessions, max_open_sessions: item.max_open_sessions, + can_manage_mode_registry: item.can_manage_mode_registry, }, ); } @@ -186,6 +190,7 @@ impl SecurityLayer { allowed_modes: None, can_start_sessions: true, max_open_sessions: None, + can_manage_mode_registry: true, }); } } @@ -210,6 +215,14 @@ impl SecurityLayer { Ok(()) } + pub fn authorize_mode_registry(&self, identity: &AuthIdentity) -> Result<(), MacpError> { + if identity.can_manage_mode_registry { + Ok(()) + } else { + Err(MacpError::Forbidden) + } + } + async fn check_bucket( bucket: &Mutex>>, sender: &str, @@ -543,6 +556,7 @@ mod tests { allowed_modes: None, can_start_sessions: true, max_open_sessions: None, + can_manage_mode_registry: false, }; assert!(layer .authorize_mode(&id, "macp.mode.decision.v1", false) @@ -562,6 +576,7 @@ mod tests { allowed_modes: Some(allowed), can_start_sessions: true, max_open_sessions: None, + can_manage_mode_registry: false, }; assert!(layer .authorize_mode(&id, "macp.mode.decision.v1", false) @@ -580,6 +595,7 @@ mod tests { allowed_modes: None, can_start_sessions: false, max_open_sessions: None, + can_manage_mode_registry: false, }; let err = layer .authorize_mode(&id, "macp.mode.decision.v1", true) @@ -595,6 +611,7 @@ mod tests { allowed_modes: None, can_start_sessions: false, max_open_sessions: None, + can_manage_mode_registry: false, }; // Regular messages (not session start) should succeed assert!(layer @@ -613,6 +630,7 @@ mod tests { allowed_modes: Some(allowed), can_start_sessions: false, max_open_sessions: None, + can_manage_mode_registry: false, }; // Cannot start sessions (checked first) @@ -633,6 +651,29 @@ mod tests { .is_ok()); } + #[test] + fn authorize_mode_registry_requires_explicit_privilege() { + let layer = SecurityLayer::dev_mode(); + let id = AuthIdentity { + sender: "agent://no-admin".into(), + allowed_modes: None, + can_start_sessions: true, + max_open_sessions: None, + can_manage_mode_registry: false, + }; + let err = layer.authorize_mode_registry(&id).unwrap_err(); + assert!(matches!(err, MacpError::Forbidden)); + } + + #[test] + fn dev_sender_header_can_manage_mode_registry() { + let layer = SecurityLayer::dev_mode(); + let mut meta = MetadataMap::new(); + meta.insert("x-macp-agent-id", "agent://dev-admin".parse().unwrap()); + let id = layer.authenticate_metadata(&meta).unwrap(); + assert!(layer.authorize_mode_registry(&id).is_ok()); + } + // --------------------------------------------------------------- // 6. enforce_rate_limit() with session_start and message categories // --------------------------------------------------------------- diff --git a/src/server.rs b/src/server.rs index b256f47..dfee79c 100644 --- a/src/server.rs +++ b/src/server.rs @@ -669,6 +669,13 @@ impl MacpRuntimeService for MacpServer { &self, request: Request, ) -> Result, Status> { + let identity = self + .security + .authenticate_metadata(request.metadata()) + .map_err(Self::status_from_error)?; + self.security + .authorize_mode_registry(&identity) + .map_err(Self::status_from_error)?; let req = request.into_inner(); let descriptor = req .descriptor @@ -689,6 +696,13 @@ impl MacpRuntimeService for MacpServer { &self, request: Request, ) -> Result, Status> { + let identity = self + .security + .authenticate_metadata(request.metadata()) + .map_err(Self::status_from_error)?; + self.security + .authorize_mode_registry(&identity) + .map_err(Self::status_from_error)?; let req = request.into_inner(); match self.runtime.unregister_extension(&req.mode) { Ok(()) => Ok(Response::new(UnregisterExtModeResponse { @@ -706,6 +720,13 @@ impl MacpRuntimeService for MacpServer { &self, request: Request, ) -> Result, Status> { + let identity = self + .security + .authenticate_metadata(request.metadata()) + .map_err(Self::status_from_error)?; + self.security + .authorize_mode_registry(&identity) + .map_err(Self::status_from_error)?; let req = request.into_inner(); let new_name = if req.promoted_mode_name.is_empty() { None @@ -853,12 +874,35 @@ mod tests { assert_eq!(err.code(), tonic::Code::PermissionDenied); } + #[tokio::test] + async fn register_ext_mode_requires_authenticated_registry_permission() { + let storage: Arc = + Arc::new(macp_runtime::storage::MemoryBackend); + let registry = Arc::new(SessionRegistry::new()); + let log_store = Arc::new(LogStore::new()); + let runtime = Arc::new(Runtime::new(storage, registry, log_store)); + let security = SecurityLayer::from_env().unwrap_or_else(|_| SecurityLayer::dev_mode()); + let server = MacpServer::new(runtime, security); + + let req = Request::new(RegisterExtModeRequest { + descriptor: Some(macp_runtime::pb::ModeDescriptor { + mode: "ext.custom.v1".into(), + mode_version: "1.0.0".into(), + message_types: vec!["SessionStart".into(), "Commitment".into()], + ..Default::default() + }), + }); + let err = server.register_ext_mode(req).await.unwrap_err(); + assert_eq!(err.code(), tonic::Code::Unauthenticated); + } + fn stream_identity(sender: &str) -> AuthIdentity { AuthIdentity { sender: sender.into(), allowed_modes: None, can_start_sessions: true, max_open_sessions: None, + can_manage_mode_registry: false, } } diff --git a/src/session.rs b/src/session.rs index 957ecf8..81de2b0 100644 --- a/src/session.rs +++ b/src/session.rs @@ -80,15 +80,10 @@ pub fn extract_ttl_ms(payload: &SessionStartPayload) -> Result { Ok(payload.ttl_ms) } -/// Enforce the strict SessionStart binding contract for standards-track and qualifying extension modes. -pub fn validate_strict_session_start_payload( - mode: &str, +/// Validate the complete canonical SessionStart binding contract. +pub fn validate_canonical_session_start_payload( payload: &SessionStartPayload, ) -> Result<(), MacpError> { - if !requires_strict_session_start(mode) { - return Ok(()); - } - extract_ttl_ms(payload)?; if payload.mode_version.trim().is_empty() || payload.configuration_version.trim().is_empty() { @@ -110,6 +105,18 @@ pub fn validate_strict_session_start_payload( Ok(()) } +/// Enforce the strict SessionStart binding contract for standards-track and qualifying extension modes. +pub fn validate_strict_session_start_payload( + mode: &str, + payload: &SessionStartPayload, +) -> Result<(), MacpError> { + if !requires_strict_session_start(mode) { + return Ok(()); + } + + validate_canonical_session_start_payload(payload) +} + /// Validate that a session ID meets the acceptance policy. /// /// Accepts: