From 5f971f6d2d8d14eb6d77326dec077169f856c02f Mon Sep 17 00:00:00 2001 From: Alejandro Martinez Ruiz Date: Thu, 21 May 2026 19:59:44 +0200 Subject: [PATCH 1/5] feat(ewma): add base EWMA crate Introduce linkerd-ewma, a general-purpose exponentially-weighted moving average crate. The crate provides five public methods on an Ewma struct: new (initializes with INFINITY sentinel), get (returns stored value), add (blends a new sample using exponential decay), add_peak (replaces stored value when the new sample exceeds it), and add_rate (derives a rate from the inverse of the elapsed interval and feeds it through add). This is being added in spite of tower::PeakEwma because this is not limited to middleware-based RTT computing. We specifically plan to use this implementation for a load biasing feature and a success-rate circuit breaker policy, which would otherwise not be possible. Signed-off-by: Alejandro Martinez Ruiz --- Cargo.lock | 7 +++ Cargo.toml | 1 + linkerd/ewma/Cargo.toml | 13 +++++ linkerd/ewma/src/lib.rs | 125 ++++++++++++++++++++++++++++++++++++++++ 4 files changed, 146 insertions(+) create mode 100644 linkerd/ewma/Cargo.toml create mode 100644 linkerd/ewma/src/lib.rs diff --git a/Cargo.lock b/Cargo.lock index 4578b0ee69..728172de74 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1687,6 +1687,13 @@ dependencies = [ "pin-project", ] +[[package]] +name = "linkerd-ewma" +version = "0.1.0" +dependencies = [ + "tokio", +] + [[package]] name = "linkerd-exp-backoff" version = "0.1.0" diff --git a/Cargo.toml b/Cargo.toml index 7ac6863ca9..cd7cbd45e4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -22,6 +22,7 @@ members = [ "linkerd/error", "linkerd/errno", "linkerd/error-respond", + "linkerd/ewma", "linkerd/exp-backoff", "linkerd/http/access-log", "linkerd/http/body-eos", diff --git a/linkerd/ewma/Cargo.toml b/linkerd/ewma/Cargo.toml new file mode 100644 index 0000000000..8933c4e5c7 --- /dev/null +++ b/linkerd/ewma/Cargo.toml @@ -0,0 +1,13 @@ +[package] +name = "linkerd-ewma" +version = { workspace = true } +authors = { workspace = true } +license = { workspace = true } +edition = { workspace = true } +publish = { workspace = true } + +[dependencies] +tokio = { version = "1", features = ["time"] } + +[dev-dependencies] +tokio = { version = "1", features = ["macros", "rt", "test-util"] } diff --git a/linkerd/ewma/src/lib.rs b/linkerd/ewma/src/lib.rs new file mode 100644 index 0000000000..a69601d8b9 --- /dev/null +++ b/linkerd/ewma/src/lib.rs @@ -0,0 +1,125 @@ +#![deny(rust_2018_idioms, clippy::disallowed_methods, clippy::disallowed_types)] +#![forbid(unsafe_code)] + +use tokio::time; + +/// An exponentially-weighted moving average. +#[derive(Debug)] +pub struct Ewma { + value: f64, + decay: f64, + timestamp: time::Instant, +} + +// === impl Ewma === + +impl Ewma { + #[must_use] + pub fn new(decay: time::Duration, timestamp: time::Instant) -> Self { + Self { + decay: decay.as_secs_f64(), + timestamp, + value: f64::INFINITY, + } + } + + /// Returns the current value of the average. + pub fn get(&self) -> f64 { + self.value + } + + /// Updates the weighted moving average with a new value and timestamp. + /// + /// Precondition: `value` must not be NaN. Passing NaN poisons the EWMA + /// irreversibly (all subsequent reads return NaN). + pub fn add(&mut self, value: f64, ts: time::Instant) { + debug_assert!(!value.is_nan(), "EWMA input value must not be NaN"); + if ts <= self.timestamp { + return; + } + if self.value == f64::INFINITY { + self.value = value; + self.timestamp = ts; + return; + } + + self.value = { + let elapsed = ts.saturating_duration_since(self.timestamp); + let alpha = 1.0 - (-elapsed.as_secs_f64() / self.decay).exp(); + self.value * (1.0 - alpha) + value * alpha + }; + + self.timestamp = ts; + } + + pub fn add_peak(&mut self, value: f64, ts: time::Instant) { + if self.value < value { + self.value = value; + self.timestamp = ts; + return; + } + self.add(value, ts) + } + + /// Computes 1/elapsed since the last update and feeds it through `add()`. + pub fn add_rate(&mut self, ts: time::Instant) { + if ts <= self.timestamp { + return; + } + let elapsed = ts.saturating_duration_since(self.timestamp); + if !elapsed.is_zero() { + self.add(1.0 / elapsed.as_secs_f64(), ts); + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use tokio::time::{Duration, Instant}; + + #[tokio::test(flavor = "current_thread", start_paused = true)] + async fn test_new() { + let now = Instant::now(); + let ewma = Ewma::new(Duration::from_secs(10), now); + assert_eq!(ewma.get(), f64::INFINITY); + } + + #[tokio::test(flavor = "current_thread", start_paused = true)] + async fn test_add() { + let now = Instant::now(); + let mut ewma = Ewma::new(Duration::from_secs(10), now); + ewma.add(1.0, now + Duration::from_secs(1)); + assert_eq!(ewma.get(), 1.0); + } + + #[tokio::test(flavor = "current_thread", start_paused = true)] + async fn test_add_rate() { + let now = Instant::now(); + let mut ewma = Ewma::new(Duration::from_secs(10), now); + ewma.add_rate(now + Duration::from_secs(1)); + assert_eq!(ewma.get(), 1.0); + ewma.add_rate(now + Duration::from_secs(3)); + assert_eq!(ewma.get(), 0.9093653765389909); + } + + #[tokio::test(flavor = "current_thread", start_paused = true)] + async fn test_add_peak() { + let now = Instant::now(); + let mut ewma = Ewma::new(Duration::from_secs(10), now); + ewma.add_peak(1.0, now + Duration::from_secs(1)); + assert_eq!(ewma.get(), 1.0); + ewma.add_peak(2.0, now + Duration::from_secs_f64(1.5)); + assert_eq!(ewma.get(), 2.0); + } + + #[tokio::test(flavor = "current_thread", start_paused = true)] + async fn test_decay() { + let now = Instant::now(); + let mut ewma = Ewma::new(Duration::from_secs(10), now); + ewma.add(1.0, now + Duration::from_secs(1)); + assert_eq!(ewma.get(), 1.0); + ewma.add(0.0, now + Duration::from_secs(11)); + assert_eq!(ewma.get(), 0.36787944117144233); + } +} From 6cc60229d11a929145de46c53e4e17e3fca45554 Mon Sep 17 00:00:00 2001 From: Alejandro Martinez Ruiz Date: Thu, 21 May 2026 20:00:02 +0200 Subject: [PATCH 2/5] feat(ewma): add success-rate tracking extensions Extend linkerd-ewma with the API surface needed for success-rate circuit breaking. A MIN_DECAY constant (1 ms) is now applied in both constructors so that a zero-duration decay never produces division-by-zero or NaN results in downstream arithmetic. New methods: new_with_value sets an explicit initial sample instead of the INFINITY sentinel, reset overwrites both value and timestamp for breaker recovery, and get_at projects the stored value forward through exponential decay without mutating internal state. Also add_peak is now decay-aware: it projects the stored value to the candidate timestamp before deciding whether to replace it, and it unconditionally replaces INFINITY so that the first real sample always takes effect even at the construction timestamp. Signed-off-by: Alejandro Martinez Ruiz --- linkerd/ewma/src/lib.rs | 357 +++++++++++++++++++++++++++++++++++++++- 1 file changed, 354 insertions(+), 3 deletions(-) diff --git a/linkerd/ewma/src/lib.rs b/linkerd/ewma/src/lib.rs index a69601d8b9..872fcfc518 100644 --- a/linkerd/ewma/src/lib.rs +++ b/linkerd/ewma/src/lib.rs @@ -3,6 +3,11 @@ use tokio::time; +/// Minimum decay duration to prevent division-by-zero in EWMA computations. +/// Chosen as the smallest Duration that is strictly positive without overriding +/// validated configs from the control plane (CP should reject decay=0). +pub const MIN_DECAY: time::Duration = time::Duration::from_millis(1); + /// An exponentially-weighted moving average. #[derive(Debug)] pub struct Ewma { @@ -17,17 +22,62 @@ impl Ewma { #[must_use] pub fn new(decay: time::Duration, timestamp: time::Instant) -> Self { Self { - decay: decay.as_secs_f64(), + decay: decay.max(MIN_DECAY).as_secs_f64(), timestamp, value: f64::INFINITY, } } + /// Creates an EWMA with a specific initial value. + /// + /// This constructor allows setting an initial value, useful for + /// success rate tracking where you want to start at 100% (1.0) + /// success rate. + #[must_use] + pub fn new_with_value(decay: time::Duration, timestamp: time::Instant, initial: f64) -> Self { + debug_assert!(!initial.is_nan(), "EWMA initial value must not be NaN"); + Self { + decay: decay.max(MIN_DECAY).as_secs_f64(), + timestamp, + value: initial, + } + } + + /// Resets the EWMA to a new value and timestamp. + /// + /// This overwrites the current value and timestamp, useful for + /// resetting success rate tracking after recovery from a tripped state. + /// + /// Precondition: `ts` should be <= the timestamps passed to subsequent + /// `add()` calls. If it is not, those `add()` calls are silently + /// dropped because `add()` discards samples whose timestamp is at or + /// before the stored one. + pub fn reset(&mut self, value: f64, ts: time::Instant) { + debug_assert!(!value.is_nan(), "EWMA reset value must not be NaN"); + self.value = value; + self.timestamp = ts; + } + /// Returns the current value of the average. pub fn get(&self) -> f64 { self.value } + /// Returns the decayed value projected to the given time, without modifying stored state. + /// + /// Instead of returning the raw stored value, this applies exponential decay based + /// on elapsed time since the last update. This is required for load balancing where + /// stale measurements should lose influence over time. + pub fn get_at(&self, now: time::Instant) -> f64 { + debug_assert!(!self.value.is_nan(), "EWMA value must not be NaN"); + + if self.value.is_infinite() || now <= self.timestamp { + return self.value; + } + let elapsed = now.saturating_duration_since(self.timestamp); + self.value * (-elapsed.as_secs_f64() / self.decay).exp() + } + /// Updates the weighted moving average with a new value and timestamp. /// /// Precondition: `value` must not be NaN. Passing NaN poisons the EWMA @@ -52,8 +102,16 @@ impl Ewma { self.timestamp = ts; } + /// Updates the EWMA with a peak value, replacing the current value if the + /// new value exceeds the decayed projection. + /// + /// When replacement occurs, the stored timestamp is set to `ts`, which + /// may be earlier than the previously stored timestamp. This resets the + /// decay reference point, so subsequent projections via `get_at()` measure + /// elapsed time from `ts`. pub fn add_peak(&mut self, value: f64, ts: time::Instant) { - if self.value < value { + debug_assert!(!value.is_nan(), "EWMA peak value must not be NaN"); + if self.value.is_infinite() || self.get_at(ts) < value { self.value = value; self.timestamp = ts; return; @@ -78,6 +136,9 @@ mod tests { use super::*; use tokio::time::{Duration, Instant}; + // Literal value for exp(-1.0), since it's not const + const EXP_NEG1: f64 = 0.36787944117144233; + #[tokio::test(flavor = "current_thread", start_paused = true)] async fn test_new() { let now = Instant::now(); @@ -120,6 +181,296 @@ mod tests { ewma.add(1.0, now + Duration::from_secs(1)); assert_eq!(ewma.get(), 1.0); ewma.add(0.0, now + Duration::from_secs(11)); - assert_eq!(ewma.get(), 0.36787944117144233); + assert_eq!(ewma.get(), EXP_NEG1); + } + + #[tokio::test(flavor = "current_thread", start_paused = true)] + async fn test_new_with_value() { + let now = Instant::now(); + let ewma = Ewma::new_with_value(Duration::from_secs(10), now, 1.0); + + assert_eq!(ewma.get(), 1.0); + + // Verify this behaves like a normal EWMA after initialization + let mut ewma = Ewma::new_with_value(Duration::from_secs(10), now, 1.0); + + ewma.add(0.0, now + Duration::from_secs(10)); + // After one decay period value decays towards zero. + assert_eq!(ewma.get(), EXP_NEG1); + } + + #[tokio::test(flavor = "current_thread", start_paused = true)] + async fn test_add_peak_from_infinity_same_timestamp() { + let now = Instant::now(); + let mut ewma = Ewma::new(Duration::from_secs(10), now); + + assert_eq!(ewma.get(), f64::INFINITY); + + // Same timestamp as construction. The first real value should always + // take effect regardless of timestamp. + ewma.add_peak(0.5, now); + assert_eq!(ewma.get(), 0.5); + } + + #[tokio::test(flavor = "current_thread", start_paused = true)] + async fn test_add_peak_replaces_decayed_value() { + let now = Instant::now(); + let mut ewma = Ewma::new(Duration::from_secs(10), now); + + // Set initial peak of 10.0 at t=1s + ewma.add_peak(10.0, now + Duration::from_secs(1)); + assert_eq!(ewma.get(), 10.0); + + // After 25s of decay (t=26s), the decayed projection is: + // 10.0 * exp(-25/10) = 10.0 * exp(-2.5) = 0.8208... + // Since 0.8208 < 1.0, the new value should replace the stale peak. + ewma.add_peak(1.0, now + Duration::from_secs(26)); + assert_eq!(ewma.get(), 1.0); + } + + #[tokio::test(flavor = "current_thread", start_paused = true)] + async fn test_zero_decay_clamped_in_new() { + let now = Instant::now(); + let zero = Ewma::new(Duration::ZERO, now); + let min = Ewma::new(Duration::from_millis(1), now); + + // Both should produce identical behavior since ZERO gets clamped to 1ms + assert_eq!(zero.get(), min.get()); + + // After adding a value, get_at should produce finite, identical results + let mut zero = Ewma::new(Duration::ZERO, now); + let mut min = Ewma::new(Duration::from_millis(1), now); + zero.add(5.0, now + Duration::from_secs(1)); + min.add(5.0, now + Duration::from_secs(1)); + + let projected_zero = zero.get_at(now + Duration::from_secs(2)); + let projected_min = min.get_at(now + Duration::from_secs(2)); + + assert!( + projected_zero.is_finite(), + "get_at must be finite with clamped decay" + ); + assert!( + !projected_zero.is_nan(), + "get_at must not be NaN with clamped decay" + ); + assert_eq!(projected_zero, projected_min); + } + + #[tokio::test(flavor = "current_thread", start_paused = true)] + async fn test_zero_decay_clamped_in_new_with_value() { + let now = Instant::now(); + let zero = Ewma::new_with_value(Duration::ZERO, now, 1.0); + let min = Ewma::new_with_value(Duration::from_millis(1), now, 1.0); + + let projected_zero = zero.get_at(now + Duration::from_secs(1)); + let projected_min = min.get_at(now + Duration::from_secs(1)); + assert!( + projected_zero.is_finite(), + "get_at must be finite with clamped decay" + ); + assert!( + !projected_zero.is_nan(), + "get_at must not be NaN with clamped decay" + ); + assert_eq!(projected_zero, projected_min); + } + + #[tokio::test(flavor = "current_thread", start_paused = true)] + async fn test_zero_decay_add_is_safe() { + let now = Instant::now(); + let mut ewma = Ewma::new(Duration::ZERO, now); + + ewma.add(5.0, now + Duration::from_secs(1)); + ewma.add(0.5, now + Duration::from_secs(2)); + + let val = ewma.get(); + assert!( + val.is_finite(), + "add() result must be finite with clamped decay" + ); + assert!( + !val.is_nan(), + "add() result must not be NaN with clamped decay" + ); + } + + #[tokio::test(flavor = "current_thread", start_paused = true)] + async fn test_reset() { + let now = Instant::now(); + let mut ewma = Ewma::new(Duration::from_secs(10), now); + + // Set state + ewma.add(0.5, now + Duration::from_secs(1)); + assert_eq!(ewma.get(), 0.5); + + // Reset to a new value + ewma.reset(1.0, now + Duration::from_secs(2)); + assert_eq!(ewma.get(), 1.0); + + // Verify EWMA continues working after reset. + ewma.add(0.0, now + Duration::from_secs(12)); + assert_eq!(ewma.get(), EXP_NEG1); + } + + #[tokio::test(flavor = "current_thread", start_paused = true)] + async fn test_get_at_same_timestamp() { + let now = Instant::now(); + let decay = Duration::from_secs(10); + let add_at = now + Duration::from_secs(1); + let read_at = add_at; + + let mut ewma = Ewma::new(decay, now); + ewma.add(0.5, add_at); + + assert_eq!(ewma.get_at(read_at), 0.5); + } + + #[tokio::test(flavor = "current_thread", start_paused = true)] + async fn test_get_at_past_timestamp() { + let now = Instant::now(); + let decay = Duration::from_secs(10); + let add_at = now + Duration::from_secs(1); + let read_at = now + Duration::from_millis(500); + + let mut ewma = Ewma::new(decay, now); + ewma.add(0.5, add_at); + + assert_eq!(ewma.get_at(read_at), 0.5); + } + + #[tokio::test(flavor = "current_thread", start_paused = true)] + async fn test_get_at_infinity() { + let now = Instant::now(); + let decay = Duration::from_secs(10); + let probe_same = now; + let probe_near = now + Duration::from_secs(1); + let probe_far = now + Duration::from_secs(100); + + // A new Ewma without addding values should project INFINITY + // at every timestamp. + let ewma = Ewma::new(decay, now); + assert_eq!(ewma.get_at(probe_same), f64::INFINITY); + assert_eq!(ewma.get_at(probe_near), f64::INFINITY); + assert_eq!(ewma.get_at(probe_far), f64::INFINITY); + } + + #[tokio::test(flavor = "current_thread", start_paused = true)] + async fn test_get_at_decay() { + let now = Instant::now(); + let decay = Duration::from_secs(10); + let add_at = now + Duration::from_secs(1); + let read_at = now + Duration::from_secs(11); + + let mut ewma = Ewma::new(decay, now); + ewma.add(1.0, add_at); + + // Verify that get_at applies value * exp(-elapsed/decay) correctly + // without changing internal state. + assert_eq!(ewma.get_at(read_at), EXP_NEG1); + } + + #[tokio::test(flavor = "current_thread", start_paused = true)] + async fn test_get_at_large_elapsed() { + let now = Instant::now(); + let decay = Duration::from_secs(10); + let add_at = now + Duration::from_secs(1); + let read_at = now + Duration::from_secs(3600); + + let mut ewma = Ewma::new(decay, now); + ewma.add(1.0, add_at); + + let result = ewma.get_at(read_at); + assert!( + result.is_finite(), + "get_at at large elapsed must be finite, got {result}" + ); + assert!( + result < 1e-10, + "get_at at large elapsed must be near zero, got {result}" + ); + } + + #[tokio::test(flavor = "current_thread", start_paused = true)] + async fn test_get_at_non_mutation() { + let now = Instant::now(); + let decay = Duration::from_secs(10); + let first_add_at = now + Duration::from_secs(1); + let read_at = now + Duration::from_secs(6); + + let mut ewma = Ewma::new(decay, now); + ewma.add(1.0, first_add_at); + + // Take internal state before the read + let value_before = ewma.value; + let timestamp_before = ewma.timestamp; + + // Read must not mutate + let _ = ewma.get_at(read_at); + + assert_eq!(ewma.value, value_before); + assert_eq!(ewma.timestamp, timestamp_before); + } + + #[tokio::test(flavor = "current_thread", start_paused = true)] + async fn test_get_at_after_reset() { + const EWMA_VAL: f64 = 0.5; + const DECAY_AT_15S_VAL: f64 = EWMA_VAL * EXP_NEG1; + + let now = Instant::now(); + let decay = Duration::from_secs(10); + let first_add_at = now + Duration::from_secs(1); + let reset_at = now + Duration::from_secs(5); + let immediate_read_at = reset_at; + let decay_read_at = now + Duration::from_secs(15); + + let mut ewma = Ewma::new(decay, now); + // Use a large value here to make sure we detect any issues with + // reset not working properly, since we'd likely see a wildly + // different value in the assert below. + ewma.add(100.0, first_add_at); + + // Reset replaces both value and timestamp + ewma.reset(EWMA_VAL, reset_at); + + // Must return the freshly-reset value + assert_eq!(ewma.get_at(immediate_read_at), EWMA_VAL); + + // Read one decay period (10s) after the reset (so 15s decay). + assert_eq!(ewma.get_at(decay_read_at), DECAY_AT_15S_VAL); + } + + #[tokio::test(flavor = "current_thread", start_paused = true)] + async fn test_get_at_between_adds() { + const EXP_NEG0_5: f64 = 0.6065306597126334; + + let now = Instant::now(); + let decay = Duration::from_secs(10); + let first_add_at = now + Duration::from_secs(1); + let read_at = now + Duration::from_secs(6); + let second_add_at = now + Duration::from_secs(11); + + let mut ewma = Ewma::new(decay, now); + ewma.add(1.0, first_add_at); + + assert_eq!(ewma.get_at(read_at), EXP_NEG0_5); + + ewma.add(0.0, second_add_at); + + // Final state after the second add must be exp(-1.0), + // ensuring get_at() doesn't change the internal state. + assert_eq!(ewma.get(), EXP_NEG1); + } + + #[tokio::test(flavor = "current_thread", start_paused = true)] + #[cfg(debug_assertions)] + #[should_panic(expected = "NaN")] + async fn test_get_at_debug_asserts_nan() { + let now = Instant::now(); + // Inject NaN via new_with_value + let ewma = Ewma::new_with_value(Duration::from_secs(10), now, f64::NAN); + + // Should trigger a debug_assert + let _ = ewma.get_at(now); } } From 03485467c6ab10c6ba89628fcd158e97d4233909 Mon Sep 17 00:00:00 2001 From: Alejandro Martinez Ruiz Date: Thu, 7 May 2026 14:09:55 +0200 Subject: [PATCH 3/5] feat(classify): add Retry-After and gRPC pushback parsers Add a retry_after module to linkerd-http-classify with shared parsing functions for extracting backoff hints from HTTP and gRPC responses. parse_retry_after handles 429/503 responses with both delay-seconds and HTTP-date formats per RFC 7231, capping the returned duration at a caller-specified maximum. parse_grpc_retry_pushback reads the grpc-retry-pushback-ms header per the gRPC A6 spec, rejecting negative values and capping positive ones. We use the httpdate crate for the actual RFC 7231 HTTP-date parsing. Signed-off-by: Alejandro Martinez Ruiz --- Cargo.lock | 1 + Cargo.toml | 1 + linkerd/http/classify/Cargo.toml | 1 + linkerd/http/classify/src/lib.rs | 1 + linkerd/http/classify/src/retry_after.rs | 320 +++++++++++++++++++++++ 5 files changed, 324 insertions(+) create mode 100644 linkerd/http/classify/src/retry_after.rs diff --git a/Cargo.lock b/Cargo.lock index 728172de74..278b7031f7 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1758,6 +1758,7 @@ dependencies = [ "futures", "http", "http-body", + "httpdate", "linkerd-error", "linkerd-http-box", "linkerd-stack", diff --git a/Cargo.toml b/Cargo.toml index cd7cbd45e4..f10ee0c3f6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -106,6 +106,7 @@ drain = { version = "0.2", default-features = false } h2 = { version = "0.4" } http = { version = "1" } http-body = { version = "1" } +httpdate = { version = "1.0" } hyper = { version = "1", default-features = false } prometheus-client = { version = "0.23" } prost = { version = "0.14" } diff --git a/linkerd/http/classify/Cargo.toml b/linkerd/http/classify/Cargo.toml index c7a7911e40..e5c23eb0b7 100644 --- a/linkerd/http/classify/Cargo.toml +++ b/linkerd/http/classify/Cargo.toml @@ -10,6 +10,7 @@ publish = { workspace = true } futures = { version = "0.3", default-features = false } http = { workspace = true } http-body = { workspace = true } +httpdate.workspace = true pin-project = "1" tokio = { version = "1", default-features = false } tracing = { workspace = true } diff --git a/linkerd/http/classify/src/lib.rs b/linkerd/http/classify/src/lib.rs index e25f2b7701..fcc9c5ef70 100644 --- a/linkerd/http/classify/src/lib.rs +++ b/linkerd/http/classify/src/lib.rs @@ -12,6 +12,7 @@ pub use self::{ mod channel; pub mod gate; mod insert; +pub mod retry_after; /// Determines how a request's response should be classified. pub trait Classify { diff --git a/linkerd/http/classify/src/retry_after.rs b/linkerd/http/classify/src/retry_after.rs new file mode 100644 index 0000000000..9b00ec4303 --- /dev/null +++ b/linkerd/http/classify/src/retry_after.rs @@ -0,0 +1,320 @@ +//! Shared parsing for HTTP Retry-After headers and gRPC retry-pushback-ms trailers. +//! +//! This module provides pure parsing functions with no classification or store logic. +//! Both the load-biaser and the circuit breaker can use these functions to extract +//! backoff hints from HTTP and gRPC responses. + +use http::{HeaderMap, StatusCode}; +use std::time::Duration; + +/// Parse the Retry-After header from a 429 or 503 response. +/// +/// Supports two formats per RFC 7231: +/// - delay-seconds: `Retry-After: 120` -> 120 seconds +/// - HTTP-date: `Retry-After: Wed, 21 Oct 2025 07:28:00 GMT` -> duration from now +/// +/// Returns `None` for: +/// - Non-429/503 responses +/// - Missing Retry-After header +/// - Invalid header formats +/// +/// The returned duration is capped at `max` to prevent abuse. +pub fn parse_retry_after( + status: StatusCode, + headers: &HeaderMap, + max: Duration, +) -> Option { + // Only parse for 429 and 503 responses + if status != StatusCode::TOO_MANY_REQUESTS && status != StatusCode::SERVICE_UNAVAILABLE { + return None; + } + + let value = headers.get(http::header::RETRY_AFTER)?; + let s = value.to_str().ok()?; + + parse_retry_after_value(s, max) +} + +/// Parse a Retry-After header value string. +/// +/// Tries delay-seconds first (most common), then HTTP-date format. +/// The returned duration is capped at `max`. +fn parse_retry_after_value(s: &str, max: Duration) -> Option { + let s = s.trim(); + // Try delay-seconds first (most common format) + if let Ok(secs) = s.parse::() { + let duration = Duration::from_secs(secs); + tracing::debug!(?duration, "Parsed Retry-After delay-seconds"); + return Some(duration.min(max)); + } + + // Try HTTP-date format + if let Ok(datetime) = httpdate::parse_http_date(s) { + let now = std::time::SystemTime::now(); + match datetime.duration_since(now) { + Ok(duration) => { + tracing::debug!(?duration, "Parsed Retry-After HTTP-date"); + return Some(duration.min(max)); + } + Err(_) => { + tracing::debug!("Retry-After HTTP-date is in the past"); + return Some(Duration::ZERO); + } + } + } + + tracing::debug!(%s, "Failed to parse Retry-After header"); + None +} + +/// The grpc-retry-pushback-ms header/trailer name. +const GRPC_RETRY_PUSHBACK_MS: &str = "grpc-retry-pushback-ms"; + +/// Parse grpc-retry-pushback-ms from headers or trailers. +/// +/// Per gRPC A6 spec: +/// - Positive i64: retry after this many milliseconds +/// - Negative i64: do not retry (returns `None`) +/// +/// This function does **not** check grpc-status; that is a classification +/// concern left to the caller. +/// +/// Returns `None` for: +/// - Missing header/trailer +/// - Negative values (interpreted as "do not retry") +/// - Invalid formats +/// +/// The returned duration is capped at `max` to prevent abuse. +pub fn parse_grpc_retry_pushback(headers: &HeaderMap, max: Duration) -> Option { + let value = headers.get(GRPC_RETRY_PUSHBACK_MS)?; + let s = value.to_str().ok()?; + + // Parse as i64 to handle potential negative values + let ms: i64 = match s.trim().parse() { + Ok(v) => v, + Err(_) => { + tracing::debug!(%s, "Failed to parse grpc-retry-pushback-ms"); + return None; + } + }; + + // Negative values mean "do not retry". + if ms < 0 { + tracing::debug!(ms, "Ignoring negative grpc-retry-pushback-ms"); + return None; + } + + let duration = Duration::from_millis(ms as u64); + tracing::debug!(?duration, "Parsed grpc-retry-pushback-ms"); + Some(duration.min(max)) +} + +#[cfg(test)] +mod tests { + use super::*; + use http::header::RETRY_AFTER; + use http::HeaderValue; + use std::time::SystemTime; + + const MAX: Duration = Duration::from_secs(300); + + // === Retry-After tests === + + #[test] + fn parse_delay_seconds() { + let mut headers = HeaderMap::new(); + headers.insert(http::header::RETRY_AFTER, HeaderValue::from_static("120")); + + let result = parse_retry_after(StatusCode::TOO_MANY_REQUESTS, &headers, MAX); + assert_eq!(result, Some(Duration::from_secs(120))); + } + + #[test] + fn parse_delay_seconds_zero() { + let mut headers = HeaderMap::new(); + headers.insert(http::header::RETRY_AFTER, HeaderValue::from_static("0")); + + let result = parse_retry_after(StatusCode::TOO_MANY_REQUESTS, &headers, MAX); + assert_eq!(result, Some(Duration::ZERO)); + } + + #[test] + fn caps_at_max() { + let mut headers = HeaderMap::new(); + headers.insert(http::header::RETRY_AFTER, HeaderValue::from_static("3600")); + + let result = parse_retry_after(StatusCode::TOO_MANY_REQUESTS, &headers, MAX); + assert_eq!(result, Some(MAX)); + } + + #[test] + fn parses_503() { + let mut headers = HeaderMap::new(); + headers.insert(http::header::RETRY_AFTER, HeaderValue::from_static("120")); + + let result = parse_retry_after(StatusCode::SERVICE_UNAVAILABLE, &headers, MAX); + assert_eq!(result, Some(Duration::from_secs(120))); + } + + #[test] + fn ignores_other_status() { + let mut headers = HeaderMap::new(); + headers.insert(http::header::RETRY_AFTER, HeaderValue::from_static("120")); + + let result = parse_retry_after(StatusCode::OK, &headers, MAX); + assert_eq!(result, None); + } + + #[test] + fn ignores_missing_header() { + let headers = HeaderMap::new(); + + let result = parse_retry_after(StatusCode::TOO_MANY_REQUESTS, &headers, MAX); + assert_eq!(result, None); + } + + #[test] + fn ignores_invalid_value() { + let mut headers = HeaderMap::new(); + headers.insert( + http::header::RETRY_AFTER, + HeaderValue::from_static("not-a-number"), + ); + + let result = parse_retry_after(StatusCode::TOO_MANY_REQUESTS, &headers, MAX); + assert_eq!(result, None); + } + + #[test] + fn retry_after_http_date_in_past() { + let mut headers = HeaderMap::new(); + // Use a date in the past + headers.insert( + RETRY_AFTER, + "Wed, 01 Jan 2020 00:00:00 GMT".parse().unwrap(), + ); + + let result = parse_retry_after(StatusCode::TOO_MANY_REQUESTS, &headers, MAX); + assert_eq!(result, Some(Duration::ZERO)); + } + + #[test] + fn retry_after_http_date_in_future() { + let target = SystemTime::now() + Duration::from_secs(60); + let date_str = httpdate::fmt_http_date(target); + let mut headers = HeaderMap::new(); + headers.insert(RETRY_AFTER, date_str.parse().unwrap()); + + let result = parse_retry_after(StatusCode::TOO_MANY_REQUESTS, &headers, MAX); + let dur = result.expect("should parse future HTTP-date"); + + // httpdate truncates sub-seconds (1s resolution), and we can have + // some differences for clock time under CI load or NTP adjustment + // so test parsing correctness within a wide enough time window. + assert!( + dur >= Duration::from_secs(55) && dur <= Duration::from_secs(65), + "expected ~60s, got {:?}", + dur, + ); + } + + #[test] + fn retry_after_http_date_caps_at_max() { + let target = SystemTime::now() + Duration::from_secs(600); + let date_str = httpdate::fmt_http_date(target); + let mut headers = HeaderMap::new(); + headers.insert(RETRY_AFTER, date_str.parse().unwrap()); + + let result = parse_retry_after(StatusCode::TOO_MANY_REQUESTS, &headers, MAX); + assert_eq!(result, Some(MAX)); + } + + // === gRPC pushback tests === + + const MAX_GRPC: Duration = Duration::from_millis(300_000); + + #[test] + fn parse_grpc_pushback_positive() { + let mut headers = HeaderMap::new(); + headers.insert("grpc-retry-pushback-ms", HeaderValue::from_static("5000")); + + let result = parse_grpc_retry_pushback(&headers, MAX_GRPC); + assert_eq!(result, Some(Duration::from_millis(5000))); + } + + #[test] + fn parse_grpc_pushback_zero() { + let mut headers = HeaderMap::new(); + headers.insert("grpc-retry-pushback-ms", HeaderValue::from_static("0")); + + let result = parse_grpc_retry_pushback(&headers, MAX_GRPC); + assert_eq!(result, Some(Duration::ZERO)); + } + + #[test] + fn parse_grpc_pushback_negative() { + let mut headers = HeaderMap::new(); + headers.insert("grpc-retry-pushback-ms", HeaderValue::from_static("-1")); + + let result = parse_grpc_retry_pushback(&headers, MAX_GRPC); + assert_eq!(result, None); + } + + #[test] + fn parse_grpc_pushback_caps_at_max() { + let mut headers = HeaderMap::new(); + headers.insert("grpc-retry-pushback-ms", HeaderValue::from_static("999999")); + + let result = parse_grpc_retry_pushback(&headers, MAX_GRPC); + assert_eq!(result, Some(MAX_GRPC)); + } + + #[test] + fn parse_grpc_pushback_missing() { + let headers = HeaderMap::new(); + + let result = parse_grpc_retry_pushback(&headers, MAX_GRPC); + assert_eq!(result, None); + } + + #[test] + fn parse_grpc_pushback_invalid() { + let mut headers = HeaderMap::new(); + headers.insert( + "grpc-retry-pushback-ms", + HeaderValue::from_static("not-a-number"), + ); + + let result = parse_grpc_retry_pushback(&headers, MAX_GRPC); + assert_eq!(result, None); + } + + // === Whitespace handling tests === + + #[test] + fn retry_after_trailing_whitespace() { + let mut headers = HeaderMap::new(); + headers.insert(RETRY_AFTER, "120 ".parse().unwrap()); + + let result = parse_retry_after(StatusCode::TOO_MANY_REQUESTS, &headers, MAX); + assert_eq!(result, Some(Duration::from_secs(120))); + } + + #[test] + fn retry_after_leading_whitespace() { + let mut headers = HeaderMap::new(); + headers.insert(RETRY_AFTER, " 120".parse().unwrap()); + + let result = parse_retry_after(StatusCode::TOO_MANY_REQUESTS, &headers, MAX); + assert_eq!(result, Some(Duration::from_secs(120))); + } + + #[test] + fn grpc_pushback_whitespace() { + let mut headers = HeaderMap::new(); + headers.insert("grpc-retry-pushback-ms", " 5000 ".parse().unwrap()); + + let result = parse_grpc_retry_pushback(&headers, MAX); + assert_eq!(result, Some(Duration::from_millis(5000))); + } +} From 8a7fbb2837ee1e4fb6a21c6edf7d6322af8e05e6 Mon Sep 17 00:00:00 2001 From: Alejandro Martinez Ruiz Date: Thu, 21 May 2026 20:07:52 +0200 Subject: [PATCH 4/5] feat(load-biaser): add load biasing crate with RTT tracking and failure penalties Introduce the linkerd-load-biaser crate, which wraps any tower::Service to provide per-endpoint load metrics for P2C balancing. The crate tracks request latency via EWMA and injects penalties when failure responses are detected, steering traffic away from unhealthy endpoints. Penalty injection covers HTTP 429/503/5xx and gRPC RESOURCE_EXHAUSTED/UNAVAILABLE trailers-only responses (not streaming gRPC failures since we can only access headers here). For responses with backoff hints, Retry-After on HTTP 429/503 or grpc-retry-pushback-ms on gRPC trailers-only errors, the penalty is amplified so that the EWMA value remains meaningful through the server-requested backoff window. The amplification is clamped to prevent infinity from permanently disabling the endpoint. The load metric is computed as `max(rtt * (pending + 1), penalty)`, where `rtt` is the peak-EWMA latency, and `pending` is the number of in-flight requests. This is returned via tower::load::Load for direct P2C integration. The load biaser is disabled by default, preserving RTT-only behavior (PeakEwma equivalent), unless explicitly activated. Signed-off-by: Alejandro Martinez Ruiz --- Cargo.lock | 18 + Cargo.toml | 1 + linkerd/load-biaser/Cargo.toml | 29 ++ linkerd/load-biaser/src/lib.rs | 597 +++++++++++++++++++++++++++++++++ 4 files changed, 645 insertions(+) create mode 100644 linkerd/load-biaser/Cargo.toml create mode 100644 linkerd/load-biaser/src/lib.rs diff --git a/Cargo.lock b/Cargo.lock index 278b7031f7..743cfe1206 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1987,6 +1987,24 @@ dependencies = [ "tokio-util", ] +[[package]] +name = "linkerd-load-biaser" +version = "0.1.0" +dependencies = [ + "futures", + "http", + "linkerd-ewma", + "linkerd-http-classify", + "linkerd-stack", + "parking_lot", + "pin-project", + "tokio", + "tokio-test", + "tower", + "tower-service", + "tracing", +] + [[package]] name = "linkerd-meshtls" version = "0.1.0" diff --git a/Cargo.toml b/Cargo.toml index f10ee0c3f6..bc2dfc2e95 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -43,6 +43,7 @@ members = [ "linkerd/identity", "linkerd/idle-cache", "linkerd/io", + "linkerd/load-biaser", "linkerd/meshtls", "linkerd/meshtls/verifier", "linkerd/metrics", diff --git a/linkerd/load-biaser/Cargo.toml b/linkerd/load-biaser/Cargo.toml new file mode 100644 index 0000000000..a8fa0fb29a --- /dev/null +++ b/linkerd/load-biaser/Cargo.toml @@ -0,0 +1,29 @@ +[package] +name = "linkerd-load-biaser" +version = { workspace = true } +authors = { workspace = true } +license = { workspace = true } +edition = { workspace = true } +publish = { workspace = true } + +[features] +default = [] +tokio-test = ["dep:tokio-test"] + +[dependencies] +linkerd-ewma = { path = "../ewma" } +futures = { version = "0.3", default-features = false } +http = { workspace = true } +linkerd-http-classify = { path = "../http/classify" } +linkerd-stack = { path = "../stack" } +parking_lot = "0.12" +pin-project = "1" +tokio = { version = "1", features = ["io-util", "net", "time"] } +tokio-test = { version = "0.4", optional = true } +tower = { workspace = true, features = ["load"] } +tower-service = { workspace = true } +tracing = { workspace = true } + +[dev-dependencies] +tokio = { version = "1", features = ["macros", "rt", "time"] } +tokio-test = "0.4" diff --git a/linkerd/load-biaser/src/lib.rs b/linkerd/load-biaser/src/lib.rs new file mode 100644 index 0000000000..7254ca67ef --- /dev/null +++ b/linkerd/load-biaser/src/lib.rs @@ -0,0 +1,597 @@ +//! Load tracking with response failure awareness. +//! +//! This module provides a `LoadBiaser` wrapper that tracks request latency (RTT) +//! and detects failure responses (HTTP 429, 503, 5xx), applying artificial +//! penalties. Unlike Tower's `PeakEwma`, this implementation uses +//! `linkerd_ewma::Ewma` and returns `f64` metrics directly, enabling +//! integration with P2C load balancing. +//! +//! This can wrap any service (`Load` trait not required) and it tracks RTT via +//! EWMA, which is updated when responses complete, and pending requests for load +//! calculation. +//! +//! When a failure response is detected, a penalty is applied and the EWMA jumps +//! to a high value. The load is calculated as `max(rtt * (pending + 1), penalty)` +//! (so at least RTT with no pending requests). +//! +//! Both RTT and penalty decay over time via the EWMA. +//! +//! For this type to work responses must implement the `ResponseFailureHint` +//! trait, which classifies responses into failure categories (rate-limited, +//! service unavailable, internal error). For non-HTTP responses the default +//! implementation returns no failure hint, so only RTT tracking occurs. +//! +//! For gRPC, `failure_hint()` detects `RESOURCE_EXHAUSTED` (code 8) and +//! `UNAVAILABLE` (code 14) from the `grpc-status` header in trailers-only +//! (unary) responses. Streaming gRPC puts `grpc-status` in trailers which +//! are not visible at response-head time; the circuit breaker handles +//! streaming gRPC failures via `GrpcRetryPushbackClassifyEos`. + +#![deny(rust_2018_idioms, clippy::disallowed_methods, clippy::disallowed_types)] +#![forbid(unsafe_code)] + +use futures::ready; +use linkerd_ewma::{Ewma, MIN_DECAY}; +use linkerd_stack::NewService; +use parking_lot::RwLock; +use pin_project::pin_project; +use std::{ + future::Future, + marker::PhantomData, + pin::Pin, + sync::{ + atomic::{AtomicU32, Ordering}, + Arc, + }, + task::{Context, Poll}, + time::Duration, +}; +use tokio::time::Instant; +use tower::load::Load; +use tower_service::Service; + +/// Default maximum duration for Retry-After hints. +pub const DEFAULT_RETRY_AFTER_MAX_DURATION: Duration = Duration::from_secs(300); + +/// Classification of response failures for load biasing. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum FailureHint { + /// HTTP 429 Too Many Requests + RateLimited, + /// HTTP 503 Service Unavailable + ServiceUnavailable, + /// HTTP 5xx (other than 503) + InternalError, +} + +/// Cached Retry-After hint stored in HTTP response extensions. +/// +/// Stores the **uncapped** parsed value. Each consumer applies its own cap +/// via `rate_limit_hint(max)`, so different callers (e.g. load biaser vs +/// circuit breaker) can use different maximums from the same cached value. +#[derive(Clone, Copy, Debug)] +pub struct CachedRateLimitHint(pub Duration); + +/// Trait for extracting failure hints from responses. +/// +/// This allows the load biaser to classify responses and apply appropriate +/// penalties. Default implementations return None (no failure detected), +/// which is appropriate for non-HTTP transports. +/// +/// The trait splits rate limit hint access into two methods to avoid +/// requiring `&mut self` on the read path: +/// - `attach_parsed_rate_limit_hint(&mut self, max)`: parse and cache (needs `&mut`) +/// - `rate_limit_hint(&self, max)`: read cached value or parse on-read (only needs `&self`) +/// +/// # Stack ordering and caching +/// +/// In the current proxy stack `RetryAfterClassify::start()` (the circuit breaker's +/// classifier) calls `rate_limit_hint()` before `LoadBiaserFuture::poll()` calls +/// `attach_parsed_rate_limit_hint()`, because responses flow inner-to-outer. This +/// means the cache is always cold for the circuit breaker path. +pub trait ResponseFailureHint { + /// Returns a failure hint if the response indicates a failure condition. + fn failure_hint(&self) -> Option { + None + } + + /// Parse and cache the raw (uncapped) rate limit hint from this response. + /// + /// The `_max` parameter is accepted for API symmetry with `rate_limit_hint(max)` + /// but is intentionally unused. The raw uncapped value is cached so that each + /// consumer can apply their own cap via `rate_limit_hint(max)`. + fn attach_parsed_rate_limit_hint(&mut self, _max: Duration) {} + + /// Returns the rate limit hint if available. + /// + /// Checks cached value first (from a previous `attach_parsed_rate_limit_hint` call). + /// If no cached value, attempts to parse the header directly (capping at `max`). + /// Returns `None` only if the header is absent or unparseable. + fn rate_limit_hint(&self, _max: Duration) -> Option { + None + } +} + +/// HTTP responses classify failures by status code and parse Retry-After hints. +/// +/// For gRPC responses (which arrive as HTTP 200 with `grpc-status` in headers +/// for trailers-only/unary errors), the `grpc-status` header is checked when +/// the HTTP status is 200: +/// - gRPC status 8 (RESOURCE_EXHAUSTED) -> `RateLimited` +/// - gRPC status 14 (UNAVAILABLE) -> `ServiceUnavailable` +/// - Other non-zero gRPC status codes -> `InternalError` +impl ResponseFailureHint for http::Response { + fn failure_hint(&self) -> Option { + let status = self.status(); + if status == http::StatusCode::TOO_MANY_REQUESTS { + Some(FailureHint::RateLimited) + } else if status == http::StatusCode::SERVICE_UNAVAILABLE { + Some(FailureHint::ServiceUnavailable) + } else if status.is_server_error() { + Some(FailureHint::InternalError) + } else if status == http::StatusCode::OK { + // gRPC trailers-only responses: grpc-status appears in headers. + // Note: for streaming gRPC, grpc-status is in trailers (not headers) + // and will not be detected here. The circuit breaker handles streaming + // gRPC failures via GrpcRetryPushbackClassifyEos. + self.headers() + .get("grpc-status") + .and_then(|v| v.to_str().ok()) + .and_then(|s| s.parse::().ok()) + .and_then(|code| match code { + 0 => None, // OK + 8 => Some(FailureHint::RateLimited), // RESOURCE_EXHAUSTED + 14 => Some(FailureHint::ServiceUnavailable), // UNAVAILABLE + _ => Some(FailureHint::InternalError), // Other non-zero + }) + } else { + None + } + } + + fn attach_parsed_rate_limit_hint(&mut self, _max: Duration) { + // Store the uncapped value. Each consumer applies their own cap via + // rate_limit_hint(max). + if let Some(d) = linkerd_http_classify::retry_after::parse_retry_after( + self.status(), + self.headers(), + Duration::MAX, + ) { + self.extensions_mut().insert(CachedRateLimitHint(d)); + return; + } + // Try gRPC retry-pushback-ms (for trailers-only responses) + if self.status() == http::StatusCode::OK { + if let Some(d) = linkerd_http_classify::retry_after::parse_grpc_retry_pushback( + self.headers(), + Duration::MAX, + ) { + self.extensions_mut().insert(CachedRateLimitHint(d)); + } + } + } + + fn rate_limit_hint(&self, max: Duration) -> Option { + // Check cache first (from previous attach call), apply caller's cap + if let Some(cached) = self.extensions().get::() { + return Some(cached.0.min(max)); + } + // Parse on-read as fallback (header present but attach wasn't called) + if let Some(d) = linkerd_http_classify::retry_after::parse_retry_after( + self.status(), + self.headers(), + max, + ) { + return Some(d); + } + // Try gRPC pushback + if self.status() == http::StatusCode::OK { + if let Some(d) = + linkerd_http_classify::retry_after::parse_grpc_retry_pushback(self.headers(), max) + { + return Some(d); + } + } + // No header or unparseable + None + } +} + +/// Connection tuples (Connection, Metadata) used by TCP/TLS paths never indicate failures. +/// This is meant for MakeConnection services that return `(I, M)` tuples. +impl ResponseFailureHint for (C, M) {} + +/// TCP streams never indicate failures (no HTTP status codes). +impl ResponseFailureHint for tokio::net::TcpStream {} + +/// Duplex streams (used in testing) never indicate failures. +impl ResponseFailureHint for tokio::io::DuplexStream {} + +/// Mock IO streams never indicate failures. +#[cfg(feature = "tokio-test")] +impl ResponseFailureHint for tokio_test::io::Mock {} + +/// Amplification factor for Retry-After penalties. +/// +/// When a 429 or 503 response includes a Retry-After header, the load biaser injects +/// an amplified penalty so it remains meaningful through the server's requested +/// avoidance window. The injected value is: +/// +/// `penalty_secs * RETRY_AFTER_PENALTY_FACTOR * exp(retry_after / penalty_decay)` +/// +/// which decays via EWMA to `penalty_secs * RETRY_AFTER_PENALTY_FACTOR` at +/// exactly `t = retry_after`. The factor controls how aggressively the endpoint +/// is avoided near the Retry-After deadline: +/// +/// At 1.0 the endpoint is fully avoided (~0% traffic) through the window and 46s +/// beyond -- too aggressive for early-recovery discovery. At 0.1, traffic leaks +/// well before the deadline. We use 0.5: the penalty at deadline is penalty_secs/2 +/// (roughly 50x healthy load in a 5-endpoint pool), which lets occasional probes +/// through in the second half while still exceeding a plain failure penalty for +/// RA >= ~7s. For pools with 100+ endpoints the factor barely matters because P2C +/// random pair selection already makes the penalized endpoint unlikely to be drawn. +/// +/// Sending occasional probe traffic during the window is valuable: the endpoint +/// may recover early, or a fresh 429 or 503 with a different RA provides updated +/// information. The circuit breaker (when enabled) provides a strict backoff +/// window, ie. `max(backoff, retry_after)`. +const RETRY_AFTER_PENALTY_FACTOR: f64 = 0.5; + +/// Configuration for LoadBiaser behavior. +#[derive(Clone, Debug)] +pub struct LoadBiaserConfig { + /// Default RTT to use when no measurements are available + pub default_rtt: Duration, + + /// Decay duration for the RTT EWMA. + /// Controls how quickly RTT estimates adapt to changing latency + pub rtt_decay: Duration, + + /// The penalty value to inject on failure responses (429, 503, 5xx) in seconds + pub penalty_secs: f64, + + /// Decay duration for the penalty EWMA. + /// Controls how quickly the penalty decays after a failure response + pub penalty_decay: Duration, + + /// Whether load biasing penalties are enabled. When false, only RTT tracking + /// is active (PeakEwma equivalent). + pub enabled: bool, + + /// Maximum Retry-After duration to honor. Clamped to this value. + pub max_duration: Duration, +} + +impl Default for LoadBiaserConfig { + fn default() -> Self { + Self { + default_rtt: Duration::from_secs(1), + rtt_decay: Duration::from_secs(10), + // 5 second penalty on failure responses (429, 503, 5xx) + penalty_secs: 5.0, + // 10 second decay for penalty - mostly gone after ~30 seconds + penalty_decay: Duration::from_secs(10), + enabled: false, + max_duration: DEFAULT_RETRY_AFTER_MAX_DURATION, + } + } +} + +/// Shared per-endpoint state behind a single `Arc`. +/// +/// Combines the EWMA trackers (under a mutex for atomic RTT+penalty reads), +/// the in-flight request counter, and the immutable config fields that every +/// response future needs. One allocation per endpoint instead of two, and +/// futures carry a single `Arc` clone instead of copying config fields. +#[derive(Debug)] +struct SharedState { + /// RTT and penalty EWMAs; read-locked in load(), write-locked in poll() + metrics: RwLock, + /// Count of in-flight requests (up on call, down on response) + pending: AtomicU32, + /// Penalty value to inject on failure responses (in seconds) + penalty_secs: f64, + /// Whether penalty injection is enabled + enabled: bool, + /// Maximum Retry-After duration to honor (clamped) + max_duration: Duration, + /// Decay duration for the penalty EWMA (used to amplify Retry-After penalties) + penalty_decay: Duration, +} + +#[derive(Debug)] +struct LoadMetrics { + /// EWMA RTT tracking, updated on each response + rtt: Ewma, + /// EWMA penalty tracking, updated on failure responses (429, 503, 5xx) + penalty: Ewma, +} + +/// A service wrapper that tracks RTT and biases load metrics based on failure responses. +/// +/// `LoadBiaser` provides load metrics for P2C load balancing by tracking request latency +/// (RTT) via EWMA, in-flight requests, and by injecting penalties when failure responses +/// (429, 503, 5xx) are detected. +/// +/// The `load()` method returns `max(rtt * (pending + 1), penalty)`, causing P2C to +/// prefer endpoints with lower latency and fewer in-flight requests, while avoiding +/// rate-limited endpoints. +#[derive(Debug)] +pub struct LoadBiaser { + inner: S, + /// Per-endpoint metrics, pending counter, and penalty config + shared: Arc, + config: LoadBiaserConfig, +} + +/// A `NewService` implementation that creates `LoadBiaser` wrappers. +#[derive(Debug)] +pub struct NewLoadBiaser { + inner: N, + config: LoadBiaserConfig, + _marker: PhantomData, +} + +/// Response future that tracks RTT and checks for failure responses. +/// +/// When the inner future completes we store the RTT based on elapsed time since the +/// request started, decrement the pending counter, and if the response indicates a +/// failure (using the `ResponseFailureHint` trait) we inject a penalty. +#[pin_project(PinnedDrop)] +pub struct LoadBiaserFuture { + #[pin] + inner: F, + /// Request start instant for RTT calculation + start: Instant, + /// Shared endpoint state (metrics, pending counter, penalty config) + shared: Arc, + /// Whether we've already decremented pending + completed: bool, + /// Marker for the response type (`ResponseFailureHint` bound) + _response: PhantomData Rsp>, +} + +impl LoadBiaser { + /// Creates a new `LoadBiaser` wrapping the given service. + pub fn new(inner: S, mut config: LoadBiaserConfig) -> Self { + if config.penalty_secs.is_nan() || config.penalty_secs < 0.0 { + tracing::warn!( + penalty_secs = config.penalty_secs, + "penalty_secs is NaN or negative, clamping to 0.0" + ); + config.penalty_secs = 0.0; + } + if config.penalty_decay < MIN_DECAY { + tracing::warn!( + penalty_decay = ?config.penalty_decay, + min = ?MIN_DECAY, + "penalty_decay below minimum, will be clamped by EWMA constructor" + ); + } + if config.rtt_decay < MIN_DECAY { + tracing::warn!( + rtt_decay = ?config.rtt_decay, + min = ?MIN_DECAY, + "rtt_decay below minimum, will be clamped by EWMA constructor" + ); + } + let now = Instant::now(); + let shared = Arc::new(SharedState { + metrics: RwLock::new(LoadMetrics { + // Initialize RTT with default_rtt (not INFINITY) so that fresh + // endpoints have comparable load to unmeasured ones. This matches + // Tower's PeakEwma behavior and prevents P2C from permanently + // preferring endpoints whose first request happened to be fast. + rtt: Ewma::new_with_value(config.rtt_decay, now, config.default_rtt.as_secs_f64()), + penalty: Ewma::new(config.penalty_decay, now), + }), + pending: AtomicU32::new(0), + penalty_secs: config.penalty_secs, + enabled: config.enabled, + max_duration: config.max_duration, + penalty_decay: config.penalty_decay, + }); + Self { + inner, + shared, + config, + } + } + + /// Returns a reference to the inner service. + pub fn get_ref(&self) -> &S { + &self.inner + } +} + +impl Clone for LoadBiaser { + fn clone(&self) -> Self { + Self { + inner: self.inner.clone(), + shared: self.shared.clone(), + config: self.config.clone(), + } + } +} + +impl Load for LoadBiaser { + type Metric = f64; + + fn load(&self) -> Self::Metric { + let pending = self.shared.pending.load(Ordering::Acquire); + let now = Instant::now(); + + let (rtt, penalty_val) = { + let metrics = self.shared.metrics.read(); + (metrics.rtt.get_at(now), metrics.penalty.get_at(now)) + }; + + let penalty = if penalty_val.is_infinite() { + 0.0 + } else { + penalty_val + }; + + // Load = RTT * (pending + 1), minimum will be `penalty` + // The +1 ensures idle endpoints have some load based on RTT + let base = rtt * f64::from(pending.saturating_add(1)); + let load = f64::max(base, penalty); + + tracing::trace!( + rtt_secs = rtt, + pending = pending, + penalty_secs = penalty, + load = load, + "LoadBiaser::load" + ); + + load + } +} + +impl Service for LoadBiaser +where + S: Service, + S::Response: ResponseFailureHint, +{ + type Response = S::Response; + type Error = S::Error; + type Future = LoadBiaserFuture; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.inner.poll_ready(cx) + } + + fn call(&mut self, req: Req) -> Self::Future { + let prev = self.shared.pending.fetch_add(1, Ordering::AcqRel); + debug_assert!(prev < u32::MAX, "pending counter overflow"); + + LoadBiaserFuture { + inner: self.inner.call(req), + start: Instant::now(), + shared: self.shared.clone(), + completed: false, + _response: PhantomData, + } + } +} + +impl NewLoadBiaser { + /// Creates a new `NewLoadBiaser` with the given configuration. + pub fn new(config: LoadBiaserConfig, inner: N) -> Self { + Self { + inner, + config, + _marker: PhantomData, + } + } +} + +impl Clone for NewLoadBiaser { + fn clone(&self) -> Self { + Self { + inner: self.inner.clone(), + config: self.config.clone(), + _marker: PhantomData, + } + } +} + +impl NewService for NewLoadBiaser +where + N: NewService, +{ + type Service = LoadBiaser; + + fn new_service(&self, target: T) -> LoadBiaser { + LoadBiaser::new(self.inner.new_service(target), self.config.clone()) + } +} + +impl Future for LoadBiaserFuture +where + F: Future>, + Rsp: ResponseFailureHint, +{ + type Output = F::Output; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.project(); + let mut result = ready!(this.inner.poll(cx)); + let shared = &**this.shared; + + let now = Instant::now(); + let elapsed = now.saturating_duration_since(*this.start).as_secs_f64(); + + // Parse rate limit hint while we have &mut access to the response (when load + // biasing is enabled). + if shared.enabled { + if let Ok(ref mut resp) = result { + resp.attach_parsed_rate_limit_hint(shared.max_duration); + } + } + + { + let mut metrics = shared.metrics.write(); + if let Ok(ref resp) = result { + // Update RTT on all HTTP responses (including 429, 503, 5xx). + // Only transport-level errors (connection refused, resets) skip + // this update, because broken endpoints fail fast and would bias + // P2C toward sending more traffic to the broken endpoint. + metrics.rtt.add_peak(elapsed, now); + + if shared.enabled { + if let Some(hint) = resp.failure_hint() { + let base_penalty = shared.penalty_secs; + + // For rate-limited and service-unavailable responses, amplify + // the penalty using the Retry-After hint so it remains meaningful + // through the server's requested avoidance window. + let penalty_val = match hint { + FailureHint::RateLimited | FailureHint::ServiceUnavailable => { + match resp.rate_limit_hint(shared.max_duration) { + Some(ra) if ra.as_secs_f64() > 0.0 => { + let decay_secs = + shared.penalty_decay.max(MIN_DECAY).as_secs_f64(); + let amplified = base_penalty + * RETRY_AFTER_PENALTY_FACTOR + * (ra.as_secs_f64() / decay_secs).exp(); + amplified.min(1e12) + } + _ => base_penalty, + } + } + FailureHint::InternalError => base_penalty, + }; + + tracing::debug!( + penalty_secs = penalty_val, + rtt_secs = elapsed, + ?hint, + "Detected failure response - injecting load penalty" + ); + metrics.penalty.add_peak(penalty_val, now); + } + } + } + } + + shared.pending.fetch_sub(1, Ordering::Release); + *this.completed = true; + + Poll::Ready(result) + } +} + +#[pin_project::pinned_drop] +impl PinnedDrop for LoadBiaserFuture { + fn drop(self: Pin<&mut Self>) { + let this = self.project(); + // Only decrement if we haven't already done so in poll(). + // Avoids leaking the pending count upon cancellation. + if !*this.completed { + this.shared.pending.fetch_sub(1, Ordering::Release); + } + } +} + From c40a0af477be2cbd1bc330c356b8aa9cea47e302 Mon Sep 17 00:00:00 2001 From: Alejandro Martinez Ruiz Date: Thu, 21 May 2026 20:08:06 +0200 Subject: [PATCH 5/5] test(load-biaser): add unit tests for load biasing lifecycle These cover the complete load biasing lifecycle, including penalty injection, hint parsing, cancellation safety via PinnedDrop, and backwards-compatible behavior when disabled (ie. RTT-only behavior equivalent to PeakEwma). Signed-off-by: Alejandro Martinez Ruiz --- linkerd/load-biaser/src/lib.rs | 985 +++++++++++++++++++++++++++++++++ 1 file changed, 985 insertions(+) diff --git a/linkerd/load-biaser/src/lib.rs b/linkerd/load-biaser/src/lib.rs index 7254ca67ef..c90848c62c 100644 --- a/linkerd/load-biaser/src/lib.rs +++ b/linkerd/load-biaser/src/lib.rs @@ -595,3 +595,988 @@ impl PinnedDrop for LoadBiaserFuture { } } +#[cfg(test)] +mod tests { + use super::*; + use std::convert::Infallible; + use tokio::time; + + impl LoadBiaser { + pub fn get_rtt(&self) -> f64 { + self.shared.metrics.read().rtt.get() + } + + pub fn get_penalty(&self) -> f64 { + self.shared.metrics.read().penalty.get() + } + + pub fn get_pending(&self) -> u32 { + self.shared.pending.load(Ordering::Acquire) + } + + pub fn inject_penalty(&self, penalty_secs: f64) { + self.shared + .metrics + .write() + .penalty + .add_peak(penalty_secs, Instant::now()); + } + + pub fn inject_rtt(&self, rtt_secs: f64) { + self.shared + .metrics + .write() + .rtt + .add_peak(rtt_secs, Instant::now()); + } + } + + // Mock service for testing returning a specific HTTP status. + #[derive(Clone)] + struct MockService { + status: http::StatusCode, + } + + impl MockService { + fn new(status: http::StatusCode) -> Self { + Self { status } + } + } + + impl Service<()> for MockService { + type Response = http::Response<&'static str>; + type Error = Infallible; + type Future = futures::future::Ready>; + + fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, _: ()) -> Self::Future { + let resp = http::Response::builder() + .status(self.status) + .body("test") + .unwrap(); + futures::future::ready(Ok(resp)) + } + } + + // Mock service that always returns an error. + #[derive(Clone)] + struct ErrorService; + + impl Service<()> for ErrorService { + type Response = http::Response<&'static str>; + type Error = &'static str; + type Future = futures::future::Ready>; + + fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, _: ()) -> Self::Future { + futures::future::ready(Err("connection refused")) + } + } + + fn test_config() -> LoadBiaserConfig { + LoadBiaserConfig { + default_rtt: Duration::from_millis(100), // 0.1s default + rtt_decay: Duration::from_secs(10), + penalty_secs: 5.0, + penalty_decay: Duration::from_secs(10), + enabled: true, + max_duration: DEFAULT_RETRY_AFTER_MAX_DURATION, + } + } + + #[tokio::test(flavor = "current_thread", start_paused = true)] + async fn test_load_uses_default_rtt_initially() { + let inner = MockService::new(http::StatusCode::OK); + let biaser = LoadBiaser::new(inner, test_config()); + + // Initial load should be default_rtt * (pending + 1) = 0.1 * 1 = 0.1 + let load = biaser.load(); + assert!( + (load - 0.1).abs() < 0.001, + "initial load should be ~0.1 (default RTT): {}", + load + ); + } + + #[tokio::test(flavor = "current_thread", start_paused = true)] + async fn test_pending_increases_load() { + let inner = MockService::new(http::StatusCode::OK); + let biaser = LoadBiaser::new(inner, test_config()); + + // Inject a known RTT + time::sleep(Duration::from_millis(1)).await; + biaser.inject_rtt(0.05); // 50ms + + // RTT * (0 + 1) = 0.05 + let load_idle = biaser.load(); + + // Start a request (not awaited yet) + // Increment pending to simulate in-flight requests + biaser.shared.pending.fetch_add(1, Ordering::AcqRel); + + // RTT * (1 + 1) = 0.05 * 2 = 0.1 + let load_one_pending = biaser.load(); + + // Increment again + biaser.shared.pending.fetch_add(1, Ordering::AcqRel); + + // RTT * (2 + 1) = 0.05 * 3 = 0.15 + let load_two_pending = biaser.load(); + + assert!( + load_one_pending > load_idle, + "load should increase with pending: {} > {}", + load_one_pending, + load_idle + ); + assert!( + load_two_pending > load_one_pending, + "load should increase more: {} > {}", + load_two_pending, + load_one_pending + ); + } + + #[tokio::test(flavor = "current_thread", start_paused = true)] + async fn test_rtt_tracked_after_request() { + let inner = MockService::new(http::StatusCode::OK); + let mut biaser = LoadBiaser::new(inner, test_config()); + + // Advance time so EWMA accepts updates + time::sleep(Duration::from_millis(1)).await; + + // RTT should start at default_rtt (0.1s) before any requests + let initial_rtt = biaser.get_rtt(); + assert!( + (initial_rtt - 0.1).abs() < 0.01, + "RTT should start at default_rtt (0.1s), got: {initial_rtt}" + ); + + // Make a request (will record RTT) + let _ = biaser.call(()).await; + + // RTT should now reflect the actual request latency + let rtt = biaser.get_rtt(); + assert!( + rtt < initial_rtt, + "RTT should decrease after a fast request" + ); + } + + #[tokio::test(flavor = "current_thread", start_paused = true)] + async fn test_429_injects_penalty() { + let inner = MockService::new(http::StatusCode::TOO_MANY_REQUESTS); + let mut biaser = LoadBiaser::new(inner, test_config()); + + time::sleep(Duration::from_millis(1)).await; + + // Penalty should be infinite (none) before we get requests. + assert!(biaser.get_penalty().is_infinite()); + + // Make a request that returns 429. + let _ = biaser.call(()).await; + + // A penalty should be injected (5 seconds) after seeing a 429 + let penalty = biaser.get_penalty(); + assert!( + (penalty - 5.0).abs() < 0.1, + "penalty should be ~5s after 429: {}", + penalty + ); + + // Load should be at least the penalty. + let load = biaser.load(); + assert!(load >= 4.9, "load should be high after 429: {}", load); + } + + #[tokio::test(flavor = "current_thread", start_paused = true)] + async fn test_200_does_not_inject_penalty() { + let inner = MockService::new(http::StatusCode::OK); + let mut biaser = LoadBiaser::new(inner, test_config()); + + time::sleep(Duration::from_millis(1)).await; + + // Make a request that returns 200. + let _ = biaser.call(()).await; + + // Penalty should still be infinite (none). + assert!( + biaser.get_penalty().is_infinite(), + "penalty should not be injected for 200" + ); + } + + #[tokio::test(flavor = "current_thread", start_paused = true)] + async fn test_penalty_decays_over_time() { + let inner = MockService::new(http::StatusCode::TOO_MANY_REQUESTS); + // Use a custom config with default_rtt=1ms (0.001s) so load() tracks penalty directly. + // inject_rtt(0.001) would be a no-op here: add_peak() only replaces when the new + // value exceeds the decayed current value, so 0.001 < 0.1 (test_config default) + // would fail the peak check and fall through to add() which no-ops at ts==self.timestamp. + let config = LoadBiaserConfig { + default_rtt: Duration::from_millis(1), + ..test_config() + }; + let mut biaser = LoadBiaser::new(inner, config); + + time::sleep(Duration::from_millis(1)).await; + + // Trigger penalty via 429 response + let _ = biaser.call(()).await; + + // After call completes, pending is back to 0 (LoadBiaserFuture::poll decrements + // before returning Poll::Ready), so load = max(rtt * 1, penalty). + assert_eq!( + biaser.get_pending(), + 0, + "pending should be 0 after call completes" + ); + + // load() calls get_at(now) which projects the decayed penalty. + // Immediately after injection: penalty ~5.0, rtt ~0.001, so load ~5.0 + let load_before = biaser.load(); + assert!( + load_before > 4.0, + "load before decay should be dominated by the injected 429 penalty: {}", + load_before + ); + + // Advance time by one decay period (10s). + // Penalty decays: 5.0 * e^(-10/10) ~ 1.839 + time::sleep(Duration::from_secs(10)).await; + + // load() projects the decayed penalty at the new timestamp. + // No EWMA mutation needed since this is pure projection. + let load_after = biaser.load(); + assert!( + load_after > 1.0, + "load after one decay period should still be dominated by decayed penalty: {}", + load_after + ); + assert!( + load_after < 2.5, + "load after one decay period should reflect substantial decay (5.0 * e^-1 ~ 1.839): {}", + load_after + ); + assert!( + load_after < load_before, + "penalty should decay over time as observed through load(): {} < {}", + load_after, + load_before + ); + } + + #[tokio::test(flavor = "current_thread", start_paused = true)] + async fn test_load_is_max_of_rtt_based_and_penalty() { + let inner = MockService::new(http::StatusCode::OK); + let biaser = LoadBiaser::new(inner, test_config()); + + time::sleep(Duration::from_millis(1)).await; + + // Inject a high RTT (10 seconds) + biaser.inject_rtt(10.0); + + // Inject a lower penalty (1 second) + biaser.inject_penalty(1.0); + + // Load should be RTT-based since it's higher: 10 * 1 = 10 + let load = biaser.load(); + assert!( + (load - 10.0).abs() < 0.1, + "load should be RTT-based when higher: {}", + load + ); + + // Now inject a very high penalty + biaser.inject_penalty(20.0); + + // Load should be penalty since it's higher + let load = biaser.load(); + assert!( + (load - 20.0).abs() < 0.1, + "load should be penalty when higher: {}", + load + ); + } + + #[tokio::test(flavor = "current_thread", start_paused = true)] + async fn test_clone_shares_state() { + let inner = MockService::new(http::StatusCode::TOO_MANY_REQUESTS); + let mut biaser1 = LoadBiaser::new(inner.clone(), test_config()); + let biaser2 = biaser1.clone(); + + time::sleep(Duration::from_millis(1)).await; + + // Trigger penalty on biaser1 + let _ = biaser1.call(()).await; + + // biaser2 should see the same penalty (shared state) + assert_eq!(biaser1.get_penalty(), biaser2.get_penalty()); + assert_eq!(biaser1.load(), biaser2.load()); + } + + #[tokio::test(flavor = "current_thread", start_paused = true)] + async fn test_pending_decremented_on_completion() { + let inner = MockService::new(http::StatusCode::OK); + let mut biaser = LoadBiaser::new(inner, test_config()); + + assert_eq!(biaser.get_pending(), 0, "pending should start at 0"); + + // Start a request, pending increments + let fut = biaser.call(()); + assert_eq!( + biaser.get_pending(), + 1, + "pending should be 1 during request" + ); + + // Complete the request, pending decrements + let _ = fut.await; + assert_eq!( + biaser.get_pending(), + 0, + "pending should be 0 after completion" + ); + } + + #[tokio::test(flavor = "current_thread", start_paused = true)] + async fn test_rtt_not_updated_on_error() { + let inner = ErrorService; + let mut biaser = LoadBiaser::new(inner, test_config()); + + time::sleep(Duration::from_millis(1)).await; + + // RTT should start at default_rtt before any requests + let initial_rtt = biaser.get_rtt(); + assert!( + (initial_rtt - 0.1).abs() < 0.01, + "RTT should start at default_rtt (0.1s), got: {initial_rtt}" + ); + + // Make a request that returns an error + let _ = biaser.call(()).await; + + // RTT should remain at default_rtt because error responses don't + // update RTT. This prevents P2C from routing more traffic to broken + // endpoints that fail fast (appearing "fast" to the load metric). + let rtt_after_error = biaser.get_rtt(); + assert!( + (rtt_after_error - initial_rtt).abs() < 0.01, + "RTT should not change on error responses: {rtt_after_error} vs {initial_rtt}" + ); + } + + #[tokio::test(flavor = "current_thread", start_paused = true)] + async fn test_pending_decremented_on_cancellation() { + use tokio::sync::oneshot; + + // Build a service whose response future blocks on a oneshot receiver. + // This lets us drop the future mid-flight to test PinnedDrop. + struct DelayedService { + rx: Option>>, + } + + impl Service<()> for DelayedService { + type Response = http::Response<&'static str>; + type Error = Infallible; + type Future = DelayedFuture; + + fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, _: ()) -> Self::Future { + DelayedFuture { + rx: self.rx.take().expect("called more than once"), + } + } + } + + struct DelayedFuture { + rx: oneshot::Receiver>, + } + + impl std::future::Future for DelayedFuture { + type Output = Result, Infallible>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + match Pin::new(&mut self.rx).poll(cx) { + Poll::Ready(Ok(resp)) => Poll::Ready(Ok(resp)), + Poll::Ready(Err(_)) => { + // Sender dropped. Return a default response + let resp = http::Response::builder() + .status(http::StatusCode::OK) + .body("cancelled") + .unwrap(); + + Poll::Ready(Ok(resp)) + } + Poll::Pending => Poll::Pending, + } + } + } + + let (tx, rx) = oneshot::channel(); + let inner = DelayedService { rx: Some(rx) }; + let mut biaser = LoadBiaser::new(inner, test_config()); + + assert_eq!(biaser.get_pending(), 0); + + // Initiate the request, pending increments in Service::call() + let fut = biaser.call(()); + assert_eq!(biaser.get_pending(), 1); + + // Drop the future without completing it. The oneshot receiver + // never receives a value, so the inner future is still pending. + // PinnedDrop must decrement the pending count. + drop(fut); + + assert_eq!( + biaser.get_pending(), + 0, + "PinnedDrop should decrement pending on cancellation" + ); + + // tx is unused. Dropping it here is fine, it just closes the channel. + drop(tx); + } + + #[test] + fn default_max_duration_matches_client_policy() { + // Enforces the invariant at LoadBiaserConfig::default(): + // max_duration must match the local DEFAULT_RETRY_AFTER_MAX_DURATION constant. + // Production code reads the constant directly via EwmaConfig::to_load_biaser_config; + // this assertion keeps the test-only default() in sync. + assert_eq!( + LoadBiaserConfig::default().max_duration, + DEFAULT_RETRY_AFTER_MAX_DURATION, + "LoadBiaserConfig::default().max_duration must match \ + DEFAULT_RETRY_AFTER_MAX_DURATION (300s)" + ); + } + + // Mock service returning an HTTP error with a Retry-After header. + // Parameterized by status code to avoid duplicating 429 and 503 variants. + #[derive(Clone)] + struct RetryAfterService { + status: http::StatusCode, + retry_after_secs: u64, + } + + impl Service<()> for RetryAfterService { + type Response = http::Response<&'static str>; + type Error = Infallible; + type Future = futures::future::Ready>; + + fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, _: ()) -> Self::Future { + let resp = http::Response::builder() + .status(self.status) + .header(http::header::RETRY_AFTER, self.retry_after_secs.to_string()) + .body("retry-after response") + .unwrap(); + + futures::future::ready(Ok(resp)) + } + } + + #[tokio::test(flavor = "current_thread", start_paused = true)] + async fn test_429_with_retry_after_uses_adaptive_penalty() { + let inner = RetryAfterService { + status: http::StatusCode::TOO_MANY_REQUESTS, + retry_after_secs: 30, + }; + let mut biaser = LoadBiaser::new(inner, test_config()); + + time::sleep(Duration::from_millis(1)).await; + + // Make a request that returns 429 with Retry-After: 30 + let _ = biaser.call(()).await; + + // Amplified: penalty_secs * FACTOR * exp(RA/decay) = 5.0 * 0.5 * e^3 + let expected = 5.0_f64 * RETRY_AFTER_PENALTY_FACTOR * (30.0_f64 / 10.0_f64).exp(); + let penalty = biaser.get_penalty(); + + assert!( + (penalty - expected).abs() < 1.0, + "penalty should be ~{expected:.1} (amplified), got: {}", + penalty + ); + } + + #[tokio::test(flavor = "current_thread", start_paused = true)] + async fn test_429_with_retry_after_clamped_to_max() { + let inner = RetryAfterService { + status: http::StatusCode::TOO_MANY_REQUESTS, + retry_after_secs: 600, + }; + let config = LoadBiaserConfig { + max_duration: DEFAULT_RETRY_AFTER_MAX_DURATION, + ..test_config() + }; + let mut biaser = LoadBiaser::new(inner, config); + + time::sleep(Duration::from_millis(1)).await; + + // Make a request that returns 429 with Retry-After: 600, clamped to 300s + let _ = biaser.call(()).await; + + // Amplified: penalty_secs * FACTOR * exp(clamped_RA/decay) = 5.0 * 0.5 * e^30 + // exceeds the 1e12 penalty clamp, so the actual penalty is clamped. + let unclamped = 5.0_f64 * RETRY_AFTER_PENALTY_FACTOR * (300.0_f64 / 10.0_f64).exp(); + let expected = unclamped.min(1e12); + let penalty = biaser.get_penalty(); + + assert!( + (penalty - expected).abs() / expected < 0.01, + "penalty should be ~{expected:.0} (amplified, clamped RA=300), got: {}", + penalty + ); + } + + #[test] + fn test_rate_limit_hint_parses_retry_after() { + let mut resp = http::Response::builder() + .status(http::StatusCode::TOO_MANY_REQUESTS) + .header(http::header::RETRY_AFTER, "45") + .body("rate limited") + .unwrap(); + let max = Duration::from_secs(60); + + resp.attach_parsed_rate_limit_hint(max); + + assert_eq!(resp.rate_limit_hint(max), Some(Duration::from_secs(45))); + } + + #[test] + fn test_rate_limit_hint_none_for_200() { + let mut resp = http::Response::builder() + .status(http::StatusCode::OK) + .header(http::header::RETRY_AFTER, "45") + .body("ok") + .unwrap(); + let max = Duration::from_secs(60); + + resp.attach_parsed_rate_limit_hint(max); + + assert_eq!(resp.rate_limit_hint(max), None); + } + + #[test] + fn test_rate_limit_hint_none_without_header() { + let mut resp = http::Response::builder() + .status(http::StatusCode::TOO_MANY_REQUESTS) + .body("rate limited") + .unwrap(); + let max = Duration::from_secs(60); + + resp.attach_parsed_rate_limit_hint(max); + + assert_eq!(resp.rate_limit_hint(max), None); + } + + #[test] + fn test_rate_limit_hint_on_read_without_attach() { + let resp = http::Response::builder() + .status(http::StatusCode::TOO_MANY_REQUESTS) + .header(http::header::RETRY_AFTER, "45") + .body("rate limited") + .unwrap(); + + // No attach call, test on-read fallback + assert_eq!( + resp.rate_limit_hint(Duration::from_secs(60)), + Some(Duration::from_secs(45)) + ); + } + + #[test] + fn test_rate_limit_hint_on_read_caps_at_max() { + let resp = http::Response::builder() + .status(http::StatusCode::TOO_MANY_REQUESTS) + .header(http::header::RETRY_AFTER, "300") + .body("rate limited") + .unwrap(); + + // On-read parse caps at caller's max + assert_eq!( + resp.rate_limit_hint(Duration::from_secs(60)), + Some(Duration::from_secs(60)) + ); + } + + // Test multiple caps + + const SMALL_CAP: Duration = Duration::from_secs(60); + const DEFAULT_CAP: Duration = DEFAULT_RETRY_AFTER_MAX_DURATION; + const LARGE_CAP: Duration = Duration::from_secs(1800); + + // Constructs a 429 response whose `Retry-After` header uses the integer + // seconds form (delay-seconds per RFC 7231). + fn build_http_retry_after_response(retry_after_secs: u64) -> http::Response<&'static str> { + http::Response::builder() + .status(http::StatusCode::TOO_MANY_REQUESTS) + .header(http::header::RETRY_AFTER, retry_after_secs.to_string()) + .body("rate limited") + .unwrap() + } + + // Constructs a 200 OK trailers-only gRPC error response using + // `grpc-status: 8` (RESOURCE_EXHAUSTED) and `grpc-retry-pushback-ms: `. + fn build_grpc_pushback_response(pushback_ms: u64) -> http::Response<&'static str> { + http::Response::builder() + .status(http::StatusCode::OK) + .header("grpc-status", "8") + .header("grpc-retry-pushback-ms", pushback_ms.to_string()) + .body("grpc error") + .unwrap() + } + + // Verifies the cached-path `.min(max)` clamp in `rate_limit_hint()` clamps HTTP + // Retry-After hints to the caller-supplied cap consistently across cap + // magnitudes (60s / 300s / 1800s) and directions (below-cap / over-cap). + #[test] + fn test_rate_limit_hint_multi_cap_http() { + struct Row { + cap: Duration, + header_value: u64, // Retry-After in integer seconds + } + + let rows = [ + Row { + cap: SMALL_CAP, + header_value: 30, + }, // below + Row { + cap: SMALL_CAP, + header_value: 120, + }, // over + Row { + cap: DEFAULT_CAP, + header_value: 120, + }, // below + Row { + cap: DEFAULT_CAP, + header_value: 600, + }, // over + Row { + cap: LARGE_CAP, + header_value: 900, + }, // below + Row { + cap: LARGE_CAP, + header_value: 3600, + }, // over + ]; + + for Row { cap, header_value } in rows { + let mut resp = build_http_retry_after_response(header_value); + resp.attach_parsed_rate_limit_hint(Duration::MAX); + + // Remove the source header after attach + resp.headers_mut().remove(http::header::RETRY_AFTER); + let parsed = Duration::from_secs(header_value); + let expected = parsed.min(cap); + + assert_eq!( + resp.rate_limit_hint(cap), + Some(expected), + "cap={cap:?}, header={header_value}s, expected={expected:?}", + ); + } + } + + // Verifies the cached-path `.min(max)` clamp in `rate_limit_hint()` clamps gRPC + // retry-pushback-ms hints to the caller-supplied cap consistently across + // cap magnitudes (60s / 300s / 1800s) and directions (below-cap / over-cap). + #[test] + fn test_rate_limit_hint_multi_cap_grpc() { + struct Row { + cap: Duration, + header_value: u64, // grpc-retry-pushback in milliseconds + } + + let rows = [ + Row { + cap: SMALL_CAP, + header_value: 30_000, + }, // below (30s) + Row { + cap: SMALL_CAP, + header_value: 120_000, + }, // over (120s) + Row { + cap: DEFAULT_CAP, + header_value: 120_000, + }, // below (120s) + Row { + cap: DEFAULT_CAP, + header_value: 600_000, + }, // over (600s) + Row { + cap: LARGE_CAP, + header_value: 900_000, + }, // below (900s) + Row { + cap: LARGE_CAP, + header_value: 3_600_000, + }, // over (3600s) + ]; + + for Row { cap, header_value } in rows { + let mut resp = build_grpc_pushback_response(header_value); + resp.attach_parsed_rate_limit_hint(Duration::MAX); + + // Remove the source header after attach + resp.headers_mut().remove("grpc-retry-pushback-ms"); + let parsed = Duration::from_millis(header_value); + let expected = parsed.min(cap); + + assert_eq!( + resp.rate_limit_hint(cap), + Some(expected), + "cap={cap:?}, header={header_value}ms, expected={expected:?}", + ); + } + } + + #[tokio::test(flavor = "current_thread", start_paused = true)] + async fn test_503_injects_full_penalty() { + let inner = MockService::new(http::StatusCode::SERVICE_UNAVAILABLE); + let mut biaser = LoadBiaser::new(inner, test_config()); + + time::sleep(Duration::from_millis(1)).await; + + let _ = biaser.call(()).await; + let penalty = biaser.get_penalty(); + + // test_config has penalty_secs = 5.0 + assert!( + (penalty - 5.0).abs() < 0.1, + "503 should inject full penalty (~5.0s), got: {penalty}" + ); + } + + #[tokio::test(flavor = "current_thread", start_paused = true)] + async fn test_503_with_retry_after_uses_adaptive_penalty() { + let inner = RetryAfterService { + status: http::StatusCode::SERVICE_UNAVAILABLE, + retry_after_secs: 30, + }; + let mut biaser = LoadBiaser::new(inner, test_config()); + + time::sleep(Duration::from_millis(1)).await; + + // Make a request that returns 503 with Retry-After: 30 + let _ = biaser.call(()).await; + + // Same amplification formula as 429: + // penalty_secs * FACTOR * exp(RA/decay) = 5.0 * 0.5 * e^3 + let expected = 5.0_f64 * RETRY_AFTER_PENALTY_FACTOR * (30.0_f64 / 10.0_f64).exp(); + let penalty = biaser.get_penalty(); + assert!( + (penalty - expected).abs() < 1.0, + "503+RA penalty should be ~{expected:.1} (amplified), got: {penalty}" + ); + } + + #[tokio::test(flavor = "current_thread", start_paused = true)] + async fn test_503_with_retry_after_clamped_to_max() { + let inner = RetryAfterService { + status: http::StatusCode::SERVICE_UNAVAILABLE, + retry_after_secs: 600, + }; + let config = LoadBiaserConfig { + max_duration: DEFAULT_RETRY_AFTER_MAX_DURATION, + ..test_config() + }; + let mut biaser = LoadBiaser::new(inner, config); + + time::sleep(Duration::from_millis(1)).await; + + // Make a request that returns 503 with Retry-After: 600, clamped to 300s + let _ = biaser.call(()).await; + + // Amplified with clamped RA: + // penalty_secs * FACTOR * exp(clamped_RA/decay) = 5.0 * 0.5 * e^30 + // exceeds the 1e12 penalty clamp, so the actual penalty is clamped. + let unclamped = 5.0_f64 * RETRY_AFTER_PENALTY_FACTOR * (300.0_f64 / 10.0_f64).exp(); + let expected = unclamped.min(1e12); + let penalty = biaser.get_penalty(); + assert!( + (penalty - expected).abs() / expected < 0.01, + "503+RA penalty should be ~{expected:.0} (amplified, clamped RA=300), got: {penalty}" + ); + } + + #[tokio::test(flavor = "current_thread", start_paused = true)] + async fn test_429_and_503_retry_after_produce_identical_penalty() { + // 429 path: RetryAfterService returns TOO_MANY_REQUESTS with Retry-After: 30 + let inner_429 = RetryAfterService { + status: http::StatusCode::TOO_MANY_REQUESTS, + retry_after_secs: 30, + }; + let mut biaser_429 = LoadBiaser::new(inner_429, test_config()); + + // 503 path: RetryAfterService returns SERVICE_UNAVAILABLE with Retry-After: 30 + let inner_503 = RetryAfterService { + status: http::StatusCode::SERVICE_UNAVAILABLE, + retry_after_secs: 30, + }; + let mut biaser_503 = LoadBiaser::new(inner_503, test_config()); + + // Bootstrap EWMA timestamps + time::sleep(Duration::from_millis(1)).await; + let _ = biaser_429.call(()).await; + + time::sleep(Duration::from_millis(1)).await; + let _ = biaser_503.call(()).await; + + let load_429 = biaser_429.load(); + let load_503 = biaser_503.load(); + + assert!( + (load_429 - load_503).abs() / load_429.max(1e-9) < 0.05, + "429+RA and 503+RA should produce similar load; 429={load_429}, 503={load_503}" + ); + } + + #[tokio::test(flavor = "current_thread", start_paused = true)] + async fn test_500_injects_penalty() { + let inner = MockService::new(http::StatusCode::INTERNAL_SERVER_ERROR); + let mut biaser = LoadBiaser::new(inner, test_config()); + + time::sleep(Duration::from_millis(1)).await; + + let _ = biaser.call(()).await; + let penalty = biaser.get_penalty(); + + // test_config has penalty_secs = 5.0 + assert!( + (penalty - 5.0).abs() < 0.1, + "500 should inject full penalty (~5.0s), got: {penalty}" + ); + } + + // Mock service that returns HTTP 200 with a `grpc-status` header, + // simulating a gRPC trailers-only error response. + #[derive(Clone)] + struct GrpcErrorService { + grpc_status: u16, + } + + impl Service<()> for GrpcErrorService { + type Response = http::Response<&'static str>; + type Error = Infallible; + type Future = futures::future::Ready>; + + fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, _: ()) -> Self::Future { + let resp = http::Response::builder() + .status(http::StatusCode::OK) + .header("grpc-status", self.grpc_status.to_string()) + .body("grpc error") + .unwrap(); + + futures::future::ready(Ok(resp)) + } + } + + #[tokio::test(flavor = "current_thread", start_paused = true)] + async fn test_grpc_resource_exhausted_injects_penalty() { + let inner = GrpcErrorService { grpc_status: 8 }; // RESOURCE_EXHAUSTED + let mut biaser = LoadBiaser::new(inner, test_config()); + + time::sleep(Duration::from_millis(1)).await; + + let _ = biaser.call(()).await; + // gRPC RESOURCE_EXHAUSTED maps to FailureHint::RateLimited, + // which injects the full penalty_secs (5.0). + let penalty = biaser.get_penalty(); + + assert!( + (penalty - 5.0).abs() < 0.1, + "gRPC RESOURCE_EXHAUSTED should inject full penalty (~5s), got: {penalty}" + ); + } + + async fn assert_disabled_no_penalty(mut biaser: LoadBiaser, label: &str) + where + S: Service<(), Response = http::Response<&'static str>, Error = Infallible>, + { + time::sleep(Duration::from_millis(1)).await; + let _ = biaser.call(()).await; + assert!( + biaser.get_penalty().is_infinite(), + "penalty should not be injected when disabled for {label}: {}", + biaser.get_penalty() + ); + } + + fn disabled_config() -> LoadBiaserConfig { + LoadBiaserConfig { + enabled: false, + ..test_config() + } + } + + #[tokio::test(flavor = "current_thread", start_paused = true)] + async fn test_429_disabled_no_penalty() { + let inner = MockService::new(http::StatusCode::TOO_MANY_REQUESTS); + assert_disabled_no_penalty(LoadBiaser::new(inner, disabled_config()), "429").await; + } + + #[tokio::test(flavor = "current_thread", start_paused = true)] + async fn test_503_disabled_no_penalty() { + let inner = MockService::new(http::StatusCode::SERVICE_UNAVAILABLE); + assert_disabled_no_penalty(LoadBiaser::new(inner, disabled_config()), "503").await; + } + + #[tokio::test(flavor = "current_thread", start_paused = true)] + async fn test_500_disabled_no_penalty() { + let inner = MockService::new(http::StatusCode::INTERNAL_SERVER_ERROR); + assert_disabled_no_penalty(LoadBiaser::new(inner, disabled_config()), "500").await; + } + + #[tokio::test(flavor = "current_thread", start_paused = true)] + async fn test_grpc_resource_exhausted_disabled_no_penalty() { + let inner = GrpcErrorService { grpc_status: 8 }; + assert_disabled_no_penalty( + LoadBiaser::new(inner, disabled_config()), + "gRPC RESOURCE_EXHAUSTED (status 8)", + ) + .await; + } + + #[tokio::test(flavor = "current_thread", start_paused = true)] + async fn test_grpc_unavailable_disabled_no_penalty() { + let inner = GrpcErrorService { grpc_status: 14 }; + assert_disabled_no_penalty( + LoadBiaser::new(inner, disabled_config()), + "gRPC UNAVAILABLE (status 14)", + ) + .await; + } + + #[tokio::test(flavor = "current_thread", start_paused = true)] + async fn test_grpc_internal_disabled_no_penalty() { + let inner = GrpcErrorService { grpc_status: 13 }; + assert_disabled_no_penalty( + LoadBiaser::new(inner, disabled_config()), + "gRPC INTERNAL (status 13)", + ) + .await; + } +}