diff --git a/tower-http/src/compression/predicate.rs b/tower-http/src/compression/predicate.rs index a843d9d68..bf643bdee 100644 --- a/tower-http/src/compression/predicate.rs +++ b/tower-http/src/compression/predicate.rs @@ -146,17 +146,17 @@ impl Predicate for DefaultPredicate { /// [`Predicate`] that will only allow compression of responses above a certain size. #[derive(Clone, Copy, Debug)] -pub struct SizeAbove(u16); +pub struct SizeAbove(u64); impl SizeAbove { - pub(crate) const DEFAULT_MIN_SIZE: u16 = 32; + pub(crate) const DEFAULT_MIN_SIZE: u64 = 32; /// Create a new `SizeAbove` predicate that will only compress responses larger than /// `min_size_bytes`. /// /// The response will be compressed if the exact size cannot be determined through either the /// `content-length` header or [`Body::size_hint`]. - pub const fn new(min_size_bytes: u16) -> Self { + pub const fn new(min_size_bytes: u64) -> Self { Self(min_size_bytes) } } @@ -181,7 +181,7 @@ impl Predicate for SizeAbove { }); match content_size { - Some(size) => size >= (self.0 as u64), + Some(size) => size >= self.0, _ => true, } } diff --git a/tower-http/src/macros.rs b/tower-http/src/macros.rs index f58d34a66..56a37e252 100644 --- a/tower-http/src/macros.rs +++ b/tower-http/src/macros.rs @@ -103,3 +103,35 @@ macro_rules! opaque_future { } } } + +/// Evaluate `$call` at most once every `$interval` per call site. +/// +/// Uses a monotonic clock and atomic timestamp to rate-limit without locks. +/// Adapted from dial9-tokio-telemetry's rate_limit module. +// TODO: Once MSRV >= 1.70, switch to OnceLock for monotonic timing. +// See: https://github.com/dial9-rs/dial9/blob/6772039/dial9-tokio-telemetry/src/rate_limit.rs +#[allow(unused_macros)] +macro_rules! rate_limited { + ($interval:expr, $call:expr) => {{ + use std::sync::atomic::{AtomicU64, Ordering}; + use std::time::{Duration, SystemTime, UNIX_EPOCH}; + + static NEXT_CALL: AtomicU64 = AtomicU64::new(0); + + let interval: Duration = $interval; + let now = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap_or(Duration::ZERO) + .as_secs(); + let next = NEXT_CALL.load(Ordering::Relaxed); + if now >= next { + let new_next = now.saturating_add(interval.as_secs()); + if NEXT_CALL + .compare_exchange(next, new_next, Ordering::Relaxed, Ordering::Relaxed) + .is_ok() + { + $call; + } + } + }}; +} diff --git a/tower-http/src/services/fs/serve_dir/future.rs b/tower-http/src/services/fs/serve_dir/future.rs index 6386ead47..e073acdb5 100644 --- a/tower-http/src/services/fs/serve_dir/future.rs +++ b/tower-http/src/services/fs/serve_dir/future.rs @@ -119,8 +119,22 @@ where ))); } - Ok(OpenFileOutput::NotModified) => { - break Poll::Ready(Ok(response_with_status(StatusCode::NOT_MODIFIED))); + Ok(OpenFileOutput::NotModified { + etag, + last_modified, + }) => { + let mut res = response_with_status(StatusCode::NOT_MODIFIED); + if let Some(etag) = etag { + res.headers_mut() + .insert(header::ETAG, etag.into_header_value()); + } + if let Some(last_modified) = last_modified { + res.headers_mut().insert( + header::LAST_MODIFIED, + HeaderValue::from_str(&last_modified.0.to_string()).unwrap(), + ); + } + break Poll::Ready(Ok(res)); } Ok(OpenFileOutput::InvalidRedirectUri) => { @@ -250,6 +264,10 @@ fn build_response(output: FileOpened) -> Response { builder = builder.header(header::LAST_MODIFIED, last_modified.0.to_string()); } + if let Some(etag) = output.etag { + builder = builder.header(header::ETAG, etag.into_header_value()); + } + match output.maybe_range { Some(Ok(ranges)) => { if let Some(range) = ranges.first() { diff --git a/tower-http/src/services/fs/serve_dir/headers.rs b/tower-http/src/services/fs/serve_dir/headers.rs index e3a87a2c7..f756d24fb 100644 --- a/tower-http/src/services/fs/serve_dir/headers.rs +++ b/tower-http/src/services/fs/serve_dir/headers.rs @@ -2,6 +2,147 @@ use http::header::HeaderValue; use httpdate::HttpDate; use std::time::SystemTime; +/// A strong ETag derived from file metadata (size + mtime with nanosecond precision). +/// +/// Format is an implementation detail and may change between versions. Clients should +/// treat ETags as opaque values per RFC 9110 §8.8.3. +#[derive(Clone, Debug)] +pub(super) struct ETag(HeaderValue); + +impl ETag { + /// Generate an ETag from file size and modification time. + /// + /// Returns `None` only for pre-epoch modification times, which are unsupported. + pub(super) fn from_metadata(size: u64, modified: SystemTime) -> Option { + let duration = modified.duration_since(SystemTime::UNIX_EPOCH).ok()?; + // NOTE: Changing this format is a cache-busting event for all clients, + // but is not a semver break (ETags are opaque per RFC 9110 §8.8.3). + let value = format!( + "\"{:x}.{:08x}-{:x}\"", + duration.as_secs(), + duration.subsec_nanos(), + size + ); + HeaderValue::from_str(&value).ok().map(ETag) + } + + pub(super) fn into_header_value(self) -> HeaderValue { + self.0 + } + + /// Strong comparison per RFC 9110 §8.8.3.2: both must not be weak, + /// and the opaque-tags must be identical. + fn strong_eq(&self, other: &[u8]) -> bool { + if other.starts_with(b"W/") { + return false; + } + self.0.as_bytes() == other + } + + /// Weak comparison per RFC 9110 §8.8.3.2: ignore W/ prefix, + /// compare opaque-tags. + fn weak_eq(&self, other: &[u8]) -> bool { + let this = self.0.as_bytes(); + let other = other.strip_prefix(b"W/").unwrap_or(other); + let this = this.strip_prefix(b"W/").unwrap_or(this); + this == other + } +} + +/// Parsed `If-None-Match` header (RFC 9110 §13.1.2). +pub(super) struct IfNoneMatch(HeaderValue); + +impl IfNoneMatch { + pub(super) fn from_header_value(value: &HeaderValue) -> Option { + // Reject empty values + if value.as_bytes().is_empty() { + return None; + } + Some(IfNoneMatch(value.clone())) + } + + /// Returns true if the precondition passes (none of the ETags match). + /// A failed precondition (returns false) means we should return 304. + /// + /// Uses weak comparison per RFC 9110 §13.1.2. + pub(super) fn precondition_passes(&self, etag: &ETag) -> bool { + let bytes = self.0.as_bytes(); + if bytes == b"*" { + return false; + } + !for_each_etag(bytes, |tag| etag.weak_eq(tag)) + } +} + +/// Parsed `If-Match` header (RFC 9110 §13.1.1). +pub(super) struct IfMatch(HeaderValue); + +impl IfMatch { + pub(super) fn from_header_value(value: &HeaderValue) -> Option { + if value.as_bytes().is_empty() { + return None; + } + Some(IfMatch(value.clone())) + } + + /// Returns true if the precondition passes (at least one ETag matches). + /// A failed precondition (returns false) means we should return 412. + /// + /// Uses strong comparison per RFC 9110 §13.1.1. + pub(super) fn precondition_passes(&self, etag: &ETag) -> bool { + let bytes = self.0.as_bytes(); + if bytes == b"*" { + return true; + } + for_each_etag(bytes, |tag| etag.strong_eq(tag)) + } +} + +/// Iterate over comma-separated ETags in a header value, trimming OWS. +/// Returns true if `predicate` returns true for any tag (short-circuits). +/// +/// Handles commas inside quoted strings per RFC 9110 §8.8.3 (ETags are quoted). +fn for_each_etag(header: &[u8], mut predicate: impl FnMut(&[u8]) -> bool) -> bool { + let mut start = 0; + let mut in_quotes = false; + for i in 0..header.len() { + match header[i] { + b'"' => in_quotes = !in_quotes, + b',' if !in_quotes => { + let trimmed = trim_ows(&header[start..i]); + if !trimmed.is_empty() && predicate(trimmed) { + return true; + } + start = i + 1; + } + _ => {} + } + } + let trimmed = trim_ows(&header[start..]); + if !trimmed.is_empty() && predicate(trimmed) { + return true; + } + false +} + +/// Trim leading/trailing OWS (SP / HTAB) per RFC 9110. +fn trim_ows(bytes: &[u8]) -> &[u8] { + let start = bytes + .iter() + .position(|&b| b != b' ' && b != b'\t') + .unwrap_or(bytes.len()); + let end = bytes + .iter() + .rposition(|&b| b != b' ' && b != b'\t') + .map(|i| i + 1) + .unwrap_or(0); + if start >= end { + &[] + } else { + &bytes[start..end] + } +} + pub(super) struct LastModified(pub(super) HttpDate); impl From for LastModified { @@ -43,3 +184,79 @@ impl IfUnmodifiedSince { .map(|time| IfUnmodifiedSince(time.into())) } } + +#[cfg(test)] +mod tests { + use super::*; + + /// Helper: collect all ETags parsed from a header value. + fn collect_etags(header: &[u8]) -> Vec> { + let mut tags = Vec::new(); + for_each_etag(header, |tag| { + tags.push(tag.to_vec()); + false // don't short-circuit, collect all + }); + tags + } + + #[test] + fn for_each_etag_simple_list() { + let tags = collect_etags(b"\"foo\", \"bar\", \"baz\""); + assert_eq!( + tags, + vec![ + b"\"foo\"".to_vec(), + b"\"bar\"".to_vec(), + b"\"baz\"".to_vec() + ] + ); + } + + #[test] + fn for_each_etag_comma_inside_quotes() { + // An ETag containing a comma inside the quoted string should not be split + let tags = collect_etags(b"\"foo,bar\", \"baz\""); + assert_eq!(tags, vec![b"\"foo,bar\"".to_vec(), b"\"baz\"".to_vec()]); + } + + #[test] + fn for_each_etag_multiple_commas_inside_quotes() { + let tags = collect_etags(b"\"a,b,c\", \"d\""); + assert_eq!(tags, vec![b"\"a,b,c\"".to_vec(), b"\"d\"".to_vec()]); + } + + #[test] + fn for_each_etag_weak_with_comma_inside() { + let tags = collect_etags(b"W/\"foo,bar\", \"baz\""); + assert_eq!(tags, vec![b"W/\"foo,bar\"".to_vec(), b"\"baz\"".to_vec()]); + } + + #[test] + fn for_each_etag_single_tag() { + let tags = collect_etags(b"\"only\""); + assert_eq!(tags, vec![b"\"only\"".to_vec()]); + } + + #[test] + fn for_each_etag_empty() { + let tags = collect_etags(b""); + assert!(tags.is_empty()); + } + + #[test] + fn for_each_etag_whitespace_only() { + let tags = collect_etags(b" , , "); + assert!(tags.is_empty()); + } + + #[test] + fn for_each_etag_short_circuits() { + let mut count = 0; + let found = for_each_etag(b"\"a\", \"b\", \"c\"", |_tag| { + count += 1; + count == 2 // match on second tag + }); + assert!(found); + assert_eq!(count, 2); + } +} diff --git a/tower-http/src/services/fs/serve_dir/open_file.rs b/tower-http/src/services/fs/serve_dir/open_file.rs index d7e760846..6ab4dfb4e 100644 --- a/tower-http/src/services/fs/serve_dir/open_file.rs +++ b/tower-http/src/services/fs/serve_dir/open_file.rs @@ -1,5 +1,5 @@ use super::{ - headers::{IfModifiedSince, IfUnmodifiedSince, LastModified}, + headers::{ETag, IfMatch, IfModifiedSince, IfNoneMatch, IfUnmodifiedSince, LastModified}, ServeVariant, }; use crate::content_encoding::{Encoding, QValue}; @@ -18,10 +18,15 @@ use tokio::{fs::File, io::AsyncSeekExt}; pub(super) enum OpenFileOutput { FileOpened(Box), - Redirect { location: HeaderValue }, + Redirect { + location: HeaderValue, + }, FileNotFound, PreconditionFailed, - NotModified, + NotModified { + etag: Option, + last_modified: Option, + }, InvalidRedirectUri, InvalidFilename, } @@ -34,6 +39,7 @@ pub(super) struct FileOpened { pub(super) maybe_range: Option>, RangeUnsatisfiableError>>, pub(super) last_modified: Option, pub(super) precompression_configured: bool, + pub(super) etag: Option, } pub(super) enum FileRequestExtent { @@ -50,15 +56,24 @@ pub(super) async fn open_file( buf_chunk_size: usize, precompression_configured: bool, ) -> io::Result { - let if_unmodified_since = req - .headers() - .get(header::IF_UNMODIFIED_SINCE) - .and_then(IfUnmodifiedSince::from_header_value); - - let if_modified_since = req - .headers() - .get(header::IF_MODIFIED_SINCE) - .and_then(IfModifiedSince::from_header_value); + let preconditions = Preconditions { + if_match: req + .headers() + .get(header::IF_MATCH) + .and_then(IfMatch::from_header_value), + if_unmodified_since: req + .headers() + .get(header::IF_UNMODIFIED_SINCE) + .and_then(IfUnmodifiedSince::from_header_value), + if_none_match: req + .headers() + .get(header::IF_NONE_MATCH) + .and_then(IfNoneMatch::from_header_value), + if_modified_since: req + .headers() + .get(header::IF_MODIFIED_SINCE) + .and_then(IfModifiedSince::from_header_value), + }; let mime = match variant { ServeVariant::Directory { @@ -91,15 +106,26 @@ pub(super) async fn open_file( }; if req.method() == Method::HEAD { + #[cfg(feature = "tracing")] + let _path_str = path_to_file.display().to_string(); let (meta, maybe_encoding) = file_metadata_with_fallback(path_to_file, negotiated_encodings).await?; let last_modified = meta.modified().ok().map(LastModified::from); - if let Some(output) = check_modified_headers( - last_modified.as_ref(), - if_unmodified_since, - if_modified_since, - ) { + let etag = meta + .modified() + .ok() + .and_then(|mtime| ETag::from_metadata(meta.len(), mtime)); + + #[cfg(feature = "tracing")] + if etag.is_none() { + rate_limited!( + std::time::Duration::from_secs(60), + tracing::warn!(path = %_path_str, "ETag generation failed (mtime unavailable or pre-epoch)") + ); + } + + if let Some(output) = preconditions.check(etag.as_ref(), last_modified.as_ref()) { return Ok(output); } @@ -113,8 +139,11 @@ pub(super) async fn open_file( maybe_range, last_modified, precompression_configured, + etag, }))) } else { + #[cfg(feature = "tracing")] + let _path_str = path_to_file.display().to_string(); let (mut file, maybe_encoding) = match open_file_with_fallback(path_to_file, negotiated_encodings).await { Ok(result) => result, @@ -128,11 +157,20 @@ pub(super) async fn open_file( let meta = file.metadata().await?; let last_modified = meta.modified().ok().map(LastModified::from); - if let Some(output) = check_modified_headers( - last_modified.as_ref(), - if_unmodified_since, - if_modified_since, - ) { + let etag = meta + .modified() + .ok() + .and_then(|mtime| ETag::from_metadata(meta.len(), mtime)); + + #[cfg(feature = "tracing")] + if etag.is_none() { + rate_limited!( + std::time::Duration::from_secs(60), + tracing::warn!(path = %_path_str, "ETag generation failed (mtime unavailable or pre-epoch)") + ); + } + + if let Some(output) = preconditions.check(etag.as_ref(), last_modified.as_ref()) { return Ok(output); } @@ -153,6 +191,7 @@ pub(super) async fn open_file( maybe_range, last_modified, precompression_configured, + etag, }))) } } @@ -177,34 +216,82 @@ fn is_invalid_filename_error(err: &io::Error) -> bool { false } -fn check_modified_headers( - modified: Option<&LastModified>, +/// Precondition headers parsed from the request. +struct Preconditions { + if_match: Option, if_unmodified_since: Option, + if_none_match: Option, if_modified_since: Option, -) -> Option { - if let Some(since) = if_unmodified_since { - let precondition = modified - .as_ref() - .map(|time| since.precondition_passes(time)) - .unwrap_or(false); - - if !precondition { - return Some(OpenFileOutput::PreconditionFailed); +} + +impl Preconditions { + /// Evaluate preconditions per [RFC 9110 §13.2.2](https://www.rfc-editor.org/rfc/rfc9110#section-13.2.2). + /// + /// Precedence order: + /// 1. If-Match (strong comparison) → 412 on failure + /// 2. If-Unmodified-Since (only if If-Match absent) → 412 on failure + /// 3. If-None-Match (weak comparison) → 304 on failure (for GET/HEAD) + /// 4. If-Modified-Since (only if If-None-Match absent) → 304 on failure + fn check( + self, + etag: Option<&ETag>, + last_modified: Option<&LastModified>, + ) -> Option { + // Step 1: If-Match + if let Some(if_match) = self.if_match { + // RFC 9110 §13.1.1: "If the field value is '*', the condition is FALSE + // if the origin server does not have a current representation." + // No ETag means no current representation → fail. + let passes = etag + .map(|etag| if_match.precondition_passes(etag)) + .unwrap_or(false); + if !passes { + return Some(OpenFileOutput::PreconditionFailed); + } + } else { + // Step 2: If-Unmodified-Since (only when If-Match is absent) + // RFC 9110 §13.1.4: "MUST ignore if the resource does not have a + // modification date available." + if let Some(since) = self.if_unmodified_since { + let passes = last_modified + .map(|lm| since.precondition_passes(lm)) + .unwrap_or(true); + if !passes { + return Some(OpenFileOutput::PreconditionFailed); + } + } } - } - if let Some(since) = if_modified_since { - let unmodified = modified - .as_ref() - .map(|time| !since.is_modified(time)) - // no last_modified means its always modified - .unwrap_or(false); - if unmodified { - return Some(OpenFileOutput::NotModified); + // Step 3: If-None-Match + if let Some(if_none_match) = self.if_none_match { + // No ETag available → condition is vacuously true (passes), serve normally. + let passes = etag + .map(|etag| if_none_match.precondition_passes(etag)) + .unwrap_or(true); + if !passes { + return Some(OpenFileOutput::NotModified { + etag: etag.cloned(), + last_modified: last_modified.map(|lm| LastModified(lm.0)), + }); + } + } else { + // Step 4: If-Modified-Since (only when If-None-Match is absent) + // No Last-Modified → treat as modified (serve normally). + if let Some(since) = self.if_modified_since { + let unmodified = last_modified + .map(|lm| !since.is_modified(lm)) + .unwrap_or(false); + if unmodified { + return Some(OpenFileOutput::NotModified { + etag: etag.cloned(), + last_modified: last_modified.map(|lm| LastModified(lm.0)), + }); + } + } } - } - None + None + } } // Returns the preferred_encoding encoding and modifies the path extension diff --git a/tower-http/src/services/fs/serve_dir/tests.rs b/tower-http/src/services/fs/serve_dir/tests.rs index 0b9a6c781..30fe0d397 100644 --- a/tower-http/src/services/fs/serve_dir/tests.rs +++ b/tower-http/src/services/fs/serve_dir/tests.rs @@ -1294,3 +1294,313 @@ async fn html_as_default_extension_does_not_apply_when_extension_present() { assert_eq!(res.status(), StatusCode::OK); assert_eq!(res.headers()["content-type"], "text/plain"); } + +#[tokio::test] +async fn etag_is_set_on_response() { + let svc = ServeDir::new(REPO_ROOT); + + let req = Request::builder() + .uri("/README.md") + .body(Body::empty()) + .unwrap(); + let res = svc.oneshot(req).await.unwrap(); + + assert_eq!(res.status(), StatusCode::OK); + let etag = res + .headers() + .get(header::ETAG) + .expect("Missing ETag header"); + let etag_str = etag.to_str().unwrap(); + // Strong ETag format: "-" + assert!(etag_str.starts_with('"')); + assert!(etag_str.ends_with('"')); + assert!(!etag_str.starts_with("W/")); + assert!(etag_str.contains('-')); +} + +#[tokio::test] +async fn if_none_match_returns_304() { + let svc = ServeDir::new(REPO_ROOT); + + // First request to get the ETag + let req = Request::builder() + .uri("/README.md") + .body(Body::empty()) + .unwrap(); + let res = svc.oneshot(req).await.unwrap(); + assert_eq!(res.status(), StatusCode::OK); + let etag = res.headers().get(header::ETAG).unwrap().clone(); + let last_modified = res.headers().get(header::LAST_MODIFIED).unwrap().clone(); + + // Second request with If-None-Match + let svc = ServeDir::new(REPO_ROOT); + let req = Request::builder() + .uri("/README.md") + .header(header::IF_NONE_MATCH, &etag) + .body(Body::empty()) + .unwrap(); + let res = svc.oneshot(req).await.unwrap(); + assert_eq!(res.status(), StatusCode::NOT_MODIFIED); + // RFC 9110 §15.4.5: 304 MUST include validator headers + assert_eq!(res.headers().get(header::ETAG).unwrap(), &etag); + assert_eq!( + res.headers().get(header::LAST_MODIFIED).unwrap(), + &last_modified + ); + assert!(res.into_body().frame().await.is_none()); +} + +#[tokio::test] +async fn if_none_match_with_non_matching_etag_returns_200() { + let svc = ServeDir::new(REPO_ROOT); + + let req = Request::builder() + .uri("/README.md") + .header(header::IF_NONE_MATCH, "\"not-a-real-etag\"") + .body(Body::empty()) + .unwrap(); + let res = svc.oneshot(req).await.unwrap(); + assert_eq!(res.status(), StatusCode::OK); +} + +#[tokio::test] +async fn if_none_match_wildcard_returns_304() { + let svc = ServeDir::new(REPO_ROOT); + + let req = Request::builder() + .uri("/README.md") + .header(header::IF_NONE_MATCH, "*") + .body(Body::empty()) + .unwrap(); + let res = svc.oneshot(req).await.unwrap(); + assert_eq!(res.status(), StatusCode::NOT_MODIFIED); +} + +#[tokio::test] +async fn if_match_with_matching_etag_succeeds() { + let svc = ServeDir::new(REPO_ROOT); + + // First request to get the ETag + let req = Request::builder() + .uri("/README.md") + .body(Body::empty()) + .unwrap(); + let res = svc.oneshot(req).await.unwrap(); + assert_eq!(res.status(), StatusCode::OK); + let etag = res.headers().get(header::ETAG).unwrap().clone(); + + // Second request with If-Match + let svc = ServeDir::new(REPO_ROOT); + let req = Request::builder() + .uri("/README.md") + .header(header::IF_MATCH, etag) + .body(Body::empty()) + .unwrap(); + let res = svc.oneshot(req).await.unwrap(); + assert_eq!(res.status(), StatusCode::OK); +} + +#[tokio::test] +async fn if_match_with_non_matching_etag_returns_412() { + let svc = ServeDir::new(REPO_ROOT); + + let req = Request::builder() + .uri("/README.md") + .header(header::IF_MATCH, "\"not-a-real-etag\"") + .body(Body::empty()) + .unwrap(); + let res = svc.oneshot(req).await.unwrap(); + assert_eq!(res.status(), StatusCode::PRECONDITION_FAILED); +} + +#[tokio::test] +async fn if_none_match_takes_precedence_over_if_modified_since() { + // Per RFC 9110 §13.2.2, If-None-Match takes precedence over If-Modified-Since + let svc = ServeDir::new(REPO_ROOT); + + // First request to get the ETag + let req = Request::builder() + .uri("/README.md") + .body(Body::empty()) + .unwrap(); + let res = svc.oneshot(req).await.unwrap(); + assert_eq!(res.status(), StatusCode::OK); + let etag = res.headers().get(header::ETAG).unwrap().clone(); + + // Send both If-None-Match (matching) and If-Modified-Since (very old, would normally 200) + // If-None-Match should win and return 304 + let svc = ServeDir::new(REPO_ROOT); + let req = Request::builder() + .uri("/README.md") + .header(header::IF_NONE_MATCH, etag) + .header(header::IF_MODIFIED_SINCE, "Fri, 09 Aug 1996 14:21:40 GMT") + .body(Body::empty()) + .unwrap(); + let res = svc.oneshot(req).await.unwrap(); + assert_eq!(res.status(), StatusCode::NOT_MODIFIED); +} + +#[tokio::test] +async fn if_match_takes_precedence_over_if_unmodified_since() { + // Per RFC 9110 §13.2.2, If-Match takes precedence over If-Unmodified-Since + let svc = ServeDir::new(REPO_ROOT); + + // Send If-Match (non-matching, should 412) and If-Unmodified-Since (far future, would pass) + // If-Match should win and return 412 + let req = Request::builder() + .uri("/README.md") + .header(header::IF_MATCH, "\"not-a-real-etag\"") + .header(header::IF_UNMODIFIED_SINCE, "Sun, 01 Jan 2100 00:00:00 GMT") + .body(Body::empty()) + .unwrap(); + let res = svc.oneshot(req).await.unwrap(); + assert_eq!(res.status(), StatusCode::PRECONDITION_FAILED); +} + +#[tokio::test] +async fn if_none_match_weak_comparison() { + // Weak comparison: W/"etag" should match "etag" for If-None-Match + let svc = ServeDir::new(REPO_ROOT); + + // First request to get the ETag + let req = Request::builder() + .uri("/README.md") + .body(Body::empty()) + .unwrap(); + let res = svc.oneshot(req).await.unwrap(); + assert_eq!(res.status(), StatusCode::OK); + let etag = res + .headers() + .get(header::ETAG) + .unwrap() + .to_str() + .unwrap() + .to_owned(); + + // Send with W/ prefix, should still match via weak comparison + let svc = ServeDir::new(REPO_ROOT); + let weak_etag = format!("W/{}", etag); + let req = Request::builder() + .uri("/README.md") + .header(header::IF_NONE_MATCH, &weak_etag) + .body(Body::empty()) + .unwrap(); + let res = svc.oneshot(req).await.unwrap(); + assert_eq!(res.status(), StatusCode::NOT_MODIFIED); +} + +#[tokio::test] +async fn if_match_strong_comparison_rejects_weak_etag() { + // Strong comparison: W/"etag" should NOT match "etag" for If-Match + let svc = ServeDir::new(REPO_ROOT); + + // First request to get the ETag + let req = Request::builder() + .uri("/README.md") + .body(Body::empty()) + .unwrap(); + let res = svc.oneshot(req).await.unwrap(); + assert_eq!(res.status(), StatusCode::OK); + let etag = res + .headers() + .get(header::ETAG) + .unwrap() + .to_str() + .unwrap() + .to_owned(); + + // Send with W/ prefix for If-Match, should fail (strong comparison) + let svc = ServeDir::new(REPO_ROOT); + let weak_etag = format!("W/{}", etag); + let req = Request::builder() + .uri("/README.md") + .header(header::IF_MATCH, &weak_etag) + .body(Body::empty()) + .unwrap(); + let res = svc.oneshot(req).await.unwrap(); + assert_eq!(res.status(), StatusCode::PRECONDITION_FAILED); +} + +#[tokio::test] +async fn if_none_match_multiple_etags() { + let svc = ServeDir::new(REPO_ROOT); + + let req = Request::builder() + .uri("/README.md") + .body(Body::empty()) + .unwrap(); + let res = svc.oneshot(req).await.unwrap(); + let etag = res + .headers() + .get(header::ETAG) + .unwrap() + .to_str() + .unwrap() + .to_owned(); + + // One matching among several should still produce 304 + let svc = ServeDir::new(REPO_ROOT); + let multi = format!("\"bogus\", {}, \"also-bogus\"", etag); + let req = Request::builder() + .uri("/README.md") + .header(header::IF_NONE_MATCH, &multi) + .body(Body::empty()) + .unwrap(); + let res = svc.oneshot(req).await.unwrap(); + assert_eq!(res.status(), StatusCode::NOT_MODIFIED); +} + +#[tokio::test] +async fn if_match_wildcard_succeeds() { + let svc = ServeDir::new(REPO_ROOT); + + let req = Request::builder() + .uri("/README.md") + .header(header::IF_MATCH, "*") + .body(Body::empty()) + .unwrap(); + let res = svc.oneshot(req).await.unwrap(); + assert_eq!(res.status(), StatusCode::OK); +} + +#[tokio::test] +async fn etag_on_head_request() { + let svc = ServeDir::new(REPO_ROOT); + + let req = Request::builder() + .uri("/README.md") + .method(Method::HEAD) + .body(Body::empty()) + .unwrap(); + let res = svc.oneshot(req).await.unwrap(); + assert_eq!(res.status(), StatusCode::OK); + assert!(res.headers().get(header::ETAG).is_some()); +} + +#[tokio::test] +async fn if_modified_since_304_includes_etag() { + let svc = ServeDir::new(REPO_ROOT); + + let req = Request::builder() + .uri("/README.md") + .body(Body::empty()) + .unwrap(); + let res = svc.oneshot(req).await.unwrap(); + let last_modified = res.headers().get(header::LAST_MODIFIED).unwrap().clone(); + let etag = res.headers().get(header::ETAG).unwrap().clone(); + + // Time-based 304 should also include ETag + let svc = ServeDir::new(REPO_ROOT); + let req = Request::builder() + .uri("/README.md") + .header(header::IF_MODIFIED_SINCE, &last_modified) + .body(Body::empty()) + .unwrap(); + let res = svc.oneshot(req).await.unwrap(); + assert_eq!(res.status(), StatusCode::NOT_MODIFIED); + assert_eq!(res.headers().get(header::ETAG).unwrap(), &etag); + assert_eq!( + res.headers().get(header::LAST_MODIFIED).unwrap(), + &last_modified + ); +} diff --git a/tower-http/src/timeout/body.rs b/tower-http/src/timeout/body.rs index d44f35b8c..176fc6243 100644 --- a/tower-http/src/timeout/body.rs +++ b/tower-http/src/timeout/body.rs @@ -106,9 +106,9 @@ where } } -/// Error for [`TimeoutBody`]. +/// Error for [`TimeoutBody`] and [`DeadlineBody`][super::DeadlineBody]. #[derive(Debug)] -pub struct TimeoutError(()); +pub struct TimeoutError(pub(super) ()); impl std::error::Error for TimeoutError {} diff --git a/tower-http/src/timeout/deadline_body.rs b/tower-http/src/timeout/deadline_body.rs new file mode 100644 index 000000000..b69c5fcc7 --- /dev/null +++ b/tower-http/src/timeout/deadline_body.rs @@ -0,0 +1,284 @@ +use crate::BoxError; +use http_body::Body; +use pin_project_lite::pin_project; +use std::{ + future::Future, + pin::Pin, + task::{ready, Context, Poll}, + time::Duration, +}; +use tokio::time::{sleep, Sleep}; + +pin_project! { + /// Wrapper around a [`Body`] that enforces a hard deadline on the entire body transfer. + /// + /// Unlike [`TimeoutBody`][super::TimeoutBody], which resets its deadline each time a frame is + /// received, `DeadlineBody` starts a single timer at construction and returns a + /// [`TimeoutError`][super::TimeoutError] if the body is not fully consumed before the deadline. + /// + /// The deadline is **wall-clock time from construction**, not cumulative poll time. The + /// timer continues to count even if the consumer is not actively polling the body. If you + /// poll some frames, pause to do other work, and then resume, the elapsed pause time counts + /// toward the deadline. + /// + /// # When to use this + /// + /// This is primarily useful as middleware on public-facing endpoints where you want to bound + /// the total wall-clock time a single request can hold resources (task slots, memory for + /// buffering, etc.), regardless of how frequently data trickles in. A slow client sending + /// one byte per second will never trip [`TimeoutBody`][super::TimeoutBody]'s idle timeout, + /// but will correctly trip `DeadlineBody`. + /// + /// If you only need to detect stalled connections where no data flows for a period, use + /// [`TimeoutBody`][super::TimeoutBody] instead. The two can be stacked if you want both + /// an idle timeout and a hard deadline. + /// + /// # Example + /// + /// ``` + /// use http::{Request, Response}; + /// use bytes::Bytes; + /// use http_body_util::Full; + /// use std::time::Duration; + /// use tower::ServiceBuilder; + /// use tower_http::timeout::RequestBodyDeadlineLayer; + /// + /// async fn handle(_: Request>) -> Result>, std::convert::Infallible> { + /// // ... + /// # todo!() + /// } + /// + /// # #[tokio::main] + /// # async fn main() -> Result<(), Box> { + /// let svc = ServiceBuilder::new() + /// // Timeout bodies after 30 seconds total + /// .layer(RequestBodyDeadlineLayer::new(Duration::from_secs(30))) + /// .service_fn(handle); + /// # Ok(()) + /// # } + /// ``` + pub struct DeadlineBody { + #[pin] + sleep: Sleep, + #[pin] + body: B, + } +} + +impl DeadlineBody { + /// Creates a new [`DeadlineBody`]. + /// + /// The timeout starts immediately. If the body is not fully consumed within `timeout`, + /// subsequent `poll_frame` calls will return a [`TimeoutError`][super::TimeoutError]. + pub fn new(timeout: Duration, body: B) -> Self { + DeadlineBody { + sleep: sleep(timeout), + body, + } + } +} + +impl Body for DeadlineBody +where + B: Body, + B::Error: Into, +{ + type Data = B::Data; + type Error = Box; + + fn poll_frame( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll, Self::Error>>> { + let this = self.project(); + + // Error if the absolute timeout has expired. + if let Poll::Ready(()) = this.sleep.poll(cx) { + return Poll::Ready(Some(Err(Box::new(super::TimeoutError(()))))); + } + + // Check for body data. + let frame = ready!(this.body.poll_frame(cx)); + + Poll::Ready(frame.transpose().map_err(Into::into).transpose()) + } + + fn is_end_stream(&self) -> bool { + self.body.is_end_stream() + } + + fn size_hint(&self) -> http_body::SizeHint { + self.body.size_hint() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + use bytes::Bytes; + use http_body::Frame; + use http_body_util::BodyExt; + use pin_project_lite::pin_project; + use std::{error::Error, fmt::Display}; + use tokio::time::sleep; + + #[derive(Debug)] + struct MockError; + + impl Error for MockError {} + + impl Display for MockError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "mock error") + } + } + + pin_project! { + /// A body that yields a frame after a delay. + struct MockBody { + #[pin] + sleep: Sleep, + } + } + + impl Body for MockBody { + type Data = Bytes; + type Error = MockError; + + fn poll_frame( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll, Self::Error>>> { + let this = self.project(); + this.sleep + .poll(cx) + .map(|_| Some(Ok(Frame::data(vec![].into())))) + } + } + + pin_project! { + /// A body that yields multiple frames with a delay between each. + struct MultiFrameBody { + frames_remaining: usize, + frame_interval: Duration, + #[pin] + sleep: Option, + } + } + + impl Body for MultiFrameBody { + type Data = Bytes; + type Error = MockError; + + fn poll_frame( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll, Self::Error>>> { + let mut this = self.project(); + + if *this.frames_remaining == 0 { + return Poll::Ready(None); + } + + // Start the sleep if not active. + let sleep_pinned = if let Some(s) = this.sleep.as_mut().as_pin_mut() { + s + } else { + this.sleep.set(Some(sleep(*this.frame_interval))); + this.sleep.as_mut().as_pin_mut().unwrap() + }; + + ready!(sleep_pinned.poll(cx)); + this.sleep.set(None); + *this.frames_remaining -= 1; + + Poll::Ready(Some(Ok(Frame::data(Bytes::from("chunk"))))) + } + } + + #[tokio::test] + async fn body_completes_within_timeout() { + let mock_body = MockBody { + sleep: sleep(Duration::from_millis(50)), + }; + let timeout_body = DeadlineBody::new(Duration::from_millis(200), mock_body); + + assert!(timeout_body + .boxed() + .frame() + .await + .expect("no frame") + .is_ok()); + } + + #[tokio::test] + async fn body_exceeds_timeout() { + let mock_body = MockBody { + sleep: sleep(Duration::from_millis(200)), + }; + let timeout_body = DeadlineBody::new(Duration::from_millis(50), mock_body); + + let result = timeout_body.boxed().frame().await.unwrap(); + assert!(result.is_err()); + assert!(result + .unwrap_err() + .downcast_ref::() + .is_some()); + } + + #[tokio::test] + async fn deadline_fires_despite_steady_frames() { + // Each frame arrives every 30ms (well within an idle timeout of 100ms), + // but total transfer takes 5 * 30ms = 150ms, exceeding the 100ms deadline. + let body = MultiFrameBody { + frames_remaining: 5, + frame_interval: Duration::from_millis(30), + sleep: None, + }; + let timeout_body = DeadlineBody::new(Duration::from_millis(100), body); + + let mut boxed = timeout_body.boxed(); + let mut got_error = false; + + loop { + match boxed.frame().await { + Some(Ok(_)) => {} + Some(Err(_)) => { + got_error = true; + break; + } + None => break, + } + } + + assert!( + got_error, + "expected timeout error before all frames arrived" + ); + } + + #[tokio::test] + async fn all_frames_arrive_within_deadline() { + // Each frame arrives every 20ms, total = 3 * 20ms = 60ms, within 200ms deadline. + let body = MultiFrameBody { + frames_remaining: 3, + frame_interval: Duration::from_millis(20), + sleep: None, + }; + let timeout_body = DeadlineBody::new(Duration::from_millis(200), body); + + let mut boxed = timeout_body.boxed(); + let mut frame_count = 0; + + loop { + match boxed.frame().await { + Some(Ok(_)) => frame_count += 1, + Some(Err(e)) => panic!("unexpected error: {}", e), + None => break, + } + } + + assert_eq!(frame_count, 3); + } +} diff --git a/tower-http/src/timeout/mod.rs b/tower-http/src/timeout/mod.rs index e159b23cc..699fd720d 100644 --- a/tower-http/src/timeout/mod.rs +++ b/tower-http/src/timeout/mod.rs @@ -13,6 +13,22 @@ //! and the specified status code. That means if your service's error type is [`Infallible`], it will //! still be [`Infallible`] after applying this middleware. //! +//! # Body timeouts +//! +//! Two body timeout wrappers are available for limiting how long a request or response body +//! transfer can take: +//! +//! - [`TimeoutBody`] resets its deadline each time a frame is received. Use this to detect +//! idle connections where no data flows for a period of time. +//! - [`DeadlineBody`] starts a single timer at construction and never resets it. Use +//! this to cap the total wall-clock time spent transferring a body, regardless of how +//! frequently data arrives. +//! +//! Both are applied via their corresponding layer types ([`RequestBodyTimeoutLayer`] / +//! [`RequestBodyDeadlineLayer`] for request bodies, [`ResponseBodyTimeoutLayer`] / +//! [`ResponseBodyDeadlineLayer`] for response bodies). They can be stacked if you +//! want both an idle timeout and an absolute deadline. +//! //! # Example //! //! ``` @@ -41,10 +57,13 @@ //! [`Infallible`]: std::convert::Infallible mod body; +mod deadline_body; mod service; pub use body::{TimeoutBody, TimeoutError}; +pub use deadline_body::DeadlineBody; pub use service::{ - RequestBodyTimeout, RequestBodyTimeoutLayer, ResponseBodyTimeout, ResponseBodyTimeoutLayer, + RequestBodyDeadline, RequestBodyDeadlineLayer, RequestBodyTimeout, RequestBodyTimeoutLayer, + ResponseBodyDeadline, ResponseBodyDeadlineLayer, ResponseBodyTimeout, ResponseBodyTimeoutLayer, Timeout, TimeoutLayer, }; diff --git a/tower-http/src/timeout/service.rs b/tower-http/src/timeout/service.rs index 68ea56ef3..c5eb845cc 100644 --- a/tower-http/src/timeout/service.rs +++ b/tower-http/src/timeout/service.rs @@ -1,4 +1,5 @@ use crate::timeout::body::TimeoutBody; +use crate::timeout::deadline_body::DeadlineBody; use http::{Request, Response, StatusCode}; use pin_project_lite::pin_project; use std::{ @@ -305,6 +306,167 @@ where } } +/// Applies a [`DeadlineBody`] to the request body. +/// +/// Unlike [`RequestBodyTimeoutLayer`], which resets on each frame, this enforces a hard +/// deadline on the entire body transfer. +#[derive(Clone, Debug)] +pub struct RequestBodyDeadlineLayer { + timeout: Duration, +} + +impl RequestBodyDeadlineLayer { + /// Creates a new [`RequestBodyDeadlineLayer`]. + pub fn new(timeout: Duration) -> Self { + Self { timeout } + } +} + +impl Layer for RequestBodyDeadlineLayer { + type Service = RequestBodyDeadline; + + fn layer(&self, inner: S) -> Self::Service { + RequestBodyDeadline::new(inner, self.timeout) + } +} + +/// Applies a [`DeadlineBody`] to the request body. +#[derive(Clone, Debug)] +pub struct RequestBodyDeadline { + inner: S, + timeout: Duration, +} + +impl RequestBodyDeadline { + /// Creates a new [`RequestBodyDeadline`]. + pub fn new(service: S, timeout: Duration) -> Self { + Self { + inner: service, + timeout, + } + } + + /// Returns a new [`Layer`] that wraps services with a [`RequestBodyDeadlineLayer`] middleware. + /// + /// [`Layer`]: tower_layer::Layer + pub fn layer(timeout: Duration) -> RequestBodyDeadlineLayer { + RequestBodyDeadlineLayer::new(timeout) + } + + define_inner_service_accessors!(); +} + +impl Service> for RequestBodyDeadline +where + S: Service>>, +{ + type Response = S::Response; + type Error = S::Error; + type Future = S::Future; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.inner.poll_ready(cx) + } + + fn call(&mut self, req: Request) -> Self::Future { + let req = req.map(|body| DeadlineBody::new(self.timeout, body)); + self.inner.call(req) + } +} + +/// Applies a [`DeadlineBody`] to the response body. +/// +/// Unlike [`ResponseBodyTimeoutLayer`], which resets on each frame, this enforces a hard +/// deadline on the entire body transfer. +#[derive(Clone)] +pub struct ResponseBodyDeadlineLayer { + timeout: Duration, +} + +impl ResponseBodyDeadlineLayer { + /// Creates a new [`ResponseBodyDeadlineLayer`]. + pub fn new(timeout: Duration) -> Self { + Self { timeout } + } +} + +impl Layer for ResponseBodyDeadlineLayer { + type Service = ResponseBodyDeadline; + + fn layer(&self, inner: S) -> Self::Service { + ResponseBodyDeadline::new(inner, self.timeout) + } +} + +/// Applies a [`DeadlineBody`] to the response body. +#[derive(Clone)] +pub struct ResponseBodyDeadline { + inner: S, + timeout: Duration, +} + +impl ResponseBodyDeadline { + /// Creates a new [`ResponseBodyDeadline`]. + pub fn new(service: S, timeout: Duration) -> Self { + Self { + inner: service, + timeout, + } + } + + /// Returns a new [`Layer`] that wraps services with a [`ResponseBodyDeadlineLayer`] middleware. + /// + /// [`Layer`]: tower_layer::Layer + pub fn layer(timeout: Duration) -> ResponseBodyDeadlineLayer { + ResponseBodyDeadlineLayer::new(timeout) + } + + define_inner_service_accessors!(); +} + +impl Service> for ResponseBodyDeadline +where + S: Service, Response = Response>, +{ + type Response = Response>; + type Error = S::Error; + type Future = ResponseBodyDeadlineFuture; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.inner.poll_ready(cx) + } + + fn call(&mut self, req: Request) -> Self::Future { + ResponseBodyDeadlineFuture { + inner: self.inner.call(req), + timeout: self.timeout, + } + } +} + +pin_project! { + /// Response future for [`ResponseBodyDeadline`]. + pub struct ResponseBodyDeadlineFuture { + #[pin] + inner: Fut, + timeout: Duration, + } +} + +impl Future for ResponseBodyDeadlineFuture +where + Fut: Future, E>>, +{ + type Output = Result>, E>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let timeout = self.timeout; + let this = self.project(); + let res = ready!(this.inner.poll(cx))?; + Poll::Ready(Ok(res.map(|body| DeadlineBody::new(timeout, body)))) + } +} + #[cfg(test)] mod tests { use super::*;