@@ -933,14 +933,15 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
933933 const float logit_softcap,
934934 const uint3 ne01,
935935 const int ne02,
936+ const int gqa_ratio,
936937 const int ne11,
937938 const int stride_Q1,
938939 const int stride_Q2,
939940 const int stride_K,
940941 const int stride_V,
941942 const int stride_mask,
942943 const int jt,
943- const int zt ,
944+ const int zt_gqa ,
944945 const int kb0_start,
945946 const int kb0_stop) {
946947#if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4))
@@ -1023,7 +1024,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
10231024 const int j = jc / ncols2;
10241025 const int c = jc % ncols2;
10251026
1026- if ((ncols1 == 1 || jt*ncols1 + j < int (ne01.z )) && (ncols2 == 1 || zt *ncols2 + c < ne02 )) {
1027+ if ((ncols1 == 1 || jt*ncols1 + j < int (ne01.z )) && (ncols2 == 1 || zt_gqa *ncols2 + c < gqa_ratio )) {
10271028#pragma unroll
10281029 for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
10291030 const int k = k0 + (stride_k == WARP_SIZE ? threadIdx .x : threadIdx .x % stride_k);
@@ -1409,7 +1410,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
14091410 const int j_dst = jc_dst / ncols2;
14101411 const int c_dst = jc_dst % ncols2;
14111412
1412- if (!is_fixup && ((ncols1 > 1 && jt*ncols1 + j_dst >= int (ne01.z )) || (ncols2 > 1 && zt *ncols2 + c_dst >= ne02 ))) {
1413+ if (!is_fixup && ((ncols1 > 1 && jt*ncols1 + j_dst >= int (ne01.z )) || (ncols2 > 1 && zt_gqa *ncols2 + c_dst >= gqa_ratio ))) {
14131414 continue ;
14141415 }
14151416
@@ -1448,7 +1449,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
14481449 }
14491450#else
14501451 GGML_UNUSED_VARS (Q_f2, K_h2, V_h2, mask_h, sinks_f, dstk, dstk_fixup,
1451- scale, slope, logit_softcap, ne01, ne02,
1452+ scale, slope, logit_softcap, ne01, ne02, gqa_ratio,
14521453 stride_Q1, stride_Q2, stride_K, stride_V, stride_mask,
14531454 jt, kb0_start, kb0_stop);
14541455 NO_DEVICE_CODE;
@@ -1521,13 +1522,13 @@ static __global__ void flash_attn_ext_f16(
15211522
15221523 const int stride_V = V_is_K_view ? stride_K : nb21 / sizeof (half2);
15231524
1524- const int iter_k = (ne11 + (nbatch_fa - 1 )) / nbatch_fa;
1525- const int iter_j = (ne01.z + (ncols1 - 1 )) / ncols1;
1526- const int iter_z = (ne02 + (ncols2 - 1 )) / ncols2;
1525+ const int iter_k = (ne11 + (nbatch_fa - 1 )) / nbatch_fa;
1526+ const int iter_j = (ne01.z + (ncols1 - 1 )) / ncols1;
1527+ const int iter_z_gqa = (gqa_ratio + (ncols2 - 1 )) / ncols2;
15271528
15281529 // kbc == k block continuous, current index in continuous ijk space.
1529- int kbc = int64_t (blockIdx .x + 0 )*(iter_k*iter_j*iter_z *ne03) / gridDim .x ;
1530- const int kbc_stop = int64_t (blockIdx .x + 1 )*(iter_k*iter_j*iter_z *ne03) / gridDim .x ;
1530+ int kbc = int64_t (blockIdx .x + 0 )*(iter_k*iter_j*iter_z_gqa*ne12 *ne03) / gridDim .x ;
1531+ const int kbc_stop = int64_t (blockIdx .x + 1 )*(iter_k*iter_j*iter_z_gqa*ne12 *ne03) / gridDim .x ;
15311532
15321533 // If the seams of 2 CUDA blocks fall within an output tile their results need to be combined.
15331534 // For this we need to track both the block that starts the tile (needs_fixup) and the block that finishes the tile (is_fixup).
@@ -1538,22 +1539,24 @@ static __global__ void flash_attn_ext_f16(
15381539 int kb0_stop = min (iter_k, kb0_start + kbc_stop - kbc);
15391540
15401541 while (kbc < kbc_stop && kb0_stop == iter_k) {
1541- const int sequence = kbc / (iter_k*iter_j*iter_z);
1542- const int zt = (kbc - iter_k*iter_j*iter_z*sequence) / (iter_k*iter_j); // head in units of ncols2
1543- const int jt = (kbc - iter_k*iter_j*iter_z*sequence - iter_k*iter_j*zt) / iter_k; // j index of current tile.
1542+ // z_KV == K/V head index, zt_gqa = Q head start index per K/V head, jt = token position start index
1543+ const int sequence = kbc /(iter_k*iter_j*iter_z_gqa*ne12);
1544+ const int z_KV = (kbc - iter_k*iter_j*iter_z_gqa*ne12 * sequence)/(iter_k*iter_j*iter_z_gqa);
1545+ const int zt_gqa = (kbc - iter_k*iter_j*iter_z_gqa*ne12 * sequence - iter_k*iter_j*iter_z_gqa * z_KV)/(iter_k*iter_j);
1546+ const int jt = (kbc - iter_k*iter_j*iter_z_gqa*ne12 * sequence - iter_k*iter_j*iter_z_gqa * z_KV - iter_k*iter_j * zt_gqa) / iter_k;
15441547
1545- const int head0 = zt * ncols2;
1548+ const int zt_Q = z_KV*gqa_ratio + zt_gqa* ncols2; // Global Q head start index.
15461549
1547- const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02* head0 );
1548- const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*(head0 / gqa_ratio) );
1550+ const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02*zt_Q );
1551+ const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*z_KV );
15491552 const half * mask_h = ncols2 == 1 && !mask ? nullptr :
15501553 (const half *) (mask + nb33*(sequence % ne33));
1551- float2 * dstk = ((float2 *) dst) + (sequence*ne01.z *ne02 + head0 ) * (DV/2 );
1554+ float2 * dstk = ((float2 *) dst) + (sequence*ne01.z *ne02 + zt_Q ) * (DV/2 );
15521555
1553- const half2 * V_h2 = V_is_K_view ? K_h2 : (const half2 *) (V + nb23*sequence + nb22*(head0 / gqa_ratio) );
1554- const float * sinks_f = sinks ? (const float *) sinks + head0 : nullptr ;
1556+ const half2 * V_h2 = V_is_K_view ? K_h2 : (const half2 *) (V + nb23*sequence + nb22*z_KV );
1557+ const float * sinks_f = sinks ? (const float *) sinks + zt_Q : nullptr ;
15551558
1556- const float slope = ncols2 == 1 ? get_alibi_slope (max_bias, head0 , n_head_log2, m0, m1) : 1 .0f ;
1559+ const float slope = ncols2 == 1 ? get_alibi_slope (max_bias, zt_Q , n_head_log2, m0, m1) : 1 .0f ;
15571560
15581561 if (KV_max) {
15591562 kb0_stop = min (kb0_stop, KV_max[sequence*iter_j + jt] / nbatch_fa);
@@ -1563,12 +1566,12 @@ static __global__ void flash_attn_ext_f16(
15631566 constexpr bool needs_fixup = false ; // CUDA block is working on an entire tile.
15641567 flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, V_is_K_view, needs_fixup, is_fixup>
15651568 (Q_f2, K_h2, V_h2, mask_h, sinks_f, dstk, dst_meta, scale, slope, logit_softcap,
1566- ne01, ne02, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, zt , kb0_start, kb0_stop);
1569+ ne01, ne02, gqa_ratio, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, zt_gqa , kb0_start, kb0_stop);
15671570 } else {
15681571 constexpr bool needs_fixup = true ; // CUDA block is missing the beginning of a tile.
15691572 flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, V_is_K_view, needs_fixup, is_fixup>
15701573 (Q_f2, K_h2, V_h2, mask_h, sinks_f, dstk, dst_meta, scale, slope, logit_softcap,
1571- ne01, ne02, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, zt , kb0_start, kb0_stop);
1574+ ne01, ne02, gqa_ratio, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, zt_gqa , kb0_start, kb0_stop);
15721575 }
15731576
15741577 kbc += iter_k;
@@ -1582,22 +1585,24 @@ static __global__ void flash_attn_ext_f16(
15821585 return ;
15831586 }
15841587
1585- const int sequence = kbc / (iter_k*iter_j*iter_z);
1586- const int zt = (kbc - iter_k*iter_j*iter_z*sequence) / (iter_k*iter_j); // head in units of ncols2
1587- const int jt = (kbc - iter_k*iter_j*iter_z*sequence - iter_k*iter_j*zt) / iter_k; // j index of current tile.
1588+ // z_KV == K/V head index, zt_gqa = Q head start index per K/V head, jt = token position start index.
1589+ const int sequence = kbc /(iter_k*iter_j*iter_z_gqa*ne12);
1590+ const int z_KV = (kbc - iter_k*iter_j*iter_z_gqa*ne12 * sequence)/(iter_k*iter_j*iter_z_gqa);
1591+ const int zt_gqa = (kbc - iter_k*iter_j*iter_z_gqa*ne12 * sequence - iter_k*iter_j*iter_z_gqa * z_KV)/(iter_k*iter_j);
1592+ const int jt = (kbc - iter_k*iter_j*iter_z_gqa*ne12 * sequence - iter_k*iter_j*iter_z_gqa * z_KV - iter_k*iter_j * zt_gqa) / iter_k;
15881593
1589- const int head0 = zt * ncols2;
1594+ const int zt_Q = z_KV*gqa_ratio + zt_gqa* ncols2; // Global Q head start index.
15901595
1591- const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02* head0 );
1592- const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*(head0 / gqa_ratio) );
1596+ const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02*zt_Q );
1597+ const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*z_KV );
15931598 const half * mask_h = ncols2 == 1 && !mask ? nullptr :
15941599 (const half *) (mask + nb33*(sequence % ne33));
1595- float2 * dstk = ((float2 *) dst) + (sequence*ne01.z *ne02 + head0 ) * (DV/2 );
1600+ float2 * dstk = ((float2 *) dst) + (sequence*ne01.z *ne02 + zt_Q ) * (DV/2 );
15961601
1597- const half2 * V_h2 = V_is_K_view ? K_h2 : (const half2 *) (V + nb23*sequence + nb22*(head0 / gqa_ratio) );
1598- const float * sinks_f = sinks ? (const float *) sinks + head0 : nullptr ;
1602+ const half2 * V_h2 = V_is_K_view ? K_h2 : (const half2 *) (V + nb23*sequence + nb22*z_KV );
1603+ const float * sinks_f = sinks ? (const float *) sinks + zt_Q : nullptr ;
15991604
1600- const float slope = ncols2 == 1 ? get_alibi_slope (max_bias, head0 , n_head_log2, m0, m1) : 1 .0f ;
1605+ const float slope = ncols2 == 1 ? get_alibi_slope (max_bias, zt_Q , n_head_log2, m0, m1) : 1 .0f ;
16011606
16021607 if (KV_max) {
16031608 kb0_stop = min (kb0_stop, KV_max[sequence*iter_j + jt] / nbatch_fa);
@@ -1607,7 +1612,7 @@ static __global__ void flash_attn_ext_f16(
16071612 constexpr bool needs_fixup = false ;
16081613 flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, V_is_K_view, needs_fixup, is_fixup>
16091614 (Q_f2, K_h2, V_h2, mask_h, sinks_f, dstk, dst_meta, scale, slope, logit_softcap,
1610- ne01, ne02, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, zt , kb0_start, kb0_stop);
1615+ ne01, ne02, gqa_ratio, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, zt_gqa , kb0_start, kb0_stop);
16111616#else
16121617 GGML_UNUSED_VARS (Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale,
16131618 max_bias, m0, m1, n_head_log2, logit_softcap,
0 commit comments