diff --git a/cmd/thv-operator/pkg/vmcpconfig/converter.go b/cmd/thv-operator/pkg/vmcpconfig/converter.go
index 7eead1d576..10c2538e07 100644
--- a/cmd/thv-operator/pkg/vmcpconfig/converter.go
+++ b/cmd/thv-operator/pkg/vmcpconfig/converter.go
@@ -61,15 +61,23 @@ func NewConverter(oidcResolver oidc.Resolver, k8sClient client.Client) (*Convert
}, nil
}
-// Convert converts VirtualMCPServer CRD spec to vmcp Config
+// Convert converts VirtualMCPServer CRD spec to vmcp Config.
+//
+// The conversion starts with a DeepCopy of the embedded config.Config from the CRD spec.
+// This ensures that simple fields (like Optimizer, Metadata, etc.) are automatically
+// passed through without explicit mapping. Only fields that require special handling
+// (auth, aggregation, composite tools, telemetry) are explicitly converted below.
func (c *Converter) Convert(
ctx context.Context,
vmcp *mcpv1alpha1.VirtualMCPServer,
) (*vmcpconfig.Config, error) {
- config := &vmcpconfig.Config{
- Name: vmcp.Name,
- Group: vmcp.Spec.Config.Group,
- }
+ // Start with a deep copy of the embedded config for automatic field passthrough.
+ // This ensures new fields added to config.Config are automatically included
+ // without requiring explicit mapping in this converter.
+ config := vmcp.Spec.Config.DeepCopy()
+
+ // Override name with the CR name (authoritative source)
+ config.Name = vmcp.Name
// Convert IncomingAuth - required field, no defaults
if vmcp.Spec.IncomingAuth != nil {
diff --git a/cmd/vmcp/README.md b/cmd/vmcp/README.md
index 30ac862ca2..10a60bef2a 100644
--- a/cmd/vmcp/README.md
+++ b/cmd/vmcp/README.md
@@ -6,7 +6,7 @@ The Virtual MCP Server (vmcp) is a standalone binary that aggregates multiple MC
## Features
-### Implemented (Phase 1)
+### Implemented
- ✅ **Group-Based Backend Management**: Automatic workload discovery from ToolHive groups
- ✅ **Tool Aggregation**: Combines tools from multiple MCP servers with conflict resolution (prefix, priority, manual)
- ✅ **Resource & Prompt Aggregation**: Unified access to resources and prompts from all backends
@@ -15,12 +15,14 @@ The Virtual MCP Server (vmcp) is a standalone binary that aggregates multiple MC
- ✅ **Health Endpoints**: `/health` and `/ping` for service monitoring
- ✅ **Configuration Validation**: `vmcp validate` command for config verification
- ✅ **Observability**: OpenTelemetry metrics and traces for backend operations and workflow executions
+- ✅ **Composite Tools**: Multi-step workflows with elicitation support
### In Progress
- 🚧 **Incoming Authentication** (Issue #165): OIDC, local, anonymous authentication
- 🚧 **Outgoing Authentication** (Issue #160): RFC 8693 token exchange for backend API access
- 🚧 **Token Caching**: Memory and Redis cache providers
- 🚧 **Health Monitoring** (Issue #166): Circuit breakers, backend health checks
+- 🚧 **Optimizer** Support the MCP optimizer in vMCP for context optimization on large toolsets.
### Future (Phase 2+)
- 📋 **Authorization**: Cedar policy-based access control
diff --git a/cmd/vmcp/app/commands.go b/cmd/vmcp/app/commands.go
index ca6060bcab..891b32c1cb 100644
--- a/cmd/vmcp/app/commands.go
+++ b/cmd/vmcp/app/commands.go
@@ -24,6 +24,7 @@ import (
"github.com/stacklok/toolhive/pkg/vmcp/discovery"
"github.com/stacklok/toolhive/pkg/vmcp/health"
"github.com/stacklok/toolhive/pkg/vmcp/k8s"
+ "github.com/stacklok/toolhive/pkg/vmcp/optimizer"
vmcprouter "github.com/stacklok/toolhive/pkg/vmcp/router"
vmcpserver "github.com/stacklok/toolhive/pkg/vmcp/server"
)
@@ -416,6 +417,11 @@ func runServe(cmd *cobra.Command, _ []string) error {
Watcher: backendWatcher,
}
+ if cfg.Optimizer != nil {
+ // TODO: update this with the real optimizer.
+ serverCfg.OptimizerFactory = optimizer.NewDummyOptimizer
+ }
+
// Convert composite tool configurations to workflow definitions
workflowDefs, err := vmcpserver.ConvertConfigToWorkflowDefinitions(cfg.CompositeTools)
if err != nil {
diff --git a/deploy/charts/operator-crds/files/crds/toolhive.stacklok.dev_virtualmcpservers.yaml b/deploy/charts/operator-crds/files/crds/toolhive.stacklok.dev_virtualmcpservers.yaml
index bff67be5a8..bd70f3bffd 100644
--- a/deploy/charts/operator-crds/files/crds/toolhive.stacklok.dev_virtualmcpservers.yaml
+++ b/deploy/charts/operator-crds/files/crds/toolhive.stacklok.dev_virtualmcpservers.yaml
@@ -862,6 +862,21 @@ spec:
- default
type: object
type: object
+ optimizer:
+ description: |-
+ Optimizer configures the MCP optimizer for context optimization on large toolsets.
+ When enabled, vMCP exposes only find_tool and call_tool operations to clients
+ instead of all backend tools directly. This reduces token usage by allowing
+ LLMs to discover relevant tools on demand rather than receiving all tool definitions.
+ properties:
+ embeddingService:
+ description: |-
+ EmbeddingService is the name of a Kubernetes Service that provides the embedding service
+ for semantic tool discovery. The service must implement the optimizer embedding API.
+ type: string
+ required:
+ - embeddingService
+ type: object
outgoingAuth:
description: OutgoingAuth configures how the virtual MCP server
authenticates to backends.
diff --git a/deploy/charts/operator-crds/templates/toolhive.stacklok.dev_virtualmcpservers.yaml b/deploy/charts/operator-crds/templates/toolhive.stacklok.dev_virtualmcpservers.yaml
index 8e329f8584..b497666ce5 100644
--- a/deploy/charts/operator-crds/templates/toolhive.stacklok.dev_virtualmcpservers.yaml
+++ b/deploy/charts/operator-crds/templates/toolhive.stacklok.dev_virtualmcpservers.yaml
@@ -865,6 +865,21 @@ spec:
- default
type: object
type: object
+ optimizer:
+ description: |-
+ Optimizer configures the MCP optimizer for context optimization on large toolsets.
+ When enabled, vMCP exposes only find_tool and call_tool operations to clients
+ instead of all backend tools directly. This reduces token usage by allowing
+ LLMs to discover relevant tools on demand rather than receiving all tool definitions.
+ properties:
+ embeddingService:
+ description: |-
+ EmbeddingService is the name of a Kubernetes Service that provides the embedding service
+ for semantic tool discovery. The service must implement the optimizer embedding API.
+ type: string
+ required:
+ - embeddingService
+ type: object
outgoingAuth:
description: OutgoingAuth configures how the virtual MCP server
authenticates to backends.
diff --git a/docs/operator/crd-api.md b/docs/operator/crd-api.md
index 14e8587c6a..de89fc4a7c 100644
--- a/docs/operator/crd-api.md
+++ b/docs/operator/crd-api.md
@@ -225,6 +225,7 @@ _Appears in:_
| `metadata` _object (keys:string, values:string)_ | Refer to Kubernetes API documentation for fields of `metadata`. | | |
| `telemetry` _[pkg.telemetry.Config](#pkgtelemetryconfig)_ | Telemetry configures OpenTelemetry-based observability for the Virtual MCP server
including distributed tracing, OTLP metrics export, and Prometheus metrics endpoint. | | |
| `audit` _[pkg.audit.Config](#pkgauditconfig)_ | Audit configures audit logging for the Virtual MCP server.
When present, audit logs include MCP protocol operations.
See audit.Config for available configuration options. | | |
+| `optimizer` _[vmcp.config.OptimizerConfig](#vmcpconfigoptimizerconfig)_ | Optimizer configures the MCP optimizer for context optimization on large toolsets.
When enabled, vMCP exposes only find_tool and call_tool operations to clients
instead of all backend tools directly. This reduces token usage by allowing
LLMs to discover relevant tools on demand rather than receiving all tool definitions. | | |
#### vmcp.config.ConflictResolutionConfig
@@ -343,6 +344,24 @@ _Appears in:_
| `failureHandling` _[vmcp.config.FailureHandlingConfig](#vmcpconfigfailurehandlingconfig)_ | FailureHandling configures failure handling. | | |
+#### vmcp.config.OptimizerConfig
+
+
+
+OptimizerConfig configures the MCP optimizer.
+When enabled, vMCP exposes only find_tool and call_tool operations to clients
+instead of all backend tools directly.
+
+
+
+_Appears in:_
+- [vmcp.config.Config](#vmcpconfigconfig)
+
+| Field | Description | Default | Validation |
+| --- | --- | --- | --- |
+| `embeddingService` _string_ | EmbeddingService is the name of a Kubernetes Service that provides the embedding service
for semantic tool discovery. The service must implement the optimizer embedding API. | | Required: \{\}
|
+
+
#### vmcp.config.OutgoingAuthConfig
diff --git a/pkg/vmcp/config/config.go b/pkg/vmcp/config/config.go
index 9693643348..a6355cd5e8 100644
--- a/pkg/vmcp/config/config.go
+++ b/pkg/vmcp/config/config.go
@@ -108,6 +108,13 @@ type Config struct {
// See audit.Config for available configuration options.
// +optional
Audit *audit.Config `json:"audit,omitempty" yaml:"audit,omitempty"`
+
+ // Optimizer configures the MCP optimizer for context optimization on large toolsets.
+ // When enabled, vMCP exposes only find_tool and call_tool operations to clients
+ // instead of all backend tools directly. This reduces token usage by allowing
+ // LLMs to discover relevant tools on demand rather than receiving all tool definitions.
+ // +optional
+ Optimizer *OptimizerConfig `json:"optimizer,omitempty" yaml:"optimizer,omitempty"`
}
// IncomingAuthConfig configures client authentication to the virtual MCP server.
@@ -474,6 +481,18 @@ type OutputProperty struct {
Default thvjson.Any `json:"default,omitempty" yaml:"default,omitempty"`
}
+// OptimizerConfig configures the MCP optimizer.
+// When enabled, vMCP exposes only find_tool and call_tool operations to clients
+// instead of all backend tools directly.
+// +kubebuilder:object:generate=true
+// +gendoc
+type OptimizerConfig struct {
+ // EmbeddingService is the name of a Kubernetes Service that provides the embedding service
+ // for semantic tool discovery. The service must implement the optimizer embedding API.
+ // +kubebuilder:validation:Required
+ EmbeddingService string `json:"embeddingService" yaml:"embeddingService"`
+}
+
// Validator validates configuration.
type Validator interface {
// Validate checks if the configuration is valid.
diff --git a/pkg/vmcp/config/zz_generated.deepcopy.go b/pkg/vmcp/config/zz_generated.deepcopy.go
index b50fa118af..5afe9c656a 100644
--- a/pkg/vmcp/config/zz_generated.deepcopy.go
+++ b/pkg/vmcp/config/zz_generated.deepcopy.go
@@ -175,6 +175,11 @@ func (in *Config) DeepCopyInto(out *Config) {
*out = new(audit.Config)
(*in).DeepCopyInto(*out)
}
+ if in.Optimizer != nil {
+ in, out := &in.Optimizer, &out.Optimizer
+ *out = new(OptimizerConfig)
+ **out = **in
+ }
}
// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new Config.
@@ -312,6 +317,21 @@ func (in *OperationalConfig) DeepCopy() *OperationalConfig {
return out
}
+// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil.
+func (in *OptimizerConfig) DeepCopyInto(out *OptimizerConfig) {
+ *out = *in
+}
+
+// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new OptimizerConfig.
+func (in *OptimizerConfig) DeepCopy() *OptimizerConfig {
+ if in == nil {
+ return nil
+ }
+ out := new(OptimizerConfig)
+ in.DeepCopyInto(out)
+ return out
+}
+
// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil.
func (in *OutgoingAuthConfig) DeepCopyInto(out *OutgoingAuthConfig) {
*out = *in
diff --git a/pkg/vmcp/optimizer/dummy_optimizer.go b/pkg/vmcp/optimizer/dummy_optimizer.go
new file mode 100644
index 0000000000..3a8338f04d
--- /dev/null
+++ b/pkg/vmcp/optimizer/dummy_optimizer.go
@@ -0,0 +1,122 @@
+package optimizer
+
+import (
+ "context"
+ "encoding/json"
+ "fmt"
+ "strings"
+
+ "github.com/mark3labs/mcp-go/mcp"
+ "github.com/mark3labs/mcp-go/server"
+)
+
+// DummyOptimizer implements the Optimizer interface using exact string matching.
+//
+// This implementation is intended for testing and development. It performs
+// case-insensitive substring matching on tool names and descriptions.
+//
+// For production use, see the EmbeddingOptimizer which uses semantic similarity.
+type DummyOptimizer struct {
+ // tools contains all available tools indexed by name.
+ tools map[string]server.ServerTool
+}
+
+// NewDummyOptimizer creates a new DummyOptimizer with the given tools.
+//
+// The tools slice should contain all backend tools (as ServerTool with handlers).
+func NewDummyOptimizer(tools []server.ServerTool) Optimizer {
+ toolMap := make(map[string]server.ServerTool, len(tools))
+ for _, tool := range tools {
+ toolMap[tool.Tool.Name] = tool
+ }
+
+ return DummyOptimizer{
+ tools: toolMap,
+ }
+}
+
+// FindTool searches for tools using exact substring matching.
+//
+// The search is case-insensitive and matches against:
+// - Tool name (substring match)
+// - Tool description (substring match)
+//
+// Returns all matching tools with a score of 1.0 (exact match semantics).
+// TokenMetrics are returned as zero values (not implemented in dummy).
+func (d DummyOptimizer) FindTool(_ context.Context, input FindToolInput) (*FindToolOutput, error) {
+ if input.ToolDescription == "" {
+ return nil, fmt.Errorf("tool_description is required")
+ }
+
+ // Log all tools in the optimizer for debugging
+ fmt.Printf("[DummyOptimizer.FindTool] Searching for %q in %d tools:\n", input.ToolDescription, len(d.tools))
+ for name, tool := range d.tools {
+ fmt.Printf(" - %q: %q\n", name, tool.Tool.Description)
+ }
+
+ searchTerm := strings.ToLower(input.ToolDescription)
+
+ var matches []ToolMatch
+ for _, tool := range d.tools {
+ nameLower := strings.ToLower(tool.Tool.Name)
+ descLower := strings.ToLower(tool.Tool.Description)
+
+ // Check if search term matches name or description
+ if strings.Contains(nameLower, searchTerm) || strings.Contains(descLower, searchTerm) {
+ schema, err := getToolSchema(tool.Tool)
+ if err != nil {
+ return nil, err
+ }
+ matches = append(matches, ToolMatch{
+ Name: tool.Tool.Name,
+ Description: tool.Tool.Description,
+ Parameters: schema,
+ Score: 1.0, // Exact match semantics
+ })
+ }
+ }
+
+ return &FindToolOutput{
+ Tools: matches,
+ TokenMetrics: TokenMetrics{}, // Zero values for dummy
+ }, nil
+}
+
+// CallTool invokes a tool by name using its registered handler.
+//
+// The tool is looked up by exact name match. If found, the handler
+// is invoked directly with the given parameters.
+func (d DummyOptimizer) CallTool(ctx context.Context, input CallToolInput) (*mcp.CallToolResult, error) {
+ if input.ToolName == "" {
+ return nil, fmt.Errorf("tool_name is required")
+ }
+
+ // Verify the tool exists
+ tool, exists := d.tools[input.ToolName]
+ if !exists {
+ return mcp.NewToolResultError(fmt.Sprintf("tool not found: %s", input.ToolName)), nil
+ }
+
+ // Build the MCP request
+ request := mcp.CallToolRequest{}
+ request.Params.Name = input.ToolName
+ request.Params.Arguments = input.Parameters
+
+ // Call the tool handler directly
+ return tool.Handler(ctx, request)
+}
+
+// getToolSchema returns the input schema for a tool.
+// Prefers RawInputSchema if set, otherwise marshals InputSchema.
+func getToolSchema(tool mcp.Tool) (json.RawMessage, error) {
+ if len(tool.RawInputSchema) > 0 {
+ return tool.RawInputSchema, nil
+ }
+
+ // Fall back to InputSchema
+ data, err := json.Marshal(tool.InputSchema)
+ if err != nil {
+ return nil, fmt.Errorf("failed to marshal input schema for tool %s: %w", tool.Name, err)
+ }
+ return data, nil
+}
diff --git a/pkg/vmcp/optimizer/dummy_optimizer_test.go b/pkg/vmcp/optimizer/dummy_optimizer_test.go
new file mode 100644
index 0000000000..489f04e0e5
--- /dev/null
+++ b/pkg/vmcp/optimizer/dummy_optimizer_test.go
@@ -0,0 +1,188 @@
+package optimizer
+
+import (
+ "context"
+ "testing"
+
+ "github.com/mark3labs/mcp-go/mcp"
+ "github.com/mark3labs/mcp-go/server"
+ "github.com/stretchr/testify/require"
+)
+
+func TestDummyOptimizer_FindTool(t *testing.T) {
+ t.Parallel()
+
+ tools := []server.ServerTool{
+ {
+ Tool: mcp.Tool{
+ Name: "fetch_url",
+ Description: "Fetch content from a URL",
+ },
+ },
+ {
+ Tool: mcp.Tool{
+ Name: "read_file",
+ Description: "Read a file from the filesystem",
+ },
+ },
+ {
+ Tool: mcp.Tool{
+ Name: "write_file",
+ Description: "Write content to a file",
+ },
+ },
+ }
+
+ opt := NewDummyOptimizer(tools)
+
+ tests := []struct {
+ name string
+ input FindToolInput
+ expectedNames []string
+ expectedError bool
+ errorContains string
+ }{
+ {
+ name: "find by exact name",
+ input: FindToolInput{
+ ToolDescription: "fetch_url",
+ },
+ expectedNames: []string{"fetch_url"},
+ },
+ {
+ name: "find by description substring",
+ input: FindToolInput{
+ ToolDescription: "file",
+ },
+ expectedNames: []string{"read_file", "write_file"},
+ },
+ {
+ name: "case insensitive search",
+ input: FindToolInput{
+ ToolDescription: "FETCH",
+ },
+ expectedNames: []string{"fetch_url"},
+ },
+ {
+ name: "no matches",
+ input: FindToolInput{
+ ToolDescription: "nonexistent",
+ },
+ expectedNames: []string{},
+ },
+ {
+ name: "empty description",
+ input: FindToolInput{},
+ expectedError: true,
+ errorContains: "tool_description is required",
+ },
+ }
+
+ for _, tc := range tests {
+ t.Run(tc.name, func(t *testing.T) {
+ t.Parallel()
+
+ result, err := opt.FindTool(context.Background(), tc.input)
+
+ if tc.expectedError {
+ require.Error(t, err)
+ require.Contains(t, err.Error(), tc.errorContains)
+ return
+ }
+
+ require.NoError(t, err)
+ require.NotNil(t, result)
+
+ // Extract names from results
+ var names []string
+ for _, match := range result.Tools {
+ names = append(names, match.Name)
+ }
+
+ require.ElementsMatch(t, tc.expectedNames, names)
+ })
+ }
+}
+
+func TestDummyOptimizer_CallTool(t *testing.T) {
+ t.Parallel()
+
+ tools := []server.ServerTool{
+ {
+ Tool: mcp.Tool{
+ Name: "test_tool",
+ Description: "A test tool",
+ },
+ Handler: func(_ context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
+ args, _ := req.Params.Arguments.(map[string]any)
+ input := args["input"].(string)
+ return mcp.NewToolResultText("Hello, " + input + "!"), nil
+ },
+ },
+ }
+
+ opt := NewDummyOptimizer(tools)
+
+ tests := []struct {
+ name string
+ input CallToolInput
+ expectedText string
+ expectedError bool
+ isToolError bool
+ errorContains string
+ }{
+ {
+ name: "successful tool call",
+ input: CallToolInput{
+ ToolName: "test_tool",
+ Parameters: map[string]any{"input": "World"},
+ },
+ expectedText: "Hello, World!",
+ },
+ {
+ name: "tool not found",
+ input: CallToolInput{
+ ToolName: "nonexistent",
+ Parameters: map[string]any{},
+ },
+ isToolError: true,
+ expectedText: "tool not found: nonexistent",
+ },
+ {
+ name: "empty tool name",
+ input: CallToolInput{
+ Parameters: map[string]any{},
+ },
+ expectedError: true,
+ errorContains: "tool_name is required",
+ },
+ }
+
+ for _, tc := range tests {
+ t.Run(tc.name, func(t *testing.T) {
+ t.Parallel()
+
+ result, err := opt.CallTool(context.Background(), tc.input)
+
+ if tc.expectedError {
+ require.Error(t, err)
+ require.Contains(t, err.Error(), tc.errorContains)
+ return
+ }
+
+ require.NoError(t, err)
+ require.NotNil(t, result)
+
+ if tc.isToolError {
+ require.True(t, result.IsError)
+ }
+
+ if tc.expectedText != "" {
+ require.Len(t, result.Content, 1)
+ textContent, ok := result.Content[0].(mcp.TextContent)
+ require.True(t, ok)
+ require.Equal(t, tc.expectedText, textContent.Text)
+ }
+ })
+ }
+}
diff --git a/pkg/vmcp/optimizer/optimizer.go b/pkg/vmcp/optimizer/optimizer.go
new file mode 100644
index 0000000000..2da6f748f7
--- /dev/null
+++ b/pkg/vmcp/optimizer/optimizer.go
@@ -0,0 +1,88 @@
+// Package optimizer provides the Optimizer interface for intelligent tool discovery
+// and invocation in the Virtual MCP Server.
+//
+// When the optimizer is enabled, vMCP exposes only two tools to clients:
+// - find_tool: Semantic search over available tools
+// - call_tool: Dynamic invocation of any backend tool
+//
+// This reduces token usage by avoiding the need to send all tool definitions
+// to the LLM, instead allowing it to discover relevant tools on demand.
+package optimizer
+
+import (
+ "context"
+ "encoding/json"
+
+ "github.com/mark3labs/mcp-go/mcp"
+)
+
+// Optimizer defines the interface for intelligent tool discovery and invocation.
+//
+// Implementations may use various strategies for tool matching:
+// - DummyOptimizer: Exact string matching (for testing)
+// - EmbeddingOptimizer: Semantic similarity via embeddings (production)
+type Optimizer interface {
+ // FindTool searches for tools matching the given description and keywords.
+ // Returns matching tools ranked by relevance score.
+ FindTool(ctx context.Context, input FindToolInput) (*FindToolOutput, error)
+
+ // CallTool invokes a tool by name with the given parameters.
+ // Returns the tool's result or an error if the tool is not found or execution fails.
+ // Returns the MCP CallToolResult directly from the underlying tool handler.
+ CallTool(ctx context.Context, input CallToolInput) (*mcp.CallToolResult, error)
+}
+
+// FindToolInput contains the parameters for finding tools.
+type FindToolInput struct {
+ // ToolDescription is a natural language description of the tool to find.
+ ToolDescription string `json:"tool_description" description:"Natural language description of the tool to find"`
+
+ // ToolKeywords is an optional list of keywords to narrow the search.
+ ToolKeywords []string `json:"tool_keywords,omitempty" description:"Optional keywords to narrow search"`
+}
+
+// FindToolOutput contains the results of a tool search.
+type FindToolOutput struct {
+ // Tools contains the matching tools, ranked by relevance.
+ Tools []ToolMatch `json:"tools"`
+
+ // TokenMetrics provides information about token savings from using the optimizer.
+ TokenMetrics TokenMetrics `json:"token_metrics"`
+}
+
+// ToolMatch represents a tool that matched the search criteria.
+type ToolMatch struct {
+ // Name is the unique identifier of the tool.
+ Name string `json:"name"`
+
+ // Description is the human-readable description of the tool.
+ Description string `json:"description"`
+
+ // Parameters is the JSON schema for the tool's input parameters.
+ // Uses json.RawMessage to preserve the original schema format.
+ Parameters json.RawMessage `json:"parameters"`
+
+ // Score indicates how well this tool matches the search criteria (0.0-1.0).
+ Score float64 `json:"score"`
+}
+
+// TokenMetrics provides information about token usage optimization.
+type TokenMetrics struct {
+ // BaselineTokens is the estimated tokens if all tools were sent.
+ BaselineTokens int `json:"baseline_tokens"`
+
+ // ReturnedTokens is the actual tokens for the returned tools.
+ ReturnedTokens int `json:"returned_tokens"`
+
+ // SavingsPercent is the percentage of tokens saved.
+ SavingsPercent float64 `json:"savings_percent"`
+}
+
+// CallToolInput contains the parameters for calling a tool.
+type CallToolInput struct {
+ // ToolName is the name of the tool to invoke.
+ ToolName string `json:"tool_name" description:"Name of the tool to call"`
+
+ // Parameters are the arguments to pass to the tool.
+ Parameters map[string]any `json:"parameters" description:"Parameters to pass to the tool"`
+}
diff --git a/pkg/vmcp/schema/reflect.go b/pkg/vmcp/schema/reflect.go
new file mode 100644
index 0000000000..a6f1ec419e
--- /dev/null
+++ b/pkg/vmcp/schema/reflect.go
@@ -0,0 +1,173 @@
+package schema
+
+import (
+ "encoding/json"
+ "fmt"
+ "reflect"
+ "strings"
+)
+
+// GenerateSchema generates a JSON Schema from a Go struct type using reflection.
+//
+// The function inspects struct tags to determine:
+// - json: Field name in the schema (uses json tag name)
+// - description: Field description (from `description` tag)
+// - omitempty: Whether the field is optional (not in required array)
+//
+// Supported types:
+// - string -> {"type": "string"}
+// - int, int64, etc. -> {"type": "integer"}
+// - float64, float32 -> {"type": "number"}
+// - bool -> {"type": "boolean"}
+// - []T -> {"type": "array", "items": {...}}
+// - map[string]any -> {"type": "object"}
+// - struct -> {"type": "object", "properties": {...}}
+//
+// Example:
+//
+// type FindToolInput struct {
+// ToolDescription string `json:"tool_description" description:"Natural language description"`
+// ToolKeywords []string `json:"tool_keywords,omitempty" description:"Optional keywords"`
+// }
+// schema := GenerateSchema[FindToolInput]()
+func GenerateSchema[T any]() map[string]any {
+ var zero T
+ t := reflect.TypeOf(zero)
+ return generateSchemaForType(t)
+}
+
+// generateSchemaForType generates schema for a reflect.Type.
+func generateSchemaForType(t reflect.Type) map[string]any {
+ // Handle pointer types
+ if t.Kind() == reflect.Ptr {
+ t = t.Elem()
+ }
+
+ switch t.Kind() {
+ case reflect.Struct:
+ return generateObjectSchema(t)
+ case reflect.String:
+ return map[string]any{"type": "string"}
+ case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
+ reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
+ return map[string]any{"type": "integer"}
+ case reflect.Float32, reflect.Float64:
+ return map[string]any{"type": "number"}
+ case reflect.Bool:
+ return map[string]any{"type": "boolean"}
+ case reflect.Slice:
+ return map[string]any{
+ "type": "array",
+ "items": generateSchemaForType(t.Elem()),
+ }
+ case reflect.Map:
+ // For map[string]any, just return object type
+ return map[string]any{"type": "object"}
+ case reflect.Interface:
+ // For any/interface{}, return empty object
+ return map[string]any{"type": "object"}
+ default:
+ return map[string]any{"type": "object"}
+ }
+}
+
+// generateObjectSchema generates schema for a struct type.
+func generateObjectSchema(t reflect.Type) map[string]any {
+ properties := make(map[string]any)
+ var required []string
+
+ for i := range t.NumField() {
+ field := t.Field(i)
+
+ // Skip unexported fields
+ if !field.IsExported() {
+ continue
+ }
+
+ // Get JSON tag for field name
+ jsonTag := field.Tag.Get("json")
+ if jsonTag == "-" {
+ continue
+ }
+
+ // Parse json tag (name,omitempty)
+ jsonName, isOptional := parseJSONTag(jsonTag)
+ if jsonName == "" {
+ jsonName = field.Name
+ }
+
+ // Generate schema for field type
+ fieldSchema := generateSchemaForType(field.Type)
+
+ // Add description if present
+ if desc := field.Tag.Get("description"); desc != "" {
+ fieldSchema["description"] = desc
+ }
+
+ properties[jsonName] = fieldSchema
+
+ // Add to required if not optional
+ if !isOptional {
+ required = append(required, jsonName)
+ }
+ }
+
+ schema := map[string]any{
+ "type": "object",
+ "properties": properties,
+ }
+
+ if len(required) > 0 {
+ schema["required"] = required
+ }
+
+ return schema
+}
+
+// parseJSONTag parses a json struct tag and returns the field name and whether it's optional.
+func parseJSONTag(tag string) (name string, optional bool) {
+ if tag == "" {
+ return "", false
+ }
+
+ parts := strings.Split(tag, ",")
+ name = parts[0]
+
+ for _, part := range parts[1:] {
+ if part == "omitempty" {
+ optional = true
+ }
+ }
+
+ return name, optional
+}
+
+// Translate converts an untyped input (typically map[string]any from MCP request arguments)
+// to a typed struct using JSON marshalling/unmarshalling.
+//
+// This provides a simple, reliable way to convert MCP tool arguments to typed Go structs
+// without manual field-by-field extraction.
+//
+// Example:
+//
+// args := request.Params.Arguments // map[string]any
+// input, err := Translate[FindToolInput](args)
+// if err != nil {
+// return nil, fmt.Errorf("invalid arguments: %w", err)
+// }
+func Translate[T any](input any) (T, error) {
+ var result T
+
+ // Marshal to JSON
+ data, err := json.Marshal(input)
+ if err != nil {
+ return result, fmt.Errorf("failed to marshal input: %w", err)
+ }
+
+ // Unmarshal to typed struct
+ if err := json.Unmarshal(data, &result); err != nil {
+ return result, fmt.Errorf("failed to unmarshal to %T: %w", result, err)
+ }
+
+ return result, nil
+}
diff --git a/pkg/vmcp/schema/reflect_test.go b/pkg/vmcp/schema/reflect_test.go
new file mode 100644
index 0000000000..a6acbbecf4
--- /dev/null
+++ b/pkg/vmcp/schema/reflect_test.go
@@ -0,0 +1,141 @@
+package schema
+
+import (
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+
+ "github.com/stacklok/toolhive/pkg/vmcp/optimizer"
+)
+
+func TestGenerateSchema_FindToolInput(t *testing.T) {
+ t.Parallel()
+
+ expected := map[string]any{
+ "type": "object",
+ "properties": map[string]any{
+ "tool_description": map[string]any{
+ "type": "string",
+ "description": "Natural language description of the tool to find",
+ },
+ "tool_keywords": map[string]any{
+ "type": "array",
+ "items": map[string]any{"type": "string"},
+ "description": "Optional keywords to narrow search",
+ },
+ },
+ "required": []string{"tool_description"},
+ }
+
+ actual := GenerateSchema[optimizer.FindToolInput]()
+
+ require.Equal(t, expected, actual)
+}
+
+func TestGenerateSchema_CallToolInput(t *testing.T) {
+ t.Parallel()
+
+ expected := map[string]any{
+ "type": "object",
+ "properties": map[string]any{
+ "tool_name": map[string]any{
+ "type": "string",
+ "description": "Name of the tool to call",
+ },
+ "parameters": map[string]any{
+ "type": "object",
+ "description": "Parameters to pass to the tool",
+ },
+ },
+ "required": []string{"tool_name", "parameters"},
+ }
+
+ actual := GenerateSchema[optimizer.CallToolInput]()
+
+ require.Equal(t, expected, actual)
+}
+
+func TestTranslate_FindToolInput(t *testing.T) {
+ t.Parallel()
+
+ input := map[string]any{
+ "tool_description": "find a tool to read files",
+ "tool_keywords": []any{"file", "read"},
+ }
+
+ result, err := Translate[optimizer.FindToolInput](input)
+ require.NoError(t, err)
+
+ require.Equal(t, "find a tool to read files", result.ToolDescription)
+ require.Equal(t, []string{"file", "read"}, result.ToolKeywords)
+}
+
+func TestTranslate_CallToolInput(t *testing.T) {
+ t.Parallel()
+
+ input := map[string]any{
+ "tool_name": "read_file",
+ "parameters": map[string]any{
+ "path": "/etc/hosts",
+ },
+ }
+
+ result, err := Translate[optimizer.CallToolInput](input)
+ require.NoError(t, err)
+
+ require.Equal(t, "read_file", result.ToolName)
+ require.Equal(t, map[string]any{"path": "/etc/hosts"}, result.Parameters)
+}
+
+func TestTranslate_PartialInput(t *testing.T) {
+ t.Parallel()
+
+ input := map[string]any{
+ "tool_description": "find a file reader",
+ }
+
+ result, err := Translate[optimizer.FindToolInput](input)
+ require.NoError(t, err)
+
+ require.Equal(t, "find a file reader", result.ToolDescription)
+ require.Nil(t, result.ToolKeywords)
+}
+
+func TestTranslate_InvalidInput(t *testing.T) {
+ t.Parallel()
+
+ input := make(chan int)
+
+ _, err := Translate[optimizer.FindToolInput](input)
+ require.Error(t, err)
+ assert.Contains(t, err.Error(), "failed to marshal input")
+}
+
+func TestGenerateSchema_PrimitiveTypes(t *testing.T) {
+ t.Parallel()
+
+ type TestStruct struct {
+ StringField string `json:"string_field"`
+ IntField int `json:"int_field"`
+ FloatField float64 `json:"float_field"`
+ BoolField bool `json:"bool_field"`
+ OptionalStr string `json:"optional_str,omitempty"`
+ }
+
+ expected := map[string]any{
+ "type": "object",
+ "properties": map[string]any{
+ "string_field": map[string]any{"type": "string"},
+ "int_field": map[string]any{"type": "integer"},
+ "float_field": map[string]any{"type": "number"},
+ "bool_field": map[string]any{"type": "boolean"},
+ "optional_str": map[string]any{"type": "string"},
+ },
+ "required": []string{"string_field", "int_field", "float_field", "bool_field"},
+ }
+
+ actual := GenerateSchema[TestStruct]()
+
+ require.Equal(t, expected, actual)
+}
diff --git a/pkg/vmcp/server/adapter/optimizer_adapter.go b/pkg/vmcp/server/adapter/optimizer_adapter.go
new file mode 100644
index 0000000000..d0abb559ec
--- /dev/null
+++ b/pkg/vmcp/server/adapter/optimizer_adapter.go
@@ -0,0 +1,94 @@
+package adapter
+
+import (
+ "context"
+ "encoding/json"
+ "fmt"
+
+ "github.com/mark3labs/mcp-go/mcp"
+ "github.com/mark3labs/mcp-go/server"
+
+ "github.com/stacklok/toolhive/pkg/vmcp/optimizer"
+ "github.com/stacklok/toolhive/pkg/vmcp/schema"
+)
+
+// OptimizerToolNames defines the tool names exposed when optimizer is enabled.
+const (
+ FindToolName = "find_tool"
+ CallToolName = "call_tool"
+)
+
+// Pre-generated schemas for optimizer tools.
+// Generated at package init time so any schema errors panic at startup.
+var (
+ findToolInputSchema = mustMarshalSchema(schema.GenerateSchema[optimizer.FindToolInput]())
+ callToolInputSchema = mustMarshalSchema(schema.GenerateSchema[optimizer.CallToolInput]())
+)
+
+// CreateOptimizerTools creates the SDK tools for optimizer mode.
+// When optimizer is enabled, only these two tools are exposed to clients
+// instead of all backend tools.
+func CreateOptimizerTools(opt optimizer.Optimizer) []server.ServerTool {
+ return []server.ServerTool{
+ {
+ Tool: mcp.Tool{
+ Name: FindToolName,
+ Description: "Search for tools by description. Returns matching tools ranked by relevance.",
+ RawInputSchema: findToolInputSchema,
+ },
+ Handler: createFindToolHandler(opt),
+ },
+ {
+ Tool: mcp.Tool{
+ Name: CallToolName,
+ Description: "Call a tool by name with the given parameters.",
+ RawInputSchema: callToolInputSchema,
+ },
+ Handler: createCallToolHandler(opt),
+ },
+ }
+}
+
+// createFindToolHandler creates a handler for the find_tool optimizer operation.
+func createFindToolHandler(opt optimizer.Optimizer) func(context.Context, mcp.CallToolRequest) (*mcp.CallToolResult, error) {
+ return func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
+ input, err := schema.Translate[optimizer.FindToolInput](request.Params.Arguments)
+ if err != nil {
+ return mcp.NewToolResultError(fmt.Sprintf("invalid arguments: %v", err)), nil
+ }
+
+ output, err := opt.FindTool(ctx, input)
+ if err != nil {
+ return mcp.NewToolResultError(fmt.Sprintf("find_tool failed: %v", err)), nil
+ }
+
+ return mcp.NewToolResultStructuredOnly(output), nil
+ }
+}
+
+// createCallToolHandler creates a handler for the call_tool optimizer operation.
+func createCallToolHandler(opt optimizer.Optimizer) func(context.Context, mcp.CallToolRequest) (*mcp.CallToolResult, error) {
+ return func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
+ input, err := schema.Translate[optimizer.CallToolInput](request.Params.Arguments)
+ if err != nil {
+ return mcp.NewToolResultError(fmt.Sprintf("invalid arguments: %v", err)), nil
+ }
+
+ result, err := opt.CallTool(ctx, input)
+ if err != nil {
+ return mcp.NewToolResultError(fmt.Sprintf("call_tool failed: %v", err)), nil
+ }
+
+ return result, nil
+ }
+}
+
+// mustMarshalSchema marshals a schema to JSON, panicking on error.
+// This is safe because schemas are generated from known types at startup.
+func mustMarshalSchema(s map[string]any) json.RawMessage {
+ data, err := json.Marshal(s)
+ if err != nil {
+ panic(fmt.Sprintf("failed to marshal schema: %v", err))
+ }
+ return data
+}
diff --git a/pkg/vmcp/server/adapter/optimizer_adapter_test.go b/pkg/vmcp/server/adapter/optimizer_adapter_test.go
new file mode 100644
index 0000000000..f0c4db73e0
--- /dev/null
+++ b/pkg/vmcp/server/adapter/optimizer_adapter_test.go
@@ -0,0 +1,104 @@
+package adapter
+
+import (
+ "context"
+ "testing"
+
+ "github.com/mark3labs/mcp-go/mcp"
+ "github.com/stretchr/testify/require"
+
+ "github.com/stacklok/toolhive/pkg/vmcp/optimizer"
+)
+
+// mockOptimizer implements optimizer.Optimizer for testing.
+type mockOptimizer struct {
+ findToolFunc func(ctx context.Context, input optimizer.FindToolInput) (*optimizer.FindToolOutput, error)
+ callToolFunc func(ctx context.Context, input optimizer.CallToolInput) (*mcp.CallToolResult, error)
+}
+
+func (m *mockOptimizer) FindTool(ctx context.Context, input optimizer.FindToolInput) (*optimizer.FindToolOutput, error) {
+ if m.findToolFunc != nil {
+ return m.findToolFunc(ctx, input)
+ }
+ return &optimizer.FindToolOutput{}, nil
+}
+
+func (m *mockOptimizer) CallTool(ctx context.Context, input optimizer.CallToolInput) (*mcp.CallToolResult, error) {
+ if m.callToolFunc != nil {
+ return m.callToolFunc(ctx, input)
+ }
+ return mcp.NewToolResultText("ok"), nil
+}
+
+func TestCreateOptimizerTools(t *testing.T) {
+ t.Parallel()
+
+ opt := &mockOptimizer{}
+ tools := CreateOptimizerTools(opt)
+
+ require.Len(t, tools, 2)
+ require.Equal(t, FindToolName, tools[0].Tool.Name)
+ require.Equal(t, CallToolName, tools[1].Tool.Name)
+}
+
+func TestFindToolHandler(t *testing.T) {
+ t.Parallel()
+
+ opt := &mockOptimizer{
+ findToolFunc: func(_ context.Context, input optimizer.FindToolInput) (*optimizer.FindToolOutput, error) {
+ require.Equal(t, "read files", input.ToolDescription)
+ return &optimizer.FindToolOutput{
+ Tools: []optimizer.ToolMatch{
+ {
+ Name: "read_file",
+ Description: "Read a file",
+ Score: 1.0,
+ },
+ },
+ }, nil
+ },
+ }
+
+ tools := CreateOptimizerTools(opt)
+ handler := tools[0].Handler
+
+ request := mcp.CallToolRequest{}
+ request.Params.Arguments = map[string]any{
+ "tool_description": "read files",
+ }
+
+ result, err := handler(context.Background(), request)
+ require.NoError(t, err)
+ require.NotNil(t, result)
+ require.False(t, result.IsError)
+ require.Len(t, result.Content, 1)
+}
+
+func TestCallToolHandler(t *testing.T) {
+ t.Parallel()
+
+ opt := &mockOptimizer{
+ callToolFunc: func(_ context.Context, input optimizer.CallToolInput) (*mcp.CallToolResult, error) {
+ require.Equal(t, "read_file", input.ToolName)
+ require.Equal(t, "/etc/hosts", input.Parameters["path"])
+ return mcp.NewToolResultText("file contents here"), nil
+ },
+ }
+
+ tools := CreateOptimizerTools(opt)
+ handler := tools[1].Handler
+
+ request := mcp.CallToolRequest{}
+ request.Params.Arguments = map[string]any{
+ "tool_name": "read_file",
+ "parameters": map[string]any{
+ "path": "/etc/hosts",
+ },
+ }
+
+ result, err := handler(context.Background(), request)
+ require.NoError(t, err)
+ require.NotNil(t, result)
+ require.False(t, result.IsError)
+ require.Len(t, result.Content, 1)
+}
diff --git a/pkg/vmcp/server/server.go b/pkg/vmcp/server/server.go
index 477b9a1234..d0879f3a63 100644
--- a/pkg/vmcp/server/server.go
+++ b/pkg/vmcp/server/server.go
@@ -28,6 +28,7 @@ import (
"github.com/stacklok/toolhive/pkg/vmcp/composer"
"github.com/stacklok/toolhive/pkg/vmcp/discovery"
"github.com/stacklok/toolhive/pkg/vmcp/health"
+ "github.com/stacklok/toolhive/pkg/vmcp/optimizer"
"github.com/stacklok/toolhive/pkg/vmcp/router"
"github.com/stacklok/toolhive/pkg/vmcp/server/adapter"
vmcpsession "github.com/stacklok/toolhive/pkg/vmcp/session"
@@ -119,6 +120,10 @@ type Config struct {
// Only set when running in K8s with outgoingAuth.source: discovered.
// Used for /readyz endpoint to gate readiness on cache sync.
Watcher Watcher
+
+ // OptimizerFactory builds an optimizer from a list of tools.
+ // If not set, the optimizer is disabled.
+ OptimizerFactory func([]server.ServerTool) optimizer.Optimizer
}
// Server is the Virtual MCP Server that aggregates multiple backends.
@@ -344,89 +349,9 @@ func New(
}
// Register OnRegisterSession hook to inject capabilities after SDK registers session.
- // This hook fires AFTER the session is registered in the SDK (unlike AfterInitialize which
- // fires BEFORE session registration), allowing us to safely call AddSessionTools/AddSessionResources.
- //
- // The discovery middleware populates capabilities in the context, which is available here.
- // We inject them into the SDK session and store the routing table for subsequent requests.
- //
- // IMPORTANT: Session capabilities are immutable after injection.
- // - Capabilities discovered during initialize are fixed for the session lifetime
- // - Backend changes (new tools, removed resources) won't be reflected in existing sessions
- // - Clients must create new sessions to see updated capabilities
- // TODO(dynamic-capabilities): Consider implementing capability refresh mechanism when SDK supports it
+ // See handleSessionRegistration for implementation details.
hooks.AddOnRegisterSession(func(ctx context.Context, session server.ClientSession) {
- sessionID := session.SessionID()
- logger.Debugw("OnRegisterSession hook called", "session_id", sessionID)
-
- // Get capabilities from context (discovered by middleware)
- caps, ok := discovery.DiscoveredCapabilitiesFromContext(ctx)
- if !ok || caps == nil {
- logger.Warnw("no discovered capabilities in context for OnRegisterSession hook",
- "session_id", sessionID)
- return
- }
-
- // Validate that routing table exists
- if caps.RoutingTable == nil {
- logger.Warnw("routing table is nil in discovered capabilities",
- "session_id", sessionID)
- return
- }
-
- // Add composite tools to capabilities
- // Composite tools are static (from configuration) and not discovered from backends
- // They are added here to be exposed alongside backend tools in the session
- if len(srv.workflowDefs) > 0 {
- compositeTools := convertWorkflowDefsToTools(srv.workflowDefs)
-
- // Validate no conflicts between composite tool names and backend tool names
- if err := validateNoToolConflicts(caps.Tools, compositeTools); err != nil {
- logger.Errorw("composite tool name conflict detected",
- "session_id", sessionID,
- "error", err)
- // Don't add composite tools if there are conflicts
- // This prevents ambiguity in routing/execution
- return
- }
-
- caps.CompositeTools = compositeTools
- logger.Debugw("added composite tools to session capabilities",
- "session_id", sessionID,
- "composite_tool_count", len(compositeTools))
- }
-
- // Store routing table in VMCPSession for subsequent requests
- // This enables the middleware to reconstruct capabilities from session
- // without re-running discovery for every request
- vmcpSess, err := vmcpsession.GetVMCPSession(sessionID, sessionManager)
- if err != nil {
- logger.Errorw("failed to get VMCPSession for routing table storage",
- "error", err,
- "session_id", sessionID)
- return
- }
-
- vmcpSess.SetRoutingTable(caps.RoutingTable)
- vmcpSess.SetTools(caps.Tools)
- logger.Debugw("routing table and tools stored in VMCPSession",
- "session_id", sessionID,
- "tool_count", len(caps.RoutingTable.Tools),
- "resource_count", len(caps.RoutingTable.Resources),
- "prompt_count", len(caps.RoutingTable.Prompts))
-
- // Inject capabilities into SDK session
- if err := srv.injectCapabilities(sessionID, caps); err != nil {
- logger.Errorw("failed to inject session capabilities",
- "error", err,
- "session_id", sessionID)
- return
- }
-
- logger.Infow("session capabilities injected",
- "session_id", sessionID,
- "tool_count", len(caps.Tools),
- "resource_count", len(caps.Resources))
+ srv.handleSessionRegistration(ctx, session, sessionManager)
})
return srv, nil
@@ -779,6 +704,7 @@ func (s *Server) Ready() <-chan struct{} {
// - No previous capabilities exist, so no deletion needed
// - Capabilities are IMMUTABLE for the session lifetime (see limitation below)
// - Discovery middleware does not re-run for subsequent requests
+// - If injectOptimizerCapabilities is called, this should not be called again.
//
// LIMITATION: Session capabilities are fixed at creation time.
// If backends change (new tools added, resources removed), existing sessions won't see updates.
@@ -852,6 +778,175 @@ func (s *Server) injectCapabilities(
return nil
}
+// injectOptimizerCapabilities injects all capabilities into the session, including optimizer tools.
+// It should not be called if optimizer mode and replaces injectCapabilities.
+//
+// When optimizer mode is enabled, instead of exposing all backend tools directly,
+// vMCP exposes only two meta-tools:
+// - find_tool: Search for tools by description
+// - call_tool: Invoke a tool by name with parameters
+//
+// This method:
+// 1. Converts all tools (backend + composite) to SDK format with handlers
+// 2. Injects the optimizer capabilities into the session
+func (s *Server) injectOptimizerCapabilities(
+ sessionID string,
+ caps *aggregator.AggregatedCapabilities,
+) error {
+
+ tools := append([]vmcp.Tool{}, caps.Tools...)
+ tools = append(tools, caps.CompositeTools...)
+
+ sdkTools, err := s.capabilityAdapter.ToSDKTools(tools)
+ if err != nil {
+ return fmt.Errorf("failed to convert tools to SDK format: %w", err)
+ }
+
+ // Create optimizer tools (find_tool, call_tool)
+ optimizerTools := adapter.CreateOptimizerTools(s.config.OptimizerFactory(sdkTools))
+
+ logger.Debugw("created optimizer tools for session",
+ "session_id", sessionID,
+ "backend_tool_count", len(caps.Tools),
+ "composite_tool_count", len(caps.CompositeTools),
+ "total_tools_indexed", len(sdkTools))
+
+ // Save tools before clearing - caps is shared across sessions
+ savedTools := caps.Tools
+ savedCompositeTools := caps.CompositeTools
+
+ // Clear tools from caps - they're now wrapped by optimizer
+ // Resources and prompts are preserved and handled normally
+ caps.Tools = nil
+ caps.CompositeTools = nil
+ defer func() {
+ // Restore tools before returning error
+ caps.Tools = savedTools
+ caps.CompositeTools = savedCompositeTools
+ }()
+
+ // Manually add the optimizer tools, since we don't want to bother converting
+ // optimizer tools into `vmcp.Tool`s as well.
+ if err := s.mcpServer.AddSessionTools(sessionID, optimizerTools...); err != nil {
+ return fmt.Errorf("failed to add session tools: %w", err)
+ }
+
+ return s.injectCapabilities(sessionID, caps)
+}
+
+// handleSessionRegistration processes a new MCP session registration.
+//
+// This hook fires AFTER the session is registered in the SDK (unlike AfterInitialize which
+// fires BEFORE session registration), allowing us to safely call AddSessionTools/AddSessionResources.
+//
+// The discovery middleware populates capabilities in the context, which is available here.
+// We inject them into the SDK session and store the routing table for subsequent requests.
+//
+// This method performs the following steps:
+// 1. Retrieves discovered capabilities from context
+// 2. Adds composite tools from configuration
+// 3. Stores routing table in VMCPSession for request routing
+// 4. Injects capabilities into the SDK session
+//
+// IMPORTANT: Session capabilities are immutable after injection.
+// - Capabilities discovered during initialize are fixed for the session lifetime
+// - Backend changes (new tools, removed resources) won't be reflected in existing sessions
+// - Clients must create new sessions to see updated capabilities
+//
+// TODO(dynamic-capabilities): Consider implementing capability refresh mechanism when SDK supports it
+//
+// The sessionManager parameter is passed explicitly because this method is called
+// from a closure registered before the Server is fully constructed.
+func (s *Server) handleSessionRegistration(
+ ctx context.Context,
+ session server.ClientSession,
+ sessionManager *transportsession.Manager,
+) {
+ sessionID := session.SessionID()
+ logger.Debugw("OnRegisterSession hook called", "session_id", sessionID)
+
+ // Get capabilities from context (discovered by middleware)
+ caps, ok := discovery.DiscoveredCapabilitiesFromContext(ctx)
+ if !ok || caps == nil {
+ logger.Warnw("no discovered capabilities in context for OnRegisterSession hook",
+ "session_id", sessionID)
+ return
+ }
+
+ // Validate that routing table exists
+ if caps.RoutingTable == nil {
+ logger.Warnw("routing table is nil in discovered capabilities",
+ "session_id", sessionID)
+ return
+ }
+
+ // Add composite tools to capabilities
+ // Composite tools are static (from configuration) and not discovered from backends
+ // They are added here to be exposed alongside backend tools in the session
+ if len(s.workflowDefs) > 0 {
+ compositeTools := convertWorkflowDefsToTools(s.workflowDefs)
+
+ // Validate no conflicts between composite tool names and backend tool names
+ if err := validateNoToolConflicts(caps.Tools, compositeTools); err != nil {
+ logger.Errorw("composite tool name conflict detected",
+ "session_id", sessionID,
+ "error", err)
+ // Don't add composite tools if there are conflicts
+ // This prevents ambiguity in routing/execution
+ return
+ }
+
+ caps.CompositeTools = compositeTools
+ logger.Debugw("added composite tools to session capabilities",
+ "session_id", sessionID,
+ "composite_tool_count", len(compositeTools))
+ }
+
+ // Store routing table in VMCPSession for subsequent requests
+ // This enables the middleware to reconstruct capabilities from session
+ // without re-running discovery for every request
+ vmcpSess, err := vmcpsession.GetVMCPSession(sessionID, sessionManager)
+ if err != nil {
+ logger.Errorw("failed to get VMCPSession for routing table storage",
+ "error", err,
+ "session_id", sessionID)
+ return
+ }
+
+ vmcpSess.SetRoutingTable(caps.RoutingTable)
+ vmcpSess.SetTools(caps.Tools)
+ logger.Debugw("routing table and tools stored in VMCPSession",
+ "session_id", sessionID,
+ "tool_count", len(caps.RoutingTable.Tools),
+ "resource_count", len(caps.RoutingTable.Resources),
+ "prompt_count", len(caps.RoutingTable.Prompts))
+
+ if s.config.OptimizerFactory != nil {
+ err = s.injectOptimizerCapabilities(sessionID, caps)
+ if err != nil {
+ logger.Errorw("failed to create optimizer tools",
+ "error", err,
+ "session_id", sessionID)
+ } else {
+ logger.Infow("optimizer capabilities injected")
+ }
+ return
+ }
+
+ // Inject capabilities into SDK session
+ if err := s.injectCapabilities(sessionID, caps); err != nil {
+ logger.Errorw("failed to inject session capabilities",
+ "error", err,
+ "session_id", sessionID)
+ return
+ }
+
+ logger.Infow("session capabilities injected",
+ "session_id", sessionID,
+ "tool_count", len(caps.Tools),
+ "resource_count", len(caps.Resources))
+}
+
// validateAndCreateExecutors validates workflow definitions and creates executors.
//
// This function:
diff --git a/test/e2e/thv-operator/virtualmcp/virtualmcp_optimizer_test.go b/test/e2e/thv-operator/virtualmcp/virtualmcp_optimizer_test.go
new file mode 100644
index 0000000000..1db816db9a
--- /dev/null
+++ b/test/e2e/thv-operator/virtualmcp/virtualmcp_optimizer_test.go
@@ -0,0 +1,311 @@
+package virtualmcp
+
+import (
+ "encoding/json"
+ "fmt"
+ "strings"
+ "time"
+
+ "github.com/mark3labs/mcp-go/mcp"
+ . "github.com/onsi/ginkgo/v2"
+ . "github.com/onsi/gomega"
+ metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
+ "k8s.io/apimachinery/pkg/runtime"
+ "k8s.io/apimachinery/pkg/types"
+
+ mcpv1alpha1 "github.com/stacklok/toolhive/cmd/thv-operator/api/v1alpha1"
+ vmcpconfig "github.com/stacklok/toolhive/pkg/vmcp/config"
+ "github.com/stacklok/toolhive/test/e2e/images"
+)
+
+// callFindTool calls find_tool and returns the StructuredContent directly
+func callFindTool(mcpClient *InitializedMCPClient, description string) (map[string]any, error) {
+ req := mcp.CallToolRequest{}
+ req.Params.Name = "find_tool"
+ req.Params.Arguments = map[string]any{"tool_description": description}
+
+ result, err := mcpClient.Client.CallTool(mcpClient.Ctx, req)
+ if err != nil {
+ return nil, err
+ }
+ content, ok := result.StructuredContent.(map[string]any)
+ if !ok {
+ return nil, fmt.Errorf("expected map[string]any, got %T", result.StructuredContent)
+ }
+ return content, nil
+}
+
+// getToolNames extracts tool names from find_tool structured content
+func getToolNames(content map[string]any) []string {
+ tools, ok := content["tools"].([]any)
+ if !ok {
+ return nil
+ }
+ var names []string
+ for _, t := range tools {
+ if tool, ok := t.(map[string]any); ok {
+ if name, ok := tool["name"].(string); ok {
+ names = append(names, name)
+ }
+ }
+ }
+ return names
+}
+
+// callToolViaOptimizer invokes a tool through call_tool
+func callToolViaOptimizer(mcpClient *InitializedMCPClient, toolName string, params map[string]any) (*mcp.CallToolResult, error) {
+ req := mcp.CallToolRequest{}
+ req.Params.Name = "call_tool"
+ req.Params.Arguments = map[string]any{
+ "tool_name": toolName,
+ "parameters": params,
+ }
+ return mcpClient.Client.CallTool(mcpClient.Ctx, req)
+}
+
+var _ = Describe("VirtualMCPServer Optimizer Mode", Ordered, func() {
+ var (
+ testNamespace = "default"
+ mcpGroupName = "test-optimizer-group"
+ vmcpServerName = "test-vmcp-optimizer"
+ backendName = "backend-optimizer-fetch"
+ compositeToolName = "double_fetch"
+ timeout = 3 * time.Minute
+ pollingInterval = 1 * time.Second
+ vmcpNodePort int32
+ )
+
+ BeforeAll(func() {
+ By("Creating MCPGroup for optimizer test")
+ CreateMCPGroupAndWait(ctx, k8sClient, mcpGroupName, testNamespace,
+ "Test MCP Group for optimizer E2E tests", timeout, pollingInterval)
+
+ By("Creating backend MCPServer - fetch")
+ CreateMCPServerAndWait(ctx, k8sClient, backendName, testNamespace,
+ mcpGroupName, images.GofetchServerImage, timeout, pollingInterval)
+
+ By("Creating VirtualMCPServer with optimizer enabled and a composite tool")
+ // Define composite tool parameters schema
+ paramSchema := map[string]interface{}{
+ "type": "object",
+ "properties": map[string]interface{}{
+ "url": map[string]interface{}{
+ "type": "string",
+ "description": "URL to fetch twice",
+ },
+ },
+ "required": []string{"url"},
+ }
+ paramSchemaBytes, err := json.Marshal(paramSchema)
+ Expect(err).ToNot(HaveOccurred())
+
+ // Define step arguments that reference the input parameter
+ stepArgs := map[string]interface{}{
+ "url": "{{.params.url}}",
+ }
+ stepArgsBytes, err := json.Marshal(stepArgs)
+ Expect(err).ToNot(HaveOccurred())
+
+ vmcpServer := &mcpv1alpha1.VirtualMCPServer{
+ ObjectMeta: metav1.ObjectMeta{
+ Name: vmcpServerName,
+ Namespace: testNamespace,
+ },
+ Spec: mcpv1alpha1.VirtualMCPServerSpec{
+ ServiceType: "NodePort",
+ IncomingAuth: &mcpv1alpha1.IncomingAuthConfig{
+ Type: "anonymous",
+ },
+ OutgoingAuth: &mcpv1alpha1.OutgoingAuthConfig{
+ Source: "discovered",
+ },
+ // Define a composite tool that calls fetch twice
+ CompositeTools: []mcpv1alpha1.CompositeToolSpec{
+ {
+ Name: compositeToolName,
+ Description: "Fetches a URL twice in sequence for verification",
+ Parameters: &runtime.RawExtension{
+ Raw: paramSchemaBytes,
+ },
+ Steps: []mcpv1alpha1.WorkflowStep{
+ {
+ ID: "first_fetch",
+ Type: "tool",
+ Tool: fmt.Sprintf("%s.fetch", backendName),
+ Arguments: &runtime.RawExtension{
+ Raw: stepArgsBytes,
+ },
+ },
+ {
+ ID: "second_fetch",
+ Type: "tool",
+ Tool: fmt.Sprintf("%s.fetch", backendName),
+ DependsOn: []string{"first_fetch"},
+ Arguments: &runtime.RawExtension{
+ Raw: stepArgsBytes,
+ },
+ },
+ },
+ },
+ },
+ Config: vmcpconfig.Config{
+ Group: mcpGroupName,
+ Optimizer: &vmcpconfig.OptimizerConfig{
+ // EmbeddingService is required but not used by DummyOptimizer
+ EmbeddingService: "dummy-embedding-service",
+ },
+ },
+ },
+ }
+ Expect(k8sClient.Create(ctx, vmcpServer)).To(Succeed())
+
+ By("Waiting for VirtualMCPServer to be ready")
+ WaitForVirtualMCPServerReady(ctx, k8sClient, vmcpServerName, testNamespace, timeout, pollingInterval)
+
+ By("Getting VirtualMCPServer NodePort")
+ vmcpNodePort = GetVMCPNodePort(ctx, k8sClient, vmcpServerName, testNamespace, timeout, pollingInterval)
+ _, _ = fmt.Fprintf(GinkgoWriter, "VirtualMCPServer is accessible at NodePort: %d\n", vmcpNodePort)
+ })
+
+ AfterAll(func() {
+ By("Cleaning up VirtualMCPServer")
+ vmcpServer := &mcpv1alpha1.VirtualMCPServer{}
+ if err := k8sClient.Get(ctx, types.NamespacedName{
+ Name: vmcpServerName,
+ Namespace: testNamespace,
+ }, vmcpServer); err == nil {
+ _ = k8sClient.Delete(ctx, vmcpServer)
+ }
+
+ By("Cleaning up backend MCPServer")
+ backend := &mcpv1alpha1.MCPServer{}
+ if err := k8sClient.Get(ctx, types.NamespacedName{
+ Name: backendName,
+ Namespace: testNamespace,
+ }, backend); err == nil {
+ _ = k8sClient.Delete(ctx, backend)
+ }
+
+ By("Cleaning up MCPGroup")
+ mcpGroup := &mcpv1alpha1.MCPGroup{}
+ if err := k8sClient.Get(ctx, types.NamespacedName{
+ Name: mcpGroupName,
+ Namespace: testNamespace,
+ }, mcpGroup); err == nil {
+ _ = k8sClient.Delete(ctx, mcpGroup)
+ }
+ })
+
+ It("should only expose find_tool and call_tool", func() {
+ By("Creating and initializing MCP client")
+ mcpClient, err := CreateInitializedMCPClient(vmcpNodePort, "optimizer-test-client", 30*time.Second)
+ Expect(err).ToNot(HaveOccurred())
+ defer mcpClient.Close()
+
+ By("Listing tools from VirtualMCPServer")
+ listRequest := mcp.ListToolsRequest{}
+ tools, err := mcpClient.Client.ListTools(mcpClient.Ctx, listRequest)
+ Expect(err).ToNot(HaveOccurred())
+
+ By("Verifying only optimizer tools are exposed")
+ Expect(tools.Tools).To(HaveLen(2), "Should only have find_tool and call_tool")
+
+ toolNames := make([]string, len(tools.Tools))
+ for i, tool := range tools.Tools {
+ toolNames[i] = tool.Name
+ }
+ Expect(toolNames).To(ContainElements("find_tool", "call_tool"))
+
+ _, _ = fmt.Fprintf(GinkgoWriter, "✓ Optimizer mode correctly exposes only: %v\n", toolNames)
+ })
+
+ It("should find backend tools via find_tool", func() {
+ By("Creating and initializing MCP client")
+ mcpClient, err := CreateInitializedMCPClient(vmcpNodePort, "optimizer-find-test", 30*time.Second)
+ Expect(err).ToNot(HaveOccurred())
+ defer mcpClient.Close()
+
+ By("Calling find_tool to search for fetch tool")
+ result, err := callFindTool(mcpClient, "fetch")
+ Expect(err).ToNot(HaveOccurred())
+
+ toolNames := getToolNames(result)
+ Expect(toolNames).ToNot(BeEmpty(), "find_tool should return matching tools")
+ _, _ = fmt.Fprintf(GinkgoWriter, "✓ find_tool returned tools: %v\n", toolNames)
+
+ By("Verifying at least one tool matches 'fetch'")
+ var foundFetch bool
+ for _, name := range toolNames {
+ if strings.Contains(name, "fetch") {
+ foundFetch = true
+ }
+ }
+ Expect(foundFetch).To(BeTrue(), "Should find a fetch-related tool")
+ })
+
+ It("should invoke backend tools via call_tool", func() {
+ By("Creating and initializing MCP client")
+ mcpClient, err := CreateInitializedMCPClient(vmcpNodePort, "optimizer-call-test", 30*time.Second)
+ Expect(err).ToNot(HaveOccurred())
+ defer mcpClient.Close()
+
+ By("Finding the backend fetch tool")
+ findResult, err := callFindTool(mcpClient, "internet") // matches backend fetch description
+ Expect(err).ToNot(HaveOccurred())
+
+ toolNames := getToolNames(findResult)
+ Expect(toolNames).ToNot(BeEmpty())
+
+ // Find the backend tool (has prefix), not the composite
+ var backendToolName string
+ for _, name := range toolNames {
+ if strings.Contains(name, backendName) {
+ backendToolName = name
+ break
+ }
+ }
+ Expect(backendToolName).ToNot(BeEmpty(), "Should find backend tool")
+
+ By(fmt.Sprintf("Calling %s via call_tool", backendToolName))
+ result, err := callToolViaOptimizer(mcpClient, backendToolName, map[string]any{
+ "url": "https://example.com",
+ })
+ Expect(err).ToNot(HaveOccurred())
+ Expect(result).ToNot(BeNil())
+ Expect(result.Content).ToNot(BeEmpty(), "call_tool should return content from backend tool")
+
+ _, _ = fmt.Fprintf(GinkgoWriter, "✓ Successfully called %s via call_tool\n", backendToolName)
+ })
+
+ It("should find and invoke composite tools via optimizer", func() {
+ By("Creating and initializing MCP client")
+ mcpClient, err := CreateInitializedMCPClient(vmcpNodePort, "optimizer-composite-test", 30*time.Second)
+ Expect(err).ToNot(HaveOccurred())
+ defer mcpClient.Close()
+
+ By("Calling find_tool to search for composite tool")
+ result, err := callFindTool(mcpClient, "twice") // matches "Fetches a URL twice..."
+ Expect(err).ToNot(HaveOccurred())
+
+ toolNames := getToolNames(result)
+ var foundTool string
+ for _, name := range toolNames {
+ _, _ = fmt.Fprintf(GinkgoWriter, " - Found tool: %s\n", name)
+ if strings.Contains(name, compositeToolName) {
+ foundTool = name
+ }
+ }
+ Expect(foundTool).ToNot(BeEmpty(), "Should find composite tool %s", compositeToolName)
+ _, _ = fmt.Fprintf(GinkgoWriter, "✓ Found composite tool %s via find_tool\n", foundTool)
+
+ By(fmt.Sprintf("Calling composite tool %s via call_tool", foundTool))
+ callResult, err := callToolViaOptimizer(mcpClient, foundTool, map[string]any{
+ "url": "https://example.com",
+ })
+ Expect(err).ToNot(HaveOccurred())
+ Expect(callResult).ToNot(BeNil())
+ Expect(callResult.Content).ToNot(BeEmpty(), "call_tool should return content from composite tool")
+
+ _, _ = fmt.Fprintf(GinkgoWriter, "✓ Successfully called composite tool %s via call_tool\n", foundTool)
+ })
+})