Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
134 changes: 134 additions & 0 deletions drpcyamux/conn.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
// Copyright (C) 2023 Elara Musayelyan
// Copyright (C) 2025 Cockroach Labs
// See LICENSE for copying information.

package drpcyamux

import (
"context"
"errors"
"io"
"sync"

"github.com/hashicorp/yamux"
"storj.io/drpc"
"storj.io/drpc/drpcconn"
)

var ErrClosed = errors.New("connection closed")

var _ drpc.Conn = &Conn{}

// Conn implements drpc.Conn using the yamux multiplexer to allow concurrent
// RPCs
type Conn struct {
conn io.ReadWriteCloser
sess *yamux.Session

closeOnce sync.Once
closeErr error
closed chan struct{}
}

// NewConn returns a new multiplexed DRPC connection as a client
func NewConn(conn io.ReadWriteCloser) (*Conn, error) {
return NewConnWithConfig(conn, nil)
}

// NewConnWithConfig returns a new multiplexed DRPC connection as a client
// with the given yamux configuration
func NewConnWithConfig(conn io.ReadWriteCloser, config *yamux.Config) (*Conn, error) {
sess, err := yamux.Client(conn, config)
if err != nil {
return nil, err
}

return &Conn{
conn: conn,
sess: sess,
closed: make(chan struct{}),
}, nil
}

// Close closes the multiplexer session and the underlying connection. It is
// safe to call Close multiple times.
func (c *Conn) Close() error {
c.closeOnce.Do(func() {
close(c.closed)

// Close session first to stop accepting new streams
sessErr := c.sess.Close()

// Always close the underlying connection
connErr := c.conn.Close()

// Return the first error encountered
if sessErr != nil {
c.closeErr = sessErr
} else {
c.closeErr = connErr
}
})
return c.closeErr
}

// Closed returns a channel that will be closed
// when the connection is closed
func (c *Conn) Closed() <-chan struct{} {
return c.closed
}

// Invoke issues the rpc on the transport serializing in, waits for a response,
// and deserializes it into out.
func (c *Conn) Invoke(
ctx context.Context, rpc string, enc drpc.Encoding, in, out drpc.Message,
) error {
select {
case <-c.closed:
return ErrClosed
default:
}

stream, err := c.sess.Open()
if err != nil {
return err
}
defer stream.Close()

dconn := drpcconn.New(stream)
defer dconn.Close()

return dconn.Invoke(ctx, rpc, enc, in, out)
}

// NewStream begins a streaming rpc on the connection.
func (c *Conn) NewStream(ctx context.Context, rpc string, enc drpc.Encoding) (drpc.Stream, error) {
select {
case <-c.closed:
return nil, ErrClosed
default:
}

stream, err := c.sess.Open()
if err != nil {
return nil, err
}

dconn := drpcconn.New(stream)

s, err := dconn.NewStream(ctx, rpc, enc)
if err != nil {
dconn.Close()
stream.Close()
return nil, err
}

// Clean up the yamux stream when the drpc connection closes.
// This goroutine will exit when dconn.Closed() is signaled.
go func() {
<-dconn.Closed()
stream.Close()
}()

return s, nil
}
128 changes: 128 additions & 0 deletions drpcyamux/server.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
// Copyright (C) 2023 Elara Musayelyan
// Copyright (C) 2025 Cockroach Labs
// See LICENSE for copying information.

package drpcyamux

import (
"context"
"crypto/tls"
"errors"
"net"
"sync"

"github.com/hashicorp/yamux"
"storj.io/drpc"
"storj.io/drpc/drpcctx"
"storj.io/drpc/drpcserver"
)

// Server is a DRPC server that handles multiplexed streams
type Server struct {
srv *drpcserver.Server
}

// NewServer creates a new multiplexing DRPC server with default options
func NewServer(handler drpc.Handler) *Server {
return &Server{srv: drpcserver.New(handler)}
}

// NewServerWithOptions creates a new multiplexing DRPC server with custom options
func NewServerWithOptions(handler drpc.Handler, opts drpcserver.Options) *Server {
return &Server{srv: drpcserver.NewWithOptions(handler, opts)}
}

// Serve listens on the given listener and handles all multiplexed streams.
// It blocks until the context is canceled or an unrecoverable error occurs.
func (s *Server) Serve(ctx context.Context, ln net.Listener) error {
var wg sync.WaitGroup
defer wg.Wait()

// Context for coordinating shutdown
ctx, cancel := context.WithCancel(ctx)
defer cancel()

for {
conn, err := ln.Accept()
if err != nil {
// Check if we're shutting down
select {
case <-ctx.Done():
return nil
default:
}

// If listener was closed, treat it as shutdown
var opErr *net.OpError
if errors.As(err, &opErr) && opErr.Op == "accept" {
return nil
}

return err
}

wg.Add(1)
go func() {
defer wg.Done()
s.handleConn(ctx, conn)
}()
}
}

// handleConn processes a single connection with multiplexing
func (s *Server) handleConn(ctx context.Context, conn net.Conn) {
defer conn.Close()

if tlsConn, ok := conn.(*tls.Conn); ok {
err := tlsConn.Handshake()
if err != nil {
return
}
state := tlsConn.ConnectionState()
if len(state.PeerCertificates) > 0 {
ctx = drpcctx.WithPeerConnectionInfo(
ctx, drpcctx.PeerConnectionInfo{Certificates: state.PeerCertificates})
}
}

sess, err := yamux.Server(conn, nil)
if err != nil {
return
}
defer sess.Close()

s.handleSession(ctx, sess)
}

// handleSession accepts and serves streams from a yamux session
func (s *Server) handleSession(ctx context.Context, sess *yamux.Session) {
var wg sync.WaitGroup
defer wg.Wait()

// Close session when context is cancelled
done := make(chan struct{})
defer close(done)

go func() {
select {
case <-ctx.Done():
sess.Close()
case <-done:
}
}()

for {
stream, err := sess.Accept()
if err != nil {
// Any error from Accept means the session is done
// Common errors: io.EOF (graceful close), session closed, etc.
return
}

wg.Add(1)
go func() {
defer wg.Done()
s.srv.ServeOne(ctx, stream)
}()
}
}
5 changes: 4 additions & 1 deletion go.mod
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
module storj.io/drpc

go 1.19
go 1.23.0
Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll update this once and for all in a separate PR.


toolchain go1.24.9

require (
github.com/hashicorp/yamux v0.1.2
github.com/stretchr/testify v1.10.0
github.com/zeebo/assert v1.3.0
github.com/zeebo/errs v1.2.2
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ github.com/golang/protobuf v1.5.3/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiu
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38=
github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/hashicorp/yamux v0.1.2 h1:XtB8kyFOyHXYVFnwT5C3+Bdo8gArse7j2AQ0DA0Uey8=
github.com/hashicorp/yamux v0.1.2/go.mod h1:C+zze2n6e/7wshOZep2A70/aQU6QBRWJO/G6FT1wIns=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
Expand Down