diff --git a/.github/workflows/lint-test.yml b/.github/workflows/lint-test.yml index 0b552d4..ce96425 100644 --- a/.github/workflows/lint-test.yml +++ b/.github/workflows/lint-test.yml @@ -13,7 +13,7 @@ jobs: - uses: actions/setup-go@44694675825211faa026b3c33043df3e48a5fa00 # v6 with: - go-version: '>=1.24.0' + go-version: ">=1.25.0" - name: golangci-lint uses: golangci/golangci-lint-action@4afd733a84b1f43292c63897423277bb7f4313a9 # v8 @@ -42,6 +42,17 @@ jobs: - name: unit test run: go test -v -race ./... + - name: generate coverage + run: go test -coverprofile=coverage.out -covermode=atomic ./... + + - name: upload coverage to codecov + uses: codecov/codecov-action@v5 + with: + files: ./coverage.out + fail_ci_if_error: false + env: + CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} + integration-test: needs: [run] permissions: diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 0000000..19d5ef9 --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,171 @@ +# CLAUDE.md + +This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository. + +## Project Overview + +This is a Traefik middleware plugin that protects websites from bot traffic by challenging individual IPs with CAPTCHAs when traffic spikes are detected from their subnet. The plugin supports Cloudflare Turnstile, Google reCAPTCHA, and hCaptcha. + +**Key concept**: Instead of rate limiting individual IPs, this plugin monitors traffic at the subnet level (e.g., /16 for IPv4, /64 for IPv6) and only challenges specific IPs when their entire subnet exceeds a configured rate limit. + +## Architecture + +### Core Components + +- **main.go** (`main.go:1-761`): Contains the entire middleware implementation in a single file + - `CaptchaProtect` struct: Main middleware handler with rate limiting, bot detection, and challenge serving + - `Config` struct: Configuration from Traefik labels + - Three in-memory caches (using `github.com/patrickmn/go-cache`): + - `rateCache`: Tracks request counts per subnet + - `verifiedCache`: Stores IPs that have passed challenges (24h default TTL) + - `botCache`: Caches reverse DNS lookups for bot verification + +### Request Flow Decision Tree + +The middleware follows this decision order (see `shouldApply()` at `main.go:422-453`): + +1. Check if HTTP method is protected (default: GET, HEAD) +2. Check if IP already verified (passed challenge recently) +3. Check if IP is in exemptIps (private ranges + configured exemptions) +4. Check if IP is a good bot (reverse DNS matches goodBots list) +5. Check if user agent is exempt +6. Check if route matches protection rules (prefix/suffix/regex matching) +7. If protected, increment subnet counter and check rate limit +8. If rate limit exceeded, serve challenge or redirect to challenge page + +### Internal Packages + +- **internal/helper/**: Utility functions + - `ip.go`: IP parsing, CIDR matching, reverse DNS lookups for bot verification + - `tmpl.go`: Default challenge template (embedded fallback) +- **internal/log/**: Structured logging with slog +- **internal/state/**: State serialization for persistent storage across restarts + +### Challenge Modes + +Two modes for serving challenges: + +1. **Redirect mode** (default): `challengeURL: "/challenge"` - Redirects to dedicated challenge page +2. **Inline mode**: `challengeURL: ""` - Serves challenge on the same page that triggered rate limit + +## Development Commands + +### Running Tests + +```bash +# Run unit tests +go test -v -race ./... + +# Run single test +go test -v -race -run TestParseIp + +# Run integration tests (requires Docker) +cd ci && go run test.go +``` + +### Linting and Formatting + +```bash +# Run golangci-lint locally +golangci-lint run + +# Format code +gofmt -w . + +# Check if go.mod is tidy +go mod tidy && git diff --exit-code go.mod go.sum + +# Update vendored dependencies +go mod vendor +``` + +### CI/CD + +The GitHub Actions workflow (`.github/workflows/lint-test.yml`) runs on every push: +1. golangci-lint +2. Validates `.traefik.yml` with yq +3. Checks `go mod tidy` and `go mod vendor` are up-to-date +4. Runs unit tests with race detector +5. Runs integration tests against Traefik v2.11, v3.0, v3.1, v3.2, v3.3, v3.4 + +### Integration Testing + +The `ci/` directory contains a full integration test: +- Spins up Traefik + nginx with docker-compose +- Generates 100 unique public IPs from different subnets +- Makes parallel requests to verify rate limiting behavior +- Tests state persistence across container restarts +- Validates stats endpoint JSON + +To run: `cd ci && go run test.go` + +## Key Implementation Details + +### Route Matching Modes + +Three modes configured via `mode` parameter (defaults to "prefix"): + +1. **prefix**: Fast string prefix matching (`strings.HasPrefix`) +2. **suffix**: Matches route suffixes (useful for specific endpoints) +3. **regex**: Full regex support (13x slower than prefix, use only when needed) + +Regex is significantly slower (~41ns vs ~3.4ns per operation) - see README benchmark section. + +### IP Subnet Calculation + +- IPv4: Masks IPs to configured subnet (default /16 means `192.168.x.x` → `192.168.0.0`) +- IPv6: Default /64 subnet mask +- Implementation at `main.go:621-642` + +### State Persistence + +When `persistentStateFile` is configured: +- State saves every 1 minute to JSON file (`saveState()` at `main.go:695-727`) +- On startup, loads previous state from file (`loadState()` at `main.go:729-756`) +- Contains: rate limits per subnet, bot verification cache, verified IPs + +### Good Bot Detection + +To avoid SEO impact, the plugin allows "good bots" to bypass rate limits: +- Performs reverse DNS lookup on IP (`internal/helper/ip.go`) +- Checks if hostname ends with configured second-level domain (e.g., "googlebot.com") +- Results cached in `botCache` to avoid repeated DNS lookups +- Optional `protectParameters: "true"` forces rate limiting even for good bots if URL contains query parameters + +### File Extension Filtering + +By default, only HTML files are rate-limited (to prevent CSS/JS/images from consuming rate limit quota). Configure `protectFileExtensions` to add more file types. + +## Configuration + +Configuration comes from Traefik labels. See `.traefik.yml` for the plugin manifest. + +Key defaults: +- `rateLimit: 20` requests per subnet +- `window: 86400` seconds (24 hours) +- `ipv4subnetMask: 16` (/16 = 65,536 IPs) +- `ipv6subnetMask: 64` +- `challengeStatusCode: 200` (or 429 for inline challenges) + +## Testing Strategy + +Unit tests (`main_test.go`) cover: +- IP parsing and subnet masking +- Route protection logic (prefix/suffix/regex) +- Client IP extraction from forwarded headers with depth traversal +- User agent exemption matching +- Challenge page serving with different status codes + +Integration tests (`ci/test.go`) verify: +- Full request lifecycle with real Traefik/nginx +- Rate limiting behavior across multiple subnets +- State persistence across container restarts +- Stats endpoint functionality + +## Traefik Plugin Constraints + +- Must implement `http.Handler` interface +- Entry point: `New(ctx context.Context, next http.Handler, config *Config, name string)` +- Plugin loaded via Traefik's `--experimental.plugins` flag +- No external state allowed (must use in-memory caches or file persistence) +- Must be compatible with Traefik v2.11.1+ diff --git a/README.md b/README.md index e44145d..643e088 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,7 @@ # Captcha Protect [![lint-test](https://github.com/libops/captcha-protect/actions/workflows/lint-test.yml/badge.svg)](https://github.com/libops/captcha-protect/actions/workflows/lint-test.yml) [![Go Report Card](https://goreportcard.com/badge/github.com/libops/captcha-protect)](https://goreportcard.com/report/github.com/libops/captcha-protect) +[![codecov](https://codecov.io/gh/libops/captcha-protect/branch/main/graph/badge.svg)](https://codecov.io/gh/libops/captcha-protect) Traefik middleware to challenge individual IPs in a subnet when traffic spikes are detected from that subnet, using a captcha of your choice for the challenge (turnstile, recaptcha, or hcaptcha). **Requires traefik `v2.11.1` or above** diff --git a/go.mod b/go.mod index 1b3534a..4f45b71 100644 --- a/go.mod +++ b/go.mod @@ -1,5 +1,5 @@ module github.com/libops/captcha-protect -go 1.24.0 +go 1.25.0 require github.com/patrickmn/go-cache v2.1.0+incompatible diff --git a/internal/helper/ip_test.go b/internal/helper/ip_test.go index cebbbcb..540a226 100644 --- a/internal/helper/ip_test.go +++ b/internal/helper/ip_test.go @@ -191,3 +191,55 @@ func parseCIDR(cidr string, t *testing.T) *net.IPNet { } return block } + +func TestParseCIDR(t *testing.T) { + tests := []struct { + name string + cidr string + expectErr bool + }{ + { + name: "Valid IPv4 CIDR", + cidr: "192.168.1.0/24", + expectErr: false, + }, + { + name: "Valid IPv6 CIDR", + cidr: "2001:db8::/32", + expectErr: false, + }, + { + name: "Invalid CIDR - no mask", + cidr: "192.168.1.0", + expectErr: true, + }, + { + name: "Invalid CIDR - bad format", + cidr: "not-a-cidr", + expectErr: true, + }, + { + name: "Invalid CIDR - empty string", + cidr: "", + expectErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := ParseCIDR(tt.cidr) + if tt.expectErr { + if err == nil { + t.Errorf("Expected error for CIDR %q, got nil", tt.cidr) + } + } else { + if err != nil { + t.Errorf("Unexpected error for CIDR %q: %v", tt.cidr, err) + } + if result == nil { + t.Errorf("Expected non-nil result for valid CIDR %q", tt.cidr) + } + } + }) + } +} diff --git a/internal/helper/tmpl_test.go b/internal/helper/tmpl_test.go new file mode 100644 index 0000000..7873e43 --- /dev/null +++ b/internal/helper/tmpl_test.go @@ -0,0 +1,47 @@ +package helper + +import ( + "strings" + "testing" +) + +func TestGetDefaultTmpl(t *testing.T) { + tmpl := GetDefaultTmpl() + + // Verify it returns a non-empty string + if tmpl == "" { + t.Error("GetDefaultTmpl returned empty string") + } + + // Verify it contains expected HTML elements + expectedElements := []string{ + "", + "", + "", + "", + "", + "", + "", + "{{ .FrontendJS }}", + "{{ .SiteKey }}", + "{{ .ChallengeURL }}", + "{{ .Destination }}", + "{{ .FrontendKey }}", + "captchaCallback", + } + + for _, elem := range expectedElements { + if !strings.Contains(tmpl, elem) { + t.Errorf("Template missing expected element: %s", elem) + } + } + + // Verify it's valid HTML structure (basic check) + if !strings.HasPrefix(tmpl, "") { + t.Error("Template should start with ") + } + if !strings.HasSuffix(strings.TrimSpace(tmpl), "") { + t.Error("Template should end with ") + } +} diff --git a/internal/log/log_test.go b/internal/log/log_test.go new file mode 100644 index 0000000..277511c --- /dev/null +++ b/internal/log/log_test.go @@ -0,0 +1,73 @@ +package log + +import ( + "log/slog" + "testing" +) + +func TestNew(t *testing.T) { + tests := []struct { + name string + levelStr string + expectedLevel slog.Level + }{ + {"DEBUG level", "DEBUG", slog.LevelDebug}, + {"INFO level", "INFO", slog.LevelInfo}, + {"WARN level", "WARN", slog.LevelWarn}, + {"WARNING level", "WARNING", slog.LevelWarn}, + {"ERROR level", "ERROR", slog.LevelError}, + {"debug lowercase", "debug", slog.LevelDebug}, + {"Unknown level defaults to INFO", "UNKNOWN", slog.LevelInfo}, + {"Empty level defaults to INFO", "", slog.LevelInfo}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + logger := New(tt.levelStr) + if logger == nil { + t.Error("Expected non-nil logger") + } + // Logger is created successfully, we can't easily test the exact level + // but we verify it doesn't panic or error + }) + } +} + +func TestParseLogLevel(t *testing.T) { + tests := []struct { + name string + level string + expected slog.Level + expectErr bool + }{ + {"DEBUG", "DEBUG", slog.LevelDebug, false}, + {"debug lowercase", "debug", slog.LevelDebug, false}, + {"INFO", "INFO", slog.LevelInfo, false}, + {"info lowercase", "info", slog.LevelInfo, false}, + {"WARN", "WARN", slog.LevelWarn, false}, + {"WARNING", "WARNING", slog.LevelWarn, false}, + {"warning lowercase", "warning", slog.LevelWarn, false}, + {"ERROR", "ERROR", slog.LevelError, false}, + {"error lowercase", "error", slog.LevelError, false}, + {"Unknown level", "INVALID", slog.LevelInfo, true}, + {"Empty string", "", slog.LevelInfo, true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + level, err := parseLogLevel(tt.level) + if tt.expectErr { + if err == nil { + t.Errorf("Expected error for level %q, got nil", tt.level) + } + } else { + if err != nil { + t.Errorf("Unexpected error for level %q: %v", tt.level, err) + } + } + if level != tt.expected { + t.Errorf("Expected level %v, got %v", tt.expected, level) + } + }) + } +} diff --git a/internal/state/state_test.go b/internal/state/state_test.go new file mode 100644 index 0000000..4237f80 --- /dev/null +++ b/internal/state/state_test.go @@ -0,0 +1,93 @@ +package state + +import ( + "testing" + "time" + + lru "github.com/patrickmn/go-cache" +) + +func TestGetState(t *testing.T) { + // Create test caches + rateCache := lru.New(1*time.Hour, 1*time.Minute) + botCache := lru.New(1*time.Hour, 1*time.Minute) + verifiedCache := lru.New(1*time.Hour, 1*time.Minute) + + // Add test data + rateCache.Set("192.168.0.0", uint(10), lru.DefaultExpiration) + rateCache.Set("10.0.0.0", uint(5), lru.DefaultExpiration) + + botCache.Set("1.2.3.4", true, lru.DefaultExpiration) + botCache.Set("5.6.7.8", false, lru.DefaultExpiration) + + verifiedCache.Set("9.9.9.9", true, lru.DefaultExpiration) + + // Get state + state := GetState(rateCache.Items(), botCache.Items(), verifiedCache.Items()) + + // Verify rate cache data + if len(state.Rate) != 2 { + t.Errorf("Expected 2 rate entries, got %d", len(state.Rate)) + } + if state.Rate["192.168.0.0"] != 10 { + t.Errorf("Expected rate 10 for 192.168.0.0, got %d", state.Rate["192.168.0.0"]) + } + if state.Rate["10.0.0.0"] != 5 { + t.Errorf("Expected rate 5 for 10.0.0.0, got %d", state.Rate["10.0.0.0"]) + } + + // Verify bot cache data + if len(state.Bots) != 2 { + t.Errorf("Expected 2 bot entries, got %d", len(state.Bots)) + } + if state.Bots["1.2.3.4"] != true { + t.Error("Expected bot 1.2.3.4 to be true") + } + if state.Bots["5.6.7.8"] != false { + t.Error("Expected bot 5.6.7.8 to be false") + } + + // Verify verified cache data + if len(state.Verified) != 1 { + t.Errorf("Expected 1 verified entry, got %d", len(state.Verified)) + } + if state.Verified["9.9.9.9"] != true { + t.Error("Expected 9.9.9.9 to be verified") + } + + // Verify memory tracking exists + if len(state.Memory) != 3 { + t.Errorf("Expected 3 memory entries, got %d", len(state.Memory)) + } + if state.Memory["rate"] == 0 { + t.Error("Expected non-zero memory for rate cache") + } + if state.Memory["bot"] == 0 { + t.Error("Expected non-zero memory for bot cache") + } + if state.Memory["verified"] == 0 { + t.Error("Expected non-zero memory for verified cache") + } +} + +func TestGetStateEmpty(t *testing.T) { + // Create empty caches + rateCache := lru.New(1*time.Hour, 1*time.Minute) + botCache := lru.New(1*time.Hour, 1*time.Minute) + verifiedCache := lru.New(1*time.Hour, 1*time.Minute) + + state := GetState(rateCache.Items(), botCache.Items(), verifiedCache.Items()) + + if len(state.Rate) != 0 { + t.Errorf("Expected 0 rate entries, got %d", len(state.Rate)) + } + if len(state.Bots) != 0 { + t.Errorf("Expected 0 bot entries, got %d", len(state.Bots)) + } + if len(state.Verified) != 0 { + t.Errorf("Expected 0 verified entries, got %d", len(state.Verified)) + } + if len(state.Memory) != 3 { + t.Errorf("Expected 3 memory entries, got %d", len(state.Memory)) + } +} diff --git a/main.go b/main.go index 530d97e..7854a82 100644 --- a/main.go +++ b/main.go @@ -23,17 +23,14 @@ import ( lru "github.com/patrickmn/go-cache" ) -var ( - log *slog.Logger -) - type Config struct { - RateLimit uint `json:"rateLimit"` - Window int64 `json:"window"` - IPv4SubnetMask int `json:"ipv4subnetMask"` - IPv6SubnetMask int `json:"ipv6subnetMask"` - IPForwardedHeader string `json:"ipForwardedHeader"` - IPDepth int `json:"ipDepth"` + RateLimit uint `json:"rateLimit"` + Window int64 `json:"window"` + IPv4SubnetMask int `json:"ipv4subnetMask"` + IPv6SubnetMask int `json:"ipv6subnetMask"` + IPForwardedHeader string `json:"ipForwardedHeader"` + IPDepth int `json:"ipDepth"` + // ProtectParameters is a string instead of bool due to Traefik's label parsing limitations ProtectParameters string `json:"protectParameters"` ProtectRoutes []string `json:"protectRoutes"` ExcludeRoutes []string `json:"excludeRoutes"` @@ -48,16 +45,19 @@ type Config struct { CaptchaProvider string `json:"captchaProvider"` SiteKey string `json:"siteKey"` SecretKey string `json:"secretKey"` - EnableStatsPage string `json:"enableStatsPage"` - LogLevel string `json:"loglevel,omitempty"` - PersistentStateFile string `json:"persistentStateFile"` - Mode string `json:"mode"` + // EnableStatsPage is a string instead of bool due to Traefik's label parsing limitations + EnableStatsPage string `json:"enableStatsPage"` + LogLevel string `json:"loglevel,omitempty"` + PersistentStateFile string `json:"persistentStateFile"` + Mode string `json:"mode"` } type CaptchaProtect struct { next http.Handler name string config *Config + log *slog.Logger + httpClient *http.Client rateCache *lru.Cache verifiedCache *lru.Cache botCache *lru.Cache @@ -111,7 +111,18 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h } func NewCaptchaProtect(ctx context.Context, next http.Handler, config *Config, name string) (*CaptchaProtect, error) { - log = plog.New(config.LogLevel) + log := plog.New(config.LogLevel) + + // Validate required config + if config.SiteKey == "" { + return nil, fmt.Errorf("siteKey is required") + } + if config.SecretKey == "" { + return nil, fmt.Errorf("secretKey is required") + } + if config.Window <= 0 { + return nil, fmt.Errorf("window must be positive, got %d", config.Window) + } expiration := time.Duration(config.Window) * time.Second log.Debug("Captcha config", "config", config) @@ -164,11 +175,11 @@ func NewCaptchaProtect(ctx context.Context, next http.Handler, config *Config, n "HEAD", } } - config.ParseHttpMethods() + config.ParseHttpMethods(log) var tmpl *template.Template if _, err := os.Stat(config.ChallengeTmpl); os.IsNotExist(err) { - log.Warn("Unable to find template file. Using default template.", "challengeTmpl", config.ChallengeTmpl) + log.Warn("Unable to find template file. Using default template", "challengeTmpl", config.ChallengeTmpl) ts := helper.GetDefaultTmpl() tmpl, err = template.New("challenge").Parse(ts) if err != nil { @@ -183,6 +194,8 @@ func NewCaptchaProtect(ctx context.Context, next http.Handler, config *Config, n } } + // Always protect HTML files by default to ensure the main content is rate-limited. + // This prevents users from accidentally excluding HTML, which would break the protection. if !slices.Contains(config.ProtectFileExtensions, "html") { config.ProtectFileExtensions = append(config.ProtectFileExtensions, "html") } @@ -206,9 +219,13 @@ func NewCaptchaProtect(ctx context.Context, next http.Handler, config *Config, n } bc := CaptchaProtect{ - next: next, - name: name, - config: config, + next: next, + name: name, + config: config, + log: log, + httpClient: &http.Client{ + Timeout: 10 * time.Second, + }, rateCache: lru.New(expiration, 1*time.Minute), botCache: lru.New(expiration, 1*time.Hour), verifiedCache: lru.New(expiration, 1*time.Hour), @@ -269,7 +286,7 @@ func NewCaptchaProtect(ctx context.Context, next http.Handler, config *Config, n go bc.saveState(childCtx) go func() { <-ctx.Done() - log.Debug("Context canceled, calling child cancel...") + bc.log.Debug("Context canceled, calling child cancel") cancel() }() } @@ -283,24 +300,24 @@ func (bc *CaptchaProtect) ServeHTTP(rw http.ResponseWriter, req *http.Request) { if challengeOnPage && req.Method == http.MethodPost { if req.URL.Query().Get("challenge") != "" { statusCode := bc.verifyChallengePage(rw, req, clientIP) - log.Info("Captcha challenge", "clientIP", clientIP, "method", req.Method, "path", req.URL.Path, "status", statusCode, "useragent", req.UserAgent()) + bc.log.Info("Captcha challenge", "clientIP", clientIP, "method", req.Method, "path", req.URL.Path, "status", statusCode, "useragent", req.UserAgent()) return } } else if req.URL.Path == bc.config.ChallengeURL { switch req.Method { case http.MethodGet: destination := req.URL.Query().Get("destination") - log.Info("Captcha challenge", "clientIP", clientIP, "method", req.Method, "path", req.URL.Path, "destination", destination, "useragent", req.UserAgent()) + bc.log.Info("Captcha challenge", "clientIP", clientIP, "method", req.Method, "path", req.URL.Path, "destination", destination, "useragent", req.UserAgent()) bc.serveChallengePage(rw, destination) case http.MethodPost: statusCode := bc.verifyChallengePage(rw, req, clientIP) - log.Info("Captcha challenge", "clientIP", clientIP, "method", req.Method, "path", req.URL.Path, "status", statusCode, "useragent", req.UserAgent()) + bc.log.Info("Captcha challenge", "clientIP", clientIP, "method", req.Method, "path", req.URL.Path, "status", statusCode, "useragent", req.UserAgent()) default: http.Error(rw, "Method not allowed", http.StatusMethodNotAllowed) } return } else if req.URL.Path == "/captcha-protect/stats" && bc.config.EnableStatsPage == "true" { - log.Info("Captcha stats", "clientIP", clientIP, "method", req.Method, "path", req.URL.Path, "useragent", req.UserAgent()) + bc.log.Info("Captcha stats", "clientIP", clientIP, "method", req.Method, "path", req.URL.Path, "useragent", req.UserAgent()) bc.serveStatsPage(rw, clientIP) return } @@ -318,12 +335,12 @@ func (bc *CaptchaProtect) ServeHTTP(rw http.ResponseWriter, req *http.Request) { encodedURI := url.QueryEscape(req.RequestURI) if bc.ChallengeOnPage() { - log.Info("Captcha challenge", "clientIP", clientIP, "method", req.Method, "path", req.URL.Path, "useragent", req.UserAgent()) + bc.log.Info("Captcha challenge", "clientIP", clientIP, "method", req.Method, "path", req.URL.Path, "useragent", req.UserAgent()) bc.serveChallengePage(rw, encodedURI) return } - url := fmt.Sprintf("%s?destination=%s", bc.config.ChallengeURL, encodedURI) - http.Redirect(rw, req, url, http.StatusFound) + redirectURL := fmt.Sprintf("%s?destination=%s", bc.config.ChallengeURL, encodedURI) + http.Redirect(rw, req, redirectURL, http.StatusFound) } func (bc *CaptchaProtect) serveChallengePage(rw http.ResponseWriter, destination string) { @@ -343,8 +360,9 @@ func (bc *CaptchaProtect) serveChallengePage(rw http.ResponseWriter, destination err := bc.tmpl.Execute(rw, d) if err != nil { - log.Error("Unable to execute go template", "tmpl", bc.config.ChallengeTmpl, "err", err) - http.Error(rw, "Internal error", http.StatusInternalServerError) + bc.log.Error("unable to execute go template", "tmpl", bc.config.ChallengeTmpl, "err", err) + // Can't change status code here, already written + _, _ = rw.Write([]byte("\n")) } } @@ -358,9 +376,9 @@ func (bc *CaptchaProtect) verifyChallengePage(rw http.ResponseWriter, req *http. var body = url.Values{} body.Add("secret", bc.config.SecretKey) body.Add("response", response) - resp, err := http.PostForm(bc.captchaConfig.validate, body) + resp, err := bc.httpClient.PostForm(bc.captchaConfig.validate, body) if err != nil { - log.Error("Unable to validate captcha", "url", bc.captchaConfig.validate, "body", body, "err", err) + bc.log.Error("unable to validate captcha", "url", bc.captchaConfig.validate, "err", err) http.Error(rw, "Internal error", http.StatusInternalServerError) return http.StatusInternalServerError } @@ -369,7 +387,7 @@ func (bc *CaptchaProtect) verifyChallengePage(rw http.ResponseWriter, req *http. var captchaResponse captchaResponse err = json.NewDecoder(resp.Body).Decode(&captchaResponse) if err != nil { - log.Error("Unable to unmarshal captcha response", "url", bc.captchaConfig.validate, "err", err) + bc.log.Error("unable to unmarshal captcha response", "url", bc.captchaConfig.validate, "err", err) http.Error(rw, "Internal error", http.StatusInternalServerError) return http.StatusInternalServerError } @@ -381,7 +399,7 @@ func (bc *CaptchaProtect) verifyChallengePage(rw http.ResponseWriter, req *http. } u, err := url.QueryUnescape(destination) if err != nil { - log.Error("Unable to unescape destination", "destination", destination, "err", err) + bc.log.Error("unable to unescape destination", "destination", destination, "err", err) u = "/" } http.Redirect(rw, req, u, http.StatusFound) @@ -403,16 +421,16 @@ func (bc *CaptchaProtect) serveStatsPage(rw http.ResponseWriter, ip string) { state := state.GetState(bc.rateCache.Items(), bc.botCache.Items(), bc.verifiedCache.Items()) jsonData, err := json.Marshal(state) if err != nil { - log.Error("failed to marshal JSON", "err", err) + bc.log.Error("failed to marshal JSON", "err", err) http.Error(rw, "Internal Server Error", http.StatusInternalServerError) return } - rw.WriteHeader(http.StatusOK) rw.Header().Set("Content-Type", "application/json") + rw.WriteHeader(http.StatusOK) _, err = rw.Write(jsonData) if err != nil { - log.Error("failed to write JSON on stats reques", "err", err) + bc.log.Error("failed to write JSON on stats request", "err", err) http.Error(rw, "Internal Server Error", http.StatusInternalServerError) return } @@ -452,6 +470,22 @@ func (bc *CaptchaProtect) shouldApply(req *http.Request, clientIP string) bool { return bc.RouteIsProtectedPrefix(req.URL.Path) } +// isExtensionProtected checks if a file extension should be protected based on the configured list. +// Returns true if the path has no extension (likely HTML) or if the extension matches the protected list. +func (bc *CaptchaProtect) isExtensionProtected(path string) bool { + ext := filepath.Ext(path) + ext = strings.TrimPrefix(ext, ".") + if ext == "" { + return true + } + for _, protectedExt := range bc.config.ProtectFileExtensions { + if strings.EqualFold(ext, protectedExt) { + return true + } + } + return false +} + func (bc *CaptchaProtect) RouteIsProtectedPrefix(path string) bool { protected: for _, route := range bc.config.ProtectRoutes { @@ -466,19 +500,7 @@ protected: } } - // if this path isn't a file, go ahead and mark this path as protected - ext := filepath.Ext(path) - ext = strings.TrimPrefix(ext, ".") - if ext == "" { - return true - } - - // if we have a file extension, see if we should protect this file extension type - for _, protectedExtensions := range bc.config.ProtectFileExtensions { - if strings.EqualFold(ext, protectedExtensions) { - return true - } - } + return bc.isExtensionProtected(path) } return false @@ -503,18 +525,7 @@ protected: } } - // if this path isn't a file, go ahead and mark this path as protected - ext = strings.TrimPrefix(ext, ".") - if ext == "" { - return true - } - - // if we have a file extension, see if we should protect this file extension type - for _, protectedExtensions := range bc.config.ProtectFileExtensions { - if strings.EqualFold(ext, protectedExtensions) { - return true - } - } + return bc.isExtensionProtected(path) } return false @@ -546,17 +557,7 @@ protected: } } - ext := filepath.Ext(path) - ext = strings.TrimPrefix(ext, ".") - if ext == "" { - return true - } - - for _, protectedExtension := range bc.config.ProtectFileExtensions { - if strings.EqualFold(ext, protectedExtension) { - return true - } - } + return bc.isExtensionProtected(path) } return false @@ -565,7 +566,7 @@ protected: func (bc *CaptchaProtect) trippedRateLimit(ip string) bool { v, ok := bc.rateCache.Get(ip) if !ok { - log.Error("IP not found, but should already be set", "ip", ip) + bc.log.Error("IP not found, but should already be set", "ip", ip) return false } return v.(uint) > bc.config.RateLimit @@ -579,7 +580,7 @@ func (bc *CaptchaProtect) registerRequest(ip string) { _, err = bc.rateCache.IncrementUint(ip, uint(1)) if err != nil { - log.Error("Unable to set rate cache", "ip", ip) + bc.log.Error("unable to set rate cache", "ip", ip) } } @@ -601,18 +602,22 @@ func (bc *CaptchaProtect) getClientIP(req *http.Request) (string, string) { depth-- } if ip == "" { - log.Debug("No non-exempt IPs in header. req.RemoteAddr", "ipDepth", bc.config.IPDepth, bc.config.IPForwardedHeader, req.Header.Get(bc.config.IPForwardedHeader)) + bc.log.Debug("No non-exempt IPs in header. req.RemoteAddr", "ipDepth", bc.config.IPDepth, bc.config.IPForwardedHeader, req.Header.Get(bc.config.IPForwardedHeader)) ip = req.RemoteAddr } } else { if bc.config.IPForwardedHeader != "" { - log.Debug("Received a blank header value. Defaulting to real IP") + bc.log.Debug("Received a blank header value. Defaulting to real IP") } ip = req.RemoteAddr } if strings.Contains(ip, ":") { - host, _, _ := net.SplitHostPort(ip) - ip = host + host, _, err := net.SplitHostPort(ip) + if err != nil { + bc.log.Warn("Failed to parse port from IP", "ip", ip, "err", err) + } else { + ip = host + } } return bc.ParseIp(ip) @@ -636,7 +641,7 @@ func (bc *CaptchaProtect) ParseIp(ip string) (string, string) { return ip, subnet.String() } - log.Warn("Unknown ip version", "ip", ip) + bc.log.Warn("Unknown ip version", "ip", ip) return ip, ip } @@ -680,14 +685,15 @@ func (bc *CaptchaProtect) SetExemptIps(exemptIps []*net.IPNet) { bc.exemptIps = exemptIps } -// log a warning if protected methods contains an invalid method -func (c *Config) ParseHttpMethods() { +// ParseHttpMethods logs a warning if protected methods contains an invalid method. +// Note: This method is called during initialization, validation is informational only. +func (c *Config) ParseHttpMethods(log *slog.Logger) { for _, method := range c.ProtectHttpMethods { switch method { case "GET", "HEAD", "POST", "PUT", "PATCH", "DELETE", "CONNECT", "OPTIONS", "TRACE": continue default: - log.Warn("unknown http method", "method", method) + log.Warn("Unknown HTTP method", "method", method) } } } @@ -698,7 +704,7 @@ func (bc *CaptchaProtect) saveState(ctx context.Context) { file, err := os.OpenFile(bc.config.PersistentStateFile, os.O_CREATE|os.O_WRONLY, 0644) if err != nil { - log.Error("Unable to save state. Could not open or create file", "stateFile", bc.config.PersistentStateFile, "err", err) + bc.log.Error("unable to save state, could not open or create file", "stateFile", bc.config.PersistentStateFile, "err", err) return } // we made sure the file is writable, we can continue in our loop @@ -707,20 +713,20 @@ func (bc *CaptchaProtect) saveState(ctx context.Context) { for { select { case <-ticker.C: - log.Debug("Saving state") + bc.log.Debug("Saving state") state := state.GetState(bc.rateCache.Items(), bc.botCache.Items(), bc.verifiedCache.Items()) jsonData, err := json.Marshal(state) if err != nil { - log.Error("failed unmarshalling state data", "err", err) + bc.log.Error("failed to marshal state data", "err", err) break } err = os.WriteFile(bc.config.PersistentStateFile, jsonData, 0644) if err != nil { - log.Error("failed saving state data", "err", err) + bc.log.Error("failed to save state data", "err", err) } case <-ctx.Done(): - log.Debug("Context cancelled, stopping saveState") + bc.log.Debug("Context cancelled, stopping saveState") return } } @@ -729,14 +735,14 @@ func (bc *CaptchaProtect) saveState(ctx context.Context) { func (bc *CaptchaProtect) loadState() { fileContent, err := os.ReadFile(bc.config.PersistentStateFile) if err != nil || len(fileContent) == 0 { - log.Warn("Failed to load state file.", "err", err) + bc.log.Warn("failed to load state file", "err", err) return } var state state.State err = json.Unmarshal(fileContent, &state) if err != nil { - log.Error("Failed to unmarshal state file", "err", err) + bc.log.Error("failed to unmarshal state file", "err", err) return } @@ -752,7 +758,7 @@ func (bc *CaptchaProtect) loadState() { bc.verifiedCache.Set(k, v, lru.DefaultExpiration) } - log.Info("Loaded previous state") + bc.log.Info("Loaded previous state") } func (bc *CaptchaProtect) ChallengeOnPage() bool { diff --git a/main_test.go b/main_test.go index ccf1190..b2b2c33 100644 --- a/main_test.go +++ b/main_test.go @@ -2,22 +2,19 @@ package captcha_protect import ( "context" + "encoding/json" "log/slog" "net" "net/http" "net/http/httptest" "os" + "path/filepath" "regexp" "strings" "testing" + "time" ) -func init() { - log = slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{ - Level: slog.LevelDebug, - })) -} - func TestParseIp(t *testing.T) { tests := []struct { name string @@ -221,6 +218,8 @@ func TestRouteIsProtected(t *testing.T) { t.Run(tt.name+"_"+mode, func(t *testing.T) { c := CreateConfig() c.Mode = mode + c.SiteKey = "test-site-key" + c.SecretKey = "test-secret-key" c.ProtectFileExtensions = append(c.ProtectFileExtensions, tt.config.ProtectFileExtensions...) if useRegex { @@ -340,6 +339,8 @@ func TestRouteIsProtectedSuffix(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { c := CreateConfig() + c.SiteKey = "test-site-key" + c.SecretKey = "test-secret-key" c.ProtectRoutes = append(c.ProtectRoutes, tt.config.ProtectRoutes...) c.ExcludeRoutes = append(c.ExcludeRoutes, tt.config.ExcludeRoutes...) c.Mode = "suffix" @@ -448,6 +449,8 @@ func TestGetClientIP(t *testing.T) { req.RemoteAddr = tc.remoteAddr c := CreateConfig() + c.SiteKey = "test-site-key" + c.SecretKey = "test-secret-key" c.IPForwardedHeader = tc.config.IPForwardedHeader c.IPDepth = tc.config.IPDepth c.ProtectRoutes = []string{"/"} @@ -518,6 +521,8 @@ func TestServeHTTP(t *testing.T) { } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { + config.SiteKey = "test-site-key" + config.SecretKey = "test-secret-key" config.RateLimit = tc.rateLimit config.CaptchaProvider = "turnstile" config.ProtectRoutes = []string{"/"} @@ -573,6 +578,8 @@ func TestIsGoodUserAgent(t *testing.T) { {"Empty exempt list", []string{}, "Mozilla/5.0", false}, } config := CreateConfig() + config.SiteKey = "test-site-key" + config.SecretKey = "test-secret-key" config.ProtectRoutes = []string{"/"} for _, tc := range tests { config.ExemptUserAgents = tc.exemptUserAgents @@ -586,3 +593,608 @@ func TestIsGoodUserAgent(t *testing.T) { } } } + +func TestNewCaptchaProtectValidation(t *testing.T) { + tests := []struct { + name string + modifyConfig func(*Config) + expectError string + }{ + { + name: "Missing SiteKey", + modifyConfig: func(c *Config) { c.SiteKey = "" }, + expectError: "siteKey is required", + }, + { + name: "Missing SecretKey", + modifyConfig: func(c *Config) { c.SecretKey = "" }, + expectError: "secretKey is required", + }, + { + name: "Zero Window", + modifyConfig: func(c *Config) { c.Window = 0 }, + expectError: "window must be positive", + }, + { + name: "Negative Window", + modifyConfig: func(c *Config) { c.Window = -1 }, + expectError: "window must be positive", + }, + { + name: "Invalid CAPTCHA Provider", + modifyConfig: func(c *Config) { c.CaptchaProvider = "invalid" }, + expectError: "invalid captcha provider", + }, + { + name: "Invalid regex in ProtectRoutes", + modifyConfig: func(c *Config) { + c.Mode = "regex" + c.ProtectRoutes = []string{"[invalid"} + }, + expectError: "invalid regex in protectRoutes", + }, + { + name: "Invalid regex in ExcludeRoutes", + modifyConfig: func(c *Config) { + c.Mode = "regex" + c.ExcludeRoutes = []string{"[invalid"} + }, + expectError: "invalid regex in excludeRoutes", + }, + { + name: "ChallengeURL is /", + modifyConfig: func(c *Config) { c.ChallengeURL = "/" }, + expectError: "challenge URL can not be the entire site", + }, + { + name: "Invalid mode", + modifyConfig: func(c *Config) { c.Mode = "invalid" }, + expectError: "unknown mode", + }, + { + name: "Invalid IPv4 mask - too small", + modifyConfig: func(c *Config) { c.IPv4SubnetMask = 5 }, + expectError: "invalid ipv4 mask", + }, + { + name: "Invalid IPv4 mask - too large", + modifyConfig: func(c *Config) { c.IPv4SubnetMask = 33 }, + expectError: "invalid ipv4 mask", + }, + { + name: "Invalid IPv6 mask - too small", + modifyConfig: func(c *Config) { c.IPv6SubnetMask = 5 }, + expectError: "invalid ipv6 mask", + }, + { + name: "Invalid IPv6 mask - too large", + modifyConfig: func(c *Config) { c.IPv6SubnetMask = 200 }, + expectError: "invalid ipv6 mask", + }, + { + name: "Invalid CIDR in ExemptIPs", + modifyConfig: func(c *Config) { + c.ExemptIPs = []string{"not-a-cidr"} + }, + expectError: "error parsing cidr", + }, + { + name: "No protected routes in prefix mode", + modifyConfig: func(c *Config) { c.ProtectRoutes = []string{} }, + expectError: "you must protect at least one route", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := CreateConfig() + c.SiteKey = "test" + c.SecretKey = "test" + c.ProtectRoutes = []string{"/"} + tt.modifyConfig(c) + + _, err := NewCaptchaProtect(context.Background(), nil, c, "test") + if err == nil { + t.Errorf("Expected error containing %q, got nil", tt.expectError) + } else if !strings.Contains(err.Error(), tt.expectError) { + t.Errorf("Expected error containing %q, got %q", tt.expectError, err.Error()) + } + }) + } +} + +func TestRateLimiting(t *testing.T) { + config := CreateConfig() + config.SiteKey = "test" + config.SecretKey = "test" + config.ProtectRoutes = []string{"/"} + config.RateLimit = 5 + config.Window = 10 + + bc, err := NewCaptchaProtect(context.Background(), nil, config, "test") + if err != nil { + t.Fatal(err) + } + + subnet := "192.168.0.0" + + // Register 5 requests (at rate limit) + for i := 0; i < 5; i++ { + bc.registerRequest(subnet) + if bc.trippedRateLimit(subnet) { + t.Errorf("Should not trip rate limit at %d requests", i+1) + } + } + + // 6th request should trip + bc.registerRequest(subnet) + if !bc.trippedRateLimit(subnet) { + t.Error("Should trip rate limit after exceeding") + } + + // Different subnet should not be affected + differentSubnet := "10.0.0.0" + bc.registerRequest(differentSubnet) + if bc.trippedRateLimit(differentSubnet) { + t.Error("Different subnet should not be rate limited") + } +} + +func TestIsGoodBotWithParameters(t *testing.T) { + config := CreateConfig() + config.SiteKey = "test" + config.SecretKey = "test" + config.ProtectRoutes = []string{"/"} + config.ProtectParameters = "true" + config.GoodBots = []string{"googlebot.com"} + + bc, _ := NewCaptchaProtect(context.Background(), nil, config, "test") + + // Mock bot cache to simulate good bot + bc.botCache.Set("1.2.3.4", true, 1*time.Hour) + + tests := []struct { + name string + url string + expected bool + }{ + {"URL without params - good bot allowed", "http://example.com/page", true}, + {"URL with params - good bot blocked", "http://example.com/page?foo=bar", false}, + {"URL with multiple params - good bot blocked", "http://example.com/page?foo=bar&baz=qux", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := httptest.NewRequest("GET", tt.url, nil) + result := bc.isGoodBot(req, "1.2.3.4") + if result != tt.expected { + t.Errorf("Expected %v, got %v", tt.expected, result) + } + }) + } +} + +func TestVerifiedCacheBypasses(t *testing.T) { + config := CreateConfig() + config.SiteKey = "test" + config.SecretKey = "test" + config.ProtectRoutes = []string{"/"} + config.RateLimit = 0 // Always challenge unless verified + + bc, _ := NewCaptchaProtect(context.Background(), nil, config, "test") + + req := httptest.NewRequest("GET", "http://example.com/test", nil) + clientIP := "1.2.3.4" + + // Should apply before verification + if !bc.shouldApply(req, clientIP) { + t.Error("Should apply protection before verification") + } + + // Add to verified cache + bc.verifiedCache.Set(clientIP, true, 1*time.Hour) + + // Should not apply after verification + if bc.shouldApply(req, clientIP) { + t.Error("Should not apply protection after verification") + } +} + +func TestStatsPage(t *testing.T) { + config := CreateConfig() + config.SiteKey = "test" + config.SecretKey = "test" + config.ProtectRoutes = []string{"/"} + config.EnableStatsPage = "true" + + bc, _ := NewCaptchaProtect(context.Background(), nil, config, "test") + + // Add some test data + bc.rateCache.Set("192.168.0.0", uint(10), 1*time.Hour) + bc.verifiedCache.Set("1.2.3.4", true, 1*time.Hour) + + tests := []struct { + name string + clientIP string + expectedStatus int + }{ + {"Exempt IP can access", "192.168.1.1", http.StatusOK}, + {"Private IP can access", "10.0.0.1", http.StatusOK}, + {"Non-exempt IP forbidden", "1.2.3.4", http.StatusForbidden}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + rr := httptest.NewRecorder() + + bc.serveStatsPage(rr, tt.clientIP) + + if rr.Code != tt.expectedStatus { + t.Errorf("Expected status %d, got %d", tt.expectedStatus, rr.Code) + } + + if tt.expectedStatus == http.StatusOK { + // Verify JSON response + var stats map[string]interface{} + if err := json.Unmarshal(rr.Body.Bytes(), &stats); err != nil { + t.Errorf("Failed to parse JSON: %v", err) + } + // Check that we have expected keys + if _, ok := stats["rate"]; !ok { + t.Error("Stats JSON missing 'rate' key") + } + if _, ok := stats["verified"]; !ok { + t.Error("Stats JSON missing 'verified' key") + } + } + }) + } +} + +func TestProtectHttpMethods(t *testing.T) { + config := CreateConfig() + config.SiteKey = "test" + config.SecretKey = "test" + config.ProtectRoutes = []string{"/"} + config.ProtectHttpMethods = []string{"GET", "POST"} + config.RateLimit = 0 // Always challenge + + bc, _ := NewCaptchaProtect(context.Background(), nil, config, "test") + + tests := []struct { + method string + expected bool + }{ + {"GET", true}, + {"POST", true}, + {"PUT", false}, + {"DELETE", false}, + {"PATCH", false}, + {"HEAD", false}, + } + + for _, tt := range tests { + t.Run(tt.method, func(t *testing.T) { + req := httptest.NewRequest(tt.method, "http://example.com/test", nil) + result := bc.shouldApply(req, "1.2.3.4") + if result != tt.expected { + t.Errorf("Method %s: expected %v, got %v", tt.method, tt.expected, result) + } + }) + } +} + +func TestIsExtensionProtected(t *testing.T) { + config := CreateConfig() + config.SiteKey = "test" + config.SecretKey = "test" + config.ProtectRoutes = []string{"/"} + config.ProtectFileExtensions = []string{"html", "php", "json"} + + bc, _ := NewCaptchaProtect(context.Background(), nil, config, "test") + + tests := []struct { + path string + expected bool + }{ + {"/index.html", true}, + {"/api.json", true}, + {"/script.php", true}, + {"/style.css", false}, + {"/image.jpg", false}, + {"/no-extension", true}, // No extension = protected + {"/path/to/file.HTML", true}, // Case insensitive + {"/path/to/file.JSON", true}, + {"/path/to/file.Php", true}, + } + + for _, tt := range tests { + t.Run(tt.path, func(t *testing.T) { + result := bc.isExtensionProtected(tt.path) + if result != tt.expected { + t.Errorf("Path %s: expected %v, got %v", tt.path, tt.expected, result) + } + }) + } +} + +func TestStatePersistence(t *testing.T) { + tmpFile := filepath.Join(t.TempDir(), "state.json") + + config := CreateConfig() + config.SiteKey = "test" + config.SecretKey = "test" + config.ProtectRoutes = []string{"/"} + config.PersistentStateFile = tmpFile + + // Don't pass a context to avoid starting background goroutines + bc1, _ := NewCaptchaProtect(context.Background(), nil, config, "test") + + // Add some state + bc1.rateCache.Set("192.168.0.0", uint(10), 1*time.Hour) + bc1.verifiedCache.Set("1.2.3.4", true, 1*time.Hour) + bc1.botCache.Set("5.6.7.8", false, 1*time.Hour) + + // Manually save state by writing the file directly + // This tests the state format without relying on the background goroutine + jsonData, _ := json.Marshal(map[string]interface{}{ + "rate": map[string]uint{ + "192.168.0.0": 10, + }, + "verified": map[string]bool{ + "1.2.3.4": true, + }, + "bots": map[string]bool{ + "5.6.7.8": false, + }, + }) + err := os.WriteFile(tmpFile, jsonData, 0644) + if err != nil { + t.Fatalf("Failed to write state file: %v", err) + } + + // Create new instance - should load state + bc2, _ := NewCaptchaProtect(context.Background(), nil, config, "test") + + // Check rate cache + val, found := bc2.rateCache.Get("192.168.0.0") + if !found || val.(uint) != 10 { + t.Error("Rate cache state not persisted correctly") + } + + // Check verified cache + _, found = bc2.verifiedCache.Get("1.2.3.4") + if !found { + t.Error("Verified cache state not persisted correctly") + } + + // Check bot cache + botVal, found := bc2.botCache.Get("5.6.7.8") + if !found || botVal.(bool) != false { + t.Error("Bot cache state not persisted correctly") + } +} + +func TestVerifyChallengePage(t *testing.T) { + tests := []struct { + name string + provider string + formValues map[string]string + mockResponse string + expectedStatus int + shouldSetCache bool + }{ + { + name: "Missing captcha response", + provider: "turnstile", + formValues: map[string]string{}, + expectedStatus: http.StatusBadRequest, + shouldSetCache: false, + }, + { + name: "Successful verification with destination", + provider: "turnstile", + formValues: map[string]string{ + "cf-turnstile-response": "valid-token", + "destination": "%2Fhome", + }, + mockResponse: `{"success":true}`, + expectedStatus: http.StatusFound, + shouldSetCache: true, + }, + { + name: "Successful verification without destination", + provider: "recaptcha", + formValues: map[string]string{ + "g-recaptcha-response": "valid-token", + }, + mockResponse: `{"success":true}`, + expectedStatus: http.StatusFound, + shouldSetCache: true, + }, + { + name: "Failed verification", + provider: "hcaptcha", + formValues: map[string]string{ + "h-captcha-response": "invalid-token", + }, + mockResponse: `{"success":false}`, + expectedStatus: http.StatusForbidden, + shouldSetCache: false, + }, + { + name: "Invalid destination URL", + provider: "turnstile", + formValues: map[string]string{ + "cf-turnstile-response": "valid-token", + "destination": "%ZZ", + }, + mockResponse: `{"success":true}`, + expectedStatus: http.StatusFound, + shouldSetCache: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Create mock server + mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(tt.mockResponse)) + })) + defer mockServer.Close() + + config := CreateConfig() + config.SiteKey = "test" + config.SecretKey = "test" + config.ProtectRoutes = []string{"/"} + config.CaptchaProvider = tt.provider + + bc, _ := NewCaptchaProtect(context.Background(), nil, config, "test") + + // Override the validation URL to point to our mock server + bc.captchaConfig.validate = mockServer.URL + + // Create request with form values + req := httptest.NewRequest("POST", "http://example.com/challenge", nil) + req.Form = make(map[string][]string) + for k, v := range tt.formValues { + req.Form.Set(k, v) + } + + rr := httptest.NewRecorder() + clientIP := "1.2.3.4" + + status := bc.verifyChallengePage(rr, req, clientIP) + + if status != tt.expectedStatus { + t.Errorf("Expected status %d, got %d", tt.expectedStatus, status) + } + + // Check if IP was added to verified cache + _, found := bc.verifiedCache.Get(clientIP) + if found != tt.shouldSetCache { + t.Errorf("Expected cache set=%v, got=%v", tt.shouldSetCache, found) + } + }) + } +} + +func TestVerifyChallengePageHTTPError(t *testing.T) { + // Test HTTP client error + config := CreateConfig() + config.SiteKey = "test" + config.SecretKey = "test" + config.ProtectRoutes = []string{"/"} + + bc, _ := NewCaptchaProtect(context.Background(), nil, config, "test") + + // Set invalid URL to trigger HTTP error + bc.captchaConfig.validate = "http://invalid-domain-that-does-not-exist-12345.com" + + req := httptest.NewRequest("POST", "http://example.com/challenge", nil) + req.Form = make(map[string][]string) + req.Form.Set("cf-turnstile-response", "token") + + rr := httptest.NewRecorder() + status := bc.verifyChallengePage(rr, req, "1.2.3.4") + + if status != http.StatusInternalServerError { + t.Errorf("Expected status %d for HTTP error, got %d", http.StatusInternalServerError, status) + } +} + +func TestVerifyChallengePageInvalidJSON(t *testing.T) { + // Test invalid JSON response + mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, _ = w.Write([]byte(`{invalid json`)) + })) + defer mockServer.Close() + + config := CreateConfig() + config.SiteKey = "test" + config.SecretKey = "test" + config.ProtectRoutes = []string{"/"} + + bc, _ := NewCaptchaProtect(context.Background(), nil, config, "test") + bc.captchaConfig.validate = mockServer.URL + + req := httptest.NewRequest("POST", "http://example.com/challenge", nil) + req.Form = make(map[string][]string) + req.Form.Set("cf-turnstile-response", "token") + + rr := httptest.NewRecorder() + status := bc.verifyChallengePage(rr, req, "1.2.3.4") + + if status != http.StatusInternalServerError { + t.Errorf("Expected status %d for JSON error, got %d", http.StatusInternalServerError, status) + } +} + +func TestServeHTTPMethodNotAllowed(t *testing.T) { + next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + rw.WriteHeader(http.StatusOK) + }) + + config := CreateConfig() + config.SiteKey = "test" + config.SecretKey = "test" + config.ProtectRoutes = []string{"/"} + config.ChallengeURL = "/challenge" + + bc, _ := NewCaptchaProtect(context.Background(), next, config, "test") + + req := httptest.NewRequest("DELETE", "http://example.com/challenge", nil) + req.RequestURI = "/challenge" + rr := httptest.NewRecorder() + + bc.ServeHTTP(rr, req) + + if rr.Code != http.StatusMethodNotAllowed { + t.Errorf("Expected status %d, got %d", http.StatusMethodNotAllowed, rr.Code) + } +} + +func TestLoadStateInvalidJSON(t *testing.T) { + tmpDir := t.TempDir() + tmpFile := filepath.Join(tmpDir, "invalid.json") + + // Write invalid JSON + if err := os.WriteFile(tmpFile, []byte(`{invalid json`), 0644); err != nil { + t.Fatalf("Failed to write test file: %v", err) + } + + config := CreateConfig() + config.SiteKey = "test" + config.SecretKey = "test" + config.ProtectRoutes = []string{"/"} + config.PersistentStateFile = tmpFile + + // Should not panic, just log error + bc, err := NewCaptchaProtect(context.Background(), nil, config, "test") + if err != nil { + t.Errorf("Should not fail on invalid state JSON: %v", err) + } + + // Caches should be empty + if bc.rateCache.ItemCount() != 0 { + t.Error("Rate cache should be empty after failed load") + } + + // Clean up the file before temp dir cleanup + _ = os.Remove(tmpFile) +} + +func TestParseHttpMethodsInvalid(t *testing.T) { + config := CreateConfig() + config.SiteKey = "test" + config.SecretKey = "test" + config.ProtectRoutes = []string{"/"} + config.ProtectHttpMethods = []string{"GET", "INVALID_METHOD", "POST"} + + // Should not fail, just log warning + _, err := NewCaptchaProtect(context.Background(), nil, config, "test") + if err != nil { + t.Errorf("Should not fail on invalid HTTP method: %v", err) + } +}