From 804df32ca1480a4e35635151ebcf74151e06f951 Mon Sep 17 00:00:00 2001 From: Juan Antonio Osorio Date: Mon, 20 Apr 2026 09:05:33 +0300 Subject: [PATCH] Allow embedders to inject custom session DataStorage MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The vMCP server currently only supports the "memory" and "redis" providers encoded in SessionStorageConfig. Embedders that run on Postgres, DynamoDB, Spanner, etc. would have to fork server.go or fall back to in-memory sessions (no persistence across pod restarts). Add an optional Config.DataStorage field of type transportsession.DataStorage. When non-nil, the server uses the caller-supplied store directly and the SessionStorage enum is ignored. Setting both is rejected at New() so misconfiguration surfaces loudly instead of silently favouring one. Caller owns the lifecycle: the server never calls Close() on a caller-supplied store, matching the existing convention for every other caller-supplied dependency on Config (TelemetryProvider, StatusReporter, Watcher). The server-built path is unchanged — buildSessionDataStorage now returns a closer so that lifecycle is tracked explicitly rather than via an ownership bool. Closes #4928 Co-Authored-By: Claude Opus 4.7 (1M context) --- pkg/vmcp/server/datastorage_injection_test.go | 162 ++++++++++++++++++ pkg/vmcp/server/server.go | 125 ++++++++++++-- 2 files changed, 269 insertions(+), 18 deletions(-) create mode 100644 pkg/vmcp/server/datastorage_injection_test.go diff --git a/pkg/vmcp/server/datastorage_injection_test.go b/pkg/vmcp/server/datastorage_injection_test.go new file mode 100644 index 0000000000..f50b57c9d9 --- /dev/null +++ b/pkg/vmcp/server/datastorage_injection_test.go @@ -0,0 +1,162 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package server_test + +import ( + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.uber.org/mock/gomock" + + transportsession "github.com/stacklok/toolhive/pkg/transport/session" + "github.com/stacklok/toolhive/pkg/vmcp" + vmcpconfig "github.com/stacklok/toolhive/pkg/vmcp/config" + discoveryMocks "github.com/stacklok/toolhive/pkg/vmcp/discovery/mocks" + "github.com/stacklok/toolhive/pkg/vmcp/mocks" + routerMocks "github.com/stacklok/toolhive/pkg/vmcp/router/mocks" + "github.com/stacklok/toolhive/pkg/vmcp/server" +) + +// countingDataStorage wraps a real LocalSessionDataStorage and counts how +// many times Close has been invoked. Used to assert that Server.Stop does +// not close a caller-supplied DataStorage. +type countingDataStorage struct { + transportsession.DataStorage + closeCalls atomic.Int32 +} + +func (c *countingDataStorage) Close() error { + c.closeCalls.Add(1) + return c.DataStorage.Close() +} + +func newCountingDataStorage(t *testing.T) *countingDataStorage { + t.Helper() + inner, err := transportsession.NewLocalSessionDataStorage(5 * time.Minute) + require.NoError(t, err) + return &countingDataStorage{DataStorage: inner} +} + +func TestNew_CallerOwnedDataStorageNotClosedOnStop(t *testing.T) { + t.Parallel() + + spy := newCountingDataStorage(t) + // The spy is caller-owned; close the inner LocalSessionDataStorage + // directly at the end of the test so the counter is not ticked by + // cleanup — the post-Stop assertion below must reflect only the server's + // behaviour. Err ignored: closing an already-closed local store is a + // no-op in this implementation. + t.Cleanup(func() { + _ = spy.DataStorage.Close() + }) + + ctrl := gomock.NewController(t) + t.Cleanup(ctrl.Finish) + + mockRouter := routerMocks.NewMockRouter(ctrl) + mockBackendClient := mocks.NewMockBackendClient(ctrl) + mockDiscoveryMgr := discoveryMocks.NewMockManager(ctrl) + mockDiscoveryMgr.EXPECT().Stop().Times(1) + + srv, err := server.New( + t.Context(), + &server.Config{ + Host: "127.0.0.1", + Port: 0, + SessionFactory: newNoopMockFactory(t), + DataStorage: spy, + }, + mockRouter, + mockBackendClient, + mockDiscoveryMgr, + vmcp.NewImmutableRegistry([]vmcp.Backend{}), + nil, + ) + require.NoError(t, err) + + err = srv.Stop(t.Context()) + require.NoError(t, err) + + assert.Equal(t, int32(0), spy.closeCalls.Load(), + "server must not close a caller-supplied DataStorage") +} + +func TestNew_BothSessionStorageAndDataStorageErrors(t *testing.T) { + t.Parallel() + + spy := newCountingDataStorage(t) + // Err ignored: closing an already-closed local store is a no-op. + t.Cleanup(func() { + _ = spy.DataStorage.Close() + }) + + ctrl := gomock.NewController(t) + t.Cleanup(ctrl.Finish) + + mockRouter := routerMocks.NewMockRouter(ctrl) + mockBackendClient := mocks.NewMockBackendClient(ctrl) + mockDiscoveryMgr := discoveryMocks.NewMockManager(ctrl) + + _, err := server.New( + t.Context(), + &server.Config{ + Host: "127.0.0.1", + Port: 0, + SessionFactory: newNoopMockFactory(t), + SessionStorage: &vmcpconfig.SessionStorageConfig{ + Provider: "redis", + Address: "127.0.0.1:6379", + }, + DataStorage: spy, + }, + mockRouter, + mockBackendClient, + mockDiscoveryMgr, + vmcp.NewImmutableRegistry([]vmcp.Backend{}), + nil, + ) + require.Error(t, err) + assert.Contains(t, err.Error(), "DataStorage") + assert.Contains(t, err.Error(), "SessionStorage") + assert.Equal(t, int32(0), spy.closeCalls.Load(), + "server must not close a caller-supplied DataStorage on misconfiguration") +} + +func TestNew_ServerBuiltDataStorageStopSucceeds(t *testing.T) { + // Guards against accidental regression of the server-owned close path + // when Close moved from an inline Stop() block onto sessionDataStorageCloser. + // Stop() must still complete without error when the server built the store. + // This is a smoke test — it cannot observe Close on the internal + // LocalSessionDataStorage because that type is constructed inside New(). + t.Parallel() + + ctrl := gomock.NewController(t) + t.Cleanup(ctrl.Finish) + + mockRouter := routerMocks.NewMockRouter(ctrl) + mockBackendClient := mocks.NewMockBackendClient(ctrl) + mockDiscoveryMgr := discoveryMocks.NewMockManager(ctrl) + mockDiscoveryMgr.EXPECT().Stop().Times(1) + + srv, err := server.New( + t.Context(), + &server.Config{ + Host: "127.0.0.1", + Port: 0, + SessionFactory: newNoopMockFactory(t), + SessionStorage: &vmcpconfig.SessionStorageConfig{Provider: "memory"}, + }, + mockRouter, + mockBackendClient, + mockDiscoveryMgr, + vmcp.NewImmutableRegistry([]vmcp.Backend{}), + nil, + ) + require.NoError(t, err) + + require.NoError(t, srv.Stop(t.Context())) +} diff --git a/pkg/vmcp/server/server.go b/pkg/vmcp/server/server.go index d8eb991562..2a04a92c1b 100644 --- a/pkg/vmcp/server/server.go +++ b/pkg/vmcp/server/server.go @@ -180,7 +180,38 @@ type Config struct { // When provider is "redis", a Redis-backed store is created for cross-pod // session persistence; the Redis password is read from the // THV_SESSION_REDIS_PASSWORD environment variable. + // + // Mutually exclusive with DataStorage: setting both is rejected at New(). SessionStorage *vmcpconfig.SessionStorageConfig + + // DataStorage optionally injects a caller-supplied session metadata store, + // bypassing the built-in memory/redis providers. When non-nil, the server + // uses this store as-is and SessionStorage is ignored in its entirety (no + // field of SessionStorage is consulted). Setting both DataStorage and a + // non-empty SessionStorage.Provider is rejected at New() as ambiguous + // configuration. + // + // Lifecycle: the caller owns it. The server does NOT call Close() on a + // caller-supplied DataStorage, even on error paths in New() or during + // Stop(). The caller is responsible for invoking Close() exactly once + // after Server.Stop() returns (not before — the session manager may issue + // final Update calls during Stop). The caller is likewise responsible for + // configuring the store's TTL; cfg.SessionTTL applies only to the + // transport-level session manager, not to the caller-supplied DataStorage. + // + // Sensitive material: the store holds HMAC-hashed token material and + // other session metadata. Embedders should treat the backing datastore as + // sensitive (dedicated credentials, encryption at rest, restricted read + // access). Implementations must not include credentials in Close() error + // messages — those errors are surfaced through Server.Stop(). + // + // This seam lets embedders satisfy transportsession.DataStorage against + // datastores other than the built-in providers (e.g. Postgres, DynamoDB) + // without forking the server. It enables cross-replica session metadata + // sharing when backed by a shared store; it does NOT solve cross-replica + // message delivery — callers still need session affinity at the load + // balancer for streaming responses. + DataStorage transportsession.DataStorage } // Server is the Virtual MCP Server that aggregates multiple backends. @@ -223,10 +254,16 @@ type Server struct { sessionManager *transportsession.Manager // sessionDataStorage is the pluggable key-value backend for session metadata. - // Currently always LocalSessionDataStorage (in-memory, single-process). - // Redis-backed storage for multi-pod deployments is not yet wired. + // It may be LocalSessionDataStorage (in-memory, single-process), a Redis-backed + // store, or a caller-supplied implementation injected via Config.DataStorage. sessionDataStorage transportsession.DataStorage + // sessionDataStorageCloser closes sessionDataStorage on shutdown. It is + // set only when the server built the store itself (memory or redis + // providers). When Config.DataStorage was supplied by the caller, this is + // nil and the caller is responsible for closing the store. + sessionDataStorageCloser func(context.Context) error + // Capability adapter for converting aggregator types to SDK types capabilityAdapter *adapter.CapabilityAdapter @@ -256,21 +293,51 @@ type Server struct { } // buildSessionDataStorage constructs the DataStorage backend from cfg. -// When cfg.SessionStorage is nil or provider is "memory" (or empty), local in-process -// storage is used. When provider is "redis", a Redis-backed store is created -// using the address, DB, and key prefix from cfg.SessionStorage; the password -// is read from the THV_SESSION_REDIS_PASSWORD environment variable. -// Any other provider value is a misconfiguration and returns an error. -func buildSessionDataStorage(ctx context.Context, cfg *Config) (transportsession.DataStorage, error) { +// +// Resolution order: +// +// 1. cfg.DataStorage (caller-supplied) takes precedence. When non-nil, the +// store is returned as-is with a nil closer — the caller owns the +// lifecycle. Setting both cfg.DataStorage and a non-empty +// cfg.SessionStorage.Provider is rejected as ambiguous. +// 2. cfg.SessionStorage.Provider "memory" (or empty, or nil SessionStorage): +// local in-process storage is created. +// 3. cfg.SessionStorage.Provider "redis": a Redis-backed store is created +// using the address, DB, and key prefix from cfg.SessionStorage. The +// password is read from the THV_SESSION_REDIS_PASSWORD environment +// variable. +// 4. Any other provider value is a misconfiguration and returns an error. +// +// For cases 2 and 3 (server-built stores), the returned closer wraps the +// store's Close method. For case 1 (caller-supplied), the closer is nil. +// New() routes the returned closer through Server.sessionDataStorageCloser +// so Close is invoked on shutdown (and on New() error after this point) — +// but only for server-built stores. +func buildSessionDataStorage( + ctx context.Context, + cfg *Config, +) (transportsession.DataStorage, func(context.Context) error, error) { + if cfg.DataStorage != nil { + if cfg.SessionStorage != nil && cfg.SessionStorage.Provider != "" { + return nil, nil, fmt.Errorf( + "cannot set both Config.DataStorage and Config.SessionStorage.Provider (%q); pick one", + cfg.SessionStorage.Provider) + } + return cfg.DataStorage, nil, nil + } // Default to in-process storage when session storage is not configured, // or when the provider is explicitly "memory" or left empty. if cfg.SessionStorage == nil || cfg.SessionStorage.Provider == "" || strings.EqualFold(cfg.SessionStorage.Provider, "memory") { - return transportsession.NewLocalSessionDataStorage(cfg.SessionTTL) + store, err := transportsession.NewLocalSessionDataStorage(cfg.SessionTTL) + if err != nil { + return nil, nil, err + } + return store, closerFor(store), nil } if cfg.SessionStorage.Provider != "redis" { - return nil, fmt.Errorf("unsupported session storage provider %q (supported: \"memory\", \"redis\")", + return nil, nil, fmt.Errorf("unsupported session storage provider %q (supported: \"memory\", \"redis\")", cfg.SessionStorage.Provider) } keyPrefix := cfg.SessionStorage.KeyPrefix @@ -288,7 +355,19 @@ func buildSessionDataStorage(ctx context.Context, cfg *Config) (transportsession "db", cfg.SessionStorage.DB, "key_prefix", keyPrefix, ) - return transportsession.NewRedisSessionDataStorage(ctx, redisCfg, cfg.SessionTTL) + store, err := transportsession.NewRedisSessionDataStorage(ctx, redisCfg, cfg.SessionTTL) + if err != nil { + return nil, nil, err + } + return store, closerFor(store), nil +} + +// closerFor adapts DataStorage.Close (no context) to the +// func(context.Context) error signature used by Server.sessionDataStorageCloser. +func closerFor(store transportsession.DataStorage) func(context.Context) error { + return func(context.Context) error { + return store.Close() + } } // New creates a new Virtual MCP Server instance. @@ -412,16 +491,18 @@ func New( // keyed by the same session ID. sessionManager := transportsession.NewManager(cfg.SessionTTL, transportsession.NewStreamableSession) - sessionDataStorage, err := buildSessionDataStorage(ctx, cfg) + sessionDataStorage, sessionDataStorageCloser, err := buildSessionDataStorage(ctx, cfg) if err != nil { return nil, fmt.Errorf("failed to create session data storage: %w", err) } - // Close sessionDataStorage if New() returns an error after this point so the - // background cleanup goroutine does not leak. - closeStorageOnErr := true + // If we built the store ourselves, close it when New() returns an error + // after this point so the background cleanup goroutine does not leak. + // For a caller-supplied store (sessionDataStorageCloser == nil), the + // caller owns the lifecycle and we leave it untouched on every path. + closeStorageOnErr := sessionDataStorageCloser != nil defer func() { if closeStorageOnErr { - _ = sessionDataStorage.Close() + _ = sessionDataStorageCloser(ctx) } }() @@ -486,6 +567,12 @@ func New( srv.shutdownFuncs = append(srv.shutdownFuncs, optimizerCleanup) } + // Store the session data storage closer on the Server so Stop() can invoke + // it last (after session manager and discovery manager have stopped). For + // a caller-supplied store this is nil and Stop() leaves it alone — the + // caller owns the lifecycle. + srv.sessionDataStorageCloser = sessionDataStorageCloser + // Register OnRegisterSession hook to inject capabilities after SDK registers session. // See handleSessionRegistration for implementation details. hooks.AddOnRegisterSession(func(ctx context.Context, session server.ClientSession) { @@ -848,8 +935,10 @@ func (s *Server) Stop(ctx context.Context) error { // Close session data storage last: HTTP server is down (no new in-flight requests), // all other components have stopped (no further restore or liveness checks). - if s.sessionDataStorage != nil { - if err := s.sessionDataStorage.Close(); err != nil { + // Only invoked when the server built the store itself; caller-supplied stores + // (Config.DataStorage) are left for the caller to close. + if s.sessionDataStorageCloser != nil { + if err := s.sessionDataStorageCloser(ctx); err != nil { errs = append(errs, fmt.Errorf("failed to close session data storage: %w", err)) } }