Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,14 @@ edition = "2018"
include = ["Cargo.toml", "LICENSE", "src/**/*"]

[[bench]]
name="internal"
name = "internal"
harness = false

[dependencies]
hyper = { version = "0.14.18", features = ["client"] }
lazy_static = "1.4.0"
tokio = { version = "1.17.0", features = ["io-util", "rt"] }
tracing = "0.1.34"
visibility = { version = "0.0.1", optional = true }

[dev-dependencies]
hyper = { version = "0.14.18", features = ["server"] }
Expand All @@ -40,7 +40,7 @@ hyper-trust-dns = { version = "0.4.2", features = [
"rustls-http2",
"dnssec-ring",
"dns-over-https-rustls",
"rustls-webpki"
"rustls-webpki",
] }
rand = "0.8.5"
tungstenite = "0.17"
Expand All @@ -49,4 +49,4 @@ criterion = "0.3.5"

[features]

__bench=[]
__bench = ["dep:visibility"]
23 changes: 12 additions & 11 deletions benches/internal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ use hyper::client::HttpConnector;
use hyper::header::HeaderName;
use hyper::Uri;
use hyper::{HeaderMap, Request, Response};
use hyper_reverse_proxy::benches as internal_benches;
use hyper_reverse_proxy::ReverseProxy;
use rand::distributions::Alphanumeric;
use rand::prelude::*;
Expand All @@ -31,7 +30,7 @@ fn create_proxied_response(b: &mut Criterion) {

*response.headers_mut().unwrap() = headers_map.clone();

internal_benches::create_proxied_response(black_box(response.body(()).unwrap()));
hyper_reverse_proxy::create_proxied_response(black_box(response.body(()).unwrap()));
})
});
}
Expand All @@ -46,7 +45,7 @@ fn generate_string() -> String {
}

fn build_headers() -> HeaderMap {
let mut headers_map: HeaderMap = (&*internal_benches::hop_headers())
let mut headers_map: HeaderMap = (&hyper_reverse_proxy::HOP_HEADERS)
.iter()
.map(|el: &'static HeaderName| (el.clone(), generate_string().parse().unwrap()))
.collect();
Expand Down Expand Up @@ -108,7 +107,7 @@ fn forward_url_with_str_ending_slash(b: &mut Criterion) {
b.iter(|| {
let request = Request::builder().uri(uri.clone()).body(());

internal_benches::forward_uri(forward_url, &request.unwrap());
hyper_reverse_proxy::forward_uri(forward_url, &request.unwrap());
})
});
}
Expand All @@ -122,7 +121,7 @@ fn forward_url_with_str_ending_slash_and_query(b: &mut Criterion) {
t.iter(|| {
let request = Request::builder().uri(uri.clone()).body(());

internal_benches::forward_uri(forward_url, &request.unwrap());
hyper_reverse_proxy::forward_uri(forward_url, &request.unwrap());
})
});
}
Expand All @@ -136,7 +135,7 @@ fn forward_url_no_ending_slash(b: &mut Criterion) {
t.iter(|| {
let request = Request::builder().uri(uri.clone()).body(());

internal_benches::forward_uri(forward_url, &request.unwrap());
hyper_reverse_proxy::forward_uri(forward_url, &request.unwrap());
})
});
}
Expand All @@ -150,7 +149,7 @@ fn forward_url_with_query(b: &mut Criterion) {
t.iter(|| {
let request = Request::builder().uri(uri.clone()).body(());

internal_benches::forward_uri(forward_url, &request.unwrap());
hyper_reverse_proxy::forward_uri(forward_url, &request.unwrap());
})
});
}
Expand All @@ -175,12 +174,13 @@ fn create_proxied_request_forwarded_for_occupied(b: &mut Criterion) {

*request.headers_mut().unwrap() = headers_map.clone();

internal_benches::create_proxied_request(
hyper_reverse_proxy::create_proxied_request(
client_ip,
forward_url,
request.body(()).unwrap(),
None,
);
)
.unwrap();
})
});
}
Expand All @@ -200,12 +200,13 @@ fn create_proxied_request_forwarded_for_vacant(b: &mut Criterion) {

*request.headers_mut().unwrap() = headers_map.clone();

internal_benches::create_proxied_request(
hyper_reverse_proxy::create_proxied_request(
client_ip,
forward_url,
request.body(()).unwrap(),
None,
);
)
.unwrap();
})
});
}
Expand Down
92 changes: 34 additions & 58 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,31 +8,28 @@ use hyper::http::header::{InvalidHeaderValue, ToStrError};
use hyper::http::uri::InvalidUri;
use hyper::upgrade::OnUpgrade;
use hyper::{Body, Client, Error, Request, Response, StatusCode};
use lazy_static::lazy_static;
use std::net::IpAddr;
use tokio::io::copy_bidirectional;

