From 99719efcb013400d59e04c356e25849d6ba3d449 Mon Sep 17 00:00:00 2001 From: Jacob Clayden Date: Sun, 3 May 2026 02:23:58 +0100 Subject: [PATCH] fix: persist refreshed Claude account tokens --- internal/keyring/keyring.go | 116 ++++++++++----- internal/keyring/keyring_test.go | 248 +++++++++++++++++++++++++++++-- 2 files changed, 314 insertions(+), 50 deletions(-) diff --git a/internal/keyring/keyring.go b/internal/keyring/keyring.go index bc521b7..c957502 100644 --- a/internal/keyring/keyring.go +++ b/internal/keyring/keyring.go @@ -150,11 +150,11 @@ func mergeAnonymousFresh(accounts []ClaudeOAuth) []ClaudeOAuth { // pickWinner reports whether candidate should replace current as the "token // winner" when two entries represent the same logical account. It implements the // shared tie-break policy: -// 1. Higher ExpiresAt wins outright. -// 2. On a tie: prefer non-empty AccountUUID. -// 3. Then: prefer non-nil TokenAccount. -// 4. Then: prefer richer (longer) Scopes list. -// 5. Otherwise keep current (return false). +// 1. Higher ExpiresAt wins outright. +// 2. On a tie: prefer non-empty AccountUUID. +// 3. Then: prefer non-nil TokenAccount. +// 4. Then: prefer richer (longer) Scopes list. +// 5. Otherwise keep current (return false). func pickWinner(candidate, current ClaudeOAuth) bool { if candidate.ExpiresAt > current.ExpiresAt { return true @@ -181,7 +181,6 @@ func pickWinner(candidate, current ClaudeOAuth) bool { return false } - // mergeIdentifiedByFreshness deduplicates identified accounts (those with // AccountUUID or Email) across discovery sources by preferring the entry with // the highest ExpiresAt. This fixes the source-order bias bug where a stale @@ -586,61 +585,100 @@ func BackfillCredentialsFile(acct *ClaudeOAuth) { } } +var ( + updateKeychainEntryForRefresh = UpdateKeychainEntry + storeCQAccountForRefresh = StoreCQAccount +) + // PersistRefreshedToken updates stored Claude credentials after a successful refresh. func PersistRefreshedToken(acct *ClaudeOAuth) { + cqAccount := *acct + home, err := os.UserHomeDir() - if err != nil { - return + if err == nil { + path := filepath.Join(home, ".claude", ".credentials.json") + data, err := os.ReadFile(path) + if err == nil { + var creds ClaudeCredentials + if json.Unmarshal(data, &creds) == nil && canUpdateStoredAccount(creds.ClaudeAiOauth, acct) { + stored := creds.ClaudeAiOauth + updated := mergeRefreshedAccount(stored, acct) + creds.ClaudeAiOauth = &updated + cqAccount = updated + if err := WriteCredentialsFile(&creds); err != nil { + fmt.Fprintf(os.Stderr, "cq: PersistRefreshedToken: write creds: %v\n", err) + } else if err := updateKeychainEntryForRefresh("Claude Code-credentials", &creds); err != nil { + fmt.Fprintf(os.Stderr, "cq: PersistRefreshedToken: update keychain: %v\n", err) + } + } + } } - path := filepath.Join(home, ".claude", ".credentials.json") - data, err := os.ReadFile(path) - if err != nil { - return + + if cqAccount.AccountUUID == "" { + cqAccount.AccountUUID = acct.AccountUUID } - var creds ClaudeCredentials - if json.Unmarshal(data, &creds) != nil || creds.ClaudeAiOauth == nil { - return + if cqAccount.Email == "" { + cqAccount.Email = acct.Email } - stored := creds.ClaudeAiOauth - if !sameStoredAccount(stored, acct) { - return + if cqAccount.AccountUUID != "" { + if err := storeCQAccountForRefresh(&cqAccount); err != nil { + fmt.Fprintf(os.Stderr, "cq: PersistRefreshedToken: store cq account: %v\n", err) + } + } +} + +func canUpdateStoredAccount(stored, acct *ClaudeOAuth) bool { + if stored == nil || acct == nil { + return false } + if stored.Email != "" && acct.Email != "" && stored.Email != acct.Email { + return false + } + if stored.AccountUUID != "" && acct.AccountUUID != "" && stored.AccountUUID != acct.AccountUUID { + return false + } + if stored.Email != "" && acct.Email != "" { + return true + } + if stored.AccountUUID != "" && acct.AccountUUID != "" { + return true + } + return sameStoredAccount(stored, acct) +} +func mergeRefreshedAccount(stored, acct *ClaudeOAuth) ClaudeOAuth { updated := *stored - changed := false - if acct.AccessToken != "" && stored.AccessToken != acct.AccessToken { + if acct.AccessToken != "" { updated.AccessToken = acct.AccessToken - changed = true } - if acct.ExpiresAt > 0 && stored.ExpiresAt != acct.ExpiresAt { + if acct.ExpiresAt > 0 { updated.ExpiresAt = acct.ExpiresAt - changed = true } - if acct.RefreshToken != "" && stored.RefreshToken != acct.RefreshToken { + if acct.RefreshToken != "" { updated.RefreshToken = acct.RefreshToken - changed = true } if len(acct.Scopes) > 0 && len(stored.Scopes) == 0 { updated.Scopes = acct.Scopes - changed = true } - if !changed { - return + if updated.Email == "" { + updated.Email = acct.Email } - creds.ClaudeAiOauth = &updated - - if err := WriteCredentialsFile(&creds); err != nil { - fmt.Fprintf(os.Stderr, "cq: PersistRefreshedToken: write creds: %v\n", err) - return + if updated.AccountUUID == "" { + updated.AccountUUID = acct.AccountUUID } - if err := UpdateKeychainEntry("Claude Code-credentials", &creds); err != nil { - fmt.Fprintf(os.Stderr, "cq: PersistRefreshedToken: update keychain: %v\n", err) + if updated.SubscriptionType == "" { + updated.SubscriptionType = acct.SubscriptionType } - if updated.AccountUUID != "" { - if err := StoreCQAccount(&updated); err != nil { - fmt.Fprintf(os.Stderr, "cq: PersistRefreshedToken: store cq account: %v\n", err) - } + if updated.RateLimitTier == "" { + updated.RateLimitTier = acct.RateLimitTier + } + if updated.Profile == nil { + updated.Profile = acct.Profile + } + if updated.TokenAccount == nil { + updated.TokenAccount = acct.TokenAccount } + return updated } // ActiveClaudeEmail returns the email of the currently active Claude account diff --git a/internal/keyring/keyring_test.go b/internal/keyring/keyring_test.go index 5699db5..4554f61 100644 --- a/internal/keyring/keyring_test.go +++ b/internal/keyring/keyring_test.go @@ -3,6 +3,7 @@ package keyring import ( "encoding/json" "os" + "strings" "testing" ) @@ -162,8 +163,8 @@ func TestMergeAnonymousFresh(t *testing.T) { t.Run("multiple anonymous entries — only fresher matching one merges", func(t *testing.T) { input := []ClaudeOAuth{ {Email: "a@example.com", AccountUUID: "uuid1", AccessToken: "base-at", RefreshToken: "rt-shared", ExpiresAt: 100}, - {AccessToken: "anon1-at", RefreshToken: "anon1-rt", ExpiresAt: 50}, // staler and no match - {AccessToken: "anon2-at", RefreshToken: "rt-shared", ExpiresAt: 200}, // fresher with matching RT + {AccessToken: "anon1-at", RefreshToken: "anon1-rt", ExpiresAt: 50}, // staler and no match + {AccessToken: "anon2-at", RefreshToken: "rt-shared", ExpiresAt: 200}, // fresher with matching RT } got := mergeAnonymousFresh(input) // anon2 merges into identified (shared RT); anon1 is kept (no match) @@ -605,6 +606,231 @@ func TestWriteCredentialsFile(t *testing.T) { }) } +// ── PersistRefreshedToken ───────────────────────────────────────────────────── + +func TestPersistRefreshedToken(t *testing.T) { + origUpdateKeychain := updateKeychainEntryForRefresh + origStore := storeCQAccountForRefresh + defer func() { + updateKeychainEntryForRefresh = origUpdateKeychain + storeCQAccountForRefresh = origStore + }() + + t.Run("stores cq account when credentials file missing", func(t *testing.T) { + dir := t.TempDir() + t.Setenv("HOME", dir) + + updateKeychainEntryForRefresh = func(service string, creds *ClaudeCredentials) error { + t.Fatalf("UpdateKeychainEntry called without matching credentials file") + return nil + } + + var stored *ClaudeOAuth + storeCQAccountForRefresh = func(acct *ClaudeOAuth) error { + copy := *acct + stored = © + return nil + } + + PersistRefreshedToken(&ClaudeOAuth{ + AccountUUID: "uuid-secondary", + Email: "secondary@example.com", + AccessToken: "new-at", + RefreshToken: "new-rt", + ExpiresAt: 123456, + }) + + if stored == nil { + t.Fatal("expected refreshed account to be stored in cq keyring") + } + if stored.AccountUUID != "uuid-secondary" { + t.Fatalf("AccountUUID = %q, want uuid-secondary", stored.AccountUUID) + } + if stored.AccessToken != "new-at" { + t.Fatalf("AccessToken = %q, want new-at", stored.AccessToken) + } + if _, err := os.Stat(dir + "/.claude/.credentials.json"); !os.IsNotExist(err) { + t.Fatalf("credentials file should not be created for non-active account, stat err = %v", err) + } + }) + + t.Run("stores refreshed account UUID when active credentials matched by token", func(t *testing.T) { + dir := t.TempDir() + t.Setenv("HOME", dir) + if err := os.MkdirAll(dir+"/.claude", 0o700); err != nil { + t.Fatalf("mkdir: %v", err) + } + creds := ClaudeCredentials{ClaudeAiOauth: &ClaudeOAuth{ + AccessToken: "old-at", + RefreshToken: "old-rt", + ExpiresAt: 100, + }} + data, err := json.MarshalIndent(creds, "", " ") + if err != nil { + t.Fatalf("marshal: %v", err) + } + if err := os.WriteFile(dir+"/.claude/.credentials.json", data, 0o600); err != nil { + t.Fatalf("write: %v", err) + } + + updateKeychainEntryForRefresh = func(service string, creds *ClaudeCredentials) error { + if creds.ClaudeAiOauth.AccountUUID != "uuid-active" { + t.Fatalf("keychain AccountUUID = %q, want uuid-active", creds.ClaudeAiOauth.AccountUUID) + } + return nil + } + + var stored *ClaudeOAuth + storeCQAccountForRefresh = func(acct *ClaudeOAuth) error { + copy := *acct + stored = © + return nil + } + + PersistRefreshedToken(&ClaudeOAuth{ + AccountUUID: "uuid-active", + Email: "active@example.com", + AccessToken: "new-at", + RefreshToken: "old-rt", + ExpiresAt: 200, + }) + + if stored == nil { + t.Fatal("expected token-matched account to be stored in cq keyring") + } + if stored.AccountUUID != "uuid-active" { + t.Fatalf("stored AccountUUID = %q, want uuid-active", stored.AccountUUID) + } + if stored.AccessToken != "new-at" { + t.Fatalf("stored AccessToken = %q, want new-at", stored.AccessToken) + } + }) + + t.Run("does not update active credentials with conflicting stable identifiers", func(t *testing.T) { + dir := t.TempDir() + t.Setenv("HOME", dir) + if err := os.MkdirAll(dir+"/.claude", 0o700); err != nil { + t.Fatalf("mkdir: %v", err) + } + creds := ClaudeCredentials{ClaudeAiOauth: &ClaudeOAuth{ + AccountUUID: "uuid-active", + Email: "active@example.com", + AccessToken: "shared-at", + RefreshToken: "shared-rt", + ExpiresAt: 100, + }} + data, err := json.MarshalIndent(creds, "", " ") + if err != nil { + t.Fatalf("marshal: %v", err) + } + if err := os.WriteFile(dir+"/.claude/.credentials.json", data, 0o600); err != nil { + t.Fatalf("write: %v", err) + } + + updateKeychainEntryForRefresh = func(service string, creds *ClaudeCredentials) error { + t.Fatalf("UpdateKeychainEntry called for conflicting account") + return nil + } + + var stored *ClaudeOAuth + storeCQAccountForRefresh = func(acct *ClaudeOAuth) error { + copy := *acct + stored = © + return nil + } + + PersistRefreshedToken(&ClaudeOAuth{ + AccountUUID: "uuid-other", + Email: "other@example.com", + AccessToken: "new-at", + RefreshToken: "shared-rt", + ExpiresAt: 200, + }) + + data, err = os.ReadFile(dir + "/.claude/.credentials.json") + if err != nil { + t.Fatalf("read: %v", err) + } + var gotCreds ClaudeCredentials + if err := json.Unmarshal(data, &gotCreds); err != nil { + t.Fatalf("unmarshal: %v", err) + } + if gotCreds.ClaudeAiOauth.AccessToken != "shared-at" { + t.Fatalf("active AccessToken = %q, want shared-at", gotCreds.ClaudeAiOauth.AccessToken) + } + if stored == nil { + t.Fatal("expected conflicting refreshed account to be stored in cq keyring") + } + if stored.AccountUUID != "uuid-other" { + t.Fatalf("stored AccountUUID = %q, want uuid-other", stored.AccountUUID) + } + }) + + t.Run("stores refreshed active account after credentials update", func(t *testing.T) { + dir := t.TempDir() + t.Setenv("HOME", dir) + if err := os.MkdirAll(dir+"/.claude", 0o700); err != nil { + t.Fatalf("mkdir: %v", err) + } + creds := ClaudeCredentials{ClaudeAiOauth: &ClaudeOAuth{ + AccountUUID: "uuid-active", + Email: "active@example.com", + AccessToken: "old-at", + RefreshToken: "old-rt", + ExpiresAt: 100, + Scopes: []string{"existing:scope"}, + }} + data, err := json.MarshalIndent(creds, "", " ") + if err != nil { + t.Fatalf("marshal: %v", err) + } + if err := os.WriteFile(dir+"/.claude/.credentials.json", data, 0o600); err != nil { + t.Fatalf("write: %v", err) + } + + keychainUpdated := false + updateKeychainEntryForRefresh = func(service string, creds *ClaudeCredentials) error { + keychainUpdated = true + if got := creds.ClaudeAiOauth.AccessToken; got != "new-at" { + t.Fatalf("keychain AccessToken = %q, want new-at", got) + } + return nil + } + + var stored *ClaudeOAuth + storeCQAccountForRefresh = func(acct *ClaudeOAuth) error { + copy := *acct + stored = © + return nil + } + + PersistRefreshedToken(&ClaudeOAuth{ + AccountUUID: "uuid-active", + Email: "active@example.com", + AccessToken: "new-at", + RefreshToken: "new-rt", + ExpiresAt: 200, + Scopes: []string{"new:scope"}, + }) + + if !keychainUpdated { + t.Fatal("expected active keychain credentials update") + } + if stored == nil { + t.Fatal("expected refreshed active account to be stored in cq keyring") + } + if stored.AccessToken != "new-at" { + t.Fatalf("stored AccessToken = %q, want new-at", stored.AccessToken) + } + if stored.RefreshToken != "new-rt" { + t.Fatalf("stored RefreshToken = %q, want new-rt", stored.RefreshToken) + } + if got := strings.Join(stored.Scopes, ","); got != "existing:scope" { + t.Fatalf("stored Scopes = %q, want existing:scope", got) + } + }) +} + // ── BackfillCredentialsFile ─────────────────────────────────────────────────── func TestBackfillCredentialsFile(t *testing.T) { @@ -691,11 +917,11 @@ func TestBackfillCredentialsFile(t *testing.T) { t.Setenv("HOME", dir) stored := ClaudeOAuth{ - AccessToken: "tok-match", + AccessToken: "tok-match", RefreshToken: "rt", - ExpiresAt: 100, - Email: "already@example.com", - AccountUUID: "uuid-already", + ExpiresAt: 100, + Email: "already@example.com", + AccountUUID: "uuid-already", } writeInitialCreds(t, dir, stored) @@ -849,11 +1075,11 @@ func TestMergeIdentifiedByFreshness(t *testing.T) { // TestMergeIdentifiedByFreshnessTieBreaking covers the deterministic tie-break // policy when two entries share the same ExpiresAt. The requested policy is: -// 1. Prefer the entry with a non-empty AccountUUID. -// 2. Then prefer the entry with a non-nil TokenAccount. -// 3. Then prefer the entry with richer (longer) Scopes list. -// 4. Token winner keeps its own token fields; metadata is enriched from loser. -// 5. When otherwise equivalent, output is stable (first-seen wins). +// 1. Prefer the entry with a non-empty AccountUUID. +// 2. Then prefer the entry with a non-nil TokenAccount. +// 3. Then prefer the entry with richer (longer) Scopes list. +// 4. Token winner keeps its own token fields; metadata is enriched from loser. +// 5. When otherwise equivalent, output is stable (first-seen wins). func TestMergeIdentifiedByFreshnessTieBreaking(t *testing.T) { t.Run("equal ExpiresAt prefers entry with non-empty AccountUUID", func(t *testing.T) { // Both entries share the same email and the same ExpiresAt.