diff --git a/cpr/CMakeLists.txt b/cpr/CMakeLists.txt index db48358e4..30367f73e 100644 --- a/cpr/CMakeLists.txt +++ b/cpr/CMakeLists.txt @@ -6,6 +6,7 @@ add_library(cpr auth.cpp callback.cpp cert_info.cpp + connection_pool.cpp cookies.cpp cprtypes.cpp curl_container.cpp diff --git a/cpr/connection_pool.cpp b/cpr/connection_pool.cpp new file mode 100644 index 000000000..6afb8643e --- /dev/null +++ b/cpr/connection_pool.cpp @@ -0,0 +1,39 @@ +#include "cpr/connection_pool.h" +#include +#include +#include + +namespace cpr { +ConnectionPool::ConnectionPool() { + CURLSH* curl_share = curl_share_init(); + this->connection_mutex_ = std::make_shared(); + + auto lock_f = +[](CURL* /*handle*/, curl_lock_data /*data*/, curl_lock_access /*access*/, void* userptr) { + std::mutex* lock = static_cast(userptr); + lock->lock(); // cppcheck-suppress localMutex // False positive: mutex is used as callback for libcurl, not local scope + }; + + auto unlock_f = +[](CURL* /*handle*/, curl_lock_data /*data*/, void* userptr) { + std::mutex* lock = static_cast(userptr); + lock->unlock(); + }; + + curl_share_setopt(curl_share, CURLSHOPT_SHARE, CURL_LOCK_DATA_CONNECT); + curl_share_setopt(curl_share, CURLSHOPT_USERDATA, this->connection_mutex_.get()); + curl_share_setopt(curl_share, CURLSHOPT_LOCKFUNC, lock_f); + curl_share_setopt(curl_share, CURLSHOPT_UNLOCKFUNC, unlock_f); + + this->curl_sh_ = std::shared_ptr(curl_share, + [](CURLSH* ptr) { + // Make sure to reset callbacks before cleanup to avoid deadlocks + curl_share_setopt(ptr, CURLSHOPT_LOCKFUNC, nullptr); + curl_share_setopt(ptr, CURLSHOPT_UNLOCKFUNC, nullptr); + curl_share_cleanup(ptr); + }); +} + +void ConnectionPool::SetupHandler(CURL* easy_handler) const { + curl_easy_setopt(easy_handler, CURLOPT_SHARE, this->curl_sh_.get()); +} + +} // namespace cpr \ No newline at end of file diff --git a/cpr/session.cpp b/cpr/session.cpp index e9953f758..59e2b64e6 100644 --- a/cpr/session.cpp +++ b/cpr/session.cpp @@ -31,6 +31,7 @@ #include "cpr/body_view.h" #include "cpr/callback.h" #include "cpr/connect_timeout.h" +#include "cpr/connection_pool.h" #include "cpr/cookies.h" #include "cpr/cprtypes.h" #include "cpr/curlholder.h" @@ -397,6 +398,11 @@ void Session::SetConnectTimeout(const ConnectTimeout& timeout) { curl_easy_setopt(curl_->handle, CURLOPT_CONNECTTIMEOUT_MS, timeout.Milliseconds()); } +void Session::SetConnectionPool(const ConnectionPool& pool) { + CURL* curl = curl_->handle; + pool.SetupHandler(curl); +} + void Session::SetAuth(const Authentication& auth) { // Ignore here since this has been defined by libcurl. switch (auth.GetAuthMode()) { @@ -1091,6 +1097,7 @@ void Session::SetOption(const MultiRange& multi_range) { SetMultiRange(multi_ran void Session::SetOption(const ReserveSize& reserve_size) { SetReserveSize(reserve_size.size); } void Session::SetOption(const AcceptEncoding& accept_encoding) { SetAcceptEncoding(accept_encoding); } void Session::SetOption(AcceptEncoding&& accept_encoding) { SetAcceptEncoding(std::move(accept_encoding)); } +void Session::SetOption(const ConnectionPool& pool) { SetConnectionPool(pool); } // clang-format on void Session::SetCancellationParam(std::shared_ptr param) { diff --git a/include/cpr/connection_pool.h b/include/cpr/connection_pool.h new file mode 100644 index 000000000..b4a758b4f --- /dev/null +++ b/include/cpr/connection_pool.h @@ -0,0 +1,80 @@ +#ifndef CPR_CONNECTION_POOL_H +#define CPR_CONNECTION_POOL_H + +#include +#include +#include + +namespace cpr { +/** + * cpr connection pool implementation for sharing connections between HTTP requests. + * + * The ConnectionPool enables connection reuse across multiple HTTP requests to the same host, + * which can significantly improve performance by avoiding the overhead of establishing new + * connections for each request. It uses libcurl's CURLSH (share) interface to manage + * connection sharing in a thread-safe manner. + * + * Example: + * ```cpp + * // Create a connection pool + * cpr::ConnectionPool pool; + * + * // Use the pool with requests to reuse connections + * cpr::Response r1 = cpr::Get(cpr::Url{"http://example.com/api/data"}, pool); + * cpr::Response r2 = cpr::Get(cpr::Url{"http://example.com/api/more"}, pool); + * + * // Or with async requests + * auto future1 = cpr::GetAsync(cpr::Url{"http://example.com/api/data"}, pool); + * auto future2 = cpr::GetAsync(cpr::Url{"http://example.com/api/more"}, pool); + * ``` + **/ +class ConnectionPool { + public: + /** + * Creates a new connection pool with shared connection state. + * Initializes the underlying CURLSH handle and sets up thread-safe locking mechanisms. + **/ + ConnectionPool(); + + /** + * Copy constructor - creates a new connection pool sharing the same connection state. + * Multiple ConnectionPool instances can share the same underlying connection pool. + **/ + ConnectionPool(const ConnectionPool&) = default; + + /** + * Copy assignment operator is deleted to prevent accidental copying. + * Use the copy constructor if you need to share the connection pool. + **/ + ConnectionPool& operator=(const ConnectionPool&) = delete; + + /** + * Configures a CURL easy handle to use this connection pool. + * This method sets up the easy handle to participate in connection sharing + * managed by this pool. + * + * @param easy_handler The CURL easy handle to configure for connection sharing. + **/ + void SetupHandler(CURL* easy_handler) const; + + private: + /** + * Thread-safe mutex used for synchronizing access to shared connections. + * This mutex is passed to libcurl's locking callbacks to ensure thread safety + * when multiple threads access the same connection pool. It's declared first + * to ensure it's destroyed last, after the CURLSH handle that references it. + **/ + std::shared_ptr connection_mutex_; + + /** + * Shared CURL handle (CURLSH) that manages the actual connection sharing. + * This handle maintains the pool of reusable connections and is configured + * with appropriate locking callbacks for thread safety. The shared_ptr uses + * a custom deleter that safely resets the lock/unlock callbacks before + * calling curl_share_cleanup() to prevent use-after-free issues during destruction. + * Declared last to ensure it's destroyed first, before the mutex it references. + **/ + std::shared_ptr curl_sh_; +}; +} // namespace cpr +#endif \ No newline at end of file diff --git a/include/cpr/cpr.h b/include/cpr/cpr.h index fbad1726a..5c20f7ddb 100644 --- a/include/cpr/cpr.h +++ b/include/cpr/cpr.h @@ -7,6 +7,7 @@ #include "cpr/callback.h" #include "cpr/cert_info.h" #include "cpr/connect_timeout.h" +#include "cpr/connection_pool.h" #include "cpr/cookies.h" #include "cpr/cprtypes.h" #include "cpr/cprver.h" diff --git a/include/cpr/curlholder.h b/include/cpr/curlholder.h index 27130cd60..19ca2ba09 100644 --- a/include/cpr/curlholder.h +++ b/include/cpr/curlholder.h @@ -4,7 +4,6 @@ #include #include #include -#include #include "cpr/secure_string.h" diff --git a/include/cpr/session.h b/include/cpr/session.h index a563efb46..483b7ef2d 100644 --- a/include/cpr/session.h +++ b/include/cpr/session.h @@ -18,6 +18,7 @@ #include "cpr/body_view.h" #include "cpr/callback.h" #include "cpr/connect_timeout.h" +#include "cpr/connection_pool.h" #include "cpr/cookies.h" #include "cpr/cprtypes.h" #include "cpr/curlholder.h" @@ -72,6 +73,7 @@ class Session : public std::enable_shared_from_this { [[nodiscard]] const Header& GetHeader() const; void SetTimeout(const Timeout& timeout); void SetConnectTimeout(const ConnectTimeout& timeout); + void SetConnectionPool(const ConnectionPool& pool); void SetAuth(const Authentication& auth); // Only supported with libcurl >= 7.61.0. // As an alternative use SetHeader and add the token manually. @@ -137,6 +139,7 @@ class Session : public std::enable_shared_from_this { void SetOption(const Timeout& timeout); void SetOption(const ConnectTimeout& timeout); void SetOption(const Authentication& auth); + void SetOption(const ConnectionPool& pool); // Only supported with libcurl >= 7.61.0. // As an alternative use SetHeader and add the token manually. #if LIBCURL_VERSION_NUM >= 0x073D00 diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index fd5afe170..14c3d0b90 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -70,6 +70,7 @@ add_cpr_test(file_upload) add_cpr_test(singleton) add_cpr_test(threadpool) add_cpr_test(testUtils) +add_cpr_test(connection_pool) if (ENABLE_SSL_TESTS) add_cpr_test(ssl) diff --git a/test/abstractServer.cpp b/test/abstractServer.cpp index e32394b5a..5f868e9da 100644 --- a/test/abstractServer.cpp +++ b/test/abstractServer.cpp @@ -49,6 +49,9 @@ static void EventHandler(mg_connection* conn, int event, void* event_data, void* case MG_EV_HTTP_MSG: { AbstractServer* server = static_cast(context); + // Use the connection address as unique identifier instead + int port = AbstractServer::GetRemotePort(conn); + server->AddConnection(port); server->OnRequest(conn, static_cast(event_data)); } break; @@ -79,6 +82,18 @@ void AbstractServer::Run() { server_stop_cv.notify_all(); } +void AbstractServer::AddConnection(int remote_port) { + unique_connections.insert(remote_port); +} + +size_t AbstractServer::GetConnectionCount() { + return unique_connections.size(); +} + +void AbstractServer::ResetConnectionCount() { + unique_connections.clear(); +} + static const std::string base64_chars = "ABCDEFGHIJKLMNOPQRSTUVWXYZ" "abcdefghijklmnopqrstuvwxyz" diff --git a/test/abstractServer.hpp b/test/abstractServer.hpp index d2daec26d..2365b51af 100644 --- a/test/abstractServer.hpp +++ b/test/abstractServer.hpp @@ -7,6 +7,7 @@ #include #include #include +#include #include "cpr/cpr.h" #include "mongoose.h" @@ -38,18 +39,26 @@ class AbstractServer : public testing::Environment { void Start(); void Stop(); + size_t GetConnectionCount(); + void ResetConnectionCount(); + void AddConnection(int remote_port); + virtual std::string GetBaseUrl() = 0; virtual uint16_t GetPort() = 0; virtual void acceptConnection(mg_connection* conn) = 0; virtual void OnRequest(mg_connection* conn, mg_http_message* msg) = 0; + static uint16_t GetRemotePort(const mg_connection* conn); + static uint16_t GetLocalPort(const mg_connection* conn); + private: std::shared_ptr serverThread{nullptr}; std::mutex server_mutex; std::condition_variable server_start_cv; std::condition_variable server_stop_cv; std::atomic should_run{false}; + std::set unique_connections; void Run(); @@ -61,9 +70,6 @@ class AbstractServer : public testing::Environment { static std::string Base64Decode(const std::string& in); static void SendError(mg_connection* conn, int code, std::string& reason); static bool IsConnectionActive(mg_mgr* mgr, mg_connection* conn); - - static uint16_t GetRemotePort(const mg_connection* conn); - static uint16_t GetLocalPort(const mg_connection* conn); }; } // namespace cpr diff --git a/test/connection_pool_tests.cpp b/test/connection_pool_tests.cpp new file mode 100644 index 000000000..ba546a5cd --- /dev/null +++ b/test/connection_pool_tests.cpp @@ -0,0 +1,110 @@ +#include + +#include +#include +#include +#include + +#include + +#include "httpServer.hpp" + +using namespace cpr; + +static HttpServer* server = new HttpServer(); +const size_t NUM_REQUESTS = 10; + +TEST(MultipleGetTests, PoolBasicMultipleGetTest) { + Url url{server->GetBaseUrl() + "/hello.html"}; + ConnectionPool pool; + server->ResetConnectionCount(); + + // Without shared connection pool - make 10 sequential requests + for (size_t i = 0; i < NUM_REQUESTS; ++i) { + Response response = cpr::Get(url); + EXPECT_EQ(std::string{"Hello world!"}, response.text); + EXPECT_EQ(url, response.url); + EXPECT_EQ(std::string{"text/html"}, response.header["content-type"]); + EXPECT_EQ(200, response.status_code); + } + EXPECT_EQ(server->GetConnectionCount(), NUM_REQUESTS); + + // With shared connection pool - make 10 sequential requests + server->ResetConnectionCount(); + for (size_t i = 0; i < NUM_REQUESTS; ++i) { + Response response = cpr::Get(url, pool); + EXPECT_EQ(std::string{"Hello world!"}, response.text); + EXPECT_EQ(url, response.url); + EXPECT_EQ(std::string{"text/html"}, response.header["content-type"]); + EXPECT_EQ(200, response.status_code); + } + EXPECT_LT(server->GetConnectionCount(), NUM_REQUESTS); +} + +TEST(MultipleGetTests, PoolAsyncGetMultipleTest) { + Url url{server->GetBaseUrl() + "/hello.html"}; + ConnectionPool pool; + std::vector responses; + server->ResetConnectionCount(); + + const size_t NUM_BATCHES = 2; + const size_t BATCH_SIZE = NUM_REQUESTS / 2; // 5 requests per batch + + // Without shared connection pool - two batches with 10ms sleep + responses.reserve(NUM_REQUESTS); + + for (size_t batch = 0; batch < NUM_BATCHES; ++batch) { + for (size_t i = 0; i < BATCH_SIZE; ++i) { + responses.emplace_back(cpr::GetAsync(url)); + } + + // Sleep between batches but not after the last batch + if (batch != NUM_BATCHES - 1) { + std::this_thread::sleep_for(std::chrono::milliseconds(5)); + } + } + + // Wait for all responses + for (AsyncResponse& future : responses) { + Response response = future.get(); + EXPECT_EQ(std::string{"Hello world!"}, response.text); + EXPECT_EQ(url, response.url); + EXPECT_EQ(std::string{"text/html"}, response.header["content-type"]); + EXPECT_EQ(200, response.status_code); + } + EXPECT_EQ(server->GetConnectionCount(), NUM_REQUESTS); + + // With shared connection pool - same two-batch approach + server->ResetConnectionCount(); + responses.clear(); + responses.reserve(NUM_REQUESTS); + + for (size_t batch = 0; batch < NUM_BATCHES; ++batch) { + for (size_t i = 0; i < BATCH_SIZE; ++i) { + responses.emplace_back(cpr::GetAsync(url, pool)); + } + + // Sleep between batches but not after the last batch + if (batch != NUM_BATCHES - 1) { + std::this_thread::sleep_for(std::chrono::milliseconds(5)); + } + } + + // Wait for all responses + for (AsyncResponse& future : responses) { + Response response = future.get(); + EXPECT_EQ(std::string{"Hello world!"}, response.text); + EXPECT_EQ(url, response.url); + EXPECT_EQ(std::string{"text/html"}, response.header["content-type"]); + EXPECT_EQ(200, response.status_code); + } + + // With connection pooling, should use fewer connections than requests + EXPECT_LT(server->GetConnectionCount(), NUM_REQUESTS); +} + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + ::testing::AddGlobalTestEnvironment(server); + return RUN_ALL_TESTS(); +} \ No newline at end of file