Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions example/ck_tile/01_fmha/codegen/cpp_symbol_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ def get_mask_cpp_check_expr(mask: str) -> str:
"qs": "ck_tile::BlockFmhaPipelineQSKSVS",
"qr_async_trload": "ck_tile::BlockFmhaPipelineQRKSVSAsyncTrload",
"qr_async_trload_v3": "ck_tile::BlockFmhaFwdV3Pipeline",
"qr_wholek_prefetch": "ck_tile::BlockFmhaPipelineQRKSVSWholeKPrefetch",
}

PIPELINE_ENUM_MAP = {
Expand All @@ -145,6 +146,7 @@ def get_mask_cpp_check_expr(mask: str) -> str:
"qr_pagedkv": "ck_tile::BlockFmhaPipelineEnum::QRKSVS",
"qr_async_trload": "ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC_TRLOAD",
"qr_async_trload_v3": "ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC_TRLOAD_V3",
"qr_wholek_prefetch": "ck_tile::BlockFmhaPipelineEnum::QRKSVS_WHOLEK_PREFETCH",
}

BOOL_MAP = {
Expand Down
55 changes: 50 additions & 5 deletions example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,7 @@ def scheck(self) -> str:
return "true" # always support
else:
return "true"
elif self.pipeline_tag in ["qr", "qs"]:
elif self.pipeline_tag in ["qr", "qs", "qr_wholek_prefetch"]:
if self.spad == "t":
return f"true /*a.seqlen_q % {self.bm0} != 0*/" # TODO: order of get_pipelines() matters! (ugly)
else:
Expand All @@ -315,7 +315,9 @@ def scheck(self) -> str:
assert False

def seqtune(self, max_bm0: int) -> str:
if self.bm0 == max_bm0:
if (
self.bm0 == max_bm0 or self.bm0 == 64
): # 64 is the smallest bm0 tile (used e.g. by wholek_prefetch) and serves as a generic fallback
return "true/*fall back to largest tile*/"
else:
return f"a.seqlen_q <= {self.bm0}"
Expand All @@ -329,7 +331,7 @@ def skcheck(self) -> str:
return f"(a.cu_seqlen_k_ptr != nullptr) || (a.seqlen_k == 0 || a.seqlen_k % {self.bn0} != 0)"
else:
return f"(a.cu_seqlen_k_ptr == nullptr) && (a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0)"
elif self.pipeline_tag in ["qr", "qs"]:
elif self.pipeline_tag in ["qr", "qs", "qr_wholek_prefetch"]:
if self.skpad == "t":
return f"true /*a.seqlen_k % {self.bn0} != 0*/" # TODO: order of get_pipelines() matters! (ugly)
else:
Expand All @@ -350,7 +352,13 @@ def dcheck(self) -> str:
return f"a.hdim_q % {vec} == 0"
else:
assert False
elif self.pipeline_tag in ["qr", "qs", "qr_async_trload", "qr_async_trload_v3"]:
elif self.pipeline_tag in [
"qr",
"qs",
"qr_async_trload",
"qr_async_trload_v3",
"qr_wholek_prefetch",
]:
bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max]
if self.dpad == "t":
return f"true /*a.hdim_q % {bk0submax} != 0*/" # TODO: order of get_pipelines() matters! (ugly)
Expand All @@ -367,7 +375,13 @@ def dvcheck(self) -> str:
return f"a.hdim_v % {vec} == 0"
else:
assert False
elif self.pipeline_tag in ["qr", "qs", "qr_async_trload", "qr_async_trload_v3"]:
elif self.pipeline_tag in [
"qr",
"qs",
"qr_async_trload",
"qr_async_trload_v3",
"qr_wholek_prefetch",
]:
bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max]
if self.dvpad == "t":
return f"true /*a.hdim_v % {bk0submax} != 0*/" # TODO: order of get_pipelines() matters! (ugly)
Expand Down Expand Up @@ -942,6 +956,7 @@ def get_hdim_tile_size_dict(cls, dtype: str) -> Optional[dict]:
( 96, 128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
(128, 128) : [FmhaFwdTileSize( 16, 32, 64, 128, 32, 128, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 32, -1),
FmhaFwdTileSize( 32, 32, 128, 128, 32, 128, 1, 1, 1, 1, 1, 1, 32, 32, 16, 32, 32, 16, -1),
FmhaFwdTileSize(64, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 32, 16, 16, 16, -1, CppConstraint('get_num_blocks(128) < num_cus * min_cu_util_rate')),
FmhaFwdTileSize(128, 64, 32, 128, 16, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
# (160, 160) : [FmhaFwdTileSize(128, 128 , 32, 160, 32, 160, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1)],
Expand Down Expand Up @@ -1012,6 +1027,32 @@ def get_pipelines(
else:
pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "f", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip
pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip
if (
(hdim, hdim_v) == (128, 128)
and logits == "f"
and bias == "no"
and dropout == "f"
and sink == "f"
):
pipelines.append(
FmhaFwdPipeline(
"qr_wholek_prefetch",
"row",
"f",
Copy link
Contributor

@poyenc poyenc Jan 15, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why don't add padding for both of the seqlen_q & seqlen_k dimensions?

"f",
"f",
"f",
logits,
bias,
lse,
dropout,
qscale,
mask,
skip,
"f",
sink,
)
)
if receipt == 1 and bias != "bias":
pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip # TODO: cover arbitraty hdim# fmt: skip
elif dtype in cls._DT_FP8BF16 or dtype in cls._DT_FP8FP32:
Expand Down Expand Up @@ -1320,6 +1361,10 @@ def get_fwd_blobs(
for tile, pipeline in itertools.product(
tiles, factory.get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl)
):
if pipeline.tag == "qr_wholek_prefetch" and (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please move the check to the CompatibilityRuleFactoryGfx9

(hdim, hdim_v) == (128, 128) and tile.F_bm0 == 128
):
continue
problem_ctx = ProblemContext(
dtype=dtype, mode=mode, hdim=hdim, hdim_v=hdim_v
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ enum class BlockFmhaPipelineEnum
QRKSVS_ASYNC,
QSKSVS,
QRKSVS_ASYNC_TRLOAD,
QRKSVS_WHOLEK_PREFETCH,
QRKSVS_ASYNC_TRLOAD_V3,
};

Expand Down Expand Up @@ -40,4 +41,10 @@ struct BlockFmhaPipelineEnumToStr<BlockFmhaPipelineEnum::QRKSVS_ASYNC_TRLOAD>
static constexpr const char* name = "qr_async_trload";
};

template <>
struct BlockFmhaPipelineEnumToStr<BlockFmhaPipelineEnum::QRKSVS_WHOLEK_PREFETCH>
{
static constexpr const char* name = "qr_wholek_prefetch";
};

} // namespace ck_tile
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch
}
}();

static constexpr const char* name = "qr_async";
static constexpr const char* name = "qr_wholek_prefetch";

using DropoutType = std::conditional_t<kHasDropout, BlockDropout, NullBlockDropout>;

Expand Down