Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions jsonrpc/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ type pendingRequest struct {
type Client struct {
opts *Options
conn net.Conn
writeMu sync.Mutex
msgID int
msgIDMu sync.Mutex
pendingRequests map[int]pendingRequest
Expand Down Expand Up @@ -188,8 +189,10 @@ func (c *Client) Method(
}
c.pendingRequestsMu.Unlock()

c.writeMu.Lock()
_ = c.conn.SetWriteDeadline(time.Now().Add(writeTimeout))
_, err = c.conn.Write(msg)
c.writeMu.Unlock()
if err != nil {
c.Close()
return SocketError(fmt.Errorf("Failed to write to socket: %w", err))
Expand Down
88 changes: 88 additions & 0 deletions jsonrpc/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
"encoding/json"
"fmt"
"net"
"sync"
"testing"
"time"

Expand Down Expand Up @@ -85,6 +86,55 @@ func (s *testSuite) TestMethodSync() {
require.Empty(s.T(), client.pendingRequests)
}

func (s *testSuite) TestMethodConcurrentWritesSerialized() {
s.server.OnRequest.Set(func(conn net.Conn, req *types.Request) *types.Response {
return &types.Response{
JSONRPC: types.JSONRPC,
ID: &req.ID,
Result: test.ToResult(req.ID),
}
})

var conn *concurrentWriteDetectingConn
client, err := Connect(&Options{
Dial: func() (net.Conn, error) {
netConn, err := s.dial()
if err != nil {
return nil, err
}
conn = newConcurrentWriteDetectingConn(netConn)
return conn, nil
},
})
require.NoError(s.T(), err)
defer client.Close()

ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()

const numRequests = 32
start := make(chan struct{})
errs := make(chan error, numRequests)
var wg sync.WaitGroup
for i := 0; i < numRequests; i++ {
wg.Add(1)
go func() {
defer wg.Done()
<-start
var response int
errs <- client.MethodBlocking(ctx, &response, "method")
}()
}
close(start)
wg.Wait()
close(errs)

for err := range errs {
require.NoError(s.T(), err)
}
require.Zero(s.T(), conn.concurrentWrites())
}

// Server returns an error nested in a `{ message: ...}` object.
func (s *testSuite) TestMethodBlockingNestedErrByServer() {
client, err := Connect(&Options{
Expand Down Expand Up @@ -260,3 +310,41 @@ func (s *testSuite) TestNotification() {
s.T().Fatal("timeout")
}
}

type concurrentWriteDetectingConn struct {
net.Conn

inWrite chan struct{}
mu sync.Mutex
count int
}

func newConcurrentWriteDetectingConn(conn net.Conn) *concurrentWriteDetectingConn {
return &concurrentWriteDetectingConn{
Conn: conn,
inWrite: make(chan struct{}, 1),
}
}

func (c *concurrentWriteDetectingConn) Write(b []byte) (int, error) {
select {
case c.inWrite <- struct{}{}:
defer func() {
<-c.inWrite
}()
default:
c.mu.Lock()
c.count++
c.mu.Unlock()
return 0, fmt.Errorf("concurrent write detected")
}

time.Sleep(5 * time.Millisecond)
return c.Conn.Write(b)
}

func (c *concurrentWriteDetectingConn) concurrentWrites() int {
c.mu.Lock()
defer c.mu.Unlock()
return c.count
}
Loading