Skip to content

Commit 66d65ec

Browse files
authored
cuda: cap grid.y at 65535 in non-contiguous dequantize/convert kernels (ggml-org#19999)
1 parent 05728db commit 66d65ec

1 file changed

Lines changed: 32 additions & 32 deletions

File tree

ggml/src/ggml-cuda/convert.cu

Lines changed: 32 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -16,27 +16,27 @@ static __global__ void dequantize_block(const void * __restrict__ vx, dst_t * __
1616
return;
1717
}
1818

19-
const int64_t i01 = blockIdx.y;
20-
21-
for (int64_t i0203 = blockIdx.z; i0203 < ne0203; i0203 += gridDim.z) {
22-
const uint2 dm = fast_div_modulo((uint32_t)i0203, ne02);
23-
const int64_t i02 = dm.y;
24-
const int64_t i03 = dm.x;
25-
26-
const int64_t ibx0 = i03*s03 + i02*s02 + i01*s01;
27-
28-
const int64_t ib = ibx0 + i00/qk; // block index
29-
const int64_t iqs = (i00%qk)/qr; // quant index
30-
const int64_t iybs = i00 - i00%qk; // y block start index
31-
const int64_t y_offset = qr == 1 ? 1 : qk/2;
32-
33-
// dequantize
34-
float2 v;
35-
dequantize_kernel(vx, ib, iqs, v);
36-
37-
const int64_t iy0 = (i0203*ne01 + i01)*ne00 + iybs + iqs;
38-
y[iy0 + 0] = ggml_cuda_cast<dst_t>(v.x);
39-
y[iy0 + y_offset] = ggml_cuda_cast<dst_t>(v.y);
19+
for (int64_t i01 = blockIdx.y; i01 < ne01; i01 += gridDim.y) {
20+
for (int64_t i0203 = blockIdx.z; i0203 < ne0203; i0203 += gridDim.z) {
21+
const uint2 dm = fast_div_modulo((uint32_t)i0203, ne02);
22+
const int64_t i02 = dm.y;
23+
const int64_t i03 = dm.x;
24+
25+
const int64_t ibx0 = i03*s03 + i02*s02 + i01*s01;
26+
27+
const int64_t ib = ibx0 + i00/qk; // block index
28+
const int64_t iqs = (i00%qk)/qr; // quant index
29+
const int64_t iybs = i00 - i00%qk; // y block start index
30+
const int64_t y_offset = qr == 1 ? 1 : qk/2;
31+
32+
// dequantize
33+
float2 v;
34+
dequantize_kernel(vx, ib, iqs, v);
35+
36+
const int64_t iy0 = (i0203*ne01 + i01)*ne00 + iybs + iqs;
37+
y[iy0 + 0] = ggml_cuda_cast<dst_t>(v.x);
38+
y[iy0 + y_offset] = ggml_cuda_cast<dst_t>(v.y);
39+
}
4040
}
4141
}
4242

@@ -492,7 +492,7 @@ static void dequantize_block_cuda(const void * vx, dst_t * y,
492492
const int64_t s01, const int64_t s02, const int64_t s03, cudaStream_t stream) {
493493
const int64_t ne0203 = ne02*ne03;
494494
const uint3 ne02_fdv = init_fastdiv_values(ne02);
495-
const dim3 num_blocks((ne00 + 2*CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / (2*CUDA_DEQUANTIZE_BLOCK_SIZE), ne01, (int)std::min(ne0203, (int64_t)65535));
495+
const dim3 num_blocks((ne00 + 2*CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / (2*CUDA_DEQUANTIZE_BLOCK_SIZE), (int)std::min(ne01, (int64_t)65535), (int)std::min(ne0203, (int64_t)65535));
496496
dequantize_block<qk, qr, dequantize_kernel><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>
497497
(vx, y, ne00, ne01, ne0203, ne02_fdv, s01, s02, s03);
498498
}
@@ -628,18 +628,18 @@ static __global__ void convert_unary(
628628
return;
629629
}
630630

631-
const int64_t i01 = blockIdx.y;
632-
633631
const src_t * x = (const src_t *) vx;
634632

635-
for (int64_t i0203 = blockIdx.z; i0203 < ne0203; i0203 += gridDim.z) {
636-
const uint2 dm = fast_div_modulo((uint32_t)i0203, ne02);
637-
const int64_t i02 = dm.y;
638-
const int64_t i03 = dm.x;
633+
for (int64_t i01 = blockIdx.y; i01 < ne01; i01 += gridDim.y) {
634+
for (int64_t i0203 = blockIdx.z; i0203 < ne0203; i0203 += gridDim.z) {
635+
const uint2 dm = fast_div_modulo((uint32_t)i0203, ne02);
636+
const int64_t i02 = dm.y;
637+
const int64_t i03 = dm.x;
639638

640-
const int64_t ix = i03*s03 + i02*s02 + i01*s01 + i00;
641-
const int64_t iy = (i0203*ne01 + i01)*ne00 + i00;
642-
y[iy] = ggml_cuda_cast<dst_t>(x[ix]);
639+
const int64_t ix = i03*s03 + i02*s02 + i01*s01 + i00;
640+
const int64_t iy = (i0203*ne01 + i01)*ne00 + i00;
641+
y[iy] = ggml_cuda_cast<dst_t>(x[ix]);
642+
}
643643
}
644644
}
645645

@@ -649,7 +649,7 @@ static void convert_unary_cuda(const void * vx, dst_t * y,
649649
const int64_t s01, const int64_t s02, const int64_t s03, cudaStream_t stream) {
650650
const int64_t ne0203 = ne02*ne03;
651651
const uint3 ne02_fdv = init_fastdiv_values(ne02);
652-
const dim3 num_blocks((ne00 + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE, ne01, (int)std::min(ne0203, (int64_t)65535));
652+
const dim3 num_blocks((ne00 + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE, (int)std::min(ne01, (int64_t)65535), (int)std::min(ne0203, (int64_t)65535));
653653
convert_unary<src_t><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>
654654
(vx, y, ne00, ne01, ne0203, ne02_fdv, s01, s02, s03);
655655
}

0 commit comments

Comments
 (0)