Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions crates/rmcp/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,16 @@ name = "test_streamable_http_protocol_version"
required-features = ["server", "client", "transport-streamable-http-server", "reqwest"]
path = "tests/test_streamable_http_protocol_version.rs"

[[test]]
name = "test_stateless_protocol_version"
required-features = ["server", "transport-streamable-http-server", "reqwest"]
path = "tests/test_stateless_protocol_version.rs"

[[test]]
name = "test_protocol_version_negotiation"
required-features = ["server", "client"]
path = "tests/test_protocol_version_negotiation.rs"

[[test]]
name = "test_streamable_http_4xx_error_body"
required-features = ["transport-streamable-http-client", "transport-streamable-http-client-reqwest"]
Expand Down
10 changes: 8 additions & 2 deletions crates/rmcp/src/handler/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use crate::{
model::{TaskSupport, *},
service::{
MaybeSendFuture, NotificationContext, RequestContext, RoleServer, Service, ServiceRole,
negotiate_protocol_version,
},
};

Expand Down Expand Up @@ -202,8 +203,13 @@ macro_rules! server_handler_methods {
request: InitializeRequestParams,
context: RequestContext<RoleServer>,
) -> impl Future<Output = Result<InitializeResult, McpError>> + MaybeSendFuture + '_ {
context.peer.set_peer_info(request);
std::future::ready(Ok(self.get_info()))
context.peer.set_peer_info(request.clone());
let mut info = self.get_info();
info.protocol_version = negotiate_protocol_version(
&request.protocol_version,
info.protocol_version,
);
std::future::ready(Ok(info))
}
fn complete(
&self,
Expand Down
7 changes: 6 additions & 1 deletion crates/rmcp/src/service/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ where
}

