diff --git a/CLAUDE.md b/CLAUDE.md index 145769e..0f41c4c 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -7,9 +7,15 @@ This file provides guidance to Claude Code (claude.ai/code) when working with co ```bash go build -o auth . # build binary go vet ./... # lint +go test ./... # run tests ``` -No test files exist in this codebase. +Handler tests run against a real CockroachDB: `tu.Setup()` starts an isolated +in-memory `cockroach-go/v2/testserver` per test and `schema.Migrate` applies the +embedded `schema/*.sql` migrations. `newTestDB(t)` (setup_test.go) returns a +`*tu.Context`; use `db.Ctx()` for both seeding (via `pgctx`) and the request +context. Google's token endpoint is stubbed through the `googleTokenURL` package +var. Tests download the CockroachDB binary on first run. ### Required environment variables @@ -19,6 +25,8 @@ No test files exist in this codebase. | `OAUTH2_CLIENT_ID` | Google OAuth app client ID | | `OAUTH2_CLIENT_SECRET` | Google OAuth app client secret | | `PORT` | Listen port (default: `8080`) | +| `BASE_URL` | Public base URL of this service (default: `https://auth.deploys.app`) | +| `INTROSPECTION_TOKEN` | Shared secret guarding `POST /introspect`; unset disables the endpoint | ## Architecture diff --git a/README.md b/README.md index d3bce71..a9a1c06 100644 --- a/README.md +++ b/README.md @@ -19,6 +19,8 @@ Required environment variables: | `OAUTH2_CLIENT_ID` | Google OAuth app client ID | | `OAUTH2_CLIENT_SECRET` | Google OAuth app client secret | | `PORT` | Listen port (default: `8080`) | +| `BASE_URL` | Public base URL of this service (default: `https://auth.deploys.app`) | +| `INTROSPECTION_TOKEN` | Shared secret for the `/introspect` endpoint; if unset, introspection is disabled | ```shell $ ./auth @@ -31,11 +33,23 @@ the service starts. | Method | Path | Purpose | |---|---|---| -| `GET` | `/` | Validate the OAuth2 client and redirect to Google | +| `GET` | `/.well-known/oauth-authorization-server` | OAuth 2.0 Authorization Server Metadata (RFC 8414) | +| `GET` | `/` | Validate the OAuth2 client and redirect to Google (authorize endpoint; supports PKCE) | | `GET` | `/callback` | Receive Google's code and issue an internal auth code | -| `POST` | `/token` | Exchange client credentials + code for a user token | +| `POST` | `/token` | Exchange code (+ secret or PKCE verifier) for a user token | +| `POST` | `/register` | Dynamic Client Registration for public clients (RFC 7591) | +| `POST` | `/introspect` | Token introspection for resource servers (RFC 7662) | | `POST` | `/revoke` | Revoke a user token | +### MCP / public clients + +The service is an OAuth 2.1 authorization server. CLI / MCP clients register +dynamically at `/register` (public clients, no secret), use **PKCE (S256)** at +the authorize and token endpoints, and may use loopback redirect URIs +(`http://127.0.0.1:`). Confidential web clients keep using +`client_secret` as before. A resource server validates issued bearer tokens via +`/introspect` (authenticated with `INTROSPECTION_TOKEN`). + ## Deployment The provided [Dockerfile](./Dockerfile) builds a `gcr.io/distroless/static` diff --git a/callback_test.go b/callback_test.go new file mode 100644 index 0000000..0cbe892 --- /dev/null +++ b/callback_test.go @@ -0,0 +1,150 @@ +package main + +import ( + "fmt" + "net/http" + "net/http/httptest" + "net/url" + "strings" + "sync" + "testing" +) + +// A single mock Google token endpoint is started lazily and shared by all +// callback tests. It keys the returned id_token email by the authorization +// code in the request, so parallel tests do not race over googleTokenURL. +var ( + googleMockOnce sync.Once + googleMockMap sync.Map // code -> email +) + +func registerGoogleCode(t *testing.T, code, email string) { + t.Helper() + googleMockOnce.Do(func() { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + r.ParseForm() + email := "default@example.com" + if v, ok := googleMockMap.Load(r.FormValue("code")); ok { + email = v.(string) + } + w.Header().Set("Content-Type", "application/json") + fmt.Fprintf(w, `{"id_token":%q}`, fakeIDToken(email)) + })) + googleTokenURL = srv.URL + }) + googleMockMap.Store(code, email) + t.Cleanup(func() { googleMockMap.Delete(code) }) +} + +// --- OLD FLOW (regression): Google callback issues an internal code --- + +func TestCallbackHandler_Confidential_Success(t *testing.T) { + t.Parallel() + registerGoogleCode(t, t.Name(), "user@example.com") + + tdb := newTestDB(t) + ctx := tdb.Ctx() + seedConfidentialClient(t, ctx, "web", "topsecret", "https://app.example.com/*") + seedSession(t, ctx, "sess123", "web", "gstate", "cbstate", "https://app.example.com/cb", "", "", "") + + q := url.Values{"state": {"gstate"}, "code": {t.Name()}} + req := getReqPath(t, "/callback", q) + req.AddCookie(&http.Cookie{Name: "s", Value: "sess123"}) + rec := httptest.NewRecorder() + CallbackHandler{OAuth2ClientID: "g", OAuth2ClientSecret: "gs", BaseURL: "https://auth.test"}. + ServeHTTP(rec, req.WithContext(ctx)) + + if rec.Code != http.StatusFound { + t.Fatalf("status = %d, want 302; body=%s", rec.Code, rec.Body.String()) + } + loc := rec.Header().Get("Location") + if !strings.HasPrefix(loc, "https://app.example.com/cb?") { + t.Fatalf("Location = %q, want redirect to client callback", loc) + } + u, _ := url.Parse(loc) + if u.Query().Get("state") != "cbstate" { + t.Errorf("callback state = %q, want cbstate", u.Query().Get("state")) + } + returnedCode := u.Query().Get("code") + if returnedCode == "" { + t.Fatal("callback code is empty") + } + // The session is single-use and an internal code was minted for the email. + if n := countRows(t, ctx, "oauth2_sessions"); n != 0 { + t.Errorf("oauth2_sessions = %d, want 0 (consumed)", n) + } + email, challenge, method := codeEmailPKCE(t, ctx, returnedCode) + if email != "user@example.com" { + t.Errorf("code email = %q, want user@example.com", email) + } + if challenge != "" || method != "" { + t.Errorf("confidential code carried PKCE: (%q,%q)", challenge, method) + } +} + +func TestCallbackHandler_MissingSessionCookie(t *testing.T) { + t.Parallel() + q := url.Values{"state": {"gstate"}, "code": {"google-code"}} + req := getReqPath(t, "/callback", q) // no cookie + rec := httptest.NewRecorder() + CallbackHandler{BaseURL: "https://auth.test"}.ServeHTTP(rec, req) + if rec.Code != http.StatusBadRequest { + t.Fatalf("status = %d, want 400", rec.Code) + } +} + +func TestCallbackHandler_StateMismatch(t *testing.T) { + t.Parallel() + tdb := newTestDB(t) + ctx := tdb.Ctx() + seedSession(t, ctx, "sess123", "web", "expected-state", "cbstate", "https://app.example.com/cb", "", "", "") + + q := url.Values{"state": {"attacker-state"}, "code": {"google-code"}} + req := getReqPath(t, "/callback", q) + req.AddCookie(&http.Cookie{Name: "s", Value: "sess123"}) + rec := httptest.NewRecorder() + CallbackHandler{BaseURL: "https://auth.test"}.ServeHTTP(rec, req.WithContext(ctx)) + if rec.Code != http.StatusBadRequest { + t.Fatalf("status = %d, want 400 (state mismatch)", rec.Code) + } + // Session is consumed on read even on mismatch; no code should be minted. + if n := countRows(t, ctx, "oauth2_codes"); n != 0 { + t.Errorf("oauth2_codes = %d, want 0", n) + } +} + +// --- NEW FLOW: PKCE challenge from the session is carried onto the code --- + +func TestCallbackHandler_Public_CarriesPKCE(t *testing.T) { + t.Parallel() + registerGoogleCode(t, t.Name(), "user@example.com") + _, challenge := pkcePair() + + tdb := newTestDB(t) + ctx := tdb.Ctx() + seedPublicClient(t, ctx, "cli", "http://127.0.0.1:55001/callback") + seedSession(t, ctx, "sess123", "cli", "gstate", "cbstate", "http://127.0.0.1:55001/callback", challenge, "S256", "https://api.deploys.app") + + q := url.Values{"state": {"gstate"}, "code": {t.Name()}} + req := getReqPath(t, "/callback", q) + req.AddCookie(&http.Cookie{Name: "s", Value: "sess123"}) + rec := httptest.NewRecorder() + CallbackHandler{BaseURL: "https://auth.test"}.ServeHTTP(rec, req.WithContext(ctx)) + + if rec.Code != http.StatusFound { + t.Fatalf("status = %d, want 302; body=%s", rec.Code, rec.Body.String()) + } + u, _ := url.Parse(rec.Header().Get("Location")) + email, gotChallenge, gotMethod := codeEmailPKCE(t, ctx, u.Query().Get("code")) + if email != "user@example.com" { + t.Errorf("code email = %q", email) + } + if gotChallenge != challenge || gotMethod != "S256" { + t.Errorf("code PKCE = (%q,%q), want (%q,S256)", gotChallenge, gotMethod, challenge) + } +} + +func getReqPath(t *testing.T, path string, q url.Values) *http.Request { + t.Helper() + return httptest.NewRequest(http.MethodGet, path+"?"+q.Encode(), nil) +} diff --git a/cleanup.go b/cleanup.go new file mode 100644 index 0000000..8c6c927 --- /dev/null +++ b/cleanup.go @@ -0,0 +1,33 @@ +package main + +import ( + "context" + "database/sql" + "log/slog" + "time" +) + +// startCleanupWorker periodically removes expired oauth2 sessions and codes. +// Both have a 1-hour TTL but are only deleted on use, so abandoned rows would +// otherwise accumulate. Registered clients are never reaped — MCP clients reuse +// their client_id across logins. +func startCleanupWorker(db *sql.DB) { + go func() { + for { + cleanupExpired(db) + time.Sleep(15 * time.Minute) + } + }() +} + +func cleanupExpired(db *sql.DB) { + ctx := context.Background() + for _, q := range []string{ + `delete from oauth2_sessions where created_at < now() - interval '1 hour'`, + `delete from oauth2_codes where created_at < now() - interval '1 hour'`, + } { + if _, err := db.ExecContext(ctx, q); err != nil { + slog.ErrorContext(ctx, "cleanup: delete expired rows", "error", err) + } + } +} diff --git a/go.mod b/go.mod index 8d9f4aa..c1f7861 100644 --- a/go.mod +++ b/go.mod @@ -6,3 +6,10 @@ require ( github.com/acoshift/pgsql v0.16.0 github.com/lib/pq v1.12.3 ) + +require ( + github.com/cockroachdb/cockroach-go/v2 v2.4.3 + github.com/gofrs/flock v0.12.1 // indirect + golang.org/x/sys v0.28.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) diff --git a/go.sum b/go.sum index 66994d2..baba46b 100644 --- a/go.sum +++ b/go.sum @@ -2,13 +2,28 @@ github.com/DATA-DOG/go-sqlmock v1.5.0 h1:Shsta01QNfFxHCfpW6YH2STWB0MudeXXEWMr20O github.com/DATA-DOG/go-sqlmock v1.5.0/go.mod h1:f/Ixk793poVmq4qj/V1dPUg2JEAKC73Q5eFN3EC/SaM= github.com/acoshift/pgsql v0.16.0 h1:ak+fwy8Xnx0uZBhSmvFhGGk3a7EQNu8IiRLpnP99IT4= github.com/acoshift/pgsql v0.16.0/go.mod h1:HtdMa77CYeRb9pD6+cT/ZPjpudiUVQ2OIKv6QXjgEZw= +github.com/cockroachdb/cockroach-go/v2 v2.4.3 h1:LJO3K3jC5WXvMePRQSJE1NsIGoFGcEx1LW83W6RAlhw= +github.com/cockroachdb/cockroach-go/v2 v2.4.3/go.mod h1:9U179XbCx4qFWtNhc7BiWLPfuyMVQ7qdAhfrwLz1vH0= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/gofrs/flock v0.12.1 h1:MTLVXXHf8ekldpJk3AKicLij9MdwOWkZ+a/jHHZby9E= +github.com/gofrs/flock v0.12.1/go.mod h1:9zxTsyu5xtJ9DK+1tFZyibEV7y3uwDxPPfbxeeHCoD0= +github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= +github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/lib/pq v1.12.3 h1:tTWxr2YLKwIvK90ZXEw8GP7UFHtcbTtty8zsI+YjrfQ= github.com/lib/pq v1.12.3/go.mod h1:/p+8NSbOcwzAEI7wiMXFlgydTwcgTr3OSKMsD2BitpA= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= -github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= +github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8= +github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= +github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= +github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +golang.org/x/sys v0.28.0 h1:Fksou7UEQUWlKvIdsqzJmUmCX3cZuD2+P3XyyzwMhlA= +golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/handler.go b/handler.go index 7cdf6aa..44863ee 100644 --- a/handler.go +++ b/handler.go @@ -1,6 +1,7 @@ package main import ( + "crypto/subtle" "encoding/base64" "encoding/json" "errors" @@ -16,6 +17,7 @@ import ( type RedirectHandler struct { OAuth2ClientID string + BaseURL string } func (h RedirectHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { @@ -41,6 +43,13 @@ func (h RedirectHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } + codeChallenge := r.FormValue("code_challenge") + codeChallengeMethod := r.FormValue("code_challenge_method") + if codeChallenge != "" && codeChallengeMethod == "" { + codeChallengeMethod = "S256" + } + resource := r.FormValue("resource") + ctx := r.Context() oauth2Client, err := getOAuth2Client(ctx, clientID) @@ -54,25 +63,47 @@ func (h RedirectHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { http.Error(w, "Internal server error", http.StatusInternalServerError) return } - pattern := oauth2Client.RedirectURI - pattern = strings.ReplaceAll(pattern, ".", `\.`) - pattern = strings.ReplaceAll(pattern, "/", `\/`) - pattern = strings.ReplaceAll(pattern, "*", `.*`) - re := regexp.MustCompile(`^` + pattern + `$`) - if !re.MatchString(callbackURL) { - slog.WarnContext(ctx, "redirect: invalid redirect_uri", "client_id", clientID, "redirect_uri", callbackURL) - http.Error(w, "Invalid redirect_uri parameter", http.StatusBadRequest) - return + + if oauth2Client.IsPublic() { + // Public clients (CLI / MCP) must use PKCE and an exact, pre-registered + // redirect URI (loopback port is allowed to vary). + if codeChallenge == "" { + http.Error(w, "Missing code_challenge parameter", http.StatusBadRequest) + return + } + if codeChallengeMethod != "S256" { + http.Error(w, "Unsupported code_challenge_method parameter", http.StatusBadRequest) + return + } + if !redirectURIAllowed(oauth2Client.RedirectURIs, callbackURL) { + slog.WarnContext(ctx, "redirect: invalid redirect_uri", "client_id", clientID, "redirect_uri", callbackURL) + http.Error(w, "Invalid redirect_uri parameter", http.StatusBadRequest) + return + } + } else { + pattern := oauth2Client.RedirectURI + pattern = strings.ReplaceAll(pattern, ".", `\.`) + pattern = strings.ReplaceAll(pattern, "/", `\/`) + pattern = strings.ReplaceAll(pattern, "*", `.*`) + re := regexp.MustCompile(`^` + pattern + `$`) + if !re.MatchString(callbackURL) { + slog.WarnContext(ctx, "redirect: invalid redirect_uri", "client_id", clientID, "redirect_uri", callbackURL) + http.Error(w, "Invalid redirect_uri parameter", http.StatusBadRequest) + return + } } state := generateState() sessionID := generateSessionID() err = saveSession(ctx, sessionID, &Session{ - ClientID: oauth2Client.ID, - State: state, - CallbackState: callbackState, - CallbackURL: callbackURL, + ClientID: oauth2Client.ID, + State: state, + CallbackState: callbackState, + CallbackURL: callbackURL, + CodeChallenge: codeChallenge, + CodeChallengeMethod: codeChallengeMethod, + Resource: resource, }) if err != nil { slog.ErrorContext(ctx, "redirect: save session", "error", err) @@ -92,7 +123,7 @@ func (h RedirectHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { params := url.Values{} params.Set("response_type", "code") params.Set("client_id", h.OAuth2ClientID) - params.Set("redirect_uri", "https://auth.deploys.app/callback") + params.Set("redirect_uri", h.BaseURL+"/callback") params.Set("scope", "https://www.googleapis.com/auth/userinfo.email") params.Set("access_type", "online") params.Set("prompt", "consent") @@ -110,9 +141,14 @@ func isURL(s string) bool { return (p.Scheme == "http" || p.Scheme == "https") && p.Host != "" } +// googleTokenURL is Google's OAuth2 token endpoint. It is a variable so tests +// can point the callback exchange at a stub server. +var googleTokenURL = "https://oauth2.googleapis.com/token" + type CallbackHandler struct { OAuth2ClientID string OAuth2ClientSecret string + BaseURL string } func (h CallbackHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { @@ -161,11 +197,11 @@ func (h CallbackHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { params := url.Values{} params.Set("grant_type", "authorization_code") params.Set("code", code) - params.Set("redirect_uri", "https://auth.deploys.app/callback") + params.Set("redirect_uri", h.BaseURL+"/callback") params.Set("client_id", h.OAuth2ClientID) params.Set("client_secret", h.OAuth2ClientSecret) - resp, err := http.Post("https://oauth2.googleapis.com/token", "application/x-www-form-urlencoded", strings.NewReader(params.Encode())) + resp, err := http.Post(googleTokenURL, "application/x-www-form-urlencoded", strings.NewReader(params.Encode())) if err != nil { slog.WarnContext(ctx, "callback: exchange token", "error", err) failResponse(w, r) @@ -195,7 +231,13 @@ func (h CallbackHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { } returnCode := generateCode() - err = insertOAuth2Code(ctx, session.ClientID, returnCode, email) + err = insertOAuth2Code(ctx, session.ClientID, returnCode, &OAuth2Code{ + Email: email, + CodeChallenge: session.CodeChallenge, + CodeChallengeMethod: session.CodeChallengeMethod, + RedirectURI: session.CallbackURL, + Resource: session.Resource, + }) if err != nil { slog.ErrorContext(ctx, "callback: insert oauth2 code", "error", err) http.Error(w, "Internal server error", http.StatusInternalServerError) @@ -297,69 +339,121 @@ func (RevokePostHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { type TokenHandler struct{} func (TokenHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { - clientID := r.PostFormValue("client_id") - if clientID == "" { - http.Error(w, "Missing client_id parameter", http.StatusBadRequest) + grantType := r.PostFormValue("grant_type") + if grantType != "" && grantType != "authorization_code" { + oauthError(w, http.StatusBadRequest, "unsupported_grant_type", "only authorization_code is supported") return } - clientSecret := r.PostFormValue("client_secret") - if clientSecret == "" { - http.Error(w, "Missing client_secret parameter", http.StatusBadRequest) + + clientID := r.PostFormValue("client_id") + if clientID == "" { + oauthError(w, http.StatusBadRequest, "invalid_request", "missing client_id") return } code := r.PostFormValue("code") if code == "" { - http.Error(w, "Missing code parameter", http.StatusBadRequest) + oauthError(w, http.StatusBadRequest, "invalid_request", "missing code") return } ctx := r.Context() oauth2Client, err := getOAuth2Client(ctx, clientID) if errors.Is(err, ErrOAuth2ClientNotFound) { - http.Error(w, "Invalid client_id parameter", http.StatusBadRequest) + oauthError(w, http.StatusUnauthorized, "invalid_client", "unknown client_id") return } if err != nil { - http.Error(w, "Internal server error", http.StatusInternalServerError) + slog.ErrorContext(ctx, "token: get oauth2 client", "error", err) + oauthError(w, http.StatusInternalServerError, "server_error", "") return } - if oauth2Client.Secret != clientSecret { - http.Error(w, "Invalid client_secret parameter", http.StatusBadRequest) - return + + // Authenticate the client. Public clients (CLI / MCP) rely on PKCE instead + // of a secret; confidential clients must present client_secret. + if oauth2Client.IsPublic() { + if r.PostFormValue("code_verifier") == "" { + oauthError(w, http.StatusBadRequest, "invalid_request", "missing code_verifier") + return + } + } else { + clientSecret := r.PostFormValue("client_secret") + if clientSecret == "" { + oauthError(w, http.StatusBadRequest, "invalid_request", "missing client_secret") + return + } + if subtle.ConstantTimeCompare([]byte(clientSecret), []byte(oauth2Client.Secret)) != 1 { + oauthError(w, http.StatusUnauthorized, "invalid_client", "invalid client_secret") + return + } } - email, err := getOAuth2EmailFromCode(ctx, clientID, code) + oauth2Code, err := getOAuth2Code(ctx, clientID, code) if errors.Is(err, ErrOAuth2CodeNotFound) { - slog.WarnContext(ctx, "token: invalid code", "client_id", clientID, "code", code) - http.Error(w, "Invalid code parameter", http.StatusBadRequest) + slog.WarnContext(ctx, "token: invalid code", "client_id", clientID) + oauthError(w, http.StatusBadRequest, "invalid_grant", "invalid or expired code") return } if err != nil { - slog.ErrorContext(ctx, "token: get oauth2 email from code", "error", err) - http.Error(w, "Internal server error", http.StatusInternalServerError) + slog.ErrorContext(ctx, "token: get oauth2 code", "error", err) + oauthError(w, http.StatusInternalServerError, "server_error", "") return } + // PKCE: verify whenever the code was issued with a challenge. + if oauth2Code.CodeChallenge != "" { + verifier := r.PostFormValue("code_verifier") + if !verifyPKCE(verifier, oauth2Code.CodeChallenge, oauth2Code.CodeChallengeMethod) { + slog.WarnContext(ctx, "token: pkce verification failed", "client_id", clientID) + oauthError(w, http.StatusBadRequest, "invalid_grant", "PKCE verification failed") + return + } + } + + // For public clients the redirect_uri presented here must match the one the + // code was bound to at the authorize step (RFC 6749 §4.1.3). + if oauth2Client.IsPublic() && oauth2Code.RedirectURI != "" { + if r.PostFormValue("redirect_uri") != oauth2Code.RedirectURI { + oauthError(w, http.StatusBadRequest, "invalid_grant", "redirect_uri mismatch") + return + } + } + token := generateToken() hashedToken := hashToken(token) - err = insertToken(ctx, hashedToken, email) + err = insertToken(ctx, hashedToken, oauth2Code.Email) if err != nil { slog.ErrorContext(ctx, "token: insert token", "error", err) - http.Error(w, "Internal server error", http.StatusInternalServerError) + oauthError(w, http.StatusInternalServerError, "server_error", "") return } var resp struct { - RefreshToken string `json:"refresh_token"` + AccessToken string `json:"access_token"` TokenType string `json:"token_type"` + ExpiresIn int `json:"expires_in"` + RefreshToken string `json:"refresh_token"` } - resp.TokenType = "bearer" - resp.RefreshToken = token + resp.AccessToken = token + resp.TokenType = "Bearer" + resp.ExpiresIn = tokenTTLSeconds + resp.RefreshToken = token // retained for backward compatibility with the existing web client w.Header().Set("Content-Type", "application/json; charset=utf-8") + w.Header().Set("Cache-Control", "no-store") json.NewEncoder(w).Encode(resp) } +func oauthError(w http.ResponseWriter, status int, code, desc string) { + w.Header().Set("Content-Type", "application/json; charset=utf-8") + w.Header().Set("Cache-Control", "no-store") + w.WriteHeader(status) + body := map[string]string{"error": code} + if desc != "" { + body["error_description"] = desc + } + json.NewEncoder(w).Encode(body) +} + type apiResult struct { OK bool `json:"ok"` Result any `json:"result,omitempty"` diff --git a/handler_helpers_test.go b/handler_helpers_test.go new file mode 100644 index 0000000..847f1d7 --- /dev/null +++ b/handler_helpers_test.go @@ -0,0 +1,102 @@ +package main + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "testing" +) + +func TestIsURL(t *testing.T) { + for in, want := range map[string]bool{ + "https://app.example.com/cb": true, + "http://127.0.0.1:8080/cb": true, + "app.example.com": false, // no scheme + "ftp://example.com": false, // unsupported scheme + "https://": false, // no host + "": false, + } { + if got := isURL(in); got != want { + t.Errorf("isURL(%q) = %v, want %v", in, got, want) + } + } +} + +func TestExtractEmailFromIDToken(t *testing.T) { + t.Run("valid", func(t *testing.T) { + email, err := extractEmailFromIDToken(fakeIDToken("user@example.com")) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if email != "user@example.com" { + t.Errorf("email = %q, want user@example.com", email) + } + }) + t.Run("malformed (not 3 parts)", func(t *testing.T) { + if _, err := extractEmailFromIDToken("only.two"); err == nil { + t.Error("expected error for malformed token") + } + }) +} + +func TestHashToken(t *testing.T) { + a := hashToken("deploys-api.abc") + if a == "" { + t.Fatal("hashToken returned empty") + } + if a != hashToken("deploys-api.abc") { + t.Error("hashToken is not deterministic") + } + if a == hashToken("deploys-api.different") { + t.Error("different inputs produced the same hash") + } +} + +func TestMetadataHandler(t *testing.T) { + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/.well-known/oauth-authorization-server", nil) + MetadataHandler{BaseURL: "https://auth.test"}.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want 200", rec.Code) + } + var meta struct { + Issuer string `json:"issuer"` + AuthorizationEndpoint string `json:"authorization_endpoint"` + TokenEndpoint string `json:"token_endpoint"` + RegistrationEndpoint string `json:"registration_endpoint"` + CodeChallengeMethodsSupported []string `json:"code_challenge_methods_supported"` + TokenEndpointAuthMethodsSupported []string `json:"token_endpoint_auth_methods_supported"` + } + if err := json.NewDecoder(rec.Body).Decode(&meta); err != nil { + t.Fatalf("decode: %v", err) + } + if meta.Issuer != "https://auth.test" { + t.Errorf("issuer = %q", meta.Issuer) + } + if meta.AuthorizationEndpoint != "https://auth.test/" { + t.Errorf("authorization_endpoint = %q", meta.AuthorizationEndpoint) + } + if meta.TokenEndpoint != "https://auth.test/token" { + t.Errorf("token_endpoint = %q", meta.TokenEndpoint) + } + if meta.RegistrationEndpoint != "https://auth.test/register" { + t.Errorf("registration_endpoint = %q", meta.RegistrationEndpoint) + } + if !contains(meta.CodeChallengeMethodsSupported, "S256") { + t.Errorf("code_challenge_methods_supported = %v, want S256", meta.CodeChallengeMethodsSupported) + } + if !contains(meta.TokenEndpointAuthMethodsSupported, "none") || + !contains(meta.TokenEndpointAuthMethodsSupported, "client_secret_post") { + t.Errorf("token_endpoint_auth_methods_supported = %v", meta.TokenEndpointAuthMethodsSupported) + } +} + +func contains(s []string, v string) bool { + for _, x := range s { + if x == v { + return true + } + } + return false +} diff --git a/introspect.go b/introspect.go new file mode 100644 index 0000000..0b1b134 --- /dev/null +++ b/introspect.go @@ -0,0 +1,66 @@ +package main + +import ( + "crypto/subtle" + "database/sql" + "encoding/json" + "errors" + "log/slog" + "net/http" +) + +// IntrospectHandler implements OAuth 2.0 Token Introspection (RFC 7662). The MCP +// resource server calls it to validate an opaque bearer token it received. +// +// The endpoint is itself protected by a shared secret (INTROSPECTION_TOKEN) +// presented as `Authorization: Bearer `, so it is not a public oracle. +type IntrospectHandler struct { + Token string +} + +func (h IntrospectHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + if h.Token == "" { + http.Error(w, "introspection not configured", http.StatusServiceUnavailable) + return + } + expected := "Bearer " + h.Token + if subtle.ConstantTimeCompare([]byte(r.Header.Get("Authorization")), []byte(expected)) != 1 { + w.Header().Set("WWW-Authenticate", "Bearer") + http.Error(w, "unauthorized", http.StatusUnauthorized) + return + } + + token := r.PostFormValue("token") + if token == "" { + writeInactive(w) + return + } + + ctx := r.Context() + email, exp, err := lookupToken(ctx, hashToken(token)) + if errors.Is(err, sql.ErrNoRows) { + writeInactive(w) + return + } + if err != nil { + slog.ErrorContext(ctx, "introspect: lookup token", "error", err) + http.Error(w, "Internal server error", http.StatusInternalServerError) + return + } + + w.Header().Set("Content-Type", "application/json; charset=utf-8") + w.Header().Set("Cache-Control", "no-store") + json.NewEncoder(w).Encode(map[string]any{ + "active": true, + "sub": email, + "username": email, + "token_type": "Bearer", + "exp": exp, + }) +} + +func writeInactive(w http.ResponseWriter) { + w.Header().Set("Content-Type", "application/json; charset=utf-8") + w.Header().Set("Cache-Control", "no-store") + json.NewEncoder(w).Encode(map[string]any{"active": false}) +} diff --git a/introspect_test.go b/introspect_test.go new file mode 100644 index 0000000..c8484d1 --- /dev/null +++ b/introspect_test.go @@ -0,0 +1,133 @@ +package main + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "net/url" + "strings" + "testing" +) + +func introspectReq(t *testing.T, auth, token string) *http.Request { + t.Helper() + form := url.Values{"token": {token}} + req := httptest.NewRequest(http.MethodPost, "/introspect", strings.NewReader(form.Encode())) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + if auth != "" { + req.Header.Set("Authorization", auth) + } + return req +} + +func TestIntrospectHandler_NotConfigured(t *testing.T) { + t.Parallel() + rec := httptest.NewRecorder() + IntrospectHandler{Token: ""}.ServeHTTP(rec, introspectReq(t, "Bearer x", "tok")) + if rec.Code != http.StatusServiceUnavailable { + t.Fatalf("status = %d, want 503", rec.Code) + } +} + +func TestIntrospectHandler_Unauthorized(t *testing.T) { + t.Parallel() + rec := httptest.NewRecorder() + IntrospectHandler{Token: "secret"}.ServeHTTP(rec, introspectReq(t, "Bearer wrong", "tok")) + if rec.Code != http.StatusUnauthorized { + t.Fatalf("status = %d, want 401", rec.Code) + } +} + +func TestIntrospectHandler_ActiveToken(t *testing.T) { + t.Parallel() + tdb := newTestDB(t) + ctx := tdb.Ctx() + const raw = "deploys-api.abc" + seedToken(t, ctx, hashToken(raw), "user@example.com") + + rec := httptest.NewRecorder() + IntrospectHandler{Token: "secret"}.ServeHTTP(rec, introspectReq(t, "Bearer secret", raw).WithContext(ctx)) + + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want 200; body=%s", rec.Code, rec.Body.String()) + } + var resp struct { + Active bool `json:"active"` + Sub string `json:"sub"` + Exp int64 `json:"exp"` + } + if err := json.NewDecoder(rec.Body).Decode(&resp); err != nil { + t.Fatalf("decode: %v", err) + } + if !resp.Active { + t.Error("active = false, want true") + } + if resp.Sub != "user@example.com" { + t.Errorf("sub = %q", resp.Sub) + } + if resp.Exp == 0 { + t.Error("exp = 0, want a future unix timestamp") + } +} + +func TestIntrospectHandler_UnknownToken(t *testing.T) { + t.Parallel() + tdb := newTestDB(t) + ctx := tdb.Ctx() // empty user_tokens + + rec := httptest.NewRecorder() + IntrospectHandler{Token: "secret"}.ServeHTTP(rec, introspectReq(t, "Bearer secret", "deploys-api.gone").WithContext(ctx)) + + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want 200", rec.Code) + } + if active := decodeActive(t, rec); active { + t.Error("active = true, want false for unknown token") + } +} + +func TestIntrospectHandler_ExpiredToken(t *testing.T) { + t.Parallel() + tdb := newTestDB(t) + ctx := tdb.Ctx() + const raw = "deploys-api.expired" + // Insert a token that already expired. + if _, err := pgctxExec(t, ctx, `insert into user_tokens (token, email, expires_at) values ($1, $2, now() - interval '1 hour')`, hashToken(raw), "old@example.com"); err != nil { + t.Fatal(err) + } + + rec := httptest.NewRecorder() + IntrospectHandler{Token: "secret"}.ServeHTTP(rec, introspectReq(t, "Bearer secret", raw).WithContext(ctx)) + + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want 200", rec.Code) + } + if active := decodeActive(t, rec); active { + t.Error("active = true, want false for expired token") + } +} + +func TestIntrospectHandler_EmptyToken(t *testing.T) { + t.Parallel() + // Empty token short-circuits before any DB lookup. + rec := httptest.NewRecorder() + IntrospectHandler{Token: "secret"}.ServeHTTP(rec, introspectReq(t, "Bearer secret", "")) + + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want 200", rec.Code) + } + if active := decodeActive(t, rec); active { + t.Error("active = true, want false for empty token") + } +} + +func decodeActive(t *testing.T, rec *httptest.ResponseRecorder) bool { + t.Helper() + var resp struct { + Active bool `json:"active"` + } + if err := json.NewDecoder(rec.Body).Decode(&resp); err != nil { + t.Fatalf("decode: %v", err) + } + return resp.Active +} diff --git a/main.go b/main.go index 14e8fb8..c1f5c8a 100644 --- a/main.go +++ b/main.go @@ -4,6 +4,7 @@ import ( "database/sql" "net/http" "os" + "strings" "github.com/acoshift/pgsql/pgctx" _ "github.com/lib/pq" @@ -14,12 +15,20 @@ func main() { if port == "" { port = "8080" } + baseURL := os.Getenv("BASE_URL") + if baseURL == "" { + baseURL = "https://auth.deploys.app" + } + baseURL = strings.TrimRight(baseURL, "/") + oauth2ClientID := os.Getenv("OAUTH2_CLIENT_ID") if oauth2ClientID == "" { panic("missing OAUTH2_CLIENT_ID") } oauth2ClientSecret := os.Getenv("OAUTH2_CLIENT_SECRET") + introspectionToken := os.Getenv("INTROSPECTION_TOKEN") + sqlURL := os.Getenv("SQL_URL") if sqlURL == "" { panic("missing SQL_URL") @@ -30,15 +39,21 @@ func main() { } defer db.Close() + startCleanupWorker(db) + mux := http.NewServeMux() - mux.Handle("GET /", RedirectHandler{OAuth2ClientID: oauth2ClientID}) + mux.Handle("GET /.well-known/oauth-authorization-server", MetadataHandler{BaseURL: baseURL}) + mux.Handle("GET /", RedirectHandler{OAuth2ClientID: oauth2ClientID, BaseURL: baseURL}) mux.Handle("GET /callback", CallbackHandler{ OAuth2ClientID: oauth2ClientID, OAuth2ClientSecret: oauth2ClientSecret, + BaseURL: baseURL, }) mux.Handle("GET /revoke", RevokeHandler{}) // TODO: remove ? mux.Handle("POST /revoke", RevokePostHandler{}) mux.Handle("POST /token", TokenHandler{}) + mux.Handle("POST /register", RegisterHandler{BaseURL: baseURL}) + mux.Handle("POST /introspect", IntrospectHandler{Token: introspectionToken}) http.ListenAndServe(":"+port, pgctx.Middleware(db)(mux)) } diff --git a/metadata.go b/metadata.go new file mode 100644 index 0000000..e795958 --- /dev/null +++ b/metadata.go @@ -0,0 +1,32 @@ +package main + +import ( + "encoding/json" + "net/http" +) + +// MetadataHandler serves OAuth 2.0 Authorization Server Metadata (RFC 8414). +// MCP clients fetch this to discover the authorize, token and registration +// endpoints plus the supported PKCE methods and client auth methods. +type MetadataHandler struct { + BaseURL string +} + +func (h MetadataHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + meta := map[string]any{ + "issuer": h.BaseURL, + "authorization_endpoint": h.BaseURL + "/", + "token_endpoint": h.BaseURL + "/token", + "registration_endpoint": h.BaseURL + "/register", + "revocation_endpoint": h.BaseURL + "/revoke", + "introspection_endpoint": h.BaseURL + "/introspect", + "response_types_supported": []string{"code"}, + "grant_types_supported": []string{"authorization_code"}, + "code_challenge_methods_supported": []string{"S256"}, + "token_endpoint_auth_methods_supported": []string{"client_secret_post", "none"}, + } + + w.Header().Set("Content-Type", "application/json; charset=utf-8") + w.Header().Set("Cache-Control", "public, max-age=3600") + json.NewEncoder(w).Encode(meta) +} diff --git a/oauth2.go b/oauth2.go index 4a82307..5041590 100644 --- a/oauth2.go +++ b/oauth2.go @@ -6,6 +6,7 @@ import ( "errors" "github.com/acoshift/pgsql/pgctx" + "github.com/lib/pq" ) var ( @@ -15,26 +16,46 @@ var ( ) type OAuth2Client struct { - ID string - Secret string - RedirectURI string + ID string + Secret string + RedirectURI string // legacy glob pattern (confidential clients) + RedirectURIs []string // exact URIs (public / DCR clients) + TokenEndpointAuthMethod string // "client_secret_post" or "none" + ClientName string +} + +// IsPublic reports whether the client authenticates without a secret (PKCE only). +func (c *OAuth2Client) IsPublic() bool { + return c.TokenEndpointAuthMethod == "none" } type Session struct { - ClientID string - State string - CallbackState string - CallbackURL string + ClientID string + State string + CallbackState string + CallbackURL string + CodeChallenge string + CodeChallengeMethod string + Resource string +} + +type OAuth2Code struct { + Email string + CodeChallenge string + CodeChallengeMethod string + RedirectURI string + Resource string } func getOAuth2Client(ctx context.Context, clientID string) (*OAuth2Client, error) { var x OAuth2Client + var secret sql.NullString err := pgctx.QueryRow(ctx, ` - select id, secret, redirect_uri + select id, secret, redirect_uri, redirect_uris, token_endpoint_auth_method from oauth2_clients where id = $1 `, clientID).Scan( - &x.ID, &x.Secret, &x.RedirectURI, + &x.ID, &secret, &x.RedirectURI, &x.RedirectURIs, &x.TokenEndpointAuthMethod, ) if errors.Is(err, sql.ErrNoRows) { return nil, ErrOAuth2ClientNotFound @@ -42,41 +63,55 @@ func getOAuth2Client(ctx context.Context, clientID string) (*OAuth2Client, error if err != nil { return nil, err } + x.Secret = secret.String return &x, nil } -func insertOAuth2Code(ctx context.Context, clientID, code, email string) error { +// insertOAuth2Client persists a dynamically registered public client. +func insertOAuth2Client(ctx context.Context, c *OAuth2Client) error { _, err := pgctx.Exec(ctx, ` - insert into oauth2_codes (id, client_id, email) - values ($1, $2, $3) - `, code, clientID, email) + insert into oauth2_clients (id, secret, redirect_uri, redirect_uris, token_endpoint_auth_method, client_name) + values ($1, null, '', $2, $3, $4) + `, c.ID, pq.Array(c.RedirectURIs), c.TokenEndpointAuthMethod, c.ClientName) return err } -func getOAuth2EmailFromCode(ctx context.Context, clientID, code string) (string, error) { - var email string +func insertOAuth2Code(ctx context.Context, clientID, code string, c *OAuth2Code) error { + _, err := pgctx.Exec(ctx, ` + insert into oauth2_codes (id, client_id, email, code_challenge, code_challenge_method, redirect_uri, resource) + values ($1, $2, $3, $4, $5, $6, $7) + `, code, clientID, c.Email, c.CodeChallenge, c.CodeChallengeMethod, c.RedirectURI, c.Resource) + return err +} + +// getOAuth2Code atomically consumes a code, returning the associated email and +// the PKCE / redirect binding stored when it was issued. 1-hour TTL. +func getOAuth2Code(ctx context.Context, clientID, code string) (*OAuth2Code, error) { + var x OAuth2Code err := pgctx.QueryRow(ctx, ` delete from oauth2_codes where id = $1 and client_id = $2 and now() < created_at + interval '1 hour' - returning email - `, code, clientID).Scan(&email) + returning email, code_challenge, code_challenge_method, redirect_uri, resource + `, code, clientID).Scan( + &x.Email, &x.CodeChallenge, &x.CodeChallengeMethod, &x.RedirectURI, &x.Resource, + ) if errors.Is(err, sql.ErrNoRows) { - return "", ErrOAuth2CodeNotFound + return nil, ErrOAuth2CodeNotFound } if err != nil { - return "", err + return nil, err } - return email, nil + return &x, nil } func getSession(ctx context.Context, sessionID string) (*Session, error) { var x Session err := pgctx.QueryRow(ctx, ` - select client_id, state, callback_state, callback_url + select client_id, state, callback_state, callback_url, code_challenge, code_challenge_method, resource from oauth2_sessions where id = $1 and now() < created_at + interval '1 hour' `, sessionID).Scan( - &x.ClientID, &x.State, &x.CallbackState, &x.CallbackURL, + &x.ClientID, &x.State, &x.CallbackState, &x.CallbackURL, &x.CodeChallenge, &x.CodeChallengeMethod, &x.Resource, ) if errors.Is(err, sql.ErrNoRows) { return nil, ErrOAuth2SessionNotFound @@ -95,8 +130,9 @@ func getSession(ctx context.Context, sessionID string) (*Session, error) { func saveSession(ctx context.Context, sessionID string, session *Session) error { _, err := pgctx.Exec(ctx, ` - insert into oauth2_sessions (id, client_id, state, callback_state, callback_url) - values ($1, $2, $3, $4, $5) - `, sessionID, session.ClientID, session.State, session.CallbackState, session.CallbackURL) + insert into oauth2_sessions (id, client_id, state, callback_state, callback_url, code_challenge, code_challenge_method, resource) + values ($1, $2, $3, $4, $5, $6, $7, $8) + `, sessionID, session.ClientID, session.State, session.CallbackState, session.CallbackURL, + session.CodeChallenge, session.CodeChallengeMethod, session.Resource) return err } diff --git a/pkce.go b/pkce.go new file mode 100644 index 0000000..32c90ef --- /dev/null +++ b/pkce.go @@ -0,0 +1,80 @@ +package main + +import ( + "crypto/sha256" + "crypto/subtle" + "encoding/base64" + "net/url" + "strings" +) + +// verifyPKCE checks a PKCE code_verifier against the stored code_challenge. +// Only the S256 method is supported (OAuth 2.1 / MCP requirement). +func verifyPKCE(verifier, challenge, method string) bool { + if verifier == "" || challenge == "" { + return false + } + if method != "S256" { + return false + } + sum := sha256.Sum256([]byte(verifier)) + computed := base64.RawURLEncoding.EncodeToString(sum[:]) + return subtle.ConstantTimeCompare([]byte(computed), []byte(challenge)) == 1 +} + +func isLoopbackHost(host string) bool { + return host == "127.0.0.1" || host == "::1" || host == "localhost" +} + +// redirectURIAllowed reports whether got matches one of the registered exact +// redirect URIs. Per RFC 8252 the port is ignored for loopback addresses, so a +// CLI listening on an ephemeral localhost port is accepted. +func redirectURIAllowed(registered []string, got string) bool { + gu, err := url.Parse(got) + if err != nil { + return false + } + for _, raw := range registered { + raw = strings.TrimSpace(raw) + if raw == "" { + continue + } + ru, err := url.Parse(raw) + if err != nil { + continue + } + if ru.Scheme != gu.Scheme { + continue + } + if !strings.EqualFold(ru.Hostname(), gu.Hostname()) { + continue + } + if ru.Path != gu.Path { + continue + } + if isLoopbackHost(gu.Hostname()) { + return true + } + if ru.Port() == gu.Port() { + return true + } + } + return false +} + +// validRegistrationRedirectURI enforces the redirect URIs a client may register: +// HTTPS anywhere, or plain HTTP only for loopback (native/CLI clients). +func validRegistrationRedirectURI(s string) bool { + u, err := url.Parse(s) + if err != nil || u.Host == "" { + return false + } + switch u.Scheme { + case "https": + return true + case "http": + return isLoopbackHost(u.Hostname()) + default: + return false + } +} diff --git a/pkce_test.go b/pkce_test.go new file mode 100644 index 0000000..55a399b --- /dev/null +++ b/pkce_test.go @@ -0,0 +1,86 @@ +package main + +import "testing" + +func TestVerifyPKCE(t *testing.T) { + verifier, challenge := pkcePair() + + cases := []struct { + name string + verifier string + chal string + method string + want bool + }{ + {"valid S256", verifier, challenge, "S256", true}, + {"wrong verifier", "not-the-verifier", challenge, "S256", false}, + {"empty verifier", "", challenge, "S256", false}, + {"empty challenge", verifier, "", "S256", false}, + {"plain method rejected", verifier, verifier, "plain", false}, + {"empty method rejected", verifier, challenge, "", false}, + } + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + if got := verifyPKCE(c.verifier, c.chal, c.method); got != c.want { + t.Errorf("verifyPKCE(%q,%q,%q) = %v, want %v", c.verifier, c.chal, c.method, got, c.want) + } + }) + } +} + +func TestIsLoopbackHost(t *testing.T) { + for host, want := range map[string]bool{ + "127.0.0.1": true, + "::1": true, + "localhost": true, + "example.com": false, + "10.0.0.1": false, + "": false, + } { + if got := isLoopbackHost(host); got != want { + t.Errorf("isLoopbackHost(%q) = %v, want %v", host, got, want) + } + } +} + +func TestRedirectURIAllowed(t *testing.T) { + cases := []struct { + name string + registered []string + got string + want bool + }{ + {"exact https", []string{"https://app.example.com/cb"}, "https://app.example.com/cb", true}, + {"loopback ignores port", []string{"http://127.0.0.1:1234/callback"}, "http://127.0.0.1:55001/callback", true}, + {"localhost ignores port", []string{"http://localhost:1/callback"}, "http://localhost:9999/callback", true}, + {"loopback path mismatch", []string{"http://127.0.0.1:1234/callback"}, "http://127.0.0.1:55001/other", false}, + {"https port mismatch", []string{"https://app.example.com:443/cb"}, "https://app.example.com:8443/cb", false}, + {"scheme mismatch", []string{"https://app.example.com/cb"}, "http://app.example.com/cb", false}, + {"host mismatch", []string{"https://app.example.com/cb"}, "https://evil.example.com/cb", false}, + {"not in list", []string{"https://a.example.com/cb"}, "https://b.example.com/cb", false}, + {"matches second entry", []string{"https://a.example.com/cb", "http://127.0.0.1:1/cb"}, "http://127.0.0.1:42/cb", true}, + } + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + if got := redirectURIAllowed(c.registered, c.got); got != c.want { + t.Errorf("redirectURIAllowed(%v, %q) = %v, want %v", c.registered, c.got, got, c.want) + } + }) + } +} + +func TestValidRegistrationRedirectURI(t *testing.T) { + for uri, want := range map[string]bool{ + "https://app.example.com/cb": true, + "http://127.0.0.1:1234/cb": true, + "http://localhost/cb": true, + "http://app.example.com/cb": false, // plain http only allowed for loopback + "ftp://example.com": false, + "not-a-url": false, + "": false, + } { + if got := validRegistrationRedirectURI(uri); got != want { + t.Errorf("validRegistrationRedirectURI(%q) = %v, want %v", uri, got, want) + } + } +} diff --git a/redirect_test.go b/redirect_test.go new file mode 100644 index 0000000..5337d1c --- /dev/null +++ b/redirect_test.go @@ -0,0 +1,205 @@ +package main + +import ( + "net/http" + "net/http/httptest" + "net/url" + "strings" + "testing" +) + +func getReq(t *testing.T, q url.Values) *http.Request { + t.Helper() + return httptest.NewRequest(http.MethodGet, "/?"+q.Encode(), nil) +} + +func findCookie(rec *httptest.ResponseRecorder, name string) *http.Cookie { + for _, c := range rec.Result().Cookies() { + if c.Name == name { + return c + } + } + return nil +} + +// --- OLD FLOW (regression): confidential client redirect to Google --- + +func TestRedirectHandler_Confidential_Success(t *testing.T) { + t.Parallel() + tdb := newTestDB(t) + ctx := tdb.Ctx() + seedConfidentialClient(t, ctx, "web", "topsecret", "https://app.example.com/*") + + q := url.Values{ + "client_id": {"web"}, + "state": {"cbstate"}, + "redirect_uri": {"https://app.example.com/cb"}, + } + rec := httptest.NewRecorder() + RedirectHandler{OAuth2ClientID: "googleclient", BaseURL: "https://auth.test"}. + ServeHTTP(rec, getReq(t, q).WithContext(ctx)) + + if rec.Code != http.StatusFound { + t.Fatalf("status = %d, want 302; body=%s", rec.Code, rec.Body.String()) + } + loc := rec.Header().Get("Location") + if !strings.HasPrefix(loc, "https://accounts.google.com/o/oauth2/auth?") { + t.Errorf("Location = %q, want Google authorize URL", loc) + } + if !strings.Contains(loc, "client_id=googleclient") { + t.Errorf("Location missing google client_id: %q", loc) + } + if !strings.Contains(loc, url.QueryEscape("https://auth.test/callback")) { + t.Errorf("Location missing callback redirect_uri: %q", loc) + } + if c := findCookie(rec, "s"); c == nil || c.Value == "" { + t.Error("session cookie 's' not set") + } + if n := countRows(t, ctx, "oauth2_sessions"); n != 1 { + t.Errorf("oauth2_sessions = %d, want 1", n) + } +} + +func TestRedirectHandler_MissingParams(t *testing.T) { + t.Parallel() + // All rejected before any DB access, so no test DB is needed. + cases := map[string]url.Values{ + "missing client_id": {"state": {"s"}, "redirect_uri": {"https://a/cb"}}, + "missing state": {"client_id": {"web"}, "redirect_uri": {"https://a/cb"}}, + "missing redirect_uri": {"client_id": {"web"}, "state": {"s"}}, + "invalid redirect_uri": {"client_id": {"web"}, "state": {"s"}, "redirect_uri": {"not-a-url"}}, + } + for name, q := range cases { + t.Run(name, func(t *testing.T) { + rec := httptest.NewRecorder() + RedirectHandler{OAuth2ClientID: "g", BaseURL: "https://auth.test"}.ServeHTTP(rec, getReq(t, q)) + if rec.Code != http.StatusBadRequest { + t.Errorf("status = %d, want 400", rec.Code) + } + }) + } +} + +func TestRedirectHandler_UnknownClient(t *testing.T) { + t.Parallel() + tdb := newTestDB(t) + ctx := tdb.Ctx() // empty DB + + q := url.Values{"client_id": {"ghost"}, "state": {"s"}, "redirect_uri": {"https://a/cb"}} + rec := httptest.NewRecorder() + RedirectHandler{OAuth2ClientID: "g", BaseURL: "https://auth.test"}. + ServeHTTP(rec, getReq(t, q).WithContext(ctx)) + if rec.Code != http.StatusBadRequest { + t.Fatalf("status = %d, want 400", rec.Code) + } +} + +func TestRedirectHandler_Confidential_RedirectNotAllowed(t *testing.T) { + t.Parallel() + tdb := newTestDB(t) + ctx := tdb.Ctx() + seedConfidentialClient(t, ctx, "web", "topsecret", "https://app.example.com/*") + + q := url.Values{"client_id": {"web"}, "state": {"s"}, "redirect_uri": {"https://evil.example.com/cb"}} + rec := httptest.NewRecorder() + RedirectHandler{OAuth2ClientID: "g", BaseURL: "https://auth.test"}. + ServeHTTP(rec, getReq(t, q).WithContext(ctx)) + if rec.Code != http.StatusBadRequest { + t.Fatalf("status = %d, want 400", rec.Code) + } +} + +// --- NEW FLOW: public client requires PKCE + exact (loopback) redirect --- + +func TestRedirectHandler_Public_Success(t *testing.T) { + t.Parallel() + _, challenge := pkcePair() + tdb := newTestDB(t) + ctx := tdb.Ctx() + seedPublicClient(t, ctx, "cli", "http://127.0.0.1:1234/callback") + + q := url.Values{ + "client_id": {"cli"}, + "state": {"cbstate"}, + "redirect_uri": {"http://127.0.0.1:55001/callback"}, // different loopback port, allowed + "code_challenge": {challenge}, + "code_challenge_method": {"S256"}, + } + rec := httptest.NewRecorder() + RedirectHandler{OAuth2ClientID: "googleclient", BaseURL: "https://auth.test"}. + ServeHTTP(rec, getReq(t, q).WithContext(ctx)) + + if rec.Code != http.StatusFound { + t.Fatalf("status = %d, want 302; body=%s", rec.Code, rec.Body.String()) + } + if !strings.HasPrefix(rec.Header().Get("Location"), "https://accounts.google.com/") { + t.Errorf("Location = %q", rec.Header().Get("Location")) + } + // The PKCE challenge must be persisted on the session for the callback. + challengeStored, methodStored, cbURL := oneSessionPKCE(t, ctx) + if challengeStored != challenge || methodStored != "S256" { + t.Errorf("session PKCE = (%q,%q), want (%q,S256)", challengeStored, methodStored, challenge) + } + if cbURL != "http://127.0.0.1:55001/callback" { + t.Errorf("session callback_url = %q", cbURL) + } +} + +func TestRedirectHandler_Public_MissingChallenge(t *testing.T) { + t.Parallel() + tdb := newTestDB(t) + ctx := tdb.Ctx() + seedPublicClient(t, ctx, "cli", "http://127.0.0.1:1234/callback") + + q := url.Values{"client_id": {"cli"}, "state": {"s"}, "redirect_uri": {"http://127.0.0.1:55001/callback"}} + rec := httptest.NewRecorder() + RedirectHandler{OAuth2ClientID: "g", BaseURL: "https://auth.test"}. + ServeHTTP(rec, getReq(t, q).WithContext(ctx)) + if rec.Code != http.StatusBadRequest { + t.Fatalf("status = %d, want 400 (PKCE required for public clients)", rec.Code) + } +} + +func TestRedirectHandler_Public_UnsupportedChallengeMethod(t *testing.T) { + t.Parallel() + _, challenge := pkcePair() + tdb := newTestDB(t) + ctx := tdb.Ctx() + seedPublicClient(t, ctx, "cli", "http://127.0.0.1:1234/callback") + + q := url.Values{ + "client_id": {"cli"}, + "state": {"s"}, + "redirect_uri": {"http://127.0.0.1:55001/callback"}, + "code_challenge": {challenge}, + "code_challenge_method": {"plain"}, + } + rec := httptest.NewRecorder() + RedirectHandler{OAuth2ClientID: "g", BaseURL: "https://auth.test"}. + ServeHTTP(rec, getReq(t, q).WithContext(ctx)) + if rec.Code != http.StatusBadRequest { + t.Fatalf("status = %d, want 400 (only S256)", rec.Code) + } +} + +func TestRedirectHandler_Public_RedirectNotRegistered(t *testing.T) { + t.Parallel() + _, challenge := pkcePair() + tdb := newTestDB(t) + ctx := tdb.Ctx() + seedPublicClient(t, ctx, "cli", "http://127.0.0.1:1234/callback") + + q := url.Values{ + "client_id": {"cli"}, + "state": {"s"}, + "redirect_uri": {"https://evil.example.com/callback"}, + "code_challenge": {challenge}, + "code_challenge_method": {"S256"}, + } + rec := httptest.NewRecorder() + RedirectHandler{OAuth2ClientID: "g", BaseURL: "https://auth.test"}. + ServeHTTP(rec, getReq(t, q).WithContext(ctx)) + if rec.Code != http.StatusBadRequest { + t.Fatalf("status = %d, want 400 (redirect not registered)", rec.Code) + } +} diff --git a/register.go b/register.go new file mode 100644 index 0000000..af549a9 --- /dev/null +++ b/register.go @@ -0,0 +1,85 @@ +package main + +import ( + "encoding/json" + "log/slog" + "net/http" + "time" +) + +// RegisterHandler implements OAuth 2.0 Dynamic Client Registration (RFC 7591). +// It only issues public clients (token_endpoint_auth_method "none"); they +// authenticate via PKCE, so no client secret is generated. This is the path MCP +// CLIs use to obtain a client_id without manual provisioning. +type RegisterHandler struct { + BaseURL string +} + +func (h RegisterHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + var req struct { + ClientName string `json:"client_name"` + RedirectURIs []string `json:"redirect_uris"` + GrantTypes []string `json:"grant_types"` + ResponseTypes []string `json:"response_types"` + TokenEndpointAuthMethod string `json:"token_endpoint_auth_method"` + } + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + registrationError(w, "invalid_client_metadata", "invalid request body") + return + } + + if req.TokenEndpointAuthMethod != "" && req.TokenEndpointAuthMethod != "none" { + registrationError(w, "invalid_client_metadata", "only token_endpoint_auth_method \"none\" is supported") + return + } + if len(req.RedirectURIs) == 0 { + registrationError(w, "invalid_redirect_uri", "at least one redirect_uri is required") + return + } + for _, uri := range req.RedirectURIs { + if !validRegistrationRedirectURI(uri) { + registrationError(w, "invalid_redirect_uri", "redirect_uri must be https or http loopback: "+uri) + return + } + } + + ctx := r.Context() + clientID := generateClientID() + err := insertOAuth2Client(ctx, &OAuth2Client{ + ID: clientID, + RedirectURIs: req.RedirectURIs, + TokenEndpointAuthMethod: "none", + ClientName: req.ClientName, + }) + if err != nil { + slog.ErrorContext(ctx, "register: insert oauth2 client", "error", err) + w.Header().Set("Content-Type", "application/json; charset=utf-8") + w.WriteHeader(http.StatusInternalServerError) + json.NewEncoder(w).Encode(map[string]string{"error": "server_error"}) + return + } + + resp := map[string]any{ + "client_id": clientID, + "client_id_issued_at": time.Now().Unix(), + "client_name": req.ClientName, + "redirect_uris": req.RedirectURIs, + "grant_types": []string{"authorization_code"}, + "response_types": []string{"code"}, + "token_endpoint_auth_method": "none", + } + + w.Header().Set("Content-Type", "application/json; charset=utf-8") + w.Header().Set("Cache-Control", "no-store") + w.WriteHeader(http.StatusCreated) + json.NewEncoder(w).Encode(resp) +} + +func registrationError(w http.ResponseWriter, code, desc string) { + w.Header().Set("Content-Type", "application/json; charset=utf-8") + w.WriteHeader(http.StatusBadRequest) + json.NewEncoder(w).Encode(map[string]string{ + "error": code, + "error_description": desc, + }) +} diff --git a/register_test.go b/register_test.go new file mode 100644 index 0000000..d932259 --- /dev/null +++ b/register_test.go @@ -0,0 +1,88 @@ +package main + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" +) + +func postJSON(t *testing.T, path, body string) *http.Request { + t.Helper() + req := httptest.NewRequest(http.MethodPost, path, strings.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + return req +} + +func TestRegisterHandler_Success(t *testing.T) { + t.Parallel() + tdb := newTestDB(t) + ctx := tdb.Ctx() + + body := `{"client_name":"my cli","redirect_uris":["http://127.0.0.1:1234/cb"]}` + rec := httptest.NewRecorder() + RegisterHandler{BaseURL: "https://auth.test"}.ServeHTTP(rec, postJSON(t, "/register", body).WithContext(ctx)) + + if rec.Code != http.StatusCreated { + t.Fatalf("status = %d, want 201; body=%s", rec.Code, rec.Body.String()) + } + var resp struct { + ClientID string `json:"client_id"` + TokenEndpointAuthMethod string `json:"token_endpoint_auth_method"` + RedirectURIs []string `json:"redirect_uris"` + } + if err := json.NewDecoder(rec.Body).Decode(&resp); err != nil { + t.Fatalf("decode: %v", err) + } + if resp.ClientID == "" { + t.Fatal("client_id is empty") + } + if resp.TokenEndpointAuthMethod != "none" { + t.Errorf("token_endpoint_auth_method = %q, want none", resp.TokenEndpointAuthMethod) + } + if len(resp.RedirectURIs) != 1 || resp.RedirectURIs[0] != "http://127.0.0.1:1234/cb" { + t.Errorf("redirect_uris = %v", resp.RedirectURIs) + } + // The client is persisted as a public client with the registered redirect. + if uris, ok := clientRedirectURIs(t, ctx, resp.ClientID); !ok || len(uris) != 1 || uris[0] != "http://127.0.0.1:1234/cb" { + t.Errorf("persisted redirect_uris = %v (ok=%v)", uris, ok) + } +} + +func TestRegisterHandler_HTTPSRedirectAllowed(t *testing.T) { + t.Parallel() + tdb := newTestDB(t) + ctx := tdb.Ctx() + + body := `{"redirect_uris":["https://app.example.com/cb"]}` + rec := httptest.NewRecorder() + RegisterHandler{BaseURL: "https://auth.test"}.ServeHTTP(rec, postJSON(t, "/register", body).WithContext(ctx)) + + if rec.Code != http.StatusCreated { + t.Fatalf("status = %d, want 201; body=%s", rec.Code, rec.Body.String()) + } + if n := countRows(t, ctx, "oauth2_clients"); n != 1 { + t.Errorf("oauth2_clients = %d, want 1", n) + } +} + +func TestRegisterHandler_Rejections(t *testing.T) { + t.Parallel() + // All rejected before any DB write, so no test DB is needed. + cases := map[string]string{ + "invalid json": `{not json`, + "no redirect_uris": `{"client_name":"x"}`, + "non-loopback http": `{"redirect_uris":["http://app.example.com/cb"]}`, + "confidential rejected": `{"redirect_uris":["https://app.example.com/cb"],"token_endpoint_auth_method":"client_secret_post"}`, + } + for name, body := range cases { + t.Run(name, func(t *testing.T) { + rec := httptest.NewRecorder() + RegisterHandler{BaseURL: "https://auth.test"}.ServeHTTP(rec, postJSON(t, "/register", body)) + if rec.Code != http.StatusBadRequest { + t.Errorf("status = %d, want 400; body=%s", rec.Code, rec.Body.String()) + } + }) + } +} diff --git a/schema.sql b/schema.sql index c9155d5..d754914 100644 --- a/schema.sql +++ b/schema.sql @@ -1,28 +1,38 @@ create table oauth2_clients ( - id string, - secret string not null, - redirect_uri string not null, - created_at timestamptz not null default now(), + id string, + secret string, + redirect_uri string not null default '', + redirect_uris string[] not null default array[]::string[], + token_endpoint_auth_method string not null default 'client_secret_post', + client_name string not null default '', + created_at timestamptz not null default now(), primary key (id) ); create table oauth2_codes ( - id string, - client_id string not null, - email string not null, - created_at timestamptz not null default now(), + id string, + client_id string not null, + email string not null, + code_challenge string not null default '', + code_challenge_method string not null default '', + redirect_uri string not null default '', + resource string not null default '', + created_at timestamptz not null default now(), primary key (id), foreign key (client_id) references oauth2_clients (id) on delete cascade ); create index oauth2_codes_created_at_idx on oauth2_codes (created_at); create table oauth2_sessions ( - id string, - client_id string not null, - state string not null, - callback_state string not null, - callback_url string not null, - created_at timestamptz not null default now(), + id string, + client_id string not null, + state string not null, + callback_state string not null, + callback_url string not null, + code_challenge string not null default '', + code_challenge_method string not null default '', + resource string not null default '', + created_at timestamptz not null default now(), primary key (id) ); create index oauth2_sessions_created_at_idx on oauth2_sessions (created_at); diff --git a/schema/01_init.sql b/schema/01_init.sql new file mode 100644 index 0000000..083b135 --- /dev/null +++ b/schema/01_init.sql @@ -0,0 +1,36 @@ +create table oauth2_clients ( + id string, + secret string not null, + redirect_uri string not null, + created_at timestamptz not null default now(), + primary key (id) +); + +create table oauth2_codes ( + id string, + client_id string not null, + email string not null, + created_at timestamptz not null default now(), + primary key (id), + foreign key (client_id) references oauth2_clients (id) on delete cascade +); +create index oauth2_codes_created_at_idx on oauth2_codes (created_at); + +create table oauth2_sessions ( + id string, + client_id string not null, + state string not null, + callback_state string not null, + callback_url string not null, + created_at timestamptz not null default now(), + primary key (id) +); +create index oauth2_sessions_created_at_idx on oauth2_sessions (created_at); + +create table user_tokens ( + token string, + email string not null, + expires_at timestamptz not null, + created_at timestamptz not null default now(), + primary key (token) +); diff --git a/schema/02_mcp.sql b/schema/02_mcp.sql new file mode 100644 index 0000000..7e62402 --- /dev/null +++ b/schema/02_mcp.sql @@ -0,0 +1,16 @@ +-- MCP / OAuth 2.1 support (public clients, PKCE, DCR, resource indicators). + +alter table oauth2_clients alter column secret drop not null; +alter table oauth2_clients alter column redirect_uri set default ''; +alter table oauth2_clients add column if not exists redirect_uris string[] not null default array[]::string[]; +alter table oauth2_clients add column if not exists token_endpoint_auth_method string not null default 'client_secret_post'; +alter table oauth2_clients add column if not exists client_name string not null default ''; + +alter table oauth2_codes add column if not exists code_challenge string not null default ''; +alter table oauth2_codes add column if not exists code_challenge_method string not null default ''; +alter table oauth2_codes add column if not exists redirect_uri string not null default ''; +alter table oauth2_codes add column if not exists resource string not null default ''; + +alter table oauth2_sessions add column if not exists code_challenge string not null default ''; +alter table oauth2_sessions add column if not exists code_challenge_method string not null default ''; +alter table oauth2_sessions add column if not exists resource string not null default ''; diff --git a/schema/schema.go b/schema/schema.go new file mode 100644 index 0000000..2d6257c --- /dev/null +++ b/schema/schema.go @@ -0,0 +1,76 @@ +package schema + +import ( + "context" + "database/sql" + "embed" +) + +//go:embed *.sql +var fs embed.FS + +func Migrate(ctx context.Context, db *sql.DB) error { + if err := setupMigrationTable(ctx, db); err != nil { + return err + } + + list, err := fs.ReadDir(".") + if err != nil { + return err + } + + for _, x := range list { + migrateID := x.Name() + migrated, err := isMigrated(ctx, db, migrateID) + if err != nil { + return err + } + if migrated { + continue + } + + b, err := fs.ReadFile(migrateID) + if err != nil { + return err + } + if _, err = db.ExecContext(ctx, string(b)); err != nil { + return err + } + if err = stampMigrated(ctx, db, migrateID, string(b)); err != nil { + return err + } + } + + return nil +} + +func setupMigrationTable(ctx context.Context, db *sql.DB) error { + _, err := db.ExecContext(ctx, ` + create table if not exists migrations ( + id varchar not null, + content varchar not null, + ts timestamptz not null default now(), + primary key (id) + ) + `) + return err +} + +func isMigrated(ctx context.Context, db *sql.DB, migrateID string) (bool, error) { + var migrated bool + err := db.QueryRowContext(ctx, ` + select exists ( + select 1 from migrations where id = $1 + ) + `, migrateID).Scan(&migrated) + return migrated, err +} + +func stampMigrated(ctx context.Context, db *sql.DB, migrateID, content string) error { + _, err := db.ExecContext(ctx, ` + insert into migrations (id, content) + values ($1, $2) + on conflict (id) do nothing + `, migrateID, content) + return err +} diff --git a/setup_test.go b/setup_test.go new file mode 100644 index 0000000..1c858e3 --- /dev/null +++ b/setup_test.go @@ -0,0 +1,17 @@ +package main + +import ( + "testing" + + "github.com/deploys-app/auth/tu" +) + +// newTestDB starts a fresh CockroachDB instance for the test and tears it down +// on completion. Each test gets its own isolated database so they can run in +// parallel without sharing rows or worrying about cleanup. +func newTestDB(t *testing.T) *tu.Context { + t.Helper() + c := tu.Setup() + t.Cleanup(c.Teardown) + return c +} diff --git a/testutil_test.go b/testutil_test.go new file mode 100644 index 0000000..9c080e5 --- /dev/null +++ b/testutil_test.go @@ -0,0 +1,151 @@ +package main + +import ( + "context" + "crypto/sha256" + "database/sql" + "encoding/base64" + "testing" + + "github.com/acoshift/pgsql/pgctx" + "github.com/lib/pq" +) + +// pkcePair returns a PKCE verifier and its S256 challenge. +func pkcePair() (verifier, challenge string) { + verifier = "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk" + sum := sha256.Sum256([]byte(verifier)) + challenge = base64.RawURLEncoding.EncodeToString(sum[:]) + return verifier, challenge +} + +// fakeIDToken builds a JWT-shaped token whose payload carries the given email, +// matching what extractEmailFromIDToken expects (RawStdEncoding-encoded body). +func fakeIDToken(email string) string { + payload := base64.RawStdEncoding.EncodeToString([]byte(`{"email":"` + email + `"}`)) + return "header." + payload + ".signature" +} + +// --- seeding helpers (operate on a context already carrying the DB) --- + +func seedConfidentialClient(t *testing.T, ctx context.Context, id, secret, redirectGlob string) { + t.Helper() + _, err := pgctx.Exec(ctx, ` + insert into oauth2_clients (id, secret, redirect_uri, token_endpoint_auth_method) + values ($1, $2, $3, 'client_secret_post') + `, id, secret, redirectGlob) + if err != nil { + t.Fatalf("seed confidential client: %v", err) + } +} + +func seedPublicClient(t *testing.T, ctx context.Context, id string, redirectURIs ...string) { + t.Helper() + _, err := pgctx.Exec(ctx, ` + insert into oauth2_clients (id, redirect_uris, token_endpoint_auth_method) + values ($1, $2, 'none') + `, id, pq.Array(redirectURIs)) + if err != nil { + t.Fatalf("seed public client: %v", err) + } +} + +func seedCode(t *testing.T, ctx context.Context, id, clientID, email, challenge, method, redirect, resource string) { + t.Helper() + _, err := pgctx.Exec(ctx, ` + insert into oauth2_codes (id, client_id, email, code_challenge, code_challenge_method, redirect_uri, resource) + values ($1, $2, $3, $4, $5, $6, $7) + `, id, clientID, email, challenge, method, redirect, resource) + if err != nil { + t.Fatalf("seed code: %v", err) + } +} + +func seedSession(t *testing.T, ctx context.Context, id, clientID, state, cbState, cbURL, challenge, method, resource string) { + t.Helper() + _, err := pgctx.Exec(ctx, ` + insert into oauth2_sessions (id, client_id, state, callback_state, callback_url, code_challenge, code_challenge_method, resource) + values ($1, $2, $3, $4, $5, $6, $7, $8) + `, id, clientID, state, cbState, cbURL, challenge, method, resource) + if err != nil { + t.Fatalf("seed session: %v", err) + } +} + +// pgctxExec runs an arbitrary statement against the test DB (for one-off seeds +// that the typed helpers don't cover, e.g. an already-expired token). +func pgctxExec(t *testing.T, ctx context.Context, query string, args ...any) (sql.Result, error) { + t.Helper() + return pgctx.Exec(ctx, query, args...) +} + +func seedToken(t *testing.T, ctx context.Context, hashedToken, email string) { + t.Helper() + _, err := pgctx.Exec(ctx, ` + insert into user_tokens (token, email, expires_at) + values ($1, $2, now() + interval '7 days') + `, hashedToken, email) + if err != nil { + t.Fatalf("seed token: %v", err) + } +} + +// --- assertion helpers --- + +func tokenEmail(t *testing.T, ctx context.Context, hashedToken string) (string, bool) { + t.Helper() + var email string + err := pgctx.QueryRow(ctx, `select email from user_tokens where token = $1`, hashedToken).Scan(&email) + if err != nil { + return "", false + } + return email, true +} + +func countRows(t *testing.T, ctx context.Context, table string) int { + t.Helper() + var n int + if err := pgctx.QueryRow(ctx, `select count(*) from `+table).Scan(&n); err != nil { + t.Fatalf("count %s: %v", table, err) + } + return n +} + +func clientRedirectURIs(t *testing.T, ctx context.Context, id string) ([]string, bool) { + t.Helper() + var uris []string + err := pgctx.QueryRow(ctx, `select redirect_uris from oauth2_clients where id = $1`, id).Scan(&uris) + if err != nil { + return nil, false + } + return uris, true +} + +// codeEmailPKCE reads back a minted oauth2_codes row (without consuming it). +func codeEmailPKCE(t *testing.T, ctx context.Context, id string) (email, challenge, method string) { + t.Helper() + err := pgctx.QueryRow(ctx, ` + select email, code_challenge, code_challenge_method + from oauth2_codes + where id = $1 + `, id).Scan(&email, &challenge, &method) + if err != nil { + t.Fatalf("read code %q: %v", id, err) + } + return email, challenge, method +} + +// oneSessionPKCE returns the PKCE challenge/method and callback URL of the only +// session row (the redirect handler is expected to have created exactly one). +func oneSessionPKCE(t *testing.T, ctx context.Context) (challenge, method, callbackURL string) { + t.Helper() + err := pgctx.QueryRow(ctx, ` + select code_challenge, code_challenge_method, callback_url + from oauth2_sessions + limit 1 + `).Scan(&challenge, &method, &callbackURL) + if err != nil { + t.Fatalf("read session: %v", err) + } + return challenge, method, callbackURL +} diff --git a/token.go b/token.go index 92431c5..4f8c15b 100644 --- a/token.go +++ b/token.go @@ -11,6 +11,10 @@ import ( const tokenPrefix = "deploys-api." +// tokenTTLSeconds mirrors the 7-day TTL applied to user_tokens rows; reported +// to clients as expires_in. +const tokenTTLSeconds = 7 * 24 * 60 * 60 + func generateBase64RandomString(s int) string { b := make([]byte, s) _, err := rand.Read(b[:]) @@ -36,6 +40,10 @@ func generateSessionID() string { return generateBase64RandomString(32) } +func generateClientID() string { + return generateBase64RandomString(16) +} + func hashToken(token string) string { h := sha256.New() h.Write([]byte(token)) @@ -55,3 +63,14 @@ func deleteToken(ctx context.Context, token string) error { _, err := pgctx.Exec(ctx, `delete from user_tokens where token = $1`, token) return err } + +// lookupToken resolves a hashed token to its owner email and expiry (unix +// seconds), returning sql.ErrNoRows when the token is unknown or expired. +func lookupToken(ctx context.Context, hashedToken string) (email string, exp int64, err error) { + err = pgctx.QueryRow(ctx, ` + select email, extract(epoch from expires_at)::bigint + from user_tokens + where token = $1 and expires_at > now() + `, hashedToken).Scan(&email, &exp) + return +} diff --git a/token_test.go b/token_test.go new file mode 100644 index 0000000..4107946 --- /dev/null +++ b/token_test.go @@ -0,0 +1,279 @@ +package main + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "net/url" + "strings" + "testing" +) + +func postForm(t *testing.T, path string, form url.Values) *http.Request { + t.Helper() + req := httptest.NewRequest(http.MethodPost, path, strings.NewReader(form.Encode())) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + return req +} + +// --- OLD FLOW (regression): confidential client + client_secret --- + +func TestTokenHandler_Confidential_Success(t *testing.T) { + t.Parallel() + tdb := newTestDB(t) + ctx := tdb.Ctx() + seedConfidentialClient(t, ctx, "web", "topsecret", "https://app.example.com/*") + seedCode(t, ctx, "abc", "web", "user@example.com", "", "", "https://app.example.com/cb", "") + + form := url.Values{"client_id": {"web"}, "client_secret": {"topsecret"}, "code": {"abc"}} + rec := httptest.NewRecorder() + TokenHandler{}.ServeHTTP(rec, postForm(t, "/token", form).WithContext(ctx)) + + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want 200; body=%s", rec.Code, rec.Body.String()) + } + var resp struct { + AccessToken string `json:"access_token"` + TokenType string `json:"token_type"` + ExpiresIn int `json:"expires_in"` + RefreshToken string `json:"refresh_token"` + } + if err := json.NewDecoder(rec.Body).Decode(&resp); err != nil { + t.Fatalf("decode: %v", err) + } + // Backward compatibility: the legacy client reads refresh_token. + if !strings.HasPrefix(resp.RefreshToken, tokenPrefix) { + t.Errorf("refresh_token %q missing prefix %q", resp.RefreshToken, tokenPrefix) + } + if resp.AccessToken != resp.RefreshToken { + t.Errorf("access_token (%q) should equal refresh_token (%q)", resp.AccessToken, resp.RefreshToken) + } + if resp.ExpiresIn != tokenTTLSeconds { + t.Errorf("expires_in = %d, want %d", resp.ExpiresIn, tokenTTLSeconds) + } + if resp.TokenType != "Bearer" { + t.Errorf("token_type = %q, want Bearer", resp.TokenType) + } + // The token is persisted (hashed) against the right email. + if email, ok := tokenEmail(t, ctx, hashToken(resp.RefreshToken)); !ok || email != "user@example.com" { + t.Errorf("persisted token email = %q (ok=%v), want user@example.com", email, ok) + } + // The code is single-use. + if n := countRows(t, ctx, "oauth2_codes"); n != 0 { + t.Errorf("oauth2_codes = %d, want 0 (code consumed)", n) + } +} + +func TestTokenHandler_Confidential_WrongSecret(t *testing.T) { + t.Parallel() + tdb := newTestDB(t) + ctx := tdb.Ctx() + seedConfidentialClient(t, ctx, "web", "topsecret", "https://app.example.com/*") + seedCode(t, ctx, "abc", "web", "user@example.com", "", "", "https://app.example.com/cb", "") + + form := url.Values{"client_id": {"web"}, "client_secret": {"wrong"}, "code": {"abc"}} + rec := httptest.NewRecorder() + TokenHandler{}.ServeHTTP(rec, postForm(t, "/token", form).WithContext(ctx)) + + if rec.Code != http.StatusUnauthorized { + t.Fatalf("status = %d, want 401", rec.Code) + } + assertOAuthError(t, rec, "invalid_client") + // The code must not be consumed on a failed exchange. + if n := countRows(t, ctx, "oauth2_codes"); n != 1 { + t.Errorf("oauth2_codes = %d, want 1 (code preserved)", n) + } +} + +func TestTokenHandler_Confidential_MissingSecret(t *testing.T) { + t.Parallel() + tdb := newTestDB(t) + ctx := tdb.Ctx() + seedConfidentialClient(t, ctx, "web", "topsecret", "https://app.example.com/*") + + form := url.Values{"client_id": {"web"}, "code": {"abc"}} + rec := httptest.NewRecorder() + TokenHandler{}.ServeHTTP(rec, postForm(t, "/token", form).WithContext(ctx)) + + if rec.Code != http.StatusBadRequest { + t.Fatalf("status = %d, want 400", rec.Code) + } + assertOAuthError(t, rec, "invalid_request") +} + +// --- NEW FLOW: public client + PKCE --- + +func TestTokenHandler_Public_PKCE_Success(t *testing.T) { + t.Parallel() + verifier, challenge := pkcePair() + const redirect = "http://127.0.0.1:5000/callback" + + tdb := newTestDB(t) + ctx := tdb.Ctx() + seedPublicClient(t, ctx, "cli", redirect) + seedCode(t, ctx, "abc", "cli", "user@example.com", challenge, "S256", redirect, "") + + form := url.Values{ + "grant_type": {"authorization_code"}, + "client_id": {"cli"}, + "code": {"abc"}, + "code_verifier": {verifier}, + "redirect_uri": {redirect}, + } + rec := httptest.NewRecorder() + TokenHandler{}.ServeHTTP(rec, postForm(t, "/token", form).WithContext(ctx)) + + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want 200; body=%s", rec.Code, rec.Body.String()) + } + var resp struct { + AccessToken string `json:"access_token"` + TokenType string `json:"token_type"` + ExpiresIn int `json:"expires_in"` + } + if err := json.NewDecoder(rec.Body).Decode(&resp); err != nil { + t.Fatalf("decode: %v", err) + } + if resp.AccessToken == "" { + t.Error("access_token must be populated") + } + if resp.TokenType != "Bearer" { + t.Errorf("token_type = %q, want Bearer", resp.TokenType) + } + if resp.ExpiresIn != tokenTTLSeconds { + t.Errorf("expires_in = %d, want %d", resp.ExpiresIn, tokenTTLSeconds) + } + if email, ok := tokenEmail(t, ctx, hashToken(resp.AccessToken)); !ok || email != "user@example.com" { + t.Errorf("persisted token email = %q (ok=%v), want user@example.com", email, ok) + } +} + +func TestTokenHandler_Public_PKCE_BadVerifier(t *testing.T) { + t.Parallel() + _, challenge := pkcePair() + const redirect = "http://127.0.0.1:5000/callback" + + tdb := newTestDB(t) + ctx := tdb.Ctx() + seedPublicClient(t, ctx, "cli", redirect) + seedCode(t, ctx, "abc", "cli", "user@example.com", challenge, "S256", redirect, "") + + form := url.Values{ + "grant_type": {"authorization_code"}, + "client_id": {"cli"}, + "code": {"abc"}, + "code_verifier": {"this-is-the-wrong-verifier"}, + "redirect_uri": {redirect}, + } + rec := httptest.NewRecorder() + TokenHandler{}.ServeHTTP(rec, postForm(t, "/token", form).WithContext(ctx)) + + if rec.Code != http.StatusBadRequest { + t.Fatalf("status = %d, want 400; body=%s", rec.Code, rec.Body.String()) + } + assertOAuthError(t, rec, "invalid_grant") + if n := countRows(t, ctx, "user_tokens"); n != 0 { + t.Errorf("user_tokens = %d, want 0 (no token issued on PKCE failure)", n) + } +} + +func TestTokenHandler_Public_RedirectMismatch(t *testing.T) { + t.Parallel() + verifier, challenge := pkcePair() + const redirect = "http://127.0.0.1:5000/callback" + + tdb := newTestDB(t) + ctx := tdb.Ctx() + seedPublicClient(t, ctx, "cli", redirect) + seedCode(t, ctx, "abc", "cli", "user@example.com", challenge, "S256", redirect, "") + + form := url.Values{ + "grant_type": {"authorization_code"}, + "client_id": {"cli"}, + "code": {"abc"}, + "code_verifier": {verifier}, + "redirect_uri": {"http://127.0.0.1:5000/different"}, + } + rec := httptest.NewRecorder() + TokenHandler{}.ServeHTTP(rec, postForm(t, "/token", form).WithContext(ctx)) + + if rec.Code != http.StatusBadRequest { + t.Fatalf("status = %d, want 400", rec.Code) + } + assertOAuthError(t, rec, "invalid_grant") +} + +func TestTokenHandler_Public_MissingVerifier(t *testing.T) { + t.Parallel() + tdb := newTestDB(t) + ctx := tdb.Ctx() + seedPublicClient(t, ctx, "cli", "http://127.0.0.1:5000/callback") + + form := url.Values{"grant_type": {"authorization_code"}, "client_id": {"cli"}, "code": {"abc"}} + rec := httptest.NewRecorder() + TokenHandler{}.ServeHTTP(rec, postForm(t, "/token", form).WithContext(ctx)) + + if rec.Code != http.StatusBadRequest { + t.Fatalf("status = %d, want 400", rec.Code) + } + assertOAuthError(t, rec, "invalid_request") +} + +// --- shared request validation --- + +func TestTokenHandler_UnsupportedGrantType(t *testing.T) { + t.Parallel() + form := url.Values{"grant_type": {"password"}, "client_id": {"x"}, "code": {"y"}} + rec := httptest.NewRecorder() + // Rejected before any DB access. + TokenHandler{}.ServeHTTP(rec, postForm(t, "/token", form)) + + if rec.Code != http.StatusBadRequest { + t.Fatalf("status = %d, want 400", rec.Code) + } + assertOAuthError(t, rec, "unsupported_grant_type") +} + +func TestTokenHandler_UnknownClient(t *testing.T) { + t.Parallel() + tdb := newTestDB(t) + ctx := tdb.Ctx() // empty DB: client lookup misses + + form := url.Values{"client_id": {"ghost"}, "client_secret": {"x"}, "code": {"y"}} + rec := httptest.NewRecorder() + TokenHandler{}.ServeHTTP(rec, postForm(t, "/token", form).WithContext(ctx)) + + if rec.Code != http.StatusUnauthorized { + t.Fatalf("status = %d, want 401", rec.Code) + } + assertOAuthError(t, rec, "invalid_client") +} + +func TestTokenHandler_InvalidCode(t *testing.T) { + t.Parallel() + tdb := newTestDB(t) + ctx := tdb.Ctx() + seedConfidentialClient(t, ctx, "web", "topsecret", "https://app.example.com/*") + + form := url.Values{"client_id": {"web"}, "client_secret": {"topsecret"}, "code": {"gone"}} + rec := httptest.NewRecorder() + TokenHandler{}.ServeHTTP(rec, postForm(t, "/token", form).WithContext(ctx)) + + if rec.Code != http.StatusBadRequest { + t.Fatalf("status = %d, want 400", rec.Code) + } + assertOAuthError(t, rec, "invalid_grant") +} + +func assertOAuthError(t *testing.T, rec *httptest.ResponseRecorder, wantCode string) { + t.Helper() + var body struct { + Error string `json:"error"` + } + if err := json.NewDecoder(rec.Body).Decode(&body); err != nil { + t.Fatalf("decode error body: %v", err) + } + if body.Error != wantCode { + t.Errorf("error = %q, want %q", body.Error, wantCode) + } +} diff --git a/tu/tu.go b/tu/tu.go new file mode 100644 index 0000000..fa7f9c8 --- /dev/null +++ b/tu/tu.go @@ -0,0 +1,65 @@ +// Package tu is the test utility for spinning up an isolated CockroachDB. +package tu + +import ( + "context" + "database/sql" + + "github.com/acoshift/pgsql/pgctx" + "github.com/cockroachdb/cockroach-go/v2/testserver" + + "github.com/deploys-app/auth/schema" +) + +// Context holds the test server and DB connection. +type Context struct { + ts testserver.TestServer + DB *sql.DB +} + +func (c *Context) setup() { + var err error + defer func() { + if err != nil { + panic(err) + } + }() + + c.ts, err = testserver.NewTestServer() + if err != nil { + return + } + defer func() { + if err != nil { + c.Teardown() + } + }() + + c.DB, err = sql.Open("postgres", c.ts.PGURL().String()+"&enable_implicit_transaction_for_batch_statements=off") + if err != nil { + return + } + + err = schema.Migrate(context.Background(), c.DB) +} + +func (c *Context) Teardown() { + if c.DB != nil { + c.DB.Close() + } + if c.ts != nil { + c.ts.Stop() + } +} + +// Ctx returns a context with the DB injected, the way pgctx.Middleware would. +func (c *Context) Ctx() context.Context { + return pgctx.NewContext(context.Background(), c.DB) +} + +// Setup starts a CockroachDB test server and runs the schema migration. +func Setup() *Context { + c := &Context{} + c.setup() + return c +}