From 3b56d2d2e8be2d3b75e446a9d093c67f50a47d56 Mon Sep 17 00:00:00 2001 From: Jess Izen <44884346+jlizen@users.noreply.github.com> Date: Fri, 12 Jun 2026 11:07:45 -0700 Subject: [PATCH 1/3] feat!: Add configurable Backend trait for ServeDir, bump MSRV 1.65 (#684) * feat(fs): add configurable Backend trait for ServeDir * chore: bump MSRV to 1.65, add changelog entries * fix clippy --- .github/workflows/CI.yml | 2 +- README.md | 2 +- tower-http/CHANGELOG.md | 2 + tower-http/Cargo.toml | 2 +- tower-http/src/services/fs/mod.rs | 5 + .../src/services/fs/serve_dir/backend.rs | 140 ++++++++++ .../src/services/fs/serve_dir/future.rs | 4 +- tower-http/src/services/fs/serve_dir/mod.rs | 58 +++- .../src/services/fs/serve_dir/open_file.rs | 75 +++--- tower-http/src/services/fs/serve_dir/tests.rs | 251 ++++++++++++++++++ 10 files changed, 491 insertions(+), 50 deletions(-) create mode 100644 tower-http/src/services/fs/serve_dir/backend.rs diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index f02987f86..c09952b5b 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -102,7 +102,7 @@ jobs: # Still better to maintain fewer manual version overrides though. - run: cargo update -p async-compression --precise 0.4.23 - run: cargo update -p flate2 --precise 1.0.35 - - uses: dtolnay/rust-toolchain@1.64 + - uses: dtolnay/rust-toolchain@1.65 - uses: Swatinem/rust-cache@v2 with: save-if: ${{ github.ref == 'refs/heads/main' }} diff --git a/README.md b/README.md index 7b76c46fd..ff4b72c5b 100644 --- a/README.md +++ b/README.md @@ -37,7 +37,7 @@ The [examples] folder contains various examples of how to use Tower HTTP: ## Minimum supported Rust version -tower-http's MSRV is 1.66. +tower-http's MSRV is 1.65. ## Getting Help diff --git a/tower-http/CHANGELOG.md b/tower-http/CHANGELOG.md index 2f31ac24c..b2dd83131 100644 --- a/tower-http/CHANGELOG.md +++ b/tower-http/CHANGELOG.md @@ -32,6 +32,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 feature entries; the underlying dependencies are still pulled in transitively by the features that need them (e.g. `compression-gzip`, `fs`, `timeout`). ([#628]) +- MSRV bumped from 1.64 to 1.65. [#215]: https://github.com/tower-rs/tower-http/issues/215 [#628]: https://github.com/tower-rs/tower-http/pull/628 @@ -40,6 +41,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## Added - `body`: `UnsyncBoxBody::new()` constructor and `From` conversion to avoid double-boxing when combining `ServeDir` responses with other body types ([#537]) +- `fs`: Add `Backend` trait to make `ServeDir` work with non-filesystem sources. The default `TokioBackend` preserves existing behavior. Use `ServeDir::with_backend()` to plug in custom implementations. # 0.6.11 diff --git a/tower-http/Cargo.toml b/tower-http/Cargo.toml index 632c77f39..6b7e1012d 100644 --- a/tower-http/Cargo.toml +++ b/tower-http/Cargo.toml @@ -10,7 +10,7 @@ repository = "https://github.com/tower-rs/tower-http" homepage = "https://github.com/tower-rs/tower-http" categories = ["asynchronous", "network-programming", "web-programming"] keywords = ["io", "async", "futures", "service", "http"] -rust-version = "1.64" +rust-version = "1.65" [dependencies] bitflags = "2.0.2" diff --git a/tower-http/src/services/fs/mod.rs b/tower-http/src/services/fs/mod.rs index 4673fb2ab..3652e5be4 100644 --- a/tower-http/src/services/fs/mod.rs +++ b/tower-http/src/services/fs/mod.rs @@ -18,10 +18,15 @@ mod serve_file; pub use self::{ serve_dir::{ future::ResponseFuture as ServeFileSystemResponseFuture, + Backend, DefaultServeDirFallback, + File, + Metadata, // The response body and future are used for both ServeDir and ServeFile ResponseBody as ServeFileSystemResponseBody, ServeDir, + TokioBackend, + TokioFile, }, serve_file::ServeFile, }; diff --git a/tower-http/src/services/fs/serve_dir/backend.rs b/tower-http/src/services/fs/serve_dir/backend.rs new file mode 100644 index 000000000..d3e87afcf --- /dev/null +++ b/tower-http/src/services/fs/serve_dir/backend.rs @@ -0,0 +1,140 @@ +//! Pluggable backend trait for [`ServeDir`](super::ServeDir). +//! +//! The [`Backend`] trait abstracts file system operations so that `ServeDir` can serve +//! files from sources other than the local filesystem (e.g. rust-embed, include_dir, S3). + +use std::{future::Future, io, path::PathBuf, pin::Pin, time::SystemTime}; +use tokio::io::{AsyncRead, AsyncSeek}; + +/// Trait for file metadata. +/// +/// This is the information `ServeDir` needs about a file or directory without opening it. +pub trait Metadata: Send + 'static { + /// Returns `true` if this metadata refers to a directory. + fn is_dir(&self) -> bool; + + /// Returns the last modification time, if available. + fn modified(&self) -> io::Result; + + /// Returns the size of the file in bytes. + fn len(&self) -> u64; + + /// Returns `true` if the file is empty (zero bytes). + fn is_empty(&self) -> bool { + self.len() == 0 + } +} + +/// Trait for an opened file. +/// +/// Must support async reading and seeking (for HTTP range requests). +/// In-memory backends can use [`std::io::Cursor`] to satisfy the `AsyncSeek` requirement. +pub trait File: AsyncRead + AsyncSeek + Unpin + Send + Sync { + /// The metadata type returned by this file. + type Metadata: Metadata; + + /// Future returned by [`File::metadata`]. + type MetadataFuture<'a>: Future> + Send + where + Self: 'a; + + /// Returns metadata for this opened file. + fn metadata(&self) -> Self::MetadataFuture<'_>; +} + +/// Trait abstracting filesystem operations for [`ServeDir`](super::ServeDir). +/// +/// Implement this trait to serve files from non-filesystem sources. +/// The default implementation ([`TokioBackend`]) wraps `tokio::fs`. +pub trait Backend: Clone + Send + Sync + 'static { + /// The file type returned by [`Backend::open`]. + type File: File; + + /// The metadata type returned by [`Backend::metadata`]. + type Metadata: Metadata; + + /// Future returned by [`Backend::open`]. + type OpenFuture: Future> + Send; + + /// Future returned by [`Backend::metadata`]. + type MetadataFuture: Future> + Send; + + /// Open a file at the given path. + fn open(&self, path: PathBuf) -> Self::OpenFuture; + + /// Retrieve metadata for the given path without opening the file. + fn metadata(&self, path: PathBuf) -> Self::MetadataFuture; +} + +/// Default [`Backend`] implementation using `tokio::fs`. +#[derive(Clone, Debug, Default)] +pub struct TokioBackend; + +impl Backend for TokioBackend { + type File = TokioFile; + type Metadata = std::fs::Metadata; + type OpenFuture = Pin> + Send>>; + type MetadataFuture = Pin> + Send>>; + + fn open(&self, path: PathBuf) -> Self::OpenFuture { + Box::pin(async move { + let file = tokio::fs::File::open(&path).await?; + Ok(TokioFile(file)) + }) + } + + fn metadata(&self, path: PathBuf) -> Self::MetadataFuture { + Box::pin(async move { tokio::fs::metadata(&path).await }) + } +} + +/// Wrapper around [`tokio::fs::File`] implementing the [`File`] trait. +#[derive(Debug)] +pub struct TokioFile(tokio::fs::File); + +impl AsyncRead for TokioFile { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &mut tokio::io::ReadBuf<'_>, + ) -> std::task::Poll> { + Pin::new(&mut self.0).poll_read(cx, buf) + } +} + +impl AsyncSeek for TokioFile { + fn start_seek(mut self: Pin<&mut Self>, position: io::SeekFrom) -> io::Result<()> { + Pin::new(&mut self.0).start_seek(position) + } + + fn poll_complete( + mut self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + Pin::new(&mut self.0).poll_complete(cx) + } +} + +impl File for TokioFile { + type Metadata = std::fs::Metadata; + type MetadataFuture<'a> = + Pin> + Send + 'a>>; + + fn metadata(&self) -> Self::MetadataFuture<'_> { + Box::pin(async move { self.0.metadata().await }) + } +} + +impl Metadata for std::fs::Metadata { + fn is_dir(&self) -> bool { + self.is_dir() + } + + fn modified(&self) -> io::Result { + self.modified() + } + + fn len(&self) -> u64 { + self.len() + } +} diff --git a/tower-http/src/services/fs/serve_dir/future.rs b/tower-http/src/services/fs/serve_dir/future.rs index e073acdb5..2c019d6bb 100644 --- a/tower-http/src/services/fs/serve_dir/future.rs +++ b/tower-http/src/services/fs/serve_dir/future.rs @@ -239,8 +239,8 @@ where fn build_response(output: FileOpened) -> Response { let (maybe_file, size) = match output.extent { - FileRequestExtent::Full(file, meta) => (Some(file), meta.len()), - FileRequestExtent::Head(meta) => (None, meta.len()), + FileRequestExtent::Full(file, size) => (Some(file), size), + FileRequestExtent::Head(size) => (None, size), }; let mut builder = Response::builder() diff --git a/tower-http/src/services/fs/serve_dir/mod.rs b/tower-http/src/services/fs/serve_dir/mod.rs index 9143b569d..60aa49c3d 100644 --- a/tower-http/src/services/fs/serve_dir/mod.rs +++ b/tower-http/src/services/fs/serve_dir/mod.rs @@ -17,6 +17,7 @@ use std::{ }; use tower_service::Service; +mod backend; pub(crate) mod future; mod headers; mod open_file; @@ -24,6 +25,8 @@ mod open_file; #[cfg(test)] mod tests; +pub use self::backend::{Backend, File, Metadata, TokioBackend, TokioFile}; + // default capacity 64KiB const DEFAULT_CAPACITY: usize = 65536; @@ -50,7 +53,7 @@ const DEFAULT_CAPACITY: usize = 65536; /// let service = ServeDir::new("assets"); /// ``` #[derive(Clone, Debug)] -pub struct ServeDir { +pub struct ServeDir { base: PathBuf, redirect_path_prefix: String, buf_chunk_size: usize, @@ -60,6 +63,7 @@ pub struct ServeDir { variant: ServeVariant, fallback: Option, call_fallback_on_method_not_allowed: bool, + backend: B, } impl ServeDir { @@ -82,6 +86,7 @@ impl ServeDir { }, fallback: None, call_fallback_on_method_not_allowed: false, + backend: TokioBackend, } } @@ -97,11 +102,39 @@ impl ServeDir { variant: ServeVariant::SingleFile { mime }, fallback: None, call_fallback_on_method_not_allowed: false, + backend: TokioBackend, } } } -impl ServeDir { +impl ServeDir { + /// Create a new [`ServeDir`] with a custom [`Backend`]. + /// + /// This allows serving files from sources other than the local filesystem. + pub fn with_backend

(path: P, backend: B) -> Self + where + P: AsRef, + { + let mut base = PathBuf::from("."); + base.push(path.as_ref()); + + ServeDir { + base, + buf_chunk_size: DEFAULT_CAPACITY, + precompressed_variants: None, + variant: ServeVariant::Directory { + append_index_html_on_directories: true, + html_as_default_extension: false, + }, + fallback: None, + call_fallback_on_method_not_allowed: false, + redirect_path_prefix: String::new(), + backend, + } + } +} + +impl ServeDir { /// If the requested path is a directory append `index.html`. /// /// This is useful for static sites. @@ -243,7 +276,7 @@ impl ServeDir { /// // respond with `not_found.html` for missing files /// .fallback(ServeFile::new("assets/not_found.html")); /// ``` - pub fn fallback(self, new_fallback: F2) -> ServeDir { + pub fn fallback(self, new_fallback: F2) -> ServeDir { ServeDir { redirect_path_prefix: self.redirect_path_prefix, base: self.base, @@ -252,6 +285,7 @@ impl ServeDir { variant: self.variant, fallback: Some(new_fallback), call_fallback_on_method_not_allowed: self.call_fallback_on_method_not_allowed, + backend: self.backend, } } @@ -272,7 +306,7 @@ impl ServeDir { /// ``` /// /// Setups like this are often found in single page applications. - pub fn not_found_service(self, new_fallback: F2) -> ServeDir> { + pub fn not_found_service(self, new_fallback: F2) -> ServeDir, B> { self.fallback(SetStatus::new(new_fallback, StatusCode::NOT_FOUND)) } @@ -410,31 +444,29 @@ impl ServeDir { ) .collect(); - let open_file_config = open_file::OpenFileConfig { + let open_file_future = Box::pin(open_file::open_file(open_file::OpenFileRequest { variant: self.variant.clone(), redirect_path_prefix, - buf_chunk_size, - precompression_configured, - }; - - let open_file_future = Box::pin(open_file::open_file( - open_file_config, path_to_file, req, negotiated_encodings, range_header, - )); + buf_chunk_size, + precompression_configured, + backend: self.backend.clone(), + })); ResponseFuture::open_file_future(open_file_future, fallback_and_request) } } -impl Service> for ServeDir +impl Service> for ServeDir where F: Service, Response = Response, Error = Infallible> + Clone, F::Future: Send + 'static, FResBody: http_body::Body + Send + 'static, FResBody::Error: Into>, + B: Backend, { type Response = Response; type Error = Infallible; diff --git a/tower-http/src/services/fs/serve_dir/open_file.rs b/tower-http/src/services/fs/serve_dir/open_file.rs index bb2f4c91b..b03c40e8b 100644 --- a/tower-http/src/services/fs/serve_dir/open_file.rs +++ b/tower-http/src/services/fs/serve_dir/open_file.rs @@ -1,4 +1,5 @@ use super::{ + backend::{Backend, File as _, Metadata as _}, headers::{ETag, IfMatch, IfModifiedSince, IfNoneMatch, IfUnmodifiedSince, LastModified}, ServeVariant, }; @@ -9,12 +10,11 @@ use http_body_util::Empty; use http_range_header::RangeUnsatisfiableError; use std::{ ffi::OsStr, - fs::Metadata, io::{self, ErrorKind, SeekFrom}, ops::RangeInclusive, path::{Path, PathBuf}, }; -use tokio::{fs::File, io::AsyncSeekExt}; +use tokio::io::AsyncSeekExt; pub(super) enum OpenFileOutput { FileOpened(Box), @@ -43,31 +43,36 @@ pub(super) struct FileOpened { } pub(super) enum FileRequestExtent { - Full(File, Metadata), - Head(Metadata), + Full(Box, u64), + Head(u64), } -pub(super) struct OpenFileConfig { +pub(super) struct OpenFileRequest { pub(super) variant: ServeVariant, pub(super) redirect_path_prefix: String, + pub(super) path_to_file: PathBuf, + pub(super) req: Request>, + pub(super) negotiated_encodings: Vec<(Encoding, QValue)>, + pub(super) range_header: Option, pub(super) buf_chunk_size: usize, pub(super) precompression_configured: bool, + pub(super) backend: B, } -pub(super) async fn open_file( - config: OpenFileConfig, - mut path_to_file: PathBuf, - req: Request>, - negotiated_encodings: Vec<(Encoding, QValue)>, - range_header: Option, +pub(super) async fn open_file( + request: OpenFileRequest, ) -> io::Result { - let OpenFileConfig { + let OpenFileRequest { variant, redirect_path_prefix, + mut path_to_file, + req, + negotiated_encodings, + range_header, buf_chunk_size, precompression_configured, - } = config; - + backend, + } = request; let preconditions = Preconditions { if_match: req .headers() @@ -101,6 +106,7 @@ pub(super) async fn open_file( req.uri(), append_index_html_on_directories, html_as_default_extension, + &backend, ) .await { @@ -122,7 +128,7 @@ pub(super) async fn open_file( #[cfg(feature = "tracing")] let _path_str = path_to_file.display().to_string(); let (meta, maybe_encoding) = - file_metadata_with_fallback(path_to_file, negotiated_encodings).await?; + file_metadata_with_fallback(&backend, path_to_file, negotiated_encodings).await?; let last_modified = meta.modified().ok().map(LastModified::from); let etag = meta @@ -145,7 +151,7 @@ pub(super) async fn open_file( let maybe_range = try_parse_range(range_header.as_deref(), meta.len()); Ok(OpenFileOutput::FileOpened(Box::new(FileOpened { - extent: FileRequestExtent::Head(meta), + extent: FileRequestExtent::Head(meta.len()), chunk_size: buf_chunk_size, mime_header_value: mime, maybe_encoding, @@ -158,7 +164,7 @@ pub(super) async fn open_file( #[cfg(feature = "tracing")] let _path_str = path_to_file.display().to_string(); let (mut file, maybe_encoding) = - match open_file_with_fallback(path_to_file, negotiated_encodings).await { + match open_file_with_fallback(&backend, path_to_file, negotiated_encodings).await { Ok(result) => result, Err(err) if is_invalid_filename_error(&err) => { @@ -187,7 +193,8 @@ pub(super) async fn open_file( return Ok(output); } - let maybe_range = try_parse_range(range_header.as_deref(), meta.len()); + let size = meta.len(); + let maybe_range = try_parse_range(range_header.as_deref(), size); if let Some(Ok(ranges)) = maybe_range.as_ref() { // if there is any other amount of ranges than 1 we'll return an // unsatisfiable later as there isn't yet support for multipart ranges @@ -197,7 +204,7 @@ pub(super) async fn open_file( } Ok(OpenFileOutput::FileOpened(Box::new(FileOpened { - extent: FileRequestExtent::Full(file, meta), + extent: FileRequestExtent::Full(Box::new(file), size), chunk_size: buf_chunk_size, mime_header_value: mime, maybe_encoding, @@ -336,14 +343,15 @@ fn preferred_encoding( // Attempts to open the file with any of the possible negotiated_encodings in the // preferred order. If none of the negotiated_encodings have a corresponding precompressed // file the uncompressed file is used as a fallback. -async fn open_file_with_fallback( +async fn open_file_with_fallback( + backend: &B, mut path: PathBuf, mut negotiated_encoding: Vec<(Encoding, QValue)>, -) -> io::Result<(File, Option)> { +) -> io::Result<(B::File, Option)> { let (file, encoding) = loop { // Get the preferred encoding among the negotiated ones. let encoding = preferred_encoding(&mut path, &negotiated_encoding); - match (File::open(&path).await, encoding) { + match (backend.open(path.clone()).await, encoding) { (Ok(file), maybe_encoding) => break (file, maybe_encoding), (Err(err), Some(encoding)) if err.kind() == io::ErrorKind::NotFound && encoding != Encoding::Identity => @@ -364,15 +372,16 @@ async fn open_file_with_fallback( // Attempts to get the file metadata with any of the possible negotiated_encodings in the // preferred order. If none of the negotiated_encodings have a corresponding precompressed // file the uncompressed file is used as a fallback. -async fn file_metadata_with_fallback( +async fn file_metadata_with_fallback( + backend: &B, mut path: PathBuf, mut negotiated_encoding: Vec<(Encoding, QValue)>, -) -> io::Result<(Metadata, Option)> { - let (file, encoding) = loop { +) -> io::Result<(B::Metadata, Option)> { + let (meta, encoding) = loop { // Get the preferred encoding among the negotiated ones. let encoding = preferred_encoding(&mut path, &negotiated_encoding); - match (tokio::fs::metadata(&path).await, encoding) { - (Ok(file), maybe_encoding) => break (file, maybe_encoding), + match (backend.metadata(path.clone()).await, encoding) { + (Ok(meta), maybe_encoding) => break (meta, maybe_encoding), (Err(err), Some(encoding)) if err.kind() == io::ErrorKind::NotFound && encoding != Encoding::Identity => { @@ -386,19 +395,20 @@ async fn file_metadata_with_fallback( (Err(err), _) => return Err(err), } }; - Ok((file, encoding)) + Ok((meta, encoding)) } -async fn maybe_redirect_or_append_path( +async fn maybe_redirect_or_append_path( redirect_path_prefix: &str, path_to_file: &mut PathBuf, uri: &Uri, append_index_html_on_directories: bool, html_as_default_extension: bool, + backend: &B, ) -> Option { let uri_path = uri.path(); - let is_directory = is_dir(path_to_file).await; + let is_directory = is_dir(path_to_file, backend).await; if uri_path.ends_with('/') && uri_path != "/" && is_directory != Some(true) { return Some(OpenFileOutput::FileNotFound); @@ -441,8 +451,9 @@ fn try_parse_range( }) } -async fn is_dir(path_to_file: &Path) -> Option { - tokio::fs::metadata(path_to_file) +async fn is_dir(path_to_file: &Path, backend: &B) -> Option { + backend + .metadata(path_to_file.to_owned()) .await .ok() .map(|meta_data| meta_data.is_dir()) diff --git a/tower-http/src/services/fs/serve_dir/tests.rs b/tower-http/src/services/fs/serve_dir/tests.rs index 4a9a58e9c..3351568d8 100644 --- a/tower-http/src/services/fs/serve_dir/tests.rs +++ b/tower-http/src/services/fs/serve_dir/tests.rs @@ -1684,3 +1684,254 @@ async fn if_modified_since_304_includes_etag() { &last_modified ); } + +mod memory_backend { + use super::*; + use crate::services::fs::serve_dir::backend::{Backend, File, Metadata}; + use std::{ + collections::HashMap, future::Future, io, path::PathBuf, pin::Pin, sync::Arc, + time::SystemTime, + }; + use tokio::io::{AsyncRead, AsyncSeek}; + + /// In-memory file metadata. + #[derive(Clone)] + struct MemMetadata { + is_dir: bool, + len: u64, + modified: SystemTime, + } + + impl Metadata for MemMetadata { + fn is_dir(&self) -> bool { + self.is_dir + } + + fn modified(&self) -> io::Result { + Ok(self.modified) + } + + fn len(&self) -> u64 { + self.len + } + } + + /// In-memory file backed by a Cursor. + struct MemFile { + cursor: std::io::Cursor>, + meta: MemMetadata, + } + + impl AsyncRead for MemFile { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &mut tokio::io::ReadBuf<'_>, + ) -> std::task::Poll> { + Pin::new(&mut self.cursor).poll_read(cx, buf) + } + } + + impl AsyncSeek for MemFile { + fn start_seek(mut self: Pin<&mut Self>, position: io::SeekFrom) -> io::Result<()> { + Pin::new(&mut self.cursor).start_seek(position) + } + + fn poll_complete( + mut self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + Pin::new(&mut self.cursor).poll_complete(cx) + } + } + + impl File for MemFile { + type Metadata = MemMetadata; + type MetadataFuture<'a> = std::future::Ready>; + + fn metadata(&self) -> Self::MetadataFuture<'_> { + std::future::ready(Ok(self.meta.clone())) + } + } + + /// In-memory backend storing files in a HashMap. + #[derive(Clone)] + struct MemBackend { + files: Arc>>, + dirs: Arc>, + } + + impl MemBackend { + fn new() -> Self { + Self { + files: Arc::new(HashMap::new()), + dirs: Arc::new(Vec::new()), + } + } + + fn with_file(mut self, path: impl Into, content: impl Into>) -> Self { + Arc::get_mut(&mut self.files) + .unwrap() + .insert(path.into(), content.into()); + self + } + + fn with_dir(mut self, path: impl Into) -> Self { + Arc::get_mut(&mut self.dirs).unwrap().push(path.into()); + self + } + } + + impl Backend for MemBackend { + type File = MemFile; + type Metadata = MemMetadata; + type OpenFuture = Pin> + Send>>; + type MetadataFuture = Pin> + Send>>; + + fn open(&self, path: PathBuf) -> Self::OpenFuture { + let files = self.files.clone(); + Box::pin(async move { + match files.get(&path) { + Some(data) => Ok(MemFile { + meta: MemMetadata { + is_dir: false, + len: data.len() as u64, + modified: SystemTime::UNIX_EPOCH, + }, + cursor: std::io::Cursor::new(data.clone()), + }), + None => Err(io::Error::new(io::ErrorKind::NotFound, "not found")), + } + }) + } + + fn metadata(&self, path: PathBuf) -> Self::MetadataFuture { + let files = self.files.clone(); + let dirs = self.dirs.clone(); + Box::pin(async move { + if dirs.contains(&path) { + return Ok(MemMetadata { + is_dir: true, + len: 0, + modified: SystemTime::UNIX_EPOCH, + }); + } + match files.get(&path) { + Some(data) => Ok(MemMetadata { + is_dir: false, + len: data.len() as u64, + modified: SystemTime::UNIX_EPOCH, + }), + None => Err(io::Error::new(io::ErrorKind::NotFound, "not found")), + } + }) + } + } + + #[tokio::test] + async fn serve_file_from_memory() { + let backend = MemBackend::new().with_file("./assets/hello.txt", "Hello, world!"); + + let svc = ServeDir::with_backend("assets", backend); + + let req = Request::builder() + .uri("/hello.txt") + .body(Body::empty()) + .unwrap(); + let res = svc.oneshot(req).await.unwrap(); + + assert_eq!(res.status(), StatusCode::OK); + assert_eq!(res.headers()["content-type"], "text/plain"); + + let body = body_into_text(res.into_body()).await; + assert_eq!(body, "Hello, world!"); + } + + #[tokio::test] + async fn not_found_from_memory() { + let backend = MemBackend::new(); + + let svc = ServeDir::with_backend("assets", backend); + + let req = Request::builder() + .uri("/missing.txt") + .body(Body::empty()) + .unwrap(); + let res = svc.oneshot(req).await.unwrap(); + + assert_eq!(res.status(), StatusCode::NOT_FOUND); + } + + #[tokio::test] + async fn head_request_from_memory() { + let backend = MemBackend::new().with_file("./assets/hello.txt", "Hello, world!"); + + let svc = ServeDir::with_backend("assets", backend); + + let req = Request::builder() + .method(Method::HEAD) + .uri("/hello.txt") + .body(Body::empty()) + .unwrap(); + let res = svc.oneshot(req).await.unwrap(); + + assert_eq!(res.status(), StatusCode::OK); + assert_eq!(res.headers()["content-length"], "13"); + + // HEAD should have empty body + let body = body_into_text(res.into_body()).await; + assert!(body.is_empty()); + } + + #[tokio::test] + async fn range_request_from_memory() { + let backend = MemBackend::new().with_file("./assets/hello.txt", "Hello, world!"); + + let svc = ServeDir::with_backend("assets", backend); + + let req = Request::builder() + .uri("/hello.txt") + .header("range", "bytes=0-4") + .body(Body::empty()) + .unwrap(); + let res = svc.oneshot(req).await.unwrap(); + + assert_eq!(res.status(), StatusCode::PARTIAL_CONTENT); + assert_eq!(res.headers()["content-range"], "bytes 0-4/13"); + + let body = body_into_text(res.into_body()).await; + assert_eq!(body, "Hello"); + } + + #[tokio::test] + async fn directory_redirect_from_memory() { + let backend = MemBackend::new() + .with_dir("./assets/sub") + .with_file("./assets/sub/index.html", "

Index

"); + + let svc = ServeDir::with_backend("assets", backend); + + // Request without trailing slash should redirect + let req = Request::builder().uri("/sub").body(Body::empty()).unwrap(); + let res = svc.oneshot(req).await.unwrap(); + + assert_eq!(res.status(), StatusCode::TEMPORARY_REDIRECT); + assert_eq!(res.headers()["location"], "/sub/"); + } + + #[tokio::test] + async fn directory_serves_index_html_from_memory() { + let backend = MemBackend::new() + .with_dir("./assets/sub") + .with_file("./assets/sub/index.html", "

Index

"); + + let svc = ServeDir::with_backend("assets", backend); + + let req = Request::builder().uri("/sub/").body(Body::empty()).unwrap(); + let res = svc.oneshot(req).await.unwrap(); + + assert_eq!(res.status(), StatusCode::OK); + let body = body_into_text(res.into_body()).await; + assert_eq!(body, "

Index

"); + } +} From 8cb8d99a841ca5c7853275886f2854f0fad6f09c Mon Sep 17 00:00:00 2001 From: Oliver THEBAULT Date: Fri, 12 Jun 2026 20:27:00 +0200 Subject: [PATCH 2/3] feat(ValidateRequestHeaderLayer): add has_header("...").with_value("...") function (#360) * feat(ValidateRequestHeaderLayer): add assert() function * fix: update ValidateRequestHeaderLayer with typed-builder style * fix: revert typed builder * fix(validate-request): address review feedback on has_header_value --------- Co-authored-by: Jess Izen <44884346+jlizen@users.noreply.github.com> --- tower-http/CHANGELOG.md | 2 + tower-http/src/validate_request.rs | 417 +++++++++++++++++++++++++++-- 2 files changed, 393 insertions(+), 26 deletions(-) diff --git a/tower-http/CHANGELOG.md b/tower-http/CHANGELOG.md index b2dd83131..2edc73ac6 100644 --- a/tower-http/CHANGELOG.md +++ b/tower-http/CHANGELOG.md @@ -35,11 +35,13 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - MSRV bumped from 1.64 to 1.65. [#215]: https://github.com/tower-rs/tower-http/issues/215 +[#360]: https://github.com/tower-rs/tower-http/pull/360 [#628]: https://github.com/tower-rs/tower-http/pull/628 [#642]: https://github.com/tower-rs/tower-http/pull/642 ## Added +- **validate-request:** Add `ValidateRequestHeaderLayer::has_header_value()` to reject requests when a header does not have an expected value ([#360]) - `body`: `UnsyncBoxBody::new()` constructor and `From` conversion to avoid double-boxing when combining `ServeDir` responses with other body types ([#537]) - `fs`: Add `Backend` trait to make `ServeDir` work with non-filesystem sources. The default `TokioBackend` preserves existing behavior. Use `ServeDir::with_backend()` to plug in custom implementations. diff --git a/tower-http/src/validate_request.rs b/tower-http/src/validate_request.rs index efb301e4f..89d3bcf4b 100644 --- a/tower-http/src/validate_request.rs +++ b/tower-http/src/validate_request.rs @@ -2,6 +2,8 @@ //! //! # Example //! +//! Validation of the `Accept` header can be made by using [`ValidateRequestHeaderLayer::accept()`]: +//! //! ``` //! use tower_http::validate_request::ValidateRequestHeaderLayer; //! use http::{Request, Response, StatusCode, header::ACCEPT}; @@ -51,6 +53,116 @@ //! # } //! ``` //! +//! Validation of a custom header can be made by using [`ValidateRequestHeaderLayer::has_header_value()`]: +//! +//! ``` +//! use tower_http::validate_request::ValidateRequestHeaderLayer; +//! use http::{Request, Response, StatusCode}; +//! use http_body_util::Full; +//! use bytes::Bytes; +//! use tower::{Service, ServiceExt, ServiceBuilder, service_fn, BoxError}; +//! +//! async fn handle(request: Request>) -> Result>, BoxError> { +//! Ok(Response::new(Full::default())) +//! } +//! +//! # #[tokio::main] +//! # async fn main() -> Result<(), BoxError> { +//! let mut service = ServiceBuilder::new() +//! // Require a `X-Custom-Header` header to have the value `random-value-1234567890` or reject with a `403 Forbidden` response +//! .layer(ValidateRequestHeaderLayer::has_header_value( +//! "x-custom-header", +//! "random-value-1234567890", +//! ).expect("invalid validate header")) +//! .service_fn(handle); +//! +//! // Requests with the correct value are allowed through +//! let request = Request::builder() +//! .header("x-custom-header", "random-value-1234567890") +//! .body(Full::default()) +//! .unwrap(); +//! +//! let response = service +//! .ready() +//! .await? +//! .call(request) +//! .await?; +//! +//! assert_eq!(StatusCode::OK, response.status()); +//! +//! // Requests with an invalid value get a `403 Forbidden` response +//! let request = Request::builder() +//! .header("x-custom-header", "wrong-value") +//! .body(Full::default()) +//! .unwrap(); +//! +//! let response = service +//! .ready() +//! .await? +//! .call(request) +//! .await?; +//! +//! assert_eq!(StatusCode::FORBIDDEN, response.status()); +//! # Ok(()) +//! # } +//! ``` +//! +//! To require only that a header is present, use [`ValidateRequestHeaderLayer::custom()`]: +//! +//! ``` +//! use tower_http::validate_request::ValidateRequestHeaderLayer; +//! use http::{Request, Response, StatusCode}; +//! use http_body_util::Full; +//! use bytes::Bytes; +//! use tower::{ServiceBuilder, service_fn, BoxError}; +//! +//! async fn handle(request: Request>) -> Result>, BoxError> { +//! Ok(Response::new(Full::default())) +//! } +//! +//! # fn main() { +//! let service = ServiceBuilder::new() +//! .layer(ValidateRequestHeaderLayer::custom(|req: &mut Request>| { +//! if req.headers().contains_key("x-custom-header") { +//! Ok(()) +//! } else { +//! let mut res = Response::new(Full::::default()); +//! *res.status_mut() = StatusCode::FORBIDDEN; +//! Err(res) +//! } +//! })) +//! .service_fn(handle); +//! # } +//! ``` +//! +//! To serve a custom response when validation fails, also use [`ValidateRequestHeaderLayer::custom()`]: +//! +//! ``` +//! use tower_http::validate_request::ValidateRequestHeaderLayer; +//! use http::{Request, Response, StatusCode}; +//! use http_body_util::Full; +//! use bytes::Bytes; +//! use tower::{ServiceBuilder, service_fn, BoxError}; +//! +//! async fn handle(request: Request>) -> Result>, BoxError> { +//! Ok(Response::new(Full::default())) +//! } +//! +//! # fn main() { +//! let service = ServiceBuilder::new() +//! .layer(ValidateRequestHeaderLayer::custom(|req: &mut Request>| { +//! match req.headers().get("x-custom-header").map(|v| v.as_bytes()) { +//! Some(b"random-value-1234567890") => Ok(()), +//! _ => Err(Response::builder() +//! .status(StatusCode::FORBIDDEN) +//! .body(Full::::default()) +//! .unwrap()), +//! } +//! })) +//! .service_fn(handle); +//! # } +//! ``` +//! //! Custom validation can be made by implementing [`ValidateRequest`]: //! //! ``` @@ -90,33 +202,10 @@ //! # } //! ``` //! -//! Or using a closure: -//! -//! ``` -//! use tower_http::validate_request::{ValidateRequestHeaderLayer, ValidateRequest}; -//! use http::{Request, Response, StatusCode, header::ACCEPT}; -//! use bytes::Bytes; -//! use http_body_util::Full; -//! use tower::{Service, ServiceExt, ServiceBuilder, service_fn, BoxError}; -//! -//! async fn handle(request: Request>) -> Result>, BoxError> { -//! # todo!(); -//! // ... -//! } -//! -//! # #[tokio::main] -//! # async fn main() -> Result<(), BoxError> { -//! let service = ServiceBuilder::new() -//! .layer(ValidateRequestHeaderLayer::custom(|request: &mut Request>| { -//! // Validate the request -//! # Ok::<_, Response>>(()) -//! })) -//! .service_fn(handle); -//! # Ok(()) -//! # } -//! ``` +//! [`Accept`]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Accept -use http::{header, Request, Response, StatusCode}; +use http::header::InvalidHeaderName; +use http::{header, header::HeaderName, Request, Response, StatusCode}; use mime::{Mime, MimeIter}; use pin_project_lite::pin_project; use std::{ @@ -168,6 +257,65 @@ impl ValidateRequestHeaderLayer> { } } +impl ValidateRequestHeaderLayer> { + /// Validate requests have a required header with a specific value. + /// + /// Rejects with `403 Forbidden` if the header is missing or does not have the expected value. + /// Header values that are not valid UTF-8 are treated as non-matching. + /// + /// If the request contains multiple values for the header, only the first occurrence is + /// checked. + /// + /// # Errors + /// + /// Returns an error if `expected_header_name` is not a valid HTTP header name per RFC 7230 + /// (non-empty, at most 32,768 bytes, containing only valid token characters). + /// + /// # Example + /// + /// ``` + /// use http::{Request, Response, StatusCode}; + /// use http_body_util::Full; + /// use bytes::Bytes; + /// use tower::{Service, ServiceBuilder, ServiceExt, service_fn}; + /// use tower_http::validate_request::ValidateRequestHeaderLayer; + /// + /// async fn handle(request: Request>) -> Result>, std::convert::Infallible> { + /// Ok(Response::new(request.into_body())) + /// } + /// + /// # #[tokio::main] + /// # async fn main() { + /// let mut service = ServiceBuilder::new() + /// .layer(ValidateRequestHeaderLayer::has_header_value( + /// "x-custom-header", + /// "random-value-1234567890", + /// ).expect("invalid validate header")) + /// .service_fn(handle); + /// + /// let request = Request::builder() + /// .header("x-custom-header", "random-value-1234567890") + /// .body(Full::default()) + /// .unwrap(); + /// + /// let response = service.ready().await.unwrap().call(request).await.unwrap(); + /// assert_eq!(response.status(), StatusCode::OK); + /// # } + /// ``` + pub fn has_header_value( + expected_header_name: &str, + expected_header_value: &str, + ) -> Result + where + ResBody: Default, + { + Ok(Self::custom(RequiredHeaderValue::new( + expected_header_name.parse::()?, + expected_header_value, + ))) + } +} + impl ValidateRequestHeaderLayer { /// Validate requests using a custom method. pub fn custom(validate: T) -> ValidateRequestHeaderLayer { @@ -411,6 +559,67 @@ where } } +/// Type that rejects requests if a header is not present or does not have an expected value. +pub struct RequiredHeaderValue { + expected_header_name: HeaderName, + expected_header_value: Arc, + _ty: PhantomData ResBody>, +} + +impl RequiredHeaderValue { + fn new(expected_header_name: HeaderName, expected_header_value: &str) -> Self + where + ResBody: Default, + { + Self { + expected_header_name, + expected_header_value: expected_header_value.into(), + _ty: PhantomData, + } + } +} + +impl Clone for RequiredHeaderValue { + fn clone(&self) -> Self { + Self { + expected_header_name: self.expected_header_name.clone(), + expected_header_value: self.expected_header_value.clone(), + _ty: PhantomData, + } + } +} + +impl fmt::Debug for RequiredHeaderValue { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("RequiredHeaderValue") + .field("expected_header_name", &self.expected_header_name) + .field("expected_header_value", &self.expected_header_value) + .finish() + } +} + +impl ValidateRequest for RequiredHeaderValue +where + ResBody: Default, +{ + type ResponseBody = ResBody; + + fn validate(&mut self, req: &mut Request) -> Result<(), Response> { + let request_header_value = req + .headers() + .get(&self.expected_header_name) + .and_then(|v| v.to_str().ok()); + + if request_header_value != Some(&*self.expected_header_value) { + let mut res = Response::new(ResBody::default()); + *res.status_mut() = StatusCode::FORBIDDEN; + return Err(res); + } + + Ok(()) + } +} + #[cfg(test)] mod tests { #[allow(unused_imports)] @@ -581,6 +790,162 @@ mod tests { assert_eq!(res.status(), StatusCode::NOT_ACCEPTABLE); } + #[tokio::test] + async fn valid_custom_header() { + let mut service = ServiceBuilder::new() + .layer( + ValidateRequestHeaderLayer::has_header_value( + "x-custom-header", + "random-value-1234567890", + ) + .expect("invalid validate header"), + ) + .service_fn(echo); + + let request = Request::get("/") + .header("x-custom-header", "random-value-1234567890") + .body(Body::empty()) + .unwrap(); + + let res = service.ready().await.unwrap().call(request).await.unwrap(); + + assert_eq!(res.status(), StatusCode::OK); + } + + #[tokio::test] + async fn invalid_custom_header() { + let mut service = ServiceBuilder::new() + .layer( + ValidateRequestHeaderLayer::has_header_value( + "x-custom-header", + "random-value-1234567890", + ) + .expect("invalid validate header"), + ) + .service_fn(echo); + + let request = Request::get("/") + .header("x-custom-header", "wrong-value") + .body(Body::empty()) + .unwrap(); + + let res = service.ready().await.unwrap().call(request).await.unwrap(); + + assert_eq!(res.status(), StatusCode::FORBIDDEN); + } + + #[tokio::test] + async fn missing_custom_header() { + let mut service = ServiceBuilder::new() + .layer( + ValidateRequestHeaderLayer::has_header_value( + "x-custom-header", + "random-value-1234567890", + ) + .expect("invalid validate header"), + ) + .service_fn(echo); + + let request = Request::get("/").body(Body::empty()).unwrap(); + + let res = service.ready().await.unwrap().call(request).await.unwrap(); + + assert_eq!(res.status(), StatusCode::FORBIDDEN); + } + + #[tokio::test] + async fn custom_header_multiple_values_uses_first() { + let mut service = ServiceBuilder::new() + .layer( + ValidateRequestHeaderLayer::has_header_value("x-custom-header", "correct-value") + .expect("invalid validate header"), + ) + .service_fn(echo); + + // First value matches: should pass + let request = Request::get("/") + .header("x-custom-header", "correct-value") + .header("x-custom-header", "other-value") + .body(Body::empty()) + .unwrap(); + + let res = service.ready().await.unwrap().call(request).await.unwrap(); + assert_eq!(res.status(), StatusCode::OK); + + // First value does not match: should reject even if second matches + let request = Request::get("/") + .header("x-custom-header", "wrong-value") + .header("x-custom-header", "correct-value") + .body(Body::empty()) + .unwrap(); + + let res = service.ready().await.unwrap().call(request).await.unwrap(); + assert_eq!(res.status(), StatusCode::FORBIDDEN); + } + + #[test] + fn invalid_header_name_returns_error() { + let result = ValidateRequestHeaderLayer::>::has_header_value( + "invalid header name with spaces", + "value", + ); + assert!(result.is_err()); + } + + #[tokio::test] + async fn custom_header_non_utf8_value_rejects() { + let mut service = ServiceBuilder::new() + .layer( + ValidateRequestHeaderLayer::has_header_value("x-custom-header", "expected-value") + .expect("invalid validate header"), + ) + .service_fn(echo); + + let request = Request::get("/") + .header("x-custom-header", b"\xff\xfe".as_slice()) + .body(Body::empty()) + .unwrap(); + + let res = service.ready().await.unwrap().call(request).await.unwrap(); + assert_eq!(res.status(), StatusCode::FORBIDDEN); + } + + #[tokio::test] + async fn custom_header_name_is_case_insensitive() { + let mut service = ServiceBuilder::new() + .layer( + ValidateRequestHeaderLayer::has_header_value("x-custom-header", "my-value") + .expect("invalid validate header"), + ) + .service_fn(echo); + + let request = Request::get("/") + .header("X-Custom-Header", "my-value") + .body(Body::empty()) + .unwrap(); + + let res = service.ready().await.unwrap().call(request).await.unwrap(); + assert_eq!(res.status(), StatusCode::OK); + } + + #[tokio::test] + async fn custom_header_value_is_case_sensitive() { + let mut service = ServiceBuilder::new() + .layer( + ValidateRequestHeaderLayer::has_header_value("x-custom-header", "My-Value") + .expect("invalid validate header"), + ) + .service_fn(echo); + + let request = Request::get("/") + .header("x-custom-header", "my-value") + .body(Body::empty()) + .unwrap(); + + let res = service.ready().await.unwrap().call(request).await.unwrap(); + assert_eq!(res.status(), StatusCode::FORBIDDEN); + } + async fn echo(req: Request) -> Result, BoxError> { Ok(Response::new(req.into_body())) } From af828a6ec99dca9f562fbb534f6c2b806becc7f2 Mon Sep 17 00:00:00 2001 From: Jess Izen <44884346+jlizen@users.noreply.github.com> Date: Fri, 12 Jun 2026 14:17:17 -0700 Subject: [PATCH 3/3] feat(follow_redirect)!: preserve request extensions across redirects (#706) * feat(follow_redirect)!: preserve request extensions across redirects * fix(follow_redirect)!: persist header/extension removals across hops * fix(follow_redirect): hand-write FilterCredentials Debug for MSRV * docs(follow_redirect): tighten extension and policy doc comments --- tower-http/CHANGELOG.md | 12 + tower-http/src/follow_redirect/mod.rs | 226 +++++++++++++++++- .../policy/filter_credentials.rs | 178 +++++++++++++- tower-http/src/follow_redirect/policy/mod.rs | 3 + 4 files changed, 405 insertions(+), 14 deletions(-) diff --git a/tower-http/CHANGELOG.md b/tower-http/CHANGELOG.md index 2edc73ac6..54fec3d4a 100644 --- a/tower-http/CHANGELOG.md +++ b/tower-http/CHANGELOG.md @@ -33,9 +33,21 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 by the features that need them (e.g. `compression-gzip`, `fs`, `timeout`). ([#628]) - MSRV bumped from 1.64 to 1.65. +- **breaking:** `follow-redirect`: `FollowRedirect` now forwards request + `Extensions` to redirected requests instead of dropping them. The `Standard` + policy drops extensions on cross-origin redirections (same-origin keeps them). + Opt out with `FollowRedirectLayer::preserve_extensions(false)`; keep specific + types with `FilterCredentials::allow_extension::()` or all of them with + `keep_all_extensions()`. ([#581]) +- **breaking:** `follow-redirect`: header and extension filtering is now + cumulative. A value a policy drops on one hop is no longer replayed on later + hops, so `FilterCredentials` no longer re-sends `Cookie`/`Authorization` to a + same-origin target reached after a cross-origin hop. Custom `Policy::on_request` + impls now see the previous hop's filtered request, not the original. ([#581]) [#215]: https://github.com/tower-rs/tower-http/issues/215 [#360]: https://github.com/tower-rs/tower-http/pull/360 +[#581]: https://github.com/tower-rs/tower-http/pull/581 [#628]: https://github.com/tower-rs/tower-http/pull/628 [#642]: https://github.com/tower-rs/tower-http/pull/642 diff --git a/tower-http/src/follow_redirect/mod.rs b/tower-http/src/follow_redirect/mod.rs index 34d9e607e..4d74959df 100644 --- a/tower-http/src/follow_redirect/mod.rs +++ b/tower-http/src/follow_redirect/mod.rs @@ -6,11 +6,16 @@ //! redirections. //! //! The middleware tries to clone the original [`Request`] when making a redirected request. -//! However, since [`Extensions`][http::Extensions] are `!Clone`, any extensions set by outer -//! middleware will be discarded. Also, the request body cannot always be cloned. When the -//! original body is known to be empty by [`Body::size_hint`], the middleware uses `Default` -//! implementation of the body type to create a new request body. If you know that the body can be -//! cloned in some way, you can tell the middleware to clone it by configuring a [`policy`]. +//! Request headers and [`Extensions`] are carried over to redirected requests; the [`policy`] +//! decides which survive each hop (the [`Standard`] policy drops credential headers and all +//! extensions cross-origin), and filtering is cumulative, so a dropped value never reappears later +//! in the chain. Extension forwarding can be disabled with +//! [`FollowRedirectLayer::preserve_extensions`]. +//! +//! The request body cannot always be cloned. When the original body is known to be empty by +//! [`Body::size_hint`], the middleware uses the `Default` implementation of the body type. If the +//! body can be cloned in some way, you can tell the middleware to clone it by configuring a +//! [`policy`]. //! //! # Examples //! @@ -98,8 +103,8 @@ use self::policy::{Action, Attempt, Policy, Standard}; use futures_util::future::Either; use http::{ header::CONTENT_ENCODING, header::CONTENT_LENGTH, header::CONTENT_TYPE, header::LOCATION, - header::TRANSFER_ENCODING, HeaderMap, HeaderValue, Method, Request, Response, StatusCode, Uri, - Version, + header::TRANSFER_ENCODING, Extensions, HeaderMap, HeaderValue, Method, Request, Response, + StatusCode, Uri, Version, }; use http_body::Body; use pin_project_lite::pin_project; @@ -119,9 +124,10 @@ use url::Url; /// [`Layer`] for retrying requests with a [`Service`] to follow redirection responses. /// /// See the [module docs](self) for more details. -#[derive(Clone, Copy, Debug, Default)] +#[derive(Clone, Copy, Debug)] pub struct FollowRedirectLayer

{ policy: P, + preserve_extensions: bool, } impl FollowRedirectLayer { @@ -134,7 +140,26 @@ impl FollowRedirectLayer { impl

FollowRedirectLayer

{ /// Create a new [`FollowRedirectLayer`] with the given redirection [`Policy`]. pub fn with_policy(policy: P) -> Self { - FollowRedirectLayer { policy } + FollowRedirectLayer { + policy, + preserve_extensions: true, + } + } + + /// Whether request [`Extensions`] are carried over to redirected requests. Defaults to `true`. + /// + /// Setting this to `false` drops all extensions on redirected requests. When preserved, the + /// [`policy`] still filters them via [`Policy::on_request`]; the [`Standard`] policy drops + /// extensions cross-origin (see [`FilterCredentials`][policy::FilterCredentials]). + pub fn preserve_extensions(mut self, preserve: bool) -> Self { + self.preserve_extensions = preserve; + self + } +} + +impl Default for FollowRedirectLayer

{ + fn default() -> Self { + FollowRedirectLayer::with_policy(P::default()) } } @@ -147,6 +172,7 @@ where fn layer(&self, inner: S) -> Self::Service { FollowRedirect::with_policy(inner, self.policy.clone()) + .preserve_extensions(self.preserve_extensions) } } @@ -157,6 +183,7 @@ where pub struct FollowRedirect { inner: S, policy: P, + preserve_extensions: bool, } impl FollowRedirect { @@ -179,7 +206,11 @@ where { /// Create a new [`FollowRedirect`] with the given redirection [`Policy`]. pub fn with_policy(inner: S, policy: P) -> Self { - FollowRedirect { inner, policy } + FollowRedirect { + inner, + policy, + preserve_extensions: true, + } } /// Returns a new [`Layer`] that wraps services with a `FollowRedirect` middleware @@ -193,6 +224,16 @@ where define_inner_service_accessors!(); } +impl FollowRedirect { + /// Whether request [`Extensions`] are carried over to redirected requests. Defaults to `true`. + /// + /// See [`FollowRedirectLayer::preserve_extensions`]. + pub fn preserve_extensions(mut self, preserve: bool) -> Self { + self.preserve_extensions = preserve; + self + } +} + impl Service> for FollowRedirect where S: Service, Response = Response> + Clone, @@ -214,11 +255,18 @@ where let mut body = BodyRepr::None; body.try_clone_from(req.body(), &policy); policy.on_request(&mut req); + // Snapshot the extensions to replay on redirected requests (empty when not preserving). + let extensions = if self.preserve_extensions { + req.extensions().clone() + } else { + Extensions::new() + }; ResponseFuture { method: req.method().clone(), uri: req.uri().clone(), version: req.version(), headers: req.headers().clone(), + extensions, body, future: Either::Left(service.call(req)), service, @@ -242,6 +290,7 @@ pin_project! { uri: Uri, version: Version, headers: HeaderMap, + extensions: Extensions, body: BodyRepr, } } @@ -325,7 +374,12 @@ where *req.method_mut() = this.method.clone(); *req.version_mut() = *this.version; *req.headers_mut() = this.headers.clone(); + *req.extensions_mut() = this.extensions.clone(); this.policy.on_request(&mut req); + // Carry the filtered headers and extensions forward so anything dropped on this + // hop stays dropped on the next one (e.g. credentials after a cross-origin hop). + *this.headers = req.headers().clone(); + *this.extensions = req.extensions().clone(); this.future .set(Either::Right(Oneshot::new(this.service.clone(), req))); @@ -337,7 +391,7 @@ where } } -/// Response [`Extensions`][http::Extensions] value that represents the effective request URI of +/// Response [`Extensions`] value that represents the effective request URI of /// a response returned by a [`FollowRedirect`] middleware. /// /// The value differs from the original request's effective URI if the middleware has followed @@ -463,8 +517,153 @@ mod tests { ); } + #[derive(Clone, Debug, PartialEq)] + struct Marker(u32); + + #[tokio::test] + async fn preserves_extensions() { + let svc = ServiceBuilder::new() + .layer(FollowRedirectLayer::new()) + .buffer(1) + .service_fn(handle); + let mut req = Request::builder() + .uri("http://example.com/42") + .body(Body::empty()) + .unwrap(); + req.extensions_mut().insert(Marker(7)); + let res = svc.oneshot(req).await.unwrap(); + // The same-origin redirect chain should carry the extension through to the final request. + assert_eq!(res.extensions().get::(), Some(&Marker(7))); + } + + #[tokio::test] + async fn preserve_extensions_opt_out() { + let svc = ServiceBuilder::new() + .layer(FollowRedirectLayer::new().preserve_extensions(false)) + .buffer(1) + .service_fn(handle); + let mut req = Request::builder() + .uri("http://example.com/42") + .body(Body::empty()) + .unwrap(); + req.extensions_mut().insert(Marker(7)); + let res = svc.oneshot(req).await.unwrap(); + assert!(res.extensions().get::().is_none()); + } + + #[tokio::test] + async fn drops_extensions_cross_origin() { + let svc = ServiceBuilder::new() + .layer(FollowRedirectLayer::new()) + .buffer(1) + .service_fn(cross_origin); + let mut req = Request::builder() + .uri("http://a.example.com/") + .body(Body::empty()) + .unwrap(); + req.extensions_mut().insert(Marker(7)); + let res = svc.oneshot(req).await.unwrap(); + // The Standard policy treats the cross-origin hop as blocked and drops the extension. + assert!(res.extensions().get::().is_none()); + assert_eq!( + res.extensions().get::().unwrap().0, + "http://b.example.com/" + ); + } + + #[tokio::test] + async fn allowlisted_extension_survives_cross_origin() { + #[derive(Clone, Debug, PartialEq)] + struct Allowed(u32); + + let svc = ServiceBuilder::new() + .layer(FollowRedirectLayer::with_policy( + FilterCredentials::new().allow_extension::(), + )) + .buffer(1) + .service_fn(cross_origin); + let mut req = Request::builder() + .uri("http://a.example.com/") + .body(Body::empty()) + .unwrap(); + req.extensions_mut().insert(Marker(7)); + req.extensions_mut().insert(Allowed(9)); + let res = svc.oneshot(req).await.unwrap(); + assert!(res.extensions().get::().is_none()); + assert_eq!(res.extensions().get::(), Some(&Allowed(9))); + } + + #[tokio::test] + async fn headers_and_extensions_do_not_resurrect_after_cross_origin() { + let svc = ServiceBuilder::new() + .layer(FollowRedirectLayer::new()) + .buffer(1) + .service_fn(resurrection_chain); + let mut req = Request::builder() + .uri("http://a.example.com/") + .header(http::header::COOKIE, "secret") + .body(Body::empty()) + .unwrap(); + req.extensions_mut().insert(Marker(7)); + let res = svc.oneshot(req).await.unwrap(); + // The chain is a.example.com -> b.example.com/second (cross-origin, both dropped) -> + // b.example.com/final (same-origin). Neither the cookie nor the extension may reappear on + // the final, same-origin request just because the original snapshot is replayed. + assert_eq!( + res.extensions().get::().unwrap().0, + "http://b.example.com/final" + ); + assert!(res.extensions().get::().is_none()); + assert!(!res.headers().contains_key("x-saw-cookie")); + } + + /// Redirects `a.example.com` to `b.example.com` once, then echoes the final request's + /// extensions back on the response. + async fn cross_origin(req: Request) -> Result, Infallible> { + let mut res = Response::builder(); + if req.uri().host() == Some("a.example.com") { + res = res + .status(StatusCode::MOVED_PERMANENTLY) + .header(LOCATION, "http://b.example.com/"); + } + if let Some(extensions) = res.extensions_mut() { + *extensions = req.extensions().clone(); + } + Ok::<_, Infallible>(res.body(0).unwrap()) + } + + /// A three-hop chain: `a.example.com` redirects cross-origin to `b.example.com/second`, which + /// redirects same-origin to `b.example.com/final`. Each response echoes the request's + /// extensions and flags (via the `x-saw-cookie` response header) whether the request still + /// carried a `Cookie`, so a test can detect credentials or extensions reappearing after the + /// cross-origin hop. + async fn resurrection_chain(req: Request) -> Result, Infallible> { + let location = match (req.uri().host(), req.uri().path()) { + (Some("a.example.com"), _) => Some("http://b.example.com/second"), + (Some("b.example.com"), "/second") => Some("http://b.example.com/final"), + _ => None, + }; + let saw_cookie = req.headers().contains_key(http::header::COOKIE); + let mut builder = Response::builder(); + if let Some(location) = location { + builder = builder + .status(StatusCode::TEMPORARY_REDIRECT) + .header(LOCATION, location); + } + if let Some(extensions) = builder.extensions_mut() { + *extensions = req.extensions().clone(); + } + let mut res = builder.body(0).unwrap(); + if saw_cookie { + res.headers_mut() + .insert("x-saw-cookie", HeaderValue::from_static("yes")); + } + Ok::<_, Infallible>(res) + } + /// A server with an endpoint `/{n}` which redirects to `/{n-1}` unless `n` equals zero, - /// returning `n` as the response body. + /// returning `n` as the response body. The request's extensions are echoed back on the + /// response so tests can observe which extensions reached the final request. async fn handle(req: Request) -> Result, Infallible> { let n: u64 = req.uri().path()[1..].parse().unwrap(); let mut res = Response::builder(); @@ -473,6 +672,9 @@ mod tests { .status(StatusCode::MOVED_PERMANENTLY) .header(LOCATION, format!("/{}", n - 1)); } + if let Some(extensions) = res.extensions_mut() { + *extensions = req.extensions().clone(); + } Ok::<_, Infallible>(res.body(n).unwrap()) } diff --git a/tower-http/src/follow_redirect/policy/filter_credentials.rs b/tower-http/src/follow_redirect/policy/filter_credentials.rs index f58988fcf..fe27d5b39 100644 --- a/tower-http/src/follow_redirect/policy/filter_credentials.rs +++ b/tower-http/src/follow_redirect/policy/filter_credentials.rs @@ -1,19 +1,44 @@ use super::{eq_origin, Action, Attempt, Policy}; use http::{ header::{self, HeaderName}, - Request, + Extensions, Request, }; /// A redirection [`Policy`] that removes credentials from requests in redirections. -#[derive(Clone, Debug)] +/// +/// Besides headers, it filters request [`Extensions`] on "blocked" redirections. Extensions are +/// keyed by arbitrary types with no blocklist to mirror the header one, so blocked redirections +/// drop *all* extensions by default; re-admit types with [`allow_extension`][Self::allow_extension]. +/// +/// Filtering is cumulative: a value removed on one hop is not reintroduced on later hops. +#[derive(Clone)] pub struct FilterCredentials { block_cross_origin: bool, block_any: bool, remove_blocklisted: bool, remove_all: bool, + remove_all_extensions: bool, + extension_allowlist: Vec, blocked: bool, } +// `Debug` is implemented by hand rather than derived: deriving it would require `Debug` for the +// higher-ranked `fn` pointers in `extension_allowlist`, which does not hold on older compilers +// (and would only print opaque addresses anyway). The allowlist is summarized by its length. +impl std::fmt::Debug for FilterCredentials { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("FilterCredentials") + .field("block_cross_origin", &self.block_cross_origin) + .field("block_any", &self.block_any) + .field("remove_blocklisted", &self.remove_blocklisted) + .field("remove_all", &self.remove_all) + .field("remove_all_extensions", &self.remove_all_extensions) + .field("allowed_extensions", &self.extension_allowlist.len()) + .field("blocked", &self.blocked) + .finish() + } +} + const BLOCKLIST: &[HeaderName] = &[ header::AUTHORIZATION, header::COOKIE, @@ -29,6 +54,8 @@ impl FilterCredentials { block_any: false, remove_blocklisted: true, remove_all: false, + remove_all_extensions: true, + extension_allowlist: Vec::new(), blocked: false, } } @@ -74,6 +101,38 @@ impl FilterCredentials { self.remove_all = false; self.remove_blocklisted(false) } + + /// Remove all non-allowlisted extensions on "blocked" redirections. This is the default. + /// + /// Re-admit specific types with [`allow_extension`][Self::allow_extension]. + pub fn remove_all_extensions(mut self) -> Self { + self.remove_all_extensions = true; + self + } + + /// Keep all request extensions on "blocked" redirections. + /// + /// Forwards every extension, including cross-origin. Use only when no extension carries + /// sensitive, origin-scoped data. + pub fn keep_all_extensions(mut self) -> Self { + self.remove_all_extensions = false; + self + } + + /// Keep extension type `T` on "blocked" redirections even when other extensions are removed. + /// + /// No effect under [`keep_all_extensions`][Self::keep_all_extensions]. + pub fn allow_extension(mut self) -> Self + where + T: Clone + Send + Sync + 'static, + { + self.extension_allowlist.push(|from, to| { + if let Some(value) = from.remove::() { + to.insert(value); + } + }); + self + } } impl Default for FilterCredentials { @@ -99,6 +158,19 @@ impl Policy for FilterCredentials { headers.remove(key); } } + + if self.remove_all_extensions { + let extensions = request.extensions_mut(); + if self.extension_allowlist.is_empty() { + extensions.clear(); + } else { + let mut allowed = Extensions::new(); + for transfer in &self.extension_allowlist { + transfer(extensions, &mut allowed); + } + *extensions = allowed; + } + } } } } @@ -162,4 +234,106 @@ mod tests { Policy::<(), ()>::on_request(&mut policy, &mut request); assert!(!request.headers().contains_key(header::COOKIE)); } + + #[derive(Clone, Debug, PartialEq)] + struct Kept(u32); + + #[derive(Clone, Debug, PartialEq)] + struct Dropped(u32); + + fn cross_origin_attempt<'a>(previous: &'a Uri, location: &'a Uri) -> Attempt<'a> { + Attempt { + status: Default::default(), + method: &Method::GET, + location, + previous_method: &Method::GET, + previous, + } + } + + #[test] + fn extensions_are_kept_same_origin_and_dropped_cross_origin() { + let initial = Uri::from_static("http://example.com/old"); + let same_origin = Uri::from_static("http://example.com/new"); + let cross_origin = Uri::from_static("https://example.com/new"); + + let mut policy = FilterCredentials::default(); + + let attempt = cross_origin_attempt(&initial, &same_origin); + assert!(Policy::<(), ()>::redirect(&mut policy, &attempt) + .unwrap() + .is_follow()); + let mut request = Request::builder().uri(&same_origin).body(()).unwrap(); + request.extensions_mut().insert(Kept(42)); + Policy::<(), ()>::on_request(&mut policy, &mut request); + assert_eq!(request.extensions().get::(), Some(&Kept(42))); + + let attempt = cross_origin_attempt(&same_origin, &cross_origin); + assert!(Policy::<(), ()>::redirect(&mut policy, &attempt) + .unwrap() + .is_follow()); + let mut request = Request::builder().uri(&cross_origin).body(()).unwrap(); + request.extensions_mut().insert(Kept(42)); + Policy::<(), ()>::on_request(&mut policy, &mut request); + assert!(request.extensions().get::().is_none()); + } + + #[test] + fn allowlisted_extensions_survive_cross_origin() { + let initial = Uri::from_static("http://example.com/old"); + let cross_origin = Uri::from_static("https://example.com/new"); + + let mut policy = FilterCredentials::default().allow_extension::(); + let attempt = cross_origin_attempt(&initial, &cross_origin); + assert!(Policy::<(), ()>::redirect(&mut policy, &attempt) + .unwrap() + .is_follow()); + + let mut request = Request::builder().uri(&cross_origin).body(()).unwrap(); + request.extensions_mut().insert(Kept(1)); + request.extensions_mut().insert(Dropped(2)); + Policy::<(), ()>::on_request(&mut policy, &mut request); + assert_eq!(request.extensions().get::(), Some(&Kept(1))); + assert!(request.extensions().get::().is_none()); + } + + #[test] + fn keep_all_extensions_forwards_cross_origin() { + let initial = Uri::from_static("http://example.com/old"); + let cross_origin = Uri::from_static("https://example.com/new"); + + let mut policy = FilterCredentials::default().keep_all_extensions(); + let attempt = cross_origin_attempt(&initial, &cross_origin); + assert!(Policy::<(), ()>::redirect(&mut policy, &attempt) + .unwrap() + .is_follow()); + + let mut request = Request::builder().uri(&cross_origin).body(()).unwrap(); + request.extensions_mut().insert(Kept(1)); + Policy::<(), ()>::on_request(&mut policy, &mut request); + assert_eq!(request.extensions().get::(), Some(&Kept(1))); + } + + #[test] + fn allow_extension_is_ignored_when_keeping_all() { + let initial = Uri::from_static("http://example.com/old"); + let cross_origin = Uri::from_static("https://example.com/new"); + + // The allowlist only takes effect while extensions are being removed; keep_all disables + // removal, so everything is forwarded regardless of the allowlist. + let mut policy = FilterCredentials::default() + .keep_all_extensions() + .allow_extension::(); + let attempt = cross_origin_attempt(&initial, &cross_origin); + assert!(Policy::<(), ()>::redirect(&mut policy, &attempt) + .unwrap() + .is_follow()); + + let mut request = Request::builder().uri(&cross_origin).body(()).unwrap(); + request.extensions_mut().insert(Kept(1)); + request.extensions_mut().insert(Dropped(2)); + Policy::<(), ()>::on_request(&mut policy, &mut request); + assert_eq!(request.extensions().get::(), Some(&Kept(1))); + assert_eq!(request.extensions().get::(), Some(&Dropped(2))); + } } diff --git a/tower-http/src/follow_redirect/policy/mod.rs b/tower-http/src/follow_redirect/policy/mod.rs index 36a5bc3f0..549404eaa 100644 --- a/tower-http/src/follow_redirect/policy/mod.rs +++ b/tower-http/src/follow_redirect/policy/mod.rs @@ -60,6 +60,9 @@ pub trait Policy { /// This can for example be used to remove sensitive headers from the request /// or prepare the request in other ways. /// + /// On a redirected request, whatever this method leaves on the request becomes the baseline for + /// the next hop, so a value removed here stays removed for the rest of the chain. + /// /// The default implementation does nothing. fn on_request(&mut self, _request: &mut Request) {}