diff --git a/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py b/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py index a3cfe2622a..22bec5de1b 100644 --- a/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py +++ b/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py @@ -135,6 +135,7 @@ def get_mask_cpp_check_expr(mask: str) -> str: "qs": "ck_tile::BlockFmhaPipelineQSKSVS", "qr_async_trload": "ck_tile::BlockFmhaPipelineQRKSVSAsyncTrload", "qr_async_trload_v3": "ck_tile::BlockFmhaFwdV3Pipeline", + "qr_wholek_prefetch": "ck_tile::BlockFmhaPipelineQRKSVSWholeKPrefetch", } PIPELINE_ENUM_MAP = { @@ -145,6 +146,7 @@ def get_mask_cpp_check_expr(mask: str) -> str: "qr_pagedkv": "ck_tile::BlockFmhaPipelineEnum::QRKSVS", "qr_async_trload": "ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC_TRLOAD", "qr_async_trload_v3": "ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC_TRLOAD_V3", + "qr_wholek_prefetch": "ck_tile::BlockFmhaPipelineEnum::QRKSVS_WHOLEK_PREFETCH", } BOOL_MAP = { diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py index dd65c0298b..f2c7be7f00 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py @@ -306,7 +306,7 @@ def scheck(self) -> str: return "true" # always support else: return "true" - elif self.pipeline_tag in ["qr", "qs"]: + elif self.pipeline_tag in ["qr", "qs", "qr_wholek_prefetch"]: if self.spad == "t": return f"true /*a.seqlen_q % {self.bm0} != 0*/" # TODO: order of get_pipelines() matters! (ugly) else: @@ -315,7 +315,9 @@ def scheck(self) -> str: assert False def seqtune(self, max_bm0: int) -> str: - if self.bm0 == max_bm0: + if ( + self.bm0 == max_bm0 or self.bm0 == 64 + ): # 64 is the smallest bm0 tile (used e.g. by wholek_prefetch) and serves as a generic fallback return "true/*fall back to largest tile*/" else: return f"a.seqlen_q <= {self.bm0}" @@ -329,7 +331,7 @@ def skcheck(self) -> str: return f"(a.cu_seqlen_k_ptr != nullptr) || (a.seqlen_k == 0 || a.seqlen_k % {self.bn0} != 0)" else: return f"(a.cu_seqlen_k_ptr == nullptr) && (a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0)" - elif self.pipeline_tag in ["qr", "qs"]: + elif self.pipeline_tag in ["qr", "qs", "qr_wholek_prefetch"]: if self.skpad == "t": return f"true /*a.seqlen_k % {self.bn0} != 0*/" # TODO: order of get_pipelines() matters! (ugly) else: @@ -350,7 +352,13 @@ def dcheck(self) -> str: return f"a.hdim_q % {vec} == 0" else: assert False - elif self.pipeline_tag in ["qr", "qs", "qr_async_trload", "qr_async_trload_v3"]: + elif self.pipeline_tag in [ + "qr", + "qs", + "qr_async_trload", + "qr_async_trload_v3", + "qr_wholek_prefetch", + ]: bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max] if self.dpad == "t": return f"true /*a.hdim_q % {bk0submax} != 0*/" # TODO: order of get_pipelines() matters! (ugly) @@ -367,7 +375,13 @@ def dvcheck(self) -> str: return f"a.hdim_v % {vec} == 0" else: assert False - elif self.pipeline_tag in ["qr", "qs", "qr_async_trload", "qr_async_trload_v3"]: + elif self.pipeline_tag in [ + "qr", + "qs", + "qr_async_trload", + "qr_async_trload_v3", + "qr_wholek_prefetch", + ]: bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max] if self.dvpad == "t": return f"true /*a.hdim_v % {bk0submax} != 0*/" # TODO: order of get_pipelines() matters! (ugly) @@ -942,6 +956,7 @@ def get_hdim_tile_size_dict(cls, dtype: str) -> Optional[dict]: ( 96, 128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], (128, 128) : [FmhaFwdTileSize( 16, 32, 64, 128, 32, 128, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 32, -1), FmhaFwdTileSize( 32, 32, 128, 128, 32, 128, 1, 1, 1, 1, 1, 1, 32, 32, 16, 32, 32, 16, -1), + FmhaFwdTileSize(64, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 32, 16, 16, 16, -1, CppConstraint('get_num_blocks(128) < num_cus * min_cu_util_rate')), FmhaFwdTileSize(128, 64, 32, 128, 16, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1), FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], # (160, 160) : [FmhaFwdTileSize(128, 128 , 32, 160, 32, 160, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1)], @@ -1012,6 +1027,32 @@ def get_pipelines( else: pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "f", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip + if ( + (hdim, hdim_v) == (128, 128) + and logits == "f" + and bias == "no" + and dropout == "f" + and sink == "f" + ): + pipelines.append( + FmhaFwdPipeline( + "qr_wholek_prefetch", + "row", + "f", + "f", + "f", + "f", + logits, + bias, + lse, + dropout, + qscale, + mask, + skip, + "f", + sink, + ) + ) if receipt == 1 and bias != "bias": pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip # TODO: cover arbitraty hdim# fmt: skip elif dtype in cls._DT_FP8BF16 or dtype in cls._DT_FP8FP32: @@ -1320,6 +1361,10 @@ def get_fwd_blobs( for tile, pipeline in itertools.product( tiles, factory.get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl) ): + if pipeline.tag == "qr_wholek_prefetch" and ( + (hdim, hdim_v) == (128, 128) and tile.F_bm0 == 128 + ): + continue problem_ctx = ProblemContext( dtype=dtype, mode=mode, hdim=hdim, hdim_v=hdim_v ) diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_enum.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_enum.hpp index 659bdd995b..a5a9149cd0 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_enum.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_enum.hpp @@ -12,6 +12,7 @@ enum class BlockFmhaPipelineEnum QRKSVS_ASYNC, QSKSVS, QRKSVS_ASYNC_TRLOAD, + QRKSVS_WHOLEK_PREFETCH, QRKSVS_ASYNC_TRLOAD_V3, }; @@ -40,4 +41,10 @@ struct BlockFmhaPipelineEnumToStr static constexpr const char* name = "qr_async_trload"; }; +template <> +struct BlockFmhaPipelineEnumToStr +{ + static constexpr const char* name = "qr_wholek_prefetch"; +}; + } // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp index 8114bb96c4..0f6be70146 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp @@ -108,7 +108,7 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch } }(); - static constexpr const char* name = "qr_async"; + static constexpr const char* name = "qr_wholek_prefetch"; using DropoutType = std::conditional_t;