Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
137 changes: 136 additions & 1 deletion crates/rmcp/src/transport/common/reqwest/streamable_http_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -299,9 +299,13 @@ impl StreamableHttpClientTransport<reqwest::Client> {
/// Disables idle connection pooling to avoid ~40 ms stalls caused by
/// TCP Delayed ACK on Linux when the previous response body was not
/// fully consumed before the pool attempts to reuse the connection.
///
/// Automatic redirects are disabled so caller-supplied custom headers
/// cannot be replayed to a redirect target.
fn default_http_client() -> reqwest::Client {
reqwest::Client::builder()
.pool_max_idle_per_host(0)
.redirect(reqwest::redirect::Policy::none())
.build()
.expect("failed to build default reqwest client")
}
Expand All @@ -313,7 +317,7 @@ mod tests {

use super::parse_json_rpc_error;
use crate::{
model::JsonRpcMessage,
model::{ClientJsonRpcMessage, ClientRequest, JsonRpcMessage, PingRequest, RequestId},
transport::streamable_http_client::{AuthRequiredError, InsufficientScopeError},
};

Expand Down Expand Up @@ -359,4 +363,135 @@ mod tests {
fn parse_json_rpc_error_rejects_non_error_bodies(#[case] body: &str) {
assert!(parse_json_rpc_error(body).is_none());
}

#[tokio::test]
async fn default_http_client_does_not_leak_custom_headers_to_redirect_target()
-> anyhow::Result<()> {
use std::{collections::HashMap, net::SocketAddr, sync::Arc};

use axum::{
Router, extract::State, http::StatusCode, response::IntoResponse, routing::post,
};
use http::{HeaderMap, HeaderName, HeaderValue, header::LOCATION};
use tokio::sync::Mutex;

use super::StreamableHttpClientTransport;
use crate::transport::streamable_http_client::{StreamableHttpClient, StreamableHttpError};

const API_KEY_HEADER: &str = "x-api-key";
const API_KEY_VALUE: &str = "secret";

type CapturedHeader = Arc<Mutex<Option<String>>>;

#[derive(Clone)]
struct RedirectState {
location: String,
captured_header: CapturedHeader,
}

async fn capture_api_key_header(headers: &HeaderMap, captured_header: &CapturedHeader) {
if let Some(value) = headers
.get(API_KEY_HEADER)
.and_then(|value| value.to_str().ok())
{
*captured_header.lock().await = Some(value.to_owned());
}
}

async fn redirect_handler(
State(state): State<RedirectState>,
headers: HeaderMap,
) -> impl IntoResponse {
capture_api_key_header(&headers, &state.captured_header).await;

(
StatusCode::TEMPORARY_REDIRECT,
[(LOCATION, state.location)],
"",
)
}

async fn redirected_handler(
State(captured_header): State<CapturedHeader>,
headers: HeaderMap,
) -> impl IntoResponse {
capture_api_key_header(&headers, &captured_header).await;

(
StatusCode::OK,
[(http::header::CONTENT_TYPE, "application/json")],
r#"{"jsonrpc":"2.0","id":1,"result":{}}"#,
)
}

let redirected_header = Arc::new(Mutex::new(None));
let redirected_listener =
tokio::net::TcpListener::bind(SocketAddr::from(([127, 0, 0, 1], 0))).await?;
let redirected_addr = redirected_listener.local_addr()?;
let redirected_server = tokio::spawn({
let redirected_header = redirected_header.clone();
async move {
let app = Router::new()
.route("/capture", post(redirected_handler))
.with_state(redirected_header);
axum::serve(redirected_listener, app).await
}
});

let original_header = Arc::new(Mutex::new(None));
let redirect_listener =
tokio::net::TcpListener::bind(SocketAddr::from(([127, 0, 0, 1], 0))).await?;
let redirect_addr = redirect_listener.local_addr()?;
let redirect_server = tokio::spawn({
let state = RedirectState {
location: format!("http://{redirected_addr}/capture"),
captured_header: original_header.clone(),
};
async move {
let app = Router::new()
.route("/mcp", post(redirect_handler))
.with_state(state);
axum::serve(redirect_listener, app).await
}
});

let mut custom_headers = HashMap::new();
custom_headers.insert(
HeaderName::from_static(API_KEY_HEADER),
HeaderValue::from_static(API_KEY_VALUE),
);
let message = ClientJsonRpcMessage::request(
ClientRequest::PingRequest(PingRequest::default()),
RequestId::Number(1),
);

let client = StreamableHttpClientTransport::<reqwest::Client>::default_http_client();
let result = client
.post_message(
Arc::<str>::from(format!("http://{redirect_addr}/mcp")),
message,
None,
None,
custom_headers,
)
.await;

assert!(
matches!(
result,
Err(StreamableHttpError::UnexpectedServerResponse(_))
),
"redirect response should be returned to the transport, got {result:?}"
);
assert_eq!(original_header.lock().await.as_deref(), Some(API_KEY_VALUE));
assert!(
redirected_header.lock().await.is_none(),
"custom headers should not be sent to redirect targets"
);

redirect_server.abort();
redirected_server.abort();

Ok(())
}
}