From 0eb484cd363bb8fe276c75abfd4f781a65737c46 Mon Sep 17 00:00:00 2001 From: Joe Corall Date: Sat, 25 Oct 2025 15:51:24 -0400 Subject: [PATCH 1/7] Only record response body in log Fix header/status order fix url shadowing just log tmpl errors since the status order is tricky typo --- main.go | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/main.go b/main.go index 530d97e..d7c9fc4 100644 --- a/main.go +++ b/main.go @@ -322,8 +322,8 @@ func (bc *CaptchaProtect) ServeHTTP(rw http.ResponseWriter, req *http.Request) { bc.serveChallengePage(rw, encodedURI) return } - url := fmt.Sprintf("%s?destination=%s", bc.config.ChallengeURL, encodedURI) - http.Redirect(rw, req, url, http.StatusFound) + redirectURL := fmt.Sprintf("%s?destination=%s", bc.config.ChallengeURL, encodedURI) + http.Redirect(rw, req, redirectURL, http.StatusFound) } func (bc *CaptchaProtect) serveChallengePage(rw http.ResponseWriter, destination string) { @@ -344,7 +344,8 @@ func (bc *CaptchaProtect) serveChallengePage(rw http.ResponseWriter, destination err := bc.tmpl.Execute(rw, d) if err != nil { log.Error("Unable to execute go template", "tmpl", bc.config.ChallengeTmpl, "err", err) - http.Error(rw, "Internal error", http.StatusInternalServerError) + // Can't change status code here, already written + _, _ = rw.Write([]byte("\n")) } } @@ -360,7 +361,7 @@ func (bc *CaptchaProtect) verifyChallengePage(rw http.ResponseWriter, req *http. body.Add("response", response) resp, err := http.PostForm(bc.captchaConfig.validate, body) if err != nil { - log.Error("Unable to validate captcha", "url", bc.captchaConfig.validate, "body", body, "err", err) + log.Error("Unable to validate captcha", "url", bc.captchaConfig.validate, "response", response, "err", err) http.Error(rw, "Internal error", http.StatusInternalServerError) return http.StatusInternalServerError } @@ -408,11 +409,11 @@ func (bc *CaptchaProtect) serveStatsPage(rw http.ResponseWriter, ip string) { return } - rw.WriteHeader(http.StatusOK) rw.Header().Set("Content-Type", "application/json") + rw.WriteHeader(http.StatusOK) _, err = rw.Write(jsonData) if err != nil { - log.Error("failed to write JSON on stats reques", "err", err) + log.Error("failed to write JSON on stats request", "err", err) http.Error(rw, "Internal Server Error", http.StatusInternalServerError) return } From e60b2d7f9a276b830125f9324659a9a2558f0d7b Mon Sep 17 00:00:00 2001 From: Joe Corall Date: Sat, 25 Oct 2025 15:59:02 -0400 Subject: [PATCH 2/7] General code cleanup --- main.go | 179 ++++++++++++++++++++++++++------------------------- main_test.go | 16 +++-- 2 files changed, 102 insertions(+), 93 deletions(-) diff --git a/main.go b/main.go index d7c9fc4..7854a82 100644 --- a/main.go +++ b/main.go @@ -23,17 +23,14 @@ import ( lru "github.com/patrickmn/go-cache" ) -var ( - log *slog.Logger -) - type Config struct { - RateLimit uint `json:"rateLimit"` - Window int64 `json:"window"` - IPv4SubnetMask int `json:"ipv4subnetMask"` - IPv6SubnetMask int `json:"ipv6subnetMask"` - IPForwardedHeader string `json:"ipForwardedHeader"` - IPDepth int `json:"ipDepth"` + RateLimit uint `json:"rateLimit"` + Window int64 `json:"window"` + IPv4SubnetMask int `json:"ipv4subnetMask"` + IPv6SubnetMask int `json:"ipv6subnetMask"` + IPForwardedHeader string `json:"ipForwardedHeader"` + IPDepth int `json:"ipDepth"` + // ProtectParameters is a string instead of bool due to Traefik's label parsing limitations ProtectParameters string `json:"protectParameters"` ProtectRoutes []string `json:"protectRoutes"` ExcludeRoutes []string `json:"excludeRoutes"` @@ -48,16 +45,19 @@ type Config struct { CaptchaProvider string `json:"captchaProvider"` SiteKey string `json:"siteKey"` SecretKey string `json:"secretKey"` - EnableStatsPage string `json:"enableStatsPage"` - LogLevel string `json:"loglevel,omitempty"` - PersistentStateFile string `json:"persistentStateFile"` - Mode string `json:"mode"` + // EnableStatsPage is a string instead of bool due to Traefik's label parsing limitations + EnableStatsPage string `json:"enableStatsPage"` + LogLevel string `json:"loglevel,omitempty"` + PersistentStateFile string `json:"persistentStateFile"` + Mode string `json:"mode"` } type CaptchaProtect struct { next http.Handler name string config *Config + log *slog.Logger + httpClient *http.Client rateCache *lru.Cache verifiedCache *lru.Cache botCache *lru.Cache @@ -111,7 +111,18 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h } func NewCaptchaProtect(ctx context.Context, next http.Handler, config *Config, name string) (*CaptchaProtect, error) { - log = plog.New(config.LogLevel) + log := plog.New(config.LogLevel) + + // Validate required config + if config.SiteKey == "" { + return nil, fmt.Errorf("siteKey is required") + } + if config.SecretKey == "" { + return nil, fmt.Errorf("secretKey is required") + } + if config.Window <= 0 { + return nil, fmt.Errorf("window must be positive, got %d", config.Window) + } expiration := time.Duration(config.Window) * time.Second log.Debug("Captcha config", "config", config) @@ -164,11 +175,11 @@ func NewCaptchaProtect(ctx context.Context, next http.Handler, config *Config, n "HEAD", } } - config.ParseHttpMethods() + config.ParseHttpMethods(log) var tmpl *template.Template if _, err := os.Stat(config.ChallengeTmpl); os.IsNotExist(err) { - log.Warn("Unable to find template file. Using default template.", "challengeTmpl", config.ChallengeTmpl) + log.Warn("Unable to find template file. Using default template", "challengeTmpl", config.ChallengeTmpl) ts := helper.GetDefaultTmpl() tmpl, err = template.New("challenge").Parse(ts) if err != nil { @@ -183,6 +194,8 @@ func NewCaptchaProtect(ctx context.Context, next http.Handler, config *Config, n } } + // Always protect HTML files by default to ensure the main content is rate-limited. + // This prevents users from accidentally excluding HTML, which would break the protection. if !slices.Contains(config.ProtectFileExtensions, "html") { config.ProtectFileExtensions = append(config.ProtectFileExtensions, "html") } @@ -206,9 +219,13 @@ func NewCaptchaProtect(ctx context.Context, next http.Handler, config *Config, n } bc := CaptchaProtect{ - next: next, - name: name, - config: config, + next: next, + name: name, + config: config, + log: log, + httpClient: &http.Client{ + Timeout: 10 * time.Second, + }, rateCache: lru.New(expiration, 1*time.Minute), botCache: lru.New(expiration, 1*time.Hour), verifiedCache: lru.New(expiration, 1*time.Hour), @@ -269,7 +286,7 @@ func NewCaptchaProtect(ctx context.Context, next http.Handler, config *Config, n go bc.saveState(childCtx) go func() { <-ctx.Done() - log.Debug("Context canceled, calling child cancel...") + bc.log.Debug("Context canceled, calling child cancel") cancel() }() } @@ -283,24 +300,24 @@ func (bc *CaptchaProtect) ServeHTTP(rw http.ResponseWriter, req *http.Request) { if challengeOnPage && req.Method == http.MethodPost { if req.URL.Query().Get("challenge") != "" { statusCode := bc.verifyChallengePage(rw, req, clientIP) - log.Info("Captcha challenge", "clientIP", clientIP, "method", req.Method, "path", req.URL.Path, "status", statusCode, "useragent", req.UserAgent()) + bc.log.Info("Captcha challenge", "clientIP", clientIP, "method", req.Method, "path", req.URL.Path, "status", statusCode, "useragent", req.UserAgent()) return } } else if req.URL.Path == bc.config.ChallengeURL { switch req.Method { case http.MethodGet: destination := req.URL.Query().Get("destination") - log.Info("Captcha challenge", "clientIP", clientIP, "method", req.Method, "path", req.URL.Path, "destination", destination, "useragent", req.UserAgent()) + bc.log.Info("Captcha challenge", "clientIP", clientIP, "method", req.Method, "path", req.URL.Path, "destination", destination, "useragent", req.UserAgent()) bc.serveChallengePage(rw, destination) case http.MethodPost: statusCode := bc.verifyChallengePage(rw, req, clientIP) - log.Info("Captcha challenge", "clientIP", clientIP, "method", req.Method, "path", req.URL.Path, "status", statusCode, "useragent", req.UserAgent()) + bc.log.Info("Captcha challenge", "clientIP", clientIP, "method", req.Method, "path", req.URL.Path, "status", statusCode, "useragent", req.UserAgent()) default: http.Error(rw, "Method not allowed", http.StatusMethodNotAllowed) } return } else if req.URL.Path == "/captcha-protect/stats" && bc.config.EnableStatsPage == "true" { - log.Info("Captcha stats", "clientIP", clientIP, "method", req.Method, "path", req.URL.Path, "useragent", req.UserAgent()) + bc.log.Info("Captcha stats", "clientIP", clientIP, "method", req.Method, "path", req.URL.Path, "useragent", req.UserAgent()) bc.serveStatsPage(rw, clientIP) return } @@ -318,7 +335,7 @@ func (bc *CaptchaProtect) ServeHTTP(rw http.ResponseWriter, req *http.Request) { encodedURI := url.QueryEscape(req.RequestURI) if bc.ChallengeOnPage() { - log.Info("Captcha challenge", "clientIP", clientIP, "method", req.Method, "path", req.URL.Path, "useragent", req.UserAgent()) + bc.log.Info("Captcha challenge", "clientIP", clientIP, "method", req.Method, "path", req.URL.Path, "useragent", req.UserAgent()) bc.serveChallengePage(rw, encodedURI) return } @@ -343,7 +360,7 @@ func (bc *CaptchaProtect) serveChallengePage(rw http.ResponseWriter, destination err := bc.tmpl.Execute(rw, d) if err != nil { - log.Error("Unable to execute go template", "tmpl", bc.config.ChallengeTmpl, "err", err) + bc.log.Error("unable to execute go template", "tmpl", bc.config.ChallengeTmpl, "err", err) // Can't change status code here, already written _, _ = rw.Write([]byte("\n")) } @@ -359,9 +376,9 @@ func (bc *CaptchaProtect) verifyChallengePage(rw http.ResponseWriter, req *http. var body = url.Values{} body.Add("secret", bc.config.SecretKey) body.Add("response", response) - resp, err := http.PostForm(bc.captchaConfig.validate, body) + resp, err := bc.httpClient.PostForm(bc.captchaConfig.validate, body) if err != nil { - log.Error("Unable to validate captcha", "url", bc.captchaConfig.validate, "response", response, "err", err) + bc.log.Error("unable to validate captcha", "url", bc.captchaConfig.validate, "err", err) http.Error(rw, "Internal error", http.StatusInternalServerError) return http.StatusInternalServerError } @@ -370,7 +387,7 @@ func (bc *CaptchaProtect) verifyChallengePage(rw http.ResponseWriter, req *http. var captchaResponse captchaResponse err = json.NewDecoder(resp.Body).Decode(&captchaResponse) if err != nil { - log.Error("Unable to unmarshal captcha response", "url", bc.captchaConfig.validate, "err", err) + bc.log.Error("unable to unmarshal captcha response", "url", bc.captchaConfig.validate, "err", err) http.Error(rw, "Internal error", http.StatusInternalServerError) return http.StatusInternalServerError } @@ -382,7 +399,7 @@ func (bc *CaptchaProtect) verifyChallengePage(rw http.ResponseWriter, req *http. } u, err := url.QueryUnescape(destination) if err != nil { - log.Error("Unable to unescape destination", "destination", destination, "err", err) + bc.log.Error("unable to unescape destination", "destination", destination, "err", err) u = "/" } http.Redirect(rw, req, u, http.StatusFound) @@ -404,7 +421,7 @@ func (bc *CaptchaProtect) serveStatsPage(rw http.ResponseWriter, ip string) { state := state.GetState(bc.rateCache.Items(), bc.botCache.Items(), bc.verifiedCache.Items()) jsonData, err := json.Marshal(state) if err != nil { - log.Error("failed to marshal JSON", "err", err) + bc.log.Error("failed to marshal JSON", "err", err) http.Error(rw, "Internal Server Error", http.StatusInternalServerError) return } @@ -413,7 +430,7 @@ func (bc *CaptchaProtect) serveStatsPage(rw http.ResponseWriter, ip string) { rw.WriteHeader(http.StatusOK) _, err = rw.Write(jsonData) if err != nil { - log.Error("failed to write JSON on stats request", "err", err) + bc.log.Error("failed to write JSON on stats request", "err", err) http.Error(rw, "Internal Server Error", http.StatusInternalServerError) return } @@ -453,6 +470,22 @@ func (bc *CaptchaProtect) shouldApply(req *http.Request, clientIP string) bool { return bc.RouteIsProtectedPrefix(req.URL.Path) } +// isExtensionProtected checks if a file extension should be protected based on the configured list. +// Returns true if the path has no extension (likely HTML) or if the extension matches the protected list. +func (bc *CaptchaProtect) isExtensionProtected(path string) bool { + ext := filepath.Ext(path) + ext = strings.TrimPrefix(ext, ".") + if ext == "" { + return true + } + for _, protectedExt := range bc.config.ProtectFileExtensions { + if strings.EqualFold(ext, protectedExt) { + return true + } + } + return false +} + func (bc *CaptchaProtect) RouteIsProtectedPrefix(path string) bool { protected: for _, route := range bc.config.ProtectRoutes { @@ -467,19 +500,7 @@ protected: } } - // if this path isn't a file, go ahead and mark this path as protected - ext := filepath.Ext(path) - ext = strings.TrimPrefix(ext, ".") - if ext == "" { - return true - } - - // if we have a file extension, see if we should protect this file extension type - for _, protectedExtensions := range bc.config.ProtectFileExtensions { - if strings.EqualFold(ext, protectedExtensions) { - return true - } - } + return bc.isExtensionProtected(path) } return false @@ -504,18 +525,7 @@ protected: } } - // if this path isn't a file, go ahead and mark this path as protected - ext = strings.TrimPrefix(ext, ".") - if ext == "" { - return true - } - - // if we have a file extension, see if we should protect this file extension type - for _, protectedExtensions := range bc.config.ProtectFileExtensions { - if strings.EqualFold(ext, protectedExtensions) { - return true - } - } + return bc.isExtensionProtected(path) } return false @@ -547,17 +557,7 @@ protected: } } - ext := filepath.Ext(path) - ext = strings.TrimPrefix(ext, ".") - if ext == "" { - return true - } - - for _, protectedExtension := range bc.config.ProtectFileExtensions { - if strings.EqualFold(ext, protectedExtension) { - return true - } - } + return bc.isExtensionProtected(path) } return false @@ -566,7 +566,7 @@ protected: func (bc *CaptchaProtect) trippedRateLimit(ip string) bool { v, ok := bc.rateCache.Get(ip) if !ok { - log.Error("IP not found, but should already be set", "ip", ip) + bc.log.Error("IP not found, but should already be set", "ip", ip) return false } return v.(uint) > bc.config.RateLimit @@ -580,7 +580,7 @@ func (bc *CaptchaProtect) registerRequest(ip string) { _, err = bc.rateCache.IncrementUint(ip, uint(1)) if err != nil { - log.Error("Unable to set rate cache", "ip", ip) + bc.log.Error("unable to set rate cache", "ip", ip) } } @@ -602,18 +602,22 @@ func (bc *CaptchaProtect) getClientIP(req *http.Request) (string, string) { depth-- } if ip == "" { - log.Debug("No non-exempt IPs in header. req.RemoteAddr", "ipDepth", bc.config.IPDepth, bc.config.IPForwardedHeader, req.Header.Get(bc.config.IPForwardedHeader)) + bc.log.Debug("No non-exempt IPs in header. req.RemoteAddr", "ipDepth", bc.config.IPDepth, bc.config.IPForwardedHeader, req.Header.Get(bc.config.IPForwardedHeader)) ip = req.RemoteAddr } } else { if bc.config.IPForwardedHeader != "" { - log.Debug("Received a blank header value. Defaulting to real IP") + bc.log.Debug("Received a blank header value. Defaulting to real IP") } ip = req.RemoteAddr } if strings.Contains(ip, ":") { - host, _, _ := net.SplitHostPort(ip) - ip = host + host, _, err := net.SplitHostPort(ip) + if err != nil { + bc.log.Warn("Failed to parse port from IP", "ip", ip, "err", err) + } else { + ip = host + } } return bc.ParseIp(ip) @@ -637,7 +641,7 @@ func (bc *CaptchaProtect) ParseIp(ip string) (string, string) { return ip, subnet.String() } - log.Warn("Unknown ip version", "ip", ip) + bc.log.Warn("Unknown ip version", "ip", ip) return ip, ip } @@ -681,14 +685,15 @@ func (bc *CaptchaProtect) SetExemptIps(exemptIps []*net.IPNet) { bc.exemptIps = exemptIps } -// log a warning if protected methods contains an invalid method -func (c *Config) ParseHttpMethods() { +// ParseHttpMethods logs a warning if protected methods contains an invalid method. +// Note: This method is called during initialization, validation is informational only. +func (c *Config) ParseHttpMethods(log *slog.Logger) { for _, method := range c.ProtectHttpMethods { switch method { case "GET", "HEAD", "POST", "PUT", "PATCH", "DELETE", "CONNECT", "OPTIONS", "TRACE": continue default: - log.Warn("unknown http method", "method", method) + log.Warn("Unknown HTTP method", "method", method) } } } @@ -699,7 +704,7 @@ func (bc *CaptchaProtect) saveState(ctx context.Context) { file, err := os.OpenFile(bc.config.PersistentStateFile, os.O_CREATE|os.O_WRONLY, 0644) if err != nil { - log.Error("Unable to save state. Could not open or create file", "stateFile", bc.config.PersistentStateFile, "err", err) + bc.log.Error("unable to save state, could not open or create file", "stateFile", bc.config.PersistentStateFile, "err", err) return } // we made sure the file is writable, we can continue in our loop @@ -708,20 +713,20 @@ func (bc *CaptchaProtect) saveState(ctx context.Context) { for { select { case <-ticker.C: - log.Debug("Saving state") + bc.log.Debug("Saving state") state := state.GetState(bc.rateCache.Items(), bc.botCache.Items(), bc.verifiedCache.Items()) jsonData, err := json.Marshal(state) if err != nil { - log.Error("failed unmarshalling state data", "err", err) + bc.log.Error("failed to marshal state data", "err", err) break } err = os.WriteFile(bc.config.PersistentStateFile, jsonData, 0644) if err != nil { - log.Error("failed saving state data", "err", err) + bc.log.Error("failed to save state data", "err", err) } case <-ctx.Done(): - log.Debug("Context cancelled, stopping saveState") + bc.log.Debug("Context cancelled, stopping saveState") return } } @@ -730,14 +735,14 @@ func (bc *CaptchaProtect) saveState(ctx context.Context) { func (bc *CaptchaProtect) loadState() { fileContent, err := os.ReadFile(bc.config.PersistentStateFile) if err != nil || len(fileContent) == 0 { - log.Warn("Failed to load state file.", "err", err) + bc.log.Warn("failed to load state file", "err", err) return } var state state.State err = json.Unmarshal(fileContent, &state) if err != nil { - log.Error("Failed to unmarshal state file", "err", err) + bc.log.Error("failed to unmarshal state file", "err", err) return } @@ -753,7 +758,7 @@ func (bc *CaptchaProtect) loadState() { bc.verifiedCache.Set(k, v, lru.DefaultExpiration) } - log.Info("Loaded previous state") + bc.log.Info("Loaded previous state") } func (bc *CaptchaProtect) ChallengeOnPage() bool { diff --git a/main_test.go b/main_test.go index ccf1190..e2be5fa 100644 --- a/main_test.go +++ b/main_test.go @@ -12,12 +12,6 @@ import ( "testing" ) -func init() { - log = slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{ - Level: slog.LevelDebug, - })) -} - func TestParseIp(t *testing.T) { tests := []struct { name string @@ -221,6 +215,8 @@ func TestRouteIsProtected(t *testing.T) { t.Run(tt.name+"_"+mode, func(t *testing.T) { c := CreateConfig() c.Mode = mode + c.SiteKey = "test-site-key" + c.SecretKey = "test-secret-key" c.ProtectFileExtensions = append(c.ProtectFileExtensions, tt.config.ProtectFileExtensions...) if useRegex { @@ -340,6 +336,8 @@ func TestRouteIsProtectedSuffix(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { c := CreateConfig() + c.SiteKey = "test-site-key" + c.SecretKey = "test-secret-key" c.ProtectRoutes = append(c.ProtectRoutes, tt.config.ProtectRoutes...) c.ExcludeRoutes = append(c.ExcludeRoutes, tt.config.ExcludeRoutes...) c.Mode = "suffix" @@ -448,6 +446,8 @@ func TestGetClientIP(t *testing.T) { req.RemoteAddr = tc.remoteAddr c := CreateConfig() + c.SiteKey = "test-site-key" + c.SecretKey = "test-secret-key" c.IPForwardedHeader = tc.config.IPForwardedHeader c.IPDepth = tc.config.IPDepth c.ProtectRoutes = []string{"/"} @@ -518,6 +518,8 @@ func TestServeHTTP(t *testing.T) { } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { + config.SiteKey = "test-site-key" + config.SecretKey = "test-secret-key" config.RateLimit = tc.rateLimit config.CaptchaProvider = "turnstile" config.ProtectRoutes = []string{"/"} @@ -573,6 +575,8 @@ func TestIsGoodUserAgent(t *testing.T) { {"Empty exempt list", []string{}, "Mozilla/5.0", false}, } config := CreateConfig() + config.SiteKey = "test-site-key" + config.SecretKey = "test-secret-key" config.ProtectRoutes = []string{"/"} for _, tc := range tests { config.ExemptUserAgents = tc.exemptUserAgents From a66a747395e70d1f31f26077703a10069a218468 Mon Sep 17 00:00:00 2001 From: Joe Corall Date: Sat, 25 Oct 2025 16:08:08 -0400 Subject: [PATCH 3/7] more tests --- .github/workflows/lint-test.yml | 2 +- go.mod | 2 +- main_test.go | 384 ++++++++++++++++++++++++++++++++ 3 files changed, 386 insertions(+), 2 deletions(-) diff --git a/.github/workflows/lint-test.yml b/.github/workflows/lint-test.yml index 0b552d4..2af769d 100644 --- a/.github/workflows/lint-test.yml +++ b/.github/workflows/lint-test.yml @@ -13,7 +13,7 @@ jobs: - uses: actions/setup-go@44694675825211faa026b3c33043df3e48a5fa00 # v6 with: - go-version: '>=1.24.0' + go-version: ">=1.25.0" - name: golangci-lint uses: golangci/golangci-lint-action@4afd733a84b1f43292c63897423277bb7f4313a9 # v8 diff --git a/go.mod b/go.mod index 1b3534a..4f45b71 100644 --- a/go.mod +++ b/go.mod @@ -1,5 +1,5 @@ module github.com/libops/captcha-protect -go 1.24.0 +go 1.25.0 require github.com/patrickmn/go-cache v2.1.0+incompatible diff --git a/main_test.go b/main_test.go index e2be5fa..51dac98 100644 --- a/main_test.go +++ b/main_test.go @@ -2,14 +2,17 @@ package captcha_protect import ( "context" + "encoding/json" "log/slog" "net" "net/http" "net/http/httptest" "os" + "path/filepath" "regexp" "strings" "testing" + "time" ) func TestParseIp(t *testing.T) { @@ -590,3 +593,384 @@ func TestIsGoodUserAgent(t *testing.T) { } } } + +func TestNewCaptchaProtectValidation(t *testing.T) { + tests := []struct { + name string + modifyConfig func(*Config) + expectError string + }{ + { + name: "Missing SiteKey", + modifyConfig: func(c *Config) { c.SiteKey = "" }, + expectError: "siteKey is required", + }, + { + name: "Missing SecretKey", + modifyConfig: func(c *Config) { c.SecretKey = "" }, + expectError: "secretKey is required", + }, + { + name: "Zero Window", + modifyConfig: func(c *Config) { c.Window = 0 }, + expectError: "window must be positive", + }, + { + name: "Negative Window", + modifyConfig: func(c *Config) { c.Window = -1 }, + expectError: "window must be positive", + }, + { + name: "Invalid CAPTCHA Provider", + modifyConfig: func(c *Config) { c.CaptchaProvider = "invalid" }, + expectError: "invalid captcha provider", + }, + { + name: "Invalid regex in ProtectRoutes", + modifyConfig: func(c *Config) { + c.Mode = "regex" + c.ProtectRoutes = []string{"[invalid"} + }, + expectError: "invalid regex in protectRoutes", + }, + { + name: "Invalid regex in ExcludeRoutes", + modifyConfig: func(c *Config) { + c.Mode = "regex" + c.ExcludeRoutes = []string{"[invalid"} + }, + expectError: "invalid regex in excludeRoutes", + }, + { + name: "ChallengeURL is /", + modifyConfig: func(c *Config) { c.ChallengeURL = "/" }, + expectError: "challenge URL can not be the entire site", + }, + { + name: "Invalid mode", + modifyConfig: func(c *Config) { c.Mode = "invalid" }, + expectError: "unknown mode", + }, + { + name: "Invalid IPv4 mask - too small", + modifyConfig: func(c *Config) { c.IPv4SubnetMask = 5 }, + expectError: "invalid ipv4 mask", + }, + { + name: "Invalid IPv4 mask - too large", + modifyConfig: func(c *Config) { c.IPv4SubnetMask = 33 }, + expectError: "invalid ipv4 mask", + }, + { + name: "Invalid IPv6 mask - too small", + modifyConfig: func(c *Config) { c.IPv6SubnetMask = 5 }, + expectError: "invalid ipv6 mask", + }, + { + name: "Invalid IPv6 mask - too large", + modifyConfig: func(c *Config) { c.IPv6SubnetMask = 200 }, + expectError: "invalid ipv6 mask", + }, + { + name: "Invalid CIDR in ExemptIPs", + modifyConfig: func(c *Config) { + c.ExemptIPs = []string{"not-a-cidr"} + }, + expectError: "error parsing cidr", + }, + { + name: "No protected routes in prefix mode", + modifyConfig: func(c *Config) { c.ProtectRoutes = []string{} }, + expectError: "you must protect at least one route", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := CreateConfig() + c.SiteKey = "test" + c.SecretKey = "test" + c.ProtectRoutes = []string{"/"} + tt.modifyConfig(c) + + _, err := NewCaptchaProtect(context.Background(), nil, c, "test") + if err == nil { + t.Errorf("Expected error containing %q, got nil", tt.expectError) + } else if !strings.Contains(err.Error(), tt.expectError) { + t.Errorf("Expected error containing %q, got %q", tt.expectError, err.Error()) + } + }) + } +} + +func TestRateLimiting(t *testing.T) { + config := CreateConfig() + config.SiteKey = "test" + config.SecretKey = "test" + config.ProtectRoutes = []string{"/"} + config.RateLimit = 5 + config.Window = 10 + + bc, err := NewCaptchaProtect(context.Background(), nil, config, "test") + if err != nil { + t.Fatal(err) + } + + subnet := "192.168.0.0" + + // Register 5 requests (at rate limit) + for i := 0; i < 5; i++ { + bc.registerRequest(subnet) + if bc.trippedRateLimit(subnet) { + t.Errorf("Should not trip rate limit at %d requests", i+1) + } + } + + // 6th request should trip + bc.registerRequest(subnet) + if !bc.trippedRateLimit(subnet) { + t.Error("Should trip rate limit after exceeding") + } + + // Different subnet should not be affected + differentSubnet := "10.0.0.0" + bc.registerRequest(differentSubnet) + if bc.trippedRateLimit(differentSubnet) { + t.Error("Different subnet should not be rate limited") + } +} + +func TestIsGoodBotWithParameters(t *testing.T) { + config := CreateConfig() + config.SiteKey = "test" + config.SecretKey = "test" + config.ProtectRoutes = []string{"/"} + config.ProtectParameters = "true" + config.GoodBots = []string{"googlebot.com"} + + bc, _ := NewCaptchaProtect(context.Background(), nil, config, "test") + + // Mock bot cache to simulate good bot + bc.botCache.Set("1.2.3.4", true, 1*time.Hour) + + tests := []struct { + name string + url string + expected bool + }{ + {"URL without params - good bot allowed", "http://example.com/page", true}, + {"URL with params - good bot blocked", "http://example.com/page?foo=bar", false}, + {"URL with multiple params - good bot blocked", "http://example.com/page?foo=bar&baz=qux", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := httptest.NewRequest("GET", tt.url, nil) + result := bc.isGoodBot(req, "1.2.3.4") + if result != tt.expected { + t.Errorf("Expected %v, got %v", tt.expected, result) + } + }) + } +} + +func TestVerifiedCacheBypasses(t *testing.T) { + config := CreateConfig() + config.SiteKey = "test" + config.SecretKey = "test" + config.ProtectRoutes = []string{"/"} + config.RateLimit = 0 // Always challenge unless verified + + bc, _ := NewCaptchaProtect(context.Background(), nil, config, "test") + + req := httptest.NewRequest("GET", "http://example.com/test", nil) + clientIP := "1.2.3.4" + + // Should apply before verification + if !bc.shouldApply(req, clientIP) { + t.Error("Should apply protection before verification") + } + + // Add to verified cache + bc.verifiedCache.Set(clientIP, true, 1*time.Hour) + + // Should not apply after verification + if bc.shouldApply(req, clientIP) { + t.Error("Should not apply protection after verification") + } +} + +func TestStatsPage(t *testing.T) { + config := CreateConfig() + config.SiteKey = "test" + config.SecretKey = "test" + config.ProtectRoutes = []string{"/"} + config.EnableStatsPage = "true" + + bc, _ := NewCaptchaProtect(context.Background(), nil, config, "test") + + // Add some test data + bc.rateCache.Set("192.168.0.0", uint(10), 1*time.Hour) + bc.verifiedCache.Set("1.2.3.4", true, 1*time.Hour) + + tests := []struct { + name string + clientIP string + expectedStatus int + }{ + {"Exempt IP can access", "192.168.1.1", http.StatusOK}, + {"Private IP can access", "10.0.0.1", http.StatusOK}, + {"Non-exempt IP forbidden", "1.2.3.4", http.StatusForbidden}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + rr := httptest.NewRecorder() + + bc.serveStatsPage(rr, tt.clientIP) + + if rr.Code != tt.expectedStatus { + t.Errorf("Expected status %d, got %d", tt.expectedStatus, rr.Code) + } + + if tt.expectedStatus == http.StatusOK { + // Verify JSON response + var stats map[string]interface{} + if err := json.Unmarshal(rr.Body.Bytes(), &stats); err != nil { + t.Errorf("Failed to parse JSON: %v", err) + } + // Check that we have expected keys + if _, ok := stats["rate"]; !ok { + t.Error("Stats JSON missing 'rate' key") + } + if _, ok := stats["verified"]; !ok { + t.Error("Stats JSON missing 'verified' key") + } + } + }) + } +} + +func TestProtectHttpMethods(t *testing.T) { + config := CreateConfig() + config.SiteKey = "test" + config.SecretKey = "test" + config.ProtectRoutes = []string{"/"} + config.ProtectHttpMethods = []string{"GET", "POST"} + config.RateLimit = 0 // Always challenge + + bc, _ := NewCaptchaProtect(context.Background(), nil, config, "test") + + tests := []struct { + method string + expected bool + }{ + {"GET", true}, + {"POST", true}, + {"PUT", false}, + {"DELETE", false}, + {"PATCH", false}, + {"HEAD", false}, + } + + for _, tt := range tests { + t.Run(tt.method, func(t *testing.T) { + req := httptest.NewRequest(tt.method, "http://example.com/test", nil) + result := bc.shouldApply(req, "1.2.3.4") + if result != tt.expected { + t.Errorf("Method %s: expected %v, got %v", tt.method, tt.expected, result) + } + }) + } +} + +func TestIsExtensionProtected(t *testing.T) { + config := CreateConfig() + config.SiteKey = "test" + config.SecretKey = "test" + config.ProtectRoutes = []string{"/"} + config.ProtectFileExtensions = []string{"html", "php", "json"} + + bc, _ := NewCaptchaProtect(context.Background(), nil, config, "test") + + tests := []struct { + path string + expected bool + }{ + {"/index.html", true}, + {"/api.json", true}, + {"/script.php", true}, + {"/style.css", false}, + {"/image.jpg", false}, + {"/no-extension", true}, // No extension = protected + {"/path/to/file.HTML", true}, // Case insensitive + {"/path/to/file.JSON", true}, + {"/path/to/file.Php", true}, + } + + for _, tt := range tests { + t.Run(tt.path, func(t *testing.T) { + result := bc.isExtensionProtected(tt.path) + if result != tt.expected { + t.Errorf("Path %s: expected %v, got %v", tt.path, tt.expected, result) + } + }) + } +} + +func TestStatePersistence(t *testing.T) { + tmpFile := filepath.Join(t.TempDir(), "state.json") + + config := CreateConfig() + config.SiteKey = "test" + config.SecretKey = "test" + config.ProtectRoutes = []string{"/"} + config.PersistentStateFile = tmpFile + + // Don't pass a context to avoid starting background goroutines + bc1, _ := NewCaptchaProtect(context.Background(), nil, config, "test") + + // Add some state + bc1.rateCache.Set("192.168.0.0", uint(10), 1*time.Hour) + bc1.verifiedCache.Set("1.2.3.4", true, 1*time.Hour) + bc1.botCache.Set("5.6.7.8", false, 1*time.Hour) + + // Manually save state by writing the file directly + // This tests the state format without relying on the background goroutine + jsonData, _ := json.Marshal(map[string]interface{}{ + "rate": map[string]uint{ + "192.168.0.0": 10, + }, + "verified": map[string]bool{ + "1.2.3.4": true, + }, + "bots": map[string]bool{ + "5.6.7.8": false, + }, + }) + err := os.WriteFile(tmpFile, jsonData, 0644) + if err != nil { + t.Fatalf("Failed to write state file: %v", err) + } + + // Create new instance - should load state + bc2, _ := NewCaptchaProtect(context.Background(), nil, config, "test") + + // Check rate cache + val, found := bc2.rateCache.Get("192.168.0.0") + if !found || val.(uint) != 10 { + t.Error("Rate cache state not persisted correctly") + } + + // Check verified cache + _, found = bc2.verifiedCache.Get("1.2.3.4") + if !found { + t.Error("Verified cache state not persisted correctly") + } + + // Check bot cache + botVal, found := bc2.botCache.Get("5.6.7.8") + if !found || botVal.(bool) != false { + t.Error("Bot cache state not persisted correctly") + } +} From bde8a1e38f48c3678efe88b95980b5c7e3a1577a Mon Sep 17 00:00:00 2001 From: Joe Corall Date: Sat, 25 Oct 2025 16:10:42 -0400 Subject: [PATCH 4/7] f --- CLAUDE.md | 171 +++++++++++++++++++++++++++++++++++++++++++++++++++ main_test.go | 2 +- 2 files changed, 172 insertions(+), 1 deletion(-) create mode 100644 CLAUDE.md diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 0000000..19d5ef9 --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,171 @@ +# CLAUDE.md + +This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository. + +## Project Overview + +This is a Traefik middleware plugin that protects websites from bot traffic by challenging individual IPs with CAPTCHAs when traffic spikes are detected from their subnet. The plugin supports Cloudflare Turnstile, Google reCAPTCHA, and hCaptcha. + +**Key concept**: Instead of rate limiting individual IPs, this plugin monitors traffic at the subnet level (e.g., /16 for IPv4, /64 for IPv6) and only challenges specific IPs when their entire subnet exceeds a configured rate limit. + +## Architecture + +### Core Components + +- **main.go** (`main.go:1-761`): Contains the entire middleware implementation in a single file + - `CaptchaProtect` struct: Main middleware handler with rate limiting, bot detection, and challenge serving + - `Config` struct: Configuration from Traefik labels + - Three in-memory caches (using `github.com/patrickmn/go-cache`): + - `rateCache`: Tracks request counts per subnet + - `verifiedCache`: Stores IPs that have passed challenges (24h default TTL) + - `botCache`: Caches reverse DNS lookups for bot verification + +### Request Flow Decision Tree + +The middleware follows this decision order (see `shouldApply()` at `main.go:422-453`): + +1. Check if HTTP method is protected (default: GET, HEAD) +2. Check if IP already verified (passed challenge recently) +3. Check if IP is in exemptIps (private ranges + configured exemptions) +4. Check if IP is a good bot (reverse DNS matches goodBots list) +5. Check if user agent is exempt +6. Check if route matches protection rules (prefix/suffix/regex matching) +7. If protected, increment subnet counter and check rate limit +8. If rate limit exceeded, serve challenge or redirect to challenge page + +### Internal Packages + +- **internal/helper/**: Utility functions + - `ip.go`: IP parsing, CIDR matching, reverse DNS lookups for bot verification + - `tmpl.go`: Default challenge template (embedded fallback) +- **internal/log/**: Structured logging with slog +- **internal/state/**: State serialization for persistent storage across restarts + +### Challenge Modes + +Two modes for serving challenges: + +1. **Redirect mode** (default): `challengeURL: "/challenge"` - Redirects to dedicated challenge page +2. **Inline mode**: `challengeURL: ""` - Serves challenge on the same page that triggered rate limit + +## Development Commands + +### Running Tests + +```bash +# Run unit tests +go test -v -race ./... + +# Run single test +go test -v -race -run TestParseIp + +# Run integration tests (requires Docker) +cd ci && go run test.go +``` + +### Linting and Formatting + +```bash +# Run golangci-lint locally +golangci-lint run + +# Format code +gofmt -w . + +# Check if go.mod is tidy +go mod tidy && git diff --exit-code go.mod go.sum + +# Update vendored dependencies +go mod vendor +``` + +### CI/CD + +The GitHub Actions workflow (`.github/workflows/lint-test.yml`) runs on every push: +1. golangci-lint +2. Validates `.traefik.yml` with yq +3. Checks `go mod tidy` and `go mod vendor` are up-to-date +4. Runs unit tests with race detector +5. Runs integration tests against Traefik v2.11, v3.0, v3.1, v3.2, v3.3, v3.4 + +### Integration Testing + +The `ci/` directory contains a full integration test: +- Spins up Traefik + nginx with docker-compose +- Generates 100 unique public IPs from different subnets +- Makes parallel requests to verify rate limiting behavior +- Tests state persistence across container restarts +- Validates stats endpoint JSON + +To run: `cd ci && go run test.go` + +## Key Implementation Details + +### Route Matching Modes + +Three modes configured via `mode` parameter (defaults to "prefix"): + +1. **prefix**: Fast string prefix matching (`strings.HasPrefix`) +2. **suffix**: Matches route suffixes (useful for specific endpoints) +3. **regex**: Full regex support (13x slower than prefix, use only when needed) + +Regex is significantly slower (~41ns vs ~3.4ns per operation) - see README benchmark section. + +### IP Subnet Calculation + +- IPv4: Masks IPs to configured subnet (default /16 means `192.168.x.x` → `192.168.0.0`) +- IPv6: Default /64 subnet mask +- Implementation at `main.go:621-642` + +### State Persistence + +When `persistentStateFile` is configured: +- State saves every 1 minute to JSON file (`saveState()` at `main.go:695-727`) +- On startup, loads previous state from file (`loadState()` at `main.go:729-756`) +- Contains: rate limits per subnet, bot verification cache, verified IPs + +### Good Bot Detection + +To avoid SEO impact, the plugin allows "good bots" to bypass rate limits: +- Performs reverse DNS lookup on IP (`internal/helper/ip.go`) +- Checks if hostname ends with configured second-level domain (e.g., "googlebot.com") +- Results cached in `botCache` to avoid repeated DNS lookups +- Optional `protectParameters: "true"` forces rate limiting even for good bots if URL contains query parameters + +### File Extension Filtering + +By default, only HTML files are rate-limited (to prevent CSS/JS/images from consuming rate limit quota). Configure `protectFileExtensions` to add more file types. + +## Configuration + +Configuration comes from Traefik labels. See `.traefik.yml` for the plugin manifest. + +Key defaults: +- `rateLimit: 20` requests per subnet +- `window: 86400` seconds (24 hours) +- `ipv4subnetMask: 16` (/16 = 65,536 IPs) +- `ipv6subnetMask: 64` +- `challengeStatusCode: 200` (or 429 for inline challenges) + +## Testing Strategy + +Unit tests (`main_test.go`) cover: +- IP parsing and subnet masking +- Route protection logic (prefix/suffix/regex) +- Client IP extraction from forwarded headers with depth traversal +- User agent exemption matching +- Challenge page serving with different status codes + +Integration tests (`ci/test.go`) verify: +- Full request lifecycle with real Traefik/nginx +- Rate limiting behavior across multiple subnets +- State persistence across container restarts +- Stats endpoint functionality + +## Traefik Plugin Constraints + +- Must implement `http.Handler` interface +- Entry point: `New(ctx context.Context, next http.Handler, config *Config, name string)` +- Plugin loaded via Traefik's `--experimental.plugins` flag +- No external state allowed (must use in-memory caches or file persistence) +- Must be compatible with Traefik v2.11.1+ diff --git a/main_test.go b/main_test.go index 51dac98..ad84bfd 100644 --- a/main_test.go +++ b/main_test.go @@ -902,7 +902,7 @@ func TestIsExtensionProtected(t *testing.T) { {"/script.php", true}, {"/style.css", false}, {"/image.jpg", false}, - {"/no-extension", true}, // No extension = protected + {"/no-extension", true}, // No extension = protected {"/path/to/file.HTML", true}, // Case insensitive {"/path/to/file.JSON", true}, {"/path/to/file.Php", true}, From 424a1eedf3e958706c4bddc36eee66c3488b3886 Mon Sep 17 00:00:00 2001 From: Joe Corall Date: Sat, 25 Oct 2025 16:15:44 -0400 Subject: [PATCH 5/7] more tests --- internal/helper/ip_test.go | 52 +++++++++ internal/helper/tmpl_test.go | 47 ++++++++ internal/log/log_test.go | 73 ++++++++++++ internal/state/state_test.go | 93 +++++++++++++++ main_test.go | 218 +++++++++++++++++++++++++++++++++++ 5 files changed, 483 insertions(+) create mode 100644 internal/helper/tmpl_test.go create mode 100644 internal/log/log_test.go create mode 100644 internal/state/state_test.go diff --git a/internal/helper/ip_test.go b/internal/helper/ip_test.go index cebbbcb..540a226 100644 --- a/internal/helper/ip_test.go +++ b/internal/helper/ip_test.go @@ -191,3 +191,55 @@ func parseCIDR(cidr string, t *testing.T) *net.IPNet { } return block } + +func TestParseCIDR(t *testing.T) { + tests := []struct { + name string + cidr string + expectErr bool + }{ + { + name: "Valid IPv4 CIDR", + cidr: "192.168.1.0/24", + expectErr: false, + }, + { + name: "Valid IPv6 CIDR", + cidr: "2001:db8::/32", + expectErr: false, + }, + { + name: "Invalid CIDR - no mask", + cidr: "192.168.1.0", + expectErr: true, + }, + { + name: "Invalid CIDR - bad format", + cidr: "not-a-cidr", + expectErr: true, + }, + { + name: "Invalid CIDR - empty string", + cidr: "", + expectErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := ParseCIDR(tt.cidr) + if tt.expectErr { + if err == nil { + t.Errorf("Expected error for CIDR %q, got nil", tt.cidr) + } + } else { + if err != nil { + t.Errorf("Unexpected error for CIDR %q: %v", tt.cidr, err) + } + if result == nil { + t.Errorf("Expected non-nil result for valid CIDR %q", tt.cidr) + } + } + }) + } +} diff --git a/internal/helper/tmpl_test.go b/internal/helper/tmpl_test.go new file mode 100644 index 0000000..7873e43 --- /dev/null +++ b/internal/helper/tmpl_test.go @@ -0,0 +1,47 @@ +package helper + +import ( + "strings" + "testing" +) + +func TestGetDefaultTmpl(t *testing.T) { + tmpl := GetDefaultTmpl() + + // Verify it returns a non-empty string + if tmpl == "" { + t.Error("GetDefaultTmpl returned empty string") + } + + // Verify it contains expected HTML elements + expectedElements := []string{ + "", + "", + "", + "", + "", + "", + "", + "{{ .FrontendJS }}", + "{{ .SiteKey }}", + "{{ .ChallengeURL }}", + "{{ .Destination }}", + "{{ .FrontendKey }}", + "captchaCallback", + } + + for _, elem := range expectedElements { + if !strings.Contains(tmpl, elem) { + t.Errorf("Template missing expected element: %s", elem) + } + } + + // Verify it's valid HTML structure (basic check) + if !strings.HasPrefix(tmpl, "") { + t.Error("Template should start with ") + } + if !strings.HasSuffix(strings.TrimSpace(tmpl), "") { + t.Error("Template should end with ") + } +} diff --git a/internal/log/log_test.go b/internal/log/log_test.go new file mode 100644 index 0000000..277511c --- /dev/null +++ b/internal/log/log_test.go @@ -0,0 +1,73 @@ +package log + +import ( + "log/slog" + "testing" +) + +func TestNew(t *testing.T) { + tests := []struct { + name string + levelStr string + expectedLevel slog.Level + }{ + {"DEBUG level", "DEBUG", slog.LevelDebug}, + {"INFO level", "INFO", slog.LevelInfo}, + {"WARN level", "WARN", slog.LevelWarn}, + {"WARNING level", "WARNING", slog.LevelWarn}, + {"ERROR level", "ERROR", slog.LevelError}, + {"debug lowercase", "debug", slog.LevelDebug}, + {"Unknown level defaults to INFO", "UNKNOWN", slog.LevelInfo}, + {"Empty level defaults to INFO", "", slog.LevelInfo}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + logger := New(tt.levelStr) + if logger == nil { + t.Error("Expected non-nil logger") + } + // Logger is created successfully, we can't easily test the exact level + // but we verify it doesn't panic or error + }) + } +} + +func TestParseLogLevel(t *testing.T) { + tests := []struct { + name string + level string + expected slog.Level + expectErr bool + }{ + {"DEBUG", "DEBUG", slog.LevelDebug, false}, + {"debug lowercase", "debug", slog.LevelDebug, false}, + {"INFO", "INFO", slog.LevelInfo, false}, + {"info lowercase", "info", slog.LevelInfo, false}, + {"WARN", "WARN", slog.LevelWarn, false}, + {"WARNING", "WARNING", slog.LevelWarn, false}, + {"warning lowercase", "warning", slog.LevelWarn, false}, + {"ERROR", "ERROR", slog.LevelError, false}, + {"error lowercase", "error", slog.LevelError, false}, + {"Unknown level", "INVALID", slog.LevelInfo, true}, + {"Empty string", "", slog.LevelInfo, true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + level, err := parseLogLevel(tt.level) + if tt.expectErr { + if err == nil { + t.Errorf("Expected error for level %q, got nil", tt.level) + } + } else { + if err != nil { + t.Errorf("Unexpected error for level %q: %v", tt.level, err) + } + } + if level != tt.expected { + t.Errorf("Expected level %v, got %v", tt.expected, level) + } + }) + } +} diff --git a/internal/state/state_test.go b/internal/state/state_test.go new file mode 100644 index 0000000..4237f80 --- /dev/null +++ b/internal/state/state_test.go @@ -0,0 +1,93 @@ +package state + +import ( + "testing" + "time" + + lru "github.com/patrickmn/go-cache" +) + +func TestGetState(t *testing.T) { + // Create test caches + rateCache := lru.New(1*time.Hour, 1*time.Minute) + botCache := lru.New(1*time.Hour, 1*time.Minute) + verifiedCache := lru.New(1*time.Hour, 1*time.Minute) + + // Add test data + rateCache.Set("192.168.0.0", uint(10), lru.DefaultExpiration) + rateCache.Set("10.0.0.0", uint(5), lru.DefaultExpiration) + + botCache.Set("1.2.3.4", true, lru.DefaultExpiration) + botCache.Set("5.6.7.8", false, lru.DefaultExpiration) + + verifiedCache.Set("9.9.9.9", true, lru.DefaultExpiration) + + // Get state + state := GetState(rateCache.Items(), botCache.Items(), verifiedCache.Items()) + + // Verify rate cache data + if len(state.Rate) != 2 { + t.Errorf("Expected 2 rate entries, got %d", len(state.Rate)) + } + if state.Rate["192.168.0.0"] != 10 { + t.Errorf("Expected rate 10 for 192.168.0.0, got %d", state.Rate["192.168.0.0"]) + } + if state.Rate["10.0.0.0"] != 5 { + t.Errorf("Expected rate 5 for 10.0.0.0, got %d", state.Rate["10.0.0.0"]) + } + + // Verify bot cache data + if len(state.Bots) != 2 { + t.Errorf("Expected 2 bot entries, got %d", len(state.Bots)) + } + if state.Bots["1.2.3.4"] != true { + t.Error("Expected bot 1.2.3.4 to be true") + } + if state.Bots["5.6.7.8"] != false { + t.Error("Expected bot 5.6.7.8 to be false") + } + + // Verify verified cache data + if len(state.Verified) != 1 { + t.Errorf("Expected 1 verified entry, got %d", len(state.Verified)) + } + if state.Verified["9.9.9.9"] != true { + t.Error("Expected 9.9.9.9 to be verified") + } + + // Verify memory tracking exists + if len(state.Memory) != 3 { + t.Errorf("Expected 3 memory entries, got %d", len(state.Memory)) + } + if state.Memory["rate"] == 0 { + t.Error("Expected non-zero memory for rate cache") + } + if state.Memory["bot"] == 0 { + t.Error("Expected non-zero memory for bot cache") + } + if state.Memory["verified"] == 0 { + t.Error("Expected non-zero memory for verified cache") + } +} + +func TestGetStateEmpty(t *testing.T) { + // Create empty caches + rateCache := lru.New(1*time.Hour, 1*time.Minute) + botCache := lru.New(1*time.Hour, 1*time.Minute) + verifiedCache := lru.New(1*time.Hour, 1*time.Minute) + + state := GetState(rateCache.Items(), botCache.Items(), verifiedCache.Items()) + + if len(state.Rate) != 0 { + t.Errorf("Expected 0 rate entries, got %d", len(state.Rate)) + } + if len(state.Bots) != 0 { + t.Errorf("Expected 0 bot entries, got %d", len(state.Bots)) + } + if len(state.Verified) != 0 { + t.Errorf("Expected 0 verified entries, got %d", len(state.Verified)) + } + if len(state.Memory) != 3 { + t.Errorf("Expected 3 memory entries, got %d", len(state.Memory)) + } +} diff --git a/main_test.go b/main_test.go index ad84bfd..7be8c8b 100644 --- a/main_test.go +++ b/main_test.go @@ -974,3 +974,221 @@ func TestStatePersistence(t *testing.T) { t.Error("Bot cache state not persisted correctly") } } + +func TestVerifyChallengePage(t *testing.T) { + tests := []struct { + name string + provider string + formValues map[string]string + mockResponse string + expectedStatus int + shouldSetCache bool + }{ + { + name: "Missing captcha response", + provider: "turnstile", + formValues: map[string]string{}, + expectedStatus: http.StatusBadRequest, + shouldSetCache: false, + }, + { + name: "Successful verification with destination", + provider: "turnstile", + formValues: map[string]string{ + "cf-turnstile-response": "valid-token", + "destination": "%2Fhome", + }, + mockResponse: `{"success":true}`, + expectedStatus: http.StatusFound, + shouldSetCache: true, + }, + { + name: "Successful verification without destination", + provider: "recaptcha", + formValues: map[string]string{ + "g-recaptcha-response": "valid-token", + }, + mockResponse: `{"success":true}`, + expectedStatus: http.StatusFound, + shouldSetCache: true, + }, + { + name: "Failed verification", + provider: "hcaptcha", + formValues: map[string]string{ + "h-captcha-response": "invalid-token", + }, + mockResponse: `{"success":false}`, + expectedStatus: http.StatusForbidden, + shouldSetCache: false, + }, + { + name: "Invalid destination URL", + provider: "turnstile", + formValues: map[string]string{ + "cf-turnstile-response": "valid-token", + "destination": "%ZZ", + }, + mockResponse: `{"success":true}`, + expectedStatus: http.StatusFound, + shouldSetCache: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Create mock server + mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(tt.mockResponse)) + })) + defer mockServer.Close() + + config := CreateConfig() + config.SiteKey = "test" + config.SecretKey = "test" + config.ProtectRoutes = []string{"/"} + config.CaptchaProvider = tt.provider + + bc, _ := NewCaptchaProtect(context.Background(), nil, config, "test") + + // Override the validation URL to point to our mock server + bc.captchaConfig.validate = mockServer.URL + + // Create request with form values + req := httptest.NewRequest("POST", "http://example.com/challenge", nil) + req.Form = make(map[string][]string) + for k, v := range tt.formValues { + req.Form.Set(k, v) + } + + rr := httptest.NewRecorder() + clientIP := "1.2.3.4" + + status := bc.verifyChallengePage(rr, req, clientIP) + + if status != tt.expectedStatus { + t.Errorf("Expected status %d, got %d", tt.expectedStatus, status) + } + + // Check if IP was added to verified cache + _, found := bc.verifiedCache.Get(clientIP) + if found != tt.shouldSetCache { + t.Errorf("Expected cache set=%v, got=%v", tt.shouldSetCache, found) + } + }) + } +} + +func TestVerifyChallengePageHTTPError(t *testing.T) { + // Test HTTP client error + config := CreateConfig() + config.SiteKey = "test" + config.SecretKey = "test" + config.ProtectRoutes = []string{"/"} + + bc, _ := NewCaptchaProtect(context.Background(), nil, config, "test") + + // Set invalid URL to trigger HTTP error + bc.captchaConfig.validate = "http://invalid-domain-that-does-not-exist-12345.com" + + req := httptest.NewRequest("POST", "http://example.com/challenge", nil) + req.Form = make(map[string][]string) + req.Form.Set("cf-turnstile-response", "token") + + rr := httptest.NewRecorder() + status := bc.verifyChallengePage(rr, req, "1.2.3.4") + + if status != http.StatusInternalServerError { + t.Errorf("Expected status %d for HTTP error, got %d", http.StatusInternalServerError, status) + } +} + +func TestVerifyChallengePageInvalidJSON(t *testing.T) { + // Test invalid JSON response + mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, _ = w.Write([]byte(`{invalid json`)) + })) + defer mockServer.Close() + + config := CreateConfig() + config.SiteKey = "test" + config.SecretKey = "test" + config.ProtectRoutes = []string{"/"} + + bc, _ := NewCaptchaProtect(context.Background(), nil, config, "test") + bc.captchaConfig.validate = mockServer.URL + + req := httptest.NewRequest("POST", "http://example.com/challenge", nil) + req.Form = make(map[string][]string) + req.Form.Set("cf-turnstile-response", "token") + + rr := httptest.NewRecorder() + status := bc.verifyChallengePage(rr, req, "1.2.3.4") + + if status != http.StatusInternalServerError { + t.Errorf("Expected status %d for JSON error, got %d", http.StatusInternalServerError, status) + } +} + +func TestServeHTTPMethodNotAllowed(t *testing.T) { + next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + rw.WriteHeader(http.StatusOK) + }) + + config := CreateConfig() + config.SiteKey = "test" + config.SecretKey = "test" + config.ProtectRoutes = []string{"/"} + config.ChallengeURL = "/challenge" + + bc, _ := NewCaptchaProtect(context.Background(), next, config, "test") + + req := httptest.NewRequest("DELETE", "http://example.com/challenge", nil) + req.RequestURI = "/challenge" + rr := httptest.NewRecorder() + + bc.ServeHTTP(rr, req) + + if rr.Code != http.StatusMethodNotAllowed { + t.Errorf("Expected status %d, got %d", http.StatusMethodNotAllowed, rr.Code) + } +} + +func TestLoadStateInvalidJSON(t *testing.T) { + tmpFile := filepath.Join(t.TempDir(), "invalid.json") + + // Write invalid JSON + _ = os.WriteFile(tmpFile, []byte(`{invalid json`), 0644) + + config := CreateConfig() + config.SiteKey = "test" + config.SecretKey = "test" + config.ProtectRoutes = []string{"/"} + config.PersistentStateFile = tmpFile + + // Should not panic, just log error + bc, err := NewCaptchaProtect(context.Background(), nil, config, "test") + if err != nil { + t.Errorf("Should not fail on invalid state JSON: %v", err) + } + + // Caches should be empty + if bc.rateCache.ItemCount() != 0 { + t.Error("Rate cache should be empty after failed load") + } +} + +func TestParseHttpMethodsInvalid(t *testing.T) { + config := CreateConfig() + config.SiteKey = "test" + config.SecretKey = "test" + config.ProtectRoutes = []string{"/"} + config.ProtectHttpMethods = []string{"GET", "INVALID_METHOD", "POST"} + + // Should not fail, just log warning + _, err := NewCaptchaProtect(context.Background(), nil, config, "test") + if err != nil { + t.Errorf("Should not fail on invalid HTTP method: %v", err) + } +} From 4cee313fb02af44e525134d004e78a85a61283bc Mon Sep 17 00:00:00 2001 From: Joe Corall Date: Sat, 25 Oct 2025 16:17:34 -0400 Subject: [PATCH 6/7] codecov --- .github/workflows/lint-test.yml | 11 +++++++++++ README.md | 1 + 2 files changed, 12 insertions(+) diff --git a/.github/workflows/lint-test.yml b/.github/workflows/lint-test.yml index 2af769d..ce96425 100644 --- a/.github/workflows/lint-test.yml +++ b/.github/workflows/lint-test.yml @@ -42,6 +42,17 @@ jobs: - name: unit test run: go test -v -race ./... + - name: generate coverage + run: go test -coverprofile=coverage.out -covermode=atomic ./... + + - name: upload coverage to codecov + uses: codecov/codecov-action@v5 + with: + files: ./coverage.out + fail_ci_if_error: false + env: + CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} + integration-test: needs: [run] permissions: diff --git a/README.md b/README.md index e44145d..643e088 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,7 @@ # Captcha Protect [![lint-test](https://github.com/libops/captcha-protect/actions/workflows/lint-test.yml/badge.svg)](https://github.com/libops/captcha-protect/actions/workflows/lint-test.yml) [![Go Report Card](https://goreportcard.com/badge/github.com/libops/captcha-protect)](https://goreportcard.com/report/github.com/libops/captcha-protect) +[![codecov](https://codecov.io/gh/libops/captcha-protect/branch/main/graph/badge.svg)](https://codecov.io/gh/libops/captcha-protect) Traefik middleware to challenge individual IPs in a subnet when traffic spikes are detected from that subnet, using a captcha of your choice for the challenge (turnstile, recaptcha, or hcaptcha). **Requires traefik `v2.11.1` or above** From 6c917157dd8418d1d56256aeaac0690fbcc085d1 Mon Sep 17 00:00:00 2001 From: Joe Corall Date: Sat, 25 Oct 2025 16:19:32 -0400 Subject: [PATCH 7/7] fixup cleanup --- main_test.go | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/main_test.go b/main_test.go index 7be8c8b..b2b2c33 100644 --- a/main_test.go +++ b/main_test.go @@ -1156,10 +1156,13 @@ func TestServeHTTPMethodNotAllowed(t *testing.T) { } func TestLoadStateInvalidJSON(t *testing.T) { - tmpFile := filepath.Join(t.TempDir(), "invalid.json") + tmpDir := t.TempDir() + tmpFile := filepath.Join(tmpDir, "invalid.json") // Write invalid JSON - _ = os.WriteFile(tmpFile, []byte(`{invalid json`), 0644) + if err := os.WriteFile(tmpFile, []byte(`{invalid json`), 0644); err != nil { + t.Fatalf("Failed to write test file: %v", err) + } config := CreateConfig() config.SiteKey = "test" @@ -1177,6 +1180,9 @@ func TestLoadStateInvalidJSON(t *testing.T) { if bc.rateCache.ItemCount() != 0 { t.Error("Rate cache should be empty after failed load") } + + // Clean up the file before temp dir cleanup + _ = os.Remove(tmpFile) } func TestParseHttpMethodsInvalid(t *testing.T) {