diff --git a/tower-http/src/cors/allow_headers.rs b/tower-http/src/cors/allow_headers.rs index 8e49e780..929dafda 100644 --- a/tower-http/src/cors/allow_headers.rs +++ b/tower-http/src/cors/allow_headers.rs @@ -58,6 +58,10 @@ impl AllowHeaders { matches!(&self.0, AllowHeadersInner::Const(Some(v)) if v == WILDCARD) } + pub(super) fn varies_with_request_headers(&self) -> bool { + !matches!(&self.0, AllowHeadersInner::Const(_)) + } + pub(super) fn to_header(&self, parts: &RequestParts) -> Option<(HeaderName, HeaderValue)> { let allow_headers = match &self.0 { AllowHeadersInner::Const(v) => v.clone()?, diff --git a/tower-http/src/cors/allow_methods.rs b/tower-http/src/cors/allow_methods.rs index a2aeb642..5735a7b2 100644 --- a/tower-http/src/cors/allow_methods.rs +++ b/tower-http/src/cors/allow_methods.rs @@ -72,6 +72,10 @@ impl AllowMethods { matches!(&self.0, AllowMethodsInner::Const(Some(v)) if v == WILDCARD) } + pub(super) fn varies_with_request_method(&self) -> bool { + !matches!(&self.0, AllowMethodsInner::Const(_)) + } + pub(super) fn to_header(&self, parts: &RequestParts) -> Option<(HeaderName, HeaderValue)> { let allow_methods = match &self.0 { AllowMethodsInner::Const(v) => v.clone()?, diff --git a/tower-http/src/cors/allow_origin.rs b/tower-http/src/cors/allow_origin.rs index 646220fa..cc4e96e0 100644 --- a/tower-http/src/cors/allow_origin.rs +++ b/tower-http/src/cors/allow_origin.rs @@ -111,6 +111,10 @@ impl AllowOrigin { matches!(&self.0, OriginInner::Const(v) if v == WILDCARD) } + pub(super) fn varies_with_origin(&self) -> bool { + !matches!(&self.0, OriginInner::Const(_)) + } + pub(super) fn to_future( &self, origin: Option<&HeaderValue>, diff --git a/tower-http/src/cors/mod.rs b/tower-http/src/cors/mod.rs index e7aca714..cb927597 100644 --- a/tower-http/src/cors/mod.rs +++ b/tower-http/src/cors/mod.rs @@ -51,10 +51,7 @@ use allow_origin::AllowOriginFuture; use bytes::{BufMut, BytesMut}; -use http::{ - header::{self, HeaderName}, - HeaderMap, HeaderValue, Method, Request, Response, -}; +use http::{header, HeaderMap, HeaderName, HeaderValue, Method, Request, Response}; use pin_project_lite::pin_project; use std::{ future::Future, @@ -99,6 +96,7 @@ pub struct CorsLayer { expose_headers: ExposeHeaders, max_age: MaxAge, vary: Vary, + is_vary_custom: bool, } #[allow(clippy::declare_interior_mutable_const)] @@ -122,6 +120,7 @@ impl CorsLayer { expose_headers: Default::default(), max_age: Default::default(), vary: Default::default(), + is_vary_custom: false, } } @@ -432,12 +431,20 @@ impl CorsLayer { /// Set the value(s) of the [`Vary`][mdn] header. /// - /// In contrast to the other headers, this one has a non-empty default of - /// [`preflight_request_headers()`]. + /// By default, this value is derived from whether CORS response headers are + /// request-dependent: + /// + /// - `Origin` is included when `Access-Control-Allow-Origin` depends on the + /// request's `Origin` header (for example, origin lists or predicates). + /// - `Access-Control-Request-Method` is included when + /// `Access-Control-Allow-Methods` mirrors `Access-Control-Request-Method`. + /// - `Access-Control-Request-Headers` is included when + /// `Access-Control-Allow-Headers` mirrors `Access-Control-Request-Headers`. + /// - If none of those values are request-dependent, no `Vary` header is + /// added. /// - /// You only need to set this if you want to remove some of these defaults, - /// or if you use a closure for one of the other headers and want to add a - /// vary header accordingly. + /// Calling this method sets `Vary` explicitly and pins it to the provided + /// value, regardless of future changes to those other CORS settings. /// /// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Vary pub fn vary(mut self, headers: T) -> Self @@ -445,8 +452,34 @@ impl CorsLayer { T: Into, { self.vary = headers.into(); + self.is_vary_custom = true; self } + + /// Recomputes the `Vary` header, if it hasn't been set explicitly. + fn update_vary_header(&mut self) { + if !self.is_vary_custom { + let vary_origin = self.allow_origin.varies_with_origin(); + let vary_method = self.allow_methods.varies_with_request_method(); + let vary_headers = self.allow_headers.varies_with_request_headers(); + + if !(vary_origin || vary_method || vary_headers) { + self.vary = Vary::list([]); + } else { + let mut vary_header_names = Vec::new(); + if vary_origin { + vary_header_names.push(header::ORIGIN); + } + if vary_method { + vary_header_names.push(header::ACCESS_CONTROL_REQUEST_METHOD); + } + if vary_headers { + vary_header_names.push(header::ACCESS_CONTROL_REQUEST_HEADERS); + } + self.vary = Vary::list(vary_header_names); + } + } + } } /// Represents a wildcard value (`*`) used with some CORS headers such as @@ -493,10 +526,12 @@ impl Layer for CorsLayer { fn layer(&self, inner: S) -> Self::Service { ensure_usable_cors_rules(self); - Cors { - inner, - layer: self.clone(), - } + // Clone the layer to modify Vary header logic + let mut layer = self.clone(); + + layer.update_vary_header(); + + Cors { inner, layer } } } @@ -641,6 +676,8 @@ impl Cors { F: FnOnce(CorsLayer) -> CorsLayer, { self.layer = f(self.layer); + + self.layer.update_vary_header(); self } } diff --git a/tower-http/src/cors/tests.rs b/tower-http/src/cors/tests.rs index 8f3f4acb..b74fa2f0 100644 --- a/tower-http/src/cors/tests.rs +++ b/tower-http/src/cors/tests.rs @@ -1,37 +1,98 @@ use std::convert::Infallible; -use crate::test_helpers::Body; -use http::{header, HeaderValue, Request, Response}; +use crate::{cors::Vary, test_helpers::Body}; +use http::{header, HeaderName, HeaderValue, Method, Request, Response}; use tower::{service_fn, util::ServiceExt, Layer}; -use crate::cors::{AllowOrigin, CorsLayer}; +use crate::cors::{AllowHeaders, AllowMethods, AllowOrigin, Any, Cors, CorsLayer}; + +const INITIAL_VARY_HEADERS: HeaderValue = HeaderValue::from_static("accept, accept-encoding"); +const ADDITIONAL_VARY_HEADERS: [HeaderName; 3] = [ + header::ORIGIN, + header::ACCESS_CONTROL_REQUEST_METHOD, + header::ACCESS_CONTROL_REQUEST_HEADERS, +]; + +#[tokio::test] +async fn permissive_vary_header_is_empty() { + let svc = CorsLayer::permissive().layer(service_fn(|_: Request| async { + Ok::<_, Infallible>(Response::new(Body::empty())) + })); + + let req = Request::builder().body(Body::empty()).unwrap(); + + let res = svc.oneshot(req).await.unwrap(); + assert!( + res.headers().get(header::VARY).is_none(), + "Vary header should be omitted for permissive config" + ); +} #[tokio::test] -#[allow( - clippy::declare_interior_mutable_const, - clippy::borrow_interior_mutable_const -)] -async fn vary_set_by_inner_service() { - const CUSTOM_VARY_HEADERS: HeaderValue = HeaderValue::from_static("accept, accept-encoding"); +async fn include_custom_permissive_to_vary_set_by_inner_service() { const PERMISSIVE_CORS_VARY_HEADERS: HeaderValue = HeaderValue::from_static( "origin, access-control-request-method, access-control-request-headers", ); async fn inner_svc(_: Request) -> Result, Infallible> { Ok(Response::builder() - .header(header::VARY, CUSTOM_VARY_HEADERS) + .header(header::VARY, INITIAL_VARY_HEADERS) .body(Body::empty()) .unwrap()) } - let svc = CorsLayer::permissive().layer(service_fn(inner_svc)); + let svc = CorsLayer::permissive() + .vary(Vary::list(ADDITIONAL_VARY_HEADERS)) + .layer(service_fn(inner_svc)); + let res = svc.oneshot(Request::new(Body::empty())).await.unwrap(); let mut vary_headers = res.headers().get_all(header::VARY).into_iter(); - assert_eq!(vary_headers.next(), Some(&CUSTOM_VARY_HEADERS)); + assert_eq!(vary_headers.next(), Some(&INITIAL_VARY_HEADERS)); assert_eq!(vary_headers.next(), Some(&PERMISSIVE_CORS_VARY_HEADERS)); assert_eq!(vary_headers.next(), None); } +#[tokio::test] +async fn permissive_with_custom_vary_builder() { + let custom_vary = HeaderValue::from_static("x-foo"); + let svc = CorsLayer::permissive() + .vary(Vary::list([header::HeaderName::from_static("x-foo")])) + .layer(service_fn(|_: Request| async { + Ok::<_, Infallible>(Response::new(Body::empty())) + })); + + let req = Request::builder().body(Body::empty()).unwrap(); + let res = svc.oneshot(req).await.unwrap(); + let vary = res.headers().get(header::VARY); + assert_eq!(vary, Some(&custom_vary)); +} + +#[tokio::test] +async fn permissive_with_inner_and_builder_vary() { + let custom_vary = HeaderValue::from_static("x-foo"); + let inner_vary = HeaderValue::from_static("accept-encoding"); + let svc = CorsLayer::permissive() + .vary(Vary::list([header::HeaderName::from_static("x-foo")])) + .layer(service_fn(|_: Request| { + let inner_vary = inner_vary.clone(); + async move { + Ok::<_, Infallible>( + Response::builder() + .header(header::VARY, inner_vary) + .body(Body::empty()) + .unwrap(), + ) + } + })); + + let req = Request::builder().body(Body::empty()).unwrap(); + let res = svc.oneshot(req).await.unwrap(); + let mut vary_headers = res.headers().get_all(header::VARY).iter(); + assert_eq!(vary_headers.next(), Some(&inner_vary)); + assert_eq!(vary_headers.next(), Some(&custom_vary)); + assert_eq!(vary_headers.next(), None); +} + #[tokio::test] async fn test_allow_origin_async_predicate() { #[derive(Clone)] @@ -71,3 +132,115 @@ async fn test_allow_origin_async_predicate() { let res = allow_origin.to_future(Some(&invalid_origin), &parts).await; assert!(res.is_none()); } + +#[tokio::test] +async fn derived_vary_header_for_mixed_wildcard_configuration() { + let svc = CorsLayer::new() + .allow_origin(Any) + .allow_methods(AllowMethods::mirror_request()) + .allow_headers(AllowHeaders::mirror_request()) + .layer(service_fn(|_: Request| async { + Ok::<_, Infallible>(Response::new(Body::empty())) + })); + + let req = Request::builder() + .method(Method::OPTIONS) + .header(header::ORIGIN, "https://example.com") + .header(header::ACCESS_CONTROL_REQUEST_METHOD, "GET") + .header(header::ACCESS_CONTROL_REQUEST_HEADERS, "content-type") + .body(Body::empty()) + .unwrap(); + + let res = svc.oneshot(req).await.unwrap(); + + assert_eq!( + res.headers().get(header::VARY), + Some(&HeaderValue::from_static( + "access-control-request-method, access-control-request-headers", + )) + ); +} + +#[tokio::test] +async fn very_permissive_emits_vary_headers() { + let svc = CorsLayer::very_permissive().layer(service_fn(|_: Request| async { + Ok::<_, Infallible>(Response::new(Body::empty())) + })); + + let req = Request::builder() + .method(Method::OPTIONS) + .header(header::ORIGIN, "https://example.com") + .header(header::ACCESS_CONTROL_REQUEST_METHOD, "GET") + .header(header::ACCESS_CONTROL_REQUEST_HEADERS, "content-type") + .body(Body::empty()) + .unwrap(); + + let res = svc.oneshot(req).await.unwrap(); + + assert_eq!( + res.headers().get(header::VARY), + Some(&HeaderValue::from_static( + "origin, access-control-request-method, access-control-request-headers", + )) + ); +} + +#[tokio::test] +async fn cors_map_layer_smoke_without_vary_header() { + let svc = Cors::new(service_fn(|_: Request| async { + Ok::<_, Infallible>(Response::new(Body::empty())) + })) + .allow_origin(Any) + .allow_methods(Any) + .allow_headers(Any); + + let res = svc.oneshot(Request::new(Body::empty())).await.unwrap(); + + assert!(res.headers().get(header::VARY).is_none()); +} + +#[tokio::test] +async fn cors_map_layer_smoke_with_vary_header() { + let svc = Cors::new(service_fn(|_: Request| async { + Ok::<_, Infallible>(Response::new(Body::empty())) + })) + .allow_origin(Any) + .allow_methods(AllowMethods::mirror_request()) + .allow_headers(Any); + + let req = Request::builder() + .method(Method::OPTIONS) + .header(header::ORIGIN, "https://example.com") + .header(header::ACCESS_CONTROL_REQUEST_METHOD, "GET") + .body(Body::empty()) + .unwrap(); + + let res = svc.oneshot(req).await.unwrap(); + + assert_eq!( + res.headers().get(header::VARY), + Some(&HeaderValue::from_static("access-control-request-method")) + ); +} + +#[tokio::test] +async fn exact_origin_does_not_emit_origin_vary_header() { + let svc = CorsLayer::new() + .allow_origin(AllowOrigin::exact(HeaderValue::from_static( + "http://example.com", + ))) + .allow_methods([Method::GET]) + .allow_headers([header::CONTENT_TYPE]) + .layer(service_fn(|_: Request| async { + Ok::<_, Infallible>(Response::new(Body::empty())) + })); + + let req = Request::builder() + .header(header::ORIGIN, "http://example.com") + .body(Body::empty()) + .unwrap(); + + let res = svc.oneshot(req).await.unwrap(); + + assert!(res.headers().get(header::VARY).is_none()); +} diff --git a/tower-http/src/cors/vary.rs b/tower-http/src/cors/vary.rs index 3ebe4a27..21a3486d 100644 --- a/tower-http/src/cors/vary.rs +++ b/tower-http/src/cors/vary.rs @@ -1,6 +1,6 @@ use http::header::{self, HeaderName, HeaderValue}; -use super::preflight_request_headers; +use crate::cors::preflight_request_headers; /// Holds configuration for how to set the [`Vary`][mdn] header. ///