From 4118e5fe12dc96224febf7363056d054214d6f81 Mon Sep 17 00:00:00 2001 From: Shubham Dhama Date: Thu, 30 Oct 2025 09:08:01 +0000 Subject: [PATCH] drpcyamux: add yamux multiplexing support for concurrent RPCs This package enables multiple concurrent RPCs over a single connection using HashiCorp's yamux multiplexer. It's based on https://gitea.elara.ws/Elara6331/drpc but almost rewritten to fix critical concurrency issues and resource leaks. Key improvements over the original: - Fixed race conditions on connection state using sync.Once and channels - Eliminated goroutine leaks in stream cleanup and server shutdown paths - Proper graceful shutdown with WaitGroups throughout the server stack - Thread-safe idempotent Close() on both client and server - Simplified error handling in session Accept loop - Context-aware shutdown that properly unblocks blocking operations The package provides two main components: - Conn: Client-side drpc.Conn implementation with yamux multiplexing - Server: Server that accepts yamux sessions and handles streams concurrently --- drpcyamux/conn.go | 134 ++++++++++++++++++++++++++++++++++++++++++++ drpcyamux/server.go | 128 ++++++++++++++++++++++++++++++++++++++++++ go.mod | 5 +- go.sum | 2 + 4 files changed, 268 insertions(+), 1 deletion(-) create mode 100644 drpcyamux/conn.go create mode 100644 drpcyamux/server.go diff --git a/drpcyamux/conn.go b/drpcyamux/conn.go new file mode 100644 index 0000000..cb67d15 --- /dev/null +++ b/drpcyamux/conn.go @@ -0,0 +1,134 @@ +// Copyright (C) 2023 Elara Musayelyan +// Copyright (C) 2025 Cockroach Labs +// See LICENSE for copying information. + +package drpcyamux + +import ( + "context" + "errors" + "io" + "sync" + + "github.com/hashicorp/yamux" + "storj.io/drpc" + "storj.io/drpc/drpcconn" +) + +var ErrClosed = errors.New("connection closed") + +var _ drpc.Conn = &Conn{} + +// Conn implements drpc.Conn using the yamux multiplexer to allow concurrent +// RPCs +type Conn struct { + conn io.ReadWriteCloser + sess *yamux.Session + + closeOnce sync.Once + closeErr error + closed chan struct{} +} + +// NewConn returns a new multiplexed DRPC connection as a client +func NewConn(conn io.ReadWriteCloser) (*Conn, error) { + return NewConnWithConfig(conn, nil) +} + +// NewConnWithConfig returns a new multiplexed DRPC connection as a client +// with the given yamux configuration +func NewConnWithConfig(conn io.ReadWriteCloser, config *yamux.Config) (*Conn, error) { + sess, err := yamux.Client(conn, config) + if err != nil { + return nil, err + } + + return &Conn{ + conn: conn, + sess: sess, + closed: make(chan struct{}), + }, nil +} + +// Close closes the multiplexer session and the underlying connection. It is +// safe to call Close multiple times. +func (c *Conn) Close() error { + c.closeOnce.Do(func() { + close(c.closed) + + // Close session first to stop accepting new streams + sessErr := c.sess.Close() + + // Always close the underlying connection + connErr := c.conn.Close() + + // Return the first error encountered + if sessErr != nil { + c.closeErr = sessErr + } else { + c.closeErr = connErr + } + }) + return c.closeErr +} + +// Closed returns a channel that will be closed +// when the connection is closed +func (c *Conn) Closed() <-chan struct{} { + return c.closed +} + +// Invoke issues the rpc on the transport serializing in, waits for a response, +// and deserializes it into out. +func (c *Conn) Invoke( + ctx context.Context, rpc string, enc drpc.Encoding, in, out drpc.Message, +) error { + select { + case <-c.closed: + return ErrClosed + default: + } + + stream, err := c.sess.Open() + if err != nil { + return err + } + defer stream.Close() + + dconn := drpcconn.New(stream) + defer dconn.Close() + + return dconn.Invoke(ctx, rpc, enc, in, out) +} + +// NewStream begins a streaming rpc on the connection. +func (c *Conn) NewStream(ctx context.Context, rpc string, enc drpc.Encoding) (drpc.Stream, error) { + select { + case <-c.closed: + return nil, ErrClosed + default: + } + + stream, err := c.sess.Open() + if err != nil { + return nil, err + } + + dconn := drpcconn.New(stream) + + s, err := dconn.NewStream(ctx, rpc, enc) + if err != nil { + dconn.Close() + stream.Close() + return nil, err + } + + // Clean up the yamux stream when the drpc connection closes. + // This goroutine will exit when dconn.Closed() is signaled. + go func() { + <-dconn.Closed() + stream.Close() + }() + + return s, nil +} diff --git a/drpcyamux/server.go b/drpcyamux/server.go new file mode 100644 index 0000000..501f4b0 --- /dev/null +++ b/drpcyamux/server.go @@ -0,0 +1,128 @@ +// Copyright (C) 2023 Elara Musayelyan +// Copyright (C) 2025 Cockroach Labs +// See LICENSE for copying information. + +package drpcyamux + +import ( + "context" + "crypto/tls" + "errors" + "net" + "sync" + + "github.com/hashicorp/yamux" + "storj.io/drpc" + "storj.io/drpc/drpcctx" + "storj.io/drpc/drpcserver" +) + +// Server is a DRPC server that handles multiplexed streams +type Server struct { + srv *drpcserver.Server +} + +// NewServer creates a new multiplexing DRPC server with default options +func NewServer(handler drpc.Handler) *Server { + return &Server{srv: drpcserver.New(handler)} +} + +// NewServerWithOptions creates a new multiplexing DRPC server with custom options +func NewServerWithOptions(handler drpc.Handler, opts drpcserver.Options) *Server { + return &Server{srv: drpcserver.NewWithOptions(handler, opts)} +} + +// Serve listens on the given listener and handles all multiplexed streams. +// It blocks until the context is canceled or an unrecoverable error occurs. +func (s *Server) Serve(ctx context.Context, ln net.Listener) error { + var wg sync.WaitGroup + defer wg.Wait() + + // Context for coordinating shutdown + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + for { + conn, err := ln.Accept() + if err != nil { + // Check if we're shutting down + select { + case <-ctx.Done(): + return nil + default: + } + + // If listener was closed, treat it as shutdown + var opErr *net.OpError + if errors.As(err, &opErr) && opErr.Op == "accept" { + return nil + } + + return err + } + + wg.Add(1) + go func() { + defer wg.Done() + s.handleConn(ctx, conn) + }() + } +} + +// handleConn processes a single connection with multiplexing +func (s *Server) handleConn(ctx context.Context, conn net.Conn) { + defer conn.Close() + + if tlsConn, ok := conn.(*tls.Conn); ok { + err := tlsConn.Handshake() + if err != nil { + return + } + state := tlsConn.ConnectionState() + if len(state.PeerCertificates) > 0 { + ctx = drpcctx.WithPeerConnectionInfo( + ctx, drpcctx.PeerConnectionInfo{Certificates: state.PeerCertificates}) + } + } + + sess, err := yamux.Server(conn, nil) + if err != nil { + return + } + defer sess.Close() + + s.handleSession(ctx, sess) +} + +// handleSession accepts and serves streams from a yamux session +func (s *Server) handleSession(ctx context.Context, sess *yamux.Session) { + var wg sync.WaitGroup + defer wg.Wait() + + // Close session when context is cancelled + done := make(chan struct{}) + defer close(done) + + go func() { + select { + case <-ctx.Done(): + sess.Close() + case <-done: + } + }() + + for { + stream, err := sess.Accept() + if err != nil { + // Any error from Accept means the session is done + // Common errors: io.EOF (graceful close), session closed, etc. + return + } + + wg.Add(1) + go func() { + defer wg.Done() + s.srv.ServeOne(ctx, stream) + }() + } +} diff --git a/go.mod b/go.mod index bbed14e..76ed515 100644 --- a/go.mod +++ b/go.mod @@ -1,8 +1,11 @@ module storj.io/drpc -go 1.19 +go 1.23.0 + +toolchain go1.24.9 require ( + github.com/hashicorp/yamux v0.1.2 github.com/stretchr/testify v1.10.0 github.com/zeebo/assert v1.3.0 github.com/zeebo/errs v1.2.2 diff --git a/go.sum b/go.sum index 4655a0d..629138f 100644 --- a/go.sum +++ b/go.sum @@ -6,6 +6,8 @@ github.com/golang/protobuf v1.5.3/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiu github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/hashicorp/yamux v0.1.2 h1:XtB8kyFOyHXYVFnwT5C3+Bdo8gArse7j2AQ0DA0Uey8= +github.com/hashicorp/yamux v0.1.2/go.mod h1:C+zze2n6e/7wshOZep2A70/aQU6QBRWJO/G6FT1wIns= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=