diff --git a/mcp/mcp_test.go b/mcp/mcp_test.go index 8c40de08..32f40a10 100644 --- a/mcp/mcp_test.go +++ b/mcp/mcp_test.go @@ -1902,16 +1902,20 @@ func TestPointerArgEquivalence(t *testing.T) { type input struct { In string `json:",omitempty"` } + inputSchema := json.RawMessage(`{"type":"object","properties":{"In":{"type":"string"}},"additionalProperties":false}`) type output struct { Out string } + outputSchema := json.RawMessage(`{"type":"object","required":["Out"],"properties":{"Out":{"type":"string"}},"additionalProperties":false}`) cs, _, cleanup := basicConnection(t, func(s *Server) { - // Add two equivalent tools, one of which operates in the 'pointer' realm, - // the other of which does not. + // Add three equivalent tools: + // - one operates on pointers, with inferred schemas + // - one operates on pointers, with user-provided schemas + // - one operates on non-pointers // // We handle a few different types of results, to assert they behave the // same in all cases. - AddTool(s, &Tool{Name: "pointer"}, func(_ context.Context, req *CallToolRequest, in *input) (*CallToolResult, *output, error) { + handlePointers := func(_ context.Context, req *CallToolRequest, in *input) (*CallToolResult, *output, error) { switch in.In { case "": return nil, nil, fmt.Errorf("must provide input") @@ -1924,7 +1928,13 @@ func TestPointerArgEquivalence(t *testing.T) { default: panic("unreachable") } - }) + } + AddTool(s, &Tool{Name: "pointer-inferred"}, handlePointers) + AddTool(s, &Tool{ + Name: "pointer-provided", + InputSchema: inputSchema, + OutputSchema: outputSchema, + }, handlePointers) AddTool(s, &Tool{Name: "nonpointer"}, func(_ context.Context, req *CallToolRequest, in input) (*CallToolResult, output, error) { switch in.In { case "": @@ -1947,50 +1957,51 @@ func TestPointerArgEquivalence(t *testing.T) { if err != nil { t.Fatal(err) } - if got, want := len(tools.Tools), 2; got != want { + if got, want := len(tools.Tools), 3; got != want { t.Fatalf("got %d tools, want %d", got, want) } - t0 := tools.Tools[0] - t1 := tools.Tools[1] - - // First, check that the tool schemas don't differ. - if diff := cmp.Diff(t0.InputSchema, t1.InputSchema); diff != "" { - t.Errorf("input schemas do not match (-%s +%s):\n%s", t0.Name, t1.Name, diff) - } - if diff := cmp.Diff(t0.OutputSchema, t1.OutputSchema); diff != "" { - t.Errorf("output schemas do not match (-%s +%s):\n%s", t0.Name, t1.Name, diff) - } - // Then, check that we handle empty input equivalently. - for _, args := range []any{nil, struct{}{}} { - r0, err := cs.CallTool(ctx, &CallToolParams{Name: t0.Name, Arguments: args}) - if err != nil { - t.Fatal(err) - } - r1, err := cs.CallTool(ctx, &CallToolParams{Name: t1.Name, Arguments: args}) - if err != nil { - t.Fatal(err) + t0 := tools.Tools[0] + for _, t1 := range tools.Tools[1:] { + // First, check that the tool schemas don't differ. + if diff := cmp.Diff(t0.InputSchema, t1.InputSchema); diff != "" { + t.Errorf("input schemas do not match (-%s +%s):\n%s", t0.Name, t1.Name, diff) } - if diff := cmp.Diff(r0, r1, ctrCmpOpts...); diff != "" { - t.Errorf("CallTool(%v) with no arguments mismatch (-%s +%s):\n%s", args, t0.Name, t1.Name, diff) + if diff := cmp.Diff(t0.OutputSchema, t1.OutputSchema); diff != "" { + t.Errorf("output schemas do not match (-%s +%s):\n%s", t0.Name, t1.Name, diff) } - } - // Then, check that we handle different types of output equivalently. - for _, in := range []string{"nil", "empty", "ok"} { - t.Run(in, func(t *testing.T) { - r0, err := cs.CallTool(ctx, &CallToolParams{Name: t0.Name, Arguments: input{In: in}}) + // Then, check that we handle empty input equivalently. + for _, args := range []any{nil, struct{}{}} { + r0, err := cs.CallTool(ctx, &CallToolParams{Name: t0.Name, Arguments: args}) if err != nil { t.Fatal(err) } - r1, err := cs.CallTool(ctx, &CallToolParams{Name: t1.Name, Arguments: input{In: in}}) + r1, err := cs.CallTool(ctx, &CallToolParams{Name: t1.Name, Arguments: args}) if err != nil { t.Fatal(err) } if diff := cmp.Diff(r0, r1, ctrCmpOpts...); diff != "" { - t.Errorf("CallTool({\"In\": %q}) mismatch (-%s +%s):\n%s", in, t0.Name, t1.Name, diff) + t.Errorf("CallTool(%v) with no arguments mismatch (-%s +%s):\n%s", args, t0.Name, t1.Name, diff) } - }) + } + + // Then, check that we handle different types of output equivalently. + for _, in := range []string{"nil", "empty", "ok"} { + t.Run(in, func(t *testing.T) { + r0, err := cs.CallTool(ctx, &CallToolParams{Name: t0.Name, Arguments: input{In: in}}) + if err != nil { + t.Fatal(err) + } + r1, err := cs.CallTool(ctx, &CallToolParams{Name: t1.Name, Arguments: input{In: in}}) + if err != nil { + t.Fatal(err) + } + if diff := cmp.Diff(r0, r1, ctrCmpOpts...); diff != "" { + t.Errorf("CallTool({\"In\": %q}) mismatch (-%s +%s):\n%s", in, t0.Name, t1.Name, diff) + } + }) + } } } diff --git a/mcp/server.go b/mcp/server.go index d4317222..6e2be0dd 100644 --- a/mcp/server.go +++ b/mcp/server.go @@ -359,21 +359,14 @@ func toolForErr[In, Out any](t *Tool, h ToolHandlerFor[In, Out]) (*Tool, ToolHan // Pointers are treated equivalently to non-pointers when deriving the schema. // If an indirection occurred to derive the schema, a non-nil zero value is // returned to be used in place of the typed nil zero value. -// -// Note that if sfield already holds a schema, zero will be nil even if T is a -// pointer: if the user provided the schema, they may have intentionally -// derived it from the pointer type, and handling of zero values is up to them. -// -// TODO(rfindley): we really shouldn't ever return 'null' results. Maybe we -// should have a jsonschema.Zero(schema) helper? func setSchema[T any](sfield *any, rfield **jsonschema.Resolved) (zero any, err error) { + rt := reflect.TypeFor[T]() + if rt.Kind() == reflect.Pointer { + rt = rt.Elem() + zero = reflect.Zero(rt).Interface() + } var internalSchema *jsonschema.Schema if *sfield == nil { - rt := reflect.TypeFor[T]() - if rt.Kind() == reflect.Pointer { - rt = rt.Elem() - zero = reflect.Zero(rt).Interface() - } // TODO: we should be able to pass nil opts here. internalSchema, err = jsonschema.ForType(rt, &jsonschema.ForOptions{}) if err == nil {