-
Notifications
You must be signed in to change notification settings - Fork 267
[CK_TILE][FMHA] Fix uninitialized sink_size in mask_info::decode() and filter redundant no-mask+sink instances #3504
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Changes from all commits
9a47f21
57d5599
9de34c9
84c5f7b
b7da954
6542982
f20b16e
e459f43
692ffd9
d739507
2cc4fea
7975c94
4c37878
7425250
26f3a0b
ac46bdc
5844a1f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -880,6 +880,11 @@ def get_fwd_splitkv_blobs( | |||||
|
|
||||||
| factories = get_factories_for_targets(targets, get_factory) | ||||||
|
|
||||||
| all_mask_keys = list(get_mask_map("simplified").keys()) + list( | ||||||
| get_mask_map("generic").keys() | ||||||
| ) | ||||||
| no_mask_keys = [mask_key for mask_key in all_mask_keys if "no" in mask_key] | ||||||
|
|
||||||
| for factory, dtype in itertools.product(factories, FWD_DTYPE_MAP.keys()): | ||||||
| d = factory.get_hdim_tile_size_dict(dtype) | ||||||
| if d is None: | ||||||
|
|
@@ -899,6 +904,10 @@ def get_fwd_splitkv_blobs( | |||||
| or pipeline.F_logits == "f" | ||||||
| ): | ||||||
| continue | ||||||
| # sink_size is only meaningful when mask is applied | ||||||
|
||||||
| # sink_size is only meaningful when mask is applied | |
| # sink_size is only meaningful when masking is enabled, so disallow sink when no mask is applied |
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -642,6 +642,11 @@ def get_fwd_blobs( | |||||
|
|
||||||
| factories = get_factories_for_targets(targets, get_factory) | ||||||
|
|
||||||
| all_mask_keys = list(get_mask_map("simplified").keys()) + list( | ||||||
| get_mask_map("generic").keys() | ||||||
| ) | ||||||
| no_mask_keys = [mask_key for mask_key in all_mask_keys if "no" in mask_key] | ||||||
|
|
||||||
| for factory, dtype in itertools.product(factories, FWD_DTYPE_MAP.keys()): | ||||||
| d = factory.get_hdim_tile_size_dict(dtype) | ||||||
| if d is None: | ||||||
|
|
@@ -666,6 +671,10 @@ def get_fwd_blobs( | |||||
| or pipeline.F_logits == "f" | ||||||
| ): | ||||||
| continue | ||||||
| # sink_size is only meaningful when mask is applied | ||||||
|
||||||
| # sink_size is only meaningful when mask is applied | |
| # sink_size is only meaningful when masking is enabled; disallow sink when no mask is used |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -64,6 +64,7 @@ struct BlockFmhaPipelineProblem | |
| static constexpr auto QScaleEnum = Traits::QScaleEnum; | ||
| static constexpr index_t kBlockPerCu = Traits::kBlockPerCu; | ||
| static constexpr bool kHasSink = Traits::kHasSink; | ||
| static_assert(FmhaMask::IsMasking || !kHasSink); | ||
|
||
| }; | ||
|
|
||
| template <typename QDataType_, | ||
|
|
@@ -174,6 +175,7 @@ struct BlockFmhaFwdPagedKVPipelineProblem | |
| static constexpr bool kIsPagedKV = Traits::kIsPagedKV; | ||
| static constexpr index_t kBlockPerCu = Traits::kBlockPerCu; | ||
| static constexpr bool kHasSink = Traits::kHasSink; | ||
| static_assert(FmhaMask::IsMasking || !kHasSink); | ||
|
||
| }; | ||
|
|
||
| template <typename QDataType_, | ||
|
|
@@ -228,6 +230,7 @@ struct BlockFmhaFwdSplitKVPipelineProblem | |
| static constexpr bool kMergeNumHeadGroupsSeqLenQ = Traits::kMergeNumHeadGroupsSeqLenQ; | ||
| static constexpr index_t kBlockPerCu = Traits::kBlockPerCu; | ||
| static constexpr bool kHasSink = Traits::kHasSink; | ||
| static_assert(FmhaMask::IsMasking || !kHasSink); | ||
|
||
| }; | ||
|
|
||
| // extract tile size attributes to remove dependency on traits | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This comment says sink_size is only meaningful when no masking is applied, but the condition directly below filters out the no-mask + sink=true combination. Please update the comment to match the logic (i.e., sink is only meaningful when masking is enabled).