From 337d1e3d4116fc982d076165471dae10c6ba560d Mon Sep 17 00:00:00 2001 From: mike andrews Date: Thu, 30 Apr 2026 18:07:00 -0400 Subject: [PATCH] jsonrpc: serialize writes to client connection --- jsonrpc/client.go | 3 ++ jsonrpc/client_test.go | 88 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 91 insertions(+) diff --git a/jsonrpc/client.go b/jsonrpc/client.go index a40b3a3..57bc55a 100644 --- a/jsonrpc/client.go +++ b/jsonrpc/client.go @@ -47,6 +47,7 @@ type pendingRequest struct { type Client struct { opts *Options conn net.Conn + writeMu sync.Mutex msgID int msgIDMu sync.Mutex pendingRequests map[int]pendingRequest @@ -188,8 +189,10 @@ func (c *Client) Method( } c.pendingRequestsMu.Unlock() + c.writeMu.Lock() _ = c.conn.SetWriteDeadline(time.Now().Add(writeTimeout)) _, err = c.conn.Write(msg) + c.writeMu.Unlock() if err != nil { c.Close() return SocketError(fmt.Errorf("Failed to write to socket: %w", err)) diff --git a/jsonrpc/client_test.go b/jsonrpc/client_test.go index 61600cb..52f3886 100644 --- a/jsonrpc/client_test.go +++ b/jsonrpc/client_test.go @@ -19,6 +19,7 @@ import ( "encoding/json" "fmt" "net" + "sync" "testing" "time" @@ -85,6 +86,55 @@ func (s *testSuite) TestMethodSync() { require.Empty(s.T(), client.pendingRequests) } +func (s *testSuite) TestMethodConcurrentWritesSerialized() { + s.server.OnRequest.Set(func(conn net.Conn, req *types.Request) *types.Response { + return &types.Response{ + JSONRPC: types.JSONRPC, + ID: &req.ID, + Result: test.ToResult(req.ID), + } + }) + + var conn *concurrentWriteDetectingConn + client, err := Connect(&Options{ + Dial: func() (net.Conn, error) { + netConn, err := s.dial() + if err != nil { + return nil, err + } + conn = newConcurrentWriteDetectingConn(netConn) + return conn, nil + }, + }) + require.NoError(s.T(), err) + defer client.Close() + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + const numRequests = 32 + start := make(chan struct{}) + errs := make(chan error, numRequests) + var wg sync.WaitGroup + for i := 0; i < numRequests; i++ { + wg.Add(1) + go func() { + defer wg.Done() + <-start + var response int + errs <- client.MethodBlocking(ctx, &response, "method") + }() + } + close(start) + wg.Wait() + close(errs) + + for err := range errs { + require.NoError(s.T(), err) + } + require.Zero(s.T(), conn.concurrentWrites()) +} + // Server returns an error nested in a `{ message: ...}` object. func (s *testSuite) TestMethodBlockingNestedErrByServer() { client, err := Connect(&Options{ @@ -260,3 +310,41 @@ func (s *testSuite) TestNotification() { s.T().Fatal("timeout") } } + +type concurrentWriteDetectingConn struct { + net.Conn + + inWrite chan struct{} + mu sync.Mutex + count int +} + +func newConcurrentWriteDetectingConn(conn net.Conn) *concurrentWriteDetectingConn { + return &concurrentWriteDetectingConn{ + Conn: conn, + inWrite: make(chan struct{}, 1), + } +} + +func (c *concurrentWriteDetectingConn) Write(b []byte) (int, error) { + select { + case c.inWrite <- struct{}{}: + defer func() { + <-c.inWrite + }() + default: + c.mu.Lock() + c.count++ + c.mu.Unlock() + return 0, fmt.Errorf("concurrent write detected") + } + + time.Sleep(5 * time.Millisecond) + return c.Conn.Write(b) +} + +func (c *concurrentWriteDetectingConn) concurrentWrites() int { + c.mu.Lock() + defer c.mu.Unlock() + return c.count +}