Skip to content

Commit d1fb1b0

Browse files
feat(http): optional X-Forwarded-Host support for host authorization (#50)
Co-authored-by: blink-so[bot] <211532188+blink-so[bot]@users.noreply.github.com>
1 parent 2441cdc commit d1fb1b0

5 files changed

Lines changed: 146 additions & 28 deletions

File tree

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,8 @@ agentapi server --allowed-hosts 'example.com,example.org' -- claude
102102
AGENTAPI_ALLOWED_HOSTS='example.com example.org' agentapi server -- claude
103103
```
104104

105+
If you're running behind a trusted reverse proxy that sets the `X-Forwarded-Host` header, you can opt in to using that header for host authorization with `--use-x-forwarded-host` (or `AGENTAPI_USE_X_FORWARDED_HOST=true`). When enabled, the server prefers the first `X-Forwarded-Host` value, and matches it against the allowed host list. Leave this disabled unless your deployment terminates at a trusted proxy.
106+
105107
#### Allowed origins
106108

107109
By default, the server allows CORS requests from `http://localhost:3284`, `http://localhost:3000`, and `http://localhost:3001`. If you'd like to change which origins can make cross-origin requests to AgentAPI, you can change this by using the `AGENTAPI_ALLOWED_ORIGINS` environment variable or the `--allowed-origins` flag.

cmd/server/server.go

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -96,12 +96,13 @@ func runServer(ctx context.Context, logger *slog.Logger, argsToPass []string) er
9696
}
9797
port := viper.GetInt(FlagPort)
9898
srv, err := httpapi.NewServer(ctx, httpapi.ServerConfig{
99-
AgentType: agentType,
100-
Process: process,
101-
Port: port,
102-
ChatBasePath: viper.GetString(FlagChatBasePath),
103-
AllowedHosts: viper.GetStringSlice(FlagAllowedHosts),
104-
AllowedOrigins: viper.GetStringSlice(FlagAllowedOrigins),
99+
AgentType: agentType,
100+
Process: process,
101+
Port: port,
102+
ChatBasePath: viper.GetString(FlagChatBasePath),
103+
AllowedHosts: viper.GetStringSlice(FlagAllowedHosts),
104+
AllowedOrigins: viper.GetStringSlice(FlagAllowedOrigins),
105+
UseXForwardedHost: viper.GetBool(FlagUseXForwardedHost),
105106
})
106107
if err != nil {
107108
return xerrors.Errorf("failed to create server: %w", err)
@@ -155,15 +156,16 @@ type flagSpec struct {
155156
}
156157

157158
const (
158-
FlagType = "type"
159-
FlagPort = "port"
160-
FlagPrintOpenAPI = "print-openapi"
161-
FlagChatBasePath = "chat-base-path"
162-
FlagTermWidth = "term-width"
163-
FlagTermHeight = "term-height"
164-
FlagAllowedHosts = "allowed-hosts"
165-
FlagAllowedOrigins = "allowed-origins"
166-
FlagExit = "exit"
159+
FlagType = "type"
160+
FlagPort = "port"
161+
FlagPrintOpenAPI = "print-openapi"
162+
FlagChatBasePath = "chat-base-path"
163+
FlagTermWidth = "term-width"
164+
FlagTermHeight = "term-height"
165+
FlagAllowedHosts = "allowed-hosts"
166+
FlagAllowedOrigins = "allowed-origins"
167+
FlagUseXForwardedHost = "use-x-forwarded-host"
168+
FlagExit = "exit"
167169
)
168170

169171
func CreateServerCmd() *cobra.Command {
@@ -197,6 +199,7 @@ func CreateServerCmd() *cobra.Command {
197199
{FlagAllowedHosts, "a", []string{"localhost", "127.0.0.1", "[::1]"}, "HTTP allowed hosts (hostnames only, no ports). Use '*' for all, comma-separated list via flag, space-separated list via AGENTAPI_ALLOWED_HOSTS env var", "stringSlice"},
198200
// localhost:3284 is the default origin when you open the chat interface in your browser. localhost:3000 and 3001 are used during development.
199201
{FlagAllowedOrigins, "o", []string{"http://localhost:3284", "http://localhost:3000", "http://localhost:3001"}, "HTTP allowed origins. Use '*' for all, comma-separated list via flag, space-separated list via AGENTAPI_ALLOWED_ORIGINS env var", "stringSlice"},
202+
{FlagUseXForwardedHost, "", false, "Use X-Forwarded-Host header for host authorization (behind trusted proxies)", "bool"},
200203
}
201204

202205
for _, spec := range flagSpecs {

cmd/server/server_test.go

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,7 @@ func TestServerCmd_AllArgs_Defaults(t *testing.T) {
157157
{"term-height default", FlagTermHeight, uint16(1000), func() any { return viper.GetUint16(FlagTermHeight) }},
158158
{"allowed-hosts default", FlagAllowedHosts, []string{"localhost", "127.0.0.1", "[::1]"}, func() any { return viper.GetStringSlice(FlagAllowedHosts) }},
159159
{"allowed-origins default", FlagAllowedOrigins, []string{"http://localhost:3284", "http://localhost:3000", "http://localhost:3001"}, func() any { return viper.GetStringSlice(FlagAllowedOrigins) }},
160+
{"use-x-forwarded-host default", FlagUseXForwardedHost, false, func() any { return viper.GetBool(FlagUseXForwardedHost) }},
160161
}
161162

162163
for _, tt := range tests {
@@ -191,6 +192,7 @@ func TestServerCmd_AllEnvVars(t *testing.T) {
191192
{"AGENTAPI_TERM_HEIGHT", "AGENTAPI_TERM_HEIGHT", "500", uint16(500), func() any { return viper.GetUint16(FlagTermHeight) }},
192193
{"AGENTAPI_ALLOWED_HOSTS", "AGENTAPI_ALLOWED_HOSTS", "localhost example.com", []string{"localhost", "example.com"}, func() any { return viper.GetStringSlice(FlagAllowedHosts) }},
193194
{"AGENTAPI_ALLOWED_ORIGINS", "AGENTAPI_ALLOWED_ORIGINS", "https://example.com http://localhost:3000", []string{"https://example.com", "http://localhost:3000"}, func() any { return viper.GetStringSlice(FlagAllowedOrigins) }},
195+
{"AGENTAPI_USE_X_FORWARDED_HOST", "AGENTAPI_USE_X_FORWARDED_HOST", "true", true, func() any { return viper.GetBool(FlagUseXForwardedHost) }},
194196
}
195197

196198
for _, tt := range tests {
@@ -268,6 +270,13 @@ func TestServerCmd_ArgsPrecedenceOverEnv(t *testing.T) {
268270
[]string{"https://cli-example.com"},
269271
func() any { return viper.GetStringSlice(FlagAllowedOrigins) },
270272
},
273+
{
274+
"use-x-forwarded-host: CLI overrides env",
275+
"AGENTAPI_USE_X_FORWARDED_HOST", "false",
276+
[]string{"--use-x-forwarded-host"},
277+
true,
278+
func() any { return viper.GetBool(FlagUseXForwardedHost) },
279+
},
271280
}
272281

273282
for _, tt := range tests {

lib/httpapi/server.go

Lines changed: 26 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -62,12 +62,13 @@ func (s *Server) GetOpenAPI() string {
6262
const snapshotInterval = 25 * time.Millisecond
6363

6464
type ServerConfig struct {
65-
AgentType mf.AgentType
66-
Process *termexec.Process
67-
Port int
68-
ChatBasePath string
69-
AllowedHosts []string
70-
AllowedOrigins []string
65+
AgentType mf.AgentType
66+
Process *termexec.Process
67+
Port int
68+
ChatBasePath string
69+
AllowedHosts []string
70+
AllowedOrigins []string
71+
UseXForwardedHost bool
7172
}
7273

7374
// Validate allowed hosts don't contain whitespace, commas, schemes, or ports.
@@ -176,7 +177,7 @@ func NewServer(ctx context.Context, config ServerConfig) (*Server, error) {
176177
badHostHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
177178
http.Error(w, "Invalid host header. Allowed hosts: "+strings.Join(allowedHosts, ", "), http.StatusBadRequest)
178179
})
179-
router.Use(hostAuthorizationMiddleware(allowedHosts, badHostHandler))
180+
router.Use(hostAuthorizationMiddleware(allowedHosts, config.UseXForwardedHost, badHostHandler))
180181

181182
corsMiddleware := cors.New(cors.Options{
182183
AllowedOrigins: allowedOrigins,
@@ -229,8 +230,9 @@ func (s *Server) Handler() http.Handler {
229230

230231
// hostAuthorizationMiddleware enforces that the request Host header matches one of the allowed
231232
// hosts, ignoring any port in the comparison. If allowedHosts is empty, all hosts are allowed.
232-
// Always uses url.Parse("http://" + r.Host) to robustly extract the hostname (handles IPv6).
233-
func hostAuthorizationMiddleware(allowedHosts []string, badHostHandler http.Handler) func(next http.Handler) http.Handler {
233+
// If useXForwardedHost is true and the X-Forwarded-Host header is present, that header is used
234+
// as the source of host. Hostname is extracted via url.Parse to handle IPv6 and strip ports.
235+
func hostAuthorizationMiddleware(allowedHosts []string, useXForwardedHost bool, badHostHandler http.Handler) func(next http.Handler) http.Handler {
234236
// Copy for safety; also build a map for O(1) lookups with case-insensitive keys.
235237
allowed := make(map[string]struct{}, len(allowedHosts))
236238
for _, h := range allowedHosts {
@@ -243,13 +245,24 @@ func hostAuthorizationMiddleware(allowedHosts []string, badHostHandler http.Hand
243245
next.ServeHTTP(w, r)
244246
return
245247
}
246-
// Extract hostname from the Host header using url.Parse; ignore any port.
247-
hostHeader := r.Host
248-
if hostHeader == "" {
248+
// Choose header source
249+
rawHost := r.Host
250+
if useXForwardedHost {
251+
if xfhs := r.Header.Values("X-Forwarded-Host"); len(xfhs) > 0 {
252+
// Use the first value and trim anything after a comma
253+
h := xfhs[0]
254+
if idx := strings.IndexByte(h, ','); idx >= 0 {
255+
h = h[:idx]
256+
}
257+
rawHost = strings.TrimSpace(h)
258+
}
259+
}
260+
if rawHost == "" {
249261
badHostHandler.ServeHTTP(w, r)
250262
return
251263
}
252-
if u, err := url.Parse("http://" + hostHeader); err == nil {
264+
// Extract hostname via url.Parse; ignore any port.
265+
if u, err := url.Parse("http://" + rawHost); err == nil {
253266
hostname := u.Hostname()
254267
if _, ok := allowed[strings.ToLower(hostname)]; ok {
255268
next.ServeHTTP(w, r)

lib/httpapi/server_test.go

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -306,6 +306,97 @@ func TestServer_AllowedHosts(t *testing.T) {
306306
}
307307
}
308308

309+
func TestServer_UseXForwardedHost(t *testing.T) {
310+
cases := []struct {
311+
name string
312+
allowedHosts []string
313+
useXForwardedHost bool
314+
hostHeader string
315+
xForwardedHostHeader string
316+
expectedStatusCode int
317+
expectedErrorMsg string
318+
}{
319+
{
320+
name: "disabled flag ignores X-Forwarded-Host",
321+
allowedHosts: []string{"app.example.com"},
322+
useXForwardedHost: false,
323+
hostHeader: "malicious.com",
324+
xForwardedHostHeader: "app.example.com",
325+
expectedStatusCode: http.StatusBadRequest,
326+
expectedErrorMsg: "Invalid host header. Allowed hosts: app.example.com",
327+
},
328+
{
329+
name: "enabled flag uses X-Forwarded-Host",
330+
allowedHosts: []string{"app.example.com"},
331+
useXForwardedHost: true,
332+
hostHeader: "malicious.com",
333+
xForwardedHostHeader: "app.example.com",
334+
expectedStatusCode: http.StatusOK,
335+
},
336+
{
337+
name: "enabled with port in X-Forwarded-Host",
338+
allowedHosts: []string{"app.example.com"},
339+
useXForwardedHost: true,
340+
hostHeader: "malicious.com",
341+
xForwardedHostHeader: "app.example.com:443",
342+
expectedStatusCode: http.StatusOK,
343+
},
344+
{
345+
name: "enabled with IPv6 literal in X-Forwarded-Host",
346+
allowedHosts: []string{"2001:db8::1"},
347+
useXForwardedHost: true,
348+
hostHeader: "malicious.com",
349+
xForwardedHostHeader: "[2001:db8::1]:8443",
350+
expectedStatusCode: http.StatusOK,
351+
},
352+
{
353+
name: "enabled with comma-separated X-Forwarded-Host takes first",
354+
allowedHosts: []string{"first.example.com"},
355+
useXForwardedHost: true,
356+
hostHeader: "malicious.com",
357+
xForwardedHostHeader: "first.example.com, other.example.com",
358+
expectedStatusCode: http.StatusOK,
359+
},
360+
}
361+
362+
for _, tc := range cases {
363+
t.Run(tc.name, func(t *testing.T) {
364+
t.Parallel()
365+
ctx := logctx.WithLogger(context.Background(), slog.New(slog.NewTextHandler(os.Stdout, nil)))
366+
s, err := httpapi.NewServer(ctx, httpapi.ServerConfig{
367+
AgentType: msgfmt.AgentTypeClaude,
368+
Process: nil,
369+
Port: 0,
370+
ChatBasePath: "/chat",
371+
AllowedHosts: tc.allowedHosts,
372+
AllowedOrigins: []string{"https://example.com"}, // isolate
373+
UseXForwardedHost: tc.useXForwardedHost,
374+
})
375+
require.NoError(t, err)
376+
tsServer := httptest.NewServer(s.Handler())
377+
t.Cleanup(tsServer.Close)
378+
379+
req, err := http.NewRequest("GET", tsServer.URL+"/status", nil)
380+
require.NoError(t, err)
381+
if tc.hostHeader != "" {
382+
req.Host = tc.hostHeader
383+
}
384+
if tc.xForwardedHostHeader != "" {
385+
req.Header.Set("X-Forwarded-Host", tc.xForwardedHostHeader)
386+
}
387+
388+
resp, err := (&http.Client{}).Do(req)
389+
require.NoError(t, err)
390+
t.Cleanup(func() { _ = resp.Body.Close() })
391+
require.Equal(t, tc.expectedStatusCode, resp.StatusCode)
392+
if tc.expectedErrorMsg != "" {
393+
b, _ := io.ReadAll(resp.Body)
394+
require.Contains(t, string(b), tc.expectedErrorMsg)
395+
}
396+
})
397+
}
398+
}
399+
309400
func TestServer_CORSPreflightWithHosts(t *testing.T) {
310401
cases := []struct {
311402
name string

0 commit comments

Comments
 (0)