Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 44 additions & 2 deletions drpcserver/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,10 @@ type Server struct {
opts Options
handler drpc.Handler

mu sync.Mutex
stats map[string]*drpcstats.Stats
mu sync.Mutex
stats map[string]*drpcstats.Stats
listeners []net.Listener
conns map[net.Conn]struct{}
}

// New constructs a new Server.
Expand All @@ -72,6 +74,7 @@ func NewWithOptions(handler drpc.Handler, opts Options) *Server {
s := &Server{
opts: opts,
handler: handler,
conns: make(map[net.Conn]struct{}),
}

if s.opts.CollectStats {
Expand Down Expand Up @@ -168,6 +171,11 @@ func (s *Server) Serve(ctx context.Context, lis net.Listener) (err error) {
lis = tls.NewListener(lis, s.opts.TLSConfig)
}

// Track listeners we are serving on.
s.mu.Lock()
s.listeners = append(s.listeners, lis)
s.mu.Unlock()

tracker := drpcctx.NewTracker(ctx)
defer tracker.Wait()
defer tracker.Cancel()
Expand Down Expand Up @@ -203,8 +211,11 @@ func (s *Server) Serve(ctx context.Context, lis net.Listener) (err error) {
return errs.Wrap(err)
}

s.addConn(conn)

// TODO(jeff): connection limits?
tracker.Run(func(ctx context.Context) {
defer s.removeConn(conn)
err := s.ServeOne(ctx, conn)
if err != nil && s.opts.Log != nil {
s.opts.Log(err)
Expand All @@ -213,6 +224,37 @@ func (s *Server) Serve(ctx context.Context, lis net.Listener) (err error) {
}
}

// addConn registers a connection for tracking.
func (s *Server) addConn(conn net.Conn) {
s.mu.Lock()
s.conns[conn] = struct{}{}
s.mu.Unlock()
}

// removeConn unregisters a connection from tracking.
func (s *Server) removeConn(conn net.Conn) {
s.mu.Lock()
delete(s.conns, conn)
s.mu.Unlock()
}

// Stop closes all tracked listeners and forcefully closes all active
// connections.
func (s *Server) Stop() {
s.mu.Lock()
defer s.mu.Unlock()

for _, lis := range s.listeners {
_ = lis.Close()
}
s.listeners = nil

for conn := range s.conns {
_ = conn.Close()
}
s.conns = make(map[net.Conn]struct{})
}

// handleRPC handles the rpc that has been requested by the stream.
func (s *Server) handleRPC(stream *drpcstream.Stream, rpc string) (err error) {
err = s.handler.HandleRPC(stream, rpc)
Expand Down
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
module storj.io/drpc

go 1.23.0
go 1.25.5

require (
github.com/stretchr/testify v1.10.0
Expand Down
Loading