diff --git a/cmd/thv/app/run_flags.go b/cmd/thv/app/run_flags.go index 94c159af8a..c512051d86 100644 --- a/cmd/thv/app/run_flags.go +++ b/cmd/thv/app/run_flags.go @@ -8,6 +8,7 @@ import ( "fmt" "log/slog" "strings" + "time" "github.com/spf13/cobra" @@ -104,6 +105,9 @@ type RunFlags struct { // Endpoint prefix for SSE endpoint URLs EndpointPrefix string + // SessionTTL is the session inactivity timeout. Zero uses the transport default. + SessionTTL time.Duration + // Network mode Network string @@ -264,6 +268,8 @@ func AddRunFlags(cmd *cobra.Command, config *RunFlags) { cmd.Flags().BoolVar(&config.Stateless, "stateless", false, "Declare the server as stateless (POST-only, no SSE). "+ "Use for MCP servers implementing streamable-HTTP stateless mode.") + cmd.Flags().DurationVar(&config.SessionTTL, "session-ttl", 0, + "Session inactivity timeout (e.g., 30m, 2h); zero uses the default (2h)") cmd.Flags().StringVar(&config.EndpointPrefix, "endpoint-prefix", "", "Path prefix to prepend to SSE endpoint URLs (e.g., /playwright)") cmd.Flags().StringVar(&config.Network, "network", "", @@ -665,6 +671,7 @@ func buildRunnerConfig( runner.WithAllowDockerGateway(runFlags.AllowDockerGateway), runner.WithTrustProxyHeaders(runFlags.TrustProxyHeaders), runner.WithStateless(runFlags.Stateless), + runner.WithSessionTTL(runFlags.SessionTTL), runner.WithEndpointPrefix(runFlags.EndpointPrefix), runner.WithNetworkMode(runFlags.Network), runner.WithK8sPodPatch(runFlags.K8sPodPatch), diff --git a/cmd/thv/app/vmcp.go b/cmd/thv/app/vmcp.go index ef3cf72e81..26d07296e8 100644 --- a/cmd/thv/app/vmcp.go +++ b/cmd/thv/app/vmcp.go @@ -5,6 +5,7 @@ package app import ( "fmt" + "time" "github.com/spf13/cobra" @@ -39,6 +40,7 @@ func newVMCPServeCommand() *cobra.Command { enableEmbedding bool embeddingModel string embeddingImage string + sessionTTL time.Duration ) cmd := &cobra.Command{ Use: "serve", @@ -64,6 +66,7 @@ configuration file is needed for the common case of aggregating a local group.`, EnableEmbedding: enableEmbedding, EmbeddingModel: embeddingModel, EmbeddingImage: embeddingImage, + SessionTTL: sessionTTL, }) }, } @@ -80,6 +83,8 @@ configuration file is needed for the common case of aggregating a local group.`, cmd.Flags().StringVar(&host, "host", "127.0.0.1", "Host address to bind to") cmd.Flags().IntVar(&port, "port", 4483, "Port to listen on") cmd.Flags().BoolVar(&enableAudit, "enable-audit", false, "Enable audit logging with default configuration") + cmd.Flags().DurationVar(&sessionTTL, "session-ttl", 0, + "Session inactivity timeout (e.g., 30m, 2h); zero uses the default (30m)") return cmd } diff --git a/cmd/vmcp/app/commands.go b/cmd/vmcp/app/commands.go index a30c6209e9..bbc7b0db57 100644 --- a/cmd/vmcp/app/commands.go +++ b/cmd/vmcp/app/commands.go @@ -7,6 +7,7 @@ package app import ( "fmt" "log/slog" + "time" "github.com/spf13/cobra" "github.com/spf13/viper" @@ -96,12 +97,14 @@ from all configured backend MCP servers.`, host, _ := cmd.Flags().GetString("host") port, _ := cmd.Flags().GetInt("port") enableAudit, _ := cmd.Flags().GetBool("enable-audit") + sessionTTL, _ := cmd.Flags().GetDuration("session-ttl") return vmcpcli.Serve(cmd.Context(), vmcpcli.ServeConfig{ ConfigPath: configPath, Host: host, Port: port, EnableAudit: enableAudit, + SessionTTL: sessionTTL, }) }, } @@ -110,6 +113,8 @@ from all configured backend MCP servers.`, cmd.Flags().String("host", "127.0.0.1", "Host address to bind to") cmd.Flags().Int("port", 4483, "Port to listen on") cmd.Flags().Bool("enable-audit", false, "Enable audit logging with default configuration") + cmd.Flags().Duration("session-ttl", time.Duration(0), + "Session inactivity timeout (e.g., 30m, 2h); zero uses the default (30m)") return cmd } diff --git a/docs/cli/thv_run.md b/docs/cli/thv_run.md index d3c38dc4cb..b69d7b6cd1 100644 --- a/docs/cli/thv_run.md +++ b/docs/cli/thv_run.md @@ -178,6 +178,7 @@ thv run [flags] SERVER_OR_IMAGE_OR_PROTOCOL [-- ARGS...] --runtime-add-package stringArray Add additional packages to install in the builder and runtime stages (can be repeated) --runtime-image string Override the default base image for protocol schemes (e.g., golang:1.24-alpine, node:20-alpine, python:3.11-slim) --secret stringArray Specify a secret to be fetched from the secrets manager and set as an environment variable (format: NAME,target=TARGET) + --session-ttl duration Session inactivity timeout (e.g., 30m, 2h); zero uses the default (2h) --stateless Declare the server as stateless (POST-only, no SSE). Use for MCP servers implementing streamable-HTTP stateless mode. --target-host string Host to forward traffic to (only applicable to SSE or Streamable HTTP transport) (default "127.0.0.1") --target-port int Port for the container to expose (only applicable to SSE or Streamable HTTP transport) diff --git a/docs/cli/thv_vmcp_serve.md b/docs/cli/thv_vmcp_serve.md index f20718f6a4..48243c7ee4 100644 --- a/docs/cli/thv_vmcp_serve.md +++ b/docs/cli/thv_vmcp_serve.md @@ -42,6 +42,7 @@ thv vmcp serve [flags] --optimizer Enable FTS5 keyword optimizer (Tier 1): exposes find_tool and call_tool instead of all backend tools --optimizer-embedding Enable managed TEI semantic optimizer (Tier 2); implies --optimizer --port int Port to listen on (default 4483) + --session-ttl duration Session inactivity timeout (e.g., 30m, 2h); zero uses the default (30m) ``` ### Options inherited from parent commands diff --git a/docs/server/docs.go b/docs/server/docs.go index 761041e32e..fb3b83c802 100644 --- a/docs/server/docs.go +++ b/docs/server/docs.go @@ -1321,6 +1321,10 @@ const docTemplate = `{ "type": "array", "uniqueItems": false }, + "session_ttl": { + "description": "SessionTTL is the inactivity timeout for proxy sessions.\nZero uses the transport default (2h). Negative values are rejected by the builder.", + "type": "integer" + }, "stateless": { "description": "Stateless indicates the server only supports POST (no SSE/GET).\nWhen true, the proxy returns 405 for incoming GET requests and uses a\nPOST-based health check instead of the default GET probe.\nApplies to both remote URLs and local container workloads.", "type": "boolean" diff --git a/docs/server/swagger.json b/docs/server/swagger.json index 91b6e3b3aa..c3f189a423 100644 --- a/docs/server/swagger.json +++ b/docs/server/swagger.json @@ -1314,6 +1314,10 @@ "type": "array", "uniqueItems": false }, + "session_ttl": { + "description": "SessionTTL is the inactivity timeout for proxy sessions.\nZero uses the transport default (2h). Negative values are rejected by the builder.", + "type": "integer" + }, "stateless": { "description": "Stateless indicates the server only supports POST (no SSE/GET).\nWhen true, the proxy returns 405 for incoming GET requests and uses a\nPOST-based health check instead of the default GET probe.\nApplies to both remote URLs and local container workloads.", "type": "boolean" diff --git a/docs/server/swagger.yaml b/docs/server/swagger.yaml index f1118124d0..aed459bf8a 100644 --- a/docs/server/swagger.yaml +++ b/docs/server/swagger.yaml @@ -1266,6 +1266,11 @@ components: type: string type: array uniqueItems: false + session_ttl: + description: |- + SessionTTL is the inactivity timeout for proxy sessions. + Zero uses the transport default (2h). Negative values are rejected by the builder. + type: integer stateless: description: |- Stateless indicates the server only supports POST (no SSE/GET). diff --git a/pkg/runner/config.go b/pkg/runner/config.go index 3934a632dd..6cc7a3a8c2 100644 --- a/pkg/runner/config.go +++ b/pkg/runner/config.go @@ -10,6 +10,7 @@ import ( "fmt" "io" "log/slog" + "time" "github.com/stacklok/toolhive-core/permissions" v1beta1 "github.com/stacklok/toolhive/cmd/thv-operator/api/v1beta1" @@ -203,6 +204,10 @@ type RunConfig struct { // Applies to both remote URLs and local container workloads. Stateless bool `json:"stateless,omitempty" yaml:"stateless,omitempty"` + // SessionTTL is the inactivity timeout for proxy sessions. + // Zero uses the transport default (2h). Negative values are rejected by the builder. + SessionTTL time.Duration `json:"session_ttl,omitempty" yaml:"session_ttl,omitempty" swaggertype:"primitive,integer"` + // ProxyMode is the effective HTTP protocol the proxy uses. // For stdio transports, this is the configured mode (sse or streamable-http). // For direct transports (sse/streamable-http), this matches the transport type. diff --git a/pkg/runner/config_builder.go b/pkg/runner/config_builder.go index 88c4fd1cb2..8b82b46e77 100644 --- a/pkg/runner/config_builder.go +++ b/pkg/runner/config_builder.go @@ -14,6 +14,7 @@ import ( "path/filepath" "slices" "strings" + "time" "github.com/stacklok/toolhive-core/permissions" regtypes "github.com/stacklok/toolhive-core/registry/types" @@ -362,6 +363,19 @@ func WithEndpointPrefix(prefix string) RunConfigBuilderOption { } } +// WithSessionTTL sets the inactivity timeout for proxy sessions. +// Zero is valid and means "use the transport default" (2h). +// Negative values return an error. +func WithSessionTTL(ttl time.Duration) RunConfigBuilderOption { + return func(b *runConfigBuilder) error { + if ttl < 0 { + return fmt.Errorf("session-ttl must be non-negative, got %s", ttl) + } + b.config.SessionTTL = ttl + return nil + } +} + // WithNetworkMode sets the network mode for the container. // The network mode will be applied to the permission profile after it is loaded. func WithNetworkMode(networkMode string) RunConfigBuilderOption { diff --git a/pkg/runner/config_builder_test.go b/pkg/runner/config_builder_test.go index af885a219f..4f5e38f00e 100644 --- a/pkg/runner/config_builder_test.go +++ b/pkg/runner/config_builder_test.go @@ -1489,6 +1489,57 @@ func TestWithRegistryServerName(t *testing.T) { } } +func TestWithSessionTTL(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + ttl time.Duration + expectErr bool + expectedTTL time.Duration + }{ + { + name: "zero is accepted and means use default", + ttl: 0, + expectErr: false, + expectedTTL: 0, + }, + { + name: "positive duration is accepted", + ttl: 45 * time.Minute, + expectErr: false, + expectedTTL: 45 * time.Minute, + }, + { + name: "large positive duration is accepted", + ttl: 24 * time.Hour, + expectErr: false, + expectedTTL: 24 * time.Hour, + }, + { + name: "negative duration returns an error", + ttl: -1 * time.Second, + expectErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + builder := &runConfigBuilder{config: NewRunConfig()} + err := WithSessionTTL(tt.ttl)(builder) + + if tt.expectErr { + require.Error(t, err) + return + } + require.NoError(t, err) + assert.Equal(t, tt.expectedTTL, builder.config.SessionTTL) + }) + } +} + func TestResolveRegistryServerName(t *testing.T) { t.Parallel() diff --git a/pkg/runner/runner.go b/pkg/runner/runner.go index 9e296d1fc8..afb489ab3b 100644 --- a/pkg/runner/runner.go +++ b/pkg/runner/runner.go @@ -166,6 +166,14 @@ func (c *RunConfig) GetPort() int { // //nolint:gocyclo // This function is complex but manageable func (r *Runner) Run(ctx context.Context) error { + // Resolve session TTL once so both the transport proxy and Redis storage use + // the same effective value, rather than each applying their own zero-fallback + // independently. + effectiveSessionTTL := r.Config.SessionTTL + if effectiveSessionTTL <= 0 { + effectiveSessionTTL = session.DefaultSessionTTL + } + // Create transport with runtime transportConfig := types.Config{ Type: r.Config.Transport, @@ -177,6 +185,7 @@ func (r *Runner) Run(ctx context.Context) error { Debug: r.Config.Debug, TrustProxyHeaders: r.Config.TrustProxyHeaders, EndpointPrefix: r.Config.EndpointPrefix, + SessionTTL: effectiveSessionTTL, } // Set proxy mode for stdio transport @@ -368,7 +377,7 @@ func (r *Runner) Run(ctx context.Context) error { Password: os.Getenv(session.RedisPasswordEnvVar), DB: int(redisCfg.DB), KeyPrefix: keyPrefix, - }, session.DefaultSessionTTL) + }, effectiveSessionTTL) if err != nil { return fmt.Errorf("failed to create Redis session storage: %w", err) } diff --git a/pkg/transport/factory.go b/pkg/transport/factory.go index 003bab13ee..9322464836 100644 --- a/pkg/transport/factory.go +++ b/pkg/transport/factory.go @@ -55,6 +55,7 @@ func (*Factory) Create(config types.Config, opts ...Option) (types.Transport, er if config.SessionStorage != nil { stdio.SetSessionStorage(config.SessionStorage) } + stdio.SetSessionTTL(config.SessionTTL) tr = stdio case types.TransportTypeSSE: httpTransport := NewHTTPTransport( @@ -73,6 +74,7 @@ func (*Factory) Create(config types.Config, opts ...Option) (types.Transport, er config.Middlewares..., ) httpTransport.sessionStorage = config.SessionStorage + httpTransport.sessionTTL = config.SessionTTL tr = httpTransport case types.TransportTypeStreamableHTTP: httpTransport := NewHTTPTransport( @@ -91,6 +93,7 @@ func (*Factory) Create(config types.Config, opts ...Option) (types.Transport, er config.Middlewares..., ) httpTransport.sessionStorage = config.SessionStorage + httpTransport.sessionTTL = config.SessionTTL tr = httpTransport case types.TransportTypeInspector: // HTTP transport is not implemented yet diff --git a/pkg/transport/http.go b/pkg/transport/http.go index 53b9af6d12..decadcb96c 100644 --- a/pkg/transport/http.go +++ b/pkg/transport/http.go @@ -13,6 +13,7 @@ import ( "os" "strings" "sync" + "time" "golang.org/x/oauth2" @@ -81,6 +82,10 @@ type HTTPTransport struct { // Used for Redis-backed session sharing across replicas. sessionStorage session.Storage + // sessionTTL overrides the inactivity timeout for sessions managed by the + // underlying proxy. Zero uses the proxy's default. + sessionTTL time.Duration + // Transparent proxy proxy types.Proxy @@ -432,6 +437,9 @@ func (t *HTTPTransport) buildProxyOptions(remoteBasePath, remoteRawQuery string) if t.stateless { opts = append(opts, transparent.WithStateless()) } + if t.sessionTTL > 0 { + opts = append(opts, transparent.WithSessionTTL(t.sessionTTL)) + } if t.sessionStorage != nil { opts = append(opts, transparent.WithSessionStorage(t.sessionStorage)) } diff --git a/pkg/transport/proxy/httpsse/http_proxy.go b/pkg/transport/proxy/httpsse/http_proxy.go index 9890317fcc..3eb6943975 100644 --- a/pkg/transport/proxy/httpsse/http_proxy.go +++ b/pkg/transport/proxy/httpsse/http_proxy.go @@ -74,6 +74,14 @@ type HTTPSSEProxy struct { // Session manager for SSE clients sessionManager *session.Manager + // sessionTTL is the resolved inactivity timeout for the session manager. + // Defaults to session.DefaultSessionTTL; overridable via WithSessionTTL. + sessionTTL time.Duration + + // sessionStorage is the optional custom storage backend for the session manager. + // When nil, in-memory LocalStorage is used. Set via WithSessionStorage. + sessionStorage session.Storage + // liveSSESessions tracks active SSE connections local to this instance. // Keys are clientID strings; values are *session.SSESession. // This is separate from sessionManager so that distributed storage backends @@ -121,11 +129,18 @@ func WithSessionStorage(storage session.Storage) Option { if storage == nil { return } - if p.sessionManager != nil { - _ = p.sessionManager.Stop() + p.sessionStorage = storage + } +} + +// WithSessionTTL overrides the session inactivity timeout used by this proxy. +// Zero or negative values are ignored so the constructor's default is preserved. +func WithSessionTTL(ttl time.Duration) Option { + return func(p *HTTPSSEProxy) { + if ttl <= 0 { + return } - sseFactory := func(id string) session.Session { return session.NewSSESession(id) } - p.sessionManager = session.NewManagerWithStorage(session.DefaultSessionTTL, sseFactory, storage) + p.sessionTTL = ttl } } @@ -138,11 +153,6 @@ func NewHTTPSSEProxy( middlewares []types.NamedMiddleware, opts ...Option, ) *HTTPSSEProxy { - // Create a factory for SSE sessions - sseFactory := func(id string) session.Session { - return session.NewSSESession(id) - } - proxy := &HTTPSSEProxy{ middlewares: middlewares, host: host, @@ -150,7 +160,7 @@ func NewHTTPSSEProxy( trustProxyHeaders: trustProxyHeaders, shutdownCh: make(chan struct{}), messageCh: make(chan jsonrpc2.Message, 100), - sessionManager: session.NewManager(session.DefaultSessionTTL, sseFactory), + sessionTTL: session.DefaultSessionTTL, pendingMessages: []*ssecommon.PendingSSEMessage{}, prometheusHandler: prometheusHandler, } @@ -159,6 +169,14 @@ func NewHTTPSSEProxy( opt(proxy) } + // Construct the session manager once, after options have resolved sessionTTL and sessionStorage. + sseFactory := func(id string) session.Session { return session.NewSSESession(id) } + if proxy.sessionStorage != nil { + proxy.sessionManager = session.NewManagerWithStorage(proxy.sessionTTL, sseFactory, proxy.sessionStorage) + } else { + proxy.sessionManager = session.NewManager(proxy.sessionTTL, sseFactory) + } + // Create MCP pinger and health checker mcpPinger := NewMCPPinger(proxy) proxy.healthChecker = healthcheck.NewHealthChecker("stdio", mcpPinger) @@ -429,7 +447,9 @@ func (p *HTTPSSEProxy) handleSSEConnection(w http.ResponseWriter, r *http.Reques } flusher.Flush() case <-keepAliveTicker.C: - // Send SSE comment as keep-alive + // Refresh session TTL while the SSE socket is open so the cleanup + // goroutine does not evict clients that haven't sent a POST recently. + p.sessionManager.Get(clientID) if _, err := fmt.Fprint(w, ": keep-alive\n\n"); err != nil { slog.Debug("failed to write keep-alive", "error", err) return diff --git a/pkg/transport/proxy/streamable/streamable_proxy.go b/pkg/transport/proxy/streamable/streamable_proxy.go index a51ff24b12..2a2c957034 100644 --- a/pkg/transport/proxy/streamable/streamable_proxy.go +++ b/pkg/transport/proxy/streamable/streamable_proxy.go @@ -56,9 +56,21 @@ type HTTPProxy struct { // Session manager for streamable HTTP sessions sessionManager *session.Manager - // Waiters keyed by JSON-encoded request ID -> one-shot channel for response delivery + // sessionTTL is the resolved inactivity timeout for the session manager. + // Defaults to session.DefaultSessionTTL; overridable via WithSessionTTL. + sessionTTL time.Duration + + // sessionStorage is the optional custom storage backend for the session manager. + // When nil, in-memory LocalStorage is used. Set via WithSessionStorage. + sessionStorage session.Storage + + // Waiters keyed by compositeKey(sessID, idKey) -> one-shot channel for response delivery. + // The composite key MUST be unique per concurrent request; sharing it across requests + // (e.g. with sessID="" for sessionless requests) silently overwrites entries and crosses + // response payloads between unrelated clients. See resolveSessionForRequest. waiters sync.Map // map[string]chan jsonrpc2.Message - // Map of compositeKey(sessID|idKey) -> original client JSON-RPC ID to restore before replying + // Keyed by the same compositeKey(sessID, idKey); stores the original client JSON-RPC ID + // to restore before replying. Same uniqueness requirement as `waiters`. idRestore sync.Map // map[string]jsonrpc2.ID // Health checker @@ -78,11 +90,18 @@ func WithSessionStorage(storage session.Storage) Option { if storage == nil { return } - if p.sessionManager != nil { - _ = p.sessionManager.Stop() + p.sessionStorage = storage + } +} + +// WithSessionTTL overrides the session inactivity timeout used by this proxy. +// Zero or negative values are ignored so the constructor's default is preserved. +func WithSessionTTL(ttl time.Duration) Option { + return func(p *HTTPProxy) { + if ttl <= 0 { + return } - sFactory := func(id string) session.Session { return session.NewStreamableSession(id) } - p.sessionManager = session.NewManagerWithStorage(session.DefaultSessionTTL, sFactory, storage) + p.sessionTTL = ttl } } @@ -106,13 +125,20 @@ func NewHTTPProxy( middlewares: middlewares, messageCh: make(chan jsonrpc2.Message, 100), responseCh: make(chan jsonrpc2.Message, 100), - sessionManager: session.NewManager(session.DefaultSessionTTL, sFactory), + sessionTTL: session.DefaultSessionTTL, } for _, opt := range opts { opt(proxy) } + // Construct the session manager once, after options have resolved sessionTTL and sessionStorage. + if proxy.sessionStorage != nil { + proxy.sessionManager = session.NewManagerWithStorage(proxy.sessionTTL, sFactory, proxy.sessionStorage) + } else { + proxy.sessionManager = session.NewManager(proxy.sessionTTL, sFactory) + } + // Create health checker without MCP pinger // Streamable transport doesn't support MCP ping, so health check only verifies proxy is running proxy.healthChecker = healthcheck.NewHealthChecker(string(types.TransportTypeStreamableHTTP), nil) @@ -296,7 +322,7 @@ func (p *HTTPProxy) handlePost(w http.ResponseWriter, r *http.Request) { } // Notifications or client responses are accepted and forwarded (202) - if p.handleNotificationOrClientResponse(w, msg) { + if p.handleNotificationOrClientResponse(w, r.Header.Get("Mcp-Session-Id"), msg) { return } @@ -619,17 +645,21 @@ func (p *HTTPProxy) ensureSession(id string) error { return p.sessionManager.AddWithID(id) } -// resolveSessionForBatch resolves or creates an ephemeral session for batch POSTs. +// resolveSessionForBatch resolves the session for batch POSTs. // Writes appropriate HTTP errors and returns an error when handling should stop. +// +// Sessionless POSTs receive a per-request UUID used solely as an in-process +// routing token. Sessionless routing tokens MUST be unique per request: +// sharing one (e.g. the empty string) across concurrent sessionless requests +// causes waiters/idRestore overwrites in the in-process sync.Maps, which leaks +// one client's response payload to another (with the JSON-RPC id rewritten to +// the receiver's). This is a confidentiality bug, not a performance issue -- +// do not collapse the token. The UUID is not registered with sessionManager, +// so no session object is created in any storage backend. func (p *HTTPProxy) resolveSessionForBatch(w http.ResponseWriter, r *http.Request) (string, error) { sessID := r.Header.Get("Mcp-Session-Id") if sessID == "" { - sessID = uuid.New().String() - if err := p.ensureSession(sessID); err != nil { - writeHTTPError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to create session: %v", err)) - return "", err - } - return sessID, nil + return uuid.New().String(), nil } if _, ok := p.sessionManager.Get(sessID); !ok { session.WriteNotFound(w, nil) @@ -639,8 +669,18 @@ func (p *HTTPProxy) resolveSessionForBatch(w http.ResponseWriter, r *http.Reques } // resolveSessionForRequest resolves session rules for a single JSON-RPC request. -// On initialize, assigns session if missing and returns setSessionHeader=true. -// For other methods, allows optional session by creating ephemeral (no header set). +// On initialize, assigns a new session ID if none is provided and returns setSessionHeader=true. +// A provided but unknown session ID returns 404. +// +// Sessionless non-initialize requests receive a per-request UUID used solely as +// an in-process routing token (not registered with sessionManager). Sessionless +// routing tokens MUST be unique per request: sharing one (e.g. the empty string) +// across concurrent sessionless requests with the same JSON-RPC id collapses +// them onto the same compositeKey(sessID, idKey) and overwrites entries in the +// waiters / idRestore sync.Maps, leaking one client's response payload to +// another. This is a confidentiality bug, not a performance issue -- do not +// collapse the token. +// // Writes HTTP errors on failure and returns error to stop handling. func (p *HTTPProxy) resolveSessionForRequest( w http.ResponseWriter, @@ -662,17 +702,14 @@ func (p *HTTPProxy) resolveSessionForRequest( return sessID, setSessionHeader, nil } - // Non-initialize path: sessions are optional; create ephemeral if missing + // Sessionless non-initialize: generate a per-request routing token. + // setSessionHeader stays false so the client never sees this UUID and the + // next request remains sessionless. if sessID == "" { - sessID = uuid.New().String() - if err := p.ensureSession(sessID); err != nil { - writeHTTPError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to create session: %v", err)) - return "", false, err - } - return sessID, false, nil + return uuid.New().String(), false, nil } - // If session is provided, ensure it exists + // Session ID provided but not found: reject with 404. if _, ok := p.sessionManager.Get(sessID); !ok { session.WriteNotFound(w, req.ID.Raw()) return "", false, fmt.Errorf("session not found") @@ -708,8 +745,12 @@ func decodeJSONRPCMessage(w http.ResponseWriter, body []byte) (jsonrpc2.Message, return msg, true } -func (p *HTTPProxy) handleNotificationOrClientResponse(w http.ResponseWriter, msg jsonrpc2.Message) bool { +func (p *HTTPProxy) handleNotificationOrClientResponse(w http.ResponseWriter, sessID string, msg jsonrpc2.Message) bool { if isNotification(msg) || (func() bool { _, ok := msg.(*jsonrpc2.Response); return ok })() { + // Refresh TTL so a client sending only notifications doesn't get evicted. + if sessID != "" { + p.sessionManager.Get(sessID) + } if err := p.SendMessageToDestination(msg); err != nil { slog.Error("failed to send message to destination", "error", err) } diff --git a/pkg/transport/proxy/streamable/streamable_proxy_spec_test.go b/pkg/transport/proxy/streamable/streamable_proxy_spec_test.go index c13aea6896..1574462661 100644 --- a/pkg/transport/proxy/streamable/streamable_proxy_spec_test.go +++ b/pkg/transport/proxy/streamable/streamable_proxy_spec_test.go @@ -7,7 +7,9 @@ import ( "bytes" "context" "encoding/json" + "fmt" "io" + "net" "net/http" "testing" "time" @@ -320,3 +322,115 @@ func TestSingleRequestWithStaleSessionIncludesRequestID(t *testing.T) { assert.Contains(t, string(body), `"code":-32001`) assert.Contains(t, string(body), `"id":"test-42"`) } + +// pickFreePort returns a TCP port the OS reports as available. There is a small +// race window before the proxy binds it, but that is the same pattern other +// streamable tests follow and is acceptable here. +func pickFreePort(t *testing.T) int { + t.Helper() + l, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + port := l.Addr().(*net.TCPAddr).Port + require.NoError(t, l.Close()) + return port +} + +// TestSessionlessConcurrentRequestsAreNotMixed verifies that two concurrent +// sessionless POSTs sharing a JSON-RPC id each receive their own response +// payload. Regression test: a previous change collapsed every sessionless +// request onto compositeKey("", idKey), causing the in-process waiters / +// idRestore sync.Maps to silently overwrite — symptoms were response +// cross-talk (one client receiving the other's payload, with the JSON-RPC id +// rewritten back to its own) and a request-timeout for the losing client. +// +// t.Setenv requires a non-parallel test; the trade-off is acceptable for a +// single regression test that needs a short proxy timeout to fail fast. +func TestSessionlessConcurrentRequestsAreNotMixed(t *testing.T) { + // Cap per-request timeout so this test fails in seconds, not the 60s + // default, when the bug regresses. + t.Setenv(proxyRequestTimeoutEnv, "3s") + + port := pickFreePort(t) + proxy := NewHTTPProxy("127.0.0.1", port, nil, nil) + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + require.NoError(t, proxy.Start(ctx)) + t.Cleanup(func() { _ = proxy.Stop(ctx) }) + + time.Sleep(50 * time.Millisecond) + + // Backend uses a synchronization barrier: collect both incoming requests + // before sending any response, so both proxy waiters are registered first + // — that's the precondition under which the bug overwrites one waiter. + // Echo req.Method back in the result so we can prove each client got its + // own payload, not the other's. + go func() { + buffered := make([]*jsonrpc2.Request, 0, 2) + for len(buffered) < 2 { + select { + case msg := <-proxy.GetMessageChannel(): + if req, ok := msg.(*jsonrpc2.Request); ok && req.ID.IsValid() { + buffered = append(buffered, req) + } + case <-ctx.Done(): + return + } + } + for _, req := range buffered { + result := map[string]any{"echoed_method": req.Method} + resp, _ := jsonrpc2.NewResponse(req.ID, result, nil) + _ = proxy.ForwardResponseToClients(ctx, resp) + } + }() + + url := fmt.Sprintf("http://127.0.0.1:%d%s", port, StreamableHTTPEndpoint) + + type result struct { + method string + status int + body map[string]any + err error + } + + // Both POSTs share the same JSON-RPC id (1) and omit Mcp-Session-Id. + fire := func(method string) result { + body := fmt.Sprintf(`{"jsonrpc":"2.0","id":1,"method":%q,"params":{}}`, method) + client := &http.Client{Timeout: 8 * time.Second} + resp, err := client.Post(url, "application/json", bytes.NewReader([]byte(body))) + if err != nil { + return result{method: method, err: err} + } + defer func() { + _, _ = io.Copy(io.Discard, resp.Body) + resp.Body.Close() + }() + var decoded map[string]any + _ = json.NewDecoder(resp.Body).Decode(&decoded) + return result{method: method, status: resp.StatusCode, body: decoded} + } + + resCh := make(chan result, 2) + go func() { resCh <- fire("tools/list") }() + go func() { resCh <- fire("resources/list") }() + + received := map[string]result{} + for i := 0; i < 2; i++ { + select { + case r := <-resCh: + received[r.method] = r + case <-time.After(15 * time.Second): + t.Fatalf("timeout waiting for concurrent sessionless responses; received so far: %v", received) + } + } + + for _, method := range []string{"tools/list", "resources/list"} { + r := received[method] + require.NoError(t, r.err, "client %q HTTP error", method) + require.Equal(t, http.StatusOK, r.status, "client %q HTTP status", method) + require.NotNil(t, r.body, "client %q empty body", method) + res, ok := r.body["result"].(map[string]any) + require.True(t, ok, "client %q missing result: %v", method, r.body) + assert.Equal(t, method, res["echoed_method"], + "client %q received the other client's payload (response cross-talk)", method) + } +} diff --git a/pkg/transport/proxy/transparent/transparent_proxy.go b/pkg/transport/proxy/transparent/transparent_proxy.go index f5d9599b9a..e17e124421 100644 --- a/pkg/transport/proxy/transparent/transparent_proxy.go +++ b/pkg/transport/proxy/transparent/transparent_proxy.go @@ -78,6 +78,14 @@ type TransparentProxy struct { // Sessions for tracking state sessionManager *session.Manager + // sessionTTL is the resolved inactivity timeout for the session manager. + // Defaults to session.DefaultSessionTTL; overridable via WithSessionTTL. + sessionTTL time.Duration + + // sessionStorage is the optional custom storage backend for the session manager. + // When nil, in-memory LocalStorage is used. Set via WithSessionStorage. + sessionStorage session.Storage + // If mcp server has been initialized (atomic access) isServerInitialized atomic.Bool @@ -290,14 +298,18 @@ func WithSessionStorage(storage session.Storage) Option { if storage == nil { return } - if p.sessionManager != nil { - _ = p.sessionManager.Stop() + p.sessionStorage = storage + } +} + +// WithSessionTTL overrides the session inactivity timeout used by this proxy. +// Zero or negative values are ignored so the constructor's default is preserved. +func WithSessionTTL(ttl time.Duration) Option { + return func(p *TransparentProxy) { + if ttl <= 0 { + return } - p.sessionManager = session.NewManagerWithStorage( - session.DefaultSessionTTL, - func(id string) session.Session { return session.NewProxySession(id) }, - storage, - ) + p.sessionTTL = ttl } } @@ -426,7 +438,7 @@ func NewTransparentProxyWithOptions( prometheusHandler: prometheusHandler, authInfoHandler: authInfoHandler, prefixHandlers: prefixHandlers, - sessionManager: session.NewManager(session.DefaultSessionTTL, session.NewProxySession), + sessionTTL: session.DefaultSessionTTL, isRemote: isRemote, transportType: transportType, onHealthCheckFailed: onHealthCheckFailed, @@ -445,6 +457,14 @@ func NewTransparentProxyWithOptions( opt(proxy) } + // Construct the session manager once, after options have resolved sessionTTL and sessionStorage. + proxyFactory := func(id string) session.Session { return session.NewProxySession(id) } + if proxy.sessionStorage != nil { + proxy.sessionManager = session.NewManagerWithStorage(proxy.sessionTTL, proxyFactory, proxy.sessionStorage) + } else { + proxy.sessionManager = session.NewManager(proxy.sessionTTL, proxyFactory) + } + // Create appropriate response processor based on transport type proxy.responseProcessor = createResponseProcessor( transportType, diff --git a/pkg/transport/stdio.go b/pkg/transport/stdio.go index 315dbb11ab..cb92480246 100644 --- a/pkg/transport/stdio.go +++ b/pkg/transport/stdio.go @@ -63,6 +63,7 @@ type StdioTransport struct { prometheusHandler http.Handler trustProxyHeaders bool sessionStorage session.Storage + sessionTTL time.Duration // Mutex for protecting shared state mutex sync.Mutex @@ -142,6 +143,12 @@ func (t *StdioTransport) SetSessionStorage(storage session.Storage) { t.sessionStorage = storage } +// SetSessionTTL configures the inactivity timeout for sessions managed by the +// underlying proxy. Zero is valid and means "use the proxy's default". +func (t *StdioTransport) SetSessionTTL(ttl time.Duration) { + t.sessionTTL = ttl +} + // Mode returns the transport mode. func (*StdioTransport) Mode() types.TransportType { return types.TransportTypeStdio @@ -192,6 +199,9 @@ func (t *StdioTransport) Start(ctx context.Context) error { switch t.proxyMode { case types.ProxyModeStreamableHTTP: var streamableOpts []streamable.Option + if t.sessionTTL > 0 { + streamableOpts = append(streamableOpts, streamable.WithSessionTTL(t.sessionTTL)) + } if t.sessionStorage != nil { streamableOpts = append(streamableOpts, streamable.WithSessionStorage(t.sessionStorage)) } @@ -202,6 +212,9 @@ func (t *StdioTransport) Start(ctx context.Context) error { slog.Debug("streamable HTTP proxy started, processing messages") case types.ProxyModeSSE: var sseOpts []httpsse.Option + if t.sessionTTL > 0 { + sseOpts = append(sseOpts, httpsse.WithSessionTTL(t.sessionTTL)) + } if t.sessionStorage != nil { sseOpts = append(sseOpts, httpsse.WithSessionStorage(t.sessionStorage)) } diff --git a/pkg/transport/types/transport.go b/pkg/transport/types/transport.go index beec37e407..1096827531 100644 --- a/pkg/transport/types/transport.go +++ b/pkg/transport/types/transport.go @@ -11,6 +11,7 @@ import ( "context" "encoding/json" "net/http" + "time" "golang.org/x/exp/jsonrpc2" "golang.org/x/oauth2" @@ -276,6 +277,11 @@ type Config struct { // Used for Redis-backed session sharing across replicas. // When nil, transports use their default in-memory LocalStorage. SessionStorage session.Storage + + // SessionTTL is the inactivity timeout for sessions managed by this proxy. + // Sessions idle for longer than this duration are cleaned up by the session + // manager's background worker. Zero uses session.DefaultSessionTTL. + SessionTTL time.Duration } // ProxyMode represents the proxy mode for stdio transport. diff --git a/pkg/vmcp/cli/serve.go b/pkg/vmcp/cli/serve.go index ffb501115c..2097e39c64 100644 --- a/pkg/vmcp/cli/serve.go +++ b/pkg/vmcp/cli/serve.go @@ -69,6 +69,10 @@ type ServeConfig struct { // the loaded config does not already define an audit section. EnableAudit bool + // SessionTTL is the inactivity timeout for vMCP sessions. + // Zero uses the server default (30m). Negative values fail validation. + SessionTTL time.Duration + // Optimizer tier selection (Phase 4 — flag-driven). // EnableOptimizer enables Tier 1 FTS5 keyword search (find_tool / call_tool). EnableOptimizer bool @@ -113,6 +117,9 @@ func Serve(ctx context.Context, cfg ServeConfig) error { if err := cfg.validateQuickModeHost(); err != nil { return err } + if cfg.SessionTTL < 0 { + return fmt.Errorf("session-ttl must be non-negative, got %s", cfg.SessionTTL) + } // Load and validate configuration — file path takes precedence over group quick mode. vmcpCfg, err := func() (*config.Config, error) { @@ -373,6 +380,7 @@ func Serve(ctx context.Context, cfg ServeConfig) error { GroupRef: vmcpCfg.Group, Host: cfg.Host, Port: cfg.Port, + SessionTTL: cfg.SessionTTL, AuthMiddleware: authMiddleware, AuthzMiddleware: authzMiddleware, AuthInfoHandler: authInfoHandler,