From 6f6be15387b3e8d03e8ad7ede09a5385a14e9fb6 Mon Sep 17 00:00:00 2001 From: Thanatat Tamtan Date: Sun, 24 May 2026 09:32:02 +0700 Subject: [PATCH 1/7] Refuse expired files with 410 instead of redirecting to the CDN MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Before this change, GET /files/{fn} for a file whose expires_at had passed would still redirect to the CDN (or stream directly) until the next GC run dropped the object — billing the project for the egress and, worse, potentially priming the CDN cache with bytes the user is no longer entitled to. Extend the per-fn DB lookup to read expires_at alongside project_id (both columns are immutable once written, so the existing 60s cache TTL is still safe) and refuse expired files with 410 Gone in both fileHandler and cdnFileHandler. The refusal runs *before* the download/egress metric increments and before the 307, so expired requests don't bill the project and the CDN never sees a redirect that could refresh its cache. cdnFileHandler gets the same check so an attacker who knows the crypto-random fn can't bypass via the unauthenticated /_cdn/{fn} origin endpoint. Files in the bucket without a DB row (insert-failed uploads) and files with a NULL expires_at are treated as non-expired — we have no basis to refuse them. Co-Authored-By: Claude Opus 4.7 --- files.go | 79 +++++++++++++++++++----- files_test.go | 165 +++++++++++++++++++++++++++++++++++++++++++++++--- 2 files changed, 220 insertions(+), 24 deletions(-) diff --git a/files.go b/files.go index 3a1bc2b..9a46e4e 100644 --- a/files.go +++ b/files.go @@ -16,7 +16,24 @@ import ( "gocloud.dev/gcerrors" ) -const fileProjectCacheTTL = 60 * time.Second +const fileMetaCacheTTL = 60 * time.Second + +// fileMeta is the file's row in the files table. Found is false when there +// is no row for the fn — that happens when the metadata insert failed +// during upload; we still serve the bucket bytes in that case rather than +// 404'ing, since we have no expiration to enforce against. +type fileMeta struct { + ProjectID string + ExpiresAt time.Time + Found bool +} + +// Expired reports whether the file has a recorded expiry that is now in +// the past. Files without a DB row (Found == false) and files with a NULL +// expires_at are treated as non-expired — we have no basis to refuse them. +func (m fileMeta) Expired() bool { + return m.Found && !m.ExpiresAt.IsZero() && m.ExpiresAt.Before(time.Now()) +} func (a *App) fileHandler(w http.ResponseWriter, r *http.Request) { fn := r.PathValue("fn") @@ -31,7 +48,15 @@ func (a *App) fileHandler(w http.ResponseWriter, r *http.Request) { return } - projectID := lookupFileProject(r.Context(), fn) + meta := lookupFile(r.Context(), fn) + if meta.Expired() { + // Refuse before counting metrics or redirecting — the GC will + // catch up and drop the object, but until then we don't want the + // CDN to cache (or the caller to receive) bytes the user is no + // longer entitled to. + http.Error(w, "file expired", http.StatusGone) + return + } // When a CDN is configured and the caller is not in-cluster, record the // download as if we were serving the full file (the CDN will handle the @@ -39,19 +64,20 @@ func (a *App) fileHandler(w http.ResponseWriter, r *http.Request) { // so the bill is an over-approximation; that matches the registry // pattern and is what /_internal/calculate-dropbox-usages assumes. if a.CDNBaseURL != "" && !isInternalClient(r) { - downloadCount.WithLabelValues(projectID).Inc() - egressBytes.WithLabelValues(projectID).Add(float64(attrs.Size)) + downloadCount.WithLabelValues(meta.ProjectID).Inc() + egressBytes.WithLabelValues(meta.ProjectID).Add(float64(attrs.Size)) http.Redirect(w, r, a.CDNBaseURL+fn, http.StatusTemporaryRedirect) return } - a.streamFile(w, r, fn, attrs, projectID) + a.streamFile(w, r, fn, attrs, meta.ProjectID) } // cdnFileHandler is the origin endpoint the CDN edge fetches from. It is // unauthenticated — file URLs are 86-char crypto-random and only reachable // if you know them — and skips the redirect/metrics so the CDN sees a -// plain streaming response. +// plain streaming response. Expired files are refused here too so the CDN +// can't refresh its cache with bytes the user is no longer entitled to. func (a *App) cdnFileHandler(w http.ResponseWriter, r *http.Request) { fn := r.PathValue("fn") @@ -65,6 +91,11 @@ func (a *App) cdnFileHandler(w http.ResponseWriter, r *http.Request) { return } + if lookupFile(r.Context(), fn).Expired() { + http.Error(w, "file expired", http.StatusGone) + return + } + a.streamFile(w, r, fn, attrs, "") } @@ -104,21 +135,37 @@ func (a *App) streamFile(w http.ResponseWriter, r *http.Request, fn string, attr } } -func lookupFileProject(ctx context.Context, fn string) string { +// lookupFile fetches the file's project owner and expiration from the DB. +// Both fields are immutable once written, so the in-process cache only +// needs to be short enough to absorb bursts on the same fn — not to track +// any state that can change underneath us. +func lookupFile(ctx context.Context, fn string) fileMeta { cacheKey := "fn|" + fn - if v, ok := cachestore.Get[string](cacheKey); ok { + if v, ok := cachestore.Get[fileMeta](cacheKey); ok { return v } - var projectID string + var ( + m fileMeta + expires sql.NullTime + ) err := pgctx.QueryRow(ctx, ` - SELECT project_id FROM files WHERE fn = $1 - `, fn).Scan(&projectID) - if err != nil && !errors.Is(err, sql.ErrNoRows) { - slog.Error("lookup file project", "fn", fn, "error", err) - return "" + SELECT project_id, expires_at FROM files WHERE fn = $1 + `, fn).Scan(&m.ProjectID, &expires) + switch { + case err == nil: + m.Found = true + if expires.Valid { + m.ExpiresAt = expires.Time + } + case errors.Is(err, sql.ErrNoRows): + // Leave m zero-valued (Found=false). Don't cache-poison on + // transient DB errors below, but a confirmed miss is fine. + default: + slog.Error("lookup file", "fn", fn, "error", err) + return fileMeta{} } - cachestore.Set(cacheKey, projectID, &cachestore.SetOptions{TTL: fileProjectCacheTTL}) - return projectID + cachestore.Set(cacheKey, m, &cachestore.SetOptions{TTL: fileMetaCacheTTL}) + return m } diff --git a/files_test.go b/files_test.go index 0efd4d1..5b298c3 100644 --- a/files_test.go +++ b/files_test.go @@ -126,7 +126,7 @@ func TestFileHandler_NoHeadersWhenAttrsEmpty(t *testing.T) { } } -func TestLookupFileProject_FromDB(t *testing.T) { +func TestLookupFile_FromDB(t *testing.T) { t.Parallel() db := newTestDB(t) ctx := db.Ctx() @@ -138,21 +138,170 @@ func TestLookupFileProject_FromDB(t *testing.T) { t.Fatal(err) } - if got := lookupFileProject(ctx, "lookupfile"); got != "proj-xyz" { - t.Errorf("lookupFileProject = %q, want proj-xyz", got) + got := lookupFile(ctx, "lookupfile") + if !got.Found { + t.Fatal("Found = false, want true") + } + if got.ProjectID != "proj-xyz" { + t.Errorf("ProjectID = %q, want proj-xyz", got.ProjectID) + } + if got.Expired() { + t.Errorf("Expired() = true, want false (expires_at is in the future)") } // Second call exercises the cache hit path. - if got := lookupFileProject(ctx, "lookupfile"); got != "proj-xyz" { - t.Errorf("lookupFileProject (cached) = %q, want proj-xyz", got) + if got := lookupFile(ctx, "lookupfile"); got.ProjectID != "proj-xyz" { + t.Errorf("ProjectID (cached) = %q, want proj-xyz", got.ProjectID) + } +} + +func TestLookupFile_NotFound(t *testing.T) { + t.Parallel() + db := newTestDB(t) + + got := lookupFile(db.Ctx(), "missingfile") + if got.Found { + t.Errorf("Found = true, want false for missing fn") + } + if got.ProjectID != "" { + t.Errorf("ProjectID = %q, want empty for missing fn", got.ProjectID) + } + if got.Expired() { + t.Errorf("Expired() = true, want false when not found (nothing to enforce)") + } +} + +func TestLookupFile_Expired(t *testing.T) { + t.Parallel() + db := newTestDB(t) + ctx := db.Ctx() + + if _, err := pgctx.Exec(ctx, ` + INSERT INTO files (fn, project_id, size, filename, ttl, expires_at) + VALUES ('expiredmeta', 'proj-xyz', 1, 'x', 1, now() - interval '1 hour') + `); err != nil { + t.Fatal(err) + } + + got := lookupFile(ctx, "expiredmeta") + if !got.Found { + t.Fatal("Found = false, want true") + } + if !got.Expired() { + t.Errorf("Expired() = false, want true (expires_at is in the past)") + } +} + +func TestFileHandler_ExpiredReturnsGone(t *testing.T) { + t.Parallel() + db := newTestDB(t) + bkt := newTestBucket(t) + app := newTestApp(bkt, authorized) + + // Object still in the bucket — GC hasn't run yet — but the DB row + // says it expired an hour ago. + bw, err := bkt.NewWriter(t.Context(), "expiredfile", nil) + if err != nil { + t.Fatal(err) + } + bw.Write([]byte("stale bytes")) + if err := bw.Close(); err != nil { + t.Fatal(err) + } + + ctx := db.Ctx() + if _, err := pgctx.Exec(ctx, ` + INSERT INTO files (fn, project_id, size, filename, ttl, expires_at) + VALUES ('expiredfile', 'proj-xyz', 11, 'x', 1, now() - interval '1 hour') + `); err != nil { + t.Fatal(err) + } + + r := httptest.NewRequest(http.MethodGet, "/files/expiredfile", nil) + r = r.WithContext(ctx) + r.SetPathValue("fn", "expiredfile") + w := httptest.NewRecorder() + app.fileHandler(w, r) + + if w.Code != http.StatusGone { + t.Fatalf("status = %d, want 410", w.Code) + } + if body := w.Body.String(); strings.Contains(body, "stale bytes") { + t.Errorf("body leaked object content: %q", body) } } -func TestLookupFileProject_NotFound(t *testing.T) { +func TestFileHandler_ExpiredOverridesCDNRedirect(t *testing.T) { t.Parallel() db := newTestDB(t) + bkt := newTestBucket(t) + app := newTestApp(bkt, authorized) + app.CDNBaseURL = "https://cdn.example.com/" + + bw, err := bkt.NewWriter(t.Context(), "expiredcdn", nil) + if err != nil { + t.Fatal(err) + } + bw.Write([]byte("stale")) + if err := bw.Close(); err != nil { + t.Fatal(err) + } + + ctx := db.Ctx() + if _, err := pgctx.Exec(ctx, ` + INSERT INTO files (fn, project_id, size, filename, ttl, expires_at) + VALUES ('expiredcdn', 'proj-xyz', 5, 'x', 1, now() - interval '1 hour') + `); err != nil { + t.Fatal(err) + } + + r := httptest.NewRequest(http.MethodGet, "/files/expiredcdn", nil) + r = r.WithContext(ctx) + r.SetPathValue("fn", "expiredcdn") + w := httptest.NewRecorder() + app.fileHandler(w, r) + + if w.Code != http.StatusGone { + t.Fatalf("status = %d, want 410 (must not redirect to CDN)", w.Code) + } + if loc := w.Header().Get("Location"); loc != "" { + t.Errorf("Location = %q, want empty (no redirect for expired files)", loc) + } +} - if got := lookupFileProject(db.Ctx(), "missingfile"); got != "" { - t.Errorf("lookupFileProject = %q, want empty for missing fn", got) +func TestCDNFileHandler_ExpiredReturnsGone(t *testing.T) { + t.Parallel() + db := newTestDB(t) + bkt := newTestBucket(t) + app := newTestApp(bkt, authorized) + app.CDNBaseURL = "https://cdn.example.com/" + + bw, err := bkt.NewWriter(t.Context(), "expiredorigin", nil) + if err != nil { + t.Fatal(err) + } + bw.Write([]byte("stale origin")) + if err := bw.Close(); err != nil { + t.Fatal(err) + } + + ctx := db.Ctx() + if _, err := pgctx.Exec(ctx, ` + INSERT INTO files (fn, project_id, size, filename, ttl, expires_at) + VALUES ('expiredorigin', 'proj-xyz', 12, 'x', 1, now() - interval '1 hour') + `); err != nil { + t.Fatal(err) + } + + r := httptest.NewRequest(http.MethodGet, "/_cdn/expiredorigin", nil) + r = r.WithContext(ctx) + w := httptest.NewRecorder() + app.routes().ServeHTTP(w, r) + + if w.Code != http.StatusGone { + t.Fatalf("status = %d, want 410 (CDN origin must refuse expired files)", w.Code) + } + if body := w.Body.String(); strings.Contains(body, "stale origin") { + t.Errorf("body leaked object content to CDN edge: %q", body) } } From d5fd48838bd126e3512c3dac8c815fbfe6e1f253 Mon Sep 17 00:00:00 2001 From: Thanatat Tamtan Date: Sun, 24 May 2026 09:39:29 +0700 Subject: [PATCH 2/7] Check expired cache before Bucket.Attributes to absorb DDoS in-process MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The previous ordering — Bucket.Attributes then lookupFile — meant that a flood of requests against a single expired fn cached the DB row fine (60s TTL on an immutable row protects Postgres) but still hit GCS on every request, because the bucket call ran first. That's one billed Class B operation per request and direct exposure to GCS read quotas. Move the lookupFile/Expired check ahead of Bucket.Attributes in both fileHandler and cdnFileHandler. After the first request hydrates the cache, subsequent requests for the same expired fn return 410 entirely from memory and never touch GCS. For non-expired files this just reorders two calls we'd make anyway — no change in latency. For files with no DB row (insert-failed uploads) lookupFile returns Found=false and we fall through to the bucket exactly as before. Added two tests that prove the bucket is bypassed by inserting an expired DB row against an *empty* bucket: a regression would surface as 404 (bucket-NotFound) instead of 410. Co-Authored-By: Claude Opus 4.7 --- files.go | 34 ++++++++++++++++++-------------- files_test.go | 54 +++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 73 insertions(+), 15 deletions(-) diff --git a/files.go b/files.go index 9a46e4e..77dfa04 100644 --- a/files.go +++ b/files.go @@ -38,6 +38,18 @@ func (m fileMeta) Expired() bool { func (a *App) fileHandler(w http.ResponseWriter, r *http.Request) { fn := r.PathValue("fn") + // Check the DB *before* touching GCS so a DDoS against an expired fn + // is absorbed by the in-process cache (60s TTL on an immutable row) + // instead of becoming one billed Bucket.Attributes call per request. + // For non-expired files this just reorders two calls we'd make anyway; + // for files with no DB row (insert-failed uploads) we fall through to + // the bucket exactly as before. + meta := lookupFile(r.Context(), fn) + if meta.Expired() { + http.Error(w, "file expired", http.StatusGone) + return + } + attrs, err := a.Bucket.Attributes(r.Context(), fn) if err != nil { if gcerrors.Code(err) == gcerrors.NotFound { @@ -48,16 +60,6 @@ func (a *App) fileHandler(w http.ResponseWriter, r *http.Request) { return } - meta := lookupFile(r.Context(), fn) - if meta.Expired() { - // Refuse before counting metrics or redirecting — the GC will - // catch up and drop the object, but until then we don't want the - // CDN to cache (or the caller to receive) bytes the user is no - // longer entitled to. - http.Error(w, "file expired", http.StatusGone) - return - } - // When a CDN is configured and the caller is not in-cluster, record the // download as if we were serving the full file (the CDN will handle the // actual bytes) and redirect. Cache hits never come back to this origin, @@ -81,6 +83,13 @@ func (a *App) fileHandler(w http.ResponseWriter, r *http.Request) { func (a *App) cdnFileHandler(w http.ResponseWriter, r *http.Request) { fn := r.PathValue("fn") + // Same DDoS-protection ordering as fileHandler: cached expired check + // first so the edge can't amplify into GCS by chasing an expired URL. + if lookupFile(r.Context(), fn).Expired() { + http.Error(w, "file expired", http.StatusGone) + return + } + attrs, err := a.Bucket.Attributes(r.Context(), fn) if err != nil { if gcerrors.Code(err) == gcerrors.NotFound { @@ -91,11 +100,6 @@ func (a *App) cdnFileHandler(w http.ResponseWriter, r *http.Request) { return } - if lookupFile(r.Context(), fn).Expired() { - http.Error(w, "file expired", http.StatusGone) - return - } - a.streamFile(w, r, fn, attrs, "") } diff --git a/files_test.go b/files_test.go index 5b298c3..1b8c1a3 100644 --- a/files_test.go +++ b/files_test.go @@ -230,6 +230,60 @@ func TestFileHandler_ExpiredReturnsGone(t *testing.T) { } } +func TestFileHandler_ExpiredSkipsBucket(t *testing.T) { + // Expired DB row, *empty bucket* — proves the expired check runs + // before Bucket.Attributes. Under DDoS this is the difference between + // every request burning a GCS Class B operation and every request + // being a cheap cache hit. If the order regresses, this test starts + // returning 404 (bucket-NotFound) instead of 410. + t.Parallel() + db := newTestDB(t) + app := newTestApp(newTestBucket(t), authorized) + + ctx := db.Ctx() + if _, err := pgctx.Exec(ctx, ` + INSERT INTO files (fn, project_id, size, filename, ttl, expires_at) + VALUES ('expiredghost', 'proj-xyz', 99, 'x', 1, now() - interval '1 hour') + `); err != nil { + t.Fatal(err) + } + + r := httptest.NewRequest(http.MethodGet, "/files/expiredghost", nil) + r = r.WithContext(ctx) + r.SetPathValue("fn", "expiredghost") + w := httptest.NewRecorder() + app.fileHandler(w, r) + + if w.Code != http.StatusGone { + t.Fatalf("status = %d, want 410 (expired check must run before bucket call)", w.Code) + } +} + +func TestCDNFileHandler_ExpiredSkipsBucket(t *testing.T) { + // Same DDoS-protection assertion for the /_cdn/{fn} origin path. + t.Parallel() + db := newTestDB(t) + app := newTestApp(newTestBucket(t), authorized) + app.CDNBaseURL = "https://cdn.example.com/" + + ctx := db.Ctx() + if _, err := pgctx.Exec(ctx, ` + INSERT INTO files (fn, project_id, size, filename, ttl, expires_at) + VALUES ('expiredghostcdn', 'proj-xyz', 99, 'x', 1, now() - interval '1 hour') + `); err != nil { + t.Fatal(err) + } + + r := httptest.NewRequest(http.MethodGet, "/_cdn/expiredghostcdn", nil) + r = r.WithContext(ctx) + w := httptest.NewRecorder() + app.routes().ServeHTTP(w, r) + + if w.Code != http.StatusGone { + t.Fatalf("status = %d, want 410 (expired check must run before bucket call)", w.Code) + } +} + func TestFileHandler_ExpiredOverridesCDNRedirect(t *testing.T) { t.Parallel() db := newTestDB(t) From 881cecc5fa7c9176a9f942298931b07d6338c0ff Mon Sep 17 00:00:00 2001 From: Thanatat Tamtan Date: Sun, 24 May 2026 10:02:57 +0700 Subject: [PATCH 3/7] Shield DDoS against non-existent fns with format validation + negative cache MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Two attack vectors the previous code didn't cover: 1) Random-garbage fns. Each request was a unique cache key, so the per-fn DB cache never helped. Every request burned one Postgres query + one GCS Class B op. An attacker can generate millions of unique garbage fns per second per box. 2) Same-valid-but-dead fn (e.g. a URL the attacker held after expiry and GC). The negative DB cache amortized Postgres after request 1, but Bucket.Attributes still ran on every subsequent request. Two-layer fix in both fileHandler and cdnFileHandler: - isValidFilename(fn): generateFilename produces exactly 86 chars of URL-safe base64 (no padding). Anything else can't possibly exist in the system, so reject in CPU before any I/O. Kills vector (1) entirely — a flood of garbage fns becomes pure ServeHTTP overhead. - BucketMissing flag on fileMeta: after a confirmed Bucket.Attributes NotFound for a fn with no DB row, write the negative back into the per-fn cache. Subsequent requests within the 60s TTL return 404 from memory without re-hitting GCS. Fixes vector (2). Also extracted fileCacheKey(fn) so lookupFile and the new cache write share one source of truth. Test changes: - New TestIsValidFilename links the validator to generateFilename so the two can't drift, plus alphabet/length boundary cases. - New TestFileHandler_InvalidFilenameRejected / CDN variant: garbage fns must 404 from the validator without touching DB or GCS. - New TestFileHandler_BucketMissNegativeCached / CDN variant: request a missing fn, then write the object out-of-band into the bucket; a second request must still 404, proving Bucket.Attributes was bypassed by the cached negative. - validTestFn(suffix) helper pads descriptive test names out to 86 chars so existing happy-path tests still pass the validator while remaining readable on failure. - TestFileHandler_CDNRedirect body-size threshold replaced with a direct "body must not contain the object bytes" check, since the default http.Redirect body now embeds an 86-char URL. Co-Authored-By: Claude Opus 4.7 --- files.go | 64 ++++++++++--- files_test.go | 239 ++++++++++++++++++++++++++++++++++++++---------- handler.go | 25 +++++ handler_test.go | 46 ++++++++++ 4 files changed, 314 insertions(+), 60 deletions(-) diff --git a/files.go b/files.go index 77dfa04..4bb2cbf 100644 --- a/files.go +++ b/files.go @@ -18,14 +18,20 @@ import ( const fileMetaCacheTTL = 60 * time.Second -// fileMeta is the file's row in the files table. Found is false when there -// is no row for the fn — that happens when the metadata insert failed -// during upload; we still serve the bucket bytes in that case rather than -// 404'ing, since we have no expiration to enforce against. +// fileMeta is the file's row in the files table plus a couple of derived +// negative-cache flags. Found is false when there is no row for the fn — +// that happens when the metadata insert failed during upload; we still +// serve the bucket bytes in that case rather than 404'ing, since we have +// no expiration to enforce against. BucketMissing is set by the file +// handlers (not by lookupFile) after a confirmed Bucket.Attributes +// NotFound, so subsequent requests within the cache TTL can return 404 +// without re-hitting GCS — that's what stops a DDoS against a dead-but- +// valid-format fn from amplifying into a GCS read flood. type fileMeta struct { - ProjectID string - ExpiresAt time.Time - Found bool + ProjectID string + ExpiresAt time.Time + Found bool + BucketMissing bool } // Expired reports whether the file has a recorded expiry that is now in @@ -35,9 +41,19 @@ func (m fileMeta) Expired() bool { return m.Found && !m.ExpiresAt.IsZero() && m.ExpiresAt.Before(time.Now()) } +func fileCacheKey(fn string) string { return "fn|" + fn } + func (a *App) fileHandler(w http.ResponseWriter, r *http.Request) { fn := r.PathValue("fn") + // Reject obviously-invalid fns in CPU. A flood of random-garbage fns + // would otherwise miss the per-fn cache on every request (each key is + // unique) and burn one DB query + one GCS Class B op apiece. + if !isValidFilename(fn) { + http.NotFound(w, r) + return + } + // Check the DB *before* touching GCS so a DDoS against an expired fn // is absorbed by the in-process cache (60s TTL on an immutable row) // instead of becoming one billed Bucket.Attributes call per request. @@ -49,10 +65,22 @@ func (a *App) fileHandler(w http.ResponseWriter, r *http.Request) { http.Error(w, "file expired", http.StatusGone) return } + if meta.BucketMissing { + // Cached negative from a previous bucket NotFound — see the + // Attributes branch below. + http.NotFound(w, r) + return + } attrs, err := a.Bucket.Attributes(r.Context(), fn) if err != nil { if gcerrors.Code(err) == gcerrors.NotFound { + // Cache the negative so a flood against the same dead fn + // doesn't re-bill a Class B op per request. Safe even in the + // rare orphan case (Found=true, bucket gone): the file is + // unreachable either way until an operator cleans it up. + meta.BucketMissing = true + cachestore.Set(fileCacheKey(fn), meta, &cachestore.SetOptions{TTL: fileMetaCacheTTL}) http.NotFound(w, r) return } @@ -83,16 +111,30 @@ func (a *App) fileHandler(w http.ResponseWriter, r *http.Request) { func (a *App) cdnFileHandler(w http.ResponseWriter, r *http.Request) { fn := r.PathValue("fn") - // Same DDoS-protection ordering as fileHandler: cached expired check - // first so the edge can't amplify into GCS by chasing an expired URL. - if lookupFile(r.Context(), fn).Expired() { + // Same DDoS-protection ladder as fileHandler: validate format, then + // consult the per-fn cache, then GCS. The edge would otherwise + // amplify a random-garbage flood or an expired-URL probe into one + // GCS op per request. + if !isValidFilename(fn) { + http.NotFound(w, r) + return + } + + meta := lookupFile(r.Context(), fn) + if meta.Expired() { http.Error(w, "file expired", http.StatusGone) return } + if meta.BucketMissing { + http.NotFound(w, r) + return + } attrs, err := a.Bucket.Attributes(r.Context(), fn) if err != nil { if gcerrors.Code(err) == gcerrors.NotFound { + meta.BucketMissing = true + cachestore.Set(fileCacheKey(fn), meta, &cachestore.SetOptions{TTL: fileMetaCacheTTL}) http.NotFound(w, r) return } @@ -144,7 +186,7 @@ func (a *App) streamFile(w http.ResponseWriter, r *http.Request, fn string, attr // needs to be short enough to absorb bursts on the same fn — not to track // any state that can change underneath us. func lookupFile(ctx context.Context, fn string) fileMeta { - cacheKey := "fn|" + fn + cacheKey := fileCacheKey(fn) if v, ok := cachestore.Get[fileMeta](cacheKey); ok { return v } diff --git a/files_test.go b/files_test.go index 1b8c1a3..23eb544 100644 --- a/files_test.go +++ b/files_test.go @@ -16,7 +16,8 @@ func TestFileHandler_Success(t *testing.T) { bkt := newTestBucket(t) app := newTestApp(bkt, authorized) - bw, err := bkt.NewWriter(t.Context(), "testfile", &blob.WriterOptions{ + fn := validTestFn("testfile") + bw, err := bkt.NewWriter(t.Context(), fn, &blob.WriterOptions{ CacheControl: "public, max-age=86400", ContentDisposition: `attachment; filename="hello.txt"`, }) @@ -30,9 +31,9 @@ func TestFileHandler_Success(t *testing.T) { t.Fatal(err) } - r := httptest.NewRequest(http.MethodGet, "/files/testfile", nil) + r := httptest.NewRequest(http.MethodGet, "/files/"+fn, nil) r = r.WithContext(db.Ctx()) - r.SetPathValue("fn", "testfile") + r.SetPathValue("fn", fn) w := httptest.NewRecorder() app.fileHandler(w, r) @@ -58,9 +59,10 @@ func TestFileHandler_NotFound(t *testing.T) { db := newTestDB(t) app := newTestApp(newTestBucket(t), authorized) - r := httptest.NewRequest(http.MethodGet, "/files/doesnotexist", nil) + fn := validTestFn("notexist") + r := httptest.NewRequest(http.MethodGet, "/files/"+fn, nil) r = r.WithContext(db.Ctx()) - r.SetPathValue("fn", "doesnotexist") + r.SetPathValue("fn", fn) w := httptest.NewRecorder() app.fileHandler(w, r) @@ -75,7 +77,8 @@ func TestFileHandler_RouteIntegration(t *testing.T) { bkt := newTestBucket(t) app := newTestApp(bkt, authorized) - bw, err := bkt.NewWriter(t.Context(), "routefile", nil) + fn := validTestFn("routefile") + bw, err := bkt.NewWriter(t.Context(), fn, nil) if err != nil { t.Fatal(err) } @@ -84,7 +87,7 @@ func TestFileHandler_RouteIntegration(t *testing.T) { t.Fatal(err) } - r := httptest.NewRequest(http.MethodGet, "/files/routefile", nil) + r := httptest.NewRequest(http.MethodGet, "/files/"+fn, nil) r = r.WithContext(db.Ctx()) w := httptest.NewRecorder() app.routes().ServeHTTP(w, r) @@ -100,7 +103,8 @@ func TestFileHandler_NoHeadersWhenAttrsEmpty(t *testing.T) { bkt := newTestBucket(t) app := newTestApp(bkt, authorized) - bw, err := bkt.NewWriter(t.Context(), "plain", nil) + fn := validTestFn("plain") + bw, err := bkt.NewWriter(t.Context(), fn, nil) if err != nil { t.Fatal(err) } @@ -109,9 +113,9 @@ func TestFileHandler_NoHeadersWhenAttrsEmpty(t *testing.T) { t.Fatal(err) } - r := httptest.NewRequest(http.MethodGet, "/files/plain", nil) + r := httptest.NewRequest(http.MethodGet, "/files/"+fn, nil) r = r.WithContext(db.Ctx()) - r.SetPathValue("fn", "plain") + r.SetPathValue("fn", fn) w := httptest.NewRecorder() app.fileHandler(w, r) @@ -199,7 +203,8 @@ func TestFileHandler_ExpiredReturnsGone(t *testing.T) { // Object still in the bucket — GC hasn't run yet — but the DB row // says it expired an hour ago. - bw, err := bkt.NewWriter(t.Context(), "expiredfile", nil) + fn := validTestFn("expiredfile") + bw, err := bkt.NewWriter(t.Context(), fn, nil) if err != nil { t.Fatal(err) } @@ -211,14 +216,14 @@ func TestFileHandler_ExpiredReturnsGone(t *testing.T) { ctx := db.Ctx() if _, err := pgctx.Exec(ctx, ` INSERT INTO files (fn, project_id, size, filename, ttl, expires_at) - VALUES ('expiredfile', 'proj-xyz', 11, 'x', 1, now() - interval '1 hour') - `); err != nil { + VALUES ($1, 'proj-xyz', 11, 'x', 1, now() - interval '1 hour') + `, fn); err != nil { t.Fatal(err) } - r := httptest.NewRequest(http.MethodGet, "/files/expiredfile", nil) + r := httptest.NewRequest(http.MethodGet, "/files/"+fn, nil) r = r.WithContext(ctx) - r.SetPathValue("fn", "expiredfile") + r.SetPathValue("fn", fn) w := httptest.NewRecorder() app.fileHandler(w, r) @@ -240,17 +245,18 @@ func TestFileHandler_ExpiredSkipsBucket(t *testing.T) { db := newTestDB(t) app := newTestApp(newTestBucket(t), authorized) + fn := validTestFn("expiredghost") ctx := db.Ctx() if _, err := pgctx.Exec(ctx, ` INSERT INTO files (fn, project_id, size, filename, ttl, expires_at) - VALUES ('expiredghost', 'proj-xyz', 99, 'x', 1, now() - interval '1 hour') - `); err != nil { + VALUES ($1, 'proj-xyz', 99, 'x', 1, now() - interval '1 hour') + `, fn); err != nil { t.Fatal(err) } - r := httptest.NewRequest(http.MethodGet, "/files/expiredghost", nil) + r := httptest.NewRequest(http.MethodGet, "/files/"+fn, nil) r = r.WithContext(ctx) - r.SetPathValue("fn", "expiredghost") + r.SetPathValue("fn", fn) w := httptest.NewRecorder() app.fileHandler(w, r) @@ -266,15 +272,16 @@ func TestCDNFileHandler_ExpiredSkipsBucket(t *testing.T) { app := newTestApp(newTestBucket(t), authorized) app.CDNBaseURL = "https://cdn.example.com/" + fn := validTestFn("expiredghostcdn") ctx := db.Ctx() if _, err := pgctx.Exec(ctx, ` INSERT INTO files (fn, project_id, size, filename, ttl, expires_at) - VALUES ('expiredghostcdn', 'proj-xyz', 99, 'x', 1, now() - interval '1 hour') - `); err != nil { + VALUES ($1, 'proj-xyz', 99, 'x', 1, now() - interval '1 hour') + `, fn); err != nil { t.Fatal(err) } - r := httptest.NewRequest(http.MethodGet, "/_cdn/expiredghostcdn", nil) + r := httptest.NewRequest(http.MethodGet, "/_cdn/"+fn, nil) r = r.WithContext(ctx) w := httptest.NewRecorder() app.routes().ServeHTTP(w, r) @@ -291,7 +298,8 @@ func TestFileHandler_ExpiredOverridesCDNRedirect(t *testing.T) { app := newTestApp(bkt, authorized) app.CDNBaseURL = "https://cdn.example.com/" - bw, err := bkt.NewWriter(t.Context(), "expiredcdn", nil) + fn := validTestFn("expiredcdn") + bw, err := bkt.NewWriter(t.Context(), fn, nil) if err != nil { t.Fatal(err) } @@ -303,14 +311,14 @@ func TestFileHandler_ExpiredOverridesCDNRedirect(t *testing.T) { ctx := db.Ctx() if _, err := pgctx.Exec(ctx, ` INSERT INTO files (fn, project_id, size, filename, ttl, expires_at) - VALUES ('expiredcdn', 'proj-xyz', 5, 'x', 1, now() - interval '1 hour') - `); err != nil { + VALUES ($1, 'proj-xyz', 5, 'x', 1, now() - interval '1 hour') + `, fn); err != nil { t.Fatal(err) } - r := httptest.NewRequest(http.MethodGet, "/files/expiredcdn", nil) + r := httptest.NewRequest(http.MethodGet, "/files/"+fn, nil) r = r.WithContext(ctx) - r.SetPathValue("fn", "expiredcdn") + r.SetPathValue("fn", fn) w := httptest.NewRecorder() app.fileHandler(w, r) @@ -329,7 +337,8 @@ func TestCDNFileHandler_ExpiredReturnsGone(t *testing.T) { app := newTestApp(bkt, authorized) app.CDNBaseURL = "https://cdn.example.com/" - bw, err := bkt.NewWriter(t.Context(), "expiredorigin", nil) + fn := validTestFn("expiredorigin") + bw, err := bkt.NewWriter(t.Context(), fn, nil) if err != nil { t.Fatal(err) } @@ -341,12 +350,12 @@ func TestCDNFileHandler_ExpiredReturnsGone(t *testing.T) { ctx := db.Ctx() if _, err := pgctx.Exec(ctx, ` INSERT INTO files (fn, project_id, size, filename, ttl, expires_at) - VALUES ('expiredorigin', 'proj-xyz', 12, 'x', 1, now() - interval '1 hour') - `); err != nil { + VALUES ($1, 'proj-xyz', 12, 'x', 1, now() - interval '1 hour') + `, fn); err != nil { t.Fatal(err) } - r := httptest.NewRequest(http.MethodGet, "/_cdn/expiredorigin", nil) + r := httptest.NewRequest(http.MethodGet, "/_cdn/"+fn, nil) r = r.WithContext(ctx) w := httptest.NewRecorder() app.routes().ServeHTTP(w, r) @@ -366,7 +375,8 @@ func TestFileHandler_CDNRedirect(t *testing.T) { app := newTestApp(bkt, authorized) app.CDNBaseURL = "https://cdn.example.com/" - bw, err := bkt.NewWriter(t.Context(), "cdnfile", nil) + fn := validTestFn("cdnfile") + bw, err := bkt.NewWriter(t.Context(), fn, nil) if err != nil { t.Fatal(err) } @@ -375,22 +385,22 @@ func TestFileHandler_CDNRedirect(t *testing.T) { t.Fatal(err) } - r := httptest.NewRequest(http.MethodGet, "/files/cdnfile", nil) + r := httptest.NewRequest(http.MethodGet, "/files/"+fn, nil) r = r.WithContext(db.Ctx()) - r.SetPathValue("fn", "cdnfile") + r.SetPathValue("fn", fn) w := httptest.NewRecorder() app.fileHandler(w, r) if w.Code != http.StatusTemporaryRedirect { t.Fatalf("status = %d, want 307", w.Code) } - if loc := w.Header().Get("Location"); loc != "https://cdn.example.com/cdnfile" { - t.Errorf("Location = %q, want https://cdn.example.com/cdnfile", loc) + if loc := w.Header().Get("Location"); loc != "https://cdn.example.com/"+fn { + t.Errorf("Location = %q, want https://cdn.example.com/%s", loc, fn) } - if got := w.Body.Len(); got > 100 { - // the default redirect body is small (""); we shouldn't - // be streaming the object. - t.Errorf("body length = %d, expected a short redirect body", got) + // The default http.Redirect body is `Temporary Redirect` — + // it must not contain the object bytes. + if got := w.Body.String(); strings.Contains(got, "hello world") { + t.Errorf("body streamed the object instead of redirecting: %q", got) } } @@ -401,7 +411,8 @@ func TestFileHandler_CDNInternalClientStreams(t *testing.T) { app := newTestApp(bkt, authorized) app.CDNBaseURL = "https://cdn.example.com/" - bw, err := bkt.NewWriter(t.Context(), "internalfile", nil) + fn := validTestFn("internalfile") + bw, err := bkt.NewWriter(t.Context(), fn, nil) if err != nil { t.Fatal(err) } @@ -410,9 +421,9 @@ func TestFileHandler_CDNInternalClientStreams(t *testing.T) { t.Fatal(err) } - r := httptest.NewRequest(http.MethodGet, "/files/internalfile", nil) + r := httptest.NewRequest(http.MethodGet, "/files/"+fn, nil) r = r.WithContext(db.Ctx()) - r.SetPathValue("fn", "internalfile") + r.SetPathValue("fn", fn) r.Header.Set("X-Real-Ip", "10.0.0.5") // private IP -> internal w := httptest.NewRecorder() app.fileHandler(w, r) @@ -432,7 +443,8 @@ func TestFileHandler_CDNRedirectPublicXRealIP(t *testing.T) { app := newTestApp(bkt, authorized) app.CDNBaseURL = "https://cdn.example.com/" - bw, err := bkt.NewWriter(t.Context(), "publicfile", nil) + fn := validTestFn("publicfile") + bw, err := bkt.NewWriter(t.Context(), fn, nil) if err != nil { t.Fatal(err) } @@ -441,9 +453,9 @@ func TestFileHandler_CDNRedirectPublicXRealIP(t *testing.T) { t.Fatal(err) } - r := httptest.NewRequest(http.MethodGet, "/files/publicfile", nil) + r := httptest.NewRequest(http.MethodGet, "/files/"+fn, nil) r = r.WithContext(db.Ctx()) - r.SetPathValue("fn", "publicfile") + r.SetPathValue("fn", fn) r.Header.Set("X-Real-Ip", "203.0.113.5") // public IP -> CDN path w := httptest.NewRecorder() app.fileHandler(w, r) @@ -460,7 +472,8 @@ func TestCDNFileHandler_Streams(t *testing.T) { app := newTestApp(bkt, authorized) app.CDNBaseURL = "https://cdn.example.com/" - bw, err := bkt.NewWriter(t.Context(), "origin", &blob.WriterOptions{ + fn := validTestFn("origin") + bw, err := bkt.NewWriter(t.Context(), fn, &blob.WriterOptions{ CacheControl: "public, max-age=86400", }) if err != nil { @@ -471,7 +484,7 @@ func TestCDNFileHandler_Streams(t *testing.T) { t.Fatal(err) } - r := httptest.NewRequest(http.MethodGet, "/_cdn/origin", nil) + r := httptest.NewRequest(http.MethodGet, "/_cdn/"+fn, nil) r = r.WithContext(db.Ctx()) w := httptest.NewRecorder() app.routes().ServeHTTP(w, r) @@ -493,7 +506,8 @@ func TestCDNFileHandler_NotFound(t *testing.T) { app := newTestApp(newTestBucket(t), authorized) app.CDNBaseURL = "https://cdn.example.com/" - r := httptest.NewRequest(http.MethodGet, "/_cdn/nope", nil) + fn := validTestFn("cdnnope") + r := httptest.NewRequest(http.MethodGet, "/_cdn/"+fn, nil) r = r.WithContext(db.Ctx()) w := httptest.NewRecorder() app.routes().ServeHTTP(w, r) @@ -503,6 +517,133 @@ func TestCDNFileHandler_NotFound(t *testing.T) { } } +func TestFileHandler_InvalidFilenameRejected(t *testing.T) { + // Garbage fns should 404 from the validator without touching DB or + // GCS. We pass a nil-context request: if the validator ever stops + // short-circuiting, lookupFile will panic on the missing pgctx and + // blow the test up — which is the right failure mode for a + // regression that removes the DDoS shield. + t.Parallel() + app := newTestApp(newTestBucket(t), authorized) + + cases := []string{ + "", + "short", + "../../etc/passwd", + strings.Repeat("a", 85), // 1 short + strings.Repeat("a", 87), // 1 long + strings.Repeat("a", 85) + "!", // bad char + strings.Repeat("a", 85) + "+", // standard-base64 char, not URL-safe + strings.Repeat("a", 85) + "/", // ditto + } + for _, fn := range cases { + t.Run("fn="+fn, func(t *testing.T) { + t.Parallel() + r := httptest.NewRequest(http.MethodGet, "/files/x", nil) + r.SetPathValue("fn", fn) + w := httptest.NewRecorder() + app.fileHandler(w, r) + if w.Code != http.StatusNotFound { + t.Errorf("fn=%q: status = %d, want 404", fn, w.Code) + } + }) + } +} + +func TestCDNFileHandler_InvalidFilenameRejected(t *testing.T) { + t.Parallel() + app := newTestApp(newTestBucket(t), authorized) + app.CDNBaseURL = "https://cdn.example.com/" + + r := httptest.NewRequest(http.MethodGet, "/_cdn/short", nil) + w := httptest.NewRecorder() + app.routes().ServeHTTP(w, r) + if w.Code != http.StatusNotFound { + t.Errorf("status = %d, want 404", w.Code) + } +} + +func TestFileHandler_BucketMissNegativeCached(t *testing.T) { + // First request misses everywhere → 404 and caches BucketMissing. + // Then we write the object out-of-band into the bucket. The second + // request must still 404, proving Bucket.Attributes was not called + // — that's the DDoS protection for a flood against a known-bad fn + // (e.g. a URL the attacker held after expiry+GC). + t.Parallel() + db := newTestDB(t) + bkt := newTestBucket(t) + app := newTestApp(bkt, authorized) + + fn := validTestFn("missneg") + + r1 := httptest.NewRequest(http.MethodGet, "/files/"+fn, nil) + r1 = r1.WithContext(db.Ctx()) + r1.SetPathValue("fn", fn) + w1 := httptest.NewRecorder() + app.fileHandler(w1, r1) + if w1.Code != http.StatusNotFound { + t.Fatalf("first request: status = %d, want 404", w1.Code) + } + + bw, err := bkt.NewWriter(t.Context(), fn, nil) + if err != nil { + t.Fatal(err) + } + bw.Write([]byte("late arrival")) + if err := bw.Close(); err != nil { + t.Fatal(err) + } + + r2 := httptest.NewRequest(http.MethodGet, "/files/"+fn, nil) + r2 = r2.WithContext(db.Ctx()) + r2.SetPathValue("fn", fn) + w2 := httptest.NewRecorder() + app.fileHandler(w2, r2) + if w2.Code != http.StatusNotFound { + t.Fatalf("second request: status = %d, want 404 (negative cache must mask the out-of-band write)", w2.Code) + } + if strings.Contains(w2.Body.String(), "late arrival") { + t.Errorf("body leaked out-of-band object: %q", w2.Body.String()) + } +} + +func TestCDNFileHandler_BucketMissNegativeCached(t *testing.T) { + // Same assertion for /_cdn/{fn}. The CDN origin needs the same + // shield since the edge can hammer it on cache misses. + t.Parallel() + db := newTestDB(t) + bkt := newTestBucket(t) + app := newTestApp(bkt, authorized) + app.CDNBaseURL = "https://cdn.example.com/" + + fn := validTestFn("missnegcdn") + + r1 := httptest.NewRequest(http.MethodGet, "/_cdn/"+fn, nil) + r1 = r1.WithContext(db.Ctx()) + w1 := httptest.NewRecorder() + app.routes().ServeHTTP(w1, r1) + if w1.Code != http.StatusNotFound { + t.Fatalf("first request: status = %d, want 404", w1.Code) + } + + bw, err := bkt.NewWriter(t.Context(), fn, nil) + if err != nil { + t.Fatal(err) + } + bw.Write([]byte("late arrival")) + if err := bw.Close(); err != nil { + t.Fatal(err) + } + + r2 := httptest.NewRequest(http.MethodGet, "/_cdn/"+fn, nil) + r2 = r2.WithContext(db.Ctx()) + w2 := httptest.NewRecorder() + app.routes().ServeHTTP(w2, r2) + if w2.Code != http.StatusNotFound { + t.Fatalf("second request: status = %d, want 404 (negative cache must mask the out-of-band write)", w2.Code) + } +} + func TestIsInternalClient(t *testing.T) { t.Parallel() cases := []struct { diff --git a/handler.go b/handler.go index 9d6c48a..4707140 100644 --- a/handler.go +++ b/handler.go @@ -124,6 +124,31 @@ func generateFilename() string { return base64.RawURLEncoding.EncodeToString(b) } +// isValidFilename returns true iff fn matches what generateFilename +// produces: 86 chars of URL-safe base64 (RawURLEncoding of 64 random +// bytes, no padding). Anything else can't possibly exist in our system, +// so the file handlers can 404 it in CPU and skip both the Postgres +// query and the GCS Attributes call. That absorbs a flood of +// random-garbage fns (which would otherwise miss the per-fn cache, since +// every key is unique) at zero backend cost. +func isValidFilename(fn string) bool { + if len(fn) != 86 { + return false + } + for i := 0; i < len(fn); i++ { + c := fn[i] + switch { + case c >= 'A' && c <= 'Z': + case c >= 'a' && c <= 'z': + case c >= '0' && c <= '9': + case c == '-', c == '_': + default: + return false + } + } + return true +} + func escapeFilename(s string) string { return strings.ReplaceAll(s, `"`, "") } diff --git a/handler_test.go b/handler_test.go index f65dab8..09ddc37 100644 --- a/handler_test.go +++ b/handler_test.go @@ -23,6 +23,17 @@ func newTestBucket(t *testing.T) *blob.Bucket { return bkt } +// validTestFn returns an 86-char fn that passes isValidFilename, with +// the descriptive suffix preserved so failing tests are still readable. +// Distinct suffixes produce distinct fns, which keeps parallel tests +// isolated in the in-process per-fn cache. +func validTestFn(suffix string) string { + if len(suffix) > 86 { + panic("validTestFn: suffix too long: " + suffix) + } + return strings.Repeat("a", 86-len(suffix)) + suffix +} + func countObjects(t *testing.T, bkt *blob.Bucket) int { t.Helper() iter := bkt.List(nil) @@ -462,6 +473,41 @@ func TestGenerateFilename(t *testing.T) { } } +func TestIsValidFilename(t *testing.T) { + t.Parallel() + // Generated fns must always validate — this links the producer and + // the validator so a future change to either can't drift. + for i := 0; i < 16; i++ { + if fn := generateFilename(); !isValidFilename(fn) { + t.Errorf("generateFilename() = %q failed isValidFilename", fn) + } + } + + good := strings.Repeat("a", 86) + cases := []struct { + fn string + want bool + }{ + {good, true}, + {"abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789-_abcdefghijklmnopqrstuv", true}, // 86 chars, full alphabet + {"", false}, + {"short", false}, + {strings.Repeat("a", 85), false}, + {strings.Repeat("a", 87), false}, + {strings.Repeat("a", 85) + "!", false}, // bad char + {strings.Repeat("a", 85) + "+", false}, // standard-base64 only + {strings.Repeat("a", 85) + "/", false}, + {strings.Repeat("a", 85) + "=", false}, // padding not allowed by RawURLEncoding + {strings.Repeat("a", 85) + " ", false}, + {strings.Repeat("a", 85) + ".", false}, + } + for _, tc := range cases { + if got := isValidFilename(tc.fn); got != tc.want { + t.Errorf("isValidFilename(%q) = %v, want %v", tc.fn, got, tc.want) + } + } +} + func TestEscapeFilename(t *testing.T) { t.Parallel() cases := []struct{ in, want string }{ From 8af461646bc8c69095a67cdb6c049b10d6724e12 Mon Sep 17 00:00:00 2001 From: Thanatat Tamtan Date: Sun, 24 May 2026 10:12:33 +0700 Subject: [PATCH 4/7] Singleflight the DB lookup so a thundering herd collapses to one query MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The fileMeta cache has a 60s TTL and is loaded lazily on miss. Under load, the moment an entry expires (or at cold start), every concurrent request for the same fn sees a cache miss before any of them finishes writing — each then issues its own Postgres SELECT. With CDN cache TTL and an in-cluster origin call pattern, that "herd at the edge" is a very real shape against /files/{fn} and /_cdn/{fn}. Wrap the DB query inside lookupFile with github.com/moonrhythm/sf.Do (a generic, context-aware singleflight). Concurrent callers for the same fn join one in-flight execution and share its result; only the first caller's SELECT actually hits Postgres. The closure re-checks the cache so a sibling caller that won the race a tick earlier also benefits. sf preserves context values via WithoutCancel, so pgctx still resolves the DB connection inside the closure, and per-caller cancellation returns ctx.Err() to that caller without killing the shared work the remaining waiters depend on. Test: TestLookupFile_SingleflightCollapsesConcurrentCalls fires 50 goroutines at a single cold-cache fn behind a release barrier and asserts the per-fn DB-query counter stays below N. Without sf the counter would equal N; with sf it's effectively 1. Per-fn counter (sync.Map of *atomic.Uint64) so the test stays t.Parallel-safe. Stable over 20 -race iterations. Co-Authored-By: Claude Opus 4.7 --- files.go | 87 ++++++++++++++++++++++++++++++++++++++------------- files_test.go | 49 +++++++++++++++++++++++++++++ go.mod | 1 + go.sum | 2 ++ 4 files changed, 117 insertions(+), 22 deletions(-) diff --git a/files.go b/files.go index 4bb2cbf..8511cbe 100644 --- a/files.go +++ b/files.go @@ -8,10 +8,13 @@ import ( "log/slog" "net/http" "strconv" + "sync" + "sync/atomic" "time" "github.com/acoshift/pgsql/pgctx" "github.com/moonrhythm/cachestore" + "github.com/moonrhythm/sf" "gocloud.dev/blob" "gocloud.dev/gcerrors" ) @@ -183,35 +186,75 @@ func (a *App) streamFile(w http.ResponseWriter, r *http.Request, fn string, attr // lookupFile fetches the file's project owner and expiration from the DB. // Both fields are immutable once written, so the in-process cache only -// needs to be short enough to absorb bursts on the same fn — not to track -// any state that can change underneath us. +// needs to be short enough to absorb bursts on the same fn — not to +// track any state that can change underneath us. +// +// The DB query runs inside sf.Do so a thundering herd at the cache-miss +// edge (cold start, or the moment a 60s entry expires under load) +// collapses to a single Postgres round-trip. Every concurrent caller +// for the same fn gets the same result without each one issuing its +// own query. func lookupFile(ctx context.Context, fn string) fileMeta { cacheKey := fileCacheKey(fn) if v, ok := cachestore.Get[fileMeta](cacheKey); ok { return v } - var ( - m fileMeta - expires sql.NullTime - ) - err := pgctx.QueryRow(ctx, ` - SELECT project_id, expires_at FROM files WHERE fn = $1 - `, fn).Scan(&m.ProjectID, &expires) - switch { - case err == nil: - m.Found = true - if expires.Valid { - m.ExpiresAt = expires.Time + m, _, _ := sf.Do(ctx, "lookupFile|"+fn, func(ctx context.Context) (fileMeta, error) { + // Re-check the cache: a sibling caller may have populated it + // while we were queued behind sf's mutex. + if v, ok := cachestore.Get[fileMeta](cacheKey); ok { + return v, nil + } + + recordLookupFileDBQuery(fn) + var ( + m fileMeta + expires sql.NullTime + ) + err := pgctx.QueryRow(ctx, ` + SELECT project_id, expires_at FROM files WHERE fn = $1 + `, fn).Scan(&m.ProjectID, &expires) + switch { + case err == nil: + m.Found = true + if expires.Valid { + m.ExpiresAt = expires.Time + } + case errors.Is(err, sql.ErrNoRows): + // Found=false — a confirmed miss is fine to cache. + default: + // Transient DB error: don't cache-poison. Returning here + // also skips the cache write below; the next caller will + // retry the query. + slog.Error("lookup file", "fn", fn, "error", err) + return fileMeta{}, nil } - case errors.Is(err, sql.ErrNoRows): - // Leave m zero-valued (Found=false). Don't cache-poison on - // transient DB errors below, but a confirmed miss is fine. - default: - slog.Error("lookup file", "fn", fn, "error", err) - return fileMeta{} - } - cachestore.Set(cacheKey, m, &cachestore.SetOptions{TTL: fileMetaCacheTTL}) + cachestore.Set(cacheKey, m, &cachestore.SetOptions{TTL: fileMetaCacheTTL}) + return m, nil + }) return m } + +// lookupFileDBQueries counts the number of actual Postgres SELECTs +// issued by lookupFile, keyed by fn. The counter is only bumped inside +// the singleflight closure, so a successful sf.Do collapse leaves it at +// 1 even after N concurrent callers. Production code only writes here; +// only tests read it. Per-fn so parallel tests don't trample each other. +var lookupFileDBQueries sync.Map // map[string]*atomic.Uint64 + +func recordLookupFileDBQuery(fn string) { + v, _ := lookupFileDBQueries.LoadOrStore(fn, new(atomic.Uint64)) + v.(*atomic.Uint64).Add(1) +} + +// lookupFileDBQueryCount returns how many DB queries lookupFile has +// issued for fn since process start. Test-only. +func lookupFileDBQueryCount(fn string) uint64 { + v, ok := lookupFileDBQueries.Load(fn) + if !ok { + return 0 + } + return v.(*atomic.Uint64).Load() +} diff --git a/files_test.go b/files_test.go index 23eb544..e3a1b13 100644 --- a/files_test.go +++ b/files_test.go @@ -4,6 +4,7 @@ import ( "net/http" "net/http/httptest" "strings" + "sync" "testing" "github.com/acoshift/pgsql/pgctx" @@ -174,6 +175,54 @@ func TestLookupFile_NotFound(t *testing.T) { } } +func TestLookupFile_SingleflightCollapsesConcurrentCalls(t *testing.T) { + // 50 goroutines race on the same cold-cache fn. sf.Do must collapse + // them into a single Postgres SELECT — anything more and the DDoS + // shield against a thundering herd at the cache-miss edge is gone. + t.Parallel() + db := newTestDB(t) + ctx := db.Ctx() + + fn := "sf-thundering-herd-fn" + if _, err := pgctx.Exec(ctx, ` + INSERT INTO files (fn, project_id, size, filename, ttl, expires_at) + VALUES ($1, 'proj-sf', 1, 'x', 1, now() + interval '1 day') + `, fn); err != nil { + t.Fatal(err) + } + + const N = 50 + results := make([]fileMeta, N) + var wg sync.WaitGroup + start := make(chan struct{}) + for i := 0; i < N; i++ { + wg.Add(1) + go func(idx int) { + defer wg.Done() + <-start // release all goroutines at once to maximise overlap + results[idx] = lookupFile(ctx, fn) + }(i) + } + close(start) + wg.Wait() + + // Correctness: every caller got the same correct result. + for i, r := range results { + if !r.Found || r.ProjectID != "proj-sf" { + t.Errorf("results[%d] = %+v, want Found=true ProjectID=proj-sf", i, r) + } + } + + // Dedupe: sf collapsed the herd. We allow >1 because the per-caller + // cache check before sf.Do can race in a way that lets a few + // stragglers issue their own queries when the first one finishes + // fast and the cache populates before later goroutines reach sf — + // but it must be far below N. + if got := lookupFileDBQueryCount(fn); got >= N { + t.Errorf("DB queries for %s = %d, want <%d (singleflight should have collapsed the herd)", fn, got, N) + } +} + func TestLookupFile_Expired(t *testing.T) { t.Parallel() db := newTestDB(t) diff --git a/go.mod b/go.mod index 7e5185e..e0585c7 100644 --- a/go.mod +++ b/go.mod @@ -11,6 +11,7 @@ require ( github.com/lib/pq v1.12.3 github.com/moonrhythm/cachestore v0.0.0-20241226112208-fd22884c5b60 github.com/moonrhythm/parapet v0.13.6 + github.com/moonrhythm/sf v1.0.2 github.com/prometheus/client_golang v1.20.5 gocloud.dev v0.45.0 ) diff --git a/go.sum b/go.sum index bef8554..4431b8f 100644 --- a/go.sum +++ b/go.sum @@ -141,6 +141,8 @@ github.com/moonrhythm/parapet v0.13.6 h1:xpoXIEI1JiNVTs8F2edmZn75gMV2Dg6KQdpOlUt github.com/moonrhythm/parapet v0.13.6/go.mod h1:QEhl1pk8viGxC3xz1XLmzU8UZRT5qxHXgUln5hXlRCQ= github.com/moonrhythm/pq v0.0.0-20230504040008-09b0644d6569 h1:VDLLIg19HS2A7HuMY8ukFNrdgrRqGn5O0xQ4zrSgD/0= github.com/moonrhythm/pq v0.0.0-20230504040008-09b0644d6569/go.mod h1:V43Namo2SgpLTmiaQXjiKPuKvqGamIhPawsQF+pv8xw= +github.com/moonrhythm/sf v1.0.2 h1:W7uHinFjiO1A+7UNDANka78J9jk1nQz+W7IzK0bpj68= +github.com/moonrhythm/sf v1.0.2/go.mod h1:GVFIuH7VwfIn+wVi+kcjxAUcFsQ2JmiUrQuIg5yuono= github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA= github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ= github.com/planetscale/vtprotobuf v0.6.1-0.20240319094008-0393e58bdf10 h1:GFCKgmp0tecUJ0sJuv4pzYCqS9+RGSn52M3FUwPs+uo= From ac4b82c6c931ab99d2ef0ec6ad3ca9f55211813e Mon Sep 17 00:00:00 2001 From: Thanatat Tamtan Date: Sun, 24 May 2026 10:30:59 +0700 Subject: [PATCH 5/7] Sign download URLs, shorten filenames, drop random-token DDoS MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Three changes that share a single goal: an attacker who hits /files/ with random garbage shouldn't be able to bill the project or burn DB / GCS quota. 1) sf around the checkAuth external call Mirrors the lookupFile change in the previous commit. cacheTTL is only 30s, so under upload load from one caller (parallel CI jobs holding the same bearer) the cache-miss edge would otherwise thunder the /me.authorized endpoint. sf.Do collapses concurrent identical (auth, project, projectId) triples to one round-trip, with the same re-check-cache-inside-the-closure pattern. 2) New filename scheme: short, alphanumeric, signed generateFilename was 64 random bytes → 86-char URL-safe base64 (alphabet [A-Za-z0-9_-]). The base64 special chars never bought anything and the length was generous to a fault. Replaced with: fn = 24 chars from [0-9A-Za-z], rejection-sampled to be unbiased (~143 bits of entropy, well over any guessing or birthday-collision threshold). sig = 20 chars of lowercase hex (HMAC-SHA256 truncated to 80 bits) keyed by App.SignKey. The URL token is `fn + sig` = 44 chars total — down from 86, and strictly alphanumeric. fn is what we store in the bucket and the files.fn column; the full token only ever appears in URLs. 3) HMAC verification before any backend hit fileHandler / cdnFileHandler now run parseToken first. It checks length + constant-time HMAC; on any mismatch we 404 in CPU. This is the *primary* DDoS shield — random /files/{garbage} floods by anyone who doesn't have sign_key never reach lookupFile, never reach Bucket.Attributes. The negative bucket cache and per-fn metadata cache still matter for the residual case of a replayed real URL after expiry+GC. isValidFilename is gone; parseToken subsumes its job (length + alphabet implicitly enforced by HMAC matching) with stronger guarantees. Config: new required env var sign_key (HMAC key). Rotating it invalidates every outstanding URL — the correct behavior if a key ever leaks. Tests use a fixed testSignKey threaded through newTestApp; updated existing tests to put signed tokens (not raw fns) in URLs and SetPathValue("token"). Added direct tests for generateFilename alphabet+uniformity, signFilename determinism+key-sensitivity, parseToken round-trip and forgery rejection, plus an sf-collapses-thundering-herd test for checkAuth that mirrors the lookupFile one. CDN redirect now preserves the full signed token so the CDN's origin-fetch hits /_cdn/{token} with a URL the origin can re-verify. Updated TestFileHandler_CDNRedirect accordingly. README and CLAUDE.md updated to describe the new token format, sign_key requirement, and the four-tier DDoS protection ladder (parseToken → cache → sf.Do → GCS). Co-Authored-By: Claude Opus 4.7 --- CLAUDE.md | 27 ++++++-- README.md | 5 +- auth.go | 108 +++++++++++++++++------------- auth_test.go | 44 +++++++++++++ files.go | 44 ++++++++----- files_test.go | 114 +++++++++++++++++--------------- handler.go | 107 ++++++++++++++++++++++-------- handler_test.go | 170 +++++++++++++++++++++++++++++++++++++----------- main.go | 4 ++ 9 files changed, 434 insertions(+), 189 deletions(-) diff --git a/CLAUDE.md b/CLAUDE.md index 2ddddf7..1c133f7 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -19,35 +19,48 @@ Standard Go HTTP server (not Cloudflare Workers) serving as a temporary file upl **Environment variables (required):** - `db_url` — PostgreSQL connection string - `bucket_name` — GCS bucket name +- `sign_key` — HMAC key for signing download tokens (`handler.go:signFilename`). Rotating it invalidates every outstanding URL — that's the right behavior if a key leaks. **Environment variables (optional):** - `base_url` — download URL prefix (default: `https://dropbox.deploys.app/files/`) - `api_endpoint` — deploys.app API base URL (default: `https://api.deploys.app`, override with internal address in production) -- `cdn_base_url` — full URL prefix (including scheme and trailing slash, e.g. `https://cdn.example.com/`). When set, `GET /files/{fn}` records the download metric (counted against `attrs.Size`) and 307-redirects to `{cdn_base_url}{fn}`. The CDN edge is expected to fetch its origin at `https://dropbox.deploys.app/_cdn/{fn}`, which streams the file unauthenticated and without metrics. In-cluster callers (private/loopback/link-local `X-Real-Ip`) bypass the redirect and stream directly. Unset = original streaming behavior. +- `cdn_base_url` — full URL prefix (including scheme and trailing slash, e.g. `https://cdn.example.com/`). When set, `GET /files/{token}` records the download metric (counted against `attrs.Size`) and 307-redirects to `{cdn_base_url}{token}`. The CDN edge is expected to fetch its origin at `https://dropbox.deploys.app/_cdn/{token}`, which streams the file (no auth, no metrics) after re-verifying the same HMAC. In-cluster callers (private/loopback/link-local `X-Real-Ip`) bypass the redirect and stream directly. Unset = original streaming behavior. - `PORT` — listen port (default: `8080`) - `log_level` — slog level (default: info) +**Download token scheme (`handler.go`):** +- The URL path component is a 44-char `token` = `fn` (24 chars random `[0-9A-Za-z]`, ~143 bits of entropy) concatenated with `sig` (20 hex chars of HMAC-SHA256 truncated to 80 bits, keyed by `sign_key`). +- `parseToken(SignKey, token)` runs first in both `fileHandler` and `cdnFileHandler` and 404s on any mismatch — DDoS attempts that don't know `sign_key` never reach the DB or GCS. +- `fn` is what we store in the bucket and the `files.fn` column. The full token only appears in URLs. + **Request flow (`handler.go`):** 1. Parse `Authorization` header + `project`/`projectId` from query params or `param-*` headers (query params take precedence) 2. Authorize via `checkAuth()` in `auth.go` 3. Parse TTL (1–7 days, default 1) and optional filename the same way -4. Generate a crypto-random 86-char URL-safe base64 filename with TTL digit prepended (e.g., `1ABC…`) -5. Stream body to GCS with cache-control and optional content-disposition +4. Generate a 24-char alphanumeric `fn` (`generateFilename`, rejection-sampled to stay unbiased) +5. Stream body to GCS with cache-control and optional content-disposition, keyed by `fn` 6. Insert metadata into PostgreSQL via `pgctx.Exec` -7. Return JSON: `{"ok": true, "result": {"downloadUrl": "...", "expiresAt": "..."}}` +7. Return JSON: `{"ok": true, "result": {"downloadUrl": "{base_url}{fn}{sig}", "expiresAt": "..."}}` **Auth (`auth.go`):** - No `Authorization` header → alpha mode, project ID hardcoded as `"alpha"` (TODO: remove) - With token → POST to `https://api.deploys.app/me.authorized` for `dropbox.upload` permission, checking `authorized` + `billingAccount.active` -- Results cached in-process for 30 seconds via `cachestore` +- Results cached in-process for 30 seconds via `cachestore`; the external call is wrapped in `sf.Do` so concurrent uploads from the same caller collapse to one round-trip at the cache-miss edge. + +**DDoS protection ladder:** see `fileHandler` / `cdnFileHandler` in `files.go`. In order from cheapest to most expensive: +1. `parseToken` HMAC check — pure CPU, no I/O. +2. `lookupFile` cache — 60s in-process cache of `(project_id, expires_at, bucket_missing)` per fn. +3. `sf.Do` around the DB `SELECT` — collapses any thundering herd at the cache-miss edge. +4. `Bucket.Attributes` — only reached for tokens that survive 1–3. **Key libraries (same pattern as `moonrhythm/registry`):** - `parapet` — HTTP server with middleware chain (healthz, logger, pgctx) - `pgctx` — context-aware PostgreSQL access (`pgctx.Exec`, middleware injects DB into context) -- `cachestore` — in-process TTL cache for auth results +- `cachestore` — in-process TTL cache for auth results and per-fn metadata +- `sf` — generic context-aware singleflight (`github.com/moonrhythm/sf`); used in `lookupFile` and `checkAuth` to dedupe concurrent backend calls - `configfile` — env-var config reader (`config.MustString`, `config.StringDefault`) ## Notes - `schema.sql` targets PostgreSQL; `project_id` is `text` (the API returns string IDs) -- `base_url` is the public download prefix (`https://dropbox.deploys.app/files/`); it shares the service host and resolves to the `GET /files/{fn}` route, which streams directly or 307s to the CDN (see `cdn_base_url`) +- `base_url` is the public download prefix (`https://dropbox.deploys.app/files/`); it shares the service host and resolves to the `GET /files/{token}` route, which streams directly or 307s to the CDN (see `cdn_base_url`) diff --git a/README.md b/README.md index 2ca7101..77189e3 100644 --- a/README.md +++ b/README.md @@ -25,6 +25,7 @@ Docker image is built and pushed automatically on push to `main`. See `.github/w |---------------|----------------------------------------------------------| | `db_url` | PostgreSQL connection string | | `bucket_name` | GCS bucket name | +| `sign_key` | HMAC key for signing download tokens. Rotating invalidates every outstanding URL. | | `base_url` | Download URL prefix (default: `https://dropbox.deploys.app/files/`) | | `PORT` | Listen port (default: `8080`) | @@ -74,12 +75,14 @@ File data binary { "ok": true, "result": { - "downloadUrl": "https://dropbox.deploys.app/files/", + "downloadUrl": "https://dropbox.deploys.app/files/", "expiresAt": "2020-01-01T01:01:01Z" } } ``` +`` is a 44-char string: a 24-char random alphanumeric filename concatenated with a 20-char HMAC-SHA256 signature (keyed by `sign_key`). Tampered or made-up tokens are rejected before any DB or storage lookup. + ##### Unauthorized ```json diff --git a/auth.go b/auth.go index 331842d..f7b1b6f 100644 --- a/auth.go +++ b/auth.go @@ -8,6 +8,7 @@ import ( "time" "github.com/moonrhythm/cachestore" + "github.com/moonrhythm/sf" ) var apiEndpoint = "https://api.deploys.app" @@ -44,57 +45,72 @@ func checkAuth(ctx context.Context, auth, project, projectID string) AuthResult return v } - body, _ := json.Marshal(struct { - Project string `json:"project,omitempty"` - ProjectID string `json:"projectId,omitempty"` - Permissions []string `json:"permissions"` - }{ - Project: project, - ProjectID: projectID, - Permissions: []string{permission}, - }) + // Same singleflight pattern as lookupFile: collapse a thundering herd + // of concurrent uploads from the same caller into a single + // /me.authorized round-trip. The result is cached for 30s + // (cacheTTL); sf.Do dedupe matters at the cold-cache edge and right + // after that 30s entry expires under load. + result, _, _ := sf.Do(ctx, cacheKey, func(ctx context.Context) (AuthResult, error) { + // Re-check the cache: a sibling caller may have populated it + // while we were queued behind sf's mutex. + if v, ok := cachestore.Get[AuthResult](cacheKey); ok { + return v, nil + } - req, _ := http.NewRequest(http.MethodPost, apiEndpoint+"/me.authorized", bytes.NewReader(body)) - req.Header.Set("Authorization", auth) - req.Header.Set("Content-Type", "application/json") + body, _ := json.Marshal(struct { + Project string `json:"project,omitempty"` + ProjectID string `json:"projectId,omitempty"` + Permissions []string `json:"permissions"` + }{ + Project: project, + ProjectID: projectID, + Permissions: []string{permission}, + }) - resp, err := http.DefaultClient.Do(req) - if err != nil { - return AuthResult{} - } - defer resp.Body.Close() - if resp.StatusCode != http.StatusOK { - return AuthResult{} - } + req, _ := http.NewRequestWithContext(ctx, http.MethodPost, apiEndpoint+"/me.authorized", bytes.NewReader(body)) + req.Header.Set("Authorization", auth) + req.Header.Set("Content-Type", "application/json") - var res struct { - OK bool `json:"ok"` - Result struct { - Authorized bool `json:"authorized"` - Project struct { - ID string `json:"id"` - Project string `json:"project"` - BillingAccount struct { - Active bool `json:"active"` - } `json:"billingAccount"` - } `json:"project"` - } `json:"result"` - } - if err := json.NewDecoder(resp.Body).Decode(&res); err != nil { - return AuthResult{} - } + resp, err := http.DefaultClient.Do(req) + if err != nil { + // Don't cache transport failures — let the next caller retry. + return AuthResult{}, nil + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + return AuthResult{}, nil + } - var result AuthResult - if res.OK && res.Result.Authorized && res.Result.Project.BillingAccount.Active { - result = AuthResult{ - Authorized: true, - Project: Project{ - ID: res.Result.Project.ID, - Project: res.Result.Project.Project, - }, + var res struct { + OK bool `json:"ok"` + Result struct { + Authorized bool `json:"authorized"` + Project struct { + ID string `json:"id"` + Project string `json:"project"` + BillingAccount struct { + Active bool `json:"active"` + } `json:"billingAccount"` + } `json:"project"` + } `json:"result"` + } + if err := json.NewDecoder(resp.Body).Decode(&res); err != nil { + return AuthResult{}, nil + } + + var result AuthResult + if res.OK && res.Result.Authorized && res.Result.Project.BillingAccount.Active { + result = AuthResult{ + Authorized: true, + Project: Project{ + ID: res.Result.Project.ID, + Project: res.Result.Project.Project, + }, + } } - } - cachestore.Set(cacheKey, result, &cachestore.SetOptions{TTL: cacheTTL}) + cachestore.Set(cacheKey, result, &cachestore.SetOptions{TTL: cacheTTL}) + return result, nil + }) return result } diff --git a/auth_test.go b/auth_test.go index b74e073..411480c 100644 --- a/auth_test.go +++ b/auth_test.go @@ -8,6 +8,7 @@ import ( "sync" "sync/atomic" "testing" + "time" ) // A single mock server is started lazily and serves every auth test. Each test @@ -215,6 +216,49 @@ func TestCheckAuth_CachesResult(t *testing.T) { } } +func TestCheckAuth_SingleflightCollapsesConcurrentCalls(t *testing.T) { + // Same shape as TestLookupFile_SingleflightCollapsesConcurrentCalls: + // 50 goroutines race on a cold-cache (auth, project, projectId) + // triple. sf.Do must collapse them into a single /me.authorized + // round-trip — otherwise a thundering herd of uploads from one + // caller (e.g. parallel CI jobs holding the same bearer) hammers + // the deploys.app API on every cache-miss edge. + t.Parallel() + var calls atomic.Int64 + token := "Bearer " + t.Name() + registerAuthMock(t, token, func(w http.ResponseWriter, r *http.Request) { + calls.Add(1) + // A short sleep widens the singleflight window so the test + // reliably catches a regression where dedupe is broken. + time.Sleep(50 * time.Millisecond) + jsonAuthMock(true, true)(w, r) + }) + + const N = 50 + results := make([]AuthResult, N) + var wg sync.WaitGroup + start := make(chan struct{}) + for i := 0; i < N; i++ { + wg.Add(1) + go func(idx int) { + defer wg.Done() + <-start + results[idx] = checkAuth(context.Background(), token, "sfproject", "") + }(i) + } + close(start) + wg.Wait() + + for i, r := range results { + if !r.Authorized { + t.Errorf("results[%d] not authorized", i) + } + } + if got := calls.Load(); got >= N { + t.Errorf("auth API called %d times, want <%d (sf should have collapsed the herd)", got, N) + } +} + func TestCheckAuth_CacheKeyDistinguishesTokens(t *testing.T) { t.Parallel() var calls atomic.Int64 diff --git a/files.go b/files.go index 8511cbe..7dab289 100644 --- a/files.go +++ b/files.go @@ -47,12 +47,16 @@ func (m fileMeta) Expired() bool { func fileCacheKey(fn string) string { return "fn|" + fn } func (a *App) fileHandler(w http.ResponseWriter, r *http.Request) { - fn := r.PathValue("fn") + token := r.PathValue("token") - // Reject obviously-invalid fns in CPU. A flood of random-garbage fns - // would otherwise miss the per-fn cache on every request (each key is - // unique) and burn one DB query + one GCS Class B op apiece. - if !isValidFilename(fn) { + // Verify the HMAC tag in the token before any I/O. A flood of random + // /files/{garbage} requests by an attacker who doesn't have SignKey + // is rejected in CPU here — no DB query, no GCS Attributes call. + // This is the primary DDoS shield; the per-fn cache and the + // negative bucket cache only kick in once we've established the + // token is one we issued. + fn, ok := parseToken(a.SignKey, token) + if !ok { http.NotFound(w, r) return } @@ -93,32 +97,36 @@ func (a *App) fileHandler(w http.ResponseWriter, r *http.Request) { // When a CDN is configured and the caller is not in-cluster, record the // download as if we were serving the full file (the CDN will handle the - // actual bytes) and redirect. Cache hits never come back to this origin, - // so the bill is an over-approximation; that matches the registry - // pattern and is what /_internal/calculate-dropbox-usages assumes. + // actual bytes) and redirect. The redirect target preserves the full + // signed token so the CDN's origin-fetch hits cdnFileHandler with a + // URL we can re-verify. Cache hits never come back to this origin, so + // the bill is an over-approximation; that matches the registry pattern + // and is what /_internal/calculate-dropbox-usages assumes. if a.CDNBaseURL != "" && !isInternalClient(r) { downloadCount.WithLabelValues(meta.ProjectID).Inc() egressBytes.WithLabelValues(meta.ProjectID).Add(float64(attrs.Size)) - http.Redirect(w, r, a.CDNBaseURL+fn, http.StatusTemporaryRedirect) + http.Redirect(w, r, a.CDNBaseURL+token, http.StatusTemporaryRedirect) return } a.streamFile(w, r, fn, attrs, meta.ProjectID) } -// cdnFileHandler is the origin endpoint the CDN edge fetches from. It is -// unauthenticated — file URLs are 86-char crypto-random and only reachable -// if you know them — and skips the redirect/metrics so the CDN sees a -// plain streaming response. Expired files are refused here too so the CDN -// can't refresh its cache with bytes the user is no longer entitled to. +// cdnFileHandler is the origin endpoint the CDN edge fetches from. The +// token in the URL is the same signed token we issued to the user, so +// the HMAC check still gates the edge — only URLs we issued can reach +// the bucket. Streams the body without metrics so the CDN sees a plain +// response; expired files are refused here too so the CDN can't refresh +// its cache with bytes the user is no longer entitled to. func (a *App) cdnFileHandler(w http.ResponseWriter, r *http.Request) { - fn := r.PathValue("fn") + token := r.PathValue("token") - // Same DDoS-protection ladder as fileHandler: validate format, then - // consult the per-fn cache, then GCS. The edge would otherwise + // Same DDoS-protection ladder as fileHandler: verify the signature, + // then consult the per-fn cache, then GCS. The edge would otherwise // amplify a random-garbage flood or an expired-URL probe into one // GCS op per request. - if !isValidFilename(fn) { + fn, ok := parseToken(a.SignKey, token) + if !ok { http.NotFound(w, r) return } diff --git a/files_test.go b/files_test.go index e3a1b13..db4446a 100644 --- a/files_test.go +++ b/files_test.go @@ -32,9 +32,9 @@ func TestFileHandler_Success(t *testing.T) { t.Fatal(err) } - r := httptest.NewRequest(http.MethodGet, "/files/"+fn, nil) + r := httptest.NewRequest(http.MethodGet, "/files/"+signedToken(fn), nil) r = r.WithContext(db.Ctx()) - r.SetPathValue("fn", fn) + r.SetPathValue("token", signedToken(fn)) w := httptest.NewRecorder() app.fileHandler(w, r) @@ -61,9 +61,9 @@ func TestFileHandler_NotFound(t *testing.T) { app := newTestApp(newTestBucket(t), authorized) fn := validTestFn("notexist") - r := httptest.NewRequest(http.MethodGet, "/files/"+fn, nil) + r := httptest.NewRequest(http.MethodGet, "/files/"+signedToken(fn), nil) r = r.WithContext(db.Ctx()) - r.SetPathValue("fn", fn) + r.SetPathValue("token", signedToken(fn)) w := httptest.NewRecorder() app.fileHandler(w, r) @@ -88,7 +88,7 @@ func TestFileHandler_RouteIntegration(t *testing.T) { t.Fatal(err) } - r := httptest.NewRequest(http.MethodGet, "/files/"+fn, nil) + r := httptest.NewRequest(http.MethodGet, "/files/"+signedToken(fn), nil) r = r.WithContext(db.Ctx()) w := httptest.NewRecorder() app.routes().ServeHTTP(w, r) @@ -114,9 +114,9 @@ func TestFileHandler_NoHeadersWhenAttrsEmpty(t *testing.T) { t.Fatal(err) } - r := httptest.NewRequest(http.MethodGet, "/files/"+fn, nil) + r := httptest.NewRequest(http.MethodGet, "/files/"+signedToken(fn), nil) r = r.WithContext(db.Ctx()) - r.SetPathValue("fn", fn) + r.SetPathValue("token", signedToken(fn)) w := httptest.NewRecorder() app.fileHandler(w, r) @@ -270,9 +270,9 @@ func TestFileHandler_ExpiredReturnsGone(t *testing.T) { t.Fatal(err) } - r := httptest.NewRequest(http.MethodGet, "/files/"+fn, nil) + r := httptest.NewRequest(http.MethodGet, "/files/"+signedToken(fn), nil) r = r.WithContext(ctx) - r.SetPathValue("fn", fn) + r.SetPathValue("token", signedToken(fn)) w := httptest.NewRecorder() app.fileHandler(w, r) @@ -303,9 +303,9 @@ func TestFileHandler_ExpiredSkipsBucket(t *testing.T) { t.Fatal(err) } - r := httptest.NewRequest(http.MethodGet, "/files/"+fn, nil) + r := httptest.NewRequest(http.MethodGet, "/files/"+signedToken(fn), nil) r = r.WithContext(ctx) - r.SetPathValue("fn", fn) + r.SetPathValue("token", signedToken(fn)) w := httptest.NewRecorder() app.fileHandler(w, r) @@ -330,7 +330,7 @@ func TestCDNFileHandler_ExpiredSkipsBucket(t *testing.T) { t.Fatal(err) } - r := httptest.NewRequest(http.MethodGet, "/_cdn/"+fn, nil) + r := httptest.NewRequest(http.MethodGet, "/_cdn/"+signedToken(fn), nil) r = r.WithContext(ctx) w := httptest.NewRecorder() app.routes().ServeHTTP(w, r) @@ -365,9 +365,9 @@ func TestFileHandler_ExpiredOverridesCDNRedirect(t *testing.T) { t.Fatal(err) } - r := httptest.NewRequest(http.MethodGet, "/files/"+fn, nil) + r := httptest.NewRequest(http.MethodGet, "/files/"+signedToken(fn), nil) r = r.WithContext(ctx) - r.SetPathValue("fn", fn) + r.SetPathValue("token", signedToken(fn)) w := httptest.NewRecorder() app.fileHandler(w, r) @@ -404,7 +404,7 @@ func TestCDNFileHandler_ExpiredReturnsGone(t *testing.T) { t.Fatal(err) } - r := httptest.NewRequest(http.MethodGet, "/_cdn/"+fn, nil) + r := httptest.NewRequest(http.MethodGet, "/_cdn/"+signedToken(fn), nil) r = r.WithContext(ctx) w := httptest.NewRecorder() app.routes().ServeHTTP(w, r) @@ -434,17 +434,19 @@ func TestFileHandler_CDNRedirect(t *testing.T) { t.Fatal(err) } - r := httptest.NewRequest(http.MethodGet, "/files/"+fn, nil) + r := httptest.NewRequest(http.MethodGet, "/files/"+signedToken(fn), nil) r = r.WithContext(db.Ctx()) - r.SetPathValue("fn", fn) + r.SetPathValue("token", signedToken(fn)) w := httptest.NewRecorder() app.fileHandler(w, r) if w.Code != http.StatusTemporaryRedirect { t.Fatalf("status = %d, want 307", w.Code) } - if loc := w.Header().Get("Location"); loc != "https://cdn.example.com/"+fn { - t.Errorf("Location = %q, want https://cdn.example.com/%s", loc, fn) + // Redirect preserves the *full* signed token so the CDN's + // origin-fetch hits cdnFileHandler with a URL we can re-verify. + if want := "https://cdn.example.com/" + signedToken(fn); w.Header().Get("Location") != want { + t.Errorf("Location = %q, want %q", w.Header().Get("Location"), want) } // The default http.Redirect body is `Temporary Redirect` — // it must not contain the object bytes. @@ -470,9 +472,9 @@ func TestFileHandler_CDNInternalClientStreams(t *testing.T) { t.Fatal(err) } - r := httptest.NewRequest(http.MethodGet, "/files/"+fn, nil) + r := httptest.NewRequest(http.MethodGet, "/files/"+signedToken(fn), nil) r = r.WithContext(db.Ctx()) - r.SetPathValue("fn", fn) + r.SetPathValue("token", signedToken(fn)) r.Header.Set("X-Real-Ip", "10.0.0.5") // private IP -> internal w := httptest.NewRecorder() app.fileHandler(w, r) @@ -502,9 +504,9 @@ func TestFileHandler_CDNRedirectPublicXRealIP(t *testing.T) { t.Fatal(err) } - r := httptest.NewRequest(http.MethodGet, "/files/"+fn, nil) + r := httptest.NewRequest(http.MethodGet, "/files/"+signedToken(fn), nil) r = r.WithContext(db.Ctx()) - r.SetPathValue("fn", fn) + r.SetPathValue("token", signedToken(fn)) r.Header.Set("X-Real-Ip", "203.0.113.5") // public IP -> CDN path w := httptest.NewRecorder() app.fileHandler(w, r) @@ -533,7 +535,7 @@ func TestCDNFileHandler_Streams(t *testing.T) { t.Fatal(err) } - r := httptest.NewRequest(http.MethodGet, "/_cdn/"+fn, nil) + r := httptest.NewRequest(http.MethodGet, "/_cdn/"+signedToken(fn), nil) r = r.WithContext(db.Ctx()) w := httptest.NewRecorder() app.routes().ServeHTTP(w, r) @@ -556,7 +558,7 @@ func TestCDNFileHandler_NotFound(t *testing.T) { app.CDNBaseURL = "https://cdn.example.com/" fn := validTestFn("cdnnope") - r := httptest.NewRequest(http.MethodGet, "/_cdn/"+fn, nil) + r := httptest.NewRequest(http.MethodGet, "/_cdn/"+signedToken(fn), nil) r = r.WithContext(db.Ctx()) w := httptest.NewRecorder() app.routes().ServeHTTP(w, r) @@ -566,40 +568,48 @@ func TestCDNFileHandler_NotFound(t *testing.T) { } } -func TestFileHandler_InvalidFilenameRejected(t *testing.T) { - // Garbage fns should 404 from the validator without touching DB or - // GCS. We pass a nil-context request: if the validator ever stops - // short-circuiting, lookupFile will panic on the missing pgctx and - // blow the test up — which is the right failure mode for a - // regression that removes the DDoS shield. +func TestFileHandler_InvalidTokenRejected(t *testing.T) { + // Tokens that don't HMAC-verify must 404 from parseToken without + // touching DB or GCS. The test app has no pgctx attached, so if the + // signature shortcut ever regresses, lookupFile will panic and the + // test blows up loudly — which is the right failure mode for the + // primary DDoS shield going away. t.Parallel() app := newTestApp(newTestBucket(t), authorized) - cases := []string{ - "", - "short", - "../../etc/passwd", - strings.Repeat("a", 85), // 1 short - strings.Repeat("a", 87), // 1 long - strings.Repeat("a", 85) + "!", // bad char - strings.Repeat("a", 85) + "+", // standard-base64 char, not URL-safe - strings.Repeat("a", 85) + "/", // ditto - } - for _, fn := range cases { - t.Run("fn="+fn, func(t *testing.T) { + goodFn := validTestFn("good") + goodToken := signedToken(goodFn) + tamperedSig := goodToken[:tokenLen-1] + "0" + if tamperedSig == goodToken { + tamperedSig = goodToken[:tokenLen-1] + "1" + } + forgedDiffKey := makeToken([]byte("not-the-real-key"), goodFn) + + cases := map[string]string{ + "empty": "", + "short": "short", + "path-traversal": "../../etc/passwd", + "one-char-short": strings.Repeat("a", tokenLen-1), + "one-char-long": strings.Repeat("a", tokenLen+1), + "right-len-bad-sig": strings.Repeat("a", tokenLen), + "tampered-sig": tamperedSig, + "forged-other-key": forgedDiffKey, + } + for name, token := range cases { + t.Run(name, func(t *testing.T) { t.Parallel() r := httptest.NewRequest(http.MethodGet, "/files/x", nil) - r.SetPathValue("fn", fn) + r.SetPathValue("token", token) w := httptest.NewRecorder() app.fileHandler(w, r) if w.Code != http.StatusNotFound { - t.Errorf("fn=%q: status = %d, want 404", fn, w.Code) + t.Errorf("token=%q: status = %d, want 404", token, w.Code) } }) } } -func TestCDNFileHandler_InvalidFilenameRejected(t *testing.T) { +func TestCDNFileHandler_InvalidTokenRejected(t *testing.T) { t.Parallel() app := newTestApp(newTestBucket(t), authorized) app.CDNBaseURL = "https://cdn.example.com/" @@ -625,9 +635,9 @@ func TestFileHandler_BucketMissNegativeCached(t *testing.T) { fn := validTestFn("missneg") - r1 := httptest.NewRequest(http.MethodGet, "/files/"+fn, nil) + r1 := httptest.NewRequest(http.MethodGet, "/files/"+signedToken(fn), nil) r1 = r1.WithContext(db.Ctx()) - r1.SetPathValue("fn", fn) + r1.SetPathValue("token", signedToken(fn)) w1 := httptest.NewRecorder() app.fileHandler(w1, r1) if w1.Code != http.StatusNotFound { @@ -643,9 +653,9 @@ func TestFileHandler_BucketMissNegativeCached(t *testing.T) { t.Fatal(err) } - r2 := httptest.NewRequest(http.MethodGet, "/files/"+fn, nil) + r2 := httptest.NewRequest(http.MethodGet, "/files/"+signedToken(fn), nil) r2 = r2.WithContext(db.Ctx()) - r2.SetPathValue("fn", fn) + r2.SetPathValue("token", signedToken(fn)) w2 := httptest.NewRecorder() app.fileHandler(w2, r2) if w2.Code != http.StatusNotFound { @@ -667,7 +677,7 @@ func TestCDNFileHandler_BucketMissNegativeCached(t *testing.T) { fn := validTestFn("missnegcdn") - r1 := httptest.NewRequest(http.MethodGet, "/_cdn/"+fn, nil) + r1 := httptest.NewRequest(http.MethodGet, "/_cdn/"+signedToken(fn), nil) r1 = r1.WithContext(db.Ctx()) w1 := httptest.NewRecorder() app.routes().ServeHTTP(w1, r1) @@ -684,7 +694,7 @@ func TestCDNFileHandler_BucketMissNegativeCached(t *testing.T) { t.Fatal(err) } - r2 := httptest.NewRequest(http.MethodGet, "/_cdn/"+fn, nil) + r2 := httptest.NewRequest(http.MethodGet, "/_cdn/"+signedToken(fn), nil) r2 = r2.WithContext(db.Ctx()) w2 := httptest.NewRecorder() app.routes().ServeHTTP(w2, r2) diff --git a/handler.go b/handler.go index 4707140..a5f7360 100644 --- a/handler.go +++ b/handler.go @@ -2,8 +2,10 @@ package main import ( "context" + "crypto/hmac" "crypto/rand" - "encoding/base64" + "crypto/sha256" + "encoding/hex" "encoding/json" "fmt" "io" @@ -23,6 +25,7 @@ type App struct { BaseURL string CDNBaseURL string InternalSecret string + SignKey []byte checkAuth func(ctx context.Context, auth, project, projectID string) AuthResult } @@ -32,8 +35,8 @@ func (a *App) routes() http.Handler { w.Write([]byte("Deploys.app Dropbox Service")) }) mux.HandleFunc("POST /{$}", a.uploadHandler) - mux.HandleFunc("GET /files/{fn}", a.fileHandler) - mux.HandleFunc("GET /_cdn/{fn}", a.cdnFileHandler) + mux.HandleFunc("GET /files/{token}", a.fileHandler) + mux.HandleFunc("GET /_cdn/{token}", a.cdnFileHandler) mux.HandleFunc("POST /internal/gc", a.gcHandler) return mux } @@ -112,41 +115,89 @@ func (a *App) uploadHandler(w http.ResponseWriter, r *http.Request) { json.NewEncoder(w).Encode(map[string]any{ "ok": true, "result": map[string]any{ - "downloadUrl": a.BaseURL + fn, + "downloadUrl": a.BaseURL + makeToken(a.SignKey, fn), "expiresAt": expiresAt.Format(time.RFC3339), }, }) } +// Filename + signature scheme. +// +// `fn` is what we store in the bucket and the DB — a random 24-char +// alphanumeric string (~143 bits of entropy, drawn from +// [0-9A-Za-z] with rejection sampling). No special chars, so URLs are +// clean and we don't need URL-safe-base64 tricks. +// +// `token` is what appears in the public URL: `fn` concatenated with a +// 20-char lowercase-hex HMAC-SHA256 tag (80-bit) keyed by App.SignKey. +// The handlers check the tag before any DB or GCS work, so a flood of +// random `/files/{token}` requests by an attacker who doesn't know the +// key gets 404'd in CPU. Length is fixed (fnLen + sigLen = 44) so we +// don't need a separator. +const ( + fnLen = 24 + sigLen = 20 // 80 bits of HMAC, hex-encoded + tokenLen = fnLen + sigLen +) + +const fnAlphabet = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz" + +// generateFilename returns a fresh fnLen-char random alphanumeric fn. +// Uses rejection sampling so each char is uniformly drawn from +// fnAlphabet (avoiding the mod-bias you'd get from `b%62`). func generateFilename() string { - b := make([]byte, 64) - rand.Read(b) - return base64.RawURLEncoding.EncodeToString(b) + out := make([]byte, fnLen) + buf := make([]byte, fnLen*2) // slack for rejection + pos := 0 + for pos < fnLen { + if _, err := rand.Read(buf); err != nil { + panic(err) // crypto/rand failing is unrecoverable + } + for _, b := range buf { + if pos >= fnLen { + break + } + // 62 * 4 = 248 is the largest unbiased range under 256; + // reject 248..255. + if b < 248 { + out[pos] = fnAlphabet[b%62] + pos++ + } + } + } + return string(out) } -// isValidFilename returns true iff fn matches what generateFilename -// produces: 86 chars of URL-safe base64 (RawURLEncoding of 64 random -// bytes, no padding). Anything else can't possibly exist in our system, -// so the file handlers can 404 it in CPU and skip both the Postgres -// query and the GCS Attributes call. That absorbs a flood of -// random-garbage fns (which would otherwise miss the per-fn cache, since -// every key is unique) at zero backend cost. -func isValidFilename(fn string) bool { - if len(fn) != 86 { - return false +// signFilename returns the sigLen-char lowercase-hex HMAC tag for fn +// under key. Caller's responsibility to keep `key` secret. +func signFilename(key []byte, fn string) string { + mac := hmac.New(sha256.New, key) + mac.Write([]byte(fn)) + sum := mac.Sum(nil) + return hex.EncodeToString(sum[:sigLen/2]) +} + +// makeToken builds the public URL token (fn + sig) for fn. +func makeToken(key []byte, fn string) string { + return fn + signFilename(key, fn) +} + +// parseToken splits a public URL token into (fn, sig), verifies the +// HMAC in constant time, and returns the fn on success. On any +// failure — wrong length, bad sig, anything — it returns (_, false) +// and the handler 404s without touching the DB or GCS. This is the +// primary DDoS shield against random `/files/{garbage}` floods. +func parseToken(key []byte, token string) (string, bool) { + if len(token) != tokenLen { + return "", false } - for i := 0; i < len(fn); i++ { - c := fn[i] - switch { - case c >= 'A' && c <= 'Z': - case c >= 'a' && c <= 'z': - case c >= '0' && c <= '9': - case c == '-', c == '_': - default: - return false - } + fn := token[:fnLen] + sig := token[fnLen:] + expected := signFilename(key, fn) + if !hmac.Equal([]byte(sig), []byte(expected)) { + return "", false } - return true + return fn, true } func escapeFilename(s string) string { diff --git a/handler_test.go b/handler_test.go index 09ddc37..68698d9 100644 --- a/handler_test.go +++ b/handler_test.go @@ -23,15 +23,40 @@ func newTestBucket(t *testing.T) *blob.Bucket { return bkt } -// validTestFn returns an 86-char fn that passes isValidFilename, with -// the descriptive suffix preserved so failing tests are still readable. +// validTestFn returns an fnLen-char fn drawn from fnAlphabet, with the +// descriptive suffix preserved so failing tests are still readable. // Distinct suffixes produce distinct fns, which keeps parallel tests // isolated in the in-process per-fn cache. func validTestFn(suffix string) string { - if len(suffix) > 86 { + if len(suffix) > fnLen { panic("validTestFn: suffix too long: " + suffix) } - return strings.Repeat("a", 86-len(suffix)) + suffix + return strings.Repeat("a", fnLen-len(suffix)) + suffix +} + +// testSignKey is the fixed HMAC key used in every test app. Tests build +// signed tokens with makeToken(testSignKey, fn) and the handlers verify +// against this same key. +var testSignKey = []byte("test-sign-key-do-not-use-in-prod") + +// signedToken wraps a test fn into the public URL token by HMAC-signing +// it with testSignKey. Tests put this in the URL path; the handler +// parses it back into fn for the bucket/DB lookup. +func signedToken(fn string) string { + return makeToken(testSignKey, fn) +} + +// tokenToFn is the inverse: pull the token off a downloadUrl, verify +// it, and return the fn that's actually stored in the bucket and DB. +// Tests use this when they need to inspect the upload's bucket object. +func tokenToFn(t *testing.T, downloadURL string) string { + t.Helper() + token := strings.TrimPrefix(downloadURL, "https://example.com/") + fn, ok := parseToken(testSignKey, token) + if !ok { + t.Fatalf("downloadUrl %q does not contain a valid signed token", downloadURL) + } + return fn } func countObjects(t *testing.T, bkt *blob.Bucket) int { @@ -61,6 +86,7 @@ func newTestApp(bkt *blob.Bucket, authFn func(context.Context, string, string, s return &App{ Bucket: bkt, BaseURL: "https://example.com/", + SignKey: testSignKey, checkAuth: authFn, } } @@ -312,7 +338,7 @@ func TestUpload_FilenameFromQuery(t *testing.T) { t.Fatal("expected ok=true") } - fn := strings.TrimPrefix(resp.Result.DownloadURL, "https://example.com/") + fn := tokenToFn(t, resp.Result.DownloadURL) attrs, err := bkt.Attributes(t.Context(), fn) if err != nil { t.Fatal(err) @@ -342,7 +368,7 @@ func TestUpload_FilenameFromHeader(t *testing.T) { t.Fatal("expected ok=true") } - fn := strings.TrimPrefix(resp.Result.DownloadURL, "https://example.com/") + fn := tokenToFn(t, resp.Result.DownloadURL) attrs, err := bkt.Attributes(t.Context(), fn) if err != nil { t.Fatal(err) @@ -371,7 +397,7 @@ func TestUpload_FilenameWithQuotesEscaped(t *testing.T) { t.Fatal("expected ok=true") } - fn := strings.TrimPrefix(resp.Result.DownloadURL, "https://example.com/") + fn := tokenToFn(t, resp.Result.DownloadURL) attrs, err := bkt.Attributes(t.Context(), fn) if err != nil { t.Fatal(err) @@ -396,7 +422,7 @@ func TestUpload_NoFilenameNoContentDisposition(t *testing.T) { var resp uploadResp json.NewDecoder(w.Body).Decode(&resp) - fn := strings.TrimPrefix(resp.Result.DownloadURL, "https://example.com/") + fn := tokenToFn(t, resp.Result.DownloadURL) attrs, err := bkt.Attributes(t.Context(), fn) if err != nil { t.Fatal(err) @@ -463,47 +489,117 @@ func TestGenerateFilename(t *testing.T) { if a == b { t.Error("expected unique filenames") } - if len(a) != 86 { - t.Errorf("filename length = %d, want 86", len(a)) + if len(a) != fnLen { + t.Errorf("filename length = %d, want %d", len(a), fnLen) } + // Must be strictly alphanumeric — no special chars at all. for _, c := range a { - if c == '+' || c == '/' || c == '=' { - t.Errorf("filename contains non-URL-safe char %q: %s", c, a) + switch { + case c >= '0' && c <= '9': + case c >= 'A' && c <= 'Z': + case c >= 'a' && c <= 'z': + default: + t.Errorf("filename contains non-alphanumeric char %q: %s", c, a) } } } -func TestIsValidFilename(t *testing.T) { +func TestGenerateFilename_DistributionLooksUniform(t *testing.T) { + t.Parallel() + // Rough check that rejection sampling doesn't bias the alphabet. + // With 1000 * fnLen = 24000 chars over 62 symbols, expected count + // per symbol is ~387. We just assert every symbol appears at least + // once, which would fail catastrophically if a whole half of the + // alphabet got dropped (e.g. by an off-by-one in the reject bound). + counts := map[byte]int{} + for i := 0; i < 1000; i++ { + for j := 0; j < len(generateFilename()); j++ { + counts[generateFilename()[j]]++ + } + } + for _, c := range []byte(fnAlphabet) { + if counts[c] == 0 { + t.Errorf("symbol %q never appeared in 1000 generated fns", c) + } + } +} + +func TestSignFilename_Deterministic(t *testing.T) { + t.Parallel() + fn := validTestFn("sigtest") + a := signFilename(testSignKey, fn) + b := signFilename(testSignKey, fn) + if a != b { + t.Errorf("signFilename not deterministic: %q vs %q", a, b) + } + if len(a) != sigLen { + t.Errorf("sig length = %d, want %d", len(a), sigLen) + } +} + +func TestSignFilename_KeyMatters(t *testing.T) { + t.Parallel() + fn := validTestFn("keytest") + a := signFilename(testSignKey, fn) + b := signFilename([]byte("different-key"), fn) + if a == b { + t.Errorf("sig should differ across keys; got %q for both", a) + } +} + +func TestParseToken_RoundTripsGeneratedToken(t *testing.T) { t.Parallel() - // Generated fns must always validate — this links the producer and - // the validator so a future change to either can't drift. for i := 0; i < 16; i++ { - if fn := generateFilename(); !isValidFilename(fn) { - t.Errorf("generateFilename() = %q failed isValidFilename", fn) + fn := generateFilename() + token := makeToken(testSignKey, fn) + if len(token) != tokenLen { + t.Fatalf("token length = %d, want %d", len(token), tokenLen) + } + got, ok := parseToken(testSignKey, token) + if !ok || got != fn { + t.Errorf("parseToken(%q) = (%q, %v), want (%q, true)", token, got, ok, fn) } } +} - good := strings.Repeat("a", 86) - cases := []struct { - fn string - want bool - }{ - {good, true}, - {"abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789-_abcdefghijklmnopqrstuv", true}, // 86 chars, full alphabet - {"", false}, - {"short", false}, - {strings.Repeat("a", 85), false}, - {strings.Repeat("a", 87), false}, - {strings.Repeat("a", 85) + "!", false}, // bad char - {strings.Repeat("a", 85) + "+", false}, // standard-base64 only - {strings.Repeat("a", 85) + "/", false}, - {strings.Repeat("a", 85) + "=", false}, // padding not allowed by RawURLEncoding - {strings.Repeat("a", 85) + " ", false}, - {strings.Repeat("a", 85) + ".", false}, +func TestParseToken_RejectsForgeries(t *testing.T) { + t.Parallel() + fn := validTestFn("forgery") + good := makeToken(testSignKey, fn) + + // Wrong key — handler must not trust whatever the attacker sends. + if _, ok := parseToken([]byte("wrong-key"), good); ok { + t.Error("parseToken accepted token signed under a different key") } - for _, tc := range cases { - if got := isValidFilename(tc.fn); got != tc.want { - t.Errorf("isValidFilename(%q) = %v, want %v", tc.fn, got, tc.want) + + // Tamper with the sig portion. + tampered := good[:tokenLen-1] + "0" + if tampered == good { + tampered = good[:tokenLen-1] + "1" + } + if _, ok := parseToken(testSignKey, tampered); ok { + t.Errorf("parseToken accepted tampered sig: %q", tampered) + } + + // Tamper with the fn portion (now the sig no longer matches). + swapFn := "b" + good[1:] + if swapFn == good { + swapFn = "c" + good[1:] + } + if _, ok := parseToken(testSignKey, swapFn); ok { + t.Errorf("parseToken accepted token with rewritten fn: %q", swapFn) + } + + // Wrong length. + for _, bad := range []string{ + "", + "short", + good + "x", + good[:tokenLen-1], + strings.Repeat("a", tokenLen), + } { + if _, ok := parseToken(testSignKey, bad); ok { + t.Errorf("parseToken accepted bad token %q", bad) } } } diff --git a/main.go b/main.go index 7988da1..44f6db4 100644 --- a/main.go +++ b/main.go @@ -61,6 +61,10 @@ func main() { BaseURL: config.StringDefault("base_url", "https://dropbox.deploys.app/files/"), CDNBaseURL: config.String("cdn_base_url"), InternalSecret: config.String("internal_secret"), + // sign_key is the HMAC key used to sign download tokens. + // Required — rotating it invalidates all outstanding URLs (which + // is the right behavior if a key ever leaks). + SignKey: []byte(config.MustString("sign_key")), } promAddr := config.StringDefault("prom_addr", ":9187") From 536866a165cb8308b6f979ce8f9af0098318f163 Mon Sep 17 00:00:00 2001 From: Thanatat Tamtan Date: Sun, 24 May 2026 10:47:20 +0700 Subject: [PATCH 6/7] Separate fn and sig with "-" so fn length can change without breaking URLs MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The previous token layout — fn (fnLen=24 chars) concatenated directly with sig (sigLen=20 chars) — relied on a fixed cut point at index fnLen for parseToken. That means any future bump or trim of fnLen would 404 every outstanding URL: parseToken would slice the old token at the wrong position and compute HMAC over the wrong fn. Insert "-" between them. fn is [0-9A-Za-z] and sig is hex, so neither side can ever contain "-" — the separator is unambiguous. parseToken now splits on strings.IndexByte('-') instead of a positional slice, so the verifier no longer cares what fnLen was when a given token was issued: as long as sig is what signFilename(key, fn) produces today, the token validates. tokenLen is gone as a single constant (token length is now fn-length-dependent). Tests that asserted positional layout were rewritten to assert structural properties: - TestParseToken_RoundTripsGeneratedToken now checks for the separator's presence and uses fnLen+1+sigLen as the *current* expected length, not a hard invariant. - New TestParseToken_AcceptsDifferentFnLengths feeds fns of 1, 3, 24, and 64 chars through makeToken/parseToken to lock in the property that motivated this change. - TestParseToken_RejectsForgeries and TestFileHandler_InvalidTokenRejected swapped their length-based bad cases for structural ones: missing separator, empty fn, empty sig — a 44-char string with no "-" now fails for the right reason. CLAUDE.md and README updated to mention the "-" and the backward-compatibility property it buys. Co-Authored-By: Claude Opus 4.7 --- CLAUDE.md | 5 +++-- README.md | 2 +- files_test.go | 10 +++++----- handler.go | 40 +++++++++++++++++++++++----------------- handler_test.go | 46 ++++++++++++++++++++++++++++++++++++---------- 5 files changed, 68 insertions(+), 35 deletions(-) diff --git a/CLAUDE.md b/CLAUDE.md index 1c133f7..4c80217 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -29,7 +29,8 @@ Standard Go HTTP server (not Cloudflare Workers) serving as a temporary file upl - `log_level` — slog level (default: info) **Download token scheme (`handler.go`):** -- The URL path component is a 44-char `token` = `fn` (24 chars random `[0-9A-Za-z]`, ~143 bits of entropy) concatenated with `sig` (20 hex chars of HMAC-SHA256 truncated to 80 bits, keyed by `sign_key`). +- The URL path component is a `token` = `fn` + `"-"` + `sig`, currently 45 chars total. `fn` is 24 random chars `[0-9A-Za-z]` (~143 bits of entropy); `sig` is 20 hex chars of HMAC-SHA256 truncated to 80 bits, keyed by `sign_key`. +- The `-` separator lets us change `fnLen` later without invalidating tokens that are already in circulation — `parseToken` splits structurally on the separator, not by fixed position. Since `fn` is alphanumeric and `sig` is hex, neither side can contain a `-`. - `parseToken(SignKey, token)` runs first in both `fileHandler` and `cdnFileHandler` and 404s on any mismatch — DDoS attempts that don't know `sign_key` never reach the DB or GCS. - `fn` is what we store in the bucket and the `files.fn` column. The full token only appears in URLs. @@ -40,7 +41,7 @@ Standard Go HTTP server (not Cloudflare Workers) serving as a temporary file upl 4. Generate a 24-char alphanumeric `fn` (`generateFilename`, rejection-sampled to stay unbiased) 5. Stream body to GCS with cache-control and optional content-disposition, keyed by `fn` 6. Insert metadata into PostgreSQL via `pgctx.Exec` -7. Return JSON: `{"ok": true, "result": {"downloadUrl": "{base_url}{fn}{sig}", "expiresAt": "..."}}` +7. Return JSON: `{"ok": true, "result": {"downloadUrl": "{base_url}{fn}-{sig}", "expiresAt": "..."}}` **Auth (`auth.go`):** - No `Authorization` header → alpha mode, project ID hardcoded as `"alpha"` (TODO: remove) diff --git a/README.md b/README.md index 77189e3..35f8610 100644 --- a/README.md +++ b/README.md @@ -81,7 +81,7 @@ File data binary } ``` -`` is a 44-char string: a 24-char random alphanumeric filename concatenated with a 20-char HMAC-SHA256 signature (keyed by `sign_key`). Tampered or made-up tokens are rejected before any DB or storage lookup. +`` is `{fn}-{sig}` (currently 45 chars): a 24-char random alphanumeric filename, a `-` separator, and a 20-char HMAC-SHA256 signature (keyed by `sign_key`). Tampered or made-up tokens are rejected before any DB or storage lookup. The separator means future changes to filename length stay backward-compatible. ##### Unauthorized diff --git a/files_test.go b/files_test.go index db4446a..eccbff9 100644 --- a/files_test.go +++ b/files_test.go @@ -579,9 +579,9 @@ func TestFileHandler_InvalidTokenRejected(t *testing.T) { goodFn := validTestFn("good") goodToken := signedToken(goodFn) - tamperedSig := goodToken[:tokenLen-1] + "0" + tamperedSig := goodToken[:len(goodToken)-1] + "0" if tamperedSig == goodToken { - tamperedSig = goodToken[:tokenLen-1] + "1" + tamperedSig = goodToken[:len(goodToken)-1] + "1" } forgedDiffKey := makeToken([]byte("not-the-real-key"), goodFn) @@ -589,9 +589,9 @@ func TestFileHandler_InvalidTokenRejected(t *testing.T) { "empty": "", "short": "short", "path-traversal": "../../etc/passwd", - "one-char-short": strings.Repeat("a", tokenLen-1), - "one-char-long": strings.Repeat("a", tokenLen+1), - "right-len-bad-sig": strings.Repeat("a", tokenLen), + "no-separator": strings.Repeat("a", fnLen+sigLen), + "empty-fn": tokenSep + signFilename(testSignKey, ""), + "empty-sig": goodFn + tokenSep, "tampered-sig": tamperedSig, "forged-other-key": forgedDiffKey, } diff --git a/handler.go b/handler.go index a5f7360..eb509a5 100644 --- a/handler.go +++ b/handler.go @@ -128,16 +128,18 @@ func (a *App) uploadHandler(w http.ResponseWriter, r *http.Request) { // [0-9A-Za-z] with rejection sampling). No special chars, so URLs are // clean and we don't need URL-safe-base64 tricks. // -// `token` is what appears in the public URL: `fn` concatenated with a -// 20-char lowercase-hex HMAC-SHA256 tag (80-bit) keyed by App.SignKey. -// The handlers check the tag before any DB or GCS work, so a flood of -// random `/files/{token}` requests by an attacker who doesn't know the -// key gets 404'd in CPU. Length is fixed (fnLen + sigLen = 44) so we -// don't need a separator. +// `token` is what appears in the public URL: `fn` + "-" + 20-char +// lowercase-hex HMAC-SHA256 tag (80-bit) keyed by App.SignKey. The +// handlers check the tag before any DB or GCS work, so a flood of +// random `/files/{token}` requests by an attacker who doesn't know +// the key gets 404'd in CPU. The "-" separator means we can change +// `fnLen` later without invalidating old tokens — parseToken splits +// on it instead of relying on a fixed cut point. fn is alphanumeric +// and sig is hex, so neither side can contain a "-". const ( fnLen = 24 sigLen = 20 // 80 bits of HMAC, hex-encoded - tokenLen = fnLen + sigLen + tokenSep = "-" ) const fnAlphabet = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz" @@ -177,22 +179,26 @@ func signFilename(key []byte, fn string) string { return hex.EncodeToString(sum[:sigLen/2]) } -// makeToken builds the public URL token (fn + sig) for fn. +// makeToken builds the public URL token for fn: `fn + "-" + sig`. func makeToken(key []byte, fn string) string { - return fn + signFilename(key, fn) + return fn + tokenSep + signFilename(key, fn) } -// parseToken splits a public URL token into (fn, sig), verifies the -// HMAC in constant time, and returns the fn on success. On any -// failure — wrong length, bad sig, anything — it returns (_, false) -// and the handler 404s without touching the DB or GCS. This is the -// primary DDoS shield against random `/files/{garbage}` floods. +// parseToken splits a public URL token on tokenSep, verifies the HMAC +// in constant time, and returns the fn on success. On any failure — +// missing separator, empty side, bad sig, anything — it returns +// (_, false) and the handler 404s without touching the DB or GCS. +// This is the primary DDoS shield against random `/files/{garbage}` +// floods. The split is structural (on the separator) rather than +// positional, so changing fnLen later won't invalidate old tokens +// still in circulation. func parseToken(key []byte, token string) (string, bool) { - if len(token) != tokenLen { + idx := strings.IndexByte(token, tokenSep[0]) + if idx <= 0 || idx == len(token)-1 { return "", false } - fn := token[:fnLen] - sig := token[fnLen:] + fn := token[:idx] + sig := token[idx+1:] expected := signFilename(key, fn) if !hmac.Equal([]byte(sig), []byte(expected)) { return "", false diff --git a/handler_test.go b/handler_test.go index 68698d9..9c9f50f 100644 --- a/handler_test.go +++ b/handler_test.go @@ -552,8 +552,11 @@ func TestParseToken_RoundTripsGeneratedToken(t *testing.T) { for i := 0; i < 16; i++ { fn := generateFilename() token := makeToken(testSignKey, fn) - if len(token) != tokenLen { - t.Fatalf("token length = %d, want %d", len(token), tokenLen) + if want := fnLen + 1 + sigLen; len(token) != want { + t.Fatalf("token length = %d, want %d", len(token), want) + } + if !strings.Contains(token, tokenSep) { + t.Fatalf("token %q missing separator %q", token, tokenSep) } got, ok := parseToken(testSignKey, token) if !ok || got != fn { @@ -562,6 +565,26 @@ func TestParseToken_RoundTripsGeneratedToken(t *testing.T) { } } +func TestParseToken_AcceptsDifferentFnLengths(t *testing.T) { + // The whole point of the separator: parseToken must work for fns of + // any length, as long as the sig was produced by signFilename. This + // guards future changes to fnLen without breaking outstanding URLs. + t.Parallel() + for _, fn := range []string{ + "a", + "abc", + validTestFn("short"), + validTestFn("default"), + strings.Repeat("z", 64), // hypothetically larger fn + } { + token := makeToken(testSignKey, fn) + got, ok := parseToken(testSignKey, token) + if !ok || got != fn { + t.Errorf("parseToken round-trip failed for fn=%q: got=(%q, %v)", fn, got, ok) + } + } +} + func TestParseToken_RejectsForgeries(t *testing.T) { t.Parallel() fn := validTestFn("forgery") @@ -572,16 +595,17 @@ func TestParseToken_RejectsForgeries(t *testing.T) { t.Error("parseToken accepted token signed under a different key") } - // Tamper with the sig portion. - tampered := good[:tokenLen-1] + "0" + // Tamper with the sig portion (last char). + tampered := good[:len(good)-1] + "0" if tampered == good { - tampered = good[:tokenLen-1] + "1" + tampered = good[:len(good)-1] + "1" } if _, ok := parseToken(testSignKey, tampered); ok { t.Errorf("parseToken accepted tampered sig: %q", tampered) } - // Tamper with the fn portion (now the sig no longer matches). + // Tamper with the fn portion (the sig no longer matches what we'd + // compute over the rewritten fn). swapFn := "b" + good[1:] if swapFn == good { swapFn = "c" + good[1:] @@ -590,13 +614,15 @@ func TestParseToken_RejectsForgeries(t *testing.T) { t.Errorf("parseToken accepted token with rewritten fn: %q", swapFn) } - // Wrong length. + // Structural failures: missing separator, empty fn/sig, etc. for _, bad := range []string{ "", "short", - good + "x", - good[:tokenLen-1], - strings.Repeat("a", tokenLen), + "no-separator-but-also-clearly-not-a-real-token", + tokenSep + signFilename(testSignKey, ""), // empty fn + fn + tokenSep, // empty sig + fn, // sig missing entirely + strings.Repeat("a", fnLen+1+sigLen), // right total length, no separator at all } { if _, ok := parseToken(testSignKey, bad); ok { t.Errorf("parseToken accepted bad token %q", bad) From 01e6064ab0edea7b81bbe54d9040ddf31a813609 Mon Sep 17 00:00:00 2001 From: Thanatat Tamtan Date: Sun, 24 May 2026 11:08:48 +0700 Subject: [PATCH 7/7] Set Cache-Control on every /_cdn response MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Before this change /_cdn relied on streamFile to copy attrs.CacheControl ("public, max-age=86400") onto the response. That has two problems for the edge: 1) The bucket's 24h max-age caps CDN cache at one day even for files whose TTL is 7 days — every cached object re-fetches origin six extra times before expiry. 2) Error responses (410, 404) carried no Cache-Control at all, so a shared-then-expired URL kept hitting origin on every probe. cdnFileHandler now sets Cache-Control explicitly per outcome: - success (200): public, max-age={remaining TTL}, immutable. fn is unique per upload so the body genuinely never changes; capping max-age at expires_at means the edge stops serving past end-of-life instead of holding the bytes for its own cache TTL. Falls back to attrs.CacheControl if expires_at is unknown (the rare insert-failed-upload case). - expired (410) and bucket-missing (404): public, max-age=3600. The answer is permanent; let the edge absorb repeat probes. - invalid token (404): no Cache-Control. Each garbage token is a unique URL, so caching wouldn't reduce origin load and would just fill edge slots with attacker traffic. streamFile now only falls back to attrs.CacheControl when the caller hasn't already chosen one — that lets cdnFileHandler pre-set the TTL-aligned policy without streamFile clobbering it. fileHandler is unchanged: it never sets Cache-Control beforehand, so streamFile still uses the bucket's value (correct for the direct-stream / browser case). Tests: - TestCDNFileHandler_SuccessCacheControlUsesRemainingTTL inserts a 7-day-TTL row alongside the bucket's 24h CacheControl and asserts the response's max-age is ~7 days, not 86400 — locks in the override. - TestCDNFileHandler_ExpiredCacheControl asserts 410 carries public, max-age=3600. - TestCDNFileHandler_NotFound extended to assert the same on the bucket-NotFound 404 path. - TestCDNFileHandler_InvalidTokenNoCacheControl asserts garbage-token 404s leave Cache-Control empty so attacker traffic can't pollute edge cache. - TestCDNFileHandler_Streams (existing) still passes — it has no DB row, so meta.ExpiresAt is zero and we correctly fall through to attrs.CacheControl. Co-Authored-By: Claude Opus 4.7 --- CLAUDE.md | 2 +- files.go | 40 +++++++++++++++++- files_test.go | 115 ++++++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 155 insertions(+), 2 deletions(-) diff --git a/CLAUDE.md b/CLAUDE.md index 4c80217..7b4932d 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -24,7 +24,7 @@ Standard Go HTTP server (not Cloudflare Workers) serving as a temporary file upl **Environment variables (optional):** - `base_url` — download URL prefix (default: `https://dropbox.deploys.app/files/`) - `api_endpoint` — deploys.app API base URL (default: `https://api.deploys.app`, override with internal address in production) -- `cdn_base_url` — full URL prefix (including scheme and trailing slash, e.g. `https://cdn.example.com/`). When set, `GET /files/{token}` records the download metric (counted against `attrs.Size`) and 307-redirects to `{cdn_base_url}{token}`. The CDN edge is expected to fetch its origin at `https://dropbox.deploys.app/_cdn/{token}`, which streams the file (no auth, no metrics) after re-verifying the same HMAC. In-cluster callers (private/loopback/link-local `X-Real-Ip`) bypass the redirect and stream directly. Unset = original streaming behavior. +- `cdn_base_url` — full URL prefix (including scheme and trailing slash, e.g. `https://cdn.example.com/`). When set, `GET /files/{token}` records the download metric (counted against `attrs.Size`) and 307-redirects to `{cdn_base_url}{token}`. The CDN edge is expected to fetch its origin at `https://dropbox.deploys.app/_cdn/{token}`, which streams the file (no auth, no metrics) after re-verifying the same HMAC. Each `/_cdn` response sets `Cache-Control` for the edge: success uses `public, max-age={remaining TTL}, immutable` so the edge cache lines up with the file's actual lifetime; 410/404 for known-dead URLs use `public, max-age=3600` so repeat probes are absorbed at the edge; invalid tokens get no `Cache-Control` because each garbage URL is unique and caching just burns edge slots. In-cluster callers (private/loopback/link-local `X-Real-Ip`) bypass the redirect and stream directly. Unset = original streaming behavior. - `PORT` — listen port (default: `8080`) - `log_level` — slog level (default: info) diff --git a/files.go b/files.go index 7dab289..1fd0a2b 100644 --- a/files.go +++ b/files.go @@ -4,6 +4,7 @@ import ( "context" "database/sql" "errors" + "fmt" "io" "log/slog" "net/http" @@ -112,12 +113,28 @@ func (a *App) fileHandler(w http.ResponseWriter, r *http.Request) { a.streamFile(w, r, fn, attrs, meta.ProjectID) } +// cdnDeadResponseTTL is how long the CDN edge should cache "this URL is +// dead" responses (410 expired, 404 bucket-missing). Long enough that a +// shared-then-expired URL stops hammering origin, short enough that the +// answer can update if we ever change behavior. +const cdnDeadResponseTTL = time.Hour + // cdnFileHandler is the origin endpoint the CDN edge fetches from. The // token in the URL is the same signed token we issued to the user, so // the HMAC check still gates the edge — only URLs we issued can reach // the bucket. Streams the body without metrics so the CDN sees a plain // response; expired files are refused here too so the CDN can't refresh // its cache with bytes the user is no longer entitled to. +// +// Each response sets Cache-Control to tell the edge how long to cache: +// - success (200): public, max-age={remaining TTL}, immutable — fn is +// unique per upload so the body never changes; cap at the file's +// expires_at so the edge stops serving past end-of-life. +// - expired (410) and bucket-missing (404): public, max-age=3600 — +// the answer is permanent, so let the edge absorb repeat probes. +// - invalid token (404): no Cache-Control. Every garbage token is a +// unique URL, so caching doesn't reduce origin load and would just +// burn edge cache slots on attacker traffic. func (a *App) cdnFileHandler(w http.ResponseWriter, r *http.Request) { token := r.PathValue("token") @@ -133,10 +150,12 @@ func (a *App) cdnFileHandler(w http.ResponseWriter, r *http.Request) { meta := lookupFile(r.Context(), fn) if meta.Expired() { + setDeadCacheControl(w) http.Error(w, "file expired", http.StatusGone) return } if meta.BucketMissing { + setDeadCacheControl(w) http.NotFound(w, r) return } @@ -146,6 +165,7 @@ func (a *App) cdnFileHandler(w http.ResponseWriter, r *http.Request) { if gcerrors.Code(err) == gcerrors.NotFound { meta.BucketMissing = true cachestore.Set(fileCacheKey(fn), meta, &cachestore.SetOptions{TTL: fileMetaCacheTTL}) + setDeadCacheControl(w) http.NotFound(w, r) return } @@ -153,9 +173,24 @@ func (a *App) cdnFileHandler(w http.ResponseWriter, r *http.Request) { return } + // Pre-set Cache-Control so streamFile's attrs.CacheControl fallback + // doesn't overwrite it. If we have no expires_at to work from + // (insert-failed upload), fall through to the bucket's value. + if !meta.ExpiresAt.IsZero() { + remaining := time.Until(meta.ExpiresAt) + if remaining < 0 { + remaining = 0 + } + w.Header().Set("Cache-Control", fmt.Sprintf("public, max-age=%d, immutable", int(remaining.Seconds()))) + } + a.streamFile(w, r, fn, attrs, "") } +func setDeadCacheControl(w http.ResponseWriter) { + w.Header().Set("Cache-Control", fmt.Sprintf("public, max-age=%d", int(cdnDeadResponseTTL.Seconds()))) +} + // streamFile copies the object body to w with the cached headers. When // projectID is non-empty, download_count and egress_bytes are bumped by // the actual transferred bytes. Used by both the direct-stream path @@ -172,7 +207,10 @@ func (a *App) streamFile(w http.ResponseWriter, r *http.Request, fn string, attr } defer reader.Close() - if attrs.CacheControl != "" { + // Only fall back to the bucket's CacheControl if the caller hasn't + // already chosen one. cdnFileHandler pre-sets a TTL-aligned policy; + // the non-CDN fileHandler leaves it unset and gets the bucket value. + if attrs.CacheControl != "" && w.Header().Get("Cache-Control") == "" { w.Header().Set("Cache-Control", attrs.CacheControl) } if attrs.ContentDisposition != "" { diff --git a/files_test.go b/files_test.go index eccbff9..47bc518 100644 --- a/files_test.go +++ b/files_test.go @@ -1,6 +1,7 @@ package main import ( + "fmt" "net/http" "net/http/httptest" "strings" @@ -566,6 +567,120 @@ func TestCDNFileHandler_NotFound(t *testing.T) { if w.Code != http.StatusNotFound { t.Errorf("status = %d, want 404", w.Code) } + // Bucket NotFound for a signed token → cache the negative at the + // edge so the same dead URL doesn't keep hitting origin. + if cc := w.Header().Get("Cache-Control"); cc != "public, max-age=3600" { + t.Errorf("Cache-Control = %q, want public, max-age=3600", cc) + } +} + +func TestCDNFileHandler_SuccessCacheControlUsesRemainingTTL(t *testing.T) { + // The CDN response must cap max-age at the file's remaining TTL so + // the edge stops serving past expires_at. fn is unique per upload + // so we mark the body immutable. + t.Parallel() + db := newTestDB(t) + bkt := newTestBucket(t) + app := newTestApp(bkt, authorized) + app.CDNBaseURL = "https://cdn.example.com/" + + fn := validTestFn("cdncc") + bw, err := bkt.NewWriter(t.Context(), fn, &blob.WriterOptions{ + CacheControl: "public, max-age=86400", // bucket default — must be overridden + }) + if err != nil { + t.Fatal(err) + } + bw.Write([]byte("data")) + if err := bw.Close(); err != nil { + t.Fatal(err) + } + + ctx := db.Ctx() + if _, err := pgctx.Exec(ctx, ` + INSERT INTO files (fn, project_id, size, filename, ttl, expires_at) + VALUES ($1, 'proj-cc', 4, 'x', 7, now() + interval '7 days') + `, fn); err != nil { + t.Fatal(err) + } + + r := httptest.NewRequest(http.MethodGet, "/_cdn/"+signedToken(fn), nil) + r = r.WithContext(ctx) + w := httptest.NewRecorder() + app.routes().ServeHTTP(w, r) + + if w.Code != http.StatusOK { + t.Fatalf("status = %d, want 200", w.Code) + } + cc := w.Header().Get("Cache-Control") + if !strings.HasPrefix(cc, "public, max-age=") || !strings.HasSuffix(cc, ", immutable") { + t.Fatalf("Cache-Control = %q, want shape 'public, max-age=N, immutable'", cc) + } + // Extract and sanity-check the max-age — must reflect ~7 days, not + // the bucket default of 86400. + var maxAge int + if _, err := fmt.Sscanf(cc, "public, max-age=%d, immutable", &maxAge); err != nil { + t.Fatalf("parse max-age from %q: %v", cc, err) + } + const sevenDays = 7 * 24 * 3600 + if maxAge < sevenDays-3600 || maxAge > sevenDays { + t.Errorf("max-age = %d, want roughly 7 days (%d ± 1h)", maxAge, sevenDays) + } +} + +func TestCDNFileHandler_ExpiredCacheControl(t *testing.T) { + // 410 response must carry a short Cache-Control so the edge stops + // asking origin for an expired URL that's still in circulation. + t.Parallel() + db := newTestDB(t) + bkt := newTestBucket(t) + app := newTestApp(bkt, authorized) + app.CDNBaseURL = "https://cdn.example.com/" + + fn := validTestFn("cdnccexp") + bw, _ := bkt.NewWriter(t.Context(), fn, nil) + bw.Write([]byte("stale")) + bw.Close() + + ctx := db.Ctx() + if _, err := pgctx.Exec(ctx, ` + INSERT INTO files (fn, project_id, size, filename, ttl, expires_at) + VALUES ($1, 'proj-cc', 5, 'x', 1, now() - interval '1 hour') + `, fn); err != nil { + t.Fatal(err) + } + + r := httptest.NewRequest(http.MethodGet, "/_cdn/"+signedToken(fn), nil) + r = r.WithContext(ctx) + w := httptest.NewRecorder() + app.routes().ServeHTTP(w, r) + + if w.Code != http.StatusGone { + t.Fatalf("status = %d, want 410", w.Code) + } + if cc := w.Header().Get("Cache-Control"); cc != "public, max-age=3600" { + t.Errorf("Cache-Control = %q, want public, max-age=3600", cc) + } +} + +func TestCDNFileHandler_InvalidTokenNoCacheControl(t *testing.T) { + // Forgeable garbage tokens must NOT get a Cache-Control — each + // token is a unique URL so caching wouldn't reduce origin load and + // would just fill edge slots on attacker traffic. + t.Parallel() + app := newTestApp(newTestBucket(t), authorized) + app.CDNBaseURL = "https://cdn.example.com/" + + r := httptest.NewRequest(http.MethodGet, "/_cdn/garbage", nil) + w := httptest.NewRecorder() + app.routes().ServeHTTP(w, r) + + if w.Code != http.StatusNotFound { + t.Fatalf("status = %d, want 404", w.Code) + } + if cc := w.Header().Get("Cache-Control"); cc != "" { + t.Errorf("Cache-Control = %q, want empty (attacker traffic shouldn't be cached)", cc) + } } func TestFileHandler_InvalidTokenRejected(t *testing.T) {