@@ -119,19 +119,24 @@ func (s *Server) Serve() error {
119119 }
120120
121121 s .wg .Add (1 )
122- go func (netConn net.Conn ) {
123- defer s .wg .Done ()
124122
125- // Check connection limit after Accept to prevent TOCTOU race
126- if s .MaxConnections > 0 && s . connCount .Add (1 ) > int64 (s .MaxConnections ) {
123+ if s . MaxConnections > 0 {
124+ if s .connCount .Add (1 ) > int64 (s .MaxConnections ) {
127125 s .connCount .Add (- 1 )
128- netConn .Close ()
129- s .Logger .Warn ("Connection limit reached, rejecting connection from %s" , netConn .RemoteAddr ())
130- return
126+ s .wg .Done ()
127+ conn .Close ()
128+ s .Logger .Warn ("Connection limit reached, rejecting connection from %s" , conn .RemoteAddr ())
129+ continue
131130 }
131+ } else {
132+ s .connCount .Add (1 )
133+ }
134+
135+ go func (netConn net.Conn ) {
136+ defer s .wg .Done ()
137+ defer s .connCount .Add (- 1 )
132138
133139 s .handleConnectionInternal (netConn )
134- s .connCount .Add (- 1 )
135140 }(conn )
136141 }
137142}
@@ -151,13 +156,18 @@ func (s *Server) Shutdown(ctx context.Context) error {
151156
152157 // Close all active connections
153158 s .mu .RLock ()
159+ conns := make ([]* Connection , 0 , len (s .activeConns ))
154160 for conn := range s .activeConns {
161+ conns = append (conns , conn )
162+ }
163+ s .mu .RUnlock ()
164+
165+ for _ , conn := range conns {
155166 err := conn .Close ()
156167 if err != nil {
157168 return err
158169 }
159170 }
160- s .mu .RUnlock ()
161171
162172 // Run shutdown hooks
163173 s .mu .Lock ()
@@ -252,9 +262,9 @@ func (s *Server) handleConnectionInternal(netConn net.Conn) {
252262
253263 s .Logger .Debug ("Command from %s: %s %v" , netConn .RemoteAddr (), cmd .Name , cmd .Args )
254264
255- s .setConnectionActive (conn )
256-
265+ conn .setState (StateProcessing )
257266 response := s .handleCommand (conn , cmd )
267+ conn .setState (StateActive )
258268
259269 if s .WriteTimeout > 0 {
260270 err := netConn .SetWriteDeadline (time .Now ().Add (s .WriteTimeout ))
@@ -377,14 +387,14 @@ func (s *Server) checkIdleConnections() {
377387
378388 currentState := ConnState (conn .state .Load ())
379389
380- if currentState == StateActive && lastUsed .Before (idleThreshold ) {
390+ if ( currentState == StateActive || currentState == StateIdle ) && lastUsed .Before (idleThreshold ) {
381391 idleConns = append (idleConns , conn )
382392 }
383393 }
384394
385395 for _ , conn := range idleConns {
386- conn .setState ( StateIdle )
387- s . Logger . Debug ( "Connection %s marked as idle" , conn .RemoteAddr () )
396+ s . Logger . Info ( "Closing idle connection %s" , conn .RemoteAddr () )
397+ conn .Close ( )
388398 }
389399}
390400
0 commit comments