diff --git a/mcp/streamable.go b/mcp/streamable.go index 7f09b40b..32482857 100644 --- a/mcp/streamable.go +++ b/mcp/streamable.go @@ -274,19 +274,7 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque // Allow multiple 'Accept' headers. // https://developer.mozilla.org/en-US/docs/Web/HTTP/Reference/Headers/Accept#syntax - accept := strings.Split(strings.Join(req.Header.Values("Accept"), ","), ",") - var jsonOK, streamOK bool - for _, c := range accept { - switch strings.TrimSpace(c) { - case "application/json", "application/*": - jsonOK = true - case "text/event-stream", "text/*": - streamOK = true - case "*/*": - jsonOK = true - streamOK = true - } - } + jsonOK, streamOK := streamableAccepts(req.Header.Values("Accept")) if req.Method == http.MethodGet { if !streamOK { @@ -551,6 +539,26 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque sessInfo.transport.ServeHTTP(w, req) } +func streamableAccepts(values []string) (jsonOK, streamOK bool) { + for _, value := range values { + for _, raw := range strings.Split(value, ",") { + token := strings.TrimSpace(raw) + // Ignore Accept parameters like ";charset=utf-8"; match the base media type. + base, _, _ := strings.Cut(token, ";") + switch strings.ToLower(strings.TrimSpace(base)) { + case "application/json", "application/*": + jsonOK = true + case "text/event-stream", "text/*": + streamOK = true + case "*/*": + jsonOK = true + streamOK = true + } + } + } + return jsonOK, streamOK +} + // A StreamableServerTransport implements the server side of the MCP streamable // transport. // diff --git a/mcp/streamable_test.go b/mcp/streamable_test.go index d1dc482a..e26c5cc2 100644 --- a/mcp/streamable_test.go +++ b/mcp/streamable_test.go @@ -894,12 +894,26 @@ func TestStreamableServerTransport(t *testing.T) { wantStatusCode: http.StatusOK, wantMessages: []jsonrpc.Message{resp(4, &CallToolResult{Content: []Content{}}, nil)}, }, + { + method: "POST", + headers: http.Header{"Accept": {"application/json;charset=utf-8, text/event-stream"}}, + messages: []jsonrpc.Message{req(5, "tools/call", &CallToolParams{Name: "tool"})}, + wantStatusCode: http.StatusOK, + wantMessages: []jsonrpc.Message{resp(5, &CallToolResult{Content: []Content{}}, nil)}, + }, + { + method: "POST", + headers: http.Header{"Accept": {"application/json;charset=utf-8", "text/event-stream"}}, + messages: []jsonrpc.Message{req(6, "tools/call", &CallToolParams{Name: "tool"})}, + wantStatusCode: http.StatusOK, + wantMessages: []jsonrpc.Message{resp(6, &CallToolResult{Content: []Content{}}, nil)}, + }, { method: "POST", headers: http.Header{"Accept": {"text/*, application/*"}}, - messages: []jsonrpc.Message{req(4, "tools/call", &CallToolParams{Name: "tool"})}, + messages: []jsonrpc.Message{req(7, "tools/call", &CallToolParams{Name: "tool"})}, wantStatusCode: http.StatusOK, - wantMessages: []jsonrpc.Message{resp(4, &CallToolResult{Content: []Content{}}, nil)}, + wantMessages: []jsonrpc.Message{resp(7, &CallToolResult{Content: []Content{}}, nil)}, }, }, wantSessions: 1, @@ -1960,12 +1974,49 @@ func TestStreamableGETWithoutSession(t *testing.T) { t.Fatal(err) } defer resp.Body.Close() + body, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatal(err) + } // GET without session should return 400 Bad Request, not 405 Method Not Allowed, // because GET is a valid method - it just requires a session ID. if got, want := resp.StatusCode, http.StatusBadRequest; got != want { t.Errorf("status code: got %d, want %d", got, want) } + if got, want := strings.TrimSpace(string(body)), "Bad Request: GET requires an Mcp-Session-Id header"; got != want { + t.Errorf("body: got %q, want %q", got, want) + } +} + +func TestStreamableGETWithoutEventStreamAccept(t *testing.T) { + server := NewServer(testImpl, nil) + handler := NewStreamableHTTPHandler(func(req *http.Request) *Server { return server }, nil) + httpServer := httptest.NewServer(mustNotPanic(t, handler)) + defer httpServer.Close() + + req, err := http.NewRequest("GET", httpServer.URL, nil) + if err != nil { + t.Fatal(err) + } + req.Header.Set("Accept", "application/json") + + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + body, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatal(err) + } + + if got, want := resp.StatusCode, http.StatusBadRequest; got != want { + t.Errorf("status code: got %d, want %d", got, want) + } + if got, want := strings.TrimSpace(string(body)), "Accept must contain 'text/event-stream' for GET requests"; got != want { + t.Errorf("body: got %q, want %q", got, want) + } } func TestStreamableClientContextPropagation(t *testing.T) {