From 783893f829dc32ecb153e69e05edfb47aa67dfb0 Mon Sep 17 00:00:00 2001 From: Shubham Dhama Date: Thu, 5 Mar 2026 04:05:04 +0000 Subject: [PATCH] *: add opt-in stream multiplexing support DRPC currently allows only one active stream per transport at a time. This commit adds multiplexing as an opt-in mode (Options.Mux) that enables concurrent streams over a single transport, while preserving the original sequential behavior as the default. Mux and non-mux paths have fundamentally different concurrency models, so a single Manager with conditionals would add branching in hot paths. Instead, two separate types are used: Manager (non-mux, unchanged) and MuxManager (new, in manager_mux.go). MuxManager runs two goroutines: manageReader routes packets to streams via a streamRegistry, and manageWriter batches frames from a sharedWriteBuf into transport writes. The Stream now accepts a StreamWriter interface instead of *Writer directly, so both paths share the same Stream code. Non-mux uses *Writer (direct transport writes), mux uses muxWriter (serializes each packet atomically into the shared buffer to prevent frame interleaving). Packet buffering is abstracted behind a packetStore interface with two implementations: syncPacketBuffer (blocking single-slot, non-mux) and queuePacketBuffer (non-blocking queue with sync.Pool recycling, mux). RawRecv and MsgRecv branch on the mux flag for buffer lifecycle: non-mux copies data then calls Done() to unblock the reader, mux takes ownership and recycles after consumption. HandlePacket now processes KindMessage before checking the term signal. This is needed for mux mode where Put must always run to return pool buffers. In non-mux mode, Put on a closed syncPacketBuffer returns immediately, so there is no behavioral change. Though there is a brief blocking window if terminate() has set the term signal but hasn't called pbuf.Close() yet. The reader's monotonicity check is relaxed from global to per-stream so interleaved frames from different streams are accepted. Non-mux never produces interleaved stream IDs, so this has no behavioral impact there. Conn uses a streamManager interface satisfied by both Manager types, branching once in the constructor. In mux mode, Invoke allocates a per-call marshal buffer instead of reusing a shared one to support concurrent calls. On the server side, ServeOne branches once to either serveOneNonMux (sequential handleRPC) or serveOneMux (concurrent handleRPC with WaitGroup). --- drpcconn/conn.go | 54 +++- drpcconn/conn_test.go | 5 +- drpcmanager/frame_queue_test.go | 82 +++++ drpcmanager/manager.go | 6 + drpcmanager/manager_mux.go | 405 +++++++++++++++++++++++++ drpcmanager/manager_test.go | 179 +++++++++++ drpcmanager/mux_writer.go | 110 +++++++ drpcmanager/mux_writer_test.go | 80 +++++ drpcmanager/registry.go | 89 ++++++ drpcmanager/registry_test.go | 136 +++++++++ drpcserver/server.go | 51 +++- drpcstream/pktbuf.go | 31 +- drpcstream/pktbuf_mux.go | 116 +++++++ drpcstream/stream.go | 114 ++++--- drpcwire/reader.go | 5 +- drpcwire/reader_test.go | 45 +++ drpcwire/writer.go | 9 + internal/drpcopts/stream.go | 7 + internal/integration/benchmark_test.go | 348 +++++++++++++++++++++ internal/integration/common_test.go | 90 +++++- 20 files changed, 1890 insertions(+), 72 deletions(-) create mode 100644 drpcmanager/frame_queue_test.go create mode 100644 drpcmanager/manager_mux.go create mode 100644 drpcmanager/mux_writer.go create mode 100644 drpcmanager/mux_writer_test.go create mode 100644 drpcmanager/registry.go create mode 100644 drpcmanager/registry_test.go create mode 100644 drpcstream/pktbuf_mux.go create mode 100644 internal/integration/benchmark_test.go diff --git a/drpcconn/conn.go b/drpcconn/conn.go index 636f346..0b0c509 100644 --- a/drpcconn/conn.go +++ b/drpcconn/conn.go @@ -29,10 +29,20 @@ type Options struct { CollectStats bool } +// streamManager is the interface satisfied by both drpcmanager.Manager (non-mux) +// and drpcmanager.MuxManager (mux). +type streamManager interface { + NewClientStream(ctx context.Context, rpc string) (*drpcstream.Stream, error) + Closed() <-chan struct{} + Unblocked() <-chan struct{} + Close() error +} + // Conn is a drpc client connection. type Conn struct { tr drpc.Transport - man *drpcmanager.Manager + man streamManager + mux bool mu sync.Mutex wbuf []byte @@ -56,7 +66,12 @@ func NewWithOptions(tr drpc.Transport, opts Options) *Conn { c.stats = make(map[string]*drpcstats.Stats) } - c.man = drpcmanager.NewWithOptions(tr, opts.Manager) + c.mux = opts.Manager.Mux + if c.mux { + c.man = drpcmanager.NewMuxWithOptions(tr, opts.Manager) + } else { + c.man = drpcmanager.NewWithOptions(tr, opts.Manager) + } return c } @@ -100,8 +115,9 @@ func (c *Conn) Unblocked() <-chan struct{} { return c.man.Unblocked() } // Close closes the connection. func (c *Conn) Close() (err error) { return c.man.Close() } -// Invoke issues the rpc on the transport serializing in, waits for a response, and -// deserializes it into out. Only one Invoke or Stream may be open at a time. +// Invoke issues the rpc on the transport serializing in, waits for a response, +// and deserializes it into out. In non-mux mode, only one Invoke or Stream may +// be open at a time. In mux mode, multiple calls may be open concurrently. func (c *Conn) Invoke(ctx context.Context, rpc string, enc drpc.Encoding, in, out drpc.Message) (err error) { defer func() { err = drpc.ToRPCErr(err) }() @@ -117,18 +133,28 @@ func (c *Conn) Invoke(ctx context.Context, rpc string, enc drpc.Encoding, in, ou } defer func() { err = errs.Combine(err, stream.Close()) }() - // we have to protect c.wbuf here even though the manager only allows one - // stream at a time because the stream may async close allowing another - // concurrent call to Invoke to proceed. - c.mu.Lock() - defer c.mu.Unlock() - - c.wbuf, err = drpcenc.MarshalAppend(in, enc, c.wbuf[:0]) - if err != nil { - return err + var data []byte + if c.mux { + // Per-call buffer allocation for concurrent access. + data, err = drpcenc.MarshalAppend(in, enc, nil) + if err != nil { + return err + } + } else { + // We have to protect c.wbuf here even though the manager only allows + // one stream at a time because the stream may async close allowing + // another concurrent call to Invoke to proceed. + c.mu.Lock() + defer c.mu.Unlock() + + c.wbuf, err = drpcenc.MarshalAppend(in, enc, c.wbuf[:0]) + if err != nil { + return err + } + data = c.wbuf } - if err := c.doInvoke(stream, enc, rpc, c.wbuf, metadata, out); err != nil { + if err := c.doInvoke(stream, enc, rpc, data, metadata, out); err != nil { return err } return nil diff --git a/drpcconn/conn_test.go b/drpcconn/conn_test.go index 9f74516..2550c3b 100644 --- a/drpcconn/conn_test.go +++ b/drpcconn/conn_test.go @@ -180,7 +180,10 @@ func TestConn_NewStreamSendsGrpcAndDrpcMetadata(t *testing.T) { }) s, err := conn.NewStream(ctx, "/com.example.Foo/Bar", testEncoding{}) assert.NoError(t, err) - _ = s.CloseSend() + + assert.NoError(t, s.CloseSend()) + + ctx.Wait() } func TestConn_encodeMetadata(t *testing.T) { diff --git a/drpcmanager/frame_queue_test.go b/drpcmanager/frame_queue_test.go new file mode 100644 index 0000000..b175e6b --- /dev/null +++ b/drpcmanager/frame_queue_test.go @@ -0,0 +1,82 @@ +// Copyright (C) 2026 Cockroach Labs. +// See LICENSE for copying information. + +package drpcmanager + +import ( + "testing" + + "github.com/zeebo/assert" + "storj.io/drpc/drpcwire" +) + +func TestSharedWriteBuf_AppendDrain(t *testing.T) { + sw := newSharedWriteBuf() + + pkt := drpcwire.Packet{ + Data: []byte("hello"), + ID: drpcwire.ID{Stream: 1, Message: 2}, + Kind: drpcwire.KindMessage, + } + + assert.NoError(t, sw.Append(pkt)) + + // Drain should return serialized bytes. + data := sw.Drain(nil) + assert.That(t, len(data) > 0) + + // Parse the frame back out to verify correctness. + _, got, ok, err := drpcwire.ParseFrame(data) + assert.NoError(t, err) + assert.That(t, ok) + assert.DeepEqual(t, got.Data, pkt.Data) + assert.Equal(t, got.ID.Stream, pkt.ID.Stream) + assert.Equal(t, got.ID.Message, pkt.ID.Message) + assert.Equal(t, got.Kind, pkt.Kind) + assert.Equal(t, got.Done, true) +} + +func TestSharedWriteBuf_CloseIdempotent(t *testing.T) { + sw := newSharedWriteBuf() + sw.Close() + sw.Close() // must not panic +} + +func TestSharedWriteBuf_AppendAfterClose(t *testing.T) { + sw := newSharedWriteBuf() + sw.Close() + + err := sw.Append(drpcwire.Packet{}) + assert.Error(t, err) +} + +func TestSharedWriteBuf_WaitAndDrainBlocks(t *testing.T) { + sw := newSharedWriteBuf() + + done := make(chan struct{}) + go func() { + defer close(done) + data, ok := sw.WaitAndDrain(nil) + assert.That(t, ok) + assert.That(t, len(data) > 0) + }() + + // Append should wake the blocked WaitAndDrain. + assert.NoError(t, sw.Append(drpcwire.Packet{Data: []byte("a")})) + <-done +} + +func TestSharedWriteBuf_WaitAndDrainCloseEmpty(t *testing.T) { + sw := newSharedWriteBuf() + + done := make(chan struct{}) + go func() { + defer close(done) + _, ok := sw.WaitAndDrain(nil) + assert.That(t, !ok) + }() + + // Close on empty buffer should return ok=false. + sw.Close() + <-done +} diff --git a/drpcmanager/manager.go b/drpcmanager/manager.go index d730836..0f8c768 100644 --- a/drpcmanager/manager.go +++ b/drpcmanager/manager.go @@ -60,6 +60,11 @@ type Options struct { // handling. When enabled, the server stream will decode incoming metadata // into grpc metadata in the context. GRPCMetadataCompatMode bool + + // Mux enables stream multiplexing on the transport, allowing multiple + // concurrent streams. When false (default), the manager uses the + // original single-stream-at-a-time behavior. + Mux bool } // Manager handles the logic of managing a transport for a drpc client or @@ -306,6 +311,7 @@ func (m *Manager) newStream(ctx context.Context, sid uint64, kind drpc.StreamKin drpcopts.SetStreamStats(&opts.Internal, cb(rpc)) } + m.wr.Reset() stream := drpcstream.NewWithOptions(ctx, sid, m.wr, opts) select { case m.streams <- streamInfo{ctx: ctx, stream: stream}: diff --git a/drpcmanager/manager_mux.go b/drpcmanager/manager_mux.go new file mode 100644 index 0000000..48d01a2 --- /dev/null +++ b/drpcmanager/manager_mux.go @@ -0,0 +1,405 @@ +// Copyright (C) 2026 Cockroach Labs. +// See LICENSE for copying information. + +package drpcmanager + +import ( + "context" + "errors" + "fmt" + "io" + "sync" + "sync/atomic" + "time" + + grpcmetadata "google.golang.org/grpc/metadata" + + "storj.io/drpc" + "storj.io/drpc/drpcdebug" + "storj.io/drpc/drpcmetadata" + "storj.io/drpc/drpcsignal" + "storj.io/drpc/drpcstream" + "storj.io/drpc/drpcwire" + "storj.io/drpc/internal/drpcopts" +) + +// MuxManager handles the logic of managing a transport for a drpc client or +// server with stream multiplexing enabled. Multiple streams can be active +// concurrently on a single transport. +type MuxManager struct { + tr drpc.Transport + rd *drpcwire.Reader + opts Options + + sw *sharedWriteBuf + reg *streamRegistry + streamID atomic.Uint64 + wg sync.WaitGroup + pkts chan drpcwire.Packet + pdone drpcsignal.Chan + metaMu sync.Mutex + meta map[uint64]map[string]string + + sigs struct { + term drpcsignal.Signal + write drpcsignal.Signal + read drpcsignal.Signal + tport drpcsignal.Signal + } +} + +// NewMuxWithOptions returns a new mux manager for the transport. It uses the +// provided options to manage details of how it uses it. +func NewMuxWithOptions(tr drpc.Transport, opts Options) *MuxManager { + m := &MuxManager{ + tr: tr, + rd: drpcwire.NewReaderWithOptions(tr, opts.Reader), + opts: opts, + + pkts: make(chan drpcwire.Packet), + meta: make(map[uint64]map[string]string), + } + + // a buffer of size 1 allows the consumer of the packet to signal it is done + // without having to coordinate with the sender of the packet. + m.pdone.Make(1) + + // set the internal stream options + drpcopts.SetStreamTransport(&m.opts.Stream.Internal, m.tr) + drpcopts.SetStreamMux(&m.opts.Stream.Internal, true) + + m.sw = newSharedWriteBuf() + m.reg = newStreamRegistry() + + go m.manageReader() + go m.manageWriter() + + return m +} + +// String returns a string representation of the manager. +func (m *MuxManager) String() string { return fmt.Sprintf("", m) } + +func (m *MuxManager) log(what string, cb func() string) { + if drpcdebug.Enabled { + drpcdebug.Log(func() (_, _, _ string) { return m.String(), what, cb() }) + } +} + +// +// helpers +// + +// terminate puts the MuxManager into a terminal state and closes any resources +// that need to be closed to signal the state change. +func (m *MuxManager) terminate(err error) { + if m.sigs.term.Set(err) { + m.log("TERM", func() string { return fmt.Sprint(err) }) + m.sigs.tport.Set(m.tr.Close()) + m.sw.Close() + m.metaMu.Lock() + for id := range m.meta { + delete(m.meta, id) + } + m.metaMu.Unlock() + // Cancel all active streams so they get a clear error. + m.reg.ForEach(func(_ uint64, s *drpcstream.Stream) { + cancelErr := err + if errors.Is(cancelErr, io.EOF) { + cancelErr = context.Canceled + if s.Kind() == drpc.StreamKindClient { + cancelErr = drpc.ClosedError.New("connection closed") + } + } + s.Cancel(cancelErr) + }) + m.reg.Close() + } +} + +func (m *MuxManager) putMetadata(streamID uint64, metadata map[string]string) { + m.metaMu.Lock() + defer m.metaMu.Unlock() + m.meta[streamID] = metadata +} + +func (m *MuxManager) popMetadata(streamID uint64) map[string]string { + m.metaMu.Lock() + defer m.metaMu.Unlock() + metadata := m.meta[streamID] + delete(m.meta, streamID) + return metadata +} + +// +// manage reader +// + +// manageReader is always reading a packet and dispatching it to the appropriate +// stream or queue. It sets the read signal when it exits so that one can wait +// to ensure that no one is reading on the reader. It sets the term signal if +// there is any error reading packets. +func (m *MuxManager) manageReader() { + defer m.sigs.read.Set(nil) + + var pkt drpcwire.Packet + var err error + var run int + + for !m.sigs.term.IsSet() { + // if we have a run of "small" packets, drop the buffer to release + // memory so that a burst of large packets does not cause eternally + // large heap usage. + if run > 10 { + pkt.Data = nil + run = 0 + } + + pkt, err = m.rd.ReadPacketUsing(pkt.Data[:0]) + if err != nil { + if isConnectionReset(err) { + err = drpc.ClosedError.Wrap(err) + } + m.terminate(managerClosed.Wrap(err)) + return + } + + if len(pkt.Data) < cap(pkt.Data)/4 { + run++ + } else { + run = 0 + } + + m.log("READ", pkt.String) + + stream, ok := m.reg.Get(pkt.ID.Stream) + + switch { + // if the packet is for a registered stream, deliver it. + case ok && stream != nil: + if err := stream.HandlePacket(pkt); err != nil { + m.terminate(managerClosed.Wrap(err)) + return + } + // For message packets, HandlePacket transferred ownership of + // pkt.Data to the stream's packetBuffer. Acquire a fresh buffer + // from the pool so the next ReadPacketUsing doesn't allocate. + if pkt.Kind == drpcwire.KindMessage { + pkt.Data = drpcstream.AcquirePacketBuf() + } + + // if any invoke sequence is being sent, forward it to be handled. + case pkt.Kind == drpcwire.KindInvoke || pkt.Kind == drpcwire.KindInvokeMetadata: + select { + case m.pkts <- pkt: + m.pdone.Recv() + case <-m.sigs.term.Signal(): + return + } + + // silently drop packet for an unregistered stream + default: + m.log("DROP", pkt.String) + } + } +} + +// manageWriter drains the shared write buffer and writes pre-serialized +// bytes directly to the transport. It blocks on the sharedWriteBuf's +// condition variable until data is available, and naturally batches +// frames that accumulate while the previous write is in flight. +func (m *MuxManager) manageWriter() { + defer m.sigs.write.Set(nil) + + var spare []byte + for { + data, ok := m.sw.WaitAndDrain(spare[:0:cap(spare)]) + if !ok { + return + } + if _, err := m.tr.Write(data); err != nil { + m.terminate(managerClosed.Wrap(err)) + return + } + spare = data + } +} + +// +// manage streams +// + +// newStream creates a stream value with the appropriate configuration for this manager. +func (m *MuxManager) newStream(ctx context.Context, sid uint64, kind drpc.StreamKind, rpc string) (*drpcstream.Stream, error) { + opts := m.opts.Stream + drpcopts.SetStreamKind(&opts.Internal, kind) + drpcopts.SetStreamRPC(&opts.Internal, rpc) + if cb := drpcopts.GetManagerStatsCB(&m.opts.Internal); cb != nil { + drpcopts.SetStreamStats(&opts.Internal, cb(rpc)) + } + + stream := drpcstream.NewWithOptions(ctx, sid, &muxWriter{sw: m.sw}, opts) + + if err := m.reg.Register(sid, stream); err != nil { + return nil, err + } + + m.wg.Add(1) + go m.manageStream(ctx, stream) + + m.log("STREAM", stream.String) + return stream, nil +} + +// manageStream watches the context and the stream and returns when the stream +// is finished, canceling the stream if the context is canceled. +func (m *MuxManager) manageStream(ctx context.Context, stream *drpcstream.Stream) { + defer m.wg.Done() + defer m.reg.Unregister(stream.ID()) + + select { + case <-m.sigs.term.Signal(): + err := m.sigs.term.Err() + if errors.Is(err, io.EOF) { + err = context.Canceled + if stream.Kind() == drpc.StreamKindClient { + err = drpc.ClosedError.New("connection closed") + } + } + stream.Cancel(err) + <-stream.Finished() + + case <-stream.Finished(): + // stream finished naturally + + case <-ctx.Done(): + m.log("CANCEL", stream.String) + + if m.opts.SoftCancel { + // Best-effort send KindCancel, never terminate connection. + if busy, err := stream.SendCancel(ctx.Err()); err != nil { + m.log("CANCEL_ERR", func() string { + return fmt.Sprintf("%s: %v", stream.String(), err) + }) + } else if busy { + m.log("CANCEL_BUSY", stream.String) + } + stream.Cancel(ctx.Err()) + <-stream.Finished() + } else { + // Hard cancel: terminate connection if stream not finished. + if !stream.Cancel(ctx.Err()) { + m.log("UNFIN", stream.String) + m.terminate(ctx.Err()) + } else { + m.log("CLEAN", stream.String) + } + <-stream.Finished() + } + } +} + +// +// exported interface +// + +// Closed returns a channel that is closed once the manager is closed. +func (m *MuxManager) Closed() <-chan struct{} { + return m.sigs.term.Signal() +} + +// Unblocked returns a channel that is closed when the manager is available for +// new streams. With multiplexing enabled, the connection is never blocked, so +// this always returns an already-closed channel. +func (m *MuxManager) Unblocked() <-chan struct{} { + return closedCh +} + +// Close closes the transport the manager is using. +func (m *MuxManager) Close() error { + m.terminate(managerClosed.New("Close called")) + + m.wg.Wait() // wait for all stream goroutines + m.sigs.write.Wait() + m.sigs.read.Wait() + m.sigs.tport.Wait() + + return m.sigs.tport.Err() +} + +// NewClientStream starts a stream on the managed transport for use by a client. +func (m *MuxManager) NewClientStream(ctx context.Context, rpc string) (stream *drpcstream.Stream, err error) { + if err, ok := m.sigs.term.Get(); ok { + return nil, err + } + sid := m.streamID.Add(1) + return m.newStream(ctx, sid, drpc.StreamKindClient, rpc) +} + +// NewServerStream starts a stream on the managed transport for use by a server. +// It does this by waiting for the client to issue an invoke message and +// returning the details. +func (m *MuxManager) NewServerStream(ctx context.Context) (stream *drpcstream.Stream, rpc string, err error) { + if err, ok := m.sigs.term.Get(); ok { + return nil, "", err + } + + var timeoutCh <-chan time.Time + + // set up the timeout on the context if necessary. + if timeout := m.opts.InactivityTimeout; timeout > 0 { + timer := time.NewTimer(timeout) + defer timer.Stop() + timeoutCh = timer.C + } + + for { + select { + case <-timeoutCh: + return nil, "", context.DeadlineExceeded + + case <-ctx.Done(): + return nil, "", ctx.Err() + + case <-m.sigs.term.Signal(): + return nil, "", m.sigs.term.Err() + + case pkt := <-m.pkts: + switch pkt.Kind { + case drpcwire.KindInvokeMetadata: + metadata, err := drpcmetadata.Decode(pkt.Data) + if err != nil { + m.pdone.Send() + return nil, "", err + } + m.putMetadata(pkt.ID.Stream, metadata) + m.pdone.Send() + + case drpcwire.KindInvoke: + rpc = string(pkt.Data) + streamCtx := ctx + + if metadata := m.popMetadata(pkt.ID.Stream); metadata != nil { + if m.opts.GRPCMetadataCompatMode { + grpcMeta := make(map[string][]string, len(metadata)) + for k, v := range metadata { + grpcMeta[k] = []string{v} + } + streamCtx = grpcmetadata.NewIncomingContext(streamCtx, grpcMeta) + } else { + streamCtx = drpcmetadata.NewIncomingContext(streamCtx, metadata) + } + } + stream, err := m.newStream(streamCtx, pkt.ID.Stream, drpc.StreamKindServer, rpc) + // Ack the invoke only after stream registration so subsequent + // message packets cannot be dropped for an unknown stream ID. + m.pdone.Send() + return stream, rpc, err + + default: + // this should never happen, but defensive. + m.pdone.Send() + } + } + } +} diff --git a/drpcmanager/manager_test.go b/drpcmanager/manager_test.go index 5918113..eca2ff4 100644 --- a/drpcmanager/manager_test.go +++ b/drpcmanager/manager_test.go @@ -16,6 +16,7 @@ import ( grpcmetadata "google.golang.org/grpc/metadata" "storj.io/drpc/drpcmetadata" + "storj.io/drpc/drpcstream" "storj.io/drpc/drpctest" "storj.io/drpc/drpcwire" ) @@ -161,6 +162,184 @@ func TestDrpcMetadataWithGRPCMetadataCompatMode(t *testing.T) { ctx.Wait() } +func TestMux_DrpcMetadataInterleavedAcrossStreams(t *testing.T) { + ctx := drpctest.NewTracker(t) + defer ctx.Close() + + cconn, sconn := net.Pipe() + defer func() { _ = cconn.Close() }() + defer func() { _ = sconn.Close() }() + + cman := NewMuxWithOptions(cconn, Options{}) + defer func() { _ = cman.Close() }() + + sman := NewMuxWithOptions(sconn, Options{ + GRPCMetadataCompatMode: true, + }) + defer func() { _ = sman.Close() }() + + stream1, err := cman.NewClientStream(ctx, "rpc-1") + assert.NoError(t, err) + defer func() { _ = stream1.Close() }() + + stream2, err := cman.NewClientStream(ctx, "rpc-2") + assert.NoError(t, err) + defer func() { _ = stream2.Close() }() + + metadata1 := map[string]string{"stream": "one"} + metadata2 := map[string]string{"stream": "two"} + + buf1, err := drpcmetadata.Encode(nil, metadata1) + assert.NoError(t, err) + buf2, err := drpcmetadata.Encode(nil, metadata2) + assert.NoError(t, err) + + assert.NoError(t, stream1.RawWrite(drpcwire.KindInvokeMetadata, buf1)) + assert.NoError(t, stream2.RawWrite(drpcwire.KindInvokeMetadata, buf2)) + assert.NoError(t, stream1.RawWrite(drpcwire.KindInvoke, []byte("rpc-1"))) + assert.NoError(t, stream2.RawWrite(drpcwire.KindInvoke, []byte("rpc-2"))) + + srvStream1, rpc1, err := sman.NewServerStream(ctx) + assert.NoError(t, err) + assert.Equal(t, "rpc-1", rpc1) + defer func() { _ = srvStream1.Close() }() + + got1, ok := grpcmetadata.FromIncomingContext(srvStream1.Context()) + assert.That(t, ok) + assert.Equal(t, grpcmetadata.MD{"stream": []string{"one"}}, got1) + + srvStream2, rpc2, err := sman.NewServerStream(ctx) + assert.NoError(t, err) + assert.Equal(t, "rpc-2", rpc2) + defer func() { _ = srvStream2.Close() }() + + got2, ok := grpcmetadata.FromIncomingContext(srvStream2.Context()) + assert.That(t, ok) + assert.Equal(t, grpcmetadata.MD{"stream": []string{"two"}}, got2) +} + +func TestMux_NewServerStreamUnreadMessageDoesNotBlockOtherStreams(t *testing.T) { + ctx := drpctest.NewTracker(t) + defer ctx.Close() + + cconn, sconn := net.Pipe() + defer func() { _ = cconn.Close() }() + defer func() { _ = sconn.Close() }() + + cman := NewMuxWithOptions(cconn, Options{}) + defer func() { _ = cman.Close() }() + + sman := NewMuxWithOptions(sconn, Options{}) + defer func() { _ = sman.Close() }() + + stream1, err := cman.NewClientStream(ctx, "rpc-1") + assert.NoError(t, err) + defer func() { _ = stream1.Close() }() + + stream2, err := cman.NewClientStream(ctx, "rpc-2") + assert.NoError(t, err) + defer func() { _ = stream2.Close() }() + + assert.NoError(t, stream1.RawWrite(drpcwire.KindInvoke, []byte("rpc-1"))) + assert.NoError(t, stream1.RawWrite(drpcwire.KindMessage, []byte("message-1"))) + assert.NoError(t, stream2.RawWrite(drpcwire.KindInvoke, []byte("rpc-2"))) + + srvStream1, rpc1, err := sman.NewServerStream(ctx) + assert.NoError(t, err) + assert.Equal(t, "rpc-1", rpc1) + defer func() { _ = srvStream1.Close() }() + + // Do not read the first stream's message. The manager must still be able to + // accept and register additional streams. + timeoutCtx, cancel := context.WithTimeout(ctx, time.Second) + defer cancel() + srvStream2, rpc2, err := sman.NewServerStream(timeoutCtx) + assert.NoError(t, err) + assert.Equal(t, "rpc-2", rpc2) + defer func() { _ = srvStream2.Close() }() +} + +// TestMux_ConcurrentLargeMessages verifies that two streams writing messages +// larger than SplitSize concurrently do not corrupt each other's data. With the +// mux implementation, rawWriteLocked splits messages into multiple frames and +// each frame is appended to the shared write buffer independently. Frames from +// different streams can interleave in the buffer, and the reader resets partial +// packets when it sees a frame from a different stream, silently corrupting +// data. +func TestMux_ConcurrentLargeMessages(t *testing.T) { + ctx := drpctest.NewTracker(t) + defer ctx.Close() + + cconn, sconn := net.Pipe() + defer func() { _ = cconn.Close() }() + defer func() { _ = sconn.Close() }() + + // Use a small SplitSize to force even small messages to be split into + // multiple frames, making interleaving likely. + streamOpts := drpcstream.Options{SplitSize: 5} + + cman := NewMuxWithOptions(cconn, Options{Stream: streamOpts}) + defer func() { _ = cman.Close() }() + + sman := NewMuxWithOptions(sconn, Options{Stream: streamOpts}) + defer func() { _ = sman.Close() }() + + // Create two client streams and send invoke + message concurrently. + stream1, err := cman.NewClientStream(ctx, "rpc-1") + assert.NoError(t, err) + defer func() { _ = stream1.Close() }() + + stream2, err := cman.NewClientStream(ctx, "rpc-2") + assert.NoError(t, err) + defer func() { _ = stream2.Close() }() + + msg1 := []byte("AAAAAAAAAAAAAAAAAAAA") // 20 bytes, split into 4 frames of 5 bytes + msg2 := []byte("BBBBBBBBBBBBBBBBBBBB") // 20 bytes, split into 4 frames of 5 bytes + + // Send invokes first (these are small, no splitting). + assert.NoError(t, stream1.RawWrite(drpcwire.KindInvoke, []byte("rpc-1"))) + assert.NoError(t, stream2.RawWrite(drpcwire.KindInvoke, []byte("rpc-2"))) + + // Accept both server streams before sending messages, so the streams are + // registered and the reader can route packets. + srvStream1, rpc1, err := sman.NewServerStream(ctx) + assert.NoError(t, err) + assert.Equal(t, "rpc-1", rpc1) + defer func() { _ = srvStream1.Close() }() + + srvStream2, rpc2, err := sman.NewServerStream(ctx) + assert.NoError(t, err) + assert.Equal(t, "rpc-2", rpc2) + defer func() { _ = srvStream2.Close() }() + + // Write messages concurrently from both streams. With SplitSize=5, each + // 20-byte message becomes 4 frames. The frames should not interleave. + ready := make(chan struct{}) + var wg sync.WaitGroup + wg.Add(2) + go func() { + defer wg.Done() + <-ready + assert.NoError(t, stream1.RawWrite(drpcwire.KindMessage, msg1)) + }() + go func() { + defer wg.Done() + <-ready + assert.NoError(t, stream2.RawWrite(drpcwire.KindMessage, msg2)) + }() + close(ready) + wg.Wait() + + // Read from both server streams and verify correctness. + got1, err := srvStream1.RawRecv() + assert.NoError(t, err) + assert.DeepEqual(t, got1, msg1) + + got2, err := srvStream2.RawRecv() + assert.NoError(t, err) + assert.DeepEqual(t, got2, msg2) +} + type blockingTransport chan struct{} func (b blockingTransport) Read(p []byte) (n int, err error) { <-b; return 0, io.EOF } diff --git a/drpcmanager/mux_writer.go b/drpcmanager/mux_writer.go new file mode 100644 index 0000000..b36c25e --- /dev/null +++ b/drpcmanager/mux_writer.go @@ -0,0 +1,110 @@ +// Copyright (C) 2026 Cockroach Labs. +// See LICENSE for copying information. + +package drpcmanager + +import ( + "sync" + + "storj.io/drpc/drpcwire" +) + +// muxWriter implements drpcwire.StreamWriter by serializing packet bytes into a +// shared write buffer. The manageWriter goroutine drains the buffer and writes +// directly to the transport. +// +// The entire packet is serialized as a single frame (via AppendFrame) under one +// mutex hold, so frames from concurrent streams never interleave on the wire. +// The packet's Data slice is consumed (copied) before WritePacket returns, so +// callers may safely reuse their buffers afterward. +type muxWriter struct { + sw *sharedWriteBuf +} + +func (w *muxWriter) WritePacket(pkt drpcwire.Packet) error { + return w.sw.Append(pkt) +} + +// Flush is a no-op because the manageWriter goroutine flushes to the +// transport after draining the shared buffer. +func (w *muxWriter) Flush() error { return nil } + +func (w *muxWriter) Empty() bool { return true } + +// sharedWriteBuf collects serialized frame bytes from multiple concurrent +// producers. A single consumer (manageWriter) drains the buffer and writes +// the pre-serialized bytes to the transport. +type sharedWriteBuf struct { + mu sync.Mutex + cond *sync.Cond + buf []byte + closed bool +} + +func newSharedWriteBuf() *sharedWriteBuf { + sw := &sharedWriteBuf{} + sw.cond = sync.NewCond(&sw.mu) + return sw +} + +// Append serializes pkt as a single frame into the shared buffer. The packet's +// Data slice is consumed (copied by AppendFrame) before Append returns. +func (sw *sharedWriteBuf) Append(pkt drpcwire.Packet) error { + sw.mu.Lock() + if sw.closed { + sw.mu.Unlock() + return managerClosed.New("enqueue") + } + sw.buf = drpcwire.AppendFrame(sw.buf, drpcwire.Frame{ + Data: pkt.Data, + ID: pkt.ID, + Kind: pkt.Kind, + Control: pkt.Control, + Done: true, + }) + sw.mu.Unlock() + + sw.cond.Signal() + return nil +} + +// Drain swaps out accumulated bytes, giving the caller ownership of the +// returned slice. The internal buffer is replaced with spare (reset to zero +// length) so producers can continue appending without allocation. +func (sw *sharedWriteBuf) Drain(spare []byte) []byte { + sw.mu.Lock() + data := sw.buf + sw.buf = spare + sw.mu.Unlock() + return data +} + +// WaitAndDrain blocks until data is available or the buffer is closed. +// Returns the accumulated bytes and true if data was available, or nil and +// false if the buffer is closed and empty. +func (sw *sharedWriteBuf) WaitAndDrain(spare []byte) ([]byte, bool) { + sw.mu.Lock() + for len(sw.buf) == 0 && !sw.closed { + sw.cond.Wait() + } + if sw.closed && len(sw.buf) == 0 { + sw.mu.Unlock() + return nil, false + } + data := sw.buf + sw.buf = spare + sw.mu.Unlock() + return data, true +} + +// Close marks the buffer as closed and wakes the consumer. +func (sw *sharedWriteBuf) Close() { + sw.mu.Lock() + defer sw.mu.Unlock() + + if sw.closed { + return + } + sw.closed = true + sw.cond.Broadcast() +} diff --git a/drpcmanager/mux_writer_test.go b/drpcmanager/mux_writer_test.go new file mode 100644 index 0000000..50c4e1c --- /dev/null +++ b/drpcmanager/mux_writer_test.go @@ -0,0 +1,80 @@ +// Copyright (C) 2026 Cockroach Labs. +// See LICENSE for copying information. + +package drpcmanager + +import ( + "testing" + + "github.com/zeebo/assert" + "storj.io/drpc/drpcwire" +) + +func TestMuxWriter_WritePacket(t *testing.T) { + sw := newSharedWriteBuf() + w := &muxWriter{sw: sw} + + pkt := drpcwire.Packet{ + Data: []byte("hello"), + ID: drpcwire.ID{Stream: 1, Message: 2}, + Kind: drpcwire.KindMessage, + } + + assert.NoError(t, w.WritePacket(pkt)) + + data := sw.Drain(nil) + _, got, ok, err := drpcwire.ParseFrame(data) + assert.NoError(t, err) + assert.That(t, ok) + assert.DeepEqual(t, got.Data, pkt.Data) + assert.Equal(t, got.ID.Stream, pkt.ID.Stream) + assert.Equal(t, got.ID.Message, pkt.ID.Message) + assert.Equal(t, got.Kind, pkt.Kind) + assert.Equal(t, got.Done, true) +} + +func TestMuxWriter_WritePacketIsolatesData(t *testing.T) { + sw := newSharedWriteBuf() + w := &muxWriter{sw: sw} + + data := []byte("hello") + pkt := drpcwire.Packet{ + Data: data, + ID: drpcwire.ID{Stream: 1, Message: 2}, + Kind: drpcwire.KindMessage, + } + + assert.NoError(t, w.WritePacket(pkt)) + + // Mutate the original source buffer after WritePacket. + data[0] = 'j' + + // The serialized data in the shared buffer should be unaffected because + // AppendFrame copies the bytes during serialization. + buf := sw.Drain(nil) + _, got, ok, err := drpcwire.ParseFrame(buf) + assert.NoError(t, err) + assert.That(t, ok) + assert.DeepEqual(t, got.Data, []byte("hello")) +} + +func TestMuxWriter_FlushNoop(t *testing.T) { + sw := newSharedWriteBuf() + w := &muxWriter{sw: sw} + assert.NoError(t, w.Flush()) +} + +func TestMuxWriter_Empty(t *testing.T) { + sw := newSharedWriteBuf() + w := &muxWriter{sw: sw} + assert.That(t, w.Empty()) +} + +func TestMuxWriter_WritePacketAfterClose(t *testing.T) { + sw := newSharedWriteBuf() + w := &muxWriter{sw: sw} + sw.Close() + + err := w.WritePacket(drpcwire.Packet{}) + assert.Error(t, err) +} diff --git a/drpcmanager/registry.go b/drpcmanager/registry.go new file mode 100644 index 0000000..b6e90b1 --- /dev/null +++ b/drpcmanager/registry.go @@ -0,0 +1,89 @@ +// Copyright (C) 2026 Cockroach Labs. +// See LICENSE for copying information. + +package drpcmanager + +import ( + "sync" + + "storj.io/drpc/drpcstream" +) + +// streamRegistry is a thread-safe map of stream IDs to stream objects. +// It is used by the MuxManager to track all active streams for lifecycle +// management and packet routing. +type streamRegistry struct { + mu sync.RWMutex + streams map[uint64]*drpcstream.Stream + closed bool +} + +func newStreamRegistry() *streamRegistry { + return &streamRegistry{ + streams: make(map[uint64]*drpcstream.Stream), + } +} + +// Register adds a stream to the registry. It returns an error if the registry +// is closed or if a stream with the same ID is already registered. +func (r *streamRegistry) Register(id uint64, stream *drpcstream.Stream) error { + r.mu.Lock() + defer r.mu.Unlock() + + if r.closed { + return managerClosed.New("register") + } + if _, ok := r.streams[id]; ok { + return managerClosed.New("duplicate stream id") + } + r.streams[id] = stream + return nil +} + +// Unregister removes a stream from the registry. It is a no-op if the stream +// is not registered or if the registry has been closed. +func (r *streamRegistry) Unregister(id uint64) { + r.mu.Lock() + defer r.mu.Unlock() + + if r.streams != nil { + delete(r.streams, id) + } +} + +// Get returns the stream for the given ID and whether it was found. +func (r *streamRegistry) Get(id uint64) (*drpcstream.Stream, bool) { + r.mu.RLock() + defer r.mu.RUnlock() + + s, ok := r.streams[id] + return s, ok +} + +// Close marks the registry as closed, preventing future Register calls. +// It does not cancel any streams. +func (r *streamRegistry) Close() { + r.mu.Lock() + defer r.mu.Unlock() + + r.closed = true +} + +// ForEach calls fn for each registered stream. The function is called with +// the stream ID and stream pointer. The registry is read-locked during iteration. +func (r *streamRegistry) ForEach(fn func(uint64, *drpcstream.Stream)) { + r.mu.RLock() + defer r.mu.RUnlock() + + for id, s := range r.streams { + fn(id, s) + } +} + +// Len returns the number of registered streams. +func (r *streamRegistry) Len() int { + r.mu.RLock() + defer r.mu.RUnlock() + + return len(r.streams) +} diff --git a/drpcmanager/registry_test.go b/drpcmanager/registry_test.go new file mode 100644 index 0000000..d27359b --- /dev/null +++ b/drpcmanager/registry_test.go @@ -0,0 +1,136 @@ +// Copyright (C) 2026 Cockroach Labs. +// See LICENSE for copying information. + +package drpcmanager + +import ( + "context" + "testing" + + "github.com/zeebo/assert" + + "storj.io/drpc/drpcstream" +) + +func testStream(id uint64) *drpcstream.Stream { + return drpcstream.New(context.Background(), id, &muxWriter{sw: newSharedWriteBuf()}) +} + +func TestStreamRegistry_RegisterAndGet(t *testing.T) { + reg := newStreamRegistry() + s := testStream(1) + + assert.NoError(t, reg.Register(1, s)) + + got, ok := reg.Get(1) + assert.That(t, ok) + assert.Equal(t, got, s) +} + +func TestStreamRegistry_GetMissing(t *testing.T) { + reg := newStreamRegistry() + + got, ok := reg.Get(42) + assert.That(t, !ok) + assert.Nil(t, got) +} + +func TestStreamRegistry_Unregister(t *testing.T) { + reg := newStreamRegistry() + s := testStream(1) + + assert.NoError(t, reg.Register(1, s)) + assert.Equal(t, reg.Len(), 1) + + reg.Unregister(1) + + _, ok := reg.Get(1) + assert.That(t, !ok) + assert.Equal(t, reg.Len(), 0) +} + +func TestStreamRegistry_UnregisterIdempotent(t *testing.T) { + reg := newStreamRegistry() + + // must not panic when unregistering a non-existent ID + reg.Unregister(99) +} + +func TestStreamRegistry_DuplicateRegister(t *testing.T) { + reg := newStreamRegistry() + s1 := testStream(1) + s2 := testStream(1) + + assert.NoError(t, reg.Register(1, s1)) + assert.Error(t, reg.Register(1, s2)) + + // original stream is still registered + got, ok := reg.Get(1) + assert.That(t, ok) + assert.Equal(t, got, s1) +} + +func TestStreamRegistry_RegisterAfterClose(t *testing.T) { + reg := newStreamRegistry() + reg.Close() + + err := reg.Register(1, testStream(1)) + assert.Error(t, err) +} + +func TestStreamRegistry_UnregisterAfterClose(t *testing.T) { + reg := newStreamRegistry() + s := testStream(1) + assert.NoError(t, reg.Register(1, s)) + + reg.Close() + + // must not panic + reg.Unregister(1) +} + +func TestStreamRegistry_Len(t *testing.T) { + reg := newStreamRegistry() + assert.Equal(t, reg.Len(), 0) + + assert.NoError(t, reg.Register(1, testStream(1))) + assert.Equal(t, reg.Len(), 1) + + assert.NoError(t, reg.Register(2, testStream(2))) + assert.Equal(t, reg.Len(), 2) + + reg.Unregister(1) + assert.Equal(t, reg.Len(), 1) +} + +func TestStreamRegistry_ForEach(t *testing.T) { + reg := newStreamRegistry() + s1 := testStream(1) + s2 := testStream(2) + s3 := testStream(3) + + assert.NoError(t, reg.Register(1, s1)) + assert.NoError(t, reg.Register(2, s2)) + assert.NoError(t, reg.Register(3, s3)) + + seen := make(map[uint64]*drpcstream.Stream) + reg.ForEach(func(id uint64, s *drpcstream.Stream) { + seen[id] = s + }) + + assert.Equal(t, len(seen), 3) + assert.Equal(t, seen[1], s1) + assert.Equal(t, seen[2], s2) + assert.Equal(t, seen[3], s3) +} + +func TestStreamRegistry_ForEach_Empty(t *testing.T) { + reg := newStreamRegistry() + + count := 0 + reg.ForEach(func(id uint64, s *drpcstream.Stream) { + count++ + }) + + assert.Equal(t, count, 0) +} diff --git a/drpcserver/server.go b/drpcserver/server.go index b8f95d9..9337182 100644 --- a/drpcserver/server.go +++ b/drpcserver/server.go @@ -92,6 +92,19 @@ func (s *Server) getStats(rpc string) *drpcstats.Stats { // ServeOne serves a single set of rpcs on the provided transport. func (s *Server) ServeOne(ctx context.Context, tr drpc.Transport) (err error) { + ctx, err = s.setupConnection(ctx, tr) + if err != nil { + return err + } + if s.opts.Manager.Mux { + return s.serveOneMux(ctx, tr) + } + return s.serveOneNonMux(ctx, tr) +} + +// setupConnection performs TLS handshake (if applicable) and sets up the +// drpccache context. It is shared by both mux and non-mux serve paths. +func (s *Server) setupConnection(ctx context.Context, tr drpc.Transport) (context.Context, error) { // Check if the transport is a TLS connection if tlsConn, ok := tr.(*tls.Conn); ok { // Manually perform the TLS handshake to access peer certificate @@ -109,7 +122,7 @@ func (s *Server) ServeOne(ctx context.Context, tr drpc.Transport) (err error) { // anyway. err := tlsConn.HandshakeContext(ctx) if err != nil { - return drpc.ConnectionError.New("server handshake [%q] failed: %w", tlsConn.RemoteAddr(), err) + return ctx, drpc.ConnectionError.New("server handshake [%q] failed: %w", tlsConn.RemoteAddr(), err) } state := tlsConn.ConnectionState() if len(state.PeerCertificates) > 0 { @@ -117,7 +130,11 @@ func (s *Server) ServeOne(ctx context.Context, tr drpc.Transport) (err error) { ctx, drpcctx.PeerConnectionInfo{Certificates: state.PeerCertificates}) } } + return ctx, nil +} +// serveOneNonMux serves rpcs sequentially using the non-mux Manager. +func (s *Server) serveOneNonMux(ctx context.Context, tr drpc.Transport) (err error) { man := drpcmanager.NewWithOptions(tr, s.opts.Manager) defer func() { err = errs.Combine(err, man.Close()) }() @@ -137,6 +154,38 @@ func (s *Server) ServeOne(ctx context.Context, tr drpc.Transport) (err error) { } } +// serveOneMux serves rpcs concurrently using the MuxManager. +func (s *Server) serveOneMux(ctx context.Context, tr drpc.Transport) (err error) { + man := drpcmanager.NewMuxWithOptions(tr, s.opts.Manager) + + var wg sync.WaitGroup + defer func() { + wg.Wait() + err = errs.Combine(err, man.Close()) + }() + + cache := drpccache.New() + defer cache.Clear() + + ctx = drpccache.WithContext(ctx, cache) + + for { + stream, rpc, err := man.NewServerStream(ctx) + if err != nil { + return errs.Wrap(err) + } + wg.Add(1) + go func() { + defer wg.Done() + if err := s.handleRPC(stream, rpc); err != nil { + if s.opts.Log != nil { + s.opts.Log(err) + } + } + }() + } +} + var temporarySleep = 500 * time.Millisecond // Serve listens for connections on the listener and serves the drpc request diff --git a/drpcstream/pktbuf.go b/drpcstream/pktbuf.go index db68864..5854d86 100644 --- a/drpcstream/pktbuf.go +++ b/drpcstream/pktbuf.go @@ -7,7 +7,22 @@ import ( "sync" ) -type packetBuffer struct { +// packetStore is the interface for packet buffer implementations. +// syncPacketBuffer is used for non-mux mode (blocking, single-slot), +// queuePacketBuffer is used for mux mode (non-blocking, queued). +type packetStore interface { + Put(data []byte) + Get() ([]byte, error) + Close(err error) + Done() + Recycle([]byte) +} + +// syncPacketBuffer is the original single-slot, blocking packet buffer used +// in non-mux mode. Put blocks until the previous value is consumed via +// Get+Done, and the reader (manageReader) blocks in Put until the stream +// consumer finishes processing. +type syncPacketBuffer struct { mu sync.Mutex cond sync.Cond err error @@ -16,11 +31,11 @@ type packetBuffer struct { held bool } -func (pb *packetBuffer) init() { +func (pb *syncPacketBuffer) init() { pb.cond.L = &pb.mu } -func (pb *packetBuffer) Close(err error) { +func (pb *syncPacketBuffer) Close(err error) { pb.mu.Lock() defer pb.mu.Unlock() @@ -36,7 +51,7 @@ func (pb *packetBuffer) Close(err error) { } } -func (pb *packetBuffer) Put(data []byte) { +func (pb *syncPacketBuffer) Put(data []byte) { pb.mu.Lock() defer pb.mu.Unlock() @@ -57,7 +72,7 @@ func (pb *packetBuffer) Put(data []byte) { } } -func (pb *packetBuffer) Get() ([]byte, error) { +func (pb *syncPacketBuffer) Get() ([]byte, error) { pb.mu.Lock() defer pb.mu.Unlock() @@ -74,7 +89,7 @@ func (pb *packetBuffer) Get() ([]byte, error) { return pb.data, nil } -func (pb *packetBuffer) Done() { +func (pb *syncPacketBuffer) Done() { pb.mu.Lock() defer pb.mu.Unlock() @@ -83,3 +98,7 @@ func (pb *packetBuffer) Done() { pb.held = false pb.cond.Broadcast() } + +// Recycle is a no-op for syncPacketBuffer. Buffer lifetime is managed by +// the manageReader goroutine that owns the underlying slice. +func (pb *syncPacketBuffer) Recycle([]byte) {} diff --git a/drpcstream/pktbuf_mux.go b/drpcstream/pktbuf_mux.go new file mode 100644 index 0000000..cb152cb --- /dev/null +++ b/drpcstream/pktbuf_mux.go @@ -0,0 +1,116 @@ +// Copyright (C) 2026 Cockroach Labs. +// See LICENSE for copying information. + +package drpcstream + +import ( + "io" + "sync" +) + +// pktBuf wraps a byte slice so the pool stores a pointer type, avoiding +// an allocation on every Put. +type pktBuf struct { + data []byte +} + +// pktBufPool recycles byte slices used on the read path to avoid per-packet +// allocations. Buffers flow: AcquirePacketBuf -> reader -> Put (zero-copy +// ownership transfer) -> Get -> consumer -> Recycle -> pool. +var pktBufPool = sync.Pool{ + New: func() interface{} { return &pktBuf{} }, +} + +func recycleToPktBufPool(data []byte) { + pb := pktBufPool.Get().(*pktBuf) + pb.data = data[:0] + pktBufPool.Put(pb) +} + +// AcquirePacketBuf returns a byte slice from the shared packet buffer pool +// for use as a read buffer. Returns nil if the pool is empty. +func AcquirePacketBuf() []byte { + pb := pktBufPool.Get().(*pktBuf) + data := pb.data[:0] + pb.data = nil + pktBufPool.Put(pb) + return data +} + +// queuePacketBuffer is a non-blocking, queue-based packet buffer used in mux +// mode. Put appends to an unbounded queue and returns immediately, allowing +// the reader goroutine to continue dispatching packets to other streams. +type queuePacketBuffer struct { + mu sync.Mutex + cond sync.Cond + err error + data [][]byte +} + +func (pb *queuePacketBuffer) init() { + pb.cond.L = &pb.mu +} + +func (pb *queuePacketBuffer) Close(err error) { + pb.mu.Lock() + defer pb.mu.Unlock() + + if pb.err == nil { + // Preserve already-queued messages on graceful close so readers can + // drain them before seeing EOF. + if err != io.EOF { + for i := range pb.data { + recycleToPktBufPool(pb.data[i]) + pb.data[i] = nil + } + pb.data = pb.data[:0] + } + pb.err = err + pb.cond.Broadcast() + } +} + +// Put takes ownership of data. The caller must not use data after calling Put. +// If the buffer is closed, data is returned to the pool. +func (pb *queuePacketBuffer) Put(data []byte) { + pb.mu.Lock() + defer pb.mu.Unlock() + + if pb.err != nil { + recycleToPktBufPool(data) + return + } + + pb.data = append(pb.data, data) + pb.cond.Broadcast() +} + +func (pb *queuePacketBuffer) Get() ([]byte, error) { + pb.mu.Lock() + defer pb.mu.Unlock() + + for len(pb.data) == 0 && pb.err == nil { + pb.cond.Wait() + } + if len(pb.data) == 0 { + return nil, pb.err + } + + data := pb.data[0] + n := copy(pb.data, pb.data[1:]) + pb.data[n] = nil + pb.data = pb.data[:n] + return data, nil +} + +// Done is a no-op for queuePacketBuffer. Buffer ownership is transferred to +// the caller via Get, and recycling is done explicitly via Recycle. +func (pb *queuePacketBuffer) Done() {} + +// Recycle returns a buffer obtained from Get back to the pool. +// Call this after the data has been fully consumed (e.g. after Unmarshal). +func (pb *queuePacketBuffer) Recycle(buf []byte) { + if buf != nil { + recycleToPktBufPool(buf) + } +} diff --git a/drpcstream/stream.go b/drpcstream/stream.go index 29ccd63..378c169 100644 --- a/drpcstream/stream.go +++ b/drpcstream/stream.go @@ -53,8 +53,9 @@ type Stream struct { flush sync.Once id drpcwire.ID - wr *drpcwire.Writer - pbuf packetBuffer + mux bool + wr drpcwire.StreamWriter + pbuf packetStore wbuf []byte mu sync.Mutex // protects state transitions @@ -72,7 +73,7 @@ var _ drpc.Stream = (*Stream)(nil) // New returns a new stream bound to the context with the given stream id and // will use the writer to write messages on. It is important use monotonically // increasing stream ids within a single transport. -func New(ctx context.Context, sid uint64, wr *drpcwire.Writer) *Stream { +func New(ctx context.Context, sid uint64, wr drpcwire.StreamWriter) *Stream { return NewWithOptions(ctx, sid, wr, Options{}) } @@ -80,7 +81,9 @@ func New(ctx context.Context, sid uint64, wr *drpcwire.Writer) *Stream { // stream id and will use the writer to write messages on. It is important use // monotonically increasing stream ids within a single transport. The options // are used to control details of how the Stream operates. -func NewWithOptions(ctx context.Context, sid uint64, wr *drpcwire.Writer, opts Options) *Stream { +func NewWithOptions( + ctx context.Context, sid uint64, wr drpcwire.StreamWriter, opts Options, +) *Stream { var task *trace.Task if trace.IsEnabled() { kind, rpc := drpcopts.GetStreamKind(&opts.Internal), drpcopts.GetStreamRPC(&opts.Internal) @@ -89,6 +92,8 @@ func NewWithOptions(ctx context.Context, sid uint64, wr *drpcwire.Writer, opts O } } + mux := drpcopts.GetStreamMux(&opts.Internal) + s := &Stream{ ctx: streamCtx{ Context: ctx, @@ -98,12 +103,21 @@ func NewWithOptions(ctx context.Context, sid uint64, wr *drpcwire.Writer, opts O fin: drpcopts.GetStreamFin(&opts.Internal), task: task, - id: drpcwire.ID{Stream: sid}, - wr: wr.Reset(), + id: drpcwire.ID{Stream: sid}, + mux: mux, + wr: wr, } - // initialize the packet buffer - s.pbuf.init() + // initialize the packet buffer based on mode + if mux { + pb := new(queuePacketBuffer) + pb.init() + s.pbuf = pb + } else { + pb := new(syncPacketBuffer) + pb.init() + s.pbuf = pb + } return s } @@ -225,17 +239,19 @@ func (s *Stream) HandlePacket(pkt drpcwire.Packet) (err error) { drpcopts.GetStreamStats(&s.opts.Internal).AddRead(uint64(len(pkt.Data))) - if s.sigs.term.IsSet() { + // Put must always be called for message packets to manage buffer + // ownership. Put returns the buffer to the pool if the stream is closed. + if pkt.Kind == drpcwire.KindMessage { + s.pbuf.Put(pkt.Data) return nil } - s.log("HANDLE", pkt.String) - - if pkt.Kind == drpcwire.KindMessage { - s.pbuf.Put(pkt.Data) + if s.sigs.term.IsSet() { return nil } + s.log("HANDLE", pkt.String) + s.mu.Lock() defer s.mu.Unlock() @@ -314,26 +330,28 @@ func (s *Stream) checkCancelError(err error) error { return err } -// newFrameLocked bumps the internal message id and returns a frame. It must be +// nextID bumps the internal message id and returns the new ID. It must be // called under a mutex. -func (s *Stream) newFrameLocked(kind drpcwire.Kind) drpcwire.Frame { +func (s *Stream) nextID() drpcwire.ID { s.id.Message++ - return drpcwire.Frame{ID: s.id, Kind: kind} + return s.id } // sendPacketLocked sends the packet in a single write and flushes. It does not // check for any conditions to stop it from writing and is meant for internal // stream use to do things like signal errors or closes to the remote side. func (s *Stream) sendPacketLocked(kind drpcwire.Kind, control bool, data []byte) (err error) { - fr := s.newFrameLocked(kind) - fr.Data = data - fr.Control = control - fr.Done = true + pkt := drpcwire.Packet{ + ID: s.nextID(), + Kind: kind, + Data: data, + Control: control, + } drpcopts.GetStreamStats(&s.opts.Internal).AddWritten(uint64(len(data))) - s.log("SEND", fr.String) + s.log("SEND", pkt.String) - if err := s.wr.WriteFrame(fr); err != nil { + if err := s.wr.WritePacket(pkt); err != nil { return errs.Wrap(err) } if err := s.wr.Flush(); err != nil { @@ -376,29 +394,26 @@ func (s *Stream) RawWrite(kind drpcwire.Kind, data []byte) (err error) { // rawWriteLocked does the body of RawWrite assuming the caller is holding the // appropriate locks. func (s *Stream) rawWriteLocked(kind drpcwire.Kind, data []byte) (err error) { - fr := s.newFrameLocked(kind) - n := s.opts.SplitSize - - for { - switch { - case s.sigs.send.IsSet(): - return s.sigs.send.Err() - case s.sigs.term.IsSet(): - return s.sigs.term.Err() - } + switch { + case s.sigs.send.IsSet(): + return s.sigs.send.Err() + case s.sigs.term.IsSet(): + return s.sigs.term.Err() + } - fr.Data, data = drpcwire.SplitData(data, n) - fr.Done = len(data) == 0 + pkt := drpcwire.Packet{ + ID: s.nextID(), + Kind: kind, + Data: data, + } - drpcopts.GetStreamStats(&s.opts.Internal).AddWritten(uint64(len(fr.Data))) - s.log("SEND", fr.String) + drpcopts.GetStreamStats(&s.opts.Internal).AddWritten(uint64(len(data))) + s.log("SEND", pkt.String) - if err := s.wr.WriteFrame(fr); err != nil { - return s.checkCancelError(errs.Wrap(err)) - } else if fr.Done { - return nil - } + if err := s.wr.WritePacket(pkt); err != nil { + return s.checkCancelError(errs.Wrap(err)) } + return nil } // RawFlush flushes any buffers of data. @@ -461,6 +476,14 @@ func (s *Stream) RawRecv() (data []byte, err error) { if err != nil { return nil, err } + + if s.mux { + // In mux mode, the buffer is owned by the caller (from the pool). + return data, nil + } + + // In non-mux mode, copy the data and release the slot so the reader + // goroutine can continue. data = append([]byte(nil), data...) s.pbuf.Done() @@ -514,7 +537,14 @@ func (s *Stream) MsgRecv(msg drpc.Message, enc drpc.Encoding) (err error) { return err } err = enc.Unmarshal(data, msg) - s.pbuf.Done() + + if s.mux { + // In mux mode, return the buffer to the pool after unmarshal. + s.pbuf.Recycle(data) + } else { + // In non-mux mode, release the slot so the reader can continue. + s.pbuf.Done() + } return err } diff --git a/drpcwire/reader.go b/drpcwire/reader.go index c9ac397..ef801d3 100644 --- a/drpcwire/reader.go +++ b/drpcwire/reader.go @@ -143,7 +143,10 @@ func (r *Reader) ReadPacketUsing(buf []byte) (pkt Packet, err error) { pkt.Control = pkt.Control || fr.Control switch { - case fr.ID.Less(r.id): + case fr.ID.Stream == 0 || fr.ID.Message == 0: + return Packet{}, drpc.ProtocolError.New("id monotonicity violation (fr:%v r:%v)", fr.ID, r.id) + + case fr.ID.Stream == r.id.Stream && fr.ID.Less(r.id): return Packet{}, drpc.ProtocolError.New("id monotonicity violation (fr:%v r:%v)", fr.ID, r.id) case r.id != fr.ID || pkt.ID == ID{}: diff --git a/drpcwire/reader_test.go b/drpcwire/reader_test.go index d57145b..ec0cd3f 100644 --- a/drpcwire/reader_test.go +++ b/drpcwire/reader_test.go @@ -154,6 +154,51 @@ func TestReader(t *testing.T) { Frames: []Frame{{ID: ID{Stream: 0, Message: 1}}}, Error: "id monotonicity violation", }, + + { // cross-stream: frames from a lower stream after a higher stream are allowed (multiplexing) + Packets: []Packet{ + {Data: []byte("a"), ID: ID{Stream: 1, Message: 1}, Kind: KindMessage}, + {Data: []byte("b"), ID: ID{Stream: 2, Message: 1}, Kind: KindMessage}, + {Data: nil, ID: ID{Stream: 1, Message: 2}, Kind: KindClose}, + }, + Frames: []Frame{ + {Data: []byte("a"), ID: ID{Stream: 1, Message: 1}, Kind: KindMessage, Done: true}, + {Data: []byte("b"), ID: ID{Stream: 2, Message: 1}, Kind: KindMessage, Done: true}, + {Data: nil, ID: ID{Stream: 1, Message: 2}, Kind: KindClose, Done: true}, + }, + }, + + { // cross-stream: interleaved single-frame packets from multiple streams + Packets: []Packet{ + {Data: []byte("s1m1"), ID: ID{Stream: 1, Message: 1}, Kind: KindInvoke}, + {Data: []byte("s2m1"), ID: ID{Stream: 2, Message: 1}, Kind: KindInvoke}, + {Data: []byte("s1m2"), ID: ID{Stream: 1, Message: 2}, Kind: KindMessage}, + {Data: []byte("s2m2"), ID: ID{Stream: 2, Message: 2}, Kind: KindMessage}, + {Data: nil, ID: ID{Stream: 1, Message: 3}, Kind: KindCloseSend}, + {Data: nil, ID: ID{Stream: 2, Message: 3}, Kind: KindCloseSend}, + }, + Frames: []Frame{ + {Data: []byte("s1m1"), ID: ID{Stream: 1, Message: 1}, Kind: KindInvoke, Done: true}, + {Data: []byte("s2m1"), ID: ID{Stream: 2, Message: 1}, Kind: KindInvoke, Done: true}, + {Data: []byte("s1m2"), ID: ID{Stream: 1, Message: 2}, Kind: KindMessage, Done: true}, + {Data: []byte("s2m2"), ID: ID{Stream: 2, Message: 2}, Kind: KindMessage, Done: true}, + {Data: nil, ID: ID{Stream: 1, Message: 3}, Kind: KindCloseSend, Done: true}, + {Data: nil, ID: ID{Stream: 2, Message: 3}, Kind: KindCloseSend, Done: true}, + }, + }, + + { // cross-stream: within-stream monotonicity is still enforced + Packets: []Packet{ + {Data: []byte("a"), ID: ID{Stream: 1, Message: 1}, Kind: KindMessage}, + {Data: []byte("b"), ID: ID{Stream: 2, Message: 1}, Kind: KindMessage}, + }, + Frames: []Frame{ + {Data: []byte("a"), ID: ID{Stream: 1, Message: 1}, Kind: KindMessage, Done: true}, + {Data: []byte("b"), ID: ID{Stream: 2, Message: 1}, Kind: KindMessage, Done: true}, + {Data: []byte("c"), ID: ID{Stream: 2, Message: 1}, Kind: KindMessage, Done: true}, // replay on stream 2 + }, + Error: "id monotonicity violation", + }, } for _, tc := range cases { diff --git a/drpcwire/writer.go b/drpcwire/writer.go index fe909cf..bfcf0ac 100644 --- a/drpcwire/writer.go +++ b/drpcwire/writer.go @@ -12,6 +12,15 @@ import ( "storj.io/drpc/drpcdebug" ) +// StreamWriter is the interface used by streams for writing packets. Each call +// to WritePacket must serialize the full data atomically so that frames from +// concurrent writers on different streams do not interleave on the wire. +type StreamWriter interface { + WritePacket(pkt Packet) error + Flush() error + Empty() bool +} + // // Writer // diff --git a/internal/drpcopts/stream.go b/internal/drpcopts/stream.go index 6ab0511..7459347 100644 --- a/internal/drpcopts/stream.go +++ b/internal/drpcopts/stream.go @@ -15,6 +15,7 @@ type Stream struct { kind drpc.StreamKind rpc string stats *drpcstats.Stats + mux bool } // GetStreamTransport returns the drpc.Transport stored in the options. @@ -46,3 +47,9 @@ func GetStreamStats(opts *Stream) *drpcstats.Stats { return opts.stats } // SetStreamStats sets the Stats stored in the options. func SetStreamStats(opts *Stream, stats *drpcstats.Stats) { opts.stats = stats } + +// GetStreamMux returns whether the stream is in multiplexing mode. +func GetStreamMux(opts *Stream) bool { return opts.mux } + +// SetStreamMux sets whether the stream is in multiplexing mode. +func SetStreamMux(opts *Stream, mux bool) { opts.mux = mux } diff --git a/internal/integration/benchmark_test.go b/internal/integration/benchmark_test.go new file mode 100644 index 0000000..1e53184 --- /dev/null +++ b/internal/integration/benchmark_test.go @@ -0,0 +1,348 @@ +// Copyright (C) 2026 Cockroach Labs. +// See LICENSE for copying information. + +package integration + +import ( + "context" + "errors" + "io" + "strconv" + "sync" + "sync/atomic" + "testing" + "time" + + "google.golang.org/protobuf/proto" + "storj.io/drpc/drpcconn" + "storj.io/drpc/drpctest" +) + +var echoServer = impl{ + Method1Fn: func(_ context.Context, in *In) (*Out, error) { + return &Out{Out: in.In, Data: in.Data}, nil + }, + + Method2Fn: func(stream DRPCService_Method2Stream) error { + var last *In + for { + in, err := stream.Recv() + if errors.Is(err, io.EOF) { + break + } else if err != nil { + return err + } + last = in + } + if last == nil { + last = &In{} + } + return stream.SendAndClose(&Out{Out: last.In, Data: last.Data}) + }, + + Method3Fn: func(in *In, stream DRPCService_Method3Stream) error { + out := &Out{Out: 1, Data: in.Data} + for i := int64(0); i < in.In; i++ { + if err := stream.Send(out); err != nil { + return err + } + } + return nil + }, + + Method4Fn: func(stream DRPCService_Method4Stream) error { + for { + in, err := stream.Recv() + if errors.Is(err, io.EOF) { + return nil + } else if err != nil { + return err + } + if err := stream.Send(&Out{Out: in.In, Data: in.Data}); err != nil { + return err + } + } + }, +} + +var sizes = []struct { + name string + data []byte +}{ + {"small", nil}, + {"1KB", make([]byte, 1024)}, + {"8KB", make([]byte, 8192)}, +} + +var concurrencies = []int{1, 10, 100} +var activeStreamCounts = []int{2, 8, 32} + +const parallelActiveStreamsWriteLatency = 200 * time.Microsecond + +func BenchmarkUnary(b *testing.B) { + for _, sz := range sizes { + b.Run("size="+sz.name, func(b *testing.B) { + for _, c := range concurrencies { + b.Run("concurrent="+strconv.Itoa(c), func(b *testing.B) { + client, cleanup := createConnection(b, echoServer) + defer cleanup() + + in := &In{In: 1, Data: sz.data} + b.SetBytes(int64(proto.Size(in))) + b.ReportAllocs() + b.SetParallelism(c) + b.ResetTimer() + + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + if _, err := client.Method1(context.Background(), in); err != nil { + b.Error(err) + } + } + }) + }) + } + }) + } +} + +func BenchmarkInputStream(b *testing.B) { + for _, sz := range sizes { + b.Run("size="+sz.name, func(b *testing.B) { + client, cleanup := createConnection(b, echoServer) + defer cleanup() + + in := &In{In: 1, Data: sz.data} + stream, err := client.Method2(context.Background()) + if err != nil { + b.Fatal(err) + } + + b.SetBytes(int64(proto.Size(in))) + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + if err := stream.Send(in); err != nil { + b.Fatal(err) + } + } + + if _, err := stream.CloseAndRecv(); err != nil { + b.Fatal(err) + } + }) + } +} + +func BenchmarkOutputStream(b *testing.B) { + for _, sz := range sizes { + b.Run("size="+sz.name, func(b *testing.B) { + client, cleanup := createConnection(b, echoServer) + defer cleanup() + + in := &In{In: int64(b.N), Data: sz.data} + stream, err := client.Method3(context.Background(), in) + if err != nil { + b.Fatal(err) + } + + b.ReportAllocs() + b.ResetTimer() + + var last *Out + for i := 0; i < b.N; i++ { + last, err = stream.Recv() + if err != nil { + b.Fatal(err) + } + } + if last != nil { + b.SetBytes(int64(proto.Size(last))) + } + }) + } +} + +func BenchmarkBidiStream(b *testing.B) { + for _, sz := range sizes { + b.Run("size="+sz.name, func(b *testing.B) { + client, cleanup := createConnection(b, echoServer) + defer cleanup() + + in := &In{In: 1, Data: sz.data} + stream, err := client.Method4(context.Background()) + if err != nil { + b.Fatal(err) + } + + b.SetBytes(int64(proto.Size(in))) + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + if err := stream.Send(in); err != nil { + b.Fatal(err) + } + if _, err := stream.Recv(); err != nil { + b.Fatal(err) + } + } + }) + } +} + +func BenchmarkParallelActiveStreams(b *testing.B) { + benchmarkParallelActiveStreamsModes(b, 0) +} + +func BenchmarkParallelActiveStreamsWriteLatency(b *testing.B) { + benchmarkParallelActiveStreamsModes(b, parallelActiveStreamsWriteLatency) +} + +func benchmarkParallelActiveStreamsModes(b *testing.B, writeLatency time.Duration) { + transports := []struct { + name string + useTCP bool + }{ + {"Pipe", false}, + {"TCP", true}, + } + + modes := []struct { + name string + mux bool + oneConnPerStream bool + }{ + {"NonMux", false, true}, + {"Mux/ConnPerStream", true, true}, + {"Mux/SharedConn", true, false}, + } + + for _, sz := range sizes { + b.Run("size="+sz.name, func(b *testing.B) { + for _, streamCount := range activeStreamCounts { + b.Run("streams="+strconv.Itoa(streamCount), func(b *testing.B) { + for _, transport := range transports { + for _, mode := range modes { + b.Run(transport.name+"/"+mode.name, func(b *testing.B) { + benchmarkParallelActiveStreams(b, sz.data, streamCount, mode.mux, mode.oneConnPerStream, writeLatency, transport.useTCP) + }) + } + } + }) + } + }) + } +} + +func benchmarkParallelActiveStreams( + b *testing.B, payload []byte, streamCount int, mux bool, oneConnPerStream bool, writeLatency time.Duration, useTCP bool, +) { + ctx := drpctest.NewTracker(b) + defer ctx.Close() + + type bidiClient interface { + Send(*In) error + Recv() (*Out, error) + Close() error + } + + type worker struct { + stream bidiClient + close func() + } + workers := make([]worker, 0, streamCount) + + opts := connTransportOpts{ + writeDelay: writeLatency, + useTCP: useTCP, + mux: mux, + } + + var sharedConn *drpcconn.Conn + if !oneConnPerStream { + sharedConn = createConnectionWithTransport(b, echoServer, ctx, opts) + } + + for i := 0; i < streamCount; i++ { + conn := sharedConn + if oneConnPerStream { + conn = createConnectionWithTransport(b, echoServer, ctx, opts) + } + + client := NewDRPCServiceClient(conn) + stream, err := client.Method4(context.Background()) + if err != nil { + b.Fatalf("create stream %d: %v", i, err) + } + + s := stream + c := conn + workers = append(workers, worker{ + stream: s, + close: func() { + _ = s.Close() + if oneConnPerStream { + _ = c.Close() + } + }, + }) + } + + input := &In{In: 1, Data: payload} + b.SetBytes(int64(proto.Size(input))) + b.ReportAllocs() + + var sent atomic.Int64 + errCh := make(chan error, 1) + start := make(chan struct{}) + var wg sync.WaitGroup + + b.ResetTimer() + for _, w := range workers { + wg.Add(1) + go func(w worker) { + defer wg.Done() + msg := &In{In: 1, Data: payload} + <-start + + for { + n := sent.Add(1) + if n > int64(b.N) { + return + } + if err := w.stream.Send(msg); err != nil { + select { + case errCh <- err: + default: + } + return + } + if _, err := w.stream.Recv(); err != nil { + select { + case errCh <- err: + default: + } + return + } + } + }(w) + } + close(start) + wg.Wait() + b.StopTimer() + + select { + case err := <-errCh: + b.Fatal(err) + default: + } + + for _, w := range workers { + w.close() + } + if sharedConn != nil { + _ = sharedConn.Close() + } +} diff --git a/internal/integration/common_test.go b/internal/integration/common_test.go index 5acb61a..aa786c5 100644 --- a/internal/integration/common_test.go +++ b/internal/integration/common_test.go @@ -10,6 +10,7 @@ import ( "strconv" "sync" "testing" + "time" "github.com/zeebo/assert" "google.golang.org/grpc/codes" @@ -40,18 +41,93 @@ func in(n int64) *In { return &In{In: n} } func out(n int64) *Out { return &Out{Out: n} } func createRawConnection(t testing.TB, server DRPCServiceServer, ctx *drpctest.Tracker) *drpcconn.Conn { - c1, c2 := net.Pipe() + return createRawConnectionWithWriteDelay(t, server, ctx, 0) +} + +func createRawConnectionWithWriteDelay( + t testing.TB, server DRPCServiceServer, ctx *drpctest.Tracker, writeDelay time.Duration, +) *drpcconn.Conn { + return createConnectionWithTransport(t, server, ctx, connTransportOpts{writeDelay: writeDelay}) +} + +func createTCPConnectionWithWriteDelay( + t testing.TB, server DRPCServiceServer, ctx *drpctest.Tracker, writeDelay time.Duration, +) *drpcconn.Conn { + return createConnectionWithTransport(t, server, ctx, connTransportOpts{writeDelay: writeDelay, useTCP: true}) +} + +type connTransportOpts struct { + writeDelay time.Duration + useTCP bool + mux bool +} + +func createConnectionWithTransport( + t testing.TB, server DRPCServiceServer, ctx *drpctest.Tracker, opts connTransportOpts, +) *drpcconn.Conn { + var serverConn, clientConn net.Conn + if opts.useTCP { + lis, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { lis.Close() }) + + type result struct { + conn net.Conn + err error + } + ch := make(chan result, 1) + go func() { + c, err := lis.Accept() + ch <- result{c, err} + }() + + c, err := net.Dial("tcp", lis.Addr().String()) + if err != nil { + t.Fatal(err) + } + r := <-ch + if r.err != nil { + t.Fatal(r.err) + } + serverConn, clientConn = r.conn, c + t.Cleanup(func() { + serverConn.Close() + clientConn.Close() + }) + } else { + c1, c2 := net.Pipe() + serverConn, clientConn = c1, c2 + } + + if opts.writeDelay > 0 { + serverConn = &writeLatencyConn{Conn: serverConn, writeDelay: opts.writeDelay} + clientConn = &writeLatencyConn{Conn: clientConn, writeDelay: opts.writeDelay} + } + managerOpts := drpcmanager.Options{ + SoftCancel: true, + Mux: opts.mux, + } mux := drpcmux.New() assert.NoError(t, DRPCRegisterService(mux, server)) - srv := drpcserver.New(mux) - ctx.Run(func(ctx context.Context) { _ = srv.ServeOne(ctx, c1) }) - return drpcconn.NewWithOptions(c2, drpcconn.Options{ - Manager: drpcmanager.Options{ - SoftCancel: true, - }, + srv := drpcserver.NewWithOptions(mux, drpcserver.Options{Manager: managerOpts}) + ctx.Run(func(ctx context.Context) { _ = srv.ServeOne(ctx, serverConn) }) + return drpcconn.NewWithOptions(clientConn, drpcconn.Options{ + Manager: managerOpts, }) } +type writeLatencyConn struct { + net.Conn + writeDelay time.Duration +} + +func (c *writeLatencyConn) Write(p []byte) (int, error) { + time.Sleep(c.writeDelay) + return c.Conn.Write(p) +} + func createConnection(t testing.TB, server DRPCServiceServer) (DRPCServiceClient, func()) { ctx := drpctest.NewTracker(t) conn := createRawConnection(t, server, ctx)