Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
6862777
add block scale parameters to kernel
ltqin Nov 28, 2025
7260af6
add block scale to kernel
ltqin Nov 30, 2025
39104a0
add smoke test
ltqin Dec 1, 2025
356c3c9
format
ltqin Dec 1, 2025
42b5aa4
Revert "format"
ltqin Dec 1, 2025
2954847
only format my code
ltqin Dec 1, 2025
49a280b
format py
ltqin Dec 1, 2025
fa90ef7
Merge branch 'develop' into ck_tile/fmha_fwd_block_scale
ltqin Dec 1, 2025
326d8c3
fix auto not allowd in function prototype
ltqin Dec 1, 2025
43feee2
Merge branch 'develop' into ck_tile/fmha_fwd_block_scale
ltqin Dec 1, 2025
4484de9
change instance tttt to ttff
ltqin Dec 2, 2025
2709f16
fix structured binding issue
ltqin Dec 2, 2025
259fd92
change s_acc elementwise op
ltqin Dec 3, 2025
0d462a2
async pipeline add block scale
ltqin Dec 3, 2025
01b232c
Merge branch 'develop' into ck_tile/fmha_fwd_block_scale
ltqin Dec 4, 2025
ca7d542
Merge branch 'develop' into ck_tile/fmha_fwd_block_scale
ltqin Dec 18, 2025
163fd5e
add quantation P using shift exp2
ltqin Jan 6, 2026
01ccfc5
precompute (m - shift) once per row
ltqin Jan 6, 2026
806e96b
Merge branch 'develop' into ck_tile/fmha_fwd_block_scale
ltqin Jan 6, 2026
b08c25b
change blk scale seqstrt ptr name
ltqin Jan 7, 2026
86ad0d1
fix some name
ltqin Jan 7, 2026
9506bbc
Merge branch 'develop' into ck_tile/fmha_fwd_block_scale
ltqin Jan 15, 2026
99b720d
fix for deduction guide
ltqin Jan 15, 2026
6d61a5e
fix some comments
ltqin Jan 15, 2026
bbe780e
add P scale to qr_ksvs_pipeline
ltqin Jan 16, 2026
e8c31d7
add comment to idx_identity
ltqin Jan 16, 2026
fb9a3f2
change the method of calculating descale block index
ltqin Jan 16, 2026
162cda7
unify naming style: use block_scale_ as name prefix
ltqin Jan 16, 2026
6d19a46
unify naming style
ltqin Jan 16, 2026
1e95db8
Merge branch 'develop' into ck_tile/fmha_fwd_block_scale
ltqin Jan 19, 2026
142abcb
Merge branch 'develop' into ck_tile/fmha_fwd_block_scale
illsilin Jan 19, 2026
1995182
Merge branch 'develop' into ck_tile/fmha_fwd_block_scale
poyenc Jan 21, 2026
536cc56
update the CHANGELOG.md
ltqin Jan 21, 2026
1df0a12
Add FP8 block scale quantization support for FMHA forward kernel
ltqin Jan 21, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj
* Added FMHA batch prefill kernel support for several KV cache layouts, flexible page sizes, and different lookup table configurations.
* Added gpt-oss sink support for FMHA FWD, include qr_ks_vs, qr_async, qr_async_trload and splitkv pipelines.
* Added persistent async input scheduler for CK Tile universal GEMM kernels to support asynchronous input streaming.
* Added FP8 block scale quantization for FMHA forward kernel.

### Changed

Expand Down
2 changes: 2 additions & 0 deletions example/ck_tile/01_fmha/codegen/cpp_symbol_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,11 +77,13 @@ def get_mask_cpp_check_expr(mask: str) -> str:
QSCALE_MAP = {
"no": "ck_tile::BlockAttentionQuantScaleEnum::NO_SCALE",
"pertensor": "ck_tile::BlockAttentionQuantScaleEnum::PERTENSOR",
"blockscale": "ck_tile::BlockAttentionQuantScaleEnum::BLOCKSCALE",
}

QSCALE_CHECK_MAP = {
"no": "quant_scale_enum::no_scale",
"pertensor": "quant_scale_enum::pertensor",
"blockscale": "quant_scale_enum::blockscale",
}

