Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 29 additions & 7 deletions actions/app.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,16 +80,38 @@ func NewApp(conf Config) (*App, error) {
SSLRedirect: true,
SSLProxyHeaders: map[string]string{"X-Forwarded-Proto": "https"},
}).Handler)
router.Use(cors.Handler(cors.Options{
// AllowedOrigins: []string{"https://foo.com"}, // Use this to allow specific origin hosts
AllowedOrigins: []string{"https://*", "http://*"},
// AllowOriginFunc: func(r *http.Request, origin string) bool { return true },

// Configure CORS FIRST so headers are present even when CSRF protection blocks requests
corsOptions := cors.Options{
AllowedMethods: []string{"GET", "POST", "PUT", "DELETE", "OPTIONS"},
AllowedHeaders: []string{"Accept", "Authorization", "Content-Type", "X-CSRF-Token"},
ExposedHeaders: []string{"Link"},
AllowCredentials: false,
MaxAge: 300, // Maximum value not ignored by any of major browsers
}))
AllowCredentials: true, // Required for session cookies in cross-origin requests
MaxAge: 300, // Maximum value not ignored by any of major browsers
}

if conf.DeployEnv.IsProduction() {
// In production, only allow specific trusted origins
corsOptions.AllowedOrigins = []string{conf.SiteURL}
} else {
// In development, allow any origin (needed for testing CSRF protection)
corsOptions.AllowOriginFunc = func(r *http.Request, origin string) bool { return true }
}

router.Use(cors.Handler(corsOptions))

// Configure cross-origin protection (CSRF defense) AFTER CORS
crossOriginProtection := http.NewCrossOriginProtection()
if conf.DeployEnv.IsProduction() {
// In production, only trust requests from SITE_URL
if err := crossOriginProtection.AddTrustedOrigin(conf.SiteURL); err != nil {
return nil, fmt.Errorf("could not add trusted origin: %w", err)
}
}
// In development, the zero-value CrossOriginProtection allows all origins
router.Use(func(next http.Handler) http.Handler {
return crossOriginProtection.Handler(next)
})
router.Use(kbsession.NewMiddleware(sessionStore))

