Skip to content

Commit 4f6cad6

Browse files
committed
feat(client): add some general HTTP/1 client middleware
1 parent b9dc3d2 commit 4f6cad6

File tree

2 files changed

+150
-0
lines changed

2 files changed

+150
-0
lines changed

src/client/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
/// Legacy implementations of `connect` module and `Client`
44
#[cfg(feature = "client-legacy")]
55
pub mod legacy;
6+
pub mod service;
67

78
#[cfg(feature = "client-proxy")]
89
pub mod proxy;

src/client/service.rs

Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
//! todo
2+
3+
use std::task::{Context, Poll};
4+
5+
use http::header::{HeaderValue, HOST};
6+
use http::{Method, Request, Uri};
7+
use tower_service::Service;
8+
9+
/// todo
10+
#[derive(Clone, Debug)]
11+
pub struct SetHost<S> {
12+
inner: S,
13+
}
14+
15+
/// todo
16+
#[derive(Clone, Debug)]
17+
pub struct Http1RequestTarget<S> {
18+
inner: S,
19+
}
20+
21+
// ===== impl SetHost =====
22+
23+
impl<S> SetHost<S> {
24+
/// todo
25+
pub fn new(inner: S) -> Self {
26+
SetHost { inner }
27+
}
28+
29+
/// Access the inner service.
30+
pub fn inner(&self) -> &S {
31+
&self.inner
32+
}
33+
}
34+
35+
impl<S, ReqBody> Service<Request<ReqBody>> for SetHost<S>
36+
where
37+
S: Service<Request<ReqBody>>,
38+
{
39+
type Response = S::Response;
40+
type Error = S::Error;
41+
type Future = S::Future;
42+
43+
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
44+
self.inner.poll_ready(cx)
45+
}
46+
47+
fn call(&mut self, mut req: Request<ReqBody>) -> Self::Future {
48+
if req.uri().authority().is_some() {
49+
let uri = req.uri().clone();
50+
req.headers_mut().entry(HOST).or_insert_with(|| {
51+
let hostname = uri.host().expect("authority implies host");
52+
if let Some(port) = get_non_default_port(&uri) {
53+
let s = format!("{hostname}:{port}");
54+
HeaderValue::from_str(&s)
55+
} else {
56+
HeaderValue::from_str(hostname)
57+
}
58+
.expect("uri host is valid header value")
59+
});
60+
}
61+
self.inner.call(req)
62+
}
63+
}
64+
65+
fn get_non_default_port(uri: &Uri) -> Option<http::uri::Port<&str>> {
66+
match (uri.port().map(|p| p.as_u16()), is_schema_secure(uri)) {
67+
(Some(443), true) => None,
68+
(Some(80), false) => None,
69+
_ => uri.port(),
70+
}
71+
}
72+
73+
fn is_schema_secure(uri: &Uri) -> bool {
74+
uri.scheme_str()
75+
.map(|scheme_str| matches!(scheme_str, "wss" | "https"))
76+
.unwrap_or_default()
77+
}
78+
79+
// ===== impl Http1RequestTarget =====
80+
81+
impl<S> Http1RequestTarget<S> {
82+
/// todo
83+
pub fn new(inner: S) -> Self {
84+
Http1RequestTarget { inner }
85+
}
86+
87+
/// Access the inner service.
88+
pub fn inner(&self) -> &S {
89+
&self.inner
90+
}
91+
}
92+
93+
impl<S, ReqBody> Service<Request<ReqBody>> for Http1RequestTarget<S>
94+
where
95+
S: Service<Request<ReqBody>>,
96+
{
97+
type Response = S::Response;
98+
type Error = S::Error;
99+
type Future = S::Future;
100+
101+
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
102+
self.inner.poll_ready(cx)
103+
}
104+
105+
fn call(&mut self, mut req: Request<ReqBody>) -> Self::Future {
106+
// CONNECT always sends authority-form, so check it first...
107+
if req.method() == Method::CONNECT {
108+
authority_form(req.uri_mut());
109+
} else {
110+
origin_form(req.uri_mut());
111+
}
112+
self.inner.call(req)
113+
}
114+
}
115+
116+
fn origin_form(uri: &mut Uri) {
117+
let path = match uri.path_and_query() {
118+
Some(path) if path.as_str() != "/" => {
119+
let mut parts = ::http::uri::Parts::default();
120+
parts.path_and_query = Some(path.clone());
121+
Uri::from_parts(parts).expect("path is valid uri")
122+
}
123+
_none_or_just_slash => {
124+
debug_assert!(Uri::default() == "/");
125+
Uri::default()
126+
}
127+
};
128+
*uri = path
129+
}
130+
131+
fn authority_form(uri: &mut Uri) {
132+
if let Some(path) = uri.path_and_query() {
133+
// `https://hyper.rs` would parse with `/` path, don't
134+
// annoy people about that...
135+
if path != "/" {
136+
tracing::debug!("HTTP/1.1 CONNECT request stripping path: {:?}", path);
137+
}
138+
}
139+
*uri = match uri.authority() {
140+
Some(auth) => {
141+
let mut parts = ::http::uri::Parts::default();
142+
parts.authority = Some(auth.clone());
143+
Uri::from_parts(parts).expect("authority is valid")
144+
}
145+
None => {
146+
unreachable!("authority_form with relative uri");
147+
}
148+
};
149+
}

0 commit comments

Comments
 (0)