Skip to content

Commit cebcac2

Browse files
committed
added shared memrory transport to match other transports design patterns
1 parent 0cf145c commit cebcac2

9 files changed

Lines changed: 1515 additions & 476 deletions

File tree

CMakeLists.txt

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ add_library(mcp_cpp STATIC
4444
src/mcp/JSONParser.cpp
4545
src/mcp/StdioTransport.cpp
4646
src/mcp/InMemoryTransport.cpp
47+
src/mcp/SharedMemoryTransport.cpp
4748
src/mcp/HTTPTransport.cpp
4849
src/mcp/HTTPServer.cpp
4950
src/mcp/Client.cpp
@@ -81,7 +82,8 @@ target_include_directories(mcp_cpp PUBLIC ${BOOST_INCLUDE_DIR})
8182
target_compile_definitions(mcp_cpp PUBLIC
8283
BOOST_ERROR_CODE_HEADER_ONLY
8384
BOOST_SYSTEM_NO_DEPRECATED
84-
BOOST_ASIO_NO_DEPRECATED)
85+
BOOST_ASIO_NO_DEPRECATED
86+
BOOST_DATE_TIME_NO_LIB)
8587

8688
target_link_libraries(mcp_cpp PUBLIC OpenSSL::SSL OpenSSL::Crypto)
8789

