77#include < atomic>
88#include < chrono>
99#include < mutex>
10+ #include < algorithm>
1011#include < optional>
1112#include < stdexcept>
1213#include < thread>
@@ -26,7 +27,7 @@ namespace mcp {
2627
2728// Client implementation
2829class Client ::Impl {
29- public :
30+ private :
3031 friend class Client ; // Allow outer Client to invoke private coroutine helpers
3132 std::unique_ptr<ITransport> transport;
3233 ClientCapabilities capabilities;
@@ -37,6 +38,7 @@ class Client::Impl {
3738 IClient::ProgressHandler progressHandler;
3839 IClient::ErrorHandler errorHandler;
3940 IClient::SamplingHandler samplingHandler;
41+ IClient::SamplingHandlerCancelable samplingHandlerCancelable;
4042 validation::ValidationMode validationMode{validation::ValidationMode::Off};
4143
4244 // Listings cache (optional)
@@ -47,6 +49,85 @@ class Client::Impl {
4749 struct TemplatesCache { std::vector<ResourceTemplate> data; std::chrono::steady_clock::time_point ts; bool set{false }; } templatesCache;
4850 struct PromptsCache { std::vector<Prompt> data; std::chrono::steady_clock::time_point ts; bool set{false }; } promptsCache;
4951
52+ // Cancellation support for server->client requests (e.g., sampling/createMessage)
53+ struct CancellationToken { std::atomic<bool > cancelled{false }; };
54+ std::mutex cancelMutex;
55+ std::unordered_map<std::string, std::shared_ptr<CancellationToken>> cancelMap;
56+ std::unordered_map<std::string, std::vector<std::shared_ptr<std::stop_source>>> stopSources;
57+
58+ static std::string idToString (const JSONRPCId& id) {
59+ std::string idStr;
60+ std::visit ([&](const auto & v){
61+ using T = std::decay_t <decltype (v)>;
62+ if constexpr (std::is_same_v<T, std::string>) { idStr = v; }
63+ else if constexpr (std::is_same_v<T, int64_t >) { idStr = std::to_string (v); }
64+ else { idStr = " " ; }
65+ }, id);
66+ return idStr;
67+ }
68+
69+ static std::string parseIdFromParams (const JSONValue& params) {
70+ std::string idStr;
71+ if (std::holds_alternative<JSONValue::Object>(params.value )) {
72+ const auto & o = std::get<JSONValue::Object>(params.value );
73+ auto it = o.find (" id" );
74+ if (it != o.end () && it->second ) {
75+ if (std::holds_alternative<std::string>(it->second ->value )) idStr = std::get<std::string>(it->second ->value );
76+ else if (std::holds_alternative<int64_t >(it->second ->value )) idStr = std::to_string (std::get<int64_t >(it->second ->value ));
77+ }
78+ }
79+ return idStr;
80+ }
81+
82+ std::shared_ptr<CancellationToken> registerCancelToken (const std::string& idStr) {
83+ if (idStr.empty ()) return std::make_shared<CancellationToken>();
84+ std::lock_guard<std::mutex> lk (cancelMutex);
85+ auto it = cancelMap.find (idStr);
86+ if (it != cancelMap.end ()) return it->second ;
87+ auto tok = std::make_shared<CancellationToken>();
88+ cancelMap[idStr] = tok;
89+ return tok;
90+ }
91+ void unregisterCancelToken (const std::string& idStr) {
92+ if (idStr.empty ()) return ;
93+ std::lock_guard<std::mutex> lk (cancelMutex);
94+ cancelMap.erase (idStr);
95+ stopSources.erase (idStr);
96+ }
97+ void cancelById (const std::string& idStr) {
98+ std::lock_guard<std::mutex> lk (cancelMutex);
99+ auto it = cancelMap.find (idStr);
100+ if (it == cancelMap.end () || !it->second ) {
101+ auto tok = std::make_shared<CancellationToken>();
102+ tok->cancelled .store (true );
103+ cancelMap[idStr] = tok;
104+ } else {
105+ it->second ->cancelled .store (true );
106+ }
107+ auto itS = stopSources.find (idStr);
108+ if (itS != stopSources.end ()) {
109+ for (auto & src : itS->second ) { if (src) { try { src->request_stop (); } catch (...) {} } }
110+ }
111+ }
112+ std::shared_ptr<std::stop_source> registerStopSource (const std::string& idStr) {
113+ auto src = std::make_shared<std::stop_source>();
114+ std::lock_guard<std::mutex> lk (cancelMutex);
115+ stopSources[idStr].push_back (src);
116+ auto it = cancelMap.find (idStr);
117+ if (it != cancelMap.end () && it->second && it->second ->cancelled .load ()) {
118+ try { src->request_stop (); } catch (...) {}
119+ }
120+ return src;
121+ }
122+ void unregisterStopSource (const std::string& idStr, const std::shared_ptr<std::stop_source>& src) {
123+ std::lock_guard<std::mutex> lk (cancelMutex);
124+ auto it = stopSources.find (idStr);
125+ if (it == stopSources.end ()) return ;
126+ auto & vec = it->second ;
127+ vec.erase (std::remove_if (vec.begin (), vec.end (), [&](const std::shared_ptr<std::stop_source>& p){ return p.get () == src.get (); }), vec.end ());
128+ if (vec.empty ()) stopSources.erase (it);
129+ }
130+
50131 explicit Impl (const Implementation& info)
51132 : clientInfo(info) {
52133 // Set default client capabilities
@@ -207,6 +288,14 @@ void Client::Impl::onNotification(std::unique_ptr<JSONRPCNotification> n) {
207288 const auto & o = std::get<JSONValue::Object>(n->params ->value );
208289 this ->handleProgressNotification (o);
209290 }
291+ } else if (n->method == Methods::Cancelled) {
292+ // Server-initiated cancellation for a pending request id
293+ if (n->params .has_value ()) {
294+ std::string idStr = parseIdFromParams (n->params .value ());
295+ if (!idStr.empty ()) {
296+ this ->cancelById (idStr);
297+ }
298+ }
210299 } else {
211300 this ->invalidateCachesForListChanged (n->method );
212301 auto it = this ->notificationHandlers .find (n->method );
@@ -256,7 +345,13 @@ void Client::Impl::logInvalidCreateMessageResultContext(const JSONValue& result)
256345std::unique_ptr<JSONRPCResponse> Client::Impl::onRequest (const JSONRPCRequest& req) {
257346 try {
258347 if (req.method == Methods::CreateMessage) {
259- if (!this ->samplingHandler ) {
348+ // Register cancellation and stop_source for this request id
349+ const std::string idStr = Impl::idToString (req.id );
350+ auto token = this ->registerCancelToken (idStr);
351+ struct ScopeGuard { std::function<void ()> f; ~ScopeGuard (){ if (f) f (); } } guard{ [this , idStr](){ this ->unregisterCancelToken (idStr); } };
352+ auto src = this ->registerStopSource (idStr);
353+
354+ if (!this ->samplingHandler && !this ->samplingHandlerCancelable ) {
260355 errors::McpError e; e.code = JSONRPCErrorCodes::MethodNotAllowed; e.message = " No sampling handler registered" ;
261356 return errors::makeErrorResponse (req.id , e);
262357 }
@@ -276,8 +371,15 @@ std::unique_ptr<JSONRPCResponse> Client::Impl::onRequest(const JSONRPCRequest& r
276371 return errors::makeErrorResponse (req.id , e);
277372 }
278373 }
279- auto fut = this ->samplingHandler (messages, modelPreferences, systemPrompt, includeContext);
374+ std::future<JSONValue> fut = this ->samplingHandler
375+ ? this ->samplingHandler (messages, modelPreferences, systemPrompt, includeContext)
376+ : this ->samplingHandlerCancelable (messages, modelPreferences, systemPrompt, includeContext, src->get_token ());
280377 JSONValue result = fut.get ();
378+ // If cancelled while or after handler ran, return Cancelled
379+ if (token && token->cancelled .load ()) {
380+ errors::McpError e; e.code = JSONRPCErrorCodes::InternalError; e.message = " Cancelled" ;
381+ return errors::makeErrorResponse (req.id , e);
382+ }
281383 if (this ->validationMode == validation::ValidationMode::Strict) {
282384 if (!validation::validateCreateMessageResultJson (result)) {
283385 this ->logInvalidCreateMessageResultContext (result);
@@ -1035,7 +1137,7 @@ mcp::async::Task<JSONValue> Client::Impl::coGetPrompt(const std::string& name, c
10351137
10361138
10371139Client::Client (const Implementation& clientInfo)
1038- : pImpl(std::make_unique <Impl>(clientInfo)) {
1140+ : pImpl(std::unique_ptr <Impl>(new Impl( clientInfo) )) {
10391141 FUNC_SCOPE ();
10401142}
10411143
@@ -1176,6 +1278,11 @@ void Client::SetSamplingHandler(SamplingHandler handler) {
11761278 pImpl->samplingHandler = std::move (handler);
11771279}
11781280
1281+ void Client::SetSamplingHandlerCancelable (SamplingHandlerCancelable handler) {
1282+ FUNC_SCOPE ();
1283+ pImpl->samplingHandlerCancelable = std::move (handler);
1284+ }
1285+
11791286void Client::SetErrorHandler (ErrorHandler handler) {
11801287 FUNC_SCOPE ();
11811288 pImpl->errorHandler = std::move (handler);
0 commit comments