Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
167 changes: 38 additions & 129 deletions pkg/vmcp/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ package client

import (
"context"
"encoding/base64"
"errors"
"fmt"
"io"
Expand All @@ -23,6 +22,7 @@ import (
"github.com/mark3labs/mcp-go/mcp"

"github.com/stacklok/toolhive/pkg/auth"
"github.com/stacklok/toolhive/pkg/versions"
"github.com/stacklok/toolhive/pkg/vmcp"
vmcpauth "github.com/stacklok/toolhive/pkg/vmcp/auth"
authtypes "github.com/stacklok/toolhive/pkg/vmcp/auth/types"
Expand Down Expand Up @@ -187,44 +187,47 @@ func (h *httpBackendClient) defaultClientFactory(ctx context.Context, target *vm
identity: identity,
}

// Add size limit layer for DoS protection
sizeLimitedTransport := roundTripperFunc(func(req *http.Request) (*http.Response, error) {
resp, err := baseTransport.RoundTrip(req)
if err != nil {
return nil, err
}
// Wrap response body with size limit
resp.Body = struct {
io.Reader
io.Closer
}{
Reader: io.LimitReader(resp.Body, maxResponseSize),
Closer: resp.Body,
}
return resp, nil
})

// Create HTTP client with configured transport chain
// Set timeouts to prevent long-lived connections that require continuous listening
httpClient := &http.Client{
Transport: sizeLimitedTransport,
Timeout: 30 * time.Second, // Prevent hanging on connections
}

var c *client.Client

switch target.TransportType {
case "streamable-http", "streamable":
// "streamable" is a legacy alias for "streamable-http".
//
// For streamable-HTTP each MCP call is a single bounded HTTP
// request/response pair, so a per-response body size limit is safe.
sizeLimitedTransport := roundTripperFunc(func(req *http.Request) (*http.Response, error) {
resp, err := baseTransport.RoundTrip(req)
if err != nil {
return nil, err
}
resp.Body = struct {
io.Reader
io.Closer
}{
Reader: io.LimitReader(resp.Body, maxResponseSize),
Closer: resp.Body,
}
return resp, nil
})
httpClient := &http.Client{
Transport: sizeLimitedTransport,
Timeout: 30 * time.Second,
}
c, err = client.NewStreamableHttpClient(
target.BaseURL,
transport.WithHTTPTimeout(30*time.Second), // Set timeout instead of 0
transport.WithHTTPTimeout(30*time.Second),
transport.WithHTTPBasicClient(httpClient),
)
if err != nil {
return nil, fmt.Errorf("failed to create streamable-http client: %w", err)
}

case "sse":
// For SSE the entire session is one long-lived HTTP response body.
// Applying io.LimitReader would silently terminate the stream after
// maxResponseSize cumulative bytes — not per-event — which is wrong.
// http.Client.Timeout is also omitted: it would kill the stream.
httpClient := &http.Client{Transport: baseTransport}
c, err = client.NewSSEMCPClient(
target.BaseURL,
transport.WithHTTPClient(httpClient),
Expand Down Expand Up @@ -325,7 +328,7 @@ func initializeClient(ctx context.Context, c *client.Client) (*mcp.ServerCapabil
ProtocolVersion: mcp.LATEST_PROTOCOL_VERSION,
ClientInfo: mcp.Implementation{
Name: "toolhive-vmcp",
Version: "0.1.0",
Version: versions.Version,
},
Capabilities: mcp.ClientCapabilities{
// Virtual MCP acts as a client to backends
Expand Down Expand Up @@ -382,36 +385,6 @@ func queryPrompts(ctx context.Context, c *client.Client, supported bool, backend
return &mcp.ListPromptsResult{Prompts: []mcp.Prompt{}}, nil
}

// convertContent converts mcp.Content to vmcp.Content.
// This preserves the full content structure from backend responses.
func convertContent(content mcp.Content) vmcp.Content {
if textContent, ok := mcp.AsTextContent(content); ok {
return vmcp.Content{
Type: "text",
Text: textContent.Text,
}
}
if imageContent, ok := mcp.AsImageContent(content); ok {
return vmcp.Content{
Type: "image",
Data: imageContent.Data,
MimeType: imageContent.MIMEType,
}
}
if audioContent, ok := mcp.AsAudioContent(content); ok {
return vmcp.Content{
Type: "audio",
Data: audioContent.Data,
MimeType: audioContent.MIMEType,
}
}
// Handle embedded resources if needed
// Unknown content types are marked as "unknown" type with no data
slog.Warn("encountered unknown content type, marking as unknown content",
"type", fmt.Sprintf("%T", content))
return vmcp.Content{Type: "unknown"}
}

// ListCapabilities queries a backend for its MCP capabilities.
// Returns tools, resources, and prompts exposed by the backend.
// Only queries capabilities that the server advertises during initialization.
Expand Down Expand Up @@ -467,25 +440,10 @@ func (h *httpBackendClient) ListCapabilities(ctx context.Context, target *vmcp.B

// Convert tools
for i, tool := range toolsResp.Tools {
// Convert ToolInputSchema to map[string]any
// The ToolInputSchema is a struct with Type, Properties, Required fields
inputSchema := map[string]any{
"type": tool.InputSchema.Type,
}
if tool.InputSchema.Properties != nil {
inputSchema["properties"] = tool.InputSchema.Properties
}
if len(tool.InputSchema.Required) > 0 {
inputSchema["required"] = tool.InputSchema.Required
}
if tool.InputSchema.Defs != nil {
inputSchema["$defs"] = tool.InputSchema.Defs
}

capabilities.Tools[i] = vmcp.Tool{
Name: tool.Name,
Description: tool.Description,
InputSchema: inputSchema,
InputSchema: conversion.ConvertToolInputSchema(tool.InputSchema),
BackendID: target.WorkloadID,
}
}
Expand Down Expand Up @@ -608,11 +566,8 @@ func (h *httpBackendClient) CallTool(
// Continue processing - we return the result with IsError flag and metadata preserved
}

// Convert MCP content to vmcp.Content array
contentArray := make([]vmcp.Content, len(result.Content))
for i, content := range result.Content {
contentArray[i] = convertContent(content)
}
// Convert MCP content to vmcp.Content array.
contentArray := conversion.ConvertMCPContents(result.Content)

// Check for structured content first (preferred for composite tool step chaining).
// StructuredContent allows templates to access nested fields directly via {{.steps.stepID.output.field}}.
Expand Down Expand Up @@ -683,33 +638,8 @@ func (h *httpBackendClient) ReadResource(
return nil, fmt.Errorf("resource read failed on backend %s: %w", target.WorkloadID, err)
}

// Concatenate all resource contents
// MCP resources can have multiple contents (text or blob)
var data []byte
var mimeType string
for i, content := range result.Contents {
// Try to convert to TextResourceContents
if textContent, ok := mcp.AsTextResourceContents(content); ok {
data = append(data, []byte(textContent.Text)...)
if i == 0 && textContent.MIMEType != "" {
mimeType = textContent.MIMEType
}
} else if blobContent, ok := mcp.AsBlobResourceContents(content); ok {
// Blob is base64-encoded per MCP spec, decode it to bytes
decoded, err := base64.StdEncoding.DecodeString(blobContent.Blob)
if err != nil {
slog.Warn("failed to decode base64 blob from resource",
"resource", uri, "backend", target.WorkloadID, "error", err)
// Append raw blob as fallback
data = append(data, []byte(blobContent.Blob)...)
} else {
data = append(data, decoded...)
}
if i == 0 && blobContent.MIMEType != "" {
mimeType = blobContent.MIMEType
}
}
}
// Concatenate all resource content items into a single byte slice.
data, mimeType := conversion.ConcatenateResourceContents(result.Contents)

// Extract _meta field from backend response
meta := conversion.FromMCPMeta(result.Meta)
Expand Down Expand Up @@ -756,11 +686,7 @@ func (h *httpBackendClient) GetPrompt(
slog.Debug("translating prompt name", "client_name", name, "backend_name", backendPromptName)
}

// Convert map[string]any to map[string]string
stringArgs := make(map[string]string)
for k, v := range arguments {
stringArgs[k] = fmt.Sprintf("%v", v)
}
stringArgs := conversion.ConvertPromptArguments(arguments)

result, err := c.GetPrompt(ctx, mcp.GetPromptRequest{
Params: mcp.GetPromptParams{
Expand All @@ -772,26 +698,9 @@ func (h *httpBackendClient) GetPrompt(
return nil, fmt.Errorf("prompt get failed on backend %s: %w", target.WorkloadID, err)
}

// Concatenate all prompt messages into a single string
// MCP prompts return messages with role and content (Content interface)
var prompt string
for _, msg := range result.Messages {
if msg.Role != "" {
prompt += fmt.Sprintf("[%s] ", msg.Role)
}
// Try to convert content to TextContent
if textContent, ok := mcp.AsTextContent(msg.Content); ok {
prompt += textContent.Text + "\n"
}
// TODO: Handle other content types (image, audio, resource)
}

// Extract _meta field from backend response
meta := conversion.FromMCPMeta(result.Meta)

return &vmcp.PromptGetResult{
Messages: prompt,
Messages: conversion.ConvertPromptMessages(result.Messages),
Description: result.Description,
Meta: meta,
Meta: conversion.FromMCPMeta(result.Meta),
}, nil
}
Loading
Loading