Skip to content

Commit 2a4e6e5

Browse files
committed
allow localhost on dev api
1 parent 2c5b6a5 commit 2a4e6e5

4 files changed

Lines changed: 49 additions & 11 deletions

File tree

internal/config/config.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ type Config struct {
1616
type ServerConfig struct {
1717
Host string `env:"SERVER_HOST" env-default:"0.0.0.0"`
1818
Port string `env:"SERVER_PORT" env-default:"8080"`
19-
AllowedOrigins []string `env:"ALLOWED_ORIGINS" env-default:"http://localhost:3000,http://localhost:8080" env-separator:","`
19+
AllowedOrigins []string `env:"ALLOWED_ORIGINS" env-default:"http://localhost:3000,http://localhost:8080,http://localhost:5173,http://127.0.0.1:3000,http://127.0.0.1:8080,http://127.0.0.1:5173" env-separator:","`
2020
}
2121

2222
type DatabaseConfig struct {

internal/middleware/cors.go

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,40 @@
11
package middleware
22

3-
import "net/http"
3+
import (
4+
"net/http"
5+
"strings"
6+
)
47

58
// CORS adds CORS headers for cross-origin requests
6-
func CORS(allowedOrigins []string) func(http.Handler) http.Handler {
9+
func CORS(allowedOrigins []string, isDev bool) func(http.Handler) http.Handler {
710
return func(next http.Handler) http.Handler {
811
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
912
origin := r.Header.Get("Origin")
1013
allowed := false
1114

12-
// Allow if origin is in allowlist
13-
for _, o := range allowedOrigins {
14-
if o == origin {
15-
allowed = true
16-
break
15+
// In development, allow any localhost or 127.0.0.1 origin
16+
isLocal := strings.HasPrefix(origin, "http://localhost") || strings.HasPrefix(origin, "http://127.0.0.1")
17+
if isDev && isLocal {
18+
allowed = true
19+
}
20+
21+
// In production, explicitly block localhost/127.0.0.1 even if in allowlist
22+
if !isDev && isLocal {
23+
allowed = false
24+
} else if !allowed {
25+
// Allow if origin is in allowlist
26+
for _, o := range allowedOrigins {
27+
if o == origin {
28+
allowed = true
29+
break
30+
}
1731
}
1832
}
1933

2034
// If no allowed origins specified, allow all (development mode)
2135
// But for credentials to work with *, strict browsers block it.
2236
// So good practice: if development, echo back origin.
23-
if len(allowedOrigins) == 0 {
37+
if !allowed && len(allowedOrigins) == 0 {
2438
allowed = true
2539
}
2640

internal/middleware/cors_test.go

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ func TestCORS(t *testing.T) {
1717
origin string
1818
method string
1919
setupOrigins []string
20+
isDev bool
2021
expectedOrigin string
2122
expectedCreds string
2223
}{
@@ -25,6 +26,7 @@ func TestCORS(t *testing.T) {
2526
origin: "https://app.example.com",
2627
method: "GET",
2728
setupOrigins: allowedOrigins,
29+
isDev: false,
2830
expectedOrigin: "https://app.example.com",
2931
expectedCreds: "true",
3032
},
@@ -33,6 +35,7 @@ func TestCORS(t *testing.T) {
3335
origin: "https://evil.com",
3436
method: "GET",
3537
setupOrigins: allowedOrigins,
38+
isDev: false,
3639
expectedOrigin: "",
3740
expectedCreds: "",
3841
},
@@ -41,6 +44,7 @@ func TestCORS(t *testing.T) {
4144
origin: "",
4245
method: "GET",
4346
setupOrigins: allowedOrigins,
47+
isDev: false,
4448
expectedOrigin: "",
4549
expectedCreds: "",
4650
},
@@ -49,6 +53,7 @@ func TestCORS(t *testing.T) {
4953
origin: "https://app.example.com",
5054
method: "OPTIONS",
5155
setupOrigins: allowedOrigins,
56+
isDev: false,
5257
expectedOrigin: "https://app.example.com",
5358
expectedCreds: "true",
5459
},
@@ -57,14 +62,33 @@ func TestCORS(t *testing.T) {
5762
origin: "https://random.com",
5863
method: "GET",
5964
setupOrigins: []string{}, // Empty = allow all
65+
isDev: false, // Though typically true in dev
6066
expectedOrigin: "https://random.com",
6167
expectedCreds: "true",
6268
},
69+
{
70+
name: "AllowedLocalhostInDev",
71+
origin: "http://localhost:5173",
72+
method: "GET",
73+
setupOrigins: []string{"https://app.example.com"},
74+
isDev: true,
75+
expectedOrigin: "http://localhost:5173",
76+
expectedCreds: "true",
77+
},
78+
{
79+
name: "BlockedLocalhostInProd",
80+
origin: "http://localhost:3000",
81+
method: "GET",
82+
setupOrigins: []string{"http://localhost:3000", "https://app.example.com"},
83+
isDev: false,
84+
expectedOrigin: "",
85+
expectedCreds: "",
86+
},
6387
}
6488

6589
for _, tt := range tests {
6690
t.Run(tt.name, func(t *testing.T) {
67-
handler := middleware.CORS(tt.setupOrigins)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
91+
handler := middleware.CORS(tt.setupOrigins, tt.isDev)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
6892
w.WriteHeader(http.StatusOK)
6993
}))
7094

internal/router/router.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ func New(h *handler.Handler, queries database.Querier, jwtSecret string, allowed
2020
}
2121
r.Use(chimiddleware.Recoverer)
2222
r.Use(chimiddleware.RequestID)
23-
r.Use(middleware.CORS(allowedOrigins))
23+
r.Use(middleware.CORS(allowedOrigins, h.Config.Env == "development"))
2424

2525
// API routes
2626
r.Route("/api", func(r chi.Router) {

0 commit comments

Comments
 (0)