Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 53 additions & 3 deletions router/router_server_ws.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,16 @@ package router
import (
"context"
"encoding/json"
"net/http"
"time"

"emperror.dev/errors"
"github.com/gin-gonic/gin"
ws "github.com/gorilla/websocket"
"github.com/pterodactyl/wings/router/middleware"
"github.com/pterodactyl/wings/router/websocket"
"github.com/pterodactyl/wings/server"
"golang.org/x/time/rate"
)

var expectedCloseCodes = []int{
Expand All @@ -25,6 +28,27 @@ func getServerWebsocket(c *gin.Context) {
manager := middleware.ExtractManager(c)
s, _ := manager.Get(c.Param("server"))

// Limit the total number of websockets that can be opened at any one time for
// a server instance. This applies across all users connected to the server, and
// is not applied on a per-user basis.
//
// todo: it would be great to make this per-user instead, but we need to modify
// how we even request this endpoint in order for that to be possible. Some type
// of signed identifier in the URL that is verified on this end and set by the
// panel using a shared secret is likely the easiest option. The benefit of that
// is that we can both scope things to the user before authentication, and also
// verify that the JWT provided by the panel is assigned to the same user.
if s.Websockets().Len() >= 30 {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{
"error": "Too many open websocket connections.",
})

return
}

c.Header("Content-Security-Policy", "default-src 'self'")
c.Header("X-Frame-Options", "DENY")

// Create a context that can be canceled when the user disconnects from this
// socket that will also cancel listeners running in separate threads. If the
// connection itself is terminated listeners using this context will also be
Expand Down Expand Up @@ -81,20 +105,46 @@ func getServerWebsocket(c *gin.Context) {
return
}

for {
j := websocket.Message{}
// There is a separate rate limiter that applies to individual message types
// within the actual websocket logic handler. _This_ rate limiter just exists
// to avoid enormous floods of data through the socket since we need to parse
// JSON each time. This rate limit realistically should never be hit since this
// would require sending 50+ messages a second over the websocket (no more than
// 10 per 200ms).
var throttled bool
rl := rate.NewLimiter(rate.Every(time.Millisecond*200), 10)

_, p, err := handler.Connection.ReadMessage()
for {
t, p, err := handler.Connection.ReadMessage()
if err != nil {
if ws.IsUnexpectedCloseError(err, expectedCloseCodes...) {
handler.Logger().WithField("error", err).Warn("error handling websocket message for server")
}
break
}

if !rl.Allow() {
if !throttled {
throttled = true
_ = handler.Connection.WriteJSON(websocket.Message{Event: websocket.ThrottledEvent, Args: []string{"global"}})
}
continue
}

throttled = false

// If the message isn't a format we expect, or the length of the message is far larger
// than we'd ever expect, drop it. The websocket upgrader logic does enforce a maximum
// _compressed_ message size of 4Kb but that could decompress to a much larger amount
// of data.
if t != ws.TextMessage || len(p) > 32_768 {
continue
}

// Discard and JSON parse errors into the void and don't continue processing this
// specific socket request. If we did a break here the client would get disconnected
// from the socket, which is NOT what we want to do.
var j websocket.Message
if err := json.Unmarshal(p, &j); err != nil {
continue
}
Expand Down
91 changes: 91 additions & 0 deletions router/websocket/limiter.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
package websocket

import (
"sync"
"time"

"golang.org/x/time/rate"
)

type LimiterBucket struct {
mu sync.RWMutex
limits map[Event]*rate.Limiter
throttles map[Event]bool
}

func (h *Handler) IsThrottled(e Event) bool {
l := h.limiter.For(e)

h.limiter.mu.Lock()
defer h.limiter.mu.Unlock()

if l.Allow() {
h.limiter.throttles[e] = false

return false
}

// If not allowed, track the throttling and send an event over the wire
// if one wasn't already sent in the same throttling period.
if v, ok := h.limiter.throttles[e]; !v || !ok {
h.limiter.throttles[e] = true
h.Logger().WithField("event", e).Debug("throttling websocket due to event volume")

_ = h.unsafeSendJson(&Message{Event: ThrottledEvent, Args: []string{string(e)}})
}

return true
}

func NewLimiter() *LimiterBucket {
return &LimiterBucket{
limits: make(map[Event]*rate.Limiter, 4),
throttles: make(map[Event]bool, 4),
}
}

// For returns the internal rate limiter for the given event type. In most
// cases this is a shared rate limiter for events, but certain "heavy" or low-frequency
// events implement their own limiters.
func (l *LimiterBucket) For(e Event) *rate.Limiter {
name := limiterName(e)

l.mu.RLock()
if v, ok := l.limits[name]; ok {
l.mu.RUnlock()
return v
}

l.mu.RUnlock()
l.mu.Lock()
defer l.mu.Unlock()

limit, burst := limitValuesFor(e)
l.limits[name] = rate.NewLimiter(limit, burst)

return l.limits[name]
}

// limitValuesFor returns the underlying limit and burst value for the given event.
func limitValuesFor(e Event) (rate.Limit, int) {
// Twice every five seconds.
if e == AuthenticationEvent || e == SendServerLogsEvent {
return rate.Every(time.Second * 5), 2
}

// 10 per second.
if e == SendCommandEvent {
return rate.Every(time.Second), 10
}

// 4 per second.
return rate.Every(time.Second), 4
}

func limiterName(e Event) Event {
if e == AuthenticationEvent || e == SendServerLogsEvent || e == SendCommandEvent {
return e
}

return "_default"
}
4 changes: 2 additions & 2 deletions router/websocket/listeners.go
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ func (h *Handler) listenForServerEvents(ctx context.Context) error {
continue
}
var sendErr error
message := Message{Event: e.Topic}
message := Message{Event: Event(e.Topic)}
if str, ok := e.Data.(string); ok {
message.Args = []string{str}
} else if b, ok := e.Data.([]byte); ok {
Expand All @@ -149,7 +149,7 @@ func (h *Handler) listenForServerEvents(ctx context.Context) error {
continue
}
}
onError(message.Event, sendErr)
onError(string(message.Event), sendErr)
}
break
}
Expand Down
25 changes: 14 additions & 11 deletions router/websocket/message.go
Original file line number Diff line number Diff line change
@@ -1,21 +1,24 @@
package websocket

type Event string

const (
AuthenticationSuccessEvent = "auth success"
TokenExpiringEvent = "token expiring"
TokenExpiredEvent = "token expired"
AuthenticationEvent = "auth"
SetStateEvent = "set state"
SendServerLogsEvent = "send logs"
SendCommandEvent = "send command"
SendStatsEvent = "send stats"
ErrorEvent = "daemon error"
JwtErrorEvent = "jwt error"
AuthenticationSuccessEvent = Event("auth success")
TokenExpiringEvent = Event("token expiring")
TokenExpiredEvent = Event("token expired")
AuthenticationEvent = Event("auth")
SetStateEvent = Event("set state")
SendServerLogsEvent = Event("send logs")
SendCommandEvent = Event("send command")
SendStatsEvent = Event("send stats")
ErrorEvent = Event("daemon error")
JwtErrorEvent = Event("jwt error")
ThrottledEvent = Event("throttled")
)

type Message struct {
// The event to perform.
Event string `json:"event"`
Event Event `json:"event"`

// The data to pass along, only used by power/command currently. Other requests
// should either omit the field or pass an empty value as it is ignored.
Expand Down
23 changes: 16 additions & 7 deletions router/websocket/websocket.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,13 @@ import (
"sync"
"time"

"github.com/pterodactyl/wings/internal/models"

"emperror.dev/errors"
"github.com/apex/log"
"github.com/gbrlsnchs/jwt/v3"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
"github.com/gorilla/websocket"
"github.com/pterodactyl/wings/internal/models"

"github.com/pterodactyl/wings/system"

Expand Down Expand Up @@ -46,6 +45,7 @@ type Handler struct {
server *server.Server
ra server.RequestActivity
uuid uuid.UUID
limiter *LimiterBucket
}

var (
Expand Down Expand Up @@ -84,6 +84,7 @@ func NewTokenPayload(token []byte) (*tokens.WebsocketPayload, error) {
// GetHandler returns a new websocket handler using the context provided.
func GetHandler(s *server.Server, w http.ResponseWriter, r *http.Request, c *gin.Context) (*Handler, error) {
upgrader := websocket.Upgrader{
EnableCompression: true,
// Ensure that the websocket request is originating from the Panel itself,
// and not some other location.
CheckOrigin: func(r *http.Request) bool {
Expand All @@ -110,12 +111,16 @@ func GetHandler(s *server.Server, w http.ResponseWriter, r *http.Request, c *gin
return nil, err
}

conn.SetReadLimit(4096)
_ = conn.SetCompressionLevel(5)

return &Handler{
Connection: conn,
jwt: nil,
server: s,
ra: s.NewRequestActivity("", c.ClientIP()),
uuid: u,
limiter: NewLimiter(),
}, nil
}

Expand Down Expand Up @@ -150,7 +155,7 @@ func (h *Handler) SendJson(v Message) error {

// If the user does not have permission to see backup events, do not emit
// them over the socket.
if strings.HasPrefix(v.Event, server.BackupCompletedEvent) {
if strings.HasPrefix(string(v.Event), server.BackupCompletedEvent) {
if !j.HasPermission(PermissionReceiveBackups) {
return nil
}
Expand Down Expand Up @@ -277,6 +282,14 @@ func (h *Handler) setJwt(token *tokens.WebsocketPayload) {

// HandleInbound handles an inbound socket request and route it to the proper action.
func (h *Handler) HandleInbound(ctx context.Context, m Message) error {
if h.server.IsSuspended() {
return server.ErrSuspended
}

if h.IsThrottled(m.Event) {
return nil
}

if m.Event != AuthenticationEvent {
if err := h.TokenValid(); err != nil {
h.unsafeSendJson(Message{
Expand All @@ -287,10 +300,6 @@ func (h *Handler) HandleInbound(ctx context.Context, m Message) error {
}
}

if h.server.IsSuspended() {
return server.ErrSuspended
}

switch m.Event {
case AuthenticationEvent:
{
Expand Down
7 changes: 7 additions & 0 deletions server/websockets.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,13 @@ func (s *Server) Websockets() *WebsocketBag {
return s.wsBag
}

func (w *WebsocketBag) Len() int {
w.mu.Lock()
defer w.mu.Unlock()

return len(w.conns)
}

// Push adds a new websocket connection to the end of the stack.
func (w *WebsocketBag) Push(u uuid.UUID, cancel *context.CancelFunc) {
w.mu.Lock()
Expand Down