diff --git a/README.md b/README.md index 247de18..4580937 100644 --- a/README.md +++ b/README.md @@ -52,6 +52,7 @@ Flags: --tls-cert string TLS certificate path --tls-key string TLS key path --tls-min-ver string TLS min version, one of (1.2|1.3) (default "1.2") + --trust-proxy trust proxy headers such as X-Forwarded-For (use when running behind a reverse proxy) -v, --version version for rest-server ``` diff --git a/cmd/rest-server/main.go b/cmd/rest-server/main.go index 12fd6b6..9fc392e 100644 --- a/cmd/rest-server/main.go +++ b/cmd/rest-server/main.go @@ -74,6 +74,7 @@ func newRestServerApp() *restServerApp { flags.BoolVar(&rv.Server.Prometheus, "prometheus", rv.Server.Prometheus, "enable Prometheus metrics") flags.BoolVar(&rv.Server.PrometheusNoAuth, "prometheus-no-auth", rv.Server.PrometheusNoAuth, "disable auth for Prometheus /metrics endpoint") flags.BoolVar(&rv.Server.GroupAccessibleRepos, "group-accessible-repos", rv.Server.GroupAccessibleRepos, "let filesystem group be able to access repo files") + flags.BoolVar(&rv.Server.TrustProxy, "trust-proxy", rv.Server.TrustProxy, "trust proxy headers such as X-Forwarded-For (use when running behind a reverse proxy)") return rv } @@ -164,6 +165,12 @@ func (app *restServerApp) runRoot(_ *cobra.Command, _ []string) error { log.Println("Group accessible repos disabled") } + if app.Server.TrustProxy { + log.Println("Trust proxy headers enabled") + } else { + log.Println("Trust proxy headers disabled") + } + enabledTLS, privateKey, publicKey, err := app.tlsSettings() if err != nil { return err diff --git a/handlers.go b/handlers.go index 5938edd..67c93ff 100644 --- a/handlers.go +++ b/handlers.go @@ -35,6 +35,7 @@ type Server struct { PanicOnError bool NoVerifyUpload bool GroupAccessibleRepos bool + TrustProxy bool htpasswdFile *HtpasswdFile quotaManager *quota.Manager diff --git a/mux.go b/mux.go index 9c604b3..934bf39 100644 --- a/mux.go +++ b/mux.go @@ -21,6 +21,10 @@ func (s *Server) debugHandler(next http.Handler) http.Handler { }) } +func (s *Server) proxyHandler(next http.Handler) http.Handler { + return handlers.ProxyHeaders(next) +} + func (s *Server) logHandler(next http.Handler) http.Handler { var accessLog io.Writer @@ -111,6 +115,9 @@ func NewHandler(server *Server) (http.Handler, error) { if server.Debug { handler = server.debugHandler(handler) } + if server.TrustProxy { + handler = server.proxyHandler(handler) + } if server.Log != "" { handler = server.logHandler(handler) }