diff --git a/cmd/git-credential-github-app-sts/main.go b/cmd/git-credential-github-app-sts/main.go new file mode 100644 index 0000000..090409e --- /dev/null +++ b/cmd/git-credential-github-app-sts/main.go @@ -0,0 +1,152 @@ +package main + +import ( + "bufio" + "context" + "fmt" + "io" + "os" + "strconv" + "strings" + "time" + + "github.com/alecthomas/kong" + + "github.com/etsy/github-app-sts/internal/client" + "github.com/etsy/github-app-sts/internal/client/tokensource" +) + +var args struct { + STSUrl string `name:"sts-url" required:"" help:"STS server URL." env:"GITHUB_APP_STS_URL"` + Scope string `name:"scope" required:"" help:"Policy name." env:"GITHUB_APP_STS_SCOPE"` + TokenSource string `name:"token-source" help:"Token source provider (gcp)." env:"GITHUB_APP_STS_TOKEN_SOURCE" xor:"token"` + TokenFile string `name:"token-file" help:"Path to file containing OIDC identity token." env:"GITHUB_APP_STS_TOKEN_FILE" xor:"token"` + Operation string `arg:"" help:"Git credential helper operation (get, store, erase)."` +} + +func main() { + k := kong.Parse(&args, + kong.Name("git-credential-github-app-sts"), + kong.Description("Git credential helper that uses GitHub App STS for authentication."), + ) + + if args.Operation != "get" { + return + } + + attrs, err := readCredentialAttributes(os.Stdin) + + k.FatalIfErrorf(err) + + outAttrs, err := get(attrs) + + k.FatalIfErrorf(err) + + writeCredentialAttributes(os.Stdout, outAttrs) +} + +type credentialAttribute struct { + key string + value string +} + +func get(in []credentialAttribute) ([]credentialAttribute, error) { + input := make(map[string]string) + + for _, a := range in { + input[a.key] = a.value + } + + if input["protocol"] != "https" { + return nil, nil + } + + host := input["host"] + + if host == "" { + return nil, fmt.Errorf("missing host in credential input") + } + + path := input["path"] + + if path == "" { + return nil, fmt.Errorf("missing path in credential input") + } + + resource := fmt.Sprintf("https://%s/%s", host, path) + + ctx := context.Background() + + subjectToken, err := resolveToken(ctx, args.STSUrl) + + if err != nil { + return nil, err + } + + resp, err := client.ExchangeToken(ctx, client.ExchangeRequest{ + STSUrl: args.STSUrl, + SubjectToken: subjectToken, + Resource: resource, + Scope: args.Scope, + }) + + if err != nil { + return nil, err + } + + expiryUTC := time.Now().Add(time.Duration(resp.ExpiresIn) * time.Second).UTC() + + return []credentialAttribute{ + {"username", "x-access-token"}, + {"password", resp.AccessToken}, + {"password_expiry_utc", strconv.FormatInt(expiryUTC.Unix(), 10)}, + }, nil +} + +func resolveToken(ctx context.Context, audience string) (string, error) { + if args.TokenFile != "" { + return tokensource.FromFile(args.TokenFile) + } + + switch args.TokenSource { + case "gcp": + return tokensource.FromGCP(ctx, audience) + default: + return "", fmt.Errorf("unknown token source: %s", args.TokenSource) + } +} + +func readCredentialAttributes(r io.Reader) ([]credentialAttribute, error) { + s := bufio.NewScanner(r) + a := make([]credentialAttribute, 0) + + for s.Scan() { + line := s.Text() + + if line == "" { + break + } + + key, value, ok := strings.Cut(line, "=") + + if !ok { + return nil, fmt.Errorf("missing '=' character in git credential input") + } + + a = append(a, credentialAttribute{key, value}) + } + + if err := s.Err(); err != nil { + return nil, err + } + + return a, nil +} + +func writeCredentialAttributes(w io.Writer, attrs []credentialAttribute) { + for _, attr := range attrs { + fmt.Fprintf(w, "%s=%s\n", attr.key, attr.value) + } + + fmt.Fprintln(w) +} diff --git a/cmd/github-app-sts-client/main.go b/cmd/github-app-sts-client/main.go new file mode 100644 index 0000000..2d65fad --- /dev/null +++ b/cmd/github-app-sts-client/main.go @@ -0,0 +1,58 @@ +package main + +import ( + "context" + "fmt" + "os" + + "github.com/alecthomas/kong" + + "github.com/etsy/github-app-sts/internal/client" + "github.com/etsy/github-app-sts/internal/client/tokensource" +) + +var args struct { + STSUrl string `name:"sts-url" required:"" help:"STS server URL." env:"GITHUB_APP_STS_URL"` + Resource string `name:"resource" required:"" help:"Target resource URL (e.g. https://github.com/org/repo)." env:"GITHUB_APP_STS_RESOURCE"` + Scope string `name:"scope" required:"" help:"Policy name." env:"GITHUB_APP_STS_SCOPE"` + TokenSource string `name:"token-source" help:"Token source provider (gcp)." env:"GITHUB_APP_STS_TOKEN_SOURCE" xor:"token"` + TokenFile string `name:"token-file" help:"Path to file containing OIDC identity token." env:"GITHUB_APP_STS_TOKEN_FILE" xor:"token"` +} + +func main() { + k := kong.Parse(&args, + kong.Name("github-app-sts-client"), + kong.Description("Exchange an OIDC token for a GitHub App installation token."), + ) + + ctx := context.Background() + + subjectToken, err := resolveToken(ctx, args.STSUrl) + k.FatalIfErrorf(err) + + resp, err := client.ExchangeToken(ctx, client.ExchangeRequest{ + STSUrl: args.STSUrl, + SubjectToken: subjectToken, + Resource: args.Resource, + Scope: args.Scope, + }) + + k.FatalIfErrorf(err) + + fmt.Println(resp.AccessToken) +} + +func resolveToken(ctx context.Context, audience string) (string, error) { + if args.TokenFile != "" { + return tokensource.FromFile(args.TokenFile) + } + + switch args.TokenSource { + case "gcp": + return tokensource.FromGCP(ctx, audience) + default: + fmt.Fprintf(os.Stderr, "error: unknown token source: %s\n", args.TokenSource) + os.Exit(1) + return "", nil + } +} diff --git a/go.mod b/go.mod index 6d48931..cc55129 100644 --- a/go.mod +++ b/go.mod @@ -22,6 +22,7 @@ require ( cloud.google.com/go/compute/metadata v0.9.0 // indirect cloud.google.com/go/iam v1.5.3 // indirect cloud.google.com/go/longrunning v0.8.0 // indirect + github.com/alecthomas/kong v1.15.0 // indirect github.com/antlr4-go/antlr/v4 v4.13.1 // indirect github.com/beorn7/perks v1.0.1 // indirect github.com/cenkalti/backoff/v5 v5.0.3 // indirect diff --git a/go.sum b/go.sum index 56cae3c..36449b0 100644 --- a/go.sum +++ b/go.sum @@ -14,6 +14,8 @@ cloud.google.com/go/kms v1.26.0 h1:cK9mN2cf+9V63D3H1f6koxTatWy39aTI/hCjz1I+adU= cloud.google.com/go/kms v1.26.0/go.mod h1:pHKOdFJm63hxBsiPkYtowZPltu9dW0MWvBa6IA4HM58= cloud.google.com/go/longrunning v0.8.0 h1:LiKK77J3bx5gDLi4SMViHixjD2ohlkwBi+mKA7EhfW8= cloud.google.com/go/longrunning v0.8.0/go.mod h1:UmErU2Onzi+fKDg2gR7dusz11Pe26aknR4kHmJJqIfk= +github.com/alecthomas/kong v1.15.0 h1:BVJstKbpO73zKpmIu+m/aLRrNmWwxXPIGTNin9VmLVI= +github.com/alecthomas/kong v1.15.0/go.mod h1:wrlbXem1CWqUV5Vbmss5ISYhsVPkBb1Yo7YKJghju2I= github.com/antlr4-go/antlr/v4 v4.13.1 h1:SqQKkuVZ+zWkMMNkjy5FZe5mr5WURWnlpmOuzYWrPrQ= github.com/antlr4-go/antlr/v4 v4.13.1/go.mod h1:GKmUxMtwp6ZgGwZSva4eWPC5mS6vUAmOABFgjdkM7Nw= github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= diff --git a/internal/client/client.go b/internal/client/client.go new file mode 100644 index 0000000..ee5d5d3 --- /dev/null +++ b/internal/client/client.go @@ -0,0 +1,79 @@ +package client + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "strings" + + "github.com/etsy/github-app-sts/internal/protocol" +) + +type ExchangeRequest struct { + STSUrl string + SubjectToken string + Resource string + Scope string +} + +func ExchangeToken(ctx context.Context, req ExchangeRequest) (*protocol.SuccessResponse, error) { + tokenURL := strings.TrimRight(req.STSUrl, "/") + "/token" + + form := url.Values{ + "grant_type": {protocol.GrantType}, + "subject_token_type": {protocol.SubjectTokenType}, + "subject_token": {req.SubjectToken}, + "resource": {req.Resource}, + "scope": {req.Scope}, + } + + httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, tokenURL, strings.NewReader(form.Encode())) + + if err != nil { + return nil, fmt.Errorf("creating request: %w", err) + } + + httpReq.Header.Set("Content-Type", "application/x-www-form-urlencoded") + + resp, err := http.DefaultClient.Do(httpReq) + + if err != nil { + return nil, fmt.Errorf("sending request: %w", err) + } + + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + + if err != nil { + return nil, fmt.Errorf("reading response: %w", err) + } + + if resp.StatusCode != http.StatusOK { + var parsed protocol.ErrorResponseBody + + if json.Unmarshal(body, &parsed) == nil && parsed.Error != "" { + return nil, &protocol.ErrorResponse{ + StatusCode: resp.StatusCode, + ErrorResponseBody: parsed, + } + } + + return nil, fmt.Errorf("STS returned %d: %s", resp.StatusCode, body) + } + + var result protocol.SuccessResponse + + if err := json.Unmarshal(body, &result); err != nil { + return nil, fmt.Errorf("parsing response: %w", err) + } + + if result.AccessToken == "" { + return nil, fmt.Errorf("missing access_token in response") + } + + return &result, nil +} diff --git a/internal/client/client_test.go b/internal/client/client_test.go new file mode 100644 index 0000000..add0c9c --- /dev/null +++ b/internal/client/client_test.go @@ -0,0 +1,216 @@ +package client + +import ( + "context" + "encoding/json" + "errors" + "net/http" + "net/http/httptest" + "net/url" + "testing" + + "github.com/etsy/github-app-sts/internal/protocol" +) + +func TestExchangeToken_Success(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + t.Errorf("expected POST, got %s", r.Method) + } + + if ct := r.Header.Get("Content-Type"); ct != "application/x-www-form-urlencoded" { + t.Errorf("expected Content-Type application/x-www-form-urlencoded, got %s", ct) + } + + if err := r.ParseForm(); err != nil { + t.Fatalf("parsing form: %v", err) + } + + assertFormValue(t, r.PostForm, "grant_type", protocol.GrantType) + assertFormValue(t, r.PostForm, "subject_token_type", protocol.SubjectTokenType) + assertFormValue(t, r.PostForm, "subject_token", "my-jwt-token") + assertFormValue(t, r.PostForm, "resource", "https://github.com/myorg/myrepo") + assertFormValue(t, r.PostForm, "scope", "my-policy") + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(protocol.SuccessResponse{ + AccessToken: "ghs_test123", + IssuedTokenType: "urn:ietf:params:oauth:token-type:access_token", + TokenType: "Bearer", + ExpiresIn: 3600, + }) + })) + defer server.Close() + + resp, err := ExchangeToken(context.Background(), ExchangeRequest{ + STSUrl: server.URL, + SubjectToken: "my-jwt-token", + Resource: "https://github.com/myorg/myrepo", + Scope: "my-policy", + }) + + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if resp.AccessToken != "ghs_test123" { + t.Errorf("expected access_token ghs_test123, got %s", resp.AccessToken) + } + + if resp.ExpiresIn != 3600 { + t.Errorf("expected expires_in 3600, got %d", resp.ExpiresIn) + } +} + +func TestExchangeToken_TrailingSlashNormalized(t *testing.T) { + var requestPath string + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + requestPath = r.URL.Path + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(protocol.SuccessResponse{ + AccessToken: "ghs_test", + TokenType: "Bearer", + }) + })) + + defer server.Close() + + _, err := ExchangeToken(context.Background(), ExchangeRequest{ + STSUrl: server.URL + "///", + SubjectToken: "token", + Resource: "https://github.com/org", + Scope: "policy", + }) + + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if requestPath != "/token" { + t.Errorf("expected path /token, got %s", requestPath) + } +} + +func TestExchangeToken_JSONError(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusBadRequest) + json.NewEncoder(w).Encode(map[string]string{ + "error": "invalid_grant", + "error_description": "token expired", + }) + })) + + defer server.Close() + + _, err := ExchangeToken(context.Background(), ExchangeRequest{ + STSUrl: server.URL, + SubjectToken: "bad-token", + Resource: "https://github.com/org", + Scope: "policy", + }) + + if err == nil { + t.Fatal("expected error, got nil") + } + + var stsErr *protocol.ErrorResponse + + if !errors.As(err, &stsErr) { + t.Fatalf("expected *protocol.ErrorResponse, got %T: %v", err, err) + } + + if stsErr.StatusCode != http.StatusBadRequest { + t.Errorf("expected status 400, got %d", stsErr.StatusCode) + } + + if stsErr.ErrorResponseBody.Error != "invalid_grant" { + t.Errorf("expected error code invalid_grant, got %s", stsErr.ErrorResponseBody.Error) + } + + if stsErr.Description != "token expired" { + t.Errorf("expected description 'token expired', got %s", stsErr.Description) + } +} + +func TestExchangeToken_NonJSONError(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusBadGateway) + w.Write([]byte("Bad Gateway")) + })) + + defer server.Close() + + _, err := ExchangeToken(context.Background(), ExchangeRequest{ + STSUrl: server.URL, + SubjectToken: "token", + Resource: "https://github.com/org", + Scope: "policy", + }) + + if err == nil { + t.Fatal("expected error, got nil") + } + + var stsErr *protocol.ErrorResponse + + if errors.As(err, &stsErr) { + t.Fatalf("expected non-STS error, got ErrorResponse: %v", stsErr) + } +} + +func TestExchangeToken_MissingAccessToken(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(map[string]string{ + "token_type": "Bearer", + }) + })) + + defer server.Close() + + _, err := ExchangeToken(context.Background(), ExchangeRequest{ + STSUrl: server.URL, + SubjectToken: "token", + Resource: "https://github.com/org", + Scope: "policy", + }) + + if err == nil { + t.Fatal("expected error, got nil") + } + + var stsErr *protocol.ErrorResponse + + if errors.As(err, &stsErr) { + t.Fatalf("expected non-STS error, got ErrorResponse: %v", stsErr) + } +} + +func TestExchangeToken_NetworkError(t *testing.T) { + _, err := ExchangeToken(context.Background(), ExchangeRequest{ + STSUrl: "http://localhost:1", + SubjectToken: "token", + Resource: "https://github.com/org", + Scope: "policy", + }) + + if err == nil { + t.Fatal("expected error, got nil") + } + + var stsErr *protocol.ErrorResponse + + if errors.As(err, &stsErr) { + t.Fatalf("expected non-STS error, got ErrorResponse: %v", stsErr) + } +} + +func assertFormValue(t *testing.T, values url.Values, key, expected string) { + t.Helper() + + if got := values.Get(key); got != expected { + t.Errorf("form value %s: expected %q, got %q", key, expected, got) + } +} diff --git a/internal/client/tokensource/file.go b/internal/client/tokensource/file.go new file mode 100644 index 0000000..c55b069 --- /dev/null +++ b/internal/client/tokensource/file.go @@ -0,0 +1,24 @@ +package tokensource + +import ( + "fmt" + "os" + "strings" +) + +// FromFile reads an OIDC identity token from the given file path. +func FromFile(path string) (string, error) { + data, err := os.ReadFile(path) + + if err != nil { + return "", fmt.Errorf("reading token file: %w", err) + } + + token := strings.TrimSpace(string(data)) + + if token == "" { + return "", fmt.Errorf("token file %s is empty", path) + } + + return token, nil +} diff --git a/internal/client/tokensource/gcp.go b/internal/client/tokensource/gcp.go new file mode 100644 index 0000000..9c37c61 --- /dev/null +++ b/internal/client/tokensource/gcp.go @@ -0,0 +1,32 @@ +package tokensource + +import ( + "context" + "fmt" + + "google.golang.org/api/idtoken" +) + +// FromGCP fetches an OIDC identity token using GCP Application Default +// Credentials. This works on GCE, Cloud Run, GKE, and anywhere a service +// account key or workload identity federation is configured. +func FromGCP(ctx context.Context, audience string) (string, error) { + ts, err := idtoken.NewTokenSource(ctx, audience) + + if err != nil { + return "", fmt.Errorf("creating GCP token source: %w", err) + } + + token, err := ts.Token() + + if err != nil { + return "", fmt.Errorf("fetching GCP identity token: %w", err) + } + + // it's a little weird, but the ID token is set as the access token field + if token.AccessToken == "" { + return "", fmt.Errorf("GCP token source returned an empty token") + } + + return token.AccessToken, nil +} diff --git a/internal/protocol/protocol.go b/internal/protocol/protocol.go new file mode 100644 index 0000000..1a2aeea --- /dev/null +++ b/internal/protocol/protocol.go @@ -0,0 +1,35 @@ +// Package protocol defines the RFC 8693 token exchange wire format shared +// between the STS client and server. +package protocol + +import "fmt" + +const ( + GrantType = "urn:ietf:params:oauth:grant-type:token-exchange" + SubjectTokenType = "urn:ietf:params:oauth:token-type:jwt" + IssuedTokenType = "urn:ietf:params:oauth:token-type:access_token" +) + +// SuccessResponse is the JSON body for a successful token exchange. +type SuccessResponse struct { + AccessToken string `json:"access_token"` + IssuedTokenType string `json:"issued_token_type"` + TokenType string `json:"token_type"` + ExpiresIn int `json:"expires_in"` +} + +// ErrorResponseBody is the JSON body for an OAuth 2.0 error response. +type ErrorResponseBody struct { + Error string `json:"error"` + Description string `json:"error_description"` +} + +// ErrorResponse pairs the wire format with the HTTP status code. +type ErrorResponse struct { + StatusCode int + ErrorResponseBody +} + +func (e *ErrorResponse) Error() string { + return fmt.Sprintf("STS error (%s): %s", e.ErrorResponseBody.Error, e.Description) +} diff --git a/internal/server/token_handler.go b/internal/server/token_handler.go index f6bf1ea..a4a6f14 100644 --- a/internal/server/token_handler.go +++ b/internal/server/token_handler.go @@ -13,6 +13,7 @@ import ( "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/metric" + "github.com/etsy/github-app-sts/internal/protocol" "github.com/etsy/github-app-sts/internal/server/sts" ) @@ -23,13 +24,6 @@ var exchangeCounter, _ = meter.Int64Counter( metric.WithDescription("Total token exchange attempts"), ) -const ( - expectedGrantType = "urn:ietf:params:oauth:grant-type:token-exchange" - expectedSubjectTokenType = "urn:ietf:params:oauth:token-type:jwt" - - issuedTokenType = "urn:ietf:params:oauth:token-type:access_token" -) - type service interface { ExchangeToken(ctx context.Context, subjectToken, resource, scope string) (*sts.Result, error) } @@ -43,64 +37,64 @@ func (h *tokenHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { ctx := r.Context() if r.Method != http.MethodPost { - h.writeError(w, &errorResponse{http.StatusMethodNotAllowed, "invalid_request", "method not allowed"}) + h.writeError(w, newErrorResponse(http.StatusMethodNotAllowed, "invalid_request", "method not allowed")) return } if err := r.ParseForm(); err != nil { - h.writeError(w, &errorResponse{http.StatusBadRequest, "invalid_request", "could not parse form body"}) + h.writeError(w, newErrorResponse(http.StatusBadRequest, "invalid_request", "could not parse form body")) return } grantType := r.FormValue("grant_type") if grantType == "" { - h.writeError(w, &errorResponse{http.StatusBadRequest, "invalid_request", "missing grant_type parameter"}) + h.writeError(w, newErrorResponse(http.StatusBadRequest, "invalid_request", "missing grant_type parameter")) return } - if grantType != expectedGrantType { - h.writeError(w, &errorResponse{http.StatusBadRequest, "unsupported_grant_type", "grant_type must be " + expectedGrantType}) + if grantType != protocol.GrantType { + h.writeError(w, newErrorResponse(http.StatusBadRequest, "unsupported_grant_type", "grant_type must be "+protocol.GrantType)) return } subjectTokenType := r.FormValue("subject_token_type") if subjectTokenType == "" { - h.writeError(w, &errorResponse{http.StatusBadRequest, "invalid_request", "missing subject_token_type parameter"}) + h.writeError(w, newErrorResponse(http.StatusBadRequest, "invalid_request", "missing subject_token_type parameter")) return } - if subjectTokenType != expectedSubjectTokenType { - h.writeError(w, &errorResponse{http.StatusBadRequest, "invalid_request", "subject_token_type must be " + expectedSubjectTokenType}) + if subjectTokenType != protocol.SubjectTokenType { + h.writeError(w, newErrorResponse(http.StatusBadRequest, "invalid_request", "subject_token_type must be "+protocol.SubjectTokenType)) return } subjectToken := r.FormValue("subject_token") if subjectToken == "" { - h.writeError(w, &errorResponse{http.StatusBadRequest, "invalid_request", "missing subject_token parameter"}) + h.writeError(w, newErrorResponse(http.StatusBadRequest, "invalid_request", "missing subject_token parameter")) return } resource := r.FormValue("resource") if resource == "" { - h.writeError(w, &errorResponse{http.StatusBadRequest, "invalid_request", "missing resource parameter"}) + h.writeError(w, newErrorResponse(http.StatusBadRequest, "invalid_request", "missing resource parameter")) return } scope := r.FormValue("scope") if scope == "" { - h.writeError(w, &errorResponse{http.StatusBadRequest, "invalid_request", "missing scope parameter"}) + h.writeError(w, newErrorResponse(http.StatusBadRequest, "invalid_request", "missing scope parameter")) return } result, err := h.sts.ExchangeToken(ctx, subjectToken, resource, scope) if err != nil { - response := &errorResponse{http.StatusInternalServerError, "server_error", "internal server error"} + response := newErrorResponse(http.StatusInternalServerError, "server_error", "internal server error") attrs := []attribute.KeyValue{ attribute.String("error.type", fmt.Sprintf("%T", err)), @@ -113,15 +107,15 @@ func (h *tokenHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { switch stsError.Kind { case sts.ErrKindInvalidToken: - response = &errorResponse{http.StatusBadRequest, "invalid_grant", stsError.Message} + response = newErrorResponse(http.StatusBadRequest, "invalid_grant", stsError.Message) logLevel = slog.LevelWarn case sts.ErrKindInvalidResource: - response = &errorResponse{http.StatusBadRequest, "invalid_target", stsError.Message} + response = newErrorResponse(http.StatusBadRequest, "invalid_target", stsError.Message) logLevel = slog.LevelInfo case sts.ErrKindInvalidScope: - response = &errorResponse{http.StatusBadRequest, "invalid_scope", stsError.Message} + response = newErrorResponse(http.StatusBadRequest, "invalid_scope", stsError.Message) logLevel = slog.LevelInfo case sts.ErrKindInternal: @@ -130,7 +124,7 @@ func (h *tokenHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { attrs = append( attrs, - attribute.String("sts.error", response.Error), + attribute.String("sts.error", response.ErrorResponseBody.Error), attribute.String("sts.target_org", stsError.Details.TargetOrganization), attribute.String("sts.target_repo", stsError.Details.TargetRepository), ) @@ -153,9 +147,9 @@ func (h *tokenHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { h.logExchange(ctx, slog.LevelInfo, result.Details, nil) - h.writeSuccess(w, &successResponse{ + h.writeSuccess(w, &protocol.SuccessResponse{ AccessToken: result.Token.AccessToken, - IssuedTokenType: issuedTokenType, + IssuedTokenType: protocol.IssuedTokenType, TokenType: "Bearer", ExpiresIn: int(time.Until(result.Token.ExpiresAt).Seconds()), }) @@ -188,28 +182,22 @@ func (h *tokenHandler) logExchange(ctx context.Context, level slog.Level, d *sts h.logger.LogAttrs(ctx, level, "token exchange succeeded", attrs...) } -type errorResponse struct { - Status int - Error string `json:"error"` - Description string `json:"error_description"` +func newErrorResponse(statusCode int, errorCode, description string) *protocol.ErrorResponse { + return &protocol.ErrorResponse{ + StatusCode: statusCode, + ErrorResponseBody: protocol.ErrorResponseBody{Error: errorCode, Description: description}, + } } -func (h *tokenHandler) writeError(w http.ResponseWriter, response *errorResponse) { +func (h *tokenHandler) writeError(w http.ResponseWriter, response *protocol.ErrorResponse) { w.Header().Set("Content-Type", "application/json") - w.WriteHeader(response.Status) + w.WriteHeader(response.StatusCode) // SAFETY: this is infallible. - json.NewEncoder(w).Encode(response) -} - -type successResponse struct { - AccessToken string `json:"access_token"` - IssuedTokenType string `json:"issued_token_type"` - TokenType string `json:"token_type"` - ExpiresIn int `json:"expires_in"` + json.NewEncoder(w).Encode(response.ErrorResponseBody) } -func (h *tokenHandler) writeSuccess(w http.ResponseWriter, response *successResponse) { +func (h *tokenHandler) writeSuccess(w http.ResponseWriter, response *protocol.SuccessResponse) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) diff --git a/internal/server/token_handler_test.go b/internal/server/token_handler_test.go index d022924..b049177 100644 --- a/internal/server/token_handler_test.go +++ b/internal/server/token_handler_test.go @@ -12,6 +12,7 @@ import ( "testing" "time" + "github.com/etsy/github-app-sts/internal/protocol" "github.com/etsy/github-app-sts/internal/server/github" "github.com/etsy/github-app-sts/internal/server/sts" ) @@ -27,8 +28,8 @@ func (f *fakeExchanger) ExchangeToken(_ context.Context, _, _, _ string) (*sts.R func validFormValues() url.Values { return url.Values{ - "grant_type": {expectedGrantType}, - "subject_token_type": {expectedSubjectTokenType}, + "grant_type": {protocol.GrantType}, + "subject_token_type": {protocol.SubjectTokenType}, "subject_token": {"fake.jwt.token"}, "resource": {"https://github.com/myorg/myrepo"}, "scope": {"my-policy"}, @@ -270,8 +271,8 @@ func TestTokenHandler(t *testing.T) { t.Errorf("token_type = %q, want %q", resp["token_type"], "Bearer") } - if resp["issued_token_type"] != issuedTokenType { - t.Errorf("issued_token_type = %q, want %q", resp["issued_token_type"], issuedTokenType) + if resp["issued_token_type"] != protocol.IssuedTokenType { + t.Errorf("issued_token_type = %q, want %q", resp["issued_token_type"], protocol.IssuedTokenType) } if _, ok := resp["expires_in"]; !ok {