Skip to content

Commit b5e8e6f

Browse files
authored
feat(mcp): do mcp tool call from cli args (#1216)
* make json rpc 2.0 call based on cli args - basic sse resposne parsing - remove zero / nil arguments before making the request * move tool request to mcp_request.go * print tool response according to output schema * print structuredContent from JSON RPC response * only register mcp command if SRC_EXPERIMENT_MCP=true * rename ParseToolResponse to DecodeToolResponse
1 parent 7bc3f56 commit b5e8e6f

File tree

3 files changed

+172
-18
lines changed

3 files changed

+172
-18
lines changed

cmd/src/mcp.go

Lines changed: 61 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,27 @@
11
package main
22

33
import (
4+
"context"
5+
"encoding/json"
46
"flag"
57
"fmt"
8+
"os"
69
"strings"
710

11+
"github.com/sourcegraph/src-cli/internal/api"
812
"github.com/sourcegraph/src-cli/internal/mcp"
13+
14+
"github.com/sourcegraph/sourcegraph/lib/errors"
915
)
1016

1117
func init() {
12-
flagSet := flag.NewFlagSet("mcp", flag.ExitOnError)
13-
commands = append(commands, &command{
14-
flagSet: flagSet,
15-
handler: mcpMain,
16-
})
18+
if os.Getenv("SRC_EXPERIMENT_MCP") == "true" {
19+
flagSet := flag.NewFlagSet("mcp", flag.ExitOnError)
20+
commands = append(commands, &command{
21+
flagSet: flagSet,
22+
handler: mcpMain,
23+
})
24+
}
1725
}
1826
func mcpMain(args []string) error {
1927
fmt.Println("NOTE: This command is still experimental")
@@ -44,37 +52,73 @@ func mcpMain(args []string) error {
4452
if !ok {
4553
return fmt.Errorf("tool definition for %q not found - run src mcp list-tools to see a list of available tools", subcmd)
4654
}
47-
return handleMcpTool(tool, args[1:])
48-
}
4955

50-
func handleMcpTool(tool *mcp.ToolDef, args []string) error {
51-
fs, vars, err := mcp.BuildArgFlagSet(tool)
56+
flagArgs := args[1:] // skip subcommand name
57+
if len(args) > 1 && args[1] == "schema" {
58+
return printSchemas(tool)
59+
}
60+
61+
flags, vars, err := mcp.BuildArgFlagSet(tool)
5262
if err != nil {
5363
return err
5464
}
65+
if err := flags.Parse(flagArgs); err != nil {
66+
return err
67+
}
68+
mcp.DerefFlagValues(vars)
5569

56-
if err := fs.Parse(args); err != nil {
70+
if err := validateToolArgs(tool.InputSchema, args, vars); err != nil {
5771
return err
5872
}
5973

60-
inputSchema := tool.InputSchema
74+
apiClient := cfg.apiClient(nil, flags.Output())
75+
return handleMcpTool(context.Background(), apiClient, tool, vars)
76+
}
77+
78+
func printSchemas(tool *mcp.ToolDef) error {
79+
input, err := json.MarshalIndent(tool.InputSchema, "", " ")
80+
if err != nil {
81+
return err
82+
}
83+
output, err := json.MarshalIndent(tool.OutputSchema, "", " ")
84+
if err != nil {
85+
return err
86+
}
6187

88+
fmt.Printf("Input:\n%v\nOutput:\n%v\n", string(input), string(output))
89+
return nil
90+
}
91+
92+
func validateToolArgs(inputSchema mcp.SchemaObject, args []string, vars map[string]any) error {
6293
for _, reqName := range inputSchema.Required {
6394
if vars[reqName] == nil {
64-
return fmt.Errorf("no value provided for required flag --%s", reqName)
95+
return errors.Newf("no value provided for required flag --%s", reqName)
6596
}
6697
}
6798

6899
if len(args) < len(inputSchema.Required) {
69-
return fmt.Errorf("not enough arguments provided - the following flags are required:\n%s", strings.Join(inputSchema.Required, "\n"))
100+
return errors.Newf("not enough arguments provided - the following flags are required:\n%s", strings.Join(inputSchema.Required, "\n"))
70101
}
71102

72-
mcp.DerefFlagValues(vars)
103+
return nil
104+
}
73105

74-
fmt.Println("Flags")
75-
for name, val := range vars {
76-
fmt.Printf("--%s=%v\n", name, val)
106+
func handleMcpTool(ctx context.Context, client api.Client, tool *mcp.ToolDef, vars map[string]any) error {
107+
resp, err := mcp.DoToolRequest(ctx, client, tool, vars)
108+
if err != nil {
109+
return err
77110
}
78111

112+
result, err := mcp.DecodeToolResponse(resp)
113+
if err != nil {
114+
return err
115+
}
116+
defer resp.Body.Close()
117+
118+
output, err := json.MarshalIndent(result, "", " ")
119+
if err != nil {
120+
return err
121+
}
122+
fmt.Println(string(output))
79123
return nil
80124
}

internal/mcp/mcp_args.go

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,11 +32,28 @@ func DerefFlagValues(vars map[string]any) {
3232
if slice, ok := vv.(strSliceFlag); ok {
3333
vv = slice.vals
3434
}
35-
vars[k] = vv
35+
if isNil(vv) {
36+
delete(vars, k)
37+
} else {
38+
vars[k] = vv
39+
}
3640
}
3741
}
3842
}
3943

44+
func isNil(v any) bool {
45+
if v == nil {
46+
return true
47+
}
48+
rv := reflect.ValueOf(v)
49+
switch rv.Kind() {
50+
case reflect.Slice, reflect.Map, reflect.Pointer, reflect.Interface:
51+
return rv.IsNil()
52+
default:
53+
return false
54+
}
55+
}
56+
4057
func BuildArgFlagSet(tool *ToolDef) (*flag.FlagSet, map[string]any, error) {
4158
if tool == nil {
4259
return nil, nil, errors.New("cannot build flagset on nil Tool Definition")

internal/mcp/mcp_request.go

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
package mcp
2+
3+
import (
4+
"bytes"
5+
"context"
6+
"encoding/json"
7+
"io"
8+
"net/http"
9+
10+
"github.com/sourcegraph/src-cli/internal/api"
11+
12+
"github.com/sourcegraph/sourcegraph/lib/errors"
13+
)
14+
15+
const McpURLPath = ".api/mcp/v1"
16+
17+
func DoToolRequest(ctx context.Context, client api.Client, tool *ToolDef, vars map[string]any) (*http.Response, error) {
18+
jsonRPC := struct {
19+
Version string `json:"jsonrpc"`
20+
ID int `json:"id"`
21+
Method string `json:"method"`
22+
Params any `json:"params"`
23+
}{
24+
Version: "2.0",
25+
ID: 1,
26+
Method: "tools/call",
27+
Params: struct {
28+
Name string `json:"name"`
29+
Arguments map[string]any `json:"arguments"`
30+
}{
31+
Name: tool.RawName,
32+
Arguments: vars,
33+
},
34+
}
35+
36+
buf := bytes.NewBuffer(nil)
37+
data, err := json.Marshal(jsonRPC)
38+
if err != nil {
39+
return nil, err
40+
}
41+
buf.Write(data)
42+
43+
req, err := client.NewHTTPRequest(ctx, http.MethodPost, McpURLPath, buf)
44+
if err != nil {
45+
return nil, err
46+
}
47+
req.Header.Add("Content-Type", "application/json")
48+
req.Header.Add("Accept", "*/*")
49+
50+
return client.Do(req)
51+
}
52+
53+
func DecodeToolResponse(resp *http.Response) (map[string]json.RawMessage, error) {
54+
data, err := readSSEResponseData(resp)
55+
if err != nil {
56+
return nil, err
57+
}
58+
59+
if data == nil {
60+
return map[string]json.RawMessage{}, nil
61+
}
62+
63+
jsonRPCResp := struct {
64+
Version string `json:"jsonrpc"`
65+
ID int `json:"id"`
66+
Result struct {
67+
Content []json.RawMessage `json:"content"`
68+
StructuredContent map[string]json.RawMessage `json:"structuredContent"`
69+
} `json:"result"`
70+
}{}
71+
if err := json.Unmarshal(data, &jsonRPCResp); err != nil {
72+
return nil, errors.Wrapf(err, "failed to unmarshal MCP JSON-RPC response")
73+
}
74+
75+
return jsonRPCResp.Result.StructuredContent, nil
76+
}
77+
func readSSEResponseData(resp *http.Response) ([]byte, error) {
78+
data, err := io.ReadAll(resp.Body)
79+
if err != nil {
80+
return nil, err
81+
}
82+
// The response is an SSE reponse
83+
// event:
84+
// data:
85+
lines := bytes.SplitSeq(data, []byte("\n"))
86+
for line := range lines {
87+
if jsonData, ok := bytes.CutPrefix(line, []byte("data: ")); ok {
88+
return jsonData, nil
89+
}
90+
}
91+
return nil, errors.New("no data found in SSE response")
92+
93+
}

0 commit comments

Comments
 (0)