Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 43 additions & 5 deletions client/transport/sse.go
Original file line number Diff line number Diff line change
Expand Up @@ -177,16 +177,31 @@ func (c *SSE) Start(ctx context.Context) error {
go c.readSSE(resp.Body)

// Wait for the endpoint to be received
timeout := time.NewTimer(30 * time.Second)
defer timeout.Stop()
endpointTimeout := 30 * time.Second
if deadline, ok := ctx.Deadline(); ok {
remaining := time.Until(deadline)
// If context deadline has already passed, return immediately
if remaining <= 0 {
cancel()
return ctx.Err()
}
// Use the shorter of remaining time or default timeout
if remaining < endpointTimeout {
endpointTimeout = remaining
}
}

timer := time.NewTimer(endpointTimeout)
defer timer.Stop()

select {
case <-c.endpointChan:
// Endpoint received, proceed
case <-ctx.Done():
return fmt.Errorf("context cancelled while waiting for endpoint")
case <-timeout.C: // Add a timeout
return fmt.Errorf("context cancelled while waiting for endpoint: %w", ctx.Err())
case <-timer.C:
cancel()
return fmt.Errorf("timeout waiting for endpoint")
return fmt.Errorf("timeout waiting for endpoint after %v", endpointTimeout)
}

c.started.Store(true)
Expand Down Expand Up @@ -419,6 +434,7 @@ func (c *SSE) SendRequest(
resp.Body.Close()

if err != nil {
deleteResponseChan()
return nil, fmt.Errorf("failed to read response body: %w", err)
}

Expand All @@ -439,10 +455,32 @@ func (c *SSE) SendRequest(
return nil, fmt.Errorf("request failed with status %d: %s", resp.StatusCode, body)
}

// Calculate response timeout
responseTimeout := 60 * time.Second
if deadline, ok := ctx.Deadline(); ok {
remaining := time.Until(deadline)
// Check if context deadline has already passed
if remaining <= 0 {
deleteResponseChan()
return nil, ctx.Err()
}
// Use the shorter of remaining time or default timeout
if remaining < responseTimeout {
responseTimeout = remaining
}
}

timer := time.NewTimer(responseTimeout)
defer timer.Stop()

select {
case <-ctx.Done():
deleteResponseChan()
return nil, ctx.Err()
case <-timer.C:
// Timeout handling
deleteResponseChan()
return nil, fmt.Errorf("timeout waiting for SSE response after %v", responseTimeout)
case response, ok := <-responseChan:
if ok {
return response, nil
Expand Down
201 changes: 201 additions & 0 deletions client/transport/sse_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1077,3 +1077,204 @@ func TestSSE_SendNotification_Unauthorized_StaticToken(t *testing.T) {
// Clean up
transport.Close()
}

func TestSSE_SendRequest_Timeout(t *testing.T) {
t.Run("TimeoutWhenServerNeverResponds", func(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Header.Get("Accept") == "text/event-stream" {
w.Header().Set("Content-Type", "text/event-stream")
w.WriteHeader(http.StatusOK)
flusher, _ := w.(http.Flusher)
fmt.Fprintf(w, "event: endpoint\ndata: /message\n\n")
flusher.Flush()
<-r.Context().Done()
return
}

if r.Method == http.MethodPost {
w.WriteHeader(http.StatusAccepted)
return
}
}))
defer server.Close()

transport, err := NewSSE(server.URL)
require.NoError(t, err)
defer transport.Close()

ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()

err = transport.Start(ctx)
require.NoError(t, err)

requestCtx, requestCancel := context.WithTimeout(context.Background(), 2*time.Second)
defer requestCancel()

request := JSONRPCRequest{
JSONRPC: "2.0",
ID: mcp.NewRequestId(int64(1)),
Method: "test/timeout",
}

startTime := time.Now()
_, err = transport.SendRequest(requestCtx, request)
duration := time.Since(startTime)

require.Error(t, err, "Expected timeout error")
require.Contains(t, err.Error(), "timeout", "Error should mention timeout")
expectedTimeout := 2 * time.Second
require.GreaterOrEqual(t, duration, expectedTimeout*7/10) // 70% of expected
require.LessOrEqual(t, duration, expectedTimeout*13/10) // 130% of expected
})

t.Run("ContextDeadlineTakesPrecedence", func(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Header.Get("Accept") == "text/event-stream" {
w.Header().Set("Content-Type", "text/event-stream")
w.WriteHeader(http.StatusOK)
flusher, _ := w.(http.Flusher)
fmt.Fprintf(w, "event: endpoint\ndata: /message\n\n")
flusher.Flush()
<-r.Context().Done()
return
}

if r.Method == http.MethodPost {
w.WriteHeader(http.StatusAccepted)
return
}
}))
defer server.Close()

transport, err := NewSSE(server.URL)
require.NoError(t, err)
defer transport.Close()

ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()

err = transport.Start(ctx)
require.NoError(t, err)

requestCtx, requestCancel := context.WithTimeout(context.Background(), 1*time.Second)
defer requestCancel()

request := JSONRPCRequest{
JSONRPC: "2.0",
ID: mcp.NewRequestId(int64(1)),
Method: "test/deadline",
}

startTime := time.Now()
_, err = transport.SendRequest(requestCtx, request)
duration := time.Since(startTime)

require.Error(t, err)
errMsg := err.Error()
require.True(t,
strings.Contains(errMsg, "timeout") || strings.Contains(errMsg, "deadline exceeded"),
"Error should mention timeout or deadline, got: %v", err)
require.LessOrEqual(t, duration, 1500*time.Millisecond, "Should respect context deadline of 1s")
})

