diff --git a/drpcserver/server.go b/drpcserver/server.go index 13c8186..c023088 100644 --- a/drpcserver/server.go +++ b/drpcserver/server.go @@ -51,8 +51,10 @@ type Server struct { opts Options handler drpc.Handler - mu sync.Mutex - stats map[string]*drpcstats.Stats + mu sync.Mutex + stats map[string]*drpcstats.Stats + listeners []net.Listener + conns map[net.Conn]struct{} } // New constructs a new Server. @@ -72,6 +74,7 @@ func NewWithOptions(handler drpc.Handler, opts Options) *Server { s := &Server{ opts: opts, handler: handler, + conns: make(map[net.Conn]struct{}), } if s.opts.CollectStats { @@ -168,6 +171,11 @@ func (s *Server) Serve(ctx context.Context, lis net.Listener) (err error) { lis = tls.NewListener(lis, s.opts.TLSConfig) } + // Track listeners we are serving on. + s.mu.Lock() + s.listeners = append(s.listeners, lis) + s.mu.Unlock() + tracker := drpcctx.NewTracker(ctx) defer tracker.Wait() defer tracker.Cancel() @@ -203,8 +211,11 @@ func (s *Server) Serve(ctx context.Context, lis net.Listener) (err error) { return errs.Wrap(err) } + s.addConn(conn) + // TODO(jeff): connection limits? tracker.Run(func(ctx context.Context) { + defer s.removeConn(conn) err := s.ServeOne(ctx, conn) if err != nil && s.opts.Log != nil { s.opts.Log(err) @@ -213,6 +224,37 @@ func (s *Server) Serve(ctx context.Context, lis net.Listener) (err error) { } } +// addConn registers a connection for tracking. +func (s *Server) addConn(conn net.Conn) { + s.mu.Lock() + s.conns[conn] = struct{}{} + s.mu.Unlock() +} + +// removeConn unregisters a connection from tracking. +func (s *Server) removeConn(conn net.Conn) { + s.mu.Lock() + delete(s.conns, conn) + s.mu.Unlock() +} + +// Stop closes all tracked listeners and forcefully closes all active +// connections. +func (s *Server) Stop() { + s.mu.Lock() + defer s.mu.Unlock() + + for _, lis := range s.listeners { + _ = lis.Close() + } + s.listeners = nil + + for conn := range s.conns { + _ = conn.Close() + } + s.conns = make(map[net.Conn]struct{}) +} + // handleRPC handles the rpc that has been requested by the stream. func (s *Server) handleRPC(stream *drpcstream.Stream, rpc string) (err error) { err = s.handler.HandleRPC(stream, rpc) diff --git a/go.mod b/go.mod index a66c0f9..d16fe49 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module storj.io/drpc -go 1.23.0 +go 1.25.5 require ( github.com/stretchr/testify v1.10.0