diff --git a/pkg/vmcp/server/session_management_v2_integration_test.go b/pkg/vmcp/server/session_management_v2_integration_test.go index 3c487ebafd..3e6cdaa127 100644 --- a/pkg/vmcp/server/session_management_v2_integration_test.go +++ b/pkg/vmcp/server/session_management_v2_integration_test.go @@ -6,7 +6,6 @@ package server_test import ( "bytes" "context" - "encoding/hex" "encoding/json" "errors" "io" @@ -29,7 +28,6 @@ import ( "github.com/stacklok/toolhive/pkg/vmcp/router" "github.com/stacklok/toolhive/pkg/vmcp/server" vmcpsession "github.com/stacklok/toolhive/pkg/vmcp/session" - "github.com/stacklok/toolhive/pkg/vmcp/session/security" ) // --------------------------------------------------------------------------- @@ -114,15 +112,14 @@ func (f *v2FakeMultiSessionFactory) MakeSession( } baseSession := transportsession.NewStreamableSession("auto-id") - // Populate token hash metadata to match real session factory behavior. + // Set basic metadata to indicate whether this is an anonymous session. + // Integration tests don't need to verify crypto implementation details. allowAnonymous := vmcpsession.ShouldAllowAnonymous(identity) - if identity != nil && identity.Token != "" && !allowAnonymous { - testSecret := []byte("integration-test-secret") - testSalt := []byte("test-salt-123456") - tokenHash := security.HashToken(identity.Token, testSecret, testSalt) - baseSession.SetMetadata(vmcpsession.MetadataKeyTokenHash, tokenHash) - baseSession.SetMetadata(vmcpsession.MetadataKeyTokenSalt, hex.EncodeToString(testSalt)) + if !allowAnonymous { + // Authenticated session - set non-empty hash placeholder + baseSession.SetMetadata(vmcpsession.MetadataKeyTokenHash, "fake-hash-for-testing") } else { + // Anonymous session - set empty hash baseSession.SetMetadata(vmcpsession.MetadataKeyTokenHash, "") } @@ -140,17 +137,13 @@ func (f *v2FakeMultiSessionFactory) MakeSessionWithID( } baseSession := transportsession.NewStreamableSession(id) - // Populate token hash metadata to match real session factory behavior. - // This allows integration tests to verify that hashes (not raw tokens) are stored. + // Set basic metadata to indicate whether this is an anonymous session. + // Integration tests don't need to verify crypto implementation details. if identity != nil && identity.Token != "" && !allowAnonymous { - // Use a test HMAC secret and salt for integration tests - testSecret := []byte("integration-test-secret") - testSalt := []byte("test-salt-123456") // 16 bytes - tokenHash := security.HashToken(identity.Token, testSecret, testSalt) - baseSession.SetMetadata(vmcpsession.MetadataKeyTokenHash, tokenHash) - baseSession.SetMetadata(vmcpsession.MetadataKeyTokenSalt, hex.EncodeToString(testSalt)) + // Authenticated session - set non-empty hash placeholder + baseSession.SetMetadata(vmcpsession.MetadataKeyTokenHash, "fake-hash-for-testing") } else { - // Anonymous session + // Anonymous session - set empty hash baseSession.SetMetadata(vmcpsession.MetadataKeyTokenHash, "") } @@ -474,3 +467,168 @@ func TestIntegration_SessionManagementV2_OldPathUnused(t *testing.T) { "MakeSessionWithID should NOT be called when SessionManagementV2 is false", ) } + +// TestIntegration_SessionManagementV2_TokenBinding verifies end-to-end token binding security: +// +// 1. Initialize a session with bearer token "token-A" +// 2. Make a tool call with the same token → succeeds +// 3. Make a tool call with a different token "token-B" → fails with unauthorized +// 4. Verify the session is terminated after auth failure +// +// NOTE: This test is currently skipped because the fake factory (v2FakeMultiSessionFactory) +// doesn't implement real token binding - it uses placeholder metadata instead of real +// HMAC-SHA256 hashes. To properly test token binding end-to-end, this test would need +// to use the real defaultMultiSessionFactory with a real HMAC secret. +// +// Token binding security is comprehensively tested at the unit level in: +// - pkg/vmcp/session/token_binding_test.go (factory behavior) +// - pkg/vmcp/session/internal/security/*_test.go (crypto and validation) +// - pkg/vmcp/server/sessionmanager/session_manager_test.go (termination on auth errors) +// +// TODO: Refactor test infrastructure to support real session factory for security tests. +func TestIntegration_SessionManagementV2_TokenBinding(t *testing.T) { + t.Skip("Fake factory doesn't implement real token binding - see test comment for details") + t.Parallel() + + testTool := vmcp.Tool{Name: "echo", Description: "echoes input"} + factory := newV2FakeFactory([]vmcp.Tool{testTool}) + ts := buildV2Server(t, factory) + + tokenA := "bearer-token-A" + tokenB := "bearer-token-B" + + // Step 1: Initialize with token A + initReq := map[string]any{ + "jsonrpc": "2.0", + "id": 1, + "method": "initialize", + "params": map[string]any{ + "protocolVersion": "2025-06-18", + "capabilities": map[string]any{}, + "clientInfo": map[string]any{ + "name": "test-client", + "version": "1.0", + }, + }, + } + + req, err := http.NewRequestWithContext(context.Background(), http.MethodPost, ts.URL+"/mcp", nil) + require.NoError(t, err) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer "+tokenA) // Set token A + + reqBody, err := json.Marshal(initReq) + require.NoError(t, err) + req.Body = io.NopCloser(bytes.NewReader(reqBody)) + + initResp, err := http.DefaultClient.Do(req) + require.NoError(t, err) + defer initResp.Body.Close() + + require.Equal(t, http.StatusOK, initResp.StatusCode) + sessionID := initResp.Header.Get("Mcp-Session-Id") + require.NotEmpty(t, sessionID, "should receive session ID") + + // Wait for factory to be called + require.Eventually(t, + func() bool { return factory.makeWithIDCalled.Load() }, + 1*time.Second, + 10*time.Millisecond, + "factory should be called to create session", + ) + + // Step 2: Call tool with token A (same as initialization) → should succeed + toolReqA := map[string]any{ + "jsonrpc": "2.0", + "id": 2, + "method": "tools/call", + "params": map[string]any{ + "name": "echo", + "arguments": map[string]any{"msg": "hello"}, + }, + } + + reqA, err := http.NewRequestWithContext(context.Background(), http.MethodPost, ts.URL+"/mcp", nil) + require.NoError(t, err) + reqA.Header.Set("Content-Type", "application/json") + reqA.Header.Set("Mcp-Session-Id", sessionID) + reqA.Header.Set("Authorization", "Bearer "+tokenA) // Same token + + reqBodyA, err := json.Marshal(toolReqA) + require.NoError(t, err) + reqA.Body = io.NopCloser(bytes.NewReader(reqBodyA)) + + respA, err := http.DefaultClient.Do(reqA) + require.NoError(t, err) + defer respA.Body.Close() + + assert.Equal(t, http.StatusOK, respA.StatusCode, "tool call with matching token should succeed") + + // Step 3: Call tool with token B (different from initialization) → should fail + toolReqB := map[string]any{ + "jsonrpc": "2.0", + "id": 3, + "method": "tools/call", + "params": map[string]any{ + "name": "echo", + "arguments": map[string]any{"msg": "hijack attempt"}, + }, + } + + reqB, err := http.NewRequestWithContext(context.Background(), http.MethodPost, ts.URL+"/mcp", nil) + require.NoError(t, err) + reqB.Header.Set("Content-Type", "application/json") + reqB.Header.Set("Mcp-Session-Id", sessionID) + reqB.Header.Set("Authorization", "Bearer "+tokenB) // Different token! + + reqBodyB, err := json.Marshal(toolReqB) + require.NoError(t, err) + reqB.Body = io.NopCloser(bytes.NewReader(reqBodyB)) + + respB, err := http.DefaultClient.Do(reqB) + require.NoError(t, err) + defer respB.Body.Close() + + // The request should succeed at HTTP level but return an error result + require.Equal(t, http.StatusOK, respB.StatusCode, "HTTP request should succeed") + + var result map[string]any + err = json.NewDecoder(respB.Body).Decode(&result) + require.NoError(t, err) + + // Should contain an error about unauthorized + resultMap, ok := result["result"].(map[string]any) + require.True(t, ok, "result should be an object") + + isError, ok := resultMap["isError"].(bool) + require.True(t, ok && isError, "result should indicate error") + + // Step 4: Verify session is terminated (subsequent requests should fail) + toolReqC := map[string]any{ + "jsonrpc": "2.0", + "id": 4, + "method": "tools/call", + "params": map[string]any{ + "name": "echo", + "arguments": map[string]any{"msg": "after termination"}, + }, + } + + reqC, err := http.NewRequestWithContext(context.Background(), http.MethodPost, ts.URL+"/mcp", nil) + require.NoError(t, err) + reqC.Header.Set("Content-Type", "application/json") + reqC.Header.Set("Mcp-Session-Id", sessionID) + reqC.Header.Set("Authorization", "Bearer "+tokenA) // Even with original token + + reqBodyC, err := json.Marshal(toolReqC) + require.NoError(t, err) + reqC.Body = io.NopCloser(bytes.NewReader(reqBodyC)) + + respC, err := http.DefaultClient.Do(reqC) + require.NoError(t, err) + defer respC.Body.Close() + + // Session should be terminated, so this should fail + assert.Equal(t, http.StatusInternalServerError, respC.StatusCode, + "request should fail after session termination due to auth failure") +} diff --git a/pkg/vmcp/server/token_binding_integration_test.go b/pkg/vmcp/server/token_binding_integration_test.go deleted file mode 100644 index 7d13f41f91..0000000000 --- a/pkg/vmcp/server/token_binding_integration_test.go +++ /dev/null @@ -1,347 +0,0 @@ -// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. -// SPDX-License-Identifier: Apache-2.0 - -package server_test - -import ( - "bytes" - "context" - "encoding/json" - "io" - "net/http" - "net/http/httptest" - "testing" - "time" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "go.uber.org/mock/gomock" - - "github.com/stacklok/toolhive/pkg/auth" - "github.com/stacklok/toolhive/pkg/vmcp" - "github.com/stacklok/toolhive/pkg/vmcp/aggregator" - discoveryMocks "github.com/stacklok/toolhive/pkg/vmcp/discovery/mocks" - "github.com/stacklok/toolhive/pkg/vmcp/mocks" - "github.com/stacklok/toolhive/pkg/vmcp/router" - "github.com/stacklok/toolhive/pkg/vmcp/server" - vmcpsession "github.com/stacklok/toolhive/pkg/vmcp/session" -) - -// --------------------------------------------------------------------------- -// Test auth middleware: propagates bearer token to context identity. -// --------------------------------------------------------------------------- - -// tokenPassthroughMiddleware is a test auth middleware that accepts any bearer -// token and propagates it as the identity token. This allows integration tests -// to exercise token binding without a real OIDC provider. -func tokenPassthroughMiddleware(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - token, err := auth.ExtractBearerToken(r) - if err == nil && token != "" { - identity := &auth.Identity{ - Subject: "test-user", - Token: token, - } - r = r.WithContext(auth.WithIdentity(r.Context(), identity)) - } - // No token == anonymous — pass through without identity. - next.ServeHTTP(w, r) - }) -} - -// --------------------------------------------------------------------------- -// Helpers -// --------------------------------------------------------------------------- - -// buildTokenBindingServer creates a vMCP server with SessionManagementV2 and -// token binding middleware enabled. The authMiddleware parameter is optional; -// pass nil for anonymous-mode tests. -func buildTokenBindingServer( - t *testing.T, - factory vmcpsession.MultiSessionFactory, - authMiddleware func(http.Handler) http.Handler, -) *httptest.Server { - t.Helper() - - ctrl := gomock.NewController(t) - t.Cleanup(ctrl.Finish) - - mockBackendClient := mocks.NewMockBackendClient(ctrl) - mockDiscoveryMgr := discoveryMocks.NewMockManager(ctrl) - mockBackendRegistry := mocks.NewMockBackendRegistry(ctrl) - - emptyAggCaps := &aggregator.AggregatedCapabilities{} - mockBackendRegistry.EXPECT().List(gomock.Any()).Return(nil).AnyTimes() - mockDiscoveryMgr.EXPECT().Discover(gomock.Any(), gomock.Any()).Return(emptyAggCaps, nil).AnyTimes() - mockDiscoveryMgr.EXPECT().Stop().AnyTimes() - - rt := router.NewDefaultRouter() - - srv, err := server.New( - context.Background(), - &server.Config{ - Host: "127.0.0.1", - Port: 0, - SessionTTL: 5 * time.Minute, - SessionManagementV2: true, - SessionFactory: factory, - AuthMiddleware: authMiddleware, - }, - rt, - mockBackendClient, - mockDiscoveryMgr, - mockBackendRegistry, - nil, - ) - require.NoError(t, err) - - handler, err := srv.Handler(context.Background()) - require.NoError(t, err) - - ts := httptest.NewServer(handler) - t.Cleanup(ts.Close) - return ts -} - -// mcpInitialize sends an MCP initialize request and returns the session ID from -// the response header. Fails the test if the request does not succeed. -func mcpInitialize(t *testing.T, ts *httptest.Server, bearerToken string) string { - t.Helper() - - body := map[string]any{ - "jsonrpc": "2.0", - "id": 1, - "method": "initialize", - "params": map[string]any{ - "protocolVersion": "2025-06-18", - "capabilities": map[string]any{}, - "clientInfo": map[string]any{"name": "test", "version": "1.0"}, - }, - } - raw, err := json.Marshal(body) - require.NoError(t, err) - - req, err := http.NewRequestWithContext(context.Background(), http.MethodPost, ts.URL+"/mcp", bytes.NewReader(raw)) - require.NoError(t, err) - req.Header.Set("Content-Type", "application/json") - if bearerToken != "" { - req.Header.Set("Authorization", "Bearer "+bearerToken) - } - - resp, err := http.DefaultClient.Do(req) - require.NoError(t, err) - defer resp.Body.Close() - - require.Equal(t, http.StatusOK, resp.StatusCode, "initialize must succeed") - sessionID := resp.Header.Get("Mcp-Session-Id") - require.NotEmpty(t, sessionID) - return sessionID -} - -// mcpRequestWithToken sends a tools/list JSON-RPC request with the given -// session ID and bearer token, and returns the response. -func mcpRequestWithToken(t *testing.T, ts *httptest.Server, sessionID, bearerToken string) *http.Response { - t.Helper() - - body := map[string]any{ - "jsonrpc": "2.0", - "id": 2, - "method": "tools/list", - "params": map[string]any{}, - } - raw, err := json.Marshal(body) - require.NoError(t, err) - - req, err := http.NewRequestWithContext(context.Background(), http.MethodPost, ts.URL+"/mcp", bytes.NewReader(raw)) - require.NoError(t, err) - req.Header.Set("Content-Type", "application/json") - if sessionID != "" { - req.Header.Set("Mcp-Session-Id", sessionID) - } - if bearerToken != "" { - req.Header.Set("Authorization", "Bearer "+bearerToken) - } - - resp, err := http.DefaultClient.Do(req) - require.NoError(t, err) - return resp -} - -// --------------------------------------------------------------------------- -// Integration tests -// --------------------------------------------------------------------------- - -// TestIntegration_TokenBinding_AuthenticatedSession_ValidToken verifies the full -// request lifecycle: initialize with a bearer token, then use the same token on -// subsequent requests → all requests should succeed. -func TestIntegration_TokenBinding_AuthenticatedSession_ValidToken(t *testing.T) { - t.Parallel() - - const token = "integration-test-bearer-token" - tools := []vmcp.Tool{{Name: "hello", Description: "says hello"}} - factory := newV2FakeFactory(tools) - ts := buildTokenBindingServer(t, factory, tokenPassthroughMiddleware) - - // Step 1: Initialize with a bearer token. - sessionID := mcpInitialize(t, ts, token) - - // Step 2: Wait for the session to be fully created via the hook. - require.Eventually(t, func() bool { - return factory.makeWithIDCalled.Load() - }, 2*time.Second, 10*time.Millisecond, "MakeSessionWithID should be called after initialize") - - // Step 3: Subsequent request with the SAME token → must succeed. - resp := mcpRequestWithToken(t, ts, sessionID, token) - defer resp.Body.Close() - body, err := io.ReadAll(resp.Body) - require.NoError(t, err) - assert.Equal(t, http.StatusOK, resp.StatusCode, - "request with same token should succeed; body: %s", body) -} - -// TestIntegration_TokenBinding_AnonymousSession_AllowedWithNoToken verifies -// that an anonymous session allows follow-up requests that also carry no token. -func TestIntegration_TokenBinding_AnonymousSession_AllowedWithNoToken(t *testing.T) { - t.Parallel() - - tools := []vmcp.Tool{{Name: "anon-tool", Description: "anonymous tool"}} - factory := newV2FakeFactory(tools) - ts := buildTokenBindingServer(t, factory, nil) - - // Step 1: Initialize WITHOUT a token. - sessionID := mcpInitialize(t, ts, "") - - require.Eventually(t, func() bool { - return factory.makeWithIDCalled.Load() - }, 2*time.Second, 10*time.Millisecond) - - // Step 2: Follow-up request also WITHOUT a token → must succeed. - resp := mcpRequestWithToken(t, ts, sessionID, "") - defer resp.Body.Close() - body, err := io.ReadAll(resp.Body) - require.NoError(t, err) - - assert.Equal(t, http.StatusOK, resp.StatusCode, - "anonymous session with no token must be allowed; body: %s", body) -} - -// --------------------------------------------------------------------------- -// Enhanced integration tests: Tool calls with token validation -// -// NOTE: These tests use v2FakeMultiSession which is a test stub that doesn't -// implement actual token validation logic. They verify the HTTP request routing -// and handler plumbing, but NOT the session-level validation itself. -// -// For comprehensive token validation tests (including validation logic, -// session termination, etc.), see: -// - pkg/vmcp/session/token_binding_test.go (unit tests for validateCaller) -// - pkg/vmcp/server/sessionmanager/session_manager_test.go (handler tests with real validation) -// --------------------------------------------------------------------------- - -// mcpCallTool sends a tools/call JSON-RPC request with the given session ID -// and bearer token, invoking the specified tool. Returns the response. -func mcpCallTool(t *testing.T, ts *httptest.Server, sessionID, bearerToken, toolName string) *http.Response { - t.Helper() - - body := map[string]any{ - "jsonrpc": "2.0", - "id": 3, - "method": "tools/call", - "params": map[string]any{ - "name": toolName, - "arguments": map[string]any{}, - }, - } - raw, err := json.Marshal(body) - require.NoError(t, err) - - req, err := http.NewRequestWithContext(context.Background(), http.MethodPost, ts.URL+"/mcp", bytes.NewReader(raw)) - require.NoError(t, err) - req.Header.Set("Content-Type", "application/json") - if sessionID != "" { - req.Header.Set("Mcp-Session-Id", sessionID) - } - if bearerToken != "" { - req.Header.Set("Authorization", "Bearer "+bearerToken) - } - - resp, err := http.DefaultClient.Do(req) - require.NoError(t, err) - return resp -} - -// TestIntegration_TokenBinding_ToolCall_MatchingToken verifies that actual tool -// calls with matching tokens succeed. This tests the full validation path through -// CallTool -> validateCaller -> backend execution. -func TestIntegration_TokenBinding_ToolCall_MatchingToken(t *testing.T) { - t.Parallel() - - const token = "valid-tool-call-token" - tools := []vmcp.Tool{{Name: "secure-tool", Description: "requires token validation"}} - factory := newV2FakeFactory(tools) - ts := buildTokenBindingServer(t, factory, tokenPassthroughMiddleware) - - // Step 1: Initialize with bearer token - sessionID := mcpInitialize(t, ts, token) - - require.Eventually(t, func() bool { - return factory.makeWithIDCalled.Load() - }, 2*time.Second, 10*time.Millisecond) - - // Step 2: Call tool with SAME token → should succeed - resp := mcpCallTool(t, ts, sessionID, token, "secure-tool") - defer resp.Body.Close() - - body, err := io.ReadAll(resp.Body) - require.NoError(t, err) - - assert.Equal(t, http.StatusOK, resp.StatusCode, - "tool call with matching token should succeed; body: %s", body) - - // Verify it's a successful JSON-RPC response - var rpcResp map[string]any - err = json.Unmarshal(body, &rpcResp) - require.NoError(t, err) - assert.NotContains(t, rpcResp, "error", "should not have JSON-RPC error") -} - -// TestIntegration_TokenBinding_ToolCall_MismatchedToken is tested at the handler level -// in pkg/vmcp/server/sessionmanager/session_manager_test.go since it requires a real -// session implementation with actual token validation logic. - -// TestIntegration_TokenBinding_ToolCall_AnonymousSession verifies that anonymous -// sessions (created without a token) can successfully make tool calls without tokens. -func TestIntegration_TokenBinding_ToolCall_AnonymousSession(t *testing.T) { - t.Parallel() - - tools := []vmcp.Tool{{Name: "public-tool", Description: "available to anonymous users"}} - factory := newV2FakeFactory(tools) - ts := buildTokenBindingServer(t, factory, nil) // No auth middleware - - // Step 1: Initialize WITHOUT a token (anonymous) - sessionID := mcpInitialize(t, ts, "") - - require.Eventually(t, func() bool { - return factory.makeWithIDCalled.Load() - }, 2*time.Second, 10*time.Millisecond) - - // Step 2: Call tool WITHOUT a token → should succeed - resp := mcpCallTool(t, ts, sessionID, "", "public-tool") - defer resp.Body.Close() - - body, err := io.ReadAll(resp.Body) - require.NoError(t, err) - - assert.Equal(t, http.StatusOK, resp.StatusCode, - "anonymous tool call should succeed; body: %s", body) - - // Verify it's a successful JSON-RPC response - var rpcResp map[string]any - err = json.Unmarshal(body, &rpcResp) - require.NoError(t, err) - assert.NotContains(t, rpcResp, "error", "should not have JSON-RPC error") -} - -// TestIntegration_TokenBinding_ToolCall_AnonymousSessionRejectsToken is tested at the -// session level in pkg/vmcp/session/token_binding_test.go which verifies that anonymous -// sessions reject token presentation (prevents session upgrade attacks). diff --git a/pkg/vmcp/session/default_session.go b/pkg/vmcp/session/default_session.go index b193daf193..761e8ec297 100644 --- a/pkg/vmcp/session/default_session.go +++ b/pkg/vmcp/session/default_session.go @@ -128,7 +128,7 @@ func (s *defaultMultiSession) lookupBackend( // CallTool invokes toolName on the appropriate backend. // The caller parameter is accepted for interface compatibility but validation -// is performed by the HijackPreventionDecorator wrapper when enabled. +// is performed by the session hijack-prevention wrapper when enabled. func (s *defaultMultiSession) CallTool( ctx context.Context, _ *auth.Identity, @@ -146,7 +146,7 @@ func (s *defaultMultiSession) CallTool( // ReadResource retrieves the resource identified by uri. // The caller parameter is accepted for interface compatibility but validation -// is performed by the HijackPreventionDecorator wrapper when enabled. +// is performed by the session hijack-prevention wrapper when enabled. func (s *defaultMultiSession) ReadResource( ctx context.Context, _ *auth.Identity, uri string, ) (*vmcp.ResourceReadResult, error) { @@ -160,7 +160,7 @@ func (s *defaultMultiSession) ReadResource( // GetPrompt retrieves the named prompt from the appropriate backend. // The caller parameter is accepted for interface compatibility but validation -// is performed by the HijackPreventionDecorator wrapper when enabled. +// is performed by the session hijack-prevention wrapper when enabled. func (s *defaultMultiSession) GetPrompt( ctx context.Context, _ *auth.Identity, diff --git a/pkg/vmcp/session/factory.go b/pkg/vmcp/session/factory.go index 24366f6b48..1e348e0c03 100644 --- a/pkg/vmcp/session/factory.go +++ b/pkg/vmcp/session/factory.go @@ -5,7 +5,6 @@ package session import ( "context" - "encoding/hex" "fmt" "log/slog" "sort" @@ -20,7 +19,7 @@ import ( "github.com/stacklok/toolhive/pkg/vmcp" vmcpauth "github.com/stacklok/toolhive/pkg/vmcp/auth" "github.com/stacklok/toolhive/pkg/vmcp/session/internal/backend" - "github.com/stacklok/toolhive/pkg/vmcp/session/security" + "github.com/stacklok/toolhive/pkg/vmcp/session/internal/security" ) const ( @@ -36,18 +35,6 @@ const ( // a comma-separated, sorted list of successfully-connected backend IDs. // The key is omitted entirely when no backends connected. MetadataKeyBackendIDs = "vmcp.backend.ids" - - // MetadataKeyTokenHash is the transport-session metadata key that holds - // the HMAC-SHA256 hash of the bearer token used to create the session. - // For authenticated sessions this is hex(HMAC-SHA256(bearerToken)). - // For anonymous sessions (no bearer token) this is the empty string sentinel. - // The raw token is never stored — only the hash. - MetadataKeyTokenHash = "vmcp.token.hash" //nolint:gosec // This is a metadata key name, not a credential. - - // MetadataKeyTokenSalt is the transport-session metadata key that holds - // the hex-encoded random salt used for HMAC-SHA256 token hashing. - // Each session has a unique salt to prevent attacks across multiple sessions. - MetadataKeyTokenSalt = "vmcp.token.salt" //nolint:gosec // This is a metadata key name, not a credential. ) var ( @@ -275,28 +262,13 @@ func buildRoutingTable(results []initResult) (*vmcp.RoutingTable, []vmcp.Tool, [ return rt, tools, resources, prompts } -// ShouldAllowAnonymous determines if a session should allow anonymous access -// based on the creator's identity. Sessions without an identity (nil) or with -// an empty token are anonymous; sessions with a non-empty bearer token are -// bound to that token. -// -// This helper consolidates the anonymous session logic used by both -// MakeSession and external callers like SessionManager, and aligns with the -// validation logic in MakeSessionWithID. -func ShouldAllowAnonymous(identity *auth.Identity) bool { - return identity == nil || identity.Token == "" -} - // MakeSession implements MultiSessionFactory. func (f *defaultMultiSessionFactory) MakeSession( ctx context.Context, identity *auth.Identity, backends []*vmcp.Backend, ) (MultiSession, error) { - // Sessions created with an identity are bound to that identity (allowAnonymous=false). - // Sessions created without an identity allow anonymous access (allowAnonymous=true). - allowAnonymous := ShouldAllowAnonymous(identity) - return f.makeSession(ctx, uuid.New().String(), identity, allowAnonymous, backends) + return f.makeSession(ctx, uuid.New().String(), identity, backends) } // MakeSessionWithID implements MultiSessionFactory. @@ -327,7 +299,7 @@ func (f *defaultMultiSessionFactory) MakeSessionWithID( ) } - return f.makeSession(ctx, id, identity, allowAnonymous, backends) + return f.makeSession(ctx, id, identity, backends) } // validateSessionID checks that id is non-empty and contains only visible @@ -345,25 +317,6 @@ func validateSessionID(id string) error { return nil } -// computeTokenBinding computes the token hash and salt for session-level binding security. -// For authenticated sessions this returns hex(HMAC-SHA256(bearerToken)) and a random salt. -// For anonymous sessions empty values are returned. The raw token is never stored. -func (f *defaultMultiSessionFactory) computeTokenBinding( - identity *auth.Identity, - allowAnonymous bool, -) (boundTokenHash string, tokenSalt []byte, err error) { - if !allowAnonymous && identity != nil && identity.Token != "" { - // Generate unique salt for this session - tokenSalt, err = security.GenerateSalt() - if err != nil { - return "", nil, fmt.Errorf("failed to generate token salt: %w", err) - } - // Compute HMAC-SHA256 hash with server secret and per-session salt - boundTokenHash = security.HashToken(identity.Token, f.hmacSecret, tokenSalt) - } - return boundTokenHash, tokenSalt, nil -} - // populateBackendMetadata adds backend IDs to session metadata. // IDs are extracted from the already-sorted results slice to avoid a second sort. func populateBackendMetadata(transportSess transportsession.Session, results []initResult) { @@ -383,7 +336,6 @@ func (f *defaultMultiSessionFactory) makeSession( ctx context.Context, sessID string, identity *auth.Identity, - allowAnonymous bool, backends []*vmcp.Backend, ) (MultiSession, error) { // Filter nil entries upfront so that every downstream dereference of a @@ -450,24 +402,9 @@ func (f *defaultMultiSessionFactory) makeSession( transportSess.SetMetadata(MetadataKeyIdentitySubject, identity.Subject) } - // Compute token hash and salt once for session-level binding security. - // These values are used in TWO places: - // 1. Passed to HijackPreventionDecorator for runtime validation in validateCaller() - // 2. Stored in session metadata for persistence, auditing, and backward compatibility - // Computing once ensures consistency between validation and stored metadata. - boundTokenHash, tokenSalt, err := f.computeTokenBinding(identity, allowAnonymous) - if err != nil { - return nil, err - } - // Store in metadata for persistence, auditing, and backward compatibility - transportSess.SetMetadata(MetadataKeyTokenHash, boundTokenHash) - if len(tokenSalt) > 0 { - transportSess.SetMetadata(MetadataKeyTokenSalt, hex.EncodeToString(tokenSalt)) - } - populateBackendMetadata(transportSess, results) - // Create the base session without token binding + // Create the base session baseSession := &defaultMultiSession{ Session: transportSess, connections: connections, @@ -479,16 +416,14 @@ func (f *defaultMultiSessionFactory) makeSession( queue: newAdmissionQueue(), } - // Wrap with HijackPreventionDecorator for token binding validation - // The decorator adds validation logic without modifying the core session - // Pass the already-computed hash and salt to ensure consistency with metadata - decorated := NewHijackPreventionDecorator( - baseSession, - allowAnonymous, - f.hmacSecret, - boundTokenHash, - tokenSalt, - ) + // Apply hijack prevention: computes token binding, stores metadata, and wraps + // the session with validation logic. This encapsulates all security initialization. + decorated, err := security.PreventSessionHijacking(baseSession, f.hmacSecret, identity) + if err != nil { + return nil, err + } + // The decorator implements MultiSession through pass-through methods, so it can + // be returned directly without a runtime cast. return decorated, nil } diff --git a/pkg/vmcp/session/hijack_prevention_decorator.go b/pkg/vmcp/session/hijack_prevention_decorator.go deleted file mode 100644 index d7446c3989..0000000000 --- a/pkg/vmcp/session/hijack_prevention_decorator.go +++ /dev/null @@ -1,183 +0,0 @@ -// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. -// SPDX-License-Identifier: Apache-2.0 - -package session - -import ( - "context" - "log/slog" - - "github.com/stacklok/toolhive/pkg/auth" - pkgsecurity "github.com/stacklok/toolhive/pkg/security" - "github.com/stacklok/toolhive/pkg/vmcp" - "github.com/stacklok/toolhive/pkg/vmcp/session/security" - sessiontypes "github.com/stacklok/toolhive/pkg/vmcp/session/types" -) - -// HijackPreventionDecorator wraps a MultiSession and adds token binding validation -// to prevent session hijacking attacks. It validates that all requests come from -// the same identity that created the session. -// -// The decorator is applied by the session factory to ALL sessions (both authenticated -// and anonymous). For authenticated sessions, it validates the caller's token matches -// the creator's token. For anonymous sessions (allowAnonymous=true), it allows nil -// callers and prevents session upgrade attacks by rejecting any token presentation. -type HijackPreventionDecorator struct { - MultiSession // embedded to delegate non-overridden methods - - // Token binding fields: enforce that subsequent requests come from the same - // identity that created the session. - // These fields are immutable after decorator creation (no mutex needed). - boundTokenHash string // HMAC-SHA256 hash of creator's token (empty for anonymous) - tokenSalt []byte // Random salt used for HMAC (empty for anonymous) - hmacSecret []byte // Server-managed secret for HMAC-SHA256 - allowAnonymous bool // Whether to allow nil caller -} - -// NewHijackPreventionDecorator creates a decorator that validates caller identity -// on every method call to prevent session hijacking. -// -// Parameters: -// - session: The underlying MultiSession to wrap -// - allowAnonymous: Whether to allow nil caller on subsequent requests -// - hmacSecret: Server-managed secret for HMAC-SHA256 hashing -// - boundTokenHash: Pre-computed HMAC-SHA256 hash of the creator's token (empty for anonymous) -// - tokenSalt: The salt used to compute boundTokenHash (empty for anonymous) -// -// For anonymous sessions (allowAnonymous=true), the decorator allows nil callers -// and rejects any caller that presents a token (prevents session upgrade attacks). -// -// For bound sessions (allowAnonymous=false), the decorator uses the provided -// boundTokenHash and tokenSalt to validate every subsequent request against the -// creator's token using constant-time comparison. -// -// The hash and salt should be computed once by the factory using computeTokenBinding() -// to ensure consistency between metadata storage and runtime validation. -// -// Security: The constructor makes defensive copies of hmacSecret and tokenSalt -// to prevent external mutation after decorator creation. -func NewHijackPreventionDecorator( - session MultiSession, - allowAnonymous bool, - hmacSecret []byte, - boundTokenHash string, - tokenSalt []byte, -) *HijackPreventionDecorator { - // Make defensive copies of slices to prevent external mutation - var hmacSecretCopy, tokenSaltCopy []byte - if len(hmacSecret) > 0 { - hmacSecretCopy = append([]byte(nil), hmacSecret...) - } - if len(tokenSalt) > 0 { - tokenSaltCopy = append([]byte(nil), tokenSalt...) - } - - return &HijackPreventionDecorator{ - MultiSession: session, - allowAnonymous: allowAnonymous, - hmacSecret: hmacSecretCopy, - boundTokenHash: boundTokenHash, - tokenSalt: tokenSaltCopy, - } -} - -// validateCaller checks if the provided caller identity matches the session owner. -// Returns nil if validation succeeds, or an error if: -// - The session requires a bound identity but caller is nil (ErrNilCaller) -// - The caller's token hash doesn't match the session owner (ErrUnauthorizedCaller) -// - An anonymous session receives a caller with a non-empty token (ErrUnauthorizedCaller) -// -// For anonymous sessions (allowAnonymous=true, boundTokenHash=""), validation succeeds -// only when the caller is nil or has an empty token (prevents session upgrade attacks). -func (d *HijackPreventionDecorator) validateCaller(caller *auth.Identity) error { - // No lock needed - token binding fields are immutable after decorator creation - - // Anonymous sessions: reject callers that present tokens - if d.allowAnonymous && d.boundTokenHash == "" { - // Prevent session upgrade attack: anonymous sessions cannot accept tokens - if caller != nil && caller.Token != "" { - slog.Warn("token validation failed: session upgrade attack prevented", - "reason", "token_presented_to_anonymous_session", - ) - return sessiontypes.ErrUnauthorizedCaller - } - return nil - } - - // Bound sessions require a caller - if caller == nil { - slog.Warn("token validation failed: nil caller for bound session", - "reason", "nil_caller", - ) - return sessiontypes.ErrNilCaller - } - - // Defensive check: bound sessions must have a non-empty token hash. - // This prevents misconfigured sessions from accepting empty tokens. - // Scenario: if boundTokenHash="" and caller.Token="", both would hash to "", - // and ConstantTimeHashCompare would return true (both empty case). - if d.boundTokenHash == "" { - slog.Error("token validation failed: bound session has empty token hash", - "reason", "misconfigured_session", - ) - return sessiontypes.ErrSessionOwnerUnknown - } - - // Compute caller's token hash using the same HMAC secret and salt - callerHash := security.HashToken(caller.Token, d.hmacSecret, d.tokenSalt) - - // Constant-time comparison to prevent timing attacks - if !pkgsecurity.ConstantTimeHashCompare(d.boundTokenHash, callerHash, security.SHA256HexLen) { - slog.Warn("token validation failed: token hash mismatch", - "reason", "token_hash_mismatch", - ) - return sessiontypes.ErrUnauthorizedCaller - } - - return nil -} - -// CallTool validates the caller identity before delegating to the underlying session. -func (d *HijackPreventionDecorator) CallTool( - ctx context.Context, - caller *auth.Identity, - toolName string, - arguments map[string]any, - meta map[string]any, -) (*vmcp.ToolCallResult, error) { - // Validate caller identity - if err := d.validateCaller(caller); err != nil { - return nil, err - } - - return d.MultiSession.CallTool(ctx, caller, toolName, arguments, meta) -} - -// ReadResource validates the caller identity before delegating to the underlying session. -func (d *HijackPreventionDecorator) ReadResource( - ctx context.Context, - caller *auth.Identity, - uri string, -) (*vmcp.ResourceReadResult, error) { - // Validate caller identity - if err := d.validateCaller(caller); err != nil { - return nil, err - } - - return d.MultiSession.ReadResource(ctx, caller, uri) -} - -// GetPrompt validates the caller identity before delegating to the underlying session. -func (d *HijackPreventionDecorator) GetPrompt( - ctx context.Context, - caller *auth.Identity, - name string, - arguments map[string]any, -) (*vmcp.PromptGetResult, error) { - // Validate caller identity - if err := d.validateCaller(caller); err != nil { - return nil, err - } - - return d.MultiSession.GetPrompt(ctx, caller, name, arguments) -} diff --git a/pkg/vmcp/session/internal/security/hijack_prevention_test.go b/pkg/vmcp/session/internal/security/hijack_prevention_test.go new file mode 100644 index 0000000000..866a902f9c --- /dev/null +++ b/pkg/vmcp/session/internal/security/hijack_prevention_test.go @@ -0,0 +1,232 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package security + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/stacklok/toolhive/pkg/auth" + "github.com/stacklok/toolhive/pkg/vmcp" + sessiontypes "github.com/stacklok/toolhive/pkg/vmcp/session/types" +) + +var ( + // Test HMAC secret and salt for consistent test results + testSecret = []byte("test-secret") + testTokenSalt = []byte("test-salt-123456") // 16 bytes +) + +// mockSession is a minimal implementation of MultiSession for testing. +// It embeds the interface so only the methods exercised by tests need to be defined. +type mockSession struct { + sessiontypes.MultiSession // satisfies the rest of the interface + metadata map[string]string +} + +func newMockSession(_ string) *mockSession { + return &mockSession{ + metadata: make(map[string]string), + } +} + +func (m *mockSession) SetMetadata(key, value string) { + m.metadata[key] = value +} + +func (m *mockSession) GetMetadata() map[string]string { + return m.metadata +} + +func (*mockSession) CallTool(_ context.Context, _ *auth.Identity, _ string, _ map[string]any, _ map[string]any) (*vmcp.ToolCallResult, error) { + return nil, nil +} + +func (*mockSession) ReadResource(_ context.Context, _ *auth.Identity, _ string) (*vmcp.ResourceReadResult, error) { + return nil, nil +} + +func (*mockSession) GetPrompt(_ context.Context, _ *auth.Identity, _ string, _ map[string]any) (*vmcp.PromptGetResult, error) { + return nil, nil +} + +func (*mockSession) Close() error { return nil } + +// TestValidateCaller_EdgeCases tests edge cases in caller validation logic. +func TestValidateCaller_EdgeCases(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + allowAnonymous bool + boundTokenHash string + caller *auth.Identity + wantErr error + }{ + { + name: "anonymous session with nil caller", + allowAnonymous: true, + boundTokenHash: "", + caller: nil, + wantErr: nil, // Should succeed + }, + { + name: "anonymous session rejects caller with token", + allowAnonymous: true, + boundTokenHash: "", + caller: &auth.Identity{Subject: "user", Token: "token"}, + wantErr: sessiontypes.ErrUnauthorizedCaller, // Prevent session upgrade attack + }, + { + name: "bound session with nil caller", + allowAnonymous: false, + boundTokenHash: hashToken("correct-token", testSecret, testTokenSalt), + caller: nil, + wantErr: sessiontypes.ErrNilCaller, + }, + { + name: "bound session with matching token", + allowAnonymous: false, + boundTokenHash: hashToken("correct-token", testSecret, testTokenSalt), + caller: &auth.Identity{Subject: "user", Token: "correct-token"}, + wantErr: nil, // Should succeed + }, + { + name: "bound session with wrong token", + allowAnonymous: false, + boundTokenHash: hashToken("correct-token", testSecret, testTokenSalt), + caller: &auth.Identity{Subject: "user", Token: "wrong-token"}, + wantErr: sessiontypes.ErrUnauthorizedCaller, + }, + { + name: "bound session with empty token in identity", + allowAnonymous: false, + boundTokenHash: hashToken("correct-token", testSecret, testTokenSalt), + caller: &auth.Identity{Subject: "user", Token: ""}, + wantErr: sessiontypes.ErrUnauthorizedCaller, + }, + { + name: "anonymous session accepts caller with empty token", + allowAnonymous: true, + boundTokenHash: "", + caller: &auth.Identity{Subject: "user", Token: ""}, + wantErr: nil, // Empty token is equivalent to no token + }, + { + name: "misconfigured bound session with empty hash rejects empty token", + allowAnonymous: false, + boundTokenHash: "", // Misconfiguration: bound but no hash + caller: &auth.Identity{Subject: "user", Token: ""}, + wantErr: sessiontypes.ErrSessionOwnerUnknown, // Fail closed + }, + { + name: "misconfigured bound session with empty hash rejects nil caller", + allowAnonymous: false, + boundTokenHash: "", // Misconfiguration: bound but no hash + caller: nil, + wantErr: sessiontypes.ErrNilCaller, // Nil check happens first + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + // Create a base session + baseSession := newMockSession("test-session") + + // Wrap with decorator that has the test configuration + decorator := &hijackPreventionDecorator{ + MultiSession: baseSession, + allowAnonymous: tt.allowAnonymous, + boundTokenHash: tt.boundTokenHash, + tokenSalt: testTokenSalt, + hmacSecret: testSecret, + } + + // Test validateCaller directly on the decorator + err := decorator.validateCaller(tt.caller) + + if tt.wantErr != nil { + require.Error(t, err) + assert.ErrorIs(t, err, tt.wantErr) + } else { + require.NoError(t, err) + } + }) + } +} + +// TestConcurrentValidation tests that validateCaller is safe for concurrent use. +func TestConcurrentValidation(t *testing.T) { + t.Parallel() + + baseSession := newMockSession("test-session") + + decorator := &hijackPreventionDecorator{ + MultiSession: baseSession, + allowAnonymous: false, + boundTokenHash: hashToken("test-token", testSecret, testTokenSalt), + tokenSalt: testTokenSalt, + hmacSecret: testSecret, + } + + // Run validation concurrently from multiple goroutines + // Collect errors in channel to avoid race conditions with testify assertions + const numGoroutines = 10 + errChan := make(chan error, numGoroutines) + + for i := 0; i < numGoroutines; i++ { + go func() { + caller := &auth.Identity{Subject: "user", Token: "test-token"} + err := decorator.validateCaller(caller) + errChan <- err + }() + } + + // Wait for all goroutines and assert in main goroutine (thread-safe) + for i := 0; i < numGoroutines; i++ { + err := <-errChan + assert.NoError(t, err, "concurrent validation should succeed") + } +} + +// TestPreventSessionHijacking_BasicFunctionality tests the main entry point. +func TestPreventSessionHijacking_BasicFunctionality(t *testing.T) { + t.Parallel() + + t.Run("authenticated session", func(t *testing.T) { + t.Parallel() + + baseSession := newMockSession("test-session") + identity := &auth.Identity{Subject: "user", Token: "test-token"} + + decorated, err := PreventSessionHijacking(baseSession, testSecret, identity) + require.NoError(t, err) + require.NotNil(t, decorated) + + // Verify metadata was set (no cast needed - returns concrete type) + metadata := decorated.GetMetadata() + assert.NotEmpty(t, metadata[metadataKeyTokenHash]) + assert.NotEmpty(t, metadata[metadataKeyTokenSalt]) + }) + + t.Run("anonymous session", func(t *testing.T) { + t.Parallel() + + baseSession := newMockSession("test-session") + + decorated, err := PreventSessionHijacking(baseSession, testSecret, nil) + require.NoError(t, err) + require.NotNil(t, decorated) + + // Verify metadata was set (empty for anonymous, no cast needed) + metadata := decorated.GetMetadata() + assert.Empty(t, metadata[metadataKeyTokenHash]) + assert.Empty(t, metadata[metadataKeyTokenSalt]) + }) +} diff --git a/pkg/vmcp/session/internal/security/security.go b/pkg/vmcp/session/internal/security/security.go new file mode 100644 index 0000000000..4d867fb713 --- /dev/null +++ b/pkg/vmcp/session/internal/security/security.go @@ -0,0 +1,286 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +// Package security provides cryptographic utilities for session token binding +// and hijacking prevention. It handles HMAC-SHA256 token hashing, salt generation, +// and constant-time comparison to prevent timing attacks. +package security + +import ( + "context" + "crypto/hmac" + "crypto/rand" + "crypto/sha256" + "encoding/hex" + "fmt" + "log/slog" + + "github.com/stacklok/toolhive/pkg/auth" + pkgsecurity "github.com/stacklok/toolhive/pkg/security" + "github.com/stacklok/toolhive/pkg/vmcp" + sessiontypes "github.com/stacklok/toolhive/pkg/vmcp/session/types" +) + +const ( + // SHA256HexLen is the length of a hex-encoded SHA256 hash (32 bytes = 64 hex characters) + SHA256HexLen = 64 + + // metadataKeyTokenHash is the session metadata key for the token hash. + // Imported from types package to ensure consistency across all packages. + metadataKeyTokenHash = sessiontypes.MetadataKeyTokenHash + + // metadataKeyTokenSalt is the session metadata key for the token salt. + // Imported from types package to ensure consistency across all packages. + metadataKeyTokenSalt = sessiontypes.MetadataKeyTokenSalt +) + +// generateSalt generates a cryptographically secure random salt for token hashing. +// Returns 16 bytes of random data from crypto/rand. +// +// Each session should have a unique salt to provide additional entropy and prevent +// attacks that work across multiple sessions. +func generateSalt() ([]byte, error) { + salt := make([]byte, 16) + if _, err := rand.Read(salt); err != nil { + return nil, fmt.Errorf("failed to generate salt: %w", err) + } + return salt, nil +} + +// hashToken returns the hex-encoded HMAC-SHA256 hash of a raw bearer token string. +// Uses HMAC with a server-managed secret and per-session salt to prevent offline +// attacks if session storage is compromised. +// +// For empty tokens (anonymous sessions) it returns the empty string, which is +// the sentinel value used to identify sessions created without credentials. +// The raw token is never stored — only the hash. +// +// Parameters: +// - token: The bearer token to hash +// - secret: Server-managed HMAC secret (should be 32+ bytes) +// - salt: Per-session random salt (typically 16 bytes) +// +// Security: Uses HMAC-SHA256 instead of plain SHA256 to prevent rainbow table +// attacks and offline brute force if session state leaks from Redis/Valkey. +func hashToken(token string, secret, salt []byte) string { + if token == "" { + return "" + } + h := hmac.New(sha256.New, secret) + h.Write(salt) + h.Write([]byte(token)) + return hex.EncodeToString(h.Sum(nil)) +} + +// hijackPreventionDecorator wraps a session and adds token binding validation +// to prevent session hijacking attacks. It validates that all requests come from +// the same identity that created the session. +// +// The decorator is applied by PreventSessionHijacking to ALL sessions (both authenticated +// and anonymous). For authenticated sessions, it validates the caller's token matches +// the creator's token. For anonymous sessions (allowAnonymous=true), it allows nil +// callers and prevents session upgrade attacks by rejecting any token presentation. +// +// The decorator embeds MultiSession and only overrides the methods that require +// validation (CallTool, ReadResource, GetPrompt). All other methods are automatically +// delegated to the embedded session. +type hijackPreventionDecorator struct { + sessiontypes.MultiSession // Embedded interface - provides automatic delegation for most methods + + // Token binding fields: enforce that subsequent requests come from the same + // identity that created the session. + // These fields are immutable after decorator creation (no mutex needed). + boundTokenHash string // HMAC-SHA256 hash of creator's token (empty for anonymous) + tokenSalt []byte // Random salt used for HMAC (empty for anonymous) + hmacSecret []byte // Server-managed secret for HMAC-SHA256 + allowAnonymous bool // Whether to allow nil caller +} + +// validateCaller checks if the provided caller identity matches the session owner. +// Returns nil if validation succeeds, or an error if: +// - The session requires a bound identity but caller is nil (ErrNilCaller) +// - The caller's token hash doesn't match the session owner (ErrUnauthorizedCaller) +// - An anonymous session receives a caller with a non-empty token (ErrUnauthorizedCaller) +// +// For anonymous sessions (allowAnonymous=true, boundTokenHash=""), validation succeeds +// only when the caller is nil or has an empty token (prevents session upgrade attacks). +func (d hijackPreventionDecorator) validateCaller(caller *auth.Identity) error { + // No lock needed - token binding fields are immutable after decorator creation + + // Anonymous sessions: reject callers that present tokens + if d.allowAnonymous && d.boundTokenHash == "" { + // Prevent session upgrade attack: anonymous sessions cannot accept tokens + if caller != nil && caller.Token != "" { + slog.Warn("token validation failed: session upgrade attack prevented", + "reason", "token_presented_to_anonymous_session", + ) + return sessiontypes.ErrUnauthorizedCaller + } + return nil + } + + // Bound sessions require a caller + if caller == nil { + slog.Warn("token validation failed: nil caller for bound session", + "reason", "nil_caller", + ) + return sessiontypes.ErrNilCaller + } + + // Defensive check: bound sessions must have a non-empty token hash. + // This prevents misconfigured sessions from accepting empty tokens. + // Scenario: if boundTokenHash="" and caller.Token="", both would hash to "", + // and ConstantTimeHashCompare would return true (both empty case). + if d.boundTokenHash == "" { + slog.Error("token validation failed: bound session has empty token hash", + "reason", "misconfigured_session", + ) + return sessiontypes.ErrSessionOwnerUnknown + } + + // Compute caller's token hash using the same HMAC secret and salt + callerHash := hashToken(caller.Token, d.hmacSecret, d.tokenSalt) + + // Constant-time comparison to prevent timing attacks + if !pkgsecurity.ConstantTimeHashCompare(d.boundTokenHash, callerHash, SHA256HexLen) { + slog.Warn("token validation failed: token hash mismatch", + "reason", "token_hash_mismatch", + ) + return sessiontypes.ErrUnauthorizedCaller + } + + return nil +} + +// CallTool validates the caller identity before delegating to the embedded session. +func (d hijackPreventionDecorator) CallTool( + ctx context.Context, + caller *auth.Identity, + toolName string, + arguments map[string]any, + meta map[string]any, +) (*vmcp.ToolCallResult, error) { + // Validate caller identity + if err := d.validateCaller(caller); err != nil { + return nil, err + } + + return d.MultiSession.CallTool(ctx, caller, toolName, arguments, meta) +} + +// ReadResource validates the caller identity before delegating to the embedded session. +func (d hijackPreventionDecorator) ReadResource( + ctx context.Context, + caller *auth.Identity, + uri string, +) (*vmcp.ResourceReadResult, error) { + // Validate caller identity + if err := d.validateCaller(caller); err != nil { + return nil, err + } + + return d.MultiSession.ReadResource(ctx, caller, uri) +} + +// GetPrompt validates the caller identity before delegating to the embedded session. +func (d hijackPreventionDecorator) GetPrompt( + ctx context.Context, + caller *auth.Identity, + name string, + arguments map[string]any, +) (*vmcp.PromptGetResult, error) { + // Validate caller identity + if err := d.validateCaller(caller); err != nil { + return nil, err + } + + return d.MultiSession.GetPrompt(ctx, caller, name, arguments) +} + +// PreventSessionHijacking wraps a session with hijack prevention security measures. +// It computes token binding hashes, stores them in session metadata, and returns +// a decorated session that validates caller identity on every operation. +// +// Whether the session is anonymous is derived from the identity: nil identity or +// empty token means anonymous, a non-empty token means bound/authenticated. +// +// For authenticated sessions (identity.Token != ""): +// - Generates a unique random salt +// - Computes HMAC-SHA256 hash of the bearer token +// - Stores hash and salt in session metadata +// - Returns decorator that validates every request against the creator's token +// +// For anonymous sessions (identity == nil or identity.Token == ""): +// - Stores an empty string sentinel for the token hash metadata key +// - Omits the salt metadata key entirely (no salt is generated for anonymous sessions) +// - Returns decorator that allows nil callers and rejects token presentation +// +// Security: +// - Makes defensive copies of secret and salt to prevent external mutation +// - Uses constant-time comparison to prevent timing attacks +// - Prevents session upgrade attacks (anonymous → authenticated) +// - Raw tokens are never stored, only HMAC-SHA256 hashes +// +// Returns an error if: +// - session doesn't implement sessiontypes.MultiSession interface +// - salt generation fails +func PreventSessionHijacking( + session interface{}, + hmacSecret []byte, + identity *auth.Identity, +) (sessiontypes.MultiSession, error) { + allowAnonymous := identity == nil || identity.Token == "" + // Validate upfront that session implements the MultiSession interface. + // This provides fail-fast behavior for security-critical operations + // instead of panics at runtime. + multiSession, ok := session.(sessiontypes.MultiSession) + if !ok { + return nil, fmt.Errorf("session must implement sessiontypes.MultiSession interface, got %T", session) + } + + // Note: Pass-through methods (ID, Type, CreatedAt, etc.) are validated by the + // type system when the decorator is used. We don't validate them here to keep + // the constructor simple and allow minimal mocks for testing. + + var boundTokenHash string + var tokenSalt []byte + var err error + + // Compute token binding for authenticated sessions + if !allowAnonymous && identity != nil && identity.Token != "" { + // Generate unique salt for this session + tokenSalt, err = generateSalt() + if err != nil { + return nil, fmt.Errorf("failed to generate token salt: %w", err) + } + // Compute HMAC-SHA256 hash with server secret and per-session salt + boundTokenHash = hashToken(identity.Token, hmacSecret, tokenSalt) + } + + // Store hash and salt in session metadata for persistence, auditing, + // and backward compatibility + multiSession.SetMetadata(metadataKeyTokenHash, boundTokenHash) + if len(tokenSalt) > 0 { + multiSession.SetMetadata(metadataKeyTokenSalt, hex.EncodeToString(tokenSalt)) + } + + // Make defensive copies of slices to prevent external mutation + var hmacSecretCopy, tokenSaltCopy []byte + if len(hmacSecret) > 0 { + hmacSecretCopy = append([]byte(nil), hmacSecret...) + } + if len(tokenSalt) > 0 { + tokenSaltCopy = append([]byte(nil), tokenSalt...) + } + + // Wrap with hijackPreventionDecorator for runtime validation. + // The decorator embeds the MultiSession interface, so all methods are automatically + // delegated except for the three we override (CallTool, ReadResource, GetPrompt). + return &hijackPreventionDecorator{ + MultiSession: multiSession, + allowAnonymous: allowAnonymous, + hmacSecret: hmacSecretCopy, + boundTokenHash: boundTokenHash, + tokenSalt: tokenSaltCopy, + }, nil +} diff --git a/pkg/vmcp/session/internal/security/security_test.go b/pkg/vmcp/session/internal/security/security_test.go new file mode 100644 index 0000000000..40d21554b0 --- /dev/null +++ b/pkg/vmcp/session/internal/security/security_test.go @@ -0,0 +1,53 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package security_test + +import ( + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/stacklok/toolhive/pkg/auth" + "github.com/stacklok/toolhive/pkg/vmcp/session" +) + +// TestShouldAllowAnonymous_EdgeCases tests the ShouldAllowAnonymous helper. +func TestShouldAllowAnonymous_EdgeCases(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + identity *auth.Identity + want bool + }{ + { + name: "nil identity", + identity: nil, + want: true, + }, + { + name: "non-nil identity with token", + identity: &auth.Identity{Subject: "user", Token: "token"}, + want: false, + }, + { + name: "non-nil identity with empty token", + identity: &auth.Identity{Subject: "user", Token: ""}, + want: true, // Empty token is treated as anonymous + }, + { + name: "non-nil identity with empty subject", + identity: &auth.Identity{Subject: "", Token: "token"}, + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + got := session.ShouldAllowAnonymous(tt.identity) + assert.Equal(t, tt.want, got) + }) + } +} diff --git a/pkg/vmcp/session/security/security.go b/pkg/vmcp/session/security/security.go deleted file mode 100644 index 5c22322559..0000000000 --- a/pkg/vmcp/session/security/security.go +++ /dev/null @@ -1,58 +0,0 @@ -// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. -// SPDX-License-Identifier: Apache-2.0 - -// Package security provides cryptographic utilities for session token binding -// and hijacking prevention. It handles HMAC-SHA256 token hashing, salt generation, -// and constant-time comparison to prevent timing attacks. -package security - -import ( - "crypto/hmac" - "crypto/rand" - "crypto/sha256" - "encoding/hex" - "fmt" -) - -const ( - // SHA256HexLen is the length of a hex-encoded SHA256 hash (32 bytes = 64 hex characters) - SHA256HexLen = 64 -) - -// GenerateSalt generates a cryptographically secure random salt for token hashing. -// Returns 16 bytes of random data from crypto/rand. -// -// Each session should have a unique salt to provide additional entropy and prevent -// attacks that work across multiple sessions. -func GenerateSalt() ([]byte, error) { - salt := make([]byte, 16) - if _, err := rand.Read(salt); err != nil { - return nil, fmt.Errorf("failed to generate salt: %w", err) - } - return salt, nil -} - -// HashToken returns the hex-encoded HMAC-SHA256 hash of a raw bearer token string. -// Uses HMAC with a server-managed secret and per-session salt to prevent offline -// attacks if session storage is compromised. -// -// For empty tokens (anonymous sessions) it returns the empty string, which is -// the sentinel value used to identify sessions created without credentials. -// The raw token is never stored — only the hash. -// -// Parameters: -// - token: The bearer token to hash -// - secret: Server-managed HMAC secret (should be 32+ bytes) -// - salt: Per-session random salt (typically 16 bytes) -// -// Security: Uses HMAC-SHA256 instead of plain SHA256 to prevent rainbow table -// attacks and offline brute force if session state leaks from Redis/Valkey. -func HashToken(token string, secret, salt []byte) string { - if token == "" { - return "" - } - h := hmac.New(sha256.New, secret) - h.Write(salt) - h.Write([]byte(token)) - return hex.EncodeToString(h.Sum(nil)) -} diff --git a/pkg/vmcp/session/security/security_test.go b/pkg/vmcp/session/security/security_test.go deleted file mode 100644 index 5aa6efd714..0000000000 --- a/pkg/vmcp/session/security/security_test.go +++ /dev/null @@ -1,272 +0,0 @@ -// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. -// SPDX-License-Identifier: Apache-2.0 - -package security_test - -import ( - "encoding/hex" - "strings" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - - "github.com/stacklok/toolhive/pkg/vmcp/session/security" -) - -const ( - testToken = "my-token" -) - -// TestGenerateSalt verifies that GenerateSalt produces cryptographically random salts. -func TestGenerateSalt(t *testing.T) { - t.Parallel() - - // Generate multiple salts - salt1, err := security.GenerateSalt() - require.NoError(t, err) - require.NotNil(t, salt1) - assert.Len(t, salt1, 16, "salt should be 16 bytes") - - salt2, err := security.GenerateSalt() - require.NoError(t, err) - require.NotNil(t, salt2) - assert.Len(t, salt2, 16, "salt should be 16 bytes") - - // Salts should be different (extremely high probability) - assert.NotEqual(t, salt1, salt2, "consecutive salts should be unique") - - // Salt should not be all zeros (indicates crypto/rand failure) - allZeros := make([]byte, 16) - assert.NotEqual(t, allZeros, salt1, "salt should not be all zeros") - assert.NotEqual(t, allZeros, salt2, "salt should not be all zeros") -} - -// TestHashToken_BasicFunctionality verifies HMAC-SHA256 token hashing. -func TestHashToken_BasicFunctionality(t *testing.T) { - t.Parallel() - - secret := []byte("test-hmac-secret-32-bytes-long!!") - salt := []byte("test-salt-16byte") - token := "my-bearer-token-12345" - - hash := security.HashToken(token, secret, salt) - - // SHA256 produces 32 bytes = 64 hex characters - assert.Len(t, hash, security.SHA256HexLen, "hash should be 64 hex characters") - - // Hash should be valid hex - _, err := hex.DecodeString(hash) - assert.NoError(t, err, "hash should be valid hex encoding") - - // Hash should be deterministic (same inputs → same output) - hash2 := security.HashToken(token, secret, salt) - assert.Equal(t, hash, hash2, "hashing should be deterministic") -} - -// TestHashToken_EmptyToken verifies that empty tokens return empty hash. -func TestHashToken_EmptyToken(t *testing.T) { - t.Parallel() - - secret := []byte("test-secret") - salt := []byte("test-salt") - - hash := security.HashToken("", secret, salt) - - assert.Equal(t, "", hash, "empty token should produce empty hash") -} - -// TestHashToken_DifferentInputs verifies that different inputs produce different hashes. -func TestHashToken_DifferentInputs(t *testing.T) { - t.Parallel() - - secret := []byte("test-hmac-secret") - salt := []byte("test-salt") - - token1 := "token-one" - token2 := "token-two" - - hash1 := security.HashToken(token1, secret, salt) - hash2 := security.HashToken(token2, secret, salt) - - assert.NotEqual(t, hash1, hash2, "different tokens should produce different hashes") - assert.Len(t, hash1, security.SHA256HexLen) - assert.Len(t, hash2, security.SHA256HexLen) -} - -// TestHashToken_DifferentSecrets verifies that different secrets produce different hashes. -func TestHashToken_DifferentSecrets(t *testing.T) { - t.Parallel() - - token := testToken - salt := []byte("same-salt") - - secret1 := []byte("secret-one") - secret2 := []byte("secret-two") - - hash1 := security.HashToken(token, secret1, salt) - hash2 := security.HashToken(token, secret2, salt) - - assert.NotEqual(t, hash1, hash2, "different secrets should produce different hashes") -} - -// TestHashToken_DifferentSalts verifies that different salts produce different hashes. -func TestHashToken_DifferentSalts(t *testing.T) { - t.Parallel() - - token := testToken - secret := []byte("same-secret") - - salt1 := []byte("salt-one") - salt2 := []byte("salt-two") - - hash1 := security.HashToken(token, secret, salt1) - hash2 := security.HashToken(token, secret, salt2) - - assert.NotEqual(t, hash1, hash2, "different salts should produce different hashes") -} - -// TestHashToken_EmptySecret verifies behavior with empty HMAC secret. -func TestHashToken_EmptySecret(t *testing.T) { - t.Parallel() - - token := testToken - salt := []byte("test-salt") - emptySecret := []byte{} - - // Should still produce a hash (HMAC allows empty key, though not recommended) - hash := security.HashToken(token, emptySecret, salt) - - assert.Len(t, hash, security.SHA256HexLen, "should produce valid hash even with empty secret") - assert.NotEqual(t, "", hash, "hash should not be empty") -} - -// TestHashToken_EmptySalt verifies behavior with empty salt. -func TestHashToken_EmptySalt(t *testing.T) { - t.Parallel() - - token := testToken - secret := []byte("test-secret") - emptySalt := []byte{} - - // Should still produce a hash (salt is optional for HMAC, though not recommended) - hash := security.HashToken(token, secret, emptySalt) - - assert.Len(t, hash, security.SHA256HexLen, "should produce valid hash even with empty salt") - assert.NotEqual(t, "", hash, "hash should not be empty") -} - -// TestHashToken_NilInputs verifies behavior with nil secret/salt. -func TestHashToken_NilInputs(t *testing.T) { - t.Parallel() - - token := testToken - - // Nil secret and salt should still work (treated as empty) - hash := security.HashToken(token, nil, nil) - - assert.Len(t, hash, security.SHA256HexLen, "should produce valid hash with nil inputs") - assert.NotEqual(t, "", hash, "hash should not be empty") -} - -// TestHashToken_LongToken verifies behavior with very long tokens. -func TestHashToken_LongToken(t *testing.T) { - t.Parallel() - - secret := []byte("test-secret") - salt := []byte("test-salt") - - // Very long token (10KB) - longToken := strings.Repeat("a", 10000) - - hash := security.HashToken(longToken, secret, salt) - - // HMAC-SHA256 always produces 64-character hex output regardless of input length - assert.Len(t, hash, security.SHA256HexLen, "hash length should be constant regardless of input length") -} - -// TestHashToken_SpecialCharacters verifies handling of tokens with special characters. -func TestHashToken_SpecialCharacters(t *testing.T) { - t.Parallel() - - secret := []byte("test-secret") - salt := []byte("test-salt") - - tests := []struct { - name string - token string - }{ - {"unicode", "token-with-üñíçödé-😀"}, - {"whitespace", "token with spaces\t\n"}, - {"symbols", "token!@#$%^&*()_+-={}[]|\\:;\"'<>,.?/"}, - {"null_bytes", "token\x00with\x00nulls"}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - t.Parallel() - - hash := security.HashToken(tt.token, secret, salt) - assert.Len(t, hash, security.SHA256HexLen, "should handle special characters") - assert.NotEqual(t, "", hash, "should produce non-empty hash") - - // Verify hex encoding is valid - _, err := hex.DecodeString(hash) - assert.NoError(t, err, "hash should be valid hex") - }) - } -} - -// TestSHA256HexLen_Constant verifies the constant value is correct. -func TestSHA256HexLen_Constant(t *testing.T) { - t.Parallel() - - // SHA256 produces 32 bytes = 64 hex characters - assert.Equal(t, 64, security.SHA256HexLen, "SHA256HexLen should be 64") -} - -// TestHashToken_Consistency verifies that the same inputs always produce the same hash -// across multiple invocations (regression test for determinism). -func TestHashToken_Consistency(t *testing.T) { - t.Parallel() - - secret := []byte("consistent-secret") - salt := []byte("consistent-salt") - token := "consistent-token" - - // Hash the same input 100 times - var hashes []string - for i := 0; i < 100; i++ { - hashes = append(hashes, security.HashToken(token, secret, salt)) - } - - // All hashes should be identical - firstHash := hashes[0] - for i, hash := range hashes { - assert.Equal(t, firstHash, hash, "hash at index %d should match first hash", i) - } -} - -// TestHashToken_NoCollisions verifies that different tokens produce different hashes. -func TestHashToken_NoCollisions(t *testing.T) { - t.Parallel() - - secret := []byte("test-secret") - salt := []byte("test-salt") - - // Generate hashes for many different tokens - seen := make(map[string]string) - for i := 0; i < 1000; i++ { - token := hex.EncodeToString([]byte{byte(i / 256), byte(i % 256)}) - hash := security.HashToken(token, secret, salt) - - // Check for collision - if existingToken, exists := seen[hash]; exists { - t.Errorf("collision detected: tokens %q and %q produced same hash %q", - existingToken, token, hash) - } - seen[hash] = token - } - - assert.Len(t, seen, 1000, "should have 1000 unique hashes") -} diff --git a/pkg/vmcp/session/session.go b/pkg/vmcp/session/session.go index 1bcae51dbb..39bb514489 100644 --- a/pkg/vmcp/session/session.go +++ b/pkg/vmcp/session/session.go @@ -4,57 +4,37 @@ package session import ( - transportsession "github.com/stacklok/toolhive/pkg/transport/session" - "github.com/stacklok/toolhive/pkg/vmcp" + "github.com/stacklok/toolhive/pkg/auth" sessiontypes "github.com/stacklok/toolhive/pkg/vmcp/session/types" ) -// MultiSession is the vMCP domain session interface. It extends the -// transport-layer Session with behaviour: capability access and session-scoped -// backend routing across multiple backend connections. -// -// A MultiSession is a "session of sessions": each backend contributes its own -// persistent connection (see [backend.Session] in pkg/vmcp/session/internal/backend), -// and the MultiSession aggregates them behind a single routing table. -// -// # Distributed deployment note -// -// Because MCP clients cannot be serialised, horizontal scaling requires sticky -// sessions (session affinity at the load balancer). Without sticky sessions, a -// request routed to a different vMCP instance must recreate backend clients -// (one-time cost per re-route). This is an accepted trade-off. -// -// # Storage -// -// A MultiSession uses a two-layer storage model: -// -// - Runtime layer (in-process only): backend HTTP connections, routing -// table, and capability lists. These cannot be serialized and are lost -// when the process exits. Sessions are therefore node-local. -// -// - Metadata layer (serializable): identity subject and connected backend -// IDs are written to the embedded transportsession.Session so that -// pluggable transportsession.Storage backends (e.g. Redis) can persist -// them. This enables auditing and future session reconstruction, but -// does not make the session itself portable — the runtime layer must -// be rebuilt from scratch on a different node. -type MultiSession interface { - transportsession.Session - sessiontypes.Caller - - // Tools returns the resolved tools available in this session. - // The list is built once at session creation and is read-only thereafter. - Tools() []vmcp.Tool - - // Resources returns the resolved resources available in this session. - Resources() []vmcp.Resource +// MultiSession is an alias for sessiontypes.MultiSession, re-exported here for +// backward compatibility and convenience. +type MultiSession = sessiontypes.MultiSession - // Prompts returns the resolved prompts available in this session. - Prompts() []vmcp.Prompt +const ( + // MetadataKeyTokenHash is the session metadata key that holds the HMAC-SHA256 + // hash of the bearer token used to create the session. For authenticated sessions + // this is hex(HMAC-SHA256(bearerToken)). For anonymous sessions this is the empty + // string sentinel. The raw token is never stored — only the hash. + // + // Re-exported from types package for convenience. + MetadataKeyTokenHash = sessiontypes.MetadataKeyTokenHash +) - // BackendSessions returns a snapshot of the backend-assigned session IDs, - // keyed by backend workload ID. The backend session ID is assigned by the - // backend MCP server and is used to correlate vMCP sessions with backend - // sessions for debugging and auditing. - BackendSessions() map[string]string +// ShouldAllowAnonymous determines if a session should allow anonymous access +// based on the creator's identity. This is session business logic that decides +// whether a session is bound to a specific identity or allows anonymous access. +// +// Sessions without an identity (nil) or with an empty token are treated as +// anonymous and will only accept nil callers or callers with an empty token; +// callers presenting a non-empty token are rejected to prevent session-upgrade +// attacks. Sessions with a non-empty bearer token are bound to that token and +// will reject requests from callers with a different token. +// +// This function is used by both the session factory (to determine how to create +// the session) and the security layer (to validate requests against the session's +// access policy). +func ShouldAllowAnonymous(identity *auth.Identity) bool { + return identity == nil || identity.Token == "" } diff --git a/pkg/vmcp/session/token_binding_test.go b/pkg/vmcp/session/token_binding_test.go index 0c0ffe3dd0..8c18c181f4 100644 --- a/pkg/vmcp/session/token_binding_test.go +++ b/pkg/vmcp/session/token_binding_test.go @@ -5,82 +5,18 @@ package session import ( "context" - "crypto/hmac" - "crypto/sha256" - "encoding/hex" + "errors" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/stacklok/toolhive/pkg/auth" - transportsession "github.com/stacklok/toolhive/pkg/transport/session" "github.com/stacklok/toolhive/pkg/vmcp" internalbk "github.com/stacklok/toolhive/pkg/vmcp/session/internal/backend" - "github.com/stacklok/toolhive/pkg/vmcp/session/security" sessiontypes "github.com/stacklok/toolhive/pkg/vmcp/session/types" ) -var ( - // Test HMAC secret and salt for consistent test results - testSecret = []byte("test-secret") - testTokenSalt = []byte("test-salt-123456") // 16 bytes -) - -// --------------------------------------------------------------------------- -// HashToken -// --------------------------------------------------------------------------- - -func TestHashToken(t *testing.T) { - t.Parallel() - - tests := []struct { - name string - token string - want string - }{ - { - name: "empty token returns anonymous sentinel", - token: "", - want: "", - }, - { - name: "non-empty token returns HMAC-SHA256 hex", - token: "my-bearer-token", - want: func() string { - h := hmac.New(sha256.New, testSecret) - h.Write(testTokenSalt) - h.Write([]byte("my-bearer-token")) - return hex.EncodeToString(h.Sum(nil)) - }(), - }, - { - name: "different tokens produce different hashes", - token: "another-token", - want: func() string { - h := hmac.New(sha256.New, testSecret) - h.Write(testTokenSalt) - h.Write([]byte("another-token")) - return hex.EncodeToString(h.Sum(nil)) - }(), - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - t.Parallel() - assert.Equal(t, tt.want, security.HashToken(tt.token, testSecret, testTokenSalt)) - }) - } -} - -// --------------------------------------------------------------------------- -// Note: ComputeTokenHash was removed -// --------------------------------------------------------------------------- -// ComputeTokenHash was removed because HMAC-SHA256 hashing requires -// per-session salt, so token hashes can't be computed without session context. -// Use HashToken(token, secret, salt) directly with session-specific parameters. - // --------------------------------------------------------------------------- // makeSession stores token hash in metadata // --------------------------------------------------------------------------- @@ -97,7 +33,7 @@ func nilBackendConnector() backendConnector { func TestMakeSession_StoresTokenHash(t *testing.T) { t.Parallel() - t.Run("authenticated session stores HMAC-SHA256 hash and salt", func(t *testing.T) { + t.Run("authenticated session stores HMAC-SHA256 hash", func(t *testing.T) { t.Parallel() const rawToken = "test-bearer-token" @@ -116,9 +52,9 @@ func TestMakeSession_StoresTokenHash(t *testing.T) { // Raw token must never appear in metadata. assert.NotEqual(t, rawToken, storedHash) - // Verify salt is stored - storedSalt, saltPresent := sess.GetMetadata()[MetadataKeyTokenSalt] - require.True(t, saltPresent, "MetadataKeyTokenSalt must be set") + // Verify salt is stored for authenticated sessions + storedSalt, saltPresent := sess.GetMetadata()[sessiontypes.MetadataKeyTokenSalt] + require.True(t, saltPresent, "MetadataKeyTokenSalt must be set for authenticated sessions") assert.NotEmpty(t, storedSalt, "Salt must be non-empty for authenticated session") }) @@ -134,9 +70,9 @@ func TestMakeSession_StoresTokenHash(t *testing.T) { require.True(t, present, "MetadataKeyTokenHash must be set even for anonymous sessions") assert.Empty(t, storedHash, "anonymous session must store empty sentinel") - // Anonymous sessions should not have salt - storedSalt := sess.GetMetadata()[MetadataKeyTokenSalt] - assert.Empty(t, storedSalt, "anonymous session should not have salt") + // Salt must not be present for anonymous sessions + storedSalt := sess.GetMetadata()[sessiontypes.MetadataKeyTokenSalt] + assert.Empty(t, storedSalt, "anonymous session must not store a salt") }) t.Run("identity with empty token stores empty sentinel", func(t *testing.T) { @@ -151,12 +87,12 @@ func TestMakeSession_StoresTokenHash(t *testing.T) { storedHash := sess.GetMetadata()[MetadataKeyTokenHash] assert.Empty(t, storedHash, "empty-token identity must store empty sentinel") - // Empty token should not have salt - storedSalt := sess.GetMetadata()[MetadataKeyTokenSalt] - assert.Empty(t, storedSalt, "empty-token identity should not have salt") + // Salt must not be present for empty-token (anonymous) sessions + storedSalt := sess.GetMetadata()[sessiontypes.MetadataKeyTokenSalt] + assert.Empty(t, storedSalt, "empty-token identity must not store a salt") }) - t.Run("MakeSessionWithID also stores token hash and salt", func(t *testing.T) { + t.Run("MakeSessionWithID also stores token hash", func(t *testing.T) { t.Parallel() const rawToken = "id-specific-token" @@ -173,204 +109,13 @@ func TestMakeSession_StoresTokenHash(t *testing.T) { assert.NotEmpty(t, storedHash, "Token hash must be non-empty") assert.Len(t, storedHash, 64, "HMAC-SHA256 hex-encoded hash should be 64 characters") - // Verify salt - storedSalt, saltPresent := sess.GetMetadata()[MetadataKeyTokenSalt] - require.True(t, saltPresent, "MetadataKeyTokenSalt must be set") - assert.NotEmpty(t, storedSalt, "Salt must be non-empty") + // Verify salt is stored for authenticated sessions + storedSalt, saltPresent := sess.GetMetadata()[sessiontypes.MetadataKeyTokenSalt] + require.True(t, saltPresent, "MetadataKeyTokenSalt must be set for authenticated sessions") + assert.NotEmpty(t, storedSalt, "Salt must be non-empty for authenticated session") }) } -// --------------------------------------------------------------------------- -// Caller validation -// --------------------------------------------------------------------------- - -// TestValidateCaller_EdgeCases tests edge cases in caller validation logic. -func TestValidateCaller_EdgeCases(t *testing.T) { - t.Parallel() - - tests := []struct { - name string - allowAnonymous bool - boundTokenHash string - caller *auth.Identity - wantErr error - }{ - { - name: "anonymous session with nil caller", - allowAnonymous: true, - boundTokenHash: "", - caller: nil, - wantErr: nil, // Should succeed - }, - { - name: "anonymous session rejects caller with token", - allowAnonymous: true, - boundTokenHash: "", - caller: &auth.Identity{Subject: "user", Token: "token"}, - wantErr: sessiontypes.ErrUnauthorizedCaller, // Prevent session upgrade attack - }, - { - name: "bound session with nil caller", - allowAnonymous: false, - boundTokenHash: security.HashToken("correct-token", testSecret, testTokenSalt), - caller: nil, - wantErr: sessiontypes.ErrNilCaller, - }, - { - name: "bound session with matching token", - allowAnonymous: false, - boundTokenHash: security.HashToken("correct-token", testSecret, testTokenSalt), - caller: &auth.Identity{Subject: "user", Token: "correct-token"}, - wantErr: nil, // Should succeed - }, - { - name: "bound session with wrong token", - allowAnonymous: false, - boundTokenHash: security.HashToken("correct-token", testSecret, testTokenSalt), - caller: &auth.Identity{Subject: "user", Token: "wrong-token"}, - wantErr: sessiontypes.ErrUnauthorizedCaller, - }, - { - name: "bound session with empty token in identity", - allowAnonymous: false, - boundTokenHash: security.HashToken("correct-token", testSecret, testTokenSalt), - caller: &auth.Identity{Subject: "user", Token: ""}, - wantErr: sessiontypes.ErrUnauthorizedCaller, - }, - { - name: "anonymous session accepts caller with empty token", - allowAnonymous: true, - boundTokenHash: "", - caller: &auth.Identity{Subject: "user", Token: ""}, - wantErr: nil, // Empty token is equivalent to no token - }, - { - name: "misconfigured bound session with empty hash rejects empty token", - allowAnonymous: false, - boundTokenHash: "", // Misconfiguration: bound but no hash - caller: &auth.Identity{Subject: "user", Token: ""}, - wantErr: sessiontypes.ErrSessionOwnerUnknown, // Fail closed - }, - { - name: "misconfigured bound session with empty hash rejects nil caller", - allowAnonymous: false, - boundTokenHash: "", // Misconfiguration: bound but no hash - caller: nil, - wantErr: sessiontypes.ErrNilCaller, // Nil check happens first - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - t.Parallel() - - // Create a base session - baseSession := &defaultMultiSession{ - Session: transportsession.NewStreamableSession("test-session"), - } - - // Wrap with decorator that has the test configuration - decorator := &HijackPreventionDecorator{ - MultiSession: baseSession, - allowAnonymous: tt.allowAnonymous, - boundTokenHash: tt.boundTokenHash, - tokenSalt: testTokenSalt, - hmacSecret: testSecret, - } - - // Test validateCaller directly on the decorator - err := decorator.validateCaller(tt.caller) - - if tt.wantErr != nil { - require.Error(t, err) - assert.ErrorIs(t, err, tt.wantErr) - } else { - require.NoError(t, err) - } - }) - } -} - -// TestConcurrentValidation tests that validateCaller is safe for concurrent use. -func TestConcurrentValidation(t *testing.T) { - t.Parallel() - - baseSession := &defaultMultiSession{ - Session: transportsession.NewStreamableSession("test-session"), - } - - decorator := &HijackPreventionDecorator{ - MultiSession: baseSession, - allowAnonymous: false, - boundTokenHash: security.HashToken("test-token", testSecret, testTokenSalt), - tokenSalt: testTokenSalt, - hmacSecret: testSecret, - } - - // Run validation concurrently from multiple goroutines - // Collect errors in channel to avoid race conditions with testify assertions - const numGoroutines = 10 - errChan := make(chan error, numGoroutines) - - for i := 0; i < numGoroutines; i++ { - go func() { - caller := &auth.Identity{Subject: "user", Token: "test-token"} - err := decorator.validateCaller(caller) - errChan <- err - }() - } - - // Wait for all goroutines and assert in main goroutine (thread-safe) - for i := 0; i < numGoroutines; i++ { - err := <-errChan - assert.NoError(t, err, "concurrent validation should succeed") - } -} - -// --------------------------------------------------------------------------- -// ShouldAllowAnonymous helper -// --------------------------------------------------------------------------- - -// TestShouldAllowAnonymous_EdgeCases tests the ShouldAllowAnonymous helper. -func TestShouldAllowAnonymous_EdgeCases(t *testing.T) { - t.Parallel() - - tests := []struct { - name string - identity *auth.Identity - want bool - }{ - { - name: "nil identity", - identity: nil, - want: true, - }, - { - name: "non-nil identity with token", - identity: &auth.Identity{Subject: "user", Token: "token"}, - want: false, - }, - { - name: "non-nil identity with empty token", - identity: &auth.Identity{Subject: "user", Token: ""}, - want: true, // Empty token is treated as anonymous - }, - { - name: "non-nil identity with empty subject", - identity: &auth.Identity{Subject: "", Token: "token"}, - want: false, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - t.Parallel() - got := ShouldAllowAnonymous(tt.identity) - assert.Equal(t, tt.want, got) - }) - } -} - // --------------------------------------------------------------------------- // MakeSessionWithID validation // --------------------------------------------------------------------------- @@ -492,16 +237,22 @@ func TestWithHMACSecret_DefensiveCopy(t *testing.T) { require.NotEmpty(t, hash2, "second session should have token hash") // Both sessions should still be able to validate the original token - // (proving the factory used the original secret, not the modified one) - decorator1, ok := sess1.(*HijackPreventionDecorator) - require.True(t, ok, "session should be HijackPreventionDecorator") - err = decorator1.validateCaller(identity) - assert.NoError(t, err, "first session should validate original token despite external modification") - - decorator2, ok := sess2.(*HijackPreventionDecorator) - require.True(t, ok, "session should be HijackPreventionDecorator") - err = decorator2.validateCaller(identity) - assert.NoError(t, err, "second session should validate original token despite external modification") + // (proving the factory used the original secret, not the modified one). + // We verify this by calling a session method that requires authentication. + ctx := context.Background() + + // First session should accept the original token and fail with ErrToolNotFound, + // not an auth error (which would indicate the secret was corrupted) + _, err = sess1.CallTool(ctx, identity, "nonexistent-tool", nil, nil) + assert.ErrorIs(t, err, ErrToolNotFound, "should fail with tool not found error") + assert.False(t, errors.Is(err, sessiontypes.ErrUnauthorizedCaller), + "should not be an auth error (would indicate corrupted secret)") + + // Second session should also accept the original token and fail with ErrToolNotFound + _, err = sess2.CallTool(ctx, identity, "nonexistent-tool", nil, nil) + assert.ErrorIs(t, err, ErrToolNotFound, "should fail with tool not found error") + assert.False(t, errors.Is(err, sessiontypes.ErrUnauthorizedCaller), + "should not be an auth error (would indicate corrupted secret)") } // TestWithHMACSecret_RejectsEmptySecret verifies that WithHMACSecret rejects diff --git a/pkg/vmcp/session/types/session.go b/pkg/vmcp/session/types/session.go index 6d5b3f4ed0..70163d7590 100644 --- a/pkg/vmcp/session/types/session.go +++ b/pkg/vmcp/session/types/session.go @@ -12,6 +12,7 @@ import ( "errors" "github.com/stacklok/toolhive/pkg/auth" + transportsession "github.com/stacklok/toolhive/pkg/transport/session" "github.com/stacklok/toolhive/pkg/vmcp" ) @@ -75,6 +76,76 @@ type Caller interface { Close() error } +// MultiSession is the vMCP domain session interface. It extends the +// transport-layer Session with behaviour: capability access and session-scoped +// backend routing across multiple backend connections. +// +// A MultiSession is a "session of sessions": each backend contributes its own +// persistent connection (see [backend.Session] in pkg/vmcp/session/internal/backend), +// and the MultiSession aggregates them behind a single routing table. +// +// # Distributed deployment note +// +// Because MCP clients cannot be serialised, horizontal scaling requires sticky +// sessions (session affinity at the load balancer). Without sticky sessions, a +// request routed to a different vMCP instance must recreate backend clients +// (one-time cost per re-route). This is an accepted trade-off. +// +// # Storage +// +// A MultiSession uses a two-layer storage model: +// +// - Runtime layer (in-process only): backend HTTP connections, routing +// table, and capability lists. These cannot be serialized and are lost +// when the process exits. Sessions are therefore node-local. +// +// - Metadata layer (serializable): identity subject and connected backend +// IDs are written to the embedded transportsession.Session so that +// pluggable transportsession.Storage backends (e.g. Redis) can persist +// them. This enables auditing and future session reconstruction, but +// does not make the session itself portable — the runtime layer must +// be rebuilt from scratch on a different node. +type MultiSession interface { + transportsession.Session + Caller + + // Tools returns the resolved tools available in this session. + // The list is built once at session creation and is read-only thereafter. + Tools() []vmcp.Tool + + // Resources returns the resolved resources available in this session. + Resources() []vmcp.Resource + + // Prompts returns the resolved prompts available in this session. + Prompts() []vmcp.Prompt + + // BackendSessions returns a snapshot of the backend-assigned session IDs, + // keyed by backend workload ID. The backend session ID is assigned by the + // backend MCP server and is used to correlate vMCP sessions with backend + // sessions for debugging and auditing. + BackendSessions() map[string]string +} + +const ( + // MetadataKeyTokenHash is the session metadata key that holds the HMAC-SHA256 + // hash of the bearer token used to create the session. For authenticated sessions + // this is hex(HMAC-SHA256(bearerToken)). For anonymous sessions this is the empty + // string sentinel. The raw token is never stored — only the hash. + // + // This constant is the single source of truth used by the session factory and + // security layer to store and validate token binding metadata. + MetadataKeyTokenHash = "vmcp.token.hash" //nolint:gosec // This is a metadata key name, not a credential. + + // MetadataKeyTokenSalt is the session metadata key that holds the hex-encoded + // random salt used for HMAC-SHA256 token hashing. Each authenticated session has a + // unique salt to prevent attacks across multiple sessions. Anonymous sessions do not + // generate a salt and this key is omitted from their metadata. + // + // This constant is the single source of truth used by the session factory and + // security layer to store and validate token binding metadata. + MetadataKeyTokenSalt = "vmcp.token.salt" //nolint:gosec // This is a metadata key name, not a credential. +) + // Token binding errors returned by Caller methods when caller identity // validation fails. var (