t.Run("TimeoutCleansUpResponseChannel", func(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Header.Get("Accept") == "text/event-stream" {
w.Header().Set("Content-Type", "text/event-stream")
w.WriteHeader(http.StatusOK)
flusher, _ := w.(http.Flusher)
fmt.Fprintf(w, "event: endpoint\ndata: /message\n\n")
flusher.Flush()
<-r.Context().Done()
return
}

if r.Method == http.MethodPost {
w.WriteHeader(http.StatusAccepted)
return
}
}))
defer server.Close()

transport, err := NewSSE(server.URL)
require.NoError(t, err)
defer transport.Close()

ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()

err = transport.Start(ctx)
require.NoError(t, err)

transport.mu.RLock()
initialCount := len(transport.responses)
transport.mu.RUnlock()
require.Equal(t, 0, initialCount)

requestCtx, requestCancel := context.WithTimeout(context.Background(), 500*time.Millisecond)
defer requestCancel()

request := JSONRPCRequest{
JSONRPC: "2.0",
ID: mcp.NewRequestId(int64(999)),
Method: "test/timeout",
}

_, err = transport.SendRequest(requestCtx, request)
require.Error(t, err)

time.Sleep(50 * time.Millisecond)

transport.mu.RLock()
finalCount := len(transport.responses)
transport.mu.RUnlock()

require.Equal(t, 0, finalCount)
})

t.Run("AlreadyExpiredDeadline", func(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Header.Get("Accept") == "text/event-stream" {
w.Header().Set("Content-Type", "text/event-stream")
w.WriteHeader(http.StatusOK)
flusher, _ := w.(http.Flusher)
fmt.Fprintf(w, "event: endpoint\ndata: /message\n\n")
flusher.Flush()
<-r.Context().Done()
return
}

if r.Method == http.MethodPost {
w.WriteHeader(http.StatusAccepted)
return
}
}))
defer server.Close()

transport, err := NewSSE(server.URL)
require.NoError(t, err)
defer transport.Close()

ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()

err = transport.Start(ctx)
require.NoError(t, err)

expiredCtx, expiredCancel := context.WithDeadline(context.Background(), time.Now().Add(-1*time.Second))
defer expiredCancel()

request := JSONRPCRequest{
JSONRPC: "2.0",
ID: mcp.NewRequestId(int64(1)),
Method: "test/expired",
}

_, err = transport.SendRequest(expiredCtx, request)

require.Error(t, err)
require.True(t, errors.Is(err, context.DeadlineExceeded),
"Expected context.DeadlineExceeded, got: %v", err)
})
}