diff --git a/Cargo.lock b/Cargo.lock index 4578b0ee69..743cfe1206 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1687,6 +1687,13 @@ dependencies = [ "pin-project", ] +[[package]] +name = "linkerd-ewma" +version = "0.1.0" +dependencies = [ + "tokio", +] + [[package]] name = "linkerd-exp-backoff" version = "0.1.0" @@ -1751,6 +1758,7 @@ dependencies = [ "futures", "http", "http-body", + "httpdate", "linkerd-error", "linkerd-http-box", "linkerd-stack", @@ -1979,6 +1987,24 @@ dependencies = [ "tokio-util", ] +[[package]] +name = "linkerd-load-biaser" +version = "0.1.0" +dependencies = [ + "futures", + "http", + "linkerd-ewma", + "linkerd-http-classify", + "linkerd-stack", + "parking_lot", + "pin-project", + "tokio", + "tokio-test", + "tower", + "tower-service", + "tracing", +] + [[package]] name = "linkerd-meshtls" version = "0.1.0" diff --git a/Cargo.toml b/Cargo.toml index 7ac6863ca9..bc2dfc2e95 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -22,6 +22,7 @@ members = [ "linkerd/error", "linkerd/errno", "linkerd/error-respond", + "linkerd/ewma", "linkerd/exp-backoff", "linkerd/http/access-log", "linkerd/http/body-eos", @@ -42,6 +43,7 @@ members = [ "linkerd/identity", "linkerd/idle-cache", "linkerd/io", + "linkerd/load-biaser", "linkerd/meshtls", "linkerd/meshtls/verifier", "linkerd/metrics", @@ -105,6 +107,7 @@ drain = { version = "0.2", default-features = false } h2 = { version = "0.4" } http = { version = "1" } http-body = { version = "1" } +httpdate = { version = "1.0" } hyper = { version = "1", default-features = false } prometheus-client = { version = "0.23" } prost = { version = "0.14" } diff --git a/linkerd/ewma/Cargo.toml b/linkerd/ewma/Cargo.toml new file mode 100644 index 0000000000..8933c4e5c7 --- /dev/null +++ b/linkerd/ewma/Cargo.toml @@ -0,0 +1,13 @@ +[package] +name = "linkerd-ewma" +version = { workspace = true } +authors = { workspace = true } +license = { workspace = true } +edition = { workspace = true } +publish = { workspace = true } + +[dependencies] +tokio = { version = "1", features = ["time"] } + +[dev-dependencies] +tokio = { version = "1", features = ["macros", "rt", "test-util"] } diff --git a/linkerd/ewma/src/lib.rs b/linkerd/ewma/src/lib.rs new file mode 100644 index 0000000000..872fcfc518 --- /dev/null +++ b/linkerd/ewma/src/lib.rs @@ -0,0 +1,476 @@ +#![deny(rust_2018_idioms, clippy::disallowed_methods, clippy::disallowed_types)] +#![forbid(unsafe_code)] + +use tokio::time; + +/// Minimum decay duration to prevent division-by-zero in EWMA computations. +/// Chosen as the smallest Duration that is strictly positive without overriding +/// validated configs from the control plane (CP should reject decay=0). +pub const MIN_DECAY: time::Duration = time::Duration::from_millis(1); + +/// An exponentially-weighted moving average. +#[derive(Debug)] +pub struct Ewma { + value: f64, + decay: f64, + timestamp: time::Instant, +} + +// === impl Ewma === + +impl Ewma { + #[must_use] + pub fn new(decay: time::Duration, timestamp: time::Instant) -> Self { + Self { + decay: decay.max(MIN_DECAY).as_secs_f64(), + timestamp, + value: f64::INFINITY, + } + } + + /// Creates an EWMA with a specific initial value. + /// + /// This constructor allows setting an initial value, useful for + /// success rate tracking where you want to start at 100% (1.0) + /// success rate. + #[must_use] + pub fn new_with_value(decay: time::Duration, timestamp: time::Instant, initial: f64) -> Self { + debug_assert!(!initial.is_nan(), "EWMA initial value must not be NaN"); + Self { + decay: decay.max(MIN_DECAY).as_secs_f64(), + timestamp, + value: initial, + } + } + + /// Resets the EWMA to a new value and timestamp. + /// + /// This overwrites the current value and timestamp, useful for + /// resetting success rate tracking after recovery from a tripped state. + /// + /// Precondition: `ts` should be <= the timestamps passed to subsequent + /// `add()` calls. If it is not, those `add()` calls are silently + /// dropped because `add()` discards samples whose timestamp is at or + /// before the stored one. + pub fn reset(&mut self, value: f64, ts: time::Instant) { + debug_assert!(!value.is_nan(), "EWMA reset value must not be NaN"); + self.value = value; + self.timestamp = ts; + } + + /// Returns the current value of the average. + pub fn get(&self) -> f64 { + self.value + } + + /// Returns the decayed value projected to the given time, without modifying stored state. + /// + /// Instead of returning the raw stored value, this applies exponential decay based + /// on elapsed time since the last update. This is required for load balancing where + /// stale measurements should lose influence over time. + pub fn get_at(&self, now: time::Instant) -> f64 { + debug_assert!(!self.value.is_nan(), "EWMA value must not be NaN"); + + if self.value.is_infinite() || now <= self.timestamp { + return self.value; + } + let elapsed = now.saturating_duration_since(self.timestamp); + self.value * (-elapsed.as_secs_f64() / self.decay).exp() + } + + /// Updates the weighted moving average with a new value and timestamp. + /// + /// Precondition: `value` must not be NaN. Passing NaN poisons the EWMA + /// irreversibly (all subsequent reads return NaN). + pub fn add(&mut self, value: f64, ts: time::Instant) { + debug_assert!(!value.is_nan(), "EWMA input value must not be NaN"); + if ts <= self.timestamp { + return; + } + if self.value == f64::INFINITY { + self.value = value; + self.timestamp = ts; + return; + } + + self.value = { + let elapsed = ts.saturating_duration_since(self.timestamp); + let alpha = 1.0 - (-elapsed.as_secs_f64() / self.decay).exp(); + self.value * (1.0 - alpha) + value * alpha + }; + + self.timestamp = ts; + } + + /// Updates the EWMA with a peak value, replacing the current value if the + /// new value exceeds the decayed projection. + /// + /// When replacement occurs, the stored timestamp is set to `ts`, which + /// may be earlier than the previously stored timestamp. This resets the + /// decay reference point, so subsequent projections via `get_at()` measure + /// elapsed time from `ts`. + pub fn add_peak(&mut self, value: f64, ts: time::Instant) { + debug_assert!(!value.is_nan(), "EWMA peak value must not be NaN"); + if self.value.is_infinite() || self.get_at(ts) < value { + self.value = value; + self.timestamp = ts; + return; + } + self.add(value, ts) + } + + /// Computes 1/elapsed since the last update and feeds it through `add()`. + pub fn add_rate(&mut self, ts: time::Instant) { + if ts <= self.timestamp { + return; + } + let elapsed = ts.saturating_duration_since(self.timestamp); + if !elapsed.is_zero() { + self.add(1.0 / elapsed.as_secs_f64(), ts); + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use tokio::time::{Duration, Instant}; + + // Literal value for exp(-1.0), since it's not const + const EXP_NEG1: f64 = 0.36787944117144233; + + #[tokio::test(flavor = "current_thread", start_paused = true)] + async fn test_new() { + let now = Instant::now(); + let ewma = Ewma::new(Duration::from_secs(10), now); + assert_eq!(ewma.get(), f64::INFINITY); + } + + #[tokio::test(flavor = "current_thread", start_paused = true)] + async fn test_add() { + let now = Instant::now(); + let mut ewma = Ewma::new(Duration::from_secs(10), now); + ewma.add(1.0, now + Duration::from_secs(1)); + assert_eq!(ewma.get(), 1.0); + } + + #[tokio::test(flavor = "current_thread", start_paused = true)] + async fn test_add_rate() { + let now = Instant::now(); + let mut ewma = Ewma::new(Duration::from_secs(10), now); + ewma.add_rate(now + Duration::from_secs(1)); + assert_eq!(ewma.get(), 1.0); + ewma.add_rate(now + Duration::from_secs(3)); + assert_eq!(ewma.get(), 0.9093653765389909); + } + + #[tokio::test(flavor = "current_thread", start_paused = true)] + async fn test_add_peak() { + let now = Instant::now(); + let mut ewma = Ewma::new(Duration::from_secs(10), now); + ewma.add_peak(1.0, now + Duration::from_secs(1)); + assert_eq!(ewma.get(), 1.0); + ewma.add_peak(2.0, now + Duration::from_secs_f64(1.5)); + assert_eq!(ewma.get(), 2.0); + } + + #[tokio::test(flavor = "current_thread", start_paused = true)] + async fn test_decay() { + let now = Instant::now(); + let mut ewma = Ewma::new(Duration::from_secs(10), now); + ewma.add(1.0, now + Duration::from_secs(1)); + assert_eq!(ewma.get(), 1.0); + ewma.add(0.0, now + Duration::from_secs(11)); + assert_eq!(ewma.get(), EXP_NEG1); + } + + #[tokio::test(flavor = "current_thread", start_paused = true)] + async fn test_new_with_value() { + let now = Instant::now(); + let ewma = Ewma::new_with_value(Duration::from_secs(10), now, 1.0); + + assert_eq!(ewma.get(), 1.0); + + // Verify this behaves like a normal EWMA after initialization + let mut ewma = Ewma::new_with_value(Duration::from_secs(10), now, 1.0); + + ewma.add(0.0, now + Duration::from_secs(10)); + // After one decay period value decays towards zero. + assert_eq!(ewma.get(), EXP_NEG1); + } + + #[tokio::test(flavor = "current_thread", start_paused = true)] + async fn test_add_peak_from_infinity_same_timestamp() { + let now = Instant::now(); + let mut ewma = Ewma::new(Duration::from_secs(10), now); + + assert_eq!(ewma.get(), f64::INFINITY); + + // Same timestamp as construction. The first real value should always + // take effect regardless of timestamp. + ewma.add_peak(0.5, now); + assert_eq!(ewma.get(), 0.5); + } + + #[tokio::test(flavor = "current_thread", start_paused = true)] + async fn test_add_peak_replaces_decayed_value() { + let now = Instant::now(); + let mut ewma = Ewma::new(Duration::from_secs(10), now); + + // Set initial peak of 10.0 at t=1s + ewma.add_peak(10.0, now + Duration::from_secs(1)); + assert_eq!(ewma.get(), 10.0); + + // After 25s of decay (t=26s), the decayed projection is: + // 10.0 * exp(-25/10) = 10.0 * exp(-2.5) = 0.8208... + // Since 0.8208 < 1.0, the new value should replace the stale peak. + ewma.add_peak(1.0, now + Duration::from_secs(26)); + assert_eq!(ewma.get(), 1.0); + } + + #[tokio::test(flavor = "current_thread", start_paused = true)] + async fn test_zero_decay_clamped_in_new() { + let now = Instant::now(); + let zero = Ewma::new(Duration::ZERO, now); + let min = Ewma::new(Duration::from_millis(1), now); + + // Both should produce identical behavior since ZERO gets clamped to 1ms + assert_eq!(zero.get(), min.get()); + + // After adding a value, get_at should produce finite, identical results + let mut zero = Ewma::new(Duration::ZERO, now); + let mut min = Ewma::new(Duration::from_millis(1), now); + zero.add(5.0, now + Duration::from_secs(1)); + min.add(5.0, now + Duration::from_secs(1)); + + let projected_zero = zero.get_at(now + Duration::from_secs(2)); + let projected_min = min.get_at(now + Duration::from_secs(2)); + + assert!( + projected_zero.is_finite(), + "get_at must be finite with clamped decay" + ); + assert!( + !projected_zero.is_nan(), + "get_at must not be NaN with clamped decay" + ); + assert_eq!(projected_zero, projected_min); + } + + #[tokio::test(flavor = "current_thread", start_paused = true)] + async fn test_zero_decay_clamped_in_new_with_value() { + let now = Instant::now(); + let zero = Ewma::new_with_value(Duration::ZERO, now, 1.0); + let min = Ewma::new_with_value(Duration::from_millis(1), now, 1.0); + + let projected_zero = zero.get_at(now + Duration::from_secs(1)); + let projected_min = min.get_at(now + Duration::from_secs(1)); + assert!( + projected_zero.is_finite(), + "get_at must be finite with clamped decay" + ); + assert!( + !projected_zero.is_nan(), + "get_at must not be NaN with clamped decay" + ); + assert_eq!(projected_zero, projected_min); + } + + #[tokio::test(flavor = "current_thread", start_paused = true)] + async fn test_zero_decay_add_is_safe() { + let now = Instant::now(); + let mut ewma = Ewma::new(Duration::ZERO, now); + + ewma.add(5.0, now + Duration::from_secs(1)); + ewma.add(0.5, now + Duration::from_secs(2)); + + let val = ewma.get(); + assert!( + val.is_finite(), + "add() result must be finite with clamped decay" + ); + assert!( + !val.is_nan(), + "add() result must not be NaN with clamped decay" + ); + } + + #[tokio::test(flavor = "current_thread", start_paused = true)] + async fn test_reset() { + let now = Instant::now(); + let mut ewma = Ewma::new(Duration::from_secs(10), now); + + // Set state + ewma.add(0.5, now + Duration::from_secs(1)); + assert_eq!(ewma.get(), 0.5); + + // Reset to a new value + ewma.reset(1.0, now + Duration::from_secs(2)); + assert_eq!(ewma.get(), 1.0); + + // Verify EWMA continues working after reset. + ewma.add(0.0, now + Duration::from_secs(12)); + assert_eq!(ewma.get(), EXP_NEG1); + } + + #[tokio::test(flavor = "current_thread", start_paused = true)] + async fn test_get_at_same_timestamp() { + let now = Instant::now(); + let decay = Duration::from_secs(10); + let add_at = now + Duration::from_secs(1); + let read_at = add_at; + + let mut ewma = Ewma::new(decay, now); + ewma.add(0.5, add_at); + + assert_eq!(ewma.get_at(read_at), 0.5); + } + + #[tokio::test(flavor = "current_thread", start_paused = true)] + async fn test_get_at_past_timestamp() { + let now = Instant::now(); + let decay = Duration::from_secs(10); + let add_at = now + Duration::from_secs(1); + let read_at = now + Duration::from_millis(500); + + let mut ewma = Ewma::new(decay, now); + ewma.add(0.5, add_at); + + assert_eq!(ewma.get_at(read_at), 0.5); + } + + #[tokio::test(flavor = "current_thread", start_paused = true)] + async fn test_get_at_infinity() { + let now = Instant::now(); + let decay = Duration::from_secs(10); + let probe_same = now; + let probe_near = now + Duration::from_secs(1); + let probe_far = now + Duration::from_secs(100); + + // A new Ewma without addding values should project INFINITY + // at every timestamp. + let ewma = Ewma::new(decay, now); + assert_eq!(ewma.get_at(probe_same), f64::INFINITY); + assert_eq!(ewma.get_at(probe_near), f64::INFINITY); + assert_eq!(ewma.get_at(probe_far), f64::INFINITY); + } + + #[tokio::test(flavor = "current_thread", start_paused = true)] + async fn test_get_at_decay() { + let now = Instant::now(); + let decay = Duration::from_secs(10); + let add_at = now + Duration::from_secs(1); + let read_at = now + Duration::from_secs(11); + + let mut ewma = Ewma::new(decay, now); + ewma.add(1.0, add_at); + + // Verify that get_at applies value * exp(-elapsed/decay) correctly + // without changing internal state. + assert_eq!(ewma.get_at(read_at), EXP_NEG1); + } + + #[tokio::test(flavor = "current_thread", start_paused = true)] + async fn test_get_at_large_elapsed() { + let now = Instant::now(); + let decay = Duration::from_secs(10); + let add_at = now + Duration::from_secs(1); + let read_at = now + Duration::from_secs(3600); + + let mut ewma = Ewma::new(decay, now); + ewma.add(1.0, add_at); + + let result = ewma.get_at(read_at); + assert!( + result.is_finite(), + "get_at at large elapsed must be finite, got {result}" + ); + assert!( + result < 1e-10, + "get_at at large elapsed must be near zero, got {result}" + ); + } + + #[tokio::test(flavor = "current_thread", start_paused = true)] + async fn test_get_at_non_mutation() { + let now = Instant::now(); + let decay = Duration::from_secs(10); + let first_add_at = now + Duration::from_secs(1); + let read_at = now + Duration::from_secs(6); + + let mut ewma = Ewma::new(decay, now); + ewma.add(1.0, first_add_at); + + // Take internal state before the read + let value_before = ewma.value; + let timestamp_before = ewma.timestamp; + + // Read must not mutate + let _ = ewma.get_at(read_at); + + assert_eq!(ewma.value, value_before); + assert_eq!(ewma.timestamp, timestamp_before); + } + + #[tokio::test(flavor = "current_thread", start_paused = true)] + async fn test_get_at_after_reset() { + const EWMA_VAL: f64 = 0.5; + const DECAY_AT_15S_VAL: f64 = EWMA_VAL * EXP_NEG1; + + let now = Instant::now(); + let decay = Duration::from_secs(10); + let first_add_at = now + Duration::from_secs(1); + let reset_at = now + Duration::from_secs(5); + let immediate_read_at = reset_at; + let decay_read_at = now + Duration::from_secs(15); + + let mut ewma = Ewma::new(decay, now); + // Use a large value here to make sure we detect any issues with + // reset not working properly, since we'd likely see a wildly + // different value in the assert below. + ewma.add(100.0, first_add_at); + + // Reset replaces both value and timestamp + ewma.reset(EWMA_VAL, reset_at); + + // Must return the freshly-reset value + assert_eq!(ewma.get_at(immediate_read_at), EWMA_VAL); + + // Read one decay period (10s) after the reset (so 15s decay). + assert_eq!(ewma.get_at(decay_read_at), DECAY_AT_15S_VAL); + } + + #[tokio::test(flavor = "current_thread", start_paused = true)] + async fn test_get_at_between_adds() { + const EXP_NEG0_5: f64 = 0.6065306597126334; + + let now = Instant::now(); + let decay = Duration::from_secs(10); + let first_add_at = now + Duration::from_secs(1); + let read_at = now + Duration::from_secs(6); + let second_add_at = now + Duration::from_secs(11); + + let mut ewma = Ewma::new(decay, now); + ewma.add(1.0, first_add_at); + + assert_eq!(ewma.get_at(read_at), EXP_NEG0_5); + + ewma.add(0.0, second_add_at); + + // Final state after the second add must be exp(-1.0), + // ensuring get_at() doesn't change the internal state. + assert_eq!(ewma.get(), EXP_NEG1); + } + + #[tokio::test(flavor = "current_thread", start_paused = true)] + #[cfg(debug_assertions)] + #[should_panic(expected = "NaN")] + async fn test_get_at_debug_asserts_nan() { + let now = Instant::now(); + // Inject NaN via new_with_value + let ewma = Ewma::new_with_value(Duration::from_secs(10), now, f64::NAN); + + // Should trigger a debug_assert + let _ = ewma.get_at(now); + } +} diff --git a/linkerd/http/classify/Cargo.toml b/linkerd/http/classify/Cargo.toml index c7a7911e40..e5c23eb0b7 100644 --- a/linkerd/http/classify/Cargo.toml +++ b/linkerd/http/classify/Cargo.toml @@ -10,6 +10,7 @@ publish = { workspace = true } futures = { version = "0.3", default-features = false } http = { workspace = true } http-body = { workspace = true } +httpdate.workspace = true pin-project = "1" tokio = { version = "1", default-features = false } tracing = { workspace = true } diff --git a/linkerd/http/classify/src/lib.rs b/linkerd/http/classify/src/lib.rs index e25f2b7701..fcc9c5ef70 100644 --- a/linkerd/http/classify/src/lib.rs +++ b/linkerd/http/classify/src/lib.rs @@ -12,6 +12,7 @@ pub use self::{ mod channel; pub mod gate; mod insert; +pub mod retry_after; /// Determines how a request's response should be classified. pub trait Classify { diff --git a/linkerd/http/classify/src/retry_after.rs b/linkerd/http/classify/src/retry_after.rs new file mode 100644 index 0000000000..9b00ec4303 --- /dev/null +++ b/linkerd/http/classify/src/retry_after.rs @@ -0,0 +1,320 @@ +//! Shared parsing for HTTP Retry-After headers and gRPC retry-pushback-ms trailers. +//! +//! This module provides pure parsing functions with no classification or store logic. +//! Both the load-biaser and the circuit breaker can use these functions to extract +//! backoff hints from HTTP and gRPC responses. + +use http::{HeaderMap, StatusCode}; +use std::time::Duration; + +/// Parse the Retry-After header from a 429 or 503 response. +/// +/// Supports two formats per RFC 7231: +/// - delay-seconds: `Retry-After: 120` -> 120 seconds +/// - HTTP-date: `Retry-After: Wed, 21 Oct 2025 07:28:00 GMT` -> duration from now +/// +/// Returns `None` for: +/// - Non-429/503 responses +/// - Missing Retry-After header +/// - Invalid header formats +/// +/// The returned duration is capped at `max` to prevent abuse. +pub fn parse_retry_after( + status: StatusCode, + headers: &HeaderMap, + max: Duration, +) -> Option { + // Only parse for 429 and 503 responses + if status != StatusCode::TOO_MANY_REQUESTS && status != StatusCode::SERVICE_UNAVAILABLE { + return None; + } + + let value = headers.get(http::header::RETRY_AFTER)?; + let s = value.to_str().ok()?; + + parse_retry_after_value(s, max) +} + +/// Parse a Retry-After header value string. +/// +/// Tries delay-seconds first (most common), then HTTP-date format. +/// The returned duration is capped at `max`. +fn parse_retry_after_value(s: &str, max: Duration) -> Option { + let s = s.trim(); + // Try delay-seconds first (most common format) + if let Ok(secs) = s.parse::() { + let duration = Duration::from_secs(secs); + tracing::debug!(?duration, "Parsed Retry-After delay-seconds"); + return Some(duration.min(max)); + } + + // Try HTTP-date format + if let Ok(datetime) = httpdate::parse_http_date(s) { + let now = std::time::SystemTime::now(); + match datetime.duration_since(now) { + Ok(duration) => { + tracing::debug!(?duration, "Parsed Retry-After HTTP-date"); + return Some(duration.min(max)); + } + Err(_) => { + tracing::debug!("Retry-After HTTP-date is in the past"); + return Some(Duration::ZERO); + } + } + } + + tracing::debug!(%s, "Failed to parse Retry-After header"); + None +} + +/// The grpc-retry-pushback-ms header/trailer name. +const GRPC_RETRY_PUSHBACK_MS: &str = "grpc-retry-pushback-ms"; + +/// Parse grpc-retry-pushback-ms from headers or trailers. +/// +/// Per gRPC A6 spec: +/// - Positive i64: retry after this many milliseconds +/// - Negative i64: do not retry (returns `None`) +/// +/// This function does **not** check grpc-status; that is a classification +/// concern left to the caller. +/// +/// Returns `None` for: +/// - Missing header/trailer +/// - Negative values (interpreted as "do not retry") +/// - Invalid formats +/// +/// The returned duration is capped at `max` to prevent abuse. +pub fn parse_grpc_retry_pushback(headers: &HeaderMap, max: Duration) -> Option { + let value = headers.get(GRPC_RETRY_PUSHBACK_MS)?; + let s = value.to_str().ok()?; + + // Parse as i64 to handle potential negative values + let ms: i64 = match s.trim().parse() { + Ok(v) => v, + Err(_) => { + tracing::debug!(%s, "Failed to parse grpc-retry-pushback-ms"); + return None; + } + }; + + // Negative values mean "do not retry". + if ms < 0 { + tracing::debug!(ms, "Ignoring negative grpc-retry-pushback-ms"); + return None; + } + + let duration = Duration::from_millis(ms as u64); + tracing::debug!(?duration, "Parsed grpc-retry-pushback-ms"); + Some(duration.min(max)) +} + +#[cfg(test)] +mod tests { + use super::*; + use http::header::RETRY_AFTER; + use http::HeaderValue; + use std::time::SystemTime; + + const MAX: Duration = Duration::from_secs(300); + + // === Retry-After tests === + + #[test] + fn parse_delay_seconds() { + let mut headers = HeaderMap::new(); + headers.insert(http::header::RETRY_AFTER, HeaderValue::from_static("120")); + + let result = parse_retry_after(StatusCode::TOO_MANY_REQUESTS, &headers, MAX); + assert_eq!(result, Some(Duration::from_secs(120))); + } + + #[test] + fn parse_delay_seconds_zero() { + let mut headers = HeaderMap::new(); + headers.insert(http::header::RETRY_AFTER, HeaderValue::from_static("0")); + + let result = parse_retry_after(StatusCode::TOO_MANY_REQUESTS, &headers, MAX); + assert_eq!(result, Some(Duration::ZERO)); + } + + #[test] + fn caps_at_max() { + let mut headers = HeaderMap::new(); + headers.insert(http::header::RETRY_AFTER, HeaderValue::from_static("3600")); + + let result = parse_retry_after(StatusCode::TOO_MANY_REQUESTS, &headers, MAX); + assert_eq!(result, Some(MAX)); + } + + #[test] + fn parses_503() { + let mut headers = HeaderMap::new(); + headers.insert(http::header::RETRY_AFTER, HeaderValue::from_static("120")); + + let result = parse_retry_after(StatusCode::SERVICE_UNAVAILABLE, &headers, MAX); + assert_eq!(result, Some(Duration::from_secs(120))); + } + + #[test] + fn ignores_other_status() { + let mut headers = HeaderMap::new(); + headers.insert(http::header::RETRY_AFTER, HeaderValue::from_static("120")); + + let result = parse_retry_after(StatusCode::OK, &headers, MAX); + assert_eq!(result, None); + } + + #[test] + fn ignores_missing_header() { + let headers = HeaderMap::new(); + + let result = parse_retry_after(StatusCode::TOO_MANY_REQUESTS, &headers, MAX); + assert_eq!(result, None); + } + + #[test] + fn ignores_invalid_value() { + let mut headers = HeaderMap::new(); + headers.insert( + http::header::RETRY_AFTER, + HeaderValue::from_static("not-a-number"), + ); + + let result = parse_retry_after(StatusCode::TOO_MANY_REQUESTS, &headers, MAX); + assert_eq!(result, None); + } + + #[test] + fn retry_after_http_date_in_past() { + let mut headers = HeaderMap::new(); + // Use a date in the past + headers.insert( + RETRY_AFTER, + "Wed, 01 Jan 2020 00:00:00 GMT".parse().unwrap(), + ); + + let result = parse_retry_after(StatusCode::TOO_MANY_REQUESTS, &headers, MAX); + assert_eq!(result, Some(Duration::ZERO)); + } + + #[test] + fn retry_after_http_date_in_future() { + let target = SystemTime::now() + Duration::from_secs(60); + let date_str = httpdate::fmt_http_date(target); + let mut headers = HeaderMap::new(); + headers.insert(RETRY_AFTER, date_str.parse().unwrap()); + + let result = parse_retry_after(StatusCode::TOO_MANY_REQUESTS, &headers, MAX); + let dur = result.expect("should parse future HTTP-date"); + + // httpdate truncates sub-seconds (1s resolution), and we can have + // some differences for clock time under CI load or NTP adjustment + // so test parsing correctness within a wide enough time window. + assert!( + dur >= Duration::from_secs(55) && dur <= Duration::from_secs(65), + "expected ~60s, got {:?}", + dur, + ); + } + + #[test] + fn retry_after_http_date_caps_at_max() { + let target = SystemTime::now() + Duration::from_secs(600); + let date_str = httpdate::fmt_http_date(target); + let mut headers = HeaderMap::new(); + headers.insert(RETRY_AFTER, date_str.parse().unwrap()); + + let result = parse_retry_after(StatusCode::TOO_MANY_REQUESTS, &headers, MAX); + assert_eq!(result, Some(MAX)); + } + + // === gRPC pushback tests === + + const MAX_GRPC: Duration = Duration::from_millis(300_000); + + #[test] + fn parse_grpc_pushback_positive() { + let mut headers = HeaderMap::new(); + headers.insert("grpc-retry-pushback-ms", HeaderValue::from_static("5000")); + + let result = parse_grpc_retry_pushback(&headers, MAX_GRPC); + assert_eq!(result, Some(Duration::from_millis(5000))); + } + + #[test] + fn parse_grpc_pushback_zero() { + let mut headers = HeaderMap::new(); + headers.insert("grpc-retry-pushback-ms", HeaderValue::from_static("0")); + + let result = parse_grpc_retry_pushback(&headers, MAX_GRPC); + assert_eq!(result, Some(Duration::ZERO)); + } + + #[test] + fn parse_grpc_pushback_negative() { + let mut headers = HeaderMap::new(); + headers.insert("grpc-retry-pushback-ms", HeaderValue::from_static("-1")); + + let result = parse_grpc_retry_pushback(&headers, MAX_GRPC); + assert_eq!(result, None); + } + + #[test] + fn parse_grpc_pushback_caps_at_max() { + let mut headers = HeaderMap::new(); + headers.insert("grpc-retry-pushback-ms", HeaderValue::from_static("999999")); + + let result = parse_grpc_retry_pushback(&headers, MAX_GRPC); + assert_eq!(result, Some(MAX_GRPC)); + } + + #[test] + fn parse_grpc_pushback_missing() { + let headers = HeaderMap::new(); + + let result = parse_grpc_retry_pushback(&headers, MAX_GRPC); + assert_eq!(result, None); + } + + #[test] + fn parse_grpc_pushback_invalid() { + let mut headers = HeaderMap::new(); + headers.insert( + "grpc-retry-pushback-ms", + HeaderValue::from_static("not-a-number"), + ); + + let result = parse_grpc_retry_pushback(&headers, MAX_GRPC); + assert_eq!(result, None); + } + + // === Whitespace handling tests === + + #[test] + fn retry_after_trailing_whitespace() { + let mut headers = HeaderMap::new(); + headers.insert(RETRY_AFTER, "120 ".parse().unwrap()); + + let result = parse_retry_after(StatusCode::TOO_MANY_REQUESTS, &headers, MAX); + assert_eq!(result, Some(Duration::from_secs(120))); + } + + #[test] + fn retry_after_leading_whitespace() { + let mut headers = HeaderMap::new(); + headers.insert(RETRY_AFTER, " 120".parse().unwrap()); + + let result = parse_retry_after(StatusCode::TOO_MANY_REQUESTS, &headers, MAX); + assert_eq!(result, Some(Duration::from_secs(120))); + } + + #[test] + fn grpc_pushback_whitespace() { + let mut headers = HeaderMap::new(); + headers.insert("grpc-retry-pushback-ms", " 5000 ".parse().unwrap()); + + let result = parse_grpc_retry_pushback(&headers, MAX); + assert_eq!(result, Some(Duration::from_millis(5000))); + } +} diff --git a/linkerd/load-biaser/Cargo.toml b/linkerd/load-biaser/Cargo.toml new file mode 100644 index 0000000000..a8fa0fb29a --- /dev/null +++ b/linkerd/load-biaser/Cargo.toml @@ -0,0 +1,29 @@ +[package] +name = "linkerd-load-biaser" +version = { workspace = true } +authors = { workspace = true } +license = { workspace = true } +edition = { workspace = true } +publish = { workspace = true } + +[features] +default = [] +tokio-test = ["dep:tokio-test"] + +[dependencies] +linkerd-ewma = { path = "../ewma" } +futures = { version = "0.3", default-features = false } +http = { workspace = true } +linkerd-http-classify = { path = "../http/classify" } +linkerd-stack = { path = "../stack" } +parking_lot = "0.12" +pin-project = "1" +tokio = { version = "1", features = ["io-util", "net", "time"] } +tokio-test = { version = "0.4", optional = true } +tower = { workspace = true, features = ["load"] } +tower-service = { workspace = true } +tracing = { workspace = true } + +[dev-dependencies] +tokio = { version = "1", features = ["macros", "rt", "time"] } +tokio-test = "0.4" diff --git a/linkerd/load-biaser/src/lib.rs b/linkerd/load-biaser/src/lib.rs new file mode 100644 index 0000000000..c90848c62c --- /dev/null +++ b/linkerd/load-biaser/src/lib.rs @@ -0,0 +1,1582 @@ +//! Load tracking with response failure awareness. +//! +//! This module provides a `LoadBiaser` wrapper that tracks request latency (RTT) +//! and detects failure responses (HTTP 429, 503, 5xx), applying artificial +//! penalties. Unlike Tower's `PeakEwma`, this implementation uses +//! `linkerd_ewma::Ewma` and returns `f64` metrics directly, enabling +//! integration with P2C load balancing. +//! +//! This can wrap any service (`Load` trait not required) and it tracks RTT via +//! EWMA, which is updated when responses complete, and pending requests for load +//! calculation. +//! +//! When a failure response is detected, a penalty is applied and the EWMA jumps +//! to a high value. The load is calculated as `max(rtt * (pending + 1), penalty)` +//! (so at least RTT with no pending requests). +//! +//! Both RTT and penalty decay over time via the EWMA. +//! +//! For this type to work responses must implement the `ResponseFailureHint` +//! trait, which classifies responses into failure categories (rate-limited, +//! service unavailable, internal error). For non-HTTP responses the default +//! implementation returns no failure hint, so only RTT tracking occurs. +//! +//! For gRPC, `failure_hint()` detects `RESOURCE_EXHAUSTED` (code 8) and +//! `UNAVAILABLE` (code 14) from the `grpc-status` header in trailers-only +//! (unary) responses. Streaming gRPC puts `grpc-status` in trailers which +//! are not visible at response-head time; the circuit breaker handles +//! streaming gRPC failures via `GrpcRetryPushbackClassifyEos`. + +#![deny(rust_2018_idioms, clippy::disallowed_methods, clippy::disallowed_types)] +#![forbid(unsafe_code)] + +use futures::ready; +use linkerd_ewma::{Ewma, MIN_DECAY}; +use linkerd_stack::NewService; +use parking_lot::RwLock; +use pin_project::pin_project; +use std::{ + future::Future, + marker::PhantomData, + pin::Pin, + sync::{ + atomic::{AtomicU32, Ordering}, + Arc, + }, + task::{Context, Poll}, + time::Duration, +}; +use tokio::time::Instant; +use tower::load::Load; +use tower_service::Service; + +/// Default maximum duration for Retry-After hints. +pub const DEFAULT_RETRY_AFTER_MAX_DURATION: Duration = Duration::from_secs(300); + +/// Classification of response failures for load biasing. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum FailureHint { + /// HTTP 429 Too Many Requests + RateLimited, + /// HTTP 503 Service Unavailable + ServiceUnavailable, + /// HTTP 5xx (other than 503) + InternalError, +} + +/// Cached Retry-After hint stored in HTTP response extensions. +/// +/// Stores the **uncapped** parsed value. Each consumer applies its own cap +/// via `rate_limit_hint(max)`, so different callers (e.g. load biaser vs +/// circuit breaker) can use different maximums from the same cached value. +#[derive(Clone, Copy, Debug)] +pub struct CachedRateLimitHint(pub Duration); + +/// Trait for extracting failure hints from responses. +/// +/// This allows the load biaser to classify responses and apply appropriate +/// penalties. Default implementations return None (no failure detected), +/// which is appropriate for non-HTTP transports. +/// +/// The trait splits rate limit hint access into two methods to avoid +/// requiring `&mut self` on the read path: +/// - `attach_parsed_rate_limit_hint(&mut self, max)`: parse and cache (needs `&mut`) +/// - `rate_limit_hint(&self, max)`: read cached value or parse on-read (only needs `&self`) +/// +/// # Stack ordering and caching +/// +/// In the current proxy stack `RetryAfterClassify::start()` (the circuit breaker's +/// classifier) calls `rate_limit_hint()` before `LoadBiaserFuture::poll()` calls +/// `attach_parsed_rate_limit_hint()`, because responses flow inner-to-outer. This +/// means the cache is always cold for the circuit breaker path. +pub trait ResponseFailureHint { + /// Returns a failure hint if the response indicates a failure condition. + fn failure_hint(&self) -> Option { + None + } + + /// Parse and cache the raw (uncapped) rate limit hint from this response. + /// + /// The `_max` parameter is accepted for API symmetry with `rate_limit_hint(max)` + /// but is intentionally unused. The raw uncapped value is cached so that each + /// consumer can apply their own cap via `rate_limit_hint(max)`. + fn attach_parsed_rate_limit_hint(&mut self, _max: Duration) {} + + /// Returns the rate limit hint if available. + /// + /// Checks cached value first (from a previous `attach_parsed_rate_limit_hint` call). + /// If no cached value, attempts to parse the header directly (capping at `max`). + /// Returns `None` only if the header is absent or unparseable. + fn rate_limit_hint(&self, _max: Duration) -> Option { + None + } +} + +/// HTTP responses classify failures by status code and parse Retry-After hints. +/// +/// For gRPC responses (which arrive as HTTP 200 with `grpc-status` in headers +/// for trailers-only/unary errors), the `grpc-status` header is checked when +/// the HTTP status is 200: +/// - gRPC status 8 (RESOURCE_EXHAUSTED) -> `RateLimited` +/// - gRPC status 14 (UNAVAILABLE) -> `ServiceUnavailable` +/// - Other non-zero gRPC status codes -> `InternalError` +impl ResponseFailureHint for http::Response { + fn failure_hint(&self) -> Option { + let status = self.status(); + if status == http::StatusCode::TOO_MANY_REQUESTS { + Some(FailureHint::RateLimited) + } else if status == http::StatusCode::SERVICE_UNAVAILABLE { + Some(FailureHint::ServiceUnavailable) + } else if status.is_server_error() { + Some(FailureHint::InternalError) + } else if status == http::StatusCode::OK { + // gRPC trailers-only responses: grpc-status appears in headers. + // Note: for streaming gRPC, grpc-status is in trailers (not headers) + // and will not be detected here. The circuit breaker handles streaming + // gRPC failures via GrpcRetryPushbackClassifyEos. + self.headers() + .get("grpc-status") + .and_then(|v| v.to_str().ok()) + .and_then(|s| s.parse::().ok()) + .and_then(|code| match code { + 0 => None, // OK + 8 => Some(FailureHint::RateLimited), // RESOURCE_EXHAUSTED + 14 => Some(FailureHint::ServiceUnavailable), // UNAVAILABLE + _ => Some(FailureHint::InternalError), // Other non-zero + }) + } else { + None + } + } + + fn attach_parsed_rate_limit_hint(&mut self, _max: Duration) { + // Store the uncapped value. Each consumer applies their own cap via + // rate_limit_hint(max). + if let Some(d) = linkerd_http_classify::retry_after::parse_retry_after( + self.status(), + self.headers(), + Duration::MAX, + ) { + self.extensions_mut().insert(CachedRateLimitHint(d)); + return; + } + // Try gRPC retry-pushback-ms (for trailers-only responses) + if self.status() == http::StatusCode::OK { + if let Some(d) = linkerd_http_classify::retry_after::parse_grpc_retry_pushback( + self.headers(), + Duration::MAX, + ) { + self.extensions_mut().insert(CachedRateLimitHint(d)); + } + } + } + + fn rate_limit_hint(&self, max: Duration) -> Option { + // Check cache first (from previous attach call), apply caller's cap + if let Some(cached) = self.extensions().get::() { + return Some(cached.0.min(max)); + } + // Parse on-read as fallback (header present but attach wasn't called) + if let Some(d) = linkerd_http_classify::retry_after::parse_retry_after( + self.status(), + self.headers(), + max, + ) { + return Some(d); + } + // Try gRPC pushback + if self.status() == http::StatusCode::OK { + if let Some(d) = + linkerd_http_classify::retry_after::parse_grpc_retry_pushback(self.headers(), max) + { + return Some(d); + } + } + // No header or unparseable + None + } +} + +/// Connection tuples (Connection, Metadata) used by TCP/TLS paths never indicate failures. +/// This is meant for MakeConnection services that return `(I, M)` tuples. +impl ResponseFailureHint for (C, M) {} + +/// TCP streams never indicate failures (no HTTP status codes). +impl ResponseFailureHint for tokio::net::TcpStream {} + +/// Duplex streams (used in testing) never indicate failures. +impl ResponseFailureHint for tokio::io::DuplexStream {} + +/// Mock IO streams never indicate failures. +#[cfg(feature = "tokio-test")] +impl ResponseFailureHint for tokio_test::io::Mock {} + +/// Amplification factor for Retry-After penalties. +/// +/// When a 429 or 503 response includes a Retry-After header, the load biaser injects +/// an amplified penalty so it remains meaningful through the server's requested +/// avoidance window. The injected value is: +/// +/// `penalty_secs * RETRY_AFTER_PENALTY_FACTOR * exp(retry_after / penalty_decay)` +/// +/// which decays via EWMA to `penalty_secs * RETRY_AFTER_PENALTY_FACTOR` at +/// exactly `t = retry_after`. The factor controls how aggressively the endpoint +/// is avoided near the Retry-After deadline: +/// +/// At 1.0 the endpoint is fully avoided (~0% traffic) through the window and 46s +/// beyond -- too aggressive for early-recovery discovery. At 0.1, traffic leaks +/// well before the deadline. We use 0.5: the penalty at deadline is penalty_secs/2 +/// (roughly 50x healthy load in a 5-endpoint pool), which lets occasional probes +/// through in the second half while still exceeding a plain failure penalty for +/// RA >= ~7s. For pools with 100+ endpoints the factor barely matters because P2C +/// random pair selection already makes the penalized endpoint unlikely to be drawn. +/// +/// Sending occasional probe traffic during the window is valuable: the endpoint +/// may recover early, or a fresh 429 or 503 with a different RA provides updated +/// information. The circuit breaker (when enabled) provides a strict backoff +/// window, ie. `max(backoff, retry_after)`. +const RETRY_AFTER_PENALTY_FACTOR: f64 = 0.5; + +/// Configuration for LoadBiaser behavior. +#[derive(Clone, Debug)] +pub struct LoadBiaserConfig { + /// Default RTT to use when no measurements are available + pub default_rtt: Duration, + + /// Decay duration for the RTT EWMA. + /// Controls how quickly RTT estimates adapt to changing latency + pub rtt_decay: Duration, + + /// The penalty value to inject on failure responses (429, 503, 5xx) in seconds + pub penalty_secs: f64, + + /// Decay duration for the penalty EWMA. + /// Controls how quickly the penalty decays after a failure response + pub penalty_decay: Duration, + + /// Whether load biasing penalties are enabled. When false, only RTT tracking + /// is active (PeakEwma equivalent). + pub enabled: bool, + + /// Maximum Retry-After duration to honor. Clamped to this value. + pub max_duration: Duration, +} + +impl Default for LoadBiaserConfig { + fn default() -> Self { + Self { + default_rtt: Duration::from_secs(1), + rtt_decay: Duration::from_secs(10), + // 5 second penalty on failure responses (429, 503, 5xx) + penalty_secs: 5.0, + // 10 second decay for penalty - mostly gone after ~30 seconds + penalty_decay: Duration::from_secs(10), + enabled: false, + max_duration: DEFAULT_RETRY_AFTER_MAX_DURATION, + } + } +} + +/// Shared per-endpoint state behind a single `Arc`. +/// +/// Combines the EWMA trackers (under a mutex for atomic RTT+penalty reads), +/// the in-flight request counter, and the immutable config fields that every +/// response future needs. One allocation per endpoint instead of two, and +/// futures carry a single `Arc` clone instead of copying config fields. +#[derive(Debug)] +struct SharedState { + /// RTT and penalty EWMAs; read-locked in load(), write-locked in poll() + metrics: RwLock, + /// Count of in-flight requests (up on call, down on response) + pending: AtomicU32, + /// Penalty value to inject on failure responses (in seconds) + penalty_secs: f64, + /// Whether penalty injection is enabled + enabled: bool, + /// Maximum Retry-After duration to honor (clamped) + max_duration: Duration, + /// Decay duration for the penalty EWMA (used to amplify Retry-After penalties) + penalty_decay: Duration, +} + +#[derive(Debug)] +struct LoadMetrics { + /// EWMA RTT tracking, updated on each response + rtt: Ewma, + /// EWMA penalty tracking, updated on failure responses (429, 503, 5xx) + penalty: Ewma, +} + +/// A service wrapper that tracks RTT and biases load metrics based on failure responses. +/// +/// `LoadBiaser` provides load metrics for P2C load balancing by tracking request latency +/// (RTT) via EWMA, in-flight requests, and by injecting penalties when failure responses +/// (429, 503, 5xx) are detected. +/// +/// The `load()` method returns `max(rtt * (pending + 1), penalty)`, causing P2C to +/// prefer endpoints with lower latency and fewer in-flight requests, while avoiding +/// rate-limited endpoints. +#[derive(Debug)] +pub struct LoadBiaser { + inner: S, + /// Per-endpoint metrics, pending counter, and penalty config + shared: Arc, + config: LoadBiaserConfig, +} + +/// A `NewService` implementation that creates `LoadBiaser` wrappers. +#[derive(Debug)] +pub struct NewLoadBiaser { + inner: N, + config: LoadBiaserConfig, + _marker: PhantomData, +} + +/// Response future that tracks RTT and checks for failure responses. +/// +/// When the inner future completes we store the RTT based on elapsed time since the +/// request started, decrement the pending counter, and if the response indicates a +/// failure (using the `ResponseFailureHint` trait) we inject a penalty. +#[pin_project(PinnedDrop)] +pub struct LoadBiaserFuture { + #[pin] + inner: F, + /// Request start instant for RTT calculation + start: Instant, + /// Shared endpoint state (metrics, pending counter, penalty config) + shared: Arc, + /// Whether we've already decremented pending + completed: bool, + /// Marker for the response type (`ResponseFailureHint` bound) + _response: PhantomData Rsp>, +} + +impl LoadBiaser { + /// Creates a new `LoadBiaser` wrapping the given service. + pub fn new(inner: S, mut config: LoadBiaserConfig) -> Self { + if config.penalty_secs.is_nan() || config.penalty_secs < 0.0 { + tracing::warn!( + penalty_secs = config.penalty_secs, + "penalty_secs is NaN or negative, clamping to 0.0" + ); + config.penalty_secs = 0.0; + } + if config.penalty_decay < MIN_DECAY { + tracing::warn!( + penalty_decay = ?config.penalty_decay, + min = ?MIN_DECAY, + "penalty_decay below minimum, will be clamped by EWMA constructor" + ); + } + if config.rtt_decay < MIN_DECAY { + tracing::warn!( + rtt_decay = ?config.rtt_decay, + min = ?MIN_DECAY, + "rtt_decay below minimum, will be clamped by EWMA constructor" + ); + } + let now = Instant::now(); + let shared = Arc::new(SharedState { + metrics: RwLock::new(LoadMetrics { + // Initialize RTT with default_rtt (not INFINITY) so that fresh + // endpoints have comparable load to unmeasured ones. This matches + // Tower's PeakEwma behavior and prevents P2C from permanently + // preferring endpoints whose first request happened to be fast. + rtt: Ewma::new_with_value(config.rtt_decay, now, config.default_rtt.as_secs_f64()), + penalty: Ewma::new(config.penalty_decay, now), + }), + pending: AtomicU32::new(0), + penalty_secs: config.penalty_secs, + enabled: config.enabled, + max_duration: config.max_duration, + penalty_decay: config.penalty_decay, + }); + Self { + inner, + shared, + config, + } + } + + /// Returns a reference to the inner service. + pub fn get_ref(&self) -> &S { + &self.inner + } +} + +impl Clone for LoadBiaser { + fn clone(&self) -> Self { + Self { + inner: self.inner.clone(), + shared: self.shared.clone(), + config: self.config.clone(), + } + } +} + +impl Load for LoadBiaser { + type Metric = f64; + + fn load(&self) -> Self::Metric { + let pending = self.shared.pending.load(Ordering::Acquire); + let now = Instant::now(); + + let (rtt, penalty_val) = { + let metrics = self.shared.metrics.read(); + (metrics.rtt.get_at(now), metrics.penalty.get_at(now)) + }; + + let penalty = if penalty_val.is_infinite() { + 0.0 + } else { + penalty_val + }; + + // Load = RTT * (pending + 1), minimum will be `penalty` + // The +1 ensures idle endpoints have some load based on RTT + let base = rtt * f64::from(pending.saturating_add(1)); + let load = f64::max(base, penalty); + + tracing::trace!( + rtt_secs = rtt, + pending = pending, + penalty_secs = penalty, + load = load, + "LoadBiaser::load" + ); + + load + } +} + +impl Service for LoadBiaser +where + S: Service, + S::Response: ResponseFailureHint, +{ + type Response = S::Response; + type Error = S::Error; + type Future = LoadBiaserFuture; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.inner.poll_ready(cx) + } + + fn call(&mut self, req: Req) -> Self::Future { + let prev = self.shared.pending.fetch_add(1, Ordering::AcqRel); + debug_assert!(prev < u32::MAX, "pending counter overflow"); + + LoadBiaserFuture { + inner: self.inner.call(req), + start: Instant::now(), + shared: self.shared.clone(), + completed: false, + _response: PhantomData, + } + } +} + +impl NewLoadBiaser { + /// Creates a new `NewLoadBiaser` with the given configuration. + pub fn new(config: LoadBiaserConfig, inner: N) -> Self { + Self { + inner, + config, + _marker: PhantomData, + } + } +} + +impl Clone for NewLoadBiaser { + fn clone(&self) -> Self { + Self { + inner: self.inner.clone(), + config: self.config.clone(), + _marker: PhantomData, + } + } +} + +impl NewService for NewLoadBiaser +where + N: NewService, +{ + type Service = LoadBiaser; + + fn new_service(&self, target: T) -> LoadBiaser { + LoadBiaser::new(self.inner.new_service(target), self.config.clone()) + } +} + +impl Future for LoadBiaserFuture +where + F: Future>, + Rsp: ResponseFailureHint, +{ + type Output = F::Output; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.project(); + let mut result = ready!(this.inner.poll(cx)); + let shared = &**this.shared; + + let now = Instant::now(); + let elapsed = now.saturating_duration_since(*this.start).as_secs_f64(); + + // Parse rate limit hint while we have &mut access to the response (when load + // biasing is enabled). + if shared.enabled { + if let Ok(ref mut resp) = result { + resp.attach_parsed_rate_limit_hint(shared.max_duration); + } + } + + { + let mut metrics = shared.metrics.write(); + if let Ok(ref resp) = result { + // Update RTT on all HTTP responses (including 429, 503, 5xx). + // Only transport-level errors (connection refused, resets) skip + // this update, because broken endpoints fail fast and would bias + // P2C toward sending more traffic to the broken endpoint. + metrics.rtt.add_peak(elapsed, now); + + if shared.enabled { + if let Some(hint) = resp.failure_hint() { + let base_penalty = shared.penalty_secs; + + // For rate-limited and service-unavailable responses, amplify + // the penalty using the Retry-After hint so it remains meaningful + // through the server's requested avoidance window. + let penalty_val = match hint { + FailureHint::RateLimited | FailureHint::ServiceUnavailable => { + match resp.rate_limit_hint(shared.max_duration) { + Some(ra) if ra.as_secs_f64() > 0.0 => { + let decay_secs = + shared.penalty_decay.max(MIN_DECAY).as_secs_f64(); + let amplified = base_penalty + * RETRY_AFTER_PENALTY_FACTOR + * (ra.as_secs_f64() / decay_secs).exp(); + amplified.min(1e12) + } + _ => base_penalty, + } + } + FailureHint::InternalError => base_penalty, + }; + + tracing::debug!( + penalty_secs = penalty_val, + rtt_secs = elapsed, + ?hint, + "Detected failure response - injecting load penalty" + ); + metrics.penalty.add_peak(penalty_val, now); + } + } + } + } + + shared.pending.fetch_sub(1, Ordering::Release); + *this.completed = true; + + Poll::Ready(result) + } +} + +#[pin_project::pinned_drop] +impl PinnedDrop for LoadBiaserFuture { + fn drop(self: Pin<&mut Self>) { + let this = self.project(); + // Only decrement if we haven't already done so in poll(). + // Avoids leaking the pending count upon cancellation. + if !*this.completed { + this.shared.pending.fetch_sub(1, Ordering::Release); + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::convert::Infallible; + use tokio::time; + + impl LoadBiaser { + pub fn get_rtt(&self) -> f64 { + self.shared.metrics.read().rtt.get() + } + + pub fn get_penalty(&self) -> f64 { + self.shared.metrics.read().penalty.get() + } + + pub fn get_pending(&self) -> u32 { + self.shared.pending.load(Ordering::Acquire) + } + + pub fn inject_penalty(&self, penalty_secs: f64) { + self.shared + .metrics + .write() + .penalty + .add_peak(penalty_secs, Instant::now()); + } + + pub fn inject_rtt(&self, rtt_secs: f64) { + self.shared + .metrics + .write() + .rtt + .add_peak(rtt_secs, Instant::now()); + } + } + + // Mock service for testing returning a specific HTTP status. + #[derive(Clone)] + struct MockService { + status: http::StatusCode, + } + + impl MockService { + fn new(status: http::StatusCode) -> Self { + Self { status } + } + } + + impl Service<()> for MockService { + type Response = http::Response<&'static str>; + type Error = Infallible; + type Future = futures::future::Ready>; + + fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, _: ()) -> Self::Future { + let resp = http::Response::builder() + .status(self.status) + .body("test") + .unwrap(); + futures::future::ready(Ok(resp)) + } + } + + // Mock service that always returns an error. + #[derive(Clone)] + struct ErrorService; + + impl Service<()> for ErrorService { + type Response = http::Response<&'static str>; + type Error = &'static str; + type Future = futures::future::Ready>; + + fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, _: ()) -> Self::Future { + futures::future::ready(Err("connection refused")) + } + } + + fn test_config() -> LoadBiaserConfig { + LoadBiaserConfig { + default_rtt: Duration::from_millis(100), // 0.1s default + rtt_decay: Duration::from_secs(10), + penalty_secs: 5.0, + penalty_decay: Duration::from_secs(10), + enabled: true, + max_duration: DEFAULT_RETRY_AFTER_MAX_DURATION, + } + } + + #[tokio::test(flavor = "current_thread", start_paused = true)] + async fn test_load_uses_default_rtt_initially() { + let inner = MockService::new(http::StatusCode::OK); + let biaser = LoadBiaser::new(inner, test_config()); + + // Initial load should be default_rtt * (pending + 1) = 0.1 * 1 = 0.1 + let load = biaser.load(); + assert!( + (load - 0.1).abs() < 0.001, + "initial load should be ~0.1 (default RTT): {}", + load + ); + } + + #[tokio::test(flavor = "current_thread", start_paused = true)] + async fn test_pending_increases_load() { + let inner = MockService::new(http::StatusCode::OK); + let biaser = LoadBiaser::new(inner, test_config()); + + // Inject a known RTT + time::sleep(Duration::from_millis(1)).await; + biaser.inject_rtt(0.05); // 50ms + + // RTT * (0 + 1) = 0.05 + let load_idle = biaser.load(); + + // Start a request (not awaited yet) + // Increment pending to simulate in-flight requests + biaser.shared.pending.fetch_add(1, Ordering::AcqRel); + + // RTT * (1 + 1) = 0.05 * 2 = 0.1 + let load_one_pending = biaser.load(); + + // Increment again + biaser.shared.pending.fetch_add(1, Ordering::AcqRel); + + // RTT * (2 + 1) = 0.05 * 3 = 0.15 + let load_two_pending = biaser.load(); + + assert!( + load_one_pending > load_idle, + "load should increase with pending: {} > {}", + load_one_pending, + load_idle + ); + assert!( + load_two_pending > load_one_pending, + "load should increase more: {} > {}", + load_two_pending, + load_one_pending + ); + } + + #[tokio::test(flavor = "current_thread", start_paused = true)] + async fn test_rtt_tracked_after_request() { + let inner = MockService::new(http::StatusCode::OK); + let mut biaser = LoadBiaser::new(inner, test_config()); + + // Advance time so EWMA accepts updates + time::sleep(Duration::from_millis(1)).await; + + // RTT should start at default_rtt (0.1s) before any requests + let initial_rtt = biaser.get_rtt(); + assert!( + (initial_rtt - 0.1).abs() < 0.01, + "RTT should start at default_rtt (0.1s), got: {initial_rtt}" + ); + + // Make a request (will record RTT) + let _ = biaser.call(()).await; + + // RTT should now reflect the actual request latency + let rtt = biaser.get_rtt(); + assert!( + rtt < initial_rtt, + "RTT should decrease after a fast request" + ); + } + + #[tokio::test(flavor = "current_thread", start_paused = true)] + async fn test_429_injects_penalty() { + let inner = MockService::new(http::StatusCode::TOO_MANY_REQUESTS); + let mut biaser = LoadBiaser::new(inner, test_config()); + + time::sleep(Duration::from_millis(1)).await; + + // Penalty should be infinite (none) before we get requests. + assert!(biaser.get_penalty().is_infinite()); + + // Make a request that returns 429. + let _ = biaser.call(()).await; + + // A penalty should be injected (5 seconds) after seeing a 429 + let penalty = biaser.get_penalty(); + assert!( + (penalty - 5.0).abs() < 0.1, + "penalty should be ~5s after 429: {}", + penalty + ); + + // Load should be at least the penalty. + let load = biaser.load(); + assert!(load >= 4.9, "load should be high after 429: {}", load); + } + + #[tokio::test(flavor = "current_thread", start_paused = true)] + async fn test_200_does_not_inject_penalty() { + let inner = MockService::new(http::StatusCode::OK); + let mut biaser = LoadBiaser::new(inner, test_config()); + + time::sleep(Duration::from_millis(1)).await; + + // Make a request that returns 200. + let _ = biaser.call(()).await; + + // Penalty should still be infinite (none). + assert!( + biaser.get_penalty().is_infinite(), + "penalty should not be injected for 200" + ); + } + + #[tokio::test(flavor = "current_thread", start_paused = true)] + async fn test_penalty_decays_over_time() { + let inner = MockService::new(http::StatusCode::TOO_MANY_REQUESTS); + // Use a custom config with default_rtt=1ms (0.001s) so load() tracks penalty directly. + // inject_rtt(0.001) would be a no-op here: add_peak() only replaces when the new + // value exceeds the decayed current value, so 0.001 < 0.1 (test_config default) + // would fail the peak check and fall through to add() which no-ops at ts==self.timestamp. + let config = LoadBiaserConfig { + default_rtt: Duration::from_millis(1), + ..test_config() + }; + let mut biaser = LoadBiaser::new(inner, config); + + time::sleep(Duration::from_millis(1)).await; + + // Trigger penalty via 429 response + let _ = biaser.call(()).await; + + // After call completes, pending is back to 0 (LoadBiaserFuture::poll decrements + // before returning Poll::Ready), so load = max(rtt * 1, penalty). + assert_eq!( + biaser.get_pending(), + 0, + "pending should be 0 after call completes" + ); + + // load() calls get_at(now) which projects the decayed penalty. + // Immediately after injection: penalty ~5.0, rtt ~0.001, so load ~5.0 + let load_before = biaser.load(); + assert!( + load_before > 4.0, + "load before decay should be dominated by the injected 429 penalty: {}", + load_before + ); + + // Advance time by one decay period (10s). + // Penalty decays: 5.0 * e^(-10/10) ~ 1.839 + time::sleep(Duration::from_secs(10)).await; + + // load() projects the decayed penalty at the new timestamp. + // No EWMA mutation needed since this is pure projection. + let load_after = biaser.load(); + assert!( + load_after > 1.0, + "load after one decay period should still be dominated by decayed penalty: {}", + load_after + ); + assert!( + load_after < 2.5, + "load after one decay period should reflect substantial decay (5.0 * e^-1 ~ 1.839): {}", + load_after + ); + assert!( + load_after < load_before, + "penalty should decay over time as observed through load(): {} < {}", + load_after, + load_before + ); + } + + #[tokio::test(flavor = "current_thread", start_paused = true)] + async fn test_load_is_max_of_rtt_based_and_penalty() { + let inner = MockService::new(http::StatusCode::OK); + let biaser = LoadBiaser::new(inner, test_config()); + + time::sleep(Duration::from_millis(1)).await; + + // Inject a high RTT (10 seconds) + biaser.inject_rtt(10.0); + + // Inject a lower penalty (1 second) + biaser.inject_penalty(1.0); + + // Load should be RTT-based since it's higher: 10 * 1 = 10 + let load = biaser.load(); + assert!( + (load - 10.0).abs() < 0.1, + "load should be RTT-based when higher: {}", + load + ); + + // Now inject a very high penalty + biaser.inject_penalty(20.0); + + // Load should be penalty since it's higher + let load = biaser.load(); + assert!( + (load - 20.0).abs() < 0.1, + "load should be penalty when higher: {}", + load + ); + } + + #[tokio::test(flavor = "current_thread", start_paused = true)] + async fn test_clone_shares_state() { + let inner = MockService::new(http::StatusCode::TOO_MANY_REQUESTS); + let mut biaser1 = LoadBiaser::new(inner.clone(), test_config()); + let biaser2 = biaser1.clone(); + + time::sleep(Duration::from_millis(1)).await; + + // Trigger penalty on biaser1 + let _ = biaser1.call(()).await; + + // biaser2 should see the same penalty (shared state) + assert_eq!(biaser1.get_penalty(), biaser2.get_penalty()); + assert_eq!(biaser1.load(), biaser2.load()); + } + + #[tokio::test(flavor = "current_thread", start_paused = true)] + async fn test_pending_decremented_on_completion() { + let inner = MockService::new(http::StatusCode::OK); + let mut biaser = LoadBiaser::new(inner, test_config()); + + assert_eq!(biaser.get_pending(), 0, "pending should start at 0"); + + // Start a request, pending increments + let fut = biaser.call(()); + assert_eq!( + biaser.get_pending(), + 1, + "pending should be 1 during request" + ); + + // Complete the request, pending decrements + let _ = fut.await; + assert_eq!( + biaser.get_pending(), + 0, + "pending should be 0 after completion" + ); + } + + #[tokio::test(flavor = "current_thread", start_paused = true)] + async fn test_rtt_not_updated_on_error() { + let inner = ErrorService; + let mut biaser = LoadBiaser::new(inner, test_config()); + + time::sleep(Duration::from_millis(1)).await; + + // RTT should start at default_rtt before any requests + let initial_rtt = biaser.get_rtt(); + assert!( + (initial_rtt - 0.1).abs() < 0.01, + "RTT should start at default_rtt (0.1s), got: {initial_rtt}" + ); + + // Make a request that returns an error + let _ = biaser.call(()).await; + + // RTT should remain at default_rtt because error responses don't + // update RTT. This prevents P2C from routing more traffic to broken + // endpoints that fail fast (appearing "fast" to the load metric). + let rtt_after_error = biaser.get_rtt(); + assert!( + (rtt_after_error - initial_rtt).abs() < 0.01, + "RTT should not change on error responses: {rtt_after_error} vs {initial_rtt}" + ); + } + + #[tokio::test(flavor = "current_thread", start_paused = true)] + async fn test_pending_decremented_on_cancellation() { + use tokio::sync::oneshot; + + // Build a service whose response future blocks on a oneshot receiver. + // This lets us drop the future mid-flight to test PinnedDrop. + struct DelayedService { + rx: Option>>, + } + + impl Service<()> for DelayedService { + type Response = http::Response<&'static str>; + type Error = Infallible; + type Future = DelayedFuture; + + fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, _: ()) -> Self::Future { + DelayedFuture { + rx: self.rx.take().expect("called more than once"), + } + } + } + + struct DelayedFuture { + rx: oneshot::Receiver>, + } + + impl std::future::Future for DelayedFuture { + type Output = Result, Infallible>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + match Pin::new(&mut self.rx).poll(cx) { + Poll::Ready(Ok(resp)) => Poll::Ready(Ok(resp)), + Poll::Ready(Err(_)) => { + // Sender dropped. Return a default response + let resp = http::Response::builder() + .status(http::StatusCode::OK) + .body("cancelled") + .unwrap(); + + Poll::Ready(Ok(resp)) + } + Poll::Pending => Poll::Pending, + } + } + } + + let (tx, rx) = oneshot::channel(); + let inner = DelayedService { rx: Some(rx) }; + let mut biaser = LoadBiaser::new(inner, test_config()); + + assert_eq!(biaser.get_pending(), 0); + + // Initiate the request, pending increments in Service::call() + let fut = biaser.call(()); + assert_eq!(biaser.get_pending(), 1); + + // Drop the future without completing it. The oneshot receiver + // never receives a value, so the inner future is still pending. + // PinnedDrop must decrement the pending count. + drop(fut); + + assert_eq!( + biaser.get_pending(), + 0, + "PinnedDrop should decrement pending on cancellation" + ); + + // tx is unused. Dropping it here is fine, it just closes the channel. + drop(tx); + } + + #[test] + fn default_max_duration_matches_client_policy() { + // Enforces the invariant at LoadBiaserConfig::default(): + // max_duration must match the local DEFAULT_RETRY_AFTER_MAX_DURATION constant. + // Production code reads the constant directly via EwmaConfig::to_load_biaser_config; + // this assertion keeps the test-only default() in sync. + assert_eq!( + LoadBiaserConfig::default().max_duration, + DEFAULT_RETRY_AFTER_MAX_DURATION, + "LoadBiaserConfig::default().max_duration must match \ + DEFAULT_RETRY_AFTER_MAX_DURATION (300s)" + ); + } + + // Mock service returning an HTTP error with a Retry-After header. + // Parameterized by status code to avoid duplicating 429 and 503 variants. + #[derive(Clone)] + struct RetryAfterService { + status: http::StatusCode, + retry_after_secs: u64, + } + + impl Service<()> for RetryAfterService { + type Response = http::Response<&'static str>; + type Error = Infallible; + type Future = futures::future::Ready>; + + fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, _: ()) -> Self::Future { + let resp = http::Response::builder() + .status(self.status) + .header(http::header::RETRY_AFTER, self.retry_after_secs.to_string()) + .body("retry-after response") + .unwrap(); + + futures::future::ready(Ok(resp)) + } + } + + #[tokio::test(flavor = "current_thread", start_paused = true)] + async fn test_429_with_retry_after_uses_adaptive_penalty() { + let inner = RetryAfterService { + status: http::StatusCode::TOO_MANY_REQUESTS, + retry_after_secs: 30, + }; + let mut biaser = LoadBiaser::new(inner, test_config()); + + time::sleep(Duration::from_millis(1)).await; + + // Make a request that returns 429 with Retry-After: 30 + let _ = biaser.call(()).await; + + // Amplified: penalty_secs * FACTOR * exp(RA/decay) = 5.0 * 0.5 * e^3 + let expected = 5.0_f64 * RETRY_AFTER_PENALTY_FACTOR * (30.0_f64 / 10.0_f64).exp(); + let penalty = biaser.get_penalty(); + + assert!( + (penalty - expected).abs() < 1.0, + "penalty should be ~{expected:.1} (amplified), got: {}", + penalty + ); + } + + #[tokio::test(flavor = "current_thread", start_paused = true)] + async fn test_429_with_retry_after_clamped_to_max() { + let inner = RetryAfterService { + status: http::StatusCode::TOO_MANY_REQUESTS, + retry_after_secs: 600, + }; + let config = LoadBiaserConfig { + max_duration: DEFAULT_RETRY_AFTER_MAX_DURATION, + ..test_config() + }; + let mut biaser = LoadBiaser::new(inner, config); + + time::sleep(Duration::from_millis(1)).await; + + // Make a request that returns 429 with Retry-After: 600, clamped to 300s + let _ = biaser.call(()).await; + + // Amplified: penalty_secs * FACTOR * exp(clamped_RA/decay) = 5.0 * 0.5 * e^30 + // exceeds the 1e12 penalty clamp, so the actual penalty is clamped. + let unclamped = 5.0_f64 * RETRY_AFTER_PENALTY_FACTOR * (300.0_f64 / 10.0_f64).exp(); + let expected = unclamped.min(1e12); + let penalty = biaser.get_penalty(); + + assert!( + (penalty - expected).abs() / expected < 0.01, + "penalty should be ~{expected:.0} (amplified, clamped RA=300), got: {}", + penalty + ); + } + + #[test] + fn test_rate_limit_hint_parses_retry_after() { + let mut resp = http::Response::builder() + .status(http::StatusCode::TOO_MANY_REQUESTS) + .header(http::header::RETRY_AFTER, "45") + .body("rate limited") + .unwrap(); + let max = Duration::from_secs(60); + + resp.attach_parsed_rate_limit_hint(max); + + assert_eq!(resp.rate_limit_hint(max), Some(Duration::from_secs(45))); + } + + #[test] + fn test_rate_limit_hint_none_for_200() { + let mut resp = http::Response::builder() + .status(http::StatusCode::OK) + .header(http::header::RETRY_AFTER, "45") + .body("ok") + .unwrap(); + let max = Duration::from_secs(60); + + resp.attach_parsed_rate_limit_hint(max); + + assert_eq!(resp.rate_limit_hint(max), None); + } + + #[test] + fn test_rate_limit_hint_none_without_header() { + let mut resp = http::Response::builder() + .status(http::StatusCode::TOO_MANY_REQUESTS) + .body("rate limited") + .unwrap(); + let max = Duration::from_secs(60); + + resp.attach_parsed_rate_limit_hint(max); + + assert_eq!(resp.rate_limit_hint(max), None); + } + + #[test] + fn test_rate_limit_hint_on_read_without_attach() { + let resp = http::Response::builder() + .status(http::StatusCode::TOO_MANY_REQUESTS) + .header(http::header::RETRY_AFTER, "45") + .body("rate limited") + .unwrap(); + + // No attach call, test on-read fallback + assert_eq!( + resp.rate_limit_hint(Duration::from_secs(60)), + Some(Duration::from_secs(45)) + ); + } + + #[test] + fn test_rate_limit_hint_on_read_caps_at_max() { + let resp = http::Response::builder() + .status(http::StatusCode::TOO_MANY_REQUESTS) + .header(http::header::RETRY_AFTER, "300") + .body("rate limited") + .unwrap(); + + // On-read parse caps at caller's max + assert_eq!( + resp.rate_limit_hint(Duration::from_secs(60)), + Some(Duration::from_secs(60)) + ); + } + + // Test multiple caps + + const SMALL_CAP: Duration = Duration::from_secs(60); + const DEFAULT_CAP: Duration = DEFAULT_RETRY_AFTER_MAX_DURATION; + const LARGE_CAP: Duration = Duration::from_secs(1800); + + // Constructs a 429 response whose `Retry-After` header uses the integer + // seconds form (delay-seconds per RFC 7231). + fn build_http_retry_after_response(retry_after_secs: u64) -> http::Response<&'static str> { + http::Response::builder() + .status(http::StatusCode::TOO_MANY_REQUESTS) + .header(http::header::RETRY_AFTER, retry_after_secs.to_string()) + .body("rate limited") + .unwrap() + } + + // Constructs a 200 OK trailers-only gRPC error response using + // `grpc-status: 8` (RESOURCE_EXHAUSTED) and `grpc-retry-pushback-ms: `. + fn build_grpc_pushback_response(pushback_ms: u64) -> http::Response<&'static str> { + http::Response::builder() + .status(http::StatusCode::OK) + .header("grpc-status", "8") + .header("grpc-retry-pushback-ms", pushback_ms.to_string()) + .body("grpc error") + .unwrap() + } + + // Verifies the cached-path `.min(max)` clamp in `rate_limit_hint()` clamps HTTP + // Retry-After hints to the caller-supplied cap consistently across cap + // magnitudes (60s / 300s / 1800s) and directions (below-cap / over-cap). + #[test] + fn test_rate_limit_hint_multi_cap_http() { + struct Row { + cap: Duration, + header_value: u64, // Retry-After in integer seconds + } + + let rows = [ + Row { + cap: SMALL_CAP, + header_value: 30, + }, // below + Row { + cap: SMALL_CAP, + header_value: 120, + }, // over + Row { + cap: DEFAULT_CAP, + header_value: 120, + }, // below + Row { + cap: DEFAULT_CAP, + header_value: 600, + }, // over + Row { + cap: LARGE_CAP, + header_value: 900, + }, // below + Row { + cap: LARGE_CAP, + header_value: 3600, + }, // over + ]; + + for Row { cap, header_value } in rows { + let mut resp = build_http_retry_after_response(header_value); + resp.attach_parsed_rate_limit_hint(Duration::MAX); + + // Remove the source header after attach + resp.headers_mut().remove(http::header::RETRY_AFTER); + let parsed = Duration::from_secs(header_value); + let expected = parsed.min(cap); + + assert_eq!( + resp.rate_limit_hint(cap), + Some(expected), + "cap={cap:?}, header={header_value}s, expected={expected:?}", + ); + } + } + + // Verifies the cached-path `.min(max)` clamp in `rate_limit_hint()` clamps gRPC + // retry-pushback-ms hints to the caller-supplied cap consistently across + // cap magnitudes (60s / 300s / 1800s) and directions (below-cap / over-cap). + #[test] + fn test_rate_limit_hint_multi_cap_grpc() { + struct Row { + cap: Duration, + header_value: u64, // grpc-retry-pushback in milliseconds + } + + let rows = [ + Row { + cap: SMALL_CAP, + header_value: 30_000, + }, // below (30s) + Row { + cap: SMALL_CAP, + header_value: 120_000, + }, // over (120s) + Row { + cap: DEFAULT_CAP, + header_value: 120_000, + }, // below (120s) + Row { + cap: DEFAULT_CAP, + header_value: 600_000, + }, // over (600s) + Row { + cap: LARGE_CAP, + header_value: 900_000, + }, // below (900s) + Row { + cap: LARGE_CAP, + header_value: 3_600_000, + }, // over (3600s) + ]; + + for Row { cap, header_value } in rows { + let mut resp = build_grpc_pushback_response(header_value); + resp.attach_parsed_rate_limit_hint(Duration::MAX); + + // Remove the source header after attach + resp.headers_mut().remove("grpc-retry-pushback-ms"); + let parsed = Duration::from_millis(header_value); + let expected = parsed.min(cap); + + assert_eq!( + resp.rate_limit_hint(cap), + Some(expected), + "cap={cap:?}, header={header_value}ms, expected={expected:?}", + ); + } + } + + #[tokio::test(flavor = "current_thread", start_paused = true)] + async fn test_503_injects_full_penalty() { + let inner = MockService::new(http::StatusCode::SERVICE_UNAVAILABLE); + let mut biaser = LoadBiaser::new(inner, test_config()); + + time::sleep(Duration::from_millis(1)).await; + + let _ = biaser.call(()).await; + let penalty = biaser.get_penalty(); + + // test_config has penalty_secs = 5.0 + assert!( + (penalty - 5.0).abs() < 0.1, + "503 should inject full penalty (~5.0s), got: {penalty}" + ); + } + + #[tokio::test(flavor = "current_thread", start_paused = true)] + async fn test_503_with_retry_after_uses_adaptive_penalty() { + let inner = RetryAfterService { + status: http::StatusCode::SERVICE_UNAVAILABLE, + retry_after_secs: 30, + }; + let mut biaser = LoadBiaser::new(inner, test_config()); + + time::sleep(Duration::from_millis(1)).await; + + // Make a request that returns 503 with Retry-After: 30 + let _ = biaser.call(()).await; + + // Same amplification formula as 429: + // penalty_secs * FACTOR * exp(RA/decay) = 5.0 * 0.5 * e^3 + let expected = 5.0_f64 * RETRY_AFTER_PENALTY_FACTOR * (30.0_f64 / 10.0_f64).exp(); + let penalty = biaser.get_penalty(); + assert!( + (penalty - expected).abs() < 1.0, + "503+RA penalty should be ~{expected:.1} (amplified), got: {penalty}" + ); + } + + #[tokio::test(flavor = "current_thread", start_paused = true)] + async fn test_503_with_retry_after_clamped_to_max() { + let inner = RetryAfterService { + status: http::StatusCode::SERVICE_UNAVAILABLE, + retry_after_secs: 600, + }; + let config = LoadBiaserConfig { + max_duration: DEFAULT_RETRY_AFTER_MAX_DURATION, + ..test_config() + }; + let mut biaser = LoadBiaser::new(inner, config); + + time::sleep(Duration::from_millis(1)).await; + + // Make a request that returns 503 with Retry-After: 600, clamped to 300s + let _ = biaser.call(()).await; + + // Amplified with clamped RA: + // penalty_secs * FACTOR * exp(clamped_RA/decay) = 5.0 * 0.5 * e^30 + // exceeds the 1e12 penalty clamp, so the actual penalty is clamped. + let unclamped = 5.0_f64 * RETRY_AFTER_PENALTY_FACTOR * (300.0_f64 / 10.0_f64).exp(); + let expected = unclamped.min(1e12); + let penalty = biaser.get_penalty(); + assert!( + (penalty - expected).abs() / expected < 0.01, + "503+RA penalty should be ~{expected:.0} (amplified, clamped RA=300), got: {penalty}" + ); + } + + #[tokio::test(flavor = "current_thread", start_paused = true)] + async fn test_429_and_503_retry_after_produce_identical_penalty() { + // 429 path: RetryAfterService returns TOO_MANY_REQUESTS with Retry-After: 30 + let inner_429 = RetryAfterService { + status: http::StatusCode::TOO_MANY_REQUESTS, + retry_after_secs: 30, + }; + let mut biaser_429 = LoadBiaser::new(inner_429, test_config()); + + // 503 path: RetryAfterService returns SERVICE_UNAVAILABLE with Retry-After: 30 + let inner_503 = RetryAfterService { + status: http::StatusCode::SERVICE_UNAVAILABLE, + retry_after_secs: 30, + }; + let mut biaser_503 = LoadBiaser::new(inner_503, test_config()); + + // Bootstrap EWMA timestamps + time::sleep(Duration::from_millis(1)).await; + let _ = biaser_429.call(()).await; + + time::sleep(Duration::from_millis(1)).await; + let _ = biaser_503.call(()).await; + + let load_429 = biaser_429.load(); + let load_503 = biaser_503.load(); + + assert!( + (load_429 - load_503).abs() / load_429.max(1e-9) < 0.05, + "429+RA and 503+RA should produce similar load; 429={load_429}, 503={load_503}" + ); + } + + #[tokio::test(flavor = "current_thread", start_paused = true)] + async fn test_500_injects_penalty() { + let inner = MockService::new(http::StatusCode::INTERNAL_SERVER_ERROR); + let mut biaser = LoadBiaser::new(inner, test_config()); + + time::sleep(Duration::from_millis(1)).await; + + let _ = biaser.call(()).await; + let penalty = biaser.get_penalty(); + + // test_config has penalty_secs = 5.0 + assert!( + (penalty - 5.0).abs() < 0.1, + "500 should inject full penalty (~5.0s), got: {penalty}" + ); + } + + // Mock service that returns HTTP 200 with a `grpc-status` header, + // simulating a gRPC trailers-only error response. + #[derive(Clone)] + struct GrpcErrorService { + grpc_status: u16, + } + + impl Service<()> for GrpcErrorService { + type Response = http::Response<&'static str>; + type Error = Infallible; + type Future = futures::future::Ready>; + + fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, _: ()) -> Self::Future { + let resp = http::Response::builder() + .status(http::StatusCode::OK) + .header("grpc-status", self.grpc_status.to_string()) + .body("grpc error") + .unwrap(); + + futures::future::ready(Ok(resp)) + } + } + + #[tokio::test(flavor = "current_thread", start_paused = true)] + async fn test_grpc_resource_exhausted_injects_penalty() { + let inner = GrpcErrorService { grpc_status: 8 }; // RESOURCE_EXHAUSTED + let mut biaser = LoadBiaser::new(inner, test_config()); + + time::sleep(Duration::from_millis(1)).await; + + let _ = biaser.call(()).await; + // gRPC RESOURCE_EXHAUSTED maps to FailureHint::RateLimited, + // which injects the full penalty_secs (5.0). + let penalty = biaser.get_penalty(); + + assert!( + (penalty - 5.0).abs() < 0.1, + "gRPC RESOURCE_EXHAUSTED should inject full penalty (~5s), got: {penalty}" + ); + } + + async fn assert_disabled_no_penalty(mut biaser: LoadBiaser, label: &str) + where + S: Service<(), Response = http::Response<&'static str>, Error = Infallible>, + { + time::sleep(Duration::from_millis(1)).await; + let _ = biaser.call(()).await; + assert!( + biaser.get_penalty().is_infinite(), + "penalty should not be injected when disabled for {label}: {}", + biaser.get_penalty() + ); + } + + fn disabled_config() -> LoadBiaserConfig { + LoadBiaserConfig { + enabled: false, + ..test_config() + } + } + + #[tokio::test(flavor = "current_thread", start_paused = true)] + async fn test_429_disabled_no_penalty() { + let inner = MockService::new(http::StatusCode::TOO_MANY_REQUESTS); + assert_disabled_no_penalty(LoadBiaser::new(inner, disabled_config()), "429").await; + } + + #[tokio::test(flavor = "current_thread", start_paused = true)] + async fn test_503_disabled_no_penalty() { + let inner = MockService::new(http::StatusCode::SERVICE_UNAVAILABLE); + assert_disabled_no_penalty(LoadBiaser::new(inner, disabled_config()), "503").await; + } + + #[tokio::test(flavor = "current_thread", start_paused = true)] + async fn test_500_disabled_no_penalty() { + let inner = MockService::new(http::StatusCode::INTERNAL_SERVER_ERROR); + assert_disabled_no_penalty(LoadBiaser::new(inner, disabled_config()), "500").await; + } + + #[tokio::test(flavor = "current_thread", start_paused = true)] + async fn test_grpc_resource_exhausted_disabled_no_penalty() { + let inner = GrpcErrorService { grpc_status: 8 }; + assert_disabled_no_penalty( + LoadBiaser::new(inner, disabled_config()), + "gRPC RESOURCE_EXHAUSTED (status 8)", + ) + .await; + } + + #[tokio::test(flavor = "current_thread", start_paused = true)] + async fn test_grpc_unavailable_disabled_no_penalty() { + let inner = GrpcErrorService { grpc_status: 14 }; + assert_disabled_no_penalty( + LoadBiaser::new(inner, disabled_config()), + "gRPC UNAVAILABLE (status 14)", + ) + .await; + } + + #[tokio::test(flavor = "current_thread", start_paused = true)] + async fn test_grpc_internal_disabled_no_penalty() { + let inner = GrpcErrorService { grpc_status: 13 }; + assert_disabled_no_penalty( + LoadBiaser::new(inner, disabled_config()), + "gRPC INTERNAL (status 13)", + ) + .await; + } +}