Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ func (c *Client) Initialize(
}
// Add elicitation capability if handler is configured
if c.elicitationHandler != nil {
capabilities.Elicitation = &struct{}{}
capabilities.Elicitation = &mcp.ElicitationCapability{}
}

// Ensure we send a params object with all required fields
Expand Down Expand Up @@ -629,6 +629,10 @@ func (c *Client) handleElicitationRequestTransport(ctx context.Context, request
}
}

if err := params.Validate(); err != nil {
return nil, fmt.Errorf("invalid elicitation params: %w", err)
}

// Create the MCP request
mcpRequest := mcp.ElicitationRequest{
Request: mcp.Request{
Expand Down
4 changes: 2 additions & 2 deletions client/elicitation.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ import (
type ElicitationHandler interface {
// Elicit handles an elicitation request from the server and returns the user's response.
// The implementation should:
// 1. Present the request message to the user
// 2. Validate input against the requested schema
// 1. Present the request message to the user (and URL if in URL mode)
// 2. Validate input against the requested schema (for form mode)
// 3. Allow the user to accept, decline, or cancel
// 4. Return the appropriate response
Elicit(ctx context.Context, request mcp.ElicitationRequest) (*mcp.ElicitationResult, error)
Expand Down
2 changes: 1 addition & 1 deletion client/inprocess_elicitation_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ func TestInProcessElicitation(t *testing.T) {
Version: "1.0.0",
},
Capabilities: mcp.ClientCapabilities{
Elicitation: &struct{}{},
Elicitation: &mcp.ElicitationCapability{},
},
},
})
Expand Down
102 changes: 99 additions & 3 deletions examples/elicitation/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"os/signal"
"sync/atomic"

"github.com/google/uuid"
"github.com/mark3labs/mcp-go/mcp"
"github.com/mark3labs/mcp-go/server"
)
Expand Down Expand Up @@ -129,7 +130,7 @@ func main() {
server.WithElicitation(), // Enable elicitation
)

