Skip to content

Commit 162d013

Browse files
committed
chore: add internal/testutil/mockmcp.go with MCP test helpers
1 parent b84b1c3 commit 162d013

8 files changed

Lines changed: 348 additions & 329 deletions

apidump_integration_test.go

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@ import (
2222
"github.com/coder/aibridge/fixtures"
2323
"github.com/coder/aibridge/intercept/apidump"
2424
"github.com/coder/aibridge/internal/testutil"
25-
"github.com/coder/aibridge/mcp"
2625
"github.com/coder/aibridge/provider"
2726
"github.com/stretchr/testify/require"
2827
)
@@ -95,7 +94,7 @@ func TestAPIDump(t *testing.T) {
9594
dumpDir := t.TempDir()
9695

9796
recorderClient := &testutil.MockRecorder{}
98-
b, err := aibridge.NewRequestBridge(t.Context(), tc.providersFunc(srv.URL, dumpDir), recorderClient, mcp.NewServerProxyManager(nil, testTracer), logger, nil, testTracer)
97+
b, err := aibridge.NewRequestBridge(t.Context(), tc.providersFunc(srv.URL, dumpDir), recorderClient, testutil.NilMCPManager(), logger, nil, testTracer)
9998
require.NoError(t, err)
10099

101100
mockSrv := httptest.NewUnstartedServer(b)
@@ -230,7 +229,7 @@ func TestAPIDumpPassthrough(t *testing.T) {
230229
recorderClient := &testutil.MockRecorder{}
231230
prov := tc.providerFunc(upstream.URL, dumpDir)
232231
provs := []aibridge.Provider{prov}
233-
b, err := aibridge.NewRequestBridge(t.Context(), provs, recorderClient, mcp.NewServerProxyManager(nil, testTracer), logger, nil, testTracer)
232+
b, err := aibridge.NewRequestBridge(t.Context(), provs, recorderClient, testutil.NilMCPManager(), logger, nil, testTracer)
234233
require.NoError(t, err)
235234

236235
bridgeSrv := httptest.NewUnstartedServer(b)

bridge_integration_test.go

Lines changed: 80 additions & 176 deletions
Large diffs are not rendered by default.

circuit_breaker_integration_test.go

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@ import (
1818
"github.com/coder/aibridge"
1919
"github.com/coder/aibridge/config"
2020
"github.com/coder/aibridge/internal/testutil"
21-
"github.com/coder/aibridge/mcp"
2221
"github.com/coder/aibridge/metrics"
2322
"github.com/coder/aibridge/provider"
2423
"github.com/prometheus/client_golang/prometheus"
@@ -145,7 +144,7 @@ func TestCircuitBreaker_FullRecoveryCycle(t *testing.T) {
145144
bridge, err := aibridge.NewRequestBridge(ctx,
146145
[]provider.Provider{prov},
147146
&testutil.MockRecorder{},
148-
mcp.NewServerProxyManager(nil, tracer),
147+
testutil.NilMCPManager(),
149148
logger,
150149
metrics,
151150
tracer,
@@ -318,7 +317,7 @@ func TestCircuitBreaker_HalfOpenFailure(t *testing.T) {
318317
bridge, err := aibridge.NewRequestBridge(ctx,
319318
[]provider.Provider{prov},
320319
&testutil.MockRecorder{},
321-
mcp.NewServerProxyManager(nil, tracer),
320+
testutil.NilMCPManager(),
322321
logger,
323322
metrics,
324323
tracer,
@@ -484,7 +483,7 @@ func TestCircuitBreaker_HalfOpenMaxRequests(t *testing.T) {
484483
bridge, err := aibridge.NewRequestBridge(ctx,
485484
[]provider.Provider{prov},
486485
&testutil.MockRecorder{},
487-
mcp.NewServerProxyManager(nil, tracer),
486+
testutil.NilMCPManager(),
488487
logger,
489488
metrics,
490489
tracer,
@@ -622,7 +621,7 @@ func TestCircuitBreaker_PerModelIsolation(t *testing.T) {
622621
bridge, err := aibridge.NewRequestBridge(ctx,
623622
[]provider.Provider{prov},
624623
&testutil.MockRecorder{},
625-
mcp.NewServerProxyManager(nil, tracer),
624+
testutil.NilMCPManager(),
626625
logger,
627626
m,
628627
tracer,

internal/testutil/mockmcp.go

Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,157 @@
1+
package testutil
2+
3+
import (
4+
"context"
5+
"errors"
6+
"fmt"
7+
"net/http"
8+
"net/http/httptest"
9+
"sync"
10+
"testing"
11+
"time"
12+
13+
"cdr.dev/slog/v3"
14+
"cdr.dev/slog/v3/sloggers/slogtest"
15+
"github.com/coder/aibridge/mcp"
16+
"github.com/mark3labs/mcp-go/client/transport"
17+
mcplib "github.com/mark3labs/mcp-go/mcp"
18+
"github.com/mark3labs/mcp-go/server"
19+
"github.com/stretchr/testify/require"
20+
"go.opentelemetry.io/otel/trace"
21+
"go.opentelemetry.io/otel/trace/noop"
22+
)
23+
24+
// MockToolName is the primary mock tool name used in MCP tests.
25+
const MockToolName = "coder_list_workspaces"
26+
27+
// MockMCP wraps a real mcp.ServerProxier with test assertion helpers.
28+
// It implements mcp.ServerProxier and can be passed directly to NewRequestBridge.
29+
type MockMCP struct {
30+
mcp.ServerProxier
31+
calls *callAccumulator
32+
}
33+
34+
// GetCallsByTool returns all recorded argument sets for a given tool name.
35+
func (m *MockMCP) GetCallsByTool(name string) []any {
36+
return m.calls.getCallsByTool(name)
37+
}
38+
39+
// SetToolError configures a tool to return an error when invoked.
40+
func (m *MockMCP) SetToolError(tool, errMsg string) {
41+
m.calls.setToolError(tool, errMsg)
42+
}
43+
44+
// SetupMCPForTest creates a mock MCP server named "coder" and returns a ready-to-use *MockMCP.
45+
// Cleanup (shutdown + server close) is registered automatically via t.Cleanup.
46+
func SetupMCPForTest(t *testing.T, tracer trace.Tracer) *MockMCP {
47+
t.Helper()
48+
return SetupMCPForTestWithName(t, "coder", tracer)
49+
}
50+
51+
// SetupMCPForTestWithName creates a mock MCP server with a custom proxy name and
52+
// returns a ready-to-use *MockMCP. Cleanup is registered automatically via t.Cleanup.
53+
func SetupMCPForTestWithName(t *testing.T, name string, tracer trace.Tracer) *MockMCP {
54+
t.Helper()
55+
56+
srv, acc := createMockMCPSrv(t)
57+
mcpSrv := httptest.NewServer(srv)
58+
t.Cleanup(mcpSrv.Close) // FIRST registered → runs LAST (LIFO)
59+
60+
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug)
61+
// Use a dedicated HTTP client so MCP mocks don't use http.DefaultTransport,
62+
// which can break when httptest.Server calls CloseIdleConnections in parallel
63+
// resulting in error `init MCP client: failed to send initialized notification: failed to send request: failed to send request: Post "http://127.0.0.1:43843": net/http: HTTP/1.x transport connection broken: http: CloseIdleConnections called`
64+
// https://github.com/golang/go/blob/44ec057a3e89482cf775f5eaaf03b0b5fcab1fa4/src/net/http/httptest/server.go#L268
65+
httpClient := &http.Client{Transport: &http.Transport{}}
66+
proxy, err := mcp.NewStreamableHTTPServerProxy(name, mcpSrv.URL, nil, nil, nil, logger, tracer, transport.WithHTTPBasicClient(httpClient))
67+
require.NoError(t, err)
68+
69+
mgr := mcp.NewServerProxyManager(map[string]mcp.ServerProxier{proxy.Name(): proxy}, tracer)
70+
ctx, cancel := context.WithTimeout(t.Context(), time.Second*30)
71+
t.Cleanup(cancel)
72+
require.NoError(t, mgr.Init(ctx))
73+
require.NotEmpty(t, mgr.ListTools(), "mock MCP server should expose tools after init")
74+
75+
// LAST registered → runs FIRST. Graceful shutdown before server close
76+
// prevents "CloseIdleConnections called" errors.
77+
t.Cleanup(func() {
78+
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
79+
defer cancel()
80+
_ = mgr.Shutdown(ctx)
81+
})
82+
83+
return &MockMCP{ServerProxier: mgr, calls: acc}
84+
}
85+
86+
// NilMCPManager returns an empty ServerProxier for tests that don't need MCP tools.
87+
func NilMCPManager() mcp.ServerProxier {
88+
return mcp.NewServerProxyManager(nil, noop.NewTracerProvider().Tracer(""))
89+
}
90+
91+
// callAccumulator tracks all tool invocations by name and each instance's arguments.
92+
type callAccumulator struct {
93+
calls map[string][]any
94+
callsMu sync.Mutex
95+
toolErrors map[string]string
96+
}
97+
98+
func newCallAccumulator() *callAccumulator {
99+
return &callAccumulator{
100+
calls: make(map[string][]any),
101+
toolErrors: make(map[string]string),
102+
}
103+
}
104+
105+
func (a *callAccumulator) setToolError(tool string, errMsg string) {
106+
a.callsMu.Lock()
107+
defer a.callsMu.Unlock()
108+
a.toolErrors[tool] = errMsg
109+
}
110+
111+
func (a *callAccumulator) getToolError(tool string) (string, bool) {
112+
a.callsMu.Lock()
113+
defer a.callsMu.Unlock()
114+
errMsg, ok := a.toolErrors[tool]
115+
return errMsg, ok
116+
}
117+
118+
func (a *callAccumulator) addCall(tool string, args any) {
119+
a.callsMu.Lock()
120+
defer a.callsMu.Unlock()
121+
a.calls[tool] = append(a.calls[tool], args)
122+
}
123+
124+
func (a *callAccumulator) getCallsByTool(name string) []any {
125+
a.callsMu.Lock()
126+
defer a.callsMu.Unlock()
127+
result := make([]any, len(a.calls[name]))
128+
copy(result, a.calls[name])
129+
return result
130+
}
131+
132+
func createMockMCPSrv(t *testing.T) (http.Handler, *callAccumulator) {
133+
t.Helper()
134+
135+
s := server.NewMCPServer(
136+
"Mock coder MCP server",
137+
"1.0.0",
138+
server.WithToolCapabilities(true),
139+
)
140+
141+
acc := newCallAccumulator()
142+
143+
for _, name := range []string{MockToolName, "coder_list_templates", "coder_template_version_parameters", "coder_get_authenticated_user", "coder_create_workspace_build", "coder_delete_template"} {
144+
tool := mcplib.NewTool(name,
145+
mcplib.WithDescription(fmt.Sprintf("Mock of the %s tool", name)),
146+
)
147+
s.AddTool(tool, func(ctx context.Context, request mcplib.CallToolRequest) (*mcplib.CallToolResult, error) {
148+
acc.addCall(request.Params.Name, request.Params.Arguments)
149+
if errMsg, ok := acc.getToolError(request.Params.Name); ok {
150+
return nil, errors.New(errMsg)
151+
}
152+
return mcplib.NewToolResultText("mock"), nil
153+
})
154+
}
155+
156+
return server.NewStreamableHTTPServer(s), acc
157+
}

mcp/proxy_streamable_http.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,8 @@ type StreamableHTTPServerProxy struct {
3232
tools map[string]*Tool
3333
}
3434

35-
func NewStreamableHTTPServerProxy(serverName, serverURL string, headers map[string]string, allowlist, denylist *regexp.Regexp, logger slog.Logger, tracer trace.Tracer) (*StreamableHTTPServerProxy, error) {
36-
var opts []transport.StreamableHTTPCOption
35+
func NewStreamableHTTPServerProxy(serverName, serverURL string, headers map[string]string, allowlist, denylist *regexp.Regexp, logger slog.Logger, tracer trace.Tracer, opts ...transport.StreamableHTTPCOption) (*StreamableHTTPServerProxy, error) {
36+
// nit: headers should be passed in as options instead of a separate parameter. This will be a breaking change.
3737
if headers != nil {
3838
opts = append(opts, transport.WithHTTPHeaders(headers))
3939
}

metrics_integration_test.go

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@ import (
1616
aibcontext "github.com/coder/aibridge/context"
1717
"github.com/coder/aibridge/fixtures"
1818
"github.com/coder/aibridge/internal/testutil"
19-
"github.com/coder/aibridge/mcp"
2019
"github.com/coder/aibridge/metrics"
2120
"github.com/coder/aibridge/provider"
2221
"github.com/prometheus/client_golang/prometheus"
@@ -295,11 +294,9 @@ func TestMetrics_InjectedToolUseCount(t *testing.T) {
295294
provider := provider.NewAnthropic(anthropicCfg(upstream.URL, apiKey), nil)
296295

297296
// Setup mocked MCP server & tools.
298-
mcpProxiers, _ := setupMCPServerProxiesForTest(t, testTracer)
299-
mcpMgr := mcp.NewServerProxyManager(mcpProxiers, testTracer)
300-
require.NoError(t, mcpMgr.Init(ctx))
297+
mockMCP := testutil.SetupMCPForTest(t, testTracer)
301298

302-
bridge, err := aibridge.NewRequestBridge(ctx, []aibridge.Provider{provider}, recorder, mcpMgr, logger, metrics, testTracer)
299+
bridge, err := aibridge.NewRequestBridge(ctx, []aibridge.Provider{provider}, recorder, mockMCP, logger, metrics, testTracer)
303300
require.NoError(t, err)
304301

305302
srv := httptest.NewUnstartedServer(bridge)
@@ -327,7 +324,7 @@ func TestMetrics_InjectedToolUseCount(t *testing.T) {
327324
actualServerURL := *recorder.ToolUsages()[0].ServerURL
328325

329326
count := promtest.ToFloat64(metrics.InjectedToolUseCount.WithLabelValues(
330-
config.ProviderAnthropic, "claude-sonnet-4-20250514", actualServerURL, mockToolName))
327+
config.ProviderAnthropic, "claude-sonnet-4-20250514", actualServerURL, testutil.MockToolName))
331328
require.Equal(t, 1.0, count)
332329
}
333330

@@ -341,7 +338,7 @@ func newTestSrv(t *testing.T, ctx context.Context, provider aibridge.Provider, m
341338
}
342339
wrappedRecorder := aibridge.NewRecorder(logger, tracer, clientFn)
343340

344-
bridge, err := aibridge.NewRequestBridge(ctx, []aibridge.Provider{provider}, wrappedRecorder, mcp.NewServerProxyManager(nil, testTracer), logger, metrics, tracer)
341+
bridge, err := aibridge.NewRequestBridge(ctx, []aibridge.Provider{provider}, wrappedRecorder, testutil.NilMCPManager(), logger, metrics, tracer)
345342
require.NoError(t, err)
346343

347344
srv := httptest.NewUnstartedServer(bridge)

responses_integration_test.go

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@ import (
2121
aibcontext "github.com/coder/aibridge/context"
2222
"github.com/coder/aibridge/fixtures"
2323
"github.com/coder/aibridge/internal/testutil"
24-
"github.com/coder/aibridge/mcp"
2524
"github.com/coder/aibridge/provider"
2625
"github.com/coder/aibridge/recorder"
2726
"github.com/openai/openai-go/v3/responses"
@@ -872,18 +871,16 @@ func TestResponsesInjectedTool(t *testing.T) {
872871
upstream := testutil.NewMockUpstream(t, ctx, testutil.NewFixtureResponse(fix), testutil.NewFixtureToolResponse(fix))
873872

874873
// Setup MCP server proxies (with mock tools).
875-
mcpProxiers, mcpCalls := setupMCPServerProxiesForTest(t, testTracer)
874+
mockMCP := testutil.SetupMCPForTest(t, testTracer)
876875
if tc.expectToolError != "" {
877-
mcpCalls.setToolError(tc.mcpToolName, tc.expectToolError)
876+
mockMCP.SetToolError(tc.mcpToolName, tc.expectToolError)
878877
}
879-
mcpMgr := mcp.NewServerProxyManager(mcpProxiers, testTracer)
880-
require.NoError(t, mcpMgr.Init(ctx))
881878

882879
prov := provider.NewOpenAI(openaiCfg(upstream.URL, apiKey))
883880
mockRecorder := &testutil.MockRecorder{}
884881
logger := slogtest.Make(t, &slogtest.Options{}).Leveled(slog.LevelDebug)
885882

886-
bridge, err := aibridge.NewRequestBridge(ctx, []aibridge.Provider{prov}, mockRecorder, mcpMgr, logger, nil, testTracer)
883+
bridge, err := aibridge.NewRequestBridge(ctx, []aibridge.Provider{prov}, mockRecorder, mockMCP, logger, nil, testTracer)
887884
require.NoError(t, err)
888885

889886
srv := httptest.NewUnstartedServer(bridge)
@@ -908,7 +905,7 @@ func TestResponsesInjectedTool(t *testing.T) {
908905
}, time.Second*10, time.Millisecond*50)
909906

910907
// Verify the injected tool was invoked via MCP.
911-
invocations := mcpCalls.getCallsByTool(tc.mcpToolName)
908+
invocations := mockMCP.GetCallsByTool(tc.mcpToolName)
912909
require.Len(t, invocations, 1, "expected MCP tool to be invoked once")
913910

914911
// Verify the injected tool usage was recorded.

0 commit comments

Comments
 (0)