Skip to content

Commit f6adb1a

Browse files
committed
ServerSampling.AdvertisesSamplingWithHandler|Cancellation.CooperativeReadResourceStopsEarly|SamplingCancellationE2E
1 parent 7db697c commit f6adb1a

8 files changed

Lines changed: 309 additions & 6 deletions

File tree

docs/api/client.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,14 @@ This page summarizes the public client APIs with signatures and brief descriptio
7676
const JSONValue& includeContext)>
7777
- void SetSamplingHandler(SamplingHandler handler)
7878
- Register a handler for server-initiated sampling/createMessage.
79+
- using SamplingHandlerCancelable = std::function<std::future<JSONValue>(
80+
const JSONValue& messages,
81+
const JSONValue& modelPreferences,
82+
const JSONValue& systemPrompt,
83+
const JSONValue& includeContext,
84+
std::stop_token st)>
85+
- void SetSamplingHandlerCancelable(SamplingHandlerCancelable handler)
86+
- Register a cancelable variant that receives a `std::stop_token`. When set, the client uses this handler for `sampling/createMessage` and will request stop when it receives `notifications/cancelled` targeting the in-flight request id.
7987

8088
## Notifications, progress, and errors
8189
- using NotificationHandler = std::function<void(const std::string& method, const JSONValue& params)>;

docs/api/server.md

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,21 @@ Tests:
182182
- Register a server-side handler for `sampling/createMessage`.
183183
- std::future<JSONValue> RequestCreateMessage(const CreateMessageParams& params)
184184
- Request the client to create a message (server-initiated sampling path).
185+
- std::future<JSONValue> RequestCreateMessageWithId(const CreateMessageParams& params, const std::string& requestId)
186+
- Same as above, but allows the server to specify the JSON-RPC request id. This is useful for end-to-end cancellation tests or workflows where the server needs to target a specific in-flight request with a `notifications/cancelled` message.
187+
188+
Cancellation (server-initiated):
189+
190+
- The server can request cancellation of an in-flight client request by sending:
191+
192+
```cpp
193+
JSONValue::Object cancelParams; cancelParams["id"] = std::make_shared<JSONValue>(std::string("<request-id>"));
194+
server.SendNotification(Methods::Cancelled, JSONValue{cancelParams}).get();
195+
```
196+
197+
- When the client observes this notification, it will propagate `std::stop_token` to a cancelable sampling handler (if registered) and ultimately respond with an error shaped as `{ code: -32603, message: "Cancelled" }`.
198+
199+
Tests: see `tests/test_sampling_cancellation_e2e.cpp`.
185200
186201
## Keepalive / Heartbeat
187202
- void SetKeepaliveIntervalMs(const std::optional<int>& intervalMs)

include/mcp/Client.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include <vector>
1717
#include <future>
1818
#include <functional>
19+
#include <stop_token>
1920

