Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py
Original file line number Diff line number Diff line change
Expand Up @@ -809,6 +809,11 @@ def check_hdim(problem_ctx: ProblemContext, kernel_ctx: KernelContext) -> bool:
False
return True

all_mask_keys = list(get_mask_map("simplified").keys()) + list(
get_mask_map("generic").keys()
)
no_mask_keys = [mask_key for mask_key in all_mask_keys if "no" in mask_key]

def check_feature(
problem_ctx: ProblemContext, kernel_ctx: KernelContext
) -> bool:
Expand All @@ -821,6 +826,12 @@ def check_feature(
or kernel_ctx.pipeline.F_logits == "f"
):
return False
# sink_size is only meaningful when mask is applied
if (
kernel_ctx.pipeline.F_mask in no_mask_keys
and kernel_ctx.pipeline.F_sink == "t"
):
return False
Comment on lines +829 to +834
Copy link

Copilot AI Jan 21, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment says sink_size is only meaningful when no masking is applied, but the condition directly below filters out the no-mask + sink=true combination. Please update the comment to match the logic (i.e., sink is only meaningful when masking is enabled).

Copilot uses AI. Check for mistakes.
return True

return [check_mode, check_hdim, check_feature]
Expand Down
9 changes: 9 additions & 0 deletions example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py
Original file line number Diff line number Diff line change
Expand Up @@ -880,6 +880,11 @@ def get_fwd_splitkv_blobs(

factories = get_factories_for_targets(targets, get_factory)

all_mask_keys = list(get_mask_map("simplified").keys()) + list(
get_mask_map("generic").keys()
)
no_mask_keys = [mask_key for mask_key in all_mask_keys if "no" in mask_key]

for factory, dtype in itertools.product(factories, FWD_DTYPE_MAP.keys()):
d = factory.get_hdim_tile_size_dict(dtype)
if d is None:
Expand All @@ -899,6 +904,10 @@ def get_fwd_splitkv_blobs(
or pipeline.F_logits == "f"
):
continue
# sink_size is only meaningful when mask is applied
Copy link

Copilot AI Jan 21, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment says sink_size is only meaningful when no masking is applied, but the condition directly below filters out the no-mask + sink=true combination. Please update the comment to match the logic (i.e., sink is only meaningful when masking is enabled).

Suggested change
# sink_size is only meaningful when mask is applied
# sink_size is only meaningful when masking is enabled, so disallow sink when no mask is applied

Copilot uses AI. Check for mistakes.
if pipeline.F_mask in no_mask_keys and pipeline.F_sink == "t":
continue

k = Kernel(
F_arch=factory.arch,
F_idx=0,
Expand Down
9 changes: 9 additions & 0 deletions example/ck_tile/01_fmha/codegen/ops/fmha_pagedkv_prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -642,6 +642,11 @@ def get_fwd_blobs(

factories = get_factories_for_targets(targets, get_factory)

all_mask_keys = list(get_mask_map("simplified").keys()) + list(
get_mask_map("generic").keys()
)
no_mask_keys = [mask_key for mask_key in all_mask_keys if "no" in mask_key]

for factory, dtype in itertools.product(factories, FWD_DTYPE_MAP.keys()):
d = factory.get_hdim_tile_size_dict(dtype)
if d is None:
Expand All @@ -666,6 +671,10 @@ def get_fwd_blobs(
or pipeline.F_logits == "f"
):
continue
# sink_size is only meaningful when mask is applied
Copy link

Copilot AI Jan 21, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment says sink_size is only meaningful when no masking is applied, but the condition directly below filters out the no-mask + sink=true combination. Please update the comment to match the logic (i.e., sink is only meaningful when masking is enabled).

Suggested change
# sink_size is only meaningful when mask is applied
# sink_size is only meaningful when masking is enabled; disallow sink when no mask is used

Copilot uses AI. Check for mistakes.
if pipeline.F_mask in no_mask_keys and pipeline.F_sink == "t":
continue

k = FmhaFwdKernel(
F_arch=factory.arch,
F_idx=0,
Expand Down
5 changes: 4 additions & 1 deletion example/ck_tile/01_fmha/mask.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,10 @@ struct mask_info
}
else if(str == "0")
{
tmp.type = mask_enum::no_mask;
tmp.type = mask_enum::no_mask;
tmp.left = -1;
tmp.right = -1;
tmp.sink = 0;
}
else if(str == "1" || str == "t")
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ struct BlockFmhaPipelineProblem
static constexpr auto QScaleEnum = Traits::QScaleEnum;
static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
static constexpr bool kHasSink = Traits::kHasSink;
static_assert(FmhaMask::IsMasking || !kHasSink);
Copy link

Copilot AI Jan 21, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The new static_assert has no diagnostic message, while other static_asserts in this file provide one (e.g., lines 108–123). Adding a short message (e.g., that sink requires masking) would make template instantiation failures much easier to understand.

Copilot uses AI. Check for mistakes.
};

template <typename QDataType_,
Expand Down Expand Up @@ -174,6 +175,7 @@ struct BlockFmhaFwdPagedKVPipelineProblem
static constexpr bool kIsPagedKV = Traits::kIsPagedKV;
static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
static constexpr bool kHasSink = Traits::kHasSink;
static_assert(FmhaMask::IsMasking || !kHasSink);
Copy link

Copilot AI Jan 21, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The new static_assert has no diagnostic message, while other static_asserts in this file provide one (e.g., lines 108–123). Adding a short message (e.g., that sink requires masking) would make template instantiation failures much easier to understand.

Copilot uses AI. Check for mistakes.
};

template <typename QDataType_,
Expand Down Expand Up @@ -228,6 +230,7 @@ struct BlockFmhaFwdSplitKVPipelineProblem
static constexpr bool kMergeNumHeadGroupsSeqLenQ = Traits::kMergeNumHeadGroupsSeqLenQ;
static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
static constexpr bool kHasSink = Traits::kHasSink;
static_assert(FmhaMask::IsMasking || !kHasSink);
Copy link

Copilot AI Jan 21, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The new static_assert has no diagnostic message, while other static_asserts in this file provide one (e.g., lines 108–123). Adding a short message (e.g., that sink requires masking) would make template instantiation failures much easier to understand.

Copilot uses AI. Check for mistakes.
};

// extract tile size attributes to remove dependency on traits
Expand Down