From bfdd68db1ae319832278338f8d8ee834d3f5ca55 Mon Sep 17 00:00:00 2001 From: wayslog Date: Thu, 20 Nov 2025 18:02:48 +0800 Subject: [PATCH 1/4] chore: add more test cases --- src/backend/client.rs | 31 ++++++ src/backend/pool.rs | 129 +++++++++++++++++++++++++ src/cache/mod.rs | 144 ++++++++++++++++++++-------- src/cache/tracker.rs | 93 +++++++++++++++++- src/cluster/mod.rs | 138 +++++++++++++++++++++++++-- src/config/mod.rs | 207 +++++++++++++++++++++++++++++++++++++++-- src/info.rs | 38 ++++++++ src/lib.rs | 2 +- src/meta.rs | 107 +++++++++++++++++++++ src/metrics/mod.rs | 42 ++++++++- src/metrics/tracker.rs | 20 ++++ src/slowlog.rs | 57 ++++++++++++ src/standalone/mod.rs | 85 ++++++++++++++++- 13 files changed, 1028 insertions(+), 65 deletions(-) diff --git a/src/backend/client.rs b/src/backend/client.rs index 2713a18..5f3eafe 100644 --- a/src/backend/client.rs +++ b/src/backend/client.rs @@ -46,3 +46,34 @@ impl<'a> Drop for FrontConnectionGuard<'a> { metrics::front_conn_close(self.cluster); } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::metrics; + + #[test] + fn client_ids_are_monotonic() { + let a = ClientId::new(); + let b = ClientId::new(); + assert!(b.as_u64() > a.as_u64()); + } + + #[test] + fn front_connection_guard_updates_metrics() { + let cluster = "guard-cluster"; + let initial_current = metrics::front_connections_current(cluster); + let initial_total = metrics::front_connections_total(cluster); + + { + let _guard = FrontConnectionGuard::new(cluster); + assert_eq!( + metrics::front_connections_current(cluster), + initial_current + 1 + ); + } + + assert_eq!(metrics::front_connections_current(cluster), initial_current); + assert_eq!(metrics::front_connections_total(cluster), initial_total + 1); + } +} diff --git a/src/backend/pool.rs b/src/backend/pool.rs index 8893923..99ee658 100644 --- a/src/backend/pool.rs +++ b/src/backend/pool.rs @@ -345,3 +345,132 @@ impl<'a, T: BackendRequest> Drop for ExclusiveConnection<'a, T> { } } } + +#[cfg(test)] +mod tests { + use super::*; + use std::sync::atomic::{AtomicUsize, Ordering}; + use std::sync::{Arc, Mutex}; + + #[derive(Default)] + struct TestConnector { + started: AtomicUsize, + } + + #[async_trait] + impl Connector for TestConnector { + async fn run_session( + self: Arc, + _node: BackendNode, + _cluster: Arc, + mut rx: mpsc::Receiver>, + ) { + self.started.fetch_add(1, Ordering::SeqCst); + while let Some(cmd) = rx.recv().await { + let _ = cmd.respond_to.send(Ok(cmd.request.payload)); + } + } + } + + impl TestConnector { + fn started(&self) -> usize { + self.started.load(Ordering::SeqCst) + } + } + + #[derive(Clone)] + struct CallRecorder { + values: Arc>>, + } + + impl CallRecorder { + fn new() -> Self { + Self { + values: Arc::new(Mutex::new(Vec::new())), + } + } + + fn record(&self, value: &str) { + self.values.lock().unwrap().push(value.to_string()); + } + + fn entries(&self) -> Vec { + self.values.lock().unwrap().clone() + } + } + + #[derive(Clone)] + struct TestRequest { + payload: &'static str, + total: CallRecorder, + remote: CallRecorder, + } + + impl TestRequest { + fn new(payload: &'static str, total: CallRecorder, remote: CallRecorder) -> Self { + Self { + payload, + total, + remote, + } + } + } + + impl BackendRequest for TestRequest { + type Response = &'static str; + + fn apply_total_tracker(&mut self, cluster: &str) { + self.total.record(cluster); + } + + fn apply_remote_tracker(&mut self, cluster: &str) { + self.remote.record(cluster); + } + } + + fn cluster_name() -> Arc { + Arc::::from("cluster-backend-tests".to_string()) + } + + fn backend_node() -> BackendNode { + BackendNode::new("127.0.0.1:7000".into()) + } + + #[tokio::test(flavor = "current_thread")] + async fn dispatch_sends_request_and_tracks_cluster() { + let connector = Arc::new(TestConnector::default()); + let total = CallRecorder::new(); + let remote = CallRecorder::new(); + let request = TestRequest::new("ok", total.clone(), remote.clone()); + let pool = ConnectionPool::with_slots(cluster_name(), connector.clone(), 2); + let node = backend_node(); + let rx = pool + .dispatch(node.clone(), ClientId::new(), request) + .await + .expect("dispatch"); + let response = rx.await.expect("oneshot").expect("response"); + assert_eq!(response, "ok"); + assert_eq!(connector.started(), 1); + assert_eq!(total.entries(), vec!["cluster-backend-tests".to_string()]); + assert_eq!(remote.entries(), vec!["cluster-backend-tests".to_string()]); + } + + #[tokio::test(flavor = "current_thread")] + async fn exclusive_connections_are_reused() { + let connector = Arc::new(TestConnector::default()); + let pool = ConnectionPool::::with_slots(cluster_name(), connector.clone(), 1); + let node = backend_node(); + + { + let _conn = pool.acquire_exclusive(&node); + tokio::task::yield_now().await; + assert_eq!(connector.started(), 1); + } + + { + let _conn = pool.acquire_exclusive(&node); + tokio::task::yield_now().await; + assert_eq!(connector.started(), 1); + } + } +} diff --git a/src/cache/mod.rs b/src/cache/mod.rs index e8f3df2..e759d1d 100644 --- a/src/cache/mod.rs +++ b/src/cache/mod.rs @@ -95,7 +95,10 @@ impl ClientCache { state: AtomicU8::new(initial_state.as_u8()), resp3_ready, shards: ArcSwap::from_pointee(shards), - config: RwLock::new(ClientCacheConfig { enabled: initial_state == CacheState::Enabled, ..config }), + config: RwLock::new(ClientCacheConfig { + enabled: initial_state == CacheState::Enabled, + ..config + }), drain_handle: Mutex::new(None), state_tx, }; @@ -119,9 +122,7 @@ impl ClientCache { if !self.resp3_ready { bail!("client cache requires RESP3 backend support"); } - let prev = self - .state - .swap(STATE_ENABLED, Ordering::SeqCst); + let prev = self.state.swap(STATE_ENABLED, Ordering::SeqCst); self.stop_drain_task(); self.state_tx.send_replace(CacheState::Enabled); if prev != STATE_ENABLED { @@ -135,9 +136,7 @@ impl ClientCache { } pub fn disable(self: &Arc) { - let prev = self - .state - .swap(STATE_DRAINING, Ordering::SeqCst); + let prev = self.state.swap(STATE_DRAINING, Ordering::SeqCst); if prev == STATE_DISABLED { self.state_tx.send_replace(CacheState::Disabled); return; @@ -158,11 +157,7 @@ impl ClientCache { match classify_read(command) { CacheRead::Single { kind, key, field } => { let hit = self.lookup_single(kind, key, field); - metrics::client_cache_lookup( - self.cluster.as_ref(), - kind.label(), - hit.is_some(), - ); + metrics::client_cache_lookup(self.cluster.as_ref(), kind.label(), hit.is_some()); hit } CacheRead::Multi { keys } => { @@ -324,12 +319,7 @@ impl ClientCache { metrics::client_cache_store(self.cluster.as_ref(), kind.label()); } - fn store_multi( - &self, - config: &ClientCacheConfig, - keys: &[&Bytes], - response: &RespValue, - ) { + fn store_multi(&self, config: &ClientCacheConfig, keys: &[&Bytes], response: &RespValue) { let values = match response.as_array() { Some(values) if values.len() == keys.len() => values, _ => return, @@ -402,13 +392,17 @@ impl ClientCache { fn normalize_value(kind: CacheCommandKind, resp: &RespValue) -> Option { match kind { CacheCommandKind::Value => match resp { - RespValue::BulkString(_) | RespValue::SimpleString(_) | RespValue::Null | RespValue::NullBulk => - Some(resp.clone()), + RespValue::BulkString(_) + | RespValue::SimpleString(_) + | RespValue::Null + | RespValue::NullBulk => Some(resp.clone()), _ => None, }, CacheCommandKind::HashField => match resp { - RespValue::BulkString(_) | RespValue::SimpleString(_) | RespValue::Null | RespValue::NullBulk => - Some(resp.clone()), + RespValue::BulkString(_) + | RespValue::SimpleString(_) + | RespValue::Null + | RespValue::NullBulk => Some(resp.clone()), _ => None, }, } @@ -423,13 +417,12 @@ fn resp_size(value: &RespValue) -> usize { | RespValue::Double(data) | RespValue::BigNumber(data) => data.len(), RespValue::Integer(_) => std::mem::size_of::(), - RespValue::Null - | RespValue::NullBulk - | RespValue::NullArray => 1, + RespValue::Null | RespValue::NullBulk | RespValue::NullArray => 1, RespValue::Boolean(_) => 1, - RespValue::Map(entries) | RespValue::Attribute(entries) => { - entries.iter().map(|(k, v)| resp_size(k) + resp_size(v)).sum() - } + RespValue::Map(entries) | RespValue::Attribute(entries) => entries + .iter() + .map(|(k, v)| resp_size(k) + resp_size(v)) + .sum(), RespValue::Array(values) | RespValue::Set(values) | RespValue::Push(values) => { values.iter().map(resp_size).sum() } @@ -476,10 +469,7 @@ struct CacheEntry { impl CacheEntry { fn new(value: RespValue) -> Self { - Self { - value, - access: 0, - } + Self { value, access: 0 } } } @@ -512,12 +502,7 @@ impl CacheShard { } } - fn get( - &self, - kind: CacheCommandKind, - key: &Bytes, - field: Option<&Bytes>, - ) -> Option { + fn get(&self, kind: CacheCommandKind, key: &Bytes, field: Option<&Bytes>) -> Option { let mut guard = self.inner.lock(); guard.touch(&CacheKey::new(kind, key.clone(), field.cloned())) } @@ -652,7 +637,10 @@ impl CacheShardInner { if let Some(stored) = self.entries.get(&entry.key) { if stored.access == entry.access { self.detach(&entry.key); - return self.entries.remove(&entry.key).map(|value| (entry.key, value)); + return self + .entries + .remove(&entry.key) + .map(|value| (entry.key, value)); } } } @@ -804,3 +792,81 @@ fn classify_write(command: &RedisCommand) -> Option> { _ => None, } } + +#[cfg(test)] +mod tests { + use super::*; + use bytes::Bytes; + use std::sync::Arc; + + fn config() -> ClientCacheConfig { + ClientCacheConfig { + enabled: true, + max_entries: 32, + max_value_bytes: 1024, + shard_count: 2, + drain_batch: 8, + drain_interval_ms: 10, + } + } + + fn cache() -> ClientCache { + ClientCache::new( + Arc::::from("cache-cluster".to_string()), + config(), + true, + ) + } + + fn command(parts: &[&[u8]]) -> RedisCommand { + let parts = parts + .iter() + .map(|p| Bytes::copy_from_slice(p)) + .collect::>(); + RedisCommand::new(parts).expect("command") + } + + #[test] + fn cache_store_and_lookup_value() { + let cache = cache(); + let cmd = command(&[b"GET", b"foo"]); + assert!(cache.lookup(&cmd).is_none()); + let value = RespValue::bulk("bar"); + cache.store(&cmd, &value); + assert_eq!(cache.lookup(&cmd), Some(value)); + } + + #[test] + fn cache_store_multi_populates_individual_keys() { + let cache = cache(); + let mget = command(&[b"MGET", b"foo", b"bar"]); + let resp = RespValue::array(vec![RespValue::bulk("v1"), RespValue::bulk("v2")]); + cache.store(&mget, &resp); + let foo = command(&[b"GET", b"foo"]); + let bar = command(&[b"GET", b"bar"]); + assert_eq!(cache.lookup(&foo), Some(RespValue::bulk("v1"))); + assert_eq!(cache.lookup(&bar), Some(RespValue::bulk("v2"))); + } + + #[test] + fn cache_invalidate_command_clears_entries() { + let cache = cache(); + let get = command(&[b"GET", b"key"]); + let resp = RespValue::bulk("value"); + cache.store(&get, &resp); + assert!(cache.lookup(&get).is_some()); + let del = command(&[b"DEL", b"key"]); + cache.invalidate_command(&del); + assert!(cache.lookup(&get).is_none()); + } + + #[test] + fn classification_helpers_identify_cacheable_commands() { + let get = command(&[b"GET", b"key"]); + let set = command(&[b"SET", b"key", b"value"]); + assert!(ClientCache::is_cacheable_read(&get)); + assert!(!ClientCache::is_cacheable_read(&set)); + assert!(ClientCache::is_invalidating_write(&set)); + assert!(!ClientCache::is_invalidating_write(&command(&[b"PING"]))); + } +} diff --git a/src/cache/tracker.rs b/src/cache/tracker.rs index d271dcd..f50b0c2 100644 --- a/src/cache/tracker.rs +++ b/src/cache/tracker.rs @@ -59,7 +59,8 @@ impl CacheTrackerSet { pub fn set_nodes(&self, nodes: Vec) { let mut guard = self.handles.lock(); - let desired: HashSet> = nodes.iter().map(|n| Arc::::from(n.clone())).collect(); + let desired: HashSet> = + nodes.iter().map(|n| Arc::::from(n.clone())).collect(); guard.retain(|addr, handle| { if desired.contains(addr) { @@ -194,7 +195,14 @@ async fn listen_once( ) -> Result<()> { let timeout_duration = runtime.request_timeout(timeout_ms); let mut framed = open_stream(&address, timeout_duration, backend_auth.clone()).await?; - negotiate_resp3(&cluster, &address, timeout_duration, &mut framed, backend_auth.clone()).await?; + negotiate_resp3( + &cluster, + &address, + timeout_duration, + &mut framed, + backend_auth.clone(), + ) + .await?; enable_tracking(&cluster, &address, timeout_duration, &mut framed).await?; loop { @@ -227,7 +235,9 @@ async fn open_stream( let stream = timeout(timeout_duration, TcpStream::connect(address)) .await .with_context(|| format!("connect to {address} timed out"))??; - stream.set_nodelay(true).context("failed to enable TCP_NODELAY")?; + stream + .set_nodelay(true) + .context("failed to enable TCP_NODELAY")?; #[cfg(any(unix, windows))] { let keepalive = TcpKeepalive::new() @@ -301,7 +311,10 @@ async fn enable_tracking( match timeout(timeout_duration, framed.next()).await { Ok(Some(Ok(resp))) => match resp { RespValue::SimpleString(ref s) | RespValue::BulkString(ref s) - if s.eq_ignore_ascii_case(b"OK") => Ok(()), + if s.eq_ignore_ascii_case(b"OK") => + { + Ok(()) + } RespValue::Error(err) => Err(anyhow!( "backend {backend} rejected CLIENT TRACKING for cluster {cluster}: {}", String::from_utf8_lossy(&err) @@ -313,7 +326,9 @@ async fn enable_tracking( }, Ok(Some(Err(err))) => Err(err.context("CLIENT TRACKING failed")), Ok(None) => Err(anyhow!("backend {backend} closed during CLIENT TRACKING")), - Err(_) => Err(anyhow!("backend {backend} timed out waiting for CLIENT TRACKING")), + Err(_) => Err(anyhow!( + "backend {backend} timed out waiting for CLIENT TRACKING" + )), } } @@ -337,3 +352,71 @@ fn parse_invalidation(items: &[RespValue]) -> Option> { } Some(keys) } + +#[cfg(test)] +mod tests { + use super::*; + use tokio::sync::watch; + + #[tokio::test(flavor = "current_thread")] + async fn wait_for_enabled_returns_on_state_change() { + let (state_tx, mut state_rx) = watch::channel(CacheState::Disabled); + let (_shutdown_tx, mut shutdown_rx) = watch::channel(false); + let handle = + tokio::spawn(async move { wait_for_enabled(&mut state_rx, &mut shutdown_rx).await }); + state_tx.send(CacheState::Enabled).unwrap(); + let result = tokio::time::timeout(Duration::from_millis(100), handle) + .await + .expect("wait task finished") + .expect("wait result"); + assert!(result); + } + + #[tokio::test(flavor = "current_thread")] + async fn wait_for_enabled_exits_on_shutdown() { + let (_state_tx, mut state_rx) = watch::channel(CacheState::Disabled); + let (shutdown_tx, mut shutdown_rx) = watch::channel(false); + let handle = + tokio::spawn(async move { wait_for_enabled(&mut state_rx, &mut shutdown_rx).await }); + shutdown_tx.send(true).unwrap(); + let result = tokio::time::timeout(Duration::from_millis(100), handle) + .await + .expect("wait task finished") + .expect("wait result"); + assert!(!result); + } + + #[tokio::test(flavor = "current_thread")] + async fn wait_with_shutdown_obeys_delay() { + let (_shutdown_tx, mut shutdown_rx) = watch::channel(false); + let result = wait_with_shutdown(Duration::from_millis(10), &mut shutdown_rx).await; + assert!(!result); + } + + #[tokio::test(flavor = "current_thread")] + async fn wait_with_shutdown_detects_signal() { + let (shutdown_tx, mut shutdown_rx) = watch::channel(false); + shutdown_tx.send(true).unwrap(); + let result = wait_with_shutdown(Duration::from_secs(1), &mut shutdown_rx).await; + assert!(result); + } + + #[test] + fn parse_invalidation_collects_keys() { + let payload = vec![ + RespValue::BulkString(Bytes::from_static(b"invalidate")), + RespValue::BulkString(Bytes::from_static(b"foo")), + RespValue::SimpleString(Bytes::from_static(b"bar")), + ]; + let keys = parse_invalidation(&payload).expect("keys"); + assert_eq!(keys.len(), 2); + assert_eq!(keys[0], Bytes::from_static(b"foo")); + assert_eq!(keys[1], Bytes::from_static(b"bar")); + } + + #[test] + fn parse_invalidation_rejects_unexpected_labels() { + let payload = vec![RespValue::BulkString(Bytes::from_static(b"ignore"))]; + assert!(parse_invalidation(&payload).is_none()); + } +} diff --git a/src/cluster/mod.rs b/src/cluster/mod.rs index 6da80d1..8b01d4c 100644 --- a/src/cluster/mod.rs +++ b/src/cluster/mod.rs @@ -20,9 +20,9 @@ use tokio_util::codec::{Framed, FramedParts}; use tracing::{debug, info, warn}; use crate::auth::{AuthAction, BackendAuth, FrontendAuthenticator}; -use crate::cache::{tracker::CacheTrackerSet, ClientCache}; use crate::backend::client::{ClientId, FrontConnectionGuard}; use crate::backend::pool::{BackendNode, ConnectionPool, Connector, SessionCommand}; +use crate::cache::{tracker::CacheTrackerSet, ClientCache}; use crate::config::{BackupRequestRuntime, ClusterConfig, ClusterRuntime, ConfigManager}; use crate::hotkey::Hotkey; use crate::info::{InfoContext, ProxyMode}; @@ -1402,7 +1402,15 @@ async fn fetch_topology( _ = trigger.recv() => {}, } - if let Err(err) = fetch_once(&cluster, &seeds, connector.clone(), slots.clone(), tracker.clone()).await { + if let Err(err) = fetch_once( + &cluster, + &seeds, + connector.clone(), + slots.clone(), + tracker.clone(), + ) + .await + { warn!(cluster = %cluster, error = %err, "failed to refresh cluster topology"); } } @@ -1453,6 +1461,8 @@ async fn fetch_from_seed(seed: &str, connector: Arc) -> Result #[cfg(test)] mod tests { use super::*; + use crate::utils::{crc16, trim_hash_tag}; + use std::collections::HashSet; #[test] fn parse_moved_redirect() { @@ -1466,6 +1476,121 @@ mod tests { _ => panic!("expected MOVED"), } } + + #[test] + fn classify_backend_error_detects_categories() { + assert_eq!( + classify_backend_error(&anyhow!("timed out waiting")), + "timeout" + ); + assert_eq!( + classify_backend_error(&anyhow!("closed connection")), + "closed" + ); + assert_eq!( + classify_backend_error(&anyhow!("unexpected heartbeat reply")), + "protocol" + ); + assert_eq!(classify_backend_error(&anyhow!("other")), "execute"); + } + + #[test] + fn subscription_slot_tracks_hash_slot() { + let command = RedisCommand::new(vec![ + Bytes::from_static(b"SUBSCRIBE"), + Bytes::from_static(b"channel"), + ]) + .unwrap(); + let slot = subscription_slot_for_command(&command, None, None) + .unwrap() + .unwrap(); + let expected = crc16(trim_hash_tag(b"channel", None)) % SLOT_COUNT; + assert_eq!(slot, expected); + + let unsubscribe = RedisCommand::new(vec![Bytes::from_static(b"UNSUBSCRIBE")]).unwrap(); + let current = subscription_slot_for_command(&unsubscribe, None, Some(slot)) + .unwrap() + .unwrap(); + assert_eq!(current, slot); + } + + #[test] + fn subscription_slot_errors_on_conflicting_channels() { + let command = RedisCommand::new(vec![ + Bytes::from_static(b"SUBSCRIBE"), + Bytes::from_static(b"foo"), + Bytes::from_static(b"bar"), + ]) + .unwrap(); + assert!(subscription_slot_for_command(&command, None, None).is_err()); + } + + #[test] + fn expected_ack_count_matches_subscription_kind() { + let subscribe = RedisCommand::new(vec![ + Bytes::from_static(b"SUBSCRIBE"), + Bytes::from_static(b"foo"), + Bytes::from_static(b"bar"), + ]) + .unwrap(); + assert_eq!(expected_ack_count(&subscribe, 0, 0), 2); + + let unsubscribe = RedisCommand::new(vec![Bytes::from_static(b"UNSUBSCRIBE")]).unwrap(); + assert_eq!(expected_ack_count(&unsubscribe, 3, 1), 3); + } + + #[test] + fn apply_subscription_membership_updates_sets() { + let mut channels = HashSet::new(); + let mut patterns = HashSet::new(); + let resp = RespValue::Array(vec![ + RespValue::BulkString(Bytes::from_static(b"subscribe")), + RespValue::BulkString(Bytes::from_static(b"chan")), + RespValue::Integer(1), + ]); + assert!(apply_subscription_membership( + &resp, + &mut channels, + &mut patterns + )); + assert!(channels.contains(&Bytes::from_static(b"chan"))); + + let resp = RespValue::Array(vec![ + RespValue::BulkString(Bytes::from_static(b"unsubscribe")), + RespValue::BulkString(Bytes::from_static(b"chan")), + RespValue::Integer(0), + ]); + assert!(apply_subscription_membership( + &resp, + &mut channels, + &mut patterns + )); + assert!(channels.is_empty()); + } + + #[test] + fn derive_slot_from_args_returns_consistent_value() { + let args = vec![Bytes::from_static(b"foo"), Bytes::from_static(b"foo")]; + let slot = derive_slot_from_args(&args, None).unwrap().unwrap(); + let expected = crc16(trim_hash_tag(b"foo", None)) % SLOT_COUNT; + assert_eq!(slot, expected); + } + + #[test] + fn derive_slot_from_args_fails_for_mixed_slots() { + let args = vec![Bytes::from_static(b"foo"), Bytes::from_static(b"bar")]; + assert!(derive_slot_from_args(&args, None).is_err()); + } + + #[test] + fn resp_value_to_bytes_handles_bulk_and_null() { + let value = RespValue::BulkString(Bytes::from_static(b"key")); + assert_eq!( + resp_value_to_bytes(&value), + Some(Bytes::from_static(b"key")) + ); + assert!(resp_value_to_bytes(&RespValue::NullBulk).is_none()); + } } #[derive(Debug)] enum Redirect { @@ -1689,14 +1814,7 @@ async fn execute_with_backup( .await?; if let Some(plan) = plan { race_with_backup( - pool, - client_id, - command, - target, - primary_rx, - cluster, - plan, - controller, + pool, client_id, command, target, primary_rx, cluster, plan, controller, ) .await } else { diff --git a/src/config/mod.rs b/src/config/mod.rs index e652fae..0c7f5d3 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -783,10 +783,9 @@ impl ConfigManager { ClusterField::ClientCacheEnabled => { let enabled = parse_bool_flag(value, "client-cache-enabled")?; if enabled { - entry - .client_cache - .enable() - .with_context(|| format!("cluster {} failed to enable client cache", cluster_name))?; + entry.client_cache.enable().with_context(|| { + format!("cluster {} failed to enable client cache", cluster_name) + })?; } else { entry.client_cache.disable(); } @@ -813,7 +812,9 @@ impl ConfigManager { let parsed = parse_positive_usize(value, "client-cache-max-value-bytes")?; entry.client_cache.set_max_value_bytes(parsed); let mut guard = self.config.write(); - guard.clusters_mut()[entry.index].client_cache.max_value_bytes = parsed; + guard.clusters_mut()[entry.index] + .client_cache + .max_value_bytes = parsed; info!( cluster = cluster_name, value = value, @@ -846,7 +847,9 @@ impl ConfigManager { let parsed = parse_positive_u64(value, "client-cache-drain-interval-ms")?; entry.client_cache.set_drain_interval(parsed); let mut guard = self.config.write(); - guard.clusters_mut()[entry.index].client_cache.drain_interval_ms = parsed; + guard.clusters_mut()[entry.index] + .client_cache + .drain_interval_ms = parsed; info!( cluster = cluster_name, value = value, @@ -1246,3 +1249,195 @@ fn wildcard_match(pattern: &str, target: &str) -> bool { p == pattern.len() } + +#[cfg(test)] +mod tests { + use super::*; + use crate::auth::{AuthUserConfig, BackendAuthConfig, FrontendAuthConfig, FrontendAuthTable}; + use crate::protocol::redis::RespVersion; + use std::env; + use std::sync::Mutex; + + static ENV_GUARD: Mutex<()> = Mutex::new(()); + + fn base_cluster() -> ClusterConfig { + ClusterConfig { + name: "alpha".into(), + listen_addr: "127.0.0.1:7000".into(), + hash_tag: None, + thread: Some(2), + cache_type: CacheType::Redis, + read_timeout: Some(1500), + write_timeout: Some(1500), + servers: vec!["127.0.0.1:6379".into()], + fetch_interval: None, + read_from_slave: None, + ping_fail_limit: Some(1), + ping_interval: Some(60), + ping_succ_interval: Some(120), + dial_timeout: None, + listen_proto: None, + node_connections: Some(2), + auth: None, + password: None, + backend_auth: None, + backend_password: None, + slowlog_log_slower_than: default_slowlog_log_slower_than(), + slowlog_max_len: default_slowlog_max_len(), + hotkey_sample_every: default_hotkey_sample_every(), + hotkey_sketch_width: default_hotkey_sketch_width(), + hotkey_sketch_depth: default_hotkey_sketch_depth(), + hotkey_capacity: default_hotkey_capacity(), + hotkey_decay: default_hotkey_decay(), + backend_resp_version: RespVersion::Resp2, + client_cache: ClientCacheConfig::default(), + backup_request: BackupRequestConfig::default(), + } + } + + fn config_with_single_cluster(cluster: ClusterConfig) -> Config { + Config { + clusters: vec![cluster], + } + } + + #[test] + fn cluster_config_validation_succeeds_for_minimal_setup() { + let cfg = base_cluster(); + cfg.ensure_valid().expect("valid cluster"); + } + + #[test] + fn cluster_config_validation_detects_missing_servers() { + let mut cfg = base_cluster(); + cfg.servers.clear(); + assert!(cfg.ensure_valid().is_err()); + } + + #[test] + fn listen_port_parses_hostname_style_addresses() { + let mut cfg = base_cluster(); + cfg.listen_addr = "cache.example.com:8888".into(); + assert_eq!(cfg.listen_port().unwrap(), 8888); + } + + #[test] + fn frontend_auth_users_prefers_explicit_acl_config() { + let mut cfg = base_cluster(); + cfg.auth = Some(FrontendAuthConfig::Detailed(FrontendAuthTable { + password: Some("shared".into()), + users: vec![AuthUserConfig { + username: "extra".into(), + password: "xyz".into(), + }], + })); + cfg.password = Some("legacy".into()); + let users = cfg.frontend_auth_users().expect("auth users"); + assert_eq!(users.len(), 2); + assert_eq!(users[0].username, "default"); + assert_eq!(users[0].password, "shared"); + } + + #[test] + fn frontend_auth_users_falls_back_to_password_field() { + let mut cfg = base_cluster(); + cfg.password = Some("legacy".into()); + let users = cfg.frontend_auth_users().expect("auth users"); + assert_eq!(users.len(), 1); + assert_eq!(users[0].username, "default"); + assert_eq!(users[0].password, "legacy"); + } + + #[test] + fn backend_auth_prefers_acl_config() { + let mut cfg = base_cluster(); + cfg.backend_auth = Some(BackendAuthConfig::Credential { + username: "user".into(), + password: "pw".into(), + }); + cfg.backend_password = Some("legacy".into()); + let auth = cfg.backend_auth_config().expect("backend auth"); + match auth { + BackendAuthConfig::Credential { username, password } => { + assert_eq!(username, "user"); + assert_eq!(password, "pw"); + } + BackendAuthConfig::Password(_) => panic!("expected credential variant"), + } + } + + #[test] + fn backend_auth_falls_back_to_password_field() { + let mut cfg = base_cluster(); + cfg.backend_password = Some("legacy".into()); + let auth = cfg.backend_auth_config().expect("backend auth"); + match auth { + BackendAuthConfig::Password(password) => assert_eq!(password, "legacy"), + _ => panic!("expected password variant"), + } + } + + #[test] + fn apply_defaults_uses_env_override_for_threads() { + let _guard = ENV_GUARD.lock().unwrap(); + let previous = env::var(ENV_DEFAULT_THREADS).ok(); + env::set_var(ENV_DEFAULT_THREADS, "11"); + let mut cfg = config_with_single_cluster(ClusterConfig { + thread: None, + ..base_cluster() + }); + cfg.apply_defaults(); + let thread = cfg.clusters()[0].thread.expect("thread assigned"); + if let Some(val) = previous { + env::set_var(ENV_DEFAULT_THREADS, val); + } else { + env::remove_var(ENV_DEFAULT_THREADS); + } + assert_eq!(thread, 11); + } + + #[test] + fn parse_port_handles_ipv6_endpoints() { + let port = parse_port("[2001:db8::1]:7100").expect("ipv6 port"); + assert_eq!(port, 7100); + } + + #[test] + fn parse_key_recognizes_known_fields() { + let (cluster, field) = parse_key("cluster.alpha.hotkey-decay").expect("key parsed"); + assert_eq!(cluster, "alpha"); + assert!(matches!(field, ClusterField::HotkeyDecay)); + } + + #[test] + fn parse_key_rejects_unknown_fields() { + assert!(parse_key("cluster.alpha.unknown").is_err()); + } + + #[test] + fn parse_timeout_value_handles_default_marker() { + assert_eq!(parse_timeout_value("default").unwrap(), None); + assert_eq!(parse_timeout_value("250").unwrap(), Some(250)); + assert!(parse_timeout_value("oops").is_err()); + } + + #[test] + fn parse_bool_flag_understands_common_aliases() { + assert!(parse_bool_flag("YES", "flag").unwrap()); + assert!(!parse_bool_flag("0", "flag").unwrap()); + assert!(parse_bool_flag("maybe", "flag").is_err()); + } + + #[test] + fn wildcard_match_supports_basic_patterns() { + assert!(wildcard_match("cluster.*", "cluster.alpha")); + assert!(wildcard_match("cache-?", "cache-a")); + assert!(!wildcard_match("cache-?", "cache-long")); + } + + #[test] + fn option_atomic_roundtrip() { + assert_eq!(atomic_to_option(option_to_atomic(Some(42))), Some(42)); + assert_eq!(atomic_to_option(option_to_atomic(None)), None); + } +} diff --git a/src/info.rs b/src/info.rs index 1bcf264..cedbce6 100644 --- a/src/info.rs +++ b/src/info.rs @@ -232,3 +232,41 @@ fn format_bytes(bytes: u64) -> String { format!("{value:.2}{unit}") } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn proxy_mode_labels_are_stable() { + assert_eq!(ProxyMode::Standalone.as_str(), "standalone"); + assert_eq!(ProxyMode::Cluster.as_str(), "cluster"); + } + + #[test] + fn format_bytes_handles_units() { + assert_eq!(format_bytes(0), "0B"); + assert_eq!(format_bytes(512), "512B"); + assert_eq!(format_bytes(1024), "1.00KB"); + } + + #[test] + fn render_info_includes_expected_sections() { + register_info_metrics(); + let context = InfoContext { + cluster: "info-test", + mode: ProxyMode::Standalone, + listen_port: 7000, + backend_nodes: 2, + }; + let payload = render_info(context, Some("server")); + let text = String::from_utf8(payload.to_vec()).unwrap(); + assert!(text.contains("# Server")); + assert!(text.contains("cluster_name:info-test")); + } + + fn register_info_metrics() { + metrics::front_conn_open("info-test"); + metrics::front_command("info-test", "read", true); + } +} diff --git a/src/lib.rs b/src/lib.rs index 68bfde6..38b02d1 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -16,8 +16,8 @@ use tracing::{info, warn}; use tracing_subscriber::{fmt, EnvFilter}; pub mod auth; -pub mod cache; pub mod backend; +pub mod cache; pub mod cluster; pub mod config; pub mod hotkey; diff --git a/src/meta.rs b/src/meta.rs index cabf71f..ecd259d 100644 --- a/src/meta.rs +++ b/src/meta.rs @@ -88,3 +88,110 @@ fn determine_ip(override_ip: Option<&str>) -> Result { // Fallback to loopback. Ok("127.0.0.1".to_string()) } + +#[cfg(test)] +mod tests { + use super::*; + use crate::config::{BackupRequestConfig, CacheType, ClientCacheConfig, ClusterConfig}; + use crate::protocol::redis::RespVersion; + use std::sync::Mutex; + + static ENV_GUARD: Mutex<()> = Mutex::new(()); + + fn sample_cluster() -> ClusterConfig { + ClusterConfig { + name: "test".into(), + listen_addr: "127.0.0.1:7000".into(), + hash_tag: None, + thread: Some(2), + cache_type: CacheType::Redis, + read_timeout: Some(1000), + write_timeout: Some(1000), + servers: vec!["127.0.0.1:6379".into()], + fetch_interval: None, + read_from_slave: None, + ping_fail_limit: None, + ping_interval: None, + ping_succ_interval: None, + dial_timeout: None, + listen_proto: None, + node_connections: None, + auth: None, + password: None, + backend_auth: None, + backend_password: None, + slowlog_log_slower_than: 10_000, + slowlog_max_len: 128, + hotkey_sample_every: 200, + hotkey_sketch_width: 256, + hotkey_sketch_depth: 4, + hotkey_capacity: 5_000, + hotkey_decay: 0.5, + backend_resp_version: RespVersion::Resp2, + client_cache: ClientCacheConfig::default(), + backup_request: BackupRequestConfig::default(), + } + } + + #[tokio::test(flavor = "current_thread")] + async fn scope_with_meta_makes_context_available() { + let meta = Meta { + cluster: "cluster-a".into(), + ip: "10.1.1.5".into(), + port: 7111, + }; + + let value = scope_with_meta(meta.clone(), async { + assert_eq!(current_cluster().as_deref(), Some("cluster-a")); + assert_eq!(current_ip().as_deref(), Some("10.1.1.5")); + assert_eq!(current_port(), Some(7111)); + current_meta().unwrap().cluster().to_string() + }) + .await; + + assert_eq!(value, "cluster-a".to_string()); + assert!(current_meta().is_none()); + } + + #[test] + fn derive_meta_prefers_override_ip() { + let cluster = sample_cluster(); + let meta = derive_meta(&cluster, Some("192.168.1.10")).expect("derive meta"); + assert_eq!(meta.ip(), "192.168.1.10"); + assert_eq!(meta.port(), 7000); + assert_eq!(meta.cluster(), "test"); + } + + #[test] + fn determine_ip_uses_valid_host_env() { + let _guard = ENV_GUARD.lock().unwrap(); + let prev = env::var("HOST").ok(); + env::set_var("HOST", "10.0.0.9"); + let ip = determine_ip(None).expect("determine ip"); + if let Some(value) = prev { + env::set_var("HOST", value); + } else { + env::remove_var("HOST"); + } + assert_eq!(ip, "10.0.0.9"); + } + + #[test] + fn determine_ip_falls_back_when_env_invalid() { + let _guard = ENV_GUARD.lock().unwrap(); + let prev = env::var("HOST").ok(); + env::set_var("HOST", "not-an-ip"); + let ip = determine_ip(None).expect("determine ip"); + if let Some(value) = prev { + env::set_var("HOST", value); + } else { + env::remove_var("HOST"); + } + assert_eq!(ip, "127.0.0.1"); + } + + #[test] + fn current_meta_returns_none_outside_scope() { + assert!(current_meta().is_none()); + } +} diff --git a/src/metrics/mod.rs b/src/metrics/mod.rs index aeff374..52e615a 100644 --- a/src/metrics/mod.rs +++ b/src/metrics/mod.rs @@ -388,9 +388,7 @@ pub fn client_cache_lookup(cluster: &str, kind: &str, hit: bool) { /// Record a client cache store/update event. pub fn client_cache_store(cluster: &str, kind: &str) { - CLIENT_CACHE_STORE - .with_label_values(&[cluster, kind]) - .inc(); + CLIENT_CACHE_STORE.with_label_values(&[cluster, kind]).inc(); } /// Record the number of keys invalidated from the client cache. @@ -574,3 +572,41 @@ async fn system_monitor_loop(interval: Duration) -> Result<()> { } Ok(()) } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn front_connections_counters_reflect_updates() { + let cluster = "metrics-test"; + front_conn_open(cluster); + front_conn_open(cluster); + front_conn_close(cluster); + assert_eq!(front_connections_current(cluster), 1); + assert!(front_connections_total(cluster) >= 2); + } + + #[test] + fn front_command_stats_aggregates_totals() { + let cluster = "metrics-stats"; + front_command(cluster, "read", true); + front_command(cluster, "read", false); + front_command(cluster, "invalid", false); + let stats = front_command_stats(cluster); + assert_eq!(stats.read_ok, 1); + assert_eq!(stats.read_fail, 1); + assert_eq!(stats.invalid_fail, 1); + assert_eq!(stats.total(), 3); + } + + #[test] + fn global_usage_accessors_reflect_gauges() { + MEMORY_USAGE.set(1024.0); + CPU_USAGE.set(12.5); + GLOBAL_ERROR.inc_by(5); + assert_eq!(memory_usage_bytes(), 1024 * 1024); + assert_eq!(cpu_usage_percent(), 12.5); + assert!(global_error_count() >= 5); + } +} diff --git a/src/metrics/tracker.rs b/src/metrics/tracker.rs index d4e4616..9de354d 100644 --- a/src/metrics/tracker.rs +++ b/src/metrics/tracker.rs @@ -24,3 +24,23 @@ impl Drop for Tracker { self.histogram.observe(micros); } } + +#[cfg(test)] +mod tests { + use super::*; + use prometheus::{Histogram, HistogramOpts}; + use std::thread; + use std::time::Duration; + + #[test] + fn tracker_records_elapsed_time_on_drop() { + let histogram = Histogram::with_opts(HistogramOpts::new("tracker_test", "test")).unwrap(); + let before = histogram.get_sample_count(); + { + let _tracker = Tracker::new(histogram.clone()); + thread::sleep(Duration::from_millis(1)); + } + let after = histogram.get_sample_count(); + assert!(after > before); + } +} diff --git a/src/slowlog.rs b/src/slowlog.rs index beb3132..0ee12d1 100644 --- a/src/slowlog.rs +++ b/src/slowlog.rs @@ -199,6 +199,63 @@ fn slowlog_value_error() -> RespValue { )) } +#[cfg(test)] +mod tests { + use super::*; + + fn slowlog_with_entries() -> Slowlog { + let log = Slowlog::new(0, 4); + let cmd = RedisCommand::new(vec![Bytes::from_static(b"PING")]).unwrap(); + log.maybe_record(&cmd, Duration::from_micros(10)); + log + } + + #[test] + fn handle_get_returns_entries() { + let log = slowlog_with_entries(); + let result = handle_command( + &log, + &[Bytes::from_static(b"slowlog"), Bytes::from_static(b"get")], + ); + match result { + RespValue::Array(values) => assert!(!values.is_empty()), + other => panic!("unexpected response: {:?}", other), + } + } + + #[test] + fn handle_len_reports_length() { + let log = slowlog_with_entries(); + let result = handle_command( + &log, + &[Bytes::from_static(b"slowlog"), Bytes::from_static(b"len")], + ); + assert!(matches!(result, RespValue::Integer(value) if value >= 1)); + } + + #[test] + fn handle_reset_clears_entries() { + let log = slowlog_with_entries(); + let _ = handle_command( + &log, + &[Bytes::from_static(b"slowlog"), Bytes::from_static(b"reset")], + ); + let result = handle_command( + &log, + &[Bytes::from_static(b"slowlog"), Bytes::from_static(b"len")], + ); + assert_eq!(result, RespValue::Integer(0)); + } + + #[test] + fn parse_non_negative_validates_input() { + let value = Bytes::from_static(b"5"); + assert_eq!(parse_non_negative(&value).unwrap(), 5); + let invalid = Bytes::from_static(b"-1"); + assert!(parse_non_negative(&invalid).is_err()); + } +} + struct RingBuffer { buf: Vec, capacity: usize, diff --git a/src/standalone/mod.rs b/src/standalone/mod.rs index cb8f946..35015dc 100644 --- a/src/standalone/mod.rs +++ b/src/standalone/mod.rs @@ -18,9 +18,9 @@ use tokio_util::codec::{Framed, FramedParts}; use tracing::{debug, info, warn}; use crate::auth::{AuthAction, BackendAuth, FrontendAuthenticator}; -use crate::cache::{tracker::CacheTrackerSet, ClientCache}; use crate::backend::client::{ClientId, FrontConnectionGuard}; use crate::backend::pool::{BackendNode, ConnectionPool, Connector, SessionCommand}; +use crate::cache::{tracker::CacheTrackerSet, ClientCache}; use crate::config::{ClusterConfig, ClusterRuntime, ConfigManager}; use crate::hotkey::Hotkey; use crate::info::{InfoContext, ProxyMode}; @@ -693,6 +693,89 @@ fn subscription_action_kind(kind: &[u8]) -> Option { } } +#[cfg(test)] +mod tests { + use super::*; + use bytes::Bytes; + + fn node_entry(address: &str, display: &str, weight: usize) -> NodeEntry { + NodeEntry { + backend: BackendNode::new(address.to_string()), + display: Arc::::from(display.to_string()), + weight, + } + } + + #[test] + fn parse_address_weight_extracts_suffix() { + let (addr, weight) = parse_address_weight("127.0.0.1:6379:3").unwrap(); + assert_eq!(addr, "127.0.0.1:6379"); + assert_eq!(weight, 3); + + let (addr, weight) = parse_address_weight("10.0.0.1:6380").unwrap(); + assert_eq!(addr, "10.0.0.1:6380"); + assert_eq!(weight, 1); + } + + #[test] + fn parse_servers_respects_alias_and_weight() { + let input = vec!["127.0.0.1:6379:2 main".to_string()]; + let entries = parse_servers(&input).expect("servers"); + assert_eq!(entries.len(), 1); + assert_eq!(entries[0].backend.as_str(), "127.0.0.1:6379"); + assert_eq!(entries[0].weight, 2); + assert_eq!(&*entries[0].display, "main"); + } + + #[test] + fn build_ring_reflects_weight() { + let entries = vec![node_entry("127.0.0.1:6379", "node", 2)]; + let ring = build_ring(&entries); + assert_eq!(ring.len(), 2 * VIRTUAL_NODE_FACTOR); + assert!(ring.windows(2).all(|w| w[0].0 <= w[1].0)); + } + + #[test] + fn subscription_action_kind_maps_known_values() { + assert_eq!( + subscription_action_kind(b"subscribe"), + Some(SubscriptionKind::Channel) + ); + assert_eq!( + subscription_action_kind(b"psubscribe"), + Some(SubscriptionKind::Pattern) + ); + assert!(subscription_action_kind(b"unknown").is_none()); + } + + #[test] + fn subscription_count_parses_arrays() { + let resp = RespValue::Array(vec![ + RespValue::BulkString(Bytes::from_static(b"psubscribe")), + RespValue::BulkString(Bytes::from_static(b"chan")), + RespValue::Integer(3), + ]); + assert_eq!( + subscription_count(&resp), + Some((SubscriptionKind::Pattern, 3)) + ); + } + + #[test] + fn subscription_count_rejects_invalid_payload() { + let resp = RespValue::Array(vec![RespValue::Integer(1)]); + assert!(subscription_count(&resp).is_none()); + } + + #[test] + fn hash_key_differs_for_distinct_inputs() { + let a = hash_key(b"alpha"); + let b = hash_key(b"beta"); + assert_ne!(a, b); + assert_eq!(hash_key(b"alpha"), a); + } +} + #[derive(Clone)] struct RedisConnector { runtime: Arc, From b9fb99326017c2e8bd7a5fc6939ad9272e2b2652 Mon Sep 17 00:00:00 2001 From: wayslog Date: Thu, 20 Nov 2025 19:11:47 +0800 Subject: [PATCH 2/4] chore: add connector mock --- src/backend/executor.rs | 72 ++++++++++++ src/backend/mod.rs | 1 + src/standalone/mod.rs | 237 ++++++++++++++++++++++++++++++++++------ 3 files changed, 279 insertions(+), 31 deletions(-) create mode 100644 src/backend/executor.rs diff --git a/src/backend/executor.rs b/src/backend/executor.rs new file mode 100644 index 0000000..322e984 --- /dev/null +++ b/src/backend/executor.rs @@ -0,0 +1,72 @@ +use std::sync::Arc; + +use anyhow::{anyhow, Result}; +use async_trait::async_trait; + +use crate::backend::client::ClientId; +use crate::backend::pool::{BackendNode, ConnectionPool}; + +use super::pool::BackendRequest; + +/// Abstraction over backend request execution so it can be mocked in tests. +#[async_trait] +pub trait BackendExecutor: Send + Sync +where + T: BackendRequest, +{ + /// Dispatch a non-blocking request for the given client. + async fn dispatch( + &self, + node: BackendNode, + client_id: ClientId, + request: T, + ) -> Result; + + /// Dispatch a request that requires an exclusive backend connection. + async fn dispatch_blocking(&self, node: BackendNode, request: T) -> Result; +} + +/// Default executor that proxies calls through the actual connection pool. +pub struct PoolBackendExecutor { + pool: Arc>, +} + +impl PoolBackendExecutor { + pub fn new(pool: Arc>) -> Self { + Self { pool } + } + + pub fn pool(&self) -> &Arc> { + &self.pool + } +} + +#[async_trait] +impl BackendExecutor for PoolBackendExecutor +where + T: BackendRequest, +{ + async fn dispatch( + &self, + node: BackendNode, + client_id: ClientId, + request: T, + ) -> Result { + let response_rx = self.pool.dispatch(node, client_id, request).await?; + match response_rx.await { + Ok(result) => result, + Err(_) => Err(anyhow!("backend session closed unexpectedly")), + } + } + + async fn dispatch_blocking(&self, node: BackendNode, request: T) -> Result { + let mut exclusive = self.pool.acquire_exclusive(&node); + let response_rx = exclusive.send(request).await?; + let outcome = response_rx.await; + drop(exclusive); + match outcome { + Ok(result) => result, + Err(_) => Err(anyhow!("backend session closed unexpectedly")), + } + } +} diff --git a/src/backend/mod.rs b/src/backend/mod.rs index cfb96b0..49d1c43 100644 --- a/src/backend/mod.rs +++ b/src/backend/mod.rs @@ -1,2 +1,3 @@ pub mod client; +pub mod executor; pub mod pool; diff --git a/src/standalone/mod.rs b/src/standalone/mod.rs index 35015dc..d3cc5ee 100644 --- a/src/standalone/mod.rs +++ b/src/standalone/mod.rs @@ -19,6 +19,7 @@ use tracing::{debug, info, warn}; use crate::auth::{AuthAction, BackendAuth, FrontendAuthenticator}; use crate::backend::client::{ClientId, FrontConnectionGuard}; +use crate::backend::executor::{BackendExecutor, PoolBackendExecutor}; use crate::backend::pool::{BackendNode, ConnectionPool, Connector, SessionCommand}; use crate::cache::{tracker::CacheTrackerSet, ClientCache}; use crate::config::{ClusterConfig, ClusterRuntime, ConfigManager}; @@ -49,7 +50,7 @@ pub struct StandaloneProxy { ring: Vec<(u64, BackendNode)>, auth: Option>, backend_auth: Option, - pool: Arc>, + backend: Arc>, runtime: Arc, config_manager: Arc, slowlog: Arc, @@ -65,6 +66,15 @@ impl StandaloneProxy { config: &ClusterConfig, runtime: Arc, config_manager: Arc, + ) -> Result { + Self::new_with_backend(config, runtime, config_manager, None) + } + + fn new_with_backend( + config: &ClusterConfig, + runtime: Arc, + config_manager: Arc, + backend_override: Option>>, ) -> Result { let cluster: Arc = config.name.clone().into(); let hash_tag = config.hash_tag.as_ref().map(|tag| tag.as_bytes().to_vec()); @@ -78,18 +88,25 @@ impl StandaloneProxy { let ring = build_ring(&nodes); let backend_auth = config.backend_auth_config().map(BackendAuth::from); - let connector = Arc::new(RedisConnector::new( - runtime.clone(), - DEFAULT_TIMEOUT_MS, - backend_auth.clone(), - config.backend_resp_version, - )); + + let backend: Arc> = + if let Some(backend) = backend_override { + backend + } else { + let connector = Arc::new(RedisConnector::new( + runtime.clone(), + DEFAULT_TIMEOUT_MS, + backend_auth.clone(), + config.backend_resp_version, + )); + let pool = Arc::new(ConnectionPool::new(cluster.clone(), connector)); + Arc::new(PoolBackendExecutor::new(pool)) + }; let auth = config .frontend_auth_users() .map(FrontendAuthenticator::from_users) .transpose()? .map(Arc::new); - let pool = Arc::new(ConnectionPool::new(cluster.clone(), connector)); let backend_nodes = nodes.len(); let listen_port = config.listen_port()?; @@ -121,7 +138,7 @@ impl StandaloneProxy { ring, auth, backend_auth, - pool, + backend, runtime, config_manager, slowlog, @@ -182,22 +199,11 @@ impl StandaloneProxy { match command.as_blocking() { BlockingKind::Queue { .. } | BlockingKind::Stream { .. } => { let node = self.select_node(client_id, &command)?; - let mut exclusive = self.pool.acquire_exclusive(&node); - let response_rx = exclusive.send(command).await?; - let outcome = response_rx.await; - drop(exclusive); - match outcome { - Ok(result) => result, - Err(_) => Err(anyhow!("backend session closed unexpectedly")), - } + self.backend.dispatch_blocking(node, command).await } BlockingKind::None => { let node = self.select_node(client_id, &command)?; - let response_rx = self.pool.dispatch(node, client_id, command).await?; - match response_rx.await { - Ok(result) => result, - Err(_) => Err(anyhow!("backend session closed unexpectedly")), - } + self.backend.dispatch(node, client_id, command).await } } } @@ -211,17 +217,14 @@ impl StandaloneProxy { FuturesOrdered::new(); for sub in multi.subcommands.into_iter() { let node = self.select_node(client_id, &sub.command)?; - let pool = self.pool.clone(); + let backend = self.backend.clone(); let SubCommand { positions, command } = sub; tasks.push_back(Box::pin(async move { - let response_rx = pool.dispatch(node, client_id, command).await?; - match response_rx.await { - Ok(result) => Ok(SubResponse { - positions, - response: result?, - }), - Err(_) => Err(anyhow!("backend session closed unexpectedly")), - } + let response = backend.dispatch(node, client_id, command).await?; + Ok(SubResponse { + positions, + response, + }) })); } @@ -697,6 +700,16 @@ fn subscription_action_kind(kind: &[u8]) -> Option { mod tests { use super::*; use bytes::Bytes; + use std::collections::VecDeque; + use std::path::PathBuf; + use std::sync::Arc; + use std::sync::Mutex; + + use crate::backend::client::ClientId; + use crate::backend::executor::BackendExecutor; + use crate::config::{Config, ConfigManager}; + use anyhow::anyhow; + use async_trait::async_trait; fn node_entry(address: &str, display: &str, weight: usize) -> NodeEntry { NodeEntry { @@ -774,6 +787,168 @@ mod tests { assert_ne!(a, b); assert_eq!(hash_key(b"alpha"), a); } + + #[tokio::test(flavor = "current_thread")] + async fn dispatch_sends_requests_to_backend_executor() { + let backend = Arc::new(MockBackend::default()); + backend.push_shared(Ok(RespValue::SimpleString(Bytes::from_static(b"PONG")))); + let proxy = build_proxy_with_backend(backend.clone()); + let response = proxy + .dispatch(ClientId::new(), redis_cmd(&["PING"])) + .await + .expect("response"); + assert_eq!( + response, + RespValue::SimpleString(Bytes::from_static(b"PONG")) + ); + assert_eq!(backend.calls(), vec![MockCall::Shared("PING".into())]); + } + + #[tokio::test(flavor = "current_thread")] + async fn dispatch_blocking_commands_use_exclusive_path() { + let backend = Arc::new(MockBackend::default()); + backend.push_blocking(Ok(RespValue::Array(vec![]))); + let proxy = build_proxy_with_backend(backend.clone()); + let response = proxy + .dispatch(ClientId::new(), redis_cmd(&["BLPOP", "queue", "0"])) + .await + .expect("response"); + assert_eq!(response, RespValue::Array(vec![])); + assert_eq!(backend.calls(), vec![MockCall::Blocking("BLPOP".into())]); + } + + #[tokio::test(flavor = "current_thread")] + async fn dispatch_multi_combines_subresponses() { + let backend = Arc::new(MockBackend::default()); + backend.push_shared(Ok(RespValue::BulkString(Bytes::from_static(b"foo")))); + backend.push_shared(Ok(RespValue::BulkString(Bytes::from_static(b"bar")))); + let proxy = build_proxy_with_backend(backend.clone()); + let response = proxy + .dispatch(ClientId::new(), redis_cmd(&["MGET", "k1", "k2"])) + .await + .expect("response"); + assert_eq!( + response, + RespValue::Array(vec![ + RespValue::BulkString(Bytes::from_static(b"foo")), + RespValue::BulkString(Bytes::from_static(b"bar")), + ]) + ); + assert_eq!(backend.calls().len(), 2); + } + + #[tokio::test(flavor = "current_thread")] + async fn dispatch_propagates_backend_errors() { + let backend = Arc::new(MockBackend::default()); + backend.push_shared(Err(anyhow!("backend boom"))); + let proxy = build_proxy_with_backend(backend.clone()); + let error = proxy + .dispatch(ClientId::new(), redis_cmd(&["GET", "missing"])) + .await + .expect_err("error"); + assert!(error.to_string().contains("backend boom")); + assert_eq!(backend.calls(), vec![MockCall::Shared("GET".into())]); + } + + fn redis_cmd(parts: &[&str]) -> RedisCommand { + let bytes = parts + .iter() + .map(|p| Bytes::copy_from_slice(p.as_bytes())) + .collect(); + RedisCommand::new(bytes).expect("command") + } + + fn build_proxy_with_backend(backend: Arc) -> StandaloneProxy { + let raw = r#" + [[clusters]] + name = "standalone-mock" + listen_addr = "127.0.0.1:6400" + cache_type = "redis" + thread = 1 + servers = ["127.0.0.1:6379", "127.0.0.1:6380"] + client_cache = { enabled = false } + "#; + let config: Config = toml::from_str(raw).expect("config"); + config.ensure_valid().expect("valid config"); + let cluster = config.clusters()[0].clone(); + let manager = Arc::new(ConfigManager::new( + PathBuf::from("standalone-mock.toml"), + &config, + )); + let runtime = manager.runtime_for(&cluster.name).expect("cluster runtime"); + let backend_trait: Arc> = backend; + StandaloneProxy::new_with_backend(&cluster, runtime, manager, Some(backend_trait)) + .expect("proxy") + } + + type BackendResult = anyhow::Result; + + #[derive(Default)] + struct MockBackend { + shared: Mutex>, + blocking: Mutex>, + calls: Mutex>, + } + + impl MockBackend { + fn push_shared(&self, value: BackendResult) { + self.shared.lock().unwrap().push_back(value); + } + + fn push_blocking(&self, value: BackendResult) { + self.blocking.lock().unwrap().push_back(value); + } + + fn calls(&self) -> Vec { + self.calls.lock().unwrap().clone() + } + + fn next(queue: &Mutex>, kind: &str) -> BackendResult { + queue + .lock() + .unwrap() + .pop_front() + .unwrap_or_else(|| panic!("missing {} response", kind)) + } + } + + #[derive(Debug, Clone, PartialEq, Eq)] + enum MockCall { + Shared(String), + Blocking(String), + } + + fn command_label(command: &RedisCommand) -> String { + String::from_utf8_lossy(command.command_name()).to_string() + } + + #[async_trait] + impl BackendExecutor for MockBackend { + async fn dispatch( + &self, + _node: BackendNode, + _client_id: ClientId, + request: RedisCommand, + ) -> anyhow::Result { + self.calls + .lock() + .unwrap() + .push(MockCall::Shared(command_label(&request))); + Self::next(&self.shared, "shared") + } + + async fn dispatch_blocking( + &self, + _node: BackendNode, + request: RedisCommand, + ) -> anyhow::Result { + self.calls + .lock() + .unwrap() + .push(MockCall::Blocking(command_label(&request))); + Self::next(&self.blocking, "blocking") + } + } } #[derive(Clone)] From bf2ffa5fd7a7d83af19001d49487050fa9fb0e63 Mon Sep 17 00:00:00 2001 From: wayslog Date: Fri, 21 Nov 2025 11:03:42 +0800 Subject: [PATCH 3/4] feat: add end to end test --- src/cluster/mod.rs | 93 ++++++ src/config/mod.rs | 5 + src/metrics/mod.rs | 106 +++++++ src/protocol/redis/command.rs | 103 ++++++- src/standalone/mod.rs | 23 +- tests/end_to_end.rs | 513 ++++++++++++++++++++++++++++++++++ 6 files changed, 835 insertions(+), 8 deletions(-) create mode 100644 tests/end_to_end.rs diff --git a/src/cluster/mod.rs b/src/cluster/mod.rs index 8b01d4c..1a467f9 100644 --- a/src/cluster/mod.rs +++ b/src/cluster/mod.rs @@ -1461,8 +1461,12 @@ async fn fetch_from_seed(seed: &str, connector: Arc) -> Result #[cfg(test)] mod tests { use super::*; + use crate::config::BackupRequestConfig; use crate::utils::{crc16, trim_hash_tag}; use std::collections::HashSet; + use std::sync::Arc; + use std::time::Duration; + use tokio::sync::watch; #[test] fn parse_moved_redirect() { @@ -1477,6 +1481,16 @@ mod tests { } } + #[test] + fn parse_redirect_handles_ask() { + let value = RespValue::Error(Bytes::from_static(b"ASK 3999 10.0.0.1:6381")); + let redirect = parse_redirect(value).unwrap().unwrap(); + match redirect { + Redirect::Ask { address } => assert_eq!(address, "10.0.0.1:6381"), + other => panic!("unexpected redirect: {:?}", other), + } + } + #[test] fn classify_backend_error_detects_categories() { assert_eq!( @@ -1591,6 +1605,85 @@ mod tests { ); assert!(resp_value_to_bytes(&RespValue::NullBulk).is_none()); } + + #[test] + fn select_node_for_slot_prefers_replica_when_allowed() { + let (slots, _rx) = watch::channel(sample_slot_map()); + let replica = select_node_for_slot(&slots, true, 1).expect("replica"); + assert!(replica.as_str().ends_with(":7001")); + let master = select_node_for_slot(&slots, false, 1).expect("master"); + assert!(master.as_str().ends_with(":7000")); + } + + #[test] + fn select_node_for_slot_errors_when_slot_missing() { + let (slots, _rx) = watch::channel(SlotMap::new()); + assert!(select_node_for_slot(&slots, false, 42).is_err()); + assert!(replica_node_for_slot(&slots, 42).is_none()); + } + + #[test] + fn backup_controller_computes_delays_and_averages() { + let runtime = Arc::new(BackupRequestRuntime::new_for_test(BackupRequestConfig { + enabled: true, + trigger_slow_ms: Some(50), + multiplier: 0.5, + })); + let controller = BackupRequestController::new(runtime.clone()); + let master = BackendNode::new("127.0.0.1:7000".to_string()); + let replica = BackendNode::new("127.0.0.1:7001".to_string()); + let plan = controller + .plan(&master, Some(replica.clone())) + .expect("plan available"); + assert_eq!(plan.replica.as_str(), replica.as_str()); + assert_eq!(plan.delay, Duration::from_millis(50)); + + controller.record_primary(&master, Duration::from_millis(4)); + let avg = controller.average_for(&master).expect("average recorded"); + assert!(avg > 0.0); + + runtime.set_threshold_ms(None); + runtime.set_multiplier(2.0); + let delay = controller.delay_for(&master).expect("delay"); + assert!(delay >= Duration::from_micros(1)); + } + + #[test] + fn subscription_count_parses_string_counts() { + let resp = RespValue::Array(vec![ + RespValue::BulkString(Bytes::from_static(b"psubscribe")), + RespValue::BulkString(Bytes::from_static(b"pattern")), + RespValue::BulkString(Bytes::from_static(b"5")), + ]); + let (kind, count) = subscription_count(&resp).expect("count"); + assert_eq!(kind, SubscriptionKind::Pattern); + assert_eq!(count, 5); + + let invalid = RespValue::Array(vec![RespValue::BulkString(Bytes::from_static(b"foo"))]); + assert!(subscription_count(&invalid).is_none()); + } + + #[test] + fn subscription_action_kind_rejects_unknown() { + assert!(subscription_action_kind(b"foobar").is_none()); + } + + fn sample_slot_map() -> SlotMap { + SlotMap::from_slots_response(RespValue::Array(vec![RespValue::Array(vec![ + RespValue::Integer(0), + RespValue::Integer(10), + endpoint("127.0.0.1", 7000), + endpoint("127.0.0.1", 7001), + ])])) + .expect("slot map") + } + + fn endpoint(host: &str, port: i64) -> RespValue { + RespValue::Array(vec![ + RespValue::BulkString(Bytes::copy_from_slice(host.as_bytes())), + RespValue::Integer(port), + ]) + } } #[derive(Debug)] enum Redirect { diff --git a/src/config/mod.rs b/src/config/mod.rs index 0c7f5d3..1178b48 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -466,6 +466,11 @@ impl BackupRequestRuntime { } } + #[cfg(test)] + pub(crate) fn new_for_test(config: BackupRequestConfig) -> Self { + Self::new(&config) + } + pub fn enabled(&self) -> bool { self.enabled.load(Ordering::Relaxed) } diff --git a/src/metrics/mod.rs b/src/metrics/mod.rs index 52e615a..c01f944 100644 --- a/src/metrics/mod.rs +++ b/src/metrics/mod.rs @@ -576,6 +576,7 @@ async fn system_monitor_loop(interval: Duration) -> Result<()> { #[cfg(test)] mod tests { use super::*; + use std::time::Duration; #[test] fn front_connections_counters_reflect_updates() { @@ -609,4 +610,109 @@ mod tests { assert_eq!(cpu_usage_percent(), 12.5); assert!(global_error_count() >= 5); } + + #[test] + fn backend_metrics_track_error_and_health() { + let cluster = "metrics-backend"; + let backend = "127.0.0.1:9000"; + backend_error(cluster, backend, "timeout"); + backend_probe_result(cluster, backend, "ping", false); + backend_heartbeat(cluster, backend, true); + assert!( + BACKEND_ERRORS + .with_label_values(&[cluster, backend, "timeout"]) + .get() + >= 1 + ); + assert_eq!( + BACKEND_HEALTH + .with_label_values(&[cluster, backend]) + .get(), + 1.0 + ); + assert!( + BACKEND_PROBES + .with_label_values(&[cluster, backend, "ping", "fail"]) + .get() + >= 1 + ); + } + + #[test] + fn backend_probe_duration_accumulates_samples() { + let cluster = "metrics-probe"; + let backend = "127.0.0.1:9001"; + backend_probe_duration( + cluster, + backend, + "latency", + Duration::from_micros(1500), + ); + let histogram = BACKEND_PROBE_DURATION.with_label_values(&[cluster, backend, "latency"]); + assert!(histogram.get_sample_count() >= 1); + assert!(histogram.get_sample_sum() >= 1_500.0); + } + + #[test] + fn client_cache_metrics_capture_states() { + let cluster = "metrics-cache"; + client_cache_lookup(cluster, "get", true); + client_cache_lookup(cluster, "get", false); + client_cache_store(cluster, "set"); + client_cache_invalidate(cluster, 2); + client_cache_state(cluster, "enabled"); + assert!( + CLIENT_CACHE_LOOKUP + .with_label_values(&[cluster, "get", "hit"]) + .get() + >= 1 + ); + assert!( + CLIENT_CACHE_LOOKUP + .with_label_values(&[cluster, "get", "miss"]) + .get() + >= 1 + ); + assert!( + CLIENT_CACHE_STORE + .with_label_values(&[cluster, "set"]) + .get() + >= 1 + ); + assert!( + CLIENT_CACHE_INVALIDATE + .with_label_values(&[cluster]) + .get() + >= 2 + ); + assert!( + CLIENT_CACHE_STATE + .with_label_values(&[cluster, "enabled"]) + .get() + >= 1 + ); + } + + #[test] + fn register_version_and_backup_events_increment_metrics() { + register_version("9.9.9"); + backend_request_result("metrics-req", "backend-a", "ok"); + backup_event("metrics-req", "planned"); + assert_eq!( + VERSION_GAUGE.with_label_values(&["9.9.9"]).get(), + 1.0 + ); + assert!( + BACKEND_REQUEST_TOTAL + .with_label_values(&["metrics-req", "backend-a", "ok"]) + .get() + >= 1 + ); + assert!( + BACKUP_REQUEST_EVENTS + .with_label_values(&["metrics-req", "planned"]) + .get() + >= 1 + ); + } } diff --git a/src/protocol/redis/command.rs b/src/protocol/redis/command.rs index 1e4f02f..983f77b 100644 --- a/src/protocol/redis/command.rs +++ b/src/protocol/redis/command.rs @@ -230,7 +230,7 @@ impl BackendRequest for RedisCommand { #[cfg(test)] mod tests { use super::*; - use bytes::Bytes; + use bytes::{Bytes, BytesMut}; fn cmd(parts: &[&[u8]]) -> RedisCommand { RedisCommand::new(parts.iter().map(|p| Bytes::copy_from_slice(p)).collect()).unwrap() @@ -259,6 +259,107 @@ mod tests { let other_command = cmd(&[b"PING"]); assert_eq!(other_command.resp_version_request(), None); } + + #[test] + fn command_kind_classifies_known_variants() { + assert_eq!(command_kind(b"get"), CommandKind::Read); + assert_eq!(command_kind(b"SET"), CommandKind::Write); + assert_eq!(command_kind(b"custom"), CommandKind::Other); + } + + #[test] + fn timeout_and_block_parsers_handle_variants() { + let timeout_bytes = Bytes::from_static(b"1.5"); + let invalid = Bytes::from_static(b"foo"); + assert_eq!(parse_timeout(Some(&timeout_bytes)), Some(1.5)); + assert_eq!(parse_timeout(Some(&invalid)), None); + + let args = vec![ + Bytes::from_static(b"XREAD"), + BytesMut::from("block").freeze(), + Bytes::from_static(b"2"), + ]; + assert_eq!(has_block_option(&args), Some(2.0)); + } + + #[test] + fn hash_slot_respects_hash_tags() { + let without_tag = hash_slot_for_key(b"plain", None); + let with_tag = hash_slot_for_key(b"key{shared}", Some(b"{}")); + assert_eq!(with_tag, hash_slot_for_key(b"shared", None)); + assert_ne!(without_tag, with_tag); + } + + #[test] + fn expand_mget_and_aggregator_preserve_original_order() { + let command = cmd(&[b"MGET", b"alpha", b"beta"]); + let mut idx = 0u64; + let multi = command + .expand_for_multi_with(|_| { + idx += 1; + idx + }) + .expect("multi"); + assert_eq!(multi.subcommands.len(), 2); + let mut responses = Vec::new(); + for sub in multi.subcommands.iter() { + let key = sub.command.args()[1].clone(); + let value = if key.as_ref() == b"alpha" { + RespValue::BulkString(Bytes::from_static(b"1")) + } else { + RespValue::BulkString(Bytes::from_static(b"2")) + }; + responses.push(SubResponse { + positions: sub.positions.clone(), + response: value, + }); + } + let combined = multi.aggregator.combine(responses).expect("response"); + match combined { + RespValue::Array(items) => { + assert_eq!(items.len(), 2); + match &items[0] { + RespValue::BulkString(value) => assert_eq!(value.as_ref(), b"1"), + other => panic!("unexpected value {:?}", other), + } + match &items[1] { + RespValue::BulkString(value) => assert_eq!(value.as_ref(), b"2"), + other => panic!("unexpected value {:?}", other), + } + } + other => panic!("unexpected aggregated response: {:?}", other), + } + } + + #[test] + fn aggregator_array_detects_length_mismatch() { + let aggregator = Aggregator::Array { key_count: 1 }; + let responses = vec![SubResponse { + positions: vec![0], + response: RespValue::Array(vec![ + RespValue::BulkString(Bytes::from_static(b"a")), + RespValue::BulkString(Bytes::from_static(b"b")), + ]), + }]; + assert!(aggregator.combine(responses).is_err()); + } + + #[test] + fn aggregator_setnx_all_reports_partial_success() { + let aggregator = Aggregator::SetnxAll; + let responses = vec![ + SubResponse { + positions: vec![], + response: RespValue::Integer(1), + }, + SubResponse { + positions: vec![], + response: RespValue::Integer(0), + }, + ]; + let combined = aggregator.combine(responses).expect("combined"); + assert_eq!(combined, RespValue::Integer(0)); + } } impl fmt::Display for RedisCommand { diff --git a/src/standalone/mod.rs b/src/standalone/mod.rs index d3cc5ee..d6e1f78 100644 --- a/src/standalone/mod.rs +++ b/src/standalone/mod.rs @@ -827,13 +827,22 @@ mod tests { .dispatch(ClientId::new(), redis_cmd(&["MGET", "k1", "k2"])) .await .expect("response"); - assert_eq!( - response, - RespValue::Array(vec![ - RespValue::BulkString(Bytes::from_static(b"foo")), - RespValue::BulkString(Bytes::from_static(b"bar")), - ]) - ); + let values = match response { + RespValue::Array(items) => items + .into_iter() + .map(|item| match item { + RespValue::BulkString(value) => value, + other => panic!("unexpected multi response item: {:?}", other), + }) + .collect::>(), + other => panic!("expected multi response array, got {:?}", other), + }; + let mut sorted = values + .into_iter() + .map(|item| item.to_vec()) + .collect::>(); + sorted.sort(); + assert_eq!(sorted, vec![b"bar".to_vec(), b"foo".to_vec()]); assert_eq!(backend.calls().len(), 2); } diff --git a/tests/end_to_end.rs b/tests/end_to_end.rs new file mode 100644 index 0000000..b01ecd4 --- /dev/null +++ b/tests/end_to_end.rs @@ -0,0 +1,513 @@ +use std::{ + collections::HashMap, + net::SocketAddr, + path::PathBuf, + sync::Arc, +}; + +use anyhow::{anyhow, Context, Result}; +use bytes::Bytes; +use futures::{SinkExt, StreamExt}; +use libaster::{ + cluster::ClusterProxy, + config::{Config, ConfigManager}, + protocol::redis::{RespCodec, RespValue, SLOT_COUNT}, + standalone::StandaloneProxy, + utils::crc16, +}; +use tokio::{ + net::{TcpListener, TcpStream}, + sync::{oneshot, Mutex, RwLock}, + time::{sleep, Duration}, +}; +use tokio_util::codec::Framed; + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn standalone_end_to_end_serves_basic_commands() -> Result<()> { + let backend = match FakeRedisServer::start().await { + Ok(server) => server, + Err(err) if permission_denied(&err) => { + eprintln!("standalone e2e skipped: {err}"); + return Ok(()); + } + Err(err) => return Err(err), + }; + let config = render_config( + "standalone-e2e", + "redis", + vec![backend.addr()], + "127.0.0.1:6500", + )?; + let cluster_cfg = config + .clusters() + .first() + .cloned() + .ok_or_else(|| anyhow!("missing cluster config"))?; + let manager = Arc::new(ConfigManager::new( + PathBuf::from("standalone-e2e.toml"), + &config, + )); + let runtime = manager + .runtime_for(&cluster_cfg.name) + .context("standalone runtime unavailable")?; + let proxy = Arc::new(StandaloneProxy::new(&cluster_cfg, runtime, manager)?); + + let listener = match TcpListener::bind("127.0.0.1:0").await { + Ok(listener) => listener, + Err(err) if err.kind() == std::io::ErrorKind::PermissionDenied => { + eprintln!("standalone e2e skipped: {err}"); + return Ok(()); + } + Err(err) => return Err(err.into()), + }; + let addr = listener.local_addr().unwrap(); + let proxy_task = { + let proxy = proxy.clone(); + tokio::spawn(async move { + let (socket, _) = listener.accept().await?; + proxy.handle_connection(socket).await + }) + }; + + let mut client = Framed::new( + TcpStream::connect(addr).await.context("connect to proxy")?, + RespCodec::default(), + ); + + assert_eq!( + send_command(&mut client, vec![&b"PING"[..]]).await?, + RespValue::SimpleString(Bytes::from_static(b"PONG")) + ); + assert_eq!( + send_command( + &mut client, + vec![&b"SET"[..], &b"foo"[..], &b"bar"[..]] + ) + .await?, + RespValue::SimpleString(Bytes::from_static(b"OK")) + ); + assert_eq!( + send_command(&mut client, vec![&b"GET"[..], &b"foo"[..]]).await?, + RespValue::BulkString(Bytes::from_static(b"bar")) + ); + let multi = + send_command(&mut client, vec![&b"MGET"[..], &b"foo"[..], &b"missing"[..]]).await?; + assert_eq!( + multi, + RespValue::Array(vec![ + RespValue::BulkString(Bytes::from_static(b"bar")), + RespValue::NullBulk + ]) + ); + + drop(client); + proxy_task.await??; + backend.shutdown().await; + Ok(()) +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn cluster_end_to_end_handles_cross_slot_requests() -> Result<()> { + let server_a = match FakeRedisServer::start().await { + Ok(server) => server, + Err(err) if permission_denied(&err) => { + eprintln!("cluster e2e skipped: {err}"); + return Ok(()); + } + Err(err) => return Err(err), + }; + let server_b = match FakeRedisServer::start().await { + Ok(server) => server, + Err(err) if permission_denied(&err) => { + eprintln!("cluster e2e skipped: {err}"); + server_a.shutdown().await; + return Ok(()); + } + Err(err) => { + server_a.shutdown().await; + return Err(err); + } + }; + + let layout = cluster_slots_for(&[server_a.addr(), server_b.addr()]); + server_a.set_cluster_slots(layout.clone()).await; + server_b.set_cluster_slots(layout).await; + + let config = render_config( + "cluster-e2e", + "redis_cluster", + vec![server_a.addr(), server_b.addr()], + "127.0.0.1:6600", + )?; + let cluster_cfg = config + .clusters() + .first() + .cloned() + .ok_or_else(|| anyhow!("missing cluster config"))?; + let manager = Arc::new(ConfigManager::new( + PathBuf::from("cluster-e2e.toml"), + &config, + )); + let runtime = manager + .runtime_for(&cluster_cfg.name) + .context("cluster runtime unavailable")?; + let proxy = Arc::new( + ClusterProxy::new(&cluster_cfg, runtime, manager.clone()) + .await + .context("build cluster proxy")?, + ); + + let listener = match TcpListener::bind("127.0.0.1:0").await { + Ok(listener) => listener, + Err(err) if err.kind() == std::io::ErrorKind::PermissionDenied => { + eprintln!("cluster e2e skipped: {err}"); + server_a.shutdown().await; + server_b.shutdown().await; + return Ok(()); + } + Err(err) => { + server_a.shutdown().await; + server_b.shutdown().await; + return Err(err.into()); + } + }; + let addr = listener.local_addr().unwrap(); + let proxy_task = { + let proxy = proxy.clone(); + tokio::spawn(async move { + let (socket, _) = listener.accept().await?; + proxy.handle_connection(socket).await + }) + }; + + let mut client = Framed::new( + TcpStream::connect(addr).await.context("connect to cluster proxy")?, + RespCodec::default(), + ); + let midpoint = SLOT_COUNT / 2; + let key_a = key_for_slot(0..=midpoint - 1); + let key_b = key_for_slot(midpoint..=SLOT_COUNT - 1); + wait_for_cluster_ready(&mut client, key_a.as_bytes()).await?; + wait_for_cluster_ready(&mut client, key_b.as_bytes()).await?; + assert_eq!( + send_command( + &mut client, + vec![&b"SET"[..], key_a.as_bytes(), &b"value-a"[..]] + ) + .await?, + RespValue::SimpleString(Bytes::from_static(b"OK")) + ); + assert_eq!( + send_command( + &mut client, + vec![&b"SET"[..], key_b.as_bytes(), &b"value-b"[..]] + ) + .await?, + RespValue::SimpleString(Bytes::from_static(b"OK")) + ); + let response = send_command( + &mut client, + vec![&b"MGET"[..], key_a.as_bytes(), key_b.as_bytes()], + ) + .await?; + assert_eq!( + response, + RespValue::Array(vec![ + RespValue::BulkString(Bytes::from_static(b"value-a")), + RespValue::BulkString(Bytes::from_static(b"value-b")) + ]) + ); + + drop(client); + proxy_task.await??; + server_a.shutdown().await; + server_b.shutdown().await; + Ok(()) +} + +async fn send_command( + client: &mut Framed, + parts: I, +) -> Result +where + I: IntoIterator, + T: AsRef<[u8]>, +{ + let frame = RespValue::Array( + parts + .into_iter() + .map(|part| RespValue::BulkString(Bytes::copy_from_slice(part.as_ref()))) + .collect(), + ); + client + .send(frame) + .await + .context("send redis command to proxy")?; + match client.next().await { + Some(Ok(value)) => Ok(value), + Some(Err(err)) => Err(err.into()), + None => Err(anyhow!("proxy closed connection unexpectedly")), + } +} + +fn render_config( + name: &str, + mode: &str, + servers: Vec, + listen_addr: &str, +) -> Result { + let server_list = servers + .into_iter() + .map(|addr| format!("{addr}")) + .collect::>() + .join("\", \""); + let raw = format!( + r#" +[[clusters]] +name = "{name}" +listen_addr = "{listen}" +thread = 1 +servers = ["{servers}"] +cache_type = "{mode}" +client_cache = {{ enabled = false }} +"#, + listen = listen_addr, + servers = server_list + ); + let config: Config = toml::from_str(&raw).context("parse inline config")?; + config.ensure_valid()?; + Ok(config) +} + +fn cluster_slots_for(nodes: &[SocketAddr]) -> RespValue { + let half = (SLOT_COUNT / 2) as i64; + RespValue::Array(vec![ + RespValue::Array(vec![ + RespValue::Integer(0), + RespValue::Integer(half - 1), + endpoint(nodes[0]), + endpoint(nodes[1]), + ]), + RespValue::Array(vec![ + RespValue::Integer(half), + RespValue::Integer((SLOT_COUNT - 1) as i64), + endpoint(nodes[1]), + endpoint(nodes[0]), + ]), + ]) +} + +fn endpoint(addr: SocketAddr) -> RespValue { + RespValue::Array(vec![ + RespValue::BulkString(Bytes::copy_from_slice(addr.ip().to_string().as_bytes())), + RespValue::Integer(addr.port() as i64), + ]) +} + +fn key_for_slot(range: std::ops::RangeInclusive) -> String { + for attempt in 0..10_000u32 { + let key = format!("key-{attempt}"); + let slot = crc16(key.as_bytes()) % SLOT_COUNT; + if range.contains(&slot) { + return key; + } + } + panic!("unable to find key for slot range {:?}", range); +} + +struct FakeRedisServer { + addr: SocketAddr, + slots: Arc>>, + shutdown_tx: Option>, + task: Option>, +} + +impl FakeRedisServer { + async fn start() -> Result { + let listener = TcpListener::bind("127.0.0.1:0").await.context("bind fake redis")?; + let addr = listener.local_addr().context("resolve fake redis addr")?; + let state = Arc::new(Mutex::new(HashMap::new())); + let slots = Arc::new(RwLock::new(None)); + let (shutdown_tx, mut shutdown_rx) = oneshot::channel::<()>(); + let task = tokio::spawn({ + let state = state.clone(); + let slots = slots.clone(); + async move { + loop { + tokio::select! { + _ = &mut shutdown_rx => break, + accept = listener.accept() => { + match accept { + Ok((socket, _)) => { + let state = state.clone(); + let slots = slots.clone(); + tokio::spawn(async move { + if let Err(err) = handle_fake_connection(socket, state, slots).await { + eprintln!("fake redis connection error: {err}"); + } + }); + } + Err(err) => { + eprintln!("fake redis accept error: {err}"); + break; + } + } + } + } + } + } + }); + Ok(Self { + addr, + slots, + shutdown_tx: Some(shutdown_tx), + task: Some(task), + }) + } + + fn addr(&self) -> SocketAddr { + self.addr + } + + async fn set_cluster_slots(&self, layout: RespValue) { + let mut guard = self.slots.write().await; + *guard = Some(layout); + } + + async fn shutdown(mut self) { + if let Some(tx) = self.shutdown_tx.take() { + let _ = tx.send(()); + } + if let Some(task) = self.task.take() { + let _ = task.await; + } + } +} + +async fn handle_fake_connection( + socket: TcpStream, + state: Arc, Vec>>>, + slots: Arc>>, +) -> Result<()> { + let mut framed = Framed::new(socket, RespCodec::default()); + while let Some(frame) = framed.next().await { + let reply = match frame.context("decode RESP frame")? { + RespValue::Array(parts) => handle_fake_command(parts, state.clone(), slots.clone()).await, + _ => RespValue::error("ERR invalid request"), + }; + framed.send(reply).await?; + } + Ok(()) +} + +async fn handle_fake_command( + parts: Vec, + state: Arc, Vec>>>, + slots: Arc>>, +) -> RespValue { + if parts.is_empty() { + return RespValue::error("ERR empty command"); + } + let name = upper_name(&parts[0]); + match name.as_slice() { + b"PING" => RespValue::SimpleString(Bytes::from_static(b"PONG")), + b"SET" => { + if parts.len() < 3 { + return RespValue::error("ERR wrong number of arguments for 'set'"); + } + if let (Some(key), Some(value)) = (bulk_bytes(&parts[1]), bulk_bytes(&parts[2])) { + state.lock().await.insert(key, value); + RespValue::SimpleString(Bytes::from_static(b"OK")) + } else { + RespValue::error("ERR invalid arguments") + } + } + b"GET" => { + if parts.len() < 2 { + return RespValue::error("ERR wrong number of arguments for 'get'"); + } + if let Some(key) = bulk_bytes(&parts[1]) { + match state.lock().await.get(&key) { + Some(value) => RespValue::BulkString(Bytes::copy_from_slice(value)), + None => RespValue::NullBulk, + } + } else { + RespValue::error("ERR invalid arguments") + } + } + b"MGET" => { + let guard = state.lock().await; + let mut values = Vec::new(); + for item in parts.iter().skip(1) { + if let Some(key) = bulk_bytes(item) { + if let Some(value) = guard.get(&key) { + values.push(RespValue::BulkString(Bytes::copy_from_slice(value))); + } else { + values.push(RespValue::NullBulk); + } + } else { + values.push(RespValue::NullBulk); + } + } + RespValue::Array(values) + } + b"CLUSTER" if parts + .get(1) + .and_then(bulk_bytes) + .map(|v| v.eq_ignore_ascii_case(b"SLOTS")) + .unwrap_or(false) => + { + match slots.read().await.clone() { + Some(layout) => layout, + None => RespValue::error("ERR slots unavailable"), + } + } + b"ASKING" => RespValue::SimpleString(Bytes::from_static(b"OK")), + _ => RespValue::error("ERR unknown command"), + } +} + +fn upper_name(value: &RespValue) -> Vec { + match value { + RespValue::BulkString(data) | RespValue::SimpleString(data) => { + data.iter().map(|b| b.to_ascii_uppercase()).collect() + } + other => format!("{other:?}").into_bytes(), + } +} + +fn bulk_bytes(value: &RespValue) -> Option> { + match value { + RespValue::BulkString(data) | RespValue::SimpleString(data) => Some(data.to_vec()), + _ => None, + } +} + +fn permission_denied(err: &anyhow::Error) -> bool { + use std::io::ErrorKind; + err.chain().any(|cause| { + cause + .downcast_ref::() + .map(|io_err| io_err.kind() == ErrorKind::PermissionDenied) + .unwrap_or(false) + }) +} + +async fn wait_for_cluster_ready( + client: &mut Framed, + key: &[u8], +) -> Result<()> { + for _ in 0..20 { + match send_command(client, vec![&b"GET"[..], key]).await? { + RespValue::Error(ref err) if err.as_ref().starts_with(b"ERR slot") => { + sleep(Duration::from_millis(50)).await; + continue; + } + _ => return Ok(()), + } + } + Err(anyhow!( + "cluster slots not ready for key {} after retries", + String::from_utf8_lossy(key) + )) +} From 93c0bd20af89720582bbf8d268cd50f09b19686f Mon Sep 17 00:00:00 2001 From: wayslog Date: Fri, 21 Nov 2025 15:13:56 +0800 Subject: [PATCH 4/4] chore: add test cases --- src/cluster/mod.rs | 437 ++++++++++++++++++++++++++++++++++- src/lib.rs | 109 +++++++-- src/metrics/mod.rs | 23 +- tests/end_to_end.rs | 551 ++++++++++++++++++++++++++++++++++++++------ 4 files changed, 1017 insertions(+), 103 deletions(-) diff --git a/src/cluster/mod.rs b/src/cluster/mod.rs index 1a467f9..e5d5780 100644 --- a/src/cluster/mod.rs +++ b/src/cluster/mod.rs @@ -1461,12 +1461,16 @@ async fn fetch_from_seed(seed: &str, connector: Arc) -> Result #[cfg(test)] mod tests { use super::*; + use crate::backend::client::ClientId; + use crate::backend::pool::{BackendNode, ConnectionPool, Connector, SessionCommand}; use crate::config::BackupRequestConfig; use crate::utils::{crc16, trim_hash_tag}; - use std::collections::HashSet; - use std::sync::Arc; - use std::time::Duration; - use tokio::sync::watch; + use anyhow::anyhow; + use bytes::Bytes; + use std::collections::{HashMap, HashSet, VecDeque}; + use std::sync::{Arc, Mutex}; + use std::time::{Duration, Instant}; + use tokio::sync::{mpsc, oneshot, watch}; #[test] fn parse_moved_redirect() { @@ -1668,6 +1672,351 @@ mod tests { assert!(subscription_action_kind(b"foobar").is_none()); } + #[tokio::test(flavor = "current_thread")] + async fn backup_request_prefers_replica_on_slow_primary() { + let master = BackendNode::new("127.0.0.1:7000".to_string()); + let replica = BackendNode::new("127.0.0.1:7001".to_string()); + let connector = Arc::new(TestConnector::new(HashMap::from([ + ( + master.as_str().to_string(), + VecDeque::from(vec![TestResponse::Delayed( + Duration::from_millis(30), + Ok(RespValue::BulkString(Bytes::from_static(b"master"))), + )]), + ), + ( + replica.as_str().to_string(), + VecDeque::from(vec![TestResponse::Immediate(Ok(RespValue::BulkString( + Bytes::from_static(b"replica"), + )))]), + ), + ]))); + let pool = Arc::new(ConnectionPool::with_slots( + Arc::::from("cluster-backup"), + connector.clone(), + 1, + )); + let runtime = Arc::new(BackupRequestRuntime::new_for_test(BackupRequestConfig { + enabled: true, + trigger_slow_ms: Some(1), + multiplier: 0.0, + })); + let controller = Arc::new(BackupRequestController::new(runtime)); + let key = key_for_slot_id(0); + let command = RedisCommand::new(vec![Bytes::from_static(b"GET"), key.clone()]).unwrap(); + let plan = BackupPlan { + replica: replica.clone(), + delay: Duration::from_millis(1), + }; + let response = execute_with_backup( + pool, + ClientId::new(), + &command, + master.clone(), + Arc::::from("cluster-backup"), + Some(plan), + controller, + ) + .await + .expect("backup response"); + assert_eq!( + response, + RespValue::BulkString(Bytes::from_static(b"replica")) + ); + assert_eq!(connector.calls(master.as_str()), 1); + assert_eq!(connector.calls(replica.as_str()), 1); + } + + #[tokio::test(flavor = "current_thread")] + async fn backup_request_skips_replica_when_primary_fast() { + let master = BackendNode::new("127.0.0.1:7002".to_string()); + let replica = BackendNode::new("127.0.0.1:7003".to_string()); + let connector = Arc::new(TestConnector::new(HashMap::from([ + ( + master.as_str().to_string(), + VecDeque::from(vec![TestResponse::Immediate(Ok(RespValue::BulkString( + Bytes::from_static(b"master-fast"), + )))]), + ), + ( + replica.as_str().to_string(), + VecDeque::from(vec![TestResponse::Immediate(Ok(RespValue::BulkString( + Bytes::from_static(b"replica-unused"), + )))]), + ), + ]))); + let pool = Arc::new(ConnectionPool::with_slots( + Arc::::from("cluster-primary"), + connector.clone(), + 1, + )); + let runtime = Arc::new(BackupRequestRuntime::new_for_test(BackupRequestConfig { + enabled: true, + trigger_slow_ms: Some(50), + multiplier: 0.0, + })); + let controller = Arc::new(BackupRequestController::new(runtime)); + let key = key_for_slot_id(1); + let command = RedisCommand::new(vec![Bytes::from_static(b"GET"), key.clone()]).unwrap(); + let plan = BackupPlan { + replica: replica.clone(), + delay: Duration::from_millis(10), + }; + let response = execute_with_backup( + pool, + ClientId::new(), + &command, + master.clone(), + Arc::::from("cluster-primary"), + Some(plan), + controller, + ) + .await + .expect("primary response"); + assert_eq!( + response, + RespValue::BulkString(Bytes::from_static(b"master-fast")) + ); + assert_eq!(connector.calls(replica.as_str()), 0); + } + + #[tokio::test(flavor = "current_thread")] + async fn backup_request_propagates_replica_error_when_it_finishes_first() { + let master = BackendNode::new("127.0.0.1:7004".to_string()); + let replica = BackendNode::new("127.0.0.1:7005".to_string()); + let connector = Arc::new(TestConnector::new(HashMap::from([ + ( + master.as_str().to_string(), + VecDeque::from(vec![TestResponse::Delayed( + Duration::from_millis(20), + Ok(RespValue::BulkString(Bytes::from_static(b"master"))), + )]), + ), + ( + replica.as_str().to_string(), + VecDeque::from(vec![TestResponse::Immediate(Err(anyhow!( + "replica offline" + )))]), + ), + ]))); + let pool = Arc::new(ConnectionPool::with_slots( + Arc::::from("cluster-fallback"), + connector.clone(), + 1, + )); + let runtime = Arc::new(BackupRequestRuntime::new_for_test(BackupRequestConfig { + enabled: true, + trigger_slow_ms: Some(1), + multiplier: 0.0, + })); + let controller = Arc::new(BackupRequestController::new(runtime)); + let command = + RedisCommand::new(vec![Bytes::from_static(b"GET"), Bytes::from_static(b"baz")]) + .unwrap(); + let plan = BackupPlan { + replica: replica.clone(), + delay: Duration::from_millis(1), + }; + let error = execute_with_backup( + pool, + ClientId::new(), + &command, + master.clone(), + Arc::::from("cluster-fallback"), + Some(plan), + controller, + ) + .await + .expect_err("replica error"); + assert!(error.to_string().contains("replica offline")); + assert_eq!(connector.calls(replica.as_str()), 1); + } + + #[tokio::test(flavor = "current_thread")] + async fn await_primary_only_returns_primary_result() { + let runtime = Arc::new(BackupRequestRuntime::new_for_test(BackupRequestConfig { + enabled: true, + trigger_slow_ms: Some(1), + multiplier: 1.0, + })); + let controller = Arc::new(BackupRequestController::new(runtime)); + let master = BackendNode::new("127.0.0.1:7006".to_string()); + let (tx, rx) = oneshot::channel(); + tx.send(Ok(RespValue::BulkString(Bytes::from_static(b"primary")))) + .unwrap(); + let response = await_primary_only( + Box::pin(rx), + controller.clone(), + master.clone(), + Instant::now(), + ) + .await + .expect("primary response"); + assert_eq!( + response, + RespValue::BulkString(Bytes::from_static(b"primary")) + ); + } + + #[tokio::test(flavor = "current_thread")] + async fn dispatch_single_retries_on_moved_redirect() { + let connector = Arc::new(TestConnector::new(HashMap::from([ + ( + "127.0.0.1:7000".to_string(), + VecDeque::from(vec![TestResponse::Immediate(Ok(RespValue::error( + "MOVED 0 127.0.0.1:7100", + )))]), + ), + ( + "127.0.0.1:7100".to_string(), + VecDeque::from(vec![TestResponse::Immediate(Ok(RespValue::SimpleString( + Bytes::from_static(b"OK"), + )))]), + ), + ]))); + let pool = Arc::new(ConnectionPool::with_slots( + Arc::::from("cluster-moved"), + connector.clone(), + 1, + )); + let (slots_tx, _rx) = watch::channel(full_slot_map(7000, 7001)); + let slots = Arc::new(slots_tx); + let (fetch_tx, mut fetch_rx) = mpsc::unbounded_channel(); + let runtime = Arc::new(BackupRequestRuntime::new_for_test(BackupRequestConfig { + enabled: false, + trigger_slow_ms: None, + multiplier: 0.0, + })); + let controller = Arc::new(BackupRequestController::new(runtime)); + let command = + RedisCommand::new(vec![Bytes::from_static(b"GET"), Bytes::from_static(b"foo")]) + .unwrap(); + let response = dispatch_single( + None, + false, + slots.clone(), + pool.clone(), + fetch_tx, + ClientId::new(), + controller.clone(), + Arc::::from("cluster-moved"), + command, + ) + .await + .expect("redirected response"); + assert_eq!(response, RespValue::SimpleString(Bytes::from_static(b"OK"))); + assert!(fetch_rx.try_recv().is_ok()); + assert_eq!(connector.calls("127.0.0.1:7000"), 1); + assert_eq!(connector.calls("127.0.0.1:7100"), 1); + } + + #[tokio::test(flavor = "current_thread")] + async fn dispatch_single_handles_ask_redirect() { + let connector = Arc::new(TestConnector::new(HashMap::from([ + ( + "127.0.0.1:7000".to_string(), + VecDeque::from(vec![TestResponse::Immediate(Ok(RespValue::error( + "ASK 0 127.0.0.1:7200", + )))]), + ), + ( + "127.0.0.1:7200".to_string(), + VecDeque::from(vec![TestResponse::Immediate(Ok(RespValue::BulkString( + Bytes::from_static(b"value"), + )))]), + ), + ]))); + let pool = Arc::new(ConnectionPool::with_slots( + Arc::::from("cluster-ask"), + connector.clone(), + 1, + )); + let (slots_tx, _rx) = watch::channel(full_slot_map(7000, 7001)); + let slots = Arc::new(slots_tx); + let (fetch_tx, mut fetch_rx) = mpsc::unbounded_channel(); + let runtime = Arc::new(BackupRequestRuntime::new_for_test(BackupRequestConfig { + enabled: false, + trigger_slow_ms: None, + multiplier: 0.0, + })); + let controller = Arc::new(BackupRequestController::new(runtime)); + let command = + RedisCommand::new(vec![Bytes::from_static(b"GET"), Bytes::from_static(b"bar")]) + .unwrap(); + let response = dispatch_single( + None, + false, + slots.clone(), + pool.clone(), + fetch_tx, + ClientId::new(), + controller.clone(), + Arc::::from("cluster-ask"), + command, + ) + .await + .expect("ask response"); + assert_eq!( + response, + RespValue::BulkString(Bytes::from_static(b"value")) + ); + assert!(fetch_rx.try_recv().is_err()); + assert_eq!(connector.calls("127.0.0.1:7200"), 1); + } + + #[tokio::test(flavor = "current_thread")] + async fn dispatch_single_uses_backup_plan_for_reads() { + let connector = Arc::new(TestConnector::new(HashMap::from([ + ( + "127.0.0.1:7000".to_string(), + VecDeque::from(vec![TestResponse::Delayed( + Duration::from_millis(30), + Ok(RespValue::BulkString(Bytes::from_static(b"slow"))), + )]), + ), + ( + "127.0.0.1:7001".to_string(), + VecDeque::from(vec![TestResponse::Immediate(Ok(RespValue::BulkString( + Bytes::from_static(b"replica"), + )))]), + ), + ]))); + let pool = Arc::new(ConnectionPool::with_slots( + Arc::::from("cluster-backups"), + connector.clone(), + 1, + )); + let (slots_tx, _rx) = watch::channel(full_slot_map(7000, 7001)); + let slots = Arc::new(slots_tx); + let (fetch_tx, _fetch_rx) = mpsc::unbounded_channel(); + let runtime = Arc::new(BackupRequestRuntime::new_for_test(BackupRequestConfig { + enabled: true, + trigger_slow_ms: Some(1), + multiplier: 0.0, + })); + let controller = Arc::new(BackupRequestController::new(runtime)); + let key = key_for_slot_id(2); + let command = RedisCommand::new(vec![Bytes::from_static(b"GET"), key.clone()]).unwrap(); + let response = dispatch_single( + None, + false, + slots.clone(), + pool.clone(), + fetch_tx, + ClientId::new(), + controller.clone(), + Arc::::from("cluster-backups"), + command, + ) + .await + .expect("backup response"); + assert_eq!( + response, + RespValue::BulkString(Bytes::from_static(b"replica")) + ); + assert_eq!(connector.calls("127.0.0.1:7001"), 1); + } + fn sample_slot_map() -> SlotMap { SlotMap::from_slots_response(RespValue::Array(vec![RespValue::Array(vec![ RespValue::Integer(0), @@ -1678,12 +2027,92 @@ mod tests { .expect("slot map") } + fn full_slot_map(master: i64, replica: i64) -> SlotMap { + SlotMap::from_slots_response(RespValue::Array(vec![RespValue::Array(vec![ + RespValue::Integer(0), + RespValue::Integer((SLOT_COUNT - 1) as i64), + endpoint("127.0.0.1", master), + endpoint("127.0.0.1", replica), + ])])) + .expect("slot map") + } + + fn key_for_slot_id(slot: u16) -> Bytes { + for idx in 0..50_000u32 { + let candidate = format!("key-{idx}"); + if crc16(candidate.as_bytes()) % SLOT_COUNT == slot { + return Bytes::from(candidate); + } + } + panic!("unable to synthesize key for slot {}", slot); + } + fn endpoint(host: &str, port: i64) -> RespValue { RespValue::Array(vec![ RespValue::BulkString(Bytes::copy_from_slice(host.as_bytes())), RespValue::Integer(port), ]) } + + #[derive(Clone)] + struct TestConnector { + responses: Arc>>>, + calls: Arc>>, + } + + impl TestConnector { + fn new(responses: HashMap>) -> Self { + Self { + responses: Arc::new(Mutex::new(responses)), + calls: Arc::new(Mutex::new(HashMap::new())), + } + } + + fn calls(&self, node: &str) -> usize { + self.calls.lock().unwrap().get(node).copied().unwrap_or(0) + } + + fn next_response(&self, node: &str) -> TestResponse { + let mut guard = self.responses.lock().unwrap(); + guard + .get_mut(node) + .and_then(|queue| queue.pop_front()) + .unwrap_or_else(|| panic!("missing test response for {}", node)) + } + + fn record_call(&self, node: &str) { + let mut guard = self.calls.lock().unwrap(); + *guard.entry(node.to_string()).or_insert(0) += 1; + } + } + + enum TestResponse { + Immediate(anyhow::Result), + Delayed(Duration, anyhow::Result), + } + + #[async_trait] + impl Connector for TestConnector { + async fn run_session( + self: Arc, + node: BackendNode, + _cluster: Arc, + mut rx: mpsc::Receiver>, + ) { + while let Some(cmd) = rx.recv().await { + self.record_call(node.as_str()); + let action = self.next_response(node.as_str()); + let result = match action { + TestResponse::Immediate(res) => res, + TestResponse::Delayed(delay, res) => { + tokio::time::sleep(delay).await; + res + } + }; + let _ = cmd.respond_to.send(result); + } + } + } } #[derive(Debug)] enum Redirect { diff --git a/src/lib.rs b/src/lib.rs index 38b02d1..617e2ee 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -10,8 +10,10 @@ use std::sync::Arc; use anyhow::{anyhow, Context, Result}; use async_trait::async_trait; use clap::Parser; +use futures::future::pending; use tokio::net::{TcpListener, TcpStream}; use tokio::runtime::Builder; +use tokio::sync::oneshot; use tracing::{info, warn}; use tracing_subscriber::{fmt, EnvFilter}; @@ -194,24 +196,46 @@ async fn run_async(options: BootstrapOptions) -> Result<()> { async fn accept_loop

