Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 50 additions & 27 deletions atom/model_ops/attention_mha.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# SPDX-License-Identifier: MIT
# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.

import logging
from typing import Optional

import aiter
Expand All @@ -15,6 +16,8 @@

from .attention_mla import MLAModules

logger = logging.getLogger("atom")

from atom.plugin.prepare import is_plugin_mode, is_vllm
from atom.plugin.attention_mha import PagedAttentionImplDecoratorForPluginMode
Comment on lines +19 to 22
Copy link

Copilot AI Mar 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logger = ... is executed before the subsequent from atom.plugin... imports, which will trigger Ruff E402 (module level import not at top of file). Move the logger initialization below all imports (or move these imports above the logger assignment) so lint passes.

Copilot uses AI. Check for mistakes.

Expand Down Expand Up @@ -122,38 +125,58 @@ def rope_cache(self, q, k, v, qkv, position, fwd_ctx: ForwardContext):
use_triton_attn = self.sliding_window != -1 or self.head_dim != 128
self.use_triton_attn = use_triton_attn

_fused_ok = False
if (
self.rotary_emb is not None
and self.q_norm is not None
and self.k_norm is not None
):
fused_qk_norm_rope_cache_quant_shuffle(
qkv,
num_heads_q=self.num_heads,
num_heads_k=self.num_kv_heads,
num_heads_v=self.num_kv_heads,
head_dim=self.head_dim,
eps=self.q_norm.eps,
qw=self.q_norm.weight,
kw=self.k_norm.weight,
cos_sin_cache=self.rotary_emb.cos_sin_cache,
is_neox_style=self.rotary_emb.is_neox_style,
pos_ids=position,
k_cache=k_cache,
v_cache=v_cache,
slot_mapping=attn_metadata.slot_mapping,
kv_cache_dtype=(
"auto" if self.kv_cache_dtype == "bf16" else self.kv_cache_dtype
),
k_scale=k_scale,
v_scale=v_scale,
)
qkv_backup = qkv.clone()
try:
Comment on lines 133 to +135
Copy link

Copilot AI Mar 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

qkv_backup = qkv.clone() is outside the try: and assumes qkv is always a Tensor. Since qkv is an optional argument, a None (or clone failure) would raise before reaching the except, preventing the intended fallback. Consider adding qkv is not None to the fused-path condition and/or moving backup creation inside the try: with a guarded restore only when the backup exists.

Copilot uses AI. Check for mistakes.
fused_qk_norm_rope_cache_quant_shuffle(
qkv,
Comment on lines +134 to +137
Copy link

Copilot AI Mar 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On CK-free builds where the fused kernel consistently fails, this will still clone() and then raise/catch an exception on every rope_cache() call, which is very expensive on the decode hot path. Cache the failure (e.g., a class/instance flag) and skip the fused attempt entirely after the first failure; similarly, after the first successful fused call you can avoid taking a backup on subsequent calls.

Copilot uses AI. Check for mistakes.
num_heads_q=self.num_heads,
num_heads_k=self.num_kv_heads,
num_heads_v=self.num_kv_heads,
head_dim=self.head_dim,
eps=self.q_norm.eps,
qw=self.q_norm.weight,
kw=self.k_norm.weight,
cos_sin_cache=self.rotary_emb.cos_sin_cache,
is_neox_style=self.rotary_emb.is_neox_style,
pos_ids=position,
k_cache=k_cache,
v_cache=v_cache,
slot_mapping=attn_metadata.slot_mapping,
kv_cache_dtype=(
"auto" if self.kv_cache_dtype == "bf16" else self.kv_cache_dtype
),
k_scale=k_scale,
v_scale=v_scale,
)

qkv = qkv.view(qkv.shape[0], -1, self.head_dim)
q, k, v = qkv.split(
[self.num_heads, self.num_kv_heads, self.num_kv_heads], dim=1
)
elif use_triton_attn and self.rotary_emb is not None:
qkv = qkv.view(qkv.shape[0], -1, self.head_dim)
q, k, v = qkv.split(
[self.num_heads, self.num_kv_heads, self.num_kv_heads], dim=1
)
_fused_ok = True
except Exception as e:
if not getattr(PagedAttentionImpl, "_fused_rope_warned", False):
logger.warning(
"fused_qk_norm_rope_cache_quant_shuffle failed (%s), "
"falling back to non-fused path",
e,
)
PagedAttentionImpl._fused_rope_warned = True
qkv.copy_(qkv_backup)
del qkv_backup

if (
not _fused_ok
and use_triton_attn
and self.rotary_emb is not None
and self.q_norm is None
):
k_scale = v_scale = self.kv_scale

q, k, k_cache, v_cache = fused_qk_rope_reshape_and_cache(
Expand All @@ -176,7 +199,7 @@ def rope_cache(self, q, k, v, qkv, position, fwd_ctx: ForwardContext):
k_out=k,
output_zeros=False,
)
else:
elif not _fused_ok:
# for asm paged attention
asm_layout = True
if use_triton_attn:
Expand Down
1 change: 0 additions & 1 deletion atom/model_ops/attentions/aiter_attention.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# SPDX-License-Identifier: MIT
# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.

import itertools
from typing import Type

import aiter
Expand Down