From 4fe1daa1cc23467316552806372bdeb4cfe846b5 Mon Sep 17 00:00:00 2001 From: Dany Khalife Date: Sat, 21 Mar 2026 17:22:30 -0700 Subject: [PATCH 1/2] redirect http and add hsts header --- .../internal/utils/middleware/middleware.go | 23 +++++ .../utils/middleware/middleware_test.go | 83 +++++++++++++++++++ apiserver/main.go | 1 + mcpserver/Program.cs | 2 + 4 files changed, 109 insertions(+) diff --git a/apiserver/internal/utils/middleware/middleware.go b/apiserver/internal/utils/middleware/middleware.go index 0d738e2d..24772cbc 100644 --- a/apiserver/internal/utils/middleware/middleware.go +++ b/apiserver/internal/utils/middleware/middleware.go @@ -1,6 +1,7 @@ package middleware import ( + "fmt" "net/http" "strconv" @@ -43,6 +44,28 @@ func RateLimitMiddleware(limiter *limiter.Limiter) gin.HandlerFunc { } } +func SecurityHeaders(cfg *config.Config) gin.HandlerFunc { + hostName := cfg.Server.HostName + port := cfg.Server.Port + + return func(c *gin.Context) { + proto := c.GetHeader("X-Forwarded-Proto") + if proto == "http" { + target := fmt.Sprintf("https://%s", hostName) + if port != 443 { + target = fmt.Sprintf("%s:%d", target, port) + } + target = fmt.Sprintf("%s%s", target, c.Request.URL.RequestURI()) + c.Redirect(http.StatusMovedPermanently, target) + c.Abort() + return + } + + c.Header("Strict-Transport-Security", "max-age=31536000; includeSubDomains; preload") + c.Next() + } +} + func RequestLogger() gin.HandlerFunc { return func(c *gin.Context) { c.Next() diff --git a/apiserver/internal/utils/middleware/middleware_test.go b/apiserver/internal/utils/middleware/middleware_test.go index dd61b210..0bad93ae 100644 --- a/apiserver/internal/utils/middleware/middleware_test.go +++ b/apiserver/internal/utils/middleware/middleware_test.go @@ -91,3 +91,86 @@ func (s *MiddlewareTestSuite) TestRateLimitMiddlewareStoreFailure() { s.router.ServeHTTP(w, req) s.Equal(http.StatusInternalServerError, w.Code) } + +func (s *MiddlewareTestSuite) TestSecurityHeadersAddsHSTS() { + cfg := &config.Config{ + Server: config.ServerConfig{ + HostName: "example.com", + Port: 443, + }, + } + + s.router.Use(SecurityHeaders(cfg)) + s.router.GET("/", func(c *gin.Context) { + c.String(http.StatusOK, "OK") + }) + + w := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/", nil) + s.router.ServeHTTP(w, req) + s.Equal(http.StatusOK, w.Code) + s.Equal("max-age=31536000; includeSubDomains; preload", w.Header().Get("Strict-Transport-Security")) +} + +func (s *MiddlewareTestSuite) TestSecurityHeadersRedirectsHTTP() { + cfg := &config.Config{ + Server: config.ServerConfig{ + HostName: "example.com", + Port: 443, + }, + } + + s.router.Use(SecurityHeaders(cfg)) + s.router.GET("/path", func(c *gin.Context) { + c.String(http.StatusOK, "OK") + }) + + w := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/path?q=1", nil) + req.Header.Set("X-Forwarded-Proto", "http") + s.router.ServeHTTP(w, req) + s.Equal(http.StatusMovedPermanently, w.Code) + s.Equal("https://example.com/path?q=1", w.Header().Get("Location")) +} + +func (s *MiddlewareTestSuite) TestSecurityHeadersRedirectsHTTPNonStandardPort() { + cfg := &config.Config{ + Server: config.ServerConfig{ + HostName: "example.com", + Port: 8443, + }, + } + + s.router.Use(SecurityHeaders(cfg)) + s.router.GET("/", func(c *gin.Context) { + c.String(http.StatusOK, "OK") + }) + + w := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/", nil) + req.Header.Set("X-Forwarded-Proto", "http") + s.router.ServeHTTP(w, req) + s.Equal(http.StatusMovedPermanently, w.Code) + s.Equal("https://example.com:8443/", w.Header().Get("Location")) +} + +func (s *MiddlewareTestSuite) TestSecurityHeadersNoRedirectForHTTPS() { + cfg := &config.Config{ + Server: config.ServerConfig{ + HostName: "example.com", + Port: 443, + }, + } + + s.router.Use(SecurityHeaders(cfg)) + s.router.GET("/", func(c *gin.Context) { + c.String(http.StatusOK, "OK") + }) + + w := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/", nil) + req.Header.Set("X-Forwarded-Proto", "https") + s.router.ServeHTTP(w, req) + s.Equal(http.StatusOK, w.Code) + s.Equal("max-age=31536000; includeSubDomains; preload", w.Header().Get("Strict-Transport-Security")) +} diff --git a/apiserver/main.go b/apiserver/main.go index d7482af3..28c44a92 100644 --- a/apiserver/main.go +++ b/apiserver/main.go @@ -136,6 +136,7 @@ func newServer(lc fx.Lifecycle, cfg *config.Config, db *gorm.DB, bgScheduler *sc corsCfg.AddAllowHeaders("Authorization") r.Use(cors.New(corsCfg)) } + r.Use(utils.SecurityHeaders(cfg)) r.Use(utils.RequestLogger()) lc.Append(fx.Hook{ diff --git a/mcpserver/Program.cs b/mcpserver/Program.cs index 27df8cc1..30aa4b14 100644 --- a/mcpserver/Program.cs +++ b/mcpserver/Program.cs @@ -80,6 +80,8 @@ var app = builder.Build(); +app.UseHsts(); +app.UseHttpsRedirection(); app.UseAuthentication(); app.UseAuthorization(); app.MapMcp().RequireAuthorization(); From a682482073cf03c3d74310903032cc1c14ca118b Mon Sep 17 00:00:00 2001 From: Dany Khalife Date: Sat, 21 Mar 2026 17:47:29 -0700 Subject: [PATCH 2/2] copilot feedback --- .../internal/utils/middleware/middleware.go | 26 +++++++++-- .../utils/middleware/middleware_test.go | 43 +++++++++++++++++++ mcpserver/Program.cs | 8 +++- 3 files changed, 72 insertions(+), 5 deletions(-) diff --git a/apiserver/internal/utils/middleware/middleware.go b/apiserver/internal/utils/middleware/middleware.go index 24772cbc..f95ef26d 100644 --- a/apiserver/internal/utils/middleware/middleware.go +++ b/apiserver/internal/utils/middleware/middleware.go @@ -4,6 +4,7 @@ import ( "fmt" "net/http" "strconv" + "strings" "dkhalife.com/tasks/core/config" "dkhalife.com/tasks/core/internal/services/logging" @@ -44,13 +45,29 @@ func RateLimitMiddleware(limiter *limiter.Limiter) gin.HandlerFunc { } } +func effectiveScheme(c *gin.Context) string { + if forwarded := c.GetHeader("X-Forwarded-Proto"); forwarded != "" { + scheme := forwarded + if i := strings.IndexByte(scheme, ','); i >= 0 { + scheme = scheme[:i] + } + return strings.ToLower(strings.TrimSpace(scheme)) + } + + if c.Request.TLS != nil { + return "https" + } + + return "http" +} + func SecurityHeaders(cfg *config.Config) gin.HandlerFunc { hostName := cfg.Server.HostName port := cfg.Server.Port return func(c *gin.Context) { - proto := c.GetHeader("X-Forwarded-Proto") - if proto == "http" { + scheme := effectiveScheme(c) + if scheme == "http" { target := fmt.Sprintf("https://%s", hostName) if port != 443 { target = fmt.Sprintf("%s:%d", target, port) @@ -61,7 +78,10 @@ func SecurityHeaders(cfg *config.Config) gin.HandlerFunc { return } - c.Header("Strict-Transport-Security", "max-age=31536000; includeSubDomains; preload") + if scheme == "https" { + c.Header("Strict-Transport-Security", "max-age=31536000; includeSubDomains; preload") + } + c.Next() } } diff --git a/apiserver/internal/utils/middleware/middleware_test.go b/apiserver/internal/utils/middleware/middleware_test.go index 0bad93ae..dd7ce941 100644 --- a/apiserver/internal/utils/middleware/middleware_test.go +++ b/apiserver/internal/utils/middleware/middleware_test.go @@ -2,6 +2,7 @@ package middleware import ( "context" + "crypto/tls" "errors" "net/http" "net/http/httptest" @@ -107,6 +108,48 @@ func (s *MiddlewareTestSuite) TestSecurityHeadersAddsHSTS() { w := httptest.NewRecorder() req, _ := http.NewRequest("GET", "/", nil) + req.Header.Set("X-Forwarded-Proto", "https") + s.router.ServeHTTP(w, req) + s.Equal(http.StatusOK, w.Code) + s.Equal("max-age=31536000; includeSubDomains; preload", w.Header().Get("Strict-Transport-Security")) +} + +func (s *MiddlewareTestSuite) TestSecurityHeadersNoHSTSForPlainHTTP() { + cfg := &config.Config{ + Server: config.ServerConfig{ + HostName: "example.com", + Port: 443, + }, + } + + s.router.Use(SecurityHeaders(cfg)) + s.router.GET("/", func(c *gin.Context) { + c.String(http.StatusOK, "OK") + }) + + w := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/", nil) + s.router.ServeHTTP(w, req) + s.Equal(http.StatusMovedPermanently, w.Code) + s.Empty(w.Header().Get("Strict-Transport-Security")) +} + +func (s *MiddlewareTestSuite) TestSecurityHeadersHSTSWithDirectTLS() { + cfg := &config.Config{ + Server: config.ServerConfig{ + HostName: "example.com", + Port: 443, + }, + } + + s.router.Use(SecurityHeaders(cfg)) + s.router.GET("/", func(c *gin.Context) { + c.String(http.StatusOK, "OK") + }) + + w := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/", nil) + req.TLS = &tls.ConnectionState{} s.router.ServeHTTP(w, req) s.Equal(http.StatusOK, w.Code) s.Equal("max-age=31536000; includeSubDomains; preload", w.Header().Get("Strict-Transport-Security")) diff --git a/mcpserver/Program.cs b/mcpserver/Program.cs index 30aa4b14..2478f204 100644 --- a/mcpserver/Program.cs +++ b/mcpserver/Program.cs @@ -80,8 +80,12 @@ var app = builder.Build(); -app.UseHsts(); -app.UseHttpsRedirection(); +if (!app.Environment.IsDevelopment()) +{ + app.UseHsts(); + app.UseHttpsRedirection(); +} + app.UseAuthentication(); app.UseAuthorization(); app.MapMcp().RequireAuthorization();