@@ -10,6 +10,12 @@ import (
1010 "github.com/gorilla/websocket"
1111)
1212
13+ // Limits to prevent resource exhaustion
14+ const (
15+ MaxRooms = 1000
16+ MaxClientsPerRoom = 50
17+ )
18+
1319type Hub struct {
1420 mu sync.RWMutex
1521 rooms map [string ]map [string ]* Client
@@ -86,11 +92,8 @@ func (h *Hub) HandleWS(w http.ResponseWriter, r *http.Request) {
8692 log .Printf ("signal: ws connected from %s" , r .RemoteAddr )
8793 client := & Client {conn : c , send : make (chan Message , 32 )}
8894 go client .writePump ()
89- defer func () {
90- h .removeClient (client )
91- close (client .send )
92- c .Close ()
93- }()
95+
96+ // readPump: blocks until read error (disconnect / protocol error)
9497 for {
9598 var msg Message
9699 if err := c .ReadJSON (& msg ); err != nil {
@@ -115,6 +118,14 @@ func (h *Hub) HandleWS(w http.ResponseWriter, r *http.Request) {
115118 log .Printf ("signal: unknown msg type=%s room=%s from=%s" , msg .Type , msg .Room , msg .From )
116119 }
117120 }
121+
122+ // Cleanup: order matters to avoid goroutine leak and data race.
123+ // 1. Remove from hub first (prevents new messages being sent to client.send)
124+ h .removeClient (client )
125+ // 2. Close send channel (terminates writePump's range loop)
126+ close (client .send )
127+ // 3. Close WebSocket (writePump may still be draining; conn.Close is safe to call concurrently)
128+ c .Close ()
118129}
119130
120131func (h * Hub ) addClient (c * Client ) {
@@ -125,11 +136,21 @@ func (h *Hub) addClient(c *Client) {
125136 }
126137 m , ok := h .rooms [c .room ]
127138 if ! ok {
139+ // Enforce max rooms limit
140+ if len (h .rooms ) >= MaxRooms {
141+ log .Printf ("signal: room limit reached (%d), rejecting join room=%s id=%s" , MaxRooms , c .room , c .id )
142+ return
143+ }
128144 m = make (map [string ]* Client )
129145 h .rooms [c .room ] = m
130146 }
147+ // Enforce per-room client limit
148+ if len (m ) >= MaxClientsPerRoom {
149+ log .Printf ("signal: room %s full (%d clients), rejecting id=%s" , c .room , MaxClientsPerRoom , c .id )
150+ return
151+ }
131152 m [c .id ] = c
132- log .Printf ("signal: join room=%s id=%s" , c .room , c .id )
153+ log .Printf ("signal: join room=%s id=%s (room size: %d, total rooms: %d) " , c .room , c .id , len ( m ), len ( h . rooms ) )
133154 broadcastMembers (c .room , m )
134155}
135156
@@ -197,8 +218,11 @@ func (c *Client) writePump() {
197218 for msg := range c .send {
198219 if err := c .conn .WriteJSON (msg ); err != nil {
199220 log .Printf ("signal: write message error room=%s id=%s: %v" , c .room , c .id , err )
200- c .conn .Close ()
201- break
221+ // Don't close conn here — the read goroutine owns conn lifecycle.
222+ // Drain remaining messages from send channel so close(send) doesn't block.
223+ for range c .send {
224+ }
225+ return
202226 }
203227 }
204228}
0 commit comments