Skip to content

Commit 5250386

Browse files
committed
Address PR review: buffer pool, inline shifts, GGML refs, naming
- Use buffer pool + ConvertDType for weight unpacking and index conversion instead of ad-hoc allocations (unpackWeightsToBuffer, convertIndicesToInt64) - Inline shift operations following binary ops pattern to eliminate per-element closure overhead (shiftLeftOp, shiftRightArithmeticOp, shiftRightLogicalUnsignedOp, shiftRightLogicalSignedOp) - Rename parallelTileCount → quantizedDenseParallelTileCount - Simplify numIndices calculation (last dim pre-validated as 1) - Use tgtIsUint8 GoType check in exec_bitcast instead of Bits() < 8 - Add GGML format references and doc links to fused_ops.go - Add "Follow Existing Patterns" guidance to AGENTS.md
1 parent 66fe00e commit 5250386

7 files changed

Lines changed: 214 additions & 106 deletions

File tree

.agents/AGENTS.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,13 @@ an error, to simplify the code. But everywhere else, use standard Go error handl
8282
- Use `any` instead of `interface{}`.
8383
- Organize tests in hierarchies using `t.Run()` to group related tests.
8484

85+
### Follow Existing Patterns
86+
87+
Before writing new code, read neighboring files in the same package to understand the established
88+
patterns (buffer management, dtype dispatch, parallelization, etc.). Reuse existing infrastructure
89+
rather than writing ad-hoc implementations. When in doubt, match the style and approach of the
90+
closest existing operation.
91+
8592
### Copyright Notes
8693

8794
Normal code files are prefixed with the following copyright line:

