From 5f00a3886408eb120c28e281cb080a4ca49defeb Mon Sep 17 00:00:00 2001 From: arloor Date: Fri, 16 Jan 2026 11:11:53 +0800 Subject: [PATCH 1/3] =?UTF-8?q?=E5=AE=9E=E7=8E=B0=20WebSocket=20=E5=8D=87?= =?UTF-8?q?=E7=BA=A7=E8=AF=B7=E6=B1=82=E5=A4=84=E7=90=86=E5=92=8C=E5=8F=8C?= =?UTF-8?q?=E5=90=91=E6=95=B0=E6=8D=AE=E8=BD=AC=E5=8F=91=E5=8A=9F=E8=83=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- rust_http_proxy/src/proxy.rs | 130 +++++++++++++++++++++++++++++++++++ 1 file changed, 130 insertions(+) diff --git a/rust_http_proxy/src/proxy.rs b/rust_http_proxy/src/proxy.rs index c4825e9..ce11f45 100644 --- a/rust_http_proxy/src/proxy.rs +++ b/rust_http_proxy/src/proxy.rs @@ -296,6 +296,86 @@ impl ProxyHandler { // 默认为正向代理 } + /// 处理 WebSocket 升级请求(正向代理场景) + async fn handle_websocket_upgrade_forward( + &self, upstream_req: Request, client_upgrade_fut: hyper::upgrade::OnUpgrade, + traffic_label: AccessLabel, + ) -> Result>, io::Error> { + // 发送升级请求到上游 + let mut upstream_resp = self + .forward_proxy_client + .send_request( + upstream_req, + &traffic_label, + self.config.ipv6_first, + |stream: EitherTlsStream, access_label: AccessLabel| { + CounterIO::new(stream, METRICS.proxy_traffic.clone(), LabelImpl::new(access_label)) + }, + ) + .await?; + + // 检查上游是否返回 101 Switching Protocols + if upstream_resp.status() != http::StatusCode::SWITCHING_PROTOCOLS { + warn!("[forward] WebSocket upgrade failed, upstream returned: {}", upstream_resp.status()); + return Ok(upstream_resp.map(|body| body.map_err(|e| io::Error::new(ErrorKind::InvalidData, e)).boxed())); + } + + info!("[forward] WebSocket upgrade successful, status: {}", upstream_resp.status()); + + // 准备上游的升级 + let upstream_upgrade_fut = hyper::upgrade::on(&mut upstream_resp); + + // 构造 101 响应给客户端,复制上游的响应头 + let mut client_response_builder = Response::builder().status(http::StatusCode::SWITCHING_PROTOCOLS); + + // 复制所有响应头 + if let Some(headers) = client_response_builder.headers_mut() { + for (key, value) in upstream_resp.headers() { + headers.insert(key.clone(), value.clone()); + } + } + + let client_response = client_response_builder + .body(http_body_util::Empty::::new().map_err(|e| match e {}).boxed()) + .map_err(|e| io::Error::new(ErrorKind::InvalidData, e))?; + + // 启动异步任务进行双向数据转发 + tokio::spawn(async move { + match (upstream_upgrade_fut.await, client_upgrade_fut.await) { + (Ok(upstream_upgraded), Ok(client_upgraded)) => { + if let Err(e) = + Self::tunnel_websocket_forward(upstream_upgraded, client_upgraded, traffic_label).await + { + warn!("[forward] WebSocket tunnel error: {e:?}"); + } + } + (Err(e), _) | (_, Err(e)) => { + warn!("[forward] WebSocket upgrade error: {e:?}"); + } + } + }); + + Ok(client_response) + } + + /// WebSocket 双向数据转发(正向代理场景) + async fn tunnel_websocket_forward( + upstream: Upgraded, client: Upgraded, _traffic_label: AccessLabel, + ) -> io::Result<()> { + let mut upstream_io = TokioIo::new(upstream); + let mut client_io = TokioIo::new(client); + + // 双向数据转发 + let _ = tokio::io::copy_bidirectional( + &mut client_io, + // &mut CounterIO::new(client_io, METRICS.proxy_traffic.clone(), LabelImpl::new(traffic_label.clone())), + &mut upstream_io, + ) + .await?; + + Ok(()) + } + /// 代理普通请求 /// HTTP/1.1 GET/POST/PUT/DELETE/HEAD async fn simple_proxy( @@ -309,6 +389,32 @@ impl ProxyHandler { username, relay_over_tls: None, }; + + // 先检测是否是 WebSocket 升级请求(在 request 被消费之前) + let is_websocket = req + .headers() + .get(http::header::UPGRADE) + .and_then(|v| v.to_str().ok()) + .map(|v| v.eq_ignore_ascii_case("websocket")) + .unwrap_or(false); + + if is_websocket { + info!( + "[forward] WebSocket upgrade request: {:^35} ==> {} {:?}", + client_socket_addr.to_string(), + req.method(), + req.uri(), + ); + // 在消费 request 之前,先获取客户端的 upgrade future + let client_upgrade_fut = hyper::upgrade::on(&mut req); + + mod_http1_proxy_req(&mut req)?; + + return self + .handle_websocket_upgrade_forward(req, client_upgrade_fut, access_label) + .await; + } + mod_http1_proxy_req(&mut req)?; match self .forward_proxy_client @@ -344,6 +450,15 @@ impl ProxyHandler { username, relay_over_tls: Some(forward_bypass_config.is_https), }; + + // 先检测是否是 WebSocket 升级请求(在 request 被消费之前) + let is_websocket = req + .headers() + .get(http::header::UPGRADE) + .and_then(|v| v.to_str().ok()) + .map(|v| v.eq_ignore_ascii_case("websocket")) + .unwrap_or(false); + // 如果配置了 username 和 password,添加 Proxy-Authorization 头 if let (Some(username), Some(password)) = (&forward_bypass_config.username, &forward_bypass_config.password) { let credentials = format!("{}:{}", username, password); @@ -363,6 +478,21 @@ impl ProxyHandler { info!("change host header: {origin:?} -> {host_header:?}"); } + if is_websocket { + info!( + "[forward_bypass] WebSocket upgrade request: {:^35} ==> {} {:?}", + client_socket_addr.to_string(), + req.method(), + req.uri(), + ); + // 在消费 request 之前,先获取客户端的 upgrade future + let client_upgrade_fut = hyper::upgrade::on(&mut req); + + return self + .handle_websocket_upgrade_forward(req, client_upgrade_fut, access_label) + .await; + } + warn!("bypass {:?} {} {}", req.version(), req.method(), req.uri()); match self From 13b2c25692ee35b345d4c4ad322dda40690702f4 Mon Sep 17 00:00:00 2001 From: arloor Date: Fri, 16 Jan 2026 12:38:32 +0800 Subject: [PATCH 2/3] =?UTF-8?q?=E5=AE=9E=E7=8E=B0=20WebSocket=20=E5=8D=87?= =?UTF-8?q?=E7=BA=A7=E8=AF=B7=E6=B1=82=E7=9A=84=E5=A4=84=E7=90=86=E9=80=BB?= =?UTF-8?q?=E8=BE=91=EF=BC=8C=E4=BC=98=E5=8C=96=E8=AF=B7=E6=B1=82=E5=8F=91?= =?UTF-8?q?=E9=80=81=E5=92=8C=E5=93=8D=E5=BA=94=E6=8E=A5=E6=94=B6=E6=B5=81?= =?UTF-8?q?=E7=A8=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- rust_http_proxy/src/forward_proxy_client.rs | 28 +++- rust_http_proxy/src/proxy.rs | 135 ++++++++++++-------- 2 files changed, 104 insertions(+), 59 deletions(-) diff --git a/rust_http_proxy/src/forward_proxy_client.rs b/rust_http_proxy/src/forward_proxy_client.rs index 723638c..b1946c4 100644 --- a/rust_http_proxy/src/forward_proxy_client.rs +++ b/rust_http_proxy/src/forward_proxy_client.rs @@ -46,13 +46,33 @@ where } } + #[allow(unused)] + pub async fn send_request_no_cache( + &self, req: Request, access_label: &AccessLabel, ipv6_first: Option, + stream_map_func: impl FnOnce(EitherTlsStream, AccessLabel) -> CounterIO>, + ) -> Result, std::io::Error> { + // Make a new connection + let mut c = match HttpConnection::connect(access_label, ipv6_first, stream_map_func).await { + Ok(c) => c, + Err(err) => { + error!("failed to connect to host: {}, error: {}", &access_label.target, err); + return Err(io::Error::new(io::ErrorKind::InvalidData, err)); + } + }; + + trace!("HTTP making request to host: {access_label}, request: {req:?}"); + let response = c.send_request(req).await.map_err(io::Error::other)?; + trace!("HTTP received response from host: {access_label}, response: {response:?}"); + Ok(response) + } + /// Make HTTP requests #[inline] pub async fn send_request( &self, req: Request, access_label: &AccessLabel, ipv6_first: Option, stream_map_func: impl FnOnce(EitherTlsStream, AccessLabel) -> CounterIO>, ) -> Result, std::io::Error> { - // 1. Check if there is an available client + // 1. Check if there is an available client (skip for WebSocket upgrades) if let Some(c) = self.get_cached_connection(access_label).await { debug!("HTTP client for host: {} taken from cache", &access_label); match self.send_request_conn(access_label, c, req).await { @@ -97,7 +117,7 @@ where None } - async fn send_request_conn( + pub(crate) async fn send_request_conn( &self, access_label: &AccessLabel, mut c: HttpConnection, req: Request, ) -> hyper::Result> { trace!("HTTP making request to host: {access_label}, request: {req:?}"); @@ -164,7 +184,7 @@ fn get_keep_alive_val(values: header::GetAll) -> Option { } #[allow(dead_code)] -enum HttpConnection { +pub(crate) enum HttpConnection { Http1(http1::SendRequest), } @@ -174,7 +194,7 @@ where B::Data: Send, B::Error: Into>, { - async fn connect( + pub(crate) async fn connect( access_label: &AccessLabel, ipv6_first: Option, stream_map_func: impl FnOnce(EitherTlsStream, AccessLabel) -> CounterIO>, ) -> io::Result> { diff --git a/rust_http_proxy/src/proxy.rs b/rust_http_proxy/src/proxy.rs index ce11f45..6a69176 100644 --- a/rust_http_proxy/src/proxy.rs +++ b/rust_http_proxy/src/proxy.rs @@ -298,59 +298,100 @@ impl ProxyHandler { /// 处理 WebSocket 升级请求(正向代理场景) async fn handle_websocket_upgrade_forward( - &self, upstream_req: Request, client_upgrade_fut: hyper::upgrade::OnUpgrade, - traffic_label: AccessLabel, + &self, mut req: Request, traffic_label: AccessLabel, ) -> Result>, io::Error> { - // 发送升级请求到上游 - let mut upstream_resp = self - .forward_proxy_client - .send_request( - upstream_req, - &traffic_label, - self.config.ipv6_first, - |stream: EitherTlsStream, access_label: AccessLabel| { - CounterIO::new(stream, METRICS.proxy_traffic.clone(), LabelImpl::new(access_label)) - }, + use tokio::io::{AsyncBufReadExt, AsyncWriteExt}; + + // 直接建立到上游的 TCP 连接 + let upstream_stream = connect_with_preference(&traffic_label.target, self.config.ipv6_first).await?; + info!("[forward] WebSocket TCP connection established to {}", &traffic_label.target); + + let mut upstream_io = + CounterIO::new(upstream_stream, METRICS.proxy_traffic.clone(), LabelImpl::new(traffic_label.clone())); + + // 构建 HTTP 请求行和头部 + let mut request_bytes = Vec::new(); + request_bytes.extend_from_slice( + format!( + "{} {} {:?}\r\n", + req.method(), + req.uri().path_and_query().map(|p| p.as_str()).unwrap_or("/"), + req.version() ) - .await?; + .as_bytes(), + ); + + // 添加所有请求头 + for (name, value) in req.headers() { + request_bytes.extend_from_slice(name.as_str().as_bytes()); + request_bytes.extend_from_slice(b": "); + request_bytes.extend_from_slice(value.as_bytes()); + request_bytes.extend_from_slice(b"\r\n"); + } + request_bytes.extend_from_slice(b"\r\n"); + + // 发送请求到上游 + upstream_io.write_all(&request_bytes).await?; + upstream_io.flush().await?; - // 检查上游是否返回 101 Switching Protocols - if upstream_resp.status() != http::StatusCode::SWITCHING_PROTOCOLS { - warn!("[forward] WebSocket upgrade failed, upstream returned: {}", upstream_resp.status()); - return Ok(upstream_resp.map(|body| body.map_err(|e| io::Error::new(ErrorKind::InvalidData, e)).boxed())); + info!("[forward] WebSocket upgrade request sent to upstream"); + + // 读取上游响应 + let mut reader = tokio::io::BufReader::new(upstream_io); + let mut response_line = String::new(); + reader.read_line(&mut response_line).await?; + + // 检查响应状态码 + let status_code = response_line.split_whitespace().nth(1).unwrap_or(""); + if status_code != "101" { + warn!("[forward] WebSocket upgrade failed, upstream returned: {}", response_line); + return Err(io::Error::other(format!("WebSocket upgrade failed: {}", response_line))); } - info!("[forward] WebSocket upgrade successful, status: {}", upstream_resp.status()); + info!("[forward] WebSocket upgrade successful, status: {}", status_code); + + // 读取并保存响应头 + let mut response_headers = Vec::new(); + loop { + let mut header_line = String::new(); + reader.read_line(&mut header_line).await?; + if header_line == "\r\n" || header_line == "\n" { + break; + } + response_headers.push(header_line); + } - // 准备上游的升级 - let upstream_upgrade_fut = hyper::upgrade::on(&mut upstream_resp); + // 从BufReader中取回原始stream + let upstream_io = reader.into_inner(); - // 构造 101 响应给客户端,复制上游的响应头 - let mut client_response_builder = Response::builder().status(http::StatusCode::SWITCHING_PROTOCOLS); + // 构造 101 响应给客户端,并添加上游返回的响应头 + let mut response_builder = Response::builder().status(http::StatusCode::SWITCHING_PROTOCOLS); - // 复制所有响应头 - if let Some(headers) = client_response_builder.headers_mut() { - for (key, value) in upstream_resp.headers() { - headers.insert(key.clone(), value.clone()); + // 添加上游返回的所有响应头 + for header_line in response_headers { + if let Some((name, value)) = header_line.trim_end().split_once(':') { + let name = name.trim(); + let value = value.trim(); + if let Ok(header_value) = HeaderValue::from_str(value) { + response_builder = response_builder.header(name, header_value); + } } } - let client_response = client_response_builder + let client_response = response_builder .body(http_body_util::Empty::::new().map_err(|e| match e {}).boxed()) .map_err(|e| io::Error::new(ErrorKind::InvalidData, e))?; // 启动异步任务进行双向数据转发 tokio::spawn(async move { - match (upstream_upgrade_fut.await, client_upgrade_fut.await) { - (Ok(upstream_upgraded), Ok(client_upgraded)) => { - if let Err(e) = - Self::tunnel_websocket_forward(upstream_upgraded, client_upgraded, traffic_label).await - { + match hyper::upgrade::on(&mut req).await { + Ok(client_upgraded) => { + if let Err(e) = Self::tunnel_websocket_forward(upstream_io, client_upgraded).await { warn!("[forward] WebSocket tunnel error: {e:?}"); } } - (Err(e), _) | (_, Err(e)) => { - warn!("[forward] WebSocket upgrade error: {e:?}"); + Err(e) => { + warn!("[forward] WebSocket client upgrade error: {e:?}"); } } }); @@ -360,18 +401,12 @@ impl ProxyHandler { /// WebSocket 双向数据转发(正向代理场景) async fn tunnel_websocket_forward( - upstream: Upgraded, client: Upgraded, _traffic_label: AccessLabel, + mut upstream_io: CounterIO>, client: Upgraded, ) -> io::Result<()> { - let mut upstream_io = TokioIo::new(upstream); let mut client_io = TokioIo::new(client); // 双向数据转发 - let _ = tokio::io::copy_bidirectional( - &mut client_io, - // &mut CounterIO::new(client_io, METRICS.proxy_traffic.clone(), LabelImpl::new(traffic_label.clone())), - &mut upstream_io, - ) - .await?; + let _ = tokio::io::copy_bidirectional(&mut client_io, &mut upstream_io).await?; Ok(()) } @@ -398,6 +433,7 @@ impl ProxyHandler { .map(|v| v.eq_ignore_ascii_case("websocket")) .unwrap_or(false); + mod_http1_proxy_req(&mut req)?; if is_websocket { info!( "[forward] WebSocket upgrade request: {:^35} ==> {} {:?}", @@ -405,17 +441,10 @@ impl ProxyHandler { req.method(), req.uri(), ); - // 在消费 request 之前,先获取客户端的 upgrade future - let client_upgrade_fut = hyper::upgrade::on(&mut req); - - mod_http1_proxy_req(&mut req)?; - return self - .handle_websocket_upgrade_forward(req, client_upgrade_fut, access_label) - .await; + return self.handle_websocket_upgrade_forward(req, access_label).await; } - mod_http1_proxy_req(&mut req)?; match self .forward_proxy_client .send_request( @@ -485,12 +514,8 @@ impl ProxyHandler { req.method(), req.uri(), ); - // 在消费 request 之前,先获取客户端的 upgrade future - let client_upgrade_fut = hyper::upgrade::on(&mut req); - return self - .handle_websocket_upgrade_forward(req, client_upgrade_fut, access_label) - .await; + return self.handle_websocket_upgrade_forward(req, access_label).await; } warn!("bypass {:?} {} {}", req.version(), req.method(), req.uri()); From 547c8a0e38f36864d8b9b23653e8faf4f3308ac5 Mon Sep 17 00:00:00 2001 From: arloor Date: Fri, 16 Jan 2026 12:47:22 +0800 Subject: [PATCH 3/3] =?UTF-8?q?=E4=BC=98=E5=8C=96=20WebSocket=20=E5=8D=87?= =?UTF-8?q?=E7=BA=A7=E8=AF=B7=E6=B1=82=E7=9A=84=E5=AE=A2=E6=88=B7=E7=AB=AF?= =?UTF-8?q?=E8=BF=9E=E6=8E=A5=E6=A3=80=E6=9F=A5=EF=BC=8C=E5=A2=9E=E5=BC=BA?= =?UTF-8?q?=E9=94=99=E8=AF=AF=E5=A4=84=E7=90=86=E9=80=BB=E8=BE=91?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- rust_http_proxy/src/forward_proxy_client.rs | 2 +- rust_http_proxy/src/proxy.rs | 10 ++++++++++ 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/rust_http_proxy/src/forward_proxy_client.rs b/rust_http_proxy/src/forward_proxy_client.rs index b1946c4..4c91d6b 100644 --- a/rust_http_proxy/src/forward_proxy_client.rs +++ b/rust_http_proxy/src/forward_proxy_client.rs @@ -72,7 +72,7 @@ where &self, req: Request, access_label: &AccessLabel, ipv6_first: Option, stream_map_func: impl FnOnce(EitherTlsStream, AccessLabel) -> CounterIO>, ) -> Result, std::io::Error> { - // 1. Check if there is an available client (skip for WebSocket upgrades) + // 1. Check if there is an available client if let Some(c) = self.get_cached_connection(access_label).await { debug!("HTTP client for host: {} taken from cache", &access_label); match self.send_request_conn(access_label, c, req).await { diff --git a/rust_http_proxy/src/proxy.rs b/rust_http_proxy/src/proxy.rs index 6a69176..97f7fc5 100644 --- a/rust_http_proxy/src/proxy.rs +++ b/rust_http_proxy/src/proxy.rs @@ -337,12 +337,22 @@ impl ProxyHandler { info!("[forward] WebSocket upgrade request sent to upstream"); // 读取上游响应 + // 将 upstream_io 包装在 BufReader 中,以便按行高效读取 HTTP 状态行和头部。 + // 在完成 HTTP 响应解析后,我们会通过 reader.into_inner() 取回底层的 upstream_io, + // 继续将同一个 TCP 连接用于 WebSocket 隧道的数据转发。 let mut reader = tokio::io::BufReader::new(upstream_io); let mut response_line = String::new(); reader.read_line(&mut response_line).await?; // 检查响应状态码 let status_code = response_line.split_whitespace().nth(1).unwrap_or(""); + if status_code.is_empty() { + warn!("[forward] Failed to parse status code from upstream response: {}", response_line); + return Err(io::Error::other(format!( + "Failed to parse status code from upstream response: {}", + response_line + ))); + } if status_code != "101" { warn!("[forward] WebSocket upgrade failed, upstream returned: {}", response_line); return Err(io::Error::other(format!("WebSocket upgrade failed: {}", response_line)));