diff --git a/README.md b/README.md index dc063f22c..8997790a5 100644 --- a/README.md +++ b/README.md @@ -1092,6 +1092,14 @@ The following sets of tools are available: - `repo`: Repository name (string, required) - `title`: PR title (string, required) +- **get_pull_request_review_threads_batch** - Get batch pull request review threads + - **Required OAuth Scopes**: `repo` + - `afterByPullNumber`: Optional per-PR cursor map keyed by stringified pull request number. Each value should be the endCursor returned for that pull request in a previous batch response. (object, optional) + - `owner`: Repository owner (string, required) + - `perPage`: Review threads per pull request page (min 1, max 100) (integer, optional) + - `pullNumbers`: Explicit pull request numbers to hydrate. Accepts up to 20 items. (integer[], required) + - `repo`: Repository name (string, required) + - **list_pull_requests** - List pull requests - **Required OAuth Scopes**: `repo` - `base`: Filter by base branch (string, optional) diff --git a/pkg/github/__toolsnaps__/get_pull_request_review_threads_batch.snap b/pkg/github/__toolsnaps__/get_pull_request_review_threads_batch.snap new file mode 100644 index 000000000..2ea78546f --- /dev/null +++ b/pkg/github/__toolsnaps__/get_pull_request_review_threads_batch.snap @@ -0,0 +1,49 @@ +{ + "annotations": { + "readOnlyHint": true, + "title": "Get batch pull request review threads" + }, + "description": "Get review threads for an explicit list of pull requests in a GitHub repository. Returns partial success with per-PR errors and supports per-PR cursors via afterByPullNumber.", + "inputSchema": { + "properties": { + "afterByPullNumber": { + "additionalProperties": { + "type": "string" + }, + "description": "Optional per-PR cursor map keyed by stringified pull request number. Each value should be the endCursor returned for that pull request in a previous batch response.", + "type": "object" + }, + "owner": { + "description": "Repository owner", + "type": "string" + }, + "perPage": { + "description": "Review threads per pull request page (min 1, max 100)", + "maximum": 100, + "minimum": 1, + "type": "integer" + }, + "pullNumbers": { + "description": "Explicit pull request numbers to hydrate. Accepts up to 20 items.", + "items": { + "minimum": 1, + "type": "integer" + }, + "maxItems": 20, + "minItems": 1, + "type": "array" + }, + "repo": { + "description": "Repository name", + "type": "string" + } + }, + "required": [ + "owner", + "repo", + "pullNumbers" + ], + "type": "object" + }, + "name": "get_pull_request_review_threads_batch" +} \ No newline at end of file diff --git a/pkg/github/pullrequests_batch_review_threads.go b/pkg/github/pullrequests_batch_review_threads.go new file mode 100644 index 000000000..e90492811 --- /dev/null +++ b/pkg/github/pullrequests_batch_review_threads.go @@ -0,0 +1,228 @@ +package github + +import ( + "context" + "encoding/json" + "fmt" + "strconv" + + "github.com/github/github-mcp-server/pkg/ifc" + "github.com/github/github-mcp-server/pkg/inventory" + "github.com/github/github-mcp-server/pkg/translations" + "github.com/github/github-mcp-server/pkg/utils" + "github.com/google/jsonschema-go/jsonschema" + "github.com/modelcontextprotocol/go-sdk/mcp" + + "github.com/github/github-mcp-server/pkg/scopes" +) + +const maxPullRequestReviewThreadsBatchSize = 20 + +type batchPullRequestReviewThreadsItem struct { + PullNumber int `json:"pull_number"` + ReviewThreads MinimalReviewThreadsResponse `json:"review_threads"` +} + +type batchPullRequestReviewThreadsError struct { + PullNumber int `json:"pull_number"` + Message string `json:"message"` +} + +type batchPullRequestReviewThreadsResponse struct { + Results []batchPullRequestReviewThreadsItem `json:"results"` + Errors []batchPullRequestReviewThreadsError `json:"errors,omitempty"` +} + +func GetPullRequestReviewThreadsBatch(t translations.TranslationHelperFunc) inventory.ServerTool { + schema := &jsonschema.Schema{ + Type: "object", + Properties: map[string]*jsonschema.Schema{ + "owner": { + Type: "string", + Description: "Repository owner", + }, + "repo": { + Type: "string", + Description: "Repository name", + }, + "pullNumbers": { + Type: "array", + Description: fmt.Sprintf("Explicit pull request numbers to hydrate. Accepts up to %d items.", maxPullRequestReviewThreadsBatchSize), + MinItems: jsonschema.Ptr(1), + MaxItems: jsonschema.Ptr(maxPullRequestReviewThreadsBatchSize), + Items: &jsonschema.Schema{ + Type: "integer", + Minimum: jsonschema.Ptr(1.0), + }, + }, + "perPage": { + Type: "integer", + Description: "Review threads per pull request page (min 1, max 100)", + Minimum: jsonschema.Ptr(1.0), + Maximum: jsonschema.Ptr(100.0), + }, + "afterByPullNumber": { + Type: "object", + Description: "Optional per-PR cursor map keyed by stringified pull request number. Each value should be the endCursor returned for that pull request in a previous batch response.", + AdditionalProperties: &jsonschema.Schema{Type: "string"}, + }, + }, + Required: []string{"owner", "repo", "pullNumbers"}, + } + + return NewTool( + ToolsetMetadataPullRequests, + mcp.Tool{ + Name: "get_pull_request_review_threads_batch", + Description: t("TOOL_GET_PULL_REQUEST_REVIEW_THREADS_BATCH_DESCRIPTION", "Get review threads for an explicit list of pull requests in a GitHub repository. Returns partial success with per-PR errors and supports per-PR cursors via afterByPullNumber."), + Annotations: &mcp.ToolAnnotations{ + Title: t("TOOL_GET_PULL_REQUEST_REVIEW_THREADS_BATCH_USER_TITLE", "Get batch pull request review threads"), + ReadOnlyHint: true, + }, + InputSchema: schema, + }, + []scopes.Scope{scopes.Repo}, + func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + owner, err := RequiredParam[string](args, "owner") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + repo, err := RequiredParam[string](args, "repo") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + pullNumbers, err := requiredReviewThreadBatchPullNumbers(args, "pullNumbers", maxPullRequestReviewThreadsBatchSize) + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + perPage, err := OptionalIntParamWithDefault(args, "perPage", 30) + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + basePagination := CursorPaginationParams{PerPage: perPage} + afterByPullNumber, err := optionalAfterByPullNumberParam(args, "afterByPullNumber") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + + gqlClient, err := deps.GetGQLClient(ctx) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to get GitHub GQL client", err), nil, nil + } + + result := batchPullRequestReviewThreadsResponse{ + Results: make([]batchPullRequestReviewThreadsItem, 0, len(pullNumbers)), + Errors: make([]batchPullRequestReviewThreadsError, 0), + } + + for _, pullNumber := range pullNumbers { + pagination := basePagination + if cursor, ok := afterByPullNumber[pullNumber]; ok { + pagination.After = cursor + } + + toolResult, err := GetPullRequestReviewComments(ctx, gqlClient, deps, owner, repo, pullNumber, pagination) + if err != nil { + return utils.NewToolResultErrorFromErr(fmt.Sprintf("failed to get review threads for pull request %d", pullNumber), err), nil, nil + } + if toolResult == nil { + result.Errors = append(result.Errors, batchPullRequestReviewThreadsError{PullNumber: pullNumber, Message: "failed to get pull request review threads"}) + continue + } + if toolResult.IsError { + result.Errors = append(result.Errors, batchPullRequestReviewThreadsError{PullNumber: pullNumber, Message: getCallToolText(toolResult)}) + continue + } + + var reviewThreads MinimalReviewThreadsResponse + if err := json.Unmarshal([]byte(getCallToolText(toolResult)), &reviewThreads); err != nil { + result.Errors = append(result.Errors, batchPullRequestReviewThreadsError{PullNumber: pullNumber, Message: fmt.Sprintf("failed to decode review thread response: %v", err)}) + continue + } + + result.Results = append(result.Results, batchPullRequestReviewThreadsItem{ + PullNumber: pullNumber, + ReviewThreads: reviewThreads, + }) + } + + return attachRepoVisibilityIFCLabelLazy(ctx, deps, owner, repo, MarshalledTextResult(result), ifc.LabelListIssues), nil, nil + }, + ) +} + +func requiredReviewThreadBatchPullNumbers(args map[string]any, key string, maxItems int) ([]int, error) { + raw, ok := args[key] + if !ok { + return nil, fmt.Errorf("missing required parameter: %s", key) + } + + values, ok := raw.([]any) + if !ok { + return nil, fmt.Errorf("parameter %s could not be coerced to []int, is %T", key, raw) + } + if len(values) == 0 { + return nil, fmt.Errorf("parameter %s must contain at least one pull request number", key) + } + if len(values) > maxItems { + return nil, fmt.Errorf("parameter %s exceeds the maximum batch size of %d", key, maxItems) + } + + pullNumbers := make([]int, 0, len(values)) + seen := make(map[int]struct{}, len(values)) + for i, value := range values { + number, ok := value.(float64) + if !ok { + return nil, fmt.Errorf("parameter %s element %d is not a number, is %T", key, i, value) + } + if number < 1 || number != float64(int(number)) { + return nil, fmt.Errorf("parameter %s element %d must be a positive integer", key, i) + } + intNumber := int(number) + if _, ok := seen[intNumber]; ok { + continue + } + seen[intNumber] = struct{}{} + pullNumbers = append(pullNumbers, intNumber) + } + + return pullNumbers, nil +} + +func optionalAfterByPullNumberParam(args map[string]any, key string) (map[int]string, error) { + raw, ok := args[key] + if !ok || raw == nil { + return map[int]string{}, nil + } + + values, ok := raw.(map[string]any) + if !ok { + return nil, fmt.Errorf("parameter %s could not be coerced to map[string]string, is %T", key, raw) + } + + result := make(map[int]string, len(values)) + for pullNumber, cursorValue := range values { + cursor, ok := cursorValue.(string) + if !ok { + return nil, fmt.Errorf("parameter %s[%s] is not a string, is %T", key, pullNumber, cursorValue) + } + parsedPullNumber, err := strconv.Atoi(pullNumber) + if err != nil || parsedPullNumber < 1 { + return nil, fmt.Errorf("parameter %s contains invalid pull request key %q", key, pullNumber) + } + result[parsedPullNumber] = cursor + } + + return result, nil +} + +func getCallToolText(result *mcp.CallToolResult) string { + if result == nil || len(result.Content) == 0 { + return "" + } + text, ok := result.Content[0].(*mcp.TextContent) + if !ok { + return "" + } + return text.Text +} diff --git a/pkg/github/pullrequests_batch_review_threads_test.go b/pkg/github/pullrequests_batch_review_threads_test.go new file mode 100644 index 000000000..ad47cbd74 --- /dev/null +++ b/pkg/github/pullrequests_batch_review_threads_test.go @@ -0,0 +1,292 @@ +package github + +import ( + "bytes" + "context" + "encoding/json" + "io" + "net/http" + "testing" + + "github.com/github/github-mcp-server/internal/githubv4mock" + "github.com/github/github-mcp-server/internal/toolsnaps" + "github.com/github/github-mcp-server/pkg/translations" + "github.com/google/jsonschema-go/jsonschema" + "github.com/shurcooL/githubv4" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func Test_GetPullRequestReviewThreadsBatch(t *testing.T) { + serverTool := GetPullRequestReviewThreadsBatch(translations.NullTranslationHelper) + tool := serverTool.Tool + require.NoError(t, toolsnaps.Test(tool.Name, tool)) + + assert.Equal(t, "get_pull_request_review_threads_batch", tool.Name) + schema := tool.InputSchema.(*jsonschema.Schema) + assert.Contains(t, schema.Properties, "owner") + assert.Contains(t, schema.Properties, "repo") + assert.Contains(t, schema.Properties, "pullNumbers") + assert.Contains(t, schema.Properties, "perPage") + assert.Contains(t, schema.Properties, "afterByPullNumber") + assert.ElementsMatch(t, schema.Required, []string{"owner", "repo", "pullNumbers"}) + + tests := []struct { + name string + gqlHTTPClient *http.Client + requestArgs map[string]any + expectError bool + expectedErrMsg string + validateResult func(t *testing.T, textContent string) + }{ + { + name: "successful batch review thread fetch with per-pr cursor forwarding", + gqlHTTPClient: newBatchReviewThreadsHTTPClient(t, + map[int]MinimalReviewThreadsResponse{ + 42: { + ReviewThreads: []MinimalReviewThread{{ID: "RT-42", TotalCount: 1, Comments: []MinimalReviewComment{{Body: "Looks good", Path: "file1.go", Author: "reviewer1", HTMLURL: "https://github.com/owner/repo/pull/42#discussion_r42"}}}}, + TotalCount: 1, + PageInfo: MinimalPageInfo{HasNextPage: false, EndCursor: "cursor-42-next"}, + }, + 18: { + ReviewThreads: []MinimalReviewThread{{ID: "RT-18", TotalCount: 1, Comments: []MinimalReviewComment{{Body: "Needs update", Path: "file2.go", Author: "reviewer2", HTMLURL: "https://github.com/owner/repo/pull/18#discussion_r18"}}}}, + TotalCount: 1, + PageInfo: MinimalPageInfo{HasNextPage: true, EndCursor: "cursor-18-next"}, + }, + }, + map[int]string{42: "", 18: "cursor-18-prev"}, + nil, + ), + requestArgs: map[string]any{ + "owner": "owner", + "repo": "repo", + "pullNumbers": []any{float64(42), float64(18)}, + "afterByPullNumber": map[string]any{ + "18": "cursor-18-prev", + }, + }, + validateResult: func(t *testing.T, textContent string) { + var result batchPullRequestReviewThreadsResponse + require.NoError(t, json.Unmarshal([]byte(textContent), &result)) + assert.Len(t, result.Results, 2) + assert.Empty(t, result.Errors) + assert.Equal(t, 42, result.Results[0].PullNumber) + assert.Equal(t, "RT-42", result.Results[0].ReviewThreads.ReviewThreads[0].ID) + assert.Equal(t, 18, result.Results[1].PullNumber) + assert.Equal(t, "cursor-18-next", result.Results[1].ReviewThreads.PageInfo.EndCursor) + }, + }, + { + name: "partial GraphQL failures become per-pr errors", + gqlHTTPClient: newBatchReviewThreadsHTTPClient(t, + map[int]MinimalReviewThreadsResponse{ + 42: { + ReviewThreads: []MinimalReviewThread{{ID: "RT-42", TotalCount: 1, Comments: []MinimalReviewComment{{Body: "Looks good", Path: "file1.go", Author: "reviewer1", HTMLURL: "https://github.com/owner/repo/pull/42#discussion_r42"}}}}, + TotalCount: 1, + PageInfo: MinimalPageInfo{}, + }, + }, + map[int]string{42: "", 999: ""}, + map[int]string{999: "Could not resolve to a PullRequest with the number of 999."}, + ), + requestArgs: map[string]any{ + "owner": "owner", + "repo": "repo", + "pullNumbers": []any{float64(42), float64(999)}, + }, + validateResult: func(t *testing.T, textContent string) { + var result batchPullRequestReviewThreadsResponse + require.NoError(t, json.Unmarshal([]byte(textContent), &result)) + assert.Len(t, result.Results, 1) + assert.Equal(t, 42, result.Results[0].PullNumber) + assert.Len(t, result.Errors, 1) + assert.Equal(t, 999, result.Errors[0].PullNumber) + assert.Contains(t, result.Errors[0].Message, "failed to get pull request review threads") + }, + }, + { + name: "duplicate pull numbers are deduplicated before hydration", + gqlHTTPClient: newBatchReviewThreadsHTTPClient(t, + map[int]MinimalReviewThreadsResponse{ + 42: { + ReviewThreads: []MinimalReviewThread{{ID: "RT-42", TotalCount: 1, Comments: []MinimalReviewComment{{Body: "Looks good", Path: "file1.go", Author: "reviewer1", HTMLURL: "https://github.com/owner/repo/pull/42#discussion_r42"}}}}, + TotalCount: 1, + PageInfo: MinimalPageInfo{}, + }, + }, + map[int]string{42: ""}, + nil, + ), + requestArgs: map[string]any{ + "owner": "owner", + "repo": "repo", + "pullNumbers": []any{float64(42), float64(42), float64(42)}, + }, + validateResult: func(t *testing.T, textContent string) { + var result batchPullRequestReviewThreadsResponse + require.NoError(t, json.Unmarshal([]byte(textContent), &result)) + assert.Len(t, result.Results, 1) + assert.Empty(t, result.Errors) + assert.Equal(t, 42, result.Results[0].PullNumber) + }, + }, + { + name: "empty pullNumbers fails validation", + gqlHTTPClient: githubv4mock.NewMockedHTTPClient(), + requestArgs: map[string]any{ + "owner": "owner", + "repo": "repo", + "pullNumbers": []any{}, + }, + expectError: true, + expectedErrMsg: "must contain at least one pull request number", + }, + { + name: "oversized pullNumbers fails validation", + gqlHTTPClient: githubv4mock.NewMockedHTTPClient(), + requestArgs: map[string]any{ + "owner": "owner", + "repo": "repo", + "pullNumbers": oversizedReviewThreadArgs(maxPullRequestReviewThreadsBatchSize + 1), + }, + expectError: true, + expectedErrMsg: "exceeds the maximum batch size", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + gqlClient := githubv4.NewClient(tc.gqlHTTPClient) + deps := BaseDeps{ + GQLClient: gqlClient, + Client: mustNewGHClient(t, MockHTTPClientWithHandlers(map[string]http.HandlerFunc{})), + } + handler := serverTool.Handler(deps) + + request := createMCPRequest(tc.requestArgs) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) + require.NoError(t, err) + + if tc.expectError { + require.True(t, result.IsError) + text := getErrorResult(t, result) + assert.Contains(t, text.Text, tc.expectedErrMsg) + return + } + + require.False(t, result.IsError) + text := getTextResult(t, result) + tc.validateResult(t, text.Text) + }) + } +} + +func oversizedReviewThreadArgs(count int) []any { + values := make([]any, 0, count) + for i := range count { + values = append(values, float64(i+1)) + } + return values +} + +type batchReviewThreadsRoundTripper struct { + t *testing.T + responses map[int]MinimalReviewThreadsResponse + expectedAfter map[int]string + errorMessages map[int]string +} + +func newBatchReviewThreadsHTTPClient(t *testing.T, responses map[int]MinimalReviewThreadsResponse, expectedAfter map[int]string, errorMessages map[int]string) *http.Client { + return &http.Client{Transport: &batchReviewThreadsRoundTripper{t: t, responses: responses, expectedAfter: expectedAfter, errorMessages: errorMessages}} +} + +func (rt *batchReviewThreadsRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + rt.t.Helper() + body, err := io.ReadAll(req.Body) + require.NoError(rt.t, err) + _ = req.Body.Close() + + var gqlReq struct { + Variables map[string]any `json:"variables"` + } + require.NoError(rt.t, json.Unmarshal(body, &gqlReq)) + + prNum := int(gqlReq.Variables["prNum"].(float64)) + var actualAfter string + if afterValue, ok := gqlReq.Variables["after"]; ok && afterValue != nil { + actualAfter = afterValue.(string) + } + assert.Equal(rt.t, rt.expectedAfter[prNum], actualAfter) + + var payload map[string]any + if errMsg, ok := rt.errorMessages[prNum]; ok { + payload = map[string]any{"errors": []map[string]any{{"message": errMsg}}} + } else { + payload = map[string]any{ + "data": map[string]any{ + "repository": map[string]any{ + "pullRequest": map[string]any{ + "reviewThreads": batchReviewThreadsPayload(rt.responses[prNum]), + }, + }, + }, + } + } + + jsonBody, err := json.Marshal(payload) + require.NoError(rt.t, err) + + return &http.Response{ + StatusCode: http.StatusOK, + Header: make(http.Header), + Body: io.NopCloser(bytes.NewReader(jsonBody)), + }, nil +} + +func batchReviewThreadsPayload(resp MinimalReviewThreadsResponse) map[string]any { + nodes := make([]map[string]any, 0, len(resp.ReviewThreads)) + for _, thread := range resp.ReviewThreads { + comments := make([]map[string]any, 0, len(thread.Comments)) + for _, comment := range thread.Comments { + commentNode := map[string]any{ + "id": comment.HTMLURL, + "body": comment.Body, + "path": comment.Path, + "url": comment.HTMLURL, + "author": map[string]any{ + "login": comment.Author, + }, + } + if comment.Line != nil { + commentNode["line"] = *comment.Line + } + if comment.CreatedAt != "" { + commentNode["createdAt"] = comment.CreatedAt + commentNode["updatedAt"] = comment.CreatedAt + } + comments = append(comments, commentNode) + } + + nodes = append(nodes, map[string]any{ + "id": thread.ID, + "isResolved": thread.IsResolved, + "isOutdated": thread.IsOutdated, + "isCollapsed": thread.IsCollapsed, + "comments": map[string]any{ + "totalCount": thread.TotalCount, + "nodes": comments, + }, + }) + } + + return map[string]any{ + "nodes": nodes, + "pageInfo": map[string]any{ + "hasNextPage": resp.PageInfo.HasNextPage, + "hasPreviousPage": resp.PageInfo.HasPreviousPage, + "startCursor": resp.PageInfo.StartCursor, + "endCursor": resp.PageInfo.EndCursor, + }, + "totalCount": resp.TotalCount, + } +} diff --git a/pkg/github/tools.go b/pkg/github/tools.go index 906fa777d..59ac4b3ef 100644 --- a/pkg/github/tools.go +++ b/pkg/github/tools.go @@ -221,6 +221,7 @@ func AllTools(t translations.TranslationHelperFunc) []inventory.ServerTool { // Pull request tools PullRequestRead(t), + GetPullRequestReviewThreadsBatch(t), ListPullRequests(t), SearchPullRequests(t), MergePullRequest(t),