From 4201858f4248a1048e7dbb9772e52ce67a265ffd Mon Sep 17 00:00:00 2001 From: David Ndungu Date: Mon, 6 Apr 2026 14:06:48 -0700 Subject: [PATCH 1/3] refactor(compute): extract shared matmul helpers to gpu_engine_matmul.go Consolidate repeated patterns from 14 quantized matmul methods into 6 shared helpers: uploadRawBytes, aShapeCheck2D, bweightShapeMKN, quantGemvResult, dequantSgemm, and sgemmNTOrFallback. Each original method is now a thin wrapper calling these helpers, reducing gpu_engine.go by 797 lines (net -557 across both files). Zero behavioral changes -- all method signatures remain identical. --- compute/gpu_engine.go | 1259 +++++++--------------------------- compute/gpu_engine_matmul.go | 240 +++++++ 2 files changed, 471 insertions(+), 1028 deletions(-) create mode 100644 compute/gpu_engine_matmul.go diff --git a/compute/gpu_engine.go b/compute/gpu_engine.go index 309e96a..34edc20 100644 --- a/compute/gpu_engine.go +++ b/compute/gpu_engine.go @@ -1216,59 +1216,28 @@ func (e *GPUEngine[T]) MatMulTransposeB(ctx context.Context, a, b *tensor.Tensor // matMulQ4 handles GPU Q4_0 dequant-GEMM: C = dequant(A_q4) * B. // Only supports unbatched 2D for now; batched Q4 falls back to CPU. func (e *GPUEngine[T]) matMulQ4(ctx context.Context, qs *tensor.Q4Storage, a, b *tensor.TensorNumeric[T], dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error) { - aShape := a.Shape() - bShape := b.Shape() - - if len(aShape) < 2 || len(bShape) < 2 { - return e.cpu.MatMul(ctx, a, b, dst...) - } - - m := aShape[len(aShape)-2] - k := aShape[len(aShape)-1] - n := bShape[len(bShape)-1] - - // Only handle unbatched 2D for now. - if len(aShape) > 2 || len(bShape) > 2 { - return e.cpu.MatMul(ctx, a, b, dst...) - } - - if k%32 != 0 { + m, k, n, fallback := aShapeCheck2D(a.Shape(), b.Shape(), 32) + if fallback { return e.cpu.MatMul(ctx, a, b, dst...) } e.setDevice() - // Use pre-uploaded Q4 GPU pointer if available; otherwise upload now. - var devA unsafe.Pointer - var freeA func() - if ptr, _, _ := qs.GPUPtr(); ptr != nil { - devA = ptr - freeA = func() {} // pre-uploaded; do not free - } else { - bpr := k / 32 - aBytes := qs.RawBytesGPU(bpr) - var err error - devA, err = e.pool.Alloc(e.deviceID, len(aBytes)) - if err != nil { - return e.cpu.MatMul(ctx, a, b, dst...) - } - freeA = func() { e.pool.Free(e.deviceID, devA, len(aBytes)) } - if err := e.runtime.Memcpy(devA, unsafe.Pointer(&aBytes[0]), len(aBytes), gpuapi.MemcpyHostToDevice); err != nil { - freeA() - return e.cpu.MatMul(ctx, a, b, dst...) - } + bpr := k / 32 + ptr, _, _ := qs.GPUPtr() + devA, freeA, err := e.uploadRawBytes(ptr, qs.RawBytesGPU(bpr)) + if err != nil { + return e.cpu.MatMul(ctx, a, b, dst...) } defer freeA() - // Upload B (float32) to GPU. devB, cleanupB, err := getDevicePtr(e, b) if err != nil { return e.cpu.MatMul(ctx, a, b, dst...) } defer cleanupB() - // Allocate output C. - cSize := m * n * int(unsafe.Sizeof(float32(0))) + cSize := m * n * f32Size devC, err := e.pool.Alloc(e.deviceID, cSize) if err != nil { return e.cpu.MatMul(ctx, a, b, dst...) @@ -1295,90 +1264,42 @@ func (e *GPUEngine[T]) matMulQ4BWeight(ctx context.Context, a *tensor.TensorNume if debugGPU { fmt.Fprintf(os.Stderr, "matMulQ4BWeight: aShape=%v bShape=%v GPUPtr=%v\n", a.Shape(), b.Shape(), func() bool { p, _, _ := qs.GPUPtr(); return p != nil }()) } - aShape := a.Shape() - bShape := b.Shape() - - if len(aShape) < 2 || len(bShape) < 2 { - return e.cpu.MatMul(ctx, a, b, dst...) - } - - // B must be 2D (virtual-transposed weight). - if len(bShape) > 2 { - return e.cpu.MatMul(ctx, a, b, dst...) - } - - // Flatten A's batch dims: [batch..., m, k] -> [m_total, k] - k := aShape[len(aShape)-1] - m := 1 - for i := 0; i < len(aShape)-1; i++ { - m *= aShape[i] - } - n := bShape[1] // columns of B (after virtual transpose) - // Q4 original layout is [N, K]. Verify K is a multiple of 32. - if k%32 != 0 { + m, k, n, outShape, fallback := bweightShapeMKN(a.Shape(), b.Shape(), 32) + if fallback { return e.cpu.MatMul(ctx, a, b, dst...) } - // Build output shape: [batch..., m_last, n] matching standard MatMul broadcast. - outShape := make([]int, len(aShape)) - copy(outShape, aShape[:len(aShape)-1]) - outShape[len(outShape)-1] = n - e.setDevice() - // Get Q4 device pointer (pre-uploaded or upload now). - var devQ4 unsafe.Pointer - var freeQ4 func() - if ptr, _, _ := qs.GPUPtr(); ptr != nil { - devQ4 = ptr - freeQ4 = func() {} - } else { - bpr := k / 32 - q4Bytes := qs.RawBytesGPU(bpr) - var err error - devQ4, err = e.pool.Alloc(e.deviceID, len(q4Bytes)) - if err != nil { - return e.cpu.MatMul(ctx, a, b, dst...) - } - freeQ4 = func() { e.pool.Free(e.deviceID, devQ4, len(q4Bytes)) } - if err := e.runtime.Memcpy(devQ4, unsafe.Pointer(&q4Bytes[0]), len(q4Bytes), gpuapi.MemcpyHostToDevice); err != nil { - freeQ4() - return e.cpu.MatMul(ctx, a, b, dst...) - } + bpr := k / 32 + ptr, _, _ := qs.GPUPtr() + devQ4, freeQ4, err := e.uploadRawBytes(ptr, qs.RawBytesGPU(bpr)) + if err != nil { + return e.cpu.MatMul(ctx, a, b, dst...) } defer freeQ4() - // Upload A to GPU as F32. A's data is contiguous [m, k] regardless of - // original batch shape, so the kernel sees it as a flat 2D matrix. devA, cleanupA, err := getDevicePtr(e, a) if err != nil { return e.cpu.MatMul(ctx, a, b, dst...) } defer cleanupA() - f32Size := int(unsafe.Sizeof(float32(0))) + q4DataOff := tensor.Q4GPUDataOffset(qs.NumBlocks()) if m == 1 { // GEMV fast path: C_temp[N, 1] = gemm_q4(B_q4[N,K], A^T[K,1]) - // A is [1, K], A^T is [K, 1] -- same data, just different shape. - cSize := n * f32Size - devC, err := e.pool.Alloc(e.deviceID, cSize) + result, err := e.quantGemvResult(outShape, n, func(devC unsafe.Pointer) error { + return e.kernels.GemmQ4F32(devQ4, devA, devC, n, k, 1, q4DataOff, e.stream) + }, dst...) if err != nil { return e.cpu.MatMul(ctx, a, b, dst...) } - - q4DataOff := tensor.Q4GPUDataOffset(qs.NumBlocks()) - if err := e.kernels.GemmQ4F32(devQ4, devA, devC, n, k, 1, q4DataOff, e.stream); err != nil { - e.pool.Free(e.deviceID, devC, cSize) - return e.cpu.MatMul(ctx, a, b, dst...) - } - - return makeGPUResult[T](e, outShape, devC, n, dst...) + return result, nil } // General GEMM: C_temp[N, M] = gemm_q4(B_q4[N,K], A^T[K,M]) - // Transpose flattened A[M, K] -> A^T[K, M] on GPU. aFlat, err := e.Reshape(ctx, a, []int{m, k}) if err != nil { return e.cpu.MatMul(ctx, a, b, dst...) @@ -1399,8 +1320,7 @@ func (e *GPUEngine[T]) matMulQ4BWeight(ctx context.Context, a *tensor.TensorNume return e.cpu.MatMul(ctx, a, b, dst...) } - q4DataOff2 := tensor.Q4GPUDataOffset(qs.NumBlocks()) - if err := e.kernels.GemmQ4F32(devQ4, devAT, devCTemp, n, k, m, q4DataOff2, e.stream); err != nil { + if err := e.kernels.GemmQ4F32(devQ4, devAT, devCTemp, n, k, m, q4DataOff, e.stream); err != nil { e.pool.Free(e.deviceID, devCTemp, cTempSize) return e.cpu.MatMul(ctx, a, b, dst...) } @@ -1423,42 +1343,17 @@ func (e *GPUEngine[T]) matMulQ4BWeight(ctx context.Context, a *tensor.TensorNume // For GEMV (n==1, single-column B), uses fused dequant+GEMV kernel. // For general GEMM (n>1), dequantizes Q4_K to F32 on GPU then calls cuBLAS Sgemm. func (e *GPUEngine[T]) matMulQ4K(ctx context.Context, qs *tensor.Q4KStorage, a, b *tensor.TensorNumeric[T], dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error) { - aShape := a.Shape() - bShape := b.Shape() - - if len(aShape) < 2 || len(bShape) < 2 { + m, k, n, fallback := aShapeCheck2D(a.Shape(), b.Shape(), 0) + if fallback { return e.cpu.MatMul(ctx, a, b, dst...) } - // Only handle unbatched 2D for now. - if len(aShape) > 2 || len(bShape) > 2 { - return e.cpu.MatMul(ctx, a, b, dst...) - } - - m := aShape[0] - k := aShape[1] - n := bShape[1] - e.setDevice() - // Get Q4_K device pointer (pre-uploaded or upload now). - var devW unsafe.Pointer - var freeW func() - if ptr, _, _ := qs.GPUPtr(); ptr != nil { - devW = ptr - freeW = func() {} - } else { - rawBytes := qs.RawBytes() - var err error - devW, err = e.pool.Alloc(e.deviceID, len(rawBytes)) - if err != nil { - return e.cpu.MatMul(ctx, a, b, dst...) - } - freeW = func() { e.pool.Free(e.deviceID, devW, len(rawBytes)) } - if err := e.runtime.Memcpy(devW, unsafe.Pointer(&rawBytes[0]), len(rawBytes), gpuapi.MemcpyHostToDevice); err != nil { - freeW() - return e.cpu.MatMul(ctx, a, b, dst...) - } + ptr, _, _ := qs.GPUPtr() + devW, freeW, err := e.uploadRawBytes(ptr, qs.RawBytes()) + if err != nil { + return e.cpu.MatMul(ctx, a, b, dst...) } defer freeW() @@ -1470,65 +1365,27 @@ func (e *GPUEngine[T]) matMulQ4K(ctx context.Context, qs *tensor.Q4KStorage, a, } defer cleanupX() - f32Size := int(unsafe.Sizeof(float32(0))) - cSize := m * f32Size - devY, err := e.pool.Alloc(e.deviceID, cSize) + result, err := e.quantGemvResult([]int{m, n}, m*n, func(devY unsafe.Pointer) error { + return e.kernels.GemvQ4KF32(devW, devX, devY, m, k, e.stream) + }, dst...) if err != nil { return e.cpu.MatMul(ctx, a, b, dst...) } - - if err := e.kernels.GemvQ4KF32(devW, devX, devY, m, k, e.stream); err != nil { - e.pool.Free(e.deviceID, devY, cSize) - return e.cpu.MatMul(ctx, a, b, dst...) - } - - return makeGPUResult[T](e, []int{m, n}, devY, m*n, dst...) + return result, nil } // General GEMM: dequantize Q4_K to F32 on GPU, then cuBLAS Sgemm. - // C[M,N] = dequant(A_q4k)[M,K] * B[K,N] - f32Size := int(unsafe.Sizeof(float32(0))) - - // Dequantize A to F32. - dequantSize := m * k * f32Size - devAF32, err := e.pool.Alloc(e.deviceID, dequantSize) - if err != nil { - return e.cpu.MatMul(ctx, a, b, dst...) - } - defer e.pool.Free(e.deviceID, devAF32, dequantSize) - + var gpuDequant func(src, dst unsafe.Pointer, rows, cols int) error if k%256 == 0 { - if err := e.kernels.DequantQ4KF32(devW, devAF32, m, k, e.stream); err != nil { - return e.cpu.MatMul(ctx, a, b, dst...) - } - } else { - dequant := make([]float32, m*k) - qs.Dequantize(dequant) - if err := e.runtime.Memcpy(devAF32, unsafe.Pointer(&dequant[0]), dequantSize, gpuapi.MemcpyHostToDevice); err != nil { - return e.cpu.MatMul(ctx, a, b, dst...) + gpuDequant = func(src, dst unsafe.Pointer, rows, cols int) error { + return e.kernels.DequantQ4KF32(src, dst, rows, cols, e.stream) } } - - // Upload B to GPU. - devB, cleanupB, err := getDevicePtr(e, b) - if err != nil { - return e.cpu.MatMul(ctx, a, b, dst...) - } - defer cleanupB() - - // Allocate output C. - cSize := m * n * f32Size - devC, err := e.pool.Alloc(e.deviceID, cSize) + result, err := e.dequantSgemm(devW, m, k, n, gpuDequant, qs.Dequantize, b, []int{m, n}, "matMulQ4K", dst...) if err != nil { return e.cpu.MatMul(ctx, a, b, dst...) } - - if err := e.blas.Sgemm(m, n, k, 1.0, devAF32, devB, 0.0, devC); err != nil { - e.pool.Free(e.deviceID, devC, cSize) - return nil, fmt.Errorf("matMulQ4K: Sgemm: %w", err) - } - - return makeGPUResult[T](e, []int{m, n}, devC, m*n, dst...) + return result, nil } // matMulQ4KBWeight handles MatMul where B has Q4_K storage (virtual-transposed weight). @@ -1537,58 +1394,21 @@ func (e *GPUEngine[T]) matMulQ4K(ctx context.Context, qs *tensor.Q4KStorage, a, // on the Q4_K weight data, halving memory bandwidth vs separate dequant + GEMM. // For general GEMM (m>1), dequantizes Q4_K to F32 on GPU then calls cuBLAS SgemmNT. func (e *GPUEngine[T]) matMulQ4KBWeight(ctx context.Context, a *tensor.TensorNumeric[T], qs *tensor.Q4KStorage, b *tensor.TensorNumeric[T], dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error) { - aShape := a.Shape() - bShape := b.Shape() - - if len(aShape) < 2 || len(bShape) < 2 { - return e.cpu.MatMul(ctx, a, b, dst...) - } - - // B must be 2D (virtual-transposed weight). - if len(bShape) > 2 { + m, k, n, outShape, fallback := bweightShapeMKN(a.Shape(), b.Shape(), 0) + if fallback { return e.cpu.MatMul(ctx, a, b, dst...) } - // Flatten A's batch dims: [batch..., m, k] -> [m_total, k] - k := aShape[len(aShape)-1] - m := 1 - for i := 0; i < len(aShape)-1; i++ { - m *= aShape[i] - } - n := bShape[1] // columns of B (after virtual transpose) - - // Build output shape: [batch..., m_last, n]. - outShape := make([]int, len(aShape)) - copy(outShape, aShape[:len(aShape)-1]) - outShape[len(outShape)-1] = n - e.setDevice() - // Get Q4_K device pointer (pre-uploaded or upload now). - // Q4_K data is stored as [N, K] super-blocks. - var devQ4K unsafe.Pointer - var freeQ4K func() - if ptr, _, _ := qs.GPUPtr(); ptr != nil { - devQ4K = ptr - freeQ4K = func() {} - } else { - rawBytes := qs.RawBytes() - var err error - devQ4K, err = e.pool.Alloc(e.deviceID, len(rawBytes)) - if err != nil { - return e.cpu.MatMul(ctx, a, b, dst...) - } - freeQ4K = func() { e.pool.Free(e.deviceID, devQ4K, len(rawBytes)) } - if err := e.runtime.Memcpy(devQ4K, unsafe.Pointer(&rawBytes[0]), len(rawBytes), gpuapi.MemcpyHostToDevice); err != nil { - freeQ4K() - return e.cpu.MatMul(ctx, a, b, dst...) - } + ptr, _, _ := qs.GPUPtr() + devQ4K, freeQ4K, err := e.uploadRawBytes(ptr, qs.RawBytes()) + if err != nil { + return e.cpu.MatMul(ctx, a, b, dst...) } defer freeQ4K() // Fused GEMV path: y[n] = sum_k dequant(B_q4k[n, k]) * x[k], when m==1. - // Requires K % 256 == 0 for Q4_K super-block alignment. - // When K is not aligned, falls through to the general dequant+cuBLAS path. if m == 1 && k%256 == 0 { devX, cleanupX, err := getDevicePtr(e, a) if err != nil { @@ -1596,138 +1416,44 @@ func (e *GPUEngine[T]) matMulQ4KBWeight(ctx context.Context, a *tensor.TensorNum } defer cleanupX() - f32Size := int(unsafe.Sizeof(float32(0))) - cSize := n * f32Size - devY, err := e.pool.Alloc(e.deviceID, cSize) + result, err := e.quantGemvResult(outShape, n, func(devY unsafe.Pointer) error { + return e.kernels.GemvQ4KF32(devQ4K, devX, devY, n, k, e.stream) + }, dst...) if err != nil { return e.cpu.MatMul(ctx, a, b, dst...) } - - if err := e.kernels.GemvQ4KF32(devQ4K, devX, devY, n, k, e.stream); err != nil { - e.pool.Free(e.deviceID, devY, cSize) - return e.cpu.MatMul(ctx, a, b, dst...) - } - - return makeGPUResult[T](e, outShape, devY, n, dst...) - } - - // General GEMM: dequantize Q4_K to F32 on GPU, then cuBLAS. - // Q4_K data is [N, K]. Dequantize gives F32 [N, K]. - // We need C[M,N] = A[M,K] * B^T where B = dequant(B_q4k)[N,K]. - // Use SgemmNT: C = A * B^T (A is [M,K], B is [N,K], C is [M,N]). - f32Size := int(unsafe.Sizeof(float32(0))) - - // Dequantize B to F32 [N, K]. - dequantSize := n * k * f32Size - devBF32, err := e.pool.Alloc(e.deviceID, dequantSize) - if err != nil { - return e.cpu.MatMul(ctx, a, b, dst...) + return result, nil } - defer e.pool.Free(e.deviceID, devBF32, dequantSize) + // General GEMM: dequantize Q4_K to F32, then SgemmNT. + var gpuDequant func(src, dst unsafe.Pointer, rows, cols int) error if k%256 == 0 { - // GPU dequant when K is super-block aligned. - if err := e.kernels.DequantQ4KF32(devQ4K, devBF32, n, k, e.stream); err != nil { - return e.cpu.MatMul(ctx, a, b, dst...) - } - } else { - // CPU dequant for unaligned K (super-block boundary doesn't match row boundary). - dequant := make([]float32, n*k) - qs.Dequantize(dequant) - if err := e.runtime.Memcpy(devBF32, unsafe.Pointer(&dequant[0]), dequantSize, gpuapi.MemcpyHostToDevice); err != nil { - return e.cpu.MatMul(ctx, a, b, dst...) - } - } - - // Upload A to GPU. - devA, cleanupA, err := getDevicePtr(e, a) - if err != nil { - return e.cpu.MatMul(ctx, a, b, dst...) - } - defer cleanupA() - - // Allocate output C [M, N]. - cSize := m * n * f32Size - devC, err := e.pool.Alloc(e.deviceID, cSize) - if err != nil { - return e.cpu.MatMul(ctx, a, b, dst...) - } - - // Use SgemmNT if available (avoids explicit transpose). - if ntBLAS, ok := e.blas.(gpuapi.BLASTransposeB); ok { - if err := ntBLAS.SgemmNT(m, n, k, 1.0, devA, devBF32, 0.0, devC); err != nil { - e.pool.Free(e.deviceID, devC, cSize) - return nil, fmt.Errorf("matMulQ4KBWeight: SgemmNT: %w", err) + gpuDequant = func(src, dst unsafe.Pointer, rows, cols int) error { + return e.kernels.DequantQ4KF32(src, dst, rows, cols, e.stream) } - return makeGPUResult[T](e, outShape, devC, m*n, dst...) } - - // Fallback: transpose dequantized B then use Sgemm. - devBT, err := e.pool.Alloc(e.deviceID, dequantSize) + result, err := e.sgemmNTOrFallback(devQ4K, m, k, n, gpuDequant, qs.Dequantize, a, outShape, "matMulQ4KBWeight", dst...) if err != nil { - e.pool.Free(e.deviceID, devC, cSize) - return e.cpu.MatMul(ctx, a, b, dst...) - } - defer e.pool.Free(e.deviceID, devBT, dequantSize) - - if err := e.kernels.Transpose2D(devBF32, devBT, n, k, e.stream); err != nil { - e.pool.Free(e.deviceID, devC, cSize) return e.cpu.MatMul(ctx, a, b, dst...) } - - if err := e.blas.Sgemm(m, n, k, 1.0, devA, devBT, 0.0, devC); err != nil { - e.pool.Free(e.deviceID, devC, cSize) - return nil, fmt.Errorf("matMulQ4KBWeight: Sgemm: %w", err) - } - - return makeGPUResult[T](e, outShape, devC, m*n, dst...) + return result, nil } // matMulQ6K handles GPU Q6_K dequant-GEMM when Q6_K storage is on A. // For GEMV (n==1, single-column B), uses fused dequant+GEMV kernel. // For general GEMM (n>1), dequantizes Q6_K to F32 on CPU then calls cuBLAS Sgemm. func (e *GPUEngine[T]) matMulQ6K(ctx context.Context, qs *tensor.Q6KStorage, a, b *tensor.TensorNumeric[T], dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error) { - aShape := a.Shape() - bShape := b.Shape() - - if len(aShape) < 2 || len(bShape) < 2 { - return e.cpu.MatMul(ctx, a, b, dst...) - } - - // Only handle unbatched 2D for now. - if len(aShape) > 2 || len(bShape) > 2 { - return e.cpu.MatMul(ctx, a, b, dst...) - } - - m := aShape[0] - k := aShape[1] - n := bShape[1] - - // K must be a multiple of 256 for Q6_K super-blocks. - if k%256 != 0 { + m, k, n, fallback := aShapeCheck2D(a.Shape(), b.Shape(), 256) + if fallback { return e.cpu.MatMul(ctx, a, b, dst...) } e.setDevice() - // Get Q6_K device pointer (pre-uploaded or upload now). - var devW unsafe.Pointer - var freeW func() - if ptr, _, _ := qs.GPUPtr(); ptr != nil { - devW = ptr - freeW = func() {} - } else { - rawBytes := qs.RawBytes() - var err error - devW, err = e.pool.Alloc(e.deviceID, len(rawBytes)) - if err != nil { - return e.cpu.MatMul(ctx, a, b, dst...) - } - freeW = func() { e.pool.Free(e.deviceID, devW, len(rawBytes)) } - if err := e.runtime.Memcpy(devW, unsafe.Pointer(&rawBytes[0]), len(rawBytes), gpuapi.MemcpyHostToDevice); err != nil { - freeW() - return e.cpu.MatMul(ctx, a, b, dst...) - } + ptr, _, _ := qs.GPUPtr() + devW, freeW, err := e.uploadRawBytes(ptr, qs.RawBytes()) + if err != nil { + return e.cpu.MatMul(ctx, a, b, dst...) } defer freeW() @@ -1739,116 +1465,42 @@ func (e *GPUEngine[T]) matMulQ6K(ctx context.Context, qs *tensor.Q6KStorage, a, } defer cleanupX() - f32Size := int(unsafe.Sizeof(float32(0))) - cSize := m * f32Size - devY, err := e.pool.Alloc(e.deviceID, cSize) + result, err := e.quantGemvResult([]int{m, n}, m*n, func(devY unsafe.Pointer) error { + return e.kernels.GemvQ6KF32(devW, devX, devY, m, k, e.stream) + }, dst...) if err != nil { return e.cpu.MatMul(ctx, a, b, dst...) } - - if err := e.kernels.GemvQ6KF32(devW, devX, devY, m, k, e.stream); err != nil { - e.pool.Free(e.deviceID, devY, cSize) - return e.cpu.MatMul(ctx, a, b, dst...) - } - - return makeGPUResult[T](e, []int{m, n}, devY, m*n, dst...) + return result, nil } // General GEMM: dequantize Q6_K to F32 on CPU, upload, cuBLAS Sgemm. - f32Size := int(unsafe.Sizeof(float32(0))) - dequant := make([]float32, m*k) - qs.Dequantize(dequant) - - dequantSize := m * k * f32Size - devAF32, err := e.pool.Alloc(e.deviceID, dequantSize) + result, err := e.dequantSgemm(devW, m, k, n, nil, qs.Dequantize, b, []int{m, n}, "matMulQ6K", dst...) if err != nil { return e.cpu.MatMul(ctx, a, b, dst...) } - defer e.pool.Free(e.deviceID, devAF32, dequantSize) + return result, nil +} - if err := e.runtime.Memcpy(devAF32, unsafe.Pointer(&dequant[0]), dequantSize, gpuapi.MemcpyHostToDevice); err != nil { +// matMulQ6KBWeight handles MatMul where B has Q6_K storage (virtual-transposed weight). +// B's shape after virtual transpose is [K, N], but the Q6_K data is laid out as [N, K]. +// For GEMV (m==1, single-token decode), uses fused dequant+GEMV kernel directly +// on the Q6_K weight data, halving memory bandwidth vs separate dequant + GEMM. +// For general GEMM (m>1), dequantizes Q6_K to F32 on CPU then calls cuBLAS SgemmNT. +func (e *GPUEngine[T]) matMulQ6KBWeight(ctx context.Context, a *tensor.TensorNumeric[T], qs *tensor.Q6KStorage, b *tensor.TensorNumeric[T], dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error) { + m, k, n, outShape, fallback := bweightShapeMKN(a.Shape(), b.Shape(), 256) + if fallback { return e.cpu.MatMul(ctx, a, b, dst...) } - devB, cleanupB, err := getDevicePtr(e, b) - if err != nil { - return e.cpu.MatMul(ctx, a, b, dst...) - } - defer cleanupB() + e.setDevice() - cSize := m * n * f32Size - devC, err := e.pool.Alloc(e.deviceID, cSize) + ptr, _, _ := qs.GPUPtr() + devQ6K, freeQ6K, err := e.uploadRawBytes(ptr, qs.RawBytes()) if err != nil { return e.cpu.MatMul(ctx, a, b, dst...) } - - if err := e.blas.Sgemm(m, n, k, 1.0, devAF32, devB, 0.0, devC); err != nil { - e.pool.Free(e.deviceID, devC, cSize) - return nil, fmt.Errorf("matMulQ6K: Sgemm: %w", err) - } - - return makeGPUResult[T](e, []int{m, n}, devC, m*n, dst...) -} - -// matMulQ6KBWeight handles MatMul where B has Q6_K storage (virtual-transposed weight). -// B's shape after virtual transpose is [K, N], but the Q6_K data is laid out as [N, K]. -// For GEMV (m==1, single-token decode), uses fused dequant+GEMV kernel directly -// on the Q6_K weight data, halving memory bandwidth vs separate dequant + GEMM. -// For general GEMM (m>1), dequantizes Q6_K to F32 on CPU then calls cuBLAS SgemmNT. -func (e *GPUEngine[T]) matMulQ6KBWeight(ctx context.Context, a *tensor.TensorNumeric[T], qs *tensor.Q6KStorage, b *tensor.TensorNumeric[T], dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error) { - aShape := a.Shape() - bShape := b.Shape() - - if len(aShape) < 2 || len(bShape) < 2 { - return e.cpu.MatMul(ctx, a, b, dst...) - } - - // B must be 2D (virtual-transposed weight). - if len(bShape) > 2 { - return e.cpu.MatMul(ctx, a, b, dst...) - } - - // Flatten A's batch dims: [batch..., m, k] -> [m_total, k] - k := aShape[len(aShape)-1] - m := 1 - for i := 0; i < len(aShape)-1; i++ { - m *= aShape[i] - } - n := bShape[1] // columns of B (after virtual transpose) - - // K must be a multiple of 256 for Q6_K super-blocks. - if k%256 != 0 { - return e.cpu.MatMul(ctx, a, b, dst...) - } - - // Build output shape: [batch..., m_last, n]. - outShape := make([]int, len(aShape)) - copy(outShape, aShape[:len(aShape)-1]) - outShape[len(outShape)-1] = n - - e.setDevice() - - // Get Q6_K device pointer (pre-uploaded or upload now). - // Q6_K data is stored as [N, K] super-blocks. - var devQ6K unsafe.Pointer - var freeQ6K func() - if ptr, _, _ := qs.GPUPtr(); ptr != nil { - devQ6K = ptr - freeQ6K = func() {} - } else { - rawBytes := qs.RawBytes() - var err error - devQ6K, err = e.pool.Alloc(e.deviceID, len(rawBytes)) - if err != nil { - return e.cpu.MatMul(ctx, a, b, dst...) - } - freeQ6K = func() { e.pool.Free(e.deviceID, devQ6K, len(rawBytes)) } - if err := e.runtime.Memcpy(devQ6K, unsafe.Pointer(&rawBytes[0]), len(rawBytes), gpuapi.MemcpyHostToDevice); err != nil { - freeQ6K() - return e.cpu.MatMul(ctx, a, b, dst...) - } - } - defer freeQ6K() + defer freeQ6K() // Fused GEMV path: y[n] = sum_k dequant(B_q6k[n, k]) * x[k], when m==1. if m == 1 { @@ -1858,560 +1510,225 @@ func (e *GPUEngine[T]) matMulQ6KBWeight(ctx context.Context, a *tensor.TensorNum } defer cleanupX() - f32Size := int(unsafe.Sizeof(float32(0))) - cSize := n * f32Size - devY, err := e.pool.Alloc(e.deviceID, cSize) - if err != nil { - return e.cpu.MatMul(ctx, a, b, dst...) - } - - if err := e.kernels.GemvQ6KF32(devQ6K, devX, devY, n, k, e.stream); err != nil { - e.pool.Free(e.deviceID, devY, cSize) - return e.cpu.MatMul(ctx, a, b, dst...) - } - - return makeGPUResult[T](e, outShape, devY, n, dst...) - } - - // General GEMM: dequantize Q6_K to F32 on GPU, then cuBLAS. - f32Size := int(unsafe.Sizeof(float32(0))) - dequantSize := n * k * f32Size - devBF32, err := e.pool.Alloc(e.deviceID, dequantSize) - if err != nil { - return e.cpu.MatMul(ctx, a, b, dst...) - } - defer e.pool.Free(e.deviceID, devBF32, dequantSize) - - if err := e.kernels.DequantQ6KF32(devQ6K, devBF32, n, k, e.stream); err != nil { - // GPU dequant failed — fall back to CPU dequant + upload. - dequant := make([]float32, n*k) - qs.Dequantize(dequant) - if err := e.runtime.Memcpy(devBF32, unsafe.Pointer(&dequant[0]), dequantSize, gpuapi.MemcpyHostToDevice); err != nil { - return e.cpu.MatMul(ctx, a, b, dst...) - } - } - - // Upload A to GPU. - devA, cleanupA, err := getDevicePtr(e, a) - if err != nil { - return e.cpu.MatMul(ctx, a, b, dst...) - } - defer cleanupA() - - // Allocate output C [M, N]. - cSize := m * n * f32Size - devC, err := e.pool.Alloc(e.deviceID, cSize) - if err != nil { - return e.cpu.MatMul(ctx, a, b, dst...) - } - - // Use SgemmNT if available (avoids explicit transpose). - if ntBLAS, ok := e.blas.(gpuapi.BLASTransposeB); ok { - if err := ntBLAS.SgemmNT(m, n, k, 1.0, devA, devBF32, 0.0, devC); err != nil { - e.pool.Free(e.deviceID, devC, cSize) - return nil, fmt.Errorf("matMulQ6KBWeight: SgemmNT: %w", err) - } - return makeGPUResult[T](e, outShape, devC, m*n, dst...) - } - - // Fallback: transpose dequantized B then use Sgemm. - devBT, err := e.pool.Alloc(e.deviceID, dequantSize) - if err != nil { - e.pool.Free(e.deviceID, devC, cSize) - return e.cpu.MatMul(ctx, a, b, dst...) - } - defer e.pool.Free(e.deviceID, devBT, dequantSize) - - if err := e.kernels.Transpose2D(devBF32, devBT, n, k, e.stream); err != nil { - e.pool.Free(e.deviceID, devC, cSize) - return e.cpu.MatMul(ctx, a, b, dst...) - } - - if err := e.blas.Sgemm(m, n, k, 1.0, devA, devBT, 0.0, devC); err != nil { - e.pool.Free(e.deviceID, devC, cSize) - return nil, fmt.Errorf("matMulQ6KBWeight: Sgemm: %w", err) - } - - return makeGPUResult[T](e, outShape, devC, m*n, dst...) -} - -// matMulQ5K handles GPU Q5_K dequant-GEMV when Q5_K storage is on A. -// For GEMV (n==1, single-column B), uses fused dequant+GEMV kernel. -// For general GEMM (n>1), falls back to CPU. -func (e *GPUEngine[T]) matMulQ5K(ctx context.Context, qs *tensor.Q5KStorage, a, b *tensor.TensorNumeric[T], dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error) { - aShape := a.Shape() - bShape := b.Shape() - - if len(aShape) < 2 || len(bShape) < 2 { - return e.cpu.MatMul(ctx, a, b, dst...) - } - - if len(aShape) > 2 || len(bShape) > 2 { - return e.cpu.MatMul(ctx, a, b, dst...) - } - - m := aShape[0] - k := aShape[1] - n := bShape[1] - - if k%256 != 0 { - return e.cpu.MatMul(ctx, a, b, dst...) - } - - if n != 1 { - return e.cpu.MatMul(ctx, a, b, dst...) - } - - e.setDevice() - - var devW unsafe.Pointer - var freeW func() - if ptr, _, _ := qs.GPUPtr(); ptr != nil { - devW = ptr - freeW = func() {} - } else { - rawBytes := qs.RawBytes() - var err error - devW, err = e.pool.Alloc(e.deviceID, len(rawBytes)) - if err != nil { - return e.cpu.MatMul(ctx, a, b, dst...) - } - freeW = func() { e.pool.Free(e.deviceID, devW, len(rawBytes)) } - if err := e.runtime.Memcpy(devW, unsafe.Pointer(&rawBytes[0]), len(rawBytes), gpuapi.MemcpyHostToDevice); err != nil { - freeW() - return e.cpu.MatMul(ctx, a, b, dst...) - } - } - defer freeW() - - devX, cleanupX, err := getDevicePtr(e, b) - if err != nil { - return e.cpu.MatMul(ctx, a, b, dst...) - } - defer cleanupX() - - f32Size := int(unsafe.Sizeof(float32(0))) - cSize := m * f32Size - devY, err := e.pool.Alloc(e.deviceID, cSize) - if err != nil { - return e.cpu.MatMul(ctx, a, b, dst...) - } - - if err := e.kernels.GemvQ5KF32(devW, devX, devY, m, k, e.stream); err != nil { - e.pool.Free(e.deviceID, devY, cSize) - return e.cpu.MatMul(ctx, a, b, dst...) - } - - return makeGPUResult[T](e, []int{m, n}, devY, m*n, dst...) -} - -// matMulQ5KBWeight handles MatMul where B has Q5_K storage (virtual-transposed weight). -// B's shape after virtual transpose is [K, N], but the Q5_K data is laid out as [N, K]. -// For GEMV (m==1, single-token decode), uses fused dequant+GEMV kernel. -// For general GEMM (m>1), falls back to CPU. -func (e *GPUEngine[T]) matMulQ5KBWeight(ctx context.Context, a *tensor.TensorNumeric[T], qs *tensor.Q5KStorage, b *tensor.TensorNumeric[T], dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error) { - aShape := a.Shape() - bShape := b.Shape() - - if len(aShape) < 2 || len(bShape) < 2 { - return e.cpu.MatMul(ctx, a, b, dst...) - } - - if len(bShape) > 2 { - return e.cpu.MatMul(ctx, a, b, dst...) - } - - k := aShape[len(aShape)-1] - m := 1 - for i := 0; i < len(aShape)-1; i++ { - m *= aShape[i] - } - n := bShape[1] - - if k%256 != 0 { - return e.cpu.MatMul(ctx, a, b, dst...) - } - - outShape := make([]int, len(aShape)) - copy(outShape, aShape[:len(aShape)-1]) - outShape[len(outShape)-1] = n - - e.setDevice() - - var devQ5K unsafe.Pointer - var freeQ5K func() - if ptr, _, _ := qs.GPUPtr(); ptr != nil { - devQ5K = ptr - freeQ5K = func() {} - } else { - rawBytes := qs.RawBytes() - var err error - devQ5K, err = e.pool.Alloc(e.deviceID, len(rawBytes)) - if err != nil { - return e.cpu.MatMul(ctx, a, b, dst...) - } - freeQ5K = func() { e.pool.Free(e.deviceID, devQ5K, len(rawBytes)) } - if err := e.runtime.Memcpy(devQ5K, unsafe.Pointer(&rawBytes[0]), len(rawBytes), gpuapi.MemcpyHostToDevice); err != nil { - freeQ5K() - return e.cpu.MatMul(ctx, a, b, dst...) - } - } - defer freeQ5K() - - if m == 1 { - devX, cleanupX, err := getDevicePtr(e, a) - if err != nil { - return e.cpu.MatMul(ctx, a, b, dst...) - } - defer cleanupX() - - f32Size := int(unsafe.Sizeof(float32(0))) - cSize := n * f32Size - devY, err := e.pool.Alloc(e.deviceID, cSize) - if err != nil { - return e.cpu.MatMul(ctx, a, b, dst...) - } - - if err := e.kernels.GemvQ5KF32(devQ5K, devX, devY, n, k, e.stream); err != nil { - e.pool.Free(e.deviceID, devY, cSize) - return e.cpu.MatMul(ctx, a, b, dst...) - } - - return makeGPUResult[T](e, outShape, devY, n, dst...) - } - - // General GEMM (M>1): dequantize Q5_K to F32 on GPU, then cuBLAS SgemmNT. - f32Size := int(unsafe.Sizeof(float32(0))) - dequantSize := n * k * f32Size - devBF32, err := e.pool.Alloc(e.deviceID, dequantSize) - if err != nil { - return e.cpu.MatMul(ctx, a, b, dst...) - } - defer e.pool.Free(e.deviceID, devBF32, dequantSize) - - if err := e.kernels.DequantQ5KF32(devQ5K, devBF32, n, k, e.stream); err != nil { - return e.cpu.MatMul(ctx, a, b, dst...) - } - - devA, cleanupA, err := getDevicePtr(e, a) - if err != nil { - return e.cpu.MatMul(ctx, a, b, dst...) - } - defer cleanupA() - - cSize := m * n * f32Size - devC, err := e.pool.Alloc(e.deviceID, cSize) - if err != nil { - return e.cpu.MatMul(ctx, a, b, dst...) - } - - if ntBLAS, ok := e.blas.(gpuapi.BLASTransposeB); ok { - if err := ntBLAS.SgemmNT(m, n, k, 1.0, devA, devBF32, 0.0, devC); err != nil { - e.pool.Free(e.deviceID, devC, cSize) - return nil, fmt.Errorf("matMulQ5KBWeight: SgemmNT: %w", err) - } - return makeGPUResult[T](e, outShape, devC, m*n, dst...) - } - - e.pool.Free(e.deviceID, devC, cSize) - return e.cpu.MatMul(ctx, a, b, dst...) -} - -// matMulQ5_0 handles GPU Q5_0 dequant-GEMM when Q5_0 storage is on A. -// For GEMV (n==1, single-column B), uses fused dequant+GEMV kernel. -// For general GEMM (n>1), dequantizes Q5_0 to F32 on CPU then calls cuBLAS Sgemm. -func (e *GPUEngine[T]) matMulQ5_0(ctx context.Context, qs *tensor.Q5_0Storage, a, b *tensor.TensorNumeric[T], dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error) { - aShape := a.Shape() - bShape := b.Shape() - - if len(aShape) < 2 || len(bShape) < 2 { - return e.cpu.MatMul(ctx, a, b, dst...) - } - - if len(aShape) > 2 || len(bShape) > 2 { - return e.cpu.MatMul(ctx, a, b, dst...) - } - - m := aShape[0] - k := aShape[1] - n := bShape[1] - - if k%32 != 0 { - return e.cpu.MatMul(ctx, a, b, dst...) - } - - e.setDevice() - - var devW unsafe.Pointer - var freeW func() - nBlocks := qs.NumBlocks() - if ptr, _, _ := qs.GPUPtr(); ptr != nil { - devW = ptr - freeW = func() {} - } else { - rawBytes := qs.RawBytesGPU() - var err error - devW, err = e.pool.Alloc(e.deviceID, len(rawBytes)) - if err != nil { - return e.cpu.MatMul(ctx, a, b, dst...) - } - freeW = func() { e.pool.Free(e.deviceID, devW, len(rawBytes)) } - if err := e.runtime.Memcpy(devW, unsafe.Pointer(&rawBytes[0]), len(rawBytes), gpuapi.MemcpyHostToDevice); err != nil { - freeW() - return e.cpu.MatMul(ctx, a, b, dst...) - } - } - defer freeW() - - qhOff := tensor.Q5_0GPUQhOffset(nBlocks) - qsOff := tensor.Q5_0GPUQsOffset(nBlocks) - - if n == 1 { - devX, cleanupX, err := getDevicePtr(e, b) - if err != nil { - return e.cpu.MatMul(ctx, a, b, dst...) - } - defer cleanupX() - - f32Size := int(unsafe.Sizeof(float32(0))) - cSize := m * f32Size - devY, err := e.pool.Alloc(e.deviceID, cSize) + result, err := e.quantGemvResult(outShape, n, func(devY unsafe.Pointer) error { + return e.kernels.GemvQ6KF32(devQ6K, devX, devY, n, k, e.stream) + }, dst...) if err != nil { return e.cpu.MatMul(ctx, a, b, dst...) } - - if err := e.kernels.GemvQ5_0F32(devW, devX, devY, m, k, qhOff, qsOff, e.stream); err != nil { - e.pool.Free(e.deviceID, devY, cSize) - return e.cpu.MatMul(ctx, a, b, dst...) - } - - return makeGPUResult[T](e, []int{m, n}, devY, m*n, dst...) + return result, nil } - f32Size := int(unsafe.Sizeof(float32(0))) - dequant := make([]float32, m*k) - qs.Dequantize(dequant) - - dequantSize := m * k * f32Size - devAF32, err := e.pool.Alloc(e.deviceID, dequantSize) + // General GEMM: dequantize Q6_K to F32, then SgemmNT. + gpuDequant := func(src, dst unsafe.Pointer, rows, cols int) error { + return e.kernels.DequantQ6KF32(src, dst, rows, cols, e.stream) + } + result, err := e.sgemmNTOrFallback(devQ6K, m, k, n, gpuDequant, qs.Dequantize, a, outShape, "matMulQ6KBWeight", dst...) if err != nil { return e.cpu.MatMul(ctx, a, b, dst...) } - defer e.pool.Free(e.deviceID, devAF32, dequantSize) + return result, nil +} - if err := e.runtime.Memcpy(devAF32, unsafe.Pointer(&dequant[0]), dequantSize, gpuapi.MemcpyHostToDevice); err != nil { +// matMulQ5K handles GPU Q5_K dequant-GEMV when Q5_K storage is on A. +// For GEMV (n==1, single-column B), uses fused dequant+GEMV kernel. +// For general GEMM (n>1), falls back to CPU. +func (e *GPUEngine[T]) matMulQ5K(ctx context.Context, qs *tensor.Q5KStorage, a, b *tensor.TensorNumeric[T], dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error) { + m, k, n, fallback := aShapeCheck2D(a.Shape(), b.Shape(), 256) + if fallback || n != 1 { return e.cpu.MatMul(ctx, a, b, dst...) } - devB, cleanupB, err := getDevicePtr(e, b) + e.setDevice() + + ptr, _, _ := qs.GPUPtr() + devW, freeW, err := e.uploadRawBytes(ptr, qs.RawBytes()) if err != nil { return e.cpu.MatMul(ctx, a, b, dst...) } - defer cleanupB() + defer freeW() - cSize := m * n * f32Size - devC, err := e.pool.Alloc(e.deviceID, cSize) + devX, cleanupX, err := getDevicePtr(e, b) if err != nil { return e.cpu.MatMul(ctx, a, b, dst...) } + defer cleanupX() - if err := e.blas.Sgemm(m, n, k, 1.0, devAF32, devB, 0.0, devC); err != nil { - e.pool.Free(e.deviceID, devC, cSize) - return nil, fmt.Errorf("matMulQ5_0: Sgemm: %w", err) + result, err := e.quantGemvResult([]int{m, n}, m*n, func(devY unsafe.Pointer) error { + return e.kernels.GemvQ5KF32(devW, devX, devY, m, k, e.stream) + }, dst...) + if err != nil { + return e.cpu.MatMul(ctx, a, b, dst...) } - - return makeGPUResult[T](e, []int{m, n}, devC, m*n, dst...) + return result, nil } -// matMulQ5_0BWeight handles MatMul where B has Q5_0 storage (virtual-transposed weight). -func (e *GPUEngine[T]) matMulQ5_0BWeight(ctx context.Context, a *tensor.TensorNumeric[T], qs *tensor.Q5_0Storage, b *tensor.TensorNumeric[T], dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error) { - aShape := a.Shape() - bShape := b.Shape() - - if len(aShape) < 2 || len(bShape) < 2 { +// matMulQ5KBWeight handles MatMul where B has Q5_K storage (virtual-transposed weight). +// B's shape after virtual transpose is [K, N], but the Q5_K data is laid out as [N, K]. +// For GEMV (m==1, single-token decode), uses fused dequant+GEMV kernel. +// For general GEMM (m>1), dequantizes Q5_K to F32 on GPU then calls cuBLAS SgemmNT. +func (e *GPUEngine[T]) matMulQ5KBWeight(ctx context.Context, a *tensor.TensorNumeric[T], qs *tensor.Q5KStorage, b *tensor.TensorNumeric[T], dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error) { + m, k, n, outShape, fallback := bweightShapeMKN(a.Shape(), b.Shape(), 256) + if fallback { return e.cpu.MatMul(ctx, a, b, dst...) } - if len(bShape) > 2 { + e.setDevice() + + ptr, _, _ := qs.GPUPtr() + devQ5K, freeQ5K, err := e.uploadRawBytes(ptr, qs.RawBytes()) + if err != nil { return e.cpu.MatMul(ctx, a, b, dst...) } + defer freeQ5K() - k := aShape[len(aShape)-1] - m := 1 - for i := 0; i < len(aShape)-1; i++ { - m *= aShape[i] + // Fused GEMV path: y[n] = sum_k dequant(B_q5k[n, k]) * x[k], when m==1. + if m == 1 { + devX, cleanupX, err := getDevicePtr(e, a) + if err != nil { + return e.cpu.MatMul(ctx, a, b, dst...) + } + defer cleanupX() + + result, err := e.quantGemvResult(outShape, n, func(devY unsafe.Pointer) error { + return e.kernels.GemvQ5KF32(devQ5K, devX, devY, n, k, e.stream) + }, dst...) + if err != nil { + return e.cpu.MatMul(ctx, a, b, dst...) + } + return result, nil } - n := bShape[1] - if k%32 != 0 { + // General GEMM (M>1): dequantize Q5_K to F32 on GPU, then SgemmNT. + gpuDequant := func(src, dst unsafe.Pointer, rows, cols int) error { + return e.kernels.DequantQ5KF32(src, dst, rows, cols, e.stream) + } + result, err := e.sgemmNTOrFallback(devQ5K, m, k, n, gpuDequant, nil, a, outShape, "matMulQ5KBWeight", dst...) + if err != nil { return e.cpu.MatMul(ctx, a, b, dst...) } + return result, nil +} - outShape := make([]int, len(aShape)) - copy(outShape, aShape[:len(aShape)-1]) - outShape[len(outShape)-1] = n +// matMulQ5_0 handles GPU Q5_0 dequant-GEMM when Q5_0 storage is on A. +// For GEMV (n==1, single-column B), uses fused dequant+GEMV kernel. +// For general GEMM (n>1), dequantizes Q5_0 to F32 on CPU then calls cuBLAS Sgemm. +func (e *GPUEngine[T]) matMulQ5_0(ctx context.Context, qs *tensor.Q5_0Storage, a, b *tensor.TensorNumeric[T], dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error) { + m, k, n, fallback := aShapeCheck2D(a.Shape(), b.Shape(), 32) + if fallback { + return e.cpu.MatMul(ctx, a, b, dst...) + } e.setDevice() - var devQ5_0 unsafe.Pointer - var freeQ5_0 func() nBlocks := qs.NumBlocks() - if ptr, _, _ := qs.GPUPtr(); ptr != nil { - devQ5_0 = ptr - freeQ5_0 = func() {} - } else { - rawBytes := qs.RawBytesGPU() - var err error - devQ5_0, err = e.pool.Alloc(e.deviceID, len(rawBytes)) - if err != nil { - return e.cpu.MatMul(ctx, a, b, dst...) - } - freeQ5_0 = func() { e.pool.Free(e.deviceID, devQ5_0, len(rawBytes)) } - if err := e.runtime.Memcpy(devQ5_0, unsafe.Pointer(&rawBytes[0]), len(rawBytes), gpuapi.MemcpyHostToDevice); err != nil { - freeQ5_0() - return e.cpu.MatMul(ctx, a, b, dst...) - } + ptr, _, _ := qs.GPUPtr() + devW, freeW, err := e.uploadRawBytes(ptr, qs.RawBytesGPU()) + if err != nil { + return e.cpu.MatMul(ctx, a, b, dst...) } - defer freeQ5_0() + defer freeW() qhOff := tensor.Q5_0GPUQhOffset(nBlocks) qsOff := tensor.Q5_0GPUQsOffset(nBlocks) - if m == 1 { - devX, cleanupX, err := getDevicePtr(e, a) + // Fused GEMV path. + if n == 1 { + devX, cleanupX, err := getDevicePtr(e, b) if err != nil { return e.cpu.MatMul(ctx, a, b, dst...) } defer cleanupX() - f32Size := int(unsafe.Sizeof(float32(0))) - cSize := n * f32Size - devY, err := e.pool.Alloc(e.deviceID, cSize) + result, err := e.quantGemvResult([]int{m, n}, m*n, func(devY unsafe.Pointer) error { + return e.kernels.GemvQ5_0F32(devW, devX, devY, m, k, qhOff, qsOff, e.stream) + }, dst...) if err != nil { return e.cpu.MatMul(ctx, a, b, dst...) } - - if err := e.kernels.GemvQ5_0F32(devQ5_0, devX, devY, n, k, qhOff, qsOff, e.stream); err != nil { - e.pool.Free(e.deviceID, devY, cSize) - return e.cpu.MatMul(ctx, a, b, dst...) - } - - return makeGPUResult[T](e, outShape, devY, n, dst...) + return result, nil } - // General GEMM (M>1): dequantize Q5_0 to F32 on GPU, then cuBLAS SgemmNT. - f32Size := int(unsafe.Sizeof(float32(0))) - dequantSize := n * k * f32Size - devBF32, err := e.pool.Alloc(e.deviceID, dequantSize) + // General GEMM: dequantize Q5_0 to F32 on CPU, then cuBLAS Sgemm. + result, err := e.dequantSgemm(devW, m, k, n, nil, qs.Dequantize, b, []int{m, n}, "matMulQ5_0", dst...) if err != nil { return e.cpu.MatMul(ctx, a, b, dst...) } - defer e.pool.Free(e.deviceID, devBF32, dequantSize) - - if dqErr := e.kernels.DequantQ5_0F32(devQ5_0, devBF32, n, k, e.stream); dqErr != nil { - // GPU dequant failed — fall back to CPU dequant + upload. - dequant := make([]float32, n*k) - qs.Dequantize(dequant) - if err := e.runtime.Memcpy(devBF32, unsafe.Pointer(&dequant[0]), dequantSize, gpuapi.MemcpyHostToDevice); err != nil { - return e.cpu.MatMul(ctx, a, b, dst...) - } - } + return result, nil +} - devA, cleanupA, err := getDevicePtr(e, a) - if err != nil { +// matMulQ5_0BWeight handles MatMul where B has Q5_0 storage (virtual-transposed weight). +func (e *GPUEngine[T]) matMulQ5_0BWeight(ctx context.Context, a *tensor.TensorNumeric[T], qs *tensor.Q5_0Storage, b *tensor.TensorNumeric[T], dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error) { + m, k, n, outShape, fallback := bweightShapeMKN(a.Shape(), b.Shape(), 32) + if fallback { return e.cpu.MatMul(ctx, a, b, dst...) } - defer cleanupA() - cSize := m * n * f32Size - devC, err := e.pool.Alloc(e.deviceID, cSize) + e.setDevice() + + nBlocks := qs.NumBlocks() + ptr, _, _ := qs.GPUPtr() + devQ5_0, freeQ5_0, err := e.uploadRawBytes(ptr, qs.RawBytesGPU()) if err != nil { return e.cpu.MatMul(ctx, a, b, dst...) } + defer freeQ5_0() - if ntBLAS, ok := e.blas.(gpuapi.BLASTransposeB); ok { - if err := ntBLAS.SgemmNT(m, n, k, 1.0, devA, devBF32, 0.0, devC); err != nil { - e.pool.Free(e.deviceID, devC, cSize) - return nil, fmt.Errorf("matMulQ5_0BWeight: SgemmNT: %w", err) + qhOff := tensor.Q5_0GPUQhOffset(nBlocks) + qsOff := tensor.Q5_0GPUQsOffset(nBlocks) + + // Fused GEMV path. + if m == 1 { + devX, cleanupX, err := getDevicePtr(e, a) + if err != nil { + return e.cpu.MatMul(ctx, a, b, dst...) } - return makeGPUResult[T](e, outShape, devC, m*n, dst...) - } + defer cleanupX() - devBT, err := e.pool.Alloc(e.deviceID, dequantSize) - if err != nil { - e.pool.Free(e.deviceID, devC, cSize) - return e.cpu.MatMul(ctx, a, b, dst...) + result, err := e.quantGemvResult(outShape, n, func(devY unsafe.Pointer) error { + return e.kernels.GemvQ5_0F32(devQ5_0, devX, devY, n, k, qhOff, qsOff, e.stream) + }, dst...) + if err != nil { + return e.cpu.MatMul(ctx, a, b, dst...) + } + return result, nil } - defer e.pool.Free(e.deviceID, devBT, dequantSize) - if err := e.kernels.Transpose2D(devBF32, devBT, n, k, e.stream); err != nil { - e.pool.Free(e.deviceID, devC, cSize) - return e.cpu.MatMul(ctx, a, b, dst...) + // General GEMM (M>1): dequantize Q5_0 to F32, then SgemmNT. + gpuDequant := func(src, dst unsafe.Pointer, rows, cols int) error { + return e.kernels.DequantQ5_0F32(src, dst, rows, cols, e.stream) } - - if err := e.blas.Sgemm(m, n, k, 1.0, devA, devBT, 0.0, devC); err != nil { - e.pool.Free(e.deviceID, devC, cSize) - return nil, fmt.Errorf("matMulQ5_0BWeight: Sgemm: %w", err) + result, err := e.sgemmNTOrFallback(devQ5_0, m, k, n, gpuDequant, qs.Dequantize, a, outShape, "matMulQ5_0BWeight", dst...) + if err != nil { + return e.cpu.MatMul(ctx, a, b, dst...) } - - return makeGPUResult[T](e, outShape, devC, m*n, dst...) + return result, nil } // matMulQ8 handles GPU Q8_0 dequant-GEMM: C = dequant(A_q8) * B. // Only supports unbatched 2D for now; batched Q8 falls back to CPU. func (e *GPUEngine[T]) matMulQ8(ctx context.Context, qs *tensor.Q8Storage, a, b *tensor.TensorNumeric[T], dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error) { - aShape := a.Shape() - bShape := b.Shape() - - if len(aShape) < 2 || len(bShape) < 2 { - return e.cpu.MatMul(ctx, a, b, dst...) - } - - m := aShape[len(aShape)-2] - k := aShape[len(aShape)-1] - n := bShape[len(bShape)-1] - - // Only handle unbatched 2D for now. - if len(aShape) > 2 || len(bShape) > 2 { - return e.cpu.MatMul(ctx, a, b, dst...) - } - - if k%32 != 0 { + m, k, n, fallback := aShapeCheck2D(a.Shape(), b.Shape(), 32) + if fallback { return e.cpu.MatMul(ctx, a, b, dst...) } e.setDevice() - // Use pre-uploaded Q8 GPU pointer if available; otherwise upload now. - var devA unsafe.Pointer - var freeA func() - if ptr, _, _ := qs.GPUPtr(); ptr != nil { - devA = ptr - freeA = func() {} // pre-uploaded; do not free - } else { - aBytes := qs.RawBytes() - var err error - devA, err = e.pool.Alloc(e.deviceID, len(aBytes)) - if err != nil { - return e.cpu.MatMul(ctx, a, b, dst...) - } - freeA = func() { e.pool.Free(e.deviceID, devA, len(aBytes)) } - if err := e.runtime.Memcpy(devA, unsafe.Pointer(&aBytes[0]), len(aBytes), gpuapi.MemcpyHostToDevice); err != nil { - freeA() - return e.cpu.MatMul(ctx, a, b, dst...) - } + ptr, _, _ := qs.GPUPtr() + devA, freeA, err := e.uploadRawBytes(ptr, qs.RawBytes()) + if err != nil { + return e.cpu.MatMul(ctx, a, b, dst...) } defer freeA() - // Upload B (float32) to GPU. devB, cleanupB, err := getDevicePtr(e, b) if err != nil { return e.cpu.MatMul(ctx, a, b, dst...) } defer cleanupB() - // Allocate output C. - cSize := m * n * int(unsafe.Sizeof(float32(0))) + cSize := m * n * f32Size devC, err := e.pool.Alloc(e.deviceID, cSize) if err != nil { return e.cpu.MatMul(ctx, a, b, dst...) @@ -2434,82 +1751,35 @@ func (e *GPUEngine[T]) matMulQ8(ctx context.Context, qs *tensor.Q8Storage, a, b // For GEMV (M=1), A^T[K,1] is just A's data as a column, and C_temp[N,1] // can be reshaped to [1, N] without a physical transpose. func (e *GPUEngine[T]) matMulQ8BWeight(ctx context.Context, a *tensor.TensorNumeric[T], qs *tensor.Q8Storage, b *tensor.TensorNumeric[T], dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error) { - aShape := a.Shape() - bShape := b.Shape() - - if len(aShape) < 2 || len(bShape) < 2 { - return e.cpu.MatMul(ctx, a, b, dst...) - } - - // B must be 2D (virtual-transposed weight). - if len(bShape) > 2 { - return e.cpu.MatMul(ctx, a, b, dst...) - } - - // Flatten A's batch dims: [batch..., m, k] -> [m_total, k] - k := aShape[len(aShape)-1] - m := 1 - for i := 0; i < len(aShape)-1; i++ { - m *= aShape[i] - } - n := bShape[1] // columns of B (after virtual transpose) - - // Q8 original layout is [N, K]. Verify K is a multiple of 32. - if k%32 != 0 { + m, k, n, outShape, fallback := bweightShapeMKN(a.Shape(), b.Shape(), 32) + if fallback { return e.cpu.MatMul(ctx, a, b, dst...) } - // Build output shape: [batch..., m_last, n] matching standard MatMul broadcast. - outShape := make([]int, len(aShape)) - copy(outShape, aShape[:len(aShape)-1]) - outShape[len(outShape)-1] = n - e.setDevice() - // Get Q8 device pointer (pre-uploaded or upload now). - var devQ8 unsafe.Pointer - var freeQ8 func() - if ptr, _, _ := qs.GPUPtr(); ptr != nil { - devQ8 = ptr - freeQ8 = func() {} - } else { - q8Bytes := qs.RawBytes() - var err error - devQ8, err = e.pool.Alloc(e.deviceID, len(q8Bytes)) - if err != nil { - return e.cpu.MatMul(ctx, a, b, dst...) - } - freeQ8 = func() { e.pool.Free(e.deviceID, devQ8, len(q8Bytes)) } - if err := e.runtime.Memcpy(devQ8, unsafe.Pointer(&q8Bytes[0]), len(q8Bytes), gpuapi.MemcpyHostToDevice); err != nil { - freeQ8() - return e.cpu.MatMul(ctx, a, b, dst...) - } + ptr, _, _ := qs.GPUPtr() + devQ8, freeQ8, err := e.uploadRawBytes(ptr, qs.RawBytes()) + if err != nil { + return e.cpu.MatMul(ctx, a, b, dst...) } defer freeQ8() - // Upload A to GPU as F32. devA, cleanupA, err := getDevicePtr(e, a) if err != nil { return e.cpu.MatMul(ctx, a, b, dst...) } defer cleanupA() - f32Size := int(unsafe.Sizeof(float32(0))) - if m == 1 { // GEMV fast path: C_temp[N, 1] = gemm_q8(B_q8[N,K], A^T[K,1]) - cSize := n * f32Size - devC, err := e.pool.Alloc(e.deviceID, cSize) + result, err := e.quantGemvResult(outShape, n, func(devC unsafe.Pointer) error { + return e.kernels.GemmQ8F32(devQ8, devA, devC, n, k, 1, e.stream) + }, dst...) if err != nil { return e.cpu.MatMul(ctx, a, b, dst...) } - - if err := e.kernels.GemmQ8F32(devQ8, devA, devC, n, k, 1, e.stream); err != nil { - e.pool.Free(e.deviceID, devC, cSize) - return e.cpu.MatMul(ctx, a, b, dst...) - } - - return makeGPUResult[T](e, outShape, devC, n, dst...) + return result, nil } // General GEMM: C_temp[N, M] = gemm_q8(B_q8[N,K], A^T[K,M]) @@ -2538,7 +1808,6 @@ func (e *GPUEngine[T]) matMulQ8BWeight(ctx context.Context, a *tensor.TensorNume return e.cpu.MatMul(ctx, a, b, dst...) } - // C_temp is [N, M], transpose to [M, N], then reshape to outShape. cTempTensor, err := makeGPUResult[T](e, []int{n, m}, devCTemp, n*m) if err != nil { e.pool.Free(e.deviceID, devCTemp, cTempSize) @@ -2553,53 +1822,25 @@ func (e *GPUEngine[T]) matMulQ8BWeight(ctx context.Context, a *tensor.TensorNume } // matMulBF16 handles MatMul where A has BFloat16Storage. -// A is [M, K] in BF16, B is [K, N] in FP32 → C is [M, N] in FP32. +// A is [M, K] in BF16, B is [K, N] in FP32 -> C is [M, N] in FP32. // B's FP32 data is converted to BF16 on the fly, then MixedBF16Gemm // computes with BF16 inputs and FP32 output via cublasGemmEx. func (e *GPUEngine[T]) matMulBF16(ctx context.Context, bs *tensor.BFloat16Storage, a, b *tensor.TensorNumeric[T], dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error) { - aShape := a.Shape() - bShape := b.Shape() - - if len(aShape) < 2 || len(bShape) < 2 { - return e.cpu.MatMul(ctx, a, b, dst...) - } - if len(aShape) > 2 || len(bShape) > 2 { - return e.cpu.MatMul(ctx, a, b, dst...) - } - - if e.blas == nil { + m, k, n, fallback := aShapeCheck2D(a.Shape(), b.Shape(), 0) + if fallback || e.blas == nil { return e.cpu.MatMul(ctx, a, b, dst...) } - m := aShape[0] - k := aShape[1] - n := bShape[1] - e.setDevice() - // Get BF16 device pointer for A (pre-uploaded or upload now). - var devA unsafe.Pointer - var freeA func() - if ptr, _, _ := bs.GPUPtr(); ptr != nil { - devA = ptr - freeA = func() {} - } else { - aBytes := bs.RawBytes() - var err error - devA, err = e.pool.Alloc(e.deviceID, len(aBytes)) - if err != nil { - return e.cpu.MatMul(ctx, a, b, dst...) - } - freeA = func() { e.pool.Free(e.deviceID, devA, len(aBytes)) } - if err := e.runtime.Memcpy(devA, unsafe.Pointer(&aBytes[0]), len(aBytes), gpuapi.MemcpyHostToDevice); err != nil { - freeA() - return e.cpu.MatMul(ctx, a, b, dst...) - } + ptr, _, _ := bs.GPUPtr() + devA, freeA, err := e.uploadRawBytes(ptr, bs.RawBytes()) + if err != nil { + return e.cpu.MatMul(ctx, a, b, dst...) } defer freeA() // Convert B from FP32 to BF16 and upload. - // BFloat16Storage is Storage[float32], so T is float32 here. bData := b.Data() bF32 := *(*[]float32)(unsafe.Pointer(&bData)) bBF16 := tensor.NewBFloat16Storage(bF32) @@ -2614,7 +1855,6 @@ func (e *GPUEngine[T]) matMulBF16(ctx context.Context, bs *tensor.BFloat16Storag return e.cpu.MatMul(ctx, a, b, dst...) } - // Allocate FP32 output. cSize := m * n * f32Size devC, err := e.pool.Alloc(e.deviceID, cSize) if err != nil { @@ -2630,41 +1870,18 @@ func (e *GPUEngine[T]) matMulBF16(ctx context.Context, bs *tensor.BFloat16Storag } // matMulBF16BWeight handles MatMul where B has BFloat16Storage. -// A is [M, K] in FP32, B is [K, N] in BF16 → C is [M, N] in FP32. +// A is [M, K] in FP32, B is [K, N] in BF16 -> C is [M, N] in FP32. // A's FP32 data is converted to BF16 on the fly, then MixedBF16Gemm // computes with BF16 inputs and FP32 output via cublasGemmEx. func (e *GPUEngine[T]) matMulBF16BWeight(ctx context.Context, a *tensor.TensorNumeric[T], bs *tensor.BFloat16Storage, b *tensor.TensorNumeric[T], dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error) { - aShape := a.Shape() - bShape := b.Shape() - - if len(aShape) < 2 || len(bShape) < 2 { - return e.cpu.MatMul(ctx, a, b, dst...) - } - if len(bShape) > 2 { - return e.cpu.MatMul(ctx, a, b, dst...) - } - - if e.blas == nil { + m, k, n, outShape, fallback := bweightShapeMKN(a.Shape(), b.Shape(), 0) + if fallback || e.blas == nil { return e.cpu.MatMul(ctx, a, b, dst...) } - // Flatten A's batch dims: [batch..., m, k] -> [m_total, k] - k := aShape[len(aShape)-1] - m := 1 - for i := 0; i < len(aShape)-1; i++ { - m *= aShape[i] - } - n := bShape[1] - - // Build output shape: [batch..., m_last, n]. - outShape := make([]int, len(aShape)) - copy(outShape, aShape[:len(aShape)-1]) - outShape[len(outShape)-1] = n - e.setDevice() // Convert A from FP32 to BF16 and upload. - // BFloat16Storage is Storage[float32], so T is float32 here. aData := a.Data() aF32 := *(*[]float32)(unsafe.Pointer(&aData)) aBF16 := tensor.NewBFloat16Storage(aF32) @@ -2679,27 +1896,13 @@ func (e *GPUEngine[T]) matMulBF16BWeight(ctx context.Context, a *tensor.TensorNu return e.cpu.MatMul(ctx, a, b, dst...) } - // Get BF16 device pointer for B (pre-uploaded or upload now). - var devB unsafe.Pointer - var freeB func() - if ptr, _, _ := bs.GPUPtr(); ptr != nil { - devB = ptr - freeB = func() {} - } else { - bBytes := bs.RawBytes() - devB, err = e.pool.Alloc(e.deviceID, len(bBytes)) - if err != nil { - return e.cpu.MatMul(ctx, a, b, dst...) - } - freeB = func() { e.pool.Free(e.deviceID, devB, len(bBytes)) } - if err := e.runtime.Memcpy(devB, unsafe.Pointer(&bBytes[0]), len(bBytes), gpuapi.MemcpyHostToDevice); err != nil { - freeB() - return e.cpu.MatMul(ctx, a, b, dst...) - } + ptr, _, _ := bs.GPUPtr() + devB, freeB, err := e.uploadRawBytes(ptr, bs.RawBytes()) + if err != nil { + return e.cpu.MatMul(ctx, a, b, dst...) } defer freeB() - // Allocate FP32 output. cSize := m * n * f32Size devC, err := e.pool.Alloc(e.deviceID, cSize) if err != nil { diff --git a/compute/gpu_engine_matmul.go b/compute/gpu_engine_matmul.go new file mode 100644 index 0000000..435d770 --- /dev/null +++ b/compute/gpu_engine_matmul.go @@ -0,0 +1,240 @@ +package compute + +// gpu_engine_matmul.go contains shared helper functions extracted from the 14 +// quantized matmul methods in gpu_engine.go. Each helper captures a recurring +// pattern (upload, shape validation, GEMV output, dequant+GEMM) so that the +// individual matmul methods can be thin wrappers. + +import ( + "fmt" + "unsafe" + + "github.com/zerfoo/ztensor/internal/gpuapi" + "github.com/zerfoo/ztensor/tensor" +) + +// uploadRawBytes uploads raw quantized bytes to the GPU. If the storage already +// has a GPU pointer (via GPUPtr()), it returns that pointer with a no-op free. +// Otherwise it allocates device memory, copies the bytes, and returns a free +// function the caller must defer. +// +// gpuPtr should return (ptr, _, _) from the storage's GPUPtr() method. +// rawBytes is the byte payload to upload (e.g. qs.RawBytes()). +func (e *GPUEngine[T]) uploadRawBytes(gpuPtr unsafe.Pointer, rawBytes []byte) (devPtr unsafe.Pointer, free func(), err error) { + if gpuPtr != nil { + return gpuPtr, func() {}, nil + } + devPtr, err = e.pool.Alloc(e.deviceID, len(rawBytes)) + if err != nil { + return nil, nil, err + } + free = func() { e.pool.Free(e.deviceID, devPtr, len(rawBytes)) } + if err := e.runtime.Memcpy(devPtr, unsafe.Pointer(&rawBytes[0]), len(rawBytes), gpuapi.MemcpyHostToDevice); err != nil { + free() + return nil, nil, err + } + return devPtr, free, nil +} + +// aShapeCheck2D validates shapes for A-side quantized matmul methods. +// Returns m, k, n and whether the caller should fall back to CPU. +// Falls back when either operand has <2 dims, either has >2 dims, +// or k is not a multiple of kAlignment. +func aShapeCheck2D(aShape, bShape []int, kAlignment int) (m, k, n int, fallback bool) { + if len(aShape) < 2 || len(bShape) < 2 { + return 0, 0, 0, true + } + if len(aShape) > 2 || len(bShape) > 2 { + return 0, 0, 0, true + } + m = aShape[0] + k = aShape[1] + n = bShape[1] + if kAlignment > 0 && k%kAlignment != 0 { + return 0, 0, 0, true + } + return m, k, n, false +} + +// bweightShapeMKN validates shapes for BWeight (virtual-transposed weight) +// matmul methods. It flattens A's batch dimensions and builds the output shape. +// Returns m, k, n, outShape, and whether the caller should fall back to CPU. +// Falls back when either operand has <2 dims, B has >2 dims, or k is not a +// multiple of kAlignment. +func bweightShapeMKN(aShape, bShape []int, kAlignment int) (m, k, n int, outShape []int, fallback bool) { + if len(aShape) < 2 || len(bShape) < 2 { + return 0, 0, 0, nil, true + } + if len(bShape) > 2 { + return 0, 0, 0, nil, true + } + k = aShape[len(aShape)-1] + m = 1 + for i := 0; i < len(aShape)-1; i++ { + m *= aShape[i] + } + n = bShape[1] + if kAlignment > 0 && k%kAlignment != 0 { + return 0, 0, 0, nil, true + } + outShape = make([]int, len(aShape)) + copy(outShape, aShape[:len(aShape)-1]) + outShape[len(outShape)-1] = n + return m, k, n, outShape, false +} + +// quantGemvResult allocates a GPU output buffer of outElems float32s, calls +// gemvFn to fill it, and wraps the result as a tensor with the given shape. +// On any error it cleans up and returns (nil, err) so the caller can fall back. +func (e *GPUEngine[T]) quantGemvResult(outShape []int, outElems int, gemvFn func(devY unsafe.Pointer) error, dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error) { + cSize := outElems * f32Size + devY, err := e.pool.Alloc(e.deviceID, cSize) + if err != nil { + return nil, err + } + if err := gemvFn(devY); err != nil { + e.pool.Free(e.deviceID, devY, cSize) + return nil, err + } + return makeGPUResult[T](e, outShape, devY, outElems, dst...) +} + +// dequantSgemm dequantizes weight data to F32 on the GPU (or CPU as fallback), +// uploads the other operand, allocates the output, and calls cuBLAS Sgemm. +// C[m,n] = dequantA[m,k] * B[k,n]. +// +// devW is the quantized device pointer. dequantFn is the GPU dequantization +// function; if nil, cpuDequant is used instead. cpuDequant produces F32 data +// on the host (may be nil if dequantFn always succeeds). +func (e *GPUEngine[T]) dequantSgemm( + devW unsafe.Pointer, m, k, n int, + dequantFn func(src, dst unsafe.Pointer, rows, cols int) error, + cpuDequant func([]float32), + b *tensor.TensorNumeric[T], + outShape []int, name string, + dst ...*tensor.TensorNumeric[T], +) (*tensor.TensorNumeric[T], error) { + dequantSize := m * k * f32Size + devAF32, err := e.pool.Alloc(e.deviceID, dequantSize) + if err != nil { + return nil, err + } + defer e.pool.Free(e.deviceID, devAF32, dequantSize) + + dequanted := false + if dequantFn != nil { + if err := dequantFn(devW, devAF32, m, k); err == nil { + dequanted = true + } + } + if !dequanted { + if cpuDequant == nil { + return nil, fmt.Errorf("%s: no dequant path available", name) + } + dequant := make([]float32, m*k) + cpuDequant(dequant) + if err := e.runtime.Memcpy(devAF32, unsafe.Pointer(&dequant[0]), dequantSize, gpuapi.MemcpyHostToDevice); err != nil { + return nil, err + } + } + + devB, cleanupB, err := getDevicePtr(e, b) + if err != nil { + return nil, err + } + defer cleanupB() + + cSize := m * n * f32Size + devC, err := e.pool.Alloc(e.deviceID, cSize) + if err != nil { + return nil, err + } + + if err := e.blas.Sgemm(m, n, k, 1.0, devAF32, devB, 0.0, devC); err != nil { + e.pool.Free(e.deviceID, devC, cSize) + return nil, fmt.Errorf("%s: Sgemm: %w", name, err) + } + + return makeGPUResult[T](e, outShape, devC, m*n, dst...) +} + +// sgemmNTOrFallback dequantizes weight data to F32, then computes +// C[m,n] = A[m,k] * dequant(B)[n,k]^T using SgemmNT if available, +// or an explicit Transpose2D + Sgemm otherwise. +// +// devW is the quantized device pointer laid out as [n, k]. +// dequantFn is the GPU dequantization function (may be nil). +// cpuDequant produces F32 data on the host (may be nil if dequantFn always succeeds). +func (e *GPUEngine[T]) sgemmNTOrFallback( + devW unsafe.Pointer, m, k, n int, + dequantFn func(src, dst unsafe.Pointer, rows, cols int) error, + cpuDequant func([]float32), + a *tensor.TensorNumeric[T], + outShape []int, name string, + dst ...*tensor.TensorNumeric[T], +) (*tensor.TensorNumeric[T], error) { + dequantSize := n * k * f32Size + devBF32, err := e.pool.Alloc(e.deviceID, dequantSize) + if err != nil { + return nil, err + } + defer e.pool.Free(e.deviceID, devBF32, dequantSize) + + dequanted := false + if dequantFn != nil { + if err := dequantFn(devW, devBF32, n, k); err == nil { + dequanted = true + } + } + if !dequanted { + if cpuDequant == nil { + return nil, fmt.Errorf("%s: no dequant path available", name) + } + dequant := make([]float32, n*k) + cpuDequant(dequant) + if err := e.runtime.Memcpy(devBF32, unsafe.Pointer(&dequant[0]), dequantSize, gpuapi.MemcpyHostToDevice); err != nil { + return nil, err + } + } + + devA, cleanupA, err := getDevicePtr(e, a) + if err != nil { + return nil, err + } + defer cleanupA() + + cSize := m * n * f32Size + devC, err := e.pool.Alloc(e.deviceID, cSize) + if err != nil { + return nil, err + } + + // Prefer SgemmNT (avoids explicit transpose). + if ntBLAS, ok := e.blas.(gpuapi.BLASTransposeB); ok { + if err := ntBLAS.SgemmNT(m, n, k, 1.0, devA, devBF32, 0.0, devC); err != nil { + e.pool.Free(e.deviceID, devC, cSize) + return nil, fmt.Errorf("%s: SgemmNT: %w", name, err) + } + return makeGPUResult[T](e, outShape, devC, m*n, dst...) + } + + // Fallback: transpose dequantized B then use Sgemm. + devBT, err := e.pool.Alloc(e.deviceID, dequantSize) + if err != nil { + e.pool.Free(e.deviceID, devC, cSize) + return nil, err + } + defer e.pool.Free(e.deviceID, devBT, dequantSize) + + if err := e.kernels.Transpose2D(devBF32, devBT, n, k, e.stream); err != nil { + e.pool.Free(e.deviceID, devC, cSize) + return nil, err + } + + if err := e.blas.Sgemm(m, n, k, 1.0, devA, devBT, 0.0, devC); err != nil { + e.pool.Free(e.deviceID, devC, cSize) + return nil, fmt.Errorf("%s: Sgemm: %w", name, err) + } + + return makeGPUResult[T](e, outShape, devC, m*n, dst...) +} From 5dcbb87460b701f5a063101ecb36c0eb242603cf Mon Sep 17 00:00:00 2001 From: David Ndungu Date: Mon, 6 Apr 2026 14:12:15 -0700 Subject: [PATCH 2/3] fix(kernels): update GemvQ5_0F32 test to match qhOffset/qsOffset signature The GemvQ5_0F32 kernel was updated to accept qhOffset and qsOffset parameters for the GPU-separated layout, but the test still called with the old 6-arg signature. Fix all 3 call sites: - TestGemvQ5_0F32_Parity - TestGemvQ5_0F32_MultipleSizes - BenchmarkGemvQ5_0F32_4096 Add q5_0ToGPULayout helper to convert standard block format to the GPU-separated layout (scales | qh | qs) needed by the kernel. --- internal/cuda/kernels/gemv_q5_0_test.go | 92 +++++++++++++++++++------ 1 file changed, 70 insertions(+), 22 deletions(-) diff --git a/internal/cuda/kernels/gemv_q5_0_test.go b/internal/cuda/kernels/gemv_q5_0_test.go index 6fde7a5..02da073 100644 --- a/internal/cuda/kernels/gemv_q5_0_test.go +++ b/internal/cuda/kernels/gemv_q5_0_test.go @@ -9,6 +9,30 @@ import ( "github.com/zerfoo/ztensor/internal/cuda" ) +// q5_0ToGPULayout converts standard Q5_0 block data (22 bytes/block: [d(2)|qh(4)|qs(16)]) +// to the GPU-separated layout: scales(2*N, padded) | qh(4*N, padded) | qs(16*N). +func q5_0ToGPULayout(raw []byte, nBlocks int) []byte { + const blockBytes = 22 + scaleBytes := nBlocks * 2 + paddedScaleBytes := (scaleBytes + 15) &^ 15 + qhBytes := nBlocks * 4 + paddedQhBytes := (qhBytes + 15) &^ 15 + qsBytes := nBlocks * 16 + total := paddedScaleBytes + paddedQhBytes + qsBytes + + out := make([]byte, total) + for i := range nBlocks { + blockOff := i * blockBytes + // scale: 2 bytes at blockOff+0 + copy(out[i*2:i*2+2], raw[blockOff:blockOff+2]) + // qh: 4 bytes at blockOff+2 + copy(out[paddedScaleBytes+i*4:paddedScaleBytes+i*4+4], raw[blockOff+2:blockOff+6]) + // qs: 16 bytes at blockOff+6 + copy(out[paddedScaleBytes+paddedQhBytes+i*16:paddedScaleBytes+paddedQhBytes+i*16+16], raw[blockOff+6:blockOff+22]) + } + return out +} + // dequantizeQ5_0 dequantizes one Q5_0 block (22 bytes) into 32 float32 values. // Inlined here to avoid an import cycle with the tensor package. func dequantizeQ5_0(raw []byte, dst []float32) { @@ -153,12 +177,6 @@ func TestGemvQ5_0F32_Parity(t *testing.T) { } defer func() { _ = stream.Destroy() }() - devW, err := cuda.Malloc(len(raw)) - if err != nil { - t.Fatalf("cuda.Malloc W: %v", err) - } - defer func() { _ = cuda.Free(devW) }() - devX, err := cuda.Malloc(K * 4) if err != nil { t.Fatalf("cuda.Malloc x: %v", err) @@ -171,14 +189,30 @@ func TestGemvQ5_0F32_Parity(t *testing.T) { } defer func() { _ = cuda.Free(devY) }() - if err := cuda.Memcpy(devW, unsafe.Pointer(&raw[0]), len(raw), cuda.MemcpyHostToDevice); err != nil { - t.Fatalf("Memcpy W: %v", err) - } if err := cuda.Memcpy(devX, unsafe.Pointer(&x[0]), K*4, cuda.MemcpyHostToDevice); err != nil { t.Fatalf("Memcpy x: %v", err) } - if err := GemvQ5_0F32(devW, devX, devY, M, K, stream.Ptr()); err != nil { + // Convert standard Q5_0 blocks to GPU-separated layout (scales | qh | qs) + // and compute region offsets for the kernel. + nBlocks := M * (K / 32) + gpuRaw := q5_0ToGPULayout(raw, nBlocks) + scaleBytes := nBlocks * 2 + qhOffset := (scaleBytes + 15) &^ 15 + qhBytes := nBlocks * 4 + qsOffset := qhOffset + (qhBytes+15)&^15 + + // Re-upload GPU-layout data. + devWGPU, err := cuda.Malloc(len(gpuRaw)) + if err != nil { + t.Fatalf("cuda.Malloc W GPU: %v", err) + } + defer func() { _ = cuda.Free(devWGPU) }() + if err := cuda.Memcpy(devWGPU, unsafe.Pointer(&gpuRaw[0]), len(gpuRaw), cuda.MemcpyHostToDevice); err != nil { + t.Fatalf("Memcpy W GPU: %v", err) + } + + if err := GemvQ5_0F32(devWGPU, devX, devY, M, K, qhOffset, qsOffset, stream.Ptr()); err != nil { t.Fatalf("GemvQ5_0F32: %v", err) } @@ -241,12 +275,6 @@ func TestGemvQ5_0F32_MultipleSizes(t *testing.T) { } defer func() { _ = stream.Destroy() }() - devW, err := cuda.Malloc(len(raw)) - if err != nil { - t.Fatalf("cuda.Malloc W: %v", err) - } - defer func() { _ = cuda.Free(devW) }() - devX, err := cuda.Malloc(tc.K * 4) if err != nil { t.Fatalf("cuda.Malloc x: %v", err) @@ -259,14 +287,27 @@ func TestGemvQ5_0F32_MultipleSizes(t *testing.T) { } defer func() { _ = cuda.Free(devY) }() - if err := cuda.Memcpy(devW, unsafe.Pointer(&raw[0]), len(raw), cuda.MemcpyHostToDevice); err != nil { - t.Fatalf("Memcpy W: %v", err) + // Convert to GPU-separated layout and compute offsets. + nBlocks := tc.M * (tc.K / 32) + gpuRaw := q5_0ToGPULayout(raw, nBlocks) + scaleBytes := nBlocks * 2 + qhOffset := (scaleBytes + 15) &^ 15 + qhBytes := nBlocks * 4 + qsOffset := qhOffset + (qhBytes+15)&^15 + + devWGPU, err := cuda.Malloc(len(gpuRaw)) + if err != nil { + t.Fatalf("cuda.Malloc W GPU: %v", err) + } + defer func() { _ = cuda.Free(devWGPU) }() + if err := cuda.Memcpy(devWGPU, unsafe.Pointer(&gpuRaw[0]), len(gpuRaw), cuda.MemcpyHostToDevice); err != nil { + t.Fatalf("Memcpy W GPU: %v", err) } if err := cuda.Memcpy(devX, unsafe.Pointer(&x[0]), tc.K*4, cuda.MemcpyHostToDevice); err != nil { t.Fatalf("Memcpy x: %v", err) } - if err := GemvQ5_0F32(devW, devX, devY, tc.M, tc.K, stream.Ptr()); err != nil { + if err := GemvQ5_0F32(devWGPU, devX, devY, tc.M, tc.K, qhOffset, qsOffset, stream.Ptr()); err != nil { t.Fatalf("GemvQ5_0F32: %v", err) } @@ -318,19 +359,26 @@ func BenchmarkGemvQ5_0F32_4096(b *testing.B) { } defer func() { _ = stream.Destroy() }() - devW, _ := cuda.Malloc(len(raw)) + nBlocks := M * (K / 32) + gpuRaw := q5_0ToGPULayout(raw, nBlocks) + scaleBytes := nBlocks * 2 + qhOffset := (scaleBytes + 15) &^ 15 + qhBytes := nBlocks * 4 + qsOffset := qhOffset + (qhBytes+15)&^15 + + devW, _ := cuda.Malloc(len(gpuRaw)) defer func() { _ = cuda.Free(devW) }() devX, _ := cuda.Malloc(K * 4) defer func() { _ = cuda.Free(devX) }() devY, _ := cuda.Malloc(M * 4) defer func() { _ = cuda.Free(devY) }() - _ = cuda.Memcpy(devW, unsafe.Pointer(&raw[0]), len(raw), cuda.MemcpyHostToDevice) + _ = cuda.Memcpy(devW, unsafe.Pointer(&gpuRaw[0]), len(gpuRaw), cuda.MemcpyHostToDevice) _ = cuda.Memcpy(devX, unsafe.Pointer(&x[0]), K*4, cuda.MemcpyHostToDevice) b.ResetTimer() for b.Loop() { - _ = GemvQ5_0F32(devW, devX, devY, M, K, stream.Ptr()) + _ = GemvQ5_0F32(devW, devX, devY, M, K, qhOffset, qsOffset, stream.Ptr()) } _ = stream.Synchronize() From 601a677309e244a813bc4c8b4213dbf586a15a0d Mon Sep 17 00:00:00 2001 From: David Ndungu Date: Mon, 6 Apr 2026 14:13:56 -0700 Subject: [PATCH 3/3] fix(ci): exclude metal and pjrt from go vet These packages use unsafe.Pointer for GPU/accelerator runtime bindings via purego/dlopen, same as cuda/hip/opencl. The pjrt package was added after the initial CI exclusion list. --- .github/workflows/ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index a27c58f..b9f8128 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -18,5 +18,5 @@ jobs: # Run go vet on all packages except those with intentional # unsafe.Pointer usage for GPU runtime bindings via purego/dlopen. # These warnings are expected and documented in docs/QUALITY.md. - go vet $(go list ./... | grep -v '/internal/cuda$' | grep -v '/internal/hip$' | grep -v '/internal/opencl$' | grep -v '/internal/cudnn$' | grep -v '/internal/tensorrt$' | grep -v '/internal/fpga$' | grep -v '/internal/sycl$') + go vet $(go list ./... | grep -v '/internal/cuda$' | grep -v '/internal/hip$' | grep -v '/internal/opencl$' | grep -v '/internal/cudnn$' | grep -v '/internal/tensorrt$' | grep -v '/internal/fpga$' | grep -v '/internal/sycl$' | grep -v '/internal/metal$' | grep -v '/internal/pjrt$') - run: go test -race -timeout 300s ./...