Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 20 additions & 13 deletions atom/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -621,22 +621,29 @@ def __post_init__(self):
# assert os.path.isdir(self.model)

assert 1 <= self.tensor_parallel_size <= 8
self.hf_config = get_hf_config(self.model)
if is_plugin_mode():
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here we follow the atom config post init?

# plugin mode
assert (
self.plugin_config is not None
), "plugin_config is required in plugin mode"
self.hf_config = self.plugin_config.model_config.hf_config
else:
self.hf_config = get_hf_config(self.model)

self.generation_config = get_generation_config(self.model)
if self.generation_config is not None:
if (
eos_ids := getattr(self.generation_config, "eos_token_id", None)
) is not None:
self.stop_token_ids = (
[eos_ids] if isinstance(eos_ids, int) else eos_ids
)
if not hasattr(self.hf_config, "rope_parameters"):
# Compatible with both transformers < 5
rope_params = getattr(self.hf_config, "rope_scaling", {})
if rope_params is None:
rope_params = {}
rope_params["rope_theta"] = getattr(self.hf_config, "rope_theta", None)
rope_params["rope_type"] = getattr(self.hf_config, "rope_type", "default")
rope_params = getattr(self.hf_config, "rope_scaling", {}) or {}
rope_params["rope_theta"] = self.hf_config.rope_theta
rope_params["rope_type"] = getattr(rope_params, "rope_type", "default")
self.hf_config.rope_parameters = rope_params
Comment on lines 642 to 646
Copy link

Copilot AI Mar 10, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rope_params["rope_type"] = getattr(rope_params, "rope_type", "default") is incorrect because rope_params is a dict; getattr(...) will always return the default and ignore an existing "rope_type" key. Use rope_params.get("rope_type", "default") (or read from self.hf_config) so non-default rope settings are preserved.

Copilot uses AI. Check for mistakes.

