From d321f3a076e0ac0f274413658a89fefb59c54cc9 Mon Sep 17 00:00:00 2001 From: "Ding, Yi" Date: Tue, 20 Jan 2026 05:18:11 +0000 Subject: [PATCH] [CK_TILE] Fix Int32 Overflow in Deterministic FMHA BWD --- example/ck_tile/01_fmha/fmha_bwd.hpp | 4 ++-- example/ck_tile/01_fmha/fmha_bwd_runner.hpp | 16 +++++++-------- .../ops/fmha/kernel/fmha_bwd_kernel.hpp | 20 +++++++++---------- 3 files changed, 20 insertions(+), 20 deletions(-) diff --git a/example/ck_tile/01_fmha/fmha_bwd.hpp b/example/ck_tile/01_fmha/fmha_bwd.hpp index d1b55168e35..180d039cd44 100644 --- a/example/ck_tile/01_fmha/fmha_bwd.hpp +++ b/example/ck_tile/01_fmha/fmha_bwd.hpp @@ -189,7 +189,7 @@ struct fmha_bwd_args ck_tile::index_t nhead_stride_randval; ck_tile::index_t nhead_stride_do; ck_tile::index_t nhead_stride_lsed; - ck_tile::index_t nhead_stride_dq_acc; + ck_tile::long_index_t nhead_stride_dq_acc; ck_tile::index_t nhead_stride_dq; ck_tile::index_t nhead_stride_dk; ck_tile::index_t nhead_stride_dv; @@ -202,7 +202,7 @@ struct fmha_bwd_args ck_tile::index_t batch_stride_randval; ck_tile::index_t batch_stride_do; ck_tile::index_t batch_stride_lsed; - ck_tile::index_t batch_stride_dq_acc; + ck_tile::long_index_t batch_stride_dq_acc; ck_tile::index_t batch_stride_dq; ck_tile::index_t batch_stride_dk; ck_tile::index_t batch_stride_dv; diff --git a/example/ck_tile/01_fmha/fmha_bwd_runner.hpp b/example/ck_tile/01_fmha/fmha_bwd_runner.hpp index d62b908e33f..f41f0668e54 100644 --- a/example/ck_tile/01_fmha/fmha_bwd_runner.hpp +++ b/example/ck_tile/01_fmha/fmha_bwd_runner.hpp @@ -287,9 +287,7 @@ bwd_result fmha_bwd_run(mode_enum mode, ? get_lengths(i_perm, shape_batch, nhead, shape_seqlen_q, max_seqlen_k) : std::array{1, 1, 1, 1} /* dummy shape for simplifying code */); ck_tile::HostTensor dq_acc_host( - i_perm - ? std::array{nsplits, shape_batch, nhead, shape_seqlen_q, hdim_q} - : std::array{nsplits, shape_batch, shape_seqlen_q, nhead, hdim_q}); + std::array{shape_batch, nhead, nsplits, shape_seqlen_q, hdim_q}); if(init_method == "ui" || init_method == "0") { @@ -433,6 +431,7 @@ bwd_result fmha_bwd_run(mode_enum mode, const ck_tile::index_t stride_dk = (i_perm ? hdim_q : nhead * hdim_q); const ck_tile::index_t stride_dv = (i_perm ? hdim_v : nhead * hdim_v); const ck_tile::index_t stride_dbias = (i_perm ? max_seqlen_k : nhead * max_seqlen_k); + const auto split_stride_dq_acc = (shape_seqlen_q * hdim_q); // setup nhead_stride_* arguments const ck_tile::index_t nhead_stride_q = (i_perm ? shape_seqlen_q * hdim_q : hdim_q); const ck_tile::index_t nhead_stride_k = (i_perm ? shape_seqlen_k * hdim_q : hdim_q); @@ -444,6 +443,8 @@ bwd_result fmha_bwd_run(mode_enum mode, const ck_tile::index_t nhead_stride_lsed = shape_seqlen_q; const ck_tile::index_t nhead_stride_dbias = (i_perm ? shape_seqlen_q * max_seqlen_k : max_seqlen_k); + const auto nhead_stride_dq_acc = + static_cast(split_stride_dq_acc) * nsplits; // setup batch_stride_* arguments const ck_tile::index_t batch_stride_q = (nhead * shape_seqlen_q * hdim_q); const ck_tile::index_t batch_stride_k = (nhead_k * shape_seqlen_k * hdim_q); @@ -456,8 +457,7 @@ bwd_result fmha_bwd_run(mode_enum mode, const ck_tile::index_t batch_stride_dk = (nhead * shape_seqlen_k * hdim_q); const ck_tile::index_t batch_stride_dv = (nhead * shape_seqlen_k * hdim_v); const ck_tile::index_t batch_stride_dbias = (nhead * shape_seqlen_q * max_seqlen_k); - const ck_tile::index_t split_stride_dq_acc = - (shape_batch * nhead * shape_seqlen_q * hdim_q); + const auto batch_stride_dq_acc = nhead * nhead_stride_dq_acc; const auto drop_seed_offset = [&]() -> decltype(fmha_bwd_args::drop_seed_offset) { if(drop_prefs) @@ -513,7 +513,7 @@ bwd_result fmha_bwd_run(mode_enum mode, stride_o, stride_randval, stride_do, - stride_q, // stride_dq_acc + hdim_q, // stride_dq_acc stride_q, // stride_dq stride_dk, stride_dv, @@ -526,7 +526,7 @@ bwd_result fmha_bwd_run(mode_enum mode, nhead_stride_randval, nhead_stride_do, nhead_stride_lsed, - nhead_stride_q, // nhead_stride_dq_acc + nhead_stride_dq_acc, nhead_stride_q, // nhead_stride_dq nhead_stride_k, // nhead_stride_dk nhead_stride_v, // nhead_stride_dv @@ -539,7 +539,7 @@ bwd_result fmha_bwd_run(mode_enum mode, batch_stride_randval, batch_stride_do, batch_stride_lsed, - batch_stride_q, // batch_stride_dq_acc + batch_stride_dq_acc, batch_stride_q, // batch_stride_dq batch_stride_dk, batch_stride_dv, diff --git a/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp index 5b491465b38..06b0d76a0d6 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp @@ -171,7 +171,7 @@ struct FmhaBwdDQDKDVKernel ck_tile::index_t nhead_stride_v; ck_tile::index_t nhead_stride_do; ck_tile::index_t nhead_stride_lsed; - ck_tile::index_t nhead_stride_dq_acc; + ck_tile::long_index_t nhead_stride_dq_acc; ck_tile::index_t nhead_stride_dk; ck_tile::index_t nhead_stride_dv; }; @@ -294,7 +294,7 @@ struct FmhaBwdDQDKDVKernel ck_tile::index_t batch_stride_v; ck_tile::index_t batch_stride_do; ck_tile::index_t batch_stride_lsed; - ck_tile::index_t batch_stride_dq_acc; + ck_tile::long_index_t batch_stride_dq_acc; ck_tile::index_t batch_stride_dk; ck_tile::index_t batch_stride_dv; }; @@ -377,7 +377,7 @@ struct FmhaBwdDQDKDVKernel ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_do, ck_tile::index_t nhead_stride_lsed, - ck_tile::index_t nhead_stride_dq_acc, + ck_tile::long_index_t nhead_stride_dq_acc, ck_tile::index_t nhead_stride_dk, ck_tile::index_t nhead_stride_dv, ck_tile::index_t nhead_stride_dbias, @@ -388,7 +388,7 @@ struct FmhaBwdDQDKDVKernel ck_tile::index_t batch_stride_randval, ck_tile::index_t batch_stride_do, ck_tile::index_t batch_stride_lsed, - ck_tile::index_t batch_stride_dq_acc, + ck_tile::long_index_t batch_stride_dq_acc, ck_tile::index_t batch_stride_dk, ck_tile::index_t batch_stride_dv, ck_tile::index_t batch_stride_dbias, @@ -549,7 +549,7 @@ struct FmhaBwdDQDKDVKernel ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_do, ck_tile::index_t nhead_stride_lsed, - ck_tile::index_t nhead_stride_dq_acc, + ck_tile::long_index_t nhead_stride_dq_acc, ck_tile::index_t nhead_stride_dk, ck_tile::index_t nhead_stride_dv, ck_tile::index_t nhead_stride_dbias, @@ -1574,7 +1574,7 @@ struct FmhaBwdConvertQGradKernel ck_tile::index_t stride_dq; ck_tile::index_t stride_dq_acc; ck_tile::index_t nhead_stride_dq; - ck_tile::index_t nhead_stride_dq_acc; + ck_tile::long_index_t nhead_stride_dq_acc; }; struct FmhaBwdConvertQGradDeterministicKargs @@ -1589,7 +1589,7 @@ struct FmhaBwdConvertQGradKernel FmhaBwdConvertQGradEmptyKargs<0>> { ck_tile::index_t batch_stride_dq; - ck_tile::index_t batch_stride_dq_acc; + ck_tile::long_index_t batch_stride_dq_acc; }; struct FmhaBwdConvertQGradGroupModeKargs @@ -1620,9 +1620,9 @@ struct FmhaBwdConvertQGradKernel ck_tile::index_t stride_dq, ck_tile::index_t stride_dq_acc, ck_tile::index_t nhead_stride_dq, - ck_tile::index_t nhead_stride_dq_acc, + ck_tile::long_index_t nhead_stride_dq_acc, ck_tile::index_t batch_stride_dq, - ck_tile::index_t batch_stride_dq_acc, + ck_tile::long_index_t batch_stride_dq_acc, ck_tile::index_t split_stride_dq_acc) { Kargs kargs{{dq_acc_ptr, @@ -1660,7 +1660,7 @@ struct FmhaBwdConvertQGradKernel ck_tile::index_t stride_dq, ck_tile::index_t stride_dq_acc, ck_tile::index_t nhead_stride_dq, - ck_tile::index_t nhead_stride_dq_acc, + ck_tile::long_index_t nhead_stride_dq_acc, ck_tile::index_t split_stride_dq_acc) { Kargs kargs{{dq_acc_ptr,