Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions apidump_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ import (
"github.com/coder/aibridge/fixtures"
"github.com/coder/aibridge/intercept/apidump"
"github.com/coder/aibridge/internal/testutil"
"github.com/coder/aibridge/mcp"
"github.com/coder/aibridge/provider"
"github.com/stretchr/testify/require"
)
Expand Down Expand Up @@ -95,7 +94,7 @@ func TestAPIDump(t *testing.T) {
dumpDir := t.TempDir()

recorderClient := &testutil.MockRecorder{}
b, err := aibridge.NewRequestBridge(t.Context(), tc.providersFunc(srv.URL, dumpDir), recorderClient, mcp.NewServerProxyManager(nil, testTracer), logger, nil, testTracer)
b, err := aibridge.NewRequestBridge(t.Context(), tc.providersFunc(srv.URL, dumpDir), recorderClient, testutil.NilMCPManager(), logger, nil, testTracer)
require.NoError(t, err)

mockSrv := httptest.NewUnstartedServer(b)
Expand Down Expand Up @@ -230,7 +229,7 @@ func TestAPIDumpPassthrough(t *testing.T) {
recorderClient := &testutil.MockRecorder{}
prov := tc.providerFunc(upstream.URL, dumpDir)
provs := []aibridge.Provider{prov}
b, err := aibridge.NewRequestBridge(t.Context(), provs, recorderClient, mcp.NewServerProxyManager(nil, testTracer), logger, nil, testTracer)
b, err := aibridge.NewRequestBridge(t.Context(), provs, recorderClient, testutil.NilMCPManager(), logger, nil, testTracer)
require.NoError(t, err)

bridgeSrv := httptest.NewUnstartedServer(b)
Expand Down
256 changes: 80 additions & 176 deletions bridge_integration_test.go

Large diffs are not rendered by default.

9 changes: 4 additions & 5 deletions circuit_breaker_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ import (
"github.com/coder/aibridge"
"github.com/coder/aibridge/config"
"github.com/coder/aibridge/internal/testutil"
"github.com/coder/aibridge/mcp"
"github.com/coder/aibridge/metrics"
"github.com/coder/aibridge/provider"
"github.com/prometheus/client_golang/prometheus"
Expand Down Expand Up @@ -145,7 +144,7 @@ func TestCircuitBreaker_FullRecoveryCycle(t *testing.T) {
bridge, err := aibridge.NewRequestBridge(ctx,
[]provider.Provider{prov},
&testutil.MockRecorder{},
mcp.NewServerProxyManager(nil, tracer),
testutil.NilMCPManager(),
logger,
metrics,
tracer,
Expand Down Expand Up @@ -318,7 +317,7 @@ func TestCircuitBreaker_HalfOpenFailure(t *testing.T) {
bridge, err := aibridge.NewRequestBridge(ctx,
[]provider.Provider{prov},
&testutil.MockRecorder{},
mcp.NewServerProxyManager(nil, tracer),
testutil.NilMCPManager(),
logger,
metrics,
tracer,
Expand Down Expand Up @@ -484,7 +483,7 @@ func TestCircuitBreaker_HalfOpenMaxRequests(t *testing.T) {
bridge, err := aibridge.NewRequestBridge(ctx,
[]provider.Provider{prov},
&testutil.MockRecorder{},
mcp.NewServerProxyManager(nil, tracer),
testutil.NilMCPManager(),
logger,
metrics,
tracer,
Expand Down Expand Up @@ -622,7 +621,7 @@ func TestCircuitBreaker_PerModelIsolation(t *testing.T) {
bridge, err := aibridge.NewRequestBridge(ctx,
[]provider.Provider{prov},
&testutil.MockRecorder{},
mcp.NewServerProxyManager(nil, tracer),
testutil.NilMCPManager(),
logger,
m,
tracer,
Expand Down
151 changes: 151 additions & 0 deletions internal/testutil/mockmcp.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
package testutil

import (
"context"
"errors"
"fmt"
"net/http"
"net/http/httptest"
"sync"
"testing"
"time"

"cdr.dev/slog/v3"
"cdr.dev/slog/v3/sloggers/slogtest"
"github.com/coder/aibridge/mcp"
"github.com/mark3labs/mcp-go/client/transport"
mcplib "github.com/mark3labs/mcp-go/mcp"
"github.com/mark3labs/mcp-go/server"
"github.com/stretchr/testify/require"
"go.opentelemetry.io/otel/trace"
"go.opentelemetry.io/otel/trace/noop"
)

// MockToolName is the primary mock tool name used in MCP tests.
const MockToolName = "coder_list_workspaces"

// MockMCP wraps a real mcp.ServerProxier with test assertion helpers.
// Implements mcp.ServerProxier so it can be passed directly to NewRequestBridge.
type MockMCP struct {
mcp.ServerProxier
calls *callAccumulator
}

// GetCallsByTool returns recorded arguments for a given tool name.
func (m *MockMCP) GetCallsByTool(name string) []any {
return m.calls.getCallsByTool(name)
}

// SetToolError configures a tool to return an error when invoked.
func (m *MockMCP) SetToolError(tool, errMsg string) {
m.calls.setToolError(tool, errMsg)
}

// SetupMCPForTest creates a ready-to-use MCP server with proxy named "coder".
func SetupMCPForTest(t *testing.T, tracer trace.Tracer) *MockMCP {
t.Helper()
return SetupMCPForTestWithName(t, "coder", tracer)
}

func SetupMCPForTestWithName(t *testing.T, name string, tracer trace.Tracer) *MockMCP {
t.Helper()

srv, acc := createMockMCPSrv(t)
mcpSrv := httptest.NewServer(srv)
t.Cleanup(mcpSrv.Close) // FIRST registered → runs LAST (LIFO)

logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug)
// Use a dedicated HTTP client so MCP mocks don't use http.DefaultTransport,
// which can break when httptest.Server calls CloseIdleConnections in parallel
// resulting in error `init MCP client: failed to send initialized notification: failed to send request: failed to send request: Post "http://127.0.0.1:43843": net/http: HTTP/1.x transport connection broken: http: CloseIdleConnections called`
// https://github.com/golang/go/blob/44ec057a3e89482cf775f5eaaf03b0b5fcab1fa4/src/net/http/httptest/server.go#L268
httpClient := &http.Client{Transport: &http.Transport{}}
proxy, err := mcp.NewStreamableHTTPServerProxy(name, mcpSrv.URL, nil, nil, nil, logger, tracer, transport.WithHTTPBasicClient(httpClient))
require.NoError(t, err)

mgr := mcp.NewServerProxyManager(map[string]mcp.ServerProxier{proxy.Name(): proxy}, tracer)
t.Cleanup(func() {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
require.NoError(t, mgr.Shutdown(ctx))
})

ctx, cancel := context.WithTimeout(t.Context(), time.Second*30)
t.Cleanup(cancel)
require.NoError(t, mgr.Init(ctx))
require.NotEmpty(t, mgr.ListTools(), "mock MCP server should expose tools after init")

return &MockMCP{ServerProxier: mgr, calls: acc}
}

func NilMCPManager() mcp.ServerProxier {
return mcp.NewServerProxyManager(nil, noop.NewTracerProvider().Tracer(""))
}

// callAccumulator tracks all tool invocations by name and each instance's arguments.
type callAccumulator struct {
calls map[string][]any
callsMu sync.Mutex
toolErrors map[string]string
}

func newCallAccumulator() *callAccumulator {
return &callAccumulator{
calls: make(map[string][]any),
toolErrors: make(map[string]string),
}
}

func (a *callAccumulator) setToolError(tool string, errMsg string) {
a.callsMu.Lock()
defer a.callsMu.Unlock()
a.toolErrors[tool] = errMsg
}

func (a *callAccumulator) getToolError(tool string) (string, bool) {
a.callsMu.Lock()
defer a.callsMu.Unlock()
errMsg, ok := a.toolErrors[tool]
return errMsg, ok
}

func (a *callAccumulator) addCall(tool string, args any) {
a.callsMu.Lock()
defer a.callsMu.Unlock()
a.calls[tool] = append(a.calls[tool], args)
}

func (a *callAccumulator) getCallsByTool(name string) []any {
a.callsMu.Lock()
defer a.callsMu.Unlock()
result := make([]any, len(a.calls[name]))
copy(result, a.calls[name])
return result
}

func createMockMCPSrv(t *testing.T) (http.Handler, *callAccumulator) {
t.Helper()

s := server.NewMCPServer(
"Mock coder MCP server",
"1.0.0",
server.WithToolCapabilities(true),
)

acc := newCallAccumulator()

for _, name := range []string{MockToolName, "coder_list_templates", "coder_template_version_parameters", "coder_get_authenticated_user", "coder_create_workspace_build", "coder_delete_template"} {
tool := mcplib.NewTool(name,
mcplib.WithDescription(fmt.Sprintf("Mock of the %s tool", name)),
)
s.AddTool(tool, func(ctx context.Context, request mcplib.CallToolRequest) (*mcplib.CallToolResult, error) {
acc.addCall(request.Params.Name, request.Params.Arguments)
if errMsg, ok := acc.getToolError(request.Params.Name); ok {
return nil, errors.New(errMsg)
}
return mcplib.NewToolResultText("mock"), nil
})
}

return server.NewStreamableHTTPServer(s), acc
}
4 changes: 2 additions & 2 deletions mcp/proxy_streamable_http.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ type StreamableHTTPServerProxy struct {
tools map[string]*Tool
}

func NewStreamableHTTPServerProxy(serverName, serverURL string, headers map[string]string, allowlist, denylist *regexp.Regexp, logger slog.Logger, tracer trace.Tracer) (*StreamableHTTPServerProxy, error) {
var opts []transport.StreamableHTTPCOption
func NewStreamableHTTPServerProxy(serverName, serverURL string, headers map[string]string, allowlist, denylist *regexp.Regexp, logger slog.Logger, tracer trace.Tracer, opts ...transport.StreamableHTTPCOption) (*StreamableHTTPServerProxy, error) {
// nit: headers should be passed in as options instead of a separate parameter. This will be a breaking change.
if headers != nil {
opts = append(opts, transport.WithHTTPHeaders(headers))
}
Expand Down
11 changes: 4 additions & 7 deletions metrics_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ import (
aibcontext "github.com/coder/aibridge/context"
"github.com/coder/aibridge/fixtures"
"github.com/coder/aibridge/internal/testutil"
"github.com/coder/aibridge/mcp"
"github.com/coder/aibridge/metrics"
"github.com/coder/aibridge/provider"
"github.com/prometheus/client_golang/prometheus"
Expand Down Expand Up @@ -295,11 +294,9 @@ func TestMetrics_InjectedToolUseCount(t *testing.T) {
provider := provider.NewAnthropic(anthropicCfg(upstream.URL, apiKey), nil)

// Setup mocked MCP server & tools.
mcpProxiers, _ := setupMCPServerProxiesForTest(t, testTracer)
mcpMgr := mcp.NewServerProxyManager(mcpProxiers, testTracer)
require.NoError(t, mcpMgr.Init(ctx))
mockMCP := testutil.SetupMCPForTest(t, testTracer)

bridge, err := aibridge.NewRequestBridge(ctx, []aibridge.Provider{provider}, recorder, mcpMgr, logger, metrics, testTracer)
bridge, err := aibridge.NewRequestBridge(ctx, []aibridge.Provider{provider}, recorder, mockMCP, logger, metrics, testTracer)
require.NoError(t, err)

srv := httptest.NewUnstartedServer(bridge)
Expand Down Expand Up @@ -327,7 +324,7 @@ func TestMetrics_InjectedToolUseCount(t *testing.T) {
actualServerURL := *recorder.ToolUsages()[0].ServerURL

count := promtest.ToFloat64(metrics.InjectedToolUseCount.WithLabelValues(
config.ProviderAnthropic, "claude-sonnet-4-20250514", actualServerURL, mockToolName))
config.ProviderAnthropic, "claude-sonnet-4-20250514", actualServerURL, testutil.MockToolName))
require.Equal(t, 1.0, count)
}

Expand All @@ -341,7 +338,7 @@ func newTestSrv(t *testing.T, ctx context.Context, provider aibridge.Provider, m
}
wrappedRecorder := aibridge.NewRecorder(logger, tracer, clientFn)

bridge, err := aibridge.NewRequestBridge(ctx, []aibridge.Provider{provider}, wrappedRecorder, mcp.NewServerProxyManager(nil, testTracer), logger, metrics, tracer)
bridge, err := aibridge.NewRequestBridge(ctx, []aibridge.Provider{provider}, wrappedRecorder, testutil.NilMCPManager(), logger, metrics, tracer)
require.NoError(t, err)

srv := httptest.NewUnstartedServer(bridge)
Expand Down
11 changes: 4 additions & 7 deletions responses_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ import (
aibcontext "github.com/coder/aibridge/context"
"github.com/coder/aibridge/fixtures"
"github.com/coder/aibridge/internal/testutil"
"github.com/coder/aibridge/mcp"
"github.com/coder/aibridge/provider"
"github.com/coder/aibridge/recorder"
"github.com/openai/openai-go/v3/responses"
Expand Down Expand Up @@ -872,18 +871,16 @@ func TestResponsesInjectedTool(t *testing.T) {
upstream := testutil.NewMockUpstream(t, ctx, testutil.NewFixtureResponse(fix), testutil.NewFixtureToolResponse(fix))

// Setup MCP server proxies (with mock tools).
mcpProxiers, mcpCalls := setupMCPServerProxiesForTest(t, testTracer)
mockMCP := testutil.SetupMCPForTest(t, testTracer)
if tc.expectToolError != "" {
mcpCalls.setToolError(tc.mcpToolName, tc.expectToolError)
mockMCP.SetToolError(tc.mcpToolName, tc.expectToolError)
}
mcpMgr := mcp.NewServerProxyManager(mcpProxiers, testTracer)
require.NoError(t, mcpMgr.Init(ctx))

prov := provider.NewOpenAI(openaiCfg(upstream.URL, apiKey))
mockRecorder := &testutil.MockRecorder{}
logger := slogtest.Make(t, &slogtest.Options{}).Leveled(slog.LevelDebug)

bridge, err := aibridge.NewRequestBridge(ctx, []aibridge.Provider{prov}, mockRecorder, mcpMgr, logger, nil, testTracer)
bridge, err := aibridge.NewRequestBridge(ctx, []aibridge.Provider{prov}, mockRecorder, mockMCP, logger, nil, testTracer)
require.NoError(t, err)

srv := httptest.NewUnstartedServer(bridge)
Expand All @@ -908,7 +905,7 @@ func TestResponsesInjectedTool(t *testing.T) {
}, time.Second*10, time.Millisecond*50)

// Verify the injected tool was invoked via MCP.
invocations := mcpCalls.getCallsByTool(tc.mcpToolName)
invocations := mockMCP.GetCallsByTool(tc.mcpToolName)
require.Len(t, invocations, 1, "expected MCP tool to be invoked once")

// Verify the injected tool usage was recorded.
Expand Down
Loading