diff --git a/crates/rmcp/src/transport/common/reqwest/streamable_http_client.rs b/crates/rmcp/src/transport/common/reqwest/streamable_http_client.rs index fffcd393..37c2b08f 100644 --- a/crates/rmcp/src/transport/common/reqwest/streamable_http_client.rs +++ b/crates/rmcp/src/transport/common/reqwest/streamable_http_client.rs @@ -299,9 +299,13 @@ impl StreamableHttpClientTransport { /// Disables idle connection pooling to avoid ~40 ms stalls caused by /// TCP Delayed ACK on Linux when the previous response body was not /// fully consumed before the pool attempts to reuse the connection. + /// + /// Automatic redirects are disabled so caller-supplied custom headers + /// cannot be replayed to a redirect target. fn default_http_client() -> reqwest::Client { reqwest::Client::builder() .pool_max_idle_per_host(0) + .redirect(reqwest::redirect::Policy::none()) .build() .expect("failed to build default reqwest client") } @@ -313,7 +317,7 @@ mod tests { use super::parse_json_rpc_error; use crate::{ - model::JsonRpcMessage, + model::{ClientJsonRpcMessage, ClientRequest, JsonRpcMessage, PingRequest, RequestId}, transport::streamable_http_client::{AuthRequiredError, InsufficientScopeError}, }; @@ -359,4 +363,135 @@ mod tests { fn parse_json_rpc_error_rejects_non_error_bodies(#[case] body: &str) { assert!(parse_json_rpc_error(body).is_none()); } + + #[tokio::test] + async fn default_http_client_does_not_leak_custom_headers_to_redirect_target() + -> anyhow::Result<()> { + use std::{collections::HashMap, net::SocketAddr, sync::Arc}; + + use axum::{ + Router, extract::State, http::StatusCode, response::IntoResponse, routing::post, + }; + use http::{HeaderMap, HeaderName, HeaderValue, header::LOCATION}; + use tokio::sync::Mutex; + + use super::StreamableHttpClientTransport; + use crate::transport::streamable_http_client::{StreamableHttpClient, StreamableHttpError}; + + const API_KEY_HEADER: &str = "x-api-key"; + const API_KEY_VALUE: &str = "secret"; + + type CapturedHeader = Arc>>; + + #[derive(Clone)] + struct RedirectState { + location: String, + captured_header: CapturedHeader, + } + + async fn capture_api_key_header(headers: &HeaderMap, captured_header: &CapturedHeader) { + if let Some(value) = headers + .get(API_KEY_HEADER) + .and_then(|value| value.to_str().ok()) + { + *captured_header.lock().await = Some(value.to_owned()); + } + } + + async fn redirect_handler( + State(state): State, + headers: HeaderMap, + ) -> impl IntoResponse { + capture_api_key_header(&headers, &state.captured_header).await; + + ( + StatusCode::TEMPORARY_REDIRECT, + [(LOCATION, state.location)], + "", + ) + } + + async fn redirected_handler( + State(captured_header): State, + headers: HeaderMap, + ) -> impl IntoResponse { + capture_api_key_header(&headers, &captured_header).await; + + ( + StatusCode::OK, + [(http::header::CONTENT_TYPE, "application/json")], + r#"{"jsonrpc":"2.0","id":1,"result":{}}"#, + ) + } + + let redirected_header = Arc::new(Mutex::new(None)); + let redirected_listener = + tokio::net::TcpListener::bind(SocketAddr::from(([127, 0, 0, 1], 0))).await?; + let redirected_addr = redirected_listener.local_addr()?; + let redirected_server = tokio::spawn({ + let redirected_header = redirected_header.clone(); + async move { + let app = Router::new() + .route("/capture", post(redirected_handler)) + .with_state(redirected_header); + axum::serve(redirected_listener, app).await + } + }); + + let original_header = Arc::new(Mutex::new(None)); + let redirect_listener = + tokio::net::TcpListener::bind(SocketAddr::from(([127, 0, 0, 1], 0))).await?; + let redirect_addr = redirect_listener.local_addr()?; + let redirect_server = tokio::spawn({ + let state = RedirectState { + location: format!("http://{redirected_addr}/capture"), + captured_header: original_header.clone(), + }; + async move { + let app = Router::new() + .route("/mcp", post(redirect_handler)) + .with_state(state); + axum::serve(redirect_listener, app).await + } + }); + + let mut custom_headers = HashMap::new(); + custom_headers.insert( + HeaderName::from_static(API_KEY_HEADER), + HeaderValue::from_static(API_KEY_VALUE), + ); + let message = ClientJsonRpcMessage::request( + ClientRequest::PingRequest(PingRequest::default()), + RequestId::Number(1), + ); + + let client = StreamableHttpClientTransport::::default_http_client(); + let result = client + .post_message( + Arc::::from(format!("http://{redirect_addr}/mcp")), + message, + None, + None, + custom_headers, + ) + .await; + + assert!( + matches!( + result, + Err(StreamableHttpError::UnexpectedServerResponse(_)) + ), + "redirect response should be returned to the transport, got {result:?}" + ); + assert_eq!(original_header.lock().await.as_deref(), Some(API_KEY_VALUE)); + assert!( + redirected_header.lock().await.is_none(), + "custom headers should not be sent to redirect targets" + ); + + redirect_server.abort(); + redirected_server.abort(); + + Ok(()) + } }