Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions tower-http/src/cors/allow_headers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,10 @@ impl AllowHeaders {
matches!(&self.0, AllowHeadersInner::Const(Some(v)) if v == WILDCARD)
}

pub(super) fn varies_with_request_headers(&self) -> bool {
!matches!(&self.0, AllowHeadersInner::Const(_))
}

pub(super) fn to_header(&self, parts: &RequestParts) -> Option<(HeaderName, HeaderValue)> {
let allow_headers = match &self.0 {
AllowHeadersInner::Const(v) => v.clone()?,
Expand Down
4 changes: 4 additions & 0 deletions tower-http/src/cors/allow_methods.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,10 @@ impl AllowMethods {
matches!(&self.0, AllowMethodsInner::Const(Some(v)) if v == WILDCARD)
}

pub(super) fn varies_with_request_method(&self) -> bool {
!matches!(&self.0, AllowMethodsInner::Const(_))
}

pub(super) fn to_header(&self, parts: &RequestParts) -> Option<(HeaderName, HeaderValue)> {
let allow_methods = match &self.0 {
AllowMethodsInner::Const(v) => v.clone()?,
Expand Down
4 changes: 4 additions & 0 deletions tower-http/src/cors/allow_origin.rs
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,10 @@ impl AllowOrigin {
matches!(&self.0, OriginInner::Const(v) if v == WILDCARD)
}

pub(super) fn varies_with_origin(&self) -> bool {
!matches!(&self.0, OriginInner::Const(_))
}

pub(super) fn to_future(
&self,
origin: Option<&HeaderValue>,
Expand Down
63 changes: 50 additions & 13 deletions tower-http/src/cors/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,7 @@

use allow_origin::AllowOriginFuture;
use bytes::{BufMut, BytesMut};
use http::{
header::{self, HeaderName},
HeaderMap, HeaderValue, Method, Request, Response,
};
use http::{header, HeaderMap, HeaderName, HeaderValue, Method, Request, Response};
use pin_project_lite::pin_project;
use std::{
future::Future,
Expand Down Expand Up @@ -99,6 +96,7 @@ pub struct CorsLayer {
expose_headers: ExposeHeaders,
max_age: MaxAge,
vary: Vary,
is_vary_custom: bool,
}

#[allow(clippy::declare_interior_mutable_const)]
Expand All @@ -122,6 +120,7 @@ impl CorsLayer {
expose_headers: Default::default(),
max_age: Default::default(),
vary: Default::default(),
is_vary_custom: false,
}
}

Expand Down Expand Up @@ -432,21 +431,55 @@ impl CorsLayer {

/// Set the value(s) of the [`Vary`][mdn] header.
///
/// In contrast to the other headers, this one has a non-empty default of
/// [`preflight_request_headers()`].
/// By default, this value is derived from whether CORS response headers are
/// request-dependent:
///
/// - `Origin` is included when `Access-Control-Allow-Origin` depends on the
/// request's `Origin` header (for example, origin lists or predicates).
/// - `Access-Control-Request-Method` is included when
/// `Access-Control-Allow-Methods` mirrors `Access-Control-Request-Method`.
/// - `Access-Control-Request-Headers` is included when
/// `Access-Control-Allow-Headers` mirrors `Access-Control-Request-Headers`.
/// - If none of those values are request-dependent, no `Vary` header is
/// added.
///
/// You only need to set this if you want to remove some of these defaults,
/// or if you use a closure for one of the other headers and want to add a
/// vary header accordingly.
/// Calling this method sets `Vary` explicitly and pins it to the provided
/// value, regardless of future changes to those other CORS settings.
///
/// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Vary
pub fn vary<T>(mut self, headers: T) -> Self
where
T: Into<Vary>,
{
self.vary = headers.into();
self.is_vary_custom = true;
self
}

/// Recomputes the `Vary` header, if it hasn't been set explicitly.
fn update_vary_header(&mut self) {
if !self.is_vary_custom {
let vary_origin = self.allow_origin.varies_with_origin();
let vary_method = self.allow_methods.varies_with_request_method();
let vary_headers = self.allow_headers.varies_with_request_headers();

if !(vary_origin || vary_method || vary_headers) {
self.vary = Vary::list([]);
} else {
let mut vary_header_names = Vec::new();
if vary_origin {
vary_header_names.push(header::ORIGIN);
}
if vary_method {
vary_header_names.push(header::ACCESS_CONTROL_REQUEST_METHOD);
}
if vary_headers {
vary_header_names.push(header::ACCESS_CONTROL_REQUEST_HEADERS);
}
self.vary = Vary::list(vary_header_names);
}
}
}
}

/// Represents a wildcard value (`*`) used with some CORS headers such as
Expand Down Expand Up @@ -493,10 +526,12 @@ impl<S> Layer<S> for CorsLayer {
fn layer(&self, inner: S) -> Self::Service {
ensure_usable_cors_rules(self);

Cors {
inner,
layer: self.clone(),
}
// Clone the layer to modify Vary header logic
let mut layer = self.clone();

layer.update_vary_header();

Cors { inner, layer }
}
}

Expand Down Expand Up @@ -641,6 +676,8 @@ impl<S> Cors<S> {
F: FnOnce(CorsLayer) -> CorsLayer,
{
self.layer = f(self.layer);

self.layer.update_vary_header();
self
}
}
Expand Down
197 changes: 185 additions & 12 deletions tower-http/src/cors/tests.rs
Original file line number Diff line number Diff line change
@@ -1,37 +1,98 @@
use std::convert::Infallible;

use crate::test_helpers::Body;
use http::{header, HeaderValue, Request, Response};
use crate::{cors::Vary, test_helpers::Body};
use http::{header, HeaderName, HeaderValue, Method, Request, Response};
use tower::{service_fn, util::ServiceExt, Layer};

use crate::cors::{AllowOrigin, CorsLayer};
use crate::cors::{AllowHeaders, AllowMethods, AllowOrigin, Any, Cors, CorsLayer};

const INITIAL_VARY_HEADERS: HeaderValue = HeaderValue::from_static("accept, accept-encoding");
const ADDITIONAL_VARY_HEADERS: [HeaderName; 3] = [
header::ORIGIN,
header::ACCESS_CONTROL_REQUEST_METHOD,
header::ACCESS_CONTROL_REQUEST_HEADERS,
];

#[tokio::test]
async fn permissive_vary_header_is_empty() {
let svc = CorsLayer::permissive().layer(service_fn(|_: Request<Body>| async {
Ok::<_, Infallible>(Response::new(Body::empty()))
}));

let req = Request::builder().body(Body::empty()).unwrap();

let res = svc.oneshot(req).await.unwrap();
assert!(
res.headers().get(header::VARY).is_none(),
"Vary header should be omitted for permissive config"
);
}

#[tokio::test]
#[allow(
clippy::declare_interior_mutable_const,
clippy::borrow_interior_mutable_const
)]
async fn vary_set_by_inner_service() {
const CUSTOM_VARY_HEADERS: HeaderValue = HeaderValue::from_static("accept, accept-encoding");
async fn include_custom_permissive_to_vary_set_by_inner_service() {
const PERMISSIVE_CORS_VARY_HEADERS: HeaderValue = HeaderValue::from_static(
"origin, access-control-request-method, access-control-request-headers",
);

async fn inner_svc(_: Request<Body>) -> Result<Response<Body>, Infallible> {
Ok(Response::builder()
.header(header::VARY, CUSTOM_VARY_HEADERS)
.header(header::VARY, INITIAL_VARY_HEADERS)
.body(Body::empty())
.unwrap())
}

let svc = CorsLayer::permissive().layer(service_fn(inner_svc));
let svc = CorsLayer::permissive()
.vary(Vary::list(ADDITIONAL_VARY_HEADERS))
.layer(service_fn(inner_svc));

let res = svc.oneshot(Request::new(Body::empty())).await.unwrap();
let mut vary_headers = res.headers().get_all(header::VARY).into_iter();
assert_eq!(vary_headers.next(), Some(&CUSTOM_VARY_HEADERS));
assert_eq!(vary_headers.next(), Some(&INITIAL_VARY_HEADERS));
assert_eq!(vary_headers.next(), Some(&PERMISSIVE_CORS_VARY_HEADERS));
assert_eq!(vary_headers.next(), None);
}

#[tokio::test]
async fn permissive_with_custom_vary_builder() {
let custom_vary = HeaderValue::from_static("x-foo");
let svc = CorsLayer::permissive()
.vary(Vary::list([header::HeaderName::from_static("x-foo")]))
.layer(service_fn(|_: Request<Body>| async {
Ok::<_, Infallible>(Response::new(Body::empty()))
}));

let req = Request::builder().body(Body::empty()).unwrap();
let res = svc.oneshot(req).await.unwrap();
let vary = res.headers().get(header::VARY);
assert_eq!(vary, Some(&custom_vary));
}

#[tokio::test]
async fn permissive_with_inner_and_builder_vary() {
let custom_vary = HeaderValue::from_static("x-foo");
let inner_vary = HeaderValue::from_static("accept-encoding");
let svc = CorsLayer::permissive()
.vary(Vary::list([header::HeaderName::from_static("x-foo")]))
.layer(service_fn(|_: Request<Body>| {
let inner_vary = inner_vary.clone();
async move {
Ok::<_, Infallible>(
Response::builder()
.header(header::VARY, inner_vary)
.body(Body::empty())
.unwrap(),
)
}
}));

let req = Request::builder().body(Body::empty()).unwrap();
let res = svc.oneshot(req).await.unwrap();
let mut vary_headers = res.headers().get_all(header::VARY).iter();
assert_eq!(vary_headers.next(), Some(&inner_vary));
assert_eq!(vary_headers.next(), Some(&custom_vary));
assert_eq!(vary_headers.next(), None);
}

#[tokio::test]
async fn test_allow_origin_async_predicate() {
#[derive(Clone)]
Expand Down Expand Up @@ -71,3 +132,115 @@ async fn test_allow_origin_async_predicate() {
let res = allow_origin.to_future(Some(&invalid_origin), &parts).await;
assert!(res.is_none());
}

#[tokio::test]
async fn derived_vary_header_for_mixed_wildcard_configuration() {
let svc = CorsLayer::new()
.allow_origin(Any)
.allow_methods(AllowMethods::mirror_request())
.allow_headers(AllowHeaders::mirror_request())
.layer(service_fn(|_: Request<Body>| async {
Ok::<_, Infallible>(Response::new(Body::empty()))
}));

let req = Request::builder()
.method(Method::OPTIONS)
.header(header::ORIGIN, "https://example.com")
.header(header::ACCESS_CONTROL_REQUEST_METHOD, "GET")
.header(header::ACCESS_CONTROL_REQUEST_HEADERS, "content-type")
.body(Body::empty())
.unwrap();

let res = svc.oneshot(req).await.unwrap();

assert_eq!(
res.headers().get(header::VARY),
Some(&HeaderValue::from_static(
"access-control-request-method, access-control-request-headers",
))
);
}

#[tokio::test]
async fn very_permissive_emits_vary_headers() {
let svc = CorsLayer::very_permissive().layer(service_fn(|_: Request<Body>| async {
Ok::<_, Infallible>(Response::new(Body::empty()))
}));

let req = Request::builder()
.method(Method::OPTIONS)
.header(header::ORIGIN, "https://example.com")
.header(header::ACCESS_CONTROL_REQUEST_METHOD, "GET")
.header(header::ACCESS_CONTROL_REQUEST_HEADERS, "content-type")
.body(Body::empty())
.unwrap();

let res = svc.oneshot(req).await.unwrap();

assert_eq!(
res.headers().get(header::VARY),
Some(&HeaderValue::from_static(
"origin, access-control-request-method, access-control-request-headers",
))
);
}

#[tokio::test]
async fn cors_map_layer_smoke_without_vary_header() {
let svc = Cors::new(service_fn(|_: Request<Body>| async {
Ok::<_, Infallible>(Response::new(Body::empty()))
}))
.allow_origin(Any)
.allow_methods(Any)
.allow_headers(Any);

let res = svc.oneshot(Request::new(Body::empty())).await.unwrap();

assert!(res.headers().get(header::VARY).is_none());
}

#[tokio::test]
async fn cors_map_layer_smoke_with_vary_header() {
let svc = Cors::new(service_fn(|_: Request<Body>| async {
Ok::<_, Infallible>(Response::new(Body::empty()))
}))
.allow_origin(Any)
.allow_methods(AllowMethods::mirror_request())
.allow_headers(Any);

let req = Request::builder()
.method(Method::OPTIONS)
.header(header::ORIGIN, "https://example.com")
.header(header::ACCESS_CONTROL_REQUEST_METHOD, "GET")
.body(Body::empty())
.unwrap();

let res = svc.oneshot(req).await.unwrap();

assert_eq!(
res.headers().get(header::VARY),
Some(&HeaderValue::from_static("access-control-request-method"))
);
}

#[tokio::test]
async fn exact_origin_does_not_emit_origin_vary_header() {
let svc = CorsLayer::new()
.allow_origin(AllowOrigin::exact(HeaderValue::from_static(
"http://example.com",
)))
.allow_methods([Method::GET])
.allow_headers([header::CONTENT_TYPE])
.layer(service_fn(|_: Request<Body>| async {
Ok::<_, Infallible>(Response::new(Body::empty()))
}));

let req = Request::builder()
.header(header::ORIGIN, "http://example.com")
.body(Body::empty())
.unwrap();

let res = svc.oneshot(req).await.unwrap();

assert!(res.headers().get(header::VARY).is_none());
}
2 changes: 1 addition & 1 deletion tower-http/src/cors/vary.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use http::header::{self, HeaderName, HeaderValue};

use super::preflight_request_headers;
use crate::cors::preflight_request_headers;

/// Holds configuration for how to set the [`Vary`][mdn] header.
///
Expand Down
Loading