Skip to content

Commit d9263f7

Browse files
markus-kusanoMarkus Kusano
andauthored
mcp: support elicitation defaults (#644)
Fixes #617 --------- Co-authored-by: Markus Kusano <kusano@google.com>
1 parent 1943780 commit d9263f7

File tree

3 files changed

+177
-9
lines changed

3 files changed

+177
-9
lines changed

mcp/client.go

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -321,10 +321,13 @@ func (c *Client) elicit(ctx context.Context, req *ElicitRequest) (*ElicitResult,
321321
if err != nil {
322322
return nil, jsonrpc2.NewError(codeInvalidParams, fmt.Sprintf("failed to resolve requested schema: %v", err))
323323
}
324-
325324
if err := resolved.Validate(res.Content); err != nil {
326325
return nil, jsonrpc2.NewError(codeInvalidParams, fmt.Sprintf("elicitation result content does not match requested schema: %v", err))
327326
}
327+
err = resolved.ApplyDefaults(&res.Content)
328+
if err != nil {
329+
return nil, jsonrpc2.NewError(codeInvalidParams, fmt.Sprintf("failed to apply schema defalts to elicitation result: %v", err))
330+
}
328331
}
329332

330333
return res, nil
@@ -341,6 +344,9 @@ func validateElicitSchema(wireSchema any) (*jsonschema.Schema, error) {
341344
if err := remarshal(wireSchema, &schema); err != nil {
342345
return nil, err
343346
}
347+
if schema == nil {
348+
return nil, nil
349+
}
344350

345351
// The root schema must be of type "object" if specified
346352
if schema.Type != "" && schema.Type != "object" {
@@ -369,7 +375,6 @@ func validateElicitProperty(propName string, propSchema *jsonschema.Schema) erro
369375
if len(propSchema.Properties) > 0 {
370376
return fmt.Errorf("elicit schema property %q contains nested properties, only primitive properties are allowed", propName)
371377
}
372-
373378
// Validate based on the property type - only primitives are supported
374379
switch propSchema.Type {
375380
case "string":
@@ -439,7 +444,7 @@ func validateElicitStringProperty(propName string, propSchema *jsonschema.Schema
439444
}
440445
}
441446

442-
return nil
447+
return validateDefaultProperty[string](propName, propSchema)
443448
}
444449

445450
// validateElicitNumberProperty validates number and integer-type properties.
@@ -450,19 +455,28 @@ func validateElicitNumberProperty(propName string, propSchema *jsonschema.Schema
450455
}
451456
}
452457

458+
intDefaultError := validateDefaultProperty[int](propName, propSchema)
459+
floatDefaultError := validateDefaultProperty[float64](propName, propSchema)
460+
if intDefaultError != nil && floatDefaultError != nil {
461+
return fmt.Errorf("elicit schema property %q has default value that cannot be interpreted as an int or float", propName)
462+
}
463+
453464
return nil
454465
}
455466

456467
// validateElicitBooleanProperty validates boolean-type properties.
457468
func validateElicitBooleanProperty(propName string, propSchema *jsonschema.Schema) error {
458-
// Validate default value if specified - must be a valid boolean
469+
return validateDefaultProperty[bool](propName, propSchema)
470+
}
471+
472+
func validateDefaultProperty[T any](propName string, propSchema *jsonschema.Schema) error {
473+
// Validate default value if specified - must be a valid T
459474
if propSchema.Default != nil {
460-
var defaultValue bool
475+
var defaultValue T
461476
if err := json.Unmarshal(propSchema.Default, &defaultValue); err != nil {
462-
return fmt.Errorf("elicit schema property %q has invalid default value, must be a boolean: %v", propName, err)
477+
return fmt.Errorf("elicit schema property %q has invalid default value, must be a %T: %v", propName, defaultValue, err)
463478
}
464479
}
465-
466480
return nil
467481
}
468482

mcp/mcp_test.go

Lines changed: 125 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1254,7 +1254,37 @@ func TestElicitationSchemaValidation(t *testing.T) {
12541254
"enabled": {Type: "boolean", Default: json.RawMessage(`"not-a-boolean"`)},
12551255
},
12561256
},
1257-
expectedError: "elicit schema property \"enabled\" has invalid default value, must be a boolean",
1257+
expectedError: "elicit schema property \"enabled\" has invalid default value, must be a bool",
1258+
},
1259+
{
1260+
name: "string with invalid default",
1261+
schema: &jsonschema.Schema{
1262+
Type: "object",
1263+
Properties: map[string]*jsonschema.Schema{
1264+
"enabled": {Type: "string", Default: json.RawMessage("true")},
1265+
},
1266+
},
1267+
expectedError: "elicit schema property \"enabled\" has invalid default value, must be a string",
1268+
},
1269+
{
1270+
name: "integer with invalid default",
1271+
schema: &jsonschema.Schema{
1272+
Type: "object",
1273+
Properties: map[string]*jsonschema.Schema{
1274+
"enabled": {Type: "integer", Default: json.RawMessage("true")},
1275+
},
1276+
},
1277+
expectedError: "elicit schema property \"enabled\" has default value that cannot be interpreted as an int or float",
1278+
},
1279+
{
1280+
name: "number with invalid default",
1281+
schema: &jsonschema.Schema{
1282+
Type: "object",
1283+
Properties: map[string]*jsonschema.Schema{
1284+
"enabled": {Type: "number", Default: json.RawMessage("true")},
1285+
},
1286+
},
1287+
expectedError: "elicit schema property \"enabled\" has default value that cannot be interpreted as an int or float",
12581288
},
12591289
{
12601290
name: "enum with mismatched enumNames length",
@@ -1459,6 +1489,100 @@ func TestElicitationCapabilityDeclaration(t *testing.T) {
14591489
})
14601490
}
14611491

