From 434f67d8f625f83533756aa89db1ac24aeec25da Mon Sep 17 00:00:00 2001 From: Mladen Todorovic Date: Fri, 12 Dec 2025 16:48:41 +0100 Subject: [PATCH] Add support for nodes --- internal/toolsets/vulnerability/nodes.go | 213 ++++++++++ internal/toolsets/vulnerability/nodes_test.go | 369 ++++++++++++++++++ internal/toolsets/vulnerability/toolset.go | 1 + .../toolsets/vulnerability/toolset_test.go | 3 +- 4 files changed, 585 insertions(+), 1 deletion(-) create mode 100644 internal/toolsets/vulnerability/nodes.go create mode 100644 internal/toolsets/vulnerability/nodes_test.go diff --git a/internal/toolsets/vulnerability/nodes.go b/internal/toolsets/vulnerability/nodes.go new file mode 100644 index 0000000..038e4d8 --- /dev/null +++ b/internal/toolsets/vulnerability/nodes.go @@ -0,0 +1,213 @@ +package vulnerability + +import ( + "context" + "fmt" + "io" + "sort" + "strings" + + "github.com/google/jsonschema-go/jsonschema" + "github.com/modelcontextprotocol/go-sdk/mcp" + "github.com/pkg/errors" + v1 "github.com/stackrox/rox/generated/api/v1" + "github.com/stackrox/stackrox-mcp/internal/client" + "github.com/stackrox/stackrox-mcp/internal/client/auth" + "github.com/stackrox/stackrox-mcp/internal/logging" + "github.com/stackrox/stackrox-mcp/internal/toolsets" + "google.golang.org/grpc" +) + +// getNodesForCVEInput defines the input parameters for get_nodes_for_cve tool. +type getNodesForCVEInput struct { + CVEName string `json:"cveName"` + FilterClusterID string `json:"filterClusterId,omitempty"` +} + +func (input *getNodesForCVEInput) validate() error { + if input.CVEName == "" { + return errors.New("CVE name is required") + } + + return nil +} + +// NodeGroupResult contains aggregated node information by cluster and OS. +type NodeGroupResult struct { + ClusterID string `json:"clusterId"` + ClusterName string `json:"clusterName"` + OperatingSystem string `json:"operatingSystem"` + Count int `json:"count"` +} + +// getNodesForCVEOutput defines the output structure for get_nodes_for_cve tool. +type getNodesForCVEOutput struct { + NodeGroups []NodeGroupResult `json:"nodeGroups"` +} + +// getNodesForCVETool implements the get_nodes_for_cve tool. +type getNodesForCVETool struct { + name string + client *client.Client +} + +// NewGetNodesForCVETool creates a new get_nodes_for_cve tool. +func NewGetNodesForCVETool(c *client.Client) toolsets.Tool { + return &getNodesForCVETool{ + name: "get_nodes_for_cve", + client: c, + } +} + +// IsReadOnly returns true as this tool only reads data. +func (t *getNodesForCVETool) IsReadOnly() bool { + return true +} + +// GetName returns the tool name. +func (t *getNodesForCVETool) GetName() string { + return t.name +} + +// GetTool returns the MCP Tool definition. +func (t *getNodesForCVETool) GetTool() *mcp.Tool { + return &mcp.Tool{ + Name: t.name, + Description: "Get aggregated node groups affected by a specific CVE, grouped by cluster and operating system image", + InputSchema: getNodesForCVEInputSchema(), + } +} + +// getNodesForCVEInputSchema returns the JSON schema for input validation. +func getNodesForCVEInputSchema() *jsonschema.Schema { + schema, err := jsonschema.For[getNodesForCVEInput](nil) + if err != nil { + logging.Fatal("Could not get jsonschema for get_nodes_for_cve input", err) + + return nil + } + + // CVE name is required. + schema.Required = []string{"cveName"} + + schema.Properties["cveName"].Description = "CVE name to filter nodes (e.g., CVE-2020-26159)" + schema.Properties["filterClusterId"].Description = "Optional cluster ID to filter nodes" + + return schema +} + +// RegisterWith registers the get_nodes_for_cve tool handler with the MCP server. +func (t *getNodesForCVETool) RegisterWith(server *mcp.Server) { + mcp.AddTool(server, t.GetTool(), t.handle) +} + +// buildNodeQuery builds query used to search nodes in StackRox Central. +// We will quote values to have strict match. Without quote: CVE-2025-10, would match CVE-2025-101. +func buildNodeQuery(input getNodesForCVEInput) string { + queryParts := []string{fmt.Sprintf("CVE:%q", input.CVEName)} + + if input.FilterClusterID != "" { + queryParts = append(queryParts, fmt.Sprintf("Cluster ID:%q", input.FilterClusterID)) + } + + return strings.Join(queryParts, "+") +} + +// aggregateNodeGroups consumes entire stream and aggregates nodes by cluster and OS. +func aggregateNodeGroups( + stream grpc.ServerStreamingClient[v1.ExportNodeResponse], +) ([]NodeGroupResult, error) { + // Map key: "clusterId|osImage" + // Map value: NodeGroupResult with count and clusterName. + groups := make(map[string]*NodeGroupResult) + + for { + resp, err := stream.Recv() + + // Stream ended - no more nodes. + if errors.Is(err, io.EOF) { + break + } + + if err != nil { + return nil, errors.Wrap(err, "error receiving from stream") + } + + node := resp.GetNode() + if node == nil { + continue + } + + // Create unique key for this cluster+OS combination. + key := fmt.Sprintf("%s|%s", node.GetClusterId(), node.GetOsImage()) + if group, exists := groups[key]; exists { + group.Count++ + + continue + } + + groups[key] = &NodeGroupResult{ + ClusterID: node.GetClusterId(), + ClusterName: node.GetClusterName(), + OperatingSystem: node.GetOsImage(), + Count: 1, + } + } + + result := make([]NodeGroupResult, 0, len(groups)) + for _, group := range groups { + result = append(result, *group) + } + + // Sort for consistent ordering (by clusterId, then OS). + sort.Slice(result, func(i, j int) bool { + if result[i].ClusterID != result[j].ClusterID { + return result[i].ClusterID < result[j].ClusterID + } + + return result[i].OperatingSystem < result[j].OperatingSystem + }) + + return result, nil +} + +// handle is the handler for get_nodes_for_cve tool. +func (t *getNodesForCVETool) handle( + ctx context.Context, + req *mcp.CallToolRequest, + input getNodesForCVEInput, +) (*mcp.CallToolResult, *getNodesForCVEOutput, error) { + err := input.validate() + if err != nil { + return nil, nil, err + } + + conn, err := t.client.ReadyConn(ctx) + if err != nil { + return nil, nil, errors.Wrap(err, "unable to connect to server") + } + + callCtx := auth.WithMCPRequestContext(ctx, req) + nodeClient := v1.NewNodeServiceClient(conn) + + query := buildNodeQuery(input) + exportReq := &v1.ExportNodeRequest{ + Query: query, + } + + stream, err := nodeClient.ExportNodes(callCtx, exportReq) + if err != nil { + return nil, nil, client.NewError(err, "ExportNodes") + } + + nodeGroups, err := aggregateNodeGroups(stream) + if err != nil { + return nil, nil, err + } + + output := &getNodesForCVEOutput{ + NodeGroups: nodeGroups, + } + + return nil, output, nil +} diff --git a/internal/toolsets/vulnerability/nodes_test.go b/internal/toolsets/vulnerability/nodes_test.go new file mode 100644 index 0000000..1cd39ed --- /dev/null +++ b/internal/toolsets/vulnerability/nodes_test.go @@ -0,0 +1,369 @@ +package vulnerability + +import ( + "context" + "testing" + + "github.com/modelcontextprotocol/go-sdk/mcp" + "github.com/pkg/errors" + v1 "github.com/stackrox/rox/generated/api/v1" + "github.com/stackrox/rox/generated/storage" + "github.com/stackrox/stackrox-mcp/internal/client" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +func TestNewGetNodesForCVETool(t *testing.T) { + tool := NewGetNodesForCVETool(&client.Client{}) + require.NotNil(t, tool) + assert.Equal(t, "get_nodes_for_cve", tool.GetName()) +} + +func TestGetNodesForCVETool_IsReadOnly(t *testing.T) { + c := &client.Client{} + tool := NewGetNodesForCVETool(c) + + assert.True(t, tool.IsReadOnly(), "get_nodes_for_cve should be read-only") +} + +func TestGetNodesForCVETool_GetTool(t *testing.T) { + c := &client.Client{} + tool := NewGetNodesForCVETool(c) + + mcpTool := tool.GetTool() + + require.NotNil(t, mcpTool) + assert.Equal(t, "get_nodes_for_cve", mcpTool.Name) + assert.Contains(t, mcpTool.Description, "aggregated") + assert.NotNil(t, mcpTool.InputSchema) +} + +func TestGetNodesForCVETool_RegisterWith(t *testing.T) { + c := &client.Client{} + tool := NewGetNodesForCVETool(c) + server := mcp.NewServer( + &mcp.Implementation{ + Name: "test-server", + Version: "1.0.0", + }, + &mcp.ServerOptions{}, + ) + + // Should not panic. + assert.NotPanics(t, func() { + tool.RegisterWith(server) + }) +} + +// Unit tests for input validate method. +func TestNodeInputValidate(t *testing.T) { + tests := map[string]struct { + input getNodesForCVEInput + expectError bool + errorMsg string + }{ + "valid input with CVE only": { + input: getNodesForCVEInput{CVEName: "CVE-2021-44228"}, + expectError: false, + }, + "missing CVE name (empty string)": { + input: getNodesForCVEInput{CVEName: ""}, + expectError: true, + errorMsg: "CVE name is required", + }, + "missing CVE name (zero value)": { + input: getNodesForCVEInput{}, + expectError: true, + errorMsg: "CVE name is required", + }, + } + + for testName, testCase := range tests { + t.Run(testName, func(t *testing.T) { + err := testCase.input.validate() + + if !testCase.expectError { + require.NoError(t, err) + + return + } + + require.Error(t, err) + assert.Contains(t, err.Error(), testCase.errorMsg) + }) + } +} + +// Mock infrastructure for gRPC testing. + +// mockNodeService implements v1.NodeServiceServer for testing. +type mockNodeService struct { + v1.UnimplementedNodeServiceServer + + nodes []*storage.Node + err error + + lastCallQuery string +} + +func (m *mockNodeService) ExportNodes( + req *v1.ExportNodeRequest, + stream grpc.ServerStreamingServer[v1.ExportNodeResponse], +) error { + m.lastCallQuery = req.GetQuery() + + if m.err != nil { + return m.err + } + + // Send all nodes through the stream. + for _, node := range m.nodes { + resp := &v1.ExportNodeResponse{Node: node} + if err := stream.Send(resp); err != nil { + return errors.Wrap(err, "sending node over stream failed") + } + } + + return nil +} + +// Integration tests for handle method. +func TestNodeHandle_MissingCVE(t *testing.T) { + mockService := &mockNodeService{ + nodes: []*storage.Node{}, + } + + grpcServer, listener := setupMockDeploymentServer(&mockDeploymentService{}) + defer grpcServer.Stop() + + v1.RegisterNodeServiceServer(grpcServer, mockService) + + testClient := createTestClient(t, listener) + tool, ok := NewGetNodesForCVETool(testClient).(*getNodesForCVETool) + require.True(t, ok) + + ctx := context.Background() + req := &mcp.CallToolRequest{} + inputWithoutCVEName := getNodesForCVEInput{} + + result, output, err := tool.handle(ctx, req, inputWithoutCVEName) + + require.Error(t, err) + assert.Nil(t, result) + assert.Nil(t, output) + assert.Contains(t, err.Error(), "CVE name is required") +} + +func TestNodeHandle_EmptyResults(t *testing.T) { + mockService := &mockNodeService{ + nodes: []*storage.Node{}, + } + + grpcServer, listener := setupMockDeploymentServer(&mockDeploymentService{}) + defer grpcServer.Stop() + + v1.RegisterNodeServiceServer(grpcServer, mockService) + + testClient := createTestClient(t, listener) + tool, ok := NewGetNodesForCVETool(testClient).(*getNodesForCVETool) + require.True(t, ok) + + ctx := context.Background() + req := &mcp.CallToolRequest{} + input := getNodesForCVEInput{ + CVEName: "CVE-9999-99999", + } + + result, output, err := tool.handle(ctx, req, input) + + require.NoError(t, err) + require.NotNil(t, output) + assert.Nil(t, result) + assert.Empty(t, output.NodeGroups, "Should have empty nodeGroups array") +} + +func TestNodeHandle_ExportNodesError(t *testing.T) { + mockService := &mockNodeService{ + err: status.Error(codes.Internal, "database error"), + } + + grpcServer, listener := setupMockDeploymentServer(&mockDeploymentService{}) + defer grpcServer.Stop() + + v1.RegisterNodeServiceServer(grpcServer, mockService) + + testClient := createTestClient(t, listener) + tool, ok := NewGetNodesForCVETool(testClient).(*getNodesForCVETool) + require.True(t, ok) + + ctx := context.Background() + req := &mcp.CallToolRequest{} + input := getNodesForCVEInput{ + CVEName: "CVE-2021-44228", + } + + result, output, err := tool.handle(ctx, req, input) + + require.Error(t, err) + assert.Nil(t, result) + assert.Nil(t, output) + assert.Contains(t, err.Error(), "database error") +} + +func TestNodeHandle_Aggregation(t *testing.T) { + mockService := &mockNodeService{ + nodes: []*storage.Node{ + {Name: "n1", ClusterId: "c1", ClusterName: "Prod", OsImage: "Ubuntu 20.04"}, + {Name: "n2", ClusterId: "c1", ClusterName: "Prod", OsImage: "Ubuntu 20.04"}, + {Name: "n3", ClusterId: "c1", ClusterName: "Prod", OsImage: "Ubuntu 22.04"}, + {Name: "n4", ClusterId: "c2", ClusterName: "Dev", OsImage: "Ubuntu 20.04"}, + {Name: "n5", ClusterId: "c2", ClusterName: "Dev", OsImage: "Ubuntu 20.04"}, + {Name: "n6", ClusterId: "c1", ClusterName: "Prod", OsImage: "Ubuntu 20.04"}, + }, + } + + grpcServer, listener := setupMockDeploymentServer(&mockDeploymentService{}) + defer grpcServer.Stop() + + v1.RegisterNodeServiceServer(grpcServer, mockService) + + testClient := createTestClient(t, listener) + tool, ok := NewGetNodesForCVETool(testClient).(*getNodesForCVETool) + require.True(t, ok) + + ctx := context.Background() + req := &mcp.CallToolRequest{} + input := getNodesForCVEInput{ + CVEName: "CVE-2021-44228", + } + + result, output, err := tool.handle(ctx, req, input) + + require.NoError(t, err) + require.NotNil(t, output) + assert.Nil(t, result) + + // Should have 3 groups: + // - c1 + Ubuntu 20.04: count=2 + // - c1 + Ubuntu 22.04: count=1 + // - c2 + Ubuntu 20.04: count=1 + require.Len(t, output.NodeGroups, 3) + + // Verify first group (sorted by cluster, then OS). + assert.Equal(t, "c1", output.NodeGroups[0].ClusterID) + assert.Equal(t, "Prod", output.NodeGroups[0].ClusterName) + assert.Equal(t, "Ubuntu 20.04", output.NodeGroups[0].OperatingSystem) + assert.Equal(t, 3, output.NodeGroups[0].Count) + + // Verify second group (c1 + Ubuntu 22.04). + assert.Equal(t, "c1", output.NodeGroups[1].ClusterID) + assert.Equal(t, "Prod", output.NodeGroups[1].ClusterName) + assert.Equal(t, "Ubuntu 22.04", output.NodeGroups[1].OperatingSystem) + assert.Equal(t, 1, output.NodeGroups[1].Count) + + // Verify third group (c2 + Ubuntu 20.04). + assert.Equal(t, "c2", output.NodeGroups[2].ClusterID) + assert.Equal(t, "Dev", output.NodeGroups[2].ClusterName) + assert.Equal(t, "Ubuntu 20.04", output.NodeGroups[2].OperatingSystem) + assert.Equal(t, 2, output.NodeGroups[2].Count) +} + +func TestNodeHandle_Sorting(t *testing.T) { + mockService := &mockNodeService{ + nodes: []*storage.Node{ + {Name: "n1", ClusterId: "z-cluster", ClusterName: "A", OsImage: "Ubuntu 20.04"}, + {Name: "n2", ClusterId: "a-cluster", ClusterName: "Z", OsImage: "RHEL 8"}, + {Name: "n3", ClusterId: "a-cluster", ClusterName: "Z", OsImage: "CentOS 7"}, + {Name: "n4", ClusterId: "z-cluster", ClusterName: "A", OsImage: ""}, + }, + } + + grpcServer, listener := setupMockDeploymentServer(&mockDeploymentService{}) + defer grpcServer.Stop() + + v1.RegisterNodeServiceServer(grpcServer, mockService) + + testClient := createTestClient(t, listener) + tool, ok := NewGetNodesForCVETool(testClient).(*getNodesForCVETool) + require.True(t, ok) + + ctx := context.Background() + req := &mcp.CallToolRequest{} + input := getNodesForCVEInput{ + CVEName: "CVE-2021-44228", + } + + result, output, err := tool.handle(ctx, req, input) + + require.NoError(t, err) + require.NotNil(t, output) + assert.Nil(t, result) + + require.Len(t, output.NodeGroups, 4) + + // Should be sorted by cluster ID first, then OS. + // Expected order: + // 1. a-cluster + CentOS 7 + // 2. a-cluster + RHEL 8 + // 3. z-cluster + + // 4. z-cluster + Ubuntu 20.04 + assert.Equal(t, "a-cluster", output.NodeGroups[0].ClusterID) + assert.Equal(t, "CentOS 7", output.NodeGroups[0].OperatingSystem) + + assert.Equal(t, "a-cluster", output.NodeGroups[1].ClusterID) + assert.Equal(t, "RHEL 8", output.NodeGroups[1].OperatingSystem) + + assert.Equal(t, "z-cluster", output.NodeGroups[2].ClusterID) + assert.Empty(t, output.NodeGroups[2].OperatingSystem) + + assert.Equal(t, "z-cluster", output.NodeGroups[3].ClusterID) + assert.Equal(t, "Ubuntu 20.04", output.NodeGroups[3].OperatingSystem) +} + +func TestNodeHandle_WithFilters(t *testing.T) { + mockService := &mockNodeService{ + nodes: []*storage.Node{ + {Name: "n1", ClusterId: "cluster-1", ClusterName: "C1", OsImage: "Ubuntu 20.04"}, + }, + } + + grpcServer, listener := setupMockDeploymentServer(&mockDeploymentService{}) + defer grpcServer.Stop() + + v1.RegisterNodeServiceServer(grpcServer, mockService) + + tool, ok := NewGetNodesForCVETool(createTestClient(t, listener)).(*getNodesForCVETool) + require.True(t, ok) + + tests := map[string]struct { + input getNodesForCVEInput + expectedQuery string + }{ + "CVE only": { + input: getNodesForCVEInput{CVEName: "CVE-2021-44228"}, + expectedQuery: `CVE:"CVE-2021-44228"`, + }, + "CVE with cluster": { + input: getNodesForCVEInput{ + CVEName: "CVE-2021-44228", + FilterClusterID: "cluster-123", + }, + expectedQuery: `CVE:"CVE-2021-44228"+Cluster ID:"cluster-123"`, + }, + } + + for testName, testCase := range tests { + t.Run(testName, func(t *testing.T) { + result, output, err := tool.handle(context.Background(), &mcp.CallToolRequest{}, testCase.input) + + require.NoError(t, err) + require.NotNil(t, output) + assert.Nil(t, result) + assert.Len(t, output.NodeGroups, 1) + assert.Equal(t, testCase.expectedQuery, mockService.lastCallQuery) + }) + } +} diff --git a/internal/toolsets/vulnerability/toolset.go b/internal/toolsets/vulnerability/toolset.go index ded3c78..0ab8f0d 100644 --- a/internal/toolsets/vulnerability/toolset.go +++ b/internal/toolsets/vulnerability/toolset.go @@ -19,6 +19,7 @@ func NewToolset(cfg *config.Config, c *client.Client) *Toolset { cfg: cfg, tools: []toolsets.Tool{ NewGetDeploymentsForCVETool(c), + NewGetNodesForCVETool(c), }, } } diff --git a/internal/toolsets/vulnerability/toolset_test.go b/internal/toolsets/vulnerability/toolset_test.go index c1b58eb..9eccf0d 100644 --- a/internal/toolsets/vulnerability/toolset_test.go +++ b/internal/toolsets/vulnerability/toolset_test.go @@ -38,8 +38,9 @@ func TestToolset_IsEnabled_True(t *testing.T) { tools := toolset.GetTools() require.NotEmpty(t, tools, "Should return tools when enabled") - require.Len(t, tools, 1, "Should have tools") + require.Len(t, tools, 2, "Should have all vulnerability tools") assert.Equal(t, "get_deployments_for_cve", tools[0].GetName()) + assert.Equal(t, "get_nodes_for_cve", tools[1].GetName()) } func TestToolset_IsEnabled_False(t *testing.T) {