examples/mcp_server/main.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,10 @@ int main() {
4747

4848
// Prepare waitable completion and handler (fires on EOF or transport error)
4949
std::promise<void> stopped;
50-
server->SetErrorHandler([&stopped](const std::string& err) {
50+
server->SetErrorHandler([&stopped, &server](const std::string& err) {
5151
LOG_INFO("Server stopping: {}", err);
52+
// Proactively stop the server/transport so the demo process exits promptly
53+
try { (void)server->Stop().get(); } catch (...) {}
5254
try { stopped.set_value(); } catch (...) {}
5355
});
5456

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
//==========================================================================================================
2+
// SPDX-License-Identifier: MIT
3+
// Copyright (c) 2025 Vinny Parla
4+
// File: SharedMemoryTransport.hpp
5+
// Purpose: Cross-process shared-memory JSON-RPC transport built on Boost.Interprocess
6+
//==========================================================================================================
7+
8+
#pragma once
9+
10+
#include <future>
11+
#include <memory>
12+
#include <string>
13+
#include "mcp/Transport.h"
14+
15+
namespace mcp {
16+
17+
class SharedMemoryTransport : public ITransport {
18+
public:
19+
//==========================================================================================================
20+
// Options
21+
// Purpose: Configuration for channel identity and queue sizing.
22+
// Fields:
23+
// channelName: Base channel name. Two queues are derived: "<channelName>_c2s" and "<channelName>_s2c".
24+
// create: When true, create queues (server-side). When false, open existing queues (client-side).
25+
// maxMessageSize: Maximum size per message in bytes when creating queues (default: 1 MiB).
26+
// maxMessageCount: Maximum number of messages buffered per queue when creating (default: 256).
27+
//==========================================================================================================
28+
struct Options {
29+
std::string channelName;
30+
bool create{false};
31+
std::size_t maxMessageSize{1024ull * 1024ull};
32+
unsigned int maxMessageCount{256u};
33+
};
34+
35+
explicit SharedMemoryTransport(const Options& opts);
36+
~SharedMemoryTransport() override;
37+
38+
////////////////////////////////////////// ITransport //////////////////////////////////////////
39+
//==========================================================================================================
40+
// Starts the transport I/O loop. Returns when worker is ready to accept requests.
41+
//==========================================================================================================
42+
std::future<void> Start() override;
43+
44+
//==========================================================================================================
45+
// Closes the transport and stops the I/O loop.
46+
//==========================================================================================================
47+
std::future<void> Close() override;
48+
49+
//==========================================================================================================
50+
// Indicates whether the transport session is currently connected.
51+
//==========================================================================================================
52+
bool IsConnected() const override;
53+
54+
//==========================================================================================================
55+
// Returns a transport session identifier for diagnostics.
56+
//==========================================================================================================
57+
std::string GetSessionId() const override;
58+
59+
//==========================================================================================================
60+
// Sends a JSON-RPC request and returns a future for the response.
61+
//==========================================================================================================
62+
std::future<std::unique_ptr<JSONRPCResponse>> SendRequest(
63+
std::unique_ptr<JSONRPCRequest> request) override;
64+
65+
//==========================================================================================================
66+
// Sends a JSON-RPC notification (no response expected).
67+
//==========================================================================================================
68+
std::future<void> SendNotification(
69+
std::unique_ptr<JSONRPCNotification> notification) override;
70+
71+
//==========================================================================================================
72+
// Registers handlers for incoming notifications and errors. RequestHandler is used when this
73+
// transport instance is wired on the server side.
74+
//==========================================================================================================
75+
void SetNotificationHandler(NotificationHandler handler) override;
76+
void SetRequestHandler(RequestHandler handler) override;
77+
void SetErrorHandler(ErrorHandler handler) override;
78+
79+
private:
80+
class Impl;
81+
std::unique_ptr<Impl> pImpl;
82+
};
83+
84+
//==========================================================================================================
85+
// SharedMemoryTransportFactory
86+
// Purpose: Factory for creating shared-memory transports from a configuration string.
87+
// Supported formats:
88+
// - "shm://<channelName>?create=true&maxSize=<bytes>&maxCount=<n>"
89+
// - "<channelName>" (defaults to create=false)
90+
// Unknown parameters are ignored.
91+
//==========================================================================================================
92+
class SharedMemoryTransportFactory : public ITransportFactory {
93+
public:
94+
std::unique_ptr<ITransport> CreateTransport(const std::string& config) override;
95+
};
96+
97+
} // namespace mcp

src/mcp/Server.cpp

Lines changed: 61 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#include <limits>
1515
#include <stop_token>
1616
#include <thread>
17+
#include <condition_variable>
1718
#include <unordered_map>
1819
#include <unordered_set>
1920

@@ -81,6 +82,8 @@ class Server::Impl {
8182
std::atomic<unsigned int> keepaliveFailureThreshold{3u};
8283
std::atomic<bool> keepaliveSending{false};
8384
std::atomic<bool> keepaliveSendFailed{false};
85+
std::mutex keepaliveMutex;
86+
std::condition_variable cvKeepalive;
8487

8588
// Client logging preference (minimum severity)
8689
std::atomic<Logger::Level> clientLogMin{Logger::Level::DEBUG};
@@ -890,10 +893,8 @@ mcp::async::Task<void> Server::Impl::coStart(std::unique_ptr<ITransport> transpo
890893

891894
this->transport->SetErrorHandler([this](const std::string& err){
892895
LOG_ERROR("Transport error: {}", err);
893-
// If an error occurs during a keepalive send, mark last send as failed
894-
if (this->keepaliveSending.load()) {
895-
this->keepaliveSendFailed.store(true);
896-
}
896+
// Treat any transport error as a keepalive send failure to avoid races
897+
this->keepaliveSendFailed.store(true);
897898
if (this->errorCallback) {
898899
try {
899900
this->errorCallback(err);
@@ -982,34 +983,33 @@ std::unique_ptr<JSONRPCResponse> Server::Impl::dispatchRequest(const JSONRPCRequ
982983
}
983984

984985
void Server::Impl::sendInitializedAndListChangedAsync() {
985-
std::thread([this]() {
986-
try {
987-
auto note = std::make_unique<JSONRPCNotification>();
988-
note->method = Methods::Initialized;
989-
note->params = JSONValue{JSONValue::Object{}};
990-
(void)this->transport->SendNotification(std::move(note));
991-
{
992-
auto n = std::make_unique<JSONRPCNotification>();
993-
n->method = Methods::ToolListChanged;
994-
n->params = JSONValue{JSONValue::Object{}};
995-
(void)this->transport->SendNotification(std::move(n));
996-
}
997-
{
998-
auto n = std::make_unique<JSONRPCNotification>();
999-
n->method = Methods::ResourceListChanged;
1000-
n->params = JSONValue{JSONValue::Object{}};
1001-
(void)this->transport->SendNotification(std::move(n));
1002-
}
1003-
{
1004-
auto n = std::make_unique<JSONRPCNotification>();
1005-
n->method = Methods::PromptListChanged;
1006-
n->params = JSONValue{JSONValue::Object{}};
1007-
(void)this->transport->SendNotification(std::move(n));
1008-
}
1009-
} catch (const std::exception& e) {
1010-
LOG_ERROR("Initialized/list_changed async send exception: {}", e.what());
986+
// Send synchronously to avoid lifetime races in tests/demos
987+
try {
988+
auto note = std::make_unique<JSONRPCNotification>();
989+
note->method = Methods::Initialized;
990+
note->params = JSONValue{JSONValue::Object{}};
991+
(void)this->transport->SendNotification(std::move(note));
992+
{
993+
auto n = std::make_unique<JSONRPCNotification>();
994+
n->method = Methods::ToolListChanged;
995+
n->params = JSONValue{JSONValue::Object{}};
996+
(void)this->transport->SendNotification(std::move(n));
997+
}
998+
{
999+
auto n = std::make_unique<JSONRPCNotification>();
1000+
n->method = Methods::ResourceListChanged;
1001+
n->params = JSONValue{JSONValue::Object{}};
1002+
(void)this->transport->SendNotification(std::move(n));
1003+
}
1004+
{
1005+
auto n = std::make_unique<JSONRPCNotification>();
1006+
n->method = Methods::PromptListChanged;
1007+
n->params = JSONValue{JSONValue::Object{}};
1008+
(void)this->transport->SendNotification(std::move(n));
10111009
}
1012-
}).detach();
1010+
} catch (const std::exception& e) {
1011+
LOG_ERROR("Initialized/list_changed send exception: {}", e.what());
1012+
}
10131013
}
10141014

10151015
// Common helpers for paging and strict validation (used by list handlers)
@@ -1168,6 +1168,7 @@ mcp::async::Task<void> Server::Impl::coStop() {
11681168
LOG_INFO("Stopping MCP server");
11691169
if (this->keepaliveThread.joinable()) {
11701170
this->keepaliveStop.store(true);
1171+
this->cvKeepalive.notify_all();
11711172
try {
11721173
this->keepaliveThread.join();
11731174
} catch (...) {
@@ -1512,6 +1513,7 @@ Server::~Server() {
15121513
// Ensure keepalive thread is stopped
15131514
if (pImpl && pImpl->keepaliveThread.joinable()) {
15141515
pImpl->keepaliveStop.store(true);
1516+
pImpl->cvKeepalive.notify_all();
15151517
try {
15161518
pImpl->keepaliveThread.join();
15171519
} catch (...) {
@@ -1766,10 +1768,11 @@ void Server::SetKeepaliveIntervalMs(const std::optional<int>& intervalMs) {
17661768
pImpl->capabilities.experimental.erase("keepalive");
17671769
if (pImpl->keepaliveThread.joinable()) {
17681770
pImpl->keepaliveStop.store(true);
1771+
// Wake the keepalive thread if it's waiting
1772+
pImpl->cvKeepalive.notify_all();
17691773
try {
17701774
pImpl->keepaliveThread.join();
1771-
} catch (...) {
1772-
}
1775+
} catch (...) {}
17731776
pImpl->keepaliveStop.store(false);
17741777
}
17751778
return;
@@ -1783,20 +1786,29 @@ void Server::SetKeepaliveIntervalMs(const std::optional<int>& intervalMs) {
17831786
pImpl->capabilities.experimental["keepalive"] = JSONValue{kv};
17841787

17851788
pImpl->keepaliveIntervalMs.store(ms);
1789+
// Wake the keepalive thread to re-evaluate the interval immediately
1790+
pImpl->cvKeepalive.notify_all();
17861791

17871792
// Start background loop if not running
17881793
if (!pImpl->keepaliveThread.joinable()) {
17891794
pImpl->keepaliveStop.store(false);
17901795
pImpl->keepaliveThread = std::thread([this]() {
17911796
while (!pImpl->keepaliveStop.load()) {
17921797
int delay = pImpl->keepaliveIntervalMs.load();
1793-
if (delay <= 0) {
1794-
break;
1795-
}
1796-
std::this_thread::sleep_for(std::chrono::milliseconds(delay));
1797-
if (pImpl->keepaliveStop.load()) {
1798-
break;
1798+
if (delay <= 0) { break; }
1799+
{
1800+
std::unique_lock<std::mutex> lk(pImpl->keepaliveMutex);
1801+
// Wake early if stop requested or interval changes
1802+
bool pred = pImpl->cvKeepalive.wait_for(
1803+
lk,
1804+
std::chrono::milliseconds(delay),
1805+
[this, delay]() { return pImpl->keepaliveStop.load() || pImpl->keepaliveIntervalMs.load() != delay; }
1806+
);
1807+
(void)pred; // if pred true and interval changed, loop re-evaluates delay; if stop, loop breaks below
17991808
}
1809+
if (pImpl->keepaliveStop.load()) { break; }
1810+
// Re-check interval and connection after potential interval change
1811+
if (pImpl->keepaliveIntervalMs.load() <= 0) { break; }
18001812
if (!pImpl->transport || !pImpl->transport->IsConnected()) {
18011813
continue;
18021814
}
@@ -1807,8 +1819,13 @@ void Server::SetKeepaliveIntervalMs(const std::optional<int>& intervalMs) {
18071819
n->method = Methods::Keepalive;
18081820
n->params = JSONValue{JSONValue::Object{}};
18091821
(void)pImpl->transport->SendNotification(std::move(n));
1810-
} catch (...) {
1822+
} catch (const std::exception& e) {
18111823
pImpl->keepaliveSendFailed.store(true);
1824+
if (pImpl->errorCallback) {
1825+
try {
1826+
pImpl->errorCallback("Keepalive failure: " + std::string(e.what()));
1827+
} catch (...) {}
1828+
}
18121829
}
18131830
pImpl->keepaliveSending.store(false);
18141831
if (pImpl->keepaliveSendFailed.load()) {
@@ -1819,21 +1836,15 @@ void Server::SetKeepaliveIntervalMs(const std::optional<int>& intervalMs) {
18191836
}
18201837
pImpl->keepaliveConsecutiveFailures.store(next);
18211838
if (next >= pImpl->keepaliveFailureThreshold.load()) {
1822-
LOG_ERROR("Keepalive failure threshold reached ({}); closing transport", next);
1823-
try {
1824-
(void)pImpl->transport->Close();
1825-
} catch (...) {
1826-
}
1839+
LOG_ERROR("Keepalive failure threshold reached ({})", next);
1840+
try { (void)pImpl->transport->Close(); } catch (...) {}
18271841
if (pImpl->errorCallback) {
1828-
try {
1829-
pImpl->errorCallback("Keepalive failure threshold reached; closing transport");
1830-
}
1831-
catch (...) {
1832-
}
1842+
try { pImpl->errorCallback("Keepalive failure threshold reached; closing transport"); } catch (...) {}
18331843
}
18341844
break;
18351845
}
18361846
} else {
1847+
// Success path: reset consecutive failures
18371848
pImpl->keepaliveConsecutiveFailures.store(0u);
18381849
}
18391850
}

0 commit comments

Comments
 (0)