diff --git a/cmd/thv-operator/pkg/vmcpconfig/converter.go b/cmd/thv-operator/pkg/vmcpconfig/converter.go index 7eead1d576..10c2538e07 100644 --- a/cmd/thv-operator/pkg/vmcpconfig/converter.go +++ b/cmd/thv-operator/pkg/vmcpconfig/converter.go @@ -61,15 +61,23 @@ func NewConverter(oidcResolver oidc.Resolver, k8sClient client.Client) (*Convert }, nil } -// Convert converts VirtualMCPServer CRD spec to vmcp Config +// Convert converts VirtualMCPServer CRD spec to vmcp Config. +// +// The conversion starts with a DeepCopy of the embedded config.Config from the CRD spec. +// This ensures that simple fields (like Optimizer, Metadata, etc.) are automatically +// passed through without explicit mapping. Only fields that require special handling +// (auth, aggregation, composite tools, telemetry) are explicitly converted below. func (c *Converter) Convert( ctx context.Context, vmcp *mcpv1alpha1.VirtualMCPServer, ) (*vmcpconfig.Config, error) { - config := &vmcpconfig.Config{ - Name: vmcp.Name, - Group: vmcp.Spec.Config.Group, - } + // Start with a deep copy of the embedded config for automatic field passthrough. + // This ensures new fields added to config.Config are automatically included + // without requiring explicit mapping in this converter. + config := vmcp.Spec.Config.DeepCopy() + + // Override name with the CR name (authoritative source) + config.Name = vmcp.Name // Convert IncomingAuth - required field, no defaults if vmcp.Spec.IncomingAuth != nil { diff --git a/cmd/vmcp/README.md b/cmd/vmcp/README.md index 30ac862ca2..10a60bef2a 100644 --- a/cmd/vmcp/README.md +++ b/cmd/vmcp/README.md @@ -6,7 +6,7 @@ The Virtual MCP Server (vmcp) is a standalone binary that aggregates multiple MC ## Features -### Implemented (Phase 1) +### Implemented - ✅ **Group-Based Backend Management**: Automatic workload discovery from ToolHive groups - ✅ **Tool Aggregation**: Combines tools from multiple MCP servers with conflict resolution (prefix, priority, manual) - ✅ **Resource & Prompt Aggregation**: Unified access to resources and prompts from all backends @@ -15,12 +15,14 @@ The Virtual MCP Server (vmcp) is a standalone binary that aggregates multiple MC - ✅ **Health Endpoints**: `/health` and `/ping` for service monitoring - ✅ **Configuration Validation**: `vmcp validate` command for config verification - ✅ **Observability**: OpenTelemetry metrics and traces for backend operations and workflow executions +- ✅ **Composite Tools**: Multi-step workflows with elicitation support ### In Progress - 🚧 **Incoming Authentication** (Issue #165): OIDC, local, anonymous authentication - 🚧 **Outgoing Authentication** (Issue #160): RFC 8693 token exchange for backend API access - 🚧 **Token Caching**: Memory and Redis cache providers - 🚧 **Health Monitoring** (Issue #166): Circuit breakers, backend health checks +- 🚧 **Optimizer** Support the MCP optimizer in vMCP for context optimization on large toolsets. ### Future (Phase 2+) - 📋 **Authorization**: Cedar policy-based access control diff --git a/cmd/vmcp/app/commands.go b/cmd/vmcp/app/commands.go index ca6060bcab..891b32c1cb 100644 --- a/cmd/vmcp/app/commands.go +++ b/cmd/vmcp/app/commands.go @@ -24,6 +24,7 @@ import ( "github.com/stacklok/toolhive/pkg/vmcp/discovery" "github.com/stacklok/toolhive/pkg/vmcp/health" "github.com/stacklok/toolhive/pkg/vmcp/k8s" + "github.com/stacklok/toolhive/pkg/vmcp/optimizer" vmcprouter "github.com/stacklok/toolhive/pkg/vmcp/router" vmcpserver "github.com/stacklok/toolhive/pkg/vmcp/server" ) @@ -416,6 +417,11 @@ func runServe(cmd *cobra.Command, _ []string) error { Watcher: backendWatcher, } + if cfg.Optimizer != nil { + // TODO: update this with the real optimizer. + serverCfg.OptimizerFactory = optimizer.NewDummyOptimizer + } + // Convert composite tool configurations to workflow definitions workflowDefs, err := vmcpserver.ConvertConfigToWorkflowDefinitions(cfg.CompositeTools) if err != nil { diff --git a/deploy/charts/operator-crds/files/crds/toolhive.stacklok.dev_virtualmcpservers.yaml b/deploy/charts/operator-crds/files/crds/toolhive.stacklok.dev_virtualmcpservers.yaml index bff67be5a8..bd70f3bffd 100644 --- a/deploy/charts/operator-crds/files/crds/toolhive.stacklok.dev_virtualmcpservers.yaml +++ b/deploy/charts/operator-crds/files/crds/toolhive.stacklok.dev_virtualmcpservers.yaml @@ -862,6 +862,21 @@ spec: - default type: object type: object + optimizer: + description: |- + Optimizer configures the MCP optimizer for context optimization on large toolsets. + When enabled, vMCP exposes only find_tool and call_tool operations to clients + instead of all backend tools directly. This reduces token usage by allowing + LLMs to discover relevant tools on demand rather than receiving all tool definitions. + properties: + embeddingService: + description: |- + EmbeddingService is the name of a Kubernetes Service that provides the embedding service + for semantic tool discovery. The service must implement the optimizer embedding API. + type: string + required: + - embeddingService + type: object outgoingAuth: description: OutgoingAuth configures how the virtual MCP server authenticates to backends. diff --git a/deploy/charts/operator-crds/templates/toolhive.stacklok.dev_virtualmcpservers.yaml b/deploy/charts/operator-crds/templates/toolhive.stacklok.dev_virtualmcpservers.yaml index 8e329f8584..b497666ce5 100644 --- a/deploy/charts/operator-crds/templates/toolhive.stacklok.dev_virtualmcpservers.yaml +++ b/deploy/charts/operator-crds/templates/toolhive.stacklok.dev_virtualmcpservers.yaml @@ -865,6 +865,21 @@ spec: - default type: object type: object + optimizer: + description: |- + Optimizer configures the MCP optimizer for context optimization on large toolsets. + When enabled, vMCP exposes only find_tool and call_tool operations to clients + instead of all backend tools directly. This reduces token usage by allowing + LLMs to discover relevant tools on demand rather than receiving all tool definitions. + properties: + embeddingService: + description: |- + EmbeddingService is the name of a Kubernetes Service that provides the embedding service + for semantic tool discovery. The service must implement the optimizer embedding API. + type: string + required: + - embeddingService + type: object outgoingAuth: description: OutgoingAuth configures how the virtual MCP server authenticates to backends. diff --git a/docs/operator/crd-api.md b/docs/operator/crd-api.md index 14e8587c6a..de89fc4a7c 100644 --- a/docs/operator/crd-api.md +++ b/docs/operator/crd-api.md @@ -225,6 +225,7 @@ _Appears in:_ | `metadata` _object (keys:string, values:string)_ | Refer to Kubernetes API documentation for fields of `metadata`. | | | | `telemetry` _[pkg.telemetry.Config](#pkgtelemetryconfig)_ | Telemetry configures OpenTelemetry-based observability for the Virtual MCP server
including distributed tracing, OTLP metrics export, and Prometheus metrics endpoint. | | | | `audit` _[pkg.audit.Config](#pkgauditconfig)_ | Audit configures audit logging for the Virtual MCP server.
When present, audit logs include MCP protocol operations.
See audit.Config for available configuration options. | | | +| `optimizer` _[vmcp.config.OptimizerConfig](#vmcpconfigoptimizerconfig)_ | Optimizer configures the MCP optimizer for context optimization on large toolsets.
When enabled, vMCP exposes only find_tool and call_tool operations to clients
instead of all backend tools directly. This reduces token usage by allowing
LLMs to discover relevant tools on demand rather than receiving all tool definitions. | | | #### vmcp.config.ConflictResolutionConfig @@ -343,6 +344,24 @@ _Appears in:_ | `failureHandling` _[vmcp.config.FailureHandlingConfig](#vmcpconfigfailurehandlingconfig)_ | FailureHandling configures failure handling. | | | +#### vmcp.config.OptimizerConfig + + + +OptimizerConfig configures the MCP optimizer. +When enabled, vMCP exposes only find_tool and call_tool operations to clients +instead of all backend tools directly. + + + +_Appears in:_ +- [vmcp.config.Config](#vmcpconfigconfig) + +| Field | Description | Default | Validation | +| --- | --- | --- | --- | +| `embeddingService` _string_ | EmbeddingService is the name of a Kubernetes Service that provides the embedding service
for semantic tool discovery. The service must implement the optimizer embedding API. | | Required: \{\}
| + + #### vmcp.config.OutgoingAuthConfig diff --git a/pkg/vmcp/config/config.go b/pkg/vmcp/config/config.go index 9693643348..a6355cd5e8 100644 --- a/pkg/vmcp/config/config.go +++ b/pkg/vmcp/config/config.go @@ -108,6 +108,13 @@ type Config struct { // See audit.Config for available configuration options. // +optional Audit *audit.Config `json:"audit,omitempty" yaml:"audit,omitempty"` + + // Optimizer configures the MCP optimizer for context optimization on large toolsets. + // When enabled, vMCP exposes only find_tool and call_tool operations to clients + // instead of all backend tools directly. This reduces token usage by allowing + // LLMs to discover relevant tools on demand rather than receiving all tool definitions. + // +optional + Optimizer *OptimizerConfig `json:"optimizer,omitempty" yaml:"optimizer,omitempty"` } // IncomingAuthConfig configures client authentication to the virtual MCP server. @@ -474,6 +481,18 @@ type OutputProperty struct { Default thvjson.Any `json:"default,omitempty" yaml:"default,omitempty"` } +// OptimizerConfig configures the MCP optimizer. +// When enabled, vMCP exposes only find_tool and call_tool operations to clients +// instead of all backend tools directly. +// +kubebuilder:object:generate=true +// +gendoc +type OptimizerConfig struct { + // EmbeddingService is the name of a Kubernetes Service that provides the embedding service + // for semantic tool discovery. The service must implement the optimizer embedding API. + // +kubebuilder:validation:Required + EmbeddingService string `json:"embeddingService" yaml:"embeddingService"` +} + // Validator validates configuration. type Validator interface { // Validate checks if the configuration is valid. diff --git a/pkg/vmcp/config/zz_generated.deepcopy.go b/pkg/vmcp/config/zz_generated.deepcopy.go index b50fa118af..5afe9c656a 100644 --- a/pkg/vmcp/config/zz_generated.deepcopy.go +++ b/pkg/vmcp/config/zz_generated.deepcopy.go @@ -175,6 +175,11 @@ func (in *Config) DeepCopyInto(out *Config) { *out = new(audit.Config) (*in).DeepCopyInto(*out) } + if in.Optimizer != nil { + in, out := &in.Optimizer, &out.Optimizer + *out = new(OptimizerConfig) + **out = **in + } } // DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new Config. @@ -312,6 +317,21 @@ func (in *OperationalConfig) DeepCopy() *OperationalConfig { return out } +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *OptimizerConfig) DeepCopyInto(out *OptimizerConfig) { + *out = *in +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new OptimizerConfig. +func (in *OptimizerConfig) DeepCopy() *OptimizerConfig { + if in == nil { + return nil + } + out := new(OptimizerConfig) + in.DeepCopyInto(out) + return out +} + // DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. func (in *OutgoingAuthConfig) DeepCopyInto(out *OutgoingAuthConfig) { *out = *in diff --git a/pkg/vmcp/optimizer/dummy_optimizer.go b/pkg/vmcp/optimizer/dummy_optimizer.go new file mode 100644 index 0000000000..3a8338f04d --- /dev/null +++ b/pkg/vmcp/optimizer/dummy_optimizer.go @@ -0,0 +1,122 @@ +package optimizer + +import ( + "context" + "encoding/json" + "fmt" + "strings" + + "github.com/mark3labs/mcp-go/mcp" + "github.com/mark3labs/mcp-go/server" +) + +// DummyOptimizer implements the Optimizer interface using exact string matching. +// +// This implementation is intended for testing and development. It performs +// case-insensitive substring matching on tool names and descriptions. +// +// For production use, see the EmbeddingOptimizer which uses semantic similarity. +type DummyOptimizer struct { + // tools contains all available tools indexed by name. + tools map[string]server.ServerTool +} + +// NewDummyOptimizer creates a new DummyOptimizer with the given tools. +// +// The tools slice should contain all backend tools (as ServerTool with handlers). +func NewDummyOptimizer(tools []server.ServerTool) Optimizer { + toolMap := make(map[string]server.ServerTool, len(tools)) + for _, tool := range tools { + toolMap[tool.Tool.Name] = tool + } + + return DummyOptimizer{ + tools: toolMap, + } +} + +// FindTool searches for tools using exact substring matching. +// +// The search is case-insensitive and matches against: +// - Tool name (substring match) +// - Tool description (substring match) +// +// Returns all matching tools with a score of 1.0 (exact match semantics). +// TokenMetrics are returned as zero values (not implemented in dummy). +func (d DummyOptimizer) FindTool(_ context.Context, input FindToolInput) (*FindToolOutput, error) { + if input.ToolDescription == "" { + return nil, fmt.Errorf("tool_description is required") + } + + // Log all tools in the optimizer for debugging + fmt.Printf("[DummyOptimizer.FindTool] Searching for %q in %d tools:\n", input.ToolDescription, len(d.tools)) + for name, tool := range d.tools { + fmt.Printf(" - %q: %q\n", name, tool.Tool.Description) + } + + searchTerm := strings.ToLower(input.ToolDescription) + + var matches []ToolMatch + for _, tool := range d.tools { + nameLower := strings.ToLower(tool.Tool.Name) + descLower := strings.ToLower(tool.Tool.Description) + + // Check if search term matches name or description + if strings.Contains(nameLower, searchTerm) || strings.Contains(descLower, searchTerm) { + schema, err := getToolSchema(tool.Tool) + if err != nil { + return nil, err + } + matches = append(matches, ToolMatch{ + Name: tool.Tool.Name, + Description: tool.Tool.Description, + Parameters: schema, + Score: 1.0, // Exact match semantics + }) + } + } + + return &FindToolOutput{ + Tools: matches, + TokenMetrics: TokenMetrics{}, // Zero values for dummy + }, nil +} + +// CallTool invokes a tool by name using its registered handler. +// +// The tool is looked up by exact name match. If found, the handler +// is invoked directly with the given parameters. +func (d DummyOptimizer) CallTool(ctx context.Context, input CallToolInput) (*mcp.CallToolResult, error) { + if input.ToolName == "" { + return nil, fmt.Errorf("tool_name is required") + } + + // Verify the tool exists + tool, exists := d.tools[input.ToolName] + if !exists { + return mcp.NewToolResultError(fmt.Sprintf("tool not found: %s", input.ToolName)), nil + } + + // Build the MCP request + request := mcp.CallToolRequest{} + request.Params.Name = input.ToolName + request.Params.Arguments = input.Parameters + + // Call the tool handler directly + return tool.Handler(ctx, request) +} + +// getToolSchema returns the input schema for a tool. +// Prefers RawInputSchema if set, otherwise marshals InputSchema. +func getToolSchema(tool mcp.Tool) (json.RawMessage, error) { + if len(tool.RawInputSchema) > 0 { + return tool.RawInputSchema, nil + } + + // Fall back to InputSchema + data, err := json.Marshal(tool.InputSchema) + if err != nil { + return nil, fmt.Errorf("failed to marshal input schema for tool %s: %w", tool.Name, err) + } + return data, nil +} diff --git a/pkg/vmcp/optimizer/dummy_optimizer_test.go b/pkg/vmcp/optimizer/dummy_optimizer_test.go new file mode 100644 index 0000000000..489f04e0e5 --- /dev/null +++ b/pkg/vmcp/optimizer/dummy_optimizer_test.go @@ -0,0 +1,188 @@ +package optimizer + +import ( + "context" + "testing" + + "github.com/mark3labs/mcp-go/mcp" + "github.com/mark3labs/mcp-go/server" + "github.com/stretchr/testify/require" +) + +func TestDummyOptimizer_FindTool(t *testing.T) { + t.Parallel() + + tools := []server.ServerTool{ + { + Tool: mcp.Tool{ + Name: "fetch_url", + Description: "Fetch content from a URL", + }, + }, + { + Tool: mcp.Tool{ + Name: "read_file", + Description: "Read a file from the filesystem", + }, + }, + { + Tool: mcp.Tool{ + Name: "write_file", + Description: "Write content to a file", + }, + }, + } + + opt := NewDummyOptimizer(tools) + + tests := []struct { + name string + input FindToolInput + expectedNames []string + expectedError bool + errorContains string + }{ + { + name: "find by exact name", + input: FindToolInput{ + ToolDescription: "fetch_url", + }, + expectedNames: []string{"fetch_url"}, + }, + { + name: "find by description substring", + input: FindToolInput{ + ToolDescription: "file", + }, + expectedNames: []string{"read_file", "write_file"}, + }, + { + name: "case insensitive search", + input: FindToolInput{ + ToolDescription: "FETCH", + }, + expectedNames: []string{"fetch_url"}, + }, + { + name: "no matches", + input: FindToolInput{ + ToolDescription: "nonexistent", + }, + expectedNames: []string{}, + }, + { + name: "empty description", + input: FindToolInput{}, + expectedError: true, + errorContains: "tool_description is required", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + result, err := opt.FindTool(context.Background(), tc.input) + + if tc.expectedError { + require.Error(t, err) + require.Contains(t, err.Error(), tc.errorContains) + return + } + + require.NoError(t, err) + require.NotNil(t, result) + + // Extract names from results + var names []string + for _, match := range result.Tools { + names = append(names, match.Name) + } + + require.ElementsMatch(t, tc.expectedNames, names) + }) + } +} + +func TestDummyOptimizer_CallTool(t *testing.T) { + t.Parallel() + + tools := []server.ServerTool{ + { + Tool: mcp.Tool{ + Name: "test_tool", + Description: "A test tool", + }, + Handler: func(_ context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + args, _ := req.Params.Arguments.(map[string]any) + input := args["input"].(string) + return mcp.NewToolResultText("Hello, " + input + "!"), nil + }, + }, + } + + opt := NewDummyOptimizer(tools) + + tests := []struct { + name string + input CallToolInput + expectedText string + expectedError bool + isToolError bool + errorContains string + }{ + { + name: "successful tool call", + input: CallToolInput{ + ToolName: "test_tool", + Parameters: map[string]any{"input": "World"}, + }, + expectedText: "Hello, World!", + }, + { + name: "tool not found", + input: CallToolInput{ + ToolName: "nonexistent", + Parameters: map[string]any{}, + }, + isToolError: true, + expectedText: "tool not found: nonexistent", + }, + { + name: "empty tool name", + input: CallToolInput{ + Parameters: map[string]any{}, + }, + expectedError: true, + errorContains: "tool_name is required", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + result, err := opt.CallTool(context.Background(), tc.input) + + if tc.expectedError { + require.Error(t, err) + require.Contains(t, err.Error(), tc.errorContains) + return + } + + require.NoError(t, err) + require.NotNil(t, result) + + if tc.isToolError { + require.True(t, result.IsError) + } + + if tc.expectedText != "" { + require.Len(t, result.Content, 1) + textContent, ok := result.Content[0].(mcp.TextContent) + require.True(t, ok) + require.Equal(t, tc.expectedText, textContent.Text) + } + }) + } +} diff --git a/pkg/vmcp/optimizer/optimizer.go b/pkg/vmcp/optimizer/optimizer.go new file mode 100644 index 0000000000..2da6f748f7 --- /dev/null +++ b/pkg/vmcp/optimizer/optimizer.go @@ -0,0 +1,88 @@ +// Package optimizer provides the Optimizer interface for intelligent tool discovery +// and invocation in the Virtual MCP Server. +// +// When the optimizer is enabled, vMCP exposes only two tools to clients: +// - find_tool: Semantic search over available tools +// - call_tool: Dynamic invocation of any backend tool +// +// This reduces token usage by avoiding the need to send all tool definitions +// to the LLM, instead allowing it to discover relevant tools on demand. +package optimizer + +import ( + "context" + "encoding/json" + + "github.com/mark3labs/mcp-go/mcp" +) + +// Optimizer defines the interface for intelligent tool discovery and invocation. +// +// Implementations may use various strategies for tool matching: +// - DummyOptimizer: Exact string matching (for testing) +// - EmbeddingOptimizer: Semantic similarity via embeddings (production) +type Optimizer interface { + // FindTool searches for tools matching the given description and keywords. + // Returns matching tools ranked by relevance score. + FindTool(ctx context.Context, input FindToolInput) (*FindToolOutput, error) + + // CallTool invokes a tool by name with the given parameters. + // Returns the tool's result or an error if the tool is not found or execution fails. + // Returns the MCP CallToolResult directly from the underlying tool handler. + CallTool(ctx context.Context, input CallToolInput) (*mcp.CallToolResult, error) +} + +// FindToolInput contains the parameters for finding tools. +type FindToolInput struct { + // ToolDescription is a natural language description of the tool to find. + ToolDescription string `json:"tool_description" description:"Natural language description of the tool to find"` + + // ToolKeywords is an optional list of keywords to narrow the search. + ToolKeywords []string `json:"tool_keywords,omitempty" description:"Optional keywords to narrow search"` +} + +// FindToolOutput contains the results of a tool search. +type FindToolOutput struct { + // Tools contains the matching tools, ranked by relevance. + Tools []ToolMatch `json:"tools"` + + // TokenMetrics provides information about token savings from using the optimizer. + TokenMetrics TokenMetrics `json:"token_metrics"` +} + +// ToolMatch represents a tool that matched the search criteria. +type ToolMatch struct { + // Name is the unique identifier of the tool. + Name string `json:"name"` + + // Description is the human-readable description of the tool. + Description string `json:"description"` + + // Parameters is the JSON schema for the tool's input parameters. + // Uses json.RawMessage to preserve the original schema format. + Parameters json.RawMessage `json:"parameters"` + + // Score indicates how well this tool matches the search criteria (0.0-1.0). + Score float64 `json:"score"` +} + +// TokenMetrics provides information about token usage optimization. +type TokenMetrics struct { + // BaselineTokens is the estimated tokens if all tools were sent. + BaselineTokens int `json:"baseline_tokens"` + + // ReturnedTokens is the actual tokens for the returned tools. + ReturnedTokens int `json:"returned_tokens"` + + // SavingsPercent is the percentage of tokens saved. + SavingsPercent float64 `json:"savings_percent"` +} + +// CallToolInput contains the parameters for calling a tool. +type CallToolInput struct { + // ToolName is the name of the tool to invoke. + ToolName string `json:"tool_name" description:"Name of the tool to call"` + + // Parameters are the arguments to pass to the tool. + Parameters map[string]any `json:"parameters" description:"Parameters to pass to the tool"` +} diff --git a/pkg/vmcp/schema/reflect.go b/pkg/vmcp/schema/reflect.go new file mode 100644 index 0000000000..a6f1ec419e --- /dev/null +++ b/pkg/vmcp/schema/reflect.go @@ -0,0 +1,173 @@ +package schema + +import ( + "encoding/json" + "fmt" + "reflect" + "strings" +) + +// GenerateSchema generates a JSON Schema from a Go struct type using reflection. +// +// The function inspects struct tags to determine: +// - json: Field name in the schema (uses json tag name) +// - description: Field description (from `description` tag) +// - omitempty: Whether the field is optional (not in required array) +// +// Supported types: +// - string -> {"type": "string"} +// - int, int64, etc. -> {"type": "integer"} +// - float64, float32 -> {"type": "number"} +// - bool -> {"type": "boolean"} +// - []T -> {"type": "array", "items": {...}} +// - map[string]any -> {"type": "object"} +// - struct -> {"type": "object", "properties": {...}} +// +// Example: +// +// type FindToolInput struct { +// ToolDescription string `json:"tool_description" description:"Natural language description"` +// ToolKeywords []string `json:"tool_keywords,omitempty" description:"Optional keywords"` +// } +// schema := GenerateSchema[FindToolInput]() +func GenerateSchema[T any]() map[string]any { + var zero T + t := reflect.TypeOf(zero) + return generateSchemaForType(t) +} + +// generateSchemaForType generates schema for a reflect.Type. +func generateSchemaForType(t reflect.Type) map[string]any { + // Handle pointer types + if t.Kind() == reflect.Ptr { + t = t.Elem() + } + + switch t.Kind() { + case reflect.Struct: + return generateObjectSchema(t) + case reflect.String: + return map[string]any{"type": "string"} + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, + reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + return map[string]any{"type": "integer"} + case reflect.Float32, reflect.Float64: + return map[string]any{"type": "number"} + case reflect.Bool: + return map[string]any{"type": "boolean"} + case reflect.Slice: + return map[string]any{ + "type": "array", + "items": generateSchemaForType(t.Elem()), + } + case reflect.Map: + // For map[string]any, just return object type + return map[string]any{"type": "object"} + case reflect.Interface: + // For any/interface{}, return empty object + return map[string]any{"type": "object"} + default: + return map[string]any{"type": "object"} + } +} + +// generateObjectSchema generates schema for a struct type. +func generateObjectSchema(t reflect.Type) map[string]any { + properties := make(map[string]any) + var required []string + + for i := range t.NumField() { + field := t.Field(i) + + // Skip unexported fields + if !field.IsExported() { + continue + } + + // Get JSON tag for field name + jsonTag := field.Tag.Get("json") + if jsonTag == "-" { + continue + } + + // Parse json tag (name,omitempty) + jsonName, isOptional := parseJSONTag(jsonTag) + if jsonName == "" { + jsonName = field.Name + } + + // Generate schema for field type + fieldSchema := generateSchemaForType(field.Type) + + // Add description if present + if desc := field.Tag.Get("description"); desc != "" { + fieldSchema["description"] = desc + } + + properties[jsonName] = fieldSchema + + // Add to required if not optional + if !isOptional { + required = append(required, jsonName) + } + } + + schema := map[string]any{ + "type": "object", + "properties": properties, + } + + if len(required) > 0 { + schema["required"] = required + } + + return schema +} + +// parseJSONTag parses a json struct tag and returns the field name and whether it's optional. +func parseJSONTag(tag string) (name string, optional bool) { + if tag == "" { + return "", false + } + + parts := strings.Split(tag, ",") + name = parts[0] + + for _, part := range parts[1:] { + if part == "omitempty" { + optional = true + } + } + + return name, optional +} + +// Translate converts an untyped input (typically map[string]any from MCP request arguments) +// to a typed struct using JSON marshalling/unmarshalling. +// +// This provides a simple, reliable way to convert MCP tool arguments to typed Go structs +// without manual field-by-field extraction. +// +// Example: +// +// args := request.Params.Arguments // map[string]any +// input, err := Translate[FindToolInput](args) +// if err != nil { +// return nil, fmt.Errorf("invalid arguments: %w", err) +// } +func Translate[T any](input any) (T, error) { + var result T + + // Marshal to JSON + data, err := json.Marshal(input) + if err != nil { + return result, fmt.Errorf("failed to marshal input: %w", err) + } + + // Unmarshal to typed struct + if err := json.Unmarshal(data, &result); err != nil { + return result, fmt.Errorf("failed to unmarshal to %T: %w", result, err) + } + + return result, nil +} diff --git a/pkg/vmcp/schema/reflect_test.go b/pkg/vmcp/schema/reflect_test.go new file mode 100644 index 0000000000..a6acbbecf4 --- /dev/null +++ b/pkg/vmcp/schema/reflect_test.go @@ -0,0 +1,141 @@ +package schema + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/stacklok/toolhive/pkg/vmcp/optimizer" +) + +func TestGenerateSchema_FindToolInput(t *testing.T) { + t.Parallel() + + expected := map[string]any{ + "type": "object", + "properties": map[string]any{ + "tool_description": map[string]any{ + "type": "string", + "description": "Natural language description of the tool to find", + }, + "tool_keywords": map[string]any{ + "type": "array", + "items": map[string]any{"type": "string"}, + "description": "Optional keywords to narrow search", + }, + }, + "required": []string{"tool_description"}, + } + + actual := GenerateSchema[optimizer.FindToolInput]() + + require.Equal(t, expected, actual) +} + +func TestGenerateSchema_CallToolInput(t *testing.T) { + t.Parallel() + + expected := map[string]any{ + "type": "object", + "properties": map[string]any{ + "tool_name": map[string]any{ + "type": "string", + "description": "Name of the tool to call", + }, + "parameters": map[string]any{ + "type": "object", + "description": "Parameters to pass to the tool", + }, + }, + "required": []string{"tool_name", "parameters"}, + } + + actual := GenerateSchema[optimizer.CallToolInput]() + + require.Equal(t, expected, actual) +} + +func TestTranslate_FindToolInput(t *testing.T) { + t.Parallel() + + input := map[string]any{ + "tool_description": "find a tool to read files", + "tool_keywords": []any{"file", "read"}, + } + + result, err := Translate[optimizer.FindToolInput](input) + require.NoError(t, err) + + require.Equal(t, "find a tool to read files", result.ToolDescription) + require.Equal(t, []string{"file", "read"}, result.ToolKeywords) +} + +func TestTranslate_CallToolInput(t *testing.T) { + t.Parallel() + + input := map[string]any{ + "tool_name": "read_file", + "parameters": map[string]any{ + "path": "/etc/hosts", + }, + } + + result, err := Translate[optimizer.CallToolInput](input) + require.NoError(t, err) + + require.Equal(t, "read_file", result.ToolName) + require.Equal(t, map[string]any{"path": "/etc/hosts"}, result.Parameters) +} + +func TestTranslate_PartialInput(t *testing.T) { + t.Parallel() + + input := map[string]any{ + "tool_description": "find a file reader", + } + + result, err := Translate[optimizer.FindToolInput](input) + require.NoError(t, err) + + require.Equal(t, "find a file reader", result.ToolDescription) + require.Nil(t, result.ToolKeywords) +} + +func TestTranslate_InvalidInput(t *testing.T) { + t.Parallel() + + input := make(chan int) + + _, err := Translate[optimizer.FindToolInput](input) + require.Error(t, err) + assert.Contains(t, err.Error(), "failed to marshal input") +} + +func TestGenerateSchema_PrimitiveTypes(t *testing.T) { + t.Parallel() + + type TestStruct struct { + StringField string `json:"string_field"` + IntField int `json:"int_field"` + FloatField float64 `json:"float_field"` + BoolField bool `json:"bool_field"` + OptionalStr string `json:"optional_str,omitempty"` + } + + expected := map[string]any{ + "type": "object", + "properties": map[string]any{ + "string_field": map[string]any{"type": "string"}, + "int_field": map[string]any{"type": "integer"}, + "float_field": map[string]any{"type": "number"}, + "bool_field": map[string]any{"type": "boolean"}, + "optional_str": map[string]any{"type": "string"}, + }, + "required": []string{"string_field", "int_field", "float_field", "bool_field"}, + } + + actual := GenerateSchema[TestStruct]() + + require.Equal(t, expected, actual) +} diff --git a/pkg/vmcp/server/adapter/optimizer_adapter.go b/pkg/vmcp/server/adapter/optimizer_adapter.go new file mode 100644 index 0000000000..d0abb559ec --- /dev/null +++ b/pkg/vmcp/server/adapter/optimizer_adapter.go @@ -0,0 +1,94 @@ +package adapter + +import ( + "context" + "encoding/json" + "fmt" + + "github.com/mark3labs/mcp-go/mcp" + "github.com/mark3labs/mcp-go/server" + + "github.com/stacklok/toolhive/pkg/vmcp/optimizer" + "github.com/stacklok/toolhive/pkg/vmcp/schema" +) + +// OptimizerToolNames defines the tool names exposed when optimizer is enabled. +const ( + FindToolName = "find_tool" + CallToolName = "call_tool" +) + +// Pre-generated schemas for optimizer tools. +// Generated at package init time so any schema errors panic at startup. +var ( + findToolInputSchema = mustMarshalSchema(schema.GenerateSchema[optimizer.FindToolInput]()) + callToolInputSchema = mustMarshalSchema(schema.GenerateSchema[optimizer.CallToolInput]()) +) + +// CreateOptimizerTools creates the SDK tools for optimizer mode. +// When optimizer is enabled, only these two tools are exposed to clients +// instead of all backend tools. +func CreateOptimizerTools(opt optimizer.Optimizer) []server.ServerTool { + return []server.ServerTool{ + { + Tool: mcp.Tool{ + Name: FindToolName, + Description: "Search for tools by description. Returns matching tools ranked by relevance.", + RawInputSchema: findToolInputSchema, + }, + Handler: createFindToolHandler(opt), + }, + { + Tool: mcp.Tool{ + Name: CallToolName, + Description: "Call a tool by name with the given parameters.", + RawInputSchema: callToolInputSchema, + }, + Handler: createCallToolHandler(opt), + }, + } +} + +// createFindToolHandler creates a handler for the find_tool optimizer operation. +func createFindToolHandler(opt optimizer.Optimizer) func(context.Context, mcp.CallToolRequest) (*mcp.CallToolResult, error) { + return func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + input, err := schema.Translate[optimizer.FindToolInput](request.Params.Arguments) + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("invalid arguments: %v", err)), nil + } + + output, err := opt.FindTool(ctx, input) + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("find_tool failed: %v", err)), nil + } + + return mcp.NewToolResultStructuredOnly(output), nil + } +} + +// createCallToolHandler creates a handler for the call_tool optimizer operation. +func createCallToolHandler(opt optimizer.Optimizer) func(context.Context, mcp.CallToolRequest) (*mcp.CallToolResult, error) { + return func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + input, err := schema.Translate[optimizer.CallToolInput](request.Params.Arguments) + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("invalid arguments: %v", err)), nil + } + + result, err := opt.CallTool(ctx, input) + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("call_tool failed: %v", err)), nil + } + + return result, nil + } +} + +// mustMarshalSchema marshals a schema to JSON, panicking on error. +// This is safe because schemas are generated from known types at startup. +func mustMarshalSchema(s map[string]any) json.RawMessage { + data, err := json.Marshal(s) + if err != nil { + panic(fmt.Sprintf("failed to marshal schema: %v", err)) + } + return data +} diff --git a/pkg/vmcp/server/adapter/optimizer_adapter_test.go b/pkg/vmcp/server/adapter/optimizer_adapter_test.go new file mode 100644 index 0000000000..f0c4db73e0 --- /dev/null +++ b/pkg/vmcp/server/adapter/optimizer_adapter_test.go @@ -0,0 +1,104 @@ +package adapter + +import ( + "context" + "testing" + + "github.com/mark3labs/mcp-go/mcp" + "github.com/stretchr/testify/require" + + "github.com/stacklok/toolhive/pkg/vmcp/optimizer" +) + +// mockOptimizer implements optimizer.Optimizer for testing. +type mockOptimizer struct { + findToolFunc func(ctx context.Context, input optimizer.FindToolInput) (*optimizer.FindToolOutput, error) + callToolFunc func(ctx context.Context, input optimizer.CallToolInput) (*mcp.CallToolResult, error) +} + +func (m *mockOptimizer) FindTool(ctx context.Context, input optimizer.FindToolInput) (*optimizer.FindToolOutput, error) { + if m.findToolFunc != nil { + return m.findToolFunc(ctx, input) + } + return &optimizer.FindToolOutput{}, nil +} + +func (m *mockOptimizer) CallTool(ctx context.Context, input optimizer.CallToolInput) (*mcp.CallToolResult, error) { + if m.callToolFunc != nil { + return m.callToolFunc(ctx, input) + } + return mcp.NewToolResultText("ok"), nil +} + +func TestCreateOptimizerTools(t *testing.T) { + t.Parallel() + + opt := &mockOptimizer{} + tools := CreateOptimizerTools(opt) + + require.Len(t, tools, 2) + require.Equal(t, FindToolName, tools[0].Tool.Name) + require.Equal(t, CallToolName, tools[1].Tool.Name) +} + +func TestFindToolHandler(t *testing.T) { + t.Parallel() + + opt := &mockOptimizer{ + findToolFunc: func(_ context.Context, input optimizer.FindToolInput) (*optimizer.FindToolOutput, error) { + require.Equal(t, "read files", input.ToolDescription) + return &optimizer.FindToolOutput{ + Tools: []optimizer.ToolMatch{ + { + Name: "read_file", + Description: "Read a file", + Score: 1.0, + }, + }, + }, nil + }, + } + + tools := CreateOptimizerTools(opt) + handler := tools[0].Handler + + request := mcp.CallToolRequest{} + request.Params.Arguments = map[string]any{ + "tool_description": "read files", + } + + result, err := handler(context.Background(), request) + require.NoError(t, err) + require.NotNil(t, result) + require.False(t, result.IsError) + require.Len(t, result.Content, 1) +} + +func TestCallToolHandler(t *testing.T) { + t.Parallel() + + opt := &mockOptimizer{ + callToolFunc: func(_ context.Context, input optimizer.CallToolInput) (*mcp.CallToolResult, error) { + require.Equal(t, "read_file", input.ToolName) + require.Equal(t, "/etc/hosts", input.Parameters["path"]) + return mcp.NewToolResultText("file contents here"), nil + }, + } + + tools := CreateOptimizerTools(opt) + handler := tools[1].Handler + + request := mcp.CallToolRequest{} + request.Params.Arguments = map[string]any{ + "tool_name": "read_file", + "parameters": map[string]any{ + "path": "/etc/hosts", + }, + } + + result, err := handler(context.Background(), request) + require.NoError(t, err) + require.NotNil(t, result) + require.False(t, result.IsError) + require.Len(t, result.Content, 1) +} diff --git a/pkg/vmcp/server/server.go b/pkg/vmcp/server/server.go index 477b9a1234..d0879f3a63 100644 --- a/pkg/vmcp/server/server.go +++ b/pkg/vmcp/server/server.go @@ -28,6 +28,7 @@ import ( "github.com/stacklok/toolhive/pkg/vmcp/composer" "github.com/stacklok/toolhive/pkg/vmcp/discovery" "github.com/stacklok/toolhive/pkg/vmcp/health" + "github.com/stacklok/toolhive/pkg/vmcp/optimizer" "github.com/stacklok/toolhive/pkg/vmcp/router" "github.com/stacklok/toolhive/pkg/vmcp/server/adapter" vmcpsession "github.com/stacklok/toolhive/pkg/vmcp/session" @@ -119,6 +120,10 @@ type Config struct { // Only set when running in K8s with outgoingAuth.source: discovered. // Used for /readyz endpoint to gate readiness on cache sync. Watcher Watcher + + // OptimizerFactory builds an optimizer from a list of tools. + // If not set, the optimizer is disabled. + OptimizerFactory func([]server.ServerTool) optimizer.Optimizer } // Server is the Virtual MCP Server that aggregates multiple backends. @@ -344,89 +349,9 @@ func New( } // Register OnRegisterSession hook to inject capabilities after SDK registers session. - // This hook fires AFTER the session is registered in the SDK (unlike AfterInitialize which - // fires BEFORE session registration), allowing us to safely call AddSessionTools/AddSessionResources. - // - // The discovery middleware populates capabilities in the context, which is available here. - // We inject them into the SDK session and store the routing table for subsequent requests. - // - // IMPORTANT: Session capabilities are immutable after injection. - // - Capabilities discovered during initialize are fixed for the session lifetime - // - Backend changes (new tools, removed resources) won't be reflected in existing sessions - // - Clients must create new sessions to see updated capabilities - // TODO(dynamic-capabilities): Consider implementing capability refresh mechanism when SDK supports it + // See handleSessionRegistration for implementation details. hooks.AddOnRegisterSession(func(ctx context.Context, session server.ClientSession) { - sessionID := session.SessionID() - logger.Debugw("OnRegisterSession hook called", "session_id", sessionID) - - // Get capabilities from context (discovered by middleware) - caps, ok := discovery.DiscoveredCapabilitiesFromContext(ctx) - if !ok || caps == nil { - logger.Warnw("no discovered capabilities in context for OnRegisterSession hook", - "session_id", sessionID) - return - } - - // Validate that routing table exists - if caps.RoutingTable == nil { - logger.Warnw("routing table is nil in discovered capabilities", - "session_id", sessionID) - return - } - - // Add composite tools to capabilities - // Composite tools are static (from configuration) and not discovered from backends - // They are added here to be exposed alongside backend tools in the session - if len(srv.workflowDefs) > 0 { - compositeTools := convertWorkflowDefsToTools(srv.workflowDefs) - - // Validate no conflicts between composite tool names and backend tool names - if err := validateNoToolConflicts(caps.Tools, compositeTools); err != nil { - logger.Errorw("composite tool name conflict detected", - "session_id", sessionID, - "error", err) - // Don't add composite tools if there are conflicts - // This prevents ambiguity in routing/execution - return - } - - caps.CompositeTools = compositeTools - logger.Debugw("added composite tools to session capabilities", - "session_id", sessionID, - "composite_tool_count", len(compositeTools)) - } - - // Store routing table in VMCPSession for subsequent requests - // This enables the middleware to reconstruct capabilities from session - // without re-running discovery for every request - vmcpSess, err := vmcpsession.GetVMCPSession(sessionID, sessionManager) - if err != nil { - logger.Errorw("failed to get VMCPSession for routing table storage", - "error", err, - "session_id", sessionID) - return - } - - vmcpSess.SetRoutingTable(caps.RoutingTable) - vmcpSess.SetTools(caps.Tools) - logger.Debugw("routing table and tools stored in VMCPSession", - "session_id", sessionID, - "tool_count", len(caps.RoutingTable.Tools), - "resource_count", len(caps.RoutingTable.Resources), - "prompt_count", len(caps.RoutingTable.Prompts)) - - // Inject capabilities into SDK session - if err := srv.injectCapabilities(sessionID, caps); err != nil { - logger.Errorw("failed to inject session capabilities", - "error", err, - "session_id", sessionID) - return - } - - logger.Infow("session capabilities injected", - "session_id", sessionID, - "tool_count", len(caps.Tools), - "resource_count", len(caps.Resources)) + srv.handleSessionRegistration(ctx, session, sessionManager) }) return srv, nil @@ -779,6 +704,7 @@ func (s *Server) Ready() <-chan struct{} { // - No previous capabilities exist, so no deletion needed // - Capabilities are IMMUTABLE for the session lifetime (see limitation below) // - Discovery middleware does not re-run for subsequent requests +// - If injectOptimizerCapabilities is called, this should not be called again. // // LIMITATION: Session capabilities are fixed at creation time. // If backends change (new tools added, resources removed), existing sessions won't see updates. @@ -852,6 +778,175 @@ func (s *Server) injectCapabilities( return nil } +// injectOptimizerCapabilities injects all capabilities into the session, including optimizer tools. +// It should not be called if optimizer mode and replaces injectCapabilities. +// +// When optimizer mode is enabled, instead of exposing all backend tools directly, +// vMCP exposes only two meta-tools: +// - find_tool: Search for tools by description +// - call_tool: Invoke a tool by name with parameters +// +// This method: +// 1. Converts all tools (backend + composite) to SDK format with handlers +// 2. Injects the optimizer capabilities into the session +func (s *Server) injectOptimizerCapabilities( + sessionID string, + caps *aggregator.AggregatedCapabilities, +) error { + + tools := append([]vmcp.Tool{}, caps.Tools...) + tools = append(tools, caps.CompositeTools...) + + sdkTools, err := s.capabilityAdapter.ToSDKTools(tools) + if err != nil { + return fmt.Errorf("failed to convert tools to SDK format: %w", err) + } + + // Create optimizer tools (find_tool, call_tool) + optimizerTools := adapter.CreateOptimizerTools(s.config.OptimizerFactory(sdkTools)) + + logger.Debugw("created optimizer tools for session", + "session_id", sessionID, + "backend_tool_count", len(caps.Tools), + "composite_tool_count", len(caps.CompositeTools), + "total_tools_indexed", len(sdkTools)) + + // Save tools before clearing - caps is shared across sessions + savedTools := caps.Tools + savedCompositeTools := caps.CompositeTools + + // Clear tools from caps - they're now wrapped by optimizer + // Resources and prompts are preserved and handled normally + caps.Tools = nil + caps.CompositeTools = nil + defer func() { + // Restore tools before returning error + caps.Tools = savedTools + caps.CompositeTools = savedCompositeTools + }() + + // Manually add the optimizer tools, since we don't want to bother converting + // optimizer tools into `vmcp.Tool`s as well. + if err := s.mcpServer.AddSessionTools(sessionID, optimizerTools...); err != nil { + return fmt.Errorf("failed to add session tools: %w", err) + } + + return s.injectCapabilities(sessionID, caps) +} + +// handleSessionRegistration processes a new MCP session registration. +// +// This hook fires AFTER the session is registered in the SDK (unlike AfterInitialize which +// fires BEFORE session registration), allowing us to safely call AddSessionTools/AddSessionResources. +// +// The discovery middleware populates capabilities in the context, which is available here. +// We inject them into the SDK session and store the routing table for subsequent requests. +// +// This method performs the following steps: +// 1. Retrieves discovered capabilities from context +// 2. Adds composite tools from configuration +// 3. Stores routing table in VMCPSession for request routing +// 4. Injects capabilities into the SDK session +// +// IMPORTANT: Session capabilities are immutable after injection. +// - Capabilities discovered during initialize are fixed for the session lifetime +// - Backend changes (new tools, removed resources) won't be reflected in existing sessions +// - Clients must create new sessions to see updated capabilities +// +// TODO(dynamic-capabilities): Consider implementing capability refresh mechanism when SDK supports it +// +// The sessionManager parameter is passed explicitly because this method is called +// from a closure registered before the Server is fully constructed. +func (s *Server) handleSessionRegistration( + ctx context.Context, + session server.ClientSession, + sessionManager *transportsession.Manager, +) { + sessionID := session.SessionID() + logger.Debugw("OnRegisterSession hook called", "session_id", sessionID) + + // Get capabilities from context (discovered by middleware) + caps, ok := discovery.DiscoveredCapabilitiesFromContext(ctx) + if !ok || caps == nil { + logger.Warnw("no discovered capabilities in context for OnRegisterSession hook", + "session_id", sessionID) + return + } + + // Validate that routing table exists + if caps.RoutingTable == nil { + logger.Warnw("routing table is nil in discovered capabilities", + "session_id", sessionID) + return + } + + // Add composite tools to capabilities + // Composite tools are static (from configuration) and not discovered from backends + // They are added here to be exposed alongside backend tools in the session + if len(s.workflowDefs) > 0 { + compositeTools := convertWorkflowDefsToTools(s.workflowDefs) + + // Validate no conflicts between composite tool names and backend tool names + if err := validateNoToolConflicts(caps.Tools, compositeTools); err != nil { + logger.Errorw("composite tool name conflict detected", + "session_id", sessionID, + "error", err) + // Don't add composite tools if there are conflicts + // This prevents ambiguity in routing/execution + return + } + + caps.CompositeTools = compositeTools + logger.Debugw("added composite tools to session capabilities", + "session_id", sessionID, + "composite_tool_count", len(compositeTools)) + } + + // Store routing table in VMCPSession for subsequent requests + // This enables the middleware to reconstruct capabilities from session + // without re-running discovery for every request + vmcpSess, err := vmcpsession.GetVMCPSession(sessionID, sessionManager) + if err != nil { + logger.Errorw("failed to get VMCPSession for routing table storage", + "error", err, + "session_id", sessionID) + return + } + + vmcpSess.SetRoutingTable(caps.RoutingTable) + vmcpSess.SetTools(caps.Tools) + logger.Debugw("routing table and tools stored in VMCPSession", + "session_id", sessionID, + "tool_count", len(caps.RoutingTable.Tools), + "resource_count", len(caps.RoutingTable.Resources), + "prompt_count", len(caps.RoutingTable.Prompts)) + + if s.config.OptimizerFactory != nil { + err = s.injectOptimizerCapabilities(sessionID, caps) + if err != nil { + logger.Errorw("failed to create optimizer tools", + "error", err, + "session_id", sessionID) + } else { + logger.Infow("optimizer capabilities injected") + } + return + } + + // Inject capabilities into SDK session + if err := s.injectCapabilities(sessionID, caps); err != nil { + logger.Errorw("failed to inject session capabilities", + "error", err, + "session_id", sessionID) + return + } + + logger.Infow("session capabilities injected", + "session_id", sessionID, + "tool_count", len(caps.Tools), + "resource_count", len(caps.Resources)) +} + // validateAndCreateExecutors validates workflow definitions and creates executors. // // This function: diff --git a/test/e2e/thv-operator/virtualmcp/virtualmcp_optimizer_test.go b/test/e2e/thv-operator/virtualmcp/virtualmcp_optimizer_test.go new file mode 100644 index 0000000000..1db816db9a --- /dev/null +++ b/test/e2e/thv-operator/virtualmcp/virtualmcp_optimizer_test.go @@ -0,0 +1,311 @@ +package virtualmcp + +import ( + "encoding/json" + "fmt" + "strings" + "time" + + "github.com/mark3labs/mcp-go/mcp" + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/runtime" + "k8s.io/apimachinery/pkg/types" + + mcpv1alpha1 "github.com/stacklok/toolhive/cmd/thv-operator/api/v1alpha1" + vmcpconfig "github.com/stacklok/toolhive/pkg/vmcp/config" + "github.com/stacklok/toolhive/test/e2e/images" +) + +// callFindTool calls find_tool and returns the StructuredContent directly +func callFindTool(mcpClient *InitializedMCPClient, description string) (map[string]any, error) { + req := mcp.CallToolRequest{} + req.Params.Name = "find_tool" + req.Params.Arguments = map[string]any{"tool_description": description} + + result, err := mcpClient.Client.CallTool(mcpClient.Ctx, req) + if err != nil { + return nil, err + } + content, ok := result.StructuredContent.(map[string]any) + if !ok { + return nil, fmt.Errorf("expected map[string]any, got %T", result.StructuredContent) + } + return content, nil +} + +// getToolNames extracts tool names from find_tool structured content +func getToolNames(content map[string]any) []string { + tools, ok := content["tools"].([]any) + if !ok { + return nil + } + var names []string + for _, t := range tools { + if tool, ok := t.(map[string]any); ok { + if name, ok := tool["name"].(string); ok { + names = append(names, name) + } + } + } + return names +} + +// callToolViaOptimizer invokes a tool through call_tool +func callToolViaOptimizer(mcpClient *InitializedMCPClient, toolName string, params map[string]any) (*mcp.CallToolResult, error) { + req := mcp.CallToolRequest{} + req.Params.Name = "call_tool" + req.Params.Arguments = map[string]any{ + "tool_name": toolName, + "parameters": params, + } + return mcpClient.Client.CallTool(mcpClient.Ctx, req) +} + +var _ = Describe("VirtualMCPServer Optimizer Mode", Ordered, func() { + var ( + testNamespace = "default" + mcpGroupName = "test-optimizer-group" + vmcpServerName = "test-vmcp-optimizer" + backendName = "backend-optimizer-fetch" + compositeToolName = "double_fetch" + timeout = 3 * time.Minute + pollingInterval = 1 * time.Second + vmcpNodePort int32 + ) + + BeforeAll(func() { + By("Creating MCPGroup for optimizer test") + CreateMCPGroupAndWait(ctx, k8sClient, mcpGroupName, testNamespace, + "Test MCP Group for optimizer E2E tests", timeout, pollingInterval) + + By("Creating backend MCPServer - fetch") + CreateMCPServerAndWait(ctx, k8sClient, backendName, testNamespace, + mcpGroupName, images.GofetchServerImage, timeout, pollingInterval) + + By("Creating VirtualMCPServer with optimizer enabled and a composite tool") + // Define composite tool parameters schema + paramSchema := map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "url": map[string]interface{}{ + "type": "string", + "description": "URL to fetch twice", + }, + }, + "required": []string{"url"}, + } + paramSchemaBytes, err := json.Marshal(paramSchema) + Expect(err).ToNot(HaveOccurred()) + + // Define step arguments that reference the input parameter + stepArgs := map[string]interface{}{ + "url": "{{.params.url}}", + } + stepArgsBytes, err := json.Marshal(stepArgs) + Expect(err).ToNot(HaveOccurred()) + + vmcpServer := &mcpv1alpha1.VirtualMCPServer{ + ObjectMeta: metav1.ObjectMeta{ + Name: vmcpServerName, + Namespace: testNamespace, + }, + Spec: mcpv1alpha1.VirtualMCPServerSpec{ + ServiceType: "NodePort", + IncomingAuth: &mcpv1alpha1.IncomingAuthConfig{ + Type: "anonymous", + }, + OutgoingAuth: &mcpv1alpha1.OutgoingAuthConfig{ + Source: "discovered", + }, + // Define a composite tool that calls fetch twice + CompositeTools: []mcpv1alpha1.CompositeToolSpec{ + { + Name: compositeToolName, + Description: "Fetches a URL twice in sequence for verification", + Parameters: &runtime.RawExtension{ + Raw: paramSchemaBytes, + }, + Steps: []mcpv1alpha1.WorkflowStep{ + { + ID: "first_fetch", + Type: "tool", + Tool: fmt.Sprintf("%s.fetch", backendName), + Arguments: &runtime.RawExtension{ + Raw: stepArgsBytes, + }, + }, + { + ID: "second_fetch", + Type: "tool", + Tool: fmt.Sprintf("%s.fetch", backendName), + DependsOn: []string{"first_fetch"}, + Arguments: &runtime.RawExtension{ + Raw: stepArgsBytes, + }, + }, + }, + }, + }, + Config: vmcpconfig.Config{ + Group: mcpGroupName, + Optimizer: &vmcpconfig.OptimizerConfig{ + // EmbeddingService is required but not used by DummyOptimizer + EmbeddingService: "dummy-embedding-service", + }, + }, + }, + } + Expect(k8sClient.Create(ctx, vmcpServer)).To(Succeed()) + + By("Waiting for VirtualMCPServer to be ready") + WaitForVirtualMCPServerReady(ctx, k8sClient, vmcpServerName, testNamespace, timeout, pollingInterval) + + By("Getting VirtualMCPServer NodePort") + vmcpNodePort = GetVMCPNodePort(ctx, k8sClient, vmcpServerName, testNamespace, timeout, pollingInterval) + _, _ = fmt.Fprintf(GinkgoWriter, "VirtualMCPServer is accessible at NodePort: %d\n", vmcpNodePort) + }) + + AfterAll(func() { + By("Cleaning up VirtualMCPServer") + vmcpServer := &mcpv1alpha1.VirtualMCPServer{} + if err := k8sClient.Get(ctx, types.NamespacedName{ + Name: vmcpServerName, + Namespace: testNamespace, + }, vmcpServer); err == nil { + _ = k8sClient.Delete(ctx, vmcpServer) + } + + By("Cleaning up backend MCPServer") + backend := &mcpv1alpha1.MCPServer{} + if err := k8sClient.Get(ctx, types.NamespacedName{ + Name: backendName, + Namespace: testNamespace, + }, backend); err == nil { + _ = k8sClient.Delete(ctx, backend) + } + + By("Cleaning up MCPGroup") + mcpGroup := &mcpv1alpha1.MCPGroup{} + if err := k8sClient.Get(ctx, types.NamespacedName{ + Name: mcpGroupName, + Namespace: testNamespace, + }, mcpGroup); err == nil { + _ = k8sClient.Delete(ctx, mcpGroup) + } + }) + + It("should only expose find_tool and call_tool", func() { + By("Creating and initializing MCP client") + mcpClient, err := CreateInitializedMCPClient(vmcpNodePort, "optimizer-test-client", 30*time.Second) + Expect(err).ToNot(HaveOccurred()) + defer mcpClient.Close() + + By("Listing tools from VirtualMCPServer") + listRequest := mcp.ListToolsRequest{} + tools, err := mcpClient.Client.ListTools(mcpClient.Ctx, listRequest) + Expect(err).ToNot(HaveOccurred()) + + By("Verifying only optimizer tools are exposed") + Expect(tools.Tools).To(HaveLen(2), "Should only have find_tool and call_tool") + + toolNames := make([]string, len(tools.Tools)) + for i, tool := range tools.Tools { + toolNames[i] = tool.Name + } + Expect(toolNames).To(ContainElements("find_tool", "call_tool")) + + _, _ = fmt.Fprintf(GinkgoWriter, "✓ Optimizer mode correctly exposes only: %v\n", toolNames) + }) + + It("should find backend tools via find_tool", func() { + By("Creating and initializing MCP client") + mcpClient, err := CreateInitializedMCPClient(vmcpNodePort, "optimizer-find-test", 30*time.Second) + Expect(err).ToNot(HaveOccurred()) + defer mcpClient.Close() + + By("Calling find_tool to search for fetch tool") + result, err := callFindTool(mcpClient, "fetch") + Expect(err).ToNot(HaveOccurred()) + + toolNames := getToolNames(result) + Expect(toolNames).ToNot(BeEmpty(), "find_tool should return matching tools") + _, _ = fmt.Fprintf(GinkgoWriter, "✓ find_tool returned tools: %v\n", toolNames) + + By("Verifying at least one tool matches 'fetch'") + var foundFetch bool + for _, name := range toolNames { + if strings.Contains(name, "fetch") { + foundFetch = true + } + } + Expect(foundFetch).To(BeTrue(), "Should find a fetch-related tool") + }) + + It("should invoke backend tools via call_tool", func() { + By("Creating and initializing MCP client") + mcpClient, err := CreateInitializedMCPClient(vmcpNodePort, "optimizer-call-test", 30*time.Second) + Expect(err).ToNot(HaveOccurred()) + defer mcpClient.Close() + + By("Finding the backend fetch tool") + findResult, err := callFindTool(mcpClient, "internet") // matches backend fetch description + Expect(err).ToNot(HaveOccurred()) + + toolNames := getToolNames(findResult) + Expect(toolNames).ToNot(BeEmpty()) + + // Find the backend tool (has prefix), not the composite + var backendToolName string + for _, name := range toolNames { + if strings.Contains(name, backendName) { + backendToolName = name + break + } + } + Expect(backendToolName).ToNot(BeEmpty(), "Should find backend tool") + + By(fmt.Sprintf("Calling %s via call_tool", backendToolName)) + result, err := callToolViaOptimizer(mcpClient, backendToolName, map[string]any{ + "url": "https://example.com", + }) + Expect(err).ToNot(HaveOccurred()) + Expect(result).ToNot(BeNil()) + Expect(result.Content).ToNot(BeEmpty(), "call_tool should return content from backend tool") + + _, _ = fmt.Fprintf(GinkgoWriter, "✓ Successfully called %s via call_tool\n", backendToolName) + }) + + It("should find and invoke composite tools via optimizer", func() { + By("Creating and initializing MCP client") + mcpClient, err := CreateInitializedMCPClient(vmcpNodePort, "optimizer-composite-test", 30*time.Second) + Expect(err).ToNot(HaveOccurred()) + defer mcpClient.Close() + + By("Calling find_tool to search for composite tool") + result, err := callFindTool(mcpClient, "twice") // matches "Fetches a URL twice..." + Expect(err).ToNot(HaveOccurred()) + + toolNames := getToolNames(result) + var foundTool string + for _, name := range toolNames { + _, _ = fmt.Fprintf(GinkgoWriter, " - Found tool: %s\n", name) + if strings.Contains(name, compositeToolName) { + foundTool = name + } + } + Expect(foundTool).ToNot(BeEmpty(), "Should find composite tool %s", compositeToolName) + _, _ = fmt.Fprintf(GinkgoWriter, "✓ Found composite tool %s via find_tool\n", foundTool) + + By(fmt.Sprintf("Calling composite tool %s via call_tool", foundTool)) + callResult, err := callToolViaOptimizer(mcpClient, foundTool, map[string]any{ + "url": "https://example.com", + }) + Expect(err).ToNot(HaveOccurred()) + Expect(callResult).ToNot(BeNil()) + Expect(callResult.Content).ToNot(BeEmpty(), "call_tool should return content from composite tool") + + _, _ = fmt.Fprintf(GinkgoWriter, "✓ Successfully called composite tool %s via call_tool\n", foundTool) + }) +})