diff --git a/pkg/vmcp/client/client.go b/pkg/vmcp/client/client.go index 5f38ee50d8..441dd42d66 100644 --- a/pkg/vmcp/client/client.go +++ b/pkg/vmcp/client/client.go @@ -9,7 +9,6 @@ package client import ( "context" - "encoding/base64" "errors" "fmt" "io" @@ -23,6 +22,7 @@ import ( "github.com/mark3labs/mcp-go/mcp" "github.com/stacklok/toolhive/pkg/auth" + "github.com/stacklok/toolhive/pkg/versions" "github.com/stacklok/toolhive/pkg/vmcp" vmcpauth "github.com/stacklok/toolhive/pkg/vmcp/auth" authtypes "github.com/stacklok/toolhive/pkg/vmcp/auth/types" @@ -187,37 +187,35 @@ func (h *httpBackendClient) defaultClientFactory(ctx context.Context, target *vm identity: identity, } - // Add size limit layer for DoS protection - sizeLimitedTransport := roundTripperFunc(func(req *http.Request) (*http.Response, error) { - resp, err := baseTransport.RoundTrip(req) - if err != nil { - return nil, err - } - // Wrap response body with size limit - resp.Body = struct { - io.Reader - io.Closer - }{ - Reader: io.LimitReader(resp.Body, maxResponseSize), - Closer: resp.Body, - } - return resp, nil - }) - - // Create HTTP client with configured transport chain - // Set timeouts to prevent long-lived connections that require continuous listening - httpClient := &http.Client{ - Transport: sizeLimitedTransport, - Timeout: 30 * time.Second, // Prevent hanging on connections - } - var c *client.Client switch target.TransportType { case "streamable-http", "streamable": + // "streamable" is a legacy alias for "streamable-http". + // + // For streamable-HTTP each MCP call is a single bounded HTTP + // request/response pair, so a per-response body size limit is safe. + sizeLimitedTransport := roundTripperFunc(func(req *http.Request) (*http.Response, error) { + resp, err := baseTransport.RoundTrip(req) + if err != nil { + return nil, err + } + resp.Body = struct { + io.Reader + io.Closer + }{ + Reader: io.LimitReader(resp.Body, maxResponseSize), + Closer: resp.Body, + } + return resp, nil + }) + httpClient := &http.Client{ + Transport: sizeLimitedTransport, + Timeout: 30 * time.Second, + } c, err = client.NewStreamableHttpClient( target.BaseURL, - transport.WithHTTPTimeout(30*time.Second), // Set timeout instead of 0 + transport.WithHTTPTimeout(30*time.Second), transport.WithHTTPBasicClient(httpClient), ) if err != nil { @@ -225,6 +223,11 @@ func (h *httpBackendClient) defaultClientFactory(ctx context.Context, target *vm } case "sse": + // For SSE the entire session is one long-lived HTTP response body. + // Applying io.LimitReader would silently terminate the stream after + // maxResponseSize cumulative bytes — not per-event — which is wrong. + // http.Client.Timeout is also omitted: it would kill the stream. + httpClient := &http.Client{Transport: baseTransport} c, err = client.NewSSEMCPClient( target.BaseURL, transport.WithHTTPClient(httpClient), @@ -325,7 +328,7 @@ func initializeClient(ctx context.Context, c *client.Client) (*mcp.ServerCapabil ProtocolVersion: mcp.LATEST_PROTOCOL_VERSION, ClientInfo: mcp.Implementation{ Name: "toolhive-vmcp", - Version: "0.1.0", + Version: versions.Version, }, Capabilities: mcp.ClientCapabilities{ // Virtual MCP acts as a client to backends @@ -382,36 +385,6 @@ func queryPrompts(ctx context.Context, c *client.Client, supported bool, backend return &mcp.ListPromptsResult{Prompts: []mcp.Prompt{}}, nil } -// convertContent converts mcp.Content to vmcp.Content. -// This preserves the full content structure from backend responses. -func convertContent(content mcp.Content) vmcp.Content { - if textContent, ok := mcp.AsTextContent(content); ok { - return vmcp.Content{ - Type: "text", - Text: textContent.Text, - } - } - if imageContent, ok := mcp.AsImageContent(content); ok { - return vmcp.Content{ - Type: "image", - Data: imageContent.Data, - MimeType: imageContent.MIMEType, - } - } - if audioContent, ok := mcp.AsAudioContent(content); ok { - return vmcp.Content{ - Type: "audio", - Data: audioContent.Data, - MimeType: audioContent.MIMEType, - } - } - // Handle embedded resources if needed - // Unknown content types are marked as "unknown" type with no data - slog.Warn("encountered unknown content type, marking as unknown content", - "type", fmt.Sprintf("%T", content)) - return vmcp.Content{Type: "unknown"} -} - // ListCapabilities queries a backend for its MCP capabilities. // Returns tools, resources, and prompts exposed by the backend. // Only queries capabilities that the server advertises during initialization. @@ -467,25 +440,10 @@ func (h *httpBackendClient) ListCapabilities(ctx context.Context, target *vmcp.B // Convert tools for i, tool := range toolsResp.Tools { - // Convert ToolInputSchema to map[string]any - // The ToolInputSchema is a struct with Type, Properties, Required fields - inputSchema := map[string]any{ - "type": tool.InputSchema.Type, - } - if tool.InputSchema.Properties != nil { - inputSchema["properties"] = tool.InputSchema.Properties - } - if len(tool.InputSchema.Required) > 0 { - inputSchema["required"] = tool.InputSchema.Required - } - if tool.InputSchema.Defs != nil { - inputSchema["$defs"] = tool.InputSchema.Defs - } - capabilities.Tools[i] = vmcp.Tool{ Name: tool.Name, Description: tool.Description, - InputSchema: inputSchema, + InputSchema: conversion.ConvertToolInputSchema(tool.InputSchema), BackendID: target.WorkloadID, } } @@ -608,11 +566,8 @@ func (h *httpBackendClient) CallTool( // Continue processing - we return the result with IsError flag and metadata preserved } - // Convert MCP content to vmcp.Content array - contentArray := make([]vmcp.Content, len(result.Content)) - for i, content := range result.Content { - contentArray[i] = convertContent(content) - } + // Convert MCP content to vmcp.Content array. + contentArray := conversion.ConvertMCPContents(result.Content) // Check for structured content first (preferred for composite tool step chaining). // StructuredContent allows templates to access nested fields directly via {{.steps.stepID.output.field}}. @@ -683,33 +638,8 @@ func (h *httpBackendClient) ReadResource( return nil, fmt.Errorf("resource read failed on backend %s: %w", target.WorkloadID, err) } - // Concatenate all resource contents - // MCP resources can have multiple contents (text or blob) - var data []byte - var mimeType string - for i, content := range result.Contents { - // Try to convert to TextResourceContents - if textContent, ok := mcp.AsTextResourceContents(content); ok { - data = append(data, []byte(textContent.Text)...) - if i == 0 && textContent.MIMEType != "" { - mimeType = textContent.MIMEType - } - } else if blobContent, ok := mcp.AsBlobResourceContents(content); ok { - // Blob is base64-encoded per MCP spec, decode it to bytes - decoded, err := base64.StdEncoding.DecodeString(blobContent.Blob) - if err != nil { - slog.Warn("failed to decode base64 blob from resource", - "resource", uri, "backend", target.WorkloadID, "error", err) - // Append raw blob as fallback - data = append(data, []byte(blobContent.Blob)...) - } else { - data = append(data, decoded...) - } - if i == 0 && blobContent.MIMEType != "" { - mimeType = blobContent.MIMEType - } - } - } + // Concatenate all resource content items into a single byte slice. + data, mimeType := conversion.ConcatenateResourceContents(result.Contents) // Extract _meta field from backend response meta := conversion.FromMCPMeta(result.Meta) @@ -756,11 +686,7 @@ func (h *httpBackendClient) GetPrompt( slog.Debug("translating prompt name", "client_name", name, "backend_name", backendPromptName) } - // Convert map[string]any to map[string]string - stringArgs := make(map[string]string) - for k, v := range arguments { - stringArgs[k] = fmt.Sprintf("%v", v) - } + stringArgs := conversion.ConvertPromptArguments(arguments) result, err := c.GetPrompt(ctx, mcp.GetPromptRequest{ Params: mcp.GetPromptParams{ @@ -772,26 +698,9 @@ func (h *httpBackendClient) GetPrompt( return nil, fmt.Errorf("prompt get failed on backend %s: %w", target.WorkloadID, err) } - // Concatenate all prompt messages into a single string - // MCP prompts return messages with role and content (Content interface) - var prompt string - for _, msg := range result.Messages { - if msg.Role != "" { - prompt += fmt.Sprintf("[%s] ", msg.Role) - } - // Try to convert content to TextContent - if textContent, ok := mcp.AsTextContent(msg.Content); ok { - prompt += textContent.Text + "\n" - } - // TODO: Handle other content types (image, audio, resource) - } - - // Extract _meta field from backend response - meta := conversion.FromMCPMeta(result.Meta) - return &vmcp.PromptGetResult{ - Messages: prompt, + Messages: conversion.ConvertPromptMessages(result.Messages), Description: result.Description, - Meta: meta, + Meta: conversion.FromMCPMeta(result.Meta), }, nil } diff --git a/pkg/vmcp/client/conversions_test.go b/pkg/vmcp/client/conversions_test.go deleted file mode 100644 index c0a51af0c5..0000000000 --- a/pkg/vmcp/client/conversions_test.go +++ /dev/null @@ -1,489 +0,0 @@ -// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. -// SPDX-License-Identifier: Apache-2.0 - -package client - -import ( - "testing" - - "github.com/mark3labs/mcp-go/mcp" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - - "github.com/stacklok/toolhive/pkg/vmcp" -) - -// These tests verify the critical type conversion logic in the backend client. -// Since we can't easily mock the mark3labs client, we test the conversion patterns -// that our code uses to transform MCP SDK types to vmcp domain types. - -func TestToolInputSchemaConversion(t *testing.T) { - t.Parallel() - - t.Run("converts basic tool schema", func(t *testing.T) { - t.Parallel() - - sdkTool := mcp.Tool{ - Name: "create_issue", - Description: "Create a GitHub issue", - InputSchema: mcp.ToolInputSchema{ - Type: "object", - Properties: map[string]any{ - "title": map[string]any{"type": "string", "description": "Issue title"}, - "body": map[string]any{"type": "string", "description": "Issue body"}, - }, - Required: []string{"title"}, - }, - } - - inputSchema := convertToolInputSchema(sdkTool.InputSchema) - - assert.Equal(t, "object", inputSchema["type"]) - assert.NotNil(t, inputSchema["properties"]) - assert.Equal(t, []string{"title"}, inputSchema["required"]) - - props := inputSchema["properties"].(map[string]any) - assert.Contains(t, props, "title") - assert.Contains(t, props, "body") - titleProp := props["title"].(map[string]any) - assert.Equal(t, "string", titleProp["type"]) - assert.Equal(t, "Issue title", titleProp["description"]) - }) - - t.Run("converts schema with $defs", func(t *testing.T) { - t.Parallel() - - sdkTool := mcp.Tool{ - Name: "complex_tool", - InputSchema: mcp.ToolInputSchema{ - Type: "object", - Properties: map[string]any{ - "config": map[string]any{"$ref": "#/$defs/Config"}, - }, - Defs: map[string]any{ - "Config": map[string]any{ - "type": "object", - "properties": map[string]any{"enabled": map[string]any{"type": "boolean"}}, - }, - }, - }, - } - - inputSchema := convertToolInputSchema(sdkTool.InputSchema) - - assert.Contains(t, inputSchema, "$defs") - defs := inputSchema["$defs"].(map[string]any) - assert.Contains(t, defs, "Config") - }) - - t.Run("handles empty required array", func(t *testing.T) { - t.Parallel() - - sdkTool := mcp.Tool{ - Name: "optional_tool", - InputSchema: mcp.ToolInputSchema{ - Type: "object", - Properties: map[string]any{"optional_param": map[string]any{"type": "string"}}, - Required: []string{}, - }, - } - - inputSchema := convertToolInputSchema(sdkTool.InputSchema) - - assert.NotContains(t, inputSchema, "required") - }) -} - -func TestContentInterfaceHandling(t *testing.T) { - t.Parallel() - - t.Run("extracts text content correctly", func(t *testing.T) { - t.Parallel() - - toolResult := &mcp.CallToolResult{ - Content: []mcp.Content{ - mcp.NewTextContent("First text result"), - mcp.NewTextContent("Second text result"), - }, - IsError: false, - } - - resultMap := convertContentToMap(toolResult.Content) - - assert.Equal(t, "First text result", resultMap["text"]) - assert.Equal(t, "Second text result", resultMap["text_1"]) - }) - - t.Run("extracts mixed content types", func(t *testing.T) { - t.Parallel() - - toolResult := &mcp.CallToolResult{ - Content: []mcp.Content{ - mcp.NewTextContent("Text content"), - mcp.NewImageContent("base64data", "image/png"), - mcp.NewTextContent("More text"), - }, - IsError: false, - } - - resultMap := convertContentToMap(toolResult.Content) - - assert.Equal(t, "Text content", resultMap["text"]) - assert.Equal(t, "More text", resultMap["text_1"]) - assert.Equal(t, "base64data", resultMap["image_0"]) - }) - - t.Run("handles error result correctly", func(t *testing.T) { - t.Parallel() - - toolResult := &mcp.CallToolResult{ - Content: []mcp.Content{ - mcp.NewTextContent("Error: something went wrong"), - }, - IsError: true, - } - - // Verify IsError is a boolean (not pointer) - from client.go:223 - assert.True(t, toolResult.IsError) - // Our code should check: if result.IsError { return error } - }) -} - -func TestResourceContentsHandling(t *testing.T) { - t.Parallel() - - t.Run("extracts text resource content", func(t *testing.T) { - t.Parallel() - - resourceResult := &mcp.ReadResourceResult{ - Contents: []mcp.ResourceContents{ - mcp.TextResourceContents{ - URI: "test://resource", - MIMEType: "text/plain", - Text: "Resource text content", - }, - }, - } - - data := convertResourceContents(resourceResult.Contents) - assert.Equal(t, []byte("Resource text content"), data) - }) - - t.Run("extracts blob resource content", func(t *testing.T) { - t.Parallel() - - resourceResult := &mcp.ReadResourceResult{ - Contents: []mcp.ResourceContents{ - mcp.BlobResourceContents{ - URI: "test://binary", - MIMEType: "application/octet-stream", - Blob: "YmFzZTY0ZGF0YQ==", - }, - }, - } - - data := convertResourceContents(resourceResult.Contents) - assert.Equal(t, []byte("YmFzZTY0ZGF0YQ=="), data) - }) - - t.Run("concatenates multiple resource contents", func(t *testing.T) { - t.Parallel() - - resourceResult := &mcp.ReadResourceResult{ - Contents: []mcp.ResourceContents{ - mcp.TextResourceContents{URI: "test://multi", Text: "Part 1"}, - mcp.TextResourceContents{URI: "test://multi", Text: "Part 2"}, - }, - } - - data := convertResourceContents(resourceResult.Contents) - assert.Equal(t, []byte("Part 1Part 2"), data) - }) -} - -func TestPromptMessageHandling(t *testing.T) { - t.Parallel() - - t.Run("extracts prompt with single message", func(t *testing.T) { - t.Parallel() - - promptResult := &mcp.GetPromptResult{ - Description: "Test prompt", - Messages: []mcp.PromptMessage{ - {Role: "user", Content: mcp.NewTextContent("What is the weather?")}, - }, - } - - prompt := convertPromptMessages(promptResult.Messages) - assert.Equal(t, "[user] What is the weather?\n", prompt) - }) - - t.Run("concatenates multiple prompt messages", func(t *testing.T) { - t.Parallel() - - promptResult := &mcp.GetPromptResult{ - Messages: []mcp.PromptMessage{ - {Role: "system", Content: mcp.NewTextContent("You are a helpful assistant")}, - {Role: "user", Content: mcp.NewTextContent("Hello")}, - {Role: "assistant", Content: mcp.NewTextContent("Hi there!")}, - }, - } - - prompt := convertPromptMessages(promptResult.Messages) - expected := "[system] You are a helpful assistant\n[user] Hello\n[assistant] Hi there!\n" - assert.Equal(t, expected, prompt) - }) - - t.Run("handles prompt message without role", func(t *testing.T) { - t.Parallel() - - promptResult := &mcp.GetPromptResult{ - Messages: []mcp.PromptMessage{ - {Role: "", Content: mcp.NewTextContent("Message content")}, - }, - } - - prompt := convertPromptMessages(promptResult.Messages) - assert.Equal(t, "Message content\n", prompt) - }) -} - -func TestGetPromptArgumentsConversion(t *testing.T) { - t.Parallel() - - t.Run("converts map[string]any to map[string]string", func(t *testing.T) { - t.Parallel() - - arguments := map[string]any{ - "string_arg": "value", - "int_arg": 42, - "bool_arg": true, - "float_arg": 3.14, - } - - stringArgs := convertPromptArguments(arguments) - - assert.Equal(t, "value", stringArgs["string_arg"]) - assert.Equal(t, "42", stringArgs["int_arg"]) - assert.Equal(t, "true", stringArgs["bool_arg"]) - assert.Equal(t, "3.14", stringArgs["float_arg"]) - }) - - t.Run("handles nil and empty values", func(t *testing.T) { - t.Parallel() - - arguments := map[string]any{ - "nil_arg": nil, - "empty_arg": "", - } - - stringArgs := convertPromptArguments(arguments) - - assert.Equal(t, "", stringArgs["nil_arg"]) - assert.Equal(t, "", stringArgs["empty_arg"]) - }) -} - -func TestResourceMIMETypeField(t *testing.T) { - t.Parallel() - - t.Run("uses MIMEType not MimeType", func(t *testing.T) { - t.Parallel() - - // This verifies we're using the correct field name (from client.go:167) - sdkResource := mcp.Resource{ - URI: "test://resource", - Name: "Test Resource", - Description: "A test resource", - MIMEType: "application/json", // Note: MIMEType, not MimeType - } - - vmcpResource := vmcp.Resource{ - URI: sdkResource.URI, - Name: sdkResource.Name, - Description: sdkResource.Description, - MimeType: sdkResource.MIMEType, // Our conversion uses MIMEType - BackendID: "test-backend", - } - - assert.Equal(t, "application/json", vmcpResource.MimeType) - }) -} - -func TestMultipleContentItemsHandling(t *testing.T) { - t.Parallel() - - t.Run("handles tool result with many text items", func(t *testing.T) { - t.Parallel() - - toolResult := &mcp.CallToolResult{ - Content: []mcp.Content{ - mcp.NewTextContent("Result 1"), - mcp.NewTextContent("Result 2"), - mcp.NewTextContent("Result 3"), - mcp.NewTextContent("Result 4"), - mcp.NewTextContent("Result 5"), - }, - IsError: false, - } - - resultMap := convertContentToMap(toolResult.Content) - - assert.Equal(t, "Result 1", resultMap["text"]) - assert.Equal(t, "Result 2", resultMap["text_1"]) - assert.Equal(t, "Result 3", resultMap["text_2"]) - assert.Equal(t, "Result 4", resultMap["text_3"]) - assert.Equal(t, "Result 5", resultMap["text_4"]) - }) - - t.Run("handles tool result with many images", func(t *testing.T) { - t.Parallel() - - toolResult := &mcp.CallToolResult{ - Content: []mcp.Content{ - mcp.NewImageContent("data1", "image/png"), - mcp.NewImageContent("data2", "image/jpeg"), - mcp.NewImageContent("data3", "image/gif"), - }, - IsError: false, - } - - resultMap := convertContentToMap(toolResult.Content) - - assert.Equal(t, "data1", resultMap["image_0"]) - assert.Equal(t, "data2", resultMap["image_1"]) - assert.Equal(t, "data3", resultMap["image_2"]) - }) - - t.Run("handles empty content array", func(t *testing.T) { - t.Parallel() - - emptyContent := []mcp.Content{} - resultMap := convertContentToMap(emptyContent) - - assert.Empty(t, resultMap) - }) -} - -func TestPromptArgumentConversion(t *testing.T) { - t.Parallel() - - t.Run("converts prompt arguments correctly", func(t *testing.T) { - t.Parallel() - - // From client.go:174-183 - sdkPrompt := mcp.Prompt{ - Name: "test_prompt", - Description: "A test prompt", - Arguments: []mcp.PromptArgument{ - { - Name: "required_arg", - Description: "A required argument", - Required: true, - }, - { - Name: "optional_arg", - Description: "An optional argument", - Required: false, - }, - }, - } - - // Apply our conversion - args := make([]vmcp.PromptArgument, len(sdkPrompt.Arguments)) - for j, arg := range sdkPrompt.Arguments { - args[j] = vmcp.PromptArgument{ - Name: arg.Name, - Description: arg.Description, - Required: arg.Required, - } - } - - vmcpPrompt := vmcp.Prompt{ - Name: sdkPrompt.Name, - Description: sdkPrompt.Description, - Arguments: args, - BackendID: "test-backend", - } - - // Verify conversion - require.Len(t, vmcpPrompt.Arguments, 2) - assert.Equal(t, "required_arg", vmcpPrompt.Arguments[0].Name) - assert.True(t, vmcpPrompt.Arguments[0].Required) - assert.Equal(t, "optional_arg", vmcpPrompt.Arguments[1].Name) - assert.False(t, vmcpPrompt.Arguments[1].Required) - }) -} - -func TestConvertContent(t *testing.T) { - t.Parallel() - - t.Run("converts text content", func(t *testing.T) { - t.Parallel() - - mcpContent := mcp.NewTextContent("Hello, world!") - result := convertContent(mcpContent) - - assert.Equal(t, "text", result.Type) - assert.Equal(t, "Hello, world!", result.Text) - }) - - t.Run("converts empty text content", func(t *testing.T) { - t.Parallel() - - mcpContent := mcp.NewTextContent("") - result := convertContent(mcpContent) - - assert.Equal(t, "text", result.Type) - assert.Equal(t, "", result.Text) - }) - - t.Run("converts image content", func(t *testing.T) { - t.Parallel() - - mcpContent := mcp.NewImageContent("base64-encoded-image-data", "image/png") - result := convertContent(mcpContent) - - assert.Equal(t, "image", result.Type) - assert.Equal(t, "base64-encoded-image-data", result.Data) - assert.Equal(t, "image/png", result.MimeType) - }) - - t.Run("converts image content with empty mime type", func(t *testing.T) { - t.Parallel() - - mcpContent := mcp.NewImageContent("image-data", "") - result := convertContent(mcpContent) - - assert.Equal(t, "image", result.Type) - assert.Equal(t, "image-data", result.Data) - assert.Equal(t, "", result.MimeType) - }) - - t.Run("converts audio content", func(t *testing.T) { - t.Parallel() - - mcpContent := mcp.NewAudioContent("base64-encoded-audio-data", "audio/mpeg") - result := convertContent(mcpContent) - - assert.Equal(t, "audio", result.Type) - assert.Equal(t, "base64-encoded-audio-data", result.Data) - assert.Equal(t, "audio/mpeg", result.MimeType) - }) - - t.Run("converts audio content with empty mime type", func(t *testing.T) { - t.Parallel() - - mcpContent := mcp.NewAudioContent("audio-data", "") - result := convertContent(mcpContent) - - assert.Equal(t, "audio", result.Type) - assert.Equal(t, "audio-data", result.Data) - assert.Equal(t, "", result.MimeType) - }) - - // Note: We cannot easily test "unknown" content types because mcp.Content is an interface - // with an isContent() marker method. The MCP SDK only provides Text, Image, and Audio content types. - // If the SDK adds new content types in the future (e.g., embedded resources), convertContent - // will return Type: "unknown" for those until we add explicit support. -} diff --git a/pkg/vmcp/client/testhelpers_test.go b/pkg/vmcp/client/testhelpers_test.go deleted file mode 100644 index c5c1db2f34..0000000000 --- a/pkg/vmcp/client/testhelpers_test.go +++ /dev/null @@ -1,91 +0,0 @@ -// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. -// SPDX-License-Identifier: Apache-2.0 - -package client - -import ( - "fmt" - - "github.com/mark3labs/mcp-go/mcp" -) - -// Helper functions to encapsulate conversion logic patterns - -// convertToolInputSchema simulates the conversion logic from client.go:138-151 -func convertToolInputSchema(schema mcp.ToolInputSchema) map[string]any { - inputSchema := map[string]any{ - "type": schema.Type, - } - if schema.Properties != nil { - inputSchema["properties"] = schema.Properties - } - if len(schema.Required) > 0 { - inputSchema["required"] = schema.Required - } - if schema.Defs != nil { - inputSchema["$defs"] = schema.Defs - } - return inputSchema -} - -// convertContentToMap simulates the conversion logic from conversion.ContentArrayToMap -// This test helper converts MCP SDK content types to a map for testing. -// Audio content is intentionally ignored (not supported for template substitution). -func convertContentToMap(contents []mcp.Content) map[string]any { - resultMap := make(map[string]any) - textIndex := 0 - imageIndex := 0 - for _, content := range contents { - if textContent, ok := mcp.AsTextContent(content); ok { - key := "text" - if textIndex > 0 { - key = fmt.Sprintf("text_%d", textIndex) - } - resultMap[key] = textContent.Text - textIndex++ - } else if imageContent, ok := mcp.AsImageContent(content); ok { - key := fmt.Sprintf("image_%d", imageIndex) - resultMap[key] = imageContent.Data - imageIndex++ - } - // Audio content is ignored (matches conversion.ContentArrayToMap behavior) - // Resource content is handled separately, not in this map - } - return resultMap -} - -// convertResourceContents simulates the conversion logic from client.go:276-289 -func convertResourceContents(contents []mcp.ResourceContents) []byte { - var data []byte - for _, content := range contents { - if textContent, ok := mcp.AsTextResourceContents(content); ok { - data = append(data, []byte(textContent.Text)...) - } else if blobContent, ok := mcp.AsBlobResourceContents(content); ok { - data = append(data, []byte(blobContent.Blob)...) - } - } - return data -} - -// convertPromptMessages simulates the conversion logic from client.go:315-327 -func convertPromptMessages(messages []mcp.PromptMessage) string { - var prompt string - for _, msg := range messages { - if msg.Role != "" { - prompt += "[" + string(msg.Role) + "] " - } - if textContent, ok := mcp.AsTextContent(msg.Content); ok { - prompt += textContent.Text + "\n" - } - } - return prompt -} - -// convertPromptArguments simulates the conversion logic from client.go:306-309 -func convertPromptArguments(arguments map[string]any) map[string]string { - stringArgs := make(map[string]string) - for k, v := range arguments { - stringArgs[k] = fmt.Sprintf("%v", v) - } - return stringArgs -} diff --git a/pkg/vmcp/conversion/content.go b/pkg/vmcp/conversion/content.go index 45c9a6f2dd..b4f5a6612d 100644 --- a/pkg/vmcp/conversion/content.go +++ b/pkg/vmcp/conversion/content.go @@ -6,11 +6,120 @@ package conversion import ( + "encoding/base64" + "encoding/json" "fmt" + "log/slog" + "strings" + + "github.com/mark3labs/mcp-go/mcp" "github.com/stacklok/toolhive/pkg/vmcp" ) +// ConvertMCPContent converts a single mcp.Content item to vmcp.Content. +// Unknown content types are returned as vmcp.Content{Type: "unknown"}. +func ConvertMCPContent(content mcp.Content) vmcp.Content { + if text, ok := mcp.AsTextContent(content); ok { + return vmcp.Content{Type: "text", Text: text.Text} + } + if img, ok := mcp.AsImageContent(content); ok { + return vmcp.Content{Type: "image", Data: img.Data, MimeType: img.MIMEType} + } + if audio, ok := mcp.AsAudioContent(content); ok { + return vmcp.Content{Type: "audio", Data: audio.Data, MimeType: audio.MIMEType} + } + slog.Debug("Encountered unknown MCP content type", "type", fmt.Sprintf("%T", content)) + return vmcp.Content{Type: "unknown"} +} + +// ConvertMCPContents converts a slice of mcp.Content to []vmcp.Content. +// Returns an empty (non-nil) slice for a nil or empty input. +func ConvertMCPContents(contents []mcp.Content) []vmcp.Content { + result := make([]vmcp.Content, 0, len(contents)) + for _, c := range contents { + result = append(result, ConvertMCPContent(c)) + } + return result +} + +// ConcatenateResourceContents concatenates all MCP resource content items into a +// single byte slice and returns the MIME type of the first item. +// +// MCP resources may return multiple content chunks (text or blob). Text chunks +// are appended as UTF-8 bytes; blob chunks are base64-decoded per the MCP spec. +// If base64 decoding fails, the malformed chunk is skipped and a warning is logged +// (appending raw base64 bytes would produce corrupted binary data). +// The MIME type is taken from the first content item; subsequent items are +// expected to share the same type (the MCP spec does not define per-chunk types). +func ConcatenateResourceContents(contents []mcp.ResourceContents) (data []byte, mimeType string) { + for i, content := range contents { + if textContent, ok := mcp.AsTextResourceContents(content); ok { + data = append(data, []byte(textContent.Text)...) + if i == 0 && textContent.MIMEType != "" { + mimeType = textContent.MIMEType + } + } else if blobContent, ok := mcp.AsBlobResourceContents(content); ok { + decoded, err := base64.StdEncoding.DecodeString(blobContent.Blob) + if err != nil { + slog.Warn("Skipping malformed base64 blob resource chunk; this chunk's data is lost", + "uri", blobContent.URI, "error", err) + continue + } + data = append(data, decoded...) + if i == 0 && blobContent.MIMEType != "" { + mimeType = blobContent.MIMEType + } + } + } + return data, mimeType +} + +// ConvertToolInputSchema converts a mcp.ToolInputSchema to map[string]any via a +// JSON round-trip, capturing all fields (type, properties, required, $defs, +// additionalProperties, etc.) without enumerating them manually. Falls back to +// {type: schema.Type} if marshalling fails. +func ConvertToolInputSchema(schema mcp.ToolInputSchema) map[string]any { + result := make(map[string]any) + b, err := json.Marshal(schema) + if err != nil { + return map[string]any{"type": schema.Type} + } + if err := json.Unmarshal(b, &result); err != nil { + return map[string]any{"type": schema.Type} + } + return result +} + +// ConvertPromptMessages flattens MCP prompt messages into a single string with +// the format "[role] text\n". Messages without a role omit the prefix. Only +// text content is included; non-text content is silently discarded (Phase 1 +// limitation — vmcp.PromptGetResult carries a flat string, not structured messages). +func ConvertPromptMessages(messages []mcp.PromptMessage) string { + var sb strings.Builder + for _, msg := range messages { + if msg.Role != "" { + fmt.Fprintf(&sb, "[%s] ", msg.Role) + } + if textContent, ok := mcp.AsTextContent(msg.Content); ok { + sb.WriteString(textContent.Text) + sb.WriteByte('\n') + } + } + return sb.String() +} + +// ConvertPromptArguments converts map[string]any to map[string]string by +// formatting each value with fmt.Sprintf("%v", v). Required by the MCP +// GetPrompt API which accepts only string-typed arguments. +func ConvertPromptArguments(arguments map[string]any) map[string]string { + result := make(map[string]string, len(arguments)) + for k, v := range arguments { + result[k] = fmt.Sprintf("%v", v) + } + return result +} + // ContentArrayToMap converts a vmcp.Content array to a map for template variable substitution. // This is used by composite tool workflows and backend result handling. // diff --git a/pkg/vmcp/conversion/conversion_test.go b/pkg/vmcp/conversion/conversion_test.go index 528127e279..1275ede97a 100644 --- a/pkg/vmcp/conversion/conversion_test.go +++ b/pkg/vmcp/conversion/conversion_test.go @@ -4,15 +4,317 @@ package conversion_test import ( + "encoding/base64" "testing" "github.com/mark3labs/mcp-go/mcp" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "github.com/stacklok/toolhive/pkg/vmcp" "github.com/stacklok/toolhive/pkg/vmcp/conversion" ) +func TestConvertToolInputSchema(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + schema mcp.ToolInputSchema + checks func(t *testing.T, got map[string]any) + }{ + { + name: "captures type, properties, required", + schema: mcp.ToolInputSchema{ + Type: "object", + Properties: map[string]any{ + "title": map[string]any{"type": "string"}, + }, + Required: []string{"title"}, + }, + checks: func(t *testing.T, got map[string]any) { + t.Helper() + assert.Equal(t, "object", got["type"]) + assert.Contains(t, got, "properties") + required, ok := got["required"].([]any) + require.True(t, ok) + assert.Equal(t, []any{"title"}, required) + }, + }, + { + name: "captures $defs", + schema: mcp.ToolInputSchema{ + Type: "object", + Defs: map[string]any{"Config": map[string]any{"type": "object"}}, + }, + checks: func(t *testing.T, got map[string]any) { + t.Helper() + assert.Contains(t, got, "$defs") + }, + }, + { + name: "nil required emitted as empty array by mcp-go", + schema: mcp.ToolInputSchema{Type: "object", Required: nil}, + checks: func(t *testing.T, got map[string]any) { + t.Helper() + required, ok := got["required"].([]any) + require.True(t, ok) + assert.Empty(t, required) + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + got := conversion.ConvertToolInputSchema(tt.schema) + tt.checks(t, got) + }) + } +} + +func TestConvertPromptMessages(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + messages []mcp.PromptMessage + want string + }{ + { + name: "empty messages", + messages: nil, + want: "", + }, + { + name: "single message with role", + messages: []mcp.PromptMessage{ + {Role: "user", Content: mcp.NewTextContent("Hello")}, + }, + want: "[user] Hello\n", + }, + { + name: "message without role omits prefix", + messages: []mcp.PromptMessage{ + {Role: "", Content: mcp.NewTextContent("No role")}, + }, + want: "No role\n", + }, + { + name: "multiple messages concatenated", + messages: []mcp.PromptMessage{ + {Role: "system", Content: mcp.NewTextContent("You are helpful")}, + {Role: "user", Content: mcp.NewTextContent("Hi")}, + {Role: "assistant", Content: mcp.NewTextContent("Hello!")}, + }, + want: "[system] You are helpful\n[user] Hi\n[assistant] Hello!\n", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + assert.Equal(t, tt.want, conversion.ConvertPromptMessages(tt.messages)) + }) + } +} + +func TestConvertPromptArguments(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + arguments map[string]any + want map[string]string + }{ + { + name: "nil map returns empty map", + arguments: nil, + want: map[string]string{}, + }, + { + name: "string values pass through unchanged", + arguments: map[string]any{"key": "value"}, + want: map[string]string{"key": "value"}, + }, + { + name: "non-string values are formatted", + arguments: map[string]any{ + "int": 42, + "bool": true, + "float": 3.14, + "nil": nil, + }, + want: map[string]string{ + "int": "42", + "bool": "true", + "float": "3.14", + "nil": "", + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + assert.Equal(t, tt.want, conversion.ConvertPromptArguments(tt.arguments)) + }) + } +} + +func TestConvertMCPContent(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + input mcp.Content + want vmcp.Content + }{ + { + name: "text content", + input: mcp.NewTextContent("hello world"), + want: vmcp.Content{Type: "text", Text: "hello world"}, + }, + { + name: "image content", + input: mcp.NewImageContent("base64imgdata", "image/png"), + want: vmcp.Content{Type: "image", Data: "base64imgdata", MimeType: "image/png"}, + }, + { + name: "audio content", + input: mcp.NewAudioContent("base64audiodata", "audio/mpeg"), + want: vmcp.Content{Type: "audio", Data: "base64audiodata", MimeType: "audio/mpeg"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + got := conversion.ConvertMCPContent(tt.input) + assert.Equal(t, tt.want, got) + }) + } +} + +func TestConvertMCPContents(t *testing.T) { + t.Parallel() + + t.Run("nil slice returns empty slice", func(t *testing.T) { + t.Parallel() + got := conversion.ConvertMCPContents(nil) + assert.Empty(t, got) + }) + + t.Run("empty slice returns empty slice", func(t *testing.T) { + t.Parallel() + got := conversion.ConvertMCPContents([]mcp.Content{}) + assert.Empty(t, got) + }) + + t.Run("mixed content types are all converted", func(t *testing.T) { + t.Parallel() + input := []mcp.Content{ + mcp.NewTextContent("first"), + mcp.NewImageContent("imgdata", "image/jpeg"), + mcp.NewAudioContent("audiodata", "audio/ogg"), + } + want := []vmcp.Content{ + {Type: "text", Text: "first"}, + {Type: "image", Data: "imgdata", MimeType: "image/jpeg"}, + {Type: "audio", Data: "audiodata", MimeType: "audio/ogg"}, + } + got := conversion.ConvertMCPContents(input) + assert.Equal(t, want, got) + }) + + t.Run("order is preserved", func(t *testing.T) { + t.Parallel() + input := []mcp.Content{ + mcp.NewTextContent("a"), + mcp.NewTextContent("b"), + mcp.NewTextContent("c"), + } + got := conversion.ConvertMCPContents(input) + require.Len(t, got, 3) + assert.Equal(t, "a", got[0].Text) + assert.Equal(t, "b", got[1].Text) + assert.Equal(t, "c", got[2].Text) + }) +} + +func TestConcatenateResourceContents(t *testing.T) { + t.Parallel() + + rawText := "hello resource" + blobBytes := []byte("binary data") + blobEncoded := base64.StdEncoding.EncodeToString(blobBytes) + + tests := []struct { + name string + contents []mcp.ResourceContents + wantData []byte + wantMimeType string + }{ + { + name: "empty contents", + contents: nil, + wantData: nil, + }, + { + name: "single text item", + contents: []mcp.ResourceContents{ + mcp.TextResourceContents{URI: "file://a", MIMEType: "text/plain", Text: rawText}, + }, + wantData: []byte(rawText), + wantMimeType: "text/plain", + }, + { + name: "single blob item decoded", + contents: []mcp.ResourceContents{ + mcp.BlobResourceContents{URI: "file://b", MIMEType: "application/octet-stream", Blob: blobEncoded}, + }, + wantData: blobBytes, + wantMimeType: "application/octet-stream", + }, + { + name: "multiple text chunks concatenated", + contents: []mcp.ResourceContents{ + mcp.TextResourceContents{URI: "file://c", MIMEType: "text/plain", Text: "part1"}, + mcp.TextResourceContents{URI: "file://c", Text: "part2"}, + }, + wantData: []byte("part1part2"), + wantMimeType: "text/plain", + }, + { + name: "mime type taken from first item only", + contents: []mcp.ResourceContents{ + mcp.TextResourceContents{URI: "file://d", MIMEType: "text/html", Text: "a"}, + mcp.TextResourceContents{URI: "file://d", MIMEType: "text/plain", Text: "b"}, + }, + wantData: []byte("ab"), + wantMimeType: "text/html", + }, + { + name: "invalid base64 blob chunk is skipped", + contents: []mcp.ResourceContents{ + mcp.BlobResourceContents{URI: "file://e", Blob: "not-valid-base64!!!"}, + }, + // Malformed base64 is skipped entirely; appending raw bytes would produce + // corrupted binary data, so we prefer an empty result over corrupted data. + wantData: nil, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + data, mimeType := conversion.ConcatenateResourceContents(tt.contents) + assert.Equal(t, tt.wantData, data) + assert.Equal(t, tt.wantMimeType, mimeType) + }) + } +} + func TestContentArrayToMap(t *testing.T) { t.Parallel() diff --git a/pkg/vmcp/session/connector_integration_test.go b/pkg/vmcp/session/connector_integration_test.go new file mode 100644 index 0000000000..ed1632751b --- /dev/null +++ b/pkg/vmcp/session/connector_integration_test.go @@ -0,0 +1,219 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package session + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + + mcpmcp "github.com/mark3labs/mcp-go/mcp" + mcpserver "github.com/mark3labs/mcp-go/server" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/stacklok/toolhive/pkg/vmcp" + vmcpauth "github.com/stacklok/toolhive/pkg/vmcp/auth" + "github.com/stacklok/toolhive/pkg/vmcp/auth/strategies" + authtypes "github.com/stacklok/toolhive/pkg/vmcp/auth/types" +) + +// startInProcessMCPServer creates a real in-process MCP server over +// streamable-HTTP and returns its base URL. The server is shut down when the +// test ends via t.Cleanup. +// +// The server exposes: +// - tool "echo": returns the "input" argument as text content +// - resource "test://data": returns the static text "hello" +// - prompt "greet": returns a greeting message +func startInProcessMCPServer(t *testing.T) string { + t.Helper() + + mcpSrv := mcpserver.NewMCPServer("integration-test-backend", "1.0.0") + + mcpSrv.AddTool( + mcpmcp.NewTool("echo", + mcpmcp.WithDescription("Echoes the input back"), + mcpmcp.WithString("input", mcpmcp.Required()), + ), + func(_ context.Context, req mcpmcp.CallToolRequest) (*mcpmcp.CallToolResult, error) { + args, _ := req.Params.Arguments.(map[string]any) + input, _ := args["input"].(string) + return &mcpmcp.CallToolResult{ + Content: []mcpmcp.Content{mcpmcp.NewTextContent(input)}, + }, nil + }, + ) + + mcpSrv.AddResource( + mcpmcp.Resource{ + URI: "test://data", + Name: "Test Data", + MIMEType: "text/plain", + }, + func(_ context.Context, _ mcpmcp.ReadResourceRequest) ([]mcpmcp.ResourceContents, error) { + return []mcpmcp.ResourceContents{ + mcpmcp.TextResourceContents{URI: "test://data", MIMEType: "text/plain", Text: "hello"}, + }, nil + }, + ) + + mcpSrv.AddPrompt( + mcpmcp.NewPrompt("greet", + mcpmcp.WithPromptDescription("Returns a greeting"), + ), + func(_ context.Context, _ mcpmcp.GetPromptRequest) (*mcpmcp.GetPromptResult, error) { + return &mcpmcp.GetPromptResult{ + Messages: []mcpmcp.PromptMessage{ + {Role: "user", Content: mcpmcp.NewTextContent("Hello!")}, + }, + }, nil + }, + ) + + streamableSrv := mcpserver.NewStreamableHTTPServer(mcpSrv) + mux := http.NewServeMux() + mux.Handle("/mcp", streamableSrv) + + ts := httptest.NewServer(mux) + t.Cleanup(ts.Close) + + return ts.URL + "/mcp" +} + +// newUnauthenticatedRegistry returns a minimal OutgoingAuthRegistry that +// uses the unauthenticated (no-op) strategy — suitable for tests where the +// backend MCP server does not require auth. +func newUnauthenticatedRegistry(t *testing.T) vmcpauth.OutgoingAuthRegistry { + t.Helper() + reg := vmcpauth.NewDefaultOutgoingAuthRegistry() + require.NoError(t, reg.RegisterStrategy(authtypes.StrategyTypeUnauthenticated, strategies.NewUnauthenticatedStrategy())) + return reg +} + +// --------------------------------------------------------------------------- +// Integration tests — exercise the real HTTP connector +// --------------------------------------------------------------------------- + +func TestSessionFactory_Integration_CapabilityDiscovery(t *testing.T) { + t.Parallel() + + baseURL := startInProcessMCPServer(t) + backend := &vmcp.Backend{ + ID: "integration-backend", + Name: "integration-backend", + BaseURL: baseURL, + TransportType: "streamable-http", + } + + factory := NewSessionFactory(newUnauthenticatedRegistry(t)) + sess, err := factory.MakeSession(context.Background(), nil, []*vmcp.Backend{backend}) + require.NoError(t, err) + require.NotNil(t, sess) + t.Cleanup(func() { require.NoError(t, sess.Close()) }) + + // The real MCP Initialize + ListTools/Resources/Prompts handshake must + // have discovered all three capabilities. + require.Len(t, sess.Tools(), 1) + assert.Equal(t, "echo", sess.Tools()[0].Name) + + require.Len(t, sess.Resources(), 1) + assert.Equal(t, "test://data", sess.Resources()[0].URI) + + require.Len(t, sess.Prompts(), 1) + assert.Equal(t, "greet", sess.Prompts()[0].Name) +} + +func TestSessionFactory_Integration_CallTool(t *testing.T) { + t.Parallel() + + baseURL := startInProcessMCPServer(t) + backend := &vmcp.Backend{ + ID: "integration-backend", + Name: "integration-backend", + BaseURL: baseURL, + TransportType: "streamable-http", + } + + factory := NewSessionFactory(newUnauthenticatedRegistry(t)) + sess, err := factory.MakeSession(context.Background(), nil, []*vmcp.Backend{backend}) + require.NoError(t, err) + t.Cleanup(func() { require.NoError(t, sess.Close()) }) + + result, err := sess.CallTool(context.Background(), "echo", map[string]any{"input": "hello world"}, nil) + require.NoError(t, err) + require.NotNil(t, result) + require.Len(t, result.Content, 1) + assert.Equal(t, "hello world", result.Content[0].Text) +} + +func TestSessionFactory_Integration_ReadResource(t *testing.T) { + t.Parallel() + + baseURL := startInProcessMCPServer(t) + backend := &vmcp.Backend{ + ID: "integration-backend", + Name: "integration-backend", + BaseURL: baseURL, + TransportType: "streamable-http", + } + + factory := NewSessionFactory(newUnauthenticatedRegistry(t)) + sess, err := factory.MakeSession(context.Background(), nil, []*vmcp.Backend{backend}) + require.NoError(t, err) + t.Cleanup(func() { require.NoError(t, sess.Close()) }) + + result, err := sess.ReadResource(context.Background(), "test://data") + require.NoError(t, err) + require.NotNil(t, result) + assert.Equal(t, "hello", string(result.Contents)) +} + +func TestSessionFactory_Integration_GetPrompt(t *testing.T) { + t.Parallel() + + baseURL := startInProcessMCPServer(t) + backend := &vmcp.Backend{ + ID: "integration-backend", + Name: "integration-backend", + BaseURL: baseURL, + TransportType: "streamable-http", + } + + factory := NewSessionFactory(newUnauthenticatedRegistry(t)) + sess, err := factory.MakeSession(context.Background(), nil, []*vmcp.Backend{backend}) + require.NoError(t, err) + t.Cleanup(func() { require.NoError(t, sess.Close()) }) + + result, err := sess.GetPrompt(context.Background(), "greet", nil) + require.NoError(t, err) + require.NotNil(t, result) + // ConvertPromptMessages formats messages as "[role] text\n" + assert.Equal(t, "[user] Hello!\n", result.Messages) +} + +func TestSessionFactory_Integration_MultipleBackends(t *testing.T) { + t.Parallel() + + // Start two independent backends — each has its own "echo" tool. + // The factory must route each call to the correct backend after resolving + // the capability-name conflict (alphabetically-earlier backend wins). + url1 := startInProcessMCPServer(t) + url2 := startInProcessMCPServer(t) + + backends := []*vmcp.Backend{ + {ID: "backend-b", Name: "backend-b", BaseURL: url2, TransportType: "streamable-http"}, + {ID: "backend-a", Name: "backend-a", BaseURL: url1, TransportType: "streamable-http"}, + } + + factory := NewSessionFactory(newUnauthenticatedRegistry(t)) + sess, err := factory.MakeSession(context.Background(), nil, backends) + require.NoError(t, err) + t.Cleanup(func() { require.NoError(t, sess.Close()) }) + + // Both backends expose "echo"; "backend-a" sorts first and must win. + require.Len(t, sess.Tools(), 1, "conflicting tool names collapse to one") + assert.Equal(t, "backend-a", sess.Tools()[0].BackendID) +} diff --git a/pkg/vmcp/session/default_session.go b/pkg/vmcp/session/default_session.go new file mode 100644 index 0000000000..f940290762 --- /dev/null +++ b/pkg/vmcp/session/default_session.go @@ -0,0 +1,222 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package session + +import ( + "context" + "errors" + "fmt" + "maps" + "sync" + + transportsession "github.com/stacklok/toolhive/pkg/transport/session" + "github.com/stacklok/toolhive/pkg/vmcp" + "github.com/stacklok/toolhive/pkg/vmcp/session/internal/backend" +) + +// Compile-time assertions: defaultMultiSession must implement both interfaces. +var _ MultiSession = (*defaultMultiSession)(nil) +var _ transportsession.Session = (*defaultMultiSession)(nil) + +// Sentinel errors returned by defaultMultiSession methods. +var ( + // ErrSessionClosed is returned when an operation is attempted on a closed session. + ErrSessionClosed = errors.New("session is closed") + + // ErrToolNotFound is returned when the requested tool is not in the routing table. + ErrToolNotFound = errors.New("tool not found in session routing table") + + // ErrResourceNotFound is returned when the requested resource is not in the routing table. + ErrResourceNotFound = errors.New("resource not found in session routing table") + + // ErrPromptNotFound is returned when the requested prompt is not in the routing table. + ErrPromptNotFound = errors.New("prompt not found in session routing table") + + // ErrNoBackendClient is returned when the routing table references a backend + // that has no entry in the connections map. This indicates an internal + // invariant violation: under normal operation MakeSession always populates + // both maps together, so this error should never be seen at runtime. + ErrNoBackendClient = errors.New("no client available for backend") +) + +// defaultMultiSession is the production MultiSession implementation. +// +// # Thread-safety model +// +// mu guards connections, closed, and the wg.Add call. RLock is held only +// long enough to retrieve state and atomically increment the in-flight counter +// (wg.Add); it is released before network I/O begins. +// routingTable, tools, resources, and prompts are written once during +// MakeSession and are read-only thereafter — they do not require lock protection. +// +// wg tracks in-flight operations. Close() sets closed=true under write lock, +// then waits for wg to reach zero before tearing down backend connections. +// Because wg.Add(1) always happens while the read lock is held (and before +// Close() acquires the write lock), there is no race between Close() and +// in-flight operations. +// +// # Lifecycle +// +// 1. Created by defaultMultiSessionFactory.MakeSession (Phase 1: purely additive). +// 2. CallTool / ReadResource / GetPrompt increment wg, perform I/O, decrement wg. +// 3. Close() sets closed=true, waits for wg, then closes all backend sessions. +// +// # Composite tools +// +// Composite tools (VirtualMCPCompositeToolDefinition) are out of scope for +// Phase 1. When they are introduced they will be resolved at a higher layer +// (e.g. the vMCP router or handler) and injected alongside the backend tool +// list, rather than being routed through the backend connections held here. +type defaultMultiSession struct { + transportsession.Session // embedded interface — provides ID, Type, timestamps, etc. + + connections map[string]backend.Session // backend workload ID → persistent backend session + routingTable *vmcp.RoutingTable + tools []vmcp.Tool + resources []vmcp.Resource + prompts []vmcp.Prompt + backendSessions map[string]string // backend workload ID → backend-assigned session ID + + mu sync.RWMutex + wg sync.WaitGroup + closed bool +} + +// Tools returns a snapshot copy of the tools available in this session. +func (s *defaultMultiSession) Tools() []vmcp.Tool { + s.mu.RLock() + defer s.mu.RUnlock() + result := make([]vmcp.Tool, len(s.tools)) + copy(result, s.tools) + return result +} + +// Resources returns a snapshot copy of the resources available in this session. +func (s *defaultMultiSession) Resources() []vmcp.Resource { + s.mu.RLock() + defer s.mu.RUnlock() + result := make([]vmcp.Resource, len(s.resources)) + copy(result, s.resources) + return result +} + +// Prompts returns a snapshot copy of the prompts available in this session. +func (s *defaultMultiSession) Prompts() []vmcp.Prompt { + s.mu.RLock() + defer s.mu.RUnlock() + result := make([]vmcp.Prompt, len(s.prompts)) + copy(result, s.prompts) + return result +} + +// BackendSessions returns a snapshot copy of backend-assigned session IDs. +func (s *defaultMultiSession) BackendSessions() map[string]string { + s.mu.RLock() + defer s.mu.RUnlock() + result := make(map[string]string, len(s.backendSessions)) + maps.Copy(result, s.backendSessions) + return result +} + +// lookupBackend resolves capName against table and returns the live backend +// session for the backend that owns it. +// +// On success, wg.Add(1) has been called before the lock is released. The +// caller MUST call wg.Done() (typically via defer) when the I/O completes. +// On error, wg.Add was never called. +func (s *defaultMultiSession) lookupBackend( + capName string, + table map[string]*vmcp.BackendTarget, + notFoundErr error, +) (backend.Session, error) { + // Hold RLock to atomically check closed and register the in-flight + // operation. wg.Add(1) is called while the lock is held so that Close() + // cannot slip in between "check closed" and "add to wait group". + s.mu.RLock() + if s.closed { + s.mu.RUnlock() + return nil, ErrSessionClosed + } + target, ok := table[capName] + if !ok { + s.mu.RUnlock() + return nil, fmt.Errorf("%w: %q", notFoundErr, capName) + } + conn, ok := s.connections[target.WorkloadID] + if !ok { + s.mu.RUnlock() + return nil, fmt.Errorf("%w for backend %q", ErrNoBackendClient, target.WorkloadID) + } + s.wg.Add(1) // register before releasing the lock to avoid a race with Close() + s.mu.RUnlock() + return conn, nil +} + +// CallTool invokes toolName on the appropriate backend. +func (s *defaultMultiSession) CallTool( + ctx context.Context, + toolName string, + arguments map[string]any, + meta map[string]any, +) (*vmcp.ToolCallResult, error) { + conn, err := s.lookupBackend(toolName, s.routingTable.Tools, ErrToolNotFound) + if err != nil { + return nil, err + } + defer s.wg.Done() + return conn.CallTool(ctx, toolName, arguments, meta) +} + +// ReadResource retrieves the resource identified by uri. +func (s *defaultMultiSession) ReadResource(ctx context.Context, uri string) (*vmcp.ResourceReadResult, error) { + conn, err := s.lookupBackend(uri, s.routingTable.Resources, ErrResourceNotFound) + if err != nil { + return nil, err + } + defer s.wg.Done() + return conn.ReadResource(ctx, uri) +} + +// GetPrompt retrieves the named prompt from the appropriate backend. +func (s *defaultMultiSession) GetPrompt( + ctx context.Context, + name string, + arguments map[string]any, +) (*vmcp.PromptGetResult, error) { + conn, err := s.lookupBackend(name, s.routingTable.Prompts, ErrPromptNotFound) + if err != nil { + return nil, err + } + defer s.wg.Done() + return conn.GetPrompt(ctx, name, arguments) +} + +// Close releases all resources. It is idempotent: subsequent calls return nil +// without attempting to close backends again. +func (s *defaultMultiSession) Close() error { + s.mu.Lock() + if s.closed { + s.mu.Unlock() + return nil + } + s.closed = true + s.mu.Unlock() + + // Wait for all in-flight operations to complete before tearing down clients. + // No new operations can start after this point because closed=true was set + // under the write lock, and callers check closed under the read lock. + s.wg.Wait() + + // s.connections is read without holding mu: closed=true prevents any new + // operation from starting, and wg.Wait() ensures all in-flight operations + // have finished. connections is only written during MakeSession (phase 1), + // so no concurrent writer exists at this point. + var errs []error + for id, conn := range s.connections { + if err := conn.Close(); err != nil { + errs = append(errs, fmt.Errorf("failed to close backend %s: %w", id, err)) + } + } + return errors.Join(errs...) +} diff --git a/pkg/vmcp/session/default_session_test.go b/pkg/vmcp/session/default_session_test.go new file mode 100644 index 0000000000..05bb91472b --- /dev/null +++ b/pkg/vmcp/session/default_session_test.go @@ -0,0 +1,1076 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package session + +import ( + "context" + "errors" + "fmt" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/stacklok/toolhive/pkg/auth" + transportsession "github.com/stacklok/toolhive/pkg/transport/session" + "github.com/stacklok/toolhive/pkg/vmcp" + internalbk "github.com/stacklok/toolhive/pkg/vmcp/session/internal/backend" +) + +// --------------------------------------------------------------------------- +// Helpers / mocks +// --------------------------------------------------------------------------- + +// mockConnectedBackend is an in-process internalbk.Session for testing. +type mockConnectedBackend struct { + callToolFunc func(ctx context.Context, toolName string, arguments, meta map[string]any) (*vmcp.ToolCallResult, error) + readResourceFunc func(ctx context.Context, uri string) (*vmcp.ResourceReadResult, error) + getPromptFunc func(ctx context.Context, name string, arguments map[string]any) (*vmcp.PromptGetResult, error) + sessID string + closeCalled atomic.Bool + closeErr error +} + +func (m *mockConnectedBackend) CallTool(ctx context.Context, toolName string, arguments, meta map[string]any) (*vmcp.ToolCallResult, error) { + if m.callToolFunc != nil { + return m.callToolFunc(ctx, toolName, arguments, meta) + } + return &vmcp.ToolCallResult{Content: []vmcp.Content{{Type: "text", Text: "ok"}}}, nil +} + +func (m *mockConnectedBackend) ReadResource(ctx context.Context, uri string) (*vmcp.ResourceReadResult, error) { + if m.readResourceFunc != nil { + return m.readResourceFunc(ctx, uri) + } + return &vmcp.ResourceReadResult{Contents: []byte("data"), MimeType: "text/plain"}, nil +} + +func (m *mockConnectedBackend) GetPrompt(ctx context.Context, name string, arguments map[string]any) (*vmcp.PromptGetResult, error) { + if m.getPromptFunc != nil { + return m.getPromptFunc(ctx, name, arguments) + } + return &vmcp.PromptGetResult{Messages: "hello"}, nil +} + +func (m *mockConnectedBackend) SessionID() string { return m.sessID } +func (m *mockConnectedBackend) Close() error { + m.closeCalled.Store(true) + return m.closeErr +} + +// buildTestSession creates a defaultMultiSession wired with mock backends. +// +//nolint:unparam // backendID is intentionally a parameter for readability; callers consistently use "b1" +func buildTestSession( + t *testing.T, + backendID string, + conn internalbk.Session, + tools []vmcp.Tool, + resources []vmcp.Resource, + prompts []vmcp.Prompt, +) *defaultMultiSession { + t.Helper() + + target := &vmcp.BackendTarget{ + WorkloadID: backendID, + WorkloadName: backendID, + BaseURL: "http://localhost:9999", + } + + rt := &vmcp.RoutingTable{ + Tools: make(map[string]*vmcp.BackendTarget), + Resources: make(map[string]*vmcp.BackendTarget), + Prompts: make(map[string]*vmcp.BackendTarget), + } + for _, tool := range tools { + rt.Tools[tool.Name] = target + } + for _, res := range resources { + rt.Resources[res.URI] = target + } + for _, prompt := range prompts { + rt.Prompts[prompt.Name] = target + } + + return &defaultMultiSession{ + Session: transportsession.NewStreamableSession("test-session-id"), + connections: map[string]internalbk.Session{backendID: conn}, + routingTable: rt, + tools: tools, + resources: resources, + prompts: prompts, + backendSessions: map[string]string{backendID: "backend-session-abc"}, + } +} + +// --------------------------------------------------------------------------- +// Interface composition +// --------------------------------------------------------------------------- + +// --------------------------------------------------------------------------- +// Tools / Resources / Prompts accessors +// --------------------------------------------------------------------------- + +func TestDefaultSession_Accessors(t *testing.T) { + t.Parallel() + + tools := []vmcp.Tool{{Name: "search", BackendID: "b1"}} + resources := []vmcp.Resource{{URI: "file://readme", BackendID: "b1"}} + prompts := []vmcp.Prompt{{Name: "greet", BackendID: "b1"}} + + sess := buildTestSession(t, "b1", &mockConnectedBackend{}, tools, resources, prompts) + + assert.Equal(t, tools, sess.Tools()) + assert.Equal(t, resources, sess.Resources()) + assert.Equal(t, prompts, sess.Prompts()) + + bs := sess.BackendSessions() + assert.Equal(t, "backend-session-abc", bs["b1"]) + // Returned map is a copy — mutating it must not affect the session. + bs["b1"] = "mutated" + assert.Equal(t, "backend-session-abc", sess.BackendSessions()["b1"]) +} + +// --------------------------------------------------------------------------- +// CallTool +// --------------------------------------------------------------------------- + +func TestDefaultSession_CallTool(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + toolName string + mockFn func(ctx context.Context, toolName string, arguments, meta map[string]any) (*vmcp.ToolCallResult, error) + wantErr bool + wantErrIs error + wantContent string + }{ + { + name: "successful tool call", + toolName: "search", + mockFn: func(_ context.Context, _ string, _, _ map[string]any) (*vmcp.ToolCallResult, error) { + return &vmcp.ToolCallResult{Content: []vmcp.Content{{Type: "text", Text: "result"}}}, nil + }, + wantContent: "result", + }, + { + name: "tool not in routing table", + toolName: "nonexistent", + wantErr: true, + wantErrIs: ErrToolNotFound, + }, + { + name: "backend returns error", + toolName: "search", + mockFn: func(_ context.Context, _ string, _, _ map[string]any) (*vmcp.ToolCallResult, error) { + return nil, errors.New("backend boom") + }, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + mock := &mockConnectedBackend{callToolFunc: tt.mockFn} + sess := buildTestSession(t, "b1", mock, + []vmcp.Tool{{Name: "search", BackendID: "b1"}}, + nil, nil, + ) + + result, err := sess.CallTool(context.Background(), tt.toolName, nil, nil) + if tt.wantErr { + require.Error(t, err) + if tt.wantErrIs != nil { + assert.ErrorIs(t, err, tt.wantErrIs) + } + return + } + require.NoError(t, err) + require.NotNil(t, result) + assert.Equal(t, tt.wantContent, result.Content[0].Text) + }) + } +} + +// --------------------------------------------------------------------------- +// ReadResource +// --------------------------------------------------------------------------- + +func TestDefaultSession_ReadResource(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + uri string + mockFn func(ctx context.Context, uri string) (*vmcp.ResourceReadResult, error) + wantErr bool + wantErrIs error + wantData string + }{ + { + name: "successful read", + uri: "file://readme", + mockFn: func(_ context.Context, _ string) (*vmcp.ResourceReadResult, error) { + return &vmcp.ResourceReadResult{Contents: []byte("hello"), MimeType: "text/plain"}, nil + }, + wantData: "hello", + }, + { + name: "resource not in routing table", + uri: "file://missing", + wantErr: true, + wantErrIs: ErrResourceNotFound, + }, + { + name: "backend returns error", + uri: "file://readme", + mockFn: func(_ context.Context, _ string) (*vmcp.ResourceReadResult, error) { + return nil, errors.New("backend boom") + }, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + mock := &mockConnectedBackend{readResourceFunc: tt.mockFn} + sess := buildTestSession(t, "b1", mock, + nil, + []vmcp.Resource{{URI: "file://readme", BackendID: "b1"}}, + nil, + ) + + result, err := sess.ReadResource(context.Background(), tt.uri) + if tt.wantErr { + require.Error(t, err) + if tt.wantErrIs != nil { + assert.ErrorIs(t, err, tt.wantErrIs) + } + return + } + require.NoError(t, err) + assert.Equal(t, tt.wantData, string(result.Contents)) + }) + } +} + +// --------------------------------------------------------------------------- +// GetPrompt +// --------------------------------------------------------------------------- + +func TestDefaultSession_GetPrompt(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + prompt string + mockFn func(ctx context.Context, name string, arguments map[string]any) (*vmcp.PromptGetResult, error) + wantErr bool + wantErrIs error + wantMsg string + }{ + { + name: "successful get", + prompt: "greet", + mockFn: func(_ context.Context, _ string, _ map[string]any) (*vmcp.PromptGetResult, error) { + return &vmcp.PromptGetResult{Messages: "hi there"}, nil + }, + wantMsg: "hi there", + }, + { + name: "prompt not in routing table", + prompt: "missing", + wantErr: true, + wantErrIs: ErrPromptNotFound, + }, + { + name: "backend error is propagated", + prompt: "greet", + mockFn: func(_ context.Context, _ string, _ map[string]any) (*vmcp.PromptGetResult, error) { + return nil, errors.New("backend unavailable") + }, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + mock := &mockConnectedBackend{getPromptFunc: tt.mockFn} + sess := buildTestSession(t, "b1", mock, + nil, nil, + []vmcp.Prompt{{Name: "greet", BackendID: "b1"}}, + ) + + result, err := sess.GetPrompt(context.Background(), tt.prompt, nil) + if tt.wantErr { + require.Error(t, err) + if tt.wantErrIs != nil { + assert.ErrorIs(t, err, tt.wantErrIs) + } + return + } + require.NoError(t, err) + assert.Equal(t, tt.wantMsg, result.Messages) + }) + } +} + +// --------------------------------------------------------------------------- +// Close +// --------------------------------------------------------------------------- + +func TestDefaultSession_Close(t *testing.T) { + t.Parallel() + + t.Run("closes all backend clients", func(t *testing.T) { + t.Parallel() + + mock := &mockConnectedBackend{} + sess := buildTestSession(t, "b1", mock, nil, nil, nil) + + require.NoError(t, sess.Close()) + assert.True(t, mock.closeCalled.Load()) + }) + + t.Run("idempotent", func(t *testing.T) { + t.Parallel() + + mock := &mockConnectedBackend{} + sess := buildTestSession(t, "b1", mock, nil, nil, nil) + + require.NoError(t, sess.Close()) + require.NoError(t, sess.Close()) // second call must not panic or error + }) + + t.Run("waits for in-flight ops before closing clients", func(t *testing.T) { + t.Parallel() + + callInProgress := make(chan struct{}) + callRelease := make(chan struct{}) + + mock := &mockConnectedBackend{ + callToolFunc: func(_ context.Context, _ string, _, _ map[string]any) (*vmcp.ToolCallResult, error) { + close(callInProgress) + <-callRelease + return &vmcp.ToolCallResult{}, nil + }, + } + sess := buildTestSession(t, "b1", mock, + []vmcp.Tool{{Name: "slow"}}, nil, nil, + ) + + var callDone atomic.Bool + go func() { + _, _ = sess.CallTool(context.Background(), "slow", nil, nil) + callDone.Store(true) + }() + + // Wait until the call is actually in progress. + <-callInProgress + + closeDone := make(chan error, 1) + go func() { + closeDone <- sess.Close() + }() + + // Close must not return until the call completes. + select { + case <-closeDone: + t.Fatal("Close returned before in-flight call finished") + case <-time.After(50 * time.Millisecond): + // Expected: Close is blocking. + } + + close(callRelease) // let the call finish + require.NoError(t, <-closeDone) + assert.True(t, callDone.Load()) + assert.True(t, mock.closeCalled.Load()) + }) + + t.Run("returns joined error when a client fails to close", func(t *testing.T) { + t.Parallel() + + closeErr := errors.New("close failed") + mock := &mockConnectedBackend{closeErr: closeErr} + sess := buildTestSession(t, "b1", mock, nil, nil, nil) + + err := sess.Close() + require.Error(t, err) + assert.ErrorContains(t, err, "close failed") + }) + + t.Run("operations after close return ErrSessionClosed", func(t *testing.T) { + t.Parallel() + + mock := &mockConnectedBackend{} + sess := buildTestSession(t, "b1", mock, + []vmcp.Tool{{Name: "search"}}, + []vmcp.Resource{{URI: "file://x"}}, + []vmcp.Prompt{{Name: "greet"}}, + ) + require.NoError(t, sess.Close()) + + _, err := sess.CallTool(context.Background(), "search", nil, nil) + assert.ErrorIs(t, err, ErrSessionClosed) + + _, err = sess.ReadResource(context.Background(), "file://x") + assert.ErrorIs(t, err, ErrSessionClosed) + + _, err = sess.GetPrompt(context.Background(), "greet", nil) + assert.ErrorIs(t, err, ErrSessionClosed) + }) +} + +func TestDefaultSession_ErrNoBackendClient(t *testing.T) { + t.Parallel() + + // Build a session where the routing table points to backend "b1" but the + // connections map has no entry for it. This exercises the ErrNoBackendClient + // path in CallTool, ReadResource, and GetPrompt. + target := &vmcp.BackendTarget{WorkloadID: "b1"} + sess := &defaultMultiSession{ + Session: transportsession.NewStreamableSession("test-no-client"), + connections: map[string]internalbk.Session{}, // deliberately empty + routingTable: &vmcp.RoutingTable{ + Tools: map[string]*vmcp.BackendTarget{"search": target}, + Resources: map[string]*vmcp.BackendTarget{"file://readme": target}, + Prompts: map[string]*vmcp.BackendTarget{"greet": target}, + }, + tools: []vmcp.Tool{{Name: "search", BackendID: "b1"}}, + resources: []vmcp.Resource{{URI: "file://readme", BackendID: "b1"}}, + prompts: []vmcp.Prompt{{Name: "greet", BackendID: "b1"}}, + backendSessions: map[string]string{}, + } + defer func() { _ = sess.Close() }() + + _, err := sess.CallTool(context.Background(), "search", nil, nil) + require.ErrorIs(t, err, ErrNoBackendClient) + + _, err = sess.ReadResource(context.Background(), "file://readme") + require.ErrorIs(t, err, ErrNoBackendClient) + + _, err = sess.GetPrompt(context.Background(), "greet", nil) + require.ErrorIs(t, err, ErrNoBackendClient) +} + +func TestDefaultSession_Close_AllBackendsAttemptedOnError(t *testing.T) { + t.Parallel() + + // Both backends return a close error. Verify that both are called (the + // error-collection loop must not short-circuit after the first failure). + b1 := &mockConnectedBackend{closeErr: errors.New("b1 close error")} + b2 := &mockConnectedBackend{closeErr: errors.New("b2 close error")} + + sess := &defaultMultiSession{ + Session: transportsession.NewStreamableSession("test-multi-close"), + connections: map[string]internalbk.Session{ + "b1": b1, + "b2": b2, + }, + routingTable: &vmcp.RoutingTable{ + Tools: map[string]*vmcp.BackendTarget{}, + Resources: map[string]*vmcp.BackendTarget{}, + Prompts: map[string]*vmcp.BackendTarget{}, + }, + backendSessions: map[string]string{}, + } + + err := sess.Close() + require.Error(t, err) + assert.True(t, b1.closeCalled.Load(), "b1.close must be called even though b2 also errors") + assert.True(t, b2.closeCalled.Load(), "b2.close must be called even though b1 also errors") + assert.ErrorContains(t, err, "b1 close error") + assert.ErrorContains(t, err, "b2 close error") +} + +// --------------------------------------------------------------------------- +// SessionFactory / MakeSession +// --------------------------------------------------------------------------- + +func TestNewSessionFactory_MakeSession(t *testing.T) { + t.Parallel() + + tool := vmcp.Tool{Name: "search", BackendID: "b1"} + resource := vmcp.Resource{URI: "file://readme", BackendID: "b1"} + prompt := vmcp.Prompt{Name: "greet", BackendID: "b1"} + + backend := &vmcp.Backend{ + ID: "b1", + Name: "backend-1", + BaseURL: "http://localhost:9999", + TransportType: "streamable-http", + } + + //nolint:unparam // second return is always nil by design in the success-path connector + successConnector := func(_ context.Context, _ *vmcp.BackendTarget, _ *auth.Identity) (internalbk.Session, *vmcp.CapabilityList, error) { + return &mockConnectedBackend{sessID: "bs-1"}, &vmcp.CapabilityList{ + Tools: []vmcp.Tool{tool}, + Resources: []vmcp.Resource{resource}, + Prompts: []vmcp.Prompt{prompt}, + }, nil + } + + t.Run("creates session with backend capabilities", func(t *testing.T) { + t.Parallel() + + factory := newSessionFactoryWithConnector(successConnector) + sess, err := factory.MakeSession(context.Background(), nil, []*vmcp.Backend{backend}) + require.NoError(t, err) + require.NotNil(t, sess) + + assert.NotEmpty(t, sess.ID()) + assert.Equal(t, transportsession.SessionTypeStreamable, sess.Type()) + assert.Len(t, sess.Tools(), 1) + assert.Len(t, sess.Resources(), 1) + assert.Len(t, sess.Prompts(), 1) + assert.Equal(t, "bs-1", sess.BackendSessions()["b1"]) + + require.NoError(t, sess.Close()) + }) + + t.Run("each session gets a unique ID", func(t *testing.T) { + t.Parallel() + + factory := newSessionFactoryWithConnector(successConnector) + s1, err := factory.MakeSession(context.Background(), nil, []*vmcp.Backend{backend}) + require.NoError(t, err) + s2, err := factory.MakeSession(context.Background(), nil, []*vmcp.Backend{backend}) + require.NoError(t, err) + + assert.NotEqual(t, s1.ID(), s2.ID()) + + require.NoError(t, s1.Close()) + require.NoError(t, s2.Close()) + }) + + t.Run("no backends produces empty session", func(t *testing.T) { + t.Parallel() + + factory := newSessionFactoryWithConnector(successConnector) + sess, err := factory.MakeSession(context.Background(), nil, nil) + require.NoError(t, err) + require.NotNil(t, sess) + + assert.Empty(t, sess.Tools()) + assert.Empty(t, sess.Resources()) + assert.Empty(t, sess.Prompts()) + require.NoError(t, sess.Close()) + }) + + t.Run("nil backend entries are skipped without panic", func(t *testing.T) { + t.Parallel() + + factory := newSessionFactoryWithConnector(successConnector) + // Mix of valid and nil entries; nil must not cause a panic. + backends := []*vmcp.Backend{nil, backend, nil} + sess, err := factory.MakeSession(context.Background(), nil, backends) + require.NoError(t, err) + require.NotNil(t, sess) + + // The one valid backend should still have been initialised. + assert.Len(t, sess.Tools(), 1) + require.NoError(t, sess.Close()) + }) +} + +func TestNewSessionFactory_PartialInitialisation(t *testing.T) { + t.Parallel() + + backends := []*vmcp.Backend{ + {ID: "ok", Name: "ok", BaseURL: "http://ok:9999", TransportType: "streamable-http"}, + {ID: "fail", Name: "fail", BaseURL: "http://fail:9999", TransportType: "streamable-http"}, + } + + connector := func(_ context.Context, target *vmcp.BackendTarget, _ *auth.Identity) (internalbk.Session, *vmcp.CapabilityList, error) { + if target.WorkloadID == "fail" { + return nil, nil, errors.New("backend unavailable") + } + return &mockConnectedBackend{sessID: "s-ok"}, &vmcp.CapabilityList{ + Tools: []vmcp.Tool{{Name: "tool-ok", BackendID: "ok"}}, + }, nil + } + + factory := newSessionFactoryWithConnector(connector) + sess, err := factory.MakeSession(context.Background(), nil, backends) + require.NoError(t, err, "partial init must not return an error") + require.NotNil(t, sess) + + // Only the successful backend's capabilities are present. + assert.Len(t, sess.Tools(), 1) + assert.Equal(t, "tool-ok", sess.Tools()[0].Name) + assert.NotContains(t, sess.BackendSessions(), "fail") + + require.NoError(t, sess.Close()) +} + +func TestNewSessionFactory_ConnectorReturnsNilWithoutError(t *testing.T) { + t.Parallel() + + backend := &vmcp.Backend{ID: "b1", Name: "b1", BaseURL: "http://x:9", TransportType: "streamable-http"} + + tests := []struct { + name string + connector backendConnector + wantConnClose bool // true when the connector returns a non-nil conn that must be closed + }{ + { + name: "nil conn with nil caps", + connector: func(_ context.Context, _ *vmcp.BackendTarget, _ *auth.Identity) (internalbk.Session, *vmcp.CapabilityList, error) { + return nil, nil, nil + }, + }, + { + name: "nil conn with non-nil caps", + connector: func(_ context.Context, _ *vmcp.BackendTarget, _ *auth.Identity) (internalbk.Session, *vmcp.CapabilityList, error) { + return nil, &vmcp.CapabilityList{}, nil + }, + }, + { + name: "non-nil conn with nil caps must close conn to avoid leak", + wantConnClose: true, + connector: func(_ context.Context, _ *vmcp.BackendTarget, _ *auth.Identity) (internalbk.Session, *vmcp.CapabilityList, error) { + return &mockConnectedBackend{}, nil, nil + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + // Replace the connector with one that captures the mock so we can + // inspect closeCalled after MakeSession returns. + var captured *mockConnectedBackend + wrappedConnector := func(ctx context.Context, target *vmcp.BackendTarget, id *auth.Identity) (internalbk.Session, *vmcp.CapabilityList, error) { + conn, caps, err := tt.connector(ctx, target, id) + if m, ok := conn.(*mockConnectedBackend); ok { + captured = m + } + return conn, caps, err + } + + factory := newSessionFactoryWithConnector(wrappedConnector) + sess, err := factory.MakeSession(context.Background(), nil, []*vmcp.Backend{backend}) + require.NoError(t, err) + require.NotNil(t, sess) + assert.Empty(t, sess.Tools()) + require.NoError(t, sess.Close()) + + if tt.wantConnClose { + require.NotNil(t, captured, "expected connector to return a mock conn") + assert.True(t, captured.closeCalled.Load(), "leaked connection was not closed") + } + }) + } +} + +func TestNewSessionFactory_ConnectorReturnsConnWithError(t *testing.T) { + t.Parallel() + + // Connector returns a non-nil conn alongside an error — the conn must be + // closed to avoid a connection leak. + backend := &vmcp.Backend{ID: "b1", Name: "b1", BaseURL: "http://x:9", TransportType: "streamable-http"} + leaked := &mockConnectedBackend{} + + connector := func(_ context.Context, _ *vmcp.BackendTarget, _ *auth.Identity) (internalbk.Session, *vmcp.CapabilityList, error) { + return leaked, nil, errors.New("init failed but conn was partially opened") + } + + factory := newSessionFactoryWithConnector(connector) + sess, err := factory.MakeSession(context.Background(), nil, []*vmcp.Backend{backend}) + require.NoError(t, err, "partial failure must not abort the session") + require.NotNil(t, sess) + assert.Empty(t, sess.Tools()) + require.NoError(t, sess.Close()) + + assert.True(t, leaked.closeCalled.Load(), "leaked connection was not closed") +} + +func TestNewSessionFactory_CapabilityNameConflictIsResolvedDeterministically(t *testing.T) { + t.Parallel() + + // Both backends advertise the same tool, resource, and prompt name. + // "alpha" sorts before "zeta" alphabetically, so "alpha" must always win. + backends := []*vmcp.Backend{ + // Intentionally listed in reverse order to prove sorting is applied. + {ID: "zeta", Name: "zeta", BaseURL: "http://zeta:9", TransportType: "streamable-http"}, + {ID: "alpha", Name: "alpha", BaseURL: "http://alpha:9", TransportType: "streamable-http"}, + } + + connector := func(_ context.Context, target *vmcp.BackendTarget, _ *auth.Identity) (internalbk.Session, *vmcp.CapabilityList, error) { + return &mockConnectedBackend{sessID: target.WorkloadID}, &vmcp.CapabilityList{ + Tools: []vmcp.Tool{{Name: "fetch", BackendID: target.WorkloadID}}, + Resources: []vmcp.Resource{{URI: "file://data", BackendID: target.WorkloadID}}, + Prompts: []vmcp.Prompt{{Name: "greet", BackendID: target.WorkloadID}}, + }, nil + } + + factory := newSessionFactoryWithConnector(connector) + sess, err := factory.MakeSession(context.Background(), nil, backends) + require.NoError(t, err) + require.NotNil(t, sess) + defer func() { require.NoError(t, sess.Close()) }() + + // Each capability should appear exactly once (no duplicates). + require.Len(t, sess.Tools(), 1) + require.Len(t, sess.Resources(), 1) + require.Len(t, sess.Prompts(), 1) + + // "alpha" must win because it sorts before "zeta". + assert.Equal(t, "alpha", sess.Tools()[0].BackendID) + assert.Equal(t, "alpha", sess.Resources()[0].BackendID) + assert.Equal(t, "alpha", sess.Prompts()[0].BackendID) + + // Calling the conflicted tool must reach "alpha", not "zeta". + result, err := sess.CallTool(context.Background(), "fetch", nil, nil) + require.NoError(t, err) + require.NotNil(t, result) +} + +func TestNewSessionFactory_AllBackendsFail(t *testing.T) { + t.Parallel() + + backend := &vmcp.Backend{ID: "b1", Name: "b1", BaseURL: "http://x:9", TransportType: "streamable-http"} + connector := func(_ context.Context, _ *vmcp.BackendTarget, _ *auth.Identity) (internalbk.Session, *vmcp.CapabilityList, error) { + return nil, nil, errors.New("down") + } + + factory := newSessionFactoryWithConnector(connector) + sess, err := factory.MakeSession(context.Background(), nil, []*vmcp.Backend{backend}) + require.NoError(t, err, "all-fail must still return a valid (empty) session") + require.NotNil(t, sess) + + assert.Empty(t, sess.Tools()) + require.NoError(t, sess.Close()) +} + +func TestNewSessionFactory_BackendInitTimeout(t *testing.T) { + t.Parallel() + + backend := &vmcp.Backend{ID: "slow", Name: "slow", BaseURL: "http://x:9", TransportType: "streamable-http"} + + released := make(chan struct{}) + connector := func(ctx context.Context, _ *vmcp.BackendTarget, _ *auth.Identity) (internalbk.Session, *vmcp.CapabilityList, error) { + select { + case <-ctx.Done(): + return nil, nil, ctx.Err() + case <-released: + return &mockConnectedBackend{}, &vmcp.CapabilityList{}, nil + } + } + + factory := newSessionFactoryWithConnector(connector, WithBackendInitTimeout(50*time.Millisecond)) + sess, err := factory.MakeSession(context.Background(), nil, []*vmcp.Backend{backend}) + require.NoError(t, err, "timeout is a partial failure, not a hard error") + require.NotNil(t, sess) + + // Timed-out backend produces no capabilities. + assert.Empty(t, sess.Tools()) + close(released) // allow goroutine to unblock + require.NoError(t, sess.Close()) +} + +func TestNewSessionFactory_ParallelInit(t *testing.T) { + t.Parallel() + + const numBackends = 5 + backends := make([]*vmcp.Backend, numBackends) + for i := range backends { + backends[i] = &vmcp.Backend{ + ID: fmt.Sprintf("b%d", i), + Name: fmt.Sprintf("b%d", i), + BaseURL: "http://x:9", + TransportType: "streamable-http", + } + } + + var initCount atomic.Int32 + var mu sync.Mutex + var maxConcurrent, current int32 + + connector := func(_ context.Context, target *vmcp.BackendTarget, _ *auth.Identity) (internalbk.Session, *vmcp.CapabilityList, error) { + mu.Lock() + current++ + if current > maxConcurrent { + maxConcurrent = current + } + mu.Unlock() + + time.Sleep(10 * time.Millisecond) // simulate network latency + initCount.Add(1) + + mu.Lock() + current-- + mu.Unlock() + + return &mockConnectedBackend{sessID: target.WorkloadID}, &vmcp.CapabilityList{ + Tools: []vmcp.Tool{{Name: "t-" + target.WorkloadID, BackendID: target.WorkloadID}}, + }, nil + } + + factory := newSessionFactoryWithConnector(connector, WithMaxBackendInitConcurrency(3)) + sess, err := factory.MakeSession(context.Background(), nil, backends) + require.NoError(t, err) + + // All backends must have been initialised. + assert.Equal(t, int32(numBackends), initCount.Load()) + assert.Len(t, sess.Tools(), numBackends) + + // Concurrency limit must have been respected. + assert.LessOrEqual(t, maxConcurrent, int32(3)) + + require.NoError(t, sess.Close()) +} + +func TestNewSessionFactory_MakeSession_Metadata(t *testing.T) { + t.Parallel() + + backend1 := &vmcp.Backend{ID: "b1", Name: "backend-1", BaseURL: "http://localhost:9001", TransportType: "streamable-http"} + backend2 := &vmcp.Backend{ID: "b2", Name: "backend-2", BaseURL: "http://localhost:9002", TransportType: "streamable-http"} + + //nolint:unparam // error return is always nil by design in the success-path connector + successConnector := func(_ context.Context, _ *vmcp.BackendTarget, _ *auth.Identity) (internalbk.Session, *vmcp.CapabilityList, error) { + return &mockConnectedBackend{}, &vmcp.CapabilityList{}, nil + } + failConnector := func(_ context.Context, _ *vmcp.BackendTarget, _ *auth.Identity) (internalbk.Session, *vmcp.CapabilityList, error) { + return nil, nil, errors.New("connection refused") + } + + tests := []struct { + name string + connector backendConnector + identity *auth.Identity + backends []*vmcp.Backend + wantSubject string // non-empty → assert equal; empty → assert key absent + wantBackendIDs string // non-empty → assert equal; empty → assert key absent + }{ + { + name: "sets identity subject and backend IDs", + connector: successConnector, + identity: &auth.Identity{Subject: "user-123"}, + backends: []*vmcp.Backend{backend1}, + wantSubject: "user-123", + wantBackendIDs: "b1", + }, + { + name: "omits subject when identity is nil", + connector: successConnector, + identity: nil, + backends: []*vmcp.Backend{backend1}, + wantBackendIDs: "b1", + }, + { + name: "omits subject when subject is empty", + connector: successConnector, + identity: &auth.Identity{Subject: ""}, + backends: []*vmcp.Backend{backend1}, + wantBackendIDs: "b1", + }, + { + name: "backend IDs are sorted", + connector: successConnector, + backends: []*vmcp.Backend{backend2, backend1}, // intentionally reversed + wantBackendIDs: "b1,b2", + }, + { + name: "omits backend IDs when no backends connect", + connector: failConnector, + backends: []*vmcp.Backend{backend1}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + factory := newSessionFactoryWithConnector(tt.connector) + sess, err := factory.MakeSession(context.Background(), tt.identity, tt.backends) + require.NoError(t, err) + require.NotNil(t, sess) + defer func() { require.NoError(t, sess.Close()) }() + + meta := sess.GetMetadata() + + if tt.wantSubject != "" { + assert.Equal(t, tt.wantSubject, meta[MetadataKeyIdentitySubject]) + } else { + _, ok := meta[MetadataKeyIdentitySubject] + assert.False(t, ok, "identity subject key should be absent") + } + + if tt.wantBackendIDs != "" { + assert.Equal(t, tt.wantBackendIDs, meta[MetadataKeyBackendIDs]) + } else { + _, ok := meta[MetadataKeyBackendIDs] + assert.False(t, ok, "backend IDs key should be absent") + } + }) + } +} + +// --------------------------------------------------------------------------- +// buildRoutingTable +// --------------------------------------------------------------------------- + +func TestBuildRoutingTable(t *testing.T) { + t.Parallel() + + target := func(id string) *vmcp.BackendTarget { + return &vmcp.BackendTarget{WorkloadID: id, WorkloadName: id} + } + + tests := []struct { + name string + results []initResult + wantTools []string // expected tool names in order + wantResources []string // expected resource URIs in order + wantPrompts []string // expected prompt names in order + // When a capability appears in multiple backends, wantWinner[capName] is + // the expected winning WorkloadID. + wantWinner map[string]string + }{ + { + name: "empty input", + results: nil, + wantTools: nil, + wantResources: nil, + wantPrompts: nil, + }, + { + name: "single backend all capability types", + results: []initResult{ + { + target: target("a"), + caps: &vmcp.CapabilityList{ + Tools: []vmcp.Tool{{Name: "t1"}, {Name: "t2"}}, + Resources: []vmcp.Resource{{URI: "res://1"}, {URI: "res://2"}}, + Prompts: []vmcp.Prompt{{Name: "p1"}}, + }, + }, + }, + wantTools: []string{"t1", "t2"}, + wantResources: []string{"res://1", "res://2"}, + wantPrompts: []string{"p1"}, + }, + { + name: "conflict resolution: first backend in sorted order wins", + results: []initResult{ + // Pre-sorted: "alpha" before "zeta" + { + target: target("alpha"), + caps: &vmcp.CapabilityList{ + Tools: []vmcp.Tool{{Name: "shared"}}, + }, + }, + { + target: target("zeta"), + caps: &vmcp.CapabilityList{ + Tools: []vmcp.Tool{{Name: "shared"}}, + }, + }, + }, + wantTools: []string{"shared"}, + wantWinner: map[string]string{"shared": "alpha"}, + }, + { + name: "non-conflicting capabilities from two backends are merged", + results: []initResult{ + { + target: target("a"), + caps: &vmcp.CapabilityList{Tools: []vmcp.Tool{{Name: "t-a"}}}, + }, + { + target: target("b"), + caps: &vmcp.CapabilityList{Tools: []vmcp.Tool{{Name: "t-b"}}}, + }, + }, + wantTools: []string{"t-a", "t-b"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + rt, tools, resources, prompts := buildRoutingTable(tt.results) + require.NotNil(t, rt) + + // Check list lengths and names. + toolNames := make([]string, len(tools)) + for i, t := range tools { + toolNames[i] = t.Name + } + if tt.wantTools == nil { + assert.Empty(t, tools) + } else { + assert.Equal(t, tt.wantTools, toolNames) + } + + resURIs := make([]string, len(resources)) + for i, r := range resources { + resURIs[i] = r.URI + } + if tt.wantResources == nil { + assert.Empty(t, resources) + } else { + assert.Equal(t, tt.wantResources, resURIs) + } + + promptNames := make([]string, len(prompts)) + for i, p := range prompts { + promptNames[i] = p.Name + } + if tt.wantPrompts == nil { + assert.Empty(t, prompts) + } else { + assert.Equal(t, tt.wantPrompts, promptNames) + } + + // Check conflict winners. + for capName, wantBackend := range tt.wantWinner { + if got, ok := rt.Tools[capName]; ok { + assert.Equal(t, wantBackend, got.WorkloadID, "tool %q winner", capName) + } else if got, ok := rt.Resources[capName]; ok { + assert.Equal(t, wantBackend, got.WorkloadID, "resource %q winner", capName) + } else if got, ok := rt.Prompts[capName]; ok { + assert.Equal(t, wantBackend, got.WorkloadID, "prompt %q winner", capName) + } else { + t.Errorf("capability %q not found in any routing table", capName) + } + } + }) + } +} + +func TestWithMaxBackendInitConcurrency_IgnoresNonPositive(t *testing.T) { + t.Parallel() + + f := &defaultMultiSessionFactory{maxConcurrency: defaultMaxBackendInitConcurrency} + WithMaxBackendInitConcurrency(0)(f) + assert.Equal(t, defaultMaxBackendInitConcurrency, f.maxConcurrency) + + WithMaxBackendInitConcurrency(-5)(f) + assert.Equal(t, defaultMaxBackendInitConcurrency, f.maxConcurrency) +} + +func TestWithBackendInitTimeout_IgnoresNonPositive(t *testing.T) { + t.Parallel() + + f := &defaultMultiSessionFactory{backendInitTimeout: defaultBackendInitTimeout} + WithBackendInitTimeout(0)(f) + assert.Equal(t, defaultBackendInitTimeout, f.backendInitTimeout) + + WithBackendInitTimeout(-time.Second)(f) + assert.Equal(t, defaultBackendInitTimeout, f.backendInitTimeout) +} diff --git a/pkg/vmcp/session/factory.go b/pkg/vmcp/session/factory.go new file mode 100644 index 0000000000..fa60143644 --- /dev/null +++ b/pkg/vmcp/session/factory.go @@ -0,0 +1,296 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package session + +import ( + "context" + "log/slog" + "sort" + "strings" + "sync" + "time" + + "github.com/google/uuid" + + "github.com/stacklok/toolhive/pkg/auth" + transportsession "github.com/stacklok/toolhive/pkg/transport/session" + "github.com/stacklok/toolhive/pkg/vmcp" + vmcpauth "github.com/stacklok/toolhive/pkg/vmcp/auth" + "github.com/stacklok/toolhive/pkg/vmcp/session/internal/backend" +) + +const ( + defaultMaxBackendInitConcurrency = 10 + defaultBackendInitTimeout = 30 * time.Second + + // MetadataKeyIdentitySubject is the transport-session metadata key that + // holds the subject claim of the authenticated caller (identity.Subject). + // Set at session creation; empty for anonymous callers. + MetadataKeyIdentitySubject = "vmcp.identity.subject" + + // MetadataKeyBackendIDs is the transport-session metadata key that holds + // a comma-separated, sorted list of successfully-connected backend IDs. + // The key is omitted entirely when no backends connected. + MetadataKeyBackendIDs = "vmcp.backend.ids" +) + +// MultiSessionFactory creates new MultiSessions for connecting clients. +type MultiSessionFactory interface { + // MakeSession creates a new MultiSession for the given identity against the + // provided set of backends. Backend clients are initialised in parallel + // with bounded concurrency (see WithMaxBackendInitConcurrency). + // + // Partial initialisation: if a backend fails to initialise, a warning is + // logged and the session continues with the remaining backends. The caller + // receives a valid session as long as at least one backend succeeded. + // + // If all backends fail, MakeSession still returns a valid (empty) session + // rather than an error, allowing clients to connect even when all backends + // are temporarily unavailable. + MakeSession(ctx context.Context, identity *auth.Identity, backends []*vmcp.Backend) (MultiSession, error) +} + +// backendConnector creates a connected, initialised backend Session for use +// within a single MultiSession. It is called once per backend during MakeSession. +// +// The connector is responsible for: +// 1. Creating and starting the MCP client transport. +// 2. Running the MCP Initialize handshake. +// 3. Querying backend capabilities (tools, resources, prompts). +// +// The returned backend.Session owns the underlying transport connection and +// must be closed when the session ends. The returned CapabilityList is used +// to populate the session's routing table and capability lists. +// +// On error the factory treats the failure as a partial failure: a warning is +// logged and the backend is excluded from the session. +type backendConnector func( + ctx context.Context, + target *vmcp.BackendTarget, + identity *auth.Identity, +) (backend.Session, *vmcp.CapabilityList, error) + +// defaultMultiSessionFactory is the production MultiSessionFactory implementation. +type defaultMultiSessionFactory struct { + connector backendConnector + maxConcurrency int + backendInitTimeout time.Duration +} + +// MultiSessionFactoryOption configures a defaultMultiSessionFactory. +type MultiSessionFactoryOption func(*defaultMultiSessionFactory) + +// WithMaxBackendInitConcurrency sets the maximum number of backends that are +// initialised concurrently during MakeSession. Defaults to 10. +func WithMaxBackendInitConcurrency(n int) MultiSessionFactoryOption { + return func(f *defaultMultiSessionFactory) { + if n > 0 { + f.maxConcurrency = n + } + } +} + +// WithBackendInitTimeout sets the per-backend timeout during MakeSession. +// Defaults to 30 s. +func WithBackendInitTimeout(d time.Duration) MultiSessionFactoryOption { + return func(f *defaultMultiSessionFactory) { + if d > 0 { + f.backendInitTimeout = d + } + } +} + +// NewSessionFactory creates a MultiSessionFactory that connects to backends +// over HTTP using the given outgoing auth registry. +func NewSessionFactory(registry vmcpauth.OutgoingAuthRegistry, opts ...MultiSessionFactoryOption) MultiSessionFactory { + return newSessionFactoryWithConnector(backend.NewHTTPConnector(registry), opts...) +} + +// newSessionFactoryWithConnector creates a MultiSessionFactory backed by an +// arbitrary connector. Used by tests to inject a fake connector without +// requiring real HTTP backends. +func newSessionFactoryWithConnector(connector backendConnector, opts ...MultiSessionFactoryOption) MultiSessionFactory { + f := &defaultMultiSessionFactory{ + connector: connector, + maxConcurrency: defaultMaxBackendInitConcurrency, + backendInitTimeout: defaultBackendInitTimeout, + } + for _, opt := range opts { + opt(f) + } + return f +} + +// initResult captures the outcome of initialising a single backend. +type initResult struct { + target *vmcp.BackendTarget + conn backend.Session + caps *vmcp.CapabilityList +} + +// initOneBackend attempts to connect and initialise a single backend. +// It is called from a goroutine inside MakeSession and handles all partial- +// initialisation cases: connector errors, and nil conn/caps without an error. +// Returns a non-nil *initResult on success, nil when the backend should be +// skipped (failure already logged as a warning). +func (f *defaultMultiSessionFactory) initOneBackend( + ctx context.Context, + b *vmcp.Backend, + identity *auth.Identity, +) *initResult { + bCtx, cancel := context.WithTimeout(ctx, f.backendInitTimeout) + defer cancel() + + target := vmcp.BackendToTarget(b) + conn, caps, err := f.connector(bCtx, target, identity) + if err != nil { + if conn != nil { + _ = conn.Close() + } + slog.Warn("Failed to initialise backend for session; continuing without it", + "backendID", b.ID, + "backendName", b.Name, + "error", err, + ) + return nil + } + if conn == nil || caps == nil { + if conn != nil { + _ = conn.Close() + } + slog.Warn("Backend connector returned nil conn or caps with no error; skipping backend", + "backendID", b.ID, + "backendName", b.Name, + ) + return nil + } + return &initResult{target: target, conn: conn, caps: caps} +} + +// buildRoutingTable populates a RoutingTable and capability lists from a sorted +// slice of initResults. Results must be pre-sorted by WorkloadID so that the +// alphabetically-earlier backend wins when two backends share a capability name. +func buildRoutingTable(results []initResult) (*vmcp.RoutingTable, []vmcp.Tool, []vmcp.Resource, []vmcp.Prompt) { + rt := &vmcp.RoutingTable{ + Tools: make(map[string]*vmcp.BackendTarget), + Resources: make(map[string]*vmcp.BackendTarget), + Prompts: make(map[string]*vmcp.BackendTarget), + } + var tools []vmcp.Tool + var resources []vmcp.Resource + var prompts []vmcp.Prompt + + for _, r := range results { + for _, tool := range r.caps.Tools { + if _, ok := rt.Tools[tool.Name]; !ok { + tools = append(tools, tool) + rt.Tools[tool.Name] = r.target + } + } + for _, res := range r.caps.Resources { + if _, ok := rt.Resources[res.URI]; !ok { + resources = append(resources, res) + rt.Resources[res.URI] = r.target + } + } + for _, prompt := range r.caps.Prompts { + if _, ok := rt.Prompts[prompt.Name]; !ok { + prompts = append(prompts, prompt) + rt.Prompts[prompt.Name] = r.target + } + } + } + return rt, tools, resources, prompts +} + +// MakeSession implements MultiSessionFactory. +func (f *defaultMultiSessionFactory) MakeSession( + ctx context.Context, + identity *auth.Identity, + backends []*vmcp.Backend, +) (MultiSession, error) { + // Filter nil entries upfront so that every downstream dereference of a + // *vmcp.Backend is safe. Nil entries are logged and skipped, consistent + // with the partial-initialisation approach used for failed backends. + filtered := make([]*vmcp.Backend, 0, len(backends)) + for _, b := range backends { + if b == nil { + slog.Warn("Skipping nil backend entry during session creation") + continue + } + filtered = append(filtered, b) + } + backends = filtered + + // Initialise backends in parallel with bounded concurrency. + // Each goroutine writes to its own index so no lock on the slice is needed. + rawResults := make([]*initResult, len(backends)) + sem := make(chan struct{}, f.maxConcurrency) + var wg sync.WaitGroup + wg.Add(len(backends)) + for i, b := range backends { + go func(i int, b *vmcp.Backend) { + defer wg.Done() + sem <- struct{}{} + defer func() { <-sem }() + rawResults[i] = f.initOneBackend(ctx, b, identity) + }(i, b) + } + wg.Wait() + + // Collect successful results; sort by WorkloadID so that capability-name + // conflicts are resolved deterministically: the alphabetically-earlier + // backend always wins. + connections := make(map[string]backend.Session, len(backends)) + backendSessions := make(map[string]string, len(backends)) + results := make([]initResult, 0, len(backends)) + for _, r := range rawResults { + if r == nil { + continue + } + connections[r.target.WorkloadID] = r.conn + backendSessions[r.target.WorkloadID] = r.conn.SessionID() + results = append(results, *r) + } + sort.Slice(results, func(i, j int) bool { + return results[i].target.WorkloadID < results[j].target.WorkloadID + }) + + if len(results) == 0 && len(backends) > 0 { + slog.Warn("All backends failed to initialise; session will have no capabilities", + "backendCount", len(backends)) + } + + // Build the routing table; first-writer (alphabetically) wins on conflicts. + routingTable, allTools, allResources, allPrompts := buildRoutingTable(results) + + sessID := uuid.New().String() + transportSess := transportsession.NewStreamableSession(sessID) + + // Populate serialisable metadata so that the embedded transport session + // carries the identity reference and connected backend list when persisted + // via transportsession.Storage. + if identity != nil && identity.Subject != "" { + transportSess.SetMetadata(MetadataKeyIdentitySubject, identity.Subject) + } + if len(results) > 0 { + // IDs are extracted from the already-sorted results slice to avoid a + // second sort of the connections map. + ids := make([]string, len(results)) + for i, r := range results { + ids[i] = r.target.WorkloadID + } + transportSess.SetMetadata(MetadataKeyBackendIDs, strings.Join(ids, ",")) + } + + return &defaultMultiSession{ + Session: transportSess, + connections: connections, + routingTable: routingTable, + tools: allTools, + resources: allResources, + prompts: allPrompts, + backendSessions: backendSessions, + }, nil +} diff --git a/pkg/vmcp/session/internal/backend/mcp_session.go b/pkg/vmcp/session/internal/backend/mcp_session.go new file mode 100644 index 0000000000..84ab63a919 --- /dev/null +++ b/pkg/vmcp/session/internal/backend/mcp_session.go @@ -0,0 +1,428 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package backend + +import ( + "context" + "fmt" + "io" + "log/slog" + "net/http" + "time" + + mcpclient "github.com/mark3labs/mcp-go/client" + mcptransport "github.com/mark3labs/mcp-go/client/transport" + "github.com/mark3labs/mcp-go/mcp" + + "github.com/stacklok/toolhive/pkg/auth" + "github.com/stacklok/toolhive/pkg/versions" + "github.com/stacklok/toolhive/pkg/vmcp" + vmcpauth "github.com/stacklok/toolhive/pkg/vmcp/auth" + authtypes "github.com/stacklok/toolhive/pkg/vmcp/auth/types" + "github.com/stacklok/toolhive/pkg/vmcp/conversion" +) + +const ( + // maxBackendResponseSize caps each HTTP response body for streamable-HTTP + // backends to prevent memory exhaustion. Not applied to SSE transports — + // see createMCPClient for the rationale. + maxBackendResponseSize = 100 * 1024 * 1024 // 100 MB + + // defaultBackendRequestTimeout is the wall-clock deadline for individual + // streamable-HTTP requests. Applied at both the http.Client and SDK layers + // (defense-in-depth). Not used for SSE, whose stream lifetime is unbounded. + defaultBackendRequestTimeout = 30 * time.Second +) + +// httpRoundTripperFunc adapts a plain function to http.RoundTripper. +type httpRoundTripperFunc func(*http.Request) (*http.Response, error) + +func (f httpRoundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) { return f(req) } + +// authRoundTripper adds pre-resolved authentication to outgoing backend requests. +type authRoundTripper struct { + base http.RoundTripper + authStrategy vmcpauth.Strategy + authConfig *authtypes.BackendAuthStrategy + target *vmcp.BackendTarget +} + +func (a *authRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + reqClone := req.Clone(req.Context()) + if err := a.authStrategy.Authenticate(reqClone.Context(), reqClone, a.authConfig); err != nil { + return nil, fmt.Errorf("authentication failed for backend %s: %w", a.target.WorkloadID, err) + } + return a.base.RoundTrip(reqClone) +} + +// identityRoundTripper propagates the caller's identity to outgoing backend requests. +type identityRoundTripper struct { + base http.RoundTripper + identity *auth.Identity +} + +func (i *identityRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + if i.identity != nil { + ctx := auth.WithIdentity(req.Context(), i.identity) + req = req.Clone(ctx) + } + return i.base.RoundTrip(req) +} + +// Compile-time assertion: mcpSession must implement Session. +var _ Session = (*mcpSession)(nil) + +// mcpSession wraps a persistent mark3labs MCP client for one backend. +// It is created once per backend during MakeSession and closed when the session ends. +// +// Phase 1 limitation — no reconnection: if the underlying transport drops +// (network error, server restart, SSE stream EOF), all subsequent operations +// on this backend will fail with the transport error. The session must be +// closed and a new one created to reconnect. This affects SSE backends more +// visibly because SSE uses a single long-lived HTTP stream; streamable-HTTP +// backends open a new connection per request and are therefore more resilient. +type mcpSession struct { + client *mcpclient.Client + target *vmcp.BackendTarget // bound at creation; used for capability name translation + backendSessionID string // backend-assigned session ID (may be empty) +} + +// SessionID returns the backend-assigned session ID. +func (c *mcpSession) SessionID() string { return c.backendSessionID } + +// Close closes the underlying MCP client transport. +func (c *mcpSession) Close() error { return c.client.Close() } + +// CallTool invokes a named tool on this backend. +func (c *mcpSession) CallTool( + ctx context.Context, + toolName string, + arguments map[string]any, + meta map[string]any, +) (*vmcp.ToolCallResult, error) { + backendName := c.target.GetBackendCapabilityName(toolName) + if backendName != toolName { + slog.Debug("Translating tool name", "clientName", toolName, "backendName", backendName) + } + + result, err := c.client.CallTool(ctx, mcp.CallToolRequest{ + Params: mcp.CallToolParams{ + Name: backendName, + Arguments: arguments, + Meta: conversion.ToMCPMeta(meta), + }, + }) + if err != nil { + return nil, fmt.Errorf("tool %q call failed on backend %s: %w", toolName, c.target.WorkloadID, err) + } + + contentArray := conversion.ConvertMCPContents(result.Content) + + var structuredContent map[string]any + if result.StructuredContent != nil { + if m, ok := result.StructuredContent.(map[string]any); ok { + structuredContent = m + } + } + if structuredContent == nil { + structuredContent = conversion.ContentArrayToMap(contentArray) + } + + return &vmcp.ToolCallResult{ + Content: contentArray, + StructuredContent: structuredContent, + IsError: result.IsError, + Meta: conversion.FromMCPMeta(result.Meta), + }, nil +} + +// ReadResource reads a resource from this backend. +func (c *mcpSession) ReadResource( + ctx context.Context, + uri string, +) (*vmcp.ResourceReadResult, error) { + backendURI := c.target.GetBackendCapabilityName(uri) + if backendURI != uri { + slog.Debug("Translating resource URI", "clientURI", uri, "backendURI", backendURI) + } + + result, err := c.client.ReadResource(ctx, mcp.ReadResourceRequest{ + Params: mcp.ReadResourceParams{URI: backendURI}, + }) + if err != nil { + return nil, fmt.Errorf("resource %q read failed on backend %s: %w", uri, c.target.WorkloadID, err) + } + + data, mimeType := conversion.ConcatenateResourceContents(result.Contents) + + return &vmcp.ResourceReadResult{ + Contents: data, + MimeType: mimeType, + Meta: conversion.FromMCPMeta(result.Meta), + }, nil +} + +// GetPrompt retrieves a prompt from this backend. +func (c *mcpSession) GetPrompt( + ctx context.Context, + name string, + arguments map[string]any, +) (*vmcp.PromptGetResult, error) { + backendName := c.target.GetBackendCapabilityName(name) + if backendName != name { + slog.Debug("Translating prompt name", "clientName", name, "backendName", backendName) + } + + stringArgs := conversion.ConvertPromptArguments(arguments) + + result, err := c.client.GetPrompt(ctx, mcp.GetPromptRequest{ + Params: mcp.GetPromptParams{ + Name: backendName, + Arguments: stringArgs, + }, + }) + if err != nil { + return nil, fmt.Errorf("prompt %q get failed on backend %s: %w", name, c.target.WorkloadID, err) + } + + // NOTE: ConvertPromptMessages is lossy — non-text content (images, audio) + // is discarded. Phase 1 limitation; see vmcp.PromptGetResult. + return &vmcp.PromptGetResult{ + Messages: conversion.ConvertPromptMessages(result.Messages), + Description: result.Description, + Meta: conversion.FromMCPMeta(result.Meta), + }, nil +} + +// NewHTTPConnector returns a function that creates an HTTP-based (streamable-HTTP +// or SSE) persistent backend Session for each backend. +// +// registry provides the authentication strategy for outgoing backend requests. +// Pass a registry configured with the "unauthenticated" strategy to disable auth. +func NewHTTPConnector(registry vmcpauth.OutgoingAuthRegistry) func( + ctx context.Context, + target *vmcp.BackendTarget, + identity *auth.Identity, +) (Session, *vmcp.CapabilityList, error) { + return func( + ctx context.Context, + target *vmcp.BackendTarget, + identity *auth.Identity, + ) (Session, *vmcp.CapabilityList, error) { + c, err := createMCPClient(target, identity, registry) + if err != nil { + return nil, nil, fmt.Errorf("failed to create MCP client for backend %s: %w", target.WorkloadID, err) + } + + caps, err := initAndQueryCapabilities(ctx, c, target) + if err != nil { + _ = c.Close() + return nil, nil, fmt.Errorf("failed to initialise backend %s: %w", target.WorkloadID, err) + } + + // Extract the backend-assigned session ID when the transport supports it. + // Streamable-HTTP servers send an Mcp-Session-Id response header during + // Initialize; the mark3labs transport captures it internally and exposes + // it via GetSessionId(). SSE transports do not assign a session ID, so + // the field remains empty for those backends. + var backendSessionID string + if sh, ok := c.GetTransport().(*mcptransport.StreamableHTTP); ok { + backendSessionID = sh.GetSessionId() + } + + return &mcpSession{client: c, target: target, backendSessionID: backendSessionID}, caps, nil + } +} + +// createMCPClient builds and starts a mark3labs MCP client for target. +// The transport is started with context.Background() so its lifetime is bound +// to client.Close(), not to any caller-supplied init context. +func createMCPClient( + target *vmcp.BackendTarget, + identity *auth.Identity, + registry vmcpauth.OutgoingAuthRegistry, +) (*mcpclient.Client, error) { + // Resolve and validate the auth strategy once at client creation time. + strategyName := authtypes.StrategyTypeUnauthenticated + if target.AuthConfig != nil { + strategyName = target.AuthConfig.Type + } + strategy, err := registry.GetStrategy(strategyName) + if err != nil { + return nil, fmt.Errorf("auth strategy %q not found: %w", strategyName, err) + } + if err := strategy.Validate(target.AuthConfig); err != nil { + return nil, fmt.Errorf("invalid auth config for backend %s: %w", target.WorkloadID, err) + } + + slog.Debug("Applied authentication strategy", "strategy", strategy.Name(), "backendID", target.WorkloadID) + + // Build shared transport chain: auth → identity propagation. + // The per-transport sections below may add a size-limiting wrapper on top. + base := http.RoundTripper(http.DefaultTransport) + base = &authRoundTripper{ + base: base, + authStrategy: strategy, + authConfig: target.AuthConfig, + target: target, + } + base = &identityRoundTripper{base: base, identity: identity} + + var c *mcpclient.Client + switch target.TransportType { + case "streamable-http", "streamable": + // "streamable" is a legacy alias for "streamable-http". + // + // For streamable-HTTP, each MCP call is a single bounded HTTP + // request/response pair, so a per-response body size limit is safe and + // correct. http.Client.Timeout provides a hard wall-clock deadline; + // WithHTTPTimeout additionally wraps each SDK request in a + // context.WithTimeout so the mark3labs transport surfaces a descriptive + // error before the stdlib deadline fires. Both are set to + // defaultBackendRequestTimeout: defense-in-depth. + sizeLimited := httpRoundTripperFunc(func(req *http.Request) (*http.Response, error) { + resp, err := base.RoundTrip(req) + if err != nil { + return nil, err + } + resp.Body = struct { + io.Reader + io.Closer + }{ + Reader: io.LimitReader(resp.Body, maxBackendResponseSize), + Closer: resp.Body, + } + return resp, nil + }) + httpClient := &http.Client{ + Transport: sizeLimited, + Timeout: defaultBackendRequestTimeout, + } + c, err = mcpclient.NewStreamableHttpClient( + target.BaseURL, + mcptransport.WithHTTPTimeout(defaultBackendRequestTimeout), + mcptransport.WithHTTPBasicClient(httpClient), + ) + case "sse": + // For SSE, the entire session is delivered as one long-lived HTTP + // response body. Applying io.LimitReader to that body would silently + // terminate the connection after maxBackendResponseSize cumulative bytes + // — not per-event — which is wrong. Individual event size is bounded by + // the backend; operation deadlines are enforced via context cancellation. + // + // http.Client.Timeout is also omitted: it caps the full round-trip + // including body reads, which would kill the stream after the timeout. + httpClient := &http.Client{Transport: base} + c, err = mcpclient.NewSSEMCPClient( + target.BaseURL, + mcptransport.WithHTTPClient(httpClient), + ) + default: + return nil, fmt.Errorf("%w: %s (supported: streamable-http, sse)", + vmcp.ErrUnsupportedTransport, target.TransportType) + } + if err != nil { + return nil, fmt.Errorf("failed to create %s client: %w", target.TransportType, err) + } + + // Start the transport with context.Background() so that the transport's + // lifetime is scoped to the session (terminated by client.Close()) rather + // than to the per-backend init timeout context. The init timeout context + // is used only for the Initialize handshake and capability queries in + // initAndQueryCapabilities, both of which have bounded duration. + // Without this, the SSE transport would tear down its persistent read + // goroutine when the init goroutine's defer-cancel fires after init completes. + if err := c.Start(context.Background()); err != nil { + return nil, fmt.Errorf("failed to start client: %w", err) + } + + return c, nil +} + +// initAndQueryCapabilities runs the MCP Initialize handshake then discovers +// all capabilities (tools, resources, prompts) from the backend. +func initAndQueryCapabilities( + ctx context.Context, + c *mcpclient.Client, + target *vmcp.BackendTarget, +) (*vmcp.CapabilityList, error) { + result, err := c.Initialize(ctx, mcp.InitializeRequest{ + Params: mcp.InitializeParams{ + ProtocolVersion: mcp.LATEST_PROTOCOL_VERSION, + ClientInfo: mcp.Implementation{ + Name: "toolhive-vmcp", + Version: versions.Version, + }, + }, + }) + if err != nil { + return nil, fmt.Errorf("initialize failed: %w", err) + } + + serverCaps := result.Capabilities + caps := &vmcp.CapabilityList{} + + if serverCaps.Tools != nil { + toolsResult, listErr := c.ListTools(ctx, mcp.ListToolsRequest{}) + if listErr != nil { + return nil, fmt.Errorf("list tools failed: %w", listErr) + } + for _, t := range toolsResult.Tools { + caps.Tools = append(caps.Tools, vmcp.Tool{ + Name: t.Name, + Description: t.Description, + InputSchema: conversion.ConvertToolInputSchema(t.InputSchema), + BackendID: target.WorkloadID, + }) + } + } + + if serverCaps.Resources != nil { + resResult, listErr := c.ListResources(ctx, mcp.ListResourcesRequest{}) + if listErr != nil { + return nil, fmt.Errorf("list resources failed: %w", listErr) + } + for _, r := range resResult.Resources { + caps.Resources = append(caps.Resources, vmcp.Resource{ + URI: r.URI, + Name: r.Name, + Description: r.Description, + MimeType: r.MIMEType, + BackendID: target.WorkloadID, + }) + } + } + + if serverCaps.Prompts != nil { + promptsResult, listErr := c.ListPrompts(ctx, mcp.ListPromptsRequest{}) + if listErr != nil { + return nil, fmt.Errorf("list prompts failed: %w", listErr) + } + for _, p := range promptsResult.Prompts { + args := make([]vmcp.PromptArgument, len(p.Arguments)) + for j, a := range p.Arguments { + args[j] = vmcp.PromptArgument{ + Name: a.Name, + Description: a.Description, + Required: a.Required, + } + } + caps.Prompts = append(caps.Prompts, vmcp.Prompt{ + Name: p.Name, + Description: p.Description, + Arguments: args, + BackendID: target.WorkloadID, + }) + } + } + + slog.Debug("Backend capabilities", + "backendID", target.WorkloadID, + "tools", len(caps.Tools), + "resources", len(caps.Resources), + "prompts", len(caps.Prompts), + ) + + return caps, nil +} diff --git a/pkg/vmcp/session/internal/backend/mcp_session_test.go b/pkg/vmcp/session/internal/backend/mcp_session_test.go new file mode 100644 index 0000000000..77f3345465 --- /dev/null +++ b/pkg/vmcp/session/internal/backend/mcp_session_test.go @@ -0,0 +1,49 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package backend + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/stacklok/toolhive/pkg/vmcp" + vmcpauth "github.com/stacklok/toolhive/pkg/vmcp/auth" + "github.com/stacklok/toolhive/pkg/vmcp/auth/strategies" + authtypes "github.com/stacklok/toolhive/pkg/vmcp/auth/types" +) + +func newTestRegistry(t *testing.T) vmcpauth.OutgoingAuthRegistry { + t.Helper() + reg := vmcpauth.NewDefaultOutgoingAuthRegistry() + require.NoError(t, reg.RegisterStrategy( + authtypes.StrategyTypeUnauthenticated, + strategies.NewUnauthenticatedStrategy(), + )) + return reg +} + +func TestCreateMCPClient_UnsupportedTransport(t *testing.T) { + t.Parallel() + + unsupportedTypes := []string{"stdio", "grpc", "", "ws"} + for _, transport := range unsupportedTypes { + t.Run(transport, func(t *testing.T) { + t.Parallel() + + target := &vmcp.BackendTarget{ + WorkloadID: "test-backend", + WorkloadName: "test-backend", + BaseURL: "http://localhost:9999", + TransportType: transport, + } + + _, err := createMCPClient(target, nil, newTestRegistry(t)) + require.Error(t, err) + assert.ErrorIs(t, err, vmcp.ErrUnsupportedTransport, + "transport %q should return ErrUnsupportedTransport", transport) + }) + } +} diff --git a/pkg/vmcp/session/internal/backend/session.go b/pkg/vmcp/session/internal/backend/session.go new file mode 100644 index 0000000000..ffe82d9d2c --- /dev/null +++ b/pkg/vmcp/session/internal/backend/session.go @@ -0,0 +1,48 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +// Package backend defines the Session interface for a single persistent +// backend connection and provides the HTTP-based implementation used in +// production. It is internal to pkg/vmcp/session. +package backend + +import ( + "context" + + "github.com/stacklok/toolhive/pkg/vmcp" +) + +// Session abstracts a persistent, initialised MCP connection to a single +// backend server. It is created once per backend during session creation and +// reused for the lifetime of the parent MultiSession. +// +// Each Session is bound to exactly one backend at creation time — callers do +// not need to pass a routing target to individual method calls. +// +// Implementations must be safe for concurrent use. +type Session interface { + // CallTool invokes a named tool on this backend. + CallTool( + ctx context.Context, + toolName string, + arguments map[string]any, + meta map[string]any, + ) (*vmcp.ToolCallResult, error) + + // ReadResource reads a resource from this backend. + ReadResource(ctx context.Context, uri string) (*vmcp.ResourceReadResult, error) + + // GetPrompt retrieves a prompt from this backend. + GetPrompt( + ctx context.Context, + name string, + arguments map[string]any, + ) (*vmcp.PromptGetResult, error) + + // SessionID returns the backend-assigned session ID (if any). + // Returns "" if the backend did not assign a session ID. + SessionID() string + + // Close closes the underlying transport connection. + Close() error +} diff --git a/pkg/vmcp/session/session.go b/pkg/vmcp/session/session.go new file mode 100644 index 0000000000..be5ad82bbe --- /dev/null +++ b/pkg/vmcp/session/session.go @@ -0,0 +1,95 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package session + +import ( + "context" + + transportsession "github.com/stacklok/toolhive/pkg/transport/session" + "github.com/stacklok/toolhive/pkg/vmcp" +) + +// MultiSession is the vMCP domain session interface. It extends the +// transport-layer Session with behaviour: capability access and session-scoped +// backend routing across multiple backend connections. +// +// A MultiSession is a "session of sessions": each backend contributes its own +// persistent connection (see [backend.Session] in pkg/vmcp/session/internal/backend), +// and the MultiSession aggregates them behind a single routing table. +// +// # Distributed deployment note +// +// Because MCP clients cannot be serialised, horizontal scaling requires sticky +// sessions (session affinity at the load balancer). Without sticky sessions, a +// request routed to a different vMCP instance must recreate backend clients +// (one-time cost per re-route). This is a known trade-off documented in +// RFC THV-0038: https://github.com/stacklok/toolhive-rfcs/blob/main/rfcs/THV-0038-session-scoped-client-lifecycle.md +// +// # Dual-layer storage model +// +// A MultiSession separates two layers with different lifecycles: +// +// - Metadata layer (serialisable): session ID, timestamps, identity reference, +// backend ID list. Stored via the transportsession.Storage interface and +// can persist across restarts. +// +// - Runtime layer (non-serialisable): MCP client objects, routing table, +// capabilities, backend session ID map, closed flag. Lives only in-process. +// +// All session metadata goes through the same Storage interface — no parallel +// storage path is introduced. +type MultiSession interface { + transportsession.Session + + // Tools returns the resolved tools available in this session. + // The list is built once at session creation and is read-only thereafter. + Tools() []vmcp.Tool + + // Resources returns the resolved resources available in this session. + Resources() []vmcp.Resource + + // Prompts returns the resolved prompts available in this session. + Prompts() []vmcp.Prompt + + // BackendSessions returns a snapshot of the backend-assigned session IDs, + // keyed by backend workload ID. The backend session ID is assigned by the + // backend MCP server and is used to correlate vMCP sessions with backend + // sessions for debugging and auditing. + BackendSessions() map[string]string + + // CallTool invokes toolName on the appropriate backend for this session. + // The routing table is consulted to identify the backend; the + // session-scoped client for that backend is then used, avoiding + // per-request connection overhead. + // + // arguments contains the tool input parameters. + // meta contains protocol-level metadata (_meta) forwarded from the client. + CallTool( + ctx context.Context, + toolName string, + arguments map[string]any, + meta map[string]any, + ) (*vmcp.ToolCallResult, error) + + // ReadResource retrieves the resource identified by uri from the + // appropriate backend for this session. + ReadResource(ctx context.Context, uri string) (*vmcp.ResourceReadResult, error) + + // GetPrompt retrieves the named prompt from the appropriate backend for + // this session. + // + // arguments contains the prompt input parameters. + GetPrompt( + ctx context.Context, + name string, + arguments map[string]any, + ) (*vmcp.PromptGetResult, error) + + // Close releases all resources held by this session, including all + // backend client connections. It waits for any in-flight operations to + // complete before tearing down clients. + // + // Close is idempotent: calling it multiple times returns nil. + Close() error +} diff --git a/pkg/vmcp/session/vmcp_session.go b/pkg/vmcp/session/vmcp_session.go index 4b60208781..7236173bdd 100644 --- a/pkg/vmcp/session/vmcp_session.go +++ b/pkg/vmcp/session/vmcp_session.go @@ -31,6 +31,10 @@ var _ transportsession.Session = (*VMCPSession)(nil) // 2. Routing table and tools populated in AfterInitialize hook // 3. Retrieved by middleware on subsequent requests via type assertion // 4. Cleaned up automatically by session.Manager TTL worker +// +// TODO: VMCPSession is a transitional type. Once the server layer is wired to +// use [MultiSession] (Phase 2 of RFC THV-0038), VMCPSession will be removed. +// Tracked in https://github.com/stacklok/toolhive/issues/3865 type VMCPSession struct { *transportsession.StreamableSession routingTable *vmcp.RoutingTable