Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions mcp/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ func (c *Client) capabilities(protocolVersion string) *ClientCapabilities {
// server, calls or notifications will return an error wrapping
// [ErrConnectionClosed].
func (c *Client) Connect(ctx context.Context, t Transport, opts *ClientSessionOptions) (cs *ClientSession, err error) {
cs, err = connect(ctx, t, c, (*clientSessionState)(nil), nil)
cs, err = connect(ctx, t, c, (*clientSessionState)(nil), nil, nil)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -404,7 +404,7 @@ func (cs *ClientSession) registerElicitationWaiter(elicitationID string) (await

// startKeepalive starts the keepalive mechanism for this client session.
func (cs *ClientSession) startKeepalive(interval time.Duration) {
startKeepalive(cs, interval, &cs.keepaliveCancel)
startKeepalive(cs, interval, &cs.keepaliveCancel, nil)
}

// AddRoots adds the given roots to the client,
Expand Down Expand Up @@ -441,7 +441,7 @@ func changeAndNotify[P Params](c *Client, notification string, params P, change
}
}
c.mu.Unlock()
notifySessions(sessions, notification, params, c.opts.Logger)
notifySessions(sessions, notification, params, c.opts.Logger, nil)
}

// shouldSendListChangedNotification checks if the client's capabilities allow
Expand Down
26 changes: 22 additions & 4 deletions mcp/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,15 @@ type ServerOptions struct {
// trade-offs and usage guidance.
SchemaCache *SchemaCache

// ErrorHandler, if non-nil, is called with out-of-band errors that occur
// during server operation but are not associated with a specific request.
// Examples include keepalive ping failures, notification delivery errors,
// and internal JSON-RPC protocol errors.
//
// If nil, these errors are logged using [ServerOptions.Logger] at the
// appropriate level.
ErrorHandler func(error)

// GetSessionID provides the next session ID to use for an incoming request.
// If nil, a default randomly generated ID will be used.
//
Expand Down Expand Up @@ -198,6 +207,15 @@ func NewServer(impl *Implementation, options *ServerOptions) *Server {
}
}

// reportError reports an out-of-band error via the ErrorHandler, or logs it.
func (s *Server) reportError(err error) {
if h := s.opts.ErrorHandler; h != nil {
h(err)
} else {
s.opts.Logger.Error("out-of-band error", "error", err)
}
}

// AddPrompt adds a [Prompt] to the server, or replaces one with the same name.
func (s *Server) AddPrompt(p *Prompt, h PromptHandler) {
// Assume there was a change, since add replaces existing items.
Expand Down Expand Up @@ -644,7 +662,7 @@ func (s *Server) notifySessions(n string) {
sessions := slices.Clone(s.sessions)
s.pendingNotifications[n] = nil
s.mu.Unlock() // Don't hold the lock during notification: it causes deadlock.
notifySessions(sessions, n, changeNotificationParams[n], s.opts.Logger)
notifySessions(sessions, n, changeNotificationParams[n], s.opts.Logger, s.opts.ErrorHandler)
}

// shouldSendListChangedNotification checks if the server's capabilities allow
Expand Down Expand Up @@ -873,7 +891,7 @@ func (s *Server) ResourceUpdated(ctx context.Context, params *ResourceUpdatedNot
subscribedSessions := s.resourceSubscriptions[params.URI]
sessions := slices.Collect(maps.Keys(subscribedSessions))
s.mu.Unlock()
notifySessions(sessions, notificationResourceUpdated, params, s.opts.Logger)
notifySessions(sessions, notificationResourceUpdated, params, s.opts.Logger, s.opts.ErrorHandler)
s.opts.Logger.Info("resource updated notification sent", "uri", params.URI, "subscriber_count", len(sessions))
return nil
}
Expand Down Expand Up @@ -1015,7 +1033,7 @@ func (s *Server) Connect(ctx context.Context, t Transport, opts *ServerSessionOp
}

s.opts.Logger.Info("server connecting")
ss, err := connect(ctx, t, s, state, onClose)
ss, err := connect(ctx, t, s, state, onClose, s.opts.ErrorHandler)
if err != nil {
s.opts.Logger.Error("server connect error", "error", err)
return nil, err
Expand Down Expand Up @@ -1515,7 +1533,7 @@ func (ss *ServerSession) Wait() error {

// startKeepalive starts the keepalive mechanism for this server session.
func (ss *ServerSession) startKeepalive(interval time.Duration) {
startKeepalive(ss, interval, &ss.keepaliveCancel)
startKeepalive(ss, interval, &ss.keepaliveCancel, ss.server.opts.ErrorHandler)
}

// pageToken is the internal structure for the opaque pagination cursor.
Expand Down
25 changes: 25 additions & 0 deletions mcp/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"bytes"
"context"
"encoding/json"
"errors"
"log"
"log/slog"
"slices"
Expand Down Expand Up @@ -1001,3 +1002,27 @@ func TestServerCapabilitiesOverWire(t *testing.T) {
})
}
}

func TestErrorHandler(t *testing.T) {
t.Run("reportError calls ErrorHandler", func(t *testing.T) {
var got error
s := NewServer(testImpl, &ServerOptions{
ErrorHandler: func(err error) { got = err },
})
s.reportError(errors.New("test error"))
if got == nil || got.Error() != "test error" {
t.Errorf("ErrorHandler got %v, want 'test error'", got)
}
})

t.Run("reportError falls back to logger", func(t *testing.T) {
var buf bytes.Buffer
s := NewServer(testImpl, &ServerOptions{
Logger: slog.New(slog.NewTextHandler(&buf, nil)),
})
s.reportError(errors.New("logged error"))
if !strings.Contains(buf.String(), "logged error") {
t.Errorf("log output = %q, want containing 'logged error'", buf.String())
}
})
}
15 changes: 12 additions & 3 deletions mcp/shared.go
Original file line number Diff line number Diff line change
Expand Up @@ -391,7 +391,8 @@ const (
// notifySessions calls Notify on all the sessions.
// Should be called on a copy of the peer sessions.
// The logger must be non-nil.
func notifySessions[S Session, P Params](sessions []S, method string, params P, logger *slog.Logger) {
// If onError is non-nil, it is called for each notification error instead of logging.
func notifySessions[S Session, P Params](sessions []S, method string, params P, logger *slog.Logger, onError func(error)) {
if sessions == nil {
return
}
Expand All @@ -406,7 +407,11 @@ func notifySessions[S Session, P Params](sessions []S, method string, params P,
for _, s := range sessions {
req := newRequest(s, params)
if err := handleNotify(ctx, method, req); err != nil {
logger.Warn(fmt.Sprintf("calling %s: %v", method, err))
if onError != nil {
onError(fmt.Errorf("calling %s: %w", method, err))
} else {
logger.Warn(fmt.Sprintf("calling %s: %v", method, err))
}
}
}
}
Expand Down Expand Up @@ -581,7 +586,8 @@ type keepaliveSession interface {
// startKeepalive starts the keepalive mechanism for a session.
// It assigns the cancel function to the provided cancelPtr and starts a goroutine
// that sends ping messages at the specified interval.
func startKeepalive(session keepaliveSession, interval time.Duration, cancelPtr *context.CancelFunc) {
// If onError is non-nil, it is called when a ping fails before the session is closed.
func startKeepalive(session keepaliveSession, interval time.Duration, cancelPtr *context.CancelFunc, onError func(error)) {
ctx, cancel := context.WithCancel(context.Background())
// Assign cancel function before starting goroutine to avoid race condition.
// We cannot return it because the caller may need to cancel during the
Expand All @@ -601,6 +607,9 @@ func startKeepalive(session keepaliveSession, interval time.Duration, cancelPtr
err := session.Ping(pingCtx, nil)
pingCancel()
if err != nil {
if onError != nil {
onError(fmt.Errorf("keepalive ping failed: %w", err))
}
// Ping failed, close the session
_ = session.Close()
return
Expand Down
8 changes: 6 additions & 2 deletions mcp/transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ type handler interface {
handle(ctx context.Context, req *jsonrpc.Request) (any, error)
}

func connect[H handler, State any](ctx context.Context, t Transport, b binder[H, State], s State, onClose func()) (H, error) {
func connect[H handler, State any](ctx context.Context, t Transport, b binder[H, State], s State, onClose func(), onError func(error)) (H, error) {
var zero H
mcpConn, err := t.Connect(ctx)
if err != nil {
Expand All @@ -169,6 +169,10 @@ func connect[H handler, State any](ctx context.Context, t Transport, b binder[H,
preempter.conn = conn
return jsonrpc2.HandlerFunc(h.handle)
}
onInternalError := func(err error) { log.Printf("jsonrpc2 error: %v", err) }
if onError != nil {
onInternalError = func(err error) { onError(fmt.Errorf("jsonrpc2: %w", err)) }
}
_ = jsonrpc2.NewConnection(ctx, jsonrpc2.ConnectionConfig{
Reader: reader,
Writer: writer,
Expand All @@ -178,7 +182,7 @@ func connect[H handler, State any](ctx context.Context, t Transport, b binder[H,
OnDone: func() {
b.disconnect(h)
},
OnInternalError: func(err error) { log.Printf("jsonrpc2 error: %v", err) },
OnInternalError: onInternalError,
})
assert(preempter.conn != nil, "unbound preempter")
return h, nil
Expand Down