backends/fused_ops.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -322,7 +322,7 @@ type FusedOps interface {
322322
// QuantizedEmbeddingLookup performs a quantized embedding lookup (row gather)
323323
// with on-the-fly dequantization.
324324
//
325-
// This is the quantized analogue of embedding lookup, inspired by
325+
// This is the quantized analogue of Gather for embedding lookups, inspired by
326326
// llama.cpp's ggml_get_rows. For now it is only implemented for the GGML
327327
// quantization scheme, but could be extended for others if/when needed.
328328
//

backends/simplego/exec_bitcast.go

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
package simplego
44

55
import (
6+
"reflect"
7+
68
"github.com/gomlx/gomlx/backends"
79
)
810

@@ -41,7 +43,8 @@ func execBitcast(backend *Backend, node *Node, inputs []*Buffer, inputsOwned []b
4143
// target use the same underlying Go storage type. Sub-byte types
4244
// (Int2, Uint2, Int4, Uint4) all store as []uint8.
4345
_, srcIsUint8 := src.flat.([]uint8)
44-
canReuse = srcIsUint8 && targetDType.Bits() < 8
46+
tgtIsUint8 := targetDType.GoType().Kind() == reflect.Uint8
47+
canReuse = srcIsUint8 && tgtIsUint8
4548
}
4649
}
4750
if canReuse {

backends/simplego/exec_fused_quantized.go

Lines changed: 75 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import (
77

88
"github.com/gomlx/gomlx/backends"
99
"github.com/gomlx/gomlx/pkg/core/dtypes"
10+
"github.com/gomlx/gomlx/pkg/core/shapes"
1011
"github.com/pkg/errors"
1112
)
1213

@@ -73,9 +74,16 @@ func execFusedQuantizedDense(backend *Backend, node *Node, inputs []*Buffer, inp
7374
zeroPoints = zeroPointsBuf.flat.([]float32)
7475
}
7576

76-
// For packed sub-byte weights (from Bitcast), unpack nibbles before processing.
77-
// Packed buffers have len(flat) < shape.Size() (2 nibbles per byte).
78-
wFlat := unpackWeightsToInt8(wBuf)
77+
// For packed sub-byte weights (from Bitcast), unpack nibbles via the buffer pool
78+
// and ConvertDType infrastructure. Non-sub-byte types pass through unchanged.
79+
unpackedBuf, unpackedPooled, err := unpackWeightsToBuffer(backend, wBuf)
80+
if err != nil {
81+
return nil, err
82+
}
83+
if unpackedPooled {
84+
defer backend.putBuffer(unpackedBuf)
85+
}
86+
wFlat := unpackedBuf.flat
7987

8088
switch data.scheme {
8189
case backends.QuantNF4:
@@ -105,22 +113,37 @@ func execFusedQuantizedDense(backend *Backend, node *Node, inputs []*Buffer, inp
105113
return output, nil
106114
}
107115

108-
// unpackWeightsToInt8 unpacks sub-byte weight data (Int4, Uint4) from packed
109-
// []byte storage into []int8 (one value per element) for the matmul kernel.
110-
// For non-sub-byte types, returns the flat data as-is.
111-
func unpackWeightsToInt8(wBuf *Buffer) any {
112-
var unpackFn unpackNibblesFn
116+
// unpackWeightsToBuffer unpacks sub-byte weight data (Int4, Uint4) into a pooled
117+
// buffer using the ConvertDType infrastructure. For non-sub-byte types, returns the
118+
// original buffer unchanged.
119+
//
120+
// Returns the (possibly new) buffer, whether it was allocated from the pool
121+
// (caller must putBuffer), and any error.
122+
func unpackWeightsToBuffer(backend *Backend, wBuf *Buffer) (*Buffer, bool, error) {
123+
var targetDType dtypes.DType
113124
switch wBuf.shape.DType {
114-
case dtypes.Uint4:
115-
unpackFn = unpackUint4Nibbles
116125
case dtypes.Int4:
117-
unpackFn = unpackInt4Nibbles
126+
targetDType = dtypes.Int8
127+
case dtypes.Uint4:
128+
targetDType = dtypes.Uint8
118129
default:
119-
return wBuf.flat
130+
return wBuf, false, nil
120131
}
121-
unpacked := make([]int8, wBuf.shape.Size())
122-
unpackFn(wBuf.flat.([]byte), unpacked)
123-
return unpacked
132+
133+
outBuf, err := backend.getBuffer(targetDType, wBuf.shape.Size())
134+
if err != nil {
135+
return nil, false, err
136+
}
137+
outBuf.shape = shapes.Make(targetDType, wBuf.shape.Dimensions...)
138+
139+
convertFnAny, err := convertDTypePairMap.Get(wBuf.shape.DType, targetDType)
140+
if err != nil {
141+
backend.putBuffer(outBuf)
142+
return nil, false, err
143+
}
144+
convertFn := convertFnAny.(convertFnType)
145+
convertFn(wBuf, outBuf)
146+
return outBuf, true, nil
124147
}
125148

126149
// execQuantizedEmbeddingLookup performs quantized embedding lookup.
@@ -146,50 +169,61 @@ func execQuantizedEmbeddingLookup(backend *Backend, node *Node, inputs []*Buffer
146169
return nil, err
147170
}
148171

149-
numIndices := indicesBuf.shape.Size() / indicesBuf.shape.Dimensions[indicesBuf.shape.Rank()-1]
172+
// Last dim is pre-validated to be 1, so total elements == number of indices.
173+
numIndices := indicesBuf.shape.Size()
150174

151-
indices, err := quantGatherIntSliceOfFlat(indicesBuf.flat, numIndices)
175+
// Convert indices to int64 via the buffer pool and ConvertDType infrastructure.
176+
idxBuf, idxPooled, err := convertIndicesToInt64(backend, indicesBuf)
152177
if err != nil {
153178
return nil, errors.Wrapf(err, "QuantizedEmbeddingLookup")
154179
}
155-
vocabSize := dataBuf.shape.Dimensions[0]
156-
for i, rowIdx := range indices {
180+
if idxPooled {
181+
defer backend.putBuffer(idxBuf)
182+
}
183+
indices := idxBuf.flat.([]int64)
184+
185+
vocabSize := int64(dataBuf.shape.Dimensions[0])
186+
for i, rowIdx := range indices[:numIndices] {
157187
if rowIdx < 0 || rowIdx >= vocabSize {
158188
return nil, errors.Errorf("QuantizedEmbeddingLookup: index %d out of range [0, %d)", rowIdx, vocabSize)
159189
}
160-
rowData := dataBytes[rowIdx*bytesPerRow : (rowIdx+1)*bytesPerRow]
190+
rowStart := rowIdx * int64(bytesPerRow)
191+
rowData := dataBytes[rowStart : rowStart+int64(bytesPerRow)]
161192
dequantFn(rowData, out[i*K:(i+1)*K])
162193
}
163194

164195
return output, nil
165196
}
166197

167-
// quantGatherIntSliceOfFlat converts a flat index slice ([]int32, []int64, or []int) to []int.
168-
func quantGatherIntSliceOfFlat(flat any, n int) ([]int, error) {
169-
switch s := flat.(type) {
170-
case []int32:
171-
return convertToIntSlice(s, n), nil
172-
case []int64:
173-
return convertToIntSlice(s, n), nil
174-
case []int:
175-
return s[:n], nil
176-
default:
177-
return nil, errors.Errorf("unsupported indices type %T", flat)
198+
// convertIndicesToInt64 converts an integer index buffer to int64 via the buffer
199+
// pool and ConvertDType infrastructure. If the buffer is already int64, it is
200+
// returned as-is.
201+
//
202+
// Returns the (possibly new) buffer, whether it was allocated from the pool
203+
// (caller must putBuffer), and any error.
204+
func convertIndicesToInt64(backend *Backend, indicesBuf *Buffer) (*Buffer, bool, error) {
205+
if indicesBuf.shape.DType == dtypes.Int64 {
206+
return indicesBuf, false, nil
178207
}
179-
}
208+
outBuf, err := backend.getBuffer(dtypes.Int64, indicesBuf.shape.Size())
209+
if err != nil {
210+
return nil, false, err
211+
}
212+
outBuf.shape = shapes.Make(dtypes.Int64, indicesBuf.shape.Dimensions...)
180213

181-
// convertToIntSlice converts the first n elements of an integer slice to []int.
182-
func convertToIntSlice[T int32 | int64](s []T, n int) []int {
183-
out := make([]int, n)
184-
for i := range n {
185-
out[i] = int(s[i])
214+
convertFnAny, err := convertDTypePairMap.Get(indicesBuf.shape.DType, dtypes.Int64)
215+
if err != nil {
216+
backend.putBuffer(outBuf)
217+
return nil, false, err
186218
}
187-
return out
219+
convertFn := convertFnAny.(convertFnType)
220+
convertFn(indicesBuf, outBuf)
221+
return outBuf, true, nil
188222
}
189223

190-
// parallelTileCount returns the number of parallel work units that
224+
// quantizedDenseParallelTileCount returns the number of parallel work units that
191225
// quantizedDenseParallel will dispatch for the given dimensions.
192-
func parallelTileCount(backend *Backend, M, K, N int) int {
226+
func quantizedDenseParallelTileCount(backend *Backend, M, K, N int) int {
193227
totalWork := M * K * N
194228
if backend == nil || !backend.workers.IsEnabled() || totalWork <= minParallelizeChunk {
195229
return M
@@ -202,7 +236,7 @@ func parallelTileCount(backend *Backend, M, K, N int) int {
202236
}
203237

204238
// quantizedDenseParallel parallelizes over M rows, or tiles over N columns when M=1.
205-
// workerIdx is a dense index in [0, parallelTileCount) identifying the work unit.
239+
// workerIdx is a dense index in [0, quantizedDenseParallelTileCount) identifying the work unit.
206240
func quantizedDenseParallel(backend *Backend, M, K, N int, rowFn func(workerIdx, m, nStart, nEnd int)) {
207241
totalWork := M * K * N
208242
if backend == nil || !backend.workers.IsEnabled() || totalWork <= minParallelizeChunk {

backends/simplego/exec_fused_quantized_ggml.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ func quantizedDenseGGML(backend *Backend, x []float32, weights []uint8, bias, ou
103103
}
104104

105105
// Pre-allocate per-worker scratch buffers to avoid heap allocation per tile invocation.
106-
numWorkers := parallelTileCount(backend, M, K, N)
106+
numWorkers := quantizedDenseParallelTileCount(backend, M, K, N)
107107
scratchBufs := make([][]float32, numWorkers)
108108
for i := range scratchBufs {
109109
scratchBufs[i] = make([]float32, K)

0 commit comments

Comments
 (0)