Skip to content

Commit b057c48

Browse files
omgitsadsalmaleksia
andcommitted
Add resource completion for GitHub repository resources
Co-Authored-by: Ksenia Bobrova <almaleksia@github.com>
1 parent 7e19170 commit b057c48

File tree

3 files changed

+717
-9
lines changed

3 files changed

+717
-9
lines changed
Lines changed: 340 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,340 @@
1+
package github
2+
3+
import (
4+
"context"
5+
"errors"
6+
"fmt"
7+
"strings"
8+
9+
"github.com/google/go-github/v79/github"
10+
"github.com/modelcontextprotocol/go-sdk/mcp"
11+
)
12+
13+
// CompleteHandler defines function signature for completion handlers
14+
type CompleteHandler func(ctx context.Context, client *github.Client, resolved map[string]string, argValue string) ([]string, error)
15+
16+
// RepositoryResourceArgumentResolvers is a map of argument names to their completion handlers
17+
var RepositoryResourceArgumentResolvers = map[string]CompleteHandler{
18+
"owner": completeOwner,
19+
"repo": completeRepo,
20+
"branch": completeBranch,
21+
"sha": completeSHA,
22+
"tag": completeTag,
23+
"prNumber": completePRNumber,
24+
"path": completePath,
25+
}
26+
27+
// RepositoryResourceCompletionHandler returns a CompletionHandlerFunc for repository resource completions.
28+
func RepositoryResourceCompletionHandler(getClient GetClientFn) func(ctx context.Context, req *mcp.CompleteRequest) (*mcp.CompleteResult, error) {
29+
return func(ctx context.Context, req *mcp.CompleteRequest) (*mcp.CompleteResult, error) {
30+
fmt.Println("Handling repository resource completion for", req.Params.Argument.Name)
31+
32+
if req.Params.Ref.Type != "ref/resource" {
33+
return nil, nil // Not a resource completion
34+
}
35+
36+
argName := req.Params.Argument.Name
37+
argValue := req.Params.Argument.Value
38+
resolved := req.Params.Context.Arguments
39+
if resolved == nil {
40+
resolved = map[string]string{}
41+
}
42+
43+
client, err := getClient(ctx)
44+
if err != nil {
45+
return nil, err
46+
}
47+
48+
// Argument resolver functions
49+
resolvers := RepositoryResourceArgumentResolvers
50+
51+
resolver, ok := resolvers[argName]
52+
if !ok {
53+
return nil, errors.New("no resolver for argument: " + argName)
54+
}
55+
56+
values, err := resolver(ctx, client, resolved, argValue)
57+
if err != nil {
58+
return nil, err
59+
}
60+
if len(values) > 100 {
61+
values = values[:100]
62+
}
63+
64+
return &mcp.CompleteResult{
65+
Completion: mcp.CompletionResultDetails{
66+
Values: values,
67+
Total: len(values),
68+
HasMore: false,
69+
},
70+
}, nil
71+
}
72+
}
73+
74+
// --- Per-argument resolver functions ---
75+
76+
func completeOwner(ctx context.Context, client *github.Client, resolved map[string]string, argValue string) ([]string, error) {
77+
var values []string
78+
user, _, err := client.Users.Get(ctx, "")
79+
fmt.Printf("Found user: %v\n", err)
80+
if err == nil && user.GetLogin() != "" {
81+
values = append(values, user.GetLogin())
82+
fmt.Println("Fetching organizations for user " + user.GetLogin())
83+
}
84+
85+
orgs, _, err := client.Organizations.List(ctx, "", &github.ListOptions{PerPage: 100})
86+
if err != nil {
87+
return nil, err
88+
}
89+
for _, org := range orgs {
90+
fmt.Println("Found organization: " + org.GetLogin())
91+
values = append(values, org.GetLogin())
92+
}
93+
94+
// filter values based on argValue and replace values slice
95+
if argValue != "" {
96+
var filteredValues []string
97+
for _, value := range values {
98+
if strings.Contains(value, argValue) {
99+
filteredValues = append(filteredValues, value)
100+
}
101+
}
102+
values = filteredValues
103+
}
104+
if len(values) > 100 {
105+
values = values[:100]
106+
return values, nil // Limit to 100 results
107+
}
108+
// Else also do a client.Search.Users()
109+
if argValue == "" {
110+
return values, nil // No need to search if no argValue
111+
}
112+
users, _, err := client.Search.Users(ctx, argValue, &github.SearchOptions{ListOptions: github.ListOptions{PerPage: 100 - len(values)}})
113+
if err != nil || users == nil {
114+
return nil, err
115+
}
116+
for _, user := range users.Users {
117+
values = append(values, user.GetLogin())
118+
}
119+
120+
if len(values) > 100 {
121+
values = values[:100]
122+
}
123+
return values, nil
124+
}
125+
126+
func completeRepo(ctx context.Context, client *github.Client, resolved map[string]string, argValue string) ([]string, error) {
127+
var values []string
128+
owner := resolved["owner"]
129+
if owner == "" {
130+
return values, errors.New("owner not specified")
131+
}
132+
133+
query := fmt.Sprintf("org:%s", owner)
134+
135+
if argValue != "" {
136+
query = fmt.Sprintf("%s %s", query, argValue)
137+
}
138+
repos, _, err := client.Search.Repositories(ctx, query, &github.SearchOptions{ListOptions: github.ListOptions{PerPage: 100}})
139+
if err != nil || repos == nil {
140+
return values, errors.New("failed to get repositories")
141+
}
142+
// filter repos based on argValue
143+
for _, repo := range repos.Repositories {
144+
name := repo.GetName()
145+
if argValue == "" || strings.HasPrefix(name, argValue) {
146+
values = append(values, name)
147+
}
148+
}
149+
150+
return values, nil
151+
}
152+
153+
func completeBranch(ctx context.Context, client *github.Client, resolved map[string]string, argValue string) ([]string, error) {
154+
var values []string
155+
owner := resolved["owner"]
156+
repo := resolved["repo"]
157+
if owner == "" || repo == "" {
158+
return values, errors.New("owner or repo not specified")
159+
}
160+
branches, _, _ := client.Repositories.ListBranches(ctx, owner, repo, nil)
161+
162+
for _, branch := range branches {
163+
if argValue == "" || strings.HasPrefix(branch.GetName(), argValue) {
164+
values = append(values, branch.GetName())
165+
}
166+
}
167+
if len(values) > 100 {
168+
values = values[:100]
169+
}
170+
return values, nil
171+
}
172+
173+
func completeSHA(ctx context.Context, client *github.Client, resolved map[string]string, argValue string) ([]string, error) {
174+
var values []string
175+
owner := resolved["owner"]
176+
repo := resolved["repo"]
177+
if owner == "" || repo == "" {
178+
return values, errors.New("owner or repo not specified")
179+
}
180+
commits, _, _ := client.Repositories.ListCommits(ctx, owner, repo, nil)
181+
182+
for _, commit := range commits {
183+
sha := commit.GetSHA()
184+
if argValue == "" || strings.HasPrefix(sha, argValue) {
185+
values = append(values, sha)
186+
}
187+
}
188+
if len(values) > 100 {
189+
values = values[:100]
190+
}
191+
return values, nil
192+
}
193+
194+
func completeTag(ctx context.Context, client *github.Client, resolved map[string]string, argValue string) ([]string, error) {
195+
owner := resolved["owner"]
196+
repo := resolved["repo"]
197+
if owner == "" || repo == "" {
198+
return nil, errors.New("owner or repo not specified")
199+
}
200+
tags, _, _ := client.Repositories.ListTags(ctx, owner, repo, nil)
201+
var values []string
202+
for _, tag := range tags {
203+
if argValue == "" || strings.Contains(tag.GetName(), argValue) {
204+
values = append(values, tag.GetName())
205+
}
206+
}
207+
if len(values) > 100 {
208+
values = values[:100]
209+
}
210+
return values, nil
211+
}
212+
213+
func completePRNumber(ctx context.Context, client *github.Client, resolved map[string]string, argValue string) ([]string, error) {
214+
var values []string
215+
owner := resolved["owner"]
216+
repo := resolved["repo"]
217+
if owner == "" || repo == "" {
218+
return values, errors.New("owner or repo not specified")
219+
}
220+
221+
prs, _, err := client.Search.Issues(ctx, fmt.Sprintf("repo:%s/%s is:open is:pr", owner, repo), &github.SearchOptions{ListOptions: github.ListOptions{PerPage: 100}})
222+
if err != nil {
223+
return values, err
224+
}
225+
for _, pr := range prs.Issues {
226+
num := fmt.Sprintf("%d", pr.GetNumber())
227+
if argValue == "" || strings.HasPrefix(num, argValue) {
228+
values = append(values, num)
229+
}
230+
}
231+
if len(values) > 100 {
232+
values = values[:100]
233+
}
234+
return values, nil
235+
}
236+
237+
func completePath(ctx context.Context, client *github.Client, resolved map[string]string, argValue string) ([]string, error) {
238+
owner := resolved["owner"]
239+
repo := resolved["repo"]
240+
if owner == "" || repo == "" {
241+
return nil, errors.New("owner or repo not specified")
242+
}
243+
refVal := resolved["branch"]
244+
if refVal == "" {
245+
refVal = resolved["sha"]
246+
}
247+
if refVal == "" {
248+
refVal = resolved["tag"]
249+
}
250+
if refVal == "" {
251+
refVal = "HEAD"
252+
}
253+
254+
// Determine the prefix to complete (directory path or file path)
255+
prefix := argValue
256+
if prefix != "" && !strings.HasSuffix(prefix, "/") {
257+
lastSlash := strings.LastIndex(prefix, "/")
258+
if lastSlash >= 0 {
259+
prefix = prefix[:lastSlash+1]
260+
} else {
261+
prefix = ""
262+
}
263+
}
264+
265+
// Get the tree for the ref (recursive)
266+
tree, _, err := client.Git.GetTree(ctx, owner, repo, refVal, true)
267+
if err != nil || tree == nil {
268+
return nil, errors.New("failed to get file tree")
269+
}
270+
271+
// Collect immediate children of the prefix (files and directories, no duplicates)
272+
dirs := map[string]struct{}{}
273+
files := map[string]struct{}{}
274+
prefixLen := len(prefix)
275+
for _, entry := range tree.Entries {
276+
if !strings.HasPrefix(entry.GetPath(), prefix) {
277+
continue
278+
}
279+
rel := entry.GetPath()[prefixLen:]
280+
if rel == "" {
281+
continue
282+
}
283+
// Only immediate children
284+
slashIdx := strings.Index(rel, "/")
285+
if slashIdx >= 0 {
286+
// Directory: only add the directory name (with trailing slash), prefixed with full path
287+
dirName := prefix + rel[:slashIdx+1]
288+
dirs[dirName] = struct{}{}
289+
} else if entry.GetType() == "blob" {
290+
// File: add as-is, prefixed with full path
291+
fileName := prefix + rel
292+
files[fileName] = struct{}{}
293+
}
294+
}
295+
296+
// Optionally filter by argValue (if user is typing after last slash)
297+
var filter string
298+
if argValue != "" {
299+
if lastSlash := strings.LastIndex(argValue, "/"); lastSlash >= 0 {
300+
filter = argValue[lastSlash+1:]
301+
} else {
302+
filter = argValue
303+
}
304+
}
305+
306+
var values []string
307+
// Add directories first, then files, both filtered
308+
for dir := range dirs {
309+
// Only filter on the last segment after the last slash
310+
if filter == "" {
311+
values = append(values, dir)
312+
} else {
313+
last := dir
314+
if idx := strings.LastIndex(strings.TrimRight(dir, "/"), "/"); idx >= 0 {
315+
last = dir[idx+1:]
316+
}
317+
if strings.HasPrefix(last, filter) {
318+
values = append(values, dir)
319+
}
320+
}
321+
}
322+
for file := range files {
323+
if filter == "" {
324+
values = append(values, file)
325+
} else {
326+
last := file
327+
if idx := strings.LastIndex(file, "/"); idx >= 0 {
328+
last = file[idx+1:]
329+
}
330+
if strings.HasPrefix(last, filter) {
331+
values = append(values, file)
332+
}
333+
}
334+
}
335+
336+
if len(values) > 100 {
337+
values = values[:100]
338+
}
339+
return values, nil
340+
}

0 commit comments

Comments
 (0)