diff --git a/bitbucket/rate_limit.go b/bitbucket/rate_limit.go new file mode 100644 index 0000000..3cb47a4 --- /dev/null +++ b/bitbucket/rate_limit.go @@ -0,0 +1,12 @@ +package bitbucket + +import ( + "context" + "fmt" + + forge "github.com/git-pkgs/forge" +) + +func (f *bitbucketForge) GetRateLimit(ctx context.Context) (*forge.RateLimit, error) { + return nil, fmt.Errorf("getting rate limit: %w", forge.ErrNotSupported) +} diff --git a/bitbucket/rate_limit_test.go b/bitbucket/rate_limit_test.go new file mode 100644 index 0000000..f5d1d00 --- /dev/null +++ b/bitbucket/rate_limit_test.go @@ -0,0 +1,17 @@ +package bitbucket + +import ( + "context" + "errors" + "testing" + + forge "github.com/git-pkgs/forge" +) + +func TestBitbucketRateLimitNotSupported(t *testing.T) { + f := New("test-token", nil) + _, err := f.GetRateLimit(context.Background()) + if !errors.Is(err, forge.ErrNotSupported) { + t.Fatalf("expected ErrNotSupported, got %v", err) + } +} diff --git a/forge.go b/forge.go index 274dc56..abbd4b7 100644 --- a/forge.go +++ b/forge.go @@ -45,6 +45,7 @@ type Forge interface { Secrets() SecretService Notifications() NotificationService Reviews() ReviewService + GetRateLimit(ctx context.Context) (*RateLimit, error) } // Client routes requests to the appropriate Forge based on the URL domain. diff --git a/forges_test.go b/forges_test.go index 565dd19..ad2de9d 100644 --- a/forges_test.go +++ b/forges_test.go @@ -477,6 +477,10 @@ func (m *mockForge) Reviews() ReviewService { return &mockReviewService{} } +func (m *mockForge) GetRateLimit(_ context.Context) (*RateLimit, error) { + return nil, ErrNotSupported +} + type mockRepoService struct { repo *Repository repos []Repository diff --git a/gitea/gitea.go b/gitea/gitea.go index cb34bf9..dbfcd65 100644 --- a/gitea/gitea.go +++ b/gitea/gitea.go @@ -9,7 +9,10 @@ import ( ) type giteaForge struct { - client *gitea.Client + client *gitea.Client + baseURL string + token string + httpClient *http.Client } // New creates a Gitea/Forgejo forge backend. @@ -22,7 +25,7 @@ func New(baseURL, token string, hc *http.Client) forge.Forge { opts = append(opts, gitea.SetHTTPClient(hc)) } c, _ := gitea.NewClient(baseURL, opts...) - return &giteaForge{client: c} + return &giteaForge{client: c, baseURL: baseURL, token: token, httpClient: hc} } type giteaRepoService struct { diff --git a/gitea/rate_limit.go b/gitea/rate_limit.go new file mode 100644 index 0000000..1117f8b --- /dev/null +++ b/gitea/rate_limit.go @@ -0,0 +1,67 @@ +package gitea + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "time" + + forge "github.com/git-pkgs/forge" +) + +type giteaRateLimitResponse struct { + Resources struct { + Core struct { + Limit int `json:"limit"` + Remaining int `json:"remaining"` + Reset int64 `json:"reset"` + } `json:"core"` + } `json:"resources"` +} + +func (f *giteaForge) GetRateLimit(ctx context.Context) (*forge.RateLimit, error) { + url := f.baseURL + "/api/v1/rate_limit" + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) + if err != nil { + return nil, err + } + if f.token != "" { + req.Header.Set("Authorization", "token "+f.token) + } + + hc := f.httpClient + if hc == nil { + hc = http.DefaultClient + } + + resp, err := hc.Do(req) + if err != nil { + return nil, err + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode == http.StatusNotFound { + return nil, fmt.Errorf("getting rate limit: %w", forge.ErrNotSupported) + } + if resp.StatusCode >= 400 { + return nil, &forge.HTTPError{StatusCode: resp.StatusCode, URL: url} + } + + var result giteaRateLimitResponse + if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + return nil, err + } + + core := result.Resources.Core + var reset time.Time + if core.Reset > 0 { + reset = time.Unix(core.Reset, 0) + } + + return &forge.RateLimit{ + Limit: core.Limit, + Remaining: core.Remaining, + Reset: reset, + }, nil +} diff --git a/gitea/rate_limit_test.go b/gitea/rate_limit_test.go new file mode 100644 index 0000000..51e3a51 --- /dev/null +++ b/gitea/rate_limit_test.go @@ -0,0 +1,60 @@ +package gitea + +import ( + "context" + "encoding/json" + "errors" + "net/http" + "net/http/httptest" + "testing" + + forge "github.com/git-pkgs/forge" +) + +func TestGiteaGetRateLimit(t *testing.T) { + mux := http.NewServeMux() + mux.HandleFunc("GET /api/v1/version", giteaVersionHandler) + mux.HandleFunc("GET /api/v1/rate_limit", func(w http.ResponseWriter, r *http.Request) { + _ = json.NewEncoder(w).Encode(map[string]any{ + "resources": map[string]any{ + "core": map[string]any{ + "limit": 100, + "remaining": 98, + "reset": 1717243200, + }, + }, + }) + }) + + srv := httptest.NewServer(mux) + defer srv.Close() + + f := New(srv.URL, "test-token", nil) + rl, err := f.GetRateLimit(context.Background()) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + assertEqualInt(t, "Limit", 100, rl.Limit) + assertEqualInt(t, "Remaining", 98, rl.Remaining) + if rl.Reset.Unix() != 1717243200 { + t.Errorf("Reset: want unix 1717243200, got %d", rl.Reset.Unix()) + } +} + +func TestGiteaGetRateLimitNotSupported(t *testing.T) { + mux := http.NewServeMux() + mux.HandleFunc("GET /api/v1/version", giteaVersionHandler) + mux.HandleFunc("GET /api/v1/rate_limit", func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNotFound) + }) + + srv := httptest.NewServer(mux) + defer srv.Close() + + f := New(srv.URL, "test-token", nil) + _, err := f.GetRateLimit(context.Background()) + if !errors.Is(err, forge.ErrNotSupported) { + t.Fatalf("expected ErrNotSupported, got %v", err) + } +} diff --git a/github/rate_limit.go b/github/rate_limit.go new file mode 100644 index 0000000..1209b9e --- /dev/null +++ b/github/rate_limit.go @@ -0,0 +1,25 @@ +package github + +import ( + "context" + + forge "github.com/git-pkgs/forge" +) + +func (f *gitHubForge) GetRateLimit(ctx context.Context) (*forge.RateLimit, error) { + limits, _, err := f.client.RateLimit.Get(ctx) + if err != nil { + return nil, err + } + + if limits == nil || limits.Core == nil { + return &forge.RateLimit{}, nil + } + + core := limits.Core + return &forge.RateLimit{ + Limit: core.Limit, + Remaining: core.Remaining, + Reset: core.Reset.Time, + }, nil +} diff --git a/github/rate_limit_test.go b/github/rate_limit_test.go new file mode 100644 index 0000000..cfdf3ab --- /dev/null +++ b/github/rate_limit_test.go @@ -0,0 +1,47 @@ +package github + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/google/go-github/v82/github" +) + +func TestGitHubGetRateLimit(t *testing.T) { + resetTime := time.Date(2024, 6, 1, 12, 0, 0, 0, time.UTC) + + mux := http.NewServeMux() + mux.HandleFunc("GET /api/v3/rate_limit", func(w http.ResponseWriter, r *http.Request) { + _ = json.NewEncoder(w).Encode(map[string]any{ + "resources": map[string]any{ + "core": map[string]any{ + "limit": 5000, + "remaining": 4999, + "reset": resetTime.Unix(), + }, + }, + }) + }) + + srv := httptest.NewServer(mux) + defer srv.Close() + + c := github.NewClient(nil) + c, _ = c.WithEnterpriseURLs(srv.URL+"/api/v3", srv.URL+"/api/v3") + f := &gitHubForge{client: c} + + rl, err := f.GetRateLimit(context.Background()) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + assertEqualInt(t, "Limit", 5000, rl.Limit) + assertEqualInt(t, "Remaining", 4999, rl.Remaining) + if !rl.Reset.Equal(resetTime) { + t.Errorf("Reset: want %v, got %v", resetTime, rl.Reset) + } +} diff --git a/gitlab/rate_limit.go b/gitlab/rate_limit.go new file mode 100644 index 0000000..b0b649d --- /dev/null +++ b/gitlab/rate_limit.go @@ -0,0 +1,33 @@ +package gitlab + +import ( + "context" + "strconv" + "time" + + forge "github.com/git-pkgs/forge" +) + +func (f *gitLabForge) GetRateLimit(ctx context.Context) (*forge.RateLimit, error) { + // GitLab has no dedicated rate limit endpoint. Rate limit info comes + // from response headers on any API call, so we make a lightweight request. + _, resp, err := f.client.Version.GetVersion() + if err != nil { + return nil, err + } + + limit, _ := strconv.Atoi(resp.Header.Get("RateLimit-Limit")) + remaining, _ := strconv.Atoi(resp.Header.Get("RateLimit-Remaining")) + resetUnix, _ := strconv.ParseInt(resp.Header.Get("RateLimit-Reset"), 10, 64) + + var reset time.Time + if resetUnix > 0 { + reset = time.Unix(resetUnix, 0) + } + + return &forge.RateLimit{ + Limit: limit, + Remaining: remaining, + Reset: reset, + }, nil +} diff --git a/gitlab/rate_limit_test.go b/gitlab/rate_limit_test.go new file mode 100644 index 0000000..464ee66 --- /dev/null +++ b/gitlab/rate_limit_test.go @@ -0,0 +1,56 @@ +package gitlab + +import ( + "context" + "fmt" + "net/http" + "net/http/httptest" + "testing" +) + +func TestGitLabGetRateLimit(t *testing.T) { + mux := http.NewServeMux() + mux.HandleFunc("GET /api/v4/version", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("RateLimit-Limit", "2000") + w.Header().Set("RateLimit-Remaining", "1999") + w.Header().Set("RateLimit-Reset", "1717243200") + _, _ = fmt.Fprintf(w, `{"version":"16.0.0","revision":"abc123"}`) + }) + + srv := httptest.NewServer(mux) + defer srv.Close() + + f := New(srv.URL, "test-token", nil) + rl, err := f.GetRateLimit(context.Background()) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + assertEqualInt(t, "Limit", 2000, rl.Limit) + assertEqualInt(t, "Remaining", 1999, rl.Remaining) + if rl.Reset.Unix() != 1717243200 { + t.Errorf("Reset: want unix 1717243200, got %d", rl.Reset.Unix()) + } +} + +func TestGitLabGetRateLimitNoHeaders(t *testing.T) { + mux := http.NewServeMux() + mux.HandleFunc("GET /api/v4/version", func(w http.ResponseWriter, r *http.Request) { + _, _ = fmt.Fprintf(w, `{"version":"16.0.0","revision":"abc123"}`) + }) + + srv := httptest.NewServer(mux) + defer srv.Close() + + f := New(srv.URL, "test-token", nil) + rl, err := f.GetRateLimit(context.Background()) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + assertEqualInt(t, "Limit", 0, rl.Limit) + assertEqualInt(t, "Remaining", 0, rl.Remaining) + if !rl.Reset.IsZero() { + t.Errorf("Reset: expected zero time, got %v", rl.Reset) + } +} diff --git a/internal/cli/rate_limit.go b/internal/cli/rate_limit.go new file mode 100644 index 0000000..0fd260a --- /dev/null +++ b/internal/cli/rate_limit.go @@ -0,0 +1,66 @@ +package cli + +import ( + "fmt" + "time" + + "github.com/git-pkgs/forge/internal/output" + "github.com/git-pkgs/forge/internal/resolve" + "github.com/spf13/cobra" +) + +var rateLimitCmd = &cobra.Command{ + Use: "rate-limit", + Short: "Check API rate limit status", + Args: cobra.NoArgs, + RunE: func(cmd *cobra.Command, args []string) error { + domain := domainFromFlags() + f, err := resolve.ForgeForDomain(domain) + if err != nil { + return err + } + + rl, err := f.GetRateLimit(cmd.Context()) + if err != nil { + return notSupported(err, "rate limit") + } + + p := printer() + if p.Format == output.JSON { + return p.PrintJSON(rl) + } + + if p.Format == output.Plain { + p.PrintPlain([]string{ + fmt.Sprintf("%d\t%d\t%s", rl.Limit, rl.Remaining, rl.Reset.Format(time.RFC3339)), + }) + return nil + } + + headers := []string{"LIMIT", "REMAINING", "RESETS AT"} + rows := [][]string{ + { + fmt.Sprintf("%d", rl.Limit), + fmt.Sprintf("%d", rl.Remaining), + formatReset(rl.Reset), + }, + } + p.PrintTable(headers, rows) + return nil + }, +} + +func formatReset(t time.Time) string { + if t.IsZero() { + return "-" + } + until := time.Until(t).Round(time.Second) + if until <= 0 { + return "now" + } + return fmt.Sprintf("%s (%s)", t.Format(time.RFC3339), until) +} + +func init() { + rootCmd.AddCommand(rateLimitCmd) +} diff --git a/internal/cli/rate_limit_test.go b/internal/cli/rate_limit_test.go new file mode 100644 index 0000000..ea0fa5f --- /dev/null +++ b/internal/cli/rate_limit_test.go @@ -0,0 +1,13 @@ +package cli + +import ( + "testing" +) + +func TestRateLimitCmdRejectsArgs(t *testing.T) { + rootCmd.SetArgs([]string{"rate-limit", "extra"}) + err := rootCmd.Execute() + if err == nil { + t.Fatal("expected error for extra args") + } +} diff --git a/types.go b/types.go index 13a11fb..bb76cb4 100644 --- a/types.go +++ b/types.go @@ -572,3 +572,10 @@ type SubmitReviewOpts struct { State ReviewState // approved, changes_requested, or commented Body string } + +// RateLimit holds normalized rate limit information for the current token. +type RateLimit struct { + Limit int `json:"limit"` + Remaining int `json:"remaining"` + Reset time.Time `json:"reset"` +}