-
Notifications
You must be signed in to change notification settings - Fork 42
Add CK-free fallback for fused QKNorm+RoPE+Cache #279
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,6 +1,7 @@ | ||
| # SPDX-License-Identifier: MIT | ||
| # Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. | ||
|
|
||
| import logging | ||
| from typing import Optional | ||
|
|
||
| import aiter | ||
|
|
@@ -15,6 +16,8 @@ | |
|
|
||
| from .attention_mla import MLAModules | ||
|
|
||
| logger = logging.getLogger("atom") | ||
|
|
||
| from atom.plugin.prepare import is_plugin_mode, is_vllm | ||
| from atom.plugin.attention_mha import PagedAttentionImplDecoratorForPluginMode | ||
|
|
||
|
|
@@ -122,38 +125,58 @@ def rope_cache(self, q, k, v, qkv, position, fwd_ctx: ForwardContext): | |
| use_triton_attn = self.sliding_window != -1 or self.head_dim != 128 | ||
| self.use_triton_attn = use_triton_attn | ||
|
|
||
| _fused_ok = False | ||
| if ( | ||
| self.rotary_emb is not None | ||
| and self.q_norm is not None | ||
| and self.k_norm is not None | ||
| ): | ||
| fused_qk_norm_rope_cache_quant_shuffle( | ||
| qkv, | ||
| num_heads_q=self.num_heads, | ||
| num_heads_k=self.num_kv_heads, | ||
| num_heads_v=self.num_kv_heads, | ||
| head_dim=self.head_dim, | ||
| eps=self.q_norm.eps, | ||
| qw=self.q_norm.weight, | ||
| kw=self.k_norm.weight, | ||
| cos_sin_cache=self.rotary_emb.cos_sin_cache, | ||
| is_neox_style=self.rotary_emb.is_neox_style, | ||
| pos_ids=position, | ||
| k_cache=k_cache, | ||
| v_cache=v_cache, | ||
| slot_mapping=attn_metadata.slot_mapping, | ||
| kv_cache_dtype=( | ||
| "auto" if self.kv_cache_dtype == "bf16" else self.kv_cache_dtype | ||
| ), | ||
| k_scale=k_scale, | ||
| v_scale=v_scale, | ||
| ) | ||
| qkv_backup = qkv.clone() | ||
| try: | ||
|
Comment on lines
133
to
+135
|
||
| fused_qk_norm_rope_cache_quant_shuffle( | ||
| qkv, | ||
|
Comment on lines
+134
to
+137
|
||
| num_heads_q=self.num_heads, | ||
| num_heads_k=self.num_kv_heads, | ||
| num_heads_v=self.num_kv_heads, | ||
| head_dim=self.head_dim, | ||
| eps=self.q_norm.eps, | ||
| qw=self.q_norm.weight, | ||
| kw=self.k_norm.weight, | ||
| cos_sin_cache=self.rotary_emb.cos_sin_cache, | ||
| is_neox_style=self.rotary_emb.is_neox_style, | ||
| pos_ids=position, | ||
| k_cache=k_cache, | ||
| v_cache=v_cache, | ||
| slot_mapping=attn_metadata.slot_mapping, | ||
| kv_cache_dtype=( | ||
| "auto" if self.kv_cache_dtype == "bf16" else self.kv_cache_dtype | ||
| ), | ||
| k_scale=k_scale, | ||
| v_scale=v_scale, | ||
| ) | ||
|
|
||
| qkv = qkv.view(qkv.shape[0], -1, self.head_dim) | ||
| q, k, v = qkv.split( | ||
| [self.num_heads, self.num_kv_heads, self.num_kv_heads], dim=1 | ||
| ) | ||
| elif use_triton_attn and self.rotary_emb is not None: | ||
| qkv = qkv.view(qkv.shape[0], -1, self.head_dim) | ||
| q, k, v = qkv.split( | ||
| [self.num_heads, self.num_kv_heads, self.num_kv_heads], dim=1 | ||
| ) | ||
| _fused_ok = True | ||
| except Exception as e: | ||
| if not getattr(PagedAttentionImpl, "_fused_rope_warned", False): | ||
| logger.warning( | ||
| "fused_qk_norm_rope_cache_quant_shuffle failed (%s), " | ||
| "falling back to non-fused path", | ||
| e, | ||
| ) | ||
| PagedAttentionImpl._fused_rope_warned = True | ||
| qkv.copy_(qkv_backup) | ||
| del qkv_backup | ||
|
|
||
| if ( | ||
| not _fused_ok | ||
| and use_triton_attn | ||
| and self.rotary_emb is not None | ||
| and self.q_norm is None | ||
| ): | ||
| k_scale = v_scale = self.kv_scale | ||
|
|
||
| q, k, k_cache, v_cache = fused_qk_rope_reshape_and_cache( | ||
|
|
@@ -176,7 +199,7 @@ def rope_cache(self, q, k, v, qkv, position, fwd_ctx: ForwardContext): | |
| k_out=k, | ||
| output_zeros=False, | ||
| ) | ||
| else: | ||
| elif not _fused_ok: | ||
| # for asm paged attention | ||
| asm_layout = True | ||
| if use_triton_attn: | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logger = ...is executed before the subsequentfrom atom.plugin...imports, which will trigger Ruff E402 (module level import not at top of file). Move the logger initialization below all imports (or move these imports above the logger assignment) so lint passes.