/// Echoes the client-requested version if known; otherwise returns `server_fallback`.
fn negotiate_protocol_version(
pub(crate) fn negotiate_protocol_version(
client_requested: &ProtocolVersion,
server_fallback: ProtocolVersion,
) -> ProtocolVersion {
Expand Down Expand Up @@ -254,6 +254,11 @@ where
&peer_info.params.protocol_version,
init_response.protocol_version,
);
// Update peer_info so context.protocol_version() reflects the negotiated
// version in all subsequent request handlers.
let mut negotiated_peer_info = peer_info.params.clone();
negotiated_peer_info.protocol_version = init_response.protocol_version.clone();
peer.set_peer_info(negotiated_peer_info);
transport
.send(ServerJsonRpcMessage::response(
ServerResult::InitializeResult(init_response),
Expand Down
44 changes: 41 additions & 3 deletions crates/rmcp/src/transport/streamable_http_server/tower.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@ use super::session::{
use crate::{
RoleServer,
model::{
ClientJsonRpcMessage, ClientNotification, ClientRequest, ErrorData, GetExtensions,
InitializeRequest, InitializedNotification, JsonRpcError, ProtocolVersion, RequestId,
ClientCapabilities, ClientJsonRpcMessage, ClientNotification, ClientRequest, ErrorData,
GetExtensions, Implementation, InitializeRequest, InitializeRequestParams,
InitializedNotification, JsonRpcError, ProtocolVersion, RequestId,
},
serve_server,
service::serve_directly,
Expand Down Expand Up @@ -1243,10 +1244,17 @@ where
.map_err(internal_error_response("get service"))?;
match message {
ClientJsonRpcMessage::Request(mut request) => {
// Build a peer_info so context.protocol_version() works inside handlers.
// serve_directly skips the handshake and receives None by default, making
// protocol_version() always return None in stateless mode. We reconstruct it:
// - initialize requests: version comes from the request body params
// - all other requests: version comes from the MCP-Protocol-Version header
// (already validated above; absent header defaults to 2025-03-26)
let peer_info = Self::peer_info_for_stateless_request(&request, &part.headers);
request.request.extensions_mut().insert(part);
let (transport, mut receiver) =
OneshotTransport::<RoleServer>::new(ClientJsonRpcMessage::Request(request));
let service = serve_directly(service, transport, None);
let service = serve_directly(service, transport, peer_info);
tokio::spawn(async move {
// on service created
let _ = service.waiting().await;
Expand Down Expand Up @@ -1335,4 +1343,34 @@ where
}
Ok(accepted_response())
}

/// Build a `ClientInfo` (peer_info) for a stateless request so that
/// `context.protocol_version()` returns the correct value inside handlers.
///
/// `serve_directly` skips the MCP handshake and accepts `peer_info = None`,
/// which means `context.protocol_version()` is always `None` in stateless mode.
/// We reconstruct the protocol version from the available signal per request type:
/// - initialize: version is in the request body params (authoritative)
/// - all other requests: version is in the MCP-Protocol-Version header
/// (validated before this point; absent header defaults to 2025-03-26)
fn peer_info_for_stateless_request(
request: &crate::model::JsonRpcRequest<ClientRequest>,
headers: &HeaderMap,
) -> Option<InitializeRequestParams> {
let version = if let ClientRequest::InitializeRequest(ref init) = request.request {
init.params.protocol_version.clone()
} else {
headers
.get(HEADER_MCP_PROTOCOL_VERSION)
.and_then(|v| v.to_str().ok())
.and_then(|s| serde_json::from_value(serde_json::Value::String(s.to_owned())).ok())
.unwrap_or(ProtocolVersion::V_2025_03_26)
};
Some(InitializeRequestParams {
meta: None,
protocol_version: version,
capabilities: ClientCapabilities::default(),
client_info: Implementation::default(),
})
}
}
83 changes: 83 additions & 0 deletions crates/rmcp/tests/test_protocol_version_negotiation.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
//! Tests for protocol version negotiation in the default ServerHandler::initialize impl.
//!
//! Known versions are echoed back; unknown versions fall back to LATEST.
#![cfg(not(feature = "local"))]
#![cfg(feature = "client")]

use rmcp::{
ClientHandler, ServerHandler, ServiceExt,
model::{ClientInfo, ProtocolVersion, ServerInfo},
};

#[derive(Debug, Clone, Default)]
struct EchoServer;

impl ServerHandler for EchoServer {
fn get_info(&self) -> ServerInfo {
ServerInfo::default()
}
}

#[derive(Debug, Clone)]
struct VersionedClient {
protocol_version: ProtocolVersion,
}

impl ClientHandler for VersionedClient {
fn get_info(&self) -> ClientInfo {
let mut info = ClientInfo::default();
info.protocol_version = self.protocol_version.clone();
info
}
}

async fn negotiated_version(client_version: ProtocolVersion) -> ProtocolVersion {
let (server_transport, client_transport) = tokio::io::duplex(4096);

tokio::spawn(async move {
let _ = EchoServer
.serve(server_transport)
.await
.expect("server should start")
.waiting()
.await;
});

let client = VersionedClient {
protocol_version: client_version,
}
.serve(client_transport)
.await
.expect("client should connect");

let version = client
.peer_info()
.expect("peer_info should be set")
.protocol_version
.clone();

client.cancel().await.expect("client should cancel");
version
}

#[tokio::test]
async fn known_version_echoed_back() {
for version in ProtocolVersion::KNOWN_VERSIONS {
let negotiated = negotiated_version(version.clone()).await;
assert_eq!(
negotiated, *version,
"known version {version} should be echoed back"
);
}
}

#[tokio::test]
async fn unknown_version_falls_back_to_latest() {
let unknown: ProtocolVersion = serde_json::from_str(r#""1999-01-01""#).unwrap();
let negotiated = negotiated_version(unknown).await;
assert_eq!(
negotiated,
ProtocolVersion::LATEST,
"unknown version should fall back to LATEST"
);
}
99 changes: 99 additions & 0 deletions crates/rmcp/tests/test_stateless_protocol_version.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
//! Tests for protocol version negotiation in stateless HTTP mode.
//!
//! Known versions are echoed back; unknown versions fall back to LATEST.
#![cfg(not(feature = "local"))]

use rmcp::{
model::ProtocolVersion,
transport::streamable_http_server::{
StreamableHttpServerConfig, StreamableHttpService, session::local::LocalSessionManager,
},
};
use tokio_util::sync::CancellationToken;

mod common;
use common::calculator::Calculator;

fn stateless_json_config() -> StreamableHttpServerConfig {
StreamableHttpServerConfig::default()
.with_stateful_mode(false)
.with_json_response(true)
.with_sse_keep_alive(None)
.with_cancellation_token(CancellationToken::new())
}

async fn spawn_server(
config: StreamableHttpServerConfig,
) -> (reqwest::Client, String, CancellationToken) {
let ct = config.cancellation_token.clone();
let service: StreamableHttpService<Calculator, LocalSessionManager> =
StreamableHttpService::new(|| Ok(Calculator::new()), Default::default(), config);

let router = axum::Router::new().nest_service("/mcp", service);
let tcp_listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = tcp_listener.local_addr().unwrap();

tokio::spawn({
let ct = ct.clone();
async move {
let _ = axum::serve(tcp_listener, router)
.with_graceful_shutdown(async move { ct.cancelled_owned().await })
.await;
}
});

(reqwest::Client::new(), format!("http://{addr}/mcp"), ct)
}

async fn post_init(client: &reqwest::Client, url: &str, body_version: &str) -> serde_json::Value {
let body = serde_json::json!({
"jsonrpc": "2.0",
"id": 1,
"method": "initialize",
"params": {
"protocolVersion": body_version,
"capabilities": {},
"clientInfo": {"name": "test", "version": "0.0.1"}
}
});
let resp = client
.post(url)
.header("Content-Type", "application/json")
.header("Accept", "application/json, text/event-stream")
.body(body.to_string())
.send()
.await
.expect("send request");
assert!(resp.status().is_success(), "HTTP {}", resp.status());
resp.json().await.expect("parse JSON")
}

#[tokio::test]
async fn stateless_init_echoes_known_version() {
let (client, url, ct) = spawn_server(stateless_json_config()).await;

for version in ProtocolVersion::KNOWN_VERSIONS {
let resp = post_init(&client, &url, version.as_str()).await;
assert_eq!(
resp["result"]["protocolVersion"],
version.as_str(),
"known version {version} should be echoed back"
);
}

ct.cancel();
}

#[tokio::test]
async fn stateless_init_unknown_version_falls_back_to_latest() {
let (client, url, ct) = spawn_server(stateless_json_config()).await;

let resp = post_init(&client, &url, "1999-01-01").await;
assert_eq!(
resp["result"]["protocolVersion"],
ProtocolVersion::LATEST.as_str(),
"unknown version should fall back to LATEST"
);

ct.cancel();
}