From a4aa35e579c32f091b74d926c756acbf84091db4 Mon Sep 17 00:00:00 2001 From: shamaton Date: Thu, 5 Feb 2026 12:26:39 +0900 Subject: [PATCH 1/2] Add embedded struct support with optimized fast path (#54) --- internal/common/common.go | 154 ++++++ internal/decoding/struct.go | 226 +++++++-- internal/encoding/struct.go | 246 ++++++--- internal/stream/decoding/struct.go | 225 +++++++-- internal/stream/encoding/struct.go | 159 ++++-- msgpack_test.go | 773 +++++++++++++++++++++++++++++ 6 files changed, 1602 insertions(+), 181 deletions(-) diff --git a/internal/common/common.go b/internal/common/common.go index ffd3767..ba4d6a5 100644 --- a/internal/common/common.go +++ b/internal/common/common.go @@ -9,6 +9,160 @@ import ( type Common struct { } +// FieldInfo holds information about a struct field including its path for embedded structs +type FieldInfo struct { + Path []int // path to reach this field (indices for embedded structs) + Name string // field name or tag + Omit bool // omitempty flag + Tagged bool // tag name explicitly set + OmitPaths [][]int // paths to embedded fields with omitempty +} + +// CollectFields collects all fields from a struct, expanding embedded structs +// following the same rules as encoding/json +func (c *Common) CollectFields(t reflect.Type, path []int) []FieldInfo { + return c.collectFields(t, path, nil) +} + +func (c *Common) collectFields(t reflect.Type, path []int, omitPaths [][]int) []FieldInfo { + var fields []FieldInfo + var embedded []FieldInfo // embedded fields to process later (lower priority) + + for i := 0; i < t.NumField(); i++ { + field := t.Field(i) + + // Check field visibility and get omitempty flag + public, omit, name := c.CheckField(field) + if !public { + continue + } + + // Get tag to check if embedded + tag := field.Tag.Get("msgpack") + // Extract just the name part (before comma if any) + tagName := tag + for j, ch := range tag { + if ch == ',' { + tagName = tag[:j] + break + } + } + + // Check if this is an embedded struct + isEmbedded := field.Anonymous && (tag == "" || tagName == "") + tagged := tagName != "" + + if isEmbedded { + // Get the actual type (dereference pointer if needed) + fieldType := field.Type + if fieldType.Kind() == reflect.Ptr { + fieldType = fieldType.Elem() + } + + // If it's a struct, expand its fields + if fieldType.Kind() == reflect.Struct { + newPath := append(append([]int{}, path...), i) + nextOmitPaths := omitPaths + if omit { + nextOmitPaths = appendOmitPath(omitPaths, newPath) + } + embeddedFields := c.collectFields(fieldType, newPath, nextOmitPaths) + embedded = append(embedded, embeddedFields...) + continue + } + } + + // Regular field or embedded non-struct + newPath := append(append([]int{}, path...), i) + fields = append(fields, FieldInfo{ + Path: newPath, + Name: name, + Omit: omit, + Tagged: tagged, + OmitPaths: omitPaths, + }) + } + + // Add embedded fields after regular fields (they have lower priority) + fields = append(fields, embedded...) + + // Remove duplicates and handle ambiguous fields + return c.deduplicateFields(fields) +} + +func appendOmitPath(paths [][]int, path []int) [][]int { + if len(paths) == 0 { + return [][]int{path} + } + newPaths := make([][]int, len(paths)+1) + copy(newPaths, paths) + newPaths[len(paths)] = path + return newPaths +} + +// deduplicateFields removes duplicate fields and handles ambiguous fields +// following encoding/json behavior +func (c *Common) deduplicateFields(fields []FieldInfo) []FieldInfo { + // Group fields by name and depth, preserving order + type fieldAtDepth struct { + field FieldInfo + depth int + } + fieldsByName := make(map[string][]fieldAtDepth) + var seenNames []string // To preserve order + + for _, f := range fields { + if _, seen := fieldsByName[f.Name]; !seen { + seenNames = append(seenNames, f.Name) + } + fieldsByName[f.Name] = append(fieldsByName[f.Name], fieldAtDepth{ + field: f, + depth: len(f.Path), + }) + } + + var result []FieldInfo + for _, name := range seenNames { + fieldsWithDepth := fieldsByName[name] + + // Find minimum depth + minDepth := fieldsWithDepth[0].depth + for _, fd := range fieldsWithDepth { + if fd.depth < minDepth { + minDepth = fd.depth + } + } + + // Count fields at minimum depth + var fieldsAtMinDepth []FieldInfo + for _, fd := range fieldsWithDepth { + if fd.depth == minDepth { + fieldsAtMinDepth = append(fieldsAtMinDepth, fd.field) + } + } + + // If there's exactly one field at minimum depth, use it + if len(fieldsAtMinDepth) == 1 { + result = append(result, fieldsAtMinDepth[0]) + continue + } + + // Prefer the tagged field if exactly one is tagged at minimum depth + var taggedFields []FieldInfo + for _, f := range fieldsAtMinDepth { + if f.Tagged { + taggedFields = append(taggedFields, f) + } + } + if len(taggedFields) == 1 { + result = append(result, taggedFields[0]) + } + // else: ambiguous field, skip it (following encoding/json behavior) + } + + return result +} + // CheckField returns flag whether should encode/decode or not and field name func (c *Common) CheckField(field reflect.StructField) (public, omit bool, name string) { // A to Z diff --git a/internal/decoding/struct.go b/internal/decoding/struct.go index edc71f3..459dfd8 100644 --- a/internal/decoding/struct.go +++ b/internal/decoding/struct.go @@ -9,18 +9,53 @@ import ( ) type structCacheTypeMap struct { - keys [][]byte - indexes []int + keys [][]byte + + // fast path detection + hasEmbedded bool + + // fast path (hasEmbedded == false): direct field access + simpleIndexes []int + + // embedded path (hasEmbedded == true): path-based access + indexes [][]int // field path (support for embedded structs) } type structCacheTypeArray struct { - m []int + // fast path detection + hasEmbedded bool + + // fast path (hasEmbedded == false): direct field access + simpleIndexes []int + + // embedded path (hasEmbedded == true): path-based access + indexes [][]int // field path (support for embedded structs) } // struct cache map var mapSCTM = sync.Map{} var mapSCTA = sync.Map{} +// getFieldByPath returns the field value by following the path of indices. +// The bool indicates whether the path was reachable (no nil pointer in the path). +func getFieldByPath(rv reflect.Value, path []int, allowAlloc bool) (reflect.Value, bool) { + for _, idx := range path { + // Handle pointer indirection if needed + if rv.Kind() == reflect.Ptr { + if rv.IsNil() { + if !allowAlloc { + return reflect.Value{}, false + } + // Allocate new value if pointer is nil + rv.Set(reflect.New(rv.Type().Elem())) + } + rv = rv.Elem() + } + rv = rv.Field(idx) + } + return rv, true +} + func (d *decoder) setStruct(rv reflect.Value, offset int, k reflect.Kind) (int, error) { /* if d.isDateTime(offset) { @@ -70,26 +105,66 @@ func (d *decoder) setStructFromArray(rv reflect.Value, offset int, k reflect.Kin cache, findCache := mapSCTA.Load(rv.Type()) if !findCache { scta = &structCacheTypeArray{} - for i := 0; i < rv.NumField(); i++ { - if ok, _, _ := d.CheckField(rv.Type().Field(i)); ok { - scta.m = append(scta.m, i) + fields := d.CollectFields(rv.Type(), nil) + + // detect embedded fields + hasEmbedded := false + for _, f := range fields { + if len(f.Path) > 1 || len(f.OmitPaths) > 0 { + hasEmbedded = true + break + } + } + scta.hasEmbedded = hasEmbedded + + for _, field := range fields { + if hasEmbedded { + scta.indexes = append(scta.indexes, field.Path) + } else { + scta.simpleIndexes = append(scta.simpleIndexes, field.Path[0]) } } mapSCTA.Store(rv.Type(), scta) } else { scta = cache.(*structCacheTypeArray) } + // set value - for i := 0; i < l; i++ { - if i < len(scta.m) { - o, err = d.decode(rv.Field(scta.m[i]), o) - if err != nil { - return 0, err + if scta.hasEmbedded { + for i := 0; i < l; i++ { + if i < len(scta.indexes) { + allowAlloc := !d.isCodeNil(d.data[o]) + fieldValue, ok := getFieldByPath(rv, scta.indexes[i], allowAlloc) + if ok { + o, err = d.decode(fieldValue, o) + if err != nil { + return 0, err + } + } else { + o, err = d.jumpOffset(o) + if err != nil { + return 0, err + } + } + } else { + o, err = d.jumpOffset(o) + if err != nil { + return 0, err + } } - } else { - o, err = d.jumpOffset(o) - if err != nil { - return 0, err + } + } else { + for i := 0; i < l; i++ { + if i < len(scta.simpleIndexes) { + o, err = d.decode(rv.Field(scta.simpleIndexes[i]), o) + if err != nil { + return 0, err + } + } else { + o, err = d.jumpOffset(o) + if err != nil { + return 0, err + } } } } @@ -111,10 +186,24 @@ func (d *decoder) setStructFromMap(rv reflect.Value, offset int, k reflect.Kind) cache, cacheFind := mapSCTM.Load(rv.Type()) if !cacheFind { sctm = &structCacheTypeMap{} - for i := 0; i < rv.NumField(); i++ { - if ok, _, name := d.CheckField(rv.Type().Field(i)); ok { - sctm.keys = append(sctm.keys, []byte(name)) - sctm.indexes = append(sctm.indexes, i) + fields := d.CollectFields(rv.Type(), nil) + + // detect embedded fields + hasEmbedded := false + for _, f := range fields { + if len(f.Path) > 1 || len(f.OmitPaths) > 0 { + hasEmbedded = true + break + } + } + sctm.hasEmbedded = hasEmbedded + + for _, field := range fields { + sctm.keys = append(sctm.keys, []byte(field.Name)) + if hasEmbedded { + sctm.indexes = append(sctm.indexes, field.Path) + } else { + sctm.simpleIndexes = append(sctm.simpleIndexes, field.Path[0]) } } mapSCTM.Store(rv.Type(), sctm) @@ -122,42 +211,93 @@ func (d *decoder) setStructFromMap(rv reflect.Value, offset int, k reflect.Kind) sctm = cache.(*structCacheTypeMap) } - for i := 0; i < l; i++ { - dataKey, o2, err := d.asStringByte(o, k) - if err != nil { - return 0, err - } - - fieldIndex := -1 - for keyIndex, keyBytes := range sctm.keys { - if len(keyBytes) != len(dataKey) { - continue + if sctm.hasEmbedded { + for i := 0; i < l; i++ { + dataKey, o2, err := d.asStringByte(o, k) + if err != nil { + return 0, err } - fieldIndex = sctm.indexes[keyIndex] - for dataIndex := range dataKey { - if dataKey[dataIndex] != keyBytes[dataIndex] { - fieldIndex = -1 + fieldPath := []int(nil) + for keyIndex, keyBytes := range sctm.keys { + if len(keyBytes) != len(dataKey) { + continue + } + + found := true + for dataIndex := range dataKey { + if dataKey[dataIndex] != keyBytes[dataIndex] { + found = false + break + } + } + if found { + fieldPath = sctm.indexes[keyIndex] break } } - if fieldIndex >= 0 { - break + + if fieldPath != nil { + allowAlloc := !d.isCodeNil(d.data[o2]) + fieldValue, ok := getFieldByPath(rv, fieldPath, allowAlloc) + if ok { + o2, err = d.decode(fieldValue, o2) + if err != nil { + return 0, err + } + } else { + o2, err = d.jumpOffset(o2) + if err != nil { + return 0, err + } + } + } else { + o2, err = d.jumpOffset(o2) + if err != nil { + return 0, err + } } + o = o2 } - - if fieldIndex >= 0 { - o2, err = d.decode(rv.Field(fieldIndex), o2) + } else { + for i := 0; i < l; i++ { + dataKey, o2, err := d.asStringByte(o, k) if err != nil { return 0, err } - } else { - o2, err = d.jumpOffset(o2) - if err != nil { - return 0, err + + fieldIndex := -1 + for keyIndex, keyBytes := range sctm.keys { + if len(keyBytes) != len(dataKey) { + continue + } + + found := true + for dataIndex := range dataKey { + if dataKey[dataIndex] != keyBytes[dataIndex] { + found = false + break + } + } + if found { + fieldIndex = sctm.simpleIndexes[keyIndex] + break + } + } + + if fieldIndex >= 0 { + o2, err = d.decode(rv.Field(fieldIndex), o2) + if err != nil { + return 0, err + } + } else { + o2, err = d.jumpOffset(o2) + if err != nil { + return 0, err + } } + o = o2 } - o = o2 } return o, nil } diff --git a/internal/encoding/struct.go b/internal/encoding/struct.go index 8c41be6..1c820e3 100644 --- a/internal/encoding/struct.go +++ b/internal/encoding/struct.go @@ -10,10 +10,21 @@ import ( ) type structCache struct { - indexes []int - names []string - omits []bool - noOmit bool + // common fields + names []string + omits []bool + noOmit bool + + // fast path detection + hasEmbedded bool + + // fast path (hasEmbedded == false): direct field access + simpleIndexes []int + + // embedded path (hasEmbedded == true): path-based access + indexes [][]int // field path (support for embedded structs) + omitPaths [][][]int // embedded omitempty parent paths + common.Common } @@ -22,6 +33,33 @@ var cachemap = sync.Map{} type structCalcFunc func(rv reflect.Value) (int, error) type structWriteFunc func(rv reflect.Value, offset int) int +// getFieldByPath returns the field value by following the path of indices. +// The bool indicates whether the path was reachable (no nil pointer in the path). +func getFieldByPath(rv reflect.Value, path []int) (reflect.Value, bool) { + for _, idx := range path { + // Handle pointer indirection if needed + if rv.Kind() == reflect.Ptr { + if rv.IsNil() { + // Return invalid value if pointer is nil + return reflect.Value{}, false + } + rv = rv.Elem() + } + rv = rv.Field(idx) + } + return rv, true +} + +func shouldOmitByParent(rv reflect.Value, omitPaths [][]int) bool { + for _, path := range omitPaths { + parentValue, ok := getFieldByPath(rv, path) + if !ok || parentValue.IsZero() { + return true + } + } + return false +} + func (e *encoder) getStructCalc(typ reflect.Type) structCalcFunc { for j := range extCoders { @@ -61,35 +99,58 @@ func (e *encoder) calcStructArray(rv reflect.Value) (int, error) { cache, find := cachemap.Load(t) var c *structCache if !find { - num := rv.NumField() - c = &structCache{ - indexes: make([]int, 0, num), - names: make([]string, 0, num), - omits: make([]bool, 0, num), + c = &structCache{} + fields := e.CollectFields(t, nil) + + // detect embedded fields + hasEmbedded := false + for _, f := range fields { + if len(f.Path) > 1 || len(f.OmitPaths) > 0 { + hasEmbedded = true + break + } } + c.hasEmbedded = hasEmbedded + omitCount := 0 - for i := 0; i < num; i++ { - field := t.Field(i) - if ok, omit, name := e.CheckField(field); ok { - size, err := e.calcSize(rv.Field(i)) - if err != nil { - return 0, err - } - ret += size - c.indexes = append(c.indexes, i) - c.names = append(c.names, name) - c.omits = append(c.omits, omit) - if omit { - omitCount++ - } + for _, field := range fields { + c.names = append(c.names, field.Name) + c.omits = append(c.omits, field.Omit) + if hasEmbedded { + c.indexes = append(c.indexes, field.Path) + c.omitPaths = append(c.omitPaths, field.OmitPaths) + } else { + c.simpleIndexes = append(c.simpleIndexes, field.Path[0]) + } + if field.Omit { + omitCount++ } } c.noOmit = omitCount == 0 cachemap.Store(t, c) } else { c = cache.(*structCache) - for i := 0; i < len(c.indexes); i++ { - size, err := e.calcSize(rv.Field(c.indexes[i])) + } + + // calculate size based on path type + var numFields int + if c.hasEmbedded { + numFields = len(c.indexes) + for i := 0; i < numFields; i++ { + fieldValue, ok := getFieldByPath(rv, c.indexes[i]) + if shouldOmitByParent(rv, c.omitPaths[i]) || !ok { + fieldValue = reflect.Value{} + } + size, err := e.calcSize(fieldValue) + if err != nil { + return 0, err + } + ret += size + } + } else { + numFields = len(c.simpleIndexes) + for i := 0; i < numFields; i++ { + size, err := e.calcSize(rv.Field(c.simpleIndexes[i])) if err != nil { return 0, err } @@ -98,7 +159,7 @@ func (e *encoder) calcStructArray(rv reflect.Value) (int, error) { } // format size - size, err := e.calcLength(len(c.indexes)) + size, err := e.calcLength(numFields) if err != nil { return 0, err } @@ -111,39 +172,59 @@ func (e *encoder) calcStructMap(rv reflect.Value) (int, error) { t := rv.Type() cache, find := cachemap.Load(t) var c *structCache - var l int if !find { - num := rv.NumField() - c = &structCache{ - indexes: make([]int, 0, num), - names: make([]string, 0, num), - omits: make([]bool, 0, num), + c = &structCache{} + fields := e.CollectFields(t, nil) + + // detect embedded fields + hasEmbedded := false + for _, f := range fields { + if len(f.Path) > 1 || len(f.OmitPaths) > 0 { + hasEmbedded = true + break + } } + c.hasEmbedded = hasEmbedded + omitCount := 0 - for i := 0; i < num; i++ { - if ok, omit, name := e.CheckField(rv.Type().Field(i)); ok { - size, err := e.calcSizeWithOmitEmpty(rv.Field(i), name, omit) - if err != nil { - return 0, err - } - ret += size - c.indexes = append(c.indexes, i) - c.names = append(c.names, name) - c.omits = append(c.omits, omit) - if omit { - omitCount++ - } - if size > 0 { - l++ - } + for _, field := range fields { + c.names = append(c.names, field.Name) + c.omits = append(c.omits, field.Omit) + if hasEmbedded { + c.indexes = append(c.indexes, field.Path) + c.omitPaths = append(c.omitPaths, field.OmitPaths) + } else { + c.simpleIndexes = append(c.simpleIndexes, field.Path[0]) + } + if field.Omit { + omitCount++ } } c.noOmit = omitCount == 0 cachemap.Store(t, c) } else { c = cache.(*structCache) + } + + l := 0 + if c.hasEmbedded { for i := 0; i < len(c.indexes); i++ { - size, err := e.calcSizeWithOmitEmpty(rv.Field(c.indexes[i]), c.names[i], c.omits[i]) + fieldValue, ok := getFieldByPath(rv, c.indexes[i]) + if shouldOmitByParent(rv, c.omitPaths[i]) || !ok { + continue + } + size, err := e.calcSizeWithOmitEmpty(fieldValue, c.names[i], c.omits[i]) + if err != nil { + return 0, err + } + ret += size + if size > 0 { + l++ + } + } + } else { + for i := 0; i < len(c.simpleIndexes); i++ { + size, err := e.calcSizeWithOmitEmpty(rv.Field(c.simpleIndexes[i]), c.names[i], c.omits[i]) if err != nil { return 0, err } @@ -155,7 +236,7 @@ func (e *encoder) calcStructMap(rv reflect.Value) (int, error) { } // format size - size, err := e.calcLength(len(c.indexes)) + size, err := e.calcLength(l) if err != nil { return 0, err } @@ -218,7 +299,13 @@ func (e *encoder) writeStructArray(rv reflect.Value, offset int) int { c := cache.(*structCache) // write format - num := len(c.indexes) + var num int + if c.hasEmbedded { + num = len(c.indexes) + } else { + num = len(c.simpleIndexes) + } + if num <= 0x0f { offset = e.setByte1Int(def.FixArray+num, offset) } else if num <= math.MaxUint16 { @@ -229,8 +316,18 @@ func (e *encoder) writeStructArray(rv reflect.Value, offset int) int { offset = e.setByte4Int(num, offset) } - for i := 0; i < num; i++ { - offset = e.create(rv.Field(c.indexes[i]), offset) + if c.hasEmbedded { + for i := 0; i < num; i++ { + fieldValue, ok := getFieldByPath(rv, c.indexes[i]) + if shouldOmitByParent(rv, c.omitPaths[i]) || !ok { + fieldValue = reflect.Value{} + } + offset = e.create(fieldValue, offset) + } + } else { + for i := 0; i < num; i++ { + offset = e.create(rv.Field(c.simpleIndexes[i]), offset) + } } return offset } @@ -241,14 +338,22 @@ func (e *encoder) writeStructMap(rv reflect.Value, offset int) int { c := cache.(*structCache) // format size - num := len(c.indexes) l := 0 - if c.noOmit { - l = num + if c.hasEmbedded { + num := len(c.indexes) + for i := 0; i < num; i++ { + fieldValue, ok := getFieldByPath(rv, c.indexes[i]) + if shouldOmitByParent(rv, c.omitPaths[i]) || !ok { + continue + } + if c.noOmit || !c.omits[i] || !fieldValue.IsZero() { + l++ + } + } } else { + num := len(c.simpleIndexes) for i := 0; i < num; i++ { - irv := rv.Field(c.indexes[i]) - if !c.omits[i] || !irv.IsZero() { + if c.noOmit || !c.omits[i] || !rv.Field(c.simpleIndexes[i]).IsZero() { l++ } } @@ -264,11 +369,26 @@ func (e *encoder) writeStructMap(rv reflect.Value, offset int) int { offset = e.setByte4Int(l, offset) } - for i := 0; i < num; i++ { - irv := rv.Field(c.indexes[i]) - if !c.omits[i] || !irv.IsZero() { - offset = e.writeString(c.names[i], offset) - offset = e.create(irv, offset) + if c.hasEmbedded { + num := len(c.indexes) + for i := 0; i < num; i++ { + fieldValue, ok := getFieldByPath(rv, c.indexes[i]) + if shouldOmitByParent(rv, c.omitPaths[i]) || !ok { + continue + } + if c.noOmit || !c.omits[i] || !fieldValue.IsZero() { + offset = e.writeString(c.names[i], offset) + offset = e.create(fieldValue, offset) + } + } + } else { + num := len(c.simpleIndexes) + for i := 0; i < num; i++ { + fieldValue := rv.Field(c.simpleIndexes[i]) + if c.noOmit || !c.omits[i] || !fieldValue.IsZero() { + offset = e.writeString(c.names[i], offset) + offset = e.create(fieldValue, offset) + } } } return offset diff --git a/internal/stream/decoding/struct.go b/internal/stream/decoding/struct.go index b905cd5..3adaab3 100644 --- a/internal/stream/decoding/struct.go +++ b/internal/stream/decoding/struct.go @@ -9,18 +9,53 @@ import ( ) type structCacheTypeMap struct { - keys [][]byte - indexes []int + keys [][]byte + + // fast path detection + hasEmbedded bool + + // fast path (hasEmbedded == false): direct field access + simpleIndexes []int + + // embedded path (hasEmbedded == true): path-based access + indexes [][]int // field path (support for embedded structs) } type structCacheTypeArray struct { - m []int + // fast path detection + hasEmbedded bool + + // fast path (hasEmbedded == false): direct field access + simpleIndexes []int + + // embedded path (hasEmbedded == true): path-based access + indexes [][]int // field path (support for embedded structs) } // struct cache map var mapSCTM = sync.Map{} var mapSCTA = sync.Map{} +// getFieldByPath returns the field value by following the path of indices. +// The bool indicates whether the path was reachable (no nil pointer in the path). +func getFieldByPath(rv reflect.Value, path []int, allowAlloc bool) (reflect.Value, bool) { + for _, idx := range path { + // Handle pointer indirection if needed + if rv.Kind() == reflect.Ptr { + if rv.IsNil() { + if !allowAlloc { + return reflect.Value{}, false + } + // Allocate new value if pointer is nil + rv.Set(reflect.New(rv.Type().Elem())) + } + rv = rv.Elem() + } + rv = rv.Field(idx) + } + return rv, true +} + func (d *decoder) setStruct(code byte, rv reflect.Value, k reflect.Kind) error { if len(extCoders) > 0 { innerType, data, err := d.readIfExtType(code) @@ -63,26 +98,67 @@ func (d *decoder) setStructFromArray(code byte, rv reflect.Value, k reflect.Kind cache, findCache := mapSCTA.Load(rv.Type()) if !findCache { scta = &structCacheTypeArray{} - for i := 0; i < rv.NumField(); i++ { - if ok, _, _ := d.CheckField(rv.Type().Field(i)); ok { - scta.m = append(scta.m, i) + fields := d.CollectFields(rv.Type(), nil) + + // detect embedded fields + hasEmbedded := false + for _, f := range fields { + if len(f.Path) > 1 || len(f.OmitPaths) > 0 { + hasEmbedded = true + break + } + } + scta.hasEmbedded = hasEmbedded + + for _, field := range fields { + if hasEmbedded { + scta.indexes = append(scta.indexes, field.Path) + } else { + scta.simpleIndexes = append(scta.simpleIndexes, field.Path[0]) } } mapSCTA.Store(rv.Type(), scta) } else { scta = cache.(*structCacheTypeArray) } + // set value - for i := 0; i < l; i++ { - if i < len(scta.m) { - err = d.decode(rv.Field(scta.m[i])) - if err != nil { - return err + if scta.hasEmbedded { + for i := 0; i < l; i++ { + if i < len(scta.indexes) { + code, err := d.readSize1() + if err != nil { + return err + } + allowAlloc := !d.isCodeNil(code) + fieldValue, ok := getFieldByPath(rv, scta.indexes[i], allowAlloc) + if ok { + err = d.decodeWithCode(code, fieldValue) + if err != nil { + return err + } + } else if !d.isCodeNil(code) { + return d.errorTemplate(code, k) + } + } else { + err = d.jumpOffset() + if err != nil { + return err + } } - } else { - err = d.jumpOffset() - if err != nil { - return err + } + } else { + for i := 0; i < l; i++ { + if i < len(scta.simpleIndexes) { + err = d.decode(rv.Field(scta.simpleIndexes[i])) + if err != nil { + return err + } + } else { + err = d.jumpOffset() + if err != nil { + return err + } } } } @@ -100,10 +176,24 @@ func (d *decoder) setStructFromMap(code byte, rv reflect.Value, k reflect.Kind) cache, cacheFind := mapSCTM.Load(rv.Type()) if !cacheFind { sctm = &structCacheTypeMap{} - for i := 0; i < rv.NumField(); i++ { - if ok, _, name := d.CheckField(rv.Type().Field(i)); ok { - sctm.keys = append(sctm.keys, []byte(name)) - sctm.indexes = append(sctm.indexes, i) + fields := d.CollectFields(rv.Type(), nil) + + // detect embedded fields + hasEmbedded := false + for _, f := range fields { + if len(f.Path) > 1 || len(f.OmitPaths) > 0 { + hasEmbedded = true + break + } + } + sctm.hasEmbedded = hasEmbedded + + for _, field := range fields { + sctm.keys = append(sctm.keys, []byte(field.Name)) + if hasEmbedded { + sctm.indexes = append(sctm.indexes, field.Path) + } else { + sctm.simpleIndexes = append(sctm.simpleIndexes, field.Path[0]) } } mapSCTM.Store(rv.Type(), sctm) @@ -111,39 +201,90 @@ func (d *decoder) setStructFromMap(code byte, rv reflect.Value, k reflect.Kind) sctm = cache.(*structCacheTypeMap) } - for i := 0; i < l; i++ { - dataKey, err := d.asStringByte(k) - if err != nil { - return err - } - - fieldIndex := -1 - for keyIndex, keyBytes := range sctm.keys { - if len(keyBytes) != len(dataKey) { - continue + if sctm.hasEmbedded { + for i := 0; i < l; i++ { + dataKey, err := d.asStringByte(k) + if err != nil { + return err } - fieldIndex = sctm.indexes[keyIndex] - for dataIndex := range dataKey { - if dataKey[dataIndex] != keyBytes[dataIndex] { - fieldIndex = -1 + fieldPath := []int(nil) + for keyIndex, keyBytes := range sctm.keys { + if len(keyBytes) != len(dataKey) { + continue + } + + found := true + for dataIndex := range dataKey { + if dataKey[dataIndex] != keyBytes[dataIndex] { + found = false + break + } + } + if found { + fieldPath = sctm.indexes[keyIndex] break } } - if fieldIndex >= 0 { - break + + if fieldPath != nil { + code, err := d.readSize1() + if err != nil { + return err + } + allowAlloc := !d.isCodeNil(code) + fieldValue, ok := getFieldByPath(rv, fieldPath, allowAlloc) + if ok { + err = d.decodeWithCode(code, fieldValue) + if err != nil { + return err + } + } else if !d.isCodeNil(code) { + return d.errorTemplate(code, k) + } + } else { + err = d.jumpOffset() + if err != nil { + return err + } } } - - if fieldIndex >= 0 { - err = d.decode(rv.Field(fieldIndex)) + } else { + for i := 0; i < l; i++ { + dataKey, err := d.asStringByte(k) if err != nil { return err } - } else { - err = d.jumpOffset() - if err != nil { - return err + + fieldIndex := -1 + for keyIndex, keyBytes := range sctm.keys { + if len(keyBytes) != len(dataKey) { + continue + } + + found := true + for dataIndex := range dataKey { + if dataKey[dataIndex] != keyBytes[dataIndex] { + found = false + break + } + } + if found { + fieldIndex = sctm.simpleIndexes[keyIndex] + break + } + } + + if fieldIndex >= 0 { + err = d.decode(rv.Field(fieldIndex)) + if err != nil { + return err + } + } else { + err = d.jumpOffset() + if err != nil { + return err + } } } } diff --git a/internal/stream/encoding/struct.go b/internal/stream/encoding/struct.go index db3cad1..2e7ba77 100644 --- a/internal/stream/encoding/struct.go +++ b/internal/stream/encoding/struct.go @@ -11,10 +11,21 @@ import ( ) type structCache struct { - indexes []int - names []string - omits []bool - noOmit bool + // common fields + names []string + omits []bool + noOmit bool + + // fast path detection + hasEmbedded bool + + // fast path (hasEmbedded == false): direct field access + simpleIndexes []int + + // embedded path (hasEmbedded == true): path-based access + indexes [][]int // field path (support for embedded structs) + omitPaths [][][]int // embedded omitempty parent paths + common.Common } @@ -22,6 +33,33 @@ var cachemap = sync.Map{} type structWriteFunc func(rv reflect.Value) error +// getFieldByPath returns the field value by following the path of indices. +// The bool indicates whether the path was reachable (no nil pointer in the path). +func getFieldByPath(rv reflect.Value, path []int) (reflect.Value, bool) { + for _, idx := range path { + // Handle pointer indirection if needed + if rv.Kind() == reflect.Ptr { + if rv.IsNil() { + // Return invalid value if pointer is nil + return reflect.Value{}, false + } + rv = rv.Elem() + } + rv = rv.Field(idx) + } + return rv, true +} + +func shouldOmitByParent(rv reflect.Value, omitPaths [][]int) bool { + for _, path := range omitPaths { + parentValue, ok := getFieldByPath(rv, path) + if !ok || parentValue.IsZero() { + return true + } + } + return false +} + func (e *encoder) getStructWriter(typ reflect.Type) structWriteFunc { for i := range extCoders { @@ -58,7 +96,13 @@ func (e *encoder) writeStructArray(rv reflect.Value) error { c := e.getStructCache(rv) // write format - num := len(c.indexes) + var num int + if c.hasEmbedded { + num = len(c.indexes) + } else { + num = len(c.simpleIndexes) + } + if num <= 0x0f { if err := e.setByte1Int(def.FixArray + num); err != nil { return err @@ -79,9 +123,21 @@ func (e *encoder) writeStructArray(rv reflect.Value) error { } } - for i := 0; i < num; i++ { - if err := e.create(rv.Field(c.indexes[i])); err != nil { - return err + if c.hasEmbedded { + for i := 0; i < num; i++ { + fieldValue, ok := getFieldByPath(rv, c.indexes[i]) + if shouldOmitByParent(rv, c.omitPaths[i]) || !ok { + fieldValue = reflect.Value{} + } + if err := e.create(fieldValue); err != nil { + return err + } + } + } else { + for i := 0; i < num; i++ { + if err := e.create(rv.Field(c.simpleIndexes[i])); err != nil { + return err + } } } return nil @@ -90,14 +146,22 @@ func (e *encoder) writeStructArray(rv reflect.Value) error { func (e *encoder) writeStructMap(rv reflect.Value) error { c := e.getStructCache(rv) - num := len(c.indexes) l := 0 - if c.noOmit { - l = num + if c.hasEmbedded { + num := len(c.indexes) + for i := 0; i < num; i++ { + fieldValue, ok := getFieldByPath(rv, c.indexes[i]) + if shouldOmitByParent(rv, c.omitPaths[i]) || !ok { + continue + } + if c.noOmit || !c.omits[i] || !fieldValue.IsZero() { + l++ + } + } } else { + num := len(c.simpleIndexes) for i := 0; i < num; i++ { - irv := rv.Field(c.indexes[i]) - if !c.omits[i] || !irv.IsZero() { + if c.noOmit || !c.omits[i] || !rv.Field(c.simpleIndexes[i]).IsZero() { l++ } } @@ -124,14 +188,33 @@ func (e *encoder) writeStructMap(rv reflect.Value) error { } } - for i := 0; i < num; i++ { - irv := rv.Field(c.indexes[i]) - if !c.omits[i] || !irv.IsZero() { - if err := e.writeString(c.names[i]); err != nil { - return err + if c.hasEmbedded { + num := len(c.indexes) + for i := 0; i < num; i++ { + fieldValue, ok := getFieldByPath(rv, c.indexes[i]) + if shouldOmitByParent(rv, c.omitPaths[i]) || !ok { + continue } - if err := e.create(irv); err != nil { - return err + if c.noOmit || !c.omits[i] || !fieldValue.IsZero() { + if err := e.writeString(c.names[i]); err != nil { + return err + } + if err := e.create(fieldValue); err != nil { + return err + } + } + } + } else { + num := len(c.simpleIndexes) + for i := 0; i < num; i++ { + fieldValue := rv.Field(c.simpleIndexes[i]) + if c.noOmit || !c.omits[i] || !fieldValue.IsZero() { + if err := e.writeString(c.names[i]); err != nil { + return err + } + if err := e.create(fieldValue); err != nil { + return err + } } } } @@ -145,21 +228,31 @@ func (e *encoder) getStructCache(rv reflect.Value) *structCache { return cache.(*structCache) } - num := rv.NumField() - c := &structCache{ - indexes: make([]int, 0, num), - names: make([]string, 0, num), - omits: make([]bool, 0, num), + c := &structCache{} + fields := e.CollectFields(t, nil) + + // detect embedded fields + hasEmbedded := false + for _, f := range fields { + if len(f.Path) > 1 || len(f.OmitPaths) > 0 { + hasEmbedded = true + break + } } + c.hasEmbedded = hasEmbedded + omitCount := 0 - for i := 0; i < num; i++ { - if ok, omit, name := e.CheckField(rv.Type().Field(i)); ok { - c.indexes = append(c.indexes, i) - c.names = append(c.names, name) - c.omits = append(c.omits, omit) - if omit { - omitCount++ - } + for _, field := range fields { + c.names = append(c.names, field.Name) + c.omits = append(c.omits, field.Omit) + if hasEmbedded { + c.indexes = append(c.indexes, field.Path) + c.omitPaths = append(c.omitPaths, field.OmitPaths) + } else { + c.simpleIndexes = append(c.simpleIndexes, field.Path[0]) + } + if field.Omit { + omitCount++ } } c.noOmit = omitCount == 0 diff --git a/msgpack_test.go b/msgpack_test.go index 618eaef..449590b 100644 --- a/msgpack_test.go +++ b/msgpack_test.go @@ -4,6 +4,7 @@ import ( "bytes" "encoding/binary" "encoding/hex" + "encoding/json" "errors" "fmt" "math" @@ -1666,10 +1667,19 @@ func TestStruct(t *testing.T) { t.Run("Embedded", func(t *testing.T) { testEmbedded(t) }) + t.Run("EmbeddedStruct", func(t *testing.T) { + testEmbeddedStruct(t) + }) t.Run("Jump", func(t *testing.T) { testStructJump(t) }) t.Run("UseCase", func(t *testing.T) { + b := msgpack.StructAsArray + defer func() { + msgpack.StructAsArray = b + }() + testStructUseCase(t) + msgpack.StructAsArray = true testStructUseCase(t) }) } @@ -1695,6 +1705,769 @@ func testEmbedded(t *testing.T) { encdec(t, arg) } +func testEmbeddedStruct(t *testing.T) { + b := msgpack.StructAsArray + defer func() { + msgpack.StructAsArray = b + }() + + t.Run("SimpleEmbedded", func(t *testing.T) { + type Embedded struct { + EmbeddedField int + Name string + } + type Parent struct { + Embedded + ParentField string + } + + original := Parent{ + Embedded: Embedded{EmbeddedField: 42, Name: "test"}, + ParentField: "parent", + } + + // Test with msgpack + _, err := msgpack.MarshalAsMap(original) + if err != nil { + t.Fatalf("msgpack.Marshal failed: %v", err) + } + + var msgDecoded Parent + + // Test stream encoding/decoding and cross-compatibility + for _, isArray := range []bool{false, true} { + msgpack.StructAsArray = isArray + encdec(t, encdecArg[Parent]{ + v: original, + skipEq: true, + vc: func(p Parent) error { + // Verify fields are promoted + tu.Equal(t, p.EmbeddedField, 42) + tu.Equal(t, p.Name, "test") + tu.Equal(t, p.ParentField, "parent") + msgDecoded = p + return nil + }, + }) + } + + // Compare with JSON + jsonBytes, _ := json.Marshal(original) + var jsonDecoded Parent + _ = json.Unmarshal(jsonBytes, &jsonDecoded) + + if msgDecoded.EmbeddedField != jsonDecoded.EmbeddedField || + msgDecoded.Name != jsonDecoded.Name || + msgDecoded.ParentField != jsonDecoded.ParentField { + t.Errorf("msgpack and json results differ:\nmsgpack: %+v\njson: %+v", + msgDecoded, jsonDecoded) + } + }) + + t.Run("EmbeddedWithTag", func(t *testing.T) { + type Embedded struct { + Field int + } + type Parent struct { + Embedded `msgpack:"emb"` + Regular string + } + + original := Parent{ + Embedded: Embedded{Field: 99}, + Regular: "value", + } + + _, err := msgpack.MarshalAsMap(original) + if err != nil { + t.Fatalf("msgpack.Marshal failed: %v", err) + } + + var msgDecoded Parent + + // Test stream encoding/decoding and cross-compatibility + for _, isArray := range []bool{false, true} { + msgpack.StructAsArray = isArray + encdec(t, encdecArg[Parent]{ + v: original, + skipEq: true, + vc: func(p Parent) error { + // Tagged embedded struct should not be promoted + tu.Equal(t, p.Embedded.Field, 99) + tu.Equal(t, p.Regular, "value") + msgDecoded = p + return nil + }, + }) + } + + // Compare with JSON + jsonBytes, _ := json.Marshal(original) + var jsonDecoded Parent + _ = json.Unmarshal(jsonBytes, &jsonDecoded) + + if msgDecoded.Embedded.Field != jsonDecoded.Embedded.Field || + msgDecoded.Regular != jsonDecoded.Regular { + t.Errorf("msgpack and json results differ:\nmsgpack: %+v\njson: %+v", + msgDecoded, jsonDecoded) + } + }) + + t.Run("MultiLevelEmbedding", func(t *testing.T) { + type Deep struct { + DeepField int + } + type Middle struct { + Deep + MiddleField string + } + type Top struct { + Middle + TopField bool + } + + original := Top{ + Middle: Middle{ + Deep: Deep{DeepField: 100}, + MiddleField: "middle", + }, + TopField: true, + } + + _, err := msgpack.MarshalAsMap(original) + if err != nil { + t.Fatalf("msgpack.Marshal failed: %v", err) + } + + var msgDecoded Top + + // Test stream encoding/decoding and cross-compatibility + for _, isArray := range []bool{false, true} { + msgpack.StructAsArray = isArray + encdec(t, encdecArg[Top]{ + v: original, + skipEq: true, + vc: func(p Top) error { + tu.Equal(t, p.DeepField, 100) + tu.Equal(t, p.MiddleField, "middle") + tu.Equal(t, p.TopField, true) + msgDecoded = p + return nil + }, + }) + } + + // Compare with JSON + jsonBytes, _ := json.Marshal(original) + var jsonDecoded Top + _ = json.Unmarshal(jsonBytes, &jsonDecoded) + + if msgDecoded.DeepField != jsonDecoded.DeepField || + msgDecoded.MiddleField != jsonDecoded.MiddleField || + msgDecoded.TopField != jsonDecoded.TopField { + t.Errorf("msgpack and json results differ:\nmsgpack: %+v\njson: %+v", + msgDecoded, jsonDecoded) + } + }) + + t.Run("FieldShadowing", func(t *testing.T) { + type Base struct { + Name string + Age int + } + type Derived struct { + Base + Name string // Shadows Base.Name + } + + original := Derived{ + Base: Base{Name: "base", Age: 30}, + Name: "derived", + } + + _, err := msgpack.MarshalAsMap(original) + if err != nil { + t.Fatalf("msgpack.Marshal failed: %v", err) + } + + var msgDecoded Derived + + // Test stream encoding/decoding and cross-compatibility + for _, isArray := range []bool{false, true} { + msgpack.StructAsArray = isArray + encdec(t, encdecArg[Derived]{ + v: original, + skipEq: true, + vc: func(p Derived) error { + tu.Equal(t, p.Name, "derived") + tu.Equal(t, p.Age, 30) + tu.Equal(t, p.Base.Name, "") + msgDecoded = p + return nil + }, + }) + } + + // Compare with JSON + jsonBytes, _ := json.Marshal(original) + var jsonDecoded Derived + _ = json.Unmarshal(jsonBytes, &jsonDecoded) + + if msgDecoded.Name != jsonDecoded.Name || msgDecoded.Age != jsonDecoded.Age { + t.Errorf("msgpack and json results differ:\nmsgpack: %+v\njson: %+v", + msgDecoded, jsonDecoded) + } + }) + + t.Run("TaggedFieldPriority", func(t *testing.T) { + type Tagged struct { + Name string `msgpack:"Name" json:"Name"` + } + type Plain struct { + Name string + } + type Derived struct { + Tagged + Plain + } + + original := Derived{ + Tagged: Tagged{Name: "tagged"}, + Plain: Plain{Name: "plain"}, + } + + _, err := msgpack.MarshalAsMap(original) + if err != nil { + t.Fatalf("msgpack.Marshal failed: %v", err) + } + + var msgDecoded Derived + + for _, isArray := range []bool{false, true} { + msgpack.StructAsArray = isArray + encdec(t, encdecArg[Derived]{ + v: original, + skipEq: true, + vc: func(p Derived) error { + tu.Equal(t, p.Tagged.Name, "tagged") + tu.Equal(t, p.Plain.Name, "") + msgDecoded = p + return nil + }, + }) + } + + jsonBytes, _ := json.Marshal(original) + var jsonDecoded Derived + _ = json.Unmarshal(jsonBytes, &jsonDecoded) + + if msgDecoded.Tagged.Name != jsonDecoded.Tagged.Name || + msgDecoded.Plain.Name != jsonDecoded.Plain.Name { + t.Errorf("msgpack and json results differ:\nmsgpack: %+v\njson: %+v", + msgDecoded, jsonDecoded) + } + }) + + t.Run("EmbeddedOmitEmpty", func(t *testing.T) { + type Embedded struct { + A int + B string + } + type Parent struct { + Embedded `msgpack:",omitempty"` + C int + } + + msgpack.StructAsArray = false + encdec(t, encdecArg[map[string]any]{ + v: Parent{Embedded: Embedded{}, C: 1}, + skipEq: true, + vc: func(p map[string]any) error { + if _, ok := p["A"]; ok { + return fmt.Errorf("embedded field A should be omitted") + } + if _, ok := p["B"]; ok { + return fmt.Errorf("embedded field B should be omitted") + } + if _, ok := p["C"]; !ok { + return fmt.Errorf("field C should be present") + } + return nil + }, + }) + + encdec(t, encdecArg[map[string]any]{ + v: Parent{Embedded: Embedded{A: 1, B: "b"}, C: 2}, + skipEq: true, + vc: func(p map[string]any) error { + if _, ok := p["A"]; !ok { + return fmt.Errorf("embedded field A should be present") + } + if _, ok := p["B"]; !ok { + return fmt.Errorf("embedded field B should be present") + } + if _, ok := p["C"]; !ok { + return fmt.Errorf("field C should be present") + } + return nil + }, + }) + + msgpack.StructAsArray = true + encdec(t, encdecArg[Parent]{ + v: Parent{Embedded: Embedded{}, C: 1}, + c: func(d []byte) bool { + return len(d) == 4 && + d[0] == def.FixArray+0x03 && + d[1] == 0x01 && + d[2] == def.Nil && + d[3] == def.Nil + }, + vc: func(p Parent) error { + if p.A != 0 || p.B != "" || p.C != 1 { + return fmt.Errorf("unexpected array decode: %+v", p) + } + return nil + }, + }) + + encdec(t, encdecArg[Parent]{ + v: Parent{Embedded: Embedded{A: 1, B: "b"}, C: 2}, + c: func(d []byte) bool { + return len(d) == 5 && + d[0] == def.FixArray+0x03 && + d[1] == 0x02 && + d[2] == 0x01 && + d[3] == 0xa1 && + d[4] == 'b' + }, + vc: func(p Parent) error { + if p.A != 1 || p.B != "b" || p.C != 2 { + return fmt.Errorf("unexpected array decode: %+v", p) + } + return nil + }, + }) + }) + + t.Run("AmbiguousFields", func(t *testing.T) { + type A struct { + Field string + } + type B struct { + Field string + } + type Derived struct { + A + B + } + + original := Derived{ + A: A{Field: "from A"}, + B: B{Field: "from B"}, + } + + _, err := msgpack.MarshalAsMap(original) + if err != nil { + t.Fatalf("msgpack.Marshal failed: %v", err) + } + + // Decode to map to check what fields are present + var msgDecoded map[string]interface{} + + // Test stream encoding/decoding and cross-compatibility + for _, isArray := range []bool{false} { + msgpack.StructAsArray = isArray + encdec(t, encdecArg[map[string]any]{ + v: original, + skipEq: true, + vc: func(p map[string]any) error { + // Ambiguous field should be omitted + if _, exists := p["Field"]; exists { + t.Errorf("Ambiguous field 'Field' should be omitted") + } + msgDecoded = p + return nil + }, + }) + } + + // Compare with JSON + jsonBytes, _ := json.Marshal(original) + var jsonMap map[string]interface{} + _ = json.Unmarshal(jsonBytes, &jsonMap) + + if len(msgDecoded) != len(jsonMap) { + t.Errorf("Field count differs: msgpack=%d, json=%d", len(msgDecoded), len(jsonMap)) + } + }) + + t.Run("PointerEmbedded", func(t *testing.T) { + type Base struct { + Value int + } + type Derived struct { + *Base + Extra string + } + + original := Derived{ + Base: &Base{Value: 123}, + Extra: "extra", + } + + _, err := msgpack.MarshalAsMap(original) + if err != nil { + t.Fatalf("msgpack.Marshal failed: %v", err) + } + + var msgDecoded Derived + + // Test stream encoding/decoding and cross-compatibility + for _, isArray := range []bool{false, true} { + msgpack.StructAsArray = isArray + encdec(t, encdecArg[Derived]{ + v: original, + skipEq: true, + vc: func(p Derived) error { + if p.Base == nil { + return fmt.Errorf("base pointer should not be nil after unmarshal") + } + tu.Equal(t, p.Value, 123) + tu.Equal(t, p.Extra, "extra") + msgDecoded = p + return nil + }, + }) + } + + // Compare with JSON + jsonBytes, _ := json.Marshal(original) + var jsonDecoded Derived + _ = json.Unmarshal(jsonBytes, &jsonDecoded) + + if msgDecoded.Value != jsonDecoded.Value || msgDecoded.Extra != jsonDecoded.Extra { + t.Errorf("msgpack and json results differ:\nmsgpack: %+v\njson: %+v", + msgDecoded, jsonDecoded) + } + }) + + t.Run("PointerEmbeddedNil", func(t *testing.T) { + type Base struct { + Value int + } + type Derived struct { + *Base + Extra string + } + + original := Derived{ + Base: nil, + Extra: "extra", + } + + _, err := msgpack.MarshalAsMap(original) + if err != nil { + t.Fatalf("msgpack.Marshal failed: %v", err) + } + + msgpack.StructAsArray = false + encdec(t, encdecArg[Derived]{ + v: original, + skipEq: true, + vc: func(p Derived) error { + if p.Base != nil { + return fmt.Errorf("base pointer should be nil after unmarshal") + } + tu.Equal(t, p.Extra, "extra") + return nil + }, + }) + + msgpack.StructAsArray = true + encdec(t, encdecArg[Derived]{ + v: original, + skipEq: true, + vc: func(p Derived) error { + if p.Base != nil { + return fmt.Errorf("base pointer should be nil after unmarshal") + } + tu.Equal(t, p.Extra, "extra") + return nil + }, + }) + /* + + var msgDecoded Derived + + for _, isArray := range []bool{false, true} { + msgpack.StructAsArray = isArray + encdec(t, encdecArg[Derived]{ + v: original, + skipEq: true, + vc: func(p Derived) error { + if p.Base != nil { + return fmt.Errorf("base pointer should be nil after unmarshal") + } + tu.Equal(t, p.Extra, "extra") + msgDecoded = p + return nil + }, + }) + } + + // Compare with JSON + jsonBytes, _ := json.Marshal(original) + var jsonDecoded Derived + _ = json.Unmarshal(jsonBytes, &jsonDecoded) + tu.Equal(t, jsonDecoded, msgDecoded) + */ + + original.Base = &Base{Value: 123} + for _, isArray := range []bool{false, true} { + msgpack.StructAsArray = isArray + encdec(t, encdecArg[Derived]{ + v: original, + skipEq: true, + vc: func(p Derived) error { + tu.Equal(t, p.Base.Value, 123) + tu.Equal(t, p.Extra, "extra") + return nil + }, + }) + } + + }) + + t.Run("TagSameAsStructName", func(t *testing.T) { + // Edge case: tag with same name as struct field + type Inner struct { + Value int + } + type Outer struct { + Inner `msgpack:"Inner"` + Other string + } + + original := Outer{ + Inner: Inner{Value: 456}, + Other: "test", + } + + _, err := msgpack.MarshalAsMap(original) + if err != nil { + t.Fatalf("msgpack.Marshal failed: %v", err) + } + + var msgDecoded Outer + + // Test stream encoding/decoding and cross-compatibility + for _, isArray := range []bool{false, true} { + msgpack.StructAsArray = isArray + encdec(t, encdecArg[Outer]{ + v: original, + skipEq: true, + vc: func(p Outer) error { + tu.Equal(t, p.Inner.Value, 456) + tu.Equal(t, p.Other, "test") + msgDecoded = p + return nil + }, + }) + } + + // Compare with JSON + jsonBytes, _ := json.Marshal(original) + var jsonDecoded Outer + _ = json.Unmarshal(jsonBytes, &jsonDecoded) + + if msgDecoded.Inner.Value != jsonDecoded.Inner.Value || + msgDecoded.Other != jsonDecoded.Other { + t.Errorf("msgpack and json results differ:\nmsgpack: %+v\njson: %+v", + msgDecoded, jsonDecoded) + } + }) + + t.Run("MixedEmbeddedAndRegular", func(t *testing.T) { + type Timestamp struct { + CreatedAt string + UpdatedAt string + } + type Document struct { + Timestamp + ID int + Content string + Title string + } + + original := Document{ + Timestamp: Timestamp{ + CreatedAt: "2024-01-01", + UpdatedAt: "2024-01-02", + }, + ID: 42, + Content: "test content", + Title: "test title", + } + + _, err := msgpack.MarshalAsMap(original) + if err != nil { + t.Fatalf("msgpack.Marshal failed: %v", err) + } + + var msgDecoded Document + + // Test stream encoding/decoding and cross-compatibility + for _, isArray := range []bool{false, true} { + msgpack.StructAsArray = isArray + encdec(t, encdecArg[Document]{ + v: original, + skipEq: true, + vc: func(p Document) error { + tu.Equal(t, p.CreatedAt, "2024-01-01") + tu.Equal(t, p.UpdatedAt, "2024-01-02") + tu.Equal(t, p.ID, 42) + tu.Equal(t, p.Content, "test content") + tu.Equal(t, p.Title, "test title") + msgDecoded = p + return nil + }, + }) + } + + // Compare with JSON + jsonBytes, _ := json.Marshal(original) + var jsonDecoded Document + _ = json.Unmarshal(jsonBytes, &jsonDecoded) + + if msgDecoded.CreatedAt != jsonDecoded.CreatedAt || + msgDecoded.UpdatedAt != jsonDecoded.UpdatedAt || + msgDecoded.ID != jsonDecoded.ID || + msgDecoded.Content != jsonDecoded.Content || + msgDecoded.Title != jsonDecoded.Title { + t.Errorf("msgpack and json results differ:\nmsgpack: %+v\njson: %+v", + msgDecoded, jsonDecoded) + } + }) + + t.Run("EmbeddedWithPrivateFields", func(t *testing.T) { + type Base struct { + Public string + private int + } + type Derived struct { + Base + Extra string + } + + original := Derived{ + Base: Base{Public: "public", private: 123}, + Extra: "extra", + } + + _, err := msgpack.MarshalAsMap(original) + if err != nil { + t.Fatalf("msgpack.Marshal failed: %v", err) + } + + var msgDecoded Derived + + // Test stream encoding/decoding and cross-compatibility + for _, isArray := range []bool{false, true} { + msgpack.StructAsArray = isArray + encdec(t, encdecArg[Derived]{ + v: original, + skipEq: true, + vc: func(p Derived) error { + // Only public fields should be marshaled + tu.Equal(t, p.Public, "public") + tu.Equal(t, p.Extra, "extra") + tu.Equal(t, p.private, 0) // private field should be zero value + msgDecoded = p + return nil + }, + }) + } + + // Compare with JSON + jsonBytes, _ := json.Marshal(original) + var jsonDecoded Derived + _ = json.Unmarshal(jsonBytes, &jsonDecoded) + + if msgDecoded.Public != jsonDecoded.Public || msgDecoded.Extra != jsonDecoded.Extra { + t.Errorf("msgpack and json results differ:\nmsgpack: %+v\njson: %+v", + msgDecoded, jsonDecoded) + } + }) + + t.Run("ComplexNestedEmbedding", func(t *testing.T) { + type A struct { + FieldA string + } + type B struct { + A + FieldB int + } + type C struct { + B + FieldC bool + } + type D struct { + C + FieldD float64 + } + + original := D{ + C: C{ + B: B{ + A: A{FieldA: "a"}, + FieldB: 1, + }, + FieldC: true, + }, + FieldD: 3.14, + } + + _, err := msgpack.MarshalAsMap(original) + if err != nil { + t.Fatalf("msgpack.Marshal failed: %v", err) + } + + var msgDecoded D + + // Test stream encoding/decoding and cross-compatibility + for _, isArray := range []bool{false, true} { + msgpack.StructAsArray = isArray + encdec(t, encdecArg[D]{ + v: original, + skipEq: true, + vc: func(p D) error { + // All fields should be promoted to top level + tu.Equal(t, p.FieldA, "a") + tu.Equal(t, p.FieldB, 1) + tu.Equal(t, p.FieldC, true) + tu.Equal(t, p.FieldD, 3.14) + msgDecoded = p + return nil + }, + }) + } + + // Compare with JSON + jsonBytes, _ := json.Marshal(original) + var jsonDecoded D + _ = json.Unmarshal(jsonBytes, &jsonDecoded) + + if msgDecoded.FieldA != jsonDecoded.FieldA || + msgDecoded.FieldB != jsonDecoded.FieldB || + msgDecoded.FieldC != jsonDecoded.FieldC || + msgDecoded.FieldD != jsonDecoded.FieldD { + t.Errorf("msgpack and json results differ:\nmsgpack: %+v\njson: %+v", + msgDecoded, jsonDecoded) + } + }) +} + func testStructTag(t *testing.T) { type vSt struct { One int `msgpack:"Three"` From 984f35fdcb5e857b052b36425612fc8c1bf9969d Mon Sep 17 00:00:00 2001 From: shamaton Date: Fri, 6 Feb 2026 08:28:07 +0900 Subject: [PATCH 2/2] ignore lint QF1008 --- msgpack_test.go | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/msgpack_test.go b/msgpack_test.go index 449590b..edb5818 100644 --- a/msgpack_test.go +++ b/msgpack_test.go @@ -1793,6 +1793,7 @@ func testEmbeddedStruct(t *testing.T) { skipEq: true, vc: func(p Parent) error { // Tagged embedded struct should not be promoted + //nolint:staticcheck // QF1008 for test tu.Equal(t, p.Embedded.Field, 99) tu.Equal(t, p.Regular, "value") msgDecoded = p @@ -1806,6 +1807,7 @@ func testEmbeddedStruct(t *testing.T) { var jsonDecoded Parent _ = json.Unmarshal(jsonBytes, &jsonDecoded) + //nolint:staticcheck // QF1008 for test if msgDecoded.Embedded.Field != jsonDecoded.Embedded.Field || msgDecoded.Regular != jsonDecoded.Regular { t.Errorf("msgpack and json results differ:\nmsgpack: %+v\njson: %+v", @@ -2230,6 +2232,7 @@ func testEmbeddedStruct(t *testing.T) { v: original, skipEq: true, vc: func(p Derived) error { + //nolint:staticcheck // QF1008 for test tu.Equal(t, p.Base.Value, 123) tu.Equal(t, p.Extra, "extra") return nil @@ -2268,6 +2271,7 @@ func testEmbeddedStruct(t *testing.T) { v: original, skipEq: true, vc: func(p Outer) error { + //nolint:staticcheck // QF1008 for test tu.Equal(t, p.Inner.Value, 456) tu.Equal(t, p.Other, "test") msgDecoded = p @@ -2281,6 +2285,7 @@ func testEmbeddedStruct(t *testing.T) { var jsonDecoded Outer _ = json.Unmarshal(jsonBytes, &jsonDecoded) + //nolint:staticcheck // QF1008 for test if msgDecoded.Inner.Value != jsonDecoded.Inner.Value || msgDecoded.Other != jsonDecoded.Other { t.Errorf("msgpack and json results differ:\nmsgpack: %+v\njson: %+v",