diff --git a/drpcconn/conn_test.go b/drpcconn/conn_test.go index 9f74516..e7402b6 100644 --- a/drpcconn/conn_test.go +++ b/drpcconn/conn_test.go @@ -43,9 +43,9 @@ func TestConn_InvokeFlushesSendClose(t *testing.T) { wr := drpcwire.NewWriter(ps, 64) rd := drpcwire.NewReader(ps) - _, _ = rd.ReadPacket() // Invoke - _, _ = rd.ReadPacket() // Message - pkt, _ := rd.ReadPacket() // CloseSend + _, _ = rd.ReadFrame() // Invoke + _, _ = rd.ReadFrame() // Message + pkt, _ := rd.ReadFrame() // CloseSend _ = wr.WritePacket(drpcwire.Packet{ Data: []byte("qux"), @@ -54,8 +54,8 @@ func TestConn_InvokeFlushesSendClose(t *testing.T) { }) _ = wr.Flush() - _, _ = rd.ReadPacket() // Close - <-invokeDone // wait for invoke to return + _, _ = rd.ReadFrame() // Close + <-invokeDone // wait for invoke to return // ensure that any later packets are dropped by writing one // before closing the transport. @@ -98,7 +98,7 @@ func TestConn_InvokeSendsGrpcAndDrpcMetadata(t *testing.T) { wr := drpcwire.NewWriter(ps, 64) rd := drpcwire.NewReader(ps) - md, err := rd.ReadPacket() // Metadata + md, err := rd.ReadFrame() // Metadata assert.NoError(t, err) assert.Equal(t, md.Kind, drpcwire.KindInvokeMetadata) metadata, err := drpcmetadata.Decode(md.Data) @@ -110,9 +110,9 @@ func TestConn_InvokeSendsGrpcAndDrpcMetadata(t *testing.T) { "common-key": "common-value2", }) - _, _ = rd.ReadPacket() // Invoke - _, _ = rd.ReadPacket() // Message - pkt, _ := rd.ReadPacket() // CloseSend + _, _ = rd.ReadFrame() // Invoke + _, _ = rd.ReadFrame() // Message + pkt, _ := rd.ReadFrame() // CloseSend _ = wr.WritePacket(drpcwire.Packet{ Data: []byte("qux"), @@ -121,7 +121,7 @@ func TestConn_InvokeSendsGrpcAndDrpcMetadata(t *testing.T) { }) _ = wr.Flush() - _, _ = rd.ReadPacket() // Close + _, _ = rd.ReadFrame() // Close }) conn := New(pc) @@ -154,7 +154,7 @@ func TestConn_NewStreamSendsGrpcAndDrpcMetadata(t *testing.T) { ctx.Run(func(ctx context.Context) { rd := drpcwire.NewReader(ps) - md, err := rd.ReadPacket() // Metadata + md, err := rd.ReadFrame() // Metadata assert.NoError(t, err) assert.Equal(t, md.Kind, drpcwire.KindInvokeMetadata) metadata, err := drpcmetadata.Decode(md.Data) @@ -164,8 +164,8 @@ func TestConn_NewStreamSendsGrpcAndDrpcMetadata(t *testing.T) { "drpc-key": "drpc-value", }) - _, _ = rd.ReadPacket() // Invoke - _, _ = rd.ReadPacket() // CloseSend + _, _ = rd.ReadFrame() // Invoke + _, _ = rd.ReadFrame() // CloseSend }) conn := New(pc) diff --git a/drpcmanager/manager.go b/drpcmanager/manager.go index ae1cd7a..2207ea3 100644 --- a/drpcmanager/manager.go +++ b/drpcmanager/manager.go @@ -72,6 +72,9 @@ type Manager struct { rd *drpcwire.Reader opts Options + lastFrameID drpcwire.ID + lastFrameKind drpcwire.Kind + sem drpcsignal.Chan // held by the active stream sbuf streamBuffer // largest stream id created pkts chan drpcwire.Packet // channel for invoke packets @@ -213,27 +216,15 @@ func (m *Manager) terminate(err error) { // manage reader // -// manageReader is always reading a packet and dispatching it to the appropriate -// stream or queue. It sets the read signal when it exits so that one can wait -// to ensure that no one is reading on the reader. It sets the term signal if -// there is any error reading packets. +// manageReader reads the frame and dispatches them to the appropriate stream or +// queue. It sets the read signal when it exits so that one can wait to ensure +// that no one is reading on the reader. It sets the term signal if there is any +// error reading frames. func (m *Manager) manageReader() { defer m.sigs.read.Set(nil) - var pkt drpcwire.Packet - var err error - var run int - for !m.sigs.term.IsSet() { - // if we have a run of "small" packets, drop the buffer to release - // memory so that a burst of large packets does not cause eternally - // large heap usage. - if run > 10 { - pkt.Data = nil - run = 0 - } - - pkt, err = m.rd.ReadPacketUsing(pkt.Data[:0]) + incomingFrame, err := m.rd.ReadFrame() if err != nil { if isConnectionReset(err) { err = drpc.ClosedError.Wrap(err) @@ -242,36 +233,36 @@ func (m *Manager) manageReader() { return } - if len(pkt.Data) < cap(pkt.Data)/4 { - run++ - } else { - run = 0 - } + m.log("READ", incomingFrame.String) - m.log("READ", pkt.String) + if ok := m.checkStreamMonotonicity(incomingFrame); !ok { + m.terminate(managerClosed.Wrap(drpc.ProtocolError.New("id monotonicity violation"))) + return + } switch curr := m.sbuf.Get(); { - // if the packet is for the current stream, deliver it. - case curr != nil && pkt.ID.Stream == curr.ID(): - if err := curr.HandlePacket(pkt); err != nil { + // If the frame is for the current stream, deliver it. + case curr != nil && incomingFrame.ID.Stream == curr.ID(): + if err := curr.HandleFrame(incomingFrame); err != nil { m.terminate(managerClosed.Wrap(err)) return } - // if an old message has been sent, just ignore it. - case curr != nil && pkt.ID.Stream < curr.ID(): + // If a frame arrives for an old stream, just ignore it. + case curr != nil && incomingFrame.ID.Stream < curr.ID(): - // if any invoke sequence is being sent, close any old unterminated - // stream and forward it to be handled. - case pkt.Kind == drpcwire.KindInvoke || pkt.Kind == drpcwire.KindInvokeMetadata: + // If an invoke sequence is being sent for a new stream, close any + // old unterminated stream and forward it to be handled. + case incomingFrame.Kind == drpcwire.KindInvoke || incomingFrame.Kind == drpcwire.KindInvokeMetadata: if curr != nil && !curr.IsTerminated() { curr.Cancel(context.Canceled) } + pkt := drpcwire.Packet{ID: incomingFrame.ID, Kind: incomingFrame.Kind, Data: incomingFrame.Data} select { case m.pkts <- pkt: // Wait for NewServerStream to finish stream creation (including - // sbuf.Set) before reading the next packet. This guarantees curr + // sbuf.Set) before reading the next frame. This guarantees curr // is set for subsequent non-invoke packets. m.pdone.Recv() @@ -280,18 +271,28 @@ func (m *Manager) manageReader() { } default: - // A non-invoke packet arrived for a stream that doesn't exist yet - // (curr is nil or pkt.ID.Stream > curr.ID). The first packet of a - // new stream must be KindInvoke or KindInvokeMetadata. + // A non-invoke frame arrived for a stream that doesn't exist yet + // (curr is nil or incomingFrame.ID.Stream > curr.ID). The first + // frame of a new stream must be KindInvoke or KindInvokeMetadata. m.terminate(managerClosed.Wrap(drpc.ProtocolError.New( - "first packet of a new stream must be Invoke, got %v (ID:%v)", - pkt.Kind, - pkt.ID))) + "first frame of a new stream must be Invoke, got %v (ID:%v)", + incomingFrame.Kind, + incomingFrame.ID))) return } } } +func (m *Manager) checkStreamMonotonicity(incomingFrame drpcwire.Frame) bool { + ok := incomingFrame.ID.Stream >= m.lastFrameID.Stream + m.lastFrameKind = incomingFrame.Kind + m.lastFrameID = incomingFrame.ID + if incomingFrame.Done { + m.lastFrameID.Message += 1 + } + return ok +} + // // manage streams // diff --git a/drpcmanager/manager_test.go b/drpcmanager/manager_test.go index 58e8320..fd70611 100644 --- a/drpcmanager/manager_test.go +++ b/drpcmanager/manager_test.go @@ -193,6 +193,71 @@ func waitForClosed(t *testing.T, man *Manager) { } } +// +// manageReader tests +// + +// Global frame monotonicity: a frame with an ID lower than the last seen +// frame causes the manager to terminate with a protocol error. +func TestManageReader_GlobalMonotonicity_SameStream(t *testing.T) { + ctx := drpctest.NewTracker(t) + defer ctx.Close() + + cconn, sconn := net.Pipe() + defer func() { _ = cconn.Close() }() + defer func() { _ = sconn.Close() }() + + man := New(sconn) + defer func() { _ = man.Close() }() + + // Consume the invoke and drain messages so HandleFrame doesn't block. + ctx.Run(func(ctx context.Context) { + stream, _, err := man.NewServerStream(ctx) + assert.NoError(t, err) + for { + if _, err := stream.RawRecv(); err != nil { + return + } + } + }) + + writeFrames(t, cconn, + createFrame(drpcwire.KindInvoke, 1, 1, "rpc", true), + createFrame(drpcwire.KindMessage, 1, 5, "ok", true), + createFrame(drpcwire.KindMessage, 1, 4, "bad", true), + ) + + waitForClosed(t, man) +} + +// Cross-stream monotonicity: after seeing stream 2, a frame for stream 1 +// with a higher message ID is still rejected because {1,x} < {2,y}. +func TestManageReader_GlobalMonotonicity_CrossStream(t *testing.T) { + ctx := drpctest.NewTracker(t) + defer ctx.Close() + + cconn, sconn := net.Pipe() + defer func() { _ = cconn.Close() }() + defer func() { _ = sconn.Close() }() + + man := New(sconn) + defer func() { _ = man.Close() }() + + // Consume both invokes so manageReader can proceed. + ctx.Run(func(ctx context.Context) { + _, _, _ = man.NewServerStream(ctx) + _, _, _ = man.NewServerStream(ctx) + }) + + writeFrames(t, cconn, + createFrame(drpcwire.KindInvoke, 1, 1, "rpc1", true), + createFrame(drpcwire.KindInvoke, 2, 1, "rpc2", true), + createFrame(drpcwire.KindMessage, 1, 4, "bad", true), + ) + + waitForClosed(t, man) +} + // Invoke replay: after [s1,m1,invoke,done=true], lastFrameID is bumped to // {1,2}. A replayed [s1,m1,invoke] is caught by the monotonicity check. func TestManageReader_InvokeReplayBlocked(t *testing.T) { @@ -218,6 +283,274 @@ func TestManageReader_InvokeReplayBlocked(t *testing.T) { waitForClosed(t, man) } +// Non-done frames don't bump the message ID, so continuation frames with +// the same ID are accepted. +func TestManageReader_ContinuationFramesAccepted(t *testing.T) { + ctx := drpctest.NewTracker(t) + defer ctx.Close() + + cconn, sconn := net.Pipe() + defer func() { _ = cconn.Close() }() + defer func() { _ = sconn.Close() }() + + man := New(sconn) + defer func() { _ = man.Close() }() + + recv := make(chan []byte, 1) + ctx.Run(func(ctx context.Context) { + stream, _, err := man.NewServerStream(ctx) + assert.NoError(t, err) + data, err := stream.RawRecv() + assert.NoError(t, err) + recv <- data + }) + + writeFrames(t, cconn, + createFrame(drpcwire.KindInvoke, 1, 1, "rpc", true), + createFrame(drpcwire.KindMessage, 1, 2, "hel", false), + createFrame(drpcwire.KindMessage, 1, 2, "lo", true), + ) + + assert.DeepEqual(t, <-recv, []byte("hello")) +} + +// Old-stream frames are silently ignored on the client side when the local +// stream ID has advanced past the incoming frame's stream ID. +func TestManageReader_OldStreamFramesIgnored(t *testing.T) { + ctx := drpctest.NewTracker(t) + defer ctx.Close() + + cconn, sconn := net.Pipe() + defer func() { _ = cconn.Close() }() + defer func() { _ = sconn.Close() }() + + cman := NewWithOptions(cconn, Options{SoftCancel: true}) + defer func() { _ = cman.Close() }() + + // Drain all client writes so nothing blocks, and write server + // responses once we've seen enough data. + ctx.Run(func(ctx context.Context) { + buf := make([]byte, 4096) + for { + _, err := sconn.Read(buf) + if err != nil { + return + } + } + }) + + // Create stream 1 on the client, then cancel it so the client + // advances to stream 2. + subctx, cancel := context.WithCancel(ctx) + _, err := cman.NewClientStream(subctx, "rpc1") + assert.NoError(t, err) + cancel() + <-cman.Unblocked() + + stream2, err := cman.NewClientStream(ctx, "rpc2") + assert.NoError(t, err) + + // Write an old-stream frame (s1) then the real response for s2. + // The s1 frame should be silently ignored by the client manager. + writeFrames(t, sconn, + createFrame(drpcwire.KindMessage, 1, 1, "old", true), + createFrame(drpcwire.KindMessage, 2, 1, "new", true), + ) + + data, err := stream2.RawRecv() + assert.NoError(t, err) + assert.DeepEqual(t, data, []byte("new")) + + _ = stream2.Close() +} + +// The first frame for a new stream must be KindInvoke or KindInvokeMetadata. +// A non-invoke kind causes a protocol error. +func TestManageReader_FirstFrameMustBeInvoke(t *testing.T) { + for _, kind := range []drpcwire.Kind{ + drpcwire.KindMessage, + drpcwire.KindCancel, + drpcwire.KindClose, + drpcwire.KindCloseSend, + drpcwire.KindError, + } { + t.Run(kind.String(), func(t *testing.T) { + cconn, sconn := net.Pipe() + defer func() { _ = cconn.Close() }() + defer func() { _ = sconn.Close() }() + + man := New(sconn) + defer func() { _ = man.Close() }() + + writeFrames(t, cconn, + createFrame(kind, 1, 1, "", true), + ) + + waitForClosed(t, man) + }) + } +} + +// A valid invoke sequence: Invoke → Message. +// Metadata encoding is covered separately by TestDrpcMetadata. +func TestManageReader_ValidInvokeSequence(t *testing.T) { + ctx := drpctest.NewTracker(t) + defer ctx.Close() + + cconn, sconn := net.Pipe() + defer func() { _ = cconn.Close() }() + defer func() { _ = sconn.Close() }() + + man := New(sconn) + defer func() { _ = man.Close() }() + + recv := make(chan []byte, 1) + ctx.Run(func(ctx context.Context) { + stream, rpc, err := man.NewServerStream(ctx) + assert.NoError(t, err) + assert.Equal(t, rpc, "myrpc") + + data, err := stream.RawRecv() + assert.NoError(t, err) + recv <- data + }) + + writeFrames(t, cconn, + createFrame(drpcwire.KindInvoke, 1, 1, "myrpc", true), + createFrame(drpcwire.KindMessage, 1, 2, "payload", true), + ) + + assert.DeepEqual(t, <-recv, []byte("payload")) +} + +// Multi-frame message delivered through manager to stream: frames are +// assembled by the stream into a single packet. +func TestManageReader_MultiFrameDelivery(t *testing.T) { + ctx := drpctest.NewTracker(t) + defer ctx.Close() + + cconn, sconn := net.Pipe() + defer func() { _ = cconn.Close() }() + defer func() { _ = sconn.Close() }() + + man := New(sconn) + defer func() { _ = man.Close() }() + + recv := make(chan []byte, 1) + ctx.Run(func(ctx context.Context) { + stream, _, err := man.NewServerStream(ctx) + assert.NoError(t, err) + + data, err := stream.RawRecv() + assert.NoError(t, err) + recv <- data + }) + + writeFrames(t, cconn, + createFrame(drpcwire.KindInvoke, 1, 1, "rpc", true), + createFrame(drpcwire.KindMessage, 1, 2, "hel", false), + createFrame(drpcwire.KindMessage, 1, 2, "lo ", false), + createFrame(drpcwire.KindMessage, 1, 2, "world", true), + ) + + assert.DeepEqual(t, <-recv, []byte("hello world")) +} + +// When a higher message ID arrives mid-assembly, the partial data is +// discarded and only the new message is delivered. +func TestManageReader_HigherMsgDiscardsInProgress(t *testing.T) { + ctx := drpctest.NewTracker(t) + defer ctx.Close() + + cconn, sconn := net.Pipe() + defer func() { _ = cconn.Close() }() + defer func() { _ = sconn.Close() }() + + man := New(sconn) + defer func() { _ = man.Close() }() + + recv := make(chan []byte, 1) + ctx.Run(func(ctx context.Context) { + stream, _, err := man.NewServerStream(ctx) + assert.NoError(t, err) + data, err := stream.RawRecv() + assert.NoError(t, err) + recv <- data + }) + + writeFrames(t, cconn, + createFrame(drpcwire.KindInvoke, 1, 1, "rpc", true), + createFrame(drpcwire.KindMessage, 1, 2, "discard", false), + createFrame(drpcwire.KindMessage, 1, 3, "kept", true), + ) + + assert.DeepEqual(t, <-recv, []byte("kept")) +} + +// A continuation frame with a different kind than the first frame of the +// packet causes the manager to terminate with a protocol error. +func TestManageReader_KindChangeWithinPacket(t *testing.T) { + ctx := drpctest.NewTracker(t) + defer ctx.Close() + + cconn, sconn := net.Pipe() + defer func() { _ = cconn.Close() }() + defer func() { _ = sconn.Close() }() + + man := New(sconn) + defer func() { _ = man.Close() }() + + ctx.Run(func(ctx context.Context) { + stream, _, err := man.NewServerStream(ctx) + assert.NoError(t, err) + for { + if _, err := stream.RawRecv(); err != nil { + return + } + } + }) + + writeFrames(t, cconn, + createFrame(drpcwire.KindInvoke, 1, 1, "rpc", true), + createFrame(drpcwire.KindMessage, 1, 2, "data", false), + createFrame(drpcwire.KindClose, 1, 2, "", true), + ) + + waitForClosed(t, man) +} + +// Multi-frame assembly works correctly when the message ID is greater than +// the previous message (e.g., on the server side where invoke consumed +// earlier IDs). +func TestManageReader_MultiFrameWithSkippedMessageID(t *testing.T) { + ctx := drpctest.NewTracker(t) + defer ctx.Close() + + cconn, sconn := net.Pipe() + defer func() { _ = cconn.Close() }() + defer func() { _ = sconn.Close() }() + + man := New(sconn) + defer func() { _ = man.Close() }() + + recv := make(chan []byte, 1) + ctx.Run(func(ctx context.Context) { + stream, _, err := man.NewServerStream(ctx) + assert.NoError(t, err) + data, err := stream.RawRecv() + assert.NoError(t, err) + recv <- data + }) + + writeFrames(t, cconn, + createFrame(drpcwire.KindInvoke, 1, 1, "rpc", true), + createFrame(drpcwire.KindMessage, 1, 3, "hel", false), + createFrame(drpcwire.KindMessage, 1, 3, "lo", true), + ) + + assert.DeepEqual(t, <-recv, []byte("hello")) +} + // A second invoke for the same stream ID is rejected — the stream treats // it as a protocol error, terminating the manager. func TestManageReader_InvokeOnExistingStream(t *testing.T) { @@ -246,6 +579,43 @@ func TestManageReader_InvokeOnExistingStream(t *testing.T) { assert.That(t, drpc.ProtocolError.Has(man.sigs.term.Err())) } +// When a non-invoke frame arrives before the stream is created (e.g., +// NewServerStream hasn't returned yet), manageReader waits for the stream +// and retries. +func TestManageReader_WaitsForStreamCreation(t *testing.T) { + ctx := drpctest.NewTracker(t) + defer ctx.Close() + + cconn, sconn := net.Pipe() + defer func() { _ = cconn.Close() }() + defer func() { _ = sconn.Close() }() + + man := New(sconn) + defer func() { _ = man.Close() }() + + // Write invoke + message immediately. The message arrives before + // NewServerStream creates the stream, exercising the default/wait path. + writeFrames(t, cconn, + createFrame(drpcwire.KindInvoke, 1, 1, "rpc", true), + createFrame(drpcwire.KindMessage, 1, 2, "data", true), + ) + + // Small delay to let manageReader process both frames. + time.Sleep(10 * time.Millisecond) + + recv := make(chan []byte, 1) + ctx.Run(func(ctx context.Context) { + stream, _, err := man.NewServerStream(ctx) + assert.NoError(t, err) + + data, err := stream.RawRecv() + assert.NoError(t, err) + recv <- data + }) + + assert.DeepEqual(t, <-recv, []byte("data")) +} + type blockingTransport chan struct{} func (b blockingTransport) Read(p []byte) (n int, err error) { <-b; return 0, io.EOF } diff --git a/drpcstream/stream.go b/drpcstream/stream.go index 29ccd63..7fd769e 100644 --- a/drpcstream/stream.go +++ b/drpcstream/stream.go @@ -23,7 +23,7 @@ import ( // Options controls configuration settings for a stream. type Options struct { - // SplitSize controls the default size we split packets into frames. + // SplitSize controls the default size we split data packets into frames. SplitSize int // ManualFlush controls if the stream will automatically flush after every @@ -52,6 +52,11 @@ type Stream struct { read inspectMutex flush sync.Once + assembling bool + pktBuf []byte + pktKind drpcwire.Kind + nextMessageID uint64 + id drpcwire.ID wr *drpcwire.Writer pbuf packetBuffer @@ -98,6 +103,8 @@ func NewWithOptions(ctx context.Context, sid uint64, wr *drpcwire.Writer, opts O fin: drpcopts.GetStreamFin(&opts.Internal), task: task, + nextMessageID: 1, + id: drpcwire.ID{Stream: sid}, wr: wr.Reset(), } @@ -211,24 +218,65 @@ func (s *Stream) IsFinished() bool { return s.sigs.fin.IsSet() } func (s *Stream) SetManualFlush(mf bool) { s.opts.ManualFlush = mf } // -// packet handler +// frame handler // -// HandlePacket advances the stream state machine by inspecting the packet. It -// returns any major errors that should terminate the transport the stream is -// operating on as well as a boolean indicating if the stream expects more -// packets. -func (s *Stream) HandlePacket(pkt drpcwire.Packet) (err error) { - if pkt.ID.Stream != s.id.Stream { +// HandleFrame processes an incoming frame, assembling multi-frame packets +// and dispatching complete packets to the stream state machine. +func (s *Stream) HandleFrame(fr drpcwire.Frame) (err error) { + if s.sigs.term.IsSet() { return nil } - drpcopts.GetStreamStats(&s.opts.Internal).AddRead(uint64(len(pkt.Data))) + if fr.ID.Stream != s.ID() { + return drpc.ProtocolError.New("frame doesn't belong to this stream (fr: %v)", fr.ID) + } - if s.sigs.term.IsSet() { + if fr.ID.Message < s.nextMessageID { + return drpc.ProtocolError.New( + "id monotonicity violation: frame %v has message ID less than expected %v", fr.ID, s.nextMessageID) + } else if fr.ID.Message > s.nextMessageID || !s.assembling { + s.pktBuf = s.pktBuf[:0] + s.assembling = true + s.nextMessageID = fr.ID.Message + } else if fr.Kind != s.pktKind { + return drpc.ProtocolError.New("frame kind change within packet: got %v, expected %v", fr.Kind, s.pktKind) + } + + // TODO(shubham): add buf reuse + s.pktBuf = append(s.pktBuf, fr.Data...) + + s.pktKind = fr.Kind + + if s.opts.MaximumBufferSize > 0 && len(s.pktBuf) > s.opts.MaximumBufferSize { + return drpc.ProtocolError.New("data overflow (len:%d)", len(s.pktBuf)) + } + + if !fr.Done { return nil } + s.assembling = false + s.nextMessageID = fr.ID.Message + 1 + + err = s.handlePacket(drpcwire.Packet{ + ID: fr.ID, + Kind: fr.Kind, + Control: fr.Control, + Data: s.pktBuf, + }) + + // TODO(shubham): add buf reuse + s.pktBuf = nil + return err +} + +// handlePacket advances the stream state machine by inspecting the packet. It +// returns any major errors that should terminate the transport the stream is +// operating on. +func (s *Stream) handlePacket(pkt drpcwire.Packet) (err error) { + drpcopts.GetStreamStats(&s.opts.Internal).AddRead(uint64(len(pkt.Data))) + s.log("HANDLE", pkt.String) if pkt.Kind == drpcwire.KindMessage { @@ -240,7 +288,7 @@ func (s *Stream) HandlePacket(pkt drpcwire.Packet) (err error) { defer s.mu.Unlock() switch pkt.Kind { - case drpcwire.KindInvoke: + case drpcwire.KindInvoke, drpcwire.KindInvokeMetadata: err := drpc.ProtocolError.New("invoke on existing stream") s.terminate(err) return err @@ -375,6 +423,7 @@ func (s *Stream) RawWrite(kind drpcwire.Kind, data []byte) (err error) { // rawWriteLocked does the body of RawWrite assuming the caller is holding the // appropriate locks. +// TODO(shubham): can we merge this with sendPacketLocked? func (s *Stream) rawWriteLocked(kind drpcwire.Kind, data []byte) (err error) { fr := s.newFrameLocked(kind) n := s.opts.SplitSize diff --git a/drpcstream/stream_test.go b/drpcstream/stream_test.go index 3cf4ca3..a62ab87 100644 --- a/drpcstream/stream_test.go +++ b/drpcstream/stream_test.go @@ -8,6 +8,7 @@ import ( "context" "errors" "io" + "strings" "testing" "github.com/zeebo/assert" @@ -18,16 +19,23 @@ import ( "storj.io/drpc/drpcwire" ) +// handleFrame is a helper that sends a single-frame packet to the stream. +// It constructs a frame with the given kind, matching the stream's ID, +// using the provided message ID, done=true. +func handleFrame(st *Stream, kind drpcwire.Kind, mid uint64) error { + return st.HandleFrame(drpcwire.Frame{ + ID: drpcwire.ID{Stream: st.ID(), Message: mid}, + Kind: kind, + Done: true, + }) +} + func TestStream_StateTransitions(t *testing.T) { ctx := drpctest.NewTracker(t) defer ctx.Close() any := errors.New("any sentinel error") - handlePacket := func(st *Stream, kind drpcwire.Kind) error { - return st.HandlePacket(drpcwire.Packet{Kind: kind}) - } - checkErrs := func(t *testing.T, exp interface{}, got error) { t.Helper() @@ -81,32 +89,32 @@ func TestStream_StateTransitions(t *testing.T) { }, { // recv close - Op: func(st *Stream) error { return handlePacket(st, drpcwire.KindClose) }, + Op: func(st *Stream) error { return handleFrame(st, drpcwire.KindClose, 1) }, Send: &drpc.ClosedError, Recv: io.EOF, }, { // recv error - Op: func(st *Stream) error { return handlePacket(st, drpcwire.KindError) }, + Op: func(st *Stream) error { return handleFrame(st, drpcwire.KindError, 1) }, Send: io.EOF, Recv: any, }, { // recv closesend - Op: func(st *Stream) error { return handlePacket(st, drpcwire.KindCloseSend) }, + Op: func(st *Stream) error { return handleFrame(st, drpcwire.KindCloseSend, 1) }, Send: nil, Recv: io.EOF, }, } for _, test := range cases { - st := New(ctx, 0, drpcwire.NewWriter(io.Discard, 0)) + st := New(ctx, 1, drpcwire.NewWriter(io.Discard, 0)) assert.NoError(t, test.Op(st)) checkErrs(t, test.Send, st.RawWrite(drpcwire.KindMessage, nil)) if test.Recv == nil { - ctx.Run(func(ctx context.Context) { _ = handlePacket(st, drpcwire.KindMessage) }) + ctx.Run(func(ctx context.Context) { _ = handleFrame(st, drpcwire.KindMessage, 2) }) } _, err := st.RawRecv() checkErrs(t, test.Recv, err) @@ -117,10 +125,6 @@ func TestStream_Unblocks(t *testing.T) { ctx := drpctest.NewTracker(t) defer ctx.Close() - handlePacket := func(st *Stream, kind drpcwire.Kind) error { - return st.HandlePacket(drpcwire.Packet{Kind: kind}) - } - cases := []struct { Op func(st *Stream) error }{ @@ -141,20 +145,20 @@ func TestStream_Unblocks(t *testing.T) { }, { // recv close - Op: func(st *Stream) error { return handlePacket(st, drpcwire.KindClose) }, + Op: func(st *Stream) error { return handleFrame(st, drpcwire.KindClose, 1) }, }, { // recv error - Op: func(st *Stream) error { return handlePacket(st, drpcwire.KindError) }, + Op: func(st *Stream) error { return handleFrame(st, drpcwire.KindError, 1) }, }, { // recv closesend - Op: func(st *Stream) error { return handlePacket(st, drpcwire.KindCloseSend) }, + Op: func(st *Stream) error { return handleFrame(st, drpcwire.KindCloseSend, 1) }, }, } for _, test := range cases { - st := New(ctx, 0, drpcwire.NewWriter(io.Discard, 0)) + st := New(ctx, 1, drpcwire.NewWriter(io.Discard, 0)) ctx.Run(func(ctx context.Context) { _, _ = st.RawRecv() }) assert.NoError(t, test.Op(st)) @@ -200,20 +204,6 @@ func TestStream_ConcurrentCloseCancel(t *testing.T) { assert.That(t, errors.Is(<-errch, context.Canceled)) } -func TestStream_Control(t *testing.T) { - st := New(context.Background(), 0, drpcwire.NewWriter(io.Discard, 0)) - - // N.B. the stream will return nil on any HandlePacket calls after the - // stream has been terminated for any reason, including if an invalid - // packet has been sent. the order of these two assertions is important! - - // an invalid packet is not an error if the control bit is set - assert.NoError(t, st.HandlePacket(drpcwire.Packet{Control: true})) - - // an invalid packet is an error if the control bit it not set - assert.That(t, drpc.InternalError.Has(st.HandlePacket(drpcwire.Packet{}))) -} - func TestStream_CorkUntilFirstRead(t *testing.T) { run := func() { ctx := drpctest.NewTracker(t) @@ -234,10 +224,11 @@ func TestStream_CorkUntilFirstRead(t *testing.T) { errch <- err }) ctx.Run(func(ctx context.Context) { - errch <- st.HandlePacket(drpcwire.Packet{ + errch <- st.HandleFrame(drpcwire.Frame{ Data: []byte("read"), ID: drpcwire.ID{Message: 1}, Kind: drpcwire.KindMessage, + Done: true, }) }) @@ -266,20 +257,24 @@ func TestStream_PacketBufferReuse(t *testing.T) { defer ctx.Close() defer ctx.Wait() - buf := make([]byte, 20) - st := New(ctx, 0, drpcwire.NewWriter(io.Discard, 0)) + data := make([]byte, 20) + mid := uint64(1) + st := New(ctx, 1, drpcwire.NewWriter(io.Discard, 0)) ctx.Run(func(ctx context.Context) { for !st.IsTerminated() { - err := st.HandlePacket(drpcwire.Packet{ - Data: buf, + err := st.HandleFrame(drpcwire.Frame{ + Data: data, + ID: drpcwire.ID{Stream: 1, Message: mid}, Kind: drpcwire.KindMessage, + Done: true, }) if err != nil { return } - for i := range buf { - buf[i]++ + mid++ + for i := range data { + data[i]++ } } }) @@ -327,3 +322,354 @@ func TestStream_SendCancelBusyDuringBlockedClose(t *testing.T) { assert.NoError(t, err) assert.That(t, busy) } + +// +// HandleFrame tests +// + +// A frame routed to the wrong stream is a protocol error. +func TestHandleFrame_WrongStreamID(t *testing.T) { + st := New(context.Background(), 1, drpcwire.NewWriter(io.Discard, 0)) + + err := st.HandleFrame(drpcwire.Frame{ + ID: drpcwire.ID{Stream: 2, Message: 1}, + Kind: drpcwire.KindMessage, + Done: true, + }) + assert.Error(t, err) + assert.That(t, drpc.ProtocolError.Has(err)) + assert.That(t, strings.Contains(err.Error(), "doesn't belong")) +} + +// A frame with a message ID lower than a previously completed message is rejected. +func TestHandleFrame_MessageMonotonicity(t *testing.T) { + st := New(context.Background(), 1, drpcwire.NewWriter(io.Discard, 0)) + // Close packet buffer so KindMessage delivery doesn't block. + st.pbuf.Close(io.EOF) + + // m3 completes, nextMessageID becomes 4. + assert.NoError(t, st.HandleFrame(drpcwire.Frame{ + ID: drpcwire.ID{Stream: 1, Message: 3}, Kind: drpcwire.KindMessage, Done: true, + })) + + // m2 < 4 → error. + err := st.HandleFrame(drpcwire.Frame{ + ID: drpcwire.ID{Stream: 1, Message: 2}, Kind: drpcwire.KindMessage, Done: true, + }) + assert.Error(t, err) + assert.That(t, drpc.ProtocolError.Has(err)) + assert.That(t, strings.Contains(err.Error(), "monotonicity")) +} + +func TestHandleFrame_FirstFrameOnFreshStream(t *testing.T) { + // On the client side, the first message received will have ID 1. But on the + // server side, invoke is consumed by the manager. The first frame reaching + // the stream could have msg > 1 (e.g., msg=2). nextMessageID=1, so 2 > 1 + // makes this a valid frame. + for _, messageID := range []uint64{1, 2} { + st := New(context.Background(), 1, drpcwire.NewWriter(io.Discard, 0)) + // Close the packet buffer so KindMessage Put doesn't block. + st.pbuf.Close(io.EOF) + err := st.HandleFrame(drpcwire.Frame{ + ID: drpcwire.ID{Stream: 1, Message: messageID}, Kind: drpcwire.KindMessage, Done: true, + }) + assert.NoError(t, err) + } +} + +// When a higher message ID arrives mid-assembly, the in-progress data is +// silently discarded and a new packet begins. +func TestHandleFrame_HigherMsgDiscardsInProgress(t *testing.T) { + ctx := drpctest.NewTracker(t) + defer ctx.Close() + + st := New(ctx, 1, drpcwire.NewWriter(io.Discard, 0)) + + // Start accumulating m1 (done=false doesn't call Put, so no blocking). + assert.NoError(t, st.HandleFrame(drpcwire.Frame{ + ID: drpcwire.ID{Stream: 1, Message: 1}, Kind: drpcwire.KindMessage, Data: []byte("discard"), Done: false, + })) + + // Launch receiver before sending done frame to avoid Put blocking. + recv := make(chan []byte, 1) + ctx.Run(func(ctx context.Context) { + data, err := st.RawRecv() + assert.NoError(t, err) + recv <- data + }) + + // m2 arrives, m1 data should be silently discarded. + assert.NoError(t, st.HandleFrame(drpcwire.Frame{ + ID: drpcwire.ID{Stream: 1, Message: 2}, Kind: drpcwire.KindMessage, Data: []byte("kept"), Done: true, + })) + + // Verify only m2's data was delivered. + assert.DeepEqual(t, <-recv, []byte("kept")) +} + +// Continuation frames (same message ID, mid-assembly) must carry the same +// kind as the first frame. A kind change mid-packet is a protocol error. +func TestHandleFrame_KindChangeWithinPacket(t *testing.T) { + st := New(context.Background(), 1, drpcwire.NewWriter(io.Discard, 0)) + + assert.NoError(t, st.HandleFrame(drpcwire.Frame{ + ID: drpcwire.ID{Stream: 1, Message: 1}, Kind: drpcwire.KindMessage, Done: false, + })) + + err := st.HandleFrame(drpcwire.Frame{ + ID: drpcwire.ID{Stream: 1, Message: 1}, Kind: drpcwire.KindError, Done: true, + }) + assert.Error(t, err) + assert.That(t, drpc.ProtocolError.Has(err)) + assert.That(t, strings.Contains(err.Error(), "kind change")) +} + +// Multiple continuation frames for the same message accumulate data correctly. +func TestHandleFrame_MultiFrameDataAccumulation(t *testing.T) { + ctx := drpctest.NewTracker(t) + defer ctx.Close() + + st := New(ctx, 1, drpcwire.NewWriter(io.Discard, 0)) + + // Continuation frames (done=false) don't call Put, so no blocking. + assert.NoError(t, st.HandleFrame(drpcwire.Frame{ + ID: drpcwire.ID{Stream: 1, Message: 1}, Kind: drpcwire.KindMessage, Data: []byte("hel"), Done: false, + })) + assert.NoError(t, st.HandleFrame(drpcwire.Frame{ + ID: drpcwire.ID{Stream: 1, Message: 1}, Kind: drpcwire.KindMessage, Data: []byte("lo "), Done: false, + })) + + // Launch receiver before the final frame to avoid Put blocking. + recv := make(chan []byte, 1) + ctx.Run(func(ctx context.Context) { + data, err := st.RawRecv() + assert.NoError(t, err) + recv <- data + }) + + assert.NoError(t, st.HandleFrame(drpcwire.Frame{ + ID: drpcwire.ID{Stream: 1, Message: 1}, Kind: drpcwire.KindMessage, Data: []byte("world"), Done: true, + })) + + assert.DeepEqual(t, <-recv, []byte("hello world")) +} + +// Multi-frame assembly works when the message ID is greater than nextMessageID +// (e.g., on the server side where invoke consumed earlier message IDs). +// Continuation frames must accumulate data, not reset on each frame. +func TestHandleFrame_MultiFrameWithSkippedMessageID(t *testing.T) { + ctx := drpctest.NewTracker(t) + defer ctx.Close() + + st := New(ctx, 1, drpcwire.NewWriter(io.Discard, 0)) + + // msg=3 is greater than nextMessageID=1. Continuation frames for the + // same message must still accumulate correctly. + assert.NoError(t, st.HandleFrame(drpcwire.Frame{ + ID: drpcwire.ID{Stream: 1, Message: 3}, Kind: drpcwire.KindMessage, Data: []byte("hel"), Done: false, + })) + assert.NoError(t, st.HandleFrame(drpcwire.Frame{ + ID: drpcwire.ID{Stream: 1, Message: 3}, Kind: drpcwire.KindMessage, Data: []byte("lo"), Done: false, + })) + + recv := make(chan []byte, 1) + ctx.Run(func(ctx context.Context) { + data, err := st.RawRecv() + assert.NoError(t, err) + recv <- data + }) + + assert.NoError(t, st.HandleFrame(drpcwire.Frame{ + ID: drpcwire.ID{Stream: 1, Message: 3}, Kind: drpcwire.KindMessage, Data: []byte(" world"), Done: true, + })) + + assert.DeepEqual(t, <-recv, []byte("hello world")) +} + +// Once a message completes (done=true), the same message ID is rejected. +func TestHandleFrame_DonePreventsReplay(t *testing.T) { + st := New(context.Background(), 1, drpcwire.NewWriter(io.Discard, 0)) + // Close packet buffer so KindMessage delivery doesn't block. + st.pbuf.Close(io.EOF) + + // m1 completes → nextMessageID becomes 2. + assert.NoError(t, st.HandleFrame(drpcwire.Frame{ + ID: drpcwire.ID{Stream: 1, Message: 1}, Kind: drpcwire.KindMessage, Done: true, + })) + + // Same message ID again → error. + err := st.HandleFrame(drpcwire.Frame{ + ID: drpcwire.ID{Stream: 1, Message: 1}, Kind: drpcwire.KindMessage, Done: true, + }) + assert.Error(t, err) + assert.That(t, drpc.ProtocolError.Has(err)) + assert.That(t, strings.Contains(err.Error(), "monotonicity")) +} + +// Kind consistency is only enforced within a packet (continuation frames), not +// across messages. A multi-frame KindMessage followed by a KindClose for the +// next message should be accepted without error. +func TestHandleFrame_MultiFrameThenNextMessage(t *testing.T) { + ctx := drpctest.NewTracker(t) + defer ctx.Close() + + st := New(ctx, 1, drpcwire.NewWriter(io.Discard, 0)) + + assert.NoError(t, st.HandleFrame(drpcwire.Frame{ + ID: drpcwire.ID{Stream: 1, Message: 1}, Kind: drpcwire.KindMessage, Data: []byte("ab"), Done: false, + })) + + // Launch receiver before done frame. + recv := make(chan []byte, 1) + ctx.Run(func(ctx context.Context) { + data, err := st.RawRecv() + assert.NoError(t, err) + recv <- data + }) + + assert.NoError(t, st.HandleFrame(drpcwire.Frame{ + ID: drpcwire.ID{Stream: 1, Message: 1}, Kind: drpcwire.KindMessage, Data: []byte("cd"), Done: true, + })) + assert.DeepEqual(t, <-recv, []byte("abcd")) + + // Message 2 with a different kind — should not trigger kind check. + assert.NoError(t, st.HandleFrame(drpcwire.Frame{ + ID: drpcwire.ID{Stream: 1, Message: 2}, Kind: drpcwire.KindClose, Done: true, + })) + + // Close triggers EOF on recv. + ctx.Run(func(ctx context.Context) { + _, err := st.RawRecv() + assert.That(t, errors.Is(err, io.EOF)) + }) + ctx.Wait() +} + +// Invoke and InvokeMetadata frames are rejected on an already-created stream. +func TestHandleFrame_InvokeOnExistingStream(t *testing.T) { + st := New(context.Background(), 1, drpcwire.NewWriter(io.Discard, 0)) + + err := handleFrame(st, drpcwire.KindInvoke, 1) + assert.Error(t, err) + assert.That(t, drpc.ProtocolError.Has(err)) + assert.That(t, strings.Contains(err.Error(), "invoke on existing stream")) +} + +func TestHandleFrame_InvokeMetadataOnExistingStream(t *testing.T) { + st := New(context.Background(), 1, drpcwire.NewWriter(io.Discard, 0)) + + err := handleFrame(st, drpcwire.KindInvokeMetadata, 1) + assert.Error(t, err) + assert.That(t, drpc.ProtocolError.Has(err)) + assert.That(t, strings.Contains(err.Error(), "invoke on existing stream")) +} + +// Frames arriving after the stream is terminated are silently ignored. +func TestHandleFrame_AfterTerminated(t *testing.T) { + st := New(context.Background(), 1, drpcwire.NewWriter(io.Discard, 0)) + + // Terminate the stream via cancel. + st.Cancel(context.Canceled) + + // Frames after termination are silently ignored. + err := st.HandleFrame(drpcwire.Frame{ + ID: drpcwire.ID{Stream: 1, Message: 1}, Kind: drpcwire.KindMessage, Done: true, + }) + assert.NoError(t, err) +} + +// A completed KindMessage frame delivers its data through RawRecv. +func TestHandleFrame_MessageDeliveredViaRecv(t *testing.T) { + ctx := drpctest.NewTracker(t) + defer ctx.Close() + + st := New(ctx, 1, drpcwire.NewWriter(io.Discard, 0)) + + // Launch receiver before sending to avoid Put blocking. + recv := make(chan []byte, 1) + ctx.Run(func(ctx context.Context) { + data, err := st.RawRecv() + assert.NoError(t, err) + recv <- data + }) + + assert.NoError(t, st.HandleFrame(drpcwire.Frame{ + ID: drpcwire.ID{Stream: 1, Message: 1}, + Kind: drpcwire.KindMessage, + Data: []byte("payload"), + Done: true, + })) + + assert.DeepEqual(t, <-recv, []byte("payload")) +} + +// +// Write-side tests +// + +func TestRawWrite_NonMessageSingleFrame(t *testing.T) { + // Non-KindMessage kinds must produce a single frame (n=0 in + // rawWriteLocked means default 64KB, effectively no split for + // small payloads). Verify they produce exactly one frame with Done=true. + kinds := []drpcwire.Kind{ + drpcwire.KindInvoke, + drpcwire.KindError, + drpcwire.KindCancel, + drpcwire.KindClose, + drpcwire.KindCloseSend, + drpcwire.KindInvokeMetadata, + } + + for _, kind := range kinds { + var buf bytes.Buffer + st := New(context.Background(), 1, drpcwire.NewWriter(&buf, 0)) + + assert.NoError(t, st.RawWrite(kind, []byte("data"))) + assert.NoError(t, st.RawFlush()) + var err error + + // Parse all frames from the buffer — should be exactly one. + data := buf.Bytes() + var frames []drpcwire.Frame + for len(data) > 0 { + var fr drpcwire.Frame + var ok bool + data, fr, ok, err = drpcwire.ParseFrame(data) + assert.NoError(t, err) + assert.That(t, ok) + frames = append(frames, fr) + } + assert.Equal(t, len(frames), 1) + assert.That(t, frames[0].Done) + assert.Equal(t, frames[0].Kind, kind) + } +} + +func TestRawWrite_MessageRespectsSplitSize(t *testing.T) { + var buf bytes.Buffer + st := NewWithOptions(context.Background(), 1, + drpcwire.NewWriter(&buf, 0), + Options{SplitSize: 5}, + ) + + // "helloworld" is 10 bytes, split at 5 → 2 frames. + assert.NoError(t, st.RawWrite(drpcwire.KindMessage, []byte("helloworld"))) + assert.NoError(t, st.RawFlush()) + var err error + + data := buf.Bytes() + var frames []drpcwire.Frame + for len(data) > 0 { + var fr drpcwire.Frame + var ok bool + data, fr, ok, err = drpcwire.ParseFrame(data) + assert.NoError(t, err) + assert.That(t, ok) + frames = append(frames, fr) + } + assert.Equal(t, len(frames), 2) + assert.That(t, !frames[0].Done) + assert.That(t, frames[1].Done) + assert.DeepEqual(t, frames[0].Data, []byte("hello")) + assert.DeepEqual(t, frames[1].Data, []byte("world")) +} diff --git a/drpcwire/reader.go b/drpcwire/reader.go index c9ac397..d5ab580 100644 --- a/drpcwire/reader.go +++ b/drpcwire/reader.go @@ -16,13 +16,12 @@ type ReaderOptions struct { MaximumBufferSize int } -// Reader reconstructs packets from frames read from an io.Reader. +// Reader reads frames from an io.Reader. type Reader struct { opts ReaderOptions r io.Reader curr []byte buf []byte - id ID rerr error } @@ -35,12 +34,12 @@ type Reader struct { // 9: maximum varint data length const maxFrameOverhead = 1 + 9 + 9 + 9 -// NewReader constructs a Reader to read Packets from the io.Reader. +// NewReader constructs a Reader to read Frames from the io.Reader. func NewReader(r io.Reader) *Reader { return NewReaderWithOptions(r, ReaderOptions{}) } -// NewReaderWithOptions constructs a Reader to read Packets from +// NewReaderWithOptions constructs a Reader to read Frames from // the io.Reader. It uses the provided options to manage buffering. func NewReaderWithOptions(r io.Reader, opts ReaderOptions) *Reader { if opts.MaximumBufferSize == 0 { @@ -50,10 +49,9 @@ func NewReaderWithOptions(r io.Reader, opts ReaderOptions) *Reader { return &Reader{ opts: opts, r: r, - // Err on the side of a smaller buffer since ReadPacket will lazily + // Err on the side of a smaller buffer since ReadFrame will lazily // grow this buffer. curr: make([]byte, 0, 4096), - id: ID{Stream: 1, Message: 1}, } } @@ -76,29 +74,14 @@ func (r *Reader) read(p []byte) (n int, err error) { return 0, drpc.InternalError.Wrap(io.ErrNoProgress) } -// ReadPacket reads a packet from the io.Reader. It is equivalent to -// calling ReadPacketUsing(nil). -func (r *Reader) ReadPacket() (pkt Packet, err error) { - return r.ReadPacketUsing(nil) -} - -// ReadPacketUsing reads a packet from the io.Reader. IDs read from -// frames must be monotonically increasing. When a new ID is read, the -// old data is discarded. This allows for easier asynchronous interrupts. -// If the amount of data in the Packet becomes too large, an error is -// returned. The returned packet's Data field is constructed by appending -// to the provided buf after it has been resliced to be zero length. -func (r *Reader) ReadPacketUsing(buf []byte) (pkt Packet, err error) { - pkt.Data = buf[:0] - - var fr Frame - var ok bool - +// ReadFrame reads a single frame from the io.Reader. +func (r *Reader) ReadFrame() (fr Frame, err error) { for { + var ok bool r.curr, fr, ok, err = ParseFrame(r.curr) switch { case err != nil: - return Packet{}, drpc.ProtocolError.Wrap(err) + return Frame{}, drpc.ProtocolError.Wrap(err) case !ok: // r.curr doesn't have enough data for a full frame, so prepend @@ -115,62 +98,28 @@ func (r *Reader) ReadPacketUsing(buf []byte) (pkt Packet, err error) { n, err := r.read(r.buf[len(r.buf):cap(r.buf)]) if err != nil { - return Packet{}, err + return Frame{}, err } ncap := uint(len(r.buf) + n) if ncap > uint(cap(r.buf)) { - return Packet{}, drpc.ProtocolError.New("data overflow") + return Frame{}, drpc.ProtocolError.New("data overflow") } r.buf = r.buf[:ncap] if len(r.buf)-maxFrameOverhead > r.opts.MaximumBufferSize { - return Packet{}, drpc.ProtocolError.New("data overflow") + return Frame{}, drpc.ProtocolError.New("data overflow") } r.curr = r.buf continue } - // since we got a packet, signal that we need to restore buf with - // whatever remains in r.curr the next time we don't have a packet. + // since we got a frame, signal that we need to restore buf with + // whatever remains in r.curr the next time we don't have a frame. if len(r.buf) > 0 { r.buf = r.buf[:0] } - - // If any frames are set to control, then the whole packet is - // considered to be control. - pkt.Control = pkt.Control || fr.Control - - switch { - case fr.ID.Less(r.id): - return Packet{}, drpc.ProtocolError.New("id monotonicity violation (fr:%v r:%v)", fr.ID, r.id) - - case r.id != fr.ID || pkt.ID == ID{}: - r.id = fr.ID - - pkt = Packet{ - Data: pkt.Data[:0], - ID: fr.ID, - Kind: fr.Kind, - Control: fr.Control, - } - - case fr.Kind != pkt.Kind: - return Packet{}, drpc.ProtocolError.New("packet kind change (fr:%v pkt:%v)", fr.Kind, pkt.Kind) - } - - pkt.Data = append(pkt.Data, fr.Data...) - - switch { - case len(pkt.Data) > r.opts.MaximumBufferSize: - return Packet{}, drpc.ProtocolError.New("data overflow (len:%v)", len(pkt.Data)) - - case fr.Done: - // increment the message id so that we do not accept any frames - // with the same id. - r.id.Message++ - return pkt, nil - } + return fr, nil } } diff --git a/drpcwire/reader_test.go b/drpcwire/reader_test.go index d57145b..4a551ff 100644 --- a/drpcwire/reader_test.go +++ b/drpcwire/reader_test.go @@ -15,176 +15,145 @@ import ( "github.com/zeebo/assert" ) -func TestReader(t *testing.T) { - type testCase struct { - Packets []Packet - Frames []Frame - Error string - Options ReaderOptions - } - - p := func(kind Kind, id uint64, control bool, data string) Packet { - return Packet{ - Data: []byte(data), - ID: ID{Stream: 1, Message: id}, - Kind: kind, - Control: control, - } - } - - f := func(kind Kind, id uint64, data string, done, control bool) Frame { +func TestReadFrame(t *testing.T) { + f := func(kind Kind, sid, mid uint64, data string, done, control bool) Frame { return Frame{ Data: []byte(data), - ID: ID{Stream: 1, Message: id}, + ID: ID{Stream: sid, Message: mid}, Kind: kind, Done: done, Control: control, } } - m := func(pkt Packet, frames ...Frame) testCase { - return testCase{ - Packets: []Packet{pkt}, - Frames: frames, + t.Run("SingleFrame", func(t *testing.T) { + fr := f(KindMessage, 1, 1, "hello", true, false) + + buf := AppendFrame(nil, fr) + rd := NewReader(bytes.NewReader(buf)) + + got, err := rd.ReadFrame() + assert.NoError(t, err) + assert.DeepEqual(t, got, fr) + + _, err = rd.ReadFrame() + assert.That(t, errors.Is(err, io.EOF)) + }) + + t.Run("MultipleFrames", func(t *testing.T) { + // Frames are returned individually even when they share a message + // ID and have done=false. Reader does no assembly — that's the + // stream's job. + frames := []Frame{ + f(KindMessage, 1, 1, "hello", false, false), + f(KindMessage, 1, 1, " ", false, false), + f(KindMessage, 1, 1, "world", true, false), + f(KindClose, 1, 2, "", true, false), } - } - megaFrames := make([]Frame, 0, 10*1024) - for i := 0; i < 10*1024; i++ { - megaFrames = append(megaFrames, f(KindMessage, 1, strings.Repeat("X", 1024), false, false)) - } - megaFrames = append(megaFrames, f(KindMessage, 1, "", true, false)) - - // 1 more than the maximum frame overhead is the minimum required to overflow - const overFrame = maxFrameOverhead + 1 - - cases := []testCase{ - m(p(KindMessage, 1, false, "hello world"), - f(KindMessage, 1, "hello", false, false), - f(KindMessage, 1, " ", false, false), - f(KindMessage, 1, "world", true, false)), - - m(p(KindMessage, 1, true, "hello world"), - f(KindMessage, 1, "hello", false, false), - f(KindMessage, 1, " ", false, true), - f(KindMessage, 1, "world", true, false)), - - m(p(KindClose, 2, false, ""), - f(KindMessage, 1, "hello", false, false), - f(KindMessage, 1, " ", false, false), - f(KindClose, 2, "", true, false)), - - { - Packets: []Packet{ - p(KindClose, 2, false, ""), - }, - Frames: []Frame{ - f(KindMessage, 1, "1", false, false), - f(KindClose, 2, "", true, false), - f(KindMessage, 1, "1", true, false), - }, - Error: "id monotonicity violation", - }, - - { // a single frame that's too large - Frames: []Frame{f(KindMessage, 1, strings.Repeat("X", 4<<20+overFrame), true, false)}, - Error: "data overflow", - }, - - { // a single frame that's too large with limited size - Frames: []Frame{f(KindMessage, 1, strings.Repeat("X", 1000+overFrame), true, false)}, - Error: "data overflow", - Options: ReaderOptions{MaximumBufferSize: 1000}, - }, - - { // multiple frames that make too large a packet - Frames: megaFrames, - Error: "data overflow", - }, - - { // multiple frames that make too large a packet with limited size - Frames: []Frame{ - f(KindMessage, 1, strings.Repeat("X", 500), false, false), - f(KindMessage, 1, strings.Repeat("X", 400), false, false), - f(KindMessage, 1, strings.Repeat("X", 100), false, false), - f(KindMessage, 1, strings.Repeat("X", overFrame), true, false), - }, - Error: "data overflow", - Options: ReaderOptions{MaximumBufferSize: 1000}, - }, - - { // Control bit is preserved - Packets: []Packet{ - p(KindClose, 2, false, ""), - p(KindMessage, 3, true, "ab"), - }, - Frames: []Frame{ - f(KindMessage, 1, "1", false, false), - f(KindClose, 2, "", true, false), - f(KindMessage, 3, "a", false, true), - f(KindMessage, 3, "b", true, false), - }, - }, - - { // packet kind changes - Frames: []Frame{ - f(KindMessage, 1, "", false, false), - f(KindClose, 1, "", false, false), - }, - Error: "packet kind change", - }, - - { // id monotonicity from id reuse - Packets: []Packet{ - p(KindMessage, 1, false, "1"), - }, - Frames: []Frame{ - f(KindMessage, 1, "1", true, false), - f(KindMessage, 1, "2", true, false), - }, - Error: "id monotonicity violation", - }, - - { // message id zero is not allowed - Frames: []Frame{{ID: ID{Stream: 1, Message: 0}}}, - Error: "id monotonicity violation", - }, - - { // stream id zero is not allowed - Frames: []Frame{{ID: ID{Stream: 0, Message: 1}}}, - Error: "id monotonicity violation", - }, - } + var buf []byte + for _, fr := range frames { + buf = AppendFrame(buf, fr) + } + + rd := NewReader(bytes.NewReader(buf)) + for _, exp := range frames { + got, err := rd.ReadFrame() + assert.NoError(t, err) + assert.DeepEqual(t, got, exp) + } + + _, err := rd.ReadFrame() + assert.That(t, errors.Is(err, io.EOF)) + }) + + t.Run("NoMonotonicity", func(t *testing.T) { + // Reader no longer enforces monotonicity. Frames with decreasing + // IDs should be returned without error. + frames := []Frame{ + f(KindMessage, 1, 5, "a", true, false), + f(KindMessage, 1, 3, "b", true, false), + } - for _, tc := range cases { var buf []byte - for _, fr := range tc.Frames { + for _, fr := range frames { buf = AppendFrame(buf, fr) } - rd := NewReaderWithOptions(bytes.NewReader(buf), tc.Options) - for _, expPkt := range tc.Packets { - pkt, err := rd.ReadPacket() + rd := NewReader(bytes.NewReader(buf)) + for _, exp := range frames { + got, err := rd.ReadFrame() assert.NoError(t, err) - assert.DeepEqual(t, expPkt, pkt) + assert.DeepEqual(t, got, exp) } + }) + + t.Run("BufferOverflow_SingleLargeFrame", func(t *testing.T) { + // 1 more than the maximum frame overhead is the minimum required to overflow. + const overFrame = maxFrameOverhead + 1 + fr := f(KindMessage, 1, 1, strings.Repeat("X", 4<<20+overFrame), true, false) - _, err := rd.ReadPacket() + buf := AppendFrame(nil, fr) + rd := NewReader(bytes.NewReader(buf)) + + _, err := rd.ReadFrame() assert.Error(t, err) - if tc.Error != "" { - assert.That(t, strings.Contains(err.Error(), tc.Error)) - } else { - assert.Equal(t, err, io.EOF) - } - } + assert.That(t, strings.Contains(err.Error(), "data overflow")) + }) + + t.Run("BufferOverflow_CustomLimit", func(t *testing.T) { + const overFrame = maxFrameOverhead + 1 + fr := f(KindMessage, 1, 1, strings.Repeat("X", 1000+overFrame), true, false) + + buf := AppendFrame(nil, fr) + rd := NewReaderWithOptions(bytes.NewReader(buf), ReaderOptions{MaximumBufferSize: 1000}) + + _, err := rd.ReadFrame() + assert.Error(t, err) + assert.That(t, strings.Contains(err.Error(), "data overflow")) + }) + + t.Run("ErrorWithData", func(t *testing.T) { + // If the underlying reader returns data and an error together, + // the frame should still be parsed from the data. + rd := NewReader(readerFunc(func(b []byte) (int, error) { + out := AppendFrame(b[:0:8], Frame{ + Data: []byte("test"), + ID: ID{1, 1}, + Kind: KindMessage, + Done: true, + }) + return len(out), io.EOF + })) + + got, err := rd.ReadFrame() + assert.NoError(t, err) + assert.DeepEqual(t, got, Frame{ + Data: []byte("test"), + ID: ID{1, 1}, + Kind: KindMessage, + Done: true, + }) + + _, err = rd.ReadFrame() + assert.That(t, errors.Is(err, io.EOF)) + }) + + t.Run("ErrorNoProgress", func(t *testing.T) { + rd := NewReader(readerFunc(func(b []byte) (int, error) { + return 0, nil + })) + + _, err := rd.ReadFrame() + assert.That(t, errors.Is(err, io.ErrNoProgress)) + }) } -func TestReaderRandomized(t *testing.T) { +func TestReadFrame_Randomized(t *testing.T) { seed := time.Now().UnixNano() t.Log("seed:", seed) rng := rand.New(rand.NewSource(seed)) - // create a function to get a predefined sequence of bytes bid := 0 get := func(n int) []byte { out := make([]byte, n) @@ -195,75 +164,40 @@ func TestReaderRandomized(t *testing.T) { return out } - // construct a random sequence of frames of different sizes - // to attempt to capture any bugs from buffer management + // Build a random sequence of frames with varying sizes. + var frames []Frame var buf []byte mid := uint64(1) done := false for i := 0; i < 1000; i++ { - buf = AppendFrame(buf, Frame{ + data := get(rng.Intn(8192)) + fr := Frame{ ID: ID{Stream: 1, Message: mid}, - Data: get(rng.Intn(8192)), + Data: data, Done: done, - }) + } + frames = append(frames, fr) + buf = AppendFrame(buf, fr) if done { mid++ } - done = rng.Intn(10) == 0 } - // read all of the packets back which should have the - // exact sequence of bytes, so we reset bid to generate - // the sequence again. + // ReadFrame should return each frame individually. bid = 0 r := NewReader(bytes.NewBuffer(buf)) - for i := 1; ; i++ { - pkt, err := r.ReadPacket() - if errors.Is(err, io.EOF) { - break - } + for _, exp := range frames { + got, err := r.ReadFrame() assert.NoError(t, err) - assert.Equal(t, pkt.ID.Message, i) - assert.Equal(t, pkt.Data, get(len(pkt.Data))) + assert.Equal(t, got.ID, exp.ID) + assert.Equal(t, got.Done, exp.Done) + assert.Equal(t, got.Data, get(len(exp.Data))) } } type readerFunc func([]byte) (int, error) func (fn readerFunc) Read(p []byte) (int, error) { return fn(p) } - -func TestReaderErrorWithData(t *testing.T) { - r := NewReader(readerFunc(func(b []byte) (int, error) { - out := AppendFrame(b[:0:8], Frame{ - Data: []byte("test"), - ID: ID{1, 1}, - Kind: KindMessage, - Done: true, - }) - return len(out), io.EOF - })) - - pkt, err := r.ReadPacket() - assert.NoError(t, err) - assert.Equal(t, pkt, Packet{ - Data: []byte("test"), - ID: ID{1, 1}, - Kind: KindMessage, - Control: false, - }) - - _, err = r.ReadPacket() - assert.Equal(t, err, io.EOF) -} - -func TestReaderErrorNoProgress(t *testing.T) { - r := NewReader(readerFunc(func(b []byte) (int, error) { - return 0, nil - })) - - _, err := r.ReadPacket() - assert.That(t, errors.Is(err, io.ErrNoProgress)) -}