if err := app.defineRoutes(router); err != nil {
Expand Down
19 changes: 13 additions & 6 deletions actions/app_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,10 @@ func TestMain(m *testing.M) {
}

type Fixture struct {
t *testing.T
App *App
Client *kbhttp.Client
t *testing.T
App *App
Client *kbhttp.Client
BaseURL string
}

// NewFixture starts a local test server and returns it along with a cleanup function that should be deferred.
Expand All @@ -69,12 +70,18 @@ func NewFixture(t *testing.T) *Fixture {
require.Nil(t, err)

return &Fixture{
t: t,
App: app,
Client: kbhttp.NewClient(kbhttp.ClientConfig{BaseURL: baseURL}),
t: t,
App: app,
Client: kbhttp.NewClient(kbhttp.ClientConfig{BaseURL: baseURL}),
BaseURL: baseURL.String(),
}
}

func (f *Fixture) Cleanup() {
assert.Nil(f.t, f.App.Stop(context.Background()))
}

// URL returns the full URL for the given path.
func (f *Fixture) URL(path string) string {
return f.BaseURL + path
}
201 changes: 201 additions & 0 deletions actions/csrf_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
package actions

import (
"net/http"
"strings"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestCSRFProtection_BlocksUntrustedOrigins(t *testing.T) {
f := NewFixture(t)
defer f.Cleanup()

// Try to POST from an untrusted origin
body := strings.NewReader("name=Evil+User&email=evil@example.com")
req, err := http.NewRequest("POST", f.URL("/users"), body)
require.NoError(t, err)

req.Header.Set("Origin", "https://evil.com")
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")

resp, err := f.Client.Do(req)
require.NoError(t, err)
defer resp.Body.Close()

// In development mode, CrossOriginProtection allows all origins (zero-value behavior)
// In production mode with trusted origins configured, this should be blocked
if f.App.conf.DeployEnv.IsProduction() {
assert.Equal(t, http.StatusForbidden, resp.StatusCode,
"Should block requests from untrusted origins in production")
}
}

func TestCSRFProtection_AllowsTrustedOrigins(t *testing.T) {
f := NewFixture(t)
defer f.Cleanup()

// POST with the trusted origin (SITE_URL)
body := strings.NewReader("name=Test+User&email=test@example.com")
req, err := http.NewRequest("POST", f.URL("/users"), body)
require.NoError(t, err)

req.Header.Set("Origin", conf.SiteURL)
Copy link

Copilot AI Nov 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reference to undefined variable 'conf'. This should likely be 'fix.App.conf.SiteURL' to access the configuration from the test fixture.

Copilot uses AI. Check for mistakes.
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")

resp, err := f.Client.Do(req)
require.NoError(t, err)
defer resp.Body.Close()

// Should either succeed (200/201) or fail for auth reasons, but NOT forbidden
assert.NotEqual(t, http.StatusForbidden, resp.StatusCode,
"Should allow requests from trusted origins")
}

func TestCSRFProtection_AllowsRefererHeader(t *testing.T) {
f := NewFixture(t)
defer f.Cleanup()

// POST with valid Referer header (used when Origin is not present)
body := strings.NewReader("name=Test+User&email=test@example.com")
req, err := http.NewRequest("POST", f.URL("/users"), body)
require.NoError(t, err)

req.Header.Set("Referer", conf.SiteURL+"/users/new")
Copy link

Copilot AI Nov 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reference to undefined variable 'conf'. This should likely be 'fix.App.conf.SiteURL' to access the configuration from the test fixture.

Copilot uses AI. Check for mistakes.
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")

resp, err := f.Client.Do(req)
require.NoError(t, err)
defer resp.Body.Close()

assert.NotEqual(t, http.StatusForbidden, resp.StatusCode,
"Should allow requests with valid Referer header")
}

func TestCSRFProtection_BlocksInvalidReferer(t *testing.T) {
f := NewFixture(t)
defer f.Cleanup()

// POST with invalid Referer header
body := strings.NewReader("name=Evil+User&email=evil@example.com")
req, err := http.NewRequest("POST", f.URL("/users"), body)
require.NoError(t, err)

req.Header.Set("Referer", "https://evil.com/attack")
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")

resp, err := f.Client.Do(req)
require.NoError(t, err)
defer resp.Body.Close()

if f.App.conf.DeployEnv.IsProduction() {
assert.Equal(t, http.StatusForbidden, resp.StatusCode,
"Should block requests with invalid Referer in production")
}
}

func TestCSRFProtection_AllowsSafeMethodsWithoutOrigin(t *testing.T) {
f := NewFixture(t)
defer f.Cleanup()

safeMethods := []string{"GET", "HEAD", "OPTIONS"}

for _, method := range safeMethods {
t.Run(method, func(t *testing.T) {
// Safe methods should work without Origin or Referer headers
req, err := http.NewRequest(method, f.URL("/"), nil)
require.NoError(t, err)

resp, err := f.Client.Do(req)
require.NoError(t, err)
defer resp.Body.Close()

// Should not be blocked by CSRF protection
assert.NotEqual(t, http.StatusForbidden, resp.StatusCode,
"%s requests should not be blocked by CSRF protection", method)
})
}
}

func TestCSRFProtection_ProtectsAllMutatingEndpoints(t *testing.T) {
f := NewFixture(t)
defer f.Cleanup()

tests := []struct {
method string
path string
}{
{"POST", "/users"},
{"PUT", "/users/1"},
{"DELETE", "/users/1"},
{"POST", "/users/1/update"},
{"POST", "/users/1/delete"},
}

for _, tt := range tests {
t.Run(tt.method+" "+tt.path, func(t *testing.T) {
// Try to perform state-changing operation from untrusted origin
body := strings.NewReader("name=Hacker&email=hacker@example.com")
req, err := http.NewRequest(tt.method, f.URL(tt.path), body)
require.NoError(t, err)

req.Header.Set("Origin", "https://attacker.com")
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")

resp, err := f.Client.Do(req)
require.NoError(t, err)
defer resp.Body.Close()

if f.App.conf.DeployEnv.IsProduction() {
assert.Equal(t, http.StatusForbidden, resp.StatusCode,
"%s %s should be protected from cross-origin requests in production",
tt.method, tt.path)
}
})
}
}

func TestCSRFProtection_AllowsSameOriginRequests(t *testing.T) {
f := NewFixture(t)
defer f.Cleanup()

// Use the same origin as the test server
body := strings.NewReader("name=Same+Origin+User&email=same@example.com")
req, err := http.NewRequest("POST", f.URL("/users"), body)
require.NoError(t, err)

req.Header.Set("Origin", f.BaseURL)
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")

resp, err := f.Client.Do(req)
require.NoError(t, err)
defer resp.Body.Close()

// Should not be blocked (though may fail auth check)
assert.NotEqual(t, http.StatusForbidden, resp.StatusCode,
"Same-origin requests should not be blocked by CSRF protection")
}

func TestCSRFProtection_BlocksMissingOriginAndReferer(t *testing.T) {
f := NewFixture(t)
defer f.Cleanup()

// State-changing request without Origin or Referer headers
body := strings.NewReader("name=No+Origin&email=noorigin@example.com")
req, err := http.NewRequest("POST", f.URL("/users"), body)
require.NoError(t, err)

req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
// Explicitly NOT setting Origin or Referer

resp, err := f.Client.Do(req)
require.NoError(t, err)
defer resp.Body.Close()

// Behavior depends on environment and CrossOriginProtection configuration
// In strict production mode, this might be blocked
// Document the actual behavior observed
t.Logf("Status code for request without Origin/Referer: %d", resp.StatusCode)
}
8 changes: 1 addition & 7 deletions go.mod
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
module github.com/katabole/kbexample

go 1.24.0

toolchain go1.24.1
go 1.25.4

require (
github.com/elnormous/contenttype v1.0.4
Expand All @@ -29,7 +27,6 @@ require (
cloud.google.com/go/compute/metadata v0.8.4 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/fsnotify/fsnotify v1.9.0 // indirect
github.com/golang/protobuf v1.3.1 // indirect
github.com/gorilla/mux v1.8.1 // indirect
github.com/gorilla/securecookie v1.1.2 // indirect
github.com/hashicorp/errwrap v1.1.0 // indirect
Expand All @@ -39,13 +36,10 @@ require (
github.com/kr/text v0.2.0 // indirect
github.com/mattn/go-sqlite3 v2.0.3+incompatible // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/rogpeppe/go-internal v1.12.0 // indirect
golang.org/x/crypto v0.42.0 // indirect
golang.org/x/net v0.43.0 // indirect
golang.org/x/oauth2 v0.31.0 // indirect
golang.org/x/sync v0.17.0 // indirect
golang.org/x/sys v0.36.0 // indirect
golang.org/x/text v0.29.0 // indirect
google.golang.org/appengine v1.6.5 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)
Loading