diff --git a/.gitignore b/.gitignore index 00856f8..b9cc35c 100644 --- a/.gitignore +++ b/.gitignore @@ -8,3 +8,4 @@ cmd/protoc-gen-go-drpc/protoc-gen-go-drpc /WORKSPACE BUILD.bazel MODULE.bazel* +debug/* \ No newline at end of file diff --git a/AGENTS.md b/AGENTS.md new file mode 100644 index 0000000..f2c1b09 --- /dev/null +++ b/AGENTS.md @@ -0,0 +1 @@ +This fork of DRPC library is maintained and used by [CockroachDB](https://github.com/cockroachdb/cockroach) and is customized for CockroachDB's needs. \ No newline at end of file diff --git a/Makefile b/Makefile index b5f49ad..de2214f 100644 --- a/Makefile +++ b/Makefile @@ -25,6 +25,10 @@ lint: staticcheck $(PKG) golangci-lint run +.PHONY: install-protoc-plugin +install-protoc-plugin: + $(GO) install ./cmd/protoc-gen-go-drpc/ + .PHONY: gen-bazel gen-bazel: @echo "Generating WORKSPACE" diff --git a/cmd/protoc-gen-go-drpc/main.go b/cmd/protoc-gen-go-drpc/main.go index 39ddfbb..3cd2769 100644 --- a/cmd/protoc-gen-go-drpc/main.go +++ b/cmd/protoc-gen-go-drpc/main.go @@ -16,8 +16,9 @@ import ( ) type config struct { - protolib string - json bool + protolib string + json bool + generateAdapters bool } func main() { @@ -25,6 +26,7 @@ func main() { var conf config flags.StringVar(&conf.protolib, "protolib", "google.golang.org/protobuf", "which protobuf library to use for encoding") flags.BoolVar(&conf.json, "json", true, "generate encoders with json support") + flags.BoolVar(&conf.generateAdapters, "generate-adapters", true, "generate gRPC/DRPC adapter and RPC interface code") protogen.Options{ ParamFunc: flags.Set, @@ -55,7 +57,7 @@ func generateFile(plugin *protogen.Plugin, file *protogen.File, conf config) { d.generateEncoding(conf) for _, service := range file.Services { - d.generateService(service) + d.generateService(service, conf) } } @@ -267,7 +269,7 @@ func (d *drpc) generateEncoding(conf config) { // service generation // -func (d *drpc) generateService(service *protogen.Service) { +func (d *drpc) generateService(service *protogen.Service, conf config) { // Client interface d.P("type ", d.ClientIface(service), " interface {") d.P("DRPCConn() ", d.Ident("storj.io/drpc", "Conn")) @@ -294,7 +296,7 @@ func (d *drpc) generateService(service *protogen.Service) { d.P("func (c *", d.ClientImpl(service), ") DRPCConn() ", d.Ident("storj.io/drpc", "Conn"), "{ return c.cc }") d.P() for _, method := range service.Methods { - d.generateClientMethod(method) + d.generateClientMethod(method, conf) } // Server interface @@ -339,11 +341,13 @@ func (d *drpc) generateService(service *protogen.Service) { // Server methods for _, method := range service.Methods { - d.generateServerMethod(method) + d.generateServerMethod(method, conf) } - d.generateServiceRPCInterfaces(service) - d.generateServiceAdapters(service) + if conf.generateAdapters { + d.generateServiceRPCInterfaces(service) + d.generateServiceAdapters(service) + } } // @@ -362,7 +366,7 @@ func (d *drpc) generateClientSignature(method *protogen.Method) string { return fmt.Sprintf("%s(ctx %s%s) (%s, error)", method.GoName, d.Ident("context", "Context"), reqArg, respName) } -func (d *drpc) generateClientMethod(method *protogen.Method) { +func (d *drpc) generateClientMethod(method *protogen.Method, conf config) { recvType := d.ClientImpl(method.Parent) outType := d.OutputType(method) inType := d.InputType(method) @@ -408,7 +412,9 @@ func (d *drpc) generateClientMethod(method *protogen.Method) { d.P("}") d.P() - d.generateRPCClientInterface(method) + if conf.generateAdapters { + d.generateRPCClientInterface(method) + } d.P("type ", d.ClientStreamImpl(method), " struct {") d.P(d.Ident("storj.io/drpc", "Stream")) @@ -510,7 +516,7 @@ func (d *drpc) generateServerReceiver(method *protogen.Method) { d.P(")") } -func (d *drpc) generateServerMethod(method *protogen.Method) { +func (d *drpc) generateServerMethod(method *protogen.Method, conf config) { genSend := method.Desc.IsStreamingServer() genSendAndClose := !method.Desc.IsStreamingServer() genRecv := method.Desc.IsStreamingClient() @@ -531,7 +537,9 @@ func (d *drpc) generateServerMethod(method *protogen.Method) { d.P("}") d.P() - d.generateRPCServerInterface(method) + if conf.generateAdapters { + d.generateRPCServerInterface(method) + } d.P("type ", d.ServerStreamImpl(method), " struct {") d.P(d.Ident("storj.io/drpc", "Stream")) diff --git a/drpcconn/conn.go b/drpcconn/conn.go index 636f346..33a50a4 100644 --- a/drpcconn/conn.go +++ b/drpcconn/conn.go @@ -31,10 +31,9 @@ type Options struct { // Conn is a drpc client connection. type Conn struct { - tr drpc.Transport - man *drpcmanager.Manager - mu sync.Mutex - wbuf []byte + tr drpc.Transport + man *drpcmanager.Manager + mu sync.Mutex // protects stats stats map[string]*drpcstats.Stats } @@ -92,16 +91,16 @@ func (c *Conn) Transport() drpc.Transport { return c.tr } // Closed returns a channel that is closed once the connection is closed. func (c *Conn) Closed() <-chan struct{} { return c.man.Closed() } -// Unblocked returns a channel that is closed once the connection is no longer -// blocked by a previously canceled Invoke or NewStream call. It should not -// be called concurrently with Invoke or NewStream. +// Unblocked returns a channel that is closed when the connection is available +// for new streams. With multiplexing enabled, this always returns an +// already-closed channel. func (c *Conn) Unblocked() <-chan struct{} { return c.man.Unblocked() } // Close closes the connection. func (c *Conn) Close() (err error) { return c.man.Close() } // Invoke issues the rpc on the transport serializing in, waits for a response, and -// deserializes it into out. Only one Invoke or Stream may be open at a time. +// deserializes it into out. Multiple Invoke or Stream calls may be open concurrently. func (c *Conn) Invoke(ctx context.Context, rpc string, enc drpc.Encoding, in, out drpc.Message) (err error) { defer func() { err = drpc.ToRPCErr(err) }() @@ -117,18 +116,13 @@ func (c *Conn) Invoke(ctx context.Context, rpc string, enc drpc.Encoding, in, ou } defer func() { err = errs.Combine(err, stream.Close()) }() - // we have to protect c.wbuf here even though the manager only allows one - // stream at a time because the stream may async close allowing another - // concurrent call to Invoke to proceed. - c.mu.Lock() - defer c.mu.Unlock() - - c.wbuf, err = drpcenc.MarshalAppend(in, enc, c.wbuf[:0]) + // Per-call buffer allocation for concurrent access. + data, err := drpcenc.MarshalAppend(in, enc, nil) if err != nil { return err } - if err := c.doInvoke(stream, enc, rpc, c.wbuf, metadata, out); err != nil { + if err := c.doInvoke(stream, enc, rpc, data, metadata, out); err != nil { return err } return nil @@ -155,8 +149,8 @@ func (c *Conn) doInvoke(stream *drpcstream.Stream, enc drpc.Encoding, rpc string return nil } -// NewStream begins a streaming rpc on the connection. Only one Invoke or Stream may -// be open at a time. +// NewStream begins a streaming rpc on the connection. Multiple Invoke or Stream calls +// may be open concurrently. func (c *Conn) NewStream(ctx context.Context, rpc string, enc drpc.Encoding) (_ drpc.Stream, err error) { defer func() { err = drpc.ToRPCErr(err) }() diff --git a/drpcconn/conn_test.go b/drpcconn/conn_test.go index 9f74516..2550c3b 100644 --- a/drpcconn/conn_test.go +++ b/drpcconn/conn_test.go @@ -180,7 +180,10 @@ func TestConn_NewStreamSendsGrpcAndDrpcMetadata(t *testing.T) { }) s, err := conn.NewStream(ctx, "/com.example.Foo/Bar", testEncoding{}) assert.NoError(t, err) - _ = s.CloseSend() + + assert.NoError(t, s.CloseSend()) + + ctx.Wait() } func TestConn_encodeMetadata(t *testing.T) { diff --git a/drpcmanager/frame_queue_test.go b/drpcmanager/frame_queue_test.go new file mode 100644 index 0000000..b175e6b --- /dev/null +++ b/drpcmanager/frame_queue_test.go @@ -0,0 +1,82 @@ +// Copyright (C) 2026 Cockroach Labs. +// See LICENSE for copying information. + +package drpcmanager + +import ( + "testing" + + "github.com/zeebo/assert" + "storj.io/drpc/drpcwire" +) + +func TestSharedWriteBuf_AppendDrain(t *testing.T) { + sw := newSharedWriteBuf() + + pkt := drpcwire.Packet{ + Data: []byte("hello"), + ID: drpcwire.ID{Stream: 1, Message: 2}, + Kind: drpcwire.KindMessage, + } + + assert.NoError(t, sw.Append(pkt)) + + // Drain should return serialized bytes. + data := sw.Drain(nil) + assert.That(t, len(data) > 0) + + // Parse the frame back out to verify correctness. + _, got, ok, err := drpcwire.ParseFrame(data) + assert.NoError(t, err) + assert.That(t, ok) + assert.DeepEqual(t, got.Data, pkt.Data) + assert.Equal(t, got.ID.Stream, pkt.ID.Stream) + assert.Equal(t, got.ID.Message, pkt.ID.Message) + assert.Equal(t, got.Kind, pkt.Kind) + assert.Equal(t, got.Done, true) +} + +func TestSharedWriteBuf_CloseIdempotent(t *testing.T) { + sw := newSharedWriteBuf() + sw.Close() + sw.Close() // must not panic +} + +func TestSharedWriteBuf_AppendAfterClose(t *testing.T) { + sw := newSharedWriteBuf() + sw.Close() + + err := sw.Append(drpcwire.Packet{}) + assert.Error(t, err) +} + +func TestSharedWriteBuf_WaitAndDrainBlocks(t *testing.T) { + sw := newSharedWriteBuf() + + done := make(chan struct{}) + go func() { + defer close(done) + data, ok := sw.WaitAndDrain(nil) + assert.That(t, ok) + assert.That(t, len(data) > 0) + }() + + // Append should wake the blocked WaitAndDrain. + assert.NoError(t, sw.Append(drpcwire.Packet{Data: []byte("a")})) + <-done +} + +func TestSharedWriteBuf_WaitAndDrainCloseEmpty(t *testing.T) { + sw := newSharedWriteBuf() + + done := make(chan struct{}) + go func() { + defer close(done) + _, ok := sw.WaitAndDrain(nil) + assert.That(t, !ok) + }() + + // Close on empty buffer should return ok=false. + sw.Close() + <-done +} diff --git a/drpcmanager/manager.go b/drpcmanager/manager.go index 1680a65..8a26476 100644 --- a/drpcmanager/manager.go +++ b/drpcmanager/manager.go @@ -10,6 +10,8 @@ import ( "io" "net" "strings" + "sync" + "sync/atomic" "syscall" "time" @@ -68,30 +70,26 @@ type Options struct { // to the appropriate stream. type Manager struct { tr drpc.Transport - wr *drpcwire.Writer rd *drpcwire.Reader opts Options - sem drpcsignal.Chan // held by the active stream - sbuf streamBuffer // largest stream id created - pkts chan drpcwire.Packet // channel for invoke packets - pdone drpcsignal.Chan // signals when a packets buffers can be reused - sfin chan struct{} // shared signal for stream finished - streams chan streamInfo // channel to signal that a stream should start + sw *sharedWriteBuf // shared write buffer for the writer goroutine + reg *streamRegistry // tracks all active streams by ID + streamID atomic.Uint64 // next stream ID for client streams + wg sync.WaitGroup // tracks active manageStream goroutines + pkts chan drpcwire.Packet // channel for invoke packets + pdone drpcsignal.Chan // signals when packet buffers can be reused + metaMu sync.Mutex + meta map[uint64]map[string]string // invoke metadata buffered by stream ID sigs struct { - term drpcsignal.Signal // set when the manager should start terminating - stream drpcsignal.Signal // set when the manage streams goroutine is done - read drpcsignal.Signal // set after the goroutine reading from the transport is done - tport drpcsignal.Signal // set after the transport has been closed + term drpcsignal.Signal // set when the manager should start terminating + write drpcsignal.Signal // set when the writer goroutine is done + read drpcsignal.Signal // set after the goroutine reading from the transport is done + tport drpcsignal.Signal // set after the transport has been closed } } -type streamInfo struct { - ctx context.Context - stream *drpcstream.Stream -} - // New returns a new Manager for the transport. func New(tr drpc.Transport) *Manager { return NewWithOptions(tr, Options{}) @@ -102,31 +100,25 @@ func New(tr drpc.Transport) *Manager { func NewWithOptions(tr drpc.Transport, opts Options) *Manager { m := &Manager{ tr: tr, - wr: drpcwire.NewWriter(tr, opts.WriterBufferSize), rd: drpcwire.NewReaderWithOptions(tr, opts.Reader), opts: opts, - pkts: make(chan drpcwire.Packet), - sfin: make(chan struct{}, 1), - streams: make(chan streamInfo), + pkts: make(chan drpcwire.Packet), + meta: make(map[uint64]map[string]string), } - // initialize the stream buffer - m.sbuf.init() - - // this semaphore controls the number of concurrent streams. it MUST be 1. - m.sem.Make(1) - // a buffer of size 1 allows the consumer of the packet to signal it is done // without having to coordinate with the sender of the packet. m.pdone.Make(1) // set the internal stream options drpcopts.SetStreamTransport(&m.opts.Stream.Internal, m.tr) - drpcopts.SetStreamFin(&m.opts.Stream.Internal, m.sfin) + + m.sw = newSharedWriteBuf() + m.reg = newStreamRegistry() go m.manageReader() - go m.manageStreams() + go m.manageWriter() return m } @@ -144,71 +136,44 @@ func (m *Manager) log(what string, cb func() string) { // helpers // -// acquireSemaphore attempts to acquire the semaphore protecting streams. If the -// context is canceled or the manager is terminated, it returns an error. -func (m *Manager) acquireSemaphore(ctx context.Context) error { - if err, ok := m.sigs.term.Get(); ok { - return err - } else if err := ctx.Err(); err != nil { - return err - } - - select { - case <-ctx.Done(): - return ctx.Err() - - case <-m.sigs.term.Signal(): - return m.sigs.term.Err() - - case m.sem.Get() <- struct{}{}: - if err := m.waitForPreviousStream(ctx); err != nil { - m.sem.Recv() - return err - } - return nil - } -} - -// waitForPreviousStream will, if there was a previous stream, ensure it is -// Closed and then wait until it is in the Finished state, where it will no -// longer make any reads or writes on the transport. It exits early if the -// context is canceled or the manager is terminated. -func (m *Manager) waitForPreviousStream(ctx context.Context) (err error) { - prev := m.sbuf.Get() - if prev == nil { - return nil - } - - // if the stream is not finished yet, we need to wait for it to be - // finished before letting the next stream to start. - if prev.IsFinished() { - return nil - } - - m.log("WAIT", prev.String) - - select { - case <-ctx.Done(): - return ctx.Err() - - case <-m.sigs.term.Signal(): - return m.sigs.term.Err() - - case <-prev.Finished(): - return nil - } -} - // terminate puts the Manager into a terminal state and closes any resources // that need to be closed to signal the state change. func (m *Manager) terminate(err error) { if m.sigs.term.Set(err) { m.log("TERM", func() string { return fmt.Sprint(err) }) m.sigs.tport.Set(m.tr.Close()) - m.sbuf.Close() + m.sw.Close() + m.metaMu.Lock() + for id := range m.meta { + delete(m.meta, id) + } + m.metaMu.Unlock() + // Cancel all active streams so they get a clear error. + cancelErr := err + if errors.Is(cancelErr, io.EOF) { + cancelErr = context.Canceled + } + m.reg.ForEach(func(_ uint64, s *drpcstream.Stream) { + s.Cancel(cancelErr) + }) + m.reg.Close() } } +func (m *Manager) putMetadata(streamID uint64, metadata map[string]string) { + m.metaMu.Lock() + defer m.metaMu.Unlock() + m.meta[streamID] = metadata +} + +func (m *Manager) popMetadata(streamID uint64) map[string]string { + m.metaMu.Lock() + defer m.metaMu.Unlock() + metadata := m.meta[streamID] + delete(m.meta, streamID) + return metadata +} + // // manage reader // @@ -250,46 +215,56 @@ func (m *Manager) manageReader() { m.log("READ", pkt.String) - again: - switch curr := m.sbuf.Get(); { - // if the packet is for the current stream, deliver it. - case curr != nil && pkt.ID.Stream == curr.ID(): - if err := curr.HandlePacket(pkt); err != nil { + stream, ok := m.reg.Get(pkt.ID.Stream) + + switch { + // if the packet is for a registered stream, deliver it. + case ok && stream != nil: + if err := stream.HandlePacket(pkt); err != nil { m.terminate(managerClosed.Wrap(err)) return } - - // if an old message has been sent, just ignore it. - case curr != nil && pkt.ID.Stream < curr.ID(): - - // if any invoke sequence is being sent, close any old unterminated - // stream and forward it to be handled. - case pkt.Kind == drpcwire.KindInvoke || pkt.Kind == drpcwire.KindInvokeMetadata: - if curr != nil && !curr.IsTerminated() { - curr.Cancel(context.Canceled) + // For message packets, HandlePacket transferred ownership of + // pkt.Data to the stream's packetBuffer. Acquire a fresh buffer + // from the pool so the next ReadPacketUsing doesn't allocate. + if pkt.Kind == drpcwire.KindMessage { + pkt.Data = drpcstream.AcquirePacketBuf() } + // if any invoke sequence is being sent, forward it to be handled. + case pkt.Kind == drpcwire.KindInvoke || pkt.Kind == drpcwire.KindInvokeMetadata: select { case m.pkts <- pkt: m.pdone.Recv() - case <-m.sigs.term.Signal(): return } - // a non-invoke packet should be delivered to some stream so we wait for - // a new stream to be created and try again. like an invoke, we - // implicitly close any previous stream. + // silently drop packet for an unregistered stream default: - if curr != nil && !curr.IsTerminated() { - curr.Cancel(context.Canceled) - } + m.log("DROP", pkt.String) + } + } +} - if !m.sbuf.Wait(curr.ID()) { - return - } - goto again +// manageWriter drains the shared write buffer and writes pre-serialized +// bytes directly to the transport. It blocks on the sharedWriteBuf's +// condition variable until data is available, and naturally batches +// frames that accumulate while the previous write is in flight. +func (m *Manager) manageWriter() { + defer m.sigs.write.Set(nil) + + var spare []byte + for { + data, ok := m.sw.WaitAndDrain(spare[:0:cap(spare)]) + if !ok { + return } + if _, err := m.tr.Write(data); err != nil { + m.terminate(managerClosed.Wrap(err)) + return + } + spare = data } } @@ -306,37 +281,25 @@ func (m *Manager) newStream(ctx context.Context, sid uint64, kind, rpc string) ( drpcopts.SetStreamStats(&opts.Internal, cb(rpc)) } - stream := drpcstream.NewWithOptions(ctx, sid, m.wr, opts) - select { - case m.streams <- streamInfo{ctx: ctx, stream: stream}: - m.sbuf.Set(stream) - m.log("STREAM", stream.String) - return stream, nil + stream := drpcstream.NewWithOptions(ctx, sid, &muxWriter{sw: m.sw}, opts) - case <-m.sigs.term.Signal(): - return nil, m.sigs.term.Err() + if err := m.reg.Register(sid, stream); err != nil { + return nil, err } -} -// manageStreams reads from the streams channel for stream infos and runs the -// manageStream function on them. -func (m *Manager) manageStreams() { - defer m.sigs.stream.Set(nil) + m.wg.Add(1) + go m.manageStream(ctx, stream) - for { - select { - case si := <-m.streams: - m.manageStream(si.ctx, si.stream) - - case <-m.sigs.term.Signal(): - return - } - } + m.log("STREAM", stream.String) + return stream, nil } // manageStream watches the context and the stream and returns when the stream // is finished, canceling the stream if the context is canceled. func (m *Manager) manageStream(ctx context.Context, stream *drpcstream.Stream) { + defer m.wg.Done() + defer m.reg.Unregister(stream.ID()) + select { case <-m.sigs.term.Signal(): err := m.sigs.term.Err() @@ -344,47 +307,34 @@ func (m *Manager) manageStream(ctx context.Context, stream *drpcstream.Stream) { err = context.Canceled } stream.Cancel(err) - <-m.sfin - m.sem.Recv() + <-stream.Finished() - case <-m.sfin: - m.sem.Recv() + case <-stream.Finished(): + // stream finished naturally case <-ctx.Done(): m.log("CANCEL", stream.String) if m.opts.SoftCancel { - // allow a new stream to begin. - m.sem.Recv() - - // attempt to send the soft cancel. if it fails or if the stream is - // busy sending something else, then we have to hard cancel. + // Best-effort send KindCancel, never terminate connection. if busy, err := stream.SendCancel(ctx.Err()); err != nil { - m.terminate(err) + m.log("CANCEL_ERR", func() string { + return fmt.Sprintf("%s: %v", stream.String(), err) + }) } else if busy { - m.log("BUSY", stream.String) - m.terminate(ctx.Err()) + m.log("CANCEL_BUSY", stream.String) } stream.Cancel(ctx.Err()) - - // wait for the stream to signal that it is finished. - <-m.sfin + <-stream.Finished() } else { - // If the stream isn't already finished, we have to terminate the - // transport to do an active cancel. If it is already finished, - // there is no need. + // Hard cancel: terminate connection if stream not finished. if !stream.Cancel(ctx.Err()) { m.log("UNFIN", stream.String) m.terminate(ctx.Err()) } else { m.log("CLEAN", stream.String) } - - // wait for the stream to signal that it is finished. - <-m.sfin - - // allow a new stream to begin. - m.sem.Recv() + <-stream.Finished() } } } @@ -398,15 +348,10 @@ func (m *Manager) Closed() <-chan struct{} { return m.sigs.term.Signal() } -// Unblocked returns a channel that is closed when the manager is no longer -// blocked from creating a new stream due to a previous stream's soft cancel. It -// should not be called concurrently with NewClientStream or NewServerStream and -// the return result is only valid until the next call to NewClientStream or -// NewServerStream. +// Unblocked returns a channel that is closed when the manager is available for +// new streams. With multiplexing enabled, the connection is never blocked, so +// this always returns an already-closed channel. func (m *Manager) Unblocked() <-chan struct{} { - if prev := m.sbuf.Get(); prev != nil { - return prev.Context().Done() - } return closedCh } @@ -414,7 +359,8 @@ func (m *Manager) Unblocked() <-chan struct{} { func (m *Manager) Close() error { m.terminate(managerClosed.New("Close called")) - m.sigs.stream.Wait() + m.wg.Wait() // wait for all stream goroutines + m.sigs.write.Wait() m.sigs.read.Wait() m.sigs.tport.Wait() @@ -423,28 +369,21 @@ func (m *Manager) Close() error { // NewClientStream starts a stream on the managed transport for use by a client. func (m *Manager) NewClientStream(ctx context.Context, rpc string) (stream *drpcstream.Stream, err error) { - if err := m.acquireSemaphore(ctx); err != nil { + if err, ok := m.sigs.term.Get(); ok { return nil, err } - - return m.newStream(ctx, m.sbuf.Get().ID()+1, "cli", rpc) + sid := m.streamID.Add(1) + return m.newStream(ctx, sid, "cli", rpc) } // NewServerStream starts a stream on the managed transport for use by a server. // It does this by waiting for the client to issue an invoke message and // returning the details. func (m *Manager) NewServerStream(ctx context.Context) (stream *drpcstream.Stream, rpc string, err error) { - if err := m.acquireSemaphore(ctx); err != nil { + if err, ok := m.sigs.term.Get(); ok { return nil, "", err } - defer func() { - if err != nil { - m.sem.Recv() - } - }() - var meta map[string]string - var metaID uint64 var timeoutCh <-chan time.Time // set up the timeout on the context if necessary. @@ -467,38 +406,40 @@ func (m *Manager) NewServerStream(ctx context.Context) (stream *drpcstream.Strea case pkt := <-m.pkts: switch pkt.Kind { - // keep track of any metadata being sent before an invoke so that we - // can include it if the stream id matches the eventual invoke. case drpcwire.KindInvokeMetadata: - meta, err = drpcmetadata.Decode(pkt.Data) - m.pdone.Send() - + metadata, err := drpcmetadata.Decode(pkt.Data) if err != nil { + m.pdone.Send() return nil, "", err } - metaID = pkt.ID.Stream + m.putMetadata(pkt.ID.Stream, metadata) + m.pdone.Send() case drpcwire.KindInvoke: rpc = string(pkt.Data) - m.pdone.Send() + streamCtx := ctx - if metaID == pkt.ID.Stream { + if metadata := m.popMetadata(pkt.ID.Stream); metadata != nil { if m.opts.GRPCMetadataCompatMode { // Populate incoming metadata as grpc metadata in the // context. This is a short-term fix that will enable us // to send and receive grpc metadata when DRPC is enabled, // without any changes in the calling code. - grpcMeta := make(map[string][]string, len(meta)) - for k, v := range meta { + grpcMeta := make(map[string][]string, len(metadata)) + for k, v := range metadata { grpcMeta[k] = []string{v} } - ctx = grpcmetadata.NewIncomingContext(ctx, grpcMeta) + streamCtx = grpcmetadata.NewIncomingContext(streamCtx, grpcMeta) } else { // Add metadata to the incoming context. - ctx = drpcmetadata.NewIncomingContext(ctx, meta) + streamCtx = drpcmetadata.NewIncomingContext(streamCtx, metadata) } } - stream, err := m.newStream(ctx, pkt.ID.Stream, "srv", rpc) + stream, err := m.newStream(streamCtx, pkt.ID.Stream, "srv", rpc) + // Ack the invoke only after stream registration so subsequent + // message packets cannot be dropped for an unknown stream ID. + // Always ack, even on error, so the reader goroutine does not block. + m.pdone.Send() return stream, rpc, err default: diff --git a/drpcmanager/manager_test.go b/drpcmanager/manager_test.go index 5918113..52c0c1e 100644 --- a/drpcmanager/manager_test.go +++ b/drpcmanager/manager_test.go @@ -15,20 +15,11 @@ import ( "github.com/zeebo/assert" grpcmetadata "google.golang.org/grpc/metadata" "storj.io/drpc/drpcmetadata" - + "storj.io/drpc/drpcstream" "storj.io/drpc/drpctest" "storj.io/drpc/drpcwire" ) -func closed(ch <-chan struct{}) bool { - select { - case <-ch: - return true - default: - return false - } -} - func TestTimeout(t *testing.T) { tr := make(blockingTransport) man := NewWithOptions(tr, Options{ @@ -69,10 +60,8 @@ func TestDrpcMetadata(t *testing.T) { assert.NoError(t, stream.RawWrite(drpcwire.KindInvoke, []byte("invoke"))) assert.NoError(t, stream.RawWrite(drpcwire.KindMessage, []byte("message"))) assert.NoError(t, stream.RawFlush()) - assert.That(t, !closed(cman.Unblocked())) assert.NoError(t, stream.Close()) - assert.That(t, closed(cman.Unblocked())) }) ctx.Run(func(ctx context.Context) { @@ -129,10 +118,8 @@ func TestDrpcMetadataWithGRPCMetadataCompatMode(t *testing.T) { assert.NoError(t, stream.RawWrite(drpcwire.KindInvoke, []byte("invoke"))) assert.NoError(t, stream.RawWrite(drpcwire.KindMessage, []byte("message"))) assert.NoError(t, stream.RawFlush()) - assert.That(t, !closed(cman.Unblocked())) assert.NoError(t, stream.Close()) - assert.That(t, closed(cman.Unblocked())) }) ctx.Run(func(ctx context.Context) { @@ -161,6 +148,183 @@ func TestDrpcMetadataWithGRPCMetadataCompatMode(t *testing.T) { ctx.Wait() } +func TestDrpcMetadataInterleavedAcrossStreams(t *testing.T) { + ctx := drpctest.NewTracker(t) + defer ctx.Close() + + cconn, sconn := net.Pipe() + defer func() { _ = cconn.Close() }() + defer func() { _ = sconn.Close() }() + + cman := New(cconn) + defer func() { _ = cman.Close() }() + + sman := NewWithOptions(sconn, Options{ + GRPCMetadataCompatMode: false, + }) + defer func() { _ = sman.Close() }() + + stream1, err := cman.NewClientStream(ctx, "rpc-1") + assert.NoError(t, err) + defer func() { _ = stream1.Close() }() + + stream2, err := cman.NewClientStream(ctx, "rpc-2") + assert.NoError(t, err) + defer func() { _ = stream2.Close() }() + + metadata1 := map[string]string{"stream": "one"} + metadata2 := map[string]string{"stream": "two"} + + buf1, err := drpcmetadata.Encode(nil, metadata1) + assert.NoError(t, err) + buf2, err := drpcmetadata.Encode(nil, metadata2) + assert.NoError(t, err) + + assert.NoError(t, stream1.RawWrite(drpcwire.KindInvokeMetadata, buf1)) + assert.NoError(t, stream2.RawWrite(drpcwire.KindInvokeMetadata, buf2)) + assert.NoError(t, stream1.RawWrite(drpcwire.KindInvoke, []byte("rpc-1"))) + assert.NoError(t, stream2.RawWrite(drpcwire.KindInvoke, []byte("rpc-2"))) + + srvStream1, rpc1, err := sman.NewServerStream(ctx) + assert.NoError(t, err) + assert.Equal(t, "rpc-1", rpc1) + defer func() { _ = srvStream1.Close() }() + + got1, ok := drpcmetadata.GetFromIncomingContext(srvStream1.Context()) + assert.That(t, ok) + assert.Equal(t, metadata1, got1) + + srvStream2, rpc2, err := sman.NewServerStream(ctx) + assert.NoError(t, err) + assert.Equal(t, "rpc-2", rpc2) + defer func() { _ = srvStream2.Close() }() + + got2, ok := drpcmetadata.GetFromIncomingContext(srvStream2.Context()) + assert.That(t, ok) + assert.Equal(t, metadata2, got2) +} + +func TestNewServerStreamUnreadMessageDoesNotBlockOtherStreams(t *testing.T) { + ctx := drpctest.NewTracker(t) + defer ctx.Close() + + cconn, sconn := net.Pipe() + defer func() { _ = cconn.Close() }() + defer func() { _ = sconn.Close() }() + + cman := New(cconn) + defer func() { _ = cman.Close() }() + + sman := New(sconn) + defer func() { _ = sman.Close() }() + + stream1, err := cman.NewClientStream(ctx, "rpc-1") + assert.NoError(t, err) + defer func() { _ = stream1.Close() }() + + stream2, err := cman.NewClientStream(ctx, "rpc-2") + assert.NoError(t, err) + defer func() { _ = stream2.Close() }() + + assert.NoError(t, stream1.RawWrite(drpcwire.KindInvoke, []byte("rpc-1"))) + assert.NoError(t, stream1.RawWrite(drpcwire.KindMessage, []byte("message-1"))) + assert.NoError(t, stream2.RawWrite(drpcwire.KindInvoke, []byte("rpc-2"))) + + srvStream1, rpc1, err := sman.NewServerStream(ctx) + assert.NoError(t, err) + assert.Equal(t, "rpc-1", rpc1) + defer func() { _ = srvStream1.Close() }() + + // Do not read the first stream's message. The manager must still be able to + // accept and register additional streams. + timeoutCtx, cancel := context.WithTimeout(ctx, time.Second) + defer cancel() + srvStream2, rpc2, err := sman.NewServerStream(timeoutCtx) + assert.NoError(t, err) + assert.Equal(t, "rpc-2", rpc2) + defer func() { _ = srvStream2.Close() }() +} + +// TestConcurrentLargeMessages verifies that two streams writing messages larger +// than SplitSize concurrently do not corrupt each other's data. With the current +// implementation, rawWriteLocked splits messages into multiple frames and each +// frame is appended to the shared write buffer independently. Frames from +// different streams can interleave in the buffer, and the reader resets partial +// packets when it sees a frame from a different stream, silently corrupting data. +func TestConcurrentLargeMessages(t *testing.T) { + ctx := drpctest.NewTracker(t) + defer ctx.Close() + + cconn, sconn := net.Pipe() + defer func() { _ = cconn.Close() }() + defer func() { _ = sconn.Close() }() + + // Use a small SplitSize to force even small messages to be split into + // multiple frames, making interleaving likely. + streamOpts := drpcstream.Options{SplitSize: 5} + + cman := NewWithOptions(cconn, Options{Stream: streamOpts}) + defer func() { _ = cman.Close() }() + + sman := NewWithOptions(sconn, Options{Stream: streamOpts}) + defer func() { _ = sman.Close() }() + + // Create two client streams and send invoke + message concurrently. + stream1, err := cman.NewClientStream(ctx, "rpc-1") + assert.NoError(t, err) + defer func() { _ = stream1.Close() }() + + stream2, err := cman.NewClientStream(ctx, "rpc-2") + assert.NoError(t, err) + defer func() { _ = stream2.Close() }() + + msg1 := []byte("AAAAAAAAAAAAAAAAAAAA") // 20 bytes, split into 4 frames of 5 bytes + msg2 := []byte("BBBBBBBBBBBBBBBBBBBB") // 20 bytes, split into 4 frames of 5 bytes + + // Send invokes first (these are small, no splitting). + assert.NoError(t, stream1.RawWrite(drpcwire.KindInvoke, []byte("rpc-1"))) + assert.NoError(t, stream2.RawWrite(drpcwire.KindInvoke, []byte("rpc-2"))) + + // Accept both server streams before sending messages, so the streams are + // registered and the reader can route packets. + srvStream1, rpc1, err := sman.NewServerStream(ctx) + assert.NoError(t, err) + assert.Equal(t, "rpc-1", rpc1) + defer func() { _ = srvStream1.Close() }() + + srvStream2, rpc2, err := sman.NewServerStream(ctx) + assert.NoError(t, err) + assert.Equal(t, "rpc-2", rpc2) + defer func() { _ = srvStream2.Close() }() + + // Write messages concurrently from both streams. With SplitSize=5, each + // 20-byte message becomes 4 frames. The frames should not interleave. + ready := make(chan struct{}) + var wg sync.WaitGroup + wg.Add(2) + go func() { + defer wg.Done() + <-ready + assert.NoError(t, stream1.RawWrite(drpcwire.KindMessage, msg1)) + }() + go func() { + defer wg.Done() + <-ready + assert.NoError(t, stream2.RawWrite(drpcwire.KindMessage, msg2)) + }() + close(ready) + wg.Wait() + + // Read from both server streams and verify correctness. + got1, err := srvStream1.RawRecv() + assert.NoError(t, err) + assert.DeepEqual(t, got1, msg1) + + got2, err := srvStream2.RawRecv() + assert.NoError(t, err) + assert.DeepEqual(t, got2, msg2) +} + type blockingTransport chan struct{} func (b blockingTransport) Read(p []byte) (n int, err error) { <-b; return 0, io.EOF } @@ -189,10 +353,8 @@ func TestUnblocked_NoCancel(t *testing.T) { assert.NoError(t, stream.RawWrite(drpcwire.KindInvoke, []byte("invoke"))) assert.NoError(t, stream.RawWrite(drpcwire.KindMessage, []byte("message"))) assert.NoError(t, stream.RawFlush()) - assert.That(t, !closed(cman.Unblocked())) assert.NoError(t, stream.Close()) - assert.That(t, closed(cman.Unblocked())) }) ctx.Run(func(ctx context.Context) { @@ -230,17 +392,20 @@ func TestUnblocked_SoftCancel(t *testing.T) { if softCancel { assert.NoError(t, err) } else if i > 0 { + // Hard cancel terminates the connection, so subsequent streams fail. assert.Error(t, err) return + } else { + assert.NoError(t, err) } defer func() { _ = stream.Close() }() - assert.That(t, !closed(man.Unblocked())) cancel() // temporary unblock writing to allow the stream to finish soft cancel tr.setWriteOpen(true) - <-man.Unblocked() + // With multiplexing, we wait for the stream to finish instead of Unblocked(). + <-stream.Finished() tr.setWriteOpen(false) }() } diff --git a/drpcmanager/mux_writer.go b/drpcmanager/mux_writer.go new file mode 100644 index 0000000..b36c25e --- /dev/null +++ b/drpcmanager/mux_writer.go @@ -0,0 +1,110 @@ +// Copyright (C) 2026 Cockroach Labs. +// See LICENSE for copying information. + +package drpcmanager + +import ( + "sync" + + "storj.io/drpc/drpcwire" +) + +// muxWriter implements drpcwire.StreamWriter by serializing packet bytes into a +// shared write buffer. The manageWriter goroutine drains the buffer and writes +// directly to the transport. +// +// The entire packet is serialized as a single frame (via AppendFrame) under one +// mutex hold, so frames from concurrent streams never interleave on the wire. +// The packet's Data slice is consumed (copied) before WritePacket returns, so +// callers may safely reuse their buffers afterward. +type muxWriter struct { + sw *sharedWriteBuf +} + +func (w *muxWriter) WritePacket(pkt drpcwire.Packet) error { + return w.sw.Append(pkt) +} + +// Flush is a no-op because the manageWriter goroutine flushes to the +// transport after draining the shared buffer. +func (w *muxWriter) Flush() error { return nil } + +func (w *muxWriter) Empty() bool { return true } + +// sharedWriteBuf collects serialized frame bytes from multiple concurrent +// producers. A single consumer (manageWriter) drains the buffer and writes +// the pre-serialized bytes to the transport. +type sharedWriteBuf struct { + mu sync.Mutex + cond *sync.Cond + buf []byte + closed bool +} + +func newSharedWriteBuf() *sharedWriteBuf { + sw := &sharedWriteBuf{} + sw.cond = sync.NewCond(&sw.mu) + return sw +} + +// Append serializes pkt as a single frame into the shared buffer. The packet's +// Data slice is consumed (copied by AppendFrame) before Append returns. +func (sw *sharedWriteBuf) Append(pkt drpcwire.Packet) error { + sw.mu.Lock() + if sw.closed { + sw.mu.Unlock() + return managerClosed.New("enqueue") + } + sw.buf = drpcwire.AppendFrame(sw.buf, drpcwire.Frame{ + Data: pkt.Data, + ID: pkt.ID, + Kind: pkt.Kind, + Control: pkt.Control, + Done: true, + }) + sw.mu.Unlock() + + sw.cond.Signal() + return nil +} + +// Drain swaps out accumulated bytes, giving the caller ownership of the +// returned slice. The internal buffer is replaced with spare (reset to zero +// length) so producers can continue appending without allocation. +func (sw *sharedWriteBuf) Drain(spare []byte) []byte { + sw.mu.Lock() + data := sw.buf + sw.buf = spare + sw.mu.Unlock() + return data +} + +// WaitAndDrain blocks until data is available or the buffer is closed. +// Returns the accumulated bytes and true if data was available, or nil and +// false if the buffer is closed and empty. +func (sw *sharedWriteBuf) WaitAndDrain(spare []byte) ([]byte, bool) { + sw.mu.Lock() + for len(sw.buf) == 0 && !sw.closed { + sw.cond.Wait() + } + if sw.closed && len(sw.buf) == 0 { + sw.mu.Unlock() + return nil, false + } + data := sw.buf + sw.buf = spare + sw.mu.Unlock() + return data, true +} + +// Close marks the buffer as closed and wakes the consumer. +func (sw *sharedWriteBuf) Close() { + sw.mu.Lock() + defer sw.mu.Unlock() + + if sw.closed { + return + } + sw.closed = true + sw.cond.Broadcast() +} diff --git a/drpcmanager/mux_writer_test.go b/drpcmanager/mux_writer_test.go new file mode 100644 index 0000000..50c4e1c --- /dev/null +++ b/drpcmanager/mux_writer_test.go @@ -0,0 +1,80 @@ +// Copyright (C) 2026 Cockroach Labs. +// See LICENSE for copying information. + +package drpcmanager + +import ( + "testing" + + "github.com/zeebo/assert" + "storj.io/drpc/drpcwire" +) + +func TestMuxWriter_WritePacket(t *testing.T) { + sw := newSharedWriteBuf() + w := &muxWriter{sw: sw} + + pkt := drpcwire.Packet{ + Data: []byte("hello"), + ID: drpcwire.ID{Stream: 1, Message: 2}, + Kind: drpcwire.KindMessage, + } + + assert.NoError(t, w.WritePacket(pkt)) + + data := sw.Drain(nil) + _, got, ok, err := drpcwire.ParseFrame(data) + assert.NoError(t, err) + assert.That(t, ok) + assert.DeepEqual(t, got.Data, pkt.Data) + assert.Equal(t, got.ID.Stream, pkt.ID.Stream) + assert.Equal(t, got.ID.Message, pkt.ID.Message) + assert.Equal(t, got.Kind, pkt.Kind) + assert.Equal(t, got.Done, true) +} + +func TestMuxWriter_WritePacketIsolatesData(t *testing.T) { + sw := newSharedWriteBuf() + w := &muxWriter{sw: sw} + + data := []byte("hello") + pkt := drpcwire.Packet{ + Data: data, + ID: drpcwire.ID{Stream: 1, Message: 2}, + Kind: drpcwire.KindMessage, + } + + assert.NoError(t, w.WritePacket(pkt)) + + // Mutate the original source buffer after WritePacket. + data[0] = 'j' + + // The serialized data in the shared buffer should be unaffected because + // AppendFrame copies the bytes during serialization. + buf := sw.Drain(nil) + _, got, ok, err := drpcwire.ParseFrame(buf) + assert.NoError(t, err) + assert.That(t, ok) + assert.DeepEqual(t, got.Data, []byte("hello")) +} + +func TestMuxWriter_FlushNoop(t *testing.T) { + sw := newSharedWriteBuf() + w := &muxWriter{sw: sw} + assert.NoError(t, w.Flush()) +} + +func TestMuxWriter_Empty(t *testing.T) { + sw := newSharedWriteBuf() + w := &muxWriter{sw: sw} + assert.That(t, w.Empty()) +} + +func TestMuxWriter_WritePacketAfterClose(t *testing.T) { + sw := newSharedWriteBuf() + w := &muxWriter{sw: sw} + sw.Close() + + err := w.WritePacket(drpcwire.Packet{}) + assert.Error(t, err) +} diff --git a/drpcmanager/registry.go b/drpcmanager/registry.go new file mode 100644 index 0000000..32b50cb --- /dev/null +++ b/drpcmanager/registry.go @@ -0,0 +1,89 @@ +// Copyright (C) 2026 Cockroach Labs. +// See LICENSE for copying information. + +package drpcmanager + +import ( + "sync" + + "storj.io/drpc/drpcstream" +) + +// streamRegistry is a thread-safe map of stream IDs to stream objects. +// It is used by the Manager to track all active streams for lifecycle +// management and packet routing. +type streamRegistry struct { + mu sync.RWMutex + streams map[uint64]*drpcstream.Stream + closed bool +} + +func newStreamRegistry() *streamRegistry { + return &streamRegistry{ + streams: make(map[uint64]*drpcstream.Stream), + } +} + +// Register adds a stream to the registry. It returns an error if the registry +// is closed or if a stream with the same ID is already registered. +func (r *streamRegistry) Register(id uint64, stream *drpcstream.Stream) error { + r.mu.Lock() + defer r.mu.Unlock() + + if r.closed { + return managerClosed.New("register") + } + if _, ok := r.streams[id]; ok { + return managerClosed.New("duplicate stream id") + } + r.streams[id] = stream + return nil +} + +// Unregister removes a stream from the registry. It is a no-op if the stream +// is not registered or if the registry has been closed. +func (r *streamRegistry) Unregister(id uint64) { + r.mu.Lock() + defer r.mu.Unlock() + + if r.streams != nil { + delete(r.streams, id) + } +} + +// Get returns the stream for the given ID and whether it was found. +func (r *streamRegistry) Get(id uint64) (*drpcstream.Stream, bool) { + r.mu.RLock() + defer r.mu.RUnlock() + + s, ok := r.streams[id] + return s, ok +} + +// Close marks the registry as closed, preventing future Register calls. +// It does not cancel any streams. +func (r *streamRegistry) Close() { + r.mu.Lock() + defer r.mu.Unlock() + + r.closed = true +} + +// ForEach calls fn for each registered stream. The function is called with +// the stream ID and stream pointer. The registry is read-locked during iteration. +func (r *streamRegistry) ForEach(fn func(uint64, *drpcstream.Stream)) { + r.mu.RLock() + defer r.mu.RUnlock() + + for id, s := range r.streams { + fn(id, s) + } +} + +// Len returns the number of registered streams. +func (r *streamRegistry) Len() int { + r.mu.RLock() + defer r.mu.RUnlock() + + return len(r.streams) +} diff --git a/drpcmanager/registry_test.go b/drpcmanager/registry_test.go new file mode 100644 index 0000000..d27359b --- /dev/null +++ b/drpcmanager/registry_test.go @@ -0,0 +1,136 @@ +// Copyright (C) 2026 Cockroach Labs. +// See LICENSE for copying information. + +package drpcmanager + +import ( + "context" + "testing" + + "github.com/zeebo/assert" + + "storj.io/drpc/drpcstream" +) + +func testStream(id uint64) *drpcstream.Stream { + return drpcstream.New(context.Background(), id, &muxWriter{sw: newSharedWriteBuf()}) +} + +func TestStreamRegistry_RegisterAndGet(t *testing.T) { + reg := newStreamRegistry() + s := testStream(1) + + assert.NoError(t, reg.Register(1, s)) + + got, ok := reg.Get(1) + assert.That(t, ok) + assert.Equal(t, got, s) +} + +func TestStreamRegistry_GetMissing(t *testing.T) { + reg := newStreamRegistry() + + got, ok := reg.Get(42) + assert.That(t, !ok) + assert.Nil(t, got) +} + +func TestStreamRegistry_Unregister(t *testing.T) { + reg := newStreamRegistry() + s := testStream(1) + + assert.NoError(t, reg.Register(1, s)) + assert.Equal(t, reg.Len(), 1) + + reg.Unregister(1) + + _, ok := reg.Get(1) + assert.That(t, !ok) + assert.Equal(t, reg.Len(), 0) +} + +func TestStreamRegistry_UnregisterIdempotent(t *testing.T) { + reg := newStreamRegistry() + + // must not panic when unregistering a non-existent ID + reg.Unregister(99) +} + +func TestStreamRegistry_DuplicateRegister(t *testing.T) { + reg := newStreamRegistry() + s1 := testStream(1) + s2 := testStream(1) + + assert.NoError(t, reg.Register(1, s1)) + assert.Error(t, reg.Register(1, s2)) + + // original stream is still registered + got, ok := reg.Get(1) + assert.That(t, ok) + assert.Equal(t, got, s1) +} + +func TestStreamRegistry_RegisterAfterClose(t *testing.T) { + reg := newStreamRegistry() + reg.Close() + + err := reg.Register(1, testStream(1)) + assert.Error(t, err) +} + +func TestStreamRegistry_UnregisterAfterClose(t *testing.T) { + reg := newStreamRegistry() + s := testStream(1) + assert.NoError(t, reg.Register(1, s)) + + reg.Close() + + // must not panic + reg.Unregister(1) +} + +func TestStreamRegistry_Len(t *testing.T) { + reg := newStreamRegistry() + assert.Equal(t, reg.Len(), 0) + + assert.NoError(t, reg.Register(1, testStream(1))) + assert.Equal(t, reg.Len(), 1) + + assert.NoError(t, reg.Register(2, testStream(2))) + assert.Equal(t, reg.Len(), 2) + + reg.Unregister(1) + assert.Equal(t, reg.Len(), 1) +} + +func TestStreamRegistry_ForEach(t *testing.T) { + reg := newStreamRegistry() + s1 := testStream(1) + s2 := testStream(2) + s3 := testStream(3) + + assert.NoError(t, reg.Register(1, s1)) + assert.NoError(t, reg.Register(2, s2)) + assert.NoError(t, reg.Register(3, s3)) + + seen := make(map[uint64]*drpcstream.Stream) + reg.ForEach(func(id uint64, s *drpcstream.Stream) { + seen[id] = s + }) + + assert.Equal(t, len(seen), 3) + assert.Equal(t, seen[1], s1) + assert.Equal(t, seen[2], s2) + assert.Equal(t, seen[3], s3) +} + +func TestStreamRegistry_ForEach_Empty(t *testing.T) { + reg := newStreamRegistry() + + count := 0 + reg.ForEach(func(id uint64, s *drpcstream.Stream) { + count++ + }) + + assert.Equal(t, count, 0) +} diff --git a/drpcserver/server.go b/drpcserver/server.go index 815dc2a..2992488 100644 --- a/drpcserver/server.go +++ b/drpcserver/server.go @@ -119,7 +119,12 @@ func (s *Server) ServeOne(ctx context.Context, tr drpc.Transport) (err error) { } man := drpcmanager.NewWithOptions(tr, s.opts.Manager) - defer func() { err = errs.Combine(err, man.Close()) }() + + var wg sync.WaitGroup + defer func() { + wg.Wait() + err = errs.Combine(err, man.Close()) + }() cache := drpccache.New() defer cache.Clear() @@ -131,9 +136,15 @@ func (s *Server) ServeOne(ctx context.Context, tr drpc.Transport) (err error) { if err != nil { return errs.Wrap(err) } - if err := s.handleRPC(stream, rpc); err != nil { - return errs.Wrap(err) - } + wg.Add(1) + go func() { + defer wg.Done() + if err := s.handleRPC(stream, rpc); err != nil { + if s.opts.Log != nil { + s.opts.Log(err) + } + } + }() } } diff --git a/drpcstream/pktbuf.go b/drpcstream/pktbuf.go index db68864..fd59d4c 100644 --- a/drpcstream/pktbuf.go +++ b/drpcstream/pktbuf.go @@ -4,16 +4,29 @@ package drpcstream import ( + "io" "sync" ) +// pktBufPool recycles byte slices used on the read path to avoid per-packet +// allocations. Buffers flow: AcquirePacketBuf → reader → Put (zero-copy +// ownership transfer) → Get → consumer → recycle → pool. +var pktBufPool = sync.Pool{} + +// AcquirePacketBuf returns a byte slice from the shared packet buffer pool +// for use as a read buffer. Returns nil if the pool is empty. +func AcquirePacketBuf() []byte { + if v := pktBufPool.Get(); v != nil { + return v.([]byte)[:0] + } + return nil +} + type packetBuffer struct { mu sync.Mutex cond sync.Cond err error - data []byte - set bool - held bool + data [][]byte } func (pb *packetBuffer) init() { @@ -24,62 +37,62 @@ func (pb *packetBuffer) Close(err error) { pb.mu.Lock() defer pb.mu.Unlock() - for pb.held { - pb.cond.Wait() - } - if pb.err == nil { - pb.data = nil - pb.set = false + // Preserve already-queued messages on graceful close so readers can + // drain them before seeing EOF. + if err != io.EOF { + for i := range pb.data { + pktBufPool.Put(pb.data[i][:0]) + pb.data[i] = nil + } + pb.data = pb.data[:0] + } pb.err = err pb.cond.Broadcast() } } +// Put takes ownership of data. The caller must not use data after calling Put. +// If the buffer is closed, data is returned to the pool. func (pb *packetBuffer) Put(data []byte) { pb.mu.Lock() defer pb.mu.Unlock() - for pb.set && pb.err == nil { - pb.cond.Wait() - } if pb.err != nil { + pktBufPool.Put(data[:0]) return } - pb.data = data - pb.set = true - pb.held = false + pb.data = append(pb.data, data) pb.cond.Broadcast() - - for pb.set || pb.held { - pb.cond.Wait() - } } func (pb *packetBuffer) Get() ([]byte, error) { pb.mu.Lock() defer pb.mu.Unlock() - for !pb.set && pb.err == nil { + for len(pb.data) == 0 && pb.err == nil { pb.cond.Wait() } - if pb.err != nil { + if len(pb.data) == 0 { return nil, pb.err } - pb.held = true - pb.cond.Broadcast() + data := pb.data[0] + n := copy(pb.data, pb.data[1:]) + pb.data[n] = nil + pb.data = pb.data[:n] + return data, nil +} - return pb.data, nil +// recycle returns a buffer obtained from Get back to the pool. +// Call this after the data has been fully consumed (e.g. after Unmarshal). +func (pb *packetBuffer) recycle(buf []byte) { + if buf != nil { + pktBufPool.Put(buf[:0]) + } } func (pb *packetBuffer) Done() { - pb.mu.Lock() - defer pb.mu.Unlock() - - pb.data = nil - pb.set = false - pb.held = false - pb.cond.Broadcast() + // Kept for backward compatibility with stream callers. } diff --git a/drpcstream/stream.go b/drpcstream/stream.go index 2f7aabd..7f5a54b 100644 --- a/drpcstream/stream.go +++ b/drpcstream/stream.go @@ -11,7 +11,6 @@ import ( "sync" "github.com/zeebo/errs" - "storj.io/drpc" "storj.io/drpc/drpcctx" "storj.io/drpc/drpcdebug" @@ -53,7 +52,7 @@ type Stream struct { flush sync.Once id drpcwire.ID - wr *drpcwire.Writer + wr drpcwire.StreamWriter pbuf packetBuffer wbuf []byte @@ -72,7 +71,7 @@ var _ drpc.Stream = (*Stream)(nil) // New returns a new stream bound to the context with the given stream id and // will use the writer to write messages on. It is important use monotonically // increasing stream ids within a single transport. -func New(ctx context.Context, sid uint64, wr *drpcwire.Writer) *Stream { +func New(ctx context.Context, sid uint64, wr drpcwire.StreamWriter) *Stream { return NewWithOptions(ctx, sid, wr, Options{}) } @@ -80,7 +79,9 @@ func New(ctx context.Context, sid uint64, wr *drpcwire.Writer) *Stream { // stream id and will use the writer to write messages on. It is important use // monotonically increasing stream ids within a single transport. The options // are used to control details of how the Stream operates. -func NewWithOptions(ctx context.Context, sid uint64, wr *drpcwire.Writer, opts Options) *Stream { +func NewWithOptions( + ctx context.Context, sid uint64, wr drpcwire.StreamWriter, opts Options, +) *Stream { var task *trace.Task if trace.IsEnabled() { kind, rpc := drpcopts.GetStreamKind(&opts.Internal), drpcopts.GetStreamRPC(&opts.Internal) @@ -99,7 +100,7 @@ func NewWithOptions(ctx context.Context, sid uint64, wr *drpcwire.Writer, opts O task: task, id: drpcwire.ID{Stream: sid}, - wr: wr.Reset(), + wr: wr, } // initialize the packet buffer @@ -221,17 +222,19 @@ func (s *Stream) HandlePacket(pkt drpcwire.Packet) (err error) { drpcopts.GetStreamStats(&s.opts.Internal).AddRead(uint64(len(pkt.Data))) - if s.sigs.term.IsSet() { + // Put must always be called for message packets to manage buffer + // ownership. Put returns the buffer to the pool if the stream is closed. + if pkt.Kind == drpcwire.KindMessage { + s.pbuf.Put(pkt.Data) return nil } - s.log("HANDLE", pkt.String) - - if pkt.Kind == drpcwire.KindMessage { - s.pbuf.Put(pkt.Data) + if s.sigs.term.IsSet() { return nil } + s.log("HANDLE", pkt.String) + s.mu.Lock() defer s.mu.Unlock() @@ -310,26 +313,28 @@ func (s *Stream) checkCancelError(err error) error { return err } -// newFrameLocked bumps the internal message id and returns a frame. It must be +// nextID bumps the internal message id and returns the new ID. It must be // called under a mutex. -func (s *Stream) newFrameLocked(kind drpcwire.Kind) drpcwire.Frame { +func (s *Stream) nextID() drpcwire.ID { s.id.Message++ - return drpcwire.Frame{ID: s.id, Kind: kind} + return s.id } // sendPacketLocked sends the packet in a single write and flushes. It does not // check for any conditions to stop it from writing and is meant for internal // stream use to do things like signal errors or closes to the remote side. func (s *Stream) sendPacketLocked(kind drpcwire.Kind, control bool, data []byte) (err error) { - fr := s.newFrameLocked(kind) - fr.Data = data - fr.Control = control - fr.Done = true + pkt := drpcwire.Packet{ + ID: s.nextID(), + Kind: kind, + Data: data, + Control: control, + } drpcopts.GetStreamStats(&s.opts.Internal).AddWritten(uint64(len(data))) - s.log("SEND", fr.String) + s.log("SEND", pkt.String) - if err := s.wr.WriteFrame(fr); err != nil { + if err := s.wr.WritePacket(pkt); err != nil { return errs.Wrap(err) } if err := s.wr.Flush(); err != nil { @@ -372,29 +377,26 @@ func (s *Stream) RawWrite(kind drpcwire.Kind, data []byte) (err error) { // rawWriteLocked does the body of RawWrite assuming the caller is holding the // appropriate locks. func (s *Stream) rawWriteLocked(kind drpcwire.Kind, data []byte) (err error) { - fr := s.newFrameLocked(kind) - n := s.opts.SplitSize - - for { - switch { - case s.sigs.send.IsSet(): - return s.sigs.send.Err() - case s.sigs.term.IsSet(): - return s.sigs.term.Err() - } + switch { + case s.sigs.send.IsSet(): + return s.sigs.send.Err() + case s.sigs.term.IsSet(): + return s.sigs.term.Err() + } - fr.Data, data = drpcwire.SplitData(data, n) - fr.Done = len(data) == 0 + pkt := drpcwire.Packet{ + ID: s.nextID(), + Kind: kind, + Data: data, + } - drpcopts.GetStreamStats(&s.opts.Internal).AddWritten(uint64(len(fr.Data))) - s.log("SEND", fr.String) + drpcopts.GetStreamStats(&s.opts.Internal).AddWritten(uint64(len(data))) + s.log("SEND", pkt.String) - if err := s.wr.WriteFrame(fr); err != nil { - return s.checkCancelError(errs.Wrap(err)) - } else if fr.Done { - return nil - } + if err := s.wr.WritePacket(pkt); err != nil { + return s.checkCancelError(errs.Wrap(err)) } + return nil } // RawFlush flushes any buffers of data. @@ -457,8 +459,7 @@ func (s *Stream) RawRecv() (data []byte, err error) { if err != nil { return nil, err } - data = append([]byte(nil), data...) - s.pbuf.Done() + // Transfer buffer ownership to the caller without recycling. return data, nil } @@ -510,7 +511,7 @@ func (s *Stream) MsgRecv(msg drpc.Message, enc drpc.Encoding) (err error) { return err } err = enc.Unmarshal(data, msg) - s.pbuf.Done() + s.pbuf.recycle(data) return err } diff --git a/drpcwire/reader.go b/drpcwire/reader.go index c9ac397..ef801d3 100644 --- a/drpcwire/reader.go +++ b/drpcwire/reader.go @@ -143,7 +143,10 @@ func (r *Reader) ReadPacketUsing(buf []byte) (pkt Packet, err error) { pkt.Control = pkt.Control || fr.Control switch { - case fr.ID.Less(r.id): + case fr.ID.Stream == 0 || fr.ID.Message == 0: + return Packet{}, drpc.ProtocolError.New("id monotonicity violation (fr:%v r:%v)", fr.ID, r.id) + + case fr.ID.Stream == r.id.Stream && fr.ID.Less(r.id): return Packet{}, drpc.ProtocolError.New("id monotonicity violation (fr:%v r:%v)", fr.ID, r.id) case r.id != fr.ID || pkt.ID == ID{}: diff --git a/drpcwire/reader_test.go b/drpcwire/reader_test.go index d57145b..ec0cd3f 100644 --- a/drpcwire/reader_test.go +++ b/drpcwire/reader_test.go @@ -154,6 +154,51 @@ func TestReader(t *testing.T) { Frames: []Frame{{ID: ID{Stream: 0, Message: 1}}}, Error: "id monotonicity violation", }, + + { // cross-stream: frames from a lower stream after a higher stream are allowed (multiplexing) + Packets: []Packet{ + {Data: []byte("a"), ID: ID{Stream: 1, Message: 1}, Kind: KindMessage}, + {Data: []byte("b"), ID: ID{Stream: 2, Message: 1}, Kind: KindMessage}, + {Data: nil, ID: ID{Stream: 1, Message: 2}, Kind: KindClose}, + }, + Frames: []Frame{ + {Data: []byte("a"), ID: ID{Stream: 1, Message: 1}, Kind: KindMessage, Done: true}, + {Data: []byte("b"), ID: ID{Stream: 2, Message: 1}, Kind: KindMessage, Done: true}, + {Data: nil, ID: ID{Stream: 1, Message: 2}, Kind: KindClose, Done: true}, + }, + }, + + { // cross-stream: interleaved single-frame packets from multiple streams + Packets: []Packet{ + {Data: []byte("s1m1"), ID: ID{Stream: 1, Message: 1}, Kind: KindInvoke}, + {Data: []byte("s2m1"), ID: ID{Stream: 2, Message: 1}, Kind: KindInvoke}, + {Data: []byte("s1m2"), ID: ID{Stream: 1, Message: 2}, Kind: KindMessage}, + {Data: []byte("s2m2"), ID: ID{Stream: 2, Message: 2}, Kind: KindMessage}, + {Data: nil, ID: ID{Stream: 1, Message: 3}, Kind: KindCloseSend}, + {Data: nil, ID: ID{Stream: 2, Message: 3}, Kind: KindCloseSend}, + }, + Frames: []Frame{ + {Data: []byte("s1m1"), ID: ID{Stream: 1, Message: 1}, Kind: KindInvoke, Done: true}, + {Data: []byte("s2m1"), ID: ID{Stream: 2, Message: 1}, Kind: KindInvoke, Done: true}, + {Data: []byte("s1m2"), ID: ID{Stream: 1, Message: 2}, Kind: KindMessage, Done: true}, + {Data: []byte("s2m2"), ID: ID{Stream: 2, Message: 2}, Kind: KindMessage, Done: true}, + {Data: nil, ID: ID{Stream: 1, Message: 3}, Kind: KindCloseSend, Done: true}, + {Data: nil, ID: ID{Stream: 2, Message: 3}, Kind: KindCloseSend, Done: true}, + }, + }, + + { // cross-stream: within-stream monotonicity is still enforced + Packets: []Packet{ + {Data: []byte("a"), ID: ID{Stream: 1, Message: 1}, Kind: KindMessage}, + {Data: []byte("b"), ID: ID{Stream: 2, Message: 1}, Kind: KindMessage}, + }, + Frames: []Frame{ + {Data: []byte("a"), ID: ID{Stream: 1, Message: 1}, Kind: KindMessage, Done: true}, + {Data: []byte("b"), ID: ID{Stream: 2, Message: 1}, Kind: KindMessage, Done: true}, + {Data: []byte("c"), ID: ID{Stream: 2, Message: 1}, Kind: KindMessage, Done: true}, // replay on stream 2 + }, + Error: "id monotonicity violation", + }, } for _, tc := range cases { diff --git a/drpcwire/writer.go b/drpcwire/writer.go index fe909cf..bfcf0ac 100644 --- a/drpcwire/writer.go +++ b/drpcwire/writer.go @@ -12,6 +12,15 @@ import ( "storj.io/drpc/drpcdebug" ) +// StreamWriter is the interface used by streams for writing packets. Each call +// to WritePacket must serialize the full data atomically so that frames from +// concurrent writers on different streams do not interleave on the wire. +type StreamWriter interface { + WritePacket(pkt Packet) error + Flush() error + Empty() bool +} + // // Writer // diff --git a/go.mod b/go.mod index a66c0f9..61bc019 100644 --- a/go.mod +++ b/go.mod @@ -8,13 +8,16 @@ require ( github.com/zeebo/errs v1.4.0 google.golang.org/genproto/googleapis/rpc v0.0.0-20230525234030-28d5490b6b19 google.golang.org/grpc v1.57.2 - google.golang.org/protobuf v1.30.0 + google.golang.org/protobuf v1.33.0 ) require ( github.com/davecgh/go-spew v1.1.1 // indirect github.com/golang/protobuf v1.5.3 // indirect + github.com/kr/pretty v0.3.1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect + golang.org/x/net v0.38.0 // indirect golang.org/x/sys v0.33.0 // indirect + gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index e9d3c41..1201203 100644 --- a/go.sum +++ b/go.sum @@ -1,3 +1,4 @@ +github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= @@ -6,20 +7,30 @@ github.com/golang/protobuf v1.5.3/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiu github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= +github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= +github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= +github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= +github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8= +github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/zeebo/assert v1.3.1 h1:vukIABvugfNMZMQO1ABsyQDJDTVQbn+LWSMy1ol1h6A= github.com/zeebo/assert v1.3.1/go.mod h1:Pq9JiuJQpG8JLJdtkwrJESF0Foym2/D9XMU5ciN/wJ0= github.com/zeebo/errs v1.4.0 h1:XNdoD/RRMKP7HD0UhJnIzUy74ISdGGxURlYG8HSWSfM= github.com/zeebo/errs v1.4.0/go.mod h1:sgbWHsvVuTPHcqJJGQ1WhI5KbWlHYz+2+2C/LSEtCw4= -golang.org/x/net v0.9.0 h1:aWJ/m6xSmxWBx+V0XRHTlrYrPG56jKsLdTFmsSsCzOM= -golang.org/x/net v0.9.0/go.mod h1:d48xBJpPfHeWQsugry2m+kC02ZBRGRgulfHnEXEuWns= +golang.org/x/net v0.38.0 h1:vRMAPTMaeGqVhG5QyLJHqNDwecKTomGeqbnfZyKlBI8= +golang.org/x/net v0.38.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8= golang.org/x/sys v0.33.0 h1:q3i8TbbEz+JRD9ywIRlyRAQbM0qF7hu24q3teo2hbuw= golang.org/x/sys v0.33.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= -golang.org/x/text v0.9.0 h1:2sjJmO8cDvYveuX97RDLsxlyUxLl+GHoLxBiRdHllBE= -golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= +golang.org/x/text v0.23.0 h1:D71I7dUrlY+VX0gQShAThNGHFxZ13dGLBHQLVl1mJlY= +golang.org/x/text v0.23.0/go.mod h1:/BLNzu4aZCJ1+kcD0DNRotWKage4q2rGVAg4o22unh4= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= google.golang.org/genproto/googleapis/rpc v0.0.0-20230525234030-28d5490b6b19 h1:0nDDozoAU19Qb2HwhXadU8OcsiO/09cnTqhUtq2MEOM= google.golang.org/genproto/googleapis/rpc v0.0.0-20230525234030-28d5490b6b19/go.mod h1:66JfowdXAEgad5O9NnYcsNPLCPZJD++2L9X0PCMODrA= @@ -27,9 +38,10 @@ google.golang.org/grpc v1.57.2 h1:uw37EN34aMFFXB2QPW7Tq6tdTbind1GpRxw5aOX3a5k= google.golang.org/grpc v1.57.2/go.mod h1:Sd+9RMTACXwmub0zcNY2c4arhtrbBYD1AUHI/dt16Mo= google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= -google.golang.org/protobuf v1.30.0 h1:kPPoIgf3TsEvrm0PFe15JQ+570QVxYzEvvHqChK+cng= -google.golang.org/protobuf v1.30.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= -gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +google.golang.org/protobuf v1.33.0 h1:uNO2rsAINq/JlFpSdYEKIZ0uKD/R9cpdv0T+yoGwGmI= +google.golang.org/protobuf v1.33.0/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/internal/integration/Makefile b/internal/integration/Makefile new file mode 100644 index 0000000..4430126 --- /dev/null +++ b/internal/integration/Makefile @@ -0,0 +1,68 @@ +GO ?= go +COUNT ?= 1 +BENCH ?= . +RUN ?= ^$$ +TIMEOUT ?= 300s +PROFILES_DIR ?= /tmp/drpc-profiles +CPU_PROFILE ?= $(PROFILES_DIR)/cpu.pprof +MEM_PROFILE ?= $(PROFILES_DIR)/mem.pprof +BLOCK_PROFILE ?= $(PROFILES_DIR)/block.pprof +MUTEX_PROFILE ?= $(PROFILES_DIR)/mutex.pprof +TRACE_PROFILE ?= $(PROFILES_DIR)/trace.out +BLOCK_PROFILE_RATE ?= 1 +MEM_PROFILE_RATE ?= 4096 +MUTEX_PROFILE_FRACTION ?= 1 + +# Run benchmarks once +.PHONY: bench +bench: + $(GO) test -bench=$(BENCH) -benchmem -count=$(COUNT) -timeout=$(TIMEOUT) . + +# Run benchmarks with count=10 for use with benchstat. +# Usage: make bench-stat > before.txt (then after changes) make bench-stat > after.txt +# benchstat before.txt after.txt +.PHONY: bench-stat +bench-stat: + $(GO) test -bench=$(BENCH) -benchmem -count=10 -timeout=$(TIMEOUT) . + +.PHONY: bench-parallel-streams +bench-parallel-streams: + $(GO) test -run=$(RUN) -bench=BenchmarkParallelActiveStreams -benchmem -count=$(COUNT) -timeout=$(TIMEOUT) . + +# Collect benchmark profiles in one run. +# Example: +# make bench-profile BENCH='BenchmarkParallelActiveStreams/size=1KB' COUNT=1 +.PHONY: bench-profile +bench-profile: + mkdir -p $(PROFILES_DIR) + $(GO) test -run=$(RUN) -bench=$(BENCH) -benchmem -count=$(COUNT) -timeout=$(TIMEOUT) \ + -cpuprofile=$(CPU_PROFILE) \ + -memprofile=$(MEM_PROFILE) -memprofilerate=$(MEM_PROFILE_RATE) \ + -blockprofile=$(BLOCK_PROFILE) -blockprofilerate=$(BLOCK_PROFILE_RATE) \ + -mutexprofile=$(MUTEX_PROFILE) -mutexprofilefraction=$(MUTEX_PROFILE_FRACTION) \ + -trace=$(TRACE_PROFILE) . + +.PHONY: profile-top-cpu +profile-top-cpu: + $(GO) tool pprof -top $(CPU_PROFILE) + +.PHONY: profile-top-alloc-space +profile-top-alloc-space: + $(GO) tool pprof -top -sample_index=alloc_space $(MEM_PROFILE) + +.PHONY: profile-top-inuse-space +profile-top-inuse-space: + $(GO) tool pprof -top -sample_index=inuse_space $(MEM_PROFILE) + +.PHONY: profile-top-block +profile-top-block: + $(GO) tool pprof -top $(BLOCK_PROFILE) + +.PHONY: profile-top-mutex +profile-top-mutex: + $(GO) tool pprof -top $(MUTEX_PROFILE) + +.PHONY: gen-proto +gen-proto: + $(MAKE) -C ../.. install-protoc-plugin + $(GO) generate . diff --git a/internal/integration/benchmark_test.go b/internal/integration/benchmark_test.go new file mode 100644 index 0000000..e4edf8f --- /dev/null +++ b/internal/integration/benchmark_test.go @@ -0,0 +1,338 @@ +// Copyright (C) 2026 Cockroach Labs. +// See LICENSE for copying information. + +package integration + +import ( + "context" + "errors" + "io" + "strconv" + "sync" + "sync/atomic" + "testing" + "time" + + "google.golang.org/protobuf/proto" + "storj.io/drpc/drpcconn" + "storj.io/drpc/drpctest" +) + +var echoServer = impl{ + Method1Fn: func(_ context.Context, in *In) (*Out, error) { + return &Out{Out: in.In, Data: in.Data}, nil + }, + + Method2Fn: func(stream DRPCService_Method2Stream) error { + var last *In + for { + in, err := stream.Recv() + if errors.Is(err, io.EOF) { + break + } else if err != nil { + return err + } + last = in + } + if last == nil { + last = &In{} + } + return stream.SendAndClose(&Out{Out: last.In, Data: last.Data}) + }, + + Method3Fn: func(in *In, stream DRPCService_Method3Stream) error { + out := &Out{Out: 1, Data: in.Data} + for i := int64(0); i < in.In; i++ { + if err := stream.Send(out); err != nil { + return err + } + } + return nil + }, + + Method4Fn: func(stream DRPCService_Method4Stream) error { + for { + in, err := stream.Recv() + if errors.Is(err, io.EOF) { + return nil + } else if err != nil { + return err + } + if err := stream.Send(&Out{Out: in.In, Data: in.Data}); err != nil { + return err + } + } + }, +} + +var sizes = []struct { + name string + data []byte +}{ + {"small", nil}, + {"1KB", make([]byte, 1024)}, + {"8KB", make([]byte, 8192)}, +} + +var concurrencies = []int{1, 10, 100} +var activeStreamCounts = []int{2, 8, 32} + +const parallelActiveStreamsWriteLatency = 200 * time.Microsecond + +func BenchmarkUnary(b *testing.B) { + for _, sz := range sizes { + b.Run("size="+sz.name, func(b *testing.B) { + for _, c := range concurrencies { + b.Run("concurrent="+strconv.Itoa(c), func(b *testing.B) { + client, cleanup := createConnection(b, echoServer) + defer cleanup() + + in := &In{In: 1, Data: sz.data} + b.SetBytes(int64(proto.Size(in))) + b.ReportAllocs() + b.SetParallelism(c) + b.ResetTimer() + + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + if _, err := client.Method1(context.Background(), in); err != nil { + b.Error(err) + } + } + }) + }) + } + }) + } +} + +func BenchmarkInputStream(b *testing.B) { + for _, sz := range sizes { + b.Run("size="+sz.name, func(b *testing.B) { + client, cleanup := createConnection(b, echoServer) + defer cleanup() + + in := &In{In: 1, Data: sz.data} + stream, err := client.Method2(context.Background()) + if err != nil { + b.Fatal(err) + } + + b.SetBytes(int64(proto.Size(in))) + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + if err := stream.Send(in); err != nil { + b.Fatal(err) + } + } + + if _, err := stream.CloseAndRecv(); err != nil { + b.Fatal(err) + } + }) + } +} + +func BenchmarkOutputStream(b *testing.B) { + for _, sz := range sizes { + b.Run("size="+sz.name, func(b *testing.B) { + client, cleanup := createConnection(b, echoServer) + defer cleanup() + + in := &In{In: int64(b.N), Data: sz.data} + stream, err := client.Method3(context.Background(), in) + if err != nil { + b.Fatal(err) + } + + b.ReportAllocs() + b.ResetTimer() + + var last *Out + for i := 0; i < b.N; i++ { + last, err = stream.Recv() + if err != nil { + b.Fatal(err) + } + } + if last != nil { + b.SetBytes(int64(proto.Size(last))) + } + }) + } +} + +func BenchmarkBidiStream(b *testing.B) { + for _, sz := range sizes { + b.Run("size="+sz.name, func(b *testing.B) { + client, cleanup := createConnection(b, echoServer) + defer cleanup() + + in := &In{In: 1, Data: sz.data} + stream, err := client.Method4(context.Background()) + if err != nil { + b.Fatal(err) + } + + b.SetBytes(int64(proto.Size(in))) + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + if err := stream.Send(in); err != nil { + b.Fatal(err) + } + if _, err := stream.Recv(); err != nil { + b.Fatal(err) + } + } + }) + } +} + +func BenchmarkParallelActiveStreams(b *testing.B) { + benchmarkParallelActiveStreamsModes(b, 0) +} + +func BenchmarkParallelActiveStreamsWriteLatency(b *testing.B) { + benchmarkParallelActiveStreamsModes(b, parallelActiveStreamsWriteLatency) +} + +func benchmarkParallelActiveStreamsModes(b *testing.B, writeLatency time.Duration) { + modes := []struct { + name string + oneConnPerStream bool + useTCP bool + }{ + {name: "mux_single_connection", oneConnPerStream: false}, + {name: "one_connection_per_stream", oneConnPerStream: true}, + {name: "mux_tcp", oneConnPerStream: false, useTCP: true}, + {name: "one_conn_tcp", oneConnPerStream: true, useTCP: true}, + } + + for _, sz := range sizes { + b.Run("size="+sz.name, func(b *testing.B) { + for _, streamCount := range activeStreamCounts { + b.Run("streams="+strconv.Itoa(streamCount), func(b *testing.B) { + for _, mode := range modes { + b.Run(mode.name, func(b *testing.B) { + benchmarkParallelActiveStreams(b, sz.data, streamCount, mode.oneConnPerStream, writeLatency, mode.useTCP) + }) + } + }) + } + }) + } +} + +func benchmarkParallelActiveStreams( + b *testing.B, payload []byte, streamCount int, oneConnPerStream bool, writeLatency time.Duration, useTCP bool, +) { + ctx := drpctest.NewTracker(b) + defer ctx.Close() + + type bidiClient interface { + Send(*In) error + Recv() (*Out, error) + Close() error + } + + type worker struct { + stream bidiClient + close func() + } + workers := make([]worker, 0, streamCount) + + createConn := createRawConnectionWithWriteDelay + if useTCP { + createConn = createTCPConnectionWithWriteDelay + } + + var sharedConn *drpcconn.Conn + if !oneConnPerStream { + sharedConn = createConn(b, echoServer, ctx, writeLatency) + } + + for i := 0; i < streamCount; i++ { + conn := sharedConn + if oneConnPerStream { + conn = createConn(b, echoServer, ctx, writeLatency) + } + + client := NewDRPCServiceClient(conn) + stream, err := client.Method4(context.Background()) + if err != nil { + b.Fatalf("create stream %d: %v", i, err) + } + + s := stream + c := conn + workers = append(workers, worker{ + stream: s, + close: func() { + _ = s.Close() + if oneConnPerStream { + _ = c.Close() + } + }, + }) + } + + input := &In{In: 1, Data: payload} + b.SetBytes(int64(proto.Size(input))) + b.ReportAllocs() + + var sent atomic.Int64 + errCh := make(chan error, 1) + start := make(chan struct{}) + var wg sync.WaitGroup + + b.ResetTimer() + for _, w := range workers { + wg.Add(1) + go func(w worker) { + defer wg.Done() + msg := &In{In: 1, Data: payload} + <-start + + for { + n := sent.Add(1) + if n > int64(b.N) { + return + } + if err := w.stream.Send(msg); err != nil { + select { + case errCh <- err: + default: + } + return + } + if _, err := w.stream.Recv(); err != nil { + select { + case errCh <- err: + default: + } + return + } + } + }(w) + } + close(start) + wg.Wait() + b.StopTimer() + + select { + case err := <-errCh: + b.Fatal(err) + default: + } + + for _, w := range workers { + w.close() + } + if sharedConn != nil { + _ = sharedConn.Close() + } +} diff --git a/internal/integration/common_test.go b/internal/integration/common_test.go index 5acb61a..7780958 100644 --- a/internal/integration/common_test.go +++ b/internal/integration/common_test.go @@ -10,6 +10,7 @@ import ( "strconv" "sync" "testing" + "time" "github.com/zeebo/assert" "google.golang.org/grpc/codes" @@ -40,18 +41,86 @@ func in(n int64) *In { return &In{In: n} } func out(n int64) *Out { return &Out{Out: n} } func createRawConnection(t testing.TB, server DRPCServiceServer, ctx *drpctest.Tracker) *drpcconn.Conn { - c1, c2 := net.Pipe() + return createRawConnectionWithWriteDelay(t, server, ctx, 0) +} + +func createRawConnectionWithWriteDelay( + t testing.TB, server DRPCServiceServer, ctx *drpctest.Tracker, writeDelay time.Duration, +) *drpcconn.Conn { + return createConnectionWithTransport(t, server, ctx, writeDelay, false) +} + +func createTCPConnectionWithWriteDelay( + t testing.TB, server DRPCServiceServer, ctx *drpctest.Tracker, writeDelay time.Duration, +) *drpcconn.Conn { + return createConnectionWithTransport(t, server, ctx, writeDelay, true) +} + +func createConnectionWithTransport( + t testing.TB, server DRPCServiceServer, ctx *drpctest.Tracker, + writeDelay time.Duration, useTCP bool, +) *drpcconn.Conn { + var serverConn, clientConn net.Conn + if useTCP { + lis, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { lis.Close() }) + + type result struct { + conn net.Conn + err error + } + ch := make(chan result, 1) + go func() { + c, err := lis.Accept() + ch <- result{c, err} + }() + + c, err := net.Dial("tcp", lis.Addr().String()) + if err != nil { + t.Fatal(err) + } + r := <-ch + if r.err != nil { + t.Fatal(r.err) + } + serverConn, clientConn = r.conn, c + t.Cleanup(func() { + serverConn.Close() + clientConn.Close() + }) + } else { + c1, c2 := net.Pipe() + serverConn, clientConn = c1, c2 + } + + if writeDelay > 0 { + serverConn = &writeLatencyConn{Conn: serverConn, writeDelay: writeDelay} + clientConn = &writeLatencyConn{Conn: clientConn, writeDelay: writeDelay} + } mux := drpcmux.New() assert.NoError(t, DRPCRegisterService(mux, server)) srv := drpcserver.New(mux) - ctx.Run(func(ctx context.Context) { _ = srv.ServeOne(ctx, c1) }) - return drpcconn.NewWithOptions(c2, drpcconn.Options{ + ctx.Run(func(ctx context.Context) { _ = srv.ServeOne(ctx, serverConn) }) + return drpcconn.NewWithOptions(clientConn, drpcconn.Options{ Manager: drpcmanager.Options{ SoftCancel: true, }, }) } +type writeLatencyConn struct { + net.Conn + writeDelay time.Duration +} + +func (c *writeLatencyConn) Write(p []byte) (int, error) { + time.Sleep(c.writeDelay) + return c.Conn.Write(p) +} + func createConnection(t testing.TB, server DRPCServiceServer) (DRPCServiceClient, func()) { ctx := drpctest.NewTracker(t) conn := createRawConnection(t, server, ctx) diff --git a/internal/integration/customservice/service.pb.go b/internal/integration/customservice/service.pb.go index 251b45f..336bb29 100644 --- a/internal/integration/customservice/service.pb.go +++ b/internal/integration/customservice/service.pb.go @@ -3,8 +3,8 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: -// protoc-gen-go v1.27.1 -// protoc v4.24.4 +// protoc-gen-go v1.36.5 +// protoc v5.29.3 // source: service.proto package service @@ -14,6 +14,7 @@ import ( protoimpl "google.golang.org/protobuf/runtime/protoimpl" reflect "reflect" sync "sync" + unsafe "unsafe" ) const ( @@ -24,21 +25,18 @@ const ( ) type In struct { - state protoimpl.MessageState - sizeCache protoimpl.SizeCache + state protoimpl.MessageState `protogen:"open.v1"` + In int64 `protobuf:"varint,1,opt,name=in,proto3" json:"in,omitempty"` + Data []byte `protobuf:"bytes,2,opt,name=data,proto3" json:"data,omitempty"` unknownFields protoimpl.UnknownFields - - In int64 `protobuf:"varint,1,opt,name=in,proto3" json:"in,omitempty"` - Data []byte `protobuf:"bytes,2,opt,name=data,proto3" json:"data,omitempty"` + sizeCache protoimpl.SizeCache } func (x *In) Reset() { *x = In{} - if protoimpl.UnsafeEnabled { - mi := &file_service_proto_msgTypes[0] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) - } + mi := &file_service_proto_msgTypes[0] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) } func (x *In) String() string { @@ -49,7 +47,7 @@ func (*In) ProtoMessage() {} func (x *In) ProtoReflect() protoreflect.Message { mi := &file_service_proto_msgTypes[0] - if protoimpl.UnsafeEnabled && x != nil { + if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) @@ -79,21 +77,18 @@ func (x *In) GetData() []byte { } type Out struct { - state protoimpl.MessageState - sizeCache protoimpl.SizeCache + state protoimpl.MessageState `protogen:"open.v1"` + Out int64 `protobuf:"varint,1,opt,name=out,proto3" json:"out,omitempty"` + Data []byte `protobuf:"bytes,2,opt,name=data,proto3" json:"data,omitempty"` unknownFields protoimpl.UnknownFields - - Out int64 `protobuf:"varint,1,opt,name=out,proto3" json:"out,omitempty"` - Data []byte `protobuf:"bytes,2,opt,name=data,proto3" json:"data,omitempty"` + sizeCache protoimpl.SizeCache } func (x *Out) Reset() { *x = Out{} - if protoimpl.UnsafeEnabled { - mi := &file_service_proto_msgTypes[1] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) - } + mi := &file_service_proto_msgTypes[1] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) } func (x *Out) String() string { @@ -104,7 +99,7 @@ func (*Out) ProtoMessage() {} func (x *Out) ProtoReflect() protoreflect.Message { mi := &file_service_proto_msgTypes[1] - if protoimpl.UnsafeEnabled && x != nil { + if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) @@ -135,7 +130,7 @@ func (x *Out) GetData() []byte { var File_service_proto protoreflect.FileDescriptor -var file_service_proto_rawDesc = []byte{ +var file_service_proto_rawDesc = string([]byte{ 0x0a, 0x0d, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x07, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x22, 0x28, 0x0a, 0x02, 0x49, 0x6e, 0x12, 0x0e, 0x0a, 0x02, 0x69, 0x6e, 0x18, 0x01, 0x20, 0x01, 0x28, 0x03, 0x52, 0x02, 0x69, 0x6e, 0x12, 0x12, @@ -158,22 +153,22 @@ var file_service_proto_rawDesc = []byte{ 0x65, 0x72, 0x6e, 0x61, 0x6c, 0x2f, 0x69, 0x6e, 0x74, 0x65, 0x67, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x2f, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, -} +}) var ( file_service_proto_rawDescOnce sync.Once - file_service_proto_rawDescData = file_service_proto_rawDesc + file_service_proto_rawDescData []byte ) func file_service_proto_rawDescGZIP() []byte { file_service_proto_rawDescOnce.Do(func() { - file_service_proto_rawDescData = protoimpl.X.CompressGZIP(file_service_proto_rawDescData) + file_service_proto_rawDescData = protoimpl.X.CompressGZIP(unsafe.Slice(unsafe.StringData(file_service_proto_rawDesc), len(file_service_proto_rawDesc))) }) return file_service_proto_rawDescData } var file_service_proto_msgTypes = make([]protoimpl.MessageInfo, 2) -var file_service_proto_goTypes = []interface{}{ +var file_service_proto_goTypes = []any{ (*In)(nil), // 0: service.In (*Out)(nil), // 1: service.Out } @@ -198,37 +193,11 @@ func file_service_proto_init() { if File_service_proto != nil { return } - if !protoimpl.UnsafeEnabled { - file_service_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*In); i { - case 0: - return &v.state - case 1: - return &v.sizeCache - case 2: - return &v.unknownFields - default: - return nil - } - } - file_service_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*Out); i { - case 0: - return &v.state - case 1: - return &v.sizeCache - case 2: - return &v.unknownFields - default: - return nil - } - } - } type x struct{} out := protoimpl.TypeBuilder{ File: protoimpl.DescBuilder{ GoPackagePath: reflect.TypeOf(x{}).PkgPath(), - RawDescriptor: file_service_proto_rawDesc, + RawDescriptor: unsafe.Slice(unsafe.StringData(file_service_proto_rawDesc), len(file_service_proto_rawDesc)), NumEnums: 0, NumMessages: 2, NumExtensions: 0, @@ -239,7 +208,6 @@ func file_service_proto_init() { MessageInfos: file_service_proto_msgTypes, }.Build() File_service_proto = out.File - file_service_proto_rawDesc = nil file_service_proto_goTypes = nil file_service_proto_depIdxs = nil } diff --git a/internal/integration/customservice/service_drpc.pb.go b/internal/integration/customservice/service_drpc.pb.go index 9991940..7ae7bcd 100644 --- a/internal/integration/customservice/service_drpc.pb.go +++ b/internal/integration/customservice/service_drpc.pb.go @@ -1,12 +1,12 @@ // Code generated by protoc-gen-go-drpc. DO NOT EDIT. -// protoc-gen-go-drpc version: (devel) +// protoc-gen-go-drpc version: v0.0.35-0.20260221163524-0f9d4319a40d+dirty // source: service.proto package service import ( context "context" - errors "errors" + errors "github.com/cockroachdb/errors" drpc "storj.io/drpc" drpcerr "storj.io/drpc/drpcerr" customencoding "storj.io/drpc/internal/integration/customencoding" @@ -280,6 +280,7 @@ type DRPCService_Method2Stream interface { drpc.Stream SendAndClose(*Out) error Recv() (*In, error) + RecvMsg(interface{}) error } type drpcService_Method2Stream struct { @@ -305,7 +306,7 @@ func (x *drpcService_Method2Stream) Recv() (*In, error) { return m, nil } -func (x *drpcService_Method2Stream) RecvMsg(m *In) error { +func (x *drpcService_Method2Stream) RecvMsg(m interface{}) error { return x.MsgRecv(m, drpcEncoding_File_service_proto{}) } @@ -330,6 +331,7 @@ type DRPCService_Method4Stream interface { drpc.Stream Send(*Out) error Recv() (*In, error) + RecvMsg(interface{}) error } type drpcService_Method4Stream struct { @@ -352,6 +354,6 @@ func (x *drpcService_Method4Stream) Recv() (*In, error) { return m, nil } -func (x *drpcService_Method4Stream) RecvMsg(m *In) error { +func (x *drpcService_Method4Stream) RecvMsg(m interface{}) error { return x.MsgRecv(m, drpcEncoding_File_service_proto{}) } diff --git a/internal/integration/doc.go b/internal/integration/doc.go index d23a99f..4f2c23a 100644 --- a/internal/integration/doc.go +++ b/internal/integration/doc.go @@ -4,6 +4,6 @@ // Package integration holds integration tests for drpc. package integration -//go:generate protoc --go_out=paths=source_relative:service/. --go-drpc_out=paths=source_relative:service/. service.proto -//go:generate protoc --gogo_out=paths=source_relative:gogoservice/. --go-drpc_out=paths=source_relative,protolib=github.com/gogo/protobuf:gogoservice/. service.proto -//go:generate protoc --go_out=paths=source_relative:customservice/. --go-drpc_out=paths=source_relative,protolib=storj.io/drpc/internal/integration/customencoding:customservice/. service.proto +//go:generate protoc --go_out=paths=source_relative:service/. --go-drpc_out=paths=source_relative,generate-adapters=false:service/. service.proto +//go:generate protoc --gogo_out=paths=source_relative:gogoservice/. --go-drpc_out=paths=source_relative,protolib=github.com/gogo/protobuf,generate-adapters=false:gogoservice/. service.proto +//go:generate protoc --go_out=paths=source_relative:customservice/. --go-drpc_out=paths=source_relative,protolib=storj.io/drpc/internal/integration/customencoding,generate-adapters=false:customservice/. service.proto diff --git a/internal/integration/go.mod b/internal/integration/go.mod index df01970..efe284e 100644 --- a/internal/integration/go.mod +++ b/internal/integration/go.mod @@ -3,6 +3,7 @@ module storj.io/drpc/internal/integration go 1.25.0 require ( + github.com/cockroachdb/errors v1.12.0 github.com/gogo/protobuf v1.3.2 github.com/zeebo/assert v1.3.1 github.com/zeebo/errs v1.4.0 @@ -13,7 +14,16 @@ require ( ) require ( + github.com/cockroachdb/logtags v0.0.0-20230118201751-21c54148d20b // indirect + github.com/cockroachdb/redact v1.1.5 // indirect + github.com/getsentry/sentry-go v0.27.0 // indirect github.com/golang/protobuf v1.5.3 // indirect + github.com/kr/pretty v0.3.1 // indirect + github.com/kr/text v0.2.0 // indirect + github.com/pkg/errors v0.9.1 // indirect + github.com/rogpeppe/go-internal v1.9.0 // indirect + golang.org/x/sys v0.33.0 // indirect + golang.org/x/text v0.23.0 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20230525234030-28d5490b6b19 // indirect ) diff --git a/internal/integration/go.sum b/internal/integration/go.sum index d06e526..79c4117 100644 --- a/internal/integration/go.sum +++ b/internal/integration/go.sum @@ -1,5 +1,16 @@ +github.com/cockroachdb/errors v1.12.0 h1:d7oCs6vuIMUQRVbi6jWWWEJZahLCfJpnJSVobd1/sUo= +github.com/cockroachdb/errors v1.12.0/go.mod h1:SvzfYNNBshAVbZ8wzNc/UPK3w1vf0dKDUP41ucAIf7g= +github.com/cockroachdb/logtags v0.0.0-20230118201751-21c54148d20b h1:r6VH0faHjZeQy818SGhaone5OnYfxFR/+AzdY3sf5aE= +github.com/cockroachdb/logtags v0.0.0-20230118201751-21c54148d20b/go.mod h1:Vz9DsVWQQhf3vs21MhPMZpMGSht7O/2vFW2xusFUVOs= +github.com/cockroachdb/redact v1.1.5 h1:u1PMllDkdFfPWaNGMyLD1+so+aq3uUItthCFqzwPJ30= +github.com/cockroachdb/redact v1.1.5/go.mod h1:BVNblN9mBWFyMyqK1k3AAiSxhvhfK2oOZZ2lK+dpvRg= +github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/getsentry/sentry-go v0.27.0 h1:Pv98CIbtB3LkMWmXi4Joa5OOcwbmnX88sF5qbK3r3Ps= +github.com/getsentry/sentry-go v0.27.0/go.mod h1:lc76E2QywIyW8WuBnwl8Lc4bkmQH4+w1gwTf25trprY= +github.com/go-errors/errors v1.4.2 h1:J6MZopCL4uSllY1OfXM374weqZFFItUbrImctkmUxIA= +github.com/go-errors/errors v1.4.2/go.mod h1:sIVyrIiJhuEF+Pj9Ebtd6P/rEYROXFi3BopGUQ5a5Og= github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q= github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= @@ -10,8 +21,19 @@ github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8= github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= +github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= +github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/pingcap/errors v0.11.4 h1:lFuQV/oaUMGcD2tqt+01ROSmJs75VG1ToEOkZIZ4nE4= +github.com/pingcap/errors v0.11.4/go.mod h1:Oi8TUi2kEtXXLMJk9l1cGmz20kV3TaQ0usTwv5KuLY8= +github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA= +github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= +github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8= +github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= diff --git a/internal/integration/gogoservice/service_drpc.pb.go b/internal/integration/gogoservice/service_drpc.pb.go index 1a61701..b65e6f2 100644 --- a/internal/integration/gogoservice/service_drpc.pb.go +++ b/internal/integration/gogoservice/service_drpc.pb.go @@ -1,5 +1,5 @@ // Code generated by protoc-gen-go-drpc. DO NOT EDIT. -// protoc-gen-go-drpc version: (devel) +// protoc-gen-go-drpc version: v0.0.35-0.20260221163524-0f9d4319a40d+dirty // source: service.proto package service @@ -7,7 +7,7 @@ package service import ( bytes "bytes" context "context" - errors "errors" + errors "github.com/cockroachdb/errors" jsonpb "github.com/gogo/protobuf/jsonpb" proto "github.com/gogo/protobuf/proto" drpc "storj.io/drpc" @@ -287,6 +287,7 @@ type DRPCService_Method2Stream interface { drpc.Stream SendAndClose(*Out) error Recv() (*In, error) + RecvMsg(interface{}) error } type drpcService_Method2Stream struct { @@ -312,7 +313,7 @@ func (x *drpcService_Method2Stream) Recv() (*In, error) { return m, nil } -func (x *drpcService_Method2Stream) RecvMsg(m *In) error { +func (x *drpcService_Method2Stream) RecvMsg(m interface{}) error { return x.MsgRecv(m, drpcEncoding_File_service_proto{}) } @@ -337,6 +338,7 @@ type DRPCService_Method4Stream interface { drpc.Stream Send(*Out) error Recv() (*In, error) + RecvMsg(interface{}) error } type drpcService_Method4Stream struct { @@ -359,6 +361,6 @@ func (x *drpcService_Method4Stream) Recv() (*In, error) { return m, nil } -func (x *drpcService_Method4Stream) RecvMsg(m *In) error { +func (x *drpcService_Method4Stream) RecvMsg(m interface{}) error { return x.MsgRecv(m, drpcEncoding_File_service_proto{}) } diff --git a/internal/integration/service/service.pb.go b/internal/integration/service/service.pb.go index 251b45f..336bb29 100644 --- a/internal/integration/service/service.pb.go +++ b/internal/integration/service/service.pb.go @@ -3,8 +3,8 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: -// protoc-gen-go v1.27.1 -// protoc v4.24.4 +// protoc-gen-go v1.36.5 +// protoc v5.29.3 // source: service.proto package service @@ -14,6 +14,7 @@ import ( protoimpl "google.golang.org/protobuf/runtime/protoimpl" reflect "reflect" sync "sync" + unsafe "unsafe" ) const ( @@ -24,21 +25,18 @@ const ( ) type In struct { - state protoimpl.MessageState - sizeCache protoimpl.SizeCache + state protoimpl.MessageState `protogen:"open.v1"` + In int64 `protobuf:"varint,1,opt,name=in,proto3" json:"in,omitempty"` + Data []byte `protobuf:"bytes,2,opt,name=data,proto3" json:"data,omitempty"` unknownFields protoimpl.UnknownFields - - In int64 `protobuf:"varint,1,opt,name=in,proto3" json:"in,omitempty"` - Data []byte `protobuf:"bytes,2,opt,name=data,proto3" json:"data,omitempty"` + sizeCache protoimpl.SizeCache } func (x *In) Reset() { *x = In{} - if protoimpl.UnsafeEnabled { - mi := &file_service_proto_msgTypes[0] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) - } + mi := &file_service_proto_msgTypes[0] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) } func (x *In) String() string { @@ -49,7 +47,7 @@ func (*In) ProtoMessage() {} func (x *In) ProtoReflect() protoreflect.Message { mi := &file_service_proto_msgTypes[0] - if protoimpl.UnsafeEnabled && x != nil { + if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) @@ -79,21 +77,18 @@ func (x *In) GetData() []byte { } type Out struct { - state protoimpl.MessageState - sizeCache protoimpl.SizeCache + state protoimpl.MessageState `protogen:"open.v1"` + Out int64 `protobuf:"varint,1,opt,name=out,proto3" json:"out,omitempty"` + Data []byte `protobuf:"bytes,2,opt,name=data,proto3" json:"data,omitempty"` unknownFields protoimpl.UnknownFields - - Out int64 `protobuf:"varint,1,opt,name=out,proto3" json:"out,omitempty"` - Data []byte `protobuf:"bytes,2,opt,name=data,proto3" json:"data,omitempty"` + sizeCache protoimpl.SizeCache } func (x *Out) Reset() { *x = Out{} - if protoimpl.UnsafeEnabled { - mi := &file_service_proto_msgTypes[1] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) - } + mi := &file_service_proto_msgTypes[1] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) } func (x *Out) String() string { @@ -104,7 +99,7 @@ func (*Out) ProtoMessage() {} func (x *Out) ProtoReflect() protoreflect.Message { mi := &file_service_proto_msgTypes[1] - if protoimpl.UnsafeEnabled && x != nil { + if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) @@ -135,7 +130,7 @@ func (x *Out) GetData() []byte { var File_service_proto protoreflect.FileDescriptor -var file_service_proto_rawDesc = []byte{ +var file_service_proto_rawDesc = string([]byte{ 0x0a, 0x0d, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x07, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x22, 0x28, 0x0a, 0x02, 0x49, 0x6e, 0x12, 0x0e, 0x0a, 0x02, 0x69, 0x6e, 0x18, 0x01, 0x20, 0x01, 0x28, 0x03, 0x52, 0x02, 0x69, 0x6e, 0x12, 0x12, @@ -158,22 +153,22 @@ var file_service_proto_rawDesc = []byte{ 0x65, 0x72, 0x6e, 0x61, 0x6c, 0x2f, 0x69, 0x6e, 0x74, 0x65, 0x67, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x2f, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, -} +}) var ( file_service_proto_rawDescOnce sync.Once - file_service_proto_rawDescData = file_service_proto_rawDesc + file_service_proto_rawDescData []byte ) func file_service_proto_rawDescGZIP() []byte { file_service_proto_rawDescOnce.Do(func() { - file_service_proto_rawDescData = protoimpl.X.CompressGZIP(file_service_proto_rawDescData) + file_service_proto_rawDescData = protoimpl.X.CompressGZIP(unsafe.Slice(unsafe.StringData(file_service_proto_rawDesc), len(file_service_proto_rawDesc))) }) return file_service_proto_rawDescData } var file_service_proto_msgTypes = make([]protoimpl.MessageInfo, 2) -var file_service_proto_goTypes = []interface{}{ +var file_service_proto_goTypes = []any{ (*In)(nil), // 0: service.In (*Out)(nil), // 1: service.Out } @@ -198,37 +193,11 @@ func file_service_proto_init() { if File_service_proto != nil { return } - if !protoimpl.UnsafeEnabled { - file_service_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*In); i { - case 0: - return &v.state - case 1: - return &v.sizeCache - case 2: - return &v.unknownFields - default: - return nil - } - } - file_service_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*Out); i { - case 0: - return &v.state - case 1: - return &v.sizeCache - case 2: - return &v.unknownFields - default: - return nil - } - } - } type x struct{} out := protoimpl.TypeBuilder{ File: protoimpl.DescBuilder{ GoPackagePath: reflect.TypeOf(x{}).PkgPath(), - RawDescriptor: file_service_proto_rawDesc, + RawDescriptor: unsafe.Slice(unsafe.StringData(file_service_proto_rawDesc), len(file_service_proto_rawDesc)), NumEnums: 0, NumMessages: 2, NumExtensions: 0, @@ -239,7 +208,6 @@ func file_service_proto_init() { MessageInfos: file_service_proto_msgTypes, }.Build() File_service_proto = out.File - file_service_proto_rawDesc = nil file_service_proto_goTypes = nil file_service_proto_depIdxs = nil } diff --git a/internal/integration/service/service_drpc.pb.go b/internal/integration/service/service_drpc.pb.go index 7287578..0b859e6 100644 --- a/internal/integration/service/service_drpc.pb.go +++ b/internal/integration/service/service_drpc.pb.go @@ -1,12 +1,12 @@ // Code generated by protoc-gen-go-drpc. DO NOT EDIT. -// protoc-gen-go-drpc version: (devel) +// protoc-gen-go-drpc version: v0.0.35-0.20260221163524-0f9d4319a40d+dirty // source: service.proto package service import ( context "context" - errors "errors" + errors "github.com/cockroachdb/errors" protojson "google.golang.org/protobuf/encoding/protojson" proto "google.golang.org/protobuf/proto" drpc "storj.io/drpc" @@ -285,6 +285,7 @@ type DRPCService_Method2Stream interface { drpc.Stream SendAndClose(*Out) error Recv() (*In, error) + RecvMsg(interface{}) error } type drpcService_Method2Stream struct { @@ -310,7 +311,7 @@ func (x *drpcService_Method2Stream) Recv() (*In, error) { return m, nil } -func (x *drpcService_Method2Stream) RecvMsg(m *In) error { +func (x *drpcService_Method2Stream) RecvMsg(m interface{}) error { return x.MsgRecv(m, drpcEncoding_File_service_proto{}) } @@ -335,6 +336,7 @@ type DRPCService_Method4Stream interface { drpc.Stream Send(*Out) error Recv() (*In, error) + RecvMsg(interface{}) error } type drpcService_Method4Stream struct { @@ -357,6 +359,6 @@ func (x *drpcService_Method4Stream) Recv() (*In, error) { return m, nil } -func (x *drpcService_Method4Stream) RecvMsg(m *In) error { +func (x *drpcService_Method4Stream) RecvMsg(m interface{}) error { return x.MsgRecv(m, drpcEncoding_File_service_proto{}) }