diff --git a/drpcconn/conn.go b/drpcconn/conn.go index 636f346..0b0c509 100644 --- a/drpcconn/conn.go +++ b/drpcconn/conn.go @@ -29,10 +29,20 @@ type Options struct { CollectStats bool } +// streamManager is the interface satisfied by both drpcmanager.Manager (non-mux) +// and drpcmanager.MuxManager (mux). +type streamManager interface { + NewClientStream(ctx context.Context, rpc string) (*drpcstream.Stream, error) + Closed() <-chan struct{} + Unblocked() <-chan struct{} + Close() error +} + // Conn is a drpc client connection. type Conn struct { tr drpc.Transport - man *drpcmanager.Manager + man streamManager + mux bool mu sync.Mutex wbuf []byte @@ -56,7 +66,12 @@ func NewWithOptions(tr drpc.Transport, opts Options) *Conn { c.stats = make(map[string]*drpcstats.Stats) } - c.man = drpcmanager.NewWithOptions(tr, opts.Manager) + c.mux = opts.Manager.Mux + if c.mux { + c.man = drpcmanager.NewMuxWithOptions(tr, opts.Manager) + } else { + c.man = drpcmanager.NewWithOptions(tr, opts.Manager) + } return c } @@ -100,8 +115,9 @@ func (c *Conn) Unblocked() <-chan struct{} { return c.man.Unblocked() } // Close closes the connection. func (c *Conn) Close() (err error) { return c.man.Close() } -// Invoke issues the rpc on the transport serializing in, waits for a response, and -// deserializes it into out. Only one Invoke or Stream may be open at a time. +// Invoke issues the rpc on the transport serializing in, waits for a response, +// and deserializes it into out. In non-mux mode, only one Invoke or Stream may +// be open at a time. In mux mode, multiple calls may be open concurrently. func (c *Conn) Invoke(ctx context.Context, rpc string, enc drpc.Encoding, in, out drpc.Message) (err error) { defer func() { err = drpc.ToRPCErr(err) }() @@ -117,18 +133,28 @@ func (c *Conn) Invoke(ctx context.Context, rpc string, enc drpc.Encoding, in, ou } defer func() { err = errs.Combine(err, stream.Close()) }() - // we have to protect c.wbuf here even though the manager only allows one - // stream at a time because the stream may async close allowing another - // concurrent call to Invoke to proceed. - c.mu.Lock() - defer c.mu.Unlock() - - c.wbuf, err = drpcenc.MarshalAppend(in, enc, c.wbuf[:0]) - if err != nil { - return err + var data []byte + if c.mux { + // Per-call buffer allocation for concurrent access. + data, err = drpcenc.MarshalAppend(in, enc, nil) + if err != nil { + return err + } + } else { + // We have to protect c.wbuf here even though the manager only allows + // one stream at a time because the stream may async close allowing + // another concurrent call to Invoke to proceed. + c.mu.Lock() + defer c.mu.Unlock() + + c.wbuf, err = drpcenc.MarshalAppend(in, enc, c.wbuf[:0]) + if err != nil { + return err + } + data = c.wbuf } - if err := c.doInvoke(stream, enc, rpc, c.wbuf, metadata, out); err != nil { + if err := c.doInvoke(stream, enc, rpc, data, metadata, out); err != nil { return err } return nil diff --git a/drpcconn/conn_test.go b/drpcconn/conn_test.go index 9f74516..2550c3b 100644 --- a/drpcconn/conn_test.go +++ b/drpcconn/conn_test.go @@ -180,7 +180,10 @@ func TestConn_NewStreamSendsGrpcAndDrpcMetadata(t *testing.T) { }) s, err := conn.NewStream(ctx, "/com.example.Foo/Bar", testEncoding{}) assert.NoError(t, err) - _ = s.CloseSend() + + assert.NoError(t, s.CloseSend()) + + ctx.Wait() } func TestConn_encodeMetadata(t *testing.T) { diff --git a/drpcmanager/frame_queue_test.go b/drpcmanager/frame_queue_test.go new file mode 100644 index 0000000..b175e6b --- /dev/null +++ b/drpcmanager/frame_queue_test.go @@ -0,0 +1,82 @@ +// Copyright (C) 2026 Cockroach Labs. +// See LICENSE for copying information. + +package drpcmanager + +import ( + "testing" + + "github.com/zeebo/assert" + "storj.io/drpc/drpcwire" +) + +func TestSharedWriteBuf_AppendDrain(t *testing.T) { + sw := newSharedWriteBuf() + + pkt := drpcwire.Packet{ + Data: []byte("hello"), + ID: drpcwire.ID{Stream: 1, Message: 2}, + Kind: drpcwire.KindMessage, + } + + assert.NoError(t, sw.Append(pkt)) + + // Drain should return serialized bytes. + data := sw.Drain(nil) + assert.That(t, len(data) > 0) + + // Parse the frame back out to verify correctness. + _, got, ok, err := drpcwire.ParseFrame(data) + assert.NoError(t, err) + assert.That(t, ok) + assert.DeepEqual(t, got.Data, pkt.Data) + assert.Equal(t, got.ID.Stream, pkt.ID.Stream) + assert.Equal(t, got.ID.Message, pkt.ID.Message) + assert.Equal(t, got.Kind, pkt.Kind) + assert.Equal(t, got.Done, true) +} + +func TestSharedWriteBuf_CloseIdempotent(t *testing.T) { + sw := newSharedWriteBuf() + sw.Close() + sw.Close() // must not panic +} + +func TestSharedWriteBuf_AppendAfterClose(t *testing.T) { + sw := newSharedWriteBuf() + sw.Close() + + err := sw.Append(drpcwire.Packet{}) + assert.Error(t, err) +} + +func TestSharedWriteBuf_WaitAndDrainBlocks(t *testing.T) { + sw := newSharedWriteBuf() + + done := make(chan struct{}) + go func() { + defer close(done) + data, ok := sw.WaitAndDrain(nil) + assert.That(t, ok) + assert.That(t, len(data) > 0) + }() + + // Append should wake the blocked WaitAndDrain. + assert.NoError(t, sw.Append(drpcwire.Packet{Data: []byte("a")})) + <-done +} + +func TestSharedWriteBuf_WaitAndDrainCloseEmpty(t *testing.T) { + sw := newSharedWriteBuf() + + done := make(chan struct{}) + go func() { + defer close(done) + _, ok := sw.WaitAndDrain(nil) + assert.That(t, !ok) + }() + + // Close on empty buffer should return ok=false. + sw.Close() + <-done +} diff --git a/drpcmanager/manager.go b/drpcmanager/manager.go index d730836..0f8c768 100644 --- a/drpcmanager/manager.go +++ b/drpcmanager/manager.go @@ -60,6 +60,11 @@ type Options struct { // handling. When enabled, the server stream will decode incoming metadata // into grpc metadata in the context. GRPCMetadataCompatMode bool + + // Mux enables stream multiplexing on the transport, allowing multiple + // concurrent streams. When false (default), the manager uses the + // original single-stream-at-a-time behavior. + Mux bool } // Manager handles the logic of managing a transport for a drpc client or @@ -306,6 +311,7 @@ func (m *Manager) newStream(ctx context.Context, sid uint64, kind drpc.StreamKin drpcopts.SetStreamStats(&opts.Internal, cb(rpc)) } + m.wr.Reset() stream := drpcstream.NewWithOptions(ctx, sid, m.wr, opts) select { case m.streams <- streamInfo{ctx: ctx, stream: stream}: diff --git a/drpcmanager/manager_mux.go b/drpcmanager/manager_mux.go new file mode 100644 index 0000000..48d01a2 --- /dev/null +++ b/drpcmanager/manager_mux.go @@ -0,0 +1,405 @@ +// Copyright (C) 2026 Cockroach Labs. +// See LICENSE for copying information. + +package drpcmanager + +import ( + "context" + "errors" + "fmt" + "io" + "sync" + "sync/atomic" + "time" + + grpcmetadata "google.golang.org/grpc/metadata" + + "storj.io/drpc" + "storj.io/drpc/drpcdebug" + "storj.io/drpc/drpcmetadata" + "storj.io/drpc/drpcsignal" + "storj.io/drpc/drpcstream" + "storj.io/drpc/drpcwire" + "storj.io/drpc/internal/drpcopts" +) + +// MuxManager handles the logic of managing a transport for a drpc client or +// server with stream multiplexing enabled. Multiple streams can be active +// concurrently on a single transport. +type MuxManager struct { + tr drpc.Transport + rd *drpcwire.Reader + opts Options + + sw *sharedWriteBuf + reg *streamRegistry + streamID atomic.Uint64 + wg sync.WaitGroup + pkts chan drpcwire.Packet + pdone drpcsignal.Chan + metaMu sync.Mutex + meta map[uint64]map[string]string + + sigs struct { + term drpcsignal.Signal + write drpcsignal.Signal + read drpcsignal.Signal + tport drpcsignal.Signal + } +} + +// NewMuxWithOptions returns a new mux manager for the transport. It uses the +// provided options to manage details of how it uses it. +func NewMuxWithOptions(tr drpc.Transport, opts Options) *MuxManager { + m := &MuxManager{ + tr: tr, + rd: drpcwire.NewReaderWithOptions(tr, opts.Reader), + opts: opts, + + pkts: make(chan drpcwire.Packet), + meta: make(map[uint64]map[string]string), + } + + // a buffer of size 1 allows the consumer of the packet to signal it is done + // without having to coordinate with the sender of the packet. + m.pdone.Make(1) + + // set the internal stream options + drpcopts.SetStreamTransport(&m.opts.Stream.Internal, m.tr) + drpcopts.SetStreamMux(&m.opts.Stream.Internal, true) + + m.sw = newSharedWriteBuf() + m.reg = newStreamRegistry() + + go m.manageReader() + go m.manageWriter() + + return m +} + +// String returns a string representation of the manager. +func (m *MuxManager) String() string { return fmt.Sprintf("", m) } + +func (m *MuxManager) log(what string, cb func() string) { + if drpcdebug.Enabled { + drpcdebug.Log(func() (_, _, _ string) { return m.String(), what, cb() }) + } +} + +// +// helpers +// + +// terminate puts the MuxManager into a terminal state and closes any resources +// that need to be closed to signal the state change. +func (m *MuxManager) terminate(err error) { + if m.sigs.term.Set(err) { + m.log("TERM", func() string { return fmt.Sprint(err) }) + m.sigs.tport.Set(m.tr.Close()) + m.sw.Close() + m.metaMu.Lock() + for id := range m.meta { + delete(m.meta, id) + } + m.metaMu.Unlock() + // Cancel all active streams so they get a clear error. + m.reg.ForEach(func(_ uint64, s *drpcstream.Stream) { + cancelErr := err + if errors.Is(cancelErr, io.EOF) { + cancelErr = context.Canceled + if s.Kind() == drpc.StreamKindClient { + cancelErr = drpc.ClosedError.New("connection closed") + } + } + s.Cancel(cancelErr) + }) + m.reg.Close() + } +} + +func (m *MuxManager) putMetadata(streamID uint64, metadata map[string]string) { + m.metaMu.Lock() + defer m.metaMu.Unlock() + m.meta[streamID] = metadata +} + +func (m *MuxManager) popMetadata(streamID uint64) map[string]string { + m.metaMu.Lock() + defer m.metaMu.Unlock() + metadata := m.meta[streamID] + delete(m.meta, streamID) + return metadata +} + +// +// manage reader +// + +// manageReader is always reading a packet and dispatching it to the appropriate +// stream or queue. It sets the read signal when it exits so that one can wait +// to ensure that no one is reading on the reader. It sets the term signal if +// there is any error reading packets. +func (m *MuxManager) manageReader() { + defer m.sigs.read.Set(nil) + + var pkt drpcwire.Packet + var err error + var run int + + for !m.sigs.term.IsSet() { + // if we have a run of "small" packets, drop the buffer to release + // memory so that a burst of large packets does not cause eternally + // large heap usage. + if run > 10 { + pkt.Data = nil + run = 0 + } + + pkt, err = m.rd.ReadPacketUsing(pkt.Data[:0]) + if err != nil { + if isConnectionReset(err) { + err = drpc.ClosedError.Wrap(err) + } + m.terminate(managerClosed.Wrap(err)) + return + } + + if len(pkt.Data) < cap(pkt.Data)/4 { + run++ + } else { + run = 0 + } + + m.log("READ", pkt.String) + + stream, ok := m.reg.Get(pkt.ID.Stream) + + switch { + // if the packet is for a registered stream, deliver it. + case ok && stream != nil: + if err := stream.HandlePacket(pkt); err != nil { + m.terminate(managerClosed.Wrap(err)) + return + } + // For message packets, HandlePacket transferred ownership of + // pkt.Data to the stream's packetBuffer. Acquire a fresh buffer + // from the pool so the next ReadPacketUsing doesn't allocate. + if pkt.Kind == drpcwire.KindMessage { + pkt.Data = drpcstream.AcquirePacketBuf() + } + + // if any invoke sequence is being sent, forward it to be handled. + case pkt.Kind == drpcwire.KindInvoke || pkt.Kind == drpcwire.KindInvokeMetadata: + select { + case m.pkts <- pkt: + m.pdone.Recv() + case <-m.sigs.term.Signal(): + return + } + + // silently drop packet for an unregistered stream + default: + m.log("DROP", pkt.String) + } + } +} + +// manageWriter drains the shared write buffer and writes pre-serialized +// bytes directly to the transport. It blocks on the sharedWriteBuf's +// condition variable until data is available, and naturally batches +// frames that accumulate while the previous write is in flight. +func (m *MuxManager) manageWriter() { + defer m.sigs.write.Set(nil) + + var spare []byte + for { + data, ok := m.sw.WaitAndDrain(spare[:0:cap(spare)]) + if !ok { + return + } + if _, err := m.tr.Write(data); err != nil { + m.terminate(managerClosed.Wrap(err)) + return + } + spare = data + } +} + +// +// manage streams +// + +// newStream creates a stream value with the appropriate configuration for this manager. +func (m *MuxManager) newStream(ctx context.Context, sid uint64, kind drpc.StreamKind, rpc string) (*drpcstream.Stream, error) { + opts := m.opts.Stream + drpcopts.SetStreamKind(&opts.Internal, kind) + drpcopts.SetStreamRPC(&opts.Internal, rpc) + if cb := drpcopts.GetManagerStatsCB(&m.opts.Internal); cb != nil { + drpcopts.SetStreamStats(&opts.Internal, cb(rpc)) + } + + stream := drpcstream.NewWithOptions(ctx, sid, &muxWriter{sw: m.sw}, opts) + + if err := m.reg.Register(sid, stream); err != nil { + return nil, err + } + + m.wg.Add(1) + go m.manageStream(ctx, stream) + + m.log("STREAM", stream.String) + return stream, nil +} + +// manageStream watches the context and the stream and returns when the stream +// is finished, canceling the stream if the context is canceled. +func (m *MuxManager) manageStream(ctx context.Context, stream *drpcstream.Stream) { + defer m.wg.Done() + defer m.reg.Unregister(stream.ID()) + + select { + case <-m.sigs.term.Signal(): + err := m.sigs.term.Err() + if errors.Is(err, io.EOF) { + err = context.Canceled + if stream.Kind() == drpc.StreamKindClient { + err = drpc.ClosedError.New("connection closed") + } + } + stream.Cancel(err) + <-stream.Finished() + + case <-stream.Finished(): + // stream finished naturally + + case <-ctx.Done(): + m.log("CANCEL", stream.String) + + if m.opts.SoftCancel { + // Best-effort send KindCancel, never terminate connection. + if busy, err := stream.SendCancel(ctx.Err()); err != nil { + m.log("CANCEL_ERR", func() string { + return fmt.Sprintf("%s: %v", stream.String(), err) + }) + } else if busy { + m.log("CANCEL_BUSY", stream.String) + } + stream.Cancel(ctx.Err()) + <-stream.Finished() + } else { + // Hard cancel: terminate connection if stream not finished. + if !stream.Cancel(ctx.Err()) { + m.log("UNFIN", stream.String) + m.terminate(ctx.Err()) + } else { + m.log("CLEAN", stream.String) + } + <-stream.Finished() + } + } +} + +// +// exported interface +// + +// Closed returns a channel that is closed once the manager is closed. +func (m *MuxManager) Closed() <-chan struct{} { + return m.sigs.term.Signal() +} + +// Unblocked returns a channel that is closed when the manager is available for +// new streams. With multiplexing enabled, the connection is never blocked, so +// this always returns an already-closed channel. +func (m *MuxManager) Unblocked() <-chan struct{} { + return closedCh +} + +// Close closes the transport the manager is using. +func (m *MuxManager) Close() error { + m.terminate(managerClosed.New("Close called")) + + m.wg.Wait() // wait for all stream goroutines + m.sigs.write.Wait() + m.sigs.read.Wait() + m.sigs.tport.Wait() + + return m.sigs.tport.Err() +} + +// NewClientStream starts a stream on the managed transport for use by a client. +func (m *MuxManager) NewClientStream(ctx context.Context, rpc string) (stream *drpcstream.Stream, err error) { + if err, ok := m.sigs.term.Get(); ok { + return nil, err + } + sid := m.streamID.Add(1) + return m.newStream(ctx, sid, drpc.StreamKindClient, rpc) +} + +// NewServerStream starts a stream on the managed transport for use by a server. +// It does this by waiting for the client to issue an invoke message and +// returning the details. +func (m *MuxManager) NewServerStream(ctx context.Context) (stream *drpcstream.Stream, rpc string, err error) { + if err, ok := m.sigs.term.Get(); ok { + return nil, "", err + } + + var timeoutCh <-chan time.Time + + // set up the timeout on the context if necessary. + if timeout := m.opts.InactivityTimeout; timeout > 0 { + timer := time.NewTimer(timeout) + defer timer.Stop() + timeoutCh = timer.C + } + + for { + select { + case <-timeoutCh: + return nil, "", context.DeadlineExceeded + + case <-ctx.Done(): + return nil, "", ctx.Err() + + case <-m.sigs.term.Signal(): + return nil, "", m.sigs.term.Err() + + case pkt := <-m.pkts: + switch pkt.Kind { + case drpcwire.KindInvokeMetadata: + metadata, err := drpcmetadata.Decode(pkt.Data) + if err != nil { + m.pdone.Send() + return nil, "", err + } + m.putMetadata(pkt.ID.Stream, metadata) + m.pdone.Send() + + case drpcwire.KindInvoke: + rpc = string(pkt.Data) + streamCtx := ctx + + if metadata := m.popMetadata(pkt.ID.Stream); metadata != nil { + if m.opts.GRPCMetadataCompatMode { + grpcMeta := make(map[string][]string, len(metadata)) + for k, v := range metadata { + grpcMeta[k] = []string{v} + } + streamCtx = grpcmetadata.NewIncomingContext(streamCtx, grpcMeta) + } else { + streamCtx = drpcmetadata.NewIncomingContext(streamCtx, metadata) + } + } + stream, err := m.newStream(streamCtx, pkt.ID.Stream, drpc.StreamKindServer, rpc) + // Ack the invoke only after stream registration so subsequent + // message packets cannot be dropped for an unknown stream ID. + m.pdone.Send() + return stream, rpc, err + + default: + // this should never happen, but defensive. + m.pdone.Send() + } + } + } +} diff --git a/drpcmanager/manager_test.go b/drpcmanager/manager_test.go index 5918113..eca2ff4 100644 --- a/drpcmanager/manager_test.go +++ b/drpcmanager/manager_test.go @@ -16,6 +16,7 @@ import ( grpcmetadata "google.golang.org/grpc/metadata" "storj.io/drpc/drpcmetadata" + "storj.io/drpc/drpcstream" "storj.io/drpc/drpctest" "storj.io/drpc/drpcwire" ) @@ -161,6 +162,184 @@ func TestDrpcMetadataWithGRPCMetadataCompatMode(t *testing.T) { ctx.Wait() } +func TestMux_DrpcMetadataInterleavedAcrossStreams(t *testing.T) { + ctx := drpctest.NewTracker(t) + defer ctx.Close() + + cconn, sconn := net.Pipe() + defer func() { _ = cconn.Close() }() + defer func() { _ = sconn.Close() }() + + cman := NewMuxWithOptions(cconn, Options{}) + defer func() { _ = cman.Close() }() + + sman := NewMuxWithOptions(sconn, Options{ + GRPCMetadataCompatMode: true, + }) + defer func() { _ = sman.Close() }() + + stream1, err := cman.NewClientStream(ctx, "rpc-1") + assert.NoError(t, err) + defer func() { _ = stream1.Close() }() + + stream2, err := cman.NewClientStream(ctx, "rpc-2") + assert.NoError(t, err) + defer func() { _ = stream2.Close() }() + + metadata1 := map[string]string{"stream": "one"} + metadata2 := map[string]string{"stream": "two"} + + buf1, err := drpcmetadata.Encode(nil, metadata1) + assert.NoError(t, err) + buf2, err := drpcmetadata.Encode(nil, metadata2) + assert.NoError(t, err) + + assert.NoError(t, stream1.RawWrite(drpcwire.KindInvokeMetadata, buf1)) + assert.NoError(t, stream2.RawWrite(drpcwire.KindInvokeMetadata, buf2)) + assert.NoError(t, stream1.RawWrite(drpcwire.KindInvoke, []byte("rpc-1"))) + assert.NoError(t, stream2.RawWrite(drpcwire.KindInvoke, []byte("rpc-2"))) + + srvStream1, rpc1, err := sman.NewServerStream(ctx) + assert.NoError(t, err) + assert.Equal(t, "rpc-1", rpc1) + defer func() { _ = srvStream1.Close() }() + + got1, ok := grpcmetadata.FromIncomingContext(srvStream1.Context()) + assert.That(t, ok) + assert.Equal(t, grpcmetadata.MD{"stream": []string{"one"}}, got1) + + srvStream2, rpc2, err := sman.NewServerStream(ctx) + assert.NoError(t, err) + assert.Equal(t, "rpc-2", rpc2) + defer func() { _ = srvStream2.Close() }() + + got2, ok := grpcmetadata.FromIncomingContext(srvStream2.Context()) + assert.That(t, ok) + assert.Equal(t, grpcmetadata.MD{"stream": []string{"two"}}, got2) +} + +func TestMux_NewServerStreamUnreadMessageDoesNotBlockOtherStreams(t *testing.T) { + ctx := drpctest.NewTracker(t) + defer ctx.Close() + + cconn, sconn := net.Pipe() + defer func() { _ = cconn.Close() }() + defer func() { _ = sconn.Close() }() + + cman := NewMuxWithOptions(cconn, Options{}) + defer func() { _ = cman.Close() }() + + sman := NewMuxWithOptions(sconn, Options{}) + defer func() { _ = sman.Close() }() + + stream1, err := cman.NewClientStream(ctx, "rpc-1") + assert.NoError(t, err) + defer func() { _ = stream1.Close() }() + + stream2, err := cman.NewClientStream(ctx, "rpc-2") + assert.NoError(t, err) + defer func() { _ = stream2.Close() }() + + assert.NoError(t, stream1.RawWrite(drpcwire.KindInvoke, []byte("rpc-1"))) + assert.NoError(t, stream1.RawWrite(drpcwire.KindMessage, []byte("message-1"))) + assert.NoError(t, stream2.RawWrite(drpcwire.KindInvoke, []byte("rpc-2"))) + + srvStream1, rpc1, err := sman.NewServerStream(ctx) + assert.NoError(t, err) + assert.Equal(t, "rpc-1", rpc1) + defer func() { _ = srvStream1.Close() }() + + // Do not read the first stream's message. The manager must still be able to + // accept and register additional streams. + timeoutCtx, cancel := context.WithTimeout(ctx, time.Second) + defer cancel() + srvStream2, rpc2, err := sman.NewServerStream(timeoutCtx) + assert.NoError(t, err) + assert.Equal(t, "rpc-2", rpc2) + defer func() { _ = srvStream2.Close() }() +} + +// TestMux_ConcurrentLargeMessages verifies that two streams writing messages +// larger than SplitSize concurrently do not corrupt each other's data. With the +// mux implementation, rawWriteLocked splits messages into multiple frames and +// each frame is appended to the shared write buffer independently. Frames from +// different streams can interleave in the buffer, and the reader resets partial +// packets when it sees a frame from a different stream, silently corrupting +// data. +func TestMux_ConcurrentLargeMessages(t *testing.T) { + ctx := drpctest.NewTracker(t) + defer ctx.Close() + + cconn, sconn := net.Pipe() + defer func() { _ = cconn.Close() }() + defer func() { _ = sconn.Close() }() + + // Use a small SplitSize to force even small messages to be split into + // multiple frames, making interleaving likely. + streamOpts := drpcstream.Options{SplitSize: 5} + + cman := NewMuxWithOptions(cconn, Options{Stream: streamOpts}) + defer func() { _ = cman.Close() }() + + sman := NewMuxWithOptions(sconn, Options{Stream: streamOpts}) + defer func() { _ = sman.Close() }() + + // Create two client streams and send invoke + message concurrently. + stream1, err := cman.NewClientStream(ctx, "rpc-1") + assert.NoError(t, err) + defer func() { _ = stream1.Close() }() + + stream2, err := cman.NewClientStream(ctx, "rpc-2") + assert.NoError(t, err) + defer func() { _ = stream2.Close() }() + + msg1 := []byte("AAAAAAAAAAAAAAAAAAAA") // 20 bytes, split into 4 frames of 5 bytes + msg2 := []byte("BBBBBBBBBBBBBBBBBBBB") // 20 bytes, split into 4 frames of 5 bytes + + // Send invokes first (these are small, no splitting). + assert.NoError(t, stream1.RawWrite(drpcwire.KindInvoke, []byte("rpc-1"))) + assert.NoError(t, stream2.RawWrite(drpcwire.KindInvoke, []byte("rpc-2"))) + + // Accept both server streams before sending messages, so the streams are + // registered and the reader can route packets. + srvStream1, rpc1, err := sman.NewServerStream(ctx) + assert.NoError(t, err) + assert.Equal(t, "rpc-1", rpc1) + defer func() { _ = srvStream1.Close() }() + + srvStream2, rpc2, err := sman.NewServerStream(ctx) + assert.NoError(t, err) + assert.Equal(t, "rpc-2", rpc2) + defer func() { _ = srvStream2.Close() }() + + // Write messages concurrently from both streams. With SplitSize=5, each + // 20-byte message becomes 4 frames. The frames should not interleave. + ready := make(chan struct{}) + var wg sync.WaitGroup + wg.Add(2) + go func() { + defer wg.Done() + <-ready + assert.NoError(t, stream1.RawWrite(drpcwire.KindMessage, msg1)) + }() + go func() { + defer wg.Done() + <-ready + assert.NoError(t, stream2.RawWrite(drpcwire.KindMessage, msg2)) + }() + close(ready) + wg.Wait() + + // Read from both server streams and verify correctness. + got1, err := srvStream1.RawRecv() + assert.NoError(t, err) + assert.DeepEqual(t, got1, msg1) + + got2, err := srvStream2.RawRecv() + assert.NoError(t, err) + assert.DeepEqual(t, got2, msg2) +} + type blockingTransport chan struct{} func (b blockingTransport) Read(p []byte) (n int, err error) { <-b; return 0, io.EOF } diff --git a/drpcmanager/mux_writer.go b/drpcmanager/mux_writer.go new file mode 100644 index 0000000..b36c25e --- /dev/null +++ b/drpcmanager/mux_writer.go @@ -0,0 +1,110 @@ +// Copyright (C) 2026 Cockroach Labs. +// See LICENSE for copying information. + +package drpcmanager + +import ( + "sync" + + "storj.io/drpc/drpcwire" +) + +// muxWriter implements drpcwire.StreamWriter by serializing packet bytes into a +// shared write buffer. The manageWriter goroutine drains the buffer and writes +// directly to the transport. +// +// The entire packet is serialized as a single frame (via AppendFrame) under one +// mutex hold, so frames from concurrent streams never interleave on the wire. +// The packet's Data slice is consumed (copied) before WritePacket returns, so +// callers may safely reuse their buffers afterward. +type muxWriter struct { + sw *sharedWriteBuf +} + +func (w *muxWriter) WritePacket(pkt drpcwire.Packet) error { + return w.sw.Append(pkt) +} + +// Flush is a no-op because the manageWriter goroutine flushes to the +// transport after draining the shared buffer. +func (w *muxWriter) Flush() error { return nil } + +func (w *muxWriter) Empty() bool { return true } + +// sharedWriteBuf collects serialized frame bytes from multiple concurrent +// producers. A single consumer (manageWriter) drains the buffer and writes +// the pre-serialized bytes to the transport. +type sharedWriteBuf struct { + mu sync.Mutex + cond *sync.Cond + buf []byte + closed bool +} + +func newSharedWriteBuf() *sharedWriteBuf { + sw := &sharedWriteBuf{} + sw.cond = sync.NewCond(&sw.mu) + return sw +} + +// Append serializes pkt as a single frame into the shared buffer. The packet's +// Data slice is consumed (copied by AppendFrame) before Append returns. +func (sw *sharedWriteBuf) Append(pkt drpcwire.Packet) error { + sw.mu.Lock() + if sw.closed { + sw.mu.Unlock() + return managerClosed.New("enqueue") + } + sw.buf = drpcwire.AppendFrame(sw.buf, drpcwire.Frame{ + Data: pkt.Data, + ID: pkt.ID, + Kind: pkt.Kind, + Control: pkt.Control, + Done: true, + }) + sw.mu.Unlock() + + sw.cond.Signal() + return nil +} + +// Drain swaps out accumulated bytes, giving the caller ownership of the +// returned slice. The internal buffer is replaced with spare (reset to zero +// length) so producers can continue appending without allocation. +func (sw *sharedWriteBuf) Drain(spare []byte) []byte { + sw.mu.Lock() + data := sw.buf + sw.buf = spare + sw.mu.Unlock() + return data +} + +// WaitAndDrain blocks until data is available or the buffer is closed. +// Returns the accumulated bytes and true if data was available, or nil and +// false if the buffer is closed and empty. +func (sw *sharedWriteBuf) WaitAndDrain(spare []byte) ([]byte, bool) { + sw.mu.Lock() + for len(sw.buf) == 0 && !sw.closed { + sw.cond.Wait() + } + if sw.closed && len(sw.buf) == 0 { + sw.mu.Unlock() + return nil, false + } + data := sw.buf + sw.buf = spare + sw.mu.Unlock() + return data, true +} + +// Close marks the buffer as closed and wakes the consumer. +func (sw *sharedWriteBuf) Close() { + sw.mu.Lock() + defer sw.mu.Unlock() + + if sw.closed { + return + } + sw.closed = true + sw.cond.Broadcast() +} diff --git a/drpcmanager/mux_writer_test.go b/drpcmanager/mux_writer_test.go new file mode 100644 index 0000000..50c4e1c --- /dev/null +++ b/drpcmanager/mux_writer_test.go @@ -0,0 +1,80 @@ +// Copyright (C) 2026 Cockroach Labs. +// See LICENSE for copying information. + +package drpcmanager + +import ( + "testing" + + "github.com/zeebo/assert" + "storj.io/drpc/drpcwire" +) + +func TestMuxWriter_WritePacket(t *testing.T) { + sw := newSharedWriteBuf() + w := &muxWriter{sw: sw} + + pkt := drpcwire.Packet{ + Data: []byte("hello"), + ID: drpcwire.ID{Stream: 1, Message: 2}, + Kind: drpcwire.KindMessage, + } + + assert.NoError(t, w.WritePacket(pkt)) + + data := sw.Drain(nil) + _, got, ok, err := drpcwire.ParseFrame(data) + assert.NoError(t, err) + assert.That(t, ok) + assert.DeepEqual(t, got.Data, pkt.Data) + assert.Equal(t, got.ID.Stream, pkt.ID.Stream) + assert.Equal(t, got.ID.Message, pkt.ID.Message) + assert.Equal(t, got.Kind, pkt.Kind) + assert.Equal(t, got.Done, true) +} + +func TestMuxWriter_WritePacketIsolatesData(t *testing.T) { + sw := newSharedWriteBuf() + w := &muxWriter{sw: sw} + + data := []byte("hello") + pkt := drpcwire.Packet{ + Data: data, + ID: drpcwire.ID{Stream: 1, Message: 2}, + Kind: drpcwire.KindMessage, + } + + assert.NoError(t, w.WritePacket(pkt)) + + // Mutate the original source buffer after WritePacket. + data[0] = 'j' + + // The serialized data in the shared buffer should be unaffected because + // AppendFrame copies the bytes during serialization. + buf := sw.Drain(nil) + _, got, ok, err := drpcwire.ParseFrame(buf) + assert.NoError(t, err) + assert.That(t, ok) + assert.DeepEqual(t, got.Data, []byte("hello")) +} + +func TestMuxWriter_FlushNoop(t *testing.T) { + sw := newSharedWriteBuf() + w := &muxWriter{sw: sw} + assert.NoError(t, w.Flush()) +} + +func TestMuxWriter_Empty(t *testing.T) { + sw := newSharedWriteBuf() + w := &muxWriter{sw: sw} + assert.That(t, w.Empty()) +} + +func TestMuxWriter_WritePacketAfterClose(t *testing.T) { + sw := newSharedWriteBuf() + w := &muxWriter{sw: sw} + sw.Close() + + err := w.WritePacket(drpcwire.Packet{}) + assert.Error(t, err) +} diff --git a/drpcmanager/registry.go b/drpcmanager/registry.go new file mode 100644 index 0000000..b6e90b1 --- /dev/null +++ b/drpcmanager/registry.go @@ -0,0 +1,89 @@ +// Copyright (C) 2026 Cockroach Labs. +// See LICENSE for copying information. + +package drpcmanager + +import ( + "sync" + + "storj.io/drpc/drpcstream" +) + +// streamRegistry is a thread-safe map of stream IDs to stream objects. +// It is used by the MuxManager to track all active streams for lifecycle +// management and packet routing. +type streamRegistry struct { + mu sync.RWMutex + streams map[uint64]*drpcstream.Stream + closed bool +} + +func newStreamRegistry() *streamRegistry { + return &streamRegistry{ + streams: make(map[uint64]*drpcstream.Stream), + } +} + +// Register adds a stream to the registry. It returns an error if the registry +// is closed or if a stream with the same ID is already registered. +func (r *streamRegistry) Register(id uint64, stream *drpcstream.Stream) error { + r.mu.Lock() + defer r.mu.Unlock() + + if r.closed { + return managerClosed.New("register") + } + if _, ok := r.streams[id]; ok { + return managerClosed.New("duplicate stream id") + } + r.streams[id] = stream + return nil +} + +// Unregister removes a stream from the registry. It is a no-op if the stream +// is not registered or if the registry has been closed. +func (r *streamRegistry) Unregister(id uint64) { + r.mu.Lock() + defer r.mu.Unlock() + + if r.streams != nil { + delete(r.streams, id) + } +} + +// Get returns the stream for the given ID and whether it was found. +func (r *streamRegistry) Get(id uint64) (*drpcstream.Stream, bool) { + r.mu.RLock() + defer r.mu.RUnlock() + + s, ok := r.streams[id] + return s, ok +} + +// Close marks the registry as closed, preventing future Register calls. +// It does not cancel any streams. +func (r *streamRegistry) Close() { + r.mu.Lock() + defer r.mu.Unlock() + + r.closed = true +} + +// ForEach calls fn for each registered stream. The function is called with +// the stream ID and stream pointer. The registry is read-locked during iteration. +func (r *streamRegistry) ForEach(fn func(uint64, *drpcstream.Stream)) { + r.mu.RLock() + defer r.mu.RUnlock() + + for id, s := range r.streams { + fn(id, s) + } +} + +// Len returns the number of registered streams. +func (r *streamRegistry) Len() int { + r.mu.RLock() + defer r.mu.RUnlock() + + return len(r.streams) +} diff --git a/drpcmanager/registry_test.go b/drpcmanager/registry_test.go new file mode 100644 index 0000000..d27359b --- /dev/null +++ b/drpcmanager/registry_test.go @@ -0,0 +1,136 @@ +// Copyright (C) 2026 Cockroach Labs. +// See LICENSE for copying information. + +package drpcmanager + +import ( + "context" + "testing" + + "github.com/zeebo/assert" + + "storj.io/drpc/drpcstream" +) + +func testStream(id uint64) *drpcstream.Stream { + return drpcstream.New(context.Background(), id, &muxWriter{sw: newSharedWriteBuf()}) +} + +func TestStreamRegistry_RegisterAndGet(t *testing.T) { + reg := newStreamRegistry() + s := testStream(1) + + assert.NoError(t, reg.Register(1, s)) + + got, ok := reg.Get(1) + assert.That(t, ok) + assert.Equal(t, got, s) +} + +func TestStreamRegistry_GetMissing(t *testing.T) { + reg := newStreamRegistry() + + got, ok := reg.Get(42) + assert.That(t, !ok) + assert.Nil(t, got) +} + +func TestStreamRegistry_Unregister(t *testing.T) { + reg := newStreamRegistry() + s := testStream(1) + + assert.NoError(t, reg.Register(1, s)) + assert.Equal(t, reg.Len(), 1) + + reg.Unregister(1) + + _, ok := reg.Get(1) + assert.That(t, !ok) + assert.Equal(t, reg.Len(), 0) +} + +func TestStreamRegistry_UnregisterIdempotent(t *testing.T) { + reg := newStreamRegistry() + + // must not panic when unregistering a non-existent ID + reg.Unregister(99) +} + +func TestStreamRegistry_DuplicateRegister(t *testing.T) { + reg := newStreamRegistry() + s1 := testStream(1) + s2 := testStream(1) + + assert.NoError(t, reg.Register(1, s1)) + assert.Error(t, reg.Register(1, s2)) + + // original stream is still registered + got, ok := reg.Get(1) + assert.That(t, ok) + assert.Equal(t, got, s1) +} + +func TestStreamRegistry_RegisterAfterClose(t *testing.T) { + reg := newStreamRegistry() + reg.Close() + + err := reg.Register(1, testStream(1)) + assert.Error(t, err) +} + +func TestStreamRegistry_UnregisterAfterClose(t *testing.T) { + reg := newStreamRegistry() + s := testStream(1) + assert.NoError(t, reg.Register(1, s)) + + reg.Close() + + // must not panic + reg.Unregister(1) +} + +func TestStreamRegistry_Len(t *testing.T) { + reg := newStreamRegistry() + assert.Equal(t, reg.Len(), 0) + + assert.NoError(t, reg.Register(1, testStream(1))) + assert.Equal(t, reg.Len(), 1) + + assert.NoError(t, reg.Register(2, testStream(2))) + assert.Equal(t, reg.Len(), 2) + + reg.Unregister(1) + assert.Equal(t, reg.Len(), 1) +} + +func TestStreamRegistry_ForEach(t *testing.T) { + reg := newStreamRegistry() + s1 := testStream(1) + s2 := testStream(2) + s3 := testStream(3) + + assert.NoError(t, reg.Register(1, s1)) + assert.NoError(t, reg.Register(2, s2)) + assert.NoError(t, reg.Register(3, s3)) + + seen := make(map[uint64]*drpcstream.Stream) + reg.ForEach(func(id uint64, s *drpcstream.Stream) { + seen[id] = s + }) + + assert.Equal(t, len(seen), 3) + assert.Equal(t, seen[1], s1) + assert.Equal(t, seen[2], s2) + assert.Equal(t, seen[3], s3) +} + +func TestStreamRegistry_ForEach_Empty(t *testing.T) { + reg := newStreamRegistry() + + count := 0 + reg.ForEach(func(id uint64, s *drpcstream.Stream) { + count++ + }) + + assert.Equal(t, count, 0) +} diff --git a/drpcserver/server.go b/drpcserver/server.go index b8f95d9..9337182 100644 --- a/drpcserver/server.go +++ b/drpcserver/server.go @@ -92,6 +92,19 @@ func (s *Server) getStats(rpc string) *drpcstats.Stats { // ServeOne serves a single set of rpcs on the provided transport. func (s *Server) ServeOne(ctx context.Context, tr drpc.Transport) (err error) { + ctx, err = s.setupConnection(ctx, tr) + if err != nil { + return err + } + if s.opts.Manager.Mux { + return s.serveOneMux(ctx, tr) + } + return s.serveOneNonMux(ctx, tr) +} + +// setupConnection performs TLS handshake (if applicable) and sets up the +// drpccache context. It is shared by both mux and non-mux serve paths. +func (s *Server) setupConnection(ctx context.Context, tr drpc.Transport) (context.Context, error) { // Check if the transport is a TLS connection if tlsConn, ok := tr.(*tls.Conn); ok { // Manually perform the TLS handshake to access peer certificate @@ -109,7 +122,7 @@ func (s *Server) ServeOne(ctx context.Context, tr drpc.Transport) (err error) { // anyway. err := tlsConn.HandshakeContext(ctx) if err != nil { - return drpc.ConnectionError.New("server handshake [%q] failed: %w", tlsConn.RemoteAddr(), err) + return ctx, drpc.ConnectionError.New("server handshake [%q] failed: %w", tlsConn.RemoteAddr(), err) } state := tlsConn.ConnectionState() if len(state.PeerCertificates) > 0 { @@ -117,7 +130,11 @@ func (s *Server) ServeOne(ctx context.Context, tr drpc.Transport) (err error) { ctx, drpcctx.PeerConnectionInfo{Certificates: state.PeerCertificates}) } } + return ctx, nil +} +// serveOneNonMux serves rpcs sequentially using the non-mux Manager. +func (s *Server) serveOneNonMux(ctx context.Context, tr drpc.Transport) (err error) { man := drpcmanager.NewWithOptions(tr, s.opts.Manager) defer func() { err = errs.Combine(err, man.Close()) }() @@ -137,6 +154,38 @@ func (s *Server) ServeOne(ctx context.Context, tr drpc.Transport) (err error) { } } +// serveOneMux serves rpcs concurrently using the MuxManager. +func (s *Server) serveOneMux(ctx context.Context, tr drpc.Transport) (err error) { + man := drpcmanager.NewMuxWithOptions(tr, s.opts.Manager) + + var wg sync.WaitGroup + defer func() { + wg.Wait() + err = errs.Combine(err, man.Close()) + }() + + cache := drpccache.New() + defer cache.Clear() + + ctx = drpccache.WithContext(ctx, cache) + + for { + stream, rpc, err := man.NewServerStream(ctx) + if err != nil { + return errs.Wrap(err) + } + wg.Add(1) + go func() { + defer wg.Done() + if err := s.handleRPC(stream, rpc); err != nil { + if s.opts.Log != nil { + s.opts.Log(err) + } + } + }() + } +} + var temporarySleep = 500 * time.Millisecond // Serve listens for connections on the listener and serves the drpc request diff --git a/drpcstream/pktbuf.go b/drpcstream/pktbuf.go index db68864..5854d86 100644 --- a/drpcstream/pktbuf.go +++ b/drpcstream/pktbuf.go @@ -7,7 +7,22 @@ import ( "sync" ) -type packetBuffer struct { +// packetStore is the interface for packet buffer implementations. +// syncPacketBuffer is used for non-mux mode (blocking, single-slot), +// queuePacketBuffer is used for mux mode (non-blocking, queued). +type packetStore interface { + Put(data []byte) + Get() ([]byte, error) + Close(err error) + Done() + Recycle([]byte) +} + +// syncPacketBuffer is the original single-slot, blocking packet buffer used +// in non-mux mode. Put blocks until the previous value is consumed via +// Get+Done, and the reader (manageReader) blocks in Put until the stream +// consumer finishes processing. +type syncPacketBuffer struct { mu sync.Mutex cond sync.Cond err error @@ -16,11 +31,11 @@ type packetBuffer struct { held bool } -func (pb *packetBuffer) init() { +func (pb *syncPacketBuffer) init() { pb.cond.L = &pb.mu } -func (pb *packetBuffer) Close(err error) { +func (pb *syncPacketBuffer) Close(err error) { pb.mu.Lock() defer pb.mu.Unlock() @@ -36,7 +51,7 @@ func (pb *packetBuffer) Close(err error) { } } -func (pb *packetBuffer) Put(data []byte) { +func (pb *syncPacketBuffer) Put(data []byte) { pb.mu.Lock() defer pb.mu.Unlock() @@ -57,7 +72,7 @@ func (pb *packetBuffer) Put(data []byte) { } } -func (pb *packetBuffer) Get() ([]byte, error) { +func (pb *syncPacketBuffer) Get() ([]byte, error) { pb.mu.Lock() defer pb.mu.Unlock() @@ -74,7 +89,7 @@ func (pb *packetBuffer) Get() ([]byte, error) { return pb.data, nil } -func (pb *packetBuffer) Done() { +func (pb *syncPacketBuffer) Done() { pb.mu.Lock() defer pb.mu.Unlock() @@ -83,3 +98,7 @@ func (pb *packetBuffer) Done() { pb.held = false pb.cond.Broadcast() } + +// Recycle is a no-op for syncPacketBuffer. Buffer lifetime is managed by +// the manageReader goroutine that owns the underlying slice. +func (pb *syncPacketBuffer) Recycle([]byte) {} diff --git a/drpcstream/pktbuf_mux.go b/drpcstream/pktbuf_mux.go new file mode 100644 index 0000000..cb152cb --- /dev/null +++ b/drpcstream/pktbuf_mux.go @@ -0,0 +1,116 @@ +// Copyright (C) 2026 Cockroach Labs. +// See LICENSE for copying information. + +package drpcstream + +import ( + "io" + "sync" +) + +// pktBuf wraps a byte slice so the pool stores a pointer type, avoiding +// an allocation on every Put. +type pktBuf struct { + data []byte +} + +// pktBufPool recycles byte slices used on the read path to avoid per-packet +// allocations. Buffers flow: AcquirePacketBuf -> reader -> Put (zero-copy +// ownership transfer) -> Get -> consumer -> Recycle -> pool. +var pktBufPool = sync.Pool{ + New: func() interface{} { return &pktBuf{} }, +} + +func recycleToPktBufPool(data []byte) { + pb := pktBufPool.Get().(*pktBuf) + pb.data = data[:0] + pktBufPool.Put(pb) +} + +// AcquirePacketBuf returns a byte slice from the shared packet buffer pool +// for use as a read buffer. Returns nil if the pool is empty. +func AcquirePacketBuf() []byte { + pb := pktBufPool.Get().(*pktBuf) + data := pb.data[:0] + pb.data = nil + pktBufPool.Put(pb) + return data +} + +// queuePacketBuffer is a non-blocking, queue-based packet buffer used in mux +// mode. Put appends to an unbounded queue and returns immediately, allowing +// the reader goroutine to continue dispatching packets to other streams. +type queuePacketBuffer struct { + mu sync.Mutex + cond sync.Cond + err error + data [][]byte +} + +func (pb *queuePacketBuffer) init() { + pb.cond.L = &pb.mu +} + +func (pb *queuePacketBuffer) Close(err error) { + pb.mu.Lock() + defer pb.mu.Unlock() + + if pb.err == nil { + // Preserve already-queued messages on graceful close so readers can + // drain them before seeing EOF. + if err != io.EOF { + for i := range pb.data { + recycleToPktBufPool(pb.data[i]) + pb.data[i] = nil + } + pb.data = pb.data[:0] + } + pb.err = err + pb.cond.Broadcast() + } +} + +// Put takes ownership of data. The caller must not use data after calling Put. +// If the buffer is closed, data is returned to the pool. +func (pb *queuePacketBuffer) Put(data []byte) { + pb.mu.Lock() + defer pb.mu.Unlock() + + if pb.err != nil { + recycleToPktBufPool(data) + return + } + + pb.data = append(pb.data, data) + pb.cond.Broadcast() +} + +func (pb *queuePacketBuffer) Get() ([]byte, error) { + pb.mu.Lock() + defer pb.mu.Unlock() + + for len(pb.data) == 0 && pb.err == nil { + pb.cond.Wait() + } + if len(pb.data) == 0 { + return nil, pb.err + } + + data := pb.data[0] + n := copy(pb.data, pb.data[1:]) + pb.data[n] = nil + pb.data = pb.data[:n] + return data, nil +} + +// Done is a no-op for queuePacketBuffer. Buffer ownership is transferred to +// the caller via Get, and recycling is done explicitly via Recycle. +func (pb *queuePacketBuffer) Done() {} + +// Recycle returns a buffer obtained from Get back to the pool. +// Call this after the data has been fully consumed (e.g. after Unmarshal). +func (pb *queuePacketBuffer) Recycle(buf []byte) { + if buf != nil { + recycleToPktBufPool(buf) + } +} diff --git a/drpcstream/stream.go b/drpcstream/stream.go index 29ccd63..378c169 100644 --- a/drpcstream/stream.go +++ b/drpcstream/stream.go @@ -53,8 +53,9 @@ type Stream struct { flush sync.Once id drpcwire.ID - wr *drpcwire.Writer - pbuf packetBuffer + mux bool + wr drpcwire.StreamWriter + pbuf packetStore wbuf []byte mu sync.Mutex // protects state transitions @@ -72,7 +73,7 @@ var _ drpc.Stream = (*Stream)(nil) // New returns a new stream bound to the context with the given stream id and // will use the writer to write messages on. It is important use monotonically // increasing stream ids within a single transport. -func New(ctx context.Context, sid uint64, wr *drpcwire.Writer) *Stream { +func New(ctx context.Context, sid uint64, wr drpcwire.StreamWriter) *Stream { return NewWithOptions(ctx, sid, wr, Options{}) } @@ -80,7 +81,9 @@ func New(ctx context.Context, sid uint64, wr *drpcwire.Writer) *Stream { // stream id and will use the writer to write messages on. It is important use // monotonically increasing stream ids within a single transport. The options // are used to control details of how the Stream operates. -func NewWithOptions(ctx context.Context, sid uint64, wr *drpcwire.Writer, opts Options) *Stream { +func NewWithOptions( + ctx context.Context, sid uint64, wr drpcwire.StreamWriter, opts Options, +) *Stream { var task *trace.Task if trace.IsEnabled() { kind, rpc := drpcopts.GetStreamKind(&opts.Internal), drpcopts.GetStreamRPC(&opts.Internal) @@ -89,6 +92,8 @@ func NewWithOptions(ctx context.Context, sid uint64, wr *drpcwire.Writer, opts O } } + mux := drpcopts.GetStreamMux(&opts.Internal) + s := &Stream{ ctx: streamCtx{ Context: ctx, @@ -98,12 +103,21 @@ func NewWithOptions(ctx context.Context, sid uint64, wr *drpcwire.Writer, opts O fin: drpcopts.GetStreamFin(&opts.Internal), task: task, - id: drpcwire.ID{Stream: sid}, - wr: wr.Reset(), + id: drpcwire.ID{Stream: sid}, + mux: mux, + wr: wr, } - // initialize the packet buffer - s.pbuf.init() + // initialize the packet buffer based on mode + if mux { + pb := new(queuePacketBuffer) + pb.init() + s.pbuf = pb + } else { + pb := new(syncPacketBuffer) + pb.init() + s.pbuf = pb + } return s } @@ -225,17 +239,19 @@ func (s *Stream) HandlePacket(pkt drpcwire.Packet) (err error) { drpcopts.GetStreamStats(&s.opts.Internal).AddRead(uint64(len(pkt.Data))) - if s.sigs.term.IsSet() { + // Put must always be called for message packets to manage buffer + // ownership. Put returns the buffer to the pool if the stream is closed. + if pkt.Kind == drpcwire.KindMessage { + s.pbuf.Put(pkt.Data) return nil } - s.log("HANDLE", pkt.String) - - if pkt.Kind == drpcwire.KindMessage { - s.pbuf.Put(pkt.Data) + if s.sigs.term.IsSet() { return nil } + s.log("HANDLE", pkt.String) + s.mu.Lock() defer s.mu.Unlock() @@ -314,26 +330,28 @@ func (s *Stream) checkCancelError(err error) error { return err } -// newFrameLocked bumps the internal message id and returns a frame. It must be +// nextID bumps the internal message id and returns the new ID. It must be // called under a mutex. -func (s *Stream) newFrameLocked(kind drpcwire.Kind) drpcwire.Frame { +func (s *Stream) nextID() drpcwire.ID { s.id.Message++ - return drpcwire.Frame{ID: s.id, Kind: kind} + return s.id } // sendPacketLocked sends the packet in a single write and flushes. It does not // check for any conditions to stop it from writing and is meant for internal // stream use to do things like signal errors or closes to the remote side. func (s *Stream) sendPacketLocked(kind drpcwire.Kind, control bool, data []byte) (err error) { - fr := s.newFrameLocked(kind) - fr.Data = data - fr.Control = control - fr.Done = true + pkt := drpcwire.Packet{ + ID: s.nextID(), + Kind: kind, + Data: data, + Control: control, + } drpcopts.GetStreamStats(&s.opts.Internal).AddWritten(uint64(len(data))) - s.log("SEND", fr.String) + s.log("SEND", pkt.String) - if err := s.wr.WriteFrame(fr); err != nil { + if err := s.wr.WritePacket(pkt); err != nil { return errs.Wrap(err) } if err := s.wr.Flush(); err != nil { @@ -376,29 +394,26 @@ func (s *Stream) RawWrite(kind drpcwire.Kind, data []byte) (err error) { // rawWriteLocked does the body of RawWrite assuming the caller is holding the // appropriate locks. func (s *Stream) rawWriteLocked(kind drpcwire.Kind, data []byte) (err error) { - fr := s.newFrameLocked(kind) - n := s.opts.SplitSize - - for { - switch { - case s.sigs.send.IsSet(): - return s.sigs.send.Err() - case s.sigs.term.IsSet(): - return s.sigs.term.Err() - } + switch { + case s.sigs.send.IsSet(): + return s.sigs.send.Err() + case s.sigs.term.IsSet(): + return s.sigs.term.Err() + } - fr.Data, data = drpcwire.SplitData(data, n) - fr.Done = len(data) == 0 + pkt := drpcwire.Packet{ + ID: s.nextID(), + Kind: kind, + Data: data, + } - drpcopts.GetStreamStats(&s.opts.Internal).AddWritten(uint64(len(fr.Data))) - s.log("SEND", fr.String) + drpcopts.GetStreamStats(&s.opts.Internal).AddWritten(uint64(len(data))) + s.log("SEND", pkt.String) - if err := s.wr.WriteFrame(fr); err != nil { - return s.checkCancelError(errs.Wrap(err)) - } else if fr.Done { - return nil - } + if err := s.wr.WritePacket(pkt); err != nil { + return s.checkCancelError(errs.Wrap(err)) } + return nil } // RawFlush flushes any buffers of data. @@ -461,6 +476,14 @@ func (s *Stream) RawRecv() (data []byte, err error) { if err != nil { return nil, err } + + if s.mux { + // In mux mode, the buffer is owned by the caller (from the pool). + return data, nil + } + + // In non-mux mode, copy the data and release the slot so the reader + // goroutine can continue. data = append([]byte(nil), data...) s.pbuf.Done() @@ -514,7 +537,14 @@ func (s *Stream) MsgRecv(msg drpc.Message, enc drpc.Encoding) (err error) { return err } err = enc.Unmarshal(data, msg) - s.pbuf.Done() + + if s.mux { + // In mux mode, return the buffer to the pool after unmarshal. + s.pbuf.Recycle(data) + } else { + // In non-mux mode, release the slot so the reader can continue. + s.pbuf.Done() + } return err } diff --git a/drpcwire/reader.go b/drpcwire/reader.go index c9ac397..ef801d3 100644 --- a/drpcwire/reader.go +++ b/drpcwire/reader.go @@ -143,7 +143,10 @@ func (r *Reader) ReadPacketUsing(buf []byte) (pkt Packet, err error) { pkt.Control = pkt.Control || fr.Control switch { - case fr.ID.Less(r.id): + case fr.ID.Stream == 0 || fr.ID.Message == 0: + return Packet{}, drpc.ProtocolError.New("id monotonicity violation (fr:%v r:%v)", fr.ID, r.id) + + case fr.ID.Stream == r.id.Stream && fr.ID.Less(r.id): return Packet{}, drpc.ProtocolError.New("id monotonicity violation (fr:%v r:%v)", fr.ID, r.id) case r.id != fr.ID || pkt.ID == ID{}: diff --git a/drpcwire/reader_test.go b/drpcwire/reader_test.go index d57145b..ec0cd3f 100644 --- a/drpcwire/reader_test.go +++ b/drpcwire/reader_test.go @@ -154,6 +154,51 @@ func TestReader(t *testing.T) { Frames: []Frame{{ID: ID{Stream: 0, Message: 1}}}, Error: "id monotonicity violation", }, + + { // cross-stream: frames from a lower stream after a higher stream are allowed (multiplexing) + Packets: []Packet{ + {Data: []byte("a"), ID: ID{Stream: 1, Message: 1}, Kind: KindMessage}, + {Data: []byte("b"), ID: ID{Stream: 2, Message: 1}, Kind: KindMessage}, + {Data: nil, ID: ID{Stream: 1, Message: 2}, Kind: KindClose}, + }, + Frames: []Frame{ + {Data: []byte("a"), ID: ID{Stream: 1, Message: 1}, Kind: KindMessage, Done: true}, + {Data: []byte("b"), ID: ID{Stream: 2, Message: 1}, Kind: KindMessage, Done: true}, + {Data: nil, ID: ID{Stream: 1, Message: 2}, Kind: KindClose, Done: true}, + }, + }, + + { // cross-stream: interleaved single-frame packets from multiple streams + Packets: []Packet{ + {Data: []byte("s1m1"), ID: ID{Stream: 1, Message: 1}, Kind: KindInvoke}, + {Data: []byte("s2m1"), ID: ID{Stream: 2, Message: 1}, Kind: KindInvoke}, + {Data: []byte("s1m2"), ID: ID{Stream: 1, Message: 2}, Kind: KindMessage}, + {Data: []byte("s2m2"), ID: ID{Stream: 2, Message: 2}, Kind: KindMessage}, + {Data: nil, ID: ID{Stream: 1, Message: 3}, Kind: KindCloseSend}, + {Data: nil, ID: ID{Stream: 2, Message: 3}, Kind: KindCloseSend}, + }, + Frames: []Frame{ + {Data: []byte("s1m1"), ID: ID{Stream: 1, Message: 1}, Kind: KindInvoke, Done: true}, + {Data: []byte("s2m1"), ID: ID{Stream: 2, Message: 1}, Kind: KindInvoke, Done: true}, + {Data: []byte("s1m2"), ID: ID{Stream: 1, Message: 2}, Kind: KindMessage, Done: true}, + {Data: []byte("s2m2"), ID: ID{Stream: 2, Message: 2}, Kind: KindMessage, Done: true}, + {Data: nil, ID: ID{Stream: 1, Message: 3}, Kind: KindCloseSend, Done: true}, + {Data: nil, ID: ID{Stream: 2, Message: 3}, Kind: KindCloseSend, Done: true}, + }, + }, + + { // cross-stream: within-stream monotonicity is still enforced + Packets: []Packet{ + {Data: []byte("a"), ID: ID{Stream: 1, Message: 1}, Kind: KindMessage}, + {Data: []byte("b"), ID: ID{Stream: 2, Message: 1}, Kind: KindMessage}, + }, + Frames: []Frame{ + {Data: []byte("a"), ID: ID{Stream: 1, Message: 1}, Kind: KindMessage, Done: true}, + {Data: []byte("b"), ID: ID{Stream: 2, Message: 1}, Kind: KindMessage, Done: true}, + {Data: []byte("c"), ID: ID{Stream: 2, Message: 1}, Kind: KindMessage, Done: true}, // replay on stream 2 + }, + Error: "id monotonicity violation", + }, } for _, tc := range cases { diff --git a/drpcwire/writer.go b/drpcwire/writer.go index fe909cf..bfcf0ac 100644 --- a/drpcwire/writer.go +++ b/drpcwire/writer.go @@ -12,6 +12,15 @@ import ( "storj.io/drpc/drpcdebug" ) +// StreamWriter is the interface used by streams for writing packets. Each call +// to WritePacket must serialize the full data atomically so that frames from +// concurrent writers on different streams do not interleave on the wire. +type StreamWriter interface { + WritePacket(pkt Packet) error + Flush() error + Empty() bool +} + // // Writer // diff --git a/internal/drpcopts/stream.go b/internal/drpcopts/stream.go index 6ab0511..7459347 100644 --- a/internal/drpcopts/stream.go +++ b/internal/drpcopts/stream.go @@ -15,6 +15,7 @@ type Stream struct { kind drpc.StreamKind rpc string stats *drpcstats.Stats + mux bool } // GetStreamTransport returns the drpc.Transport stored in the options. @@ -46,3 +47,9 @@ func GetStreamStats(opts *Stream) *drpcstats.Stats { return opts.stats } // SetStreamStats sets the Stats stored in the options. func SetStreamStats(opts *Stream, stats *drpcstats.Stats) { opts.stats = stats } + +// GetStreamMux returns whether the stream is in multiplexing mode. +func GetStreamMux(opts *Stream) bool { return opts.mux } + +// SetStreamMux sets whether the stream is in multiplexing mode. +func SetStreamMux(opts *Stream, mux bool) { opts.mux = mux } diff --git a/internal/integration/benchmark_test.go b/internal/integration/benchmark_test.go new file mode 100644 index 0000000..1e53184 --- /dev/null +++ b/internal/integration/benchmark_test.go @@ -0,0 +1,348 @@ +// Copyright (C) 2026 Cockroach Labs. +// See LICENSE for copying information. + +package integration + +import ( + "context" + "errors" + "io" + "strconv" + "sync" + "sync/atomic" + "testing" + "time" + + "google.golang.org/protobuf/proto" + "storj.io/drpc/drpcconn" + "storj.io/drpc/drpctest" +) + +var echoServer = impl{ + Method1Fn: func(_ context.Context, in *In) (*Out, error) { + return &Out{Out: in.In, Data: in.Data}, nil + }, + + Method2Fn: func(stream DRPCService_Method2Stream) error { + var last *In + for { + in, err := stream.Recv() + if errors.Is(err, io.EOF) { + break + } else if err != nil { + return err + } + last = in + } + if last == nil { + last = &In{} + } + return stream.SendAndClose(&Out{Out: last.In, Data: last.Data}) + }, + + Method3Fn: func(in *In, stream DRPCService_Method3Stream) error { + out := &Out{Out: 1, Data: in.Data} + for i := int64(0); i < in.In; i++ { + if err := stream.Send(out); err != nil { + return err + } + } + return nil + }, + + Method4Fn: func(stream DRPCService_Method4Stream) error { + for { + in, err := stream.Recv() + if errors.Is(err, io.EOF) { + return nil + } else if err != nil { + return err + } + if err := stream.Send(&Out{Out: in.In, Data: in.Data}); err != nil { + return err + } + } + }, +} + +var sizes = []struct { + name string + data []byte +}{ + {"small", nil}, + {"1KB", make([]byte, 1024)}, + {"8KB", make([]byte, 8192)}, +} + +var concurrencies = []int{1, 10, 100} +var activeStreamCounts = []int{2, 8, 32} + +const parallelActiveStreamsWriteLatency = 200 * time.Microsecond + +func BenchmarkUnary(b *testing.B) { + for _, sz := range sizes { + b.Run("size="+sz.name, func(b *testing.B) { + for _, c := range concurrencies { + b.Run("concurrent="+strconv.Itoa(c), func(b *testing.B) { + client, cleanup := createConnection(b, echoServer) + defer cleanup() + + in := &In{In: 1, Data: sz.data} + b.SetBytes(int64(proto.Size(in))) + b.ReportAllocs() + b.SetParallelism(c) + b.ResetTimer() + + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + if _, err := client.Method1(context.Background(), in); err != nil { + b.Error(err) + } + } + }) + }) + } + }) + } +} + +func BenchmarkInputStream(b *testing.B) { + for _, sz := range sizes { + b.Run("size="+sz.name, func(b *testing.B) { + client, cleanup := createConnection(b, echoServer) + defer cleanup() + + in := &In{In: 1, Data: sz.data} + stream, err := client.Method2(context.Background()) + if err != nil { + b.Fatal(err) + } + + b.SetBytes(int64(proto.Size(in))) + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + if err := stream.Send(in); err != nil { + b.Fatal(err) + } + } + + if _, err := stream.CloseAndRecv(); err != nil { + b.Fatal(err) + } + }) + } +} + +func BenchmarkOutputStream(b *testing.B) { + for _, sz := range sizes { + b.Run("size="+sz.name, func(b *testing.B) { + client, cleanup := createConnection(b, echoServer) + defer cleanup() + + in := &In{In: int64(b.N), Data: sz.data} + stream, err := client.Method3(context.Background(), in) + if err != nil { + b.Fatal(err) + } + + b.ReportAllocs() + b.ResetTimer() + + var last *Out + for i := 0; i < b.N; i++ { + last, err = stream.Recv() + if err != nil { + b.Fatal(err) + } + } + if last != nil { + b.SetBytes(int64(proto.Size(last))) + } + }) + } +} + +func BenchmarkBidiStream(b *testing.B) { + for _, sz := range sizes { + b.Run("size="+sz.name, func(b *testing.B) { + client, cleanup := createConnection(b, echoServer) + defer cleanup() + + in := &In{In: 1, Data: sz.data} + stream, err := client.Method4(context.Background()) + if err != nil { + b.Fatal(err) + } + + b.SetBytes(int64(proto.Size(in))) + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + if err := stream.Send(in); err != nil { + b.Fatal(err) + } + if _, err := stream.Recv(); err != nil { + b.Fatal(err) + } + } + }) + } +} + +func BenchmarkParallelActiveStreams(b *testing.B) { + benchmarkParallelActiveStreamsModes(b, 0) +} + +func BenchmarkParallelActiveStreamsWriteLatency(b *testing.B) { + benchmarkParallelActiveStreamsModes(b, parallelActiveStreamsWriteLatency) +} + +func benchmarkParallelActiveStreamsModes(b *testing.B, writeLatency time.Duration) { + transports := []struct { + name string + useTCP bool + }{ + {"Pipe", false}, + {"TCP", true}, + } + + modes := []struct { + name string + mux bool + oneConnPerStream bool + }{ + {"NonMux", false, true}, + {"Mux/ConnPerStream", true, true}, + {"Mux/SharedConn", true, false}, + } + + for _, sz := range sizes { + b.Run("size="+sz.name, func(b *testing.B) { + for _, streamCount := range activeStreamCounts { + b.Run("streams="+strconv.Itoa(streamCount), func(b *testing.B) { + for _, transport := range transports { + for _, mode := range modes { + b.Run(transport.name+"/"+mode.name, func(b *testing.B) { + benchmarkParallelActiveStreams(b, sz.data, streamCount, mode.mux, mode.oneConnPerStream, writeLatency, transport.useTCP) + }) + } + } + }) + } + }) + } +} + +func benchmarkParallelActiveStreams( + b *testing.B, payload []byte, streamCount int, mux bool, oneConnPerStream bool, writeLatency time.Duration, useTCP bool, +) { + ctx := drpctest.NewTracker(b) + defer ctx.Close() + + type bidiClient interface { + Send(*In) error + Recv() (*Out, error) + Close() error + } + + type worker struct { + stream bidiClient + close func() + } + workers := make([]worker, 0, streamCount) + + opts := connTransportOpts{ + writeDelay: writeLatency, + useTCP: useTCP, + mux: mux, + } + + var sharedConn *drpcconn.Conn + if !oneConnPerStream { + sharedConn = createConnectionWithTransport(b, echoServer, ctx, opts) + } + + for i := 0; i < streamCount; i++ { + conn := sharedConn + if oneConnPerStream { + conn = createConnectionWithTransport(b, echoServer, ctx, opts) + } + + client := NewDRPCServiceClient(conn) + stream, err := client.Method4(context.Background()) + if err != nil { + b.Fatalf("create stream %d: %v", i, err) + } + + s := stream + c := conn + workers = append(workers, worker{ + stream: s, + close: func() { + _ = s.Close() + if oneConnPerStream { + _ = c.Close() + } + }, + }) + } + + input := &In{In: 1, Data: payload} + b.SetBytes(int64(proto.Size(input))) + b.ReportAllocs() + + var sent atomic.Int64 + errCh := make(chan error, 1) + start := make(chan struct{}) + var wg sync.WaitGroup + + b.ResetTimer() + for _, w := range workers { + wg.Add(1) + go func(w worker) { + defer wg.Done() + msg := &In{In: 1, Data: payload} + <-start + + for { + n := sent.Add(1) + if n > int64(b.N) { + return + } + if err := w.stream.Send(msg); err != nil { + select { + case errCh <- err: + default: + } + return + } + if _, err := w.stream.Recv(); err != nil { + select { + case errCh <- err: + default: + } + return + } + } + }(w) + } + close(start) + wg.Wait() + b.StopTimer() + + select { + case err := <-errCh: + b.Fatal(err) + default: + } + + for _, w := range workers { + w.close() + } + if sharedConn != nil { + _ = sharedConn.Close() + } +} diff --git a/internal/integration/common_test.go b/internal/integration/common_test.go index 5acb61a..aa786c5 100644 --- a/internal/integration/common_test.go +++ b/internal/integration/common_test.go @@ -10,6 +10,7 @@ import ( "strconv" "sync" "testing" + "time" "github.com/zeebo/assert" "google.golang.org/grpc/codes" @@ -40,18 +41,93 @@ func in(n int64) *In { return &In{In: n} } func out(n int64) *Out { return &Out{Out: n} } func createRawConnection(t testing.TB, server DRPCServiceServer, ctx *drpctest.Tracker) *drpcconn.Conn { - c1, c2 := net.Pipe() + return createRawConnectionWithWriteDelay(t, server, ctx, 0) +} + +func createRawConnectionWithWriteDelay( + t testing.TB, server DRPCServiceServer, ctx *drpctest.Tracker, writeDelay time.Duration, +) *drpcconn.Conn { + return createConnectionWithTransport(t, server, ctx, connTransportOpts{writeDelay: writeDelay}) +} + +func createTCPConnectionWithWriteDelay( + t testing.TB, server DRPCServiceServer, ctx *drpctest.Tracker, writeDelay time.Duration, +) *drpcconn.Conn { + return createConnectionWithTransport(t, server, ctx, connTransportOpts{writeDelay: writeDelay, useTCP: true}) +} + +type connTransportOpts struct { + writeDelay time.Duration + useTCP bool + mux bool +} + +func createConnectionWithTransport( + t testing.TB, server DRPCServiceServer, ctx *drpctest.Tracker, opts connTransportOpts, +) *drpcconn.Conn { + var serverConn, clientConn net.Conn + if opts.useTCP { + lis, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { lis.Close() }) + + type result struct { + conn net.Conn + err error + } + ch := make(chan result, 1) + go func() { + c, err := lis.Accept() + ch <- result{c, err} + }() + + c, err := net.Dial("tcp", lis.Addr().String()) + if err != nil { + t.Fatal(err) + } + r := <-ch + if r.err != nil { + t.Fatal(r.err) + } + serverConn, clientConn = r.conn, c + t.Cleanup(func() { + serverConn.Close() + clientConn.Close() + }) + } else { + c1, c2 := net.Pipe() + serverConn, clientConn = c1, c2 + } + + if opts.writeDelay > 0 { + serverConn = &writeLatencyConn{Conn: serverConn, writeDelay: opts.writeDelay} + clientConn = &writeLatencyConn{Conn: clientConn, writeDelay: opts.writeDelay} + } + managerOpts := drpcmanager.Options{ + SoftCancel: true, + Mux: opts.mux, + } mux := drpcmux.New() assert.NoError(t, DRPCRegisterService(mux, server)) - srv := drpcserver.New(mux) - ctx.Run(func(ctx context.Context) { _ = srv.ServeOne(ctx, c1) }) - return drpcconn.NewWithOptions(c2, drpcconn.Options{ - Manager: drpcmanager.Options{ - SoftCancel: true, - }, + srv := drpcserver.NewWithOptions(mux, drpcserver.Options{Manager: managerOpts}) + ctx.Run(func(ctx context.Context) { _ = srv.ServeOne(ctx, serverConn) }) + return drpcconn.NewWithOptions(clientConn, drpcconn.Options{ + Manager: managerOpts, }) } +type writeLatencyConn struct { + net.Conn + writeDelay time.Duration +} + +func (c *writeLatencyConn) Write(p []byte) (int, error) { + time.Sleep(c.writeDelay) + return c.Conn.Write(p) +} + func createConnection(t testing.TB, server DRPCServiceServer) (DRPCServiceClient, func()) { ctx := drpctest.NewTracker(t) conn := createRawConnection(t, server, ctx)