Skip to content

Commit 47fda99

Browse files
Copilotbbockelm
andauthored
Add per-issuer lock to prevent thundering herd on new issuers (#180)
This commit addresses the thundering herd problem when multiple threads simultaneously try to validate tokens from the same issuer: Per-issuer locking: - Added a per-issuer mutex map with shared_ptr ownership - Threads acquire a lock for an issuer before fetching keys from web - Other threads wait on the lock, then find keys in cache - Lock ownership transfers through async status for proper lifecycle - Limited to 1000 cached mutexes to prevent resource exhaustion Negative caching: - On web fetch failure (e.g., 404), store empty keys in cache - Uses same TTL as successful lookups (get_next_update_delta) - Subsequent lookups hit cache and fail fast without web requests - Prevents repeated web requests for known-bad issuers SQLite busy timeout: - Added 5-second busy timeout to handle concurrent DB access - Applied to all database operations (init, read, write) Stress tests: - StressTestValidToken: 10 threads, 5 seconds, valid token - StressTestInvalidIssuer: 10 threads, 5 seconds, 404 issuer - ConcurrentNewIssuerLookup: Verifies only ONE web fetch occurs Verified behavior: - Valid issuer: ONE key lookup for thousands of validations - Invalid issuer: ONE web request (OIDC + OAuth fallback), then cached --------- Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: Brian P Bockelman <bockelman@gmail.com>
1 parent a569ab5 commit 47fda99

File tree

5 files changed

+718
-19
lines changed

5 files changed

+718
-19
lines changed

src/scitokens_cache.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,10 @@
1818

1919
namespace {
2020

21+
// Timeout in milliseconds to wait when database is locked
22+
// This handles concurrent access from multiple threads/processes
23+
constexpr int SQLITE_BUSY_TIMEOUT_MS = 5000;
24+
2125
void initialize_cachedb(const std::string &keycache_file) {
2226

2327
sqlite3 *db;
@@ -27,6 +31,8 @@ void initialize_cachedb(const std::string &keycache_file) {
2731
sqlite3_close(db);
2832
return;
2933
}
34+
// Set busy timeout to handle concurrent access
35+
sqlite3_busy_timeout(db, SQLITE_BUSY_TIMEOUT_MS);
3036
char *err_msg = nullptr;
3137
rc = sqlite3_exec(db,
3238
"CREATE TABLE IF NOT EXISTS keycache ("
@@ -161,6 +167,8 @@ bool scitokens::Validator::get_public_keys_from_db(const std::string issuer,
161167
sqlite3_close(db);
162168
return false;
163169
}
170+
// Set busy timeout to handle concurrent access
171+
sqlite3_busy_timeout(db, SQLITE_BUSY_TIMEOUT_MS);
164172

165173
sqlite3_stmt *stmt;
166174
rc = sqlite3_prepare_v2(db, "SELECT keys from keycache where issuer = ?",
@@ -260,6 +268,8 @@ bool scitokens::Validator::store_public_keys(const std::string &issuer,
260268
sqlite3_close(db);
261269
return false;
262270
}
271+
// Set busy timeout to handle concurrent access
272+
sqlite3_busy_timeout(db, SQLITE_BUSY_TIMEOUT_MS);
263273

264274
if ((rc = sqlite3_exec(db, "BEGIN", 0, 0, 0)) != SQLITE_OK) {
265275
sqlite3_close(db);

src/scitokens_internal.cpp

Lines changed: 111 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include <memory>
55
#include <sstream>
66
#include <sys/stat.h>
7+
#include <unordered_map>
78

89
#include <jwt-cpp/base.h>
910
#include <jwt-cpp/jwt.h>
@@ -34,6 +35,47 @@ CurlRaii myCurl;
3435

3536
std::mutex key_refresh_mutex;
3637

38+
// Per-issuer mutex map for preventing thundering herd on new issuers
39+
std::mutex issuer_mutex_map_lock;
40+
std::unordered_map<std::string, std::shared_ptr<std::mutex>> issuer_mutexes;
41+
constexpr size_t MAX_ISSUER_MUTEXES = 1000;
42+
43+
// Get or create a mutex for a specific issuer
44+
std::shared_ptr<std::mutex> get_issuer_mutex(const std::string &issuer) {
45+
std::lock_guard<std::mutex> guard(issuer_mutex_map_lock);
46+
47+
auto it = issuer_mutexes.find(issuer);
48+
if (it != issuer_mutexes.end()) {
49+
return it->second;
50+
}
51+
52+
// Prevent resource exhaustion: limit the number of cached mutexes
53+
if (issuer_mutexes.size() >= MAX_ISSUER_MUTEXES) {
54+
// Remove mutexes that are no longer in use
55+
// Since we hold issuer_mutex_map_lock, no other thread can acquire
56+
// a reference to these mutexes, making this check safe
57+
for (auto iter = issuer_mutexes.begin();
58+
iter != issuer_mutexes.end();) {
59+
if (iter->second.use_count() == 1) {
60+
// Only we hold a reference, safe to remove
61+
iter = issuer_mutexes.erase(iter);
62+
} else {
63+
++iter;
64+
}
65+
}
66+
67+
// If still at capacity after cleanup, fail rather than unbounded growth
68+
if (issuer_mutexes.size() >= MAX_ISSUER_MUTEXES) {
69+
throw std::runtime_error(
70+
"Too many concurrent issuers - resource exhaustion prevented");
71+
}
72+
}
73+
74+
auto mutex_ptr = std::make_shared<std::mutex>();
75+
issuer_mutexes[issuer] = mutex_ptr;
76+
return mutex_ptr;
77+
}
78+
3779
} // namespace
3880

3981
namespace scitokens {
@@ -948,16 +990,36 @@ Validator::get_public_key_pem(const std::string &issuer, const std::string &kid,
948990
result->m_done = true;
949991
}
950992
} else {
951-
// No keys in the DB, or they are expired
993+
// No keys in the DB, or they are expired, so get them from the web.
952994
// Record that we had expired keys if the issuer was previously known
953-
// (This is tracked by having an entry in issuer stats)
954995
auto &issuer_stats =
955996
internal::MonitoringStats::instance().get_issuer_stats(issuer);
956997
issuer_stats.inc_expired_key();
957998

958-
// Get keys from the web.
959-
result = get_public_keys_from_web(
960-
issuer, internal::SimpleCurlGet::default_timeout);
999+
// Use per-issuer lock to prevent thundering herd for new issuers
1000+
auto issuer_mutex = get_issuer_mutex(issuer);
1001+
std::unique_lock<std::mutex> issuer_lock(*issuer_mutex);
1002+
1003+
// Check again if keys are now in DB (another thread may have fetched
1004+
// them while we were waiting for the lock)
1005+
if (get_public_keys_from_db(issuer, now, result->m_keys,
1006+
result->m_next_update)) {
1007+
// Keys are now available, use them
1008+
result->m_continue_fetch = false;
1009+
result->m_do_store = false;
1010+
result->m_done = true;
1011+
// Lock released here - no need to hold it
1012+
} else {
1013+
// Still no keys, fetch them from the web
1014+
result = get_public_keys_from_web(
1015+
issuer, internal::SimpleCurlGet::default_timeout);
1016+
1017+
// Transfer ownership of the lock to the async status
1018+
// The lock will be held until keys are stored in
1019+
// get_public_key_pem_continue
1020+
result->m_issuer_mutex = issuer_mutex;
1021+
result->m_issuer_lock = std::move(issuer_lock);
1022+
}
9611023
}
9621024
result->m_issuer = issuer;
9631025
result->m_kid = kid;
@@ -973,21 +1035,56 @@ Validator::get_public_key_pem_continue(std::unique_ptr<AsyncStatus> status,
9731035
std::string &algorithm) {
9741036

9751037
if (status->m_continue_fetch) {
976-
status = get_public_keys_from_web_continue(std::move(status));
977-
if (status->m_continue_fetch) {
978-
return std::move(status);
1038+
// Save issuer and lock info before potentially moving status
1039+
std::string issuer = status->m_issuer;
1040+
auto issuer_mutex = status->m_issuer_mutex;
1041+
std::unique_lock<std::mutex> issuer_lock(
1042+
std::move(status->m_issuer_lock));
1043+
1044+
try {
1045+
status = get_public_keys_from_web_continue(std::move(status));
1046+
if (status->m_continue_fetch) {
1047+
// Restore the lock to status before returning
1048+
status->m_issuer_mutex = issuer_mutex;
1049+
status->m_issuer_lock = std::move(issuer_lock);
1050+
return std::move(status);
1051+
}
1052+
// Success - restore the lock to status for later release
1053+
status->m_issuer_mutex = issuer_mutex;
1054+
status->m_issuer_lock = std::move(issuer_lock);
1055+
} catch (...) {
1056+
// Web fetch failed - store empty keys as negative cache entry
1057+
// This prevents thundering herd on repeated failed lookups
1058+
if (issuer_lock.owns_lock()) {
1059+
// Store empty keys with short TTL for negative caching
1060+
auto now = std::time(NULL);
1061+
int negative_cache_ttl =
1062+
configurer::Configuration::get_next_update_delta();
1063+
picojson::value empty_keys;
1064+
picojson::object keys_obj;
1065+
keys_obj["keys"] = picojson::value(picojson::array());
1066+
empty_keys = picojson::value(keys_obj);
1067+
store_public_keys(issuer, empty_keys, now + negative_cache_ttl,
1068+
now + negative_cache_ttl);
1069+
issuer_lock.unlock();
1070+
}
1071+
throw; // Re-throw the original exception
9791072
}
9801073
}
9811074
if (status->m_do_store) {
9821075
// Async web fetch completed successfully - record monitoring
983-
if (status->m_is_refresh) {
984-
auto &issuer_stats =
985-
internal::MonitoringStats::instance().get_issuer_stats(
986-
status->m_issuer);
987-
issuer_stats.inc_successful_key_lookup();
988-
}
1076+
// This counts both initial fetches and refreshes
1077+
auto &issuer_stats =
1078+
internal::MonitoringStats::instance().get_issuer_stats(
1079+
status->m_issuer);
1080+
issuer_stats.inc_successful_key_lookup();
9891081
store_public_keys(status->m_issuer, status->m_keys,
9901082
status->m_next_update, status->m_expires);
1083+
// Release the per-issuer lock now that keys are stored
1084+
// Other threads waiting on this issuer can now proceed
1085+
if (status->m_issuer_lock.owns_lock()) {
1086+
status->m_issuer_lock.unlock();
1087+
}
9911088
}
9921089
status->m_done = true;
9931090

src/scitokens_internal.h

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -552,6 +552,10 @@ class AsyncStatus {
552552
bool m_is_refresh{false}; // True if this is a refresh of an existing key
553553
AsyncState m_state{DOWNLOAD_METADATA};
554554
std::unique_lock<std::mutex> m_refresh_lock;
555+
// Per-issuer lock to prevent thundering herd on new issuers
556+
// We store both the shared_ptr (to keep mutex alive) and the lock
557+
std::shared_ptr<std::mutex> m_issuer_mutex;
558+
std::unique_lock<std::mutex> m_issuer_lock;
555559

556560
int64_t m_next_update{-1};
557561
int64_t m_expires{-1};
@@ -776,6 +780,8 @@ class Validator {
776780

777781
try {
778782
auto result = verify_async(scitoken);
783+
// Note: m_is_sync flag no longer needed since counting is only done
784+
// in verify_async_continue
779785

780786
// Extract issuer from the result's JWT string after decoding starts
781787
const jwt::decoded_jwt<jwt::traits::kazuho_picojson> *jwt_decoded =
@@ -834,7 +840,8 @@ class Validator {
834840
std::chrono::duration_cast<std::chrono::nanoseconds>(
835841
end_time - last_duration_update);
836842
issuer_stats->add_sync_time(delta);
837-
issuer_stats->inc_successful_validation();
843+
// Note: inc_successful_validation() is called in
844+
// verify_async_continue
838845
}
839846
} catch (const std::exception &e) {
840847
// Record failure (final duration update)
@@ -882,6 +889,8 @@ class Validator {
882889
}
883890

884891
auto result = verify_async(jwt);
892+
// Note: m_is_sync flag no longer needed since counting is only done
893+
// in verify_async_continue
885894
while (!result->m_done) {
886895
result = verify_async_continue(std::move(result));
887896
}
@@ -893,7 +902,8 @@ class Validator {
893902
std::chrono::duration_cast<std::chrono::nanoseconds>(
894903
end_time - start_time);
895904
issuer_stats->add_sync_time(duration);
896-
issuer_stats->inc_successful_validation();
905+
// Note: inc_successful_validation() is called in
906+
// verify_async_continue
897907
}
898908
} catch (const std::exception &e) {
899909
// Record failure if we have an issuer
@@ -1004,6 +1014,7 @@ class Validator {
10041014
// Start monitoring timing and record async validation started
10051015
status->m_start_time = std::chrono::steady_clock::now();
10061016
status->m_monitoring_started = true;
1017+
status->m_issuer = jwt.get_issuer();
10071018
auto &stats = internal::MonitoringStats::instance().get_issuer_stats(
10081019
jwt.get_issuer());
10091020
stats.inc_async_validation_started();
@@ -1181,9 +1192,8 @@ class Validator {
11811192
}
11821193
}
11831194

1184-
// Record successful validation (only for async API, sync handles its
1185-
// own)
1186-
if (status->m_monitoring_started && !status->m_is_sync) {
1195+
// Record successful validation
1196+
if (status->m_monitoring_started) {
11871197
auto end_time = std::chrono::steady_clock::now();
11881198
auto duration =
11891199
std::chrono::duration_cast<std::chrono::nanoseconds>(
@@ -1195,8 +1205,11 @@ class Validator {
11951205
stats.add_async_time(duration);
11961206
}
11971207

1208+
// Create new result, preserving monitoring flags
11981209
std::unique_ptr<AsyncStatus> result(new AsyncStatus());
11991210
result->m_done = true;
1211+
result->m_is_sync = status->m_is_sync;
1212+
result->m_monitoring_started = status->m_monitoring_started;
12001213
return result;
12011214
}
12021215

0 commit comments

Comments
 (0)