Skip to content

Commit 4af6c16

Browse files
mcp: fix race condition in ServerSession.startKeepalive
1 parent 755b9ed commit 4af6c16

3 files changed

Lines changed: 36 additions & 52 deletions

File tree

mcp/server.go

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1028,6 +1028,13 @@ func (s *Server) Connect(ctx context.Context, t Transport, opts *ServerSessionOp
10281028
s.opts.Logger.Error("server connect error", "error", err)
10291029
return nil, err
10301030
}
1031+
1032+
// Start keepalive before returning the session to avoid race conditions with Close.
1033+
// This is safe because the spec allows sending pings before initialization (see ServerSession.handle for details).
1034+
if s.opts.KeepAlive > 0 {
1035+
ss.startKeepalive(ss.server.opts.KeepAlive)
1036+
}
1037+
10311038
return ss, nil
10321039
}
10331040

@@ -1055,9 +1062,6 @@ func (ss *ServerSession) initialized(ctx context.Context, params *InitializedPar
10551062
ss.server.opts.Logger.Error("duplicate initialized notification")
10561063
return nil, fmt.Errorf("duplicate %q received", notificationInitialized)
10571064
}
1058-
if ss.server.opts.KeepAlive > 0 {
1059-
ss.startKeepalive(ss.server.opts.KeepAlive)
1060-
}
10611065
if h := ss.server.opts.InitializedHandler; h != nil {
10621066
h(ctx, serverRequestFor(ss, params))
10631067
}
@@ -1107,7 +1111,7 @@ type ServerSession struct {
11071111
server *Server
11081112
conn *jsonrpc2.Connection
11091113
mcpConn Connection
1110-
keepaliveCancel context.CancelFunc // TODO: theory around why keepaliveCancel need not be guarded
1114+
keepaliveCancel context.CancelFunc
11111115

11121116
mu sync.Mutex
11131117
state ServerSessionState

mcp/server_test.go

Lines changed: 0 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -508,54 +508,6 @@ func TestServerAddResourceTemplate(t *testing.T) {
508508
}
509509
}
510510

511-
// TestServerSessionkeepaliveCancelOverwritten is to verify that `ServerSession.keepaliveCancel` is assigned exactly once,
512-
// ensuring that only a single goroutine is responsible for the session's keepalive ping mechanism.
513-
func TestServerSessionkeepaliveCancelOverwritten(t *testing.T) {
514-
// Set KeepAlive to a long duration to ensure the keepalive
515-
// goroutine stays alive for the duration of the test without actually sending
516-
// ping requests, since we don't have a real client connection established.
517-
server := NewServer(testImpl, &ServerOptions{KeepAlive: 5 * time.Second})
518-
ss := &ServerSession{server: server}
519-
520-
// 1. Initialize the session.
521-
_, err := ss.initialize(context.Background(), &InitializeParams{})
522-
if err != nil {
523-
t.Fatalf("ServerSession initialize failed: %v", err)
524-
}
525-
526-
// 2. Call 'initialized' for the first time. This should start the keepalive mechanism.
527-
_, err = ss.initialized(context.Background(), &InitializedParams{})
528-
if err != nil {
529-
t.Fatalf("First initialized call failed: %v", err)
530-
}
531-
if ss.keepaliveCancel == nil {
532-
t.Fatalf("expected ServerSession.keepaliveCancel to be set after the first call of initialized")
533-
}
534-
535-
// Save the cancel function and use defer to ensure resources are cleaned up.
536-
firstCancel := ss.keepaliveCancel
537-
defer firstCancel()
538-
539-
// 3. Manually set the field to nil.
540-
// Do this to facilitate the test's core assertion. The goal is to verify that
541-
// 'ss.keepaliveCancel' is not assigned a second time. By setting it to nil,
542-
// we can easily check after the next call if a new keepalive goroutine was started.
543-
ss.keepaliveCancel = nil
544-
545-
// 4. Call 'initialized' for the second time. This should return an error.
546-
_, err = ss.initialized(context.Background(), &InitializedParams{})
547-
if err == nil {
548-
t.Fatalf("Expected 'duplicate initialized received' error on second call, got nil")
549-
}
550-
551-
// 5. Re-check the field to ensure it remains nil.
552-
// Since 'initialized' correctly returned an error and did not call
553-
// 'startKeepalive', the field should remain unchanged.
554-
if ss.keepaliveCancel != nil {
555-
t.Fatal("expected ServerSession.keepaliveCancel to be nil after we manually niled it and re-initialized")
556-
}
557-
}
558-
559511
// panicks reports whether f() panics.
560512
func panics(f func()) (b bool) {
561513
defer func() {

mcp/streamable_test.go

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -367,6 +367,34 @@ func TestStreamableServerShutdown(t *testing.T) {
367367
}
368368
}
369369

370+
// TestStreamableStatelessKeepaliveRace verifies that there is no data race between
371+
// ServerSession.startKeepalive and ServerSession.Close in stateless servers.
372+
func TestStreamableStatelessKeepaliveRace(t *testing.T) {
373+
ctx := context.Background()
374+
server := NewServer(testImpl, &ServerOptions{KeepAlive: time.Hour})
375+
AddTool(server, &Tool{Name: "greet"}, sayHi)
376+
handler := NewStreamableHTTPHandler(
377+
func(*http.Request) *Server { return server },
378+
&StreamableHTTPOptions{Stateless: true},
379+
)
380+
httpServer := httptest.NewServer(mustNotPanic(t, handler))
381+
defer httpServer.Close()
382+
383+
for range 50 {
384+
cs, err := NewClient(testImpl, nil).Connect(ctx, &StreamableClientTransport{
385+
Endpoint: httpServer.URL,
386+
}, nil)
387+
if err != nil {
388+
t.Fatalf("NewClient() failed: %v", err)
389+
}
390+
_, _ = cs.CallTool(ctx, &CallToolParams{
391+
Name: "greet",
392+
Arguments: map[string]any{"Name": "world"},
393+
})
394+
_ = cs.Close()
395+
}
396+
}
397+
370398
// TestClientReplay verifies that the client can recover from a mid-stream
371399
// network failure and receive replayed messages (if replay is configured). It
372400
// uses a proxy that is killed and restarted to simulate a recoverable network

0 commit comments

Comments
 (0)