From 9fc69d1c96b4b348361782155ee20ad5e05df961 Mon Sep 17 00:00:00 2001 From: Thanatat Tamtan Date: Wed, 27 May 2026 07:47:00 +0700 Subject: [PATCH 1/5] feat: support OAuth 2.1 / MCP login flow Add the pieces an MCP CLI client needs to authenticate against this service as an OAuth 2.1 authorization server, while keeping the existing confidential web-client flow working. - discovery metadata (RFC 8414) at /.well-known/oauth-authorization-server - PKCE (S256) on the authorize and token endpoints - public clients: /token accepts code_verifier instead of client_secret - Dynamic Client Registration (RFC 7591) at /register - token introspection (RFC 7662) at /introspect for the resource server - loopback redirect URIs (RFC 8252) for native/CLI clients - standard token response (access_token + expires_in; refresh_token kept for backward compatibility) - BASE_URL / INTROSPECTION_TOKEN env vars; periodic cleanup of expired oauth2 sessions and codes Flows are discriminated by client token_endpoint_auth_method: existing rows default to client_secret_post (unchanged); registered MCP clients are public (none) and PKCE-gated. Co-Authored-By: Claude Opus 4.7 (1M context) --- README.md | 18 +++++- cleanup.go | 33 ++++++++++ handler.go | 166 ++++++++++++++++++++++++++++++++++++++------------ introspect.go | 66 ++++++++++++++++++++ main.go | 17 +++++- metadata.go | 32 ++++++++++ oauth2.go | 90 +++++++++++++++++++-------- pkce.go | 80 ++++++++++++++++++++++++ register.go | 85 ++++++++++++++++++++++++++ schema.sql | 38 +++++++----- token.go | 19 ++++++ 11 files changed, 564 insertions(+), 80 deletions(-) create mode 100644 cleanup.go create mode 100644 introspect.go create mode 100644 metadata.go create mode 100644 pkce.go create mode 100644 register.go diff --git a/README.md b/README.md index d3bce71..a9a1c06 100644 --- a/README.md +++ b/README.md @@ -19,6 +19,8 @@ Required environment variables: | `OAUTH2_CLIENT_ID` | Google OAuth app client ID | | `OAUTH2_CLIENT_SECRET` | Google OAuth app client secret | | `PORT` | Listen port (default: `8080`) | +| `BASE_URL` | Public base URL of this service (default: `https://auth.deploys.app`) | +| `INTROSPECTION_TOKEN` | Shared secret for the `/introspect` endpoint; if unset, introspection is disabled | ```shell $ ./auth @@ -31,11 +33,23 @@ the service starts. | Method | Path | Purpose | |---|---|---| -| `GET` | `/` | Validate the OAuth2 client and redirect to Google | +| `GET` | `/.well-known/oauth-authorization-server` | OAuth 2.0 Authorization Server Metadata (RFC 8414) | +| `GET` | `/` | Validate the OAuth2 client and redirect to Google (authorize endpoint; supports PKCE) | | `GET` | `/callback` | Receive Google's code and issue an internal auth code | -| `POST` | `/token` | Exchange client credentials + code for a user token | +| `POST` | `/token` | Exchange code (+ secret or PKCE verifier) for a user token | +| `POST` | `/register` | Dynamic Client Registration for public clients (RFC 7591) | +| `POST` | `/introspect` | Token introspection for resource servers (RFC 7662) | | `POST` | `/revoke` | Revoke a user token | +### MCP / public clients + +The service is an OAuth 2.1 authorization server. CLI / MCP clients register +dynamically at `/register` (public clients, no secret), use **PKCE (S256)** at +the authorize and token endpoints, and may use loopback redirect URIs +(`http://127.0.0.1:`). Confidential web clients keep using +`client_secret` as before. A resource server validates issued bearer tokens via +`/introspect` (authenticated with `INTROSPECTION_TOKEN`). + ## Deployment The provided [Dockerfile](./Dockerfile) builds a `gcr.io/distroless/static` diff --git a/cleanup.go b/cleanup.go new file mode 100644 index 0000000..8c6c927 --- /dev/null +++ b/cleanup.go @@ -0,0 +1,33 @@ +package main + +import ( + "context" + "database/sql" + "log/slog" + "time" +) + +// startCleanupWorker periodically removes expired oauth2 sessions and codes. +// Both have a 1-hour TTL but are only deleted on use, so abandoned rows would +// otherwise accumulate. Registered clients are never reaped — MCP clients reuse +// their client_id across logins. +func startCleanupWorker(db *sql.DB) { + go func() { + for { + cleanupExpired(db) + time.Sleep(15 * time.Minute) + } + }() +} + +func cleanupExpired(db *sql.DB) { + ctx := context.Background() + for _, q := range []string{ + `delete from oauth2_sessions where created_at < now() - interval '1 hour'`, + `delete from oauth2_codes where created_at < now() - interval '1 hour'`, + } { + if _, err := db.ExecContext(ctx, q); err != nil { + slog.ErrorContext(ctx, "cleanup: delete expired rows", "error", err) + } + } +} diff --git a/handler.go b/handler.go index 7cdf6aa..33308c3 100644 --- a/handler.go +++ b/handler.go @@ -1,6 +1,7 @@ package main import ( + "crypto/subtle" "encoding/base64" "encoding/json" "errors" @@ -16,6 +17,7 @@ import ( type RedirectHandler struct { OAuth2ClientID string + BaseURL string } func (h RedirectHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { @@ -41,6 +43,13 @@ func (h RedirectHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } + codeChallenge := r.FormValue("code_challenge") + codeChallengeMethod := r.FormValue("code_challenge_method") + if codeChallenge != "" && codeChallengeMethod == "" { + codeChallengeMethod = "S256" + } + resource := r.FormValue("resource") + ctx := r.Context() oauth2Client, err := getOAuth2Client(ctx, clientID) @@ -54,25 +63,47 @@ func (h RedirectHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { http.Error(w, "Internal server error", http.StatusInternalServerError) return } - pattern := oauth2Client.RedirectURI - pattern = strings.ReplaceAll(pattern, ".", `\.`) - pattern = strings.ReplaceAll(pattern, "/", `\/`) - pattern = strings.ReplaceAll(pattern, "*", `.*`) - re := regexp.MustCompile(`^` + pattern + `$`) - if !re.MatchString(callbackURL) { - slog.WarnContext(ctx, "redirect: invalid redirect_uri", "client_id", clientID, "redirect_uri", callbackURL) - http.Error(w, "Invalid redirect_uri parameter", http.StatusBadRequest) - return + + if oauth2Client.IsPublic() { + // Public clients (CLI / MCP) must use PKCE and an exact, pre-registered + // redirect URI (loopback port is allowed to vary). + if codeChallenge == "" { + http.Error(w, "Missing code_challenge parameter", http.StatusBadRequest) + return + } + if codeChallengeMethod != "S256" { + http.Error(w, "Unsupported code_challenge_method parameter", http.StatusBadRequest) + return + } + if !redirectURIAllowed(oauth2Client.RedirectURIs, callbackURL) { + slog.WarnContext(ctx, "redirect: invalid redirect_uri", "client_id", clientID, "redirect_uri", callbackURL) + http.Error(w, "Invalid redirect_uri parameter", http.StatusBadRequest) + return + } + } else { + pattern := oauth2Client.RedirectURI + pattern = strings.ReplaceAll(pattern, ".", `\.`) + pattern = strings.ReplaceAll(pattern, "/", `\/`) + pattern = strings.ReplaceAll(pattern, "*", `.*`) + re := regexp.MustCompile(`^` + pattern + `$`) + if !re.MatchString(callbackURL) { + slog.WarnContext(ctx, "redirect: invalid redirect_uri", "client_id", clientID, "redirect_uri", callbackURL) + http.Error(w, "Invalid redirect_uri parameter", http.StatusBadRequest) + return + } } state := generateState() sessionID := generateSessionID() err = saveSession(ctx, sessionID, &Session{ - ClientID: oauth2Client.ID, - State: state, - CallbackState: callbackState, - CallbackURL: callbackURL, + ClientID: oauth2Client.ID, + State: state, + CallbackState: callbackState, + CallbackURL: callbackURL, + CodeChallenge: codeChallenge, + CodeChallengeMethod: codeChallengeMethod, + Resource: resource, }) if err != nil { slog.ErrorContext(ctx, "redirect: save session", "error", err) @@ -92,7 +123,7 @@ func (h RedirectHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { params := url.Values{} params.Set("response_type", "code") params.Set("client_id", h.OAuth2ClientID) - params.Set("redirect_uri", "https://auth.deploys.app/callback") + params.Set("redirect_uri", h.BaseURL+"/callback") params.Set("scope", "https://www.googleapis.com/auth/userinfo.email") params.Set("access_type", "online") params.Set("prompt", "consent") @@ -113,6 +144,7 @@ func isURL(s string) bool { type CallbackHandler struct { OAuth2ClientID string OAuth2ClientSecret string + BaseURL string } func (h CallbackHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { @@ -161,7 +193,7 @@ func (h CallbackHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { params := url.Values{} params.Set("grant_type", "authorization_code") params.Set("code", code) - params.Set("redirect_uri", "https://auth.deploys.app/callback") + params.Set("redirect_uri", h.BaseURL+"/callback") params.Set("client_id", h.OAuth2ClientID) params.Set("client_secret", h.OAuth2ClientSecret) @@ -195,7 +227,13 @@ func (h CallbackHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { } returnCode := generateCode() - err = insertOAuth2Code(ctx, session.ClientID, returnCode, email) + err = insertOAuth2Code(ctx, session.ClientID, returnCode, &OAuth2Code{ + Email: email, + CodeChallenge: session.CodeChallenge, + CodeChallengeMethod: session.CodeChallengeMethod, + RedirectURI: session.CallbackURL, + Resource: session.Resource, + }) if err != nil { slog.ErrorContext(ctx, "callback: insert oauth2 code", "error", err) http.Error(w, "Internal server error", http.StatusInternalServerError) @@ -297,69 +335,121 @@ func (RevokePostHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { type TokenHandler struct{} func (TokenHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { - clientID := r.PostFormValue("client_id") - if clientID == "" { - http.Error(w, "Missing client_id parameter", http.StatusBadRequest) + grantType := r.PostFormValue("grant_type") + if grantType != "" && grantType != "authorization_code" { + oauthError(w, http.StatusBadRequest, "unsupported_grant_type", "only authorization_code is supported") return } - clientSecret := r.PostFormValue("client_secret") - if clientSecret == "" { - http.Error(w, "Missing client_secret parameter", http.StatusBadRequest) + + clientID := r.PostFormValue("client_id") + if clientID == "" { + oauthError(w, http.StatusBadRequest, "invalid_request", "missing client_id") return } code := r.PostFormValue("code") if code == "" { - http.Error(w, "Missing code parameter", http.StatusBadRequest) + oauthError(w, http.StatusBadRequest, "invalid_request", "missing code") return } ctx := r.Context() oauth2Client, err := getOAuth2Client(ctx, clientID) if errors.Is(err, ErrOAuth2ClientNotFound) { - http.Error(w, "Invalid client_id parameter", http.StatusBadRequest) + oauthError(w, http.StatusUnauthorized, "invalid_client", "unknown client_id") return } if err != nil { - http.Error(w, "Internal server error", http.StatusInternalServerError) + slog.ErrorContext(ctx, "token: get oauth2 client", "error", err) + oauthError(w, http.StatusInternalServerError, "server_error", "") return } - if oauth2Client.Secret != clientSecret { - http.Error(w, "Invalid client_secret parameter", http.StatusBadRequest) - return + + // Authenticate the client. Public clients (CLI / MCP) rely on PKCE instead + // of a secret; confidential clients must present client_secret. + if oauth2Client.IsPublic() { + if r.PostFormValue("code_verifier") == "" { + oauthError(w, http.StatusBadRequest, "invalid_request", "missing code_verifier") + return + } + } else { + clientSecret := r.PostFormValue("client_secret") + if clientSecret == "" { + oauthError(w, http.StatusBadRequest, "invalid_request", "missing client_secret") + return + } + if subtle.ConstantTimeCompare([]byte(clientSecret), []byte(oauth2Client.Secret)) != 1 { + oauthError(w, http.StatusUnauthorized, "invalid_client", "invalid client_secret") + return + } } - email, err := getOAuth2EmailFromCode(ctx, clientID, code) + oauth2Code, err := getOAuth2Code(ctx, clientID, code) if errors.Is(err, ErrOAuth2CodeNotFound) { - slog.WarnContext(ctx, "token: invalid code", "client_id", clientID, "code", code) - http.Error(w, "Invalid code parameter", http.StatusBadRequest) + slog.WarnContext(ctx, "token: invalid code", "client_id", clientID) + oauthError(w, http.StatusBadRequest, "invalid_grant", "invalid or expired code") return } if err != nil { - slog.ErrorContext(ctx, "token: get oauth2 email from code", "error", err) - http.Error(w, "Internal server error", http.StatusInternalServerError) + slog.ErrorContext(ctx, "token: get oauth2 code", "error", err) + oauthError(w, http.StatusInternalServerError, "server_error", "") return } + // PKCE: verify whenever the code was issued with a challenge. + if oauth2Code.CodeChallenge != "" { + verifier := r.PostFormValue("code_verifier") + if !verifyPKCE(verifier, oauth2Code.CodeChallenge, oauth2Code.CodeChallengeMethod) { + slog.WarnContext(ctx, "token: pkce verification failed", "client_id", clientID) + oauthError(w, http.StatusBadRequest, "invalid_grant", "PKCE verification failed") + return + } + } + + // For public clients the redirect_uri presented here must match the one the + // code was bound to at the authorize step (RFC 6749 §4.1.3). + if oauth2Client.IsPublic() && oauth2Code.RedirectURI != "" { + if r.PostFormValue("redirect_uri") != oauth2Code.RedirectURI { + oauthError(w, http.StatusBadRequest, "invalid_grant", "redirect_uri mismatch") + return + } + } + token := generateToken() hashedToken := hashToken(token) - err = insertToken(ctx, hashedToken, email) + err = insertToken(ctx, hashedToken, oauth2Code.Email) if err != nil { slog.ErrorContext(ctx, "token: insert token", "error", err) - http.Error(w, "Internal server error", http.StatusInternalServerError) + oauthError(w, http.StatusInternalServerError, "server_error", "") return } var resp struct { - RefreshToken string `json:"refresh_token"` + AccessToken string `json:"access_token"` TokenType string `json:"token_type"` + ExpiresIn int `json:"expires_in"` + RefreshToken string `json:"refresh_token"` } - resp.TokenType = "bearer" - resp.RefreshToken = token + resp.AccessToken = token + resp.TokenType = "Bearer" + resp.ExpiresIn = tokenTTLSeconds + resp.RefreshToken = token // retained for backward compatibility with the existing web client w.Header().Set("Content-Type", "application/json; charset=utf-8") + w.Header().Set("Cache-Control", "no-store") json.NewEncoder(w).Encode(resp) } +func oauthError(w http.ResponseWriter, status int, code, desc string) { + w.Header().Set("Content-Type", "application/json; charset=utf-8") + w.Header().Set("Cache-Control", "no-store") + w.WriteHeader(status) + body := map[string]string{"error": code} + if desc != "" { + body["error_description"] = desc + } + json.NewEncoder(w).Encode(body) +} + type apiResult struct { OK bool `json:"ok"` Result any `json:"result,omitempty"` diff --git a/introspect.go b/introspect.go new file mode 100644 index 0000000..0b1b134 --- /dev/null +++ b/introspect.go @@ -0,0 +1,66 @@ +package main + +import ( + "crypto/subtle" + "database/sql" + "encoding/json" + "errors" + "log/slog" + "net/http" +) + +// IntrospectHandler implements OAuth 2.0 Token Introspection (RFC 7662). The MCP +// resource server calls it to validate an opaque bearer token it received. +// +// The endpoint is itself protected by a shared secret (INTROSPECTION_TOKEN) +// presented as `Authorization: Bearer `, so it is not a public oracle. +type IntrospectHandler struct { + Token string +} + +func (h IntrospectHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + if h.Token == "" { + http.Error(w, "introspection not configured", http.StatusServiceUnavailable) + return + } + expected := "Bearer " + h.Token + if subtle.ConstantTimeCompare([]byte(r.Header.Get("Authorization")), []byte(expected)) != 1 { + w.Header().Set("WWW-Authenticate", "Bearer") + http.Error(w, "unauthorized", http.StatusUnauthorized) + return + } + + token := r.PostFormValue("token") + if token == "" { + writeInactive(w) + return + } + + ctx := r.Context() + email, exp, err := lookupToken(ctx, hashToken(token)) + if errors.Is(err, sql.ErrNoRows) { + writeInactive(w) + return + } + if err != nil { + slog.ErrorContext(ctx, "introspect: lookup token", "error", err) + http.Error(w, "Internal server error", http.StatusInternalServerError) + return + } + + w.Header().Set("Content-Type", "application/json; charset=utf-8") + w.Header().Set("Cache-Control", "no-store") + json.NewEncoder(w).Encode(map[string]any{ + "active": true, + "sub": email, + "username": email, + "token_type": "Bearer", + "exp": exp, + }) +} + +func writeInactive(w http.ResponseWriter) { + w.Header().Set("Content-Type", "application/json; charset=utf-8") + w.Header().Set("Cache-Control", "no-store") + json.NewEncoder(w).Encode(map[string]any{"active": false}) +} diff --git a/main.go b/main.go index 14e8fb8..c1f5c8a 100644 --- a/main.go +++ b/main.go @@ -4,6 +4,7 @@ import ( "database/sql" "net/http" "os" + "strings" "github.com/acoshift/pgsql/pgctx" _ "github.com/lib/pq" @@ -14,12 +15,20 @@ func main() { if port == "" { port = "8080" } + baseURL := os.Getenv("BASE_URL") + if baseURL == "" { + baseURL = "https://auth.deploys.app" + } + baseURL = strings.TrimRight(baseURL, "/") + oauth2ClientID := os.Getenv("OAUTH2_CLIENT_ID") if oauth2ClientID == "" { panic("missing OAUTH2_CLIENT_ID") } oauth2ClientSecret := os.Getenv("OAUTH2_CLIENT_SECRET") + introspectionToken := os.Getenv("INTROSPECTION_TOKEN") + sqlURL := os.Getenv("SQL_URL") if sqlURL == "" { panic("missing SQL_URL") @@ -30,15 +39,21 @@ func main() { } defer db.Close() + startCleanupWorker(db) + mux := http.NewServeMux() - mux.Handle("GET /", RedirectHandler{OAuth2ClientID: oauth2ClientID}) + mux.Handle("GET /.well-known/oauth-authorization-server", MetadataHandler{BaseURL: baseURL}) + mux.Handle("GET /", RedirectHandler{OAuth2ClientID: oauth2ClientID, BaseURL: baseURL}) mux.Handle("GET /callback", CallbackHandler{ OAuth2ClientID: oauth2ClientID, OAuth2ClientSecret: oauth2ClientSecret, + BaseURL: baseURL, }) mux.Handle("GET /revoke", RevokeHandler{}) // TODO: remove ? mux.Handle("POST /revoke", RevokePostHandler{}) mux.Handle("POST /token", TokenHandler{}) + mux.Handle("POST /register", RegisterHandler{BaseURL: baseURL}) + mux.Handle("POST /introspect", IntrospectHandler{Token: introspectionToken}) http.ListenAndServe(":"+port, pgctx.Middleware(db)(mux)) } diff --git a/metadata.go b/metadata.go new file mode 100644 index 0000000..e795958 --- /dev/null +++ b/metadata.go @@ -0,0 +1,32 @@ +package main + +import ( + "encoding/json" + "net/http" +) + +// MetadataHandler serves OAuth 2.0 Authorization Server Metadata (RFC 8414). +// MCP clients fetch this to discover the authorize, token and registration +// endpoints plus the supported PKCE methods and client auth methods. +type MetadataHandler struct { + BaseURL string +} + +func (h MetadataHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + meta := map[string]any{ + "issuer": h.BaseURL, + "authorization_endpoint": h.BaseURL + "/", + "token_endpoint": h.BaseURL + "/token", + "registration_endpoint": h.BaseURL + "/register", + "revocation_endpoint": h.BaseURL + "/revoke", + "introspection_endpoint": h.BaseURL + "/introspect", + "response_types_supported": []string{"code"}, + "grant_types_supported": []string{"authorization_code"}, + "code_challenge_methods_supported": []string{"S256"}, + "token_endpoint_auth_methods_supported": []string{"client_secret_post", "none"}, + } + + w.Header().Set("Content-Type", "application/json; charset=utf-8") + w.Header().Set("Cache-Control", "public, max-age=3600") + json.NewEncoder(w).Encode(meta) +} diff --git a/oauth2.go b/oauth2.go index 4a82307..76ec028 100644 --- a/oauth2.go +++ b/oauth2.go @@ -4,6 +4,7 @@ import ( "context" "database/sql" "errors" + "strings" "github.com/acoshift/pgsql/pgctx" ) @@ -15,26 +16,47 @@ var ( ) type OAuth2Client struct { - ID string - Secret string - RedirectURI string + ID string + Secret string + RedirectURI string // legacy glob pattern (confidential clients) + RedirectURIs []string // exact URIs (public / DCR clients) + TokenEndpointAuthMethod string // "client_secret_post" or "none" + ClientName string +} + +// IsPublic reports whether the client authenticates without a secret (PKCE only). +func (c *OAuth2Client) IsPublic() bool { + return c.TokenEndpointAuthMethod == "none" } type Session struct { - ClientID string - State string - CallbackState string - CallbackURL string + ClientID string + State string + CallbackState string + CallbackURL string + CodeChallenge string + CodeChallengeMethod string + Resource string +} + +type OAuth2Code struct { + Email string + CodeChallenge string + CodeChallengeMethod string + RedirectURI string + Resource string } func getOAuth2Client(ctx context.Context, clientID string) (*OAuth2Client, error) { var x OAuth2Client + var secret sql.NullString + var redirectURIs string err := pgctx.QueryRow(ctx, ` - select id, secret, redirect_uri + select id, secret, redirect_uri, redirect_uris, token_endpoint_auth_method from oauth2_clients where id = $1 `, clientID).Scan( - &x.ID, &x.Secret, &x.RedirectURI, + &x.ID, &secret, &x.RedirectURI, &redirectURIs, &x.TokenEndpointAuthMethod, ) if errors.Is(err, sql.ErrNoRows) { return nil, ErrOAuth2ClientNotFound @@ -42,41 +64,58 @@ func getOAuth2Client(ctx context.Context, clientID string) (*OAuth2Client, error if err != nil { return nil, err } + x.Secret = secret.String + if redirectURIs != "" { + x.RedirectURIs = strings.Split(redirectURIs, "\n") + } return &x, nil } -func insertOAuth2Code(ctx context.Context, clientID, code, email string) error { +// insertOAuth2Client persists a dynamically registered public client. +func insertOAuth2Client(ctx context.Context, c *OAuth2Client) error { + _, err := pgctx.Exec(ctx, ` + insert into oauth2_clients (id, secret, redirect_uri, redirect_uris, token_endpoint_auth_method, client_name) + values ($1, null, '', $2, $3, $4) + `, c.ID, strings.Join(c.RedirectURIs, "\n"), c.TokenEndpointAuthMethod, c.ClientName) + return err +} + +func insertOAuth2Code(ctx context.Context, clientID, code string, c *OAuth2Code) error { _, err := pgctx.Exec(ctx, ` - insert into oauth2_codes (id, client_id, email) - values ($1, $2, $3) - `, code, clientID, email) + insert into oauth2_codes (id, client_id, email, code_challenge, code_challenge_method, redirect_uri, resource) + values ($1, $2, $3, $4, $5, $6, $7) + `, code, clientID, c.Email, c.CodeChallenge, c.CodeChallengeMethod, c.RedirectURI, c.Resource) return err } -func getOAuth2EmailFromCode(ctx context.Context, clientID, code string) (string, error) { - var email string +// getOAuth2Code atomically consumes a code, returning the associated email and +// the PKCE / redirect binding stored when it was issued. 1-hour TTL. +func getOAuth2Code(ctx context.Context, clientID, code string) (*OAuth2Code, error) { + var x OAuth2Code err := pgctx.QueryRow(ctx, ` delete from oauth2_codes where id = $1 and client_id = $2 and now() < created_at + interval '1 hour' - returning email - `, code, clientID).Scan(&email) + returning email, code_challenge, code_challenge_method, redirect_uri, resource + `, code, clientID).Scan( + &x.Email, &x.CodeChallenge, &x.CodeChallengeMethod, &x.RedirectURI, &x.Resource, + ) if errors.Is(err, sql.ErrNoRows) { - return "", ErrOAuth2CodeNotFound + return nil, ErrOAuth2CodeNotFound } if err != nil { - return "", err + return nil, err } - return email, nil + return &x, nil } func getSession(ctx context.Context, sessionID string) (*Session, error) { var x Session err := pgctx.QueryRow(ctx, ` - select client_id, state, callback_state, callback_url + select client_id, state, callback_state, callback_url, code_challenge, code_challenge_method, resource from oauth2_sessions where id = $1 and now() < created_at + interval '1 hour' `, sessionID).Scan( - &x.ClientID, &x.State, &x.CallbackState, &x.CallbackURL, + &x.ClientID, &x.State, &x.CallbackState, &x.CallbackURL, &x.CodeChallenge, &x.CodeChallengeMethod, &x.Resource, ) if errors.Is(err, sql.ErrNoRows) { return nil, ErrOAuth2SessionNotFound @@ -95,8 +134,9 @@ func getSession(ctx context.Context, sessionID string) (*Session, error) { func saveSession(ctx context.Context, sessionID string, session *Session) error { _, err := pgctx.Exec(ctx, ` - insert into oauth2_sessions (id, client_id, state, callback_state, callback_url) - values ($1, $2, $3, $4, $5) - `, sessionID, session.ClientID, session.State, session.CallbackState, session.CallbackURL) + insert into oauth2_sessions (id, client_id, state, callback_state, callback_url, code_challenge, code_challenge_method, resource) + values ($1, $2, $3, $4, $5, $6, $7, $8) + `, sessionID, session.ClientID, session.State, session.CallbackState, session.CallbackURL, + session.CodeChallenge, session.CodeChallengeMethod, session.Resource) return err } diff --git a/pkce.go b/pkce.go new file mode 100644 index 0000000..32c90ef --- /dev/null +++ b/pkce.go @@ -0,0 +1,80 @@ +package main + +import ( + "crypto/sha256" + "crypto/subtle" + "encoding/base64" + "net/url" + "strings" +) + +// verifyPKCE checks a PKCE code_verifier against the stored code_challenge. +// Only the S256 method is supported (OAuth 2.1 / MCP requirement). +func verifyPKCE(verifier, challenge, method string) bool { + if verifier == "" || challenge == "" { + return false + } + if method != "S256" { + return false + } + sum := sha256.Sum256([]byte(verifier)) + computed := base64.RawURLEncoding.EncodeToString(sum[:]) + return subtle.ConstantTimeCompare([]byte(computed), []byte(challenge)) == 1 +} + +func isLoopbackHost(host string) bool { + return host == "127.0.0.1" || host == "::1" || host == "localhost" +} + +// redirectURIAllowed reports whether got matches one of the registered exact +// redirect URIs. Per RFC 8252 the port is ignored for loopback addresses, so a +// CLI listening on an ephemeral localhost port is accepted. +func redirectURIAllowed(registered []string, got string) bool { + gu, err := url.Parse(got) + if err != nil { + return false + } + for _, raw := range registered { + raw = strings.TrimSpace(raw) + if raw == "" { + continue + } + ru, err := url.Parse(raw) + if err != nil { + continue + } + if ru.Scheme != gu.Scheme { + continue + } + if !strings.EqualFold(ru.Hostname(), gu.Hostname()) { + continue + } + if ru.Path != gu.Path { + continue + } + if isLoopbackHost(gu.Hostname()) { + return true + } + if ru.Port() == gu.Port() { + return true + } + } + return false +} + +// validRegistrationRedirectURI enforces the redirect URIs a client may register: +// HTTPS anywhere, or plain HTTP only for loopback (native/CLI clients). +func validRegistrationRedirectURI(s string) bool { + u, err := url.Parse(s) + if err != nil || u.Host == "" { + return false + } + switch u.Scheme { + case "https": + return true + case "http": + return isLoopbackHost(u.Hostname()) + default: + return false + } +} diff --git a/register.go b/register.go new file mode 100644 index 0000000..af549a9 --- /dev/null +++ b/register.go @@ -0,0 +1,85 @@ +package main + +import ( + "encoding/json" + "log/slog" + "net/http" + "time" +) + +// RegisterHandler implements OAuth 2.0 Dynamic Client Registration (RFC 7591). +// It only issues public clients (token_endpoint_auth_method "none"); they +// authenticate via PKCE, so no client secret is generated. This is the path MCP +// CLIs use to obtain a client_id without manual provisioning. +type RegisterHandler struct { + BaseURL string +} + +func (h RegisterHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + var req struct { + ClientName string `json:"client_name"` + RedirectURIs []string `json:"redirect_uris"` + GrantTypes []string `json:"grant_types"` + ResponseTypes []string `json:"response_types"` + TokenEndpointAuthMethod string `json:"token_endpoint_auth_method"` + } + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + registrationError(w, "invalid_client_metadata", "invalid request body") + return + } + + if req.TokenEndpointAuthMethod != "" && req.TokenEndpointAuthMethod != "none" { + registrationError(w, "invalid_client_metadata", "only token_endpoint_auth_method \"none\" is supported") + return + } + if len(req.RedirectURIs) == 0 { + registrationError(w, "invalid_redirect_uri", "at least one redirect_uri is required") + return + } + for _, uri := range req.RedirectURIs { + if !validRegistrationRedirectURI(uri) { + registrationError(w, "invalid_redirect_uri", "redirect_uri must be https or http loopback: "+uri) + return + } + } + + ctx := r.Context() + clientID := generateClientID() + err := insertOAuth2Client(ctx, &OAuth2Client{ + ID: clientID, + RedirectURIs: req.RedirectURIs, + TokenEndpointAuthMethod: "none", + ClientName: req.ClientName, + }) + if err != nil { + slog.ErrorContext(ctx, "register: insert oauth2 client", "error", err) + w.Header().Set("Content-Type", "application/json; charset=utf-8") + w.WriteHeader(http.StatusInternalServerError) + json.NewEncoder(w).Encode(map[string]string{"error": "server_error"}) + return + } + + resp := map[string]any{ + "client_id": clientID, + "client_id_issued_at": time.Now().Unix(), + "client_name": req.ClientName, + "redirect_uris": req.RedirectURIs, + "grant_types": []string{"authorization_code"}, + "response_types": []string{"code"}, + "token_endpoint_auth_method": "none", + } + + w.Header().Set("Content-Type", "application/json; charset=utf-8") + w.Header().Set("Cache-Control", "no-store") + w.WriteHeader(http.StatusCreated) + json.NewEncoder(w).Encode(resp) +} + +func registrationError(w http.ResponseWriter, code, desc string) { + w.Header().Set("Content-Type", "application/json; charset=utf-8") + w.WriteHeader(http.StatusBadRequest) + json.NewEncoder(w).Encode(map[string]string{ + "error": code, + "error_description": desc, + }) +} diff --git a/schema.sql b/schema.sql index c9155d5..a561bc2 100644 --- a/schema.sql +++ b/schema.sql @@ -1,28 +1,38 @@ create table oauth2_clients ( - id string, - secret string not null, - redirect_uri string not null, - created_at timestamptz not null default now(), + id string, + secret string, + redirect_uri string not null default '', + redirect_uris string not null default '', + token_endpoint_auth_method string not null default 'client_secret_post', + client_name string not null default '', + created_at timestamptz not null default now(), primary key (id) ); create table oauth2_codes ( - id string, - client_id string not null, - email string not null, - created_at timestamptz not null default now(), + id string, + client_id string not null, + email string not null, + code_challenge string not null default '', + code_challenge_method string not null default '', + redirect_uri string not null default '', + resource string not null default '', + created_at timestamptz not null default now(), primary key (id), foreign key (client_id) references oauth2_clients (id) on delete cascade ); create index oauth2_codes_created_at_idx on oauth2_codes (created_at); create table oauth2_sessions ( - id string, - client_id string not null, - state string not null, - callback_state string not null, - callback_url string not null, - created_at timestamptz not null default now(), + id string, + client_id string not null, + state string not null, + callback_state string not null, + callback_url string not null, + code_challenge string not null default '', + code_challenge_method string not null default '', + resource string not null default '', + created_at timestamptz not null default now(), primary key (id) ); create index oauth2_sessions_created_at_idx on oauth2_sessions (created_at); diff --git a/token.go b/token.go index 92431c5..4f8c15b 100644 --- a/token.go +++ b/token.go @@ -11,6 +11,10 @@ import ( const tokenPrefix = "deploys-api." +// tokenTTLSeconds mirrors the 7-day TTL applied to user_tokens rows; reported +// to clients as expires_in. +const tokenTTLSeconds = 7 * 24 * 60 * 60 + func generateBase64RandomString(s int) string { b := make([]byte, s) _, err := rand.Read(b[:]) @@ -36,6 +40,10 @@ func generateSessionID() string { return generateBase64RandomString(32) } +func generateClientID() string { + return generateBase64RandomString(16) +} + func hashToken(token string) string { h := sha256.New() h.Write([]byte(token)) @@ -55,3 +63,14 @@ func deleteToken(ctx context.Context, token string) error { _, err := pgctx.Exec(ctx, `delete from user_tokens where token = $1`, token) return err } + +// lookupToken resolves a hashed token to its owner email and expiry (unix +// seconds), returning sql.ErrNoRows when the token is unknown or expired. +func lookupToken(ctx context.Context, hashedToken string) (email string, exp int64, err error) { + err = pgctx.QueryRow(ctx, ` + select email, extract(epoch from expires_at)::bigint + from user_tokens + where token = $1 and expires_at > now() + `, hashedToken).Scan(&email, &exp) + return +} From 1550960d0fca061b28c5cf17b2568c8ebf2af221 Mon Sep 17 00:00:00 2001 From: Thanatat Tamtan Date: Wed, 27 May 2026 07:56:08 +0700 Subject: [PATCH 2/5] test: cover old confidential flow and new MCP/PKCE flow Add the first tests to the repo, using go-sqlmock to fake the DB layer (pgctx talks to a standard *sql.DB, no transactions) and an httptest stub for Google's token endpoint. Regression (old flow): - TokenHandler confidential client_secret path still returns refresh_token (and now access_token/expires_in), wrong/missing secret rejected - RedirectHandler glob redirect validation + Google redirect + session cookie - CallbackHandler session lookup, state check, internal code issuance New flow: - TokenHandler public/PKCE: success, bad verifier, redirect mismatch, missing verifier, unsupported grant_type, unknown client, invalid code - RedirectHandler public: PKCE required, S256 only, loopback redirect match - CallbackHandler carries PKCE/redirect/resource onto the code - Dynamic Client Registration validation + success - Introspection: config/auth gates, active/unknown/empty token - Discovery metadata; PKCE + redirect + registration URI helpers To make the callback testable, googleTokenURL is now a package var. Co-Authored-By: Claude Opus 4.7 (1M context) --- callback_test.go | 125 +++++++++++++++++ go.mod | 2 + go.sum | 5 +- handler.go | 6 +- handler_helpers_test.go | 102 ++++++++++++++ introspect_test.go | 111 ++++++++++++++++ pkce_test.go | 86 ++++++++++++ redirect_test.go | 197 +++++++++++++++++++++++++++ register_test.go | 83 ++++++++++++ testutil_test.go | 56 ++++++++ token_test.go | 287 ++++++++++++++++++++++++++++++++++++++++ 11 files changed, 1057 insertions(+), 3 deletions(-) create mode 100644 callback_test.go create mode 100644 handler_helpers_test.go create mode 100644 introspect_test.go create mode 100644 pkce_test.go create mode 100644 redirect_test.go create mode 100644 register_test.go create mode 100644 testutil_test.go create mode 100644 token_test.go diff --git a/callback_test.go b/callback_test.go new file mode 100644 index 0000000..db4e1ba --- /dev/null +++ b/callback_test.go @@ -0,0 +1,125 @@ +package main + +import ( + "fmt" + "net/http" + "net/http/httptest" + "net/url" + "strings" + "testing" + + "github.com/DATA-DOG/go-sqlmock" +) + +func sessionRow(clientID, state, callbackState, callbackURL, challenge, method, resource string) *sqlmock.Rows { + cols := []string{"client_id", "state", "callback_state", "callback_url", "code_challenge", "code_challenge_method", "resource"} + return sqlmock.NewRows(cols).AddRow(clientID, state, callbackState, callbackURL, challenge, method, resource) +} + +// stubGoogle points the callback token exchange at a local server that returns +// an id_token carrying email, and restores the real URL afterwards. +func stubGoogle(t *testing.T, email string) { + t.Helper() + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + fmt.Fprintf(w, `{"id_token":%q}`, fakeIDToken(email)) + })) + old := googleTokenURL + googleTokenURL = srv.URL + t.Cleanup(func() { + googleTokenURL = old + srv.Close() + }) +} + +// --- OLD FLOW (regression): Google callback issues an internal code --- + +func TestCallbackHandler_Confidential_Success(t *testing.T) { + stubGoogle(t, "user@example.com") + + db, mock := newMock(t) + mock.ExpectQuery("from oauth2_sessions"). + WillReturnRows(sessionRow("web", "gstate", "cbstate", "https://app.example.com/cb", "", "", "")) + mock.ExpectExec("delete from oauth2_sessions").WillReturnResult(sqlmock.NewResult(0, 1)) + mock.ExpectExec("insert into oauth2_codes"). + WithArgs(sqlmock.AnyArg(), "web", "user@example.com", "", "", "https://app.example.com/cb", ""). + WillReturnResult(sqlmock.NewResult(0, 1)) + + q := url.Values{"state": {"gstate"}, "code": {"google-code"}} + req := httptest.NewRequest(http.MethodGet, "/callback?"+q.Encode(), nil) + req.AddCookie(&http.Cookie{Name: "s", Value: "sess123"}) + rec := httptest.NewRecorder() + CallbackHandler{OAuth2ClientID: "g", OAuth2ClientSecret: "gs", BaseURL: "https://auth.test"}. + ServeHTTP(rec, withDB(req, db)) + + if rec.Code != http.StatusFound { + t.Fatalf("status = %d, want 302; body=%s", rec.Code, rec.Body.String()) + } + loc := rec.Header().Get("Location") + if !strings.HasPrefix(loc, "https://app.example.com/cb?") { + t.Fatalf("Location = %q, want redirect to client callback", loc) + } + u, _ := url.Parse(loc) + if u.Query().Get("state") != "cbstate" { + t.Errorf("callback state = %q, want cbstate", u.Query().Get("state")) + } + if u.Query().Get("code") == "" { + t.Error("callback code is empty") + } + assertExpectations(t, mock) +} + +func TestCallbackHandler_MissingSessionCookie(t *testing.T) { + db, _ := newMock(t) + q := url.Values{"state": {"gstate"}, "code": {"google-code"}} + req := httptest.NewRequest(http.MethodGet, "/callback?"+q.Encode(), nil) // no cookie + rec := httptest.NewRecorder() + CallbackHandler{BaseURL: "https://auth.test"}.ServeHTTP(rec, withDB(req, db)) + if rec.Code != http.StatusBadRequest { + t.Fatalf("status = %d, want 400", rec.Code) + } +} + +func TestCallbackHandler_StateMismatch(t *testing.T) { + db, mock := newMock(t) + mock.ExpectQuery("from oauth2_sessions"). + WillReturnRows(sessionRow("web", "expected-state", "cbstate", "https://app.example.com/cb", "", "", "")) + mock.ExpectExec("delete from oauth2_sessions").WillReturnResult(sqlmock.NewResult(0, 1)) + + q := url.Values{"state": {"attacker-state"}, "code": {"google-code"}} + req := httptest.NewRequest(http.MethodGet, "/callback?"+q.Encode(), nil) + req.AddCookie(&http.Cookie{Name: "s", Value: "sess123"}) + rec := httptest.NewRecorder() + CallbackHandler{BaseURL: "https://auth.test"}.ServeHTTP(rec, withDB(req, db)) + if rec.Code != http.StatusBadRequest { + t.Fatalf("status = %d, want 400 (state mismatch)", rec.Code) + } + assertExpectations(t, mock) +} + +// --- NEW FLOW: PKCE challenge from the session is carried onto the code --- + +func TestCallbackHandler_Public_CarriesPKCE(t *testing.T) { + stubGoogle(t, "user@example.com") + _, challenge := pkcePair() + + db, mock := newMock(t) + mock.ExpectQuery("from oauth2_sessions"). + WillReturnRows(sessionRow("cli", "gstate", "cbstate", "http://127.0.0.1:55001/callback", challenge, "S256", "https://api.deploys.app")) + mock.ExpectExec("delete from oauth2_sessions").WillReturnResult(sqlmock.NewResult(0, 1)) + // The code must be persisted with the PKCE challenge, redirect and resource. + mock.ExpectExec("insert into oauth2_codes"). + WithArgs(sqlmock.AnyArg(), "cli", "user@example.com", challenge, "S256", "http://127.0.0.1:55001/callback", "https://api.deploys.app"). + WillReturnResult(sqlmock.NewResult(0, 1)) + + q := url.Values{"state": {"gstate"}, "code": {"google-code"}} + req := httptest.NewRequest(http.MethodGet, "/callback?"+q.Encode(), nil) + req.AddCookie(&http.Cookie{Name: "s", Value: "sess123"}) + rec := httptest.NewRecorder() + CallbackHandler{BaseURL: "https://auth.test"}.ServeHTTP(rec, withDB(req, db)) + + if rec.Code != http.StatusFound { + t.Fatalf("status = %d, want 302; body=%s", rec.Code, rec.Body.String()) + } + assertExpectations(t, mock) +} diff --git a/go.mod b/go.mod index 8d9f4aa..b64f078 100644 --- a/go.mod +++ b/go.mod @@ -6,3 +6,5 @@ require ( github.com/acoshift/pgsql v0.16.0 github.com/lib/pq v1.12.3 ) + +require github.com/DATA-DOG/go-sqlmock v1.5.2 diff --git a/go.sum b/go.sum index 66994d2..d5e7c24 100644 --- a/go.sum +++ b/go.sum @@ -1,9 +1,10 @@ -github.com/DATA-DOG/go-sqlmock v1.5.0 h1:Shsta01QNfFxHCfpW6YH2STWB0MudeXXEWMr20OEh60= -github.com/DATA-DOG/go-sqlmock v1.5.0/go.mod h1:f/Ixk793poVmq4qj/V1dPUg2JEAKC73Q5eFN3EC/SaM= +github.com/DATA-DOG/go-sqlmock v1.5.2 h1:OcvFkGmslmlZibjAjaHm3L//6LiuBgolP7OputlJIzU= +github.com/DATA-DOG/go-sqlmock v1.5.2/go.mod h1:88MAG/4G7SMwSE3CeA0ZKzrT5CiOU3OJ+JlNzwDqpNU= github.com/acoshift/pgsql v0.16.0 h1:ak+fwy8Xnx0uZBhSmvFhGGk3a7EQNu8IiRLpnP99IT4= github.com/acoshift/pgsql v0.16.0/go.mod h1:HtdMa77CYeRb9pD6+cT/ZPjpudiUVQ2OIKv6QXjgEZw= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/kisielk/sqlstruct v0.0.0-20201105191214-5f3e10d3ab46/go.mod h1:yyMNCyc/Ib3bDTKd379tNMpB/7/H5TjM2Y9QJ5THLbE= github.com/lib/pq v1.12.3 h1:tTWxr2YLKwIvK90ZXEw8GP7UFHtcbTtty8zsI+YjrfQ= github.com/lib/pq v1.12.3/go.mod h1:/p+8NSbOcwzAEI7wiMXFlgydTwcgTr3OSKMsD2BitpA= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= diff --git a/handler.go b/handler.go index 33308c3..44863ee 100644 --- a/handler.go +++ b/handler.go @@ -141,6 +141,10 @@ func isURL(s string) bool { return (p.Scheme == "http" || p.Scheme == "https") && p.Host != "" } +// googleTokenURL is Google's OAuth2 token endpoint. It is a variable so tests +// can point the callback exchange at a stub server. +var googleTokenURL = "https://oauth2.googleapis.com/token" + type CallbackHandler struct { OAuth2ClientID string OAuth2ClientSecret string @@ -197,7 +201,7 @@ func (h CallbackHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { params.Set("client_id", h.OAuth2ClientID) params.Set("client_secret", h.OAuth2ClientSecret) - resp, err := http.Post("https://oauth2.googleapis.com/token", "application/x-www-form-urlencoded", strings.NewReader(params.Encode())) + resp, err := http.Post(googleTokenURL, "application/x-www-form-urlencoded", strings.NewReader(params.Encode())) if err != nil { slog.WarnContext(ctx, "callback: exchange token", "error", err) failResponse(w, r) diff --git a/handler_helpers_test.go b/handler_helpers_test.go new file mode 100644 index 0000000..847f1d7 --- /dev/null +++ b/handler_helpers_test.go @@ -0,0 +1,102 @@ +package main + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "testing" +) + +func TestIsURL(t *testing.T) { + for in, want := range map[string]bool{ + "https://app.example.com/cb": true, + "http://127.0.0.1:8080/cb": true, + "app.example.com": false, // no scheme + "ftp://example.com": false, // unsupported scheme + "https://": false, // no host + "": false, + } { + if got := isURL(in); got != want { + t.Errorf("isURL(%q) = %v, want %v", in, got, want) + } + } +} + +func TestExtractEmailFromIDToken(t *testing.T) { + t.Run("valid", func(t *testing.T) { + email, err := extractEmailFromIDToken(fakeIDToken("user@example.com")) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if email != "user@example.com" { + t.Errorf("email = %q, want user@example.com", email) + } + }) + t.Run("malformed (not 3 parts)", func(t *testing.T) { + if _, err := extractEmailFromIDToken("only.two"); err == nil { + t.Error("expected error for malformed token") + } + }) +} + +func TestHashToken(t *testing.T) { + a := hashToken("deploys-api.abc") + if a == "" { + t.Fatal("hashToken returned empty") + } + if a != hashToken("deploys-api.abc") { + t.Error("hashToken is not deterministic") + } + if a == hashToken("deploys-api.different") { + t.Error("different inputs produced the same hash") + } +} + +func TestMetadataHandler(t *testing.T) { + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/.well-known/oauth-authorization-server", nil) + MetadataHandler{BaseURL: "https://auth.test"}.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want 200", rec.Code) + } + var meta struct { + Issuer string `json:"issuer"` + AuthorizationEndpoint string `json:"authorization_endpoint"` + TokenEndpoint string `json:"token_endpoint"` + RegistrationEndpoint string `json:"registration_endpoint"` + CodeChallengeMethodsSupported []string `json:"code_challenge_methods_supported"` + TokenEndpointAuthMethodsSupported []string `json:"token_endpoint_auth_methods_supported"` + } + if err := json.NewDecoder(rec.Body).Decode(&meta); err != nil { + t.Fatalf("decode: %v", err) + } + if meta.Issuer != "https://auth.test" { + t.Errorf("issuer = %q", meta.Issuer) + } + if meta.AuthorizationEndpoint != "https://auth.test/" { + t.Errorf("authorization_endpoint = %q", meta.AuthorizationEndpoint) + } + if meta.TokenEndpoint != "https://auth.test/token" { + t.Errorf("token_endpoint = %q", meta.TokenEndpoint) + } + if meta.RegistrationEndpoint != "https://auth.test/register" { + t.Errorf("registration_endpoint = %q", meta.RegistrationEndpoint) + } + if !contains(meta.CodeChallengeMethodsSupported, "S256") { + t.Errorf("code_challenge_methods_supported = %v, want S256", meta.CodeChallengeMethodsSupported) + } + if !contains(meta.TokenEndpointAuthMethodsSupported, "none") || + !contains(meta.TokenEndpointAuthMethodsSupported, "client_secret_post") { + t.Errorf("token_endpoint_auth_methods_supported = %v", meta.TokenEndpointAuthMethodsSupported) + } +} + +func contains(s []string, v string) bool { + for _, x := range s { + if x == v { + return true + } + } + return false +} diff --git a/introspect_test.go b/introspect_test.go new file mode 100644 index 0000000..af0c9e2 --- /dev/null +++ b/introspect_test.go @@ -0,0 +1,111 @@ +package main + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "net/url" + "strings" + "testing" + + "github.com/DATA-DOG/go-sqlmock" +) + +func introspectReq(t *testing.T, auth, token string) *http.Request { + t.Helper() + form := url.Values{"token": {token}} + req := httptest.NewRequest(http.MethodPost, "/introspect", strings.NewReader(form.Encode())) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + if auth != "" { + req.Header.Set("Authorization", auth) + } + return req +} + +func TestIntrospectHandler_NotConfigured(t *testing.T) { + db, _ := newMock(t) + rec := httptest.NewRecorder() + IntrospectHandler{Token: ""}.ServeHTTP(rec, withDB(introspectReq(t, "Bearer x", "tok"), db)) + if rec.Code != http.StatusServiceUnavailable { + t.Fatalf("status = %d, want 503", rec.Code) + } +} + +func TestIntrospectHandler_Unauthorized(t *testing.T) { + db, _ := newMock(t) + rec := httptest.NewRecorder() + IntrospectHandler{Token: "secret"}.ServeHTTP(rec, withDB(introspectReq(t, "Bearer wrong", "tok"), db)) + if rec.Code != http.StatusUnauthorized { + t.Fatalf("status = %d, want 401", rec.Code) + } +} + +func TestIntrospectHandler_ActiveToken(t *testing.T) { + db, mock := newMock(t) + mock.ExpectQuery("from user_tokens"). + WillReturnRows(sqlmock.NewRows([]string{"email", "exp"}).AddRow("user@example.com", int64(1893456000))) + + rec := httptest.NewRecorder() + IntrospectHandler{Token: "secret"}.ServeHTTP(rec, withDB(introspectReq(t, "Bearer secret", "deploys-api.abc"), db)) + + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want 200; body=%s", rec.Code, rec.Body.String()) + } + var resp struct { + Active bool `json:"active"` + Sub string `json:"sub"` + Exp int64 `json:"exp"` + } + if err := json.NewDecoder(rec.Body).Decode(&resp); err != nil { + t.Fatalf("decode: %v", err) + } + if !resp.Active { + t.Error("active = false, want true") + } + if resp.Sub != "user@example.com" { + t.Errorf("sub = %q", resp.Sub) + } + if resp.Exp != 1893456000 { + t.Errorf("exp = %d", resp.Exp) + } + assertExpectations(t, mock) +} + +func TestIntrospectHandler_UnknownToken(t *testing.T) { + db, mock := newMock(t) + mock.ExpectQuery("from user_tokens").WillReturnError(sqlNoRows()) + + rec := httptest.NewRecorder() + IntrospectHandler{Token: "secret"}.ServeHTTP(rec, withDB(introspectReq(t, "Bearer secret", "deploys-api.gone"), db)) + + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want 200", rec.Code) + } + var resp struct { + Active bool `json:"active"` + } + if err := json.NewDecoder(rec.Body).Decode(&resp); err != nil { + t.Fatalf("decode: %v", err) + } + if resp.Active { + t.Error("active = true, want false for unknown token") + } + assertExpectations(t, mock) +} + +func TestIntrospectHandler_EmptyToken(t *testing.T) { + db, _ := newMock(t) // no DB lookup for empty token + rec := httptest.NewRecorder() + IntrospectHandler{Token: "secret"}.ServeHTTP(rec, withDB(introspectReq(t, "Bearer secret", ""), db)) + + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want 200", rec.Code) + } + var resp struct { + Active bool `json:"active"` + } + json.NewDecoder(rec.Body).Decode(&resp) + if resp.Active { + t.Error("active = true, want false for empty token") + } +} diff --git a/pkce_test.go b/pkce_test.go new file mode 100644 index 0000000..d3d9d56 --- /dev/null +++ b/pkce_test.go @@ -0,0 +1,86 @@ +package main + +import "testing" + +func TestVerifyPKCE(t *testing.T) { + verifier, challenge := pkcePair() + + cases := []struct { + name string + verifier string + chal string + method string + want bool + }{ + {"valid S256", verifier, challenge, "S256", true}, + {"wrong verifier", "not-the-verifier", challenge, "S256", false}, + {"empty verifier", "", challenge, "S256", false}, + {"empty challenge", verifier, "", "S256", false}, + {"plain method rejected", verifier, verifier, "plain", false}, + {"empty method rejected", verifier, challenge, "", false}, + } + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + if got := verifyPKCE(c.verifier, c.chal, c.method); got != c.want { + t.Errorf("verifyPKCE(%q,%q,%q) = %v, want %v", c.verifier, c.chal, c.method, got, c.want) + } + }) + } +} + +func TestIsLoopbackHost(t *testing.T) { + for host, want := range map[string]bool{ + "127.0.0.1": true, + "::1": true, + "localhost": true, + "example.com": false, + "10.0.0.1": false, + "": false, + } { + if got := isLoopbackHost(host); got != want { + t.Errorf("isLoopbackHost(%q) = %v, want %v", host, got, want) + } + } +} + +func TestRedirectURIAllowed(t *testing.T) { + cases := []struct { + name string + registered []string + got string + want bool + }{ + {"exact https", []string{"https://app.example.com/cb"}, "https://app.example.com/cb", true}, + {"loopback ignores port", []string{"http://127.0.0.1:1234/callback"}, "http://127.0.0.1:55001/callback", true}, + {"localhost ignores port", []string{"http://localhost:1/callback"}, "http://localhost:9999/callback", true}, + {"loopback path mismatch", []string{"http://127.0.0.1:1234/callback"}, "http://127.0.0.1:55001/other", false}, + {"https port mismatch", []string{"https://app.example.com:443/cb"}, "https://app.example.com:8443/cb", false}, + {"scheme mismatch", []string{"https://app.example.com/cb"}, "http://app.example.com/cb", false}, + {"host mismatch", []string{"https://app.example.com/cb"}, "https://evil.example.com/cb", false}, + {"not in list", []string{"https://a.example.com/cb"}, "https://b.example.com/cb", false}, + {"matches second entry", []string{"https://a.example.com/cb", "http://127.0.0.1:1/cb"}, "http://127.0.0.1:42/cb", true}, + } + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + if got := redirectURIAllowed(c.registered, c.got); got != c.want { + t.Errorf("redirectURIAllowed(%v, %q) = %v, want %v", c.registered, c.got, got, c.want) + } + }) + } +} + +func TestValidRegistrationRedirectURI(t *testing.T) { + for uri, want := range map[string]bool{ + "https://app.example.com/cb": true, + "http://127.0.0.1:1234/cb": true, + "http://localhost/cb": true, + "http://app.example.com/cb": false, // plain http only allowed for loopback + "ftp://example.com": false, + "not-a-url": false, + "": false, + } { + if got := validRegistrationRedirectURI(uri); got != want { + t.Errorf("validRegistrationRedirectURI(%q) = %v, want %v", uri, got, want) + } + } +} diff --git a/redirect_test.go b/redirect_test.go new file mode 100644 index 0000000..4ae477c --- /dev/null +++ b/redirect_test.go @@ -0,0 +1,197 @@ +package main + +import ( + "net/http" + "net/http/httptest" + "net/url" + "strings" + "testing" + + "github.com/DATA-DOG/go-sqlmock" +) + +func getReq(t *testing.T, q url.Values) *http.Request { + t.Helper() + return httptest.NewRequest(http.MethodGet, "/?"+q.Encode(), nil) +} + +func findCookie(rec *httptest.ResponseRecorder, name string) *http.Cookie { + for _, c := range rec.Result().Cookies() { + if c.Name == name { + return c + } + } + return nil +} + +// --- OLD FLOW (regression): confidential client redirect to Google --- + +func TestRedirectHandler_Confidential_Success(t *testing.T) { + db, mock := newMock(t) + mock.ExpectQuery("from oauth2_clients"). + WillReturnRows(clientRow("web", "topsecret", "https://app.example.com/*", "", "client_secret_post")) + mock.ExpectExec("insert into oauth2_sessions").WillReturnResult(sqlmock.NewResult(0, 1)) + + q := url.Values{ + "client_id": {"web"}, + "state": {"cbstate"}, + "redirect_uri": {"https://app.example.com/cb"}, + } + rec := httptest.NewRecorder() + RedirectHandler{OAuth2ClientID: "googleclient", BaseURL: "https://auth.test"}. + ServeHTTP(rec, withDB(getReq(t, q), db)) + + if rec.Code != http.StatusFound { + t.Fatalf("status = %d, want 302; body=%s", rec.Code, rec.Body.String()) + } + loc := rec.Header().Get("Location") + if !strings.HasPrefix(loc, "https://accounts.google.com/o/oauth2/auth?") { + t.Errorf("Location = %q, want Google authorize URL", loc) + } + if !strings.Contains(loc, "client_id=googleclient") { + t.Errorf("Location missing google client_id: %q", loc) + } + if !strings.Contains(loc, url.QueryEscape("https://auth.test/callback")) { + t.Errorf("Location missing callback redirect_uri: %q", loc) + } + if c := findCookie(rec, "s"); c == nil || c.Value == "" { + t.Error("session cookie 's' not set") + } + assertExpectations(t, mock) +} + +func TestRedirectHandler_MissingParams(t *testing.T) { + cases := map[string]url.Values{ + "missing client_id": {"state": {"s"}, "redirect_uri": {"https://a/cb"}}, + "missing state": {"client_id": {"web"}, "redirect_uri": {"https://a/cb"}}, + "missing redirect_uri": {"client_id": {"web"}, "state": {"s"}}, + "invalid redirect_uri": {"client_id": {"web"}, "state": {"s"}, "redirect_uri": {"not-a-url"}}, + } + for name, q := range cases { + t.Run(name, func(t *testing.T) { + db, _ := newMock(t) // no DB calls expected + rec := httptest.NewRecorder() + RedirectHandler{OAuth2ClientID: "g", BaseURL: "https://auth.test"}. + ServeHTTP(rec, withDB(getReq(t, q), db)) + if rec.Code != http.StatusBadRequest { + t.Errorf("status = %d, want 400", rec.Code) + } + }) + } +} + +func TestRedirectHandler_UnknownClient(t *testing.T) { + db, mock := newMock(t) + mock.ExpectQuery("from oauth2_clients").WillReturnError(sqlNoRows()) + + q := url.Values{"client_id": {"ghost"}, "state": {"s"}, "redirect_uri": {"https://a/cb"}} + rec := httptest.NewRecorder() + RedirectHandler{OAuth2ClientID: "g", BaseURL: "https://auth.test"}. + ServeHTTP(rec, withDB(getReq(t, q), db)) + if rec.Code != http.StatusBadRequest { + t.Fatalf("status = %d, want 400", rec.Code) + } +} + +func TestRedirectHandler_Confidential_RedirectNotAllowed(t *testing.T) { + db, mock := newMock(t) + mock.ExpectQuery("from oauth2_clients"). + WillReturnRows(clientRow("web", "topsecret", "https://app.example.com/*", "", "client_secret_post")) + + q := url.Values{"client_id": {"web"}, "state": {"s"}, "redirect_uri": {"https://evil.example.com/cb"}} + rec := httptest.NewRecorder() + RedirectHandler{OAuth2ClientID: "g", BaseURL: "https://auth.test"}. + ServeHTTP(rec, withDB(getReq(t, q), db)) + if rec.Code != http.StatusBadRequest { + t.Fatalf("status = %d, want 400", rec.Code) + } +} + +// --- NEW FLOW: public client requires PKCE + exact (loopback) redirect --- + +func TestRedirectHandler_Public_Success(t *testing.T) { + _, challenge := pkcePair() + db, mock := newMock(t) + mock.ExpectQuery("from oauth2_clients"). + WillReturnRows(clientRow("cli", "", "", "http://127.0.0.1:1234/callback", "none")) + // PKCE challenge + method must be persisted on the session for the callback. + mock.ExpectExec("insert into oauth2_sessions"). + WithArgs(sqlmock.AnyArg(), "cli", sqlmock.AnyArg(), "cbstate", + "http://127.0.0.1:55001/callback", challenge, "S256", ""). + WillReturnResult(sqlmock.NewResult(0, 1)) + + q := url.Values{ + "client_id": {"cli"}, + "state": {"cbstate"}, + "redirect_uri": {"http://127.0.0.1:55001/callback"}, // different loopback port, allowed + "code_challenge": {challenge}, + "code_challenge_method": {"S256"}, + } + rec := httptest.NewRecorder() + RedirectHandler{OAuth2ClientID: "googleclient", BaseURL: "https://auth.test"}. + ServeHTTP(rec, withDB(getReq(t, q), db)) + + if rec.Code != http.StatusFound { + t.Fatalf("status = %d, want 302; body=%s", rec.Code, rec.Body.String()) + } + if !strings.HasPrefix(rec.Header().Get("Location"), "https://accounts.google.com/") { + t.Errorf("Location = %q", rec.Header().Get("Location")) + } + assertExpectations(t, mock) +} + +func TestRedirectHandler_Public_MissingChallenge(t *testing.T) { + db, mock := newMock(t) + mock.ExpectQuery("from oauth2_clients"). + WillReturnRows(clientRow("cli", "", "", "http://127.0.0.1:1234/callback", "none")) + + q := url.Values{"client_id": {"cli"}, "state": {"s"}, "redirect_uri": {"http://127.0.0.1:55001/callback"}} + rec := httptest.NewRecorder() + RedirectHandler{OAuth2ClientID: "g", BaseURL: "https://auth.test"}. + ServeHTTP(rec, withDB(getReq(t, q), db)) + if rec.Code != http.StatusBadRequest { + t.Fatalf("status = %d, want 400 (PKCE required for public clients)", rec.Code) + } +} + +func TestRedirectHandler_Public_UnsupportedChallengeMethod(t *testing.T) { + _, challenge := pkcePair() + db, mock := newMock(t) + mock.ExpectQuery("from oauth2_clients"). + WillReturnRows(clientRow("cli", "", "", "http://127.0.0.1:1234/callback", "none")) + + q := url.Values{ + "client_id": {"cli"}, + "state": {"s"}, + "redirect_uri": {"http://127.0.0.1:55001/callback"}, + "code_challenge": {challenge}, + "code_challenge_method": {"plain"}, + } + rec := httptest.NewRecorder() + RedirectHandler{OAuth2ClientID: "g", BaseURL: "https://auth.test"}. + ServeHTTP(rec, withDB(getReq(t, q), db)) + if rec.Code != http.StatusBadRequest { + t.Fatalf("status = %d, want 400 (only S256)", rec.Code) + } +} + +func TestRedirectHandler_Public_RedirectNotRegistered(t *testing.T) { + _, challenge := pkcePair() + db, mock := newMock(t) + mock.ExpectQuery("from oauth2_clients"). + WillReturnRows(clientRow("cli", "", "", "http://127.0.0.1:1234/callback", "none")) + + q := url.Values{ + "client_id": {"cli"}, + "state": {"s"}, + "redirect_uri": {"https://evil.example.com/callback"}, + "code_challenge": {challenge}, + "code_challenge_method": {"S256"}, + } + rec := httptest.NewRecorder() + RedirectHandler{OAuth2ClientID: "g", BaseURL: "https://auth.test"}. + ServeHTTP(rec, withDB(getReq(t, q), db)) + if rec.Code != http.StatusBadRequest { + t.Fatalf("status = %d, want 400 (redirect not registered)", rec.Code) + } +} diff --git a/register_test.go b/register_test.go new file mode 100644 index 0000000..76c348d --- /dev/null +++ b/register_test.go @@ -0,0 +1,83 @@ +package main + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/DATA-DOG/go-sqlmock" +) + +func postJSON(t *testing.T, path, body string) *http.Request { + t.Helper() + req := httptest.NewRequest(http.MethodPost, path, strings.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + return req +} + +func TestRegisterHandler_Success(t *testing.T) { + db, mock := newMock(t) + mock.ExpectExec("insert into oauth2_clients"). + WillReturnResult(sqlmock.NewResult(0, 1)) + + body := `{"client_name":"my cli","redirect_uris":["http://127.0.0.1:1234/cb"]}` + rec := httptest.NewRecorder() + RegisterHandler{BaseURL: "https://auth.test"}.ServeHTTP(rec, withDB(postJSON(t, "/register", body), db)) + + if rec.Code != http.StatusCreated { + t.Fatalf("status = %d, want 201; body=%s", rec.Code, rec.Body.String()) + } + var resp struct { + ClientID string `json:"client_id"` + TokenEndpointAuthMethod string `json:"token_endpoint_auth_method"` + RedirectURIs []string `json:"redirect_uris"` + } + if err := json.NewDecoder(rec.Body).Decode(&resp); err != nil { + t.Fatalf("decode: %v", err) + } + if resp.ClientID == "" { + t.Error("client_id is empty") + } + if resp.TokenEndpointAuthMethod != "none" { + t.Errorf("token_endpoint_auth_method = %q, want none", resp.TokenEndpointAuthMethod) + } + if len(resp.RedirectURIs) != 1 || resp.RedirectURIs[0] != "http://127.0.0.1:1234/cb" { + t.Errorf("redirect_uris = %v", resp.RedirectURIs) + } + assertExpectations(t, mock) +} + +func TestRegisterHandler_HTTPSRedirectAllowed(t *testing.T) { + db, mock := newMock(t) + mock.ExpectExec("insert into oauth2_clients").WillReturnResult(sqlmock.NewResult(0, 1)) + + body := `{"redirect_uris":["https://app.example.com/cb"]}` + rec := httptest.NewRecorder() + RegisterHandler{BaseURL: "https://auth.test"}.ServeHTTP(rec, withDB(postJSON(t, "/register", body), db)) + + if rec.Code != http.StatusCreated { + t.Fatalf("status = %d, want 201; body=%s", rec.Code, rec.Body.String()) + } + assertExpectations(t, mock) +} + +func TestRegisterHandler_Rejections(t *testing.T) { + cases := map[string]string{ + "invalid json": `{not json`, + "no redirect_uris": `{"client_name":"x"}`, + "non-loopback http": `{"redirect_uris":["http://app.example.com/cb"]}`, + "confidential rejected": `{"redirect_uris":["https://app.example.com/cb"],"token_endpoint_auth_method":"client_secret_post"}`, + } + for name, body := range cases { + t.Run(name, func(t *testing.T) { + db, _ := newMock(t) // no insert expected + rec := httptest.NewRecorder() + RegisterHandler{BaseURL: "https://auth.test"}.ServeHTTP(rec, withDB(postJSON(t, "/register", body), db)) + if rec.Code != http.StatusBadRequest { + t.Errorf("status = %d, want 400; body=%s", rec.Code, rec.Body.String()) + } + }) + } +} diff --git a/testutil_test.go b/testutil_test.go new file mode 100644 index 0000000..8505463 --- /dev/null +++ b/testutil_test.go @@ -0,0 +1,56 @@ +package main + +import ( + "crypto/sha256" + "database/sql" + "encoding/base64" + "net/http" + "testing" + + "github.com/DATA-DOG/go-sqlmock" + "github.com/acoshift/pgsql/pgctx" +) + +// newMock returns a mock *sql.DB plus its controller. The DB is closed via +// t.Cleanup. pgctx talks to it through the standard database/sql interface, so +// no transaction (begin/commit) expectations are needed for these handlers. +func newMock(t *testing.T) (*sql.DB, sqlmock.Sqlmock) { + t.Helper() + db, mock, err := sqlmock.New() + if err != nil { + t.Fatalf("sqlmock.New: %v", err) + } + t.Cleanup(func() { db.Close() }) + return db, mock +} + +// withDB binds the mock DB to the request context the way pgctx.Middleware would. +func withDB(req *http.Request, db *sql.DB) *http.Request { + return req.WithContext(pgctx.NewContext(req.Context(), db)) +} + +// pkcePair returns a PKCE verifier and its S256 challenge. +func pkcePair() (verifier, challenge string) { + verifier = "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk" + sum := sha256.Sum256([]byte(verifier)) + challenge = base64.RawURLEncoding.EncodeToString(sum[:]) + return verifier, challenge +} + +// fakeIDToken builds a JWT-shaped token whose payload carries the given email, +// matching what extractEmailFromIDToken expects (RawStdEncoding-encoded body). +func fakeIDToken(email string) string { + payload := base64.RawStdEncoding.EncodeToString([]byte(`{"email":"` + email + `"}`)) + return "header." + payload + ".signature" +} + +func assertExpectations(t *testing.T, mock sqlmock.Sqlmock) { + t.Helper() + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("unmet sqlmock expectations: %v", err) + } +} + +// sqlNoRows is the error a QueryRow returns when nothing matched; the DB layer +// translates it into the ErrOAuth2*NotFound sentinels. +func sqlNoRows() error { return sql.ErrNoRows } diff --git a/token_test.go b/token_test.go new file mode 100644 index 0000000..59a33dd --- /dev/null +++ b/token_test.go @@ -0,0 +1,287 @@ +package main + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "net/url" + "strings" + "testing" + + "github.com/DATA-DOG/go-sqlmock" +) + +func postForm(t *testing.T, path string, form url.Values) *http.Request { + t.Helper() + req := httptest.NewRequest(http.MethodPost, path, strings.NewReader(form.Encode())) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + return req +} + +func clientRow(id, secret, redirectURI, redirectURIs, authMethod string) *sqlmock.Rows { + cols := []string{"id", "secret", "redirect_uri", "redirect_uris", "token_endpoint_auth_method"} + var sec any = secret + if secret == "" && authMethod == "none" { + sec = nil // public clients store NULL secret + } + return sqlmock.NewRows(cols).AddRow(id, sec, redirectURI, redirectURIs, authMethod) +} + +func codeRow(email, challenge, method, redirectURI, resource string) *sqlmock.Rows { + cols := []string{"email", "code_challenge", "code_challenge_method", "redirect_uri", "resource"} + return sqlmock.NewRows(cols).AddRow(email, challenge, method, redirectURI, resource) +} + +// --- OLD FLOW (regression): confidential client + client_secret --- + +func TestTokenHandler_Confidential_Success(t *testing.T) { + db, mock := newMock(t) + mock.ExpectQuery("from oauth2_clients"). + WillReturnRows(clientRow("web", "topsecret", "https://app.example.com/*", "", "client_secret_post")) + mock.ExpectQuery("delete from oauth2_codes"). + WillReturnRows(codeRow("user@example.com", "", "", "https://app.example.com/cb", "")) + mock.ExpectExec("insert into user_tokens"). + WithArgs(sqlmock.AnyArg(), "user@example.com"). + WillReturnResult(sqlmock.NewResult(0, 1)) + + form := url.Values{"client_id": {"web"}, "client_secret": {"topsecret"}, "code": {"abc"}} + rec := httptest.NewRecorder() + TokenHandler{}.ServeHTTP(rec, withDB(postForm(t, "/token", form), db)) + + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want 200; body=%s", rec.Code, rec.Body.String()) + } + var resp struct { + AccessToken string `json:"access_token"` + TokenType string `json:"token_type"` + ExpiresIn int `json:"expires_in"` + RefreshToken string `json:"refresh_token"` + } + if err := json.NewDecoder(rec.Body).Decode(&resp); err != nil { + t.Fatalf("decode: %v", err) + } + // Backward compatibility: the legacy client reads refresh_token. + if resp.RefreshToken == "" { + t.Error("refresh_token must remain populated for the legacy client") + } + if !strings.HasPrefix(resp.RefreshToken, tokenPrefix) { + t.Errorf("refresh_token %q missing prefix %q", resp.RefreshToken, tokenPrefix) + } + // New fields for OAuth2.1 clients. + if resp.AccessToken != resp.RefreshToken { + t.Errorf("access_token (%q) should equal refresh_token (%q)", resp.AccessToken, resp.RefreshToken) + } + if resp.ExpiresIn != tokenTTLSeconds { + t.Errorf("expires_in = %d, want %d", resp.ExpiresIn, tokenTTLSeconds) + } + if resp.TokenType != "Bearer" { + t.Errorf("token_type = %q, want Bearer", resp.TokenType) + } + assertExpectations(t, mock) +} + +func TestTokenHandler_Confidential_WrongSecret(t *testing.T) { + db, mock := newMock(t) + mock.ExpectQuery("from oauth2_clients"). + WillReturnRows(clientRow("web", "topsecret", "https://app.example.com/*", "", "client_secret_post")) + + form := url.Values{"client_id": {"web"}, "client_secret": {"wrong"}, "code": {"abc"}} + rec := httptest.NewRecorder() + TokenHandler{}.ServeHTTP(rec, withDB(postForm(t, "/token", form), db)) + + if rec.Code != http.StatusUnauthorized { + t.Fatalf("status = %d, want 401", rec.Code) + } + assertOAuthError(t, rec, "invalid_client") + assertExpectations(t, mock) +} + +func TestTokenHandler_Confidential_MissingSecret(t *testing.T) { + db, mock := newMock(t) + mock.ExpectQuery("from oauth2_clients"). + WillReturnRows(clientRow("web", "topsecret", "https://app.example.com/*", "", "client_secret_post")) + + form := url.Values{"client_id": {"web"}, "code": {"abc"}} + rec := httptest.NewRecorder() + TokenHandler{}.ServeHTTP(rec, withDB(postForm(t, "/token", form), db)) + + if rec.Code != http.StatusBadRequest { + t.Fatalf("status = %d, want 400", rec.Code) + } + assertOAuthError(t, rec, "invalid_request") +} + +// --- NEW FLOW: public client + PKCE --- + +func TestTokenHandler_Public_PKCE_Success(t *testing.T) { + verifier, challenge := pkcePair() + const redirect = "http://127.0.0.1:5000/callback" + + db, mock := newMock(t) + mock.ExpectQuery("from oauth2_clients"). + WillReturnRows(clientRow("cli", "", "", redirect, "none")) + mock.ExpectQuery("delete from oauth2_codes"). + WillReturnRows(codeRow("user@example.com", challenge, "S256", redirect, "")) + mock.ExpectExec("insert into user_tokens"). + WithArgs(sqlmock.AnyArg(), "user@example.com"). + WillReturnResult(sqlmock.NewResult(0, 1)) + + form := url.Values{ + "grant_type": {"authorization_code"}, + "client_id": {"cli"}, + "code": {"abc"}, + "code_verifier": {verifier}, + "redirect_uri": {redirect}, + } + rec := httptest.NewRecorder() + TokenHandler{}.ServeHTTP(rec, withDB(postForm(t, "/token", form), db)) + + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want 200; body=%s", rec.Code, rec.Body.String()) + } + var resp struct { + AccessToken string `json:"access_token"` + TokenType string `json:"token_type"` + ExpiresIn int `json:"expires_in"` + } + if err := json.NewDecoder(rec.Body).Decode(&resp); err != nil { + t.Fatalf("decode: %v", err) + } + if resp.AccessToken == "" { + t.Error("access_token must be populated") + } + if resp.TokenType != "Bearer" { + t.Errorf("token_type = %q, want Bearer", resp.TokenType) + } + if resp.ExpiresIn != tokenTTLSeconds { + t.Errorf("expires_in = %d, want %d", resp.ExpiresIn, tokenTTLSeconds) + } + assertExpectations(t, mock) +} + +func TestTokenHandler_Public_PKCE_BadVerifier(t *testing.T) { + _, challenge := pkcePair() + const redirect = "http://127.0.0.1:5000/callback" + + db, mock := newMock(t) + mock.ExpectQuery("from oauth2_clients"). + WillReturnRows(clientRow("cli", "", "", redirect, "none")) + mock.ExpectQuery("delete from oauth2_codes"). + WillReturnRows(codeRow("user@example.com", challenge, "S256", redirect, "")) + + form := url.Values{ + "grant_type": {"authorization_code"}, + "client_id": {"cli"}, + "code": {"abc"}, + "code_verifier": {"this-is-the-wrong-verifier"}, + "redirect_uri": {redirect}, + } + rec := httptest.NewRecorder() + TokenHandler{}.ServeHTTP(rec, withDB(postForm(t, "/token", form), db)) + + if rec.Code != http.StatusBadRequest { + t.Fatalf("status = %d, want 400; body=%s", rec.Code, rec.Body.String()) + } + assertOAuthError(t, rec, "invalid_grant") + assertExpectations(t, mock) +} + +func TestTokenHandler_Public_RedirectMismatch(t *testing.T) { + verifier, challenge := pkcePair() + + db, mock := newMock(t) + mock.ExpectQuery("from oauth2_clients"). + WillReturnRows(clientRow("cli", "", "", "http://127.0.0.1:5000/callback", "none")) + mock.ExpectQuery("delete from oauth2_codes"). + WillReturnRows(codeRow("user@example.com", challenge, "S256", "http://127.0.0.1:5000/callback", "")) + + form := url.Values{ + "grant_type": {"authorization_code"}, + "client_id": {"cli"}, + "code": {"abc"}, + "code_verifier": {verifier}, + "redirect_uri": {"http://127.0.0.1:5000/different"}, + } + rec := httptest.NewRecorder() + TokenHandler{}.ServeHTTP(rec, withDB(postForm(t, "/token", form), db)) + + if rec.Code != http.StatusBadRequest { + t.Fatalf("status = %d, want 400", rec.Code) + } + assertOAuthError(t, rec, "invalid_grant") + assertExpectations(t, mock) +} + +func TestTokenHandler_Public_MissingVerifier(t *testing.T) { + db, mock := newMock(t) + mock.ExpectQuery("from oauth2_clients"). + WillReturnRows(clientRow("cli", "", "", "http://127.0.0.1:5000/callback", "none")) + + form := url.Values{"grant_type": {"authorization_code"}, "client_id": {"cli"}, "code": {"abc"}} + rec := httptest.NewRecorder() + TokenHandler{}.ServeHTTP(rec, withDB(postForm(t, "/token", form), db)) + + if rec.Code != http.StatusBadRequest { + t.Fatalf("status = %d, want 400", rec.Code) + } + assertOAuthError(t, rec, "invalid_request") + assertExpectations(t, mock) +} + +// --- shared request validation --- + +func TestTokenHandler_UnsupportedGrantType(t *testing.T) { + db, _ := newMock(t) + form := url.Values{"grant_type": {"password"}, "client_id": {"x"}, "code": {"y"}} + rec := httptest.NewRecorder() + TokenHandler{}.ServeHTTP(rec, withDB(postForm(t, "/token", form), db)) + + if rec.Code != http.StatusBadRequest { + t.Fatalf("status = %d, want 400", rec.Code) + } + assertOAuthError(t, rec, "unsupported_grant_type") +} + +func TestTokenHandler_UnknownClient(t *testing.T) { + db, mock := newMock(t) + mock.ExpectQuery("from oauth2_clients").WillReturnError(sqlNoRows()) + + form := url.Values{"client_id": {"ghost"}, "client_secret": {"x"}, "code": {"y"}} + rec := httptest.NewRecorder() + TokenHandler{}.ServeHTTP(rec, withDB(postForm(t, "/token", form), db)) + + if rec.Code != http.StatusUnauthorized { + t.Fatalf("status = %d, want 401", rec.Code) + } + assertOAuthError(t, rec, "invalid_client") +} + +func TestTokenHandler_InvalidCode(t *testing.T) { + db, mock := newMock(t) + mock.ExpectQuery("from oauth2_clients"). + WillReturnRows(clientRow("web", "topsecret", "https://app.example.com/*", "", "client_secret_post")) + mock.ExpectQuery("delete from oauth2_codes").WillReturnError(sqlNoRows()) + + form := url.Values{"client_id": {"web"}, "client_secret": {"topsecret"}, "code": {"gone"}} + rec := httptest.NewRecorder() + TokenHandler{}.ServeHTTP(rec, withDB(postForm(t, "/token", form), db)) + + if rec.Code != http.StatusBadRequest { + t.Fatalf("status = %d, want 400", rec.Code) + } + assertOAuthError(t, rec, "invalid_grant") + assertExpectations(t, mock) +} + +func assertOAuthError(t *testing.T, rec *httptest.ResponseRecorder, wantCode string) { + t.Helper() + var body struct { + Error string `json:"error"` + } + if err := json.NewDecoder(rec.Body).Decode(&body); err != nil { + t.Fatalf("decode error body: %v", err) + } + if body.Error != wantCode { + t.Errorf("error = %q, want %q", body.Error, wantCode) + } +} From aae4667bb7053b566318b7b35adc2aaf62e89068 Mon Sep 17 00:00:00 2001 From: Thanatat Tamtan Date: Wed, 27 May 2026 07:56:39 +0700 Subject: [PATCH 3/5] docs: note test command, sqlmock, and new env vars in CLAUDE.md Co-Authored-By: Claude Opus 4.7 (1M context) --- CLAUDE.md | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/CLAUDE.md b/CLAUDE.md index 145769e..b850c4c 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -7,9 +7,12 @@ This file provides guidance to Claude Code (claude.ai/code) when working with co ```bash go build -o auth . # build binary go vet ./... # lint +go test ./... # run tests ``` -No test files exist in this codebase. +Handler tests fake the DB with `github.com/DATA-DOG/go-sqlmock` (bound via +`pgctx.NewContext`; no transactions are used) and stub Google's token endpoint +through the `googleTokenURL` package var. ### Required environment variables @@ -19,6 +22,8 @@ No test files exist in this codebase. | `OAUTH2_CLIENT_ID` | Google OAuth app client ID | | `OAUTH2_CLIENT_SECRET` | Google OAuth app client secret | | `PORT` | Listen port (default: `8080`) | +| `BASE_URL` | Public base URL of this service (default: `https://auth.deploys.app`) | +| `INTROSPECTION_TOKEN` | Shared secret guarding `POST /introspect`; unset disables the endpoint | ## Architecture From 391973b5b7d44782ffd9f1291b81006c47cdf2be Mon Sep 17 00:00:00 2001 From: Thanatat Tamtan Date: Wed, 27 May 2026 08:16:24 +0700 Subject: [PATCH 4/5] test: run against real CockroachDB instead of mocking sql Follow the ../dropbox pattern: a schema package embeds the SQL migrations (01_init pre-MCP + user_tokens, 02_mcp the PR's ALTERs) and tu.Setup starts an isolated in-memory cockroach-go/v2 testserver per test, applying the migration. Tests seed via pgctx and assert against real DB state (token persisted, code consumed, session PKCE carried through), which the sqlmock version could not verify. As a bonus the suite exercises the MCP migration SQL on a real CockroachDB. Google's token endpoint is stubbed via a shared httptest server keyed by the auth code (parallel-safe), set once through the googleTokenURL package var. Drops the go-sqlmock dependency; adds cockroach-go/v2 (test only). Co-Authored-By: Claude Opus 4.7 (1M context) --- CLAUDE.md | 9 ++- callback_test.go | 135 +++++++++++++++++++++++----------------- go.mod | 7 ++- go.sum | 24 ++++++-- introspect_test.go | 78 ++++++++++++++--------- pkce_test.go | 24 ++++---- redirect_test.go | 88 ++++++++++++++------------ register_test.go | 33 +++++----- schema/01_init.sql | 36 +++++++++++ schema/02_mcp.sql | 16 +++++ schema/schema.go | 76 +++++++++++++++++++++++ setup_test.go | 17 +++++ testutil_test.go | 146 +++++++++++++++++++++++++++++++++++-------- token_test.go | 150 +++++++++++++++++++++------------------------ tu/tu.go | 65 ++++++++++++++++++++ 15 files changed, 641 insertions(+), 263 deletions(-) create mode 100644 schema/01_init.sql create mode 100644 schema/02_mcp.sql create mode 100644 schema/schema.go create mode 100644 setup_test.go create mode 100644 tu/tu.go diff --git a/CLAUDE.md b/CLAUDE.md index b850c4c..0f41c4c 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -10,9 +10,12 @@ go vet ./... # lint go test ./... # run tests ``` -Handler tests fake the DB with `github.com/DATA-DOG/go-sqlmock` (bound via -`pgctx.NewContext`; no transactions are used) and stub Google's token endpoint -through the `googleTokenURL` package var. +Handler tests run against a real CockroachDB: `tu.Setup()` starts an isolated +in-memory `cockroach-go/v2/testserver` per test and `schema.Migrate` applies the +embedded `schema/*.sql` migrations. `newTestDB(t)` (setup_test.go) returns a +`*tu.Context`; use `db.Ctx()` for both seeding (via `pgctx`) and the request +context. Google's token endpoint is stubbed through the `googleTokenURL` package +var. Tests download the CockroachDB binary on first run. ### Required environment variables diff --git a/callback_test.go b/callback_test.go index db4e1ba..0cbe892 100644 --- a/callback_test.go +++ b/callback_test.go @@ -6,51 +6,53 @@ import ( "net/http/httptest" "net/url" "strings" + "sync" "testing" - - "github.com/DATA-DOG/go-sqlmock" ) -func sessionRow(clientID, state, callbackState, callbackURL, challenge, method, resource string) *sqlmock.Rows { - cols := []string{"client_id", "state", "callback_state", "callback_url", "code_challenge", "code_challenge_method", "resource"} - return sqlmock.NewRows(cols).AddRow(clientID, state, callbackState, callbackURL, challenge, method, resource) -} +// A single mock Google token endpoint is started lazily and shared by all +// callback tests. It keys the returned id_token email by the authorization +// code in the request, so parallel tests do not race over googleTokenURL. +var ( + googleMockOnce sync.Once + googleMockMap sync.Map // code -> email +) -// stubGoogle points the callback token exchange at a local server that returns -// an id_token carrying email, and restores the real URL afterwards. -func stubGoogle(t *testing.T, email string) { +func registerGoogleCode(t *testing.T, code, email string) { t.Helper() - srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") - fmt.Fprintf(w, `{"id_token":%q}`, fakeIDToken(email)) - })) - old := googleTokenURL - googleTokenURL = srv.URL - t.Cleanup(func() { - googleTokenURL = old - srv.Close() + googleMockOnce.Do(func() { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + r.ParseForm() + email := "default@example.com" + if v, ok := googleMockMap.Load(r.FormValue("code")); ok { + email = v.(string) + } + w.Header().Set("Content-Type", "application/json") + fmt.Fprintf(w, `{"id_token":%q}`, fakeIDToken(email)) + })) + googleTokenURL = srv.URL }) + googleMockMap.Store(code, email) + t.Cleanup(func() { googleMockMap.Delete(code) }) } // --- OLD FLOW (regression): Google callback issues an internal code --- func TestCallbackHandler_Confidential_Success(t *testing.T) { - stubGoogle(t, "user@example.com") + t.Parallel() + registerGoogleCode(t, t.Name(), "user@example.com") - db, mock := newMock(t) - mock.ExpectQuery("from oauth2_sessions"). - WillReturnRows(sessionRow("web", "gstate", "cbstate", "https://app.example.com/cb", "", "", "")) - mock.ExpectExec("delete from oauth2_sessions").WillReturnResult(sqlmock.NewResult(0, 1)) - mock.ExpectExec("insert into oauth2_codes"). - WithArgs(sqlmock.AnyArg(), "web", "user@example.com", "", "", "https://app.example.com/cb", ""). - WillReturnResult(sqlmock.NewResult(0, 1)) + tdb := newTestDB(t) + ctx := tdb.Ctx() + seedConfidentialClient(t, ctx, "web", "topsecret", "https://app.example.com/*") + seedSession(t, ctx, "sess123", "web", "gstate", "cbstate", "https://app.example.com/cb", "", "", "") - q := url.Values{"state": {"gstate"}, "code": {"google-code"}} - req := httptest.NewRequest(http.MethodGet, "/callback?"+q.Encode(), nil) + q := url.Values{"state": {"gstate"}, "code": {t.Name()}} + req := getReqPath(t, "/callback", q) req.AddCookie(&http.Cookie{Name: "s", Value: "sess123"}) rec := httptest.NewRecorder() CallbackHandler{OAuth2ClientID: "g", OAuth2ClientSecret: "gs", BaseURL: "https://auth.test"}. - ServeHTTP(rec, withDB(req, db)) + ServeHTTP(rec, req.WithContext(ctx)) if rec.Code != http.StatusFound { t.Fatalf("status = %d, want 302; body=%s", rec.Code, rec.Body.String()) @@ -63,63 +65,86 @@ func TestCallbackHandler_Confidential_Success(t *testing.T) { if u.Query().Get("state") != "cbstate" { t.Errorf("callback state = %q, want cbstate", u.Query().Get("state")) } - if u.Query().Get("code") == "" { - t.Error("callback code is empty") + returnedCode := u.Query().Get("code") + if returnedCode == "" { + t.Fatal("callback code is empty") + } + // The session is single-use and an internal code was minted for the email. + if n := countRows(t, ctx, "oauth2_sessions"); n != 0 { + t.Errorf("oauth2_sessions = %d, want 0 (consumed)", n) + } + email, challenge, method := codeEmailPKCE(t, ctx, returnedCode) + if email != "user@example.com" { + t.Errorf("code email = %q, want user@example.com", email) + } + if challenge != "" || method != "" { + t.Errorf("confidential code carried PKCE: (%q,%q)", challenge, method) } - assertExpectations(t, mock) } func TestCallbackHandler_MissingSessionCookie(t *testing.T) { - db, _ := newMock(t) + t.Parallel() q := url.Values{"state": {"gstate"}, "code": {"google-code"}} - req := httptest.NewRequest(http.MethodGet, "/callback?"+q.Encode(), nil) // no cookie + req := getReqPath(t, "/callback", q) // no cookie rec := httptest.NewRecorder() - CallbackHandler{BaseURL: "https://auth.test"}.ServeHTTP(rec, withDB(req, db)) + CallbackHandler{BaseURL: "https://auth.test"}.ServeHTTP(rec, req) if rec.Code != http.StatusBadRequest { t.Fatalf("status = %d, want 400", rec.Code) } } func TestCallbackHandler_StateMismatch(t *testing.T) { - db, mock := newMock(t) - mock.ExpectQuery("from oauth2_sessions"). - WillReturnRows(sessionRow("web", "expected-state", "cbstate", "https://app.example.com/cb", "", "", "")) - mock.ExpectExec("delete from oauth2_sessions").WillReturnResult(sqlmock.NewResult(0, 1)) + t.Parallel() + tdb := newTestDB(t) + ctx := tdb.Ctx() + seedSession(t, ctx, "sess123", "web", "expected-state", "cbstate", "https://app.example.com/cb", "", "", "") q := url.Values{"state": {"attacker-state"}, "code": {"google-code"}} - req := httptest.NewRequest(http.MethodGet, "/callback?"+q.Encode(), nil) + req := getReqPath(t, "/callback", q) req.AddCookie(&http.Cookie{Name: "s", Value: "sess123"}) rec := httptest.NewRecorder() - CallbackHandler{BaseURL: "https://auth.test"}.ServeHTTP(rec, withDB(req, db)) + CallbackHandler{BaseURL: "https://auth.test"}.ServeHTTP(rec, req.WithContext(ctx)) if rec.Code != http.StatusBadRequest { t.Fatalf("status = %d, want 400 (state mismatch)", rec.Code) } - assertExpectations(t, mock) + // Session is consumed on read even on mismatch; no code should be minted. + if n := countRows(t, ctx, "oauth2_codes"); n != 0 { + t.Errorf("oauth2_codes = %d, want 0", n) + } } // --- NEW FLOW: PKCE challenge from the session is carried onto the code --- func TestCallbackHandler_Public_CarriesPKCE(t *testing.T) { - stubGoogle(t, "user@example.com") + t.Parallel() + registerGoogleCode(t, t.Name(), "user@example.com") _, challenge := pkcePair() - db, mock := newMock(t) - mock.ExpectQuery("from oauth2_sessions"). - WillReturnRows(sessionRow("cli", "gstate", "cbstate", "http://127.0.0.1:55001/callback", challenge, "S256", "https://api.deploys.app")) - mock.ExpectExec("delete from oauth2_sessions").WillReturnResult(sqlmock.NewResult(0, 1)) - // The code must be persisted with the PKCE challenge, redirect and resource. - mock.ExpectExec("insert into oauth2_codes"). - WithArgs(sqlmock.AnyArg(), "cli", "user@example.com", challenge, "S256", "http://127.0.0.1:55001/callback", "https://api.deploys.app"). - WillReturnResult(sqlmock.NewResult(0, 1)) + tdb := newTestDB(t) + ctx := tdb.Ctx() + seedPublicClient(t, ctx, "cli", "http://127.0.0.1:55001/callback") + seedSession(t, ctx, "sess123", "cli", "gstate", "cbstate", "http://127.0.0.1:55001/callback", challenge, "S256", "https://api.deploys.app") - q := url.Values{"state": {"gstate"}, "code": {"google-code"}} - req := httptest.NewRequest(http.MethodGet, "/callback?"+q.Encode(), nil) + q := url.Values{"state": {"gstate"}, "code": {t.Name()}} + req := getReqPath(t, "/callback", q) req.AddCookie(&http.Cookie{Name: "s", Value: "sess123"}) rec := httptest.NewRecorder() - CallbackHandler{BaseURL: "https://auth.test"}.ServeHTTP(rec, withDB(req, db)) + CallbackHandler{BaseURL: "https://auth.test"}.ServeHTTP(rec, req.WithContext(ctx)) if rec.Code != http.StatusFound { t.Fatalf("status = %d, want 302; body=%s", rec.Code, rec.Body.String()) } - assertExpectations(t, mock) + u, _ := url.Parse(rec.Header().Get("Location")) + email, gotChallenge, gotMethod := codeEmailPKCE(t, ctx, u.Query().Get("code")) + if email != "user@example.com" { + t.Errorf("code email = %q", email) + } + if gotChallenge != challenge || gotMethod != "S256" { + t.Errorf("code PKCE = (%q,%q), want (%q,S256)", gotChallenge, gotMethod, challenge) + } +} + +func getReqPath(t *testing.T, path string, q url.Values) *http.Request { + t.Helper() + return httptest.NewRequest(http.MethodGet, path+"?"+q.Encode(), nil) } diff --git a/go.mod b/go.mod index b64f078..c1f7861 100644 --- a/go.mod +++ b/go.mod @@ -7,4 +7,9 @@ require ( github.com/lib/pq v1.12.3 ) -require github.com/DATA-DOG/go-sqlmock v1.5.2 +require ( + github.com/cockroachdb/cockroach-go/v2 v2.4.3 + github.com/gofrs/flock v0.12.1 // indirect + golang.org/x/sys v0.28.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) diff --git a/go.sum b/go.sum index d5e7c24..baba46b 100644 --- a/go.sum +++ b/go.sum @@ -1,15 +1,29 @@ -github.com/DATA-DOG/go-sqlmock v1.5.2 h1:OcvFkGmslmlZibjAjaHm3L//6LiuBgolP7OputlJIzU= -github.com/DATA-DOG/go-sqlmock v1.5.2/go.mod h1:88MAG/4G7SMwSE3CeA0ZKzrT5CiOU3OJ+JlNzwDqpNU= +github.com/DATA-DOG/go-sqlmock v1.5.0 h1:Shsta01QNfFxHCfpW6YH2STWB0MudeXXEWMr20OEh60= +github.com/DATA-DOG/go-sqlmock v1.5.0/go.mod h1:f/Ixk793poVmq4qj/V1dPUg2JEAKC73Q5eFN3EC/SaM= github.com/acoshift/pgsql v0.16.0 h1:ak+fwy8Xnx0uZBhSmvFhGGk3a7EQNu8IiRLpnP99IT4= github.com/acoshift/pgsql v0.16.0/go.mod h1:HtdMa77CYeRb9pD6+cT/ZPjpudiUVQ2OIKv6QXjgEZw= +github.com/cockroachdb/cockroach-go/v2 v2.4.3 h1:LJO3K3jC5WXvMePRQSJE1NsIGoFGcEx1LW83W6RAlhw= +github.com/cockroachdb/cockroach-go/v2 v2.4.3/go.mod h1:9U179XbCx4qFWtNhc7BiWLPfuyMVQ7qdAhfrwLz1vH0= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/kisielk/sqlstruct v0.0.0-20201105191214-5f3e10d3ab46/go.mod h1:yyMNCyc/Ib3bDTKd379tNMpB/7/H5TjM2Y9QJ5THLbE= +github.com/gofrs/flock v0.12.1 h1:MTLVXXHf8ekldpJk3AKicLij9MdwOWkZ+a/jHHZby9E= +github.com/gofrs/flock v0.12.1/go.mod h1:9zxTsyu5xtJ9DK+1tFZyibEV7y3uwDxPPfbxeeHCoD0= +github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= +github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/lib/pq v1.12.3 h1:tTWxr2YLKwIvK90ZXEw8GP7UFHtcbTtty8zsI+YjrfQ= github.com/lib/pq v1.12.3/go.mod h1:/p+8NSbOcwzAEI7wiMXFlgydTwcgTr3OSKMsD2BitpA= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= -github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= +github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8= +github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= +github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= +github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +golang.org/x/sys v0.28.0 h1:Fksou7UEQUWlKvIdsqzJmUmCX3cZuD2+P3XyyzwMhlA= +golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/introspect_test.go b/introspect_test.go index af0c9e2..c8484d1 100644 --- a/introspect_test.go +++ b/introspect_test.go @@ -7,8 +7,6 @@ import ( "net/url" "strings" "testing" - - "github.com/DATA-DOG/go-sqlmock" ) func introspectReq(t *testing.T, auth, token string) *http.Request { @@ -23,30 +21,32 @@ func introspectReq(t *testing.T, auth, token string) *http.Request { } func TestIntrospectHandler_NotConfigured(t *testing.T) { - db, _ := newMock(t) + t.Parallel() rec := httptest.NewRecorder() - IntrospectHandler{Token: ""}.ServeHTTP(rec, withDB(introspectReq(t, "Bearer x", "tok"), db)) + IntrospectHandler{Token: ""}.ServeHTTP(rec, introspectReq(t, "Bearer x", "tok")) if rec.Code != http.StatusServiceUnavailable { t.Fatalf("status = %d, want 503", rec.Code) } } func TestIntrospectHandler_Unauthorized(t *testing.T) { - db, _ := newMock(t) + t.Parallel() rec := httptest.NewRecorder() - IntrospectHandler{Token: "secret"}.ServeHTTP(rec, withDB(introspectReq(t, "Bearer wrong", "tok"), db)) + IntrospectHandler{Token: "secret"}.ServeHTTP(rec, introspectReq(t, "Bearer wrong", "tok")) if rec.Code != http.StatusUnauthorized { t.Fatalf("status = %d, want 401", rec.Code) } } func TestIntrospectHandler_ActiveToken(t *testing.T) { - db, mock := newMock(t) - mock.ExpectQuery("from user_tokens"). - WillReturnRows(sqlmock.NewRows([]string{"email", "exp"}).AddRow("user@example.com", int64(1893456000))) + t.Parallel() + tdb := newTestDB(t) + ctx := tdb.Ctx() + const raw = "deploys-api.abc" + seedToken(t, ctx, hashToken(raw), "user@example.com") rec := httptest.NewRecorder() - IntrospectHandler{Token: "secret"}.ServeHTTP(rec, withDB(introspectReq(t, "Bearer secret", "deploys-api.abc"), db)) + IntrospectHandler{Token: "secret"}.ServeHTTP(rec, introspectReq(t, "Bearer secret", raw).WithContext(ctx)) if rec.Code != http.StatusOK { t.Fatalf("status = %d, want 200; body=%s", rec.Code, rec.Body.String()) @@ -65,47 +65,69 @@ func TestIntrospectHandler_ActiveToken(t *testing.T) { if resp.Sub != "user@example.com" { t.Errorf("sub = %q", resp.Sub) } - if resp.Exp != 1893456000 { - t.Errorf("exp = %d", resp.Exp) + if resp.Exp == 0 { + t.Error("exp = 0, want a future unix timestamp") } - assertExpectations(t, mock) } func TestIntrospectHandler_UnknownToken(t *testing.T) { - db, mock := newMock(t) - mock.ExpectQuery("from user_tokens").WillReturnError(sqlNoRows()) + t.Parallel() + tdb := newTestDB(t) + ctx := tdb.Ctx() // empty user_tokens rec := httptest.NewRecorder() - IntrospectHandler{Token: "secret"}.ServeHTTP(rec, withDB(introspectReq(t, "Bearer secret", "deploys-api.gone"), db)) + IntrospectHandler{Token: "secret"}.ServeHTTP(rec, introspectReq(t, "Bearer secret", "deploys-api.gone").WithContext(ctx)) if rec.Code != http.StatusOK { t.Fatalf("status = %d, want 200", rec.Code) } - var resp struct { - Active bool `json:"active"` + if active := decodeActive(t, rec); active { + t.Error("active = true, want false for unknown token") } - if err := json.NewDecoder(rec.Body).Decode(&resp); err != nil { - t.Fatalf("decode: %v", err) +} + +func TestIntrospectHandler_ExpiredToken(t *testing.T) { + t.Parallel() + tdb := newTestDB(t) + ctx := tdb.Ctx() + const raw = "deploys-api.expired" + // Insert a token that already expired. + if _, err := pgctxExec(t, ctx, `insert into user_tokens (token, email, expires_at) values ($1, $2, now() - interval '1 hour')`, hashToken(raw), "old@example.com"); err != nil { + t.Fatal(err) } - if resp.Active { - t.Error("active = true, want false for unknown token") + + rec := httptest.NewRecorder() + IntrospectHandler{Token: "secret"}.ServeHTTP(rec, introspectReq(t, "Bearer secret", raw).WithContext(ctx)) + + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want 200", rec.Code) + } + if active := decodeActive(t, rec); active { + t.Error("active = true, want false for expired token") } - assertExpectations(t, mock) } func TestIntrospectHandler_EmptyToken(t *testing.T) { - db, _ := newMock(t) // no DB lookup for empty token + t.Parallel() + // Empty token short-circuits before any DB lookup. rec := httptest.NewRecorder() - IntrospectHandler{Token: "secret"}.ServeHTTP(rec, withDB(introspectReq(t, "Bearer secret", ""), db)) + IntrospectHandler{Token: "secret"}.ServeHTTP(rec, introspectReq(t, "Bearer secret", "")) if rec.Code != http.StatusOK { t.Fatalf("status = %d, want 200", rec.Code) } + if active := decodeActive(t, rec); active { + t.Error("active = true, want false for empty token") + } +} + +func decodeActive(t *testing.T, rec *httptest.ResponseRecorder) bool { + t.Helper() var resp struct { Active bool `json:"active"` } - json.NewDecoder(rec.Body).Decode(&resp) - if resp.Active { - t.Error("active = true, want false for empty token") + if err := json.NewDecoder(rec.Body).Decode(&resp); err != nil { + t.Fatalf("decode: %v", err) } + return resp.Active } diff --git a/pkce_test.go b/pkce_test.go index d3d9d56..55a399b 100644 --- a/pkce_test.go +++ b/pkce_test.go @@ -30,12 +30,12 @@ func TestVerifyPKCE(t *testing.T) { func TestIsLoopbackHost(t *testing.T) { for host, want := range map[string]bool{ - "127.0.0.1": true, - "::1": true, - "localhost": true, + "127.0.0.1": true, + "::1": true, + "localhost": true, "example.com": false, - "10.0.0.1": false, - "": false, + "10.0.0.1": false, + "": false, } { if got := isLoopbackHost(host); got != want { t.Errorf("isLoopbackHost(%q) = %v, want %v", host, got, want) @@ -71,13 +71,13 @@ func TestRedirectURIAllowed(t *testing.T) { func TestValidRegistrationRedirectURI(t *testing.T) { for uri, want := range map[string]bool{ - "https://app.example.com/cb": true, - "http://127.0.0.1:1234/cb": true, - "http://localhost/cb": true, - "http://app.example.com/cb": false, // plain http only allowed for loopback - "ftp://example.com": false, - "not-a-url": false, - "": false, + "https://app.example.com/cb": true, + "http://127.0.0.1:1234/cb": true, + "http://localhost/cb": true, + "http://app.example.com/cb": false, // plain http only allowed for loopback + "ftp://example.com": false, + "not-a-url": false, + "": false, } { if got := validRegistrationRedirectURI(uri); got != want { t.Errorf("validRegistrationRedirectURI(%q) = %v, want %v", uri, got, want) diff --git a/redirect_test.go b/redirect_test.go index 4ae477c..5337d1c 100644 --- a/redirect_test.go +++ b/redirect_test.go @@ -6,8 +6,6 @@ import ( "net/url" "strings" "testing" - - "github.com/DATA-DOG/go-sqlmock" ) func getReq(t *testing.T, q url.Values) *http.Request { @@ -27,10 +25,10 @@ func findCookie(rec *httptest.ResponseRecorder, name string) *http.Cookie { // --- OLD FLOW (regression): confidential client redirect to Google --- func TestRedirectHandler_Confidential_Success(t *testing.T) { - db, mock := newMock(t) - mock.ExpectQuery("from oauth2_clients"). - WillReturnRows(clientRow("web", "topsecret", "https://app.example.com/*", "", "client_secret_post")) - mock.ExpectExec("insert into oauth2_sessions").WillReturnResult(sqlmock.NewResult(0, 1)) + t.Parallel() + tdb := newTestDB(t) + ctx := tdb.Ctx() + seedConfidentialClient(t, ctx, "web", "topsecret", "https://app.example.com/*") q := url.Values{ "client_id": {"web"}, @@ -39,7 +37,7 @@ func TestRedirectHandler_Confidential_Success(t *testing.T) { } rec := httptest.NewRecorder() RedirectHandler{OAuth2ClientID: "googleclient", BaseURL: "https://auth.test"}. - ServeHTTP(rec, withDB(getReq(t, q), db)) + ServeHTTP(rec, getReq(t, q).WithContext(ctx)) if rec.Code != http.StatusFound { t.Fatalf("status = %d, want 302; body=%s", rec.Code, rec.Body.String()) @@ -57,10 +55,14 @@ func TestRedirectHandler_Confidential_Success(t *testing.T) { if c := findCookie(rec, "s"); c == nil || c.Value == "" { t.Error("session cookie 's' not set") } - assertExpectations(t, mock) + if n := countRows(t, ctx, "oauth2_sessions"); n != 1 { + t.Errorf("oauth2_sessions = %d, want 1", n) + } } func TestRedirectHandler_MissingParams(t *testing.T) { + t.Parallel() + // All rejected before any DB access, so no test DB is needed. cases := map[string]url.Values{ "missing client_id": {"state": {"s"}, "redirect_uri": {"https://a/cb"}}, "missing state": {"client_id": {"web"}, "redirect_uri": {"https://a/cb"}}, @@ -69,10 +71,8 @@ func TestRedirectHandler_MissingParams(t *testing.T) { } for name, q := range cases { t.Run(name, func(t *testing.T) { - db, _ := newMock(t) // no DB calls expected rec := httptest.NewRecorder() - RedirectHandler{OAuth2ClientID: "g", BaseURL: "https://auth.test"}. - ServeHTTP(rec, withDB(getReq(t, q), db)) + RedirectHandler{OAuth2ClientID: "g", BaseURL: "https://auth.test"}.ServeHTTP(rec, getReq(t, q)) if rec.Code != http.StatusBadRequest { t.Errorf("status = %d, want 400", rec.Code) } @@ -81,27 +81,29 @@ func TestRedirectHandler_MissingParams(t *testing.T) { } func TestRedirectHandler_UnknownClient(t *testing.T) { - db, mock := newMock(t) - mock.ExpectQuery("from oauth2_clients").WillReturnError(sqlNoRows()) + t.Parallel() + tdb := newTestDB(t) + ctx := tdb.Ctx() // empty DB q := url.Values{"client_id": {"ghost"}, "state": {"s"}, "redirect_uri": {"https://a/cb"}} rec := httptest.NewRecorder() RedirectHandler{OAuth2ClientID: "g", BaseURL: "https://auth.test"}. - ServeHTTP(rec, withDB(getReq(t, q), db)) + ServeHTTP(rec, getReq(t, q).WithContext(ctx)) if rec.Code != http.StatusBadRequest { t.Fatalf("status = %d, want 400", rec.Code) } } func TestRedirectHandler_Confidential_RedirectNotAllowed(t *testing.T) { - db, mock := newMock(t) - mock.ExpectQuery("from oauth2_clients"). - WillReturnRows(clientRow("web", "topsecret", "https://app.example.com/*", "", "client_secret_post")) + t.Parallel() + tdb := newTestDB(t) + ctx := tdb.Ctx() + seedConfidentialClient(t, ctx, "web", "topsecret", "https://app.example.com/*") q := url.Values{"client_id": {"web"}, "state": {"s"}, "redirect_uri": {"https://evil.example.com/cb"}} rec := httptest.NewRecorder() RedirectHandler{OAuth2ClientID: "g", BaseURL: "https://auth.test"}. - ServeHTTP(rec, withDB(getReq(t, q), db)) + ServeHTTP(rec, getReq(t, q).WithContext(ctx)) if rec.Code != http.StatusBadRequest { t.Fatalf("status = %d, want 400", rec.Code) } @@ -110,15 +112,11 @@ func TestRedirectHandler_Confidential_RedirectNotAllowed(t *testing.T) { // --- NEW FLOW: public client requires PKCE + exact (loopback) redirect --- func TestRedirectHandler_Public_Success(t *testing.T) { + t.Parallel() _, challenge := pkcePair() - db, mock := newMock(t) - mock.ExpectQuery("from oauth2_clients"). - WillReturnRows(clientRow("cli", "", "", "http://127.0.0.1:1234/callback", "none")) - // PKCE challenge + method must be persisted on the session for the callback. - mock.ExpectExec("insert into oauth2_sessions"). - WithArgs(sqlmock.AnyArg(), "cli", sqlmock.AnyArg(), "cbstate", - "http://127.0.0.1:55001/callback", challenge, "S256", ""). - WillReturnResult(sqlmock.NewResult(0, 1)) + tdb := newTestDB(t) + ctx := tdb.Ctx() + seedPublicClient(t, ctx, "cli", "http://127.0.0.1:1234/callback") q := url.Values{ "client_id": {"cli"}, @@ -129,7 +127,7 @@ func TestRedirectHandler_Public_Success(t *testing.T) { } rec := httptest.NewRecorder() RedirectHandler{OAuth2ClientID: "googleclient", BaseURL: "https://auth.test"}. - ServeHTTP(rec, withDB(getReq(t, q), db)) + ServeHTTP(rec, getReq(t, q).WithContext(ctx)) if rec.Code != http.StatusFound { t.Fatalf("status = %d, want 302; body=%s", rec.Code, rec.Body.String()) @@ -137,28 +135,37 @@ func TestRedirectHandler_Public_Success(t *testing.T) { if !strings.HasPrefix(rec.Header().Get("Location"), "https://accounts.google.com/") { t.Errorf("Location = %q", rec.Header().Get("Location")) } - assertExpectations(t, mock) + // The PKCE challenge must be persisted on the session for the callback. + challengeStored, methodStored, cbURL := oneSessionPKCE(t, ctx) + if challengeStored != challenge || methodStored != "S256" { + t.Errorf("session PKCE = (%q,%q), want (%q,S256)", challengeStored, methodStored, challenge) + } + if cbURL != "http://127.0.0.1:55001/callback" { + t.Errorf("session callback_url = %q", cbURL) + } } func TestRedirectHandler_Public_MissingChallenge(t *testing.T) { - db, mock := newMock(t) - mock.ExpectQuery("from oauth2_clients"). - WillReturnRows(clientRow("cli", "", "", "http://127.0.0.1:1234/callback", "none")) + t.Parallel() + tdb := newTestDB(t) + ctx := tdb.Ctx() + seedPublicClient(t, ctx, "cli", "http://127.0.0.1:1234/callback") q := url.Values{"client_id": {"cli"}, "state": {"s"}, "redirect_uri": {"http://127.0.0.1:55001/callback"}} rec := httptest.NewRecorder() RedirectHandler{OAuth2ClientID: "g", BaseURL: "https://auth.test"}. - ServeHTTP(rec, withDB(getReq(t, q), db)) + ServeHTTP(rec, getReq(t, q).WithContext(ctx)) if rec.Code != http.StatusBadRequest { t.Fatalf("status = %d, want 400 (PKCE required for public clients)", rec.Code) } } func TestRedirectHandler_Public_UnsupportedChallengeMethod(t *testing.T) { + t.Parallel() _, challenge := pkcePair() - db, mock := newMock(t) - mock.ExpectQuery("from oauth2_clients"). - WillReturnRows(clientRow("cli", "", "", "http://127.0.0.1:1234/callback", "none")) + tdb := newTestDB(t) + ctx := tdb.Ctx() + seedPublicClient(t, ctx, "cli", "http://127.0.0.1:1234/callback") q := url.Values{ "client_id": {"cli"}, @@ -169,17 +176,18 @@ func TestRedirectHandler_Public_UnsupportedChallengeMethod(t *testing.T) { } rec := httptest.NewRecorder() RedirectHandler{OAuth2ClientID: "g", BaseURL: "https://auth.test"}. - ServeHTTP(rec, withDB(getReq(t, q), db)) + ServeHTTP(rec, getReq(t, q).WithContext(ctx)) if rec.Code != http.StatusBadRequest { t.Fatalf("status = %d, want 400 (only S256)", rec.Code) } } func TestRedirectHandler_Public_RedirectNotRegistered(t *testing.T) { + t.Parallel() _, challenge := pkcePair() - db, mock := newMock(t) - mock.ExpectQuery("from oauth2_clients"). - WillReturnRows(clientRow("cli", "", "", "http://127.0.0.1:1234/callback", "none")) + tdb := newTestDB(t) + ctx := tdb.Ctx() + seedPublicClient(t, ctx, "cli", "http://127.0.0.1:1234/callback") q := url.Values{ "client_id": {"cli"}, @@ -190,7 +198,7 @@ func TestRedirectHandler_Public_RedirectNotRegistered(t *testing.T) { } rec := httptest.NewRecorder() RedirectHandler{OAuth2ClientID: "g", BaseURL: "https://auth.test"}. - ServeHTTP(rec, withDB(getReq(t, q), db)) + ServeHTTP(rec, getReq(t, q).WithContext(ctx)) if rec.Code != http.StatusBadRequest { t.Fatalf("status = %d, want 400 (redirect not registered)", rec.Code) } diff --git a/register_test.go b/register_test.go index 76c348d..bebff4f 100644 --- a/register_test.go +++ b/register_test.go @@ -6,8 +6,6 @@ import ( "net/http/httptest" "strings" "testing" - - "github.com/DATA-DOG/go-sqlmock" ) func postJSON(t *testing.T, path, body string) *http.Request { @@ -18,13 +16,13 @@ func postJSON(t *testing.T, path, body string) *http.Request { } func TestRegisterHandler_Success(t *testing.T) { - db, mock := newMock(t) - mock.ExpectExec("insert into oauth2_clients"). - WillReturnResult(sqlmock.NewResult(0, 1)) + t.Parallel() + tdb := newTestDB(t) + ctx := tdb.Ctx() body := `{"client_name":"my cli","redirect_uris":["http://127.0.0.1:1234/cb"]}` rec := httptest.NewRecorder() - RegisterHandler{BaseURL: "https://auth.test"}.ServeHTTP(rec, withDB(postJSON(t, "/register", body), db)) + RegisterHandler{BaseURL: "https://auth.test"}.ServeHTTP(rec, postJSON(t, "/register", body).WithContext(ctx)) if rec.Code != http.StatusCreated { t.Fatalf("status = %d, want 201; body=%s", rec.Code, rec.Body.String()) @@ -38,7 +36,7 @@ func TestRegisterHandler_Success(t *testing.T) { t.Fatalf("decode: %v", err) } if resp.ClientID == "" { - t.Error("client_id is empty") + t.Fatal("client_id is empty") } if resp.TokenEndpointAuthMethod != "none" { t.Errorf("token_endpoint_auth_method = %q, want none", resp.TokenEndpointAuthMethod) @@ -46,24 +44,32 @@ func TestRegisterHandler_Success(t *testing.T) { if len(resp.RedirectURIs) != 1 || resp.RedirectURIs[0] != "http://127.0.0.1:1234/cb" { t.Errorf("redirect_uris = %v", resp.RedirectURIs) } - assertExpectations(t, mock) + // The client is persisted as a public client with the registered redirect. + if uris, ok := clientRedirectURIs(t, ctx, resp.ClientID); !ok || uris != "http://127.0.0.1:1234/cb" { + t.Errorf("persisted redirect_uris = %q (ok=%v)", uris, ok) + } } func TestRegisterHandler_HTTPSRedirectAllowed(t *testing.T) { - db, mock := newMock(t) - mock.ExpectExec("insert into oauth2_clients").WillReturnResult(sqlmock.NewResult(0, 1)) + t.Parallel() + tdb := newTestDB(t) + ctx := tdb.Ctx() body := `{"redirect_uris":["https://app.example.com/cb"]}` rec := httptest.NewRecorder() - RegisterHandler{BaseURL: "https://auth.test"}.ServeHTTP(rec, withDB(postJSON(t, "/register", body), db)) + RegisterHandler{BaseURL: "https://auth.test"}.ServeHTTP(rec, postJSON(t, "/register", body).WithContext(ctx)) if rec.Code != http.StatusCreated { t.Fatalf("status = %d, want 201; body=%s", rec.Code, rec.Body.String()) } - assertExpectations(t, mock) + if n := countRows(t, ctx, "oauth2_clients"); n != 1 { + t.Errorf("oauth2_clients = %d, want 1", n) + } } func TestRegisterHandler_Rejections(t *testing.T) { + t.Parallel() + // All rejected before any DB write, so no test DB is needed. cases := map[string]string{ "invalid json": `{not json`, "no redirect_uris": `{"client_name":"x"}`, @@ -72,9 +78,8 @@ func TestRegisterHandler_Rejections(t *testing.T) { } for name, body := range cases { t.Run(name, func(t *testing.T) { - db, _ := newMock(t) // no insert expected rec := httptest.NewRecorder() - RegisterHandler{BaseURL: "https://auth.test"}.ServeHTTP(rec, withDB(postJSON(t, "/register", body), db)) + RegisterHandler{BaseURL: "https://auth.test"}.ServeHTTP(rec, postJSON(t, "/register", body)) if rec.Code != http.StatusBadRequest { t.Errorf("status = %d, want 400; body=%s", rec.Code, rec.Body.String()) } diff --git a/schema/01_init.sql b/schema/01_init.sql new file mode 100644 index 0000000..083b135 --- /dev/null +++ b/schema/01_init.sql @@ -0,0 +1,36 @@ +create table oauth2_clients ( + id string, + secret string not null, + redirect_uri string not null, + created_at timestamptz not null default now(), + primary key (id) +); + +create table oauth2_codes ( + id string, + client_id string not null, + email string not null, + created_at timestamptz not null default now(), + primary key (id), + foreign key (client_id) references oauth2_clients (id) on delete cascade +); +create index oauth2_codes_created_at_idx on oauth2_codes (created_at); + +create table oauth2_sessions ( + id string, + client_id string not null, + state string not null, + callback_state string not null, + callback_url string not null, + created_at timestamptz not null default now(), + primary key (id) +); +create index oauth2_sessions_created_at_idx on oauth2_sessions (created_at); + +create table user_tokens ( + token string, + email string not null, + expires_at timestamptz not null, + created_at timestamptz not null default now(), + primary key (token) +); diff --git a/schema/02_mcp.sql b/schema/02_mcp.sql new file mode 100644 index 0000000..b2fe212 --- /dev/null +++ b/schema/02_mcp.sql @@ -0,0 +1,16 @@ +-- MCP / OAuth 2.1 support (public clients, PKCE, DCR, resource indicators). + +alter table oauth2_clients alter column secret drop not null; +alter table oauth2_clients alter column redirect_uri set default ''; +alter table oauth2_clients add column if not exists redirect_uris string not null default ''; +alter table oauth2_clients add column if not exists token_endpoint_auth_method string not null default 'client_secret_post'; +alter table oauth2_clients add column if not exists client_name string not null default ''; + +alter table oauth2_codes add column if not exists code_challenge string not null default ''; +alter table oauth2_codes add column if not exists code_challenge_method string not null default ''; +alter table oauth2_codes add column if not exists redirect_uri string not null default ''; +alter table oauth2_codes add column if not exists resource string not null default ''; + +alter table oauth2_sessions add column if not exists code_challenge string not null default ''; +alter table oauth2_sessions add column if not exists code_challenge_method string not null default ''; +alter table oauth2_sessions add column if not exists resource string not null default ''; diff --git a/schema/schema.go b/schema/schema.go new file mode 100644 index 0000000..2d6257c --- /dev/null +++ b/schema/schema.go @@ -0,0 +1,76 @@ +package schema + +import ( + "context" + "database/sql" + "embed" +) + +//go:embed *.sql +var fs embed.FS + +func Migrate(ctx context.Context, db *sql.DB) error { + if err := setupMigrationTable(ctx, db); err != nil { + return err + } + + list, err := fs.ReadDir(".") + if err != nil { + return err + } + + for _, x := range list { + migrateID := x.Name() + migrated, err := isMigrated(ctx, db, migrateID) + if err != nil { + return err + } + if migrated { + continue + } + + b, err := fs.ReadFile(migrateID) + if err != nil { + return err + } + if _, err = db.ExecContext(ctx, string(b)); err != nil { + return err + } + if err = stampMigrated(ctx, db, migrateID, string(b)); err != nil { + return err + } + } + + return nil +} + +func setupMigrationTable(ctx context.Context, db *sql.DB) error { + _, err := db.ExecContext(ctx, ` + create table if not exists migrations ( + id varchar not null, + content varchar not null, + ts timestamptz not null default now(), + primary key (id) + ) + `) + return err +} + +func isMigrated(ctx context.Context, db *sql.DB, migrateID string) (bool, error) { + var migrated bool + err := db.QueryRowContext(ctx, ` + select exists ( + select 1 from migrations where id = $1 + ) + `, migrateID).Scan(&migrated) + return migrated, err +} + +func stampMigrated(ctx context.Context, db *sql.DB, migrateID, content string) error { + _, err := db.ExecContext(ctx, ` + insert into migrations (id, content) + values ($1, $2) + on conflict (id) do nothing + `, migrateID, content) + return err +} diff --git a/setup_test.go b/setup_test.go new file mode 100644 index 0000000..1c858e3 --- /dev/null +++ b/setup_test.go @@ -0,0 +1,17 @@ +package main + +import ( + "testing" + + "github.com/deploys-app/auth/tu" +) + +// newTestDB starts a fresh CockroachDB instance for the test and tears it down +// on completion. Each test gets its own isolated database so they can run in +// parallel without sharing rows or worrying about cleanup. +func newTestDB(t *testing.T) *tu.Context { + t.Helper() + c := tu.Setup() + t.Cleanup(c.Teardown) + return c +} diff --git a/testutil_test.go b/testutil_test.go index 8505463..eee42fc 100644 --- a/testutil_test.go +++ b/testutil_test.go @@ -1,34 +1,15 @@ package main import ( + "context" "crypto/sha256" "database/sql" "encoding/base64" - "net/http" "testing" - "github.com/DATA-DOG/go-sqlmock" "github.com/acoshift/pgsql/pgctx" ) -// newMock returns a mock *sql.DB plus its controller. The DB is closed via -// t.Cleanup. pgctx talks to it through the standard database/sql interface, so -// no transaction (begin/commit) expectations are needed for these handlers. -func newMock(t *testing.T) (*sql.DB, sqlmock.Sqlmock) { - t.Helper() - db, mock, err := sqlmock.New() - if err != nil { - t.Fatalf("sqlmock.New: %v", err) - } - t.Cleanup(func() { db.Close() }) - return db, mock -} - -// withDB binds the mock DB to the request context the way pgctx.Middleware would. -func withDB(req *http.Request, db *sql.DB) *http.Request { - return req.WithContext(pgctx.NewContext(req.Context(), db)) -} - // pkcePair returns a PKCE verifier and its S256 challenge. func pkcePair() (verifier, challenge string) { verifier = "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk" @@ -44,13 +25,126 @@ func fakeIDToken(email string) string { return "header." + payload + ".signature" } -func assertExpectations(t *testing.T, mock sqlmock.Sqlmock) { +// --- seeding helpers (operate on a context already carrying the DB) --- + +func seedConfidentialClient(t *testing.T, ctx context.Context, id, secret, redirectGlob string) { + t.Helper() + _, err := pgctx.Exec(ctx, ` + insert into oauth2_clients (id, secret, redirect_uri, token_endpoint_auth_method) + values ($1, $2, $3, 'client_secret_post') + `, id, secret, redirectGlob) + if err != nil { + t.Fatalf("seed confidential client: %v", err) + } +} + +func seedPublicClient(t *testing.T, ctx context.Context, id string, redirectURIs string) { + t.Helper() + _, err := pgctx.Exec(ctx, ` + insert into oauth2_clients (id, redirect_uris, token_endpoint_auth_method) + values ($1, $2, 'none') + `, id, redirectURIs) + if err != nil { + t.Fatalf("seed public client: %v", err) + } +} + +func seedCode(t *testing.T, ctx context.Context, id, clientID, email, challenge, method, redirect, resource string) { + t.Helper() + _, err := pgctx.Exec(ctx, ` + insert into oauth2_codes (id, client_id, email, code_challenge, code_challenge_method, redirect_uri, resource) + values ($1, $2, $3, $4, $5, $6, $7) + `, id, clientID, email, challenge, method, redirect, resource) + if err != nil { + t.Fatalf("seed code: %v", err) + } +} + +func seedSession(t *testing.T, ctx context.Context, id, clientID, state, cbState, cbURL, challenge, method, resource string) { + t.Helper() + _, err := pgctx.Exec(ctx, ` + insert into oauth2_sessions (id, client_id, state, callback_state, callback_url, code_challenge, code_challenge_method, resource) + values ($1, $2, $3, $4, $5, $6, $7, $8) + `, id, clientID, state, cbState, cbURL, challenge, method, resource) + if err != nil { + t.Fatalf("seed session: %v", err) + } +} + +// pgctxExec runs an arbitrary statement against the test DB (for one-off seeds +// that the typed helpers don't cover, e.g. an already-expired token). +func pgctxExec(t *testing.T, ctx context.Context, query string, args ...any) (sql.Result, error) { t.Helper() - if err := mock.ExpectationsWereMet(); err != nil { - t.Errorf("unmet sqlmock expectations: %v", err) + return pgctx.Exec(ctx, query, args...) +} + +func seedToken(t *testing.T, ctx context.Context, hashedToken, email string) { + t.Helper() + _, err := pgctx.Exec(ctx, ` + insert into user_tokens (token, email, expires_at) + values ($1, $2, now() + interval '7 days') + `, hashedToken, email) + if err != nil { + t.Fatalf("seed token: %v", err) + } +} + +// --- assertion helpers --- + +func tokenEmail(t *testing.T, ctx context.Context, hashedToken string) (string, bool) { + t.Helper() + var email string + err := pgctx.QueryRow(ctx, `select email from user_tokens where token = $1`, hashedToken).Scan(&email) + if err != nil { + return "", false + } + return email, true +} + +func countRows(t *testing.T, ctx context.Context, table string) int { + t.Helper() + var n int + if err := pgctx.QueryRow(ctx, `select count(*) from `+table).Scan(&n); err != nil { + t.Fatalf("count %s: %v", table, err) } + return n } -// sqlNoRows is the error a QueryRow returns when nothing matched; the DB layer -// translates it into the ErrOAuth2*NotFound sentinels. -func sqlNoRows() error { return sql.ErrNoRows } +func clientRedirectURIs(t *testing.T, ctx context.Context, id string) (string, bool) { + t.Helper() + var uris string + err := pgctx.QueryRow(ctx, `select redirect_uris from oauth2_clients where id = $1`, id).Scan(&uris) + if err != nil { + return "", false + } + return uris, true +} + +// codeEmailPKCE reads back a minted oauth2_codes row (without consuming it). +func codeEmailPKCE(t *testing.T, ctx context.Context, id string) (email, challenge, method string) { + t.Helper() + err := pgctx.QueryRow(ctx, ` + select email, code_challenge, code_challenge_method + from oauth2_codes + where id = $1 + `, id).Scan(&email, &challenge, &method) + if err != nil { + t.Fatalf("read code %q: %v", id, err) + } + return email, challenge, method +} + +// oneSessionPKCE returns the PKCE challenge/method and callback URL of the only +// session row (the redirect handler is expected to have created exactly one). +func oneSessionPKCE(t *testing.T, ctx context.Context) (challenge, method, callbackURL string) { + t.Helper() + err := pgctx.QueryRow(ctx, ` + select code_challenge, code_challenge_method, callback_url + from oauth2_sessions + limit 1 + `).Scan(&challenge, &method, &callbackURL) + if err != nil { + t.Fatalf("read session: %v", err) + } + return challenge, method, callbackURL +} diff --git a/token_test.go b/token_test.go index 59a33dd..4107946 100644 --- a/token_test.go +++ b/token_test.go @@ -7,8 +7,6 @@ import ( "net/url" "strings" "testing" - - "github.com/DATA-DOG/go-sqlmock" ) func postForm(t *testing.T, path string, form url.Values) *http.Request { @@ -18,35 +16,18 @@ func postForm(t *testing.T, path string, form url.Values) *http.Request { return req } -func clientRow(id, secret, redirectURI, redirectURIs, authMethod string) *sqlmock.Rows { - cols := []string{"id", "secret", "redirect_uri", "redirect_uris", "token_endpoint_auth_method"} - var sec any = secret - if secret == "" && authMethod == "none" { - sec = nil // public clients store NULL secret - } - return sqlmock.NewRows(cols).AddRow(id, sec, redirectURI, redirectURIs, authMethod) -} - -func codeRow(email, challenge, method, redirectURI, resource string) *sqlmock.Rows { - cols := []string{"email", "code_challenge", "code_challenge_method", "redirect_uri", "resource"} - return sqlmock.NewRows(cols).AddRow(email, challenge, method, redirectURI, resource) -} - // --- OLD FLOW (regression): confidential client + client_secret --- func TestTokenHandler_Confidential_Success(t *testing.T) { - db, mock := newMock(t) - mock.ExpectQuery("from oauth2_clients"). - WillReturnRows(clientRow("web", "topsecret", "https://app.example.com/*", "", "client_secret_post")) - mock.ExpectQuery("delete from oauth2_codes"). - WillReturnRows(codeRow("user@example.com", "", "", "https://app.example.com/cb", "")) - mock.ExpectExec("insert into user_tokens"). - WithArgs(sqlmock.AnyArg(), "user@example.com"). - WillReturnResult(sqlmock.NewResult(0, 1)) + t.Parallel() + tdb := newTestDB(t) + ctx := tdb.Ctx() + seedConfidentialClient(t, ctx, "web", "topsecret", "https://app.example.com/*") + seedCode(t, ctx, "abc", "web", "user@example.com", "", "", "https://app.example.com/cb", "") form := url.Values{"client_id": {"web"}, "client_secret": {"topsecret"}, "code": {"abc"}} rec := httptest.NewRecorder() - TokenHandler{}.ServeHTTP(rec, withDB(postForm(t, "/token", form), db)) + TokenHandler{}.ServeHTTP(rec, postForm(t, "/token", form).WithContext(ctx)) if rec.Code != http.StatusOK { t.Fatalf("status = %d, want 200; body=%s", rec.Code, rec.Body.String()) @@ -61,13 +42,9 @@ func TestTokenHandler_Confidential_Success(t *testing.T) { t.Fatalf("decode: %v", err) } // Backward compatibility: the legacy client reads refresh_token. - if resp.RefreshToken == "" { - t.Error("refresh_token must remain populated for the legacy client") - } if !strings.HasPrefix(resp.RefreshToken, tokenPrefix) { t.Errorf("refresh_token %q missing prefix %q", resp.RefreshToken, tokenPrefix) } - // New fields for OAuth2.1 clients. if resp.AccessToken != resp.RefreshToken { t.Errorf("access_token (%q) should equal refresh_token (%q)", resp.AccessToken, resp.RefreshToken) } @@ -77,33 +54,46 @@ func TestTokenHandler_Confidential_Success(t *testing.T) { if resp.TokenType != "Bearer" { t.Errorf("token_type = %q, want Bearer", resp.TokenType) } - assertExpectations(t, mock) + // The token is persisted (hashed) against the right email. + if email, ok := tokenEmail(t, ctx, hashToken(resp.RefreshToken)); !ok || email != "user@example.com" { + t.Errorf("persisted token email = %q (ok=%v), want user@example.com", email, ok) + } + // The code is single-use. + if n := countRows(t, ctx, "oauth2_codes"); n != 0 { + t.Errorf("oauth2_codes = %d, want 0 (code consumed)", n) + } } func TestTokenHandler_Confidential_WrongSecret(t *testing.T) { - db, mock := newMock(t) - mock.ExpectQuery("from oauth2_clients"). - WillReturnRows(clientRow("web", "topsecret", "https://app.example.com/*", "", "client_secret_post")) + t.Parallel() + tdb := newTestDB(t) + ctx := tdb.Ctx() + seedConfidentialClient(t, ctx, "web", "topsecret", "https://app.example.com/*") + seedCode(t, ctx, "abc", "web", "user@example.com", "", "", "https://app.example.com/cb", "") form := url.Values{"client_id": {"web"}, "client_secret": {"wrong"}, "code": {"abc"}} rec := httptest.NewRecorder() - TokenHandler{}.ServeHTTP(rec, withDB(postForm(t, "/token", form), db)) + TokenHandler{}.ServeHTTP(rec, postForm(t, "/token", form).WithContext(ctx)) if rec.Code != http.StatusUnauthorized { t.Fatalf("status = %d, want 401", rec.Code) } assertOAuthError(t, rec, "invalid_client") - assertExpectations(t, mock) + // The code must not be consumed on a failed exchange. + if n := countRows(t, ctx, "oauth2_codes"); n != 1 { + t.Errorf("oauth2_codes = %d, want 1 (code preserved)", n) + } } func TestTokenHandler_Confidential_MissingSecret(t *testing.T) { - db, mock := newMock(t) - mock.ExpectQuery("from oauth2_clients"). - WillReturnRows(clientRow("web", "topsecret", "https://app.example.com/*", "", "client_secret_post")) + t.Parallel() + tdb := newTestDB(t) + ctx := tdb.Ctx() + seedConfidentialClient(t, ctx, "web", "topsecret", "https://app.example.com/*") form := url.Values{"client_id": {"web"}, "code": {"abc"}} rec := httptest.NewRecorder() - TokenHandler{}.ServeHTTP(rec, withDB(postForm(t, "/token", form), db)) + TokenHandler{}.ServeHTTP(rec, postForm(t, "/token", form).WithContext(ctx)) if rec.Code != http.StatusBadRequest { t.Fatalf("status = %d, want 400", rec.Code) @@ -114,17 +104,14 @@ func TestTokenHandler_Confidential_MissingSecret(t *testing.T) { // --- NEW FLOW: public client + PKCE --- func TestTokenHandler_Public_PKCE_Success(t *testing.T) { + t.Parallel() verifier, challenge := pkcePair() const redirect = "http://127.0.0.1:5000/callback" - db, mock := newMock(t) - mock.ExpectQuery("from oauth2_clients"). - WillReturnRows(clientRow("cli", "", "", redirect, "none")) - mock.ExpectQuery("delete from oauth2_codes"). - WillReturnRows(codeRow("user@example.com", challenge, "S256", redirect, "")) - mock.ExpectExec("insert into user_tokens"). - WithArgs(sqlmock.AnyArg(), "user@example.com"). - WillReturnResult(sqlmock.NewResult(0, 1)) + tdb := newTestDB(t) + ctx := tdb.Ctx() + seedPublicClient(t, ctx, "cli", redirect) + seedCode(t, ctx, "abc", "cli", "user@example.com", challenge, "S256", redirect, "") form := url.Values{ "grant_type": {"authorization_code"}, @@ -134,7 +121,7 @@ func TestTokenHandler_Public_PKCE_Success(t *testing.T) { "redirect_uri": {redirect}, } rec := httptest.NewRecorder() - TokenHandler{}.ServeHTTP(rec, withDB(postForm(t, "/token", form), db)) + TokenHandler{}.ServeHTTP(rec, postForm(t, "/token", form).WithContext(ctx)) if rec.Code != http.StatusOK { t.Fatalf("status = %d, want 200; body=%s", rec.Code, rec.Body.String()) @@ -156,18 +143,20 @@ func TestTokenHandler_Public_PKCE_Success(t *testing.T) { if resp.ExpiresIn != tokenTTLSeconds { t.Errorf("expires_in = %d, want %d", resp.ExpiresIn, tokenTTLSeconds) } - assertExpectations(t, mock) + if email, ok := tokenEmail(t, ctx, hashToken(resp.AccessToken)); !ok || email != "user@example.com" { + t.Errorf("persisted token email = %q (ok=%v), want user@example.com", email, ok) + } } func TestTokenHandler_Public_PKCE_BadVerifier(t *testing.T) { + t.Parallel() _, challenge := pkcePair() const redirect = "http://127.0.0.1:5000/callback" - db, mock := newMock(t) - mock.ExpectQuery("from oauth2_clients"). - WillReturnRows(clientRow("cli", "", "", redirect, "none")) - mock.ExpectQuery("delete from oauth2_codes"). - WillReturnRows(codeRow("user@example.com", challenge, "S256", redirect, "")) + tdb := newTestDB(t) + ctx := tdb.Ctx() + seedPublicClient(t, ctx, "cli", redirect) + seedCode(t, ctx, "abc", "cli", "user@example.com", challenge, "S256", redirect, "") form := url.Values{ "grant_type": {"authorization_code"}, @@ -177,23 +166,26 @@ func TestTokenHandler_Public_PKCE_BadVerifier(t *testing.T) { "redirect_uri": {redirect}, } rec := httptest.NewRecorder() - TokenHandler{}.ServeHTTP(rec, withDB(postForm(t, "/token", form), db)) + TokenHandler{}.ServeHTTP(rec, postForm(t, "/token", form).WithContext(ctx)) if rec.Code != http.StatusBadRequest { t.Fatalf("status = %d, want 400; body=%s", rec.Code, rec.Body.String()) } assertOAuthError(t, rec, "invalid_grant") - assertExpectations(t, mock) + if n := countRows(t, ctx, "user_tokens"); n != 0 { + t.Errorf("user_tokens = %d, want 0 (no token issued on PKCE failure)", n) + } } func TestTokenHandler_Public_RedirectMismatch(t *testing.T) { + t.Parallel() verifier, challenge := pkcePair() + const redirect = "http://127.0.0.1:5000/callback" - db, mock := newMock(t) - mock.ExpectQuery("from oauth2_clients"). - WillReturnRows(clientRow("cli", "", "", "http://127.0.0.1:5000/callback", "none")) - mock.ExpectQuery("delete from oauth2_codes"). - WillReturnRows(codeRow("user@example.com", challenge, "S256", "http://127.0.0.1:5000/callback", "")) + tdb := newTestDB(t) + ctx := tdb.Ctx() + seedPublicClient(t, ctx, "cli", redirect) + seedCode(t, ctx, "abc", "cli", "user@example.com", challenge, "S256", redirect, "") form := url.Values{ "grant_type": {"authorization_code"}, @@ -203,38 +195,38 @@ func TestTokenHandler_Public_RedirectMismatch(t *testing.T) { "redirect_uri": {"http://127.0.0.1:5000/different"}, } rec := httptest.NewRecorder() - TokenHandler{}.ServeHTTP(rec, withDB(postForm(t, "/token", form), db)) + TokenHandler{}.ServeHTTP(rec, postForm(t, "/token", form).WithContext(ctx)) if rec.Code != http.StatusBadRequest { t.Fatalf("status = %d, want 400", rec.Code) } assertOAuthError(t, rec, "invalid_grant") - assertExpectations(t, mock) } func TestTokenHandler_Public_MissingVerifier(t *testing.T) { - db, mock := newMock(t) - mock.ExpectQuery("from oauth2_clients"). - WillReturnRows(clientRow("cli", "", "", "http://127.0.0.1:5000/callback", "none")) + t.Parallel() + tdb := newTestDB(t) + ctx := tdb.Ctx() + seedPublicClient(t, ctx, "cli", "http://127.0.0.1:5000/callback") form := url.Values{"grant_type": {"authorization_code"}, "client_id": {"cli"}, "code": {"abc"}} rec := httptest.NewRecorder() - TokenHandler{}.ServeHTTP(rec, withDB(postForm(t, "/token", form), db)) + TokenHandler{}.ServeHTTP(rec, postForm(t, "/token", form).WithContext(ctx)) if rec.Code != http.StatusBadRequest { t.Fatalf("status = %d, want 400", rec.Code) } assertOAuthError(t, rec, "invalid_request") - assertExpectations(t, mock) } // --- shared request validation --- func TestTokenHandler_UnsupportedGrantType(t *testing.T) { - db, _ := newMock(t) + t.Parallel() form := url.Values{"grant_type": {"password"}, "client_id": {"x"}, "code": {"y"}} rec := httptest.NewRecorder() - TokenHandler{}.ServeHTTP(rec, withDB(postForm(t, "/token", form), db)) + // Rejected before any DB access. + TokenHandler{}.ServeHTTP(rec, postForm(t, "/token", form)) if rec.Code != http.StatusBadRequest { t.Fatalf("status = %d, want 400", rec.Code) @@ -243,12 +235,13 @@ func TestTokenHandler_UnsupportedGrantType(t *testing.T) { } func TestTokenHandler_UnknownClient(t *testing.T) { - db, mock := newMock(t) - mock.ExpectQuery("from oauth2_clients").WillReturnError(sqlNoRows()) + t.Parallel() + tdb := newTestDB(t) + ctx := tdb.Ctx() // empty DB: client lookup misses form := url.Values{"client_id": {"ghost"}, "client_secret": {"x"}, "code": {"y"}} rec := httptest.NewRecorder() - TokenHandler{}.ServeHTTP(rec, withDB(postForm(t, "/token", form), db)) + TokenHandler{}.ServeHTTP(rec, postForm(t, "/token", form).WithContext(ctx)) if rec.Code != http.StatusUnauthorized { t.Fatalf("status = %d, want 401", rec.Code) @@ -257,20 +250,19 @@ func TestTokenHandler_UnknownClient(t *testing.T) { } func TestTokenHandler_InvalidCode(t *testing.T) { - db, mock := newMock(t) - mock.ExpectQuery("from oauth2_clients"). - WillReturnRows(clientRow("web", "topsecret", "https://app.example.com/*", "", "client_secret_post")) - mock.ExpectQuery("delete from oauth2_codes").WillReturnError(sqlNoRows()) + t.Parallel() + tdb := newTestDB(t) + ctx := tdb.Ctx() + seedConfidentialClient(t, ctx, "web", "topsecret", "https://app.example.com/*") form := url.Values{"client_id": {"web"}, "client_secret": {"topsecret"}, "code": {"gone"}} rec := httptest.NewRecorder() - TokenHandler{}.ServeHTTP(rec, withDB(postForm(t, "/token", form), db)) + TokenHandler{}.ServeHTTP(rec, postForm(t, "/token", form).WithContext(ctx)) if rec.Code != http.StatusBadRequest { t.Fatalf("status = %d, want 400", rec.Code) } assertOAuthError(t, rec, "invalid_grant") - assertExpectations(t, mock) } func assertOAuthError(t *testing.T, rec *httptest.ResponseRecorder, wantCode string) { diff --git a/tu/tu.go b/tu/tu.go new file mode 100644 index 0000000..fa7f9c8 --- /dev/null +++ b/tu/tu.go @@ -0,0 +1,65 @@ +// Package tu is the test utility for spinning up an isolated CockroachDB. +package tu + +import ( + "context" + "database/sql" + + "github.com/acoshift/pgsql/pgctx" + "github.com/cockroachdb/cockroach-go/v2/testserver" + + "github.com/deploys-app/auth/schema" +) + +// Context holds the test server and DB connection. +type Context struct { + ts testserver.TestServer + DB *sql.DB +} + +func (c *Context) setup() { + var err error + defer func() { + if err != nil { + panic(err) + } + }() + + c.ts, err = testserver.NewTestServer() + if err != nil { + return + } + defer func() { + if err != nil { + c.Teardown() + } + }() + + c.DB, err = sql.Open("postgres", c.ts.PGURL().String()+"&enable_implicit_transaction_for_batch_statements=off") + if err != nil { + return + } + + err = schema.Migrate(context.Background(), c.DB) +} + +func (c *Context) Teardown() { + if c.DB != nil { + c.DB.Close() + } + if c.ts != nil { + c.ts.Stop() + } +} + +// Ctx returns a context with the DB injected, the way pgctx.Middleware would. +func (c *Context) Ctx() context.Context { + return pgctx.NewContext(context.Background(), c.DB) +} + +// Setup starts a CockroachDB test server and runs the schema migration. +func Setup() *Context { + c := &Context{} + c.setup() + return c +} From aa59997f5abe58870e5fd6363f8e14c19da359d9 Mon Sep 17 00:00:00 2001 From: Thanatat Tamtan Date: Wed, 27 May 2026 08:43:12 +0700 Subject: [PATCH 5/5] refactor: store redirect_uris as a string[] column Switch oauth2_clients.redirect_uris from a newline-delimited string to a native CockroachDB string[] array. Reads scan straight into the []string (pgsql wraps slice destinations with pq.Array); writes pass pq.Array. Drops the strings join/split in oauth2.go. Co-Authored-By: Claude Opus 4.7 (1M context) --- oauth2.go | 10 +++------- register_test.go | 4 ++-- schema.sql | 2 +- schema/02_mcp.sql | 2 +- testutil_test.go | 11 ++++++----- 5 files changed, 13 insertions(+), 16 deletions(-) diff --git a/oauth2.go b/oauth2.go index 76ec028..5041590 100644 --- a/oauth2.go +++ b/oauth2.go @@ -4,9 +4,9 @@ import ( "context" "database/sql" "errors" - "strings" "github.com/acoshift/pgsql/pgctx" + "github.com/lib/pq" ) var ( @@ -50,13 +50,12 @@ type OAuth2Code struct { func getOAuth2Client(ctx context.Context, clientID string) (*OAuth2Client, error) { var x OAuth2Client var secret sql.NullString - var redirectURIs string err := pgctx.QueryRow(ctx, ` select id, secret, redirect_uri, redirect_uris, token_endpoint_auth_method from oauth2_clients where id = $1 `, clientID).Scan( - &x.ID, &secret, &x.RedirectURI, &redirectURIs, &x.TokenEndpointAuthMethod, + &x.ID, &secret, &x.RedirectURI, &x.RedirectURIs, &x.TokenEndpointAuthMethod, ) if errors.Is(err, sql.ErrNoRows) { return nil, ErrOAuth2ClientNotFound @@ -65,9 +64,6 @@ func getOAuth2Client(ctx context.Context, clientID string) (*OAuth2Client, error return nil, err } x.Secret = secret.String - if redirectURIs != "" { - x.RedirectURIs = strings.Split(redirectURIs, "\n") - } return &x, nil } @@ -76,7 +72,7 @@ func insertOAuth2Client(ctx context.Context, c *OAuth2Client) error { _, err := pgctx.Exec(ctx, ` insert into oauth2_clients (id, secret, redirect_uri, redirect_uris, token_endpoint_auth_method, client_name) values ($1, null, '', $2, $3, $4) - `, c.ID, strings.Join(c.RedirectURIs, "\n"), c.TokenEndpointAuthMethod, c.ClientName) + `, c.ID, pq.Array(c.RedirectURIs), c.TokenEndpointAuthMethod, c.ClientName) return err } diff --git a/register_test.go b/register_test.go index bebff4f..d932259 100644 --- a/register_test.go +++ b/register_test.go @@ -45,8 +45,8 @@ func TestRegisterHandler_Success(t *testing.T) { t.Errorf("redirect_uris = %v", resp.RedirectURIs) } // The client is persisted as a public client with the registered redirect. - if uris, ok := clientRedirectURIs(t, ctx, resp.ClientID); !ok || uris != "http://127.0.0.1:1234/cb" { - t.Errorf("persisted redirect_uris = %q (ok=%v)", uris, ok) + if uris, ok := clientRedirectURIs(t, ctx, resp.ClientID); !ok || len(uris) != 1 || uris[0] != "http://127.0.0.1:1234/cb" { + t.Errorf("persisted redirect_uris = %v (ok=%v)", uris, ok) } } diff --git a/schema.sql b/schema.sql index a561bc2..d754914 100644 --- a/schema.sql +++ b/schema.sql @@ -2,7 +2,7 @@ create table oauth2_clients ( id string, secret string, redirect_uri string not null default '', - redirect_uris string not null default '', + redirect_uris string[] not null default array[]::string[], token_endpoint_auth_method string not null default 'client_secret_post', client_name string not null default '', created_at timestamptz not null default now(), diff --git a/schema/02_mcp.sql b/schema/02_mcp.sql index b2fe212..7e62402 100644 --- a/schema/02_mcp.sql +++ b/schema/02_mcp.sql @@ -2,7 +2,7 @@ alter table oauth2_clients alter column secret drop not null; alter table oauth2_clients alter column redirect_uri set default ''; -alter table oauth2_clients add column if not exists redirect_uris string not null default ''; +alter table oauth2_clients add column if not exists redirect_uris string[] not null default array[]::string[]; alter table oauth2_clients add column if not exists token_endpoint_auth_method string not null default 'client_secret_post'; alter table oauth2_clients add column if not exists client_name string not null default ''; diff --git a/testutil_test.go b/testutil_test.go index eee42fc..9c080e5 100644 --- a/testutil_test.go +++ b/testutil_test.go @@ -8,6 +8,7 @@ import ( "testing" "github.com/acoshift/pgsql/pgctx" + "github.com/lib/pq" ) // pkcePair returns a PKCE verifier and its S256 challenge. @@ -38,12 +39,12 @@ func seedConfidentialClient(t *testing.T, ctx context.Context, id, secret, redir } } -func seedPublicClient(t *testing.T, ctx context.Context, id string, redirectURIs string) { +func seedPublicClient(t *testing.T, ctx context.Context, id string, redirectURIs ...string) { t.Helper() _, err := pgctx.Exec(ctx, ` insert into oauth2_clients (id, redirect_uris, token_endpoint_auth_method) values ($1, $2, 'none') - `, id, redirectURIs) + `, id, pq.Array(redirectURIs)) if err != nil { t.Fatalf("seed public client: %v", err) } @@ -110,12 +111,12 @@ func countRows(t *testing.T, ctx context.Context, table string) int { return n } -func clientRedirectURIs(t *testing.T, ctx context.Context, id string) (string, bool) { +func clientRedirectURIs(t *testing.T, ctx context.Context, id string) ([]string, bool) { t.Helper() - var uris string + var uris []string err := pgctx.QueryRow(ctx, `select redirect_uris from oauth2_clients where id = $1`, id).Scan(&uris) if err != nil { - return "", false + return nil, false } return uris, true }