// Add a tool that uses elicitation
// Add a tool that uses elicitation (Form Mode)
mcpServer.AddTool(
mcp.NewTool(
"create_project",
Expand All @@ -138,7 +139,7 @@ func main() {
demoElicitationHandler(mcpServer),
)

// Add another tool that demonstrates conditional elicitation
// Add another tool that demonstrates conditional elicitation (Form Mode)
mcpServer.AddTool(
mcp.NewTool(
"process_data",
Expand Down Expand Up @@ -236,7 +237,102 @@ func main() {
},
)

// Create and start stdio server
// Add a tool that uses URL elicitation (auth flow)
mcpServer.AddTool(
mcp.NewTool(
"auth_via_url",
mcp.WithDescription("Demonstrates out-of-band authentication via URL"),
),
func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
session := server.ClientSessionFromContext(ctx)
if session == nil {
return nil, fmt.Errorf("no active session")
}

// Generate unique elicitation ID
elicitationID := uuid.New().String()

// Create URL with elicitation ID for tracking
// In a real application, you would store the ID and associate it with the user session
url := fmt.Sprintf("https://myserver.com/set-api-key?elicitationId=%s", elicitationID)

// Request URL mode elicitation
result, err := mcpServer.RequestURLElicitation(
ctx,
session,
elicitationID,
url,
"Please authenticate in your browser to continue.",
)
if err != nil {
return nil, fmt.Errorf("URL elicitation failed: %w", err)
}

if result.Action == mcp.ElicitationResponseActionAccept {
// User consented to open the URL
// They will complete the flow in their browser
// Server will store credentials when user submits the form

// Simulate sending completion notification
// NOTE: In production, this notification would be sent after
// the server receives the authentication callback from the browser.
// Here we simulate immediate completion for demonstration purposes.
if err := mcpServer.SendElicitationComplete(ctx, session, elicitationID); err != nil {
// Log error but continue
fmt.Fprintf(os.Stderr, "Failed to send completion notification: %v\n", err)
}

return &mcp.CallToolResult{
Content: []mcp.Content{
mcp.NewTextContent("Authentication flow initiated. User accepted URL open request."),
},
}, nil
}

return &mcp.CallToolResult{
Content: []mcp.Content{
mcp.NewTextContent(fmt.Sprintf("User declined authentication: %s", result.Action)),
},
}, nil
},
)

// Add a tool that demonstrates returning URLElicitationRequiredError
mcpServer.AddTool(
mcp.NewTool(
"protected_action",
mcp.WithDescription("A protected action that requires prior authorization"),
),
func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
// TODO: In production, check actual authorization state
// For demo purposes, we always trigger elicitation
isAuthorized := false // Always false to demonstrate error flow

if !isAuthorized {
// When a request needs authorization that hasn't been set up
elicitationID := uuid.New().String()

// Return a special error that tells the client to start elicitation
return nil, mcp.URLElicitationRequiredError{
Elicitations: []mcp.ElicitationParams{
{
Mode: mcp.ElicitationModeURL,
ElicitationID: elicitationID,
URL: fmt.Sprintf("https://myserver.com/authorize?id=%s", elicitationID),
Message: "Authorization is required to access this resource.",
},
},
}
}

return &mcp.CallToolResult{
Content: []mcp.Content{
mcp.NewTextContent("Action completed successfully!"),
},
}, nil
},
)

stdioServer := server.NewStdioServer(mcpServer)

// Handle graceful shutdown
Expand Down
3 changes: 3 additions & 0 deletions mcp/consts.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,7 @@ const (
ContentTypeAudio = "audio"
ContentTypeLink = "resource_link"
ContentTypeResource = "resource"

ElicitationModeForm = "form"
ElicitationModeURL = "url"
)
94 changes: 94 additions & 0 deletions mcp/elicitation_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
package mcp

import (
"encoding/json"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestElicitationParamsSerialization(t *testing.T) {
tests := []struct {
name string
params ElicitationParams
expected string
}{
{
name: "Form Mode Default",
params: ElicitationParams{
Message: "Please enter data",
RequestedSchema: map[string]any{
"type": "string",
},
},
expected: `{"message":"Please enter data","requestedSchema":{"type":"string"}}`,
},
{
name: "Form Mode Explicit",
params: ElicitationParams{
Mode: ElicitationModeForm,
Message: "Please enter data",
RequestedSchema: map[string]any{
"type": "string",
},
},
expected: `{"mode":"form","message":"Please enter data","requestedSchema":{"type":"string"}}`,
},
{
name: "URL Mode",
params: ElicitationParams{
Mode: ElicitationModeURL,
Message: "Please auth",
ElicitationID: "123",
URL: "https://example.com/auth",
},
expected: `{"mode":"url","message":"Please auth","elicitationId":"123","url":"https://example.com/auth"}`,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
data, err := json.Marshal(tt.params)
require.NoError(t, err)
assert.JSONEq(t, tt.expected, string(data))

// Round trip
var decoded ElicitationParams
err = json.Unmarshal(data, &decoded)
require.NoError(t, err)

assert.Equal(t, tt.params.Message, decoded.Message)
assert.Equal(t, tt.params.Mode, decoded.Mode)

if tt.params.Mode == ElicitationModeURL {
assert.Equal(t, tt.params.ElicitationID, decoded.ElicitationID)
assert.Equal(t, tt.params.URL, decoded.URL)
}
})
}
}

func TestElicitationCapabilitySerialization(t *testing.T) {
// Test empty struct for backward compatibility
cap := ElicitationCapability{}
data, err := json.Marshal(cap)
require.NoError(t, err)
assert.JSONEq(t, "{}", string(data))

// Test with Form support
cap = ElicitationCapability{
Form: &struct{}{},
}
data, err = json.Marshal(cap)
require.NoError(t, err)
assert.JSONEq(t, `{"form":{}}`, string(data))

// Test with URL support
cap = ElicitationCapability{
URL: &struct{}{},
}
data, err = json.Marshal(cap)
require.NoError(t, err)
assert.JSONEq(t, `{"url":{}}`, string(data))
}
89 changes: 89 additions & 0 deletions mcp/elicitation_validation_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
package mcp_test

import (
"testing"

"github.com/mark3labs/mcp-go/mcp"
"github.com/stretchr/testify/require"
)

func TestElicitationParams_Validate(t *testing.T) {
tests := []struct {
name string
params mcp.ElicitationParams
wantErr bool
}{
{
name: "Valid Form Mode",
params: mcp.ElicitationParams{
Mode: mcp.ElicitationModeForm,
Message: "Fill this form",
RequestedSchema: map[string]any{"type": "object"},
},
wantErr: false,
},
{
name: "Valid URL Mode",
params: mcp.ElicitationParams{
Mode: mcp.ElicitationModeURL,
Message: "Click this link",
ElicitationID: "123",
URL: "https://example.com/auth",
},
wantErr: false,
},
{
name: "Implicit Form Form Mode (Default)",
params: mcp.ElicitationParams{
Mode: "",
Message: "Fill this form",
RequestedSchema: map[string]any{"type": "object"},
},
wantErr: false, // Should default to form and validate schema
},
{
name: "Invalid Mode",
params: mcp.ElicitationParams{
Mode: "invalid-mode",
},
wantErr: true,
},
{
name: "Form Mode Missing Schema",
params: mcp.ElicitationParams{
Mode: mcp.ElicitationModeForm,
Message: "Missing schema",
},
wantErr: true,
},
{
name: "URL Mode Missing URL",
params: mcp.ElicitationParams{
Mode: mcp.ElicitationModeURL,
ElicitationID: "123",
Message: "Missing URL",
},
wantErr: true,
},
{
name: "URL Mode Missing ElicitationID",
params: mcp.ElicitationParams{
Mode: mcp.ElicitationModeURL,
URL: "https://example.com",
Message: "Missing ID",
},
wantErr: true,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := tt.params.Validate()
if tt.wantErr {
require.Error(t, err, "expected error for test case: %s", tt.name)
} else {
require.NoError(t, err, "unexpected error for test case: %s", tt.name)
}
})
}
}
Loading