diff --git a/client/client.go b/client/client.go index 1d75218f0..6ebb8a935 100644 --- a/client/client.go +++ b/client/client.go @@ -199,7 +199,7 @@ func (c *Client) Initialize( } // Add elicitation capability if handler is configured if c.elicitationHandler != nil { - capabilities.Elicitation = &struct{}{} + capabilities.Elicitation = &mcp.ElicitationCapability{} } // Ensure we send a params object with all required fields @@ -629,6 +629,10 @@ func (c *Client) handleElicitationRequestTransport(ctx context.Context, request } } + if err := params.Validate(); err != nil { + return nil, fmt.Errorf("invalid elicitation params: %w", err) + } + // Create the MCP request mcpRequest := mcp.ElicitationRequest{ Request: mcp.Request{ diff --git a/client/elicitation.go b/client/elicitation.go index 92f519bf9..514878728 100644 --- a/client/elicitation.go +++ b/client/elicitation.go @@ -11,8 +11,8 @@ import ( type ElicitationHandler interface { // Elicit handles an elicitation request from the server and returns the user's response. // The implementation should: - // 1. Present the request message to the user - // 2. Validate input against the requested schema + // 1. Present the request message to the user (and URL if in URL mode) + // 2. Validate input against the requested schema (for form mode) // 3. Allow the user to accept, decline, or cancel // 4. Return the appropriate response Elicit(ctx context.Context, request mcp.ElicitationRequest) (*mcp.ElicitationResult, error) diff --git a/client/inprocess_elicitation_test.go b/client/inprocess_elicitation_test.go index f659bbb10..9f83a1d1e 100644 --- a/client/inprocess_elicitation_test.go +++ b/client/inprocess_elicitation_test.go @@ -124,7 +124,7 @@ func TestInProcessElicitation(t *testing.T) { Version: "1.0.0", }, Capabilities: mcp.ClientCapabilities{ - Elicitation: &struct{}{}, + Elicitation: &mcp.ElicitationCapability{}, }, }, }) diff --git a/examples/elicitation/main.go b/examples/elicitation/main.go index 742d036ad..5075f6ceb 100644 --- a/examples/elicitation/main.go +++ b/examples/elicitation/main.go @@ -8,6 +8,7 @@ import ( "os/signal" "sync/atomic" + "github.com/google/uuid" "github.com/mark3labs/mcp-go/mcp" "github.com/mark3labs/mcp-go/server" ) @@ -129,7 +130,7 @@ func main() { server.WithElicitation(), // Enable elicitation ) - // Add a tool that uses elicitation + // Add a tool that uses elicitation (Form Mode) mcpServer.AddTool( mcp.NewTool( "create_project", @@ -138,7 +139,7 @@ func main() { demoElicitationHandler(mcpServer), ) - // Add another tool that demonstrates conditional elicitation + // Add another tool that demonstrates conditional elicitation (Form Mode) mcpServer.AddTool( mcp.NewTool( "process_data", @@ -236,7 +237,102 @@ func main() { }, ) - // Create and start stdio server + // Add a tool that uses URL elicitation (auth flow) + mcpServer.AddTool( + mcp.NewTool( + "auth_via_url", + mcp.WithDescription("Demonstrates out-of-band authentication via URL"), + ), + func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + session := server.ClientSessionFromContext(ctx) + if session == nil { + return nil, fmt.Errorf("no active session") + } + + // Generate unique elicitation ID + elicitationID := uuid.New().String() + + // Create URL with elicitation ID for tracking + // In a real application, you would store the ID and associate it with the user session + url := fmt.Sprintf("https://myserver.com/set-api-key?elicitationId=%s", elicitationID) + + // Request URL mode elicitation + result, err := mcpServer.RequestURLElicitation( + ctx, + session, + elicitationID, + url, + "Please authenticate in your browser to continue.", + ) + if err != nil { + return nil, fmt.Errorf("URL elicitation failed: %w", err) + } + + if result.Action == mcp.ElicitationResponseActionAccept { + // User consented to open the URL + // They will complete the flow in their browser + // Server will store credentials when user submits the form + + // Simulate sending completion notification + // NOTE: In production, this notification would be sent after + // the server receives the authentication callback from the browser. + // Here we simulate immediate completion for demonstration purposes. + if err := mcpServer.SendElicitationComplete(ctx, session, elicitationID); err != nil { + // Log error but continue + fmt.Fprintf(os.Stderr, "Failed to send completion notification: %v\n", err) + } + + return &mcp.CallToolResult{ + Content: []mcp.Content{ + mcp.NewTextContent("Authentication flow initiated. User accepted URL open request."), + }, + }, nil + } + + return &mcp.CallToolResult{ + Content: []mcp.Content{ + mcp.NewTextContent(fmt.Sprintf("User declined authentication: %s", result.Action)), + }, + }, nil + }, + ) + + // Add a tool that demonstrates returning URLElicitationRequiredError + mcpServer.AddTool( + mcp.NewTool( + "protected_action", + mcp.WithDescription("A protected action that requires prior authorization"), + ), + func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + // TODO: In production, check actual authorization state + // For demo purposes, we always trigger elicitation + isAuthorized := false // Always false to demonstrate error flow + + if !isAuthorized { + // When a request needs authorization that hasn't been set up + elicitationID := uuid.New().String() + + // Return a special error that tells the client to start elicitation + return nil, mcp.URLElicitationRequiredError{ + Elicitations: []mcp.ElicitationParams{ + { + Mode: mcp.ElicitationModeURL, + ElicitationID: elicitationID, + URL: fmt.Sprintf("https://myserver.com/authorize?id=%s", elicitationID), + Message: "Authorization is required to access this resource.", + }, + }, + } + } + + return &mcp.CallToolResult{ + Content: []mcp.Content{ + mcp.NewTextContent("Action completed successfully!"), + }, + }, nil + }, + ) + stdioServer := server.NewStdioServer(mcpServer) // Handle graceful shutdown diff --git a/mcp/consts.go b/mcp/consts.go index 66eb3803b..058619c15 100644 --- a/mcp/consts.go +++ b/mcp/consts.go @@ -6,4 +6,7 @@ const ( ContentTypeAudio = "audio" ContentTypeLink = "resource_link" ContentTypeResource = "resource" + + ElicitationModeForm = "form" + ElicitationModeURL = "url" ) diff --git a/mcp/elicitation_test.go b/mcp/elicitation_test.go new file mode 100644 index 000000000..49e1b700e --- /dev/null +++ b/mcp/elicitation_test.go @@ -0,0 +1,94 @@ +package mcp + +import ( + "encoding/json" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestElicitationParamsSerialization(t *testing.T) { + tests := []struct { + name string + params ElicitationParams + expected string + }{ + { + name: "Form Mode Default", + params: ElicitationParams{ + Message: "Please enter data", + RequestedSchema: map[string]any{ + "type": "string", + }, + }, + expected: `{"message":"Please enter data","requestedSchema":{"type":"string"}}`, + }, + { + name: "Form Mode Explicit", + params: ElicitationParams{ + Mode: ElicitationModeForm, + Message: "Please enter data", + RequestedSchema: map[string]any{ + "type": "string", + }, + }, + expected: `{"mode":"form","message":"Please enter data","requestedSchema":{"type":"string"}}`, + }, + { + name: "URL Mode", + params: ElicitationParams{ + Mode: ElicitationModeURL, + Message: "Please auth", + ElicitationID: "123", + URL: "https://example.com/auth", + }, + expected: `{"mode":"url","message":"Please auth","elicitationId":"123","url":"https://example.com/auth"}`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + data, err := json.Marshal(tt.params) + require.NoError(t, err) + assert.JSONEq(t, tt.expected, string(data)) + + // Round trip + var decoded ElicitationParams + err = json.Unmarshal(data, &decoded) + require.NoError(t, err) + + assert.Equal(t, tt.params.Message, decoded.Message) + assert.Equal(t, tt.params.Mode, decoded.Mode) + + if tt.params.Mode == ElicitationModeURL { + assert.Equal(t, tt.params.ElicitationID, decoded.ElicitationID) + assert.Equal(t, tt.params.URL, decoded.URL) + } + }) + } +} + +func TestElicitationCapabilitySerialization(t *testing.T) { + // Test empty struct for backward compatibility + cap := ElicitationCapability{} + data, err := json.Marshal(cap) + require.NoError(t, err) + assert.JSONEq(t, "{}", string(data)) + + // Test with Form support + cap = ElicitationCapability{ + Form: &struct{}{}, + } + data, err = json.Marshal(cap) + require.NoError(t, err) + assert.JSONEq(t, `{"form":{}}`, string(data)) + + // Test with URL support + cap = ElicitationCapability{ + URL: &struct{}{}, + } + data, err = json.Marshal(cap) + require.NoError(t, err) + assert.JSONEq(t, `{"url":{}}`, string(data)) +} diff --git a/mcp/elicitation_validation_test.go b/mcp/elicitation_validation_test.go new file mode 100644 index 000000000..254e03025 --- /dev/null +++ b/mcp/elicitation_validation_test.go @@ -0,0 +1,89 @@ +package mcp_test + +import ( + "testing" + + "github.com/mark3labs/mcp-go/mcp" + "github.com/stretchr/testify/require" +) + +func TestElicitationParams_Validate(t *testing.T) { + tests := []struct { + name string + params mcp.ElicitationParams + wantErr bool + }{ + { + name: "Valid Form Mode", + params: mcp.ElicitationParams{ + Mode: mcp.ElicitationModeForm, + Message: "Fill this form", + RequestedSchema: map[string]any{"type": "object"}, + }, + wantErr: false, + }, + { + name: "Valid URL Mode", + params: mcp.ElicitationParams{ + Mode: mcp.ElicitationModeURL, + Message: "Click this link", + ElicitationID: "123", + URL: "https://example.com/auth", + }, + wantErr: false, + }, + { + name: "Implicit Form Form Mode (Default)", + params: mcp.ElicitationParams{ + Mode: "", + Message: "Fill this form", + RequestedSchema: map[string]any{"type": "object"}, + }, + wantErr: false, // Should default to form and validate schema + }, + { + name: "Invalid Mode", + params: mcp.ElicitationParams{ + Mode: "invalid-mode", + }, + wantErr: true, + }, + { + name: "Form Mode Missing Schema", + params: mcp.ElicitationParams{ + Mode: mcp.ElicitationModeForm, + Message: "Missing schema", + }, + wantErr: true, + }, + { + name: "URL Mode Missing URL", + params: mcp.ElicitationParams{ + Mode: mcp.ElicitationModeURL, + ElicitationID: "123", + Message: "Missing URL", + }, + wantErr: true, + }, + { + name: "URL Mode Missing ElicitationID", + params: mcp.ElicitationParams{ + Mode: mcp.ElicitationModeURL, + URL: "https://example.com", + Message: "Missing ID", + }, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := tt.params.Validate() + if tt.wantErr { + require.Error(t, err, "expected error for test case: %s", tt.name) + } else { + require.NoError(t, err, "unexpected error for test case: %s", tt.name) + } + }) + } +} diff --git a/mcp/errors.go b/mcp/errors.go index aead24744..de39aabbe 100644 --- a/mcp/errors.go +++ b/mcp/errors.go @@ -1,6 +1,7 @@ package mcp import ( + "encoding/json" "errors" "fmt" ) @@ -27,8 +28,31 @@ var ( // ErrResourceNotFound indicates a requested resource was not found (code: RESOURCE_NOT_FOUND). ErrResourceNotFound = errors.New("resource not found") + ) +// URLElicitationRequiredError is returned when the server requires URL elicitation to proceed. +type URLElicitationRequiredError struct { + Elicitations []ElicitationParams `json:"elicitations"` +} + +func (e URLElicitationRequiredError) Error() string { + return fmt.Sprintf("URL elicitation required: %d elicitation(s) needed", len(e.Elicitations)) +} + +func (e URLElicitationRequiredError) JSONRPCError() JSONRPCError { + return JSONRPCError{ + JSONRPC: JSONRPC_VERSION, + Error: JSONRPCErrorDetails{ + Code: URL_ELICITATION_REQUIRED, + Message: e.Error(), + Data: map[string]any{ + "elicitations": e.Elicitations, + }, + }, + } +} + // UnsupportedProtocolVersionError is returned when the server responds with // a protocol version that the client doesn't support. type UnsupportedProtocolVersionError struct { @@ -39,6 +63,12 @@ func (e UnsupportedProtocolVersionError) Error() string { return fmt.Sprintf("unsupported protocol version: %q", e.Version) } +// Is implements the errors.Is interface for better error handling +func (e URLElicitationRequiredError) Is(target error) bool { + _, ok := target.(URLElicitationRequiredError) + return ok +} + // Is implements the errors.Is interface for better error handling func (e UnsupportedProtocolVersionError) Is(target error) bool { _, ok := target.(UnsupportedProtocolVersionError) @@ -72,6 +102,24 @@ func (e *JSONRPCErrorDetails) AsError() error { err = ErrRequestInterrupted case RESOURCE_NOT_FOUND: err = ErrResourceNotFound + case URL_ELICITATION_REQUIRED: + // Attempt to reconstruct URLElicitationRequiredError from Data + if e.Data != nil { + // Round-trip through JSON to parse into struct + // This handles both map[string]any (from unmarshal) and other forms + if dataBytes, marshalErr := json.Marshal(e.Data); marshalErr == nil { + var data struct { + Elicitations []ElicitationParams `json:"elicitations"` + } + if unmarshalErr := json.Unmarshal(dataBytes, &data); unmarshalErr == nil { + return URLElicitationRequiredError{ + Elicitations: data.Elicitations, + } + } + } + } + // Fallback if data is missing or invalid + return URLElicitationRequiredError{} default: return errors.New(e.Message) } diff --git a/mcp/errors_test.go b/mcp/errors_test.go index 00ce4dc53..cab413745 100644 --- a/mcp/errors_test.go +++ b/mcp/errors_test.go @@ -168,3 +168,65 @@ func TestErrorChaining(t *testing.T) { // But the original error should require.True(t, errors.Is(err, ErrMethodNotFound)) } + +func TestURLElicitationRequiredError(t *testing.T) { + t.Parallel() + + err := URLElicitationRequiredError{ + Elicitations: []ElicitationParams{ + { + Mode: ElicitationModeURL, + ElicitationID: "123", + URL: "https://example.com/auth", + Message: "Auth required", + }, + }, + } + + // Test Error() string + expectedMsg := "URL elicitation required: 1 elicitation(s) needed" + require.Equal(t, expectedMsg, err.Error()) + + // Test JSONRPCError conversion + jsonRPCError := err.JSONRPCError() + require.Equal(t, URL_ELICITATION_REQUIRED, jsonRPCError.Error.Code) + require.Equal(t, expectedMsg, jsonRPCError.Error.Message) + + dataMap, ok := jsonRPCError.Error.Data.(map[string]any) + require.True(t, ok, "Expected Data to be map[string]any") + + elicitations, ok := dataMap["elicitations"].([]ElicitationParams) + require.True(t, ok, "Expected elicitations in Data") + + require.Equal(t, 1, len(elicitations)) + require.Equal(t, "123", elicitations[0].ElicitationID) +} + +func TestJSONRPCErrorDetails_AsError_URLElicitationRequired(t *testing.T) { + t.Parallel() + + elicitations := []ElicitationParams{ + { + Mode: ElicitationModeURL, + ElicitationID: "123", + URL: "https://example.com/auth", + }, + } + + details := &JSONRPCErrorDetails{ + Code: URL_ELICITATION_REQUIRED, + Message: "URL elicitation required...", + Data: map[string]any{ + "elicitations": elicitations, + }, + } + + err := details.AsError() + require.Error(t, err) + + var urlErr URLElicitationRequiredError + require.True(t, errors.As(err, &urlErr), "Expected error to be URLElicitationRequiredError") + require.Equal(t, 1, len(urlErr.Elicitations)) + require.Equal(t, "123", urlErr.Elicitations[0].ElicitationID) + require.Equal(t, "https://example.com/auth", urlErr.Elicitations[0].URL) +} diff --git a/mcp/types.go b/mcp/types.go index ab3436057..80881c449 100644 --- a/mcp/types.go +++ b/mcp/types.go @@ -58,6 +58,9 @@ const ( // MethodElicitationCreate requests additional information from the user during interactions. // https://modelcontextprotocol.io/docs/concepts/elicitation MethodElicitationCreate MCPMethod = "elicitation/create" + + // MethodNotificationElicitationComplete notifies when a URL mode elicitation completes. + MethodNotificationElicitationComplete MCPMethod = "notifications/elicitation/complete" // MethodListRoots requests roots list from the client during interactions. // https://modelcontextprotocol.io/specification/2025-06-18/client/roots @@ -415,8 +418,11 @@ const ( // MCP error codes const ( - // RESOURCE_NOT_FOUND indicates a requested resource was not found. + // RESOURCE_NOT_FOUND indicates that the requested resource was not found. RESOURCE_NOT_FOUND = -32002 + + // URL_ELICITATION_REQUIRED is the error code for when URL elicitation is required. + URL_ELICITATION_REQUIRED = -32042 ) /* Empty result */ @@ -510,7 +516,7 @@ type ClientCapabilities struct { // Present if the client supports sampling from an LLM. Sampling *struct{} `json:"sampling,omitempty"` // Present if the client supports elicitation requests from the server. - Elicitation *struct{} `json:"elicitation,omitempty"` + Elicitation *ElicitationCapability `json:"elicitation,omitempty"` // Present if the client supports task-based execution. Tasks *TasksCapability `json:"tasks,omitempty"` } @@ -544,7 +550,7 @@ type ServerCapabilities struct { ListChanged bool `json:"listChanged,omitempty"` } `json:"tools,omitempty"` // Present if the server supports elicitation requests to the client. - Elicitation *struct{} `json:"elicitation,omitempty"` + Elicitation *ElicitationCapability `json:"elicitation,omitempty"` // Present if the server supports roots requests to the client. Roots *struct{} `json:"roots,omitempty"` // Present if the server supports task-based execution. @@ -910,10 +916,49 @@ type ElicitationRequest struct { // ElicitationParams contains the parameters for an elicitation request. type ElicitationParams struct { + Meta *Meta `json:"_meta,omitempty"` + // Mode specifies the type of elicitation: "form" or "url". Defaults to "form". + Mode string `json:"mode,omitempty"` // A human-readable message explaining what information is being requested and why. Message string `json:"message"` + + // Form mode fields + // A JSON Schema defining the expected structure of the user's response. - RequestedSchema any `json:"requestedSchema"` + RequestedSchema any `json:"requestedSchema,omitempty"` + + // URL mode fields + + // ElicitationID is a unique identifier for the elicitation request. + ElicitationID string `json:"elicitationId,omitempty"` + // URL is the URL to be opened by the user. + URL string `json:"url,omitempty"` +} + +// Validate checks if the elicitation parameters are valid. +func (p ElicitationParams) Validate() error { + mode := p.Mode + if mode == "" { + mode = ElicitationModeForm + } + + switch mode { + case ElicitationModeForm: + if p.RequestedSchema == nil { + return fmt.Errorf("requestedSchema is required for form elicitation") + } + case ElicitationModeURL: + if p.ElicitationID == "" { + return fmt.Errorf("elicitationId is required for url elicitation") + } + if p.URL == "" { + return fmt.Errorf("url is required for url elicitation") + } + default: + return fmt.Errorf("invalid elicitation mode: %s", mode) + } + + return nil } // ElicitationResult represents the result of an elicitation request. @@ -1460,3 +1505,25 @@ func UnmarshalContent(data []byte) (Content, error) { return nil, fmt.Errorf("unknown content type: %s", contentType) } } + +// ElicitationCapability represents the elicitation capabilities of a client or server. +type ElicitationCapability struct { + Form *struct{} `json:"form,omitempty"` // Supports form mode + URL *struct{} `json:"url,omitempty"` // Supports URL mode +} + +// NewElicitationCompleteNotification creates a new elicitation complete notification. +func NewElicitationCompleteNotification(elicitationID string) JSONRPCNotification { + return JSONRPCNotification{ + JSONRPC: JSONRPC_VERSION, + Notification: Notification{ + Method: string(MethodNotificationElicitationComplete), + Params: NotificationParams{ + AdditionalFields: map[string]any{ + "elicitationId": elicitationID, + }, + }, + }, + } +} + diff --git a/server/elicitation.go b/server/elicitation.go index d3e6d3d4c..728d8ad5b 100644 --- a/server/elicitation.go +++ b/server/elicitation.go @@ -25,8 +25,63 @@ func (s *MCPServer) RequestElicitation(ctx context.Context, request mcp.Elicitat // Check if the session supports elicitation requests if elicitationSession, ok := session.(SessionWithElicitation); ok { + if err := request.Params.Validate(); err != nil { + return nil, err + } return elicitationSession.RequestElicitation(ctx, request) } return nil, ErrElicitationNotSupported } + +// RequestURLElicitation sends a URL mode elicitation request to the client. +// This is used when the server needs the user to perform an out-of-band interaction. +func (s *MCPServer) RequestURLElicitation( + ctx context.Context, + session ClientSession, + elicitationID string, + url string, + message string, +) (*mcp.ElicitationResult, error) { + if session == nil { + return nil, ErrNoActiveSession + } + + params := mcp.ElicitationParams{ + Mode: mcp.ElicitationModeURL, + Message: message, + ElicitationID: elicitationID, + URL: url, + } + + if err := params.Validate(); err != nil { + return nil, err + } + + request := mcp.ElicitationRequest{ + Request: mcp.Request{ + Method: string(mcp.MethodElicitationCreate), + }, + Params: params, + } + + if elicitationSession, ok := session.(SessionWithElicitation); ok { + return elicitationSession.RequestElicitation(ctx, request) + } + return nil, ErrElicitationNotSupported +} + +// SendElicitationComplete sends a notification that a URL mode elicitation has completed +// SendElicitationComplete sends a notification that a URL mode elicitation has completed +func (s *MCPServer) SendElicitationComplete( + ctx context.Context, + session ClientSession, + elicitationID string, +) error { + if session == nil { + return ErrNoActiveSession + } + + jsonRPCNotif := mcp.NewElicitationCompleteNotification(elicitationID) + return s.sendNotificationCore(ctx, session, jsonRPCNotif) +} diff --git a/server/elicitation_test.go b/server/elicitation_test.go index 5356b3c25..05dc3e55d 100644 --- a/server/elicitation_test.go +++ b/server/elicitation_test.go @@ -4,6 +4,7 @@ import ( "context" "errors" "testing" + "time" "github.com/mark3labs/mcp-go/mcp" "github.com/stretchr/testify/assert" @@ -31,9 +32,11 @@ func (m *mockBasicSession) Initialized() bool { // mockElicitationSession implements SessionWithElicitation for testing type mockElicitationSession struct { - sessionID string - result *mcp.ElicitationResult - err error + sessionID string + result *mcp.ElicitationResult + err error + lastRequest mcp.ElicitationRequest + notifyChan chan mcp.JSONRPCNotification } func (m *mockElicitationSession) SessionID() string { @@ -41,7 +44,11 @@ func (m *mockElicitationSession) SessionID() string { } func (m *mockElicitationSession) NotificationChannel() chan<- mcp.JSONRPCNotification { - return make(chan mcp.JSONRPCNotification, 1) + if m.notifyChan == nil { + // Buffer of 100 to avoid blocking during tests with multiple notifications + m.notifyChan = make(chan mcp.JSONRPCNotification, 100) + } + return m.notifyChan } func (m *mockElicitationSession) Initialize() {} @@ -51,6 +58,7 @@ func (m *mockElicitationSession) Initialized() bool { } func (m *mockElicitationSession) RequestElicitation(ctx context.Context, request mcp.ElicitationRequest) (*mcp.ElicitationResult, error) { + m.lastRequest = request if m.err != nil { return nil, m.err } @@ -72,13 +80,8 @@ func TestMCPServer_RequestElicitation_NoSession(t *testing.T) { _, err := server.RequestElicitation(context.Background(), request) - if err == nil { - t.Error("expected error when no session available") - } - - if !errors.Is(err, ErrNoActiveSession) { - t.Errorf("expected ErrNoActiveSession, got %v", err) - } + require.Error(t, err) + assert.True(t, errors.Is(err, ErrNoActiveSession), "expected ErrNoActiveSession, got %v", err) } func TestMCPServer_RequestElicitation_SessionDoesNotSupportElicitation(t *testing.T) { @@ -101,13 +104,8 @@ func TestMCPServer_RequestElicitation_SessionDoesNotSupportElicitation(t *testin _, err := server.RequestElicitation(ctx, request) - if err == nil { - t.Error("expected error when session doesn't support elicitation") - } - - if !errors.Is(err, ErrElicitationNotSupported) { - t.Errorf("expected ErrElicitationNotSupported, got %v", err) - } + require.Error(t, err) + assert.True(t, errors.Is(err, ErrElicitationNotSupported), "expected ErrElicitationNotSupported, got %v", err) } func TestMCPServer_RequestElicitation_Success(t *testing.T) { @@ -146,28 +144,13 @@ func TestMCPServer_RequestElicitation_Success(t *testing.T) { result, err := server.RequestElicitation(ctx, request) - if err != nil { - t.Errorf("unexpected error: %v", err) - } - - if result == nil { - t.Error("expected result, got nil") - return - } - - if result.Action != mcp.ElicitationResponseActionAccept { - t.Errorf("expected response type %q, got %q", mcp.ElicitationResponseActionAccept, result.Action) - } + require.NoError(t, err) + require.NotNil(t, result) + assert.Equal(t, mcp.ElicitationResponseActionAccept, result.Action) value, ok := result.Content.(map[string]any) - if !ok { - t.Error("expected value to be a map") - return - } - - if value["projectName"] != "my-project" { - t.Errorf("expected projectName %q, got %q", "my-project", value["projectName"]) - } + require.True(t, ok, "expected value to be a map") + assert.Equal(t, "my-project", value["projectName"]) } func TestRequestElicitation(t *testing.T) { @@ -226,7 +209,7 @@ func TestRequestElicitation(t *testing.T) { }, { name: "session does not support elicitation", - session: &fakeSession{sessionID: "test-3"}, + session: &mockBasicSession{sessionID: "test-3"}, request: mcp.ElicitationRequest{ Params: mcp.ElicitationParams{ Message: "Need info", @@ -255,8 +238,74 @@ func TestRequestElicitation(t *testing.T) { assert.Equal(t, tt.expectedType, result.Action) if tt.expectedType == mcp.ElicitationResponseActionAccept { - assert.NotNil(t, result.Action) + assert.NotNil(t, result.Content) } }) } } + +func TestRequestURLElicitation(t *testing.T) { + s := NewMCPServer("test", "1.0", WithElicitation()) + + mockSession := &mockElicitationSession{ + sessionID: "test-url-1", + result: &mcp.ElicitationResult{ + ElicitationResponse: mcp.ElicitationResponse{ + Action: mcp.ElicitationResponseActionAccept, + }, + }, + } + + ctx := context.Background() + _, err := s.RequestURLElicitation(ctx, mockSession, "id-123", "https://example.com/auth", "Please auth") + require.NoError(t, err) + + assert.Equal(t, mcp.ElicitationModeURL, mockSession.lastRequest.Params.Mode) + assert.Equal(t, "id-123", mockSession.lastRequest.Params.ElicitationID) + assert.Equal(t, "https://example.com/auth", mockSession.lastRequest.Params.URL) + + notifyChan := make(chan mcp.JSONRPCNotification, 1) + mockSessionWithChan := &mockElicitationSession{ + sessionID: "test-url-2", + notifyChan: notifyChan, + } + + err = s.SendElicitationComplete(ctx, mockSessionWithChan, "id-123") + require.NoError(t, err) + + select { + case notif := <-notifyChan: + assert.Equal(t, "notifications/elicitation/complete", notif.Method) + // Validate elicitationId is included in params + elicitationID, ok := notif.Params.AdditionalFields["elicitationId"] + assert.True(t, ok, "expected elicitationId in notification params") + assert.Equal(t, "id-123", elicitationID) + case <-time.After(100 * time.Millisecond): + t.Error("Expected notification not received") + } +} + +func TestSendElicitationComplete_NoPriorRequest(t *testing.T) { + s := NewMCPServer("test", "1.0", WithElicitation()) + + notifyChan := make(chan mcp.JSONRPCNotification, 1) + mockSession := &mockElicitationSession{ + sessionID: "test-session-complete", + notifyChan: notifyChan, + } + + // Call SendElicitationComplete directly without any prior request state + // This verifies the server can send completion notifications independently + err := s.SendElicitationComplete(context.Background(), mockSession, "independent-id-999") + require.NoError(t, err) + + select { + case notif := <-notifyChan: + assert.Equal(t, "notifications/elicitation/complete", notif.Method) + elicitationID, ok := notif.Params.AdditionalFields["elicitationId"] + assert.True(t, ok, "expected elicitationId in notification params") + assert.Equal(t, "independent-id-999", elicitationID) + case <-time.After(100 * time.Millisecond): + t.Fatal("Expected notification was not received") + } +} diff --git a/server/server.go b/server/server.go index bb56386bd..6a8a0b055 100644 --- a/server/server.go +++ b/server/server.go @@ -736,7 +736,7 @@ func (s *MCPServer) handleInitialize( } if s.capabilities.elicitation != nil && *s.capabilities.elicitation { - capabilities.Elicitation = &struct{}{} + capabilities.Elicitation = &mcp.ElicitationCapability{} } if s.capabilities.roots != nil && *s.capabilities.roots {