diff --git a/rust_http_proxy/src/forward_proxy_client.rs b/rust_http_proxy/src/forward_proxy_client.rs index 738fba5..92b65a1 100644 --- a/rust_http_proxy/src/forward_proxy_client.rs +++ b/rust_http_proxy/src/forward_proxy_client.rs @@ -13,7 +13,7 @@ use http::{HeaderMap, HeaderValue, Version, header}; use hyper::{ Request, Response, body::{self, Body}, - client::conn::http1, + client::conn::http1::{self}, }; use hyper_util::rt::TokioIo; use io_x::{CounterIO, TimeoutIO}; @@ -251,7 +251,7 @@ where let access_label = access_label.clone(); tokio::spawn(async move { - if let Err(err) = connection.await { + if let Err(err) = connection.with_upgrades().await { handle_http1_connection_error(err, access_label); } }); diff --git a/rust_http_proxy/src/proxy.rs b/rust_http_proxy/src/proxy.rs index 97f7fc5..91605b5 100644 --- a/rust_http_proxy/src/proxy.rs +++ b/rust_http_proxy/src/proxy.rs @@ -300,120 +300,70 @@ impl ProxyHandler { async fn handle_websocket_upgrade_forward( &self, mut req: Request, traffic_label: AccessLabel, ) -> Result>, io::Error> { - use tokio::io::{AsyncBufReadExt, AsyncWriteExt}; + info!("[forward] WebSocket upgrade request to {}", &traffic_label.target); - // 直接建立到上游的 TCP 连接 - let upstream_stream = connect_with_preference(&traffic_label.target, self.config.ipv6_first).await?; - info!("[forward] WebSocket TCP connection established to {}", &traffic_label.target); + // 在消费 request 之前先获取客户端的 upgrade future + let client_upgrade = hyper::upgrade::on(&mut req); - let mut upstream_io = - CounterIO::new(upstream_stream, METRICS.proxy_traffic.clone(), LabelImpl::new(traffic_label.clone())); - - // 构建 HTTP 请求行和头部 - let mut request_bytes = Vec::new(); - request_bytes.extend_from_slice( - format!( - "{} {} {:?}\r\n", - req.method(), - req.uri().path_and_query().map(|p| p.as_str()).unwrap_or("/"), - req.version() + // 使用 send_request_no_cache 发送请求 + let mut upstream_response = self + .forward_proxy_client + .send_request_no_cache( + req, + &traffic_label, + self.config.ipv6_first, + |stream: EitherTlsStream, access_label: AccessLabel| { + CounterIO::new(stream, METRICS.proxy_traffic.clone(), LabelImpl::new(access_label)) + }, ) - .as_bytes(), - ); - - // 添加所有请求头 - for (name, value) in req.headers() { - request_bytes.extend_from_slice(name.as_str().as_bytes()); - request_bytes.extend_from_slice(b": "); - request_bytes.extend_from_slice(value.as_bytes()); - request_bytes.extend_from_slice(b"\r\n"); - } - request_bytes.extend_from_slice(b"\r\n"); - - // 发送请求到上游 - upstream_io.write_all(&request_bytes).await?; - upstream_io.flush().await?; - - info!("[forward] WebSocket upgrade request sent to upstream"); - - // 读取上游响应 - // 将 upstream_io 包装在 BufReader 中,以便按行高效读取 HTTP 状态行和头部。 - // 在完成 HTTP 响应解析后,我们会通过 reader.into_inner() 取回底层的 upstream_io, - // 继续将同一个 TCP 连接用于 WebSocket 隧道的数据转发。 - let mut reader = tokio::io::BufReader::new(upstream_io); - let mut response_line = String::new(); - reader.read_line(&mut response_line).await?; + .await?; // 检查响应状态码 - let status_code = response_line.split_whitespace().nth(1).unwrap_or(""); - if status_code.is_empty() { - warn!("[forward] Failed to parse status code from upstream response: {}", response_line); - return Err(io::Error::other(format!( - "Failed to parse status code from upstream response: {}", - response_line - ))); - } - if status_code != "101" { - warn!("[forward] WebSocket upgrade failed, upstream returned: {}", response_line); - return Err(io::Error::other(format!("WebSocket upgrade failed: {}", response_line))); - } - - info!("[forward] WebSocket upgrade successful, status: {}", status_code); - - // 读取并保存响应头 - let mut response_headers = Vec::new(); - loop { - let mut header_line = String::new(); - reader.read_line(&mut header_line).await?; - if header_line == "\r\n" || header_line == "\n" { - break; - } - response_headers.push(header_line); + if upstream_response.status() != http::StatusCode::SWITCHING_PROTOCOLS { + warn!("[forward] WebSocket upgrade failed, upstream returned: {}", upstream_response.status()); + return Err(io::Error::other(format!("WebSocket upgrade failed: {}", upstream_response.status()))); } - // 从BufReader中取回原始stream - let upstream_io = reader.into_inner(); - - // 构造 101 响应给客户端,并添加上游返回的响应头 - let mut response_builder = Response::builder().status(http::StatusCode::SWITCHING_PROTOCOLS); - - // 添加上游返回的所有响应头 - for header_line in response_headers { - if let Some((name, value)) = header_line.trim_end().split_once(':') { - let name = name.trim(); - let value = value.trim(); - if let Ok(header_value) = HeaderValue::from_str(value) { - response_builder = response_builder.header(name, header_value); - } - } - } + info!("[forward] WebSocket upgrade successful, status: {}", upstream_response.status()); - let client_response = response_builder - .body(http_body_util::Empty::::new().map_err(|e| match e {}).boxed()) - .map_err(|e| io::Error::new(ErrorKind::InvalidData, e))?; + // 获取上游的 upgrade future + let upstream_upgrade = hyper::upgrade::on(&mut upstream_response); // 启动异步任务进行双向数据转发 tokio::spawn(async move { - match hyper::upgrade::on(&mut req).await { - Ok(client_upgraded) => { - if let Err(e) = Self::tunnel_websocket_forward(upstream_io, client_upgraded).await { - warn!("[forward] WebSocket tunnel error: {e:?}"); + match client_upgrade.await { + Ok(client_upgraded) => match upstream_upgrade.await { + Ok(upstream_upgraded) => { + if let Err(e) = + Self::tunnel_websocket_forward_upgraded(client_upgraded, upstream_upgraded).await + { + warn!("[forward] WebSocket tunnel error: {e:?}"); + } } - } + Err(e) => { + warn!("[forward] WebSocket upstream upgrade error: {e:?}"); + } + }, Err(e) => { warn!("[forward] WebSocket client upgrade error: {e:?}"); } } }); - Ok(client_response) + let response = upstream_response.map(|body| { + body.map_err(|e| { + let e = e; + io::Error::new(ErrorKind::InvalidData, e) + }) + .boxed() + }); + Ok(response) } - /// WebSocket 双向数据转发(正向代理场景) - async fn tunnel_websocket_forward( - mut upstream_io: CounterIO>, client: Upgraded, - ) -> io::Result<()> { + /// WebSocket 双向数据转发(正向代理场景)- Upgraded 版本 + async fn tunnel_websocket_forward_upgraded(client: Upgraded, upstream: Upgraded) -> io::Result<()> { let mut client_io = TokioIo::new(client); + let mut upstream_io = TokioIo::new(upstream); // 双向数据转发 let _ = tokio::io::copy_bidirectional(&mut client_io, &mut upstream_io).await?;