From 15577a09671bfca1c378de159d3215928ddf7a2b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Banaszewski?= Date: Wed, 4 Mar 2026 14:31:37 +0000 Subject: [PATCH] chore: add internal/testutil/mockmcp.go with MCP test helpers --- apidump_integration_test.go | 5 +- bridge_integration_test.go | 256 +++++++++------------------- circuit_breaker_integration_test.go | 9 +- internal/testutil/mockmcp.go | 151 ++++++++++++++++ mcp/proxy_streamable_http.go | 4 +- metrics_integration_test.go | 11 +- responses_integration_test.go | 11 +- trace_integration_test.go | 226 +++++++++++------------- 8 files changed, 342 insertions(+), 331 deletions(-) create mode 100644 internal/testutil/mockmcp.go diff --git a/apidump_integration_test.go b/apidump_integration_test.go index fd17ead..f6e3b14 100644 --- a/apidump_integration_test.go +++ b/apidump_integration_test.go @@ -22,7 +22,6 @@ import ( "github.com/coder/aibridge/fixtures" "github.com/coder/aibridge/intercept/apidump" "github.com/coder/aibridge/internal/testutil" - "github.com/coder/aibridge/mcp" "github.com/coder/aibridge/provider" "github.com/stretchr/testify/require" ) @@ -95,7 +94,7 @@ func TestAPIDump(t *testing.T) { dumpDir := t.TempDir() recorderClient := &testutil.MockRecorder{} - b, err := aibridge.NewRequestBridge(t.Context(), tc.providersFunc(srv.URL, dumpDir), recorderClient, mcp.NewServerProxyManager(nil, testTracer), logger, nil, testTracer) + b, err := aibridge.NewRequestBridge(t.Context(), tc.providersFunc(srv.URL, dumpDir), recorderClient, testutil.NilMCPManager(), logger, nil, testTracer) require.NoError(t, err) mockSrv := httptest.NewUnstartedServer(b) @@ -230,7 +229,7 @@ func TestAPIDumpPassthrough(t *testing.T) { recorderClient := &testutil.MockRecorder{} prov := tc.providerFunc(upstream.URL, dumpDir) provs := []aibridge.Provider{prov} - b, err := aibridge.NewRequestBridge(t.Context(), provs, recorderClient, mcp.NewServerProxyManager(nil, testTracer), logger, nil, testTracer) + b, err := aibridge.NewRequestBridge(t.Context(), provs, recorderClient, testutil.NilMCPManager(), logger, nil, testTracer) require.NoError(t, err) bridgeSrv := httptest.NewUnstartedServer(b) diff --git a/bridge_integration_test.go b/bridge_integration_test.go index bf83264..ced1450 100644 --- a/bridge_integration_test.go +++ b/bridge_integration_test.go @@ -4,14 +4,12 @@ import ( "bytes" "context" "encoding/json" - "errors" "fmt" "io" "net" "net/http" "net/http/httptest" "strings" - "sync" "testing" "time" @@ -30,8 +28,6 @@ import ( "github.com/coder/aibridge/provider" "github.com/coder/aibridge/recorder" "github.com/google/uuid" - mcplib "github.com/mark3labs/mcp-go/mcp" - "github.com/mark3labs/mcp-go/server" "github.com/openai/openai-go/v3" oaissestream "github.com/openai/openai-go/v3/packages/ssestream" "github.com/stretchr/testify/assert" @@ -50,6 +46,23 @@ const ( userID = "ae235cc1-9f8f-417d-a636-a7b170bac62e" ) +type ( + providerFunc func(addr string) aibridge.Provider + createRequestFunc func(*testing.T, string, []byte) *http.Request +) + +func newAnthropicProvider(addr string) aibridge.Provider { + return provider.NewAnthropic(anthropicCfg(addr, apiKey), nil) +} + +func newOpenAIProvider(addr string) aibridge.Provider { + return provider.NewOpenAI(openaiCfg(addr, apiKey)) +} + +func newBedrockProvider(addr string) aibridge.Provider { + return provider.NewAnthropic(anthropicCfg(addr, apiKey), testBedrockCfg(addr)) +} + func TestMain(m *testing.M) { goleak.VerifyTestMain(m) } @@ -93,7 +106,7 @@ func TestAnthropicMessages(t *testing.T) { recorderClient := &testutil.MockRecorder{} logger := slogtest.Make(t, &slogtest.Options{}).Leveled(slog.LevelDebug) providers := []aibridge.Provider{provider.NewAnthropic(anthropicCfg(upstream.URL, apiKey), nil)} - b, err := aibridge.NewRequestBridge(ctx, providers, recorderClient, mcp.NewServerProxyManager(nil, testTracer), logger, nil, testTracer) + b, err := aibridge.NewRequestBridge(ctx, providers, recorderClient, testutil.NilMCPManager(), logger, nil, testTracer) require.NoError(t, err) mockSrv := httptest.NewUnstartedServer(b) @@ -176,7 +189,7 @@ func TestAWSBedrockIntegration(t *testing.T) { logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug) b, err := aibridge.NewRequestBridge(ctx, []aibridge.Provider{ provider.NewAnthropic(anthropicCfg("http://unused", apiKey), bedrockCfg), - }, recorderClient, mcp.NewServerProxyManager(nil, testTracer), logger, nil, testTracer) + }, recorderClient, testutil.NilMCPManager(), logger, nil, testTracer) require.NoError(t, err) mockSrv := httptest.NewUnstartedServer(b) @@ -223,7 +236,7 @@ func TestAWSBedrockIntegration(t *testing.T) { logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug) b, err := aibridge.NewRequestBridge( ctx, []aibridge.Provider{provider.NewAnthropic(anthropicCfg(upstream.URL, apiKey), bedrockCfg)}, - recorderClient, mcp.NewServerProxyManager(nil, testTracer), logger, nil, testTracer) + recorderClient, testutil.NilMCPManager(), logger, nil, testTracer) require.NoError(t, err) mockBridgeSrv := httptest.NewUnstartedServer(b) @@ -311,7 +324,7 @@ func TestOpenAIChatCompletions(t *testing.T) { recorderClient := &testutil.MockRecorder{} logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) providers := []aibridge.Provider{provider.NewOpenAI(openaiCfg(upstream.URL, apiKey))} - b, err := aibridge.NewRequestBridge(t.Context(), providers, recorderClient, mcp.NewServerProxyManager(nil, testTracer), logger, nil, testTracer) + b, err := aibridge.NewRequestBridge(t.Context(), providers, recorderClient, testutil.NilMCPManager(), logger, nil, testTracer) require.NoError(t, err) mockSrv := httptest.NewUnstartedServer(b) @@ -403,13 +416,11 @@ func TestOpenAIChatCompletions(t *testing.T) { recorderClient := &testutil.MockRecorder{} // Setup MCP proxies with the tool from the fixture - mcpProxiers, mcpCalls := setupMCPServerProxiesForTest(t, testTracer) - mcpMgr := mcp.NewServerProxyManager(mcpProxiers, testTracer) - require.NoError(t, mcpMgr.Init(ctx)) + mockMCP := testutil.SetupMCPForTest(t, testTracer) logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) providers := []aibridge.Provider{provider.NewOpenAI(openaiCfg(upstream.URL, apiKey))} - b, err := aibridge.NewRequestBridge(t.Context(), providers, recorderClient, mcpMgr, logger, nil, testTracer) + b, err := aibridge.NewRequestBridge(t.Context(), providers, recorderClient, mockMCP, logger, nil, testTracer) require.NoError(t, err) mockSrv := httptest.NewUnstartedServer(b) @@ -440,7 +451,7 @@ func TestOpenAIChatCompletions(t *testing.T) { resp.Body.Close() // Verify the MCP tool was actually invoked - invocations := mcpCalls.getCallsByTool(mockToolName) + invocations := mockMCP.GetCallsByTool(testutil.MockToolName) require.Len(t, invocations, 1, "expected MCP tool to be invoked") // Verify tool was invoked with the expected args (if specified) @@ -455,7 +466,7 @@ func TestOpenAIChatCompletions(t *testing.T) { // Verify tool usage was recorded toolUsages := recorderClient.RecordedToolUsages() require.Len(t, toolUsages, 1) - assert.Equal(t, mockToolName, toolUsages[0].Tool) + assert.Equal(t, testutil.MockToolName, toolUsages[0].Tool) recorderClient.VerifyAllInterceptionsEnded(t) }) @@ -529,14 +540,14 @@ func TestSimple(t *testing.T) { t.Helper() logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) providers := []aibridge.Provider{provider.NewAnthropic(anthropicCfg(addr, apiKey), nil)} - return aibridge.NewRequestBridge(t.Context(), providers, client, mcp.NewServerProxyManager(nil, testTracer), logger, nil, testTracer) + return aibridge.NewRequestBridge(t.Context(), providers, client, testutil.NilMCPManager(), logger, nil, testTracer) } configureOpenAI := func(t *testing.T, addr string, client aibridge.Recorder) (*aibridge.RequestBridge, error) { t.Helper() logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) providers := []aibridge.Provider{provider.NewOpenAI(openaiCfg(addr, apiKey))} - return aibridge.NewRequestBridge(t.Context(), providers, client, mcp.NewServerProxyManager(nil, testTracer), logger, nil, testTracer) + return aibridge.NewRequestBridge(t.Context(), providers, client, testutil.NilMCPManager(), logger, nil, testTracer) } testCases := []struct { @@ -702,7 +713,7 @@ func TestFallthrough(t *testing.T) { configureFunc: func(addr string, client aibridge.Recorder) (aibridge.Provider, *aibridge.RequestBridge) { logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) provider := provider.NewAnthropic(anthropicCfg(addr, apiKey), nil) - bridge, err := aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{provider}, client, mcp.NewServerProxyManager(nil, testTracer), logger, nil, testTracer) + bridge, err := aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{provider}, client, testutil.NilMCPManager(), logger, nil, testTracer) require.NoError(t, err) return provider, bridge }, @@ -717,7 +728,7 @@ func TestFallthrough(t *testing.T) { configureFunc: func(addr string, client aibridge.Recorder) (aibridge.Provider, *aibridge.RequestBridge) { logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) provider := provider.NewOpenAI(openaiCfg(addr, apiKey)) - bridge, err := aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{provider}, client, mcp.NewServerProxyManager(nil, testTracer), logger, nil, testTracer) + bridge, err := aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{provider}, client, testutil.NilMCPManager(), logger, nil, testTracer) require.NoError(t, err) return provider, bridge }, @@ -732,7 +743,7 @@ func TestFallthrough(t *testing.T) { configureFunc: func(addr string, client aibridge.Recorder) (aibridge.Provider, *aibridge.RequestBridge) { logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) provider := provider.NewAnthropic(anthropicCfg(addr, apiKey), nil) - bridge, err := aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{provider}, client, mcp.NewServerProxyManager(nil, testTracer), logger, nil, testTracer) + bridge, err := aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{provider}, client, testutil.NilMCPManager(), logger, nil, testTracer) require.NoError(t, err) return provider, bridge }, @@ -747,7 +758,7 @@ func TestFallthrough(t *testing.T) { configureFunc: func(addr string, client aibridge.Recorder) (aibridge.Provider, *aibridge.RequestBridge) { logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) provider := provider.NewOpenAI(openaiCfg(addr, apiKey)) - bridge, err := aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{provider}, client, mcp.NewServerProxyManager(nil, testTracer), logger, nil, testTracer) + bridge, err := aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{provider}, client, testutil.NilMCPManager(), logger, nil, testTracer) require.NoError(t, err) return provider, bridge }, @@ -799,34 +810,6 @@ func TestFallthrough(t *testing.T) { } } -// setupMCPServerProxiesForTest creates a mock MCP server, initializes the MCP bridge, and returns the tools -func setupMCPServerProxiesForTest(t *testing.T, tracer trace.Tracer) (map[string]mcp.ServerProxier, *callAccumulator) { - t.Helper() - - // Setup Coder MCP integration - srv, acc := createMockMCPSrv(t) - mcpSrv := httptest.NewServer(srv) - t.Cleanup(mcpSrv.Close) - - logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) - proxy, err := mcp.NewStreamableHTTPServerProxy("coder", mcpSrv.URL, nil, nil, nil, logger, tracer) - require.NoError(t, err) - - // Initialize MCP client, fetch tools, and inject into bridge - ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) - t.Cleanup(cancel) - require.NoError(t, proxy.Init(ctx)) - tools := proxy.ListTools() - require.NotEmpty(t, tools) - - return map[string]mcp.ServerProxier{proxy.Name(): proxy}, acc -} - -type ( - configureFunc func(string, aibridge.Recorder, *mcp.ServerProxyManager) (*aibridge.RequestBridge, error) - createRequestFunc func(*testing.T, string, []byte) *http.Request -) - func TestAnthropicInjectedTools(t *testing.T) { t.Parallel() @@ -834,25 +817,19 @@ func TestAnthropicInjectedTools(t *testing.T) { t.Run(fmt.Sprintf("streaming=%v", streaming), func(t *testing.T) { t.Parallel() - configureFn := func(addr string, client aibridge.Recorder, srvProxyMgr *mcp.ServerProxyManager) (*aibridge.RequestBridge, error) { - logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) - providers := []aibridge.Provider{provider.NewAnthropic(anthropicCfg(addr, apiKey), nil)} - return aibridge.NewRequestBridge(t.Context(), providers, client, srvProxyMgr, logger, nil, testTracer) - } - // Build the requirements & make the assertions which are common to all providers. - recorderClient, mcpCalls, _, resp := setupInjectedToolTest(t, fixtures.AntSingleInjectedTool, streaming, configureFn, createAnthropicMessagesReq, anthropicToolResultValidator(t)) + recorderClient, mockMCP, resp := setupInjectedToolTest(t, fixtures.AntSingleInjectedTool, streaming, newAnthropicProvider, testTracer, userID, createAnthropicMessagesReq, anthropicToolResultValidator(t)) // Ensure expected tool was invoked with expected input. toolUsages := recorderClient.RecordedToolUsages() require.Len(t, toolUsages, 1) - require.Equal(t, mockToolName, toolUsages[0].Tool) + require.Equal(t, testutil.MockToolName, toolUsages[0].Tool) expected, err := json.Marshal(map[string]any{"owner": "admin"}) require.NoError(t, err) actual, err := json.Marshal(toolUsages[0].Args) require.NoError(t, err) require.EqualValues(t, expected, actual) - invocations := mcpCalls.getCallsByTool(mockToolName) + invocations := mockMCP.GetCallsByTool(testutil.MockToolName) require.Len(t, invocations, 1) actual, err = json.Marshal(invocations[0]) require.NoError(t, err) @@ -924,25 +901,19 @@ func TestOpenAIInjectedTools(t *testing.T) { t.Run(fmt.Sprintf("streaming=%v", streaming), func(t *testing.T) { t.Parallel() - configureFn := func(addr string, client aibridge.Recorder, srvProxyMgr *mcp.ServerProxyManager) (*aibridge.RequestBridge, error) { - logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) - providers := []aibridge.Provider{provider.NewOpenAI(openaiCfg(addr, apiKey))} - return aibridge.NewRequestBridge(t.Context(), providers, client, srvProxyMgr, logger, nil, testTracer) - } - // Build the requirements & make the assertions which are common to all providers. - recorderClient, mcpCalls, _, resp := setupInjectedToolTest(t, fixtures.OaiChatSingleInjectedTool, streaming, configureFn, createOpenAIChatCompletionsReq, openaiChatToolResultValidator(t)) + recorderClient, mockMCP, resp := setupInjectedToolTest(t, fixtures.OaiChatSingleInjectedTool, streaming, newOpenAIProvider, testTracer, userID, createOpenAIChatCompletionsReq, openaiChatToolResultValidator(t)) // Ensure expected tool was invoked with expected input. toolUsages := recorderClient.RecordedToolUsages() require.Len(t, toolUsages, 1) - require.Equal(t, mockToolName, toolUsages[0].Tool) + require.Equal(t, testutil.MockToolName, toolUsages[0].Tool) expected, err := json.Marshal(map[string]any{"owner": "admin"}) require.NoError(t, err) actual, err := json.Marshal(toolUsages[0].Args) require.NoError(t, err) require.EqualValues(t, expected, actual) - invocations := mcpCalls.getCallsByTool(mockToolName) + invocations := mockMCP.GetCallsByTool(testutil.MockToolName) require.Len(t, invocations, 1) actual, err = json.Marshal(invocations[0]) require.NoError(t, err) @@ -1095,8 +1066,17 @@ func openaiChatToolResultValidator(t *testing.T) func(*http.Request, []byte) { } } -// setupInjectedToolTest abstracts the common aspects required for the Test*InjectedTools tests. -func setupInjectedToolTest(t *testing.T, fixture []byte, streaming bool, configureFn configureFunc, createRequestFn func(*testing.T, string, []byte) *http.Request, toolRequestValidatorFn func(*http.Request, []byte)) (*testutil.MockRecorder, *callAccumulator, map[string]mcp.ServerProxier, *http.Response) { +// setupInjectedToolTest abstracts common setup required for injected-tool integration tests. +func setupInjectedToolTest( + t *testing.T, + fixture []byte, + streaming bool, + providerFn providerFunc, + tracer trace.Tracer, + userID string, + createRequestFn func(*testing.T, string, []byte) *http.Request, + toolRequestValidatorFn func(*http.Request, []byte), +) (*testutil.MockRecorder, *testutil.MockMCP, *http.Response) { t.Helper() ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) @@ -1113,13 +1093,17 @@ func setupInjectedToolTest(t *testing.T, fixture []byte, streaming bool, configu recorderClient := &testutil.MockRecorder{} - // Setup MCP mcpProxiers. - mcpProxiers, acc := setupMCPServerProxiesForTest(t, testTracer) - - // Configure the bridge with injected tools. - mcpMgr := mcp.NewServerProxyManager(mcpProxiers, testTracer) - require.NoError(t, mcpMgr.Init(ctx)) - b, err := configureFn(upstream.URL, recorderClient, mcpMgr) + mockMCP := testutil.SetupMCPForTest(t, tracer) + logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) + b, err := aibridge.NewRequestBridge( + t.Context(), + []aibridge.Provider{providerFn(upstream.URL)}, + recorderClient, + mockMCP, + logger, + nil, + tracer, + ) require.NoError(t, err) // Invoke request to mocked API via aibridge. @@ -1148,7 +1132,7 @@ func setupInjectedToolTest(t *testing.T, fixture []byte, streaming bool, configu return upstream.Calls.Load() == 2 }, time.Second*10, time.Millisecond*50) - return recorderClient, acc, mcpProxiers, resp + return recorderClient, mockMCP, resp } func TestErrorHandling(t *testing.T) { @@ -1160,14 +1144,14 @@ func TestErrorHandling(t *testing.T) { name string fixture []byte createRequestFunc createRequestFunc - configureFunc configureFunc + configureFunc func(string, aibridge.Recorder, mcp.ServerProxier) (*aibridge.RequestBridge, error) responseHandlerFn func(resp *http.Response) }{ { name: config.ProviderAnthropic, fixture: fixtures.AntNonStreamError, createRequestFunc: createAnthropicMessagesReq, - configureFunc: func(addr string, client aibridge.Recorder, srvProxyMgr *mcp.ServerProxyManager) (*aibridge.RequestBridge, error) { + configureFunc: func(addr string, client aibridge.Recorder, srvProxyMgr mcp.ServerProxier) (*aibridge.RequestBridge, error) { logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) providers := []aibridge.Provider{provider.NewAnthropic(anthropicCfg(addr, apiKey), nil)} return aibridge.NewRequestBridge(t.Context(), providers, client, srvProxyMgr, logger, nil, testTracer) @@ -1185,7 +1169,7 @@ func TestErrorHandling(t *testing.T) { name: config.ProviderOpenAI, fixture: fixtures.OaiChatNonStreamError, createRequestFunc: createOpenAIChatCompletionsReq, - configureFunc: func(addr string, client aibridge.Recorder, srvProxyMgr *mcp.ServerProxyManager) (*aibridge.RequestBridge, error) { + configureFunc: func(addr string, client aibridge.Recorder, srvProxyMgr mcp.ServerProxier) (*aibridge.RequestBridge, error) { logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) providers := []aibridge.Provider{provider.NewOpenAI(openaiCfg(addr, apiKey))} return aibridge.NewRequestBridge(t.Context(), providers, client, srvProxyMgr, logger, nil, testTracer) @@ -1219,7 +1203,7 @@ func TestErrorHandling(t *testing.T) { recorderClient := &testutil.MockRecorder{} - b, err := tc.configureFunc(mockSrv.URL, recorderClient, mcp.NewServerProxyManager(nil, testTracer)) + b, err := tc.configureFunc(mockSrv.URL, recorderClient, testutil.NilMCPManager()) require.NoError(t, err) // Invoke request to mocked API via aibridge. @@ -1253,14 +1237,14 @@ func TestErrorHandling(t *testing.T) { name string fixture []byte createRequestFunc createRequestFunc - configureFunc configureFunc + configureFunc func(string, aibridge.Recorder, mcp.ServerProxier) (*aibridge.RequestBridge, error) responseHandlerFn func(resp *http.Response) }{ { name: config.ProviderAnthropic, fixture: fixtures.AntMidStreamError, createRequestFunc: createAnthropicMessagesReq, - configureFunc: func(addr string, client aibridge.Recorder, srvProxyMgr *mcp.ServerProxyManager) (*aibridge.RequestBridge, error) { + configureFunc: func(addr string, client aibridge.Recorder, srvProxyMgr mcp.ServerProxier) (*aibridge.RequestBridge, error) { logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) providers := []aibridge.Provider{provider.NewAnthropic(anthropicCfg(addr, apiKey), nil)} return aibridge.NewRequestBridge(t.Context(), providers, client, srvProxyMgr, logger, nil, testTracer) @@ -1279,7 +1263,7 @@ func TestErrorHandling(t *testing.T) { name: config.ProviderOpenAI, fixture: fixtures.OaiChatMidStreamError, createRequestFunc: createOpenAIChatCompletionsReq, - configureFunc: func(addr string, client aibridge.Recorder, srvProxyMgr *mcp.ServerProxyManager) (*aibridge.RequestBridge, error) { + configureFunc: func(addr string, client aibridge.Recorder, srvProxyMgr mcp.ServerProxier) (*aibridge.RequestBridge, error) { logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) providers := []aibridge.Provider{provider.NewOpenAI(openaiCfg(addr, apiKey))} return aibridge.NewRequestBridge(t.Context(), providers, client, srvProxyMgr, logger, nil, testTracer) @@ -1315,7 +1299,7 @@ func TestErrorHandling(t *testing.T) { recorderClient := &testutil.MockRecorder{} - b, err := tc.configureFunc(upstream.URL, recorderClient, mcp.NewServerProxyManager(nil, testTracer)) + b, err := tc.configureFunc(upstream.URL, recorderClient, testutil.NilMCPManager()) require.NoError(t, err) // Invoke request to mocked API via aibridge. @@ -1352,13 +1336,13 @@ func TestStableRequestEncoding(t *testing.T) { name string fixture []byte createRequestFunc createRequestFunc - configureFunc configureFunc + configureFunc func(string, aibridge.Recorder, mcp.ServerProxier) (*aibridge.RequestBridge, error) }{ { name: config.ProviderAnthropic, fixture: fixtures.AntSimple, createRequestFunc: createAnthropicMessagesReq, - configureFunc: func(addr string, client aibridge.Recorder, srvProxyMgr *mcp.ServerProxyManager) (*aibridge.RequestBridge, error) { + configureFunc: func(addr string, client aibridge.Recorder, srvProxyMgr mcp.ServerProxier) (*aibridge.RequestBridge, error) { providers := []aibridge.Provider{provider.NewAnthropic(anthropicCfg(addr, apiKey), nil)} return aibridge.NewRequestBridge(t.Context(), providers, client, srvProxyMgr, logger, nil, testTracer) }, @@ -1367,7 +1351,7 @@ func TestStableRequestEncoding(t *testing.T) { name: config.ProviderOpenAI, fixture: fixtures.OaiChatSimple, createRequestFunc: createOpenAIChatCompletionsReq, - configureFunc: func(addr string, client aibridge.Recorder, srvProxyMgr *mcp.ServerProxyManager) (*aibridge.RequestBridge, error) { + configureFunc: func(addr string, client aibridge.Recorder, srvProxyMgr mcp.ServerProxier) (*aibridge.RequestBridge, error) { providers := []aibridge.Provider{provider.NewOpenAI(openaiCfg(addr, apiKey))} return aibridge.NewRequestBridge(t.Context(), providers, client, srvProxyMgr, logger, nil, testTracer) }, @@ -1382,11 +1366,7 @@ func TestStableRequestEncoding(t *testing.T) { t.Cleanup(cancel) // Setup MCP tools. - mcpProxiers, _ := setupMCPServerProxiesForTest(t, testTracer) - - // Configure the bridge with injected tools. - mcpMgr := mcp.NewServerProxyManager(mcpProxiers, testTracer) - require.NoError(t, mcpMgr.Init(ctx)) + mockMCP := testutil.SetupMCPForTest(t, testTracer) fix := fixtures.Parse(t, tc.fixture) @@ -1399,7 +1379,7 @@ func TestStableRequestEncoding(t *testing.T) { upstream := testutil.NewMockUpstream(t, ctx, responses...) recorder := &testutil.MockRecorder{} - bridge, err := tc.configureFunc(upstream.URL, recorder, mcpMgr) + bridge, err := tc.configureFunc(upstream.URL, recorder, mockMCP) require.NoError(t, err) // Invoke request to mocked API via aibridge. @@ -1512,14 +1492,12 @@ func TestAnthropicToolChoiceParallelDisabled(t *testing.T) { t.Cleanup(cancel) // Setup MCP tools conditionally. - var mcpMgr *mcp.ServerProxyManager + var mcpMgr mcp.ServerProxier if tc.withInjectedTools { - mcpProxiers, _ := setupMCPServerProxiesForTest(t, testTracer) - mcpMgr = mcp.NewServerProxyManager(mcpProxiers, testTracer) + mcpMgr = testutil.SetupMCPForTest(t, testTracer) } else { - mcpMgr = mcp.NewServerProxyManager(nil, testTracer) + mcpMgr = testutil.NilMCPManager() } - require.NoError(t, mcpMgr.Init(ctx)) fix := fixtures.Parse(t, fixtures.AntSimple) upstream := testutil.NewMockUpstream(t, ctx, testutil.NewFixtureResponse(fix)) @@ -1597,7 +1575,7 @@ func TestThinkingAdaptiveIsPreserved(t *testing.T) { recorderClient := &testutil.MockRecorder{} logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) providers := []aibridge.Provider{provider.NewAnthropic(anthropicCfg(upstream.URL, apiKey), nil)} - bridge, err := aibridge.NewRequestBridge(ctx, providers, recorderClient, mcp.NewServerProxyManager(nil, testTracer), logger, nil, testTracer) + bridge, err := aibridge.NewRequestBridge(ctx, providers, recorderClient, testutil.NilMCPManager(), logger, nil, testTracer) require.NoError(t, err) bridgeSrv := httptest.NewUnstartedServer(bridge) @@ -1647,7 +1625,7 @@ func TestEnvironmentDoNotLeak(t *testing.T) { configureFunc: func(addr string, client aibridge.Recorder) (*aibridge.RequestBridge, error) { logger := slogtest.Make(t, &slogtest.Options{}).Leveled(slog.LevelDebug) providers := []aibridge.Provider{provider.NewAnthropic(anthropicCfg(addr, apiKey), nil)} - return aibridge.NewRequestBridge(t.Context(), providers, client, mcp.NewServerProxyManager(nil, testTracer), logger, nil, testTracer) + return aibridge.NewRequestBridge(t.Context(), providers, client, testutil.NilMCPManager(), logger, nil, testTracer) }, createRequest: createAnthropicMessagesReq, envVars: map[string]string{ @@ -1661,7 +1639,7 @@ func TestEnvironmentDoNotLeak(t *testing.T) { configureFunc: func(addr string, client aibridge.Recorder) (*aibridge.RequestBridge, error) { logger := slogtest.Make(t, &slogtest.Options{}).Leveled(slog.LevelDebug) providers := []aibridge.Provider{provider.NewOpenAI(openaiCfg(addr, apiKey))} - return aibridge.NewRequestBridge(t.Context(), providers, client, mcp.NewServerProxyManager(nil, testTracer), logger, nil, testTracer) + return aibridge.NewRequestBridge(t.Context(), providers, client, testutil.NilMCPManager(), logger, nil, testTracer) }, createRequest: createOpenAIChatCompletionsReq, envVars: map[string]string{ @@ -1817,7 +1795,7 @@ func TestActorHeaders(t *testing.T) { provider := tc.createProviderFn(srv.URL, apiKey, send) logger := slogtest.Make(t, &slogtest.Options{}).Leveled(slog.LevelDebug) - b, err := aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{provider}, rec, mcp.NewServerProxyManager(nil, testTracer), logger, nil, testTracer) + b, err := aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{provider}, rec, testutil.NilMCPManager(), logger, nil, testTracer) require.NoError(t, err, "failed to create handler") mockSrv := httptest.NewUnstartedServer(b) @@ -1899,80 +1877,6 @@ func createOpenAIChatCompletionsReq(t *testing.T, baseURL string, input []byte) return req } -const mockToolName = "coder_list_workspaces" - -// callAccumulator tracks all tool invocations by name and each instance's arguments. -type callAccumulator struct { - calls map[string][]any - callsMu sync.Mutex - toolErrors map[string]string -} - -func newCallAccumulator() *callAccumulator { - return &callAccumulator{ - calls: make(map[string][]any), - toolErrors: make(map[string]string), - } -} - -func (a *callAccumulator) setToolError(tool string, errMsg string) { - a.callsMu.Lock() - defer a.callsMu.Unlock() - a.toolErrors[tool] = errMsg -} - -func (a *callAccumulator) getToolError(tool string) (string, bool) { - a.callsMu.Lock() - defer a.callsMu.Unlock() - errMsg, ok := a.toolErrors[tool] - return errMsg, ok -} - -func (a *callAccumulator) addCall(tool string, args any) { - a.callsMu.Lock() - defer a.callsMu.Unlock() - - a.calls[tool] = append(a.calls[tool], args) -} - -func (a *callAccumulator) getCallsByTool(name string) []any { - a.callsMu.Lock() - defer a.callsMu.Unlock() - - // Protect against concurrent access of the slice. - result := make([]any, len(a.calls[name])) - copy(result, a.calls[name]) - return result -} - -func createMockMCPSrv(t *testing.T) (http.Handler, *callAccumulator) { - t.Helper() - - s := server.NewMCPServer( - "Mock coder MCP server", - "1.0.0", - server.WithToolCapabilities(true), - ) - - // Accumulate tool calls & their arguments. - acc := newCallAccumulator() - - for _, name := range []string{mockToolName, "coder_list_templates", "coder_template_version_parameters", "coder_get_authenticated_user", "coder_create_workspace_build", "coder_delete_template"} { - tool := mcplib.NewTool(name, - mcplib.WithDescription(fmt.Sprintf("Mock of the %s tool", name)), - ) - s.AddTool(tool, func(ctx context.Context, request mcplib.CallToolRequest) (*mcplib.CallToolResult, error) { - acc.addCall(request.Params.Name, request.Params.Arguments) - if errMsg, ok := acc.getToolError(request.Params.Name); ok { - return nil, errors.New(errMsg) - } - return mcplib.NewToolResultText("mock"), nil - }) - } - - return server.NewStreamableHTTPServer(s), acc -} - func openaiCfg(url, key string) config.OpenAI { return config.OpenAI{ BaseURL: url, diff --git a/circuit_breaker_integration_test.go b/circuit_breaker_integration_test.go index 6b03970..9d0d4a7 100644 --- a/circuit_breaker_integration_test.go +++ b/circuit_breaker_integration_test.go @@ -18,7 +18,6 @@ import ( "github.com/coder/aibridge" "github.com/coder/aibridge/config" "github.com/coder/aibridge/internal/testutil" - "github.com/coder/aibridge/mcp" "github.com/coder/aibridge/metrics" "github.com/coder/aibridge/provider" "github.com/prometheus/client_golang/prometheus" @@ -145,7 +144,7 @@ func TestCircuitBreaker_FullRecoveryCycle(t *testing.T) { bridge, err := aibridge.NewRequestBridge(ctx, []provider.Provider{prov}, &testutil.MockRecorder{}, - mcp.NewServerProxyManager(nil, tracer), + testutil.NilMCPManager(), logger, metrics, tracer, @@ -318,7 +317,7 @@ func TestCircuitBreaker_HalfOpenFailure(t *testing.T) { bridge, err := aibridge.NewRequestBridge(ctx, []provider.Provider{prov}, &testutil.MockRecorder{}, - mcp.NewServerProxyManager(nil, tracer), + testutil.NilMCPManager(), logger, metrics, tracer, @@ -484,7 +483,7 @@ func TestCircuitBreaker_HalfOpenMaxRequests(t *testing.T) { bridge, err := aibridge.NewRequestBridge(ctx, []provider.Provider{prov}, &testutil.MockRecorder{}, - mcp.NewServerProxyManager(nil, tracer), + testutil.NilMCPManager(), logger, metrics, tracer, @@ -622,7 +621,7 @@ func TestCircuitBreaker_PerModelIsolation(t *testing.T) { bridge, err := aibridge.NewRequestBridge(ctx, []provider.Provider{prov}, &testutil.MockRecorder{}, - mcp.NewServerProxyManager(nil, tracer), + testutil.NilMCPManager(), logger, m, tracer, diff --git a/internal/testutil/mockmcp.go b/internal/testutil/mockmcp.go new file mode 100644 index 0000000..cf07332 --- /dev/null +++ b/internal/testutil/mockmcp.go @@ -0,0 +1,151 @@ +package testutil + +import ( + "context" + "errors" + "fmt" + "net/http" + "net/http/httptest" + "sync" + "testing" + "time" + + "cdr.dev/slog/v3" + "cdr.dev/slog/v3/sloggers/slogtest" + "github.com/coder/aibridge/mcp" + "github.com/mark3labs/mcp-go/client/transport" + mcplib "github.com/mark3labs/mcp-go/mcp" + "github.com/mark3labs/mcp-go/server" + "github.com/stretchr/testify/require" + "go.opentelemetry.io/otel/trace" + "go.opentelemetry.io/otel/trace/noop" +) + +// MockToolName is the primary mock tool name used in MCP tests. +const MockToolName = "coder_list_workspaces" + +// MockMCP wraps a real mcp.ServerProxier with test assertion helpers. +// Implements mcp.ServerProxier so it can be passed directly to NewRequestBridge. +type MockMCP struct { + mcp.ServerProxier + calls *callAccumulator +} + +// GetCallsByTool returns recorded arguments for a given tool name. +func (m *MockMCP) GetCallsByTool(name string) []any { + return m.calls.getCallsByTool(name) +} + +// SetToolError configures a tool to return an error when invoked. +func (m *MockMCP) SetToolError(tool, errMsg string) { + m.calls.setToolError(tool, errMsg) +} + +// SetupMCPForTest creates a ready-to-use MCP server with proxy named "coder". +func SetupMCPForTest(t *testing.T, tracer trace.Tracer) *MockMCP { + t.Helper() + return SetupMCPForTestWithName(t, "coder", tracer) +} + +func SetupMCPForTestWithName(t *testing.T, name string, tracer trace.Tracer) *MockMCP { + t.Helper() + + srv, acc := createMockMCPSrv(t) + mcpSrv := httptest.NewServer(srv) + t.Cleanup(mcpSrv.Close) // FIRST registered → runs LAST (LIFO) + + logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) + // Use a dedicated HTTP client so MCP mocks don't use http.DefaultTransport, + // which can break when httptest.Server calls CloseIdleConnections in parallel + // resulting in error `init MCP client: failed to send initialized notification: failed to send request: failed to send request: Post "http://127.0.0.1:43843": net/http: HTTP/1.x transport connection broken: http: CloseIdleConnections called` + // https://github.com/golang/go/blob/44ec057a3e89482cf775f5eaaf03b0b5fcab1fa4/src/net/http/httptest/server.go#L268 + httpClient := &http.Client{Transport: &http.Transport{}} + proxy, err := mcp.NewStreamableHTTPServerProxy(name, mcpSrv.URL, nil, nil, nil, logger, tracer, transport.WithHTTPBasicClient(httpClient)) + require.NoError(t, err) + + mgr := mcp.NewServerProxyManager(map[string]mcp.ServerProxier{proxy.Name(): proxy}, tracer) + t.Cleanup(func() { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + require.NoError(t, mgr.Shutdown(ctx)) + }) + + ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) + t.Cleanup(cancel) + require.NoError(t, mgr.Init(ctx)) + require.NotEmpty(t, mgr.ListTools(), "mock MCP server should expose tools after init") + + return &MockMCP{ServerProxier: mgr, calls: acc} +} + +func NilMCPManager() mcp.ServerProxier { + return mcp.NewServerProxyManager(nil, noop.NewTracerProvider().Tracer("")) +} + +// callAccumulator tracks all tool invocations by name and each instance's arguments. +type callAccumulator struct { + calls map[string][]any + callsMu sync.Mutex + toolErrors map[string]string +} + +func newCallAccumulator() *callAccumulator { + return &callAccumulator{ + calls: make(map[string][]any), + toolErrors: make(map[string]string), + } +} + +func (a *callAccumulator) setToolError(tool string, errMsg string) { + a.callsMu.Lock() + defer a.callsMu.Unlock() + a.toolErrors[tool] = errMsg +} + +func (a *callAccumulator) getToolError(tool string) (string, bool) { + a.callsMu.Lock() + defer a.callsMu.Unlock() + errMsg, ok := a.toolErrors[tool] + return errMsg, ok +} + +func (a *callAccumulator) addCall(tool string, args any) { + a.callsMu.Lock() + defer a.callsMu.Unlock() + a.calls[tool] = append(a.calls[tool], args) +} + +func (a *callAccumulator) getCallsByTool(name string) []any { + a.callsMu.Lock() + defer a.callsMu.Unlock() + result := make([]any, len(a.calls[name])) + copy(result, a.calls[name]) + return result +} + +func createMockMCPSrv(t *testing.T) (http.Handler, *callAccumulator) { + t.Helper() + + s := server.NewMCPServer( + "Mock coder MCP server", + "1.0.0", + server.WithToolCapabilities(true), + ) + + acc := newCallAccumulator() + + for _, name := range []string{MockToolName, "coder_list_templates", "coder_template_version_parameters", "coder_get_authenticated_user", "coder_create_workspace_build", "coder_delete_template"} { + tool := mcplib.NewTool(name, + mcplib.WithDescription(fmt.Sprintf("Mock of the %s tool", name)), + ) + s.AddTool(tool, func(ctx context.Context, request mcplib.CallToolRequest) (*mcplib.CallToolResult, error) { + acc.addCall(request.Params.Name, request.Params.Arguments) + if errMsg, ok := acc.getToolError(request.Params.Name); ok { + return nil, errors.New(errMsg) + } + return mcplib.NewToolResultText("mock"), nil + }) + } + + return server.NewStreamableHTTPServer(s), acc +} diff --git a/mcp/proxy_streamable_http.go b/mcp/proxy_streamable_http.go index 18f693c..795225d 100644 --- a/mcp/proxy_streamable_http.go +++ b/mcp/proxy_streamable_http.go @@ -32,8 +32,8 @@ type StreamableHTTPServerProxy struct { tools map[string]*Tool } -func NewStreamableHTTPServerProxy(serverName, serverURL string, headers map[string]string, allowlist, denylist *regexp.Regexp, logger slog.Logger, tracer trace.Tracer) (*StreamableHTTPServerProxy, error) { - var opts []transport.StreamableHTTPCOption +func NewStreamableHTTPServerProxy(serverName, serverURL string, headers map[string]string, allowlist, denylist *regexp.Regexp, logger slog.Logger, tracer trace.Tracer, opts ...transport.StreamableHTTPCOption) (*StreamableHTTPServerProxy, error) { + // nit: headers should be passed in as options instead of a separate parameter. This will be a breaking change. if headers != nil { opts = append(opts, transport.WithHTTPHeaders(headers)) } diff --git a/metrics_integration_test.go b/metrics_integration_test.go index d4c0586..6d4e678 100644 --- a/metrics_integration_test.go +++ b/metrics_integration_test.go @@ -16,7 +16,6 @@ import ( aibcontext "github.com/coder/aibridge/context" "github.com/coder/aibridge/fixtures" "github.com/coder/aibridge/internal/testutil" - "github.com/coder/aibridge/mcp" "github.com/coder/aibridge/metrics" "github.com/coder/aibridge/provider" "github.com/prometheus/client_golang/prometheus" @@ -295,11 +294,9 @@ func TestMetrics_InjectedToolUseCount(t *testing.T) { provider := provider.NewAnthropic(anthropicCfg(upstream.URL, apiKey), nil) // Setup mocked MCP server & tools. - mcpProxiers, _ := setupMCPServerProxiesForTest(t, testTracer) - mcpMgr := mcp.NewServerProxyManager(mcpProxiers, testTracer) - require.NoError(t, mcpMgr.Init(ctx)) + mockMCP := testutil.SetupMCPForTest(t, testTracer) - bridge, err := aibridge.NewRequestBridge(ctx, []aibridge.Provider{provider}, recorder, mcpMgr, logger, metrics, testTracer) + bridge, err := aibridge.NewRequestBridge(ctx, []aibridge.Provider{provider}, recorder, mockMCP, logger, metrics, testTracer) require.NoError(t, err) srv := httptest.NewUnstartedServer(bridge) @@ -327,7 +324,7 @@ func TestMetrics_InjectedToolUseCount(t *testing.T) { actualServerURL := *recorder.ToolUsages()[0].ServerURL count := promtest.ToFloat64(metrics.InjectedToolUseCount.WithLabelValues( - config.ProviderAnthropic, "claude-sonnet-4-20250514", actualServerURL, mockToolName)) + config.ProviderAnthropic, "claude-sonnet-4-20250514", actualServerURL, testutil.MockToolName)) require.Equal(t, 1.0, count) } @@ -341,7 +338,7 @@ func newTestSrv(t *testing.T, ctx context.Context, provider aibridge.Provider, m } wrappedRecorder := aibridge.NewRecorder(logger, tracer, clientFn) - bridge, err := aibridge.NewRequestBridge(ctx, []aibridge.Provider{provider}, wrappedRecorder, mcp.NewServerProxyManager(nil, testTracer), logger, metrics, tracer) + bridge, err := aibridge.NewRequestBridge(ctx, []aibridge.Provider{provider}, wrappedRecorder, testutil.NilMCPManager(), logger, metrics, tracer) require.NoError(t, err) srv := httptest.NewUnstartedServer(bridge) diff --git a/responses_integration_test.go b/responses_integration_test.go index 4b82bfb..ccd6a97 100644 --- a/responses_integration_test.go +++ b/responses_integration_test.go @@ -21,7 +21,6 @@ import ( aibcontext "github.com/coder/aibridge/context" "github.com/coder/aibridge/fixtures" "github.com/coder/aibridge/internal/testutil" - "github.com/coder/aibridge/mcp" "github.com/coder/aibridge/provider" "github.com/coder/aibridge/recorder" "github.com/openai/openai-go/v3/responses" @@ -872,18 +871,16 @@ func TestResponsesInjectedTool(t *testing.T) { upstream := testutil.NewMockUpstream(t, ctx, testutil.NewFixtureResponse(fix), testutil.NewFixtureToolResponse(fix)) // Setup MCP server proxies (with mock tools). - mcpProxiers, mcpCalls := setupMCPServerProxiesForTest(t, testTracer) + mockMCP := testutil.SetupMCPForTest(t, testTracer) if tc.expectToolError != "" { - mcpCalls.setToolError(tc.mcpToolName, tc.expectToolError) + mockMCP.SetToolError(tc.mcpToolName, tc.expectToolError) } - mcpMgr := mcp.NewServerProxyManager(mcpProxiers, testTracer) - require.NoError(t, mcpMgr.Init(ctx)) prov := provider.NewOpenAI(openaiCfg(upstream.URL, apiKey)) mockRecorder := &testutil.MockRecorder{} logger := slogtest.Make(t, &slogtest.Options{}).Leveled(slog.LevelDebug) - bridge, err := aibridge.NewRequestBridge(ctx, []aibridge.Provider{prov}, mockRecorder, mcpMgr, logger, nil, testTracer) + bridge, err := aibridge.NewRequestBridge(ctx, []aibridge.Provider{prov}, mockRecorder, mockMCP, logger, nil, testTracer) require.NoError(t, err) srv := httptest.NewUnstartedServer(bridge) @@ -908,7 +905,7 @@ func TestResponsesInjectedTool(t *testing.T) { }, time.Second*10, time.Millisecond*50) // Verify the injected tool was invoked via MCP. - invocations := mcpCalls.getCallsByTool(tc.mcpToolName) + invocations := mockMCP.GetCallsByTool(tc.mcpToolName) require.Len(t, invocations, 1, "expected MCP tool to be invoked once") // Verify the injected tool usage was recorded. diff --git a/trace_integration_test.go b/trace_integration_test.go index 608a7ac..a62e58e 100644 --- a/trace_integration_test.go +++ b/trace_integration_test.go @@ -4,19 +4,14 @@ import ( "context" "fmt" "net/http" - "net/http/httptest" "slices" "strings" "testing" "time" - "cdr.dev/slog/v3" - "cdr.dev/slog/v3/sloggers/slogtest" - "github.com/coder/aibridge" "github.com/coder/aibridge/config" "github.com/coder/aibridge/fixtures" "github.com/coder/aibridge/internal/testutil" - "github.com/coder/aibridge/mcp" "github.com/coder/aibridge/provider" "github.com/coder/aibridge/tracing" "github.com/stretchr/testify/assert" @@ -274,33 +269,81 @@ func TestTraceAnthropicErr(t *testing.T) { } } -func TestAnthropicInjectedToolsTrace(t *testing.T) { +func TestInjectedToolsTrace(t *testing.T) { t.Parallel() tests := []struct { - name string - streaming bool - bedrock bool + name string + streaming bool + bedrock bool + fixture []byte + providerFn providerFunc + createReqFn func(*testing.T, string, []byte) *http.Request + expectModel string + expectPath string + expectProvider string }{ { - name: "anthr_blocking", - streaming: false, - bedrock: false, + name: "anthr_blocking", + streaming: false, + fixture: fixtures.AntSingleInjectedTool, + providerFn: newAnthropicProvider, + createReqFn: createAnthropicMessagesReq, + expectModel: "claude-sonnet-4-20250514", + expectPath: "/anthropic/v1/messages", + expectProvider: config.ProviderAnthropic, }, { - name: "anthr_streaming", - streaming: true, - bedrock: false, + name: "anthr_streaming", + streaming: true, + fixture: fixtures.AntSingleInjectedTool, + providerFn: newAnthropicProvider, + createReqFn: createAnthropicMessagesReq, + expectModel: "claude-sonnet-4-20250514", + expectPath: "/anthropic/v1/messages", + expectProvider: config.ProviderAnthropic, }, { - name: "bedrock_blocking", - streaming: false, - bedrock: true, + name: "bedrock_blocking", + streaming: false, + bedrock: true, + fixture: fixtures.AntSingleInjectedTool, + providerFn: newBedrockProvider, + createReqFn: createAnthropicMessagesReq, + expectModel: "beddel", + expectPath: "/anthropic/v1/messages", + expectProvider: config.ProviderAnthropic, }, { - name: "bedrock_streaming", - streaming: true, - bedrock: true, + name: "bedrock_streaming", + streaming: true, + bedrock: true, + fixture: fixtures.AntSingleInjectedTool, + providerFn: newBedrockProvider, + createReqFn: createAnthropicMessagesReq, + expectModel: "beddel", + expectPath: "/anthropic/v1/messages", + expectProvider: config.ProviderAnthropic, + }, + { + name: "openai_blocking", + streaming: false, + fixture: fixtures.OaiChatSingleInjectedTool, + providerFn: newOpenAIProvider, + createReqFn: createOpenAIChatCompletionsReq, + expectModel: "gpt-4.1", + expectPath: "/openai/v1/chat/completions", + expectProvider: config.ProviderOpenAI, + }, + { + name: "openai_streaming", + streaming: true, + fixture: fixtures.OaiChatSingleInjectedTool, + providerFn: newOpenAIProvider, + createReqFn: createOpenAIChatCompletionsReq, + expectModel: "gpt-4.1", + expectPath: "/openai/v1/chat/completions", + expectProvider: config.ProviderOpenAI, }, } @@ -313,54 +356,41 @@ func TestAnthropicInjectedToolsTrace(t *testing.T) { tracer := tp.Tracer(t.Name()) defer func() { _ = tp.Shutdown(t.Context()) }() - configureFn := func(addr string, client aibridge.Recorder, srvProxyMgr *mcp.ServerProxyManager) (*aibridge.RequestBridge, error) { - logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) - var bedrockCfg *config.AWSBedrock - if tc.bedrock { - bedrockCfg = testBedrockCfg(addr) - } - providers := []aibridge.Provider{provider.NewAnthropic(anthropicCfg(addr, apiKey), bedrockCfg)} - return aibridge.NewRequestBridge(t.Context(), providers, client, srvProxyMgr, logger, nil, tracer) - } - - var reqBody string - reqFunc := func(t *testing.T, baseURL string, input []byte) *http.Request { - reqBody = string(input) - return createAnthropicMessagesReq(t, baseURL, input) + var validatorFn func(*http.Request, []byte) + if tc.expectProvider == config.ProviderAnthropic { + validatorFn = anthropicToolResultValidator(t) + } else { + validatorFn = openaiChatToolResultValidator(t) } - // Build the requirements & make the assertions which are common to all providers. - recorderClient, _, proxies, resp := setupInjectedToolTest(t, fixtures.AntSingleInjectedTool, tc.streaming, configureFn, reqFunc, anthropicToolResultValidator(t)) - + recorderClient, mockMCP, resp := setupInjectedToolTest( + t, tc.fixture, tc.streaming, tc.providerFn, tracer, userID, + tc.createReqFn, validatorFn, + ) defer resp.Body.Close() require.Len(t, recorderClient.RecordedInterceptions(), 1) intcID := recorderClient.RecordedInterceptions()[0].ID - model := gjson.Get(string(reqBody), "model").Str - if tc.bedrock { - model = "beddel" - } + tool := mockMCP.ListTools()[0] - for _, proxy := range proxies { - require.NotEmpty(t, proxy.ListTools()) - tool := proxy.ListTools()[0] - - attrs := []attribute.KeyValue{ - attribute.String(tracing.RequestPath, "/anthropic/v1/messages"), - attribute.String(tracing.InterceptionID, intcID), - attribute.String(tracing.Provider, config.ProviderAnthropic), - attribute.String(tracing.Model, model), - attribute.String(tracing.InitiatorID, userID), - attribute.String(tracing.MCPInput, "{\"owner\":\"admin\"}"), - attribute.String(tracing.MCPToolName, "coder_list_workspaces"), - attribute.String(tracing.MCPServerName, tool.ServerName), - attribute.String(tracing.MCPServerURL, tool.ServerURL), - attribute.Bool(tracing.Streaming, tc.streaming), - attribute.Bool(tracing.IsBedrock, tc.bedrock), - } - verifyTraces(t, sr, []expectTrace{{"Intercept.ProcessRequest.ToolCall", 1, codes.Unset}}, attrs) + attrs := []attribute.KeyValue{ + attribute.String(tracing.RequestPath, tc.expectPath), + attribute.String(tracing.InterceptionID, intcID), + attribute.String(tracing.Provider, tc.expectProvider), + attribute.String(tracing.Model, tc.expectModel), + attribute.String(tracing.InitiatorID, userID), + attribute.String(tracing.MCPInput, `{"owner":"admin"}`), + attribute.String(tracing.MCPToolName, "coder_list_workspaces"), + attribute.String(tracing.MCPServerName, tool.ServerName), + attribute.String(tracing.MCPServerURL, tool.ServerURL), + attribute.Bool(tracing.Streaming, tc.streaming), } + if tc.expectProvider == config.ProviderAnthropic { + attrs = append(attrs, attribute.Bool(tracing.IsBedrock, tc.bedrock)) + } + + verifyTraces(t, sr, []expectTrace{{"Intercept.ProcessRequest.ToolCall", 1, codes.Unset}}, attrs) }) } } @@ -659,60 +689,6 @@ func TestTraceOpenAIErr(t *testing.T) { } } -func TestOpenAIInjectedToolsTrace(t *testing.T) { - t.Parallel() - - for _, streaming := range []bool{true, false} { - t.Run(fmt.Sprintf("streaming=%v", streaming), func(t *testing.T) { - t.Parallel() - - sr := tracetest.NewSpanRecorder() - tp := sdktrace.NewTracerProvider(sdktrace.WithSpanProcessor(sr)) - tracer := tp.Tracer(t.Name()) - defer func() { _ = tp.Shutdown(t.Context()) }() - - configureFn := func(addr string, client aibridge.Recorder, srvProxyMgr *mcp.ServerProxyManager) (*aibridge.RequestBridge, error) { - logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) - providers := []aibridge.Provider{provider.NewOpenAI(openaiCfg(addr, apiKey))} - return aibridge.NewRequestBridge(t.Context(), providers, client, srvProxyMgr, logger, nil, tracer) - } - - var reqBody string - reqFunc := func(t *testing.T, baseURL string, input []byte) *http.Request { - reqBody = string(input) - return createOpenAIChatCompletionsReq(t, baseURL, input) - } - - // Build the requirements & make the assertions which are common to all providers. - recorderClient, _, proxies, resp := setupInjectedToolTest(t, fixtures.OaiChatSingleInjectedTool, streaming, configureFn, reqFunc, openaiChatToolResultValidator(t)) - - defer resp.Body.Close() - - require.Len(t, recorderClient.RecordedInterceptions(), 1) - intcID := recorderClient.RecordedInterceptions()[0].ID - - for _, proxy := range proxies { - require.NotEmpty(t, proxy.ListTools()) - tool := proxy.ListTools()[0] - - attrs := []attribute.KeyValue{ - attribute.String(tracing.RequestPath, "/openai/v1/chat/completions"), - attribute.String(tracing.InterceptionID, intcID), - attribute.String(tracing.Provider, config.ProviderOpenAI), - attribute.String(tracing.Model, gjson.Get(reqBody, "model").Str), - attribute.String(tracing.InitiatorID, userID), - attribute.String(tracing.MCPInput, "{\"owner\":\"admin\"}"), - attribute.String(tracing.MCPToolName, "coder_list_workspaces"), - attribute.String(tracing.MCPServerName, tool.ServerName), - attribute.String(tracing.MCPServerURL, tool.ServerURL), - attribute.Bool(tracing.Streaming, streaming), - } - verifyTraces(t, sr, []expectTrace{{"Intercept.ProcessRequest.ToolCall", 1, codes.Unset}}, attrs) - } - }) - } -} - func TestTracePassthrough(t *testing.T) { t.Parallel() @@ -751,38 +727,26 @@ func TestTracePassthrough(t *testing.T) { } func TestNewServerProxyManagerTraces(t *testing.T) { - ctx := t.Context() - sr := tracetest.NewSpanRecorder() tp := sdktrace.NewTracerProvider(sdktrace.WithSpanProcessor(sr)) tracer := tp.Tracer(t.Name()) defer func() { _ = tp.Shutdown(t.Context()) }() serverName := "serverName" - srv, _ := createMockMCPSrv(t) - mcpSrv := httptest.NewServer(srv) - t.Cleanup(mcpSrv.Close) - - logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) - proxy, err := mcp.NewStreamableHTTPServerProxy(serverName, mcpSrv.URL, nil, nil, nil, logger, tracer) - require.NoError(t, err) - tools := map[string]mcp.ServerProxier{"unusedValue": proxy} - - mcpMgr := mcp.NewServerProxyManager(tools, tracer) - err = mcpMgr.Init(ctx) - require.NoError(t, err) + mockMCP := testutil.SetupMCPForTestWithName(t, serverName, tracer) + tool := mockMCP.ListTools()[0] require.Len(t, sr.Ended(), 3) verifyTraces(t, sr, []expectTrace{{"ServerProxyManager.Init", 1, codes.Unset}}, []attribute.KeyValue{}) attrs := []attribute.KeyValue{ - attribute.String(tracing.MCPProxyName, proxy.Name()), - attribute.String(tracing.MCPServerURL, mcpSrv.URL), + attribute.String(tracing.MCPProxyName, serverName), + attribute.String(tracing.MCPServerURL, tool.ServerURL), attribute.String(tracing.MCPServerName, serverName), } verifyTraces(t, sr, []expectTrace{{"StreamableHTTPServerProxy.Init", 1, codes.Unset}}, attrs) - attrs = append(attrs, attribute.Int(tracing.MCPToolCount, len(proxy.ListTools()))) + attrs = append(attrs, attribute.Int(tracing.MCPToolCount, len(mockMCP.ListTools()))) verifyTraces(t, sr, []expectTrace{{"StreamableHTTPServerProxy.Init.fetchTools", 1, codes.Unset}}, attrs) }