From c7c0a9757bf1f2595cafefd3a258a99144553750 Mon Sep 17 00:00:00 2001 From: "lijiayi.2333" Date: Wed, 3 Jun 2026 11:31:51 +0800 Subject: [PATCH] mail: add local rate limits --- internal/client/api_errors.go | 4 + internal/client/api_errors_test.go | 15 ++ internal/client/client.go | 26 ++- internal/client/client_test.go | 189 +++++++++++++++ internal/ratelimit/canonical.go | 73 ++++++ internal/ratelimit/canonical_test.go | 150 ++++++++++++ internal/ratelimit/errors.go | 51 +++++ internal/ratelimit/limiter.go | 191 ++++++++++++++++ internal/ratelimit/limiter_test.go | 329 +++++++++++++++++++++++++++ internal/ratelimit/process_test.go | 92 ++++++++ internal/ratelimit/rules.go | 86 +++++++ internal/ratelimit/store.go | 219 ++++++++++++++++++ internal/ratelimit/store_test.go | 206 +++++++++++++++++ shortcuts/common/common.go | 4 + shortcuts/mail/mail_shortcut_test.go | 191 ++++++++++++++++ shortcuts/mail/mail_triage.go | 4 + 16 files changed, 1829 insertions(+), 1 deletion(-) create mode 100644 internal/ratelimit/canonical.go create mode 100644 internal/ratelimit/canonical_test.go create mode 100644 internal/ratelimit/errors.go create mode 100644 internal/ratelimit/limiter.go create mode 100644 internal/ratelimit/limiter_test.go create mode 100644 internal/ratelimit/process_test.go create mode 100644 internal/ratelimit/rules.go create mode 100644 internal/ratelimit/store.go create mode 100644 internal/ratelimit/store_test.go diff --git a/internal/client/api_errors.go b/internal/client/api_errors.go index 6b9244141..6b5a48260 100644 --- a/internal/client/api_errors.go +++ b/internal/client/api_errors.go @@ -14,6 +14,7 @@ import ( larkcore "github.com/larksuite/oapi-sdk-go/v3/core" "github.com/larksuite/cli/errs" + "github.com/larksuite/cli/internal/ratelimit" ) // rawAPIJSONHint guides users when an SDK or response body parse fails. The @@ -30,6 +31,9 @@ func WrapDoAPIError(err error) error { if err == nil { return nil } + if ratelimit.IsLocalRateLimit(err) { + return err + } // (1) Pass-through any typed errs.* error. if _, ok := errs.ProblemOf(err); ok { diff --git a/internal/client/api_errors_test.go b/internal/client/api_errors_test.go index 65845adc5..9d56c9fc2 100644 --- a/internal/client/api_errors_test.go +++ b/internal/client/api_errors_test.go @@ -196,6 +196,21 @@ func TestWrapDoAPIError_Nil(t *testing.T) { } } +func TestWrapDoAPIError_PreservesLocalMailRateLimit(t *testing.T) { + original := output.ErrAPI(output.LarkErrRateLimit, "rate limited", map[string]any{ + "source": "local_ratelimit", + "retry_after_ms": 100, + }) + err := WrapDoAPIError(original) + if err != original { + t.Fatalf("WrapDoAPIError returned %p, want original %p", err, original) + } + var exitErr *output.ExitError + if !errors.As(err, &exitErr) || exitErr.Detail == nil || exitErr.Detail.Type != "rate_limit" { + t.Fatalf("err = %v, want preserved rate_limit ExitError", err) + } +} + // ───────────────────────────────────────────────────────────────────────────── // WrapJSONResponseParseError: typed error contract. // diff --git a/internal/client/client.go b/internal/client/client.go index dc4a0e89e..90f347222 100644 --- a/internal/client/client.go +++ b/internal/client/client.go @@ -25,6 +25,7 @@ import ( "github.com/larksuite/cli/internal/errclass" "github.com/larksuite/cli/internal/errcompat" "github.com/larksuite/cli/internal/output" + "github.com/larksuite/cli/internal/ratelimit" "github.com/larksuite/cli/internal/util" ) @@ -130,6 +131,10 @@ func (c *APIClient) buildApiReq(request RawApiRequest) (*larkcore.ApiReq, []lark func (c *APIClient) DoSDKRequest(ctx context.Context, req *larkcore.ApiReq, as core.Identity, extraOpts ...larkcore.RequestOptionFunc) (*larkcore.ApiResp, error) { var opts []larkcore.RequestOptionFunc + if err := ratelimit.Allow(ctx, c.rateLimitRequest(req)); err != nil { + return nil, err + } + token, err := c.resolveAccessToken(ctx, as) if err != nil { // WrapDoAPIError is idempotent on already-classified errors: @@ -166,6 +171,10 @@ func (c *APIClient) DoSDKRequest(ctx context.Context, req *larkcore.ApiReq, as c func (c *APIClient) DoStream(ctx context.Context, req *larkcore.ApiReq, as core.Identity, opts ...Option) (*http.Response, error) { cfg := buildConfig(opts) + if err := ratelimit.Allow(ctx, c.rateLimitRequest(req)); err != nil { + return nil, err + } + // Resolve auth token, err := c.resolveAccessToken(ctx, as) if err != nil { @@ -250,6 +259,21 @@ func (c *APIClient) DoStream(ctx context.Context, req *larkcore.ApiReq, as core. return resp, nil } +func (c *APIClient) rateLimitRequest(req *larkcore.ApiReq) ratelimit.Request { + if req == nil { + return ratelimit.Request{} + } + limitReq := ratelimit.Request{ + Method: req.HttpMethod, + Path: req.ApiPath, + } + if c != nil && c.Config != nil { + limitReq.Brand = c.Config.Brand + limitReq.AppID = c.Config.AppID + } + return limitReq +} + func streamLogID(header http.Header) string { logID := strings.TrimSpace(header.Get(larkcore.HttpHeaderKeyLogId)) if logID == "" { @@ -379,7 +403,7 @@ func (c *APIClient) paginateLoop(ctx context.Context, request RawApiRequest, opt ExtraOpts: request.ExtraOpts, }) if err != nil { - if page == 1 { + if page == 1 || ratelimit.IsLocalRateLimit(err) { return nil, err } fmt.Fprintf(c.ErrOut, "[page %d] error, stopping pagination\n", page) diff --git a/internal/client/client_test.go b/internal/client/client_test.go index 8cf38d95b..fba1b60e8 100644 --- a/internal/client/client_test.go +++ b/internal/client/client_test.go @@ -24,6 +24,7 @@ import ( "github.com/larksuite/cli/internal/core" "github.com/larksuite/cli/internal/credential" "github.com/larksuite/cli/internal/output" + "github.com/larksuite/cli/internal/ratelimit" ) // roundTripFunc is an adapter to use a function as http.RoundTripper. @@ -48,6 +49,15 @@ func (s *staticTokenResolver) ResolveToken(_ context.Context, _ credential.Token return &credential.TokenResult{Token: "test-token"}, nil } +type countingTokenResolver struct { + count int +} + +func (s *countingTokenResolver) ResolveToken(_ context.Context, _ credential.TokenSpec) (*credential.TokenResult, error) { + s.count++ + return &credential.TokenResult{Token: "test-token"}, nil +} + // newTestAPIClient creates an APIClient with a mock HTTP transport. func newTestAPIClient(t *testing.T, rt http.RoundTripper) (*APIClient, *bytes.Buffer) { t.Helper() @@ -68,6 +78,14 @@ func newTestAPIClient(t *testing.T, rt http.RoundTripper) (*APIClient, *bytes.Bu }, errBuf } +func TestRateLimitRequestNilNoops(t *testing.T) { + ac := &APIClient{Config: &core.CliConfig{AppID: "test-app", Brand: core.BrandFeishu}} + req := ac.rateLimitRequest(nil) + if req.Brand != "" || req.AppID != "" || req.Method != "" || req.Path != "" { + t.Fatalf("rateLimitRequest(nil) = %#v, want empty request", req) + } +} + func TestIsJSONContentType(t *testing.T) { tests := []struct { ct string @@ -234,6 +252,48 @@ func TestPaginateAll_PageLimitStopsPagination(t *testing.T) { } } +func TestPaginateAll_ReturnsMailRateLimitAfterFirstPage(t *testing.T) { + now := time.Unix(100, 0) + rule := ratelimit.Rule{ + Method: "GET", + CanonicalPath: "/open-apis/mail/v1/user_mailboxes/:user_mailbox_id/messages", + Window: 2 * time.Second, + Limit: 1, + Scope: ratelimit.ScopeApp, + } + restore := ratelimit.SetDefaultLimiterForTest(ratelimit.NewLimiterForDir(t.TempDir(), []ratelimit.Rule{rule}, func() time.Time { return now })) + defer restore() + + apiCalls := 0 + ac, _ := newTestAPIClient(t, roundTripFunc(func(req *http.Request) (*http.Response, error) { + apiCalls++ + return jsonResponse(map[string]interface{}{ + "code": 0, "msg": "ok", + "data": map[string]interface{}{ + "items": []interface{}{map[string]interface{}{"id": apiCalls}}, + "has_more": true, + "page_token": "next", + }, + }), nil + })) + + _, err := ac.PaginateAll(context.Background(), RawApiRequest{ + Method: "GET", + URL: "/open-apis/mail/v1/user_mailboxes/me/messages", + As: core.AsBot, + }, PaginationOptions{PageLimit: 10, PageDelay: -1}) + if err == nil { + t.Fatal("expected local rate limit") + } + var exitErr *output.ExitError + if !errors.As(err, &exitErr) || exitErr.Detail == nil || exitErr.Detail.Type != "rate_limit" { + t.Fatalf("err = %v, want rate_limit ExitError", err) + } + if apiCalls != 1 { + t.Fatalf("api calls = %d, want 1", apiCalls) + } +} + func TestPaginateAll_NaturalEndClearsPageToken(t *testing.T) { apiCalls := 0 rt := roundTripFunc(func(req *http.Request) (*http.Response, error) { @@ -464,6 +524,60 @@ func TestDoStream_TransportFailureSplitsSubtype(t *testing.T) { } } +func TestDoStream_MailRateLimitRunsBeforeTokenAndHTTP(t *testing.T) { + now := time.Unix(100, 0) + rule := ratelimit.Rule{ + Method: "GET", + CanonicalPath: "/open-apis/mail/v1/user_mailboxes/:user_mailbox_id/messages/:message_id", + Window: 2 * time.Second, + Limit: 1, + Scope: ratelimit.ScopeApp, + } + restore := ratelimit.SetDefaultLimiterForTest(ratelimit.NewLimiterForDir(t.TempDir(), []ratelimit.Rule{rule}, func() time.Time { return now })) + defer restore() + + httpCalls := 0 + ac := &APIClient{ + HTTP: &http.Client{Transport: roundTripFunc(func(req *http.Request) (*http.Response, error) { + httpCalls++ + return &http.Response{ + StatusCode: 200, + Body: io.NopCloser(strings.NewReader("ok")), + }, nil + })}, + Config: &core.CliConfig{AppID: "test-app", AppSecret: "test-secret", Brand: core.BrandFeishu}, + } + resolver := &countingTokenResolver{} + ac.Credential = credential.NewCredentialProvider(nil, nil, resolver, nil) + + newReq := func() *larkcore.ApiReq { + return &larkcore.ApiReq{ + HttpMethod: http.MethodGet, + ApiPath: "/open-apis/mail/v1/user_mailboxes/me/messages/msg_1", + } + } + resp, err := ac.DoStream(context.Background(), newReq(), core.AsBot) + if err != nil { + t.Fatalf("first DoStream err = %v", err) + } + resp.Body.Close() + + _, err = ac.DoStream(context.Background(), newReq(), core.AsBot) + if err == nil { + t.Fatal("expected local rate limit") + } + var exitErr *output.ExitError + if !errors.As(err, &exitErr) || exitErr.Detail == nil || exitErr.Detail.Type != "rate_limit" { + t.Fatalf("err = %v, want rate_limit ExitError", err) + } + if httpCalls != 1 { + t.Fatalf("http calls = %d, want 1", httpCalls) + } + if resolver.count != 1 { + t.Fatalf("token resolutions = %d, want 1", resolver.count) + } +} + // failingTokenResolver always returns TokenUnavailableError, exercising the // auth/credential failure path through resolveAccessToken. type failingTokenResolver struct{} @@ -582,6 +696,81 @@ func TestDoSDKRequest_AuthFailureSurfacesTypedAuthenticationError(t *testing.T) } } +func TestDoSDKRequest_MailRateLimitRunsBeforeTokenAndSDK(t *testing.T) { + now := time.Unix(100, 0) + rule := ratelimit.Rule{ + Method: "GET", + CanonicalPath: "/open-apis/mail/v1/user_mailboxes/:user_mailbox_id/messages/:message_id", + Window: 2 * time.Second, + Limit: 1, + Scope: ratelimit.ScopeApp, + } + restore := ratelimit.SetDefaultLimiterForTest(ratelimit.NewLimiterForDir(t.TempDir(), []ratelimit.Rule{rule}, func() time.Time { return now })) + defer restore() + + httpCalls := 0 + ac, _ := newTestAPIClient(t, roundTripFunc(func(req *http.Request) (*http.Response, error) { + httpCalls++ + return jsonResponse(map[string]interface{}{"code": 0, "msg": "ok"}), nil + })) + resolver := &countingTokenResolver{} + ac.Credential = credential.NewCredentialProvider(nil, nil, resolver, nil) + + newReq := func() *larkcore.ApiReq { + return &larkcore.ApiReq{ + HttpMethod: http.MethodGet, + ApiPath: "/open-apis/mail/v1/user_mailboxes/me/messages/msg_1", + } + } + if _, err := ac.DoSDKRequest(context.Background(), newReq(), core.AsBot); err != nil { + t.Fatalf("first DoSDKRequest err = %v", err) + } + _, err := ac.DoSDKRequest(context.Background(), newReq(), core.AsBot) + if err == nil { + t.Fatal("expected local rate limit") + } + var exitErr *output.ExitError + if !errors.As(err, &exitErr) || exitErr.Detail == nil || exitErr.Detail.Type != "rate_limit" { + t.Fatalf("err = %v, want rate_limit ExitError", err) + } + if httpCalls != 1 { + t.Fatalf("http calls = %d, want 1", httpCalls) + } + if resolver.count != 1 { + t.Fatalf("token resolutions = %d, want 1", resolver.count) + } +} + +func TestDoSDKRequest_NonMailAndUnconfiguredMailStillSend(t *testing.T) { + rule := ratelimit.Rule{ + Method: "GET", + CanonicalPath: "/open-apis/mail/v1/user_mailboxes/:user_mailbox_id/messages/:message_id", + Window: time.Second, + Limit: 1, + Scope: ratelimit.ScopeApp, + } + restore := ratelimit.SetDefaultLimiterForTest(ratelimit.NewLimiterForDir(t.TempDir(), []ratelimit.Rule{rule}, time.Now)) + defer restore() + + httpCalls := 0 + ac, _ := newTestAPIClient(t, roundTripFunc(func(req *http.Request) (*http.Response, error) { + httpCalls++ + return jsonResponse(map[string]interface{}{"code": 0, "msg": "ok"}), nil + })) + + for _, path := range []string{ + "/open-apis/contact/v3/users/u1", + "/open-apis/mail/v1/user_mailboxes/me/settings", + } { + if _, err := ac.DoSDKRequest(context.Background(), &larkcore.ApiReq{HttpMethod: http.MethodGet, ApiPath: path}, core.AsBot); err != nil { + t.Fatalf("DoSDKRequest(%s) err = %v", path, err) + } + } + if httpCalls != 2 { + t.Fatalf("http calls = %d, want 2", httpCalls) + } +} + // TestDoSDKRequest_TransportFailureWrapsAsNetwork pins that genuinely untyped // SDK transport errors get the typed network classification via WrapDoAPIError. // io.ErrUnexpectedEOF from a RoundTripper surfaces through net/http as a diff --git a/internal/ratelimit/canonical.go b/internal/ratelimit/canonical.go new file mode 100644 index 000000000..618226dd4 --- /dev/null +++ b/internal/ratelimit/canonical.go @@ -0,0 +1,73 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +package ratelimit + +import ( + "net/url" + "regexp" + "strings" + + "github.com/larksuite/cli/internal/validate" +) + +type compiledRule struct { + rule *Rule + method string + pattern *regexp.Regexp +} + +var compiledBuiltinRules = compileRules(builtinRules) + +func Canonicalize(method, rawPath string) (string, *Rule, bool) { + method = normalizeMethod(method) + path := normalizePath(rawPath) + for _, entry := range compiledBuiltinRules { + if method != entry.method { + continue + } + if path == entry.rule.CanonicalPath || entry.pattern.MatchString(path) { + return entry.rule.CanonicalPath, entry.rule, true + } + } + return "", nil, false +} + +func normalizePath(rawPath string) string { + rawPath = strings.TrimSpace(rawPath) + if rawPath == "" { + return "" + } + if parsed, err := url.Parse(rawPath); err == nil && (parsed.IsAbs() || parsed.Host != "") { + if escaped := parsed.EscapedPath(); escaped != "" { + return escaped + } + return parsed.Path + } + return validate.StripQueryFragment(rawPath) +} + +func compileRules(rules []Rule) []compiledRule { + compiled := make([]compiledRule, 0, len(rules)) + for i := range rules { + rule := &rules[i] + compiled = append(compiled, compiledRule{ + rule: rule, + method: normalizeMethod(rule.Method), + pattern: regexp.MustCompile("^" + canonicalPattern(rule.CanonicalPath) + "$"), + }) + } + return compiled +} + +func canonicalPattern(canonical string) string { + segments := strings.Split(canonical, "/") + for i, segment := range segments { + if strings.HasPrefix(segment, ":") { + segments[i] = "[^/]+" + continue + } + segments[i] = regexp.QuoteMeta(segment) + } + return strings.Join(segments, "/") +} diff --git a/internal/ratelimit/canonical_test.go b/internal/ratelimit/canonical_test.go new file mode 100644 index 000000000..54ac5221c --- /dev/null +++ b/internal/ratelimit/canonical_test.go @@ -0,0 +1,150 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +package ratelimit + +import ( + "strings" + "testing" + "time" +) + +func TestCanonicalize(t *testing.T) { + tests := []struct { + name string + method string + path string + want string + wantMatch bool + }{ + { + name: "non mail path is ignored", + method: "GET", + path: "/open-apis/contact/v3/users/u1", + wantMatch: false, + }, + { + name: "unconfigured mail path is ignored", + method: "GET", + path: "/open-apis/mail/v1/user_mailboxes/me/settings", + wantMatch: false, + }, + { + name: "full URL query and fragment canonicalize", + method: "get", + path: "https://open.feishu.cn/open-apis/mail/v1/user_mailboxes/me/messages/msg_1?format=full#body", + want: "/open-apis/mail/v1/user_mailboxes/:user_mailbox_id/messages/:message_id", + wantMatch: true, + }, + { + name: "relative path query and fragment canonicalize", + method: "GET", + path: "/open-apis/mail/v1/user_mailboxes/me/messages/msg_1?format=metadata#body", + want: "/open-apis/mail/v1/user_mailboxes/:user_mailbox_id/messages/:message_id", + wantMatch: true, + }, + { + name: "concrete batch get canonicalizes", + method: "POST", + path: "/open-apis/mail/v1/user_mailboxes/me/messages/batch_get", + want: "/open-apis/mail/v1/user_mailboxes/:user_mailbox_id/messages/batch_get", + wantMatch: true, + }, + { + name: "SDK template path matches exactly", + method: "GET", + path: "/open-apis/mail/v1/user_mailboxes/:user_mailbox_id/messages/:message_id", + want: "/open-apis/mail/v1/user_mailboxes/:user_mailbox_id/messages/:message_id", + wantMatch: true, + }, + { + name: "method mismatch is ignored", + method: "POST", + path: "/open-apis/mail/v1/user_mailboxes/me/messages/msg_1", + wantMatch: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, rule, ok := Canonicalize(tt.method, tt.path) + if ok != tt.wantMatch { + t.Fatalf("match = %v, want %v", ok, tt.wantMatch) + } + if !ok { + return + } + if got != tt.want { + t.Fatalf("canonical = %q, want %q", got, tt.want) + } + if rule == nil { + t.Fatal("expected rule") + } + }) + } +} + +func TestBuiltinRulesUseConfirmedThreshold(t *testing.T) { + type threshold struct { + window time.Duration + limit int + } + want := map[string][]threshold{ + "GET /open-apis/mail/v1/user_mailboxes/:user_mailbox_id/messages/:message_id": { + {window: time.Minute, limit: 100}, + }, + "POST /open-apis/mail/v1/user_mailboxes/:user_mailbox_id/messages/batch_get": { + {window: time.Second, limit: 10}, + }, + "GET /open-apis/mail/v1/user_mailboxes/:user_mailbox_id/messages": { + {window: time.Second, limit: 10}, + }, + "POST /open-apis/mail/v1/user_mailboxes/:user_mailbox_id/search": { + {window: time.Minute, limit: 1000}, + {window: time.Second, limit: 50}, + }, + } + seen := make(map[string][]threshold) + for _, rule := range builtinRules { + key := rule.Method + " " + rule.CanonicalPath + if _, ok := want[key]; !ok { + t.Fatalf("unexpected builtin rule %s", key) + } + if rule.Window <= 0 { + t.Fatalf("%s window must be positive", key) + } + if rule.Limit <= 0 { + t.Fatalf("%s limit must be positive", key) + } + if rule.Scope != ScopeApp { + t.Fatalf("%s scope = %q, want %q", key, rule.Scope, ScopeApp) + } + if rule.Method == "" { + t.Fatalf("%s method must not be empty", key) + } + if !strings.HasPrefix(rule.CanonicalPath, "/open-apis/mail/") { + t.Fatalf("%s canonical path must be under /open-apis/mail/", key) + } + seen[key] = append(seen[key], threshold{window: rule.Window, limit: rule.Limit}) + } + if len(builtinRules) != 5 { + t.Fatalf("builtinRules len = %d, want 5", len(builtinRules)) + } + for key, thresholds := range want { + if len(seen[key]) != len(thresholds) { + t.Fatalf("missing builtin rule %s", key) + } + for _, threshold := range thresholds { + found := false + for _, got := range seen[key] { + if got == threshold { + found = true + break + } + } + if !found { + t.Fatalf("%s missing threshold window=%s limit=%d; got %#v", key, threshold.window, threshold.limit, seen[key]) + } + } + } +} diff --git a/internal/ratelimit/errors.go b/internal/ratelimit/errors.go new file mode 100644 index 000000000..d0a634346 --- /dev/null +++ b/internal/ratelimit/errors.go @@ -0,0 +1,51 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +package ratelimit + +import ( + "errors" + "fmt" + "math" + "time" + + "github.com/larksuite/cli/internal/output" +) + +const errorSource = "local_ratelimit" + +func newRateLimitError(rule *Rule, retryAfter time.Duration) error { + retryAfterMs := RetryAfterMs(retryAfter) + windowSeconds := int(math.Ceil(rule.Window.Seconds())) + msg := fmt.Sprintf("local rate limit: %s %s exceeded %d requests per %ds", + rule.Method, rule.CanonicalPath, rule.Limit, windowSeconds) + return output.ErrAPI(output.LarkErrRateLimit, msg, map[string]any{ + "source": errorSource, + "retry_after_ms": retryAfterMs, + "method": rule.Method, + "api_path": rule.CanonicalPath, + }) +} + +func IsLocalRateLimit(err error) bool { + var exitErr *output.ExitError + if !errors.As(err, &exitErr) || exitErr.Detail == nil { + return false + } + if exitErr.Detail.Type != "rate_limit" || exitErr.Detail.Code != output.LarkErrRateLimit { + return false + } + detail, ok := exitErr.Detail.Detail.(map[string]any) + return ok && detail["source"] == errorSource +} + +func RetryAfterMs(d time.Duration) int64 { + if d <= 0 { + return 1 + } + ms := int64(math.Ceil(float64(d) / float64(time.Millisecond))) + if ms < 1 { + return 1 + } + return ms +} diff --git a/internal/ratelimit/limiter.go b/internal/ratelimit/limiter.go new file mode 100644 index 000000000..3d85d9578 --- /dev/null +++ b/internal/ratelimit/limiter.go @@ -0,0 +1,191 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +package ratelimit + +import ( + "context" + "crypto/sha256" + "encoding/hex" + "sync" + "time" + + "github.com/larksuite/cli/internal/core" +) + +type Request struct { + Brand core.LarkBrand + AppID string + Method string + Path string +} + +type Limiter struct { + store *stateFile + compiled []compiledRule + now func() time.Time +} + +var ( + defaultLimiterMu sync.Mutex + defaultLimiterOverride *Limiter +) + +func newLimiter(store *stateFile, rules []Rule, now func() time.Time) *Limiter { + return newLimiterWithCompiled(store, compileRules(rules), now) +} + +func newLimiterWithCompiled(store *stateFile, compiled []compiledRule, now func() time.Time) *Limiter { + if now == nil { + now = time.Now + } + return &Limiter{store: store, compiled: compiled, now: now} +} + +func NewLimiterForDir(dir string, rules []Rule, now func() time.Time) *Limiter { + return newLimiter(newStateFile(dir), rules, now) +} + +func Allow(ctx context.Context, req Request) error { + defaultLimiterMu.Lock() + limiter := defaultLimiterOverride + defaultLimiterMu.Unlock() + if limiter == nil { + limiter = newLimiterWithCompiled(defaultStateFile(), compiledBuiltinRules, time.Now) + } + return limiter.Allow(ctx, req) +} + +func SetDefaultLimiterForTest(limiter *Limiter) func() { + defaultLimiterMu.Lock() + previous := defaultLimiterOverride + defaultLimiterOverride = limiter + defaultLimiterMu.Unlock() + return func() { + defaultLimiterMu.Lock() + defaultLimiterOverride = previous + defaultLimiterMu.Unlock() + } +} + +func (l *Limiter) Allow(ctx context.Context, req Request) error { + if l == nil { + return nil + } + if l.store == nil { + l.store = defaultStateFile() + } + rules, canonical, ok := l.match(req.Method, req.Path) + if !ok { + return nil + } + rules = usableRules(rules) + if len(rules) == 0 { + return nil + } + if req.AppID == "" { + return nil + } + nowFn := l.now + if nowFn == nil { + nowFn = time.Now + } + key := buildKey(req.Brand, req.AppID, normalizeMethod(req.Method), canonical) + return l.store.WithKeyLock(ctx, key, func(entries []int64) ([]int64, error) { + now := nowFn() + cutoff := now.Add(-maxMatchedWindow(rules)).UnixMilli() + kept := prune(entries, cutoff) + for _, rule := range rules { + if retryAfter, limited := retryAfterForRule(kept, now, rule); limited { + return nil, newRateLimitError(rule, retryAfter) + } + } + return append(kept, now.UnixMilli()), nil + }) +} + +func usableRules(rules []*Rule) []*Rule { + // Fail open on local rule mistakes: bad built-in rules must not block user requests. + usable := rules[:0] + for _, rule := range rules { + if rule.Scope != ScopeApp || rule.Limit <= 0 || rule.Window <= 0 { + continue + } + usable = append(usable, rule) + } + return usable +} + +func (l *Limiter) match(method, rawPath string) ([]*Rule, string, bool) { + method = normalizeMethod(method) + path := normalizePath(rawPath) + if path == "" { + return nil, "", false + } + var rules []*Rule + var canonical string + for _, entry := range l.compiled { + if method != entry.method { + continue + } + if path == entry.rule.CanonicalPath || entry.pattern.MatchString(path) { + if canonical == "" { + canonical = entry.rule.CanonicalPath + } + if entry.rule.CanonicalPath == canonical { + rules = append(rules, entry.rule) + } + } + } + if len(rules) == 0 { + return nil, "", false + } + return rules, canonical, true +} + +func prune(values []int64, cutoff int64) []int64 { + if len(values) == 0 { + return nil + } + kept := values[:0] + for _, value := range values { + if value > cutoff { + kept = append(kept, value) + } + } + return kept +} + +func buildKey(brand core.LarkBrand, appID, method, canonicalPath string) string { + sum := sha256.Sum256([]byte(string(brand) + "\x00" + appID + "\x00" + method + "\x00" + canonicalPath)) + return hex.EncodeToString(sum[:]) +} + +func maxMatchedWindow(rules []*Rule) time.Duration { + var max time.Duration + for _, rule := range rules { + if rule.Window > max { + max = rule.Window + } + } + return max +} + +func retryAfterForRule(entries []int64, now time.Time, rule *Rule) (time.Duration, bool) { + cutoff := now.Add(-rule.Window).UnixMilli() + count := 0 + var oldest int64 + for _, entry := range entries { + if entry <= cutoff { + continue + } + if count == 0 || entry < oldest { + oldest = entry + } + count++ + } + if count < rule.Limit { + return 0, false + } + return time.UnixMilli(oldest).Add(rule.Window).Sub(now), true +} diff --git a/internal/ratelimit/limiter_test.go b/internal/ratelimit/limiter_test.go new file mode 100644 index 000000000..80f0cc81a --- /dev/null +++ b/internal/ratelimit/limiter_test.go @@ -0,0 +1,329 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +package ratelimit + +import ( + "context" + "errors" + "sync" + "testing" + "time" + + "github.com/larksuite/cli/internal/core" + "github.com/larksuite/cli/internal/output" +) + +func testRule() Rule { + return Rule{ + Method: "GET", + CanonicalPath: "/open-apis/mail/v1/user_mailboxes/:user_mailbox_id/messages/:message_id", + Window: 2 * time.Second, + Limit: 2, + Scope: ScopeApp, + } +} + +func testRequest() Request { + return Request{ + Brand: core.BrandFeishu, + AppID: "app-1", + Method: "GET", + Path: "/open-apis/mail/v1/user_mailboxes/me/messages/msg_1", + } +} + +func TestLimiterAllowsUntilLimitThenReturnsRateLimit(t *testing.T) { + now := time.Unix(100, 0) + limiter := NewLimiterForDir(t.TempDir(), []Rule{testRule()}, func() time.Time { return now }) + ctx := context.Background() + + if err := limiter.Allow(ctx, testRequest()); err != nil { + t.Fatalf("first check err = %v", err) + } + if err := limiter.Allow(ctx, testRequest()); err != nil { + t.Fatalf("second check err = %v", err) + } + err := limiter.Allow(ctx, testRequest()) + if err == nil { + t.Fatal("expected rate limit error") + } + var exitErr *output.ExitError + if !errors.As(err, &exitErr) { + t.Fatalf("expected ExitError, got %T", err) + } + if exitErr.Detail == nil || exitErr.Detail.Type != "rate_limit" || exitErr.Detail.Code != output.LarkErrRateLimit { + t.Fatalf("unexpected detail: %#v", exitErr.Detail) + } + detail, ok := exitErr.Detail.Detail.(map[string]any) + if !ok { + t.Fatalf("expected detail map, got %T", exitErr.Detail.Detail) + } + if got := detail["retry_after_ms"]; got != int64(2000) { + t.Fatalf("retry_after_ms = %v, want 2000", got) + } + if got := detail["source"]; got != "local_ratelimit" { + t.Fatalf("source = %v, want local_ratelimit", got) + } +} + +func TestLimiterPrunesWindowAndAllowsAgain(t *testing.T) { + now := time.Unix(100, 0) + limiter := NewLimiterForDir(t.TempDir(), []Rule{testRule()}, func() time.Time { return now }) + ctx := context.Background() + + if err := limiter.Allow(ctx, testRequest()); err != nil { + t.Fatal(err) + } + if err := limiter.Allow(ctx, testRequest()); err != nil { + t.Fatal(err) + } + now = now.Add(2*time.Second + time.Millisecond) + if err := limiter.Allow(ctx, testRequest()); err != nil { + t.Fatalf("check after window err = %v", err) + } +} + +func TestLimiterChecksMultipleRulesForSameKey(t *testing.T) { + now := time.Unix(100, 0) + shortRule := testRule() + shortRule.Window = time.Second + shortRule.Limit = 2 + longRule := testRule() + longRule.Window = 10 * time.Second + longRule.Limit = 3 + limiter := NewLimiterForDir(t.TempDir(), []Rule{shortRule, longRule}, func() time.Time { return now }) + ctx := context.Background() + + if err := limiter.Allow(ctx, testRequest()); err != nil { + t.Fatal(err) + } + if err := limiter.Allow(ctx, testRequest()); err != nil { + t.Fatal(err) + } + if err := limiter.Allow(ctx, testRequest()); !IsLocalRateLimit(err) { + t.Fatalf("third check err = %v, want local rate_limit", err) + } + + now = now.Add(time.Second + time.Millisecond) + if err := limiter.Allow(ctx, testRequest()); err != nil { + t.Fatalf("check after short window err = %v", err) + } + err := limiter.Allow(ctx, testRequest()) + if !IsLocalRateLimit(err) { + t.Fatalf("check after long window limit err = %v, want local rate_limit", err) + } + var exitErr *output.ExitError + if !errors.As(err, &exitErr) { + t.Fatalf("expected ExitError, got %T", err) + } + detail, ok := exitErr.Detail.Detail.(map[string]any) + if !ok { + t.Fatalf("expected detail map, got %T", exitErr.Detail.Detail) + } + if got := detail["retry_after_ms"]; got != int64(8999) { + t.Fatalf("retry_after_ms = %v, want 8999", got) + } +} + +func TestLimiterKeyIsIsolatedByBrandAppMethodAndCanonicalPath(t *testing.T) { + now := time.Unix(100, 0) + rules := []Rule{ + { + Method: "GET", + CanonicalPath: "/open-apis/mail/v1/user_mailboxes/:user_mailbox_id/messages/:message_id", + Window: 2 * time.Second, + Limit: 1, + Scope: ScopeApp, + }, + { + Method: "POST", + CanonicalPath: "/open-apis/mail/v1/user_mailboxes/:user_mailbox_id/search", + Window: 2 * time.Second, + Limit: 1, + Scope: ScopeApp, + }, + } + limiter := NewLimiterForDir(t.TempDir(), rules, func() time.Time { return now }) + ctx := context.Background() + base := testRequest() + if err := limiter.Allow(ctx, base); err != nil { + t.Fatal(err) + } + + cases := []Request{ + {Brand: core.BrandLark, AppID: base.AppID, Method: base.Method, Path: base.Path}, + {Brand: base.Brand, AppID: "app-2", Method: base.Method, Path: base.Path}, + {Brand: base.Brand, AppID: base.AppID, Method: "POST", Path: "/open-apis/mail/v1/user_mailboxes/me/search"}, + } + for _, req := range cases { + if err := limiter.Allow(ctx, req); err != nil { + t.Fatalf("isolated request %#v err = %v", req, err) + } + } + if err := limiter.Allow(ctx, base); err == nil { + t.Fatal("expected original key to remain limited") + } +} + +func TestLimiterNoopsForNonMailAndUnconfiguredMail(t *testing.T) { + limiter := NewLimiterForDir(t.TempDir(), []Rule{testRule()}, time.Now) + ctx := context.Background() + requests := []Request{ + {Method: "GET", Path: "/open-apis/contact/v3/users/u1"}, + {Method: "GET", Path: "/open-apis/mail/v1/user_mailboxes/me/settings"}, + } + for _, req := range requests { + if err := limiter.Allow(ctx, req); err != nil { + t.Fatalf("request %#v err = %v", req, err) + } + } +} + +func TestLimiterSkipsMissingAppID(t *testing.T) { + now := time.Unix(100, 0) + rule := testRule() + rule.Limit = 1 + limiter := NewLimiterForDir(t.TempDir(), []Rule{rule}, func() time.Time { return now }) + req := testRequest() + req.AppID = "" + + for i := 0; i < 2; i++ { + if err := limiter.Allow(context.Background(), req); err != nil { + t.Fatalf("allow %d err = %v", i+1, err) + } + } +} + +func TestLimiterSupportsConfiguredNonMailRules(t *testing.T) { + rule := Rule{ + Method: "GET", + CanonicalPath: "/open-apis/contact/v3/users/:user_id", + Window: time.Second, + Limit: 1, + Scope: ScopeApp, + } + limiter := NewLimiterForDir(t.TempDir(), []Rule{rule}, time.Now) + req := Request{ + Brand: core.BrandFeishu, + AppID: "app-1", + Method: "GET", + Path: "/open-apis/contact/v3/users/u1", + } + if err := limiter.Allow(context.Background(), req); err != nil { + t.Fatalf("first allow err = %v", err) + } + if err := limiter.Allow(context.Background(), req); !IsLocalRateLimit(err) { + t.Fatalf("second allow err = %v, want local rate_limit", err) + } +} + +func TestLimiterSkipsUnsupportedScope(t *testing.T) { + rule := testRule() + rule.Scope = Scope("user") + limiter := NewLimiterForDir(t.TempDir(), []Rule{rule}, time.Now) + for i := 0; i < 2; i++ { + if err := limiter.Allow(context.Background(), testRequest()); err != nil { + t.Fatalf("allow %d err = %v", i+1, err) + } + } +} + +func TestLimiterSkipsInvalidRuleParams(t *testing.T) { + cases := []struct { + name string + mutate func(*Rule) + }{ + { + name: "zero limit", + mutate: func(rule *Rule) { + rule.Limit = 0 + }, + }, + { + name: "negative limit", + mutate: func(rule *Rule) { + rule.Limit = -1 + }, + }, + { + name: "zero window", + mutate: func(rule *Rule) { + rule.Window = 0 + }, + }, + { + name: "negative window", + mutate: func(rule *Rule) { + rule.Window = -time.Second + }, + }, + } + for _, tt := range cases { + t.Run(tt.name, func(t *testing.T) { + rule := testRule() + tt.mutate(&rule) + limiter := NewLimiterForDir(t.TempDir(), []Rule{rule}, time.Now) + for i := 0; i < 2; i++ { + if err := limiter.Allow(context.Background(), testRequest()); err != nil { + t.Fatalf("allow %d err = %v", i+1, err) + } + } + }) + } +} + +func TestLimiterUsesValidRulesWhenMixedWithInvalidRules(t *testing.T) { + now := time.Unix(100, 0) + invalid := testRule() + invalid.Scope = Scope("user") + valid := testRule() + valid.Limit = 1 + limiter := NewLimiterForDir(t.TempDir(), []Rule{invalid, valid}, func() time.Time { return now }) + + if err := limiter.Allow(context.Background(), testRequest()); err != nil { + t.Fatalf("first allow err = %v", err) + } + if err := limiter.Allow(context.Background(), testRequest()); !IsLocalRateLimit(err) { + t.Fatalf("second allow err = %v, want local rate_limit", err) + } +} + +func TestLimiterConcurrentSameKeyAllowsOnlyLimit(t *testing.T) { + now := time.Unix(100, 0) + rule := testRule() + rule.Limit = 3 + limiter := NewLimiterForDir(t.TempDir(), []Rule{rule}, func() time.Time { return now }) + + const workers = 20 + errs := make(chan error, workers) + var wg sync.WaitGroup + for i := 0; i < workers; i++ { + wg.Add(1) + go func() { + defer wg.Done() + errs <- limiter.Allow(context.Background(), testRequest()) + }() + } + wg.Wait() + close(errs) + + allowed := 0 + limited := 0 + for err := range errs { + switch { + case err == nil: + allowed++ + case IsLocalRateLimit(err): + limited++ + default: + t.Fatalf("unexpected error: %v", err) + } + } + if allowed != rule.Limit { + t.Fatalf("allowed = %d, want %d", allowed, rule.Limit) + } + if limited != workers-rule.Limit { + t.Fatalf("limited = %d, want %d", limited, workers-rule.Limit) + } +} diff --git a/internal/ratelimit/process_test.go b/internal/ratelimit/process_test.go new file mode 100644 index 000000000..629aff17e --- /dev/null +++ b/internal/ratelimit/process_test.go @@ -0,0 +1,92 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +package ratelimit + +import ( + "bytes" + "context" + "errors" + "fmt" + "os" + "os/exec" + "strings" + "testing" + "time" + + "github.com/larksuite/cli/internal/output" +) + +func TestCrossProcessWindowIsShared(t *testing.T) { + if os.Getenv("MAIL_RATELIMIT_HELPER") == "1" { + runProcessHelper() + return + } + + dir := t.TempDir() + const workers = 5 + type helper struct { + cmd *exec.Cmd + stdout bytes.Buffer + stderr bytes.Buffer + } + helpers := make([]*helper, 0, workers) + for i := 0; i < workers; i++ { + cmd := exec.Command(os.Args[0], "-test.run=TestCrossProcessWindowIsShared") + h := &helper{cmd: cmd} + cmd.Env = append(os.Environ(), + "MAIL_RATELIMIT_HELPER=1", + "MAIL_RATELIMIT_STATE_DIR="+dir, + ) + cmd.Stdout = &h.stdout + cmd.Stderr = &h.stderr + if err := cmd.Start(); err != nil { + t.Fatalf("start helper %d: %v", i, err) + } + helpers = append(helpers, h) + } + allowed := 0 + limited := 0 + for i, h := range helpers { + err := h.cmd.Wait() + out := h.stdout.String() + if err != nil { + t.Fatalf("helper %d failed: %v\nstdout=%s\nstderr=%s", i, err, out, h.stderr.String()) + } + switch strings.TrimSpace(out) { + case "allowed": + allowed++ + case "limited": + limited++ + default: + t.Fatalf("unexpected helper output %q", out) + } + } + if allowed > 2 { + t.Fatalf("allowed helpers = %d, want <= 2", allowed) + } + if limited == 0 { + t.Fatal("expected at least one helper to be locally rate limited") + } +} + +func runProcessHelper() { + dir := os.Getenv("MAIL_RATELIMIT_STATE_DIR") + if dir == "" { + fmt.Fprint(os.Stderr, "missing MAIL_RATELIMIT_STATE_DIR") + os.Exit(2) + } + limiter := NewLimiterForDir(dir, []Rule{testRule()}, func() time.Time { return time.Unix(100, 0) }) + err := limiter.Allow(context.Background(), testRequest()) + if err == nil { + fmt.Print("allowed") + os.Exit(0) + } + var exitErr *output.ExitError + if errors.As(err, &exitErr) && exitErr.Detail != nil && exitErr.Detail.Type == "rate_limit" { + fmt.Print("limited") + os.Exit(0) + } + fmt.Fprintf(os.Stderr, "unexpected helper err: %v", err) + os.Exit(2) +} diff --git a/internal/ratelimit/rules.go b/internal/ratelimit/rules.go new file mode 100644 index 000000000..75f9129da --- /dev/null +++ b/internal/ratelimit/rules.go @@ -0,0 +1,86 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +package ratelimit + +import ( + "strings" + "time" +) + +type Scope string + +const ScopeApp Scope = "app" + +const ( + tier3Window = time.Minute + tier3Limit = 100 + tier4MinuteWindow = time.Minute + tier4MinuteLimit = 1000 + tier4SecondWindow = time.Second + tier4SecondLimit = 50 + tier7Window = time.Second + tier7Limit = 10 +) + +type Rule struct { + Method string + CanonicalPath string + Window time.Duration + Limit int + Scope Scope +} + +var builtinRules = []Rule{ + // Online mail API YAML rateLimit tiers: + // tier 3 = 100/min, tier 4 = 1000/min and 50/s, tier 7 = 10/s. + { + Method: "GET", + CanonicalPath: "/open-apis/mail/v1/user_mailboxes/:user_mailbox_id/messages/:message_id", + Window: tier3Window, + Limit: tier3Limit, + Scope: ScopeApp, + }, + { + Method: "POST", + CanonicalPath: "/open-apis/mail/v1/user_mailboxes/:user_mailbox_id/messages/batch_get", + Window: tier7Window, + Limit: tier7Limit, + Scope: ScopeApp, + }, + { + Method: "GET", + CanonicalPath: "/open-apis/mail/v1/user_mailboxes/:user_mailbox_id/messages", + Window: tier7Window, + Limit: tier7Limit, + Scope: ScopeApp, + }, + { + Method: "POST", + CanonicalPath: "/open-apis/mail/v1/user_mailboxes/:user_mailbox_id/search", + Window: tier4MinuteWindow, + Limit: tier4MinuteLimit, + Scope: ScopeApp, + }, + { + Method: "POST", + CanonicalPath: "/open-apis/mail/v1/user_mailboxes/:user_mailbox_id/search", + Window: tier4SecondWindow, + Limit: tier4SecondLimit, + Scope: ScopeApp, + }, +} + +func maxRuleWindow(rules []Rule) time.Duration { + var max time.Duration + for _, rule := range rules { + if rule.Window > max { + max = rule.Window + } + } + return max +} + +func normalizeMethod(method string) string { + return strings.ToUpper(strings.TrimSpace(method)) +} diff --git a/internal/ratelimit/store.go b/internal/ratelimit/store.go new file mode 100644 index 000000000..4112306d5 --- /dev/null +++ b/internal/ratelimit/store.go @@ -0,0 +1,219 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +package ratelimit + +import ( + "context" + "crypto/sha256" + "encoding/hex" + "encoding/json" + "errors" + "fmt" + "os" + "path/filepath" + "regexp" + "strings" + "sync" + "time" + + "github.com/larksuite/cli/internal/core" + "github.com/larksuite/cli/internal/lockfile" + "github.com/larksuite/cli/internal/output" + "github.com/larksuite/cli/internal/validate" + "github.com/larksuite/cli/internal/vfs" +) + +const stateVersion = 1 + +var ( + lockWaitTimeout = 5 * time.Second + lockRetryInitial = 10 * time.Millisecond + lockRetryMax = 50 * time.Millisecond + stateGCInterval = time.Hour + stateGCGrace = time.Minute + stateGCMu sync.Mutex + stateGCLastRun = map[string]time.Time{} +) + +type stateFile struct { + dir string +} + +type keyState struct { + Version int `json:"version"` + Entries []int64 `json:"entries"` +} + +var stateKeyRe = regexp.MustCompile(`^[a-f0-9]{64}$`) + +func defaultStateFile() *stateFile { + // Runtime dir is workspace-aware by design; local rate limit state is shared + // across processes in the same lark-cli workspace/profile runtime. + return newStateFile(filepath.Join(core.GetRuntimeDir(), "ratelimit")) +} + +func newStateFile(dir string) *stateFile { + return &stateFile{dir: dir} +} + +func (s *stateFile) WithKeyLock(ctx context.Context, key string, fn func([]int64) ([]int64, error)) error { + if key == "" { + return internalStateError("rate limit key is empty") + } + if err := vfs.MkdirAll(s.dir, 0700); err != nil { + return internalStateError("create rate limit state dir: %v", err) + } + s.maybeGC(time.Now()) + statePath, lockPath := s.pathsForKey(key) + lock, err := s.lockKey(ctx, lockPath) + if err != nil { + return err + } + defer lock.Unlock() //nolint:errcheck // best-effort release; operation result is already decided. + + entries, err := s.loadKeyState(statePath) + if err != nil { + return err + } + entries, err = fn(entries) + if err != nil { + return err + } + return s.saveKeyState(statePath, entries) +} + +func (s *stateFile) lockKey(ctx context.Context, lockPath string) (*lockfile.LockFile, error) { + lock := lockfile.New(lockPath) + lockCtx, cancel := context.WithTimeout(ctx, lockWaitTimeout) + defer cancel() + delay := lockRetryInitial + for { + if err := lock.TryLock(); err != nil { + if !errors.Is(err, lockfile.ErrHeld) { + return nil, internalStateError("lock rate limit state: %v", err) + } + select { + case <-lockCtx.Done(): + if ctx.Err() != nil { + return nil, ctx.Err() + } + return nil, internalStateError("timed out waiting for rate limit state lock") + case <-time.After(delay): + if delay < lockRetryMax { + delay += lockRetryInitial + if delay > lockRetryMax { + delay = lockRetryMax + } + } + continue + } + } + return lock, nil + } +} + +func (s *stateFile) pathsForKey(key string) (string, string) { + safe := key + if !stateKeyRe.MatchString(key) { + sum := sha256.Sum256([]byte(key)) + safe = hex.EncodeToString(sum[:]) + } + return filepath.Join(s.dir, safe+".json"), filepath.Join(s.dir, safe+".lock") +} + +func (s *stateFile) maybeGC(now time.Time) { + stateGCMu.Lock() + last := stateGCLastRun[s.dir] + if !last.IsZero() && now.Sub(last) < stateGCInterval { + stateGCMu.Unlock() + return + } + stateGCLastRun[s.dir] = now + stateGCMu.Unlock() + + s.gcExpired(now) +} + +func (s *stateFile) gcExpired(now time.Time) { + entries, err := vfs.ReadDir(s.dir) + if err != nil { + return + } + maxAge := maxRuleWindow(builtinRules) + stateGCGrace + for _, entry := range entries { + name := entry.Name() + if entry.IsDir() || filepath.Ext(name) != ".json" { + continue + } + key := strings.TrimSuffix(name, ".json") + if !stateKeyRe.MatchString(key) { + continue + } + info, err := entry.Info() + if err != nil || now.Sub(info.ModTime()) < maxAge { + continue + } + statePath, lockPath := s.pathsForKey(key) + lock := lockfile.New(lockPath) + if err := lock.TryLock(); err != nil { + continue + } + func() { + defer lock.Unlock() //nolint:errcheck // best-effort cleanup. + info, err := vfs.Stat(statePath) + if err != nil || now.Sub(info.ModTime()) < maxAge { + return + } + _ = vfs.Remove(statePath) + }() + } +} + +func (s *stateFile) loadKeyState(path string) ([]int64, error) { + data, err := vfs.ReadFile(path) + if err != nil { + if errors.Is(err, os.ErrNotExist) { + return nil, nil + } + return nil, internalStateError("read rate limit state: %v", err) + } + var st keyState + if err := json.Unmarshal(data, &st); err != nil { + return nil, corruptStateError(err) + } + if st.Version != stateVersion { + return nil, corruptStateError(fmt.Errorf("unsupported version %d", st.Version)) + } + return st.Entries, nil +} + +func (s *stateFile) saveKeyState(path string, entries []int64) error { + if len(entries) == 0 { + if err := vfs.Remove(path); err != nil && !errors.Is(err, os.ErrNotExist) { + return internalStateError("remove empty rate limit state: %v", err) + } + return nil + } + data, err := json.MarshalIndent(keyState{Version: stateVersion, Entries: entries}, "", " ") + if err != nil { + return internalStateError("encode rate limit state: %v", err) + } + data = append(data, '\n') + if err := validate.AtomicWrite(path, data, 0600); err != nil { + return internalStateError("write rate limit state: %v", err) + } + return nil +} + +func internalStateError(format string, args ...any) error { + return output.ErrWithHint(output.ExitInternal, "internal", + fmt.Sprintf(format, args...), + "delete ratelimit/*.json under the lark-cli runtime directory and retry") +} + +func corruptStateError(err error) error { + return output.ErrWithHint(output.ExitInternal, "internal", + fmt.Sprintf("rate limit state is invalid: %v", err), + "delete ratelimit/*.json under the lark-cli runtime directory and retry") +} diff --git a/internal/ratelimit/store_test.go b/internal/ratelimit/store_test.go new file mode 100644 index 000000000..88cc56833 --- /dev/null +++ b/internal/ratelimit/store_test.go @@ -0,0 +1,206 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +package ratelimit + +import ( + "context" + "encoding/json" + "errors" + "os" + "path/filepath" + "runtime" + "strings" + "testing" + "time" + + "github.com/larksuite/cli/internal/lockfile" + "github.com/larksuite/cli/internal/output" +) + +func TestStoreWritesReadableStateWithRestrictivePermissions(t *testing.T) { + dir := t.TempDir() + rule := testRule() + req := testRequest() + limiter := NewLimiterForDir(dir, []Rule{rule}, time.Now) + if err := limiter.Allow(context.Background(), req); err != nil { + t.Fatalf("check err = %v", err) + } + statePath := testStatePath(dir, rule, req) + data, err := os.ReadFile(statePath) + if err != nil { + t.Fatalf("read state: %v", err) + } + if len(data) == 0 { + t.Fatal("expected non-empty state") + } + if runtime.GOOS != "windows" { + info, err := os.Stat(statePath) + if err != nil { + t.Fatalf("stat state: %v", err) + } + if got := info.Mode().Perm(); got != 0600 { + t.Fatalf("state mode = %o, want 0600", got) + } + } +} + +func TestStoreCorruptJSONReturnsInternalError(t *testing.T) { + dir := t.TempDir() + rule := testRule() + req := testRequest() + if err := os.WriteFile(testStatePath(dir, rule, req), []byte("{bad"), 0600); err != nil { + t.Fatalf("write corrupt state: %v", err) + } + limiter := NewLimiterForDir(dir, []Rule{rule}, time.Now) + err := limiter.Allow(context.Background(), req) + if err == nil { + t.Fatal("expected error") + } + var exitErr *output.ExitError + if !errors.As(err, &exitErr) { + t.Fatalf("expected ExitError, got %T", err) + } + if exitErr.Code != output.ExitInternal || exitErr.Detail == nil || exitErr.Detail.Type != "internal" { + t.Fatalf("unexpected detail: %#v", exitErr.Detail) + } +} + +func TestTwoLimitersShareStateFile(t *testing.T) { + dir := t.TempDir() + now := time.Unix(100, 0) + rule := testRule() + rule.Limit = 1 + first := NewLimiterForDir(dir, []Rule{rule}, func() time.Time { return now }) + second := NewLimiterForDir(dir, []Rule{rule}, func() time.Time { return now }) + + if err := first.Allow(context.Background(), testRequest()); err != nil { + t.Fatalf("first check err = %v", err) + } + if err := second.Allow(context.Background(), testRequest()); err == nil { + t.Fatal("expected second limiter to see shared state") + } +} + +func TestStoreDeletedStateFileStartsFreshWindow(t *testing.T) { + dir := t.TempDir() + now := time.Unix(100, 0) + rule := testRule() + rule.Limit = 1 + req := testRequest() + limiter := NewLimiterForDir(dir, []Rule{rule}, func() time.Time { return now }) + + if err := limiter.Allow(context.Background(), req); err != nil { + t.Fatalf("first check err = %v", err) + } + statePath := testStatePath(dir, rule, req) + if err := os.Remove(statePath); err != nil { + t.Fatalf("remove state: %v", err) + } + if err := limiter.Allow(context.Background(), req); err != nil { + t.Fatalf("check after deleted state err = %v", err) + } + if _, err := os.Stat(statePath); err != nil { + t.Fatalf("state file should be recreated: %v", err) + } +} + +func TestDifferentKeysUseDifferentStateFiles(t *testing.T) { + dir := t.TempDir() + now := time.Unix(100, 0) + rule := testRule() + limiter := NewLimiterForDir(dir, []Rule{rule}, func() time.Time { return now }) + first := testRequest() + second := first + second.AppID = "app-2" + + if err := limiter.Allow(context.Background(), first); err != nil { + t.Fatalf("first check err = %v", err) + } + if err := limiter.Allow(context.Background(), second); err != nil { + t.Fatalf("second check err = %v", err) + } + + if _, err := os.Stat(testStatePath(dir, rule, first)); err != nil { + t.Fatalf("first state file missing: %v", err) + } + if _, err := os.Stat(testStatePath(dir, rule, second)); err != nil { + t.Fatalf("second state file missing: %v", err) + } +} + +func testStatePath(dir string, rule Rule, req Request) string { + return filepath.Join(dir, buildKey(req.Brand, req.AppID, rule.Method, rule.CanonicalPath)+".json") +} + +func TestStoreGCOldKeyStateFiles(t *testing.T) { + dir := t.TempDir() + rule := testRule() + oldReq := testRequest() + freshReq := oldReq + freshReq.AppID = "app-2" + oldPath := testStatePath(dir, rule, oldReq) + freshPath := testStatePath(dir, rule, freshReq) + writeTestKeyState(t, oldPath, []int64{1}) + writeTestKeyState(t, freshPath, []int64{2}) + + now := time.Now() + oldTime := now.Add(-maxRuleWindow(builtinRules) - stateGCGrace - time.Second) + if err := os.Chtimes(oldPath, oldTime, oldTime); err != nil { + t.Fatalf("chtimes old state: %v", err) + } + + newStateFile(dir).gcExpired(now) + + if _, err := os.Stat(oldPath); !errors.Is(err, os.ErrNotExist) { + t.Fatalf("old state file stat err = %v, want not exist", err) + } + if _, err := os.Stat(freshPath); err != nil { + t.Fatalf("fresh state file should remain: %v", err) + } +} + +func TestStoreLockTimesOut(t *testing.T) { + dir := t.TempDir() + if err := os.MkdirAll(dir, 0700); err != nil { + t.Fatal(err) + } + store := newStateFile(dir) + key := buildKey(testRequest().Brand, testRequest().AppID, testRule().Method, testRule().CanonicalPath) + _, lockPath := store.pathsForKey(key) + lock := lockfile.New(lockPath) + if err := lock.TryLock(); err != nil { + t.Fatalf("hold lock: %v", err) + } + defer lock.Unlock() + + oldTimeout := lockWaitTimeout + lockWaitTimeout = 20 * time.Millisecond + defer func() { lockWaitTimeout = oldTimeout }() + + err := store.WithKeyLock(context.Background(), key, func(entries []int64) ([]int64, error) { + t.Fatal("lock callback should not run") + return entries, nil + }) + if err == nil { + t.Fatal("expected lock timeout") + } + var exitErr *output.ExitError + if !errors.As(err, &exitErr) || exitErr.Code != output.ExitInternal { + t.Fatalf("err = %v, want internal ExitError", err) + } + if !strings.Contains(err.Error(), "timed out waiting") { + t.Fatalf("err = %v, want timeout message", err) + } +} + +func writeTestKeyState(t *testing.T, path string, entries []int64) { + t.Helper() + data, err := json.Marshal(keyState{Version: stateVersion, Entries: entries}) + if err != nil { + t.Fatal(err) + } + if err := os.WriteFile(path, data, 0600); err != nil { + t.Fatalf("write key state: %v", err) + } +} diff --git a/shortcuts/common/common.go b/shortcuts/common/common.go index eeb11f585..af911bfcd 100644 --- a/shortcuts/common/common.go +++ b/shortcuts/common/common.go @@ -10,6 +10,7 @@ import ( "time" "github.com/larksuite/cli/internal/output" + "github.com/larksuite/cli/internal/ratelimit" "github.com/larksuite/cli/internal/util" ) @@ -166,6 +167,9 @@ func CheckApiError(w io.Writer, result interface{}, action string) bool { // HandleApiResult checks for network/API errors and returns the "data" field. func HandleApiResult(result interface{}, err error, action string) (map[string]interface{}, error) { if err != nil { + if ratelimit.IsLocalRateLimit(err) { + return nil, err + } return nil, output.Errorf(output.ExitAPI, "api_error", "%s: %s", action, err) } resultMap, _ := result.(map[string]interface{}) diff --git a/shortcuts/mail/mail_shortcut_test.go b/shortcuts/mail/mail_shortcut_test.go index 7d7ebb96c..cec56004e 100644 --- a/shortcuts/mail/mail_shortcut_test.go +++ b/shortcuts/mail/mail_shortcut_test.go @@ -5,8 +5,10 @@ package mail import ( "bytes" + "context" "encoding/base64" "encoding/json" + "errors" "os" "testing" "time" @@ -18,6 +20,8 @@ import ( "github.com/larksuite/cli/internal/cmdutil" "github.com/larksuite/cli/internal/core" "github.com/larksuite/cli/internal/httpmock" + "github.com/larksuite/cli/internal/output" + "github.com/larksuite/cli/internal/ratelimit" "github.com/larksuite/cli/shortcuts/common" ) @@ -89,6 +93,193 @@ func encodeFixtureEMLForMailTest(raw string) string { return base64.URLEncoding.EncodeToString([]byte(raw)) } +func TestMailMessageShortcutUsesLocalMailRateLimit(t *testing.T) { + f, stdout, _, reg := mailShortcutTestFactory(t) + now := time.Unix(100, 0) + rule := ratelimit.Rule{ + Method: "GET", + CanonicalPath: "/open-apis/mail/v1/user_mailboxes/:user_mailbox_id/messages/:message_id", + Window: 2 * time.Second, + Limit: 1, + Scope: ratelimit.ScopeApp, + } + restore := ratelimit.SetDefaultLimiterForTest(ratelimit.NewLimiterForDir(t.TempDir(), []ratelimit.Rule{rule}, func() time.Time { return now })) + defer restore() + + stub := &httpmock.Stub{ + Method: "GET", + URL: "/open-apis/mail/v1/user_mailboxes/me/messages/msg_1", + Body: map[string]interface{}{ + "code": 0, + "msg": "ok", + "data": map[string]interface{}{ + "message": map[string]interface{}{ + "message_id": "msg_1", + "subject": "hello", + "body_plain_text": encodeFixtureEMLForMailTest("hello"), + "message_state": "READ", + }, + }, + }, + } + reg.Register(stub) + + args := []string{"+message", "--message-id", "msg_1", "--html=false", "--as", "user"} + if err := runMountedMailShortcut(t, MailMessage, args, f, stdout); err != nil { + t.Fatalf("first +message err = %v", err) + } + if len(stub.CapturedBodies) != 1 { + t.Fatalf("HTTP calls after first run = %d, want 1", len(stub.CapturedBodies)) + } + + err := runMountedMailShortcut(t, MailMessage, args, f, stdout) + if err == nil { + t.Fatal("expected local rate limit") + } + var exitErr *output.ExitError + if !errors.As(err, &exitErr) || exitErr.Detail == nil || exitErr.Detail.Type != "rate_limit" { + t.Fatalf("err = %v, want rate_limit ExitError", err) + } + if len(stub.CapturedBodies) != 1 { + t.Fatalf("HTTP calls after local rate limit = %d, want 1", len(stub.CapturedBodies)) + } +} + +func TestMailMessagesShortcutUsesLocalMailRateLimit(t *testing.T) { + f, stdout, _, reg := mailShortcutTestFactory(t) + now := time.Unix(100, 0) + rule := ratelimit.Rule{ + Method: "POST", + CanonicalPath: "/open-apis/mail/v1/user_mailboxes/:user_mailbox_id/messages/batch_get", + Window: 2 * time.Second, + Limit: 1, + Scope: ratelimit.ScopeApp, + } + restore := ratelimit.SetDefaultLimiterForTest(ratelimit.NewLimiterForDir(t.TempDir(), []ratelimit.Rule{rule}, func() time.Time { return now })) + defer restore() + + stub := &httpmock.Stub{ + Method: "POST", + URL: "/open-apis/mail/v1/user_mailboxes/me/messages/batch_get", + Body: map[string]interface{}{ + "code": 0, + "msg": "ok", + "data": map[string]interface{}{ + "messages": []interface{}{ + map[string]interface{}{ + "message_id": "msg_1", + "subject": "hello", + "body_plain_text": encodeFixtureEMLForMailTest("hello"), + "message_state": "READ", + }, + }, + }, + }, + } + reg.Register(stub) + + args := []string{"+messages", "--message-ids", "msg_1", "--html=false", "--as", "user"} + if err := runMountedMailShortcut(t, MailMessages, args, f, stdout); err != nil { + t.Fatalf("first +messages err = %v", err) + } + if len(stub.CapturedBodies) != 1 { + t.Fatalf("HTTP calls after first run = %d, want 1", len(stub.CapturedBodies)) + } + + err := runMountedMailShortcut(t, MailMessages, args, f, stdout) + if err == nil { + t.Fatal("expected local rate limit") + } + var exitErr *output.ExitError + if !errors.As(err, &exitErr) || exitErr.Detail == nil || exitErr.Detail.Type != "rate_limit" { + t.Fatalf("err = %v, want rate_limit ExitError", err) + } + if len(stub.CapturedBodies) != 1 { + t.Fatalf("HTTP calls after local rate limit = %d, want 1", len(stub.CapturedBodies)) + } +} + +func TestMailTriageShortcutPreservesLocalMailRateLimit(t *testing.T) { + f, stdout, _, reg := mailShortcutTestFactory(t) + now := time.Unix(100, 0) + rule := ratelimit.Rule{ + Method: "POST", + CanonicalPath: "/open-apis/mail/v1/user_mailboxes/:user_mailbox_id/search", + Window: 2 * time.Second, + Limit: 1, + Scope: ratelimit.ScopeApp, + } + restore := ratelimit.SetDefaultLimiterForTest(ratelimit.NewLimiterForDir(t.TempDir(), []ratelimit.Rule{rule}, func() time.Time { return now })) + defer restore() + + stub := &httpmock.Stub{ + Method: "POST", + URL: "/open-apis/mail/v1/user_mailboxes/me/search", + Body: map[string]interface{}{ + "code": 0, + "msg": "ok", + "data": map[string]interface{}{ + "items": []interface{}{}, + "has_more": false, + }, + }, + } + reg.Register(stub) + + args := []string{"+triage", "--query", "hello", "--format", "data", "--as", "user"} + if err := runMountedMailShortcut(t, MailTriage, args, f, stdout); err != nil { + t.Fatalf("first +triage err = %v", err) + } + if len(stub.CapturedBodies) != 1 { + t.Fatalf("HTTP calls after first run = %d, want 1", len(stub.CapturedBodies)) + } + + err := runMountedMailShortcut(t, MailTriage, args, f, stdout) + if err == nil { + t.Fatal("expected local rate limit") + } + var exitErr *output.ExitError + if !errors.As(err, &exitErr) || exitErr.Detail == nil || exitErr.Detail.Type != "rate_limit" { + t.Fatalf("err = %v, want rate_limit ExitError", err) + } + if len(stub.CapturedBodies) != 1 { + t.Fatalf("HTTP calls after local rate limit = %d, want 1", len(stub.CapturedBodies)) + } +} + +func TestMailWatchFetchMessageUsesLocalMailRateLimit(t *testing.T) { + f, _, _, _ := mailShortcutTestFactory(t) + now := time.Unix(100, 0) + rule := ratelimit.Rule{ + Method: "GET", + CanonicalPath: "/open-apis/mail/v1/user_mailboxes/:user_mailbox_id/messages/:message_id", + Window: 2 * time.Second, + Limit: 1, + Scope: ratelimit.ScopeApp, + } + restore := ratelimit.SetDefaultLimiterForTest(ratelimit.NewLimiterForDir(t.TempDir(), []ratelimit.Rule{rule}, func() time.Time { return now })) + defer restore() + + if err := ratelimit.Allow(context.Background(), ratelimit.Request{ + Brand: core.BrandFeishu, + AppID: "test-app", + Method: "GET", + Path: "/open-apis/mail/v1/user_mailboxes/me/messages/msg_1", + }); err != nil { + t.Fatalf("pre-consume rate limit slot err = %v", err) + } + + runtime := common.TestNewRuntimeContextForAPI(context.Background(), &cobra.Command{Use: "test"}, mailTestConfig(), f, core.AsUser) + _, err := fetchMessageForWatch(runtime, "me", "msg_1", "metadata") + if err == nil { + t.Fatal("expected local rate limit") + } + var exitErr *output.ExitError + if !errors.As(err, &exitErr) || exitErr.Detail == nil || exitErr.Detail.Type != "rate_limit" { + t.Fatalf("err = %v, want rate_limit ExitError", err) + } +} + // chdirTemp changes the working directory to a fresh temp directory and // restores it when the test finishes. This allows SafeInputPath/SafeOutputPath // to accept relative file paths created in the temp directory. diff --git a/shortcuts/mail/mail_triage.go b/shortcuts/mail/mail_triage.go index 08507247b..173df1b9a 100644 --- a/shortcuts/mail/mail_triage.go +++ b/shortcuts/mail/mail_triage.go @@ -14,6 +14,7 @@ import ( "time" "github.com/larksuite/cli/internal/output" + "github.com/larksuite/cli/internal/ratelimit" "github.com/larksuite/cli/shortcuts/common" larkcore "github.com/larksuite/oapi-sdk-go/v3/core" ) @@ -1106,6 +1107,9 @@ func doJSONAPI(runtime *common.RuntimeContext, req *larkcore.ApiReq, action stri return nil, handleErr } } else { + if ratelimit.IsLocalRateLimit(err) { + return nil, err + } lastErr = output.Errorf(output.ExitAPI, "api_error", "%s: %s", action, err) if attempt == triageAPIRetries { return nil, lastErr