Skip to content

Commit d889a33

Browse files
committed
Add --refresh-before flag to databricks auth token
Users who use `databricks auth token` as an API key helper (e.g., for Claude Code) get expired tokens because the oauth2 library only refreshes within ~10 seconds of expiry. The new `--refresh-before` flag (e.g., `--refresh-before 5m`) refreshes the token if it expires within the given window. Depends on: databricks/databricks-sdk-go#1532 Resolves #4564
1 parent fd9f50e commit d889a33

File tree

4 files changed

+77
-2
lines changed

4 files changed

+77
-2
lines changed

cmd/auth/token.go

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,10 @@ using a client ID and secret is not supported.`,
6565
cmd.Flags().DurationVar(&tokenTimeout, "timeout", defaultTimeout,
6666
"Timeout for acquiring a token.")
6767

68+
var refreshBefore time.Duration
69+
cmd.Flags().DurationVar(&refreshBefore, "refresh-before", 0,
70+
"Refresh the token if it expires within this duration (e.g., 5m, 30s).")
71+
6872
cmd.RunE = func(cmd *cobra.Command, args []string) error {
6973
ctx := cmd.Context()
7074
profileName := ""
@@ -78,6 +82,7 @@ using a client ID and secret is not supported.`,
7882
profileName: profileName,
7983
args: args,
8084
tokenTimeout: tokenTimeout,
85+
refreshBefore: refreshBefore,
8186
profiler: profile.DefaultProfiler,
8287
persistentAuthOpts: nil,
8388
})
@@ -108,6 +113,9 @@ type loadTokenArgs struct {
108113
// tokenTimeout is the timeout for retrieving (and potentially refreshing) an OAuth token.
109114
tokenTimeout time.Duration
110115

116+
// refreshBefore triggers a token refresh if the token expires within this duration.
117+
refreshBefore time.Duration
118+
111119
// profiler is the profiler to use for reading the host and account ID from the .databrickscfg file.
112120
profiler profile.Profiler
113121

@@ -242,6 +250,9 @@ func loadToken(ctx context.Context, args loadTokenArgs) (*oauth2.Token, error) {
242250
return nil, err
243251
}
244252
allArgs := append(args.persistentAuthOpts, u2m.WithOAuthArgument(oauthArgument))
253+
if args.refreshBefore > 0 {
254+
allArgs = append(allArgs, u2m.WithExpiryDelta(args.refreshBefore))
255+
}
245256
persistentAuth, err := u2m.NewPersistentAuth(ctx, allArgs...)
246257
if err != nil {
247258
helpMsg := helpfulError(ctx, args.profileName, oauthArgument)

cmd/auth/token_test.go

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,11 @@ func TestToken_loadToken(t *testing.T) {
130130
Name: "legacy-ws",
131131
Host: "https://legacy-ws.cloud.databricks.com",
132132
},
133+
{
134+
Name: "valid-token",
135+
Host: "https://accounts.cloud.databricks.com",
136+
AccountID: "valid-token",
137+
},
133138
{
134139
Name: "m2m-profile",
135140
Host: "https://m2m.cloud.databricks.com",
@@ -642,6 +647,65 @@ func TestToken_loadToken(t *testing.T) {
642647
},
643648
validateToken: validateToken,
644649
},
650+
{
651+
name: "refreshBefore skips refresh when token has enough time",
652+
args: loadTokenArgs{
653+
authArguments: &auth.AuthArguments{},
654+
profileName: "valid-token",
655+
args: []string{},
656+
tokenTimeout: 1 * time.Hour,
657+
refreshBefore: 5 * time.Minute,
658+
profiler: profiler,
659+
persistentAuthOpts: []u2m.PersistentAuthOption{
660+
u2m.WithTokenCache(&inMemoryTokenCache{Tokens: map[string]*oauth2.Token{
661+
"valid-token": {AccessToken: "still-valid", RefreshToken: "valid-token", Expiry: time.Now().Add(1 * time.Hour)},
662+
}}),
663+
u2m.WithOAuthEndpointSupplier(&MockApiClient{}),
664+
},
665+
},
666+
validateToken: func(resp *oauth2.Token) {
667+
assert.Equal(t, "still-valid", resp.AccessToken)
668+
},
669+
},
670+
{
671+
name: "refreshBefore zero preserves default behavior",
672+
args: loadTokenArgs{
673+
authArguments: &auth.AuthArguments{},
674+
profileName: "valid-token",
675+
args: []string{},
676+
tokenTimeout: 1 * time.Hour,
677+
refreshBefore: 0,
678+
profiler: profiler,
679+
persistentAuthOpts: []u2m.PersistentAuthOption{
680+
u2m.WithTokenCache(&inMemoryTokenCache{Tokens: map[string]*oauth2.Token{
681+
"valid-token": {AccessToken: "still-valid", RefreshToken: "valid-token", Expiry: time.Now().Add(1 * time.Hour)},
682+
}}),
683+
u2m.WithOAuthEndpointSupplier(&MockApiClient{}),
684+
},
685+
},
686+
validateToken: func(resp *oauth2.Token) {
687+
assert.Equal(t, "still-valid", resp.AccessToken)
688+
},
689+
},
690+
{
691+
name: "refreshBefore forces refresh when token expires within window",
692+
args: loadTokenArgs{
693+
authArguments: &auth.AuthArguments{},
694+
profileName: "valid-token",
695+
args: []string{},
696+
tokenTimeout: 1 * time.Hour,
697+
refreshBefore: 2 * time.Hour,
698+
profiler: profiler,
699+
persistentAuthOpts: []u2m.PersistentAuthOption{
700+
u2m.WithTokenCache(&inMemoryTokenCache{Tokens: map[string]*oauth2.Token{
701+
"valid-token": {AccessToken: "still-valid", RefreshToken: "valid-token", Expiry: time.Now().Add(1 * time.Hour)},
702+
}}),
703+
u2m.WithOAuthEndpointSupplier(&MockApiClient{}),
704+
u2m.WithHttpClient(&http.Client{Transport: fixtures.SliceTransport{refreshSuccessTokenResponse}}),
705+
},
706+
},
707+
validateToken: validateToken,
708+
},
645709
{
646710
name: "host flag with profile env var disambiguates multi-profile",
647711
setupCtx: func(ctx context.Context) context.Context {

go.mod

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,3 +110,5 @@ require (
110110
google.golang.org/grpc v1.78.0 // indirect
111111
google.golang.org/protobuf v1.36.11 // indirect
112112
)
113+
114+
replace github.com/databricks/databricks-sdk-go => /Users/anthony.ivan/projects/databricks-sdk-go

go.sum

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,8 +75,6 @@ github.com/creack/pty v1.1.24 h1:bJrF4RRfyJnbTJqzRLHzcGaZK1NeM5kTC9jGgovnR1s=
7575
github.com/creack/pty v1.1.24/go.mod h1:08sCNb52WyoAwi2QDyzUCTgcvVFhUzewun7wtTfvcwE=
7676
github.com/cyphar/filepath-securejoin v0.4.1 h1:JyxxyPEaktOD+GAnqIqTf9A8tHyAG22rowi7HkoSU1s=
7777
github.com/cyphar/filepath-securejoin v0.4.1/go.mod h1:Sdj7gXlvMcPZsbhwhQ33GguGLDGQL7h7bg04C/+u9jI=
78-
github.com/databricks/databricks-sdk-go v0.117.0 h1:CJNFcQIkHgPMVJTSeiQoHftl0cIIvG4bOMpJSRssXpE=
79-
github.com/databricks/databricks-sdk-go v0.117.0/go.mod h1:hWoHnHbNLjPKiTm5K/7bcIv3J3Pkgo5x9pPzh8K3RVE=
8078
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
8179
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
8280
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=

0 commit comments

Comments
 (0)