diff --git a/apiserver/internal/utils/middleware/middleware.go b/apiserver/internal/utils/middleware/middleware.go index 0d738e2d..f95ef26d 100644 --- a/apiserver/internal/utils/middleware/middleware.go +++ b/apiserver/internal/utils/middleware/middleware.go @@ -1,8 +1,10 @@ package middleware import ( + "fmt" "net/http" "strconv" + "strings" "dkhalife.com/tasks/core/config" "dkhalife.com/tasks/core/internal/services/logging" @@ -43,6 +45,47 @@ func RateLimitMiddleware(limiter *limiter.Limiter) gin.HandlerFunc { } } +func effectiveScheme(c *gin.Context) string { + if forwarded := c.GetHeader("X-Forwarded-Proto"); forwarded != "" { + scheme := forwarded + if i := strings.IndexByte(scheme, ','); i >= 0 { + scheme = scheme[:i] + } + return strings.ToLower(strings.TrimSpace(scheme)) + } + + if c.Request.TLS != nil { + return "https" + } + + return "http" +} + +func SecurityHeaders(cfg *config.Config) gin.HandlerFunc { + hostName := cfg.Server.HostName + port := cfg.Server.Port + + return func(c *gin.Context) { + scheme := effectiveScheme(c) + if scheme == "http" { + target := fmt.Sprintf("https://%s", hostName) + if port != 443 { + target = fmt.Sprintf("%s:%d", target, port) + } + target = fmt.Sprintf("%s%s", target, c.Request.URL.RequestURI()) + c.Redirect(http.StatusMovedPermanently, target) + c.Abort() + return + } + + if scheme == "https" { + c.Header("Strict-Transport-Security", "max-age=31536000; includeSubDomains; preload") + } + + c.Next() + } +} + func RequestLogger() gin.HandlerFunc { return func(c *gin.Context) { c.Next() diff --git a/apiserver/internal/utils/middleware/middleware_test.go b/apiserver/internal/utils/middleware/middleware_test.go index dd61b210..dd7ce941 100644 --- a/apiserver/internal/utils/middleware/middleware_test.go +++ b/apiserver/internal/utils/middleware/middleware_test.go @@ -2,6 +2,7 @@ package middleware import ( "context" + "crypto/tls" "errors" "net/http" "net/http/httptest" @@ -91,3 +92,128 @@ func (s *MiddlewareTestSuite) TestRateLimitMiddlewareStoreFailure() { s.router.ServeHTTP(w, req) s.Equal(http.StatusInternalServerError, w.Code) } + +func (s *MiddlewareTestSuite) TestSecurityHeadersAddsHSTS() { + cfg := &config.Config{ + Server: config.ServerConfig{ + HostName: "example.com", + Port: 443, + }, + } + + s.router.Use(SecurityHeaders(cfg)) + s.router.GET("/", func(c *gin.Context) { + c.String(http.StatusOK, "OK") + }) + + w := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/", nil) + req.Header.Set("X-Forwarded-Proto", "https") + s.router.ServeHTTP(w, req) + s.Equal(http.StatusOK, w.Code) + s.Equal("max-age=31536000; includeSubDomains; preload", w.Header().Get("Strict-Transport-Security")) +} + +func (s *MiddlewareTestSuite) TestSecurityHeadersNoHSTSForPlainHTTP() { + cfg := &config.Config{ + Server: config.ServerConfig{ + HostName: "example.com", + Port: 443, + }, + } + + s.router.Use(SecurityHeaders(cfg)) + s.router.GET("/", func(c *gin.Context) { + c.String(http.StatusOK, "OK") + }) + + w := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/", nil) + s.router.ServeHTTP(w, req) + s.Equal(http.StatusMovedPermanently, w.Code) + s.Empty(w.Header().Get("Strict-Transport-Security")) +} + +func (s *MiddlewareTestSuite) TestSecurityHeadersHSTSWithDirectTLS() { + cfg := &config.Config{ + Server: config.ServerConfig{ + HostName: "example.com", + Port: 443, + }, + } + + s.router.Use(SecurityHeaders(cfg)) + s.router.GET("/", func(c *gin.Context) { + c.String(http.StatusOK, "OK") + }) + + w := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/", nil) + req.TLS = &tls.ConnectionState{} + s.router.ServeHTTP(w, req) + s.Equal(http.StatusOK, w.Code) + s.Equal("max-age=31536000; includeSubDomains; preload", w.Header().Get("Strict-Transport-Security")) +} + +func (s *MiddlewareTestSuite) TestSecurityHeadersRedirectsHTTP() { + cfg := &config.Config{ + Server: config.ServerConfig{ + HostName: "example.com", + Port: 443, + }, + } + + s.router.Use(SecurityHeaders(cfg)) + s.router.GET("/path", func(c *gin.Context) { + c.String(http.StatusOK, "OK") + }) + + w := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/path?q=1", nil) + req.Header.Set("X-Forwarded-Proto", "http") + s.router.ServeHTTP(w, req) + s.Equal(http.StatusMovedPermanently, w.Code) + s.Equal("https://example.com/path?q=1", w.Header().Get("Location")) +} + +func (s *MiddlewareTestSuite) TestSecurityHeadersRedirectsHTTPNonStandardPort() { + cfg := &config.Config{ + Server: config.ServerConfig{ + HostName: "example.com", + Port: 8443, + }, + } + + s.router.Use(SecurityHeaders(cfg)) + s.router.GET("/", func(c *gin.Context) { + c.String(http.StatusOK, "OK") + }) + + w := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/", nil) + req.Header.Set("X-Forwarded-Proto", "http") + s.router.ServeHTTP(w, req) + s.Equal(http.StatusMovedPermanently, w.Code) + s.Equal("https://example.com:8443/", w.Header().Get("Location")) +} + +func (s *MiddlewareTestSuite) TestSecurityHeadersNoRedirectForHTTPS() { + cfg := &config.Config{ + Server: config.ServerConfig{ + HostName: "example.com", + Port: 443, + }, + } + + s.router.Use(SecurityHeaders(cfg)) + s.router.GET("/", func(c *gin.Context) { + c.String(http.StatusOK, "OK") + }) + + w := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/", nil) + req.Header.Set("X-Forwarded-Proto", "https") + s.router.ServeHTTP(w, req) + s.Equal(http.StatusOK, w.Code) + s.Equal("max-age=31536000; includeSubDomains; preload", w.Header().Get("Strict-Transport-Security")) +} diff --git a/apiserver/main.go b/apiserver/main.go index d7482af3..28c44a92 100644 --- a/apiserver/main.go +++ b/apiserver/main.go @@ -136,6 +136,7 @@ func newServer(lc fx.Lifecycle, cfg *config.Config, db *gorm.DB, bgScheduler *sc corsCfg.AddAllowHeaders("Authorization") r.Use(cors.New(corsCfg)) } + r.Use(utils.SecurityHeaders(cfg)) r.Use(utils.RequestLogger()) lc.Append(fx.Hook{ diff --git a/mcpserver/Program.cs b/mcpserver/Program.cs index 27df8cc1..2478f204 100644 --- a/mcpserver/Program.cs +++ b/mcpserver/Program.cs @@ -80,6 +80,12 @@ var app = builder.Build(); +if (!app.Environment.IsDevelopment()) +{ + app.UseHsts(); + app.UseHttpsRedirection(); +} + app.UseAuthentication(); app.UseAuthorization(); app.MapMcp().RequireAuthorization();