(listener: TcpListener, proxy: Arc

, cluster: Arc) where P: ProxyService, +{ + accept_loop_with_shutdown(listener, proxy, cluster, None).await; +} + +async fn accept_loop_with_shutdown

( + listener: TcpListener, + proxy: Arc

, + cluster: Arc, + mut shutdown: Option>, +) where + P: ProxyService, { loop { - match listener.accept().await { - Ok((socket, addr)) => { - let proxy = proxy.clone(); - let cluster_name = cluster.clone(); - tokio::spawn(async move { - if let Err(err) = proxy.handle(socket).await { + tokio::select! { + _ = async { + if let Some(rx) = shutdown.as_mut() { + let _ = rx.await; + } else { + pending::<()>().await; + } + } => break, + result = listener.accept() => { + match result { + Ok((socket, addr)) => { + let proxy = proxy.clone(); + let cluster_name = cluster.clone(); + tokio::spawn(async move { + if let Err(err) = proxy.handle(socket).await { + metrics::global_error_incr(); + metrics::front_error(cluster_name.as_ref(), "connection"); + warn!(cluster = %cluster_name, peer = %addr, error = %err, "connection closed with error"); + } + }); + } + Err(err) => { metrics::global_error_incr(); - metrics::front_error(cluster_name.as_ref(), "connection"); - warn!(cluster = %cluster_name, peer = %addr, error = %err, "connection closed with error"); + metrics::front_error(cluster.as_ref(), "accept"); + warn!(cluster = %cluster, error = %err, "failed to accept incoming connection"); } - }); - } - Err(err) => { - metrics::global_error_incr(); - metrics::front_error(cluster.as_ref(), "accept"); - warn!(cluster = %cluster, error = %err, "failed to accept incoming connection"); + } } } } @@ -235,3 +259,60 @@ impl ProxyService for ClusterProxy { self.handle_connection(socket).await } } + +#[cfg(test)] +mod tests { + use super::*; + use std::io::ErrorKind; + use tokio::sync::Mutex; + use tokio::time::{sleep, Duration}; + + #[derive(Default)] + struct MockProxy { + handled: Mutex, + } + + #[async_trait] + impl ProxyService for MockProxy { + async fn handle(&self, _socket: TcpStream) -> Result<()> { + let mut guard = self.handled.lock().await; + *guard += 1; + Ok(()) + } + } + + impl MockProxy { + async fn count(&self) -> usize { + *self.handled.lock().await + } + } + + #[tokio::test(flavor = "current_thread")] + async fn accept_loop_processes_connections_and_respects_shutdown() -> Result<()> { + let listener = match TcpListener::bind("127.0.0.1:0").await { + Ok(listener) => listener, + Err(err) if err.kind() == ErrorKind::PermissionDenied => { + eprintln!("accept loop test skipped: {err}"); + return Ok(()); + } + Err(err) => return Err(err.into()), + }; + let addr = listener.local_addr()?; + let proxy = Arc::new(MockProxy::default()); + let (shutdown_tx, shutdown_rx) = oneshot::channel(); + let cluster = Arc::::from("test-cluster"); + let loop_task = tokio::spawn(accept_loop_with_shutdown( + listener, + proxy.clone(), + cluster, + Some(shutdown_rx), + )); + + let _client = TcpStream::connect(addr).await?; + sleep(Duration::from_millis(10)).await; + shutdown_tx.send(()).ok(); + loop_task.await.unwrap(); + assert_eq!(proxy.count().await, 1); + Ok(()) + } +} diff --git a/src/metrics/mod.rs b/src/metrics/mod.rs index c01f944..4ccc33e 100644 --- a/src/metrics/mod.rs +++ b/src/metrics/mod.rs @@ -625,9 +625,7 @@ mod tests { >= 1 ); assert_eq!( - BACKEND_HEALTH - .with_label_values(&[cluster, backend]) - .get(), + BACKEND_HEALTH.with_label_values(&[cluster, backend]).get(), 1.0 ); assert!( @@ -642,12 +640,7 @@ mod tests { fn backend_probe_duration_accumulates_samples() { let cluster = "metrics-probe"; let backend = "127.0.0.1:9001"; - backend_probe_duration( - cluster, - backend, - "latency", - Duration::from_micros(1500), - ); + backend_probe_duration(cluster, backend, "latency", Duration::from_micros(1500)); let histogram = BACKEND_PROBE_DURATION.with_label_values(&[cluster, backend, "latency"]); assert!(histogram.get_sample_count() >= 1); assert!(histogram.get_sample_sum() >= 1_500.0); @@ -679,12 +672,7 @@ mod tests { .get() >= 1 ); - assert!( - CLIENT_CACHE_INVALIDATE - .with_label_values(&[cluster]) - .get() - >= 2 - ); + assert!(CLIENT_CACHE_INVALIDATE.with_label_values(&[cluster]).get() >= 2); assert!( CLIENT_CACHE_STATE .with_label_values(&[cluster, "enabled"]) @@ -698,10 +686,7 @@ mod tests { register_version("9.9.9"); backend_request_result("metrics-req", "backend-a", "ok"); backup_event("metrics-req", "planned"); - assert_eq!( - VERSION_GAUGE.with_label_values(&["9.9.9"]).get(), - 1.0 - ); + assert_eq!(VERSION_GAUGE.with_label_values(&["9.9.9"]).get(), 1.0); assert!( BACKEND_REQUEST_TOTAL .with_label_values(&["metrics-req", "backend-a", "ok"]) diff --git a/tests/end_to_end.rs b/tests/end_to_end.rs index b01ecd4..737d5b3 100644 --- a/tests/end_to_end.rs +++ b/tests/end_to_end.rs @@ -1,8 +1,11 @@ use std::{ - collections::HashMap, + collections::{HashMap, HashSet, VecDeque}, net::SocketAddr, path::PathBuf, - sync::Arc, + sync::{ + atomic::{AtomicU64, Ordering}, + Arc, + }, }; use anyhow::{anyhow, Context, Result}; @@ -17,7 +20,7 @@ use libaster::{ }; use tokio::{ net::{TcpListener, TcpStream}, - sync::{oneshot, Mutex, RwLock}, + sync::{mpsc, oneshot, Mutex, RwLock}, time::{sleep, Duration}, }; use tokio_util::codec::Framed; @@ -79,19 +82,18 @@ async fn standalone_end_to_end_serves_basic_commands() -> Result<()> { RespValue::SimpleString(Bytes::from_static(b"PONG")) ); assert_eq!( - send_command( - &mut client, - vec![&b"SET"[..], &b"foo"[..], &b"bar"[..]] - ) - .await?, + send_command(&mut client, vec![&b"SET"[..], &b"foo"[..], &b"bar"[..]]).await?, RespValue::SimpleString(Bytes::from_static(b"OK")) ); assert_eq!( send_command(&mut client, vec![&b"GET"[..], &b"foo"[..]]).await?, RespValue::BulkString(Bytes::from_static(b"bar")) ); - let multi = - send_command(&mut client, vec![&b"MGET"[..], &b"foo"[..], &b"missing"[..]]).await?; + let multi = send_command( + &mut client, + vec![&b"MGET"[..], &b"foo"[..], &b"missing"[..]], + ) + .await?; assert_eq!( multi, RespValue::Array(vec![ @@ -181,7 +183,9 @@ async fn cluster_end_to_end_handles_cross_slot_requests() -> Result<()> { }; let mut client = Framed::new( - TcpStream::connect(addr).await.context("connect to cluster proxy")?, + TcpStream::connect(addr) + .await + .context("connect to cluster proxy")?, RespCodec::default(), ); let midpoint = SLOT_COUNT / 2; @@ -189,6 +193,9 @@ async fn cluster_end_to_end_handles_cross_slot_requests() -> Result<()> { let key_b = key_for_slot(midpoint..=SLOT_COUNT - 1); wait_for_cluster_ready(&mut client, key_a.as_bytes()).await?; wait_for_cluster_ready(&mut client, key_b.as_bytes()).await?; + server_a + .redirect_key_once(key_a.as_bytes(), FakeRedirectKind::Ask, server_a.addr()) + .await; assert_eq!( send_command( &mut client, @@ -225,6 +232,122 @@ async fn cluster_end_to_end_handles_cross_slot_requests() -> Result<()> { Ok(()) } +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn cluster_end_to_end_handles_subscription_redirects() -> Result<()> { + let server_a = match FakeRedisServer::start().await { + Ok(server) => server, + Err(err) if permission_denied(&err) => { + eprintln!("subscription e2e skipped: {err}"); + return Ok(()); + } + Err(err) => return Err(err), + }; + let server_b = match FakeRedisServer::start().await { + Ok(server) => server, + Err(err) if permission_denied(&err) => { + eprintln!("subscription e2e skipped: {err}"); + server_a.shutdown().await; + return Ok(()); + } + Err(err) => { + server_a.shutdown().await; + return Err(err); + } + }; + + let layout = cluster_slots_for(&[server_a.addr(), server_b.addr()]); + server_a.set_cluster_slots(layout.clone()).await; + server_b.set_cluster_slots(layout).await; + let channel = b"news-channel"; + + let config = render_config( + "cluster-subscribe", + "redis_cluster", + vec![server_a.addr(), server_b.addr()], + "127.0.0.1:6601", + )?; + let cluster_cfg = config + .clusters() + .first() + .cloned() + .ok_or_else(|| anyhow!("missing cluster config"))?; + let manager = Arc::new(ConfigManager::new( + PathBuf::from("cluster-subscribe.toml"), + &config, + )); + let runtime = manager + .runtime_for(&cluster_cfg.name) + .context("cluster runtime unavailable")?; + let proxy = Arc::new( + ClusterProxy::new(&cluster_cfg, runtime, manager.clone()) + .await + .context("build cluster proxy")?, + ); + + let listener = match TcpListener::bind("127.0.0.1:0").await { + Ok(listener) => listener, + Err(err) if err.kind() == std::io::ErrorKind::PermissionDenied => { + eprintln!("subscription e2e skipped: {err}"); + server_a.shutdown().await; + server_b.shutdown().await; + return Ok(()); + } + Err(err) => { + server_a.shutdown().await; + server_b.shutdown().await; + return Err(err.into()); + } + }; + let addr = listener.local_addr().unwrap(); + let proxy_task = { + let proxy = proxy.clone(); + tokio::spawn(async move { + let (socket, _) = listener.accept().await?; + proxy.handle_connection(socket).await + }) + }; + + let mut client = Framed::new( + TcpStream::connect(addr) + .await + .context("connect to cluster subscription proxy")?, + RespCodec::default(), + ); + wait_for_cluster_ready(&mut client, channel).await?; + server_a + .redirect_key_once(channel, FakeRedirectKind::Moved, server_b.addr()) + .await; + let subscribe = send_command(&mut client, vec![&b"SUBSCRIBE"[..], channel.as_slice()]).await?; + assert_eq!( + subscribe, + RespValue::Array(vec![ + RespValue::BulkString(Bytes::from_static(b"subscribe")), + RespValue::BulkString(Bytes::copy_from_slice(channel)), + RespValue::Integer(1) + ]) + ); + + server_b.publish(channel, b"payload").await; + let message = match client.next().await { + Some(Ok(value)) => value, + other => anyhow::bail!("missing pubsub message: {:?}", other), + }; + assert_eq!( + message, + RespValue::Array(vec![ + RespValue::BulkString(Bytes::from_static(b"message")), + RespValue::BulkString(Bytes::copy_from_slice(channel)), + RespValue::BulkString(Bytes::from_static(b"payload")) + ]) + ); + + drop(client); + proxy_task.await??; + server_a.shutdown().await; + server_b.shutdown().await; + Ok(()) +} + async fn send_command( client: &mut Framed, parts: I, @@ -317,21 +440,33 @@ fn key_for_slot(range: std::ops::RangeInclusive) -> String { struct FakeRedisServer { addr: SocketAddr, + _state: Arc, Vec>>>, slots: Arc>>, + redirects: Arc, VecDeque>>>, + subscriptions: Arc, Vec>>>, + _next_subscriber_id: Arc, shutdown_tx: Option>, task: Option>, } impl FakeRedisServer { async fn start() -> Result { - let listener = TcpListener::bind("127.0.0.1:0").await.context("bind fake redis")?; + let listener = TcpListener::bind("127.0.0.1:0") + .await + .context("bind fake redis")?; let addr = listener.local_addr().context("resolve fake redis addr")?; let state = Arc::new(Mutex::new(HashMap::new())); let slots = Arc::new(RwLock::new(None)); + let redirects = Arc::new(RwLock::new(HashMap::new())); + let subscriptions = Arc::new(Mutex::new(HashMap::new())); + let next_subscriber_id = Arc::new(AtomicU64::new(1)); let (shutdown_tx, mut shutdown_rx) = oneshot::channel::<()>(); let task = tokio::spawn({ let state = state.clone(); let slots = slots.clone(); + let redirects = redirects.clone(); + let subscriptions = subscriptions.clone(); + let next_subscriber_id = next_subscriber_id.clone(); async move { loop { tokio::select! { @@ -341,8 +476,20 @@ impl FakeRedisServer { Ok((socket, _)) => { let state = state.clone(); let slots = slots.clone(); + let redirects = redirects.clone(); + let subscriptions = subscriptions.clone(); + let next_subscriber_id = next_subscriber_id.clone(); tokio::spawn(async move { - if let Err(err) = handle_fake_connection(socket, state, slots).await { + if let Err(err) = handle_fake_connection( + socket, + state, + slots, + redirects, + subscriptions, + next_subscriber_id, + ) + .await + { eprintln!("fake redis connection error: {err}"); } }); @@ -359,7 +506,11 @@ impl FakeRedisServer { }); Ok(Self { addr, + _state: state, slots, + redirects, + subscriptions, + _next_subscriber_id: next_subscriber_id, shutdown_tx: Some(shutdown_tx), task: Some(task), }) @@ -374,6 +525,44 @@ impl FakeRedisServer { *guard = Some(layout); } + async fn redirect_key_once( + &self, + key: impl AsRef<[u8]>, + kind: FakeRedirectKind, + target: SocketAddr, + ) { + let mut guard = self.redirects.write().await; + let entry = guard + .entry(key.as_ref().to_vec()) + .or_insert_with(VecDeque::new); + entry.push_back(FakeRedirect { kind, target }); + } + + async fn publish(&self, channel: impl AsRef<[u8]>, payload: impl AsRef<[u8]>) -> usize { + let mut guard = self.subscriptions.lock().await; + let key = channel.as_ref().to_vec(); + let payload = payload.as_ref().to_vec(); + if let Some(entries) = guard.get_mut(&key) { + let mut delivered = 0usize; + entries.retain(|entry| { + let message = RespValue::Array(vec![ + RespValue::BulkString(Bytes::from_static(b"message")), + RespValue::BulkString(Bytes::copy_from_slice(&key)), + RespValue::BulkString(Bytes::copy_from_slice(&payload)), + ]); + if entry.sender.send(message).is_ok() { + delivered += 1; + true + } else { + false + } + }); + delivered + } else { + 0 + } + } + async fn shutdown(mut self) { if let Some(tx) = self.shutdown_tx.take() { let _ = tx.send(()); @@ -388,85 +577,311 @@ async fn handle_fake_connection( socket: TcpStream, state: Arc, Vec>>>, slots: Arc>>, + redirects: Arc, VecDeque>>>, + subscriptions: Arc, Vec>>>, + next_subscriber_id: Arc, ) -> Result<()> { - let mut framed = Framed::new(socket, RespCodec::default()); - while let Some(frame) = framed.next().await { + let framed = Framed::new(socket, RespCodec::default()); + let (sink, mut stream) = framed.split(); + let (tx, mut rx) = mpsc::unbounded_channel(); + let mut sink = sink; + let writer = tokio::spawn(async move { + while let Some(resp) = rx.recv().await { + if sink.send(resp).await.is_err() { + break; + } + } + }); + + let mut ctx = FakeConnectionContext::new( + state, + slots, + redirects, + subscriptions, + next_subscriber_id.fetch_add(1, Ordering::Relaxed), + tx.clone(), + ); + + while let Some(frame) = stream.next().await { let reply = match frame.context("decode RESP frame")? { - RespValue::Array(parts) => handle_fake_command(parts, state.clone(), slots.clone()).await, - _ => RespValue::error("ERR invalid request"), + RespValue::Array(parts) => ctx.handle_command(parts).await, + _ => vec![RespValue::error("ERR invalid request")], }; - framed.send(reply).await?; + for resp in reply { + if tx.send(resp).is_err() { + break; + } + } } + + ctx.cleanup().await; + drop(tx); + let _ = writer.await; Ok(()) } -async fn handle_fake_command( - parts: Vec, +struct FakeConnectionContext { state: Arc, Vec>>>, slots: Arc>>, -) -> RespValue { - if parts.is_empty() { - return RespValue::error("ERR empty command"); + redirects: Arc, VecDeque>>>, + subscriptions: Arc, Vec>>>, + subscriber_id: u64, + channels: HashSet>, + sender: mpsc::UnboundedSender, +} + +impl FakeConnectionContext { + fn new( + state: Arc, Vec>>>, + slots: Arc>>, + redirects: Arc, VecDeque>>>, + subscriptions: Arc, Vec>>>, + subscriber_id: u64, + sender: mpsc::UnboundedSender, + ) -> Self { + Self { + state, + slots, + redirects, + subscriptions, + subscriber_id, + channels: HashSet::new(), + sender, + } } - let name = upper_name(&parts[0]); - match name.as_slice() { - b"PING" => RespValue::SimpleString(Bytes::from_static(b"PONG")), - b"SET" => { - if parts.len() < 3 { - return RespValue::error("ERR wrong number of arguments for 'set'"); - } - if let (Some(key), Some(value)) = (bulk_bytes(&parts[1]), bulk_bytes(&parts[2])) { - state.lock().await.insert(key, value); - RespValue::SimpleString(Bytes::from_static(b"OK")) - } else { - RespValue::error("ERR invalid arguments") - } + + async fn handle_command(&mut self, parts: Vec) -> Vec { + if parts.is_empty() { + return vec![RespValue::error("ERR empty command")]; } - b"GET" => { - if parts.len() < 2 { - return RespValue::error("ERR wrong number of arguments for 'get'"); - } - if let Some(key) = bulk_bytes(&parts[1]) { - match state.lock().await.get(&key) { - Some(value) => RespValue::BulkString(Bytes::copy_from_slice(value)), - None => RespValue::NullBulk, - } - } else { - RespValue::error("ERR invalid arguments") + if let Some(redirect) = self.maybe_redirect(&parts).await { + return vec![redirect]; + } + let name = upper_name(&parts[0]); + match name.as_slice() { + b"PING" => vec![RespValue::SimpleString(Bytes::from_static(b"PONG"))], + b"SET" => self.handle_set(&parts).await, + b"GET" => self.handle_get(&parts).await, + b"MGET" => self.handle_mget(&parts).await, + b"CLUSTER" => self.handle_cluster(&parts).await, + b"ASKING" => vec![RespValue::SimpleString(Bytes::from_static(b"OK"))], + b"SUBSCRIBE" => self.handle_subscribe(&parts).await, + b"UNSUBSCRIBE" => self.handle_unsubscribe(&parts).await, + _ => vec![RespValue::error("ERR unknown command")], + } + } + + async fn handle_set(&self, parts: &[RespValue]) -> Vec { + if parts.len() < 3 { + return vec![RespValue::error("ERR wrong number of arguments for 'set'")]; + } + if let (Some(key), Some(value)) = (bulk_bytes(&parts[1]), bulk_bytes(&parts[2])) { + let mut guard = self.state.lock().await; + guard.insert(key, value); + vec![RespValue::SimpleString(Bytes::from_static(b"OK"))] + } else { + vec![RespValue::error("ERR invalid arguments")] + } + } + + async fn handle_get(&self, parts: &[RespValue]) -> Vec { + if parts.len() < 2 { + return vec![RespValue::error("ERR wrong number of arguments for 'get'")]; + } + if let Some(key) = bulk_bytes(&parts[1]) { + let guard = self.state.lock().await; + match guard.get(&key) { + Some(value) => vec![RespValue::BulkString(Bytes::copy_from_slice(value))], + None => vec![RespValue::NullBulk], } + } else { + vec![RespValue::error("ERR invalid arguments")] } - b"MGET" => { - let guard = state.lock().await; - let mut values = Vec::new(); - for item in parts.iter().skip(1) { - if let Some(key) = bulk_bytes(item) { - if let Some(value) = guard.get(&key) { - values.push(RespValue::BulkString(Bytes::copy_from_slice(value))); - } else { - values.push(RespValue::NullBulk); - } + } + + async fn handle_mget(&self, parts: &[RespValue]) -> Vec { + let guard = self.state.lock().await; + let mut values = Vec::new(); + for item in parts.iter().skip(1) { + if let Some(key) = bulk_bytes(item) { + if let Some(value) = guard.get(&key) { + values.push(RespValue::BulkString(Bytes::copy_from_slice(value))); } else { values.push(RespValue::NullBulk); } + } else { + values.push(RespValue::NullBulk); } - RespValue::Array(values) } - b"CLUSTER" if parts + vec![RespValue::Array(values)] + } + + async fn handle_cluster(&self, parts: &[RespValue]) -> Vec { + if parts .get(1) .and_then(bulk_bytes) .map(|v| v.eq_ignore_ascii_case(b"SLOTS")) - .unwrap_or(false) => + .unwrap_or(false) { - match slots.read().await.clone() { - Some(layout) => layout, - None => RespValue::error("ERR slots unavailable"), + match self.slots.read().await.clone() { + Some(layout) => vec![layout], + None => vec![RespValue::error("ERR slots unavailable")], + } + } else { + vec![RespValue::error("ERR unknown subcommand")] + } + } + + async fn handle_subscribe(&mut self, parts: &[RespValue]) -> Vec { + if parts.len() < 2 { + return vec![RespValue::error( + "ERR wrong number of arguments for 'subscribe'", + )]; + } + let mut responses = Vec::new(); + for item in parts.iter().skip(1) { + if let Some(channel) = bulk_bytes(item) { + self.channels.insert(channel.clone()); + self.register_channel(channel.clone()).await; + responses.push(subscription_ack( + b"subscribe", + channel, + self.channels.len() as i64, + )); } } - b"ASKING" => RespValue::SimpleString(Bytes::from_static(b"OK")), - _ => RespValue::error("ERR unknown command"), + responses + } + + async fn handle_unsubscribe(&mut self, parts: &[RespValue]) -> Vec { + if parts.len() == 1 { + let mut responses = Vec::new(); + let channels = self.channels.clone(); + for channel in channels.iter() { + self.channels.remove(channel); + self.unregister_channel(channel).await; + responses.push(subscription_ack( + b"unsubscribe", + channel.clone(), + self.channels.len() as i64, + )); + } + responses + } else { + let mut responses = Vec::new(); + for item in parts.iter().skip(1) { + if let Some(channel) = bulk_bytes(item) { + self.channels.remove(&channel); + self.unregister_channel(&channel).await; + responses.push(subscription_ack( + b"unsubscribe", + channel, + self.channels.len() as i64, + )); + } + } + responses + } + } + + async fn maybe_redirect(&self, parts: &[RespValue]) -> Option { + let mut guard = self.redirects.write().await; + for key in extract_keys(parts) { + if let Some(queue) = guard.get_mut(&key) { + if let Some(rule) = queue.pop_front() { + if queue.is_empty() { + guard.remove(&key); + } + return Some(rule.into_resp(&key)); + } + } + } + None + } + + async fn register_channel(&self, channel: Vec) { + let mut guard = self.subscriptions.lock().await; + let entry = guard.entry(channel).or_insert_with(Vec::new); + entry.push(Subscriber { + id: self.subscriber_id, + sender: self.sender.clone(), + }); + } + + async fn unregister_channel(&self, channel: &[u8]) { + let mut guard = self.subscriptions.lock().await; + if let Some(entries) = guard.get_mut(channel) { + entries.retain(|entry| entry.id != self.subscriber_id); + if entries.is_empty() { + guard.remove(channel); + } + } + } + + async fn cleanup(&mut self) { + let channels = std::mem::take(&mut self.channels); + for channel in channels { + self.unregister_channel(&channel).await; + } } } +fn subscription_ack(kind: &[u8], channel: Vec, count: i64) -> RespValue { + RespValue::Array(vec![ + RespValue::BulkString(Bytes::copy_from_slice(kind)), + RespValue::BulkString(Bytes::copy_from_slice(&channel)), + RespValue::Integer(count), + ]) +} + +fn extract_keys(parts: &[RespValue]) -> Vec> { + if parts.is_empty() { + return Vec::new(); + } + let name = upper_name(&parts[0]); + match name.as_slice() { + b"SET" | b"GET" | b"SUBSCRIBE" | b"UNSUBSCRIBE" => { + parts.iter().skip(1).filter_map(bulk_bytes).collect() + } + b"MGET" => parts.iter().skip(1).filter_map(bulk_bytes).collect(), + _ => Vec::new(), + } +} + +#[derive(Clone)] +struct Subscriber { + id: u64, + sender: mpsc::UnboundedSender, +} + +#[derive(Clone)] +struct FakeRedirect { + kind: FakeRedirectKind, + target: SocketAddr, +} + +impl FakeRedirect { + fn into_resp(self, key: &[u8]) -> RespValue { + let slot = crc16(key) % SLOT_COUNT; + match self.kind { + FakeRedirectKind::Moved => { + let payload = format!("MOVED {} {}", slot, self.target); + RespValue::error(payload) + } + FakeRedirectKind::Ask => { + let payload = format!("ASK {} {}", slot, self.target); + RespValue::error(payload) + } + } + } +} + +#[derive(Clone, Copy)] +enum FakeRedirectKind { + Moved, + Ask, +} + fn upper_name(value: &RespValue) -> Vec { match value { RespValue::BulkString(data) | RespValue::SimpleString(data) => { @@ -499,7 +914,7 @@ async fn wait_for_cluster_ready( ) -> Result<()> { for _ in 0..20 { match send_command(client, vec![&b"GET"[..], key]).await? { - RespValue::Error(ref err) if err.as_ref().starts_with(b"ERR slot") => { + RespValue::Error(ref err) if cluster_not_ready(err.as_ref()) => { sleep(Duration::from_millis(50)).await; continue; } @@ -511,3 +926,7 @@ async fn wait_for_cluster_ready( String::from_utf8_lossy(key) )) } + +fn cluster_not_ready(message: &[u8]) -> bool { + message.starts_with(b"ERR slot") || message.starts_with(b"MOVED") || message.starts_with(b"ASK") +}