Skip to content

Commit a1494db

Browse files
committed
Add Int2/Uint2 ConvertDType support, improve boolean naming
- Register Int2/Uint2 → {Int8,Uint8,Int32,Int64,Float32,Float64} converters via execConvertPackedSubByte with valuesPerByte=4 - Add unpackInt2Bits and unpackUint2Bits for 2-bit packed data - Handle Int2/Uint2 in unpackWeightsToBuffer alongside Int4/Uint4 - Register mutableBytes and fillBuffer for Int2/Uint2 - Rename unpackedPooled → isUnpackedOwned, idxPooled → isIdxOwned for clarity (all buffers are pooled; the bool tracks ownership)
1 parent 18d38b4 commit a1494db

3 files changed

Lines changed: 101 additions & 8 deletions

File tree

backends/simplego/exec_convert_dtype.go

Lines changed: 47 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,12 +57,30 @@ func init() {
5757
convertDTypePairMap.Register(dtypes.Uint4, dtypes.Int64, priorityTyped, execConvertPackedSubByte[int64](unpackUint4Nibbles, 2))
5858
convertDTypePairMap.Register(dtypes.Uint4, dtypes.Uint8, priorityTyped, execConvertPackedSubByte[uint8](unpackUint4Nibbles, 2))
5959

60+
// Register sub-byte type conversions (Int2, Uint2).
61+
// Each byte packs 4 values (2 bits each). Bit layout: bits 0-1 = first value,
62+
// bits 2-3 = second, bits 4-5 = third, bits 6-7 = fourth.
63+
convertDTypePairMap.Register(dtypes.Int2, dtypes.Float32, priorityTyped, execConvertPackedSubByte[float32](unpackInt2Bits, 4))
64+
convertDTypePairMap.Register(dtypes.Int2, dtypes.Float64, priorityTyped, execConvertPackedSubByte[float64](unpackInt2Bits, 4))
65+
convertDTypePairMap.Register(dtypes.Int2, dtypes.Int32, priorityTyped, execConvertPackedSubByte[int32](unpackInt2Bits, 4))
66+
convertDTypePairMap.Register(dtypes.Int2, dtypes.Int64, priorityTyped, execConvertPackedSubByte[int64](unpackInt2Bits, 4))
67+
convertDTypePairMap.Register(dtypes.Int2, dtypes.Int8, priorityTyped, execConvertPackedSubByte[int8](unpackInt2Bits, 4))
68+
convertDTypePairMap.Register(dtypes.Uint2, dtypes.Float32, priorityTyped, execConvertPackedSubByte[float32](unpackUint2Bits, 4))
69+
convertDTypePairMap.Register(dtypes.Uint2, dtypes.Float64, priorityTyped, execConvertPackedSubByte[float64](unpackUint2Bits, 4))
70+
convertDTypePairMap.Register(dtypes.Uint2, dtypes.Int32, priorityTyped, execConvertPackedSubByte[int32](unpackUint2Bits, 4))
71+
convertDTypePairMap.Register(dtypes.Uint2, dtypes.Int64, priorityTyped, execConvertPackedSubByte[int64](unpackUint2Bits, 4))
72+
convertDTypePairMap.Register(dtypes.Uint2, dtypes.Uint8, priorityTyped, execConvertPackedSubByte[uint8](unpackUint2Bits, 4))
73+
6074
// Register mutableBytes and fillBuffer for sub-byte types.
61-
// Packed Int4/Uint4 buffers use []byte as the Go storage type.
75+
// Packed sub-byte buffers use []byte as the Go storage type.
6276
mutableBytesDTypeMap.Register(dtypes.Int4, priorityTyped, mutableBytesGeneric[byte])
6377
mutableBytesDTypeMap.Register(dtypes.Uint4, priorityTyped, mutableBytesGeneric[byte])
78+
mutableBytesDTypeMap.Register(dtypes.Int2, priorityTyped, mutableBytesGeneric[byte])
79+
mutableBytesDTypeMap.Register(dtypes.Uint2, priorityTyped, mutableBytesGeneric[byte])
6480
fillBufferDTypeMap.Register(dtypes.Int4, priorityTyped, fillBufferGeneric[byte])
6581
fillBufferDTypeMap.Register(dtypes.Uint4, priorityTyped, fillBufferGeneric[byte])
82+
fillBufferDTypeMap.Register(dtypes.Int2, priorityTyped, fillBufferGeneric[byte])
83+
fillBufferDTypeMap.Register(dtypes.Uint2, priorityTyped, fillBufferGeneric[byte])
6684

6785
// Manually register bool x bfloat16 conversion functions.
6886
convertDTypePairMap.Register(dtypes.BFloat16, dtypes.Bool, priorityTyped, execConvertDTypeBFloat16ToBool)
@@ -225,7 +243,34 @@ func unpackUint4Nibbles(packed []byte, dst []int8) {
225243
}
226244
}
227245

