diff --git a/go/plugins/mcp/client.go b/go/plugins/mcp/client.go index 732e6b9b67..1b05dfe1ec 100644 --- a/go/plugins/mcp/client.go +++ b/go/plugins/mcp/client.go @@ -75,7 +75,7 @@ type MCPClientOptions struct { type ServerRef struct { Client *client.Client Transport transport.Interface - Error string + Error error } // GenkitMCPClient represents a client for interacting with MCP servers. @@ -103,6 +103,10 @@ func NewGenkitMCPClient(options MCPClientOptions) (*GenkitMCPClient, error) { return nil, fmt.Errorf("failed to initialize MCP client: %w", err) } + if client.server.Error != nil { + return nil, client.server.Error + } + return client, nil } @@ -119,20 +123,24 @@ func (c *GenkitMCPClient) connect(options MCPClientOptions) error { // Create and configure transport transport, err := c.createTransport(options) if err != nil { + // no transport means no ability to create a server + c.server = &ServerRef{Error: err} return err } // Start the transport ctx := context.Background() if err := transport.Start(ctx); err != nil { - return fmt.Errorf("failed to start transport: %w", err) + wrappedErr := fmt.Errorf("failed to start transport: %w", err) + c.server = &ServerRef{Error: wrappedErr} + return wrappedErr } // Create MCP client mcpClient := client.NewClient(transport) // Initialize the client if not disabled - var serverError string + var serverError error if !options.Disabled { serverError = c.initializeClient(ctx, mcpClient, options.Version) } @@ -184,7 +192,7 @@ func (c *GenkitMCPClient) createTransport(options MCPClientOptions) (transport.I } // initializeClient initializes the MCP client connection -func (c *GenkitMCPClient) initializeClient(ctx context.Context, mcpClient *client.Client, version string) string { +func (c *GenkitMCPClient) initializeClient(ctx context.Context, mcpClient *client.Client, version string) error { initReq := mcp.InitializeRequest{ Params: struct { ProtocolVersion string `json:"protocolVersion"` @@ -202,10 +210,10 @@ func (c *GenkitMCPClient) initializeClient(ctx context.Context, mcpClient *clien _, err := mcpClient.Initialize(ctx, initReq) if err != nil { - return err.Error() + return err } - return "" + return nil } // Name returns the client name @@ -239,7 +247,14 @@ func (c *GenkitMCPClient) Restart(ctx context.Context) error { if err := c.Disconnect(); err != nil { logger.FromContext(ctx).Warn("Error closing MCP transport during restart", "client", c.options.Name, "error", err) } - return c.connect(c.options) + + if err := c.connect(c.options); err != nil { + return err + } + if c.server.Error != nil { + return fmt.Errorf("failed to restart MCP client: %w", c.server.Error) + } + return nil } // Disconnect closes the connection to the MCP server diff --git a/go/plugins/mcp/prompts.go b/go/plugins/mcp/prompts.go index c0782aa429..594def7927 100644 --- a/go/plugins/mcp/prompts.go +++ b/go/plugins/mcp/prompts.go @@ -29,6 +29,9 @@ func (c *GenkitMCPClient) GetPrompt(ctx context.Context, g *genkit.Genkit, promp if !c.IsEnabled() || c.server == nil { return nil, fmt.Errorf("MCP client is disabled or not connected") } + if c.server.Error != nil { + return nil, fmt.Errorf("client is in error state: %w", c.server.Error) + } // Check if prompt already exists namespacedPromptName := c.GetPromptNameWithNamespace(promptName) @@ -109,6 +112,9 @@ func (c *GenkitMCPClient) GetActivePrompts(ctx context.Context) ([]mcp.Prompt, e if !c.IsEnabled() || c.server == nil { return nil, nil } + if c.server.Error != nil { + return nil, fmt.Errorf("client is in error state: %w", c.server.Error) + } // Get all MCP prompts return c.getPrompts(ctx) diff --git a/go/plugins/mcp/resources.go b/go/plugins/mcp/resources.go index 5715721b1e..7dd3661d5d 100644 --- a/go/plugins/mcp/resources.go +++ b/go/plugins/mcp/resources.go @@ -27,6 +27,9 @@ func (c *GenkitMCPClient) GetActiveResources(ctx context.Context) ([]ai.Resource if !c.IsEnabled() || c.server == nil { return nil, fmt.Errorf("MCP client is disabled or not connected") } + if c.server.Error != nil { + return nil, fmt.Errorf("client is in error state: %w", c.server.Error) + } var resources []ai.Resource @@ -147,12 +150,15 @@ func (c *GenkitMCPClient) readMCPResource(ctx context.Context, uri string) (ai.R if !c.IsEnabled() || c.server == nil { return ai.ResourceOutput{}, fmt.Errorf("MCP client is disabled or not connected") } + if c.server.Error != nil { + return ai.ResourceOutput{}, fmt.Errorf("client is in error state: %w", c.server.Error) + } // Create ReadResource request readReq := mcp.ReadResourceRequest{ Params: struct { - URI string `json:"uri"` - Arguments map[string]interface{} `json:"arguments,omitempty"` + URI string `json:"uri"` + Arguments map[string]any `json:"arguments,omitempty"` }{ URI: uri, Arguments: nil, diff --git a/go/plugins/mcp/tools.go b/go/plugins/mcp/tools.go index 931f619b0e..212b19db99 100644 --- a/go/plugins/mcp/tools.go +++ b/go/plugins/mcp/tools.go @@ -31,6 +31,9 @@ func (c *GenkitMCPClient) GetActiveTools(ctx context.Context, g *genkit.Genkit) if !c.IsEnabled() || c.server == nil { return nil, nil } + if c.server.Error != nil { + return nil, fmt.Errorf("client is in error state: %w", c.server.Error) + } // Get all MCP tools mcpTools, err := c.getTools(ctx) @@ -224,7 +227,6 @@ func executeToolCall(ctx context.Context, client *client.Client, toolName string } result, err := client.CallTool(ctx, callReq) - if err != nil { return nil, err }