Skip to content

Commit 3e3bc0b

Browse files
committed
fix: stale oauth tokens invalidate setup instead of forcing re-auth
1 parent dc9801d commit 3e3bc0b

1 file changed

Lines changed: 64 additions & 10 deletions

File tree

src/auth.rs

Lines changed: 64 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@ use braintrust_sdk_rust::{BraintrustClient, LoginState};
1313
use clap::{Args, Subcommand};
1414
use crossterm::event::{self, Event, KeyCode, KeyEventKind};
1515
use dialoguer::{Confirm, Input, Password};
16-
use oauth2::basic::{BasicClient, BasicTokenType};
17-
use oauth2::reqwest::async_http_client;
16+
use oauth2::basic::{BasicClient, BasicErrorResponseType, BasicRequestTokenError, BasicTokenType};
17+
use oauth2::reqwest::{async_http_client, AsyncHttpClientError};
1818
use oauth2::{
1919
AuthUrl, AuthorizationCode, ClientId, CsrfToken, EmptyExtraTokenFields, PkceCodeChallenge,
2020
PkceCodeVerifier, RedirectUrl, RefreshToken, Scope, StandardTokenResponse, TokenResponse,
@@ -623,7 +623,8 @@ pub async fn resolve_auth(base: &BaseArgs) -> Result<ResolvedAuth> {
623623
),
624624
)
625625
})?;
626-
let refreshed = refresh_oauth_access_token(&api_url, &refresh_token, client_id).await?;
626+
let refreshed =
627+
refresh_oauth_access_token(&api_url, &refresh_token, client_id, &profile_name).await?;
627628
save_profile_oauth_access_token(&profile_name, &refreshed.access_token)?;
628629
if let Some(next_refresh_token) = refreshed.refresh_token.as_ref() {
629630
if next_refresh_token != &refresh_token {
@@ -1152,7 +1153,9 @@ async fn run_login_refresh(base: &BaseArgs) -> Result<()> {
11521153
println!("Cached access token expiry before refresh: unknown");
11531154
}
11541155

1155-
let refreshed = refresh_oauth_access_token(&api_url, &refresh_token, &client_id).await?;
1156+
let refreshed =
1157+
refresh_oauth_access_token(&api_url, &refresh_token, &client_id, profile_name.as_str())
1158+
.await?;
11561159
save_profile_oauth_access_token(profile_name.as_str(), &refreshed.access_token)?;
11571160
let mut refresh_rotated = false;
11581161
if let Some(next_refresh_token) = refreshed.refresh_token.as_ref() {
@@ -2141,22 +2144,43 @@ async fn exchange_oauth_authorization_code(
21412144
Ok(to_oauth_token_response(token_response))
21422145
}
21432146

2147+
fn map_refresh_oauth_error(
2148+
api_url: &str,
2149+
profile_name: &str,
2150+
err: BasicRequestTokenError<AsyncHttpClientError>,
2151+
) -> anyhow::Error {
2152+
if let BasicRequestTokenError::ServerResponse(server_err) = &err {
2153+
if matches!(server_err.error(), BasicErrorResponseType::InvalidGrant) {
2154+
let mut message =
2155+
format!("oauth refresh token expired or was rejected for profile '{profile_name}'");
2156+
if let Some(description) = server_err.error_description() {
2157+
message.push_str(&format!(" ({description})"));
2158+
}
2159+
message.push_str(&format!(
2160+
"; re-run `bt auth login --oauth --profile {profile_name}`"
2161+
));
2162+
return recoverable_auth_error(RecoverableAuthErrorKind::OauthRefreshToken, message);
2163+
}
2164+
}
2165+
2166+
anyhow::Error::new(err).context(format!(
2167+
"failed to call oauth token endpoint {}/oauth/token",
2168+
api_url.trim_end_matches('/')
2169+
))
2170+
}
2171+
21442172
async fn refresh_oauth_access_token(
21452173
api_url: &str,
21462174
refresh_token: &str,
21472175
client_id: &str,
2176+
profile_name: &str,
21482177
) -> Result<OAuthTokenResponse> {
21492178
let oauth_client = build_oauth_client(api_url, client_id, None)?;
21502179
let token_response = oauth_client
21512180
.exchange_refresh_token(&RefreshToken::new(refresh_token.to_string()))
21522181
.request_async(async_http_client)
21532182
.await
2154-
.with_context(|| {
2155-
format!(
2156-
"failed to call oauth token endpoint {}/oauth/token",
2157-
api_url.trim_end_matches('/')
2158-
)
2159-
})?;
2183+
.map_err(|err| map_refresh_oauth_error(api_url, profile_name, err))?;
21602184
Ok(to_oauth_token_response(token_response))
21612185
}
21622186

@@ -2929,6 +2953,36 @@ mod tests {
29292953
assert!(is_missing_credential_error(&err));
29302954
}
29312955

2956+
#[test]
2957+
fn invalid_grant_refresh_error_is_treated_as_recoverable() {
2958+
let err = map_refresh_oauth_error(
2959+
"https://api.example.com",
2960+
"work",
2961+
BasicRequestTokenError::ServerResponse(oauth2::basic::BasicErrorResponse::new(
2962+
BasicErrorResponseType::InvalidGrant,
2963+
Some("refresh token expired".to_string()),
2964+
None,
2965+
)),
2966+
);
2967+
2968+
assert!(is_missing_credential_error(&err));
2969+
assert!(err.to_string().contains("refresh token expired"));
2970+
}
2971+
2972+
#[test]
2973+
fn nonrecoverable_refresh_errors_remain_nonrecoverable() {
2974+
let err = map_refresh_oauth_error(
2975+
"https://api.example.com",
2976+
"work",
2977+
BasicRequestTokenError::Other("unexpected response".to_string()),
2978+
);
2979+
2980+
assert!(!is_missing_credential_error(&err));
2981+
assert!(err
2982+
.to_string()
2983+
.contains("failed to call oauth token endpoint"));
2984+
}
2985+
29322986
fn restore_env_var(key: &str, previous: Option<OsString>) {
29332987
match previous {
29342988
Some(value) => env::set_var(key, value),

0 commit comments

Comments
 (0)