diff --git a/internal/credentials/config_test.go b/internal/credentials/config_test.go index a3378d7..eb0980f 100644 --- a/internal/credentials/config_test.go +++ b/internal/credentials/config_test.go @@ -188,7 +188,7 @@ func TestBuildFromConfig_Bearer(t *testing.T) { URL: &url.URL{Host: "api.openai.com"}, Header: make(http.Header), } - if !store.InjectCredentials(req) { + if matched, injected := store.InjectCredentials(req); !matched || !injected { t.Error("should match api.openai.com") } if got := req.Header.Get("Authorization"); got != "Bearer sk-test-123" { @@ -216,7 +216,7 @@ func TestBuildFromConfig_APIKey(t *testing.T) { URL: &url.URL{Host: "api.anthropic.com"}, Header: make(http.Header), } - if !store.InjectCredentials(req) { + if matched, injected := store.InjectCredentials(req); !matched || !injected { t.Error("should match api.anthropic.com") } if got := req.Header.Get("x-api-key"); got != "sk-ant-test" { @@ -244,7 +244,7 @@ func TestBuildFromConfig_GitHubBearer(t *testing.T) { URL: &url.URL{Host: "github.com"}, Header: make(http.Header), } - if !store.InjectCredentials(req) { + if matched, injected := store.InjectCredentials(req); !matched || !injected { t.Error("should match github.com") } if got := req.Header.Get("Authorization"); got != "Bearer ghp_test" { @@ -256,7 +256,7 @@ func TestBuildFromConfig_GitHubBearer(t *testing.T) { URL: &url.URL{Host: "raw.githubusercontent.com"}, Header: make(http.Header), } - if !store.InjectCredentials(req2) { + if matched, injected := store.InjectCredentials(req2); !matched || !injected { t.Error("should match raw.githubusercontent.com") } if got := req2.Header.Get("Authorization"); got != "Bearer ghp_test" { @@ -289,7 +289,7 @@ func TestBuildFromConfig_MissingEnvVarSkipped(t *testing.T) { URL: &url.URL{Host: "api.example.com"}, Header: make(http.Header), } - if store.InjectCredentials(req) { + if matched, _ := store.InjectCredentials(req); matched { t.Error("should not match when env var is unset") } } @@ -324,7 +324,7 @@ func TestBuildFromConfig_MultipleEntries(t *testing.T) { URL: &url.URL{Host: "api.a.com"}, Header: make(http.Header), } - if !store.InjectCredentials(req1) { + if matched, injected := store.InjectCredentials(req1); !matched || !injected { t.Error("should match api.a.com") } if got := req1.Header.Get("Authorization"); got != "Bearer key-a" { @@ -336,7 +336,7 @@ func TestBuildFromConfig_MultipleEntries(t *testing.T) { URL: &url.URL{Host: "api.b.com"}, Header: make(http.Header), } - if !store.InjectCredentials(req2) { + if matched, injected := store.InjectCredentials(req2); !matched || !injected { t.Error("should match api.b.com") } if got := req2.Header.Get("Authorization"); got != "Bearer key-b" { @@ -372,7 +372,7 @@ func TestBuildFromConfig_GCloudFromJSON(t *testing.T) { URL: &url.URL{Host: "storage.googleapis.com"}, Header: make(http.Header), } - if !store.InjectCredentials(req) { + if matched, _ := store.InjectCredentials(req); !matched { t.Error("should match storage.googleapis.com with gcloud injector from JSON") } } @@ -403,7 +403,7 @@ func TestBuildFromConfig_GCloudJSONPreferredOverFile(t *testing.T) { URL: &url.URL{Host: "storage.googleapis.com"}, Header: make(http.Header), } - if !store.InjectCredentials(req) { + if matched, _ := store.InjectCredentials(req); !matched { t.Error("should match storage.googleapis.com with gcloud injector from JSON") } } @@ -473,7 +473,7 @@ func TestBuildFromConfig_ExactAndSuffixDomains(t *testing.T) { URL: &url.URL{Host: "exact.com"}, Header: make(http.Header), } - if !store.InjectCredentials(req1) { + if matched, injected := store.InjectCredentials(req1); !matched || !injected { t.Error("should match exact.com") } @@ -482,7 +482,7 @@ func TestBuildFromConfig_ExactAndSuffixDomains(t *testing.T) { URL: &url.URL{Host: "sub.suffix.com"}, Header: make(http.Header), } - if !store.InjectCredentials(req2) { + if matched, injected := store.InjectCredentials(req2); !matched || !injected { t.Error("should match sub.suffix.com") } @@ -491,7 +491,7 @@ func TestBuildFromConfig_ExactAndSuffixDomains(t *testing.T) { URL: &url.URL{Host: "other.com"}, Header: make(http.Header), } - if store.InjectCredentials(req3) { + if matched, _ := store.InjectCredentials(req3); matched { t.Error("should not match other.com") } } diff --git a/internal/credentials/gcloud.go b/internal/credentials/gcloud.go index e07a98a..79913c1 100644 --- a/internal/credentials/gcloud.go +++ b/internal/credentials/gcloud.go @@ -72,24 +72,25 @@ func (g *GCloudInjector) init() error { // Inject sets the Authorization: Bearer header with a fresh OAuth2 token. // Always overrides — the agent may have a token from a dummy ADC file. -func (g *GCloudInjector) Inject(req *http.Request) { +func (g *GCloudInjector) Inject(req *http.Request) bool { if err := g.init(); err != nil { log.Printf("ERROR gcloud credential init failed: %v", err) - return + return false } token, err := g.credentials.TokenSource.Token() if err != nil { log.Printf("ERROR gcloud token refresh failed: %v", err) - return + return false } if !token.Valid() { log.Printf("WARN gcloud token is invalid after refresh") - return + return false } req.Header.Set("Authorization", "Bearer "+token.AccessToken) + return true } // Available returns true if ADC credentials can be loaded (from JSON or file). diff --git a/internal/credentials/static.go b/internal/credentials/static.go index 205367b..493f673 100644 --- a/internal/credentials/static.go +++ b/internal/credentials/static.go @@ -14,8 +14,9 @@ type HeaderInjector struct { Value string } -func (h *HeaderInjector) Inject(req *http.Request) { +func (h *HeaderInjector) Inject(req *http.Request) bool { req.Header.Set(h.Header, h.Value) + return true } // BearerInjector injects an Authorization: Bearer header with a static token. @@ -24,8 +25,9 @@ type BearerInjector struct { Token string } -func (b *BearerInjector) Inject(req *http.Request) { +func (b *BearerInjector) Inject(req *http.Request) bool { req.Header.Set("Authorization", "Bearer "+b.Token) + return true } // APIKeyInjector injects a key into a custom header (e.g., x-api-key). @@ -35,6 +37,7 @@ type APIKeyInjector struct { Key string } -func (a *APIKeyInjector) Inject(req *http.Request) { +func (a *APIKeyInjector) Inject(req *http.Request) bool { req.Header.Set(a.HeaderName, a.Key) + return true } diff --git a/internal/credentials/store.go b/internal/credentials/store.go index e3e9a5d..3cf5ba3 100644 --- a/internal/credentials/store.go +++ b/internal/credentials/store.go @@ -10,8 +10,8 @@ import ( // Injector can inject credentials into an HTTP request. type Injector interface { // Inject adds credential headers to the request. - // It should only inject if the relevant header is not already present. - Inject(req *http.Request) + // Returns true if the credential was successfully set, false on error. + Inject(req *http.Request) bool } // Route maps a domain pattern to a credential injector. @@ -48,8 +48,10 @@ func (s *Store) AddRoute(route Route) { } // InjectCredentials finds the first matching route for the request's -// host and injects credentials. Returns true if credentials were injected. -func (s *Store) InjectCredentials(req *http.Request) bool { +// host and injects credentials. Returns (matched, injected) where matched +// indicates a route was found and injected indicates the credential was +// successfully set. +func (s *Store) InjectCredentials(req *http.Request) (bool, bool) { s.mu.RLock() defer s.mu.RUnlock() @@ -72,11 +74,15 @@ func (s *Store) InjectCredentials(req *http.Request) bool { } if matched { - route.Injector.Inject(req) - log.Printf("CREDENTIAL_INJECT host=%s pattern=%s method=%s path=%s", host, matchedPattern, req.Method, req.URL.Path) - return true + ok := route.Injector.Inject(req) + if ok { + log.Printf("CREDENTIAL_INJECT host=%s pattern=%s method=%s path=%s", host, matchedPattern, req.Method, req.URL.Path) + } else { + log.Printf("CREDENTIAL_INJECT_FAILED host=%s pattern=%s method=%s path=%s", host, matchedPattern, req.Method, req.URL.Path) + } + return true, ok } } - return false + return false, false } diff --git a/internal/credentials/store_test.go b/internal/credentials/store_test.go index f50053c..cf249dc 100644 --- a/internal/credentials/store_test.go +++ b/internal/credentials/store_test.go @@ -18,7 +18,8 @@ func TestStore_InjectCredentials_ExactDomain(t *testing.T) { Header: make(http.Header), } - if !store.InjectCredentials(req) { + matched, injected := store.InjectCredentials(req) + if !matched || !injected { t.Error("should match github.com") } if got := req.Header.Get("Authorization"); got != "Bearer ghp_test123" { @@ -38,7 +39,8 @@ func TestStore_InjectCredentials_DomainSuffix(t *testing.T) { Header: make(http.Header), } - if !store.InjectCredentials(req) { + matched, injected := store.InjectCredentials(req) + if !matched || !injected { t.Error("should match api.openai.com via suffix .openai.com") } if got := req.Header.Get("Authorization"); got != "Bearer sk-test" { @@ -58,7 +60,8 @@ func TestStore_InjectCredentials_NoMatch(t *testing.T) { Header: make(http.Header), } - if store.InjectCredentials(req) { + matched, _ := store.InjectCredentials(req) + if matched { t.Error("should not match evil.com") } if got := req.Header.Get("Authorization"); got != "" { @@ -80,7 +83,10 @@ func TestStore_InjectCredentials_AlwaysOverrides(t *testing.T) { // Agent sets a dummy/placeholder token req.Header.Set("Authorization", "Bearer paude-proxy-managed") - store.InjectCredentials(req) + matched, injected := store.InjectCredentials(req) + if !matched || !injected { + t.Error("should match and inject for api.openai.com") + } if got := req.Header.Get("Authorization"); got != "Bearer proxy-token" { t.Errorf("proxy should override agent's dummy token: got %q, want %q", got, "Bearer proxy-token") } @@ -90,7 +96,9 @@ func TestAPIKeyInjector(t *testing.T) { inj := &APIKeyInjector{HeaderName: "x-api-key", Key: "sk-ant-test"} req := &http.Request{Header: make(http.Header)} - inj.Inject(req) + if !inj.Inject(req) { + t.Error("Inject should return true") + } if got := req.Header.Get("x-api-key"); got != "sk-ant-test" { t.Errorf("x-api-key = %q, want %q", got, "sk-ant-test") } @@ -98,7 +106,9 @@ func TestAPIKeyInjector(t *testing.T) { // Should override existing (agent may have a dummy placeholder) req2 := &http.Request{Header: make(http.Header)} req2.Header.Set("x-api-key", "paude-proxy-managed") - inj.Inject(req2) + if !inj.Inject(req2) { + t.Error("Inject should return true") + } if got := req2.Header.Get("x-api-key"); got != "sk-ant-test" { t.Errorf("should override dummy key: got %q, want %q", got, "sk-ant-test") } @@ -120,8 +130,42 @@ func TestStore_FirstMatchWins(t *testing.T) { Header: make(http.Header), } - store.InjectCredentials(req) + matched, injected := store.InjectCredentials(req) + if !matched || !injected { + t.Error("should match and inject for api.openai.com") + } if got := req.Header.Get("Authorization"); got != "Bearer exact-token" { t.Errorf("first match should win: got %q", got) } } + +// failingInjector is a mock that always fails injection. +type failingInjector struct{} + +func (f *failingInjector) Inject(req *http.Request) bool { + return false +} + +func TestStore_InjectCredentials_InjectorFails(t *testing.T) { + store := NewStore() + store.AddRoute(Route{ + ExactDomain: "example.com", + Injector: &failingInjector{}, + }) + + req := &http.Request{ + URL: &url.URL{Host: "example.com"}, + Header: make(http.Header), + } + + matched, injected := store.InjectCredentials(req) + if !matched { + t.Error("should match example.com") + } + if injected { + t.Error("should not report injection as successful") + } + if got := req.Header.Get("Authorization"); got != "" { + t.Errorf("Authorization should be empty, got %q", got) + } +} diff --git a/internal/proxy/integration_test.go b/internal/proxy/integration_test.go index dbb622f..60fd96d 100644 --- a/internal/proxy/integration_test.go +++ b/internal/proxy/integration_test.go @@ -691,3 +691,59 @@ func TestIntegration_UntrustedUpstreamCert(t *testing.T) { // Connection error is also acceptable — proxy rejected the upstream t.Logf("got error (expected — proxy rejected untrusted upstream cert): %v", err) } + +// failingInjector always fails to inject credentials. +type failingInjector struct{} + +func (f *failingInjector) Inject(req *http.Request) bool { + return false +} + +func TestIntegration_CredentialInjectionFailure_Returns502(t *testing.T) { + skipIntegration(t) + + ca, err := GenerateCA() + if err != nil { + t.Fatalf("generate CA: %v", err) + } + + upstream := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + defer upstream.Close() + + upstreamURL, _ := url.Parse(upstream.URL) + upstreamHostname := upstreamURL.Hostname() + + df := filter.NewDomainFilter(upstreamHostname) + + store := credentials.NewStore() + store.AddRoute(credentials.Route{ + ExactDomain: upstreamHostname, + Injector: &failingInjector{}, + }) + + upstreamCAs := upstreamCertPool(t, upstream) + upstreamCert := upstream.TLS.Certificates[0] + upstreamCA, _ := x509.ParseCertificate(upstreamCert.Certificate[0]) + + proxyAddr, cleanup := startTestProxy(t, ca, df, store, nil, upstreamCAs) + defer cleanup() + + client := httpClientViaProxy(t, proxyAddr, ca.Certificate, upstreamCA) + + resp, err := client.Get(upstream.URL + "/test") + if err != nil { + t.Fatalf("request failed: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusBadGateway { + t.Errorf("expected 502 Bad Gateway, got %d", resp.StatusCode) + } + + body, _ := io.ReadAll(resp.Body) + if got := string(body); got != "Proxy credential injection failed" { + t.Errorf("expected injection failure message, got %q", got) + } +} diff --git a/internal/proxy/proxy.go b/internal/proxy/proxy.go index 3680b03..7063639 100644 --- a/internal/proxy/proxy.go +++ b/internal/proxy/proxy.go @@ -311,7 +311,15 @@ func New(cfg Config) *http.Server { // Inject credentials for API requests if cfg.CredStore != nil { - cfg.CredStore.InjectCredentials(req) + matched, injected := cfg.CredStore.InjectCredentials(req) + if matched && !injected { + log.Printf("CREDENTIAL_INJECT_FAILED_502 method=%s host=%s path=%s", req.Method, req.URL.Host, req.URL.Path) + return req, goproxy.NewResponse(req, + goproxy.ContentTypeText, + http.StatusBadGateway, + "Proxy credential injection failed", + ) + } } // Suppress proxy identity headers