Skip to content

Commit b0311c1

Browse files
CUDA: fix padding of GQA to power of 2 in FA (ggml-org#19115)
1 parent 8f80d1b commit b0311c1

3 files changed

Lines changed: 64 additions & 52 deletions

File tree

ggml/src/ggml-cuda/fattn-common.cuh

Lines changed: 25 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -629,8 +629,8 @@ static __global__ void flash_attn_mask_to_KV_max(
629629
template<int D, int ncols1, int ncols2> // D == head size
630630
__launch_bounds__(D, 1)
631631
static __global__ void flash_attn_stream_k_fixup(
632-
float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne03, const int ne11,
633-
const int nbatch_fa) {
632+
float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne03,
633+
const int ne11, const int ne12, const int nbatch_fa) {
634634
constexpr int ncols = ncols1*ncols2;
635635

636636
const int bidx0 = blockIdx.x;
@@ -641,12 +641,14 @@ static __global__ void flash_attn_stream_k_fixup(
641641

642642
const float * dst_fixup_data = ((const float *) dst_fixup) + gridDim.x*(2*2*ncols);
643643

644-
const int iter_k = (ne11 + (nbatch_fa - 1)) / nbatch_fa;
645-
const int iter_j = (ne01 + (ncols1 - 1)) / ncols1;
646-
const int iter_z = (ne02 + (ncols2 - 1)) / ncols2;
644+
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
647645

648-
const int kbc0 = int64_t(bidx0 + 0)*(iter_k*iter_j*iter_z*ne03) / gridDim.x;
649-
const int kbc0_stop = int64_t(bidx0 + 1)*(iter_k*iter_j*iter_z*ne03) / gridDim.x;
646+
const int iter_k = (ne11 + (nbatch_fa - 1)) / nbatch_fa;
647+
const int iter_j = (ne01 + (ncols1 - 1)) / ncols1;
648+
const int iter_z_gqa = (gqa_ratio + (ncols2 - 1)) / ncols2;
649+
650+
const int kbc0 = int64_t(bidx0 + 0)*(iter_k*iter_j*iter_z_gqa*ne12*ne03) / gridDim.x;
651+
const int kbc0_stop = int64_t(bidx0 + 1)*(iter_k*iter_j*iter_z_gqa*ne12*ne03) / gridDim.x;
650652

651653
const bool did_not_have_any_data = kbc0 == kbc0_stop;
652654
const bool wrote_beginning_of_tile = kbc0 % iter_k == 0;
@@ -655,15 +657,19 @@ static __global__ void flash_attn_stream_k_fixup(
655657
return;
656658
}
657659

658-
const int sequence = kbc0 / (iter_k*iter_j*iter_z);
659-
const int zt = (kbc0 - iter_k*iter_j*iter_z*sequence) / (iter_k*iter_j);
660-
const int jt = (kbc0 - iter_k*iter_j*iter_z*sequence - iter_k*iter_j*zt) / iter_k; // j index of current tile.
660+
// z_KV == K/V head index, zt_gqa = Q head start index per K/V head, jt = token position start index
661+
const int sequence = kbc0 /(iter_k*iter_j*iter_z_gqa*ne12);
662+
const int z_KV = (kbc0 - iter_k*iter_j*iter_z_gqa*ne12 * sequence)/(iter_k*iter_j*iter_z_gqa);
663+
const int zt_gqa = (kbc0 - iter_k*iter_j*iter_z_gqa*ne12 * sequence - iter_k*iter_j*iter_z_gqa * z_KV)/(iter_k*iter_j);
664+
const int jt = (kbc0 - iter_k*iter_j*iter_z_gqa*ne12 * sequence - iter_k*iter_j*iter_z_gqa * z_KV - iter_k*iter_j * zt_gqa) / iter_k;
665+
666+
const int zt_Q = z_KV*gqa_ratio + zt_gqa*ncols2; // Global Q head start index.
661667

662-
if (jt*ncols1 + j >= ne01 || zt*ncols2 + c >= ne02) {
668+
if (jt*ncols1 + j >= ne01 || zt_gqa*ncols2 + c >= gqa_ratio) {
663669
return;
664670
}
665671

666-
dst += sequence*ne02*ne01*D + jt*ne02*(ncols1*D) + zt*(ncols2*D) + (j*ne02 + c)*D + tid;
672+
dst += sequence*ne02*ne01*D + jt*ne02*(ncols1*D) + zt_Q*D + (j*ne02 + c)*D + tid;
667673

668674
// Load the partial result that needs a fixup:
669675
float dst_val = 0.0f;
@@ -682,7 +688,7 @@ static __global__ void flash_attn_stream_k_fixup(
682688
int bidx = bidx0 - 1;
683689
int kbc_stop = kbc0;
684690
while(true) {
685-
const int kbc = int64_t(bidx)*(iter_k*iter_j*iter_z*ne03) / gridDim.x;
691+
const int kbc = int64_t(bidx)*(iter_k*iter_j*iter_z_gqa*ne12*ne03) / gridDim.x;
686692
if (kbc == kbc_stop) { // Did not have any data.
687693
bidx--;
688694
kbc_stop = kbc;
@@ -883,9 +889,10 @@ void launch_fattn(
883889
}
884890
}
885891

886-
const int ntiles_x = ((Q->ne[1] + ncols1 - 1) / ncols1);
887-
const int ntiles_z = ((Q->ne[2] + ncols2 - 1) / ncols2);
888-
const int ntiles_total = ntiles_x * ntiles_z * Q->ne[3];
892+
const int ntiles_x = ((Q->ne[1] + ncols1 - 1) / ncols1);
893+
const int gqa_ratio = Q->ne[2] / K->ne[2];
894+
const int ntiles_z_gqa = ((gqa_ratio + ncols2 - 1) / ncols2);
895+
const int ntiles_total = ntiles_x * ntiles_z_gqa * K->ne[2] * Q->ne[3];
889896

890897
// Optional optimization where the mask is scanned to determine whether part of the calculation can be skipped.
891898
// Only worth the overhead if there is at lease one FATTN_KQ_STRIDE x FATTN_KQ_STRIDE square to be skipped or
@@ -960,7 +967,7 @@ void launch_fattn(
960967

961968
blocks_num.x = ntiles_x;
962969
blocks_num.y = parallel_blocks;
963-
blocks_num.z = ntiles_z*Q->ne[3];
970+
blocks_num.z = ntiles_z_gqa*K->ne[2]*Q->ne[3];
964971

965972
if (parallel_blocks > 1) {
966973
dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV));
@@ -1014,7 +1021,7 @@ void launch_fattn(
10141021

10151022
flash_attn_stream_k_fixup<DV, ncols1, ncols2>
10161023
<<<blocks_num_combine, block_dim_combine, 0, main_stream>>>
1017-
((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], Q->ne[3], K->ne[1], nbatch_fa);
1024+
((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], Q->ne[3], K->ne[1], K->ne[2], nbatch_fa);
10181025
}
10191026
} else if (parallel_blocks > 1) {
10201027
const dim3 block_dim_combine(DV, 1, 1);

ggml/src/ggml-cuda/fattn-mma-f16.cuh

Lines changed: 37 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -933,14 +933,15 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
933933
const float logit_softcap,
934934
const uint3 ne01,
935935
const int ne02,
936+
const int gqa_ratio,
936937
const int ne11,
937938
const int stride_Q1,
938939
const int stride_Q2,
939940
const int stride_K,
940941
const int stride_V,
941942
const int stride_mask,
942943
const int jt,
943-
const int zt,
944+
const int zt_gqa,
944945
const int kb0_start,
945946
const int kb0_stop) {
946947
#if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4))
@@ -1023,7 +1024,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
10231024
const int j = jc / ncols2;
10241025
const int c = jc % ncols2;
10251026

1026-
if ((ncols1 == 1 || jt*ncols1 + j < int(ne01.z)) && (ncols2 == 1 || zt*ncols2 + c < ne02)) {
1027+
if ((ncols1 == 1 || jt*ncols1 + j < int(ne01.z)) && (ncols2 == 1 || zt_gqa*ncols2 + c < gqa_ratio)) {
10271028
#pragma unroll
10281029
for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
10291030
const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
@@ -1409,7 +1410,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
14091410
const int j_dst = jc_dst / ncols2;
14101411
const int c_dst = jc_dst % ncols2;
14111412

1412-
if (!is_fixup && ((ncols1 > 1 && jt*ncols1 + j_dst >= int(ne01.z)) || (ncols2 > 1 && zt*ncols2 + c_dst >= ne02))) {
1413+
if (!is_fixup && ((ncols1 > 1 && jt*ncols1 + j_dst >= int(ne01.z)) || (ncols2 > 1 && zt_gqa*ncols2 + c_dst >= gqa_ratio))) {
14131414
continue;
14141415
}
14151416

@@ -1448,7 +1449,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
14481449
}
14491450
#else
14501451
GGML_UNUSED_VARS(Q_f2, K_h2, V_h2, mask_h, sinks_f, dstk, dstk_fixup,
1451-
scale, slope, logit_softcap, ne01, ne02,
1452+
scale, slope, logit_softcap, ne01, ne02, gqa_ratio,
14521453
stride_Q1, stride_Q2, stride_K, stride_V, stride_mask,
14531454
jt, kb0_start, kb0_stop);
14541455
NO_DEVICE_CODE;
@@ -1521,13 +1522,13 @@ static __global__ void flash_attn_ext_f16(
15211522

15221523
const int stride_V = V_is_K_view ? stride_K : nb21 / sizeof(half2);
15231524

1524-
const int iter_k = (ne11 + (nbatch_fa - 1)) / nbatch_fa;
1525-
const int iter_j = (ne01.z + (ncols1 - 1)) / ncols1;
1526-
const int iter_z = (ne02 + (ncols2 - 1)) / ncols2;
1525+
const int iter_k = (ne11 + (nbatch_fa - 1)) / nbatch_fa;
1526+
const int iter_j = (ne01.z + (ncols1 - 1)) / ncols1;
1527+
const int iter_z_gqa = (gqa_ratio + (ncols2 - 1)) / ncols2;
15271528

15281529
// kbc == k block continuous, current index in continuous ijk space.
1529-
int kbc = int64_t(blockIdx.x + 0)*(iter_k*iter_j*iter_z*ne03) / gridDim.x;
1530-
const int kbc_stop = int64_t(blockIdx.x + 1)*(iter_k*iter_j*iter_z*ne03) / gridDim.x;
1530+
int kbc = int64_t(blockIdx.x + 0)*(iter_k*iter_j*iter_z_gqa*ne12*ne03) / gridDim.x;
1531+
const int kbc_stop = int64_t(blockIdx.x + 1)*(iter_k*iter_j*iter_z_gqa*ne12*ne03) / gridDim.x;
15311532

15321533
// If the seams of 2 CUDA blocks fall within an output tile their results need to be combined.
15331534
// For this we need to track both the block that starts the tile (needs_fixup) and the block that finishes the tile (is_fixup).
@@ -1538,22 +1539,24 @@ static __global__ void flash_attn_ext_f16(
15381539
int kb0_stop = min(iter_k, kb0_start + kbc_stop - kbc);
15391540

15401541
while (kbc < kbc_stop && kb0_stop == iter_k) {
1541-
const int sequence = kbc / (iter_k*iter_j*iter_z);
1542-
const int zt = (kbc - iter_k*iter_j*iter_z*sequence) / (iter_k*iter_j); // head in units of ncols2
1543-
const int jt = (kbc - iter_k*iter_j*iter_z*sequence - iter_k*iter_j*zt) / iter_k; // j index of current tile.
1542+
// z_KV == K/V head index, zt_gqa = Q head start index per K/V head, jt = token position start index
1543+
const int sequence = kbc /(iter_k*iter_j*iter_z_gqa*ne12);
1544+
const int z_KV = (kbc - iter_k*iter_j*iter_z_gqa*ne12 * sequence)/(iter_k*iter_j*iter_z_gqa);
1545+
const int zt_gqa = (kbc - iter_k*iter_j*iter_z_gqa*ne12 * sequence - iter_k*iter_j*iter_z_gqa * z_KV)/(iter_k*iter_j);
1546+
const int jt = (kbc - iter_k*iter_j*iter_z_gqa*ne12 * sequence - iter_k*iter_j*iter_z_gqa * z_KV - iter_k*iter_j * zt_gqa) / iter_k;
15441547

1545-
const int head0 = zt * ncols2;
1548+
const int zt_Q = z_KV*gqa_ratio + zt_gqa*ncols2; // Global Q head start index.
15461549

1547-
const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02* head0);
1548-
const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*(head0 / gqa_ratio));
1550+
const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02*zt_Q);
1551+
const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*z_KV);
15491552
const half * mask_h = ncols2 == 1 && !mask ? nullptr :
15501553
(const half *) (mask + nb33*(sequence % ne33));
1551-
float2 * dstk = ((float2 *) dst) + (sequence*ne01.z*ne02 + head0) * (DV/2);
1554+
float2 * dstk = ((float2 *) dst) + (sequence*ne01.z*ne02 + zt_Q) * (DV/2);
15521555

1553-
const half2 * V_h2 = V_is_K_view ? K_h2 : (const half2 *) (V + nb23*sequence + nb22*(head0 / gqa_ratio));
1554-
const float * sinks_f = sinks ? (const float *) sinks + head0 : nullptr;
1556+
const half2 * V_h2 = V_is_K_view ? K_h2 : (const half2 *) (V + nb23*sequence + nb22*z_KV);
1557+
const float * sinks_f = sinks ? (const float *) sinks + zt_Q : nullptr;
15551558

1556-
const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head0, n_head_log2, m0, m1) : 1.0f;
1559+
const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, zt_Q, n_head_log2, m0, m1) : 1.0f;
15571560

15581561
if (KV_max) {
15591562
kb0_stop = min(kb0_stop, KV_max[sequence*iter_j + jt] / nbatch_fa);
@@ -1563,12 +1566,12 @@ static __global__ void flash_attn_ext_f16(
15631566
constexpr bool needs_fixup = false; // CUDA block is working on an entire tile.
15641567
flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, V_is_K_view, needs_fixup, is_fixup>
15651568
(Q_f2, K_h2, V_h2, mask_h, sinks_f, dstk, dst_meta, scale, slope, logit_softcap,
1566-
ne01, ne02, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, zt, kb0_start, kb0_stop);
1569+
ne01, ne02, gqa_ratio, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, zt_gqa, kb0_start, kb0_stop);
15671570
} else {
15681571
constexpr bool needs_fixup = true; // CUDA block is missing the beginning of a tile.
15691572
flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, V_is_K_view, needs_fixup, is_fixup>
15701573
(Q_f2, K_h2, V_h2, mask_h, sinks_f, dstk, dst_meta, scale, slope, logit_softcap,
1571-
ne01, ne02, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, zt, kb0_start, kb0_stop);
1574+
ne01, ne02, gqa_ratio, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, zt_gqa, kb0_start, kb0_stop);
15721575
}
15731576

15741577
kbc += iter_k;
@@ -1582,22 +1585,24 @@ static __global__ void flash_attn_ext_f16(
15821585
return;
15831586
}
15841587

1585-
const int sequence = kbc / (iter_k*iter_j*iter_z);
1586-
const int zt = (kbc - iter_k*iter_j*iter_z*sequence) / (iter_k*iter_j); // head in units of ncols2
1587-
const int jt = (kbc - iter_k*iter_j*iter_z*sequence - iter_k*iter_j*zt) / iter_k; // j index of current tile.
1588+
// z_KV == K/V head index, zt_gqa = Q head start index per K/V head, jt = token position start index.
1589+
const int sequence = kbc /(iter_k*iter_j*iter_z_gqa*ne12);
1590+
const int z_KV = (kbc - iter_k*iter_j*iter_z_gqa*ne12 * sequence)/(iter_k*iter_j*iter_z_gqa);
1591+
const int zt_gqa = (kbc - iter_k*iter_j*iter_z_gqa*ne12 * sequence - iter_k*iter_j*iter_z_gqa * z_KV)/(iter_k*iter_j);
1592+
const int jt = (kbc - iter_k*iter_j*iter_z_gqa*ne12 * sequence - iter_k*iter_j*iter_z_gqa * z_KV - iter_k*iter_j * zt_gqa) / iter_k;
15881593

1589-
const int head0 = zt * ncols2;
1594+
const int zt_Q = z_KV*gqa_ratio + zt_gqa*ncols2; // Global Q head start index.
15901595

1591-
const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02* head0);
1592-
const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*(head0 / gqa_ratio));
1596+
const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02*zt_Q);
1597+
const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*z_KV);
15931598
const half * mask_h = ncols2 == 1 && !mask ? nullptr :
15941599
(const half *) (mask + nb33*(sequence % ne33));
1595-
float2 * dstk = ((float2 *) dst) + (sequence*ne01.z*ne02 + head0) * (DV/2);
1600+
float2 * dstk = ((float2 *) dst) + (sequence*ne01.z*ne02 + zt_Q) * (DV/2);
15961601

1597-
const half2 * V_h2 = V_is_K_view ? K_h2 : (const half2 *) (V + nb23*sequence + nb22*(head0 / gqa_ratio));
1598-
const float * sinks_f = sinks ? (const float *) sinks + head0 : nullptr;
1602+
const half2 * V_h2 = V_is_K_view ? K_h2 : (const half2 *) (V + nb23*sequence + nb22*z_KV);
1603+
const float * sinks_f = sinks ? (const float *) sinks + zt_Q : nullptr;
15991604

1600-
const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head0, n_head_log2, m0, m1) : 1.0f;
1605+
const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, zt_Q, n_head_log2, m0, m1) : 1.0f;
16011606

