Skip to content

Commit 15577a0

Browse files
committed
chore: add internal/testutil/mockmcp.go with MCP test helpers
1 parent b84b1c3 commit 15577a0

8 files changed

Lines changed: 342 additions & 331 deletions

apidump_integration_test.go

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@ import (
2222
"github.com/coder/aibridge/fixtures"
2323
"github.com/coder/aibridge/intercept/apidump"
2424
"github.com/coder/aibridge/internal/testutil"
25-
"github.com/coder/aibridge/mcp"
2625
"github.com/coder/aibridge/provider"
2726
"github.com/stretchr/testify/require"
2827
)
@@ -95,7 +94,7 @@ func TestAPIDump(t *testing.T) {
9594
dumpDir := t.TempDir()
9695

9796
recorderClient := &testutil.MockRecorder{}
98-
b, err := aibridge.NewRequestBridge(t.Context(), tc.providersFunc(srv.URL, dumpDir), recorderClient, mcp.NewServerProxyManager(nil, testTracer), logger, nil, testTracer)
97+
b, err := aibridge.NewRequestBridge(t.Context(), tc.providersFunc(srv.URL, dumpDir), recorderClient, testutil.NilMCPManager(), logger, nil, testTracer)
9998
require.NoError(t, err)
10099

101100
mockSrv := httptest.NewUnstartedServer(b)
@@ -230,7 +229,7 @@ func TestAPIDumpPassthrough(t *testing.T) {
230229
recorderClient := &testutil.MockRecorder{}
231230
prov := tc.providerFunc(upstream.URL, dumpDir)
232231
provs := []aibridge.Provider{prov}
233-
b, err := aibridge.NewRequestBridge(t.Context(), provs, recorderClient, mcp.NewServerProxyManager(nil, testTracer), logger, nil, testTracer)
232+
b, err := aibridge.NewRequestBridge(t.Context(), provs, recorderClient, testutil.NilMCPManager(), logger, nil, testTracer)
234233
require.NoError(t, err)
235234

236235
bridgeSrv := httptest.NewUnstartedServer(b)

bridge_integration_test.go

Lines changed: 80 additions & 176 deletions
Large diffs are not rendered by default.

circuit_breaker_integration_test.go

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@ import (
1818
"github.com/coder/aibridge"
1919
"github.com/coder/aibridge/config"
2020
"github.com/coder/aibridge/internal/testutil"
21-
"github.com/coder/aibridge/mcp"
2221
"github.com/coder/aibridge/metrics"
2322
"github.com/coder/aibridge/provider"
2423
"github.com/prometheus/client_golang/prometheus"
@@ -145,7 +144,7 @@ func TestCircuitBreaker_FullRecoveryCycle(t *testing.T) {
145144
bridge, err := aibridge.NewRequestBridge(ctx,
146145
[]provider.Provider{prov},
147146
&testutil.MockRecorder{},
148-
mcp.NewServerProxyManager(nil, tracer),
147+
testutil.NilMCPManager(),
149148
logger,
150149
metrics,
151150
tracer,
@@ -318,7 +317,7 @@ func TestCircuitBreaker_HalfOpenFailure(t *testing.T) {
318317
bridge, err := aibridge.NewRequestBridge(ctx,
319318
[]provider.Provider{prov},
320319
&testutil.MockRecorder{},
321-
mcp.NewServerProxyManager(nil, tracer),
320+
testutil.NilMCPManager(),
322321
logger,
323322
metrics,
324323
tracer,
@@ -484,7 +483,7 @@ func TestCircuitBreaker_HalfOpenMaxRequests(t *testing.T) {
484483
bridge, err := aibridge.NewRequestBridge(ctx,
485484
[]provider.Provider{prov},
486485
&testutil.MockRecorder{},
487-
mcp.NewServerProxyManager(nil, tracer),
486+
testutil.NilMCPManager(),
488487
logger,
489488
metrics,
490489
tracer,
@@ -622,7 +621,7 @@ func TestCircuitBreaker_PerModelIsolation(t *testing.T) {
622621
bridge, err := aibridge.NewRequestBridge(ctx,
623622
[]provider.Provider{prov},
624623
&testutil.MockRecorder{},
625-
mcp.NewServerProxyManager(nil, tracer),
624+
testutil.NilMCPManager(),
626625
logger,
627626
m,
628627
tracer,

internal/testutil/mockmcp.go

Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
package testutil
2+
3+
import (
4+
"context"
5+
"errors"
6+
"fmt"
7+
"net/http"
8+
"net/http/httptest"
9+
"sync"
10+
"testing"
11+
"time"
12+
13+
"cdr.dev/slog/v3"
14+
"cdr.dev/slog/v3/sloggers/slogtest"
15+
"github.com/coder/aibridge/mcp"
16+
"github.com/mark3labs/mcp-go/client/transport"
17+
mcplib "github.com/mark3labs/mcp-go/mcp"
18+
"github.com/mark3labs/mcp-go/server"
19+
"github.com/stretchr/testify/require"
20+
"go.opentelemetry.io/otel/trace"
21+
"go.opentelemetry.io/otel/trace/noop"
22+
)
23+
24+
// MockToolName is the primary mock tool name used in MCP tests.
25+
const MockToolName = "coder_list_workspaces"
26+
27+
// MockMCP wraps a real mcp.ServerProxier with test assertion helpers.
28+
// Implements mcp.ServerProxier so it can be passed directly to NewRequestBridge.
29+
type MockMCP struct {
30+
mcp.ServerProxier
31+
calls *callAccumulator
32+
}
33+
34+
// GetCallsByTool returns recorded arguments for a given tool name.
35+
func (m *MockMCP) GetCallsByTool(name string) []any {
36+
return m.calls.getCallsByTool(name)
37+
}
38+
39+
// SetToolError configures a tool to return an error when invoked.
40+
func (m *MockMCP) SetToolError(tool, errMsg string) {
41+
m.calls.setToolError(tool, errMsg)
42+
}
43+
44+
// SetupMCPForTest creates a ready-to-use MCP server with proxy named "coder".
45+
func SetupMCPForTest(t *testing.T, tracer trace.Tracer) *MockMCP {
46+
t.Helper()
47+
return SetupMCPForTestWithName(t, "coder", tracer)
48+
}
49+
50+
func SetupMCPForTestWithName(t *testing.T, name string, tracer trace.Tracer) *MockMCP {
51+
t.Helper()
52+
53+
srv, acc := createMockMCPSrv(t)
54+
mcpSrv := httptest.NewServer(srv)
55+
t.Cleanup(mcpSrv.Close) // FIRST registered → runs LAST (LIFO)
56+
57+
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug)
58+
// Use a dedicated HTTP client so MCP mocks don't use http.DefaultTransport,
59+
// which can break when httptest.Server calls CloseIdleConnections in parallel
60+
// resulting in error `init MCP client: failed to send initialized notification: failed to send request: failed to send request: Post "http://127.0.0.1:43843": net/http: HTTP/1.x transport connection broken: http: CloseIdleConnections called`
61+
// https://github.com/golang/go/blob/44ec057a3e89482cf775f5eaaf03b0b5fcab1fa4/src/net/http/httptest/server.go#L268
62+
httpClient := &http.Client{Transport: &http.Transport{}}
63+
proxy, err := mcp.NewStreamableHTTPServerProxy(name, mcpSrv.URL, nil, nil, nil, logger, tracer, transport.WithHTTPBasicClient(httpClient))
64+
require.NoError(t, err)
65+
66+
mgr := mcp.NewServerProxyManager(map[string]mcp.ServerProxier{proxy.Name(): proxy}, tracer)
67+
t.Cleanup(func() {
68+
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
69+
defer cancel()
70+
require.NoError(t, mgr.Shutdown(ctx))
71+
})
72+
73+
ctx, cancel := context.WithTimeout(t.Context(), time.Second*30)
74+
t.Cleanup(cancel)
75+
require.NoError(t, mgr.Init(ctx))
76+
require.NotEmpty(t, mgr.ListTools(), "mock MCP server should expose tools after init")
77+
78+
return &MockMCP{ServerProxier: mgr, calls: acc}
79+
}
80+
81+
func NilMCPManager() mcp.ServerProxier {
82+
return mcp.NewServerProxyManager(nil, noop.NewTracerProvider().Tracer(""))
83+
}
84+
85+
// callAccumulator tracks all tool invocations by name and each instance's arguments.
86+
type callAccumulator struct {
87+
calls map[string][]any
88+
callsMu sync.Mutex
89+
toolErrors map[string]string
90+
}
91+
92+
func newCallAccumulator() *callAccumulator {
93+
return &callAccumulator{
94+
calls: make(map[string][]any),
95+
toolErrors: make(map[string]string),
96+
}
97+
}
98+
99+
func (a *callAccumulator) setToolError(tool string, errMsg string) {
100+
a.callsMu.Lock()
101+
defer a.callsMu.Unlock()
102+
a.toolErrors[tool] = errMsg
103+
}
104+
105+
func (a *callAccumulator) getToolError(tool string) (string, bool) {
106+
a.callsMu.Lock()
107+
defer a.callsMu.Unlock()
108+
errMsg, ok := a.toolErrors[tool]
109+
return errMsg, ok
110+
}
111+
112+
func (a *callAccumulator) addCall(tool string, args any) {
113+
a.callsMu.Lock()
114+
defer a.callsMu.Unlock()
115+
a.calls[tool] = append(a.calls[tool], args)
116+
}
117+
118+
func (a *callAccumulator) getCallsByTool(name string) []any {
119+
a.callsMu.Lock()
120+
defer a.callsMu.Unlock()
121+
result := make([]any, len(a.calls[name]))
122+
copy(result, a.calls[name])
123+
return result
124+
}
125+
126+
func createMockMCPSrv(t *testing.T) (http.Handler, *callAccumulator) {
127+
t.Helper()
128+
129+
s := server.NewMCPServer(
130+
"Mock coder MCP server",
131+
"1.0.0",
132+
server.WithToolCapabilities(true),
133+
)
134+
135+
acc := newCallAccumulator()
136+
137+
for _, name := range []string{MockToolName, "coder_list_templates", "coder_template_version_parameters", "coder_get_authenticated_user", "coder_create_workspace_build", "coder_delete_template"} {
138+
tool := mcplib.NewTool(name,
139+
mcplib.WithDescription(fmt.Sprintf("Mock of the %s tool", name)),
140+
)
141+
s.AddTool(tool, func(ctx context.Context, request mcplib.CallToolRequest) (*mcplib.CallToolResult, error) {
142+
acc.addCall(request.Params.Name, request.Params.Arguments)
143+
if errMsg, ok := acc.getToolError(request.Params.Name); ok {
144+
return nil, errors.New(errMsg)
145+
}
146+
return mcplib.NewToolResultText("mock"), nil
147+
})
148+
}
149+
150+
return server.NewStreamableHTTPServer(s), acc
151+
}

mcp/proxy_streamable_http.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,8 @@ type StreamableHTTPServerProxy struct {
3232
tools map[string]*Tool
3333
}
3434

35-
func NewStreamableHTTPServerProxy(serverName, serverURL string, headers map[string]string, allowlist, denylist *regexp.Regexp, logger slog.Logger, tracer trace.Tracer) (*StreamableHTTPServerProxy, error) {
36-
var opts []transport.StreamableHTTPCOption
35+
func NewStreamableHTTPServerProxy(serverName, serverURL string, headers map[string]string, allowlist, denylist *regexp.Regexp, logger slog.Logger, tracer trace.Tracer, opts ...transport.StreamableHTTPCOption) (*StreamableHTTPServerProxy, error) {
36+
// nit: headers should be passed in as options instead of a separate parameter. This will be a breaking change.
3737
if headers != nil {
3838
opts = append(opts, transport.WithHTTPHeaders(headers))
3939
}

metrics_integration_test.go

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@ import (
1616
aibcontext "github.com/coder/aibridge/context"
1717
"github.com/coder/aibridge/fixtures"
1818
"github.com/coder/aibridge/internal/testutil"
19-
"github.com/coder/aibridge/mcp"
2019
"github.com/coder/aibridge/metrics"
2120
"github.com/coder/aibridge/provider"
2221
"github.com/prometheus/client_golang/prometheus"
@@ -295,11 +294,9 @@ func TestMetrics_InjectedToolUseCount(t *testing.T) {
295294
provider := provider.NewAnthropic(anthropicCfg(upstream.URL, apiKey), nil)
296295

297296
// Setup mocked MCP server & tools.
298-
mcpProxiers, _ := setupMCPServerProxiesForTest(t, testTracer)
299-
mcpMgr := mcp.NewServerProxyManager(mcpProxiers, testTracer)
300-
require.NoError(t, mcpMgr.Init(ctx))
297+
mockMCP := testutil.SetupMCPForTest(t, testTracer)
301298

302-
bridge, err := aibridge.NewRequestBridge(ctx, []aibridge.Provider{provider}, recorder, mcpMgr, logger, metrics, testTracer)
299+
bridge, err := aibridge.NewRequestBridge(ctx, []aibridge.Provider{provider}, recorder, mockMCP, logger, metrics, testTracer)
303300
require.NoError(t, err)
304301

305302
srv := httptest.NewUnstartedServer(bridge)
@@ -327,7 +324,7 @@ func TestMetrics_InjectedToolUseCount(t *testing.T) {
327324
actualServerURL := *recorder.ToolUsages()[0].ServerURL
328325

329326
count := promtest.ToFloat64(metrics.InjectedToolUseCount.WithLabelValues(
330-
config.ProviderAnthropic, "claude-sonnet-4-20250514", actualServerURL, mockToolName))
327+
config.ProviderAnthropic, "claude-sonnet-4-20250514", actualServerURL, testutil.MockToolName))
331328
require.Equal(t, 1.0, count)
332329
}
333330

@@ -341,7 +338,7 @@ func newTestSrv(t *testing.T, ctx context.Context, provider aibridge.Provider, m
341338
}
342339
wrappedRecorder := aibridge.NewRecorder(logger, tracer, clientFn)
343340

344-
bridge, err := aibridge.NewRequestBridge(ctx, []aibridge.Provider{provider}, wrappedRecorder, mcp.NewServerProxyManager(nil, testTracer), logger, metrics, tracer)
341+
bridge, err := aibridge.NewRequestBridge(ctx, []aibridge.Provider{provider}, wrappedRecorder, testutil.NilMCPManager(), logger, metrics, tracer)
345342
require.NoError(t, err)
346343

347344
srv := httptest.NewUnstartedServer(bridge)

responses_integration_test.go

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@ import (
2121
aibcontext "github.com/coder/aibridge/context"
2222
"github.com/coder/aibridge/fixtures"
2323
"github.com/coder/aibridge/internal/testutil"
24-
"github.com/coder/aibridge/mcp"
2524
"github.com/coder/aibridge/provider"
2625
"github.com/coder/aibridge/recorder"
2726
"github.com/openai/openai-go/v3/responses"
@@ -872,18 +871,16 @@ func TestResponsesInjectedTool(t *testing.T) {
872871
upstream := testutil.NewMockUpstream(t, ctx, testutil.NewFixtureResponse(fix), testutil.NewFixtureToolResponse(fix))
873872

874873
// Setup MCP server proxies (with mock tools).
875-
mcpProxiers, mcpCalls := setupMCPServerProxiesForTest(t, testTracer)
874+
mockMCP := testutil.SetupMCPForTest(t, testTracer)
876875
if tc.expectToolError != "" {
877-
mcpCalls.setToolError(tc.mcpToolName, tc.expectToolError)
876+
mockMCP.SetToolError(tc.mcpToolName, tc.expectToolError)
878877
}
879-
mcpMgr := mcp.NewServerProxyManager(mcpProxiers, testTracer)
880-
require.NoError(t, mcpMgr.Init(ctx))
881878

882879
prov := provider.NewOpenAI(openaiCfg(upstream.URL, apiKey))
883880
mockRecorder := &testutil.MockRecorder{}
884881
logger := slogtest.Make(t, &slogtest.Options{}).Leveled(slog.LevelDebug)
885882

886-
bridge, err := aibridge.NewRequestBridge(ctx, []aibridge.Provider{prov}, mockRecorder, mcpMgr, logger, nil, testTracer)
883+
bridge, err := aibridge.NewRequestBridge(ctx, []aibridge.Provider{prov}, mockRecorder, mockMCP, logger, nil, testTracer)
887884
require.NoError(t, err)
888885

889886
srv := httptest.NewUnstartedServer(bridge)
@@ -908,7 +905,7 @@ func TestResponsesInjectedTool(t *testing.T) {
908905
}, time.Second*10, time.Millisecond*50)
909906

910907
// Verify the injected tool was invoked via MCP.
911-
invocations := mcpCalls.getCallsByTool(tc.mcpToolName)
908+
invocations := mockMCP.GetCallsByTool(tc.mcpToolName)
912909
require.Len(t, invocations, 1, "expected MCP tool to be invoked once")
913910

914911
// Verify the injected tool usage was recorded.

0 commit comments

Comments
 (0)