Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions example/ck_tile/01_fmha/fmha_bwd.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ struct fmha_bwd_args
ck_tile::index_t nhead_stride_randval;
ck_tile::index_t nhead_stride_do;
ck_tile::index_t nhead_stride_lsed;
ck_tile::index_t nhead_stride_dq_acc;
ck_tile::long_index_t nhead_stride_dq_acc;
ck_tile::index_t nhead_stride_dq;
ck_tile::index_t nhead_stride_dk;
ck_tile::index_t nhead_stride_dv;
Expand All @@ -202,7 +202,7 @@ struct fmha_bwd_args
ck_tile::index_t batch_stride_randval;
ck_tile::index_t batch_stride_do;
ck_tile::index_t batch_stride_lsed;
ck_tile::index_t batch_stride_dq_acc;
ck_tile::long_index_t batch_stride_dq_acc;
ck_tile::index_t batch_stride_dq;
ck_tile::index_t batch_stride_dk;
ck_tile::index_t batch_stride_dv;
Expand Down
16 changes: 8 additions & 8 deletions example/ck_tile/01_fmha/fmha_bwd_runner.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -287,9 +287,7 @@ bwd_result fmha_bwd_run(mode_enum mode,
? get_lengths(i_perm, shape_batch, nhead, shape_seqlen_q, max_seqlen_k)
: std::array<ck_tile::index_t, 4>{1, 1, 1, 1} /* dummy shape for simplifying code */);
ck_tile::HostTensor<AccDataType> dq_acc_host(
i_perm
? std::array<ck_tile::index_t, 5>{nsplits, shape_batch, nhead, shape_seqlen_q, hdim_q}
: std::array<ck_tile::index_t, 5>{nsplits, shape_batch, shape_seqlen_q, nhead, hdim_q});
std::array<ck_tile::index_t, 5>{shape_batch, nhead, nsplits, shape_seqlen_q, hdim_q});

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this dq_acc layout change from (nsplit, B, H, S, D) to (B, H, nsplit, S, D) due to the fact that only batch_stride_dq_acc and nhead_stride_dq_acc are promoted to long_index_t?

In fact, in order to satisfy some aiter asm bwd requirement, our TE bwd had to make dq_acc layout as (nsplits, B, H, S, D). Can we also promote the nsplit_stride_dq_acc as well?

Copy link
Contributor Author

@DDEle DDEle Jan 21, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. Promoting nsplits_stride_dq_acc would introduce extra overhead as a thread block of the reduction & convert kernel needs access multiple splits in its hotloop (thus it is part of the ck tile window rather than simple pointer arithmetic at the beginning of the kernel).

The layout of dq_acc is kind of in the scope of kernel implementation and I guess it's better to change it accordingly.


