diff --git a/main.go b/main.go index 63cd14d..9958d99 100644 --- a/main.go +++ b/main.go @@ -31,6 +31,14 @@ var ( ) func init() { + // Skip flag parsing during tests + if len(os.Args) > 0 { + for _, arg := range os.Args { + if arg == "-test.v" || arg == "-test.run" || strings.HasPrefix(arg, "-test.") { + return + } + } + } flag.Parse() slog.SetLogLoggerLevel(getLogLevel()) } diff --git a/routes_handler.go b/routes_handler.go index 50eb637..914766f 100644 --- a/routes_handler.go +++ b/routes_handler.go @@ -17,15 +17,35 @@ import ( "github.com/go-chi/chi/v5" ) +// readRound parses and validates the round parameter from the HTTP request. +func readRound(r *http.Request) (uint64, error) { + roundStr := chi.URLParam(r, "round") + if roundStr == "" { + return 0, fmt.Errorf("round parameter is required") + } + + roundN, err := strconv.ParseUint(roundStr, 10, 64) + if err != nil { + return 0, fmt.Errorf("invalid round number %q: %w", roundStr, err) + } + + // Set a reasonable maximum to prevent issues with time calculations + const maxReasonableRound = uint64(1 << 60) + if roundN > maxReasonableRound { + return 0, fmt.Errorf("round number %d exceeds maximum allowed value %d", roundN, maxReasonableRound) + } + + return roundN, nil +} + func GetBeacon(c *grpc.Client, isV2 bool) func(http.ResponseWriter, *http.Request) { return func(w http.ResponseWriter, r *http.Request) { - roundStr := chi.URLParam(r, "round") - round, err := strconv.ParseUint(roundStr, 10, 64) + round, err := readRound(r) if err != nil { w.Header().Set("Cache-Control", "public, max-age=604800, immutable") - slog.Error("unable to parse round", "error", err) - http.Error(w, "Failed to parse round. Err: "+err.Error(), http.StatusBadRequest) + slog.Error("unable to parse round", "error", err, "client", r.RemoteAddr, "path", r.URL.Path) + http.Error(w, err.Error(), http.StatusBadRequest) return } diff --git a/routes_handler_test.go b/routes_handler_test.go new file mode 100644 index 0000000..76115b0 --- /dev/null +++ b/routes_handler_test.go @@ -0,0 +1,159 @@ +package main + +import ( + "net/http" + "net/http/httptest" + "strconv" + "testing" + + "github.com/go-chi/chi/v5" + "github.com/stretchr/testify/require" +) + +func TestReadRound(t *testing.T) { + testCases := []struct { + name string + round string + expectError bool + expectedMsg string + expectedVal uint64 + description string + }{ + { + name: "empty round parameter", + round: "", + expectError: true, + expectedMsg: "round parameter is required", + description: "empty round should return error", + }, + { + name: "invalid format - non-numeric", + round: "abc", + expectError: true, + expectedMsg: "invalid round number", + description: "non-numeric round should return error", + }, + { + name: "invalid format - negative", + round: "-1", + expectError: true, + expectedMsg: "invalid round number", + description: "negative round should return error", + }, + { + name: "invalid format - decimal", + round: "1.5", + expectError: true, + expectedMsg: "invalid round number", + description: "decimal round should return error", + }, + { + name: "extremely large number - parse error", + round: "999999999999999999999999999999999999999999999999999999999999999999999999", + expectError: true, + expectedMsg: "invalid round number", + description: "extremely large round (parse error) should return error", + }, + { + name: "large but parseable number exceeding reasonable max", + round: "1152921504606846977", // 2^60 + 1 + expectError: true, + expectedMsg: "exceeds maximum allowed value", + description: "round exceeding reasonable maximum should return error", + }, + { + name: "valid round number", + round: "1", + expectError: false, + expectedVal: 1, + description: "valid round should parse successfully", + }, + { + name: "valid large round number", + round: "1152921504606846976", // 2^60 (max allowed) + expectError: false, + expectedVal: 1152921504606846976, + description: "valid large round at maximum should parse successfully", + }, + { + name: "zero round", + round: "0", + expectError: false, + expectedVal: 0, + description: "zero round should parse successfully", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + // Create a chi router and request + r := chi.NewRouter() + r.Get("/test/{round}", func(w http.ResponseWriter, r *http.Request) { + round, err := readRound(r) + if tc.expectError { + require.Error(t, err, tc.description) + if tc.expectedMsg != "" { + require.Contains(t, err.Error(), tc.expectedMsg, tc.description) + } + } else { + require.NoError(t, err, tc.description) + require.Equal(t, tc.expectedVal, round, tc.description) + } + }) + + req := httptest.NewRequest("GET", "/test/"+tc.round, nil) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + }) + } +} + +func TestReadRoundDirect(t *testing.T) { + // Test empty round parameter + r := chi.NewRouter() + r.Get("/test/{round}", func(w http.ResponseWriter, r *http.Request) { + round, err := readRound(r) + require.Error(t, err) + require.Contains(t, err.Error(), "round parameter is required") + require.Equal(t, uint64(0), round) + }) + req := httptest.NewRequest("GET", "/test/", nil) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + // Test valid round + r = chi.NewRouter() + r.Get("/test/{round}", func(w http.ResponseWriter, r *http.Request) { + round, err := readRound(r) + require.NoError(t, err) + require.Equal(t, uint64(123), round) + }) + req = httptest.NewRequest("GET", "/test/123", nil) + w = httptest.NewRecorder() + r.ServeHTTP(w, req) + + // Test invalid round + r = chi.NewRouter() + r.Get("/test/{round}", func(w http.ResponseWriter, r *http.Request) { + round, err := readRound(r) + require.Error(t, err) + require.Contains(t, err.Error(), "invalid round number") + require.Equal(t, uint64(0), round) + }) + req = httptest.NewRequest("GET", "/test/abc", nil) + w = httptest.NewRecorder() + r.ServeHTTP(w, req) + + // Test round exceeding maximum + maxRound := strconv.FormatUint(uint64(1<<60)+1, 10) + r = chi.NewRouter() + r.Get("/test/{round}", func(w http.ResponseWriter, r *http.Request) { + round, err := readRound(r) + require.Error(t, err) + require.Contains(t, err.Error(), "exceeds maximum allowed value") + require.Equal(t, uint64(0), round) + }) + req = httptest.NewRequest("GET", "/test/"+maxRound, nil) + w = httptest.NewRecorder() + r.ServeHTTP(w, req) +}