Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions cpr/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ add_library(cpr
auth.cpp
callback.cpp
cert_info.cpp
connection_pool.cpp
cookies.cpp
cprtypes.cpp
curl_container.cpp
Expand Down
39 changes: 39 additions & 0 deletions cpr/connection_pool.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
#include "cpr/connection_pool.h"
#include <curl/curl.h>
#include <memory>
#include <mutex>

namespace cpr {
ConnectionPool::ConnectionPool() {
CURLSH* curl_share = curl_share_init();
this->connection_mutex_ = std::make_shared<std::mutex>();

auto lock_f = +[](CURL* /*handle*/, curl_lock_data /*data*/, curl_lock_access /*access*/, void* userptr) {
std::mutex* lock = static_cast<std::mutex*>(userptr);
lock->lock(); // cppcheck-suppress localMutex // False positive: mutex is used as callback for libcurl, not local scope
};

auto unlock_f = +[](CURL* /*handle*/, curl_lock_data /*data*/, void* userptr) {
std::mutex* lock = static_cast<std::mutex*>(userptr);
lock->unlock();
};

curl_share_setopt(curl_share, CURLSHOPT_SHARE, CURL_LOCK_DATA_CONNECT);
curl_share_setopt(curl_share, CURLSHOPT_USERDATA, this->connection_mutex_.get());
curl_share_setopt(curl_share, CURLSHOPT_LOCKFUNC, lock_f);
curl_share_setopt(curl_share, CURLSHOPT_UNLOCKFUNC, unlock_f);

this->curl_sh_ = std::shared_ptr<CURLSH>(curl_share,
[](CURLSH* ptr) {
// Make sure to reset callbacks before cleanup to avoid deadlocks
curl_share_setopt(ptr, CURLSHOPT_LOCKFUNC, nullptr);
curl_share_setopt(ptr, CURLSHOPT_UNLOCKFUNC, nullptr);
curl_share_cleanup(ptr);
});
}

void ConnectionPool::SetupHandler(CURL* easy_handler) const {
curl_easy_setopt(easy_handler, CURLOPT_SHARE, this->curl_sh_.get());
}

} // namespace cpr
7 changes: 7 additions & 0 deletions cpr/session.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
#include "cpr/body_view.h"
#include "cpr/callback.h"
#include "cpr/connect_timeout.h"
#include "cpr/connection_pool.h"
#include "cpr/cookies.h"
#include "cpr/cprtypes.h"
#include "cpr/curlholder.h"
Expand Down Expand Up @@ -397,6 +398,11 @@ void Session::SetConnectTimeout(const ConnectTimeout& timeout) {
curl_easy_setopt(curl_->handle, CURLOPT_CONNECTTIMEOUT_MS, timeout.Milliseconds());
}

void Session::SetConnectionPool(const ConnectionPool& pool) {
CURL* curl = curl_->handle;
pool.SetupHandler(curl);
}

void Session::SetAuth(const Authentication& auth) {
// Ignore here since this has been defined by libcurl.
switch (auth.GetAuthMode()) {
Expand Down Expand Up @@ -1091,6 +1097,7 @@ void Session::SetOption(const MultiRange& multi_range) { SetMultiRange(multi_ran
void Session::SetOption(const ReserveSize& reserve_size) { SetReserveSize(reserve_size.size); }
void Session::SetOption(const AcceptEncoding& accept_encoding) { SetAcceptEncoding(accept_encoding); }
void Session::SetOption(AcceptEncoding&& accept_encoding) { SetAcceptEncoding(std::move(accept_encoding)); }
void Session::SetOption(const ConnectionPool& pool) { SetConnectionPool(pool); }
// clang-format on

void Session::SetCancellationParam(std::shared_ptr<std::atomic_bool> param) {
Expand Down
80 changes: 80 additions & 0 deletions include/cpr/connection_pool.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
#ifndef CPR_CONNECTION_POOL_H
#define CPR_CONNECTION_POOL_H

#include <curl/curl.h>
#include <memory>
#include <mutex>

namespace cpr {
/**
* cpr connection pool implementation for sharing connections between HTTP requests.
*
* The ConnectionPool enables connection reuse across multiple HTTP requests to the same host,
* which can significantly improve performance by avoiding the overhead of establishing new
* connections for each request. It uses libcurl's CURLSH (share) interface to manage
* connection sharing in a thread-safe manner.
*
* Example:
* ```cpp
* // Create a connection pool
* cpr::ConnectionPool pool;
*
* // Use the pool with requests to reuse connections
* cpr::Response r1 = cpr::Get(cpr::Url{"http://example.com/api/data"}, pool);
* cpr::Response r2 = cpr::Get(cpr::Url{"http://example.com/api/more"}, pool);
*
* // Or with async requests
* auto future1 = cpr::GetAsync(cpr::Url{"http://example.com/api/data"}, pool);
* auto future2 = cpr::GetAsync(cpr::Url{"http://example.com/api/more"}, pool);
* ```
**/
class ConnectionPool {
public:
/**
* Creates a new connection pool with shared connection state.
* Initializes the underlying CURLSH handle and sets up thread-safe locking mechanisms.
**/
ConnectionPool();

/**
* Copy constructor - creates a new connection pool sharing the same connection state.
* Multiple ConnectionPool instances can share the same underlying connection pool.
**/
ConnectionPool(const ConnectionPool&) = default;

/**
* Copy assignment operator is deleted to prevent accidental copying.
* Use the copy constructor if you need to share the connection pool.
**/
ConnectionPool& operator=(const ConnectionPool&) = delete;

/**
* Configures a CURL easy handle to use this connection pool.
* This method sets up the easy handle to participate in connection sharing
* managed by this pool.
*
* @param easy_handler The CURL easy handle to configure for connection sharing.
**/
void SetupHandler(CURL* easy_handler) const;

private:
/**
* Thread-safe mutex used for synchronizing access to shared connections.
* This mutex is passed to libcurl's locking callbacks to ensure thread safety
* when multiple threads access the same connection pool. It's declared first
* to ensure it's destroyed last, after the CURLSH handle that references it.
**/
std::shared_ptr<std::mutex> connection_mutex_;

/**
* Shared CURL handle (CURLSH) that manages the actual connection sharing.
* This handle maintains the pool of reusable connections and is configured
* with appropriate locking callbacks for thread safety. The shared_ptr uses
* a custom deleter that safely resets the lock/unlock callbacks before
* calling curl_share_cleanup() to prevent use-after-free issues during destruction.
* Declared last to ensure it's destroyed first, before the mutex it references.
**/
std::shared_ptr<CURLSH> curl_sh_;
};
} // namespace cpr
#endif
1 change: 1 addition & 0 deletions include/cpr/cpr.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include "cpr/callback.h"
#include "cpr/cert_info.h"
#include "cpr/connect_timeout.h"
#include "cpr/connection_pool.h"
#include "cpr/cookies.h"
#include "cpr/cprtypes.h"
#include "cpr/cprver.h"
Expand Down
1 change: 0 additions & 1 deletion include/cpr/curlholder.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
#include <array>
#include <curl/curl.h>
#include <mutex>
#include <string>

#include "cpr/secure_string.h"

Expand Down
3 changes: 3 additions & 0 deletions include/cpr/session.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "cpr/body_view.h"
#include "cpr/callback.h"
#include "cpr/connect_timeout.h"
#include "cpr/connection_pool.h"
#include "cpr/cookies.h"
#include "cpr/cprtypes.h"
#include "cpr/curlholder.h"
Expand Down Expand Up @@ -72,6 +73,7 @@ class Session : public std::enable_shared_from_this<Session> {
[[nodiscard]] const Header& GetHeader() const;
void SetTimeout(const Timeout& timeout);
void SetConnectTimeout(const ConnectTimeout& timeout);
void SetConnectionPool(const ConnectionPool& pool);
void SetAuth(const Authentication& auth);
// Only supported with libcurl >= 7.61.0.
// As an alternative use SetHeader and add the token manually.
Expand Down Expand Up @@ -137,6 +139,7 @@ class Session : public std::enable_shared_from_this<Session> {
void SetOption(const Timeout& timeout);
void SetOption(const ConnectTimeout& timeout);
void SetOption(const Authentication& auth);
void SetOption(const ConnectionPool& pool);
// Only supported with libcurl >= 7.61.0.
// As an alternative use SetHeader and add the token manually.
#if LIBCURL_VERSION_NUM >= 0x073D00
Expand Down
1 change: 1 addition & 0 deletions test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ add_cpr_test(file_upload)
add_cpr_test(singleton)
add_cpr_test(threadpool)
add_cpr_test(testUtils)
add_cpr_test(connection_pool)

if (ENABLE_SSL_TESTS)
add_cpr_test(ssl)
Expand Down
15 changes: 15 additions & 0 deletions test/abstractServer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@ static void EventHandler(mg_connection* conn, int event, void* event_data, void*

case MG_EV_HTTP_MSG: {
AbstractServer* server = static_cast<AbstractServer*>(context);
// Use the connection address as unique identifier instead
int port = AbstractServer::GetRemotePort(conn);
server->AddConnection(port);
server->OnRequest(conn, static_cast<mg_http_message*>(event_data));
} break;

Expand Down Expand Up @@ -79,6 +82,18 @@ void AbstractServer::Run() {
server_stop_cv.notify_all();
}

void AbstractServer::AddConnection(int remote_port) {
unique_connections.insert(remote_port);
}

size_t AbstractServer::GetConnectionCount() {
return unique_connections.size();
}

void AbstractServer::ResetConnectionCount() {
unique_connections.clear();
}

static const std::string base64_chars =
"ABCDEFGHIJKLMNOPQRSTUVWXYZ"
"abcdefghijklmnopqrstuvwxyz"
Expand Down
12 changes: 9 additions & 3 deletions test/abstractServer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include <memory>
#include <mutex>
#include <string>
#include <set>

#include "cpr/cpr.h"
#include "mongoose.h"
Expand Down Expand Up @@ -38,18 +39,26 @@ class AbstractServer : public testing::Environment {
void Start();
void Stop();

size_t GetConnectionCount();
void ResetConnectionCount();
void AddConnection(int remote_port);

virtual std::string GetBaseUrl() = 0;
virtual uint16_t GetPort() = 0;

virtual void acceptConnection(mg_connection* conn) = 0;
virtual void OnRequest(mg_connection* conn, mg_http_message* msg) = 0;

static uint16_t GetRemotePort(const mg_connection* conn);
static uint16_t GetLocalPort(const mg_connection* conn);

private:
std::shared_ptr<std::thread> serverThread{nullptr};
std::mutex server_mutex;
std::condition_variable server_start_cv;
std::condition_variable server_stop_cv;
std::atomic<bool> should_run{false};
std::set<int> unique_connections;

void Run();

Expand All @@ -61,9 +70,6 @@ class AbstractServer : public testing::Environment {
static std::string Base64Decode(const std::string& in);
static void SendError(mg_connection* conn, int code, std::string& reason);
static bool IsConnectionActive(mg_mgr* mgr, mg_connection* conn);

static uint16_t GetRemotePort(const mg_connection* conn);
static uint16_t GetLocalPort(const mg_connection* conn);
};
} // namespace cpr

Expand Down
110 changes: 110 additions & 0 deletions test/connection_pool_tests.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
#include <gtest/gtest.h>

#include <string>
#include <vector>
#include <thread>
#include <chrono>

#include <cpr/cpr.h>

#include "httpServer.hpp"

using namespace cpr;

static HttpServer* server = new HttpServer();
const size_t NUM_REQUESTS = 10;

TEST(MultipleGetTests, PoolBasicMultipleGetTest) {
Url url{server->GetBaseUrl() + "/hello.html"};
ConnectionPool pool;
server->ResetConnectionCount();

// Without shared connection pool - make 10 sequential requests
for (size_t i = 0; i < NUM_REQUESTS; ++i) {
Response response = cpr::Get(url);
EXPECT_EQ(std::string{"Hello world!"}, response.text);
EXPECT_EQ(url, response.url);
EXPECT_EQ(std::string{"text/html"}, response.header["content-type"]);
EXPECT_EQ(200, response.status_code);
}
EXPECT_EQ(server->GetConnectionCount(), NUM_REQUESTS);

// With shared connection pool - make 10 sequential requests
server->ResetConnectionCount();
for (size_t i = 0; i < NUM_REQUESTS; ++i) {
Response response = cpr::Get(url, pool);
EXPECT_EQ(std::string{"Hello world!"}, response.text);
EXPECT_EQ(url, response.url);
EXPECT_EQ(std::string{"text/html"}, response.header["content-type"]);
EXPECT_EQ(200, response.status_code);
}
EXPECT_LT(server->GetConnectionCount(), NUM_REQUESTS);
}

TEST(MultipleGetTests, PoolAsyncGetMultipleTest) {
Url url{server->GetBaseUrl() + "/hello.html"};
ConnectionPool pool;
std::vector<AsyncResponse> responses;
server->ResetConnectionCount();

const size_t NUM_BATCHES = 2;
const size_t BATCH_SIZE = NUM_REQUESTS / 2; // 5 requests per batch

// Without shared connection pool - two batches with 10ms sleep
responses.reserve(NUM_REQUESTS);

for (size_t batch = 0; batch < NUM_BATCHES; ++batch) {
for (size_t i = 0; i < BATCH_SIZE; ++i) {
responses.emplace_back(cpr::GetAsync(url));
}

// Sleep between batches but not after the last batch
if (batch != NUM_BATCHES - 1) {
std::this_thread::sleep_for(std::chrono::milliseconds(5));
}
}

// Wait for all responses
for (AsyncResponse& future : responses) {
Response response = future.get();
EXPECT_EQ(std::string{"Hello world!"}, response.text);
EXPECT_EQ(url, response.url);
EXPECT_EQ(std::string{"text/html"}, response.header["content-type"]);
EXPECT_EQ(200, response.status_code);
}
EXPECT_EQ(server->GetConnectionCount(), NUM_REQUESTS);

// With shared connection pool - same two-batch approach
server->ResetConnectionCount();
responses.clear();
responses.reserve(NUM_REQUESTS);

for (size_t batch = 0; batch < NUM_BATCHES; ++batch) {
for (size_t i = 0; i < BATCH_SIZE; ++i) {
responses.emplace_back(cpr::GetAsync(url, pool));
}

// Sleep between batches but not after the last batch
if (batch != NUM_BATCHES - 1) {
std::this_thread::sleep_for(std::chrono::milliseconds(5));
}
}

// Wait for all responses
for (AsyncResponse& future : responses) {
Response response = future.get();
EXPECT_EQ(std::string{"Hello world!"}, response.text);
EXPECT_EQ(url, response.url);
EXPECT_EQ(std::string{"text/html"}, response.header["content-type"]);
EXPECT_EQ(200, response.status_code);
}

// With connection pooling, should use fewer connections than requests
EXPECT_LT(server->GetConnectionCount(), NUM_REQUESTS);
}

int main(int argc, char** argv) {
::testing::InitGoogleTest(&argc, argv);
::testing::AddGlobalTestEnvironment(server);
return RUN_ALL_TESTS();
}
Loading