From 197eaed913253ea555ae102b723dd9e037eff3ae Mon Sep 17 00:00:00 2001 From: Ravish Date: Fri, 27 Mar 2026 15:59:49 -0700 Subject: [PATCH] mcp: add ErrorHandler to ServerOptions Add an optional ErrorHandler callback to ServerOptions for receiving out-of-band errors that occur during server operation. These include keepalive ping failures, notification delivery errors, and internal JSON-RPC protocol errors. When ErrorHandler is nil, errors continue to be logged at the appropriate level using the configured Logger. Fixes #218 --- mcp/client.go | 6 +++--- mcp/server.go | 26 ++++++++++++++++++++++---- mcp/server_test.go | 25 +++++++++++++++++++++++++ mcp/shared.go | 15 ++++++++++++--- mcp/transport.go | 8 ++++++-- 5 files changed, 68 insertions(+), 12 deletions(-) diff --git a/mcp/client.go b/mcp/client.go index 74900b1c..199e9c9c 100644 --- a/mcp/client.go +++ b/mcp/client.go @@ -253,7 +253,7 @@ func (c *Client) capabilities(protocolVersion string) *ClientCapabilities { // server, calls or notifications will return an error wrapping // [ErrConnectionClosed]. func (c *Client) Connect(ctx context.Context, t Transport, opts *ClientSessionOptions) (cs *ClientSession, err error) { - cs, err = connect(ctx, t, c, (*clientSessionState)(nil), nil) + cs, err = connect(ctx, t, c, (*clientSessionState)(nil), nil, nil) if err != nil { return nil, err } @@ -404,7 +404,7 @@ func (cs *ClientSession) registerElicitationWaiter(elicitationID string) (await // startKeepalive starts the keepalive mechanism for this client session. func (cs *ClientSession) startKeepalive(interval time.Duration) { - startKeepalive(cs, interval, &cs.keepaliveCancel) + startKeepalive(cs, interval, &cs.keepaliveCancel, nil) } // AddRoots adds the given roots to the client, @@ -441,7 +441,7 @@ func changeAndNotify[P Params](c *Client, notification string, params P, change } } c.mu.Unlock() - notifySessions(sessions, notification, params, c.opts.Logger) + notifySessions(sessions, notification, params, c.opts.Logger, nil) } // shouldSendListChangedNotification checks if the client's capabilities allow diff --git a/mcp/server.go b/mcp/server.go index e3c03e27..d5d81b8a 100644 --- a/mcp/server.go +++ b/mcp/server.go @@ -135,6 +135,15 @@ type ServerOptions struct { // trade-offs and usage guidance. SchemaCache *SchemaCache + // ErrorHandler, if non-nil, is called with out-of-band errors that occur + // during server operation but are not associated with a specific request. + // Examples include keepalive ping failures, notification delivery errors, + // and internal JSON-RPC protocol errors. + // + // If nil, these errors are logged using [ServerOptions.Logger] at the + // appropriate level. + ErrorHandler func(error) + // GetSessionID provides the next session ID to use for an incoming request. // If nil, a default randomly generated ID will be used. // @@ -198,6 +207,15 @@ func NewServer(impl *Implementation, options *ServerOptions) *Server { } } +// reportError reports an out-of-band error via the ErrorHandler, or logs it. +func (s *Server) reportError(err error) { + if h := s.opts.ErrorHandler; h != nil { + h(err) + } else { + s.opts.Logger.Error("out-of-band error", "error", err) + } +} + // AddPrompt adds a [Prompt] to the server, or replaces one with the same name. func (s *Server) AddPrompt(p *Prompt, h PromptHandler) { // Assume there was a change, since add replaces existing items. @@ -644,7 +662,7 @@ func (s *Server) notifySessions(n string) { sessions := slices.Clone(s.sessions) s.pendingNotifications[n] = nil s.mu.Unlock() // Don't hold the lock during notification: it causes deadlock. - notifySessions(sessions, n, changeNotificationParams[n], s.opts.Logger) + notifySessions(sessions, n, changeNotificationParams[n], s.opts.Logger, s.opts.ErrorHandler) } // shouldSendListChangedNotification checks if the server's capabilities allow @@ -873,7 +891,7 @@ func (s *Server) ResourceUpdated(ctx context.Context, params *ResourceUpdatedNot subscribedSessions := s.resourceSubscriptions[params.URI] sessions := slices.Collect(maps.Keys(subscribedSessions)) s.mu.Unlock() - notifySessions(sessions, notificationResourceUpdated, params, s.opts.Logger) + notifySessions(sessions, notificationResourceUpdated, params, s.opts.Logger, s.opts.ErrorHandler) s.opts.Logger.Info("resource updated notification sent", "uri", params.URI, "subscriber_count", len(sessions)) return nil } @@ -1015,7 +1033,7 @@ func (s *Server) Connect(ctx context.Context, t Transport, opts *ServerSessionOp } s.opts.Logger.Info("server connecting") - ss, err := connect(ctx, t, s, state, onClose) + ss, err := connect(ctx, t, s, state, onClose, s.opts.ErrorHandler) if err != nil { s.opts.Logger.Error("server connect error", "error", err) return nil, err @@ -1515,7 +1533,7 @@ func (ss *ServerSession) Wait() error { // startKeepalive starts the keepalive mechanism for this server session. func (ss *ServerSession) startKeepalive(interval time.Duration) { - startKeepalive(ss, interval, &ss.keepaliveCancel) + startKeepalive(ss, interval, &ss.keepaliveCancel, ss.server.opts.ErrorHandler) } // pageToken is the internal structure for the opaque pagination cursor. diff --git a/mcp/server_test.go b/mcp/server_test.go index 227e7be3..1fc0b8c1 100644 --- a/mcp/server_test.go +++ b/mcp/server_test.go @@ -8,6 +8,7 @@ import ( "bytes" "context" "encoding/json" + "errors" "log" "log/slog" "slices" @@ -1001,3 +1002,27 @@ func TestServerCapabilitiesOverWire(t *testing.T) { }) } } + +func TestErrorHandler(t *testing.T) { + t.Run("reportError calls ErrorHandler", func(t *testing.T) { + var got error + s := NewServer(testImpl, &ServerOptions{ + ErrorHandler: func(err error) { got = err }, + }) + s.reportError(errors.New("test error")) + if got == nil || got.Error() != "test error" { + t.Errorf("ErrorHandler got %v, want 'test error'", got) + } + }) + + t.Run("reportError falls back to logger", func(t *testing.T) { + var buf bytes.Buffer + s := NewServer(testImpl, &ServerOptions{ + Logger: slog.New(slog.NewTextHandler(&buf, nil)), + }) + s.reportError(errors.New("logged error")) + if !strings.Contains(buf.String(), "logged error") { + t.Errorf("log output = %q, want containing 'logged error'", buf.String()) + } + }) +} diff --git a/mcp/shared.go b/mcp/shared.go index bda00c20..cdf3175b 100644 --- a/mcp/shared.go +++ b/mcp/shared.go @@ -391,7 +391,8 @@ const ( // notifySessions calls Notify on all the sessions. // Should be called on a copy of the peer sessions. // The logger must be non-nil. -func notifySessions[S Session, P Params](sessions []S, method string, params P, logger *slog.Logger) { +// If onError is non-nil, it is called for each notification error instead of logging. +func notifySessions[S Session, P Params](sessions []S, method string, params P, logger *slog.Logger, onError func(error)) { if sessions == nil { return } @@ -406,7 +407,11 @@ func notifySessions[S Session, P Params](sessions []S, method string, params P, for _, s := range sessions { req := newRequest(s, params) if err := handleNotify(ctx, method, req); err != nil { - logger.Warn(fmt.Sprintf("calling %s: %v", method, err)) + if onError != nil { + onError(fmt.Errorf("calling %s: %w", method, err)) + } else { + logger.Warn(fmt.Sprintf("calling %s: %v", method, err)) + } } } } @@ -581,7 +586,8 @@ type keepaliveSession interface { // startKeepalive starts the keepalive mechanism for a session. // It assigns the cancel function to the provided cancelPtr and starts a goroutine // that sends ping messages at the specified interval. -func startKeepalive(session keepaliveSession, interval time.Duration, cancelPtr *context.CancelFunc) { +// If onError is non-nil, it is called when a ping fails before the session is closed. +func startKeepalive(session keepaliveSession, interval time.Duration, cancelPtr *context.CancelFunc, onError func(error)) { ctx, cancel := context.WithCancel(context.Background()) // Assign cancel function before starting goroutine to avoid race condition. // We cannot return it because the caller may need to cancel during the @@ -601,6 +607,9 @@ func startKeepalive(session keepaliveSession, interval time.Duration, cancelPtr err := session.Ping(pingCtx, nil) pingCancel() if err != nil { + if onError != nil { + onError(fmt.Errorf("keepalive ping failed: %w", err)) + } // Ping failed, close the session _ = session.Close() return diff --git a/mcp/transport.go b/mcp/transport.go index 5f2a5007..3827a69e 100644 --- a/mcp/transport.go +++ b/mcp/transport.go @@ -152,7 +152,7 @@ type handler interface { handle(ctx context.Context, req *jsonrpc.Request) (any, error) } -func connect[H handler, State any](ctx context.Context, t Transport, b binder[H, State], s State, onClose func()) (H, error) { +func connect[H handler, State any](ctx context.Context, t Transport, b binder[H, State], s State, onClose func(), onError func(error)) (H, error) { var zero H mcpConn, err := t.Connect(ctx) if err != nil { @@ -169,6 +169,10 @@ func connect[H handler, State any](ctx context.Context, t Transport, b binder[H, preempter.conn = conn return jsonrpc2.HandlerFunc(h.handle) } + onInternalError := func(err error) { log.Printf("jsonrpc2 error: %v", err) } + if onError != nil { + onInternalError = func(err error) { onError(fmt.Errorf("jsonrpc2: %w", err)) } + } _ = jsonrpc2.NewConnection(ctx, jsonrpc2.ConnectionConfig{ Reader: reader, Writer: writer, @@ -178,7 +182,7 @@ func connect[H handler, State any](ctx context.Context, t Transport, b binder[H, OnDone: func() { b.disconnect(h) }, - OnInternalError: func(err error) { log.Printf("jsonrpc2 error: %v", err) }, + OnInternalError: onInternalError, }) assert(preempter.conn != nil, "unbound preempter") return h, nil