From 32037c9edc082441243097f9e0bc503fbdb6d434 Mon Sep 17 00:00:00 2001 From: Kalvin Chau Date: Mon, 23 Mar 2026 15:51:16 -0700 Subject: [PATCH] mcp: accept parameterized Accept media types Normalize accept tokens before validation so valid headers with media type parameters still match their base media types. Add test coverage for parameterized and split accept headers. --- mcp/streamable.go | 34 ++++++++++++++++---------- mcp/streamable_test.go | 55 ++++++++++++++++++++++++++++++++++++++++-- 2 files changed, 74 insertions(+), 15 deletions(-) diff --git a/mcp/streamable.go b/mcp/streamable.go index 7f09b40b..32482857 100644 --- a/mcp/streamable.go +++ b/mcp/streamable.go @@ -274,19 +274,7 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque // Allow multiple 'Accept' headers. // https://developer.mozilla.org/en-US/docs/Web/HTTP/Reference/Headers/Accept#syntax - accept := strings.Split(strings.Join(req.Header.Values("Accept"), ","), ",") - var jsonOK, streamOK bool - for _, c := range accept { - switch strings.TrimSpace(c) { - case "application/json", "application/*": - jsonOK = true - case "text/event-stream", "text/*": - streamOK = true - case "*/*": - jsonOK = true - streamOK = true - } - } + jsonOK, streamOK := streamableAccepts(req.Header.Values("Accept")) if req.Method == http.MethodGet { if !streamOK { @@ -551,6 +539,26 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque sessInfo.transport.ServeHTTP(w, req) } +func streamableAccepts(values []string) (jsonOK, streamOK bool) { + for _, value := range values { + for _, raw := range strings.Split(value, ",") { + token := strings.TrimSpace(raw) + // Ignore Accept parameters like ";charset=utf-8"; match the base media type. + base, _, _ := strings.Cut(token, ";") + switch strings.ToLower(strings.TrimSpace(base)) { + case "application/json", "application/*": + jsonOK = true + case "text/event-stream", "text/*": + streamOK = true + case "*/*": + jsonOK = true + streamOK = true + } + } + } + return jsonOK, streamOK +} + // A StreamableServerTransport implements the server side of the MCP streamable // transport. // diff --git a/mcp/streamable_test.go b/mcp/streamable_test.go index d1dc482a..e26c5cc2 100644 --- a/mcp/streamable_test.go +++ b/mcp/streamable_test.go @@ -894,12 +894,26 @@ func TestStreamableServerTransport(t *testing.T) { wantStatusCode: http.StatusOK, wantMessages: []jsonrpc.Message{resp(4, &CallToolResult{Content: []Content{}}, nil)}, }, + { + method: "POST", + headers: http.Header{"Accept": {"application/json;charset=utf-8, text/event-stream"}}, + messages: []jsonrpc.Message{req(5, "tools/call", &CallToolParams{Name: "tool"})}, + wantStatusCode: http.StatusOK, + wantMessages: []jsonrpc.Message{resp(5, &CallToolResult{Content: []Content{}}, nil)}, + }, + { + method: "POST", + headers: http.Header{"Accept": {"application/json;charset=utf-8", "text/event-stream"}}, + messages: []jsonrpc.Message{req(6, "tools/call", &CallToolParams{Name: "tool"})}, + wantStatusCode: http.StatusOK, + wantMessages: []jsonrpc.Message{resp(6, &CallToolResult{Content: []Content{}}, nil)}, + }, { method: "POST", headers: http.Header{"Accept": {"text/*, application/*"}}, - messages: []jsonrpc.Message{req(4, "tools/call", &CallToolParams{Name: "tool"})}, + messages: []jsonrpc.Message{req(7, "tools/call", &CallToolParams{Name: "tool"})}, wantStatusCode: http.StatusOK, - wantMessages: []jsonrpc.Message{resp(4, &CallToolResult{Content: []Content{}}, nil)}, + wantMessages: []jsonrpc.Message{resp(7, &CallToolResult{Content: []Content{}}, nil)}, }, }, wantSessions: 1, @@ -1960,12 +1974,49 @@ func TestStreamableGETWithoutSession(t *testing.T) { t.Fatal(err) } defer resp.Body.Close() + body, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatal(err) + } // GET without session should return 400 Bad Request, not 405 Method Not Allowed, // because GET is a valid method - it just requires a session ID. if got, want := resp.StatusCode, http.StatusBadRequest; got != want { t.Errorf("status code: got %d, want %d", got, want) } + if got, want := strings.TrimSpace(string(body)), "Bad Request: GET requires an Mcp-Session-Id header"; got != want { + t.Errorf("body: got %q, want %q", got, want) + } +} + +func TestStreamableGETWithoutEventStreamAccept(t *testing.T) { + server := NewServer(testImpl, nil) + handler := NewStreamableHTTPHandler(func(req *http.Request) *Server { return server }, nil) + httpServer := httptest.NewServer(mustNotPanic(t, handler)) + defer httpServer.Close() + + req, err := http.NewRequest("GET", httpServer.URL, nil) + if err != nil { + t.Fatal(err) + } + req.Header.Set("Accept", "application/json") + + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + body, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatal(err) + } + + if got, want := resp.StatusCode, http.StatusBadRequest; got != want { + t.Errorf("status code: got %d, want %d", got, want) + } + if got, want := strings.TrimSpace(string(body)), "Accept must contain 'text/event-stream' for GET requests"; got != want { + t.Errorf("body: got %q, want %q", got, want) + } } func TestStreamableClientContextPropagation(t *testing.T) {