Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@ cmd/protoc-gen-go-drpc/protoc-gen-go-drpc
/WORKSPACE
BUILD.bazel
MODULE.bazel*
debug/*
1 change: 1 addition & 0 deletions AGENTS.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
This fork of DRPC library is maintained and used by [CockroachDB](https://github.com/cockroachdb/cockroach) and is customized for CockroachDB's needs.
4 changes: 4 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@ lint:
staticcheck $(PKG)
golangci-lint run

.PHONY: install-protoc-plugin
install-protoc-plugin:
$(GO) install ./cmd/protoc-gen-go-drpc/

.PHONY: gen-bazel
gen-bazel:
@echo "Generating WORKSPACE"
Expand Down
32 changes: 20 additions & 12 deletions cmd/protoc-gen-go-drpc/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,17 @@ import (
)

type config struct {
protolib string
json bool
protolib string
json bool
generateAdapters bool
}

func main() {
var flags flag.FlagSet
var conf config
flags.StringVar(&conf.protolib, "protolib", "google.golang.org/protobuf", "which protobuf library to use for encoding")
flags.BoolVar(&conf.json, "json", true, "generate encoders with json support")
flags.BoolVar(&conf.generateAdapters, "generate-adapters", true, "generate gRPC/DRPC adapter and RPC interface code")

protogen.Options{
ParamFunc: flags.Set,
Expand Down Expand Up @@ -55,7 +57,7 @@ func generateFile(plugin *protogen.Plugin, file *protogen.File, conf config) {

d.generateEncoding(conf)
for _, service := range file.Services {
d.generateService(service)
d.generateService(service, conf)
}
}

Expand Down Expand Up @@ -267,7 +269,7 @@ func (d *drpc) generateEncoding(conf config) {
// service generation
//

func (d *drpc) generateService(service *protogen.Service) {
func (d *drpc) generateService(service *protogen.Service, conf config) {
// Client interface
d.P("type ", d.ClientIface(service), " interface {")
d.P("DRPCConn() ", d.Ident("storj.io/drpc", "Conn"))
Expand All @@ -294,7 +296,7 @@ func (d *drpc) generateService(service *protogen.Service) {
d.P("func (c *", d.ClientImpl(service), ") DRPCConn() ", d.Ident("storj.io/drpc", "Conn"), "{ return c.cc }")
d.P()
for _, method := range service.Methods {
d.generateClientMethod(method)
d.generateClientMethod(method, conf)
}

// Server interface
Expand Down Expand Up @@ -339,11 +341,13 @@ func (d *drpc) generateService(service *protogen.Service) {

// Server methods
for _, method := range service.Methods {
d.generateServerMethod(method)
d.generateServerMethod(method, conf)
}

d.generateServiceRPCInterfaces(service)
d.generateServiceAdapters(service)
if conf.generateAdapters {
d.generateServiceRPCInterfaces(service)
d.generateServiceAdapters(service)
}
}

//
Expand All @@ -362,7 +366,7 @@ func (d *drpc) generateClientSignature(method *protogen.Method) string {
return fmt.Sprintf("%s(ctx %s%s) (%s, error)", method.GoName, d.Ident("context", "Context"), reqArg, respName)
}

func (d *drpc) generateClientMethod(method *protogen.Method) {
func (d *drpc) generateClientMethod(method *protogen.Method, conf config) {
recvType := d.ClientImpl(method.Parent)
outType := d.OutputType(method)
inType := d.InputType(method)
Expand Down Expand Up @@ -408,7 +412,9 @@ func (d *drpc) generateClientMethod(method *protogen.Method) {
d.P("}")
d.P()

d.generateRPCClientInterface(method)
if conf.generateAdapters {
d.generateRPCClientInterface(method)
}

d.P("type ", d.ClientStreamImpl(method), " struct {")
d.P(d.Ident("storj.io/drpc", "Stream"))
Expand Down Expand Up @@ -510,7 +516,7 @@ func (d *drpc) generateServerReceiver(method *protogen.Method) {
d.P(")")
}

func (d *drpc) generateServerMethod(method *protogen.Method) {
func (d *drpc) generateServerMethod(method *protogen.Method, conf config) {
genSend := method.Desc.IsStreamingServer()
genSendAndClose := !method.Desc.IsStreamingServer()
genRecv := method.Desc.IsStreamingClient()
Expand All @@ -531,7 +537,9 @@ func (d *drpc) generateServerMethod(method *protogen.Method) {
d.P("}")
d.P()

d.generateRPCServerInterface(method)
if conf.generateAdapters {
d.generateRPCServerInterface(method)
}

d.P("type ", d.ServerStreamImpl(method), " struct {")
d.P(d.Ident("storj.io/drpc", "Stream"))
Expand Down
30 changes: 12 additions & 18 deletions drpcconn/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,9 @@ type Options struct {

// Conn is a drpc client connection.
type Conn struct {
tr drpc.Transport
man *drpcmanager.Manager
mu sync.Mutex
wbuf []byte
tr drpc.Transport
man *drpcmanager.Manager
mu sync.Mutex // protects stats

stats map[string]*drpcstats.Stats
}
Expand Down Expand Up @@ -92,16 +91,16 @@ func (c *Conn) Transport() drpc.Transport { return c.tr }
// Closed returns a channel that is closed once the connection is closed.
func (c *Conn) Closed() <-chan struct{} { return c.man.Closed() }

// Unblocked returns a channel that is closed once the connection is no longer
// blocked by a previously canceled Invoke or NewStream call. It should not
// be called concurrently with Invoke or NewStream.
// Unblocked returns a channel that is closed when the connection is available
// for new streams. With multiplexing enabled, this always returns an
// already-closed channel.
func (c *Conn) Unblocked() <-chan struct{} { return c.man.Unblocked() }

// Close closes the connection.
func (c *Conn) Close() (err error) { return c.man.Close() }

// Invoke issues the rpc on the transport serializing in, waits for a response, and
// deserializes it into out. Only one Invoke or Stream may be open at a time.
// deserializes it into out. Multiple Invoke or Stream calls may be open concurrently.
func (c *Conn) Invoke(ctx context.Context, rpc string, enc drpc.Encoding, in, out drpc.Message) (err error) {
defer func() { err = drpc.ToRPCErr(err) }()

Expand All @@ -117,18 +116,13 @@ func (c *Conn) Invoke(ctx context.Context, rpc string, enc drpc.Encoding, in, ou
}
defer func() { err = errs.Combine(err, stream.Close()) }()

// we have to protect c.wbuf here even though the manager only allows one
// stream at a time because the stream may async close allowing another
// concurrent call to Invoke to proceed.
c.mu.Lock()
defer c.mu.Unlock()

c.wbuf, err = drpcenc.MarshalAppend(in, enc, c.wbuf[:0])
// Per-call buffer allocation for concurrent access.
data, err := drpcenc.MarshalAppend(in, enc, nil)
if err != nil {
return err
}

if err := c.doInvoke(stream, enc, rpc, c.wbuf, metadata, out); err != nil {
if err := c.doInvoke(stream, enc, rpc, data, metadata, out); err != nil {
return err
}
return nil
Expand All @@ -155,8 +149,8 @@ func (c *Conn) doInvoke(stream *drpcstream.Stream, enc drpc.Encoding, rpc string
return nil
}

// NewStream begins a streaming rpc on the connection. Only one Invoke or Stream may
// be open at a time.
// NewStream begins a streaming rpc on the connection. Multiple Invoke or Stream calls
// may be open concurrently.
func (c *Conn) NewStream(ctx context.Context, rpc string, enc drpc.Encoding) (_ drpc.Stream, err error) {
defer func() { err = drpc.ToRPCErr(err) }()

Expand Down
5 changes: 4 additions & 1 deletion drpcconn/conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,10 @@ func TestConn_NewStreamSendsGrpcAndDrpcMetadata(t *testing.T) {
})
s, err := conn.NewStream(ctx, "/com.example.Foo/Bar", testEncoding{})
assert.NoError(t, err)
_ = s.CloseSend()

assert.NoError(t, s.CloseSend())

ctx.Wait()
}

func TestConn_encodeMetadata(t *testing.T) {
Expand Down
82 changes: 82 additions & 0 deletions drpcmanager/frame_queue_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
// Copyright (C) 2026 Cockroach Labs.
// See LICENSE for copying information.

package drpcmanager

import (
"testing"

"github.com/zeebo/assert"
"storj.io/drpc/drpcwire"
)

func TestSharedWriteBuf_AppendDrain(t *testing.T) {
sw := newSharedWriteBuf()

pkt := drpcwire.Packet{
Data: []byte("hello"),
ID: drpcwire.ID{Stream: 1, Message: 2},
Kind: drpcwire.KindMessage,
}

assert.NoError(t, sw.Append(pkt))

// Drain should return serialized bytes.
data := sw.Drain(nil)
assert.That(t, len(data) > 0)

// Parse the frame back out to verify correctness.
_, got, ok, err := drpcwire.ParseFrame(data)
assert.NoError(t, err)
assert.That(t, ok)
assert.DeepEqual(t, got.Data, pkt.Data)
assert.Equal(t, got.ID.Stream, pkt.ID.Stream)
assert.Equal(t, got.ID.Message, pkt.ID.Message)
assert.Equal(t, got.Kind, pkt.Kind)
assert.Equal(t, got.Done, true)
}

func TestSharedWriteBuf_CloseIdempotent(t *testing.T) {
sw := newSharedWriteBuf()
sw.Close()
sw.Close() // must not panic
}

func TestSharedWriteBuf_AppendAfterClose(t *testing.T) {
sw := newSharedWriteBuf()
sw.Close()

err := sw.Append(drpcwire.Packet{})
assert.Error(t, err)
}

func TestSharedWriteBuf_WaitAndDrainBlocks(t *testing.T) {
sw := newSharedWriteBuf()

done := make(chan struct{})
go func() {
defer close(done)
data, ok := sw.WaitAndDrain(nil)
assert.That(t, ok)
assert.That(t, len(data) > 0)
}()

// Append should wake the blocked WaitAndDrain.
assert.NoError(t, sw.Append(drpcwire.Packet{Data: []byte("a")}))
<-done
}

func TestSharedWriteBuf_WaitAndDrainCloseEmpty(t *testing.T) {
sw := newSharedWriteBuf()

done := make(chan struct{})
go func() {
defer close(done)
_, ok := sw.WaitAndDrain(nil)
assert.That(t, !ok)
}()

// Close on empty buffer should return ok=false.
sw.Close()
<-done
}
Loading
Loading