1492+
func TestElicitationDefaultValues(t *testing.T) {
1493+
ctx := context.Background()
1494+
ct, st := NewInMemoryTransports()
1495+
1496+
s := NewServer(testImpl, nil)
1497+
ss, err := s.Connect(ctx, st, nil)
1498+
if err != nil {
1499+
t.Fatal(err)
1500+
}
1501+
defer ss.Close()
1502+
1503+
c := NewClient(testImpl, &ClientOptions{
1504+
ElicitationHandler: func(context.Context, *ElicitRequest) (*ElicitResult, error) {
1505+
return &ElicitResult{Action: "accept", Content: map[string]any{"default": "response"}}, nil
1506+
},
1507+
})
1508+
cs, err := c.Connect(ctx, ct, nil)
1509+
if err != nil {
1510+
t.Fatal(err)
1511+
}
1512+
defer cs.Close()
1513+
1514+
testcases := []struct {
1515+
name string
1516+
schema *jsonschema.Schema
1517+
expected map[string]any
1518+
}{
1519+
{
1520+
name: "boolean with default",
1521+
schema: &jsonschema.Schema{
1522+
Type: "object",
1523+
Properties: map[string]*jsonschema.Schema{
1524+
"key": {Type: "boolean", Default: json.RawMessage("true")},
1525+
},
1526+
},
1527+
expected: map[string]any{"key": true, "default": "response"},
1528+
},
1529+
{
1530+
name: "string with default",
1531+
schema: &jsonschema.Schema{
1532+
Type: "object",
1533+
Properties: map[string]*jsonschema.Schema{
1534+
"key": {Type: "string", Default: json.RawMessage("\"potato\"")},
1535+
},
1536+
},
1537+
expected: map[string]any{"key": "potato", "default": "response"},
1538+
},
1539+
{
1540+
name: "integer with default",
1541+
schema: &jsonschema.Schema{
1542+
Type: "object",
1543+
Properties: map[string]*jsonschema.Schema{
1544+
"key": {Type: "integer", Default: json.RawMessage("123")},
1545+
},
1546+
},
1547+
expected: map[string]any{"key": float64(123), "default": "response"},
1548+
},
1549+
{
1550+
name: "number with default",
1551+
schema: &jsonschema.Schema{
1552+
Type: "object",
1553+
Properties: map[string]*jsonschema.Schema{
1554+
"key": {Type: "number", Default: json.RawMessage("89.7")},
1555+
},
1556+
},
1557+
expected: map[string]any{"key": float64(89.7), "default": "response"},
1558+
},
1559+
{
1560+
name: "enum with default",
1561+
schema: &jsonschema.Schema{
1562+
Type: "object",
1563+
Properties: map[string]*jsonschema.Schema{
1564+
"key": {Type: "string", Enum: []any{"one", "two"}, Default: json.RawMessage("\"one\"")},
1565+
},
1566+
},
1567+
expected: map[string]any{"key": "one", "default": "response"},
1568+
},
1569+
}
1570+
for _, tc := range testcases {
1571+
t.Run(tc.name, func(t *testing.T) {
1572+
res, err := ss.Elicit(ctx, &ElicitParams{
1573+
Message: "Test schema with defaults: " + tc.name,
1574+
RequestedSchema: tc.schema,
1575+
})
1576+
if err != nil {
1577+
t.Fatalf("expected no error for default schema %q, got: %v", tc.name, err)
1578+
}
1579+
if diff := cmp.Diff(tc.expected, res.Content); diff != "" {
1580+
t.Errorf("%s: did not get expected value, -want +got:\n%s", tc.name, diff)
1581+
}
1582+
})
1583+
}
1584+
}
1585+
14621586
func TestKeepAliveFailure(t *testing.T) {
14631587
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
14641588
defer cancel()

mcp/server.go

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1018,7 +1018,37 @@ func (ss *ServerSession) Elicit(ctx context.Context, params *ElicitParams) (*Eli
10181018
if err := ss.checkInitialized(methodElicit); err != nil {
10191019
return nil, err
10201020
}
1021-
return handleSend[*ElicitResult](ctx, methodElicit, newServerRequest(ss, orZero[Params](params)))
1021+
1022+
res, err := handleSend[*ElicitResult](ctx, methodElicit, newServerRequest(ss, orZero[Params](params)))
1023+
if err != nil {
1024+
return nil, err
1025+
}
1026+
1027+
if params.RequestedSchema == nil {
1028+
return res, nil
1029+
}
1030+
schema, err := validateElicitSchema(params.RequestedSchema)
1031+
if err != nil {
1032+
return nil, err
1033+
}
1034+
if schema == nil {
1035+
return res, nil
1036+
}
1037+
1038+
resolved, err := schema.Resolve(nil)
1039+
if err != nil {
1040+
fmt.Printf(" resolve err: %s", err)
1041+
return nil, err
1042+
}
1043+
if err := resolved.Validate(res.Content); err != nil {
1044+
return nil, fmt.Errorf("elicitation result content does not match requested schema: %v", err)
1045+
}
1046+
err = resolved.ApplyDefaults(&res.Content)
1047+
if err != nil {
1048+
return nil, fmt.Errorf("failed to apply schema defalts to elicitation result: %v", err)
1049+
}
1050+
1051+
return res, nil
10221052
}
10231053

10241054
// Log sends a log message to the client.

0 commit comments

Comments
 (0)