diff --git a/md/SUMMARY.md b/md/SUMMARY.md index 6750364..d5cf434 100644 --- a/md/SUMMARY.md +++ b/md/SUMMARY.md @@ -6,6 +6,7 @@ - [Design Overview](./design.md) - [Protocol Reference](./protocol.md) +- [Request Cancellation](./request-cancellation.md) - [Protocol V2](./protocol-v2.md) # Conductor (agent-client-protocol-conductor) diff --git a/md/request-cancellation.md b/md/request-cancellation.md new file mode 100644 index 0000000..4e8dc64 --- /dev/null +++ b/md/request-cancellation.md @@ -0,0 +1,106 @@ +# Request Cancellation + +The SDK exposes the ACP `$/cancel_request` notification behind the +`unstable_cancel_request` feature. The notification is protocol-level: either +side may send it to ask the peer to cancel one outstanding JSON-RPC request by +ID. + +Enable the feature when depending on the crate: + +```toml +agent-client-protocol = { version = "...", features = ["unstable_cancel_request"] } +``` + +To cancel a request sent through `ConnectionTo::send_request`, keep the +returned `SentRequest` and call `cancel` on it: + +```rust +# use agent_client_protocol::{ConnectionTo, Error, UntypedRole}; +# use agent_client_protocol_test::MyRequest; +# async fn example(cx: ConnectionTo) -> Result<(), Error> { +let request = cx.send_request(MyRequest {}); +request.cancel()?; +# Ok(()) +# } +``` + +The `SentRequest` remembers the peer and any proxy wrapping used for the +original request, so this also works for requests sent through +`ConnectionTo::send_request_to`. + +Dropping a `SentRequest` before the SDK receives a response also sends +`$/cancel_request`. This covers abandoned request handles and futures. Once the +SDK routes a response to the waiting request handle, automatic cancellation is +disarmed, even if caller code has not yet consumed it with `block_task`, +`on_receiving_result`, or `forward_response_to`. + +If you already have the JSON-RPC request ID, send the notification directly: + +```rust +# use agent_client_protocol::{ConnectionTo, Error, UntypedRole}; +# async fn example(cx: ConnectionTo) -> Result<(), Error> { +cx.send_cancel_request("request-id".to_string())?; +# Ok(()) +# } +``` + +For incoming requests, get the request-local cancellation marker from the +`Responder`. This keeps cancellation handling next to the request work it +controls: + +```rust +# use agent_client_protocol::{ConnectionTo, Error, Responder, UntypedRole}; +# use agent_client_protocol_test::{MyRequest, MyResponse}; +# async fn example(request: MyRequest, responder: Responder, cx: ConnectionTo) -> Result<(), Error> { +# async fn run_request(_request: MyRequest) -> Result { todo!() } +let cancellation = responder.cancellation(); + +cx.spawn(async move { + let response = cancellation.run_until_cancelled(run_request(request)).await; + responder.respond_with_result(response) +})?; +Ok(()) +# } +``` + +`run_until_cancelled` is the simple path for handlers that should stop work and +reply with the standard cancellation error as soon as cancellation is requested. +If the handler needs cleanup, partial results, or custom cancellation behavior, +use `cancellation.cancelled()` or `cancellation.is_cancelled()` directly inside +the request work instead. + +Cancellation markers are only updated when the connection can process the +incoming `$/cancel_request` notification. Long-running handlers should return +quickly and move work into `ConnectionTo::spawn`, `SentRequest` callbacks, or +another task. + +When proxying with `SentRequest::forward_response_to`, the SDK observes the +upstream `Responder` cancellation marker and forwards cancellation to the +downstream request automatically. + +Register `CancelRequestNotification` or `ProtocolLevelNotification` directly +only when you need low-level access to cancellation notifications, such as +custom routing or protocol tracing: + +```rust +# use agent_client_protocol::{ConnectionTo, Error, UntypedRole}; +use agent_client_protocol::schema::CancelRequestNotification; + +# fn builder() -> agent_client_protocol::Builder { +UntypedRole.builder() + .on_receive_notification( + async |cancel: CancelRequestNotification, _cx: ConnectionTo| { + let request_id = cancel.request_id; + // Mark the matching in-flight operation cancelled. + Ok(()) + }, + agent_client_protocol::on_receive_notification!(), + ) +# } +``` + +Cancellation is cooperative. A peer may ignore `$/cancel_request`, may finish +with normal data, or may respond to the original request with +`Error::request_cancelled()` (`-32800`). The SDK ignores unhandled `$/...` +notifications so unsupported protocol-level notifications do not produce +method-not-found errors. diff --git a/src/agent-client-protocol/CHANGELOG.md b/src/agent-client-protocol/CHANGELOG.md index f1aadd0..b20dc90 100644 --- a/src/agent-client-protocol/CHANGELOG.md +++ b/src/agent-client-protocol/CHANGELOG.md @@ -2,6 +2,10 @@ ## [Unreleased] +### Added + +- *(unstable)* Add SDK support for protocol-level request cancellation, including `SentRequest::cancel`, automatic cancellation when a `SentRequest` is dropped before receiving a response, request-local cancellation helpers on `Responder`, and forwarded cancellation propagation. + ## [0.12.1](https://github.com/agentclientprotocol/rust-sdk/compare/v0.12.0...v0.12.1) - 2026-05-17 ### Other diff --git a/src/agent-client-protocol/Cargo.toml b/src/agent-client-protocol/Cargo.toml index 7946942..4820782 100644 --- a/src/agent-client-protocol/Cargo.toml +++ b/src/agent-client-protocol/Cargo.toml @@ -18,6 +18,7 @@ default = [] unstable = [ "unstable_auth_methods", "unstable_boolean_config", + "unstable_cancel_request", "unstable_logout", "unstable_mcp_over_acp", "unstable_message_id", @@ -29,6 +30,7 @@ unstable = [ ] unstable_auth_methods = ["agent-client-protocol-schema/unstable_auth_methods"] unstable_boolean_config = ["agent-client-protocol-schema/unstable_boolean_config"] +unstable_cancel_request = ["agent-client-protocol-schema/unstable_cancel_request"] unstable_logout = ["agent-client-protocol-schema/unstable_logout"] unstable_mcp_over_acp = ["agent-client-protocol-schema/unstable_mcp_over_acp"] unstable_message_id = ["agent-client-protocol-schema/unstable_message_id"] diff --git a/src/agent-client-protocol/src/jsonrpc.rs b/src/agent-client-protocol/src/jsonrpc.rs index 73c553d..a8e8ca4 100644 --- a/src/agent-client-protocol/src/jsonrpc.rs +++ b/src/agent-client-protocol/src/jsonrpc.rs @@ -7,11 +7,20 @@ pub use jsonrpcmsg; // Types re-exported from crate root use serde::{Deserialize, Serialize}; use std::any::TypeId; +#[cfg(feature = "unstable_cancel_request")] +use std::collections::HashMap; use std::fmt::Debug; use std::panic::Location; use std::pin::pin; +#[cfg(feature = "unstable_cancel_request")] +use std::sync::{ + Arc, Mutex, + atomic::{AtomicBool, Ordering}, +}; use uuid::Uuid; +#[cfg(feature = "unstable_cancel_request")] +use futures::FutureExt; use futures::channel::{mpsc, oneshot}; use futures::future::{self, BoxFuture, Either}; use futures::{AsyncRead, AsyncWrite, StreamExt}; @@ -1349,6 +1358,9 @@ enum ReplyMessage { method: String, sender: oneshot::Sender, + + #[cfg(feature = "unstable_cancel_request")] + cancellation_disarm: SentRequestCancellationDisarm, }, } @@ -1364,6 +1376,267 @@ impl std::fmt::Debug for ReplyMessage { } } +/// A request-local marker that is set when the peer asks to cancel the request. +/// +/// Request handlers can get this handle from [`Responder::cancellation`] and +/// use it from spawned work to stop long-running request processing +/// cooperatively. +#[cfg(feature = "unstable_cancel_request")] +#[derive(Clone)] +pub struct RequestCancellation { + state: Arc, +} + +#[cfg(feature = "unstable_cancel_request")] +struct RequestCancellationState { + cancelled: AtomicBool, + signal_tx: Mutex>>, + signal_rx: future::Shared>, +} + +#[cfg(feature = "unstable_cancel_request")] +impl RequestCancellation { + fn new() -> Self { + let (signal_tx, signal_rx) = oneshot::channel(); + let signal_rx = signal_rx.map(|_| ()).boxed().shared(); + Self { + state: Arc::new(RequestCancellationState { + cancelled: AtomicBool::new(false), + signal_tx: Mutex::new(Some(signal_tx)), + signal_rx, + }), + } + } + + /// Wait until the peer sends `$/cancel_request` for this request. + /// + /// If cancellation was already requested, this returns immediately. + pub async fn cancelled(&self) { + self.state.signal_rx.clone().await; + } + + /// Run request work until it completes or the peer asks to cancel it. + /// + /// If cancellation is requested first, this returns + /// [`Error::request_cancelled`]. This is a convenience for request handlers + /// that want to respond with the normal result or the standard + /// cancellation error. + /// + /// [`Error::request_cancelled`]: crate::Error::request_cancelled + pub async fn run_until_cancelled( + &self, + future: impl std::future::Future>, + ) -> Result { + if self.is_cancelled() { + return Err(crate::Error::request_cancelled()); + } + + match future::select(Box::pin(future), Box::pin(self.cancelled())).await { + Either::Left((result, _)) => result, + Either::Right(((), _)) => Err(crate::Error::request_cancelled()), + } + } + + /// Returns whether the peer has already requested cancellation. + #[must_use] + pub fn is_cancelled(&self) -> bool { + self.state.cancelled.load(Ordering::Acquire) + } + + fn cancel(&self) { + if self.state.cancelled.swap(true, Ordering::AcqRel) { + return; + } + + if let Some(signal_tx) = self + .state + .signal_tx + .lock() + .expect("request cancellation signal mutex poisoned") + .take() + { + let _ = signal_tx.send(()); + } + } +} + +#[cfg(feature = "unstable_cancel_request")] +impl Debug for RequestCancellation { + fn fmt(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter + .debug_struct("RequestCancellation") + .field("is_cancelled", &self.is_cancelled()) + .finish_non_exhaustive() + } +} + +#[cfg(feature = "unstable_cancel_request")] +#[derive(Clone, Debug, Default)] +struct RequestCancellationRegistry { + inner: Arc>>, +} + +#[cfg(not(feature = "unstable_cancel_request"))] +#[derive(Clone, Debug, Default)] +struct RequestCancellationRegistry; + +#[cfg(feature = "unstable_cancel_request")] +#[derive(Debug)] +struct ResponderCancellation { + id: serde_json::Value, + registry: RequestCancellationRegistry, + cancellation: RequestCancellation, +} + +#[cfg(not(feature = "unstable_cancel_request"))] +#[derive(Debug)] +struct ResponderCancellation; + +#[cfg(feature = "unstable_cancel_request")] +impl RequestCancellationRegistry { + fn new() -> Self { + Self::default() + } + + fn register(&self, id: serde_json::Value) -> ResponderCancellation { + let cancellation = RequestCancellation::new(); + self.inner + .lock() + .expect("request cancellation registry mutex poisoned") + .insert(id.clone(), cancellation.clone()); + ResponderCancellation { + id, + registry: self.clone(), + cancellation, + } + } + + fn cancel_if_requested(&self, dispatch: &Dispatch) -> Result { + let Some(request_id) = cancellation_request_id(dispatch)? else { + return Ok(false); + }; + Ok(self.cancel(&request_id)) + } + + fn cancel(&self, request_id: &serde_json::Value) -> bool { + let cancellation = self + .inner + .lock() + .expect("request cancellation registry mutex poisoned") + .get(request_id) + .cloned(); + if let Some(cancellation) = cancellation { + cancellation.cancel(); + true + } else { + false + } + } + + fn remove(&self, request_id: &serde_json::Value) { + self.inner + .lock() + .expect("request cancellation registry mutex poisoned") + .remove(request_id); + } +} + +#[cfg(not(feature = "unstable_cancel_request"))] +impl RequestCancellationRegistry { + fn new() -> Self { + Self + } + + #[expect( + clippy::unused_self, + reason = "feature-disabled stub mirrors the real registry API" + )] + fn register(&self, _id: serde_json::Value) -> ResponderCancellation { + ResponderCancellation + } + + #[expect( + clippy::unused_self, + clippy::unnecessary_wraps, + reason = "feature-disabled stub mirrors the real registry API" + )] + fn cancel_if_requested(&self, _dispatch: &Dispatch) -> Result { + Ok(false) + } +} + +#[cfg(feature = "unstable_cancel_request")] +impl ResponderCancellation { + fn cancellation(&self) -> RequestCancellation { + self.cancellation.clone() + } +} + +#[cfg(feature = "unstable_cancel_request")] +impl Drop for ResponderCancellation { + fn drop(&mut self) { + self.registry.remove(&self.id); + } +} + +#[cfg(feature = "unstable_cancel_request")] +fn cancellation_request_id(dispatch: &Dispatch) -> Result, crate::Error> { + let Dispatch::Notification(message) = dispatch else { + return Ok(None); + }; + cancellation_request_id_from_message(message) +} + +#[cfg(feature = "unstable_cancel_request")] +fn cancellation_request_id_from_message( + message: &UntypedMessage, +) -> Result, crate::Error> { + if crate::schema::CancelRequestNotification::matches_method(&message.method) { + let notification = crate::schema::CancelRequestNotification::parse_message( + &message.method, + &message.params, + )?; + return serde_json::to_value(notification.request_id) + .map(Some) + .map_err(crate::Error::into_internal_error); + } + + if crate::schema::SuccessorMessage::::matches_method(&message.method) { + let successor = crate::schema::SuccessorMessage::::parse_message( + &message.method, + &message.params, + )?; + return cancellation_request_id_from_message(&successor.message); + } + + Ok(None) +} + +fn is_protocol_level_notification(dispatch: &Dispatch) -> bool { + let Dispatch::Notification(message) = dispatch else { + return false; + }; + is_protocol_level_notification_message(message) +} + +fn is_protocol_level_notification_message(message: &UntypedMessage) -> bool { + if message.method.starts_with("$/") { + return true; + } + + if crate::schema::SuccessorMessage::::matches_method(&message.method) { + let Ok(successor) = crate::schema::SuccessorMessage::::parse_message( + &message.method, + &message.params, + ) else { + return false; + }; + return is_protocol_level_notification_message(&successor.message); + } + + false +} + /// Messages send to be serialized over the transport. #[derive(Debug)] enum OutgoingMessage { @@ -1384,6 +1657,9 @@ enum OutgoingMessage { /// where to send the response when it arrives (includes ack channel) response_tx: oneshot::Sender, + + #[cfg(feature = "unstable_cancel_request")] + cancellation_disarm: SentRequestCancellationDisarm, }, /// Send a notification to the server. @@ -1721,6 +1997,9 @@ impl ConnectionTo { let (response_tx, response_rx) = oneshot::channel(); let role_id = peer.role_id(); let remote_style = self.counterpart.remote_style(peer); + #[cfg(feature = "unstable_cancel_request")] + let cancellation = + SentRequestCancellation::new(self.message_tx.clone(), &remote_style, &id); match remote_style.transform_outgoing_message(request) { Ok(untyped) => { // Transform the message for the target role @@ -1730,11 +2009,16 @@ impl ConnectionTo { role_id, untyped, response_tx, + #[cfg(feature = "unstable_cancel_request")] + cancellation_disarm: cancellation.disarm_handle(), }; match self.message_tx.unbounded_send(message) { Ok(()) => (), Err(error) => { + #[cfg(feature = "unstable_cancel_request")] + cancellation.disarm(); + let OutgoingMessage::Request { method, response_tx, @@ -1757,6 +2041,9 @@ impl ConnectionTo { } Err(err) => { + #[cfg(feature = "unstable_cancel_request")] + cancellation.disarm(); + response_tx .send(ResponsePayload { result: Err(crate::util::internal_error(format!( @@ -1768,8 +2055,15 @@ impl ConnectionTo { } } - SentRequest::new(id, method.clone(), self.task_tx.clone(), response_rx) - .map(move |json| ::from_value(&method, json)) + SentRequest::new( + id, + method.clone(), + self.task_tx.clone(), + response_rx, + #[cfg(feature = "unstable_cancel_request")] + cancellation, + ) + .map(move |json| ::from_value(&method, json)) } /// Send an outgoing notification to the default counterpart peer (no reply expected). @@ -1833,6 +2127,36 @@ impl ConnectionTo { ) } + /// Send a `$/cancel_request` notification for an arbitrary request ID to + /// the default counterpart peer. + #[cfg(feature = "unstable_cancel_request")] + pub fn send_cancel_request( + &self, + request_id: impl Into, + ) -> Result<(), crate::Error> + where + Counterpart: HasPeer, + { + self.send_cancel_request_to(self.counterpart.clone(), request_id) + } + + /// Send a `$/cancel_request` notification for an arbitrary request ID to a + /// specific peer. + #[cfg(feature = "unstable_cancel_request")] + pub fn send_cancel_request_to( + &self, + peer: Peer, + request_id: impl Into, + ) -> Result<(), crate::Error> + where + Counterpart: HasPeer, + { + self.send_notification_to( + peer, + crate::schema::CancelRequestNotification::new(request_id), + ) + } + /// Send an error notification (no reply expected). pub fn send_error_notification(&self, error: crate::Error) -> Result<(), crate::Error> { send_raw_message(&self.message_tx, OutgoingMessage::Error { error }) @@ -1943,6 +2267,9 @@ pub struct Responder { /// The `id` of the message we are replying to. id: jsonrpcmsg::Id, + /// Request-local cancellation state. + cancellation: ResponderCancellation, + /// Function to send the response to its destination. /// /// For incoming requests: serializes to JSON and sends over the wire. @@ -1964,12 +2291,19 @@ impl Responder { /// Create a new request context for an incoming request. /// /// The response will be serialized to JSON and sent over the wire. - fn new(message_tx: OutgoingMessageTx, method: String, id: jsonrpcmsg::Id) -> Self { + fn new( + message_tx: OutgoingMessageTx, + method: String, + id: jsonrpcmsg::Id, + cancellation_registry: &RequestCancellationRegistry, + ) -> Self { let id_clone = id.clone(); let method_clone = method.clone(); + let cancellation = cancellation_registry.register(crate::util::id_to_json(&id)); Self { method, id, + cancellation, send_fn: Box::new(move |response: Result| { send_raw_message( &message_tx, @@ -2007,6 +2341,20 @@ impl Responder { crate::util::id_to_json(&self.id) } + /// Returns the cancellation marker for this request. + /// + /// The marker is set when the peer sends `$/cancel_request` for this + /// request's JSON-RPC ID. Cancellation is cooperative: handlers should use + /// the marker to stop long-running work and then decide whether to respond + /// with [`Error::request_cancelled`] or partial data. + /// + /// [`Error::request_cancelled`]: crate::Error::request_cancelled + #[cfg(feature = "unstable_cancel_request")] + #[must_use] + pub fn cancellation(&self) -> RequestCancellation { + self.cancellation.cancellation() + } + /// Convert to a `Responder` that expects a JSON value /// and which checks (dynamically) that the JSON value it receives /// can be converted to `T`. @@ -2019,6 +2367,7 @@ impl Responder { Responder { method, id: self.id, + cancellation: self.cancellation, send_fn: self.send_fn, } } @@ -2035,6 +2384,7 @@ impl Responder { Responder { method: self.method, id: self.id, + cancellation: self.cancellation, send_fn: Box::new(move |input: Result| { let t_value = wrap_fn(&method, input); (self.send_fn)(t_value) @@ -2106,26 +2456,40 @@ impl ResponseRouter { /// Create a new response context for routing a response to a local awaiter. /// /// When `respond_with_result` is called, the response is sent through the oneshot - /// channel to the code that originally sent the request. + /// channel to the code that originally sent the request. If that receiver was + /// dropped, the response is discarded because there is no local awaiter left. pub(crate) fn new( method: String, id: jsonrpcmsg::Id, role_id: RoleId, sender: oneshot::Sender, + #[cfg(feature = "unstable_cancel_request")] + cancellation_disarm: SentRequestCancellationDisarm, ) -> Self { + let response_method = method.clone(); + let response_id = id.clone(); Self { method, id, role_id, send_fn: Box::new(move |response: Result| { - sender + if sender .send(ResponsePayload { result: response, ack_tx: None, }) - .map_err(|_| { - crate::util::internal_error("failed to send response, receiver dropped") - }) + .is_err() + { + tracing::debug!( + method = %response_method, + id = ?response_id, + "dropped response because local receiver was gone" + ); + } else { + #[cfg(feature = "unstable_cancel_request")] + cancellation_disarm.disarm(); + } + Ok(()) }), } } @@ -2813,16 +3177,166 @@ pub struct SentRequest { task_tx: TaskTx, response_rx: oneshot::Receiver, to_result: Box Result + Send>, + #[cfg(feature = "unstable_cancel_request")] + cancellation: SentRequestCancellation, +} + +#[cfg(feature = "unstable_cancel_request")] +fn jsonrpc_id_to_request_id(id: &jsonrpcmsg::Id) -> Result { + match id { + jsonrpcmsg::Id::String(value) => Ok(crate::schema::RequestId::Str(value.clone())), + jsonrpcmsg::Id::Number(value) => Ok(crate::schema::RequestId::Number( + i64::try_from(*value).map_err(|_| { + crate::util::internal_error(format!( + "request ID `{value}` cannot be represented as an ACP request ID" + )) + })?, + )), + jsonrpcmsg::Id::Null => Ok(crate::schema::RequestId::Null), + } +} + +#[cfg(feature = "unstable_cancel_request")] +#[derive(Clone, Debug)] +pub(crate) struct SentRequestCancellationDisarm { + armed: Arc, +} + +#[cfg(feature = "unstable_cancel_request")] +impl SentRequestCancellationDisarm { + fn new() -> Self { + Self { + armed: Arc::new(AtomicBool::new(true)), + } + } + + fn disarm(&self) { + self.armed.store(false, Ordering::Release); + } +} + +#[cfg(feature = "unstable_cancel_request")] +enum SentRequestCancellation { + Send { + message_tx: OutgoingMessageTx, + notification: UntypedMessage, + disarm: SentRequestCancellationDisarm, + }, + Failed { + error: String, + disarm: SentRequestCancellationDisarm, + }, +} + +#[cfg(feature = "unstable_cancel_request")] +impl SentRequestCancellation { + fn new( + message_tx: OutgoingMessageTx, + remote_style: &crate::role::RemoteStyle, + request_id: &jsonrpcmsg::Id, + ) -> Self { + let notification = jsonrpc_id_to_request_id(request_id) + .and_then(|request_id| { + remote_style.transform_outgoing_message( + crate::schema::CancelRequestNotification::new(request_id), + ) + }) + .map_err(|error| error.to_string()); + let disarm = SentRequestCancellationDisarm::new(); + + match notification { + Ok(notification) => Self::Send { + message_tx, + notification, + disarm, + }, + Err(error) => Self::Failed { error, disarm }, + } + } + + fn disarm(&self) { + self.disarm_handle().disarm(); + } + + fn disarm_handle(&self) -> SentRequestCancellationDisarm { + match self { + Self::Send { disarm, .. } | Self::Failed { disarm, .. } => disarm.clone(), + } + } + + fn send(&self) -> Result<(), crate::Error> { + match self { + Self::Send { + message_tx, + notification, + disarm, + } => { + if !disarm.armed.swap(false, Ordering::AcqRel) { + return Ok(()); + } + + send_raw_message( + message_tx, + OutgoingMessage::Notification { + untyped: notification.clone(), + }, + ) + } + Self::Failed { error, disarm } => { + if !disarm.armed.swap(false, Ordering::AcqRel) { + return Ok(()); + } + + Err(crate::util::internal_error(format!( + "failed to create cancel request notification: {error}" + ))) + } + } + } +} + +#[cfg(feature = "unstable_cancel_request")] +impl Drop for SentRequestCancellation { + fn drop(&mut self) { + if let Err(error) = self.send() { + tracing::debug!(?error, "failed to auto-cancel dropped request"); + } + } +} + +#[cfg(feature = "unstable_cancel_request")] +impl Debug for SentRequestCancellation { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Send { + notification, + disarm, + .. + } => f + .debug_struct("SentRequestCancellation") + .field("notification", notification) + .field("armed", &disarm.armed.load(Ordering::Acquire)) + .finish(), + Self::Failed { error, disarm } => f + .debug_struct("SentRequestCancellation") + .field("error", error) + .field("armed", &disarm.armed.load(Ordering::Acquire)) + .finish(), + } + } } impl Debug for SentRequest { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("SentRequest") + let mut debug = f.debug_struct("SentRequest"); + debug .field("id", &self.id) .field("method", &self.method) .field("task_tx", &self.task_tx) - .field("response_rx", &self.response_rx) - .finish_non_exhaustive() + .field("response_rx", &self.response_rx); + #[cfg(feature = "unstable_cancel_request")] + debug.field("cancellation", &self.cancellation); + debug.finish_non_exhaustive() } } @@ -2832,6 +3346,7 @@ impl SentRequest { method: String, task_tx: mpsc::UnboundedSender, response_rx: oneshot::Receiver, + #[cfg(feature = "unstable_cancel_request")] cancellation: SentRequestCancellation, ) -> Self { Self { id, @@ -2839,10 +3354,24 @@ impl SentRequest { response_rx, task_tx, to_result: Box::new(Ok), + #[cfg(feature = "unstable_cancel_request")] + cancellation, } } } +impl SentRequest { + /// Send a `$/cancel_request` notification for this outgoing request. + /// + /// This uses the same peer and message wrapping that were used to send the + /// original request, so it is the preferred way to cancel a [`SentRequest`] + /// when the request handle is still available. + #[cfg(feature = "unstable_cancel_request")] + pub fn cancel(&self) -> Result<(), crate::Error> { + self.cancellation.send() + } +} + impl SentRequest { /// The id of the outgoing request. #[must_use] @@ -2867,6 +3396,8 @@ impl SentRequest { response_rx: self.response_rx, task_tx: self.task_tx, to_result: Box::new(move |value| map_fn((self.to_result)(value)?)), + #[cfg(feature = "unstable_cancel_request")] + cancellation: self.cancellation, } } @@ -2925,7 +3456,68 @@ impl SentRequest { where T: Send, { - self.on_receiving_result(async move |result| responder.respond_with_result(result)) + #[cfg(feature = "unstable_cancel_request")] + { + self.forward_response_to_observing_cancellation(responder) + } + #[cfg(not(feature = "unstable_cancel_request"))] + { + self.on_receiving_result(async move |result| responder.respond_with_result(result)) + } + } + + #[cfg(feature = "unstable_cancel_request")] + #[track_caller] + fn forward_response_to_observing_cancellation( + self, + responder: Responder, + ) -> Result<(), crate::Error> + where + T: Send, + { + let task_tx = self.task_tx.clone(); + let method = self.method; + let response_rx = self.response_rx; + let to_result = self.to_result; + let downstream_cancellation = self.cancellation; + let upstream_cancellation = responder.cancellation(); + let location = Location::caller(); + + Task::new(location, async move { + let response = if upstream_cancellation.is_cancelled() { + downstream_cancellation.send()?; + response_rx.await + } else { + match future::select(Box::pin(upstream_cancellation.cancelled()), response_rx).await + { + Either::Left(((), response_rx)) => { + downstream_cancellation.send()?; + response_rx.await + } + Either::Right((response, _)) => response, + } + }; + + downstream_cancellation.disarm(); + + let ResponsePayload { result, ack_tx } = response.map_err(|err| { + crate::util::internal_error(format!("response to `{method}` never received: {err}")) + })?; + + let typed_result = match result { + Ok(json_value) => to_result(json_value), + Err(err) => Err(err), + }; + + let outcome = responder.respond_with_result(typed_result); + + if let Some(tx) = ack_tx { + let _ = tx.send(()); + } + + outcome + }) + .spawn(&task_tx) } /// Block the current task until the response is received. @@ -2999,6 +3591,9 @@ impl SentRequest { result: Ok(json_value), ack_tx, }) => { + #[cfg(feature = "unstable_cancel_request")] + self.cancellation.disarm(); + // Ack immediately - we're in a spawned task, so the dispatch loop // can continue while we process the value. if let Some(tx) = ack_tx { @@ -3013,15 +3608,23 @@ impl SentRequest { result: Err(err), ack_tx, }) => { + #[cfg(feature = "unstable_cancel_request")] + self.cancellation.disarm(); + if let Some(tx) = ack_tx { let _ = tx.send(()); } Err(err) } - Err(err) => Err(crate::util::internal_error(format!( - "response to `{}` never received: {}", - self.method, err - ))), + Err(err) => { + #[cfg(feature = "unstable_cancel_request")] + self.cancellation.disarm(); + + Err(crate::util::internal_error(format!( + "response to `{}` never received: {}", + self.method, err + ))) + } } } @@ -3171,11 +3774,16 @@ impl SentRequest { let method = self.method; let response_rx = self.response_rx; let to_result = self.to_result; + #[cfg(feature = "unstable_cancel_request")] + let cancellation = self.cancellation; let location = Location::caller(); Task::new(location, async move { match response_rx.await { Ok(ResponsePayload { result, ack_tx }) => { + #[cfg(feature = "unstable_cancel_request")] + cancellation.disarm(); + // Convert the result using to_result for Ok values let typed_result = match result { Ok(json_value) => to_result(json_value), @@ -3193,9 +3801,14 @@ impl SentRequest { outcome } - Err(err) => Err(crate::util::internal_error(format!( - "response to `{method}` never received: {err}" - ))), + Err(err) => { + #[cfg(feature = "unstable_cancel_request")] + cancellation.disarm(); + + Err(crate::util::internal_error(format!( + "response to `{method}` never received: {err}" + ))) + } } }) .spawn(&task_tx) diff --git a/src/agent-client-protocol/src/jsonrpc/incoming_actor.rs b/src/agent-client-protocol/src/jsonrpc/incoming_actor.rs index 302554f..1a5b80d 100644 --- a/src/agent-client-protocol/src/jsonrpc/incoming_actor.rs +++ b/src/agent-client-protocol/src/jsonrpc/incoming_actor.rs @@ -30,6 +30,8 @@ struct PendingReply { method: String, role_id: RoleId, sender: oneshot::Sender, + #[cfg(feature = "unstable_cancel_request")] + cancellation_disarm: super::SentRequestCancellationDisarm, } /// Incoming protocol actor: The central dispatch loop for a connection. @@ -62,6 +64,8 @@ pub(super) async fn incoming_protocol_actor( FxHashMap::default(); let mut pending_messages: Vec = vec![]; + let request_cancellations = super::RequestCancellationRegistry::new(); + // Map from request ID to (method, sender) for response dispatch. // Keys are JSON values because jsonrpcmsg::Id doesn't implement Eq. // The method is stored to allow routing responses through typed handlers. @@ -76,6 +80,8 @@ pub(super) async fn incoming_protocol_actor( role_id, method, sender, + #[cfg(feature = "unstable_cancel_request")] + cancellation_disarm, } => { tracing::trace!(?id, %method, "incoming_actor: subscribing to response"); let id = serde_json::to_value(&id).unwrap(); @@ -85,6 +91,8 @@ pub(super) async fn incoming_protocol_actor( method, role_id, sender, + #[cfg(feature = "unstable_cancel_request")] + cancellation_disarm, }, ); } @@ -135,7 +143,12 @@ pub(super) async fn incoming_protocol_actor( tracing::trace!(method = %request.method, id = ?request.id, "Handling request"); let request_method = request.method.clone(); let request_id = request.id.clone(); - match dispatch_from_request(connection, request, &protocol_compat) { + match dispatch_from_request( + connection, + request, + &protocol_compat, + &request_cancellations, + ) { Ok(dispatch) => { dispatch_dispatch( counterpart.clone(), @@ -144,6 +157,7 @@ pub(super) async fn incoming_protocol_actor( &mut dynamic_handlers, &mut handler, &mut pending_messages, + &request_cancellations, ) .await?; } @@ -183,6 +197,7 @@ pub(super) async fn incoming_protocol_actor( &mut dynamic_handlers, &mut handler, &mut pending_messages, + &request_cancellations, ) .await?; } else { @@ -218,6 +233,7 @@ fn dispatch_from_request( connection: &ConnectionTo, request: jsonrpcmsg::Request, protocol_compat: &ProtocolCompat, + request_cancellations: &super::RequestCancellationRegistry, ) -> Result { let message = UntypedMessage::new(&request.method, &request.params).expect("well-formed JSON"); let message = protocol_compat.incoming_message(message)?; @@ -229,6 +245,7 @@ fn dispatch_from_request( connection.message_tx.clone(), request.method.clone(), id.clone(), + request_cancellations, ), )), None => Ok(Dispatch::Notification(message)), @@ -249,10 +266,19 @@ fn dispatch_from_response( method, role_id, sender, + #[cfg(feature = "unstable_cancel_request")] + cancellation_disarm, } = pending_reply; // Create a Dispatch::Response with a ResponseRouter that routes to the oneshot - let router = ResponseRouter::new(method.clone(), id.clone(), role_id, sender); + let router = ResponseRouter::new( + method.clone(), + id.clone(), + role_id, + sender, + #[cfg(feature = "unstable_cancel_request")] + cancellation_disarm, + ); Dispatch::Response(result, router) } @@ -268,6 +294,7 @@ async fn dispatch_dispatch( dynamic_handlers: &mut FxHashMap>>, handler: &mut impl HandleDispatchFrom, pending_messages: &mut Vec, + request_cancellations: &super::RequestCancellationRegistry, ) -> Result<(), crate::Error> { tracing::trace!(?dispatch, "dispatch_dispatch"); @@ -276,6 +303,22 @@ async fn dispatch_dispatch( let id = dispatch.id(); let method = dispatch.method().to_string(); + match request_cancellations.cancel_if_requested(&dispatch) { + Ok(true) => { + tracing::debug!(?method, "Marked request as cancelled"); + } + Ok(false) => {} + Err(err) => { + tracing::warn!( + ?method, + ?id, + ?err, + "Request cancellation notification errored" + ); + return report_handler_error(connection, id, method, err); + } + } + // First, apply the handlers given by the user. tracing::trace!(handler = ?handler.describe_chain(), "Attempting handler chain"); match handler @@ -351,6 +394,11 @@ async fn dispatch_dispatch( } } + if super::is_protocol_level_notification(&dispatch) { + tracing::debug!(?method, "Ignoring unhandled protocol-level notification"); + return Ok(()); + } + // If the message was never handled, check whether the retry flag was set. // If so, enqueue it for later processing. Else, reject it. if retry_any { diff --git a/src/agent-client-protocol/src/jsonrpc/outgoing_actor.rs b/src/agent-client-protocol/src/jsonrpc/outgoing_actor.rs index 0b54ff7..65a5611 100644 --- a/src/agent-client-protocol/src/jsonrpc/outgoing_actor.rs +++ b/src/agent-client-protocol/src/jsonrpc/outgoing_actor.rs @@ -41,6 +41,8 @@ pub(super) async fn outgoing_protocol_actor( method, untyped, response_tx, + #[cfg(feature = "unstable_cancel_request")] + cancellation_disarm, } => { let request = match protocol_compat .outgoing_message(untyped) @@ -49,6 +51,8 @@ pub(super) async fn outgoing_protocol_actor( Ok(request) => request, Err(error) => { tracing::warn!(?id, %method, ?error, "Failed to convert outgoing request"); + #[cfg(feature = "unstable_cancel_request")] + cancellation_disarm.disarm(); complete_request_with_error(response_tx, error); continue; } @@ -61,6 +65,8 @@ pub(super) async fn outgoing_protocol_actor( role_id, method, sender: response_tx, + #[cfg(feature = "unstable_cancel_request")] + cancellation_disarm, }) .map_err(crate::Error::into_internal_error)?; @@ -167,6 +173,8 @@ mod tests { method: "session/new".into(), untyped: malformed_v2_known_method()?, response_tx, + #[cfg(feature = "unstable_cancel_request")] + cancellation_disarm: crate::jsonrpc::SentRequestCancellationDisarm::new(), }) .map_err(crate::Error::into_internal_error)?; drop(outgoing_tx); diff --git a/src/agent-client-protocol/src/lib.rs b/src/agent-client-protocol/src/lib.rs index 9a28f59..103666f 100644 --- a/src/agent-client-protocol/src/lib.rs +++ b/src/agent-client-protocol/src/lib.rs @@ -108,6 +108,8 @@ pub mod jsonrpcmsg { pub use jsonrpcmsg::{Error, Id, Message, Params, Request, Response}; } +#[cfg(feature = "unstable_cancel_request")] +pub use jsonrpc::RequestCancellation; pub use jsonrpc::{ Builder, ByteStreams, Channel, ConnectionTo, Dispatch, HandleDispatchFrom, Handled, IntoHandled, JsonRpcMessage, JsonRpcNotification, JsonRpcRequest, JsonRpcResponse, Lines, diff --git a/src/agent-client-protocol/src/schema/mod.rs b/src/agent-client-protocol/src/schema/mod.rs index 6279701..8724839 100644 --- a/src/agent-client-protocol/src/schema/mod.rs +++ b/src/agent-client-protocol/src/schema/mod.rs @@ -257,6 +257,7 @@ macro_rules! impl_jsonrpc_response_enum { mod agent_to_client; mod client_to_agent; mod enum_impls; +mod protocol_level; mod proxy_protocol; #[cfg(feature = "unstable_protocol_v2")] mod v2_impls; diff --git a/src/agent-client-protocol/src/schema/protocol_level.rs b/src/agent-client-protocol/src/schema/protocol_level.rs new file mode 100644 index 0000000..81ea0da --- /dev/null +++ b/src/agent-client-protocol/src/schema/protocol_level.rs @@ -0,0 +1,38 @@ +#[cfg(feature = "unstable_cancel_request")] +use crate::{ + JsonRpcMessage, JsonRpcNotification, UntypedMessage, + schema::{CancelRequestNotification, ProtocolLevelNotification}, +}; + +#[cfg(feature = "unstable_cancel_request")] +impl_jsonrpc_notification!(CancelRequestNotification, "$/cancel_request"); + +#[cfg(feature = "unstable_cancel_request")] +impl JsonRpcMessage for ProtocolLevelNotification { + fn matches_method(method: &str) -> bool { + method == "$/cancel_request" + } + + fn method(&self) -> &str { + match self { + Self::CancelRequestNotification(_) => "$/cancel_request", + _ => "_unknown", + } + } + + fn to_untyped_message(&self) -> Result { + UntypedMessage::new(self.method(), self) + } + + fn parse_message(method: &str, params: &impl serde::Serialize) -> Result { + match method { + "$/cancel_request" => { + crate::util::json_cast_params(params).map(Self::CancelRequestNotification) + } + _ => Err(crate::Error::method_not_found()), + } + } +} + +#[cfg(feature = "unstable_cancel_request")] +impl JsonRpcNotification for ProtocolLevelNotification {} diff --git a/src/agent-client-protocol/src/schema/v2_impls.rs b/src/agent-client-protocol/src/schema/v2_impls.rs index 5b87c8a..e47117a 100644 --- a/src/agent-client-protocol/src/schema/v2_impls.rs +++ b/src/agent-client-protocol/src/schema/v2_impls.rs @@ -270,6 +270,8 @@ impl_v2_jsonrpc_request!( #[cfg(feature = "unstable_mcp_over_acp")] impl_v2_jsonrpc_request!(v2::MessageMcpRequest, v2::MessageMcpResponse, "mcp/message"); +#[cfg(feature = "unstable_cancel_request")] +impl_v2_jsonrpc_notification!(v2::CancelRequestNotification, "$/cancel_request"); impl_v2_jsonrpc_notification!(v2::CancelNotification, "session/cancel"); #[cfg(feature = "unstable_mcp_over_acp")] impl_v2_jsonrpc_notification!(v2::MessageMcpNotification, "mcp/message"); @@ -325,6 +327,36 @@ impl_v2_jsonrpc_request!( impl_v2_jsonrpc_notification!(v2::SessionNotification, "session/update"); +#[cfg(feature = "unstable_cancel_request")] +impl JsonRpcMessage for v2::ProtocolLevelNotification { + fn matches_method(method: &str) -> bool { + method == "$/cancel_request" + } + + fn method(&self) -> &str { + match self { + Self::CancelRequestNotification(_) => "$/cancel_request", + _ => "_unknown", + } + } + + fn to_untyped_message(&self) -> Result { + UntypedMessage::new(self.method(), self) + } + + fn parse_message(method: &str, params: &impl serde::Serialize) -> Result { + match method { + "$/cancel_request" => { + crate::util::json_cast_params(params).map(Self::CancelRequestNotification) + } + _ => Err(crate::Error::method_not_found()), + } + } +} + +#[cfg(feature = "unstable_cancel_request")] +impl JsonRpcNotification for v2::ProtocolLevelNotification {} + impl_v2_jsonrpc_request_enum!(v2::ClientRequest { InitializeRequest => "initialize", AuthenticateRequest => "authenticate", diff --git a/src/agent-client-protocol/tests/jsonrpc_request_cancellation.rs b/src/agent-client-protocol/tests/jsonrpc_request_cancellation.rs new file mode 100644 index 0000000..9383fd0 --- /dev/null +++ b/src/agent-client-protocol/tests/jsonrpc_request_cancellation.rs @@ -0,0 +1,948 @@ +#![cfg(feature = "unstable_cancel_request")] + +use std::sync::{Arc, Mutex}; + +use agent_client_protocol::{ + Channel, ConnectionTo, Dispatch, Handled, JsonRpcMessage, JsonRpcRequest, JsonRpcResponse, + Responder, Role, RoleId, SentRequest, + role::UntypedRole, + schema::{CancelRequestNotification, ProtocolLevelNotification, RequestId}, +}; +use expect_test::expect; +use futures::{AsyncRead, AsyncWrite}; +use serde::{Deserialize, Serialize}; +use tokio_util::compat::{TokioAsyncReadCompatExt, TokioAsyncWriteCompatExt}; + +fn setup_test_streams() -> ( + impl AsyncRead, + impl AsyncWrite, + impl AsyncRead, + impl AsyncWrite, +) { + let (client_writer, server_reader) = tokio::io::duplex(4096); + let (server_writer, client_reader) = tokio::io::duplex(4096); + + let server_reader = server_reader.compat(); + let server_writer = server_writer.compat_write(); + let client_reader = client_reader.compat(); + let client_writer = client_writer.compat_write(); + + (server_reader, server_writer, client_reader, client_writer) +} + +async fn read_jsonrpc_response_line( + reader: &mut tokio::io::BufReader, +) -> serde_json::Value { + use tokio::io::AsyncBufReadExt as _; + + let mut line = String::new(); + match tokio::time::timeout( + tokio::time::Duration::from_secs(1), + reader.read_line(&mut line), + ) + .await + { + Ok(Ok(0)) | Err(_) => panic!("timed out waiting for JSON-RPC response"), + Ok(Ok(_)) => serde_json::from_str(line.trim()).expect("response should be valid JSON"), + Ok(Err(error)) => panic!("failed to read JSON-RPC response line: {error}"), + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +struct SimpleRequest { + message: String, +} + +impl JsonRpcMessage for SimpleRequest { + fn matches_method(method: &str) -> bool { + method == "simple_method" + } + + fn method(&self) -> &'static str { + "simple_method" + } + + fn to_untyped_message( + &self, + ) -> Result { + agent_client_protocol::UntypedMessage::new(self.method(), self) + } + + fn parse_message( + method: &str, + params: &impl Serialize, + ) -> Result { + if !Self::matches_method(method) { + return Err(agent_client_protocol::Error::method_not_found()); + } + agent_client_protocol::util::json_cast_params(params) + } +} + +impl JsonRpcRequest for SimpleRequest { + type Response = SimpleResponse; +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +struct SimpleResponse { + result: String, +} + +impl JsonRpcResponse for SimpleResponse { + fn into_json(self, _method: &str) -> Result { + serde_json::to_value(self).map_err(agent_client_protocol::Error::into_internal_error) + } + + fn from_value( + _method: &str, + value: serde_json::Value, + ) -> Result { + agent_client_protocol::util::json_cast(&value) + } +} + +#[derive(Debug, Default, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] +struct WrappedHost; + +#[derive(Debug, Default, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] +struct WrappedCounterpart; + +#[derive(Debug, Default, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] +struct WrappedSuccessor; + +#[derive(Debug, Default, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] +struct WrappedSuccessorCounterpart; + +impl Role for WrappedHost { + type Counterpart = WrappedCounterpart; + + fn role_id(&self) -> RoleId { + RoleId::from_singleton(self) + } + + async fn default_handle_dispatch_from( + &self, + message: Dispatch, + _connection: ConnectionTo, + ) -> Result, agent_client_protocol::Error> { + Ok(Handled::No { + message, + retry: false, + }) + } + + fn counterpart(&self) -> Self::Counterpart { + WrappedCounterpart + } +} + +impl Role for WrappedCounterpart { + type Counterpart = WrappedHost; + + fn role_id(&self) -> RoleId { + RoleId::from_singleton(self) + } + + async fn default_handle_dispatch_from( + &self, + message: Dispatch, + _connection: ConnectionTo, + ) -> Result, agent_client_protocol::Error> { + Ok(Handled::No { + message, + retry: false, + }) + } + + fn counterpart(&self) -> Self::Counterpart { + WrappedHost + } +} + +impl Role for WrappedSuccessor { + type Counterpart = WrappedSuccessorCounterpart; + + fn role_id(&self) -> RoleId { + RoleId::from_singleton(self) + } + + async fn default_handle_dispatch_from( + &self, + message: Dispatch, + _connection: ConnectionTo, + ) -> Result, agent_client_protocol::Error> { + Ok(Handled::No { + message, + retry: false, + }) + } + + fn counterpart(&self) -> Self::Counterpart { + WrappedSuccessorCounterpart + } +} + +impl Role for WrappedSuccessorCounterpart { + type Counterpart = WrappedSuccessor; + + fn role_id(&self) -> RoleId { + RoleId::from_singleton(self) + } + + async fn default_handle_dispatch_from( + &self, + message: Dispatch, + _connection: ConnectionTo, + ) -> Result, agent_client_protocol::Error> { + Ok(Handled::No { + message, + retry: false, + }) + } + + fn counterpart(&self) -> Self::Counterpart { + WrappedSuccessor + } +} + +impl agent_client_protocol::role::HasPeer for WrappedCounterpart { + fn remote_style(&self, _peer: WrappedCounterpart) -> agent_client_protocol::role::RemoteStyle { + agent_client_protocol::role::RemoteStyle::Counterpart + } +} + +impl agent_client_protocol::role::HasPeer for WrappedCounterpart { + fn remote_style(&self, _peer: WrappedSuccessor) -> agent_client_protocol::role::RemoteStyle { + agent_client_protocol::role::RemoteStyle::Successor + } +} + +#[tokio::test(flavor = "current_thread")] +async fn unhandled_protocol_level_notifications_are_ignored() { + use tokio::io::{AsyncWriteExt, BufReader}; + use tokio::task::LocalSet; + + let local = LocalSet::new(); + + local + .run_until(async { + let (mut client_writer, server_reader) = tokio::io::duplex(4096); + let (server_writer, client_reader) = tokio::io::duplex(4096); + + let server_transport = agent_client_protocol::ByteStreams::new( + server_writer.compat_write(), + server_reader.compat(), + ); + let server = UntypedRole.builder().on_receive_request( + async |request: SimpleRequest, + responder: Responder, + _connection: ConnectionTo| { + responder.respond(SimpleResponse { + result: format!("echo: {}", request.message), + }) + }, + agent_client_protocol::on_receive_request!(), + ); + + tokio::task::spawn_local(async move { + if let Err(error) = server.connect_to(server_transport).await { + panic!("server should stay alive: {error:?}"); + } + }); + + let mut client_reader = BufReader::new(client_reader); + + client_writer + .write_all( + br#"{"jsonrpc":"2.0","method":"$/cancel_request","params":{"requestId":"req-1"}} +"#, + ) + .await + .unwrap(); + client_writer.flush().await.unwrap(); + + client_writer + .write_all( + br#"{"jsonrpc":"2.0","id":2,"method":"simple_method","params":{"message":"after cancel"}} +"#, + ) + .await + .unwrap(); + client_writer.flush().await.unwrap(); + + let response = read_jsonrpc_response_line(&mut client_reader).await; + expect![[r#" + { + "id": 2, + "jsonrpc": "2.0", + "result": { + "result": "echo: after cancel" + } + }"#]] + .assert_eq(&serde_json::to_string_pretty(&response).unwrap()); + }) + .await; +} + +#[tokio::test(flavor = "current_thread")] +async fn unhandled_wrapped_protocol_level_notifications_are_ignored() { + use tokio::io::{AsyncWriteExt, BufReader}; + use tokio::task::LocalSet; + + let local = LocalSet::new(); + + local + .run_until(async { + let (mut client_writer, server_reader) = tokio::io::duplex(4096); + let (server_writer, client_reader) = tokio::io::duplex(4096); + + let server_transport = agent_client_protocol::ByteStreams::new( + server_writer.compat_write(), + server_reader.compat(), + ); + let server = WrappedHost + .builder() + .on_receive_notification_from( + WrappedSuccessor, + async |cancel: CancelRequestNotification, + cx: ConnectionTo| { + Ok::<_, agent_client_protocol::Error>(Handled::No { + message: (cancel, cx), + retry: false, + }) + }, + agent_client_protocol::on_receive_notification!(), + ) + .on_receive_request( + async |request: SimpleRequest, + responder: Responder, + _connection: ConnectionTo| { + responder.respond(SimpleResponse { + result: format!("echo: {}", request.message), + }) + }, + agent_client_protocol::on_receive_request!(), + ); + + tokio::task::spawn_local(async move { + if let Err(error) = server.connect_to(server_transport).await { + panic!("server should stay alive: {error:?}"); + } + }); + + let mut client_reader = BufReader::new(client_reader); + + client_writer + .write_all( + br#"{"jsonrpc":"2.0","method":"_proxy/successor","params":{"method":"$/cancel_request","params":{"requestId":"req-1"}}} +"#, + ) + .await + .unwrap(); + client_writer.flush().await.unwrap(); + + client_writer + .write_all( + br#"{"jsonrpc":"2.0","id":2,"method":"simple_method","params":{"message":"after wrapped cancel"}} +"#, + ) + .await + .unwrap(); + client_writer.flush().await.unwrap(); + + let response = read_jsonrpc_response_line(&mut client_reader).await; + expect![[r#" + { + "id": 2, + "jsonrpc": "2.0", + "result": { + "result": "echo: after wrapped cancel" + } + }"#]] + .assert_eq(&serde_json::to_string_pretty(&response).unwrap()); + }) + .await; +} + +#[tokio::test(flavor = "current_thread")] +async fn cancel_request_notification_can_be_sent_and_handled() { + use tokio::task::LocalSet; + + let local = LocalSet::new(); + + local + .run_until(async { + let received = Arc::new(Mutex::new(Vec::new())); + let received_for_handler = received.clone(); + + let (server_reader, server_writer, client_reader, client_writer) = setup_test_streams(); + let server_transport = + agent_client_protocol::ByteStreams::new(server_writer, server_reader); + let server = UntypedRole.builder().on_receive_notification( + async move |notification: CancelRequestNotification, + _connection: ConnectionTo| { + received_for_handler + .lock() + .unwrap() + .push(notification.request_id); + Ok(()) + }, + agent_client_protocol::on_receive_notification!(), + ); + + tokio::task::spawn_local(async move { + if let Err(error) = server.connect_to(server_transport).await { + panic!("server should stay alive: {error:?}"); + } + }); + + let client_transport = + agent_client_protocol::ByteStreams::new(client_writer, client_reader); + UntypedRole + .builder() + .connect_with(client_transport, async |cx| { + cx.send_cancel_request("request-42".to_string())?; + tokio::time::sleep(tokio::time::Duration::from_millis(50)).await; + Ok(()) + }) + .await + .unwrap(); + + assert_eq!( + *received.lock().unwrap(), + vec![RequestId::Str("request-42".into())] + ); + }) + .await; +} + +#[tokio::test(flavor = "current_thread")] +async fn sent_request_can_send_cancellation_for_its_id() { + use tokio::task::LocalSet; + + let local = LocalSet::new(); + + local + .run_until(async { + let received = Arc::new(Mutex::new(Vec::new())); + let received_for_handler = received.clone(); + + let (server_reader, server_writer, client_reader, client_writer) = setup_test_streams(); + let server_transport = + agent_client_protocol::ByteStreams::new(server_writer, server_reader); + let server = UntypedRole + .builder() + .on_receive_request( + async |_request: SimpleRequest, + _responder: Responder, + _connection: ConnectionTo| { Ok(()) }, + agent_client_protocol::on_receive_request!(), + ) + .on_receive_notification( + async move |notification: CancelRequestNotification, + _connection: ConnectionTo| { + received_for_handler + .lock() + .unwrap() + .push(notification.request_id); + Ok(()) + }, + agent_client_protocol::on_receive_notification!(), + ); + + tokio::task::spawn_local(async move { + if let Err(error) = server.connect_to(server_transport).await { + panic!("server should stay alive: {error:?}"); + } + }); + + let client_transport = + agent_client_protocol::ByteStreams::new(client_writer, client_reader); + let expected_id = UntypedRole + .builder() + .connect_with(client_transport, async |cx| { + let request: SentRequest = cx.send_request(SimpleRequest { + message: "slow".into(), + }); + let expected_id = request.id(); + request.cancel()?; + tokio::time::sleep(tokio::time::Duration::from_millis(50)).await; + Ok(expected_id) + }) + .await + .unwrap(); + + let received = received.lock().unwrap(); + assert_eq!(received.len(), 1); + assert_eq!(serde_json::to_value(&received[0]).unwrap(), expected_id); + }) + .await; +} + +#[tokio::test(flavor = "current_thread")] +async fn dropped_sent_request_sends_cancellation_for_its_id() { + use tokio::task::LocalSet; + + let local = LocalSet::new(); + + local + .run_until(async { + let received = Arc::new(Mutex::new(Vec::new())); + let received_for_handler = received.clone(); + + let (server_reader, server_writer, client_reader, client_writer) = setup_test_streams(); + let server_transport = + agent_client_protocol::ByteStreams::new(server_writer, server_reader); + let server = UntypedRole + .builder() + .on_receive_request( + async |_request: SimpleRequest, + _responder: Responder, + _connection: ConnectionTo| { Ok(()) }, + agent_client_protocol::on_receive_request!(), + ) + .on_receive_notification( + async move |notification: CancelRequestNotification, + _connection: ConnectionTo| { + received_for_handler + .lock() + .unwrap() + .push(notification.request_id); + Ok(()) + }, + agent_client_protocol::on_receive_notification!(), + ); + + tokio::task::spawn_local(async move { + if let Err(error) = server.connect_to(server_transport).await { + panic!("server should stay alive: {error:?}"); + } + }); + + let client_transport = + agent_client_protocol::ByteStreams::new(client_writer, client_reader); + let expected_id = UntypedRole + .builder() + .connect_with(client_transport, async |cx| { + let request: SentRequest = cx.send_request(SimpleRequest { + message: "abandoned".into(), + }); + let expected_id = request.id(); + drop(request); + tokio::time::sleep(tokio::time::Duration::from_millis(50)).await; + Ok(expected_id) + }) + .await + .unwrap(); + + let received = received.lock().unwrap(); + assert_eq!(received.len(), 1); + assert_eq!(serde_json::to_value(&received[0]).unwrap(), expected_id); + }) + .await; +} + +#[tokio::test(flavor = "current_thread")] +async fn late_response_after_dropped_sent_request_does_not_close_connection() { + use tokio::task::LocalSet; + + let local = LocalSet::new(); + + local + .run_until(async { + let received = Arc::new(Mutex::new(Vec::new())); + let received_for_handler = received.clone(); + + let (server_reader, server_writer, client_reader, client_writer) = setup_test_streams(); + let server_transport = + agent_client_protocol::ByteStreams::new(server_writer, server_reader); + let server = UntypedRole + .builder() + .on_receive_request( + async |request: SimpleRequest, + responder: Responder, + connection: ConnectionTo| { + if request.message == "late" { + connection.spawn(async move { + tokio::time::sleep(tokio::time::Duration::from_millis(10)).await; + responder.respond(SimpleResponse { + result: "late response".into(), + }) + })?; + return Ok(()); + } + + responder.respond(SimpleResponse { + result: format!("echo: {}", request.message), + }) + }, + agent_client_protocol::on_receive_request!(), + ) + .on_receive_notification( + async move |notification: CancelRequestNotification, + _connection: ConnectionTo| { + received_for_handler + .lock() + .unwrap() + .push(notification.request_id); + Ok(()) + }, + agent_client_protocol::on_receive_notification!(), + ); + + tokio::task::spawn_local(async move { + if let Err(error) = server.connect_to(server_transport).await { + panic!("server should stay alive: {error:?}"); + } + }); + + let client_transport = + agent_client_protocol::ByteStreams::new(client_writer, client_reader); + let (expected_id, response) = UntypedRole + .builder() + .connect_with(client_transport, async |cx| { + let request: SentRequest = cx.send_request(SimpleRequest { + message: "late".into(), + }); + let expected_id = request.id(); + drop(request); + + tokio::time::sleep(tokio::time::Duration::from_millis(50)).await; + + let response = cx + .send_request(SimpleRequest { + message: "after late".into(), + }) + .block_task() + .await?; + Ok((expected_id, response)) + }) + .await + .unwrap(); + + assert_eq!(response.result, "echo: after late"); + let received = received.lock().unwrap(); + assert_eq!(received.len(), 1); + assert_eq!(serde_json::to_value(&received[0]).unwrap(), expected_id); + }) + .await; +} + +#[tokio::test(flavor = "current_thread")] +async fn response_buffered_before_drop_disarms_auto_cancellation() { + use tokio::task::LocalSet; + + let local = LocalSet::new(); + + local + .run_until(async { + let received = Arc::new(Mutex::new(Vec::new())); + let received_for_handler = received.clone(); + + let (server_reader, server_writer, client_reader, client_writer) = setup_test_streams(); + let server_transport = + agent_client_protocol::ByteStreams::new(server_writer, server_reader); + let server = UntypedRole + .builder() + .on_receive_request( + async |request: SimpleRequest, + responder: Responder, + _connection: ConnectionTo| { + responder.respond(SimpleResponse { + result: format!("echo: {}", request.message), + }) + }, + agent_client_protocol::on_receive_request!(), + ) + .on_receive_notification( + async move |notification: CancelRequestNotification, + _connection: ConnectionTo| { + received_for_handler + .lock() + .unwrap() + .push(notification.request_id); + Ok(()) + }, + agent_client_protocol::on_receive_notification!(), + ); + + tokio::task::spawn_local(async move { + if let Err(error) = server.connect_to(server_transport).await { + panic!("server should stay alive: {error:?}"); + } + }); + + let client_transport = + agent_client_protocol::ByteStreams::new(client_writer, client_reader); + let response = UntypedRole + .builder() + .connect_with(client_transport, async |cx| { + let request: SentRequest = cx.send_request(SimpleRequest { + message: "buffered".into(), + }); + + tokio::time::sleep(tokio::time::Duration::from_millis(50)).await; + drop(request); + tokio::time::sleep(tokio::time::Duration::from_millis(50)).await; + + cx.send_request(SimpleRequest { + message: "after buffered".into(), + }) + .block_task() + .await + }) + .await + .unwrap(); + + assert_eq!(response.result, "echo: after buffered"); + assert!(received.lock().unwrap().is_empty()); + }) + .await; +} + +#[tokio::test(flavor = "current_thread")] +async fn completed_sent_request_does_not_send_cancellation_on_drop() { + use tokio::task::LocalSet; + + let local = LocalSet::new(); + + local + .run_until(async { + let received = Arc::new(Mutex::new(Vec::new())); + let received_for_handler = received.clone(); + + let (server_reader, server_writer, client_reader, client_writer) = setup_test_streams(); + let server_transport = + agent_client_protocol::ByteStreams::new(server_writer, server_reader); + let server = UntypedRole + .builder() + .on_receive_request( + async |request: SimpleRequest, + responder: Responder, + _connection: ConnectionTo| { + responder.respond(SimpleResponse { + result: format!("echo: {}", request.message), + }) + }, + agent_client_protocol::on_receive_request!(), + ) + .on_receive_notification( + async move |notification: CancelRequestNotification, + _connection: ConnectionTo| { + received_for_handler + .lock() + .unwrap() + .push(notification.request_id); + Ok(()) + }, + agent_client_protocol::on_receive_notification!(), + ); + + tokio::task::spawn_local(async move { + if let Err(error) = server.connect_to(server_transport).await { + panic!("server should stay alive: {error:?}"); + } + }); + + let client_transport = + agent_client_protocol::ByteStreams::new(client_writer, client_reader); + let response = UntypedRole + .builder() + .connect_with(client_transport, async |cx| { + let response = cx + .send_request(SimpleRequest { + message: "complete".into(), + }) + .block_task() + .await?; + tokio::time::sleep(tokio::time::Duration::from_millis(50)).await; + Ok(response) + }) + .await + .unwrap(); + + assert_eq!(response.result, "echo: complete"); + assert!(received.lock().unwrap().is_empty()); + }) + .await; +} + +#[tokio::test(flavor = "current_thread")] +async fn forward_response_to_propagates_cancellation_to_downstream_request() { + use tokio::task::LocalSet; + + let local = LocalSet::new(); + + local + .run_until(async { + let backend_cancellations = Arc::new(Mutex::new(Vec::new())); + let backend_cancellations_for_handler = backend_cancellations.clone(); + + let (backend_for_proxy, backend_for_server) = Channel::duplex(); + let (backend_connection_tx, backend_connection_rx) = + futures::channel::oneshot::channel(); + + tokio::task::spawn_local(async move { + let result = UntypedRole + .builder() + .connect_with(backend_for_proxy, async |connection| { + drop(backend_connection_tx.send(connection.clone())); + std::future::pending::>().await + }) + .await; + if let Err(error) = result { + panic!("proxy-to-backend connection should stay alive: {error:?}"); + } + }); + + let backend_server = UntypedRole + .builder() + .on_receive_request( + async |_request: SimpleRequest, + _responder: Responder, + _connection: ConnectionTo| { Ok(()) }, + agent_client_protocol::on_receive_request!(), + ) + .on_receive_notification( + async move |notification: CancelRequestNotification, + _connection: ConnectionTo| { + backend_cancellations_for_handler + .lock() + .unwrap() + .push(notification.request_id); + Ok(()) + }, + agent_client_protocol::on_receive_notification!(), + ); + + tokio::task::spawn_local(async move { + if let Err(error) = backend_server.connect_to(backend_for_server).await { + panic!("backend server should stay alive: {error:?}"); + } + }); + + let backend_connection = backend_connection_rx + .await + .expect("backend connection should start"); + + let (server_reader, server_writer, client_reader, client_writer) = setup_test_streams(); + let proxy_transport = + agent_client_protocol::ByteStreams::new(server_writer, server_reader); + let proxy = UntypedRole.builder().on_receive_request( + { + let backend_connection = backend_connection.clone(); + async move |request: SimpleRequest, + responder: Responder, + _connection: ConnectionTo| { + backend_connection + .send_request(request) + .forward_response_to(responder)?; + Ok(()) + } + }, + agent_client_protocol::on_receive_request!(), + ); + + tokio::task::spawn_local(async move { + if let Err(error) = proxy.connect_to(proxy_transport).await { + panic!("proxy should stay alive: {error:?}"); + } + }); + + let client_transport = + agent_client_protocol::ByteStreams::new(client_writer, client_reader); + UntypedRole + .builder() + .connect_with(client_transport, async |connection| { + let request: SentRequest = + connection.send_request(SimpleRequest { + message: "cancel downstream".into(), + }); + request.cancel()?; + tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; + Ok(()) + }) + .await + .unwrap(); + + let backend_cancellations = backend_cancellations.lock().unwrap(); + assert_eq!(backend_cancellations.len(), 1); + }) + .await; +} + +#[tokio::test(flavor = "current_thread")] +async fn request_handler_can_observe_cancellation_from_responder() { + use tokio::task::LocalSet; + + let local = LocalSet::new(); + + local + .run_until(async { + let (server_reader, server_writer, client_reader, client_writer) = setup_test_streams(); + let server_transport = + agent_client_protocol::ByteStreams::new(server_writer, server_reader); + let server = UntypedRole.builder().on_receive_request( + async |_request: SimpleRequest, + responder: Responder, + connection: ConnectionTo| { + let cancellation = responder.cancellation(); + assert!(!cancellation.is_cancelled()); + + connection.spawn(async move { + let response = cancellation + .run_until_cancelled(futures::future::pending::< + Result, + >()) + .await; + assert!(cancellation.is_cancelled()); + responder.respond_with_result(response) + })?; + + Ok(()) + }, + agent_client_protocol::on_receive_request!(), + ); + + tokio::task::spawn_local(async move { + if let Err(error) = server.connect_to(server_transport).await { + panic!("server should stay alive: {error:?}"); + } + }); + + let client_transport = + agent_client_protocol::ByteStreams::new(client_writer, client_reader); + let error = UntypedRole + .builder() + .connect_with(client_transport, async |cx| { + let request: SentRequest = cx.send_request(SimpleRequest { + message: "cancel me".into(), + }); + request.cancel()?; + Ok(request + .block_task() + .await + .expect_err("request should be cancelled")) + }) + .await + .unwrap(); + + assert_eq!(i32::from(error.code), -32800); + assert_eq!(error.message, "Request cancelled"); + }) + .await; +} + +#[test] +fn protocol_level_notification_and_cancelled_error_code_are_typed() { + let notification = ProtocolLevelNotification::parse_message( + "$/cancel_request", + &serde_json::json!({ "requestId": "req-1" }), + ) + .unwrap(); + assert_eq!(notification.method(), "$/cancel_request"); + + let error = agent_client_protocol::Error::request_cancelled(); + assert_eq!(i32::from(error.code), -32800); + assert_eq!(error.message, "Request cancelled"); +}