diff --git a/pkg/github/__toolsnaps__/get_file_contents.snap b/pkg/github/__toolsnaps__/get_file_contents.snap index b3975abbc..6ee4b8a8e 100644 --- a/pkg/github/__toolsnaps__/get_file_contents.snap +++ b/pkg/github/__toolsnaps__/get_file_contents.snap @@ -3,15 +3,19 @@ "title": "Get file or directory contents", "readOnlyHint": true }, - "description": "Get the contents of a file or directory from a GitHub repository", + "description": "Get the contents of a file or directory from a GitHub repository. To ensure the file SHA is returned and prevent fallback to raw content, set `allow_raw_fallback` to `false`.", "inputSchema": { "properties": { + "allow_raw_fallback": { + "description": "Whether to allow falling back to getting raw content when the file is too large. When this is false, the file\'s SHA will always be returned.", + "type": "boolean" + }, "owner": { "description": "Repository owner (username or organization)", "type": "string" }, "path": { - "description": "Path to file/directory (directories must end with a slash '/')", + "description": "Path to file/directory (directories must end with a slash \'/\')", "type": "string" }, "ref": { @@ -35,4 +39,4 @@ "type": "object" }, "name": "get_file_contents" -} \ No newline at end of file +} diff --git a/pkg/github/repositories.go b/pkg/github/repositories.go index fa5d7338a..8487c867f 100644 --- a/pkg/github/repositories.go +++ b/pkg/github/repositories.go @@ -8,7 +8,6 @@ import ( "io" "net/http" "net/url" - "strconv" "strings" ghErrors "github.com/github/github-mcp-server/pkg/errors" @@ -445,7 +444,7 @@ func CreateRepository(getClient GetClientFn, t translations.TranslationHelperFun // GetFileContents creates a tool to get the contents of a file or directory from a GitHub repository. func GetFileContents(getClient GetClientFn, getRawClient raw.GetRawClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { return mcp.NewTool("get_file_contents", - mcp.WithDescription(t("TOOL_GET_FILE_CONTENTS_DESCRIPTION", "Get the contents of a file or directory from a GitHub repository")), + mcp.WithDescription(t("TOOL_GET_FILE_CONTENTS_DESCRIPTION", "Get the contents of a file or directory from a GitHub repository. To ensure the file SHA is returned and prevent fallback to raw content, set `allow_raw_fallback` to `false`.")), mcp.WithToolAnnotation(mcp.ToolAnnotation{ Title: t("TOOL_GET_FILE_CONTENTS_USER_TITLE", "Get file or directory contents"), ReadOnlyHint: ToBoolPtr(true), @@ -468,6 +467,9 @@ func GetFileContents(getClient GetClientFn, getRawClient raw.GetRawClientFn, t t mcp.WithString("sha", mcp.Description("Accepts optional git sha, if sha is specified it will be used instead of ref"), ), + mcp.WithBoolean("allow_raw_fallback", + mcp.Description("Whether to allow falling back to getting raw content when the file is too large. When this is false, the file's SHA will always be returned."), + ), ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { owner, err := RequiredParam[string](request, "owner") @@ -490,91 +492,66 @@ func GetFileContents(getClient GetClientFn, getRawClient raw.GetRawClientFn, t t if err != nil { return mcp.NewToolResultError(err.Error()), nil } + allowRawFallback, err := OptionalParam[bool](request, "allow_raw_fallback") + if err != nil { + // If the parameter is not present, default to true + allowRawFallback = true + } - rawOpts := &raw.RawContentOpts{} - - if strings.HasPrefix(ref, "refs/pull/") { - prNumber := strings.TrimSuffix(strings.TrimPrefix(ref, "refs/pull/"), "/head") - if len(prNumber) > 0 { - // fetch the PR from the API to get the latest commit and use SHA - githubClient, err := getClient(ctx) - if err != nil { - return nil, fmt.Errorf("failed to get GitHub client: %w", err) - } - prNum, err := strconv.Atoi(prNumber) - if err != nil { - return nil, fmt.Errorf("invalid pull request number: %w", err) - } - pr, _, err := githubClient.PullRequests.Get(ctx, owner, repo, prNum) - if err != nil { - return nil, fmt.Errorf("failed to get pull request: %w", err) - } - sha = pr.GetHead().GetSHA() - ref = "" - } + var refOrSha string + if sha != "" { + refOrSha = sha + } else { + refOrSha = ref } - rawOpts.SHA = sha - rawOpts.Ref = ref - // If the path is (most likely) not to be a directory, we will first try to get the raw content from the GitHub raw content API. - if path != "" && !strings.HasSuffix(path, "/") { + rawOpts := &raw.RawContentOpts{} + if refOrSha != "" { + rawOpts.Ref = refOrSha + } + // If the path is (most likely) not to be a directory, we will first try to get the raw content from the GitHub raw content API. + if path != "" && !strings.HasSuffix(path, "/") && allowRawFallback { rawClient, err := getRawClient(ctx) if err != nil { return mcp.NewToolResultError("failed to get GitHub raw content client"), nil } resp, err := rawClient.GetRawContent(ctx, owner, repo, path, rawOpts) if err != nil { - return mcp.NewToolResultError("failed to get raw repository content"), nil - } - defer func() { - _ = resp.Body.Close() - }() - - if resp.StatusCode != http.StatusOK { - // If the raw content is not found, we will fall back to the GitHub API (in case it is a directory) + // Fallback to the GitHub API if there is an error } else { - // If the raw content is found, return it directly - body, err := io.ReadAll(resp.Body) - if err != nil { - return mcp.NewToolResultError("failed to read response body"), nil - } - contentType := resp.Header.Get("Content-Type") + defer func() { + _ = resp.Body.Close() + }() - var resourceURI string - switch { - case sha != "": - resourceURI, err = url.JoinPath("repo://", owner, repo, "sha", sha, "contents", path) + if resp.StatusCode == http.StatusOK { + // If the raw content is found, return it directly + body, err := io.ReadAll(resp.Body) if err != nil { - return nil, fmt.Errorf("failed to create resource URI: %w", err) + return mcp.NewToolResultError("failed to read response body"), nil } - case ref != "": - resourceURI, err = url.JoinPath("repo://", owner, repo, ref, "contents", path) + contentType := resp.Header.Get("Content-Type") + + resourceURI, err := url.JoinPath("repo://", owner, repo, refOrSha, "contents", path) if err != nil { return nil, fmt.Errorf("failed to create resource URI: %w", err) } - default: - resourceURI, err = url.JoinPath("repo://", owner, repo, "contents", path) - if err != nil { - return nil, fmt.Errorf("failed to create resource URI: %w", err) + + if strings.HasPrefix(contentType, "application") || strings.HasPrefix(contentType, "text") { + return mcp.NewToolResultResource("successfully downloaded text file", mcp.TextResourceContents{ + URI: resourceURI, + Text: string(body), + MIMEType: contentType, + }), nil } - } - if strings.HasPrefix(contentType, "application") || strings.HasPrefix(contentType, "text") { - return mcp.NewToolResultResource("successfully downloaded text file", mcp.TextResourceContents{ + return mcp.NewToolResultResource("successfully downloaded binary file", mcp.BlobResourceContents{ URI: resourceURI, - Text: string(body), + Blob: base64.StdEncoding.EncodeToString(body), MIMEType: contentType, }), nil } - - return mcp.NewToolResultResource("successfully downloaded binary file", mcp.BlobResourceContents{ - URI: resourceURI, - Blob: base64.StdEncoding.EncodeToString(body), - MIMEType: contentType, - }), nil - } } @@ -583,35 +560,30 @@ func GetFileContents(getClient GetClientFn, getRawClient raw.GetRawClientFn, t t return mcp.NewToolResultError("failed to get GitHub client"), nil } - if sha != "" { - ref = sha + opts := &github.RepositoryContentGetOptions{} + if refOrSha != "" { + opts.Ref = refOrSha } - if strings.HasSuffix(path, "/") { - opts := &github.RepositoryContentGetOptions{Ref: ref} - _, dirContent, resp, err := client.Repositories.GetContents(ctx, owner, repo, path, opts) - if err != nil { - return mcp.NewToolResultError("failed to get file contents"), nil - } - defer func() { _ = resp.Body.Close() }() - - if resp.StatusCode != 200 { - body, err := io.ReadAll(resp.Body) - if err != nil { - return mcp.NewToolResultError("failed to read response body"), nil - } - return mcp.NewToolResultError(fmt.Sprintf("failed to get file contents: %s", string(body))), nil - } + fileContent, dirContent, resp, err := client.Repositories.GetContents(ctx, owner, repo, path, opts) + if err != nil { + return ghErrors.NewGitHubAPIErrorResponse(ctx, "failed to get file contents", resp, err), nil + } + defer func() { _ = resp.Body.Close() }() - r, err := json.Marshal(dirContent) - if err != nil { - return mcp.NewToolResultError("failed to marshal response"), nil - } - return mcp.NewToolResultText(string(r)), nil + var r []byte + if dirContent != nil { + r, err = json.Marshal(dirContent) + } else { + r, err = json.Marshal(fileContent) + } + if err != nil { + return mcp.NewToolResultError("failed to marshal response"), nil } - return mcp.NewToolResultError("Failed to get file contents. The path does not point to a file or directory, or the file does not exist in the repository."), nil + return mcp.NewToolResultText(string(r)), nil } } + // ForkRepository creates a tool to fork a repository. func ForkRepository(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { return mcp.NewTool("fork_repository", diff --git a/pkg/github/repositories_test.go b/pkg/github/repositories_test.go index b621cec43..f51b2cbcd 100644 --- a/pkg/github/repositories_test.go +++ b/pkg/github/repositories_test.go @@ -33,6 +33,7 @@ func Test_GetFileContents(t *testing.T) { assert.Contains(t, tool.InputSchema.Properties, "path") assert.Contains(t, tool.InputSchema.Properties, "ref") assert.Contains(t, tool.InputSchema.Properties, "sha") + assert.Contains(t, tool.InputSchema.Properties, "allow_raw_fallback") assert.ElementsMatch(t, tool.InputSchema.Required, []string{"owner", "repo", "path"}) // Mock response for raw content @@ -218,6 +219,153 @@ func Test_GetFileContents(t *testing.T) { } } +func Test_GetFileContentsWithAllowRawFallback(t *testing.T) { + t.Parallel() + + // Mock content for the file + mockFileContent := "Hello, GitHub Copilot!" + mockFileSHA := "a1b2c3d4e5f6a7b8c9d0e1f2a3b4c5d6e7f8a9b0" + + tests := []struct { + name string + allowFallback *bool + mockResponses []mock.MockBackendOption + expectError bool + expectedSHA string + expectedContent string + }{ + { + name: "allow_raw_fallback is true (default behavior)", + allowFallback: boolPtr(true), + mockResponses: []mock.MockBackendOption{ + // First, attempt to get content with application/vnd.github.v3+json + mock.WithRequestMatchHandler( + mock.GetReposContentsByOwnerByRepoByPath, + http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Contains(t, r.Header.Get("Accept"), "application/vnd.github.v3+json") + w.WriteHeader(http.StatusNotFound) // Simulate file not found in tree endpoint + }), + ), + // Then, fallback to raw content + mock.WithRequestMatchHandler( + raw.GetRawReposContentsByOwnerByRepoByBranchByPath, + http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Contains(t, r.Header.Get("Accept"), "application/vnd.github.v3.raw") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(mockFileContent)) + }), + ), + }, + expectError: false, + expectedSHA: "", // Raw content does not return SHA + expectedContent: mockFileContent, + }, + { + name: "allow_raw_fallback is false (force SHA retrieval)", + allowFallback: boolPtr(false), + mockResponses: []mock.MockBackendOption{ + // Only attempt to get content with application/vnd.github.v3+json + mock.WithRequestMatchHandler( + mock.GetReposContentsByOwnerByRepoByPath, + http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Contains(t, r.Header.Get("Accept"), "application/vnd.github.v3+json") + content := github.RepositoryContent{ + SHA: github.Ptr(mockFileSHA), + Content: github.Ptr(base64.StdEncoding.EncodeToString([]byte(mockFileContent))), + } + w.WriteHeader(http.StatusOK) + _ = json.NewEncoder(w).Encode(content) + }), + ), + }, + expectError: false, + expectedSHA: mockFileSHA, + expectedContent: mockFileContent, + }, + { + name: "allow_raw_fallback is nil (default behavior, fallback allowed)", + allowFallback: nil, + mockResponses: []mock.MockBackendOption{ + // First, attempt to get content with application/vnd.github.v3+json + mock.WithRequestMatchHandler( + mock.GetReposContentsByOwnerByRepoByPath, + http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Contains(t, r.Header.Get("Accept"), "application/vnd.github.v3+json") + w.WriteHeader(http.StatusNotFound) // Simulate file not found in tree endpoint + }), + ), + // Then, fallback to raw content + mock.WithRequestMatchHandler( + raw.GetRawReposContentsByOwnerByRepoByBranchByPath, + http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Contains(t, r.Header.Get("Accept"), "application/vnd.github.v3.raw") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(mockFileContent)) + }), + ), + }, + expectError: false, + expectedSHA: "", // Raw content does not return SHA + expectedContent: mockFileContent, + }, + { + name: "allow_raw_fallback is false and content fetch fails", + allowFallback: boolPtr(false), + mockResponses: []mock.MockBackendOption{ + // Only attempt to get content with application/vnd.github.v3+json + mock.WithRequestMatchHandler( + mock.GetReposContentsByOwnerByRepoByPath, + http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Contains(t, r.Header.Get("Accept"), "application/vnd.github.v3+json") + w.WriteHeader(http.StatusNotFound) // Simulate file not found + _, _ = w.Write([]byte(`{"message": "Not Found"}`)) + }), + ), + }, + expectError: true, + expectedContent: "", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + client := github.NewClient(mock.NewMockedHTTPClient(tc.mockResponses...)) + mockRawClient := raw.NewClient(client, &url.URL{Scheme: "https", Host: "raw.example.com", Path: "/"}) + _, handler := GetFileContents(stubGetClientFn(client), stubGetRawClientFn(mockRawClient), translations.NullTranslationHelper) + + requestArgs := map[string]interface{}{ + "owner": "test-owner", + "repo": "test-repo", + "path": "test.md", + "ref": "main", + } + if tc.allowFallback != nil { + requestArgs["allow_raw_fallback"] = *tc.allowFallback + } + + request := createMCPRequest(requestArgs) + result, err := handler(context.Background(), request) + + if tc.expectError { + require.Error(t, err) + assert.True(t, result.IsError) + assert.Contains(t, getErrorResult(t, result).Text, "Failed to get file contents") + return + } + + require.NoError(t, err) + assert.False(t, result.IsError) + + textContent := getTextResult(t, result) + assert.Equal(t, tc.expectedContent, textContent.Text) + assert.Equal(t, tc.expectedSHA, textContent.URI) // For raw content, URI contains SHA if returned, otherwise empty. This needs careful assertion. + if tc.expectedSHA != "" { + assert.Contains(t, textContent.URI, tc.expectedSHA) + } + }) + } +} + func Test_ForkRepository(t *testing.T) { // Verify tool definition once mockClient := github.NewClient(nil) @@ -2049,3 +2197,8 @@ func Test_GetTag(t *testing.T) { }) } } + +// Helper to get a pointer to a boolean value +func boolPtr(b bool) *bool { + return &b +}