diff --git a/crates/rmcp/src/transport/streamable_http_server/tower.rs b/crates/rmcp/src/transport/streamable_http_server/tower.rs index cd2f5f1e..8ebec4e5 100644 --- a/crates/rmcp/src/transport/streamable_http_server/tower.rs +++ b/crates/rmcp/src/transport/streamable_http_server/tower.rs @@ -1124,44 +1124,40 @@ where } } } else { - let (session_id, transport) = self - .session_manager - .create_session() - .await - .map_err(internal_error_response("create session"))?; // Capture init params for external store persistence before // extensions are injected (which would require Clone). - let stored_init_params = if self.config.session_store.is_some() { - if let ClientJsonRpcMessage::Request(req) = &message { - if let ClientRequest::InitializeRequest(init_req) = &req.request { - Some(init_req.params.clone()) - } else { - None - } - } else { - None + let stored_init_params = match &mut message { + ClientJsonRpcMessage::Request(req) => { + let ClientRequest::InitializeRequest(init_req) = &req.request else { + return Err(unexpected_message_response("initialize request")); + }; + // Reject mismatched MCP-Protocol-Version header before binding the session to anything. + validate_header_matches_init_body( + &part.headers, + init_req.params.protocol_version.as_str(), + Some(req.id.clone()), + )?; + let stored_init_params = self + .config + .session_store + .as_ref() + .map(|_| init_req.params.clone()); + // inject request part to extensions + req.request.extensions_mut().insert(part); + stored_init_params } - } else { - None - }; - if let ClientJsonRpcMessage::Request(req) = &mut message { - let ClientRequest::InitializeRequest(init_req) = &req.request else { + _ => { return Err(unexpected_message_response("initialize request")); - }; - // Reject mismatched MCP-Protocol-Version header before binding the session to anything. - validate_header_matches_init_body( - &part.headers, - init_req.params.protocol_version.as_str(), - Some(req.id.clone()), - )?; - // inject request part to extensions - req.request.extensions_mut().insert(part); - } else { - return Err(unexpected_message_response("initialize request")); - } + } + }; let service = self .get_service() .map_err(internal_error_response("get service"))?; + let (session_id, transport) = self + .session_manager + .create_session() + .await + .map_err(internal_error_response("create session"))?; // spawn a task to serve the session Self::spawn_session_worker( self.session_manager.clone(), diff --git a/crates/rmcp/tests/test_streamable_http_protocol_version.rs b/crates/rmcp/tests/test_streamable_http_protocol_version.rs index 3500266b..0ed61c0e 100644 --- a/crates/rmcp/tests/test_streamable_http_protocol_version.rs +++ b/crates/rmcp/tests/test_streamable_http_protocol_version.rs @@ -1,5 +1,7 @@ #![cfg(not(feature = "local"))] //! Regression tests for the `MCP-Protocol-Version` header / initialize body consistency check. +use std::sync::Arc; + use rmcp::transport::streamable_http_server::{ StreamableHttpServerConfig, StreamableHttpService, session::local::LocalSessionManager, }; @@ -16,10 +18,17 @@ fn init_body(body_version: &str) -> String { async fn spawn_server( config: StreamableHttpServerConfig, +) -> (reqwest::Client, String, CancellationToken) { + spawn_server_with_manager(config, Arc::new(LocalSessionManager::default())).await +} + +async fn spawn_server_with_manager( + config: StreamableHttpServerConfig, + session_manager: Arc, ) -> (reqwest::Client, String, CancellationToken) { let ct = config.cancellation_token.clone(); let service: StreamableHttpService = - StreamableHttpService::new(|| Ok(Calculator::new()), Default::default(), config); + StreamableHttpService::new(|| Ok(Calculator::new()), session_manager, config); let router = axum::Router::new().nest_service("/mcp", service); let tcp_listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); @@ -71,6 +80,17 @@ async fn post_init( req.send().await.expect("send initialize request") } +async fn post_non_initialize(client: &reqwest::Client, url: &str) -> reqwest::Response { + client + .post(url) + .header("Content-Type", "application/json") + .header("Accept", "application/json, text/event-stream") + .body(r#"{"jsonrpc":"2.0","id":1,"method":"tools/list","params":{}}"#) + .send() + .await + .expect("send non-initialize request") +} + #[tokio::test] async fn stateless_init_rejects_when_header_older_than_body() -> anyhow::Result<()> { let (client, url, ct) = spawn_server(stateless_json_config()).await; @@ -147,3 +167,21 @@ async fn stateful_init_rejects_when_header_mismatches_body() -> anyhow::Result<( ct.cancel(); Ok(()) } + +#[tokio::test] +async fn stateful_rejected_initial_posts_do_not_create_sessions() -> anyhow::Result<()> { + let session_manager = Arc::new(LocalSessionManager::default()); + let (client, url, ct) = + spawn_server_with_manager(stateful_config(), session_manager.clone()).await; + + let response = post_non_initialize(&client, &url).await; + assert_eq!(response.status(), 422); + assert_eq!(session_manager.sessions.read().await.len(), 0); + + let response = post_init(&client, &url, Some("2024-11-05"), "2025-11-25").await; + assert_eq!(response.status(), 400); + assert_eq!(session_manager.sessions.read().await.len(), 0); + + ct.cancel(); + Ok(()) +}