if(init_method == "ui" || init_method == "0")
{
Expand Down Expand Up @@ -433,6 +431,7 @@ bwd_result fmha_bwd_run(mode_enum mode,
const ck_tile::index_t stride_dk = (i_perm ? hdim_q : nhead * hdim_q);
const ck_tile::index_t stride_dv = (i_perm ? hdim_v : nhead * hdim_v);
const ck_tile::index_t stride_dbias = (i_perm ? max_seqlen_k : nhead * max_seqlen_k);
const auto split_stride_dq_acc = (shape_seqlen_q * hdim_q);
// setup nhead_stride_* arguments
const ck_tile::index_t nhead_stride_q = (i_perm ? shape_seqlen_q * hdim_q : hdim_q);
const ck_tile::index_t nhead_stride_k = (i_perm ? shape_seqlen_k * hdim_q : hdim_q);
Expand All @@ -444,6 +443,8 @@ bwd_result fmha_bwd_run(mode_enum mode,
const ck_tile::index_t nhead_stride_lsed = shape_seqlen_q;
const ck_tile::index_t nhead_stride_dbias =
(i_perm ? shape_seqlen_q * max_seqlen_k : max_seqlen_k);
const auto nhead_stride_dq_acc =
static_cast<ck_tile::long_index_t>(split_stride_dq_acc) * nsplits;
// setup batch_stride_* arguments
const ck_tile::index_t batch_stride_q = (nhead * shape_seqlen_q * hdim_q);
const ck_tile::index_t batch_stride_k = (nhead_k * shape_seqlen_k * hdim_q);
Expand All @@ -456,8 +457,7 @@ bwd_result fmha_bwd_run(mode_enum mode,
const ck_tile::index_t batch_stride_dk = (nhead * shape_seqlen_k * hdim_q);
const ck_tile::index_t batch_stride_dv = (nhead * shape_seqlen_k * hdim_v);
const ck_tile::index_t batch_stride_dbias = (nhead * shape_seqlen_q * max_seqlen_k);
const ck_tile::index_t split_stride_dq_acc =
(shape_batch * nhead * shape_seqlen_q * hdim_q);
const auto batch_stride_dq_acc = nhead * nhead_stride_dq_acc;

const auto drop_seed_offset = [&]() -> decltype(fmha_bwd_args::drop_seed_offset) {
if(drop_prefs)
Expand Down Expand Up @@ -513,7 +513,7 @@ bwd_result fmha_bwd_run(mode_enum mode,
stride_o,
stride_randval,
stride_do,
stride_q, // stride_dq_acc
hdim_q, // stride_dq_acc
stride_q, // stride_dq
stride_dk,
stride_dv,
Expand All @@ -526,7 +526,7 @@ bwd_result fmha_bwd_run(mode_enum mode,
nhead_stride_randval,
nhead_stride_do,
nhead_stride_lsed,
nhead_stride_q, // nhead_stride_dq_acc
nhead_stride_dq_acc,
nhead_stride_q, // nhead_stride_dq
nhead_stride_k, // nhead_stride_dk
nhead_stride_v, // nhead_stride_dv
Expand All @@ -539,7 +539,7 @@ bwd_result fmha_bwd_run(mode_enum mode,
batch_stride_randval,
batch_stride_do,
batch_stride_lsed,
batch_stride_q, // batch_stride_dq_acc
batch_stride_dq_acc,
batch_stride_q, // batch_stride_dq
batch_stride_dk,
batch_stride_dv,
Expand Down
20 changes: 10 additions & 10 deletions include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ struct FmhaBwdDQDKDVKernel
ck_tile::index_t nhead_stride_v;
ck_tile::index_t nhead_stride_do;
ck_tile::index_t nhead_stride_lsed;
ck_tile::index_t nhead_stride_dq_acc;
ck_tile::long_index_t nhead_stride_dq_acc;
ck_tile::index_t nhead_stride_dk;
ck_tile::index_t nhead_stride_dv;
};
Expand Down Expand Up @@ -294,7 +294,7 @@ struct FmhaBwdDQDKDVKernel
ck_tile::index_t batch_stride_v;
ck_tile::index_t batch_stride_do;
ck_tile::index_t batch_stride_lsed;
ck_tile::index_t batch_stride_dq_acc;
ck_tile::long_index_t batch_stride_dq_acc;
ck_tile::index_t batch_stride_dk;
ck_tile::index_t batch_stride_dv;
};
Expand Down Expand Up @@ -377,7 +377,7 @@ struct FmhaBwdDQDKDVKernel
ck_tile::index_t nhead_stride_randval,
ck_tile::index_t nhead_stride_do,
ck_tile::index_t nhead_stride_lsed,
ck_tile::index_t nhead_stride_dq_acc,
ck_tile::long_index_t nhead_stride_dq_acc,
ck_tile::index_t nhead_stride_dk,
ck_tile::index_t nhead_stride_dv,
ck_tile::index_t nhead_stride_dbias,
Expand All @@ -388,7 +388,7 @@ struct FmhaBwdDQDKDVKernel
ck_tile::index_t batch_stride_randval,
ck_tile::index_t batch_stride_do,
ck_tile::index_t batch_stride_lsed,
ck_tile::index_t batch_stride_dq_acc,
ck_tile::long_index_t batch_stride_dq_acc,
ck_tile::index_t batch_stride_dk,
ck_tile::index_t batch_stride_dv,
ck_tile::index_t batch_stride_dbias,
Expand Down Expand Up @@ -549,7 +549,7 @@ struct FmhaBwdDQDKDVKernel
ck_tile::index_t nhead_stride_randval,
ck_tile::index_t nhead_stride_do,
ck_tile::index_t nhead_stride_lsed,
ck_tile::index_t nhead_stride_dq_acc,
ck_tile::long_index_t nhead_stride_dq_acc,
ck_tile::index_t nhead_stride_dk,
ck_tile::index_t nhead_stride_dv,
ck_tile::index_t nhead_stride_dbias,
Expand Down Expand Up @@ -1574,7 +1574,7 @@ struct FmhaBwdConvertQGradKernel
ck_tile::index_t stride_dq;
ck_tile::index_t stride_dq_acc;
ck_tile::index_t nhead_stride_dq;
ck_tile::index_t nhead_stride_dq_acc;
ck_tile::long_index_t nhead_stride_dq_acc;
};

struct FmhaBwdConvertQGradDeterministicKargs
Expand All @@ -1589,7 +1589,7 @@ struct FmhaBwdConvertQGradKernel
FmhaBwdConvertQGradEmptyKargs<0>>
{
ck_tile::index_t batch_stride_dq;
ck_tile::index_t batch_stride_dq_acc;
ck_tile::long_index_t batch_stride_dq_acc;
};

struct FmhaBwdConvertQGradGroupModeKargs
Expand Down Expand Up @@ -1620,9 +1620,9 @@ struct FmhaBwdConvertQGradKernel
ck_tile::index_t stride_dq,
ck_tile::index_t stride_dq_acc,
ck_tile::index_t nhead_stride_dq,
ck_tile::index_t nhead_stride_dq_acc,
ck_tile::long_index_t nhead_stride_dq_acc,
ck_tile::index_t batch_stride_dq,
ck_tile::index_t batch_stride_dq_acc,
ck_tile::long_index_t batch_stride_dq_acc,
ck_tile::index_t split_stride_dq_acc)
{
Kargs kargs{{dq_acc_ptr,
Expand Down Expand Up @@ -1660,7 +1660,7 @@ struct FmhaBwdConvertQGradKernel
ck_tile::index_t stride_dq,
ck_tile::index_t stride_dq_acc,
ck_tile::index_t nhead_stride_dq,
ck_tile::index_t nhead_stride_dq_acc,
ck_tile::long_index_t nhead_stride_dq_acc,
ck_tile::index_t split_stride_dq_acc)
{
Kargs kargs{{dq_acc_ptr,
Expand Down