diff --git a/auth_test.go b/auth_test.go index fc804e3..13b2a0d 100644 --- a/auth_test.go +++ b/auth_test.go @@ -71,3 +71,39 @@ func TestAddAuth_RejectsHS512(t *testing.T) { require.Equal(t, http.StatusUnauthorized, w.Code) } +func TestAddAuth_MissingHeader(t *testing.T) { + secretHex := strings.Repeat("c", 256) + t.Setenv("DRAND_AUTH_KEY", secretHex) + + next := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + }) + protected := AddAuth(next) + + r := httptest.NewRequest(http.MethodGet, "/v2/chains", nil) + w := httptest.NewRecorder() + + protected.ServeHTTP(w, r) + require.Equal(t, http.StatusUnauthorized, w.Code) + require.Equal(t, "Missing JWT\n", w.Body.String()) +} + +func TestAddAuth_InvalidToken(t *testing.T) { + secretHex := strings.Repeat("d", 256) + t.Setenv("DRAND_AUTH_KEY", secretHex) + + next := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + }) + protected := AddAuth(next) + + r := httptest.NewRequest(http.MethodGet, "/v2/chains", nil) + r.Header.Set("Authorization", "Bearer not-a-valid-jwt") + w := httptest.NewRecorder() + + protected.ServeHTTP(w, r) + require.Equal(t, http.StatusUnauthorized, w.Code) + require.Equal(t, "Invalid JWT\n", w.Body.String()) +} + + diff --git a/middleware_test.go b/middleware_test.go new file mode 100644 index 0000000..3db551d --- /dev/null +++ b/middleware_test.go @@ -0,0 +1,39 @@ +package main + +import ( + "log/slog" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestAddCommonHeaders(t *testing.T) { + next := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + }) + + handler := addCommonHeaders(next) + + r := httptest.NewRequest(http.MethodGet, "/v2/chains", nil) + w := httptest.NewRecorder() + + handler.ServeHTTP(w, r) + + require.Equal(t, http.StatusOK, w.Code) + require.Equal(t, version, w.Header().Get("Server")) + require.Equal(t, "application/json", w.Header().Get("Content-Type")) + require.Equal(t, "*", w.Header().Get("Access-Control-Allow-Origin")) +} + +func TestGetLogLevel(t *testing.T) { + old := *verbose + defer func() { *verbose = old }() + + *verbose = false + require.Equal(t, slog.LevelInfo, getLogLevel()) + + *verbose = true + require.Equal(t, slog.LevelDebug, getLogLevel()) +} diff --git a/routes_test.go b/routes_test.go new file mode 100644 index 0000000..56bb817 --- /dev/null +++ b/routes_test.go @@ -0,0 +1,93 @@ +package main + +import ( + "net/http" + "net/http/httptest" + "strings" + "testing" +) + +func TestDisplayRoutes_FiltersAndFallsBack(t *testing.T) { + // prepare a stable set of routes + allRoutes = []string{ + "GET /v1/info", + "GET /v2/chains", + "GET /v2/chains/{chainhash}/info", + } + + t.Run("filters by prefix", func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/v2", nil) + rr := httptest.NewRecorder() + + DisplayRoutes(rr, req) + + if rr.Code != http.StatusNotFound { + t.Fatalf("expected status %d, got %d", http.StatusNotFound, rr.Code) + } + + body := rr.Body.String() + if !strings.Contains(body, "GET /v2/chains") { + t.Fatalf("expected filtered routes to contain v2 route, got %q", body) + } + if strings.Contains(body, "GET /v1/info") { + t.Fatalf("expected v1 route to be filtered out, got %q", body) + } + }) + + t.Run("falls back to all routes when no prefix matches", func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/nonexistent", nil) + rr := httptest.NewRecorder() + + DisplayRoutes(rr, req) + + if rr.Code != http.StatusNotFound { + t.Fatalf("expected status %d, got %d", http.StatusNotFound, rr.Code) + } + + body := rr.Body.String() + for _, route := range allRoutes { + if !strings.Contains(body, route) { + t.Fatalf("expected fallback to include %q, got %q", route, body) + } + } + }) +} + +func TestDisplayRoutes_HeadersAndOrder(t *testing.T) { + allRoutes = []string{ + "GET /v1/info", + "GET /v2/chains", + "GET /v2/chains/{chainhash}/info", + } + + req := httptest.NewRequest(http.MethodGet, "/", nil) + rr := httptest.NewRecorder() + + DisplayRoutes(rr, req) + + // Content-Type and other headers + if ct := rr.Header().Get("Content-Type"); ct != "text/plain; charset=utf-8" { + t.Fatalf("expected Content-Type text/plain; charset=utf-8, got %q", ct) + } + if rr.Header().Get("X-Content-Type-Options") != "nosniff" { + t.Fatalf("expected X-Content-Type-Options: nosniff") + } + + // V2 routes must appear after V1 (DisplayRoutes sort order) + body := rr.Body.String() + lines := strings.Split(strings.TrimSpace(body), "\n") + v1Idx, v2Idx := -1, -1 + for i, line := range lines { + if strings.HasPrefix(line, "GET /v1") { + v1Idx = i + } + if strings.HasPrefix(line, "GET /v2") { + if v2Idx == -1 { + v2Idx = i + } + } + } + if v1Idx >= 0 && v2Idx >= 0 && v1Idx > v2Idx { + t.Fatalf("expected v2 routes after v1: v1 at %d, v2 at %d; body:\n%s", v1Idx, v2Idx, body) + } +}