Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 1 addition & 7 deletions example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,7 @@ def scheck(self) -> str:
assert False

def seqtune(self, max_bm0: int) -> str:
if self.bm0 == max_bm0 or self.bm0 == 64:
if self.bm0 == max_bm0:
return "true/*fall back to largest tile*/"
else:
return f"a.seqlen_q <= {self.bm0}"
Expand Down Expand Up @@ -847,11 +847,6 @@ def check_hdim_tile(
(problem_ctx.hdim, problem_ctx.hdim_v) != (128, 128)
and kernel_ctx.tile.F_bm0 != 128
)
or (
(problem_ctx.hdim, problem_ctx.hdim_v) == (128, 128)
and kernel_ctx.pipeline.tag != "qr_async"
and kernel_ctx.tile.F_bk0 == 64
)
):
# non qr_async_trload only support km0=128 tile size when hdim is not 128
# non qr_async only support kn0=128 tile size when hdim is 128
Expand Down Expand Up @@ -947,7 +942,6 @@ def get_hdim_tile_size_dict(cls, dtype: str) -> Optional[dict]:
( 96, 128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
(128, 128) : [FmhaFwdTileSize( 16, 32, 64, 128, 32, 128, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 32, -1),
FmhaFwdTileSize( 32, 32, 128, 128, 32, 128, 1, 1, 1, 1, 1, 1, 32, 32, 16, 32, 32, 16, -1),
FmhaFwdTileSize( 64, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 32, 16, 16, 16, -1, CppConstraint('get_num_blocks(64) <= num_cus')),
FmhaFwdTileSize(128, 64, 32, 128, 16, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
# (160, 160) : [FmhaFwdTileSize(128, 128 , 32, 160, 32, 160, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1)],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -321,8 +321,6 @@ struct BlockFmhaPipelineQRKSVSAsync
{
if(num_total_loop <= 0)
{
buffer_load_fence(0); // rocm-7.1.1, if whole tile is masked out, need to fence(0)
// otherwise will have compute error(maybe compiler bug?)
if constexpr(kStoreLSE)
{
auto lse =
Expand Down