From c4966d89d236fa54616e7c2a032e4cfb0656fb50 Mon Sep 17 00:00:00 2001 From: Peter Dedene Date: Sat, 21 Feb 2026 14:35:54 +0100 Subject: [PATCH 1/2] feat(credentials): add HashiCorp Vault secret backend Spawn `vault kv get -format=json` via CLI, matching existing backend pattern (op, bw, pass). Zero new Go dependencies. - KV-v2 and KV-v1 auto-detection (checks for both data + metadata) - jq pipe support: `vault:secret/app/creds | .password` - Config fields: vault_binary, vault_addr, vault_skip_verify, vault_cacert, vault_namespace, vault_token_file - Hot-reload of all vault settings via fsnotify - Wired through daemon, executor, and HTTP proxy - Secret backends table added to README --- README.md | 14 + docs/CONFIG.md | 48 +++- internal/config/config.go | 67 ++++- internal/config/config_test.go | 117 +++++++++ internal/credentials/cache.go | 2 +- internal/credentials/credentials.go | 15 ++ internal/credentials/credentials_test.go | 8 + internal/credentials/parser.go | 6 + internal/credentials/parser_test.go | 20 ++ internal/credentials/vault.go | 159 ++++++++++++ internal/credentials/vault_test.go | 312 +++++++++++++++++++++++ internal/daemon/daemon.go | 17 ++ internal/daemon/daemon_test.go | 1 - internal/daemon/executor.go | 2 + internal/daemon/executor_env_test.go | 8 +- internal/httpproxy/proxy.go | 8 + 16 files changed, 795 insertions(+), 9 deletions(-) create mode 100644 internal/credentials/vault.go create mode 100644 internal/credentials/vault_test.go diff --git a/README.md b/README.md index 2e3f14d..95cad28 100644 --- a/README.md +++ b/README.md @@ -87,6 +87,20 @@ claw-wrap supports two approaches for credential injection: - You want route-based credential injection by host/path - Multiple tools need the same API credentials +## Secret Backends + +| Backend | Prefix | Example | Notes | +| ------- | ------ | ------- | ----- | +| [pass](https://www.passwordstore.org/) | `pass:` | `pass:cli/github/token` | Default when no prefix given | +| Environment | `env:` | `env:MY_TOKEN` | Reads from daemon environment | +| [1Password](https://1password.com/) | `op://` | `op://Vault/Item/field` | Requires `op` CLI, session auth | +| [Bitwarden](https://bitwarden.com/) | `bw:` | `bw:item-uuid` | Requires `bw` CLI, session managed | +| [macOS Keychain](https://support.apple.com/guide/keychain-access/) | `keychain:` | `keychain:service-name` | macOS only | +| [age](https://age-encryption.org/) | `age:` | `age:/path/to/file.age` | File-level encryption | +| [HashiCorp Vault](https://www.vaultproject.io/) | `vault:` | `vault:secret/myapp/key` | KV-v1 & KV-v2, external auth | + +All backends except `env:` support jq extraction: `vault:secret/app/creds \| .password` + ## Quick Start This example sets up `gh` (GitHub CLI) as a proxied tool. diff --git a/docs/CONFIG.md b/docs/CONFIG.md index 8d3160a..aa889e4 100644 --- a/docs/CONFIG.md +++ b/docs/CONFIG.md @@ -257,7 +257,7 @@ Optional in-memory TTL cache for credential fetch results. - Default: `0` (disabled) - Format: Go duration (`30s`, `2m`, `1h`) -- Scope: only `op://` (1Password) and `bw:` (Bitwarden) credential sources +- Scope: `op://` (1Password), `bw:` (Bitwarden), and `vault:` (HashiCorp Vault) credential sources - `claw-wrap check` always bypasses this cache and fetches credentials live Use this to reduce repeated upstream secret-store latency for frequently-invoked tools. @@ -685,6 +685,52 @@ If `bw_binary` is unset, claw-wrap only auto-detects `bw` in trusted directories - Session token passed via environment variable, not command line - Session cleaned up on daemon shutdown +### HashiCorp Vault (`vault:`) + +```yaml +credentials: + api-key: + source: vault:secret/myapp/api-key + + # With jq extraction from secret JSON + db-password: + source: vault:secret/myapp/database | .password +``` + +Fetches secrets from HashiCorp Vault using the `vault` CLI. Supports both KV-v2 (default) and KV-v1 engines. + +Use natural paths (e.g., `secret/myapp/key`) — the `vault kv get` command handles the KV-v2 `/data/` path prefix internally. + +Optional CLI and connection overrides: + +```yaml +proxy: + vault_binary: /usr/bin/vault + vault_addr: https://127.0.0.1:8200 + vault_skip_verify: false + vault_cacert: /etc/vault/ca.pem + vault_namespace: "" + vault_token_file: /home/bot/.vault-token +``` + +If `vault_binary` is unset, claw-wrap only auto-detects `vault` in trusted directories: +`/usr/bin`, `/usr/local/bin`, `/opt/homebrew/bin`, `/home/linuxbrew/.linuxbrew/bin`. + +Connection settings (`vault_addr`, `vault_skip_verify`, `vault_cacert`, `vault_namespace`) override the corresponding `VAULT_ADDR`, `VAULT_SKIP_VERIFY`, `VAULT_CACERT`, and `VAULT_NAMESPACE` environment variables when set. + +**Authentication model:** + +claw-wrap does **not** authenticate with Vault itself. The user (or operator) must run `vault login` externally, which stores a token at `~/.vault-token`. The `vault` CLI reads this token automatically. This supports time-scoped access: configure TTL on the Vault user so tokens expire after a set window (15 minutes, 1 hour, etc.). + +Use `vault_token_file` to point to a non-default token file location (requires Vault CLI 1.10+). + +**Security:** + +- Secrets never stored in plaintext config — fetched on-demand via CLI +- Token managed externally; claw-wrap cannot refresh or extend access +- Expired tokens produce a generic "vault read failed" error +- Supports self-signed certs via `vault_cacert` or `vault_skip_verify` + ### jq Extraction All backends support jq extraction using the pipe syntax: diff --git a/internal/config/config.go b/internal/config/config.go index ea4cd35..fe250e1 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -59,6 +59,12 @@ type ProxyConfig struct { ReplayCacheTTL string `yaml:"replay_cache_ttl"` // e.g., "2m" ReplayCacheMax int `yaml:"replay_cache_max_entries"` CredentialCacheTTL string `yaml:"credential_cache_ttl"` // e.g., "30s" (0/empty disables) + VaultBinary string `yaml:"vault_binary"` // e.g., "/usr/bin/vault" + VaultAddr string `yaml:"vault_addr"` // e.g., "https://127.0.0.1:8200" + VaultSkipVerify bool `yaml:"vault_skip_verify"` // skip TLS verification + VaultCACert string `yaml:"vault_cacert"` // e.g., "/path/to/ca.pem" + VaultNamespace string `yaml:"vault_namespace"` // enterprise namespace + VaultTokenFile string `yaml:"vault_token_file"` // override default ~/.vault-token } // SecurityConfig holds security policy flags. @@ -185,8 +191,8 @@ type ToolDef struct { AllowedArgs []BlockedArg `yaml:"allowed_args,omitempty"` RedactOutput []ToolRedactRule `yaml:"redact_output,omitempty"` ConfigFile *ConfigFileDef `yaml:"config_file,omitempty"` - UseProxy bool `yaml:"use_proxy,omitempty"` // Enable HTTP proxy for this tool - UsePTY *bool `yaml:"use_pty,omitempty"` // PTY mode: nil=default on, false=opt out + UseProxy bool `yaml:"use_proxy,omitempty"` // Enable HTTP proxy for this tool + UsePTY *bool `yaml:"use_pty,omitempty"` // PTY mode: nil=default on, false=opt out } // GetUsePTY returns whether PTY mode is enabled for this tool. @@ -826,6 +832,63 @@ func (c *Config) GetBWBinary() string { return "" } +// GetVaultBinary returns the configured Vault CLI binary path or empty for trusted-directory lookup. +func (c *Config) GetVaultBinary() string { + if c.Proxy != nil && c.Proxy.VaultBinary != "" { + if !filepath.IsAbs(c.Proxy.VaultBinary) { + log.Printf("[WARN] vault_binary %q is not absolute, using trusted-directory lookup", c.Proxy.VaultBinary) + return "" + } + return c.Proxy.VaultBinary + } + return "" +} + +// GetVaultAddr returns the configured Vault server address (empty = use VAULT_ADDR env). +func (c *Config) GetVaultAddr() string { + if c.Proxy != nil { + return c.Proxy.VaultAddr + } + return "" +} + +// GetVaultSkipVerify returns whether to skip Vault TLS verification. +func (c *Config) GetVaultSkipVerify() bool { + return c.Proxy != nil && c.Proxy.VaultSkipVerify +} + +// GetVaultCACert returns the Vault CA cert path (empty = use VAULT_CACERT env). +func (c *Config) GetVaultCACert() string { + if c.Proxy != nil && c.Proxy.VaultCACert != "" { + if !filepath.IsAbs(c.Proxy.VaultCACert) { + log.Printf("[WARN] vault_cacert %q is not absolute, ignoring", c.Proxy.VaultCACert) + return "" + } + return c.Proxy.VaultCACert + } + return "" +} + +// GetVaultNamespace returns the Vault enterprise namespace (empty = none). +func (c *Config) GetVaultNamespace() string { + if c.Proxy != nil { + return c.Proxy.VaultNamespace + } + return "" +} + +// GetVaultTokenFile returns the Vault token file path (empty = default ~/.vault-token). +func (c *Config) GetVaultTokenFile() string { + if c.Proxy != nil && c.Proxy.VaultTokenFile != "" { + if !filepath.IsAbs(c.Proxy.VaultTokenFile) { + log.Printf("[WARN] vault_token_file %q is not absolute, ignoring", c.Proxy.VaultTokenFile) + return "" + } + return c.Proxy.VaultTokenFile + } + return "" +} + // GetOPTokenFile returns the configured 1Password token file path or the default. // The returned path is always absolute. func (c *Config) GetOPTokenFile() string { diff --git a/internal/config/config_test.go b/internal/config/config_test.go index e4e50c8..5330e90 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -662,6 +662,123 @@ func TestGetBWBinary(t *testing.T) { } } +func TestGetVaultBinary(t *testing.T) { + tests := []struct { + name string + cfg Config + want string + }{ + {"nil proxy returns empty", Config{}, ""}, + {"empty vault_binary returns empty", Config{Proxy: &ProxyConfig{}}, ""}, + {"configured absolute path returned", Config{Proxy: &ProxyConfig{VaultBinary: "/usr/bin/vault"}}, "/usr/bin/vault"}, + {"relative path rejected", Config{Proxy: &ProxyConfig{VaultBinary: "vault"}}, ""}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := tt.cfg.GetVaultBinary(); got != tt.want { + t.Errorf("GetVaultBinary() = %q, want %q", got, tt.want) + } + }) + } +} + +func TestGetVaultAddr(t *testing.T) { + tests := []struct { + name string + cfg Config + want string + }{ + {"nil proxy", Config{}, ""}, + {"empty", Config{Proxy: &ProxyConfig{}}, ""}, + {"configured", Config{Proxy: &ProxyConfig{VaultAddr: "https://vault:8200"}}, "https://vault:8200"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := tt.cfg.GetVaultAddr(); got != tt.want { + t.Errorf("GetVaultAddr() = %q, want %q", got, tt.want) + } + }) + } +} + +func TestGetVaultSkipVerify(t *testing.T) { + tests := []struct { + name string + cfg Config + want bool + }{ + {"nil proxy", Config{}, false}, + {"default false", Config{Proxy: &ProxyConfig{}}, false}, + {"configured true", Config{Proxy: &ProxyConfig{VaultSkipVerify: true}}, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := tt.cfg.GetVaultSkipVerify(); got != tt.want { + t.Errorf("GetVaultSkipVerify() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestGetVaultCACert(t *testing.T) { + tests := []struct { + name string + cfg Config + want string + }{ + {"nil proxy", Config{}, ""}, + {"empty", Config{Proxy: &ProxyConfig{}}, ""}, + {"absolute path", Config{Proxy: &ProxyConfig{VaultCACert: "/etc/vault/ca.pem"}}, "/etc/vault/ca.pem"}, + {"relative path rejected", Config{Proxy: &ProxyConfig{VaultCACert: "ca.pem"}}, ""}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := tt.cfg.GetVaultCACert(); got != tt.want { + t.Errorf("GetVaultCACert() = %q, want %q", got, tt.want) + } + }) + } +} + +func TestGetVaultNamespace(t *testing.T) { + tests := []struct { + name string + cfg Config + want string + }{ + {"nil proxy", Config{}, ""}, + {"empty", Config{Proxy: &ProxyConfig{}}, ""}, + {"configured", Config{Proxy: &ProxyConfig{VaultNamespace: "team-a"}}, "team-a"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := tt.cfg.GetVaultNamespace(); got != tt.want { + t.Errorf("GetVaultNamespace() = %q, want %q", got, tt.want) + } + }) + } +} + +func TestGetVaultTokenFile(t *testing.T) { + tests := []struct { + name string + cfg Config + want string + }{ + {"nil proxy", Config{}, ""}, + {"empty returns empty", Config{Proxy: &ProxyConfig{}}, ""}, + {"absolute path", Config{Proxy: &ProxyConfig{VaultTokenFile: "/home/bot/.vault-token"}}, "/home/bot/.vault-token"}, + {"relative path rejected", Config{Proxy: &ProxyConfig{VaultTokenFile: ".vault-token"}}, ""}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := tt.cfg.GetVaultTokenFile(); got != tt.want { + t.Errorf("GetVaultTokenFile() = %q, want %q", got, tt.want) + } + }) + } +} + func TestLoad_PassBinaryFromYAML(t *testing.T) { tmpDir := t.TempDir() diff --git a/internal/credentials/cache.go b/internal/credentials/cache.go index 007d708..e3afe85 100644 --- a/internal/credentials/cache.go +++ b/internal/credentials/cache.go @@ -191,7 +191,7 @@ func sweepInterval(ttl time.Duration) time.Duration { } func isCredentialCacheableBackend(backend Backend) bool { - return backend == Backend1Password || backend == BackendBitwarden + return backend == Backend1Password || backend == BackendBitwarden || backend == BackendVault } func credentialCacheKey(parsed *ParsedSource) string { diff --git a/internal/credentials/credentials.go b/internal/credentials/credentials.go index 1b50982..4909882 100644 --- a/internal/credentials/credentials.go +++ b/internal/credentials/credentials.go @@ -25,6 +25,7 @@ type FetchOptions struct { PassBinary string OPBinary string BWBinary string + VaultBinary string BypassCache bool } @@ -52,6 +53,13 @@ func WithBWBinary(path string) FetchOption { } } +// WithVaultBinary sets the path to the HashiCorp Vault CLI binary. +func WithVaultBinary(path string) FetchOption { + return func(o *FetchOptions) { + o.VaultBinary = path + } +} + // WithBypassCache forces live credential fetches and bypasses result caching. func WithBypassCache() FetchOption { return func(o *FetchOptions) { @@ -67,6 +75,7 @@ func WithBypassCache() FetchOption { // - age:/path/to/file.age - decrypt age-encrypted file // - keychain:service-name - fetch from macOS Keychain // - bw:item-uuid - fetch from Bitwarden +// - vault:secret/path - fetch from HashiCorp Vault // - path/in/store - legacy format, assumed to be pass // // All sources optionally support jq extraction: "source | .jq_expr" @@ -145,6 +154,12 @@ func Fetch(source string, opts ...FetchOption) (string, error) { return "", err } + case BackendVault: + result, err = fetchFromVault(ctx, parsed, options.VaultBinary) + if err != nil { + return "", err + } + default: return "", fmt.Errorf("unknown credential backend") } diff --git a/internal/credentials/credentials_test.go b/internal/credentials/credentials_test.go index 0e63950..5d0ffe7 100644 --- a/internal/credentials/credentials_test.go +++ b/internal/credentials/credentials_test.go @@ -59,6 +59,14 @@ func TestWithBWBinary(t *testing.T) { } } +func TestWithVaultBinary(t *testing.T) { + opts := &FetchOptions{} + WithVaultBinary("/custom/vault")(opts) + if opts.VaultBinary != "/custom/vault" { + t.Errorf("VaultBinary = %q, want %q", opts.VaultBinary, "/custom/vault") + } +} + func TestWithBypassCache(t *testing.T) { opts := &FetchOptions{} WithBypassCache()(opts) diff --git a/internal/credentials/parser.go b/internal/credentials/parser.go index 4d81c11..4c098b3 100644 --- a/internal/credentials/parser.go +++ b/internal/credentials/parser.go @@ -16,6 +16,7 @@ const ( BackendAge Backend = "age" BackendKeychain Backend = "keychain" BackendBitwarden Backend = "bw" + BackendVault Backend = "vault" ) // ParsedSource represents a parsed credential source URI. @@ -38,6 +39,8 @@ type ParsedSource struct { // - keychain:service-name | .jq_expr // - age:/path/to/file.age // - age:/path/to/file.age | .jq_expr +// - vault:secret/myapp/api-key +// - vault:secret/myapp/api-key | .password // - path/in/store (legacy, assumed pass) func ParseSource(source string) (*ParsedSource, error) { if source == "" { @@ -79,6 +82,9 @@ func ParseSource(source string) (*ParsedSource, error) { case strings.HasPrefix(source, "env:"): backend = BackendEnv path = strings.TrimPrefix(source, "env:") + case strings.HasPrefix(source, "vault:"): + backend = BackendVault + path = strings.TrimPrefix(source, "vault:") default: // Legacy format: assume pass backend = BackendPass diff --git a/internal/credentials/parser_test.go b/internal/credentials/parser_test.go index 9285b51..a351a2b 100644 --- a/internal/credentials/parser_test.go +++ b/internal/credentials/parser_test.go @@ -95,6 +95,26 @@ func TestParseSource(t *testing.T) { wantJQ: ".credentials.token", }, + // Vault backend + { + name: "vault simple", + source: "vault:secret/myapp/api-key", + wantBackend: BackendVault, + wantPath: "secret/myapp/api-key", + }, + { + name: "vault with jq", + source: "vault:secret/myapp/creds | .password", + wantBackend: BackendVault, + wantPath: "secret/myapp/creds", + wantJQ: ".password", + }, + { + name: "vault empty path", + source: "vault:", + wantErr: true, + }, + // Complex jq expressions { name: "complex jq filter", diff --git a/internal/credentials/vault.go b/internal/credentials/vault.go new file mode 100644 index 0000000..e1ca6a8 --- /dev/null +++ b/internal/credentials/vault.go @@ -0,0 +1,159 @@ +package credentials + +import ( + "context" + "encoding/json" + "fmt" + "log" + "os" + "os/exec" + "strings" + "sync" + "time" +) + +const vaultCommandTimeout = 30 * time.Second + +// vaultConfig holds Vault CLI environment overrides. +// Protected by vaultMu for concurrent access. +var ( + vaultMu sync.RWMutex + vaultAddr string + vaultSkipVerify bool + vaultCACert string + vaultNamespace string + vaultTokenFile string +) + +// SetVaultAddr configures the Vault server address. +func SetVaultAddr(addr string) { + vaultMu.Lock() + defer vaultMu.Unlock() + vaultAddr = addr +} + +// SetVaultSkipVerify configures TLS verification skip. +func SetVaultSkipVerify(skip bool) { + vaultMu.Lock() + defer vaultMu.Unlock() + vaultSkipVerify = skip +} + +// SetVaultCACert configures the CA cert path. +func SetVaultCACert(path string) { + vaultMu.Lock() + defer vaultMu.Unlock() + vaultCACert = path +} + +// SetVaultNamespace configures the Vault namespace (enterprise). +func SetVaultNamespace(ns string) { + vaultMu.Lock() + defer vaultMu.Unlock() + vaultNamespace = ns +} + +// SetVaultTokenFile configures a custom token file path. +func SetVaultTokenFile(path string) { + vaultMu.Lock() + defer vaultMu.Unlock() + vaultTokenFile = path +} + +func getVaultSettings() (addr string, skipVerify bool, caCert, namespace, tokenFile string) { + vaultMu.RLock() + defer vaultMu.RUnlock() + return vaultAddr, vaultSkipVerify, vaultCACert, vaultNamespace, vaultTokenFile +} + +// fetchFromVault retrieves a credential from HashiCorp Vault using the vault CLI. +func fetchFromVault(ctx context.Context, parsed *ParsedSource, vaultBinaryOverride string) (string, error) { + vaultBinary, err := resolveVaultBinary(vaultBinaryOverride) + if err != nil { + return "", fmt.Errorf("vault CLI not found in trusted locations") + } + + ctx, cancel := context.WithTimeout(ctx, vaultCommandTimeout) + defer cancel() + + cmd := exec.CommandContext(ctx, vaultBinary, "kv", "get", "-format=json", parsed.Path) + cmd.Env = vaultEnv() + + output, err := cmd.Output() + if err != nil { + log.Printf("[DEBUG] vault kv get failed: %v", err) + if exitErr, ok := err.(*exec.ExitError); ok && len(exitErr.Stderr) > 0 { + log.Printf("[DEBUG] vault stderr: %s", string(exitErr.Stderr)) + } + return "", fmt.Errorf("vault read failed") + } + + result, err := extractVaultData(output) + if err != nil { + return "", err + } + + if parsed.HasJQ() { + return ApplyJQ(ctx, []byte(result), parsed.JQExpr) + } + + return result, nil +} + +// extractVaultData parses the vault kv get JSON response. +// KV-v2 format: {"data":{"data":{...},"metadata":{...}}} +// KV-v1 format: {"data":{...}} +func extractVaultData(output []byte) (string, error) { + var response struct { + Data json.RawMessage `json:"data"` + } + if err := json.Unmarshal(output, &response); err != nil { + log.Printf("[DEBUG] vault response parse failed: %v", err) + return "", fmt.Errorf("vault read failed") + } + + // Try KV-v2: look for nested .data.data AND .data.metadata. + // Both must be present to confirm KV-v2 — a KV-v1 secret with a key + // named "data" would otherwise be misidentified. + var kvV2 struct { + Data json.RawMessage `json:"data"` + Metadata json.RawMessage `json:"metadata"` + } + if err := json.Unmarshal(response.Data, &kvV2); err == nil && kvV2.Data != nil && kvV2.Metadata != nil { + return strings.TrimSpace(string(kvV2.Data)), nil + } + + // Fall back to KV-v1: .data is the secret itself + return strings.TrimSpace(string(response.Data)), nil +} + +func resolveVaultBinary(vaultBinaryOverride string) (string, error) { + if vaultBinaryOverride != "" { + return vaultBinaryOverride, nil + } + return findTrustedBinaryFunc("vault") +} + +// vaultEnv returns environment variables for the vault CLI. +func vaultEnv() []string { + env := os.Environ() + addr, skipVerify, caCert, namespace, tokenFile := getVaultSettings() + + if addr != "" { + env = append(env, "VAULT_ADDR="+addr) + } + if skipVerify { + env = append(env, "VAULT_SKIP_VERIFY=1") + } + if caCert != "" { + env = append(env, "VAULT_CACERT="+caCert) + } + if namespace != "" { + env = append(env, "VAULT_NAMESPACE="+namespace) + } + if tokenFile != "" { + env = append(env, "VAULT_TOKEN_FILE="+tokenFile) + } + + return env +} diff --git a/internal/credentials/vault_test.go b/internal/credentials/vault_test.go new file mode 100644 index 0000000..40df17e --- /dev/null +++ b/internal/credentials/vault_test.go @@ -0,0 +1,312 @@ +package credentials + +import ( + "context" + "os" + "path/filepath" + "strconv" + "testing" +) + +func writeMockVaultScript(t *testing.T, dir string, output string, exitCode int) string { + t.Helper() + scriptPath := filepath.Join(dir, "vault") + var script string + if exitCode == 0 { + script = "#!/bin/sh\nprintf '%s' '" + output + "'\n" + } else { + script = "#!/bin/sh\necho 'Error: permission denied' >&2\nexit " + strconv.Itoa(exitCode) + "\n" + } + if err := os.WriteFile(scriptPath, []byte(script), 0755); err != nil { + t.Fatalf("write mock script: %v", err) + } + return scriptPath +} + +func TestResolveVaultBinary(t *testing.T) { + t.Run("override returns override", func(t *testing.T) { + got, err := resolveVaultBinary("/custom/vault") + if err != nil { + t.Errorf("resolveVaultBinary() error = %v", err) + } + if got != "/custom/vault" { + t.Errorf("resolveVaultBinary() = %q, want %q", got, "/custom/vault") + } + }) + + t.Run("empty falls through to findTrustedBinary", func(t *testing.T) { + orig := findTrustedBinaryFunc + findTrustedBinaryFunc = func(name string) (string, error) { + if name != "vault" { + t.Errorf("findTrustedBinaryFunc called with %q, want %q", name, "vault") + } + return "/usr/bin/vault", nil + } + defer func() { findTrustedBinaryFunc = orig }() + + got, err := resolveVaultBinary("") + if err != nil { + t.Errorf("resolveVaultBinary() error = %v", err) + } + if got != "/usr/bin/vault" { + t.Errorf("resolveVaultBinary() = %q, want %q", got, "/usr/bin/vault") + } + }) +} + +func TestFetchFromVault_NoBinary(t *testing.T) { + orig := findTrustedBinaryFunc + findTrustedBinaryFunc = func(name string) (string, error) { + return "", os.ErrNotExist + } + defer func() { findTrustedBinaryFunc = orig }() + + parsed := &ParsedSource{ + Backend: BackendVault, + Path: "secret/myapp/key", + } + + _, err := fetchFromVault(context.Background(), parsed, "") + if err == nil { + t.Error("fetchFromVault() should error when vault binary not found") + } +} + +func TestFetchFromVault_Success(t *testing.T) { + tmpDir := t.TempDir() + kvV2Response := `{"data":{"data":{"password":"s3cret","username":"admin"},"metadata":{"version":1}}}` + scriptPath := writeMockVaultScript(t, tmpDir, kvV2Response, 0) + + parsed := &ParsedSource{ + Backend: BackendVault, + Path: "secret/myapp/creds", + Original: "vault:secret/myapp/creds", + } + + result, err := fetchFromVault(context.Background(), parsed, scriptPath) + if err != nil { + t.Fatalf("fetchFromVault() error = %v", err) + } + + // Should return the inner .data.data JSON + want := `{"password":"s3cret","username":"admin"}` + if result != want { + t.Errorf("result = %q, want %q", result, want) + } +} + +func TestFetchFromVault_WithJQ(t *testing.T) { + tmpDir := t.TempDir() + kvV2Response := `{"data":{"data":{"password":"s3cret","username":"admin"},"metadata":{"version":1}}}` + scriptPath := writeMockVaultScript(t, tmpDir, kvV2Response, 0) + + parsed := &ParsedSource{ + Backend: BackendVault, + Path: "secret/myapp/creds", + JQExpr: ".password", + Original: "vault:secret/myapp/creds | .password", + } + + result, err := fetchFromVault(context.Background(), parsed, scriptPath) + if err != nil { + t.Fatalf("fetchFromVault() error = %v", err) + } + + if result != "s3cret" { + t.Errorf("result = %q, want %q", result, "s3cret") + } +} + +func TestFetchFromVault_KVv1Fallback(t *testing.T) { + tmpDir := t.TempDir() + // KV-v1 response has data directly at .data (no nested .data.data) + kvV1Response := `{"data":{"password":"v1secret","username":"admin"}}` + scriptPath := writeMockVaultScript(t, tmpDir, kvV1Response, 0) + + parsed := &ParsedSource{ + Backend: BackendVault, + Path: "secret/myapp/creds", + JQExpr: ".password", + Original: "vault:secret/myapp/creds | .password", + } + + result, err := fetchFromVault(context.Background(), parsed, scriptPath) + if err != nil { + t.Fatalf("fetchFromVault() error = %v", err) + } + + if result != "v1secret" { + t.Errorf("result = %q, want %q", result, "v1secret") + } +} + +func TestFetchFromVault_FailingBinary(t *testing.T) { + tmpDir := t.TempDir() + scriptPath := writeMockVaultScript(t, tmpDir, "", 2) + + parsed := &ParsedSource{ + Backend: BackendVault, + Path: "secret/myapp/creds", + Original: "vault:secret/myapp/creds", + } + + _, err := fetchFromVault(context.Background(), parsed, scriptPath) + if err == nil { + t.Error("fetchFromVault() should error on non-zero exit") + } + if err.Error() != "vault read failed" { + t.Errorf("error = %q, want generic %q", err.Error(), "vault read failed") + } +} + +func TestFetchFromVault_InvalidJSON(t *testing.T) { + tmpDir := t.TempDir() + scriptPath := writeMockVaultScript(t, tmpDir, "not-json", 0) + + parsed := &ParsedSource{ + Backend: BackendVault, + Path: "secret/myapp/creds", + Original: "vault:secret/myapp/creds", + } + + _, err := fetchFromVault(context.Background(), parsed, scriptPath) + if err == nil { + t.Error("fetchFromVault() should error on invalid JSON") + } + if err.Error() != "vault read failed" { + t.Errorf("error = %q, want generic %q", err.Error(), "vault read failed") + } +} + +func TestVaultEnv(t *testing.T) { + // Save originals + origAddr, origSkip, origCACert, origNs, origTokenFile := getVaultSettings() + defer func() { + SetVaultAddr(origAddr) + SetVaultSkipVerify(origSkip) + SetVaultCACert(origCACert) + SetVaultNamespace(origNs) + SetVaultTokenFile(origTokenFile) + }() + + SetVaultAddr("https://vault.example.com:8200") + SetVaultSkipVerify(true) + SetVaultCACert("/etc/vault/ca.pem") + SetVaultNamespace("team-a") + SetVaultTokenFile("/home/bot/.vault-token") + + env := vaultEnv() + + checks := map[string]string{ + "VAULT_ADDR": "https://vault.example.com:8200", + "VAULT_SKIP_VERIFY": "1", + "VAULT_CACERT": "/etc/vault/ca.pem", + "VAULT_NAMESPACE": "team-a", + "VAULT_TOKEN_FILE": "/home/bot/.vault-token", + } + + for key, want := range checks { + found := false + for _, e := range env { + if e == key+"="+want { + found = true + break + } + } + if !found { + t.Errorf("vaultEnv() missing %s=%s", key, want) + } + } +} + +func TestVaultEnv_Empty(t *testing.T) { + origAddr, origSkip, origCACert, origNs, origTokenFile := getVaultSettings() + defer func() { + SetVaultAddr(origAddr) + SetVaultSkipVerify(origSkip) + SetVaultCACert(origCACert) + SetVaultNamespace(origNs) + SetVaultTokenFile(origTokenFile) + }() + + SetVaultAddr("") + SetVaultSkipVerify(false) + SetVaultCACert("") + SetVaultNamespace("") + SetVaultTokenFile("") + + env := vaultEnv() + + // Count how many times each vault key appears — our setters should add zero. + // Ambient env may already contain VAULT_* vars (e.g. from a real Vault session), + // so we compare counts before and after to isolate what vaultEnv() appended. + baseEnv := os.Environ() + baseCount := make(map[string]int) + vaultKeys := []string{"VAULT_ADDR", "VAULT_SKIP_VERIFY", "VAULT_CACERT", "VAULT_NAMESPACE", "VAULT_TOKEN_FILE"} + for _, e := range baseEnv { + for _, key := range vaultKeys { + if len(e) > len(key) && e[:len(key)+1] == key+"=" { + baseCount[key]++ + } + } + } + + envCount := make(map[string]int) + for _, e := range env { + for _, key := range vaultKeys { + if len(e) > len(key) && e[:len(key)+1] == key+"=" { + envCount[key]++ + } + } + } + + for _, key := range vaultKeys { + if envCount[key] > baseCount[key] { + t.Errorf("vaultEnv() should not append %s when empty (base=%d, got=%d)", key, baseCount[key], envCount[key]) + } + } +} + +func TestExtractVaultData_KVv2(t *testing.T) { + input := []byte(`{"data":{"data":{"password":"s3cret"},"metadata":{"version":1}}}`) + result, err := extractVaultData(input) + if err != nil { + t.Fatalf("extractVaultData() error = %v", err) + } + if result != `{"password":"s3cret"}` { + t.Errorf("result = %q, want %q", result, `{"password":"s3cret"}`) + } +} + +func TestExtractVaultData_KVv1(t *testing.T) { + input := []byte(`{"data":{"password":"v1secret"}}`) + result, err := extractVaultData(input) + if err != nil { + t.Fatalf("extractVaultData() error = %v", err) + } + if result != `{"password":"v1secret"}` { + t.Errorf("result = %q, want %q", result, `{"password":"v1secret"}`) + } +} + +func TestExtractVaultData_KVv1WithDataKey(t *testing.T) { + // KV-v1 secret that happens to have a key named "data" — must NOT be + // misidentified as KV-v2. The fix checks for both "data" AND "metadata". + input := []byte(`{"data":{"data":"some-value","other":"field"}}`) + result, err := extractVaultData(input) + if err != nil { + t.Fatalf("extractVaultData() error = %v", err) + } + // Should return the full .data object (KV-v1 fallback), not just "some-value" + want := `{"data":"some-value","other":"field"}` + if result != want { + t.Errorf("result = %q, want %q", result, want) + } +} + +func TestExtractVaultData_InvalidJSON(t *testing.T) { + _, err := extractVaultData([]byte("not-json")) + if err == nil { + t.Error("extractVaultData() should error on invalid JSON") + } +} diff --git a/internal/daemon/daemon.go b/internal/daemon/daemon.go index 74db917..6e71899 100644 --- a/internal/daemon/daemon.go +++ b/internal/daemon/daemon.go @@ -75,6 +75,11 @@ var ( setAgeIdentityFileFunc = credentials.SetAgeIdentityFile setOPTokenFileFunc = credentials.SetOPTokenFile setCredentialCacheTTLFunc = credentials.SetCredentialCacheTTL + setVaultAddrFunc = credentials.SetVaultAddr + setVaultSkipVerifyFunc = credentials.SetVaultSkipVerify + setVaultCACertFunc = credentials.SetVaultCACert + setVaultNamespaceFunc = credentials.SetVaultNamespace + setVaultTokenFileFunc = credentials.SetVaultTokenFile fetchCredentialFunc = credentials.Fetch cleanupBWSessionFunc = credentials.CleanupBWSession ) @@ -156,6 +161,11 @@ func (d *Daemon) Run() error { setAgeIdentityFileFunc(cfg.GetAgeIdentityFile()) setOPTokenFileFunc(cfg.GetOPTokenFile()) setCredentialCacheTTLFunc(cfg.GetCredentialCacheTTL()) + setVaultAddrFunc(cfg.GetVaultAddr()) + setVaultSkipVerifyFunc(cfg.GetVaultSkipVerify()) + setVaultCACertFunc(cfg.GetVaultCACert()) + setVaultNamespaceFunc(cfg.GetVaultNamespace()) + setVaultTokenFileFunc(cfg.GetVaultTokenFile()) defer cleanupBWSessionFunc() auditLogger, err := audit.New(cfg.GetAuditConfig()) @@ -350,6 +360,11 @@ func (d *Daemon) reloadConfig() error { setAgeIdentityFileFunc(newCfg.GetAgeIdentityFile()) setOPTokenFileFunc(newCfg.GetOPTokenFile()) setCredentialCacheTTLFunc(newCfg.GetCredentialCacheTTL()) + setVaultAddrFunc(newCfg.GetVaultAddr()) + setVaultSkipVerifyFunc(newCfg.GetVaultSkipVerify()) + setVaultCACertFunc(newCfg.GetVaultCACert()) + setVaultNamespaceFunc(newCfg.GetVaultNamespace()) + setVaultTokenFileFunc(newCfg.GetVaultTokenFile()) return nil } @@ -483,6 +498,7 @@ func (d *Daemon) startHTTPProxy(cfg *config.Config) error { httpproxy.WithPassBinary(cfg.GetPassBinary()), httpproxy.WithOPBinary(cfg.GetOPBinary()), httpproxy.WithBWBinary(cfg.GetBWBinary()), + httpproxy.WithVaultBinary(cfg.GetVaultBinary()), httpproxy.WithAuthToken(d.proxyAuthToken), httpproxy.WithRequireAuth(requireAuth), ) @@ -684,6 +700,7 @@ func (d *Daemon) handleAdminRequest(conn net.Conn, data []byte, cfg *config.Conf credentials.WithPassBinary(cfg.GetPassBinary()), credentials.WithOPBinary(cfg.GetOPBinary()), credentials.WithBWBinary(cfg.GetBWBinary()), + credentials.WithVaultBinary(cfg.GetVaultBinary()), credentials.WithBypassCache(), ) if err != nil || value == "" { diff --git a/internal/daemon/daemon_test.go b/internal/daemon/daemon_test.go index 229a714..e1b22e7 100644 --- a/internal/daemon/daemon_test.go +++ b/internal/daemon/daemon_test.go @@ -676,4 +676,3 @@ tools: {} t.Fatalf("proxyAuthToken changed across disable/enable: got %q, want %q", d.proxyAuthToken, firstToken) } } - diff --git a/internal/daemon/executor.go b/internal/daemon/executor.go index 3cb9a6f..c60f6c8 100644 --- a/internal/daemon/executor.go +++ b/internal/daemon/executor.go @@ -280,6 +280,7 @@ func (e *ToolExecutor) buildEnvironment() ([]string, error) { credentials.WithPassBinary(e.cfg.GetPassBinary()), credentials.WithOPBinary(e.cfg.GetOPBinary()), credentials.WithBWBinary(e.cfg.GetBWBinary()), + credentials.WithVaultBinary(e.cfg.GetVaultBinary()), ) if err != nil { return "", err @@ -437,6 +438,7 @@ func (e *ToolExecutor) setupConfigFile() error { credentials.WithPassBinary(e.cfg.GetPassBinary()), credentials.WithOPBinary(e.cfg.GetOPBinary()), credentials.WithBWBinary(e.cfg.GetBWBinary()), + credentials.WithVaultBinary(e.cfg.GetVaultBinary()), ) if err != nil { return fmt.Errorf("fetch credential %s: %w", credName, err) diff --git a/internal/daemon/executor_env_test.go b/internal/daemon/executor_env_test.go index 203726c..a0e28a7 100644 --- a/internal/daemon/executor_env_test.go +++ b/internal/daemon/executor_env_test.go @@ -504,7 +504,7 @@ func TestBuildEnvironment_PTYFallbackTerm(t *testing.T) { UsePTY: true, }, tool: &config.ToolDef{}, - cfg: &config.Config{}, + cfg: &config.Config{}, } env, err := executor.buildEnvironment() @@ -529,7 +529,7 @@ func TestBuildEnvironment_NonPTYFallbackTerm(t *testing.T) { UsePTY: false, }, tool: &config.ToolDef{}, - cfg: &config.Config{}, + cfg: &config.Config{}, } env, err := executor.buildEnvironment() @@ -554,7 +554,7 @@ func TestBuildEnvironment_UpgradesInheritedDumbInPTY(t *testing.T) { UsePTY: true, }, tool: &config.ToolDef{}, - cfg: &config.Config{}, + cfg: &config.Config{}, } env, err := executor.buildEnvironment() @@ -582,7 +582,7 @@ func TestBuildEnvironment_ReqTermPreservedInPTY(t *testing.T) { }, }, tool: &config.ToolDef{}, - cfg: &config.Config{}, + cfg: &config.Config{}, } env, err := executor.buildEnvironment() diff --git a/internal/httpproxy/proxy.go b/internal/httpproxy/proxy.go index 0bc748c..6ba383c 100644 --- a/internal/httpproxy/proxy.go +++ b/internal/httpproxy/proxy.go @@ -79,6 +79,13 @@ func WithBWBinary(path string) Option { } } +// WithVaultBinary sets the HashiCorp Vault CLI binary path. +func WithVaultBinary(path string) Option { + return func(p *Proxy) { + p.credOpts = append(p.credOpts, credentials.WithVaultBinary(path)) + } +} + // WithAuthToken sets the required proxy auth token. func WithAuthToken(token string) Option { return func(p *Proxy) { @@ -594,6 +601,7 @@ func SetupFromConfig(cfg *config.Config) (*Proxy, error) { WithPassBinary(cfg.GetPassBinary()), WithOPBinary(cfg.GetOPBinary()), WithBWBinary(cfg.GetBWBinary()), + WithVaultBinary(cfg.GetVaultBinary()), ) return proxy, nil From 4e9362a5bcb0ad0f92178de0028bf5b0c47275b4 Mon Sep 17 00:00:00 2001 From: Peter Dedene Date: Sat, 21 Feb 2026 15:04:12 +0100 Subject: [PATCH 2/2] =?UTF-8?q?fix(vault):=20address=20PR=20review=20?= =?UTF-8?q?=E2=80=94=20option=20injection,=20env=20override,=20tri-state?= =?UTF-8?q?=20skip=5Fverify?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add `--` end-of-options guard before path argument (prevents option injection if path starts with `-`) - Rewrite vaultEnv() to strip managed VAULT_* keys before appending configured overrides, so config values actually take precedence - Change vault_skip_verify to *bool tri-state: nil inherits ambient env, explicit true/false overrides VAULT_SKIP_VERIFY --- internal/config/config.go | 10 ++-- internal/config/config_test.go | 44 ++++++++++------- internal/credentials/vault.go | 48 ++++++++++++++++--- internal/credentials/vault_test.go | 77 ++++++++++++++++++++++++++---- 4 files changed, 145 insertions(+), 34 deletions(-) diff --git a/internal/config/config.go b/internal/config/config.go index fe250e1..5bfcbd4 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -61,7 +61,7 @@ type ProxyConfig struct { CredentialCacheTTL string `yaml:"credential_cache_ttl"` // e.g., "30s" (0/empty disables) VaultBinary string `yaml:"vault_binary"` // e.g., "/usr/bin/vault" VaultAddr string `yaml:"vault_addr"` // e.g., "https://127.0.0.1:8200" - VaultSkipVerify bool `yaml:"vault_skip_verify"` // skip TLS verification + VaultSkipVerify *bool `yaml:"vault_skip_verify"` // skip TLS verification (nil = inherit env) VaultCACert string `yaml:"vault_cacert"` // e.g., "/path/to/ca.pem" VaultNamespace string `yaml:"vault_namespace"` // enterprise namespace VaultTokenFile string `yaml:"vault_token_file"` // override default ~/.vault-token @@ -853,8 +853,12 @@ func (c *Config) GetVaultAddr() string { } // GetVaultSkipVerify returns whether to skip Vault TLS verification. -func (c *Config) GetVaultSkipVerify() bool { - return c.Proxy != nil && c.Proxy.VaultSkipVerify +// nil = not configured (inherit ambient env), non-nil = explicit override. +func (c *Config) GetVaultSkipVerify() *bool { + if c.Proxy != nil { + return c.Proxy.VaultSkipVerify + } + return nil } // GetVaultCACert returns the Vault CA cert path (empty = use VAULT_CACERT env). diff --git a/internal/config/config_test.go b/internal/config/config_test.go index 5330e90..2fdb554 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -702,22 +702,34 @@ func TestGetVaultAddr(t *testing.T) { } func TestGetVaultSkipVerify(t *testing.T) { - tests := []struct { - name string - cfg Config - want bool - }{ - {"nil proxy", Config{}, false}, - {"default false", Config{Proxy: &ProxyConfig{}}, false}, - {"configured true", Config{Proxy: &ProxyConfig{VaultSkipVerify: true}}, true}, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if got := tt.cfg.GetVaultSkipVerify(); got != tt.want { - t.Errorf("GetVaultSkipVerify() = %v, want %v", got, tt.want) - } - }) - } + boolPtr := func(v bool) *bool { return &v } + + t.Run("nil proxy", func(t *testing.T) { + cfg := Config{} + if got := cfg.GetVaultSkipVerify(); got != nil { + t.Errorf("GetVaultSkipVerify() = %v, want nil", *got) + } + }) + t.Run("unset returns nil", func(t *testing.T) { + cfg := Config{Proxy: &ProxyConfig{}} + if got := cfg.GetVaultSkipVerify(); got != nil { + t.Errorf("GetVaultSkipVerify() = %v, want nil", *got) + } + }) + t.Run("explicit true", func(t *testing.T) { + cfg := Config{Proxy: &ProxyConfig{VaultSkipVerify: boolPtr(true)}} + got := cfg.GetVaultSkipVerify() + if got == nil || !*got { + t.Errorf("GetVaultSkipVerify() = %v, want *true", got) + } + }) + t.Run("explicit false", func(t *testing.T) { + cfg := Config{Proxy: &ProxyConfig{VaultSkipVerify: boolPtr(false)}} + got := cfg.GetVaultSkipVerify() + if got == nil || *got { + t.Errorf("GetVaultSkipVerify() = %v, want *false", got) + } + }) } func TestGetVaultCACert(t *testing.T) { diff --git a/internal/credentials/vault.go b/internal/credentials/vault.go index e1ca6a8..149e0af 100644 --- a/internal/credentials/vault.go +++ b/internal/credentials/vault.go @@ -19,7 +19,7 @@ const vaultCommandTimeout = 30 * time.Second var ( vaultMu sync.RWMutex vaultAddr string - vaultSkipVerify bool + vaultSkipVerify *bool vaultCACert string vaultNamespace string vaultTokenFile string @@ -33,7 +33,8 @@ func SetVaultAddr(addr string) { } // SetVaultSkipVerify configures TLS verification skip. -func SetVaultSkipVerify(skip bool) { +// nil = inherit ambient env, non-nil = override. +func SetVaultSkipVerify(skip *bool) { vaultMu.Lock() defer vaultMu.Unlock() vaultSkipVerify = skip @@ -60,7 +61,7 @@ func SetVaultTokenFile(path string) { vaultTokenFile = path } -func getVaultSettings() (addr string, skipVerify bool, caCert, namespace, tokenFile string) { +func getVaultSettings() (addr string, skipVerify *bool, caCert, namespace, tokenFile string) { vaultMu.RLock() defer vaultMu.RUnlock() return vaultAddr, vaultSkipVerify, vaultCACert, vaultNamespace, vaultTokenFile @@ -76,7 +77,7 @@ func fetchFromVault(ctx context.Context, parsed *ParsedSource, vaultBinaryOverri ctx, cancel := context.WithTimeout(ctx, vaultCommandTimeout) defer cancel() - cmd := exec.CommandContext(ctx, vaultBinary, "kv", "get", "-format=json", parsed.Path) + cmd := exec.CommandContext(ctx, vaultBinary, "kv", "get", "-format=json", "--", parsed.Path) cmd.Env = vaultEnv() output, err := cmd.Output() @@ -135,15 +136,48 @@ func resolveVaultBinary(vaultBinaryOverride string) (string, error) { } // vaultEnv returns environment variables for the vault CLI. +// Configured values override any ambient VAULT_* env vars; unconfigured +// fields (empty string / nil) inherit from the process environment. func vaultEnv() []string { - env := os.Environ() addr, skipVerify, caCert, namespace, tokenFile := getVaultSettings() + // Collect keys that will be overridden so we can strip them. + strip := make(map[string]bool) + if addr != "" { + strip["VAULT_ADDR"] = true + } + if skipVerify != nil { + strip["VAULT_SKIP_VERIFY"] = true + } + if caCert != "" { + strip["VAULT_CACERT"] = true + } + if namespace != "" { + strip["VAULT_NAMESPACE"] = true + } + if tokenFile != "" { + strip["VAULT_TOKEN_FILE"] = true + } + + // Copy ambient env, filtering out keys we're about to set. + env := make([]string, 0, len(os.Environ())+len(strip)) + for _, e := range os.Environ() { + key, _, _ := strings.Cut(e, "=") + if !strip[key] { + env = append(env, e) + } + } + + // Append configured overrides. if addr != "" { env = append(env, "VAULT_ADDR="+addr) } - if skipVerify { - env = append(env, "VAULT_SKIP_VERIFY=1") + if skipVerify != nil { + if *skipVerify { + env = append(env, "VAULT_SKIP_VERIFY=1") + } else { + env = append(env, "VAULT_SKIP_VERIFY=0") + } } if caCert != "" { env = append(env, "VAULT_CACERT="+caCert) diff --git a/internal/credentials/vault_test.go b/internal/credentials/vault_test.go index 40df17e..c80af90 100644 --- a/internal/credentials/vault_test.go +++ b/internal/credentials/vault_test.go @@ -178,6 +178,8 @@ func TestFetchFromVault_InvalidJSON(t *testing.T) { } } +func boolPtr(v bool) *bool { return &v } + func TestVaultEnv(t *testing.T) { // Save originals origAddr, origSkip, origCACert, origNs, origTokenFile := getVaultSettings() @@ -190,7 +192,7 @@ func TestVaultEnv(t *testing.T) { }() SetVaultAddr("https://vault.example.com:8200") - SetVaultSkipVerify(true) + SetVaultSkipVerify(boolPtr(true)) SetVaultCACert("/etc/vault/ca.pem") SetVaultNamespace("team-a") SetVaultTokenFile("/home/bot/.vault-token") @@ -219,7 +221,32 @@ func TestVaultEnv(t *testing.T) { } } -func TestVaultEnv_Empty(t *testing.T) { +func TestVaultEnv_ExplicitFalse(t *testing.T) { + origAddr, origSkip, origCACert, origNs, origTokenFile := getVaultSettings() + defer func() { + SetVaultAddr(origAddr) + SetVaultSkipVerify(origSkip) + SetVaultCACert(origCACert) + SetVaultNamespace(origNs) + SetVaultTokenFile(origTokenFile) + }() + + SetVaultSkipVerify(boolPtr(false)) + + env := vaultEnv() + found := false + for _, e := range env { + if e == "VAULT_SKIP_VERIFY=0" { + found = true + break + } + } + if !found { + t.Error("vaultEnv() should set VAULT_SKIP_VERIFY=0 when explicitly false") + } +} + +func TestVaultEnv_NilSkipVerify(t *testing.T) { origAddr, origSkip, origCACert, origNs, origTokenFile := getVaultSettings() defer func() { SetVaultAddr(origAddr) @@ -230,16 +257,15 @@ func TestVaultEnv_Empty(t *testing.T) { }() SetVaultAddr("") - SetVaultSkipVerify(false) + SetVaultSkipVerify(nil) SetVaultCACert("") SetVaultNamespace("") SetVaultTokenFile("") env := vaultEnv() - // Count how many times each vault key appears — our setters should add zero. - // Ambient env may already contain VAULT_* vars (e.g. from a real Vault session), - // so we compare counts before and after to isolate what vaultEnv() appended. + // With nil skip and empty strings, vaultEnv() should not add any VAULT_* keys. + // Ambient env may already have them, but vaultEnv() shouldn't strip or add. baseEnv := os.Environ() baseCount := make(map[string]int) vaultKeys := []string{"VAULT_ADDR", "VAULT_SKIP_VERIFY", "VAULT_CACERT", "VAULT_NAMESPACE", "VAULT_TOKEN_FILE"} @@ -261,9 +287,44 @@ func TestVaultEnv_Empty(t *testing.T) { } for _, key := range vaultKeys { - if envCount[key] > baseCount[key] { - t.Errorf("vaultEnv() should not append %s when empty (base=%d, got=%d)", key, baseCount[key], envCount[key]) + if envCount[key] != baseCount[key] { + t.Errorf("vaultEnv() should not change %s count when unconfigured (base=%d, got=%d)", key, baseCount[key], envCount[key]) + } + } +} + +func TestVaultEnv_OverridesAmbient(t *testing.T) { + origAddr, origSkip, origCACert, origNs, origTokenFile := getVaultSettings() + defer func() { + SetVaultAddr(origAddr) + SetVaultSkipVerify(origSkip) + SetVaultCACert(origCACert) + SetVaultNamespace(origNs) + SetVaultTokenFile(origTokenFile) + }() + + // Set ambient VAULT_ADDR, then override via config + t.Setenv("VAULT_ADDR", "https://old.example.com") + SetVaultAddr("https://new.example.com") + SetVaultSkipVerify(nil) + SetVaultCACert("") + SetVaultNamespace("") + SetVaultTokenFile("") + + env := vaultEnv() + + // Should have exactly one VAULT_ADDR with the new value + count := 0 + for _, e := range env { + if e == "VAULT_ADDR=https://new.example.com" { + count++ } + if e == "VAULT_ADDR=https://old.example.com" { + t.Error("vaultEnv() should have stripped ambient VAULT_ADDR") + } + } + if count != 1 { + t.Errorf("expected 1 VAULT_ADDR entry, got %d", count) } }