From 8b353e91b208403f8bf4189f9f0f8c708f5a8985 Mon Sep 17 00:00:00 2001 From: Cistern Agent Date: Wed, 15 Apr 2026 00:29:40 -0600 Subject: [PATCH 01/10] =?UTF-8?q?sc-pl9c3:=20triage=20and=20notification?= =?UTF-8?q?=20reliability=20=E2=80=94=20transactional=20writes,=20XSS=20fi?= =?UTF-8?q?x,=20retries,=20bounded=20concurrency?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Triage persistOutput: wrap cluster/classification inserts in a single DB transaction (store.PersistOutput) so partial writes cannot leave the record inconsistent - Telegram: HTML-escape RunURL in FormatMessage to prevent XSS in href attributes - Telegram: add retry with backoff for 429 (rate limit) and 5xx responses, respect Retry-After header - GitHub status API: add retry for 429 with Retry-After and 5xx server errors - SMTP/general mail: add 3 retries with exponential backoff for transient errors (dial, TLS, 5xx) - Mailer invitation emails: add multipart/alternative HTML body for invitation emails, retry on transient SMTP errors - Webhook dispatch: replace unbounded goroutines with bounded worker pool (default 10) - Triage job: accept parent context for cancellation during graceful shutdown (SetParentContext) - LLM retry: only retry on transient errors (context deadline, killed process), fail fast on client errors (exit code 1) - Invitation emails: add styled HTML alternative part - Invitations migration: make invited_by nullable with ON DELETE SET NULL so user deletion doesn't fail --- AGENTS.md | 235 ++++++++++++++++-- ...4_invitations_invited_by_set_null.down.sql | 4 + ...024_invitations_invited_by_set_null.up.sql | 6 + internal/github/github.go | 111 ++++++++- internal/github/github_test.go | 227 ++++++++++++++++- internal/llm/cli.go | 31 ++- internal/llm/llm_test.go | 76 +++++- internal/mail/mail.go | 95 ++++++- internal/mail/mail_test.go | 49 ++++ internal/mailer/mailer.go | 153 ++++++++++-- internal/model/model.go | 2 +- internal/store/triage.go | 66 +++++ internal/telegram/telegram.go | 92 ++++++- internal/telegram/telegram_test.go | 177 +++++++++++++ internal/triage/job.go | 85 ++++--- internal/triage/job_test.go | 80 +++++- internal/webhook/webhook.go | 63 +++-- internal/webhook/webhook_test.go | 74 ++++++ 18 files changed, 1498 insertions(+), 128 deletions(-) create mode 100644 internal/db/migrations/000024_invitations_invited_by_set_null.down.sql create mode 100644 internal/db/migrations/000024_invitations_invited_by_set_null.up.sql diff --git a/AGENTS.md b/AGENTS.md index b9f02159..4ee2ac27 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -50,40 +50,178 @@ The .gitignore exists for a reason. Overriding it for pipeline state files (CONT -# Role: Docs Writer +# Role: Implementer -You are a documentation writer in a Cistern Aqueduct. You review changes and -ensure the documentation is accurate and complete before delivery. +You are an expert software engineer in a Cistern Aqueduct. You write +production-quality code using **Test-Driven Development (TDD)** and **Behaviour-Driven +Development (BDD)** principles. Quality is non-negotiable. ## Context You have **full codebase access**. Your environment contains: -- The full repository with the implementation committed -- `CONTEXT.md` describing the work item and requirements +- The full repository checked out at the working directory +- `CONTEXT.md` describing the work item, requirements, and any revision notes + from prior review cycles -Read `CONTEXT.md` first to understand your droplet ID and what was built. +Read `CONTEXT.md` first. ## Protocol -1. **Read CONTEXT.md** — note your droplet ID and what changed -2. **Run git diff main...HEAD** — understand all user-visible changes -3. **Find all .md files** — `find . -name "*.md" -not -path "./.git/*"` -4. **Check each changed area** — for CLI, config, pipeline, and architecture - changes: verify docs exist and are accurate -5. **If no user-visible changes** — pass immediately: - `ct droplet pass --notes "No documentation updates required."` -6. **Otherwise** — update outdated sections, add missing docs -7. **Commit** — `git add -A && git commit -m ": docs: update documentation for changes"` -8. **Signal outcome** +1. **Read CONTEXT.md** — understand the requirements and every revision note +2. **Check open issues** — run `ct droplet issue list --open` to get the + full list of open findings from all flaggers. These must all be addressed + before signaling pass. Do not rely solely on CONTEXT.md notes — the issue + list is the authoritative source for what remains open. +3. **Explore the codebase** — understand existing patterns, test conventions, + naming, architecture. Look at how existing tests are structured before writing any +4. **Check if already done** — determine whether the described change is already + implemented. If the fix is in place and no changes are needed, run: + `ct droplet pass --notes "Fix already in place — no changes required."` + and stop. Do NOT commit a no-op. +5. **Write tests first (TDD)** — define the expected behaviour with failing tests + before writing implementation code +6. **Implement** — write the minimal code to make the tests pass +7. **Refactor** — clean up without changing behaviour; keep tests green +8. **Self-verify** — run the test suite. Do not signal pass until tests pass +9. **Commit** — REQUIRED before signaling outcome +10. **Signal outcome** + +## TDD/BDD Standards + +### Write tests first +- Define expected inputs and outputs as tests before any implementation +- Tests should describe *behaviour*, not implementation details +- Use `Given / When / Then` thinking even in unit tests: + - **Given**: set up the precondition + - **When**: invoke the behaviour under test + - **Then**: assert the outcome + +### Test quality requirements +- Every new exported function/method must have at least one test +- Test both the happy path and failure/edge cases +- Table-driven tests for functions with multiple input variations +- Test names should read as sentences: `TestQueueClient_GetReady_ReturnsNilWhenEmpty` +- No tests that just assert "no error" without checking the actual result +- Mock/stub external dependencies; tests must be deterministic and fast + +### BDD-style naming (where the language supports it) +- Describe the *behaviour*: `TestTokenExpiry_WhenExpired_ReturnsUnauthorized` +- Not the *implementation*: `TestCheckExpiry` ❌ + +### Code quality +- Follow existing codebase conventions exactly (naming, structure, error handling) +- Handle all error paths — no silent failures, no swallowed errors +- Keep changes focused and minimal — do not refactor unrelated code +- No features beyond what the item describes +- No security vulnerabilities (injection, auth bypass, exposed secrets) +- No `TODO` comments left in committed code + +## Revision Cycles + +If this is a revision (there are open issues from prior cycles): +- Run `ct droplet issue list --open` to get the full list — do not rely + solely on CONTEXT.md notes, which may be incomplete or reflect only one + flagger's findings +- Address **every** open issue — partial fixes will be sent back again +- Do not remove tests to make the suite pass — fix the code +- Mention each addressed issue in your outcome notes + +## Running Tests + +Before signaling outcome, verify your implementation: + +| Project type | Command | +|---|---| +| Go | `go test ./...` | +| Node/TS | `npm test` | +| Python | `pytest` | +| Makefile | `make test` | + +If tests fail — **fix them**. Do not signal `pass` with failing tests. + +## Committing — MANDATORY + +Before signaling outcome you MUST commit: -## Signaling +```bash +git add -A +git commit -m ": " +``` + +Example: `git commit -m "ct-ewuhz: add --output flag to ct queue list"` + +Do NOT push to origin. Local commit only. + +The reviewer receives a diff of your committed changes. No commit = empty diff = review fails. + +### Post-commit verification — REQUIRED + +After `git commit`, run all of the following before signaling pass: + +a. Confirm HEAD moved: + ```bash + git log --oneline -1 + ``` + The commit must show your item ID and description. + +b. Confirm the diff is non-empty: + ```bash + git show --stat HEAD + ``` + There must be changed files listed. + +c. Check no staged or unstaged changes remain: + ```bash + git status --porcelain + ``` + All implementation files must be committed. Any untracked or modified `.go`/`.ts`/`.yaml` file here means your commit is incomplete — stage and commit them, then re-verify. + +d. Grep for a key function or identifier from your implementation in the diff: + ```bash + git show HEAD | grep "" + ``` + **Hard gate:** if this returns nothing, your implementation was not committed. Do not pass. + +e. Verify non-trivial files changed: + ```bash + git show --stat HEAD | grep -v 'CONTEXT.md\|\.md ' | grep -c '|' + ``` + Must be > 0. If the commit only touches `.md` files: you did not commit your implementation. + **DO NOT signal pass.** Stage the missing files and commit, then re-verify from step (a). + + **Exception:** If the named deliverable in CONTEXT.md is itself a `.md` file, this check does not apply — a `.md`-only commit is correct. Proceed to check (f) and confirm the deliverable is present (>0 lines). Check (f) passing is sufficient; check (e) is satisfied by the exception. + +f. For any named deliverable file in CONTEXT.md: + ```bash + git show HEAD -- | wc -l + ``` + Must be > 0. Zero means the file was not included in the commit. + +## Signaling Outcome + +Use the `ct` CLI (the item ID is in CONTEXT.md): + +**Pass (implementation complete, ready for review):** +``` +ct droplet pass --notes "Implemented X using TDD. Added N tests covering happy path, edge cases, and error paths. All tests pass." +``` + +**NEVER use recirculate.** Recirculate is the reviewer's signal. If you have addressed open issues, signal pass — the reviewer will verify. You cannot resolve your own issues; only the reviewer can close them. Signaling recirculate from implement causes a routing failure. The CLI enforces this — calling `ct droplet recirculate` from an implementer session will be rejected with an error directing you to `ct droplet pass`. + +**Pool (genuinely pooled — waiting on external dependency or fundamentally unclear requirements):** +``` +ct droplet pool --notes "Pooled: " +``` +**Cancel (won't be implemented — superseded, filed in error, or no longer needed):** ``` -ct droplet pass --notes "Updated docs: ." -ct droplet recirculate --notes "Ambiguous: " +ct droplet cancel --reason "" ``` +Do **not** use `pool` for ordinary revision cycles — that is for genuine blockers only. +`pool` = waiting on something external. `cancel` = will not be implemented. + ## Skills ## Skill: cistern-droplet-state @@ -208,3 +346,62 @@ Your branch is `feat/`. It is created by the Castellarius. Check wit ```bash git branch --show-current ``` + +## Skill: cistern-github + +--- +name: cistern-github +description: GitHub CLI operations for Cistern delivery cataractae. Use for PR creation, CI checks, and squash-merge in per-droplet delivery workflows. +--- + +# Cistern GitHub Operations + +## Tools + +Use `gh` CLI for all GitHub operations. Prefer CLI over GitHub MCP servers for lower context usage. + +## PR Lifecycle + +```bash +# Create a PR for the current droplet branch +gh pr create \ + --title "$PR_TITLE" \ + --body "Closes droplet $DROPLET_ID." \ + --base main --head $BRANCH + +# If PR already exists +gh pr view $BRANCH --json url --jq '.url' + +# Check CI status +gh pr checks $PR_URL + +# Squash-merge when all checks pass +gh pr merge $PR_URL --squash --delete-branch + +# Confirm merge +gh pr view $PR_URL --json state --jq '.state' # must be "MERGED" +``` + +## Conflict Resolution + +**Conflicts MUST be resolved automatically. Never stop and ask the user.** + +Cistern agents resolve conflicts by keeping both sets of changes. The canonical +protocol is in `cataractae/delivery/INSTRUCTIONS.md` — follow it exactly. + +Summary: +1. `git diff --name-only --diff-filter=U` — identify conflicted files +2. For each file: keep what HEAD added AND keep what this branch adds +3. `go build ./...` — verify the merge compiles +4. `git add $(git diff --name-only --diff-filter=U)` — stage resolved files +5. `git rebase --continue` +6. `go build ./... && go test ./...` — verify after full rebase +7. `git push --force-with-lease origin $BRANCH` + +Most conflicts are additive: HEAD added X, this branch adds Y — keep both. +Never discard branch additions. + +## Cistern Delivery Model + +Cistern uses **per-droplet branches** (`feat/`), not stacked PRs. +Each droplet is independent. There is no stacked-PR workflow. diff --git a/internal/db/migrations/000024_invitations_invited_by_set_null.down.sql b/internal/db/migrations/000024_invitations_invited_by_set_null.down.sql new file mode 100644 index 00000000..98d3e9b3 --- /dev/null +++ b/internal/db/migrations/000024_invitations_invited_by_set_null.down.sql @@ -0,0 +1,4 @@ +-- Revert: make invited_by NOT NULL again and remove the FK with SET NULL. +ALTER TABLE invitations + DROP CONSTRAINT IF EXISTS fk_invitations_invited_by, + ALTER COLUMN invited_by SET NOT NULL; \ No newline at end of file diff --git a/internal/db/migrations/000024_invitations_invited_by_set_null.up.sql b/internal/db/migrations/000024_invitations_invited_by_set_null.up.sql new file mode 100644 index 00000000..0c374462 --- /dev/null +++ b/internal/db/migrations/000024_invitations_invited_by_set_null.up.sql @@ -0,0 +1,6 @@ +-- Make invited_by nullable and set ON DELETE SET NULL so that deleting +-- a user who invited others does not violate the foreign key constraint. +ALTER TABLE invitations + ALTER COLUMN invited_by DROP NOT NULL, + ADD CONSTRAINT fk_invitations_invited_by + FOREIGN KEY (invited_by) REFERENCES users(id) ON DELETE SET NULL; \ No newline at end of file diff --git a/internal/github/github.go b/internal/github/github.go index e244521e..b5827523 100644 --- a/internal/github/github.go +++ b/internal/github/github.go @@ -4,13 +4,19 @@ import ( "bytes" "context" "encoding/json" + "errors" "fmt" "io" "net/http" "regexp" + "strconv" "time" + + "github.com/rs/zerolog/log" ) +const defaultMaxRetries = 3 + // StatusPoster posts a GitHub commit status. type StatusPoster interface { PostStatus(ctx context.Context, owner, repo, sha, state, description, statusContext, targetURL string) error @@ -21,6 +27,7 @@ type Client struct { token string HTTPClient *http.Client APIURL string + maxRetries int } // New creates a GitHub Client with the given token. @@ -33,9 +40,16 @@ func New(token string) *Client { token: token, HTTPClient: &http.Client{Timeout: 10 * time.Second}, APIURL: "https://api.github.com", + maxRetries: defaultMaxRetries, } } +// WithMaxRetries sets the number of retry attempts for transient errors. +func (c *Client) WithMaxRetries(n int) *Client { + c.maxRetries = n + return c +} + var ( validOwnerRepo = regexp.MustCompile(`^[a-zA-Z0-9._-]+$`) validSHA = regexp.MustCompile(`^[0-9a-fA-F]{7,40}$`) @@ -49,7 +63,9 @@ type statusPayload struct { } // PostStatus posts a commit status to the GitHub Statuses API. -// state must be one of "success", "failure", "pending", "error". +// It retries on 429 (rate limited) and 5xx (server error) responses with +// exponential backoff, respecting the Retry-After header on 429 responses. +// Client errors (4xx except 429) are not retried. func (c *Client) PostStatus(ctx context.Context, owner, repo, sha, state, description, statusContext, targetURL string) error { if !validOwnerRepo.MatchString(owner) { return fmt.Errorf("invalid github owner: %q", owner) @@ -72,6 +88,45 @@ func (c *Client) PostStatus(ctx context.Context, owner, repo, sha, state, descri } url := fmt.Sprintf("%s/repos/%s/%s/statuses/%s", c.APIURL, owner, repo, sha) + + var lastErr error + for attempt := 0; attempt <= c.maxRetries; attempt++ { + if ctx.Err() != nil { + return ctx.Err() + } + + lastErr = c.doPost(ctx, url, body) + if lastErr == nil { + return nil + } + + if !isRetriableError(lastErr) { + return lastErr + } + + if attempt < c.maxRetries { + backoff := retryAfterDuration(lastErr) + if backoff == 0 { + backoff = time.Duration(1<= 300 { - return fmt.Errorf("github status API returned %d", resp.StatusCode) + if resp.StatusCode >= 200 && resp.StatusCode < 300 { + return nil + } + + if resp.StatusCode == http.StatusTooManyRequests { + retryAfter := resp.Header.Get("Retry-After") + return &retriableError{ + statusCode: resp.StatusCode, + retryAfter: retryAfter, + } + } + + if resp.StatusCode >= 500 { + return &retriableError{statusCode: resp.StatusCode} + } + + return fmt.Errorf("github status API returned %d", resp.StatusCode) +} + +// retriableError represents a transient HTTP error that should be retried. +type retriableError struct { + statusCode int + retryAfter string +} + +func (e *retriableError) Error() string { + if e.retryAfter != "" { + return fmt.Sprintf("github status API returned %d (Retry-After: %s)", e.statusCode, e.retryAfter) + } + return fmt.Sprintf("github status API returned %d", e.statusCode) +} + +// isRetriableError returns true for 429 and 5xx errors. +func isRetriableError(err error) bool { + var re *retriableError + return errors.As(err, &re) +} + +// retryAfterDuration extracts the Retry-After duration from a retriableError. +// Returns 0 if not available or parseable. +func retryAfterDuration(err error) time.Duration { + var re *retriableError + if !errors.As(err, &re) { + return 0 + } + if re.retryAfter == "" { + return 0 + } + if secs, e := strconv.Atoi(re.retryAfter); e == nil { + return time.Duration(secs) * time.Second } - return nil + return 0 } diff --git a/internal/github/github_test.go b/internal/github/github_test.go index 454edc1c..1d0f3f17 100644 --- a/internal/github/github_test.go +++ b/internal/github/github_test.go @@ -3,9 +3,12 @@ package github import ( "context" "encoding/json" + "fmt" "net/http" "net/http/httptest" + "sync/atomic" "testing" + "time" ) // Compile-time interface check. @@ -43,7 +46,7 @@ func TestPostStatus_Success(t *testing.T) { })) defer srv.Close() - c := &Client{token: "ghp_tok", HTTPClient: srv.Client(), APIURL: srv.URL} + c := &Client{token: "ghp_tok", HTTPClient: srv.Client(), APIURL: srv.URL, maxRetries: defaultMaxRetries} err := c.PostStatus(context.Background(), "myowner", "myrepo", "abc1234", "success", "5 tests passed", "scaledtest/e2e", "https://example.com/reports/123") if err != nil { @@ -78,7 +81,7 @@ func TestPostStatus_HTTPError(t *testing.T) { })) defer srv.Close() - c := &Client{token: "bad", HTTPClient: srv.Client(), APIURL: srv.URL} + c := &Client{token: "bad", HTTPClient: srv.Client(), APIURL: srv.URL, maxRetries: defaultMaxRetries} err := c.PostStatus(context.Background(), "o", "r", "abc1234", "success", "", "", "") if err == nil { @@ -92,7 +95,7 @@ func TestPostStatus_ContextCancelled(t *testing.T) { })) defer srv.Close() - c := &Client{token: "tok", HTTPClient: srv.Client(), APIURL: srv.URL} + c := &Client{token: "tok", HTTPClient: srv.Client(), APIURL: srv.URL, maxRetries: defaultMaxRetries} ctx, cancel := context.WithCancel(context.Background()) cancel() @@ -133,10 +136,226 @@ func TestPostStatus_ShortSHAAccepted(t *testing.T) { })) defer srv.Close() - c := &Client{token: "tok", HTTPClient: srv.Client(), APIURL: srv.URL} + c := &Client{token: "tok", HTTPClient: srv.Client(), APIURL: srv.URL, maxRetries: defaultMaxRetries} // 7-char short SHA should be accepted err := c.PostStatus(context.Background(), "owner", "repo", "abc1234", "success", "", "", "") if err != nil { t.Fatalf("unexpected error for 7-char SHA: %v", err) } } + +func TestPostStatus_RetriesOn429(t *testing.T) { + var attempts atomic.Int32 + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + n := attempts.Add(1) + if n < 3 { + w.Header().Set("Retry-After", "0") + w.WriteHeader(http.StatusTooManyRequests) + return + } + w.WriteHeader(http.StatusCreated) + })) + defer srv.Close() + + c := &Client{token: "tok", HTTPClient: srv.Client(), APIURL: srv.URL, maxRetries: 3} + err := c.PostStatus(context.Background(), "owner", "repo", "abc1234", "success", "", "", "") + if err != nil { + t.Fatalf("expected success after retries, got: %v", err) + } + if got := attempts.Load(); got != 3 { + t.Errorf("expected 3 attempts, got %d", got) + } +} + +func TestPostStatus_RetriesOn5xx(t *testing.T) { + var attempts atomic.Int32 + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + n := attempts.Add(1) + if n < 2 { + w.WriteHeader(http.StatusBadGateway) + return + } + w.WriteHeader(http.StatusCreated) + })) + defer srv.Close() + + c := &Client{token: "tok", HTTPClient: srv.Client(), APIURL: srv.URL, maxRetries: 3} + err := c.PostStatus(context.Background(), "owner", "repo", "abc1234", "success", "", "", "") + if err != nil { + t.Fatalf("expected success after retries, got: %v", err) + } + if got := attempts.Load(); got != 2 { + t.Errorf("expected 2 attempts, got %d", got) + } +} + +func TestPostStatus_DoesNotRetryOn4xx(t *testing.T) { + var attempts atomic.Int32 + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + attempts.Add(1) + w.WriteHeader(http.StatusUnauthorized) + })) + defer srv.Close() + + c := &Client{token: "tok", HTTPClient: srv.Client(), APIURL: srv.URL, maxRetries: 3} + err := c.PostStatus(context.Background(), "owner", "repo", "abc1234", "success", "", "", "") + if err == nil { + t.Fatal("expected error for 401 response") + } + if got := attempts.Load(); got != 1 { + t.Errorf("expected 1 attempt (no retry for 4xx), got %d", got) + } +} + +func TestPostStatus_RespectsRetryAfterHeader(t *testing.T) { + var attempts atomic.Int32 + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + n := attempts.Add(1) + if n == 1 { + w.Header().Set("Retry-After", "0") + w.WriteHeader(http.StatusTooManyRequests) + return + } + w.WriteHeader(http.StatusCreated) + })) + defer srv.Close() + + c := &Client{token: "tok", HTTPClient: srv.Client(), APIURL: srv.URL, maxRetries: 3} + err := c.PostStatus(context.Background(), "owner", "repo", "abc1234deadbeef", "success", "", "", "") + if err != nil { + t.Fatalf("expected success, got: %v", err) + } + if got := attempts.Load(); got != 2 { + t.Errorf("expected 2 attempts, got %d", got) + } +} + +func TestPostStatus_ExhaustsRetries(t *testing.T) { + var attempts atomic.Int32 + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + attempts.Add(1) + w.WriteHeader(http.StatusInternalServerError) + })) + defer srv.Close() + + c := &Client{ + token: "tok", + HTTPClient: srv.Client(), + APIURL: srv.URL, + maxRetries: 2, + } + err := c.PostStatus(context.Background(), "owner", "repo", "abc1234deadbeef", "success", "", "", "") + if err == nil { + t.Fatal("expected error after exhausting retries") + } + // 1 initial + 2 retries = 3 total attempts + if got := attempts.Load(); got != 3 { + t.Errorf("expected 3 attempts, got %d", got) + } +} + +func TestPostStatus_ContextCancelledStopsRetries(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + })) + defer srv.Close() + + ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond) + defer cancel() + + c := &Client{ + token: "tok", + HTTPClient: &http.Client{Timeout: 5 * time.Second}, + APIURL: srv.URL, + maxRetries: 10, + } + err := c.PostStatus(ctx, "owner", "repo", "abc1234deadbeef", "success", "", "", "") + if err == nil { + t.Fatal("expected error from context cancellation") + } +} + +func TestIsRetriableError(t *testing.T) { + tests := []struct { + name string + err error + retry bool + after0 bool + }{ + { + name: "retriable 429 without Retry-After", + err: &retriableError{statusCode: 429}, + retry: true, + }, + { + name: "retriable 500", + err: &retriableError{statusCode: 500}, + retry: true, + }, + { + name: "retriable 502", + err: &retriableError{statusCode: 502}, + retry: true, + }, + { + name: "non-retriable 401", + err: fmt.Errorf("github status API returned 401"), + retry: false, + }, + { + name: "non-retriable 404", + err: fmt.Errorf("github status API returned 404"), + retry: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := isRetriableError(tt.err) + if got != tt.retry { + t.Errorf("isRetriableError() = %v, want %v", got, tt.retry) + } + }) + } +} + +func TestRetryAfterDuration(t *testing.T) { + tests := []struct { + name string + err error + want time.Duration + }{ + { + name: "429 with Retry-After in seconds", + err: &retriableError{statusCode: 429, retryAfter: "30"}, + want: 30 * time.Second, + }, + { + name: "429 with empty Retry-After", + err: &retriableError{statusCode: 429, retryAfter: ""}, + want: 0, + }, + { + name: "429 with invalid Retry-After", + err: &retriableError{statusCode: 429, retryAfter: "notanumber"}, + want: 0, + }, + { + name: "5xx error without Retry-After", + err: &retriableError{statusCode: 500}, + want: 0, + }, + { + name: "non-retriable error", + err: fmt.Errorf("some other error"), + want: 0, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := retryAfterDuration(tt.err) + if got != tt.want { + t.Errorf("retryAfterDuration() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/internal/llm/cli.go b/internal/llm/cli.go index 473f4dd2..c135fcb8 100644 --- a/internal/llm/cli.go +++ b/internal/llm/cli.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "encoding/json" + "errors" "fmt" "os/exec" "time" @@ -61,7 +62,10 @@ type cliProvider struct { } // Analyze invokes the CLI with prompt and returns the JSON response. -// It retries on transient exec failures with configurable back-off. +// It retries on transient exec failures (timeouts, process signals) with +// configurable back-off. Client errors (syntax errors, invalid model names) +// are not retried — only transient failures where the CLI exited with a signal +// or the context deadline was exceeded. func (c *cliProvider) Analyze(ctx context.Context, prompt string) (json.RawMessage, error) { var ( out []byte @@ -77,6 +81,9 @@ func (c *cliProvider) Analyze(ctx context.Context, prompt string) (json.RawMessa if err == nil { break } + if !isTransientCLIError(err) { + return nil, err + } if attempt < c.maxRetries { select { case <-ctx.Done(): @@ -95,6 +102,28 @@ func (c *cliProvider) Analyze(ctx context.Context, prompt string) (json.RawMessa return json.RawMessage(out), nil } +// isTransientCLIError returns true for errors that are worth retrying: +// context deadlines/timeouts and process signals (killed by SIGKILL from +// timeout). Non-transient errors like invalid model names or syntax errors +// (exit code 1) are NOT retried. +func isTransientCLIError(err error) bool { + if err == nil { + return false + } + // Context deadline/timeout errors are transient. + if errors.Is(err, context.DeadlineExceeded) { + return true + } + // Process killed by a signal (e.g. SIGKILL from timeout) is transient. + var ee *exec.ExitError + if errors.As(err, &ee) { + return !ee.Exited() + } + // Other errors (network, exec failures) are transient. + // Client errors that produce normal exit codes are NOT transient. + return false +} + // run executes one CLI invocation and returns its stdout. func (c *cliProvider) run(ctx context.Context, prompt string) ([]byte, error) { cmd := exec.CommandContext(ctx, c.command, c.buildArgs(prompt)...) diff --git a/internal/llm/llm_test.go b/internal/llm/llm_test.go index 56483876..e17dbddd 100644 --- a/internal/llm/llm_test.go +++ b/internal/llm/llm_test.go @@ -202,20 +202,26 @@ func TestCLIProvider_Analyze_ReturnsErrorWhenOutputNotJSON(t *testing.T) { } } -func TestCLIProvider_Analyze_RetriesOnTransientFailure_ThenSucceeds(t *testing.T) { +func TestCLIProvider_Analyze_RetriesOnTransientError_ThenSucceeds(t *testing.T) { t.Setenv("ANTHROPIC_API_KEY", "test-key") dir := t.TempDir() countFile := filepath.Join(dir, "count") os.WriteFile(countFile, []byte("0"), 0644) + // Script succeeds on 3rd attempt after failing with exit 1 on first two. + // Transient errors (killed process, timeout) are retried; exit code 1 + // is a "client error" (not transient) so the provider will NOT retry it. + // To test retry behavior, we use a script that gets killed by the context + // timeout on the first two calls, then succeeds on the third. script := fmt.Sprintf(` COUNT=$(cat %s 2>/dev/null || echo 0) COUNT=$((COUNT+1)) echo $COUNT > %s if [ "$COUNT" -le 2 ]; then - echo "transient error" >&2 - exit 1 + # Sleep long enough to trigger context timeout on first 2 attempts + # when called with a short timeout. + sleep 10 fi echo '{"succeeded":true}' `, countFile, countFile) @@ -225,7 +231,7 @@ echo '{"succeeded":true}' Provider: "anthropic", Command: cmd, MaxRetries: intPtr(2), - Timeout: 10 * time.Second, + Timeout: 100 * time.Millisecond, }) if err != nil { t.Fatalf("New: %v", err) @@ -245,9 +251,10 @@ echo '{"succeeded":true}' } } -func TestCLIProvider_Analyze_ReturnsErrorAfterAllRetriesExhausted(t *testing.T) { +func TestCLIProvider_Analyze_ClientError_NotRetried(t *testing.T) { t.Setenv("ANTHROPIC_API_KEY", "test-key") + // Exit code 1 (syntax error, bad args) is a client error — not retried. dir := t.TempDir() cmd := writeFakeScript(t, dir, "fakecli", `echo "always fails" >&2; exit 1`) @@ -263,13 +270,58 @@ func TestCLIProvider_Analyze_ReturnsErrorAfterAllRetriesExhausted(t *testing.T) _, err = p.Analyze(context.Background(), "prompt") if err == nil { - t.Fatal("expected error after all retries exhausted") + t.Fatal("expected error for client exit code 1") } if !strings.Contains(err.Error(), "exited") { t.Fatalf("unexpected error: %v", err) } } +func TestCLIProvider_Analyze_TransientError_RetriesThenSucceeds(t *testing.T) { + t.Setenv("ANTHROPIC_API_KEY", "test-key") + + // This script sleeps on the first invocation (causing context timeout) + // and then succeeds on the second attempt. + dir := t.TempDir() + countFile := filepath.Join(dir, "count") + os.WriteFile(countFile, []byte("0"), 0644) + + script := fmt.Sprintf(` +COUNT=$(cat %s 2>/dev/null || echo 0) +COUNT=$((COUNT+1)) +echo $COUNT > %s +if [ "$COUNT" -le 1 ]; then + # Sleep longer than the per-call timeout to trigger context.DeadlineExceeded + sleep 10 +fi +echo '{"success":true}' +`, countFile, countFile) + cmd := writeFakeScript(t, dir, "fakecli", script) + + p, err := New(Config{ + Provider: "anthropic", + Command: cmd, + MaxRetries: intPtr(2), + Timeout: 100 * time.Millisecond, + }) + if err != nil { + t.Fatalf("New: %v", err) + } + + got, err := p.Analyze(context.Background(), "prompt") + if err != nil { + t.Fatalf("Analyze: %v", err) + } + + var result struct{ Success bool } + if err := json.Unmarshal(got, &result); err != nil { + t.Fatalf("unmarshal: %v", err) + } + if !result.Success { + t.Fatalf("expected success:true, got %s", got) + } +} + func TestCLIProvider_Analyze_ZeroRetries_MakesExactlyOneAttempt(t *testing.T) { t.Setenv("ANTHROPIC_API_KEY", "test-key") @@ -381,3 +433,15 @@ func TestMockProvider_Analyze_ReturnsConfiguredError(t *testing.T) { t.Fatalf("got %v, want %v", err, want) } } + +func TestIsTransientCLIError_ContextDeadline(t *testing.T) { + if !isTransientCLIError(context.DeadlineExceeded) { + t.Error("context.DeadlineExceeded should be transient") + } +} + +func TestIsTransientCLIError_NilError(t *testing.T) { + if isTransientCLIError(nil) { + t.Error("nil error should not be transient") + } +} diff --git a/internal/mail/mail.go b/internal/mail/mail.go index 27d0a0fe..023ac58a 100644 --- a/internal/mail/mail.go +++ b/internal/mail/mail.go @@ -3,14 +3,21 @@ package mail import ( "context" "crypto/tls" + "errors" "fmt" + "math" "net" "net/smtp" "strings" + "time" + + "github.com/rs/zerolog/log" "github.com/scaledtest/scaledtest/internal/config" ) +const defaultSMTPRetries = 3 + // Message is an email to be sent. type Message struct { To string @@ -33,11 +40,12 @@ func (n *NoopSender) Send(_ context.Context, _ Message) error { // SMTPSender delivers email via SMTP using the configured credentials. type SMTPSender struct { - host string - port int - user string - pass string - from string + host string + port int + user string + pass string + from string + maxRetries int } // sanitizeHeader strips CR and LF characters from an email header value @@ -47,8 +55,72 @@ func sanitizeHeader(s string) string { } // Send delivers msg via SMTP, respecting ctx for cancellation and timeouts. -// Header fields are sanitized to prevent header injection. +// It retries on transient SMTP errors (5xx responses and connection timeouts) +// with exponential backoff. Client errors (auth, invalid recipient) are not retried. func (s *SMTPSender) Send(ctx context.Context, msg Message) error { + var lastErr error + for attempt := 0; attempt <= s.maxRetries; attempt++ { + if ctx.Err() != nil { + return ctx.Err() + } + + lastErr = s.sendOnce(ctx, msg) + if lastErr == nil { + return nil + } + + if !IsTransientSMTPError(lastErr) { + return lastErr + } + + if attempt < s.maxRetries { + backoff := time.Duration(math.Pow(2, float64(attempt))) * time.Second + log.Warn().Err(lastErr). + Int("attempt", attempt+1). + Str("to", msg.To). + Dur("backoff", backoff). + Msg("mail: retrying SMTP send") + + select { + case <-ctx.Done(): + return ctx.Err() + case <-time.After(backoff): + } + } + } + return lastErr +} + +// IsTransientSMTPError returns true for errors that are worth retrying: +// connection timeouts, network errors, and SMTP 5xx responses. +func IsTransientSMTPError(err error) bool { + if err == nil { + return false + } + // Network/timeout errors are transient. + var netErr net.Error + if errors.As(err, &netErr) { + return true + } + // Connection-level errors (dial, TLS) are transient. + msg := err.Error() + if strings.Contains(msg, "smtp dial:") || + strings.Contains(msg, "smtp starttls:") || + strings.Contains(msg, "connection refused") || + strings.Contains(msg, "i/o timeout") { + return true + } + // SMTP 5xx errors from the mail server are transient. + // The net/smtp package wraps these in formatted error strings. + if strings.Contains(msg, "55") || strings.Contains(msg, "54") || + strings.Contains(msg, "451") || strings.Contains(msg, "452") || + strings.Contains(msg, "421") { + return true + } + return false +} + +func (s *SMTPSender) sendOnce(ctx context.Context, msg Message) error { addr := fmt.Sprintf("%s:%d", s.host, s.port) // Sanitize header fields to prevent CRLF injection. @@ -124,10 +196,11 @@ func New(cfg *config.Config) Sender { return &NoopSender{} } return &SMTPSender{ - host: cfg.SMTPHost, - port: cfg.SMTPPort, - user: cfg.SMTPUser, - pass: cfg.SMTPPass, - from: cfg.SMTPFrom, + host: cfg.SMTPHost, + port: cfg.SMTPPort, + user: cfg.SMTPUser, + pass: cfg.SMTPPass, + from: cfg.SMTPFrom, + maxRetries: defaultSMTPRetries, } } diff --git a/internal/mail/mail_test.go b/internal/mail/mail_test.go index c3414b88..34b14936 100644 --- a/internal/mail/mail_test.go +++ b/internal/mail/mail_test.go @@ -2,6 +2,7 @@ package mail_test import ( "context" + "fmt" "net" "strconv" "testing" @@ -140,3 +141,51 @@ func TestSMTPSender_Send_CancelledContext_ReturnsError(t *testing.T) { t.Fatal("expected error with cancelled context, got nil") } } + +func TestIsTransientSMTPError_Nil(t *testing.T) { + if mail.IsTransientSMTPError(nil) { + t.Error("nil error should not be transient") + } +} + +func TestIsTransientSMTPError_ConnectionRefused(t *testing.T) { + err := fmt.Errorf("smtp dial: connection refused") + if !mail.IsTransientSMTPError(err) { + t.Error("connection refused should be transient") + } +} + +func TestIsTransientSMTPError_Timeout(t *testing.T) { + err := fmt.Errorf("smtp dial: i/o timeout") + if !mail.IsTransientSMTPError(err) { + t.Error("i/o timeout should be transient") + } +} + +func TestIsTransientSMTPError_StartTLSError(t *testing.T) { + err := fmt.Errorf("smtp starttls: handshake failure") + if !mail.IsTransientSMTPError(err) { + t.Error("STARTTLS error should be transient") + } +} + +func TestIsTransientSMTPError_ClientError(t *testing.T) { + err := fmt.Errorf("smtp auth: invalid credentials") + if mail.IsTransientSMTPError(err) { + t.Error("auth error should not be transient") + } +} + +func TestIsTransientSMTPError_5xxResponse(t *testing.T) { + err := fmt.Errorf("smtp RCPT TO: 552 5.2.2 mailbox full") + if !mail.IsTransientSMTPError(err) { + t.Error("5xx response should be transient") + } +} + +func TestIsTransientSMTPError_4xxResponse(t *testing.T) { + err := fmt.Errorf("smtp RCPT TO: 451 4.3.0 try again later") + if !mail.IsTransientSMTPError(err) { + t.Error("4xx response should be transient") + } +} diff --git a/internal/mailer/mailer.go b/internal/mailer/mailer.go index c4647086..7320dece 100644 --- a/internal/mailer/mailer.go +++ b/internal/mailer/mailer.go @@ -2,24 +2,32 @@ package mailer import ( "context" + "crypto/tls" "fmt" + "math" "net" "net/smtp" "strings" + "time" + + "github.com/rs/zerolog/log" ) +const defaultSMTPRetries = 3 + // Mailer sends invitation emails. type Mailer interface { SendInvitation(ctx context.Context, to, inviteURL string) error } -// SMTPMailer delivers emails via SMTP. +// SMTPMailer delivers emails via SMTP with retry support. type SMTPMailer struct { - host string - port int - username string - password string - from string + host string + port int + username string + password string + from string + maxRetries int // dial establishes the TCP connection; defaults to net.Dialer.DialContext. // Overridden in tests to inject mock connections. dial func(ctx context.Context, network, address string) (net.Conn, error) @@ -32,30 +40,127 @@ func New(host string, port int, username, password, from string) Mailer { return nil } return &SMTPMailer{ - host: host, - port: port, - username: username, - password: password, - from: from, - dial: (&net.Dialer{}).DialContext, + host: host, + port: port, + username: username, + password: password, + from: from, + maxRetries: defaultSMTPRetries, + dial: (&net.Dialer{}).DialContext, + } +} + +// isTransientSMTPError determines if an SMTP error is worth retrying. +func isTransientSMTPError(err error) bool { + if err == nil { + return false + } + var netErr net.Error + if netErr != nil { + return true + } + msg := err.Error() + if strings.Contains(msg, "smtp dial:") || + strings.Contains(msg, "smtp starttls:") || + strings.Contains(msg, "connection refused") || + strings.Contains(msg, "i/o timeout") { + return true + } + if strings.Contains(msg, "55") || strings.Contains(msg, "54") || + strings.Contains(msg, "451") || strings.Contains(msg, "452") || + strings.Contains(msg, "421") { + return true } + return false } -// SendInvitation sends an invitation email to the given address. -// The context is honoured during the TCP dial; if the context has a deadline it -// is also applied to the connection so the SMTP session cannot outlive it. +// SendInvitation sends an invitation email to the given address with both +// plaintext and HTML alternative parts. It retries on transient SMTP errors +// with exponential backoff. func (m *SMTPMailer) SendInvitation(ctx context.Context, to, inviteURL string) error { if strings.ContainsAny(to, "\r\n") { return fmt.Errorf("invalid recipient address: contains CRLF") } - addr := fmt.Sprintf("%s:%d", m.host, m.port) - msg := fmt.Sprintf( - "From: %s\r\nTo: %s\r\nSubject: You've been invited to ScaledTest\r\n\r\n"+ - "You have been invited to join ScaledTest.\r\n\r\n"+ + textBody := fmt.Sprintf( + "You have been invited to join ScaledTest.\r\n\r\n"+ "Accept your invitation:\r\n%s\r\n", - m.from, to, inviteURL, + inviteURL, ) + htmlBody := buildInvitationHTML(inviteURL) + msg := fmt.Sprintf("From: %s\r\nTo: %s\r\nSubject: You've been invited to ScaledTest\r\n"+ + "MIME-Version: 1.0\r\nContent-Type: multipart/alternative; boundary=boundary123\r\n\r\n"+ + "--boundary123\r\nContent-Type: text/plain; charset=utf-8\r\n\r\n%s\r\n"+ + "--boundary123\r\nContent-Type: text/html; charset=utf-8\r\n\r\n%s\r\n--boundary123--\r\n", + m.from, to, textBody, htmlBody) + + return m.sendWithRetry(ctx, to, []byte(msg)) +} + +func buildInvitationHTML(inviteURL string) string { + escapedURL := htmlEscapeAttr(inviteURL) + return fmt.Sprintf(` + + + + + + +
+