2021
namespace mcp {
2122

@@ -249,6 +250,17 @@ class IClient {
249250
const JSONValue& includeContext)>;
250251
virtual void SetSamplingHandler(SamplingHandler handler) = 0;
251252

253+
// Optional cancelable sampling handler variant that receives a std::stop_token.
254+
// When set, this handler will be used instead of the non-cancelable variant for
255+
// servicing server-initiated sampling/createMessage requests.
256+
using SamplingHandlerCancelable = std::function<std::future<JSONValue>(
257+
const JSONValue& messages,
258+
const JSONValue& modelPreferences,
259+
const JSONValue& systemPrompt,
260+
const JSONValue& includeContext,
261+
std::stop_token st)>;
262+
virtual void SetSamplingHandlerCancelable(SamplingHandlerCancelable handler) = 0;
263+
252264
////////////////////////////////////////// Notification handlers ///////////////////////////////////////////
253265
//==========================================================================================================
254266
// Registers a notification handler for a specific method name.
@@ -368,6 +380,7 @@ class Client : public IClient {
368380
const JSONValue& arguments) override;
369381

370382
void SetSamplingHandler(SamplingHandler handler) override;
383+
void SetSamplingHandlerCancelable(SamplingHandlerCancelable handler) override;
371384

372385
void SetNotificationHandler(const std::string& method, NotificationHandler handler) override;
373386
void RemoveNotificationHandler(const std::string& method) override;

include/mcp/Server.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -323,6 +323,18 @@ class IServer {
323323
//==========================================================================================================
324324
virtual std::future<JSONValue> RequestCreateMessage(const CreateMessageParams& params) = 0;
325325

326+
//==========================================================================================================
327+
// Requests the client to create a message (server-initiated sampling) with a caller-provided request id.
328+
// Useful for tests and cancellation E2E scenarios where the server needs to cancel an in-flight request by id.
329+
// Args:
330+
// params: CreateMessage parameters.
331+
// requestId: Caller-provided id to assign to the JSON-RPC request.
332+
// Returns:
333+
// Future with raw JSON result from the client's handler.
334+
//==========================================================================================================
335+
virtual std::future<JSONValue> RequestCreateMessageWithId(const CreateMessageParams& params,
336+
const std::string& requestId) = 0;
337+
326338
/////////////////////////////////////////// Notification sending ///////////////////////////////////////////
327339
//==========================================================================================================
328340
// Sends an arbitrary JSON-RPC notification (method + params) to the client.
@@ -504,6 +516,8 @@ class Server : public IServer {
504516

505517
// Server-initiated sampling (request client to create a message)
506518
std::future<JSONValue> RequestCreateMessage(const CreateMessageParams& params) override;
519+
std::future<JSONValue> RequestCreateMessageWithId(const CreateMessageParams& params,
520+
const std::string& requestId) override;
507521

508522
// IServer message sending
509523
std::future<void> SendNotification(const std::string& method, const JSONValue& params) override;

src/mcp/Client.cpp

Lines changed: 111 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#include <atomic>
88
#include <chrono>
99
#include <mutex>
10+
#include <algorithm>
1011
#include <optional>
1112
#include <stdexcept>
1213
#include <thread>
@@ -26,7 +27,7 @@ namespace mcp {
2627

2728
// Client implementation
2829
class Client::Impl {
29-
public:
30+
private:
3031
friend class Client; // Allow outer Client to invoke private coroutine helpers
3132
std::unique_ptr<ITransport> transport;
3233
ClientCapabilities capabilities;
@@ -37,6 +38,7 @@ class Client::Impl {
3738
IClient::ProgressHandler progressHandler;
3839
IClient::ErrorHandler errorHandler;
3940
IClient::SamplingHandler samplingHandler;
41+
IClient::SamplingHandlerCancelable samplingHandlerCancelable;
4042
validation::ValidationMode validationMode{validation::ValidationMode::Off};
4143

4244
// Listings cache (optional)
@@ -47,6 +49,85 @@ class Client::Impl {
4749
struct TemplatesCache { std::vector<ResourceTemplate> data; std::chrono::steady_clock::time_point ts; bool set{false}; } templatesCache;
4850
struct PromptsCache { std::vector<Prompt> data; std::chrono::steady_clock::time_point ts; bool set{false}; } promptsCache;
4951

52+
// Cancellation support for server->client requests (e.g., sampling/createMessage)
53+
struct CancellationToken { std::atomic<bool> cancelled{false}; };
54+
std::mutex cancelMutex;
55+
std::unordered_map<std::string, std::shared_ptr<CancellationToken>> cancelMap;
56+
std::unordered_map<std::string, std::vector<std::shared_ptr<std::stop_source>>> stopSources;
57+
58+
static std::string idToString(const JSONRPCId& id) {
59+
std::string idStr;
60+
std::visit([&](const auto& v){
61+
using T = std::decay_t<decltype(v)>;
62+
if constexpr (std::is_same_v<T, std::string>) { idStr = v; }
63+
else if constexpr (std::is_same_v<T, int64_t>) { idStr = std::to_string(v); }
64+
else { idStr = ""; }
65+
}, id);
66+
return idStr;
67+
}
68+
69+
static std::string parseIdFromParams(const JSONValue& params) {
70+
std::string idStr;
71+
if (std::holds_alternative<JSONValue::Object>(params.value)) {
72+
const auto& o = std::get<JSONValue::Object>(params.value);
73+
auto it = o.find("id");
74+
if (it != o.end() && it->second) {
75+
if (std::holds_alternative<std::string>(it->second->value)) idStr = std::get<std::string>(it->second->value);
76+
else if (std::holds_alternative<int64_t>(it->second->value)) idStr = std::to_string(std::get<int64_t>(it->second->value));
77+
}
78+
}
79+
return idStr;
80+
}
81+
82+
std::shared_ptr<CancellationToken> registerCancelToken(const std::string& idStr) {
83+
if (idStr.empty()) return std::make_shared<CancellationToken>();
84+
std::lock_guard<std::mutex> lk(cancelMutex);
85+
auto it = cancelMap.find(idStr);
86+
if (it != cancelMap.end()) return it->second;
87+
auto tok = std::make_shared<CancellationToken>();
88+
cancelMap[idStr] = tok;
89+
return tok;
90+
}
91+
void unregisterCancelToken(const std::string& idStr) {
92+
if (idStr.empty()) return;
93+
std::lock_guard<std::mutex> lk(cancelMutex);
94+
cancelMap.erase(idStr);
95+
stopSources.erase(idStr);
96+
}
97+
void cancelById(const std::string& idStr) {
98+
std::lock_guard<std::mutex> lk(cancelMutex);
99+
auto it = cancelMap.find(idStr);
100+
if (it == cancelMap.end() || !it->second) {
101+
auto tok = std::make_shared<CancellationToken>();
102+
tok->cancelled.store(true);
103+
cancelMap[idStr] = tok;
104+
} else {
105+
it->second->cancelled.store(true);
106+
}
107+
auto itS = stopSources.find(idStr);
108+
if (itS != stopSources.end()) {
109+
for (auto& src : itS->second) { if (src) { try { src->request_stop(); } catch (...) {} } }
110+
}
111+
}
112+
std::shared_ptr<std::stop_source> registerStopSource(const std::string& idStr) {
113+
auto src = std::make_shared<std::stop_source>();
114+
std::lock_guard<std::mutex> lk(cancelMutex);
115+
stopSources[idStr].push_back(src);
116+
auto it = cancelMap.find(idStr);
117+
if (it != cancelMap.end() && it->second && it->second->cancelled.load()) {
118+
try { src->request_stop(); } catch (...) {}
119+
}
120+
return src;
121+
}
122+
void unregisterStopSource(const std::string& idStr, const std::shared_ptr<std::stop_source>& src) {
123+
std::lock_guard<std::mutex> lk(cancelMutex);
124+
auto it = stopSources.find(idStr);
125+
if (it == stopSources.end()) return;
126+
auto& vec = it->second;
127+
vec.erase(std::remove_if(vec.begin(), vec.end(), [&](const std::shared_ptr<std::stop_source>& p){ return p.get() == src.get(); }), vec.end());
128+
if (vec.empty()) stopSources.erase(it);
129+
}
130+
50131
explicit Impl(const Implementation& info)
51132
: clientInfo(info) {
52133
// Set default client capabilities
@@ -207,6 +288,14 @@ void Client::Impl::onNotification(std::unique_ptr<JSONRPCNotification> n) {
207288
const auto& o = std::get<JSONValue::Object>(n->params->value);
208289
this->handleProgressNotification(o);
209290
}
291+
} else if (n->method == Methods::Cancelled) {
292+
// Server-initiated cancellation for a pending request id
293+
if (n->params.has_value()) {
294+
std::string idStr = parseIdFromParams(n->params.value());
295+
if (!idStr.empty()) {
296+
this->cancelById(idStr);
297+
}
298+
}
210299
} else {
211300
this->invalidateCachesForListChanged(n->method);
212301
auto it = this->notificationHandlers.find(n->method);
@@ -256,7 +345,13 @@ void Client::Impl::logInvalidCreateMessageResultContext(const JSONValue& result)
256345
std::unique_ptr<JSONRPCResponse> Client::Impl::onRequest(const JSONRPCRequest& req) {
257346
try {
258347
if (req.method == Methods::CreateMessage) {
259-
if (!this->samplingHandler) {
348+
// Register cancellation and stop_source for this request id
349+
const std::string idStr = Impl::idToString(req.id);
350+
auto token = this->registerCancelToken(idStr);
351+
struct ScopeGuard { std::function<void()> f; ~ScopeGuard(){ if (f) f(); } } guard{ [this, idStr](){ this->unregisterCancelToken(idStr); } };
352+
auto src = this->registerStopSource(idStr);
353+
354+
if (!this->samplingHandler && !this->samplingHandlerCancelable) {
260355
errors::McpError e; e.code = JSONRPCErrorCodes::MethodNotAllowed; e.message = "No sampling handler registered";
261356
return errors::makeErrorResponse(req.id, e);
262357
}
@@ -276,8 +371,15 @@ std::unique_ptr<JSONRPCResponse> Client::Impl::onRequest(const JSONRPCRequest& r
276371
return errors::makeErrorResponse(req.id, e);
277372
}
278373
}
279-
auto fut = this->samplingHandler(messages, modelPreferences, systemPrompt, includeContext);
374+
std::future<JSONValue> fut = this->samplingHandler
375+
? this->samplingHandler(messages, modelPreferences, systemPrompt, includeContext)
376+
: this->samplingHandlerCancelable(messages, modelPreferences, systemPrompt, includeContext, src->get_token());
280377
JSONValue result = fut.get();
378+
// If cancelled while or after handler ran, return Cancelled
379+
if (token && token->cancelled.load()) {
380+
errors::McpError e; e.code = JSONRPCErrorCodes::InternalError; e.message = "Cancelled";
381+
return errors::makeErrorResponse(req.id, e);
382+
}
281383
if (this->validationMode == validation::ValidationMode::Strict) {
282384
if (!validation::validateCreateMessageResultJson(result)) {
283385
this->logInvalidCreateMessageResultContext(result);
@@ -1035,7 +1137,7 @@ mcp::async::Task<JSONValue> Client::Impl::coGetPrompt(const std::string& name, c
10351137

10361138

10371139
Client::Client(const Implementation& clientInfo)
1038-
: pImpl(std::make_unique<Impl>(clientInfo)) {
1140+
: pImpl(std::unique_ptr<Impl>(new Impl(clientInfo))) {
10391141
FUNC_SCOPE();
10401142
}
10411143

@@ -1176,6 +1278,11 @@ void Client::SetSamplingHandler(SamplingHandler handler) {
11761278
pImpl->samplingHandler = std::move(handler);
11771279
}
11781280

1281+
void Client::SetSamplingHandlerCancelable(SamplingHandlerCancelable handler) {
1282+
FUNC_SCOPE();
1283+
pImpl->samplingHandlerCancelable = std::move(handler);
1284+
}
1285+
11791286
void Client::SetErrorHandler(ErrorHandler handler) {
11801287
FUNC_SCOPE();
11811288
pImpl->errorHandler = std::move(handler);

src/mcp/Server.cpp

Lines changed: 64 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ namespace mcp {
3030

3131
// Server implementation
3232
class Server::Impl {
33-
public:
33+
private:
3434
friend class Server;
3535
// Coroutine helpers (declarations)
3636
mcp::async::Task<void> coStart(std::unique_ptr<ITransport> transport);
@@ -40,6 +40,7 @@ class Server::Impl {
4040
mcp::async::Task<JSONValue> coReadResource(const std::string& uri);
4141
mcp::async::Task<JSONValue> coGetPrompt(const std::string& name, const JSONValue& arguments);
4242
mcp::async::Task<JSONValue> coRequestCreateMessage(const CreateMessageParams& params);
43+
mcp::async::Task<JSONValue> coRequestCreateMessageWithId(const CreateMessageParams& params, const std::string& requestId);
4344
mcp::async::Task<void> coSendNotification(const std::string& method, const JSONValue& params);
4445
mcp::async::Task<void> coNotifyResourcesListChanged();
4546
mcp::async::Task<void> coNotifyToolsListChanged();
@@ -1277,8 +1278,64 @@ mcp::async::Task<JSONValue> Server::Impl::coRequestCreateMessage(const CreateMes
12771278
co_return JSONValue{};
12781279
}
12791280

1281+
mcp::async::Task<JSONValue> Server::Impl::coRequestCreateMessageWithId(const CreateMessageParams& params, const std::string& requestId) {
1282+
FUNC_SCOPE();
1283+
if (!this->transport) {
1284+
LOG_ERROR("RequestCreateMessageWithId called without transport");
1285+
co_return JSONValue{};
1286+
}
1287+
auto request = std::make_unique<JSONRPCRequest>();
1288+
request->method = Methods::CreateMessage;
1289+
request->id = requestId; // assign caller-provided id so cancellation can target this request
1290+
JSONValue::Object obj;
1291+
// messages (array)
1292+
JSONValue::Array msgs;
1293+
msgs.reserve(params.messages.size());
1294+
for (const auto& m : params.messages) msgs.push_back(std::make_shared<JSONValue>(m));
1295+
obj["messages"] = std::make_shared<JSONValue>(msgs);
1296+
// Optional fields
1297+
if (params.modelPreferences.has_value()) {
1298+
obj["modelPreferences"] = std::make_shared<JSONValue>(params.modelPreferences.value());
1299+
}
1300+
if (params.systemPrompt.has_value()) {
1301+
obj["systemPrompt"] = std::make_shared<JSONValue>(params.systemPrompt.value());
1302+
}
1303+
if (params.includeContext.has_value()) {
1304+
obj["includeContext"] = std::make_shared<JSONValue>(params.includeContext.value());
1305+
}
1306+
if (params.maxTokens.has_value()) {
1307+
obj["maxTokens"] = std::make_shared<JSONValue>(static_cast<int64_t>(params.maxTokens.value()));
1308+
}
1309+
if (params.temperature.has_value()) {
1310+
obj["temperature"] = std::make_shared<JSONValue>(params.temperature.value());
1311+
}
1312+
if (params.stopSequences.has_value()) {
1313+
JSONValue::Array arr;
1314+
for (const auto& s : params.stopSequences.value()) {
1315+
arr.push_back(std::make_shared<JSONValue>(s));
1316+
}
1317+
obj["stopSequences"] = std::make_shared<JSONValue>(arr);
1318+
}
1319+
if (params.metadata.has_value()) {
1320+
obj["metadata"] = std::make_shared<JSONValue>(params.metadata.value());
1321+
}
1322+
request->params = JSONValue{obj};
1323+
1324+
auto fut = this->transport->SendRequest(std::move(request));
1325+
try {
1326+
auto resp = co_await mcp::async::makeFutureAwaitable(std::move(fut));
1327+
if (resp) {
1328+
if (resp->result.has_value()) co_return resp->result.value();
1329+
if (resp->error.has_value()) co_return resp->error.value();
1330+
}
1331+
} catch (const std::exception& e) {
1332+
LOG_ERROR("RequestCreateMessageWithId exception: {}", e.what());
1333+
}
1334+
co_return JSONValue{};
1335+
}
1336+
12801337
Server::Server(const std::string& serverInfo)
1281-
: pImpl(std::make_unique<Impl>()) {
1338+
: pImpl(std::unique_ptr<Impl>(new Impl())) {
12821339
FUNC_SCOPE();
12831340
pImpl->serverInfo = serverInfo;
12841341
}
@@ -1637,6 +1694,11 @@ std::future<JSONValue> Server::RequestCreateMessage(const CreateMessageParams& p
16371694
return pImpl->coRequestCreateMessage(params).toFuture();
16381695
}
16391696

1697+
std::future<JSONValue> Server::RequestCreateMessageWithId(const CreateMessageParams& params, const std::string& requestId) {
1698+
FUNC_SCOPE();
1699+
return pImpl->coRequestCreateMessageWithId(params, requestId).toFuture();
1700+
}
1701+
16401702
void Server::RegisterResourceTemplate(const ResourceTemplate& resourceTemplate) {
16411703
FUNC_SCOPE();
16421704
LOG_DEBUG("Registering resource template: {}", resourceTemplate.uriTemplate);

tests/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ add_executable(mcp_tests
4141
test_initialize_notifications.cpp
4242
test_errors.cpp
4343
test_client_typed.cpp
44+
test_sampling_cancellation_e2e.cpp
4445
test_validation_mode.cpp
4546
test_validation_lists.cpp
4647
test_validation_sampling.cpp

0 commit comments

Comments
 (0)