Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 21 additions & 13 deletions mcp/streamable.go
Original file line number Diff line number Diff line change
Expand Up @@ -274,19 +274,7 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque

// Allow multiple 'Accept' headers.
// https://developer.mozilla.org/en-US/docs/Web/HTTP/Reference/Headers/Accept#syntax
accept := strings.Split(strings.Join(req.Header.Values("Accept"), ","), ",")
var jsonOK, streamOK bool
for _, c := range accept {
switch strings.TrimSpace(c) {
case "application/json", "application/*":
jsonOK = true
case "text/event-stream", "text/*":
streamOK = true
case "*/*":
jsonOK = true
streamOK = true
}
}
jsonOK, streamOK := streamableAccepts(req.Header.Values("Accept"))

if req.Method == http.MethodGet {
if !streamOK {
Expand Down Expand Up @@ -551,6 +539,26 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque
sessInfo.transport.ServeHTTP(w, req)
}

func streamableAccepts(values []string) (jsonOK, streamOK bool) {
for _, value := range values {
for _, raw := range strings.Split(value, ",") {
token := strings.TrimSpace(raw)
// Ignore Accept parameters like ";charset=utf-8"; match the base media type.
base, _, _ := strings.Cut(token, ";")
switch strings.ToLower(strings.TrimSpace(base)) {
case "application/json", "application/*":
jsonOK = true
case "text/event-stream", "text/*":
streamOK = true
case "*/*":
jsonOK = true
streamOK = true
}
}
}
return jsonOK, streamOK
}

// A StreamableServerTransport implements the server side of the MCP streamable
// transport.
//
Expand Down
55 changes: 53 additions & 2 deletions mcp/streamable_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -894,12 +894,26 @@ func TestStreamableServerTransport(t *testing.T) {
wantStatusCode: http.StatusOK,
wantMessages: []jsonrpc.Message{resp(4, &CallToolResult{Content: []Content{}}, nil)},
},
{
method: "POST",
headers: http.Header{"Accept": {"application/json;charset=utf-8, text/event-stream"}},
messages: []jsonrpc.Message{req(5, "tools/call", &CallToolParams{Name: "tool"})},
wantStatusCode: http.StatusOK,
wantMessages: []jsonrpc.Message{resp(5, &CallToolResult{Content: []Content{}}, nil)},
},
{
method: "POST",
headers: http.Header{"Accept": {"application/json;charset=utf-8", "text/event-stream"}},
messages: []jsonrpc.Message{req(6, "tools/call", &CallToolParams{Name: "tool"})},
wantStatusCode: http.StatusOK,
wantMessages: []jsonrpc.Message{resp(6, &CallToolResult{Content: []Content{}}, nil)},
},
{
method: "POST",
headers: http.Header{"Accept": {"text/*, application/*"}},
messages: []jsonrpc.Message{req(4, "tools/call", &CallToolParams{Name: "tool"})},
messages: []jsonrpc.Message{req(7, "tools/call", &CallToolParams{Name: "tool"})},
wantStatusCode: http.StatusOK,
wantMessages: []jsonrpc.Message{resp(4, &CallToolResult{Content: []Content{}}, nil)},
wantMessages: []jsonrpc.Message{resp(7, &CallToolResult{Content: []Content{}}, nil)},
},
},
wantSessions: 1,
Expand Down Expand Up @@ -1960,12 +1974,49 @@ func TestStreamableGETWithoutSession(t *testing.T) {
t.Fatal(err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatal(err)
}

// GET without session should return 400 Bad Request, not 405 Method Not Allowed,
// because GET is a valid method - it just requires a session ID.
if got, want := resp.StatusCode, http.StatusBadRequest; got != want {
t.Errorf("status code: got %d, want %d", got, want)
}
if got, want := strings.TrimSpace(string(body)), "Bad Request: GET requires an Mcp-Session-Id header"; got != want {
t.Errorf("body: got %q, want %q", got, want)
}
}

func TestStreamableGETWithoutEventStreamAccept(t *testing.T) {
server := NewServer(testImpl, nil)
handler := NewStreamableHTTPHandler(func(req *http.Request) *Server { return server }, nil)
httpServer := httptest.NewServer(mustNotPanic(t, handler))
defer httpServer.Close()

req, err := http.NewRequest("GET", httpServer.URL, nil)
if err != nil {
t.Fatal(err)
}
req.Header.Set("Accept", "application/json")

resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatal(err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatal(err)
}

if got, want := resp.StatusCode, http.StatusBadRequest; got != want {
t.Errorf("status code: got %d, want %d", got, want)
}
if got, want := strings.TrimSpace(string(body)), "Accept must contain 'text/event-stream' for GET requests"; got != want {
t.Errorf("body: got %q, want %q", got, want)
}
}

func TestStreamableClientContextPropagation(t *testing.T) {
Expand Down