16021607
if (KV_max) {
16031608
kb0_stop = min(kb0_stop, KV_max[sequence*iter_j + jt] / nbatch_fa);
@@ -1607,7 +1612,7 @@ static __global__ void flash_attn_ext_f16(
16071612
constexpr bool needs_fixup = false;
16081613
flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, V_is_K_view, needs_fixup, is_fixup>
16091614
(Q_f2, K_h2, V_h2, mask_h, sinks_f, dstk, dst_meta, scale, slope, logit_softcap,
1610-
ne01, ne02, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, zt, kb0_start, kb0_stop);
1615+
ne01, ne02, gqa_ratio, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, zt_gqa, kb0_start, kb0_stop);
16111616
#else
16121617
GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale,
16131618
max_bias, m0, m1, n_head_log2, logit_softcap,

tests/test-backend-ops.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8216,8 +8216,8 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
82168216
for (int nh : { 4, }) {
82178217
for (int nr3 : { 1, 3, }) {
82188218
if (hsk > 64 && nr3 > 1) continue; // skip broadcast for large head sizes
8219-
for (int nr2 : { 1, 4, 16 }) {
8220-
if (nr2 == 16 && hsk != 128) continue;
8219+
for (int nr2 : { 1, 4, 12 }) {
8220+
if (nr2 == 12 && hsk != 128) continue;
82218221
//for (int kv : { 1, 17, 31, 33, 61, 113, 65, 127, 129, 130, 255, 260, 371, 380, 407, 512, 1024, }) {
82228222
for (int kv : { 113, 512, 1024, }) {
82238223
if (nr2 != 1 && kv != 512) continue;

0 commit comments

Comments
 (0)