Skip to content

Commit 86e8d18

Browse files
authored
Add HTTP to HTTPS redirect and HSTS headers (#255)
1 parent e2c9b90 commit 86e8d18

4 files changed

Lines changed: 176 additions & 0 deletions

File tree

apiserver/internal/utils/middleware/middleware.go

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
package middleware
22

33
import (
4+
"fmt"
45
"net/http"
56
"strconv"
7+
"strings"
68

79
"dkhalife.com/tasks/core/config"
810
"dkhalife.com/tasks/core/internal/services/logging"
@@ -43,6 +45,47 @@ func RateLimitMiddleware(limiter *limiter.Limiter) gin.HandlerFunc {
4345
}
4446
}
4547

48+
func effectiveScheme(c *gin.Context) string {
49+
if forwarded := c.GetHeader("X-Forwarded-Proto"); forwarded != "" {
50+
scheme := forwarded
51+
if i := strings.IndexByte(scheme, ','); i >= 0 {
52+
scheme = scheme[:i]
53+
}
54+
return strings.ToLower(strings.TrimSpace(scheme))
55+
}
56+
57+
if c.Request.TLS != nil {
58+
return "https"
59+
}
60+
61+
return "http"
62+
}
63+
64+
func SecurityHeaders(cfg *config.Config) gin.HandlerFunc {
65+
hostName := cfg.Server.HostName
66+
port := cfg.Server.Port
67+
68+
return func(c *gin.Context) {
69+
scheme := effectiveScheme(c)
70+
if scheme == "http" {
71+
target := fmt.Sprintf("https://%s", hostName)
72+
if port != 443 {
73+
target = fmt.Sprintf("%s:%d", target, port)
74+
}
75+
target = fmt.Sprintf("%s%s", target, c.Request.URL.RequestURI())
76+
c.Redirect(http.StatusMovedPermanently, target)
77+
c.Abort()
78+
return
79+
}
80+
81+
if scheme == "https" {
82+
c.Header("Strict-Transport-Security", "max-age=31536000; includeSubDomains; preload")
83+
}
84+
85+
c.Next()
86+
}
87+
}
88+
4689
func RequestLogger() gin.HandlerFunc {
4790
return func(c *gin.Context) {
4891
c.Next()

apiserver/internal/utils/middleware/middleware_test.go

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package middleware
22

33
import (
44
"context"
5+
"crypto/tls"
56
"errors"
67
"net/http"
78
"net/http/httptest"
@@ -91,3 +92,128 @@ func (s *MiddlewareTestSuite) TestRateLimitMiddlewareStoreFailure() {
9192
s.router.ServeHTTP(w, req)
9293
s.Equal(http.StatusInternalServerError, w.Code)
9394
}
95+
96+
func (s *MiddlewareTestSuite) TestSecurityHeadersAddsHSTS() {
97+
cfg := &config.Config{
98+
Server: config.ServerConfig{
99+
HostName: "example.com",
100+
Port: 443,
101+
},
102+
}
103+
104+
s.router.Use(SecurityHeaders(cfg))
105+
s.router.GET("/", func(c *gin.Context) {
106+
c.String(http.StatusOK, "OK")
107+
})
108+
109+
w := httptest.NewRecorder()
110+
req, _ := http.NewRequest("GET", "/", nil)
111+
req.Header.Set("X-Forwarded-Proto", "https")
112+
s.router.ServeHTTP(w, req)
113+
s.Equal(http.StatusOK, w.Code)
114+
s.Equal("max-age=31536000; includeSubDomains; preload", w.Header().Get("Strict-Transport-Security"))
115+
}
116+
117+
func (s *MiddlewareTestSuite) TestSecurityHeadersNoHSTSForPlainHTTP() {
118+
cfg := &config.Config{
119+
Server: config.ServerConfig{
120+
HostName: "example.com",
121+
Port: 443,
122+
},
123+
}
124+
125+
s.router.Use(SecurityHeaders(cfg))
126+
s.router.GET("/", func(c *gin.Context) {
127+
c.String(http.StatusOK, "OK")
128+
})
129+
130+
w := httptest.NewRecorder()
131+
req, _ := http.NewRequest("GET", "/", nil)
132+
s.router.ServeHTTP(w, req)
133+
s.Equal(http.StatusMovedPermanently, w.Code)
134+
s.Empty(w.Header().Get("Strict-Transport-Security"))
135+
}
136+
137+
func (s *MiddlewareTestSuite) TestSecurityHeadersHSTSWithDirectTLS() {
138+
cfg := &config.Config{
139+
Server: config.ServerConfig{
140+
HostName: "example.com",
141+
Port: 443,
142+
},
143+
}
144+
145+
s.router.Use(SecurityHeaders(cfg))
146+
s.router.GET("/", func(c *gin.Context) {
147+
c.String(http.StatusOK, "OK")
148+
})
149+
150+
w := httptest.NewRecorder()
151+
req, _ := http.NewRequest("GET", "/", nil)
152+
req.TLS = &tls.ConnectionState{}
153+
s.router.ServeHTTP(w, req)
154+
s.Equal(http.StatusOK, w.Code)
155+
s.Equal("max-age=31536000; includeSubDomains; preload", w.Header().Get("Strict-Transport-Security"))
156+
}
157+
158+
func (s *MiddlewareTestSuite) TestSecurityHeadersRedirectsHTTP() {
159+
cfg := &config.Config{
160+
Server: config.ServerConfig{
161+
HostName: "example.com",
162+
Port: 443,
163+
},
164+
}
165+
166+
s.router.Use(SecurityHeaders(cfg))
167+
s.router.GET("/path", func(c *gin.Context) {
168+
c.String(http.StatusOK, "OK")
169+
})
170+
171+
w := httptest.NewRecorder()
172+
req, _ := http.NewRequest("GET", "/path?q=1", nil)
173+
req.Header.Set("X-Forwarded-Proto", "http")
174+
s.router.ServeHTTP(w, req)
175+
s.Equal(http.StatusMovedPermanently, w.Code)
176+
s.Equal("https://example.com/path?q=1", w.Header().Get("Location"))
177+
}
178+
179+
func (s *MiddlewareTestSuite) TestSecurityHeadersRedirectsHTTPNonStandardPort() {
180+
cfg := &config.Config{
181+
Server: config.ServerConfig{
182+
HostName: "example.com",
183+
Port: 8443,
184+
},
185+
}
186+
187+
s.router.Use(SecurityHeaders(cfg))
188+
s.router.GET("/", func(c *gin.Context) {
189+
c.String(http.StatusOK, "OK")
190+
})
191+
192+
w := httptest.NewRecorder()
193+
req, _ := http.NewRequest("GET", "/", nil)
194+
req.Header.Set("X-Forwarded-Proto", "http")
195+
s.router.ServeHTTP(w, req)
196+
s.Equal(http.StatusMovedPermanently, w.Code)
197+
s.Equal("https://example.com:8443/", w.Header().Get("Location"))
198+
}
199+
200+
func (s *MiddlewareTestSuite) TestSecurityHeadersNoRedirectForHTTPS() {
201+
cfg := &config.Config{
202+
Server: config.ServerConfig{
203+
HostName: "example.com",
204+
Port: 443,
205+
},
206+
}
207+
208+
s.router.Use(SecurityHeaders(cfg))
209+
s.router.GET("/", func(c *gin.Context) {
210+
c.String(http.StatusOK, "OK")
211+
})
212+
213+
w := httptest.NewRecorder()
214+
req, _ := http.NewRequest("GET", "/", nil)
215+
req.Header.Set("X-Forwarded-Proto", "https")
216+
s.router.ServeHTTP(w, req)
217+
s.Equal(http.StatusOK, w.Code)
218+
s.Equal("max-age=31536000; includeSubDomains; preload", w.Header().Get("Strict-Transport-Security"))
219+
}

apiserver/main.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,7 @@ func newServer(lc fx.Lifecycle, cfg *config.Config, db *gorm.DB, bgScheduler *sc
136136
corsCfg.AddAllowHeaders("Authorization")
137137
r.Use(cors.New(corsCfg))
138138
}
139+
r.Use(utils.SecurityHeaders(cfg))
139140
r.Use(utils.RequestLogger())
140141

141142
lc.Append(fx.Hook{

mcpserver/Program.cs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,12 @@
8080

8181
var app = builder.Build();
8282

83+
if (!app.Environment.IsDevelopment())
84+
{
85+
app.UseHsts();
86+
app.UseHttpsRedirection();
87+
}
88+
8389
app.UseAuthentication();
8490
app.UseAuthorization();
8591
app.MapMcp().RequireAuthorization();

0 commit comments

Comments
 (0)