|
1 | 1 | package signal |
2 | 2 |
|
3 | 3 | import ( |
| 4 | + "net/http" |
4 | 5 | "testing" |
5 | 6 |
|
6 | 7 | "github.com/gorilla/websocket" |
7 | 8 | ) |
8 | 9 |
|
| 10 | +func TestIsOriginAllowed(t *testing.T) { |
| 11 | + tests := []struct { |
| 12 | + name string |
| 13 | + opts Options |
| 14 | + host string |
| 15 | + origin string |
| 16 | + allowed bool |
| 17 | + }{ |
| 18 | + {"no origin localhost", Options{}, "localhost:8080", "", true}, |
| 19 | + {"no origin 127.0.0.1", Options{}, "127.0.0.1:8080", "", true}, |
| 20 | + {"no origin external", Options{}, "example.com:8080", "", false}, |
| 21 | + {"origin localhost", Options{}, "localhost:8080", "http://localhost:8080", true}, |
| 22 | + {"origin 127.0.0.1", Options{}, "127.0.0.1:8080", "http://127.0.0.1:8080", true}, |
| 23 | + {"origin external blocked", Options{}, "example.com", "https://example.com", false}, |
| 24 | + {"allow all", Options{AllowAllOrigins: true}, "example.com", "https://evil.com", true}, |
| 25 | + {"whitelist match", Options{AllowedOrigins: []string{"https://example.com"}}, "", "https://example.com", true}, |
| 26 | + {"whitelist miss", Options{AllowedOrigins: []string{"https://example.com"}}, "", "https://evil.com", false}, |
| 27 | + } |
| 28 | + for _, tt := range tests { |
| 29 | + t.Run(tt.name, func(t *testing.T) { |
| 30 | + h := NewHubWithOptions(tt.opts) |
| 31 | + r := &http.Request{Host: tt.host, Header: http.Header{}} |
| 32 | + if tt.origin != "" { |
| 33 | + r.Header.Set("Origin", tt.origin) |
| 34 | + } |
| 35 | + if got := h.isOriginAllowed(r); got != tt.allowed { |
| 36 | + t.Errorf("isOriginAllowed() = %v, want %v", got, tt.allowed) |
| 37 | + } |
| 38 | + }) |
| 39 | + } |
| 40 | +} |
| 41 | + |
9 | 42 | func drainMessages(ch <-chan Message) { |
10 | 43 | for { |
11 | 44 | select { |
|
0 commit comments