Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions src/auth/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use std::time::Duration;
use anyhow::{anyhow, bail, Context, Result};
use bytes::Bytes;
use futures::{SinkExt, StreamExt};
use serde::Deserialize;
use serde::{Deserialize, Serialize};
use tokio::net::TcpStream;
use tokio::time::timeout;
use tokio_util::codec::Framed;
Expand All @@ -16,7 +16,7 @@ use crate::protocol::redis::{RedisCommand, RespCodec, RespValue};
pub const DEFAULT_USER: &str = "default";

/// 前端 ACL 配置,兼容旧版简单密码写法。
#[derive(Debug, Clone, Deserialize)]
#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(untagged)]
pub enum FrontendAuthConfig {
/// 旧版 `password = "xxx"` 样式。
Expand All @@ -25,7 +25,7 @@ pub enum FrontendAuthConfig {
Detailed(FrontendAuthTable),
}

#[derive(Debug, Clone, Deserialize)]
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct FrontendAuthTable {
/// 简化写法:纯密码等价于 default 用户。
#[serde(default)]
Expand All @@ -35,7 +35,7 @@ pub struct FrontendAuthTable {
pub users: Vec<AuthUserConfig>,
}

#[derive(Debug, Clone, Deserialize)]
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct AuthUserConfig {
pub username: String,
pub password: String,
Expand Down Expand Up @@ -64,7 +64,7 @@ impl FrontendAuthConfig {
}

/// 后端认证配置,支持 ACL 写法与旧式密码。
#[derive(Debug, Clone, Deserialize)]
#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(untagged)]
pub enum BackendAuthConfig {
Password(String),
Expand Down
88 changes: 62 additions & 26 deletions src/cluster/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ use tracing::{debug, info, warn};
use crate::auth::{AuthAction, BackendAuth, FrontendAuthenticator};
use crate::backend::client::{ClientId, FrontConnectionGuard};
use crate::backend::pool::{BackendNode, ConnectionPool, Connector, SessionCommand};
use crate::config::ClusterConfig;
use crate::config::{ClusterConfig, ClusterRuntime, ConfigManager};
use crate::info::{InfoContext, ProxyMode};
use crate::metrics;
use crate::protocol::redis::{
Expand Down Expand Up @@ -50,27 +50,29 @@ pub struct ClusterProxy {
slots: Arc<watch::Sender<SlotMap>>,
pool: Arc<ConnectionPool<RedisCommand>>,
fetch_trigger: mpsc::UnboundedSender<()>,
backend_timeout: Duration,
runtime: Arc<ClusterRuntime>,
config_manager: Arc<ConfigManager>,
listen_port: u16,
seed_nodes: usize,
}

impl ClusterProxy {
pub async fn new(config: &ClusterConfig) -> Result<Self> {
pub async fn new(
config: &ClusterConfig,
runtime: Arc<ClusterRuntime>,
config_manager: Arc<ConfigManager>,
) -> Result<Self> {
let cluster: Arc<str> = config.name.clone().into();
let hash_tag = config.hash_tag.as_ref().map(|tag| tag.as_bytes().to_vec());
let read_from_slave = config.read_from_slave.unwrap_or(false);

let (slot_tx, _slot_rx) = watch::channel(SlotMap::new());
let (trigger_tx, trigger_rx) = mpsc::unbounded_channel();

let timeout_ms = config
.read_timeout
.or(config.write_timeout)
.unwrap_or(REQUEST_TIMEOUT_MS);
let backend_auth = config.backend_auth_config().map(BackendAuth::from);
let connector = Arc::new(ClusterConnector::new(
Duration::from_millis(timeout_ms),
runtime.clone(),
REQUEST_TIMEOUT_MS,
backend_auth.clone(),
));
let pool = Arc::new(ConnectionPool::new(cluster.clone(), connector.clone()));
Expand All @@ -90,7 +92,8 @@ impl ClusterProxy {
slots: Arc::new(slot_tx),
pool: pool.clone(),
fetch_trigger: trigger_tx.clone(),
backend_timeout: Duration::from_millis(timeout_ms),
runtime,
config_manager,
listen_port,
seed_nodes: config.servers.len(),
};
Expand Down Expand Up @@ -268,6 +271,19 @@ impl ClusterProxy {
}
}
}
if let Some(response) = self.try_handle_config(&cmd).await {
let kind_label = cmd.kind_label();
let success = !response.is_error();
metrics::front_command(
self.cluster.as_ref(),
kind_label,
success,
);
let fut = async move { response };
pending.push_back(Box::pin(fut));
inflight += 1;
continue;
}
if let Some(response) = self.try_handle_info(&cmd) {
metrics::front_command(
self.cluster.as_ref(),
Expand Down Expand Up @@ -319,6 +335,10 @@ impl ClusterProxy {
Ok(())
}

async fn try_handle_config(&self, command: &RedisCommand) -> Option<RespValue> {
self.config_manager.handle_command(command).await
}

fn try_handle_info(&self, command: &RedisCommand) -> Option<RespValue> {
if !command.command_name().eq_ignore_ascii_case(b"INFO") {
return None;
Expand Down Expand Up @@ -600,15 +620,16 @@ impl ClusterProxy {
node: &BackendNode,
) -> Result<Framed<TcpStream, RespCodec>> {
let addr = node.as_str().to_string();
let stream = timeout(self.backend_timeout, TcpStream::connect(&addr))
let timeout_duration = self.runtime.request_timeout(REQUEST_TIMEOUT_MS);
let stream = timeout(timeout_duration, TcpStream::connect(&addr))
.await
.with_context(|| format!("connect to {} timed out", addr))??;
stream
.set_nodelay(true)
.with_context(|| format!("failed to set TCP_NODELAY on {}", addr))?;
let mut framed = Framed::new(stream, RespCodec::default());
if let Some(auth) = &self.backend_auth {
auth.apply_to_stream(&mut framed, self.backend_timeout, &addr)
auth.apply_to_stream(&mut framed, timeout_duration, &addr)
.await?;
}
Ok(framed)
Expand Down Expand Up @@ -819,31 +840,33 @@ fn resp_value_to_bytes(value: &RespValue) -> Option<Bytes> {

#[derive(Clone)]
struct ClusterConnector {
timeout: Duration,
runtime: Arc<ClusterRuntime>,
default_timeout_ms: u64,
backend_auth: Option<BackendAuth>,
heartbeat_interval: Duration,
slow_response_threshold: Duration,
reconnect_base_delay: Duration,
max_reconnect_attempts: usize,
}

impl ClusterConnector {
fn new(timeout: Duration, backend_auth: Option<BackendAuth>) -> Self {
let slow_response_threshold = timeout
.checked_mul(3)
.unwrap_or_else(|| Duration::from_secs(5));
fn new(
runtime: Arc<ClusterRuntime>,
default_timeout_ms: u64,
backend_auth: Option<BackendAuth>,
) -> Self {
Self {
timeout,
runtime,
default_timeout_ms,
backend_auth,
heartbeat_interval: Duration::from_secs(30),
slow_response_threshold,
reconnect_base_delay: Duration::from_millis(50),
max_reconnect_attempts: 3,
}
}

async fn open_stream(&self, address: &str) -> Result<Framed<TcpStream, RespCodec>> {
let stream = timeout(self.timeout, TcpStream::connect(address))
let timeout_duration = self.current_timeout();
let stream = timeout(timeout_duration, TcpStream::connect(address))
.await
.with_context(|| format!("connection to {} timed out", address))??;
stream
Expand All @@ -864,7 +887,7 @@ impl ClusterConnector {
}
let mut framed = Framed::new(stream, RespCodec::default());
if let Some(auth) = &self.backend_auth {
auth.apply_to_stream(&mut framed, self.timeout, address)
auth.apply_to_stream(&mut framed, timeout_duration, address)
.await?;
}
Ok(framed)
Expand All @@ -881,7 +904,8 @@ impl ClusterConnector {
info!(blocking = ?blocking, "cluster connector executing blocking candidate {name}");
}
}
timeout(self.timeout, framed.send(command.to_resp()))
let timeout_duration = self.current_timeout();
timeout(timeout_duration, framed.send(command.to_resp()))
.await
.context("timed out sending command")??;

Expand All @@ -891,7 +915,7 @@ impl ClusterConnector {
Some(Err(err)) => Err(err.into()),
None => Err(anyhow!("backend closed connection")),
},
BlockingKind::None => match timeout(self.timeout, framed.next()).await {
BlockingKind::None => match timeout(timeout_duration, framed.next()).await {
Ok(Some(Ok(value))) => Ok(value),
Ok(Some(Err(err))) => Err(err.into()),
Ok(None) => Err(anyhow!("backend closed connection")),
Expand Down Expand Up @@ -945,11 +969,12 @@ impl ClusterConnector {
use RespValue::{Array, BulkString, SimpleString};

let ping = Array(vec![BulkString(Bytes::from_static(b"PING"))]);
timeout(self.timeout, framed.send(ping))
let timeout_duration = self.current_timeout();
timeout(timeout_duration, framed.send(ping))
.await
.context("timed out sending heartbeat")??;

match timeout(self.timeout, framed.next()).await {
match timeout(timeout_duration, framed.next()).await {
Ok(Some(Ok(resp))) => match resp {
SimpleString(ref data) | BulkString(ref data)
if data.eq_ignore_ascii_case(b"PONG") =>
Expand All @@ -963,6 +988,16 @@ impl ClusterConnector {
Err(_) => Err(anyhow!("timed out waiting for heartbeat reply")),
}
}

fn current_timeout(&self) -> Duration {
self.runtime.request_timeout(self.default_timeout_ms)
}

fn slow_response_threshold(&self) -> Duration {
self.current_timeout()
.checked_mul(3)
.unwrap_or_else(|| Duration::from_secs(5))
}
}

#[async_trait]
Expand Down Expand Up @@ -1005,10 +1040,11 @@ impl Connector<RedisCommand> for ClusterConnector {
}

if let Some(ref mut framed) = connection {
let slow_threshold = self.slow_response_threshold();
let started = Instant::now();
let result = self.execute(framed, cmd.request).await;
let elapsed = started.elapsed();
let is_slow = elapsed > self.slow_response_threshold;
let is_slow = elapsed > slow_threshold;

let mut should_drop = false;
match result {
Expand Down
Loading