lazy_static! {
static ref TE_HEADER: HeaderName = HeaderName::from_static("te");
static ref CONNECTION_HEADER: HeaderName = HeaderName::from_static("connection");
static ref UPGRADE_HEADER: HeaderName = HeaderName::from_static("upgrade");
static ref TRAILER_HEADER: HeaderName = HeaderName::from_static("trailer");
static ref TRAILERS_HEADER: HeaderName = HeaderName::from_static("trailers");
// A list of the headers, using hypers actual HeaderName comparison
static ref HOP_HEADERS: [HeaderName; 9] = [
CONNECTION_HEADER.clone(),
TE_HEADER.clone(),
TRAILER_HEADER.clone(),
HeaderName::from_static("keep-alive"),
HeaderName::from_static("proxy-connection"),
HeaderName::from_static("proxy-authenticate"),
HeaderName::from_static("proxy-authorization"),
HeaderName::from_static("transfer-encoding"),
HeaderName::from_static("upgrade"),
];

static ref X_FORWARDED_FOR: HeaderName = HeaderName::from_static("x-forwarded-for");
}
static TE_HEADER: HeaderName = HeaderName::from_static("te");
static CONNECTION_HEADER: HeaderName = HeaderName::from_static("connection");
static UPGRADE_HEADER: HeaderName = HeaderName::from_static("upgrade");
static TRAILERS_HEADER: HeaderName = HeaderName::from_static("trailers");
static X_FORWARDED_FOR: HeaderName = HeaderName::from_static("x-forwarded-for");

// A list of the headers, using hypers actual HeaderName comparison
#[cfg_attr(feature = "__bench", visibility::make(pub))]
static HOP_HEADERS: [HeaderName; 9] = [
HeaderName::from_static("connection"),
HeaderName::from_static("te"),
HeaderName::from_static("trailer"),
HeaderName::from_static("keep-alive"),
HeaderName::from_static("proxy-connection"),
HeaderName::from_static("proxy-authenticate"),
HeaderName::from_static("proxy-authorization"),
HeaderName::from_static("transfer-encoding"),
HeaderName::from_static("upgrade"),
];