self.generation_config = get_generation_config(self.model)
if self.generation_config is not None:
if (
eos_ids := getattr(self.generation_config, "eos_token_id", None)
) is not None:
self.stop_token_ids = [eos_ids] if isinstance(eos_ids, int) else eos_ids
self.quant_config = get_quant_config(self.hf_config)
hf_config_max_position_embeddings = getattr(
self.hf_config, "max_position_embeddings", 8192
Expand Down
16 changes: 11 additions & 5 deletions atom/model_ops/radix_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from .base_attention import BaseAttention
from atom.plugin.prepare import is_plugin_mode, is_sglang
from atom.models.utils import maybe_prefix
from atom.utils import envs


class RadixAttention(BaseAttention):
Expand Down Expand Up @@ -47,7 +48,6 @@ def __init__(
prefix=prefix,
**kwargs,
)
self.rotary_emb = rotary_emb

if is_sglang():
from sglang.srt.layers.radix_attention import RadixAttention
Expand All @@ -64,6 +64,8 @@ def __init__(
raise NotImplementedError(
"RadixAttention is only supported for plugin mode for sglang for now"
)
# if True, save cache will be done in rope
self.use_aiter_rope_fused_qknorm = envs.ATOM_ROPE_FUSED_QKNORM

def forward_impl_plugin_mode(
self,
Expand All @@ -82,10 +84,14 @@ def forward_impl_plugin_mode(
# for sglang, forward_batch is required
forward_batch = kwargs.get("forward_batch", None)
assert forward_batch is not None, "forward_batch is required for sglang"
if self.rotary_emb is not None:
assert positions is not None, "positions is required for ROPE"
query, key = self.rotary_emb(positions, query, key)
return self.attn(q=query, k=key, v=value, forward_batch=forward_batch)
# forward_batch contains the filed attn_backend, which will find the backend registered in ATOM
return self.attn(
Comment on lines +87 to +88
Copy link

Copilot AI Mar 10, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Typo in comment: filed attn_backend should be field attn_backend.

Copilot uses AI. Check for mistakes.
query,
key,
value,
forward_batch=forward_batch,
save_kv_cache=not self.use_aiter_rope_fused_qknorm,
)
else:
raise NotImplementedError(
"RadixAttention is only supported for plugin mode for sglang for now"
Expand Down
82 changes: 74 additions & 8 deletions atom/models/qwen3_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from aiter.dist.parallel_state import get_pp_group, get_tensor_model_parallel_world_size

# from atom.model_ops.rotary_embedding import get_rope
from aiter.rotary_embedding import get_rope
from aiter.rotary_embedding import get_rope, AiterFusedSetKVBufferArg
from atom.config import Config, QuantizationConfig
from atom.model_ops.activation import SiluAndMul

Expand Down Expand Up @@ -34,6 +34,7 @@
from atom.utils import envs
from torch import nn
from atom.model_loader.loader import load_model_in_plugin_mode
from atom.plugin.prepare import is_sglang

# import torch.distributed as dist
from transformers import PretrainedConfig
Expand All @@ -42,6 +43,7 @@
ATOM_ENABLE_QK_NORM_ROPE_CACHE_QUANT_FUSION = (
envs.ATOM_ENABLE_QK_NORM_ROPE_CACHE_QUANT_FUSION
)
ENABLE_AITER_ROPE_FUSED_QKNORM_FOR_SGL_PLUGIN_MODE = envs.ATOM_ROPE_FUSED_QKNORM


class Qwen3MoeMLP(nn.Module):
Expand Down Expand Up @@ -228,6 +230,65 @@ def __init__(

self.kv_cache_dtype = kv_cache_dtype
self.layer_num = layer_num
self.k_scale = torch.tensor([1.0], dtype=torch.float32)
self.v_scale = torch.tensor([1.0], dtype=torch.float32)

Comment on lines +233 to +235
Copy link

Copilot AI Mar 10, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self.k_scale/self.v_scale are created as CPU tensors in __init__, but are later passed alongside GPU kv-cache buffers in plugin mode. This will cause device mismatch errors when the fused path is enabled. Consider registering them as buffers and initializing/moving them to the same device as the model (or creating them on-demand on qkv.device).

Copilot uses AI. Check for mistakes.
def forward_sgl_plugin_mode(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, @ZhiweiYan-96 @zejunchen-zejun

Here should be refined because the execution details should be contained inside the radix attention instead of the model forward

self,
positions: torch.Tensor,
qkv: torch.Tensor,
**model_kwargs: dict[str, Any] | None,
):
if ENABLE_AITER_ROPE_FUSED_QKNORM_FOR_SGL_PLUGIN_MODE:
forward_batch = model_kwargs.get("forward_batch", None)
assert forward_batch is not None, "forward_batch is required for sglang"
k_buffer, v_buffer = forward_batch.token_to_kv_pool.get_kv_buffer(
self.layer_num
)
block_size = 1024 # Default fallback
if hasattr(forward_batch, "attn_backend") and hasattr(
forward_batch.attn_backend, "page_size"
):
block_size = forward_batch.attn_backend.page_size
elif hasattr(forward_batch.token_to_kv_pool, "allocator") and hasattr(
forward_batch.token_to_kv_pool.allocator, "page_size"
):
block_size = forward_batch.token_to_kv_pool.allocator.page_size
elif hasattr(forward_batch.token_to_kv_pool, "page_size"):
block_size = forward_batch.token_to_kv_pool.page_size
x = 16 // k_buffer.element_size()
aiter_fused_set_kv_buffer_arg = AiterFusedSetKVBufferArg(
kv_cache=(k_buffer, v_buffer),
cache_loc=forward_batch.out_cache_loc,
k_scale=self.k_scale,
v_scale=self.v_scale,
return_kv=True,
use_shuffle_layout=True,
block_size=block_size,
x=x,
)
q, k, v = self.rotary_emb(
qkv,
self.q_norm.weight,
self.k_norm.weight,
positions,
self.num_heads,
self.num_kv_heads,
self.q_norm.eps,
fused_set_kv_buffer_arg=aiter_fused_set_kv_buffer_arg,
)
else:
q, k, v = torch.split(
qkv, [self.q_size, self.kv_size, self.kv_size], dim=-1
)
# Add qk-norm
q = self.q_norm(q)
k = self.k_norm(k)

q, k = self.rotary_emb(positions, q, k)

attn_output = self.attn(q, k, v, positions=positions, **model_kwargs)
return attn_output

def forward(
self,
Expand All @@ -245,13 +306,18 @@ def forward(
query=q, key=k, value=v, positions=positions, q_scale=None, qkv=qkv
)
else:
# Add qk-norm
q = self.q_norm(q)
k = self.k_norm(k)
if is_sglang():
attn_output = self.forward_sgl_plugin_mode(
positions, qkv, **model_kwargs
)
else:
# Add qk-norm
q = self.q_norm(q)
k = self.k_norm(k)

attn_output = self.attn(
query=q, key=k, value=v, positions=positions, **model_kwargs
)
attn_output = self.attn(
query=q, key=k, value=v, positions=positions, **model_kwargs
)
output = self.o_proj(attn_output)
return output

Expand All @@ -265,7 +331,7 @@ def __init__(self, atom_config=None, layer_num: int = 0, prefix: str = "") -> No
self.hidden_size = config.hidden_size
rope_params = config.rope_parameters
rope_theta = rope_params["rope_theta"]
rope_scaling = rope_params
rope_scaling = None if rope_params["rope_type"] == "default" else rope_params
kv_cache_dtype = atom_config.kv_cache_dtype
max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
# DecoderLayers are created with `make_layers` which passes the prefix
Expand Down
Empty file.
Loading