From 1898402a7a7409e08b3dadb74d7b5aa865a2bb2b Mon Sep 17 00:00:00 2001 From: Shubham Dhama Date: Mon, 23 Mar 2026 06:09:33 +0000 Subject: [PATCH 1/2] drpcmanager: increase coverage of wire protocol at manager level --- drpcmanager/manager_test.go | 370 ++++++++++++++++++++++++++++++++++++ 1 file changed, 370 insertions(+) diff --git a/drpcmanager/manager_test.go b/drpcmanager/manager_test.go index 58e8320..4d05355 100644 --- a/drpcmanager/manager_test.go +++ b/drpcmanager/manager_test.go @@ -193,6 +193,71 @@ func waitForClosed(t *testing.T, man *Manager) { } } +// +// manageReader tests +// + +// Global frame monotonicity: a frame with an ID lower than the last seen +// frame causes the manager to terminate with a protocol error. +func TestManageReader_GlobalMonotonicity_SameStream(t *testing.T) { + ctx := drpctest.NewTracker(t) + defer ctx.Close() + + cconn, sconn := net.Pipe() + defer func() { _ = cconn.Close() }() + defer func() { _ = sconn.Close() }() + + man := New(sconn) + defer func() { _ = man.Close() }() + + // Consume the invoke and drain messages so HandleFrame doesn't block. + ctx.Run(func(ctx context.Context) { + stream, _, err := man.NewServerStream(ctx) + assert.NoError(t, err) + for { + if _, err := stream.RawRecv(); err != nil { + return + } + } + }) + + writeFrames(t, cconn, + createFrame(drpcwire.KindInvoke, 1, 1, "rpc", true), + createFrame(drpcwire.KindMessage, 1, 5, "ok", true), + createFrame(drpcwire.KindMessage, 1, 4, "bad", true), + ) + + waitForClosed(t, man) +} + +// Cross-stream monotonicity: after seeing stream 2, a frame for stream 1 +// with a higher message ID is still rejected because {1,x} < {2,y}. +func TestManageReader_GlobalMonotonicity_CrossStream(t *testing.T) { + ctx := drpctest.NewTracker(t) + defer ctx.Close() + + cconn, sconn := net.Pipe() + defer func() { _ = cconn.Close() }() + defer func() { _ = sconn.Close() }() + + man := New(sconn) + defer func() { _ = man.Close() }() + + // Consume both invokes so manageReader can proceed. + ctx.Run(func(ctx context.Context) { + _, _, _ = man.NewServerStream(ctx) + _, _, _ = man.NewServerStream(ctx) + }) + + writeFrames(t, cconn, + createFrame(drpcwire.KindInvoke, 1, 1, "rpc1", true), + createFrame(drpcwire.KindInvoke, 2, 1, "rpc2", true), + createFrame(drpcwire.KindMessage, 1, 4, "bad", true), + ) + + waitForClosed(t, man) +} + // Invoke replay: after [s1,m1,invoke,done=true], lastFrameID is bumped to // {1,2}. A replayed [s1,m1,invoke] is caught by the monotonicity check. func TestManageReader_InvokeReplayBlocked(t *testing.T) { @@ -218,6 +283,274 @@ func TestManageReader_InvokeReplayBlocked(t *testing.T) { waitForClosed(t, man) } +// Non-done frames don't bump the message ID, so continuation frames with +// the same ID are accepted. +func TestManageReader_ContinuationFramesAccepted(t *testing.T) { + ctx := drpctest.NewTracker(t) + defer ctx.Close() + + cconn, sconn := net.Pipe() + defer func() { _ = cconn.Close() }() + defer func() { _ = sconn.Close() }() + + man := New(sconn) + defer func() { _ = man.Close() }() + + recv := make(chan []byte, 1) + ctx.Run(func(ctx context.Context) { + stream, _, err := man.NewServerStream(ctx) + assert.NoError(t, err) + data, err := stream.RawRecv() + assert.NoError(t, err) + recv <- data + }) + + writeFrames(t, cconn, + createFrame(drpcwire.KindInvoke, 1, 1, "rpc", true), + createFrame(drpcwire.KindMessage, 1, 2, "hel", false), + createFrame(drpcwire.KindMessage, 1, 2, "lo", true), + ) + + assert.DeepEqual(t, <-recv, []byte("hello")) +} + +// Old-stream frames are silently ignored on the client side when the local +// stream ID has advanced past the incoming frame's stream ID. +func TestManageReader_OldStreamFramesIgnored(t *testing.T) { + ctx := drpctest.NewTracker(t) + defer ctx.Close() + + cconn, sconn := net.Pipe() + defer func() { _ = cconn.Close() }() + defer func() { _ = sconn.Close() }() + + cman := NewWithOptions(cconn, Options{SoftCancel: true}) + defer func() { _ = cman.Close() }() + + // Drain all client writes so nothing blocks, and write server + // responses once we've seen enough data. + ctx.Run(func(ctx context.Context) { + buf := make([]byte, 4096) + for { + _, err := sconn.Read(buf) + if err != nil { + return + } + } + }) + + // Create stream 1 on the client, then cancel it so the client + // advances to stream 2. + subctx, cancel := context.WithCancel(ctx) + _, err := cman.NewClientStream(subctx, "rpc1") + assert.NoError(t, err) + cancel() + <-cman.Unblocked() + + stream2, err := cman.NewClientStream(ctx, "rpc2") + assert.NoError(t, err) + + // Write an old-stream frame (s1) then the real response for s2. + // The s1 frame should be silently ignored by the client manager. + writeFrames(t, sconn, + createFrame(drpcwire.KindMessage, 1, 1, "old", true), + createFrame(drpcwire.KindMessage, 2, 1, "new", true), + ) + + data, err := stream2.RawRecv() + assert.NoError(t, err) + assert.DeepEqual(t, data, []byte("new")) + + _ = stream2.Close() +} + +// The first frame for a new stream must be KindInvoke or KindInvokeMetadata. +// A non-invoke kind causes a protocol error. +func Disabled_TestManageReader_FirstFrameMustBeInvoke(t *testing.T) { + for _, kind := range []drpcwire.Kind{ + drpcwire.KindMessage, + drpcwire.KindCancel, + drpcwire.KindClose, + drpcwire.KindCloseSend, + drpcwire.KindError, + } { + t.Run(kind.String(), func(t *testing.T) { + cconn, sconn := net.Pipe() + defer func() { _ = cconn.Close() }() + defer func() { _ = sconn.Close() }() + + man := New(sconn) + defer func() { _ = man.Close() }() + + writeFrames(t, cconn, + createFrame(kind, 1, 1, "", true), + ) + + waitForClosed(t, man) + }) + } +} + +// A valid invoke sequence: Invoke → Message. +// Metadata encoding is covered separately by TestDrpcMetadata. +func TestManageReader_ValidInvokeSequence(t *testing.T) { + ctx := drpctest.NewTracker(t) + defer ctx.Close() + + cconn, sconn := net.Pipe() + defer func() { _ = cconn.Close() }() + defer func() { _ = sconn.Close() }() + + man := New(sconn) + defer func() { _ = man.Close() }() + + recv := make(chan []byte, 1) + ctx.Run(func(ctx context.Context) { + stream, rpc, err := man.NewServerStream(ctx) + assert.NoError(t, err) + assert.Equal(t, rpc, "myrpc") + + data, err := stream.RawRecv() + assert.NoError(t, err) + recv <- data + }) + + writeFrames(t, cconn, + createFrame(drpcwire.KindInvoke, 1, 1, "myrpc", true), + createFrame(drpcwire.KindMessage, 1, 2, "payload", true), + ) + + assert.DeepEqual(t, <-recv, []byte("payload")) +} + +// Multi-frame message delivered through manager to stream: frames are +// assembled by the stream into a single packet. +func TestManageReader_MultiFrameDelivery(t *testing.T) { + ctx := drpctest.NewTracker(t) + defer ctx.Close() + + cconn, sconn := net.Pipe() + defer func() { _ = cconn.Close() }() + defer func() { _ = sconn.Close() }() + + man := New(sconn) + defer func() { _ = man.Close() }() + + recv := make(chan []byte, 1) + ctx.Run(func(ctx context.Context) { + stream, _, err := man.NewServerStream(ctx) + assert.NoError(t, err) + + data, err := stream.RawRecv() + assert.NoError(t, err) + recv <- data + }) + + writeFrames(t, cconn, + createFrame(drpcwire.KindInvoke, 1, 1, "rpc", true), + createFrame(drpcwire.KindMessage, 1, 2, "hel", false), + createFrame(drpcwire.KindMessage, 1, 2, "lo ", false), + createFrame(drpcwire.KindMessage, 1, 2, "world", true), + ) + + assert.DeepEqual(t, <-recv, []byte("hello world")) +} + +// When a higher message ID arrives mid-assembly, the partial data is +// discarded and only the new message is delivered. +func TestManageReader_HigherMsgDiscardsInProgress(t *testing.T) { + ctx := drpctest.NewTracker(t) + defer ctx.Close() + + cconn, sconn := net.Pipe() + defer func() { _ = cconn.Close() }() + defer func() { _ = sconn.Close() }() + + man := New(sconn) + defer func() { _ = man.Close() }() + + recv := make(chan []byte, 1) + ctx.Run(func(ctx context.Context) { + stream, _, err := man.NewServerStream(ctx) + assert.NoError(t, err) + data, err := stream.RawRecv() + assert.NoError(t, err) + recv <- data + }) + + writeFrames(t, cconn, + createFrame(drpcwire.KindInvoke, 1, 1, "rpc", true), + createFrame(drpcwire.KindMessage, 1, 2, "discard", false), + createFrame(drpcwire.KindMessage, 1, 3, "kept", true), + ) + + assert.DeepEqual(t, <-recv, []byte("kept")) +} + +// A continuation frame with a different kind than the first frame of the +// packet causes the manager to terminate with a protocol error. +func TestManageReader_KindChangeWithinPacket(t *testing.T) { + ctx := drpctest.NewTracker(t) + defer ctx.Close() + + cconn, sconn := net.Pipe() + defer func() { _ = cconn.Close() }() + defer func() { _ = sconn.Close() }() + + man := New(sconn) + defer func() { _ = man.Close() }() + + ctx.Run(func(ctx context.Context) { + stream, _, err := man.NewServerStream(ctx) + assert.NoError(t, err) + for { + if _, err := stream.RawRecv(); err != nil { + return + } + } + }) + + writeFrames(t, cconn, + createFrame(drpcwire.KindInvoke, 1, 1, "rpc", true), + createFrame(drpcwire.KindMessage, 1, 2, "data", false), + createFrame(drpcwire.KindClose, 1, 2, "", true), + ) + + waitForClosed(t, man) +} + +// Multi-frame assembly works correctly when the message ID is greater than +// the previous message (e.g., on the server side where invoke consumed +// earlier IDs). +func TestManageReader_MultiFrameWithSkippedMessageID(t *testing.T) { + ctx := drpctest.NewTracker(t) + defer ctx.Close() + + cconn, sconn := net.Pipe() + defer func() { _ = cconn.Close() }() + defer func() { _ = sconn.Close() }() + + man := New(sconn) + defer func() { _ = man.Close() }() + + recv := make(chan []byte, 1) + ctx.Run(func(ctx context.Context) { + stream, _, err := man.NewServerStream(ctx) + assert.NoError(t, err) + data, err := stream.RawRecv() + assert.NoError(t, err) + recv <- data + }) + + writeFrames(t, cconn, + createFrame(drpcwire.KindInvoke, 1, 1, "rpc", true), + createFrame(drpcwire.KindMessage, 1, 3, "hel", false), + createFrame(drpcwire.KindMessage, 1, 3, "lo", true), + ) + + assert.DeepEqual(t, <-recv, []byte("hello")) +} + // A second invoke for the same stream ID is rejected — the stream treats // it as a protocol error, terminating the manager. func TestManageReader_InvokeOnExistingStream(t *testing.T) { @@ -246,6 +579,43 @@ func TestManageReader_InvokeOnExistingStream(t *testing.T) { assert.That(t, drpc.ProtocolError.Has(man.sigs.term.Err())) } +// When a non-invoke frame arrives before the stream is created (e.g., +// NewServerStream hasn't returned yet), manageReader waits for the stream +// and retries. +func TestManageReader_WaitsForStreamCreation(t *testing.T) { + ctx := drpctest.NewTracker(t) + defer ctx.Close() + + cconn, sconn := net.Pipe() + defer func() { _ = cconn.Close() }() + defer func() { _ = sconn.Close() }() + + man := New(sconn) + defer func() { _ = man.Close() }() + + // Write invoke + message immediately. The message arrives before + // NewServerStream creates the stream, exercising the default/wait path. + writeFrames(t, cconn, + createFrame(drpcwire.KindInvoke, 1, 1, "rpc", true), + createFrame(drpcwire.KindMessage, 1, 2, "data", true), + ) + + // Small delay to let manageReader process both frames. + time.Sleep(10 * time.Millisecond) + + recv := make(chan []byte, 1) + ctx.Run(func(ctx context.Context) { + stream, _, err := man.NewServerStream(ctx) + assert.NoError(t, err) + + data, err := stream.RawRecv() + assert.NoError(t, err) + recv <- data + }) + + assert.DeepEqual(t, <-recv, []byte("data")) +} + type blockingTransport chan struct{} func (b blockingTransport) Read(p []byte) (n int, err error) { <-b; return 0, io.EOF } From 53e0b30d92872798cb62b24989282ea4226b1af5 Mon Sep 17 00:00:00 2001 From: Shubham Dhama Date: Mon, 23 Mar 2026 06:16:35 +0000 Subject: [PATCH 2/2] *: move frame assembly from reader to stream Previously, Reader assembled wire frames into complete packets before handing them to the manager. This change makes Reader return individual frames (ReadFrame), and the stream handles frame assembly itself (HandleFrame). The manager now enforces global frame ID monotonicity and other validation that are beyond a stream's scope. This is groundwork for stream multiplexing, where frames from different streams will be interleaved on the wire and must be routed to the correct stream before assembly. --- drpcconn/conn_test.go | 26 +-- drpcmanager/manager.go | 77 +++---- drpcmanager/manager_test.go | 2 +- drpcstream/stream.go | 71 +++++- drpcstream/stream_test.go | 422 ++++++++++++++++++++++++++++++++---- drpcwire/reader.go | 79 ++----- drpcwire/reader_test.go | 322 +++++++++++---------------- 7 files changed, 639 insertions(+), 360 deletions(-) diff --git a/drpcconn/conn_test.go b/drpcconn/conn_test.go index 9f74516..e7402b6 100644 --- a/drpcconn/conn_test.go +++ b/drpcconn/conn_test.go @@ -43,9 +43,9 @@ func TestConn_InvokeFlushesSendClose(t *testing.T) { wr := drpcwire.NewWriter(ps, 64) rd := drpcwire.NewReader(ps) - _, _ = rd.ReadPacket() // Invoke - _, _ = rd.ReadPacket() // Message - pkt, _ := rd.ReadPacket() // CloseSend + _, _ = rd.ReadFrame() // Invoke + _, _ = rd.ReadFrame() // Message + pkt, _ := rd.ReadFrame() // CloseSend _ = wr.WritePacket(drpcwire.Packet{ Data: []byte("qux"), @@ -54,8 +54,8 @@ func TestConn_InvokeFlushesSendClose(t *testing.T) { }) _ = wr.Flush() - _, _ = rd.ReadPacket() // Close - <-invokeDone // wait for invoke to return + _, _ = rd.ReadFrame() // Close + <-invokeDone // wait for invoke to return // ensure that any later packets are dropped by writing one // before closing the transport. @@ -98,7 +98,7 @@ func TestConn_InvokeSendsGrpcAndDrpcMetadata(t *testing.T) { wr := drpcwire.NewWriter(ps, 64) rd := drpcwire.NewReader(ps) - md, err := rd.ReadPacket() // Metadata + md, err := rd.ReadFrame() // Metadata assert.NoError(t, err) assert.Equal(t, md.Kind, drpcwire.KindInvokeMetadata) metadata, err := drpcmetadata.Decode(md.Data) @@ -110,9 +110,9 @@ func TestConn_InvokeSendsGrpcAndDrpcMetadata(t *testing.T) { "common-key": "common-value2", }) - _, _ = rd.ReadPacket() // Invoke - _, _ = rd.ReadPacket() // Message - pkt, _ := rd.ReadPacket() // CloseSend + _, _ = rd.ReadFrame() // Invoke + _, _ = rd.ReadFrame() // Message + pkt, _ := rd.ReadFrame() // CloseSend _ = wr.WritePacket(drpcwire.Packet{ Data: []byte("qux"), @@ -121,7 +121,7 @@ func TestConn_InvokeSendsGrpcAndDrpcMetadata(t *testing.T) { }) _ = wr.Flush() - _, _ = rd.ReadPacket() // Close + _, _ = rd.ReadFrame() // Close }) conn := New(pc) @@ -154,7 +154,7 @@ func TestConn_NewStreamSendsGrpcAndDrpcMetadata(t *testing.T) { ctx.Run(func(ctx context.Context) { rd := drpcwire.NewReader(ps) - md, err := rd.ReadPacket() // Metadata + md, err := rd.ReadFrame() // Metadata assert.NoError(t, err) assert.Equal(t, md.Kind, drpcwire.KindInvokeMetadata) metadata, err := drpcmetadata.Decode(md.Data) @@ -164,8 +164,8 @@ func TestConn_NewStreamSendsGrpcAndDrpcMetadata(t *testing.T) { "drpc-key": "drpc-value", }) - _, _ = rd.ReadPacket() // Invoke - _, _ = rd.ReadPacket() // CloseSend + _, _ = rd.ReadFrame() // Invoke + _, _ = rd.ReadFrame() // CloseSend }) conn := New(pc) diff --git a/drpcmanager/manager.go b/drpcmanager/manager.go index ae1cd7a..2207ea3 100644 --- a/drpcmanager/manager.go +++ b/drpcmanager/manager.go @@ -72,6 +72,9 @@ type Manager struct { rd *drpcwire.Reader opts Options + lastFrameID drpcwire.ID + lastFrameKind drpcwire.Kind + sem drpcsignal.Chan // held by the active stream sbuf streamBuffer // largest stream id created pkts chan drpcwire.Packet // channel for invoke packets @@ -213,27 +216,15 @@ func (m *Manager) terminate(err error) { // manage reader // -// manageReader is always reading a packet and dispatching it to the appropriate -// stream or queue. It sets the read signal when it exits so that one can wait -// to ensure that no one is reading on the reader. It sets the term signal if -// there is any error reading packets. +// manageReader reads the frame and dispatches them to the appropriate stream or +// queue. It sets the read signal when it exits so that one can wait to ensure +// that no one is reading on the reader. It sets the term signal if there is any +// error reading frames. func (m *Manager) manageReader() { defer m.sigs.read.Set(nil) - var pkt drpcwire.Packet - var err error - var run int - for !m.sigs.term.IsSet() { - // if we have a run of "small" packets, drop the buffer to release - // memory so that a burst of large packets does not cause eternally - // large heap usage. - if run > 10 { - pkt.Data = nil - run = 0 - } - - pkt, err = m.rd.ReadPacketUsing(pkt.Data[:0]) + incomingFrame, err := m.rd.ReadFrame() if err != nil { if isConnectionReset(err) { err = drpc.ClosedError.Wrap(err) @@ -242,36 +233,36 @@ func (m *Manager) manageReader() { return } - if len(pkt.Data) < cap(pkt.Data)/4 { - run++ - } else { - run = 0 - } + m.log("READ", incomingFrame.String) - m.log("READ", pkt.String) + if ok := m.checkStreamMonotonicity(incomingFrame); !ok { + m.terminate(managerClosed.Wrap(drpc.ProtocolError.New("id monotonicity violation"))) + return + } switch curr := m.sbuf.Get(); { - // if the packet is for the current stream, deliver it. - case curr != nil && pkt.ID.Stream == curr.ID(): - if err := curr.HandlePacket(pkt); err != nil { + // If the frame is for the current stream, deliver it. + case curr != nil && incomingFrame.ID.Stream == curr.ID(): + if err := curr.HandleFrame(incomingFrame); err != nil { m.terminate(managerClosed.Wrap(err)) return } - // if an old message has been sent, just ignore it. - case curr != nil && pkt.ID.Stream < curr.ID(): + // If a frame arrives for an old stream, just ignore it. + case curr != nil && incomingFrame.ID.Stream < curr.ID(): - // if any invoke sequence is being sent, close any old unterminated - // stream and forward it to be handled. - case pkt.Kind == drpcwire.KindInvoke || pkt.Kind == drpcwire.KindInvokeMetadata: + // If an invoke sequence is being sent for a new stream, close any + // old unterminated stream and forward it to be handled. + case incomingFrame.Kind == drpcwire.KindInvoke || incomingFrame.Kind == drpcwire.KindInvokeMetadata: if curr != nil && !curr.IsTerminated() { curr.Cancel(context.Canceled) } + pkt := drpcwire.Packet{ID: incomingFrame.ID, Kind: incomingFrame.Kind, Data: incomingFrame.Data} select { case m.pkts <- pkt: // Wait for NewServerStream to finish stream creation (including - // sbuf.Set) before reading the next packet. This guarantees curr + // sbuf.Set) before reading the next frame. This guarantees curr // is set for subsequent non-invoke packets. m.pdone.Recv() @@ -280,18 +271,28 @@ func (m *Manager) manageReader() { } default: - // A non-invoke packet arrived for a stream that doesn't exist yet - // (curr is nil or pkt.ID.Stream > curr.ID). The first packet of a - // new stream must be KindInvoke or KindInvokeMetadata. + // A non-invoke frame arrived for a stream that doesn't exist yet + // (curr is nil or incomingFrame.ID.Stream > curr.ID). The first + // frame of a new stream must be KindInvoke or KindInvokeMetadata. m.terminate(managerClosed.Wrap(drpc.ProtocolError.New( - "first packet of a new stream must be Invoke, got %v (ID:%v)", - pkt.Kind, - pkt.ID))) + "first frame of a new stream must be Invoke, got %v (ID:%v)", + incomingFrame.Kind, + incomingFrame.ID))) return } } } +func (m *Manager) checkStreamMonotonicity(incomingFrame drpcwire.Frame) bool { + ok := incomingFrame.ID.Stream >= m.lastFrameID.Stream + m.lastFrameKind = incomingFrame.Kind + m.lastFrameID = incomingFrame.ID + if incomingFrame.Done { + m.lastFrameID.Message += 1 + } + return ok +} + // // manage streams // diff --git a/drpcmanager/manager_test.go b/drpcmanager/manager_test.go index 4d05355..fd70611 100644 --- a/drpcmanager/manager_test.go +++ b/drpcmanager/manager_test.go @@ -366,7 +366,7 @@ func TestManageReader_OldStreamFramesIgnored(t *testing.T) { // The first frame for a new stream must be KindInvoke or KindInvokeMetadata. // A non-invoke kind causes a protocol error. -func Disabled_TestManageReader_FirstFrameMustBeInvoke(t *testing.T) { +func TestManageReader_FirstFrameMustBeInvoke(t *testing.T) { for _, kind := range []drpcwire.Kind{ drpcwire.KindMessage, drpcwire.KindCancel, diff --git a/drpcstream/stream.go b/drpcstream/stream.go index 29ccd63..7fd769e 100644 --- a/drpcstream/stream.go +++ b/drpcstream/stream.go @@ -23,7 +23,7 @@ import ( // Options controls configuration settings for a stream. type Options struct { - // SplitSize controls the default size we split packets into frames. + // SplitSize controls the default size we split data packets into frames. SplitSize int // ManualFlush controls if the stream will automatically flush after every @@ -52,6 +52,11 @@ type Stream struct { read inspectMutex flush sync.Once + assembling bool + pktBuf []byte + pktKind drpcwire.Kind + nextMessageID uint64 + id drpcwire.ID wr *drpcwire.Writer pbuf packetBuffer @@ -98,6 +103,8 @@ func NewWithOptions(ctx context.Context, sid uint64, wr *drpcwire.Writer, opts O fin: drpcopts.GetStreamFin(&opts.Internal), task: task, + nextMessageID: 1, + id: drpcwire.ID{Stream: sid}, wr: wr.Reset(), } @@ -211,24 +218,65 @@ func (s *Stream) IsFinished() bool { return s.sigs.fin.IsSet() } func (s *Stream) SetManualFlush(mf bool) { s.opts.ManualFlush = mf } // -// packet handler +// frame handler // -// HandlePacket advances the stream state machine by inspecting the packet. It -// returns any major errors that should terminate the transport the stream is -// operating on as well as a boolean indicating if the stream expects more -// packets. -func (s *Stream) HandlePacket(pkt drpcwire.Packet) (err error) { - if pkt.ID.Stream != s.id.Stream { +// HandleFrame processes an incoming frame, assembling multi-frame packets +// and dispatching complete packets to the stream state machine. +func (s *Stream) HandleFrame(fr drpcwire.Frame) (err error) { + if s.sigs.term.IsSet() { return nil } - drpcopts.GetStreamStats(&s.opts.Internal).AddRead(uint64(len(pkt.Data))) + if fr.ID.Stream != s.ID() { + return drpc.ProtocolError.New("frame doesn't belong to this stream (fr: %v)", fr.ID) + } - if s.sigs.term.IsSet() { + if fr.ID.Message < s.nextMessageID { + return drpc.ProtocolError.New( + "id monotonicity violation: frame %v has message ID less than expected %v", fr.ID, s.nextMessageID) + } else if fr.ID.Message > s.nextMessageID || !s.assembling { + s.pktBuf = s.pktBuf[:0] + s.assembling = true + s.nextMessageID = fr.ID.Message + } else if fr.Kind != s.pktKind { + return drpc.ProtocolError.New("frame kind change within packet: got %v, expected %v", fr.Kind, s.pktKind) + } + + // TODO(shubham): add buf reuse + s.pktBuf = append(s.pktBuf, fr.Data...) + + s.pktKind = fr.Kind + + if s.opts.MaximumBufferSize > 0 && len(s.pktBuf) > s.opts.MaximumBufferSize { + return drpc.ProtocolError.New("data overflow (len:%d)", len(s.pktBuf)) + } + + if !fr.Done { return nil } + s.assembling = false + s.nextMessageID = fr.ID.Message + 1 + + err = s.handlePacket(drpcwire.Packet{ + ID: fr.ID, + Kind: fr.Kind, + Control: fr.Control, + Data: s.pktBuf, + }) + + // TODO(shubham): add buf reuse + s.pktBuf = nil + return err +} + +// handlePacket advances the stream state machine by inspecting the packet. It +// returns any major errors that should terminate the transport the stream is +// operating on. +func (s *Stream) handlePacket(pkt drpcwire.Packet) (err error) { + drpcopts.GetStreamStats(&s.opts.Internal).AddRead(uint64(len(pkt.Data))) + s.log("HANDLE", pkt.String) if pkt.Kind == drpcwire.KindMessage { @@ -240,7 +288,7 @@ func (s *Stream) HandlePacket(pkt drpcwire.Packet) (err error) { defer s.mu.Unlock() switch pkt.Kind { - case drpcwire.KindInvoke: + case drpcwire.KindInvoke, drpcwire.KindInvokeMetadata: err := drpc.ProtocolError.New("invoke on existing stream") s.terminate(err) return err @@ -375,6 +423,7 @@ func (s *Stream) RawWrite(kind drpcwire.Kind, data []byte) (err error) { // rawWriteLocked does the body of RawWrite assuming the caller is holding the // appropriate locks. +// TODO(shubham): can we merge this with sendPacketLocked? func (s *Stream) rawWriteLocked(kind drpcwire.Kind, data []byte) (err error) { fr := s.newFrameLocked(kind) n := s.opts.SplitSize diff --git a/drpcstream/stream_test.go b/drpcstream/stream_test.go index 3cf4ca3..a62ab87 100644 --- a/drpcstream/stream_test.go +++ b/drpcstream/stream_test.go @@ -8,6 +8,7 @@ import ( "context" "errors" "io" + "strings" "testing" "github.com/zeebo/assert" @@ -18,16 +19,23 @@ import ( "storj.io/drpc/drpcwire" ) +// handleFrame is a helper that sends a single-frame packet to the stream. +// It constructs a frame with the given kind, matching the stream's ID, +// using the provided message ID, done=true. +func handleFrame(st *Stream, kind drpcwire.Kind, mid uint64) error { + return st.HandleFrame(drpcwire.Frame{ + ID: drpcwire.ID{Stream: st.ID(), Message: mid}, + Kind: kind, + Done: true, + }) +} + func TestStream_StateTransitions(t *testing.T) { ctx := drpctest.NewTracker(t) defer ctx.Close() any := errors.New("any sentinel error") - handlePacket := func(st *Stream, kind drpcwire.Kind) error { - return st.HandlePacket(drpcwire.Packet{Kind: kind}) - } - checkErrs := func(t *testing.T, exp interface{}, got error) { t.Helper() @@ -81,32 +89,32 @@ func TestStream_StateTransitions(t *testing.T) { }, { // recv close - Op: func(st *Stream) error { return handlePacket(st, drpcwire.KindClose) }, + Op: func(st *Stream) error { return handleFrame(st, drpcwire.KindClose, 1) }, Send: &drpc.ClosedError, Recv: io.EOF, }, { // recv error - Op: func(st *Stream) error { return handlePacket(st, drpcwire.KindError) }, + Op: func(st *Stream) error { return handleFrame(st, drpcwire.KindError, 1) }, Send: io.EOF, Recv: any, }, { // recv closesend - Op: func(st *Stream) error { return handlePacket(st, drpcwire.KindCloseSend) }, + Op: func(st *Stream) error { return handleFrame(st, drpcwire.KindCloseSend, 1) }, Send: nil, Recv: io.EOF, }, } for _, test := range cases { - st := New(ctx, 0, drpcwire.NewWriter(io.Discard, 0)) + st := New(ctx, 1, drpcwire.NewWriter(io.Discard, 0)) assert.NoError(t, test.Op(st)) checkErrs(t, test.Send, st.RawWrite(drpcwire.KindMessage, nil)) if test.Recv == nil { - ctx.Run(func(ctx context.Context) { _ = handlePacket(st, drpcwire.KindMessage) }) + ctx.Run(func(ctx context.Context) { _ = handleFrame(st, drpcwire.KindMessage, 2) }) } _, err := st.RawRecv() checkErrs(t, test.Recv, err) @@ -117,10 +125,6 @@ func TestStream_Unblocks(t *testing.T) { ctx := drpctest.NewTracker(t) defer ctx.Close() - handlePacket := func(st *Stream, kind drpcwire.Kind) error { - return st.HandlePacket(drpcwire.Packet{Kind: kind}) - } - cases := []struct { Op func(st *Stream) error }{ @@ -141,20 +145,20 @@ func TestStream_Unblocks(t *testing.T) { }, { // recv close - Op: func(st *Stream) error { return handlePacket(st, drpcwire.KindClose) }, + Op: func(st *Stream) error { return handleFrame(st, drpcwire.KindClose, 1) }, }, { // recv error - Op: func(st *Stream) error { return handlePacket(st, drpcwire.KindError) }, + Op: func(st *Stream) error { return handleFrame(st, drpcwire.KindError, 1) }, }, { // recv closesend - Op: func(st *Stream) error { return handlePacket(st, drpcwire.KindCloseSend) }, + Op: func(st *Stream) error { return handleFrame(st, drpcwire.KindCloseSend, 1) }, }, } for _, test := range cases { - st := New(ctx, 0, drpcwire.NewWriter(io.Discard, 0)) + st := New(ctx, 1, drpcwire.NewWriter(io.Discard, 0)) ctx.Run(func(ctx context.Context) { _, _ = st.RawRecv() }) assert.NoError(t, test.Op(st)) @@ -200,20 +204,6 @@ func TestStream_ConcurrentCloseCancel(t *testing.T) { assert.That(t, errors.Is(<-errch, context.Canceled)) } -func TestStream_Control(t *testing.T) { - st := New(context.Background(), 0, drpcwire.NewWriter(io.Discard, 0)) - - // N.B. the stream will return nil on any HandlePacket calls after the - // stream has been terminated for any reason, including if an invalid - // packet has been sent. the order of these two assertions is important! - - // an invalid packet is not an error if the control bit is set - assert.NoError(t, st.HandlePacket(drpcwire.Packet{Control: true})) - - // an invalid packet is an error if the control bit it not set - assert.That(t, drpc.InternalError.Has(st.HandlePacket(drpcwire.Packet{}))) -} - func TestStream_CorkUntilFirstRead(t *testing.T) { run := func() { ctx := drpctest.NewTracker(t) @@ -234,10 +224,11 @@ func TestStream_CorkUntilFirstRead(t *testing.T) { errch <- err }) ctx.Run(func(ctx context.Context) { - errch <- st.HandlePacket(drpcwire.Packet{ + errch <- st.HandleFrame(drpcwire.Frame{ Data: []byte("read"), ID: drpcwire.ID{Message: 1}, Kind: drpcwire.KindMessage, + Done: true, }) }) @@ -266,20 +257,24 @@ func TestStream_PacketBufferReuse(t *testing.T) { defer ctx.Close() defer ctx.Wait() - buf := make([]byte, 20) - st := New(ctx, 0, drpcwire.NewWriter(io.Discard, 0)) + data := make([]byte, 20) + mid := uint64(1) + st := New(ctx, 1, drpcwire.NewWriter(io.Discard, 0)) ctx.Run(func(ctx context.Context) { for !st.IsTerminated() { - err := st.HandlePacket(drpcwire.Packet{ - Data: buf, + err := st.HandleFrame(drpcwire.Frame{ + Data: data, + ID: drpcwire.ID{Stream: 1, Message: mid}, Kind: drpcwire.KindMessage, + Done: true, }) if err != nil { return } - for i := range buf { - buf[i]++ + mid++ + for i := range data { + data[i]++ } } }) @@ -327,3 +322,354 @@ func TestStream_SendCancelBusyDuringBlockedClose(t *testing.T) { assert.NoError(t, err) assert.That(t, busy) } + +// +// HandleFrame tests +// + +// A frame routed to the wrong stream is a protocol error. +func TestHandleFrame_WrongStreamID(t *testing.T) { + st := New(context.Background(), 1, drpcwire.NewWriter(io.Discard, 0)) + + err := st.HandleFrame(drpcwire.Frame{ + ID: drpcwire.ID{Stream: 2, Message: 1}, + Kind: drpcwire.KindMessage, + Done: true, + }) + assert.Error(t, err) + assert.That(t, drpc.ProtocolError.Has(err)) + assert.That(t, strings.Contains(err.Error(), "doesn't belong")) +} + +// A frame with a message ID lower than a previously completed message is rejected. +func TestHandleFrame_MessageMonotonicity(t *testing.T) { + st := New(context.Background(), 1, drpcwire.NewWriter(io.Discard, 0)) + // Close packet buffer so KindMessage delivery doesn't block. + st.pbuf.Close(io.EOF) + + // m3 completes, nextMessageID becomes 4. + assert.NoError(t, st.HandleFrame(drpcwire.Frame{ + ID: drpcwire.ID{Stream: 1, Message: 3}, Kind: drpcwire.KindMessage, Done: true, + })) + + // m2 < 4 → error. + err := st.HandleFrame(drpcwire.Frame{ + ID: drpcwire.ID{Stream: 1, Message: 2}, Kind: drpcwire.KindMessage, Done: true, + }) + assert.Error(t, err) + assert.That(t, drpc.ProtocolError.Has(err)) + assert.That(t, strings.Contains(err.Error(), "monotonicity")) +} + +func TestHandleFrame_FirstFrameOnFreshStream(t *testing.T) { + // On the client side, the first message received will have ID 1. But on the + // server side, invoke is consumed by the manager. The first frame reaching + // the stream could have msg > 1 (e.g., msg=2). nextMessageID=1, so 2 > 1 + // makes this a valid frame. + for _, messageID := range []uint64{1, 2} { + st := New(context.Background(), 1, drpcwire.NewWriter(io.Discard, 0)) + // Close the packet buffer so KindMessage Put doesn't block. + st.pbuf.Close(io.EOF) + err := st.HandleFrame(drpcwire.Frame{ + ID: drpcwire.ID{Stream: 1, Message: messageID}, Kind: drpcwire.KindMessage, Done: true, + }) + assert.NoError(t, err) + } +} + +// When a higher message ID arrives mid-assembly, the in-progress data is +// silently discarded and a new packet begins. +func TestHandleFrame_HigherMsgDiscardsInProgress(t *testing.T) { + ctx := drpctest.NewTracker(t) + defer ctx.Close() + + st := New(ctx, 1, drpcwire.NewWriter(io.Discard, 0)) + + // Start accumulating m1 (done=false doesn't call Put, so no blocking). + assert.NoError(t, st.HandleFrame(drpcwire.Frame{ + ID: drpcwire.ID{Stream: 1, Message: 1}, Kind: drpcwire.KindMessage, Data: []byte("discard"), Done: false, + })) + + // Launch receiver before sending done frame to avoid Put blocking. + recv := make(chan []byte, 1) + ctx.Run(func(ctx context.Context) { + data, err := st.RawRecv() + assert.NoError(t, err) + recv <- data + }) + + // m2 arrives, m1 data should be silently discarded. + assert.NoError(t, st.HandleFrame(drpcwire.Frame{ + ID: drpcwire.ID{Stream: 1, Message: 2}, Kind: drpcwire.KindMessage, Data: []byte("kept"), Done: true, + })) + + // Verify only m2's data was delivered. + assert.DeepEqual(t, <-recv, []byte("kept")) +} + +// Continuation frames (same message ID, mid-assembly) must carry the same +// kind as the first frame. A kind change mid-packet is a protocol error. +func TestHandleFrame_KindChangeWithinPacket(t *testing.T) { + st := New(context.Background(), 1, drpcwire.NewWriter(io.Discard, 0)) + + assert.NoError(t, st.HandleFrame(drpcwire.Frame{ + ID: drpcwire.ID{Stream: 1, Message: 1}, Kind: drpcwire.KindMessage, Done: false, + })) + + err := st.HandleFrame(drpcwire.Frame{ + ID: drpcwire.ID{Stream: 1, Message: 1}, Kind: drpcwire.KindError, Done: true, + }) + assert.Error(t, err) + assert.That(t, drpc.ProtocolError.Has(err)) + assert.That(t, strings.Contains(err.Error(), "kind change")) +} + +// Multiple continuation frames for the same message accumulate data correctly. +func TestHandleFrame_MultiFrameDataAccumulation(t *testing.T) { + ctx := drpctest.NewTracker(t) + defer ctx.Close() + + st := New(ctx, 1, drpcwire.NewWriter(io.Discard, 0)) + + // Continuation frames (done=false) don't call Put, so no blocking. + assert.NoError(t, st.HandleFrame(drpcwire.Frame{ + ID: drpcwire.ID{Stream: 1, Message: 1}, Kind: drpcwire.KindMessage, Data: []byte("hel"), Done: false, + })) + assert.NoError(t, st.HandleFrame(drpcwire.Frame{ + ID: drpcwire.ID{Stream: 1, Message: 1}, Kind: drpcwire.KindMessage, Data: []byte("lo "), Done: false, + })) + + // Launch receiver before the final frame to avoid Put blocking. + recv := make(chan []byte, 1) + ctx.Run(func(ctx context.Context) { + data, err := st.RawRecv() + assert.NoError(t, err) + recv <- data + }) + + assert.NoError(t, st.HandleFrame(drpcwire.Frame{ + ID: drpcwire.ID{Stream: 1, Message: 1}, Kind: drpcwire.KindMessage, Data: []byte("world"), Done: true, + })) + + assert.DeepEqual(t, <-recv, []byte("hello world")) +} + +// Multi-frame assembly works when the message ID is greater than nextMessageID +// (e.g., on the server side where invoke consumed earlier message IDs). +// Continuation frames must accumulate data, not reset on each frame. +func TestHandleFrame_MultiFrameWithSkippedMessageID(t *testing.T) { + ctx := drpctest.NewTracker(t) + defer ctx.Close() + + st := New(ctx, 1, drpcwire.NewWriter(io.Discard, 0)) + + // msg=3 is greater than nextMessageID=1. Continuation frames for the + // same message must still accumulate correctly. + assert.NoError(t, st.HandleFrame(drpcwire.Frame{ + ID: drpcwire.ID{Stream: 1, Message: 3}, Kind: drpcwire.KindMessage, Data: []byte("hel"), Done: false, + })) + assert.NoError(t, st.HandleFrame(drpcwire.Frame{ + ID: drpcwire.ID{Stream: 1, Message: 3}, Kind: drpcwire.KindMessage, Data: []byte("lo"), Done: false, + })) + + recv := make(chan []byte, 1) + ctx.Run(func(ctx context.Context) { + data, err := st.RawRecv() + assert.NoError(t, err) + recv <- data + }) + + assert.NoError(t, st.HandleFrame(drpcwire.Frame{ + ID: drpcwire.ID{Stream: 1, Message: 3}, Kind: drpcwire.KindMessage, Data: []byte(" world"), Done: true, + })) + + assert.DeepEqual(t, <-recv, []byte("hello world")) +} + +// Once a message completes (done=true), the same message ID is rejected. +func TestHandleFrame_DonePreventsReplay(t *testing.T) { + st := New(context.Background(), 1, drpcwire.NewWriter(io.Discard, 0)) + // Close packet buffer so KindMessage delivery doesn't block. + st.pbuf.Close(io.EOF) + + // m1 completes → nextMessageID becomes 2. + assert.NoError(t, st.HandleFrame(drpcwire.Frame{ + ID: drpcwire.ID{Stream: 1, Message: 1}, Kind: drpcwire.KindMessage, Done: true, + })) + + // Same message ID again → error. + err := st.HandleFrame(drpcwire.Frame{ + ID: drpcwire.ID{Stream: 1, Message: 1}, Kind: drpcwire.KindMessage, Done: true, + }) + assert.Error(t, err) + assert.That(t, drpc.ProtocolError.Has(err)) + assert.That(t, strings.Contains(err.Error(), "monotonicity")) +} + +// Kind consistency is only enforced within a packet (continuation frames), not +// across messages. A multi-frame KindMessage followed by a KindClose for the +// next message should be accepted without error. +func TestHandleFrame_MultiFrameThenNextMessage(t *testing.T) { + ctx := drpctest.NewTracker(t) + defer ctx.Close() + + st := New(ctx, 1, drpcwire.NewWriter(io.Discard, 0)) + + assert.NoError(t, st.HandleFrame(drpcwire.Frame{ + ID: drpcwire.ID{Stream: 1, Message: 1}, Kind: drpcwire.KindMessage, Data: []byte("ab"), Done: false, + })) + + // Launch receiver before done frame. + recv := make(chan []byte, 1) + ctx.Run(func(ctx context.Context) { + data, err := st.RawRecv() + assert.NoError(t, err) + recv <- data + }) + + assert.NoError(t, st.HandleFrame(drpcwire.Frame{ + ID: drpcwire.ID{Stream: 1, Message: 1}, Kind: drpcwire.KindMessage, Data: []byte("cd"), Done: true, + })) + assert.DeepEqual(t, <-recv, []byte("abcd")) + + // Message 2 with a different kind — should not trigger kind check. + assert.NoError(t, st.HandleFrame(drpcwire.Frame{ + ID: drpcwire.ID{Stream: 1, Message: 2}, Kind: drpcwire.KindClose, Done: true, + })) + + // Close triggers EOF on recv. + ctx.Run(func(ctx context.Context) { + _, err := st.RawRecv() + assert.That(t, errors.Is(err, io.EOF)) + }) + ctx.Wait() +} + +// Invoke and InvokeMetadata frames are rejected on an already-created stream. +func TestHandleFrame_InvokeOnExistingStream(t *testing.T) { + st := New(context.Background(), 1, drpcwire.NewWriter(io.Discard, 0)) + + err := handleFrame(st, drpcwire.KindInvoke, 1) + assert.Error(t, err) + assert.That(t, drpc.ProtocolError.Has(err)) + assert.That(t, strings.Contains(err.Error(), "invoke on existing stream")) +} + +func TestHandleFrame_InvokeMetadataOnExistingStream(t *testing.T) { + st := New(context.Background(), 1, drpcwire.NewWriter(io.Discard, 0)) + + err := handleFrame(st, drpcwire.KindInvokeMetadata, 1) + assert.Error(t, err) + assert.That(t, drpc.ProtocolError.Has(err)) + assert.That(t, strings.Contains(err.Error(), "invoke on existing stream")) +} + +// Frames arriving after the stream is terminated are silently ignored. +func TestHandleFrame_AfterTerminated(t *testing.T) { + st := New(context.Background(), 1, drpcwire.NewWriter(io.Discard, 0)) + + // Terminate the stream via cancel. + st.Cancel(context.Canceled) + + // Frames after termination are silently ignored. + err := st.HandleFrame(drpcwire.Frame{ + ID: drpcwire.ID{Stream: 1, Message: 1}, Kind: drpcwire.KindMessage, Done: true, + }) + assert.NoError(t, err) +} + +// A completed KindMessage frame delivers its data through RawRecv. +func TestHandleFrame_MessageDeliveredViaRecv(t *testing.T) { + ctx := drpctest.NewTracker(t) + defer ctx.Close() + + st := New(ctx, 1, drpcwire.NewWriter(io.Discard, 0)) + + // Launch receiver before sending to avoid Put blocking. + recv := make(chan []byte, 1) + ctx.Run(func(ctx context.Context) { + data, err := st.RawRecv() + assert.NoError(t, err) + recv <- data + }) + + assert.NoError(t, st.HandleFrame(drpcwire.Frame{ + ID: drpcwire.ID{Stream: 1, Message: 1}, + Kind: drpcwire.KindMessage, + Data: []byte("payload"), + Done: true, + })) + + assert.DeepEqual(t, <-recv, []byte("payload")) +} + +// +// Write-side tests +// + +func TestRawWrite_NonMessageSingleFrame(t *testing.T) { + // Non-KindMessage kinds must produce a single frame (n=0 in + // rawWriteLocked means default 64KB, effectively no split for + // small payloads). Verify they produce exactly one frame with Done=true. + kinds := []drpcwire.Kind{ + drpcwire.KindInvoke, + drpcwire.KindError, + drpcwire.KindCancel, + drpcwire.KindClose, + drpcwire.KindCloseSend, + drpcwire.KindInvokeMetadata, + } + + for _, kind := range kinds { + var buf bytes.Buffer + st := New(context.Background(), 1, drpcwire.NewWriter(&buf, 0)) + + assert.NoError(t, st.RawWrite(kind, []byte("data"))) + assert.NoError(t, st.RawFlush()) + var err error + + // Parse all frames from the buffer — should be exactly one. + data := buf.Bytes() + var frames []drpcwire.Frame + for len(data) > 0 { + var fr drpcwire.Frame + var ok bool + data, fr, ok, err = drpcwire.ParseFrame(data) + assert.NoError(t, err) + assert.That(t, ok) + frames = append(frames, fr) + } + assert.Equal(t, len(frames), 1) + assert.That(t, frames[0].Done) + assert.Equal(t, frames[0].Kind, kind) + } +} + +func TestRawWrite_MessageRespectsSplitSize(t *testing.T) { + var buf bytes.Buffer + st := NewWithOptions(context.Background(), 1, + drpcwire.NewWriter(&buf, 0), + Options{SplitSize: 5}, + ) + + // "helloworld" is 10 bytes, split at 5 → 2 frames. + assert.NoError(t, st.RawWrite(drpcwire.KindMessage, []byte("helloworld"))) + assert.NoError(t, st.RawFlush()) + var err error + + data := buf.Bytes() + var frames []drpcwire.Frame + for len(data) > 0 { + var fr drpcwire.Frame + var ok bool + data, fr, ok, err = drpcwire.ParseFrame(data) + assert.NoError(t, err) + assert.That(t, ok) + frames = append(frames, fr) + } + assert.Equal(t, len(frames), 2) + assert.That(t, !frames[0].Done) + assert.That(t, frames[1].Done) + assert.DeepEqual(t, frames[0].Data, []byte("hello")) + assert.DeepEqual(t, frames[1].Data, []byte("world")) +} diff --git a/drpcwire/reader.go b/drpcwire/reader.go index c9ac397..d5ab580 100644 --- a/drpcwire/reader.go +++ b/drpcwire/reader.go @@ -16,13 +16,12 @@ type ReaderOptions struct { MaximumBufferSize int } -// Reader reconstructs packets from frames read from an io.Reader. +// Reader reads frames from an io.Reader. type Reader struct { opts ReaderOptions r io.Reader curr []byte buf []byte - id ID rerr error } @@ -35,12 +34,12 @@ type Reader struct { // 9: maximum varint data length const maxFrameOverhead = 1 + 9 + 9 + 9 -// NewReader constructs a Reader to read Packets from the io.Reader. +// NewReader constructs a Reader to read Frames from the io.Reader. func NewReader(r io.Reader) *Reader { return NewReaderWithOptions(r, ReaderOptions{}) } -// NewReaderWithOptions constructs a Reader to read Packets from +// NewReaderWithOptions constructs a Reader to read Frames from // the io.Reader. It uses the provided options to manage buffering. func NewReaderWithOptions(r io.Reader, opts ReaderOptions) *Reader { if opts.MaximumBufferSize == 0 { @@ -50,10 +49,9 @@ func NewReaderWithOptions(r io.Reader, opts ReaderOptions) *Reader { return &Reader{ opts: opts, r: r, - // Err on the side of a smaller buffer since ReadPacket will lazily + // Err on the side of a smaller buffer since ReadFrame will lazily // grow this buffer. curr: make([]byte, 0, 4096), - id: ID{Stream: 1, Message: 1}, } } @@ -76,29 +74,14 @@ func (r *Reader) read(p []byte) (n int, err error) { return 0, drpc.InternalError.Wrap(io.ErrNoProgress) } -// ReadPacket reads a packet from the io.Reader. It is equivalent to -// calling ReadPacketUsing(nil). -func (r *Reader) ReadPacket() (pkt Packet, err error) { - return r.ReadPacketUsing(nil) -} - -// ReadPacketUsing reads a packet from the io.Reader. IDs read from -// frames must be monotonically increasing. When a new ID is read, the -// old data is discarded. This allows for easier asynchronous interrupts. -// If the amount of data in the Packet becomes too large, an error is -// returned. The returned packet's Data field is constructed by appending -// to the provided buf after it has been resliced to be zero length. -func (r *Reader) ReadPacketUsing(buf []byte) (pkt Packet, err error) { - pkt.Data = buf[:0] - - var fr Frame - var ok bool - +// ReadFrame reads a single frame from the io.Reader. +func (r *Reader) ReadFrame() (fr Frame, err error) { for { + var ok bool r.curr, fr, ok, err = ParseFrame(r.curr) switch { case err != nil: - return Packet{}, drpc.ProtocolError.Wrap(err) + return Frame{}, drpc.ProtocolError.Wrap(err) case !ok: // r.curr doesn't have enough data for a full frame, so prepend @@ -115,62 +98,28 @@ func (r *Reader) ReadPacketUsing(buf []byte) (pkt Packet, err error) { n, err := r.read(r.buf[len(r.buf):cap(r.buf)]) if err != nil { - return Packet{}, err + return Frame{}, err } ncap := uint(len(r.buf) + n) if ncap > uint(cap(r.buf)) { - return Packet{}, drpc.ProtocolError.New("data overflow") + return Frame{}, drpc.ProtocolError.New("data overflow") } r.buf = r.buf[:ncap] if len(r.buf)-maxFrameOverhead > r.opts.MaximumBufferSize { - return Packet{}, drpc.ProtocolError.New("data overflow") + return Frame{}, drpc.ProtocolError.New("data overflow") } r.curr = r.buf continue } - // since we got a packet, signal that we need to restore buf with - // whatever remains in r.curr the next time we don't have a packet. + // since we got a frame, signal that we need to restore buf with + // whatever remains in r.curr the next time we don't have a frame. if len(r.buf) > 0 { r.buf = r.buf[:0] } - - // If any frames are set to control, then the whole packet is - // considered to be control. - pkt.Control = pkt.Control || fr.Control - - switch { - case fr.ID.Less(r.id): - return Packet{}, drpc.ProtocolError.New("id monotonicity violation (fr:%v r:%v)", fr.ID, r.id) - - case r.id != fr.ID || pkt.ID == ID{}: - r.id = fr.ID - - pkt = Packet{ - Data: pkt.Data[:0], - ID: fr.ID, - Kind: fr.Kind, - Control: fr.Control, - } - - case fr.Kind != pkt.Kind: - return Packet{}, drpc.ProtocolError.New("packet kind change (fr:%v pkt:%v)", fr.Kind, pkt.Kind) - } - - pkt.Data = append(pkt.Data, fr.Data...) - - switch { - case len(pkt.Data) > r.opts.MaximumBufferSize: - return Packet{}, drpc.ProtocolError.New("data overflow (len:%v)", len(pkt.Data)) - - case fr.Done: - // increment the message id so that we do not accept any frames - // with the same id. - r.id.Message++ - return pkt, nil - } + return fr, nil } } diff --git a/drpcwire/reader_test.go b/drpcwire/reader_test.go index d57145b..4a551ff 100644 --- a/drpcwire/reader_test.go +++ b/drpcwire/reader_test.go @@ -15,176 +15,145 @@ import ( "github.com/zeebo/assert" ) -func TestReader(t *testing.T) { - type testCase struct { - Packets []Packet - Frames []Frame - Error string - Options ReaderOptions - } - - p := func(kind Kind, id uint64, control bool, data string) Packet { - return Packet{ - Data: []byte(data), - ID: ID{Stream: 1, Message: id}, - Kind: kind, - Control: control, - } - } - - f := func(kind Kind, id uint64, data string, done, control bool) Frame { +func TestReadFrame(t *testing.T) { + f := func(kind Kind, sid, mid uint64, data string, done, control bool) Frame { return Frame{ Data: []byte(data), - ID: ID{Stream: 1, Message: id}, + ID: ID{Stream: sid, Message: mid}, Kind: kind, Done: done, Control: control, } } - m := func(pkt Packet, frames ...Frame) testCase { - return testCase{ - Packets: []Packet{pkt}, - Frames: frames, + t.Run("SingleFrame", func(t *testing.T) { + fr := f(KindMessage, 1, 1, "hello", true, false) + + buf := AppendFrame(nil, fr) + rd := NewReader(bytes.NewReader(buf)) + + got, err := rd.ReadFrame() + assert.NoError(t, err) + assert.DeepEqual(t, got, fr) + + _, err = rd.ReadFrame() + assert.That(t, errors.Is(err, io.EOF)) + }) + + t.Run("MultipleFrames", func(t *testing.T) { + // Frames are returned individually even when they share a message + // ID and have done=false. Reader does no assembly — that's the + // stream's job. + frames := []Frame{ + f(KindMessage, 1, 1, "hello", false, false), + f(KindMessage, 1, 1, " ", false, false), + f(KindMessage, 1, 1, "world", true, false), + f(KindClose, 1, 2, "", true, false), } - } - megaFrames := make([]Frame, 0, 10*1024) - for i := 0; i < 10*1024; i++ { - megaFrames = append(megaFrames, f(KindMessage, 1, strings.Repeat("X", 1024), false, false)) - } - megaFrames = append(megaFrames, f(KindMessage, 1, "", true, false)) - - // 1 more than the maximum frame overhead is the minimum required to overflow - const overFrame = maxFrameOverhead + 1 - - cases := []testCase{ - m(p(KindMessage, 1, false, "hello world"), - f(KindMessage, 1, "hello", false, false), - f(KindMessage, 1, " ", false, false), - f(KindMessage, 1, "world", true, false)), - - m(p(KindMessage, 1, true, "hello world"), - f(KindMessage, 1, "hello", false, false), - f(KindMessage, 1, " ", false, true), - f(KindMessage, 1, "world", true, false)), - - m(p(KindClose, 2, false, ""), - f(KindMessage, 1, "hello", false, false), - f(KindMessage, 1, " ", false, false), - f(KindClose, 2, "", true, false)), - - { - Packets: []Packet{ - p(KindClose, 2, false, ""), - }, - Frames: []Frame{ - f(KindMessage, 1, "1", false, false), - f(KindClose, 2, "", true, false), - f(KindMessage, 1, "1", true, false), - }, - Error: "id monotonicity violation", - }, - - { // a single frame that's too large - Frames: []Frame{f(KindMessage, 1, strings.Repeat("X", 4<<20+overFrame), true, false)}, - Error: "data overflow", - }, - - { // a single frame that's too large with limited size - Frames: []Frame{f(KindMessage, 1, strings.Repeat("X", 1000+overFrame), true, false)}, - Error: "data overflow", - Options: ReaderOptions{MaximumBufferSize: 1000}, - }, - - { // multiple frames that make too large a packet - Frames: megaFrames, - Error: "data overflow", - }, - - { // multiple frames that make too large a packet with limited size - Frames: []Frame{ - f(KindMessage, 1, strings.Repeat("X", 500), false, false), - f(KindMessage, 1, strings.Repeat("X", 400), false, false), - f(KindMessage, 1, strings.Repeat("X", 100), false, false), - f(KindMessage, 1, strings.Repeat("X", overFrame), true, false), - }, - Error: "data overflow", - Options: ReaderOptions{MaximumBufferSize: 1000}, - }, - - { // Control bit is preserved - Packets: []Packet{ - p(KindClose, 2, false, ""), - p(KindMessage, 3, true, "ab"), - }, - Frames: []Frame{ - f(KindMessage, 1, "1", false, false), - f(KindClose, 2, "", true, false), - f(KindMessage, 3, "a", false, true), - f(KindMessage, 3, "b", true, false), - }, - }, - - { // packet kind changes - Frames: []Frame{ - f(KindMessage, 1, "", false, false), - f(KindClose, 1, "", false, false), - }, - Error: "packet kind change", - }, - - { // id monotonicity from id reuse - Packets: []Packet{ - p(KindMessage, 1, false, "1"), - }, - Frames: []Frame{ - f(KindMessage, 1, "1", true, false), - f(KindMessage, 1, "2", true, false), - }, - Error: "id monotonicity violation", - }, - - { // message id zero is not allowed - Frames: []Frame{{ID: ID{Stream: 1, Message: 0}}}, - Error: "id monotonicity violation", - }, - - { // stream id zero is not allowed - Frames: []Frame{{ID: ID{Stream: 0, Message: 1}}}, - Error: "id monotonicity violation", - }, - } + var buf []byte + for _, fr := range frames { + buf = AppendFrame(buf, fr) + } + + rd := NewReader(bytes.NewReader(buf)) + for _, exp := range frames { + got, err := rd.ReadFrame() + assert.NoError(t, err) + assert.DeepEqual(t, got, exp) + } + + _, err := rd.ReadFrame() + assert.That(t, errors.Is(err, io.EOF)) + }) + + t.Run("NoMonotonicity", func(t *testing.T) { + // Reader no longer enforces monotonicity. Frames with decreasing + // IDs should be returned without error. + frames := []Frame{ + f(KindMessage, 1, 5, "a", true, false), + f(KindMessage, 1, 3, "b", true, false), + } - for _, tc := range cases { var buf []byte - for _, fr := range tc.Frames { + for _, fr := range frames { buf = AppendFrame(buf, fr) } - rd := NewReaderWithOptions(bytes.NewReader(buf), tc.Options) - for _, expPkt := range tc.Packets { - pkt, err := rd.ReadPacket() + rd := NewReader(bytes.NewReader(buf)) + for _, exp := range frames { + got, err := rd.ReadFrame() assert.NoError(t, err) - assert.DeepEqual(t, expPkt, pkt) + assert.DeepEqual(t, got, exp) } + }) + + t.Run("BufferOverflow_SingleLargeFrame", func(t *testing.T) { + // 1 more than the maximum frame overhead is the minimum required to overflow. + const overFrame = maxFrameOverhead + 1 + fr := f(KindMessage, 1, 1, strings.Repeat("X", 4<<20+overFrame), true, false) - _, err := rd.ReadPacket() + buf := AppendFrame(nil, fr) + rd := NewReader(bytes.NewReader(buf)) + + _, err := rd.ReadFrame() assert.Error(t, err) - if tc.Error != "" { - assert.That(t, strings.Contains(err.Error(), tc.Error)) - } else { - assert.Equal(t, err, io.EOF) - } - } + assert.That(t, strings.Contains(err.Error(), "data overflow")) + }) + + t.Run("BufferOverflow_CustomLimit", func(t *testing.T) { + const overFrame = maxFrameOverhead + 1 + fr := f(KindMessage, 1, 1, strings.Repeat("X", 1000+overFrame), true, false) + + buf := AppendFrame(nil, fr) + rd := NewReaderWithOptions(bytes.NewReader(buf), ReaderOptions{MaximumBufferSize: 1000}) + + _, err := rd.ReadFrame() + assert.Error(t, err) + assert.That(t, strings.Contains(err.Error(), "data overflow")) + }) + + t.Run("ErrorWithData", func(t *testing.T) { + // If the underlying reader returns data and an error together, + // the frame should still be parsed from the data. + rd := NewReader(readerFunc(func(b []byte) (int, error) { + out := AppendFrame(b[:0:8], Frame{ + Data: []byte("test"), + ID: ID{1, 1}, + Kind: KindMessage, + Done: true, + }) + return len(out), io.EOF + })) + + got, err := rd.ReadFrame() + assert.NoError(t, err) + assert.DeepEqual(t, got, Frame{ + Data: []byte("test"), + ID: ID{1, 1}, + Kind: KindMessage, + Done: true, + }) + + _, err = rd.ReadFrame() + assert.That(t, errors.Is(err, io.EOF)) + }) + + t.Run("ErrorNoProgress", func(t *testing.T) { + rd := NewReader(readerFunc(func(b []byte) (int, error) { + return 0, nil + })) + + _, err := rd.ReadFrame() + assert.That(t, errors.Is(err, io.ErrNoProgress)) + }) } -func TestReaderRandomized(t *testing.T) { +func TestReadFrame_Randomized(t *testing.T) { seed := time.Now().UnixNano() t.Log("seed:", seed) rng := rand.New(rand.NewSource(seed)) - // create a function to get a predefined sequence of bytes bid := 0 get := func(n int) []byte { out := make([]byte, n) @@ -195,75 +164,40 @@ func TestReaderRandomized(t *testing.T) { return out } - // construct a random sequence of frames of different sizes - // to attempt to capture any bugs from buffer management + // Build a random sequence of frames with varying sizes. + var frames []Frame var buf []byte mid := uint64(1) done := false for i := 0; i < 1000; i++ { - buf = AppendFrame(buf, Frame{ + data := get(rng.Intn(8192)) + fr := Frame{ ID: ID{Stream: 1, Message: mid}, - Data: get(rng.Intn(8192)), + Data: data, Done: done, - }) + } + frames = append(frames, fr) + buf = AppendFrame(buf, fr) if done { mid++ } - done = rng.Intn(10) == 0 } - // read all of the packets back which should have the - // exact sequence of bytes, so we reset bid to generate - // the sequence again. + // ReadFrame should return each frame individually. bid = 0 r := NewReader(bytes.NewBuffer(buf)) - for i := 1; ; i++ { - pkt, err := r.ReadPacket() - if errors.Is(err, io.EOF) { - break - } + for _, exp := range frames { + got, err := r.ReadFrame() assert.NoError(t, err) - assert.Equal(t, pkt.ID.Message, i) - assert.Equal(t, pkt.Data, get(len(pkt.Data))) + assert.Equal(t, got.ID, exp.ID) + assert.Equal(t, got.Done, exp.Done) + assert.Equal(t, got.Data, get(len(exp.Data))) } } type readerFunc func([]byte) (int, error) func (fn readerFunc) Read(p []byte) (int, error) { return fn(p) } - -func TestReaderErrorWithData(t *testing.T) { - r := NewReader(readerFunc(func(b []byte) (int, error) { - out := AppendFrame(b[:0:8], Frame{ - Data: []byte("test"), - ID: ID{1, 1}, - Kind: KindMessage, - Done: true, - }) - return len(out), io.EOF - })) - - pkt, err := r.ReadPacket() - assert.NoError(t, err) - assert.Equal(t, pkt, Packet{ - Data: []byte("test"), - ID: ID{1, 1}, - Kind: KindMessage, - Control: false, - }) - - _, err = r.ReadPacket() - assert.Equal(t, err, io.EOF) -} - -func TestReaderErrorNoProgress(t *testing.T) { - r := NewReader(readerFunc(func(b []byte) (int, error) { - return 0, nil - })) - - _, err := r.ReadPacket() - assert.That(t, errors.Is(err, io.ErrNoProgress)) -}