#[derive(Debug)]
pub enum ProxyError {
Expand Down Expand Up @@ -69,25 +66,25 @@ impl From<InvalidHeaderValue> for ProxyError {
fn remove_hop_headers(headers: &mut HeaderMap) {
debug!("Removing hop headers");

for header in &*HOP_HEADERS {
for header in &HOP_HEADERS {
headers.remove(header);
}
}

fn get_upgrade_type(headers: &HeaderMap) -> Option<String> {
#[allow(clippy::blocks_in_if_conditions)]
if headers
.get(&*CONNECTION_HEADER)
.get(&CONNECTION_HEADER)
.map(|value| {
value
.to_str()
.unwrap()
.split(',')
.any(|e| e.trim() == *UPGRADE_HEADER)
.any(|e| e.trim() == UPGRADE_HEADER)
})
.unwrap_or(false)
{
if let Some(upgrade_value) = headers.get(&*UPGRADE_HEADER) {
if let Some(upgrade_value) = headers.get(&UPGRADE_HEADER) {
debug!(
"Found upgrade header with value: {}",
upgrade_value.to_str().unwrap().to_owned()
Expand All @@ -101,10 +98,10 @@ fn get_upgrade_type(headers: &HeaderMap) -> Option<String> {
}

fn remove_connection_headers(headers: &mut HeaderMap) {
if headers.get(&*CONNECTION_HEADER).is_some() {
if headers.get(&CONNECTION_HEADER).is_some() {
debug!("Removing connection headers");

let value = headers.get(&*CONNECTION_HEADER).cloned().unwrap();
let value = headers.get(&CONNECTION_HEADER).cloned().unwrap();

for name in value.to_str().unwrap().split(',') {
if !name.trim().is_empty() {
Expand All @@ -114,6 +111,7 @@ fn remove_connection_headers(headers: &mut HeaderMap) {
}
}

#[cfg_attr(feature = "__bench", visibility::make(pub))]
fn create_proxied_response<B>(mut response: Response<B>) -> Response<B> {
info!("Creating proxied response");

Expand All @@ -123,6 +121,7 @@ fn create_proxied_response<B>(mut response: Response<B>) -> Response<B> {
response
}

#[cfg_attr(feature = "__bench", visibility::make(pub))]
fn forward_uri<B>(forward_url: &str, req: &Request<B>) -> String {
debug!("Building forward uri");

Expand Down Expand Up @@ -201,6 +200,7 @@ fn forward_uri<B>(forward_url: &str, req: &Request<B>) -> String {
url.parse().unwrap()
}

#[cfg_attr(feature = "__bench", visibility::make(pub))]
fn create_proxied_request<B>(
client_ip: IpAddr,
forward_url: &str,
Expand All @@ -211,13 +211,13 @@ fn create_proxied_request<B>(

let contains_te_trailers_value = request
.headers()
.get(&*TE_HEADER)
.get(&TE_HEADER)
.map(|value| {
value
.to_str()
.unwrap()
.split(',')
.any(|e| e.trim() == *TRAILERS_HEADER)
.any(|e| e.trim() == TRAILERS_HEADER)
})
.unwrap_or(false);

Expand All @@ -239,22 +239,22 @@ fn create_proxied_request<B>(

request
.headers_mut()
.insert(&*TE_HEADER, HeaderValue::from_static("trailers"));
.insert(&TE_HEADER, HeaderValue::from_static("trailers"));
}

if let Some(value) = upgrade_type {
debug!("Repopulate upgrade headers");

request
.headers_mut()
.insert(&*UPGRADE_HEADER, value.parse().unwrap());
.insert(&UPGRADE_HEADER, value.parse().unwrap());
request
.headers_mut()
.insert(&*CONNECTION_HEADER, HeaderValue::from_static("UPGRADE"));
.insert(&CONNECTION_HEADER, HeaderValue::from_static("UPGRADE"));
}

// Add forwarding information in the headers
match request.headers_mut().entry(&*X_FORWARDED_FOR) {
match request.headers_mut().entry(&X_FORWARDED_FOR) {
hyper::header::Entry::Vacant(entry) => {
debug!("X-Fowraded-for header was vacant");
entry.insert(client_ip.to_string().parse()?);
Expand Down Expand Up @@ -362,27 +362,3 @@ impl<T: hyper::client::connect::Connect + Clone + Send + Sync + 'static> Reverse
call::<T>(client_ip, forward_uri, request, &self.client).await
}
}

#[cfg(feature = "__bench")]
pub mod benches {
pub fn hop_headers() -> &'static [crate::HeaderName] {
&*super::HOP_HEADERS
}

pub fn create_proxied_response<T>(response: crate::Response<T>) {
super::create_proxied_response(response);
}

pub fn forward_uri<B>(forward_url: &str, req: &crate::Request<B>) {
super::forward_uri(forward_url, req);
}

pub fn create_proxied_request<B>(
client_ip: crate::IpAddr,
forward_url: &str,
request: crate::Request<B>,
upgrade_type: Option<&String>,
) {
super::create_proxied_request(client_ip, forward_url, request, upgrade_type).unwrap();
}
}