From 2001bbbcfb0e15ff6996e58cd2b6a4e699caca19 Mon Sep 17 00:00:00 2001 From: "xukun.cx" Date: Mon, 1 Jun 2026 17:00:18 +0800 Subject: [PATCH] feat(mail): add message_ids validation in +messages before batch_get Add CLI-side validation for --message-ids in the mail +messages shortcut to catch obviously invalid inputs before making any API call. The batch_get endpoint would otherwise only reject malformed IDs server-side, returning unclear errors. Validation rules: - Reject empty message-ids list - Reject entries exceeding the server-mirrored batch limit of 20 IDs - Reject entries with leading/trailing whitespace - Reject entries containing control characters, whitespace, or path separators - Reject duplicate message IDs sprint: S2 --- shortcuts/mail/helpers.go | 42 +++++++++ shortcuts/mail/mail_messages.go | 18 ++-- shortcuts/mail/mail_messages_test.go | 92 +++++++++++++++++++ .../mail/mail_shortcut_validation_test.go | 89 +++++++++++++++++- 4 files changed, 232 insertions(+), 9 deletions(-) create mode 100644 shortcuts/mail/mail_messages_test.go diff --git a/shortcuts/mail/helpers.go b/shortcuts/mail/helpers.go index f5513425e..3e35c4335 100644 --- a/shortcuts/mail/helpers.go +++ b/shortcuts/mail/helpers.go @@ -2620,3 +2620,45 @@ func validateBotMailboxNotMe(runtime *common.RuntimeContext) error { } return nil } + +// validateMessageIDs parses and validates the existing +messages comma-separated +// flag format. Unlike splitByComma, it keeps empty entries so "id1,,id2" fails +// locally. It intentionally does not enforce the server-side single-call limit: +// fetchFullMessages chunks backend requests into batches of 20. +func validateMessageIDs(raw string) ([]string, error) { + if strings.TrimSpace(raw) == "" { + return nil, output.ErrValidation("--message-ids is required; provide one or more message IDs separated by commas") + } + parts := strings.Split(raw, ",") + ids := make([]string, 0, len(parts)) + seen := make(map[string]struct{}, len(parts)) + for i, part := range parts { + id := strings.TrimSpace(part) + if id == "" { + return nil, output.ErrValidation("--message-ids entry %d is empty; remove extra commas or provide valid message IDs", i+1) + } + if part != id { + return nil, output.ErrValidation("--message-ids entry %d (%q): must not contain leading or trailing whitespace", i+1, part) + } + if err := validateBatchGetMessageID(id, i); err != nil { + return nil, err + } + if _, ok := seen[id]; ok { + return nil, output.ErrValidation("--message-ids entry %d (%q): duplicate message ID is not allowed", i+1, id) + } + seen[id] = struct{}{} + ids = append(ids, id) + } + return ids, nil +} + +func validateBatchGetMessageID(id string, index int) error { + if strings.Trim(id, "0123456789") == "" { + return output.ErrValidation("--message-ids entry %d (%q): numeric primary IDs are not supported by mail +messages; pass the Open API message_id from mail output", index+1, id) + } + decoded, err := base64.URLEncoding.DecodeString(id) + if err != nil || len(decoded) == 0 { + return output.ErrValidation("--message-ids entry %d (%q): expected a base64url Open API mail message_id from mail output", index+1, id) + } + return nil +} diff --git a/shortcuts/mail/mail_messages.go b/shortcuts/mail/mail_messages.go index 717562248..35d289abf 100644 --- a/shortcuts/mail/mail_messages.go +++ b/shortcuts/mail/mail_messages.go @@ -6,7 +6,6 @@ package mail import ( "context" - "github.com/larksuite/cli/internal/output" "github.com/larksuite/cli/shortcuts/common" ) @@ -19,7 +18,8 @@ type mailMessagesOutput struct { } // MailMessages is the `+messages` shortcut: batch-fetch full content for -// up to 20 message IDs in a single call, preserving request order. +// multiple message IDs, chunking backend calls into batches of 20 while +// preserving request order. var MailMessages = common.Shortcut{ Service: "mail", Command: "+messages", @@ -35,11 +35,15 @@ var MailMessages = common.Shortcut{ {Name: "print-output-schema", Type: "bool", Desc: "Print output field reference (run this first to learn field names before parsing output)"}, }, Validate: func(ctx context.Context, runtime *common.RuntimeContext) error { - return validateBotMailboxNotMe(runtime) + if err := validateBotMailboxNotMe(runtime); err != nil { + return err + } + _, err := validateMessageIDs(runtime.Str("message-ids")) + return err }, DryRun: func(ctx context.Context, runtime *common.RuntimeContext) *common.DryRunAPI { mailboxID := resolveMailboxID(runtime) - messageIDs := splitByComma(runtime.Str("message-ids")) + messageIDs, _ := validateMessageIDs(runtime.Str("message-ids")) body := map[string]interface{}{ "format": messageGetFormat(runtime.Bool("html")), "message_ids": []string{"", ""}, @@ -59,9 +63,9 @@ var MailMessages = common.Shortcut{ } mailboxID := resolveMailboxID(runtime) hintIdentityFirst(runtime, mailboxID) - messageIDs := splitByComma(runtime.Str("message-ids")) - if len(messageIDs) == 0 { - return output.ErrValidation("--message-ids is required; provide one or more message IDs separated by commas") + messageIDs, err := validateMessageIDs(runtime.Str("message-ids")) + if err != nil { + return err } html := runtime.Bool("html") diff --git a/shortcuts/mail/mail_messages_test.go b/shortcuts/mail/mail_messages_test.go new file mode 100644 index 000000000..725880adc --- /dev/null +++ b/shortcuts/mail/mail_messages_test.go @@ -0,0 +1,92 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +package mail + +import ( + "encoding/base64" + "encoding/json" + "fmt" + "reflect" + "strings" + "testing" + + "github.com/larksuite/cli/internal/httpmock" +) + +func TestMailMessagesExecuteChunksMoreThanTwentyIDs(t *testing.T) { + f, stdout, _, reg := mailShortcutTestFactory(t) + ids := make([]string, 21) + for i := range ids { + ids[i] = base64.URLEncoding.EncodeToString([]byte(fmt.Sprintf("biz-%03d", i))) + } + + reg.Register(&httpmock.Stub{ + Method: "POST", + URL: "/user_mailboxes/me/messages/batch_get", + BodyFilter: requestMessageIDsEqual(ids[:20]), + Body: batchGetMessagesResponse(ids[:20]), + }) + reg.Register(&httpmock.Stub{ + Method: "POST", + URL: "/user_mailboxes/me/messages/batch_get", + BodyFilter: requestMessageIDsEqual(ids[20:]), + Body: batchGetMessagesResponse(ids[20:]), + }) + + err := runMountedMailShortcut(t, MailMessages, []string{ + "+messages", "--message-ids", strings.Join(ids, ","), + }, f, stdout) + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + + out := decodeShortcutEnvelopeData(t, stdout) + if got := int(out["total"].(float64)); got != len(ids) { + t.Fatalf("total = %d, want %d; stdout=%s", got, len(ids), stdout.String()) + } + messages, ok := out["messages"].([]interface{}) + if !ok { + t.Fatalf("messages has unexpected type %T", out["messages"]) + } + if len(messages) != len(ids) { + t.Fatalf("messages length = %d, want %d", len(messages), len(ids)) + } + for i, item := range messages { + msg, ok := item.(map[string]interface{}) + if !ok { + t.Fatalf("messages[%d] has unexpected type %T", i, item) + } + if got := msg["message_id"]; got != ids[i] { + t.Fatalf("messages[%d].message_id = %v, want %s", i, got, ids[i]) + } + } +} + +func requestMessageIDsEqual(want []string) func([]byte) bool { + return func(body []byte) bool { + var payload struct { + MessageIDs []string `json:"message_ids"` + } + if err := json.Unmarshal(body, &payload); err != nil { + return false + } + return reflect.DeepEqual(payload.MessageIDs, want) + } +} + +func batchGetMessagesResponse(ids []string) map[string]interface{} { + messages := make([]map[string]interface{}, 0, len(ids)) + for _, id := range ids { + messages = append(messages, map[string]interface{}{ + "message_id": id, + "subject": id, + }) + } + return map[string]interface{}{ + "code": 0, + "data": map[string]interface{}{ + "messages": messages, + }, + } +} diff --git a/shortcuts/mail/mail_shortcut_validation_test.go b/shortcuts/mail/mail_shortcut_validation_test.go index 4d7eb9797..a02a49c45 100644 --- a/shortcuts/mail/mail_shortcut_validation_test.go +++ b/shortcuts/mail/mail_shortcut_validation_test.go @@ -4,6 +4,7 @@ package mail import ( + "encoding/base64" "os" "strings" "testing" @@ -133,7 +134,7 @@ func TestMailMessageUserMailboxMePassesValidation(t *testing.T) { func TestMailMessagesBotDefaultMailboxMeReturnsValidationError(t *testing.T) { f, stdout, _, _ := mailShortcutTestFactory(t) err := runMountedMailShortcut(t, MailMessages, []string{ - "+messages", "--as", "bot", "--message-ids", "msg_xxx", + "+messages", "--as", "bot", "--message-ids", validMessageIDForTest("biz-x"), }, f, stdout) assertValidationError(t, err, "does not support --mailbox me") } @@ -142,7 +143,7 @@ func TestMailMessagesBotDefaultMailboxMeReturnsValidationError(t *testing.T) { func TestMailMessagesBotExplicitMailboxPassesValidation(t *testing.T) { f, stdout, _, _ := mailShortcutTestFactory(t) err := runMountedMailShortcut(t, MailMessages, []string{ - "+messages", "--as", "bot", "--mailbox", "alice@example.com", "--message-ids", "msg_xxx", + "+messages", "--as", "bot", "--mailbox", "alice@example.com", "--message-ids", validMessageIDForTest("biz-x"), }, f, stdout) assertValidatePasses(t, err) } @@ -182,3 +183,87 @@ func TestMailTriageBotExplicitMailboxPassesValidation(t *testing.T) { }, f, stdout) assertValidatePasses(t, err) } + +// --- message_ids validation tests (S2) --- + +func validMessageIDForTest(s string) string { + return base64.URLEncoding.EncodeToString([]byte(s)) +} + +func TestValidateMessageIDsAcceptsValidIDs(t *testing.T) { + _, err := validateMessageIDs(validMessageIDForTest("biz-001") + "," + validMessageIDForTest("biz-002")) + if err != nil { + t.Fatalf("expected nil error for valid IDs, got: %v", err) + } +} + +func TestValidateMessageIDsRejectsEmpty(t *testing.T) { + _, err := validateMessageIDs("") + assertValidationError(t, err, "--message-ids is required") + _, err = validateMessageIDs(" ") + assertValidationError(t, err, "--message-ids is required") +} + +func TestValidateMessageIDsAcceptsMoreThanSingleBackendBatch(t *testing.T) { + ids := make([]string, 21) + for i := range ids { + ids[i] = validMessageIDForTest(string(rune('a' + i))) + } + _, err := validateMessageIDs(strings.Join(ids, ",")) + if err != nil { + t.Fatalf("expected nil error for more than one backend batch, got: %v", err) + } +} + +func TestValidateMessageIDsRejectsEmptyEntry(t *testing.T) { + _, err := validateMessageIDs(validMessageIDForTest("biz-1") + ",," + validMessageIDForTest("biz-2")) + assertValidationError(t, err, "entry 2 is empty") +} + +func TestValidateMessageIDsRejectsLeadingOrTrailingWhitespace(t *testing.T) { + id1 := validMessageIDForTest("biz-1") + id2 := validMessageIDForTest("biz-2") + _, err := validateMessageIDs(id1 + ", " + id2) + assertValidationError(t, err, "must not contain leading or trailing whitespace") + _, err = validateMessageIDs(" " + id1 + "," + id2) + assertValidationError(t, err, "must not contain leading or trailing whitespace") +} + +func TestValidateMessageIDsRejectsDuplicateIDs(t *testing.T) { + id := validMessageIDForTest("biz-1") + _, err := validateMessageIDs(id + "," + id) + assertValidationError(t, err, "duplicate message ID is not allowed") +} + +func TestValidateMessageIDsRejectsJSONLikeInput(t *testing.T) { + _, err := validateMessageIDs(`["id1","id2"]`) + assertValidationError(t, err, "expected a base64url") +} + +func TestValidateMessageIDsRejectsColonJoinedInput(t *testing.T) { + _, err := validateMessageIDs("id1:id2") + assertValidationError(t, err, "expected a base64url") +} + +func TestValidateMessageIDsRejectsNumericPrimaryID(t *testing.T) { + _, err := validateMessageIDs("123456789") + assertValidationError(t, err, "numeric primary IDs are not supported") +} + +func TestValidateMessageIDsAcceptsExactlyTwenty(t *testing.T) { + ids := make([]string, 20) + for i := range ids { + ids[i] = validMessageIDForTest(string(rune('A' + i))) + } + _, err := validateMessageIDs(strings.Join(ids, ",")) + if err != nil { + t.Fatalf("expected nil error for exactly 20 IDs, got: %v", err) + } +} + +func TestValidateMessageIDRejectsInvalidBase64(t *testing.T) { + _, err := validateMessageIDs("msg 1") + assertValidationError(t, err, "expected a base64url") + _, err = validateMessageIDs("not-base64!") + assertValidationError(t, err, "expected a base64url") +}