diff --git a/drpcmanager/manager.go b/drpcmanager/manager.go index 2207ea3..2301a3d 100644 --- a/drpcmanager/manager.go +++ b/drpcmanager/manager.go @@ -75,12 +75,17 @@ type Manager struct { lastFrameID drpcwire.ID lastFrameKind drpcwire.Kind - sem drpcsignal.Chan // held by the active stream - sbuf streamBuffer // largest stream id created - pkts chan drpcwire.Packet // channel for invoke packets - pdone drpcsignal.Chan // signals when a packets buffers can be reused - sfin chan struct{} // shared signal for stream finished - streams chan streamInfo // channel to signal that a stream should start + sem drpcsignal.Chan // held by the active stream + sbuf streamBuffer // largest stream id created + sfin chan struct{} // shared signal for stream finished + streams chan streamInfo // channel to signal that a stream should start + + pdone drpcsignal.Chan // signals when NewServerStream has registered the new stream + invokes chan invokeInfo // completed invoke info from manageReader to NewServerStream + + // Below fields are owned by the manageReader goroutine, used in handleInvokeFrame. + metadata map[string]string // accumulated invoke metadata + pa drpcwire.PacketAssembler // assembles invoke/metadata frames into packets sigs struct { term drpcsignal.Signal // set when the manager should start terminating @@ -90,6 +95,14 @@ type Manager struct { } } +// invokeInfo carries the assembled invoke data from manageReader to +// NewServerStream. It is reused across invocations; call Reset between uses. +type invokeInfo struct { + sid uint64 + metadata map[string]string + data []byte // RPC name bytes from the KindInvoke packet +} + type streamInfo struct { ctx context.Context stream *drpcstream.Stream @@ -109,7 +122,8 @@ func NewWithOptions(tr drpc.Transport, opts Options) *Manager { rd: drpcwire.NewReaderWithOptions(tr, opts.Reader), opts: opts, - pkts: make(chan drpcwire.Packet), + invokes: make(chan invokeInfo), + sfin: make(chan struct{}, 1), streams: make(chan streamInfo), } @@ -120,10 +134,12 @@ func NewWithOptions(tr drpc.Transport, opts Options) *Manager { // this semaphore controls the number of concurrent streams. it MUST be 1. m.sem.Make(1) - // a buffer of size 1 allows the consumer of the packet to signal it is done - // without having to coordinate with the sender of the packet. + // a buffer of size 1 allows NewServerStream to signal it is done creating a + // new server stream without having to coordinate with manageReader. m.pdone.Make(1) + m.pa = drpcwire.NewPacketAssembler() + // set the internal stream options drpcopts.SetStreamTransport(&m.opts.Stream.Internal, m.tr) drpcopts.SetStreamFin(&m.opts.Stream.Internal, m.sfin) @@ -257,16 +273,8 @@ func (m *Manager) manageReader() { if curr != nil && !curr.IsTerminated() { curr.Cancel(context.Canceled) } - - pkt := drpcwire.Packet{ID: incomingFrame.ID, Kind: incomingFrame.Kind, Data: incomingFrame.Data} - select { - case m.pkts <- pkt: - // Wait for NewServerStream to finish stream creation (including - // sbuf.Set) before reading the next frame. This guarantees curr - // is set for subsequent non-invoke packets. - m.pdone.Recv() - - case <-m.sigs.term.Signal(): + if err := m.handleInvokeFrame(incomingFrame); err != nil { + m.terminate(managerClosed.Wrap(err)) return } @@ -293,6 +301,43 @@ func (m *Manager) checkStreamMonotonicity(incomingFrame drpcwire.Frame) bool { return ok } +// handleInvokeFrame assembles invoke/metadata frames into complete packets and +// forwards the finished invoke info to NewServerStream via m.newServerStreamInfo. +// Metadata packets are accumulated; the invoke packet triggers the send. +func (m *Manager) handleInvokeFrame(fr drpcwire.Frame) error { + pkt, packetReady, err := m.pa.AppendFrame(fr) + if err != nil { + return err + } + if !packetReady { + return nil + } + + // Metadata arrives before invoke; accumulate it and wait for the invoke. + if pkt.Kind == drpcwire.KindInvokeMetadata { + meta, err := drpcmetadata.Decode(pkt.Data) + if err != nil { + return err + } + m.metadata = meta + return nil + } + + // Invoke packet completes the sequence. Send to NewServerStream. + select { + case m.invokes <- invokeInfo{sid: pkt.ID.Stream, data: pkt.Data, metadata: m.metadata}: + // Wait for NewServerStream to finish stream creation (including + // sbuf.Set) before reading the next frame. This guarantees curr + // is set for subsequent non-invoke packets. + m.pdone.Recv() + + m.pa.Reset() + m.metadata = nil + case <-m.sigs.term.Signal(): + } + return nil +} + // // manage streams // @@ -446,8 +491,6 @@ func (m *Manager) NewServerStream(ctx context.Context) (stream *drpcstream.Strea } }() - var meta map[string]string - var metaID uint64 var timeoutCh <-chan time.Time // set up the timeout on the context if necessary. @@ -457,61 +500,40 @@ func (m *Manager) NewServerStream(ctx context.Context) (stream *drpcstream.Strea timeoutCh = timer.C } - for { - select { - case <-timeoutCh: - return nil, "", context.DeadlineExceeded + select { + case <-timeoutCh: + return nil, "", context.DeadlineExceeded - case <-ctx.Done(): - return nil, "", ctx.Err() + case <-ctx.Done(): + return nil, "", ctx.Err() - case <-m.sigs.term.Signal(): - return nil, "", m.sigs.term.Err() - - case pkt := <-m.pkts: - switch pkt.Kind { - // keep track of any metadata being sent before an invoke so that we - // can include it if the stream id matches the eventual invoke. - case drpcwire.KindInvokeMetadata: - meta, err = drpcmetadata.Decode(pkt.Data) - m.pdone.Send() - - if err != nil { - return nil, "", err - } - metaID = pkt.ID.Stream - - case drpcwire.KindInvoke: - rpc = string(pkt.Data) - - if metaID == pkt.ID.Stream { - if m.opts.GRPCMetadataCompatMode { - // Populate incoming metadata as grpc metadata in the - // context. This is a short-term fix that will enable us - // to send and receive grpc metadata when DRPC is enabled, - // without any changes in the calling code. - grpcMeta := make(map[string][]string, len(meta)) - for k, v := range meta { - grpcMeta[k] = []string{v} - } - ctx = grpcmetadata.NewIncomingContext(ctx, grpcMeta) - } else { - // Add metadata to the incoming context. - ctx = drpcmetadata.NewIncomingContext(ctx, meta) - } + case <-m.sigs.term.Signal(): + return nil, "", m.sigs.term.Err() + + case pkt := <-m.invokes: + rpc = string(pkt.data) + if pkt.metadata != nil { + if m.opts.GRPCMetadataCompatMode { + // Populate incoming metadata as grpc metadata in the + // context. This is a short-term fix that will enable us + // to send and receive grpc metadata when DRPC is enabled, + // without any changes in the calling code. + grpcMeta := make(map[string][]string, len(pkt.metadata)) + for k, v := range pkt.metadata { + grpcMeta[k] = []string{v} } - stream, err := m.newStream(ctx, pkt.ID.Stream, drpc.StreamKindServer, rpc) - // Signal pdone only after stream registration so that - // manageReader sees the new stream via sbuf.Get() when it reads - // the next frame. - m.pdone.Send() - return stream, rpc, err - - default: - // this should never happen, but defensive. - m.pdone.Send() + ctx = grpcmetadata.NewIncomingContext(ctx, grpcMeta) + } else { + // Add metadata to the incoming context. + ctx = drpcmetadata.NewIncomingContext(ctx, pkt.metadata) } } + stream, err := m.newStream(ctx, pkt.sid, drpc.StreamKindServer, rpc) + // Signal pdone only after stream registration so that + // manageReader sees the new stream via sbuf.Get() when it reads + // the next frame. + m.pdone.Send() + return stream, rpc, err } } diff --git a/drpcstream/stream.go b/drpcstream/stream.go index 7fd769e..f275f6d 100644 --- a/drpcstream/stream.go +++ b/drpcstream/stream.go @@ -52,10 +52,7 @@ type Stream struct { read inspectMutex flush sync.Once - assembling bool - pktBuf []byte - pktKind drpcwire.Kind - nextMessageID uint64 + pa drpcwire.PacketAssembler id drpcwire.ID wr *drpcwire.Writer @@ -94,6 +91,9 @@ func NewWithOptions(ctx context.Context, sid uint64, wr *drpcwire.Writer, opts O } } + pa := drpcwire.NewPacketAssembler() + pa.SetStreamID(sid) + s := &Stream{ ctx: streamCtx{ Context: ctx, @@ -103,7 +103,7 @@ func NewWithOptions(ctx context.Context, sid uint64, wr *drpcwire.Writer, opts O fin: drpcopts.GetStreamFin(&opts.Internal), task: task, - nextMessageID: 1, + pa: pa, id: drpcwire.ID{Stream: sid}, wr: wr.Reset(), @@ -228,47 +228,14 @@ func (s *Stream) HandleFrame(fr drpcwire.Frame) (err error) { return nil } - if fr.ID.Stream != s.ID() { - return drpc.ProtocolError.New("frame doesn't belong to this stream (fr: %v)", fr.ID) - } - - if fr.ID.Message < s.nextMessageID { - return drpc.ProtocolError.New( - "id monotonicity violation: frame %v has message ID less than expected %v", fr.ID, s.nextMessageID) - } else if fr.ID.Message > s.nextMessageID || !s.assembling { - s.pktBuf = s.pktBuf[:0] - s.assembling = true - s.nextMessageID = fr.ID.Message - } else if fr.Kind != s.pktKind { - return drpc.ProtocolError.New("frame kind change within packet: got %v, expected %v", fr.Kind, s.pktKind) - } - - // TODO(shubham): add buf reuse - s.pktBuf = append(s.pktBuf, fr.Data...) - - s.pktKind = fr.Kind - - if s.opts.MaximumBufferSize > 0 && len(s.pktBuf) > s.opts.MaximumBufferSize { - return drpc.ProtocolError.New("data overflow (len:%d)", len(s.pktBuf)) + packet, packetReady, err := s.pa.AppendFrame(fr) + if err != nil { + return err } - - if !fr.Done { + if !packetReady { return nil } - - s.assembling = false - s.nextMessageID = fr.ID.Message + 1 - - err = s.handlePacket(drpcwire.Packet{ - ID: fr.ID, - Kind: fr.Kind, - Control: fr.Control, - Data: s.pktBuf, - }) - - // TODO(shubham): add buf reuse - s.pktBuf = nil - return err + return s.handlePacket(packet) } // handlePacket advances the stream state machine by inspecting the packet. It diff --git a/drpcstream/stream_test.go b/drpcstream/stream_test.go index a62ab87..62e4944 100644 --- a/drpcstream/stream_test.go +++ b/drpcstream/stream_test.go @@ -327,40 +327,6 @@ func TestStream_SendCancelBusyDuringBlockedClose(t *testing.T) { // HandleFrame tests // -// A frame routed to the wrong stream is a protocol error. -func TestHandleFrame_WrongStreamID(t *testing.T) { - st := New(context.Background(), 1, drpcwire.NewWriter(io.Discard, 0)) - - err := st.HandleFrame(drpcwire.Frame{ - ID: drpcwire.ID{Stream: 2, Message: 1}, - Kind: drpcwire.KindMessage, - Done: true, - }) - assert.Error(t, err) - assert.That(t, drpc.ProtocolError.Has(err)) - assert.That(t, strings.Contains(err.Error(), "doesn't belong")) -} - -// A frame with a message ID lower than a previously completed message is rejected. -func TestHandleFrame_MessageMonotonicity(t *testing.T) { - st := New(context.Background(), 1, drpcwire.NewWriter(io.Discard, 0)) - // Close packet buffer so KindMessage delivery doesn't block. - st.pbuf.Close(io.EOF) - - // m3 completes, nextMessageID becomes 4. - assert.NoError(t, st.HandleFrame(drpcwire.Frame{ - ID: drpcwire.ID{Stream: 1, Message: 3}, Kind: drpcwire.KindMessage, Done: true, - })) - - // m2 < 4 → error. - err := st.HandleFrame(drpcwire.Frame{ - ID: drpcwire.ID{Stream: 1, Message: 2}, Kind: drpcwire.KindMessage, Done: true, - }) - assert.Error(t, err) - assert.That(t, drpc.ProtocolError.Has(err)) - assert.That(t, strings.Contains(err.Error(), "monotonicity")) -} - func TestHandleFrame_FirstFrameOnFreshStream(t *testing.T) { // On the client side, the first message received will have ID 1. But on the // server side, invoke is consumed by the manager. The first frame reaching @@ -377,174 +343,6 @@ func TestHandleFrame_FirstFrameOnFreshStream(t *testing.T) { } } -// When a higher message ID arrives mid-assembly, the in-progress data is -// silently discarded and a new packet begins. -func TestHandleFrame_HigherMsgDiscardsInProgress(t *testing.T) { - ctx := drpctest.NewTracker(t) - defer ctx.Close() - - st := New(ctx, 1, drpcwire.NewWriter(io.Discard, 0)) - - // Start accumulating m1 (done=false doesn't call Put, so no blocking). - assert.NoError(t, st.HandleFrame(drpcwire.Frame{ - ID: drpcwire.ID{Stream: 1, Message: 1}, Kind: drpcwire.KindMessage, Data: []byte("discard"), Done: false, - })) - - // Launch receiver before sending done frame to avoid Put blocking. - recv := make(chan []byte, 1) - ctx.Run(func(ctx context.Context) { - data, err := st.RawRecv() - assert.NoError(t, err) - recv <- data - }) - - // m2 arrives, m1 data should be silently discarded. - assert.NoError(t, st.HandleFrame(drpcwire.Frame{ - ID: drpcwire.ID{Stream: 1, Message: 2}, Kind: drpcwire.KindMessage, Data: []byte("kept"), Done: true, - })) - - // Verify only m2's data was delivered. - assert.DeepEqual(t, <-recv, []byte("kept")) -} - -// Continuation frames (same message ID, mid-assembly) must carry the same -// kind as the first frame. A kind change mid-packet is a protocol error. -func TestHandleFrame_KindChangeWithinPacket(t *testing.T) { - st := New(context.Background(), 1, drpcwire.NewWriter(io.Discard, 0)) - - assert.NoError(t, st.HandleFrame(drpcwire.Frame{ - ID: drpcwire.ID{Stream: 1, Message: 1}, Kind: drpcwire.KindMessage, Done: false, - })) - - err := st.HandleFrame(drpcwire.Frame{ - ID: drpcwire.ID{Stream: 1, Message: 1}, Kind: drpcwire.KindError, Done: true, - }) - assert.Error(t, err) - assert.That(t, drpc.ProtocolError.Has(err)) - assert.That(t, strings.Contains(err.Error(), "kind change")) -} - -// Multiple continuation frames for the same message accumulate data correctly. -func TestHandleFrame_MultiFrameDataAccumulation(t *testing.T) { - ctx := drpctest.NewTracker(t) - defer ctx.Close() - - st := New(ctx, 1, drpcwire.NewWriter(io.Discard, 0)) - - // Continuation frames (done=false) don't call Put, so no blocking. - assert.NoError(t, st.HandleFrame(drpcwire.Frame{ - ID: drpcwire.ID{Stream: 1, Message: 1}, Kind: drpcwire.KindMessage, Data: []byte("hel"), Done: false, - })) - assert.NoError(t, st.HandleFrame(drpcwire.Frame{ - ID: drpcwire.ID{Stream: 1, Message: 1}, Kind: drpcwire.KindMessage, Data: []byte("lo "), Done: false, - })) - - // Launch receiver before the final frame to avoid Put blocking. - recv := make(chan []byte, 1) - ctx.Run(func(ctx context.Context) { - data, err := st.RawRecv() - assert.NoError(t, err) - recv <- data - }) - - assert.NoError(t, st.HandleFrame(drpcwire.Frame{ - ID: drpcwire.ID{Stream: 1, Message: 1}, Kind: drpcwire.KindMessage, Data: []byte("world"), Done: true, - })) - - assert.DeepEqual(t, <-recv, []byte("hello world")) -} - -// Multi-frame assembly works when the message ID is greater than nextMessageID -// (e.g., on the server side where invoke consumed earlier message IDs). -// Continuation frames must accumulate data, not reset on each frame. -func TestHandleFrame_MultiFrameWithSkippedMessageID(t *testing.T) { - ctx := drpctest.NewTracker(t) - defer ctx.Close() - - st := New(ctx, 1, drpcwire.NewWriter(io.Discard, 0)) - - // msg=3 is greater than nextMessageID=1. Continuation frames for the - // same message must still accumulate correctly. - assert.NoError(t, st.HandleFrame(drpcwire.Frame{ - ID: drpcwire.ID{Stream: 1, Message: 3}, Kind: drpcwire.KindMessage, Data: []byte("hel"), Done: false, - })) - assert.NoError(t, st.HandleFrame(drpcwire.Frame{ - ID: drpcwire.ID{Stream: 1, Message: 3}, Kind: drpcwire.KindMessage, Data: []byte("lo"), Done: false, - })) - - recv := make(chan []byte, 1) - ctx.Run(func(ctx context.Context) { - data, err := st.RawRecv() - assert.NoError(t, err) - recv <- data - }) - - assert.NoError(t, st.HandleFrame(drpcwire.Frame{ - ID: drpcwire.ID{Stream: 1, Message: 3}, Kind: drpcwire.KindMessage, Data: []byte(" world"), Done: true, - })) - - assert.DeepEqual(t, <-recv, []byte("hello world")) -} - -// Once a message completes (done=true), the same message ID is rejected. -func TestHandleFrame_DonePreventsReplay(t *testing.T) { - st := New(context.Background(), 1, drpcwire.NewWriter(io.Discard, 0)) - // Close packet buffer so KindMessage delivery doesn't block. - st.pbuf.Close(io.EOF) - - // m1 completes → nextMessageID becomes 2. - assert.NoError(t, st.HandleFrame(drpcwire.Frame{ - ID: drpcwire.ID{Stream: 1, Message: 1}, Kind: drpcwire.KindMessage, Done: true, - })) - - // Same message ID again → error. - err := st.HandleFrame(drpcwire.Frame{ - ID: drpcwire.ID{Stream: 1, Message: 1}, Kind: drpcwire.KindMessage, Done: true, - }) - assert.Error(t, err) - assert.That(t, drpc.ProtocolError.Has(err)) - assert.That(t, strings.Contains(err.Error(), "monotonicity")) -} - -// Kind consistency is only enforced within a packet (continuation frames), not -// across messages. A multi-frame KindMessage followed by a KindClose for the -// next message should be accepted without error. -func TestHandleFrame_MultiFrameThenNextMessage(t *testing.T) { - ctx := drpctest.NewTracker(t) - defer ctx.Close() - - st := New(ctx, 1, drpcwire.NewWriter(io.Discard, 0)) - - assert.NoError(t, st.HandleFrame(drpcwire.Frame{ - ID: drpcwire.ID{Stream: 1, Message: 1}, Kind: drpcwire.KindMessage, Data: []byte("ab"), Done: false, - })) - - // Launch receiver before done frame. - recv := make(chan []byte, 1) - ctx.Run(func(ctx context.Context) { - data, err := st.RawRecv() - assert.NoError(t, err) - recv <- data - }) - - assert.NoError(t, st.HandleFrame(drpcwire.Frame{ - ID: drpcwire.ID{Stream: 1, Message: 1}, Kind: drpcwire.KindMessage, Data: []byte("cd"), Done: true, - })) - assert.DeepEqual(t, <-recv, []byte("abcd")) - - // Message 2 with a different kind — should not trigger kind check. - assert.NoError(t, st.HandleFrame(drpcwire.Frame{ - ID: drpcwire.ID{Stream: 1, Message: 2}, Kind: drpcwire.KindClose, Done: true, - })) - - // Close triggers EOF on recv. - ctx.Run(func(ctx context.Context) { - _, err := st.RawRecv() - assert.That(t, errors.Is(err, io.EOF)) - }) - ctx.Wait() -} - // Invoke and InvokeMetadata frames are rejected on an already-created stream. func TestHandleFrame_InvokeOnExistingStream(t *testing.T) { st := New(context.Background(), 1, drpcwire.NewWriter(io.Discard, 0)) diff --git a/drpcwire/packet_assembler.go b/drpcwire/packet_assembler.go new file mode 100644 index 0000000..2cf7fae --- /dev/null +++ b/drpcwire/packet_assembler.go @@ -0,0 +1,89 @@ +package drpcwire + +import ( + "storj.io/drpc" +) + +// PacketAssembler assembles frames into complete packets, enforcing wire +// protocol invariants: +// - All frames must belong to the same stream ID (set explicitly via +// SetStreamID, or inferred from the first frame). +// - Message IDs must be monotonically increasing. +// - Frame kind must not change within a single packet (multi-frame). +// +// It is not safe for concurrent use. +type PacketAssembler struct { + pk Packet + assembling bool + streamInitialized bool +} + +// NewPacketAssembler returns a new PacketAssembler ready to assemble frames. +func NewPacketAssembler() PacketAssembler { + return PacketAssembler{ + pk: Packet{ + ID: ID{Stream: 0, Message: 1}, + }, + } +} + +// SetStreamID sets the expected stream ID. Frames for a different stream will +// be rejected. If not called, the stream ID is inferred from the first frame. +func (pa *PacketAssembler) SetStreamID(streamID uint64) { + pa.pk.ID.Stream = streamID + pa.streamInitialized = true +} + +// Reset clears all assembly state, preparing the assembler for a new stream. +func (pa *PacketAssembler) Reset() { + pa.pk = Packet{ + ID: ID{Stream: 0, Message: 1}, + } + pa.assembling = false + pa.streamInitialized = false +} + +// AppendFrame adds a frame to the in-progress packet. It returns the completed +// packet and true when a frame with Done=true is received. It returns false +// when more frames are needed to complete the packet. +func (pa *PacketAssembler) AppendFrame(fr Frame) (packet Packet, packetReady bool, err error) { + // Enforce stream ID consistency: infer from first frame or reject mismatches. + if !pa.streamInitialized { + pa.pk.ID.Stream = fr.ID.Stream + pa.streamInitialized = true + } else if fr.ID.Stream != pa.pk.ID.Stream { + return Packet{}, false, drpc.ProtocolError.New( + "frame stream mismatch: got stream %d, expected %d", fr.ID.Stream, pa.pk.ID.Stream) + } + + if fr.ID.Message < pa.pk.ID.Message { + return Packet{}, false, drpc.ProtocolError.New( + "message id monotonicity violation: got %v, expected >= %v", fr.ID.Message, pa.pk.ID.Message) + } else if fr.ID.Message > pa.pk.ID.Message || !pa.assembling { + // New message: reset the buffer and start assembling. + pa.pk.Data = pa.pk.Data[:0] + pa.assembling = true + pa.pk.ID.Message = fr.ID.Message + } else if fr.Kind != pa.pk.Kind { + return Packet{}, false, drpc.ProtocolError.New( + "frame kind changed mid-packet: got %v, expected %v", fr.Kind, pa.pk.Kind) + } + + // TODO(shubham): add buf reuse + pa.pk.Data = append(pa.pk.Data, fr.Data...) + pa.pk.Kind = fr.Kind + pa.pk.Control = fr.Control + + if !fr.Done { + return Packet{}, false, nil + } + + packet = pa.pk + + pa.assembling = false + pa.pk.ID.Message = fr.ID.Message + 1 + // Reuse the backing array: the caller must consume packet.Data before the + // next AppendFrame call, as it will be overwritten. + pa.pk.Data = pa.pk.Data[:0] + return packet, true, nil +} diff --git a/drpcwire/packet_assembler_test.go b/drpcwire/packet_assembler_test.go new file mode 100644 index 0000000..41cf70d --- /dev/null +++ b/drpcwire/packet_assembler_test.go @@ -0,0 +1,235 @@ +// Copyright (C) 2026 Cockroach Labs. +// See LICENSE for copying information. + +package drpcwire + +import ( + "strings" + "testing" + + "github.com/zeebo/assert" + + "storj.io/drpc" +) + +func TestPacketAssembler_WrongStreamID(t *testing.T) { + pa := NewPacketAssembler() + pa.SetStreamID(1) + + _, _, err := pa.AppendFrame(Frame{ + ID: ID{Stream: 2, Message: 1}, + Kind: KindMessage, + Done: true, + }) + assert.Error(t, err) + assert.That(t, drpc.ProtocolError.Has(err)) + assert.That(t, strings.Contains(err.Error(), "frame stream mismatch")) +} + +func TestPacketAssembler_StreamIDInferredFromFirstFrame(t *testing.T) { + pa := NewPacketAssembler() + + // First frame sets the stream ID implicitly. + _, _, err := pa.AppendFrame(Frame{ + ID: ID{Stream: 5, Message: 1}, + Kind: KindMessage, + Done: true, + }) + assert.NoError(t, err) + + // Second frame for a different stream is rejected. + _, _, err = pa.AppendFrame(Frame{ + ID: ID{Stream: 6, Message: 2}, + Kind: KindMessage, + Done: true, + }) + assert.Error(t, err) + assert.That(t, strings.Contains(err.Error(), "frame stream mismatch")) +} + +// A frame with a message ID lower than a previously completed message is rejected. +func TestPacketAssembler_MessageMonotonicity(t *testing.T) { + pa := NewPacketAssembler() + pa.SetStreamID(1) + + // m3 completes, next expected becomes 4. + _, _, err := pa.AppendFrame(Frame{ + ID: ID{Stream: 1, Message: 3}, Kind: KindMessage, Done: true, + }) + assert.NoError(t, err) + + // m2 < 4 → error. + _, _, err = pa.AppendFrame(Frame{ + ID: ID{Stream: 1, Message: 2}, Kind: KindMessage, Done: true, + }) + assert.Error(t, err) + assert.That(t, drpc.ProtocolError.Has(err)) + assert.That(t, strings.Contains(err.Error(), "monotonicity")) +} + +// When a higher message ID arrives mid-assembly, the in-progress data is +// silently discarded and a new packet begins. +func TestPacketAssembler_HigherMsgDiscardsInProgress(t *testing.T) { + pa := NewPacketAssembler() + pa.SetStreamID(1) + + // Start accumulating m1. + _, ready, err := pa.AppendFrame(Frame{ + ID: ID{Stream: 1, Message: 1}, Kind: KindMessage, Data: []byte("discard"), Done: false, + }) + assert.NoError(t, err) + assert.That(t, !ready) + + // m2 arrives, m1 data should be silently discarded. + pkt, ready, err := pa.AppendFrame(Frame{ + ID: ID{Stream: 1, Message: 2}, Kind: KindMessage, Data: []byte("kept"), Done: true, + }) + assert.NoError(t, err) + assert.That(t, ready) + assert.DeepEqual(t, pkt.Data, []byte("kept")) +} + +// Continuation frames (same message ID, mid-assembly) must carry the same +// kind as the first frame. A kind change mid-packet is a protocol error. +func TestPacketAssembler_KindChangeWithinPacket(t *testing.T) { + pa := NewPacketAssembler() + pa.SetStreamID(1) + + _, _, err := pa.AppendFrame(Frame{ + ID: ID{Stream: 1, Message: 1}, Kind: KindMessage, Done: false, + }) + assert.NoError(t, err) + + _, _, err = pa.AppendFrame(Frame{ + ID: ID{Stream: 1, Message: 1}, Kind: KindError, Done: true, + }) + assert.Error(t, err) + assert.That(t, drpc.ProtocolError.Has(err)) + assert.That(t, strings.Contains(err.Error(), "kind change")) +} + +// Multiple continuation frames for the same message accumulate data correctly. +func TestPacketAssembler_MultiFrameDataAccumulation(t *testing.T) { + pa := NewPacketAssembler() + pa.SetStreamID(1) + + _, ready, err := pa.AppendFrame(Frame{ + ID: ID{Stream: 1, Message: 1}, Kind: KindMessage, Data: []byte("hel"), Done: false, + }) + assert.NoError(t, err) + assert.That(t, !ready) + + _, ready, err = pa.AppendFrame(Frame{ + ID: ID{Stream: 1, Message: 1}, Kind: KindMessage, Data: []byte("lo "), Done: false, + }) + assert.NoError(t, err) + assert.That(t, !ready) + + pkt, ready, err := pa.AppendFrame(Frame{ + ID: ID{Stream: 1, Message: 1}, Kind: KindMessage, Data: []byte("world"), Done: true, + }) + assert.NoError(t, err) + assert.That(t, ready) + assert.DeepEqual(t, pkt.Data, []byte("hello world")) +} + +// Multi-frame assembly works when the message ID is greater than the initial +// expected ID (e.g., on the server side where invoke consumed earlier message +// IDs). Continuation frames must accumulate data, not reset on each frame. +func TestPacketAssembler_MultiFrameWithSkippedMessageID(t *testing.T) { + pa := NewPacketAssembler() + pa.SetStreamID(1) + + // msg=3 is greater than initial expected message ID=1. + _, ready, err := pa.AppendFrame(Frame{ + ID: ID{Stream: 1, Message: 3}, Kind: KindMessage, Data: []byte("hel"), Done: false, + }) + assert.NoError(t, err) + assert.That(t, !ready) + + _, ready, err = pa.AppendFrame(Frame{ + ID: ID{Stream: 1, Message: 3}, Kind: KindMessage, Data: []byte("lo"), Done: false, + }) + assert.NoError(t, err) + assert.That(t, !ready) + + pkt, ready, err := pa.AppendFrame(Frame{ + ID: ID{Stream: 1, Message: 3}, Kind: KindMessage, Data: []byte(" world"), Done: true, + }) + assert.NoError(t, err) + assert.That(t, ready) + assert.DeepEqual(t, pkt.Data, []byte("hello world")) +} + +// Once a message completes (done=true), the same message ID is rejected. +func TestPacketAssembler_DonePreventsReplay(t *testing.T) { + pa := NewPacketAssembler() + pa.SetStreamID(1) + + // m1 completes → next expected becomes 2. + _, _, err := pa.AppendFrame(Frame{ + ID: ID{Stream: 1, Message: 1}, Kind: KindMessage, Done: true, + }) + assert.NoError(t, err) + + // Same message ID again → error. + _, _, err = pa.AppendFrame(Frame{ + ID: ID{Stream: 1, Message: 1}, Kind: KindMessage, Done: true, + }) + assert.Error(t, err) + assert.That(t, drpc.ProtocolError.Has(err)) + assert.That(t, strings.Contains(err.Error(), "monotonicity")) +} + +// Kind consistency is only enforced within a packet (continuation frames), not +// across messages. A KindMessage followed by a KindClose for the next message +// should be accepted without error. +func TestPacketAssembler_KindChangeAcrossMessages(t *testing.T) { + pa := NewPacketAssembler() + pa.SetStreamID(1) + + // Multi-frame message 1 with KindMessage. + _, _, err := pa.AppendFrame(Frame{ + ID: ID{Stream: 1, Message: 1}, Kind: KindMessage, Data: []byte("ab"), Done: false, + }) + assert.NoError(t, err) + + pkt, ready, err := pa.AppendFrame(Frame{ + ID: ID{Stream: 1, Message: 1}, Kind: KindMessage, Data: []byte("cd"), Done: true, + }) + assert.NoError(t, err) + assert.That(t, ready) + assert.DeepEqual(t, pkt.Data, []byte("abcd")) + + // Message 2 with a different kind — should not trigger kind check. + pkt, ready, err = pa.AppendFrame(Frame{ + ID: ID{Stream: 1, Message: 2}, Kind: KindClose, Done: true, + }) + assert.NoError(t, err) + assert.That(t, ready) + assert.Equal(t, pkt.Kind, KindClose) +} + +// Reset clears all state so the assembler can be reused for a new stream. +func TestPacketAssembler_Reset(t *testing.T) { + pa := NewPacketAssembler() + pa.SetStreamID(1) + + // Complete a packet on stream 1. + _, _, err := pa.AppendFrame(Frame{ + ID: ID{Stream: 1, Message: 1}, Kind: KindMessage, Done: true, + }) + assert.NoError(t, err) + + // After reset, stream ID is cleared and must be re-inferred. + pa.Reset() + + // A frame for stream 2 should now be accepted. + pkt, ready, err := pa.AppendFrame(Frame{ + ID: ID{Stream: 2, Message: 1}, Kind: KindMessage, Data: []byte("new"), Done: true, + }) + assert.NoError(t, err) + assert.That(t, ready) + assert.DeepEqual(t, pkt.Data, []byte("new")) + assert.Equal(t, pkt.ID.Stream, uint64(2)) +}