diff --git a/pkg/webhook/mutating/middleware.go b/pkg/webhook/mutating/middleware.go index 1c1ee29947..d484f6af98 100644 --- a/pkg/webhook/mutating/middleware.go +++ b/pkg/webhook/mutating/middleware.go @@ -7,6 +7,7 @@ import ( "bytes" "context" "encoding/json" + "errors" "fmt" "io" "log/slog" @@ -88,9 +89,16 @@ func createMutatingHandler(executors []clientExecutor, serverName, transport str return } - // Read the request body to get the raw MCP request. + // Read the request body to get the raw MCP request, capped at MaxRequestSize + // to prevent unbounded memory consumption from oversized inbound requests. + r.Body = http.MaxBytesReader(w, r.Body, webhook.MaxRequestSize) bodyBytes, err := io.ReadAll(r.Body) if err != nil { + var maxErr *http.MaxBytesError + if errors.As(err, &maxErr) { + sendErrorResponse(w, http.StatusRequestEntityTooLarge, "Request body exceeds maximum size", parsedMCP.ID) + return + } sendErrorResponse(w, http.StatusInternalServerError, "Failed to read request body", parsedMCP.ID) return } diff --git a/pkg/webhook/mutating/middleware_test.go b/pkg/webhook/mutating/middleware_test.go index 38bee34492..12670dd816 100644 --- a/pkg/webhook/mutating/middleware_test.go +++ b/pkg/webhook/mutating/middleware_test.go @@ -405,6 +405,81 @@ func TestMutatingMiddleware_SkipNonMCPRequests(t *testing.T) { assert.Equal(t, http.StatusOK, rr.Code) } +// makeJSONBodyOfSize builds a syntactically-valid JSON body of exactly `size` +// bytes by padding the value of a "data" field with ASCII characters. The +// resulting bytes are valid JSON-RPC for use as an MCP request body. +func makeJSONBodyOfSize(tb testing.TB, size int) []byte { + tb.Helper() + const envelope = `{"jsonrpc":"2.0","method":"tools/call","id":1,"data":""}` + if size < len(envelope) { + tb.Fatalf("requested size %d is smaller than minimum envelope size %d", size, len(envelope)) + } + padding := bytes.Repeat([]byte("a"), size-len(envelope)) + body := []byte(`{"jsonrpc":"2.0","method":"tools/call","id":1,"data":"`) + body = append(body, padding...) + body = append(body, []byte(`"}`)...) + if len(body) != size { + tb.Fatalf("constructed body length %d != requested size %d", len(body), size) + } + return body +} + +func TestMutatingMiddleware_RequestBodySizeLimit(t *testing.T) { + t.Parallel() + + t.Run("body exceeding MaxRequestSize returns 413", func(t *testing.T) { + t.Parallel() + + cfg := makeConfig(closedServerURL, webhook.FailurePolicyIgnore) + mw := createMutatingHandler(makeExecutors(t, []webhook.Config{cfg}), "srv", "stdio") + + body := makeJSONBodyOfSize(t, webhook.MaxRequestSize+1) + req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(body)) + ctx := context.WithValue(req.Context(), mcp.MCPRequestContextKey, &mcp.ParsedMCPRequest{Method: "tools/call", ID: 1}) + req = req.WithContext(ctx) + + var nextCalled bool + nextHandler := http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) { nextCalled = true }) + + rr := httptest.NewRecorder() + mw(nextHandler).ServeHTTP(rr, req) + + assert.False(t, nextCalled, "next must not be called for oversized requests") + assert.Equal(t, http.StatusRequestEntityTooLarge, rr.Code) + + var errResp map[string]interface{} + require.NoError(t, json.Unmarshal(rr.Body.Bytes(), &errResp)) + errObj, ok := errResp["Error"].(map[string]interface{}) + require.True(t, ok) + assert.Equal(t, float64(http.StatusRequestEntityTooLarge), errObj["code"]) + assert.Equal(t, "Request body exceeds maximum size", errObj["message"]) + }) + + t.Run("body at MaxRequestSize boundary is accepted", func(t *testing.T) { + t.Parallel() + + // Use FailurePolicyIgnore against a closed port: the outbound webhook + // call fails and is ignored per fail-open, so next is invoked. This + // isolates the test from depending on a working webhook server. + cfg := makeConfig(closedServerURL, webhook.FailurePolicyIgnore) + mw := createMutatingHandler(makeExecutors(t, []webhook.Config{cfg}), "srv", "stdio") + + body := makeJSONBodyOfSize(t, webhook.MaxRequestSize) + req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(body)) + ctx := context.WithValue(req.Context(), mcp.MCPRequestContextKey, &mcp.ParsedMCPRequest{Method: "tools/call", ID: 1}) + req = req.WithContext(ctx) + + var nextCalled bool + nextHandler := http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) { nextCalled = true }) + + rr := httptest.NewRecorder() + mw(nextHandler).ServeHTTP(rr, req) + + assert.True(t, nextCalled, "next should be called for boundary-size body (fail-open ignores webhook error)") + assert.Equal(t, http.StatusOK, rr.Code) + }) +} + func TestMiddlewareParams_Validate(t *testing.T) { t.Parallel() tests := []struct { diff --git a/pkg/webhook/types.go b/pkg/webhook/types.go index 1d8f6ef110..f9ef729d1f 100644 --- a/pkg/webhook/types.go +++ b/pkg/webhook/types.go @@ -30,6 +30,11 @@ const MinTimeout = 1 * time.Second // MaxResponseSize is the maximum allowed size in bytes for webhook responses (1 MB). const MaxResponseSize = 1 << 20 +// MaxRequestSize is the maximum allowed size in bytes for inbound webhook +// middleware request bodies (1 MB). Requests exceeding this size are +// rejected with HTTP 413 before the body is buffered or forwarded. +const MaxRequestSize = 1 << 20 + // Type indicates whether a webhook is validating or mutating. type Type string diff --git a/pkg/webhook/validating/middleware.go b/pkg/webhook/validating/middleware.go index 370791c570..a7783ae0bf 100644 --- a/pkg/webhook/validating/middleware.go +++ b/pkg/webhook/validating/middleware.go @@ -7,6 +7,7 @@ import ( "bytes" "context" "encoding/json" + "errors" "fmt" "io" "log/slog" @@ -97,9 +98,16 @@ func createValidatingHandler(executors []clientExecutor, serverName, transport s return } - // Read the request body to get the raw MCP request + // Read the request body to get the raw MCP request, capped at MaxRequestSize + // to prevent unbounded memory consumption from oversized inbound requests. + r.Body = http.MaxBytesReader(w, r.Body, webhook.MaxRequestSize) bodyBytes, err := io.ReadAll(r.Body) if err != nil { + var maxErr *http.MaxBytesError + if errors.As(err, &maxErr) { + sendErrorResponse(w, http.StatusRequestEntityTooLarge, "Request body exceeds maximum size", parsedMCP.ID) + return + } sendErrorResponse(w, http.StatusInternalServerError, "Failed to read request body", parsedMCP.ID) return } diff --git a/pkg/webhook/validating/middleware_test.go b/pkg/webhook/validating/middleware_test.go index bfe60d0b6c..862b65938c 100644 --- a/pkg/webhook/validating/middleware_test.go +++ b/pkg/webhook/validating/middleware_test.go @@ -270,6 +270,102 @@ func TestValidatingMiddleware(t *testing.T) { }) } +// makeJSONBodyOfSize builds a syntactically-valid JSON body of exactly `size` +// bytes by padding the value of a "data" field with ASCII characters. The +// resulting bytes are valid JSON-RPC for use as an MCP request body. +func makeJSONBodyOfSize(tb testing.TB, size int) []byte { + tb.Helper() + // Envelope structure: {"jsonrpc":"2.0","method":"tools/call","id":1,"data":""} + const envelope = `{"jsonrpc":"2.0","method":"tools/call","id":1,"data":""}` + if size < len(envelope) { + tb.Fatalf("requested size %d is smaller than minimum envelope size %d", size, len(envelope)) + } + padding := bytes.Repeat([]byte("a"), size-len(envelope)) + body := []byte(`{"jsonrpc":"2.0","method":"tools/call","id":1,"data":"`) + body = append(body, padding...) + body = append(body, []byte(`"}`)...) + if len(body) != size { + tb.Fatalf("constructed body length %d != requested size %d", len(body), size) + } + return body +} + +func TestValidatingMiddleware_RequestBodySizeLimit(t *testing.T) { + t.Parallel() + + t.Run("body exceeding MaxRequestSize returns 413", func(t *testing.T) { + t.Parallel() + + // Use fail-open so that even if the middleware reached the webhook call + // (it must NOT, since the size check fires first), the test would still + // be unambiguous about the size rejection vs. a downstream error. + cfg := webhook.Config{ + Name: "test-webhook", + URL: closedServerURL, + Timeout: webhook.DefaultTimeout, + FailurePolicy: webhook.FailurePolicyIgnore, + TLSConfig: &webhook.TLSConfig{InsecureSkipVerify: true}, + } + client, err := webhook.NewClient(cfg, webhook.TypeValidating, nil) + require.NoError(t, err) + mw := createValidatingHandler([]clientExecutor{{client: client, config: cfg}}, "srv", "stdio") + + body := makeJSONBodyOfSize(t, webhook.MaxRequestSize+1) + req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(body)) + ctx := context.WithValue(req.Context(), mcp.MCPRequestContextKey, &mcp.ParsedMCPRequest{Method: "tools/call", ID: 1}) + req = req.WithContext(ctx) + + var nextCalled bool + nextHandler := http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) { nextCalled = true }) + + rr := httptest.NewRecorder() + mw(nextHandler).ServeHTTP(rr, req) + + assert.False(t, nextCalled, "next must not be called for oversized requests") + assert.Equal(t, http.StatusRequestEntityTooLarge, rr.Code) + + // The error response is a JSON-RPC envelope. + var errResp map[string]interface{} + require.NoError(t, json.Unmarshal(rr.Body.Bytes(), &errResp)) + errObj, ok := errResp["Error"].(map[string]interface{}) + require.True(t, ok) + assert.Equal(t, float64(http.StatusRequestEntityTooLarge), errObj["code"]) + assert.Equal(t, "Request body exceeds maximum size", errObj["message"]) + }) + + t.Run("body at MaxRequestSize boundary is accepted", func(t *testing.T) { + t.Parallel() + + // Point at a closed port with FailurePolicyIgnore so that the outbound + // webhook call fails and is ignored per fail-open. This isolates the + // test from depending on a working webhook server. + cfg := webhook.Config{ + Name: "test-webhook", + URL: closedServerURL, + Timeout: webhook.DefaultTimeout, + FailurePolicy: webhook.FailurePolicyIgnore, + TLSConfig: &webhook.TLSConfig{InsecureSkipVerify: true}, + } + client, err := webhook.NewClient(cfg, webhook.TypeValidating, nil) + require.NoError(t, err) + mw := createValidatingHandler([]clientExecutor{{client: client, config: cfg}}, "srv", "stdio") + + body := makeJSONBodyOfSize(t, webhook.MaxRequestSize) + req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(body)) + ctx := context.WithValue(req.Context(), mcp.MCPRequestContextKey, &mcp.ParsedMCPRequest{Method: "tools/call", ID: 1}) + req = req.WithContext(ctx) + + var nextCalled bool + nextHandler := http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) { nextCalled = true }) + + rr := httptest.NewRecorder() + mw(nextHandler).ServeHTTP(rr, req) + + assert.True(t, nextCalled, "next should be called for boundary-size body (fail-open ignores webhook error)") + assert.Equal(t, http.StatusOK, rr.Code) + }) +} + func TestMiddlewareParams_Validate(t *testing.T) { t.Parallel() tests := []struct {