Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 22 additions & 7 deletions go/plugins/mcp/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ type MCPClientOptions struct {
type ServerRef struct {
Client *client.Client
Transport transport.Interface
Error string
Error error
}

// GenkitMCPClient represents a client for interacting with MCP servers.
Expand Down Expand Up @@ -103,6 +103,10 @@ func NewGenkitMCPClient(options MCPClientOptions) (*GenkitMCPClient, error) {
return nil, fmt.Errorf("failed to initialize MCP client: %w", err)
}

if client.server.Error != nil {
return nil, client.server.Error
}

return client, nil
}

Expand All @@ -119,20 +123,24 @@ func (c *GenkitMCPClient) connect(options MCPClientOptions) error {
// Create and configure transport
transport, err := c.createTransport(options)
if err != nil {
// no transport means no ability to create a server
c.server = &ServerRef{Error: err}
return err
}

// Start the transport
ctx := context.Background()
if err := transport.Start(ctx); err != nil {
return fmt.Errorf("failed to start transport: %w", err)
wrappedErr := fmt.Errorf("failed to start transport: %w", err)
c.server = &ServerRef{Error: wrappedErr}
return wrappedErr
}

// Create MCP client
mcpClient := client.NewClient(transport)

// Initialize the client if not disabled
var serverError string
var serverError error
if !options.Disabled {
serverError = c.initializeClient(ctx, mcpClient, options.Version)
}
Expand Down Expand Up @@ -184,7 +192,7 @@ func (c *GenkitMCPClient) createTransport(options MCPClientOptions) (transport.I
}

// initializeClient initializes the MCP client connection
func (c *GenkitMCPClient) initializeClient(ctx context.Context, mcpClient *client.Client, version string) string {
func (c *GenkitMCPClient) initializeClient(ctx context.Context, mcpClient *client.Client, version string) error {
initReq := mcp.InitializeRequest{
Params: struct {
ProtocolVersion string `json:"protocolVersion"`
Expand All @@ -202,10 +210,10 @@ func (c *GenkitMCPClient) initializeClient(ctx context.Context, mcpClient *clien

_, err := mcpClient.Initialize(ctx, initReq)
if err != nil {
return err.Error()
return err
}

return ""
return nil
}

// Name returns the client name
Expand Down Expand Up @@ -239,7 +247,14 @@ func (c *GenkitMCPClient) Restart(ctx context.Context) error {
if err := c.Disconnect(); err != nil {
logger.FromContext(ctx).Warn("Error closing MCP transport during restart", "client", c.options.Name, "error", err)
}
return c.connect(c.options)

if err := c.connect(c.options); err != nil {
return err
}
if c.server.Error != nil {
return fmt.Errorf("failed to restart MCP client: %w", c.server.Error)
}
return nil
}

// Disconnect closes the connection to the MCP server
Expand Down
6 changes: 6 additions & 0 deletions go/plugins/mcp/prompts.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ func (c *GenkitMCPClient) GetPrompt(ctx context.Context, g *genkit.Genkit, promp
if !c.IsEnabled() || c.server == nil {
return nil, fmt.Errorf("MCP client is disabled or not connected")
}
if c.server.Error != nil {
return nil, fmt.Errorf("client is in error state: %w", c.server.Error)
}

// Check if prompt already exists
namespacedPromptName := c.GetPromptNameWithNamespace(promptName)
Expand Down Expand Up @@ -109,6 +112,9 @@ func (c *GenkitMCPClient) GetActivePrompts(ctx context.Context) ([]mcp.Prompt, e
if !c.IsEnabled() || c.server == nil {
return nil, nil
}
if c.server.Error != nil {
return nil, fmt.Errorf("client is in error state: %w", c.server.Error)
}

// Get all MCP prompts
return c.getPrompts(ctx)
Expand Down
10 changes: 8 additions & 2 deletions go/plugins/mcp/resources.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ func (c *GenkitMCPClient) GetActiveResources(ctx context.Context) ([]ai.Resource
if !c.IsEnabled() || c.server == nil {
return nil, fmt.Errorf("MCP client is disabled or not connected")
}
if c.server.Error != nil {
return nil, fmt.Errorf("client is in error state: %w", c.server.Error)
}

var resources []ai.Resource

Expand Down Expand Up @@ -147,12 +150,15 @@ func (c *GenkitMCPClient) readMCPResource(ctx context.Context, uri string) (ai.R
if !c.IsEnabled() || c.server == nil {
return ai.ResourceOutput{}, fmt.Errorf("MCP client is disabled or not connected")
}
if c.server.Error != nil {
return ai.ResourceOutput{}, fmt.Errorf("client is in error state: %w", c.server.Error)
}

// Create ReadResource request
readReq := mcp.ReadResourceRequest{
Params: struct {
URI string `json:"uri"`
Arguments map[string]interface{} `json:"arguments,omitempty"`
URI string `json:"uri"`
Arguments map[string]any `json:"arguments,omitempty"`
}{
URI: uri,
Arguments: nil,
Expand Down
4 changes: 3 additions & 1 deletion go/plugins/mcp/tools.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ func (c *GenkitMCPClient) GetActiveTools(ctx context.Context, g *genkit.Genkit)
if !c.IsEnabled() || c.server == nil {
return nil, nil
}
if c.server.Error != nil {
return nil, fmt.Errorf("client is in error state: %w", c.server.Error)
}

// Get all MCP tools
mcpTools, err := c.getTools(ctx)
Expand Down Expand Up @@ -224,7 +227,6 @@ func executeToolCall(ctx context.Context, client *client.Client, toolName string
}

result, err := client.CallTool(ctx, callReq)

if err != nil {
return nil, err
}
Expand Down
Loading