From bfa239f04731ff25fea1f1825b872f4d26a244fd Mon Sep 17 00:00:00 2001 From: Kurt Degiorgio Date: Thu, 12 Mar 2026 15:08:16 +0000 Subject: [PATCH] Implement Go interceptor chain MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add the interceptor implementation for servers built with the go-sdk. Interceptors sit between the transport and method handlers, enabling policy enforcement, data sanitization, and traffic auditing without modifying handler code. The core interceptor framework (chain engine, validators, mutators, invocations, results) is fully protocol-agnostic with zero MCP dependencies, so it can be used with any server that processes request/response pairs (gRPC, custom HTTP, etc.). The MCP-specific server integration lives in a separate mcpserver sub-package, wiring interceptors into go-sdk's middleware for all transports (stdio, HTTP). This covers the server-side chain execution model only. The SEP's protocol-level methods (interceptors/register, interceptors/update, interceptors/list, interceptors/execute) are not yet implemented. Package structure: interceptors/ chain.go — public Chain type, protocol-agnostic API chain_executor.go — chainExecutor, filtering, sorting, snapshots chain_validate.go — parallel validator dispatch chain_mutate.go — sequential mutator execution interceptor.go — types, Metadata, Validator/Mutator structs invocation.go — Invocation with audit-mode payload cloning result.go — ValidationResult, MutationResult, ChainResult doc.go — package documentation with examples interceptors/mcpserver/ server.go — Server wrapper, middleware, capability declaration events.go — MCP event name constants doc.go — sub-package documentation Also includes: - examples/validator and examples/mutator - HTTP integration tests (httptest + StreamableHTTPHandler) - doc/DESIGN.md, doc/CONFORMANCE.md, doc/PERFORMANCE.md Signed-off-by: Kurt Degiorgio --- .github/workflows/go.yml | 2 +- .gitignore | 2 + go/sdk/README.md | 56 +- go/sdk/doc/CONFORMANCE.md | 34 ++ go/sdk/doc/DESIGN.md | 195 ++++++ go/sdk/doc/PERFORMANCE.md | 51 ++ go/sdk/examples/.gitkeep | 0 go/sdk/examples/mutator/main.go | 87 +++ go/sdk/examples/validator/main.go | 86 +++ go/sdk/go.mod | 20 +- go/sdk/go.sum | 30 + go/sdk/interceptors/chain.go | 112 ++++ go/sdk/interceptors/chain_executor.go | 214 +++++++ go/sdk/interceptors/chain_mutate.go | 150 +++++ go/sdk/interceptors/chain_test.go | 137 +++++ go/sdk/interceptors/chain_validate.go | 191 ++++++ go/sdk/interceptors/doc.go | 158 +++++ go/sdk/interceptors/interceptor.go | 161 ++++- go/sdk/interceptors/interceptor_test.go | 14 - go/sdk/interceptors/invocation.go | 80 +++ go/sdk/interceptors/mcpserver/doc.go | 62 ++ go/sdk/interceptors/mcpserver/events.go | 17 + go/sdk/interceptors/mcpserver/server.go | 366 +++++++++++ .../mcpserver/server_integration_test.go | 577 ++++++++++++++++++ .../mcpserver/testharness_test.go | 191 ++++++ go/sdk/interceptors/result.go | 96 +++ go/sdk/internal/.gitkeep | 0 27 files changed, 3070 insertions(+), 19 deletions(-) create mode 100644 go/sdk/doc/CONFORMANCE.md create mode 100644 go/sdk/doc/DESIGN.md create mode 100644 go/sdk/doc/PERFORMANCE.md delete mode 100644 go/sdk/examples/.gitkeep create mode 100644 go/sdk/examples/mutator/main.go create mode 100644 go/sdk/examples/validator/main.go create mode 100644 go/sdk/go.sum create mode 100644 go/sdk/interceptors/chain.go create mode 100644 go/sdk/interceptors/chain_executor.go create mode 100644 go/sdk/interceptors/chain_mutate.go create mode 100644 go/sdk/interceptors/chain_test.go create mode 100644 go/sdk/interceptors/chain_validate.go create mode 100644 go/sdk/interceptors/doc.go delete mode 100644 go/sdk/interceptors/interceptor_test.go create mode 100644 go/sdk/interceptors/invocation.go create mode 100644 go/sdk/interceptors/mcpserver/doc.go create mode 100644 go/sdk/interceptors/mcpserver/events.go create mode 100644 go/sdk/interceptors/mcpserver/server.go create mode 100644 go/sdk/interceptors/mcpserver/server_integration_test.go create mode 100644 go/sdk/interceptors/mcpserver/testharness_test.go create mode 100644 go/sdk/interceptors/result.go delete mode 100644 go/sdk/internal/.gitkeep diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml index 87636b9..0d4973e 100644 --- a/.github/workflows/go.yml +++ b/.github/workflows/go.yml @@ -51,7 +51,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - go: ["1.23", "1.24", "1.25"] + go: ["1.24", "1.25"] steps: - name: Check out code uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 diff --git a/.gitignore b/.gitignore index 4013ac4..28557d4 100644 --- a/.gitignore +++ b/.gitignore @@ -4,6 +4,8 @@ *.dll *.so *.dylib +mutator +validator # Build results [Bb]in/ diff --git a/go/sdk/README.md b/go/sdk/README.md index a24c5a9..732f7e5 100644 --- a/go/sdk/README.md +++ b/go/sdk/README.md @@ -1,3 +1,57 @@ # MCP Interceptors - Go Implementation -This will contain the Go implementation of the MCP Interceptors based on [SEP-1763](https://github.com/modelcontextprotocol/modelcontextprotocol/issues/1763). +Go implementation of the MCP Interceptor Extension based on +[SEP-1763](https://github.com/modelcontextprotocol/modelcontextprotocol/issues/1763). + +## Quick Start + +```go +mcpServer := mcp.NewServer(&mcp.Implementation{ + Name: "my-server", + Version: "0.1.0", +}, nil) + +// Wrap with interceptor support. +srv := interceptors.NewServer(mcpServer, + // Optional Context Provider + interceptors.WithContextProvider( + func(_ context.Context, _ mcp.Request) *interceptors.InvocationContext { + return &interceptors.InvocationContext{ + Principal: &interceptors.Principal{Type: "user", ID: "alice"}, + } + }, + ), +) + +// Register a validator that blocks dangerous tool calls. +srv.AddInterceptor(&interceptors.Validator{ + Metadata: interceptors.Metadata{ + Name: "block-dangerous", + Events: []string{interceptors.EventToolsCall}, + Phase: interceptors.PhaseRequest, + Mode: interceptors.ModeOn, + }, + Handler: func(_ context.Context, inv *interceptors.Invocation) (*interceptors.ValidationResult, error) { + // validate the request... + return &interceptors.ValidationResult{Valid: true}, nil + }, +}) + +srv.Run(context.Background(), &mcp.StdioTransport{}) +``` + +See [`examples/`](examples/) for complete working examples. + +## Documentation + +- [**DESIGN.md**](doc/DESIGN.md) — architecture, execution model, integration + with the go-sdk. +- [**PERFORMANCE.md**](doc/PERFORMANCE.md) — per-request cost model, allocation + summary, and optimization notes. +- [**CONFORMANCE.md**](doc/CONFORMANCE.md) — SEP conformance status. + +Package API documentation is available via `go doc`: + +```sh +go doc github.com/modelcontextprotocol/ext-interceptors/go/sdk/interceptors +``` diff --git a/go/sdk/doc/CONFORMANCE.md b/go/sdk/doc/CONFORMANCE.md new file mode 100644 index 0000000..77d8bd5 --- /dev/null +++ b/go/sdk/doc/CONFORMANCE.md @@ -0,0 +1,34 @@ +# SEP Conformance + +Status of this Go SDK implementation against the +[SEP-1763](https://github.com/modelcontextprotocol/modelcontextprotocol/issues/1763) +interceptor proposal. + +## Implemented + +| Area | Notes | +|------|-------| +| Validation interceptors | Parallel execution, severity-based blocking, fail-open support | +| Mutation interceptors | Sequential execution, priority ordering, atomic payload updates | +| Interceptor metadata | Name, version, description, events, phase, priorityHint (polymorphic JSON), compat, configSchema, mode, failOpen, timeout | +| Event names | Constants for all standard server-side MCP methods; JSON-RPC method names used directly as event names | +| Unified result envelope | ValidationResult, MutationResult, ExecutionResult with base envelope fields | +| Chain result | ChainResult with status, results, finalPayload, validationSummary, abortedAt | +| JSON-RPC error mapping | Typed error data structs: -32602 for validation, -32603 for mutation, -32000 for timeout | +| Trust-boundary execution order | Receiving: validate (parallel) then mutate (sequential); Sending: mutate (sequential) then validate (parallel) | +| Priority ordering | Mutators sorted by `priorityHint.Resolve(phase)` ascending, alphabetical tiebreak | +| Fail-open behavior | `FailOpen: true` interceptors log errors without aborting the chain | +| Audit mode | `ModeAudit` records results without blocking or applying mutations | +| Timeout & context | Per-interceptor timeouts, chain-level context cancellation, `InvocationContext` with principal/traceId via `mcpserver.WithContextProvider` | +| Receiving direction (client → server) | All server-side method calls intercepted via `AddReceivingMiddleware` | +| Capability declaration | Interceptor metadata injected into `initialize` response via `Capabilities.Experimental` | +| First-party (in-process) deployment | Interceptors run as Go functions within the server process | +| Third-party and hybrid deployment | Handlers can call remote services; local and remote interceptors can be mixed freely. No built-in remote interceptor abstraction yet | + +## Not Implemented + +| Area | SEP expects | Notes | +|------|-------------|-------| +| Wildcard event matching | `type InterceptorEvent = ... \| "*/request" \| "*/response" \| "*"` | `matchesEvent` does exact match only; wildcard patterns are planned | +| Protocol methods | `interceptors/list`, `interceptor/invoke`, `interceptor/executeChain` | Requires upstream go-sdk changes to register custom JSON-RPC methods | +| Server → client interception | Client features as interceptable events: `"sampling/createMessage"`, `"elicitation/create"`, `"roots/list"` | Requires a `sendingMiddleware` installed via `Server.AddSendingMiddleware`. Outgoing requests run mutate → validate, incoming responses run validate → mutate (same `executeForSending`/`executeForReceiving` methods). Interceptors match by event name (no API changes needed for registration) | diff --git a/go/sdk/doc/DESIGN.md b/go/sdk/doc/DESIGN.md new file mode 100644 index 0000000..ee84a87 --- /dev/null +++ b/go/sdk/doc/DESIGN.md @@ -0,0 +1,195 @@ +# Go SDK Interceptors — Design Document + +## Integration Point: Receiving Middleware + +The go-sdk processes an incoming JSON-RPC message in this order: + +``` +Transport (SSE / stdio) + → JSON-RPC decode + → Params deserialization (json.RawMessage → typed struct) + → Receiving middleware chain ← we hook in here + → Method handler (e.g. tool handler) + → Result returned through middleware + → JSON-RPC encode + → Transport +``` + +## Capability Declaration + +During initialization, the middleware intercepts the `"initialize"` response +and injects interceptor metadata into +`Capabilities.Experimental["io.modelcontextprotocol/interceptors"]`. This +follows the same pattern as the variants extension +(`io.modelcontextprotocol/server-variants`). + +The capability payload includes: +- `supportedEvents` — deduplicated list of events with registered interceptors +- `interceptors` — full metadata array in wire format + +## Request/Response Lifecycle + +When a JSON-RPC request arrives, `receivingMiddleware` in +`mcpserver/server.go` runs the following sequence: + +``` +0. If method == "initialize" → enrich result with capability declaration +1. Assign typed params to Invocation inv.Payload = req.GetParams() +2. Run request-phase chain (validate → mutate) +3. If aborted → return error +4. Params already modified in place — no unmarshal needed +5. Call next handler next(ctx, method, req) +6. Assign result to Invocation inv.Payload = result +7. Run response-phase chain (mutate → validate) +8. If aborted → return error +9. Result already modified in place — no unmarshal needed +10. Return result +``` + +The JSON-RPC method name is used directly as the event name (e.g. `"tools/call"`). + +### Typed Payload — Zero JSON Operations + +`Invocation.Payload` is `any`, the live Go value from the go-sdk +(e.g. `*mcp.CallToolParamsRaw`). Handlers type-assert directly, the +same pattern as gRPC-Go interceptors (`req any`). No JSON marshaling +or unmarshaling occurs in the normal path. + +```go +// Validator — type-assert, inspect, return: +params, ok := inv.Payload.(*mcp.CallToolParamsRaw) +if !ok { + return nil, fmt.Errorf("unexpected payload type %T", inv.Payload) +} + +// Mutator — type-assert, modify in place, return: +result, ok := inv.Payload.(*mcp.CallToolResult) +if !ok { + return nil, fmt.Errorf("unexpected payload type %T", inv.Payload) +} +result.Content[0] = &mcp.TextContent{Text: "modified"} +return &MutationResult{Modified: true}, nil +``` + +**Audit mode:** Audit-mode mutators receive a deep-copied payload (via +`Invocation.withCopiedPayload()`) so their in-place modifications don't +affect the real struct. The deep copy uses a JSON round-trip +(`json.Marshal` → `reflect.New` → `json.Unmarshal`). Only audit-mode +mutators pay this cost. + +### Limitations + +- **Params must round-trip through JSON faithfully** for audit-mode deep + copy. All go-sdk param and result types use standard `encoding/json` + tags, so this holds in practice. +- **Type assertions require knowing the concrete type.** Interceptors must + know which type to expect for a given event (e.g. `*mcp.CallToolParamsRaw` + for `tools/call` requests). The `Events` field on `Metadata` narrows + which events reach a handler, so single-event interceptors always see + the expected type. + +--- + +## What Is and Is Not Intercepted + +### Intercepted + +All JSON-RPC **method calls** routed through the server's receiving middleware: + +| Method | Event | +|--------|-------| +| `tools/call` | `EventToolsCall` | +| `tools/list` | `EventToolsList` | +| `prompts/get` | `EventPromptsGet` | +| `prompts/list` | `EventPromptsList` | +| `resources/read` | `EventResourcesRead` | +| `resources/list` | `EventResourcesList` | +| `resources/subscribe` | `EventResourcesSubscribe` | + +Unknown methods pass through the middleware untouched. + +### Not Intercepted + +1. **Progress notifications.** During a tool call, a handler can call + `session.NotifyProgress()`. These are JSON-RPC *notifications* sent + directly over the transport — they do not flow through `MethodHandler` + middleware. Interceptors never see them. + +2. **Transport-level SSE streaming.** The Streamable HTTP transport + multiplexes multiple JSON-RPC messages over a single SSE connection. + This is connection management, not per-message streaming. Each individual + method call is still a single request → single response, which the + middleware intercepts normally. + +3. **JSON-RPC notifications** (e.g. `notifications/initialized`, + `notifications/cancelled`). The go-sdk routes notifications through a + separate handler path, not through `MethodHandler` middleware. + +Notification interception is not defined by the proposal. +If this becomes necessary, it would require a separate notification middleware +hook in the go-sdk. + +--- + +## Chain Execution Model + +The `chainExecutor` in `chain_executor.go` implements trust-boundary-aware +execution: + +**Request phase** (receiving data — untrusted → trusted): +``` +Validate (parallel) → Mutate (sequential) +``` +Validation acts as a security gate before mutations process the data. + +**Response phase** (sending data — trusted → untrusted): +``` +Mutate (sequential) → Validate (parallel) +``` +Mutations prepare/sanitize data, then validation verifies before sending. + +### Validator execution +- All matching validators run in parallel (goroutines). +- A validator returning `Valid: false` with `Severity: "error"` in enforced + mode (`Mode: ModeOn`) aborts the chain. +- `FailOpen: true` validators log errors and record an `ExecutionResult` + (with `Error` populated) for observability, but don't abort. + +### Mutator execution +- Mutators run sequentially, ordered by `PriorityHint.Resolve(phase)` + (ascending), with alphabetical name tiebreak. +- Each mutator modifies the typed payload in place via type assertion on `inv.Payload`. +- If any mutator fails (and is not `FailOpen`), the chain aborts. + `FailOpen` mutators record an `ExecutionResult` (with `Error` populated) + and continue. +- In `ModeAudit`, the mutator runs on a deep-copied payload and its result + is recorded, but the real payload is not affected. + +### Filtering +`newChainExecutor` filters the full interceptor set by: +1. `Mode != ModeOff` +2. Phase matches (or interceptor phase is `PhaseBoth`) +3. Event matches (exact match only; wildcard support is planned via `matchesEvent`) + +--- + +## File Map + +### `interceptors/` — protocol-agnostic core (zero MCP imports) + +| File | Responsibility | +|------|---------------| +| `interceptor.go` | Types (Phase, Mode, InterceptorType, Priority, Severity, Compat), Metadata struct, Interceptor interface, Validator/Mutator structs and handler types | +| `invocation.go` | Invocation (with audit-mode payload cloning), InvocationContext, Principal — the input to every handler | +| `result.go` | All outcome types: ValidationResult, MutationResult, ExecutionResult, ChainResult, AbortInfo | +| `chain.go` | `Chain` public API: `NewChain`, `Add`, `ExecuteForReceiving`, `ExecuteForSending`, `IsEmpty`, `Interceptors` | +| `chain_executor.go` | `interceptorSnapshot` (atomic snapshot with lazy chain cache), `chainExecutor` struct, `newChainExecutor` (filtering + sorting), `executeForReceiving`, `executeForSending`, `timeoutResult`, `matchesPhase`, `matchesEvent` | +| `chain_validate.go` | `validatorResult` struct, `runValidators` (parallel dispatch + N=1 fast path), `executeValidator`, `recordValidation` | +| `chain_mutate.go` | `mutatorOutcome` type + constants, `runMutators` (sequential loop + audit-mode copy), `executeMutator` | + +### `interceptors/mcpserver/` — MCP server integration + +| File | Responsibility | +|------|---------------| +| `server.go` | `Server` wrapper, `receivingMiddleware`, `WithContextProvider`, capability declaration, `NewStreamableHTTPHandler`, `abortToJSONRPCError` | +| `events.go` | Event name constants for standard MCP methods | diff --git a/go/sdk/doc/PERFORMANCE.md b/go/sdk/doc/PERFORMANCE.md new file mode 100644 index 0000000..c525110 --- /dev/null +++ b/go/sdk/doc/PERFORMANCE.md @@ -0,0 +1,51 @@ +# Go SDK Interceptors — Performance + +Analysis of per-request costs and allocation patterns. + +--- + +## Design Rationale: Typed Payload + +The interceptor chain passes Go's typed params/result objects directly +to handlers, avoiding JSON serialization entirely in the normal path. +This follows the same pattern used by gRPC-Go, where interceptors +receive the request as `any` and type-assert to the concrete type. + +`Invocation.Payload` is `any, the same pointer the go-sdk already +allocated during JSON-RPC deserialization. No wrapper struct, no +intermediate copies. Mutators modify the value in place through the +pointer (no marshal/unmarshal round-trip) + +--- + +## Per-Request Cost Model + +Every intercepted request passes through the middleware in +`mcpserver/server.go`. The cost depends on whether interceptors match +the event. + +### Fast Path (no matching interceptors) + +When no interceptors match, the middleware cost is: + +1. One atomic pointer load: `s.chain.snapshot.Load()` +2. Two `sync.Map` lookups on the chain cache: `getChain(event, PhaseRequest)` and `getChain(event, PhaseResponse)` +3. One boolean check: `ce.empty` + +Zero allocations, zero JSON operations. + +### Intercepted Path + +With interceptors active, each active phase +incurs: + +| Step | Operation | Allocations | JSON ops | +|------|-----------|-------------|----------| +| 1 | `Invocation` struct | 1 struct | 0 | +| 2 | `ChainResult` struct (with pre-allocated `Results` slice) | 1 struct + 1 slice | 0 | +| 3 | Validator execution | 0 (N=1) / goroutines (N>1) | 0 | +| 4 | Mutator execution | 0 | 0 | +| 5 | Audit-mode deep copy | 1 per audit mutator | 1 marshal + 1 unmarshal | + +Zero JSON operations in the normal path. Phases with no matching +interceptors are skipped entirely. diff --git a/go/sdk/examples/.gitkeep b/go/sdk/examples/.gitkeep deleted file mode 100644 index e69de29..0000000 diff --git a/go/sdk/examples/mutator/main.go b/go/sdk/examples/mutator/main.go new file mode 100644 index 0000000..b0a962b --- /dev/null +++ b/go/sdk/examples/mutator/main.go @@ -0,0 +1,87 @@ +// Example: a simple mutator interceptor that adds an "[audited]" prefix +// to every tool call result's text content. +// +// Demonstrates WithContextProvider to populate Invocation.Context with +// caller identity, which interceptor handlers can inspect. +package main + +import ( + "context" + "fmt" + "log" + "net/http" + + "github.com/modelcontextprotocol/go-sdk/mcp" + + "github.com/modelcontextprotocol/ext-interceptors/go/sdk/interceptors" + "github.com/modelcontextprotocol/ext-interceptors/go/sdk/interceptors/mcpserver" +) + +func main() { + mcpServer := mcp.NewServer(&mcp.Implementation{ + Name: "example-server", + Version: "0.1.0", + }, nil) + + mcpServer.AddTool(&mcp.Tool{Name: "greet", Description: "says hello"}, + func(ctx context.Context, req *mcp.CallToolRequest) (*mcp.CallToolResult, error) { + return &mcp.CallToolResult{ + Content: []mcp.Content{&mcp.TextContent{Text: "hello world"}}, + }, nil + }, + ) + + // Build the mutator. + m := &interceptors.Mutator{ + Metadata: interceptors.Metadata{ + Name: "audit-tag", + Events: []string{mcpserver.EventToolsCall}, + Phase: interceptors.PhaseResponse, + Mode: interceptors.ModeOn, + TimeoutMs: 5000, + }, + Handler: func(_ context.Context, inv *interceptors.Invocation) (*interceptors.MutationResult, error) { + result, ok := inv.Payload.(*mcp.CallToolResult) + if !ok { + return nil, fmt.Errorf("unexpected payload type %T", inv.Payload) + } + + // Use caller identity from the context provider if available. + tag := "[audited]" + if inv.Context != nil && inv.Context.Principal != nil { + tag = fmt.Sprintf("[audited by %s]", inv.Context.Principal.ID) + } + + for i, c := range result.Content { + if tc, ok := c.(*mcp.TextContent); ok { + tc.Text = tag + " " + tc.Text + result.Content[i] = tc + } + } + return &interceptors.MutationResult{Modified: true}, nil + }, + } + + // Wire it up with a context provider that populates caller identity. + // In production, extract the principal from OAuth tokens on the session + // (e.g. via req.GetSession() or transport-level auth headers). + srv := mcpserver.NewServer(mcpServer, + mcpserver.WithContextProvider( + func(_ context.Context, _ mcp.Request) *interceptors.InvocationContext { + return &interceptors.InvocationContext{ + Principal: &interceptors.Principal{ + Type: "user", + ID: "example-user", + }, + } + }, + ), + ) + srv.AddInterceptor(m) + + handler := mcpserver.NewStreamableHTTPHandler(srv, nil) + log.Println("listening on :8080") + if err := http.ListenAndServe(":8080", handler); err != nil { + log.Fatal(err) + } +} diff --git a/go/sdk/examples/validator/main.go b/go/sdk/examples/validator/main.go new file mode 100644 index 0000000..f0f87a0 --- /dev/null +++ b/go/sdk/examples/validator/main.go @@ -0,0 +1,86 @@ +// Example: a simple validator interceptor that rejects tool calls +// to a tool named "dangerous_tool". +// +// Demonstrates WithContextProvider to populate Invocation.Context with +// caller identity, which interceptor handlers can inspect. +package main + +import ( + "context" + "fmt" + "log" + "net/http" + + "github.com/modelcontextprotocol/go-sdk/mcp" + + "github.com/modelcontextprotocol/ext-interceptors/go/sdk/interceptors" + "github.com/modelcontextprotocol/ext-interceptors/go/sdk/interceptors/mcpserver" +) + +func main() { + mcpServer := mcp.NewServer(&mcp.Implementation{ + Name: "example-server", + Version: "0.1.0", + }, nil) + + mcpServer.AddTool(&mcp.Tool{Name: "echo", Description: "echoes input"}, + func(ctx context.Context, req *mcp.CallToolRequest) (*mcp.CallToolResult, error) { + return &mcp.CallToolResult{ + Content: []mcp.Content{&mcp.TextContent{Text: fmt.Sprintf("you said: %s", req.Params.Arguments)}}, + }, nil + }, + ) + + // Build the validator. + v := &interceptors.Validator{ + Metadata: interceptors.Metadata{ + Name: "block-dangerous-tool", + Events: []string{mcpserver.EventToolsCall}, + Phase: interceptors.PhaseRequest, + Mode: interceptors.ModeOn, + TimeoutMs: 5000, + }, + Handler: func(_ context.Context, inv *interceptors.Invocation) (*interceptors.ValidationResult, error) { + params, ok := inv.Payload.(*mcp.CallToolParamsRaw) + if !ok { + return nil, fmt.Errorf("unexpected payload type %T", inv.Payload) + } + if params.Name == "dangerous_tool" { + reason := "dangerous_tool is not allowed" + if inv.Context != nil && inv.Context.Principal != nil { + reason = fmt.Sprintf("dangerous_tool is not allowed for %s", inv.Context.Principal.ID) + } + return &interceptors.ValidationResult{ + Valid: false, + Messages: []interceptors.ValidationMessage{ + {Message: reason, Severity: interceptors.SeverityError}, + }, + }, nil + } + return &interceptors.ValidationResult{Valid: true}, nil + }, + } + + // Wire it up with a context provider that populates caller identity. + // In production, extract the principal from OAuth tokens on the session + // (e.g. via req.GetSession() or transport-level auth headers). + srv := mcpserver.NewServer(mcpServer, + mcpserver.WithContextProvider( + func(_ context.Context, _ mcp.Request) *interceptors.InvocationContext { + return &interceptors.InvocationContext{ + Principal: &interceptors.Principal{ + Type: "user", + ID: "example-user", + }, + } + }, + ), + ) + srv.AddInterceptor(v) + + handler := mcpserver.NewStreamableHTTPHandler(srv, nil) + log.Println("listening on :8080") + if err := http.ListenAndServe(":8080", handler); err != nil { + log.Fatal(err) + } +} diff --git a/go/sdk/go.mod b/go/sdk/go.mod index 0584e8f..4303b25 100644 --- a/go/sdk/go.mod +++ b/go/sdk/go.mod @@ -1,4 +1,22 @@ module github.com/modelcontextprotocol/ext-interceptors/go/sdk -go 1.23.0 +go 1.24.0 +toolchain go1.24.3 + +require ( + github.com/modelcontextprotocol/go-sdk v1.4.0 + github.com/stretchr/testify v1.11.1 +) + +require ( + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/google/jsonschema-go v0.4.2 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/segmentio/asm v1.1.3 // indirect + github.com/segmentio/encoding v0.5.3 // indirect + github.com/yosida95/uritemplate/v3 v3.0.2 // indirect + golang.org/x/oauth2 v0.34.0 // indirect + golang.org/x/sys v0.40.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) diff --git a/go/sdk/go.sum b/go/sdk/go.sum new file mode 100644 index 0000000..dacfe34 --- /dev/null +++ b/go/sdk/go.sum @@ -0,0 +1,30 @@ +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/golang-jwt/jwt/v5 v5.3.0 h1:pv4AsKCKKZuqlgs5sUmn4x8UlGa0kEVt/puTpKx9vvo= +github.com/golang-jwt/jwt/v5 v5.3.0/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE= +github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= +github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= +github.com/google/jsonschema-go v0.4.2 h1:tmrUohrwoLZZS/P3x7ex0WAVknEkBZM46iALbcqoRA8= +github.com/google/jsonschema-go v0.4.2/go.mod h1:r5quNTdLOYEz95Ru18zA0ydNbBuYoo9tgaYcxEYhJVE= +github.com/modelcontextprotocol/go-sdk v1.4.0 h1:u0kr8lbJc1oBcawK7Df+/ajNMpIDFE41OEPxdeTLOn8= +github.com/modelcontextprotocol/go-sdk v1.4.0/go.mod h1:Nxc2n+n/GdCebUaqCOhTetptS17SXXNu9IfNTaLDi1E= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/segmentio/asm v1.1.3 h1:WM03sfUOENvvKexOLp+pCqgb/WDjsi7EK8gIsICtzhc= +github.com/segmentio/asm v1.1.3/go.mod h1:Ld3L4ZXGNcSLRg4JBsZ3//1+f/TjYl0Mzen/DQy1EJg= +github.com/segmentio/encoding v0.5.3 h1:OjMgICtcSFuNvQCdwqMCv9Tg7lEOXGwm1J5RPQccx6w= +github.com/segmentio/encoding v0.5.3/go.mod h1:HS1ZKa3kSN32ZHVZ7ZLPLXWvOVIiZtyJnO1gPH1sKt0= +github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= +github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= +github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4= +github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4= +golang.org/x/oauth2 v0.34.0 h1:hqK/t4AKgbqWkdkcAeI8XLmbK+4m4G5YeQRrmiotGlw= +golang.org/x/oauth2 v0.34.0/go.mod h1:lzm5WQJQwKZ3nwavOZ3IS5Aulzxi68dUSgRHujetwEA= +golang.org/x/sys v0.40.0 h1:DBZZqJ2Rkml6QMQsZywtnjnnGvHza6BTfYFWY9kjEWQ= +golang.org/x/sys v0.40.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= +golang.org/x/tools v0.41.0 h1:a9b8iMweWG+S0OBnlU36rzLp20z1Rp10w+IY2czHTQc= +golang.org/x/tools v0.41.0/go.mod h1:XSY6eDqxVNiYgezAVqqCeihT4j1U2CCsqvH3WhQpnlg= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/go/sdk/interceptors/chain.go b/go/sdk/interceptors/chain.go new file mode 100644 index 0000000..99c2b10 --- /dev/null +++ b/go/sdk/interceptors/chain.go @@ -0,0 +1,112 @@ +// Copyright 2025 The MCP Interceptors Authors. All rights reserved. +// Use of this source code is governed by an Apache-2.0 +// license that can be found in the LICENSE file. + +package interceptors + +import ( + "context" + "fmt" + "log/slog" + "sync" + "sync/atomic" +) + +// Chain is a protocol-agnostic interceptor chain engine. It manages a set +// of interceptors and executes them in the correct order for a given +// event and phase. It has no dependency on any specific transport and +// can be used with MCP, gRPC, custom HTTP, or any other server. +type Chain struct { + mu sync.Mutex + snapshot atomic.Pointer[interceptorSnapshot] + logger *slog.Logger +} + +// ChainOption configures a Chain. +type ChainOption func(*Chain) + +// WithChainLogger sets the logger for chain execution. +// If not set, slog.Default() is used. +func WithChainLogger(l *slog.Logger) ChainOption { + return func(c *Chain) { + c.logger = l + } +} + +// NewChain creates a new Chain with optional configuration. +func NewChain(opts ...ChainOption) *Chain { + c := &Chain{} + for _, opt := range opts { + opt(c) + } + if c.logger == nil { + c.logger = slog.Default() + } + c.snapshot.Store(&interceptorSnapshot{logger: c.logger}) + return c +} + +// Add registers an interceptor. It panics if the interceptor has a nil +// handler or is an unsupported type. It is safe to call while the chain +// is in use. Returns the receiver for chaining. +func (c *Chain) Add(i Interceptor) *Chain { + switch v := i.(type) { + case *Validator: + if v.Handler == nil { + panic("interceptors: validator " + v.Name + " has nil handler") + } + case *Mutator: + if v.Handler == nil { + panic("interceptors: mutator " + v.Name + " has nil handler") + } + default: + panic(fmt.Sprintf("interceptors: unsupported interceptor type %T", i)) + } + c.mu.Lock() + defer c.mu.Unlock() + old := c.snapshot.Load() + newAll := make([]Interceptor, len(old.all)+1) + copy(newAll, old.all) + newAll[len(old.all)] = i + c.snapshot.Store(&interceptorSnapshot{ + all: newAll, + logger: c.logger, + }) + return c +} + +// ExecuteForReceiving runs the interceptor chain for the given invocation +// using receive-side ordering (validate then mutate). Returns nil, nil if +// no interceptors match the invocation's event and phase. +func (c *Chain) ExecuteForReceiving(ctx context.Context, inv *Invocation) (*ChainResult, error) { + snap := c.snapshot.Load() + ce := snap.getChain(inv.Event, inv.Phase) + if ce.empty { + return nil, nil + } + return ce.executeForReceiving(ctx, inv) +} + +// ExecuteForSending runs the interceptor chain for the given invocation +// using send-side ordering (mutate then validate). Returns nil, nil if +// no interceptors match the invocation's event and phase. +func (c *Chain) ExecuteForSending(ctx context.Context, inv *Invocation) (*ChainResult, error) { + snap := c.snapshot.Load() + ce := snap.getChain(inv.Event, inv.Phase) + if ce.empty { + return nil, nil + } + return ce.executeForSending(ctx, inv) +} + +// IsEmpty reports whether no interceptors match the given event and phase. +func (c *Chain) IsEmpty(event string, phase Phase) bool { + snap := c.snapshot.Load() + ce := snap.getChain(event, phase) + return ce.empty +} + +// Interceptors returns the current list of registered interceptors. +func (c *Chain) Interceptors() []Interceptor { + return c.snapshot.Load().all +} diff --git a/go/sdk/interceptors/chain_executor.go b/go/sdk/interceptors/chain_executor.go new file mode 100644 index 0000000..75ef760 --- /dev/null +++ b/go/sdk/interceptors/chain_executor.go @@ -0,0 +1,214 @@ +// Copyright 2025 The MCP Interceptors Authors. All rights reserved. +// Use of this source code is governed by an Apache-2.0 +// license that can be found in the LICENSE file. + +package interceptors + +import ( + "context" + "fmt" + "log/slog" + "sort" + "sync" + "time" +) + +// interceptorSnapshot is an immutable snapshot of registered interceptors, +// created once per Add call. It holds the interceptor list and +// lazily-built chain executors keyed by "event|phase". Swapping the snapshot +// pointer implicitly invalidates all cached chains. +type interceptorSnapshot struct { + all []Interceptor + logger *slog.Logger + chains sync.Map // "event|phase" -> *chainExecutor +} + +// getChain returns a cached chainExecutor for the given event and phase, +// building one on first access via newChainExecutor. +func (snap *interceptorSnapshot) getChain(event string, phase Phase) *chainExecutor { + key := event + "|" + string(phase) + if v, ok := snap.chains.Load(key); ok { + return v.(*chainExecutor) + } + ce := newChainExecutor(snap.all, event, phase, snap.logger) + v, _ := snap.chains.LoadOrStore(key, ce) + return v.(*chainExecutor) +} + +// chainExecutor holds the filtered and sorted interceptors applicable to a +// specific event+phase pair, and orchestrates their execution according to +// trust-boundary-aware ordering. +// +// Validators and mutators are separated at construction time so that each +// execution method (executeForReceiving, executeForSending) can run them +// in the correct order without re-filtering. +type chainExecutor struct { + empty bool // true when no validators or mutators matched + validators []*Validator // validators matching this event+phase (run in parallel) + mutators []*Mutator // mutators matching this event+phase (run sequentially by priority) + resultsCap int // len(validators) + len(mutators); used to pre-allocate ChainResult.Results + + logger *slog.Logger +} + +// newChainExecutor builds a chainExecutor by filtering the global interceptor +// list down to those that are active (Mode != ModeOff), match the target phase, +// and match the target event. Mutators are then sorted by priority (ascending) +// with alphabetical name as a tiebreaker, which guarantees +// deterministic execution order. +func newChainExecutor( + all []Interceptor, + event string, + phase Phase, + logger *slog.Logger, +) *chainExecutor { + ce := &chainExecutor{logger: logger} + for _, i := range all { + meta := i.GetMetadata() + if meta.Mode == ModeOff || !matchesPhase(meta.Phase, phase) || !matchesEvent(meta.Events, event) { + continue + } + switch v := i.(type) { + case *Validator: + ce.validators = append(ce.validators, v) + case *Mutator: + ce.mutators = append(ce.mutators, v) + default: + logger.Warn("unknown interceptor type, skipping", + "interceptor", meta.Name, + "type", fmt.Sprintf("%T", i), + ) + } + } + // Sort mutators by priority (ascending), alphabetical tiebreak. + sort.Slice(ce.mutators, func(i, j int) bool { + pi := ce.mutators[i].PriorityHint.Resolve(phase) + pj := ce.mutators[j].PriorityHint.Resolve(phase) + if pi != pj { + return pi < pj + } + return ce.mutators[i].Name < ce.mutators[j].Name + }) + ce.empty = len(ce.validators) == 0 && len(ce.mutators) == 0 + ce.resultsCap = len(ce.validators) + len(ce.mutators) + return ce +} + +// executeForReceiving runs the chain for incoming requests (server receiving) +// using the receive-side ordering: +// +// Receive -> Validate (parallel) -> Mutate (sequential) +// +// Validators run first as a security barrier: if any enforced +// validator produces an error-severity message, the chain aborts with +// "validation_failed" before any mutators run, so the payload is +// unmodified. Only after all validators pass do mutators run +// sequentially, modifying the typed payload in place. If a mutator +// aborts, the chain returns with the abort recorded. +func (ce *chainExecutor) executeForReceiving(ctx context.Context, inv *Invocation) (*ChainResult, error) { + start := time.Now() + cr := &ChainResult{ + Event: inv.Event, + Phase: inv.Phase, + FinalPayload: inv.Payload, + Results: make([]ExecutionResult, 0, ce.resultsCap), + } + + // 1. Run validators in parallel. + ce.runValidators(ctx, inv, cr) + if len(cr.AbortedAt) > 0 { + cr.Status = ChainValidationFailed + cr.TotalDurationMs = time.Since(start).Milliseconds() + return cr, nil + } + if err := ctx.Err(); err != nil { + return ce.timeoutResult(cr, start), nil + } + + // 2. Run mutators sequentially (in-place on inv.Payload). + ce.runMutators(ctx, inv, cr) + if len(cr.AbortedAt) > 0 { + cr.Status = ChainMutationFailed + cr.TotalDurationMs = time.Since(start).Milliseconds() + return cr, nil + } + + cr.Status = ChainSuccess + cr.TotalDurationMs = time.Since(start).Milliseconds() + return cr, nil +} + +// executeForSending runs the chain for outgoing responses (server sending) +// using the send-side ordering: +// +// Mutate (sequential) -> Validate (parallel) -> Send +// +// Mutators run first to prepare/sanitize outgoing data, then +// validators check the (now mutated) payload before it leaves the server. +// Mutators modify the typed value in place, so validators automatically +// see the post-mutation state. +func (ce *chainExecutor) executeForSending(ctx context.Context, inv *Invocation) (*ChainResult, error) { + start := time.Now() + cr := &ChainResult{ + Event: inv.Event, + Phase: inv.Phase, + FinalPayload: inv.Payload, + Results: make([]ExecutionResult, 0, ce.resultsCap), + } + + // 1. Run mutators sequentially (in-place on inv.Payload). + ce.runMutators(ctx, inv, cr) + if len(cr.AbortedAt) > 0 { + cr.Status = ChainMutationFailed + cr.TotalDurationMs = time.Since(start).Milliseconds() + return cr, nil + } + if err := ctx.Err(); err != nil { + return ce.timeoutResult(cr, start), nil + } + + // 2. Run validators in parallel. + ce.runValidators(ctx, inv, cr) + if len(cr.AbortedAt) > 0 { + cr.Status = ChainValidationFailed + cr.TotalDurationMs = time.Since(start).Milliseconds() + return cr, nil + } + + cr.Status = ChainSuccess + cr.TotalDurationMs = time.Since(start).Milliseconds() + return cr, nil +} + +// timeoutResult sets the ChainResult to "timeout" status and appends +// an AbortInfo entry. Used when the parent context is cancelled between +// chain stages (e.g. after validators but before mutators). +func (ce *chainExecutor) timeoutResult(cr *ChainResult, start time.Time) *ChainResult { + cr.Status = ChainTimeout + cr.TotalDurationMs = time.Since(start).Milliseconds() + cr.AbortedAt = append(cr.AbortedAt, AbortInfo{ + Reason: "chain execution timeout exceeded", + Type: AbortTimeout, + Phase: string(cr.Phase), + }) + return cr +} + +// matchesPhase checks if an interceptor's configured phase covers the target +// phase. An interceptor with PhaseBoth matches any target phase. +func matchesPhase(interceptorPhase, targetPhase Phase) bool { + return interceptorPhase == PhaseBoth || interceptorPhase == targetPhase +} + +// matchesEvent checks if any of the interceptor's registered event patterns +// match the given event string. Currently uses exact string comparison only. +// +// TODO: support wildcard patterns (e.g. "*", "*/request", "tools/*"). +func matchesEvent(interceptorEvents []string, event string) bool { + for _, pattern := range interceptorEvents { + if pattern == event { + return true + } + } + return false +} diff --git a/go/sdk/interceptors/chain_mutate.go b/go/sdk/interceptors/chain_mutate.go new file mode 100644 index 0000000..1a743cb --- /dev/null +++ b/go/sdk/interceptors/chain_mutate.go @@ -0,0 +1,150 @@ +// Copyright 2025 The MCP Interceptors Authors. All rights reserved. +// Use of this source code is governed by an Apache-2.0 +// license that can be found in the LICENSE file. + +package interceptors + +import ( + "context" + "time" +) + +// mutatorOutcome describes the result of executing a single mutator, used by +// the runMutators loop to decide whether to continue, skip, or halt the chain. +type mutatorOutcome int + +const ( + mutatorOK mutatorOutcome = iota // handler succeeded, payload may have been updated + mutatorSkipped // handler failed (fail-open) or audit mode; continue chain + mutatorAborted // handler failed (fail-closed); halt chain +) + +// runMutators runs mutators sequentially in priority order. Mutators modify +// the typed payload in place via type assertion on inv.Payload. For +// audit-mode mutators, a deep-copied invocation is used so modifications +// don't affect the real payload. +func (ce *chainExecutor) runMutators(ctx context.Context, inv *Invocation, cr *ChainResult) { + if len(ce.mutators) == 0 { + return + } + + for _, m := range ce.mutators { + if ctx.Err() != nil { + cr.AbortedAt = append(cr.AbortedAt, AbortInfo{ + Reason: "context cancelled during mutation chain", + Type: AbortTimeout, + Phase: string(inv.Phase), + }) + return + } + + // Audit-mode mutators work on a deep copy so in-place + // modifications don't affect the real payload. + mInv := inv + if m.Mode == ModeAudit { + var err error + mInv, err = inv.withCopiedPayload() + if err != nil { + ce.logger.Warn("failed to deep-copy payload for audit mutator, skipping", + "interceptor", m.Name, + "error", err, + ) + continue + } + } + + if ce.executeMutator(ctx, m, mInv, cr) == mutatorAborted { + return + } + } +} + +// executeMutator runs a single mutator handler and manages its full lifecycle: +// +// 1. Timeout setup: wraps ctx with a per-interceptor deadline if configured. +// 2. Handler invocation: calls the mutator's Handler with the current payload. +// 3. Error handling: on failure, checks FailOpen to decide between abort and skip. +// Fail-closed records both an AbortInfo and an ExecutionResult for audit. +// 4. Audit mode: records the ExecutionResult but does not apply payload changes +// (the caller already passed a deep-copied invocation for audit mutators). +// 5. In-place mutation: the handler modifies the typed value directly via type +// assertion on inv.Payload. result.Modified is advisory (recorded in the +// ExecutionResult for observability) but not checked by the chain. +func (ce *chainExecutor) executeMutator( + ctx context.Context, m *Mutator, mInv *Invocation, + cr *ChainResult, +) mutatorOutcome { + if m.Handler == nil { + ce.logger.Warn("mutator has nil handler, skipping", + "interceptor", m.Name, + ) + return mutatorSkipped + } + + // Apply per-interceptor timeout from metadata. + handlerCtx := ctx + if m.TimeoutMs > 0 { + var cancel context.CancelFunc + handlerCtx, cancel = context.WithTimeout(ctx, time.Duration(m.TimeoutMs)*time.Millisecond) + defer cancel() + } + + mStart := time.Now() + result, err := m.Handler(handlerCtx, mInv) + dur := time.Since(mStart).Milliseconds() + + if err != nil { + // Distinguish timeout from general mutation errors for downstream consumers. + abortType := AbortMutation + if handlerCtx.Err() == context.DeadlineExceeded { + abortType = AbortTimeout + } + ce.logger.Warn("mutator error", + "interceptor", m.Name, + "error", err, + ) + // Always record the execution result for observability, + // regardless of fail-open/fail-closed. + cr.Results = append(cr.Results, ExecutionResult{ + Interceptor: m.Name, + Type: TypeMutation, + Phase: mInv.Phase, + DurationMs: dur, + Error: err.Error(), + }) + if !m.FailOpen { + // Fail-closed: record abort entry and halt chain. + cr.AbortedAt = append(cr.AbortedAt, AbortInfo{ + Interceptor: m.Name, + Reason: err.Error(), + Type: abortType, + Phase: string(mInv.Phase), + }) + return mutatorAborted + } + // Fail-open: execution result above is the audit trail; + // no abort, continue chain. + return mutatorSkipped + } + + cr.Results = append(cr.Results, ExecutionResult{ + Interceptor: m.Name, + Type: TypeMutation, + Phase: mInv.Phase, + DurationMs: dur, + Mutation: result, + }) + + // Guard against handlers that return (nil, nil). + if result == nil { + return mutatorSkipped + } + + // In audit mode, the mutation result is recorded for observability but + // the payload is not modified — the caller already gave us a deep copy. + if m.Mode == ModeAudit { + return mutatorSkipped + } + + return mutatorOK +} diff --git a/go/sdk/interceptors/chain_test.go b/go/sdk/interceptors/chain_test.go new file mode 100644 index 0000000..46ee6dd --- /dev/null +++ b/go/sdk/interceptors/chain_test.go @@ -0,0 +1,137 @@ +// Copyright 2025 The MCP Interceptors Authors. All rights reserved. +// Use of this source code is governed by an Apache-2.0 +// license that can be found in the LICENSE file. + +package interceptors + +import ( + "context" + "fmt" + "log/slog" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// stubPayload is a minimal payload for chain-level tests. +type stubPayload struct{ Value string } + +func TestChain_FailOpenRecordsExecutionResult(t *testing.T) { + t.Parallel() + // Verifies that fail-open interceptors that return errors still + // have an ExecutionResult recorded in ChainResult.Results, even + // though they don't produce an AbortInfo. + + t.Run("fail-open validator error is recorded", func(t *testing.T) { + t.Parallel() + failOpenValidator := &Validator{ + Metadata: Metadata{ + Name: "fo-validator", + Events: []string{"test/event"}, + Phase: PhaseRequest, + Mode: ModeOn, + FailOpen: true, + }, + Handler: func(_ context.Context, _ *Invocation) (*ValidationResult, error) { + return nil, fmt.Errorf("transient failure") + }, + } + passingValidator := &Validator{ + Metadata: Metadata{ + Name: "passing-validator", + Events: []string{"test/event"}, + Phase: PhaseRequest, + Mode: ModeOn, + }, + Handler: func(_ context.Context, _ *Invocation) (*ValidationResult, error) { + return &ValidationResult{Valid: true}, nil + }, + } + + chain := NewChain(WithChainLogger(slog.Default())) + chain.Add(failOpenValidator).Add(passingValidator) + + inv := &Invocation{ + Event: "test/event", + Phase: PhaseRequest, + Payload: &stubPayload{Value: "hello"}, + } + + cr, err := chain.ExecuteForReceiving(context.Background(), inv) + require.NoError(t, err) + + // Chain should succeed — fail-open doesn't abort. + assert.Equal(t, ChainSuccess, cr.Status) + assert.Empty(t, cr.AbortedAt) + + // Both interceptors should have an ExecutionResult. + require.Len(t, cr.Results, 2) + var found bool + for _, r := range cr.Results { + if r.Interceptor == "fo-validator" { + found = true + assert.Equal(t, TypeValidation, r.Type) + assert.Equal(t, "transient failure", r.Error) + } + } + assert.True(t, found, "fail-open validator should have an ExecutionResult") + }) + + t.Run("fail-open mutator error is recorded", func(t *testing.T) { + t.Parallel() + failOpenMutator := &Mutator{ + Metadata: Metadata{ + Name: "fo-mutator", + Events: []string{"test/event"}, + Phase: PhaseResponse, + Mode: ModeOn, + FailOpen: true, + PriorityHint: NewPriority(10), + }, + Handler: func(_ context.Context, _ *Invocation) (*MutationResult, error) { + return nil, fmt.Errorf("transient failure") + }, + } + passingMutator := &Mutator{ + Metadata: Metadata{ + Name: "passing-mutator", + Events: []string{"test/event"}, + Phase: PhaseResponse, + Mode: ModeOn, + PriorityHint: NewPriority(20), + }, + Handler: func(_ context.Context, _ *Invocation) (*MutationResult, error) { + return &MutationResult{Modified: false}, nil + }, + } + + chain := NewChain(WithChainLogger(slog.Default())) + chain.Add(failOpenMutator).Add(passingMutator) + + inv := &Invocation{ + Event: "test/event", + Phase: PhaseResponse, + Payload: &stubPayload{Value: "hello"}, + } + + cr, err := chain.ExecuteForSending(context.Background(), inv) + require.NoError(t, err) + + // Chain should succeed — fail-open doesn't abort. + assert.Equal(t, ChainSuccess, cr.Status) + assert.Empty(t, cr.AbortedAt) + + // Both interceptors should have an ExecutionResult. + require.Len(t, cr.Results, 2) + var found bool + for _, r := range cr.Results { + if r.Interceptor == "fo-mutator" { + found = true + assert.Equal(t, TypeMutation, r.Type) + assert.Equal(t, "transient failure", r.Error) + } + } + assert.True(t, found, "fail-open mutator should have an ExecutionResult") + }) +} diff --git a/go/sdk/interceptors/chain_validate.go b/go/sdk/interceptors/chain_validate.go new file mode 100644 index 0000000..e223e90 --- /dev/null +++ b/go/sdk/interceptors/chain_validate.go @@ -0,0 +1,191 @@ +// Copyright 2025 The MCP Interceptors Authors. All rights reserved. +// Use of this source code is governed by an Apache-2.0 +// license that can be found in the LICENSE file. + +package interceptors + +import ( + "context" + "sync" + "time" +) + +// validatorResult holds the output of a single validator handler call, +// separating execution from result recording so the handler can run +// without holding a lock. +type validatorResult struct { + result *ValidationResult + err error + durationMs int64 + timedOut bool +} + +// runValidators runs all validators concurrently, collects their results, and +// records any abort conditions. All validators run to completion +// before any abort decision is made — this ensures the ChainResult contains a +// complete picture of all validation findings, not just the first failure. +// +// A mutex guards all writes to the shared ChainResult (Results, AbortedAt, +// ValidationSummary) since goroutines append concurrently. +// +// Error handling per validator: +// - Handler returns error + FailOpen=false (fail-closed): records an +// ExecutionResult (for timing/audit) and an AbortInfo to block the chain. +// - Handler returns error + FailOpen=true: logs and records an +// ExecutionResult (for observability) but does not abort. +// - Handler returns a result with Valid=false + Mode!=Audit (enforced): +// scans messages for the first error-severity entry and records an AbortInfo. +// - Handler returns a result with Valid=false + Mode=Audit: records normally +// without aborting (audit-only observation). +func (ce *chainExecutor) runValidators(ctx context.Context, inv *Invocation, cr *ChainResult) { + if len(ce.validators) == 0 { + return + } + + // Fast path: single validator doesn't need goroutine, WaitGroup, or Mutex. + if len(ce.validators) == 1 { + v := ce.validators[0] + vr := ce.executeValidator(ctx, v, inv) + ce.recordValidation(v, inv, cr, vr) + return + } + + var ( + mu sync.Mutex + wg sync.WaitGroup + ) + + for _, v := range ce.validators { + wg.Add(1) + go func(v *Validator) { + defer wg.Done() + + vr := ce.executeValidator(ctx, v, inv) + + mu.Lock() + ce.recordValidation(v, inv, cr, vr) + mu.Unlock() + }(v) + } + + wg.Wait() +} + +// executeValidator runs a single validator handler and returns its output. +// It does not modify ChainResult, so it is safe to call without a lock. +func (ce *chainExecutor) executeValidator(ctx context.Context, v *Validator, inv *Invocation) validatorResult { + if v.Handler == nil { + ce.logger.Warn("validator has nil handler, skipping", + "interceptor", v.Name, + ) + return validatorResult{} + } + + // Apply per-interceptor timeout from metadata. + handlerCtx := ctx + if v.TimeoutMs > 0 { + var cancel context.CancelFunc + handlerCtx, cancel = context.WithTimeout(ctx, time.Duration(v.TimeoutMs)*time.Millisecond) + defer cancel() + } + + vStart := time.Now() + result, err := v.Handler(handlerCtx, inv) + dur := time.Since(vStart).Milliseconds() + + return validatorResult{ + result: result, + err: err, + durationMs: dur, + timedOut: handlerCtx.Err() == context.DeadlineExceeded, + } +} + +// recordValidation writes a validator's execution outcome into the +// ChainResult. The caller must ensure exclusive access to cr (either +// by holding a lock or being the only writer in the N=1 fast path). +func (ce *chainExecutor) recordValidation(v *Validator, inv *Invocation, cr *ChainResult, vr validatorResult) { + if vr.err != nil { + // Distinguish timeout errors from general validation errors + // so downstream consumers can differentiate root cause. + abortType := AbortValidation + if vr.timedOut { + abortType = AbortTimeout + } + ce.logger.Warn("validator error", + "interceptor", v.Name, + "error", vr.err, + ) + // Always record the execution result for observability, + // regardless of fail-open/fail-closed. + cr.Results = append(cr.Results, ExecutionResult{ + Interceptor: v.Name, + Type: TypeValidation, + Phase: inv.Phase, + DurationMs: vr.durationMs, + Error: vr.err.Error(), + }) + // Fail-closed: record an abort entry to halt the chain. + // Fail-open: the execution result above is the audit trail; + // no abort is recorded. + if !v.FailOpen { + cr.AbortedAt = append(cr.AbortedAt, AbortInfo{ + Interceptor: v.Name, + Reason: vr.err.Error(), + Type: abortType, + Phase: string(inv.Phase), + }) + } + return + } + + // Guard against handlers that return (nil, nil). + if vr.result == nil { + return + } + + cr.Results = append(cr.Results, ExecutionResult{ + Interceptor: v.Name, + Type: TypeValidation, + Phase: inv.Phase, + DurationMs: vr.durationMs, + Validation: vr.result, + }) + + // Tally validation summary. + for _, msg := range vr.result.Messages { + switch msg.Severity { + case SeverityError: + cr.ValidationSummary.Errors++ + case SeverityWarn: + cr.ValidationSummary.Warnings++ + case SeverityInfo: + cr.ValidationSummary.Infos++ + } + } + + // Only error-severity messages cause an abort. A validator + // returning Valid=false with only warn/info messages does NOT + // block the chain — the findings are recorded in the + // ValidationSummary and ExecutionResult but execution continues. + // + // In audit mode (ModeAudit), even error-severity messages don't + // abort — the result is recorded for observability only. + // + // When multiple error messages exist, only the first is recorded + // as the abort reason. All messages remain visible in the + // ValidationResult attached to the ExecutionResult. + if v.Mode != ModeAudit && !vr.result.Valid { + for _, msg := range vr.result.Messages { + if msg.Severity == SeverityError { + cr.AbortedAt = append(cr.AbortedAt, AbortInfo{ + Interceptor: v.Name, + Reason: msg.Message, + Type: AbortValidation, + Phase: string(inv.Phase), + }) + break + } + } + } +} diff --git a/go/sdk/interceptors/doc.go b/go/sdk/interceptors/doc.go new file mode 100644 index 0000000..a7c2493 --- /dev/null +++ b/go/sdk/interceptors/doc.go @@ -0,0 +1,158 @@ +// Copyright 2025 The MCP Interceptors Authors. All rights reserved. +// Use of this source code is governed by an Apache-2.0 +// license that can be found in the LICENSE file. + +// Package interceptors provides a protocol-agnostic validation and +// mutation middleware framework. It defines interceptor types +// ([Validator], [Mutator]), the chain engine ([Chain]), and all +// supporting types needed to build interceptor pipelines for any +// server protocol. +// +// This package has no dependency on MCP or any specific transport. +// For MCP-specific integration (wrapping an [mcp.Server] with +// middleware), see the [interceptors/mcpserver] sub-package. +// +// # Standalone Usage (any protocol) +// +// Create a [Chain], register interceptors, and call +// [Chain.ExecuteForReceiving] or [Chain.ExecuteForSending] from your +// server's request/response pipeline: +// +// chain := interceptors.NewChain( +// interceptors.WithChainLogger(logger), +// ) +// chain.Add(myValidator) +// chain.Add(myMutator) +// +// // In your request handler: +// inv := &interceptors.Invocation{ +// Event: "tools/call", +// Phase: interceptors.PhaseRequest, +// Payload: typedParams, +// } +// cr, err := chain.ExecuteForReceiving(ctx, inv) +// if err != nil { ... } +// if cr != nil && len(cr.AbortedAt) > 0 { +// // handle abort +// } +// +// # Validators +// +// A [Validator] inspects the typed payload and decides whether the +// request or response should proceed. All validators for a given +// event run in parallel; because they share the same [Invocation] +// pointer, handlers MUST treat the Invocation and its Payload as +// read-only — mutating either is a data race. If any validator in +// enforced mode ([ModeOn]) returns an error-severity message, the +// chain aborts before any mutators run. Only error-severity messages +// cause an abort; warn and info findings are recorded in the +// [ChainResult] but do not block the chain. +// +// Type-assert the payload to its concrete type: +// +// v := &interceptors.Validator{ +// Metadata: interceptors.Metadata{ +// Name: "block-dangerous-tool", +// Events: []string{"tools/call"}, +// Phase: interceptors.PhaseRequest, +// Mode: interceptors.ModeOn, +// }, +// Handler: func(ctx context.Context, inv *interceptors.Invocation) (*interceptors.ValidationResult, error) { +// params, ok := inv.Payload.(*MyRequestParams) +// if !ok { +// return nil, fmt.Errorf("unexpected payload type %T", inv.Payload) +// } +// // inspect params ... +// return &interceptors.ValidationResult{Valid: true}, nil +// }, +// } +// +// # Mutators +// +// A [Mutator] transforms the payload in place. Mutators run sequentially +// in priority order (see [Priority]). Each mutator receives the typed +// value and can modify it directly. If any mutator fails (and is not +// configured with FailOpen), the chain aborts. FailOpen mutators +// record an [ExecutionResult] (with the error captured) for +// observability but do not block. +// +// Type-assert the payload and modify the value in place: +// +// m := &interceptors.Mutator{ +// Metadata: interceptors.Metadata{ +// Name: "redact-pii", +// Events: []string{"tools/call"}, +// Phase: interceptors.PhaseResponse, +// Mode: interceptors.ModeOn, +// }, +// Handler: func(ctx context.Context, inv *interceptors.Invocation) (*interceptors.MutationResult, error) { +// result, ok := inv.Payload.(*MyResponseResult) +// if !ok { +// return nil, fmt.Errorf("unexpected payload type %T", inv.Payload) +// } +// // modify result in place ... +// return &interceptors.MutationResult{Modified: true}, nil +// }, +// } +// +// # Execution Order +// +// The chain execution order depends on direction: +// +// Request phase (untrusted → trusted): +// +// Validate (parallel) → Mutate (sequential) +// +// Response phase (trusted → untrusted): +// +// Mutate (sequential) → Validate (parallel) +// +// Validators act as a security gate on the trust boundary side, +// while mutators prepare or sanitize data on the other side. +// +// # Modes and FailOpen +// +// Each interceptor has a [Mode] that controls what happens with +// successful results, and a FailOpen flag that controls what happens +// when the handler returns a Go error. These are orthogonal: +// +// - [ModeOn]: fully enforced — validation failures block, mutations +// apply in place. +// - [ModeAudit]: the handler runs and results are recorded, but +// validation findings do not block and mutations run on a +// deep-copied payload so the real data is unaffected. +// - [ModeOff]: the interceptor is skipped entirely. +// +// FailOpen (default false) controls crash resilience: +// +// - FailOpen=false: a handler error aborts the chain. An +// [ExecutionResult] (with Error populated) and an [AbortInfo] +// are both recorded. +// - FailOpen=true: a handler error is logged and an +// [ExecutionResult] is recorded, but the chain continues. +// +// Note that [ModeAudit] does NOT imply FailOpen. Audit mode only +// suppresses enforcement of successful results (validation findings +// and mutations). If the handler itself returns an error and +// FailOpen is false, the chain still aborts. For truly safe +// observation-only interceptors, set both ModeAudit and FailOpen: +// +// Metadata: interceptors.Metadata{ +// Mode: interceptors.ModeAudit, +// FailOpen: true, +// } +// +// Behavior matrix for validators: +// +// Mode=On, FailOpen=false → error aborts, Valid=false+SeverityError aborts +// Mode=On, FailOpen=true → error continues, Valid=false+SeverityError aborts +// ModeAudit, FailOpen=false → error aborts, findings recorded only +// ModeAudit, FailOpen=true → error continues, findings recorded only +// +// Behavior matrix for mutators: +// +// Mode=On, FailOpen=false → error aborts, mutations applied in place +// Mode=On, FailOpen=true → error continues, mutations applied in place +// ModeAudit, FailOpen=false → error aborts, mutations recorded (deep copy) +// ModeAudit, FailOpen=true → error continues, mutations recorded (deep copy) +package interceptors diff --git a/go/sdk/interceptors/interceptor.go b/go/sdk/interceptors/interceptor.go index 3cac4e1..483753e 100644 --- a/go/sdk/interceptors/interceptor.go +++ b/go/sdk/interceptors/interceptor.go @@ -1,6 +1,163 @@ // Copyright 2025 The MCP Interceptors Authors. All rights reserved. -// Use of this source code is governed by a Apache-2.0 +// Use of this source code is governed by an Apache-2.0 // license that can be found in the LICENSE file. -// Package interceptors implements MCP Interceptors based on [SEP-1763](https://github.com/modelcontextprotocol/modelcontextprotocol/issues/1763). package interceptors + +import ( + "context" + "encoding/json" +) + +// Phase determines when an interceptor runs. +type Phase string + +const ( + PhaseRequest Phase = "request" + PhaseResponse Phase = "response" + PhaseBoth Phase = "both" +) + +// Mode controls enforcement behavior. +type Mode string + +const ( + ModeOn Mode = "on" // Enforced: validation failures block, mutations apply + ModeAudit Mode = "audit" // Audit: log results but don't block or apply mutations + ModeOff Mode = "off" // Disabled +) + +// InterceptorType identifies the category of an interceptor. +type InterceptorType string + +const ( + TypeValidation InterceptorType = "validation" + TypeMutation InterceptorType = "mutation" +) + +// Priority represents an interceptor's ordering hint. +// Can be a single value (applies to both phases) or per-phase. +// +// JSON representation is polymorphic: a single number when both +// phases are equal, or {"request": N, "response": N} when they differ. +type Priority struct { + Request int + Response int +} + +// NewPriority creates a Priority with the same value for both phases. +func NewPriority(v int) Priority { + return Priority{Request: v, Response: v} +} + +// Resolve returns the priority for the given phase. +func (p Priority) Resolve(phase Phase) int { + if phase == PhaseResponse { + return p.Response + } + return p.Request +} + +// MarshalJSON implements polymorphic serialization: +// emits a single number when both phases are equal, or an object otherwise. +func (p Priority) MarshalJSON() ([]byte, error) { + if p.Request == p.Response { + return json.Marshal(p.Request) + } + return json.Marshal(struct { + Request int `json:"request,omitempty"` + Response int `json:"response,omitempty"` + }{p.Request, p.Response}) +} + +// UnmarshalJSON handles both number and {request, response} forms. +func (p *Priority) UnmarshalJSON(data []byte) error { + var n int + if err := json.Unmarshal(data, &n); err == nil { + p.Request = n + p.Response = n + return nil + } + var obj struct { + Request int `json:"request"` + Response int `json:"response"` + } + if err := json.Unmarshal(data, &obj); err != nil { + return err + } + p.Request = obj.Request + p.Response = obj.Response + return nil +} + +// Severity represents validation message severity. +type Severity string + +const ( + SeverityInfo Severity = "info" + SeverityWarn Severity = "warn" + SeverityError Severity = "error" // Only error blocks execution +) + +// Compat represents protocol version compatibility. +type Compat struct { + MinProtocol string `json:"minProtocol"` + MaxProtocol string `json:"maxProtocol,omitempty"` +} + +// Metadata holds all common interceptor metadata. +type Metadata struct { + Name string `json:"name"` + Version string `json:"version,omitempty"` + Description string `json:"description,omitempty"` + Events []string `json:"events"` + Phase Phase `json:"phase"` + PriorityHint Priority `json:"priorityHint,omitempty"` + Compat *Compat `json:"compat,omitempty"` + ConfigSchema json.RawMessage `json:"configSchema,omitempty"` + + // TODO: Therese are essentially deployment configs?, so we might want to move them into a separate struct + // and have the user configure them separately? + Mode Mode `json:"mode"` + FailOpen bool `json:"failOpen,omitempty"` + TimeoutMs int64 `json:"timeoutMs,omitempty"` +} + +// Interceptor is the common interface for all interceptors. It is implemented by both Validator and Mutator. +type Interceptor interface { + GetMetadata() *Metadata + GetType() InterceptorType +} + +// --- ValidatorHandler --- + +// ValidatorHandler is the function signature for validation handlers. +// +// Handlers MUST treat the Invocation and its Payload as read-only. +// Multiple validators for the same event run concurrently and share +// the same Invocation pointer, so any mutation of the Payload (or +// other Invocation fields) is a data race. +type ValidatorHandler func(ctx context.Context, inv *Invocation) (*ValidationResult, error) + +// Validator is a validation interceptor. +type Validator struct { + Metadata + Handler ValidatorHandler +} + +func (v *Validator) GetMetadata() *Metadata { return &v.Metadata } +func (v *Validator) GetType() InterceptorType { return TypeValidation } + +// --- MutatorHandler --- + +// MutatorHandler is the function signature for raw mutation handlers. +type MutatorHandler func(ctx context.Context, inv *Invocation) (*MutationResult, error) + +// Mutator is a mutation interceptor. +type Mutator struct { + Metadata + Handler MutatorHandler +} + +func (m *Mutator) GetMetadata() *Metadata { return &m.Metadata } +func (m *Mutator) GetType() InterceptorType { return TypeMutation } diff --git a/go/sdk/interceptors/interceptor_test.go b/go/sdk/interceptors/interceptor_test.go deleted file mode 100644 index fb110b9..0000000 --- a/go/sdk/interceptors/interceptor_test.go +++ /dev/null @@ -1,14 +0,0 @@ -// Copyright 2025 The MCP Interceptors Authors. All rights reserved. -// Use of this source code is governed by a Apache-2.0 -// license that can be found in the LICENSE file. - -package interceptors - -import "testing" - -func TestPlaceholder(t *testing.T) { - // Placeholder test to ensure go test runs successfully - if false { - t.Error("This should not happen") - } -} diff --git a/go/sdk/interceptors/invocation.go b/go/sdk/interceptors/invocation.go new file mode 100644 index 0000000..6d1e7df --- /dev/null +++ b/go/sdk/interceptors/invocation.go @@ -0,0 +1,80 @@ +// Copyright 2025 The MCP Interceptors Authors. All rights reserved. +// Use of this source code is governed by an Apache-2.0 +// license that can be found in the LICENSE file. + +package interceptors + +import ( + "encoding/json" + "fmt" + "reflect" +) + +// Invocation is the context passed to every interceptor handler. +type Invocation struct { + Event string // e.g. "tools/call" + Phase Phase // "request" or "response" + Payload any // The typed payload (use type assertion directly) + Config map[string]any // Per-invocation config + Context *InvocationContext // Optional caller context (identity, trace, etc.) + Session any // The server session; nil for protocol-level invocations + + mutatedParams any // set via SetMutatedParams in response phase only +} + +// MutatedParams returns the request params after request-phase mutators +// have run. Only available in the response phase; returns nil in the +// request phase. +func (inv *Invocation) MutatedParams() any { + return inv.mutatedParams +} + +// SetMutatedParams sets the mutated request params on the invocation. +// This is intended for server integrations that build response-phase +// invocations; interceptor handlers should use MutatedParams() to read. +func (inv *Invocation) SetMutatedParams(p any) { + inv.mutatedParams = p +} + +// withCopiedPayload returns a shallow copy of the Invocation with a +// deep-copied Payload. Used for audit-mode mutators so their in-place +// modifications don't affect the real struct. +// +// The clone is a JSON round-trip: marshal the typed value, then +// unmarshal into a fresh instance of the same concrete type via +// reflect.New. The payload must be a pointer. +func (inv *Invocation) withCopiedPayload() (*Invocation, error) { + raw, err := json.Marshal(inv.Payload) + if err != nil { + return nil, fmt.Errorf("clone payload: marshal: %w", err) + } + cp := reflect.New(reflect.TypeOf(inv.Payload).Elem()) + if err := json.Unmarshal(raw, cp.Interface()); err != nil { + return nil, fmt.Errorf("clone payload: unmarshal: %w", err) + } + return &Invocation{ + Event: inv.Event, + Phase: inv.Phase, + Payload: cp.Interface(), + mutatedParams: inv.mutatedParams, + Config: inv.Config, + Context: inv.Context, + Session: inv.Session, + }, nil +} + +// InvocationContext holds optional context passed to interceptors. +type InvocationContext struct { + Principal *Principal `json:"principal,omitempty"` + TraceID string `json:"traceId,omitempty"` + SpanID string `json:"spanId,omitempty"` + Timestamp string `json:"timestamp,omitempty"` + SessionID string `json:"sessionId,omitempty"` +} + +// Principal identifies the caller. +type Principal struct { + Type string `json:"type"` + ID string `json:"id,omitempty"` + Claims map[string]any `json:"claims,omitempty"` +} diff --git a/go/sdk/interceptors/mcpserver/doc.go b/go/sdk/interceptors/mcpserver/doc.go new file mode 100644 index 0000000..3e02d80 --- /dev/null +++ b/go/sdk/interceptors/mcpserver/doc.go @@ -0,0 +1,62 @@ +// Copyright 2025 The MCP Interceptors Authors. All rights reserved. +// Use of this source code is governed by an Apache-2.0 +// license that can be found in the LICENSE file. + +// Package mcpserver integrates the [interceptors] chain engine with +// [mcp.Server] from the go-sdk. It installs interceptors as +// receiving middleware, so they run for all transports (stdio, HTTP) +// and compose with other extensions like variants. +// +// # Getting Started +// +// Wrap an existing [mcp.Server] with [NewServer], register one or +// more interceptors, then start the server as usual: +// +// srv := mcpserver.NewServer(mcpServer, +// mcpserver.WithLogger(logger), +// mcpserver.WithContextProvider(myProvider), +// ) +// srv.AddInterceptor(myValidator) +// srv.AddInterceptor(myMutator) +// srv.Run(ctx, transport) +// +// # Middleware Ordering +// +// [NewServer] calls [mcp.Server.AddReceivingMiddleware] to install the +// interceptor chain. The go-sdk stacks middleware so that the first +// middleware added runs outermost. This means ordering depends on +// when other middleware is registered relative to NewServer: +// +// // auth runs BEFORE interceptors (outermost) +// mcpServer.AddReceivingMiddleware(authMiddleware) +// srv := mcpserver.NewServer(mcpServer) +// // logging runs AFTER interceptors (innermost) +// mcpServer.AddReceivingMiddleware(loggingMiddleware) +// +// Execution order: auth → interceptors → logging → handler +// +// # Context Provider +// +// Use [WithContextProvider] to populate [interceptors.InvocationContext] +// for every intercepted request. This is typically used to extract +// caller identity from OAuth tokens or session metadata: +// +// mcpserver.WithContextProvider( +// func(ctx context.Context, req mcp.Request) *interceptors.InvocationContext { +// return &interceptors.InvocationContext{ +// Principal: &interceptors.Principal{Type: "user", ID: "alice"}, +// } +// }, +// ) +// +// Handlers access this via [interceptors.Invocation].Context. +// +// # Transports +// +// The interceptor server works with any transport supported by the +// go-sdk. For stdio, use [Server.Run]. For HTTP, use +// [NewStreamableHTTPHandler]: +// +// handler := mcpserver.NewStreamableHTTPHandler(srv, nil) +// http.ListenAndServe(":8080", handler) +package mcpserver diff --git a/go/sdk/interceptors/mcpserver/events.go b/go/sdk/interceptors/mcpserver/events.go new file mode 100644 index 0000000..a5d4809 --- /dev/null +++ b/go/sdk/interceptors/mcpserver/events.go @@ -0,0 +1,17 @@ +// Copyright 2025 The MCP Interceptors Authors. All rights reserved. +// Use of this source code is governed by an Apache-2.0 +// license that can be found in the LICENSE file. + +package mcpserver + +// Event name constants for standard MCP methods. +const ( + // Server Features + EventToolsList = "tools/list" + EventToolsCall = "tools/call" + EventPromptsList = "prompts/list" + EventPromptsGet = "prompts/get" + EventResourcesList = "resources/list" + EventResourcesRead = "resources/read" + EventResourcesSubscribe = "resources/subscribe" +) diff --git a/go/sdk/interceptors/mcpserver/server.go b/go/sdk/interceptors/mcpserver/server.go new file mode 100644 index 0000000..9c82227 --- /dev/null +++ b/go/sdk/interceptors/mcpserver/server.go @@ -0,0 +1,366 @@ +// Copyright 2025 The MCP Interceptors Authors. All rights reserved. +// Use of this source code is governed by an Apache-2.0 +// license that can be found in the LICENSE file. + +package mcpserver + +import ( + "context" + "encoding/json" + "log/slog" + "net/http" + + "github.com/modelcontextprotocol/go-sdk/jsonrpc" + "github.com/modelcontextprotocol/go-sdk/mcp" + + "github.com/modelcontextprotocol/ext-interceptors/go/sdk/interceptors" +) + +// extensionID is the capability key used in Capabilities.Experimental +const extensionID = "io.modelcontextprotocol/interceptors" + +// ServerOption configures a Server. +type ServerOption func(*Server) + +// WithLogger sets the logger for interceptor chain execution. +// If not set, slog.Default() is used. +func WithLogger(l *slog.Logger) ServerOption { + return func(s *Server) { + s.logger = l + } +} + +// ContextProviderFunc extracts an InvocationContext from an incoming MCP +// request. This is called once per intercepted request and the result is +// passed to all interceptor handlers via Invocation.Context. +// +// Typical use: extract principal identity from OAuth tokens available on +// the request's session (via RequestExtra.TokenInfo). +type ContextProviderFunc func(ctx context.Context, req mcp.Request) *interceptors.InvocationContext + +// WithContextProvider sets a function that populates InvocationContext +// for every intercepted request. Without this, Invocation.Context is nil. +func WithContextProvider(f ContextProviderFunc) ServerOption { + return func(s *Server) { + s.contextProvider = f + } +} + +// Server wraps an *mcp.Server with interceptor support. Interceptors are +// installed as middleware on the inner server at construction time, so they +// work across all transports (stdio, HTTP) and compose with other extensions +// like variants. +type Server struct { + inner *mcp.Server + chain *interceptors.Chain + logger *slog.Logger + contextProvider ContextProviderFunc +} + +// NewServer creates a new interceptor-aware server wrapping the given +// mcp.Server. Middleware is installed immediately, so interceptors added +// via AddInterceptor will be active for all future requests. +func NewServer(server *mcp.Server, opts ...ServerOption) *Server { + s := &Server{ + inner: server, + } + for _, opt := range opts { + opt(s) + } + if s.logger == nil { + s.logger = slog.Default() + } + s.chain = interceptors.NewChain(interceptors.WithChainLogger(s.logger)) + server.AddReceivingMiddleware(s.receivingMiddleware()) + return s +} + +// AddInterceptor registers an interceptor. It panics if the interceptor +// has a nil handler. It is safe to call while the server is running. +// Returns the receiver for chaining. +func (s *Server) AddInterceptor(i interceptors.Interceptor) *Server { + s.chain.Add(i) + return s +} + +// MCPServer returns the underlying *mcp.Server. This is useful when +// composing with other extensions (e.g., variants): +// +// is := mcpserver.NewServer(mcpServer) +// is.AddInterceptor(myValidator) +// vs.WithVariant(variant, is.MCPServer(), 0) +func (s *Server) MCPServer() *mcp.Server { + return s.inner +} + +// Run delegates to the inner server's Run method. Convenience for standalone +// use (e.g., stdio). +func (s *Server) Run(ctx context.Context, t mcp.Transport) error { + return s.inner.Run(ctx, t) +} + +// receivingMiddleware returns mcp.Middleware that intercepts incoming +// requests and outgoing responses. +func (s *Server) receivingMiddleware() mcp.Middleware { + return func(next mcp.MethodHandler) mcp.MethodHandler { + return func(ctx context.Context, method string, req mcp.Request) (mcp.Result, error) { + // When handling the "initialize" method, we enrich the result with interceptor + // capability metadata. This allows clients to discover interceptor support and + // capabilities + if method == "initialize" { + result, err := next(ctx, method, req) + if err != nil { + return nil, err + } + return s.enrichInitResult(result) + } + + event := method + + // Fast path: no interceptors match this event. + if s.chain.IsEmpty(event, interceptors.PhaseRequest) && s.chain.IsEmpty(event, interceptors.PhaseResponse) { + return next(ctx, method, req) + } + + var session *mcp.ServerSession + if sess, ok := req.GetSession().(*mcp.ServerSession); ok { + session = sess + } + + // Extract invocation context from the request using the context provider, if set. + // This allows interceptor handlers to access caller identity and other contextual info. + var invCtx *interceptors.InvocationContext + if s.contextProvider != nil { + invCtx = s.contextProvider(ctx, req) + } + + // Run request-phase interceptors. + if !s.chain.IsEmpty(event, interceptors.PhaseRequest) { + if err := s.interceptRequest(ctx, req, event, session, invCtx); err != nil { + return nil, err + } + } + + // Call the next handler. + result, err := next(ctx, method, req) + if err != nil { + return result, err + } + + // Run response-phase interceptors. + if !s.chain.IsEmpty(event, interceptors.PhaseResponse) { + return s.interceptResponse(ctx, req, result, event, session, invCtx) + } + return result, nil + } + } +} + +// interceptRequest runs the request-phase interceptor chain. Mutators +// modify the typed params in place, so no marshal/unmarshal is needed. +func (s *Server) interceptRequest( + ctx context.Context, + req mcp.Request, + event string, + session *mcp.ServerSession, + invCtx *interceptors.InvocationContext, +) error { + inv := &interceptors.Invocation{ + Event: event, + Phase: interceptors.PhaseRequest, + Payload: req.GetParams(), + Session: session, + Context: invCtx, + } + cr, err := s.chain.ExecuteForReceiving(ctx, inv) + if err != nil { + return err + } + if cr != nil && len(cr.AbortedAt) > 0 { + s.logAborts("request", event, cr.AbortedAt) + return abortToJSONRPCError(cr.AbortedAt) + } + return nil +} + +// interceptResponse runs the response-phase interceptor chain. Mutators +// modify the typed result in place, so no marshal/unmarshal is needed. +func (s *Server) interceptResponse( + ctx context.Context, + req mcp.Request, + result mcp.Result, + event string, + session *mcp.ServerSession, + invCtx *interceptors.InvocationContext, +) (mcp.Result, error) { + inv := &interceptors.Invocation{ + Event: event, + Phase: interceptors.PhaseResponse, + Payload: result, + Session: session, + Context: invCtx, + } + inv.SetMutatedParams(req.GetParams()) + cr, err := s.chain.ExecuteForSending(ctx, inv) + if err != nil { + return nil, err + } + if cr != nil && len(cr.AbortedAt) > 0 { + s.logAborts("response", event, cr.AbortedAt) + return nil, abortToJSONRPCError(cr.AbortedAt) + } + + // Result was already modified in place by mutators — no unmarshal needed. + return result, nil +} + +// logAborts logs each abort entry at warn level. +func (s *Server) logAborts(phase, event string, aborts []interceptors.AbortInfo) { + for _, a := range aborts { + s.logger.Warn("interceptors: "+phase+" aborted", + "event", event, + "interceptor", a.Interceptor, + "reason", a.Reason, + ) + } +} + +// enrichInitResult injects interceptor capability metadata into the +// InitializeResult, following the same pattern as the variants extension. +// The capability is declared under Capabilities.Experimental so it works +// without upstream go-sdk changes. +func (s *Server) enrichInitResult(result mcp.Result) (mcp.Result, error) { + initResult, ok := result.(*mcp.InitializeResult) + if !ok { + return result, nil + } + + if initResult.Capabilities == nil { + initResult.Capabilities = &mcp.ServerCapabilities{} + } + if initResult.Capabilities.Experimental == nil { + initResult.Capabilities.Experimental = make(map[string]any) + } + + // Build the interceptor list in wire format. + // wireInterceptor embeds Metadata (which carries all JSON tags) and + // adds the type field from GetType(). + type wireInterceptor struct { + *interceptors.Metadata + Type interceptors.InterceptorType `json:"type"` + } + all := s.chain.Interceptors() + interceptorInfos := make([]wireInterceptor, 0, len(all)) + supportedEvents := map[string]bool{} + for _, ri := range all { + meta := ri.GetMetadata() + interceptorInfos = append(interceptorInfos, wireInterceptor{ + Metadata: meta, + Type: ri.GetType(), + }) + for _, e := range meta.Events { + supportedEvents[e] = true + } + } + + events := make([]string, 0, len(supportedEvents)) + for e := range supportedEvents { + events = append(events, e) + } + + initResult.Capabilities.Experimental[extensionID] = map[string]any{ + "supportedEvents": events, + "interceptors": interceptorInfos, + } + + return initResult, nil +} + +// NewStreamableHTTPHandler returns a new [mcp.StreamableHTTPHandler] for +// serving multiple concurrent clients over HTTP. It mirrors +// [mcp.NewStreamableHTTPHandler]. +// +// handler := mcpserver.NewStreamableHTTPHandler(srv, nil) +// http.ListenAndServe(":8080", handler) +func NewStreamableHTTPHandler(s *Server, opts *mcp.StreamableHTTPOptions) *mcp.StreamableHTTPHandler { + if s == nil { + panic("mcpserver: nil Server") + } + srv := s.MCPServer() + return mcp.NewStreamableHTTPHandler( + func(r *http.Request) *mcp.Server { return srv }, + opts, + ) +} + +// --- abort conversion --- + +// Error data structs for abortToJSONRPCError. + +type validationErrorData struct { + ValidationErrors []interceptors.AbortInfo `json:"validationErrors"` +} + +type mutationErrorData struct { + FailedInterceptor string `json:"failedInterceptor"` + Reason string `json:"reason"` +} + +type timeoutErrorData struct { + Interceptor string `json:"interceptor"` + TimeoutMs int64 `json:"timeoutMs,omitempty"` + Phase string `json:"phase"` +} + +// codeServerError is the JSON-RPC error code for timeout. +const codeServerError = -32000 + +// abortToJSONRPCError converts a list of abort infos into a *jsonrpc.Error +// with the correct error code and structured data: +// - Validation failures: -32602 (InvalidParams) with validationErrors data +// - Mutation failures: -32603 (InternalError) with failedInterceptor data +func abortToJSONRPCError(aborts []interceptors.AbortInfo) *jsonrpc.Error { + first := aborts[0] + + // Determine the error code from the abort type. + var code int64 + switch first.Type { + case interceptors.AbortValidation: + code = jsonrpc.CodeInvalidParams // -32602 + case interceptors.AbortTimeout: + code = codeServerError // -32000 + default: + code = jsonrpc.CodeInternalError // -32603 + } + + // Build structured error data. + var data json.RawMessage + switch first.Type { + case interceptors.AbortValidation: + d, _ := json.Marshal(validationErrorData{ValidationErrors: aborts}) + data = d + case interceptors.AbortTimeout: + d, _ := json.Marshal(timeoutErrorData{ + Interceptor: first.Interceptor, + Phase: first.Phase, + }) + data = d + default: + d, _ := json.Marshal(mutationErrorData{ + FailedInterceptor: first.Interceptor, + Reason: first.Reason, + }) + data = d + } + + msg := "interceptor abort: " + first.Reason + if len(aborts) > 1 { + msg = "interceptor abort: multiple validation failures" + } + + return &jsonrpc.Error{ + Code: code, + Message: msg, + Data: data, + } +} diff --git a/go/sdk/interceptors/mcpserver/server_integration_test.go b/go/sdk/interceptors/mcpserver/server_integration_test.go new file mode 100644 index 0000000..a1283b6 --- /dev/null +++ b/go/sdk/interceptors/mcpserver/server_integration_test.go @@ -0,0 +1,577 @@ +package mcpserver_test + +import ( + "context" + "encoding/json" + "fmt" + "strings" + "sync/atomic" + "testing" + "time" + + "github.com/modelcontextprotocol/go-sdk/mcp" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/modelcontextprotocol/ext-interceptors/go/sdk/interceptors" + "github.com/modelcontextprotocol/ext-interceptors/go/sdk/interceptors/mcpserver" +) + +func TestServer_ValidatorRejectsBlockedTool(t *testing.T) { + t.Parallel() + cs := setup(t, blockToolValidator("echo")) + + _, err := callEcho(t, cs) + assert.Error(t, err) +} + +func TestServer_ValidatorAllowsTool(t *testing.T) { + t.Parallel() + cs := setup(t, allowAllValidator("allow-all")) + + result, err := callEcho(t, cs) + require.NoError(t, err) + assert.True(t, strings.HasPrefix(resultText(t, result), "echo:")) +} + +func TestServer_MutatorModifiesResponse(t *testing.T) { + t.Parallel() + cs := setup(t, prefixMutator("prefix", "[mutated]", interceptors.PhaseResponse, 0, interceptors.ModeOn)) + + result, err := callEcho(t, cs) + require.NoError(t, err) + assert.True(t, strings.HasPrefix(resultText(t, result), "[mutated] echo:")) +} + +func TestServer_ChainedMutatorsWithAuditMode(t *testing.T) { + t.Parallel() + // 5 response mutators in priority order. Mutator 3 (priority 30) is + // audit-mode: it runs on a deep copy, so its "[AUDIT]" prefix must + // NOT appear in the final output. The other four are enforced and + // each prepend a tag. Expected final text (applied innermost-first + // by ascending priority): + // + // [M5] [M4] [M2] [M1] echo: ... + cs := setup(t, + prefixMutator("m1", "[M1]", interceptors.PhaseResponse, 10, interceptors.ModeOn), + prefixMutator("m2", "[M2]", interceptors.PhaseResponse, 20, interceptors.ModeOn), + prefixMutator("m3-audit", "[AUDIT]", interceptors.PhaseResponse, 30, interceptors.ModeAudit), + prefixMutator("m4", "[M4]", interceptors.PhaseResponse, 40, interceptors.ModeOn), + prefixMutator("m5", "[M5]", interceptors.PhaseResponse, 50, interceptors.ModeOn), + ) + + result, err := callEcho(t, cs) + require.NoError(t, err) + + text := resultText(t, result) + assert.True(t, strings.HasPrefix(text, "[M5] [M4] [M2] [M1] echo:")) + assert.NotContains(t, text, "[AUDIT]") +} + +func TestServer_MutatorFailureAbortsChain(t *testing.T) { + t.Parallel() + // 3 response mutators. Mutator 2 (priority 20) returns an error + // and is fail-closed. The chain aborts; client receives an error. + cs := setup(t, + prefixMutator("m1", "[M1]", interceptors.PhaseResponse, 10, interceptors.ModeOn), + failMutator("m2-fail", interceptors.PhaseResponse, 20), + prefixMutator("m3", "[M3]", interceptors.PhaseResponse, 30, interceptors.ModeOn), + ) + + _, err := callEcho(t, cs) + assert.Error(t, err) +} + +func TestServer_ValidatorWarnDoesNotBlock(t *testing.T) { + t.Parallel() + // A validator returns Valid=false with only SeverityWarn messages. + // The chain should continue — warn-only failures don't abort. + warnValidator := &interceptors.Validator{ + Metadata: interceptors.Metadata{ + Name: "warn-only", + Events: []string{mcpserver.EventToolsCall}, + Phase: interceptors.PhaseRequest, + Mode: interceptors.ModeOn, + }, + Handler: func(_ context.Context, _ *interceptors.Invocation) (*interceptors.ValidationResult, error) { + return &interceptors.ValidationResult{ + Valid: false, + Messages: []interceptors.ValidationMessage{ + {Message: "something looks off", Severity: interceptors.SeverityWarn}, + }, + }, nil + }, + } + + cs := setup(t, warnValidator) + + result, err := callEcho(t, cs) + require.NoError(t, err) + assert.True(t, strings.HasPrefix(resultText(t, result), "echo:")) +} + +func TestServer_ValidationFailurePreventsM(t *testing.T) { + t.Parallel() + // A validator rejects the request. A mutator is also registered but + // should never run because executeForReceiving returns at the + // ChainValidationFailed check before reaching runMutators. + mutatorRan := false + spy := &interceptors.Mutator{ + Metadata: interceptors.Metadata{ + Name: "spy-mutator", + Events: []string{mcpserver.EventToolsCall}, + Phase: interceptors.PhaseRequest, + Mode: interceptors.ModeOn, + }, + Handler: func(_ context.Context, _ *interceptors.Invocation) (*interceptors.MutationResult, error) { + mutatorRan = true + return &interceptors.MutationResult{Modified: false}, nil + }, + } + + cs := setup(t, blockToolValidator("echo"), spy) + + _, err := callEcho(t, cs) + assert.Error(t, err) + assert.False(t, mutatorRan, "mutator should not run after validation failure") +} + +func TestServer_ValidatorTimeoutAbortsChain(t *testing.T) { + t.Parallel() + slowValidator := &interceptors.Validator{ + Metadata: interceptors.Metadata{ + Name: "slow-validator", + Events: []string{mcpserver.EventToolsCall}, + Phase: interceptors.PhaseRequest, + Mode: interceptors.ModeOn, + TimeoutMs: 1, + }, + Handler: func(ctx context.Context, _ *interceptors.Invocation) (*interceptors.ValidationResult, error) { + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-time.After(100 * time.Millisecond): + return &interceptors.ValidationResult{Valid: true}, nil + } + }, + } + + cs := setup(t, slowValidator) + + _, err := callEcho(t, cs) + require.Error(t, err) + assert.Contains(t, err.Error(), "context deadline exceeded") +} + +func TestServer_MutatorTimeoutAbortsChain(t *testing.T) { + t.Parallel() + // 3 response mutators. Mutator 2 (priority 20) blocks longer than + // its 1ms timeout. The chain aborts; client receives an error. + slowMutator := &interceptors.Mutator{ + Metadata: interceptors.Metadata{ + Name: "m2-slow", + Events: []string{mcpserver.EventToolsCall}, + Phase: interceptors.PhaseResponse, + Mode: interceptors.ModeOn, + TimeoutMs: 1, + PriorityHint: interceptors.NewPriority(20), + }, + Handler: func(ctx context.Context, _ *interceptors.Invocation) (*interceptors.MutationResult, error) { + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-time.After(100 * time.Millisecond): + return &interceptors.MutationResult{Modified: false}, nil + } + }, + } + + cs := setup(t, + prefixMutator("m1", "[M1]", interceptors.PhaseResponse, 10, interceptors.ModeOn), + slowMutator, + prefixMutator("m3", "[M3]", interceptors.PhaseResponse, 30, interceptors.ModeOn), + ) + + _, err := callEcho(t, cs) + require.Error(t, err) + assert.Contains(t, err.Error(), "context deadline exceeded") +} + +func TestServer_MutatorFailOpenContinuesChain(t *testing.T) { + t.Parallel() + // 3 response mutators. Mutator 2 (priority 20) returns an error but + // has FailOpen set. The chain continues; M1 and M3 both apply. + failOpenMutator := &interceptors.Mutator{ + Metadata: interceptors.Metadata{ + Name: "m2-failopen", + Events: []string{mcpserver.EventToolsCall}, + Phase: interceptors.PhaseResponse, + Mode: interceptors.ModeOn, + FailOpen: true, + PriorityHint: interceptors.NewPriority(20), + }, + Handler: func(_ context.Context, _ *interceptors.Invocation) (*interceptors.MutationResult, error) { + return nil, fmt.Errorf("simulated failure") + }, + } + + cs := setup(t, + prefixMutator("m1", "[M1]", interceptors.PhaseResponse, 10, interceptors.ModeOn), + failOpenMutator, + prefixMutator("m3", "[M3]", interceptors.PhaseResponse, 30, interceptors.ModeOn), + ) + + result, err := callEcho(t, cs) + require.NoError(t, err) + + text := resultText(t, result) + assert.Contains(t, text, "[M1]") + assert.Contains(t, text, "[M3]") +} + +func TestServer_ResponseMutatorSeesMutatedParams(t *testing.T) { + t.Parallel() + // A request mutator injects a marker into the arguments. A response + // mutator reads MutatedParams() and copies the marker into the + // response text. This proves MutatedParams() reflects request-phase + // mutations, not the original request. + reqMutator := &interceptors.Mutator{ + Metadata: interceptors.Metadata{ + Name: "inject-marker", + Events: []string{mcpserver.EventToolsCall}, + Phase: interceptors.PhaseRequest, + Mode: interceptors.ModeOn, + }, + Handler: func(_ context.Context, inv *interceptors.Invocation) (*interceptors.MutationResult, error) { + params, ok := inv.Payload.(*mcp.CallToolParamsRaw) + if !ok { + return nil, fmt.Errorf("unexpected payload type %T", inv.Payload) + } + var args map[string]any + if err := json.Unmarshal(params.Arguments, &args); err != nil { + return nil, err + } + args["marker"] = "injected-by-request-mutator" + raw, err := json.Marshal(args) + if err != nil { + return nil, err + } + params.Arguments = raw + return &interceptors.MutationResult{Modified: true}, nil + }, + } + + respMutator := &interceptors.Mutator{ + Metadata: interceptors.Metadata{ + Name: "read-marker", + Events: []string{mcpserver.EventToolsCall}, + Phase: interceptors.PhaseResponse, + Mode: interceptors.ModeOn, + }, + Handler: func(_ context.Context, inv *interceptors.Invocation) (*interceptors.MutationResult, error) { + result, ok := inv.Payload.(*mcp.CallToolResult) + if !ok { + return nil, fmt.Errorf("unexpected payload type %T", inv.Payload) + } + params, ok := inv.MutatedParams().(*mcp.CallToolParamsRaw) + if !ok { + return nil, fmt.Errorf("unexpected MutatedParams type %T", inv.MutatedParams()) + } + var args map[string]any + if err := json.Unmarshal(params.Arguments, &args); err != nil { + return nil, err + } + marker, _ := args["marker"].(string) + for _, c := range result.Content { + if tc, ok := c.(*mcp.TextContent); ok { + tc.Text = fmt.Sprintf("[marker=%s] %s", marker, tc.Text) + } + } + return &interceptors.MutationResult{Modified: true}, nil + }, + } + + cs := setup(t, reqMutator, respMutator) + + result, err := callEcho(t, cs) + require.NoError(t, err) + assert.Contains(t, resultText(t, result), "[marker=injected-by-request-mutator]") +} + +func TestServer_CombinedValidatorsAndMutators(t *testing.T) { + t.Parallel() + // Exercises the full chain with 3 validators and 3 mutators on each + // side, verifying: + // + // Request: Validate (parallel) → Mutate (sequential by priority) + // Response: Mutate (sequential by priority) → Validate (parallel) + // + // Request mutators inject fields into arguments (reflected in echo output). + // Response mutators prepend tags to text (visible in final result). + // Atomic counters confirm every interceptor actually ran. + + var ( + reqValCount atomic.Int32 + reqMutCount atomic.Int32 + respMutCount atomic.Int32 + respValCount atomic.Int32 + ) + + // --- Request validators (3, parallel) --- + + // req-v1: passes if tool name is "echo". + reqV1 := &interceptors.Validator{ + Metadata: interceptors.Metadata{ + Name: "req-v1-tool-check", + Events: []string{mcpserver.EventToolsCall}, + Phase: interceptors.PhaseRequest, + Mode: interceptors.ModeOn, + }, + Handler: func(_ context.Context, inv *interceptors.Invocation) (*interceptors.ValidationResult, error) { + reqValCount.Add(1) + params, ok := inv.Payload.(*mcp.CallToolParamsRaw) + if !ok { + return nil, fmt.Errorf("unexpected type %T", inv.Payload) + } + if params.Name != "echo" { + return &interceptors.ValidationResult{ + Valid: false, + Messages: []interceptors.ValidationMessage{ + {Message: "only echo allowed", Severity: interceptors.SeverityError}, + }, + }, nil + } + return &interceptors.ValidationResult{Valid: true}, nil + }, + } + + // req-v2: passes if arguments contain "text". + reqV2 := &interceptors.Validator{ + Metadata: interceptors.Metadata{ + Name: "req-v2-args-check", + Events: []string{mcpserver.EventToolsCall}, + Phase: interceptors.PhaseRequest, + Mode: interceptors.ModeOn, + }, + Handler: func(_ context.Context, inv *interceptors.Invocation) (*interceptors.ValidationResult, error) { + reqValCount.Add(1) + params, ok := inv.Payload.(*mcp.CallToolParamsRaw) + if !ok { + return nil, fmt.Errorf("unexpected type %T", inv.Payload) + } + var args map[string]any + if err := json.Unmarshal(params.Arguments, &args); err != nil { + return nil, err + } + if _, ok := args["text"]; !ok { + return &interceptors.ValidationResult{ + Valid: false, + Messages: []interceptors.ValidationMessage{ + {Message: "missing text argument", Severity: interceptors.SeverityError}, + }, + }, nil + } + return &interceptors.ValidationResult{Valid: true}, nil + }, + } + + // req-v3: always passes with a warning (warn doesn't block). + reqV3 := &interceptors.Validator{ + Metadata: interceptors.Metadata{ + Name: "req-v3-warn", + Events: []string{mcpserver.EventToolsCall}, + Phase: interceptors.PhaseRequest, + Mode: interceptors.ModeOn, + }, + Handler: func(_ context.Context, _ *interceptors.Invocation) (*interceptors.ValidationResult, error) { + reqValCount.Add(1) + return &interceptors.ValidationResult{ + Valid: true, + Messages: []interceptors.ValidationMessage{ + {Message: "request noted", Severity: interceptors.SeverityWarn}, + }, + }, nil + }, + } + + // --- Request mutators (3, sequential by priority 10 → 20 → 30) --- + // Each adds a field to the arguments. The echo handler serializes + // the arguments map, so all three fields appear in the output. + + requestArgMutator := func(name, key, value string, priority int) *interceptors.Mutator { + return &interceptors.Mutator{ + Metadata: interceptors.Metadata{ + Name: name, + Events: []string{mcpserver.EventToolsCall}, + Phase: interceptors.PhaseRequest, + Mode: interceptors.ModeOn, + PriorityHint: interceptors.NewPriority(priority), + }, + Handler: func(_ context.Context, inv *interceptors.Invocation) (*interceptors.MutationResult, error) { + reqMutCount.Add(1) + params, ok := inv.Payload.(*mcp.CallToolParamsRaw) + if !ok { + return nil, fmt.Errorf("unexpected type %T", inv.Payload) + } + var args map[string]any + if err := json.Unmarshal(params.Arguments, &args); err != nil { + return nil, err + } + args[key] = value + raw, err := json.Marshal(args) + if err != nil { + return nil, err + } + params.Arguments = raw + return &interceptors.MutationResult{Modified: true}, nil + }, + } + } + + reqM1 := requestArgMutator("req-m1", "step1", "done", 10) + reqM2 := requestArgMutator("req-m2", "step2", "done", 20) + reqM3 := requestArgMutator("req-m3", "step3", "done", 30) + + // --- Response mutators (3, sequential by priority 10 → 20 → 30) --- + // Each prepends a tag. Applied innermost-first, so final text is: + // [RM3] [RM2] [RM1] echo: ... + + respPrefixMutator := func(name, tag string, priority int) *interceptors.Mutator { + return &interceptors.Mutator{ + Metadata: interceptors.Metadata{ + Name: name, + Events: []string{mcpserver.EventToolsCall}, + Phase: interceptors.PhaseResponse, + Mode: interceptors.ModeOn, + PriorityHint: interceptors.NewPriority(priority), + }, + Handler: func(_ context.Context, inv *interceptors.Invocation) (*interceptors.MutationResult, error) { + respMutCount.Add(1) + result, ok := inv.Payload.(*mcp.CallToolResult) + if !ok { + return nil, fmt.Errorf("unexpected type %T", inv.Payload) + } + for _, c := range result.Content { + if tc, ok := c.(*mcp.TextContent); ok { + tc.Text = tag + " " + tc.Text + } + } + return &interceptors.MutationResult{Modified: true}, nil + }, + } + } + + respM1 := respPrefixMutator("resp-m1", "[RM1]", 10) + respM2 := respPrefixMutator("resp-m2", "[RM2]", 20) + respM3 := respPrefixMutator("resp-m3", "[RM3]", 30) + + // --- Response validators (3, parallel) --- + // Run after response mutators. resp-v2 verifies it sees the mutated + // text (tags applied by response mutators), which proves the + // Mutate → Validate ordering on the response side. + + // resp-v1: passes if response has content. + respV1 := &interceptors.Validator{ + Metadata: interceptors.Metadata{ + Name: "resp-v1-has-content", + Events: []string{mcpserver.EventToolsCall}, + Phase: interceptors.PhaseResponse, + Mode: interceptors.ModeOn, + }, + Handler: func(_ context.Context, inv *interceptors.Invocation) (*interceptors.ValidationResult, error) { + respValCount.Add(1) + result, ok := inv.Payload.(*mcp.CallToolResult) + if !ok { + return nil, fmt.Errorf("unexpected type %T", inv.Payload) + } + if len(result.Content) == 0 { + return &interceptors.ValidationResult{ + Valid: false, + Messages: []interceptors.ValidationMessage{ + {Message: "empty response", Severity: interceptors.SeverityError}, + }, + }, nil + } + return &interceptors.ValidationResult{Valid: true}, nil + }, + } + + // resp-v2: rejects if response text doesn't contain "[RM1]". + // This proves the response-side ordering (Mutate → Validate): if + // this validator ran before mutators, it wouldn't find the tag and + // would abort the chain, failing the test. + respV2 := &interceptors.Validator{ + Metadata: interceptors.Metadata{ + Name: "resp-v2-sees-mutation", + Events: []string{mcpserver.EventToolsCall}, + Phase: interceptors.PhaseResponse, + Mode: interceptors.ModeOn, + }, + Handler: func(_ context.Context, inv *interceptors.Invocation) (*interceptors.ValidationResult, error) { + respValCount.Add(1) + result, ok := inv.Payload.(*mcp.CallToolResult) + if !ok { + return nil, fmt.Errorf("unexpected type %T", inv.Payload) + } + for _, c := range result.Content { + if tc, ok := c.(*mcp.TextContent); ok { + if strings.Contains(tc.Text, "[RM1]") { + return &interceptors.ValidationResult{Valid: true}, nil + } + } + } + return &interceptors.ValidationResult{ + Valid: false, + Messages: []interceptors.ValidationMessage{ + {Message: "response missing mutator tag [RM1]", Severity: interceptors.SeverityError}, + }, + }, nil + }, + } + + // resp-v3: always passes with an info message. + respV3 := &interceptors.Validator{ + Metadata: interceptors.Metadata{ + Name: "resp-v3-info", + Events: []string{mcpserver.EventToolsCall}, + Phase: interceptors.PhaseResponse, + Mode: interceptors.ModeOn, + }, + Handler: func(_ context.Context, _ *interceptors.Invocation) (*interceptors.ValidationResult, error) { + respValCount.Add(1) + return &interceptors.ValidationResult{ + Valid: true, + Messages: []interceptors.ValidationMessage{ + {Message: "response looks good", Severity: interceptors.SeverityInfo}, + }, + }, nil + }, + } + + cs := setup(t, + reqV1, reqV2, reqV3, + reqM1, reqM2, reqM3, + respM1, respM2, respM3, + respV1, respV2, respV3, + ) + + result, err := callEcho(t, cs) + require.NoError(t, err) + + text := resultText(t, result) + + // Response mutators applied in priority order (outermost = highest priority). + assert.True(t, strings.HasPrefix(text, "[RM3] [RM2] [RM1] echo:"), + "expected response mutator tags in priority order, got: %s", text) + + // Request mutators injected fields into arguments (visible in echo output). + assert.Contains(t, text, "step1") + assert.Contains(t, text, "step2") + assert.Contains(t, text, "step3") + + // Every interceptor ran exactly once. + assert.Equal(t, int32(3), reqValCount.Load(), "expected 3 request validators to run") + assert.Equal(t, int32(3), reqMutCount.Load(), "expected 3 request mutators to run") + assert.Equal(t, int32(3), respMutCount.Load(), "expected 3 response mutators to run") + assert.Equal(t, int32(3), respValCount.Load(), "expected 3 response validators to run") +} diff --git a/go/sdk/interceptors/mcpserver/testharness_test.go b/go/sdk/interceptors/mcpserver/testharness_test.go new file mode 100644 index 0000000..42200ea --- /dev/null +++ b/go/sdk/interceptors/mcpserver/testharness_test.go @@ -0,0 +1,191 @@ +package mcpserver_test + +import ( + "context" + "fmt" + "net/http/httptest" + "testing" + + "github.com/modelcontextprotocol/go-sdk/mcp" + "github.com/stretchr/testify/require" + + "github.com/modelcontextprotocol/ext-interceptors/go/sdk/interceptors" + "github.com/modelcontextprotocol/ext-interceptors/go/sdk/interceptors/mcpserver" +) + +// --- Server setup --- + +// setup creates an HTTP test server with interceptor support and a connected +// client session. The server has a single "echo" tool registered for +// EventToolsCall. Cleanup is registered via t.Cleanup. +func setup(t *testing.T, is ...interceptors.Interceptor) *mcp.ClientSession { + t.Helper() + return setupWithTools(t, defaultTools(), is...) +} + +// setupWithTools creates an HTTP test server with the given tools and +// interceptors. Cleanup is registered via t.Cleanup. +func setupWithTools(t *testing.T, tools []testTool, is ...interceptors.Interceptor) *mcp.ClientSession { + t.Helper() + + mcpServer := mcp.NewServer(&mcp.Implementation{ + Name: "test-server", + Version: "0.1.0", + }, nil) + + for _, tool := range tools { + mcpServer.AddTool(tool.tool, tool.handler) + } + + srv := mcpserver.NewServer(mcpServer) + for _, i := range is { + srv.AddInterceptor(i) + } + + handler := mcpserver.NewStreamableHTTPHandler(srv, nil) + httpServer := httptest.NewServer(handler) + t.Cleanup(httpServer.Close) + + client := mcp.NewClient(&mcp.Implementation{Name: "test-client", Version: "0.1.0"}, nil) + cs, err := client.Connect(context.Background(), &mcp.StreamableClientTransport{ + Endpoint: httpServer.URL, + }, nil) + require.NoError(t, err) + t.Cleanup(func() { cs.Close() }) + + return cs +} + +// --- Tool definitions --- + +// testTool pairs a tool definition with its handler for use with setupWithTools. +type testTool struct { + tool *mcp.Tool + handler mcp.ToolHandler +} + +// defaultTools returns the standard "echo" tool used by most tests. +func defaultTools() []testTool { + return []testTool{ + { + tool: &mcp.Tool{ + Name: "echo", + Description: "echoes input", + InputSchema: map[string]any{"type": "object"}, + }, + handler: func(_ context.Context, req *mcp.CallToolRequest) (*mcp.CallToolResult, error) { + return &mcp.CallToolResult{ + Content: []mcp.Content{&mcp.TextContent{Text: fmt.Sprintf("echo: %s", req.Params.Arguments)}}, + }, nil + }, + }, + } +} + +// --- Call helpers --- + +// callEcho calls the "echo" tool and returns the result. +func callEcho(t *testing.T, cs *mcp.ClientSession) (*mcp.CallToolResult, error) { + t.Helper() + return cs.CallTool(context.Background(), &mcp.CallToolParams{ + Name: "echo", + Arguments: map[string]any{"text": "hello"}, + }) +} + +// --- Assertion helpers --- + +// resultText extracts the first TextContent string from a CallToolResult. +// Fails the test if the result has no content or the first item isn't TextContent. +func resultText(t *testing.T, result *mcp.CallToolResult) string { + t.Helper() + require.NotEmpty(t, result.Content) + tc, ok := result.Content[0].(*mcp.TextContent) + require.True(t, ok, "expected TextContent, got %T", result.Content[0]) + return tc.Text +} + +// --- Interceptor builders --- + +// prefixMutator creates a mutator that prepends tag to each TextContent. +func prefixMutator(name, tag string, phase interceptors.Phase, priority int, mode interceptors.Mode) *interceptors.Mutator { + return &interceptors.Mutator{ + Metadata: interceptors.Metadata{ + Name: name, + Events: []string{mcpserver.EventToolsCall}, + Phase: phase, + Mode: mode, + PriorityHint: interceptors.NewPriority(priority), + }, + Handler: func(_ context.Context, inv *interceptors.Invocation) (*interceptors.MutationResult, error) { + result, ok := inv.Payload.(*mcp.CallToolResult) + if !ok { + return nil, fmt.Errorf("unexpected payload type %T", inv.Payload) + } + for _, c := range result.Content { + if tc, ok := c.(*mcp.TextContent); ok { + tc.Text = tag + " " + tc.Text + } + } + return &interceptors.MutationResult{Modified: true}, nil + }, + } +} + +// failMutator creates a mutator that always returns an error. +func failMutator(name string, phase interceptors.Phase, priority int) *interceptors.Mutator { + return &interceptors.Mutator{ + Metadata: interceptors.Metadata{ + Name: name, + Events: []string{mcpserver.EventToolsCall}, + Phase: phase, + Mode: interceptors.ModeOn, + PriorityHint: interceptors.NewPriority(priority), + }, + Handler: func(_ context.Context, _ *interceptors.Invocation) (*interceptors.MutationResult, error) { + return nil, fmt.Errorf("simulated failure") + }, + } +} + +// blockToolValidator creates a request validator that rejects calls to the named tool. +func blockToolValidator(toolName string) *interceptors.Validator { + return &interceptors.Validator{ + Metadata: interceptors.Metadata{ + Name: "block-" + toolName, + Events: []string{mcpserver.EventToolsCall}, + Phase: interceptors.PhaseRequest, + Mode: interceptors.ModeOn, + }, + Handler: func(_ context.Context, inv *interceptors.Invocation) (*interceptors.ValidationResult, error) { + params, ok := inv.Payload.(*mcp.CallToolParamsRaw) + if !ok { + return nil, fmt.Errorf("unexpected payload type %T", inv.Payload) + } + if params.Name == toolName { + return &interceptors.ValidationResult{ + Valid: false, + Messages: []interceptors.ValidationMessage{ + {Message: toolName + " is blocked", Severity: interceptors.SeverityError}, + }, + }, nil + } + return &interceptors.ValidationResult{Valid: true}, nil + }, + } +} + +// allowAllValidator creates a request validator that allows all calls. +func allowAllValidator(name string) *interceptors.Validator { + return &interceptors.Validator{ + Metadata: interceptors.Metadata{ + Name: name, + Events: []string{mcpserver.EventToolsCall}, + Phase: interceptors.PhaseRequest, + Mode: interceptors.ModeOn, + }, + Handler: func(_ context.Context, _ *interceptors.Invocation) (*interceptors.ValidationResult, error) { + return &interceptors.ValidationResult{Valid: true}, nil + }, + } +} diff --git a/go/sdk/interceptors/result.go b/go/sdk/interceptors/result.go new file mode 100644 index 0000000..a84d6a0 --- /dev/null +++ b/go/sdk/interceptors/result.go @@ -0,0 +1,96 @@ +// Copyright 2025 The MCP Interceptors Authors. All rights reserved. +// Use of this source code is governed by an Apache-2.0 +// license that can be found in the LICENSE file. + +package interceptors + +// --- Interceptor result types --- + +// ValidationMessage is a single validation finding. +type ValidationMessage struct { + Path string `json:"path,omitempty"` + Message string `json:"message"` + Severity Severity `json:"severity"` +} + +// ValidationSuggestion is an optional suggested correction. +type ValidationSuggestion struct { + Path string `json:"path"` + Value any `json:"value"` +} + +// ValidationResult is returned by validation interceptors. +type ValidationResult struct { + Valid bool `json:"valid"` + Severity Severity `json:"severity,omitempty"` + Messages []ValidationMessage `json:"messages,omitempty"` + Suggestions []ValidationSuggestion `json:"suggestions,omitempty"` +} + +// MutationResult is returned by mutation interceptors. +type MutationResult struct { + Modified bool `json:"modified"` + Info map[string]any `json:"info,omitempty"` +} + +// --- Chain result --- + +// ChainStatus describes the outcome of a full interceptor chain execution. +type ChainStatus string + +const ( + ChainSuccess ChainStatus = "success" + ChainValidationFailed ChainStatus = "validation_failed" + ChainMutationFailed ChainStatus = "mutation_failed" + ChainTimeout ChainStatus = "timeout" +) + +// AbortType classifies the reason an interceptor chain was aborted. +type AbortType string + +const ( + AbortValidation AbortType = "validation" + AbortMutation AbortType = "mutation" + AbortTimeout AbortType = "timeout" +) + +// ChainResult aggregates results from executing the full interceptor chain. +type ChainResult struct { + Status ChainStatus `json:"status"` + Event string `json:"event"` + Phase Phase `json:"phase"` + Results []ExecutionResult `json:"results"` + FinalPayload any `json:"finalPayload,omitempty"` + ValidationSummary ValidationSummary `json:"validationSummary"` + TotalDurationMs int64 `json:"totalDurationMs"` + AbortedAt []AbortInfo `json:"abortedAt,omitempty"` +} + +// ExecutionResult tracks a single interceptor's execution result. +type ExecutionResult struct { + Interceptor string `json:"interceptor"` // Name of the interceptor + Type InterceptorType `json:"type"` + Phase Phase `json:"phase"` + DurationMs int64 `json:"durationMs,omitempty"` + Error string `json:"error,omitempty"` // Non-empty when the handler returned an error + Info map[string]any `json:"info,omitempty"` + + // One of these is set depending on type: + Validation *ValidationResult `json:"validation,omitempty"` + Mutation *MutationResult `json:"mutation,omitempty"` +} + +// ValidationSummary counts validation outcomes. +type ValidationSummary struct { + Errors int `json:"errors"` + Warnings int `json:"warnings"` + Infos int `json:"infos"` +} + +// AbortInfo describes where and why a chain was aborted. +type AbortInfo struct { + Interceptor string `json:"interceptor"` + Reason string `json:"reason"` + Type AbortType `json:"type"` + Phase string `json:"phase,omitempty"` // phase at which the abort occurred +} diff --git a/go/sdk/internal/.gitkeep b/go/sdk/internal/.gitkeep deleted file mode 100644 index e69de29..0000000