Skip to content

Commit 1779bb9

Browse files
refactor(oauth): address code review suggestions for token refresh
- Extract containsIgnoreCase to new internal/stringutil package - Replace custom toLower implementations with strings.ToLower - Add MaxExpiredTokenAge constant (24h) with documentation - Add TestRefreshStateSync to ensure health and oauth RefreshState stay in sync These changes improve code maintainability by: - Eliminating duplicate containsIgnoreCase implementations - Using stdlib strings.ToLower instead of custom ASCII-only version - Making the 24-hour expiry threshold explicit and documented - Adding a guard test to catch RefreshState constant drift Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
1 parent 9f90dc6 commit 1779bb9

5 files changed

Lines changed: 78 additions & 63 deletions

File tree

internal/health/calculator.go

Lines changed: 6 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,11 @@ package health
33

44
import (
55
"fmt"
6+
"strings"
67
"time"
78

89
"mcpproxy-go/internal/contracts"
10+
"mcpproxy-go/internal/stringutil"
911
)
1012

1113
// RefreshState represents the current state of token refresh for health reporting.
@@ -125,7 +127,7 @@ func CalculateHealth(input HealthCalculatorInput, cfg *HealthCalculatorConfig) *
125127
// 4. Connection state checks
126128
// Normalize state to lowercase for consistent matching
127129
// (ConnectionState.String() returns "Error", "Disconnected", etc.)
128-
state := toLower(input.State)
130+
state := strings.ToLower(input.State)
129131
switch state {
130132
case "error":
131133
// For OAuth-required servers with OAuth-related errors, suggest login instead of restart
@@ -320,7 +322,7 @@ func formatErrorSummary(lastError string) string {
320322

321323
// Check for known patterns (in order)
322324
for _, mapping := range errorMappings {
323-
if containsIgnoreCase(lastError, mapping.pattern) {
325+
if stringutil.ContainsIgnoreCase(lastError, mapping.pattern) {
324326
return mapping.friendly
325327
}
326328
}
@@ -375,36 +377,6 @@ func formatRefreshRetryDetail(retryCount int, nextAttempt *time.Time, lastError
375377
return detail
376378
}
377379

378-
// containsIgnoreCase checks if s contains substr, ignoring case.
379-
func containsIgnoreCase(s, substr string) bool {
380-
return len(s) >= len(substr) &&
381-
(s == substr ||
382-
containsLower(toLower(s), toLower(substr)))
383-
}
384-
385-
// toLower is a simple ASCII lowercase conversion.
386-
func toLower(s string) string {
387-
b := make([]byte, len(s))
388-
for i := 0; i < len(s); i++ {
389-
c := s[i]
390-
if c >= 'A' && c <= 'Z' {
391-
c += 'a' - 'A'
392-
}
393-
b[i] = c
394-
}
395-
return string(b)
396-
}
397-
398-
// containsLower checks if s contains substr (both should be lowercase).
399-
func containsLower(s, substr string) bool {
400-
for i := 0; i <= len(s)-len(substr); i++ {
401-
if s[i:i+len(substr)] == substr {
402-
return true
403-
}
404-
}
405-
return false
406-
}
407-
408380
// isOAuthRelatedError checks if the error message indicates an OAuth issue.
409381
func isOAuthRelatedError(err string) bool {
410382
if err == "" {
@@ -422,7 +394,7 @@ func isOAuthRelatedError(err string) bool {
422394
"access_denied",
423395
}
424396
for _, pattern := range oauthPatterns {
425-
if containsIgnoreCase(err, pattern) {
397+
if stringutil.ContainsIgnoreCase(err, pattern) {
426398
return true
427399
}
428400
}
@@ -475,7 +447,7 @@ func ExtractOAuthConfigError(lastError string) string {
475447
}
476448

477449
for _, pattern := range configPatterns {
478-
if containsIgnoreCase(lastError, pattern) {
450+
if stringutil.ContainsIgnoreCase(lastError, pattern) {
479451
return lastError
480452
}
481453
}

internal/health/calculator_test.go

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@ import (
55
"time"
66

77
"github.com/stretchr/testify/assert"
8+
9+
"mcpproxy-go/internal/oauth"
810
)
911

1012
func TestCalculateHealth_DisabledServer(t *testing.T) {
@@ -762,3 +764,19 @@ func TestFormatRefreshRetryDetail(t *testing.T) {
762764
assert.LessOrEqual(t, len(result), 200) // Reasonable max length
763765
})
764766
}
767+
768+
// TestRefreshStateSync ensures health.RefreshState values stay in sync with oauth.RefreshState.
769+
// The health package mirrors oauth.RefreshState for decoupling, but the values must match
770+
// for proper state mapping when wiring RefreshManager state into health calculation.
771+
func TestRefreshStateSync(t *testing.T) {
772+
// Verify that the integer values match between health and oauth packages
773+
// This test will fail if either package changes its constants without updating the other
774+
assert.Equal(t, int(RefreshStateIdle), int(oauth.RefreshStateIdle),
775+
"RefreshStateIdle values must match between health and oauth packages")
776+
assert.Equal(t, int(RefreshStateScheduled), int(oauth.RefreshStateScheduled),
777+
"RefreshStateScheduled values must match between health and oauth packages")
778+
assert.Equal(t, int(RefreshStateRetrying), int(oauth.RefreshStateRetrying),
779+
"RefreshStateRetrying values must match between health and oauth packages")
780+
assert.Equal(t, int(RefreshStateFailed), int(oauth.RefreshStateFailed),
781+
"RefreshStateFailed values must match between health and oauth packages")
782+
}

internal/oauth/refresh_manager.go

Lines changed: 10 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import (
1010
"go.uber.org/zap"
1111

1212
"mcpproxy-go/internal/storage"
13+
"mcpproxy-go/internal/stringutil"
1314
)
1415

1516
// Default refresh configuration
@@ -31,6 +32,11 @@ const (
3132

3233
// MaxRetryBackoff is the maximum backoff duration (5 minutes per FR-009).
3334
MaxRetryBackoff = 5 * time.Minute
35+
36+
// MaxExpiredTokenAge is how long after token expiration we continue retrying
37+
// before giving up completely. After this duration, we assume the refresh token
38+
// is no longer valid even if it wasn't explicitly rejected.
39+
MaxExpiredTokenAge = 24 * time.Hour
3440
)
3541

3642
// RefreshState represents the current state of token refresh for health reporting.
@@ -656,9 +662,9 @@ func (m *RefreshManager) handleRefreshFailure(serverName string, err error) {
656662
if !expiresAt.IsZero() && now.After(expiresAt) {
657663
// Token has already expired - check if we should give up
658664
// We'll keep trying as long as there's a chance the refresh token is still valid
659-
// Only give up if we've been trying for a very long time (e.g., > 24 hours past expiration)
665+
// Only give up if we've been trying for too long (MaxExpiredTokenAge)
660666
timeSinceExpiry := now.Sub(expiresAt)
661-
if timeSinceExpiry > 24*time.Hour {
667+
if timeSinceExpiry > MaxExpiredTokenAge {
662668
m.logger.Error("OAuth token refresh failed - token expired too long ago",
663669
zap.String("server", serverName),
664670
zap.Duration("expired_for", timeSinceExpiry),
@@ -699,7 +705,7 @@ func classifyRefreshError(err error) string {
699705
"refresh token invalid",
700706
}
701707
for _, pattern := range permanentErrors {
702-
if containsIgnoreCase(errStr, pattern) {
708+
if stringutil.ContainsIgnoreCase(errStr, pattern) {
703709
return "failed_invalid_grant"
704710
}
705711
}
@@ -716,39 +722,14 @@ func classifyRefreshError(err error) string {
716722
"context deadline exceeded",
717723
}
718724
for _, pattern := range networkErrors {
719-
if containsIgnoreCase(errStr, pattern) {
725+
if stringutil.ContainsIgnoreCase(errStr, pattern) {
720726
return "failed_network"
721727
}
722728
}
723729

724730
return "failed_other"
725731
}
726732

727-
// containsIgnoreCase checks if s contains substr, ignoring case.
728-
func containsIgnoreCase(s, substr string) bool {
729-
sLower := toLower(s)
730-
substrLower := toLower(substr)
731-
for i := 0; i <= len(sLower)-len(substrLower); i++ {
732-
if sLower[i:i+len(substrLower)] == substrLower {
733-
return true
734-
}
735-
}
736-
return false
737-
}
738-
739-
// toLower converts a string to lowercase (ASCII only).
740-
func toLower(s string) string {
741-
b := make([]byte, len(s))
742-
for i := 0; i < len(s); i++ {
743-
c := s[i]
744-
if c >= 'A' && c <= 'Z' {
745-
c += 'a' - 'A'
746-
}
747-
b[i] = c
748-
}
749-
return string(b)
750-
}
751-
752733
// calculateBackoff calculates the exponential backoff duration for a given retry count.
753734
// The formula is: base * 2^retryCount, capped at MaxRetryBackoff (5 minutes).
754735
// Sequence: 10s → 20s → 40s → 80s → 160s → 300s (cap).

internal/stringutil/strings.go

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
// Package stringutil provides common string utility functions.
2+
package stringutil
3+
4+
import "strings"
5+
6+
// ContainsIgnoreCase checks if s contains substr, ignoring case.
7+
func ContainsIgnoreCase(s, substr string) bool {
8+
return strings.Contains(strings.ToLower(s), strings.ToLower(substr))
9+
}
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
package stringutil
2+
3+
import (
4+
"testing"
5+
6+
"github.com/stretchr/testify/assert"
7+
)
8+
9+
func TestContainsIgnoreCase(t *testing.T) {
10+
tests := []struct {
11+
name string
12+
s string
13+
substr string
14+
expected bool
15+
}{
16+
{"exact match", "hello", "hello", true},
17+
{"case insensitive match", "Hello World", "hello", true},
18+
{"case insensitive match upper", "hello world", "WORLD", true},
19+
{"mixed case", "HeLLo WoRLD", "ello wor", true},
20+
{"no match", "hello", "goodbye", false},
21+
{"empty substr", "hello", "", true},
22+
{"empty string", "", "hello", false},
23+
{"both empty", "", "", true},
24+
{"substr longer than string", "hi", "hello", false},
25+
{"special chars", "error: invalid_grant", "INVALID_GRANT", true},
26+
{"network error", "connection timeout", "TIMEOUT", true},
27+
}
28+
29+
for _, tt := range tests {
30+
t.Run(tt.name, func(t *testing.T) {
31+
result := ContainsIgnoreCase(tt.s, tt.substr)
32+
assert.Equal(t, tt.expected, result)
33+
})
34+
}
35+
}

0 commit comments

Comments
 (0)