From 3a7fff253806d59db718d71477175f00eb3a77ee Mon Sep 17 00:00:00 2001 From: MicQo Date: Thu, 21 Aug 2025 16:59:20 +0200 Subject: [PATCH 1/4] add stream encoder --- core/stream.go | 508 +++++++++++++++++++++++++++++++++++++++- core/stream_test.go | 162 +++++++++++++ examples/encode/main.go | 28 ++- 3 files changed, 692 insertions(+), 6 deletions(-) create mode 100644 core/stream_test.go diff --git a/core/stream.go b/core/stream.go index 423df58..0542a09 100644 --- a/core/stream.go +++ b/core/stream.go @@ -204,7 +204,7 @@ func (dec *Decoder) readStreamEntryContent(buf []byte, cursor *int, firstId *mod msg.Fields[fieldName] = unsafeBytes2Str(fieldValue) } // read lp count - if _, err = dec.readListPackEntryAsString(buf, cursor); err != nil { + if _, err = dec.readListPackEntryAsString(buf, cursor); err != nil { return nil, fmt.Errorf("read fields end flag failed: %v", err) } msgs = append(msgs, msg) @@ -331,3 +331,509 @@ func (dec *Decoder) readStreamGroups(version uint) ([]*model.StreamGroup, error) } return groups, nil } + +// WriteStreamObject writes a stream object to RDB file +func (enc *Encoder) WriteStreamObject(key string, stream *model.StreamObject, options ...interface{}) error { + err := enc.beforeWriteObject(options...) + if err != nil { + return err + } + + // Write stream type based on version + var streamType byte + switch stream.Version { + case 1: + streamType = typeStreamListPacks + case 2: + streamType = typeStreamListPacks2 + case 3: + streamType = typeStreamListPacks3 + default: + streamType = typeStreamListPacks // default to version 1 + } + + err = enc.write([]byte{streamType}) + if err != nil { + return err + } + + err = enc.writeString(key) + if err != nil { + return err + } + + // Write stream entries + err = enc.writeStreamEntries(stream.Entries) + if err != nil { + return err + } + + // Write stream length + err = enc.writeLength(stream.Length) + if err != nil { + return err + } + + // Write last ID + err = enc.writeStreamId(stream.LastId) + if err != nil { + return err + } + + // Write version 2+ fields if available + if stream.Version >= 2 { + if stream.FirstId != nil { + err = enc.writeStreamId(stream.FirstId) + if err != nil { + return err + } + } else { + // Write zero ID if FirstId is nil + err = enc.writeStreamId(&model.StreamId{Ms: 0, Sequence: 0}) + if err != nil { + return err + } + } + + if stream.MaxDeletedId != nil { + err = enc.writeStreamId(stream.MaxDeletedId) + if err != nil { + return err + } + } else { + // Write zero ID if MaxDeletedId is nil + err = enc.writeStreamId(&model.StreamId{Ms: 0, Sequence: 0}) + if err != nil { + return err + } + } + + err = enc.writeLength(stream.AddedEntriesCount) + if err != nil { + return err + } + } + + // Write stream groups + err = enc.writeStreamGroups(stream.Groups, stream.Version) + if err != nil { + return err + } + + enc.state = writtenObjectState + return nil +} + +// writeStreamId writes a stream ID +func (enc *Encoder) writeStreamId(id *model.StreamId) error { + err := enc.writeLength(id.Ms) + if err != nil { + return err + } + return enc.writeLength(id.Sequence) +} + +// writeStreamEntries writes stream entries +func (enc *Encoder) writeStreamEntries(entries []*model.StreamEntry) error { + err := enc.writeLength(uint64(len(entries))) + if err != nil { + return err + } + + for _, entry := range entries { + // Write entry header (first message ID) + header := make([]byte, 16) + binary.BigEndian.PutUint64(header[0:8], entry.FirstMsgId.Ms) + binary.BigEndian.PutUint64(header[8:16], entry.FirstMsgId.Sequence) + err = enc.writeString(unsafeBytes2Str(header)) + if err != nil { + return err + } + + // Write entry content as listpack + err = enc.writeStreamEntryContent(entry) + if err != nil { + return err + } + } + + return nil +} + +// writeStreamEntryContent writes a stream entry content as listpack +func (enc *Encoder) writeStreamEntryContent(entry *model.StreamEntry) error { + // Calculate total messages (including deleted ones) + totalMsgs := len(entry.Msgs) + deletedCount := 0 + for _, msg := range entry.Msgs { + if msg.Deleted { + deletedCount++ + } + } + validCount := totalMsgs - deletedCount + + // Create listpack buffer with proper backlen values + var listpack []byte + var entries []listpackEntry + + // Add count and deleted count + entries = append(entries, listpackEntry{intVal: int64(validCount)}) + entries = append(entries, listpackEntry{intVal: int64(deletedCount)}) + + // Add master field names + entries = append(entries, listpackEntry{intVal: int64(len(entry.Fields))}) + for _, field := range entry.Fields { + entries = append(entries, listpackEntry{strVal: field}) + } + // Add end marker for master fields + entries = append(entries, listpackEntry{isEnd: true}) + + // Add messages + for _, msg := range entry.Msgs { + // Calculate flag + flag := StreamItemFlagNone + if msg.Deleted { + flag |= StreamItemFlagDeleted + } + + // Check if message uses same fields as master + if len(msg.Fields) == len(entry.Fields) { + sameFields := true + for _, field := range entry.Fields { + if msg.Fields[field] == "" { + sameFields = false + break + } + } + if sameFields { + flag |= StreamItemFlagSameFields + } + } + + // Add flag + entries = append(entries, listpackEntry{intVal: int64(flag)}) + + // Add message ID (relative to first message ID) + msDiff := int64(msg.Id.Ms) - int64(entry.FirstMsgId.Ms) + seqDiff := int64(msg.Id.Sequence) - int64(entry.FirstMsgId.Sequence) + entries = append(entries, listpackEntry{intVal: msDiff}) + entries = append(entries, listpackEntry{intVal: seqDiff}) + + // Add field count if not same fields + if flag&StreamItemFlagSameFields == 0 { + entries = append(entries, listpackEntry{intVal: int64(len(msg.Fields))}) + } + + // Add fields + if flag&StreamItemFlagSameFields > 0 { + // Use master field names + for _, field := range entry.Fields { + value := msg.Fields[field] + entries = append(entries, listpackEntry{strVal: value}) + } + } else { + // Add field names and values + for fieldName, fieldValue := range msg.Fields { + entries = append(entries, listpackEntry{strVal: fieldName}) + entries = append(entries, listpackEntry{strVal: fieldValue}) + } + } + + // Add end marker for message fields + entries = append(entries, listpackEntry{isEnd: true}) + } + + // Build listpack with proper backlen values + listpack = enc.buildListpack(entries) + + // Write the complete listpack + err := enc.writeString(unsafeBytes2Str(listpack)) + if err != nil { + return err + } + + return nil +} + +type listpackEntry struct { + intVal int64 + strVal string + isEnd bool +} + +// buildListpack builds a proper listpack with backlen values +func (enc *Encoder) buildListpack(entries []listpackEntry) []byte { + var listpack []byte + var entrySizes []uint32 + + // First pass: encode entries and calculate sizes + for _, entry := range entries { + var encoded []byte + if entry.isEnd { + encoded = []byte{0xFF} + } else if entry.strVal != "" { + encoded = enc.encodeListPackString(entry.strVal) + } else { + encoded = enc.encodeListPackInt(entry.intVal) + } + listpack = append(listpack, encoded...) + entrySizes = append(entrySizes, uint32(len(encoded))) + } + + // Second pass: add backlen values + var finalListpack []byte + for i := len(entries) - 1; i >= 0; i-- { + // Add backlen + backlen := enc.encodeBacklen(entrySizes[i]) + finalListpack = append(backlen, finalListpack...) + // Add entry + entryStart := 0 + for j := 0; j < i; j++ { + entryStart += int(entrySizes[j]) + } + entryEnd := entryStart + int(entrySizes[i]) + finalListpack = append(listpack[entryStart:entryEnd], finalListpack...) + } + + // Add header + totalBytes := len(finalListpack) + 6 // 6 bytes for header + header := make([]byte, 6) + binary.LittleEndian.PutUint32(header[0:4], uint32(totalBytes)) + binary.LittleEndian.PutUint16(header[4:6], uint16(len(entries))) + + return append(header, finalListpack...) +} + +// encodeBacklen encodes a backlen value +func (enc *Encoder) encodeBacklen(elementLen uint32) []byte { + if elementLen <= 127 { + return []byte{byte(elementLen)} + } else if elementLen < (1<<14)-1 { + return []byte{ + byte(0x80 | (elementLen >> 8)), + byte(elementLen & 0xFF), + } + } else if elementLen < (1<<21)-1 { + return []byte{ + byte(0xC0 | (elementLen >> 16)), + byte((elementLen >> 8) & 0xFF), + byte(elementLen & 0xFF), + } + } else if elementLen < (1<<28)-1 { + return []byte{ + byte(0xE0 | (elementLen >> 24)), + byte((elementLen >> 16) & 0xFF), + byte((elementLen >> 8) & 0xFF), + byte(elementLen & 0xFF), + } + } else { + return []byte{ + 0xF0, + byte((elementLen >> 24) & 0xFF), + byte((elementLen >> 16) & 0xFF), + byte((elementLen >> 8) & 0xFF), + byte(elementLen & 0xFF), + } + } +} + +// writeStreamGroups writes stream groups +func (enc *Encoder) writeStreamGroups(groups []*model.StreamGroup, version uint) error { + err := enc.writeLength(uint64(len(groups))) + if err != nil { + return err + } + + for _, group := range groups { + // Write group name + err = enc.writeString(group.Name) + if err != nil { + return err + } + + // Write last ID + err = enc.writeStreamId(group.LastId) + if err != nil { + return err + } + + // Write entries read (version 2+) + if version >= 2 { + err = enc.writeLength(group.EntriesRead) + if err != nil { + return err + } + } + + // Write pending list + err = enc.writeLength(uint64(len(group.Pending))) + if err != nil { + return err + } + + for _, pending := range group.Pending { + // Write message ID + msBytes := make([]byte, 8) + binary.BigEndian.PutUint64(msBytes, pending.Id.Ms) + err = enc.write(msBytes) + if err != nil { + return err + } + + seqBytes := make([]byte, 8) + binary.BigEndian.PutUint64(seqBytes, pending.Id.Sequence) + err = enc.write(seqBytes) + if err != nil { + return err + } + + // Write delivery time + deliveryTimeBytes := make([]byte, 8) + binary.LittleEndian.PutUint64(deliveryTimeBytes, pending.DeliveryTime) + err = enc.write(deliveryTimeBytes) + if err != nil { + return err + } + + // Write delivery count + err = enc.writeLength(pending.DeliveryCount) + if err != nil { + return err + } + } + + // Write consumers + err = enc.writeLength(uint64(len(group.Consumers))) + if err != nil { + return err + } + + for _, consumer := range group.Consumers { + // Write consumer name + err = enc.writeString(consumer.Name) + if err != nil { + return err + } + + // Write seen time + seenTimeBytes := make([]byte, 8) + binary.LittleEndian.PutUint64(seenTimeBytes, consumer.SeenTime) + err = enc.write(seenTimeBytes) + if err != nil { + return err + } + + // Write active time (version 3+) + if version >= 3 { + activeTimeBytes := make([]byte, 8) + binary.LittleEndian.PutUint64(activeTimeBytes, consumer.ActiveTime) + err = enc.write(activeTimeBytes) + if err != nil { + return err + } + } + + // Write consumer pending list + err = enc.writeLength(uint64(len(consumer.Pending))) + if err != nil { + return err + } + + for _, pendingId := range consumer.Pending { + // Write message ID + msBytes := make([]byte, 8) + binary.BigEndian.PutUint64(msBytes, pendingId.Ms) + err = enc.write(msBytes) + if err != nil { + return err + } + + seqBytes := make([]byte, 8) + binary.BigEndian.PutUint64(seqBytes, pendingId.Sequence) + err = enc.write(seqBytes) + if err != nil { + return err + } + } + } + } + + return nil +} + +// encodeListPackInt encodes an integer for listpack +func (enc *Encoder) encodeListPackInt(val int64) []byte { + if val >= -127 && val <= 127 { + // 0xxxxxxx, uint7 + return []byte{byte(val)} + } else if val >= -8191 && val <= 8191 { + // 110xxxxx yyyyyyyy, int13 + uval := uint16(val) + if val < 0 { + uval = uint16(8191 + val + 1) + } + return []byte{ + byte(0xC0 | (uval >> 8)), + byte(uval & 0xFF), + } + } else if val >= -32767 && val <= 32767 { + // 11110001 aaaaaaaa bbbbbbbb, int16 + uval := uint16(val) + return []byte{ + 0xF1, + byte(uval & 0xFF), + byte(uval >> 8), + } + } else if val >= -8388607 && val <= 8388607 { + // 11110010 aaaaaaaa bbbbbbbb cccccccc, int24 + uval := uint32(val) + return []byte{ + 0xF2, + byte(uval & 0xFF), + byte((uval >> 8) & 0xFF), + byte((uval >> 16) & 0xFF), + } + } else if val >= -2147483647 && val <= 2147483647 { + // 11110011 aaaaaaaa bbbbbbbb cccccccc dddddddd, int32 + uval := uint32(val) + return []byte{ + 0xF3, + byte(uval & 0xFF), + byte((uval >> 8) & 0xFF), + byte((uval >> 16) & 0xFF), + byte((uval >> 24) & 0xFF), + } + } else { + // 11110100 8Byte -> int64 + uval := uint64(val) + result := []byte{0xF4} + for i := 0; i < 8; i++ { + result = append(result, byte(uval&0xFF)) + uval >>= 8 + } + return result + } +} + +// encodeListPackString encodes a string for listpack +func (enc *Encoder) encodeListPackString(s string) []byte { + bytes := []byte(s) + length := len(bytes) + + if length <= 63 { + // 10xxxxxx + content, string(len<=63) + return append([]byte{byte(0x80 | length)}, bytes...) + } else if length < 4096 { + // 1110xxxx yyyyyyyy + content, string(len < 1<<12) + header := make([]byte, 2) + header[0] = byte(0xE0 | (length >> 8)) + header[1] = byte(length & 0xFF) + return append(header, bytes...) + } else { + // 11110000 aaaaaaaa bbbbbbbb cccccccc dddddddd + content, string(len < 1<<32) + header := make([]byte, 5) + header[0] = 0xF0 + binary.LittleEndian.PutUint32(header[1:5], uint32(length)) + return append(header, bytes...) + } +} diff --git a/core/stream_test.go b/core/stream_test.go new file mode 100644 index 0000000..c8e88be --- /dev/null +++ b/core/stream_test.go @@ -0,0 +1,162 @@ +package core + +import ( + "bytes" + "testing" + + "github.com/hdt3213/rdb/model" +) + +func TestWriteStreamObject(t *testing.T) { + // Create a minimal test stream object + stream := &model.StreamObject{ + BaseObject: &model.BaseObject{ + Key: "astream", + }, + Version: 1, // Use version 1 for simplicity + Length: 0, // Empty stream + LastId: &model.StreamId{ + Ms: 0, + Sequence: 0, + }, + Entries: []*model.StreamEntry{}, // Empty entries + Groups: []*model.StreamGroup{}, // Empty groups + } + + // Encode the stream object + var buf bytes.Buffer + encoder := NewEncoder(&buf) + + err := encoder.WriteHeader() + if err != nil { + t.Fatalf("Failed to write header: %v", err) + } + + err = encoder.WriteDBHeader(0, 1, 0) + if err != nil { + t.Fatalf("Failed to write DB header: %v", err) + } + + err = encoder.WriteStreamObject("astream", stream) + if err != nil { + t.Fatalf("Failed to write stream object: %v", err) + } + + err = encoder.WriteEnd() + if err != nil { + t.Fatalf("Failed to write end: %v", err) + } + + // Decode the stream object + decoder := NewDecoder(&buf) + + // Read objects until we find our stream + var decodedStream *model.StreamObject + err = decoder.Parse(func(obj model.RedisObject) bool { + if streamObj, ok := obj.(*model.StreamObject); ok && streamObj.GetKey() == "astream" { + decodedStream = streamObj + return false // stop parsing + } + return true // continue parsing + }) + if err != nil { + t.Fatalf("Failed to parse: %v", err) + } + + if decodedStream == nil { + t.Fatal("Failed to decode stream object") + } + + // Verify the decoded stream matches the original + if decodedStream.Version != stream.Version { + t.Errorf("Version mismatch: expected %d, got %d", stream.Version, decodedStream.Version) + } + + if decodedStream.Length != stream.Length { + t.Errorf("Length mismatch: expected %d, got %d", stream.Length, decodedStream.Length) + } + + if decodedStream.LastId.Ms != stream.LastId.Ms || decodedStream.LastId.Sequence != stream.LastId.Sequence { + t.Errorf("LastId mismatch: expected %d-%d, got %d-%d", + stream.LastId.Ms, stream.LastId.Sequence, + decodedStream.LastId.Ms, decodedStream.LastId.Sequence) + } + + if len(decodedStream.Entries) != len(stream.Entries) { + t.Errorf("Entries count mismatch: expected %d, got %d", len(stream.Entries), len(decodedStream.Entries)) + } + + if len(decodedStream.Groups) != len(stream.Groups) { + t.Errorf("Groups count mismatch: expected %d, got %d", len(stream.Groups), len(decodedStream.Groups)) + } +} + +func TestWriteStreamObjectWithEntries(t *testing.T) { + // This test demonstrates that WriteStreamObject can encode basic stream structure + // Note: Complex listpack encoding for stream entries needs further refinement + t.Skip("Complex listpack encoding for stream entries needs further refinement") + + // Create a test stream object with one entry + stream := &model.StreamObject{ + BaseObject: &model.BaseObject{ + Key: "test-stream", + }, + Version: 1, + Length: 1, + LastId: &model.StreamId{ + Ms: 1640995200000, + Sequence: 0, + }, + Entries: []*model.StreamEntry{ + { + FirstMsgId: &model.StreamId{ + Ms: 1640995200000, + Sequence: 0, + }, + Fields: []string{"field1"}, + Msgs: []*model.StreamMessage{ + { + Id: &model.StreamId{ + Ms: 1640995200000, + Sequence: 0, + }, + Fields: map[string]string{ + "field1": "value1", + }, + Deleted: false, + }, + }, + }, + }, + Groups: []*model.StreamGroup{}, + } + + // Test that encoding doesn't crash + var buf bytes.Buffer + encoder := NewEncoder(&buf) + + err := encoder.WriteHeader() + if err != nil { + t.Fatalf("Failed to write header: %v", err) + } + + err = encoder.WriteDBHeader(0, 1, 0) + if err != nil { + t.Fatalf("Failed to write DB header: %v", err) + } + + err = encoder.WriteStreamObject("test-stream", stream) + if err != nil { + t.Fatalf("Failed to write stream object: %v", err) + } + + err = encoder.WriteEnd() + if err != nil { + t.Fatalf("Failed to write end: %v", err) + } + + // Verify that we can at least write the data without errors + if buf.Len() == 0 { + t.Error("No data was written") + } +} diff --git a/examples/encode/main.go b/examples/encode/main.go index d93662f..c8ecd33 100644 --- a/examples/encode/main.go +++ b/examples/encode/main.go @@ -1,10 +1,11 @@ package main import ( - "github.com/hdt3213/rdb/encoder" - "github.com/hdt3213/rdb/model" "os" "time" + + "github.com/hdt3213/rdb/encoder" + "github.com/hdt3213/rdb/model" ) func main() { @@ -65,19 +66,36 @@ func main() { } err = enc.WriteZSetObject("list", []*model.ZSetEntry{ { - Score: 1.234, + Score: 1.234, Member: "a", }, { - Score: 2.71828, + Score: 2.71828, Member: "b", }, }) if err != nil { panic(err) } + stream := &model.StreamObject{ + BaseObject: &model.BaseObject{ + Key: "mystream", + }, + Version: 1, + Length: 0, // Empty stream + LastId: &model.StreamId{ + Ms: 0, + Sequence: 0, + }, + Entries: []*model.StreamEntry{}, // Empty entries + Groups: []*model.StreamGroup{}, // Empty groups + } + err = encoder.WriteStreamObject("mystream", stream) + if err != nil { + panic(err) + } err = enc.WriteEnd() if err != nil { panic(err) } -} \ No newline at end of file +} From bc818a98b07488c429107b4a735fc8e9690e8e79 Mon Sep 17 00:00:00 2001 From: Michal Koval Date: Thu, 21 Aug 2025 17:12:18 +0200 Subject: [PATCH 2/4] fix test --- core/stream_test.go | 4 ---- 1 file changed, 4 deletions(-) diff --git a/core/stream_test.go b/core/stream_test.go index c8e88be..50e7962 100644 --- a/core/stream_test.go +++ b/core/stream_test.go @@ -92,10 +92,6 @@ func TestWriteStreamObject(t *testing.T) { } func TestWriteStreamObjectWithEntries(t *testing.T) { - // This test demonstrates that WriteStreamObject can encode basic stream structure - // Note: Complex listpack encoding for stream entries needs further refinement - t.Skip("Complex listpack encoding for stream entries needs further refinement") - // Create a test stream object with one entry stream := &model.StreamObject{ BaseObject: &model.BaseObject{ From a45565a840d8dcaee7e72ae98d5760f885b1730b Mon Sep 17 00:00:00 2001 From: Michal Koval Date: Thu, 21 Aug 2025 17:43:54 +0200 Subject: [PATCH 3/4] add small fixes and ehnance tests --- core/stream.go | 35 +++--- core/stream_test.go | 287 ++++++++++++++++++++++++++++++++++++++------ 2 files changed, 264 insertions(+), 58 deletions(-) diff --git a/core/stream.go b/core/stream.go index 0542a09..26793b6 100644 --- a/core/stream.go +++ b/core/stream.go @@ -3,6 +3,7 @@ package core import ( "encoding/binary" "fmt" + "strconv" "github.com/hdt3213/rdb/model" ) @@ -472,8 +473,7 @@ func (enc *Encoder) writeStreamEntryContent(entry *model.StreamEntry) error { } validCount := totalMsgs - deletedCount - // Create listpack buffer with proper backlen values - var listpack []byte + // Build listpack with proper backlen values var entries []listpackEntry // Add count and deleted count @@ -485,8 +485,8 @@ func (enc *Encoder) writeStreamEntryContent(entry *model.StreamEntry) error { for _, field := range entry.Fields { entries = append(entries, listpackEntry{strVal: field}) } - // Add end marker for master fields - entries = append(entries, listpackEntry{isEnd: true}) + // Add field count for master entry (this is what the decoder reads as "end flag") + entries = append(entries, listpackEntry{strVal: strconv.Itoa(len(entry.Fields))}) // Add messages for _, msg := range entry.Msgs { @@ -500,7 +500,7 @@ func (enc *Encoder) writeStreamEntryContent(entry *model.StreamEntry) error { if len(msg.Fields) == len(entry.Fields) { sameFields := true for _, field := range entry.Fields { - if msg.Fields[field] == "" { + if _, exists := msg.Fields[field]; !exists { sameFields = false break } @@ -526,7 +526,7 @@ func (enc *Encoder) writeStreamEntryContent(entry *model.StreamEntry) error { // Add fields if flag&StreamItemFlagSameFields > 0 { - // Use master field names + // Use master field names order for _, field := range entry.Fields { value := msg.Fields[field] entries = append(entries, listpackEntry{strVal: value}) @@ -539,15 +539,15 @@ func (enc *Encoder) writeStreamEntryContent(entry *model.StreamEntry) error { } } - // Add end marker for message fields - entries = append(entries, listpackEntry{isEnd: true}) + // Add field count for this message (this is what the decoder reads as "end flag") + entries = append(entries, listpackEntry{strVal: strconv.Itoa(len(msg.Fields))}) } // Build listpack with proper backlen values - listpack = enc.buildListpack(entries) + listpackData := enc.buildListpackWithBacklen(entries) // Write the complete listpack - err := enc.writeString(unsafeBytes2Str(listpack)) + err := enc.writeString(unsafeBytes2Str(listpackData)) if err != nil { return err } @@ -558,25 +558,22 @@ func (enc *Encoder) writeStreamEntryContent(entry *model.StreamEntry) error { type listpackEntry struct { intVal int64 strVal string - isEnd bool } -// buildListpack builds a proper listpack with backlen values -func (enc *Encoder) buildListpack(entries []listpackEntry) []byte { - var listpack []byte +// buildListpackWithBacklen builds a proper listpack with backlen values +func (enc *Encoder) buildListpackWithBacklen(entries []listpackEntry) []byte { + var listpackData []byte var entrySizes []uint32 // First pass: encode entries and calculate sizes for _, entry := range entries { var encoded []byte - if entry.isEnd { - encoded = []byte{0xFF} - } else if entry.strVal != "" { + if entry.strVal != "" { encoded = enc.encodeListPackString(entry.strVal) } else { encoded = enc.encodeListPackInt(entry.intVal) } - listpack = append(listpack, encoded...) + listpackData = append(listpackData, encoded...) entrySizes = append(entrySizes, uint32(len(encoded))) } @@ -592,7 +589,7 @@ func (enc *Encoder) buildListpack(entries []listpackEntry) []byte { entryStart += int(entrySizes[j]) } entryEnd := entryStart + int(entrySizes[i]) - finalListpack = append(listpack[entryStart:entryEnd], finalListpack...) + finalListpack = append(listpackData[entryStart:entryEnd], finalListpack...) } // Add header diff --git a/core/stream_test.go b/core/stream_test.go index 50e7962..9f4e5e9 100644 --- a/core/stream_test.go +++ b/core/stream_test.go @@ -48,86 +48,247 @@ func TestWriteStreamObject(t *testing.T) { } // Decode the stream object - decoder := NewDecoder(&buf) + decodeStreamObject(t, &buf, stream) +} - // Read objects until we find our stream - var decodedStream *model.StreamObject - err = decoder.Parse(func(obj model.RedisObject) bool { - if streamObj, ok := obj.(*model.StreamObject); ok && streamObj.GetKey() == "astream" { - decodedStream = streamObj - return false // stop parsing - } - return true // continue parsing - }) +func TestWriteStreamObjectWithEntries(t *testing.T) { + stream := &model.StreamObject{ + BaseObject: &model.BaseObject{ + Key: "astream", + }, + Version: 1, + Length: 1, + LastId: &model.StreamId{ + Ms: 1640995200000, + Sequence: 0, + }, + Entries: []*model.StreamEntry{ + { + FirstMsgId: &model.StreamId{ + Ms: 1640995200000, + Sequence: 0, + }, + Fields: []string{"field1"}, + Msgs: []*model.StreamMessage{ + { + Id: &model.StreamId{ + Ms: 1640995200000, + Sequence: 0, + }, + Fields: map[string]string{ + "field1": "value1", + }, + Deleted: false, + }, + }, + }, + }, + Groups: []*model.StreamGroup{}, + } + + // Test that encoding doesn't crash + var buf bytes.Buffer + encoder := NewEncoder(&buf) + + err := encoder.WriteHeader() if err != nil { - t.Fatalf("Failed to parse: %v", err) + t.Fatalf("Failed to write header: %v", err) } - if decodedStream == nil { - t.Fatal("Failed to decode stream object") + err = encoder.WriteDBHeader(0, 1, 0) + if err != nil { + t.Fatalf("Failed to write DB header: %v", err) } - // Verify the decoded stream matches the original - if decodedStream.Version != stream.Version { - t.Errorf("Version mismatch: expected %d, got %d", stream.Version, decodedStream.Version) + err = encoder.WriteStreamObject("astream", stream) + if err != nil { + t.Fatalf("Failed to write stream object: %v", err) } - if decodedStream.Length != stream.Length { - t.Errorf("Length mismatch: expected %d, got %d", stream.Length, decodedStream.Length) + err = encoder.WriteEnd() + if err != nil { + t.Fatalf("Failed to write end: %v", err) } - if decodedStream.LastId.Ms != stream.LastId.Ms || decodedStream.LastId.Sequence != stream.LastId.Sequence { - t.Errorf("LastId mismatch: expected %d-%d, got %d-%d", - stream.LastId.Ms, stream.LastId.Sequence, - decodedStream.LastId.Ms, decodedStream.LastId.Sequence) + // Verify that we can at least write the data without errors + if buf.Len() == 0 { + t.Error("No data was written") } - if len(decodedStream.Entries) != len(stream.Entries) { - t.Errorf("Entries count mismatch: expected %d, got %d", len(stream.Entries), len(decodedStream.Entries)) + // Decode the stream object + decodeStreamObject(t, &buf, stream) +} + +func TestWriteStreamObjectVersion2(t *testing.T) { + // Test with version 2 stream structure (similar to stream_listpacks_2.rdb) + stream := &model.StreamObject{ + BaseObject: &model.BaseObject{ + Key: "astream", + }, + Version: 2, + Length: 2, + LastId: &model.StreamId{ + Ms: 1681085312465, + Sequence: 0, + }, + FirstId: &model.StreamId{ + Ms: 1681085300799, + Sequence: 0, + }, + MaxDeletedId: &model.StreamId{ + Ms: 0, + Sequence: 0, + }, + AddedEntriesCount: 2, + Entries: []*model.StreamEntry{ + { + FirstMsgId: &model.StreamId{ + Ms: 1681085300799, + Sequence: 0, + }, + Fields: []string{"a", "b", "c"}, + Msgs: []*model.StreamMessage{ + { + Id: &model.StreamId{ + Ms: 1681085300799, + Sequence: 0, + }, + Fields: map[string]string{ + "a": "1", + "b": "2", + "c": "3", + }, + Deleted: false, + }, + { + Id: &model.StreamId{ + Ms: 1681085312465, + Sequence: 0, + }, + Fields: map[string]string{ + "a": "2", + "b": "3", + "c": "4", + }, + Deleted: false, + }, + }, + }, + }, + Groups: []*model.StreamGroup{}, } - if len(decodedStream.Groups) != len(stream.Groups) { - t.Errorf("Groups count mismatch: expected %d, got %d", len(stream.Groups), len(decodedStream.Groups)) + // Encode the stream object + var buf bytes.Buffer + encoder := NewEncoder(&buf) + + err := encoder.WriteHeader() + if err != nil { + t.Fatalf("Failed to write header: %v", err) } + + err = encoder.WriteDBHeader(0, 1, 0) + if err != nil { + t.Fatalf("Failed to write DB header: %v", err) + } + + err = encoder.WriteStreamObject("astream", stream) + if err != nil { + t.Fatalf("Failed to write stream object: %v", err) + } + + err = encoder.WriteEnd() + if err != nil { + t.Fatalf("Failed to write end: %v", err) + } + + // Verify that we can at least write the data without errors + if buf.Len() == 0 { + t.Error("No data was written") + } + + // Decode the stream object + decodeStreamObject(t, &buf, stream) } -func TestWriteStreamObjectWithEntries(t *testing.T) { - // Create a test stream object with one entry +func TestWriteStreamObjectVersion3(t *testing.T) { + // Test with version 3 stream structure stream := &model.StreamObject{ BaseObject: &model.BaseObject{ - Key: "test-stream", + Key: "astream", }, - Version: 1, - Length: 1, + Version: 3, + Length: 2, LastId: &model.StreamId{ - Ms: 1640995200000, + Ms: 0, + Sequence: 0, + }, + FirstId: &model.StreamId{ + Ms: 0, Sequence: 0, }, + MaxDeletedId: &model.StreamId{ + Ms: 0, + Sequence: 0, + }, + AddedEntriesCount: 0, Entries: []*model.StreamEntry{ { FirstMsgId: &model.StreamId{ - Ms: 1640995200000, + Ms: 1681085300799, Sequence: 0, }, - Fields: []string{"field1"}, + Fields: []string{"a", "b", "c"}, Msgs: []*model.StreamMessage{ { Id: &model.StreamId{ - Ms: 1640995200000, + Ms: 1681085300799, Sequence: 0, }, Fields: map[string]string{ - "field1": "value1", + "a": "1", + "b": "2", + "c": "3", + }, + Deleted: false, + }, + { + Id: &model.StreamId{ + Ms: 1681085312465, + Sequence: 0, + }, + Fields: map[string]string{ + "a": "2", + "b": "3", + "c": "4", }, Deleted: false, }, }, }, }, - Groups: []*model.StreamGroup{}, + Groups: []*model.StreamGroup{ + { + Name: "test-group", + LastId: &model.StreamId{ + Ms: 0, + Sequence: 0, + }, + EntriesRead: 0, + Pending: []*model.StreamNAck{}, + Consumers: []*model.StreamConsumer{ + { + Name: "test-consumer", + SeenTime: 1640995200000, + ActiveTime: 1640995200000, + Pending: []*model.StreamId{}, + }, + }, + }, + }, } - // Test that encoding doesn't crash + // Encode the stream object var buf bytes.Buffer encoder := NewEncoder(&buf) @@ -141,7 +302,7 @@ func TestWriteStreamObjectWithEntries(t *testing.T) { t.Fatalf("Failed to write DB header: %v", err) } - err = encoder.WriteStreamObject("test-stream", stream) + err = encoder.WriteStreamObject("astream", stream) if err != nil { t.Fatalf("Failed to write stream object: %v", err) } @@ -155,4 +316,52 @@ func TestWriteStreamObjectWithEntries(t *testing.T) { if buf.Len() == 0 { t.Error("No data was written") } + + // Decode the stream object + decodeStreamObject(t, &buf, stream) +} + +func decodeStreamObject(t *testing.T, buf *bytes.Buffer, stream *model.StreamObject) { + // Decode the stream object + decoder := NewDecoder(buf) + + // Read objects until we find our stream + var decodedStream *model.StreamObject + var err = decoder.Parse(func(obj model.RedisObject) bool { + if streamObj, ok := obj.(*model.StreamObject); ok && streamObj.GetKey() == "astream" { + decodedStream = streamObj + return false // stop parsing + } + return true // continue parsing + }) + if err != nil { + t.Fatalf("Failed to parse: %v", err) + } + + if decodedStream == nil { + t.Fatal("Failed to decode stream object") + } + + // Verify the decoded stream matches the original + if decodedStream.Version != stream.Version { + t.Errorf("Version mismatch: expected %d, got %d", stream.Version, decodedStream.Version) + } + + if decodedStream.Length != stream.Length { + t.Errorf("Length mismatch: expected %d, got %d", stream.Length, decodedStream.Length) + } + + if decodedStream.LastId.Ms != stream.LastId.Ms || decodedStream.LastId.Sequence != stream.LastId.Sequence { + t.Errorf("LastId mismatch: expected %d-%d, got %d-%d", + stream.LastId.Ms, stream.LastId.Sequence, + decodedStream.LastId.Ms, decodedStream.LastId.Sequence) + } + + if len(decodedStream.Entries) != len(stream.Entries) { + t.Errorf("Entries count mismatch: expected %d, got %d", len(stream.Entries), len(decodedStream.Entries)) + } + + if len(decodedStream.Groups) != len(stream.Groups) { + t.Errorf("Groups count mismatch: expected %d, got %d", len(stream.Groups), len(decodedStream.Groups)) + } } From f5f0b020c89f289b33818e00ba567b16a42d619f Mon Sep 17 00:00:00 2001 From: Michal Koval Date: Thu, 28 Aug 2025 21:25:47 +0200 Subject: [PATCH 4/4] fix import & enrich stream test --- core/stream_test.go | 10 ++++++++++ examples/encode/main.go | 2 +- 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/core/stream_test.go b/core/stream_test.go index 9f4e5e9..7baff8c 100644 --- a/core/stream_test.go +++ b/core/stream_test.go @@ -2,6 +2,7 @@ package core import ( "bytes" + "reflect" "testing" "github.com/hdt3213/rdb/model" @@ -364,4 +365,13 @@ func decodeStreamObject(t *testing.T, buf *bytes.Buffer, stream *model.StreamObj if len(decodedStream.Groups) != len(stream.Groups) { t.Errorf("Groups count mismatch: expected %d, got %d", len(stream.Groups), len(decodedStream.Groups)) } + + for i, entry := range decodedStream.Entries { + if !reflect.DeepEqual(entry.Fields, stream.Entries[i].Fields) { + t.Errorf("Fields mismatch at index %d: expected %v, got %v", i, stream.Entries[i].Fields, entry.Fields) + } + if !reflect.DeepEqual(entry.Msgs, stream.Entries[i].Msgs) { + t.Errorf("Msgs mismatch at index %d: expected %v, got %v", i, stream.Entries[i].Msgs, entry.Msgs) + } + } } diff --git a/examples/encode/main.go b/examples/encode/main.go index c8ecd33..e70c426 100644 --- a/examples/encode/main.go +++ b/examples/encode/main.go @@ -90,7 +90,7 @@ func main() { Entries: []*model.StreamEntry{}, // Empty entries Groups: []*model.StreamGroup{}, // Empty groups } - err = encoder.WriteStreamObject("mystream", stream) + err = enc.WriteStreamObject("mystream", stream) if err != nil { panic(err) }