diff --git a/.github/workflows/github-release.yml b/.github/workflows/github-release.yml index 20c8973..901e003 100644 --- a/.github/workflows/github-release.yml +++ b/.github/workflows/github-release.yml @@ -5,27 +5,13 @@ on: - main types: - closed -permissions: - contents: write - actions: write jobs: release: - if: github.event.pull_request.merged == true && !contains(github.event.pull_request.title, '[skip-release]') - runs-on: ubuntu-24.04 - steps: - - name: Checkout - uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5 - with: - fetch-depth: 0 - - - name: install autotag binary - run: curl -sL https://git.io/autotag-install | sudo sh -s -- -b /usr/bin - - - name: create release - run: |- - TAG=$(autotag) - git push origin v$TAG - gh release create v$TAG --title "v$TAG" --generate-notes - env: - GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - + if: github.event.pull_request.merged == true && !contains(github.event.pull_request.title, 'skip-release') + uses: libops/actions/.github/workflows/bump-release.yaml@main + with: + prefix: v + permissions: + contents: write + actions: write + secrets: inherit diff --git a/.github/workflows/lint-test.yml b/.github/workflows/lint-test.yml index ce96425..799fccd 100644 --- a/.github/workflows/lint-test.yml +++ b/.github/workflows/lint-test.yml @@ -53,14 +53,36 @@ jobs: env: CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} - integration-test: + integration-test-latest: needs: [run] permissions: contents: read runs-on: ubuntu-24.04 strategy: matrix: - traefik: [v2.11, v3.0, v3.1, v3.2, v3.3, v3.4] + traefik: [latest] + steps: + - uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5 + + - name: run + run: go run test.go + working-directory: ./ci + env: + TRAEFIK_TAG: ${{ matrix.traefik }} + + - name: cleanup + if: ${{ always() }} + run: docker compose logs --tail 100 nginx nginx2 traefik && docker compose down + working-directory: ./ci + + integration-test: + needs: [integration-test-latest] + permissions: + contents: read + runs-on: ubuntu-24.04 + strategy: + matrix: + traefik: [v2.11, v3.0, v3.1, v3.2, v3.3, v3.4, v3.5] steps: - uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5 @@ -72,5 +94,5 @@ jobs: - name: cleanup if: ${{ always() }} - run: docker compose down + run: docker compose logs --tail 100 nginx nginx2 traefik && docker compose down working-directory: ./ci diff --git a/.traefik.yml b/.traefik.yml index 21bb038..3cd0a51 100644 --- a/.traefik.yml +++ b/.traefik.yml @@ -11,3 +11,4 @@ testData: CaptchaProvider: turnstile SiteKey: 1x00000000000000000000AA SecretKey: 1x0000000000000000000000000000000AA + EnableStateReconciliation: "false" diff --git a/CLAUDE.md b/CLAUDE.md index 19d5ef9..32d635c 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -16,9 +16,10 @@ This is a Traefik middleware plugin that protects websites from bot traffic by c - `CaptchaProtect` struct: Main middleware handler with rate limiting, bot detection, and challenge serving - `Config` struct: Configuration from Traefik labels - Three in-memory caches (using `github.com/patrickmn/go-cache`): - - `rateCache`: Tracks request counts per subnet + - `rateCache`: Tracks request counts per subnet (TTL = `window` config value) - `verifiedCache`: Stores IPs that have passed challenges (24h default TTL) - - `botCache`: Caches reverse DNS lookups for bot verification + - `botCache`: Caches reverse DNS lookups for bot verification (1h TTL) + - **Why go-cache instead of sync.Map?** The plugin requires automatic TTL-based expiration for all caches. `sync.Map` has no built-in expiration mechanism, requiring manual cleanup goroutines. `go-cache` provides thread-safe maps with automatic expiration and cleanup. ### Request Flow Decision Tree @@ -120,9 +121,18 @@ Regex is significantly slower (~41ns vs ~3.4ns per operation) - see README bench ### State Persistence When `persistentStateFile` is configured: -- State saves every 1 minute to JSON file (`saveState()` at `main.go:695-727`) -- On startup, loads previous state from file (`loadState()` at `main.go:729-756`) +- State saves every 10 seconds (with 0-2s random jitter) to JSON file (`saveState()` at `main.go:716-746`) +- Uses file locking (`.lock` files) to prevent concurrent writes (`internal/state/state.go:61-129`) +- On startup, loads previous state from file (`loadState()` at `main.go:729-761`) - Contains: rate limits per subnet, bot verification cache, verified IPs +- **Important**: Each middleware instance runs its own save goroutine. If multiple instances share the same `persistentStateFile`, they will write more frequently (e.g., 2 instances = writes every ~5 seconds) +- **State Reconciliation**: When `enableStateReconciliation: "true"`, each save performs a read-modify-write cycle to merge state from other instances. This adds I/O overhead but prevents data loss in multi-instance deployments (see `internal/state/state.go:86-100`) + +**Why not Redis?** Traefik plugins are loaded via Yaegi (a Go interpreter), which has significant limitations: +- Yaegi cannot interpret Go packages that use `unsafe`, cgo, or complex reflection patterns +- Popular Redis clients like `go-redis/redis` are incompatible with Yaegi + +**Current solution**: File-based persistence with reconciliation avoids these issues. Local caches remain fast (no network overhead), state saves are batched (every 10s), and reconciliation handles conflicts without complex coordination. The tradeoff is accepting slightly stale data across instances (max 10s delay) rather than the complexity and performance cost of real-time Redis synchronization. ### Good Bot Detection diff --git a/README.md b/README.md index 643e088..a44c33b 100644 --- a/README.md +++ b/README.md @@ -119,6 +119,7 @@ services: | `enableStatsPage` | `string` | `"false"` | Allows `exemptIps` to access `/captcha-protect/stats` to monitor the rate limiter. | | `logLevel` | `string` | `"INFO"` | Log level for the middleware. Options: `ERROR`, `WARNING`, `INFO`, or `DEBUG`. | | `persistentStateFile` | `string` | `""` | File path to persist rate limiter state across Traefik restarts. In Docker, mount this file from the host. | +| `enableStateReconciliation` | `string` | `"false"` | When `"true"`, reads and merges disk state before each save to prevent multiple instances from overwriting data. Adds extra I/O overhead. Only enable for multi-instance deployments sharing state. | ### Good Bots diff --git a/ci/.env b/ci/.env index a23749a..a2cbe21 100644 --- a/ci/.env +++ b/ci/.env @@ -1,4 +1,4 @@ -TRAEFIK_TAG=v3.3.3 +TRAEFIK_TAG=v3.5 NGINX_TAG=1.27.4-alpine3.21 TURNSTILE_SITE_KEY=1x00000000000000000000AA TURNSTILE_SECRET_KEY=1x0000000000000000000000000000000AA diff --git a/ci/docker-compose.yml b/ci/docker-compose.yml index 8f5a0e1..28dde67 100644 --- a/ci/docker-compose.yml +++ b/ci/docker-compose.yml @@ -21,6 +21,7 @@ services: traefik.http.middlewares.captcha-protect.plugin.captcha-protect.goodBots: "" traefik.http.middlewares.captcha-protect.plugin.captcha-protect.protectRoutes: "/" traefik.http.middlewares.captcha-protect.plugin.captcha-protect.persistentStateFile: "/tmp/state.json" + traefik.http.middlewares.captcha-protect.plugin.captcha-protect.enableStateReconciliation: "true" healthcheck: test: curl -fs http://localhost/healthz | grep -q OK || exit 1 volumes: @@ -28,7 +29,36 @@ services: networks: default: aliases: - - nginx + - nginx + nginx2: + image: nginx:${NGINX_TAG} + labels: + traefik.enable: true + traefik.http.routers.nginx2.entrypoints: http + traefik.http.routers.nginx2.service: nginx2 + traefik.http.routers.nginx2.rule: Host(`localhost`) && PathPrefix(`/app2`) + traefik.http.services.nginx2.loadbalancer.server.port: 80 + traefik.http.routers.nginx2.middlewares: captcha-protect@docker + traefik.http.middlewares.captcha-protect.plugin.captcha-protect.captchaProvider: turnstile + traefik.http.middlewares.captcha-protect.plugin.captcha-protect.window: 120 + traefik.http.middlewares.captcha-protect.plugin.captcha-protect.rateLimit: ${RATE_LIMIT} + traefik.http.middlewares.captcha-protect.plugin.captcha-protect.siteKey: ${TURNSTILE_SITE_KEY} + traefik.http.middlewares.captcha-protect.plugin.captcha-protect.secretKey: ${TURNSTILE_SECRET_KEY} + traefik.http.middlewares.captcha-protect.plugin.captcha-protect.enableStatsPage: "true" + traefik.http.middlewares.captcha-protect.plugin.captcha-protect.ipForwardedHeader: "X-Forwarded-For" + traefik.http.middlewares.captcha-protect.plugin.captcha-protect.logLevel: "DEBUG" + traefik.http.middlewares.captcha-protect.plugin.captcha-protect.goodBots: "" + traefik.http.middlewares.captcha-protect.plugin.captcha-protect.protectRoutes: "/" + traefik.http.middlewares.captcha-protect.plugin.captcha-protect.persistentStateFile: "/tmp/state.json" + traefik.http.middlewares.captcha-protect.plugin.captcha-protect.enableStateReconciliation: "true" + healthcheck: + test: curl -fs http://localhost/healthz | grep -q OK || exit 1 + volumes: + - ./conf/nginx/default.conf:/etc/nginx/conf.d/default.conf:r + networks: + default: + aliases: + - nginx2 traefik: image: traefik:${TRAEFIK_TAG} command: >- diff --git a/ci/test.go b/ci/test.go index 4c2276f..6f8ee67 100755 --- a/ci/test.go +++ b/ci/test.go @@ -25,7 +25,6 @@ var ( const numIPs = 100 const parallelism = 10 -const expectedRedirectURL = "http://localhost/challenge?destination=%2F" func main() { _ips := []string{ @@ -48,24 +47,19 @@ func main() { fmt.Println("Bringing traefik/nginx online") runCommand("docker", "compose", "up", "-d") waitForService("http://localhost") + waitForService("http://localhost/app2") fmt.Printf("Making sure %d attempt(s) pass\n", rateLimit) - runParallelChecks(ips, rateLimit) + runParallelChecks(ips, rateLimit, "http://localhost") - fmt.Printf("Making sure attempt #%d causes a redirect to the challenge page\n", rateLimit+1) - ensureRedirect(ips) + time.Sleep(cp.StateSaveInterval + cp.StateSaveJitter + (1 * time.Second)) + runCommand("jq", ".", "tmp/state.json") - fmt.Println("Sleeping for 2m") - time.Sleep(125 * time.Second) - fmt.Println("Making sure one attempt passes after 2m window") - runParallelChecks(ips, 1) - fmt.Println("All good 🚀") + fmt.Printf("Making sure attempt #%d causes a redirect to the challenge page\n", rateLimit+1) + ensureRedirect(ips, "http://localhost") - // make sure the state has time to save - fmt.Println("Waiting for state to save") - runCommand("jq", ".", "tmp/state.json") - time.Sleep(80 * time.Second) - runCommand("jq", ".", "tmp/state.json") + fmt.Println("\nTesting state sharing between nginx instances...") + testStateSharing(ips) runCommand("docker", "container", "stats", "--no-stream") @@ -138,7 +132,7 @@ func waitForService(url string) { } } -func runParallelChecks(ips []string, rateLimit int) { +func runParallelChecks(ips []string, rateLimit int, url string) { var wg sync.WaitGroup sem := make(chan struct{}, parallelism) @@ -151,7 +145,7 @@ func runParallelChecks(ips []string, rateLimit int) { defer func() { <-sem }() fmt.Printf("Checking %s\n", ip) - output := httpRequest(ip) + output := httpRequest(ip, url) if output != "" { slog.Error("Unexpected output", "ip", ip, "output", output) os.Exit(1) @@ -164,13 +158,19 @@ func runParallelChecks(ips []string, rateLimit int) { wg.Wait() } -func ensureRedirect(ips []string) { +func ensureRedirect(ips []string, url string) { + expectedURL := url + "/challenge?destination=%2F" + if url != "http://localhost" { + // For /app2, the destination should be the app2 path + expectedURL = "http://localhost/challenge?destination=%2Fapp2" + } + for _, ip := range ips { fmt.Printf("Checking %s\n", ip) - output := httpRequest(ip) + output := httpRequest(ip, url) - if output != expectedRedirectURL { - slog.Error("Unexpected output", "ip", ip, "output", output) + if output != expectedURL { + slog.Error("Unexpected output", "ip", ip, "output", output, "expected", expectedURL) os.Exit(1) } @@ -178,7 +178,27 @@ func ensureRedirect(ips []string) { } } -func httpRequest(ip string) string { +func testStateSharing(ips []string) { + // Use first IP to test state sharing + testIP := ips[0] + + fmt.Printf("Testing with IP: %s\n", testIP) + + // The IP should already be at rate limit from previous tests on localhost/ + // Now verify it's also rate limited on localhost/app2 (shared state) + fmt.Println("Verifying IP is rate limited on /app2 (state should be shared)...") + output := httpRequest(testIP, "http://localhost/app2") + expectedURL := "http://localhost/challenge?destination=%2Fapp2" + + if output != expectedURL { + slog.Error("State NOT shared between instances!", "ip", testIP, "output", output, "expected", expectedURL) + os.Exit(1) + } + + fmt.Println("✓ State is correctly shared between nginx instances!") +} + +func httpRequest(ip, url string) string { client := &http.Client{ CheckRedirect: func(req *http.Request, via []*http.Request) error { // Capture the redirect URL and stop following it @@ -189,7 +209,7 @@ func httpRequest(ip string) string { }, } - req, err := http.NewRequest("GET", "http://localhost", nil) + req, err := http.NewRequest("GET", url, nil) if err != nil { slog.Error("Failed to create request", "err", err) os.Exit(1) diff --git a/internal/state/lock.go b/internal/state/lock.go new file mode 100644 index 0000000..849abb3 --- /dev/null +++ b/internal/state/lock.go @@ -0,0 +1,126 @@ +package state + +import ( + "fmt" + "os" + "strconv" + "strings" + "time" +) + +// FileLock represents an exclusive file lock using lock file creation +// This implementation doesn't use syscall.Flock which is not available in Traefik plugins +type FileLock struct { + lockPath string + pid int +} + +// NewFileLock creates a new file lock for the given path. +// It uses a separate .lock file to coordinate access. +func NewFileLock(path string) (*FileLock, error) { + return &FileLock{ + lockPath: path, + pid: os.Getpid(), + }, nil +} + +// Lock acquires an exclusive lock by creating a lock file. +// It will retry for up to 5 seconds if the lock is held by another process. +func (fl *FileLock) Lock() error { + timeout := time.After(5 * time.Second) + ticker := time.NewTicker(10 * time.Millisecond) + defer ticker.Stop() + + for { + select { + case <-timeout: + return fmt.Errorf("timeout waiting for file lock") + case <-ticker.C: + // Try to create lock file exclusively + f, err := os.OpenFile(fl.lockPath, os.O_CREATE|os.O_EXCL|os.O_WRONLY, 0644) + if err == nil { + // Successfully created lock file + _, err = f.WriteString(strconv.Itoa(fl.pid)) + f.Close() + // Check for write error + if err != nil { + // We got the lock but failed to write. + // Best effort to clean up, then return the error. + _ = os.Remove(fl.lockPath) + return fmt.Errorf("failed to write pid to lock file: %v", err) + } + // We hold the lock + return nil + } + + // If we're here, os.OpenFile failed, likely because the file exists. + // Check if lock file is stale (older than 10 seconds) + info, statErr := os.Stat(fl.lockPath) + if statErr == nil { + if time.Since(info.ModTime()) > 10*time.Second { + // Lock file is stale, try to remove it + err = os.Remove(fl.lockPath) + + if err != nil && !os.IsNotExist(err) { + // If we can't remove it (and it's not 'not exist'), + // something is wrong (e.g., permissions). + return fmt.Errorf("unable to remove stale lock: %v", err) + } + } + } + // If stat failed (e.g., file removed between OpenFile and Stat) + // or lock is not stale, just loop and wait for next tick. + } + } +} + +// Unlock releases the exclusive lock by removing the lock file +// This is now safer and checks the PID. +func (fl *FileLock) Unlock() error { + content, err := os.ReadFile(fl.lockPath) + if err != nil { + if os.IsNotExist(err) { + return nil // Already unlocked + } + return fmt.Errorf("failed to read lock file on unlock: %v", err) + } + + lockPIDStr := string(content) + myPIDStr := strconv.Itoa(fl.pid) + + if lockPIDStr != myPIDStr { + // This is not our lock. Do not remove it. + return fmt.Errorf("cannot unlock file held by different process (my_pid: %s, lock_pid: %s)", myPIDStr, lockPIDStr) + } + + // It is our lock, remove it. + err = os.Remove(fl.lockPath) + if err != nil && !os.IsNotExist(err) { + // Failed to remove, and not because it was already gone + return fmt.Errorf("failed to remove our lock file: %v", err) + } + + // Succeeded, or it was already gone (which is fine) + return nil +} + +// Close is an alias for Unlock for compatibility. +// It will not return an error if the lock is held by another process. +func (fl *FileLock) Close() error { + err := fl.Unlock() + + // If Unlock fails, we only want to suppress the error + // if it's because the lock is held by someone else. + // In the context of Close(), this is fine. + if err != nil { + if strings.Contains(err.Error(), "cannot unlock file held by different process") { + return nil + } + // IsNotExist is already handled by Unlock, but this is safe. + if os.IsNotExist(err) { + return nil + } + } + + return err +} diff --git a/internal/state/lock_test.go b/internal/state/lock_test.go new file mode 100644 index 0000000..61e6a5d --- /dev/null +++ b/internal/state/lock_test.go @@ -0,0 +1,354 @@ +// File: filelock_test.go +package state + +import ( + "fmt" + "os" + "path/filepath" + "strconv" + "sync" + "sync/atomic" + "testing" + "time" +) + +// TestFileLock_LockUnlock tests the basic Lock and Unlock functionality. +func TestFileLock_LockUnlock(t *testing.T) { + t.Parallel() + tempDir := t.TempDir() + lockPath := filepath.Join(tempDir, "test.lock") + + fl, err := NewFileLock(lockPath) + if err != nil { + t.Fatalf("NewFileLock() error = %v", err) + } + + if err := fl.Lock(); err != nil { + t.Fatalf("Lock() error = %v", err) + } + + if _, err := os.Stat(lockPath); err != nil { + t.Fatalf("lock file was not created: %v", err) + } + + content, err := os.ReadFile(lockPath) + if err != nil { + t.Fatalf("could not read lock file: %v", err) + } + expectedPID := strconv.Itoa(os.Getpid()) + if string(content) != expectedPID { + t.Errorf("lock file contains wrong PID: got %q, want %q", string(content), expectedPID) + } + + if err := fl.Unlock(); err != nil { + t.Fatalf("Unlock() error = %v", err) + } + + if _, err := os.Stat(lockPath); !os.IsNotExist(err) { + t.Fatal("lock file was not removed after Unlock()") + } +} + +// TestFileLock_Close tests the Close functionality, including idempotency. +func TestFileLock_Close(t *testing.T) { + t.Parallel() + tempDir := t.TempDir() + lockPath := filepath.Join(tempDir, "test.lock") + + fl, err := NewFileLock(lockPath) + if err != nil { + t.Fatalf("NewFileLock() error = %v", err) + } + + if err := fl.Lock(); err != nil { + t.Fatalf("Lock() error = %v", err) + } + + if _, err := os.Stat(lockPath); err != nil { + t.Fatalf("lock file was not created: %v", err) + } + + if err := fl.Close(); err != nil { + t.Fatalf("Close() error = %v", err) + } + + if _, err := os.Stat(lockPath); !os.IsNotExist(err) { + t.Fatal("lock file was not removed after Close()") + } + + // Close again (should be idempotent and not return an error) + if err := fl.Close(); err != nil { + t.Fatalf("second Close() returned an error: %v", err) + } +} + +// TestFileLock_Contention tests that a second process waits for the first to unlock. +func TestFileLock_Contention(t *testing.T) { + t.Parallel() + tempDir := t.TempDir() + lockPath := filepath.Join(tempDir, "test.lock") + + fl1, _ := NewFileLock(lockPath) + fl2, _ := NewFileLock(lockPath) + + var wg sync.WaitGroup + wg.Add(2) + + // Use channels to synchronize the goroutines and ensure + // g2 doesn't try to lock until g1 *definitely* has the lock. + g1Locked := make(chan struct{}) + + // Goroutine 1: Acquires lock first, holds it, then releases + go func() { + defer wg.Done() + if err := fl1.Lock(); err != nil { + t.Errorf("g1: Lock() error = %v", err) + return + } + close(g1Locked) + + time.Sleep(100 * time.Millisecond) + + if err := fl1.Unlock(); err != nil { + t.Errorf("g1: Unlock() error = %v", err) + } + }() + + // Goroutine 2: Waits for g1 to get the lock, then tries to acquire it + go func() { + defer wg.Done() + <-g1Locked + + startTime := time.Now() + if err := fl2.Lock(); err != nil { + t.Errorf("g2: Lock() error = %v", err) + return + } + elapsed := time.Since(startTime) + + if elapsed < 90*time.Millisecond { // Give some buffer + t.Errorf("g2 did not wait for g1 to unlock; elapsed = %v", elapsed) + } + + if err := fl2.Unlock(); err != nil { + t.Errorf("g2: Unlock() error = %v", err) + } + }() + + wg.Wait() + + if _, err := os.Stat(lockPath); !os.IsNotExist(err) { + t.Fatal("lock file was not removed after all goroutines finished") + } +} + +// TestFileLock_Timeout tests that Lock() returns an error if it can't +// acquire the lock within the timeout period. +func TestFileLock_Timeout(t *testing.T) { + t.Parallel() + tempDir := t.TempDir() + lockPath := filepath.Join(tempDir, "test.lock") + + fl1, _ := NewFileLock(lockPath) + if err := fl1.Lock(); err != nil { + t.Fatalf("fl1: Lock() error = %v", err) + } + // Defer a Close() for cleanup. This is safer now. + defer fl1.Close() + + // Try to acquire the same lock in another goroutine + fl2, _ := NewFileLock(lockPath) + + startTime := time.Now() + err := fl2.Lock() + elapsed := time.Since(startTime) + + if err == nil { + t.Fatal("fl2: Lock() did not return an error, expected timeout") + fl2.Unlock() //nolint:errcheck + return + } + if err.Error() != "timeout waiting for file lock" { + t.Errorf("fl2: Lock() returned wrong error: got %q, want %q", err.Error(), "timeout waiting for file lock") + } + + if elapsed < 4*time.Second || elapsed > 6*time.Second { + t.Errorf("fl2: timeout duration was not ~5s: got %v", elapsed) + } +} + +// TestFileLock_StaleLock tests that a lock file older than 10 seconds is removed. +func TestFileLock_StaleLock(t *testing.T) { + t.Parallel() + tempDir := t.TempDir() + lockPath := filepath.Join(tempDir, "test.lock") + + if err := os.WriteFile(lockPath, []byte("12345"), 0644); err != nil { + t.Fatalf("failed to create stale lock file: %v", err) + } + + staleTime := time.Now().Add(-15 * time.Second) + if err := os.Chtimes(lockPath, staleTime, staleTime); err != nil { + t.Fatalf("failed to set stale time: %v", err) + } + + fl, _ := NewFileLock(lockPath) + if err := fl.Lock(); err != nil { + t.Fatalf("Lock() failed to acquire stale lock: %v", err) + } + defer fl.Unlock() //nolint:errcheck + + content, err := os.ReadFile(lockPath) + if err != nil { + t.Fatalf("could not read new lock file: %v", err) + } + expectedPID := strconv.Itoa(os.Getpid()) + if string(content) != expectedPID { + t.Errorf("lock file not overwritten with new PID: got %q, want %q", string(content), expectedPID) + } +} + +// TestFileLock_StaleLockRace tests for a "Check-Then-Act" race condition +// when multiple processes detect a stale lock at the same time. +func TestFileLock_StaleLockRace(t *testing.T) { + t.Parallel() + tempDir := t.TempDir() + lockPath := filepath.Join(tempDir, "test.lock") + + if err := os.WriteFile(lockPath, []byte("stale-pid"), 0644); err != nil { + t.Fatalf("failed to create stale lock file: %v", err) + } + staleTime := time.Now().Add(-15 * time.Second) + if err := os.Chtimes(lockPath, staleTime, staleTime); err != nil { + t.Fatalf("failed to set stale time: %v", err) + } + + numGoroutines := 10 + var wg sync.WaitGroup + wg.Add(numGoroutines) + + readyGate := &sync.WaitGroup{} + readyGate.Add(numGoroutines) + releaseGate := &sync.WaitGroup{} + releaseGate.Add(1) + + var activeLocks int32 + var maxActiveLocks int32 + + for i := 0; i < numGoroutines; i++ { + go func() { + defer wg.Done() + + readyGate.Done() + releaseGate.Wait() + + fl, err := NewFileLock(lockPath) + if err != nil { + return + } + + // We expect most of these to fail with a timeout, which is fine. + // The critical part is that they don't *all* succeed. + if err := fl.Lock(); err != nil { + return + } + + // --- CRITICAL SECTION --- + currentActive := atomic.AddInt32(&activeLocks, 1) + if current := atomic.LoadInt32(&maxActiveLocks); current < currentActive { + atomic.CompareAndSwapInt32(&maxActiveLocks, current, currentActive) + } + + time.Sleep(10 * time.Millisecond) + + atomic.AddInt32(&activeLocks, -1) + // --- END CRITICAL SECTION --- + + fl.Unlock() //nolint:errcheck + }() + } + + readyGate.Wait() + releaseGate.Done() + wg.Wait() + + finalMax := atomic.LoadInt32(&maxActiveLocks) + if finalMax > 1 { + t.Errorf("RACE CONDITION DETECTED: %d goroutines held the lock simultaneously", finalMax) + } + + if _, err := os.Stat(lockPath); !os.IsNotExist(err) { + t.Error("lock file was not removed after test completion") + os.Remove(lockPath) //nolint:errcheck + } +} + +// --- NEW TESTS --- + +// TestFileLock_UnlockSafety verifies that a lock cannot be unlocked by +// a process that does not own it (PID mismatch). +func TestFileLock_UnlockSafety(t *testing.T) { + t.Parallel() + tempDir := t.TempDir() + lockPath := filepath.Join(tempDir, "test.lock") + + // Create a lock file manually with a fake PID + fakePID := "-12345" + if err := os.WriteFile(lockPath, []byte(fakePID), 0644); err != nil { + t.Fatalf("failed to create fake lock file: %v", err) + } + defer os.Remove(lockPath) //nolint:errcheck + + fl, err := NewFileLock(lockPath) + if err != nil { + t.Fatalf("NewFileLock() error = %v", err) + } + + // Try to Unlock() a file we don't own + err = fl.Unlock() + if err == nil { + t.Fatal("Unlock() did not return an error when PID did not match") + } + + // Check for the specific error + expectedErr := fmt.Sprintf("cannot unlock file held by different process (my_pid: %d, lock_pid: %s)", fl.pid, fakePID) + if err.Error() != expectedErr { + t.Errorf("Unlock() returned wrong error: \ngot: %q\nwant: %q", err.Error(), expectedErr) + } + + // Crucially, verify the lock file was NOT deleted + if _, err := os.Stat(lockPath); err != nil { + t.Fatalf("lock file was removed by unsafe Unlock(): %v", err) + } +} + +// TestFileLock_CloseSafety verifies that Close() does not return an error +// and does not delete the lock file if it's owned by another process. +func TestFileLock_CloseSafety(t *testing.T) { + t.Parallel() + tempDir := t.TempDir() + lockPath := filepath.Join(tempDir, "test.lock") + + // Create a lock file manually with a fake PID + fakePID := "-12345" + if err := os.WriteFile(lockPath, []byte(fakePID), 0644); err != nil { + t.Fatalf("failed to create fake lock file: %v", err) + } + defer os.Remove(lockPath) //nolint:errcheck + + fl, err := NewFileLock(lockPath) + if err != nil { + t.Fatalf("NewFileLock() error = %v", err) + } + + // Try to Close() a file we don't own + err = fl.Close() + if err != nil { + t.Fatalf("Close() returned an error when PID did not match: %v", err) + } + + // Verify the lock file was NOT deleted + if _, err := os.Stat(lockPath); err != nil { + t.Fatalf("lock file was removed by unsafe Close(): %v", err) + } +} diff --git a/internal/state/state.go b/internal/state/state.go index 1ac6fea..291e5fa 100644 --- a/internal/state/state.go +++ b/internal/state/state.go @@ -1,16 +1,27 @@ package state import ( + "encoding/json" + "fmt" + "log/slog" + "os" "reflect" + "time" lru "github.com/patrickmn/go-cache" ) +// CacheEntry represents a cache item with its expiration time +type CacheEntry struct { + Value interface{} `json:"value"` + Expiration int64 `json:"expiration"` // Unix timestamp in nanoseconds, 0 means no expiration +} + type State struct { - Rate map[string]uint `json:"rate"` - Bots map[string]bool `json:"bots"` - Verified map[string]bool `json:"verified"` - Memory map[string]uintptr `json:"memory"` + Rate map[string]CacheEntry `json:"rate"` + Bots map[string]CacheEntry `json:"bots"` + Verified map[string]CacheEntry `json:"verified"` + Memory map[string]uintptr `json:"memory"` } func GetState(rateCache, botCache, verifiedCache map[string]lru.Item) State { @@ -18,32 +29,293 @@ func GetState(rateCache, botCache, verifiedCache map[string]lru.Item) State { Memory: make(map[string]uintptr, 3), } - state.Rate = make(map[string]uint, len(rateCache)) - state.Memory["rate"] = reflect.TypeOf(state.Rate).Size() - for k, v := range rateCache { - state.Rate[k] = v.Object.(uint) - state.Memory["rate"] += reflect.TypeOf(k).Size() - state.Memory["rate"] += reflect.TypeOf(v).Size() - state.Memory["rate"] += uintptr(len(k)) + state.Rate, state.Memory["rate"] = getCacheEntries[uint](rateCache) + state.Bots, state.Memory["bot"] = getCacheEntries[bool](botCache) + state.Verified, state.Memory["verified"] = getCacheEntries[bool](verifiedCache) + + return state +} + +// SetState loads state data into the provided caches, preserving expiration times. +// If an entry has already expired (expiration < now), it will be skipped. +func SetState(state State, rateCache, botCache, verifiedCache *lru.Cache) { + loadCacheEntries(state.Rate, rateCache, convertRateValue) + loadCacheEntries(state.Bots, botCache, convertBoolValue) + loadCacheEntries(state.Verified, verifiedCache, convertBoolValue) +} + +// ReconcileState merges file-based state with in-memory state. +func ReconcileState(fileState State, rateCache, botCache, verifiedCache *lru.Cache) { + rateItems := rateCache.Items() + botItems := botCache.Items() + verifiedItems := verifiedCache.Items() + + // Use "max value wins" for rate cache + reconcileRateCache(fileState.Rate, rateItems, rateCache, convertRateValue) + + // Use "later expiration wins" for bot and verified caches + reconcileCacheEntries(fileState.Bots, botItems, botCache, convertBoolValue) + reconcileCacheEntries(fileState.Verified, verifiedItems, verifiedCache, convertBoolValue) +} + +// SaveStateToFile saves state to a file with locking and optional reconciliation. +// When reconcile is true, it reads and merges existing file state before saving. +// Returns timing metrics for debugging. +func SaveStateToFile( + filePath string, + reconcile bool, + rateCache, botCache, verifiedCache *lru.Cache, + log *slog.Logger, +) (lockMs, readMs, reconcileMs, marshalMs, writeMs, totalMs int64, err error) { + startTime := time.Now() + + lock, err := NewFileLock(filePath + ".lock") + if err != nil { + return 0, 0, 0, 0, 0, 0, fmt.Errorf("failed to create lock: %w", err) } + defer lock.Close() - state.Bots = make(map[string]bool, len(botCache)) - state.Memory["bot"] = reflect.TypeOf(state.Bots).Size() - for k, v := range botCache { - state.Bots[k] = v.Object.(bool) - state.Memory["bot"] += reflect.TypeOf(k).Size() - state.Memory["bot"] += reflect.TypeOf(v).Size() - state.Memory["bot"] += uintptr(len(k)) + if err := lock.Lock(); err != nil { + return 0, 0, 0, 0, 0, 0, fmt.Errorf("failed to acquire lock: %w", err) } + lockDuration := time.Since(startTime) + + var readDuration, reconcileDuration, marshalDuration, writeDuration time.Duration - state.Verified = make(map[string]bool, len(verifiedCache)) - state.Memory["verified"] = reflect.TypeOf(state.Verified).Size() - for k, v := range verifiedCache { - state.Verified[k] = v.Object.(bool) - state.Memory["verified"] += reflect.TypeOf(k).Size() - state.Memory["verified"] += reflect.TypeOf(v).Size() - state.Memory["verified"] += uintptr(len(k)) + // Reconcile with existing file state if enabled + if reconcile { + readStart := time.Now() + fileContent, readErr := os.ReadFile(filePath) + readDuration = time.Since(readStart) + + if readErr == nil && len(fileContent) > 0 { + reconcileStart := time.Now() + var fileState State + if unmarshalErr := json.Unmarshal(fileContent, &fileState); unmarshalErr == nil { + log.Debug("Reconciling state before save", "fileBytes", len(fileContent)) + ReconcileState(fileState, rateCache, botCache, verifiedCache) + } + reconcileDuration = time.Since(reconcileStart) + } } - return state + // Marshal current state + marshalStart := time.Now() + currentState := GetState(rateCache.Items(), botCache.Items(), verifiedCache.Items()) + jsonData, err := json.Marshal(currentState) + marshalDuration = time.Since(marshalStart) + + if err != nil { + return lockDuration.Milliseconds(), readDuration.Milliseconds(), + reconcileDuration.Milliseconds(), marshalDuration.Milliseconds(), + 0, 0, err + } + + // Write to disk + writeStart := time.Now() + err = os.WriteFile(filePath, jsonData, 0644) + writeDuration = time.Since(writeStart) + + if err != nil { + return lockDuration.Milliseconds(), readDuration.Milliseconds(), + reconcileDuration.Milliseconds(), marshalDuration.Milliseconds(), + writeDuration.Milliseconds(), 0, err + } + + totalDuration := time.Since(startTime) + return lockDuration.Milliseconds(), readDuration.Milliseconds(), + reconcileDuration.Milliseconds(), marshalDuration.Milliseconds(), + writeDuration.Milliseconds(), totalDuration.Milliseconds(), nil +} + +// LoadStateFromFile loads state from a file with locking. +func LoadStateFromFile( + filePath string, + rateCache, botCache, verifiedCache *lru.Cache, +) error { + lock, err := NewFileLock(filePath + ".lock") + if err != nil { + return err + } + defer lock.Close() + + if err := lock.Lock(); err != nil { + return err + } + + fileContent, err := os.ReadFile(filePath) + if err != nil || len(fileContent) == 0 { + return err + } + + var loadedState State + err = json.Unmarshal(fileContent, &loadedState) + if err != nil { + return err + } + + // Use SetState which properly handles expiration times + SetState(loadedState, rateCache, botCache, verifiedCache) + + return nil +} + +func calculateDuration(expiration int64, now int64) time.Duration { + if expiration == 0 { + return lru.NoExpiration + } + return time.Duration(expiration - now) +} + +func convertRateValue(v interface{}) (uint, bool) { + switch val := v.(type) { + case uint: + return val, true + case float64: + return uint(val), true + case int: + return uint(val), true + default: + return 0, false + } +} + +func convertBoolValue(v interface{}) (bool, bool) { + switch val := v.(type) { + case bool: + return val, true + default: + return false, false + } +} + +func getCacheEntries[T any](items map[string]lru.Item) (map[string]CacheEntry, uintptr) { + entries := make(map[string]CacheEntry, len(items)) + var memoryUsage uintptr + memoryUsage = reflect.TypeOf(entries).Size() + + for k, v := range items { + entries[k] = CacheEntry{ + Value: v.Object.(T), + Expiration: v.Expiration, + } + memoryUsage += reflect.TypeOf(k).Size() + memoryUsage += reflect.TypeOf(v).Size() + memoryUsage += uintptr(len(k)) + } + return entries, memoryUsage +} + +func loadCacheEntries[T any]( + entries map[string]CacheEntry, + cache *lru.Cache, + converter func(interface{}) (T, bool), +) { + now := time.Now().UnixNano() + for k, entry := range entries { + if entry.Expiration > 0 && entry.Expiration <= now { + continue + } + value, ok := converter(entry.Value) + if !ok { + continue + } + duration := calculateDuration(entry.Expiration, now) + cache.Set(k, value, duration) + } +} + +// reconcileCacheEntries implements "later expiration wins" +// This is correct for bool flags (Verified, Bots). +func reconcileCacheEntries[T any]( + fileEntries map[string]CacheEntry, + memItems map[string]lru.Item, + cache *lru.Cache, + converter func(interface{}) (T, bool), +) { + now := time.Now().UnixNano() + for k, fileEntry := range fileEntries { + if fileEntry.Expiration > 0 && fileEntry.Expiration <= now { + continue + } + + value, ok := converter(fileEntry.Value) + if !ok { + continue + } + + duration := calculateDuration(fileEntry.Expiration, now) + + memItem, exists := memItems[k] + if !exists { + cache.Set(k, value, duration) + continue + } + + if fileEntry.Expiration > memItem.Expiration { + cache.Set(k, value, duration) + } + } +} + +// reconcileRateCache implements "max value wins" and "max expiration wins". +// This prevents runaway growth (from summing) and accepts data loss +// (under-counting) as the safer alternative. +func reconcileRateCache( + fileEntries map[string]CacheEntry, + memItems map[string]lru.Item, + cache *lru.Cache, + converter func(interface{}) (uint, bool), +) { + now := time.Now().UnixNano() + for k, fileEntry := range fileEntries { + if fileEntry.Expiration > 0 && fileEntry.Expiration <= now { + continue + } + + fileValue, ok := converter(fileEntry.Value) + if !ok { + continue + } + + memItem, exists := memItems[k] + if !exists { + // Entry only in file, just add it + duration := calculateDuration(fileEntry.Expiration, now) + cache.Set(k, fileValue, duration) + continue + } + + // Entry in both, combine them + memValue, ok := memItem.Object.(uint) + if !ok { + // In-memory object is not uint, something is wrong. + // Overwrite with file value as a fallback. + duration := calculateDuration(fileEntry.Expiration, now) + cache.Set(k, fileValue, duration) + continue + } + + // Use the HIGHEST value, not the sum + combinedValue := maxUint(fileValue, memValue) + // Use the LATER expiration + laterExpiration := max(fileEntry.Expiration, memItem.Expiration) + + duration := calculateDuration(laterExpiration, now) + cache.Set(k, combinedValue, duration) + } +} + +func max(a, b int64) int64 { + if a > b { + return a + } + return b +} + +func maxUint(a, b uint) uint { + if a > b { + return a + } + return b } diff --git a/internal/state/state_test.go b/internal/state/state_test.go index 4237f80..b1639af 100644 --- a/internal/state/state_test.go +++ b/internal/state/state_test.go @@ -1,7 +1,11 @@ package state import ( + "encoding/json" + "log/slog" + "os" "testing" + "testing/synctest" "time" lru "github.com/patrickmn/go-cache" @@ -29,31 +33,43 @@ func TestGetState(t *testing.T) { if len(state.Rate) != 2 { t.Errorf("Expected 2 rate entries, got %d", len(state.Rate)) } - if state.Rate["192.168.0.0"] != 10 { - t.Errorf("Expected rate 10 for 192.168.0.0, got %d", state.Rate["192.168.0.0"]) + if state.Rate["192.168.0.0"].Value != uint(10) { + t.Errorf("Expected rate 10 for 192.168.0.0, got %v", state.Rate["192.168.0.0"].Value) } - if state.Rate["10.0.0.0"] != 5 { - t.Errorf("Expected rate 5 for 10.0.0.0, got %d", state.Rate["10.0.0.0"]) + if state.Rate["10.0.0.0"].Value != uint(5) { + t.Errorf("Expected rate 5 for 10.0.0.0, got %v", state.Rate["10.0.0.0"].Value) + } + // Verify expiration timestamps are set + if state.Rate["192.168.0.0"].Expiration == 0 { + t.Error("Expected non-zero expiration for 192.168.0.0") } // Verify bot cache data if len(state.Bots) != 2 { t.Errorf("Expected 2 bot entries, got %d", len(state.Bots)) } - if state.Bots["1.2.3.4"] != true { + if state.Bots["1.2.3.4"].Value != true { t.Error("Expected bot 1.2.3.4 to be true") } - if state.Bots["5.6.7.8"] != false { + if state.Bots["5.6.7.8"].Value != false { t.Error("Expected bot 5.6.7.8 to be false") } + // Verify expiration timestamps are set + if state.Bots["1.2.3.4"].Expiration == 0 { + t.Error("Expected non-zero expiration for bot 1.2.3.4") + } // Verify verified cache data if len(state.Verified) != 1 { t.Errorf("Expected 1 verified entry, got %d", len(state.Verified)) } - if state.Verified["9.9.9.9"] != true { + if state.Verified["9.9.9.9"].Value != true { t.Error("Expected 9.9.9.9 to be verified") } + // Verify expiration timestamp is set + if state.Verified["9.9.9.9"].Expiration == 0 { + t.Error("Expected non-zero expiration for verified 9.9.9.9") + } // Verify memory tracking exists if len(state.Memory) != 3 { @@ -91,3 +107,834 @@ func TestGetStateEmpty(t *testing.T) { t.Errorf("Expected 3 memory entries, got %d", len(state.Memory)) } } + +func TestSetState(t *testing.T) { + // Create state with expiration times + now := time.Now().UnixNano() + futureExpiration := now + int64(1*time.Hour) + pastExpiration := now - int64(1*time.Hour) + + state := State{ + Rate: map[string]CacheEntry{ + "192.168.0.0": {Value: uint(10), Expiration: futureExpiration}, + "10.0.0.0": {Value: uint(5), Expiration: pastExpiration}, // expired + }, + Bots: map[string]CacheEntry{ + "1.2.3.4": {Value: true, Expiration: futureExpiration}, + "5.6.7.8": {Value: false, Expiration: pastExpiration}, // expired + }, + Verified: map[string]CacheEntry{ + "9.9.9.9": {Value: true, Expiration: futureExpiration}, + "8.8.8.8": {Value: true, Expiration: pastExpiration}, // expired + "7.7.7.7": {Value: true, Expiration: 0}, // no expiration + }, + } + + // Create caches + rateCache := lru.New(1*time.Hour, 1*time.Minute) + botCache := lru.New(1*time.Hour, 1*time.Minute) + verifiedCache := lru.New(1*time.Hour, 1*time.Minute) + + // Set state + SetState(state, rateCache, botCache, verifiedCache) + + // Verify only non-expired entries were loaded + if rateCache.ItemCount() != 1 { + t.Errorf("Expected 1 rate entry (expired filtered out), got %d", rateCache.ItemCount()) + } + if v, ok := rateCache.Get("192.168.0.0"); !ok || v.(uint) != 10 { + t.Error("Expected rate 10 for 192.168.0.0") + } + if _, ok := rateCache.Get("10.0.0.0"); ok { + t.Error("Expected expired entry 10.0.0.0 to be filtered out") + } + + if botCache.ItemCount() != 1 { + t.Errorf("Expected 1 bot entry (expired filtered out), got %d", botCache.ItemCount()) + } + if v, ok := botCache.Get("1.2.3.4"); !ok || v.(bool) != true { + t.Error("Expected bot 1.2.3.4 to be true") + } + + if verifiedCache.ItemCount() != 2 { + t.Errorf("Expected 2 verified entries (1 expired filtered out), got %d", verifiedCache.ItemCount()) + } + if v, ok := verifiedCache.Get("9.9.9.9"); !ok || v.(bool) != true { + t.Error("Expected 9.9.9.9 to be verified") + } + if v, ok := verifiedCache.Get("7.7.7.7"); !ok || v.(bool) != true { + t.Error("Expected 7.7.7.7 to be verified (no expiration)") + } + if _, ok := verifiedCache.Get("8.8.8.8"); ok { + t.Error("Expected expired entry 8.8.8.8 to be filtered out") + } +} + +func TestReconcileState(t *testing.T) { + now := time.Now().UnixNano() + oldExpiration := now + int64(30*time.Minute) + newExpiration := now + int64(1*time.Hour) + + // Create file state with some entries + fileState := State{ + Rate: map[string]CacheEntry{ + "192.168.0.0": {Value: uint(15), Expiration: newExpiration}, // newer than memory + "10.0.0.0": {Value: uint(3), Expiration: oldExpiration}, // older than memory + "172.16.0.0": {Value: uint(7), Expiration: newExpiration}, // only in file + }, + Verified: map[string]CacheEntry{ + "1.1.1.1": {Value: true, Expiration: newExpiration}, // only in file + "2.2.2.2": {Value: true, Expiration: oldExpiration}, // older than memory + }, + } + + // Create memory caches with some overlapping data + rateCache := lru.New(1*time.Hour, 1*time.Minute) + botCache := lru.New(1*time.Hour, 1*time.Minute) + verifiedCache := lru.New(1*time.Hour, 1*time.Minute) + + rateCache.Set("192.168.0.0", uint(10), time.Duration(oldExpiration-now)) // older, should be replaced + rateCache.Set("10.0.0.0", uint(5), time.Duration(newExpiration-now)) // newer, should be kept + rateCache.Set("8.8.8.8", uint(20), time.Duration(newExpiration-now)) // only in memory + + verifiedCache.Set("2.2.2.2", true, time.Duration(newExpiration-now)) // newer, should be kept + + // Reconcile + ReconcileState(fileState, rateCache, botCache, verifiedCache) + + // Verify reconciliation results + // 192.168.0.0 should be updated to file's value (newer expiration) + if v, ok := rateCache.Get("192.168.0.0"); !ok || v.(uint) != 15 { + t.Errorf("Expected rate 15 for 192.168.0.0 after reconciliation, got %v", v) + } + + // 10.0.0.0 should keep memory value (newer expiration) + if v, ok := rateCache.Get("10.0.0.0"); !ok || v.(uint) != 5 { + t.Errorf("Expected rate 5 for 10.0.0.0 (memory kept), got %v", v) + } + + // 172.16.0.0 should be added from file + if v, ok := rateCache.Get("172.16.0.0"); !ok || v.(uint) != 7 { + t.Error("Expected 172.16.0.0 to be added from file") + } + + // 8.8.8.8 should still exist (only in memory) + if v, ok := rateCache.Get("8.8.8.8"); !ok || v.(uint) != 20 { + t.Error("Expected 8.8.8.8 to still exist in memory") + } + + // 1.1.1.1 should be added from file + if v, ok := verifiedCache.Get("1.1.1.1"); !ok || v.(bool) != true { + t.Error("Expected 1.1.1.1 to be added from file") + } + + // 2.2.2.2 should keep memory value (newer expiration) + if v, ok := verifiedCache.Get("2.2.2.2"); !ok || v.(bool) != true { + t.Error("Expected 2.2.2.2 to be kept from memory") + } +} + +func TestSaveStateToFile(t *testing.T) { + t.Run("Basic save without reconciliation", func(t *testing.T) { + // Create temp file + tmpFile := t.TempDir() + "/state.json" + + // Create caches with test data + rateCache := lru.New(1*time.Hour, 1*time.Minute) + botCache := lru.New(1*time.Hour, 1*time.Minute) + verifiedCache := lru.New(1*time.Hour, 1*time.Minute) + + rateCache.Set("192.168.0.0", uint(10), lru.DefaultExpiration) + botCache.Set("1.2.3.4", false, lru.DefaultExpiration) + verifiedCache.Set("5.6.7.8", true, lru.DefaultExpiration) + + // Save without reconciliation + lockMs, readMs, reconcileMs, marshalMs, writeMs, totalMs, err := SaveStateToFile( + tmpFile, + false, // no reconciliation + rateCache, + botCache, + verifiedCache, + testLogger(), + ) + + if err != nil { + t.Fatalf("SaveStateToFile failed: %v", err) + } + + // Verify timing metrics + if lockMs < 0 || readMs < 0 || reconcileMs < 0 || marshalMs < 0 || writeMs < 0 || totalMs < 0 { + t.Error("Expected all timing metrics to be non-negative") + } + + // Verify reconcileMs is 0 when reconciliation is disabled + if reconcileMs != 0 { + t.Errorf("Expected reconcileMs to be 0 when reconciliation disabled, got %d", reconcileMs) + } + + // Verify readMs is 0 when reconciliation is disabled + if readMs != 0 { + t.Errorf("Expected readMs to be 0 when reconciliation disabled, got %d", readMs) + } + + // Verify file was created and contains data + fileInfo, err := os.Stat(tmpFile) + if err != nil { + t.Fatalf("Failed to stat file: %v", err) + } + if fileInfo.Size() == 0 { + t.Error("State file is empty") + } + + // Load and verify the saved data + savedData, err := os.ReadFile(tmpFile) + if err != nil { + t.Fatalf("Failed to read saved file: %v", err) + } + + var savedState State + if err := json.Unmarshal(savedData, &savedState); err != nil { + t.Fatalf("Failed to unmarshal saved state: %v", err) + } + + if len(savedState.Rate) != 1 { + t.Errorf("Expected 1 rate entry, got %d", len(savedState.Rate)) + } + if len(savedState.Bots) != 1 { + t.Errorf("Expected 1 bot entry, got %d", len(savedState.Bots)) + } + if len(savedState.Verified) != 1 { + t.Errorf("Expected 1 verified entry, got %d", len(savedState.Verified)) + } + }) + + t.Run("Save with reconciliation", func(t *testing.T) { + tmpFile := t.TempDir() + "/state.json" + + // Create initial state file + now := time.Now().UnixNano() + futureExpiration := now + int64(1*time.Hour) + initialState := State{ + Rate: map[string]CacheEntry{ + "10.0.0.0": {Value: uint(5), Expiration: futureExpiration}, + }, + Bots: map[string]CacheEntry{}, + Verified: map[string]CacheEntry{}, + Memory: map[string]uintptr{"rate": 8, "bot": 8, "verified": 8}, + } + initialData, _ := json.Marshal(initialState) + if err := os.WriteFile(tmpFile, initialData, 0644); err != nil { + t.Fatalf("Failed to write initial state: %v", err) + } + + // Create caches with different data + rateCache := lru.New(1*time.Hour, 1*time.Minute) + botCache := lru.New(1*time.Hour, 1*time.Minute) + verifiedCache := lru.New(1*time.Hour, 1*time.Minute) + + rateCache.Set("192.168.0.0", uint(10), lru.DefaultExpiration) + + // Save with reconciliation enabled + lockMs, readMs, reconcileMs, marshalMs, writeMs, totalMs, err := SaveStateToFile( + tmpFile, + true, // enable reconciliation + rateCache, + botCache, + verifiedCache, + testLogger(), + ) + + if err != nil { + t.Fatalf("SaveStateToFile with reconciliation failed: %v", err) + } + + // Verify timing metrics (all should be non-negative) + if lockMs < 0 { + t.Error("Expected non-negative lockMs") + } + if readMs < 0 { + t.Error("Expected non-negative readMs when reconciliation is enabled") + } + if reconcileMs < 0 { + t.Error("Expected non-negative reconcileMs when reconciliation is enabled") + } + if marshalMs < 0 { + t.Error("Expected non-negative marshalMs") + } + if writeMs < 0 { + t.Error("Expected non-negative writeMs") + } + if totalMs < 0 { + t.Error("Expected non-negative totalMs") + } + + // Verify both entries are in the saved file (reconciled) + savedData, _ := os.ReadFile(tmpFile) + var savedState State + err = json.Unmarshal(savedData, &savedState) + if err != nil { + t.Errorf("Unable to unmarshal state %v", err) + } + + if len(savedState.Rate) != 2 { + t.Errorf("Expected 2 rate entries after reconciliation, got %d", len(savedState.Rate)) + } + }) + + t.Run("File write error", func(t *testing.T) { + // Use invalid path to trigger error + invalidPath := "/invalid/directory/that/does/not/exist/state.json" + + rateCache := lru.New(1*time.Hour, 1*time.Minute) + botCache := lru.New(1*time.Hour, 1*time.Minute) + verifiedCache := lru.New(1*time.Hour, 1*time.Minute) + + _, _, _, _, _, _, err := SaveStateToFile( + invalidPath, + false, + rateCache, + botCache, + verifiedCache, + testLogger(), + ) + + if err == nil { + t.Error("Expected error for invalid file path, got nil") + } + }) +} + +func TestLoadStateFromFile(t *testing.T) { + t.Run("Load valid state file", func(t *testing.T) { + tmpFile := t.TempDir() + "/state.json" + + // Create state file + now := time.Now().UnixNano() + futureExpiration := now + int64(1*time.Hour) + testState := State{ + Rate: map[string]CacheEntry{ + "192.168.0.0": {Value: uint(10), Expiration: futureExpiration}, + "10.0.0.0": {Value: uint(5), Expiration: futureExpiration}, + }, + Bots: map[string]CacheEntry{ + "1.2.3.4": {Value: true, Expiration: futureExpiration}, + }, + Verified: map[string]CacheEntry{ + "5.6.7.8": {Value: true, Expiration: futureExpiration}, + }, + Memory: map[string]uintptr{"rate": 8, "bot": 8, "verified": 8}, + } + + data, _ := json.Marshal(testState) + if err := os.WriteFile(tmpFile, data, 0644); err != nil { + t.Fatalf("Failed to write test state: %v", err) + } + + // Load into empty caches + rateCache := lru.New(1*time.Hour, 1*time.Minute) + botCache := lru.New(1*time.Hour, 1*time.Minute) + verifiedCache := lru.New(1*time.Hour, 1*time.Minute) + + err := LoadStateFromFile(tmpFile, rateCache, botCache, verifiedCache) + if err != nil { + t.Fatalf("LoadStateFromFile failed: %v", err) + } + + // Verify caches were populated + if rateCache.ItemCount() != 2 { + t.Errorf("Expected 2 rate entries, got %d", rateCache.ItemCount()) + } + if botCache.ItemCount() != 1 { + t.Errorf("Expected 1 bot entry, got %d", botCache.ItemCount()) + } + if verifiedCache.ItemCount() != 1 { + t.Errorf("Expected 1 verified entry, got %d", verifiedCache.ItemCount()) + } + + // Verify specific values + if v, ok := rateCache.Get("192.168.0.0"); !ok || v.(uint) != 10 { + t.Error("Expected rate 10 for 192.168.0.0") + } + if v, ok := botCache.Get("1.2.3.4"); !ok || v.(bool) != true { + t.Error("Expected bot 1.2.3.4 to be true") + } + if v, ok := verifiedCache.Get("5.6.7.8"); !ok || v.(bool) != true { + t.Error("Expected 5.6.7.8 to be verified") + } + }) + + t.Run("Load expired entries", func(t *testing.T) { + tmpFile := t.TempDir() + "/state.json" + + // Create state with expired entries + now := time.Now().UnixNano() + pastExpiration := now - int64(1*time.Hour) + testState := State{ + Rate: map[string]CacheEntry{ + "192.168.0.0": {Value: uint(10), Expiration: pastExpiration}, // expired + }, + Bots: map[string]CacheEntry{}, + Verified: map[string]CacheEntry{}, + Memory: map[string]uintptr{"rate": 8, "bot": 8, "verified": 8}, + } + + data, _ := json.Marshal(testState) + err := os.WriteFile(tmpFile, data, 0644) + if err != nil { + t.Fatalf("Unable to write file: %v", err) + } + // Load into empty caches + rateCache := lru.New(1*time.Hour, 1*time.Minute) + botCache := lru.New(1*time.Hour, 1*time.Minute) + verifiedCache := lru.New(1*time.Hour, 1*time.Minute) + + err = LoadStateFromFile(tmpFile, rateCache, botCache, verifiedCache) + if err != nil { + t.Fatalf("LoadStateFromFile failed: %v", err) + } + + // Expired entries should be filtered out + if rateCache.ItemCount() != 0 { + t.Errorf("Expected 0 entries (expired filtered out), got %d", rateCache.ItemCount()) + } + }) + + t.Run("File does not exist", func(t *testing.T) { + nonExistentFile := t.TempDir() + "/does-not-exist.json" + + rateCache := lru.New(1*time.Hour, 1*time.Minute) + botCache := lru.New(1*time.Hour, 1*time.Minute) + verifiedCache := lru.New(1*time.Hour, 1*time.Minute) + + err := LoadStateFromFile(nonExistentFile, rateCache, botCache, verifiedCache) + if err == nil { + t.Error("Expected error for non-existent file, got nil") + } + }) + + t.Run("Invalid JSON", func(t *testing.T) { + tmpFile := t.TempDir() + "/invalid.json" + + // Write invalid JSON + if err := os.WriteFile(tmpFile, []byte(`{invalid json`), 0644); err != nil { + t.Fatalf("Failed to write invalid JSON: %v", err) + } + + rateCache := lru.New(1*time.Hour, 1*time.Minute) + botCache := lru.New(1*time.Hour, 1*time.Minute) + verifiedCache := lru.New(1*time.Hour, 1*time.Minute) + + err := LoadStateFromFile(tmpFile, rateCache, botCache, verifiedCache) + if err == nil { + t.Error("Expected error for invalid JSON, got nil") + } + + // Caches should remain empty + if rateCache.ItemCount() != 0 { + t.Error("Expected empty cache after failed load") + } + }) + + t.Run("Empty file", func(t *testing.T) { + tmpFile := t.TempDir() + "/empty.json" + + // Write empty file + if err := os.WriteFile(tmpFile, []byte{}, 0644); err != nil { + t.Fatalf("Failed to write empty file: %v", err) + } + + rateCache := lru.New(1*time.Hour, 1*time.Minute) + botCache := lru.New(1*time.Hour, 1*time.Minute) + verifiedCache := lru.New(1*time.Hour, 1*time.Minute) + + // Empty file returns nil (no state to load, which is fine) + err := LoadStateFromFile(tmpFile, rateCache, botCache, verifiedCache) + if err != nil { + t.Errorf("Unexpected error for empty file: %v", err) + } + + // Caches should remain empty + if rateCache.ItemCount() != 0 { + t.Error("Expected empty cache after loading empty file") + } + }) +} + +func testLogger() *slog.Logger { + return slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{ + Level: slog.LevelError, // Only show errors during tests + })) +} + +// TestSetStateWithExpiration_Synctest uses synctest to verify expiration logic +func TestSetStateWithExpiration_Synctest(t *testing.T) { + synctest.Test(t, func(t *testing.T) { + // Initial time: midnight UTC 2000-01-01 + start := time.Now() + + // Create state with entries expiring at different times + state := State{ + Rate: map[string]CacheEntry{ + "192.168.0.0": { + Value: uint(10), + Expiration: start.Add(5 * time.Second).UnixNano(), // expires in 5s + }, + "10.0.0.0": { + Value: uint(5), + Expiration: start.Add(10 * time.Second).UnixNano(), // expires in 10s + }, + }, + Bots: map[string]CacheEntry{ + "1.2.3.4": { + Value: true, + Expiration: start.Add(3 * time.Second).UnixNano(), // expires in 3s + }, + }, + Verified: map[string]CacheEntry{ + "9.9.9.9": { + Value: true, + Expiration: 0, // never expires + }, + }, + } + + // Create empty caches (no cleanup interval to avoid background goroutines) + rateCache := lru.New(1*time.Hour, lru.NoExpiration) + botCache := lru.New(1*time.Hour, lru.NoExpiration) + verifiedCache := lru.New(1*time.Hour, lru.NoExpiration) + + // Load state + SetState(state, rateCache, botCache, verifiedCache) + + // Verify all entries are loaded + if rateCache.ItemCount() != 2 { + t.Errorf("Expected 2 rate entries, got %d", rateCache.ItemCount()) + } + if botCache.ItemCount() != 1 { + t.Errorf("Expected 1 bot entry, got %d", botCache.ItemCount()) + } + if verifiedCache.ItemCount() != 1 { + t.Errorf("Expected 1 verified entry, got %d", verifiedCache.ItemCount()) + } + + // Advance time by 4 seconds (bot entry should expire, rate entries still valid) + time.Sleep(4 * time.Second) + synctest.Wait() + + // Bot cache should be empty (expired at 3s) + if _, found := botCache.Get("1.2.3.4"); found { + t.Error("Bot entry should have expired after 3 seconds") + } + + // Rate entries should still be present + if _, found := rateCache.Get("192.168.0.0"); !found { + t.Error("Rate entry 192.168.0.0 should not expire until 5 seconds") + } + if _, found := rateCache.Get("10.0.0.0"); !found { + t.Error("Rate entry 10.0.0.0 should not expire until 10 seconds") + } + + // Advance time by 2 more seconds (total 6s, first rate entry should expire) + time.Sleep(2 * time.Second) + synctest.Wait() + + // First rate entry should be expired + if _, found := rateCache.Get("192.168.0.0"); found { + t.Error("Rate entry 192.168.0.0 should have expired after 5 seconds") + } + + // Second rate entry should still be present + if _, found := rateCache.Get("10.0.0.0"); !found { + t.Error("Rate entry 10.0.0.0 should not expire until 10 seconds") + } + + // Verified entry with no expiration should still be present + if _, found := verifiedCache.Get("9.9.9.9"); !found { + t.Error("Verified entry with no expiration should never expire") + } + + // Advance time by 5 more seconds (total 11s, all time-based entries expired) + time.Sleep(5 * time.Second) + synctest.Wait() + + // All time-based entries should be expired + if _, found := rateCache.Get("10.0.0.0"); found { + t.Error("Rate entry 10.0.0.0 should have expired after 10 seconds") + } + + // Only the never-expiring verified entry should remain + if _, found := verifiedCache.Get("9.9.9.9"); !found { + t.Error("Verified entry with no expiration should still be present after 11 seconds") + } + }) +} + +// TestReconcileStateWithExpiration_Synctest tests reconciliation with time control +func TestReconcileStateWithExpiration_Synctest(t *testing.T) { + synctest.Test(t, func(t *testing.T) { + start := time.Now() + + // Create file state with entries expiring at different times + fileState := State{ + Rate: map[string]CacheEntry{ + "192.168.0.0": { + Value: uint(15), + Expiration: start.Add(10 * time.Second).UnixNano(), // newer expiration + }, + "10.0.0.0": { + Value: uint(3), + Expiration: start.Add(5 * time.Second).UnixNano(), // older expiration + }, + }, + } + + // Create memory caches with overlapping data (no cleanup interval to avoid background goroutines) + rateCache := lru.New(1*time.Hour, lru.NoExpiration) + botCache := lru.New(1*time.Hour, lru.NoExpiration) + verifiedCache := lru.New(1*time.Hour, lru.NoExpiration) + + // Memory entry with older expiration (should be replaced) + rateCache.Set("192.168.0.0", uint(10), 5*time.Second) + // Memory entry with newer expiration (should be kept) + rateCache.Set("10.0.0.0", uint(5), 10*time.Second) + + // Reconcile + ReconcileState(fileState, rateCache, botCache, verifiedCache) + + // 192.168.0.0 should have file's value (newer expiration) + if v, ok := rateCache.Get("192.168.0.0"); !ok || v.(uint) != 15 { + t.Errorf("Expected rate 15 for 192.168.0.0, got %v", v) + } + + // 10.0.0.0 should have memory's value (newer expiration) + if v, ok := rateCache.Get("10.0.0.0"); !ok || v.(uint) != 5 { + t.Errorf("Expected rate 5 for 10.0.0.0 (memory kept), got %v", v) + } + + // Advance time by 6 seconds + time.Sleep(6 * time.Second) + synctest.Wait() + + // Both entries should still be present (both have 10s expiration from reconciliation) + // - 192.168.0.0 has file's value (15) with 10s expiration + // - 10.0.0.0 has memory's value (5) with 10s expiration + if _, found := rateCache.Get("10.0.0.0"); !found { + t.Error("Entry 10.0.0.0 should not expire until 10 seconds (memory had newer expiration)") + } + + if _, found := rateCache.Get("192.168.0.0"); !found { + t.Error("Entry 192.168.0.0 should not expire until 10 seconds (file had newer expiration)") + } + + // Advance time by 5 more seconds (total 11s) + time.Sleep(5 * time.Second) + synctest.Wait() + + // All entries should be expired (verify by trying to get them) + if _, found := rateCache.Get("192.168.0.0"); found { + t.Error("Entry 192.168.0.0 should have expired after 10 seconds") + } + // Manually trigger cleanup since we're not using automatic janitor + rateCache.DeleteExpired() + if rateCache.ItemCount() != 0 { + t.Errorf("Expected all entries expired, got %d entries", rateCache.ItemCount()) + } + }) +} + +// TestSaveAndLoadStateWithExpiration_Synctest tests full save/load cycle with time control +func TestSaveAndLoadStateWithExpiration_Synctest(t *testing.T) { + synctest.Test(t, func(t *testing.T) { + tmpFile := t.TempDir() + "/state.json" + + // Create caches with entries expiring at different times (no cleanup interval to avoid background goroutines) + rateCache1 := lru.New(1*time.Hour, lru.NoExpiration) + botCache1 := lru.New(1*time.Hour, lru.NoExpiration) + verifiedCache1 := lru.New(1*time.Hour, lru.NoExpiration) + + rateCache1.Set("192.168.0.0", uint(10), 5*time.Second) + rateCache1.Set("10.0.0.0", uint(5), 10*time.Second) + botCache1.Set("1.2.3.4", true, 3*time.Second) + verifiedCache1.Set("9.9.9.9", true, lru.NoExpiration) + + // Save state + _, _, _, _, _, _, err := SaveStateToFile( + tmpFile, + false, + rateCache1, + botCache1, + verifiedCache1, + testLogger(), + ) + if err != nil { + t.Fatalf("SaveStateToFile failed: %v", err) + } + + // Advance time by 4 seconds (bot expires, rates still valid) + time.Sleep(4 * time.Second) + synctest.Wait() + + // Load into new caches (no cleanup interval to avoid background goroutines) + rateCache2 := lru.New(1*time.Hour, lru.NoExpiration) + botCache2 := lru.New(1*time.Hour, lru.NoExpiration) + verifiedCache2 := lru.New(1*time.Hour, lru.NoExpiration) + + err = LoadStateFromFile(tmpFile, rateCache2, botCache2, verifiedCache2) + if err != nil { + t.Fatalf("LoadStateFromFile failed: %v", err) + } + + // Bot entry should be filtered out (expired 1 second ago) + if botCache2.ItemCount() != 0 { + t.Errorf("Expected 0 bot entries (expired), got %d", botCache2.ItemCount()) + } + + // First rate entry should be loaded (expires at 5s, we're at 4s) + if _, found := rateCache2.Get("192.168.0.0"); !found { + t.Error("Rate entry 192.168.0.0 should be loaded (not yet expired)") + } + + // Second rate entry should be loaded (expires at 10s, we're at 4s) + if _, found := rateCache2.Get("10.0.0.0"); !found { + t.Error("Rate entry 10.0.0.0 should be loaded (not yet expired)") + } + + // Verified entry should be loaded (no expiration) + if _, found := verifiedCache2.Get("9.9.9.9"); !found { + t.Error("Verified entry should be loaded (no expiration)") + } + + // Advance time by 2 more seconds (total 6s, first rate entry expires) + time.Sleep(2 * time.Second) + synctest.Wait() + + // First rate entry should be expired + if _, found := rateCache2.Get("192.168.0.0"); found { + t.Error("Rate entry 192.168.0.0 should have expired") + } + + // Second rate entry should still exist + if _, found := rateCache2.Get("10.0.0.0"); !found { + t.Error("Rate entry 10.0.0.0 should still be present") + } + }) +} + +// TestReconcilePreservesNewerData_Synctest verifies reconciliation keeps fresher data +func TestReconcilePreservesNewerData_Synctest(t *testing.T) { + synctest.Test(t, func(t *testing.T) { + tmpFile := t.TempDir() + "/state.json" + + // Create initial state file with data expiring in 5 seconds (no cleanup interval to avoid background goroutines) + initialCache := lru.New(1*time.Hour, lru.NoExpiration) + initialCache.Set("192.168.0.0", uint(100), 5*time.Second) + + _, _, _, _, _, _, err := SaveStateToFile( + tmpFile, + false, + initialCache, + lru.New(1*time.Hour, lru.NoExpiration), + lru.New(1*time.Hour, lru.NoExpiration), + testLogger(), + ) + if err != nil { + t.Fatalf("Initial save failed: %v", err) + } + + // Advance time by 2 seconds + time.Sleep(2 * time.Second) + synctest.Wait() + + // Create new in-memory data with expiration in 10 seconds from original start + // This represents fresher data (no cleanup interval to avoid background goroutines) + newCache := lru.New(1*time.Hour, lru.NoExpiration) + newCache.Set("192.168.0.0", uint(200), 8*time.Second) // expires at start+10s + + // Save with reconciliation enabled + _, _, _, _, _, _, err = SaveStateToFile( + tmpFile, + true, // reconcile + newCache, + lru.New(1*time.Hour, lru.NoExpiration), + lru.New(1*time.Hour, lru.NoExpiration), + testLogger(), + ) + if err != nil { + t.Fatalf("Reconciled save failed: %v", err) + } + + // Load back and verify we got the newer value (no cleanup interval to avoid background goroutines) + loadedCache := lru.New(1*time.Hour, lru.NoExpiration) + err = LoadStateFromFile( + tmpFile, + loadedCache, + lru.New(1*time.Hour, lru.NoExpiration), + lru.New(1*time.Hour, lru.NoExpiration), + ) + if err != nil { + t.Fatalf("Load failed: %v", err) + } + + // Should have the newer value (200 with later expiration) + if v, found := loadedCache.Get("192.168.0.0"); !found || v.(uint) != 200 { + t.Errorf("Expected value 200 (newer data), got %v (found=%v)", v, found) + } + + // Advance time by 4 more seconds (total 6s from start) + // Old data would have expired at 5s, new data expires at 10s + time.Sleep(4 * time.Second) + synctest.Wait() + + // New data should still be valid + if _, found := loadedCache.Get("192.168.0.0"); !found { + t.Error("Newer data should still be valid (expires at 10s, we're at 6s)") + } + + // Advance time by 5 more seconds (total 11s from start) + time.Sleep(5 * time.Second) + synctest.Wait() + + // Now the newer data should also be expired + if _, found := loadedCache.Get("192.168.0.0"); found { + t.Error("Newer data should have expired after 10 seconds") + } + }) +} + +// TestCacheCleanupInterval_Synctest verifies go-cache cleanup runs on schedule +// NOTE: This test is skipped because it tests the janitor goroutine which is incompatible with synctest +func TestCacheCleanupInterval_Synctest(t *testing.T) { + t.Skip("Skipping test that requires janitor goroutine (incompatible with synctest)") + synctest.Test(t, func(t *testing.T) { + // Create cache with 1 minute cleanup interval + cleanupInterval := 1 * time.Minute + cache := lru.New(5*time.Second, cleanupInterval) + + // Add entry that expires in 3 seconds + cache.Set("test-key", uint(42), 3*time.Second) + + // Verify entry exists + if _, found := cache.Get("test-key"); !found { + t.Fatal("Entry should exist immediately after Set") + } + + // Advance time by 4 seconds (entry expired but cleanup hasn't run) + time.Sleep(4 * time.Second) + synctest.Wait() + + // Entry is expired but might still be in cache (cleanup hasn't run yet) + // The Get should return false because go-cache checks expiration on Get + if _, found := cache.Get("test-key"); found { + t.Error("Entry should be expired after 3 seconds") + } + + // Advance time to trigger cleanup (cleanup runs every 1 minute) + time.Sleep(57 * time.Second) // Total 61 seconds, cleanup should have run + synctest.Wait() + + // Entry should definitely be cleaned up now + if cache.ItemCount() != 0 { + t.Errorf("Cache should be empty after cleanup, got %d items", cache.ItemCount()) + } + }) +} diff --git a/main.go b/main.go index 7854a82..7debc1d 100644 --- a/main.go +++ b/main.go @@ -5,6 +5,7 @@ import ( "encoding/json" "fmt" "log/slog" + "math/rand" "net" "net/http" "net/url" @@ -23,6 +24,13 @@ import ( lru "github.com/patrickmn/go-cache" ) +const ( + // StateSaveInterval is how often the persistent state file is written to disk + StateSaveInterval = 10 * time.Second + // StateSaveJitter is the maximum random jitter added to save interval to prevent thundering herd + StateSaveJitter = 2 * time.Second +) + type Config struct { RateLimit uint `json:"rateLimit"` Window int64 `json:"window"` @@ -49,7 +57,12 @@ type Config struct { EnableStatsPage string `json:"enableStatsPage"` LogLevel string `json:"loglevel,omitempty"` PersistentStateFile string `json:"persistentStateFile"` - Mode string `json:"mode"` + // EnableStateReconciliation is a string instead of bool due to Traefik's label parsing limitations + // When enabled, the plugin will read and merge state from disk before each save to prevent + // multiple instances from overwriting each other's data. This adds extra I/O overhead. + // Only enable this if running multiple plugin instances sharing the same state file. + EnableStateReconciliation string `json:"enableStateReconciliation"` + Mode string `json:"mode"` } type CaptchaProtect struct { @@ -82,27 +95,28 @@ type captchaResponse struct { func CreateConfig() *Config { return &Config{ - RateLimit: 20, - Window: 86400, - IPv4SubnetMask: 16, - IPv6SubnetMask: 64, - IPForwardedHeader: "", - ProtectParameters: "false", - ProtectRoutes: []string{}, - ExcludeRoutes: []string{}, - ProtectHttpMethods: []string{}, - ProtectFileExtensions: []string{}, - GoodBots: []string{}, - ExemptIPs: []string{}, - ExemptUserAgents: []string{}, - ChallengeURL: "/challenge", - ChallengeTmpl: "challenge.tmpl.html", - ChallengeStatusCode: 0, - EnableStatsPage: "false", - LogLevel: "INFO", - IPDepth: 0, - CaptchaProvider: "turnstile", - Mode: "prefix", + RateLimit: 20, + Window: 86400, + IPv4SubnetMask: 16, + IPv6SubnetMask: 64, + IPForwardedHeader: "", + ProtectParameters: "false", + ProtectRoutes: []string{}, + ExcludeRoutes: []string{}, + ProtectHttpMethods: []string{}, + ProtectFileExtensions: []string{}, + GoodBots: []string{}, + ExemptIPs: []string{}, + ExemptUserAgents: []string{}, + ChallengeURL: "/challenge", + ChallengeTmpl: "challenge.tmpl.html", + ChallengeStatusCode: 0, + EnableStatsPage: "false", + LogLevel: "INFO", + IPDepth: 0, + CaptchaProvider: "turnstile", + Mode: "prefix", + EnableStateReconciliation: "false", } } @@ -393,6 +407,7 @@ func (bc *CaptchaProtect) verifyChallengePage(rw http.ResponseWriter, req *http. } if captchaResponse.Success { bc.verifiedCache.Set(ip, true, lru.DefaultExpiration) + destination := req.FormValue("destination") if destination == "" { destination = "%2F" @@ -580,7 +595,7 @@ func (bc *CaptchaProtect) registerRequest(ip string) { _, err = bc.rateCache.IncrementUint(ip, uint(1)) if err != nil { - bc.log.Error("unable to set rate cache", "ip", ip) + bc.log.Error("unable to set rate cache", "ip", ip, "err", err) } } @@ -699,7 +714,13 @@ func (c *Config) ParseHttpMethods(log *slog.Logger) { } func (bc *CaptchaProtect) saveState(ctx context.Context) { - ticker := time.NewTicker(1 * time.Minute) + // Add random jitter to prevent multiple instances from trying to save simultaneously + jitter := time.Duration(rand.Intn(int(StateSaveJitter.Milliseconds()))) * time.Millisecond + interval := StateSaveInterval + jitter + + bc.log.Debug("State save configured", "baseInterval", StateSaveInterval, "jitter", jitter, "actualInterval", interval) + + ticker := time.NewTicker(interval) defer ticker.Stop() file, err := os.OpenFile(bc.config.PersistentStateFile, os.O_CREATE|os.O_WRONLY, 0644) @@ -713,49 +734,61 @@ func (bc *CaptchaProtect) saveState(ctx context.Context) { for { select { case <-ticker.C: - bc.log.Debug("Saving state") - state := state.GetState(bc.rateCache.Items(), bc.botCache.Items(), bc.verifiedCache.Items()) - jsonData, err := json.Marshal(state) - if err != nil { - bc.log.Error("failed to marshal state data", "err", err) - break - } - err = os.WriteFile(bc.config.PersistentStateFile, jsonData, 0644) - if err != nil { - bc.log.Error("failed to save state data", "err", err) - } + bc.log.Debug("Periodic state save triggered") + bc.saveStateNow() case <-ctx.Done(): - bc.log.Debug("Context cancelled, stopping saveState") + bc.log.Debug("Context cancelled, running saveState before shutdown") + bc.saveStateNow() return } } } -func (bc *CaptchaProtect) loadState() { - fileContent, err := os.ReadFile(bc.config.PersistentStateFile) - if err != nil || len(fileContent) == 0 { - bc.log.Warn("failed to load state file", "err", err) - return - } +// saveStateNow performs an immediate state save using the state package. +func (bc *CaptchaProtect) saveStateNow() { + reconcile := bc.config.EnableStateReconciliation == "true" + + lockMs, readMs, reconcileMs, marshalMs, writeMs, totalMs, err := state.SaveStateToFile( + bc.config.PersistentStateFile, + reconcile, + bc.rateCache, + bc.botCache, + bc.verifiedCache, + bc.log, + ) - var state state.State - err = json.Unmarshal(fileContent, &state) if err != nil { - bc.log.Error("failed to unmarshal state file", "err", err) + bc.log.Error("failed to save state", "err", err) return } - for k, v := range state.Rate { - bc.rateCache.Set(k, v, lru.DefaultExpiration) - } + // Get current state for logging (already marshaled in SaveStateToFile, but we need counts) + currentState := state.GetState(bc.rateCache.Items(), bc.botCache.Items(), bc.verifiedCache.Items()) + bc.log.Debug("State saved successfully", + "rateEntries", len(currentState.Rate), + "botEntries", len(currentState.Bots), + "verifiedEntries", len(currentState.Verified), + "lockMs", lockMs, + "readMs", readMs, + "reconcileMs", reconcileMs, + "marshalMs", marshalMs, + "writeMs", writeMs, + "totalMs", totalMs, + ) +} - for k, v := range state.Bots { - bc.botCache.Set(k, v, lru.DefaultExpiration) - } +func (bc *CaptchaProtect) loadState() { + err := state.LoadStateFromFile( + bc.config.PersistentStateFile, + bc.rateCache, + bc.botCache, + bc.verifiedCache, + ) - for k, v := range state.Verified { - bc.verifiedCache.Set(k, v, lru.DefaultExpiration) + if err != nil { + bc.log.Warn("failed to load state file", "err", err) + return } bc.log.Info("Loaded previous state") diff --git a/main_test.go b/main_test.go index b2b2c33..a15713c 100644 --- a/main_test.go +++ b/main_test.go @@ -937,15 +937,26 @@ func TestStatePersistence(t *testing.T) { // Manually save state by writing the file directly // This tests the state format without relying on the background goroutine + // Use the new CacheEntry format with expiration timestamps + futureExpiration := time.Now().Add(1 * time.Hour).UnixNano() jsonData, _ := json.Marshal(map[string]interface{}{ - "rate": map[string]uint{ - "192.168.0.0": 10, + "rate": map[string]map[string]interface{}{ + "192.168.0.0": { + "value": uint(10), + "expiration": float64(futureExpiration), + }, }, - "verified": map[string]bool{ - "1.2.3.4": true, + "verified": map[string]map[string]interface{}{ + "1.2.3.4": { + "value": true, + "expiration": float64(futureExpiration), + }, }, - "bots": map[string]bool{ - "5.6.7.8": false, + "bots": map[string]map[string]interface{}{ + "5.6.7.8": { + "value": false, + "expiration": float64(futureExpiration), + }, }, }) err := os.WriteFile(tmpFile, jsonData, 0644)