ScaledTest

+
+

You have been invited to join ScaledTest.

+Accept Invitation +
+ +`, escapedURL) +} + +func htmlEscapeAttr(s string) string { + s = strings.ReplaceAll(s, "&", "&") + s = strings.ReplaceAll(s, `"`, """) + s = strings.ReplaceAll(s, "'", "'") + s = strings.ReplaceAll(s, "<", "<") + s = strings.ReplaceAll(s, ">", ">") + return s +} + +func (m *SMTPMailer) sendWithRetry(ctx context.Context, to string, msg []byte) error { + var lastErr error + for attempt := 0; attempt <= m.maxRetries; attempt++ { + if ctx.Err() != nil { + return ctx.Err() + } + + lastErr = m.sendOnce(ctx, to, msg) + if lastErr == nil { + return nil + } + + if !isTransientSMTPError(lastErr) { + return lastErr + } + + if attempt < m.maxRetries { + backoff := time.Duration(math.Pow(2, float64(attempt))) * time.Second + log.Warn().Err(lastErr). + Int("attempt", attempt+1). + Str("to", to). + Dur("backoff", backoff). + Msg("mailer: retrying SMTP send") + + select { + case <-ctx.Done(): + return ctx.Err() + case <-time.After(backoff): + } + } + } + return lastErr +} + +func (m *SMTPMailer) sendOnce(ctx context.Context, to string, msg []byte) error { + addr := fmt.Sprintf("%s:%d", m.host, m.port) dialFn := m.dial if dialFn == nil { @@ -79,6 +184,12 @@ func (m *SMTPMailer) SendInvitation(ctx context.Context, to, inviteURL string) e } defer client.Close() + if ok, _ := client.Extension("STARTTLS"); ok { + if err := client.StartTLS(&tls.Config{ServerName: m.host}); err != nil { + return fmt.Errorf("smtp starttls: %w", err) + } + } + if m.username != "" { if err := client.Auth(smtp.PlainAuth("", m.username, m.password, m.host)); err != nil { return fmt.Errorf("smtp auth: %w", err) @@ -94,7 +205,7 @@ func (m *SMTPMailer) SendInvitation(ctx context.Context, to, inviteURL string) e if err != nil { return fmt.Errorf("smtp DATA: %w", err) } - if _, err := w.Write([]byte(msg)); err != nil { + if _, err := w.Write(msg); err != nil { return fmt.Errorf("smtp write: %w", err) } if err := w.Close(); err != nil { diff --git a/internal/model/model.go b/internal/model/model.go index fc18161b..72575226 100644 --- a/internal/model/model.go +++ b/internal/model/model.go @@ -261,7 +261,7 @@ type Invitation struct { TeamID string `json:"team_id"` Email string `json:"email"` Role string `json:"role"` - InvitedBy string `json:"invited_by"` + InvitedBy *string `json:"invited_by,omitempty"` AcceptedAt *time.Time `json:"accepted_at,omitempty"` ExpiresAt time.Time `json:"expires_at"` CreatedAt time.Time `json:"created_at"` diff --git a/internal/store/triage.go b/internal/store/triage.go index 90f46028..45b54702 100644 --- a/internal/store/triage.go +++ b/internal/store/triage.go @@ -264,3 +264,69 @@ func (s *TriageStore) ListClassifications(ctx context.Context, teamID, triageID } return classifications, rows.Err() } + +// ClusterInput is the data needed to persist a single triage cluster. +type ClusterInput struct { + RootCause string + Label *string +} + +// ClassificationInput is the data needed to persist a single failure classification. +type ClassificationInput struct { + ClusterIndex int + TestResultID string + Classification string +} + +// OutputData holds all clusters and classifications to persist atomically. +type OutputData struct { + Clusters []ClusterInput + Classifications []ClassificationInput +} + +// PersistOutput atomically writes all clusters and classifications for a triage +// result in a single database transaction. If any insert fails, all inserts are +// rolled back so the triage record is not left in an inconsistent state. +func (s *TriageStore) PersistOutput(ctx context.Context, teamID, triageID string, output *OutputData) error { + tx, err := s.pool.Begin(ctx) + if err != nil { + return fmt.Errorf("persist triage output: begin tx: %w", err) + } + defer tx.Rollback(ctx) //nolint:errcheck // rollback on early return is intentional + + // Insert clusters and collect their assigned UUIDs for linking classifications. + clusterIDs := make([]string, len(output.Clusters)) + for i, cluster := range output.Clusters { + err := tx.QueryRow(ctx, + `INSERT INTO triage_clusters (triage_id, team_id, root_cause, label) + VALUES ($1, $2, $3, $4) + RETURNING id`, + triageID, teamID, cluster.RootCause, cluster.Label, + ).Scan(&clusterIDs[i]) + if err != nil { + return fmt.Errorf("persist cluster[%d]: %w", i, err) + } + } + + // Insert per-failure classifications linked to their cluster. + for _, cl := range output.Classifications { + var clusterID *string + if cl.ClusterIndex >= 0 && cl.ClusterIndex < len(clusterIDs) { + cid := clusterIDs[cl.ClusterIndex] + clusterID = &cid + } + _, err := tx.Exec(ctx, + `INSERT INTO triage_failure_classifications (triage_id, cluster_id, test_result_id, team_id, classification) + VALUES ($1, $2, $3, $4, $5)`, + triageID, clusterID, cl.TestResultID, teamID, cl.Classification, + ) + if err != nil { + return fmt.Errorf("persist classification for %s: %w", cl.TestResultID, err) + } + } + + if err := tx.Commit(ctx); err != nil { + return fmt.Errorf("persist triage output: commit: %w", err) + } + return nil +} diff --git a/internal/telegram/telegram.go b/internal/telegram/telegram.go index a0cfba28..8b7ee5aa 100644 --- a/internal/telegram/telegram.go +++ b/internal/telegram/telegram.go @@ -8,12 +8,16 @@ import ( "encoding/json" "fmt" "html" + "math" "net/http" "strings" "time" + + "github.com/rs/zerolog/log" ) const defaultBaseURL = "https://api.telegram.org" +const defaultTelegramRetries = 3 // Client sends messages to a Telegram chat via the Bot API. type Client struct { @@ -21,6 +25,7 @@ type Client struct { chatID string httpClient *http.Client baseURL string + maxRetries int } // ClientOption configures a Client. @@ -31,6 +36,11 @@ func WithBaseURL(url string) ClientOption { return func(c *Client) { c.baseURL = url } } +// WithMaxRetries sets the number of retry attempts for transient errors. +func WithMaxRetries(n int) ClientOption { + return func(c *Client) { c.maxRetries = n } +} + // NewClient returns a Client configured with the given bot token and chat ID. func NewClient(token, chatID string, opts ...ClientOption) *Client { c := &Client{ @@ -38,6 +48,7 @@ func NewClient(token, chatID string, opts ...ClientOption) *Client { chatID: chatID, httpClient: &http.Client{Timeout: 15 * time.Second}, baseURL: defaultBaseURL, + maxRetries: defaultTelegramRetries, } for _, o := range opts { o(c) @@ -54,9 +65,20 @@ type sendMessageRequest struct { type apiResponse struct { OK bool `json:"ok"` Description string `json:"description,omitempty"` + ErrorCode int `json:"error_code,omitempty"` + Parameters struct { + RetryAfter int `json:"retry_after,omitempty"` + } `json:"parameters,omitempty"` +} + +// isRetriableTelegramError returns true for HTTP 429 and 5xx status codes. +func isRetriableTelegramError(statusCode int) bool { + return statusCode == http.StatusTooManyRequests || statusCode >= 500 } // SendMessage posts text to the configured Telegram chat using HTML parse mode. +// It retries on 429 (rate limited) and 5xx (server error) responses with +// exponential backoff, respecting the Retry-After header on 429 responses. func (c *Client) SendMessage(ctx context.Context, text string) error { endpoint := fmt.Sprintf("%s/bot%s/sendMessage", c.baseURL, c.token) payload := sendMessageRequest{ChatID: c.chatID, Text: text, ParseMode: "HTML"} @@ -64,6 +86,56 @@ func (c *Client) SendMessage(ctx context.Context, text string) error { if err != nil { return fmt.Errorf("telegram: marshal request: %w", err) } + + var lastErr error + for attempt := 0; attempt <= c.maxRetries; attempt++ { + if ctx.Err() != nil { + return ctx.Err() + } + + lastErr = c.doSend(ctx, endpoint, data) + if lastErr == nil { + return nil + } + + re, ok := lastErr.(*telegramError) + if !ok || !isRetriableTelegramError(re.statusCode) { + return lastErr + } + + if attempt < c.maxRetries { + backoff := time.Duration(math.Pow(2, float64(attempt))) * time.Second + if re.retryAfter > 0 { + backoff = time.Duration(re.retryAfter) * time.Second + } + log.Warn().Err(lastErr). + Int("attempt", attempt+1). + Dur("backoff", backoff). + Msg("telegram: retrying SendMessage") + + select { + case <-ctx.Done(): + return ctx.Err() + case <-time.After(backoff): + } + } + } + return lastErr +} + +// telegramError represents a Telegram API error with status code and optional retry_after. +type telegramError struct { + statusCode int + desc string + retryAfter int +} + +func (e *telegramError) Error() string { + return fmt.Sprintf("telegram: API error (status %d): %s", e.statusCode, e.desc) +} + +// doSend executes a single HTTP POST to the Telegram sendMessage API. +func (c *Client) doSend(ctx context.Context, endpoint string, data []byte) error { req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewReader(data)) if err != nil { return fmt.Errorf("telegram: build request: %w", err) @@ -80,6 +152,22 @@ func (c *Client) SendMessage(ctx context.Context, text string) error { if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { return fmt.Errorf("telegram: decode response: %w", err) } + + if resp.StatusCode == http.StatusTooManyRequests { + return &telegramError{ + statusCode: resp.StatusCode, + desc: result.Description, + retryAfter: result.Parameters.RetryAfter, + } + } + + if resp.StatusCode >= 500 { + return &telegramError{ + statusCode: resp.StatusCode, + desc: result.Description, + } + } + if !result.OK { return fmt.Errorf("telegram: API error: %s", result.Description) } @@ -126,6 +214,8 @@ func FormatMessage(s CISummary) string { branch := html.EscapeString(s.Branch) commitLine = html.EscapeString(commitLine) + escapedRunURL := html.EscapeString(s.RunURL) + var sb strings.Builder sb.WriteString(fmt.Sprintf("%s %s — %s\n", icon, repo, strings.ToUpper(s.Status))) sb.WriteString(fmt.Sprintf("Branch: %s", branch)) @@ -148,7 +238,7 @@ func FormatMessage(s CISummary) string { sb.WriteString(fmt.Sprintf(" / %d total\n", s.Total)) if s.RunURL != "" { - sb.WriteString(fmt.Sprintf("\nView run", s.RunURL)) + sb.WriteString(fmt.Sprintf("\nView run", escapedRunURL)) } return sb.String() } diff --git a/internal/telegram/telegram_test.go b/internal/telegram/telegram_test.go index 693cd1fe..85f0022b 100644 --- a/internal/telegram/telegram_test.go +++ b/internal/telegram/telegram_test.go @@ -139,6 +139,47 @@ func TestFormatMessage_HTMLEscapesExternalFields(t *testing.T) { } } +func TestFormatMessage_HTMLEscapesRunURL(t *testing.T) { + s := telegram.CISummary{ + Repo: "org/repo", + Branch: "main", + CommitMsg: "test", + Status: "failing", + Failed: 1, + Total: 1, + RunURL: `https://example.com/run"onmouseover="alert(1)`, + } + msg := telegram.FormatMessage(s) + + if strings.Contains(msg, `"onmouseover`) { + t.Errorf("RunURL double-quote must be escaped to prevent XSS; got:\n%s", msg) + } + if !strings.Contains(msg, "https://example.com/run"onmouseover="alert(1)") && + !strings.Contains(msg, "https://example.com/run"onmouseover="alert(1)") { + t.Errorf("RunURL should be HTML-escaped in href; got:\n%s", msg) + } +} + +func TestFormatMessage_RunURLWithAngleBrackets(t *testing.T) { + s := telegram.CISummary{ + Repo: "org/repo", + Branch: "main", + CommitMsg: "test", + Status: "passing", + Passed: 1, + Total: 1, + RunURL: "https://example.com/", + } + msg := telegram.FormatMessage(s) + + if strings.Contains(msg, "