Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions internal/client/api_errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
larkcore "github.com/larksuite/oapi-sdk-go/v3/core"

"github.com/larksuite/cli/errs"
"github.com/larksuite/cli/internal/ratelimit"
)

// rawAPIJSONHint guides users when an SDK or response body parse fails. The
Expand All @@ -30,6 +31,9 @@ func WrapDoAPIError(err error) error {
if err == nil {
return nil
}
if ratelimit.IsLocalRateLimit(err) {
return err
}

// (1) Pass-through any typed errs.* error.
if _, ok := errs.ProblemOf(err); ok {
Expand Down
15 changes: 15 additions & 0 deletions internal/client/api_errors_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,21 @@ func TestWrapDoAPIError_Nil(t *testing.T) {
}
}

func TestWrapDoAPIError_PreservesLocalMailRateLimit(t *testing.T) {
original := output.ErrAPI(output.LarkErrRateLimit, "rate limited", map[string]any{
"source": "local_ratelimit",
"retry_after_ms": 100,
})
err := WrapDoAPIError(original)
if err != original {
t.Fatalf("WrapDoAPIError returned %p, want original %p", err, original)
}
var exitErr *output.ExitError
if !errors.As(err, &exitErr) || exitErr.Detail == nil || exitErr.Detail.Type != "rate_limit" {
t.Fatalf("err = %v, want preserved rate_limit ExitError", err)
}
}

// ─────────────────────────────────────────────────────────────────────────────
// WrapJSONResponseParseError: typed error contract.
//
Expand Down
26 changes: 25 additions & 1 deletion internal/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import (
"github.com/larksuite/cli/internal/errclass"
"github.com/larksuite/cli/internal/errcompat"
"github.com/larksuite/cli/internal/output"
"github.com/larksuite/cli/internal/ratelimit"
"github.com/larksuite/cli/internal/util"
)

Expand Down Expand Up @@ -130,6 +131,10 @@ func (c *APIClient) buildApiReq(request RawApiRequest) (*larkcore.ApiReq, []lark
func (c *APIClient) DoSDKRequest(ctx context.Context, req *larkcore.ApiReq, as core.Identity, extraOpts ...larkcore.RequestOptionFunc) (*larkcore.ApiResp, error) {
var opts []larkcore.RequestOptionFunc

if err := ratelimit.Allow(ctx, c.rateLimitRequest(req)); err != nil {
return nil, err
}

token, err := c.resolveAccessToken(ctx, as)
if err != nil {
// WrapDoAPIError is idempotent on already-classified errors:
Expand Down Expand Up @@ -166,6 +171,10 @@ func (c *APIClient) DoSDKRequest(ctx context.Context, req *larkcore.ApiReq, as c
func (c *APIClient) DoStream(ctx context.Context, req *larkcore.ApiReq, as core.Identity, opts ...Option) (*http.Response, error) {
cfg := buildConfig(opts)

if err := ratelimit.Allow(ctx, c.rateLimitRequest(req)); err != nil {
return nil, err
}

// Resolve auth
token, err := c.resolveAccessToken(ctx, as)
if err != nil {
Expand Down Expand Up @@ -250,6 +259,21 @@ func (c *APIClient) DoStream(ctx context.Context, req *larkcore.ApiReq, as core.
return resp, nil
}

func (c *APIClient) rateLimitRequest(req *larkcore.ApiReq) ratelimit.Request {
if req == nil {
return ratelimit.Request{}
}
limitReq := ratelimit.Request{
Method: req.HttpMethod,
Path: req.ApiPath,
}
if c != nil && c.Config != nil {
limitReq.Brand = c.Config.Brand
limitReq.AppID = c.Config.AppID
}
return limitReq
}

func streamLogID(header http.Header) string {
logID := strings.TrimSpace(header.Get(larkcore.HttpHeaderKeyLogId))
if logID == "" {
Expand Down Expand Up @@ -379,7 +403,7 @@ func (c *APIClient) paginateLoop(ctx context.Context, request RawApiRequest, opt
ExtraOpts: request.ExtraOpts,
})
if err != nil {
if page == 1 {
if page == 1 || ratelimit.IsLocalRateLimit(err) {
return nil, err
}
fmt.Fprintf(c.ErrOut, "[page %d] error, stopping pagination\n", page)
Expand Down
189 changes: 189 additions & 0 deletions internal/client/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
"github.com/larksuite/cli/internal/core"
"github.com/larksuite/cli/internal/credential"
"github.com/larksuite/cli/internal/output"
"github.com/larksuite/cli/internal/ratelimit"
)

// roundTripFunc is an adapter to use a function as http.RoundTripper.
Expand All @@ -48,6 +49,15 @@ func (s *staticTokenResolver) ResolveToken(_ context.Context, _ credential.Token
return &credential.TokenResult{Token: "test-token"}, nil
}

type countingTokenResolver struct {
count int
}

func (s *countingTokenResolver) ResolveToken(_ context.Context, _ credential.TokenSpec) (*credential.TokenResult, error) {
s.count++
return &credential.TokenResult{Token: "test-token"}, nil
}

// newTestAPIClient creates an APIClient with a mock HTTP transport.
func newTestAPIClient(t *testing.T, rt http.RoundTripper) (*APIClient, *bytes.Buffer) {
t.Helper()
Expand All @@ -68,6 +78,14 @@ func newTestAPIClient(t *testing.T, rt http.RoundTripper) (*APIClient, *bytes.Bu
}, errBuf
}

func TestRateLimitRequestNilNoops(t *testing.T) {
ac := &APIClient{Config: &core.CliConfig{AppID: "test-app", Brand: core.BrandFeishu}}
req := ac.rateLimitRequest(nil)
if req.Brand != "" || req.AppID != "" || req.Method != "" || req.Path != "" {
t.Fatalf("rateLimitRequest(nil) = %#v, want empty request", req)
}
}

func TestIsJSONContentType(t *testing.T) {
tests := []struct {
ct string
Expand Down Expand Up @@ -234,6 +252,48 @@ func TestPaginateAll_PageLimitStopsPagination(t *testing.T) {
}
}

func TestPaginateAll_ReturnsMailRateLimitAfterFirstPage(t *testing.T) {
now := time.Unix(100, 0)
rule := ratelimit.Rule{
Method: "GET",
CanonicalPath: "/open-apis/mail/v1/user_mailboxes/:user_mailbox_id/messages",
Window: 2 * time.Second,
Limit: 1,
Scope: ratelimit.ScopeApp,
}
restore := ratelimit.SetDefaultLimiterForTest(ratelimit.NewLimiterForDir(t.TempDir(), []ratelimit.Rule{rule}, func() time.Time { return now }))
defer restore()

apiCalls := 0
ac, _ := newTestAPIClient(t, roundTripFunc(func(req *http.Request) (*http.Response, error) {
apiCalls++
return jsonResponse(map[string]interface{}{
"code": 0, "msg": "ok",
"data": map[string]interface{}{
"items": []interface{}{map[string]interface{}{"id": apiCalls}},
"has_more": true,
"page_token": "next",
},
}), nil
}))

_, err := ac.PaginateAll(context.Background(), RawApiRequest{
Method: "GET",
URL: "/open-apis/mail/v1/user_mailboxes/me/messages",
As: core.AsBot,
}, PaginationOptions{PageLimit: 10, PageDelay: -1})
if err == nil {
t.Fatal("expected local rate limit")
}
var exitErr *output.ExitError
if !errors.As(err, &exitErr) || exitErr.Detail == nil || exitErr.Detail.Type != "rate_limit" {
t.Fatalf("err = %v, want rate_limit ExitError", err)
}
if apiCalls != 1 {
t.Fatalf("api calls = %d, want 1", apiCalls)
}
}

func TestPaginateAll_NaturalEndClearsPageToken(t *testing.T) {
apiCalls := 0
rt := roundTripFunc(func(req *http.Request) (*http.Response, error) {
Expand Down Expand Up @@ -464,6 +524,60 @@ func TestDoStream_TransportFailureSplitsSubtype(t *testing.T) {
}
}

func TestDoStream_MailRateLimitRunsBeforeTokenAndHTTP(t *testing.T) {
now := time.Unix(100, 0)
rule := ratelimit.Rule{
Method: "GET",
CanonicalPath: "/open-apis/mail/v1/user_mailboxes/:user_mailbox_id/messages/:message_id",
Window: 2 * time.Second,
Limit: 1,
Scope: ratelimit.ScopeApp,
}
restore := ratelimit.SetDefaultLimiterForTest(ratelimit.NewLimiterForDir(t.TempDir(), []ratelimit.Rule{rule}, func() time.Time { return now }))
defer restore()

httpCalls := 0
ac := &APIClient{
HTTP: &http.Client{Transport: roundTripFunc(func(req *http.Request) (*http.Response, error) {
httpCalls++
return &http.Response{
StatusCode: 200,
Body: io.NopCloser(strings.NewReader("ok")),
}, nil
})},
Config: &core.CliConfig{AppID: "test-app", AppSecret: "test-secret", Brand: core.BrandFeishu},
}
resolver := &countingTokenResolver{}
ac.Credential = credential.NewCredentialProvider(nil, nil, resolver, nil)

newReq := func() *larkcore.ApiReq {
return &larkcore.ApiReq{
HttpMethod: http.MethodGet,
ApiPath: "/open-apis/mail/v1/user_mailboxes/me/messages/msg_1",
}
}
resp, err := ac.DoStream(context.Background(), newReq(), core.AsBot)
if err != nil {
t.Fatalf("first DoStream err = %v", err)
}
resp.Body.Close()

_, err = ac.DoStream(context.Background(), newReq(), core.AsBot)
if err == nil {
t.Fatal("expected local rate limit")
}
var exitErr *output.ExitError
if !errors.As(err, &exitErr) || exitErr.Detail == nil || exitErr.Detail.Type != "rate_limit" {
t.Fatalf("err = %v, want rate_limit ExitError", err)
}
if httpCalls != 1 {
t.Fatalf("http calls = %d, want 1", httpCalls)
}
if resolver.count != 1 {
t.Fatalf("token resolutions = %d, want 1", resolver.count)
}
}

// failingTokenResolver always returns TokenUnavailableError, exercising the
// auth/credential failure path through resolveAccessToken.
type failingTokenResolver struct{}
Expand Down Expand Up @@ -582,6 +696,81 @@ func TestDoSDKRequest_AuthFailureSurfacesTypedAuthenticationError(t *testing.T)
}
}

func TestDoSDKRequest_MailRateLimitRunsBeforeTokenAndSDK(t *testing.T) {
now := time.Unix(100, 0)
rule := ratelimit.Rule{
Method: "GET",
CanonicalPath: "/open-apis/mail/v1/user_mailboxes/:user_mailbox_id/messages/:message_id",
Window: 2 * time.Second,
Limit: 1,
Scope: ratelimit.ScopeApp,
}
restore := ratelimit.SetDefaultLimiterForTest(ratelimit.NewLimiterForDir(t.TempDir(), []ratelimit.Rule{rule}, func() time.Time { return now }))
defer restore()

httpCalls := 0
ac, _ := newTestAPIClient(t, roundTripFunc(func(req *http.Request) (*http.Response, error) {
httpCalls++
return jsonResponse(map[string]interface{}{"code": 0, "msg": "ok"}), nil
}))
resolver := &countingTokenResolver{}
ac.Credential = credential.NewCredentialProvider(nil, nil, resolver, nil)

newReq := func() *larkcore.ApiReq {
return &larkcore.ApiReq{
HttpMethod: http.MethodGet,
ApiPath: "/open-apis/mail/v1/user_mailboxes/me/messages/msg_1",
}
}
if _, err := ac.DoSDKRequest(context.Background(), newReq(), core.AsBot); err != nil {
t.Fatalf("first DoSDKRequest err = %v", err)
}
_, err := ac.DoSDKRequest(context.Background(), newReq(), core.AsBot)
if err == nil {
t.Fatal("expected local rate limit")
}
var exitErr *output.ExitError
if !errors.As(err, &exitErr) || exitErr.Detail == nil || exitErr.Detail.Type != "rate_limit" {
t.Fatalf("err = %v, want rate_limit ExitError", err)
}
if httpCalls != 1 {
t.Fatalf("http calls = %d, want 1", httpCalls)
}
if resolver.count != 1 {
t.Fatalf("token resolutions = %d, want 1", resolver.count)
}
}

func TestDoSDKRequest_NonMailAndUnconfiguredMailStillSend(t *testing.T) {
rule := ratelimit.Rule{
Method: "GET",
CanonicalPath: "/open-apis/mail/v1/user_mailboxes/:user_mailbox_id/messages/:message_id",
Window: time.Second,
Limit: 1,
Scope: ratelimit.ScopeApp,
}
restore := ratelimit.SetDefaultLimiterForTest(ratelimit.NewLimiterForDir(t.TempDir(), []ratelimit.Rule{rule}, time.Now))
defer restore()

httpCalls := 0
ac, _ := newTestAPIClient(t, roundTripFunc(func(req *http.Request) (*http.Response, error) {
httpCalls++
return jsonResponse(map[string]interface{}{"code": 0, "msg": "ok"}), nil
}))

for _, path := range []string{
"/open-apis/contact/v3/users/u1",
"/open-apis/mail/v1/user_mailboxes/me/settings",
} {
if _, err := ac.DoSDKRequest(context.Background(), &larkcore.ApiReq{HttpMethod: http.MethodGet, ApiPath: path}, core.AsBot); err != nil {
t.Fatalf("DoSDKRequest(%s) err = %v", path, err)
}
}
if httpCalls != 2 {
t.Fatalf("http calls = %d, want 2", httpCalls)
}
}

// TestDoSDKRequest_TransportFailureWrapsAsNetwork pins that genuinely untyped
// SDK transport errors get the typed network classification via WrapDoAPIError.
// io.ErrUnexpectedEOF from a RoundTripper surfaces through net/http as a
Expand Down
Loading
Loading