Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 27 additions & 31 deletions crates/rmcp/src/transport/streamable_http_server/tower.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1124,44 +1124,40 @@ where
}
}
} else {
let (session_id, transport) = self
.session_manager
.create_session()
.await
.map_err(internal_error_response("create session"))?;
// Capture init params for external store persistence before
// extensions are injected (which would require Clone).
let stored_init_params = if self.config.session_store.is_some() {
if let ClientJsonRpcMessage::Request(req) = &message {
if let ClientRequest::InitializeRequest(init_req) = &req.request {
Some(init_req.params.clone())
} else {
None
}
} else {
None
let stored_init_params = match &mut message {
ClientJsonRpcMessage::Request(req) => {
let ClientRequest::InitializeRequest(init_req) = &req.request else {
return Err(unexpected_message_response("initialize request"));
};
// Reject mismatched MCP-Protocol-Version header before binding the session to anything.
validate_header_matches_init_body(
&part.headers,
init_req.params.protocol_version.as_str(),
Some(req.id.clone()),
)?;
let stored_init_params = self
.config
.session_store
.as_ref()
.map(|_| init_req.params.clone());
// inject request part to extensions
req.request.extensions_mut().insert(part);
stored_init_params
}
} else {
None
};
if let ClientJsonRpcMessage::Request(req) = &mut message {
let ClientRequest::InitializeRequest(init_req) = &req.request else {
_ => {
return Err(unexpected_message_response("initialize request"));
};
// Reject mismatched MCP-Protocol-Version header before binding the session to anything.
validate_header_matches_init_body(
&part.headers,
init_req.params.protocol_version.as_str(),
Some(req.id.clone()),
)?;
// inject request part to extensions
req.request.extensions_mut().insert(part);
} else {
return Err(unexpected_message_response("initialize request"));
}
}
};
let service = self
.get_service()
.map_err(internal_error_response("get service"))?;
let (session_id, transport) = self
.session_manager
.create_session()
.await
.map_err(internal_error_response("create session"))?;
// spawn a task to serve the session
Self::spawn_session_worker(
self.session_manager.clone(),
Expand Down
40 changes: 39 additions & 1 deletion crates/rmcp/tests/test_streamable_http_protocol_version.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
#![cfg(not(feature = "local"))]
//! Regression tests for the `MCP-Protocol-Version` header / initialize body consistency check.
use std::sync::Arc;

use rmcp::transport::streamable_http_server::{
StreamableHttpServerConfig, StreamableHttpService, session::local::LocalSessionManager,
};
Expand All @@ -16,10 +18,17 @@ fn init_body(body_version: &str) -> String {

async fn spawn_server(
config: StreamableHttpServerConfig,
) -> (reqwest::Client, String, CancellationToken) {
spawn_server_with_manager(config, Arc::new(LocalSessionManager::default())).await
}

async fn spawn_server_with_manager(
config: StreamableHttpServerConfig,
session_manager: Arc<LocalSessionManager>,
) -> (reqwest::Client, String, CancellationToken) {
let ct = config.cancellation_token.clone();
let service: StreamableHttpService<Calculator, LocalSessionManager> =
StreamableHttpService::new(|| Ok(Calculator::new()), Default::default(), config);
StreamableHttpService::new(|| Ok(Calculator::new()), session_manager, config);

let router = axum::Router::new().nest_service("/mcp", service);
let tcp_listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
Expand Down Expand Up @@ -71,6 +80,17 @@ async fn post_init(
req.send().await.expect("send initialize request")
}

async fn post_non_initialize(client: &reqwest::Client, url: &str) -> reqwest::Response {
client
.post(url)
.header("Content-Type", "application/json")
.header("Accept", "application/json, text/event-stream")
.body(r#"{"jsonrpc":"2.0","id":1,"method":"tools/list","params":{}}"#)
.send()
.await
.expect("send non-initialize request")
}

#[tokio::test]
async fn stateless_init_rejects_when_header_older_than_body() -> anyhow::Result<()> {
let (client, url, ct) = spawn_server(stateless_json_config()).await;
Expand Down Expand Up @@ -147,3 +167,21 @@ async fn stateful_init_rejects_when_header_mismatches_body() -> anyhow::Result<(
ct.cancel();
Ok(())
}

#[tokio::test]
async fn stateful_rejected_initial_posts_do_not_create_sessions() -> anyhow::Result<()> {
let session_manager = Arc::new(LocalSessionManager::default());
let (client, url, ct) =
spawn_server_with_manager(stateful_config(), session_manager.clone()).await;

let response = post_non_initialize(&client, &url).await;
assert_eq!(response.status(), 422);
assert_eq!(session_manager.sessions.read().await.len(), 0);

let response = post_init(&client, &url, Some("2024-11-05"), "2025-11-25").await;
assert_eq!(response.status(), 400);
assert_eq!(session_manager.sessions.read().await.len(), 0);

ct.cancel();
Ok(())
}