From 496c69326ab78a9db14ef91a723e3c36083d9b54 Mon Sep 17 00:00:00 2001 From: JL-underdog Date: Wed, 15 Oct 2025 10:34:44 +0000 Subject: [PATCH 01/10] enable wholek_prefetch Signed-off-by: JL-underdog --- .../ck_tile/01_fmha/codegen/cpp_symbol_map.py | 2 ++ .../ck_tile/01_fmha/codegen/ops/fmha_fwd.py | 21 +++++++++++++------ .../pipeline/block_fmha_pipeline_enum.hpp | 8 +++++++ ...mha_pipeline_qr_ks_vs_whole_k_prefetch.hpp | 2 +- 4 files changed, 26 insertions(+), 7 deletions(-) diff --git a/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py b/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py index 802c9e51d79..cfea40e63ff 100644 --- a/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py +++ b/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py @@ -117,6 +117,7 @@ def get_mask_check_map(mask : str): "qr_async" : "ck_tile::BlockFmhaPipelineQRKSVSAsync", "qs" : "ck_tile::BlockFmhaPipelineQSKSVS", "qr_async_trload" : "ck_tile::BlockFmhaPipelineQRKSVSAsyncTrload", + "qr_wholek_prefetch" : "ck_tile::BlockFmhaPipelineQRKSVSWholeKPrefetch", } PIPELINE_ENUM_MAP = { @@ -126,6 +127,7 @@ def get_mask_check_map(mask : str): "qs" : "ck_tile::BlockFmhaPipelineEnum::QSKSVS", "qr_pagedkv" : "ck_tile::BlockFmhaPipelineEnum::QRKSVS", "qr_async_trload" : "ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC_TRLOAD", + "qr_wholek_prefetch" : "ck_tile::BlockFmhaPipelineEnum::QRKSVS_WHOLEK_PREFETCH", } BOOL_MAP = { diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py index da0c9ca9316..017d82f11f7 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py @@ -244,14 +244,14 @@ def scheck(self) -> str: if self.pipeline_tag in ['qr_async', 'qr_async_trload']: if self.spad == 't' : return 'true' # always support else : return 'true' - elif self.pipeline_tag in ['qr', 'qs']: + elif self.pipeline_tag in ['qr', 'qs','qr_wholek_prefetch']: if self.spad == 't' : return f'true /*a.seqlen_q % {self.bm0} != 0*/' # TODO: order of get_pipelines() matters! (ugly) else : return f'a.seqlen_q % {self.bm0} == 0' else: assert False @property def seqtune(self) -> str: - if self.bm0 == 128: return 'true/*fall back to largest tile*/' # group mode only generate spad/skpad == true + if self.bm0 == 128 or self.bm0 == 64: return 'true/*fall back to largest tile*/' # group mode only generate spad/skpad == true else: return f'a.seqlen_q <= {self.bm0}' @@ -261,7 +261,7 @@ def skcheck(self) -> str: if self.pipeline_tag == 'qr_async': if self.skpad == 't' : return f'(a.cu_seqlen_kv_ptr != nullptr) || (a.seqlen_k == 0 || a.seqlen_k % {self.bn0} != 0)' else : return f'(a.cu_seqlen_kv_ptr == nullptr) && (a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0)' - elif self.pipeline_tag in ['qr', 'qs']: + elif self.pipeline_tag in ['qr', 'qs','qr_wholek_prefetch']: if self.skpad == 't' : return f'true /*a.seqlen_k % {self.bn0} != 0*/' # TODO: order of get_pipelines() matters! (ugly) else : return f'(a.cu_seqlen_kv_ptr == nullptr) && (a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0)' elif self.pipeline_tag == 'qr_async_trload': @@ -275,7 +275,7 @@ def dcheck(self) -> str: vec = int((32 * 4) / DTYPE_BITS[self.dtype]) if self.dpad == 't': return f'a.hdim_q % {vec} == 0' else : assert False - elif self.pipeline_tag in ['qr', 'qs', 'qr_async_trload']: + elif self.pipeline_tag in ['qr', 'qs', 'qr_async_trload','qr_wholek_prefetch']: bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max] if self.dpad == 't': return f'true /*a.hdim_q % {bk0submax} != 0*/' # TODO: order of get_pipelines() matters! (ugly) else : return f'a.hdim_q % {bk0submax} == 0' @@ -287,7 +287,7 @@ def dvcheck(self) -> str: vec = int((32 * 4) / DTYPE_BITS[self.dtype]) if self.dvpad == 't': return f'a.hdim_v % {vec} == 0' else : assert False - elif self.pipeline_tag in ['qr', 'qs', 'qr_async_trload']: + elif self.pipeline_tag in ['qr', 'qs', 'qr_async_trload','qr_wholek_prefetch']: bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max] if self.dvpad == 't': return f'true /*a.hdim_v % {bk0submax} != 0*/' # TODO: order of get_pipelines() matters! (ugly) else : return f'a.hdim_v % {bk0submax} == 0' @@ -543,6 +543,7 @@ def get_hdim_tile_size_dict(dtype : str) -> Optional[dict]: (96, 128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], (128,128) : [FmhaFwdTileSize(16, 32, 64, 128, 32, 128, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 32, -1), FmhaFwdTileSize(32, 32, 128, 128, 32, 128, 1, 1, 1, 1, 1, 1, 32, 32, 16, 32, 32, 16, -1), + FmhaFwdTileSize(64, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 32, 16, 16, 16, -1,CppConstraint('get_num_blocks(128) < num_cus * min_cu_util_rate')), FmhaFwdTileSize(128, 64, 32, 128, 16, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1), FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], # (160,160) : [FmhaFwdTileSize(128, 128, 32, 160, 32, 160, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1)], @@ -591,6 +592,9 @@ def get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl) -> List[FmhaFwdPipeli if (hdim, hdim_v) in [(64, 64), (128, 128)] and logits == "f" and bias == "no" and dropout == "f" and lse == "f" and skip == "f": pipelines.append(FmhaFwdPipeline('qr_async_trload', 'row', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip, 't')) pipelines.append(FmhaFwdPipeline('qr_async_trload', 'row', 'f', 'f', 't', 't', logits, bias, lse, dropout, squant, mask, skip, 't')) + if (hdim, hdim_v) == (128, 128) and logits == "f" and bias == "no" and dropout == "f": + pipelines.append(FmhaFwdPipeline('qr_wholek_prefetch', 'row', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip, 'f')) + pipelines.append(FmhaFwdPipeline('qr_wholek_prefetch', 'row', 'f', 't', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip, 'f')) if receipt == 1 and bias != "bias": pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip, 'f')) # TODO: cover arbitraty hdim elif dtype in ['fp8', 'fp8bf16', 'fp8fp32']: @@ -628,7 +632,8 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl for ((hdim, hdim_v), tiles), mode in itertools.product(d.items(), MODE_MAP.keys()): for tile, pipeline in itertools.product(tiles, factory.get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl)): if mode == "group": - if pipeline.F_spad != 't' or pipeline.F_skpad != 't': + # if pipeline.F_spad != 't' or pipeline.F_skpad != 't': + if pipeline.tag != 'qr_wholek_prefetch' and (pipeline.F_spad != 't' or pipeline.F_skpad != 't'): # in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not continue if (hdim, hdim_v) == (192, 128): @@ -641,6 +646,10 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl continue if pipeline.tag == 'qr_async_trload' and (((hdim, hdim_v) == (128, 128) and tile.F_bn0 == 128) or ((hdim, hdim_v) not in [(64, 64), (128, 128)])): continue + if pipeline.tag == 'qr_wholek_prefetch' and (((hdim, hdim_v) == (128, 128) and tile.F_bm0 == 128)): + continue + if pipeline.tag in ['qr_async','qr','qr_async_trload'] and (((hdim, hdim_v) == (128, 128) and tile.F_bm0 == 64)): + continue # logits_soft_cap is only allowed if no bias if not ((pipeline.F_logits == 't' and pipeline.F_bias == 'no') or pipeline.F_logits == 'f'): continue diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_enum.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_enum.hpp index 45a1c8f4b87..e7437e701e8 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_enum.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_enum.hpp @@ -12,6 +12,7 @@ enum class BlockFmhaPipelineEnum QRKSVS_ASYNC, QSKSVS, QRKSVS_ASYNC_TRLOAD, + QRKSVS_WHOLEK_PREFETCH, }; template @@ -39,4 +40,11 @@ struct BlockFmhaPipelineEnumToStr static constexpr const char* name = "qr_async_trload"; }; +template <> +struct BlockFmhaPipelineEnumToStr +{ + static constexpr const char* name = "qr_wholek_prefetch"; +}; + + } // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp index 074a94613c4..81644e102bc 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp @@ -106,7 +106,7 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch } }(); - static constexpr const char* name = "qr_async"; + static constexpr const char* name = "qr_ks_vs_whole_k_prefetch"; using DropoutType = std::conditional_t; From 8799f16ad8167d27185ec5a9190b306685369390 Mon Sep 17 00:00:00 2001 From: Linjun-AMD Date: Thu, 30 Oct 2025 13:02:15 +0800 Subject: [PATCH 02/10] Update fmha_fwd.py --- example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py index b38306371e6..9100e4cadef 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py @@ -911,8 +911,6 @@ def get_fwd_blobs( continue if pipeline.tag == "qr_wholek_prefetch" and (((hdim, hdim_v) == (128, 128) and tile.F_bm0 == 128)): continue - if pipeline.tag in ['qr_async','qr','qr_async_trload'] and (((hdim, hdim_v) == (128, 128) and tile.F_bm0 == 64)): - continue # logits_soft_cap is only allowed if no bias if not ( (pipeline.F_logits == "t" and pipeline.F_bias == "no") From b47650cf4d6af6150a4f53f4c3ce7e4aa097b534 Mon Sep 17 00:00:00 2001 From: Linjun-AMD Date: Thu, 30 Oct 2025 14:54:58 +0800 Subject: [PATCH 03/10] Update block_fmha_pipeline_enum.hpp --- include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_enum.hpp | 1 - 1 file changed, 1 deletion(-) diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_enum.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_enum.hpp index e7437e701e8..c79dca9845a 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_enum.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_enum.hpp @@ -46,5 +46,4 @@ struct BlockFmhaPipelineEnumToStr static constexpr const char* name = "qr_wholek_prefetch"; }; - } // namespace ck_tile From 6c7e4630038ab7601fcee0b3e225a9ee28b5c727 Mon Sep 17 00:00:00 2001 From: Linjun-AMD Date: Thu, 15 Jan 2026 10:00:56 +0800 Subject: [PATCH 04/10] Update fmha_fwd.py --- example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py index 3e4df75c576..5b211bc92ca 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py @@ -1323,8 +1323,10 @@ def get_fwd_blobs( for tile, pipeline in itertools.product( tiles, factory.get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl) - ): - if pipeline.tag == "qr_wholek_prefetch" and (((hdim, hdim_v) == (128, 128) and tile.F_bm0 == 128)): + ): + if pipeline.tag == "qr_wholek_prefetch" and ( + (hdim, hdim_v) == (128, 128) and tile.F_bm0 == 128 + ): continue problem_ctx = ProblemContext( dtype=dtype, mode=mode, hdim=hdim, hdim_v=hdim_v From e41098ecee604bff5a0142a13f3ff6f2ffe99592 Mon Sep 17 00:00:00 2001 From: Linjun-AMD Date: Wed, 14 Jan 2026 20:20:00 -0600 Subject: [PATCH 05/10] replace squant to qscale Signed-off-by: Linjun-AMD --- .../ck_tile/01_fmha/codegen/cpp_symbol_map.py | 4 +- .../ck_tile/01_fmha/codegen/ops/fmha_fwd.py | 64 +++++++++++++++++-- 2 files changed, 61 insertions(+), 7 deletions(-) diff --git a/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py b/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py index 1da5a4fbe34..22bec5de1b9 100644 --- a/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py +++ b/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py @@ -135,7 +135,7 @@ def get_mask_cpp_check_expr(mask: str) -> str: "qs": "ck_tile::BlockFmhaPipelineQSKSVS", "qr_async_trload": "ck_tile::BlockFmhaPipelineQRKSVSAsyncTrload", "qr_async_trload_v3": "ck_tile::BlockFmhaFwdV3Pipeline", - "qr_wholek_prefetch" : "ck_tile::BlockFmhaPipelineQRKSVSWholeKPrefetch", + "qr_wholek_prefetch": "ck_tile::BlockFmhaPipelineQRKSVSWholeKPrefetch", } PIPELINE_ENUM_MAP = { @@ -146,7 +146,7 @@ def get_mask_cpp_check_expr(mask: str) -> str: "qr_pagedkv": "ck_tile::BlockFmhaPipelineEnum::QRKSVS", "qr_async_trload": "ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC_TRLOAD", "qr_async_trload_v3": "ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC_TRLOAD_V3", - "qr_wholek_prefetch" : "ck_tile::BlockFmhaPipelineEnum::QRKSVS_WHOLEK_PREFETCH", + "qr_wholek_prefetch": "ck_tile::BlockFmhaPipelineEnum::QRKSVS_WHOLEK_PREFETCH", } BOOL_MAP = { diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py index 5b211bc92ca..ea54f87a1cc 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py @@ -350,7 +350,13 @@ def dcheck(self) -> str: return f"a.hdim_q % {vec} == 0" else: assert False - elif self.pipeline_tag in ["qr", "qs", "qr_async_trload", "qr_async_trload_v3", "qr_wholek_prefetch"]: + elif self.pipeline_tag in [ + "qr", + "qs", + "qr_async_trload", + "qr_async_trload_v3", + "qr_wholek_prefetch", + ]: bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max] if self.dpad == "t": return f"true /*a.hdim_q % {bk0submax} != 0*/" # TODO: order of get_pipelines() matters! (ugly) @@ -367,7 +373,13 @@ def dvcheck(self) -> str: return f"a.hdim_v % {vec} == 0" else: assert False - elif self.pipeline_tag in ["qr", "qs", "qr_async_trload", "qr_async_trload_v3", "qr_wholek_prefetch"]: + elif self.pipeline_tag in [ + "qr", + "qs", + "qr_async_trload", + "qr_async_trload_v3", + "qr_wholek_prefetch", + ]: bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max] if self.dvpad == "t": return f"true /*a.hdim_v % {bk0submax} != 0*/" # TODO: order of get_pipelines() matters! (ugly) @@ -1013,9 +1025,51 @@ def get_pipelines( else: pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "f", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip - if (hdim, hdim_v) == (128, 128) and logits == "f" and bias == "no" and dropout == "f" and sink == "f": - pipelines.append(FmhaFwdPipeline("qr_wholek_prefetch", "row", "f", "f", "f", "f", logits, bias, lse, dropout, squant, mask, skip, 'f', sink)) - pipelines.append(FmhaFwdPipeline("qr_wholek_prefetch", "row", "f", "t", "f", "f", logits, bias, lse, dropout, squant, mask, skip, 'f', sink)) + if ( + (hdim, hdim_v) == (128, 128) + and logits == "f" + and bias == "no" + and dropout == "f" + and sink == "f" + ): + pipelines.append( + FmhaFwdPipeline( + "qr_wholek_prefetch", + "row", + "f", + "f", + "f", + "f", + logits, + bias, + lse, + dropout, + qscale, + mask, + skip, + "f", + sink, + ) + ) + pipelines.append( + FmhaFwdPipeline( + "qr_wholek_prefetch", + "row", + "f", + "t", + "f", + "f", + logits, + bias, + lse, + dropout, + qscale, + mask, + skip, + "f", + sink, + ) + ) if receipt == 1 and bias != "bias": pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip # TODO: cover arbitraty hdim# fmt: skip elif dtype in cls._DT_FP8BF16 or dtype in cls._DT_FP8FP32: From d43ea4a69cf341f3268004a2955a490754b8085f Mon Sep 17 00:00:00 2001 From: Linjun-AMD Date: Thu, 15 Jan 2026 13:15:45 +0800 Subject: [PATCH 06/10] Update fmha_fwd.py --- .../ck_tile/01_fmha/codegen/ops/fmha_fwd.py | 19 ------------------- 1 file changed, 19 deletions(-) diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py index ea54f87a1cc..498d9df5f54 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py @@ -1051,25 +1051,6 @@ def get_pipelines( sink, ) ) - pipelines.append( - FmhaFwdPipeline( - "qr_wholek_prefetch", - "row", - "f", - "t", - "f", - "f", - logits, - bias, - lse, - dropout, - qscale, - mask, - skip, - "f", - sink, - ) - ) if receipt == 1 and bias != "bias": pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip # TODO: cover arbitraty hdim# fmt: skip elif dtype in cls._DT_FP8BF16 or dtype in cls._DT_FP8FP32: From 7c44495ca447bdc3b3d80ac526071e5d4d7099eb Mon Sep 17 00:00:00 2001 From: Linjun-AMD Date: Thu, 15 Jan 2026 15:14:58 +0800 Subject: [PATCH 07/10] Update example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py index 498d9df5f54..2593b43d332 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py @@ -315,7 +315,7 @@ def scheck(self) -> str: assert False def seqtune(self, max_bm0: int) -> str: - if self.bm0 == max_bm0 or self.bm0 == 64: + if self.bm0 == max_bm0 or self.bm0 == 64: # 64 is the smallest bm0 tile (used e.g. by wholek_prefetch) and serves as a generic fallback return "true/*fall back to largest tile*/" else: return f"a.seqlen_q <= {self.bm0}" From 8ec3ca31783d31dfd286a1f3de17184772434a44 Mon Sep 17 00:00:00 2001 From: Linjun-AMD Date: Thu, 15 Jan 2026 15:15:15 +0800 Subject: [PATCH 08/10] Update include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- .../pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp index ffbddddb55c..0f6be701464 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp @@ -108,7 +108,7 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch } }(); - static constexpr const char* name = "qr_ks_vs_whole_k_prefetch"; + static constexpr const char* name = "qr_wholek_prefetch"; using DropoutType = std::conditional_t; From ecc0d6eacecb801dc2e6bf384d23ae3beb14cc76 Mon Sep 17 00:00:00 2001 From: Linjun-AMD Date: Thu, 15 Jan 2026 15:15:33 +0800 Subject: [PATCH 09/10] Update example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py index 2593b43d332..0af2dd4ec1b 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py @@ -954,7 +954,7 @@ def get_hdim_tile_size_dict(cls, dtype: str) -> Optional[dict]: ( 96, 128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], (128, 128) : [FmhaFwdTileSize( 16, 32, 64, 128, 32, 128, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 32, -1), FmhaFwdTileSize( 32, 32, 128, 128, 32, 128, 1, 1, 1, 1, 1, 1, 32, 32, 16, 32, 32, 16, -1), - FmhaFwdTileSize(64, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 32, 16, 16, 16, -1,CppConstraint('get_num_blocks(128) < num_cus * min_cu_util_rate')), + FmhaFwdTileSize(64, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 32, 16, 16, 16, -1, CppConstraint('get_num_blocks(128) < num_cus * min_cu_util_rate')), FmhaFwdTileSize(128, 64, 32, 128, 16, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1), FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], # (160, 160) : [FmhaFwdTileSize(128, 128 , 32, 160, 32, 160, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1)], From 19a3ab1c7320491ebbd2b4f122db2faa896c0968 Mon Sep 17 00:00:00 2001 From: Linjun-AMD Date: Thu, 15 Jan 2026 15:19:31 +0800 Subject: [PATCH 10/10] Improve readability of seqtune method Refactor seqtune method for better readability. --- example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py index 0af2dd4ec1b..f2c7be7f004 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py @@ -315,7 +315,9 @@ def scheck(self) -> str: assert False def seqtune(self, max_bm0: int) -> str: - if self.bm0 == max_bm0 or self.bm0 == 64: # 64 is the smallest bm0 tile (used e.g. by wholek_prefetch) and serves as a generic fallback + if ( + self.bm0 == max_bm0 or self.bm0 == 64 + ): # 64 is the smallest bm0 tile (used e.g. by wholek_prefetch) and serves as a generic fallback return "true/*fall back to largest tile*/" else: return f"a.seqlen_q <= {self.bm0}"