228-
// execConvertPackedSubByte returns a converter for packed sub-byte types (Int4, Uint4).
246+
// unpackInt2Bits unpacks packed Int2 data ([]byte, 4 signed 2-bit values per byte)
247+
// into dst []int8 (one value per element). Bit layout per byte:
248+
// bits 0-1 = first, bits 2-3 = second, bits 4-5 = third, bits 6-7 = fourth.
249+
// Signed range: [-2, 1] (values 2,3 sign-extend to -2,-1).
250+
func unpackInt2Bits(packed []byte, dst []int8) {
251+
for i, b := range packed {
252+
for j := range 4 {
253+
v := int8((b >> uint(2*j)) & 0x03)
254+
if v >= 2 {
255+
v -= 4
256+
}
257+
dst[4*i+j] = v
258+
}
259+
}
260+
}
261+
262+
// unpackUint2Bits unpacks packed Uint2 data ([]byte, 4 unsigned 2-bit values per byte)
263+
// into dst []int8 (one value per element). Unsigned range: [0, 3].
264+
func unpackUint2Bits(packed []byte, dst []int8) {
265+
for i, b := range packed {
266+
dst[4*i] = int8(b & 0x03)
267+
dst[4*i+1] = int8((b >> 2) & 0x03)
268+
dst[4*i+2] = int8((b >> 4) & 0x03)
269+
dst[4*i+3] = int8((b >> 6) & 0x03)
270+
}
271+
}
272+
273+
// execConvertPackedSubByte returns a converter for packed sub-byte types (Int4, Uint4, Int2, Uint2).
229274
// The unpackFn parameter selects signed vs unsigned nibble interpretation.
230275
// Sub-byte types are always stored packed as []byte.
231276
// valuesPerByte is the number of logical values per packed byte (e.g. 2 for 4-bit, 4 for 2-bit).

backends/simplego/exec_convert_dtype_test.go

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,54 @@ func TestConvertPackedInt4ToFloat32(t *testing.T) {
8484
assert.Equal(t, float32(-8), result[3])
8585
}
8686