BIAS_MAP = {
Expand Down
7 changes: 5 additions & 2 deletions example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py
Original file line number Diff line number Diff line change
Expand Up @@ -1018,7 +1018,7 @@ def get_pipelines(
# no need lse/dropout kernels
for logits, qscale, mask, bias, sink in itertools.product(
["t", "f"],
["no", "pertensor"],
["no", "pertensor", "blockscale"],
get_mask_map(mask_impl).keys(),
["no"],
["f", "t"],
Expand Down Expand Up @@ -1146,7 +1146,10 @@ def get_pipelines(
elif dtype in cls._DT_FP8_FP8BF16 or dtype in cls._DT_FP8FP32:
# no need lse/dropout kernels
for logits, qscale, mask, bias in itertools.product(
["f"], ["no", "pertensor"], get_mask_map(mask_impl).keys(), ["no"]
["f"],
["no", "pertensor", "blockscale"],
get_mask_map(mask_impl).keys(),
["no"],
):
pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, "f", "f", qscale, mask, "f", "f", "f")) # fmt: skip
pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, "f", "f", qscale, mask, "f", "f", "f")) # fmt: skip
Expand Down
26 changes: 26 additions & 0 deletions example/ck_tile/01_fmha/fmha_fwd.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,8 @@ struct fmha_fwd_args
// array [batch + 1]. (Used with padding)
const void* cu_seqlen_k_ptr = nullptr; // Cumulative logical (excluding padding) sequence length
// array [batch + 1]. (Used with padding)
const void* block_scale_seqstart_q_ptr;
const void* block_scale_seqstart_k_ptr;
const void* sink_ptr;

ck_tile::index_t seqlen_q;
Expand Down Expand Up @@ -257,13 +259,19 @@ struct fmha_fwd_args
ck_tile::index_t nhead_stride_randval;
ck_tile::index_t nhead_stride_lse;
ck_tile::index_t nhead_stride_o;
ck_tile::index_t nhead_stride_q_descale;
ck_tile::index_t nhead_stride_k_descale;
ck_tile::index_t nhead_stride_v_descale;
ck_tile::index_t batch_stride_q;
ck_tile::index_t batch_stride_k;
ck_tile::index_t batch_stride_v;
ck_tile::index_t batch_stride_bias;
ck_tile::index_t batch_stride_randval;
ck_tile::index_t batch_stride_lse;
ck_tile::index_t batch_stride_o;
ck_tile::index_t batch_stride_q_descale;
ck_tile::index_t batch_stride_k_descale;
ck_tile::index_t batch_stride_v_descale;

ck_tile::index_t window_size_left;
ck_tile::index_t window_size_right;
Expand All @@ -276,6 +284,9 @@ struct fmha_fwd_args

std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
drop_seed_offset;

ck_tile::index_t block_scale_size_q;
ck_tile::index_t block_scale_size_kv;
};

struct fmha_fwd_pagedkv_args
Expand Down Expand Up @@ -615,6 +626,8 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
args.seqstart_k_ptr,
args.seqlen_q_ptr,
args.seqlen_k_ptr,
args.block_scale_seqstart_q_ptr,
args.block_scale_seqstart_k_ptr,
args.hdim_q,
args.hdim_v,
args.nhead_q,
Expand All @@ -634,6 +647,9 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
args.nhead_stride_randval,
args.nhead_stride_lse,
args.nhead_stride_o,
args.nhead_stride_q_descale,
args.nhead_stride_k_descale,
args.nhead_stride_v_descale,
args.window_size_left,
args.window_size_right,
args.sink_size,
Expand All @@ -642,6 +658,8 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
args.p_drop,
args.s_randval,
args.drop_seed_offset,
args.block_scale_size_q,
args.block_scale_size_kv,
args.cu_seqlen_q_ptr,
args.cu_seqlen_k_ptr,
args.sink_ptr);
Expand Down Expand Up @@ -679,20 +697,28 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
args.nhead_stride_randval,
args.nhead_stride_lse,
args.nhead_stride_o,
args.nhead_stride_q_descale,
args.nhead_stride_k_descale,
args.nhead_stride_v_descale,
args.batch_stride_q,
args.batch_stride_k,
args.batch_stride_v,
args.batch_stride_bias,
args.batch_stride_randval,
args.batch_stride_lse,
args.batch_stride_o,
args.batch_stride_q_descale,
args.batch_stride_k_descale,
args.batch_stride_v_descale,
args.window_size_left,
args.window_size_right,
args.sink_size,
args.mask_type,
args.p_drop,
args.s_randval,
args.drop_seed_offset,
args.block_scale_size_q,
args.block_scale_size_kv,
args.cu_seqlen_q_ptr,
args.cu_seqlen_k_ptr,
args.sink_ptr);
Expand Down
Loading