diff --git a/CLAUDE.md b/CLAUDE.md index 2ddddf7..7b4932d 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -19,35 +19,49 @@ Standard Go HTTP server (not Cloudflare Workers) serving as a temporary file upl **Environment variables (required):** - `db_url` — PostgreSQL connection string - `bucket_name` — GCS bucket name +- `sign_key` — HMAC key for signing download tokens (`handler.go:signFilename`). Rotating it invalidates every outstanding URL — that's the right behavior if a key leaks. **Environment variables (optional):** - `base_url` — download URL prefix (default: `https://dropbox.deploys.app/files/`) - `api_endpoint` — deploys.app API base URL (default: `https://api.deploys.app`, override with internal address in production) -- `cdn_base_url` — full URL prefix (including scheme and trailing slash, e.g. `https://cdn.example.com/`). When set, `GET /files/{fn}` records the download metric (counted against `attrs.Size`) and 307-redirects to `{cdn_base_url}{fn}`. The CDN edge is expected to fetch its origin at `https://dropbox.deploys.app/_cdn/{fn}`, which streams the file unauthenticated and without metrics. In-cluster callers (private/loopback/link-local `X-Real-Ip`) bypass the redirect and stream directly. Unset = original streaming behavior. +- `cdn_base_url` — full URL prefix (including scheme and trailing slash, e.g. `https://cdn.example.com/`). When set, `GET /files/{token}` records the download metric (counted against `attrs.Size`) and 307-redirects to `{cdn_base_url}{token}`. The CDN edge is expected to fetch its origin at `https://dropbox.deploys.app/_cdn/{token}`, which streams the file (no auth, no metrics) after re-verifying the same HMAC. Each `/_cdn` response sets `Cache-Control` for the edge: success uses `public, max-age={remaining TTL}, immutable` so the edge cache lines up with the file's actual lifetime; 410/404 for known-dead URLs use `public, max-age=3600` so repeat probes are absorbed at the edge; invalid tokens get no `Cache-Control` because each garbage URL is unique and caching just burns edge slots. In-cluster callers (private/loopback/link-local `X-Real-Ip`) bypass the redirect and stream directly. Unset = original streaming behavior. - `PORT` — listen port (default: `8080`) - `log_level` — slog level (default: info) +**Download token scheme (`handler.go`):** +- The URL path component is a `token` = `fn` + `"-"` + `sig`, currently 45 chars total. `fn` is 24 random chars `[0-9A-Za-z]` (~143 bits of entropy); `sig` is 20 hex chars of HMAC-SHA256 truncated to 80 bits, keyed by `sign_key`. +- The `-` separator lets us change `fnLen` later without invalidating tokens that are already in circulation — `parseToken` splits structurally on the separator, not by fixed position. Since `fn` is alphanumeric and `sig` is hex, neither side can contain a `-`. +- `parseToken(SignKey, token)` runs first in both `fileHandler` and `cdnFileHandler` and 404s on any mismatch — DDoS attempts that don't know `sign_key` never reach the DB or GCS. +- `fn` is what we store in the bucket and the `files.fn` column. The full token only appears in URLs. + **Request flow (`handler.go`):** 1. Parse `Authorization` header + `project`/`projectId` from query params or `param-*` headers (query params take precedence) 2. Authorize via `checkAuth()` in `auth.go` 3. Parse TTL (1–7 days, default 1) and optional filename the same way -4. Generate a crypto-random 86-char URL-safe base64 filename with TTL digit prepended (e.g., `1ABC…`) -5. Stream body to GCS with cache-control and optional content-disposition +4. Generate a 24-char alphanumeric `fn` (`generateFilename`, rejection-sampled to stay unbiased) +5. Stream body to GCS with cache-control and optional content-disposition, keyed by `fn` 6. Insert metadata into PostgreSQL via `pgctx.Exec` -7. Return JSON: `{"ok": true, "result": {"downloadUrl": "...", "expiresAt": "..."}}` +7. Return JSON: `{"ok": true, "result": {"downloadUrl": "{base_url}{fn}-{sig}", "expiresAt": "..."}}` **Auth (`auth.go`):** - No `Authorization` header → alpha mode, project ID hardcoded as `"alpha"` (TODO: remove) - With token → POST to `https://api.deploys.app/me.authorized` for `dropbox.upload` permission, checking `authorized` + `billingAccount.active` -- Results cached in-process for 30 seconds via `cachestore` +- Results cached in-process for 30 seconds via `cachestore`; the external call is wrapped in `sf.Do` so concurrent uploads from the same caller collapse to one round-trip at the cache-miss edge. + +**DDoS protection ladder:** see `fileHandler` / `cdnFileHandler` in `files.go`. In order from cheapest to most expensive: +1. `parseToken` HMAC check — pure CPU, no I/O. +2. `lookupFile` cache — 60s in-process cache of `(project_id, expires_at, bucket_missing)` per fn. +3. `sf.Do` around the DB `SELECT` — collapses any thundering herd at the cache-miss edge. +4. `Bucket.Attributes` — only reached for tokens that survive 1–3. **Key libraries (same pattern as `moonrhythm/registry`):** - `parapet` — HTTP server with middleware chain (healthz, logger, pgctx) - `pgctx` — context-aware PostgreSQL access (`pgctx.Exec`, middleware injects DB into context) -- `cachestore` — in-process TTL cache for auth results +- `cachestore` — in-process TTL cache for auth results and per-fn metadata +- `sf` — generic context-aware singleflight (`github.com/moonrhythm/sf`); used in `lookupFile` and `checkAuth` to dedupe concurrent backend calls - `configfile` — env-var config reader (`config.MustString`, `config.StringDefault`) ## Notes - `schema.sql` targets PostgreSQL; `project_id` is `text` (the API returns string IDs) -- `base_url` is the public download prefix (`https://dropbox.deploys.app/files/`); it shares the service host and resolves to the `GET /files/{fn}` route, which streams directly or 307s to the CDN (see `cdn_base_url`) +- `base_url` is the public download prefix (`https://dropbox.deploys.app/files/`); it shares the service host and resolves to the `GET /files/{token}` route, which streams directly or 307s to the CDN (see `cdn_base_url`) diff --git a/README.md b/README.md index 2ca7101..35f8610 100644 --- a/README.md +++ b/README.md @@ -25,6 +25,7 @@ Docker image is built and pushed automatically on push to `main`. See `.github/w |---------------|----------------------------------------------------------| | `db_url` | PostgreSQL connection string | | `bucket_name` | GCS bucket name | +| `sign_key` | HMAC key for signing download tokens. Rotating invalidates every outstanding URL. | | `base_url` | Download URL prefix (default: `https://dropbox.deploys.app/files/`) | | `PORT` | Listen port (default: `8080`) | @@ -74,12 +75,14 @@ File data binary { "ok": true, "result": { - "downloadUrl": "https://dropbox.deploys.app/files/", + "downloadUrl": "https://dropbox.deploys.app/files/", "expiresAt": "2020-01-01T01:01:01Z" } } ``` +`` is `{fn}-{sig}` (currently 45 chars): a 24-char random alphanumeric filename, a `-` separator, and a 20-char HMAC-SHA256 signature (keyed by `sign_key`). Tampered or made-up tokens are rejected before any DB or storage lookup. The separator means future changes to filename length stay backward-compatible. + ##### Unauthorized ```json diff --git a/auth.go b/auth.go index 331842d..f7b1b6f 100644 --- a/auth.go +++ b/auth.go @@ -8,6 +8,7 @@ import ( "time" "github.com/moonrhythm/cachestore" + "github.com/moonrhythm/sf" ) var apiEndpoint = "https://api.deploys.app" @@ -44,57 +45,72 @@ func checkAuth(ctx context.Context, auth, project, projectID string) AuthResult return v } - body, _ := json.Marshal(struct { - Project string `json:"project,omitempty"` - ProjectID string `json:"projectId,omitempty"` - Permissions []string `json:"permissions"` - }{ - Project: project, - ProjectID: projectID, - Permissions: []string{permission}, - }) + // Same singleflight pattern as lookupFile: collapse a thundering herd + // of concurrent uploads from the same caller into a single + // /me.authorized round-trip. The result is cached for 30s + // (cacheTTL); sf.Do dedupe matters at the cold-cache edge and right + // after that 30s entry expires under load. + result, _, _ := sf.Do(ctx, cacheKey, func(ctx context.Context) (AuthResult, error) { + // Re-check the cache: a sibling caller may have populated it + // while we were queued behind sf's mutex. + if v, ok := cachestore.Get[AuthResult](cacheKey); ok { + return v, nil + } - req, _ := http.NewRequest(http.MethodPost, apiEndpoint+"/me.authorized", bytes.NewReader(body)) - req.Header.Set("Authorization", auth) - req.Header.Set("Content-Type", "application/json") + body, _ := json.Marshal(struct { + Project string `json:"project,omitempty"` + ProjectID string `json:"projectId,omitempty"` + Permissions []string `json:"permissions"` + }{ + Project: project, + ProjectID: projectID, + Permissions: []string{permission}, + }) - resp, err := http.DefaultClient.Do(req) - if err != nil { - return AuthResult{} - } - defer resp.Body.Close() - if resp.StatusCode != http.StatusOK { - return AuthResult{} - } + req, _ := http.NewRequestWithContext(ctx, http.MethodPost, apiEndpoint+"/me.authorized", bytes.NewReader(body)) + req.Header.Set("Authorization", auth) + req.Header.Set("Content-Type", "application/json") - var res struct { - OK bool `json:"ok"` - Result struct { - Authorized bool `json:"authorized"` - Project struct { - ID string `json:"id"` - Project string `json:"project"` - BillingAccount struct { - Active bool `json:"active"` - } `json:"billingAccount"` - } `json:"project"` - } `json:"result"` - } - if err := json.NewDecoder(resp.Body).Decode(&res); err != nil { - return AuthResult{} - } + resp, err := http.DefaultClient.Do(req) + if err != nil { + // Don't cache transport failures — let the next caller retry. + return AuthResult{}, nil + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + return AuthResult{}, nil + } - var result AuthResult - if res.OK && res.Result.Authorized && res.Result.Project.BillingAccount.Active { - result = AuthResult{ - Authorized: true, - Project: Project{ - ID: res.Result.Project.ID, - Project: res.Result.Project.Project, - }, + var res struct { + OK bool `json:"ok"` + Result struct { + Authorized bool `json:"authorized"` + Project struct { + ID string `json:"id"` + Project string `json:"project"` + BillingAccount struct { + Active bool `json:"active"` + } `json:"billingAccount"` + } `json:"project"` + } `json:"result"` + } + if err := json.NewDecoder(resp.Body).Decode(&res); err != nil { + return AuthResult{}, nil + } + + var result AuthResult + if res.OK && res.Result.Authorized && res.Result.Project.BillingAccount.Active { + result = AuthResult{ + Authorized: true, + Project: Project{ + ID: res.Result.Project.ID, + Project: res.Result.Project.Project, + }, + } } - } - cachestore.Set(cacheKey, result, &cachestore.SetOptions{TTL: cacheTTL}) + cachestore.Set(cacheKey, result, &cachestore.SetOptions{TTL: cacheTTL}) + return result, nil + }) return result } diff --git a/auth_test.go b/auth_test.go index b74e073..411480c 100644 --- a/auth_test.go +++ b/auth_test.go @@ -8,6 +8,7 @@ import ( "sync" "sync/atomic" "testing" + "time" ) // A single mock server is started lazily and serves every auth test. Each test @@ -215,6 +216,49 @@ func TestCheckAuth_CachesResult(t *testing.T) { } } +func TestCheckAuth_SingleflightCollapsesConcurrentCalls(t *testing.T) { + // Same shape as TestLookupFile_SingleflightCollapsesConcurrentCalls: + // 50 goroutines race on a cold-cache (auth, project, projectId) + // triple. sf.Do must collapse them into a single /me.authorized + // round-trip — otherwise a thundering herd of uploads from one + // caller (e.g. parallel CI jobs holding the same bearer) hammers + // the deploys.app API on every cache-miss edge. + t.Parallel() + var calls atomic.Int64 + token := "Bearer " + t.Name() + registerAuthMock(t, token, func(w http.ResponseWriter, r *http.Request) { + calls.Add(1) + // A short sleep widens the singleflight window so the test + // reliably catches a regression where dedupe is broken. + time.Sleep(50 * time.Millisecond) + jsonAuthMock(true, true)(w, r) + }) + + const N = 50 + results := make([]AuthResult, N) + var wg sync.WaitGroup + start := make(chan struct{}) + for i := 0; i < N; i++ { + wg.Add(1) + go func(idx int) { + defer wg.Done() + <-start + results[idx] = checkAuth(context.Background(), token, "sfproject", "") + }(i) + } + close(start) + wg.Wait() + + for i, r := range results { + if !r.Authorized { + t.Errorf("results[%d] not authorized", i) + } + } + if got := calls.Load(); got >= N { + t.Errorf("auth API called %d times, want <%d (sf should have collapsed the herd)", got, N) + } +} + func TestCheckAuth_CacheKeyDistinguishesTokens(t *testing.T) { t.Parallel() var calls atomic.Int64 diff --git a/files.go b/files.go index 3a1bc2b..1fd0a2b 100644 --- a/files.go +++ b/files.go @@ -4,26 +4,91 @@ import ( "context" "database/sql" "errors" + "fmt" "io" "log/slog" "net/http" "strconv" + "sync" + "sync/atomic" "time" "github.com/acoshift/pgsql/pgctx" "github.com/moonrhythm/cachestore" + "github.com/moonrhythm/sf" "gocloud.dev/blob" "gocloud.dev/gcerrors" ) -const fileProjectCacheTTL = 60 * time.Second +const fileMetaCacheTTL = 60 * time.Second + +// fileMeta is the file's row in the files table plus a couple of derived +// negative-cache flags. Found is false when there is no row for the fn — +// that happens when the metadata insert failed during upload; we still +// serve the bucket bytes in that case rather than 404'ing, since we have +// no expiration to enforce against. BucketMissing is set by the file +// handlers (not by lookupFile) after a confirmed Bucket.Attributes +// NotFound, so subsequent requests within the cache TTL can return 404 +// without re-hitting GCS — that's what stops a DDoS against a dead-but- +// valid-format fn from amplifying into a GCS read flood. +type fileMeta struct { + ProjectID string + ExpiresAt time.Time + Found bool + BucketMissing bool +} + +// Expired reports whether the file has a recorded expiry that is now in +// the past. Files without a DB row (Found == false) and files with a NULL +// expires_at are treated as non-expired — we have no basis to refuse them. +func (m fileMeta) Expired() bool { + return m.Found && !m.ExpiresAt.IsZero() && m.ExpiresAt.Before(time.Now()) +} + +func fileCacheKey(fn string) string { return "fn|" + fn } func (a *App) fileHandler(w http.ResponseWriter, r *http.Request) { - fn := r.PathValue("fn") + token := r.PathValue("token") + + // Verify the HMAC tag in the token before any I/O. A flood of random + // /files/{garbage} requests by an attacker who doesn't have SignKey + // is rejected in CPU here — no DB query, no GCS Attributes call. + // This is the primary DDoS shield; the per-fn cache and the + // negative bucket cache only kick in once we've established the + // token is one we issued. + fn, ok := parseToken(a.SignKey, token) + if !ok { + http.NotFound(w, r) + return + } + + // Check the DB *before* touching GCS so a DDoS against an expired fn + // is absorbed by the in-process cache (60s TTL on an immutable row) + // instead of becoming one billed Bucket.Attributes call per request. + // For non-expired files this just reorders two calls we'd make anyway; + // for files with no DB row (insert-failed uploads) we fall through to + // the bucket exactly as before. + meta := lookupFile(r.Context(), fn) + if meta.Expired() { + http.Error(w, "file expired", http.StatusGone) + return + } + if meta.BucketMissing { + // Cached negative from a previous bucket NotFound — see the + // Attributes branch below. + http.NotFound(w, r) + return + } attrs, err := a.Bucket.Attributes(r.Context(), fn) if err != nil { if gcerrors.Code(err) == gcerrors.NotFound { + // Cache the negative so a flood against the same dead fn + // doesn't re-bill a Class B op per request. Safe even in the + // rare orphan case (Found=true, bucket gone): the file is + // unreachable either way until an operator cleans it up. + meta.BucketMissing = true + cachestore.Set(fileCacheKey(fn), meta, &cachestore.SetOptions{TTL: fileMetaCacheTTL}) http.NotFound(w, r) return } @@ -31,33 +96,76 @@ func (a *App) fileHandler(w http.ResponseWriter, r *http.Request) { return } - projectID := lookupFileProject(r.Context(), fn) - // When a CDN is configured and the caller is not in-cluster, record the // download as if we were serving the full file (the CDN will handle the - // actual bytes) and redirect. Cache hits never come back to this origin, - // so the bill is an over-approximation; that matches the registry - // pattern and is what /_internal/calculate-dropbox-usages assumes. + // actual bytes) and redirect. The redirect target preserves the full + // signed token so the CDN's origin-fetch hits cdnFileHandler with a + // URL we can re-verify. Cache hits never come back to this origin, so + // the bill is an over-approximation; that matches the registry pattern + // and is what /_internal/calculate-dropbox-usages assumes. if a.CDNBaseURL != "" && !isInternalClient(r) { - downloadCount.WithLabelValues(projectID).Inc() - egressBytes.WithLabelValues(projectID).Add(float64(attrs.Size)) - http.Redirect(w, r, a.CDNBaseURL+fn, http.StatusTemporaryRedirect) + downloadCount.WithLabelValues(meta.ProjectID).Inc() + egressBytes.WithLabelValues(meta.ProjectID).Add(float64(attrs.Size)) + http.Redirect(w, r, a.CDNBaseURL+token, http.StatusTemporaryRedirect) return } - a.streamFile(w, r, fn, attrs, projectID) + a.streamFile(w, r, fn, attrs, meta.ProjectID) } -// cdnFileHandler is the origin endpoint the CDN edge fetches from. It is -// unauthenticated — file URLs are 86-char crypto-random and only reachable -// if you know them — and skips the redirect/metrics so the CDN sees a -// plain streaming response. +// cdnDeadResponseTTL is how long the CDN edge should cache "this URL is +// dead" responses (410 expired, 404 bucket-missing). Long enough that a +// shared-then-expired URL stops hammering origin, short enough that the +// answer can update if we ever change behavior. +const cdnDeadResponseTTL = time.Hour + +// cdnFileHandler is the origin endpoint the CDN edge fetches from. The +// token in the URL is the same signed token we issued to the user, so +// the HMAC check still gates the edge — only URLs we issued can reach +// the bucket. Streams the body without metrics so the CDN sees a plain +// response; expired files are refused here too so the CDN can't refresh +// its cache with bytes the user is no longer entitled to. +// +// Each response sets Cache-Control to tell the edge how long to cache: +// - success (200): public, max-age={remaining TTL}, immutable — fn is +// unique per upload so the body never changes; cap at the file's +// expires_at so the edge stops serving past end-of-life. +// - expired (410) and bucket-missing (404): public, max-age=3600 — +// the answer is permanent, so let the edge absorb repeat probes. +// - invalid token (404): no Cache-Control. Every garbage token is a +// unique URL, so caching doesn't reduce origin load and would just +// burn edge cache slots on attacker traffic. func (a *App) cdnFileHandler(w http.ResponseWriter, r *http.Request) { - fn := r.PathValue("fn") + token := r.PathValue("token") + + // Same DDoS-protection ladder as fileHandler: verify the signature, + // then consult the per-fn cache, then GCS. The edge would otherwise + // amplify a random-garbage flood or an expired-URL probe into one + // GCS op per request. + fn, ok := parseToken(a.SignKey, token) + if !ok { + http.NotFound(w, r) + return + } + + meta := lookupFile(r.Context(), fn) + if meta.Expired() { + setDeadCacheControl(w) + http.Error(w, "file expired", http.StatusGone) + return + } + if meta.BucketMissing { + setDeadCacheControl(w) + http.NotFound(w, r) + return + } attrs, err := a.Bucket.Attributes(r.Context(), fn) if err != nil { if gcerrors.Code(err) == gcerrors.NotFound { + meta.BucketMissing = true + cachestore.Set(fileCacheKey(fn), meta, &cachestore.SetOptions{TTL: fileMetaCacheTTL}) + setDeadCacheControl(w) http.NotFound(w, r) return } @@ -65,9 +173,24 @@ func (a *App) cdnFileHandler(w http.ResponseWriter, r *http.Request) { return } + // Pre-set Cache-Control so streamFile's attrs.CacheControl fallback + // doesn't overwrite it. If we have no expires_at to work from + // (insert-failed upload), fall through to the bucket's value. + if !meta.ExpiresAt.IsZero() { + remaining := time.Until(meta.ExpiresAt) + if remaining < 0 { + remaining = 0 + } + w.Header().Set("Cache-Control", fmt.Sprintf("public, max-age=%d, immutable", int(remaining.Seconds()))) + } + a.streamFile(w, r, fn, attrs, "") } +func setDeadCacheControl(w http.ResponseWriter) { + w.Header().Set("Cache-Control", fmt.Sprintf("public, max-age=%d", int(cdnDeadResponseTTL.Seconds()))) +} + // streamFile copies the object body to w with the cached headers. When // projectID is non-empty, download_count and egress_bytes are bumped by // the actual transferred bytes. Used by both the direct-stream path @@ -84,7 +207,10 @@ func (a *App) streamFile(w http.ResponseWriter, r *http.Request, fn string, attr } defer reader.Close() - if attrs.CacheControl != "" { + // Only fall back to the bucket's CacheControl if the caller hasn't + // already chosen one. cdnFileHandler pre-sets a TTL-aligned policy; + // the non-CDN fileHandler leaves it unset and gets the bucket value. + if attrs.CacheControl != "" && w.Header().Get("Cache-Control") == "" { w.Header().Set("Cache-Control", attrs.CacheControl) } if attrs.ContentDisposition != "" { @@ -104,21 +230,77 @@ func (a *App) streamFile(w http.ResponseWriter, r *http.Request, fn string, attr } } -func lookupFileProject(ctx context.Context, fn string) string { - cacheKey := "fn|" + fn - if v, ok := cachestore.Get[string](cacheKey); ok { +// lookupFile fetches the file's project owner and expiration from the DB. +// Both fields are immutable once written, so the in-process cache only +// needs to be short enough to absorb bursts on the same fn — not to +// track any state that can change underneath us. +// +// The DB query runs inside sf.Do so a thundering herd at the cache-miss +// edge (cold start, or the moment a 60s entry expires under load) +// collapses to a single Postgres round-trip. Every concurrent caller +// for the same fn gets the same result without each one issuing its +// own query. +func lookupFile(ctx context.Context, fn string) fileMeta { + cacheKey := fileCacheKey(fn) + if v, ok := cachestore.Get[fileMeta](cacheKey); ok { return v } - var projectID string - err := pgctx.QueryRow(ctx, ` - SELECT project_id FROM files WHERE fn = $1 - `, fn).Scan(&projectID) - if err != nil && !errors.Is(err, sql.ErrNoRows) { - slog.Error("lookup file project", "fn", fn, "error", err) - return "" - } + m, _, _ := sf.Do(ctx, "lookupFile|"+fn, func(ctx context.Context) (fileMeta, error) { + // Re-check the cache: a sibling caller may have populated it + // while we were queued behind sf's mutex. + if v, ok := cachestore.Get[fileMeta](cacheKey); ok { + return v, nil + } + + recordLookupFileDBQuery(fn) + var ( + m fileMeta + expires sql.NullTime + ) + err := pgctx.QueryRow(ctx, ` + SELECT project_id, expires_at FROM files WHERE fn = $1 + `, fn).Scan(&m.ProjectID, &expires) + switch { + case err == nil: + m.Found = true + if expires.Valid { + m.ExpiresAt = expires.Time + } + case errors.Is(err, sql.ErrNoRows): + // Found=false — a confirmed miss is fine to cache. + default: + // Transient DB error: don't cache-poison. Returning here + // also skips the cache write below; the next caller will + // retry the query. + slog.Error("lookup file", "fn", fn, "error", err) + return fileMeta{}, nil + } + + cachestore.Set(cacheKey, m, &cachestore.SetOptions{TTL: fileMetaCacheTTL}) + return m, nil + }) + return m +} - cachestore.Set(cacheKey, projectID, &cachestore.SetOptions{TTL: fileProjectCacheTTL}) - return projectID +// lookupFileDBQueries counts the number of actual Postgres SELECTs +// issued by lookupFile, keyed by fn. The counter is only bumped inside +// the singleflight closure, so a successful sf.Do collapse leaves it at +// 1 even after N concurrent callers. Production code only writes here; +// only tests read it. Per-fn so parallel tests don't trample each other. +var lookupFileDBQueries sync.Map // map[string]*atomic.Uint64 + +func recordLookupFileDBQuery(fn string) { + v, _ := lookupFileDBQueries.LoadOrStore(fn, new(atomic.Uint64)) + v.(*atomic.Uint64).Add(1) +} + +// lookupFileDBQueryCount returns how many DB queries lookupFile has +// issued for fn since process start. Test-only. +func lookupFileDBQueryCount(fn string) uint64 { + v, ok := lookupFileDBQueries.Load(fn) + if !ok { + return 0 + } + return v.(*atomic.Uint64).Load() } diff --git a/files_test.go b/files_test.go index 0efd4d1..47bc518 100644 --- a/files_test.go +++ b/files_test.go @@ -1,9 +1,11 @@ package main import ( + "fmt" "net/http" "net/http/httptest" "strings" + "sync" "testing" "github.com/acoshift/pgsql/pgctx" @@ -16,7 +18,8 @@ func TestFileHandler_Success(t *testing.T) { bkt := newTestBucket(t) app := newTestApp(bkt, authorized) - bw, err := bkt.NewWriter(t.Context(), "testfile", &blob.WriterOptions{ + fn := validTestFn("testfile") + bw, err := bkt.NewWriter(t.Context(), fn, &blob.WriterOptions{ CacheControl: "public, max-age=86400", ContentDisposition: `attachment; filename="hello.txt"`, }) @@ -30,9 +33,9 @@ func TestFileHandler_Success(t *testing.T) { t.Fatal(err) } - r := httptest.NewRequest(http.MethodGet, "/files/testfile", nil) + r := httptest.NewRequest(http.MethodGet, "/files/"+signedToken(fn), nil) r = r.WithContext(db.Ctx()) - r.SetPathValue("fn", "testfile") + r.SetPathValue("token", signedToken(fn)) w := httptest.NewRecorder() app.fileHandler(w, r) @@ -58,9 +61,10 @@ func TestFileHandler_NotFound(t *testing.T) { db := newTestDB(t) app := newTestApp(newTestBucket(t), authorized) - r := httptest.NewRequest(http.MethodGet, "/files/doesnotexist", nil) + fn := validTestFn("notexist") + r := httptest.NewRequest(http.MethodGet, "/files/"+signedToken(fn), nil) r = r.WithContext(db.Ctx()) - r.SetPathValue("fn", "doesnotexist") + r.SetPathValue("token", signedToken(fn)) w := httptest.NewRecorder() app.fileHandler(w, r) @@ -75,7 +79,8 @@ func TestFileHandler_RouteIntegration(t *testing.T) { bkt := newTestBucket(t) app := newTestApp(bkt, authorized) - bw, err := bkt.NewWriter(t.Context(), "routefile", nil) + fn := validTestFn("routefile") + bw, err := bkt.NewWriter(t.Context(), fn, nil) if err != nil { t.Fatal(err) } @@ -84,7 +89,7 @@ func TestFileHandler_RouteIntegration(t *testing.T) { t.Fatal(err) } - r := httptest.NewRequest(http.MethodGet, "/files/routefile", nil) + r := httptest.NewRequest(http.MethodGet, "/files/"+signedToken(fn), nil) r = r.WithContext(db.Ctx()) w := httptest.NewRecorder() app.routes().ServeHTTP(w, r) @@ -100,7 +105,8 @@ func TestFileHandler_NoHeadersWhenAttrsEmpty(t *testing.T) { bkt := newTestBucket(t) app := newTestApp(bkt, authorized) - bw, err := bkt.NewWriter(t.Context(), "plain", nil) + fn := validTestFn("plain") + bw, err := bkt.NewWriter(t.Context(), fn, nil) if err != nil { t.Fatal(err) } @@ -109,9 +115,9 @@ func TestFileHandler_NoHeadersWhenAttrsEmpty(t *testing.T) { t.Fatal(err) } - r := httptest.NewRequest(http.MethodGet, "/files/plain", nil) + r := httptest.NewRequest(http.MethodGet, "/files/"+signedToken(fn), nil) r = r.WithContext(db.Ctx()) - r.SetPathValue("fn", "plain") + r.SetPathValue("token", signedToken(fn)) w := httptest.NewRecorder() app.fileHandler(w, r) @@ -126,7 +132,7 @@ func TestFileHandler_NoHeadersWhenAttrsEmpty(t *testing.T) { } } -func TestLookupFileProject_FromDB(t *testing.T) { +func TestLookupFile_FromDB(t *testing.T) { t.Parallel() db := newTestDB(t) ctx := db.Ctx() @@ -138,21 +144,277 @@ func TestLookupFileProject_FromDB(t *testing.T) { t.Fatal(err) } - if got := lookupFileProject(ctx, "lookupfile"); got != "proj-xyz" { - t.Errorf("lookupFileProject = %q, want proj-xyz", got) + got := lookupFile(ctx, "lookupfile") + if !got.Found { + t.Fatal("Found = false, want true") + } + if got.ProjectID != "proj-xyz" { + t.Errorf("ProjectID = %q, want proj-xyz", got.ProjectID) + } + if got.Expired() { + t.Errorf("Expired() = true, want false (expires_at is in the future)") } // Second call exercises the cache hit path. - if got := lookupFileProject(ctx, "lookupfile"); got != "proj-xyz" { - t.Errorf("lookupFileProject (cached) = %q, want proj-xyz", got) + if got := lookupFile(ctx, "lookupfile"); got.ProjectID != "proj-xyz" { + t.Errorf("ProjectID (cached) = %q, want proj-xyz", got.ProjectID) + } +} + +func TestLookupFile_NotFound(t *testing.T) { + t.Parallel() + db := newTestDB(t) + + got := lookupFile(db.Ctx(), "missingfile") + if got.Found { + t.Errorf("Found = true, want false for missing fn") + } + if got.ProjectID != "" { + t.Errorf("ProjectID = %q, want empty for missing fn", got.ProjectID) + } + if got.Expired() { + t.Errorf("Expired() = true, want false when not found (nothing to enforce)") + } +} + +func TestLookupFile_SingleflightCollapsesConcurrentCalls(t *testing.T) { + // 50 goroutines race on the same cold-cache fn. sf.Do must collapse + // them into a single Postgres SELECT — anything more and the DDoS + // shield against a thundering herd at the cache-miss edge is gone. + t.Parallel() + db := newTestDB(t) + ctx := db.Ctx() + + fn := "sf-thundering-herd-fn" + if _, err := pgctx.Exec(ctx, ` + INSERT INTO files (fn, project_id, size, filename, ttl, expires_at) + VALUES ($1, 'proj-sf', 1, 'x', 1, now() + interval '1 day') + `, fn); err != nil { + t.Fatal(err) + } + + const N = 50 + results := make([]fileMeta, N) + var wg sync.WaitGroup + start := make(chan struct{}) + for i := 0; i < N; i++ { + wg.Add(1) + go func(idx int) { + defer wg.Done() + <-start // release all goroutines at once to maximise overlap + results[idx] = lookupFile(ctx, fn) + }(i) + } + close(start) + wg.Wait() + + // Correctness: every caller got the same correct result. + for i, r := range results { + if !r.Found || r.ProjectID != "proj-sf" { + t.Errorf("results[%d] = %+v, want Found=true ProjectID=proj-sf", i, r) + } + } + + // Dedupe: sf collapsed the herd. We allow >1 because the per-caller + // cache check before sf.Do can race in a way that lets a few + // stragglers issue their own queries when the first one finishes + // fast and the cache populates before later goroutines reach sf — + // but it must be far below N. + if got := lookupFileDBQueryCount(fn); got >= N { + t.Errorf("DB queries for %s = %d, want <%d (singleflight should have collapsed the herd)", fn, got, N) + } +} + +func TestLookupFile_Expired(t *testing.T) { + t.Parallel() + db := newTestDB(t) + ctx := db.Ctx() + + if _, err := pgctx.Exec(ctx, ` + INSERT INTO files (fn, project_id, size, filename, ttl, expires_at) + VALUES ('expiredmeta', 'proj-xyz', 1, 'x', 1, now() - interval '1 hour') + `); err != nil { + t.Fatal(err) + } + + got := lookupFile(ctx, "expiredmeta") + if !got.Found { + t.Fatal("Found = false, want true") + } + if !got.Expired() { + t.Errorf("Expired() = false, want true (expires_at is in the past)") + } +} + +func TestFileHandler_ExpiredReturnsGone(t *testing.T) { + t.Parallel() + db := newTestDB(t) + bkt := newTestBucket(t) + app := newTestApp(bkt, authorized) + + // Object still in the bucket — GC hasn't run yet — but the DB row + // says it expired an hour ago. + fn := validTestFn("expiredfile") + bw, err := bkt.NewWriter(t.Context(), fn, nil) + if err != nil { + t.Fatal(err) + } + bw.Write([]byte("stale bytes")) + if err := bw.Close(); err != nil { + t.Fatal(err) + } + + ctx := db.Ctx() + if _, err := pgctx.Exec(ctx, ` + INSERT INTO files (fn, project_id, size, filename, ttl, expires_at) + VALUES ($1, 'proj-xyz', 11, 'x', 1, now() - interval '1 hour') + `, fn); err != nil { + t.Fatal(err) + } + + r := httptest.NewRequest(http.MethodGet, "/files/"+signedToken(fn), nil) + r = r.WithContext(ctx) + r.SetPathValue("token", signedToken(fn)) + w := httptest.NewRecorder() + app.fileHandler(w, r) + + if w.Code != http.StatusGone { + t.Fatalf("status = %d, want 410", w.Code) + } + if body := w.Body.String(); strings.Contains(body, "stale bytes") { + t.Errorf("body leaked object content: %q", body) + } +} + +func TestFileHandler_ExpiredSkipsBucket(t *testing.T) { + // Expired DB row, *empty bucket* — proves the expired check runs + // before Bucket.Attributes. Under DDoS this is the difference between + // every request burning a GCS Class B operation and every request + // being a cheap cache hit. If the order regresses, this test starts + // returning 404 (bucket-NotFound) instead of 410. + t.Parallel() + db := newTestDB(t) + app := newTestApp(newTestBucket(t), authorized) + + fn := validTestFn("expiredghost") + ctx := db.Ctx() + if _, err := pgctx.Exec(ctx, ` + INSERT INTO files (fn, project_id, size, filename, ttl, expires_at) + VALUES ($1, 'proj-xyz', 99, 'x', 1, now() - interval '1 hour') + `, fn); err != nil { + t.Fatal(err) + } + + r := httptest.NewRequest(http.MethodGet, "/files/"+signedToken(fn), nil) + r = r.WithContext(ctx) + r.SetPathValue("token", signedToken(fn)) + w := httptest.NewRecorder() + app.fileHandler(w, r) + + if w.Code != http.StatusGone { + t.Fatalf("status = %d, want 410 (expired check must run before bucket call)", w.Code) } } -func TestLookupFileProject_NotFound(t *testing.T) { +func TestCDNFileHandler_ExpiredSkipsBucket(t *testing.T) { + // Same DDoS-protection assertion for the /_cdn/{fn} origin path. t.Parallel() db := newTestDB(t) + app := newTestApp(newTestBucket(t), authorized) + app.CDNBaseURL = "https://cdn.example.com/" + + fn := validTestFn("expiredghostcdn") + ctx := db.Ctx() + if _, err := pgctx.Exec(ctx, ` + INSERT INTO files (fn, project_id, size, filename, ttl, expires_at) + VALUES ($1, 'proj-xyz', 99, 'x', 1, now() - interval '1 hour') + `, fn); err != nil { + t.Fatal(err) + } + + r := httptest.NewRequest(http.MethodGet, "/_cdn/"+signedToken(fn), nil) + r = r.WithContext(ctx) + w := httptest.NewRecorder() + app.routes().ServeHTTP(w, r) + + if w.Code != http.StatusGone { + t.Fatalf("status = %d, want 410 (expired check must run before bucket call)", w.Code) + } +} + +func TestFileHandler_ExpiredOverridesCDNRedirect(t *testing.T) { + t.Parallel() + db := newTestDB(t) + bkt := newTestBucket(t) + app := newTestApp(bkt, authorized) + app.CDNBaseURL = "https://cdn.example.com/" + + fn := validTestFn("expiredcdn") + bw, err := bkt.NewWriter(t.Context(), fn, nil) + if err != nil { + t.Fatal(err) + } + bw.Write([]byte("stale")) + if err := bw.Close(); err != nil { + t.Fatal(err) + } + + ctx := db.Ctx() + if _, err := pgctx.Exec(ctx, ` + INSERT INTO files (fn, project_id, size, filename, ttl, expires_at) + VALUES ($1, 'proj-xyz', 5, 'x', 1, now() - interval '1 hour') + `, fn); err != nil { + t.Fatal(err) + } + + r := httptest.NewRequest(http.MethodGet, "/files/"+signedToken(fn), nil) + r = r.WithContext(ctx) + r.SetPathValue("token", signedToken(fn)) + w := httptest.NewRecorder() + app.fileHandler(w, r) - if got := lookupFileProject(db.Ctx(), "missingfile"); got != "" { - t.Errorf("lookupFileProject = %q, want empty for missing fn", got) + if w.Code != http.StatusGone { + t.Fatalf("status = %d, want 410 (must not redirect to CDN)", w.Code) + } + if loc := w.Header().Get("Location"); loc != "" { + t.Errorf("Location = %q, want empty (no redirect for expired files)", loc) + } +} + +func TestCDNFileHandler_ExpiredReturnsGone(t *testing.T) { + t.Parallel() + db := newTestDB(t) + bkt := newTestBucket(t) + app := newTestApp(bkt, authorized) + app.CDNBaseURL = "https://cdn.example.com/" + + fn := validTestFn("expiredorigin") + bw, err := bkt.NewWriter(t.Context(), fn, nil) + if err != nil { + t.Fatal(err) + } + bw.Write([]byte("stale origin")) + if err := bw.Close(); err != nil { + t.Fatal(err) + } + + ctx := db.Ctx() + if _, err := pgctx.Exec(ctx, ` + INSERT INTO files (fn, project_id, size, filename, ttl, expires_at) + VALUES ($1, 'proj-xyz', 12, 'x', 1, now() - interval '1 hour') + `, fn); err != nil { + t.Fatal(err) + } + + r := httptest.NewRequest(http.MethodGet, "/_cdn/"+signedToken(fn), nil) + r = r.WithContext(ctx) + w := httptest.NewRecorder() + app.routes().ServeHTTP(w, r) + + if w.Code != http.StatusGone { + t.Fatalf("status = %d, want 410 (CDN origin must refuse expired files)", w.Code) + } + if body := w.Body.String(); strings.Contains(body, "stale origin") { + t.Errorf("body leaked object content to CDN edge: %q", body) } } @@ -163,7 +425,8 @@ func TestFileHandler_CDNRedirect(t *testing.T) { app := newTestApp(bkt, authorized) app.CDNBaseURL = "https://cdn.example.com/" - bw, err := bkt.NewWriter(t.Context(), "cdnfile", nil) + fn := validTestFn("cdnfile") + bw, err := bkt.NewWriter(t.Context(), fn, nil) if err != nil { t.Fatal(err) } @@ -172,22 +435,24 @@ func TestFileHandler_CDNRedirect(t *testing.T) { t.Fatal(err) } - r := httptest.NewRequest(http.MethodGet, "/files/cdnfile", nil) + r := httptest.NewRequest(http.MethodGet, "/files/"+signedToken(fn), nil) r = r.WithContext(db.Ctx()) - r.SetPathValue("fn", "cdnfile") + r.SetPathValue("token", signedToken(fn)) w := httptest.NewRecorder() app.fileHandler(w, r) if w.Code != http.StatusTemporaryRedirect { t.Fatalf("status = %d, want 307", w.Code) } - if loc := w.Header().Get("Location"); loc != "https://cdn.example.com/cdnfile" { - t.Errorf("Location = %q, want https://cdn.example.com/cdnfile", loc) + // Redirect preserves the *full* signed token so the CDN's + // origin-fetch hits cdnFileHandler with a URL we can re-verify. + if want := "https://cdn.example.com/" + signedToken(fn); w.Header().Get("Location") != want { + t.Errorf("Location = %q, want %q", w.Header().Get("Location"), want) } - if got := w.Body.Len(); got > 100 { - // the default redirect body is small (""); we shouldn't - // be streaming the object. - t.Errorf("body length = %d, expected a short redirect body", got) + // The default http.Redirect body is `Temporary Redirect` — + // it must not contain the object bytes. + if got := w.Body.String(); strings.Contains(got, "hello world") { + t.Errorf("body streamed the object instead of redirecting: %q", got) } } @@ -198,7 +463,8 @@ func TestFileHandler_CDNInternalClientStreams(t *testing.T) { app := newTestApp(bkt, authorized) app.CDNBaseURL = "https://cdn.example.com/" - bw, err := bkt.NewWriter(t.Context(), "internalfile", nil) + fn := validTestFn("internalfile") + bw, err := bkt.NewWriter(t.Context(), fn, nil) if err != nil { t.Fatal(err) } @@ -207,9 +473,9 @@ func TestFileHandler_CDNInternalClientStreams(t *testing.T) { t.Fatal(err) } - r := httptest.NewRequest(http.MethodGet, "/files/internalfile", nil) + r := httptest.NewRequest(http.MethodGet, "/files/"+signedToken(fn), nil) r = r.WithContext(db.Ctx()) - r.SetPathValue("fn", "internalfile") + r.SetPathValue("token", signedToken(fn)) r.Header.Set("X-Real-Ip", "10.0.0.5") // private IP -> internal w := httptest.NewRecorder() app.fileHandler(w, r) @@ -229,7 +495,8 @@ func TestFileHandler_CDNRedirectPublicXRealIP(t *testing.T) { app := newTestApp(bkt, authorized) app.CDNBaseURL = "https://cdn.example.com/" - bw, err := bkt.NewWriter(t.Context(), "publicfile", nil) + fn := validTestFn("publicfile") + bw, err := bkt.NewWriter(t.Context(), fn, nil) if err != nil { t.Fatal(err) } @@ -238,9 +505,9 @@ func TestFileHandler_CDNRedirectPublicXRealIP(t *testing.T) { t.Fatal(err) } - r := httptest.NewRequest(http.MethodGet, "/files/publicfile", nil) + r := httptest.NewRequest(http.MethodGet, "/files/"+signedToken(fn), nil) r = r.WithContext(db.Ctx()) - r.SetPathValue("fn", "publicfile") + r.SetPathValue("token", signedToken(fn)) r.Header.Set("X-Real-Ip", "203.0.113.5") // public IP -> CDN path w := httptest.NewRecorder() app.fileHandler(w, r) @@ -257,7 +524,8 @@ func TestCDNFileHandler_Streams(t *testing.T) { app := newTestApp(bkt, authorized) app.CDNBaseURL = "https://cdn.example.com/" - bw, err := bkt.NewWriter(t.Context(), "origin", &blob.WriterOptions{ + fn := validTestFn("origin") + bw, err := bkt.NewWriter(t.Context(), fn, &blob.WriterOptions{ CacheControl: "public, max-age=86400", }) if err != nil { @@ -268,7 +536,7 @@ func TestCDNFileHandler_Streams(t *testing.T) { t.Fatal(err) } - r := httptest.NewRequest(http.MethodGet, "/_cdn/origin", nil) + r := httptest.NewRequest(http.MethodGet, "/_cdn/"+signedToken(fn), nil) r = r.WithContext(db.Ctx()) w := httptest.NewRecorder() app.routes().ServeHTTP(w, r) @@ -290,7 +558,8 @@ func TestCDNFileHandler_NotFound(t *testing.T) { app := newTestApp(newTestBucket(t), authorized) app.CDNBaseURL = "https://cdn.example.com/" - r := httptest.NewRequest(http.MethodGet, "/_cdn/nope", nil) + fn := validTestFn("cdnnope") + r := httptest.NewRequest(http.MethodGet, "/_cdn/"+signedToken(fn), nil) r = r.WithContext(db.Ctx()) w := httptest.NewRecorder() app.routes().ServeHTTP(w, r) @@ -298,6 +567,255 @@ func TestCDNFileHandler_NotFound(t *testing.T) { if w.Code != http.StatusNotFound { t.Errorf("status = %d, want 404", w.Code) } + // Bucket NotFound for a signed token → cache the negative at the + // edge so the same dead URL doesn't keep hitting origin. + if cc := w.Header().Get("Cache-Control"); cc != "public, max-age=3600" { + t.Errorf("Cache-Control = %q, want public, max-age=3600", cc) + } +} + +func TestCDNFileHandler_SuccessCacheControlUsesRemainingTTL(t *testing.T) { + // The CDN response must cap max-age at the file's remaining TTL so + // the edge stops serving past expires_at. fn is unique per upload + // so we mark the body immutable. + t.Parallel() + db := newTestDB(t) + bkt := newTestBucket(t) + app := newTestApp(bkt, authorized) + app.CDNBaseURL = "https://cdn.example.com/" + + fn := validTestFn("cdncc") + bw, err := bkt.NewWriter(t.Context(), fn, &blob.WriterOptions{ + CacheControl: "public, max-age=86400", // bucket default — must be overridden + }) + if err != nil { + t.Fatal(err) + } + bw.Write([]byte("data")) + if err := bw.Close(); err != nil { + t.Fatal(err) + } + + ctx := db.Ctx() + if _, err := pgctx.Exec(ctx, ` + INSERT INTO files (fn, project_id, size, filename, ttl, expires_at) + VALUES ($1, 'proj-cc', 4, 'x', 7, now() + interval '7 days') + `, fn); err != nil { + t.Fatal(err) + } + + r := httptest.NewRequest(http.MethodGet, "/_cdn/"+signedToken(fn), nil) + r = r.WithContext(ctx) + w := httptest.NewRecorder() + app.routes().ServeHTTP(w, r) + + if w.Code != http.StatusOK { + t.Fatalf("status = %d, want 200", w.Code) + } + cc := w.Header().Get("Cache-Control") + if !strings.HasPrefix(cc, "public, max-age=") || !strings.HasSuffix(cc, ", immutable") { + t.Fatalf("Cache-Control = %q, want shape 'public, max-age=N, immutable'", cc) + } + // Extract and sanity-check the max-age — must reflect ~7 days, not + // the bucket default of 86400. + var maxAge int + if _, err := fmt.Sscanf(cc, "public, max-age=%d, immutable", &maxAge); err != nil { + t.Fatalf("parse max-age from %q: %v", cc, err) + } + const sevenDays = 7 * 24 * 3600 + if maxAge < sevenDays-3600 || maxAge > sevenDays { + t.Errorf("max-age = %d, want roughly 7 days (%d ± 1h)", maxAge, sevenDays) + } +} + +func TestCDNFileHandler_ExpiredCacheControl(t *testing.T) { + // 410 response must carry a short Cache-Control so the edge stops + // asking origin for an expired URL that's still in circulation. + t.Parallel() + db := newTestDB(t) + bkt := newTestBucket(t) + app := newTestApp(bkt, authorized) + app.CDNBaseURL = "https://cdn.example.com/" + + fn := validTestFn("cdnccexp") + bw, _ := bkt.NewWriter(t.Context(), fn, nil) + bw.Write([]byte("stale")) + bw.Close() + + ctx := db.Ctx() + if _, err := pgctx.Exec(ctx, ` + INSERT INTO files (fn, project_id, size, filename, ttl, expires_at) + VALUES ($1, 'proj-cc', 5, 'x', 1, now() - interval '1 hour') + `, fn); err != nil { + t.Fatal(err) + } + + r := httptest.NewRequest(http.MethodGet, "/_cdn/"+signedToken(fn), nil) + r = r.WithContext(ctx) + w := httptest.NewRecorder() + app.routes().ServeHTTP(w, r) + + if w.Code != http.StatusGone { + t.Fatalf("status = %d, want 410", w.Code) + } + if cc := w.Header().Get("Cache-Control"); cc != "public, max-age=3600" { + t.Errorf("Cache-Control = %q, want public, max-age=3600", cc) + } +} + +func TestCDNFileHandler_InvalidTokenNoCacheControl(t *testing.T) { + // Forgeable garbage tokens must NOT get a Cache-Control — each + // token is a unique URL so caching wouldn't reduce origin load and + // would just fill edge slots on attacker traffic. + t.Parallel() + app := newTestApp(newTestBucket(t), authorized) + app.CDNBaseURL = "https://cdn.example.com/" + + r := httptest.NewRequest(http.MethodGet, "/_cdn/garbage", nil) + w := httptest.NewRecorder() + app.routes().ServeHTTP(w, r) + + if w.Code != http.StatusNotFound { + t.Fatalf("status = %d, want 404", w.Code) + } + if cc := w.Header().Get("Cache-Control"); cc != "" { + t.Errorf("Cache-Control = %q, want empty (attacker traffic shouldn't be cached)", cc) + } +} + +func TestFileHandler_InvalidTokenRejected(t *testing.T) { + // Tokens that don't HMAC-verify must 404 from parseToken without + // touching DB or GCS. The test app has no pgctx attached, so if the + // signature shortcut ever regresses, lookupFile will panic and the + // test blows up loudly — which is the right failure mode for the + // primary DDoS shield going away. + t.Parallel() + app := newTestApp(newTestBucket(t), authorized) + + goodFn := validTestFn("good") + goodToken := signedToken(goodFn) + tamperedSig := goodToken[:len(goodToken)-1] + "0" + if tamperedSig == goodToken { + tamperedSig = goodToken[:len(goodToken)-1] + "1" + } + forgedDiffKey := makeToken([]byte("not-the-real-key"), goodFn) + + cases := map[string]string{ + "empty": "", + "short": "short", + "path-traversal": "../../etc/passwd", + "no-separator": strings.Repeat("a", fnLen+sigLen), + "empty-fn": tokenSep + signFilename(testSignKey, ""), + "empty-sig": goodFn + tokenSep, + "tampered-sig": tamperedSig, + "forged-other-key": forgedDiffKey, + } + for name, token := range cases { + t.Run(name, func(t *testing.T) { + t.Parallel() + r := httptest.NewRequest(http.MethodGet, "/files/x", nil) + r.SetPathValue("token", token) + w := httptest.NewRecorder() + app.fileHandler(w, r) + if w.Code != http.StatusNotFound { + t.Errorf("token=%q: status = %d, want 404", token, w.Code) + } + }) + } +} + +func TestCDNFileHandler_InvalidTokenRejected(t *testing.T) { + t.Parallel() + app := newTestApp(newTestBucket(t), authorized) + app.CDNBaseURL = "https://cdn.example.com/" + + r := httptest.NewRequest(http.MethodGet, "/_cdn/short", nil) + w := httptest.NewRecorder() + app.routes().ServeHTTP(w, r) + if w.Code != http.StatusNotFound { + t.Errorf("status = %d, want 404", w.Code) + } +} + +func TestFileHandler_BucketMissNegativeCached(t *testing.T) { + // First request misses everywhere → 404 and caches BucketMissing. + // Then we write the object out-of-band into the bucket. The second + // request must still 404, proving Bucket.Attributes was not called + // — that's the DDoS protection for a flood against a known-bad fn + // (e.g. a URL the attacker held after expiry+GC). + t.Parallel() + db := newTestDB(t) + bkt := newTestBucket(t) + app := newTestApp(bkt, authorized) + + fn := validTestFn("missneg") + + r1 := httptest.NewRequest(http.MethodGet, "/files/"+signedToken(fn), nil) + r1 = r1.WithContext(db.Ctx()) + r1.SetPathValue("token", signedToken(fn)) + w1 := httptest.NewRecorder() + app.fileHandler(w1, r1) + if w1.Code != http.StatusNotFound { + t.Fatalf("first request: status = %d, want 404", w1.Code) + } + + bw, err := bkt.NewWriter(t.Context(), fn, nil) + if err != nil { + t.Fatal(err) + } + bw.Write([]byte("late arrival")) + if err := bw.Close(); err != nil { + t.Fatal(err) + } + + r2 := httptest.NewRequest(http.MethodGet, "/files/"+signedToken(fn), nil) + r2 = r2.WithContext(db.Ctx()) + r2.SetPathValue("token", signedToken(fn)) + w2 := httptest.NewRecorder() + app.fileHandler(w2, r2) + if w2.Code != http.StatusNotFound { + t.Fatalf("second request: status = %d, want 404 (negative cache must mask the out-of-band write)", w2.Code) + } + if strings.Contains(w2.Body.String(), "late arrival") { + t.Errorf("body leaked out-of-band object: %q", w2.Body.String()) + } +} + +func TestCDNFileHandler_BucketMissNegativeCached(t *testing.T) { + // Same assertion for /_cdn/{fn}. The CDN origin needs the same + // shield since the edge can hammer it on cache misses. + t.Parallel() + db := newTestDB(t) + bkt := newTestBucket(t) + app := newTestApp(bkt, authorized) + app.CDNBaseURL = "https://cdn.example.com/" + + fn := validTestFn("missnegcdn") + + r1 := httptest.NewRequest(http.MethodGet, "/_cdn/"+signedToken(fn), nil) + r1 = r1.WithContext(db.Ctx()) + w1 := httptest.NewRecorder() + app.routes().ServeHTTP(w1, r1) + if w1.Code != http.StatusNotFound { + t.Fatalf("first request: status = %d, want 404", w1.Code) + } + + bw, err := bkt.NewWriter(t.Context(), fn, nil) + if err != nil { + t.Fatal(err) + } + bw.Write([]byte("late arrival")) + if err := bw.Close(); err != nil { + t.Fatal(err) + } + + r2 := httptest.NewRequest(http.MethodGet, "/_cdn/"+signedToken(fn), nil) + r2 = r2.WithContext(db.Ctx()) + w2 := httptest.NewRecorder() + app.routes().ServeHTTP(w2, r2) + if w2.Code != http.StatusNotFound { + t.Fatalf("second request: status = %d, want 404 (negative cache must mask the out-of-band write)", w2.Code) + } } func TestIsInternalClient(t *testing.T) { diff --git a/go.mod b/go.mod index 7e5185e..e0585c7 100644 --- a/go.mod +++ b/go.mod @@ -11,6 +11,7 @@ require ( github.com/lib/pq v1.12.3 github.com/moonrhythm/cachestore v0.0.0-20241226112208-fd22884c5b60 github.com/moonrhythm/parapet v0.13.6 + github.com/moonrhythm/sf v1.0.2 github.com/prometheus/client_golang v1.20.5 gocloud.dev v0.45.0 ) diff --git a/go.sum b/go.sum index bef8554..4431b8f 100644 --- a/go.sum +++ b/go.sum @@ -141,6 +141,8 @@ github.com/moonrhythm/parapet v0.13.6 h1:xpoXIEI1JiNVTs8F2edmZn75gMV2Dg6KQdpOlUt github.com/moonrhythm/parapet v0.13.6/go.mod h1:QEhl1pk8viGxC3xz1XLmzU8UZRT5qxHXgUln5hXlRCQ= github.com/moonrhythm/pq v0.0.0-20230504040008-09b0644d6569 h1:VDLLIg19HS2A7HuMY8ukFNrdgrRqGn5O0xQ4zrSgD/0= github.com/moonrhythm/pq v0.0.0-20230504040008-09b0644d6569/go.mod h1:V43Namo2SgpLTmiaQXjiKPuKvqGamIhPawsQF+pv8xw= +github.com/moonrhythm/sf v1.0.2 h1:W7uHinFjiO1A+7UNDANka78J9jk1nQz+W7IzK0bpj68= +github.com/moonrhythm/sf v1.0.2/go.mod h1:GVFIuH7VwfIn+wVi+kcjxAUcFsQ2JmiUrQuIg5yuono= github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA= github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ= github.com/planetscale/vtprotobuf v0.6.1-0.20240319094008-0393e58bdf10 h1:GFCKgmp0tecUJ0sJuv4pzYCqS9+RGSn52M3FUwPs+uo= diff --git a/handler.go b/handler.go index 9d6c48a..eb509a5 100644 --- a/handler.go +++ b/handler.go @@ -2,8 +2,10 @@ package main import ( "context" + "crypto/hmac" "crypto/rand" - "encoding/base64" + "crypto/sha256" + "encoding/hex" "encoding/json" "fmt" "io" @@ -23,6 +25,7 @@ type App struct { BaseURL string CDNBaseURL string InternalSecret string + SignKey []byte checkAuth func(ctx context.Context, auth, project, projectID string) AuthResult } @@ -32,8 +35,8 @@ func (a *App) routes() http.Handler { w.Write([]byte("Deploys.app Dropbox Service")) }) mux.HandleFunc("POST /{$}", a.uploadHandler) - mux.HandleFunc("GET /files/{fn}", a.fileHandler) - mux.HandleFunc("GET /_cdn/{fn}", a.cdnFileHandler) + mux.HandleFunc("GET /files/{token}", a.fileHandler) + mux.HandleFunc("GET /_cdn/{token}", a.cdnFileHandler) mux.HandleFunc("POST /internal/gc", a.gcHandler) return mux } @@ -112,16 +115,95 @@ func (a *App) uploadHandler(w http.ResponseWriter, r *http.Request) { json.NewEncoder(w).Encode(map[string]any{ "ok": true, "result": map[string]any{ - "downloadUrl": a.BaseURL + fn, + "downloadUrl": a.BaseURL + makeToken(a.SignKey, fn), "expiresAt": expiresAt.Format(time.RFC3339), }, }) } +// Filename + signature scheme. +// +// `fn` is what we store in the bucket and the DB — a random 24-char +// alphanumeric string (~143 bits of entropy, drawn from +// [0-9A-Za-z] with rejection sampling). No special chars, so URLs are +// clean and we don't need URL-safe-base64 tricks. +// +// `token` is what appears in the public URL: `fn` + "-" + 20-char +// lowercase-hex HMAC-SHA256 tag (80-bit) keyed by App.SignKey. The +// handlers check the tag before any DB or GCS work, so a flood of +// random `/files/{token}` requests by an attacker who doesn't know +// the key gets 404'd in CPU. The "-" separator means we can change +// `fnLen` later without invalidating old tokens — parseToken splits +// on it instead of relying on a fixed cut point. fn is alphanumeric +// and sig is hex, so neither side can contain a "-". +const ( + fnLen = 24 + sigLen = 20 // 80 bits of HMAC, hex-encoded + tokenSep = "-" +) + +const fnAlphabet = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz" + +// generateFilename returns a fresh fnLen-char random alphanumeric fn. +// Uses rejection sampling so each char is uniformly drawn from +// fnAlphabet (avoiding the mod-bias you'd get from `b%62`). func generateFilename() string { - b := make([]byte, 64) - rand.Read(b) - return base64.RawURLEncoding.EncodeToString(b) + out := make([]byte, fnLen) + buf := make([]byte, fnLen*2) // slack for rejection + pos := 0 + for pos < fnLen { + if _, err := rand.Read(buf); err != nil { + panic(err) // crypto/rand failing is unrecoverable + } + for _, b := range buf { + if pos >= fnLen { + break + } + // 62 * 4 = 248 is the largest unbiased range under 256; + // reject 248..255. + if b < 248 { + out[pos] = fnAlphabet[b%62] + pos++ + } + } + } + return string(out) +} + +// signFilename returns the sigLen-char lowercase-hex HMAC tag for fn +// under key. Caller's responsibility to keep `key` secret. +func signFilename(key []byte, fn string) string { + mac := hmac.New(sha256.New, key) + mac.Write([]byte(fn)) + sum := mac.Sum(nil) + return hex.EncodeToString(sum[:sigLen/2]) +} + +// makeToken builds the public URL token for fn: `fn + "-" + sig`. +func makeToken(key []byte, fn string) string { + return fn + tokenSep + signFilename(key, fn) +} + +// parseToken splits a public URL token on tokenSep, verifies the HMAC +// in constant time, and returns the fn on success. On any failure — +// missing separator, empty side, bad sig, anything — it returns +// (_, false) and the handler 404s without touching the DB or GCS. +// This is the primary DDoS shield against random `/files/{garbage}` +// floods. The split is structural (on the separator) rather than +// positional, so changing fnLen later won't invalidate old tokens +// still in circulation. +func parseToken(key []byte, token string) (string, bool) { + idx := strings.IndexByte(token, tokenSep[0]) + if idx <= 0 || idx == len(token)-1 { + return "", false + } + fn := token[:idx] + sig := token[idx+1:] + expected := signFilename(key, fn) + if !hmac.Equal([]byte(sig), []byte(expected)) { + return "", false + } + return fn, true } func escapeFilename(s string) string { diff --git a/handler_test.go b/handler_test.go index f65dab8..9c9f50f 100644 --- a/handler_test.go +++ b/handler_test.go @@ -23,6 +23,42 @@ func newTestBucket(t *testing.T) *blob.Bucket { return bkt } +// validTestFn returns an fnLen-char fn drawn from fnAlphabet, with the +// descriptive suffix preserved so failing tests are still readable. +// Distinct suffixes produce distinct fns, which keeps parallel tests +// isolated in the in-process per-fn cache. +func validTestFn(suffix string) string { + if len(suffix) > fnLen { + panic("validTestFn: suffix too long: " + suffix) + } + return strings.Repeat("a", fnLen-len(suffix)) + suffix +} + +// testSignKey is the fixed HMAC key used in every test app. Tests build +// signed tokens with makeToken(testSignKey, fn) and the handlers verify +// against this same key. +var testSignKey = []byte("test-sign-key-do-not-use-in-prod") + +// signedToken wraps a test fn into the public URL token by HMAC-signing +// it with testSignKey. Tests put this in the URL path; the handler +// parses it back into fn for the bucket/DB lookup. +func signedToken(fn string) string { + return makeToken(testSignKey, fn) +} + +// tokenToFn is the inverse: pull the token off a downloadUrl, verify +// it, and return the fn that's actually stored in the bucket and DB. +// Tests use this when they need to inspect the upload's bucket object. +func tokenToFn(t *testing.T, downloadURL string) string { + t.Helper() + token := strings.TrimPrefix(downloadURL, "https://example.com/") + fn, ok := parseToken(testSignKey, token) + if !ok { + t.Fatalf("downloadUrl %q does not contain a valid signed token", downloadURL) + } + return fn +} + func countObjects(t *testing.T, bkt *blob.Bucket) int { t.Helper() iter := bkt.List(nil) @@ -50,6 +86,7 @@ func newTestApp(bkt *blob.Bucket, authFn func(context.Context, string, string, s return &App{ Bucket: bkt, BaseURL: "https://example.com/", + SignKey: testSignKey, checkAuth: authFn, } } @@ -301,7 +338,7 @@ func TestUpload_FilenameFromQuery(t *testing.T) { t.Fatal("expected ok=true") } - fn := strings.TrimPrefix(resp.Result.DownloadURL, "https://example.com/") + fn := tokenToFn(t, resp.Result.DownloadURL) attrs, err := bkt.Attributes(t.Context(), fn) if err != nil { t.Fatal(err) @@ -331,7 +368,7 @@ func TestUpload_FilenameFromHeader(t *testing.T) { t.Fatal("expected ok=true") } - fn := strings.TrimPrefix(resp.Result.DownloadURL, "https://example.com/") + fn := tokenToFn(t, resp.Result.DownloadURL) attrs, err := bkt.Attributes(t.Context(), fn) if err != nil { t.Fatal(err) @@ -360,7 +397,7 @@ func TestUpload_FilenameWithQuotesEscaped(t *testing.T) { t.Fatal("expected ok=true") } - fn := strings.TrimPrefix(resp.Result.DownloadURL, "https://example.com/") + fn := tokenToFn(t, resp.Result.DownloadURL) attrs, err := bkt.Attributes(t.Context(), fn) if err != nil { t.Fatal(err) @@ -385,7 +422,7 @@ func TestUpload_NoFilenameNoContentDisposition(t *testing.T) { var resp uploadResp json.NewDecoder(w.Body).Decode(&resp) - fn := strings.TrimPrefix(resp.Result.DownloadURL, "https://example.com/") + fn := tokenToFn(t, resp.Result.DownloadURL) attrs, err := bkt.Attributes(t.Context(), fn) if err != nil { t.Fatal(err) @@ -452,12 +489,143 @@ func TestGenerateFilename(t *testing.T) { if a == b { t.Error("expected unique filenames") } - if len(a) != 86 { - t.Errorf("filename length = %d, want 86", len(a)) + if len(a) != fnLen { + t.Errorf("filename length = %d, want %d", len(a), fnLen) } + // Must be strictly alphanumeric — no special chars at all. for _, c := range a { - if c == '+' || c == '/' || c == '=' { - t.Errorf("filename contains non-URL-safe char %q: %s", c, a) + switch { + case c >= '0' && c <= '9': + case c >= 'A' && c <= 'Z': + case c >= 'a' && c <= 'z': + default: + t.Errorf("filename contains non-alphanumeric char %q: %s", c, a) + } + } +} + +func TestGenerateFilename_DistributionLooksUniform(t *testing.T) { + t.Parallel() + // Rough check that rejection sampling doesn't bias the alphabet. + // With 1000 * fnLen = 24000 chars over 62 symbols, expected count + // per symbol is ~387. We just assert every symbol appears at least + // once, which would fail catastrophically if a whole half of the + // alphabet got dropped (e.g. by an off-by-one in the reject bound). + counts := map[byte]int{} + for i := 0; i < 1000; i++ { + for j := 0; j < len(generateFilename()); j++ { + counts[generateFilename()[j]]++ + } + } + for _, c := range []byte(fnAlphabet) { + if counts[c] == 0 { + t.Errorf("symbol %q never appeared in 1000 generated fns", c) + } + } +} + +func TestSignFilename_Deterministic(t *testing.T) { + t.Parallel() + fn := validTestFn("sigtest") + a := signFilename(testSignKey, fn) + b := signFilename(testSignKey, fn) + if a != b { + t.Errorf("signFilename not deterministic: %q vs %q", a, b) + } + if len(a) != sigLen { + t.Errorf("sig length = %d, want %d", len(a), sigLen) + } +} + +func TestSignFilename_KeyMatters(t *testing.T) { + t.Parallel() + fn := validTestFn("keytest") + a := signFilename(testSignKey, fn) + b := signFilename([]byte("different-key"), fn) + if a == b { + t.Errorf("sig should differ across keys; got %q for both", a) + } +} + +func TestParseToken_RoundTripsGeneratedToken(t *testing.T) { + t.Parallel() + for i := 0; i < 16; i++ { + fn := generateFilename() + token := makeToken(testSignKey, fn) + if want := fnLen + 1 + sigLen; len(token) != want { + t.Fatalf("token length = %d, want %d", len(token), want) + } + if !strings.Contains(token, tokenSep) { + t.Fatalf("token %q missing separator %q", token, tokenSep) + } + got, ok := parseToken(testSignKey, token) + if !ok || got != fn { + t.Errorf("parseToken(%q) = (%q, %v), want (%q, true)", token, got, ok, fn) + } + } +} + +func TestParseToken_AcceptsDifferentFnLengths(t *testing.T) { + // The whole point of the separator: parseToken must work for fns of + // any length, as long as the sig was produced by signFilename. This + // guards future changes to fnLen without breaking outstanding URLs. + t.Parallel() + for _, fn := range []string{ + "a", + "abc", + validTestFn("short"), + validTestFn("default"), + strings.Repeat("z", 64), // hypothetically larger fn + } { + token := makeToken(testSignKey, fn) + got, ok := parseToken(testSignKey, token) + if !ok || got != fn { + t.Errorf("parseToken round-trip failed for fn=%q: got=(%q, %v)", fn, got, ok) + } + } +} + +func TestParseToken_RejectsForgeries(t *testing.T) { + t.Parallel() + fn := validTestFn("forgery") + good := makeToken(testSignKey, fn) + + // Wrong key — handler must not trust whatever the attacker sends. + if _, ok := parseToken([]byte("wrong-key"), good); ok { + t.Error("parseToken accepted token signed under a different key") + } + + // Tamper with the sig portion (last char). + tampered := good[:len(good)-1] + "0" + if tampered == good { + tampered = good[:len(good)-1] + "1" + } + if _, ok := parseToken(testSignKey, tampered); ok { + t.Errorf("parseToken accepted tampered sig: %q", tampered) + } + + // Tamper with the fn portion (the sig no longer matches what we'd + // compute over the rewritten fn). + swapFn := "b" + good[1:] + if swapFn == good { + swapFn = "c" + good[1:] + } + if _, ok := parseToken(testSignKey, swapFn); ok { + t.Errorf("parseToken accepted token with rewritten fn: %q", swapFn) + } + + // Structural failures: missing separator, empty fn/sig, etc. + for _, bad := range []string{ + "", + "short", + "no-separator-but-also-clearly-not-a-real-token", + tokenSep + signFilename(testSignKey, ""), // empty fn + fn + tokenSep, // empty sig + fn, // sig missing entirely + strings.Repeat("a", fnLen+1+sigLen), // right total length, no separator at all + } { + if _, ok := parseToken(testSignKey, bad); ok { + t.Errorf("parseToken accepted bad token %q", bad) } } } diff --git a/main.go b/main.go index 7988da1..44f6db4 100644 --- a/main.go +++ b/main.go @@ -61,6 +61,10 @@ func main() { BaseURL: config.StringDefault("base_url", "https://dropbox.deploys.app/files/"), CDNBaseURL: config.String("cdn_base_url"), InternalSecret: config.String("internal_secret"), + // sign_key is the HMAC key used to sign download tokens. + // Required — rotating it invalidates all outstanding URLs (which + // is the right behavior if a key ever leaks). + SignKey: []byte(config.MustString("sign_key")), } promAddr := config.StringDefault("prom_addr", ":9187")