87+
func TestConvertPackedInt2ToInt8(t *testing.T) {
88+
// Packed Int2 → Int8: unpacks 2-bit values with sign extension.
89+
// Byte 0b11_10_01_00 = 0xE4: values 0, 1, -2, -1.
90+
srcData := []byte{0xE4}
91+
srcShape := shapes.Make(dtypes.Int2, 4) // 4 Int2 elements packed in 1 byte
92+
srcBuf := &Buffer{shape: srcShape, flat: srcData, inUse: true}
93+
94+
dstShape := shapes.Make(dtypes.Int8, 4)
95+
dstBuf := &Buffer{shape: dstShape, flat: make([]int8, 4), inUse: true}
96+
97+
tmpAny, tmpErr := convertDTypePairMap.Get(dtypes.Int2, dtypes.Int8)
98+
if tmpErr != nil {
99+
panic(tmpErr)
100+
}
101+
convertFn := tmpAny.(convertFnType)
102+
convertFn(srcBuf, dstBuf)
103+
104+
result := dstBuf.flat.([]int8)
105+
assert.Equal(t, int8(0), result[0])
106+
assert.Equal(t, int8(1), result[1])
107+
assert.Equal(t, int8(-2), result[2])
108+
assert.Equal(t, int8(-1), result[3])
109+
}
110+
111+
func TestConvertPackedUint2ToUint8(t *testing.T) {
112+
// Packed Uint2 → Uint8: unpacks 2-bit values (no sign extension).
113+
// Byte 0b11_10_01_00 = 0xE4: values 0, 1, 2, 3.
114+
srcData := []byte{0xE4}
115+
srcShape := shapes.Make(dtypes.Uint2, 4) // 4 Uint2 elements packed in 1 byte
116+
srcBuf := &Buffer{shape: srcShape, flat: srcData, inUse: true}
117+
118+
dstShape := shapes.Make(dtypes.Uint8, 4)
119+
dstBuf := &Buffer{shape: dstShape, flat: make([]uint8, 4), inUse: true}
120+
121+
tmpAny, tmpErr := convertDTypePairMap.Get(dtypes.Uint2, dtypes.Uint8)
122+
if tmpErr != nil {
123+
panic(tmpErr)
124+
}
125+
convertFn := tmpAny.(convertFnType)
126+
convertFn(srcBuf, dstBuf)
127+
128+
result := dstBuf.flat.([]uint8)
129+
assert.Equal(t, uint8(0), result[0])
130+
assert.Equal(t, uint8(1), result[1])
131+
assert.Equal(t, uint8(2), result[2])
132+
assert.Equal(t, uint8(3), result[3])
133+
}
134+
87135
func TestExecSpecialOps_ConvertDType(t *testing.T) {
88136
// Test int32 to float32
89137
y0 := graph.MustExecOnce(backend, func(x *graph.Node) *graph.Node {

backends/simplego/exec_fused_quantized.go

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -76,11 +76,11 @@ func execFusedQuantizedDense(backend *Backend, node *Node, inputs []*Buffer, inp
7676

7777
// For packed sub-byte weights (from Bitcast), unpack nibbles via the buffer pool
7878
// and ConvertDType infrastructure. Non-sub-byte types pass through unchanged.
79-
unpackedBuf, unpackedPooled, err := unpackWeightsToBuffer(backend, wBuf)
79+
unpackedBuf, isUnpackedOwned, err := unpackWeightsToBuffer(backend, wBuf)
8080
if err != nil {
8181
return nil, err
8282
}
83-
if unpackedPooled {
83+
if isUnpackedOwned {
8484
defer backend.putBuffer(unpackedBuf)
8585
}
8686
wFlat := unpackedBuf.flat
@@ -122,9 +122,9 @@ func execFusedQuantizedDense(backend *Backend, node *Node, inputs []*Buffer, inp
122122
func unpackWeightsToBuffer(backend *Backend, wBuf *Buffer) (*Buffer, bool, error) {
123123
var targetDType dtypes.DType
124124
switch wBuf.shape.DType {
125-
case dtypes.Int4:
125+
case dtypes.Int4, dtypes.Int2:
126126
targetDType = dtypes.Int8
127-
case dtypes.Uint4:
127+
case dtypes.Uint4, dtypes.Uint2:
128128
targetDType = dtypes.Uint8
129129
default:
130130
return wBuf, false, nil
@@ -173,11 +173,11 @@ func execQuantizedEmbeddingLookup(backend *Backend, node *Node, inputs []*Buffer
173173
numIndices := indicesBuf.shape.Size()
174174

175175
// Convert indices to int64 via the buffer pool and ConvertDType infrastructure.
176-
idxBuf, idxPooled, err := convertIndicesToInt64(backend, indicesBuf)
176+
idxBuf, isIdxOwned, err := convertIndicesToInt64(backend, indicesBuf)
177177
if err != nil {
178178
return nil, errors.Wrapf(err, "QuantizedEmbeddingLookup")
179179
}
180-
if idxPooled {
180+
if isIdxOwned {
181181
defer backend.putBuffer(idxBuf)
182182
}
183183
indices := idxBuf.flat.([]int64)

0 commit comments

Comments
 (0)