diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml index 87636b9..0d4973e 100644 --- a/.github/workflows/go.yml +++ b/.github/workflows/go.yml @@ -51,7 +51,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - go: ["1.23", "1.24", "1.25"] + go: ["1.24", "1.25"] steps: - name: Check out code uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 diff --git a/.gitignore b/.gitignore index 4013ac4..28557d4 100644 --- a/.gitignore +++ b/.gitignore @@ -4,6 +4,8 @@ *.dll *.so *.dylib +mutator +validator # Build results [Bb]in/ diff --git a/go/sdk/README.md b/go/sdk/README.md index a24c5a9..732f7e5 100644 --- a/go/sdk/README.md +++ b/go/sdk/README.md @@ -1,3 +1,57 @@ # MCP Interceptors - Go Implementation -This will contain the Go implementation of the MCP Interceptors based on [SEP-1763](https://github.com/modelcontextprotocol/modelcontextprotocol/issues/1763). +Go implementation of the MCP Interceptor Extension based on +[SEP-1763](https://github.com/modelcontextprotocol/modelcontextprotocol/issues/1763). + +## Quick Start + +```go +mcpServer := mcp.NewServer(&mcp.Implementation{ + Name: "my-server", + Version: "0.1.0", +}, nil) + +// Wrap with interceptor support. +srv := interceptors.NewServer(mcpServer, + // Optional Context Provider + interceptors.WithContextProvider( + func(_ context.Context, _ mcp.Request) *interceptors.InvocationContext { + return &interceptors.InvocationContext{ + Principal: &interceptors.Principal{Type: "user", ID: "alice"}, + } + }, + ), +) + +// Register a validator that blocks dangerous tool calls. +srv.AddInterceptor(&interceptors.Validator{ + Metadata: interceptors.Metadata{ + Name: "block-dangerous", + Events: []string{interceptors.EventToolsCall}, + Phase: interceptors.PhaseRequest, + Mode: interceptors.ModeOn, + }, + Handler: func(_ context.Context, inv *interceptors.Invocation) (*interceptors.ValidationResult, error) { + // validate the request... + return &interceptors.ValidationResult{Valid: true}, nil + }, +}) + +srv.Run(context.Background(), &mcp.StdioTransport{}) +``` + +See [`examples/`](examples/) for complete working examples. + +## Documentation + +- [**DESIGN.md**](doc/DESIGN.md) — architecture, execution model, integration + with the go-sdk. +- [**PERFORMANCE.md**](doc/PERFORMANCE.md) — per-request cost model, allocation + summary, and optimization notes. +- [**CONFORMANCE.md**](doc/CONFORMANCE.md) — SEP conformance status. + +Package API documentation is available via `go doc`: + +```sh +go doc github.com/modelcontextprotocol/ext-interceptors/go/sdk/interceptors +``` diff --git a/go/sdk/doc/CONFORMANCE.md b/go/sdk/doc/CONFORMANCE.md new file mode 100644 index 0000000..77d8bd5 --- /dev/null +++ b/go/sdk/doc/CONFORMANCE.md @@ -0,0 +1,34 @@ +# SEP Conformance + +Status of this Go SDK implementation against the +[SEP-1763](https://github.com/modelcontextprotocol/modelcontextprotocol/issues/1763) +interceptor proposal. + +## Implemented + +| Area | Notes | +|------|-------| +| Validation interceptors | Parallel execution, severity-based blocking, fail-open support | +| Mutation interceptors | Sequential execution, priority ordering, atomic payload updates | +| Interceptor metadata | Name, version, description, events, phase, priorityHint (polymorphic JSON), compat, configSchema, mode, failOpen, timeout | +| Event names | Constants for all standard server-side MCP methods; JSON-RPC method names used directly as event names | +| Unified result envelope | ValidationResult, MutationResult, ExecutionResult with base envelope fields | +| Chain result | ChainResult with status, results, finalPayload, validationSummary, abortedAt | +| JSON-RPC error mapping | Typed error data structs: -32602 for validation, -32603 for mutation, -32000 for timeout | +| Trust-boundary execution order | Receiving: validate (parallel) then mutate (sequential); Sending: mutate (sequential) then validate (parallel) | +| Priority ordering | Mutators sorted by `priorityHint.Resolve(phase)` ascending, alphabetical tiebreak | +| Fail-open behavior | `FailOpen: true` interceptors log errors without aborting the chain | +| Audit mode | `ModeAudit` records results without blocking or applying mutations | +| Timeout & context | Per-interceptor timeouts, chain-level context cancellation, `InvocationContext` with principal/traceId via `mcpserver.WithContextProvider` | +| Receiving direction (client → server) | All server-side method calls intercepted via `AddReceivingMiddleware` | +| Capability declaration | Interceptor metadata injected into `initialize` response via `Capabilities.Experimental` | +| First-party (in-process) deployment | Interceptors run as Go functions within the server process | +| Third-party and hybrid deployment | Handlers can call remote services; local and remote interceptors can be mixed freely. No built-in remote interceptor abstraction yet | + +## Not Implemented + +| Area | SEP expects | Notes | +|------|-------------|-------| +| Wildcard event matching | `type InterceptorEvent = ... \| "*/request" \| "*/response" \| "*"` | `matchesEvent` does exact match only; wildcard patterns are planned | +| Protocol methods | `interceptors/list`, `interceptor/invoke`, `interceptor/executeChain` | Requires upstream go-sdk changes to register custom JSON-RPC methods | +| Server → client interception | Client features as interceptable events: `"sampling/createMessage"`, `"elicitation/create"`, `"roots/list"` | Requires a `sendingMiddleware` installed via `Server.AddSendingMiddleware`. Outgoing requests run mutate → validate, incoming responses run validate → mutate (same `executeForSending`/`executeForReceiving` methods). Interceptors match by event name (no API changes needed for registration) | diff --git a/go/sdk/doc/DESIGN.md b/go/sdk/doc/DESIGN.md new file mode 100644 index 0000000..ee84a87 --- /dev/null +++ b/go/sdk/doc/DESIGN.md @@ -0,0 +1,195 @@ +# Go SDK Interceptors — Design Document + +## Integration Point: Receiving Middleware + +The go-sdk processes an incoming JSON-RPC message in this order: + +``` +Transport (SSE / stdio) + → JSON-RPC decode + → Params deserialization (json.RawMessage → typed struct) + → Receiving middleware chain ← we hook in here + → Method handler (e.g. tool handler) + → Result returned through middleware + → JSON-RPC encode + → Transport +``` + +## Capability Declaration + +During initialization, the middleware intercepts the `"initialize"` response +and injects interceptor metadata into +`Capabilities.Experimental["io.modelcontextprotocol/interceptors"]`. This +follows the same pattern as the variants extension +(`io.modelcontextprotocol/server-variants`). + +The capability payload includes: +- `supportedEvents` — deduplicated list of events with registered interceptors +- `interceptors` — full metadata array in wire format + +## Request/Response Lifecycle + +When a JSON-RPC request arrives, `receivingMiddleware` in +`mcpserver/server.go` runs the following sequence: + +``` +0. If method == "initialize" → enrich result with capability declaration +1. Assign typed params to Invocation inv.Payload = req.GetParams() +2. Run request-phase chain (validate → mutate) +3. If aborted → return error +4. Params already modified in place — no unmarshal needed +5. Call next handler next(ctx, method, req) +6. Assign result to Invocation inv.Payload = result +7. Run response-phase chain (mutate → validate) +8. If aborted → return error +9. Result already modified in place — no unmarshal needed +10. Return result +``` + +The JSON-RPC method name is used directly as the event name (e.g. `"tools/call"`). + +### Typed Payload — Zero JSON Operations + +`Invocation.Payload` is `any`, the live Go value from the go-sdk +(e.g. `*mcp.CallToolParamsRaw`). Handlers type-assert directly, the +same pattern as gRPC-Go interceptors (`req any`). No JSON marshaling +or unmarshaling occurs in the normal path. + +```go +// Validator — type-assert, inspect, return: +params, ok := inv.Payload.(*mcp.CallToolParamsRaw) +if !ok { + return nil, fmt.Errorf("unexpected payload type %T", inv.Payload) +} + +// Mutator — type-assert, modify in place, return: +result, ok := inv.Payload.(*mcp.CallToolResult) +if !ok { + return nil, fmt.Errorf("unexpected payload type %T", inv.Payload) +} +result.Content[0] = &mcp.TextContent{Text: "modified"} +return &MutationResult{Modified: true}, nil +``` + +**Audit mode:** Audit-mode mutators receive a deep-copied payload (via +`Invocation.withCopiedPayload()`) so their in-place modifications don't +affect the real struct. The deep copy uses a JSON round-trip +(`json.Marshal` → `reflect.New` → `json.Unmarshal`). Only audit-mode +mutators pay this cost. + +### Limitations + +- **Params must round-trip through JSON faithfully** for audit-mode deep + copy. All go-sdk param and result types use standard `encoding/json` + tags, so this holds in practice. +- **Type assertions require knowing the concrete type.** Interceptors must + know which type to expect for a given event (e.g. `*mcp.CallToolParamsRaw` + for `tools/call` requests). The `Events` field on `Metadata` narrows + which events reach a handler, so single-event interceptors always see + the expected type. + +--- + +## What Is and Is Not Intercepted + +### Intercepted + +All JSON-RPC **method calls** routed through the server's receiving middleware: + +| Method | Event | +|--------|-------| +| `tools/call` | `EventToolsCall` | +| `tools/list` | `EventToolsList` | +| `prompts/get` | `EventPromptsGet` | +| `prompts/list` | `EventPromptsList` | +| `resources/read` | `EventResourcesRead` | +| `resources/list` | `EventResourcesList` | +| `resources/subscribe` | `EventResourcesSubscribe` | + +Unknown methods pass through the middleware untouched. + +### Not Intercepted + +1. **Progress notifications.** During a tool call, a handler can call + `session.NotifyProgress()`. These are JSON-RPC *notifications* sent + directly over the transport — they do not flow through `MethodHandler` + middleware. Interceptors never see them. + +2. **Transport-level SSE streaming.** The Streamable HTTP transport + multiplexes multiple JSON-RPC messages over a single SSE connection. + This is connection management, not per-message streaming. Each individual + method call is still a single request → single response, which the + middleware intercepts normally. + +3. **JSON-RPC notifications** (e.g. `notifications/initialized`, + `notifications/cancelled`). The go-sdk routes notifications through a + separate handler path, not through `MethodHandler` middleware. + +Notification interception is not defined by the proposal. +If this becomes necessary, it would require a separate notification middleware +hook in the go-sdk. + +--- + +## Chain Execution Model + +The `chainExecutor` in `chain_executor.go` implements trust-boundary-aware +execution: + +**Request phase** (receiving data — untrusted → trusted): +``` +Validate (parallel) → Mutate (sequential) +``` +Validation acts as a security gate before mutations process the data. + +**Response phase** (sending data — trusted → untrusted): +``` +Mutate (sequential) → Validate (parallel) +``` +Mutations prepare/sanitize data, then validation verifies before sending. + +### Validator execution +- All matching validators run in parallel (goroutines). +- A validator returning `Valid: false` with `Severity: "error"` in enforced + mode (`Mode: ModeOn`) aborts the chain. +- `FailOpen: true` validators log errors and record an `ExecutionResult` + (with `Error` populated) for observability, but don't abort. + +### Mutator execution +- Mutators run sequentially, ordered by `PriorityHint.Resolve(phase)` + (ascending), with alphabetical name tiebreak. +- Each mutator modifies the typed payload in place via type assertion on `inv.Payload`. +- If any mutator fails (and is not `FailOpen`), the chain aborts. + `FailOpen` mutators record an `ExecutionResult` (with `Error` populated) + and continue. +- In `ModeAudit`, the mutator runs on a deep-copied payload and its result + is recorded, but the real payload is not affected. + +### Filtering +`newChainExecutor` filters the full interceptor set by: +1. `Mode != ModeOff` +2. Phase matches (or interceptor phase is `PhaseBoth`) +3. Event matches (exact match only; wildcard support is planned via `matchesEvent`) + +--- + +## File Map + +### `interceptors/` — protocol-agnostic core (zero MCP imports) + +| File | Responsibility | +|------|---------------| +| `interceptor.go` | Types (Phase, Mode, InterceptorType, Priority, Severity, Compat), Metadata struct, Interceptor interface, Validator/Mutator structs and handler types | +| `invocation.go` | Invocation (with audit-mode payload cloning), InvocationContext, Principal — the input to every handler | +| `result.go` | All outcome types: ValidationResult, MutationResult, ExecutionResult, ChainResult, AbortInfo | +| `chain.go` | `Chain` public API: `NewChain`, `Add`, `ExecuteForReceiving`, `ExecuteForSending`, `IsEmpty`, `Interceptors` | +| `chain_executor.go` | `interceptorSnapshot` (atomic snapshot with lazy chain cache), `chainExecutor` struct, `newChainExecutor` (filtering + sorting), `executeForReceiving`, `executeForSending`, `timeoutResult`, `matchesPhase`, `matchesEvent` | +| `chain_validate.go` | `validatorResult` struct, `runValidators` (parallel dispatch + N=1 fast path), `executeValidator`, `recordValidation` | +| `chain_mutate.go` | `mutatorOutcome` type + constants, `runMutators` (sequential loop + audit-mode copy), `executeMutator` | + +### `interceptors/mcpserver/` — MCP server integration + +| File | Responsibility | +|------|---------------| +| `server.go` | `Server` wrapper, `receivingMiddleware`, `WithContextProvider`, capability declaration, `NewStreamableHTTPHandler`, `abortToJSONRPCError` | +| `events.go` | Event name constants for standard MCP methods | diff --git a/go/sdk/doc/PERFORMANCE.md b/go/sdk/doc/PERFORMANCE.md new file mode 100644 index 0000000..c525110 --- /dev/null +++ b/go/sdk/doc/PERFORMANCE.md @@ -0,0 +1,51 @@ +# Go SDK Interceptors — Performance + +Analysis of per-request costs and allocation patterns. + +--- + +## Design Rationale: Typed Payload + +The interceptor chain passes Go's typed params/result objects directly +to handlers, avoiding JSON serialization entirely in the normal path. +This follows the same pattern used by gRPC-Go, where interceptors +receive the request as `any` and type-assert to the concrete type. + +`Invocation.Payload` is `any, the same pointer the go-sdk already +allocated during JSON-RPC deserialization. No wrapper struct, no +intermediate copies. Mutators modify the value in place through the +pointer (no marshal/unmarshal round-trip) + +--- + +## Per-Request Cost Model + +Every intercepted request passes through the middleware in +`mcpserver/server.go`. The cost depends on whether interceptors match +the event. + +### Fast Path (no matching interceptors) + +When no interceptors match, the middleware cost is: + +1. One atomic pointer load: `s.chain.snapshot.Load()` +2. Two `sync.Map` lookups on the chain cache: `getChain(event, PhaseRequest)` and `getChain(event, PhaseResponse)` +3. One boolean check: `ce.empty` + +Zero allocations, zero JSON operations. + +### Intercepted Path + +With interceptors active, each active phase +incurs: + +| Step | Operation | Allocations | JSON ops | +|------|-----------|-------------|----------| +| 1 | `Invocation` struct | 1 struct | 0 | +| 2 | `ChainResult` struct (with pre-allocated `Results` slice) | 1 struct + 1 slice | 0 | +| 3 | Validator execution | 0 (N=1) / goroutines (N>1) | 0 | +| 4 | Mutator execution | 0 | 0 | +| 5 | Audit-mode deep copy | 1 per audit mutator | 1 marshal + 1 unmarshal | + +Zero JSON operations in the normal path. Phases with no matching +interceptors are skipped entirely. diff --git a/go/sdk/examples/.gitkeep b/go/sdk/examples/.gitkeep deleted file mode 100644 index e69de29..0000000 diff --git a/go/sdk/examples/mutator/main.go b/go/sdk/examples/mutator/main.go new file mode 100644 index 0000000..b0a962b --- /dev/null +++ b/go/sdk/examples/mutator/main.go @@ -0,0 +1,87 @@ +// Example: a simple mutator interceptor that adds an "[audited]" prefix +// to every tool call result's text content. +// +// Demonstrates WithContextProvider to populate Invocation.Context with +// caller identity, which interceptor handlers can inspect. +package main + +import ( + "context" + "fmt" + "log" + "net/http" + + "github.com/modelcontextprotocol/go-sdk/mcp" + + "github.com/modelcontextprotocol/ext-interceptors/go/sdk/interceptors" + "github.com/modelcontextprotocol/ext-interceptors/go/sdk/interceptors/mcpserver" +) + +func main() { + mcpServer := mcp.NewServer(&mcp.Implementation{ + Name: "example-server", + Version: "0.1.0", + }, nil) + + mcpServer.AddTool(&mcp.Tool{Name: "greet", Description: "says hello"}, + func(ctx context.Context, req *mcp.CallToolRequest) (*mcp.CallToolResult, error) { + return &mcp.CallToolResult{ + Content: []mcp.Content{&mcp.TextContent{Text: "hello world"}}, + }, nil + }, + ) + + // Build the mutator. + m := &interceptors.Mutator{ + Metadata: interceptors.Metadata{ + Name: "audit-tag", + Events: []string{mcpserver.EventToolsCall}, + Phase: interceptors.PhaseResponse, + Mode: interceptors.ModeOn, + TimeoutMs: 5000, + }, + Handler: func(_ context.Context, inv *interceptors.Invocation) (*interceptors.MutationResult, error) { + result, ok := inv.Payload.(*mcp.CallToolResult) + if !ok { + return nil, fmt.Errorf("unexpected payload type %T", inv.Payload) + } + + // Use caller identity from the context provider if available. + tag := "[audited]" + if inv.Context != nil && inv.Context.Principal != nil { + tag = fmt.Sprintf("[audited by %s]", inv.Context.Principal.ID) + } + + for i, c := range result.Content { + if tc, ok := c.(*mcp.TextContent); ok { + tc.Text = tag + " " + tc.Text + result.Content[i] = tc + } + } + return &interceptors.MutationResult{Modified: true}, nil + }, + } + + // Wire it up with a context provider that populates caller identity. + // In production, extract the principal from OAuth tokens on the session + // (e.g. via req.GetSession() or transport-level auth headers). + srv := mcpserver.NewServer(mcpServer, + mcpserver.WithContextProvider( + func(_ context.Context, _ mcp.Request) *interceptors.InvocationContext { + return &interceptors.InvocationContext{ + Principal: &interceptors.Principal{ + Type: "user", + ID: "example-user", + }, + } + }, + ), + ) + srv.AddInterceptor(m) + + handler := mcpserver.NewStreamableHTTPHandler(srv, nil) + log.Println("listening on :8080") + if err := http.ListenAndServe(":8080", handler); err != nil { + log.Fatal(err) + } +} diff --git a/go/sdk/examples/validator/main.go b/go/sdk/examples/validator/main.go new file mode 100644 index 0000000..f0f87a0 --- /dev/null +++ b/go/sdk/examples/validator/main.go @@ -0,0 +1,86 @@ +// Example: a simple validator interceptor that rejects tool calls +// to a tool named "dangerous_tool". +// +// Demonstrates WithContextProvider to populate Invocation.Context with +// caller identity, which interceptor handlers can inspect. +package main + +import ( + "context" + "fmt" + "log" + "net/http" + + "github.com/modelcontextprotocol/go-sdk/mcp" + + "github.com/modelcontextprotocol/ext-interceptors/go/sdk/interceptors" + "github.com/modelcontextprotocol/ext-interceptors/go/sdk/interceptors/mcpserver" +) + +func main() { + mcpServer := mcp.NewServer(&mcp.Implementation{ + Name: "example-server", + Version: "0.1.0", + }, nil) + + mcpServer.AddTool(&mcp.Tool{Name: "echo", Description: "echoes input"}, + func(ctx context.Context, req *mcp.CallToolRequest) (*mcp.CallToolResult, error) { + return &mcp.CallToolResult{ + Content: []mcp.Content{&mcp.TextContent{Text: fmt.Sprintf("you said: %s", req.Params.Arguments)}}, + }, nil + }, + ) + + // Build the validator. + v := &interceptors.Validator{ + Metadata: interceptors.Metadata{ + Name: "block-dangerous-tool", + Events: []string{mcpserver.EventToolsCall}, + Phase: interceptors.PhaseRequest, + Mode: interceptors.ModeOn, + TimeoutMs: 5000, + }, + Handler: func(_ context.Context, inv *interceptors.Invocation) (*interceptors.ValidationResult, error) { + params, ok := inv.Payload.(*mcp.CallToolParamsRaw) + if !ok { + return nil, fmt.Errorf("unexpected payload type %T", inv.Payload) + } + if params.Name == "dangerous_tool" { + reason := "dangerous_tool is not allowed" + if inv.Context != nil && inv.Context.Principal != nil { + reason = fmt.Sprintf("dangerous_tool is not allowed for %s", inv.Context.Principal.ID) + } + return &interceptors.ValidationResult{ + Valid: false, + Messages: []interceptors.ValidationMessage{ + {Message: reason, Severity: interceptors.SeverityError}, + }, + }, nil + } + return &interceptors.ValidationResult{Valid: true}, nil + }, + } + + // Wire it up with a context provider that populates caller identity. + // In production, extract the principal from OAuth tokens on the session + // (e.g. via req.GetSession() or transport-level auth headers). + srv := mcpserver.NewServer(mcpServer, + mcpserver.WithContextProvider( + func(_ context.Context, _ mcp.Request) *interceptors.InvocationContext { + return &interceptors.InvocationContext{ + Principal: &interceptors.Principal{ + Type: "user", + ID: "example-user", + }, + } + }, + ), + ) + srv.AddInterceptor(v) + + handler := mcpserver.NewStreamableHTTPHandler(srv, nil) + log.Println("listening on :8080") + if err := http.ListenAndServe(":8080", handler); err != nil { + log.Fatal(err) + } +} diff --git a/go/sdk/go.mod b/go/sdk/go.mod index 0584e8f..4303b25 100644 --- a/go/sdk/go.mod +++ b/go/sdk/go.mod @@ -1,4 +1,22 @@ module github.com/modelcontextprotocol/ext-interceptors/go/sdk -go 1.23.0 +go 1.24.0 +toolchain go1.24.3 + +require ( + github.com/modelcontextprotocol/go-sdk v1.4.0 + github.com/stretchr/testify v1.11.1 +) + +require ( + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/google/jsonschema-go v0.4.2 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/segmentio/asm v1.1.3 // indirect + github.com/segmentio/encoding v0.5.3 // indirect + github.com/yosida95/uritemplate/v3 v3.0.2 // indirect + golang.org/x/oauth2 v0.34.0 // indirect + golang.org/x/sys v0.40.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) diff --git a/go/sdk/go.sum b/go/sdk/go.sum new file mode 100644 index 0000000..dacfe34 --- /dev/null +++ b/go/sdk/go.sum @@ -0,0 +1,30 @@ +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/golang-jwt/jwt/v5 v5.3.0 h1:pv4AsKCKKZuqlgs5sUmn4x8UlGa0kEVt/puTpKx9vvo= +github.com/golang-jwt/jwt/v5 v5.3.0/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE= +github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= +github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= +github.com/google/jsonschema-go v0.4.2 h1:tmrUohrwoLZZS/P3x7ex0WAVknEkBZM46iALbcqoRA8= +github.com/google/jsonschema-go v0.4.2/go.mod h1:r5quNTdLOYEz95Ru18zA0ydNbBuYoo9tgaYcxEYhJVE= +github.com/modelcontextprotocol/go-sdk v1.4.0 h1:u0kr8lbJc1oBcawK7Df+/ajNMpIDFE41OEPxdeTLOn8= +github.com/modelcontextprotocol/go-sdk v1.4.0/go.mod h1:Nxc2n+n/GdCebUaqCOhTetptS17SXXNu9IfNTaLDi1E= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/segmentio/asm v1.1.3 h1:WM03sfUOENvvKexOLp+pCqgb/WDjsi7EK8gIsICtzhc= +github.com/segmentio/asm v1.1.3/go.mod h1:Ld3L4ZXGNcSLRg4JBsZ3//1+f/TjYl0Mzen/DQy1EJg= +github.com/segmentio/encoding v0.5.3 h1:OjMgICtcSFuNvQCdwqMCv9Tg7lEOXGwm1J5RPQccx6w= +github.com/segmentio/encoding v0.5.3/go.mod h1:HS1ZKa3kSN32ZHVZ7ZLPLXWvOVIiZtyJnO1gPH1sKt0= +github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= +github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= +github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4= +github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4= +golang.org/x/oauth2 v0.34.0 h1:hqK/t4AKgbqWkdkcAeI8XLmbK+4m4G5YeQRrmiotGlw= +golang.org/x/oauth2 v0.34.0/go.mod h1:lzm5WQJQwKZ3nwavOZ3IS5Aulzxi68dUSgRHujetwEA= +golang.org/x/sys v0.40.0 h1:DBZZqJ2Rkml6QMQsZywtnjnnGvHza6BTfYFWY9kjEWQ= +golang.org/x/sys v0.40.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= +golang.org/x/tools v0.41.0 h1:a9b8iMweWG+S0OBnlU36rzLp20z1Rp10w+IY2czHTQc= +golang.org/x/tools v0.41.0/go.mod h1:XSY6eDqxVNiYgezAVqqCeihT4j1U2CCsqvH3WhQpnlg= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/go/sdk/interceptors/chain.go b/go/sdk/interceptors/chain.go new file mode 100644 index 0000000..99c2b10 --- /dev/null +++ b/go/sdk/interceptors/chain.go @@ -0,0 +1,112 @@ +// Copyright 2025 The MCP Interceptors Authors. All rights reserved. +// Use of this source code is governed by an Apache-2.0 +// license that can be found in the LICENSE file. + +package interceptors + +import ( + "context" + "fmt" + "log/slog" + "sync" + "sync/atomic" +) + +// Chain is a protocol-agnostic interceptor chain engine. It manages a set +// of interceptors and executes them in the correct order for a given +// event and phase. It has no dependency on any specific transport and +// can be used with MCP, gRPC, custom HTTP, or any other server. +type Chain struct { + mu sync.Mutex + snapshot atomic.Pointer[interceptorSnapshot] + logger *slog.Logger +} + +// ChainOption configures a Chain. +type ChainOption func(*Chain) + +// WithChainLogger sets the logger for chain execution. +// If not set, slog.Default() is used. +func WithChainLogger(l *slog.Logger) ChainOption { + return func(c *Chain) { + c.logger = l + } +} + +// NewChain creates a new Chain with optional configuration. +func NewChain(opts ...ChainOption) *Chain { + c := &Chain{} + for _, opt := range opts { + opt(c) + } + if c.logger == nil { + c.logger = slog.Default() + } + c.snapshot.Store(&interceptorSnapshot{logger: c.logger}) + return c +} + +// Add registers an interceptor. It panics if the interceptor has a nil +// handler or is an unsupported type. It is safe to call while the chain +// is in use. Returns the receiver for chaining. +func (c *Chain) Add(i Interceptor) *Chain { + switch v := i.(type) { + case *Validator: + if v.Handler == nil { + panic("interceptors: validator " + v.Name + " has nil handler") + } + case *Mutator: + if v.Handler == nil { + panic("interceptors: mutator " + v.Name + " has nil handler") + } + default: + panic(fmt.Sprintf("interceptors: unsupported interceptor type %T", i)) + } + c.mu.Lock() + defer c.mu.Unlock() + old := c.snapshot.Load() + newAll := make([]Interceptor, len(old.all)+1) + copy(newAll, old.all) + newAll[len(old.all)] = i + c.snapshot.Store(&interceptorSnapshot{ + all: newAll, + logger: c.logger, + }) + return c +} + +// ExecuteForReceiving runs the interceptor chain for the given invocation +// using receive-side ordering (validate then mutate). Returns nil, nil if +// no interceptors match the invocation's event and phase. +func (c *Chain) ExecuteForReceiving(ctx context.Context, inv *Invocation) (*ChainResult, error) { + snap := c.snapshot.Load() + ce := snap.getChain(inv.Event, inv.Phase) + if ce.empty { + return nil, nil + } + return ce.executeForReceiving(ctx, inv) +} + +// ExecuteForSending runs the interceptor chain for the given invocation +// using send-side ordering (mutate then validate). Returns nil, nil if +// no interceptors match the invocation's event and phase. +func (c *Chain) ExecuteForSending(ctx context.Context, inv *Invocation) (*ChainResult, error) { + snap := c.snapshot.Load() + ce := snap.getChain(inv.Event, inv.Phase) + if ce.empty { + return nil, nil + } + return ce.executeForSending(ctx, inv) +} + +// IsEmpty reports whether no interceptors match the given event and phase. +func (c *Chain) IsEmpty(event string, phase Phase) bool { + snap := c.snapshot.Load() + ce := snap.getChain(event, phase) + return ce.empty +} + +// Interceptors returns the current list of registered interceptors. +func (c *Chain) Interceptors() []Interceptor { + return c.snapshot.Load().all +} diff --git a/go/sdk/interceptors/chain_executor.go b/go/sdk/interceptors/chain_executor.go new file mode 100644 index 0000000..75ef760 --- /dev/null +++ b/go/sdk/interceptors/chain_executor.go @@ -0,0 +1,214 @@ +// Copyright 2025 The MCP Interceptors Authors. All rights reserved. +// Use of this source code is governed by an Apache-2.0 +// license that can be found in the LICENSE file. + +package interceptors + +import ( + "context" + "fmt" + "log/slog" + "sort" + "sync" + "time" +) + +// interceptorSnapshot is an immutable snapshot of registered interceptors, +// created once per Add call. It holds the interceptor list and +// lazily-built chain executors keyed by "event|phase". Swapping the snapshot +// pointer implicitly invalidates all cached chains. +type interceptorSnapshot struct { + all []Interceptor + logger *slog.Logger + chains sync.Map // "event|phase" -> *chainExecutor +} + +// getChain returns a cached chainExecutor for the given event and phase, +// building one on first access via newChainExecutor. +func (snap *interceptorSnapshot) getChain(event string, phase Phase) *chainExecutor { + key := event + "|" + string(phase) + if v, ok := snap.chains.Load(key); ok { + return v.(*chainExecutor) + } + ce := newChainExecutor(snap.all, event, phase, snap.logger) + v, _ := snap.chains.LoadOrStore(key, ce) + return v.(*chainExecutor) +} + +// chainExecutor holds the filtered and sorted interceptors applicable to a +// specific event+phase pair, and orchestrates their execution according to +// trust-boundary-aware ordering. +// +// Validators and mutators are separated at construction time so that each +// execution method (executeForReceiving, executeForSending) can run them +// in the correct order without re-filtering. +type chainExecutor struct { + empty bool // true when no validators or mutators matched + validators []*Validator // validators matching this event+phase (run in parallel) + mutators []*Mutator // mutators matching this event+phase (run sequentially by priority) + resultsCap int // len(validators) + len(mutators); used to pre-allocate ChainResult.Results + + logger *slog.Logger +} + +// newChainExecutor builds a chainExecutor by filtering the global interceptor +// list down to those that are active (Mode != ModeOff), match the target phase, +// and match the target event. Mutators are then sorted by priority (ascending) +// with alphabetical name as a tiebreaker, which guarantees +// deterministic execution order. +func newChainExecutor( + all []Interceptor, + event string, + phase Phase, + logger *slog.Logger, +) *chainExecutor { + ce := &chainExecutor{logger: logger} + for _, i := range all { + meta := i.GetMetadata() + if meta.Mode == ModeOff || !matchesPhase(meta.Phase, phase) || !matchesEvent(meta.Events, event) { + continue + } + switch v := i.(type) { + case *Validator: + ce.validators = append(ce.validators, v) + case *Mutator: + ce.mutators = append(ce.mutators, v) + default: + logger.Warn("unknown interceptor type, skipping", + "interceptor", meta.Name, + "type", fmt.Sprintf("%T", i), + ) + } + } + // Sort mutators by priority (ascending), alphabetical tiebreak. + sort.Slice(ce.mutators, func(i, j int) bool { + pi := ce.mutators[i].PriorityHint.Resolve(phase) + pj := ce.mutators[j].PriorityHint.Resolve(phase) + if pi != pj { + return pi < pj + } + return ce.mutators[i].Name < ce.mutators[j].Name + }) + ce.empty = len(ce.validators) == 0 && len(ce.mutators) == 0 + ce.resultsCap = len(ce.validators) + len(ce.mutators) + return ce +} + +// executeForReceiving runs the chain for incoming requests (server receiving) +// using the receive-side ordering: +// +// Receive -> Validate (parallel) -> Mutate (sequential) +// +// Validators run first as a security barrier: if any enforced +// validator produces an error-severity message, the chain aborts with +// "validation_failed" before any mutators run, so the payload is +// unmodified. Only after all validators pass do mutators run +// sequentially, modifying the typed payload in place. If a mutator +// aborts, the chain returns with the abort recorded. +func (ce *chainExecutor) executeForReceiving(ctx context.Context, inv *Invocation) (*ChainResult, error) { + start := time.Now() + cr := &ChainResult{ + Event: inv.Event, + Phase: inv.Phase, + FinalPayload: inv.Payload, + Results: make([]ExecutionResult, 0, ce.resultsCap), + } + + // 1. Run validators in parallel. + ce.runValidators(ctx, inv, cr) + if len(cr.AbortedAt) > 0 { + cr.Status = ChainValidationFailed + cr.TotalDurationMs = time.Since(start).Milliseconds() + return cr, nil + } + if err := ctx.Err(); err != nil { + return ce.timeoutResult(cr, start), nil + } + + // 2. Run mutators sequentially (in-place on inv.Payload). + ce.runMutators(ctx, inv, cr) + if len(cr.AbortedAt) > 0 { + cr.Status = ChainMutationFailed + cr.TotalDurationMs = time.Since(start).Milliseconds() + return cr, nil + } + + cr.Status = ChainSuccess + cr.TotalDurationMs = time.Since(start).Milliseconds() + return cr, nil +} + +// executeForSending runs the chain for outgoing responses (server sending) +// using the send-side ordering: +// +// Mutate (sequential) -> Validate (parallel) -> Send +// +// Mutators run first to prepare/sanitize outgoing data, then +// validators check the (now mutated) payload before it leaves the server. +// Mutators modify the typed value in place, so validators automatically +// see the post-mutation state. +func (ce *chainExecutor) executeForSending(ctx context.Context, inv *Invocation) (*ChainResult, error) { + start := time.Now() + cr := &ChainResult{ + Event: inv.Event, + Phase: inv.Phase, + FinalPayload: inv.Payload, + Results: make([]ExecutionResult, 0, ce.resultsCap), + } + + // 1. Run mutators sequentially (in-place on inv.Payload). + ce.runMutators(ctx, inv, cr) + if len(cr.AbortedAt) > 0 { + cr.Status = ChainMutationFailed + cr.TotalDurationMs = time.Since(start).Milliseconds() + return cr, nil + } + if err := ctx.Err(); err != nil { + return ce.timeoutResult(cr, start), nil + } + + // 2. Run validators in parallel. + ce.runValidators(ctx, inv, cr) + if len(cr.AbortedAt) > 0 { + cr.Status = ChainValidationFailed + cr.TotalDurationMs = time.Since(start).Milliseconds() + return cr, nil + } + + cr.Status = ChainSuccess + cr.TotalDurationMs = time.Since(start).Milliseconds() + return cr, nil +} + +// timeoutResult sets the ChainResult to "timeout" status and appends +// an AbortInfo entry. Used when the parent context is cancelled between +// chain stages (e.g. after validators but before mutators). +func (ce *chainExecutor) timeoutResult(cr *ChainResult, start time.Time) *ChainResult { + cr.Status = ChainTimeout + cr.TotalDurationMs = time.Since(start).Milliseconds() + cr.AbortedAt = append(cr.AbortedAt, AbortInfo{ + Reason: "chain execution timeout exceeded", + Type: AbortTimeout, + Phase: string(cr.Phase), + }) + return cr +} + +// matchesPhase checks if an interceptor's configured phase covers the target +// phase. An interceptor with PhaseBoth matches any target phase. +func matchesPhase(interceptorPhase, targetPhase Phase) bool { + return interceptorPhase == PhaseBoth || interceptorPhase == targetPhase +} + +// matchesEvent checks if any of the interceptor's registered event patterns +// match the given event string. Currently uses exact string comparison only. +// +// TODO: support wildcard patterns (e.g. "*", "*/request", "tools/*"). +func matchesEvent(interceptorEvents []string, event string) bool { + for _, pattern := range interceptorEvents { + if pattern == event { + return true + } + } + return false +} diff --git a/go/sdk/interceptors/chain_mutate.go b/go/sdk/interceptors/chain_mutate.go new file mode 100644 index 0000000..1a743cb --- /dev/null +++ b/go/sdk/interceptors/chain_mutate.go @@ -0,0 +1,150 @@ +// Copyright 2025 The MCP Interceptors Authors. All rights reserved. +// Use of this source code is governed by an Apache-2.0 +// license that can be found in the LICENSE file. + +package interceptors + +import ( + "context" + "time" +) + +// mutatorOutcome describes the result of executing a single mutator, used by +// the runMutators loop to decide whether to continue, skip, or halt the chain. +type mutatorOutcome int + +const ( + mutatorOK mutatorOutcome = iota // handler succeeded, payload may have been updated + mutatorSkipped // handler failed (fail-open) or audit mode; continue chain + mutatorAborted // handler failed (fail-closed); halt chain +) + +// runMutators runs mutators sequentially in priority order. Mutators modify +// the typed payload in place via type assertion on inv.Payload. For +// audit-mode mutators, a deep-copied invocation is used so modifications +// don't affect the real payload. +func (ce *chainExecutor) runMutators(ctx context.Context, inv *Invocation, cr *ChainResult) { + if len(ce.mutators) == 0 { + return + } + + for _, m := range ce.mutators { + if ctx.Err() != nil { + cr.AbortedAt = append(cr.AbortedAt, AbortInfo{ + Reason: "context cancelled during mutation chain", + Type: AbortTimeout, + Phase: string(inv.Phase), + }) + return + } + + // Audit-mode mutators work on a deep copy so in-place + // modifications don't affect the real payload. + mInv := inv + if m.Mode == ModeAudit { + var err error + mInv, err = inv.withCopiedPayload() + if err != nil { + ce.logger.Warn("failed to deep-copy payload for audit mutator, skipping", + "interceptor", m.Name, + "error", err, + ) + continue + } + } + + if ce.executeMutator(ctx, m, mInv, cr) == mutatorAborted { + return + } + } +} + +// executeMutator runs a single mutator handler and manages its full lifecycle: +// +// 1. Timeout setup: wraps ctx with a per-interceptor deadline if configured. +// 2. Handler invocation: calls the mutator's Handler with the current payload. +// 3. Error handling: on failure, checks FailOpen to decide between abort and skip. +// Fail-closed records both an AbortInfo and an ExecutionResult for audit. +// 4. Audit mode: records the ExecutionResult but does not apply payload changes +// (the caller already passed a deep-copied invocation for audit mutators). +// 5. In-place mutation: the handler modifies the typed value directly via type +// assertion on inv.Payload. result.Modified is advisory (recorded in the +// ExecutionResult for observability) but not checked by the chain. +func (ce *chainExecutor) executeMutator( + ctx context.Context, m *Mutator, mInv *Invocation, + cr *ChainResult, +) mutatorOutcome { + if m.Handler == nil { + ce.logger.Warn("mutator has nil handler, skipping", + "interceptor", m.Name, + ) + return mutatorSkipped + } + + // Apply per-interceptor timeout from metadata. + handlerCtx := ctx + if m.TimeoutMs > 0 { + var cancel context.CancelFunc + handlerCtx, cancel = context.WithTimeout(ctx, time.Duration(m.TimeoutMs)*time.Millisecond) + defer cancel() + } + + mStart := time.Now() + result, err := m.Handler(handlerCtx, mInv) + dur := time.Since(mStart).Milliseconds() + + if err != nil { + // Distinguish timeout from general mutation errors for downstream consumers. + abortType := AbortMutation + if handlerCtx.Err() == context.DeadlineExceeded { + abortType = AbortTimeout + } + ce.logger.Warn("mutator error", + "interceptor", m.Name, + "error", err, + ) + // Always record the execution result for observability, + // regardless of fail-open/fail-closed. + cr.Results = append(cr.Results, ExecutionResult{ + Interceptor: m.Name, + Type: TypeMutation, + Phase: mInv.Phase, + DurationMs: dur, + Error: err.Error(), + }) + if !m.FailOpen { + // Fail-closed: record abort entry and halt chain. + cr.AbortedAt = append(cr.AbortedAt, AbortInfo{ + Interceptor: m.Name, + Reason: err.Error(), + Type: abortType, + Phase: string(mInv.Phase), + }) + return mutatorAborted + } + // Fail-open: execution result above is the audit trail; + // no abort, continue chain. + return mutatorSkipped + } + + cr.Results = append(cr.Results, ExecutionResult{ + Interceptor: m.Name, + Type: TypeMutation, + Phase: mInv.Phase, + DurationMs: dur, + Mutation: result, + }) + + // Guard against handlers that return (nil, nil). + if result == nil { + return mutatorSkipped + } + + // In audit mode, the mutation result is recorded for observability but + // the payload is not modified — the caller already gave us a deep copy. + if m.Mode == ModeAudit { + return mutatorSkipped + } + + return mutatorOK +} diff --git a/go/sdk/interceptors/chain_test.go b/go/sdk/interceptors/chain_test.go new file mode 100644 index 0000000..46ee6dd --- /dev/null +++ b/go/sdk/interceptors/chain_test.go @@ -0,0 +1,137 @@ +// Copyright 2025 The MCP Interceptors Authors. All rights reserved. +// Use of this source code is governed by an Apache-2.0 +// license that can be found in the LICENSE file. + +package interceptors + +import ( + "context" + "fmt" + "log/slog" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// stubPayload is a minimal payload for chain-level tests. +type stubPayload struct{ Value string } + +func TestChain_FailOpenRecordsExecutionResult(t *testing.T) { + t.Parallel() + // Verifies that fail-open interceptors that return errors still + // have an ExecutionResult recorded in ChainResult.Results, even + // though they don't produce an AbortInfo. + + t.Run("fail-open validator error is recorded", func(t *testing.T) { + t.Parallel() + failOpenValidator := &Validator{ + Metadata: Metadata{ + Name: "fo-validator", + Events: []string{"test/event"}, + Phase: PhaseRequest, + Mode: ModeOn, + FailOpen: true, + }, + Handler: func(_ context.Context, _ *Invocation) (*ValidationResult, error) { + return nil, fmt.Errorf("transient failure") + }, + } + passingValidator := &Validator{ + Metadata: Metadata{ + Name: "passing-validator", + Events: []string{"test/event"}, + Phase: PhaseRequest, + Mode: ModeOn, + }, + Handler: func(_ context.Context, _ *Invocation) (*ValidationResult, error) { + return &ValidationResult{Valid: true}, nil + }, + } + + chain := NewChain(WithChainLogger(slog.Default())) + chain.Add(failOpenValidator).Add(passingValidator) + + inv := &Invocation{ + Event: "test/event", + Phase: PhaseRequest, + Payload: &stubPayload{Value: "hello"}, + } + + cr, err := chain.ExecuteForReceiving(context.Background(), inv) + require.NoError(t, err) + + // Chain should succeed — fail-open doesn't abort. + assert.Equal(t, ChainSuccess, cr.Status) + assert.Empty(t, cr.AbortedAt) + + // Both interceptors should have an ExecutionResult. + require.Len(t, cr.Results, 2) + var found bool + for _, r := range cr.Results { + if r.Interceptor == "fo-validator" { + found = true + assert.Equal(t, TypeValidation, r.Type) + assert.Equal(t, "transient failure", r.Error) + } + } + assert.True(t, found, "fail-open validator should have an ExecutionResult") + }) + + t.Run("fail-open mutator error is recorded", func(t *testing.T) { + t.Parallel() + failOpenMutator := &Mutator{ + Metadata: Metadata{ + Name: "fo-mutator", + Events: []string{"test/event"}, + Phase: PhaseResponse, + Mode: ModeOn, + FailOpen: true, + PriorityHint: NewPriority(10), + }, + Handler: func(_ context.Context, _ *Invocation) (*MutationResult, error) { + return nil, fmt.Errorf("transient failure") + }, + } + passingMutator := &Mutator{ + Metadata: Metadata{ + Name: "passing-mutator", + Events: []string{"test/event"}, + Phase: PhaseResponse, + Mode: ModeOn, + PriorityHint: NewPriority(20), + }, + Handler: func(_ context.Context, _ *Invocation) (*MutationResult, error) { + return &MutationResult{Modified: false}, nil + }, + } + + chain := NewChain(WithChainLogger(slog.Default())) + chain.Add(failOpenMutator).Add(passingMutator) + + inv := &Invocation{ + Event: "test/event", + Phase: PhaseResponse, + Payload: &stubPayload{Value: "hello"}, + } + + cr, err := chain.ExecuteForSending(context.Background(), inv) + require.NoError(t, err) + + // Chain should succeed — fail-open doesn't abort. + assert.Equal(t, ChainSuccess, cr.Status) + assert.Empty(t, cr.AbortedAt) + + // Both interceptors should have an ExecutionResult. + require.Len(t, cr.Results, 2) + var found bool + for _, r := range cr.Results { + if r.Interceptor == "fo-mutator" { + found = true + assert.Equal(t, TypeMutation, r.Type) + assert.Equal(t, "transient failure", r.Error) + } + } + assert.True(t, found, "fail-open mutator should have an ExecutionResult") + }) +} diff --git a/go/sdk/interceptors/chain_validate.go b/go/sdk/interceptors/chain_validate.go new file mode 100644 index 0000000..e223e90 --- /dev/null +++ b/go/sdk/interceptors/chain_validate.go @@ -0,0 +1,191 @@ +// Copyright 2025 The MCP Interceptors Authors. All rights reserved. +// Use of this source code is governed by an Apache-2.0 +// license that can be found in the LICENSE file. + +package interceptors + +import ( + "context" + "sync" + "time" +) + +// validatorResult holds the output of a single validator handler call, +// separating execution from result recording so the handler can run +// without holding a lock. +type validatorResult struct { + result *ValidationResult + err error + durationMs int64 + timedOut bool +} + +// runValidators runs all validators concurrently, collects their results, and +// records any abort conditions. All validators run to completion +// before any abort decision is made — this ensures the ChainResult contains a +// complete picture of all validation findings, not just the first failure. +// +// A mutex guards all writes to the shared ChainResult (Results, AbortedAt, +// ValidationSummary) since goroutines append concurrently. +// +// Error handling per validator: +// - Handler returns error + FailOpen=false (fail-closed): records an +// ExecutionResult (for timing/audit) and an AbortInfo to block the chain. +// - Handler returns error + FailOpen=true: logs and records an +// ExecutionResult (for observability) but does not abort. +// - Handler returns a result with Valid=false + Mode!=Audit (enforced): +// scans messages for the first error-severity entry and records an AbortInfo. +// - Handler returns a result with Valid=false + Mode=Audit: records normally +// without aborting (audit-only observation). +func (ce *chainExecutor) runValidators(ctx context.Context, inv *Invocation, cr *ChainResult) { + if len(ce.validators) == 0 { + return + } + + // Fast path: single validator doesn't need goroutine, WaitGroup, or Mutex. + if len(ce.validators) == 1 { + v := ce.validators[0] + vr := ce.executeValidator(ctx, v, inv) + ce.recordValidation(v, inv, cr, vr) + return + } + + var ( + mu sync.Mutex + wg sync.WaitGroup + ) + + for _, v := range ce.validators { + wg.Add(1) + go func(v *Validator) { + defer wg.Done() + + vr := ce.executeValidator(ctx, v, inv) + + mu.Lock() + ce.recordValidation(v, inv, cr, vr) + mu.Unlock() + }(v) + } + + wg.Wait() +} + +// executeValidator runs a single validator handler and returns its output. +// It does not modify ChainResult, so it is safe to call without a lock. +func (ce *chainExecutor) executeValidator(ctx context.Context, v *Validator, inv *Invocation) validatorResult { + if v.Handler == nil { + ce.logger.Warn("validator has nil handler, skipping", + "interceptor", v.Name, + ) + return validatorResult{} + } + + // Apply per-interceptor timeout from metadata. + handlerCtx := ctx + if v.TimeoutMs > 0 { + var cancel context.CancelFunc + handlerCtx, cancel = context.WithTimeout(ctx, time.Duration(v.TimeoutMs)*time.Millisecond) + defer cancel() + } + + vStart := time.Now() + result, err := v.Handler(handlerCtx, inv) + dur := time.Since(vStart).Milliseconds() + + return validatorResult{ + result: result, + err: err, + durationMs: dur, + timedOut: handlerCtx.Err() == context.DeadlineExceeded, + } +} + +// recordValidation writes a validator's execution outcome into the +// ChainResult. The caller must ensure exclusive access to cr (either +// by holding a lock or being the only writer in the N=1 fast path). +func (ce *chainExecutor) recordValidation(v *Validator, inv *Invocation, cr *ChainResult, vr validatorResult) { + if vr.err != nil { + // Distinguish timeout errors from general validation errors + // so downstream consumers can differentiate root cause. + abortType := AbortValidation + if vr.timedOut { + abortType = AbortTimeout + } + ce.logger.Warn("validator error", + "interceptor", v.Name, + "error", vr.err, + ) + // Always record the execution result for observability, + // regardless of fail-open/fail-closed. + cr.Results = append(cr.Results, ExecutionResult{ + Interceptor: v.Name, + Type: TypeValidation, + Phase: inv.Phase, + DurationMs: vr.durationMs, + Error: vr.err.Error(), + }) + // Fail-closed: record an abort entry to halt the chain. + // Fail-open: the execution result above is the audit trail; + // no abort is recorded. + if !v.FailOpen { + cr.AbortedAt = append(cr.AbortedAt, AbortInfo{ + Interceptor: v.Name, + Reason: vr.err.Error(), + Type: abortType, + Phase: string(inv.Phase), + }) + } + return + } + + // Guard against handlers that return (nil, nil). + if vr.result == nil { + return + } + + cr.Results = append(cr.Results, ExecutionResult{ + Interceptor: v.Name, + Type: TypeValidation, + Phase: inv.Phase, + DurationMs: vr.durationMs, + Validation: vr.result, + }) + + // Tally validation summary. + for _, msg := range vr.result.Messages { + switch msg.Severity { + case SeverityError: + cr.ValidationSummary.Errors++ + case SeverityWarn: + cr.ValidationSummary.Warnings++ + case SeverityInfo: + cr.ValidationSummary.Infos++ + } + } + + // Only error-severity messages cause an abort. A validator + // returning Valid=false with only warn/info messages does NOT + // block the chain — the findings are recorded in the + // ValidationSummary and ExecutionResult but execution continues. + // + // In audit mode (ModeAudit), even error-severity messages don't + // abort — the result is recorded for observability only. + // + // When multiple error messages exist, only the first is recorded + // as the abort reason. All messages remain visible in the + // ValidationResult attached to the ExecutionResult. + if v.Mode != ModeAudit && !vr.result.Valid { + for _, msg := range vr.result.Messages { + if msg.Severity == SeverityError { + cr.AbortedAt = append(cr.AbortedAt, AbortInfo{ + Interceptor: v.Name, + Reason: msg.Message, + Type: AbortValidation, + Phase: string(inv.Phase), + }) + break + } + } + } +} diff --git a/go/sdk/interceptors/doc.go b/go/sdk/interceptors/doc.go new file mode 100644 index 0000000..a7c2493 --- /dev/null +++ b/go/sdk/interceptors/doc.go @@ -0,0 +1,158 @@ +// Copyright 2025 The MCP Interceptors Authors. All rights reserved. +// Use of this source code is governed by an Apache-2.0 +// license that can be found in the LICENSE file. + +// Package interceptors provides a protocol-agnostic validation and +// mutation middleware framework. It defines interceptor types +// ([Validator], [Mutator]), the chain engine ([Chain]), and all +// supporting types needed to build interceptor pipelines for any +// server protocol. +// +// This package has no dependency on MCP or any specific transport. +// For MCP-specific integration (wrapping an [mcp.Server] with +// middleware), see the [interceptors/mcpserver] sub-package. +// +// # Standalone Usage (any protocol) +// +// Create a [Chain], register interceptors, and call +// [Chain.ExecuteForReceiving] or [Chain.ExecuteForSending] from your +// server's request/response pipeline: +// +// chain := interceptors.NewChain( +// interceptors.WithChainLogger(logger), +// ) +// chain.Add(myValidator) +// chain.Add(myMutator) +// +// // In your request handler: +// inv := &interceptors.Invocation{ +// Event: "tools/call", +// Phase: interceptors.PhaseRequest, +// Payload: typedParams, +// } +// cr, err := chain.ExecuteForReceiving(ctx, inv) +// if err != nil { ... } +// if cr != nil && len(cr.AbortedAt) > 0 { +// // handle abort +// } +// +// # Validators +// +// A [Validator] inspects the typed payload and decides whether the +// request or response should proceed. All validators for a given +// event run in parallel; because they share the same [Invocation] +// pointer, handlers MUST treat the Invocation and its Payload as +// read-only — mutating either is a data race. If any validator in +// enforced mode ([ModeOn]) returns an error-severity message, the +// chain aborts before any mutators run. Only error-severity messages +// cause an abort; warn and info findings are recorded in the +// [ChainResult] but do not block the chain. +// +// Type-assert the payload to its concrete type: +// +// v := &interceptors.Validator{ +// Metadata: interceptors.Metadata{ +// Name: "block-dangerous-tool", +// Events: []string{"tools/call"}, +// Phase: interceptors.PhaseRequest, +// Mode: interceptors.ModeOn, +// }, +// Handler: func(ctx context.Context, inv *interceptors.Invocation) (*interceptors.ValidationResult, error) { +// params, ok := inv.Payload.(*MyRequestParams) +// if !ok { +// return nil, fmt.Errorf("unexpected payload type %T", inv.Payload) +// } +// // inspect params ... +// return &interceptors.ValidationResult{Valid: true}, nil +// }, +// } +// +// # Mutators +// +// A [Mutator] transforms the payload in place. Mutators run sequentially +// in priority order (see [Priority]). Each mutator receives the typed +// value and can modify it directly. If any mutator fails (and is not +// configured with FailOpen), the chain aborts. FailOpen mutators +// record an [ExecutionResult] (with the error captured) for +// observability but do not block. +// +// Type-assert the payload and modify the value in place: +// +// m := &interceptors.Mutator{ +// Metadata: interceptors.Metadata{ +// Name: "redact-pii", +// Events: []string{"tools/call"}, +// Phase: interceptors.PhaseResponse, +// Mode: interceptors.ModeOn, +// }, +// Handler: func(ctx context.Context, inv *interceptors.Invocation) (*interceptors.MutationResult, error) { +// result, ok := inv.Payload.(*MyResponseResult) +// if !ok { +// return nil, fmt.Errorf("unexpected payload type %T", inv.Payload) +// } +// // modify result in place ... +// return &interceptors.MutationResult{Modified: true}, nil +// }, +// } +// +// # Execution Order +// +// The chain execution order depends on direction: +// +// Request phase (untrusted → trusted): +// +// Validate (parallel) → Mutate (sequential) +// +// Response phase (trusted → untrusted): +// +// Mutate (sequential) → Validate (parallel) +// +// Validators act as a security gate on the trust boundary side, +// while mutators prepare or sanitize data on the other side. +// +// # Modes and FailOpen +// +// Each interceptor has a [Mode] that controls what happens with +// successful results, and a FailOpen flag that controls what happens +// when the handler returns a Go error. These are orthogonal: +// +// - [ModeOn]: fully enforced — validation failures block, mutations +// apply in place. +// - [ModeAudit]: the handler runs and results are recorded, but +// validation findings do not block and mutations run on a +// deep-copied payload so the real data is unaffected. +// - [ModeOff]: the interceptor is skipped entirely. +// +// FailOpen (default false) controls crash resilience: +// +// - FailOpen=false: a handler error aborts the chain. An +// [ExecutionResult] (with Error populated) and an [AbortInfo] +// are both recorded. +// - FailOpen=true: a handler error is logged and an +// [ExecutionResult] is recorded, but the chain continues. +// +// Note that [ModeAudit] does NOT imply FailOpen. Audit mode only +// suppresses enforcement of successful results (validation findings +// and mutations). If the handler itself returns an error and +// FailOpen is false, the chain still aborts. For truly safe +// observation-only interceptors, set both ModeAudit and FailOpen: +// +// Metadata: interceptors.Metadata{ +// Mode: interceptors.ModeAudit, +// FailOpen: true, +// } +// +// Behavior matrix for validators: +// +// Mode=On, FailOpen=false → error aborts, Valid=false+SeverityError aborts +// Mode=On, FailOpen=true → error continues, Valid=false+SeverityError aborts +// ModeAudit, FailOpen=false → error aborts, findings recorded only +// ModeAudit, FailOpen=true → error continues, findings recorded only +// +// Behavior matrix for mutators: +// +// Mode=On, FailOpen=false → error aborts, mutations applied in place +// Mode=On, FailOpen=true → error continues, mutations applied in place +// ModeAudit, FailOpen=false → error aborts, mutations recorded (deep copy) +// ModeAudit, FailOpen=true → error continues, mutations recorded (deep copy) +package interceptors diff --git a/go/sdk/interceptors/interceptor.go b/go/sdk/interceptors/interceptor.go index 3cac4e1..483753e 100644 --- a/go/sdk/interceptors/interceptor.go +++ b/go/sdk/interceptors/interceptor.go @@ -1,6 +1,163 @@ // Copyright 2025 The MCP Interceptors Authors. All rights reserved. -// Use of this source code is governed by a Apache-2.0 +// Use of this source code is governed by an Apache-2.0 // license that can be found in the LICENSE file. -// Package interceptors implements MCP Interceptors based on [SEP-1763](https://github.com/modelcontextprotocol/modelcontextprotocol/issues/1763). package interceptors + +import ( + "context" + "encoding/json" +) + +// Phase determines when an interceptor runs. +type Phase string + +const ( + PhaseRequest Phase = "request" + PhaseResponse Phase = "response" + PhaseBoth Phase = "both" +) + +// Mode controls enforcement behavior. +type Mode string + +const ( + ModeOn Mode = "on" // Enforced: validation failures block, mutations apply + ModeAudit Mode = "audit" // Audit: log results but don't block or apply mutations + ModeOff Mode = "off" // Disabled +) + +// InterceptorType identifies the category of an interceptor. +type InterceptorType string + +const ( + TypeValidation InterceptorType = "validation" + TypeMutation InterceptorType = "mutation" +) + +// Priority represents an interceptor's ordering hint. +// Can be a single value (applies to both phases) or per-phase. +// +// JSON representation is polymorphic: a single number when both +// phases are equal, or {"request": N, "response": N} when they differ. +type Priority struct { + Request int + Response int +} + +// NewPriority creates a Priority with the same value for both phases. +func NewPriority(v int) Priority { + return Priority{Request: v, Response: v} +} + +// Resolve returns the priority for the given phase. +func (p Priority) Resolve(phase Phase) int { + if phase == PhaseResponse { + return p.Response + } + return p.Request +} + +// MarshalJSON implements polymorphic serialization: +// emits a single number when both phases are equal, or an object otherwise. +func (p Priority) MarshalJSON() ([]byte, error) { + if p.Request == p.Response { + return json.Marshal(p.Request) + } + return json.Marshal(struct { + Request int `json:"request,omitempty"` + Response int `json:"response,omitempty"` + }{p.Request, p.Response}) +} + +// UnmarshalJSON handles both number and {request, response} forms. +func (p *Priority) UnmarshalJSON(data []byte) error { + var n int + if err := json.Unmarshal(data, &n); err == nil { + p.Request = n + p.Response = n + return nil + } + var obj struct { + Request int `json:"request"` + Response int `json:"response"` + } + if err := json.Unmarshal(data, &obj); err != nil { + return err + } + p.Request = obj.Request + p.Response = obj.Response + return nil +} + +// Severity represents validation message severity. +type Severity string + +const ( + SeverityInfo Severity = "info" + SeverityWarn Severity = "warn" + SeverityError Severity = "error" // Only error blocks execution +) + +// Compat represents protocol version compatibility. +type Compat struct { + MinProtocol string `json:"minProtocol"` + MaxProtocol string `json:"maxProtocol,omitempty"` +} + +// Metadata holds all common interceptor metadata. +type Metadata struct { + Name string `json:"name"` + Version string `json:"version,omitempty"` + Description string `json:"description,omitempty"` + Events []string `json:"events"` + Phase Phase `json:"phase"` + PriorityHint Priority `json:"priorityHint,omitempty"` + Compat *Compat `json:"compat,omitempty"` + ConfigSchema json.RawMessage `json:"configSchema,omitempty"` + + // TODO: Therese are essentially deployment configs?, so we might want to move them into a separate struct + // and have the user configure them separately? + Mode Mode `json:"mode"` + FailOpen bool `json:"failOpen,omitempty"` + TimeoutMs int64 `json:"timeoutMs,omitempty"` +} + +// Interceptor is the common interface for all interceptors. It is implemented by both Validator and Mutator. +type Interceptor interface { + GetMetadata() *Metadata + GetType() InterceptorType +} + +// --- ValidatorHandler --- + +// ValidatorHandler is the function signature for validation handlers. +// +// Handlers MUST treat the Invocation and its Payload as read-only. +// Multiple validators for the same event run concurrently and share +// the same Invocation pointer, so any mutation of the Payload (or +// other Invocation fields) is a data race. +type ValidatorHandler func(ctx context.Context, inv *Invocation) (*ValidationResult, error) + +// Validator is a validation interceptor. +type Validator struct { + Metadata + Handler ValidatorHandler +} + +func (v *Validator) GetMetadata() *Metadata { return &v.Metadata } +func (v *Validator) GetType() InterceptorType { return TypeValidation } + +// --- MutatorHandler --- + +// MutatorHandler is the function signature for raw mutation handlers. +type MutatorHandler func(ctx context.Context, inv *Invocation) (*MutationResult, error) + +// Mutator is a mutation interceptor. +type Mutator struct { + Metadata + Handler MutatorHandler +} + +func (m *Mutator) GetMetadata() *Metadata { return &m.Metadata } +func (m *Mutator) GetType() InterceptorType { return TypeMutation } diff --git a/go/sdk/interceptors/interceptor_test.go b/go/sdk/interceptors/interceptor_test.go deleted file mode 100644 index fb110b9..0000000 --- a/go/sdk/interceptors/interceptor_test.go +++ /dev/null @@ -1,14 +0,0 @@ -// Copyright 2025 The MCP Interceptors Authors. All rights reserved. -// Use of this source code is governed by a Apache-2.0 -// license that can be found in the LICENSE file. - -package interceptors - -import "testing" - -func TestPlaceholder(t *testing.T) { - // Placeholder test to ensure go test runs successfully - if false { - t.Error("This should not happen") - } -} diff --git a/go/sdk/interceptors/invocation.go b/go/sdk/interceptors/invocation.go new file mode 100644 index 0000000..6d1e7df --- /dev/null +++ b/go/sdk/interceptors/invocation.go @@ -0,0 +1,80 @@ +// Copyright 2025 The MCP Interceptors Authors. All rights reserved. +// Use of this source code is governed by an Apache-2.0 +// license that can be found in the LICENSE file. + +package interceptors + +import ( + "encoding/json" + "fmt" + "reflect" +) + +// Invocation is the context passed to every interceptor handler. +type Invocation struct { + Event string // e.g. "tools/call" + Phase Phase // "request" or "response" + Payload any // The typed payload (use type assertion directly) + Config map[string]any // Per-invocation config + Context *InvocationContext // Optional caller context (identity, trace, etc.) + Session any // The server session; nil for protocol-level invocations + + mutatedParams any // set via SetMutatedParams in response phase only +} + +// MutatedParams returns the request params after request-phase mutators +// have run. Only available in the response phase; returns nil in the +// request phase. +func (inv *Invocation) MutatedParams() any { + return inv.mutatedParams +} + +// SetMutatedParams sets the mutated request params on the invocation. +// This is intended for server integrations that build response-phase +// invocations; interceptor handlers should use MutatedParams() to read. +func (inv *Invocation) SetMutatedParams(p any) { + inv.mutatedParams = p +} + +// withCopiedPayload returns a shallow copy of the Invocation with a +// deep-copied Payload. Used for audit-mode mutators so their in-place +// modifications don't affect the real struct. +// +// The clone is a JSON round-trip: marshal the typed value, then +// unmarshal into a fresh instance of the same concrete type via +// reflect.New. The payload must be a pointer. +func (inv *Invocation) withCopiedPayload() (*Invocation, error) { + raw, err := json.Marshal(inv.Payload) + if err != nil { + return nil, fmt.Errorf("clone payload: marshal: %w", err) + } + cp := reflect.New(reflect.TypeOf(inv.Payload).Elem()) + if err := json.Unmarshal(raw, cp.Interface()); err != nil { + return nil, fmt.Errorf("clone payload: unmarshal: %w", err) + } + return &Invocation{ + Event: inv.Event, + Phase: inv.Phase, + Payload: cp.Interface(), + mutatedParams: inv.mutatedParams, + Config: inv.Config, + Context: inv.Context, + Session: inv.Session, + }, nil +} + +// InvocationContext holds optional context passed to interceptors. +type InvocationContext struct { + Principal *Principal `json:"principal,omitempty"` + TraceID string `json:"traceId,omitempty"` + SpanID string `json:"spanId,omitempty"` + Timestamp string `json:"timestamp,omitempty"` + SessionID string `json:"sessionId,omitempty"` +} + +// Principal identifies the caller. +type Principal struct { + Type string `json:"type"` + ID string `json:"id,omitempty"` + Claims map[string]any `json:"claims,omitempty"` +} diff --git a/go/sdk/interceptors/mcpserver/doc.go b/go/sdk/interceptors/mcpserver/doc.go new file mode 100644 index 0000000..3e02d80 --- /dev/null +++ b/go/sdk/interceptors/mcpserver/doc.go @@ -0,0 +1,62 @@ +// Copyright 2025 The MCP Interceptors Authors. All rights reserved. +// Use of this source code is governed by an Apache-2.0 +// license that can be found in the LICENSE file. + +// Package mcpserver integrates the [interceptors] chain engine with +// [mcp.Server] from the go-sdk. It installs interceptors as +// receiving middleware, so they run for all transports (stdio, HTTP) +// and compose with other extensions like variants. +// +// # Getting Started +// +// Wrap an existing [mcp.Server] with [NewServer], register one or +// more interceptors, then start the server as usual: +// +// srv := mcpserver.NewServer(mcpServer, +// mcpserver.WithLogger(logger), +// mcpserver.WithContextProvider(myProvider), +// ) +// srv.AddInterceptor(myValidator) +// srv.AddInterceptor(myMutator) +// srv.Run(ctx, transport) +// +// # Middleware Ordering +// +// [NewServer] calls [mcp.Server.AddReceivingMiddleware] to install the +// interceptor chain. The go-sdk stacks middleware so that the first +// middleware added runs outermost. This means ordering depends on +// when other middleware is registered relative to NewServer: +// +// // auth runs BEFORE interceptors (outermost) +// mcpServer.AddReceivingMiddleware(authMiddleware) +// srv := mcpserver.NewServer(mcpServer) +// // logging runs AFTER interceptors (innermost) +// mcpServer.AddReceivingMiddleware(loggingMiddleware) +// +// Execution order: auth → interceptors → logging → handler +// +// # Context Provider +// +// Use [WithContextProvider] to populate [interceptors.InvocationContext] +// for every intercepted request. This is typically used to extract +// caller identity from OAuth tokens or session metadata: +// +// mcpserver.WithContextProvider( +// func(ctx context.Context, req mcp.Request) *interceptors.InvocationContext { +// return &interceptors.InvocationContext{ +// Principal: &interceptors.Principal{Type: "user", ID: "alice"}, +// } +// }, +// ) +// +// Handlers access this via [interceptors.Invocation].Context. +// +// # Transports +// +// The interceptor server works with any transport supported by the +// go-sdk. For stdio, use [Server.Run]. For HTTP, use +// [NewStreamableHTTPHandler]: +// +// handler := mcpserver.NewStreamableHTTPHandler(srv, nil) +// http.ListenAndServe(":8080", handler) +package mcpserver diff --git a/go/sdk/interceptors/mcpserver/events.go b/go/sdk/interceptors/mcpserver/events.go new file mode 100644 index 0000000..a5d4809 --- /dev/null +++ b/go/sdk/interceptors/mcpserver/events.go @@ -0,0 +1,17 @@ +// Copyright 2025 The MCP Interceptors Authors. All rights reserved. +// Use of this source code is governed by an Apache-2.0 +// license that can be found in the LICENSE file. + +package mcpserver + +// Event name constants for standard MCP methods. +const ( + // Server Features + EventToolsList = "tools/list" + EventToolsCall = "tools/call" + EventPromptsList = "prompts/list" + EventPromptsGet = "prompts/get" + EventResourcesList = "resources/list" + EventResourcesRead = "resources/read" + EventResourcesSubscribe = "resources/subscribe" +) diff --git a/go/sdk/interceptors/mcpserver/server.go b/go/sdk/interceptors/mcpserver/server.go new file mode 100644 index 0000000..9c82227 --- /dev/null +++ b/go/sdk/interceptors/mcpserver/server.go @@ -0,0 +1,366 @@ +// Copyright 2025 The MCP Interceptors Authors. All rights reserved. +// Use of this source code is governed by an Apache-2.0 +// license that can be found in the LICENSE file. + +package mcpserver + +import ( + "context" + "encoding/json" + "log/slog" + "net/http" + + "github.com/modelcontextprotocol/go-sdk/jsonrpc" + "github.com/modelcontextprotocol/go-sdk/mcp" + + "github.com/modelcontextprotocol/ext-interceptors/go/sdk/interceptors" +) + +// extensionID is the capability key used in Capabilities.Experimental +const extensionID = "io.modelcontextprotocol/interceptors" + +// ServerOption configures a Server. +type ServerOption func(*Server) + +// WithLogger sets the logger for interceptor chain execution. +// If not set, slog.Default() is used. +func WithLogger(l *slog.Logger) ServerOption { + return func(s *Server) { + s.logger = l + } +} + +// ContextProviderFunc extracts an InvocationContext from an incoming MCP +// request. This is called once per intercepted request and the result is +// passed to all interceptor handlers via Invocation.Context. +// +// Typical use: extract principal identity from OAuth tokens available on +// the request's session (via RequestExtra.TokenInfo). +type ContextProviderFunc func(ctx context.Context, req mcp.Request) *interceptors.InvocationContext + +// WithContextProvider sets a function that populates InvocationContext +// for every intercepted request. Without this, Invocation.Context is nil. +func WithContextProvider(f ContextProviderFunc) ServerOption { + return func(s *Server) { + s.contextProvider = f + } +} + +// Server wraps an *mcp.Server with interceptor support. Interceptors are +// installed as middleware on the inner server at construction time, so they +// work across all transports (stdio, HTTP) and compose with other extensions +// like variants. +type Server struct { + inner *mcp.Server + chain *interceptors.Chain + logger *slog.Logger + contextProvider ContextProviderFunc +} + +// NewServer creates a new interceptor-aware server wrapping the given +// mcp.Server. Middleware is installed immediately, so interceptors added +// via AddInterceptor will be active for all future requests. +func NewServer(server *mcp.Server, opts ...ServerOption) *Server { + s := &Server{ + inner: server, + } + for _, opt := range opts { + opt(s) + } + if s.logger == nil { + s.logger = slog.Default() + } + s.chain = interceptors.NewChain(interceptors.WithChainLogger(s.logger)) + server.AddReceivingMiddleware(s.receivingMiddleware()) + return s +} + +// AddInterceptor registers an interceptor. It panics if the interceptor +// has a nil handler. It is safe to call while the server is running. +// Returns the receiver for chaining. +func (s *Server) AddInterceptor(i interceptors.Interceptor) *Server { + s.chain.Add(i) + return s +} + +// MCPServer returns the underlying *mcp.Server. This is useful when +// composing with other extensions (e.g., variants): +// +// is := mcpserver.NewServer(mcpServer) +// is.AddInterceptor(myValidator) +// vs.WithVariant(variant, is.MCPServer(), 0) +func (s *Server) MCPServer() *mcp.Server { + return s.inner +} + +// Run delegates to the inner server's Run method. Convenience for standalone +// use (e.g., stdio). +func (s *Server) Run(ctx context.Context, t mcp.Transport) error { + return s.inner.Run(ctx, t) +} + +// receivingMiddleware returns mcp.Middleware that intercepts incoming +// requests and outgoing responses. +func (s *Server) receivingMiddleware() mcp.Middleware { + return func(next mcp.MethodHandler) mcp.MethodHandler { + return func(ctx context.Context, method string, req mcp.Request) (mcp.Result, error) { + // When handling the "initialize" method, we enrich the result with interceptor + // capability metadata. This allows clients to discover interceptor support and + // capabilities + if method == "initialize" { + result, err := next(ctx, method, req) + if err != nil { + return nil, err + } + return s.enrichInitResult(result) + } + + event := method + + // Fast path: no interceptors match this event. + if s.chain.IsEmpty(event, interceptors.PhaseRequest) && s.chain.IsEmpty(event, interceptors.PhaseResponse) { + return next(ctx, method, req) + } + + var session *mcp.ServerSession + if sess, ok := req.GetSession().(*mcp.ServerSession); ok { + session = sess + } + + // Extract invocation context from the request using the context provider, if set. + // This allows interceptor handlers to access caller identity and other contextual info. + var invCtx *interceptors.InvocationContext + if s.contextProvider != nil { + invCtx = s.contextProvider(ctx, req) + } + + // Run request-phase interceptors. + if !s.chain.IsEmpty(event, interceptors.PhaseRequest) { + if err := s.interceptRequest(ctx, req, event, session, invCtx); err != nil { + return nil, err + } + } + + // Call the next handler. + result, err := next(ctx, method, req) + if err != nil { + return result, err + } + + // Run response-phase interceptors. + if !s.chain.IsEmpty(event, interceptors.PhaseResponse) { + return s.interceptResponse(ctx, req, result, event, session, invCtx) + } + return result, nil + } + } +} + +// interceptRequest runs the request-phase interceptor chain. Mutators +// modify the typed params in place, so no marshal/unmarshal is needed. +func (s *Server) interceptRequest( + ctx context.Context, + req mcp.Request, + event string, + session *mcp.ServerSession, + invCtx *interceptors.InvocationContext, +) error { + inv := &interceptors.Invocation{ + Event: event, + Phase: interceptors.PhaseRequest, + Payload: req.GetParams(), + Session: session, + Context: invCtx, + } + cr, err := s.chain.ExecuteForReceiving(ctx, inv) + if err != nil { + return err + } + if cr != nil && len(cr.AbortedAt) > 0 { + s.logAborts("request", event, cr.AbortedAt) + return abortToJSONRPCError(cr.AbortedAt) + } + return nil +} + +// interceptResponse runs the response-phase interceptor chain. Mutators +// modify the typed result in place, so no marshal/unmarshal is needed. +func (s *Server) interceptResponse( + ctx context.Context, + req mcp.Request, + result mcp.Result, + event string, + session *mcp.ServerSession, + invCtx *interceptors.InvocationContext, +) (mcp.Result, error) { + inv := &interceptors.Invocation{ + Event: event, + Phase: interceptors.PhaseResponse, + Payload: result, + Session: session, + Context: invCtx, + } + inv.SetMutatedParams(req.GetParams()) + cr, err := s.chain.ExecuteForSending(ctx, inv) + if err != nil { + return nil, err + } + if cr != nil && len(cr.AbortedAt) > 0 { + s.logAborts("response", event, cr.AbortedAt) + return nil, abortToJSONRPCError(cr.AbortedAt) + } + + // Result was already modified in place by mutators — no unmarshal needed. + return result, nil +} + +// logAborts logs each abort entry at warn level. +func (s *Server) logAborts(phase, event string, aborts []interceptors.AbortInfo) { + for _, a := range aborts { + s.logger.Warn("interceptors: "+phase+" aborted", + "event", event, + "interceptor", a.Interceptor, + "reason", a.Reason, + ) + } +} + +// enrichInitResult injects interceptor capability metadata into the +// InitializeResult, following the same pattern as the variants extension. +// The capability is declared under Capabilities.Experimental so it works +// without upstream go-sdk changes. +func (s *Server) enrichInitResult(result mcp.Result) (mcp.Result, error) { + initResult, ok := result.(*mcp.InitializeResult) + if !ok { + return result, nil + } + + if initResult.Capabilities == nil { + initResult.Capabilities = &mcp.ServerCapabilities{} + } + if initResult.Capabilities.Experimental == nil { + initResult.Capabilities.Experimental = make(map[string]any) + } + + // Build the interceptor list in wire format. + // wireInterceptor embeds Metadata (which carries all JSON tags) and + // adds the type field from GetType(). + type wireInterceptor struct { + *interceptors.Metadata + Type interceptors.InterceptorType `json:"type"` + } + all := s.chain.Interceptors() + interceptorInfos := make([]wireInterceptor, 0, len(all)) + supportedEvents := map[string]bool{} + for _, ri := range all { + meta := ri.GetMetadata() + interceptorInfos = append(interceptorInfos, wireInterceptor{ + Metadata: meta, + Type: ri.GetType(), + }) + for _, e := range meta.Events { + supportedEvents[e] = true + } + } + + events := make([]string, 0, len(supportedEvents)) + for e := range supportedEvents { + events = append(events, e) + } + + initResult.Capabilities.Experimental[extensionID] = map[string]any{ + "supportedEvents": events, + "interceptors": interceptorInfos, + } + + return initResult, nil +} + +// NewStreamableHTTPHandler returns a new [mcp.StreamableHTTPHandler] for +// serving multiple concurrent clients over HTTP. It mirrors +// [mcp.NewStreamableHTTPHandler]. +// +// handler := mcpserver.NewStreamableHTTPHandler(srv, nil) +// http.ListenAndServe(":8080", handler) +func NewStreamableHTTPHandler(s *Server, opts *mcp.StreamableHTTPOptions) *mcp.StreamableHTTPHandler { + if s == nil { + panic("mcpserver: nil Server") + } + srv := s.MCPServer() + return mcp.NewStreamableHTTPHandler( + func(r *http.Request) *mcp.Server { return srv }, + opts, + ) +} + +// --- abort conversion --- + +// Error data structs for abortToJSONRPCError. + +type validationErrorData struct { + ValidationErrors []interceptors.AbortInfo `json:"validationErrors"` +} + +type mutationErrorData struct { + FailedInterceptor string `json:"failedInterceptor"` + Reason string `json:"reason"` +} + +type timeoutErrorData struct { + Interceptor string `json:"interceptor"` + TimeoutMs int64 `json:"timeoutMs,omitempty"` + Phase string `json:"phase"` +} + +// codeServerError is the JSON-RPC error code for timeout. +const codeServerError = -32000 + +// abortToJSONRPCError converts a list of abort infos into a *jsonrpc.Error +// with the correct error code and structured data: +// - Validation failures: -32602 (InvalidParams) with validationErrors data +// - Mutation failures: -32603 (InternalError) with failedInterceptor data +func abortToJSONRPCError(aborts []interceptors.AbortInfo) *jsonrpc.Error { + first := aborts[0] + + // Determine the error code from the abort type. + var code int64 + switch first.Type { + case interceptors.AbortValidation: + code = jsonrpc.CodeInvalidParams // -32602 + case interceptors.AbortTimeout: + code = codeServerError // -32000 + default: + code = jsonrpc.CodeInternalError // -32603 + } + + // Build structured error data. + var data json.RawMessage + switch first.Type { + case interceptors.AbortValidation: + d, _ := json.Marshal(validationErrorData{ValidationErrors: aborts}) + data = d + case interceptors.AbortTimeout: + d, _ := json.Marshal(timeoutErrorData{ + Interceptor: first.Interceptor, + Phase: first.Phase, + }) + data = d + default: + d, _ := json.Marshal(mutationErrorData{ + FailedInterceptor: first.Interceptor, + Reason: first.Reason, + }) + data = d + } + + msg := "interceptor abort: " + first.Reason + if len(aborts) > 1 { + msg = "interceptor abort: multiple validation failures" + } + + return &jsonrpc.Error{ + Code: code, + Message: msg, + Data: data, + } +} diff --git a/go/sdk/interceptors/mcpserver/server_integration_test.go b/go/sdk/interceptors/mcpserver/server_integration_test.go new file mode 100644 index 0000000..a1283b6 --- /dev/null +++ b/go/sdk/interceptors/mcpserver/server_integration_test.go @@ -0,0 +1,577 @@ +package mcpserver_test + +import ( + "context" + "encoding/json" + "fmt" + "strings" + "sync/atomic" + "testing" + "time" + + "github.com/modelcontextprotocol/go-sdk/mcp" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/modelcontextprotocol/ext-interceptors/go/sdk/interceptors" + "github.com/modelcontextprotocol/ext-interceptors/go/sdk/interceptors/mcpserver" +) + +func TestServer_ValidatorRejectsBlockedTool(t *testing.T) { + t.Parallel() + cs := setup(t, blockToolValidator("echo")) + + _, err := callEcho(t, cs) + assert.Error(t, err) +} + +func TestServer_ValidatorAllowsTool(t *testing.T) { + t.Parallel() + cs := setup(t, allowAllValidator("allow-all")) + + result, err := callEcho(t, cs) + require.NoError(t, err) + assert.True(t, strings.HasPrefix(resultText(t, result), "echo:")) +} + +func TestServer_MutatorModifiesResponse(t *testing.T) { + t.Parallel() + cs := setup(t, prefixMutator("prefix", "[mutated]", interceptors.PhaseResponse, 0, interceptors.ModeOn)) + + result, err := callEcho(t, cs) + require.NoError(t, err) + assert.True(t, strings.HasPrefix(resultText(t, result), "[mutated] echo:")) +} + +func TestServer_ChainedMutatorsWithAuditMode(t *testing.T) { + t.Parallel() + // 5 response mutators in priority order. Mutator 3 (priority 30) is + // audit-mode: it runs on a deep copy, so its "[AUDIT]" prefix must + // NOT appear in the final output. The other four are enforced and + // each prepend a tag. Expected final text (applied innermost-first + // by ascending priority): + // + // [M5] [M4] [M2] [M1] echo: ... + cs := setup(t, + prefixMutator("m1", "[M1]", interceptors.PhaseResponse, 10, interceptors.ModeOn), + prefixMutator("m2", "[M2]", interceptors.PhaseResponse, 20, interceptors.ModeOn), + prefixMutator("m3-audit", "[AUDIT]", interceptors.PhaseResponse, 30, interceptors.ModeAudit), + prefixMutator("m4", "[M4]", interceptors.PhaseResponse, 40, interceptors.ModeOn), + prefixMutator("m5", "[M5]", interceptors.PhaseResponse, 50, interceptors.ModeOn), + ) + + result, err := callEcho(t, cs) + require.NoError(t, err) + + text := resultText(t, result) + assert.True(t, strings.HasPrefix(text, "[M5] [M4] [M2] [M1] echo:")) + assert.NotContains(t, text, "[AUDIT]") +} + +func TestServer_MutatorFailureAbortsChain(t *testing.T) { + t.Parallel() + // 3 response mutators. Mutator 2 (priority 20) returns an error + // and is fail-closed. The chain aborts; client receives an error. + cs := setup(t, + prefixMutator("m1", "[M1]", interceptors.PhaseResponse, 10, interceptors.ModeOn), + failMutator("m2-fail", interceptors.PhaseResponse, 20), + prefixMutator("m3", "[M3]", interceptors.PhaseResponse, 30, interceptors.ModeOn), + ) + + _, err := callEcho(t, cs) + assert.Error(t, err) +} + +func TestServer_ValidatorWarnDoesNotBlock(t *testing.T) { + t.Parallel() + // A validator returns Valid=false with only SeverityWarn messages. + // The chain should continue — warn-only failures don't abort. + warnValidator := &interceptors.Validator{ + Metadata: interceptors.Metadata{ + Name: "warn-only", + Events: []string{mcpserver.EventToolsCall}, + Phase: interceptors.PhaseRequest, + Mode: interceptors.ModeOn, + }, + Handler: func(_ context.Context, _ *interceptors.Invocation) (*interceptors.ValidationResult, error) { + return &interceptors.ValidationResult{ + Valid: false, + Messages: []interceptors.ValidationMessage{ + {Message: "something looks off", Severity: interceptors.SeverityWarn}, + }, + }, nil + }, + } + + cs := setup(t, warnValidator) + + result, err := callEcho(t, cs) + require.NoError(t, err) + assert.True(t, strings.HasPrefix(resultText(t, result), "echo:")) +} + +func TestServer_ValidationFailurePreventsM(t *testing.T) { + t.Parallel() + // A validator rejects the request. A mutator is also registered but + // should never run because executeForReceiving returns at the + // ChainValidationFailed check before reaching runMutators. + mutatorRan := false + spy := &interceptors.Mutator{ + Metadata: interceptors.Metadata{ + Name: "spy-mutator", + Events: []string{mcpserver.EventToolsCall}, + Phase: interceptors.PhaseRequest, + Mode: interceptors.ModeOn, + }, + Handler: func(_ context.Context, _ *interceptors.Invocation) (*interceptors.MutationResult, error) { + mutatorRan = true + return &interceptors.MutationResult{Modified: false}, nil + }, + } + + cs := setup(t, blockToolValidator("echo"), spy) + + _, err := callEcho(t, cs) + assert.Error(t, err) + assert.False(t, mutatorRan, "mutator should not run after validation failure") +} + +func TestServer_ValidatorTimeoutAbortsChain(t *testing.T) { + t.Parallel() + slowValidator := &interceptors.Validator{ + Metadata: interceptors.Metadata{ + Name: "slow-validator", + Events: []string{mcpserver.EventToolsCall}, + Phase: interceptors.PhaseRequest, + Mode: interceptors.ModeOn, + TimeoutMs: 1, + }, + Handler: func(ctx context.Context, _ *interceptors.Invocation) (*interceptors.ValidationResult, error) { + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-time.After(100 * time.Millisecond): + return &interceptors.ValidationResult{Valid: true}, nil + } + }, + } + + cs := setup(t, slowValidator) + + _, err := callEcho(t, cs) + require.Error(t, err) + assert.Contains(t, err.Error(), "context deadline exceeded") +} + +func TestServer_MutatorTimeoutAbortsChain(t *testing.T) { + t.Parallel() + // 3 response mutators. Mutator 2 (priority 20) blocks longer than + // its 1ms timeout. The chain aborts; client receives an error. + slowMutator := &interceptors.Mutator{ + Metadata: interceptors.Metadata{ + Name: "m2-slow", + Events: []string{mcpserver.EventToolsCall}, + Phase: interceptors.PhaseResponse, + Mode: interceptors.ModeOn, + TimeoutMs: 1, + PriorityHint: interceptors.NewPriority(20), + }, + Handler: func(ctx context.Context, _ *interceptors.Invocation) (*interceptors.MutationResult, error) { + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-time.After(100 * time.Millisecond): + return &interceptors.MutationResult{Modified: false}, nil + } + }, + } + + cs := setup(t, + prefixMutator("m1", "[M1]", interceptors.PhaseResponse, 10, interceptors.ModeOn), + slowMutator, + prefixMutator("m3", "[M3]", interceptors.PhaseResponse, 30, interceptors.ModeOn), + ) + + _, err := callEcho(t, cs) + require.Error(t, err) + assert.Contains(t, err.Error(), "context deadline exceeded") +} + +func TestServer_MutatorFailOpenContinuesChain(t *testing.T) { + t.Parallel() + // 3 response mutators. Mutator 2 (priority 20) returns an error but + // has FailOpen set. The chain continues; M1 and M3 both apply. + failOpenMutator := &interceptors.Mutator{ + Metadata: interceptors.Metadata{ + Name: "m2-failopen", + Events: []string{mcpserver.EventToolsCall}, + Phase: interceptors.PhaseResponse, + Mode: interceptors.ModeOn, + FailOpen: true, + PriorityHint: interceptors.NewPriority(20), + }, + Handler: func(_ context.Context, _ *interceptors.Invocation) (*interceptors.MutationResult, error) { + return nil, fmt.Errorf("simulated failure") + }, + } + + cs := setup(t, + prefixMutator("m1", "[M1]", interceptors.PhaseResponse, 10, interceptors.ModeOn), + failOpenMutator, + prefixMutator("m3", "[M3]", interceptors.PhaseResponse, 30, interceptors.ModeOn), + ) + + result, err := callEcho(t, cs) + require.NoError(t, err) + + text := resultText(t, result) + assert.Contains(t, text, "[M1]") + assert.Contains(t, text, "[M3]") +} + +func TestServer_ResponseMutatorSeesMutatedParams(t *testing.T) { + t.Parallel() + // A request mutator injects a marker into the arguments. A response + // mutator reads MutatedParams() and copies the marker into the + // response text. This proves MutatedParams() reflects request-phase + // mutations, not the original request. + reqMutator := &interceptors.Mutator{ + Metadata: interceptors.Metadata{ + Name: "inject-marker", + Events: []string{mcpserver.EventToolsCall}, + Phase: interceptors.PhaseRequest, + Mode: interceptors.ModeOn, + }, + Handler: func(_ context.Context, inv *interceptors.Invocation) (*interceptors.MutationResult, error) { + params, ok := inv.Payload.(*mcp.CallToolParamsRaw) + if !ok { + return nil, fmt.Errorf("unexpected payload type %T", inv.Payload) + } + var args map[string]any + if err := json.Unmarshal(params.Arguments, &args); err != nil { + return nil, err + } + args["marker"] = "injected-by-request-mutator" + raw, err := json.Marshal(args) + if err != nil { + return nil, err + } + params.Arguments = raw + return &interceptors.MutationResult{Modified: true}, nil + }, + } + + respMutator := &interceptors.Mutator{ + Metadata: interceptors.Metadata{ + Name: "read-marker", + Events: []string{mcpserver.EventToolsCall}, + Phase: interceptors.PhaseResponse, + Mode: interceptors.ModeOn, + }, + Handler: func(_ context.Context, inv *interceptors.Invocation) (*interceptors.MutationResult, error) { + result, ok := inv.Payload.(*mcp.CallToolResult) + if !ok { + return nil, fmt.Errorf("unexpected payload type %T", inv.Payload) + } + params, ok := inv.MutatedParams().(*mcp.CallToolParamsRaw) + if !ok { + return nil, fmt.Errorf("unexpected MutatedParams type %T", inv.MutatedParams()) + } + var args map[string]any + if err := json.Unmarshal(params.Arguments, &args); err != nil { + return nil, err + } + marker, _ := args["marker"].(string) + for _, c := range result.Content { + if tc, ok := c.(*mcp.TextContent); ok { + tc.Text = fmt.Sprintf("[marker=%s] %s", marker, tc.Text) + } + } + return &interceptors.MutationResult{Modified: true}, nil + }, + } + + cs := setup(t, reqMutator, respMutator) + + result, err := callEcho(t, cs) + require.NoError(t, err) + assert.Contains(t, resultText(t, result), "[marker=injected-by-request-mutator]") +} + +func TestServer_CombinedValidatorsAndMutators(t *testing.T) { + t.Parallel() + // Exercises the full chain with 3 validators and 3 mutators on each + // side, verifying: + // + // Request: Validate (parallel) → Mutate (sequential by priority) + // Response: Mutate (sequential by priority) → Validate (parallel) + // + // Request mutators inject fields into arguments (reflected in echo output). + // Response mutators prepend tags to text (visible in final result). + // Atomic counters confirm every interceptor actually ran. + + var ( + reqValCount atomic.Int32 + reqMutCount atomic.Int32 + respMutCount atomic.Int32 + respValCount atomic.Int32 + ) + + // --- Request validators (3, parallel) --- + + // req-v1: passes if tool name is "echo". + reqV1 := &interceptors.Validator{ + Metadata: interceptors.Metadata{ + Name: "req-v1-tool-check", + Events: []string{mcpserver.EventToolsCall}, + Phase: interceptors.PhaseRequest, + Mode: interceptors.ModeOn, + }, + Handler: func(_ context.Context, inv *interceptors.Invocation) (*interceptors.ValidationResult, error) { + reqValCount.Add(1) + params, ok := inv.Payload.(*mcp.CallToolParamsRaw) + if !ok { + return nil, fmt.Errorf("unexpected type %T", inv.Payload) + } + if params.Name != "echo" { + return &interceptors.ValidationResult{ + Valid: false, + Messages: []interceptors.ValidationMessage{ + {Message: "only echo allowed", Severity: interceptors.SeverityError}, + }, + }, nil + } + return &interceptors.ValidationResult{Valid: true}, nil + }, + } + + // req-v2: passes if arguments contain "text". + reqV2 := &interceptors.Validator{ + Metadata: interceptors.Metadata{ + Name: "req-v2-args-check", + Events: []string{mcpserver.EventToolsCall}, + Phase: interceptors.PhaseRequest, + Mode: interceptors.ModeOn, + }, + Handler: func(_ context.Context, inv *interceptors.Invocation) (*interceptors.ValidationResult, error) { + reqValCount.Add(1) + params, ok := inv.Payload.(*mcp.CallToolParamsRaw) + if !ok { + return nil, fmt.Errorf("unexpected type %T", inv.Payload) + } + var args map[string]any + if err := json.Unmarshal(params.Arguments, &args); err != nil { + return nil, err + } + if _, ok := args["text"]; !ok { + return &interceptors.ValidationResult{ + Valid: false, + Messages: []interceptors.ValidationMessage{ + {Message: "missing text argument", Severity: interceptors.SeverityError}, + }, + }, nil + } + return &interceptors.ValidationResult{Valid: true}, nil + }, + } + + // req-v3: always passes with a warning (warn doesn't block). + reqV3 := &interceptors.Validator{ + Metadata: interceptors.Metadata{ + Name: "req-v3-warn", + Events: []string{mcpserver.EventToolsCall}, + Phase: interceptors.PhaseRequest, + Mode: interceptors.ModeOn, + }, + Handler: func(_ context.Context, _ *interceptors.Invocation) (*interceptors.ValidationResult, error) { + reqValCount.Add(1) + return &interceptors.ValidationResult{ + Valid: true, + Messages: []interceptors.ValidationMessage{ + {Message: "request noted", Severity: interceptors.SeverityWarn}, + }, + }, nil + }, + } + + // --- Request mutators (3, sequential by priority 10 → 20 → 30) --- + // Each adds a field to the arguments. The echo handler serializes + // the arguments map, so all three fields appear in the output. + + requestArgMutator := func(name, key, value string, priority int) *interceptors.Mutator { + return &interceptors.Mutator{ + Metadata: interceptors.Metadata{ + Name: name, + Events: []string{mcpserver.EventToolsCall}, + Phase: interceptors.PhaseRequest, + Mode: interceptors.ModeOn, + PriorityHint: interceptors.NewPriority(priority), + }, + Handler: func(_ context.Context, inv *interceptors.Invocation) (*interceptors.MutationResult, error) { + reqMutCount.Add(1) + params, ok := inv.Payload.(*mcp.CallToolParamsRaw) + if !ok { + return nil, fmt.Errorf("unexpected type %T", inv.Payload) + } + var args map[string]any + if err := json.Unmarshal(params.Arguments, &args); err != nil { + return nil, err + } + args[key] = value + raw, err := json.Marshal(args) + if err != nil { + return nil, err + } + params.Arguments = raw + return &interceptors.MutationResult{Modified: true}, nil + }, + } + } + + reqM1 := requestArgMutator("req-m1", "step1", "done", 10) + reqM2 := requestArgMutator("req-m2", "step2", "done", 20) + reqM3 := requestArgMutator("req-m3", "step3", "done", 30) + + // --- Response mutators (3, sequential by priority 10 → 20 → 30) --- + // Each prepends a tag. Applied innermost-first, so final text is: + // [RM3] [RM2] [RM1] echo: ... + + respPrefixMutator := func(name, tag string, priority int) *interceptors.Mutator { + return &interceptors.Mutator{ + Metadata: interceptors.Metadata{ + Name: name, + Events: []string{mcpserver.EventToolsCall}, + Phase: interceptors.PhaseResponse, + Mode: interceptors.ModeOn, + PriorityHint: interceptors.NewPriority(priority), + }, + Handler: func(_ context.Context, inv *interceptors.Invocation) (*interceptors.MutationResult, error) { + respMutCount.Add(1) + result, ok := inv.Payload.(*mcp.CallToolResult) + if !ok { + return nil, fmt.Errorf("unexpected type %T", inv.Payload) + } + for _, c := range result.Content { + if tc, ok := c.(*mcp.TextContent); ok { + tc.Text = tag + " " + tc.Text + } + } + return &interceptors.MutationResult{Modified: true}, nil + }, + } + } + + respM1 := respPrefixMutator("resp-m1", "[RM1]", 10) + respM2 := respPrefixMutator("resp-m2", "[RM2]", 20) + respM3 := respPrefixMutator("resp-m3", "[RM3]", 30) + + // --- Response validators (3, parallel) --- + // Run after response mutators. resp-v2 verifies it sees the mutated + // text (tags applied by response mutators), which proves the + // Mutate → Validate ordering on the response side. + + // resp-v1: passes if response has content. + respV1 := &interceptors.Validator{ + Metadata: interceptors.Metadata{ + Name: "resp-v1-has-content", + Events: []string{mcpserver.EventToolsCall}, + Phase: interceptors.PhaseResponse, + Mode: interceptors.ModeOn, + }, + Handler: func(_ context.Context, inv *interceptors.Invocation) (*interceptors.ValidationResult, error) { + respValCount.Add(1) + result, ok := inv.Payload.(*mcp.CallToolResult) + if !ok { + return nil, fmt.Errorf("unexpected type %T", inv.Payload) + } + if len(result.Content) == 0 { + return &interceptors.ValidationResult{ + Valid: false, + Messages: []interceptors.ValidationMessage{ + {Message: "empty response", Severity: interceptors.SeverityError}, + }, + }, nil + } + return &interceptors.ValidationResult{Valid: true}, nil + }, + } + + // resp-v2: rejects if response text doesn't contain "[RM1]". + // This proves the response-side ordering (Mutate → Validate): if + // this validator ran before mutators, it wouldn't find the tag and + // would abort the chain, failing the test. + respV2 := &interceptors.Validator{ + Metadata: interceptors.Metadata{ + Name: "resp-v2-sees-mutation", + Events: []string{mcpserver.EventToolsCall}, + Phase: interceptors.PhaseResponse, + Mode: interceptors.ModeOn, + }, + Handler: func(_ context.Context, inv *interceptors.Invocation) (*interceptors.ValidationResult, error) { + respValCount.Add(1) + result, ok := inv.Payload.(*mcp.CallToolResult) + if !ok { + return nil, fmt.Errorf("unexpected type %T", inv.Payload) + } + for _, c := range result.Content { + if tc, ok := c.(*mcp.TextContent); ok { + if strings.Contains(tc.Text, "[RM1]") { + return &interceptors.ValidationResult{Valid: true}, nil + } + } + } + return &interceptors.ValidationResult{ + Valid: false, + Messages: []interceptors.ValidationMessage{ + {Message: "response missing mutator tag [RM1]", Severity: interceptors.SeverityError}, + }, + }, nil + }, + } + + // resp-v3: always passes with an info message. + respV3 := &interceptors.Validator{ + Metadata: interceptors.Metadata{ + Name: "resp-v3-info", + Events: []string{mcpserver.EventToolsCall}, + Phase: interceptors.PhaseResponse, + Mode: interceptors.ModeOn, + }, + Handler: func(_ context.Context, _ *interceptors.Invocation) (*interceptors.ValidationResult, error) { + respValCount.Add(1) + return &interceptors.ValidationResult{ + Valid: true, + Messages: []interceptors.ValidationMessage{ + {Message: "response looks good", Severity: interceptors.SeverityInfo}, + }, + }, nil + }, + } + + cs := setup(t, + reqV1, reqV2, reqV3, + reqM1, reqM2, reqM3, + respM1, respM2, respM3, + respV1, respV2, respV3, + ) + + result, err := callEcho(t, cs) + require.NoError(t, err) + + text := resultText(t, result) + + // Response mutators applied in priority order (outermost = highest priority). + assert.True(t, strings.HasPrefix(text, "[RM3] [RM2] [RM1] echo:"), + "expected response mutator tags in priority order, got: %s", text) + + // Request mutators injected fields into arguments (visible in echo output). + assert.Contains(t, text, "step1") + assert.Contains(t, text, "step2") + assert.Contains(t, text, "step3") + + // Every interceptor ran exactly once. + assert.Equal(t, int32(3), reqValCount.Load(), "expected 3 request validators to run") + assert.Equal(t, int32(3), reqMutCount.Load(), "expected 3 request mutators to run") + assert.Equal(t, int32(3), respMutCount.Load(), "expected 3 response mutators to run") + assert.Equal(t, int32(3), respValCount.Load(), "expected 3 response validators to run") +} diff --git a/go/sdk/interceptors/mcpserver/testharness_test.go b/go/sdk/interceptors/mcpserver/testharness_test.go new file mode 100644 index 0000000..42200ea --- /dev/null +++ b/go/sdk/interceptors/mcpserver/testharness_test.go @@ -0,0 +1,191 @@ +package mcpserver_test + +import ( + "context" + "fmt" + "net/http/httptest" + "testing" + + "github.com/modelcontextprotocol/go-sdk/mcp" + "github.com/stretchr/testify/require" + + "github.com/modelcontextprotocol/ext-interceptors/go/sdk/interceptors" + "github.com/modelcontextprotocol/ext-interceptors/go/sdk/interceptors/mcpserver" +) + +// --- Server setup --- + +// setup creates an HTTP test server with interceptor support and a connected +// client session. The server has a single "echo" tool registered for +// EventToolsCall. Cleanup is registered via t.Cleanup. +func setup(t *testing.T, is ...interceptors.Interceptor) *mcp.ClientSession { + t.Helper() + return setupWithTools(t, defaultTools(), is...) +} + +// setupWithTools creates an HTTP test server with the given tools and +// interceptors. Cleanup is registered via t.Cleanup. +func setupWithTools(t *testing.T, tools []testTool, is ...interceptors.Interceptor) *mcp.ClientSession { + t.Helper() + + mcpServer := mcp.NewServer(&mcp.Implementation{ + Name: "test-server", + Version: "0.1.0", + }, nil) + + for _, tool := range tools { + mcpServer.AddTool(tool.tool, tool.handler) + } + + srv := mcpserver.NewServer(mcpServer) + for _, i := range is { + srv.AddInterceptor(i) + } + + handler := mcpserver.NewStreamableHTTPHandler(srv, nil) + httpServer := httptest.NewServer(handler) + t.Cleanup(httpServer.Close) + + client := mcp.NewClient(&mcp.Implementation{Name: "test-client", Version: "0.1.0"}, nil) + cs, err := client.Connect(context.Background(), &mcp.StreamableClientTransport{ + Endpoint: httpServer.URL, + }, nil) + require.NoError(t, err) + t.Cleanup(func() { cs.Close() }) + + return cs +} + +// --- Tool definitions --- + +// testTool pairs a tool definition with its handler for use with setupWithTools. +type testTool struct { + tool *mcp.Tool + handler mcp.ToolHandler +} + +// defaultTools returns the standard "echo" tool used by most tests. +func defaultTools() []testTool { + return []testTool{ + { + tool: &mcp.Tool{ + Name: "echo", + Description: "echoes input", + InputSchema: map[string]any{"type": "object"}, + }, + handler: func(_ context.Context, req *mcp.CallToolRequest) (*mcp.CallToolResult, error) { + return &mcp.CallToolResult{ + Content: []mcp.Content{&mcp.TextContent{Text: fmt.Sprintf("echo: %s", req.Params.Arguments)}}, + }, nil + }, + }, + } +} + +// --- Call helpers --- + +// callEcho calls the "echo" tool and returns the result. +func callEcho(t *testing.T, cs *mcp.ClientSession) (*mcp.CallToolResult, error) { + t.Helper() + return cs.CallTool(context.Background(), &mcp.CallToolParams{ + Name: "echo", + Arguments: map[string]any{"text": "hello"}, + }) +} + +// --- Assertion helpers --- + +// resultText extracts the first TextContent string from a CallToolResult. +// Fails the test if the result has no content or the first item isn't TextContent. +func resultText(t *testing.T, result *mcp.CallToolResult) string { + t.Helper() + require.NotEmpty(t, result.Content) + tc, ok := result.Content[0].(*mcp.TextContent) + require.True(t, ok, "expected TextContent, got %T", result.Content[0]) + return tc.Text +} + +// --- Interceptor builders --- + +// prefixMutator creates a mutator that prepends tag to each TextContent. +func prefixMutator(name, tag string, phase interceptors.Phase, priority int, mode interceptors.Mode) *interceptors.Mutator { + return &interceptors.Mutator{ + Metadata: interceptors.Metadata{ + Name: name, + Events: []string{mcpserver.EventToolsCall}, + Phase: phase, + Mode: mode, + PriorityHint: interceptors.NewPriority(priority), + }, + Handler: func(_ context.Context, inv *interceptors.Invocation) (*interceptors.MutationResult, error) { + result, ok := inv.Payload.(*mcp.CallToolResult) + if !ok { + return nil, fmt.Errorf("unexpected payload type %T", inv.Payload) + } + for _, c := range result.Content { + if tc, ok := c.(*mcp.TextContent); ok { + tc.Text = tag + " " + tc.Text + } + } + return &interceptors.MutationResult{Modified: true}, nil + }, + } +} + +// failMutator creates a mutator that always returns an error. +func failMutator(name string, phase interceptors.Phase, priority int) *interceptors.Mutator { + return &interceptors.Mutator{ + Metadata: interceptors.Metadata{ + Name: name, + Events: []string{mcpserver.EventToolsCall}, + Phase: phase, + Mode: interceptors.ModeOn, + PriorityHint: interceptors.NewPriority(priority), + }, + Handler: func(_ context.Context, _ *interceptors.Invocation) (*interceptors.MutationResult, error) { + return nil, fmt.Errorf("simulated failure") + }, + } +} + +// blockToolValidator creates a request validator that rejects calls to the named tool. +func blockToolValidator(toolName string) *interceptors.Validator { + return &interceptors.Validator{ + Metadata: interceptors.Metadata{ + Name: "block-" + toolName, + Events: []string{mcpserver.EventToolsCall}, + Phase: interceptors.PhaseRequest, + Mode: interceptors.ModeOn, + }, + Handler: func(_ context.Context, inv *interceptors.Invocation) (*interceptors.ValidationResult, error) { + params, ok := inv.Payload.(*mcp.CallToolParamsRaw) + if !ok { + return nil, fmt.Errorf("unexpected payload type %T", inv.Payload) + } + if params.Name == toolName { + return &interceptors.ValidationResult{ + Valid: false, + Messages: []interceptors.ValidationMessage{ + {Message: toolName + " is blocked", Severity: interceptors.SeverityError}, + }, + }, nil + } + return &interceptors.ValidationResult{Valid: true}, nil + }, + } +} + +// allowAllValidator creates a request validator that allows all calls. +func allowAllValidator(name string) *interceptors.Validator { + return &interceptors.Validator{ + Metadata: interceptors.Metadata{ + Name: name, + Events: []string{mcpserver.EventToolsCall}, + Phase: interceptors.PhaseRequest, + Mode: interceptors.ModeOn, + }, + Handler: func(_ context.Context, _ *interceptors.Invocation) (*interceptors.ValidationResult, error) { + return &interceptors.ValidationResult{Valid: true}, nil + }, + } +} diff --git a/go/sdk/interceptors/result.go b/go/sdk/interceptors/result.go new file mode 100644 index 0000000..a84d6a0 --- /dev/null +++ b/go/sdk/interceptors/result.go @@ -0,0 +1,96 @@ +// Copyright 2025 The MCP Interceptors Authors. All rights reserved. +// Use of this source code is governed by an Apache-2.0 +// license that can be found in the LICENSE file. + +package interceptors + +// --- Interceptor result types --- + +// ValidationMessage is a single validation finding. +type ValidationMessage struct { + Path string `json:"path,omitempty"` + Message string `json:"message"` + Severity Severity `json:"severity"` +} + +// ValidationSuggestion is an optional suggested correction. +type ValidationSuggestion struct { + Path string `json:"path"` + Value any `json:"value"` +} + +// ValidationResult is returned by validation interceptors. +type ValidationResult struct { + Valid bool `json:"valid"` + Severity Severity `json:"severity,omitempty"` + Messages []ValidationMessage `json:"messages,omitempty"` + Suggestions []ValidationSuggestion `json:"suggestions,omitempty"` +} + +// MutationResult is returned by mutation interceptors. +type MutationResult struct { + Modified bool `json:"modified"` + Info map[string]any `json:"info,omitempty"` +} + +// --- Chain result --- + +// ChainStatus describes the outcome of a full interceptor chain execution. +type ChainStatus string + +const ( + ChainSuccess ChainStatus = "success" + ChainValidationFailed ChainStatus = "validation_failed" + ChainMutationFailed ChainStatus = "mutation_failed" + ChainTimeout ChainStatus = "timeout" +) + +// AbortType classifies the reason an interceptor chain was aborted. +type AbortType string + +const ( + AbortValidation AbortType = "validation" + AbortMutation AbortType = "mutation" + AbortTimeout AbortType = "timeout" +) + +// ChainResult aggregates results from executing the full interceptor chain. +type ChainResult struct { + Status ChainStatus `json:"status"` + Event string `json:"event"` + Phase Phase `json:"phase"` + Results []ExecutionResult `json:"results"` + FinalPayload any `json:"finalPayload,omitempty"` + ValidationSummary ValidationSummary `json:"validationSummary"` + TotalDurationMs int64 `json:"totalDurationMs"` + AbortedAt []AbortInfo `json:"abortedAt,omitempty"` +} + +// ExecutionResult tracks a single interceptor's execution result. +type ExecutionResult struct { + Interceptor string `json:"interceptor"` // Name of the interceptor + Type InterceptorType `json:"type"` + Phase Phase `json:"phase"` + DurationMs int64 `json:"durationMs,omitempty"` + Error string `json:"error,omitempty"` // Non-empty when the handler returned an error + Info map[string]any `json:"info,omitempty"` + + // One of these is set depending on type: + Validation *ValidationResult `json:"validation,omitempty"` + Mutation *MutationResult `json:"mutation,omitempty"` +} + +// ValidationSummary counts validation outcomes. +type ValidationSummary struct { + Errors int `json:"errors"` + Warnings int `json:"warnings"` + Infos int `json:"infos"` +} + +// AbortInfo describes where and why a chain was aborted. +type AbortInfo struct { + Interceptor string `json:"interceptor"` + Reason string `json:"reason"` + Type AbortType `json:"type"` + Phase string `json:"phase,omitempty"` // phase at which the abort occurred +} diff --git a/go/sdk/internal/.gitkeep b/go/sdk/internal/.gitkeep deleted file mode 100644 index e69de29..0000000