From 82699fd5dd4e091d409edfe5e0a3c35c5bddf22a Mon Sep 17 00:00:00 2001 From: Dale Seo <5466341+DaleSeo@users.noreply.github.com> Date: Fri, 26 Jun 2026 12:46:10 -0400 Subject: [PATCH] feat!: add MRTR behavior support --- conformance/src/bin/server.rs | 21 +- crates/rmcp-macros/src/prompt_handler.rs | 2 +- crates/rmcp-macros/src/tool_handler.rs | 2 +- crates/rmcp/src/handler/server.rs | 32 +- crates/rmcp/src/handler/server/prompt.rs | 39 +- crates/rmcp/src/handler/server/router.rs | 6 +- .../rmcp/src/handler/server/router/prompt.rs | 7 +- crates/rmcp/src/handler/server/router/tool.rs | 18 +- crates/rmcp/src/handler/server/tool.rs | 65 +-- .../rmcp/src/handler/server/wrapper/json.rs | 9 +- crates/rmcp/src/model/mrtr.rs | 106 ++++- crates/rmcp/src/service.rs | 3 + crates/rmcp/src/service/client.rs | 381 +++++++++++++++++- crates/rmcp/tests/test_mrtr_behavior.rs | 166 ++++++++ crates/rmcp/tests/test_structured_output.rs | 14 +- 15 files changed, 777 insertions(+), 94 deletions(-) create mode 100644 crates/rmcp/tests/test_mrtr_behavior.rs diff --git a/conformance/src/bin/server.rs b/conformance/src/bin/server.rs index fda1ab938..078d76821 100644 --- a/conformance/src/bin/server.rs +++ b/conformance/src/bin/server.rs @@ -213,9 +213,9 @@ impl ServerHandler for ConformanceServer { &self, request: CallToolRequestParams, cx: RequestContext, - ) -> Result { + ) -> Result { let args = request.arguments.unwrap_or_default(); - match request.name.as_ref() { + let result = match request.name.as_ref() { "test_simple_text" => Ok(CallToolResult::success(vec![ContentBlock::text( "This is a simple text response for testing.", )])), @@ -530,7 +530,8 @@ impl ServerHandler for ConformanceServer { format!("Unknown tool: {}", request.name), None, )), - } + }; + result.map(Into::into) } async fn list_resources( @@ -555,9 +556,9 @@ impl ServerHandler for ConformanceServer { &self, request: ReadResourceRequestParams, _cx: RequestContext, - ) -> Result { + ) -> Result { let uri = request.uri.as_str(); - match uri { + let result = match uri { "test://static-text" => Ok(ReadResourceResult::new(vec![ ResourceContents::TextResourceContents { uri: uri.into(), @@ -598,7 +599,8 @@ impl ServerHandler for ConformanceServer { )) } } - } + }; + result.map(Into::into) } async fn list_resource_templates( @@ -679,8 +681,8 @@ impl ServerHandler for ConformanceServer { &self, request: GetPromptRequestParams, _cx: RequestContext, - ) -> Result { - match request.name.as_str() { + ) -> Result { + let result = match request.name.as_str() { "test_simple_prompt" => Ok(GetPromptResult::new(vec![PromptMessage::new_text( Role::User, "This is a simple test prompt.", @@ -724,7 +726,8 @@ impl ServerHandler for ConformanceServer { format!("Unknown prompt: {}", request.name), None, )), - } + }; + result.map(Into::into) } async fn complete( diff --git a/crates/rmcp-macros/src/prompt_handler.rs b/crates/rmcp-macros/src/prompt_handler.rs index af4f24bd8..24032eaab 100644 --- a/crates/rmcp-macros/src/prompt_handler.rs +++ b/crates/rmcp-macros/src/prompt_handler.rs @@ -35,7 +35,7 @@ pub fn prompt_handler(attr: TokenStream, input: TokenStream) -> syn::Result, - ) -> Result { + ) -> Result { let prompt_context = rmcp::handler::server::prompt::PromptContext::new( self, request.name, diff --git a/crates/rmcp-macros/src/tool_handler.rs b/crates/rmcp-macros/src/tool_handler.rs index dc935828d..1e39eb5f3 100644 --- a/crates/rmcp-macros/src/tool_handler.rs +++ b/crates/rmcp-macros/src/tool_handler.rs @@ -47,7 +47,7 @@ pub fn tool_handler(attr: TokenStream, input: TokenStream) -> syn::Result, - ) -> Result { + ) -> Result { let tcc = rmcp::handler::server::tool::ToolCallContext::new(self, request, context); #router.call(tcc).await } diff --git a/crates/rmcp/src/handler/server.rs b/crates/rmcp/src/handler/server.rs index 0fb4bf891..43acf1eff 100644 --- a/crates/rmcp/src/handler/server.rs +++ b/crates/rmcp/src/handler/server.rs @@ -26,6 +26,9 @@ impl Service for H { ) -> Result<::Resp, McpError> { // `context` is moved into the dispatch below, so read the negotiated version first. let protocol_version = context.protocol_version(); + let mrtr_supported = protocol_version + .as_ref() + .is_some_and(|v| v.as_str() >= ProtocolVersion::V_2026_07_28.as_str()); let result = match request { ClientRequest::InitializeRequest(request) => self .initialize(request.params, context) @@ -45,7 +48,7 @@ impl Service for H { ClientRequest::GetPromptRequest(request) => self .get_prompt(request.params, context) .await - .map(ServerResult::GetPromptResult), + .map(ServerResult::from), ClientRequest::ListPromptsRequest(request) => self .list_prompts(request.params, context) .await @@ -61,7 +64,7 @@ impl Service for H { ClientRequest::ReadResourceRequest(request) => self .read_resource(request.params, context) .await - .map(ServerResult::ReadResourceResult), + .map(ServerResult::from), ClientRequest::SubscribeRequest(request) => self .subscribe(request.params, context) .await @@ -104,7 +107,7 @@ impl Service for H { } else { self.call_tool(request.params, context) .await - .map(ServerResult::CallToolResult) + .map(ServerResult::from) } } ClientRequest::ListToolsRequest(request) => self @@ -132,6 +135,17 @@ impl Service for H { .await .map(ServerResult::CancelTaskResult), }; + let result = result.and_then(|result| { + if matches!(result, ServerResult::InputRequiredResult(_)) && !mrtr_supported { + Err(McpError::invalid_request( + "InputRequiredResult requires negotiated protocol version 2026-07-28 or newer", + None, + )) + } else { + Ok(result) + } + }); + // SEP-2164: peers negotiating 2026-07-28+ get the standard INVALID_PARAMS code for // resource-not-found; older peers keep RESOURCE_NOT_FOUND. ISO `YYYY-MM-DD` versions // compare lexically the same as chronologically. @@ -223,7 +237,7 @@ macro_rules! server_handler_methods { &self, request: GetPromptRequestParams, context: RequestContext, - ) -> impl Future> + MaybeSendFuture + '_ { + ) -> impl Future> + MaybeSendFuture + '_ { std::future::ready(Err(McpError::method_not_found::())) } fn list_prompts( @@ -253,7 +267,7 @@ macro_rules! server_handler_methods { &self, request: ReadResourceRequestParams, context: RequestContext, - ) -> impl Future> + MaybeSendFuture + '_ { + ) -> impl Future> + MaybeSendFuture + '_ { std::future::ready(Err( McpError::method_not_found::(), )) @@ -306,7 +320,7 @@ macro_rules! server_handler_methods { &self, request: CallToolRequestParams, context: RequestContext, - ) -> impl Future> + MaybeSendFuture + '_ { + ) -> impl Future> + MaybeSendFuture + '_ { std::future::ready(Err(McpError::method_not_found::())) } fn list_tools( @@ -479,7 +493,7 @@ macro_rules! impl_server_handler_for_wrapper { &self, request: GetPromptRequestParams, context: RequestContext, - ) -> impl Future> + MaybeSendFuture + '_ { + ) -> impl Future> + MaybeSendFuture + '_ { (**self).get_prompt(request, context) } @@ -512,7 +526,7 @@ macro_rules! impl_server_handler_for_wrapper { &self, request: ReadResourceRequestParams, context: RequestContext, - ) -> impl Future> + MaybeSendFuture + '_ { + ) -> impl Future> + MaybeSendFuture + '_ { (**self).read_resource(request, context) } @@ -536,7 +550,7 @@ macro_rules! impl_server_handler_for_wrapper { &self, request: CallToolRequestParams, context: RequestContext, - ) -> impl Future> + MaybeSendFuture + '_ { + ) -> impl Future> + MaybeSendFuture + '_ { (**self).call_tool(request, context) } diff --git a/crates/rmcp/src/handler/server/prompt.rs b/crates/rmcp/src/handler/server/prompt.rs index ffce6b2e0..a75e02713 100644 --- a/crates/rmcp/src/handler/server/prompt.rs +++ b/crates/rmcp/src/handler/server/prompt.rs @@ -15,7 +15,7 @@ pub use super::common::{Extension, RequestId}; use crate::{ RoleServer, handler::server::wrapper::Parameters, - model::{GetPromptResult, PromptMessage}, + model::{GetPromptResponse, GetPromptResult, InputRequiredResult, PromptMessage}, service::{MaybeBoxFuture, MaybeSend, MaybeSendFuture, RequestContext}, }; @@ -59,12 +59,12 @@ pub trait GetPromptHandler { fn handle( self, context: PromptContext<'_, S>, - ) -> MaybeBoxFuture<'_, Result>; + ) -> MaybeBoxFuture<'_, Result>; } /// Type alias for dynamic prompt handlers #[cfg(not(feature = "local"))] -pub type DynGetPromptHandler = dyn for<'a> Fn(PromptContext<'a, S>) -> BoxFuture<'a, Result> +pub type DynGetPromptHandler = dyn for<'a> Fn(PromptContext<'a, S>) -> BoxFuture<'a, Result> + Send + Sync; @@ -73,7 +73,7 @@ pub type DynGetPromptHandler = dyn for<'a> Fn( PromptContext<'a, S>, ) -> futures::future::LocalBoxFuture< 'a, - Result, + Result, >; /// Adapter type for async methods that return `Vec` @@ -91,28 +91,35 @@ pub struct SyncPromptMethodAdapter(PhantomData R>); /// Trait for types that can be converted into GetPromptResult pub trait IntoGetPromptResult { - fn into_get_prompt_result(self) -> Result; + fn into_get_prompt_result(self) -> Result; } impl IntoGetPromptResult for GetPromptResult { - fn into_get_prompt_result(self) -> Result { - Ok(self) + fn into_get_prompt_result(self) -> Result { + Ok(self.into()) + } +} + +impl IntoGetPromptResult for InputRequiredResult { + fn into_get_prompt_result(self) -> Result { + Ok(self.into()) } } impl IntoGetPromptResult for Vec { - fn into_get_prompt_result(self) -> Result { + fn into_get_prompt_result(self) -> Result { Ok(GetPromptResult { result_type: Default::default(), description: None, messages: self, meta: None, - }) + } + .into()) } } impl IntoGetPromptResult for Result { - fn into_get_prompt_result(self) -> Result { + fn into_get_prompt_result(self) -> Result { self.and_then(|v| v.into_get_prompt_result()) } } @@ -129,7 +136,7 @@ pin_project_lite::pin_project! { }, Ready { #[pin] - result: futures::future::Ready>, + result: futures::future::Ready>, } } } @@ -139,7 +146,7 @@ where F: Future, R: IntoGetPromptResult, { - type Output = Result; + type Output = Result; fn poll( self: std::pin::Pin<&mut Self>, @@ -216,7 +223,7 @@ macro_rules! impl_prompt_handler_for { fn handle( self, mut context: PromptContext<'_, S>, - ) -> MaybeBoxFuture<'_, Result> + ) -> MaybeBoxFuture<'_, Result> { $( let result = $Tn::from_context_part(&mut context); @@ -249,7 +256,7 @@ macro_rules! impl_prompt_handler_for { fn handle( self, mut context: PromptContext<'_, S>, - ) -> MaybeBoxFuture<'_, Result> + ) -> MaybeBoxFuture<'_, Result> { $( let result = $Tn::from_context_part(&mut context); @@ -280,7 +287,7 @@ macro_rules! impl_prompt_handler_for { fn handle( self, mut context: PromptContext<'_, S>, - ) -> MaybeBoxFuture<'_, Result> + ) -> MaybeBoxFuture<'_, Result> { // Extract all parameters before moving into the async block $( @@ -315,7 +322,7 @@ macro_rules! impl_prompt_handler_for { fn handle( self, mut context: PromptContext<'_, S>, - ) -> MaybeBoxFuture<'_, Result> + ) -> MaybeBoxFuture<'_, Result> { $( let result = $Tn::from_context_part(&mut context); diff --git a/crates/rmcp/src/handler/server/router.rs b/crates/rmcp/src/handler/server/router.rs index 45ff9a586..e934137b8 100644 --- a/crates/rmcp/src/handler/server/router.rs +++ b/crates/rmcp/src/handler/server/router.rs @@ -106,7 +106,7 @@ where context, ); let result = self.tool_router.call(tool_call_context).await?; - Ok(ServerResult::CallToolResult(result)) + Ok(ServerResult::from(result)) } else { self.service .handle_request(ClientRequest::CallToolRequest(request), context) @@ -129,7 +129,7 @@ where context, ); let result = self.prompt_router.get_prompt(prompt_context).await?; - Ok(ServerResult::GetPromptResult(result)) + Ok(ServerResult::from(result)) } else { self.service .handle_request(ClientRequest::GetPromptRequest(request), context) @@ -193,7 +193,7 @@ mod tests { async fn test_router_deferred_notifier_e2e() { let mut router = Router::new(DummyHandler).with_tool(tool::ToolRoute::new_dyn( Tool::new("my_tool", "test", Arc::new(Default::default())), - |_ctx| Box::pin(async { Ok(CallToolResult::default()) }), + |_ctx| Box::pin(async { Ok(CallToolResult::default().into()) }), )); let id_provider: Arc = diff --git a/crates/rmcp/src/handler/server/router/prompt.rs b/crates/rmcp/src/handler/server/router/prompt.rs index e952b2a39..509fb4287 100644 --- a/crates/rmcp/src/handler/server/router/prompt.rs +++ b/crates/rmcp/src/handler/server/router/prompt.rs @@ -2,7 +2,7 @@ use std::{borrow::Cow, sync::Arc}; use crate::{ handler::server::prompt::{DynGetPromptHandler, GetPromptHandler, PromptContext}, - model::{GetPromptResult, Prompt}, + model::{GetPromptResponse, Prompt}, service::{MaybeBoxFuture, MaybeSend}, }; @@ -50,7 +50,8 @@ impl PromptRoute { where H: for<'a> Fn( PromptContext<'a, S>, - ) -> MaybeBoxFuture<'a, Result> + ) + -> MaybeBoxFuture<'a, Result> + MaybeSend + 'static, { @@ -175,7 +176,7 @@ where pub async fn get_prompt( &self, context: PromptContext<'_, S>, - ) -> Result { + ) -> Result { let item = self.map.get(context.name.as_str()).ok_or_else(|| { crate::ErrorData::invalid_params( format!("prompt '{}' not found", context.name), diff --git a/crates/rmcp/src/handler/server/router/tool.rs b/crates/rmcp/src/handler/server/router/tool.rs index dece66d95..215116250 100644 --- a/crates/rmcp/src/handler/server/router/tool.rs +++ b/crates/rmcp/src/handler/server/router/tool.rs @@ -137,21 +137,19 @@ use crate::{ tool::{CallToolHandler, DynCallToolHandler, ToolCallContext}, tool_name_validation::validate_and_warn_tool_name, }, - model::{CallToolResult, ContentBlock, ErrorCode, Tool, ToolAnnotations}, + model::{CallToolResponse, CallToolResult, ContentBlock, ErrorCode, Tool, ToolAnnotations}, service::{MaybeBoxFuture, MaybeSend}, }; const TOOL_ARGUMENT_DESERIALIZATION_ERROR_PREFIX: &str = "failed to deserialize parameters:"; -fn into_tool_argument_error(error: crate::ErrorData) -> Result { +fn into_tool_argument_error(error: crate::ErrorData) -> Result { if error.code == ErrorCode::INVALID_PARAMS && error .message .starts_with(TOOL_ARGUMENT_DESERIALIZATION_ERROR_PREFIX) { - return Ok(CallToolResult::error(vec![ContentBlock::text( - error.message, - )])); + return Ok(CallToolResult::error(vec![ContentBlock::text(error.message)]).into()); } Err(error) @@ -200,7 +198,8 @@ impl ToolRoute { where C: for<'a> Fn( ToolCallContext<'a, S>, - ) -> MaybeBoxFuture<'a, Result> + ) + -> MaybeBoxFuture<'a, Result> + MaybeSend + 'static, { @@ -561,7 +560,7 @@ where pub async fn call( &self, context: ToolCallContext<'_, S>, - ) -> Result { + ) -> Result { let name = context.name(); if self.disabled.contains(name) { return Err(crate::ErrorData::invalid_params("tool not found", None)); @@ -679,6 +678,9 @@ mod tests { .call(ctx) .await .expect("argument validation should be a tool result"); + let CallToolResponse::Complete(result) = result else { + panic!("expected complete CallToolResult"); + }; assert_eq!(result.is_error, Some(true)); let text = result @@ -696,7 +698,7 @@ mod tests { let service = DummyService; let mut router = ToolRouter::new().with_route(ToolRoute::new_dyn( crate::model::Tool::new("test_tool", "a test tool", Arc::new(Default::default())), - |_ctx| Box::pin(async { Ok(CallToolResult::default()) }), + |_ctx| Box::pin(async { Ok(CallToolResult::default().into()) }), )); router.disable_route("test_tool"); diff --git a/crates/rmcp/src/handler/server/tool.rs b/crates/rmcp/src/handler/server/tool.rs index bf350797d..cb4966df0 100644 --- a/crates/rmcp/src/handler/server/tool.rs +++ b/crates/rmcp/src/handler/server/tool.rs @@ -16,7 +16,10 @@ pub use super::{ use crate::{ RoleServer, handler::server::wrapper::Parameters, - model::{CallToolRequestParams, CallToolResult, IntoContents, JsonObject}, + model::{ + CallToolRequestParams, CallToolResponse, CallToolResult, InputRequiredResult, IntoContents, + JsonObject, + }, service::{MaybeBoxFuture, MaybeSend, MaybeSendFuture, RequestContext}, }; @@ -77,36 +80,46 @@ impl AsRequestContext for ToolCallContext<'_, S> { } pub trait IntoCallToolResult { - fn into_call_tool_result(self) -> Result; + fn into_call_tool_result(self) -> Result; } impl IntoCallToolResult for T { - fn into_call_tool_result(self) -> Result { - Ok(CallToolResult::success(self.into_contents())) + fn into_call_tool_result(self) -> Result { + Ok(CallToolResult::success(self.into_contents()).into()) } } impl IntoCallToolResult for CallToolResult { - fn into_call_tool_result(self) -> Result { - Ok(self) + fn into_call_tool_result(self) -> Result { + Ok(self.into()) + } +} + +impl IntoCallToolResult for InputRequiredResult { + fn into_call_tool_result(self) -> Result { + Ok(self.into()) } } impl IntoCallToolResult for crate::ErrorData { - fn into_call_tool_result(self) -> Result { + fn into_call_tool_result(self) -> Result { Err(self) } } impl IntoCallToolResult for Result { - fn into_call_tool_result(self) -> Result { + fn into_call_tool_result(self) -> Result { match self { Ok(value) => value.into_call_tool_result(), Err(error) => match error.into_call_tool_result() { - Ok(mut result) => { + Ok(CallToolResponse::Complete(mut result)) => { result.is_error = Some(true); - Ok(result) + Ok(result.into()) } + Ok(CallToolResponse::InputRequired(_)) => Err(crate::ErrorData::internal_error( + "InputRequiredResult cannot be returned from a tool error branch", + None, + )), Err(e) => Err(e), }, } @@ -124,7 +137,7 @@ pin_project_lite::pin_project! { }, Ready { #[pin] - result: Ready>, + result: Ready>, } } } @@ -134,7 +147,7 @@ where F: Future, R: IntoCallToolResult, { - type Output = Result; + type Output = Result; fn poll( self: std::pin::Pin<&mut Self>, @@ -153,20 +166,21 @@ pub trait CallToolHandler { fn call( self, context: ToolCallContext<'_, S>, - ) -> MaybeBoxFuture<'_, Result>; + ) -> MaybeBoxFuture<'_, Result>; } #[cfg(not(feature = "local"))] -pub type DynCallToolHandler = dyn for<'s> Fn(ToolCallContext<'s, S>) -> BoxFuture<'s, Result> +pub type DynCallToolHandler = dyn for<'s> Fn(ToolCallContext<'s, S>) -> BoxFuture<'s, Result> + Send + Sync; #[cfg(feature = "local")] -pub type DynCallToolHandler = - dyn for<'s> Fn( - ToolCallContext<'s, S>, - ) - -> futures::future::LocalBoxFuture<'s, Result>; +pub type DynCallToolHandler = dyn for<'s> Fn( + ToolCallContext<'s, S>, +) -> futures::future::LocalBoxFuture< + 's, + Result, +>; // Tool-specific extractor for tool name #[expect(clippy::exhaustive_structs, reason = "intentionally exhaustive")] @@ -205,7 +219,10 @@ impl FromContextPart> for JsonObject { } impl<'s, S> ToolCallContext<'s, S> { - pub fn invoke(self, h: H) -> MaybeBoxFuture<'s, Result> + pub fn invoke( + self, + h: H, + ) -> MaybeBoxFuture<'s, Result> where H: CallToolHandler, { @@ -248,7 +265,7 @@ macro_rules! impl_for { fn call( self, mut context: ToolCallContext<'_, S>, - ) -> MaybeBoxFuture<'_, Result>{ + ) -> MaybeBoxFuture<'_, Result>{ $( let result = $Tn::from_context_part(&mut context); let $Tn = match result { @@ -279,7 +296,7 @@ macro_rules! impl_for { fn call( self, mut context: ToolCallContext, - ) -> MaybeBoxFuture<'static, Result>{ + ) -> MaybeBoxFuture<'static, Result>{ $( let result = $Tn::from_context_part(&mut context); let $Tn = match result { @@ -308,7 +325,7 @@ macro_rules! impl_for { fn call( self, mut context: ToolCallContext, - ) -> MaybeBoxFuture<'static, Result> { + ) -> MaybeBoxFuture<'static, Result> { $( let result = $Tn::from_context_part(&mut context); let $Tn = match result { @@ -333,7 +350,7 @@ macro_rules! impl_for { fn call( self, mut context: ToolCallContext, - ) -> MaybeBoxFuture<'static, Result> { + ) -> MaybeBoxFuture<'static, Result> { $( let result = $Tn::from_context_part(&mut context); let $Tn = match result { diff --git a/crates/rmcp/src/handler/server/wrapper/json.rs b/crates/rmcp/src/handler/server/wrapper/json.rs index c03fbd032..7c5297963 100644 --- a/crates/rmcp/src/handler/server/wrapper/json.rs +++ b/crates/rmcp/src/handler/server/wrapper/json.rs @@ -3,7 +3,10 @@ use std::borrow::Cow; use schemars::JsonSchema; use serde::Serialize; -use crate::{handler::server::tool::IntoCallToolResult, model::CallToolResult}; +use crate::{ + handler::server::tool::IntoCallToolResult, + model::{CallToolResponse, CallToolResult}, +}; /// Json wrapper for structured output /// @@ -27,7 +30,7 @@ impl JsonSchema for Json { // Implementation for Json to create structured content impl IntoCallToolResult for Json { - fn into_call_tool_result(self) -> Result { + fn into_call_tool_result(self) -> Result { let value = serde_json::to_value(self.0).map_err(|e| { crate::ErrorData::internal_error( format!("Failed to serialize structured content: {}", e), @@ -35,6 +38,6 @@ impl IntoCallToolResult for Json { ) })?; - Ok(CallToolResult::structured(value)) + Ok(CallToolResult::structured(value).into()) } } diff --git a/crates/rmcp/src/model/mrtr.rs b/crates/rmcp/src/model/mrtr.rs index e4a5b3fb8..657ceb64b 100644 --- a/crates/rmcp/src/model/mrtr.rs +++ b/crates/rmcp/src/model/mrtr.rs @@ -16,7 +16,16 @@ use std::collections::BTreeMap; use serde::{Deserialize, Serialize}; use serde_json::Value; -use super::{CreateMessageRequest, ElicitRequest, ListRootsRequest, Meta, ResultType}; +use super::{ + CallToolResult, CreateMessageRequest, ElicitRequest, GetPromptResult, ListRootsRequest, Meta, + ReadResourceResult, ResultType, ServerResult, +}; + +/// Default maximum number of MRTR rounds a high-level client call will drive. +/// +/// This matches the default used by other Tier 1 SDKs and prevents a +/// misbehaving peer from keeping a request alive indefinitely. +pub const DEFAULT_MRTR_MAX_ROUNDS: usize = 10; /// A server-initiated request that can appear inside [`InputRequests`]. /// @@ -53,6 +62,101 @@ pub type InputRequests = BTreeMap; /// for use as a `BTreeMap` value. pub type InputResponses = BTreeMap; +/// Result of a `tools/call` request, including the MRTR intermediate result. +#[derive(Debug, Clone)] +#[non_exhaustive] +pub enum CallToolResponse { + /// The server completed the tool call. + Complete(CallToolResult), + /// The server requires client-side input before the tool call can complete. + InputRequired(InputRequiredResult), +} + +impl From for CallToolResponse { + fn from(result: CallToolResult) -> Self { + Self::Complete(result) + } +} + +impl From for CallToolResponse { + fn from(result: InputRequiredResult) -> Self { + Self::InputRequired(result) + } +} + +impl From for ServerResult { + fn from(response: CallToolResponse) -> Self { + match response { + CallToolResponse::Complete(result) => ServerResult::CallToolResult(result), + CallToolResponse::InputRequired(result) => ServerResult::InputRequiredResult(result), + } + } +} + +/// Result of a `prompts/get` request, including the MRTR intermediate result. +#[derive(Debug, Clone)] +#[non_exhaustive] +pub enum GetPromptResponse { + /// The server completed the prompt request. + Complete(GetPromptResult), + /// The server requires client-side input before the prompt can be returned. + InputRequired(InputRequiredResult), +} + +impl From for GetPromptResponse { + fn from(result: GetPromptResult) -> Self { + Self::Complete(result) + } +} + +impl From for GetPromptResponse { + fn from(result: InputRequiredResult) -> Self { + Self::InputRequired(result) + } +} + +impl From for ServerResult { + fn from(response: GetPromptResponse) -> Self { + match response { + GetPromptResponse::Complete(result) => ServerResult::GetPromptResult(result), + GetPromptResponse::InputRequired(result) => ServerResult::InputRequiredResult(result), + } + } +} + +/// Result of a `resources/read` request, including the MRTR intermediate result. +#[derive(Debug, Clone)] +#[non_exhaustive] +pub enum ReadResourceResponse { + /// The server completed the resource read. + Complete(ReadResourceResult), + /// The server requires client-side input before the resource can be returned. + InputRequired(InputRequiredResult), +} + +impl From for ReadResourceResponse { + fn from(result: ReadResourceResult) -> Self { + Self::Complete(result) + } +} + +impl From for ReadResourceResponse { + fn from(result: InputRequiredResult) -> Self { + Self::InputRequired(result) + } +} + +impl From for ServerResult { + fn from(response: ReadResourceResponse) -> Self { + match response { + ReadResourceResponse::Complete(result) => ServerResult::ReadResourceResult(result), + ReadResourceResponse::InputRequired(result) => { + ServerResult::InputRequiredResult(result) + } + } + } +} + /// A result indicating that additional input is needed before the request /// can be completed. /// diff --git a/crates/rmcp/src/service.rs b/crates/rmcp/src/service.rs index 29d822a58..991df3b3b 100644 --- a/crates/rmcp/src/service.rs +++ b/crates/rmcp/src/service.rs @@ -86,6 +86,9 @@ pub enum ServiceError { Cancelled { reason: Option }, #[error("request timeout after {}", chrono::Duration::from_std(*timeout).unwrap_or_default())] Timeout { timeout: Duration }, + /// The peer kept returning `input_required` beyond the configured round cap. + #[error("input_required did not complete within {max_rounds} MRTR rounds")] + InputRequiredRoundsExceeded { max_rounds: usize }, } trait TransferObject: diff --git a/crates/rmcp/src/service/client.rs b/crates/rmcp/src/service/client.rs index 05c2749fe..84541799b 100644 --- a/crates/rmcp/src/service/client.rs +++ b/crates/rmcp/src/service/client.rs @@ -1,24 +1,26 @@ // Sampling/Roots/Logging are SEP-2577-deprecated; internal references are expected. #![expect(deprecated)] -use std::borrow::Cow; +use std::{borrow::Cow, sync::Arc, time::Duration}; use thiserror::Error; use super::*; use crate::{ model::{ - ArgumentInfo, CallToolRequest, CallToolRequestParams, CallToolResult, + ArgumentInfo, CallToolRequest, CallToolRequestParams, CallToolResponse, CallToolResult, CancelledNotification, CancelledNotificationParam, ClientInfo, ClientJsonRpcMessage, ClientNotification, ClientRequest, ClientResult, CompleteRequest, CompleteRequestParams, - CompleteResult, CompletionContext, CompletionInfo, ErrorData, GetPromptRequest, - GetPromptRequestParams, GetPromptResult, InitializeRequest, InitializedNotification, - JsonRpcResponse, ListPromptsRequest, ListPromptsResult, ListResourceTemplatesRequest, - ListResourceTemplatesResult, ListResourcesRequest, ListResourcesResult, ListToolsRequest, - ListToolsResult, PaginatedRequestParams, ProgressNotification, ProgressNotificationParam, - ReadResourceRequest, ReadResourceRequestParams, ReadResourceResult, Reference, RequestId, - RootsListChangedNotification, ServerInfo, ServerJsonRpcMessage, ServerNotification, - ServerRequest, ServerResult, SetLevelRequest, SetLevelRequestParams, SubscribeRequest, - SubscribeRequestParams, UnsubscribeRequest, UnsubscribeRequestParams, + CompleteResult, CompletionContext, CompletionInfo, DEFAULT_MRTR_MAX_ROUNDS, ErrorData, + GetExtensions, GetMeta, GetPromptRequest, GetPromptRequestParams, GetPromptResponse, + GetPromptResult, InitializeRequest, InitializedNotification, InputRequest, + InputRequiredResult, InputResponses, JsonRpcResponse, ListPromptsRequest, + ListPromptsResult, ListResourceTemplatesRequest, ListResourceTemplatesResult, + ListResourcesRequest, ListResourcesResult, ListToolsRequest, ListToolsResult, + NumberOrString, PaginatedRequestParams, ProgressNotification, ProgressNotificationParam, + ReadResourceRequest, ReadResourceRequestParams, ReadResourceResponse, ReadResourceResult, + Reference, RequestId, RootsListChangedNotification, ServerInfo, ServerJsonRpcMessage, + ServerNotification, ServerRequest, ServerResult, SetLevelRequest, SetLevelRequestParams, + SubscribeRequest, SubscribeRequestParams, UnsubscribeRequest, UnsubscribeRequestParams, }, transport::DynamicTransportError, }; @@ -361,6 +363,72 @@ macro_rules! method { } impl Peer { + /// Send one `tools/call` request and return either a final result or an MRTR + /// `InputRequiredResult` without driving any follow-up rounds. + pub async fn call_tool_once( + &self, + params: CallToolRequestParams, + ) -> Result { + let result = self + .send_request(ClientRequest::CallToolRequest(CallToolRequest { + method: Default::default(), + params, + extensions: Default::default(), + })) + .await?; + match result { + ServerResult::CallToolResult(result) => Ok(CallToolResponse::Complete(result)), + ServerResult::InputRequiredResult(result) => { + Ok(CallToolResponse::InputRequired(result)) + } + _ => Err(ServiceError::UnexpectedResponse), + } + } + + /// Send one `prompts/get` request and return either a final result or an MRTR + /// `InputRequiredResult` without driving any follow-up rounds. + pub async fn get_prompt_once( + &self, + params: GetPromptRequestParams, + ) -> Result { + let result = self + .send_request(ClientRequest::GetPromptRequest(GetPromptRequest { + method: Default::default(), + params, + extensions: Default::default(), + })) + .await?; + match result { + ServerResult::GetPromptResult(result) => Ok(GetPromptResponse::Complete(result)), + ServerResult::InputRequiredResult(result) => { + Ok(GetPromptResponse::InputRequired(result)) + } + _ => Err(ServiceError::UnexpectedResponse), + } + } + + /// Send one `resources/read` request and return either a final result or an + /// MRTR `InputRequiredResult` without driving any follow-up rounds. + pub async fn read_resource_once( + &self, + params: ReadResourceRequestParams, + ) -> Result { + let result = self + .send_request(ClientRequest::ReadResourceRequest(ReadResourceRequest { + method: Default::default(), + params, + extensions: Default::default(), + })) + .await?; + match result { + ServerResult::ReadResourceResult(result) => Ok(ReadResourceResponse::Complete(result)), + ServerResult::InputRequiredResult(result) => { + Ok(ReadResourceResponse::InputRequired(result)) + } + _ => Err(ServiceError::UnexpectedResponse), + } + } + method!(peer_req complete CompleteRequest(CompleteRequestParams) => CompleteResult); method!( #[deprecated( @@ -558,3 +626,294 @@ impl Peer { Ok(completion.values) } } + +impl RunningService +where + S: Service, +{ + /// Send one `tools/call` request without driving MRTR follow-up rounds. + pub async fn call_tool_once( + &self, + params: CallToolRequestParams, + ) -> Result { + self.peer.call_tool_once(params).await + } + + /// Send one `prompts/get` request without driving MRTR follow-up rounds. + pub async fn get_prompt_once( + &self, + params: GetPromptRequestParams, + ) -> Result { + self.peer.get_prompt_once(params).await + } + + /// Send one `resources/read` request without driving MRTR follow-up rounds. + pub async fn read_resource_once( + &self, + params: ReadResourceRequestParams, + ) -> Result { + self.peer.read_resource_once(params).await + } + + /// High-level `tools/call` helper that automatically fulfils SEP-2322 + /// `input_required` rounds through the local [`ClientHandler`] service. + /// + /// # Errors + /// + /// Returns [`ServiceError::InputRequiredRoundsExceeded`] if the peer does + /// not produce a final [`CallToolResult`] within the default MRTR round cap. + /// Other transport, protocol, and local input-handler errors are propagated. + pub async fn call_tool( + &self, + params: CallToolRequestParams, + ) -> Result { + self.call_tool_with_mrtr_max_rounds(params, DEFAULT_MRTR_MAX_ROUNDS) + .await + } + + /// Same as [`Self::call_tool`], with an explicit MRTR round cap. + /// + /// # Errors + /// + /// Returns [`ServiceError::InputRequiredRoundsExceeded`] once `max_rounds` + /// `input_required` responses have been driven without receiving a final + /// [`CallToolResult`]. Other transport, protocol, and local input-handler + /// errors are propagated. + pub async fn call_tool_with_mrtr_max_rounds( + &self, + mut params: CallToolRequestParams, + max_rounds: usize, + ) -> Result { + let mut state_only_rounds = 0usize; + for _round in 0..max_rounds { + match self.peer.call_tool_once(params.clone()).await? { + CallToolResponse::Complete(result) => return Ok(result), + CallToolResponse::InputRequired(result) => { + let (input_responses, request_state) = self + .prepare_input_required_retry(result, &mut state_only_rounds) + .await?; + params.input_responses = input_responses; + params.request_state = request_state; + } + } + } + Err(ServiceError::InputRequiredRoundsExceeded { max_rounds }) + } + + /// High-level `prompts/get` helper that automatically fulfils SEP-2322 + /// `input_required` rounds through the local [`ClientHandler`] service. + /// + /// # Errors + /// + /// Returns [`ServiceError::InputRequiredRoundsExceeded`] if the peer does + /// not produce a final [`GetPromptResult`] within the default MRTR round cap. + /// Other transport, protocol, and local input-handler errors are propagated. + pub async fn get_prompt( + &self, + params: GetPromptRequestParams, + ) -> Result { + self.get_prompt_with_mrtr_max_rounds(params, DEFAULT_MRTR_MAX_ROUNDS) + .await + } + + /// Same as [`Self::get_prompt`], with an explicit MRTR round cap. + /// + /// # Errors + /// + /// Returns [`ServiceError::InputRequiredRoundsExceeded`] once `max_rounds` + /// `input_required` responses have been driven without receiving a final + /// [`GetPromptResult`]. Other transport, protocol, and local input-handler + /// errors are propagated. + pub async fn get_prompt_with_mrtr_max_rounds( + &self, + mut params: GetPromptRequestParams, + max_rounds: usize, + ) -> Result { + let mut state_only_rounds = 0usize; + for _round in 0..max_rounds { + match self.peer.get_prompt_once(params.clone()).await? { + GetPromptResponse::Complete(result) => return Ok(result), + GetPromptResponse::InputRequired(result) => { + let (input_responses, request_state) = self + .prepare_input_required_retry(result, &mut state_only_rounds) + .await?; + params.input_responses = input_responses; + params.request_state = request_state; + } + } + } + Err(ServiceError::InputRequiredRoundsExceeded { max_rounds }) + } + + /// High-level `resources/read` helper that automatically fulfils SEP-2322 + /// `input_required` rounds through the local [`ClientHandler`] service. + /// + /// # Errors + /// + /// Returns [`ServiceError::InputRequiredRoundsExceeded`] if the peer does + /// not produce a final [`ReadResourceResult`] within the default MRTR round + /// cap. Other transport, protocol, and local input-handler errors are + /// propagated. + pub async fn read_resource( + &self, + params: ReadResourceRequestParams, + ) -> Result { + self.read_resource_with_mrtr_max_rounds(params, DEFAULT_MRTR_MAX_ROUNDS) + .await + } + + /// Same as [`Self::read_resource`], with an explicit MRTR round cap. + /// + /// # Errors + /// + /// Returns [`ServiceError::InputRequiredRoundsExceeded`] once `max_rounds` + /// `input_required` responses have been driven without receiving a final + /// [`ReadResourceResult`]. Other transport, protocol, and local input-handler + /// errors are propagated. + pub async fn read_resource_with_mrtr_max_rounds( + &self, + mut params: ReadResourceRequestParams, + max_rounds: usize, + ) -> Result { + let mut state_only_rounds = 0usize; + for _round in 0..max_rounds { + match self.peer.read_resource_once(params.clone()).await? { + ReadResourceResponse::Complete(result) => return Ok(result), + ReadResourceResponse::InputRequired(result) => { + let (input_responses, request_state) = self + .prepare_input_required_retry(result, &mut state_only_rounds) + .await?; + params.input_responses = input_responses; + params.request_state = request_state; + } + } + } + Err(ServiceError::InputRequiredRoundsExceeded { max_rounds }) + } + + async fn prepare_input_required_retry( + &self, + result: InputRequiredResult, + state_only_rounds: &mut usize, + ) -> Result<(Option, Option), ServiceError> { + let had_input_requests = result + .input_requests + .as_ref() + .is_some_and(|requests| !requests.is_empty()); + if !had_input_requests && result.request_state.is_none() { + return Err(ServiceError::UnexpectedResponse); + } + + let responses = self + .fulfill_input_requests(result.input_requests.unwrap_or_default()) + .await?; + if had_input_requests { + *state_only_rounds = 0; + } else { + Self::sleep_state_only_round(*state_only_rounds).await; + *state_only_rounds += 1; + } + + Ok(( + (!responses.is_empty()).then_some(responses), + result.request_state, + )) + } + + async fn fulfill_input_requests( + &self, + requests: crate::model::InputRequests, + ) -> Result { + let responses = futures::future::try_join_all( + requests + .into_iter() + .map(|(key, request)| self.fulfill_input_request(key, request)), + ) + .await?; + Ok(responses.into_iter().collect()) + } + + async fn fulfill_input_request( + &self, + key: String, + request: InputRequest, + ) -> Result<(String, serde_json::Value), ServiceError> { + let response = match request { + InputRequest::CreateMessage(request) => { + let mut request = ServerRequest::CreateMessageRequest(request); + let context = self.input_request_context(&key, &mut request); + match self + .service + .handle_request(request, context) + .await + .map_err(ServiceError::McpError)? + { + ClientResult::CreateMessageResult(result) => { + serde_json::to_value(result).map_err(Self::serde_to_service_error)? + } + _ => return Err(ServiceError::UnexpectedResponse), + } + } + InputRequest::Elicitation(request) => { + let mut request = ServerRequest::ElicitRequest(request); + let context = self.input_request_context(&key, &mut request); + match self + .service + .handle_request(request, context) + .await + .map_err(ServiceError::McpError)? + { + ClientResult::ElicitResult(result) => { + serde_json::to_value(result).map_err(Self::serde_to_service_error)? + } + _ => return Err(ServiceError::UnexpectedResponse), + } + } + InputRequest::ListRoots(request) => { + let mut request = ServerRequest::ListRootsRequest(request); + let context = self.input_request_context(&key, &mut request); + match self + .service + .handle_request(request, context) + .await + .map_err(ServiceError::McpError)? + { + ClientResult::ListRootsResult(result) => { + serde_json::to_value(result).map_err(Self::serde_to_service_error)? + } + _ => return Err(ServiceError::UnexpectedResponse), + } + } + }; + Ok((key, response)) + } + + fn input_request_context(&self, key: &str, request: &mut T) -> RequestContext + where + T: GetMeta + GetExtensions, + { + let mut meta = Default::default(); + let mut extensions = Default::default(); + std::mem::swap(&mut meta, request.get_meta_mut()); + std::mem::swap(&mut extensions, request.extensions_mut()); + RequestContext { + ct: tokio_util::sync::CancellationToken::new(), + id: NumberOrString::String(Arc::from(key)), + peer: self.peer.clone(), + meta, + extensions, + } + } + + async fn sleep_state_only_round(state_only_rounds: usize) { + let millis = (50u64.saturating_mul(1_u64 << state_only_rounds.min(3))).min(250); + tokio::time::sleep(Duration::from_millis(millis)).await; + } + + fn serde_to_service_error(error: serde_json::Error) -> ServiceError { + ServiceError::McpError(ErrorData::internal_error( + format!("failed to serialize MRTR input response: {error}"), + None, + )) + } +} diff --git a/crates/rmcp/tests/test_mrtr_behavior.rs b/crates/rmcp/tests/test_mrtr_behavior.rs new file mode 100644 index 000000000..b0cfef0e6 --- /dev/null +++ b/crates/rmcp/tests/test_mrtr_behavior.rs @@ -0,0 +1,166 @@ +use std::sync::{ + Arc, + atomic::{AtomicUsize, Ordering}, +}; + +use rmcp::{ + ClientHandler, ServerHandler, + model::*, + service::{RequestContext, RoleClient, RoleServer, serve_directly}, +}; +use serde_json::json; + +#[derive(Clone, Default)] +struct MrtrServer { + calls: Arc, +} + +impl ServerHandler for MrtrServer { + fn get_info(&self) -> ServerInfo { + let mut info = ServerInfo::new(ServerCapabilities::builder().enable_tools().build()); + info.protocol_version = ProtocolVersion::V_2026_07_28; + info + } + + async fn call_tool( + &self, + request: CallToolRequestParams, + _context: RequestContext, + ) -> Result { + self.calls.fetch_add(1, Ordering::SeqCst); + + if let Some(input_responses) = request.input_responses { + assert_eq!(request.request_state.as_deref(), Some("opaque-state")); + let answer = input_responses + .get("answer") + .expect("answer input response should be echoed"); + assert_eq!(answer["action"], "accept"); + assert_eq!(answer["content"]["name"], "Ferris"); + return Ok(CallToolResult::success(vec![ContentBlock::text("done")]).into()); + } + + let mut input_requests = InputRequests::new(); + input_requests.insert( + "answer".to_string(), + InputRequest::Elicitation(ElicitRequest::new( + ElicitRequestParams::FormElicitationParams { + meta: None, + message: "Name?".into(), + requested_schema: serde_json::from_value(json!({ + "type": "object", + "properties": { + "name": { "type": "string" } + }, + "required": ["name"] + })) + .unwrap(), + }, + )), + ); + Ok(InputRequiredResult::new(Some(input_requests), Some("opaque-state".into())).into()) + } +} + +#[derive(Clone, Default)] +struct MrtrClient; + +impl ClientHandler for MrtrClient { + async fn create_elicitation( + &self, + _request: ElicitRequestParams, + _context: RequestContext, + ) -> Result { + Ok( + ElicitResult::new(ElicitationAction::Accept).with_content(json!({ + "name": "Ferris" + })), + ) + } +} + +fn client_info_2026() -> ClientInfo { + ClientInfo::new( + ClientCapabilities::builder().enable_elicitation().build(), + Implementation::new("mrtr-test-client", "0.0.0"), + ) + .with_protocol_version(ProtocolVersion::V_2026_07_28) +} + +fn server_info_2026() -> ServerInfo { + let mut info = ServerInfo::new(ServerCapabilities::builder().enable_tools().build()); + info.protocol_version = ProtocolVersion::V_2026_07_28; + info +} + +#[tokio::test(flavor = "current_thread")] +async fn client_auto_fulfills_input_required_tool_call() -> anyhow::Result<()> { + tokio::task::LocalSet::new() + .run_until(async { + let (server_transport, client_transport) = tokio::io::duplex(8192); + let server = MrtrServer::default(); + let calls = server.calls.clone(); + + let server_task = tokio::task::spawn_local(async move { + let running = serve_directly::( + server, + server_transport, + Some(client_info_2026()), + ); + running.waiting().await?; + anyhow::Ok(()) + }); + + let client = serve_directly::( + MrtrClient, + client_transport, + Some(server_info_2026()), + ); + + let result = client + .call_tool(CallToolRequestParams::new("needs_input")) + .await?; + assert_eq!(result.content.len(), 1); + assert_eq!(result.content[0].as_text().unwrap().text, "done"); + assert_eq!(calls.load(Ordering::SeqCst), 2); + + drop(client); + server_task.abort(); + Ok(()) + }) + .await +} + +#[tokio::test(flavor = "current_thread")] +async fn manual_once_returns_input_required_without_retry() -> anyhow::Result<()> { + tokio::task::LocalSet::new() + .run_until(async { + let (server_transport, client_transport) = tokio::io::duplex(8192); + let server = MrtrServer::default(); + + let server_task = tokio::task::spawn_local(async move { + let running = serve_directly::( + server, + server_transport, + Some(client_info_2026()), + ); + running.waiting().await?; + anyhow::Ok(()) + }); + + let client = serve_directly::( + MrtrClient, + client_transport, + Some(server_info_2026()), + ); + + let result = client + .call_tool_once(CallToolRequestParams::new("needs_input")) + .await?; + assert!(matches!(result, CallToolResponse::InputRequired(_))); + + drop(client); + server_task.abort(); + Ok(()) + }) + .await +} diff --git a/crates/rmcp/tests/test_structured_output.rs b/crates/rmcp/tests/test_structured_output.rs index bb0d5e029..513c5ee5e 100644 --- a/crates/rmcp/tests/test_structured_output.rs +++ b/crates/rmcp/tests/test_structured_output.rs @@ -3,7 +3,7 @@ use rmcp::{ Json, ServerHandler, handler::server::{router::tool::ToolRouter, tool::IntoCallToolResult, wrapper::Parameters}, - model::{CallToolResult, ContentBlock, ServerResult, Tool}, + model::{CallToolResponse, CallToolResult, ContentBlock, ServerResult, Tool}, tool, tool_handler, tool_router, }; use schemars::JsonSchema; @@ -224,11 +224,13 @@ async fn test_structured_return_conversion() { }; let structured = Json(calc_result); - let result: Result = + let result: Result = rmcp::handler::server::tool::IntoCallToolResult::into_call_tool_result(structured); assert!(result.is_ok()); - let call_result = result.unwrap(); + let CallToolResponse::Complete(call_result) = result.unwrap() else { + panic!("expected complete CallToolResult"); + }; // Tools which return structured content should also return a serialized version as // Content::text for backwards compatibility. @@ -285,11 +287,13 @@ async fn test_output_schema_requires_structured_content() { let result = server.calculate(params).await.unwrap(); // Convert the Json to CallToolResult - let call_result: Result = + let call_result: Result = IntoCallToolResult::into_call_tool_result(result); assert!(call_result.is_ok()); - let call_result = call_result.unwrap(); + let CallToolResponse::Complete(call_result) = call_result.unwrap() else { + panic!("expected complete CallToolResult"); + }; // Verify it has structured_content and content assert!(call_result.structured_content.is_some());