From c24a0fd6790056c8bad86fefc62282c1cfa3d88c Mon Sep 17 00:00:00 2001 From: Sas Swart Date: Thu, 30 Apr 2026 09:55:30 +0000 Subject: [PATCH 1/3] feat(config): session correlation header injection configuration Add YAML and CLI configuration surface for session correlation header injection per the Bridge/Boundaries Correlation RFC (FR 2). New configuration options: - --enable-session-correlation / session_correlation_enabled: top-level toggle to disable injection entirely for deployments without AI Bridge in front. - --inject-session-id-on / session_id_inject_targets (YAML): repeatable list of inject targets in "domain= [path=]" format. - --session-id-header-name / session_id_header_name: configurable header name (default X-Coder-Agent-Firewall-Session-Id). - --sequence-number-header-name / sequence_number_header_name: configurable header name (default X-Coder-Agent-Firewall-Sequence-Number). Config validation ensures that when correlation is enabled at least one inject target is present and header names are non-empty. Parsing validates the domain=... path=... key-value format and rejects unknown keys. This commit adds config and validation only; runtime injection is wired in a follow-up PR. --- cli/cli.go | 37 ++++ config/config.go | 58 ++++++ config/session_correlation.go | 105 ++++++++++ config/session_correlation_test.go | 314 +++++++++++++++++++++++++++++ 4 files changed, 514 insertions(+) create mode 100644 config/session_correlation.go create mode 100644 config/session_correlation_test.go diff --git a/cli/cli.go b/cli/cli.go index 7d1567a..a3f2de8 100644 --- a/cli/cli.go +++ b/cli/cli.go @@ -169,6 +169,43 @@ func BaseCommand(version string) *serpent.Command { Value: &showVersion, YAML: "", // CLI only }, + // Session correlation header injection options. + { + Flag: "enable-session-correlation", + Env: "BOUNDARY_SESSION_CORRELATION_ENABLED", + Description: "Enable session correlation header injection. Disable for deployments without AI Bridge in front.", + Value: &cliConfig.SessionCorrelationEnabled, + YAML: "session_correlation_enabled", + }, + { + Flag: "inject-session-id-on", + Env: "BOUNDARY_INJECT_SESSION_ID_ON", + Description: `Inject target (repeatable). Requests matching these targets receive session correlation headers. Format: "domain= [path=]".`, + Value: &cliConfig.InjectSessionIDOn, + YAML: "", // CLI only, YAML uses session_id_inject_targets. + }, + { + Flag: "", // No CLI flag, YAML only. + Description: "Inject targets from config file (YAML only).", + Value: &cliConfig.InjectSessionIDOnYAML, + YAML: "session_id_inject_targets", + }, + { + Flag: "session-id-header-name", + Env: "BOUNDARY_SESSION_ID_HEADER_NAME", + Description: "HTTP header name for the boundary session ID.", + Default: config.DefaultSessionIDHeaderName, + Value: &cliConfig.SessionIDHeaderName, + YAML: "session_id_header_name", + }, + { + Flag: "sequence-number-header-name", + Env: "BOUNDARY_SEQUENCE_NUMBER_HEADER_NAME", + Description: "HTTP header name for the boundary sequence number.", + Default: config.DefaultSequenceNumberHeaderName, + Value: &cliConfig.SequenceNumberHeaderName, + YAML: "sequence_number_header_name", + }, }, Handler: func(inv *serpent.Invocation) error { // Handle --version flag early diff --git a/config/config.go b/config/config.go index 229cf35..6098ece 100644 --- a/config/config.go +++ b/config/config.go @@ -69,6 +69,13 @@ type CliConfig struct { NoUserNamespace serpent.Bool `yaml:"no_user_namespace"` DisableAuditLogs serpent.Bool `yaml:"disable_audit_logs"` LogProxySocketPath serpent.String `yaml:"log_proxy_socket_path"` + + // Session correlation header injection. + SessionCorrelationEnabled serpent.Bool `yaml:"session_correlation_enabled"` + InjectSessionIDOn AllowStringsArray `yaml:"inject_session_id_on"` + InjectSessionIDOnYAML serpent.StringArray `yaml:"session_id_inject_targets"` + SessionIDHeaderName serpent.String `yaml:"session_id_header_name"` + SequenceNumberHeaderName serpent.String `yaml:"sequence_number_header_name"` } type AppConfig struct { @@ -86,6 +93,10 @@ type AppConfig struct { DisableAuditLogs bool LogProxySocketPath string + // SessionCorrelation controls header injection for AI Bridge + // correlation. See SessionCorrelationConfig for details. + SessionCorrelation SessionCorrelationConfig + // SessionID is a UUIDv4 generated at process startup. It groups // all audit events produced by this boundary invocation into a // single session. Set by Run, not by configuration. @@ -107,6 +118,12 @@ func NewAppConfigFromCliConfig(cfg CliConfig, targetCMD []string) (AppConfig, er userInfo := GetUserInfo() + // Build session correlation config from CLI and YAML sources. + sc, err := buildSessionCorrelation(cfg) + if err != nil { + return AppConfig{}, fmt.Errorf("session correlation config: %w", err) + } + return AppConfig{ AllowRules: allAllowStrings, LogLevel: cfg.LogLevel.Value(), @@ -121,5 +138,46 @@ func NewAppConfigFromCliConfig(cfg CliConfig, targetCMD []string) (AppConfig, er UserInfo: userInfo, DisableAuditLogs: cfg.DisableAuditLogs.Value(), LogProxySocketPath: cfg.LogProxySocketPath.Value(), + SessionCorrelation: sc, }, nil } + +// buildSessionCorrelation merges CLI and YAML inject target sources, +// parses each target string, applies header name defaults, and +// validates the resulting configuration. +func buildSessionCorrelation(cfg CliConfig) (SessionCorrelationConfig, error) { + // Merge YAML targets with CLI targets. + rawTargets := append(cfg.InjectSessionIDOnYAML.Value(), cfg.InjectSessionIDOn.Value()...) + + var targets []InjectTarget + for _, raw := range rawTargets { + t, err := ParseInjectTarget(raw) + if err != nil { + return SessionCorrelationConfig{}, err + } + targets = append(targets, t) + } + + // Apply defaults for header names. + sessionIDHeader := cfg.SessionIDHeaderName.Value() + if sessionIDHeader == "" { + sessionIDHeader = DefaultSessionIDHeaderName + } + seqHeader := cfg.SequenceNumberHeaderName.Value() + if seqHeader == "" { + seqHeader = DefaultSequenceNumberHeaderName + } + + sc := SessionCorrelationConfig{ + Enabled: cfg.SessionCorrelationEnabled.Value(), + InjectTargets: targets, + SessionIDHeaderName: sessionIDHeader, + SequenceNumberHeaderName: seqHeader, + } + + if err := ValidateSessionCorrelation(sc); err != nil { + return SessionCorrelationConfig{}, err + } + + return sc, nil +} diff --git a/config/session_correlation.go b/config/session_correlation.go new file mode 100644 index 0000000..a67b02e --- /dev/null +++ b/config/session_correlation.go @@ -0,0 +1,105 @@ +package config + +import ( + "fmt" + "strings" +) + +// Default header names for session correlation. +const ( + DefaultSessionIDHeaderName = "X-Coder-Agent-Firewall-Session-Id" + DefaultSequenceNumberHeaderName = "X-Coder-Agent-Firewall-Sequence-Number" +) + +// InjectTarget represents a parsed target for session correlation header +// injection. Requests matching the domain (and optional path glob) will +// receive the session ID and sequence number headers. +type InjectTarget struct { + Domain string + Path string +} + +// SessionCorrelationConfig holds configuration for session correlation +// header injection. When enabled, boundary injects its session ID and +// sequence number as custom headers on matching outbound requests so +// that an upstream AI Bridge can correlate the request back to the +// boundary audit event stream. +type SessionCorrelationConfig struct { + // Enabled controls whether session correlation headers are injected. + // Deployments without AI Bridge in front should set this to false. + Enabled bool + + // InjectTargets is the list of domain/path patterns that should + // receive session correlation headers. + InjectTargets []InjectTarget + + // SessionIDHeaderName is the HTTP header name used to carry the + // boundary session ID. Defaults to DefaultSessionIDHeaderName. + SessionIDHeaderName string + + // SequenceNumberHeaderName is the HTTP header name used to carry + // the boundary sequence number. Defaults to + // DefaultSequenceNumberHeaderName. + SequenceNumberHeaderName string +} + +// ParseInjectTarget parses a string of the form "domain=... path=..." +// into an InjectTarget. The domain key is required; path is optional. +func ParseInjectTarget(raw string) (InjectTarget, error) { + raw = strings.TrimSpace(raw) + if raw == "" { + return InjectTarget{}, fmt.Errorf("inject target must not be empty") + } + + var target InjectTarget + for _, part := range strings.Fields(raw) { + key, value, ok := strings.Cut(part, "=") + if !ok { + return InjectTarget{}, fmt.Errorf( + "inject target: malformed key-value pair %q, expected key=value", part, + ) + } + switch key { + case "domain": + if value == "" { + return InjectTarget{}, fmt.Errorf("inject target: domain must not be empty") + } + target.Domain = value + case "path": + target.Path = value + default: + return InjectTarget{}, fmt.Errorf("inject target: unknown key %q", key) + } + } + + if target.Domain == "" { + return InjectTarget{}, fmt.Errorf("inject target: domain is required") + } + + return target, nil +} + +// ValidateSessionCorrelation checks that the session correlation config +// is internally consistent. It returns an error describing the first +// problem found, or nil if the config is valid. +func ValidateSessionCorrelation(cfg SessionCorrelationConfig) error { + if !cfg.Enabled { + return nil + } + + if len(cfg.InjectTargets) == 0 { + return fmt.Errorf( + "session correlation is enabled but no inject targets are configured", + ) + } + + if cfg.SessionIDHeaderName == "" { + return fmt.Errorf("session-id-header-name must not be empty when session correlation is enabled") + } + + if cfg.SequenceNumberHeaderName == "" { + return fmt.Errorf("sequence-number-header-name must not be empty when session correlation is enabled") + } + + return nil +} diff --git a/config/session_correlation_test.go b/config/session_correlation_test.go new file mode 100644 index 0000000..eb0605a --- /dev/null +++ b/config/session_correlation_test.go @@ -0,0 +1,314 @@ +package config + +import ( + "testing" +) + +func TestParseInjectTarget(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + input string + want InjectTarget + wantErr bool + }{ + { + name: "domain only", + input: "domain=dev.coder.com", + want: InjectTarget{Domain: "dev.coder.com"}, + }, + { + name: "domain and path", + input: "domain=dev.coder.com path=/api/v2/aibridge/*", + want: InjectTarget{Domain: "dev.coder.com", Path: "/api/v2/aibridge/*"}, + }, + { + name: "leading and trailing whitespace", + input: " domain=dev.coder.com path=/api/* ", + want: InjectTarget{Domain: "dev.coder.com", Path: "/api/*"}, + }, + { + name: "empty string", + input: "", + wantErr: true, + }, + { + name: "whitespace only", + input: " ", + wantErr: true, + }, + { + name: "missing domain", + input: "path=/api/*", + wantErr: true, + }, + { + name: "empty domain value", + input: "domain=", + wantErr: true, + }, + { + name: "malformed pair no equals", + input: "domain", + wantErr: true, + }, + { + name: "unknown key", + input: "domain=example.com port=443", + wantErr: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + got, err := ParseInjectTarget(tc.input) + if tc.wantErr { + if err == nil { + t.Fatalf("expected error, got nil") + } + return + } + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if got.Domain != tc.want.Domain { + t.Errorf("Domain: got %q, want %q", got.Domain, tc.want.Domain) + } + if got.Path != tc.want.Path { + t.Errorf("Path: got %q, want %q", got.Path, tc.want.Path) + } + }) + } +} + +func TestValidateSessionCorrelation(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + cfg SessionCorrelationConfig + wantErr bool + }{ + { + name: "disabled is always valid", + cfg: SessionCorrelationConfig{ + Enabled: false, + }, + }, + { + name: "disabled with empty targets is valid", + cfg: SessionCorrelationConfig{ + Enabled: false, + InjectTargets: nil, + }, + }, + { + name: "enabled with targets and default headers", + cfg: SessionCorrelationConfig{ + Enabled: true, + InjectTargets: []InjectTarget{{Domain: "dev.coder.com"}}, + SessionIDHeaderName: DefaultSessionIDHeaderName, + SequenceNumberHeaderName: DefaultSequenceNumberHeaderName, + }, + }, + { + name: "enabled with custom headers", + cfg: SessionCorrelationConfig{ + Enabled: true, + InjectTargets: []InjectTarget{{Domain: "example.com", Path: "/api/*"}}, + SessionIDHeaderName: "X-Custom-Session", + SequenceNumberHeaderName: "X-Custom-Seq", + }, + }, + { + name: "enabled with no targets", + cfg: SessionCorrelationConfig{ + Enabled: true, + InjectTargets: nil, + SessionIDHeaderName: DefaultSessionIDHeaderName, + SequenceNumberHeaderName: DefaultSequenceNumberHeaderName, + }, + wantErr: true, + }, + { + name: "enabled with empty targets slice", + cfg: SessionCorrelationConfig{ + Enabled: true, + InjectTargets: []InjectTarget{}, + SessionIDHeaderName: DefaultSessionIDHeaderName, + SequenceNumberHeaderName: DefaultSequenceNumberHeaderName, + }, + wantErr: true, + }, + { + name: "enabled with empty session id header", + cfg: SessionCorrelationConfig{ + Enabled: true, + InjectTargets: []InjectTarget{{Domain: "example.com"}}, + SessionIDHeaderName: "", + SequenceNumberHeaderName: DefaultSequenceNumberHeaderName, + }, + wantErr: true, + }, + { + name: "enabled with empty sequence number header", + cfg: SessionCorrelationConfig{ + Enabled: true, + InjectTargets: []InjectTarget{{Domain: "example.com"}}, + SessionIDHeaderName: DefaultSessionIDHeaderName, + SequenceNumberHeaderName: "", + }, + wantErr: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + err := ValidateSessionCorrelation(tc.cfg) + if tc.wantErr { + if err == nil { + t.Fatalf("expected error, got nil") + } + return + } + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + }) + } +} + +func TestNewAppConfigFromCliConfig_SessionCorrelation(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + cli CliConfig + want SessionCorrelationConfig + wantErr bool + }{ + { + name: "defaults when not configured", + cli: baseCliConfig(), + want: SessionCorrelationConfig{ + Enabled: false, + InjectTargets: nil, + SessionIDHeaderName: DefaultSessionIDHeaderName, + SequenceNumberHeaderName: DefaultSequenceNumberHeaderName, + }, + }, + { + name: "enabled with inject targets", + cli: func() CliConfig { + c := baseCliConfig() + c.SessionCorrelationEnabled.Set("true") + _ = c.InjectSessionIDOn.Set("domain=dev.coder.com path=/api/v2/aibridge/*") + return c + }(), + want: SessionCorrelationConfig{ + Enabled: true, + InjectTargets: []InjectTarget{ + {Domain: "dev.coder.com", Path: "/api/v2/aibridge/*"}, + }, + SessionIDHeaderName: DefaultSessionIDHeaderName, + SequenceNumberHeaderName: DefaultSequenceNumberHeaderName, + }, + }, + { + name: "custom header names", + cli: func() CliConfig { + c := baseCliConfig() + c.SessionCorrelationEnabled.Set("true") + _ = c.InjectSessionIDOn.Set("domain=example.com") + c.SessionIDHeaderName.Set("X-My-Session") + c.SequenceNumberHeaderName.Set("X-My-Seq") + return c + }(), + want: SessionCorrelationConfig{ + Enabled: true, + InjectTargets: []InjectTarget{{Domain: "example.com"}}, + SessionIDHeaderName: "X-My-Session", + SequenceNumberHeaderName: "X-My-Seq", + }, + }, + { + name: "enabled with no targets fails validation", + cli: func() CliConfig { + c := baseCliConfig() + c.SessionCorrelationEnabled.Set("true") + return c + }(), + wantErr: true, + }, + { + name: "invalid inject target", + cli: func() CliConfig { + c := baseCliConfig() + c.SessionCorrelationEnabled.Set("true") + _ = c.InjectSessionIDOn.Set("notakey") + return c + }(), + wantErr: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + got, err := NewAppConfigFromCliConfig(tc.cli, []string{"echo", "hello"}) + if tc.wantErr { + if err == nil { + t.Fatalf("expected error, got nil") + } + return + } + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + sc := got.SessionCorrelation + if sc.Enabled != tc.want.Enabled { + t.Errorf("Enabled: got %v, want %v", sc.Enabled, tc.want.Enabled) + } + if sc.SessionIDHeaderName != tc.want.SessionIDHeaderName { + t.Errorf("SessionIDHeaderName: got %q, want %q", + sc.SessionIDHeaderName, tc.want.SessionIDHeaderName) + } + if sc.SequenceNumberHeaderName != tc.want.SequenceNumberHeaderName { + t.Errorf("SequenceNumberHeaderName: got %q, want %q", + sc.SequenceNumberHeaderName, tc.want.SequenceNumberHeaderName) + } + if len(sc.InjectTargets) != len(tc.want.InjectTargets) { + t.Fatalf("InjectTargets len: got %d, want %d", + len(sc.InjectTargets), len(tc.want.InjectTargets)) + } + for i := range sc.InjectTargets { + if sc.InjectTargets[i].Domain != tc.want.InjectTargets[i].Domain { + t.Errorf("InjectTargets[%d].Domain: got %q, want %q", + i, sc.InjectTargets[i].Domain, tc.want.InjectTargets[i].Domain) + } + if sc.InjectTargets[i].Path != tc.want.InjectTargets[i].Path { + t.Errorf("InjectTargets[%d].Path: got %q, want %q", + i, sc.InjectTargets[i].Path, tc.want.InjectTargets[i].Path) + } + } + }) + } +} + +// baseCliConfig returns a CliConfig with valid defaults for fields that +// NewAppConfigFromCliConfig requires, so tests can focus on the session +// correlation fields without tripping over unrelated validation. +func baseCliConfig() CliConfig { + c := CliConfig{} + c.JailType.Set("nsjail") + c.SessionIDHeaderName.Set(DefaultSessionIDHeaderName) + c.SequenceNumberHeaderName.Set(DefaultSequenceNumberHeaderName) + return c +} From 6e61348ba125268a156a36bad254dba6e1121a12 Mon Sep 17 00:00:00 2001 From: Sas Swart Date: Thu, 30 Apr 2026 12:52:47 +0000 Subject: [PATCH 2/3] feat(proxy): inject session ID and sequence number headers on matching requests When session correlation is enabled and the outgoing request matches a configured inject target, set X-Coder-Agent-Firewall-Session-Id and X-Coder-Agent-Firewall-Sequence-Number headers on the forwarded request. Any values the jailed client may have set are overwritten so the upstream always sees boundary's authoritative session ID and sequence number. The sequence number is pre-allocated before the audit event so both the audit log and the injected header carry the same value. audit.Request gains a SequenceNumber pointer field; when non-nil the socket auditor uses it instead of calling its own counter. New proxy.Config fields: SessionCorrelation, SessionID, SequenceCounter. New Server method: shouldInjectHeaders (domain + optional path glob matching). Tests cover matched domain, unmatched domain, disabled injection, client-supplied header overwrite, path glob matching, and sequence number incrementing. --- audit/request.go | 7 + audit/socket_auditor.go | 9 +- proxy/proxy.go | 102 +++++++-- proxy/proxy_framework_test.go | 59 +++-- proxy/proxy_session_correlation_test.go | 273 ++++++++++++++++++++++++ 5 files changed, 414 insertions(+), 36 deletions(-) create mode 100644 proxy/proxy_session_correlation_test.go diff --git a/audit/request.go b/audit/request.go index c6ef1b3..e23f774 100644 --- a/audit/request.go +++ b/audit/request.go @@ -11,4 +11,11 @@ type Request struct { Host string Allowed bool Rule string // The rule that matched (if any) + + // SequenceNumber is a pre-allocated sequence number for this + // audit event. When non-nil the auditor must use this value + // instead of generating its own so that the audit log and + // any injected HTTP header carry the same number. When nil + // the auditor falls back to its internal SequenceCounter. + SequenceNumber *uint64 } diff --git a/audit/socket_auditor.go b/audit/socket_auditor.go index 9a613fc..c7c2602 100644 --- a/audit/socket_auditor.go +++ b/audit/socket_auditor.go @@ -80,10 +80,17 @@ func (s *SocketAuditor) AuditRequest(req Request) { httpReq.MatchedRule = req.Rule } + var seqNum uint64 + if req.SequenceNumber != nil { + seqNum = *req.SequenceNumber + } else { + seqNum = s.seq.Next() + } + log := &agentproto.BoundaryLog{ Allowed: req.Allowed, Time: timestamppb.Now(), - SequenceNumber: s.seq.Next(), + SequenceNumber: seqNum, Resource: &agentproto.BoundaryLog_HttpRequest_{HttpRequest: httpReq}, } diff --git a/proxy/proxy.go b/proxy/proxy.go index 154a6bc..0fb5dfa 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -13,23 +13,28 @@ import ( "net/http" _ "net/http/pprof" "net/url" + "path" "strconv" "strings" "sync/atomic" "time" "github.com/coder/boundary/audit" + "github.com/coder/boundary/config" "github.com/coder/boundary/rulesengine" ) // Server handles HTTP and HTTPS requests with rule-based filtering type Server struct { - ruleEngine rulesengine.Engine - auditor audit.Auditor - logger *slog.Logger - tlsConfig *tls.Config - httpPort int - started atomic.Bool + ruleEngine rulesengine.Engine + auditor audit.Auditor + logger *slog.Logger + tlsConfig *tls.Config + httpPort int + started atomic.Bool + sessionCorrelation config.SessionCorrelationConfig + sessionID string + seqCounter *audit.SequenceCounter listener net.Listener pprofServer *http.Server @@ -46,18 +51,30 @@ type Config struct { TLSConfig *tls.Config PprofEnabled bool PprofPort int + // SessionCorrelation controls header injection for AI Bridge + // correlation. See config.SessionCorrelationConfig for details. + SessionCorrelation config.SessionCorrelationConfig + // SessionID is the boundary session UUID injected as a header + // on matching requests. + SessionID string + // SequenceCounter provides monotonically increasing sequence + // numbers shared with the auditor so both carry the same value. + SequenceCounter *audit.SequenceCounter } // NewProxyServer creates a new proxy server instance func NewProxyServer(config Config) *Server { return &Server{ - ruleEngine: config.RuleEngine, - auditor: config.Auditor, - logger: config.Logger, - tlsConfig: config.TLSConfig, - httpPort: config.HTTPPort, - pprofEnabled: config.PprofEnabled, - pprofPort: config.PprofPort, + ruleEngine: config.RuleEngine, + auditor: config.Auditor, + logger: config.Logger, + tlsConfig: config.TLSConfig, + httpPort: config.HTTPPort, + pprofEnabled: config.PprofEnabled, + pprofPort: config.PprofPort, + sessionCorrelation: config.SessionCorrelation, + sessionID: config.SessionID, + seqCounter: config.SequenceCounter, } } @@ -276,12 +293,21 @@ func (p *Server) processHTTPRequest(conn net.Conn, req *http.Request, https bool result := p.ruleEngine.Evaluate(req.Method, fullURL) + // Pre-allocate a sequence number so the audit event and any + // injected header carry the same value. + var seqNum *uint64 + if p.seqCounter != nil { + n := p.seqCounter.Next() + seqNum = &n + } + p.auditor.AuditRequest(audit.Request{ - Method: req.Method, - URL: fullURL, - Host: req.Host, - Allowed: result.Allowed, - Rule: result.Rule, + Method: req.Method, + URL: fullURL, + Host: req.Host, + Allowed: result.Allowed, + Rule: result.Rule, + SequenceNumber: seqNum, }) if !result.Allowed { @@ -290,10 +316,36 @@ func (p *Server) processHTTPRequest(conn net.Conn, req *http.Request, https bool } // Forward request to destination - p.forwardRequest(conn, req, https) + p.forwardRequest(conn, req, https, seqNum) +} + +// shouldInjectHeaders reports whether the request to the given host +// and path matches any configured inject target. When session +// correlation is disabled or no targets match it returns false. +func (p *Server) shouldInjectHeaders(host, reqPath string) bool { + if !p.sessionCorrelation.Enabled { + return false + } + // Strip port from host for matching (e.g. "example.com:443" -> "example.com"). + h := host + if i := strings.LastIndex(h, ":"); i != -1 { + h = h[:i] + } + for _, target := range p.sessionCorrelation.InjectTargets { + if !strings.EqualFold(target.Domain, h) { + continue + } + if target.Path == "" { + return true + } + if matched, _ := path.Match(target.Path, reqPath); matched { + return true + } + } + return false } -func (p *Server) forwardRequest(conn net.Conn, req *http.Request, https bool) { +func (p *Server) forwardRequest(conn net.Conn, req *http.Request, https bool, seqNum *uint64) { // Create HTTP client client := &http.Client{ CheckRedirect: func(req *http.Request, via []*http.Request) error { @@ -338,6 +390,16 @@ func (p *Server) forwardRequest(conn net.Conn, req *http.Request, https bool) { } } + // Stamp session correlation headers on matching requests, + // overwriting any value the jailed client may have set so the + // upstream always sees boundary's ID. + if p.shouldInjectHeaders(req.Host, req.URL.Path) { + newReq.Header.Set(p.sessionCorrelation.SessionIDHeaderName, p.sessionID) + if seqNum != nil { + newReq.Header.Set(p.sessionCorrelation.SequenceNumberHeaderName, strconv.FormatUint(*seqNum, 10)) + } + } + // Make request to destination resp, err := client.Do(newReq) if err != nil { diff --git a/proxy/proxy_framework_test.go b/proxy/proxy_framework_test.go index 36a332b..97c9975 100644 --- a/proxy/proxy_framework_test.go +++ b/proxy/proxy_framework_test.go @@ -18,6 +18,7 @@ import ( "time" "github.com/coder/boundary/audit" + "github.com/coder/boundary/config" "github.com/coder/boundary/rulesengine" boundary_tls "github.com/coder/boundary/tls" "github.com/stretchr/testify/require" @@ -32,16 +33,19 @@ func (m *mockAuditor) AuditRequest(req audit.Request) { // ProxyTest is a high-level test framework for proxy tests type ProxyTest struct { - t *testing.T - server *Server - client *http.Client - proxyClient *http.Client - port int - useCertManager bool - configDir string - startupDelay time.Duration - allowedRules []string - auditor audit.Auditor + t *testing.T + server *Server + client *http.Client + proxyClient *http.Client + port int + useCertManager bool + configDir string + startupDelay time.Duration + allowedRules []string + auditor audit.Auditor + sessionCorrelation config.SessionCorrelationConfig + sessionID string + seqCounter *audit.SequenceCounter } // ProxyTestOption is a function that configures ProxyTest @@ -109,6 +113,28 @@ func WithAuditor(auditor audit.Auditor) ProxyTestOption { } } +// WithSessionCorrelation sets the session correlation config for the +// proxy under test. +func WithSessionCorrelation(sc config.SessionCorrelationConfig) ProxyTestOption { + return func(pt *ProxyTest) { + pt.sessionCorrelation = sc + } +} + +// WithSessionID sets the boundary session ID for the proxy under test. +func WithSessionID(id string) ProxyTestOption { + return func(pt *ProxyTest) { + pt.sessionID = id + } +} + +// WithSequenceCounter sets the sequence counter for the proxy under test. +func WithSequenceCounter(seq *audit.SequenceCounter) ProxyTestOption { + return func(pt *ProxyTest) { + pt.seqCounter = seq + } +} + // Start starts the proxy server func (pt *ProxyTest) Start() *ProxyTest { pt.t.Helper() @@ -153,11 +179,14 @@ func (pt *ProxyTest) Start() *ProxyTest { } pt.server = NewProxyServer(Config{ - HTTPPort: pt.port, - RuleEngine: ruleEngine, - Auditor: auditor, - Logger: logger, - TLSConfig: tlsConfig, + HTTPPort: pt.port, + RuleEngine: ruleEngine, + Auditor: auditor, + Logger: logger, + TLSConfig: tlsConfig, + SessionCorrelation: pt.sessionCorrelation, + SessionID: pt.sessionID, + SequenceCounter: pt.seqCounter, }) err = pt.server.Start() diff --git a/proxy/proxy_session_correlation_test.go b/proxy/proxy_session_correlation_test.go new file mode 100644 index 0000000..40a419e --- /dev/null +++ b/proxy/proxy_session_correlation_test.go @@ -0,0 +1,273 @@ +package proxy + +import ( + "net/http" + "net/http/httptest" + "net/url" + "sync" + "testing" + + "github.com/coder/boundary/audit" + "github.com/coder/boundary/config" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// headerCapturingBackend spins up an httptest.Server that records the +// headers it receives. Call receivedHeaders after the request to inspect +// them. +type headerCapturingBackend struct { + server *httptest.Server + mu sync.Mutex + headers http.Header +} + +func newHeaderCapturingBackend() *headerCapturingBackend { + hcb := &headerCapturingBackend{} + hcb.server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + hcb.mu.Lock() + hcb.headers = r.Header.Clone() + hcb.mu.Unlock() + w.WriteHeader(http.StatusOK) + })) + return hcb +} + +func (h *headerCapturingBackend) close() { h.server.Close() } + +func (h *headerCapturingBackend) receivedHeaders() http.Header { + h.mu.Lock() + defer h.mu.Unlock() + return h.headers.Clone() +} + +func TestSessionCorrelation_MatchedDomain(t *testing.T) { + backend := newHeaderCapturingBackend() + defer backend.close() + + backendURL, err := url.Parse(backend.server.URL) + require.NoError(t, err) + + seq := &audit.SequenceCounter{} + + pt := NewProxyTest(t, + WithCertManager(t.TempDir()), + WithAllowedDomain(backendURL.Hostname()), + WithSessionCorrelation(config.SessionCorrelationConfig{ + Enabled: true, + InjectTargets: []config.InjectTarget{{Domain: backendURL.Hostname()}}, + SessionIDHeaderName: config.DefaultSessionIDHeaderName, + SequenceNumberHeaderName: config.DefaultSequenceNumberHeaderName, + }), + WithSessionID("test-session-id-1234"), + WithSequenceCounter(seq), + ).Start() + defer pt.Stop() + + resp, err := pt.proxyClient.Get(backend.server.URL + "/api/v2") + require.NoError(t, err) + defer resp.Body.Close() //nolint:errcheck + require.Equal(t, http.StatusOK, resp.StatusCode) + + got := backend.receivedHeaders() + assert.Equal(t, "test-session-id-1234", got.Get(config.DefaultSessionIDHeaderName), + "session ID header must be injected on matching domain") + assert.Equal(t, "0", got.Get(config.DefaultSequenceNumberHeaderName), + "sequence number header must start at 0") +} + +func TestSessionCorrelation_UnmatchedDomain(t *testing.T) { + backend := newHeaderCapturingBackend() + defer backend.close() + + backendURL, err := url.Parse(backend.server.URL) + require.NoError(t, err) + + seq := &audit.SequenceCounter{} + + pt := NewProxyTest(t, + WithCertManager(t.TempDir()), + WithAllowedDomain(backendURL.Hostname()), + WithSessionCorrelation(config.SessionCorrelationConfig{ + Enabled: true, + InjectTargets: []config.InjectTarget{{Domain: "other-domain.example.com"}}, + SessionIDHeaderName: config.DefaultSessionIDHeaderName, + SequenceNumberHeaderName: config.DefaultSequenceNumberHeaderName, + }), + WithSessionID("test-session-id-1234"), + WithSequenceCounter(seq), + ).Start() + defer pt.Stop() + + resp, err := pt.proxyClient.Get(backend.server.URL + "/api/v2") + require.NoError(t, err) + defer resp.Body.Close() //nolint:errcheck + require.Equal(t, http.StatusOK, resp.StatusCode) + + got := backend.receivedHeaders() + assert.Empty(t, got.Get(config.DefaultSessionIDHeaderName), + "session ID header must not be injected on unmatched domain") + assert.Empty(t, got.Get(config.DefaultSequenceNumberHeaderName), + "sequence number header must not be injected on unmatched domain") +} + +func TestSessionCorrelation_Disabled(t *testing.T) { + backend := newHeaderCapturingBackend() + defer backend.close() + + backendURL, err := url.Parse(backend.server.URL) + require.NoError(t, err) + + seq := &audit.SequenceCounter{} + + pt := NewProxyTest(t, + WithCertManager(t.TempDir()), + WithAllowedDomain(backendURL.Hostname()), + WithSessionCorrelation(config.SessionCorrelationConfig{ + Enabled: false, + InjectTargets: []config.InjectTarget{{Domain: backendURL.Hostname()}}, + SessionIDHeaderName: config.DefaultSessionIDHeaderName, + SequenceNumberHeaderName: config.DefaultSequenceNumberHeaderName, + }), + WithSessionID("test-session-id-1234"), + WithSequenceCounter(seq), + ).Start() + defer pt.Stop() + + resp, err := pt.proxyClient.Get(backend.server.URL + "/api/v2") + require.NoError(t, err) + defer resp.Body.Close() //nolint:errcheck + require.Equal(t, http.StatusOK, resp.StatusCode) + + got := backend.receivedHeaders() + assert.Empty(t, got.Get(config.DefaultSessionIDHeaderName), + "session ID header must not be injected when correlation is disabled") + assert.Empty(t, got.Get(config.DefaultSequenceNumberHeaderName), + "sequence number header must not be injected when correlation is disabled") +} + +func TestSessionCorrelation_OverwritesClientValue(t *testing.T) { + backend := newHeaderCapturingBackend() + defer backend.close() + + backendURL, err := url.Parse(backend.server.URL) + require.NoError(t, err) + + seq := &audit.SequenceCounter{} + + pt := NewProxyTest(t, + WithCertManager(t.TempDir()), + WithAllowedDomain(backendURL.Hostname()), + WithSessionCorrelation(config.SessionCorrelationConfig{ + Enabled: true, + InjectTargets: []config.InjectTarget{{Domain: backendURL.Hostname()}}, + SessionIDHeaderName: config.DefaultSessionIDHeaderName, + SequenceNumberHeaderName: config.DefaultSequenceNumberHeaderName, + }), + WithSessionID("real-session-id"), + WithSequenceCounter(seq), + ).Start() + defer pt.Stop() + + // Send a request with client-supplied session correlation headers + // that should be overwritten by the proxy. + req, err := http.NewRequest(http.MethodGet, backend.server.URL+"/api/v2", nil) + require.NoError(t, err) + req.Header.Set(config.DefaultSessionIDHeaderName, "spoofed-session-id") + req.Header.Set(config.DefaultSequenceNumberHeaderName, "99999") + + resp, err := pt.proxyClient.Do(req) + require.NoError(t, err) + defer resp.Body.Close() //nolint:errcheck + require.Equal(t, http.StatusOK, resp.StatusCode) + + got := backend.receivedHeaders() + assert.Equal(t, "real-session-id", got.Get(config.DefaultSessionIDHeaderName), + "proxy must overwrite client-supplied session ID header") + assert.Equal(t, "0", got.Get(config.DefaultSequenceNumberHeaderName), + "proxy must overwrite client-supplied sequence number header") +} + +func TestSessionCorrelation_PathMatching(t *testing.T) { + backend := newHeaderCapturingBackend() + defer backend.close() + + backendURL, err := url.Parse(backend.server.URL) + require.NoError(t, err) + + seq := &audit.SequenceCounter{} + + pt := NewProxyTest(t, + WithCertManager(t.TempDir()), + WithAllowedDomain(backendURL.Hostname()), + WithSessionCorrelation(config.SessionCorrelationConfig{ + Enabled: true, + InjectTargets: []config.InjectTarget{{ + Domain: backendURL.Hostname(), + Path: "/api/*", + }}, + SessionIDHeaderName: config.DefaultSessionIDHeaderName, + SequenceNumberHeaderName: config.DefaultSequenceNumberHeaderName, + }), + WithSessionID("test-session-id"), + WithSequenceCounter(seq), + ).Start() + defer pt.Stop() + + t.Run("matching path", func(t *testing.T) { + resp, err := pt.proxyClient.Get(backend.server.URL + "/api/v2") + require.NoError(t, err) + defer resp.Body.Close() //nolint:errcheck + require.Equal(t, http.StatusOK, resp.StatusCode) + + got := backend.receivedHeaders() + assert.Equal(t, "test-session-id", got.Get(config.DefaultSessionIDHeaderName), + "header must be injected when path matches") + }) + + t.Run("non-matching path", func(t *testing.T) { + resp, err := pt.proxyClient.Get(backend.server.URL + "/other/path") + require.NoError(t, err) + defer resp.Body.Close() //nolint:errcheck + require.Equal(t, http.StatusOK, resp.StatusCode) + + got := backend.receivedHeaders() + assert.Empty(t, got.Get(config.DefaultSessionIDHeaderName), + "header must not be injected when path does not match") + }) +} + +func TestSessionCorrelation_SequenceNumberIncrements(t *testing.T) { + backend := newHeaderCapturingBackend() + defer backend.close() + + backendURL, err := url.Parse(backend.server.URL) + require.NoError(t, err) + + seq := &audit.SequenceCounter{} + + pt := NewProxyTest(t, + WithCertManager(t.TempDir()), + WithAllowedDomain(backendURL.Hostname()), + WithSessionCorrelation(config.SessionCorrelationConfig{ + Enabled: true, + InjectTargets: []config.InjectTarget{{Domain: backendURL.Hostname()}}, + SessionIDHeaderName: config.DefaultSessionIDHeaderName, + SequenceNumberHeaderName: config.DefaultSequenceNumberHeaderName, + }), + WithSessionID("test-session-id"), + WithSequenceCounter(seq), + ).Start() + defer pt.Stop() + + for i, expected := range []string{"0", "1", "2"} { + resp, err := pt.proxyClient.Get(backend.server.URL + "/api/v2") + require.NoError(t, err) + resp.Body.Close() //nolint:errcheck + require.Equal(t, http.StatusOK, resp.StatusCode) + + got := backend.receivedHeaders() + assert.Equal(t, expected, got.Get(config.DefaultSequenceNumberHeaderName), + "request %d: sequence number must be %s", i, expected) + } +} From ae0177608673125ac3c90f42a504de767683cf7d Mon Sep 17 00:00:00 2001 From: Sas Swart Date: Fri, 1 May 2026 12:36:05 +0000 Subject: [PATCH 3/3] test(proxy): integration tests for session correlation audit and header agreement Add integration tests that verify the core invariants of session correlation across the proxy, auditor, and forwarded request headers working together. These tests fill the gap identified during review of the session correlation PR stack (#196, #197, #198) where unit tests verified each component in isolation but did not verify them in concert. New test file: proxy/proxy_session_correlation_integration_test.go Tests added: - LLMRequestAuditAndHeadersAgree: audit sequence number matches the forwarded header value on inject-target requests. - NonLLMRequestAuditedWithoutHeaders: allowed non-inject-target requests are audited but carry no correlation headers. - DeniedRequestAuditedNeverForwarded: denied requests consume a sequence number but are never forwarded. - MixedRequestsSequenceOrdering: interleaved LLM, non-LLM, and denied requests all advance the counter monotonically. - SequenceGapRevealsAgenticLoop: gap between two LLM sequence numbers precisely equals intermediate tool-use requests. - SpoofedHeadersOverwrittenWithCorrectSequence: client-supplied headers are replaced and the audit event still agrees. - DisabledCorrelationNoHeadersNoPreallocatedSequence: disabled correlation means no headers and no pre-allocated sequence. - ConcurrentRequestsUniqueSequenceNumbers: concurrent requests each get a unique, dense sequence number. --- ...xy_session_correlation_integration_test.go | 578 ++++++++++++++++++ 1 file changed, 578 insertions(+) create mode 100644 proxy/proxy_session_correlation_integration_test.go diff --git a/proxy/proxy_session_correlation_integration_test.go b/proxy/proxy_session_correlation_integration_test.go new file mode 100644 index 0000000..b8035b7 --- /dev/null +++ b/proxy/proxy_session_correlation_integration_test.go @@ -0,0 +1,578 @@ +package proxy + +import ( + "fmt" + "net/http" + "net/http/httptest" + "net/url" + "strconv" + "sync" + "testing" + + "github.com/coder/boundary/audit" + "github.com/coder/boundary/config" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// multiRequestCapturingBackend records the headers from every request it +// receives, not just the last one. This is needed by integration tests +// that send multiple requests to the same backend and want to verify +// each one independently. +type multiRequestCapturingBackend struct { + server *httptest.Server + mu sync.Mutex + all []http.Header +} + +func newMultiRequestCapturingBackend() *multiRequestCapturingBackend { + mcb := &multiRequestCapturingBackend{} + mcb.server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + mcb.mu.Lock() + mcb.all = append(mcb.all, r.Header.Clone()) + mcb.mu.Unlock() + w.WriteHeader(http.StatusOK) + })) + return mcb +} + +func (m *multiRequestCapturingBackend) close() { m.server.Close() } + +func (m *multiRequestCapturingBackend) requestCount() int { + m.mu.Lock() + defer m.mu.Unlock() + return len(m.all) +} + +func (m *multiRequestCapturingBackend) headersAt(i int) http.Header { + m.mu.Lock() + defer m.mu.Unlock() + return m.all[i].Clone() +} + +// sessionCorrelationIntegrationSetup holds the shared objects for an +// integration test: the proxy, auditor, backend(s), and sequence +// counter. Tests build one via newSessionCorrelationIntegrationSetup +// and tear it down with stop. +type sessionCorrelationIntegrationSetup struct { + pt *ProxyTest + auditor *capturingAuditor + seq *audit.SequenceCounter + llmBackend *multiRequestCapturingBackend + otherBackend *multiRequestCapturingBackend +} + +func (s *sessionCorrelationIntegrationSetup) stop() { + s.pt.Stop() + if s.llmBackend != nil { + s.llmBackend.close() + } + if s.otherBackend != nil { + s.otherBackend.close() + } +} + +// newSessionCorrelationIntegrationSetup builds a proxy that allows +// traffic to two httptest backends: one that matches an inject target +// (simulating an LLM provider) and one that does not (simulating a +// generic allowed domain like github.com). Both backends capture all +// received request headers. A capturingAuditor records every audit +// event for later inspection. +func newSessionCorrelationIntegrationSetup(t *testing.T, sessionID string) *sessionCorrelationIntegrationSetup { + t.Helper() + + llm := newMultiRequestCapturingBackend() + other := newMultiRequestCapturingBackend() + + llmURL, err := url.Parse(llm.server.URL) + require.NoError(t, err) + + otherURL, err := url.Parse(other.server.URL) + require.NoError(t, err) + + aud := &capturingAuditor{} + seq := &audit.SequenceCounter{} + + // Both httptest backends resolve to 127.0.0.1, so a domain-only + // inject target would match both. We use a path glob on the LLM + // paths (/v1/*) to limit header injection to LLM requests. + pt := NewProxyTest(t, + WithCertManager(t.TempDir()), + // Allow both backends. + WithAllowedDomain(llmURL.Hostname()), + WithAllowedDomain(otherURL.Hostname()), + // Only requests matching the LLM path receive headers. + WithSessionCorrelation(config.SessionCorrelationConfig{ + Enabled: true, + InjectTargets: []config.InjectTarget{{ + Domain: llmURL.Hostname(), + Path: "/v1/*", + }}, + SessionIDHeaderName: config.DefaultSessionIDHeaderName, + SequenceNumberHeaderName: config.DefaultSequenceNumberHeaderName, + }), + WithSessionID(sessionID), + WithSequenceCounter(seq), + WithAuditor(aud), + ).Start() + + return &sessionCorrelationIntegrationSetup{ + pt: pt, + auditor: aud, + seq: seq, + llmBackend: llm, + otherBackend: other, + } +} + +// ---------- Integration Tests ---------- + +// TestIntegration_LLMRequestAuditAndHeadersAgree verifies the core +// correlation invariant: when an allowed request hits an inject target, +// the sequence number in the audit event equals the sequence number in +// the forwarded header. +func TestIntegration_LLMRequestAuditAndHeadersAgree(t *testing.T) { + const sessionID = "e5f6a7b8-0000-0000-0000-000000000000" + s := newSessionCorrelationIntegrationSetup(t, sessionID) + defer s.stop() + + resp, err := s.pt.proxyClient.Get(s.llmBackend.server.URL + "/v1/messages") + require.NoError(t, err) + defer resp.Body.Close() //nolint:errcheck + require.Equal(t, http.StatusOK, resp.StatusCode) + + // Audit event. + events := s.auditor.getRequests() + require.Len(t, events, 1) + require.True(t, events[0].Allowed) + require.NotNil(t, events[0].SequenceNumber) + assert.Equal(t, uint64(0), *events[0].SequenceNumber) + + // Forwarded headers. + require.Equal(t, 1, s.llmBackend.requestCount()) + hdr := s.llmBackend.headersAt(0) + assert.Equal(t, sessionID, hdr.Get(config.DefaultSessionIDHeaderName)) + assert.Equal(t, "0", hdr.Get(config.DefaultSequenceNumberHeaderName)) + + // The two must agree. + assert.Equal(t, + strconv.FormatUint(*events[0].SequenceNumber, 10), + hdr.Get(config.DefaultSequenceNumberHeaderName), + "audit event and forwarded header must carry the same sequence number", + ) +} + +// TestIntegration_NonLLMRequestAuditedWithoutHeaders verifies that an +// allowed request to a domain that is NOT an inject target still gets +// audited (with a sequence number) but does NOT receive correlation +// headers. +func TestIntegration_NonLLMRequestAuditedWithoutHeaders(t *testing.T) { + s := newSessionCorrelationIntegrationSetup(t, "test-session") + defer s.stop() + + resp, err := s.pt.proxyClient.Get(s.otherBackend.server.URL + "/pulls") + require.NoError(t, err) + defer resp.Body.Close() //nolint:errcheck + require.Equal(t, http.StatusOK, resp.StatusCode) + + // Audit event recorded. + events := s.auditor.getRequests() + require.Len(t, events, 1) + require.True(t, events[0].Allowed) + require.NotNil(t, events[0].SequenceNumber) + assert.Equal(t, uint64(0), *events[0].SequenceNumber) + + // No correlation headers on the backend. + require.Equal(t, 1, s.otherBackend.requestCount()) + hdr := s.otherBackend.headersAt(0) + assert.Empty(t, hdr.Get(config.DefaultSessionIDHeaderName), + "non-inject-target requests must not carry session ID header") + assert.Empty(t, hdr.Get(config.DefaultSequenceNumberHeaderName), + "non-inject-target requests must not carry sequence number header") +} + +// TestIntegration_DeniedRequestAuditedNeverForwarded verifies that a +// request denied by the rules engine is audited (consuming a sequence +// number) but is never forwarded to any backend. +func TestIntegration_DeniedRequestAuditedNeverForwarded(t *testing.T) { + // Create a setup with a custom deny-all proxy, but keep the same + // pattern of shared sequence counter and auditor. + llm := newMultiRequestCapturingBackend() + defer llm.close() + + aud := &capturingAuditor{} + seq := &audit.SequenceCounter{} + + pt := NewProxyTest(t, + WithCertManager(t.TempDir()), + // No allowed domains: deny everything. + WithSessionCorrelation(config.SessionCorrelationConfig{ + Enabled: true, + InjectTargets: []config.InjectTarget{{Domain: "anything.example.com"}}, + SessionIDHeaderName: config.DefaultSessionIDHeaderName, + SequenceNumberHeaderName: config.DefaultSequenceNumberHeaderName, + }), + WithSessionID("test-session"), + WithSequenceCounter(seq), + WithAuditor(aud), + ).Start() + defer pt.Stop() + + resp, err := pt.proxyClient.Get(llm.server.URL + "/exfil") + require.NoError(t, err) + defer resp.Body.Close() //nolint:errcheck + require.Equal(t, http.StatusForbidden, resp.StatusCode) + + // Audit event recorded. + events := aud.getRequests() + require.Len(t, events, 1) + require.False(t, events[0].Allowed) + require.NotNil(t, events[0].SequenceNumber) + assert.Equal(t, uint64(0), *events[0].SequenceNumber) + + // Backend never hit. + assert.Equal(t, 0, llm.requestCount(), + "denied requests must not be forwarded to the backend") +} + +// TestIntegration_MixedRequestsSequenceOrdering sends a realistic +// sequence of LLM, non-LLM, and denied requests, then verifies: +// 1. Sequence numbers increase monotonically across all request types. +// 2. Only inject-target requests carry correlation headers. +// 3. The sequence numbers in headers match the audit events. +// 4. The gap between two LLM requests' sequence numbers reveals the +// intermediate non-LLM and denied activity. +func TestIntegration_MixedRequestsSequenceOrdering(t *testing.T) { + const sessionID = "mixed-test-session" + + // Two allowed backends (LLM and "github"), one denied domain. + llm := newMultiRequestCapturingBackend() + defer llm.close() + + other := newMultiRequestCapturingBackend() + defer other.close() + + llmURL, err := url.Parse(llm.server.URL) + require.NoError(t, err) + + otherURL, err := url.Parse(other.server.URL) + require.NoError(t, err) + + aud := &capturingAuditor{} + seq := &audit.SequenceCounter{} + + pt := NewProxyTest(t, + WithCertManager(t.TempDir()), + WithAllowedDomain(llmURL.Hostname()), + WithAllowedDomain(otherURL.Hostname()), + // Only LLM is an inject target. + WithSessionCorrelation(config.SessionCorrelationConfig{ + Enabled: true, + InjectTargets: []config.InjectTarget{{ + Domain: llmURL.Hostname(), + Path: "/v1/*", + }}, + SessionIDHeaderName: config.DefaultSessionIDHeaderName, + SequenceNumberHeaderName: config.DefaultSequenceNumberHeaderName, + }), + WithSessionID(sessionID), + WithSequenceCounter(seq), + WithAuditor(aud), + ).Start() + defer pt.Stop() + + // Request 0: LLM (allowed, inject target). + resp, err := pt.proxyClient.Get(llm.server.URL + "/v1/messages") + require.NoError(t, err) + resp.Body.Close() //nolint:errcheck + require.Equal(t, http.StatusOK, resp.StatusCode) + + // Request 1: non-LLM (allowed, no inject). + resp, err = pt.proxyClient.Get(other.server.URL + "/coder/coder") + require.NoError(t, err) + resp.Body.Close() //nolint:errcheck + require.Equal(t, http.StatusOK, resp.StatusCode) + + // Request 2: denied (nothing is allowed for evil.example.com). + resp, err = pt.proxyClient.Get("http://evil.example.com/exfil") + require.NoError(t, err) + resp.Body.Close() //nolint:errcheck + require.Equal(t, http.StatusForbidden, resp.StatusCode) + + // Request 3: LLM again. + resp, err = pt.proxyClient.Get(llm.server.URL + "/v1/messages") + require.NoError(t, err) + resp.Body.Close() //nolint:errcheck + require.Equal(t, http.StatusOK, resp.StatusCode) + + // -- Verify audit events -- + events := aud.getRequests() + require.Len(t, events, 4, "expected exactly four audit events") + + expectedSeq := []uint64{0, 1, 2, 3} + expectedAllowed := []bool{true, true, false, true} + for i, ev := range events { + require.NotNil(t, ev.SequenceNumber, "event %d: sequence number must be set", i) + assert.Equal(t, expectedSeq[i], *ev.SequenceNumber, + "event %d: wrong sequence number", i) + assert.Equal(t, expectedAllowed[i], ev.Allowed, + "event %d: wrong allowed flag", i) + } + + // -- Verify LLM backend headers -- + require.Equal(t, 2, llm.requestCount(), + "LLM backend should have received exactly two requests") + + firstLLMHdr := llm.headersAt(0) + assert.Equal(t, sessionID, firstLLMHdr.Get(config.DefaultSessionIDHeaderName)) + assert.Equal(t, "0", firstLLMHdr.Get(config.DefaultSequenceNumberHeaderName), + "first LLM request must have sequence 0") + + secondLLMHdr := llm.headersAt(1) + assert.Equal(t, sessionID, secondLLMHdr.Get(config.DefaultSessionIDHeaderName)) + assert.Equal(t, "3", secondLLMHdr.Get(config.DefaultSequenceNumberHeaderName), + "second LLM request must have sequence 3") + + // -- Verify non-LLM backend has no correlation headers -- + require.Equal(t, 1, other.requestCount()) + otherHdr := other.headersAt(0) + assert.Empty(t, otherHdr.Get(config.DefaultSessionIDHeaderName)) + assert.Empty(t, otherHdr.Get(config.DefaultSequenceNumberHeaderName)) + + // -- Verify the gap reveals intermediate activity -- + // The gap between the two LLM sequence numbers (0 and 3) means + // that sequence numbers 1 and 2 were consumed by non-LLM + // activity, matching audit events 1 (non-LLM allowed) and 2 + // (denied). + firstLLMSeq := *events[0].SequenceNumber + secondLLMSeq := *events[3].SequenceNumber + gap := secondLLMSeq - firstLLMSeq - 1 + assert.Equal(t, uint64(2), gap, + "gap between LLM requests should reveal 2 intermediate events") +} + +// TestIntegration_SequenceGapRevealsAgenticLoop sends two LLM requests +// with several non-LLM requests in between, simulating an agentic loop +// where the model triggers tool-use HTTP calls between prompts. The +// test verifies that the gap in LLM sequence numbers precisely +// reflects the count of intermediate boundary events. +func TestIntegration_SequenceGapRevealsAgenticLoop(t *testing.T) { + const sessionID = "agentic-loop-session" + + llm := newMultiRequestCapturingBackend() + defer llm.close() + + other := newMultiRequestCapturingBackend() + defer other.close() + + llmURL, err := url.Parse(llm.server.URL) + require.NoError(t, err) + + otherURL, err := url.Parse(other.server.URL) + require.NoError(t, err) + + aud := &capturingAuditor{} + seq := &audit.SequenceCounter{} + + pt := NewProxyTest(t, + WithCertManager(t.TempDir()), + WithAllowedDomain(llmURL.Hostname()), + WithAllowedDomain(otherURL.Hostname()), + WithSessionCorrelation(config.SessionCorrelationConfig{ + Enabled: true, + InjectTargets: []config.InjectTarget{{ + Domain: llmURL.Hostname(), + Path: "/v1/*", + }}, + SessionIDHeaderName: config.DefaultSessionIDHeaderName, + SequenceNumberHeaderName: config.DefaultSequenceNumberHeaderName, + }), + WithSessionID(sessionID), + WithSequenceCounter(seq), + WithAuditor(aud), + ).Start() + defer pt.Stop() + + // First LLM prompt (seq 0). + resp, err := pt.proxyClient.Get(llm.server.URL + "/v1/messages") + require.NoError(t, err) + resp.Body.Close() //nolint:errcheck + + // Agentic loop: three tool-use HTTP calls. + for _, p := range []string{"/coder/coder", "/coder/coder/issues", "/coder/coder/pulls"} { + resp, err = pt.proxyClient.Get(other.server.URL + p) + require.NoError(t, err) + resp.Body.Close() //nolint:errcheck + } + + // Second LLM prompt (seq 4). + resp, err = pt.proxyClient.Get(llm.server.URL + "/v1/messages") + require.NoError(t, err) + resp.Body.Close() //nolint:errcheck + + // Verify LLM sequence headers. + require.Equal(t, 2, llm.requestCount()) + assert.Equal(t, "0", llm.headersAt(0).Get(config.DefaultSequenceNumberHeaderName)) + assert.Equal(t, "4", llm.headersAt(1).Get(config.DefaultSequenceNumberHeaderName)) + + // The gap between sequence numbers 0 and 4 is 3, matching the + // three tool-use requests in between. + events := aud.getRequests() + require.Len(t, events, 5) + + firstLLMSeq := *events[0].SequenceNumber + secondLLMSeq := *events[4].SequenceNumber + gap := secondLLMSeq - firstLLMSeq - 1 + assert.Equal(t, uint64(3), gap, + "gap between prompts should equal number of tool-use requests") + + // Verify the intermediate events are the tool-use requests. + for i := 1; i <= 3; i++ { + require.NotNil(t, events[i].SequenceNumber) + assert.Equal(t, uint64(i), *events[i].SequenceNumber) + assert.True(t, events[i].Allowed) + } +} + +// TestIntegration_SpoofedHeadersOverwrittenWithCorrectSequence +// verifies that when a jailed client sets its own correlation headers, +// the proxy replaces them with the real session ID and the real +// sequence number, and the audit event still agrees with the header. +func TestIntegration_SpoofedHeadersOverwrittenWithCorrectSequence(t *testing.T) { + const sessionID = "real-session-uuid" + s := newSessionCorrelationIntegrationSetup(t, sessionID) + defer s.stop() + + req, err := http.NewRequest(http.MethodPost, s.llmBackend.server.URL+"/v1/messages", nil) + require.NoError(t, err) + req.Header.Set(config.DefaultSessionIDHeaderName, "spoofed-session") + req.Header.Set(config.DefaultSequenceNumberHeaderName, "9999") + + resp, err := s.pt.proxyClient.Do(req) + require.NoError(t, err) + defer resp.Body.Close() //nolint:errcheck + require.Equal(t, http.StatusOK, resp.StatusCode) + + // Backend received real values, not spoofed. + require.Equal(t, 1, s.llmBackend.requestCount()) + hdr := s.llmBackend.headersAt(0) + assert.Equal(t, sessionID, hdr.Get(config.DefaultSessionIDHeaderName)) + assert.Equal(t, "0", hdr.Get(config.DefaultSequenceNumberHeaderName)) + + // Audit event agrees with header. + events := s.auditor.getRequests() + require.Len(t, events, 1) + require.NotNil(t, events[0].SequenceNumber) + assert.Equal(t, + strconv.FormatUint(*events[0].SequenceNumber, 10), + hdr.Get(config.DefaultSequenceNumberHeaderName), + ) +} + +// TestIntegration_DisabledCorrelationNoHeadersNoPreallocatedSequence +// verifies that when session correlation is disabled, the proxy does +// not inject headers and does not pre-allocate sequence numbers (the +// auditor falls back to its own counter instead). +func TestIntegration_DisabledCorrelationNoHeadersNoPreallocatedSequence(t *testing.T) { + llm := newMultiRequestCapturingBackend() + defer llm.close() + + llmURL, err := url.Parse(llm.server.URL) + require.NoError(t, err) + + aud := &capturingAuditor{} + + pt := NewProxyTest(t, + WithCertManager(t.TempDir()), + WithAllowedDomain(llmURL.Hostname()), + // Correlation disabled; no sequence counter. + WithSessionCorrelation(config.SessionCorrelationConfig{ + Enabled: false, + InjectTargets: []config.InjectTarget{{Domain: llmURL.Hostname()}}, + SessionIDHeaderName: config.DefaultSessionIDHeaderName, + SequenceNumberHeaderName: config.DefaultSequenceNumberHeaderName, + }), + WithSessionID("should-not-appear"), + // Explicitly do NOT set WithSequenceCounter; seqCounter is nil. + WithAuditor(aud), + ).Start() + defer pt.Stop() + + resp, err := pt.proxyClient.Get(llm.server.URL + "/v1/messages") + require.NoError(t, err) + defer resp.Body.Close() //nolint:errcheck + require.Equal(t, http.StatusOK, resp.StatusCode) + + // No correlation headers. + require.Equal(t, 1, llm.requestCount()) + hdr := llm.headersAt(0) + assert.Empty(t, hdr.Get(config.DefaultSessionIDHeaderName)) + assert.Empty(t, hdr.Get(config.DefaultSequenceNumberHeaderName)) + + // Audit event recorded but without a pre-allocated sequence + // number (nil), because no SequenceCounter was provided. + events := aud.getRequests() + require.Len(t, events, 1) + assert.Nil(t, events[0].SequenceNumber, + "no sequence counter means no pre-allocated sequence number") +} + +// TestIntegration_ConcurrentRequestsUniqueSequenceNumbers sends +// multiple requests concurrently and verifies that every request +// receives a unique sequence number, and that the set of numbers is +// dense (no gaps, no duplicates). +func TestIntegration_ConcurrentRequestsUniqueSequenceNumbers(t *testing.T) { + const sessionID = "concurrent-session" + const numRequests = 10 + + s := newSessionCorrelationIntegrationSetup(t, sessionID) + defer s.stop() + + var wg sync.WaitGroup + for i := 0; i < numRequests; i++ { + wg.Add(1) + go func() { + defer wg.Done() + resp, err := s.pt.proxyClient.Get(s.llmBackend.server.URL + "/v1/messages") + assert.NoError(t, err) + if resp != nil { + resp.Body.Close() //nolint:errcheck + } + }() + } + wg.Wait() + + // Every request should have been audited. + events := s.auditor.getRequests() + require.Len(t, events, numRequests) + + // Collect all sequence numbers and verify uniqueness. + seen := make(map[uint64]bool, numRequests) + for i, ev := range events { + require.NotNil(t, ev.SequenceNumber, + "event %d: sequence number must not be nil", i) + assert.False(t, seen[*ev.SequenceNumber], + "event %d: duplicate sequence number %d", i, *ev.SequenceNumber) + seen[*ev.SequenceNumber] = true + } + + // The set should be exactly {0, 1, ..., numRequests-1}. + for i := uint64(0); i < numRequests; i++ { + assert.True(t, seen[i], + "sequence number %d is missing from the set", i) + } + + // Every header should also carry a matching sequence number. + require.Equal(t, numRequests, s.llmBackend.requestCount()) + headerSeqs := make(map[string]bool, numRequests) + for i := 0; i < numRequests; i++ { + hdr := s.llmBackend.headersAt(i) + seqStr := hdr.Get(config.DefaultSequenceNumberHeaderName) + assert.NotEmpty(t, seqStr, "request %d: sequence header must be set", i) + headerSeqs[seqStr] = true + } + for i := uint64(0); i < numRequests; i++ { + assert.True(t, headerSeqs[fmt.Sprintf("%d", i)], + "header sequence number %d is missing", i) + } +}