From ffc17bf33fb4d8c2b36f5bfc72c3143dddf390dd Mon Sep 17 00:00:00 2001 From: Shubham Dhama Date: Wed, 18 Feb 2026 11:10:19 +0530 Subject: [PATCH 1/2] drpc: add stream multiplexing for concurrent RPCs over a single connection The DRPC manager previously allowed only one active stream at a time, enforced by a semaphore. Clients had to wait for one RPC to finish before starting another. This commit replaces that with multiplexing, where multiple streams share a single transport concurrently. To remove the single-stream restriction, the semaphore and its associated waitForPreviousStream logic are removed. In its place, a streamRegistry tracks all active streams by ID so the reader can route incoming packets to the right stream. With multiple streams writing concurrently, the old per-stream drpcwire.Writer is no longer safe. A sharedWriteBuf collects serialized frame bytes from all streams under a short-held mutex, and a dedicated manageWriter goroutine drains the buffer and writes to the transport. This also naturally batches frames that accumulate during writes. The packetBuffer previously used a single-slot design where the producer blocked until the consumer finished. With multiplexing, the reader must deliver packets to any stream without waiting, so packetBuffer is reworked into an unbounded queue with sync.Pool-based buffer recycling. The wire reader's monotonicity check previously rejected any frame with a lower stream ID than the last seen. This is relaxed to per-stream scope so interleaved packets from different streams are accepted. On the server side, ServeOne now dispatches RPCs concurrently via goroutines instead of handling them sequentially. Invoke metadata is tracked per stream ID so interleaved metadata/invoke sequences from different streams are correctly associated. Also adds a `--generate-adapters` flag to protoc-gen-go-drpc for optionally skipping gRPC/DRPC adapter codegen, and integration benchmarks comparing multiplexed vs one-connection-per-stream performance. --- .gitignore | 1 + AGENTS.md | 1 + Makefile | 4 + cmd/protoc-gen-go-drpc/main.go | 32 +- drpcconn/conn.go | 30 +- drpcconn/conn_test.go | 5 +- drpcmanager/frame_queue_test.go | 84 +++++ drpcmanager/manager.go | 317 +++++++--------- drpcmanager/manager_test.go | 119 +++++- drpcmanager/mux_writer.go | 107 ++++++ drpcmanager/mux_writer_test.go | 83 +++++ drpcmanager/registry.go | 89 +++++ drpcmanager/registry_test.go | 136 +++++++ drpcserver/server.go | 19 +- drpcstream/pktbuf.go | 75 ++-- drpcstream/stream.go | 25 +- drpcwire/reader.go | 5 +- drpcwire/reader_test.go | 45 +++ drpcwire/writer.go | 7 + go.mod | 5 +- go.sum | 26 +- internal/integration/Makefile | 68 ++++ internal/integration/benchmark_test.go | 338 ++++++++++++++++++ internal/integration/common_test.go | 75 +++- .../integration/customservice/service.pb.go | 82 ++--- .../customservice/service_drpc.pb.go | 10 +- internal/integration/doc.go | 6 +- internal/integration/go.mod | 10 + internal/integration/go.sum | 22 ++ .../gogoservice/service_drpc.pb.go | 10 +- internal/integration/service/service.pb.go | 82 ++--- .../integration/service/service_drpc.pb.go | 10 +- 32 files changed, 1504 insertions(+), 424 deletions(-) create mode 100644 AGENTS.md create mode 100644 drpcmanager/frame_queue_test.go create mode 100644 drpcmanager/mux_writer.go create mode 100644 drpcmanager/mux_writer_test.go create mode 100644 drpcmanager/registry.go create mode 100644 drpcmanager/registry_test.go create mode 100644 internal/integration/Makefile create mode 100644 internal/integration/benchmark_test.go diff --git a/.gitignore b/.gitignore index 00856f8..b9cc35c 100644 --- a/.gitignore +++ b/.gitignore @@ -8,3 +8,4 @@ cmd/protoc-gen-go-drpc/protoc-gen-go-drpc /WORKSPACE BUILD.bazel MODULE.bazel* +debug/* \ No newline at end of file diff --git a/AGENTS.md b/AGENTS.md new file mode 100644 index 0000000..f2c1b09 --- /dev/null +++ b/AGENTS.md @@ -0,0 +1 @@ +This fork of DRPC library is maintained and used by [CockroachDB](https://github.com/cockroachdb/cockroach) and is customized for CockroachDB's needs. \ No newline at end of file diff --git a/Makefile b/Makefile index b5f49ad..de2214f 100644 --- a/Makefile +++ b/Makefile @@ -25,6 +25,10 @@ lint: staticcheck $(PKG) golangci-lint run +.PHONY: install-protoc-plugin +install-protoc-plugin: + $(GO) install ./cmd/protoc-gen-go-drpc/ + .PHONY: gen-bazel gen-bazel: @echo "Generating WORKSPACE" diff --git a/cmd/protoc-gen-go-drpc/main.go b/cmd/protoc-gen-go-drpc/main.go index 39ddfbb..3cd2769 100644 --- a/cmd/protoc-gen-go-drpc/main.go +++ b/cmd/protoc-gen-go-drpc/main.go @@ -16,8 +16,9 @@ import ( ) type config struct { - protolib string - json bool + protolib string + json bool + generateAdapters bool } func main() { @@ -25,6 +26,7 @@ func main() { var conf config flags.StringVar(&conf.protolib, "protolib", "google.golang.org/protobuf", "which protobuf library to use for encoding") flags.BoolVar(&conf.json, "json", true, "generate encoders with json support") + flags.BoolVar(&conf.generateAdapters, "generate-adapters", true, "generate gRPC/DRPC adapter and RPC interface code") protogen.Options{ ParamFunc: flags.Set, @@ -55,7 +57,7 @@ func generateFile(plugin *protogen.Plugin, file *protogen.File, conf config) { d.generateEncoding(conf) for _, service := range file.Services { - d.generateService(service) + d.generateService(service, conf) } } @@ -267,7 +269,7 @@ func (d *drpc) generateEncoding(conf config) { // service generation // -func (d *drpc) generateService(service *protogen.Service) { +func (d *drpc) generateService(service *protogen.Service, conf config) { // Client interface d.P("type ", d.ClientIface(service), " interface {") d.P("DRPCConn() ", d.Ident("storj.io/drpc", "Conn")) @@ -294,7 +296,7 @@ func (d *drpc) generateService(service *protogen.Service) { d.P("func (c *", d.ClientImpl(service), ") DRPCConn() ", d.Ident("storj.io/drpc", "Conn"), "{ return c.cc }") d.P() for _, method := range service.Methods { - d.generateClientMethod(method) + d.generateClientMethod(method, conf) } // Server interface @@ -339,11 +341,13 @@ func (d *drpc) generateService(service *protogen.Service) { // Server methods for _, method := range service.Methods { - d.generateServerMethod(method) + d.generateServerMethod(method, conf) } - d.generateServiceRPCInterfaces(service) - d.generateServiceAdapters(service) + if conf.generateAdapters { + d.generateServiceRPCInterfaces(service) + d.generateServiceAdapters(service) + } } // @@ -362,7 +366,7 @@ func (d *drpc) generateClientSignature(method *protogen.Method) string { return fmt.Sprintf("%s(ctx %s%s) (%s, error)", method.GoName, d.Ident("context", "Context"), reqArg, respName) } -func (d *drpc) generateClientMethod(method *protogen.Method) { +func (d *drpc) generateClientMethod(method *protogen.Method, conf config) { recvType := d.ClientImpl(method.Parent) outType := d.OutputType(method) inType := d.InputType(method) @@ -408,7 +412,9 @@ func (d *drpc) generateClientMethod(method *protogen.Method) { d.P("}") d.P() - d.generateRPCClientInterface(method) + if conf.generateAdapters { + d.generateRPCClientInterface(method) + } d.P("type ", d.ClientStreamImpl(method), " struct {") d.P(d.Ident("storj.io/drpc", "Stream")) @@ -510,7 +516,7 @@ func (d *drpc) generateServerReceiver(method *protogen.Method) { d.P(")") } -func (d *drpc) generateServerMethod(method *protogen.Method) { +func (d *drpc) generateServerMethod(method *protogen.Method, conf config) { genSend := method.Desc.IsStreamingServer() genSendAndClose := !method.Desc.IsStreamingServer() genRecv := method.Desc.IsStreamingClient() @@ -531,7 +537,9 @@ func (d *drpc) generateServerMethod(method *protogen.Method) { d.P("}") d.P() - d.generateRPCServerInterface(method) + if conf.generateAdapters { + d.generateRPCServerInterface(method) + } d.P("type ", d.ServerStreamImpl(method), " struct {") d.P(d.Ident("storj.io/drpc", "Stream")) diff --git a/drpcconn/conn.go b/drpcconn/conn.go index 636f346..33a50a4 100644 --- a/drpcconn/conn.go +++ b/drpcconn/conn.go @@ -31,10 +31,9 @@ type Options struct { // Conn is a drpc client connection. type Conn struct { - tr drpc.Transport - man *drpcmanager.Manager - mu sync.Mutex - wbuf []byte + tr drpc.Transport + man *drpcmanager.Manager + mu sync.Mutex // protects stats stats map[string]*drpcstats.Stats } @@ -92,16 +91,16 @@ func (c *Conn) Transport() drpc.Transport { return c.tr } // Closed returns a channel that is closed once the connection is closed. func (c *Conn) Closed() <-chan struct{} { return c.man.Closed() } -// Unblocked returns a channel that is closed once the connection is no longer -// blocked by a previously canceled Invoke or NewStream call. It should not -// be called concurrently with Invoke or NewStream. +// Unblocked returns a channel that is closed when the connection is available +// for new streams. With multiplexing enabled, this always returns an +// already-closed channel. func (c *Conn) Unblocked() <-chan struct{} { return c.man.Unblocked() } // Close closes the connection. func (c *Conn) Close() (err error) { return c.man.Close() } // Invoke issues the rpc on the transport serializing in, waits for a response, and -// deserializes it into out. Only one Invoke or Stream may be open at a time. +// deserializes it into out. Multiple Invoke or Stream calls may be open concurrently. func (c *Conn) Invoke(ctx context.Context, rpc string, enc drpc.Encoding, in, out drpc.Message) (err error) { defer func() { err = drpc.ToRPCErr(err) }() @@ -117,18 +116,13 @@ func (c *Conn) Invoke(ctx context.Context, rpc string, enc drpc.Encoding, in, ou } defer func() { err = errs.Combine(err, stream.Close()) }() - // we have to protect c.wbuf here even though the manager only allows one - // stream at a time because the stream may async close allowing another - // concurrent call to Invoke to proceed. - c.mu.Lock() - defer c.mu.Unlock() - - c.wbuf, err = drpcenc.MarshalAppend(in, enc, c.wbuf[:0]) + // Per-call buffer allocation for concurrent access. + data, err := drpcenc.MarshalAppend(in, enc, nil) if err != nil { return err } - if err := c.doInvoke(stream, enc, rpc, c.wbuf, metadata, out); err != nil { + if err := c.doInvoke(stream, enc, rpc, data, metadata, out); err != nil { return err } return nil @@ -155,8 +149,8 @@ func (c *Conn) doInvoke(stream *drpcstream.Stream, enc drpc.Encoding, rpc string return nil } -// NewStream begins a streaming rpc on the connection. Only one Invoke or Stream may -// be open at a time. +// NewStream begins a streaming rpc on the connection. Multiple Invoke or Stream calls +// may be open concurrently. func (c *Conn) NewStream(ctx context.Context, rpc string, enc drpc.Encoding) (_ drpc.Stream, err error) { defer func() { err = drpc.ToRPCErr(err) }() diff --git a/drpcconn/conn_test.go b/drpcconn/conn_test.go index 9f74516..2550c3b 100644 --- a/drpcconn/conn_test.go +++ b/drpcconn/conn_test.go @@ -180,7 +180,10 @@ func TestConn_NewStreamSendsGrpcAndDrpcMetadata(t *testing.T) { }) s, err := conn.NewStream(ctx, "/com.example.Foo/Bar", testEncoding{}) assert.NoError(t, err) - _ = s.CloseSend() + + assert.NoError(t, s.CloseSend()) + + ctx.Wait() } func TestConn_encodeMetadata(t *testing.T) { diff --git a/drpcmanager/frame_queue_test.go b/drpcmanager/frame_queue_test.go new file mode 100644 index 0000000..dafd604 --- /dev/null +++ b/drpcmanager/frame_queue_test.go @@ -0,0 +1,84 @@ +// Copyright (C) 2026 Cockroach Labs. +// See LICENSE for copying information. + +package drpcmanager + +import ( + "testing" + + "github.com/zeebo/assert" + + "storj.io/drpc/drpcwire" +) + +func TestSharedWriteBuf_AppendDrain(t *testing.T) { + sw := newSharedWriteBuf() + + fr := drpcwire.Frame{ + Data: []byte("hello"), + ID: drpcwire.ID{Stream: 1, Message: 2}, + Kind: drpcwire.KindMessage, + Done: true, + } + + assert.NoError(t, sw.Append(fr)) + + // Drain should return serialized bytes. + data := sw.Drain(nil) + assert.That(t, len(data) > 0) + + // Parse the frame back out to verify correctness. + _, got, ok, err := drpcwire.ParseFrame(data) + assert.NoError(t, err) + assert.That(t, ok) + assert.DeepEqual(t, got.Data, fr.Data) + assert.Equal(t, got.ID.Stream, fr.ID.Stream) + assert.Equal(t, got.ID.Message, fr.ID.Message) + assert.Equal(t, got.Kind, fr.Kind) + assert.Equal(t, got.Done, fr.Done) +} + +func TestSharedWriteBuf_CloseIdempotent(t *testing.T) { + sw := newSharedWriteBuf() + sw.Close() + sw.Close() // must not panic +} + +func TestSharedWriteBuf_AppendAfterClose(t *testing.T) { + sw := newSharedWriteBuf() + sw.Close() + + err := sw.Append(drpcwire.Frame{}) + assert.Error(t, err) +} + +func TestSharedWriteBuf_WaitAndDrainBlocks(t *testing.T) { + sw := newSharedWriteBuf() + + done := make(chan struct{}) + go func() { + defer close(done) + data, ok := sw.WaitAndDrain(nil) + assert.That(t, ok) + assert.That(t, len(data) > 0) + }() + + // Append should wake the blocked WaitAndDrain. + assert.NoError(t, sw.Append(drpcwire.Frame{Data: []byte("a")})) + <-done +} + +func TestSharedWriteBuf_WaitAndDrainCloseEmpty(t *testing.T) { + sw := newSharedWriteBuf() + + done := make(chan struct{}) + go func() { + defer close(done) + _, ok := sw.WaitAndDrain(nil) + assert.That(t, !ok) + }() + + // Close on empty buffer should return ok=false. + sw.Close() + <-done +} diff --git a/drpcmanager/manager.go b/drpcmanager/manager.go index 1680a65..8a26476 100644 --- a/drpcmanager/manager.go +++ b/drpcmanager/manager.go @@ -10,6 +10,8 @@ import ( "io" "net" "strings" + "sync" + "sync/atomic" "syscall" "time" @@ -68,30 +70,26 @@ type Options struct { // to the appropriate stream. type Manager struct { tr drpc.Transport - wr *drpcwire.Writer rd *drpcwire.Reader opts Options - sem drpcsignal.Chan // held by the active stream - sbuf streamBuffer // largest stream id created - pkts chan drpcwire.Packet // channel for invoke packets - pdone drpcsignal.Chan // signals when a packets buffers can be reused - sfin chan struct{} // shared signal for stream finished - streams chan streamInfo // channel to signal that a stream should start + sw *sharedWriteBuf // shared write buffer for the writer goroutine + reg *streamRegistry // tracks all active streams by ID + streamID atomic.Uint64 // next stream ID for client streams + wg sync.WaitGroup // tracks active manageStream goroutines + pkts chan drpcwire.Packet // channel for invoke packets + pdone drpcsignal.Chan // signals when packet buffers can be reused + metaMu sync.Mutex + meta map[uint64]map[string]string // invoke metadata buffered by stream ID sigs struct { - term drpcsignal.Signal // set when the manager should start terminating - stream drpcsignal.Signal // set when the manage streams goroutine is done - read drpcsignal.Signal // set after the goroutine reading from the transport is done - tport drpcsignal.Signal // set after the transport has been closed + term drpcsignal.Signal // set when the manager should start terminating + write drpcsignal.Signal // set when the writer goroutine is done + read drpcsignal.Signal // set after the goroutine reading from the transport is done + tport drpcsignal.Signal // set after the transport has been closed } } -type streamInfo struct { - ctx context.Context - stream *drpcstream.Stream -} - // New returns a new Manager for the transport. func New(tr drpc.Transport) *Manager { return NewWithOptions(tr, Options{}) @@ -102,31 +100,25 @@ func New(tr drpc.Transport) *Manager { func NewWithOptions(tr drpc.Transport, opts Options) *Manager { m := &Manager{ tr: tr, - wr: drpcwire.NewWriter(tr, opts.WriterBufferSize), rd: drpcwire.NewReaderWithOptions(tr, opts.Reader), opts: opts, - pkts: make(chan drpcwire.Packet), - sfin: make(chan struct{}, 1), - streams: make(chan streamInfo), + pkts: make(chan drpcwire.Packet), + meta: make(map[uint64]map[string]string), } - // initialize the stream buffer - m.sbuf.init() - - // this semaphore controls the number of concurrent streams. it MUST be 1. - m.sem.Make(1) - // a buffer of size 1 allows the consumer of the packet to signal it is done // without having to coordinate with the sender of the packet. m.pdone.Make(1) // set the internal stream options drpcopts.SetStreamTransport(&m.opts.Stream.Internal, m.tr) - drpcopts.SetStreamFin(&m.opts.Stream.Internal, m.sfin) + + m.sw = newSharedWriteBuf() + m.reg = newStreamRegistry() go m.manageReader() - go m.manageStreams() + go m.manageWriter() return m } @@ -144,71 +136,44 @@ func (m *Manager) log(what string, cb func() string) { // helpers // -// acquireSemaphore attempts to acquire the semaphore protecting streams. If the -// context is canceled or the manager is terminated, it returns an error. -func (m *Manager) acquireSemaphore(ctx context.Context) error { - if err, ok := m.sigs.term.Get(); ok { - return err - } else if err := ctx.Err(); err != nil { - return err - } - - select { - case <-ctx.Done(): - return ctx.Err() - - case <-m.sigs.term.Signal(): - return m.sigs.term.Err() - - case m.sem.Get() <- struct{}{}: - if err := m.waitForPreviousStream(ctx); err != nil { - m.sem.Recv() - return err - } - return nil - } -} - -// waitForPreviousStream will, if there was a previous stream, ensure it is -// Closed and then wait until it is in the Finished state, where it will no -// longer make any reads or writes on the transport. It exits early if the -// context is canceled or the manager is terminated. -func (m *Manager) waitForPreviousStream(ctx context.Context) (err error) { - prev := m.sbuf.Get() - if prev == nil { - return nil - } - - // if the stream is not finished yet, we need to wait for it to be - // finished before letting the next stream to start. - if prev.IsFinished() { - return nil - } - - m.log("WAIT", prev.String) - - select { - case <-ctx.Done(): - return ctx.Err() - - case <-m.sigs.term.Signal(): - return m.sigs.term.Err() - - case <-prev.Finished(): - return nil - } -} - // terminate puts the Manager into a terminal state and closes any resources // that need to be closed to signal the state change. func (m *Manager) terminate(err error) { if m.sigs.term.Set(err) { m.log("TERM", func() string { return fmt.Sprint(err) }) m.sigs.tport.Set(m.tr.Close()) - m.sbuf.Close() + m.sw.Close() + m.metaMu.Lock() + for id := range m.meta { + delete(m.meta, id) + } + m.metaMu.Unlock() + // Cancel all active streams so they get a clear error. + cancelErr := err + if errors.Is(cancelErr, io.EOF) { + cancelErr = context.Canceled + } + m.reg.ForEach(func(_ uint64, s *drpcstream.Stream) { + s.Cancel(cancelErr) + }) + m.reg.Close() } } +func (m *Manager) putMetadata(streamID uint64, metadata map[string]string) { + m.metaMu.Lock() + defer m.metaMu.Unlock() + m.meta[streamID] = metadata +} + +func (m *Manager) popMetadata(streamID uint64) map[string]string { + m.metaMu.Lock() + defer m.metaMu.Unlock() + metadata := m.meta[streamID] + delete(m.meta, streamID) + return metadata +} + // // manage reader // @@ -250,46 +215,56 @@ func (m *Manager) manageReader() { m.log("READ", pkt.String) - again: - switch curr := m.sbuf.Get(); { - // if the packet is for the current stream, deliver it. - case curr != nil && pkt.ID.Stream == curr.ID(): - if err := curr.HandlePacket(pkt); err != nil { + stream, ok := m.reg.Get(pkt.ID.Stream) + + switch { + // if the packet is for a registered stream, deliver it. + case ok && stream != nil: + if err := stream.HandlePacket(pkt); err != nil { m.terminate(managerClosed.Wrap(err)) return } - - // if an old message has been sent, just ignore it. - case curr != nil && pkt.ID.Stream < curr.ID(): - - // if any invoke sequence is being sent, close any old unterminated - // stream and forward it to be handled. - case pkt.Kind == drpcwire.KindInvoke || pkt.Kind == drpcwire.KindInvokeMetadata: - if curr != nil && !curr.IsTerminated() { - curr.Cancel(context.Canceled) + // For message packets, HandlePacket transferred ownership of + // pkt.Data to the stream's packetBuffer. Acquire a fresh buffer + // from the pool so the next ReadPacketUsing doesn't allocate. + if pkt.Kind == drpcwire.KindMessage { + pkt.Data = drpcstream.AcquirePacketBuf() } + // if any invoke sequence is being sent, forward it to be handled. + case pkt.Kind == drpcwire.KindInvoke || pkt.Kind == drpcwire.KindInvokeMetadata: select { case m.pkts <- pkt: m.pdone.Recv() - case <-m.sigs.term.Signal(): return } - // a non-invoke packet should be delivered to some stream so we wait for - // a new stream to be created and try again. like an invoke, we - // implicitly close any previous stream. + // silently drop packet for an unregistered stream default: - if curr != nil && !curr.IsTerminated() { - curr.Cancel(context.Canceled) - } + m.log("DROP", pkt.String) + } + } +} - if !m.sbuf.Wait(curr.ID()) { - return - } - goto again +// manageWriter drains the shared write buffer and writes pre-serialized +// bytes directly to the transport. It blocks on the sharedWriteBuf's +// condition variable until data is available, and naturally batches +// frames that accumulate while the previous write is in flight. +func (m *Manager) manageWriter() { + defer m.sigs.write.Set(nil) + + var spare []byte + for { + data, ok := m.sw.WaitAndDrain(spare[:0:cap(spare)]) + if !ok { + return } + if _, err := m.tr.Write(data); err != nil { + m.terminate(managerClosed.Wrap(err)) + return + } + spare = data } } @@ -306,37 +281,25 @@ func (m *Manager) newStream(ctx context.Context, sid uint64, kind, rpc string) ( drpcopts.SetStreamStats(&opts.Internal, cb(rpc)) } - stream := drpcstream.NewWithOptions(ctx, sid, m.wr, opts) - select { - case m.streams <- streamInfo{ctx: ctx, stream: stream}: - m.sbuf.Set(stream) - m.log("STREAM", stream.String) - return stream, nil + stream := drpcstream.NewWithOptions(ctx, sid, &muxWriter{sw: m.sw}, opts) - case <-m.sigs.term.Signal(): - return nil, m.sigs.term.Err() + if err := m.reg.Register(sid, stream); err != nil { + return nil, err } -} -// manageStreams reads from the streams channel for stream infos and runs the -// manageStream function on them. -func (m *Manager) manageStreams() { - defer m.sigs.stream.Set(nil) + m.wg.Add(1) + go m.manageStream(ctx, stream) - for { - select { - case si := <-m.streams: - m.manageStream(si.ctx, si.stream) - - case <-m.sigs.term.Signal(): - return - } - } + m.log("STREAM", stream.String) + return stream, nil } // manageStream watches the context and the stream and returns when the stream // is finished, canceling the stream if the context is canceled. func (m *Manager) manageStream(ctx context.Context, stream *drpcstream.Stream) { + defer m.wg.Done() + defer m.reg.Unregister(stream.ID()) + select { case <-m.sigs.term.Signal(): err := m.sigs.term.Err() @@ -344,47 +307,34 @@ func (m *Manager) manageStream(ctx context.Context, stream *drpcstream.Stream) { err = context.Canceled } stream.Cancel(err) - <-m.sfin - m.sem.Recv() + <-stream.Finished() - case <-m.sfin: - m.sem.Recv() + case <-stream.Finished(): + // stream finished naturally case <-ctx.Done(): m.log("CANCEL", stream.String) if m.opts.SoftCancel { - // allow a new stream to begin. - m.sem.Recv() - - // attempt to send the soft cancel. if it fails or if the stream is - // busy sending something else, then we have to hard cancel. + // Best-effort send KindCancel, never terminate connection. if busy, err := stream.SendCancel(ctx.Err()); err != nil { - m.terminate(err) + m.log("CANCEL_ERR", func() string { + return fmt.Sprintf("%s: %v", stream.String(), err) + }) } else if busy { - m.log("BUSY", stream.String) - m.terminate(ctx.Err()) + m.log("CANCEL_BUSY", stream.String) } stream.Cancel(ctx.Err()) - - // wait for the stream to signal that it is finished. - <-m.sfin + <-stream.Finished() } else { - // If the stream isn't already finished, we have to terminate the - // transport to do an active cancel. If it is already finished, - // there is no need. + // Hard cancel: terminate connection if stream not finished. if !stream.Cancel(ctx.Err()) { m.log("UNFIN", stream.String) m.terminate(ctx.Err()) } else { m.log("CLEAN", stream.String) } - - // wait for the stream to signal that it is finished. - <-m.sfin - - // allow a new stream to begin. - m.sem.Recv() + <-stream.Finished() } } } @@ -398,15 +348,10 @@ func (m *Manager) Closed() <-chan struct{} { return m.sigs.term.Signal() } -// Unblocked returns a channel that is closed when the manager is no longer -// blocked from creating a new stream due to a previous stream's soft cancel. It -// should not be called concurrently with NewClientStream or NewServerStream and -// the return result is only valid until the next call to NewClientStream or -// NewServerStream. +// Unblocked returns a channel that is closed when the manager is available for +// new streams. With multiplexing enabled, the connection is never blocked, so +// this always returns an already-closed channel. func (m *Manager) Unblocked() <-chan struct{} { - if prev := m.sbuf.Get(); prev != nil { - return prev.Context().Done() - } return closedCh } @@ -414,7 +359,8 @@ func (m *Manager) Unblocked() <-chan struct{} { func (m *Manager) Close() error { m.terminate(managerClosed.New("Close called")) - m.sigs.stream.Wait() + m.wg.Wait() // wait for all stream goroutines + m.sigs.write.Wait() m.sigs.read.Wait() m.sigs.tport.Wait() @@ -423,28 +369,21 @@ func (m *Manager) Close() error { // NewClientStream starts a stream on the managed transport for use by a client. func (m *Manager) NewClientStream(ctx context.Context, rpc string) (stream *drpcstream.Stream, err error) { - if err := m.acquireSemaphore(ctx); err != nil { + if err, ok := m.sigs.term.Get(); ok { return nil, err } - - return m.newStream(ctx, m.sbuf.Get().ID()+1, "cli", rpc) + sid := m.streamID.Add(1) + return m.newStream(ctx, sid, "cli", rpc) } // NewServerStream starts a stream on the managed transport for use by a server. // It does this by waiting for the client to issue an invoke message and // returning the details. func (m *Manager) NewServerStream(ctx context.Context) (stream *drpcstream.Stream, rpc string, err error) { - if err := m.acquireSemaphore(ctx); err != nil { + if err, ok := m.sigs.term.Get(); ok { return nil, "", err } - defer func() { - if err != nil { - m.sem.Recv() - } - }() - var meta map[string]string - var metaID uint64 var timeoutCh <-chan time.Time // set up the timeout on the context if necessary. @@ -467,38 +406,40 @@ func (m *Manager) NewServerStream(ctx context.Context) (stream *drpcstream.Strea case pkt := <-m.pkts: switch pkt.Kind { - // keep track of any metadata being sent before an invoke so that we - // can include it if the stream id matches the eventual invoke. case drpcwire.KindInvokeMetadata: - meta, err = drpcmetadata.Decode(pkt.Data) - m.pdone.Send() - + metadata, err := drpcmetadata.Decode(pkt.Data) if err != nil { + m.pdone.Send() return nil, "", err } - metaID = pkt.ID.Stream + m.putMetadata(pkt.ID.Stream, metadata) + m.pdone.Send() case drpcwire.KindInvoke: rpc = string(pkt.Data) - m.pdone.Send() + streamCtx := ctx - if metaID == pkt.ID.Stream { + if metadata := m.popMetadata(pkt.ID.Stream); metadata != nil { if m.opts.GRPCMetadataCompatMode { // Populate incoming metadata as grpc metadata in the // context. This is a short-term fix that will enable us // to send and receive grpc metadata when DRPC is enabled, // without any changes in the calling code. - grpcMeta := make(map[string][]string, len(meta)) - for k, v := range meta { + grpcMeta := make(map[string][]string, len(metadata)) + for k, v := range metadata { grpcMeta[k] = []string{v} } - ctx = grpcmetadata.NewIncomingContext(ctx, grpcMeta) + streamCtx = grpcmetadata.NewIncomingContext(streamCtx, grpcMeta) } else { // Add metadata to the incoming context. - ctx = drpcmetadata.NewIncomingContext(ctx, meta) + streamCtx = drpcmetadata.NewIncomingContext(streamCtx, metadata) } } - stream, err := m.newStream(ctx, pkt.ID.Stream, "srv", rpc) + stream, err := m.newStream(streamCtx, pkt.ID.Stream, "srv", rpc) + // Ack the invoke only after stream registration so subsequent + // message packets cannot be dropped for an unknown stream ID. + // Always ack, even on error, so the reader goroutine does not block. + m.pdone.Send() return stream, rpc, err default: diff --git a/drpcmanager/manager_test.go b/drpcmanager/manager_test.go index 5918113..cbca79e 100644 --- a/drpcmanager/manager_test.go +++ b/drpcmanager/manager_test.go @@ -20,15 +20,6 @@ import ( "storj.io/drpc/drpcwire" ) -func closed(ch <-chan struct{}) bool { - select { - case <-ch: - return true - default: - return false - } -} - func TestTimeout(t *testing.T) { tr := make(blockingTransport) man := NewWithOptions(tr, Options{ @@ -69,10 +60,8 @@ func TestDrpcMetadata(t *testing.T) { assert.NoError(t, stream.RawWrite(drpcwire.KindInvoke, []byte("invoke"))) assert.NoError(t, stream.RawWrite(drpcwire.KindMessage, []byte("message"))) assert.NoError(t, stream.RawFlush()) - assert.That(t, !closed(cman.Unblocked())) assert.NoError(t, stream.Close()) - assert.That(t, closed(cman.Unblocked())) }) ctx.Run(func(ctx context.Context) { @@ -129,10 +118,8 @@ func TestDrpcMetadataWithGRPCMetadataCompatMode(t *testing.T) { assert.NoError(t, stream.RawWrite(drpcwire.KindInvoke, []byte("invoke"))) assert.NoError(t, stream.RawWrite(drpcwire.KindMessage, []byte("message"))) assert.NoError(t, stream.RawFlush()) - assert.That(t, !closed(cman.Unblocked())) assert.NoError(t, stream.Close()) - assert.That(t, closed(cman.Unblocked())) }) ctx.Run(func(ctx context.Context) { @@ -161,6 +148,103 @@ func TestDrpcMetadataWithGRPCMetadataCompatMode(t *testing.T) { ctx.Wait() } +func TestDrpcMetadataInterleavedAcrossStreams(t *testing.T) { + ctx := drpctest.NewTracker(t) + defer ctx.Close() + + cconn, sconn := net.Pipe() + defer func() { _ = cconn.Close() }() + defer func() { _ = sconn.Close() }() + + cman := New(cconn) + defer func() { _ = cman.Close() }() + + sman := NewWithOptions(sconn, Options{ + GRPCMetadataCompatMode: false, + }) + defer func() { _ = sman.Close() }() + + stream1, err := cman.NewClientStream(ctx, "rpc-1") + assert.NoError(t, err) + defer func() { _ = stream1.Close() }() + + stream2, err := cman.NewClientStream(ctx, "rpc-2") + assert.NoError(t, err) + defer func() { _ = stream2.Close() }() + + metadata1 := map[string]string{"stream": "one"} + metadata2 := map[string]string{"stream": "two"} + + buf1, err := drpcmetadata.Encode(nil, metadata1) + assert.NoError(t, err) + buf2, err := drpcmetadata.Encode(nil, metadata2) + assert.NoError(t, err) + + assert.NoError(t, stream1.RawWrite(drpcwire.KindInvokeMetadata, buf1)) + assert.NoError(t, stream2.RawWrite(drpcwire.KindInvokeMetadata, buf2)) + assert.NoError(t, stream1.RawWrite(drpcwire.KindInvoke, []byte("rpc-1"))) + assert.NoError(t, stream2.RawWrite(drpcwire.KindInvoke, []byte("rpc-2"))) + + srvStream1, rpc1, err := sman.NewServerStream(ctx) + assert.NoError(t, err) + assert.Equal(t, "rpc-1", rpc1) + defer func() { _ = srvStream1.Close() }() + + got1, ok := drpcmetadata.GetFromIncomingContext(srvStream1.Context()) + assert.That(t, ok) + assert.Equal(t, metadata1, got1) + + srvStream2, rpc2, err := sman.NewServerStream(ctx) + assert.NoError(t, err) + assert.Equal(t, "rpc-2", rpc2) + defer func() { _ = srvStream2.Close() }() + + got2, ok := drpcmetadata.GetFromIncomingContext(srvStream2.Context()) + assert.That(t, ok) + assert.Equal(t, metadata2, got2) +} + +func TestNewServerStreamUnreadMessageDoesNotBlockOtherStreams(t *testing.T) { + ctx := drpctest.NewTracker(t) + defer ctx.Close() + + cconn, sconn := net.Pipe() + defer func() { _ = cconn.Close() }() + defer func() { _ = sconn.Close() }() + + cman := New(cconn) + defer func() { _ = cman.Close() }() + + sman := New(sconn) + defer func() { _ = sman.Close() }() + + stream1, err := cman.NewClientStream(ctx, "rpc-1") + assert.NoError(t, err) + defer func() { _ = stream1.Close() }() + + stream2, err := cman.NewClientStream(ctx, "rpc-2") + assert.NoError(t, err) + defer func() { _ = stream2.Close() }() + + assert.NoError(t, stream1.RawWrite(drpcwire.KindInvoke, []byte("rpc-1"))) + assert.NoError(t, stream1.RawWrite(drpcwire.KindMessage, []byte("message-1"))) + assert.NoError(t, stream2.RawWrite(drpcwire.KindInvoke, []byte("rpc-2"))) + + srvStream1, rpc1, err := sman.NewServerStream(ctx) + assert.NoError(t, err) + assert.Equal(t, "rpc-1", rpc1) + defer func() { _ = srvStream1.Close() }() + + // Do not read the first stream's message. The manager must still be able to + // accept and register additional streams. + timeoutCtx, cancel := context.WithTimeout(ctx, time.Second) + defer cancel() + srvStream2, rpc2, err := sman.NewServerStream(timeoutCtx) + assert.NoError(t, err) + assert.Equal(t, "rpc-2", rpc2) + defer func() { _ = srvStream2.Close() }() +} + type blockingTransport chan struct{} func (b blockingTransport) Read(p []byte) (n int, err error) { <-b; return 0, io.EOF } @@ -189,10 +273,8 @@ func TestUnblocked_NoCancel(t *testing.T) { assert.NoError(t, stream.RawWrite(drpcwire.KindInvoke, []byte("invoke"))) assert.NoError(t, stream.RawWrite(drpcwire.KindMessage, []byte("message"))) assert.NoError(t, stream.RawFlush()) - assert.That(t, !closed(cman.Unblocked())) assert.NoError(t, stream.Close()) - assert.That(t, closed(cman.Unblocked())) }) ctx.Run(func(ctx context.Context) { @@ -230,17 +312,20 @@ func TestUnblocked_SoftCancel(t *testing.T) { if softCancel { assert.NoError(t, err) } else if i > 0 { + // Hard cancel terminates the connection, so subsequent streams fail. assert.Error(t, err) return + } else { + assert.NoError(t, err) } defer func() { _ = stream.Close() }() - assert.That(t, !closed(man.Unblocked())) cancel() // temporary unblock writing to allow the stream to finish soft cancel tr.setWriteOpen(true) - <-man.Unblocked() + // With multiplexing, we wait for the stream to finish instead of Unblocked(). + <-stream.Finished() tr.setWriteOpen(false) }() } diff --git a/drpcmanager/mux_writer.go b/drpcmanager/mux_writer.go new file mode 100644 index 0000000..9794779 --- /dev/null +++ b/drpcmanager/mux_writer.go @@ -0,0 +1,107 @@ +// Copyright (C) 2026 Cockroach Labs. +// See LICENSE for copying information. + +package drpcmanager + +import ( + "sync" + + "storj.io/drpc/drpcwire" +) + +// muxWriter implements drpcwire.StreamWriter by serializing frame bytes into a +// shared write buffer. The manageWriter goroutine drains the buffer and writes +// directly to the transport. +// +// Compared to the previous frameQueue approach, this avoids: +// - copying frame payload into an intermediate queue slot, +// - drpcwire.Writer mutex overhead in the writer goroutine. +// +// Frames are serialized (via AppendFrame) into the shared buffer under a +// short-held mutex. The frame's Data slice is consumed before WriteFrame +// returns, so callers may safely reuse their buffers afterward. +type muxWriter struct { + sw *sharedWriteBuf +} + +func (w *muxWriter) WriteFrame(fr drpcwire.Frame) error { + return w.sw.Append(fr) +} + +// Flush is a no-op because the manageWriter goroutine flushes to the +// transport after draining the shared buffer. +func (w *muxWriter) Flush() error { return nil } + +func (w *muxWriter) Empty() bool { return true } + +// sharedWriteBuf collects serialized frame bytes from multiple concurrent +// producers. A single consumer (manageWriter) drains the buffer and writes +// the pre-serialized bytes to the transport. +type sharedWriteBuf struct { + mu sync.Mutex + cond *sync.Cond + buf []byte + closed bool +} + +func newSharedWriteBuf() *sharedWriteBuf { + sw := &sharedWriteBuf{} + sw.cond = sync.NewCond(&sw.mu) + return sw +} + +// Append serializes fr into the shared buffer. The frame's Data slice is +// consumed (copied by AppendFrame) before Append returns. +func (sw *sharedWriteBuf) Append(fr drpcwire.Frame) error { + sw.mu.Lock() + if sw.closed { + sw.mu.Unlock() + return managerClosed.New("enqueue") + } + sw.buf = drpcwire.AppendFrame(sw.buf, fr) + sw.mu.Unlock() + + sw.cond.Signal() + return nil +} + +// Drain swaps out accumulated bytes, giving the caller ownership of the +// returned slice. The internal buffer is replaced with spare (reset to zero +// length) so producers can continue appending without allocation. +func (sw *sharedWriteBuf) Drain(spare []byte) []byte { + sw.mu.Lock() + data := sw.buf + sw.buf = spare + sw.mu.Unlock() + return data +} + +// WaitAndDrain blocks until data is available or the buffer is closed. +// Returns the accumulated bytes and true if data was available, or nil and +// false if the buffer is closed and empty. +func (sw *sharedWriteBuf) WaitAndDrain(spare []byte) ([]byte, bool) { + sw.mu.Lock() + for len(sw.buf) == 0 && !sw.closed { + sw.cond.Wait() + } + if sw.closed && len(sw.buf) == 0 { + sw.mu.Unlock() + return nil, false + } + data := sw.buf + sw.buf = spare + sw.mu.Unlock() + return data, true +} + +// Close marks the buffer as closed and wakes the consumer. +func (sw *sharedWriteBuf) Close() { + sw.mu.Lock() + defer sw.mu.Unlock() + + if sw.closed { + return + } + sw.closed = true + sw.cond.Broadcast() +} diff --git a/drpcmanager/mux_writer_test.go b/drpcmanager/mux_writer_test.go new file mode 100644 index 0000000..288631b --- /dev/null +++ b/drpcmanager/mux_writer_test.go @@ -0,0 +1,83 @@ +// Copyright (C) 2026 Cockroach Labs. +// See LICENSE for copying information. + +package drpcmanager + +import ( + "testing" + + "github.com/zeebo/assert" + + "storj.io/drpc/drpcwire" +) + +func TestMuxWriter_WriteFrame(t *testing.T) { + sw := newSharedWriteBuf() + w := &muxWriter{sw: sw} + + fr := drpcwire.Frame{ + Data: []byte("hello"), + ID: drpcwire.ID{Stream: 1, Message: 2}, + Kind: drpcwire.KindMessage, + Done: true, + } + + assert.NoError(t, w.WriteFrame(fr)) + + data := sw.Drain(nil) + _, got, ok, err := drpcwire.ParseFrame(data) + assert.NoError(t, err) + assert.That(t, ok) + assert.DeepEqual(t, got.Data, fr.Data) + assert.Equal(t, got.ID.Stream, fr.ID.Stream) + assert.Equal(t, got.ID.Message, fr.ID.Message) + assert.Equal(t, got.Kind, fr.Kind) + assert.Equal(t, got.Done, fr.Done) +} + +func TestMuxWriter_WriteFrameIsolatesData(t *testing.T) { + sw := newSharedWriteBuf() + w := &muxWriter{sw: sw} + + data := []byte("hello") + fr := drpcwire.Frame{ + Data: data, + ID: drpcwire.ID{Stream: 1, Message: 2}, + Kind: drpcwire.KindMessage, + Done: true, + } + + assert.NoError(t, w.WriteFrame(fr)) + + // Mutate the original source buffer after WriteFrame. + data[0] = 'j' + + // The serialized data in the shared buffer should be unaffected because + // AppendFrame copies the bytes during serialization. + buf := sw.Drain(nil) + _, got, ok, err := drpcwire.ParseFrame(buf) + assert.NoError(t, err) + assert.That(t, ok) + assert.DeepEqual(t, got.Data, []byte("hello")) +} + +func TestMuxWriter_FlushNoop(t *testing.T) { + sw := newSharedWriteBuf() + w := &muxWriter{sw: sw} + assert.NoError(t, w.Flush()) +} + +func TestMuxWriter_Empty(t *testing.T) { + sw := newSharedWriteBuf() + w := &muxWriter{sw: sw} + assert.That(t, w.Empty()) +} + +func TestMuxWriter_WriteFrameAfterClose(t *testing.T) { + sw := newSharedWriteBuf() + w := &muxWriter{sw: sw} + sw.Close() + + err := w.WriteFrame(drpcwire.Frame{}) + assert.Error(t, err) +} diff --git a/drpcmanager/registry.go b/drpcmanager/registry.go new file mode 100644 index 0000000..32b50cb --- /dev/null +++ b/drpcmanager/registry.go @@ -0,0 +1,89 @@ +// Copyright (C) 2026 Cockroach Labs. +// See LICENSE for copying information. + +package drpcmanager + +import ( + "sync" + + "storj.io/drpc/drpcstream" +) + +// streamRegistry is a thread-safe map of stream IDs to stream objects. +// It is used by the Manager to track all active streams for lifecycle +// management and packet routing. +type streamRegistry struct { + mu sync.RWMutex + streams map[uint64]*drpcstream.Stream + closed bool +} + +func newStreamRegistry() *streamRegistry { + return &streamRegistry{ + streams: make(map[uint64]*drpcstream.Stream), + } +} + +// Register adds a stream to the registry. It returns an error if the registry +// is closed or if a stream with the same ID is already registered. +func (r *streamRegistry) Register(id uint64, stream *drpcstream.Stream) error { + r.mu.Lock() + defer r.mu.Unlock() + + if r.closed { + return managerClosed.New("register") + } + if _, ok := r.streams[id]; ok { + return managerClosed.New("duplicate stream id") + } + r.streams[id] = stream + return nil +} + +// Unregister removes a stream from the registry. It is a no-op if the stream +// is not registered or if the registry has been closed. +func (r *streamRegistry) Unregister(id uint64) { + r.mu.Lock() + defer r.mu.Unlock() + + if r.streams != nil { + delete(r.streams, id) + } +} + +// Get returns the stream for the given ID and whether it was found. +func (r *streamRegistry) Get(id uint64) (*drpcstream.Stream, bool) { + r.mu.RLock() + defer r.mu.RUnlock() + + s, ok := r.streams[id] + return s, ok +} + +// Close marks the registry as closed, preventing future Register calls. +// It does not cancel any streams. +func (r *streamRegistry) Close() { + r.mu.Lock() + defer r.mu.Unlock() + + r.closed = true +} + +// ForEach calls fn for each registered stream. The function is called with +// the stream ID and stream pointer. The registry is read-locked during iteration. +func (r *streamRegistry) ForEach(fn func(uint64, *drpcstream.Stream)) { + r.mu.RLock() + defer r.mu.RUnlock() + + for id, s := range r.streams { + fn(id, s) + } +} + +// Len returns the number of registered streams. +func (r *streamRegistry) Len() int { + r.mu.RLock() + defer r.mu.RUnlock() + + return len(r.streams) +} diff --git a/drpcmanager/registry_test.go b/drpcmanager/registry_test.go new file mode 100644 index 0000000..d27359b --- /dev/null +++ b/drpcmanager/registry_test.go @@ -0,0 +1,136 @@ +// Copyright (C) 2026 Cockroach Labs. +// See LICENSE for copying information. + +package drpcmanager + +import ( + "context" + "testing" + + "github.com/zeebo/assert" + + "storj.io/drpc/drpcstream" +) + +func testStream(id uint64) *drpcstream.Stream { + return drpcstream.New(context.Background(), id, &muxWriter{sw: newSharedWriteBuf()}) +} + +func TestStreamRegistry_RegisterAndGet(t *testing.T) { + reg := newStreamRegistry() + s := testStream(1) + + assert.NoError(t, reg.Register(1, s)) + + got, ok := reg.Get(1) + assert.That(t, ok) + assert.Equal(t, got, s) +} + +func TestStreamRegistry_GetMissing(t *testing.T) { + reg := newStreamRegistry() + + got, ok := reg.Get(42) + assert.That(t, !ok) + assert.Nil(t, got) +} + +func TestStreamRegistry_Unregister(t *testing.T) { + reg := newStreamRegistry() + s := testStream(1) + + assert.NoError(t, reg.Register(1, s)) + assert.Equal(t, reg.Len(), 1) + + reg.Unregister(1) + + _, ok := reg.Get(1) + assert.That(t, !ok) + assert.Equal(t, reg.Len(), 0) +} + +func TestStreamRegistry_UnregisterIdempotent(t *testing.T) { + reg := newStreamRegistry() + + // must not panic when unregistering a non-existent ID + reg.Unregister(99) +} + +func TestStreamRegistry_DuplicateRegister(t *testing.T) { + reg := newStreamRegistry() + s1 := testStream(1) + s2 := testStream(1) + + assert.NoError(t, reg.Register(1, s1)) + assert.Error(t, reg.Register(1, s2)) + + // original stream is still registered + got, ok := reg.Get(1) + assert.That(t, ok) + assert.Equal(t, got, s1) +} + +func TestStreamRegistry_RegisterAfterClose(t *testing.T) { + reg := newStreamRegistry() + reg.Close() + + err := reg.Register(1, testStream(1)) + assert.Error(t, err) +} + +func TestStreamRegistry_UnregisterAfterClose(t *testing.T) { + reg := newStreamRegistry() + s := testStream(1) + assert.NoError(t, reg.Register(1, s)) + + reg.Close() + + // must not panic + reg.Unregister(1) +} + +func TestStreamRegistry_Len(t *testing.T) { + reg := newStreamRegistry() + assert.Equal(t, reg.Len(), 0) + + assert.NoError(t, reg.Register(1, testStream(1))) + assert.Equal(t, reg.Len(), 1) + + assert.NoError(t, reg.Register(2, testStream(2))) + assert.Equal(t, reg.Len(), 2) + + reg.Unregister(1) + assert.Equal(t, reg.Len(), 1) +} + +func TestStreamRegistry_ForEach(t *testing.T) { + reg := newStreamRegistry() + s1 := testStream(1) + s2 := testStream(2) + s3 := testStream(3) + + assert.NoError(t, reg.Register(1, s1)) + assert.NoError(t, reg.Register(2, s2)) + assert.NoError(t, reg.Register(3, s3)) + + seen := make(map[uint64]*drpcstream.Stream) + reg.ForEach(func(id uint64, s *drpcstream.Stream) { + seen[id] = s + }) + + assert.Equal(t, len(seen), 3) + assert.Equal(t, seen[1], s1) + assert.Equal(t, seen[2], s2) + assert.Equal(t, seen[3], s3) +} + +func TestStreamRegistry_ForEach_Empty(t *testing.T) { + reg := newStreamRegistry() + + count := 0 + reg.ForEach(func(id uint64, s *drpcstream.Stream) { + count++ + }) + + assert.Equal(t, count, 0) +} diff --git a/drpcserver/server.go b/drpcserver/server.go index 815dc2a..2992488 100644 --- a/drpcserver/server.go +++ b/drpcserver/server.go @@ -119,7 +119,12 @@ func (s *Server) ServeOne(ctx context.Context, tr drpc.Transport) (err error) { } man := drpcmanager.NewWithOptions(tr, s.opts.Manager) - defer func() { err = errs.Combine(err, man.Close()) }() + + var wg sync.WaitGroup + defer func() { + wg.Wait() + err = errs.Combine(err, man.Close()) + }() cache := drpccache.New() defer cache.Clear() @@ -131,9 +136,15 @@ func (s *Server) ServeOne(ctx context.Context, tr drpc.Transport) (err error) { if err != nil { return errs.Wrap(err) } - if err := s.handleRPC(stream, rpc); err != nil { - return errs.Wrap(err) - } + wg.Add(1) + go func() { + defer wg.Done() + if err := s.handleRPC(stream, rpc); err != nil { + if s.opts.Log != nil { + s.opts.Log(err) + } + } + }() } } diff --git a/drpcstream/pktbuf.go b/drpcstream/pktbuf.go index db68864..fd59d4c 100644 --- a/drpcstream/pktbuf.go +++ b/drpcstream/pktbuf.go @@ -4,16 +4,29 @@ package drpcstream import ( + "io" "sync" ) +// pktBufPool recycles byte slices used on the read path to avoid per-packet +// allocations. Buffers flow: AcquirePacketBuf → reader → Put (zero-copy +// ownership transfer) → Get → consumer → recycle → pool. +var pktBufPool = sync.Pool{} + +// AcquirePacketBuf returns a byte slice from the shared packet buffer pool +// for use as a read buffer. Returns nil if the pool is empty. +func AcquirePacketBuf() []byte { + if v := pktBufPool.Get(); v != nil { + return v.([]byte)[:0] + } + return nil +} + type packetBuffer struct { mu sync.Mutex cond sync.Cond err error - data []byte - set bool - held bool + data [][]byte } func (pb *packetBuffer) init() { @@ -24,62 +37,62 @@ func (pb *packetBuffer) Close(err error) { pb.mu.Lock() defer pb.mu.Unlock() - for pb.held { - pb.cond.Wait() - } - if pb.err == nil { - pb.data = nil - pb.set = false + // Preserve already-queued messages on graceful close so readers can + // drain them before seeing EOF. + if err != io.EOF { + for i := range pb.data { + pktBufPool.Put(pb.data[i][:0]) + pb.data[i] = nil + } + pb.data = pb.data[:0] + } pb.err = err pb.cond.Broadcast() } } +// Put takes ownership of data. The caller must not use data after calling Put. +// If the buffer is closed, data is returned to the pool. func (pb *packetBuffer) Put(data []byte) { pb.mu.Lock() defer pb.mu.Unlock() - for pb.set && pb.err == nil { - pb.cond.Wait() - } if pb.err != nil { + pktBufPool.Put(data[:0]) return } - pb.data = data - pb.set = true - pb.held = false + pb.data = append(pb.data, data) pb.cond.Broadcast() - - for pb.set || pb.held { - pb.cond.Wait() - } } func (pb *packetBuffer) Get() ([]byte, error) { pb.mu.Lock() defer pb.mu.Unlock() - for !pb.set && pb.err == nil { + for len(pb.data) == 0 && pb.err == nil { pb.cond.Wait() } - if pb.err != nil { + if len(pb.data) == 0 { return nil, pb.err } - pb.held = true - pb.cond.Broadcast() + data := pb.data[0] + n := copy(pb.data, pb.data[1:]) + pb.data[n] = nil + pb.data = pb.data[:n] + return data, nil +} - return pb.data, nil +// recycle returns a buffer obtained from Get back to the pool. +// Call this after the data has been fully consumed (e.g. after Unmarshal). +func (pb *packetBuffer) recycle(buf []byte) { + if buf != nil { + pktBufPool.Put(buf[:0]) + } } func (pb *packetBuffer) Done() { - pb.mu.Lock() - defer pb.mu.Unlock() - - pb.data = nil - pb.set = false - pb.held = false - pb.cond.Broadcast() + // Kept for backward compatibility with stream callers. } diff --git a/drpcstream/stream.go b/drpcstream/stream.go index 2f7aabd..f0dd9a4 100644 --- a/drpcstream/stream.go +++ b/drpcstream/stream.go @@ -53,7 +53,7 @@ type Stream struct { flush sync.Once id drpcwire.ID - wr *drpcwire.Writer + wr drpcwire.StreamWriter pbuf packetBuffer wbuf []byte @@ -72,7 +72,7 @@ var _ drpc.Stream = (*Stream)(nil) // New returns a new stream bound to the context with the given stream id and // will use the writer to write messages on. It is important use monotonically // increasing stream ids within a single transport. -func New(ctx context.Context, sid uint64, wr *drpcwire.Writer) *Stream { +func New(ctx context.Context, sid uint64, wr drpcwire.StreamWriter) *Stream { return NewWithOptions(ctx, sid, wr, Options{}) } @@ -80,7 +80,7 @@ func New(ctx context.Context, sid uint64, wr *drpcwire.Writer) *Stream { // stream id and will use the writer to write messages on. It is important use // monotonically increasing stream ids within a single transport. The options // are used to control details of how the Stream operates. -func NewWithOptions(ctx context.Context, sid uint64, wr *drpcwire.Writer, opts Options) *Stream { +func NewWithOptions(ctx context.Context, sid uint64, wr drpcwire.StreamWriter, opts Options) *Stream { var task *trace.Task if trace.IsEnabled() { kind, rpc := drpcopts.GetStreamKind(&opts.Internal), drpcopts.GetStreamRPC(&opts.Internal) @@ -99,7 +99,7 @@ func NewWithOptions(ctx context.Context, sid uint64, wr *drpcwire.Writer, opts O task: task, id: drpcwire.ID{Stream: sid}, - wr: wr.Reset(), + wr: wr, } // initialize the packet buffer @@ -221,17 +221,19 @@ func (s *Stream) HandlePacket(pkt drpcwire.Packet) (err error) { drpcopts.GetStreamStats(&s.opts.Internal).AddRead(uint64(len(pkt.Data))) - if s.sigs.term.IsSet() { + // Put must always be called for message packets to manage buffer + // ownership. Put returns the buffer to the pool if the stream is closed. + if pkt.Kind == drpcwire.KindMessage { + s.pbuf.Put(pkt.Data) return nil } - s.log("HANDLE", pkt.String) - - if pkt.Kind == drpcwire.KindMessage { - s.pbuf.Put(pkt.Data) + if s.sigs.term.IsSet() { return nil } + s.log("HANDLE", pkt.String) + s.mu.Lock() defer s.mu.Unlock() @@ -457,8 +459,7 @@ func (s *Stream) RawRecv() (data []byte, err error) { if err != nil { return nil, err } - data = append([]byte(nil), data...) - s.pbuf.Done() + // Transfer buffer ownership to the caller without recycling. return data, nil } @@ -510,7 +511,7 @@ func (s *Stream) MsgRecv(msg drpc.Message, enc drpc.Encoding) (err error) { return err } err = enc.Unmarshal(data, msg) - s.pbuf.Done() + s.pbuf.recycle(data) return err } diff --git a/drpcwire/reader.go b/drpcwire/reader.go index c9ac397..ef801d3 100644 --- a/drpcwire/reader.go +++ b/drpcwire/reader.go @@ -143,7 +143,10 @@ func (r *Reader) ReadPacketUsing(buf []byte) (pkt Packet, err error) { pkt.Control = pkt.Control || fr.Control switch { - case fr.ID.Less(r.id): + case fr.ID.Stream == 0 || fr.ID.Message == 0: + return Packet{}, drpc.ProtocolError.New("id monotonicity violation (fr:%v r:%v)", fr.ID, r.id) + + case fr.ID.Stream == r.id.Stream && fr.ID.Less(r.id): return Packet{}, drpc.ProtocolError.New("id monotonicity violation (fr:%v r:%v)", fr.ID, r.id) case r.id != fr.ID || pkt.ID == ID{}: diff --git a/drpcwire/reader_test.go b/drpcwire/reader_test.go index d57145b..ec0cd3f 100644 --- a/drpcwire/reader_test.go +++ b/drpcwire/reader_test.go @@ -154,6 +154,51 @@ func TestReader(t *testing.T) { Frames: []Frame{{ID: ID{Stream: 0, Message: 1}}}, Error: "id monotonicity violation", }, + + { // cross-stream: frames from a lower stream after a higher stream are allowed (multiplexing) + Packets: []Packet{ + {Data: []byte("a"), ID: ID{Stream: 1, Message: 1}, Kind: KindMessage}, + {Data: []byte("b"), ID: ID{Stream: 2, Message: 1}, Kind: KindMessage}, + {Data: nil, ID: ID{Stream: 1, Message: 2}, Kind: KindClose}, + }, + Frames: []Frame{ + {Data: []byte("a"), ID: ID{Stream: 1, Message: 1}, Kind: KindMessage, Done: true}, + {Data: []byte("b"), ID: ID{Stream: 2, Message: 1}, Kind: KindMessage, Done: true}, + {Data: nil, ID: ID{Stream: 1, Message: 2}, Kind: KindClose, Done: true}, + }, + }, + + { // cross-stream: interleaved single-frame packets from multiple streams + Packets: []Packet{ + {Data: []byte("s1m1"), ID: ID{Stream: 1, Message: 1}, Kind: KindInvoke}, + {Data: []byte("s2m1"), ID: ID{Stream: 2, Message: 1}, Kind: KindInvoke}, + {Data: []byte("s1m2"), ID: ID{Stream: 1, Message: 2}, Kind: KindMessage}, + {Data: []byte("s2m2"), ID: ID{Stream: 2, Message: 2}, Kind: KindMessage}, + {Data: nil, ID: ID{Stream: 1, Message: 3}, Kind: KindCloseSend}, + {Data: nil, ID: ID{Stream: 2, Message: 3}, Kind: KindCloseSend}, + }, + Frames: []Frame{ + {Data: []byte("s1m1"), ID: ID{Stream: 1, Message: 1}, Kind: KindInvoke, Done: true}, + {Data: []byte("s2m1"), ID: ID{Stream: 2, Message: 1}, Kind: KindInvoke, Done: true}, + {Data: []byte("s1m2"), ID: ID{Stream: 1, Message: 2}, Kind: KindMessage, Done: true}, + {Data: []byte("s2m2"), ID: ID{Stream: 2, Message: 2}, Kind: KindMessage, Done: true}, + {Data: nil, ID: ID{Stream: 1, Message: 3}, Kind: KindCloseSend, Done: true}, + {Data: nil, ID: ID{Stream: 2, Message: 3}, Kind: KindCloseSend, Done: true}, + }, + }, + + { // cross-stream: within-stream monotonicity is still enforced + Packets: []Packet{ + {Data: []byte("a"), ID: ID{Stream: 1, Message: 1}, Kind: KindMessage}, + {Data: []byte("b"), ID: ID{Stream: 2, Message: 1}, Kind: KindMessage}, + }, + Frames: []Frame{ + {Data: []byte("a"), ID: ID{Stream: 1, Message: 1}, Kind: KindMessage, Done: true}, + {Data: []byte("b"), ID: ID{Stream: 2, Message: 1}, Kind: KindMessage, Done: true}, + {Data: []byte("c"), ID: ID{Stream: 2, Message: 1}, Kind: KindMessage, Done: true}, // replay on stream 2 + }, + Error: "id monotonicity violation", + }, } for _, tc := range cases { diff --git a/drpcwire/writer.go b/drpcwire/writer.go index fe909cf..dfd6d9c 100644 --- a/drpcwire/writer.go +++ b/drpcwire/writer.go @@ -12,6 +12,13 @@ import ( "storj.io/drpc/drpcdebug" ) +// StreamWriter is the interface for writing frames to a stream. +type StreamWriter interface { + WriteFrame(fr Frame) error + Flush() error + Empty() bool +} + // // Writer // diff --git a/go.mod b/go.mod index a66c0f9..61bc019 100644 --- a/go.mod +++ b/go.mod @@ -8,13 +8,16 @@ require ( github.com/zeebo/errs v1.4.0 google.golang.org/genproto/googleapis/rpc v0.0.0-20230525234030-28d5490b6b19 google.golang.org/grpc v1.57.2 - google.golang.org/protobuf v1.30.0 + google.golang.org/protobuf v1.33.0 ) require ( github.com/davecgh/go-spew v1.1.1 // indirect github.com/golang/protobuf v1.5.3 // indirect + github.com/kr/pretty v0.3.1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect + golang.org/x/net v0.38.0 // indirect golang.org/x/sys v0.33.0 // indirect + gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index e9d3c41..1201203 100644 --- a/go.sum +++ b/go.sum @@ -1,3 +1,4 @@ +github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= @@ -6,20 +7,30 @@ github.com/golang/protobuf v1.5.3/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiu github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= +github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= +github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= +github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= +github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8= +github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/zeebo/assert v1.3.1 h1:vukIABvugfNMZMQO1ABsyQDJDTVQbn+LWSMy1ol1h6A= github.com/zeebo/assert v1.3.1/go.mod h1:Pq9JiuJQpG8JLJdtkwrJESF0Foym2/D9XMU5ciN/wJ0= github.com/zeebo/errs v1.4.0 h1:XNdoD/RRMKP7HD0UhJnIzUy74ISdGGxURlYG8HSWSfM= github.com/zeebo/errs v1.4.0/go.mod h1:sgbWHsvVuTPHcqJJGQ1WhI5KbWlHYz+2+2C/LSEtCw4= -golang.org/x/net v0.9.0 h1:aWJ/m6xSmxWBx+V0XRHTlrYrPG56jKsLdTFmsSsCzOM= -golang.org/x/net v0.9.0/go.mod h1:d48xBJpPfHeWQsugry2m+kC02ZBRGRgulfHnEXEuWns= +golang.org/x/net v0.38.0 h1:vRMAPTMaeGqVhG5QyLJHqNDwecKTomGeqbnfZyKlBI8= +golang.org/x/net v0.38.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8= golang.org/x/sys v0.33.0 h1:q3i8TbbEz+JRD9ywIRlyRAQbM0qF7hu24q3teo2hbuw= golang.org/x/sys v0.33.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= -golang.org/x/text v0.9.0 h1:2sjJmO8cDvYveuX97RDLsxlyUxLl+GHoLxBiRdHllBE= -golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= +golang.org/x/text v0.23.0 h1:D71I7dUrlY+VX0gQShAThNGHFxZ13dGLBHQLVl1mJlY= +golang.org/x/text v0.23.0/go.mod h1:/BLNzu4aZCJ1+kcD0DNRotWKage4q2rGVAg4o22unh4= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= google.golang.org/genproto/googleapis/rpc v0.0.0-20230525234030-28d5490b6b19 h1:0nDDozoAU19Qb2HwhXadU8OcsiO/09cnTqhUtq2MEOM= google.golang.org/genproto/googleapis/rpc v0.0.0-20230525234030-28d5490b6b19/go.mod h1:66JfowdXAEgad5O9NnYcsNPLCPZJD++2L9X0PCMODrA= @@ -27,9 +38,10 @@ google.golang.org/grpc v1.57.2 h1:uw37EN34aMFFXB2QPW7Tq6tdTbind1GpRxw5aOX3a5k= google.golang.org/grpc v1.57.2/go.mod h1:Sd+9RMTACXwmub0zcNY2c4arhtrbBYD1AUHI/dt16Mo= google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= -google.golang.org/protobuf v1.30.0 h1:kPPoIgf3TsEvrm0PFe15JQ+570QVxYzEvvHqChK+cng= -google.golang.org/protobuf v1.30.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= -gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +google.golang.org/protobuf v1.33.0 h1:uNO2rsAINq/JlFpSdYEKIZ0uKD/R9cpdv0T+yoGwGmI= +google.golang.org/protobuf v1.33.0/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/internal/integration/Makefile b/internal/integration/Makefile new file mode 100644 index 0000000..4430126 --- /dev/null +++ b/internal/integration/Makefile @@ -0,0 +1,68 @@ +GO ?= go +COUNT ?= 1 +BENCH ?= . +RUN ?= ^$$ +TIMEOUT ?= 300s +PROFILES_DIR ?= /tmp/drpc-profiles +CPU_PROFILE ?= $(PROFILES_DIR)/cpu.pprof +MEM_PROFILE ?= $(PROFILES_DIR)/mem.pprof +BLOCK_PROFILE ?= $(PROFILES_DIR)/block.pprof +MUTEX_PROFILE ?= $(PROFILES_DIR)/mutex.pprof +TRACE_PROFILE ?= $(PROFILES_DIR)/trace.out +BLOCK_PROFILE_RATE ?= 1 +MEM_PROFILE_RATE ?= 4096 +MUTEX_PROFILE_FRACTION ?= 1 + +# Run benchmarks once +.PHONY: bench +bench: + $(GO) test -bench=$(BENCH) -benchmem -count=$(COUNT) -timeout=$(TIMEOUT) . + +# Run benchmarks with count=10 for use with benchstat. +# Usage: make bench-stat > before.txt (then after changes) make bench-stat > after.txt +# benchstat before.txt after.txt +.PHONY: bench-stat +bench-stat: + $(GO) test -bench=$(BENCH) -benchmem -count=10 -timeout=$(TIMEOUT) . + +.PHONY: bench-parallel-streams +bench-parallel-streams: + $(GO) test -run=$(RUN) -bench=BenchmarkParallelActiveStreams -benchmem -count=$(COUNT) -timeout=$(TIMEOUT) . + +# Collect benchmark profiles in one run. +# Example: +# make bench-profile BENCH='BenchmarkParallelActiveStreams/size=1KB' COUNT=1 +.PHONY: bench-profile +bench-profile: + mkdir -p $(PROFILES_DIR) + $(GO) test -run=$(RUN) -bench=$(BENCH) -benchmem -count=$(COUNT) -timeout=$(TIMEOUT) \ + -cpuprofile=$(CPU_PROFILE) \ + -memprofile=$(MEM_PROFILE) -memprofilerate=$(MEM_PROFILE_RATE) \ + -blockprofile=$(BLOCK_PROFILE) -blockprofilerate=$(BLOCK_PROFILE_RATE) \ + -mutexprofile=$(MUTEX_PROFILE) -mutexprofilefraction=$(MUTEX_PROFILE_FRACTION) \ + -trace=$(TRACE_PROFILE) . + +.PHONY: profile-top-cpu +profile-top-cpu: + $(GO) tool pprof -top $(CPU_PROFILE) + +.PHONY: profile-top-alloc-space +profile-top-alloc-space: + $(GO) tool pprof -top -sample_index=alloc_space $(MEM_PROFILE) + +.PHONY: profile-top-inuse-space +profile-top-inuse-space: + $(GO) tool pprof -top -sample_index=inuse_space $(MEM_PROFILE) + +.PHONY: profile-top-block +profile-top-block: + $(GO) tool pprof -top $(BLOCK_PROFILE) + +.PHONY: profile-top-mutex +profile-top-mutex: + $(GO) tool pprof -top $(MUTEX_PROFILE) + +.PHONY: gen-proto +gen-proto: + $(MAKE) -C ../.. install-protoc-plugin + $(GO) generate . diff --git a/internal/integration/benchmark_test.go b/internal/integration/benchmark_test.go new file mode 100644 index 0000000..e4edf8f --- /dev/null +++ b/internal/integration/benchmark_test.go @@ -0,0 +1,338 @@ +// Copyright (C) 2026 Cockroach Labs. +// See LICENSE for copying information. + +package integration + +import ( + "context" + "errors" + "io" + "strconv" + "sync" + "sync/atomic" + "testing" + "time" + + "google.golang.org/protobuf/proto" + "storj.io/drpc/drpcconn" + "storj.io/drpc/drpctest" +) + +var echoServer = impl{ + Method1Fn: func(_ context.Context, in *In) (*Out, error) { + return &Out{Out: in.In, Data: in.Data}, nil + }, + + Method2Fn: func(stream DRPCService_Method2Stream) error { + var last *In + for { + in, err := stream.Recv() + if errors.Is(err, io.EOF) { + break + } else if err != nil { + return err + } + last = in + } + if last == nil { + last = &In{} + } + return stream.SendAndClose(&Out{Out: last.In, Data: last.Data}) + }, + + Method3Fn: func(in *In, stream DRPCService_Method3Stream) error { + out := &Out{Out: 1, Data: in.Data} + for i := int64(0); i < in.In; i++ { + if err := stream.Send(out); err != nil { + return err + } + } + return nil + }, + + Method4Fn: func(stream DRPCService_Method4Stream) error { + for { + in, err := stream.Recv() + if errors.Is(err, io.EOF) { + return nil + } else if err != nil { + return err + } + if err := stream.Send(&Out{Out: in.In, Data: in.Data}); err != nil { + return err + } + } + }, +} + +var sizes = []struct { + name string + data []byte +}{ + {"small", nil}, + {"1KB", make([]byte, 1024)}, + {"8KB", make([]byte, 8192)}, +} + +var concurrencies = []int{1, 10, 100} +var activeStreamCounts = []int{2, 8, 32} + +const parallelActiveStreamsWriteLatency = 200 * time.Microsecond + +func BenchmarkUnary(b *testing.B) { + for _, sz := range sizes { + b.Run("size="+sz.name, func(b *testing.B) { + for _, c := range concurrencies { + b.Run("concurrent="+strconv.Itoa(c), func(b *testing.B) { + client, cleanup := createConnection(b, echoServer) + defer cleanup() + + in := &In{In: 1, Data: sz.data} + b.SetBytes(int64(proto.Size(in))) + b.ReportAllocs() + b.SetParallelism(c) + b.ResetTimer() + + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + if _, err := client.Method1(context.Background(), in); err != nil { + b.Error(err) + } + } + }) + }) + } + }) + } +} + +func BenchmarkInputStream(b *testing.B) { + for _, sz := range sizes { + b.Run("size="+sz.name, func(b *testing.B) { + client, cleanup := createConnection(b, echoServer) + defer cleanup() + + in := &In{In: 1, Data: sz.data} + stream, err := client.Method2(context.Background()) + if err != nil { + b.Fatal(err) + } + + b.SetBytes(int64(proto.Size(in))) + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + if err := stream.Send(in); err != nil { + b.Fatal(err) + } + } + + if _, err := stream.CloseAndRecv(); err != nil { + b.Fatal(err) + } + }) + } +} + +func BenchmarkOutputStream(b *testing.B) { + for _, sz := range sizes { + b.Run("size="+sz.name, func(b *testing.B) { + client, cleanup := createConnection(b, echoServer) + defer cleanup() + + in := &In{In: int64(b.N), Data: sz.data} + stream, err := client.Method3(context.Background(), in) + if err != nil { + b.Fatal(err) + } + + b.ReportAllocs() + b.ResetTimer() + + var last *Out + for i := 0; i < b.N; i++ { + last, err = stream.Recv() + if err != nil { + b.Fatal(err) + } + } + if last != nil { + b.SetBytes(int64(proto.Size(last))) + } + }) + } +} + +func BenchmarkBidiStream(b *testing.B) { + for _, sz := range sizes { + b.Run("size="+sz.name, func(b *testing.B) { + client, cleanup := createConnection(b, echoServer) + defer cleanup() + + in := &In{In: 1, Data: sz.data} + stream, err := client.Method4(context.Background()) + if err != nil { + b.Fatal(err) + } + + b.SetBytes(int64(proto.Size(in))) + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + if err := stream.Send(in); err != nil { + b.Fatal(err) + } + if _, err := stream.Recv(); err != nil { + b.Fatal(err) + } + } + }) + } +} + +func BenchmarkParallelActiveStreams(b *testing.B) { + benchmarkParallelActiveStreamsModes(b, 0) +} + +func BenchmarkParallelActiveStreamsWriteLatency(b *testing.B) { + benchmarkParallelActiveStreamsModes(b, parallelActiveStreamsWriteLatency) +} + +func benchmarkParallelActiveStreamsModes(b *testing.B, writeLatency time.Duration) { + modes := []struct { + name string + oneConnPerStream bool + useTCP bool + }{ + {name: "mux_single_connection", oneConnPerStream: false}, + {name: "one_connection_per_stream", oneConnPerStream: true}, + {name: "mux_tcp", oneConnPerStream: false, useTCP: true}, + {name: "one_conn_tcp", oneConnPerStream: true, useTCP: true}, + } + + for _, sz := range sizes { + b.Run("size="+sz.name, func(b *testing.B) { + for _, streamCount := range activeStreamCounts { + b.Run("streams="+strconv.Itoa(streamCount), func(b *testing.B) { + for _, mode := range modes { + b.Run(mode.name, func(b *testing.B) { + benchmarkParallelActiveStreams(b, sz.data, streamCount, mode.oneConnPerStream, writeLatency, mode.useTCP) + }) + } + }) + } + }) + } +} + +func benchmarkParallelActiveStreams( + b *testing.B, payload []byte, streamCount int, oneConnPerStream bool, writeLatency time.Duration, useTCP bool, +) { + ctx := drpctest.NewTracker(b) + defer ctx.Close() + + type bidiClient interface { + Send(*In) error + Recv() (*Out, error) + Close() error + } + + type worker struct { + stream bidiClient + close func() + } + workers := make([]worker, 0, streamCount) + + createConn := createRawConnectionWithWriteDelay + if useTCP { + createConn = createTCPConnectionWithWriteDelay + } + + var sharedConn *drpcconn.Conn + if !oneConnPerStream { + sharedConn = createConn(b, echoServer, ctx, writeLatency) + } + + for i := 0; i < streamCount; i++ { + conn := sharedConn + if oneConnPerStream { + conn = createConn(b, echoServer, ctx, writeLatency) + } + + client := NewDRPCServiceClient(conn) + stream, err := client.Method4(context.Background()) + if err != nil { + b.Fatalf("create stream %d: %v", i, err) + } + + s := stream + c := conn + workers = append(workers, worker{ + stream: s, + close: func() { + _ = s.Close() + if oneConnPerStream { + _ = c.Close() + } + }, + }) + } + + input := &In{In: 1, Data: payload} + b.SetBytes(int64(proto.Size(input))) + b.ReportAllocs() + + var sent atomic.Int64 + errCh := make(chan error, 1) + start := make(chan struct{}) + var wg sync.WaitGroup + + b.ResetTimer() + for _, w := range workers { + wg.Add(1) + go func(w worker) { + defer wg.Done() + msg := &In{In: 1, Data: payload} + <-start + + for { + n := sent.Add(1) + if n > int64(b.N) { + return + } + if err := w.stream.Send(msg); err != nil { + select { + case errCh <- err: + default: + } + return + } + if _, err := w.stream.Recv(); err != nil { + select { + case errCh <- err: + default: + } + return + } + } + }(w) + } + close(start) + wg.Wait() + b.StopTimer() + + select { + case err := <-errCh: + b.Fatal(err) + default: + } + + for _, w := range workers { + w.close() + } + if sharedConn != nil { + _ = sharedConn.Close() + } +} diff --git a/internal/integration/common_test.go b/internal/integration/common_test.go index 5acb61a..7780958 100644 --- a/internal/integration/common_test.go +++ b/internal/integration/common_test.go @@ -10,6 +10,7 @@ import ( "strconv" "sync" "testing" + "time" "github.com/zeebo/assert" "google.golang.org/grpc/codes" @@ -40,18 +41,86 @@ func in(n int64) *In { return &In{In: n} } func out(n int64) *Out { return &Out{Out: n} } func createRawConnection(t testing.TB, server DRPCServiceServer, ctx *drpctest.Tracker) *drpcconn.Conn { - c1, c2 := net.Pipe() + return createRawConnectionWithWriteDelay(t, server, ctx, 0) +} + +func createRawConnectionWithWriteDelay( + t testing.TB, server DRPCServiceServer, ctx *drpctest.Tracker, writeDelay time.Duration, +) *drpcconn.Conn { + return createConnectionWithTransport(t, server, ctx, writeDelay, false) +} + +func createTCPConnectionWithWriteDelay( + t testing.TB, server DRPCServiceServer, ctx *drpctest.Tracker, writeDelay time.Duration, +) *drpcconn.Conn { + return createConnectionWithTransport(t, server, ctx, writeDelay, true) +} + +func createConnectionWithTransport( + t testing.TB, server DRPCServiceServer, ctx *drpctest.Tracker, + writeDelay time.Duration, useTCP bool, +) *drpcconn.Conn { + var serverConn, clientConn net.Conn + if useTCP { + lis, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { lis.Close() }) + + type result struct { + conn net.Conn + err error + } + ch := make(chan result, 1) + go func() { + c, err := lis.Accept() + ch <- result{c, err} + }() + + c, err := net.Dial("tcp", lis.Addr().String()) + if err != nil { + t.Fatal(err) + } + r := <-ch + if r.err != nil { + t.Fatal(r.err) + } + serverConn, clientConn = r.conn, c + t.Cleanup(func() { + serverConn.Close() + clientConn.Close() + }) + } else { + c1, c2 := net.Pipe() + serverConn, clientConn = c1, c2 + } + + if writeDelay > 0 { + serverConn = &writeLatencyConn{Conn: serverConn, writeDelay: writeDelay} + clientConn = &writeLatencyConn{Conn: clientConn, writeDelay: writeDelay} + } mux := drpcmux.New() assert.NoError(t, DRPCRegisterService(mux, server)) srv := drpcserver.New(mux) - ctx.Run(func(ctx context.Context) { _ = srv.ServeOne(ctx, c1) }) - return drpcconn.NewWithOptions(c2, drpcconn.Options{ + ctx.Run(func(ctx context.Context) { _ = srv.ServeOne(ctx, serverConn) }) + return drpcconn.NewWithOptions(clientConn, drpcconn.Options{ Manager: drpcmanager.Options{ SoftCancel: true, }, }) } +type writeLatencyConn struct { + net.Conn + writeDelay time.Duration +} + +func (c *writeLatencyConn) Write(p []byte) (int, error) { + time.Sleep(c.writeDelay) + return c.Conn.Write(p) +} + func createConnection(t testing.TB, server DRPCServiceServer) (DRPCServiceClient, func()) { ctx := drpctest.NewTracker(t) conn := createRawConnection(t, server, ctx) diff --git a/internal/integration/customservice/service.pb.go b/internal/integration/customservice/service.pb.go index 251b45f..336bb29 100644 --- a/internal/integration/customservice/service.pb.go +++ b/internal/integration/customservice/service.pb.go @@ -3,8 +3,8 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: -// protoc-gen-go v1.27.1 -// protoc v4.24.4 +// protoc-gen-go v1.36.5 +// protoc v5.29.3 // source: service.proto package service @@ -14,6 +14,7 @@ import ( protoimpl "google.golang.org/protobuf/runtime/protoimpl" reflect "reflect" sync "sync" + unsafe "unsafe" ) const ( @@ -24,21 +25,18 @@ const ( ) type In struct { - state protoimpl.MessageState - sizeCache protoimpl.SizeCache + state protoimpl.MessageState `protogen:"open.v1"` + In int64 `protobuf:"varint,1,opt,name=in,proto3" json:"in,omitempty"` + Data []byte `protobuf:"bytes,2,opt,name=data,proto3" json:"data,omitempty"` unknownFields protoimpl.UnknownFields - - In int64 `protobuf:"varint,1,opt,name=in,proto3" json:"in,omitempty"` - Data []byte `protobuf:"bytes,2,opt,name=data,proto3" json:"data,omitempty"` + sizeCache protoimpl.SizeCache } func (x *In) Reset() { *x = In{} - if protoimpl.UnsafeEnabled { - mi := &file_service_proto_msgTypes[0] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) - } + mi := &file_service_proto_msgTypes[0] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) } func (x *In) String() string { @@ -49,7 +47,7 @@ func (*In) ProtoMessage() {} func (x *In) ProtoReflect() protoreflect.Message { mi := &file_service_proto_msgTypes[0] - if protoimpl.UnsafeEnabled && x != nil { + if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) @@ -79,21 +77,18 @@ func (x *In) GetData() []byte { } type Out struct { - state protoimpl.MessageState - sizeCache protoimpl.SizeCache + state protoimpl.MessageState `protogen:"open.v1"` + Out int64 `protobuf:"varint,1,opt,name=out,proto3" json:"out,omitempty"` + Data []byte `protobuf:"bytes,2,opt,name=data,proto3" json:"data,omitempty"` unknownFields protoimpl.UnknownFields - - Out int64 `protobuf:"varint,1,opt,name=out,proto3" json:"out,omitempty"` - Data []byte `protobuf:"bytes,2,opt,name=data,proto3" json:"data,omitempty"` + sizeCache protoimpl.SizeCache } func (x *Out) Reset() { *x = Out{} - if protoimpl.UnsafeEnabled { - mi := &file_service_proto_msgTypes[1] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) - } + mi := &file_service_proto_msgTypes[1] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) } func (x *Out) String() string { @@ -104,7 +99,7 @@ func (*Out) ProtoMessage() {} func (x *Out) ProtoReflect() protoreflect.Message { mi := &file_service_proto_msgTypes[1] - if protoimpl.UnsafeEnabled && x != nil { + if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) @@ -135,7 +130,7 @@ func (x *Out) GetData() []byte { var File_service_proto protoreflect.FileDescriptor -var file_service_proto_rawDesc = []byte{ +var file_service_proto_rawDesc = string([]byte{ 0x0a, 0x0d, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x07, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x22, 0x28, 0x0a, 0x02, 0x49, 0x6e, 0x12, 0x0e, 0x0a, 0x02, 0x69, 0x6e, 0x18, 0x01, 0x20, 0x01, 0x28, 0x03, 0x52, 0x02, 0x69, 0x6e, 0x12, 0x12, @@ -158,22 +153,22 @@ var file_service_proto_rawDesc = []byte{ 0x65, 0x72, 0x6e, 0x61, 0x6c, 0x2f, 0x69, 0x6e, 0x74, 0x65, 0x67, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x2f, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, -} +}) var ( file_service_proto_rawDescOnce sync.Once - file_service_proto_rawDescData = file_service_proto_rawDesc + file_service_proto_rawDescData []byte ) func file_service_proto_rawDescGZIP() []byte { file_service_proto_rawDescOnce.Do(func() { - file_service_proto_rawDescData = protoimpl.X.CompressGZIP(file_service_proto_rawDescData) + file_service_proto_rawDescData = protoimpl.X.CompressGZIP(unsafe.Slice(unsafe.StringData(file_service_proto_rawDesc), len(file_service_proto_rawDesc))) }) return file_service_proto_rawDescData } var file_service_proto_msgTypes = make([]protoimpl.MessageInfo, 2) -var file_service_proto_goTypes = []interface{}{ +var file_service_proto_goTypes = []any{ (*In)(nil), // 0: service.In (*Out)(nil), // 1: service.Out } @@ -198,37 +193,11 @@ func file_service_proto_init() { if File_service_proto != nil { return } - if !protoimpl.UnsafeEnabled { - file_service_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*In); i { - case 0: - return &v.state - case 1: - return &v.sizeCache - case 2: - return &v.unknownFields - default: - return nil - } - } - file_service_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*Out); i { - case 0: - return &v.state - case 1: - return &v.sizeCache - case 2: - return &v.unknownFields - default: - return nil - } - } - } type x struct{} out := protoimpl.TypeBuilder{ File: protoimpl.DescBuilder{ GoPackagePath: reflect.TypeOf(x{}).PkgPath(), - RawDescriptor: file_service_proto_rawDesc, + RawDescriptor: unsafe.Slice(unsafe.StringData(file_service_proto_rawDesc), len(file_service_proto_rawDesc)), NumEnums: 0, NumMessages: 2, NumExtensions: 0, @@ -239,7 +208,6 @@ func file_service_proto_init() { MessageInfos: file_service_proto_msgTypes, }.Build() File_service_proto = out.File - file_service_proto_rawDesc = nil file_service_proto_goTypes = nil file_service_proto_depIdxs = nil } diff --git a/internal/integration/customservice/service_drpc.pb.go b/internal/integration/customservice/service_drpc.pb.go index 9991940..7ae7bcd 100644 --- a/internal/integration/customservice/service_drpc.pb.go +++ b/internal/integration/customservice/service_drpc.pb.go @@ -1,12 +1,12 @@ // Code generated by protoc-gen-go-drpc. DO NOT EDIT. -// protoc-gen-go-drpc version: (devel) +// protoc-gen-go-drpc version: v0.0.35-0.20260221163524-0f9d4319a40d+dirty // source: service.proto package service import ( context "context" - errors "errors" + errors "github.com/cockroachdb/errors" drpc "storj.io/drpc" drpcerr "storj.io/drpc/drpcerr" customencoding "storj.io/drpc/internal/integration/customencoding" @@ -280,6 +280,7 @@ type DRPCService_Method2Stream interface { drpc.Stream SendAndClose(*Out) error Recv() (*In, error) + RecvMsg(interface{}) error } type drpcService_Method2Stream struct { @@ -305,7 +306,7 @@ func (x *drpcService_Method2Stream) Recv() (*In, error) { return m, nil } -func (x *drpcService_Method2Stream) RecvMsg(m *In) error { +func (x *drpcService_Method2Stream) RecvMsg(m interface{}) error { return x.MsgRecv(m, drpcEncoding_File_service_proto{}) } @@ -330,6 +331,7 @@ type DRPCService_Method4Stream interface { drpc.Stream Send(*Out) error Recv() (*In, error) + RecvMsg(interface{}) error } type drpcService_Method4Stream struct { @@ -352,6 +354,6 @@ func (x *drpcService_Method4Stream) Recv() (*In, error) { return m, nil } -func (x *drpcService_Method4Stream) RecvMsg(m *In) error { +func (x *drpcService_Method4Stream) RecvMsg(m interface{}) error { return x.MsgRecv(m, drpcEncoding_File_service_proto{}) } diff --git a/internal/integration/doc.go b/internal/integration/doc.go index d23a99f..4f2c23a 100644 --- a/internal/integration/doc.go +++ b/internal/integration/doc.go @@ -4,6 +4,6 @@ // Package integration holds integration tests for drpc. package integration -//go:generate protoc --go_out=paths=source_relative:service/. --go-drpc_out=paths=source_relative:service/. service.proto -//go:generate protoc --gogo_out=paths=source_relative:gogoservice/. --go-drpc_out=paths=source_relative,protolib=github.com/gogo/protobuf:gogoservice/. service.proto -//go:generate protoc --go_out=paths=source_relative:customservice/. --go-drpc_out=paths=source_relative,protolib=storj.io/drpc/internal/integration/customencoding:customservice/. service.proto +//go:generate protoc --go_out=paths=source_relative:service/. --go-drpc_out=paths=source_relative,generate-adapters=false:service/. service.proto +//go:generate protoc --gogo_out=paths=source_relative:gogoservice/. --go-drpc_out=paths=source_relative,protolib=github.com/gogo/protobuf,generate-adapters=false:gogoservice/. service.proto +//go:generate protoc --go_out=paths=source_relative:customservice/. --go-drpc_out=paths=source_relative,protolib=storj.io/drpc/internal/integration/customencoding,generate-adapters=false:customservice/. service.proto diff --git a/internal/integration/go.mod b/internal/integration/go.mod index df01970..efe284e 100644 --- a/internal/integration/go.mod +++ b/internal/integration/go.mod @@ -3,6 +3,7 @@ module storj.io/drpc/internal/integration go 1.25.0 require ( + github.com/cockroachdb/errors v1.12.0 github.com/gogo/protobuf v1.3.2 github.com/zeebo/assert v1.3.1 github.com/zeebo/errs v1.4.0 @@ -13,7 +14,16 @@ require ( ) require ( + github.com/cockroachdb/logtags v0.0.0-20230118201751-21c54148d20b // indirect + github.com/cockroachdb/redact v1.1.5 // indirect + github.com/getsentry/sentry-go v0.27.0 // indirect github.com/golang/protobuf v1.5.3 // indirect + github.com/kr/pretty v0.3.1 // indirect + github.com/kr/text v0.2.0 // indirect + github.com/pkg/errors v0.9.1 // indirect + github.com/rogpeppe/go-internal v1.9.0 // indirect + golang.org/x/sys v0.33.0 // indirect + golang.org/x/text v0.23.0 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20230525234030-28d5490b6b19 // indirect ) diff --git a/internal/integration/go.sum b/internal/integration/go.sum index d06e526..79c4117 100644 --- a/internal/integration/go.sum +++ b/internal/integration/go.sum @@ -1,5 +1,16 @@ +github.com/cockroachdb/errors v1.12.0 h1:d7oCs6vuIMUQRVbi6jWWWEJZahLCfJpnJSVobd1/sUo= +github.com/cockroachdb/errors v1.12.0/go.mod h1:SvzfYNNBshAVbZ8wzNc/UPK3w1vf0dKDUP41ucAIf7g= +github.com/cockroachdb/logtags v0.0.0-20230118201751-21c54148d20b h1:r6VH0faHjZeQy818SGhaone5OnYfxFR/+AzdY3sf5aE= +github.com/cockroachdb/logtags v0.0.0-20230118201751-21c54148d20b/go.mod h1:Vz9DsVWQQhf3vs21MhPMZpMGSht7O/2vFW2xusFUVOs= +github.com/cockroachdb/redact v1.1.5 h1:u1PMllDkdFfPWaNGMyLD1+so+aq3uUItthCFqzwPJ30= +github.com/cockroachdb/redact v1.1.5/go.mod h1:BVNblN9mBWFyMyqK1k3AAiSxhvhfK2oOZZ2lK+dpvRg= +github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/getsentry/sentry-go v0.27.0 h1:Pv98CIbtB3LkMWmXi4Joa5OOcwbmnX88sF5qbK3r3Ps= +github.com/getsentry/sentry-go v0.27.0/go.mod h1:lc76E2QywIyW8WuBnwl8Lc4bkmQH4+w1gwTf25trprY= +github.com/go-errors/errors v1.4.2 h1:J6MZopCL4uSllY1OfXM374weqZFFItUbrImctkmUxIA= +github.com/go-errors/errors v1.4.2/go.mod h1:sIVyrIiJhuEF+Pj9Ebtd6P/rEYROXFi3BopGUQ5a5Og= github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q= github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= @@ -10,8 +21,19 @@ github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8= github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= +github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= +github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/pingcap/errors v0.11.4 h1:lFuQV/oaUMGcD2tqt+01ROSmJs75VG1ToEOkZIZ4nE4= +github.com/pingcap/errors v0.11.4/go.mod h1:Oi8TUi2kEtXXLMJk9l1cGmz20kV3TaQ0usTwv5KuLY8= +github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA= +github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= +github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8= +github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= diff --git a/internal/integration/gogoservice/service_drpc.pb.go b/internal/integration/gogoservice/service_drpc.pb.go index 1a61701..b65e6f2 100644 --- a/internal/integration/gogoservice/service_drpc.pb.go +++ b/internal/integration/gogoservice/service_drpc.pb.go @@ -1,5 +1,5 @@ // Code generated by protoc-gen-go-drpc. DO NOT EDIT. -// protoc-gen-go-drpc version: (devel) +// protoc-gen-go-drpc version: v0.0.35-0.20260221163524-0f9d4319a40d+dirty // source: service.proto package service @@ -7,7 +7,7 @@ package service import ( bytes "bytes" context "context" - errors "errors" + errors "github.com/cockroachdb/errors" jsonpb "github.com/gogo/protobuf/jsonpb" proto "github.com/gogo/protobuf/proto" drpc "storj.io/drpc" @@ -287,6 +287,7 @@ type DRPCService_Method2Stream interface { drpc.Stream SendAndClose(*Out) error Recv() (*In, error) + RecvMsg(interface{}) error } type drpcService_Method2Stream struct { @@ -312,7 +313,7 @@ func (x *drpcService_Method2Stream) Recv() (*In, error) { return m, nil } -func (x *drpcService_Method2Stream) RecvMsg(m *In) error { +func (x *drpcService_Method2Stream) RecvMsg(m interface{}) error { return x.MsgRecv(m, drpcEncoding_File_service_proto{}) } @@ -337,6 +338,7 @@ type DRPCService_Method4Stream interface { drpc.Stream Send(*Out) error Recv() (*In, error) + RecvMsg(interface{}) error } type drpcService_Method4Stream struct { @@ -359,6 +361,6 @@ func (x *drpcService_Method4Stream) Recv() (*In, error) { return m, nil } -func (x *drpcService_Method4Stream) RecvMsg(m *In) error { +func (x *drpcService_Method4Stream) RecvMsg(m interface{}) error { return x.MsgRecv(m, drpcEncoding_File_service_proto{}) } diff --git a/internal/integration/service/service.pb.go b/internal/integration/service/service.pb.go index 251b45f..336bb29 100644 --- a/internal/integration/service/service.pb.go +++ b/internal/integration/service/service.pb.go @@ -3,8 +3,8 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: -// protoc-gen-go v1.27.1 -// protoc v4.24.4 +// protoc-gen-go v1.36.5 +// protoc v5.29.3 // source: service.proto package service @@ -14,6 +14,7 @@ import ( protoimpl "google.golang.org/protobuf/runtime/protoimpl" reflect "reflect" sync "sync" + unsafe "unsafe" ) const ( @@ -24,21 +25,18 @@ const ( ) type In struct { - state protoimpl.MessageState - sizeCache protoimpl.SizeCache + state protoimpl.MessageState `protogen:"open.v1"` + In int64 `protobuf:"varint,1,opt,name=in,proto3" json:"in,omitempty"` + Data []byte `protobuf:"bytes,2,opt,name=data,proto3" json:"data,omitempty"` unknownFields protoimpl.UnknownFields - - In int64 `protobuf:"varint,1,opt,name=in,proto3" json:"in,omitempty"` - Data []byte `protobuf:"bytes,2,opt,name=data,proto3" json:"data,omitempty"` + sizeCache protoimpl.SizeCache } func (x *In) Reset() { *x = In{} - if protoimpl.UnsafeEnabled { - mi := &file_service_proto_msgTypes[0] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) - } + mi := &file_service_proto_msgTypes[0] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) } func (x *In) String() string { @@ -49,7 +47,7 @@ func (*In) ProtoMessage() {} func (x *In) ProtoReflect() protoreflect.Message { mi := &file_service_proto_msgTypes[0] - if protoimpl.UnsafeEnabled && x != nil { + if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) @@ -79,21 +77,18 @@ func (x *In) GetData() []byte { } type Out struct { - state protoimpl.MessageState - sizeCache protoimpl.SizeCache + state protoimpl.MessageState `protogen:"open.v1"` + Out int64 `protobuf:"varint,1,opt,name=out,proto3" json:"out,omitempty"` + Data []byte `protobuf:"bytes,2,opt,name=data,proto3" json:"data,omitempty"` unknownFields protoimpl.UnknownFields - - Out int64 `protobuf:"varint,1,opt,name=out,proto3" json:"out,omitempty"` - Data []byte `protobuf:"bytes,2,opt,name=data,proto3" json:"data,omitempty"` + sizeCache protoimpl.SizeCache } func (x *Out) Reset() { *x = Out{} - if protoimpl.UnsafeEnabled { - mi := &file_service_proto_msgTypes[1] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) - } + mi := &file_service_proto_msgTypes[1] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) } func (x *Out) String() string { @@ -104,7 +99,7 @@ func (*Out) ProtoMessage() {} func (x *Out) ProtoReflect() protoreflect.Message { mi := &file_service_proto_msgTypes[1] - if protoimpl.UnsafeEnabled && x != nil { + if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) @@ -135,7 +130,7 @@ func (x *Out) GetData() []byte { var File_service_proto protoreflect.FileDescriptor -var file_service_proto_rawDesc = []byte{ +var file_service_proto_rawDesc = string([]byte{ 0x0a, 0x0d, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x07, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x22, 0x28, 0x0a, 0x02, 0x49, 0x6e, 0x12, 0x0e, 0x0a, 0x02, 0x69, 0x6e, 0x18, 0x01, 0x20, 0x01, 0x28, 0x03, 0x52, 0x02, 0x69, 0x6e, 0x12, 0x12, @@ -158,22 +153,22 @@ var file_service_proto_rawDesc = []byte{ 0x65, 0x72, 0x6e, 0x61, 0x6c, 0x2f, 0x69, 0x6e, 0x74, 0x65, 0x67, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x2f, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, -} +}) var ( file_service_proto_rawDescOnce sync.Once - file_service_proto_rawDescData = file_service_proto_rawDesc + file_service_proto_rawDescData []byte ) func file_service_proto_rawDescGZIP() []byte { file_service_proto_rawDescOnce.Do(func() { - file_service_proto_rawDescData = protoimpl.X.CompressGZIP(file_service_proto_rawDescData) + file_service_proto_rawDescData = protoimpl.X.CompressGZIP(unsafe.Slice(unsafe.StringData(file_service_proto_rawDesc), len(file_service_proto_rawDesc))) }) return file_service_proto_rawDescData } var file_service_proto_msgTypes = make([]protoimpl.MessageInfo, 2) -var file_service_proto_goTypes = []interface{}{ +var file_service_proto_goTypes = []any{ (*In)(nil), // 0: service.In (*Out)(nil), // 1: service.Out } @@ -198,37 +193,11 @@ func file_service_proto_init() { if File_service_proto != nil { return } - if !protoimpl.UnsafeEnabled { - file_service_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*In); i { - case 0: - return &v.state - case 1: - return &v.sizeCache - case 2: - return &v.unknownFields - default: - return nil - } - } - file_service_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*Out); i { - case 0: - return &v.state - case 1: - return &v.sizeCache - case 2: - return &v.unknownFields - default: - return nil - } - } - } type x struct{} out := protoimpl.TypeBuilder{ File: protoimpl.DescBuilder{ GoPackagePath: reflect.TypeOf(x{}).PkgPath(), - RawDescriptor: file_service_proto_rawDesc, + RawDescriptor: unsafe.Slice(unsafe.StringData(file_service_proto_rawDesc), len(file_service_proto_rawDesc)), NumEnums: 0, NumMessages: 2, NumExtensions: 0, @@ -239,7 +208,6 @@ func file_service_proto_init() { MessageInfos: file_service_proto_msgTypes, }.Build() File_service_proto = out.File - file_service_proto_rawDesc = nil file_service_proto_goTypes = nil file_service_proto_depIdxs = nil } diff --git a/internal/integration/service/service_drpc.pb.go b/internal/integration/service/service_drpc.pb.go index 7287578..0b859e6 100644 --- a/internal/integration/service/service_drpc.pb.go +++ b/internal/integration/service/service_drpc.pb.go @@ -1,12 +1,12 @@ // Code generated by protoc-gen-go-drpc. DO NOT EDIT. -// protoc-gen-go-drpc version: (devel) +// protoc-gen-go-drpc version: v0.0.35-0.20260221163524-0f9d4319a40d+dirty // source: service.proto package service import ( context "context" - errors "errors" + errors "github.com/cockroachdb/errors" protojson "google.golang.org/protobuf/encoding/protojson" proto "google.golang.org/protobuf/proto" drpc "storj.io/drpc" @@ -285,6 +285,7 @@ type DRPCService_Method2Stream interface { drpc.Stream SendAndClose(*Out) error Recv() (*In, error) + RecvMsg(interface{}) error } type drpcService_Method2Stream struct { @@ -310,7 +311,7 @@ func (x *drpcService_Method2Stream) Recv() (*In, error) { return m, nil } -func (x *drpcService_Method2Stream) RecvMsg(m *In) error { +func (x *drpcService_Method2Stream) RecvMsg(m interface{}) error { return x.MsgRecv(m, drpcEncoding_File_service_proto{}) } @@ -335,6 +336,7 @@ type DRPCService_Method4Stream interface { drpc.Stream Send(*Out) error Recv() (*In, error) + RecvMsg(interface{}) error } type drpcService_Method4Stream struct { @@ -357,6 +359,6 @@ func (x *drpcService_Method4Stream) Recv() (*In, error) { return m, nil } -func (x *drpcService_Method4Stream) RecvMsg(m *In) error { +func (x *drpcService_Method4Stream) RecvMsg(m interface{}) error { return x.MsgRecv(m, drpcEncoding_File_service_proto{}) } From 12d9d05d8c34cf54f98c2d79cb989116310a801a Mon Sep 17 00:00:00 2001 From: Shubham Dhama Date: Tue, 3 Mar 2026 11:14:01 +0000 Subject: [PATCH 2/2] drpc: fix concurrent large message corruption in multiplexed streams With stream multiplexing, multiple streams write concurrently to a shared buffer. The stream's rawWriteLocked used to split large messages into multiple frames (via SplitData) and call WriteFrame for each chunk. Each WriteFrame acquires the shared mutex independently, so frames from different streams can interleave in the buffer. The reader on the other side doesn't handle interleaved frames from different streams mid-packet. When it sees a frame from a different stream, it resets the partial packet, silently corrupting data for messages larger than SplitSize. The fix changes the StreamWriter interface from WriteFrame(Frame) to WritePacket(Packet). The stream hands off the full message data in a single call, and the writer serializes it atomically under one mutex hold. rawWriteLocked no longer splits messages into frames, so there is nothing to interleave. Splitting may have been useful before multiplexing. The manageWriter goroutine already batches all pending data from the shared buffer into a single transport write, so splitting at the stream level adds no value. If we ever need to limit per-write size, that belongs in the writer implementation, not in the stream's rawWrite path. --- drpcmanager/frame_queue_test.go | 20 ++++---- drpcmanager/manager_test.go | 82 ++++++++++++++++++++++++++++++++- drpcmanager/mux_writer.go | 31 +++++++------ drpcmanager/mux_writer_test.go | 31 ++++++------- drpcstream/stream.go | 60 ++++++++++++------------ drpcwire/writer.go | 6 ++- 6 files changed, 155 insertions(+), 75 deletions(-) diff --git a/drpcmanager/frame_queue_test.go b/drpcmanager/frame_queue_test.go index dafd604..b175e6b 100644 --- a/drpcmanager/frame_queue_test.go +++ b/drpcmanager/frame_queue_test.go @@ -7,21 +7,19 @@ import ( "testing" "github.com/zeebo/assert" - "storj.io/drpc/drpcwire" ) func TestSharedWriteBuf_AppendDrain(t *testing.T) { sw := newSharedWriteBuf() - fr := drpcwire.Frame{ + pkt := drpcwire.Packet{ Data: []byte("hello"), ID: drpcwire.ID{Stream: 1, Message: 2}, Kind: drpcwire.KindMessage, - Done: true, } - assert.NoError(t, sw.Append(fr)) + assert.NoError(t, sw.Append(pkt)) // Drain should return serialized bytes. data := sw.Drain(nil) @@ -31,11 +29,11 @@ func TestSharedWriteBuf_AppendDrain(t *testing.T) { _, got, ok, err := drpcwire.ParseFrame(data) assert.NoError(t, err) assert.That(t, ok) - assert.DeepEqual(t, got.Data, fr.Data) - assert.Equal(t, got.ID.Stream, fr.ID.Stream) - assert.Equal(t, got.ID.Message, fr.ID.Message) - assert.Equal(t, got.Kind, fr.Kind) - assert.Equal(t, got.Done, fr.Done) + assert.DeepEqual(t, got.Data, pkt.Data) + assert.Equal(t, got.ID.Stream, pkt.ID.Stream) + assert.Equal(t, got.ID.Message, pkt.ID.Message) + assert.Equal(t, got.Kind, pkt.Kind) + assert.Equal(t, got.Done, true) } func TestSharedWriteBuf_CloseIdempotent(t *testing.T) { @@ -48,7 +46,7 @@ func TestSharedWriteBuf_AppendAfterClose(t *testing.T) { sw := newSharedWriteBuf() sw.Close() - err := sw.Append(drpcwire.Frame{}) + err := sw.Append(drpcwire.Packet{}) assert.Error(t, err) } @@ -64,7 +62,7 @@ func TestSharedWriteBuf_WaitAndDrainBlocks(t *testing.T) { }() // Append should wake the blocked WaitAndDrain. - assert.NoError(t, sw.Append(drpcwire.Frame{Data: []byte("a")})) + assert.NoError(t, sw.Append(drpcwire.Packet{Data: []byte("a")})) <-done } diff --git a/drpcmanager/manager_test.go b/drpcmanager/manager_test.go index cbca79e..52c0c1e 100644 --- a/drpcmanager/manager_test.go +++ b/drpcmanager/manager_test.go @@ -15,7 +15,7 @@ import ( "github.com/zeebo/assert" grpcmetadata "google.golang.org/grpc/metadata" "storj.io/drpc/drpcmetadata" - + "storj.io/drpc/drpcstream" "storj.io/drpc/drpctest" "storj.io/drpc/drpcwire" ) @@ -245,6 +245,86 @@ func TestNewServerStreamUnreadMessageDoesNotBlockOtherStreams(t *testing.T) { defer func() { _ = srvStream2.Close() }() } +// TestConcurrentLargeMessages verifies that two streams writing messages larger +// than SplitSize concurrently do not corrupt each other's data. With the current +// implementation, rawWriteLocked splits messages into multiple frames and each +// frame is appended to the shared write buffer independently. Frames from +// different streams can interleave in the buffer, and the reader resets partial +// packets when it sees a frame from a different stream, silently corrupting data. +func TestConcurrentLargeMessages(t *testing.T) { + ctx := drpctest.NewTracker(t) + defer ctx.Close() + + cconn, sconn := net.Pipe() + defer func() { _ = cconn.Close() }() + defer func() { _ = sconn.Close() }() + + // Use a small SplitSize to force even small messages to be split into + // multiple frames, making interleaving likely. + streamOpts := drpcstream.Options{SplitSize: 5} + + cman := NewWithOptions(cconn, Options{Stream: streamOpts}) + defer func() { _ = cman.Close() }() + + sman := NewWithOptions(sconn, Options{Stream: streamOpts}) + defer func() { _ = sman.Close() }() + + // Create two client streams and send invoke + message concurrently. + stream1, err := cman.NewClientStream(ctx, "rpc-1") + assert.NoError(t, err) + defer func() { _ = stream1.Close() }() + + stream2, err := cman.NewClientStream(ctx, "rpc-2") + assert.NoError(t, err) + defer func() { _ = stream2.Close() }() + + msg1 := []byte("AAAAAAAAAAAAAAAAAAAA") // 20 bytes, split into 4 frames of 5 bytes + msg2 := []byte("BBBBBBBBBBBBBBBBBBBB") // 20 bytes, split into 4 frames of 5 bytes + + // Send invokes first (these are small, no splitting). + assert.NoError(t, stream1.RawWrite(drpcwire.KindInvoke, []byte("rpc-1"))) + assert.NoError(t, stream2.RawWrite(drpcwire.KindInvoke, []byte("rpc-2"))) + + // Accept both server streams before sending messages, so the streams are + // registered and the reader can route packets. + srvStream1, rpc1, err := sman.NewServerStream(ctx) + assert.NoError(t, err) + assert.Equal(t, "rpc-1", rpc1) + defer func() { _ = srvStream1.Close() }() + + srvStream2, rpc2, err := sman.NewServerStream(ctx) + assert.NoError(t, err) + assert.Equal(t, "rpc-2", rpc2) + defer func() { _ = srvStream2.Close() }() + + // Write messages concurrently from both streams. With SplitSize=5, each + // 20-byte message becomes 4 frames. The frames should not interleave. + ready := make(chan struct{}) + var wg sync.WaitGroup + wg.Add(2) + go func() { + defer wg.Done() + <-ready + assert.NoError(t, stream1.RawWrite(drpcwire.KindMessage, msg1)) + }() + go func() { + defer wg.Done() + <-ready + assert.NoError(t, stream2.RawWrite(drpcwire.KindMessage, msg2)) + }() + close(ready) + wg.Wait() + + // Read from both server streams and verify correctness. + got1, err := srvStream1.RawRecv() + assert.NoError(t, err) + assert.DeepEqual(t, got1, msg1) + + got2, err := srvStream2.RawRecv() + assert.NoError(t, err) + assert.DeepEqual(t, got2, msg2) +} + type blockingTransport chan struct{} func (b blockingTransport) Read(p []byte) (n int, err error) { <-b; return 0, io.EOF } diff --git a/drpcmanager/mux_writer.go b/drpcmanager/mux_writer.go index 9794779..b36c25e 100644 --- a/drpcmanager/mux_writer.go +++ b/drpcmanager/mux_writer.go @@ -9,23 +9,20 @@ import ( "storj.io/drpc/drpcwire" ) -// muxWriter implements drpcwire.StreamWriter by serializing frame bytes into a +// muxWriter implements drpcwire.StreamWriter by serializing packet bytes into a // shared write buffer. The manageWriter goroutine drains the buffer and writes // directly to the transport. // -// Compared to the previous frameQueue approach, this avoids: -// - copying frame payload into an intermediate queue slot, -// - drpcwire.Writer mutex overhead in the writer goroutine. -// -// Frames are serialized (via AppendFrame) into the shared buffer under a -// short-held mutex. The frame's Data slice is consumed before WriteFrame -// returns, so callers may safely reuse their buffers afterward. +// The entire packet is serialized as a single frame (via AppendFrame) under one +// mutex hold, so frames from concurrent streams never interleave on the wire. +// The packet's Data slice is consumed (copied) before WritePacket returns, so +// callers may safely reuse their buffers afterward. type muxWriter struct { sw *sharedWriteBuf } -func (w *muxWriter) WriteFrame(fr drpcwire.Frame) error { - return w.sw.Append(fr) +func (w *muxWriter) WritePacket(pkt drpcwire.Packet) error { + return w.sw.Append(pkt) } // Flush is a no-op because the manageWriter goroutine flushes to the @@ -50,15 +47,21 @@ func newSharedWriteBuf() *sharedWriteBuf { return sw } -// Append serializes fr into the shared buffer. The frame's Data slice is -// consumed (copied by AppendFrame) before Append returns. -func (sw *sharedWriteBuf) Append(fr drpcwire.Frame) error { +// Append serializes pkt as a single frame into the shared buffer. The packet's +// Data slice is consumed (copied by AppendFrame) before Append returns. +func (sw *sharedWriteBuf) Append(pkt drpcwire.Packet) error { sw.mu.Lock() if sw.closed { sw.mu.Unlock() return managerClosed.New("enqueue") } - sw.buf = drpcwire.AppendFrame(sw.buf, fr) + sw.buf = drpcwire.AppendFrame(sw.buf, drpcwire.Frame{ + Data: pkt.Data, + ID: pkt.ID, + Kind: pkt.Kind, + Control: pkt.Control, + Done: true, + }) sw.mu.Unlock() sw.cond.Signal() diff --git a/drpcmanager/mux_writer_test.go b/drpcmanager/mux_writer_test.go index 288631b..50c4e1c 100644 --- a/drpcmanager/mux_writer_test.go +++ b/drpcmanager/mux_writer_test.go @@ -7,49 +7,46 @@ import ( "testing" "github.com/zeebo/assert" - "storj.io/drpc/drpcwire" ) -func TestMuxWriter_WriteFrame(t *testing.T) { +func TestMuxWriter_WritePacket(t *testing.T) { sw := newSharedWriteBuf() w := &muxWriter{sw: sw} - fr := drpcwire.Frame{ + pkt := drpcwire.Packet{ Data: []byte("hello"), ID: drpcwire.ID{Stream: 1, Message: 2}, Kind: drpcwire.KindMessage, - Done: true, } - assert.NoError(t, w.WriteFrame(fr)) + assert.NoError(t, w.WritePacket(pkt)) data := sw.Drain(nil) _, got, ok, err := drpcwire.ParseFrame(data) assert.NoError(t, err) assert.That(t, ok) - assert.DeepEqual(t, got.Data, fr.Data) - assert.Equal(t, got.ID.Stream, fr.ID.Stream) - assert.Equal(t, got.ID.Message, fr.ID.Message) - assert.Equal(t, got.Kind, fr.Kind) - assert.Equal(t, got.Done, fr.Done) + assert.DeepEqual(t, got.Data, pkt.Data) + assert.Equal(t, got.ID.Stream, pkt.ID.Stream) + assert.Equal(t, got.ID.Message, pkt.ID.Message) + assert.Equal(t, got.Kind, pkt.Kind) + assert.Equal(t, got.Done, true) } -func TestMuxWriter_WriteFrameIsolatesData(t *testing.T) { +func TestMuxWriter_WritePacketIsolatesData(t *testing.T) { sw := newSharedWriteBuf() w := &muxWriter{sw: sw} data := []byte("hello") - fr := drpcwire.Frame{ + pkt := drpcwire.Packet{ Data: data, ID: drpcwire.ID{Stream: 1, Message: 2}, Kind: drpcwire.KindMessage, - Done: true, } - assert.NoError(t, w.WriteFrame(fr)) + assert.NoError(t, w.WritePacket(pkt)) - // Mutate the original source buffer after WriteFrame. + // Mutate the original source buffer after WritePacket. data[0] = 'j' // The serialized data in the shared buffer should be unaffected because @@ -73,11 +70,11 @@ func TestMuxWriter_Empty(t *testing.T) { assert.That(t, w.Empty()) } -func TestMuxWriter_WriteFrameAfterClose(t *testing.T) { +func TestMuxWriter_WritePacketAfterClose(t *testing.T) { sw := newSharedWriteBuf() w := &muxWriter{sw: sw} sw.Close() - err := w.WriteFrame(drpcwire.Frame{}) + err := w.WritePacket(drpcwire.Packet{}) assert.Error(t, err) } diff --git a/drpcstream/stream.go b/drpcstream/stream.go index f0dd9a4..7f5a54b 100644 --- a/drpcstream/stream.go +++ b/drpcstream/stream.go @@ -11,7 +11,6 @@ import ( "sync" "github.com/zeebo/errs" - "storj.io/drpc" "storj.io/drpc/drpcctx" "storj.io/drpc/drpcdebug" @@ -80,7 +79,9 @@ func New(ctx context.Context, sid uint64, wr drpcwire.StreamWriter) *Stream { // stream id and will use the writer to write messages on. It is important use // monotonically increasing stream ids within a single transport. The options // are used to control details of how the Stream operates. -func NewWithOptions(ctx context.Context, sid uint64, wr drpcwire.StreamWriter, opts Options) *Stream { +func NewWithOptions( + ctx context.Context, sid uint64, wr drpcwire.StreamWriter, opts Options, +) *Stream { var task *trace.Task if trace.IsEnabled() { kind, rpc := drpcopts.GetStreamKind(&opts.Internal), drpcopts.GetStreamRPC(&opts.Internal) @@ -312,26 +313,28 @@ func (s *Stream) checkCancelError(err error) error { return err } -// newFrameLocked bumps the internal message id and returns a frame. It must be +// nextID bumps the internal message id and returns the new ID. It must be // called under a mutex. -func (s *Stream) newFrameLocked(kind drpcwire.Kind) drpcwire.Frame { +func (s *Stream) nextID() drpcwire.ID { s.id.Message++ - return drpcwire.Frame{ID: s.id, Kind: kind} + return s.id } // sendPacketLocked sends the packet in a single write and flushes. It does not // check for any conditions to stop it from writing and is meant for internal // stream use to do things like signal errors or closes to the remote side. func (s *Stream) sendPacketLocked(kind drpcwire.Kind, control bool, data []byte) (err error) { - fr := s.newFrameLocked(kind) - fr.Data = data - fr.Control = control - fr.Done = true + pkt := drpcwire.Packet{ + ID: s.nextID(), + Kind: kind, + Data: data, + Control: control, + } drpcopts.GetStreamStats(&s.opts.Internal).AddWritten(uint64(len(data))) - s.log("SEND", fr.String) + s.log("SEND", pkt.String) - if err := s.wr.WriteFrame(fr); err != nil { + if err := s.wr.WritePacket(pkt); err != nil { return errs.Wrap(err) } if err := s.wr.Flush(); err != nil { @@ -374,29 +377,26 @@ func (s *Stream) RawWrite(kind drpcwire.Kind, data []byte) (err error) { // rawWriteLocked does the body of RawWrite assuming the caller is holding the // appropriate locks. func (s *Stream) rawWriteLocked(kind drpcwire.Kind, data []byte) (err error) { - fr := s.newFrameLocked(kind) - n := s.opts.SplitSize - - for { - switch { - case s.sigs.send.IsSet(): - return s.sigs.send.Err() - case s.sigs.term.IsSet(): - return s.sigs.term.Err() - } + switch { + case s.sigs.send.IsSet(): + return s.sigs.send.Err() + case s.sigs.term.IsSet(): + return s.sigs.term.Err() + } - fr.Data, data = drpcwire.SplitData(data, n) - fr.Done = len(data) == 0 + pkt := drpcwire.Packet{ + ID: s.nextID(), + Kind: kind, + Data: data, + } - drpcopts.GetStreamStats(&s.opts.Internal).AddWritten(uint64(len(fr.Data))) - s.log("SEND", fr.String) + drpcopts.GetStreamStats(&s.opts.Internal).AddWritten(uint64(len(data))) + s.log("SEND", pkt.String) - if err := s.wr.WriteFrame(fr); err != nil { - return s.checkCancelError(errs.Wrap(err)) - } else if fr.Done { - return nil - } + if err := s.wr.WritePacket(pkt); err != nil { + return s.checkCancelError(errs.Wrap(err)) } + return nil } // RawFlush flushes any buffers of data. diff --git a/drpcwire/writer.go b/drpcwire/writer.go index dfd6d9c..bfcf0ac 100644 --- a/drpcwire/writer.go +++ b/drpcwire/writer.go @@ -12,9 +12,11 @@ import ( "storj.io/drpc/drpcdebug" ) -// StreamWriter is the interface for writing frames to a stream. +// StreamWriter is the interface used by streams for writing packets. Each call +// to WritePacket must serialize the full data atomically so that frames from +// concurrent writers on different streams do not interleave on the wire. type StreamWriter interface { - WriteFrame(fr Frame) error + WritePacket(pkt Packet) error Flush() error Empty() bool }