From 8e05f1763e643b5b1e528f794fc4954bdb7f9f57 Mon Sep 17 00:00:00 2001 From: wayslog Date: Fri, 17 Oct 2025 15:55:20 +0800 Subject: [PATCH] feat: add config command --- src/auth/mod.rs | 10 +- src/cluster/mod.rs | 88 ++++++++--- src/config/mod.rs | 360 +++++++++++++++++++++++++++++++++++++++++- src/lib.rs | 19 ++- src/standalone/mod.rs | 82 +++++++--- 5 files changed, 491 insertions(+), 68 deletions(-) diff --git a/src/auth/mod.rs b/src/auth/mod.rs index c4f8086..06aa855 100644 --- a/src/auth/mod.rs +++ b/src/auth/mod.rs @@ -5,7 +5,7 @@ use std::time::Duration; use anyhow::{anyhow, bail, Context, Result}; use bytes::Bytes; use futures::{SinkExt, StreamExt}; -use serde::Deserialize; +use serde::{Deserialize, Serialize}; use tokio::net::TcpStream; use tokio::time::timeout; use tokio_util::codec::Framed; @@ -16,7 +16,7 @@ use crate::protocol::redis::{RedisCommand, RespCodec, RespValue}; pub const DEFAULT_USER: &str = "default"; /// 前端 ACL 配置,兼容旧版简单密码写法。 -#[derive(Debug, Clone, Deserialize)] +#[derive(Debug, Clone, Deserialize, Serialize)] #[serde(untagged)] pub enum FrontendAuthConfig { /// 旧版 `password = "xxx"` 样式。 @@ -25,7 +25,7 @@ pub enum FrontendAuthConfig { Detailed(FrontendAuthTable), } -#[derive(Debug, Clone, Deserialize)] +#[derive(Debug, Clone, Deserialize, Serialize)] pub struct FrontendAuthTable { /// 简化写法:纯密码等价于 default 用户。 #[serde(default)] @@ -35,7 +35,7 @@ pub struct FrontendAuthTable { pub users: Vec, } -#[derive(Debug, Clone, Deserialize)] +#[derive(Debug, Clone, Deserialize, Serialize)] pub struct AuthUserConfig { pub username: String, pub password: String, @@ -64,7 +64,7 @@ impl FrontendAuthConfig { } /// 后端认证配置,支持 ACL 写法与旧式密码。 -#[derive(Debug, Clone, Deserialize)] +#[derive(Debug, Clone, Deserialize, Serialize)] #[serde(untagged)] pub enum BackendAuthConfig { Password(String), diff --git a/src/cluster/mod.rs b/src/cluster/mod.rs index 6989526..438e98a 100644 --- a/src/cluster/mod.rs +++ b/src/cluster/mod.rs @@ -20,7 +20,7 @@ use tracing::{debug, info, warn}; use crate::auth::{AuthAction, BackendAuth, FrontendAuthenticator}; use crate::backend::client::{ClientId, FrontConnectionGuard}; use crate::backend::pool::{BackendNode, ConnectionPool, Connector, SessionCommand}; -use crate::config::ClusterConfig; +use crate::config::{ClusterConfig, ClusterRuntime, ConfigManager}; use crate::info::{InfoContext, ProxyMode}; use crate::metrics; use crate::protocol::redis::{ @@ -50,13 +50,18 @@ pub struct ClusterProxy { slots: Arc>, pool: Arc>, fetch_trigger: mpsc::UnboundedSender<()>, - backend_timeout: Duration, + runtime: Arc, + config_manager: Arc, listen_port: u16, seed_nodes: usize, } impl ClusterProxy { - pub async fn new(config: &ClusterConfig) -> Result { + pub async fn new( + config: &ClusterConfig, + runtime: Arc, + config_manager: Arc, + ) -> Result { let cluster: Arc = config.name.clone().into(); let hash_tag = config.hash_tag.as_ref().map(|tag| tag.as_bytes().to_vec()); let read_from_slave = config.read_from_slave.unwrap_or(false); @@ -64,13 +69,10 @@ impl ClusterProxy { let (slot_tx, _slot_rx) = watch::channel(SlotMap::new()); let (trigger_tx, trigger_rx) = mpsc::unbounded_channel(); - let timeout_ms = config - .read_timeout - .or(config.write_timeout) - .unwrap_or(REQUEST_TIMEOUT_MS); let backend_auth = config.backend_auth_config().map(BackendAuth::from); let connector = Arc::new(ClusterConnector::new( - Duration::from_millis(timeout_ms), + runtime.clone(), + REQUEST_TIMEOUT_MS, backend_auth.clone(), )); let pool = Arc::new(ConnectionPool::new(cluster.clone(), connector.clone())); @@ -90,7 +92,8 @@ impl ClusterProxy { slots: Arc::new(slot_tx), pool: pool.clone(), fetch_trigger: trigger_tx.clone(), - backend_timeout: Duration::from_millis(timeout_ms), + runtime, + config_manager, listen_port, seed_nodes: config.servers.len(), }; @@ -268,6 +271,19 @@ impl ClusterProxy { } } } + if let Some(response) = self.try_handle_config(&cmd).await { + let kind_label = cmd.kind_label(); + let success = !response.is_error(); + metrics::front_command( + self.cluster.as_ref(), + kind_label, + success, + ); + let fut = async move { response }; + pending.push_back(Box::pin(fut)); + inflight += 1; + continue; + } if let Some(response) = self.try_handle_info(&cmd) { metrics::front_command( self.cluster.as_ref(), @@ -319,6 +335,10 @@ impl ClusterProxy { Ok(()) } + async fn try_handle_config(&self, command: &RedisCommand) -> Option { + self.config_manager.handle_command(command).await + } + fn try_handle_info(&self, command: &RedisCommand) -> Option { if !command.command_name().eq_ignore_ascii_case(b"INFO") { return None; @@ -600,7 +620,8 @@ impl ClusterProxy { node: &BackendNode, ) -> Result> { let addr = node.as_str().to_string(); - let stream = timeout(self.backend_timeout, TcpStream::connect(&addr)) + let timeout_duration = self.runtime.request_timeout(REQUEST_TIMEOUT_MS); + let stream = timeout(timeout_duration, TcpStream::connect(&addr)) .await .with_context(|| format!("connect to {} timed out", addr))??; stream @@ -608,7 +629,7 @@ impl ClusterProxy { .with_context(|| format!("failed to set TCP_NODELAY on {}", addr))?; let mut framed = Framed::new(stream, RespCodec::default()); if let Some(auth) = &self.backend_auth { - auth.apply_to_stream(&mut framed, self.backend_timeout, &addr) + auth.apply_to_stream(&mut framed, timeout_duration, &addr) .await?; } Ok(framed) @@ -819,31 +840,33 @@ fn resp_value_to_bytes(value: &RespValue) -> Option { #[derive(Clone)] struct ClusterConnector { - timeout: Duration, + runtime: Arc, + default_timeout_ms: u64, backend_auth: Option, heartbeat_interval: Duration, - slow_response_threshold: Duration, reconnect_base_delay: Duration, max_reconnect_attempts: usize, } impl ClusterConnector { - fn new(timeout: Duration, backend_auth: Option) -> Self { - let slow_response_threshold = timeout - .checked_mul(3) - .unwrap_or_else(|| Duration::from_secs(5)); + fn new( + runtime: Arc, + default_timeout_ms: u64, + backend_auth: Option, + ) -> Self { Self { - timeout, + runtime, + default_timeout_ms, backend_auth, heartbeat_interval: Duration::from_secs(30), - slow_response_threshold, reconnect_base_delay: Duration::from_millis(50), max_reconnect_attempts: 3, } } async fn open_stream(&self, address: &str) -> Result> { - let stream = timeout(self.timeout, TcpStream::connect(address)) + let timeout_duration = self.current_timeout(); + let stream = timeout(timeout_duration, TcpStream::connect(address)) .await .with_context(|| format!("connection to {} timed out", address))??; stream @@ -864,7 +887,7 @@ impl ClusterConnector { } let mut framed = Framed::new(stream, RespCodec::default()); if let Some(auth) = &self.backend_auth { - auth.apply_to_stream(&mut framed, self.timeout, address) + auth.apply_to_stream(&mut framed, timeout_duration, address) .await?; } Ok(framed) @@ -881,7 +904,8 @@ impl ClusterConnector { info!(blocking = ?blocking, "cluster connector executing blocking candidate {name}"); } } - timeout(self.timeout, framed.send(command.to_resp())) + let timeout_duration = self.current_timeout(); + timeout(timeout_duration, framed.send(command.to_resp())) .await .context("timed out sending command")??; @@ -891,7 +915,7 @@ impl ClusterConnector { Some(Err(err)) => Err(err.into()), None => Err(anyhow!("backend closed connection")), }, - BlockingKind::None => match timeout(self.timeout, framed.next()).await { + BlockingKind::None => match timeout(timeout_duration, framed.next()).await { Ok(Some(Ok(value))) => Ok(value), Ok(Some(Err(err))) => Err(err.into()), Ok(None) => Err(anyhow!("backend closed connection")), @@ -945,11 +969,12 @@ impl ClusterConnector { use RespValue::{Array, BulkString, SimpleString}; let ping = Array(vec![BulkString(Bytes::from_static(b"PING"))]); - timeout(self.timeout, framed.send(ping)) + let timeout_duration = self.current_timeout(); + timeout(timeout_duration, framed.send(ping)) .await .context("timed out sending heartbeat")??; - match timeout(self.timeout, framed.next()).await { + match timeout(timeout_duration, framed.next()).await { Ok(Some(Ok(resp))) => match resp { SimpleString(ref data) | BulkString(ref data) if data.eq_ignore_ascii_case(b"PONG") => @@ -963,6 +988,16 @@ impl ClusterConnector { Err(_) => Err(anyhow!("timed out waiting for heartbeat reply")), } } + + fn current_timeout(&self) -> Duration { + self.runtime.request_timeout(self.default_timeout_ms) + } + + fn slow_response_threshold(&self) -> Duration { + self.current_timeout() + .checked_mul(3) + .unwrap_or_else(|| Duration::from_secs(5)) + } } #[async_trait] @@ -1005,10 +1040,11 @@ impl Connector for ClusterConnector { } if let Some(ref mut framed) = connection { + let slow_threshold = self.slow_response_threshold(); let started = Instant::now(); let result = self.execute(framed, cmd.request).await; let elapsed = started.elapsed(); - let is_slow = elapsed > self.slow_response_threshold; + let is_slow = elapsed > slow_threshold; let mut should_drop = false; match result { diff --git a/src/config/mod.rs b/src/config/mod.rs index 167ba1b..c198dd0 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -1,18 +1,24 @@ -use std::collections::HashSet; +use std::collections::{HashMap, HashSet}; use std::env; -use std::path::Path; +use std::path::{Path, PathBuf}; +use std::sync::atomic::{AtomicI64, Ordering}; +use std::sync::Arc; -use anyhow::{bail, Context, Result}; -use serde::Deserialize; +use anyhow::{anyhow, bail, Context, Result}; +use parking_lot::RwLock; +use serde::{Deserialize, Serialize}; use tokio::fs; +use tracing::{info, warn}; use crate::auth::{AuthUserConfig, BackendAuthConfig, FrontendAuthConfig}; +use crate::protocol::redis::{RedisCommand, RespValue}; /// Environment variable controlling the default worker thread count when a /// cluster omits the `thread` field. pub const ENV_DEFAULT_THREADS: &str = "ASTER_DEFAULT_THREAD"; +const DUMP_VALUE_DEFAULT: &str = "default"; -#[derive(Debug, Clone, Deserialize)] +#[derive(Debug, Clone, Deserialize, Serialize)] pub struct Config { #[serde(default)] clusters: Vec, @@ -69,7 +75,7 @@ impl Config { } } -#[derive(Debug, Clone, Copy, Deserialize, PartialEq, Eq)] +#[derive(Debug, Clone, Copy, Deserialize, Serialize, PartialEq, Eq)] #[serde(rename_all = "snake_case")] pub enum CacheType { Redis, @@ -82,7 +88,7 @@ impl Default for CacheType { } } -#[derive(Debug, Clone, Deserialize)] +#[derive(Debug, Clone, Deserialize, Serialize)] pub struct ClusterConfig { pub name: String, pub listen_addr: String, @@ -217,3 +223,343 @@ fn default_worker_threads() -> usize { .map(|nz| nz.get()) .unwrap_or(4) } + +#[derive(Debug)] +pub struct ClusterRuntime { + read_timeout_ms: AtomicI64, + write_timeout_ms: AtomicI64, +} + +impl ClusterRuntime { + fn new(read_timeout: Option, write_timeout: Option) -> Self { + Self { + read_timeout_ms: AtomicI64::new(option_to_atomic(read_timeout)), + write_timeout_ms: AtomicI64::new(option_to_atomic(write_timeout)), + } + } + + pub fn read_timeout(&self) -> Option { + atomic_to_option(self.read_timeout_ms.load(Ordering::Relaxed)) + } + + pub fn write_timeout(&self) -> Option { + atomic_to_option(self.write_timeout_ms.load(Ordering::Relaxed)) + } + + pub fn set_read_timeout(&self, value: Option) { + self.read_timeout_ms + .store(option_to_atomic(value), Ordering::Relaxed); + } + + pub fn set_write_timeout(&self, value: Option) { + self.write_timeout_ms + .store(option_to_atomic(value), Ordering::Relaxed); + } + + pub fn request_timeout(&self, default_ms: u64) -> std::time::Duration { + std::time::Duration::from_millis(self.request_timeout_ms(default_ms)) + } + + pub fn request_timeout_ms(&self, default_ms: u64) -> u64 { + if let Some(value) = self.read_timeout() { + value + } else if let Some(value) = self.write_timeout() { + value + } else { + default_ms + } + } +} + +fn option_to_atomic(value: Option) -> i64 { + match value { + Some(v) => v as i64, + None => -1, + } +} + +fn atomic_to_option(value: i64) -> Option { + if value < 0 { + None + } else { + Some(value as u64) + } +} + +#[derive(Debug, Clone)] +struct ClusterEntry { + index: usize, + runtime: Arc, +} + +#[derive(Debug)] +pub struct ConfigManager { + path: PathBuf, + config: RwLock, + clusters: HashMap, +} + +impl ConfigManager { + pub fn new(path: PathBuf, config: &Config) -> Self { + let mut clusters = HashMap::new(); + for (index, cluster) in config.clusters().iter().enumerate() { + let key = cluster.name.to_ascii_lowercase(); + clusters.insert( + key, + ClusterEntry { + index, + runtime: Arc::new(ClusterRuntime::new( + cluster.read_timeout, + cluster.write_timeout, + )), + }, + ); + } + + Self { + path, + config: RwLock::new(config.clone()), + clusters, + } + } + + pub fn runtime_for(&self, name: &str) -> Option> { + self.clusters + .get(&name.to_ascii_lowercase()) + .map(|entry| entry.runtime.clone()) + } + + pub async fn handle_command(&self, command: &RedisCommand) -> Option { + if !command.command_name().eq_ignore_ascii_case(b"CONFIG") { + return None; + } + + let args = command.args(); + if args.len() < 2 { + return Some(err_response( + "wrong number of arguments for 'config' command", + )); + } + + let sub = args[1].to_vec().to_ascii_uppercase(); + match sub.as_slice() { + b"GET" => Some(self.handle_get(args)), + b"SET" => Some(self.handle_set(args)), + b"DUMP" => Some(self.handle_dump(args)), + b"REWRITE" => Some(self.handle_rewrite(args).await), + other => Some(err_response(format!( + "unsupported config subcommand '{}'", + String::from_utf8_lossy(other).to_ascii_lowercase() + ))), + } + } + + fn handle_get(&self, args: &[bytes::Bytes]) -> RespValue { + if args.len() != 3 { + return err_response("wrong number of arguments for 'config get' command"); + } + let pattern = String::from_utf8_lossy(&args[2]).to_string(); + let entries = self.matching_entries(&pattern); + RespValue::array(flatten_pairs(entries)) + } + + fn handle_set(&self, args: &[bytes::Bytes]) -> RespValue { + if args.len() != 4 { + return err_response("wrong number of arguments for 'config set' command"); + } + + let key = String::from_utf8_lossy(&args[2]).to_string(); + let value = String::from_utf8_lossy(&args[3]).to_string(); + match self.apply_set(&key, &value) { + Ok(()) => RespValue::simple("OK"), + Err(err) => err_response(err.to_string()), + } + } + + fn handle_dump(&self, args: &[bytes::Bytes]) -> RespValue { + if args.len() != 2 { + return err_response("wrong number of arguments for 'config dump' command"); + } + let entries = self.all_entries(); + RespValue::array(flatten_pairs(entries)) + } + + async fn handle_rewrite(&self, args: &[bytes::Bytes]) -> RespValue { + if args.len() != 2 { + return err_response("wrong number of arguments for 'config rewrite' command"); + } + match self.rewrite().await { + Ok(()) => RespValue::simple("OK"), + Err(err) => { + warn!(error = %err, "failed to rewrite configuration file"); + err_response(err.to_string()) + } + } + } + + fn apply_set(&self, key: &str, value: &str) -> Result<()> { + let (cluster_name, field) = parse_key(key)?; + let cluster_key = cluster_name.to_ascii_lowercase(); + let entry = self + .clusters + .get(&cluster_key) + .ok_or_else(|| anyhow!("unknown cluster '{}'", cluster_name))? + .clone(); + + match field { + ClusterField::ReadTimeout => { + let parsed = parse_timeout_value(value)?; + entry.runtime.set_read_timeout(parsed); + let mut guard = self.config.write(); + guard.clusters_mut()[entry.index].read_timeout = parsed; + info!( + cluster = cluster_name, + value = value, + "cluster read_timeout updated via CONFIG SET" + ); + } + ClusterField::WriteTimeout => { + let parsed = parse_timeout_value(value)?; + entry.runtime.set_write_timeout(parsed); + let mut guard = self.config.write(); + guard.clusters_mut()[entry.index].write_timeout = parsed; + info!( + cluster = cluster_name, + value = value, + "cluster write_timeout updated via CONFIG SET" + ); + } + } + Ok(()) + } + + fn matching_entries(&self, pattern: &str) -> Vec<(String, String)> { + let pattern_lower = pattern.to_ascii_lowercase(); + self.all_entries() + .into_iter() + .filter(|(key, _)| wildcard_match(&pattern_lower, &key.to_ascii_lowercase())) + .collect() + } + + fn all_entries(&self) -> Vec<(String, String)> { + let guard = self.config.read(); + let mut entries = Vec::new(); + for cluster in guard.clusters() { + let name = &cluster.name; + let key = name.to_ascii_lowercase(); + if let Some(entry) = self.clusters.get(&key) { + let runtime = entry.runtime.clone(); + entries.push(( + format!("cluster.{}.read-timeout", name), + option_to_string(runtime.read_timeout()), + )); + entries.push(( + format!("cluster.{}.write-timeout", name), + option_to_string(runtime.write_timeout()), + )); + } + } + entries.sort_by(|a, b| a.0.cmp(&b.0)); + entries + } + + async fn rewrite(&self) -> Result<()> { + let snapshot = { + let guard = self.config.read(); + toml::to_string_pretty(&*guard)? + }; + fs::write(&self.path, snapshot) + .await + .with_context(|| format!("failed to persist configuration to {}", self.path.display())) + } +} + +fn parse_key(key: &str) -> Result<(String, ClusterField)> { + let mut parts = key.splitn(3, '.'); + let scope = parts + .next() + .ok_or_else(|| anyhow!("invalid config parameter '{}'", key))?; + if !scope.eq_ignore_ascii_case("cluster") { + bail!("unsupported config parameter '{}'", key); + } + let cluster = parts + .next() + .ok_or_else(|| anyhow!("config parameter missing cluster name '{}'", key))?; + let field = parts + .next() + .ok_or_else(|| anyhow!("config parameter missing field '{}'", key))?; + let field = match field.to_ascii_lowercase().as_str() { + "read-timeout" => ClusterField::ReadTimeout, + "write-timeout" => ClusterField::WriteTimeout, + unknown => bail!("unknown cluster field '{}'", unknown), + }; + Ok((cluster.to_string(), field)) +} + +fn parse_timeout_value(value: &str) -> Result> { + if value.eq_ignore_ascii_case(DUMP_VALUE_DEFAULT) { + return Ok(None); + } + let trimmed = value.trim(); + let parsed: u64 = trimmed + .parse() + .with_context(|| format!("invalid timeout value '{}'", value))?; + Ok(Some(parsed)) +} + +fn option_to_string(value: Option) -> String { + value + .map(|v| v.to_string()) + .unwrap_or_else(|| DUMP_VALUE_DEFAULT.to_string()) +} + +fn flatten_pairs(entries: Vec<(String, String)>) -> Vec { + let mut values = Vec::with_capacity(entries.len() * 2); + for (key, value) in entries { + values.push(RespValue::bulk(key)); + values.push(RespValue::bulk(value)); + } + values +} + +fn err_response(message: T) -> RespValue { + let payload = format!("ERR {}", message.to_string()); + RespValue::error(payload) +} + +enum ClusterField { + ReadTimeout, + WriteTimeout, +} + +fn wildcard_match(pattern: &str, target: &str) -> bool { + let pattern = pattern.as_bytes(); + let target = target.as_bytes(); + let mut p = 0usize; + let mut t = 0usize; + let mut star = None; + let mut match_idx = 0usize; + + while t < target.len() { + if p < pattern.len() && (pattern[p] == target[t] || pattern[p] == b'?') { + p += 1; + t += 1; + } else if p < pattern.len() && pattern[p] == b'*' { + star = Some(p); + match_idx = t; + p += 1; + } else if let Some(star_idx) = star { + p = star_idx + 1; + match_idx += 1; + t = match_idx; + } else { + return false; + } + } + + while p < pattern.len() && pattern[p] == b'*' { + p += 1; + } + + p == pattern.len() +} diff --git a/src/lib.rs b/src/lib.rs index 685a3e5..f870cb5 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -7,7 +7,7 @@ use std::path::PathBuf; use std::sync::Arc; -use anyhow::{Context, Result}; +use anyhow::{anyhow, Context, Result}; use async_trait::async_trait; use clap::Parser; use tokio::net::{TcpListener, TcpStream}; @@ -27,7 +27,7 @@ pub mod standalone; pub mod utils; use crate::cluster::ClusterProxy; -use crate::config::{CacheType, Config}; +use crate::config::{CacheType, Config, ConfigManager}; use crate::meta::{derive_meta, scope_with_meta}; use crate::standalone::StandaloneProxy; @@ -107,6 +107,7 @@ async fn run_async(options: BootstrapOptions) -> Result<()> { metrics::register_version(env!("CARGO_PKG_VERSION")); let config = Config::load(&options.config).await?; + let config_manager = Arc::new(ConfigManager::new(options.config.clone(), &config)); info!( clusters = config.clusters().len(), @@ -125,6 +126,9 @@ async fn run_async(options: BootstrapOptions) -> Result<()> { let metrics_handles = metrics::spawn_background_tasks(options.metrics_port); for cluster_cfg in config.clusters().iter().cloned() { + let runtime = config_manager + .runtime_for(&cluster_cfg.name) + .ok_or_else(|| anyhow!("missing runtime state for cluster {}", cluster_cfg.name))?; let listen_addr = cluster_cfg.listen_addr.clone(); let listener = TcpListener::bind(&listen_addr) .await @@ -146,7 +150,11 @@ async fn run_async(options: BootstrapOptions) -> Result<()> { match cluster_cfg.cache_type { CacheType::Redis => { - let proxy = Arc::new(StandaloneProxy::new(&cluster_cfg)?); + let proxy = Arc::new(StandaloneProxy::new( + &cluster_cfg, + runtime.clone(), + config_manager.clone(), + )?); let listener = listener; let meta = meta.clone(); let cluster_label = cluster_label.clone(); @@ -155,7 +163,10 @@ async fn run_async(options: BootstrapOptions) -> Result<()> { })); } CacheType::RedisCluster => { - let proxy = Arc::new(ClusterProxy::new(&cluster_cfg).await?); + let proxy = Arc::new( + ClusterProxy::new(&cluster_cfg, runtime.clone(), config_manager.clone()) + .await?, + ); let listener = listener; let meta = meta.clone(); let cluster_label = cluster_label.clone(); diff --git a/src/standalone/mod.rs b/src/standalone/mod.rs index 63779a9..2bbae64 100644 --- a/src/standalone/mod.rs +++ b/src/standalone/mod.rs @@ -20,7 +20,7 @@ use tracing::{debug, info, warn}; use crate::auth::{AuthAction, BackendAuth, FrontendAuthenticator}; use crate::backend::client::{ClientId, FrontConnectionGuard}; use crate::backend::pool::{BackendNode, ConnectionPool, Connector, SessionCommand}; -use crate::config::ClusterConfig; +use crate::config::{ClusterConfig, ClusterRuntime, ConfigManager}; use crate::info::{InfoContext, ProxyMode}; use crate::metrics; use crate::protocol::redis::{ @@ -47,13 +47,18 @@ pub struct StandaloneProxy { auth: Option>, backend_auth: Option, pool: Arc>, - backend_timeout: Duration, + runtime: Arc, + config_manager: Arc, listen_port: u16, backend_nodes: usize, } impl StandaloneProxy { - pub fn new(config: &ClusterConfig) -> Result { + pub fn new( + config: &ClusterConfig, + runtime: Arc, + config_manager: Arc, + ) -> Result { let cluster: Arc = config.name.clone().into(); let hash_tag = config.hash_tag.as_ref().map(|tag| tag.as_bytes().to_vec()); let nodes = parse_servers(&config.servers)?; @@ -65,13 +70,10 @@ impl StandaloneProxy { } let ring = build_ring(&nodes); - let timeout_ms = config - .read_timeout - .or(config.write_timeout) - .unwrap_or(DEFAULT_TIMEOUT_MS); let backend_auth = config.backend_auth_config().map(BackendAuth::from); let connector = Arc::new(RedisConnector::new( - Duration::from_millis(timeout_ms), + runtime.clone(), + DEFAULT_TIMEOUT_MS, backend_auth.clone(), )); let auth = config @@ -91,7 +93,8 @@ impl StandaloneProxy { auth, backend_auth, pool, - backend_timeout: Duration::from_millis(timeout_ms), + runtime, + config_manager, listen_port, backend_nodes, }) @@ -318,7 +321,8 @@ impl StandaloneProxy { node: &BackendNode, ) -> Result> { let addr = node.as_str().to_string(); - let stream = timeout(self.backend_timeout, TcpStream::connect(&addr)) + let timeout_duration = self.runtime.request_timeout(DEFAULT_TIMEOUT_MS); + let stream = timeout(timeout_duration, TcpStream::connect(&addr)) .await .with_context(|| format!("connect to {} timed out", addr))??; stream @@ -326,7 +330,7 @@ impl StandaloneProxy { .with_context(|| format!("failed to set TCP_NODELAY on {}", addr))?; let mut framed = Framed::new(stream, RespCodec::default()); if let Some(auth) = &self.backend_auth { - auth.apply_to_stream(&mut framed, self.backend_timeout, &addr) + auth.apply_to_stream(&mut framed, timeout_duration, &addr) .await?; } Ok(framed) @@ -417,6 +421,13 @@ impl StandaloneProxy { } } + if let Some(response) = self.try_handle_config(&command).await { + let success = !response.is_error(); + metrics::front_command(self.cluster.as_ref(), kind_label, success); + framed.send(response).await?; + continue; + } + if let Some(response) = self.try_handle_info(&command) { metrics::front_command(self.cluster.as_ref(), kind_label, true); framed.send(response).await?; @@ -441,6 +452,10 @@ impl StandaloneProxy { Ok(()) } + async fn try_handle_config(&self, command: &RedisCommand) -> Option { + self.config_manager.handle_command(command).await + } + fn try_handle_info(&self, command: &RedisCommand) -> Option { if !command.command_name().eq_ignore_ascii_case(b"INFO") { return None; @@ -585,24 +600,25 @@ fn subscription_action_kind(kind: &[u8]) -> Option { #[derive(Clone)] struct RedisConnector { - timeout: Duration, + runtime: Arc, + default_timeout_ms: u64, reconnect_delay: Duration, max_reconnect_delay: Duration, - slow_response_threshold: Duration, heartbeat_interval: Duration, backend_auth: Option, } impl RedisConnector { - fn new(timeout: Duration, backend_auth: Option) -> Self { - let slow_response_threshold = timeout - .checked_mul(4) - .unwrap_or_else(|| Duration::from_secs(4)); + fn new( + runtime: Arc, + default_timeout_ms: u64, + backend_auth: Option, + ) -> Self { Self { - timeout, + runtime, + default_timeout_ms, reconnect_delay: Duration::from_millis(100), max_reconnect_delay: Duration::from_secs(2), - slow_response_threshold, heartbeat_interval: Duration::from_secs(20), backend_auth, } @@ -610,7 +626,8 @@ impl RedisConnector { async fn open_stream(&self, node: &BackendNode) -> Result> { let connect_target = node.as_str().to_string(); - let stream = timeout(self.timeout, TcpStream::connect(&connect_target)) + let timeout_duration = self.current_timeout(); + let stream = timeout(timeout_duration, TcpStream::connect(&connect_target)) .await .with_context(|| format!("connect to {} timed out", connect_target))??; stream @@ -631,7 +648,7 @@ impl RedisConnector { } let mut framed = Framed::new(stream, RespCodec::default()); if let Some(auth) = &self.backend_auth { - auth.apply_to_stream(&mut framed, self.timeout, &connect_target) + auth.apply_to_stream(&mut framed, timeout_duration, &connect_target) .await?; } Ok(framed) @@ -644,7 +661,8 @@ impl RedisConnector { ) -> Result { let blocking = request.as_blocking(); let frame = request.to_resp(); - timeout(self.timeout, framed.send(frame)) + let timeout_duration = self.current_timeout(); + timeout(timeout_duration, framed.send(frame)) .await .context("timed out while sending request")??; @@ -654,7 +672,7 @@ impl RedisConnector { Some(Err(err)) => Err(err.into()), None => Err(anyhow!("backend closed connection")), }, - BlockingKind::None => match timeout(self.timeout, framed.next()).await { + BlockingKind::None => match timeout(timeout_duration, framed.next()).await { Ok(Some(Ok(response))) => Ok(response), Ok(Some(Err(err))) => Err(err.into()), Ok(None) => Err(anyhow!("backend closed connection")), @@ -667,11 +685,12 @@ impl RedisConnector { use RespValue::{Array, BulkString, SimpleString}; let ping = Array(vec![BulkString(Bytes::from_static(b"PING"))]); - timeout(self.timeout, framed.send(ping)) + let timeout_duration = self.current_timeout(); + timeout(timeout_duration, framed.send(ping)) .await .context("timed out while sending heartbeat")??; - match timeout(self.timeout, framed.next()).await { + match timeout(timeout_duration, framed.next()).await { Ok(Some(Ok(resp))) => match resp { SimpleString(ref data) | BulkString(ref data) if data.eq_ignore_ascii_case(b"PONG") => @@ -692,6 +711,16 @@ impl RedisConnector { .unwrap_or_else(|| self.max_reconnect_delay); min(doubled, self.max_reconnect_delay) } + + fn current_timeout(&self) -> Duration { + self.runtime.request_timeout(self.default_timeout_ms) + } + + fn slow_response_threshold(&self) -> Duration { + self.current_timeout() + .checked_mul(4) + .unwrap_or_else(|| Duration::from_secs(4)) + } } #[async_trait] @@ -757,10 +786,11 @@ impl Connector for RedisConnector { } if let Some(ref mut framed) = connection { + let slow_threshold = self.slow_response_threshold(); let started = Instant::now(); let result = self.execute_request(framed, request).await; let elapsed = started.elapsed(); - let is_slow = elapsed > self.slow_response_threshold; + let is_slow = elapsed > slow_threshold; match result { Ok(resp) => {