Skip to content

Commit ed86d9a

Browse files
Copilotbbockelm
andcommitted
Address code review feedback: use exception hierarchy, report total time in seconds, periodic stats updates
Co-authored-by: bbockelm <1093447+bbockelm@users.noreply.github.com>
1 parent e27ae90 commit ed86d9a

File tree

4 files changed

+119
-77
lines changed

4 files changed

+119
-77
lines changed

src/scitokens.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -335,7 +335,7 @@ int scitoken_config_get_str(const char *key, char **output, char **err_msg);
335335
* - successful_validations: count of successful token validations
336336
* - unsuccessful_validations: count of failed token validations
337337
* - expired_tokens: count of expired tokens encountered
338-
* - average_validation_time_ms: average validation time in milliseconds
338+
* - total_validation_time_s: total validation time in seconds
339339
* - failed_issuer_lookups: count of failed issuer lookups (limited to prevent DDoS)
340340
*
341341
* The returned string must be freed by the caller using free().

src/scitokens_internal.cpp

Lines changed: 36 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -652,35 +652,41 @@ SciToken::deserialize_continue(std::unique_ptr<SciTokenAsyncStatus> status) {
652652
std::unique_ptr<AsyncStatus>
653653
Validator::get_public_keys_from_web(const std::string &issuer,
654654
unsigned timeout) {
655-
std::string openid_metadata, oauth_metadata;
656-
get_metadata_endpoint(issuer, openid_metadata, oauth_metadata);
657-
658-
std::unique_ptr<AsyncStatus> status(new AsyncStatus());
659-
status->m_oauth_metadata_url = oauth_metadata;
660-
status->m_cget.reset(new internal::SimpleCurlGet(1024 * 1024, timeout));
661-
auto cget_status = status->m_cget->perform_start(openid_metadata);
662-
status->m_continue_fetch = true;
663-
if (!cget_status.m_done) {
664-
return status;
665-
}
666-
return get_public_keys_from_web_continue(std::move(status));
655+
try {
656+
std::string openid_metadata, oauth_metadata;
657+
get_metadata_endpoint(issuer, openid_metadata, oauth_metadata);
658+
659+
std::unique_ptr<AsyncStatus> status(new AsyncStatus());
660+
status->m_oauth_metadata_url = oauth_metadata;
661+
status->m_cget.reset(new internal::SimpleCurlGet(1024 * 1024, timeout));
662+
auto cget_status = status->m_cget->perform_start(openid_metadata);
663+
status->m_continue_fetch = true;
664+
if (!cget_status.m_done) {
665+
return status;
666+
}
667+
return get_public_keys_from_web_continue(std::move(status));
668+
} catch (const CurlException &e) {
669+
// Rethrow CURL errors during issuer key fetch as IssuerLookupException
670+
throw IssuerLookupException(e.what());
671+
}
667672
}
668673

669674
std::unique_ptr<AsyncStatus> Validator::get_public_keys_from_web_continue(
670675
std::unique_ptr<AsyncStatus> status) {
671-
char *buffer;
672-
size_t len;
676+
try {
677+
char *buffer;
678+
size_t len;
673679

674-
switch (status->m_state) {
680+
switch (status->m_state) {
675681

676-
case AsyncStatus::DOWNLOAD_METADATA: {
677-
auto cget_status = status->m_cget->perform_continue();
678-
if (!cget_status.m_done) {
679-
return std::move(status);
680-
}
682+
case AsyncStatus::DOWNLOAD_METADATA: {
683+
auto cget_status = status->m_cget->perform_continue();
684+
if (!cget_status.m_done) {
685+
return std::move(status);
686+
}
681687
if (cget_status.m_status_code != 200) {
682688
if (status->m_oauth_fallback) {
683-
throw CurlException("Failed to retrieve metadata provider "
689+
throw IssuerLookupException("Failed to retrieve metadata provider "
684690
"information for issuer.");
685691
} else {
686692
status->m_oauth_fallback = true;
@@ -729,7 +735,7 @@ std::unique_ptr<AsyncStatus> Validator::get_public_keys_from_web_continue(
729735
return std::move(status);
730736
}
731737
if (cget_status.m_status_code != 200) {
732-
throw CurlException("Failed to retrieve the issuer's key set");
738+
throw IssuerLookupException("Failed to retrieve the issuer's key set");
733739
}
734740

735741
status->m_cget->get_data(buffer, len);
@@ -762,6 +768,14 @@ std::unique_ptr<AsyncStatus> Validator::get_public_keys_from_web_continue(
762768

763769
} // Switch
764770
return std::move(status);
771+
} catch (const CurlException &e) {
772+
// Rethrow CURL errors during issuer key fetch as IssuerLookupException
773+
// (unless it's already an IssuerLookupException)
774+
if (dynamic_cast<const IssuerLookupException*>(&e)) {
775+
throw;
776+
}
777+
throw IssuerLookupException(e.what());
778+
}
765779
}
766780

767781
std::string Validator::get_jwks(const std::string &issuer) {

src/scitokens_internal.h

Lines changed: 71 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -121,8 +121,7 @@ struct IssuerStats {
121121
std::atomic<uint64_t> successful_validations{0};
122122
std::atomic<uint64_t> unsuccessful_validations{0};
123123
std::atomic<uint64_t> expired_tokens{0};
124-
std::atomic<uint64_t> total_time_ns{0}; // Total time in nanoseconds
125-
std::atomic<uint64_t> validation_count{0}; // For computing average
124+
std::atomic<double> total_time_s{0.0}; // Total time in seconds
126125
};
127126

128127
/**
@@ -134,10 +133,13 @@ class MonitoringStats {
134133
public:
135134
static MonitoringStats &instance();
136135

136+
// Get a reference to issuer stats for periodic updates
137+
IssuerStats* get_issuer_stats(const std::string &issuer);
138+
137139
void record_validation_success(const std::string &issuer,
138-
uint64_t duration_ns);
140+
double duration_s);
139141
void record_validation_failure(const std::string &issuer,
140-
uint64_t duration_ns);
142+
double duration_s);
141143
void record_expired_token(const std::string &issuer);
142144
void record_failed_issuer_lookup(const std::string &issuer);
143145

@@ -180,6 +182,17 @@ class CurlException : public std::runtime_error {
180182
explicit CurlException(const std::string &msg) : std::runtime_error(msg) {}
181183
};
182184

185+
class IssuerLookupException : public CurlException {
186+
public:
187+
explicit IssuerLookupException(const std::string &msg) : CurlException(msg) {}
188+
};
189+
190+
class TokenExpiredException : public JWTVerificationException {
191+
public:
192+
explicit TokenExpiredException(const std::string &msg)
193+
: JWTVerificationException(msg) {}
194+
};
195+
183196
class MissingIssuerException : public std::runtime_error {
184197
public:
185198
MissingIssuerException()
@@ -473,21 +486,39 @@ class Validator {
473486
void verify(const SciToken &scitoken, time_t expiry_time) {
474487
std::string issuer = "";
475488
auto start_time = std::chrono::steady_clock::now();
476-
bool has_issuer = false;
489+
internal::IssuerStats* stats_ptr = nullptr;
477490

478491
try {
479-
// Try to extract issuer for monitoring
480-
if (scitoken.m_decoded && scitoken.m_decoded->has_payload_claim("iss")) {
481-
issuer = scitoken.m_decoded->get_issuer();
482-
has_issuer = true;
492+
auto result = verify_async(scitoken);
493+
494+
// Extract issuer from the result's JWT string after decoding starts
495+
const jwt::decoded_jwt<jwt::traits::kazuho_picojson> *jwt_decoded =
496+
scitoken.m_decoded.get();
497+
if (jwt_decoded && jwt_decoded->has_payload_claim("iss")) {
498+
issuer = jwt_decoded->get_issuer();
499+
stats_ptr = internal::MonitoringStats::instance().get_issuer_stats(issuer);
483500
}
484501

485-
auto result = verify_async(scitoken);
486502
while (!result->m_done) {
487503
auto timeout_val = result->get_timeout_val(expiry_time);
504+
// Limit select to 50ms for periodic updates
505+
if (timeout_val.tv_sec > 0 || timeout_val.tv_usec > 50000) {
506+
timeout_val.tv_sec = 0;
507+
timeout_val.tv_usec = 50000;
508+
}
509+
488510
select(result->get_max_fd() + 1, result->get_read_fd_set(),
489511
result->get_write_fd_set(), result->get_exc_fd_set(),
490512
&timeout_val);
513+
514+
// Update elapsed time periodically
515+
if (stats_ptr) {
516+
auto current_time = std::chrono::steady_clock::now();
517+
auto duration = std::chrono::duration_cast<std::chrono::duration<double>>(
518+
current_time - start_time);
519+
stats_ptr->total_time_s = duration.count();
520+
}
521+
491522
if (time(NULL) >= expiry_time) {
492523
throw CurlException("Timeout when loading the OIDC metadata.");
493524
}
@@ -496,17 +527,20 @@ class Validator {
496527
}
497528

498529
// Record successful validation
499-
if (has_issuer) {
530+
if (!issuer.empty()) {
500531
auto end_time = std::chrono::steady_clock::now();
501-
auto duration = std::chrono::duration_cast<std::chrono::nanoseconds>(
532+
auto duration = std::chrono::duration_cast<std::chrono::duration<double>>(
502533
end_time - start_time);
503534
internal::MonitoringStats::instance().record_validation_success(
504535
issuer, duration.count());
505536
}
506537
} catch (const std::exception &e) {
507538
// Record failure if we have an issuer
508-
if (has_issuer) {
509-
record_validation_error(issuer, e, start_time);
539+
if (!issuer.empty()) {
540+
auto end_time = std::chrono::steady_clock::now();
541+
auto duration = std::chrono::duration_cast<std::chrono::duration<double>>(
542+
end_time - start_time);
543+
record_validation_error(issuer, e, duration.count());
510544
}
511545
throw;
512546
}
@@ -515,13 +549,11 @@ class Validator {
515549
void verify(const jwt::decoded_jwt<jwt::traits::kazuho_picojson> &jwt) {
516550
std::string issuer = "";
517551
auto start_time = std::chrono::steady_clock::now();
518-
bool has_issuer = false;
519552

520553
try {
521554
// Try to extract issuer for monitoring
522555
if (jwt.has_payload_claim("iss")) {
523556
issuer = jwt.get_issuer();
524-
has_issuer = true;
525557
}
526558

527559
auto result = verify_async(jwt);
@@ -530,17 +562,20 @@ class Validator {
530562
}
531563

532564
// Record successful validation
533-
if (has_issuer) {
565+
if (!issuer.empty()) {
534566
auto end_time = std::chrono::steady_clock::now();
535-
auto duration = std::chrono::duration_cast<std::chrono::nanoseconds>(
567+
auto duration = std::chrono::duration_cast<std::chrono::duration<double>>(
536568
end_time - start_time);
537569
internal::MonitoringStats::instance().record_validation_success(
538570
issuer, duration.count());
539571
}
540572
} catch (const std::exception &e) {
541573
// Record failure if we have an issuer
542-
if (has_issuer) {
543-
record_validation_error(issuer, e, start_time);
574+
if (!issuer.empty()) {
575+
auto end_time = std::chrono::steady_clock::now();
576+
auto duration = std::chrono::duration_cast<std::chrono::duration<double>>(
577+
end_time - start_time);
578+
record_validation_error(issuer, e, duration.count());
544579
}
545580
throw;
546581
}
@@ -652,7 +687,17 @@ class Validator {
652687

653688
const jwt::decoded_jwt<jwt::traits::kazuho_picojson> jwt(
654689
status->m_jwt_string);
655-
verifier.verify(jwt);
690+
try {
691+
verifier.verify(jwt);
692+
} catch (const std::exception &e) {
693+
// Check if this is an expiration error from jwt-cpp
694+
std::string error_msg = e.what();
695+
if (error_msg.find("exp") != std::string::npos ||
696+
error_msg.find("expir") != std::string::npos) {
697+
throw TokenExpiredException(error_msg);
698+
}
699+
throw;
700+
}
656701

657702
bool must_verify_everything = true;
658703
if (jwt.has_payload_claim("ver")) {
@@ -927,29 +972,18 @@ class Validator {
927972
*/
928973
void record_validation_error(const std::string &issuer,
929974
const std::exception &e,
930-
const std::chrono::steady_clock::time_point &start_time) {
931-
auto end_time = std::chrono::steady_clock::now();
932-
auto duration = std::chrono::duration_cast<std::chrono::nanoseconds>(
933-
end_time - start_time);
934-
935-
std::string error_msg = e.what();
936-
937-
// Check if this is a failed issuer lookup (network/DNS errors)
938-
if (error_msg.find("resolve") != std::string::npos ||
939-
error_msg.find("host") != std::string::npos ||
940-
error_msg.find("network") != std::string::npos ||
941-
error_msg.find("Failed to retrieve") != std::string::npos) {
975+
double duration_s) {
976+
// Check exception type instead of string introspection
977+
if (dynamic_cast<const IssuerLookupException*>(&e)) {
942978
internal::MonitoringStats::instance().record_failed_issuer_lookup(issuer);
943979
}
944980

945-
// Check if this is an expiration error
946-
if (error_msg.find("exp") != std::string::npos ||
947-
error_msg.find("expir") != std::string::npos) {
981+
if (dynamic_cast<const TokenExpiredException*>(&e)) {
948982
internal::MonitoringStats::instance().record_expired_token(issuer);
949983
}
950984

951985
internal::MonitoringStats::instance().record_validation_failure(
952-
issuer, duration.count());
986+
issuer, duration_s);
953987
}
954988

955989
bool m_validate_all_claims{true};

src/scitokens_monitoring.cpp

Lines changed: 11 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -16,22 +16,25 @@ MonitoringStats &MonitoringStats::instance() {
1616
return instance;
1717
}
1818

19+
IssuerStats* MonitoringStats::get_issuer_stats(const std::string &issuer) {
20+
std::lock_guard<std::mutex> lock(m_mutex);
21+
return &m_issuer_stats[issuer];
22+
}
23+
1924
void MonitoringStats::record_validation_success(const std::string &issuer,
20-
uint64_t duration_ns) {
25+
double duration_s) {
2126
std::lock_guard<std::mutex> lock(m_mutex);
2227
auto &stats = m_issuer_stats[issuer];
2328
stats.successful_validations++;
24-
stats.total_time_ns += duration_ns;
25-
stats.validation_count++;
29+
stats.total_time_s = stats.total_time_s.load() + duration_s;
2630
}
2731

2832
void MonitoringStats::record_validation_failure(const std::string &issuer,
29-
uint64_t duration_ns) {
33+
double duration_s) {
3034
std::lock_guard<std::mutex> lock(m_mutex);
3135
auto &stats = m_issuer_stats[issuer];
3236
stats.unsuccessful_validations++;
33-
stats.total_time_ns += duration_ns;
34-
stats.validation_count++;
37+
stats.total_time_s = stats.total_time_s.load() + duration_s;
3538
}
3639

3740
void MonitoringStats::record_expired_token(const std::string &issuer) {
@@ -112,17 +115,8 @@ std::string MonitoringStats::get_json() const {
112115
stats.unsuccessful_validations.load()));
113116
issuer_obj["expired_tokens"] =
114117
picojson::value(static_cast<double>(stats.expired_tokens.load()));
115-
116-
uint64_t validation_count = stats.validation_count.load();
117-
if (validation_count > 0) {
118-
uint64_t total_time_ns = stats.total_time_ns.load();
119-
double avg_time_ms =
120-
static_cast<double>(total_time_ns) / validation_count / 1e6;
121-
issuer_obj["average_validation_time_ms"] =
122-
picojson::value(avg_time_ms);
123-
} else {
124-
issuer_obj["average_validation_time_ms"] = picojson::value(0.0);
125-
}
118+
issuer_obj["total_validation_time_s"] =
119+
picojson::value(stats.total_time_s.load());
126120

127121
std::string sanitized_issuer = sanitize_issuer_for_json(issuer);
128122
issuers_obj[sanitized_issuer] = picojson::value(issuer_obj);

0 commit comments

Comments
 (0)