77
88 "github.com/gomlx/gomlx/backends"
99 "github.com/gomlx/gomlx/pkg/core/dtypes"
10+ "github.com/gomlx/gomlx/pkg/core/shapes"
1011 "github.com/pkg/errors"
1112)
1213
@@ -73,9 +74,16 @@ func execFusedQuantizedDense(backend *Backend, node *Node, inputs []*Buffer, inp
7374 zeroPoints = zeroPointsBuf .flat .([]float32 )
7475 }
7576
76- // For packed sub-byte weights (from Bitcast), unpack nibbles before processing.
77- // Packed buffers have len(flat) < shape.Size() (2 nibbles per byte).
78- wFlat := unpackWeightsToInt8 (wBuf )
77+ // For packed sub-byte weights (from Bitcast), unpack nibbles via the buffer pool
78+ // and ConvertDType infrastructure. Non-sub-byte types pass through unchanged.
79+ unpackedBuf , unpackedPooled , err := unpackWeightsToBuffer (backend , wBuf )
80+ if err != nil {
81+ return nil , err
82+ }
83+ if unpackedPooled {
84+ defer backend .putBuffer (unpackedBuf )
85+ }
86+ wFlat := unpackedBuf .flat
7987
8088 switch data .scheme {
8189 case backends .QuantNF4 :
@@ -105,22 +113,37 @@ func execFusedQuantizedDense(backend *Backend, node *Node, inputs []*Buffer, inp
105113 return output , nil
106114}
107115
108- // unpackWeightsToInt8 unpacks sub-byte weight data (Int4, Uint4) from packed
109- // []byte storage into []int8 (one value per element) for the matmul kernel.
110- // For non-sub-byte types, returns the flat data as-is.
111- func unpackWeightsToInt8 (wBuf * Buffer ) any {
112- var unpackFn unpackNibblesFn
116+ // unpackWeightsToBuffer unpacks sub-byte weight data (Int4, Uint4) into a pooled
117+ // buffer using the ConvertDType infrastructure. For non-sub-byte types, returns the
118+ // original buffer unchanged.
119+ //
120+ // Returns the (possibly new) buffer, whether it was allocated from the pool
121+ // (caller must putBuffer), and any error.
122+ func unpackWeightsToBuffer (backend * Backend , wBuf * Buffer ) (* Buffer , bool , error ) {
123+ var targetDType dtypes.DType
113124 switch wBuf .shape .DType {
114- case dtypes .Uint4 :
115- unpackFn = unpackUint4Nibbles
116125 case dtypes .Int4 :
117- unpackFn = unpackInt4Nibbles
126+ targetDType = dtypes .Int8
127+ case dtypes .Uint4 :
128+ targetDType = dtypes .Uint8
118129 default :
119- return wBuf . flat
130+ return wBuf , false , nil
120131 }
121- unpacked := make ([]int8 , wBuf .shape .Size ())
122- unpackFn (wBuf .flat .([]byte ), unpacked )
123- return unpacked
132+
133+ outBuf , err := backend .getBuffer (targetDType , wBuf .shape .Size ())
134+ if err != nil {
135+ return nil , false , err
136+ }
137+ outBuf .shape = shapes .Make (targetDType , wBuf .shape .Dimensions ... )
138+
139+ convertFnAny , err := convertDTypePairMap .Get (wBuf .shape .DType , targetDType )
140+ if err != nil {
141+ backend .putBuffer (outBuf )
142+ return nil , false , err
143+ }
144+ convertFn := convertFnAny .(convertFnType )
145+ convertFn (wBuf , outBuf )
146+ return outBuf , true , nil
124147}
125148
126149// execQuantizedEmbeddingLookup performs quantized embedding lookup.
@@ -146,50 +169,61 @@ func execQuantizedEmbeddingLookup(backend *Backend, node *Node, inputs []*Buffer
146169 return nil , err
147170 }
148171
149- numIndices := indicesBuf .shape .Size () / indicesBuf .shape .Dimensions [indicesBuf .shape .Rank ()- 1 ]
172+ // Last dim is pre-validated to be 1, so total elements == number of indices.
173+ numIndices := indicesBuf .shape .Size ()
150174
151- indices , err := quantGatherIntSliceOfFlat (indicesBuf .flat , numIndices )
175+ // Convert indices to int64 via the buffer pool and ConvertDType infrastructure.
176+ idxBuf , idxPooled , err := convertIndicesToInt64 (backend , indicesBuf )
152177 if err != nil {
153178 return nil , errors .Wrapf (err , "QuantizedEmbeddingLookup" )
154179 }
155- vocabSize := dataBuf .shape .Dimensions [0 ]
156- for i , rowIdx := range indices {
180+ if idxPooled {
181+ defer backend .putBuffer (idxBuf )
182+ }
183+ indices := idxBuf .flat .([]int64 )
184+
185+ vocabSize := int64 (dataBuf .shape .Dimensions [0 ])
186+ for i , rowIdx := range indices [:numIndices ] {
157187 if rowIdx < 0 || rowIdx >= vocabSize {
158188 return nil , errors .Errorf ("QuantizedEmbeddingLookup: index %d out of range [0, %d)" , rowIdx , vocabSize )
159189 }
160- rowData := dataBytes [rowIdx * bytesPerRow : (rowIdx + 1 )* bytesPerRow ]
190+ rowStart := rowIdx * int64 (bytesPerRow )
191+ rowData := dataBytes [rowStart : rowStart + int64 (bytesPerRow )]
161192 dequantFn (rowData , out [i * K :(i + 1 )* K ])
162193 }
163194
164195 return output , nil
165196}
166197
167- // quantGatherIntSliceOfFlat converts a flat index slice ([]int32, []int64, or []int) to []int.
168- func quantGatherIntSliceOfFlat (flat any , n int ) ([]int , error ) {
169- switch s := flat .(type ) {
170- case []int32 :
171- return convertToIntSlice (s , n ), nil
172- case []int64 :
173- return convertToIntSlice (s , n ), nil
174- case []int :
175- return s [:n ], nil
176- default :
177- return nil , errors .Errorf ("unsupported indices type %T" , flat )
198+ // convertIndicesToInt64 converts an integer index buffer to int64 via the buffer
199+ // pool and ConvertDType infrastructure. If the buffer is already int64, it is
200+ // returned as-is.
201+ //
202+ // Returns the (possibly new) buffer, whether it was allocated from the pool
203+ // (caller must putBuffer), and any error.
204+ func convertIndicesToInt64 (backend * Backend , indicesBuf * Buffer ) (* Buffer , bool , error ) {
205+ if indicesBuf .shape .DType == dtypes .Int64 {
206+ return indicesBuf , false , nil
178207 }
179- }
208+ outBuf , err := backend .getBuffer (dtypes .Int64 , indicesBuf .shape .Size ())
209+ if err != nil {
210+ return nil , false , err
211+ }
212+ outBuf .shape = shapes .Make (dtypes .Int64 , indicesBuf .shape .Dimensions ... )
180213
181- // convertToIntSlice converts the first n elements of an integer slice to []int.
182- func convertToIntSlice [T int32 | int64 ](s []T , n int ) []int {
183- out := make ([]int , n )
184- for i := range n {
185- out [i ] = int (s [i ])
214+ convertFnAny , err := convertDTypePairMap .Get (indicesBuf .shape .DType , dtypes .Int64 )
215+ if err != nil {
216+ backend .putBuffer (outBuf )
217+ return nil , false , err
186218 }
187- return out
219+ convertFn := convertFnAny .(convertFnType )
220+ convertFn (indicesBuf , outBuf )
221+ return outBuf , true , nil
188222}
189223
190- // parallelTileCount returns the number of parallel work units that
224+ // quantizedDenseParallelTileCount returns the number of parallel work units that
191225// quantizedDenseParallel will dispatch for the given dimensions.
192- func parallelTileCount (backend * Backend , M , K , N int ) int {
226+ func quantizedDenseParallelTileCount (backend * Backend , M , K , N int ) int {
193227 totalWork := M * K * N
194228 if backend == nil || ! backend .workers .IsEnabled () || totalWork <= minParallelizeChunk {
195229 return M
@@ -202,7 +236,7 @@ func parallelTileCount(backend *Backend, M, K, N int) int {
202236}
203237
204238// quantizedDenseParallel parallelizes over M rows, or tiles over N columns when M=1.
205- // workerIdx is a dense index in [0, parallelTileCount ) identifying the work unit.
239+ // workerIdx is a dense index in [0, quantizedDenseParallelTileCount ) identifying the work unit.
206240func quantizedDenseParallel (backend * Backend , M , K , N int , rowFn func (workerIdx , m , nStart , nEnd int )) {
207241 totalWork := M * K * N
208242 if backend == nil || ! backend .workers .IsEnabled () || totalWork <= minParallelizeChunk {
0 commit comments