From 7b7bb9d890daea4fe6644ec9f3c29764aa67546b Mon Sep 17 00:00:00 2001 From: actsalan Date: Fri, 26 Jun 2026 09:48:18 -0700 Subject: [PATCH 1/2] fix: negotiate protocol version in handler (fixes #916) --- crates/rmcp/Cargo.toml | 10 ++ crates/rmcp/src/handler/server.rs | 9 +- crates/rmcp/src/service/server.rs | 7 +- .../transport/streamable_http_server/tower.rs | 44 ++++++++- .../test_protocol_version_negotiation.rs | 83 ++++++++++++++++ .../tests/test_stateless_protocol_version.rs | 99 +++++++++++++++++++ 6 files changed, 247 insertions(+), 5 deletions(-) create mode 100644 crates/rmcp/tests/test_protocol_version_negotiation.rs create mode 100644 crates/rmcp/tests/test_stateless_protocol_version.rs diff --git a/crates/rmcp/Cargo.toml b/crates/rmcp/Cargo.toml index 638679812..720431487 100644 --- a/crates/rmcp/Cargo.toml +++ b/crates/rmcp/Cargo.toml @@ -282,6 +282,16 @@ name = "test_streamable_http_protocol_version" required-features = ["server", "client", "transport-streamable-http-server", "reqwest"] path = "tests/test_streamable_http_protocol_version.rs" +[[test]] +name = "test_stateless_protocol_version" +required-features = ["server", "transport-streamable-http-server", "reqwest"] +path = "tests/test_stateless_protocol_version.rs" + +[[test]] +name = "test_protocol_version_negotiation" +required-features = ["server", "client"] +path = "tests/test_protocol_version_negotiation.rs" + [[test]] name = "test_streamable_http_4xx_error_body" required-features = ["transport-streamable-http-client", "transport-streamable-http-client-reqwest"] diff --git a/crates/rmcp/src/handler/server.rs b/crates/rmcp/src/handler/server.rs index 0fb4bf891..696a588b1 100644 --- a/crates/rmcp/src/handler/server.rs +++ b/crates/rmcp/src/handler/server.rs @@ -7,6 +7,7 @@ use crate::{ model::{TaskSupport, *}, service::{ MaybeSendFuture, NotificationContext, RequestContext, RoleServer, Service, ServiceRole, + negotiate_protocol_version, }, }; @@ -202,8 +203,14 @@ macro_rules! server_handler_methods { request: InitializeRequestParams, context: RequestContext, ) -> impl Future> + MaybeSendFuture + '_ { + let negotiated = negotiate_protocol_version( + &request.protocol_version, + ProtocolVersion::LATEST, + ); context.peer.set_peer_info(request); - std::future::ready(Ok(self.get_info())) + let mut info = self.get_info(); + info.protocol_version = negotiated; + std::future::ready(Ok(info)) } fn complete( &self, diff --git a/crates/rmcp/src/service/server.rs b/crates/rmcp/src/service/server.rs index c369e5aca..4f479b9e8 100644 --- a/crates/rmcp/src/service/server.rs +++ b/crates/rmcp/src/service/server.rs @@ -162,7 +162,7 @@ where } /// Echoes the client-requested version if known; otherwise returns `server_fallback`. -fn negotiate_protocol_version( +pub(crate) fn negotiate_protocol_version( client_requested: &ProtocolVersion, server_fallback: ProtocolVersion, ) -> ProtocolVersion { @@ -254,6 +254,11 @@ where &peer_info.params.protocol_version, init_response.protocol_version, ); + // Update peer_info so context.protocol_version() reflects the negotiated + // version in all subsequent request handlers. + let mut negotiated_peer_info = peer_info.params.clone(); + negotiated_peer_info.protocol_version = init_response.protocol_version.clone(); + peer.set_peer_info(negotiated_peer_info); transport .send(ServerJsonRpcMessage::response( ServerResult::InitializeResult(init_response), diff --git a/crates/rmcp/src/transport/streamable_http_server/tower.rs b/crates/rmcp/src/transport/streamable_http_server/tower.rs index cd2f5f1e5..a4629f95d 100644 --- a/crates/rmcp/src/transport/streamable_http_server/tower.rs +++ b/crates/rmcp/src/transport/streamable_http_server/tower.rs @@ -16,8 +16,9 @@ use super::session::{ use crate::{ RoleServer, model::{ - ClientJsonRpcMessage, ClientNotification, ClientRequest, ErrorData, GetExtensions, - InitializeRequest, InitializedNotification, JsonRpcError, ProtocolVersion, RequestId, + ClientCapabilities, ClientJsonRpcMessage, ClientNotification, ClientRequest, ErrorData, + GetExtensions, Implementation, InitializeRequest, InitializeRequestParams, + InitializedNotification, JsonRpcError, ProtocolVersion, RequestId, }, serve_server, service::serve_directly, @@ -1243,10 +1244,17 @@ where .map_err(internal_error_response("get service"))?; match message { ClientJsonRpcMessage::Request(mut request) => { + // Build a peer_info so context.protocol_version() works inside handlers. + // serve_directly skips the handshake and receives None by default, making + // protocol_version() always return None in stateless mode. We reconstruct it: + // - initialize requests: version comes from the request body params + // - all other requests: version comes from the MCP-Protocol-Version header + // (already validated above; absent header defaults to 2025-03-26) + let peer_info = Self::peer_info_for_stateless_request(&request, &part.headers); request.request.extensions_mut().insert(part); let (transport, mut receiver) = OneshotTransport::::new(ClientJsonRpcMessage::Request(request)); - let service = serve_directly(service, transport, None); + let service = serve_directly(service, transport, peer_info); tokio::spawn(async move { // on service created let _ = service.waiting().await; @@ -1335,4 +1343,34 @@ where } Ok(accepted_response()) } + + /// Build a `ClientInfo` (peer_info) for a stateless request so that + /// `context.protocol_version()` returns the correct value inside handlers. + /// + /// `serve_directly` skips the MCP handshake and accepts `peer_info = None`, + /// which means `context.protocol_version()` is always `None` in stateless mode. + /// We reconstruct the protocol version from the available signal per request type: + /// - initialize: version is in the request body params (authoritative) + /// - all other requests: version is in the MCP-Protocol-Version header + /// (validated before this point; absent header defaults to 2025-03-26) + fn peer_info_for_stateless_request( + request: &crate::model::JsonRpcRequest, + headers: &HeaderMap, + ) -> Option { + let version = if let ClientRequest::InitializeRequest(ref init) = request.request { + init.params.protocol_version.clone() + } else { + headers + .get(HEADER_MCP_PROTOCOL_VERSION) + .and_then(|v| v.to_str().ok()) + .and_then(|s| serde_json::from_value(serde_json::Value::String(s.to_owned())).ok()) + .unwrap_or(ProtocolVersion::V_2025_03_26) + }; + Some(InitializeRequestParams { + meta: None, + protocol_version: version, + capabilities: ClientCapabilities::default(), + client_info: Implementation::default(), + }) + } } diff --git a/crates/rmcp/tests/test_protocol_version_negotiation.rs b/crates/rmcp/tests/test_protocol_version_negotiation.rs new file mode 100644 index 000000000..44a314e68 --- /dev/null +++ b/crates/rmcp/tests/test_protocol_version_negotiation.rs @@ -0,0 +1,83 @@ +//! Tests for protocol version negotiation in the default ServerHandler::initialize impl. +//! +//! Known versions are echoed back; unknown versions fall back to LATEST. +#![cfg(not(feature = "local"))] +#![cfg(feature = "client")] + +use rmcp::{ + ClientHandler, ServerHandler, ServiceExt, + model::{ClientInfo, ProtocolVersion, ServerInfo}, +}; + +#[derive(Debug, Clone, Default)] +struct EchoServer; + +impl ServerHandler for EchoServer { + fn get_info(&self) -> ServerInfo { + ServerInfo::default() + } +} + +#[derive(Debug, Clone)] +struct VersionedClient { + protocol_version: ProtocolVersion, +} + +impl ClientHandler for VersionedClient { + fn get_info(&self) -> ClientInfo { + let mut info = ClientInfo::default(); + info.protocol_version = self.protocol_version.clone(); + info + } +} + +async fn negotiated_version(client_version: ProtocolVersion) -> ProtocolVersion { + let (server_transport, client_transport) = tokio::io::duplex(4096); + + tokio::spawn(async move { + let _ = EchoServer + .serve(server_transport) + .await + .expect("server should start") + .waiting() + .await; + }); + + let client = VersionedClient { + protocol_version: client_version, + } + .serve(client_transport) + .await + .expect("client should connect"); + + let version = client + .peer_info() + .expect("peer_info should be set") + .protocol_version + .clone(); + + client.cancel().await.expect("client should cancel"); + version +} + +#[tokio::test] +async fn known_version_echoed_back() { + for version in ProtocolVersion::KNOWN_VERSIONS { + let negotiated = negotiated_version(version.clone()).await; + assert_eq!( + negotiated, *version, + "known version {version} should be echoed back" + ); + } +} + +#[tokio::test] +async fn unknown_version_falls_back_to_latest() { + let unknown: ProtocolVersion = serde_json::from_str(r#""1999-01-01""#).unwrap(); + let negotiated = negotiated_version(unknown).await; + assert_eq!( + negotiated, + ProtocolVersion::LATEST, + "unknown version should fall back to LATEST" + ); +} diff --git a/crates/rmcp/tests/test_stateless_protocol_version.rs b/crates/rmcp/tests/test_stateless_protocol_version.rs new file mode 100644 index 000000000..5103ddd8d --- /dev/null +++ b/crates/rmcp/tests/test_stateless_protocol_version.rs @@ -0,0 +1,99 @@ +//! Tests for protocol version negotiation in stateless HTTP mode. +//! +//! Known versions are echoed back; unknown versions fall back to LATEST. +#![cfg(not(feature = "local"))] + +use rmcp::{ + model::ProtocolVersion, + transport::streamable_http_server::{ + StreamableHttpServerConfig, StreamableHttpService, session::local::LocalSessionManager, + }, +}; +use tokio_util::sync::CancellationToken; + +mod common; +use common::calculator::Calculator; + +fn stateless_json_config() -> StreamableHttpServerConfig { + StreamableHttpServerConfig::default() + .with_stateful_mode(false) + .with_json_response(true) + .with_sse_keep_alive(None) + .with_cancellation_token(CancellationToken::new()) +} + +async fn spawn_server( + config: StreamableHttpServerConfig, +) -> (reqwest::Client, String, CancellationToken) { + let ct = config.cancellation_token.clone(); + let service: StreamableHttpService = + StreamableHttpService::new(|| Ok(Calculator::new()), Default::default(), config); + + let router = axum::Router::new().nest_service("/mcp", service); + let tcp_listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = tcp_listener.local_addr().unwrap(); + + tokio::spawn({ + let ct = ct.clone(); + async move { + let _ = axum::serve(tcp_listener, router) + .with_graceful_shutdown(async move { ct.cancelled_owned().await }) + .await; + } + }); + + (reqwest::Client::new(), format!("http://{addr}/mcp"), ct) +} + +async fn post_init(client: &reqwest::Client, url: &str, body_version: &str) -> serde_json::Value { + let body = serde_json::json!({ + "jsonrpc": "2.0", + "id": 1, + "method": "initialize", + "params": { + "protocolVersion": body_version, + "capabilities": {}, + "clientInfo": {"name": "test", "version": "0.0.1"} + } + }); + let resp = client + .post(url) + .header("Content-Type", "application/json") + .header("Accept", "application/json, text/event-stream") + .body(body.to_string()) + .send() + .await + .expect("send request"); + assert!(resp.status().is_success(), "HTTP {}", resp.status()); + resp.json().await.expect("parse JSON") +} + +#[tokio::test] +async fn stateless_init_echoes_known_version() { + let (client, url, ct) = spawn_server(stateless_json_config()).await; + + for version in ProtocolVersion::KNOWN_VERSIONS { + let resp = post_init(&client, &url, version.as_str()).await; + assert_eq!( + resp["result"]["protocolVersion"], + version.as_str(), + "known version {version} should be echoed back" + ); + } + + ct.cancel(); +} + +#[tokio::test] +async fn stateless_init_unknown_version_falls_back_to_latest() { + let (client, url, ct) = spawn_server(stateless_json_config()).await; + + let resp = post_init(&client, &url, "1999-01-01").await; + assert_eq!( + resp["result"]["protocolVersion"], + ProtocolVersion::LATEST.as_str(), + "unknown version should fall back to LATEST" + ); + + ct.cancel(); +} From 0f12a965a224459113edc785ef14446c2f1a3d25 Mon Sep 17 00:00:00 2001 From: actsalan Date: Fri, 26 Jun 2026 10:25:43 -0700 Subject: [PATCH 2/2] fix: use server pinned version as fallback in default initialize handler --- crates/rmcp/src/handler/server.rs | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/crates/rmcp/src/handler/server.rs b/crates/rmcp/src/handler/server.rs index 696a588b1..3cec563e8 100644 --- a/crates/rmcp/src/handler/server.rs +++ b/crates/rmcp/src/handler/server.rs @@ -203,13 +203,12 @@ macro_rules! server_handler_methods { request: InitializeRequestParams, context: RequestContext, ) -> impl Future> + MaybeSendFuture + '_ { - let negotiated = negotiate_protocol_version( + context.peer.set_peer_info(request.clone()); + let mut info = self.get_info(); + info.protocol_version = negotiate_protocol_version( &request.protocol_version, - ProtocolVersion::LATEST, + info.protocol_version, ); - context.peer.set_peer_info(request); - let mut info = self.get_info(); - info.protocol_version = negotiated; std::future::ready(Ok(info)) } fn complete(