From d8635fb54f5abab0e20cce550854b27cfca784e3 Mon Sep 17 00:00:00 2001 From: Sas Swart Date: Thu, 30 Apr 2026 12:52:47 +0000 Subject: [PATCH] feat(proxy): inject session ID and sequence number headers on matching requests When session correlation is enabled and the outgoing request matches a configured inject target, set X-Coder-Agent-Firewall-Session-Id and X-Coder-Agent-Firewall-Sequence-Number headers on the forwarded request. Any values the jailed client may have set are overwritten so the upstream always sees boundary's authoritative session ID and sequence number. The sequence number is pre-allocated before the audit event so both the audit log and the injected header carry the same value. audit.Request gains a SequenceNumber pointer field; when non-nil the socket auditor uses it instead of calling its own counter. New proxy.Config fields: SessionCorrelation, SessionID, SequenceCounter. New Server method: shouldInjectHeaders (domain + optional path glob matching). Tests cover matched domain, unmatched domain, disabled injection, client-supplied header overwrite, path glob matching, and sequence number incrementing. --- audit/request.go | 7 + audit/socket_auditor.go | 9 +- proxy/proxy.go | 102 +++++++-- proxy/proxy_framework_test.go | 59 +++-- proxy/proxy_session_correlation_test.go | 273 ++++++++++++++++++++++++ 5 files changed, 414 insertions(+), 36 deletions(-) create mode 100644 proxy/proxy_session_correlation_test.go diff --git a/audit/request.go b/audit/request.go index c6ef1b37..e23f774e 100644 --- a/audit/request.go +++ b/audit/request.go @@ -11,4 +11,11 @@ type Request struct { Host string Allowed bool Rule string // The rule that matched (if any) + + // SequenceNumber is a pre-allocated sequence number for this + // audit event. When non-nil the auditor must use this value + // instead of generating its own so that the audit log and + // any injected HTTP header carry the same number. When nil + // the auditor falls back to its internal SequenceCounter. + SequenceNumber *uint64 } diff --git a/audit/socket_auditor.go b/audit/socket_auditor.go index afb50a34..06fce950 100644 --- a/audit/socket_auditor.go +++ b/audit/socket_auditor.go @@ -81,10 +81,17 @@ func (s *SocketAuditor) AuditRequest(req Request) { httpReq.MatchedRule = req.Rule } + var seqNum uint64 + if req.SequenceNumber != nil { + seqNum = *req.SequenceNumber + } else { + seqNum = s.seq.Next() + } + log := &agentproto.BoundaryLog{ Allowed: req.Allowed, Time: timestamppb.Now(), - SequenceNumber: s.seq.Next(), + SequenceNumber: seqNum, Resource: &agentproto.BoundaryLog_HttpRequest_{HttpRequest: httpReq}, } diff --git a/proxy/proxy.go b/proxy/proxy.go index 154a6bcc..0fb5dfa7 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -13,23 +13,28 @@ import ( "net/http" _ "net/http/pprof" "net/url" + "path" "strconv" "strings" "sync/atomic" "time" "github.com/coder/boundary/audit" + "github.com/coder/boundary/config" "github.com/coder/boundary/rulesengine" ) // Server handles HTTP and HTTPS requests with rule-based filtering type Server struct { - ruleEngine rulesengine.Engine - auditor audit.Auditor - logger *slog.Logger - tlsConfig *tls.Config - httpPort int - started atomic.Bool + ruleEngine rulesengine.Engine + auditor audit.Auditor + logger *slog.Logger + tlsConfig *tls.Config + httpPort int + started atomic.Bool + sessionCorrelation config.SessionCorrelationConfig + sessionID string + seqCounter *audit.SequenceCounter listener net.Listener pprofServer *http.Server @@ -46,18 +51,30 @@ type Config struct { TLSConfig *tls.Config PprofEnabled bool PprofPort int + // SessionCorrelation controls header injection for AI Bridge + // correlation. See config.SessionCorrelationConfig for details. + SessionCorrelation config.SessionCorrelationConfig + // SessionID is the boundary session UUID injected as a header + // on matching requests. + SessionID string + // SequenceCounter provides monotonically increasing sequence + // numbers shared with the auditor so both carry the same value. + SequenceCounter *audit.SequenceCounter } // NewProxyServer creates a new proxy server instance func NewProxyServer(config Config) *Server { return &Server{ - ruleEngine: config.RuleEngine, - auditor: config.Auditor, - logger: config.Logger, - tlsConfig: config.TLSConfig, - httpPort: config.HTTPPort, - pprofEnabled: config.PprofEnabled, - pprofPort: config.PprofPort, + ruleEngine: config.RuleEngine, + auditor: config.Auditor, + logger: config.Logger, + tlsConfig: config.TLSConfig, + httpPort: config.HTTPPort, + pprofEnabled: config.PprofEnabled, + pprofPort: config.PprofPort, + sessionCorrelation: config.SessionCorrelation, + sessionID: config.SessionID, + seqCounter: config.SequenceCounter, } } @@ -276,12 +293,21 @@ func (p *Server) processHTTPRequest(conn net.Conn, req *http.Request, https bool result := p.ruleEngine.Evaluate(req.Method, fullURL) + // Pre-allocate a sequence number so the audit event and any + // injected header carry the same value. + var seqNum *uint64 + if p.seqCounter != nil { + n := p.seqCounter.Next() + seqNum = &n + } + p.auditor.AuditRequest(audit.Request{ - Method: req.Method, - URL: fullURL, - Host: req.Host, - Allowed: result.Allowed, - Rule: result.Rule, + Method: req.Method, + URL: fullURL, + Host: req.Host, + Allowed: result.Allowed, + Rule: result.Rule, + SequenceNumber: seqNum, }) if !result.Allowed { @@ -290,10 +316,36 @@ func (p *Server) processHTTPRequest(conn net.Conn, req *http.Request, https bool } // Forward request to destination - p.forwardRequest(conn, req, https) + p.forwardRequest(conn, req, https, seqNum) +} + +// shouldInjectHeaders reports whether the request to the given host +// and path matches any configured inject target. When session +// correlation is disabled or no targets match it returns false. +func (p *Server) shouldInjectHeaders(host, reqPath string) bool { + if !p.sessionCorrelation.Enabled { + return false + } + // Strip port from host for matching (e.g. "example.com:443" -> "example.com"). + h := host + if i := strings.LastIndex(h, ":"); i != -1 { + h = h[:i] + } + for _, target := range p.sessionCorrelation.InjectTargets { + if !strings.EqualFold(target.Domain, h) { + continue + } + if target.Path == "" { + return true + } + if matched, _ := path.Match(target.Path, reqPath); matched { + return true + } + } + return false } -func (p *Server) forwardRequest(conn net.Conn, req *http.Request, https bool) { +func (p *Server) forwardRequest(conn net.Conn, req *http.Request, https bool, seqNum *uint64) { // Create HTTP client client := &http.Client{ CheckRedirect: func(req *http.Request, via []*http.Request) error { @@ -338,6 +390,16 @@ func (p *Server) forwardRequest(conn net.Conn, req *http.Request, https bool) { } } + // Stamp session correlation headers on matching requests, + // overwriting any value the jailed client may have set so the + // upstream always sees boundary's ID. + if p.shouldInjectHeaders(req.Host, req.URL.Path) { + newReq.Header.Set(p.sessionCorrelation.SessionIDHeaderName, p.sessionID) + if seqNum != nil { + newReq.Header.Set(p.sessionCorrelation.SequenceNumberHeaderName, strconv.FormatUint(*seqNum, 10)) + } + } + // Make request to destination resp, err := client.Do(newReq) if err != nil { diff --git a/proxy/proxy_framework_test.go b/proxy/proxy_framework_test.go index 36a332bc..97c9975f 100644 --- a/proxy/proxy_framework_test.go +++ b/proxy/proxy_framework_test.go @@ -18,6 +18,7 @@ import ( "time" "github.com/coder/boundary/audit" + "github.com/coder/boundary/config" "github.com/coder/boundary/rulesengine" boundary_tls "github.com/coder/boundary/tls" "github.com/stretchr/testify/require" @@ -32,16 +33,19 @@ func (m *mockAuditor) AuditRequest(req audit.Request) { // ProxyTest is a high-level test framework for proxy tests type ProxyTest struct { - t *testing.T - server *Server - client *http.Client - proxyClient *http.Client - port int - useCertManager bool - configDir string - startupDelay time.Duration - allowedRules []string - auditor audit.Auditor + t *testing.T + server *Server + client *http.Client + proxyClient *http.Client + port int + useCertManager bool + configDir string + startupDelay time.Duration + allowedRules []string + auditor audit.Auditor + sessionCorrelation config.SessionCorrelationConfig + sessionID string + seqCounter *audit.SequenceCounter } // ProxyTestOption is a function that configures ProxyTest @@ -109,6 +113,28 @@ func WithAuditor(auditor audit.Auditor) ProxyTestOption { } } +// WithSessionCorrelation sets the session correlation config for the +// proxy under test. +func WithSessionCorrelation(sc config.SessionCorrelationConfig) ProxyTestOption { + return func(pt *ProxyTest) { + pt.sessionCorrelation = sc + } +} + +// WithSessionID sets the boundary session ID for the proxy under test. +func WithSessionID(id string) ProxyTestOption { + return func(pt *ProxyTest) { + pt.sessionID = id + } +} + +// WithSequenceCounter sets the sequence counter for the proxy under test. +func WithSequenceCounter(seq *audit.SequenceCounter) ProxyTestOption { + return func(pt *ProxyTest) { + pt.seqCounter = seq + } +} + // Start starts the proxy server func (pt *ProxyTest) Start() *ProxyTest { pt.t.Helper() @@ -153,11 +179,14 @@ func (pt *ProxyTest) Start() *ProxyTest { } pt.server = NewProxyServer(Config{ - HTTPPort: pt.port, - RuleEngine: ruleEngine, - Auditor: auditor, - Logger: logger, - TLSConfig: tlsConfig, + HTTPPort: pt.port, + RuleEngine: ruleEngine, + Auditor: auditor, + Logger: logger, + TLSConfig: tlsConfig, + SessionCorrelation: pt.sessionCorrelation, + SessionID: pt.sessionID, + SequenceCounter: pt.seqCounter, }) err = pt.server.Start() diff --git a/proxy/proxy_session_correlation_test.go b/proxy/proxy_session_correlation_test.go new file mode 100644 index 00000000..40a419e7 --- /dev/null +++ b/proxy/proxy_session_correlation_test.go @@ -0,0 +1,273 @@ +package proxy + +import ( + "net/http" + "net/http/httptest" + "net/url" + "sync" + "testing" + + "github.com/coder/boundary/audit" + "github.com/coder/boundary/config" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// headerCapturingBackend spins up an httptest.Server that records the +// headers it receives. Call receivedHeaders after the request to inspect +// them. +type headerCapturingBackend struct { + server *httptest.Server + mu sync.Mutex + headers http.Header +} + +func newHeaderCapturingBackend() *headerCapturingBackend { + hcb := &headerCapturingBackend{} + hcb.server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + hcb.mu.Lock() + hcb.headers = r.Header.Clone() + hcb.mu.Unlock() + w.WriteHeader(http.StatusOK) + })) + return hcb +} + +func (h *headerCapturingBackend) close() { h.server.Close() } + +func (h *headerCapturingBackend) receivedHeaders() http.Header { + h.mu.Lock() + defer h.mu.Unlock() + return h.headers.Clone() +} + +func TestSessionCorrelation_MatchedDomain(t *testing.T) { + backend := newHeaderCapturingBackend() + defer backend.close() + + backendURL, err := url.Parse(backend.server.URL) + require.NoError(t, err) + + seq := &audit.SequenceCounter{} + + pt := NewProxyTest(t, + WithCertManager(t.TempDir()), + WithAllowedDomain(backendURL.Hostname()), + WithSessionCorrelation(config.SessionCorrelationConfig{ + Enabled: true, + InjectTargets: []config.InjectTarget{{Domain: backendURL.Hostname()}}, + SessionIDHeaderName: config.DefaultSessionIDHeaderName, + SequenceNumberHeaderName: config.DefaultSequenceNumberHeaderName, + }), + WithSessionID("test-session-id-1234"), + WithSequenceCounter(seq), + ).Start() + defer pt.Stop() + + resp, err := pt.proxyClient.Get(backend.server.URL + "/api/v2") + require.NoError(t, err) + defer resp.Body.Close() //nolint:errcheck + require.Equal(t, http.StatusOK, resp.StatusCode) + + got := backend.receivedHeaders() + assert.Equal(t, "test-session-id-1234", got.Get(config.DefaultSessionIDHeaderName), + "session ID header must be injected on matching domain") + assert.Equal(t, "0", got.Get(config.DefaultSequenceNumberHeaderName), + "sequence number header must start at 0") +} + +func TestSessionCorrelation_UnmatchedDomain(t *testing.T) { + backend := newHeaderCapturingBackend() + defer backend.close() + + backendURL, err := url.Parse(backend.server.URL) + require.NoError(t, err) + + seq := &audit.SequenceCounter{} + + pt := NewProxyTest(t, + WithCertManager(t.TempDir()), + WithAllowedDomain(backendURL.Hostname()), + WithSessionCorrelation(config.SessionCorrelationConfig{ + Enabled: true, + InjectTargets: []config.InjectTarget{{Domain: "other-domain.example.com"}}, + SessionIDHeaderName: config.DefaultSessionIDHeaderName, + SequenceNumberHeaderName: config.DefaultSequenceNumberHeaderName, + }), + WithSessionID("test-session-id-1234"), + WithSequenceCounter(seq), + ).Start() + defer pt.Stop() + + resp, err := pt.proxyClient.Get(backend.server.URL + "/api/v2") + require.NoError(t, err) + defer resp.Body.Close() //nolint:errcheck + require.Equal(t, http.StatusOK, resp.StatusCode) + + got := backend.receivedHeaders() + assert.Empty(t, got.Get(config.DefaultSessionIDHeaderName), + "session ID header must not be injected on unmatched domain") + assert.Empty(t, got.Get(config.DefaultSequenceNumberHeaderName), + "sequence number header must not be injected on unmatched domain") +} + +func TestSessionCorrelation_Disabled(t *testing.T) { + backend := newHeaderCapturingBackend() + defer backend.close() + + backendURL, err := url.Parse(backend.server.URL) + require.NoError(t, err) + + seq := &audit.SequenceCounter{} + + pt := NewProxyTest(t, + WithCertManager(t.TempDir()), + WithAllowedDomain(backendURL.Hostname()), + WithSessionCorrelation(config.SessionCorrelationConfig{ + Enabled: false, + InjectTargets: []config.InjectTarget{{Domain: backendURL.Hostname()}}, + SessionIDHeaderName: config.DefaultSessionIDHeaderName, + SequenceNumberHeaderName: config.DefaultSequenceNumberHeaderName, + }), + WithSessionID("test-session-id-1234"), + WithSequenceCounter(seq), + ).Start() + defer pt.Stop() + + resp, err := pt.proxyClient.Get(backend.server.URL + "/api/v2") + require.NoError(t, err) + defer resp.Body.Close() //nolint:errcheck + require.Equal(t, http.StatusOK, resp.StatusCode) + + got := backend.receivedHeaders() + assert.Empty(t, got.Get(config.DefaultSessionIDHeaderName), + "session ID header must not be injected when correlation is disabled") + assert.Empty(t, got.Get(config.DefaultSequenceNumberHeaderName), + "sequence number header must not be injected when correlation is disabled") +} + +func TestSessionCorrelation_OverwritesClientValue(t *testing.T) { + backend := newHeaderCapturingBackend() + defer backend.close() + + backendURL, err := url.Parse(backend.server.URL) + require.NoError(t, err) + + seq := &audit.SequenceCounter{} + + pt := NewProxyTest(t, + WithCertManager(t.TempDir()), + WithAllowedDomain(backendURL.Hostname()), + WithSessionCorrelation(config.SessionCorrelationConfig{ + Enabled: true, + InjectTargets: []config.InjectTarget{{Domain: backendURL.Hostname()}}, + SessionIDHeaderName: config.DefaultSessionIDHeaderName, + SequenceNumberHeaderName: config.DefaultSequenceNumberHeaderName, + }), + WithSessionID("real-session-id"), + WithSequenceCounter(seq), + ).Start() + defer pt.Stop() + + // Send a request with client-supplied session correlation headers + // that should be overwritten by the proxy. + req, err := http.NewRequest(http.MethodGet, backend.server.URL+"/api/v2", nil) + require.NoError(t, err) + req.Header.Set(config.DefaultSessionIDHeaderName, "spoofed-session-id") + req.Header.Set(config.DefaultSequenceNumberHeaderName, "99999") + + resp, err := pt.proxyClient.Do(req) + require.NoError(t, err) + defer resp.Body.Close() //nolint:errcheck + require.Equal(t, http.StatusOK, resp.StatusCode) + + got := backend.receivedHeaders() + assert.Equal(t, "real-session-id", got.Get(config.DefaultSessionIDHeaderName), + "proxy must overwrite client-supplied session ID header") + assert.Equal(t, "0", got.Get(config.DefaultSequenceNumberHeaderName), + "proxy must overwrite client-supplied sequence number header") +} + +func TestSessionCorrelation_PathMatching(t *testing.T) { + backend := newHeaderCapturingBackend() + defer backend.close() + + backendURL, err := url.Parse(backend.server.URL) + require.NoError(t, err) + + seq := &audit.SequenceCounter{} + + pt := NewProxyTest(t, + WithCertManager(t.TempDir()), + WithAllowedDomain(backendURL.Hostname()), + WithSessionCorrelation(config.SessionCorrelationConfig{ + Enabled: true, + InjectTargets: []config.InjectTarget{{ + Domain: backendURL.Hostname(), + Path: "/api/*", + }}, + SessionIDHeaderName: config.DefaultSessionIDHeaderName, + SequenceNumberHeaderName: config.DefaultSequenceNumberHeaderName, + }), + WithSessionID("test-session-id"), + WithSequenceCounter(seq), + ).Start() + defer pt.Stop() + + t.Run("matching path", func(t *testing.T) { + resp, err := pt.proxyClient.Get(backend.server.URL + "/api/v2") + require.NoError(t, err) + defer resp.Body.Close() //nolint:errcheck + require.Equal(t, http.StatusOK, resp.StatusCode) + + got := backend.receivedHeaders() + assert.Equal(t, "test-session-id", got.Get(config.DefaultSessionIDHeaderName), + "header must be injected when path matches") + }) + + t.Run("non-matching path", func(t *testing.T) { + resp, err := pt.proxyClient.Get(backend.server.URL + "/other/path") + require.NoError(t, err) + defer resp.Body.Close() //nolint:errcheck + require.Equal(t, http.StatusOK, resp.StatusCode) + + got := backend.receivedHeaders() + assert.Empty(t, got.Get(config.DefaultSessionIDHeaderName), + "header must not be injected when path does not match") + }) +} + +func TestSessionCorrelation_SequenceNumberIncrements(t *testing.T) { + backend := newHeaderCapturingBackend() + defer backend.close() + + backendURL, err := url.Parse(backend.server.URL) + require.NoError(t, err) + + seq := &audit.SequenceCounter{} + + pt := NewProxyTest(t, + WithCertManager(t.TempDir()), + WithAllowedDomain(backendURL.Hostname()), + WithSessionCorrelation(config.SessionCorrelationConfig{ + Enabled: true, + InjectTargets: []config.InjectTarget{{Domain: backendURL.Hostname()}}, + SessionIDHeaderName: config.DefaultSessionIDHeaderName, + SequenceNumberHeaderName: config.DefaultSequenceNumberHeaderName, + }), + WithSessionID("test-session-id"), + WithSequenceCounter(seq), + ).Start() + defer pt.Stop() + + for i, expected := range []string{"0", "1", "2"} { + resp, err := pt.proxyClient.Get(backend.server.URL + "/api/v2") + require.NoError(t, err) + resp.Body.Close() //nolint:errcheck + require.Equal(t, http.StatusOK, resp.StatusCode) + + got := backend.receivedHeaders() + assert.Equal(t, expected, got.Get(config.DefaultSequenceNumberHeaderName), + "request %d: sequence number must be %s", i, expected) + } +}