@@ -13,8 +13,8 @@ use braintrust_sdk_rust::{BraintrustClient, LoginState};
1313use clap:: { Args , Subcommand } ;
1414use crossterm:: event:: { self , Event , KeyCode , KeyEventKind } ;
1515use dialoguer:: { Confirm , Input , Password } ;
16- use oauth2:: basic:: { BasicClient , BasicTokenType } ;
17- use oauth2:: reqwest:: async_http_client;
16+ use oauth2:: basic:: { BasicClient , BasicErrorResponseType , BasicRequestTokenError , BasicTokenType } ;
17+ use oauth2:: reqwest:: { async_http_client, AsyncHttpClientError } ;
1818use oauth2:: {
1919 AuthUrl , AuthorizationCode , ClientId , CsrfToken , EmptyExtraTokenFields , PkceCodeChallenge ,
2020 PkceCodeVerifier , RedirectUrl , RefreshToken , Scope , StandardTokenResponse , TokenResponse ,
@@ -623,7 +623,8 @@ pub async fn resolve_auth(base: &BaseArgs) -> Result<ResolvedAuth> {
623623 ) ,
624624 )
625625 } ) ?;
626- let refreshed = refresh_oauth_access_token ( & api_url, & refresh_token, client_id) . await ?;
626+ let refreshed =
627+ refresh_oauth_access_token ( & api_url, & refresh_token, client_id, & profile_name) . await ?;
627628 save_profile_oauth_access_token ( & profile_name, & refreshed. access_token ) ?;
628629 if let Some ( next_refresh_token) = refreshed. refresh_token . as_ref ( ) {
629630 if next_refresh_token != & refresh_token {
@@ -1152,7 +1153,9 @@ async fn run_login_refresh(base: &BaseArgs) -> Result<()> {
11521153 println ! ( "Cached access token expiry before refresh: unknown" ) ;
11531154 }
11541155
1155- let refreshed = refresh_oauth_access_token ( & api_url, & refresh_token, & client_id) . await ?;
1156+ let refreshed =
1157+ refresh_oauth_access_token ( & api_url, & refresh_token, & client_id, profile_name. as_str ( ) )
1158+ . await ?;
11561159 save_profile_oauth_access_token ( profile_name. as_str ( ) , & refreshed. access_token ) ?;
11571160 let mut refresh_rotated = false ;
11581161 if let Some ( next_refresh_token) = refreshed. refresh_token . as_ref ( ) {
@@ -2141,22 +2144,43 @@ async fn exchange_oauth_authorization_code(
21412144 Ok ( to_oauth_token_response ( token_response) )
21422145}
21432146
2147+ fn map_refresh_oauth_error (
2148+ api_url : & str ,
2149+ profile_name : & str ,
2150+ err : BasicRequestTokenError < AsyncHttpClientError > ,
2151+ ) -> anyhow:: Error {
2152+ if let BasicRequestTokenError :: ServerResponse ( server_err) = & err {
2153+ if matches ! ( server_err. error( ) , BasicErrorResponseType :: InvalidGrant ) {
2154+ let mut message =
2155+ format ! ( "oauth refresh token expired or was rejected for profile '{profile_name}'" ) ;
2156+ if let Some ( description) = server_err. error_description ( ) {
2157+ message. push_str ( & format ! ( " ({description})" ) ) ;
2158+ }
2159+ message. push_str ( & format ! (
2160+ "; re-run `bt auth login --oauth --profile {profile_name}`"
2161+ ) ) ;
2162+ return recoverable_auth_error ( RecoverableAuthErrorKind :: OauthRefreshToken , message) ;
2163+ }
2164+ }
2165+
2166+ anyhow:: Error :: new ( err) . context ( format ! (
2167+ "failed to call oauth token endpoint {}/oauth/token" ,
2168+ api_url. trim_end_matches( '/' )
2169+ ) )
2170+ }
2171+
21442172async fn refresh_oauth_access_token (
21452173 api_url : & str ,
21462174 refresh_token : & str ,
21472175 client_id : & str ,
2176+ profile_name : & str ,
21482177) -> Result < OAuthTokenResponse > {
21492178 let oauth_client = build_oauth_client ( api_url, client_id, None ) ?;
21502179 let token_response = oauth_client
21512180 . exchange_refresh_token ( & RefreshToken :: new ( refresh_token. to_string ( ) ) )
21522181 . request_async ( async_http_client)
21532182 . await
2154- . with_context ( || {
2155- format ! (
2156- "failed to call oauth token endpoint {}/oauth/token" ,
2157- api_url. trim_end_matches( '/' )
2158- )
2159- } ) ?;
2183+ . map_err ( |err| map_refresh_oauth_error ( api_url, profile_name, err) ) ?;
21602184 Ok ( to_oauth_token_response ( token_response) )
21612185}
21622186
@@ -2929,6 +2953,36 @@ mod tests {
29292953 assert ! ( is_missing_credential_error( & err) ) ;
29302954 }
29312955
2956+ #[ test]
2957+ fn invalid_grant_refresh_error_is_treated_as_recoverable ( ) {
2958+ let err = map_refresh_oauth_error (
2959+ "https://api.example.com" ,
2960+ "work" ,
2961+ BasicRequestTokenError :: ServerResponse ( oauth2:: basic:: BasicErrorResponse :: new (
2962+ BasicErrorResponseType :: InvalidGrant ,
2963+ Some ( "refresh token expired" . to_string ( ) ) ,
2964+ None ,
2965+ ) ) ,
2966+ ) ;
2967+
2968+ assert ! ( is_missing_credential_error( & err) ) ;
2969+ assert ! ( err. to_string( ) . contains( "refresh token expired" ) ) ;
2970+ }
2971+
2972+ #[ test]
2973+ fn nonrecoverable_refresh_errors_remain_nonrecoverable ( ) {
2974+ let err = map_refresh_oauth_error (
2975+ "https://api.example.com" ,
2976+ "work" ,
2977+ BasicRequestTokenError :: Other ( "unexpected response" . to_string ( ) ) ,
2978+ ) ;
2979+
2980+ assert ! ( !is_missing_credential_error( & err) ) ;
2981+ assert ! ( err
2982+ . to_string( )
2983+ . contains( "failed to call oauth token endpoint" ) ) ;
2984+ }
2985+
29322986 fn restore_env_var ( key : & str , previous : Option < OsString > ) {
29332987 match previous {
29342988 Some ( value) => env:: set_var ( key, value) ,
0 commit comments