diff --git a/tok/index/helper.go b/tok/index/helper.go index f909a19c91d..cb7aeb60a95 100644 --- a/tok/index/helper.go +++ b/tok/index/helper.go @@ -8,14 +8,13 @@ package index import ( "encoding/binary" "math" - "reflect" "unsafe" c "github.com/dgraph-io/dgraph/v25/tok/constraints" "github.com/golang/glog" ) -// BytesAsFloatArray[T c.Float](encoded) converts encoded into a []T, +// BytesAsFloatArray converts encoded into a []T, // where T is either float32 or float64, depending on the value of floatBits. // Let floatBytes = floatBits/8. If len(encoded) % floatBytes is // not 0, it will ignore any trailing bytes, and simply convert floatBytes @@ -40,10 +39,8 @@ func BytesAsFloatArray[T c.Float](encoded []byte, retVal *[]T, floatBits int) { *retVal = make([]T, len(encoded)/floatBytes) } *retVal = (*retVal)[:0] - header := (*reflect.SliceHeader)(unsafe.Pointer(retVal)) - header.Data = uintptr(unsafe.Pointer(&encoded[0])) - header.Len = len(encoded) / floatBytes - header.Cap = len(encoded) / floatBytes + floatSlice := unsafe.Slice((*T)(unsafe.Pointer(&encoded[0])), len(encoded)/floatBytes) + *retVal = append(*retVal, floatSlice...) } func BytesToFloat[T c.Float](encoded []byte, floatBits int) T { diff --git a/tok/index/helper_test.go b/tok/index/helper_test.go index 2ce0b47d5da..f16c5f0acbc 100644 --- a/tok/index/helper_test.go +++ b/tok/index/helper_test.go @@ -17,6 +17,7 @@ import ( "github.com/dgraph-io/dgraph/v25/protos/pb" c "github.com/dgraph-io/dgraph/v25/tok/constraints" + "github.com/stretchr/testify/require" "github.com/viterin/vek/vek32" "google.golang.org/protobuf/proto" ) @@ -263,13 +264,13 @@ func BenchmarkEncodeDecodeUint64Matrix(b *testing.B) { }) } -func dotProductT[T c.Float](a, b []T, floatBits int) { - var dotProduct T +func dotProductT[T c.Float](a, b []T) { + var product T if len(a) != len(b) { return } for i := range a { - dotProduct += a[i] * b[i] + product += a[i] * b[i] } } @@ -295,7 +296,7 @@ func BenchmarkDotProduct(b *testing.B) { b.Run(fmt.Sprintf("vek:size=%d", len(data)), func(b *testing.B) { temp := make([]float32, num) - BytesAsFloatArray[float32](data, &temp, 32) + BytesAsFloatArray(data, &temp, 32) for k := 0; k < b.N; k++ { vek32.Dot(temp, temp) } @@ -305,7 +306,7 @@ func BenchmarkDotProduct(b *testing.B) { func(b *testing.B) { temp := make([]float32, num) - BytesAsFloatArray[float32](data, &temp, 32) + BytesAsFloatArray(data, &temp, 32) for k := 0; k < b.N; k++ { dotProduct(temp, temp) } @@ -316,9 +317,9 @@ func BenchmarkDotProduct(b *testing.B) { func(b *testing.B) { temp := make([]float32, num) - BytesAsFloatArray[float32](data, &temp, 32) + BytesAsFloatArray(data, &temp, 32) for k := 0; k < b.N; k++ { - dotProductT[float32](temp, temp, 32) + dotProductT(temp, temp) } }) } @@ -379,7 +380,7 @@ func littleEndianBytesAsFloatArray[T c.Float](encoded []byte, retVal *[]T, float } } -func BenchmarkFloatConverstion(b *testing.B) { +func BenchmarkFloatConversion(b *testing.B) { num := 1500 data := make([]byte, 64*num) _, err := rand.Read(data) @@ -391,7 +392,7 @@ func BenchmarkFloatConverstion(b *testing.B) { func(b *testing.B) { temp := make([]float32, num) for k := 0; k < b.N; k++ { - BytesAsFloatArray[float32](data, &temp, 64) + BytesAsFloatArray(data, &temp, 64) } }) @@ -399,7 +400,7 @@ func BenchmarkFloatConverstion(b *testing.B) { func(b *testing.B) { temp := make([]float32, num) for k := 0; k < b.N; k++ { - pointerFloatConversion[float32](data, &temp, 64) + pointerFloatConversion(data, &temp, 64) } }) @@ -407,7 +408,25 @@ func BenchmarkFloatConverstion(b *testing.B) { func(b *testing.B) { temp := make([]float32, num) for k := 0; k < b.N; k++ { - littleEndianBytesAsFloatArray[float32](data, &temp, 64) + littleEndianBytesAsFloatArray(data, &temp, 64) } }) } + +func TestBytesAsFloatArray(t *testing.T) { + var out32 []float32 + in32 := []float32{0.1, 1.123456789, 50000.00005} + + inf32 := unsafe.Slice((*byte)(unsafe.Pointer(&in32[0])), len(in32)*int(unsafe.Sizeof(in32[0]))) + + BytesAsFloatArray(inf32, &out32, 32) + require.Equal(t, in32, out32) + + var out64 []float64 + in64 := []float64{0.1, 1.123456789, 50000.00005} + + inf64 := unsafe.Slice((*byte)(unsafe.Pointer(&in64[0])), len(in64)*int(unsafe.Sizeof(in64[0]))) + + BytesAsFloatArray(inf64, &out64, 64) + require.Equal(t, in64, out64) +}