Skip to content

Commit a569ab5

Browse files
Copilotbbockelm
andauthored
Add optional background thread for JWKS refresh (#192)
* Implement background JWKS refresh infrastructure * Add integration test for background JWKS refresh * Address code review: include expired entries, use acquire/release semantics * Add background refresh statistics and monitoring coordination --------- Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: Brian P Bockelman <bockelman@gmail.com>
1 parent 57399fc commit a569ab5

File tree

9 files changed

+717
-69
lines changed

9 files changed

+717
-69
lines changed

CMakeLists.txt

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,10 @@ add_library(SciTokens SHARED src/scitokens.cpp src/scitokens_internal.cpp src/sc
4848
target_compile_features(SciTokens PUBLIC cxx_std_11) # Use at least C++11 for building and when linking to scitokens
4949
target_include_directories(SciTokens PUBLIC ${JWT_CPP_INCLUDES} "${PROJECT_SOURCE_DIR}/src" PRIVATE ${CURL_INCLUDE_DIRS} ${OPENSSL_INCLUDE_DIRS} ${LIBCRYPTO_INCLUDE_DIRS} ${SQLITE_INCLUDE_DIRS} ${UUID_INCLUDE_DIRS})
5050

51-
target_link_libraries(SciTokens PUBLIC ${OPENSSL_LIBRARIES} ${LIBCRYPTO_LIBRARIES} ${CURL_LIBRARIES} ${SQLITE_LIBRARIES} ${UUID_LIBRARIES})
51+
# Find threading library
52+
find_package(Threads REQUIRED)
53+
54+
target_link_libraries(SciTokens PUBLIC ${OPENSSL_LIBRARIES} ${LIBCRYPTO_LIBRARIES} ${CURL_LIBRARIES} ${SQLITE_LIBRARIES} ${UUID_LIBRARIES} Threads::Threads)
5255
if (UNIX)
5356
# pkg_check_modules fails to return an absolute path on RHEL7. Set the
5457
# link directories accordingly.

src/scitokens.cpp

Lines changed: 67 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,14 +43,17 @@ void load_config_from_environment() {
4343
bool is_int;
4444
};
4545

46-
const std::array<ConfigMapping, 6> known_configs = {
46+
const std::array<ConfigMapping, 8> known_configs = {
4747
{{"keycache.update_interval_s", "KEYCACHE_UPDATE_INTERVAL_S", true},
4848
{"keycache.expiration_interval_s", "KEYCACHE_EXPIRATION_INTERVAL_S",
4949
true},
5050
{"keycache.cache_home", "KEYCACHE_CACHE_HOME", false},
5151
{"tls.ca_file", "TLS_CA_FILE", false},
5252
{"monitoring.file", "MONITORING_FILE", false},
53-
{"monitoring.file_interval_s", "MONITORING_FILE_INTERVAL_S", true}}};
53+
{"monitoring.file_interval_s", "MONITORING_FILE_INTERVAL_S", true},
54+
{"keycache.refresh_interval_ms", "KEYCACHE_REFRESH_INTERVAL_MS", true},
55+
{"keycache.refresh_threshold_ms", "KEYCACHE_REFRESH_THRESHOLD_MS",
56+
true}}};
5457

5558
const char *prefix = "SCITOKEN_CONFIG_";
5659

@@ -128,6 +131,13 @@ int configurer::Configuration::get_monitoring_file_interval() {
128131
return m_monitoring_file_interval;
129132
}
130133

134+
// Background refresh config
135+
std::atomic_bool configurer::Configuration::m_background_refresh_enabled{false};
136+
std::atomic_int configurer::Configuration::m_refresh_interval_ms{
137+
60000}; // 60 seconds
138+
std::atomic_int configurer::Configuration::m_refresh_threshold_ms{
139+
600000}; // 10 minutes
140+
131141
SciTokenKey scitoken_key_create(const char *key_id, const char *alg,
132142
const char *public_contents,
133143
const char *private_contents, char **err_msg) {
@@ -1099,6 +1109,31 @@ int keycache_set_jwks(const char *issuer, const char *jwks, char **err_msg) {
10991109
return 0;
11001110
}
11011111

1112+
int keycache_set_background_refresh(int enabled, char **err_msg) {
1113+
try {
1114+
bool enable = (enabled != 0);
1115+
configurer::Configuration::set_background_refresh_enabled(enable);
1116+
1117+
if (enable) {
1118+
scitokens::internal::BackgroundRefreshManager::get_instance()
1119+
.start();
1120+
} else {
1121+
scitokens::internal::BackgroundRefreshManager::get_instance()
1122+
.stop();
1123+
}
1124+
} catch (std::exception &exc) {
1125+
if (err_msg) {
1126+
*err_msg = strdup(exc.what());
1127+
}
1128+
return -1;
1129+
}
1130+
return 0;
1131+
}
1132+
1133+
int keycache_stop_background_refresh(char **err_msg) {
1134+
return keycache_set_background_refresh(0, err_msg);
1135+
}
1136+
11021137
int config_set_int(const char *key, int value, char **err_msg) {
11031138
return scitoken_config_set_int(key, value, err_msg);
11041139
}
@@ -1145,6 +1180,28 @@ int scitoken_config_set_int(const char *key, int value, char **err_msg) {
11451180
return 0;
11461181
}
11471182

1183+
else if (_key == "keycache.refresh_interval_ms") {
1184+
if (value < 0) {
1185+
if (err_msg) {
1186+
*err_msg = strdup("Refresh interval must be positive.");
1187+
}
1188+
return -1;
1189+
}
1190+
configurer::Configuration::set_refresh_interval(value);
1191+
return 0;
1192+
}
1193+
1194+
else if (_key == "keycache.refresh_threshold_ms") {
1195+
if (value < 0) {
1196+
if (err_msg) {
1197+
*err_msg = strdup("Refresh threshold must be positive.");
1198+
}
1199+
return -1;
1200+
}
1201+
configurer::Configuration::set_refresh_threshold(value);
1202+
return 0;
1203+
}
1204+
11481205
else {
11491206
if (err_msg) {
11501207
*err_msg = strdup("Key not recognized.");
@@ -1178,6 +1235,14 @@ int scitoken_config_get_int(const char *key, char **err_msg) {
11781235
return configurer::Configuration::get_monitoring_file_interval();
11791236
}
11801237

1238+
else if (_key == "keycache.refresh_interval_ms") {
1239+
return configurer::Configuration::get_refresh_interval();
1240+
}
1241+
1242+
else if (_key == "keycache.refresh_threshold_ms") {
1243+
return configurer::Configuration::get_refresh_threshold();
1244+
}
1245+
11811246
else {
11821247
if (err_msg) {
11831248
*err_msg = strdup("Key not recognized.");

src/scitokens.h

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -290,6 +290,25 @@ int keycache_get_cached_jwks(const char *issuer, char **jwks, char **err_msg);
290290
*/
291291
int keycache_set_jwks(const char *issuer, const char *jwks, char **err_msg);
292292

293+
/**
294+
* Enable or disable the background refresh thread for JWKS.
295+
* - When enabled, a background thread will periodically check if any known
296+
* issuers need their JWKS refreshed based on the configured refresh interval
297+
* and threshold.
298+
* - If enabled=1 and the thread is not running, it will be started.
299+
* - If enabled=0 and the thread is running, it will be stopped gracefully.
300+
* - Returns 0 on success, nonzero on failure.
301+
*/
302+
int keycache_set_background_refresh(int enabled, char **err_msg);
303+
304+
/**
305+
* Stop the background refresh thread if it is running.
306+
* - This is a convenience function equivalent to
307+
* keycache_set_background_refresh(0, err_msg).
308+
* - Returns 0 on success, nonzero on failure.
309+
*/
310+
int keycache_stop_background_refresh(char **err_msg);
311+
293312
/**
294313
* APIs for managing scitokens configuration parameters.
295314
*/
@@ -308,6 +327,10 @@ int config_set_int(const char *key, int value, char **err_msg);
308327
* - "keycache.expiration_interval_s": Key cache expiration time (seconds)
309328
* - "monitoring.file_interval_s": Interval between monitoring file writes
310329
* (seconds, default 60)
330+
* - "keycache.refresh_interval_ms": Background refresh thread check interval
331+
* (milliseconds, default 60000)
332+
* - "keycache.refresh_threshold_ms": Time before next_update when background
333+
* refresh triggers (milliseconds, default 600000)
311334
*/
312335
int scitoken_config_set_int(const char *key, int value, char **err_msg);
313336

@@ -325,6 +348,10 @@ int config_get_int(const char *key, char **err_msg);
325348
* - "keycache.expiration_interval_s": Key cache expiration time (seconds)
326349
* - "monitoring.file_interval_s": Interval between monitoring file writes
327350
* (seconds, default 60)
351+
* - "keycache.refresh_interval_ms": Background refresh thread check interval
352+
* (milliseconds, default 60000)
353+
* - "keycache.refresh_threshold_ms": Time before next_update when background
354+
* refresh triggers (milliseconds, default 600000)
328355
*/
329356
int scitoken_config_get_int(const char *key, char **err_msg);
330357

src/scitokens_cache.cpp

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -308,3 +308,81 @@ bool scitokens::Validator::store_public_keys(const std::string &issuer,
308308
sqlite3_close(db);
309309
return true;
310310
}
311+
312+
std::vector<std::pair<std::string, int64_t>>
313+
scitokens::Validator::get_all_issuers_from_db(int64_t now) {
314+
std::vector<std::pair<std::string, int64_t>> result;
315+
316+
auto cache_fname = get_cache_file();
317+
if (cache_fname.size() == 0) {
318+
return result;
319+
}
320+
321+
sqlite3 *db;
322+
int rc = sqlite3_open(cache_fname.c_str(), &db);
323+
if (rc) {
324+
sqlite3_close(db);
325+
return result;
326+
}
327+
328+
sqlite3_stmt *stmt;
329+
rc = sqlite3_prepare_v2(db, "SELECT issuer, keys FROM keycache", -1, &stmt,
330+
NULL);
331+
if (rc != SQLITE_OK) {
332+
sqlite3_close(db);
333+
return result;
334+
}
335+
336+
while ((rc = sqlite3_step(stmt)) == SQLITE_ROW) {
337+
const unsigned char *issuer_data = sqlite3_column_text(stmt, 0);
338+
const unsigned char *keys_data = sqlite3_column_text(stmt, 1);
339+
340+
if (!issuer_data || !keys_data) {
341+
continue;
342+
}
343+
344+
std::string issuer(reinterpret_cast<const char *>(issuer_data));
345+
std::string metadata(reinterpret_cast<const char *>(keys_data));
346+
347+
// Parse the metadata to get next_update and check expiry
348+
picojson::value json_obj;
349+
auto err = picojson::parse(json_obj, metadata);
350+
if (!err.empty() || !json_obj.is<picojson::object>()) {
351+
continue;
352+
}
353+
354+
auto top_obj = json_obj.get<picojson::object>();
355+
356+
// Get expiry time
357+
auto expires_iter = top_obj.find("expires");
358+
if (expires_iter == top_obj.end() ||
359+
!expires_iter->second.is<int64_t>()) {
360+
continue;
361+
}
362+
auto expiry = expires_iter->second.get<int64_t>();
363+
364+
// Get next_update time
365+
auto next_update_iter = top_obj.find("next_update");
366+
int64_t next_update;
367+
if (next_update_iter == top_obj.end() ||
368+
!next_update_iter->second.is<int64_t>()) {
369+
// If next_update is not set, default to 4 hours before expiry
370+
next_update = expiry - 4 * 3600;
371+
} else {
372+
next_update = next_update_iter->second.get<int64_t>();
373+
}
374+
375+
// Include expired entries - they should be refreshed after a long
376+
// downtime If expired, set next_update to now so they get refreshed
377+
// immediately
378+
if (now > expiry) {
379+
next_update = now;
380+
}
381+
382+
result.push_back({issuer, next_update});
383+
}
384+
385+
sqlite3_finalize(stmt);
386+
sqlite3_close(db);
387+
return result;
388+
}

src/scitokens_internal.cpp

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11

2+
#include <chrono>
23
#include <functional>
34
#include <memory>
45
#include <sstream>
@@ -37,8 +38,101 @@ std::mutex key_refresh_mutex;
3738

3839
namespace scitokens {
3940

41+
// Define the static once_flag for Validator
42+
std::once_flag Validator::m_background_refresh_once;
43+
4044
namespace internal {
4145

46+
// BackgroundRefreshManager implementation
47+
void BackgroundRefreshManager::start() {
48+
std::lock_guard<std::mutex> lock(m_mutex);
49+
if (m_running.load(std::memory_order_acquire)) {
50+
return; // Already running
51+
}
52+
m_shutdown.store(false, std::memory_order_release);
53+
m_running.store(true, std::memory_order_release);
54+
m_thread = std::make_unique<std::thread>(
55+
&BackgroundRefreshManager::refresh_loop, this);
56+
}
57+
58+
void BackgroundRefreshManager::stop() {
59+
std::unique_ptr<std::thread> thread_to_join;
60+
61+
{
62+
std::lock_guard<std::mutex> lock(m_mutex);
63+
if (!m_running.load(std::memory_order_acquire)) {
64+
return; // Not running
65+
}
66+
67+
m_shutdown.store(true, std::memory_order_release);
68+
m_running.store(false, std::memory_order_release);
69+
thread_to_join = std::move(m_thread);
70+
}
71+
72+
m_cv.notify_all();
73+
74+
if (thread_to_join && thread_to_join->joinable()) {
75+
thread_to_join->join();
76+
}
77+
}
78+
79+
void BackgroundRefreshManager::refresh_loop() {
80+
while (!m_shutdown.load(std::memory_order_acquire)) {
81+
auto interval = configurer::Configuration::get_refresh_interval();
82+
auto threshold = configurer::Configuration::get_refresh_threshold();
83+
84+
// Wait for the interval or until shutdown
85+
{
86+
std::unique_lock<std::mutex> lock(m_mutex);
87+
m_cv.wait_for(lock, std::chrono::milliseconds(interval), [this]() {
88+
return m_shutdown.load(std::memory_order_acquire);
89+
});
90+
}
91+
92+
if (m_shutdown.load(std::memory_order_acquire)) {
93+
break;
94+
}
95+
96+
// Get list of issuers from the database
97+
auto now = std::time(NULL);
98+
auto issuers = scitokens::Validator::get_all_issuers_from_db(now);
99+
100+
for (const auto &issuer_pair : issuers) {
101+
if (m_shutdown.load(std::memory_order_acquire)) {
102+
break;
103+
}
104+
105+
const auto &issuer = issuer_pair.first;
106+
const auto &next_update = issuer_pair.second;
107+
108+
// Calculate time until next_update in milliseconds
109+
int64_t time_until_update = (next_update - now) * 1000;
110+
111+
// If next update is within threshold, try to refresh
112+
if (time_until_update <= threshold) {
113+
auto &stats =
114+
MonitoringStats::instance().get_issuer_stats(issuer);
115+
try {
116+
// Perform refresh (this will use the refresh_jwks method)
117+
scitokens::Validator::refresh_jwks(issuer);
118+
stats.inc_background_successful_refresh();
119+
} catch (std::exception &) {
120+
// Track failed refresh attempts
121+
stats.inc_background_failed_refresh();
122+
// Silently ignore errors in background refresh to avoid
123+
// disrupting the application. Background refresh is a
124+
// best-effort optimization. If it fails, the next token
125+
// verification will trigger a foreground refresh as usual.
126+
}
127+
}
128+
}
129+
130+
// Write monitoring file from background thread if configured
131+
// This avoids writing from verify() when background thread is running
132+
MonitoringStats::instance().maybe_write_monitoring_file();
133+
}
134+
}
135+
42136
SimpleCurlGet::GetStatus SimpleCurlGet::perform_start(const std::string &url) {
43137
m_len = 0;
44138

0 commit comments

Comments
 (0)