From 7b932fce442d941305b52cd8cee4ca6ae05f1080 Mon Sep 17 00:00:00 2001 From: Jeremy Drouillard Date: Tue, 13 Jan 2026 09:41:53 -0800 Subject: [PATCH 01/10] Update vmcp/README --- cmd/vmcp/README.md | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/cmd/vmcp/README.md b/cmd/vmcp/README.md index 30ac862ca2..10a60bef2a 100644 --- a/cmd/vmcp/README.md +++ b/cmd/vmcp/README.md @@ -6,7 +6,7 @@ The Virtual MCP Server (vmcp) is a standalone binary that aggregates multiple MC ## Features -### Implemented (Phase 1) +### Implemented - ✅ **Group-Based Backend Management**: Automatic workload discovery from ToolHive groups - ✅ **Tool Aggregation**: Combines tools from multiple MCP servers with conflict resolution (prefix, priority, manual) - ✅ **Resource & Prompt Aggregation**: Unified access to resources and prompts from all backends @@ -15,12 +15,14 @@ The Virtual MCP Server (vmcp) is a standalone binary that aggregates multiple MC - ✅ **Health Endpoints**: `/health` and `/ping` for service monitoring - ✅ **Configuration Validation**: `vmcp validate` command for config verification - ✅ **Observability**: OpenTelemetry metrics and traces for backend operations and workflow executions +- ✅ **Composite Tools**: Multi-step workflows with elicitation support ### In Progress - 🚧 **Incoming Authentication** (Issue #165): OIDC, local, anonymous authentication - 🚧 **Outgoing Authentication** (Issue #160): RFC 8693 token exchange for backend API access - 🚧 **Token Caching**: Memory and Redis cache providers - 🚧 **Health Monitoring** (Issue #166): Circuit breakers, backend health checks +- 🚧 **Optimizer** Support the MCP optimizer in vMCP for context optimization on large toolsets. ### Future (Phase 2+) - 📋 **Authorization**: Cedar policy-based access control From 610c82754246e07116c2d90587746c52435de313 Mon Sep 17 00:00:00 2001 From: Jeremy Drouillard Date: Tue, 13 Jan 2026 11:00:34 -0800 Subject: [PATCH 02/10] refactor(vmcp): extract AddOnRegisterSession hook to named method Extract the inline AddOnRegisterSession hook to a dedicated handleSessionRegistration method for better testability and readability. This is a pure refactor with no functional changes - the same logic is executed, just moved to a named method on *Server. This prepares the codebase for the optimizer feature which will need to conditionally modify session registration behavior. --- pkg/vmcp/server/server.go | 185 +++++++++++++++++++++----------------- 1 file changed, 103 insertions(+), 82 deletions(-) diff --git a/pkg/vmcp/server/server.go b/pkg/vmcp/server/server.go index 477b9a1234..2dc3eb0b2a 100644 --- a/pkg/vmcp/server/server.go +++ b/pkg/vmcp/server/server.go @@ -344,89 +344,9 @@ func New( } // Register OnRegisterSession hook to inject capabilities after SDK registers session. - // This hook fires AFTER the session is registered in the SDK (unlike AfterInitialize which - // fires BEFORE session registration), allowing us to safely call AddSessionTools/AddSessionResources. - // - // The discovery middleware populates capabilities in the context, which is available here. - // We inject them into the SDK session and store the routing table for subsequent requests. - // - // IMPORTANT: Session capabilities are immutable after injection. - // - Capabilities discovered during initialize are fixed for the session lifetime - // - Backend changes (new tools, removed resources) won't be reflected in existing sessions - // - Clients must create new sessions to see updated capabilities - // TODO(dynamic-capabilities): Consider implementing capability refresh mechanism when SDK supports it + // See handleSessionRegistration for implementation details. hooks.AddOnRegisterSession(func(ctx context.Context, session server.ClientSession) { - sessionID := session.SessionID() - logger.Debugw("OnRegisterSession hook called", "session_id", sessionID) - - // Get capabilities from context (discovered by middleware) - caps, ok := discovery.DiscoveredCapabilitiesFromContext(ctx) - if !ok || caps == nil { - logger.Warnw("no discovered capabilities in context for OnRegisterSession hook", - "session_id", sessionID) - return - } - - // Validate that routing table exists - if caps.RoutingTable == nil { - logger.Warnw("routing table is nil in discovered capabilities", - "session_id", sessionID) - return - } - - // Add composite tools to capabilities - // Composite tools are static (from configuration) and not discovered from backends - // They are added here to be exposed alongside backend tools in the session - if len(srv.workflowDefs) > 0 { - compositeTools := convertWorkflowDefsToTools(srv.workflowDefs) - - // Validate no conflicts between composite tool names and backend tool names - if err := validateNoToolConflicts(caps.Tools, compositeTools); err != nil { - logger.Errorw("composite tool name conflict detected", - "session_id", sessionID, - "error", err) - // Don't add composite tools if there are conflicts - // This prevents ambiguity in routing/execution - return - } - - caps.CompositeTools = compositeTools - logger.Debugw("added composite tools to session capabilities", - "session_id", sessionID, - "composite_tool_count", len(compositeTools)) - } - - // Store routing table in VMCPSession for subsequent requests - // This enables the middleware to reconstruct capabilities from session - // without re-running discovery for every request - vmcpSess, err := vmcpsession.GetVMCPSession(sessionID, sessionManager) - if err != nil { - logger.Errorw("failed to get VMCPSession for routing table storage", - "error", err, - "session_id", sessionID) - return - } - - vmcpSess.SetRoutingTable(caps.RoutingTable) - vmcpSess.SetTools(caps.Tools) - logger.Debugw("routing table and tools stored in VMCPSession", - "session_id", sessionID, - "tool_count", len(caps.RoutingTable.Tools), - "resource_count", len(caps.RoutingTable.Resources), - "prompt_count", len(caps.RoutingTable.Prompts)) - - // Inject capabilities into SDK session - if err := srv.injectCapabilities(sessionID, caps); err != nil { - logger.Errorw("failed to inject session capabilities", - "error", err, - "session_id", sessionID) - return - } - - logger.Infow("session capabilities injected", - "session_id", sessionID, - "tool_count", len(caps.Tools), - "resource_count", len(caps.Resources)) + srv.handleSessionRegistration(ctx, session, sessionManager) }) return srv, nil @@ -852,6 +772,107 @@ func (s *Server) injectCapabilities( return nil } +// handleSessionRegistration processes a new MCP session registration. +// +// This hook fires AFTER the session is registered in the SDK (unlike AfterInitialize which +// fires BEFORE session registration), allowing us to safely call AddSessionTools/AddSessionResources. +// +// The discovery middleware populates capabilities in the context, which is available here. +// We inject them into the SDK session and store the routing table for subsequent requests. +// +// This method performs the following steps: +// 1. Retrieves discovered capabilities from context +// 2. Adds composite tools from configuration +// 3. Stores routing table in VMCPSession for request routing +// 4. Injects capabilities into the SDK session +// +// IMPORTANT: Session capabilities are immutable after injection. +// - Capabilities discovered during initialize are fixed for the session lifetime +// - Backend changes (new tools, removed resources) won't be reflected in existing sessions +// - Clients must create new sessions to see updated capabilities +// +// TODO(dynamic-capabilities): Consider implementing capability refresh mechanism when SDK supports it +// +// The sessionManager parameter is passed explicitly because this method is called +// from a closure registered before the Server is fully constructed. +func (s *Server) handleSessionRegistration( + ctx context.Context, + session server.ClientSession, + sessionManager *transportsession.Manager, +) { + sessionID := session.SessionID() + logger.Debugw("OnRegisterSession hook called", "session_id", sessionID) + + // Get capabilities from context (discovered by middleware) + caps, ok := discovery.DiscoveredCapabilitiesFromContext(ctx) + if !ok || caps == nil { + logger.Warnw("no discovered capabilities in context for OnRegisterSession hook", + "session_id", sessionID) + return + } + + // Validate that routing table exists + if caps.RoutingTable == nil { + logger.Warnw("routing table is nil in discovered capabilities", + "session_id", sessionID) + return + } + + // Add composite tools to capabilities + // Composite tools are static (from configuration) and not discovered from backends + // They are added here to be exposed alongside backend tools in the session + if len(s.workflowDefs) > 0 { + compositeTools := convertWorkflowDefsToTools(s.workflowDefs) + + // Validate no conflicts between composite tool names and backend tool names + if err := validateNoToolConflicts(caps.Tools, compositeTools); err != nil { + logger.Errorw("composite tool name conflict detected", + "session_id", sessionID, + "error", err) + // Don't add composite tools if there are conflicts + // This prevents ambiguity in routing/execution + return + } + + caps.CompositeTools = compositeTools + logger.Debugw("added composite tools to session capabilities", + "session_id", sessionID, + "composite_tool_count", len(compositeTools)) + } + + // Store routing table in VMCPSession for subsequent requests + // This enables the middleware to reconstruct capabilities from session + // without re-running discovery for every request + vmcpSess, err := vmcpsession.GetVMCPSession(sessionID, sessionManager) + if err != nil { + logger.Errorw("failed to get VMCPSession for routing table storage", + "error", err, + "session_id", sessionID) + return + } + + vmcpSess.SetRoutingTable(caps.RoutingTable) + vmcpSess.SetTools(caps.Tools) + logger.Debugw("routing table and tools stored in VMCPSession", + "session_id", sessionID, + "tool_count", len(caps.RoutingTable.Tools), + "resource_count", len(caps.RoutingTable.Resources), + "prompt_count", len(caps.RoutingTable.Prompts)) + + // Inject capabilities into SDK session + if err := s.injectCapabilities(sessionID, caps); err != nil { + logger.Errorw("failed to inject session capabilities", + "error", err, + "session_id", sessionID) + return + } + + logger.Infow("session capabilities injected", + "session_id", sessionID, + "tool_count", len(caps.Tools), + "resource_count", len(caps.Resources)) +} + // validateAndCreateExecutors validates workflow definitions and creates executors. // // This function: From 6951199c9c4e01a459ca39957e65339f45c57d2a Mon Sep 17 00:00:00 2001 From: Jeremy Drouillard Date: Tue, 13 Jan 2026 11:01:21 -0800 Subject: [PATCH 03/10] feat(vmcp): add optimizer interface and types Add the pkg/vmcp/optimizer package with: - Optimizer interface defining FindTool and CallTool methods - FindToolInput/FindToolOutput types for tool discovery - CallToolInput/CallToolResult types for tool invocation - ToolMatch for search results with relevance scoring - TokenMetrics for tracking token usage optimization This interface will be implemented by: - DummyOptimizer: exact string matching (testing) - EmbeddingOptimizer: semantic similarity (production) --- pkg/vmcp/optimizer/optimizer.go | 108 ++++++++++++++++++++++++++++++++ 1 file changed, 108 insertions(+) create mode 100644 pkg/vmcp/optimizer/optimizer.go diff --git a/pkg/vmcp/optimizer/optimizer.go b/pkg/vmcp/optimizer/optimizer.go new file mode 100644 index 0000000000..b98648826f --- /dev/null +++ b/pkg/vmcp/optimizer/optimizer.go @@ -0,0 +1,108 @@ +// Package optimizer provides the Optimizer interface for intelligent tool discovery +// and invocation in the Virtual MCP Server. +// +// When the optimizer is enabled, vMCP exposes only two tools to clients: +// - find_tool: Semantic search over available tools +// - call_tool: Dynamic invocation of any backend tool +// +// This reduces token usage by avoiding the need to send all tool definitions +// to the LLM, instead allowing it to discover relevant tools on demand. +package optimizer + +import ( + "context" +) + +// Optimizer defines the interface for intelligent tool discovery and invocation. +// +// Implementations may use various strategies for tool matching: +// - DummyOptimizer: Exact string matching (for testing) +// - EmbeddingOptimizer: Semantic similarity via embeddings (production) +type Optimizer interface { + // FindTool searches for tools matching the given description and keywords. + // Returns matching tools ranked by relevance score. + FindTool(ctx context.Context, input FindToolInput) (*FindToolOutput, error) + + // CallTool invokes a tool by name with the given parameters. + // Returns the tool's result or an error if the tool is not found or execution fails. + CallTool(ctx context.Context, input CallToolInput) (*CallToolResult, error) +} + +// FindToolInput contains the parameters for finding tools. +type FindToolInput struct { + // ToolDescription is a natural language description of the tool to find. + ToolDescription string `json:"tool_description" description:"Natural language description of the tool to find"` + + // ToolKeywords is an optional list of keywords to narrow the search. + ToolKeywords []string `json:"tool_keywords,omitempty" description:"Optional keywords to narrow search"` +} + +// FindToolOutput contains the results of a tool search. +type FindToolOutput struct { + // Tools contains the matching tools, ranked by relevance. + Tools []ToolMatch `json:"tools"` + + // TokenMetrics provides information about token savings from using the optimizer. + TokenMetrics TokenMetrics `json:"token_metrics"` +} + +// ToolMatch represents a tool that matched the search criteria. +type ToolMatch struct { + // Name is the unique identifier of the tool. + Name string `json:"name"` + + // Description is the human-readable description of the tool. + Description string `json:"description"` + + // Parameters is the JSON schema for the tool's input parameters. + Parameters map[string]any `json:"parameters"` + + // Score indicates how well this tool matches the search criteria (0.0-1.0). + Score float64 `json:"score"` +} + +// TokenMetrics provides information about token usage optimization. +type TokenMetrics struct { + // BaselineTokens is the estimated tokens if all tools were sent. + BaselineTokens int `json:"baseline_tokens"` + + // ReturnedTokens is the actual tokens for the returned tools. + ReturnedTokens int `json:"returned_tokens"` + + // SavingsPercent is the percentage of tokens saved. + SavingsPercent float64 `json:"savings_percent"` +} + +// CallToolInput contains the parameters for calling a tool. +type CallToolInput struct { + // ToolName is the name of the tool to invoke. + ToolName string `json:"tool_name" description:"Name of the tool to call"` + + // Parameters are the arguments to pass to the tool. + Parameters map[string]any `json:"parameters" description:"Parameters to pass to the tool"` +} + +// CallToolResult contains the result of a tool invocation. +// This wraps the standard MCP CallToolResult content. +type CallToolResult struct { + // Content contains the tool's output. + Content []ContentBlock `json:"content"` + + // IsError indicates whether the tool execution resulted in an error. + IsError bool `json:"isError,omitempty"` +} + +// ContentBlock represents a single content item in a tool result. +type ContentBlock struct { + // Type is the content type (e.g., "text", "image", "resource"). + Type string `json:"type"` + + // Text is the text content (for type="text"). + Text string `json:"text,omitempty"` + + // Data is base64-encoded data (for type="image" or binary content). + Data string `json:"data,omitempty"` + + // MimeType is the MIME type of the content. + MimeType string `json:"mimeType,omitempty"` +} From e5042df8ea50e61e1833e76aec693696c8c712a1 Mon Sep 17 00:00:00 2001 From: Jeremy Drouillard Date: Tue, 13 Jan 2026 13:34:05 -0800 Subject: [PATCH 04/10] feat(vmcp): add optimizer interface and DummyOptimizer implementation Add the pkg/vmcp/optimizer package with: Optimizer interface: - FindTool: Search for tools by description (semantic search) - CallTool: Invoke a tool by name, returns map[string]any directly Types: - FindToolInput/FindToolOutput for tool discovery - CallToolInput for tool invocation - ToolMatch for search results with relevance scoring - TokenMetrics for tracking optimization (placeholder for now) DummyOptimizer implementation: - Case-insensitive substring matching on tool name and description - Routes tool calls via router interface to backend client - Backend tools only (composite tools not supported in v1) - Intended for testing; production will use EmbeddingOptimizer Note: Composite tools are excluded as they require the composer for execution, not the router/backendClient pipeline. This will be addressed in a future iteration. Signed-off-by: Jeremy Drouillard --- pkg/vmcp/optimizer/dummy_optimizer.go | 123 +++++++++++ pkg/vmcp/optimizer/dummy_optimizer_test.go | 226 +++++++++++++++++++++ pkg/vmcp/optimizer/optimizer.go | 28 +-- 3 files changed, 351 insertions(+), 26 deletions(-) create mode 100644 pkg/vmcp/optimizer/dummy_optimizer.go create mode 100644 pkg/vmcp/optimizer/dummy_optimizer_test.go diff --git a/pkg/vmcp/optimizer/dummy_optimizer.go b/pkg/vmcp/optimizer/dummy_optimizer.go new file mode 100644 index 0000000000..db29dc835c --- /dev/null +++ b/pkg/vmcp/optimizer/dummy_optimizer.go @@ -0,0 +1,123 @@ +package optimizer + +import ( + "context" + "fmt" + "strings" + + "github.com/stacklok/toolhive/pkg/vmcp" + "github.com/stacklok/toolhive/pkg/vmcp/router" +) + +// DummyOptimizer implements the Optimizer interface using exact string matching. +// +// This implementation is intended for testing and development. It performs +// case-insensitive substring matching on tool names and descriptions. +// +// For production use, see the EmbeddingOptimizer which uses semantic similarity. +type DummyOptimizer struct { + // tools contains all available tools indexed by name. + tools map[string]vmcp.Tool + + // router routes tool calls to backend servers. + router router.Router + + // backendClient executes tool calls on backend servers. + backendClient vmcp.BackendClient +} + +// NewDummyOptimizer creates a new DummyOptimizer with the given tools. +// +// The tools slice should only include backend tools (not composite tools). +// Composite tools are not supported in this initial implementation as they +// require execution through the composer, not the router/backendClient. +// TODO(jeremy): Add composite tool support. +// +// The router and backendClient are used by CallTool to route and execute +// tool invocations on backend servers. +// TODO: replace the dummy optimizer with a similarity search optimizer. +func NewDummyOptimizer( + tools []vmcp.Tool, + router router.Router, + backendClient vmcp.BackendClient, +) *DummyOptimizer { + toolMap := make(map[string]vmcp.Tool, len(tools)) + for _, tool := range tools { + // Skip composite tools (no backend) - not supported in this implementation + if tool.BackendID == "" { + continue + } + toolMap[tool.Name] = tool + } + + return &DummyOptimizer{ + tools: toolMap, + router: router, + backendClient: backendClient, + } +} + +// FindTool searches for tools using exact substring matching. +// +// The search is case-insensitive and matches against: +// - Tool name (substring match) +// - Tool description (substring match) +// +// Returns all matching tools with a score of 1.0 (exact match semantics). +// TokenMetrics are returned as zero values (not implemented in dummy). +func (d *DummyOptimizer) FindTool(_ context.Context, input FindToolInput) (*FindToolOutput, error) { + if input.ToolDescription == "" { + return nil, fmt.Errorf("tool_description is required") + } + + searchTerm := strings.ToLower(input.ToolDescription) + + var matches []ToolMatch + for _, tool := range d.tools { + nameLower := strings.ToLower(tool.Name) + descLower := strings.ToLower(tool.Description) + + // Check if search term matches name or description + if strings.Contains(nameLower, searchTerm) || strings.Contains(descLower, searchTerm) { + matches = append(matches, ToolMatch{ + Name: tool.Name, + Description: tool.Description, + Parameters: tool.InputSchema, + Score: 1.0, // Exact match semantics + }) + } + } + + return &FindToolOutput{ + Tools: matches, + TokenMetrics: TokenMetrics{}, // Zero values for dummy + }, nil +} + +// CallTool invokes a tool by name using the router and backend client. +// +// The tool is looked up by exact name match. If found, the request is +// routed to the appropriate backend and executed. +func (d *DummyOptimizer) CallTool(ctx context.Context, input CallToolInput) (map[string]any, error) { + if input.ToolName == "" { + return nil, fmt.Errorf("tool_name is required") + } + + // Verify the tool exists + tool, exists := d.tools[input.ToolName] + if !exists { + return nil, fmt.Errorf("tool not found: %s", input.ToolName) + } + + // Route to the correct backend + target, err := d.router.RouteTool(ctx, tool.Name) + if err != nil { + return nil, fmt.Errorf("failed to route tool %s: %w", input.ToolName, err) + } + + // Get the backend name for this tool (handles conflict resolution renaming) + backendToolName := target.GetBackendCapabilityName(tool.Name) + + // Execute the tool call and return result directly + return d.backendClient.CallTool(ctx, target, backendToolName, input.Parameters) +} diff --git a/pkg/vmcp/optimizer/dummy_optimizer_test.go b/pkg/vmcp/optimizer/dummy_optimizer_test.go new file mode 100644 index 0000000000..ffe5a701ea --- /dev/null +++ b/pkg/vmcp/optimizer/dummy_optimizer_test.go @@ -0,0 +1,226 @@ +package optimizer + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.uber.org/mock/gomock" + + "github.com/stacklok/toolhive/pkg/vmcp" + vmcpmocks "github.com/stacklok/toolhive/pkg/vmcp/mocks" + routermocks "github.com/stacklok/toolhive/pkg/vmcp/router/mocks" +) + +func TestDummyOptimizer_FindTool(t *testing.T) { + t.Parallel() + + tools := []vmcp.Tool{ + { + Name: "fetch_url", + Description: "Fetch content from a URL", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "url": map[string]any{"type": "string"}, + }, + }, + BackendID: "backend1", + }, + { + Name: "read_file", + Description: "Read a file from the filesystem", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "path": map[string]any{"type": "string"}, + }, + }, + BackendID: "backend2", + }, + { + Name: "write_file", + Description: "Write content to a file", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "path": map[string]any{"type": "string"}, + "content": map[string]any{"type": "string"}, + }, + }, + BackendID: "backend2", + }, + } + + ctrl := gomock.NewController(t) + mockRouter := routermocks.NewMockRouter(ctrl) + mockClient := vmcpmocks.NewMockBackendClient(ctrl) + + opt := NewDummyOptimizer(tools, mockRouter, mockClient) + + tests := []struct { + name string + input FindToolInput + expectedNames []string + expectedError bool + errorContains string + }{ + { + name: "find by exact name", + input: FindToolInput{ + ToolDescription: "fetch_url", + }, + expectedNames: []string{"fetch_url"}, + }, + { + name: "find by description substring", + input: FindToolInput{ + ToolDescription: "file", + }, + expectedNames: []string{"read_file", "write_file"}, + }, + { + name: "case insensitive search", + input: FindToolInput{ + ToolDescription: "FETCH", + }, + expectedNames: []string{"fetch_url"}, + }, + { + name: "no matches", + input: FindToolInput{ + ToolDescription: "nonexistent", + }, + expectedNames: []string{}, + }, + { + name: "empty description", + input: FindToolInput{}, + expectedError: true, + errorContains: "tool_description is required", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + result, err := opt.FindTool(context.Background(), tc.input) + + if tc.expectedError { + require.Error(t, err) + assert.Contains(t, err.Error(), tc.errorContains) + return + } + + require.NoError(t, err) + require.NotNil(t, result) + + // Extract names from results + var names []string + for _, match := range result.Tools { + names = append(names, match.Name) + } + + assert.ElementsMatch(t, tc.expectedNames, names) + }) + } +} + +func TestDummyOptimizer_CallTool(t *testing.T) { + t.Parallel() + + tools := []vmcp.Tool{ + { + Name: "test_tool", + Description: "A test tool", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "input": map[string]any{"type": "string"}, + }, + }, + BackendID: "backend1", + }, + } + + tests := []struct { + name string + input CallToolInput + setupMocks func(*routermocks.MockRouter, *vmcpmocks.MockBackendClient) + expectedResult map[string]any + expectedError bool + errorContains string + }{ + { + name: "successful tool call", + input: CallToolInput{ + ToolName: "test_tool", + Parameters: map[string]any{"input": "hello"}, + }, + setupMocks: func(r *routermocks.MockRouter, c *vmcpmocks.MockBackendClient) { + target := &vmcp.BackendTarget{ + WorkloadID: "backend1", + WorkloadName: "backend1", + BaseURL: "http://localhost:8080", + } + r.EXPECT().RouteTool(gomock.Any(), "test_tool").Return(target, nil) + c.EXPECT().CallTool(gomock.Any(), target, "test_tool", map[string]any{"input": "hello"}). + Return(map[string]any{ + "content": []any{ + map[string]any{"type": "text", "text": "Hello, World!"}, + }, + }, nil) + }, + expectedResult: map[string]any{ + "content": []any{ + map[string]any{"type": "text", "text": "Hello, World!"}, + }, + }, + }, + { + name: "tool not found", + input: CallToolInput{ + ToolName: "nonexistent", + Parameters: map[string]any{}, + }, + setupMocks: func(_ *routermocks.MockRouter, _ *vmcpmocks.MockBackendClient) {}, + expectedError: true, + errorContains: "tool not found: nonexistent", + }, + { + name: "empty tool name", + input: CallToolInput{ + Parameters: map[string]any{}, + }, + setupMocks: func(_ *routermocks.MockRouter, _ *vmcpmocks.MockBackendClient) {}, + expectedError: true, + errorContains: "tool_name is required", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + mockRouter := routermocks.NewMockRouter(ctrl) + mockClient := vmcpmocks.NewMockBackendClient(ctrl) + + tc.setupMocks(mockRouter, mockClient) + + opt := NewDummyOptimizer(tools, mockRouter, mockClient) + result, err := opt.CallTool(context.Background(), tc.input) + + if tc.expectedError { + require.Error(t, err) + assert.Contains(t, err.Error(), tc.errorContains) + return + } + + require.NoError(t, err) + assert.Equal(t, tc.expectedResult, result) + }) + } +} diff --git a/pkg/vmcp/optimizer/optimizer.go b/pkg/vmcp/optimizer/optimizer.go index b98648826f..befdd2bbb9 100644 --- a/pkg/vmcp/optimizer/optimizer.go +++ b/pkg/vmcp/optimizer/optimizer.go @@ -25,7 +25,8 @@ type Optimizer interface { // CallTool invokes a tool by name with the given parameters. // Returns the tool's result or an error if the tool is not found or execution fails. - CallTool(ctx context.Context, input CallToolInput) (*CallToolResult, error) + // The return type matches BackendClient.CallTool for direct passthrough. + CallTool(ctx context.Context, input CallToolInput) (map[string]any, error) } // FindToolInput contains the parameters for finding tools. @@ -81,28 +82,3 @@ type CallToolInput struct { // Parameters are the arguments to pass to the tool. Parameters map[string]any `json:"parameters" description:"Parameters to pass to the tool"` } - -// CallToolResult contains the result of a tool invocation. -// This wraps the standard MCP CallToolResult content. -type CallToolResult struct { - // Content contains the tool's output. - Content []ContentBlock `json:"content"` - - // IsError indicates whether the tool execution resulted in an error. - IsError bool `json:"isError,omitempty"` -} - -// ContentBlock represents a single content item in a tool result. -type ContentBlock struct { - // Type is the content type (e.g., "text", "image", "resource"). - Type string `json:"type"` - - // Text is the text content (for type="text"). - Text string `json:"text,omitempty"` - - // Data is base64-encoded data (for type="image" or binary content). - Data string `json:"data,omitempty"` - - // MimeType is the MIME type of the content. - MimeType string `json:"mimeType,omitempty"` -} From fc7fc204f15abd1937d7f81a4ae618dffc1c25f5 Mon Sep 17 00:00:00 2001 From: Jeremy Drouillard Date: Tue, 13 Jan 2026 13:38:07 -0800 Subject: [PATCH 05/10] feat(vmcp): add OptimizerConfig to vMCP configuration Add OptimizerConfig to pkg/vmcp/config/config.go to enable the MCP optimizer feature. When configured, vMCP exposes only find_tool and call_tool operations instead of all backend tools directly. OptimizerConfig fields: - EmbeddingService: Name of a K8s Service providing the embedding API for semantic tool discovery Generated artifacts updated: - CRD manifests (VirtualMCPServer) - DeepCopy implementations - CRD API documentation The config is automatically included in the VirtualMCPServer CRD through the embedded config.Config field. --- ...olhive.stacklok.dev_virtualmcpservers.yaml | 15 ++++++++++++++ ...olhive.stacklok.dev_virtualmcpservers.yaml | 15 ++++++++++++++ docs/operator/crd-api.md | 19 ++++++++++++++++++ pkg/vmcp/config/config.go | 19 ++++++++++++++++++ pkg/vmcp/config/zz_generated.deepcopy.go | 20 +++++++++++++++++++ 5 files changed, 88 insertions(+) diff --git a/deploy/charts/operator-crds/files/crds/toolhive.stacklok.dev_virtualmcpservers.yaml b/deploy/charts/operator-crds/files/crds/toolhive.stacklok.dev_virtualmcpservers.yaml index bff67be5a8..bd70f3bffd 100644 --- a/deploy/charts/operator-crds/files/crds/toolhive.stacklok.dev_virtualmcpservers.yaml +++ b/deploy/charts/operator-crds/files/crds/toolhive.stacklok.dev_virtualmcpservers.yaml @@ -862,6 +862,21 @@ spec: - default type: object type: object + optimizer: + description: |- + Optimizer configures the MCP optimizer for context optimization on large toolsets. + When enabled, vMCP exposes only find_tool and call_tool operations to clients + instead of all backend tools directly. This reduces token usage by allowing + LLMs to discover relevant tools on demand rather than receiving all tool definitions. + properties: + embeddingService: + description: |- + EmbeddingService is the name of a Kubernetes Service that provides the embedding service + for semantic tool discovery. The service must implement the optimizer embedding API. + type: string + required: + - embeddingService + type: object outgoingAuth: description: OutgoingAuth configures how the virtual MCP server authenticates to backends. diff --git a/deploy/charts/operator-crds/templates/toolhive.stacklok.dev_virtualmcpservers.yaml b/deploy/charts/operator-crds/templates/toolhive.stacklok.dev_virtualmcpservers.yaml index 8e329f8584..b497666ce5 100644 --- a/deploy/charts/operator-crds/templates/toolhive.stacklok.dev_virtualmcpservers.yaml +++ b/deploy/charts/operator-crds/templates/toolhive.stacklok.dev_virtualmcpservers.yaml @@ -865,6 +865,21 @@ spec: - default type: object type: object + optimizer: + description: |- + Optimizer configures the MCP optimizer for context optimization on large toolsets. + When enabled, vMCP exposes only find_tool and call_tool operations to clients + instead of all backend tools directly. This reduces token usage by allowing + LLMs to discover relevant tools on demand rather than receiving all tool definitions. + properties: + embeddingService: + description: |- + EmbeddingService is the name of a Kubernetes Service that provides the embedding service + for semantic tool discovery. The service must implement the optimizer embedding API. + type: string + required: + - embeddingService + type: object outgoingAuth: description: OutgoingAuth configures how the virtual MCP server authenticates to backends. diff --git a/docs/operator/crd-api.md b/docs/operator/crd-api.md index 14e8587c6a..de89fc4a7c 100644 --- a/docs/operator/crd-api.md +++ b/docs/operator/crd-api.md @@ -225,6 +225,7 @@ _Appears in:_ | `metadata` _object (keys:string, values:string)_ | Refer to Kubernetes API documentation for fields of `metadata`. | | | | `telemetry` _[pkg.telemetry.Config](#pkgtelemetryconfig)_ | Telemetry configures OpenTelemetry-based observability for the Virtual MCP server
including distributed tracing, OTLP metrics export, and Prometheus metrics endpoint. | | | | `audit` _[pkg.audit.Config](#pkgauditconfig)_ | Audit configures audit logging for the Virtual MCP server.
When present, audit logs include MCP protocol operations.
See audit.Config for available configuration options. | | | +| `optimizer` _[vmcp.config.OptimizerConfig](#vmcpconfigoptimizerconfig)_ | Optimizer configures the MCP optimizer for context optimization on large toolsets.
When enabled, vMCP exposes only find_tool and call_tool operations to clients
instead of all backend tools directly. This reduces token usage by allowing
LLMs to discover relevant tools on demand rather than receiving all tool definitions. | | | #### vmcp.config.ConflictResolutionConfig @@ -343,6 +344,24 @@ _Appears in:_ | `failureHandling` _[vmcp.config.FailureHandlingConfig](#vmcpconfigfailurehandlingconfig)_ | FailureHandling configures failure handling. | | | +#### vmcp.config.OptimizerConfig + + + +OptimizerConfig configures the MCP optimizer. +When enabled, vMCP exposes only find_tool and call_tool operations to clients +instead of all backend tools directly. + + + +_Appears in:_ +- [vmcp.config.Config](#vmcpconfigconfig) + +| Field | Description | Default | Validation | +| --- | --- | --- | --- | +| `embeddingService` _string_ | EmbeddingService is the name of a Kubernetes Service that provides the embedding service
for semantic tool discovery. The service must implement the optimizer embedding API. | | Required: \{\}
| + + #### vmcp.config.OutgoingAuthConfig diff --git a/pkg/vmcp/config/config.go b/pkg/vmcp/config/config.go index 9693643348..a6355cd5e8 100644 --- a/pkg/vmcp/config/config.go +++ b/pkg/vmcp/config/config.go @@ -108,6 +108,13 @@ type Config struct { // See audit.Config for available configuration options. // +optional Audit *audit.Config `json:"audit,omitempty" yaml:"audit,omitempty"` + + // Optimizer configures the MCP optimizer for context optimization on large toolsets. + // When enabled, vMCP exposes only find_tool and call_tool operations to clients + // instead of all backend tools directly. This reduces token usage by allowing + // LLMs to discover relevant tools on demand rather than receiving all tool definitions. + // +optional + Optimizer *OptimizerConfig `json:"optimizer,omitempty" yaml:"optimizer,omitempty"` } // IncomingAuthConfig configures client authentication to the virtual MCP server. @@ -474,6 +481,18 @@ type OutputProperty struct { Default thvjson.Any `json:"default,omitempty" yaml:"default,omitempty"` } +// OptimizerConfig configures the MCP optimizer. +// When enabled, vMCP exposes only find_tool and call_tool operations to clients +// instead of all backend tools directly. +// +kubebuilder:object:generate=true +// +gendoc +type OptimizerConfig struct { + // EmbeddingService is the name of a Kubernetes Service that provides the embedding service + // for semantic tool discovery. The service must implement the optimizer embedding API. + // +kubebuilder:validation:Required + EmbeddingService string `json:"embeddingService" yaml:"embeddingService"` +} + // Validator validates configuration. type Validator interface { // Validate checks if the configuration is valid. diff --git a/pkg/vmcp/config/zz_generated.deepcopy.go b/pkg/vmcp/config/zz_generated.deepcopy.go index b50fa118af..5afe9c656a 100644 --- a/pkg/vmcp/config/zz_generated.deepcopy.go +++ b/pkg/vmcp/config/zz_generated.deepcopy.go @@ -175,6 +175,11 @@ func (in *Config) DeepCopyInto(out *Config) { *out = new(audit.Config) (*in).DeepCopyInto(*out) } + if in.Optimizer != nil { + in, out := &in.Optimizer, &out.Optimizer + *out = new(OptimizerConfig) + **out = **in + } } // DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new Config. @@ -312,6 +317,21 @@ func (in *OperationalConfig) DeepCopy() *OperationalConfig { return out } +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *OptimizerConfig) DeepCopyInto(out *OptimizerConfig) { + *out = *in +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new OptimizerConfig. +func (in *OptimizerConfig) DeepCopy() *OptimizerConfig { + if in == nil { + return nil + } + out := new(OptimizerConfig) + in.DeepCopyInto(out) + return out +} + // DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. func (in *OutgoingAuthConfig) DeepCopyInto(out *OutgoingAuthConfig) { *out = *in From 8f35f89410a88753292ec2301b7783e9d02030be Mon Sep 17 00:00:00 2001 From: Jeremy Drouillard Date: Tue, 13 Jan 2026 13:40:30 -0800 Subject: [PATCH 06/10] refactor(operator): use DeepCopy for automatic config field passthrough Refactor the VirtualMCPServer config converter to use DeepCopy() for initializing the vmcp.Config from the embedded config.Config. This ensures that new fields added to config.Config (like Optimizer) are automatically passed through without requiring explicit mapping in the converter. Only fields requiring special handling (auth, aggregation, composite tools, telemetry) are explicitly converted. Benefits: - New config fields are automatically included - Reduces maintenance burden when adding config options - Less code duplication between CRD and config types --- cmd/thv-operator/pkg/vmcpconfig/converter.go | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/cmd/thv-operator/pkg/vmcpconfig/converter.go b/cmd/thv-operator/pkg/vmcpconfig/converter.go index 7eead1d576..10c2538e07 100644 --- a/cmd/thv-operator/pkg/vmcpconfig/converter.go +++ b/cmd/thv-operator/pkg/vmcpconfig/converter.go @@ -61,15 +61,23 @@ func NewConverter(oidcResolver oidc.Resolver, k8sClient client.Client) (*Convert }, nil } -// Convert converts VirtualMCPServer CRD spec to vmcp Config +// Convert converts VirtualMCPServer CRD spec to vmcp Config. +// +// The conversion starts with a DeepCopy of the embedded config.Config from the CRD spec. +// This ensures that simple fields (like Optimizer, Metadata, etc.) are automatically +// passed through without explicit mapping. Only fields that require special handling +// (auth, aggregation, composite tools, telemetry) are explicitly converted below. func (c *Converter) Convert( ctx context.Context, vmcp *mcpv1alpha1.VirtualMCPServer, ) (*vmcpconfig.Config, error) { - config := &vmcpconfig.Config{ - Name: vmcp.Name, - Group: vmcp.Spec.Config.Group, - } + // Start with a deep copy of the embedded config for automatic field passthrough. + // This ensures new fields added to config.Config are automatically included + // without requiring explicit mapping in this converter. + config := vmcp.Spec.Config.DeepCopy() + + // Override name with the CR name (authoritative source) + config.Name = vmcp.Name // Convert IncomingAuth - required field, no defaults if vmcp.Spec.IncomingAuth != nil { From a23f6ff06fa590e0074ef70668d9e9dab5498289 Mon Sep 17 00:00:00 2001 From: Jeremy Drouillard Date: Tue, 13 Jan 2026 13:47:24 -0800 Subject: [PATCH 07/10] feat(vmcp): add reflection-based schema generation and input translation Add GenerateSchema[T]() and Translate[T]() functions to pkg/vmcp/schema: GenerateSchema[T](): - Generates JSON Schema from Go struct using reflection - Uses json tags for field names - Uses description tags for field descriptions - Uses omitempty to determine required vs optional fields - Supports string, integer, number, boolean, array, object types Translate[T](): - Converts untyped map[string]any to typed structs - Uses JSON marshal/unmarshal for reliable conversion - Simplifies MCP tool argument handling These functions ensure optimizer tool schemas (find_tool, call_tool) stay in sync with their Go struct definitions. --- pkg/vmcp/schema/reflect.go | 173 ++++++++++++++++++++++++++++++++ pkg/vmcp/schema/reflect_test.go | 141 ++++++++++++++++++++++++++ 2 files changed, 314 insertions(+) create mode 100644 pkg/vmcp/schema/reflect.go create mode 100644 pkg/vmcp/schema/reflect_test.go diff --git a/pkg/vmcp/schema/reflect.go b/pkg/vmcp/schema/reflect.go new file mode 100644 index 0000000000..a6f1ec419e --- /dev/null +++ b/pkg/vmcp/schema/reflect.go @@ -0,0 +1,173 @@ +package schema + +import ( + "encoding/json" + "fmt" + "reflect" + "strings" +) + +// GenerateSchema generates a JSON Schema from a Go struct type using reflection. +// +// The function inspects struct tags to determine: +// - json: Field name in the schema (uses json tag name) +// - description: Field description (from `description` tag) +// - omitempty: Whether the field is optional (not in required array) +// +// Supported types: +// - string -> {"type": "string"} +// - int, int64, etc. -> {"type": "integer"} +// - float64, float32 -> {"type": "number"} +// - bool -> {"type": "boolean"} +// - []T -> {"type": "array", "items": {...}} +// - map[string]any -> {"type": "object"} +// - struct -> {"type": "object", "properties": {...}} +// +// Example: +// +// type FindToolInput struct { +// ToolDescription string `json:"tool_description" description:"Natural language description"` +// ToolKeywords []string `json:"tool_keywords,omitempty" description:"Optional keywords"` +// } +// schema := GenerateSchema[FindToolInput]() +func GenerateSchema[T any]() map[string]any { + var zero T + t := reflect.TypeOf(zero) + return generateSchemaForType(t) +} + +// generateSchemaForType generates schema for a reflect.Type. +func generateSchemaForType(t reflect.Type) map[string]any { + // Handle pointer types + if t.Kind() == reflect.Ptr { + t = t.Elem() + } + + switch t.Kind() { + case reflect.Struct: + return generateObjectSchema(t) + case reflect.String: + return map[string]any{"type": "string"} + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, + reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + return map[string]any{"type": "integer"} + case reflect.Float32, reflect.Float64: + return map[string]any{"type": "number"} + case reflect.Bool: + return map[string]any{"type": "boolean"} + case reflect.Slice: + return map[string]any{ + "type": "array", + "items": generateSchemaForType(t.Elem()), + } + case reflect.Map: + // For map[string]any, just return object type + return map[string]any{"type": "object"} + case reflect.Interface: + // For any/interface{}, return empty object + return map[string]any{"type": "object"} + default: + return map[string]any{"type": "object"} + } +} + +// generateObjectSchema generates schema for a struct type. +func generateObjectSchema(t reflect.Type) map[string]any { + properties := make(map[string]any) + var required []string + + for i := range t.NumField() { + field := t.Field(i) + + // Skip unexported fields + if !field.IsExported() { + continue + } + + // Get JSON tag for field name + jsonTag := field.Tag.Get("json") + if jsonTag == "-" { + continue + } + + // Parse json tag (name,omitempty) + jsonName, isOptional := parseJSONTag(jsonTag) + if jsonName == "" { + jsonName = field.Name + } + + // Generate schema for field type + fieldSchema := generateSchemaForType(field.Type) + + // Add description if present + if desc := field.Tag.Get("description"); desc != "" { + fieldSchema["description"] = desc + } + + properties[jsonName] = fieldSchema + + // Add to required if not optional + if !isOptional { + required = append(required, jsonName) + } + } + + schema := map[string]any{ + "type": "object", + "properties": properties, + } + + if len(required) > 0 { + schema["required"] = required + } + + return schema +} + +// parseJSONTag parses a json struct tag and returns the field name and whether it's optional. +func parseJSONTag(tag string) (name string, optional bool) { + if tag == "" { + return "", false + } + + parts := strings.Split(tag, ",") + name = parts[0] + + for _, part := range parts[1:] { + if part == "omitempty" { + optional = true + } + } + + return name, optional +} + +// Translate converts an untyped input (typically map[string]any from MCP request arguments) +// to a typed struct using JSON marshalling/unmarshalling. +// +// This provides a simple, reliable way to convert MCP tool arguments to typed Go structs +// without manual field-by-field extraction. +// +// Example: +// +// args := request.Params.Arguments // map[string]any +// input, err := Translate[FindToolInput](args) +// if err != nil { +// return nil, fmt.Errorf("invalid arguments: %w", err) +// } +func Translate[T any](input any) (T, error) { + var result T + + // Marshal to JSON + data, err := json.Marshal(input) + if err != nil { + return result, fmt.Errorf("failed to marshal input: %w", err) + } + + // Unmarshal to typed struct + if err := json.Unmarshal(data, &result); err != nil { + return result, fmt.Errorf("failed to unmarshal to %T: %w", result, err) + } + + return result, nil +} diff --git a/pkg/vmcp/schema/reflect_test.go b/pkg/vmcp/schema/reflect_test.go new file mode 100644 index 0000000000..a6acbbecf4 --- /dev/null +++ b/pkg/vmcp/schema/reflect_test.go @@ -0,0 +1,141 @@ +package schema + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/stacklok/toolhive/pkg/vmcp/optimizer" +) + +func TestGenerateSchema_FindToolInput(t *testing.T) { + t.Parallel() + + expected := map[string]any{ + "type": "object", + "properties": map[string]any{ + "tool_description": map[string]any{ + "type": "string", + "description": "Natural language description of the tool to find", + }, + "tool_keywords": map[string]any{ + "type": "array", + "items": map[string]any{"type": "string"}, + "description": "Optional keywords to narrow search", + }, + }, + "required": []string{"tool_description"}, + } + + actual := GenerateSchema[optimizer.FindToolInput]() + + require.Equal(t, expected, actual) +} + +func TestGenerateSchema_CallToolInput(t *testing.T) { + t.Parallel() + + expected := map[string]any{ + "type": "object", + "properties": map[string]any{ + "tool_name": map[string]any{ + "type": "string", + "description": "Name of the tool to call", + }, + "parameters": map[string]any{ + "type": "object", + "description": "Parameters to pass to the tool", + }, + }, + "required": []string{"tool_name", "parameters"}, + } + + actual := GenerateSchema[optimizer.CallToolInput]() + + require.Equal(t, expected, actual) +} + +func TestTranslate_FindToolInput(t *testing.T) { + t.Parallel() + + input := map[string]any{ + "tool_description": "find a tool to read files", + "tool_keywords": []any{"file", "read"}, + } + + result, err := Translate[optimizer.FindToolInput](input) + require.NoError(t, err) + + require.Equal(t, "find a tool to read files", result.ToolDescription) + require.Equal(t, []string{"file", "read"}, result.ToolKeywords) +} + +func TestTranslate_CallToolInput(t *testing.T) { + t.Parallel() + + input := map[string]any{ + "tool_name": "read_file", + "parameters": map[string]any{ + "path": "/etc/hosts", + }, + } + + result, err := Translate[optimizer.CallToolInput](input) + require.NoError(t, err) + + require.Equal(t, "read_file", result.ToolName) + require.Equal(t, map[string]any{"path": "/etc/hosts"}, result.Parameters) +} + +func TestTranslate_PartialInput(t *testing.T) { + t.Parallel() + + input := map[string]any{ + "tool_description": "find a file reader", + } + + result, err := Translate[optimizer.FindToolInput](input) + require.NoError(t, err) + + require.Equal(t, "find a file reader", result.ToolDescription) + require.Nil(t, result.ToolKeywords) +} + +func TestTranslate_InvalidInput(t *testing.T) { + t.Parallel() + + input := make(chan int) + + _, err := Translate[optimizer.FindToolInput](input) + require.Error(t, err) + assert.Contains(t, err.Error(), "failed to marshal input") +} + +func TestGenerateSchema_PrimitiveTypes(t *testing.T) { + t.Parallel() + + type TestStruct struct { + StringField string `json:"string_field"` + IntField int `json:"int_field"` + FloatField float64 `json:"float_field"` + BoolField bool `json:"bool_field"` + OptionalStr string `json:"optional_str,omitempty"` + } + + expected := map[string]any{ + "type": "object", + "properties": map[string]any{ + "string_field": map[string]any{"type": "string"}, + "int_field": map[string]any{"type": "integer"}, + "float_field": map[string]any{"type": "number"}, + "bool_field": map[string]any{"type": "boolean"}, + "optional_str": map[string]any{"type": "string"}, + }, + "required": []string{"string_field", "int_field", "float_field", "bool_field"}, + } + + actual := GenerateSchema[TestStruct]() + + require.Equal(t, expected, actual) +} From 963bda29a8bbdc825a6e6786c7e3625f2f678e52 Mon Sep 17 00:00:00 2001 From: Jeremy Drouillard Date: Tue, 13 Jan 2026 14:06:24 -0800 Subject: [PATCH 08/10] feat(vmcp): add optimizer tool handlers and refactor DummyOptimizer Refactor DummyOptimizer to use server.ServerTool: - Constructor takes []server.ServerTool instead of separate router/client - CallTool invokes the tool handler directly, returns *mcp.CallToolResult - FindTool returns tool schemas as json.RawMessage (preserves original format) - getToolSchema helper handles RawInputSchema vs InputSchema fallback Add optimizer tool handlers in adapter package: - CreateOptimizerTools creates find_tool and call_tool SDK tools - find_tool handler uses mcp.NewToolResultStructuredOnly for output - call_tool handler delegates to optimizer and returns result directly - Schemas are pre-generated at package init for startup-time validation Uses schema.GenerateSchema and schema.Translate for type-safe argument handling and schema generation. --- pkg/vmcp/optimizer/dummy_optimizer.go | 93 ++++++------ pkg/vmcp/optimizer/dummy_optimizer_test.go | 136 +++++++----------- pkg/vmcp/optimizer/optimizer.go | 10 +- pkg/vmcp/server/adapter/optimizer_adapter.go | 94 ++++++++++++ .../server/adapter/optimizer_adapter_test.go | 104 ++++++++++++++ 5 files changed, 297 insertions(+), 140 deletions(-) create mode 100644 pkg/vmcp/server/adapter/optimizer_adapter.go create mode 100644 pkg/vmcp/server/adapter/optimizer_adapter_test.go diff --git a/pkg/vmcp/optimizer/dummy_optimizer.go b/pkg/vmcp/optimizer/dummy_optimizer.go index db29dc835c..4365dbcc47 100644 --- a/pkg/vmcp/optimizer/dummy_optimizer.go +++ b/pkg/vmcp/optimizer/dummy_optimizer.go @@ -2,11 +2,12 @@ package optimizer import ( "context" + "encoding/json" "fmt" "strings" - "github.com/stacklok/toolhive/pkg/vmcp" - "github.com/stacklok/toolhive/pkg/vmcp/router" + "github.com/mark3labs/mcp-go/mcp" + "github.com/mark3labs/mcp-go/server" ) // DummyOptimizer implements the Optimizer interface using exact string matching. @@ -17,43 +18,20 @@ import ( // For production use, see the EmbeddingOptimizer which uses semantic similarity. type DummyOptimizer struct { // tools contains all available tools indexed by name. - tools map[string]vmcp.Tool - - // router routes tool calls to backend servers. - router router.Router - - // backendClient executes tool calls on backend servers. - backendClient vmcp.BackendClient + tools map[string]server.ServerTool } // NewDummyOptimizer creates a new DummyOptimizer with the given tools. // -// The tools slice should only include backend tools (not composite tools). -// Composite tools are not supported in this initial implementation as they -// require execution through the composer, not the router/backendClient. -// TODO(jeremy): Add composite tool support. -// -// The router and backendClient are used by CallTool to route and execute -// tool invocations on backend servers. -// TODO: replace the dummy optimizer with a similarity search optimizer. -func NewDummyOptimizer( - tools []vmcp.Tool, - router router.Router, - backendClient vmcp.BackendClient, -) *DummyOptimizer { - toolMap := make(map[string]vmcp.Tool, len(tools)) +// The tools slice should contain all backend tools (as ServerTool with handlers). +func NewDummyOptimizer(tools []server.ServerTool) *DummyOptimizer { + toolMap := make(map[string]server.ServerTool, len(tools)) for _, tool := range tools { - // Skip composite tools (no backend) - not supported in this implementation - if tool.BackendID == "" { - continue - } - toolMap[tool.Name] = tool + toolMap[tool.Tool.Name] = tool } return &DummyOptimizer{ - tools: toolMap, - router: router, - backendClient: backendClient, + tools: toolMap, } } @@ -74,15 +52,19 @@ func (d *DummyOptimizer) FindTool(_ context.Context, input FindToolInput) (*Find var matches []ToolMatch for _, tool := range d.tools { - nameLower := strings.ToLower(tool.Name) - descLower := strings.ToLower(tool.Description) + nameLower := strings.ToLower(tool.Tool.Name) + descLower := strings.ToLower(tool.Tool.Description) // Check if search term matches name or description if strings.Contains(nameLower, searchTerm) || strings.Contains(descLower, searchTerm) { + schema, err := getToolSchema(tool.Tool) + if err != nil { + return nil, err + } matches = append(matches, ToolMatch{ - Name: tool.Name, - Description: tool.Description, - Parameters: tool.InputSchema, + Name: tool.Tool.Name, + Description: tool.Tool.Description, + Parameters: schema, Score: 1.0, // Exact match semantics }) } @@ -94,11 +76,11 @@ func (d *DummyOptimizer) FindTool(_ context.Context, input FindToolInput) (*Find }, nil } -// CallTool invokes a tool by name using the router and backend client. +// CallTool invokes a tool by name using its registered handler. // -// The tool is looked up by exact name match. If found, the request is -// routed to the appropriate backend and executed. -func (d *DummyOptimizer) CallTool(ctx context.Context, input CallToolInput) (map[string]any, error) { +// The tool is looked up by exact name match. If found, the handler +// is invoked directly with the given parameters. +func (d *DummyOptimizer) CallTool(ctx context.Context, input CallToolInput) (*mcp.CallToolResult, error) { if input.ToolName == "" { return nil, fmt.Errorf("tool_name is required") } @@ -106,18 +88,29 @@ func (d *DummyOptimizer) CallTool(ctx context.Context, input CallToolInput) (map // Verify the tool exists tool, exists := d.tools[input.ToolName] if !exists { - return nil, fmt.Errorf("tool not found: %s", input.ToolName) + return mcp.NewToolResultError(fmt.Sprintf("tool not found: %s", input.ToolName)), nil } - // Route to the correct backend - target, err := d.router.RouteTool(ctx, tool.Name) - if err != nil { - return nil, fmt.Errorf("failed to route tool %s: %w", input.ToolName, err) - } + // Build the MCP request + request := mcp.CallToolRequest{} + request.Params.Name = input.ToolName + request.Params.Arguments = input.Parameters - // Get the backend name for this tool (handles conflict resolution renaming) - backendToolName := target.GetBackendCapabilityName(tool.Name) + // Call the tool handler directly + return tool.Handler(ctx, request) +} + +// getToolSchema returns the input schema for a tool. +// Prefers RawInputSchema if set, otherwise marshals InputSchema. +func getToolSchema(tool mcp.Tool) (json.RawMessage, error) { + if len(tool.RawInputSchema) > 0 { + return tool.RawInputSchema, nil + } - // Execute the tool call and return result directly - return d.backendClient.CallTool(ctx, target, backendToolName, input.Parameters) + // Fall back to InputSchema + data, err := json.Marshal(tool.InputSchema) + if err != nil { + return nil, fmt.Errorf("failed to marshal input schema for tool %s: %w", tool.Name, err) + } + return data, nil } diff --git a/pkg/vmcp/optimizer/dummy_optimizer_test.go b/pkg/vmcp/optimizer/dummy_optimizer_test.go index ffe5a701ea..489f04e0e5 100644 --- a/pkg/vmcp/optimizer/dummy_optimizer_test.go +++ b/pkg/vmcp/optimizer/dummy_optimizer_test.go @@ -4,60 +4,36 @@ import ( "context" "testing" - "github.com/stretchr/testify/assert" + "github.com/mark3labs/mcp-go/mcp" + "github.com/mark3labs/mcp-go/server" "github.com/stretchr/testify/require" - "go.uber.org/mock/gomock" - - "github.com/stacklok/toolhive/pkg/vmcp" - vmcpmocks "github.com/stacklok/toolhive/pkg/vmcp/mocks" - routermocks "github.com/stacklok/toolhive/pkg/vmcp/router/mocks" ) func TestDummyOptimizer_FindTool(t *testing.T) { t.Parallel() - tools := []vmcp.Tool{ + tools := []server.ServerTool{ { - Name: "fetch_url", - Description: "Fetch content from a URL", - InputSchema: map[string]any{ - "type": "object", - "properties": map[string]any{ - "url": map[string]any{"type": "string"}, - }, + Tool: mcp.Tool{ + Name: "fetch_url", + Description: "Fetch content from a URL", }, - BackendID: "backend1", }, { - Name: "read_file", - Description: "Read a file from the filesystem", - InputSchema: map[string]any{ - "type": "object", - "properties": map[string]any{ - "path": map[string]any{"type": "string"}, - }, + Tool: mcp.Tool{ + Name: "read_file", + Description: "Read a file from the filesystem", }, - BackendID: "backend2", }, { - Name: "write_file", - Description: "Write content to a file", - InputSchema: map[string]any{ - "type": "object", - "properties": map[string]any{ - "path": map[string]any{"type": "string"}, - "content": map[string]any{"type": "string"}, - }, + Tool: mcp.Tool{ + Name: "write_file", + Description: "Write content to a file", }, - BackendID: "backend2", }, } - ctrl := gomock.NewController(t) - mockRouter := routermocks.NewMockRouter(ctrl) - mockClient := vmcpmocks.NewMockBackendClient(ctrl) - - opt := NewDummyOptimizer(tools, mockRouter, mockClient) + opt := NewDummyOptimizer(tools) tests := []struct { name string @@ -110,7 +86,7 @@ func TestDummyOptimizer_FindTool(t *testing.T) { if tc.expectedError { require.Error(t, err) - assert.Contains(t, err.Error(), tc.errorContains) + require.Contains(t, err.Error(), tc.errorContains) return } @@ -123,7 +99,7 @@ func TestDummyOptimizer_FindTool(t *testing.T) { names = append(names, match.Name) } - assert.ElementsMatch(t, tc.expectedNames, names) + require.ElementsMatch(t, tc.expectedNames, names) }) } } @@ -131,53 +107,37 @@ func TestDummyOptimizer_FindTool(t *testing.T) { func TestDummyOptimizer_CallTool(t *testing.T) { t.Parallel() - tools := []vmcp.Tool{ + tools := []server.ServerTool{ { - Name: "test_tool", - Description: "A test tool", - InputSchema: map[string]any{ - "type": "object", - "properties": map[string]any{ - "input": map[string]any{"type": "string"}, - }, + Tool: mcp.Tool{ + Name: "test_tool", + Description: "A test tool", + }, + Handler: func(_ context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + args, _ := req.Params.Arguments.(map[string]any) + input := args["input"].(string) + return mcp.NewToolResultText("Hello, " + input + "!"), nil }, - BackendID: "backend1", }, } + opt := NewDummyOptimizer(tools) + tests := []struct { - name string - input CallToolInput - setupMocks func(*routermocks.MockRouter, *vmcpmocks.MockBackendClient) - expectedResult map[string]any - expectedError bool - errorContains string + name string + input CallToolInput + expectedText string + expectedError bool + isToolError bool + errorContains string }{ { name: "successful tool call", input: CallToolInput{ ToolName: "test_tool", - Parameters: map[string]any{"input": "hello"}, - }, - setupMocks: func(r *routermocks.MockRouter, c *vmcpmocks.MockBackendClient) { - target := &vmcp.BackendTarget{ - WorkloadID: "backend1", - WorkloadName: "backend1", - BaseURL: "http://localhost:8080", - } - r.EXPECT().RouteTool(gomock.Any(), "test_tool").Return(target, nil) - c.EXPECT().CallTool(gomock.Any(), target, "test_tool", map[string]any{"input": "hello"}). - Return(map[string]any{ - "content": []any{ - map[string]any{"type": "text", "text": "Hello, World!"}, - }, - }, nil) - }, - expectedResult: map[string]any{ - "content": []any{ - map[string]any{"type": "text", "text": "Hello, World!"}, - }, + Parameters: map[string]any{"input": "World"}, }, + expectedText: "Hello, World!", }, { name: "tool not found", @@ -185,16 +145,14 @@ func TestDummyOptimizer_CallTool(t *testing.T) { ToolName: "nonexistent", Parameters: map[string]any{}, }, - setupMocks: func(_ *routermocks.MockRouter, _ *vmcpmocks.MockBackendClient) {}, - expectedError: true, - errorContains: "tool not found: nonexistent", + isToolError: true, + expectedText: "tool not found: nonexistent", }, { name: "empty tool name", input: CallToolInput{ Parameters: map[string]any{}, }, - setupMocks: func(_ *routermocks.MockRouter, _ *vmcpmocks.MockBackendClient) {}, expectedError: true, errorContains: "tool_name is required", }, @@ -204,23 +162,27 @@ func TestDummyOptimizer_CallTool(t *testing.T) { t.Run(tc.name, func(t *testing.T) { t.Parallel() - ctrl := gomock.NewController(t) - mockRouter := routermocks.NewMockRouter(ctrl) - mockClient := vmcpmocks.NewMockBackendClient(ctrl) - - tc.setupMocks(mockRouter, mockClient) - - opt := NewDummyOptimizer(tools, mockRouter, mockClient) result, err := opt.CallTool(context.Background(), tc.input) if tc.expectedError { require.Error(t, err) - assert.Contains(t, err.Error(), tc.errorContains) + require.Contains(t, err.Error(), tc.errorContains) return } require.NoError(t, err) - assert.Equal(t, tc.expectedResult, result) + require.NotNil(t, result) + + if tc.isToolError { + require.True(t, result.IsError) + } + + if tc.expectedText != "" { + require.Len(t, result.Content, 1) + textContent, ok := result.Content[0].(mcp.TextContent) + require.True(t, ok) + require.Equal(t, tc.expectedText, textContent.Text) + } }) } } diff --git a/pkg/vmcp/optimizer/optimizer.go b/pkg/vmcp/optimizer/optimizer.go index befdd2bbb9..2da6f748f7 100644 --- a/pkg/vmcp/optimizer/optimizer.go +++ b/pkg/vmcp/optimizer/optimizer.go @@ -11,6 +11,9 @@ package optimizer import ( "context" + "encoding/json" + + "github.com/mark3labs/mcp-go/mcp" ) // Optimizer defines the interface for intelligent tool discovery and invocation. @@ -25,8 +28,8 @@ type Optimizer interface { // CallTool invokes a tool by name with the given parameters. // Returns the tool's result or an error if the tool is not found or execution fails. - // The return type matches BackendClient.CallTool for direct passthrough. - CallTool(ctx context.Context, input CallToolInput) (map[string]any, error) + // Returns the MCP CallToolResult directly from the underlying tool handler. + CallTool(ctx context.Context, input CallToolInput) (*mcp.CallToolResult, error) } // FindToolInput contains the parameters for finding tools. @@ -56,7 +59,8 @@ type ToolMatch struct { Description string `json:"description"` // Parameters is the JSON schema for the tool's input parameters. - Parameters map[string]any `json:"parameters"` + // Uses json.RawMessage to preserve the original schema format. + Parameters json.RawMessage `json:"parameters"` // Score indicates how well this tool matches the search criteria (0.0-1.0). Score float64 `json:"score"` diff --git a/pkg/vmcp/server/adapter/optimizer_adapter.go b/pkg/vmcp/server/adapter/optimizer_adapter.go new file mode 100644 index 0000000000..d0abb559ec --- /dev/null +++ b/pkg/vmcp/server/adapter/optimizer_adapter.go @@ -0,0 +1,94 @@ +package adapter + +import ( + "context" + "encoding/json" + "fmt" + + "github.com/mark3labs/mcp-go/mcp" + "github.com/mark3labs/mcp-go/server" + + "github.com/stacklok/toolhive/pkg/vmcp/optimizer" + "github.com/stacklok/toolhive/pkg/vmcp/schema" +) + +// OptimizerToolNames defines the tool names exposed when optimizer is enabled. +const ( + FindToolName = "find_tool" + CallToolName = "call_tool" +) + +// Pre-generated schemas for optimizer tools. +// Generated at package init time so any schema errors panic at startup. +var ( + findToolInputSchema = mustMarshalSchema(schema.GenerateSchema[optimizer.FindToolInput]()) + callToolInputSchema = mustMarshalSchema(schema.GenerateSchema[optimizer.CallToolInput]()) +) + +// CreateOptimizerTools creates the SDK tools for optimizer mode. +// When optimizer is enabled, only these two tools are exposed to clients +// instead of all backend tools. +func CreateOptimizerTools(opt optimizer.Optimizer) []server.ServerTool { + return []server.ServerTool{ + { + Tool: mcp.Tool{ + Name: FindToolName, + Description: "Search for tools by description. Returns matching tools ranked by relevance.", + RawInputSchema: findToolInputSchema, + }, + Handler: createFindToolHandler(opt), + }, + { + Tool: mcp.Tool{ + Name: CallToolName, + Description: "Call a tool by name with the given parameters.", + RawInputSchema: callToolInputSchema, + }, + Handler: createCallToolHandler(opt), + }, + } +} + +// createFindToolHandler creates a handler for the find_tool optimizer operation. +func createFindToolHandler(opt optimizer.Optimizer) func(context.Context, mcp.CallToolRequest) (*mcp.CallToolResult, error) { + return func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + input, err := schema.Translate[optimizer.FindToolInput](request.Params.Arguments) + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("invalid arguments: %v", err)), nil + } + + output, err := opt.FindTool(ctx, input) + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("find_tool failed: %v", err)), nil + } + + return mcp.NewToolResultStructuredOnly(output), nil + } +} + +// createCallToolHandler creates a handler for the call_tool optimizer operation. +func createCallToolHandler(opt optimizer.Optimizer) func(context.Context, mcp.CallToolRequest) (*mcp.CallToolResult, error) { + return func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + input, err := schema.Translate[optimizer.CallToolInput](request.Params.Arguments) + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("invalid arguments: %v", err)), nil + } + + result, err := opt.CallTool(ctx, input) + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("call_tool failed: %v", err)), nil + } + + return result, nil + } +} + +// mustMarshalSchema marshals a schema to JSON, panicking on error. +// This is safe because schemas are generated from known types at startup. +func mustMarshalSchema(s map[string]any) json.RawMessage { + data, err := json.Marshal(s) + if err != nil { + panic(fmt.Sprintf("failed to marshal schema: %v", err)) + } + return data +} diff --git a/pkg/vmcp/server/adapter/optimizer_adapter_test.go b/pkg/vmcp/server/adapter/optimizer_adapter_test.go new file mode 100644 index 0000000000..f0c4db73e0 --- /dev/null +++ b/pkg/vmcp/server/adapter/optimizer_adapter_test.go @@ -0,0 +1,104 @@ +package adapter + +import ( + "context" + "testing" + + "github.com/mark3labs/mcp-go/mcp" + "github.com/stretchr/testify/require" + + "github.com/stacklok/toolhive/pkg/vmcp/optimizer" +) + +// mockOptimizer implements optimizer.Optimizer for testing. +type mockOptimizer struct { + findToolFunc func(ctx context.Context, input optimizer.FindToolInput) (*optimizer.FindToolOutput, error) + callToolFunc func(ctx context.Context, input optimizer.CallToolInput) (*mcp.CallToolResult, error) +} + +func (m *mockOptimizer) FindTool(ctx context.Context, input optimizer.FindToolInput) (*optimizer.FindToolOutput, error) { + if m.findToolFunc != nil { + return m.findToolFunc(ctx, input) + } + return &optimizer.FindToolOutput{}, nil +} + +func (m *mockOptimizer) CallTool(ctx context.Context, input optimizer.CallToolInput) (*mcp.CallToolResult, error) { + if m.callToolFunc != nil { + return m.callToolFunc(ctx, input) + } + return mcp.NewToolResultText("ok"), nil +} + +func TestCreateOptimizerTools(t *testing.T) { + t.Parallel() + + opt := &mockOptimizer{} + tools := CreateOptimizerTools(opt) + + require.Len(t, tools, 2) + require.Equal(t, FindToolName, tools[0].Tool.Name) + require.Equal(t, CallToolName, tools[1].Tool.Name) +} + +func TestFindToolHandler(t *testing.T) { + t.Parallel() + + opt := &mockOptimizer{ + findToolFunc: func(_ context.Context, input optimizer.FindToolInput) (*optimizer.FindToolOutput, error) { + require.Equal(t, "read files", input.ToolDescription) + return &optimizer.FindToolOutput{ + Tools: []optimizer.ToolMatch{ + { + Name: "read_file", + Description: "Read a file", + Score: 1.0, + }, + }, + }, nil + }, + } + + tools := CreateOptimizerTools(opt) + handler := tools[0].Handler + + request := mcp.CallToolRequest{} + request.Params.Arguments = map[string]any{ + "tool_description": "read files", + } + + result, err := handler(context.Background(), request) + require.NoError(t, err) + require.NotNil(t, result) + require.False(t, result.IsError) + require.Len(t, result.Content, 1) +} + +func TestCallToolHandler(t *testing.T) { + t.Parallel() + + opt := &mockOptimizer{ + callToolFunc: func(_ context.Context, input optimizer.CallToolInput) (*mcp.CallToolResult, error) { + require.Equal(t, "read_file", input.ToolName) + require.Equal(t, "/etc/hosts", input.Parameters["path"]) + return mcp.NewToolResultText("file contents here"), nil + }, + } + + tools := CreateOptimizerTools(opt) + handler := tools[1].Handler + + request := mcp.CallToolRequest{} + request.Params.Arguments = map[string]any{ + "tool_name": "read_file", + "parameters": map[string]any{ + "path": "/etc/hosts", + }, + } + + result, err := handler(context.Background(), request) + require.NoError(t, err) + require.NotNil(t, result) + require.False(t, result.IsError) + require.Len(t, result.Content, 1) +} From 9e33aae599ead8382a6ff5963b4ebe530e0c6730 Mon Sep 17 00:00:00 2001 From: Jeremy Drouillard Date: Tue, 13 Jan 2026 15:02:04 -0800 Subject: [PATCH 09/10] feat(vmcp): wire up optimizer in session registration Add OptimizerFactory to server.Config to enable optimizer mode: - OptimizerFactory is a function that creates an Optimizer from tools - When set, session registration calls injectOptimizerCapabilities - injectOptimizerCapabilities wraps all backend/composite tools in optimizer - Only find_tool and call_tool are exposed to clients in optimizer mode - Resources and prompts are still injected normally Wire up in vmcp serve command: - When cfg.Optimizer is set, configure OptimizerFactory with DummyOptimizer - TODO comment for replacing with real optimizer implementation --- cmd/vmcp/app/commands.go | 6 +++ pkg/vmcp/optimizer/dummy_optimizer.go | 8 ++-- pkg/vmcp/server/server.go | 65 +++++++++++++++++++++++++++ 3 files changed, 75 insertions(+), 4 deletions(-) diff --git a/cmd/vmcp/app/commands.go b/cmd/vmcp/app/commands.go index ca6060bcab..891b32c1cb 100644 --- a/cmd/vmcp/app/commands.go +++ b/cmd/vmcp/app/commands.go @@ -24,6 +24,7 @@ import ( "github.com/stacklok/toolhive/pkg/vmcp/discovery" "github.com/stacklok/toolhive/pkg/vmcp/health" "github.com/stacklok/toolhive/pkg/vmcp/k8s" + "github.com/stacklok/toolhive/pkg/vmcp/optimizer" vmcprouter "github.com/stacklok/toolhive/pkg/vmcp/router" vmcpserver "github.com/stacklok/toolhive/pkg/vmcp/server" ) @@ -416,6 +417,11 @@ func runServe(cmd *cobra.Command, _ []string) error { Watcher: backendWatcher, } + if cfg.Optimizer != nil { + // TODO: update this with the real optimizer. + serverCfg.OptimizerFactory = optimizer.NewDummyOptimizer + } + // Convert composite tool configurations to workflow definitions workflowDefs, err := vmcpserver.ConvertConfigToWorkflowDefinitions(cfg.CompositeTools) if err != nil { diff --git a/pkg/vmcp/optimizer/dummy_optimizer.go b/pkg/vmcp/optimizer/dummy_optimizer.go index 4365dbcc47..9f6f887c81 100644 --- a/pkg/vmcp/optimizer/dummy_optimizer.go +++ b/pkg/vmcp/optimizer/dummy_optimizer.go @@ -24,13 +24,13 @@ type DummyOptimizer struct { // NewDummyOptimizer creates a new DummyOptimizer with the given tools. // // The tools slice should contain all backend tools (as ServerTool with handlers). -func NewDummyOptimizer(tools []server.ServerTool) *DummyOptimizer { +func NewDummyOptimizer(tools []server.ServerTool) Optimizer { toolMap := make(map[string]server.ServerTool, len(tools)) for _, tool := range tools { toolMap[tool.Tool.Name] = tool } - return &DummyOptimizer{ + return DummyOptimizer{ tools: toolMap, } } @@ -43,7 +43,7 @@ func NewDummyOptimizer(tools []server.ServerTool) *DummyOptimizer { // // Returns all matching tools with a score of 1.0 (exact match semantics). // TokenMetrics are returned as zero values (not implemented in dummy). -func (d *DummyOptimizer) FindTool(_ context.Context, input FindToolInput) (*FindToolOutput, error) { +func (d DummyOptimizer) FindTool(_ context.Context, input FindToolInput) (*FindToolOutput, error) { if input.ToolDescription == "" { return nil, fmt.Errorf("tool_description is required") } @@ -80,7 +80,7 @@ func (d *DummyOptimizer) FindTool(_ context.Context, input FindToolInput) (*Find // // The tool is looked up by exact name match. If found, the handler // is invoked directly with the given parameters. -func (d *DummyOptimizer) CallTool(ctx context.Context, input CallToolInput) (*mcp.CallToolResult, error) { +func (d DummyOptimizer) CallTool(ctx context.Context, input CallToolInput) (*mcp.CallToolResult, error) { if input.ToolName == "" { return nil, fmt.Errorf("tool_name is required") } diff --git a/pkg/vmcp/server/server.go b/pkg/vmcp/server/server.go index 2dc3eb0b2a..583f41c739 100644 --- a/pkg/vmcp/server/server.go +++ b/pkg/vmcp/server/server.go @@ -28,6 +28,7 @@ import ( "github.com/stacklok/toolhive/pkg/vmcp/composer" "github.com/stacklok/toolhive/pkg/vmcp/discovery" "github.com/stacklok/toolhive/pkg/vmcp/health" + "github.com/stacklok/toolhive/pkg/vmcp/optimizer" "github.com/stacklok/toolhive/pkg/vmcp/router" "github.com/stacklok/toolhive/pkg/vmcp/server/adapter" vmcpsession "github.com/stacklok/toolhive/pkg/vmcp/session" @@ -119,6 +120,10 @@ type Config struct { // Only set when running in K8s with outgoingAuth.source: discovered. // Used for /readyz endpoint to gate readiness on cache sync. Watcher Watcher + + // OptimizerFactory builds an optimizer from a list of tools. + // If not set, the optimizer is disabled. + OptimizerFactory func([]server.ServerTool) optimizer.Optimizer } // Server is the Virtual MCP Server that aggregates multiple backends. @@ -699,6 +704,7 @@ func (s *Server) Ready() <-chan struct{} { // - No previous capabilities exist, so no deletion needed // - Capabilities are IMMUTABLE for the session lifetime (see limitation below) // - Discovery middleware does not re-run for subsequent requests +// - If injectOptimizerCapabilities is called, this should not be called again. // // LIMITATION: Session capabilities are fixed at creation time. // If backends change (new tools added, resources removed), existing sessions won't see updates. @@ -772,6 +778,53 @@ func (s *Server) injectCapabilities( return nil } +// injectOptimizerCapabilities injects all capabilities into the session, including optimizer tools. +// It should not be called if optimizer mode and replaces injectCapabilities. +// +// When optimizer mode is enabled, instead of exposing all backend tools directly, +// vMCP exposes only two meta-tools: +// - find_tool: Search for tools by description +// - call_tool: Invoke a tool by name with parameters +// +// This method: +// 1. Converts all tools (backend + composite) to SDK format with handlers +// 2. Injects the optimizer capabilities into the session +func (s *Server) injectOptimizerCapabilities( + sessionID string, + caps *aggregator.AggregatedCapabilities, +) error { + + tools := append([]vmcp.Tool{}, caps.Tools...) + tools = append(tools, caps.CompositeTools...) + + sdkTools, err := s.capabilityAdapter.ToSDKTools(tools) + if err != nil { + return fmt.Errorf("failed to convert tools to SDK format: %w", err) + } + + // Create optimizer tools (find_tool, call_tool) + optimizerTools := adapter.CreateOptimizerTools(s.config.OptimizerFactory(sdkTools)) + + logger.Debugw("created optimizer tools for session", + "session_id", sessionID, + "backend_tool_count", len(caps.Tools), + "composite_tool_count", len(caps.CompositeTools), + "total_tools_indexed", len(sdkTools)) + + // Clear tools from caps - they're now wrapped by optimizer + // Resources and prompts are preserved and handled normally + caps.Tools = nil + caps.CompositeTools = nil + + // Manually add the optimizer tools, since we don't want to bother converting + // optimizer tools into `vmcp.Tool`s as well. + if err := s.mcpServer.AddSessionTools(sessionID, optimizerTools...); err != nil { + return fmt.Errorf("failed to add session tools: %w", err) + } + + return s.injectCapabilities(sessionID, caps) +} + // handleSessionRegistration processes a new MCP session registration. // // This hook fires AFTER the session is registered in the SDK (unlike AfterInitialize which @@ -859,6 +912,18 @@ func (s *Server) handleSessionRegistration( "resource_count", len(caps.RoutingTable.Resources), "prompt_count", len(caps.RoutingTable.Prompts)) + if s.config.OptimizerFactory != nil { + err = s.injectOptimizerCapabilities(sessionID, caps) + if err != nil { + logger.Errorw("failed to create optimizer tools", + "error", err, + "session_id", sessionID) + } else { + logger.Infow("optimizer capabilities injected") + } + return + } + // Inject capabilities into SDK session if err := s.injectCapabilities(sessionID, caps); err != nil { logger.Errorw("failed to inject session capabilities", From 48340edcf0a2967e834fdcb1c0d61760c1f06984 Mon Sep 17 00:00:00 2001 From: Jeremy Drouillard Date: Tue, 13 Jan 2026 20:55:21 -0800 Subject: [PATCH 10/10] test and fix' --- pkg/vmcp/optimizer/dummy_optimizer.go | 6 + pkg/vmcp/server/server.go | 9 + .../virtualmcp/virtualmcp_optimizer_test.go | 311 ++++++++++++++++++ 3 files changed, 326 insertions(+) create mode 100644 test/e2e/thv-operator/virtualmcp/virtualmcp_optimizer_test.go diff --git a/pkg/vmcp/optimizer/dummy_optimizer.go b/pkg/vmcp/optimizer/dummy_optimizer.go index 9f6f887c81..3a8338f04d 100644 --- a/pkg/vmcp/optimizer/dummy_optimizer.go +++ b/pkg/vmcp/optimizer/dummy_optimizer.go @@ -48,6 +48,12 @@ func (d DummyOptimizer) FindTool(_ context.Context, input FindToolInput) (*FindT return nil, fmt.Errorf("tool_description is required") } + // Log all tools in the optimizer for debugging + fmt.Printf("[DummyOptimizer.FindTool] Searching for %q in %d tools:\n", input.ToolDescription, len(d.tools)) + for name, tool := range d.tools { + fmt.Printf(" - %q: %q\n", name, tool.Tool.Description) + } + searchTerm := strings.ToLower(input.ToolDescription) var matches []ToolMatch diff --git a/pkg/vmcp/server/server.go b/pkg/vmcp/server/server.go index 583f41c739..d0879f3a63 100644 --- a/pkg/vmcp/server/server.go +++ b/pkg/vmcp/server/server.go @@ -811,10 +811,19 @@ func (s *Server) injectOptimizerCapabilities( "composite_tool_count", len(caps.CompositeTools), "total_tools_indexed", len(sdkTools)) + // Save tools before clearing - caps is shared across sessions + savedTools := caps.Tools + savedCompositeTools := caps.CompositeTools + // Clear tools from caps - they're now wrapped by optimizer // Resources and prompts are preserved and handled normally caps.Tools = nil caps.CompositeTools = nil + defer func() { + // Restore tools before returning error + caps.Tools = savedTools + caps.CompositeTools = savedCompositeTools + }() // Manually add the optimizer tools, since we don't want to bother converting // optimizer tools into `vmcp.Tool`s as well. diff --git a/test/e2e/thv-operator/virtualmcp/virtualmcp_optimizer_test.go b/test/e2e/thv-operator/virtualmcp/virtualmcp_optimizer_test.go new file mode 100644 index 0000000000..1db816db9a --- /dev/null +++ b/test/e2e/thv-operator/virtualmcp/virtualmcp_optimizer_test.go @@ -0,0 +1,311 @@ +package virtualmcp + +import ( + "encoding/json" + "fmt" + "strings" + "time" + + "github.com/mark3labs/mcp-go/mcp" + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/runtime" + "k8s.io/apimachinery/pkg/types" + + mcpv1alpha1 "github.com/stacklok/toolhive/cmd/thv-operator/api/v1alpha1" + vmcpconfig "github.com/stacklok/toolhive/pkg/vmcp/config" + "github.com/stacklok/toolhive/test/e2e/images" +) + +// callFindTool calls find_tool and returns the StructuredContent directly +func callFindTool(mcpClient *InitializedMCPClient, description string) (map[string]any, error) { + req := mcp.CallToolRequest{} + req.Params.Name = "find_tool" + req.Params.Arguments = map[string]any{"tool_description": description} + + result, err := mcpClient.Client.CallTool(mcpClient.Ctx, req) + if err != nil { + return nil, err + } + content, ok := result.StructuredContent.(map[string]any) + if !ok { + return nil, fmt.Errorf("expected map[string]any, got %T", result.StructuredContent) + } + return content, nil +} + +// getToolNames extracts tool names from find_tool structured content +func getToolNames(content map[string]any) []string { + tools, ok := content["tools"].([]any) + if !ok { + return nil + } + var names []string + for _, t := range tools { + if tool, ok := t.(map[string]any); ok { + if name, ok := tool["name"].(string); ok { + names = append(names, name) + } + } + } + return names +} + +// callToolViaOptimizer invokes a tool through call_tool +func callToolViaOptimizer(mcpClient *InitializedMCPClient, toolName string, params map[string]any) (*mcp.CallToolResult, error) { + req := mcp.CallToolRequest{} + req.Params.Name = "call_tool" + req.Params.Arguments = map[string]any{ + "tool_name": toolName, + "parameters": params, + } + return mcpClient.Client.CallTool(mcpClient.Ctx, req) +} + +var _ = Describe("VirtualMCPServer Optimizer Mode", Ordered, func() { + var ( + testNamespace = "default" + mcpGroupName = "test-optimizer-group" + vmcpServerName = "test-vmcp-optimizer" + backendName = "backend-optimizer-fetch" + compositeToolName = "double_fetch" + timeout = 3 * time.Minute + pollingInterval = 1 * time.Second + vmcpNodePort int32 + ) + + BeforeAll(func() { + By("Creating MCPGroup for optimizer test") + CreateMCPGroupAndWait(ctx, k8sClient, mcpGroupName, testNamespace, + "Test MCP Group for optimizer E2E tests", timeout, pollingInterval) + + By("Creating backend MCPServer - fetch") + CreateMCPServerAndWait(ctx, k8sClient, backendName, testNamespace, + mcpGroupName, images.GofetchServerImage, timeout, pollingInterval) + + By("Creating VirtualMCPServer with optimizer enabled and a composite tool") + // Define composite tool parameters schema + paramSchema := map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "url": map[string]interface{}{ + "type": "string", + "description": "URL to fetch twice", + }, + }, + "required": []string{"url"}, + } + paramSchemaBytes, err := json.Marshal(paramSchema) + Expect(err).ToNot(HaveOccurred()) + + // Define step arguments that reference the input parameter + stepArgs := map[string]interface{}{ + "url": "{{.params.url}}", + } + stepArgsBytes, err := json.Marshal(stepArgs) + Expect(err).ToNot(HaveOccurred()) + + vmcpServer := &mcpv1alpha1.VirtualMCPServer{ + ObjectMeta: metav1.ObjectMeta{ + Name: vmcpServerName, + Namespace: testNamespace, + }, + Spec: mcpv1alpha1.VirtualMCPServerSpec{ + ServiceType: "NodePort", + IncomingAuth: &mcpv1alpha1.IncomingAuthConfig{ + Type: "anonymous", + }, + OutgoingAuth: &mcpv1alpha1.OutgoingAuthConfig{ + Source: "discovered", + }, + // Define a composite tool that calls fetch twice + CompositeTools: []mcpv1alpha1.CompositeToolSpec{ + { + Name: compositeToolName, + Description: "Fetches a URL twice in sequence for verification", + Parameters: &runtime.RawExtension{ + Raw: paramSchemaBytes, + }, + Steps: []mcpv1alpha1.WorkflowStep{ + { + ID: "first_fetch", + Type: "tool", + Tool: fmt.Sprintf("%s.fetch", backendName), + Arguments: &runtime.RawExtension{ + Raw: stepArgsBytes, + }, + }, + { + ID: "second_fetch", + Type: "tool", + Tool: fmt.Sprintf("%s.fetch", backendName), + DependsOn: []string{"first_fetch"}, + Arguments: &runtime.RawExtension{ + Raw: stepArgsBytes, + }, + }, + }, + }, + }, + Config: vmcpconfig.Config{ + Group: mcpGroupName, + Optimizer: &vmcpconfig.OptimizerConfig{ + // EmbeddingService is required but not used by DummyOptimizer + EmbeddingService: "dummy-embedding-service", + }, + }, + }, + } + Expect(k8sClient.Create(ctx, vmcpServer)).To(Succeed()) + + By("Waiting for VirtualMCPServer to be ready") + WaitForVirtualMCPServerReady(ctx, k8sClient, vmcpServerName, testNamespace, timeout, pollingInterval) + + By("Getting VirtualMCPServer NodePort") + vmcpNodePort = GetVMCPNodePort(ctx, k8sClient, vmcpServerName, testNamespace, timeout, pollingInterval) + _, _ = fmt.Fprintf(GinkgoWriter, "VirtualMCPServer is accessible at NodePort: %d\n", vmcpNodePort) + }) + + AfterAll(func() { + By("Cleaning up VirtualMCPServer") + vmcpServer := &mcpv1alpha1.VirtualMCPServer{} + if err := k8sClient.Get(ctx, types.NamespacedName{ + Name: vmcpServerName, + Namespace: testNamespace, + }, vmcpServer); err == nil { + _ = k8sClient.Delete(ctx, vmcpServer) + } + + By("Cleaning up backend MCPServer") + backend := &mcpv1alpha1.MCPServer{} + if err := k8sClient.Get(ctx, types.NamespacedName{ + Name: backendName, + Namespace: testNamespace, + }, backend); err == nil { + _ = k8sClient.Delete(ctx, backend) + } + + By("Cleaning up MCPGroup") + mcpGroup := &mcpv1alpha1.MCPGroup{} + if err := k8sClient.Get(ctx, types.NamespacedName{ + Name: mcpGroupName, + Namespace: testNamespace, + }, mcpGroup); err == nil { + _ = k8sClient.Delete(ctx, mcpGroup) + } + }) + + It("should only expose find_tool and call_tool", func() { + By("Creating and initializing MCP client") + mcpClient, err := CreateInitializedMCPClient(vmcpNodePort, "optimizer-test-client", 30*time.Second) + Expect(err).ToNot(HaveOccurred()) + defer mcpClient.Close() + + By("Listing tools from VirtualMCPServer") + listRequest := mcp.ListToolsRequest{} + tools, err := mcpClient.Client.ListTools(mcpClient.Ctx, listRequest) + Expect(err).ToNot(HaveOccurred()) + + By("Verifying only optimizer tools are exposed") + Expect(tools.Tools).To(HaveLen(2), "Should only have find_tool and call_tool") + + toolNames := make([]string, len(tools.Tools)) + for i, tool := range tools.Tools { + toolNames[i] = tool.Name + } + Expect(toolNames).To(ContainElements("find_tool", "call_tool")) + + _, _ = fmt.Fprintf(GinkgoWriter, "✓ Optimizer mode correctly exposes only: %v\n", toolNames) + }) + + It("should find backend tools via find_tool", func() { + By("Creating and initializing MCP client") + mcpClient, err := CreateInitializedMCPClient(vmcpNodePort, "optimizer-find-test", 30*time.Second) + Expect(err).ToNot(HaveOccurred()) + defer mcpClient.Close() + + By("Calling find_tool to search for fetch tool") + result, err := callFindTool(mcpClient, "fetch") + Expect(err).ToNot(HaveOccurred()) + + toolNames := getToolNames(result) + Expect(toolNames).ToNot(BeEmpty(), "find_tool should return matching tools") + _, _ = fmt.Fprintf(GinkgoWriter, "✓ find_tool returned tools: %v\n", toolNames) + + By("Verifying at least one tool matches 'fetch'") + var foundFetch bool + for _, name := range toolNames { + if strings.Contains(name, "fetch") { + foundFetch = true + } + } + Expect(foundFetch).To(BeTrue(), "Should find a fetch-related tool") + }) + + It("should invoke backend tools via call_tool", func() { + By("Creating and initializing MCP client") + mcpClient, err := CreateInitializedMCPClient(vmcpNodePort, "optimizer-call-test", 30*time.Second) + Expect(err).ToNot(HaveOccurred()) + defer mcpClient.Close() + + By("Finding the backend fetch tool") + findResult, err := callFindTool(mcpClient, "internet") // matches backend fetch description + Expect(err).ToNot(HaveOccurred()) + + toolNames := getToolNames(findResult) + Expect(toolNames).ToNot(BeEmpty()) + + // Find the backend tool (has prefix), not the composite + var backendToolName string + for _, name := range toolNames { + if strings.Contains(name, backendName) { + backendToolName = name + break + } + } + Expect(backendToolName).ToNot(BeEmpty(), "Should find backend tool") + + By(fmt.Sprintf("Calling %s via call_tool", backendToolName)) + result, err := callToolViaOptimizer(mcpClient, backendToolName, map[string]any{ + "url": "https://example.com", + }) + Expect(err).ToNot(HaveOccurred()) + Expect(result).ToNot(BeNil()) + Expect(result.Content).ToNot(BeEmpty(), "call_tool should return content from backend tool") + + _, _ = fmt.Fprintf(GinkgoWriter, "✓ Successfully called %s via call_tool\n", backendToolName) + }) + + It("should find and invoke composite tools via optimizer", func() { + By("Creating and initializing MCP client") + mcpClient, err := CreateInitializedMCPClient(vmcpNodePort, "optimizer-composite-test", 30*time.Second) + Expect(err).ToNot(HaveOccurred()) + defer mcpClient.Close() + + By("Calling find_tool to search for composite tool") + result, err := callFindTool(mcpClient, "twice") // matches "Fetches a URL twice..." + Expect(err).ToNot(HaveOccurred()) + + toolNames := getToolNames(result) + var foundTool string + for _, name := range toolNames { + _, _ = fmt.Fprintf(GinkgoWriter, " - Found tool: %s\n", name) + if strings.Contains(name, compositeToolName) { + foundTool = name + } + } + Expect(foundTool).ToNot(BeEmpty(), "Should find composite tool %s", compositeToolName) + _, _ = fmt.Fprintf(GinkgoWriter, "✓ Found composite tool %s via find_tool\n", foundTool) + + By(fmt.Sprintf("Calling composite tool %s via call_tool", foundTool)) + callResult, err := callToolViaOptimizer(mcpClient, foundTool, map[string]any{ + "url": "https://example.com", + }) + Expect(err).ToNot(HaveOccurred()) + Expect(callResult).ToNot(BeNil()) + Expect(callResult.Content).ToNot(BeEmpty(), "call_tool should return content from composite tool") + + _, _ = fmt.Fprintf(GinkgoWriter, "✓ Successfully called composite tool %s via call_tool\n", foundTool) + }) +})