From c5e8d500e22e6a53d0a6010aae5ab24189bde4f0 Mon Sep 17 00:00:00 2001 From: Laurence Date: Tue, 9 Dec 2025 08:03:40 +0000 Subject: [PATCH] feat: add KV unmarshal feature with struct tag support Add Unmarshal() method to KVScanner that unmarshals KV entries into structs using struct tags, similar to encoding/json. This provides a convenient API for extracting structured data from SPOE messages. Features: - Struct tag support: `spoe:"keyname"` maps KV keys to struct fields - Type support: string, []byte, int32, int64, uint32, uint64, bool, netip.Addr - Optional fields via pointer types (set to nil if key not found) - Zero-allocation key matching using NameEquals() - Optimized reflection usage with cached field values and kinds - Comprehensive test coverage Performance: The unmarshal feature trades some performance for convenience. Benchmarks show manual iteration is ~3x faster, but unmarshal provides better developer experience for non-hot paths. Benchmark Results: ``` goos: linux goarch: amd64 pkg: github.com/dropmorepackets/haproxy-go/pkg/encoding cpu: 12th Gen Intel(R) Core(TM) i7-12700H BenchmarkUnmarshal-20 3883784 334.2 ns/op 614 B/op 8 allocs/op BenchmarkManualIteration-20 10422141 114.2 ns/op 157 B/op 5 allocs/op ``` Manual iteration: ~114 ns/op, 157 B/op, 5 allocs/op Unmarshal: ~334 ns/op, 614 B/op, 8 allocs/op The additional allocations come from: - Field info slice setup (one-time per unmarshal call) - Pointer field tracking map - Reflection overhead for type checking Usage Example: ```go type RequestData struct { Headers []byte `spoe:"headers"` Status int32 `spoe:"status-code"` IP netip.Addr `spoe:"client-ip"` Optional *string `spoe:"optional-field"` } var data RequestData if err := m.KV.Unmarshal(&data); err != nil { // handle error } ``` --- pkg/encoding/kvunmarshal.go | 231 +++++++++++++++++++++++++ pkg/encoding/kvunmarshal_bench_test.go | 180 +++++++++++++++++++ pkg/encoding/kvunmarshal_test.go | 189 ++++++++++++++++++++ 3 files changed, 600 insertions(+) create mode 100644 pkg/encoding/kvunmarshal.go create mode 100644 pkg/encoding/kvunmarshal_bench_test.go create mode 100644 pkg/encoding/kvunmarshal_test.go diff --git a/pkg/encoding/kvunmarshal.go b/pkg/encoding/kvunmarshal.go new file mode 100644 index 0000000..8c274d5 --- /dev/null +++ b/pkg/encoding/kvunmarshal.go @@ -0,0 +1,231 @@ +package encoding + +import ( + "fmt" + "net/netip" + "reflect" + "strings" +) + +const tagName = "spoe" + +// Unmarshal unmarshals KV entries from the scanner into the provided struct. +// The struct should have fields tagged with `spoe:"keyname"` to map KV entry +// names to struct fields. +// +// Supported field types: +// - string, []byte (for DataTypeString and DataTypeBinary) +// - int32, int64, uint32, uint64 (for integer types) +// - bool (for DataTypeBool) +// - netip.Addr (for DataTypeIPV4 and DataTypeIPV6) +// - pointer types for optional fields (nil if key not found) +// +// Example: +// +// type RequestData struct { +// Headers []byte `spoe:"headers"` +// Status int32 `spoe:"status-code"` +// IP netip.Addr `spoe:"client-ip"` +// Optional *string `spoe:"optional-field"` +// } +func (k *KVScanner) Unmarshal(v any) error { + rv := reflect.ValueOf(v) + if rv.Kind() != reflect.Pointer || rv.IsNil() { + return fmt.Errorf("unmarshal target must be a non-nil pointer to struct") + } + + rv = rv.Elem() + if rv.Kind() != reflect.Struct { + return fmt.Errorf("unmarshal target must be a pointer to struct") + } + + rt := rv.Type() + + // Build a slice of field info to avoid string allocations during lookup + type fieldInfo struct { + keyStr string // cached for NameEquals and error messages + fieldIdx int + field reflect.Value // cached to avoid repeated rv.Field() calls + fieldKind reflect.Kind // cached to avoid repeated Kind() calls + isPointer bool // cached to avoid repeated checks + } + fields := make([]fieldInfo, 0, rt.NumField()) + pointerFieldIndices := make([]int, 0, rt.NumField()) // track pointer field indices for final cleanup + for i := 0; i < rt.NumField(); i++ { + field := rt.Field(i) + tag := field.Tag.Get(tagName) + if tag == "" || tag == "-" { + continue + } + + // Handle comma-separated options (e.g., "keyname,omitempty") + // Use IndexByte to avoid allocation from strings.Split + commaIdx := strings.IndexByte(tag, ',') + var key string + if commaIdx >= 0 { + key = tag[:commaIdx] + } else { + key = tag + } + if key != "" { + fv := rv.Field(i) + fk := fv.Kind() + isPtr := fk == reflect.Pointer + fields = append(fields, fieldInfo{ + keyStr: key, + fieldIdx: i, + field: fv, + fieldKind: fk, + isPointer: isPtr, + }) + if isPtr { + pointerFieldIndices = append(pointerFieldIndices, i) + } + } + } + + entry := AcquireKVEntry() + defer ReleaseKVEntry(entry) + + // Track which pointer fields have been set (to clear unset ones later) + setPointerFields := make(map[int]bool, len(pointerFieldIndices)) + + for k.Next(entry) { + var fi *fieldInfo + // Use NameEquals to avoid string allocation during lookup + for i := range fields { + if entry.NameEquals(fields[i].keyStr) { + fi = &fields[i] + break + } + } + if fi == nil { + // Unknown key, skip it + continue + } + + if !fi.field.CanSet() { + return fmt.Errorf("field %s is not settable", rt.Field(fi.fieldIdx).Name) + } + + if err := setFieldValue(fi.field, fi.fieldKind, entry); err != nil { + return fmt.Errorf("field %s (key %q): %w", rt.Field(fi.fieldIdx).Name, fi.keyStr, err) + } + + // Track if this is a pointer field that was set + if fi.isPointer { + setPointerFields[fi.fieldIdx] = true + } + } + + if err := k.Error(); err != nil { + return fmt.Errorf("scanner error: %w", err) + } + + // Set pointer fields to nil if they weren't set (important for pooled structs) + // Only iterate through known pointer fields instead of all fields + for _, idx := range pointerFieldIndices { + if !setPointerFields[idx] { + rv.Field(idx).Set(reflect.Zero(rt.Field(idx).Type)) + } + } + + return nil +} + +func setFieldValue(field reflect.Value, fieldKind reflect.Kind, entry *KVEntry) error { + fieldType := field.Type() + + // Handle pointer types + if fieldKind == reflect.Pointer { + if entry.dataType == DataTypeNull { + field.Set(reflect.Zero(fieldType)) + return nil + } + + // Create new value of the pointed-to type + elemType := fieldType.Elem() + elemValue := reflect.New(elemType).Elem() + if err := setValue(elemValue, elemType.Kind(), entry); err != nil { + return err + } + field.Set(elemValue.Addr()) + return nil + } + + return setValue(field, fieldKind, entry) +} + +var netipAddrType = reflect.TypeOf((*netip.Addr)(nil)).Elem() + +func setValue(field reflect.Value, fieldKind reflect.Kind, entry *KVEntry) error { + fieldType := field.Type() + + switch fieldKind { + case reflect.String: + if entry.dataType != DataTypeString { + return fmt.Errorf("expected string, got %d", entry.dataType) + } + // Value() returns string for DataTypeString + field.SetString(entry.Value().(string)) + + case reflect.Slice: + if fieldType.Elem().Kind() != reflect.Uint8 { + return fmt.Errorf("unsupported slice type: %s", fieldType) + } + // []byte + if entry.dataType != DataTypeString && entry.dataType != DataTypeBinary { + return fmt.Errorf("expected string or binary, got %d", entry.dataType) + } + // Copy the bytes to avoid referencing the underlying buffer + val := entry.ValueBytes() + cp := make([]byte, len(val)) + copy(cp, val) + field.SetBytes(cp) + + case reflect.Int32: + if entry.dataType != DataTypeInt32 { + return fmt.Errorf("expected int32, got %d", entry.dataType) + } + field.SetInt(entry.ValueInt()) + + case reflect.Int64: + if entry.dataType != DataTypeInt64 { + return fmt.Errorf("expected int64, got %d", entry.dataType) + } + field.SetInt(entry.ValueInt()) + + case reflect.Uint32: + if entry.dataType != DataTypeUInt32 { + return fmt.Errorf("expected uint32, got %d", entry.dataType) + } + field.SetUint(uint64(entry.ValueInt())) + + case reflect.Uint64: + if entry.dataType != DataTypeUInt64 { + return fmt.Errorf("expected uint64, got %d", entry.dataType) + } + field.SetUint(uint64(entry.ValueInt())) + + case reflect.Bool: + if entry.dataType != DataTypeBool { + return fmt.Errorf("expected bool, got %d", entry.dataType) + } + field.SetBool(entry.ValueBool()) + + default: + // Check for netip.Addr (using cached type) + if fieldType == netipAddrType { + if entry.dataType != DataTypeIPV4 && entry.dataType != DataTypeIPV6 { + return fmt.Errorf("expected IP address, got %d", entry.dataType) + } + addr := entry.ValueAddr() + field.Set(reflect.ValueOf(addr)) + return nil + } + + return fmt.Errorf("unsupported field type: %s", fieldType) + } + + return nil +} diff --git a/pkg/encoding/kvunmarshal_bench_test.go b/pkg/encoding/kvunmarshal_bench_test.go new file mode 100644 index 0000000..8c86855 --- /dev/null +++ b/pkg/encoding/kvunmarshal_bench_test.go @@ -0,0 +1,180 @@ +package encoding + +import ( + "net/netip" + "sync" + "testing" +) + +// TestStruct represents a typical struct that would be unmarshaled +type TestStruct struct { + Headers []byte `spoe:"headers"` + Status int32 `spoe:"status-code"` + ClientIP netip.Addr `spoe:"client-ip"` + UserID uint64 `spoe:"user-id"` + Active bool `spoe:"active"` + Optional *string `spoe:"optional-field"` +} + +var testStructPool = sync.Pool{ + New: func() any { + return &TestStruct{} + }, +} + +func acquireTestStruct() *TestStruct { + return testStructPool.Get().(*TestStruct) +} + +func releaseTestStruct(s *TestStruct) { + // Reset all fields + s.Headers = nil + s.Status = 0 + s.ClientIP = netip.Addr{} + s.UserID = 0 + s.Active = false + s.Optional = nil + testStructPool.Put(s) +} + +// setupTestData creates a KV buffer with test data +func setupTestData() []byte { + buf := make([]byte, 1024) + w := NewKVWriter(buf, 0) + w.SetString("headers", "Content-Type: application/json") + w.SetInt32("status-code", 200) + addr := netip.MustParseAddr("192.168.1.100") + w.SetAddr("client-ip", addr) + w.SetUInt64("user-id", 12345) + w.SetBool("active", true) + w.SetString("optional-field", "optional-value") + return w.Bytes() +} + +func BenchmarkUnmarshal(b *testing.B) { + data := setupTestData() + b.ResetTimer() + + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + scanner := NewKVScanner(data, -1) + s := acquireTestStruct() + if err := scanner.Unmarshal(s); err != nil { + b.Fatal(err) + } + releaseTestStruct(s) + ReleaseKVScanner(scanner) + } + }) +} + +func BenchmarkManualIteration(b *testing.B) { + data := setupTestData() + b.ResetTimer() + + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + scanner := NewKVScanner(data, -1) + s := acquireTestStruct() + + entry := AcquireKVEntry() + for scanner.Next(entry) { + switch { + case entry.NameEquals("headers"): + val := entry.ValueBytes() + s.Headers = make([]byte, len(val)) + copy(s.Headers, val) + case entry.NameEquals("status-code"): + s.Status = int32(entry.ValueInt()) + case entry.NameEquals("client-ip"): + s.ClientIP = entry.ValueAddr() + case entry.NameEquals("user-id"): + s.UserID = uint64(entry.ValueInt()) + case entry.NameEquals("active"): + s.Active = entry.ValueBool() + case entry.NameEquals("optional-field"): + val := entry.Value().(string) + s.Optional = &val + } + } + + if err := scanner.Error(); err != nil { + b.Fatal(err) + } + + ReleaseKVEntry(entry) + releaseTestStruct(s) + ReleaseKVScanner(scanner) + } + }) +} + +// BenchmarkUnmarshalSequential runs unmarshal sequentially (no parallel) +func BenchmarkUnmarshalSequential(b *testing.B) { + data := setupTestData() + s := acquireTestStruct() + defer releaseTestStruct(s) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + scanner := NewKVScanner(data, -1) + if err := scanner.Unmarshal(s); err != nil { + b.Fatal(err) + } + ReleaseKVScanner(scanner) + // Reset struct manually for next iteration + s.Headers = nil + s.Status = 0 + s.ClientIP = netip.Addr{} + s.UserID = 0 + s.Active = false + s.Optional = nil + } +} + +// BenchmarkManualIterationSequential runs manual iteration sequentially +func BenchmarkManualIterationSequential(b *testing.B) { + data := setupTestData() + s := acquireTestStruct() + defer releaseTestStruct(s) + entry := AcquireKVEntry() + defer ReleaseKVEntry(entry) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + scanner := NewKVScanner(data, -1) + + for scanner.Next(entry) { + switch { + case entry.NameEquals("headers"): + val := entry.ValueBytes() + s.Headers = make([]byte, len(val)) + copy(s.Headers, val) + case entry.NameEquals("status-code"): + s.Status = int32(entry.ValueInt()) + case entry.NameEquals("client-ip"): + s.ClientIP = entry.ValueAddr() + case entry.NameEquals("user-id"): + s.UserID = uint64(entry.ValueInt()) + case entry.NameEquals("active"): + s.Active = entry.ValueBool() + case entry.NameEquals("optional-field"): + val := entry.Value().(string) + s.Optional = &val + } + } + + if err := scanner.Error(); err != nil { + b.Fatal(err) + } + + ReleaseKVScanner(scanner) + // Reset struct manually for next iteration + s.Headers = nil + s.Status = 0 + s.ClientIP = netip.Addr{} + s.UserID = 0 + s.Active = false + s.Optional = nil + } +} diff --git a/pkg/encoding/kvunmarshal_test.go b/pkg/encoding/kvunmarshal_test.go new file mode 100644 index 0000000..51c67e1 --- /dev/null +++ b/pkg/encoding/kvunmarshal_test.go @@ -0,0 +1,189 @@ +package encoding + +import ( + "net/netip" + "testing" +) + +func TestKVScanner_Unmarshal(t *testing.T) { + tests := []struct { + name string + data []byte + target interface{} + wantErr bool + check func(t *testing.T, v interface{}) + }{ + { + name: "basic types", + data: func() []byte { + buf := make([]byte, 1024) + w := NewKVWriter(buf, 0) + w.SetString("name", "test") + w.SetInt32("age", 25) + w.SetBool("active", true) + return w.Bytes() + }(), + target: &struct { + Name string `spoe:"name"` + Age int32 `spoe:"age"` + Active bool `spoe:"active"` + }{}, + wantErr: false, + check: func(t *testing.T, v interface{}) { + s := v.(*struct { + Name string `spoe:"name"` + Age int32 `spoe:"age"` + Active bool `spoe:"active"` + }) + if s.Name != "test" { + t.Errorf("Name = %q, want %q", s.Name, "test") + } + if s.Age != 25 { + t.Errorf("Age = %d, want %d", s.Age, 25) + } + if !s.Active { + t.Errorf("Active = %v, want %v", s.Active, true) + } + }, + }, + { + name: "ip address", + data: func() []byte { + buf := make([]byte, 1024) + w := NewKVWriter(buf, 0) + addr := netip.MustParseAddr("192.168.1.1") + w.SetAddr("ip", addr) + return w.Bytes() + }(), + target: &struct { + IP netip.Addr `spoe:"ip"` + }{}, + wantErr: false, + check: func(t *testing.T, v interface{}) { + s := v.(*struct { + IP netip.Addr `spoe:"ip"` + }) + if s.IP.String() != "192.168.1.1" { + t.Errorf("IP = %q, want %q", s.IP.String(), "192.168.1.1") + } + }, + }, + { + name: "optional pointer field", + data: func() []byte { + buf := make([]byte, 1024) + w := NewKVWriter(buf, 0) + w.SetString("required", "value") + // optional field not set + return w.Bytes() + }(), + target: &struct { + Required string `spoe:"required"` + Optional *string `spoe:"optional"` + }{}, + wantErr: false, + check: func(t *testing.T, v interface{}) { + s := v.(*struct { + Required string `spoe:"required"` + Optional *string `spoe:"optional"` + }) + if s.Required != "value" { + t.Errorf("Required = %q, want %q", s.Required, "value") + } + if s.Optional != nil { + t.Errorf("Optional = %v, want nil", s.Optional) + } + }, + }, + { + name: "unknown keys ignored", + data: func() []byte { + buf := make([]byte, 1024) + w := NewKVWriter(buf, 0) + w.SetString("known", "value") + w.SetString("unknown", "ignored") + return w.Bytes() + }(), + target: &struct { + Known string `spoe:"known"` + }{}, + wantErr: false, + check: func(t *testing.T, v interface{}) { + s := v.(*struct { + Known string `spoe:"known"` + }) + if s.Known != "value" { + t.Errorf("Known = %q, want %q", s.Known, "value") + } + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + scanner := NewKVScanner(tt.data, -1) + err := scanner.Unmarshal(tt.target) + if (err != nil) != tt.wantErr { + t.Errorf("Unmarshal() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !tt.wantErr && tt.check != nil { + tt.check(t, tt.target) + } + }) + } +} + +func TestKVScanner_Unmarshal_Errors(t *testing.T) { + tests := []struct { + name string + data []byte + target interface{} + }{ + { + name: "not a pointer", + data: []byte{}, + target: struct { + Name string `spoe:"name"` + }{}, + }, + { + name: "nil pointer", + data: []byte{}, + target: (*struct { + Name string `spoe:"name"` + })(nil), + }, + { + name: "not a struct", + data: []byte{}, + target: func() *int { + v := 0 + return &v + }(), + }, + { + name: "type mismatch", + data: func() []byte { + buf := make([]byte, 1024) + w := NewKVWriter(buf, 0) + w.SetString("name", "test") + return w.Bytes() + }(), + target: &struct { + Name int32 `spoe:"name"` + }{}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + scanner := NewKVScanner(tt.data, -1) + err := scanner.Unmarshal(tt.target) + if err == nil { + t.Errorf("Unmarshal() expected error, got nil") + } + }) + } +} +