Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 27 additions & 29 deletions internal/auth/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import (
)

func TestJWTRoundTrip(t *testing.T) {
mgr := NewJWTManager("test-secret-32-chars-long-enough!", 15*time.Minute, 7*24*time.Hour)
mgr, _ := NewJWTManager("test-secret-32-chars-long-enough!", 15*time.Minute, 7*24*time.Hour)

pair, err := mgr.GenerateTokenPair("user-123", "test@example.com", "maintainer", "team-456")
if err != nil {
Expand Down Expand Up @@ -47,7 +47,7 @@ func TestJWTRoundTrip(t *testing.T) {
}

func TestJWTExpiredToken(t *testing.T) {
mgr := NewJWTManager("test-secret-32-chars-long-enough!", -1*time.Second, 7*24*time.Hour)
mgr, _ := NewJWTManager("test-secret-32-chars-long-enough!", -1*time.Second, 7*24*time.Hour)

pair, err := mgr.GenerateTokenPair("user-123", "test@example.com", "maintainer", "")
if err != nil {
Expand All @@ -61,8 +61,8 @@ func TestJWTExpiredToken(t *testing.T) {
}

func TestJWTInvalidSecret(t *testing.T) {
mgr1 := NewJWTManager("secret-one-that-is-long-enough!!", 15*time.Minute, 7*24*time.Hour)
mgr2 := NewJWTManager("secret-two-that-is-long-enough!!", 15*time.Minute, 7*24*time.Hour)
mgr1, _ := NewJWTManager("secret-one-that-is-long-enough!!", 15*time.Minute, 7*24*time.Hour)
mgr2, _ := NewJWTManager("secret-two-that-is-long-enough!!", 15*time.Minute, 7*24*time.Hour)

pair, _ := mgr1.GenerateTokenPair("user-123", "test@example.com", "maintainer", "")

Expand Down Expand Up @@ -122,7 +122,7 @@ func TestAPITokenUniqueness(t *testing.T) {
}

func TestMiddlewareNoToken(t *testing.T) {
mgr := NewJWTManager("test-secret-32-chars-long-enough!", 15*time.Minute, 7*24*time.Hour)
mgr, _ := NewJWTManager("test-secret-32-chars-long-enough!", 15*time.Minute, 7*24*time.Hour)
mw := Middleware(mgr, nil)

handler := mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
Expand All @@ -142,7 +142,7 @@ func TestMiddlewareNoToken(t *testing.T) {
}

func TestMiddlewareQueryTokenBlockedForNonWebSocket(t *testing.T) {
mgr := NewJWTManager("test-secret-32-chars-long-enough!", 15*time.Minute, 7*24*time.Hour)
mgr, _ := NewJWTManager("test-secret-32-chars-long-enough!", 15*time.Minute, 7*24*time.Hour)
pair, _ := mgr.GenerateTokenPair("user-1", "test@example.com", "owner", "team-1")

mw := Middleware(mgr, nil)
Expand All @@ -161,7 +161,7 @@ func TestMiddlewareQueryTokenBlockedForNonWebSocket(t *testing.T) {
}

func TestMiddlewareQueryTokenAllowedForWebSocket(t *testing.T) {
mgr := NewJWTManager("test-secret-32-chars-long-enough!", 15*time.Minute, 7*24*time.Hour)
mgr, _ := NewJWTManager("test-secret-32-chars-long-enough!", 15*time.Minute, 7*24*time.Hour)
pair, _ := mgr.GenerateTokenPair("user-1", "test@example.com", "owner", "team-1")

mw := Middleware(mgr, nil)
Expand All @@ -186,7 +186,7 @@ func TestMiddlewareQueryTokenAllowedForWebSocket(t *testing.T) {
}

func TestMiddlewareValidJWT(t *testing.T) {
mgr := NewJWTManager("test-secret-32-chars-long-enough!", 15*time.Minute, 7*24*time.Hour)
mgr, _ := NewJWTManager("test-secret-32-chars-long-enough!", 15*time.Minute, 7*24*time.Hour)
mw := Middleware(mgr, nil)

pair, _ := mgr.GenerateTokenPair("user-123", "test@example.com", "owner", "team-1")
Expand Down Expand Up @@ -214,7 +214,7 @@ func TestMiddlewareValidJWT(t *testing.T) {
}

func TestMiddlewareAPIToken(t *testing.T) {
mgr := NewJWTManager("test-secret-32-chars-long-enough!", 15*time.Minute, 7*24*time.Hour)
mgr, _ := NewJWTManager("test-secret-32-chars-long-enough!", 15*time.Minute, 7*24*time.Hour)

apiToken, _ := GenerateAPIToken()
expectedClaims := &Claims{
Expand Down Expand Up @@ -253,7 +253,7 @@ func TestMiddlewareAPIToken(t *testing.T) {
// --- JWT edge cases ---

func TestJWTMalformedToken(t *testing.T) {
mgr := NewJWTManager("test-secret-32-chars-long-enough!", 15*time.Minute, 7*24*time.Hour)
mgr, _ := NewJWTManager("test-secret-32-chars-long-enough!", 15*time.Minute, 7*24*time.Hour)

malformed := []string{
"",
Expand All @@ -272,17 +272,15 @@ func TestJWTMalformedToken(t *testing.T) {
}
}

func TestJWTShortSecretPanics(t *testing.T) {
defer func() {
if r := recover(); r == nil {
t.Error("expected panic for short secret")
}
}()
NewJWTManager("short", 15*time.Minute, 7*24*time.Hour)
func TestJWTShortSecretReturnsError(t *testing.T) {
_, err := NewJWTManager("short", 15*time.Minute, 7*24*time.Hour)
if err == nil {
t.Error("expected error for short secret, got nil")
}
}

func TestJWTEmptyTeamID(t *testing.T) {
mgr := NewJWTManager("test-secret-32-chars-long-enough!", 15*time.Minute, 7*24*time.Hour)
mgr, _ := NewJWTManager("test-secret-32-chars-long-enough!", 15*time.Minute, 7*24*time.Hour)
pair, err := mgr.GenerateTokenPair("user-1", "test@example.com", "owner", "")
if err != nil {
t.Fatalf("GenerateTokenPair() error: %v", err)
Expand All @@ -298,7 +296,7 @@ func TestJWTEmptyTeamID(t *testing.T) {
}

func TestJWTExpiresAtInFuture(t *testing.T) {
mgr := NewJWTManager("test-secret-32-chars-long-enough!", 5*time.Minute, 7*24*time.Hour)
mgr, _ := NewJWTManager("test-secret-32-chars-long-enough!", 5*time.Minute, 7*24*time.Hour)
pair, _ := mgr.GenerateTokenPair("u", "e@e.com", "owner", "")

if !pair.ExpiresAt.After(time.Now()) {
Expand All @@ -310,14 +308,14 @@ func TestJWTExpiresAtInFuture(t *testing.T) {
}

func TestRefreshDuration(t *testing.T) {
mgr := NewJWTManager("test-secret-32-chars-long-enough!", 15*time.Minute, 48*time.Hour)
mgr, _ := NewJWTManager("test-secret-32-chars-long-enough!", 15*time.Minute, 48*time.Hour)
if mgr.RefreshDuration() != 48*time.Hour {
t.Errorf("RefreshDuration = %v, want 48h", mgr.RefreshDuration())
}
}

func TestRefreshTokenRandomness(t *testing.T) {
mgr := NewJWTManager("test-secret-32-chars-long-enough!", 15*time.Minute, 7*24*time.Hour)
mgr, _ := NewJWTManager("test-secret-32-chars-long-enough!", 15*time.Minute, 7*24*time.Hour)
p1, _ := mgr.GenerateTokenPair("u", "e@e.com", "owner", "")
p2, _ := mgr.GenerateTokenPair("u", "e@e.com", "owner", "")
if p1.RefreshToken == p2.RefreshToken {
Expand All @@ -331,7 +329,7 @@ func TestRefreshTokenRandomness(t *testing.T) {
// --- Middleware edge cases ---

func TestMiddlewareExpiredJWT(t *testing.T) {
mgr := NewJWTManager("test-secret-32-chars-long-enough!", -1*time.Second, 7*24*time.Hour)
mgr, _ := NewJWTManager("test-secret-32-chars-long-enough!", -1*time.Second, 7*24*time.Hour)
mw := Middleware(mgr, nil)

pair, _ := mgr.GenerateTokenPair("user-1", "test@example.com", "owner", "")
Expand All @@ -351,7 +349,7 @@ func TestMiddlewareExpiredJWT(t *testing.T) {
}

func TestMiddlewareInvalidAPIToken(t *testing.T) {
mgr := NewJWTManager("test-secret-32-chars-long-enough!", 15*time.Minute, 7*24*time.Hour)
mgr, _ := NewJWTManager("test-secret-32-chars-long-enough!", 15*time.Minute, 7*24*time.Hour)

lookup := func(hash string) (*Claims, error) {
return nil, fmt.Errorf("token not found")
Expand All @@ -373,7 +371,7 @@ func TestMiddlewareInvalidAPIToken(t *testing.T) {
}

func TestMiddlewareAPITokenNoLookupConfigured(t *testing.T) {
mgr := NewJWTManager("test-secret-32-chars-long-enough!", 15*time.Minute, 7*24*time.Hour)
mgr, _ := NewJWTManager("test-secret-32-chars-long-enough!", 15*time.Minute, 7*24*time.Hour)

// tokenLookup is nil — API tokens should be rejected
mw := Middleware(mgr, nil)
Expand All @@ -392,7 +390,7 @@ func TestMiddlewareAPITokenNoLookupConfigured(t *testing.T) {
}

func TestMiddlewareBearerWithAPITokenPrefix(t *testing.T) {
mgr := NewJWTManager("test-secret-32-chars-long-enough!", 15*time.Minute, 7*24*time.Hour)
mgr, _ := NewJWTManager("test-secret-32-chars-long-enough!", 15*time.Minute, 7*24*time.Hour)

apiToken, _ := GenerateAPIToken()
expectedClaims := &Claims{UserID: "api-user", Role: "owner", TeamID: "team-1"}
Expand Down Expand Up @@ -425,18 +423,18 @@ func TestMiddlewareBearerWithAPITokenPrefix(t *testing.T) {
}

func TestMiddlewareMalformedAuthHeader(t *testing.T) {
mgr := NewJWTManager("test-secret-32-chars-long-enough!", 15*time.Minute, 7*24*time.Hour)
mgr, _ := NewJWTManager("test-secret-32-chars-long-enough!", 15*time.Minute, 7*24*time.Hour)
mw := Middleware(mgr, nil)

handler := mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
t.Error("handler should not be called for malformed auth header")
}))

headers := []string{
"Bearer ", // empty token after prefix
"Bearer ", // empty token after prefix
"Basic dXNlcjpwYXNz", // basic auth (not supported)
"InvalidScheme xyz",
"Bearer", // no space after Bearer
"Bearer", // no space after Bearer
}

for _, h := range headers {
Expand Down Expand Up @@ -498,7 +496,7 @@ func TestPasswordHashUniqueness(t *testing.T) {
}

func TestRequireRole(t *testing.T) {
mgr := NewJWTManager("test-secret-32-chars-long-enough!", 15*time.Minute, 7*24*time.Hour)
mgr, _ := NewJWTManager("test-secret-32-chars-long-enough!", 15*time.Minute, 7*24*time.Hour)

tests := []struct {
name string
Expand Down
37 changes: 23 additions & 14 deletions internal/auth/csrf.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"crypto/rand"
"crypto/sha256"
"encoding/hex"
"fmt"
"net/http"
"strings"
"time"
Expand All @@ -25,24 +26,22 @@ const (
//
// The hmacKey is used to sign CSRF cookie values so that attackers cannot
// forge valid cookie+header pairs.
func CSRFMiddleware(hmacKey []byte) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
func CSRFMiddleware(hmacKey []byte) (func(http.Handler) http.Handler, error) {
if len(hmacKey) == 0 {
return nil, fmt.Errorf("csrf: HMAC key must not be empty")
}
mw := func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Skip CSRF for safe methods
if isSafeMethod(r.Method) {
next.ServeHTTP(w, r)
return
}

// Skip CSRF for any request carrying an explicit Authorization
// header (Bearer JWT or sct_ API token). These are never
// auto-attached by browsers, so CSRF is not a threat.
if hasExplicitAuth(r) {
next.ServeHTTP(w, r)
return
}

// Validate double-submit: cookie value must match header value.
cookie, err := r.Cookie(csrfCookieName)
if err != nil || cookie.Value == "" {
jsonError(w, "missing CSRF token", http.StatusForbidden)
Expand All @@ -68,22 +67,29 @@ func CSRFMiddleware(hmacKey []byte) func(http.Handler) http.Handler {
next.ServeHTTP(w, r)
})
}
return mw, nil
}

// SetCSRFCookie generates a new signed CSRF token and writes it as a cookie.
// Call this from the csrf-token endpoint so the SPA can read and echo it back.
func SetCSRFCookie(w http.ResponseWriter, hmacKey []byte, secure bool) string {
token := generateSignedToken(hmacKey)
func SetCSRFCookie(w http.ResponseWriter, hmacKey []byte, secure bool) (string, error) {
if len(hmacKey) == 0 {
return "", fmt.Errorf("csrf: HMAC key must not be empty")
}
token, err := generateSignedToken(hmacKey)
if err != nil {
return "", err
}
http.SetCookie(w, &http.Cookie{
Name: csrfCookieName,
Value: token,
Path: "/",
HttpOnly: false, // JS must read this cookie
HttpOnly: false,
Secure: secure,
SameSite: http.SameSiteStrictMode,
MaxAge: int((24 * time.Hour).Seconds()),
})
return token
return token, nil
}

func isSafeMethod(method string) bool {
Expand All @@ -100,14 +106,17 @@ func hasExplicitAuth(r *http.Request) bool {

// generateSignedToken creates a random token with an HMAC signature appended.
// Format: <random_hex>.<hmac_hex>
func generateSignedToken(key []byte) string {
func generateSignedToken(key []byte) (string, error) {
if len(key) == 0 {
return "", fmt.Errorf("csrf: HMAC key must not be empty")
}
b := make([]byte, csrfTokenBytes)
if _, err := rand.Read(b); err != nil {
panic("csrf: failed to read random bytes: " + err.Error())
return "", fmt.Errorf("csrf: failed to read random bytes: %w", err)
}
raw := hex.EncodeToString(b)
sig := computeHMAC(raw, key)
return raw + "." + sig
return raw + "." + sig, nil
}

// validSignedToken checks that a token has a valid HMAC signature.
Expand Down
Loading
Loading