From 4b54e46005f8ae28f6f3a6f605d61f0616b779d0 Mon Sep 17 00:00:00 2001 From: Guanbao Yu Date: Wed, 11 Feb 2026 19:00:58 +0800 Subject: [PATCH 01/11] register attn backend to sgl from ATOM enable mla --- atom/config.py | 31 +- atom/model_ops/__init__.py | 1 + atom/model_ops/linear.py | 11 + atom/model_ops/moe.py | 43 +- atom/model_ops/radix_attention.py | 21 +- atom/models/deepseek_v2.py | 809 +++++++- atom/models/qwen3_moe.py | 93 +- atom/plugin/attention_backend/__init__.py | 0 .../attention_backend/sgl_attn_backend.py | 1773 +++++++++++++++++ atom/plugin/register.py | 6 +- atom/utils/envs.py | 1 + 11 files changed, 2749 insertions(+), 40 deletions(-) create mode 100644 atom/plugin/attention_backend/__init__.py create mode 100644 atom/plugin/attention_backend/sgl_attn_backend.py diff --git a/atom/config.py b/atom/config.py index 1d433987d..36312efdc 100644 --- a/atom/config.py +++ b/atom/config.py @@ -817,22 +817,27 @@ def __post_init__(self): # assert os.path.isdir(self.model) assert 1 <= self.tensor_parallel_size <= 8 - self.hf_config = get_hf_config(self.model) + if is_plugin_mode(): + # plugin mode + assert ( + self.plugin_config is not None + ), "plugin_config is required in plugin mode" + self.hf_config = self.plugin_config.model_config.hf_config + else: + self.hf_config = get_hf_config(self.model) + + self.generation_config = get_generation_config(self.model) + if self.generation_config is not None: + if ( + eos_ids := getattr(self.generation_config, "eos_token_id", None) + ) is not None: + self.stop_token_ids = [eos_ids] if isinstance(eos_ids, int) else eos_ids if not hasattr(self.hf_config, "rope_parameters"): # Compatible with both transformers < 5 - rope_params = getattr(self.hf_config, "rope_scaling", {}) - if rope_params is None: - rope_params = {} - rope_params["rope_theta"] = getattr(self.hf_config, "rope_theta", None) - rope_params["rope_type"] = getattr(self.hf_config, "rope_type", "default") + rope_params = getattr(self.hf_config, "rope_scaling", {}) or {} + rope_params["rope_theta"] = self.hf_config.rope_theta + rope_params["rope_type"] = getattr(rope_params, "rope_type", "default") self.hf_config.rope_parameters = rope_params - - self.generation_config = get_generation_config(self.model) - if self.generation_config is not None: - if ( - eos_ids := getattr(self.generation_config, "eos_token_id", None) - ) is not None: - self.stop_token_ids = [eos_ids] if isinstance(eos_ids, int) else eos_ids self.quant_config = QuantizationConfig(self.hf_config) hf_config_max_position_embeddings = getattr( self.hf_config, "max_position_embeddings", 8192 diff --git a/atom/model_ops/__init__.py b/atom/model_ops/__init__.py index 4b6c0b545..4e0c0c258 100644 --- a/atom/model_ops/__init__.py +++ b/atom/model_ops/__init__.py @@ -5,6 +5,7 @@ # it can be assigned to different attention ops. # By default, PagedAttention is used. # For sglang, RadixAttention will be assigned to Attention +# see register.py for details. Attention = PagedAttention __all__ = [ diff --git a/atom/model_ops/linear.py b/atom/model_ops/linear.py index ed7459614..7555fcbc3 100644 --- a/atom/model_ops/linear.py +++ b/atom/model_ops/linear.py @@ -388,10 +388,21 @@ def process_weights_after_loading(self): if self.quant_type == QuantType.per_1x32: self.weight_scale.data = fp4_utils.e8m0_shuffle(self.weight_scale.data) + _diag_forward_counter = 0 + @mark_trace def forward( self, x: torch.Tensor, x_scale: Optional[torch.Tensor] = None, otype=dtypes.bf16 ) -> torch.Tensor: + if LinearBase._diag_forward_counter < 5: + LinearBase._diag_forward_counter += 1 + print( + f"[DIAG][LinearBase.forward] prefix={self.prefix} " + f"quant_type={self.quant_type} " + f"w_dtype={self.weight.dtype} " + f"w_scale_shape={tuple(self.weight_scale.shape) if self.weight_scale is not None else None} " + f"x shape={tuple(x.shape)}" + ) if self.quant_type.value == QuantType.No.value: y = tgemm.mm( x, diff --git a/atom/model_ops/moe.py b/atom/model_ops/moe.py index 0745df342..4b35b4e5c 100644 --- a/atom/model_ops/moe.py +++ b/atom/model_ops/moe.py @@ -1401,8 +1401,7 @@ def apply( apply_router_weight_on_input=apply_router_weight_on_input, ) else: - # Direct kernel call for non-EP/DP cases - return rocm_asm_moe_impl( + return torch.ops.aiter.rocm_aiter_fused_moe( x, layer.w13_weight, layer.w2_weight, @@ -1611,7 +1610,20 @@ def process_weights_after_loading(self, layer: nn.Module) -> None: def _process_block_quant(self, layer: nn.Module) -> None: assert self.quant_config["is_dynamic"] + print( + f"[DIAG][Fp8MoE._process_block_quant] BEFORE normalize: " + f"w13 dtype={layer.w13_weight.dtype} shape={tuple(layer.w13_weight.shape)} " + f"w13_scale dtype={layer.w13_weight_scale.dtype} shape={tuple(layer.w13_weight_scale.shape)} " + f"need_normalize={self.need_normalize_e4m3fn_to_e4m3fnuz}" + ) self._normalize_weights_and_scales(layer) + print( + f"[DIAG][Fp8MoE._process_block_quant] AFTER normalize: " + f"w13 dtype={layer.w13_weight.dtype} " + f"w13_scale min={layer.w13_weight_scale.data.min().item():.6f} " + f"max={layer.w13_weight_scale.data.max().item():.6f} " + f"mean={layer.w13_weight_scale.data.float().mean().item():.6f}" + ) if not self.need_normalize_e4m3fn_to_e4m3fnuz: layer.w13_weight = nn.Parameter(layer.w13_weight.data, requires_grad=False) @@ -1624,6 +1636,7 @@ def _process_block_quant(self, layer: nn.Module) -> None: ) shuffle_weights(layer.w13_weight, layer.w2_weight) + print(f"[DIAG][Fp8MoE._process_block_quant] DONE shuffle") def _process_channel_quant(self, layer: nn.Module) -> None: """PTPTC""" @@ -1693,13 +1706,26 @@ def get_fused_moe_quant_config( a2_scale=layer.w2_input_scale, per_act_token_quant=True, ) + elif self.block_quant: + if self.quant_type == QuantType.per_1x128: + block_shape = [128, 128] + elif self.quant_type == QuantType.per_1x32: + block_shape = [1, 32] + else: + block_shape = None + return fp8_w8a8_moe_quant_config( + w1_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, + a1_scale=layer.w13_input_scale, + a2_scale=layer.w2_input_scale, + block_shape=block_shape, + ) else: return fp8_w8a8_moe_quant_config( w1_scale=layer.w13_weight_scale, w2_scale=layer.w2_weight_scale, a1_scale=layer.w13_input_scale, a2_scale=layer.w2_input_scale, - block_shape=None, ) @mark_trace(prefix="fp8_moe", torch_compile=False) @@ -1741,6 +1767,17 @@ def apply( # per_Tensor doesn't support num_local_tokens, so fallback to # rocm_aiter_fused_moe when using per-tensor or no modular kernel. if self.quant_type == QuantType.per_Tensor or self.fused_experts is None: + if not hasattr(self, "_diag_apply_printed"): + self._diag_apply_printed = True + print( + f"[DIAG][Fp8MoE.apply] rocm_aiter path: " + f"quant_type={self.quant_type} " + f"w13 dtype={layer.w13_weight.dtype} shape={tuple(layer.w13_weight.shape)} " + f"w13_scale shape={tuple(layer.w13_weight_scale.shape)} " + f"w13_scale min={layer.w13_weight_scale.data.float().min().item():.6f} " + f"max={layer.w13_weight_scale.data.float().max().item():.6f} " + f"x dtype={x.dtype} shape={tuple(x.shape)}" + ) return torch.ops.aiter.rocm_aiter_fused_moe( x, layer.w13_weight, diff --git a/atom/model_ops/radix_attention.py b/atom/model_ops/radix_attention.py index 25388b384..c85b0251e 100644 --- a/atom/model_ops/radix_attention.py +++ b/atom/model_ops/radix_attention.py @@ -9,6 +9,7 @@ from .base_attention import BaseAttention from atom.plugin.prepare import is_plugin_mode, is_sglang from atom.models.utils import maybe_prefix +from atom.utils import envs class RadixAttention(BaseAttention): @@ -47,23 +48,35 @@ def __init__( prefix=prefix, **kwargs, ) - self.rotary_emb = rotary_emb if is_sglang(): from sglang.srt.layers.radix_attention import RadixAttention + _v_head_dim = mla_modules.kv_lora_rank if (use_mla and mla_modules is not None) else head_dim + self.attn = RadixAttention( num_heads=num_heads, head_dim=head_dim, scaling=scale, num_kv_heads=num_kv_heads, layer_id=layer_num, + v_head_dim=_v_head_dim, prefix=maybe_prefix(prefix, "attn"), ) + if self.attn.k_scale is None: + self.attn.k_scale = torch.nn.Parameter( + torch.tensor([1.0], dtype=torch.float32), requires_grad=False + ) + if self.attn.v_scale is None: + self.attn.v_scale = torch.nn.Parameter( + torch.tensor([1.0], dtype=torch.float32), requires_grad=False + ) else: raise NotImplementedError( "RadixAttention is only supported for plugin mode for sglang for now" ) + # if True, save cache will be done in rope + self.use_aiter_rope_fused_qknorm = envs.ATOM_ROPE_FUSED_QKNORM def forward_impl_plugin_mode( self, @@ -82,10 +95,8 @@ def forward_impl_plugin_mode( # for sglang, forward_batch is required forward_batch = kwargs.get("forward_batch", None) assert forward_batch is not None, "forward_batch is required for sglang" - if self.rotary_emb is not None: - assert positions is not None, "positions is required for ROPE" - query, key = self.rotary_emb(positions, query, key) - return self.attn(q=query, k=key, v=value, forward_batch=forward_batch) + # forward_batch contains the filed attn_backend, which will find the backend registered in ATOM + return self.attn(query, key, value, forward_batch=forward_batch, save_kv_cache=not self.use_aiter_rope_fused_qknorm) else: raise NotImplementedError( "RadixAttention is only supported for plugin mode for sglang for now" diff --git a/atom/models/deepseek_v2.py b/atom/models/deepseek_v2.py index d1da9f052..fe84de8ed 100644 --- a/atom/models/deepseek_v2.py +++ b/atom/models/deepseek_v2.py @@ -23,8 +23,9 @@ # limitations under the License. """Inference-only DeepseekV2/DeepseekV3 model.""" +import json import logging -from typing import Optional, Tuple, Union +from typing import Optional, Tuple, Union, Iterable, Any import torch from aiter import ( @@ -38,7 +39,7 @@ top_k_per_row_prefill, ) from aiter.dist.communication_op import tensor_model_parallel_all_reduce -from aiter.dist.parallel_state import get_pp_group, get_tensor_model_parallel_world_size +from aiter.dist.parallel_state import get_pp_group, get_tensor_model_parallel_world_size, get_tp_group from aiter.jit.utils.torch_guard import torch_compile_guard from aiter.ops.triton.fp8_mqa_logits import fp8_mqa_logits from aiter.ops.triton.fused_fp8_quant import ( @@ -70,6 +71,7 @@ RowParallelLinear, use_triton_gemm, ) +from atom.model_ops.utils import MXFP4_QUANT_BLOCK_SIZE, _has_module, quark_post_load_weights from atom.model_ops.moe import FusedMoE from atom.model_ops.topK import ( is_rocm_aiter_fuse_routed_scaling_factor, @@ -89,9 +91,24 @@ from atom.utils.forward_context import get_forward_context from torch import nn from transformers import PretrainedConfig +from atom.plugin.prepare import is_sglang # from vllm.model_executor.layers.quantization.utils.fp8_utils import per_token_group_quant_fp8 +from sglang.srt.layers.communicator import AttentionInputs, get_attn_tp_context +from sglang.srt.layers.attention.nsa.utils import nsa_use_prefill_cp +from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode +from sglang.srt.configs.model_config import is_deepseek_nsa +from sglang.srt.models.deepseek_common.utils import _use_aiter_gfx95,_use_aiter,_is_gfx95_supported +from sglang.srt.layers.quantization.rocm_mxfp4_utils import batched_gemm_afp4wfp4_pre_quant +from aiter.ops.triton.batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant import ( + batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant, +) +from sglang.srt.layers.quantization.fp8_kernel import ( + fp8_dtype, + per_tensor_quant_mla_fp8, + per_token_group_quant_mla_deep_gemm_masked_fp8, +) logger = logging.getLogger("atom") if use_triton_gemm(): @@ -113,6 +130,34 @@ gemm_a8w8_blockscale_preshuffle = None gemm_a16w8_blockscale_preshuffle = None + +from sgl_kernel import bmm_fp8 as _raw_bmm_fp8 + +from sglang.srt.utils.custom_op import register_custom_op + +# TODO(yuwei): remove this wrapper after sgl-kernel registers its own fake/meta impl +# Wrap bmm_fp8 as a custom op so torch.compile does not trace into +# torch.cuda.current_blas_handle() (which returns a non-Tensor). +@register_custom_op(mutates_args=["out"]) +def _bmm_fp8_op( + A: torch.Tensor, + B: torch.Tensor, + out: torch.Tensor, + A_scale: torch.Tensor, + B_scale: torch.Tensor, +) -> None: + _raw_bmm_fp8(A, B, A_scale, B_scale, out.dtype, out) + +def bmm_fp8(A, B, A_scale, B_scale, dtype, out=None): + if out is None: + out = torch.empty( + (A.shape[0], A.shape[1], B.shape[2]), + device=A.device, + dtype=dtype, + ) + _bmm_fp8_op(A, B, out, A_scale, B_scale) + return out + ENABLE_DS_QKNORM_QUANT_FUSION = envs.ATOM_ENABLE_DS_QKNORM_QUANT_FUSION ENABLE_ALLREDUCE_RMSNORM_FUSION = envs.ATOM_ENABLE_ALLREDUCE_RMSNORM_FUSION ENABLE_DS_INPUT_RMSNORM_QUANT_FUSION = envs.ATOM_ENABLE_DS_INPUT_RMSNORM_QUANT_FUSION @@ -1431,11 +1476,542 @@ def __init__( self.quant_dtype = layer_quant_dtype self.fuse_qknorm_quant = True - def forward( + # for sglang + self.use_nsa = is_deepseek_nsa(config) + self.use_deep_gemm_bmm = False + self.alt_stream = None + # self.w_kc, self.w_vc = self.kv_b_proj.weight.data.unflatten( + # 0, (-1, self.qk_nope_head_dim + self.v_head_dim) + # ).split([self.qk_nope_head_dim, self.v_head_dim], dim=1) + self.w_kc, self.w_vc = None, None + self.w_scale = None + self.w_scale_k = None + self.w_scale_v = None + # self.w_kc, self.w_vc = self.kv_b_proj.weight.data.unflatten( + # 0, (-1, self.qk_nope_head_dim + self.v_head_dim) + # ).split([self.qk_nope_head_dim, self.v_head_dim], dim=1) + + def _forward_sgl_prepare( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + **model_kwargs: dict[str, Any] | None + ) -> torch.Tensor: + # supplementary code, port from forward_common + hidden_states_scale = None + if isinstance(hidden_states, tuple): + hidden_states, hidden_states_scale = hidden_states + + forward_batch = model_kwargs.get("forward_batch", None) + zero_allocator = model_kwargs.get("zero_allocator", None) + llama_4_scaling = model_kwargs.get("llama_4_scaling", None) + q_lora = None + topk_indices = None + # #region agent log + try: + _pos_0 = int(positions.shape[0]) + _hs_0 = int(hidden_states.shape[0]) if hasattr(hidden_states, "shape") else -1 + _tp = int(get_tensor_model_parallel_world_size()) if forward_batch is not None else -1 + with open("/mnt/raid0/zhiyan/plugin_gb/.cursor/debug-17e017.log", "a") as _f: + _f.write(json.dumps({"sessionId": "17e017", "hypothesisId": "A", "location": "deepseek_v2.py:_forward_sgl_prepare_entry", "message": "prepare_entry", "data": {"positions_dim0": _pos_0, "hidden_states_dim0": _hs_0, "tp_world": _tp, "q_lora_rank": getattr(self, "q_lora_rank", None)}, "timestamp": __import__("time").time_ns() // 1000000}) + "\n") + except Exception: + pass + # #endregion + if self.q_lora_rank is not None: + print( + f"[MLA_DBG][_forward_sgl_prepare][layer={self.layer_num}] " + f"positions={tuple(positions.shape)} hidden_states={tuple(hidden_states.shape)} " + f"hs_scale={None if hidden_states_scale is None else tuple(hidden_states_scale.shape)} " + f"seq_lens_sum={getattr(forward_batch, 'seq_lens_sum', None)}" + ) + # qkv_lora = self.fused_qkv_a_proj(hidden_states, hidden_states_scale) + # q, latent_cache = torch.split( + # qkv_lora, + # [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], + # dim=-1 + # ) + + q, latent_cache = ( + get_attn_tp_context() + .fetch_qkv_latent() + .split( + [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], + dim=-1, + ) + ) + print( + f"[MLA_DBG][_forward_sgl_prepare][layer={self.layer_num}] " + f"fetched q={tuple(q.shape)} latent={tuple(latent_cache.shape)} " + f"positions={tuple(positions.shape)} tp_world={get_tensor_model_parallel_world_size()}" + ) + # #region agent log + try: + _q0, _p0, _tp = int(q.shape[0]), int(positions.shape[0]), int(get_tensor_model_parallel_world_size()) + _fallback_cond = _q0 != _p0 and _tp > 1 + with open("/mnt/raid0/zhiyan/plugin_gb/.cursor/debug-17e017.log", "a") as _f: + _f.write(json.dumps({"sessionId": "17e017", "hypothesisId": "B,C", "location": "deepseek_v2.py:after_fetch", "message": "after_fetch", "data": {"q_dim0": _q0, "positions_dim0": _p0, "tp_world": _tp, "fallback_will_run": _fallback_cond}, "timestamp": __import__("time").time_ns() // 1000000}) + "\n") + except Exception: + pass + # #endregion + + if q.shape[0] != positions.shape[0] and get_tensor_model_parallel_world_size() > 1: + qkv_lora = torch.cat([q, latent_cache], dim=-1) + qkv_lora = get_tp_group().all_gather(qkv_lora, dim=0) + if qkv_lora.shape[0] < positions.shape[0]: + raise RuntimeError( + f"qkv_lora gather mismatch: got {qkv_lora.shape[0]}, expected {positions.shape[0]}" + ) + qkv_lora = qkv_lora[: positions.shape[0]] + q, latent_cache = torch.split( + qkv_lora, + [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], + dim=-1, + ) + print( + f"[MLA_DBG][_forward_sgl_prepare][layer={self.layer_num}] " + f"after_fallback_gather q={tuple(q.shape)} latent={tuple(latent_cache.shape)}" + ) + # #region agent log + try: + with open("/mnt/raid0/zhiyan/plugin_gb/.cursor/debug-17e017.log", "a") as _f: + _f.write(json.dumps({"sessionId": "17e017", "hypothesisId": "C", "location": "deepseek_v2.py:after_fallback", "message": "after_fallback", "data": {"q_dim0": int(q.shape[0]), "positions_dim0": int(positions.shape[0])}, "timestamp": __import__("time").time_ns() // 1000000}) + "\n") + except Exception: + pass + # #endregion + + k_nope = latent_cache[..., : self.kv_lora_rank] + + # overlap qk norm + if self.alt_stream is not None and get_is_capture_mode(): + current_stream = torch.cuda.current_stream() + self.alt_stream.wait_stream(current_stream) + q = self.q_a_layernorm(q) + with torch.cuda.stream(self.alt_stream): + k_nope = self.kv_a_layernorm(k_nope) + current_stream.wait_stream(self.alt_stream) + else: + # if _use_aiter_gfx95 and self.q_b_proj.weight.dtype == torch.uint8: + # q, _, k_nope, *_ = fused_rms_mxfp4_quant( + # q, + # self.q_a_layernorm.weight, + # self.q_a_layernorm.variance_epsilon, + # k_nope, + # self.kv_a_layernorm.weight, + # self.kv_a_layernorm.variance_epsilon, + # ) + # else: + q_lora = None + _use_aiter_gfx95 = False + if ( + _use_aiter_gfx95 + and + self.q_b_proj.weight.dtype == torch.float8_e4m3fn + ): + if self.use_nsa: + q_quanted, q_lora, k_nope, _ = fused_rms_fp8_group_quant( + q, + self.q_a_layernorm.weight, + self.q_a_layernorm.variance_epsilon, + k_nope, + self.kv_a_layernorm.weight, + self.kv_a_layernorm.variance_epsilon, + group_size=128, + dtype_quant=torch.float8_e4m3fn, + res1=None, + output_unquantized_inp1=True, + ) + q = q_quanted + else: + q, _, k_nope, _ = fused_rms_fp8_group_quant( + q, + self.q_a_layernorm.weight, + self.q_a_layernorm.variance_epsilon, + k_nope, + self.kv_a_layernorm.weight, + self.kv_a_layernorm.variance_epsilon, + group_size=128, + dtype_quant=torch.float8_e4m3fn, + res1=None, + output_unquantized_inp1=False, + ) + + else: + q = self.q_a_layernorm(q) + k_nope = self.kv_a_layernorm(k_nope) + + # q_lora needed by indexer + if self.use_nsa: + if q_lora is None: + q_lora = q + + # overlap q_b_proj and indexer during decode + if ( + self.alt_stream is not None + and get_is_capture_mode() + and forward_batch.forward_mode.is_decode_or_idle() + and q_lora is not None + ): + current_stream = torch.cuda.current_stream() + self.alt_stream.wait_stream(current_stream) + with torch.cuda.stream(self.alt_stream): + k_nope = k_nope.unsqueeze(1) + q = self.q_b_proj(q).view( + -1, self.num_local_heads, self.qk_head_dim + ) + topk_indices = self.indexer( + x=hidden_states, + q_lora=q_lora, + positions=positions, + forward_batch=forward_batch, + layer_id=self.layer_num, + ) + current_stream.wait_stream(self.alt_stream) + else: + k_nope = k_nope.unsqueeze(1) + q = self.q_b_proj(q).view(-1, self.num_local_heads, self.qk_head_dim) + if q_lora is not None: + topk_indices = self.indexer( + x=hidden_states, + q_lora=q_lora, + positions=positions, + forward_batch=forward_batch, + layer_id=self.layer_num, + ) + else: + q = self.q_proj(hidden_states).view( + -1, self.num_local_heads, self.qk_head_dim + ) + latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0] + k_nope = latent_cache[..., : self.kv_lora_rank] + k_nope = self.kv_a_layernorm(k_nope).unsqueeze(1) + + q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) + k_pe = latent_cache[..., self.kv_lora_rank :].unsqueeze(1) + # #region agent log + try: + with open("/mnt/raid0/zhiyan/plugin_gb/.cursor/debug-17e017.log", "a") as _f: + _f.write(json.dumps({"sessionId": "17e017", "hypothesisId": "D,E", "location": "deepseek_v2.py:before_rope_assert", "message": "before_rope", "data": {"q_pe_dim0": int(q_pe.shape[0]), "k_pe_dim0": int(k_pe.shape[0]), "positions_dim0": int(positions.shape[0]), "q_lora_rank": getattr(self, "q_lora_rank", None)}, "timestamp": __import__("time").time_ns() // 1000000}) + "\n") + except Exception: + pass + # #endregion + print( + f"[MLA_DBG][_forward_sgl_prepare][layer={self.layer_num}] " + f"q_nope={tuple(q_nope.shape)} q_pe={tuple(q_pe.shape)} k_pe={tuple(k_pe.shape)} " + f"positions={tuple(positions.shape)}" + ) + + _is_hip= True + if self.use_deep_gemm_bmm: + q_nope_val, q_nope_scale, masked_m, expected_m, aligned_m = ( + per_token_group_quant_mla_deep_gemm_masked_fp8(q_nope.transpose(0, 1)) + ) + q_nope_out = q_nope.new_empty( + (self.num_local_heads, aligned_m, self.kv_lora_rank) + ) + deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked( + (q_nope_val, q_nope_scale), + (self.w_kc, self.w_scale_k), + q_nope_out, + masked_m, + expected_m, + ) + q_nope_out = q_nope_out[:, :expected_m, :] + elif _is_hip: + # TODO(haishaw): add bmm_fp8 to ROCm + if _use_aiter_gfx95 and self.w_kc.dtype == torch.uint8: + x = q_nope.transpose(0, 1) + q_nope_out = torch.empty( + x.shape[0], + x.shape[1], + self.w_kc.shape[2], + device=x.device, + dtype=torch.bfloat16, + ) + batched_gemm_afp4wfp4_pre_quant( + x, + self.w_kc.transpose(-2, -1), + self.w_scale_k.transpose(-2, -1), + torch.bfloat16, + q_nope_out, + ) + else: + if (_use_aiter_gfx95 and self.w_kc.dtype == torch.float8_e4m3fn) or ( + get_is_capture_mode() and self.w_kc.dtype == torch.float8_e4m3fnuz + ): + # fp8 Triton kernel: always on gfx950, + # cudagraph-only on gfx942 (hides launch overhead) + q_nope_out = batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant( + X=q_nope, + WQ=self.w_kc.transpose(-1, -2), + w_scale=self.mla_attn.w_scale, + group_size=128, + YQ=None, # allocate (B, M, N) + transpose_bm=False, # (B, M, N) + transpose_bm_in=True, # (M, B, K) + dtype=torch.bfloat16, + ) + + else: + q_nope_out = torch.bmm( + q_nope.to(torch.bfloat16).transpose(0, 1), + self.w_kc.to(torch.bfloat16) * self.w_scale, + ) + + elif self.w_kc.dtype == torch.float8_e4m3fn: + # fix bmm_fp8 error under cublas12.9 caused by bumpallocator, detail in pr#11612 + q_nope_val, q_nope_scale = per_tensor_quant_mla_fp8( + q_nope.transpose(0, 1), + ( + torch.zeros((1,), dtype=torch.float32, device=q_nope.device) + # if _is_cublas_ge_129 + # else zero_allocator.allocate(1) + ), + ) + q_nope_out = bmm_fp8( + q_nope_val, self.w_kc, q_nope_scale, self.w_scale, torch.bfloat16 + ) + else: + q_nope_out = torch.bmm(q_nope.transpose(0, 1), self.w_kc) + + q_nope_out = q_nope_out.transpose(0, 1) + + if ( + self.rotary_emb is not None + # and (not self._fuse_rope_for_trtllm_mla(forward_batch)) + and (not _use_aiter or not _is_gfx95_supported or self.use_nsa) + ): + # Optional hard check during debugging + assert q_pe.shape[0] == positions.shape[0], ( + f"q_pe tokens {q_pe.shape[0]} != positions {positions.shape[0]}" + ) + assert k_pe.shape[0] == positions.shape[0], ( + f"k_pe tokens {k_pe.shape[0]} != positions {positions.shape[0]}" + ) + q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe) + + if nsa_use_prefill_cp(forward_batch): + # support allgather+rerrange + k_nope, k_pe = self.rebuild_cp_kv_cache( + latent_cache, forward_batch, k_nope, k_pe + ) + # end forward prepare + return ( + q_pe, + k_pe, + q_nope_out, + k_nope, + forward_batch, + zero_allocator, + positions, + topk_indices, + llama_4_scaling, + ) + + def _forward_sgl_core( + self, + q_pe, + k_pe, + q_nope_out, + k_nope, + forward_batch, + zero_allocator, + positions, + topk_indices, + llama_4_scaling, + ): + # 1) build q/k for radix attention path + _is_hip = True + q = torch.cat([q_nope_out, q_pe], dim=-1) + k = torch.cat([k_nope, k_pe], dim=-1) + + if llama_4_scaling is not None: + q = q * llama_4_scaling + + # 2) attention core + attn_output = self.mla_attn( + q, + k, + k_nope, + forward_batch=forward_batch, + save_kv_cache=True, + **(dict(topk_indices=topk_indices) if topk_indices is not None else {}), + ) + attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank) + + # 3) up-proj by w_vc (port from sglang forward_absorb_core) + if self.use_deep_gemm_bmm: + attn_output_val, attn_output_scale, masked_m, expected_m, aligned_m = ( + per_token_group_quant_mla_deep_gemm_masked_fp8(attn_output.transpose(0, 1)) + ) + attn_bmm_output = attn_output.new_empty( + (self.num_local_heads, aligned_m, self.v_head_dim) + ) + deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked( + (attn_output_val, attn_output_scale), + (self.w_vc, self.w_scale_v), + attn_bmm_output, + masked_m, + expected_m, + ) + attn_bmm_output = ( + attn_bmm_output[:, :expected_m, :].transpose(0, 1).flatten(1, 2) + ) + + elif _is_hip: + if _use_aiter_gfx95 and self.w_vc.dtype == torch.uint8: + x = attn_output.transpose(0, 1) + y = torch.empty( + x.shape[0], + x.shape[1], + self.w_vc.shape[2], + device=x.device, + dtype=torch.bfloat16, + ) + batched_gemm_afp4wfp4_pre_quant( + x, + self.w_vc.transpose(-2, -1), + self.w_scale_v.transpose(-2, -1), + torch.bfloat16, + y, + ) + attn_bmm_output = y.transpose(0, 1).flatten(1, 2) + else: + if _use_aiter_gfx95 and self.w_kc.dtype == torch.float8_e4m3fn: + y = batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant( + X=attn_output, + WQ=self.w_vc.transpose(-1, -2), + w_scale=self.w_scale, + group_size=128, + YQ=None, + transpose_bm=False, + transpose_bm_in=True, + dtype=torch.bfloat16, + ) + else: + y = torch.bmm( + attn_output.to(torch.bfloat16).transpose(0, 1), + self.w_vc.to(torch.bfloat16) * self.w_scale, + ) + attn_bmm_output = y.transpose(0, 1).flatten(1, 2) + + elif self.w_vc.dtype == torch.float8_e4m3fn: + attn_output_val, attn_output_scale = per_tensor_quant_mla_fp8( + attn_output.transpose(0, 1), + torch.zeros((1,), dtype=torch.float32, device=attn_output.device), + ) + attn_bmm_output = bmm_fp8( + attn_output_val, + self.w_vc, + attn_output_scale, + self.w_scale, + torch.bfloat16, + ) + attn_bmm_output = attn_bmm_output.transpose(0, 1).flatten(1, 2) + + else: + attn_bmm_output = torch.bmm(attn_output.transpose(0, 1), self.w_vc) + attn_bmm_output = attn_bmm_output.transpose(0, 1).flatten(1, 2) + + output = self.o_proj(attn_bmm_output) + return output + + def prepare_qkv_latent( + self, + hidden_states: torch.Tensor, + forward_batch, + ): + assert self.q_lora_rank is not None + hidden_states_scale = None + if isinstance(hidden_states, tuple): + hidden_states, hidden_states_scale = hidden_states + print( + f"[MLA_DBG][prepare_qkv_latent][layer={self.layer_num}] " + f"hidden_states={tuple(hidden_states.shape)} " + f"hs_scale={None if hidden_states_scale is None else tuple(hidden_states_scale.shape)} " + f"seq_lens_sum={getattr(forward_batch, 'seq_lens_sum', None)} " + f"positions={None if getattr(forward_batch, 'positions', None) is None else tuple(forward_batch.positions.shape)}" + ) + qkv_lora = self.fused_qkv_a_proj(hidden_states, hidden_states_scale) + print(f"[MLA_DBG][prepare_qkv_latent][layer={self.layer_num}] qkv_lora={tuple(qkv_lora.shape)}") + + # Fallback: when communicator does not enable input_scattered gather, + # force qkv latent token dimension to align with positions. + # Use positions.shape[0] (actual input token count) instead of + # seq_lens_sum (total KV cache length, wrong for decode mode). + expected_tokens = 0 + if hasattr(forward_batch, "positions") and forward_batch.positions is not None: + expected_tokens = int(forward_batch.positions.shape[0]) + if expected_tokens <= 0: + expected_tokens = int(getattr(forward_batch, "seq_lens_sum", 0) or 0) + + if ( + expected_tokens > 0 + and qkv_lora.shape[0] != expected_tokens + and get_tensor_model_parallel_world_size() > 1 + ): + print( + f"[MLA_DBG][prepare_qkv_latent][layer={self.layer_num}] before_fallback_gather " + f"qkv_lora={tuple(qkv_lora.shape)} expected={expected_tokens} " + f"tp_world={get_tensor_model_parallel_world_size()}" + ) + qkv_lora = get_tp_group().all_gather(qkv_lora, dim=0) + if qkv_lora.shape[0] > expected_tokens: + qkv_lora = qkv_lora[:expected_tokens] + elif qkv_lora.shape[0] < expected_tokens: + raise RuntimeError( + f"prepare_qkv_latent gather mismatch: got {qkv_lora.shape[0]}, " + f"expected {expected_tokens}" + ) + print( + f"[MLA_DBG][prepare_qkv_latent][layer={self.layer_num}] return_qkv_lora={tuple(qkv_lora.shape)} expected={expected_tokens}" + ) + return qkv_lora + + + def forward_sgl_plugin_mode( self, positions: torch.Tensor, hidden_states: torch.Tensor, + **model_kwargs: dict[str, Any] | None ) -> torch.Tensor: + # forward_absorb_prepare sglang + from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode + forward_batch = model_kwargs.get("forward_batch", None) + if forward_batch is None: + raise RuntimeError("forward_batch is required in forward_sgl_plugin_mode") + + attn_tp_context = get_attn_tp_context() + print( + f"[MLA_DBG][forward_sgl_plugin_mode][layer={self.layer_num}] " + f"positions={tuple(positions.shape)} " + f"hidden_states={'tuple' if isinstance(hidden_states, tuple) else tuple(hidden_states.shape)} " + f"seq_lens_sum={getattr(forward_batch, 'seq_lens_sum', None)} " + f"input_ids={None if getattr(forward_batch, 'input_ids', None) is None else tuple(forward_batch.input_ids.shape)} " + f"allow_scatter={attn_tp_context.allow_input_scattered}" + ) + with attn_tp_context.maybe_input_scattered(forward_batch): + print( + f"[MLA_DBG][forward_sgl_plugin_mode][layer={self.layer_num}] " + f"input_scattered={attn_tp_context.input_scattered}" + ) + if self.q_lora_rank is not None: + attn_tp_context.set_attn_inputs( + AttentionInputs( + hidden_states, + forward_batch, + self.prepare_qkv_latent, + ) + ) + prepared = self._forward_sgl_prepare(positions, hidden_states, **model_kwargs) + return self._forward_sgl_core(*prepared) + + def forward_common( + + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + **model_kwargs: dict[str, Any] | None + ): hidden_states_scale = None if isinstance(hidden_states, tuple): hidden_states, hidden_states_scale = hidden_states @@ -1522,7 +2098,200 @@ def forward( k_pe, positions, hidden_states_or_q_c_scale, + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + **model_kwargs: dict[str, Any] | None + ) -> torch.Tensor: + if is_sglang(): + attn_output = self.forward_sgl_plugin_mode(positions, hidden_states, **model_kwargs) + else: + attn_output = self.forward_common(positions, hidden_states, **model_kwargs) + return attn_output + + def process_weights_after_loading(self) -> None: + # only for sglang plugin mode + if not is_sglang(): + return + self._process_mla_kv_b_proj_after_loading_sgl() + + def _process_mla_kv_b_proj_after_loading_sgl(self) -> None: + # lazy imports: only needed for sglang plugin path + from atom.model_ops.utils import normalize_e4m3fn_to_e4m3fnuz + from sglang.srt.layers.quantization.fp8_utils import ( + block_quant_dequant, + block_quant_to_tensor_quant, + channel_quant_to_tensor_quant, + inverse_transform_scale_ue8m0, + ) + from sglang.srt.layers.quantization.int8_utils import ( + block_dequant as int8_block_dequant, + ) + from sglang.srt.layers.deep_gemm_wrapper import ( + ENABLE_JIT_DEEPGEMM, + DEEPGEMM_BLACKWELL, ) + from sglang.srt.model_loader.utils import should_deepgemm_weight_requant_ue8m0 + from sglang.srt.models.deepseek_common.utils import ( + _is_cpu, + _is_cpu_amx_available, + _is_cuda, + _is_fp8_fnuz, + _is_hip, + _is_npu, + _use_aiter_gfx95, + awq_dequantize_func, + ) + from sglang.srt.utils import bind_or_assign, get_bool_env_var + + # read kv_b_proj weight (awq compatible) + if hasattr(self.kv_b_proj, "qweight"): + awq_dequantize_f = awq_dequantize_func() + if awq_dequantize_f is None: + raise ValueError("AWQ dequantize function is not supported for current device") + w = awq_dequantize_f( + self.kv_b_proj.qweight, + self.kv_b_proj.scales, + self.kv_b_proj.qzeros, + ).T + else: + w = self.kv_b_proj.weight + + # On ROCm, ATOM creates parameters with fnuz dtype but loads fn bytes + # into them (weight_loader_process view-casts a detached copy, leaving + # the nn.Parameter as fnuz). At this point LinearBase's + # process_weights_after_loading hasn't run yet (parent module iterates + # before child in named_modules). View-cast back to fn so the + # normalize path works correctly. + if _is_fp8_fnuz and w.dtype == torch.float8_e4m3fnuz: + w = w.view(torch.float8_e4m3fn) + + use_deep_gemm_bmm = False + block_scale = None + weight_block_size = None + + # Derive weight_block_size from ATOM's quant_type system + from aiter import QuantType as _AiterQuantType + _atom_qt = getattr(self.kv_b_proj, "quant_type", None) + if _atom_qt == _AiterQuantType.per_1x128: + weight_block_size = [128, 128] + elif _atom_qt == _AiterQuantType.per_1x32: + weight_block_size = [1, 32] + + # fp8 path + if w.dtype in (torch.float8_e4m3fn, torch.float8_e4m3fnuz): + if weight_block_size is not None: + assert hasattr(self.kv_b_proj, "weight_scale_inv") or hasattr(self.kv_b_proj, "weight_scale") + weight_scale = ( + self.kv_b_proj.weight_scale + if hasattr(self.kv_b_proj, "weight_scale") + else self.kv_b_proj.weight_scale_inv + ) + + if _is_fp8_fnuz and w.dtype == torch.float8_e4m3fn: + weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz( + weight=w, + weight_scale=weight_scale, + input_scale=None, + ) + else: + weight = w + + if ( + should_deepgemm_weight_requant_ue8m0( + weight_block_size=weight_block_size + ) + and getattr(weight_scale, "format_ue8m0", False) + ): + weight_scale = inverse_transform_scale_ue8m0(weight_scale, mn=weight.shape[-2]) + + if _is_cuda and weight_block_size[0] == 128 and weight_block_size[1] == 128: + if ( + ENABLE_JIT_DEEPGEMM + and not DEEPGEMM_BLACKWELL + and get_bool_env_var("SGL_USE_DEEPGEMM_BMM", "false") + ): + block_scale = weight_scale + use_deep_gemm_bmm = True + else: + w = block_quant_dequant( + weight, + weight_scale, + weight_block_size, + torch.bfloat16, + ) + else: + w, scale = block_quant_to_tensor_quant(weight, weight_scale, weight_block_size) + self.w_scale = scale + else: + if w.dtype == torch.float8_e4m3fn and _is_fp8_fnuz: + weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz( + weight=w, + weight_scale=self.kv_b_proj.weight_scale, + input_scale=None, + ) + else: + weight = w + weight_scale = self.kv_b_proj.weight_scale + + w, scale = channel_quant_to_tensor_quant(weight, weight_scale) + self.w_scale = scale + + # int8 path + if w.dtype == torch.int8: + if weight_block_size is not None: + assert hasattr(self.kv_b_proj, "weight_scale_inv") + w = int8_block_dequant( + w, + self.kv_b_proj.weight_scale_inv, + weight_block_size, + ).to(torch.bfloat16) + else: + w = w.to(torch.bfloat16) * self.kv_b_proj.weight_scale.to(torch.bfloat16) + + # split to kc/vc + w_kc, w_vc = w.unflatten( + 0, (-1, self.qk_nope_head_dim + self.v_head_dim) + ).split([self.qk_nope_head_dim, self.v_head_dim], dim=1) + + # quark fp4 special path (ATOM utility) + quant_method = getattr(self.kv_b_proj, "quant_method", None) + quant_config = getattr(quant_method, "quant_config", None) + if _use_aiter_gfx95 and quant_config is not None and quant_config.get_name() == "quark": + w_kc, self.w_scale_k, w_vc, self.w_scale_v = quark_post_load_weights(self, w, "mxfp4") + + if not use_deep_gemm_bmm: + self.w_kc = bind_or_assign( + self.w_kc, w_kc.transpose(1, 2).contiguous().transpose(1, 2) + ) + w_vc = w_vc.contiguous().transpose(1, 2) + if _is_npu: + w_vc = w_vc.contiguous() + self.w_vc = bind_or_assign(self.w_vc, w_vc) + + if hasattr(self.kv_b_proj, "weight_scale") and self.w_scale is None: + self.w_scale = bind_or_assign(self.w_scale, self.kv_b_proj.weight_scale) + if _is_hip: + self.w_scale *= 2.0 + + if _is_cpu and _is_cpu_amx_available and w.dtype == torch.float8_e4m3fn: + self.w_kc = self.w_kc.to(torch.bfloat16) * self.w_scale + self.w_vc = self.w_vc.to(torch.bfloat16) * self.w_scale + else: + num_tiles_k = self.qk_nope_head_dim // weight_block_size[1] + num_tiles_n = self.v_head_dim // weight_block_size[0] + ws_kc, ws_vc = block_scale.unflatten(0, (-1, (num_tiles_k + num_tiles_n))).split( + [num_tiles_k, num_tiles_n], dim=1 + ) + + self.w_scale_k = bind_or_assign(self.w_scale_k, ws_kc.transpose(1, 2).contiguous()) + self.w_scale_v = bind_or_assign(self.w_scale_v, ws_vc.contiguous()) + self.w_kc = bind_or_assign(self.w_kc, w_kc.transpose(1, 2).contiguous()) + self.w_vc = bind_or_assign(self.w_vc, w_vc.contiguous()) + self.use_deep_gemm_bmm = True class DeepseekV2DecoderLayer(nn.Module): @@ -1627,14 +2396,20 @@ def __init__( self.fuse_rmsnorm_quant = ( ENABLE_DS_INPUT_RMSNORM_QUANT_FUSION and self.quant_dtype is not None ) - def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, residual: Optional[torch.Tensor], + **model_kwargs: dict[str, Any] | None ) -> torch.Tensor: # Self Attention + print( + f"[MLA_DBG][decoder_layer][layer={self.layer_idx}] positions={tuple(positions.shape)} " + f"hidden_states={'tuple' if isinstance(hidden_states, tuple) else tuple(hidden_states.shape)} " + f"residual={None if residual is None else tuple(residual.shape)} " + f"fuse_input_norm_quant={self.fuse_input_norm_quant}" + ) if self.fuse_input_norm_quant: assert self.quant_dtype is not None weight = self.input_layernorm.weight @@ -1689,6 +2464,7 @@ def forward( hidden_states = self.self_attn( positions=positions, hidden_states=hidden_states, + **model_kwargs, ) if hidden_states.dtype == torch.float16: @@ -1788,6 +2564,7 @@ def forward( positions: torch.Tensor, intermediate_tensors: Optional[IntermediateTensors], inputs_embeds: Optional[torch.Tensor] = None, + **model_kwargs: dict[str, Any] | None ) -> Union[torch.Tensor, IntermediateTensors]: if get_pp_group().is_first_rank: if inputs_embeds is not None: @@ -1801,7 +2578,7 @@ def forward( residual = intermediate_tensors["residual"] for layer in self.layers[self.start_layer : self.end_layer]: - hidden_states, residual = layer(positions, hidden_states, residual) + hidden_states, residual = layer(positions, hidden_states, residual, **model_kwargs) if not get_pp_group().is_last_rank: return IntermediateTensors( @@ -1840,6 +2617,7 @@ def __init__( quant_config = atom_config.quant_config self.config = config self.quant_config = quant_config + self.atom_config = atom_config if hasattr(config, "q_lora_rank") and config.q_lora_rank is not None: self.packed_modules_mapping = { @@ -1873,6 +2651,10 @@ def __init__( self.model.make_empty_intermediate_tensors ) + if is_sglang(): + from sglang.srt.configs.model_config import is_deepseek_nsa + get_attn_tp_context().init_context(config.q_lora_rank, is_deepseek_nsa(config)) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) @@ -1882,9 +2664,11 @@ def forward( positions: torch.Tensor, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, + **model_kwargs: dict[str, Any] | None ) -> Union[torch.Tensor, IntermediateTensors]: hidden_states = self.model( - input_ids, positions, intermediate_tensors, inputs_embeds + input_ids, positions, intermediate_tensors, inputs_embeds, + **model_kwargs, ) return hidden_states @@ -1912,6 +2696,19 @@ def make_empty_intermediate_tensors( def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: return self.model.get_expert_mapping() + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + # load weights in plugin mode and discard passed weights generator + # here prefix is "model." because Qwen3MoeForCausalLM is constructed in model + # wrapper class, so the name of loaded weights are prefixed with "model.". + # The vLLM will check the name of the loaded weights to make sure all the + # weights are loaded correctly + + # lazy import to avoid circular import issue since model_loader also imports model.. + from atom.model_loader.loader import load_model_in_plugin_mode + loaded_weights_record = load_model_in_plugin_mode( + model=self, config=self.atom_config, prefix="model." + ) + return loaded_weights_record class DeepseekV3ForCausalLM(DeepseekV2ForCausalLM): pass diff --git a/atom/models/qwen3_moe.py b/atom/models/qwen3_moe.py index 4d68bd42e..71eec8b17 100644 --- a/atom/models/qwen3_moe.py +++ b/atom/models/qwen3_moe.py @@ -1,11 +1,11 @@ -from typing import Optional, Union, Any +from typing import Optional, Union, Any, Iterable import torch from aiter.dist.communication_op import tensor_model_parallel_all_reduce from aiter.dist.parallel_state import get_pp_group, get_tensor_model_parallel_world_size # from atom.model_ops.rotary_embedding import get_rope -from aiter.rotary_embedding import get_rope +from aiter.rotary_embedding import get_rope, AiterFusedSetKVBufferArg from atom.config import Config, QuantizationConfig from atom.model_ops.activation import SiluAndMul @@ -33,6 +33,8 @@ ) from atom.utils import envs from torch import nn +from atom.model_loader.loader import load_model_in_plugin_mode +from atom.plugin.prepare import is_sglang # import torch.distributed as dist from transformers import PretrainedConfig @@ -41,7 +43,7 @@ ATOM_ENABLE_QK_NORM_ROPE_CACHE_QUANT_FUSION = ( envs.ATOM_ENABLE_QK_NORM_ROPE_CACHE_QUANT_FUSION ) - +ENABLE_AITER_ROPE_FUSED_QKNORM_FOR_SGL_PLUGIN_MODE = envs.ATOM_ROPE_FUSED_QKNORM class Qwen3MoeMLP(nn.Module): def __init__( @@ -229,6 +231,61 @@ def __init__( self.kv_cache_dtype = kv_cache_dtype self.layer_num = layer_num + self.k_scale = torch.tensor([1.0], dtype=torch.float32) + self.v_scale = torch.tensor([1.0], dtype=torch.float32) + + def forward_sgl_plugin_mode( + self, + positions: torch.Tensor, + qkv: torch.Tensor, + **model_kwargs: dict[str, Any] | None, + ): + if ENABLE_AITER_ROPE_FUSED_QKNORM_FOR_SGL_PLUGIN_MODE: + forward_batch = model_kwargs.get("forward_batch", None) + assert forward_batch is not None, "forward_batch is required for sglang" + k_buffer, v_buffer = forward_batch.token_to_kv_pool.get_kv_buffer(self.layer_num) + block_size = 1024 # Default fallback + if hasattr(forward_batch, 'attn_backend') and hasattr(forward_batch.attn_backend, 'page_size'): + block_size = forward_batch.attn_backend.page_size + elif hasattr(forward_batch.token_to_kv_pool, 'allocator') and hasattr(forward_batch.token_to_kv_pool.allocator, 'page_size'): + block_size = forward_batch.token_to_kv_pool.allocator.page_size + elif hasattr(forward_batch.token_to_kv_pool, 'page_size'): + block_size = forward_batch.token_to_kv_pool.page_size + x = 16 // k_buffer.element_size() + aiter_fused_set_kv_buffer_arg = AiterFusedSetKVBufferArg( + kv_cache = (k_buffer, v_buffer), + cache_loc = forward_batch.out_cache_loc, + k_scale = self.k_scale, + v_scale = self.v_scale, + return_kv = True, + use_shuffle_layout = True, + block_size = block_size, + x = x, + ) + q, k, v = self.rotary_emb( + qkv, + self.q_norm.weight, + self.k_norm.weight, + positions, + self.num_heads, + self.num_kv_heads, + self.q_norm.eps, + fused_set_kv_buffer_arg=aiter_fused_set_kv_buffer_arg, + ) + else: + q, k, v = torch.split( + qkv, [self.q_size, self.kv_size, self.kv_size], dim=-1 + ) + # Add qk-norm + q = self.q_norm(q) + k = self.k_norm(k) + + q, k = self.rotary_emb(positions, q, k) + + attn_output = self.attn( + q, k, v, positions=positions, **model_kwargs + ) + return attn_output def forward( self, @@ -246,13 +303,16 @@ def forward( query=q, key=k, value=v, positions=positions, q_scale=None, qkv=qkv ) else: - # Add qk-norm - q = self.q_norm(q) - k = self.k_norm(k) + if is_sglang(): + attn_output = self.forward_sgl_plugin_mode(positions, qkv, **model_kwargs) + else: + # Add qk-norm + q = self.q_norm(q) + k = self.k_norm(k) - attn_output = self.attn( - query=q, key=k, value=v, positions=positions, **model_kwargs - ) + attn_output = self.attn( + query=q, key=k, value=v, positions=positions, **model_kwargs + ) output = self.o_proj(attn_output) return output @@ -266,7 +326,7 @@ def __init__(self, atom_config=None, layer_num: int = 0, prefix: str = "") -> No self.hidden_size = config.hidden_size rope_params = config.rope_parameters rope_theta = rope_params["rope_theta"] - rope_scaling = rope_params + rope_scaling = None if rope_params["rope_type"] == "default" else rope_params kv_cache_dtype = atom_config.kv_cache_dtype max_position_embeddings = getattr(config, "max_position_embeddings", 8192) # DecoderLayers are created with `make_layers` which passes the prefix @@ -348,7 +408,7 @@ def forward( @support_torch_compile -class Qwen3MoeModel(nn.Module): +class Qwen3MoeModel(torch.nn.Module): def __init__( self, atom_config: Config, @@ -522,3 +582,14 @@ def make_empty_intermediate_tensors( def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: return self.model.get_expert_mapping() + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + # load weights in plugin mode and discard passed weights generator + # here prefix is "model." because Qwen3MoeForCausalLM is constructed in model + # wrapper class, so the name of loaded weights are prefixed with "model.". + # The vLLM will check the name of the loaded weights to make sure all the + # weights are loaded correctly + loaded_weights_record = load_model_in_plugin_mode( + model=self, config=self.atom_config, prefix="model." + ) + return loaded_weights_record \ No newline at end of file diff --git a/atom/plugin/attention_backend/__init__.py b/atom/plugin/attention_backend/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/atom/plugin/attention_backend/sgl_attn_backend.py b/atom/plugin/attention_backend/sgl_attn_backend.py new file mode 100644 index 000000000..2ececc348 --- /dev/null +++ b/atom/plugin/attention_backend/sgl_attn_backend.py @@ -0,0 +1,1773 @@ +from __future__ import annotations + +""" +end to end attention solution with aiter kernels +""" + +from dataclasses import dataclass +from typing import TYPE_CHECKING, Optional + +import torch + +import sglang.srt.layers.attention.aiter_backend as _sglang_aiter +from sglang.srt.layers.attention.aiter_backend import AiterAttnBackend +from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton +from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode +from sglang.srt.utils import get_bool_env_var + +if TYPE_CHECKING: + from sglang.srt.layers.radix_attention import RadixAttention + from sglang.srt.model_executor.model_runner import ModelRunner + from sglang.srt.speculative.spec_info import SpecInput + +try: + from aiter import ( + flash_attn_varlen_func, + dtypes, + get_pa_metadata_info_v1, + get_pa_metadata_v1, + pa_fwd_asm, + pa_persistent_fwd, + mla_decode_fwd, + ) +except ImportError: + print( + "aiter is AMD specific kernel library. Please make sure aiter is installed on your AMD device." + ) + +# MLA prefill kernels - imported separately to avoid breaking the main aiter imports +mla_prefill_ps_asm_fwd = None +mla_reduce_v1 = None +mla_prefill_fwd = None +try: + from aiter import mla_prefill_ps_asm_fwd +except ImportError: + pass +try: + from aiter import mla_reduce_v1 +except ImportError: + pass +try: + from aiter.mla import mla_prefill_fwd + from aiter.mla import mla_decode_fwd +except ImportError: + pass + +import triton +import triton.language as tl + + +@triton.jit +def reshape_and_cache_shuffle_kernel( + key_ptr, # [num_tokens, num_kv_heads, head_size] + value_ptr, # [num_tokens, num_kv_heads, head_size] + key_cache_ptr, # [num_blocks, num_kv_heads, head_size // x, block_size, x] + value_cache_ptr, # [num_blocks, num_kv_heads, block_size // x, head_size, x] + slot_mapping_ptr, # [num_tokens] + k_scale_ptr, + v_scale_ptr, + x, + k_stride0, + v_stride0, + block_size, + head_size, + num_kv_heads, + BLOCK_SIZE: tl.constexpr, + QUANT: tl.constexpr, +): + tid = tl.program_id(0) + head_id = tl.program_id(1) + offset = tl.arange(0, BLOCK_SIZE) + src_offset_k = tid * k_stride0 + head_id * head_size + src_offset_v = tid * v_stride0 + head_id * head_size + slot_id = tl.load(slot_mapping_ptr + tid) + if slot_id < 0: + return + block_id = slot_id // block_size + block_offset = slot_id % block_size + dst_offset = ( + block_id * num_kv_heads * head_size * block_size + + head_id * head_size * block_size + ) + dst_k_shuffle_offset = ( + dst_offset + offset // x * block_size * x + block_offset * x + offset % x + ) + dst_v_shuffle_offset = ( + dst_offset + + block_offset // x * head_size * x + + offset * x + + block_offset % x + ) + k_val = tl.load(key_ptr + src_offset_k + offset) + v_val = tl.load(value_ptr + src_offset_v + offset) + if QUANT: + k_scale = tl.load(k_scale_ptr) + v_scale = tl.load(v_scale_ptr) + k_dtype = key_cache_ptr.type.element_ty + v_dtype = value_cache_ptr.type.element_ty + k_val = (k_val.to(tl.float32) / k_scale).to(k_dtype) + v_val = (v_val.to(tl.float32) / v_scale).to(v_dtype) + tl.store(key_cache_ptr + dst_k_shuffle_offset, k_val) + tl.store(value_cache_ptr + dst_v_shuffle_offset, v_val) + +def reshape_and_cache_shuffle_triton( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, + k_scales: torch.Tensor, + v_scales: torch.Tensor, +): + num_tokens = slot_mapping.shape[0] + _, num_kv_heads, head_size = key.shape + num_blocks, block_size, _, _ = key_cache.shape + x = 16 // key_cache.element_size() + k_cache_template = torch.empty( + [num_blocks, num_kv_heads, head_size // x, block_size, x], + dtype=key_cache.dtype, + device="meta", + ) + v_cache_template = torch.empty( + [num_blocks, num_kv_heads, block_size // x, head_size, x], + dtype=value_cache.dtype, + device="meta", + ) + new_key_cache = key_cache.view_as(k_cache_template) + new_value_cache = value_cache.view_as(v_cache_template) + QUANT = False + if kv_cache_dtype.startswith("fp8"): + QUANT = True + grid = ( + num_tokens, + num_kv_heads, + ) + reshape_and_cache_shuffle_kernel[grid]( + key, + value, + new_key_cache, + new_value_cache, + slot_mapping, + k_scales, + v_scales, + x, + key.stride(0), + value.stride(0), + block_size, + head_size, + num_kv_heads, + BLOCK_SIZE=head_size, + QUANT=QUANT, + ) + +@dataclass +class ForwardMetadata: + # kv_indptr and kv_indices are only used in MLA mode, optional for non-MLA mode + kv_indptr: Optional[torch.Tensor] + kv_indices: Optional[torch.Tensor] + qo_indptr: Optional[torch.Tensor] + kv_last_page_len: Optional[torch.Tensor] + max_q_len: Optional[int] + max_kv_len: Optional[int] + page_table: Optional[torch.Tensor] + kv_lens: Optional[torch.Tensor] + # mla + work_metadata: Optional[torch.Tensor] = None + work_info_set: Optional[torch.Tensor] = None + work_indptr: Optional[torch.Tensor] = None + reduce_indptr: Optional[torch.Tensor] = None + reduce_final_map: Optional[torch.Tensor] = None + reduce_partial_map: Optional[torch.Tensor] = None + fp8_prefill_kv_indices: Optional[torch.Tensor] = None + num_kv_splits: Optional[int] = None + # PA metadata for pa_persistent_fwd (only used in decode mode, non-MLA) + pa_metadata_qo_indptr: Optional[torch.Tensor] = None + pa_metadata_pages_kv_indptr: Optional[torch.Tensor] = None + pa_metadata_kv_indices: Optional[torch.Tensor] = None + pa_metadata_context_lens: Optional[torch.Tensor] = None + pa_metadata_max_qlen: Optional[int] = None + pa_metadata_tp_q_head_num: Optional[int] = None + # Prefill metadata for mha_batch_prefill_func (only used in prefill mode, non-MLA) + # prefill_pages_kv_indptr: Optional[torch.Tensor] = None + # prefill_kv_indices: Optional[torch.Tensor] = None + # prefill_kv_last_page_lens: Optional[torch.Tensor] = None + + + +class ATOMAttnBackendForSgl(AiterAttnBackend): + def __init__( + self, + model_runner: ModelRunner, + skip_prefill: bool = False, + kv_indptr_buf: Optional[torch.Tensor] = None, + ): + super().__init__(model_runner, skip_prefill, kv_indptr_buf) + mapping = getattr( + model_runner.token_to_kv_pool, "full_attention_layer_id_mapping", None + ) + + if isinstance(mapping, dict) and mapping: + first_full_attn_id = next(iter(mapping.keys())) + else: + first_full_attn_id = 0 + + self.q_dtype = model_runner.dtype # Save q dtype for pa_metadata building + + # assert not self.use_mla, "MLA mode is not implemented yet in ATOMAttnBackendForSgl." + + # Pre-initialized qo_indptr for pa_persistent_fwd decode mode: [0, 1, 2, ..., max_bs] + # In decode mode, each sequence has 1 token, so this is always [0, 1, 2, ..., batch_size] + max_bs = model_runner.req_to_token_pool.size + self.pa_decode_qo_indptr = torch.arange( + 0, max_bs + 1, dtype=torch.int32, device=model_runner.device + ) + self.seq_lens = torch.zeros( + (max_bs,), dtype=torch.int32, device=model_runner.device + ) + self.page_table = torch.zeros( + (max_bs, self.max_context_len // self.page_size), dtype=torch.int32, device=model_runner.device + ) + # Pre-compute strided indices for page_table construction (used in both CUDA Graph and non-CUDA Graph modes) + self.strided_indices = torch.arange( + 0, self.max_context_len, self.page_size, device=model_runner.device + ) + + if not self.use_mla: + # Pre-allocate buffers for pa_persistent_fwd (used in both CUDA graph and non-CUDA graph modes) + max_num_blocks_per_seq = (self.max_context_len + self.page_size - 1) // self.page_size + max_total_blocks = max_bs * max_num_blocks_per_seq + self.pa_kv_indices = torch.zeros( + max_total_blocks, dtype=torch.int32, device=self.device + ) + # Pre-allocate pa_kv_indptr buffer (similar to self.kv_indptr, but dedicated for pa_persistent_fwd) + self.pa_kv_indptr = torch.zeros( + (max_bs + 1,), dtype=torch.int32, device=self.device + ) + # Pre-initialized batch indices [0, 1, 2, ..., max_bs-1] for Triton kernel + self.pa_batch_indices = torch.arange( + 0, max_bs, dtype=torch.int32, device=self.device + ) + + # Pre-allocated descale tensors for FP8 attention (q, k, v all use scale=1.0) + + + self.logits_soft_cap = 0.0 + + self.forward_metadata: ForwardMetadata = None + + self.pa_metadata_buffers = None + + k_buffer, _ = model_runner.token_to_kv_pool.get_kv_buffer(first_full_attn_id) + num_slots, num_kv_heads, _ = k_buffer.shape + block_size = self.page_size + num_blocks = num_slots // block_size + max_total_tokens = num_blocks * block_size + self.k_qscale = torch.ones( + num_kv_heads, max_total_tokens, dtype=torch.float32, device=self.device + ) + self.v_qscale = torch.ones( + num_kv_heads, max_total_tokens, dtype=torch.float32, device=self.device + ) + self.decode_using_pa_ps = self.page_size == 1024 + + + + + def init_forward_metadata(self, forward_batch: ForwardBatch): + """Init auxiliary variables for triton attention backend.""" + bs = forward_batch.batch_size + kv_indptr = self.kv_indptr + spec_info = forward_batch.spec_info + qo_indptr = None + kv_last_page_len = None + max_q_len = None + page_table = None + + if forward_batch.forward_mode.is_decode_or_idle(): + if spec_info is None: + kv_indptr[1 : bs + 1] = torch.cumsum(forward_batch.seq_lens, dim=0) + kv_indptr = kv_indptr[: bs + 1] + kv_indices = torch.empty( + forward_batch.seq_lens_sum, dtype=torch.int32, device=self.device + ) + create_flashinfer_kv_indices_triton[(bs,)]( + self.req_to_token, + forward_batch.req_pool_indices, + forward_batch.seq_lens, + kv_indptr, + None, + kv_indices, + self.req_to_token.stride(0), + ) + else: + kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices + bs = kv_indptr.shape[0] - 1 + + if self.use_mla: + qo_indptr = self.qo_indptr_[: bs + 1] + qo_indptr[1 : bs + 1] = torch.cumsum(self.kv_last_page_len[:bs], dim=0) + kv_last_page_len = self.kv_last_page_len[:bs] + max_q_len = 1 + + work_metadata = None + work_indptr = None + work_info_set = None + reduce_indptr = None + reduce_final_map = None + reduce_partial_map = None + num_kv_splits = None + + if _sglang_aiter._use_mla_ps_kernel: + ( + work_metadata, + work_indptr, + work_info_set, + reduce_indptr, + reduce_final_map, + reduce_partial_map, + ) = self.make_mla_decode_meta_data_buffer(max_q_len, bs) + + num_kv_splits = self.max_split_per_batch + + self.make_mla_meta_data( + qo_indptr, + kv_indptr, + kv_last_page_len, + work_metadata, + work_info_set, + work_indptr, + reduce_indptr, + reduce_final_map, + reduce_partial_map, + max_q_len, + fast_mode=_sglang_aiter.fast_mode, + max_split_per_batch=num_kv_splits, + intra_batch_mode=_sglang_aiter.intra_batch_mode, + ) + + self.forward_metadata = ForwardMetadata( + kv_indptr, + kv_indices, + qo_indptr, + kv_last_page_len, + max_q_len, + None, # max_kv_len + None, # page_table + None, # kv_lens + work_metadata=work_metadata, + work_info_set=work_info_set, + work_indptr=work_indptr, + reduce_indptr=reduce_indptr, + reduce_final_map=reduce_final_map, + reduce_partial_map=reduce_partial_map, + num_kv_splits=num_kv_splits, + ) + + else: + if self.decode_using_pa_ps: + # Non-MLA decode mode: use same logic as CUDA Graph mode for page_table construction + seq_lens_cpu = forward_batch.seq_lens_cpu + if seq_lens_cpu is None: + seq_lens_cpu = forward_batch.seq_lens.cpu() + + # Common setup consistent with CUDA Graph mode (init_forward_metadata_replay_cuda_graph) + page_table_persistent = self.page_table + seq_lens_persistent = self.seq_lens + seq_lens_persistent.fill_(0) + page_table_persistent.fill_(0) + seq_lens_persistent[:bs].copy_(forward_batch.seq_lens, non_blocking=True) + max_seq_pages = (seq_lens_cpu.max().item() + self.page_size - 1) // self.page_size + 1 + page_table = self.req_to_token[forward_batch.req_pool_indices[:, None], self.strided_indices[:max_seq_pages][None, :],] + page_table_persistent[:bs, :max_seq_pages].copy_(page_table // self.page_size, non_blocking=True) + else: + page_table = forward_batch.req_to_token_pool.req_to_token[forward_batch.req_pool_indices, :] + + self.forward_metadata = ForwardMetadata( + kv_indptr, + kv_indices, + None, # qo_indptr not used in non-MLA mode + None, # kv_last_page_len not used in non-MLA mode + 1, # max_q_len = 1 for decode mode + None, + page_table_persistent[:bs, :max_seq_pages] if self.decode_using_pa_ps else page_table, + seq_lens_persistent[:bs] if self.decode_using_pa_ps else forward_batch.seq_lens, + ) + + # Build pa_metadata for pa_persistent_fwd + if self.decode_using_pa_ps: + self._build_pa_metadata_for_decode(bs, tp_q_head_num=self.num_head) + # return # Early return for non-MLA decode mode + else: + prefix_lens = forward_batch.extend_prefix_lens + + if self.use_mla: + # raise NotImplementedError("MLA prefill mode is not implemented yet in ATOMAttnBackendForSgl.") + self.mla_indices_updater_prefill.update( + forward_batch.req_pool_indices, + forward_batch.seq_lens, + forward_batch.seq_lens_sum, + forward_batch.extend_seq_lens, + forward_batch.extend_seq_lens.max().item(), + forward_batch.seq_lens.max().item(), + spec_info=None + ) + + max_q_len = self.mla_indices_updater_prefill.max_q_len + qo_indptr = self.mla_indices_updater_prefill.qo_indptr + + work_metadata = None + work_indptr = None + work_info_set = None + reduce_indptr = None + reduce_final_map = None + fp8_prefill_kv_indices = None + reduce_partial_map = None + + from sglang.srt.utils import is_gfx95_supported + _use_fp8_prefill_attn = ( + get_bool_env_var("SGLANG_AITER_FP8_PREFILL_ATTN", "True") and is_gfx95_supported() + ) + if _use_fp8_prefill_attn: + tile_q = 256 + qlen_granularity = tile_q // (self.num_head // self.num_kv_head) + ( + work_metadata, + work_indptr, + work_info_set, + reduce_indptr, + reduce_final_map, + reduce_partial_map + ) = self.make_mla_prefill_ps_meta_data_buffer( + bs, max_q_len, qlen_granularity + ) + + + self.make_mla_prefill_ps_meta_data( + qo_indptr, + qo_indptr, + forward_batch.seq_lens, + work_metadata, + work_indptr, + work_info_set, + reduce_indptr, + reduce_final_map, + reduce_partial_map, + is_causal=True, + ) + + total_s = int(forward_batch.extend_seq_lens.sum()) + fp8_prefill_kv_indices = torch.arange( + total_s, device=self.device, dtype=torch.int32 + ) + + self.forward_metadata = ForwardMetadata( + self.mla_indices_updater_prefill.kv_indptr, + self.mla_indices_updater_prefill.kv_indices, + qo_indptr, + self.kv_last_page_len[:bs], + max_q_len, + self.mla_indices_updater_prefill.max_kv_len, + None, + None, + work_metadata=work_metadata, + work_info_set=work_info_set, + work_indptr=work_indptr, + reduce_indptr=reduce_indptr, + reduce_final_map=reduce_final_map, + reduce_partial_map=reduce_partial_map, + fp8_prefill_kv_indices=fp8_prefill_kv_indices, + ) + else: + self.indices_updater_prefill.update( + forward_batch.req_pool_indices, + forward_batch.seq_lens, + forward_batch.seq_lens_sum, + prefix_lens, + encoder_lens=forward_batch.encoder_lens, + spec_info=None, + ) + # Get page_table for mha_batch_prefill_func + page_table = forward_batch.req_to_token_pool.req_to_token[forward_batch.req_pool_indices, :] + self.forward_metadata = ForwardMetadata( + self.indices_updater_prefill.kv_indptr, + self.indices_updater_prefill.kv_indices, + self.qo_indptr[: bs + 1], # qo_indptr is set by indices_updater_prefill + None, + self.indices_updater_prefill.max_q_len, + self.indices_updater_prefill.max_kv_len, + None, + forward_batch.seq_lens, + ) + + if (forward_batch.forward_mode.is_extend() and + not self.use_mla and + self.forward_metadata.page_table is not None): + if self.page_size > 1: + seq_lens_cpu = forward_batch.seq_lens_cpu + if seq_lens_cpu is None: + seq_lens_cpu = forward_batch.seq_lens.cpu() + max_seq_pages = (seq_lens_cpu.max().item() + self.page_size - 1) // self.page_size + 1 + self.forward_metadata.page_table = ( + self.forward_metadata.page_table[:, self.strided_indices[:max_seq_pages]] // self.page_size + ) + if self.decode_using_pa_ps: + self._build_pa_metadata_for_prefill(forward_batch.batch_size) + if not self.decode_using_pa_ps and self.page_size > 1 and self.forward_metadata.page_table is not None: + self.forward_metadata.page_table = ( + self.forward_metadata.page_table[:, self.strided_indices] // self.page_size + ) + + def _allocate_pa_metadata_buffers( + self, + work_metadata_ptrs_size, + work_metadata_ptrs_type, + work_indptr_size, + work_indptr_type, + work_info_size, + work_info_type, + reduce_indptr_size, + reduce_indptr_type, + reduce_final_map_size, + reduce_final_map_type, + reduce_partial_map_size, + reduce_partial_map_type, + ): + """Allocate or reuse pa_metadata buffers.""" + if self.pa_metadata_buffers is None: + self.pa_metadata_buffers = {} + + def _get_size_val(size): + return size[0] if isinstance(size, tuple) else size + + # Allocate work_metadata_ptrs + size_val = _get_size_val(work_metadata_ptrs_size) + if ("work_metadata_ptrs" not in self.pa_metadata_buffers or + self.pa_metadata_buffers["work_metadata_ptrs"].shape[0] < size_val): + self.pa_metadata_buffers["work_metadata_ptrs"] = torch.empty( + work_metadata_ptrs_size, dtype=work_metadata_ptrs_type, device=self.device + ) + + # Allocate work_indptr + size_val = _get_size_val(work_indptr_size) + if ("work_indptr" not in self.pa_metadata_buffers or + self.pa_metadata_buffers["work_indptr"].shape[0] < size_val): + self.pa_metadata_buffers["work_indptr"] = torch.zeros( + work_indptr_size, dtype=work_indptr_type, device=self.device + ) + else: + self.pa_metadata_buffers["work_indptr"].zero_() + + # Allocate work_info + size_val = _get_size_val(work_info_size) + if ("work_info" not in self.pa_metadata_buffers or + len(self.pa_metadata_buffers["work_info"].shape) < len(work_info_size) or + self.pa_metadata_buffers["work_info"].shape[0] < size_val): + self.pa_metadata_buffers["work_info"] = torch.zeros( + work_info_size, dtype=work_info_type, device=self.device + ) + else: + self.pa_metadata_buffers["work_info"].zero_() + + # Allocate reduce_indptr + size_val = _get_size_val(reduce_indptr_size) + if ("reduce_indptr" not in self.pa_metadata_buffers or + self.pa_metadata_buffers["reduce_indptr"].shape[0] < size_val): + self.pa_metadata_buffers["reduce_indptr"] = torch.zeros( + reduce_indptr_size, dtype=reduce_indptr_type, device=self.device + ) + else: + self.pa_metadata_buffers["reduce_indptr"].zero_() + + # Allocate reduce_final_map + size_val = _get_size_val(reduce_final_map_size) + if ("reduce_final_map" not in self.pa_metadata_buffers or + len(self.pa_metadata_buffers["reduce_final_map"].shape) < len(reduce_final_map_size) or + self.pa_metadata_buffers["reduce_final_map"].shape[0] < size_val): + self.pa_metadata_buffers["reduce_final_map"] = torch.zeros( + reduce_final_map_size, dtype=reduce_final_map_type, device=self.device + ) + else: + self.pa_metadata_buffers["reduce_final_map"].zero_() + + # Allocate reduce_partial_map + reduce_partial_map_size_val = reduce_partial_map_size if isinstance(reduce_partial_map_size, int) else reduce_partial_map_size[0] + if ("reduce_partial_map" not in self.pa_metadata_buffers or + self.pa_metadata_buffers["reduce_partial_map"].shape[0] < reduce_partial_map_size_val): + self.pa_metadata_buffers["reduce_partial_map"] = torch.zeros( + reduce_partial_map_size, dtype=reduce_partial_map_type, device=self.device + ) + else: + self.pa_metadata_buffers["reduce_partial_map"].zero_() + + def _build_pa_metadata_for_decode( + self, + batch_size: int, + tp_q_head_num: Optional[int] = None, + ): + """Build pa_metadata buffers for pa_persistent_fwd in decode mode. + + This method prepares all metadata buffers needed for pa_persistent_fwd kernel. + The metadata can be reused across multiple layers in the same forward pass. + + Args: + batch_size: Batch size for the current forward pass + tp_q_head_num: Number of Q heads per TP rank. If None, uses self.num_head. + """ + max_qlen = 1 + + # Use provided tp_q_head_num or default to self.num_head + if tp_q_head_num is None: + tp_q_head_num = self.num_head + + # kv_dtype_for_metadata = dtypes.fp8 + ( + (work_metadata_ptrs_size, work_metadata_ptrs_type), + (work_indptr_size, work_indptr_type), + (work_info_size, work_info_type), + (reduce_indptr_size, reduce_indptr_type), + (reduce_final_map_size, reduce_final_map_type), + (reduce_partial_map_size, reduce_partial_map_type), + ) = get_pa_metadata_info_v1( + batch_size, + self.num_kv_head, + ) + # Allocate metadata buffers with reuse optimization for multi-layer forward passes + self._allocate_pa_metadata_buffers( + work_metadata_ptrs_size, + work_metadata_ptrs_type, + work_indptr_size, + work_indptr_type, + work_info_size, + work_info_type, + reduce_indptr_size, + reduce_indptr_type, + reduce_final_map_size, + reduce_final_map_type, + reduce_partial_map_size, + reduce_partial_map_type, + ) + qo_indptr = self.pa_decode_qo_indptr[: batch_size + 1] + + # Get context_lens (kv_lens is always set before calling _build_pa_metadata_for_decode) + # Note: kv_lens comes from self.seq_lens which is already int32 + context_lens = self.forward_metadata.kv_lens + + kernel_block_size = self.page_size + num_blocks_per_seq = (context_lens + kernel_block_size - 1) // kernel_block_size + # Use dedicated pa_kv_indptr buffer (similar to self.kv_indptr, but for pa_persistent_fwd) + pages_kv_indptr = self.pa_kv_indptr[: batch_size + 1] + pages_kv_indptr[1 : batch_size + 1] = torch.cumsum(num_blocks_per_seq, dim=0) + + # Convert page_table to kv_indices (block indices) using Triton kernel to avoid sync + # page_table shape: [batch_size, max_num_blocks_per_seq] + # Note: page_table comes from self.page_table which is already int32 and always set before this call + page_table = self.forward_metadata.page_table + + # Use Triton kernel to gather kv_indices from page_table (avoids high-level indexing sync) + create_flashinfer_kv_indices_triton[(batch_size,)]( + page_table, + self.pa_batch_indices[:batch_size], # [0, 1, 2, ..., batch_size-1] + num_blocks_per_seq, + pages_kv_indptr, + None, # kv_start_idx + self.pa_kv_indices, + page_table.stride(0), + ) + # Use the full buffer - pa_persistent_fwd reads only valid elements based on pages_kv_indptr + kv_indices = self.pa_kv_indices + + get_pa_metadata_v1( + seqlens_qo_indptr=qo_indptr, + pages_kv_indptr=pages_kv_indptr, + context_lens=context_lens.int(), + num_heads_per_head_k=tp_q_head_num // self.num_kv_head, + num_heads_k=self.num_kv_head, + is_causal=True, + work_metadata_ptrs=self.pa_metadata_buffers["work_metadata_ptrs"], + work_indptr=self.pa_metadata_buffers["work_indptr"], + work_info=self.pa_metadata_buffers["work_info"], + reduce_indptr=self.pa_metadata_buffers["reduce_indptr"], + reduce_final_map=self.pa_metadata_buffers["reduce_final_map"], + reduce_partial_map=self.pa_metadata_buffers["reduce_partial_map"], + kv_granularity=max(kernel_block_size, 16), + block_size=kernel_block_size, + max_seqlen_qo=max_qlen, + uni_seqlen_qo=max_qlen, + fast_mode=True, + topk=-1, + max_split_per_batch=-1, + ) + # Store computed values in ForwardMetadata for reuse in forward_decode + self.forward_metadata.pa_metadata_qo_indptr = qo_indptr + self.forward_metadata.pa_metadata_pages_kv_indptr = pages_kv_indptr + self.forward_metadata.pa_metadata_kv_indices = kv_indices + self.forward_metadata.pa_metadata_context_lens = context_lens + self.forward_metadata.pa_metadata_max_qlen = max_qlen + self.forward_metadata.pa_metadata_tp_q_head_num = tp_q_head_num + + def _build_pa_metadata_for_prefill(self, batch_size: int): + """Build metadata for mha_batch_prefill_func in prefill mode. + + This method prepares page-level metadata needed for mha_batch_prefill_func. + The metadata is computed once per forward pass and reused across all layers. + """ + block_size = self.page_size + context_lens = self.forward_metadata.kv_lens + num_blocks_per_seq = (context_lens + block_size - 1) // block_size + + # Page-level kv_indptr (reuse pa_kv_indptr buffer) + pages_kv_indptr = self.pa_kv_indptr[: batch_size + 1] + pages_kv_indptr[1 : batch_size + 1] = torch.cumsum(num_blocks_per_seq, dim=0) + + # Build kv_indices from page_table using triton kernel + page_table = self.forward_metadata.page_table + create_flashinfer_kv_indices_triton[(batch_size,)]( + page_table, + self.pa_batch_indices[:batch_size], + num_blocks_per_seq, + pages_kv_indptr, + None, # kv_start_idx + self.pa_kv_indices, + page_table.stride(0), + ) + # kv_indices = self.pa_kv_indices + + # Compute kv_last_page_lens for each sequence + # kv_last_page_lens = ((context_lens - 1) % block_size + 1).int() + + # Store in ForwardMetadata for reuse in forward_extend + # self.forward_metadata.prefill_pages_kv_indptr = pages_kv_indptr + # self.forward_metadata.prefill_kv_indices = kv_indices + # self.forward_metadata.prefill_kv_last_page_lens = kv_last_page_lens + + def init_cuda_graph_state( + self, + max_bs: int, + max_num_tokens: int, + kv_indices_buf: Optional[torch.Tensor] = None, + ): + self.cuda_graph_kv_last_page_len = torch.ones(max_bs, dtype=torch.int) + if kv_indices_buf is None: + self.cuda_graph_kv_indices = torch.zeros( + (max_bs * self.max_context_len), + dtype=torch.int32, + device=self.device, + ) + else: + self.cuda_graph_kv_indices = kv_indices_buf + + # Always use preshuffle layout for pa_fwd_asm + self.page_table = torch.zeros( + (max_bs, self.max_context_len // self.page_size), dtype=torch.int32, device=self.device + ) + self.seq_lens = torch.zeros( + (max_bs,), dtype=torch.int32, device=self.device + ) + self.strided_indices = torch.arange( + 0, self.max_context_len, self.page_size, device=self.device + ) + + if self.use_mla and _sglang_aiter._use_mla_ps_kernel: + max_seqlen_qo = 1 + ( + self.work_metadata, + self.work_indptr, + self.work_info_set, + self.reduce_indptr, + self.reduce_final_map, + self.reduce_partial_map, + ) = self.make_mla_decode_meta_data_buffer(max_seqlen_qo, max_bs) + elif self.use_mla: + self.work_metadata = None + self.work_indptr = None + self.work_info_set = None + self.reduce_indptr = None + self.reduce_final_map = None + self.reduce_partial_map = None + + if self.decode_using_pa_ps and not self.use_mla: + ( + (work_metadata_ptrs_size, work_metadata_ptrs_type), + (work_indptr_size, work_indptr_type), + (work_info_size, work_info_type), + (reduce_indptr_size, reduce_indptr_type), + (reduce_final_map_size, reduce_final_map_type), + (reduce_partial_map_size, reduce_partial_map_type), + ) = get_pa_metadata_info_v1( + max_bs, + self.num_kv_head, + ) + + self._allocate_pa_metadata_buffers( + work_metadata_ptrs_size, + work_metadata_ptrs_type, + work_indptr_size, + work_indptr_type, + work_info_size, + work_info_type, + reduce_indptr_size, + reduce_indptr_type, + reduce_final_map_size, + reduce_final_map_type, + reduce_partial_map_size, + reduce_partial_map_type, + ) + + def init_forward_metadata_capture_cuda_graph( + self, + bs: int, + num_tokens: int, + req_pool_indices: torch.Tensor, + seq_lens: torch.Tensor, + encoder_lens: Optional[torch.Tensor], + forward_mode: ForwardMode, + spec_info: Optional[SpecInput], + ): + if forward_mode.is_decode_or_idle(): + if self.use_mla: + kv_indptr = self.kv_indptr + kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens, dim=0) + kv_indptr = kv_indptr[: bs + 1] + kv_indices = self.cuda_graph_kv_indices + create_flashinfer_kv_indices_triton[(bs,)]( + self.req_to_token, + req_pool_indices, + seq_lens, + kv_indptr, + None, + kv_indices, + self.req_to_token.stride(0), + ) + + qo_indptr = self.qo_indptr_[: bs + 1] + qo_indptr[1 : bs + 1] = torch.cumsum( + self.cuda_graph_kv_last_page_len[:bs], dim=0 + ) + kv_last_page_len = self.cuda_graph_kv_last_page_len[:bs] + max_q_len = 1 + + work_metadata = None + work_indptr = None + work_info_set = None + reduce_indptr = None + reduce_final_map = None + reduce_partial_map = None + num_kv_splits = None + + if _sglang_aiter._use_mla_ps_kernel: + num_kv_splits = self.max_split_per_batch + + self.make_mla_meta_data( + qo_indptr, + kv_indptr, + kv_last_page_len, + self.work_metadata, + self.work_info_set, + self.work_indptr, + self.reduce_indptr, + self.reduce_final_map, + self.reduce_partial_map, + max_q_len, + fast_mode=_sglang_aiter.fast_mode, + max_split_per_batch=num_kv_splits, + intra_batch_mode=_sglang_aiter.intra_batch_mode, + ) + + work_metadata = self.work_metadata + work_info_set = self.work_info_set + work_indptr = self.work_indptr + reduce_indptr = self.reduce_indptr + reduce_final_map = self.reduce_final_map + reduce_partial_map = self.reduce_partial_map + + self.forward_metadata = ForwardMetadata( + kv_indptr, + kv_indices, + qo_indptr, + kv_last_page_len, + max_q_len, + None, + None, + None, + work_metadata=work_metadata, + work_info_set=work_info_set, + work_indptr=work_indptr, + reduce_indptr=reduce_indptr, + reduce_final_map=reduce_final_map, + reduce_partial_map=reduce_partial_map, + num_kv_splits=num_kv_splits, + ) + else: + page_table = self.page_table[:bs, :] + self.seq_lens[:bs].copy_(seq_lens, non_blocking=True) + seq_lens_persistent = self.seq_lens[:bs] + self.forward_metadata = ForwardMetadata( + None, + None, + None, + None, + 1, + None, + page_table, + seq_lens_persistent, + ) + + if self.decode_using_pa_ps: + self._build_pa_metadata_for_decode(bs, tp_q_head_num=self.num_head) + return + else: + raise ValueError(f"Invalid mode: {forward_mode=}") + + def init_forward_metadata_replay_cuda_graph( + self, + bs: int, + req_pool_indices: torch.Tensor, + seq_lens: torch.Tensor, + seq_lens_sum: int, + encoder_lens: Optional[torch.Tensor], + forward_mode: ForwardMode, + spec_info: Optional[SpecInput], + seq_lens_cpu: Optional[torch.Tensor], + out_cache_loc: Optional[torch.Tensor] = None, + ): + if forward_mode.is_decode_or_idle(): + if self.use_mla: + kv_indptr = self.kv_indptr + kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens, dim=0) + kv_indptr = kv_indptr[: bs + 1] + kv_indices = self.cuda_graph_kv_indices + create_flashinfer_kv_indices_triton[(bs,)]( + self.req_to_token, + req_pool_indices, + seq_lens, + kv_indptr, + None, + kv_indices, + self.req_to_token.stride(0), + ) + + qo_indptr = self.qo_indptr_[: bs + 1] + qo_indptr[1 : bs + 1] = torch.cumsum( + self.cuda_graph_kv_last_page_len[:bs], dim=0 + ) + kv_last_page_len = self.cuda_graph_kv_last_page_len[:bs] + max_q_len = 1 + + work_metadata = None + work_indptr = None + work_info_set = None + reduce_indptr = None + reduce_final_map = None + reduce_partial_map = None + num_kv_splits = None + + if _sglang_aiter._use_mla_ps_kernel: + num_kv_splits = self.max_split_per_batch + + self.make_mla_meta_data( + qo_indptr, + kv_indptr, + kv_last_page_len, + self.work_metadata, + self.work_info_set, + self.work_indptr, + self.reduce_indptr, + self.reduce_final_map, + self.reduce_partial_map, + max_q_len, + fast_mode=_sglang_aiter.fast_mode, + max_split_per_batch=num_kv_splits, + intra_batch_mode=_sglang_aiter.intra_batch_mode, + ) + + work_metadata = self.work_metadata + work_info_set = self.work_info_set + work_indptr = self.work_indptr + reduce_indptr = self.reduce_indptr + reduce_final_map = self.reduce_final_map + reduce_partial_map = self.reduce_partial_map + + self.forward_metadata = ForwardMetadata( + kv_indptr, + kv_indices, + qo_indptr, + kv_last_page_len, + max_q_len, + None, + None, + None, + work_metadata=work_metadata, + work_info_set=work_info_set, + work_indptr=work_indptr, + reduce_indptr=reduce_indptr, + reduce_final_map=reduce_final_map, + reduce_partial_map=reduce_partial_map, + num_kv_splits=num_kv_splits, + ) + else: + page_table_persistent = self.page_table + seq_lens_persistent = self.seq_lens + seq_lens_persistent.fill_(0) + page_table_persistent.fill_(0) + seq_lens_persistent[:bs].copy_(seq_lens, non_blocking=True) + max_seq_pages = (seq_lens_cpu.max().item() + self.page_size - 1) // self.page_size + 1 + page_table = self.req_to_token[req_pool_indices[:, None], self.strided_indices[:max_seq_pages][None, :],] + page_table_persistent[:bs, :max_seq_pages].copy_(page_table // self.page_size, non_blocking=True) + + self.forward_metadata = ForwardMetadata( + None, + None, + None, + None, + 1, + None, + page_table_persistent[:bs, :max_seq_pages], + seq_lens_persistent[:bs], + ) + + if self.decode_using_pa_ps: + self._build_pa_metadata_for_decode(bs, tp_q_head_num=self.num_head) + else: + raise ValueError("Invalid forward mode") + + def set_kv_buffer_with_layout_shuffle( + self, + cache_loc, + k, + v, + k_buffer, + v_buffer, + k_scale, + v_scale, + block_size, + ): + num_slots, num_kv_heads, head_dim = k_buffer.shape + num_blocks = num_slots // block_size + num_slots_with_block = num_blocks * block_size + k_buffer = k_buffer[:num_slots_with_block].view(num_blocks, block_size, num_kv_heads, head_dim) + v_buffer = v_buffer[:num_slots_with_block].view(num_blocks, block_size, num_kv_heads, head_dim) + reshape_and_cache_shuffle_triton( + k, + v, + k_buffer, + v_buffer, + cache_loc, + "auto", + k_scale, + v_scale, + ) + + def forward_extend(self, q, k, v, layer, forward_batch, save_kv_cache=True): + cache_loc = ( + forward_batch.out_cache_loc + if not layer.is_cross_attention + else forward_batch.encoder_out_cache_loc + ) + + self.logits_soft_cap = layer.logit_cap + + if k is not None: + assert v is not None + if save_kv_cache: + if self.use_mla: + forward_batch.token_to_kv_pool.set_kv_buffer( + layer, cache_loc, k, v + ) + else: + k_buffer, v_buffer = forward_batch.token_to_kv_pool.get_kv_buffer( + layer.layer_id + ) + self.set_kv_buffer_with_layout_shuffle( + cache_loc, k, v, k_buffer, v_buffer, + layer.k_scale, layer.v_scale, self.page_size, + ) + + if self.use_mla: + return self._forward_extend_mla(q, k, v, layer, forward_batch) + else: + return self._forward_extend_mha(q, k, v, layer, forward_batch) + + def _forward_extend_mha(self, q, k, v, layer, forward_batch): + """Non-MLA extend path: standard MHA with flash_attn_varlen_func.""" + seqlens_in_batch = forward_batch.seq_lens + cu_seqlens_q = torch.nn.functional.pad( + torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0) + ) + if q.dtype != k.dtype and k.dtype == dtypes.fp8: + q = q.to(dtypes.fp8) + o = flash_attn_varlen_func( + q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim), + k.contiguous().view(-1, layer.tp_k_head_num, layer.head_dim), + v.contiguous().view(-1, layer.tp_v_head_num, layer.head_dim), + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_q, + max_seqlen_q=self.forward_metadata.max_q_len, + max_seqlen_k=self.forward_metadata.max_kv_len, + min_seqlen_q=0, + dropout_p=0.0, + softmax_scale=self.scale, + causal=True, + window_size=(-1, -1, 0), + sink_ptr=None, + ) + return o.view(-1, layer.tp_q_head_num * layer.head_dim) + + def _forward_extend_mla(self, q, k, v, layer, forward_batch): + """MLA extend path: ported from sglang aiter_backend forward_extend MLA logic.""" + max_q_len = self.forward_metadata.max_q_len + max_kv_len = self.forward_metadata.max_kv_len + kv_indptr = self.forward_metadata.kv_indptr + kv_indices = self.forward_metadata.kv_indices + qo_indptr = self.forward_metadata.qo_indptr + + K_Buffer = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id) + V_Buffer = forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id) + kv_lora_rank = V_Buffer.shape[-1] + qk_rope_head_dim = K_Buffer.shape[-1] - kv_lora_rank + qk_nope_head_dim = k.shape[-1] - qk_rope_head_dim + + assert len(q.shape) == 3 + assert len(k.shape) == 3 + assert len(v.shape) == 3 + + if ( + forward_batch.forward_mode.is_extend() + and not forward_batch.forward_mode.is_target_verify() + and not forward_batch.forward_mode.is_draft_extend() + ): + return self._forward_extend_mla_normal( + q, k, v, layer, forward_batch, + K_Buffer, V_Buffer, + kv_lora_rank, qk_rope_head_dim, qk_nope_head_dim, + max_q_len, max_kv_len, kv_indptr, kv_indices, qo_indptr, + ) + elif forward_batch.forward_mode.is_target_verify(): + return self._forward_extend_mla_target_verify( + q, layer, K_Buffer, qo_indptr, + ) + elif forward_batch.forward_mode.is_draft_extend(): + return self._forward_extend_mla_draft_extend( + q, layer, K_Buffer, qo_indptr, + ) + else: + raise ValueError( + f"Invalid forward mode for MLA extend: {forward_batch.forward_mode=}" + ) + + def _forward_extend_mla_normal( + self, q, k, v, layer, forward_batch, + K_Buffer, V_Buffer, + kv_lora_rank, qk_rope_head_dim, qk_nope_head_dim, + max_q_len, max_kv_len, kv_indptr, kv_indices, qo_indptr, + ): + """Normal MLA extend (not target_verify, not draft_extend). + + Three sub-paths mirroring sglang aiter_backend: + 1) No prefix -> fp8 prefill kernel (mla_prefill_ps_asm_fwd) or flash_attn fallback + 2) Has prefix, absorbed weights differ -> decompress via kv_b_proj + flash_attn + 3) Has prefix, qk_head_dim matches -> mla_prefill_fwd kernel + """ + extend_no_prefix = not any(forward_batch.extend_prefix_lens_cpu) + + if kv_indices.shape[0] == 0 or extend_no_prefix: + # --- Sub-path 1: no prefix, pure prefill --- + use_fp8_prefill = ( + self.forward_metadata.fp8_prefill_kv_indices is not None + ) + if use_fp8_prefill: + total_s = q.shape[0] + nhead = layer.tp_q_head_num + v_head_dim = layer.v_head_dim + + if q.dtype != dtypes.fp8: + q = q.to(dtypes.fp8) + if k.dtype != dtypes.fp8: + k = k.to(dtypes.fp8) + if v.dtype != dtypes.fp8: + v = v.to(dtypes.fp8) + one_scale = torch.ones( + (), dtype=torch.float32, device=q.device + ) + + kv_indptr_asm = qo_indptr + kv_indices_asm = self.forward_metadata.fp8_prefill_kv_indices + + tile_q = 256 + reduce_indptr = self.forward_metadata.reduce_indptr + reduce_final_map = self.forward_metadata.reduce_final_map + reduce_partial_map = self.forward_metadata.reduce_partial_map + + logits = torch.empty( + (reduce_partial_map.size(0) * tile_q, nhead, v_head_dim), + dtype=torch.float32, + device=q.device, + ) + attn_lse = torch.empty( + (reduce_partial_map.size(0) * tile_q, nhead), + dtype=torch.float32, + device=q.device, + ) + final_lse = torch.empty( + (total_s, nhead), + dtype=torch.float32, + device=q.device, + ) + output = q.new_empty( + (total_s, nhead, v_head_dim), + dtype=self.input_dtype, + ) + + mla_prefill_ps_asm_fwd( + q, + k, + v, + qo_indptr, + kv_indptr_asm, + kv_indices_asm, + self.forward_metadata.work_indptr, + self.forward_metadata.work_info_set, + max_q_len, + layer.scaling, + True, + logits, + attn_lse, + output, + one_scale, + one_scale, + one_scale, + ) + mla_reduce_v1( + logits, + attn_lse, + reduce_indptr, + reduce_final_map, + reduce_partial_map, + tile_q, + output, + final_lse, + ) + elif layer.qk_head_dim == (kv_lora_rank + qk_rope_head_dim) and mla_prefill_fwd is not None: + # Absorbed MLA: head_dim (576) exceeds CK limit (256), + # use mla_prefill_fwd which natively supports large MLA head dims. + # For no-prefix, use input k (bfloat16) directly instead of K_Buffer + # (which may be FP8). mla_prefill_fwd doesn't support FP8 KV. + if layer.qk_head_dim != layer.v_head_dim: + output = q.new_empty( + (q.shape[0], layer.tp_q_head_num * layer.v_head_dim) + ) + else: + output = torch.empty_like(q) + total_s = q.shape[0] + temp_kv_indices = torch.arange( + total_s, device=q.device, dtype=torch.int32 + ) + mla_prefill_fwd( + q.view(-1, layer.tp_q_head_num, layer.qk_head_dim), + k.view(-1, 1, 1, layer.qk_head_dim), + output.view(-1, layer.tp_q_head_num, layer.v_head_dim), + qo_indptr, + qo_indptr, + temp_kv_indices, + self.forward_metadata.kv_last_page_len, + max_q_len, + layer.scaling, + layer.logit_cap, + ) + else: + output = flash_attn_varlen_func( + q, + k, + v, + qo_indptr, + qo_indptr, + max_q_len, + max_q_len, + softmax_scale=layer.scaling, + causal=True, + ) + return output + + elif layer.qk_head_dim != (kv_lora_rank + qk_rope_head_dim): + # --- Sub-path 2: has prefix, need kv_b_proj decompress --- + K_Buffer = torch.index_select(K_Buffer, 0, kv_indices) + kvc, k_pe = torch.split( + K_Buffer, [kv_lora_rank, qk_rope_head_dim], dim=-1 + ) + + if self.kv_cache_dtype == dtypes.fp8: + dtype = q.dtype + kvc = kvc.to(dtype) + k_pe = k_pe.to(dtype) + + kvprefix = layer.kv_b_proj(kvc.contiguous())[0] + kvprefix = kvprefix.view( + -1, layer.tp_k_head_num, qk_nope_head_dim + layer.v_head_dim + ) + k_prefix, v_prefix = torch.split( + kvprefix, [qk_nope_head_dim, layer.v_head_dim], dim=-1 + ) + k_prefix = torch.cat( + [ + k_prefix, + torch.broadcast_to( + k_pe, + (k_pe.shape[0], layer.tp_k_head_num, k_pe.shape[2]), + ), + ], + dim=-1, + ) + + assert ( + forward_batch.extend_prefix_lens.shape + == forward_batch.extend_seq_lens.shape + ) + + o = flash_attn_varlen_func( + q, + k_prefix, + v_prefix, + qo_indptr, + kv_indptr, + max_q_len, + max_kv_len, + softmax_scale=layer.scaling, + causal=True, + ) + return o + + else: + # --- Sub-path 3: has prefix, qk_head_dim == kv_lora_rank + qk_rope_head_dim --- + # Gather needed KV entries and cast to bf16 (K_Buffer may be FP8) + k_selected = torch.index_select(K_Buffer, 0, kv_indices) + if k_selected.dtype != q.dtype: + k_selected = k_selected.to(q.dtype) + compact_kv_indices = torch.arange( + k_selected.shape[0], device=q.device, dtype=torch.int32 + ) + + if layer.qk_head_dim != layer.v_head_dim: + o = q.new_empty( + (q.shape[0], layer.tp_q_head_num * layer.v_head_dim) + ) + else: + o = torch.empty_like(q) + + mla_prefill_fwd( + q.view(-1, layer.tp_q_head_num, layer.qk_head_dim), + k_selected.view(-1, 1, 1, layer.qk_head_dim), + o.view(-1, layer.tp_q_head_num, layer.v_head_dim), + qo_indptr, + kv_indptr, + compact_kv_indices, + self.forward_metadata.kv_last_page_len, + self.forward_metadata.max_q_len, + layer.scaling, + layer.logit_cap, + ) + return o + + def _forward_extend_mla_target_verify(self, q, layer, K_Buffer, qo_indptr): + """MLA target_verify path (speculative decoding verification).""" + o = q.new_empty( + (q.shape[0], layer.tp_q_head_num, layer.v_head_dim), + dtype=self.input_dtype, + ) + + work_metadata = self.forward_metadata.work_metadata + work_indptr = self.forward_metadata.work_indptr + work_info_set = self.forward_metadata.work_info_set + reduce_indptr = self.forward_metadata.reduce_indptr + reduce_final_map = self.forward_metadata.reduce_final_map + reduce_partial_map = self.forward_metadata.reduce_partial_map + num_kv_splits = self.forward_metadata.num_kv_splits + + mla_decode_fwd( + q, + K_Buffer.view(-1, 1, 1, layer.qk_head_dim), + o, + self.forward_metadata.qo_indptr, + self.forward_metadata.kv_indptr, + self.forward_metadata.kv_indices, + self.forward_metadata.kv_last_page_len, + self.forward_metadata.max_q_len, + sm_scale=layer.scaling, + logit_cap=layer.logit_cap, + work_meta_data=work_metadata, + work_indptr=work_indptr, + work_info_set=work_info_set, + reduce_indptr=reduce_indptr, + reduce_final_map=reduce_final_map, + reduce_partial_map=reduce_partial_map, + q_scale=layer.k_scale, + kv_scale=layer.k_scale, + intra_batch_mode=_sglang_aiter.intra_batch_mode, + num_kv_splits=num_kv_splits, + ) + return o + + def _forward_extend_mla_draft_extend(self, q, layer, K_Buffer, qo_indptr): + """MLA draft_extend path (speculative decoding draft extension).""" + o = q.new_empty( + (q.shape[0], layer.tp_q_head_num, layer.v_head_dim), + dtype=self.input_dtype, + ) + + work_metadata = self.forward_metadata.work_metadata + work_indptr = self.forward_metadata.work_indptr + work_info_set = self.forward_metadata.work_info_set + reduce_indptr = self.forward_metadata.reduce_indptr + reduce_final_map = self.forward_metadata.reduce_final_map + reduce_partial_map = self.forward_metadata.reduce_partial_map + num_kv_splits = self.forward_metadata.num_kv_splits + + mla_decode_fwd( + q, + K_Buffer.view(-1, 1, 1, layer.qk_head_dim), + o, + self.forward_metadata.qo_indptr, + self.forward_metadata.kv_indptr, + self.forward_metadata.kv_indices, + self.forward_metadata.kv_last_page_len, + self.forward_metadata.max_q_len, + sm_scale=layer.scaling, + logit_cap=layer.logit_cap, + work_meta_data=work_metadata, + work_indptr=work_indptr, + work_info_set=work_info_set, + reduce_indptr=reduce_indptr, + reduce_final_map=reduce_final_map, + reduce_partial_map=reduce_partial_map, + q_scale=layer.k_scale, + kv_scale=layer.k_scale, + intra_batch_mode=_sglang_aiter.intra_batch_mode, + num_kv_splits=num_kv_splits, + ) + return o + + + def forward_decode_pa( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + layer: RadixAttention, + forward_batch: ForwardBatch, + save_kv_cache=True, + ): + q = q.reshape(-1, layer.tp_q_head_num * layer.qk_head_dim) + + if layer.qk_head_dim != layer.v_head_dim: + o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim)) + else: + o = torch.empty_like(q) + + if save_kv_cache: + k_buffer, v_buffer = forward_batch.token_to_kv_pool.get_kv_buffer( + layer.layer_id + ) + self.set_kv_buffer_with_layout_shuffle(forward_batch.out_cache_loc, k, v, k_buffer, v_buffer, layer.k_scale, layer.v_scale, self.page_size) + + if self.use_mla: + k_buffer = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id) + + work_metadata = self.forward_metadata.work_metadata + work_indptr = self.forward_metadata.work_indptr + work_info_set = self.forward_metadata.work_info_set + + reduce_indptr = self.forward_metadata.reduce_indptr + reduce_final_map = self.forward_metadata.reduce_final_map + reduce_partial_map = self.forward_metadata.reduce_partial_map + + num_kv_splits = self.forward_metadata.num_kv_splits + + mla_decode_fwd( + q.view(-1, layer.tp_q_head_num, layer.qk_head_dim), + k_buffer.view(-1, 1, 1, layer.qk_head_dim), + o.view(-1, layer.tp_q_head_num, layer.v_head_dim), + self.forward_metadata.qo_indptr, + self.forward_metadata.kv_indptr, + self.forward_metadata.kv_indices, + self.forward_metadata.kv_last_page_len, + self.forward_metadata.max_q_len, + sm_scale=layer.scaling, + logit_cap=layer.logit_cap, + work_meta_data=work_metadata, + work_indptr=work_indptr, + work_info_set=work_info_set, + reduce_indptr=reduce_indptr, + reduce_final_map=reduce_final_map, + reduce_partial_map=reduce_partial_map, + q_scale=layer.k_scale, + kv_scale=layer.k_scale, + intra_batch_mode=_sglang_aiter.intra_batch_mode, + num_kv_splits=num_kv_splits, + ) + + else: + k_buffer, v_buffer = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id) + + block_size = self.page_size + num_slots, num_kv_heads, head_size = k_buffer.shape + num_blocks = num_slots // block_size + k_buffer = k_buffer[:num_blocks * block_size].view(num_blocks, block_size, num_kv_heads, head_size) + v_buffer = v_buffer[:num_blocks * block_size].view(num_blocks, block_size, num_kv_heads, head_size) + + x = 16 // k_buffer.element_size() + k_cache_template = torch.empty( + [num_blocks, num_kv_heads, head_size // x, block_size, x], + dtype=k_buffer.dtype, + device="meta", + ) + v_cache_template = torch.empty( + [num_blocks, num_kv_heads, block_size // x, head_size, x], + dtype=v_buffer.dtype, + device="meta", + ) + new_key_cache = k_buffer.view_as(k_cache_template) + new_value_cache = v_buffer.view_as(v_cache_template) + q = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim) + pa_fwd_asm( + Q=q, + K=new_key_cache, + V=new_value_cache, + block_tables=self.forward_metadata.page_table, + context_lens=self.forward_metadata.kv_lens, + block_tables_stride0=self.forward_metadata.page_table.stride(0), + K_QScale=self.k_scale, + V_QScale=self.v_scale, + out_=o, + ) + return o + + def forward_decode_pa_ps( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + layer: RadixAttention, + forward_batch: ForwardBatch, + save_kv_cache=True, + ): + q = q.reshape(-1, layer.tp_q_head_num * layer.qk_head_dim) + + # Create o as 3D tensor [batch_size, num_heads, head_dim] for both MLA and pa_fwd_asm + # In decode mode, q.shape[0] equals batch_size (each sequence has 1 token) + # Use q.shape[0] instead of forward_batch.batch_size to be safe + batch_size = q.shape[0] + head_dim_out = layer.v_head_dim if layer.qk_head_dim != layer.v_head_dim else layer.head_dim + o = q.new_empty((batch_size, layer.tp_q_head_num, head_dim_out)) + + if save_kv_cache: + if self.use_mla: + forward_batch.token_to_kv_pool.set_kv_buffer( + layer, forward_batch.out_cache_loc, k, v + ) + else: + k_buffer, v_buffer = forward_batch.token_to_kv_pool.get_kv_buffer( + layer.layer_id + ) + self.set_kv_buffer_with_layout_shuffle(forward_batch.out_cache_loc, k, v, k_buffer, v_buffer, layer.k_scale, layer.v_scale, self.page_size) + + if self.use_mla: + k_buffer = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id) + + work_metadata = self.forward_metadata.work_metadata + work_indptr = self.forward_metadata.work_indptr + work_info_set = self.forward_metadata.work_info_set + + reduce_indptr = self.forward_metadata.reduce_indptr + reduce_final_map = self.forward_metadata.reduce_final_map + reduce_partial_map = self.forward_metadata.reduce_partial_map + + num_kv_splits = self.forward_metadata.num_kv_splits + + mla_decode_fwd( + q.view(-1, layer.tp_q_head_num, layer.qk_head_dim), + k_buffer.view(-1, 1, 1, layer.qk_head_dim), + o.view(-1, layer.tp_q_head_num, layer.v_head_dim), + self.forward_metadata.qo_indptr, + self.forward_metadata.kv_indptr, + self.forward_metadata.kv_indices, + self.forward_metadata.kv_last_page_len, + self.forward_metadata.max_q_len, + sm_scale=layer.scaling, + logit_cap=layer.logit_cap, + work_meta_data=work_metadata, + work_indptr=work_indptr, + work_info_set=work_info_set, + reduce_indptr=reduce_indptr, + reduce_final_map=reduce_final_map, + reduce_partial_map=reduce_partial_map, + q_scale=layer.k_scale, + kv_scale=layer.k_scale, + intra_batch_mode=_sglang_aiter.intra_batch_mode, + num_kv_splits=num_kv_splits, + ) + else: + k_buffer, v_buffer = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id) + num_slots, num_kv_heads, head_size = k_buffer.shape + block_size = self.page_size + num_blocks = num_slots // block_size + k_buffer = k_buffer[:num_blocks * block_size].view(num_blocks, block_size, num_kv_heads, head_size) + v_buffer = v_buffer[:num_blocks * block_size].view(num_blocks, block_size, num_kv_heads, head_size) + + + quant_dtype = dtypes.fp8 + x = 16 // quant_dtype.itemsize + k_cache_template = torch.empty( + [num_blocks, num_kv_heads, head_size // x, block_size, x], + dtype=k_buffer.dtype, + device="meta", + ) + # V: [num_blocks, block_size, num_kv_heads, head_size] -> [num_blocks, num_kv_heads, block_size // x, head_size, x] + v_cache_template = torch.empty( + [num_blocks, num_kv_heads, block_size // x, head_size, x], + dtype=v_buffer.dtype, + device="meta", + ) + new_key_cache = k_buffer.view_as(k_cache_template) + new_value_cache = v_buffer.view_as(v_cache_template) + + total_tokens = num_blocks * block_size + k_qscale = self.k_qscale[:, :total_tokens] + v_qscale = self.v_qscale[:, :total_tokens] + + q = q.view(batch_size, layer.tp_q_head_num, layer.head_dim) + + + assert self.forward_metadata.pa_metadata_qo_indptr is not None, "pa_metadata_qo_indptr should be set by _build_pa_metadata_for_decode" + assert self.forward_metadata.pa_metadata_pages_kv_indptr is not None, "pa_metadata_pages_kv_indptr should be set by _build_pa_metadata_for_decode" + assert self.forward_metadata.pa_metadata_kv_indices is not None, "pa_metadata_kv_indices should be set by _build_pa_metadata_for_decode" + assert self.forward_metadata.pa_metadata_context_lens is not None, "pa_metadata_context_lens should be set by _build_pa_metadata_for_decode" + assert self.forward_metadata.pa_metadata_max_qlen is not None, "pa_metadata_max_qlen should be set by _build_pa_metadata_for_decode" + + qo_indptr = self.forward_metadata.pa_metadata_qo_indptr + kv_indptr = self.forward_metadata.pa_metadata_pages_kv_indptr + kv_indices = self.forward_metadata.pa_metadata_kv_indices + context_lens = self.forward_metadata.pa_metadata_context_lens + max_qlen = self.forward_metadata.pa_metadata_max_qlen + + + _, _ = pa_persistent_fwd( + Q=q, + K=new_key_cache, + V=new_value_cache, + output=o, + max_qlen=max_qlen, + qo_indptr=qo_indptr, + kv_indptr=kv_indptr, + kv_indices=kv_indices, + context_lens=context_lens, + work_indptr=self.pa_metadata_buffers["work_indptr"], + work_info=self.pa_metadata_buffers["work_info"], + reduce_indptr=self.pa_metadata_buffers["reduce_indptr"], + reduce_final_map=self.pa_metadata_buffers["reduce_final_map"], + reduce_partial_map=self.pa_metadata_buffers["reduce_partial_map"], + K_QScale=k_qscale, + V_QScale=v_qscale, + softmax_scale=layer.scaling, + mask=1, + ) + return o.view(-1, layer.tp_q_head_num * head_dim_out) + + + def forward_decode( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + layer: RadixAttention, + forward_batch: ForwardBatch, + save_kv_cache=True, + ): + if self.use_mla: + return self._forward_decode_mla(q, k, v, layer, forward_batch, save_kv_cache) + else: + if self.decode_using_pa_ps: + return self.forward_decode_pa_ps(q, k, v, layer, forward_batch, save_kv_cache) + else: + return self.forward_decode_pa(q, k, v, layer, forward_batch, save_kv_cache) + + def _forward_decode_mla(self, q, k, v, layer, forward_batch, save_kv_cache): + q = q.reshape(-1, layer.tp_q_head_num * layer.qk_head_dim) + + if layer.qk_head_dim != layer.v_head_dim: + o = q.new_empty( + (q.shape[0], layer.tp_q_head_num * layer.v_head_dim), + dtype=self.input_dtype, + ) + else: + o = torch.empty_like(q, dtype=self.input_dtype) + + if save_kv_cache: + forward_batch.token_to_kv_pool.set_kv_buffer( + layer, forward_batch.out_cache_loc, k, v + ) + + k_buffer = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id) + + work_metadata = self.forward_metadata.work_metadata + work_indptr = self.forward_metadata.work_indptr + work_info_set = self.forward_metadata.work_info_set + reduce_indptr = self.forward_metadata.reduce_indptr + reduce_final_map = self.forward_metadata.reduce_final_map + reduce_partial_map = self.forward_metadata.reduce_partial_map + num_kv_splits = self.forward_metadata.num_kv_splits + + if layer.layer_id == 0: + _q_view = q.view(-1, layer.tp_q_head_num, layer.qk_head_dim) + _k_view = k_buffer.view(-1, 1, 1, layer.qk_head_dim) + _o_view = o.view(-1, layer.tp_q_head_num, layer.v_head_dim) + print( + f"[MLA_DECODE_DBG] layer=0" + f" q={tuple(_q_view.shape)} q.dtype={_q_view.dtype}" + f" k_buf={tuple(_k_view.shape)} k_buf.dtype={_k_view.dtype}" + f" o={tuple(_o_view.shape)} o.dtype={_o_view.dtype}" + f" qo_indptr={self.forward_metadata.qo_indptr.tolist()}" + f" kv_indptr={self.forward_metadata.kv_indptr.tolist()}" + f" kv_indices_len={self.forward_metadata.kv_indices.shape[0]}" + f" kv_indices_max={self.forward_metadata.kv_indices.max().item()}" + f" kv_last_page_len={self.forward_metadata.kv_last_page_len.tolist()}" + f" max_q_len={self.forward_metadata.max_q_len}" + f" sm_scale={layer.scaling}" + f" logit_cap={layer.logit_cap}" + f" k_scale={layer.k_scale}" + f" num_kv_splits={num_kv_splits}" + f" page_size={self.page_size}" + f" work_metadata={tuple(work_metadata.shape) if work_metadata is not None else None}" + f" work_indptr={tuple(work_indptr.shape) if work_indptr is not None else None}" + f" work_info_set={tuple(work_info_set.shape) if work_info_set is not None else None}" + f" reduce_indptr={tuple(reduce_indptr.shape) if reduce_indptr is not None else None} val={reduce_indptr.tolist() if reduce_indptr is not None and reduce_indptr.numel() < 20 else 'big'}" + f" reduce_final_map={tuple(reduce_final_map.shape) if reduce_final_map is not None else None}" + f" reduce_partial_map={tuple(reduce_partial_map.shape) if reduce_partial_map is not None else None}" + f" intra_batch_mode={_sglang_aiter.intra_batch_mode}" + f" _use_mla_ps_kernel={_sglang_aiter._use_mla_ps_kernel}" + f" fast_mode={_sglang_aiter.fast_mode}" + , flush=True, + ) + + mla_decode_fwd( + q.view(-1, layer.tp_q_head_num, layer.qk_head_dim), + k_buffer.view(-1, 1, 1, layer.qk_head_dim), + o.view(-1, layer.tp_q_head_num, layer.v_head_dim), + self.forward_metadata.qo_indptr, + self.forward_metadata.kv_indptr, + self.forward_metadata.kv_indices, + self.forward_metadata.kv_last_page_len, + self.forward_metadata.max_q_len, + sm_scale=layer.scaling, + logit_cap=layer.logit_cap, + work_meta_data=work_metadata, + work_indptr=work_indptr, + work_info_set=work_info_set, + reduce_indptr=reduce_indptr, + reduce_final_map=reduce_final_map, + reduce_partial_map=reduce_partial_map, + q_scale=layer.k_scale, + kv_scale=layer.k_scale, + intra_batch_mode=_sglang_aiter.intra_batch_mode, + num_kv_splits=num_kv_splits, + ) + + return o diff --git a/atom/plugin/register.py b/atom/plugin/register.py index af2427fbf..55ee75bfb 100644 --- a/atom/plugin/register.py +++ b/atom/plugin/register.py @@ -3,6 +3,7 @@ from atom.models.qwen3 import Qwen3ForCausalLM from atom.models.qwen3_moe import Qwen3MoeForCausalLM from atom.models.glm4_moe import Glm4MoeForCausalLM +from atom.models.deepseek_v2 import DeepseekV3ForCausalLM from atom.config import Config from atom.plugin.prepare import is_vllm, is_sglang @@ -12,6 +13,7 @@ "Qwen3ForCausalLM": Qwen3ForCausalLM, "Qwen3MoeForCausalLM": Qwen3MoeForCausalLM, "Glm4MoeForCausalLM": Glm4MoeForCausalLM, + "DeepseekV3ForCausalLM": DeepseekV3ForCausalLM, } @@ -28,9 +30,9 @@ def _register_custom_attention_to_sglang() -> None: @register_attention_backend("aiter") def create_atom_backend(runner): - from sglang.srt.layers.attention.aiter_backend import AiterAttnBackend + from atom.plugin.attention_backend.sgl_attn_backend import ATOMAttnBackendForSgl - return AiterAttnBackend(runner) + return ATOMAttnBackendForSgl(runner) def register_ops_to_sglang(atom_config: Config) -> None: diff --git a/atom/utils/envs.py b/atom/utils/envs.py index ea543b5f7..ae3345146 100644 --- a/atom/utils/envs.py +++ b/atom/utils/envs.py @@ -75,6 +75,7 @@ "ATOM_DISABLE_VLLM_PLUGIN_ATTENTION", "0" ).lower() == "1", + "ATOM_ROPE_FUSED_QKNORM": lambda: os.getenv("AITER_ROPE_FUSED_QKNORM", "0") == "1", } From d2014329352f243a9c66d44e31e8a86f182032e3 Mon Sep 17 00:00:00 2001 From: Guanbao Yu Date: Wed, 11 Feb 2026 19:03:24 +0800 Subject: [PATCH 02/11] make format happy --- atom/config.py | 4 +- atom/model_ops/radix_attention.py | 8 +- atom/models/qwen3_moe.py | 39 ++- .../attention_backend/sgl_attn_backend.py | 309 ++++++++++++------ 4 files changed, 237 insertions(+), 123 deletions(-) diff --git a/atom/config.py b/atom/config.py index 36312efdc..62d77209d 100644 --- a/atom/config.py +++ b/atom/config.py @@ -831,7 +831,9 @@ def __post_init__(self): if ( eos_ids := getattr(self.generation_config, "eos_token_id", None) ) is not None: - self.stop_token_ids = [eos_ids] if isinstance(eos_ids, int) else eos_ids + self.stop_token_ids = ( + [eos_ids] if isinstance(eos_ids, int) else eos_ids + ) if not hasattr(self.hf_config, "rope_parameters"): # Compatible with both transformers < 5 rope_params = getattr(self.hf_config, "rope_scaling", {}) or {} diff --git a/atom/model_ops/radix_attention.py b/atom/model_ops/radix_attention.py index c85b0251e..4340311b0 100644 --- a/atom/model_ops/radix_attention.py +++ b/atom/model_ops/radix_attention.py @@ -96,7 +96,13 @@ def forward_impl_plugin_mode( forward_batch = kwargs.get("forward_batch", None) assert forward_batch is not None, "forward_batch is required for sglang" # forward_batch contains the filed attn_backend, which will find the backend registered in ATOM - return self.attn(query, key, value, forward_batch=forward_batch, save_kv_cache=not self.use_aiter_rope_fused_qknorm) + return self.attn( + query, + key, + value, + forward_batch=forward_batch, + save_kv_cache=not self.use_aiter_rope_fused_qknorm, + ) else: raise NotImplementedError( "RadixAttention is only supported for plugin mode for sglang for now" diff --git a/atom/models/qwen3_moe.py b/atom/models/qwen3_moe.py index 71eec8b17..279845595 100644 --- a/atom/models/qwen3_moe.py +++ b/atom/models/qwen3_moe.py @@ -45,6 +45,7 @@ ) ENABLE_AITER_ROPE_FUSED_QKNORM_FOR_SGL_PLUGIN_MODE = envs.ATOM_ROPE_FUSED_QKNORM + class Qwen3MoeMLP(nn.Module): def __init__( self, @@ -243,24 +244,30 @@ def forward_sgl_plugin_mode( if ENABLE_AITER_ROPE_FUSED_QKNORM_FOR_SGL_PLUGIN_MODE: forward_batch = model_kwargs.get("forward_batch", None) assert forward_batch is not None, "forward_batch is required for sglang" - k_buffer, v_buffer = forward_batch.token_to_kv_pool.get_kv_buffer(self.layer_num) + k_buffer, v_buffer = forward_batch.token_to_kv_pool.get_kv_buffer( + self.layer_num + ) block_size = 1024 # Default fallback - if hasattr(forward_batch, 'attn_backend') and hasattr(forward_batch.attn_backend, 'page_size'): + if hasattr(forward_batch, "attn_backend") and hasattr( + forward_batch.attn_backend, "page_size" + ): block_size = forward_batch.attn_backend.page_size - elif hasattr(forward_batch.token_to_kv_pool, 'allocator') and hasattr(forward_batch.token_to_kv_pool.allocator, 'page_size'): + elif hasattr(forward_batch.token_to_kv_pool, "allocator") and hasattr( + forward_batch.token_to_kv_pool.allocator, "page_size" + ): block_size = forward_batch.token_to_kv_pool.allocator.page_size - elif hasattr(forward_batch.token_to_kv_pool, 'page_size'): + elif hasattr(forward_batch.token_to_kv_pool, "page_size"): block_size = forward_batch.token_to_kv_pool.page_size x = 16 // k_buffer.element_size() aiter_fused_set_kv_buffer_arg = AiterFusedSetKVBufferArg( - kv_cache = (k_buffer, v_buffer), - cache_loc = forward_batch.out_cache_loc, - k_scale = self.k_scale, - v_scale = self.v_scale, - return_kv = True, - use_shuffle_layout = True, - block_size = block_size, - x = x, + kv_cache=(k_buffer, v_buffer), + cache_loc=forward_batch.out_cache_loc, + k_scale=self.k_scale, + v_scale=self.v_scale, + return_kv=True, + use_shuffle_layout=True, + block_size=block_size, + x=x, ) q, k, v = self.rotary_emb( qkv, @@ -282,9 +289,7 @@ def forward_sgl_plugin_mode( q, k = self.rotary_emb(positions, q, k) - attn_output = self.attn( - q, k, v, positions=positions, **model_kwargs - ) + attn_output = self.attn(q, k, v, positions=positions, **model_kwargs) return attn_output def forward( @@ -304,7 +309,9 @@ def forward( ) else: if is_sglang(): - attn_output = self.forward_sgl_plugin_mode(positions, qkv, **model_kwargs) + attn_output = self.forward_sgl_plugin_mode( + positions, qkv, **model_kwargs + ) else: # Add qk-norm q = self.q_norm(q) diff --git a/atom/plugin/attention_backend/sgl_attn_backend.py b/atom/plugin/attention_backend/sgl_attn_backend.py index 2ececc348..7f2216511 100644 --- a/atom/plugin/attention_backend/sgl_attn_backend.py +++ b/atom/plugin/attention_backend/sgl_attn_backend.py @@ -93,10 +93,7 @@ def reshape_and_cache_shuffle_kernel( dst_offset + offset // x * block_size * x + block_offset * x + offset % x ) dst_v_shuffle_offset = ( - dst_offset - + block_offset // x * head_size * x - + offset * x - + block_offset % x + dst_offset + block_offset // x * head_size * x + offset * x + block_offset % x ) k_val = tl.load(key_ptr + src_offset_k + offset) v_val = tl.load(value_ptr + src_offset_v + offset) @@ -110,6 +107,7 @@ def reshape_and_cache_shuffle_kernel( tl.store(key_cache_ptr + dst_k_shuffle_offset, k_val) tl.store(value_cache_ptr + dst_v_shuffle_offset, v_val) + def reshape_and_cache_shuffle_triton( key: torch.Tensor, value: torch.Tensor, @@ -161,6 +159,7 @@ def reshape_and_cache_shuffle_triton( QUANT=QUANT, ) + @dataclass class ForwardMetadata: # kv_indptr and kv_indices are only used in MLA mode, optional for non-MLA mode @@ -214,7 +213,13 @@ def __init__( self.q_dtype = model_runner.dtype # Save q dtype for pa_metadata building +<<<<<<< HEAD # assert not self.use_mla, "MLA mode is not implemented yet in ATOMAttnBackendForSgl." +======= + assert ( + not self.use_mla + ), "MLA mode is not implemented yet in ATOMAttnBackendForSgl." +>>>>>>> bfc8900 (make format happy) # Pre-initialized qo_indptr for pa_persistent_fwd decode mode: [0, 1, 2, ..., max_bs] # In decode mode, each sequence has 1 token, so this is always [0, 1, 2, ..., batch_size] @@ -226,7 +231,9 @@ def __init__( (max_bs,), dtype=torch.int32, device=model_runner.device ) self.page_table = torch.zeros( - (max_bs, self.max_context_len // self.page_size), dtype=torch.int32, device=model_runner.device + (max_bs, self.max_context_len // self.page_size), + dtype=torch.int32, + device=model_runner.device, ) # Pre-compute strided indices for page_table construction (used in both CUDA Graph and non-CUDA Graph modes) self.strided_indices = torch.arange( @@ -235,7 +242,9 @@ def __init__( if not self.use_mla: # Pre-allocate buffers for pa_persistent_fwd (used in both CUDA graph and non-CUDA graph modes) - max_num_blocks_per_seq = (self.max_context_len + self.page_size - 1) // self.page_size + max_num_blocks_per_seq = ( + self.max_context_len + self.page_size - 1 + ) // self.page_size max_total_blocks = max_bs * max_num_blocks_per_seq self.pa_kv_indices = torch.zeros( max_total_blocks, dtype=torch.int32, device=self.device @@ -251,13 +260,12 @@ def __init__( # Pre-allocated descale tensors for FP8 attention (q, k, v all use scale=1.0) - self.logits_soft_cap = 0.0 self.forward_metadata: ForwardMetadata = None - + self.pa_metadata_buffers = None - + k_buffer, _ = model_runner.token_to_kv_pool.get_kv_buffer(first_full_attn_id) num_slots, num_kv_heads, _ = k_buffer.shape block_size = self.page_size @@ -271,9 +279,6 @@ def __init__( ) self.decode_using_pa_ps = self.page_size == 1024 - - - def init_forward_metadata(self, forward_batch: ForwardBatch): """Init auxiliary variables for triton attention backend.""" bs = forward_batch.batch_size @@ -370,34 +375,53 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): seq_lens_cpu = forward_batch.seq_lens_cpu if seq_lens_cpu is None: seq_lens_cpu = forward_batch.seq_lens.cpu() - + # Common setup consistent with CUDA Graph mode (init_forward_metadata_replay_cuda_graph) page_table_persistent = self.page_table seq_lens_persistent = self.seq_lens seq_lens_persistent.fill_(0) page_table_persistent.fill_(0) - seq_lens_persistent[:bs].copy_(forward_batch.seq_lens, non_blocking=True) - max_seq_pages = (seq_lens_cpu.max().item() + self.page_size - 1) // self.page_size + 1 - page_table = self.req_to_token[forward_batch.req_pool_indices[:, None], self.strided_indices[:max_seq_pages][None, :],] - page_table_persistent[:bs, :max_seq_pages].copy_(page_table // self.page_size, non_blocking=True) + seq_lens_persistent[:bs].copy_( + forward_batch.seq_lens, non_blocking=True + ) + max_seq_pages = ( + seq_lens_cpu.max().item() + self.page_size - 1 + ) // self.page_size + 1 + page_table = self.req_to_token[ + forward_batch.req_pool_indices[:, None], + self.strided_indices[:max_seq_pages][None, :], + ] + page_table_persistent[:bs, :max_seq_pages].copy_( + page_table // self.page_size, non_blocking=True + ) else: - page_table = forward_batch.req_to_token_pool.req_to_token[forward_batch.req_pool_indices, :] - + page_table = forward_batch.req_to_token_pool.req_to_token[ + forward_batch.req_pool_indices, : + ] + self.forward_metadata = ForwardMetadata( kv_indptr, kv_indices, None, # qo_indptr not used in non-MLA mode None, # kv_last_page_len not used in non-MLA mode - 1, # max_q_len = 1 for decode mode + 1, # max_q_len = 1 for decode mode None, - page_table_persistent[:bs, :max_seq_pages] if self.decode_using_pa_ps else page_table, - seq_lens_persistent[:bs] if self.decode_using_pa_ps else forward_batch.seq_lens, + ( + page_table_persistent[:bs, :max_seq_pages] + if self.decode_using_pa_ps + else page_table + ), + ( + seq_lens_persistent[:bs] + if self.decode_using_pa_ps + else forward_batch.seq_lens + ), ) - + # Build pa_metadata for pa_persistent_fwd if self.decode_using_pa_ps: self._build_pa_metadata_for_decode(bs, tp_q_head_num=self.num_head) - # return # Early return for non-MLA decode mode + # return # Early return for non-MLA decode mode else: prefix_lens = forward_batch.extend_prefix_lens @@ -488,11 +512,15 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): spec_info=None, ) # Get page_table for mha_batch_prefill_func - page_table = forward_batch.req_to_token_pool.req_to_token[forward_batch.req_pool_indices, :] + page_table = forward_batch.req_to_token_pool.req_to_token[ + forward_batch.req_pool_indices, : + ] self.forward_metadata = ForwardMetadata( self.indices_updater_prefill.kv_indptr, self.indices_updater_prefill.kv_indices, - self.qo_indptr[: bs + 1], # qo_indptr is set by indices_updater_prefill + self.qo_indptr[ + : bs + 1 + ], # qo_indptr is set by indices_updater_prefill None, self.indices_updater_prefill.max_q_len, self.indices_updater_prefill.max_kv_len, @@ -500,22 +528,34 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): forward_batch.seq_lens, ) - if (forward_batch.forward_mode.is_extend() and - not self.use_mla and - self.forward_metadata.page_table is not None): + if ( + forward_batch.forward_mode.is_extend() + and not self.use_mla + and self.forward_metadata.page_table is not None + ): if self.page_size > 1: seq_lens_cpu = forward_batch.seq_lens_cpu if seq_lens_cpu is None: seq_lens_cpu = forward_batch.seq_lens.cpu() - max_seq_pages = (seq_lens_cpu.max().item() + self.page_size - 1) // self.page_size + 1 + max_seq_pages = ( + seq_lens_cpu.max().item() + self.page_size - 1 + ) // self.page_size + 1 self.forward_metadata.page_table = ( - self.forward_metadata.page_table[:, self.strided_indices[:max_seq_pages]] // self.page_size + self.forward_metadata.page_table[ + :, self.strided_indices[:max_seq_pages] + ] + // self.page_size ) if self.decode_using_pa_ps: self._build_pa_metadata_for_prefill(forward_batch.batch_size) - if not self.decode_using_pa_ps and self.page_size > 1 and self.forward_metadata.page_table is not None: + if ( + not self.decode_using_pa_ps + and self.page_size > 1 + and self.forward_metadata.page_table is not None + ): self.forward_metadata.page_table = ( - self.forward_metadata.page_table[:, self.strided_indices] // self.page_size + self.forward_metadata.page_table[:, self.strided_indices] + // self.page_size ) def _allocate_pa_metadata_buffers( @@ -536,90 +576,112 @@ def _allocate_pa_metadata_buffers( """Allocate or reuse pa_metadata buffers.""" if self.pa_metadata_buffers is None: self.pa_metadata_buffers = {} - + def _get_size_val(size): return size[0] if isinstance(size, tuple) else size - + # Allocate work_metadata_ptrs size_val = _get_size_val(work_metadata_ptrs_size) - if ("work_metadata_ptrs" not in self.pa_metadata_buffers or - self.pa_metadata_buffers["work_metadata_ptrs"].shape[0] < size_val): + if ( + "work_metadata_ptrs" not in self.pa_metadata_buffers + or self.pa_metadata_buffers["work_metadata_ptrs"].shape[0] < size_val + ): self.pa_metadata_buffers["work_metadata_ptrs"] = torch.empty( - work_metadata_ptrs_size, dtype=work_metadata_ptrs_type, device=self.device + work_metadata_ptrs_size, + dtype=work_metadata_ptrs_type, + device=self.device, ) - + # Allocate work_indptr size_val = _get_size_val(work_indptr_size) - if ("work_indptr" not in self.pa_metadata_buffers or - self.pa_metadata_buffers["work_indptr"].shape[0] < size_val): + if ( + "work_indptr" not in self.pa_metadata_buffers + or self.pa_metadata_buffers["work_indptr"].shape[0] < size_val + ): self.pa_metadata_buffers["work_indptr"] = torch.zeros( work_indptr_size, dtype=work_indptr_type, device=self.device ) else: self.pa_metadata_buffers["work_indptr"].zero_() - + # Allocate work_info size_val = _get_size_val(work_info_size) - if ("work_info" not in self.pa_metadata_buffers or - len(self.pa_metadata_buffers["work_info"].shape) < len(work_info_size) or - self.pa_metadata_buffers["work_info"].shape[0] < size_val): + if ( + "work_info" not in self.pa_metadata_buffers + or len(self.pa_metadata_buffers["work_info"].shape) < len(work_info_size) + or self.pa_metadata_buffers["work_info"].shape[0] < size_val + ): self.pa_metadata_buffers["work_info"] = torch.zeros( work_info_size, dtype=work_info_type, device=self.device ) else: self.pa_metadata_buffers["work_info"].zero_() - + # Allocate reduce_indptr size_val = _get_size_val(reduce_indptr_size) - if ("reduce_indptr" not in self.pa_metadata_buffers or - self.pa_metadata_buffers["reduce_indptr"].shape[0] < size_val): + if ( + "reduce_indptr" not in self.pa_metadata_buffers + or self.pa_metadata_buffers["reduce_indptr"].shape[0] < size_val + ): self.pa_metadata_buffers["reduce_indptr"] = torch.zeros( reduce_indptr_size, dtype=reduce_indptr_type, device=self.device ) else: self.pa_metadata_buffers["reduce_indptr"].zero_() - + # Allocate reduce_final_map size_val = _get_size_val(reduce_final_map_size) - if ("reduce_final_map" not in self.pa_metadata_buffers or - len(self.pa_metadata_buffers["reduce_final_map"].shape) < len(reduce_final_map_size) or - self.pa_metadata_buffers["reduce_final_map"].shape[0] < size_val): + if ( + "reduce_final_map" not in self.pa_metadata_buffers + or len(self.pa_metadata_buffers["reduce_final_map"].shape) + < len(reduce_final_map_size) + or self.pa_metadata_buffers["reduce_final_map"].shape[0] < size_val + ): self.pa_metadata_buffers["reduce_final_map"] = torch.zeros( reduce_final_map_size, dtype=reduce_final_map_type, device=self.device ) else: self.pa_metadata_buffers["reduce_final_map"].zero_() - + # Allocate reduce_partial_map - reduce_partial_map_size_val = reduce_partial_map_size if isinstance(reduce_partial_map_size, int) else reduce_partial_map_size[0] - if ("reduce_partial_map" not in self.pa_metadata_buffers or - self.pa_metadata_buffers["reduce_partial_map"].shape[0] < reduce_partial_map_size_val): + reduce_partial_map_size_val = ( + reduce_partial_map_size + if isinstance(reduce_partial_map_size, int) + else reduce_partial_map_size[0] + ) + if ( + "reduce_partial_map" not in self.pa_metadata_buffers + or self.pa_metadata_buffers["reduce_partial_map"].shape[0] + < reduce_partial_map_size_val + ): self.pa_metadata_buffers["reduce_partial_map"] = torch.zeros( - reduce_partial_map_size, dtype=reduce_partial_map_type, device=self.device + reduce_partial_map_size, + dtype=reduce_partial_map_type, + device=self.device, ) else: self.pa_metadata_buffers["reduce_partial_map"].zero_() def _build_pa_metadata_for_decode( - self, - batch_size: int, + self, + batch_size: int, tp_q_head_num: Optional[int] = None, ): """Build pa_metadata buffers for pa_persistent_fwd in decode mode. - + This method prepares all metadata buffers needed for pa_persistent_fwd kernel. The metadata can be reused across multiple layers in the same forward pass. - + Args: batch_size: Batch size for the current forward pass tp_q_head_num: Number of Q heads per TP rank. If None, uses self.num_head. """ max_qlen = 1 - + # Use provided tp_q_head_num or default to self.num_head if tp_q_head_num is None: tp_q_head_num = self.num_head - + # kv_dtype_for_metadata = dtypes.fp8 ( (work_metadata_ptrs_size, work_metadata_ptrs_type), @@ -648,22 +710,22 @@ def _build_pa_metadata_for_decode( reduce_partial_map_type, ) qo_indptr = self.pa_decode_qo_indptr[: batch_size + 1] - + # Get context_lens (kv_lens is always set before calling _build_pa_metadata_for_decode) # Note: kv_lens comes from self.seq_lens which is already int32 context_lens = self.forward_metadata.kv_lens - + kernel_block_size = self.page_size num_blocks_per_seq = (context_lens + kernel_block_size - 1) // kernel_block_size # Use dedicated pa_kv_indptr buffer (similar to self.kv_indptr, but for pa_persistent_fwd) pages_kv_indptr = self.pa_kv_indptr[: batch_size + 1] pages_kv_indptr[1 : batch_size + 1] = torch.cumsum(num_blocks_per_seq, dim=0) - + # Convert page_table to kv_indices (block indices) using Triton kernel to avoid sync # page_table shape: [batch_size, max_num_blocks_per_seq] # Note: page_table comes from self.page_table which is already int32 and always set before this call page_table = self.forward_metadata.page_table - + # Use Triton kernel to gather kv_indices from page_table (avoids high-level indexing sync) create_flashinfer_kv_indices_triton[(batch_size,)]( page_table, @@ -719,7 +781,7 @@ def _build_pa_metadata_for_prefill(self, batch_size: int): # Page-level kv_indptr (reuse pa_kv_indptr buffer) pages_kv_indptr = self.pa_kv_indptr[: batch_size + 1] pages_kv_indptr[1 : batch_size + 1] = torch.cumsum(num_blocks_per_seq, dim=0) - + # Build kv_indices from page_table using triton kernel page_table = self.forward_metadata.page_table create_flashinfer_kv_indices_triton[(batch_size,)]( @@ -759,11 +821,11 @@ def init_cuda_graph_state( # Always use preshuffle layout for pa_fwd_asm self.page_table = torch.zeros( - (max_bs, self.max_context_len // self.page_size), dtype=torch.int32, device=self.device - ) - self.seq_lens = torch.zeros( - (max_bs,), dtype=torch.int32, device=self.device + (max_bs, self.max_context_len // self.page_size), + dtype=torch.int32, + device=self.device, ) + self.seq_lens = torch.zeros((max_bs,), dtype=torch.int32, device=self.device) self.strided_indices = torch.arange( 0, self.max_context_len, self.page_size, device=self.device ) @@ -1045,8 +1107,12 @@ def set_kv_buffer_with_layout_shuffle( num_slots, num_kv_heads, head_dim = k_buffer.shape num_blocks = num_slots // block_size num_slots_with_block = num_blocks * block_size - k_buffer = k_buffer[:num_slots_with_block].view(num_blocks, block_size, num_kv_heads, head_dim) - v_buffer = v_buffer[:num_slots_with_block].view(num_blocks, block_size, num_kv_heads, head_dim) + k_buffer = k_buffer[:num_slots_with_block].view( + num_blocks, block_size, num_kv_heads, head_dim + ) + v_buffer = v_buffer[:num_slots_with_block].view( + num_blocks, block_size, num_kv_heads, head_dim + ) reshape_and_cache_shuffle_triton( k, v, @@ -1466,7 +1532,16 @@ def forward_decode_pa( k_buffer, v_buffer = forward_batch.token_to_kv_pool.get_kv_buffer( layer.layer_id ) - self.set_kv_buffer_with_layout_shuffle(forward_batch.out_cache_loc, k, v, k_buffer, v_buffer, layer.k_scale, layer.v_scale, self.page_size) + self.set_kv_buffer_with_layout_shuffle( + forward_batch.out_cache_loc, + k, + v, + k_buffer, + v_buffer, + layer.k_scale, + layer.v_scale, + self.page_size, + ) if self.use_mla: k_buffer = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id) @@ -1505,13 +1580,19 @@ def forward_decode_pa( ) else: - k_buffer, v_buffer = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id) + k_buffer, v_buffer = forward_batch.token_to_kv_pool.get_kv_buffer( + layer.layer_id + ) block_size = self.page_size num_slots, num_kv_heads, head_size = k_buffer.shape num_blocks = num_slots // block_size - k_buffer = k_buffer[:num_blocks * block_size].view(num_blocks, block_size, num_kv_heads, head_size) - v_buffer = v_buffer[:num_blocks * block_size].view(num_blocks, block_size, num_kv_heads, head_size) + k_buffer = k_buffer[: num_blocks * block_size].view( + num_blocks, block_size, num_kv_heads, head_size + ) + v_buffer = v_buffer[: num_blocks * block_size].view( + num_blocks, block_size, num_kv_heads, head_size + ) x = 16 // k_buffer.element_size() k_cache_template = torch.empty( @@ -1555,7 +1636,11 @@ def forward_decode_pa_ps( # In decode mode, q.shape[0] equals batch_size (each sequence has 1 token) # Use q.shape[0] instead of forward_batch.batch_size to be safe batch_size = q.shape[0] - head_dim_out = layer.v_head_dim if layer.qk_head_dim != layer.v_head_dim else layer.head_dim + head_dim_out = ( + layer.v_head_dim + if layer.qk_head_dim != layer.v_head_dim + else layer.head_dim + ) o = q.new_empty((batch_size, layer.tp_q_head_num, head_dim_out)) if save_kv_cache: @@ -1605,13 +1690,18 @@ def forward_decode_pa_ps( num_kv_splits=num_kv_splits, ) else: - k_buffer, v_buffer = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id) + k_buffer, v_buffer = forward_batch.token_to_kv_pool.get_kv_buffer( + layer.layer_id + ) num_slots, num_kv_heads, head_size = k_buffer.shape block_size = self.page_size num_blocks = num_slots // block_size - k_buffer = k_buffer[:num_blocks * block_size].view(num_blocks, block_size, num_kv_heads, head_size) - v_buffer = v_buffer[:num_blocks * block_size].view(num_blocks, block_size, num_kv_heads, head_size) - + k_buffer = k_buffer[: num_blocks * block_size].view( + num_blocks, block_size, num_kv_heads, head_size + ) + v_buffer = v_buffer[: num_blocks * block_size].view( + num_blocks, block_size, num_kv_heads, head_size + ) quant_dtype = dtypes.fp8 x = 16 // quant_dtype.itemsize @@ -1628,27 +1718,35 @@ def forward_decode_pa_ps( ) new_key_cache = k_buffer.view_as(k_cache_template) new_value_cache = v_buffer.view_as(v_cache_template) - + total_tokens = num_blocks * block_size k_qscale = self.k_qscale[:, :total_tokens] v_qscale = self.v_qscale[:, :total_tokens] - + q = q.view(batch_size, layer.tp_q_head_num, layer.head_dim) - - assert self.forward_metadata.pa_metadata_qo_indptr is not None, "pa_metadata_qo_indptr should be set by _build_pa_metadata_for_decode" - assert self.forward_metadata.pa_metadata_pages_kv_indptr is not None, "pa_metadata_pages_kv_indptr should be set by _build_pa_metadata_for_decode" - assert self.forward_metadata.pa_metadata_kv_indices is not None, "pa_metadata_kv_indices should be set by _build_pa_metadata_for_decode" - assert self.forward_metadata.pa_metadata_context_lens is not None, "pa_metadata_context_lens should be set by _build_pa_metadata_for_decode" - assert self.forward_metadata.pa_metadata_max_qlen is not None, "pa_metadata_max_qlen should be set by _build_pa_metadata_for_decode" - + assert ( + self.forward_metadata.pa_metadata_qo_indptr is not None + ), "pa_metadata_qo_indptr should be set by _build_pa_metadata_for_decode" + assert ( + self.forward_metadata.pa_metadata_pages_kv_indptr is not None + ), "pa_metadata_pages_kv_indptr should be set by _build_pa_metadata_for_decode" + assert ( + self.forward_metadata.pa_metadata_kv_indices is not None + ), "pa_metadata_kv_indices should be set by _build_pa_metadata_for_decode" + assert ( + self.forward_metadata.pa_metadata_context_lens is not None + ), "pa_metadata_context_lens should be set by _build_pa_metadata_for_decode" + assert ( + self.forward_metadata.pa_metadata_max_qlen is not None + ), "pa_metadata_max_qlen should be set by _build_pa_metadata_for_decode" + qo_indptr = self.forward_metadata.pa_metadata_qo_indptr kv_indptr = self.forward_metadata.pa_metadata_pages_kv_indptr kv_indices = self.forward_metadata.pa_metadata_kv_indices context_lens = self.forward_metadata.pa_metadata_context_lens max_qlen = self.forward_metadata.pa_metadata_max_qlen - - + _, _ = pa_persistent_fwd( Q=q, K=new_key_cache, @@ -1667,25 +1765,26 @@ def forward_decode_pa_ps( K_QScale=k_qscale, V_QScale=v_qscale, softmax_scale=layer.scaling, - mask=1, + mask=1, ) return o.view(-1, layer.tp_q_head_num * head_dim_out) - def forward_decode( - self, - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - layer: RadixAttention, - forward_batch: ForwardBatch, - save_kv_cache=True, - ): + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + layer: RadixAttention, + forward_batch: ForwardBatch, + save_kv_cache=True, + ): if self.use_mla: return self._forward_decode_mla(q, k, v, layer, forward_batch, save_kv_cache) else: if self.decode_using_pa_ps: - return self.forward_decode_pa_ps(q, k, v, layer, forward_batch, save_kv_cache) + return self.forward_decode_pa_ps( + q, k, v, layer, forward_batch, save_kv_cache + ) else: return self.forward_decode_pa(q, k, v, layer, forward_batch, save_kv_cache) From 215aab85b19afe7caf3073619077c1912cf8c799 Mon Sep 17 00:00:00 2001 From: ZhiweiYan-96 Date: Mon, 9 Mar 2026 16:13:10 +0000 Subject: [PATCH 03/11] arg parse for format launch --- atom/plugin/config.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/atom/plugin/config.py b/atom/plugin/config.py index 299b905c5..f811b64b2 100644 --- a/atom/plugin/config.py +++ b/atom/plugin/config.py @@ -127,10 +127,13 @@ def _generate_atom_config_from_sglang_config(config: Any): from sglang.srt.configs.load_config import LoadConfig from atom.config import Config, ParallelConfig, CompilationConfig + # Format1: sglang serve --model-path ... + # Format2: python3 -m sglang.launch_server --model-path ... + args_list = sys.argv[2:] if sys.argv[1] == "serve" else sys.argv[1:] # sglang has no global config variable like vllm, # so here construct the server args from sys.argv passed by users # this is the only way to get full arguments - server_args: ServerArgs = prepare_server_args(sys.argv[1:]) + server_args: ServerArgs = prepare_server_args(args_list) sgl_model_config = SglangModelConfig.from_server_args(server_args) sgl_model_opt_config = ModelOptConfig( @@ -222,7 +225,6 @@ def generate_atom_config_for_plugin_mode(config: Any = None): """ logger.info("Generate atom config for plugin mode from passed config") - atom_config = None from atom.plugin import is_vllm, is_sglang from atom.config import set_current_atom_config From 96af643f3560b4c71f56a596cea0d9bb9b2ad4c9 Mon Sep 17 00:00:00 2001 From: ZhiweiYan-96 Date: Thu, 19 Mar 2026 14:46:53 +0000 Subject: [PATCH 04/11] remove print logging --- atom/models/deepseek_v2.py | 135 ++++++------------ .../attention_backend/sgl_attn_backend.py | 60 ++++---- 2 files changed, 67 insertions(+), 128 deletions(-) diff --git a/atom/models/deepseek_v2.py b/atom/models/deepseek_v2.py index fe84de8ed..98f2f649d 100644 --- a/atom/models/deepseek_v2.py +++ b/atom/models/deepseek_v2.py @@ -1507,30 +1507,7 @@ def _forward_sgl_prepare( llama_4_scaling = model_kwargs.get("llama_4_scaling", None) q_lora = None topk_indices = None - # #region agent log - try: - _pos_0 = int(positions.shape[0]) - _hs_0 = int(hidden_states.shape[0]) if hasattr(hidden_states, "shape") else -1 - _tp = int(get_tensor_model_parallel_world_size()) if forward_batch is not None else -1 - with open("/mnt/raid0/zhiyan/plugin_gb/.cursor/debug-17e017.log", "a") as _f: - _f.write(json.dumps({"sessionId": "17e017", "hypothesisId": "A", "location": "deepseek_v2.py:_forward_sgl_prepare_entry", "message": "prepare_entry", "data": {"positions_dim0": _pos_0, "hidden_states_dim0": _hs_0, "tp_world": _tp, "q_lora_rank": getattr(self, "q_lora_rank", None)}, "timestamp": __import__("time").time_ns() // 1000000}) + "\n") - except Exception: - pass - # #endregion if self.q_lora_rank is not None: - print( - f"[MLA_DBG][_forward_sgl_prepare][layer={self.layer_num}] " - f"positions={tuple(positions.shape)} hidden_states={tuple(hidden_states.shape)} " - f"hs_scale={None if hidden_states_scale is None else tuple(hidden_states_scale.shape)} " - f"seq_lens_sum={getattr(forward_batch, 'seq_lens_sum', None)}" - ) - # qkv_lora = self.fused_qkv_a_proj(hidden_states, hidden_states_scale) - # q, latent_cache = torch.split( - # qkv_lora, - # [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], - # dim=-1 - # ) - q, latent_cache = ( get_attn_tp_context() .fetch_qkv_latent() @@ -1539,20 +1516,6 @@ def _forward_sgl_prepare( dim=-1, ) ) - print( - f"[MLA_DBG][_forward_sgl_prepare][layer={self.layer_num}] " - f"fetched q={tuple(q.shape)} latent={tuple(latent_cache.shape)} " - f"positions={tuple(positions.shape)} tp_world={get_tensor_model_parallel_world_size()}" - ) - # #region agent log - try: - _q0, _p0, _tp = int(q.shape[0]), int(positions.shape[0]), int(get_tensor_model_parallel_world_size()) - _fallback_cond = _q0 != _p0 and _tp > 1 - with open("/mnt/raid0/zhiyan/plugin_gb/.cursor/debug-17e017.log", "a") as _f: - _f.write(json.dumps({"sessionId": "17e017", "hypothesisId": "B,C", "location": "deepseek_v2.py:after_fetch", "message": "after_fetch", "data": {"q_dim0": _q0, "positions_dim0": _p0, "tp_world": _tp, "fallback_will_run": _fallback_cond}, "timestamp": __import__("time").time_ns() // 1000000}) + "\n") - except Exception: - pass - # #endregion if q.shape[0] != positions.shape[0] and get_tensor_model_parallel_world_size() > 1: qkv_lora = torch.cat([q, latent_cache], dim=-1) @@ -1567,17 +1530,6 @@ def _forward_sgl_prepare( [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1, ) - print( - f"[MLA_DBG][_forward_sgl_prepare][layer={self.layer_num}] " - f"after_fallback_gather q={tuple(q.shape)} latent={tuple(latent_cache.shape)}" - ) - # #region agent log - try: - with open("/mnt/raid0/zhiyan/plugin_gb/.cursor/debug-17e017.log", "a") as _f: - _f.write(json.dumps({"sessionId": "17e017", "hypothesisId": "C", "location": "deepseek_v2.py:after_fallback", "message": "after_fallback", "data": {"q_dim0": int(q.shape[0]), "positions_dim0": int(positions.shape[0])}, "timestamp": __import__("time").time_ns() // 1000000}) + "\n") - except Exception: - pass - # #endregion k_nope = latent_cache[..., : self.kv_lora_rank] @@ -1687,18 +1639,11 @@ def _forward_sgl_prepare( q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) k_pe = latent_cache[..., self.kv_lora_rank :].unsqueeze(1) - # #region agent log - try: - with open("/mnt/raid0/zhiyan/plugin_gb/.cursor/debug-17e017.log", "a") as _f: - _f.write(json.dumps({"sessionId": "17e017", "hypothesisId": "D,E", "location": "deepseek_v2.py:before_rope_assert", "message": "before_rope", "data": {"q_pe_dim0": int(q_pe.shape[0]), "k_pe_dim0": int(k_pe.shape[0]), "positions_dim0": int(positions.shape[0]), "q_lora_rank": getattr(self, "q_lora_rank", None)}, "timestamp": __import__("time").time_ns() // 1000000}) + "\n") - except Exception: - pass - # #endregion - print( - f"[MLA_DBG][_forward_sgl_prepare][layer={self.layer_num}] " - f"q_nope={tuple(q_nope.shape)} q_pe={tuple(q_pe.shape)} k_pe={tuple(k_pe.shape)} " - f"positions={tuple(positions.shape)}" - ) + # print( + # f"[MLA_DBG][_forward_sgl_prepare][layer={self.layer_num}] " + # f"q_nope={tuple(q_nope.shape)} q_pe={tuple(q_pe.shape)} k_pe={tuple(k_pe.shape)} " + # f"positions={tuple(positions.shape)}" + # ) _is_hip= True if self.use_deep_gemm_bmm: @@ -1743,7 +1688,7 @@ def _forward_sgl_prepare( q_nope_out = batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant( X=q_nope, WQ=self.w_kc.transpose(-1, -2), - w_scale=self.mla_attn.w_scale, + w_scale=self.w_scale, group_size=128, YQ=None, # allocate (B, M, N) transpose_bm=False, # (B, M, N) @@ -1924,15 +1869,15 @@ def prepare_qkv_latent( hidden_states_scale = None if isinstance(hidden_states, tuple): hidden_states, hidden_states_scale = hidden_states - print( - f"[MLA_DBG][prepare_qkv_latent][layer={self.layer_num}] " - f"hidden_states={tuple(hidden_states.shape)} " - f"hs_scale={None if hidden_states_scale is None else tuple(hidden_states_scale.shape)} " - f"seq_lens_sum={getattr(forward_batch, 'seq_lens_sum', None)} " - f"positions={None if getattr(forward_batch, 'positions', None) is None else tuple(forward_batch.positions.shape)}" - ) + # print( + # f"[MLA_DBG][prepare_qkv_latent][layer={self.layer_num}] " + # f"hidden_states={tuple(hidden_states.shape)} " + # f"hs_scale={None if hidden_states_scale is None else tuple(hidden_states_scale.shape)} " + # f"seq_lens_sum={getattr(forward_batch, 'seq_lens_sum', None)} " + # f"positions={None if getattr(forward_batch, 'positions', None) is None else tuple(forward_batch.positions.shape)}" + # ) qkv_lora = self.fused_qkv_a_proj(hidden_states, hidden_states_scale) - print(f"[MLA_DBG][prepare_qkv_latent][layer={self.layer_num}] qkv_lora={tuple(qkv_lora.shape)}") + # print(f"[MLA_DBG][prepare_qkv_latent][layer={self.layer_num}] qkv_lora={tuple(qkv_lora.shape)}") # Fallback: when communicator does not enable input_scattered gather, # force qkv latent token dimension to align with positions. @@ -1949,11 +1894,11 @@ def prepare_qkv_latent( and qkv_lora.shape[0] != expected_tokens and get_tensor_model_parallel_world_size() > 1 ): - print( - f"[MLA_DBG][prepare_qkv_latent][layer={self.layer_num}] before_fallback_gather " - f"qkv_lora={tuple(qkv_lora.shape)} expected={expected_tokens} " - f"tp_world={get_tensor_model_parallel_world_size()}" - ) + # print( + # f"[MLA_DBG][prepare_qkv_latent][layer={self.layer_num}] before_fallback_gather " + # f"qkv_lora={tuple(qkv_lora.shape)} expected={expected_tokens} " + # f"tp_world={get_tensor_model_parallel_world_size()}" + # ) qkv_lora = get_tp_group().all_gather(qkv_lora, dim=0) if qkv_lora.shape[0] > expected_tokens: qkv_lora = qkv_lora[:expected_tokens] @@ -1962,9 +1907,9 @@ def prepare_qkv_latent( f"prepare_qkv_latent gather mismatch: got {qkv_lora.shape[0]}, " f"expected {expected_tokens}" ) - print( - f"[MLA_DBG][prepare_qkv_latent][layer={self.layer_num}] return_qkv_lora={tuple(qkv_lora.shape)} expected={expected_tokens}" - ) + # print( + # f"[MLA_DBG][prepare_qkv_latent][layer={self.layer_num}] return_qkv_lora={tuple(qkv_lora.shape)} expected={expected_tokens}" + # ) return qkv_lora @@ -1981,19 +1926,19 @@ def forward_sgl_plugin_mode( raise RuntimeError("forward_batch is required in forward_sgl_plugin_mode") attn_tp_context = get_attn_tp_context() - print( - f"[MLA_DBG][forward_sgl_plugin_mode][layer={self.layer_num}] " - f"positions={tuple(positions.shape)} " - f"hidden_states={'tuple' if isinstance(hidden_states, tuple) else tuple(hidden_states.shape)} " - f"seq_lens_sum={getattr(forward_batch, 'seq_lens_sum', None)} " - f"input_ids={None if getattr(forward_batch, 'input_ids', None) is None else tuple(forward_batch.input_ids.shape)} " - f"allow_scatter={attn_tp_context.allow_input_scattered}" - ) + # print( + # f"[MLA_DBG][forward_sgl_plugin_mode][layer={self.layer_num}] " + # f"positions={tuple(positions.shape)} " + # f"hidden_states={'tuple' if isinstance(hidden_states, tuple) else tuple(hidden_states.shape)} " + # f"seq_lens_sum={getattr(forward_batch, 'seq_lens_sum', None)} " + # f"input_ids={None if getattr(forward_batch, 'input_ids', None) is None else tuple(forward_batch.input_ids.shape)} " + # f"allow_scatter={attn_tp_context.allow_input_scattered}" + # ) with attn_tp_context.maybe_input_scattered(forward_batch): - print( - f"[MLA_DBG][forward_sgl_plugin_mode][layer={self.layer_num}] " - f"input_scattered={attn_tp_context.input_scattered}" - ) + # print( + # f"[MLA_DBG][forward_sgl_plugin_mode][layer={self.layer_num}] " + # f"input_scattered={attn_tp_context.input_scattered}" + # ) if self.q_lora_rank is not None: attn_tp_context.set_attn_inputs( AttentionInputs( @@ -2404,12 +2349,12 @@ def forward( **model_kwargs: dict[str, Any] | None ) -> torch.Tensor: # Self Attention - print( - f"[MLA_DBG][decoder_layer][layer={self.layer_idx}] positions={tuple(positions.shape)} " - f"hidden_states={'tuple' if isinstance(hidden_states, tuple) else tuple(hidden_states.shape)} " - f"residual={None if residual is None else tuple(residual.shape)} " - f"fuse_input_norm_quant={self.fuse_input_norm_quant}" - ) + # print( + # f"[MLA_DBG][decoder_layer][layer={self.layer_idx}] positions={tuple(positions.shape)} " + # f"hidden_states={'tuple' if isinstance(hidden_states, tuple) else tuple(hidden_states.shape)} " + # f"residual={None if residual is None else tuple(residual.shape)} " + # f"fuse_input_norm_quant={self.fuse_input_norm_quant}" + # ) if self.fuse_input_norm_quant: assert self.quant_dtype is not None weight = self.input_layernorm.weight diff --git a/atom/plugin/attention_backend/sgl_attn_backend.py b/atom/plugin/attention_backend/sgl_attn_backend.py index 7f2216511..fe935fe8e 100644 --- a/atom/plugin/attention_backend/sgl_attn_backend.py +++ b/atom/plugin/attention_backend/sgl_attn_backend.py @@ -213,13 +213,7 @@ def __init__( self.q_dtype = model_runner.dtype # Save q dtype for pa_metadata building -<<<<<<< HEAD # assert not self.use_mla, "MLA mode is not implemented yet in ATOMAttnBackendForSgl." -======= - assert ( - not self.use_mla - ), "MLA mode is not implemented yet in ATOMAttnBackendForSgl." ->>>>>>> bfc8900 (make format happy) # Pre-initialized qo_indptr for pa_persistent_fwd decode mode: [0, 1, 2, ..., max_bs] # In decode mode, each sequence has 1 token, so this is always [0, 1, 2, ..., batch_size] @@ -1818,33 +1812,33 @@ def _forward_decode_mla(self, q, k, v, layer, forward_batch, save_kv_cache): _q_view = q.view(-1, layer.tp_q_head_num, layer.qk_head_dim) _k_view = k_buffer.view(-1, 1, 1, layer.qk_head_dim) _o_view = o.view(-1, layer.tp_q_head_num, layer.v_head_dim) - print( - f"[MLA_DECODE_DBG] layer=0" - f" q={tuple(_q_view.shape)} q.dtype={_q_view.dtype}" - f" k_buf={tuple(_k_view.shape)} k_buf.dtype={_k_view.dtype}" - f" o={tuple(_o_view.shape)} o.dtype={_o_view.dtype}" - f" qo_indptr={self.forward_metadata.qo_indptr.tolist()}" - f" kv_indptr={self.forward_metadata.kv_indptr.tolist()}" - f" kv_indices_len={self.forward_metadata.kv_indices.shape[0]}" - f" kv_indices_max={self.forward_metadata.kv_indices.max().item()}" - f" kv_last_page_len={self.forward_metadata.kv_last_page_len.tolist()}" - f" max_q_len={self.forward_metadata.max_q_len}" - f" sm_scale={layer.scaling}" - f" logit_cap={layer.logit_cap}" - f" k_scale={layer.k_scale}" - f" num_kv_splits={num_kv_splits}" - f" page_size={self.page_size}" - f" work_metadata={tuple(work_metadata.shape) if work_metadata is not None else None}" - f" work_indptr={tuple(work_indptr.shape) if work_indptr is not None else None}" - f" work_info_set={tuple(work_info_set.shape) if work_info_set is not None else None}" - f" reduce_indptr={tuple(reduce_indptr.shape) if reduce_indptr is not None else None} val={reduce_indptr.tolist() if reduce_indptr is not None and reduce_indptr.numel() < 20 else 'big'}" - f" reduce_final_map={tuple(reduce_final_map.shape) if reduce_final_map is not None else None}" - f" reduce_partial_map={tuple(reduce_partial_map.shape) if reduce_partial_map is not None else None}" - f" intra_batch_mode={_sglang_aiter.intra_batch_mode}" - f" _use_mla_ps_kernel={_sglang_aiter._use_mla_ps_kernel}" - f" fast_mode={_sglang_aiter.fast_mode}" - , flush=True, - ) + # print( + # f"[MLA_DECODE_DBG] layer=0" + # f" q={tuple(_q_view.shape)} q.dtype={_q_view.dtype}" + # f" k_buf={tuple(_k_view.shape)} k_buf.dtype={_k_view.dtype}" + # f" o={tuple(_o_view.shape)} o.dtype={_o_view.dtype}" + # f" qo_indptr={self.forward_metadata.qo_indptr.tolist()}" + # f" kv_indptr={self.forward_metadata.kv_indptr.tolist()}" + # f" kv_indices_len={self.forward_metadata.kv_indices.shape[0]}" + # f" kv_indices_max={self.forward_metadata.kv_indices.max().item()}" + # f" kv_last_page_len={self.forward_metadata.kv_last_page_len.tolist()}" + # f" max_q_len={self.forward_metadata.max_q_len}" + # f" sm_scale={layer.scaling}" + # f" logit_cap={layer.logit_cap}" + # f" k_scale={layer.k_scale}" + # f" num_kv_splits={num_kv_splits}" + # f" page_size={self.page_size}" + # f" work_metadata={tuple(work_metadata.shape) if work_metadata is not None else None}" + # f" work_indptr={tuple(work_indptr.shape) if work_indptr is not None else None}" + # f" work_info_set={tuple(work_info_set.shape) if work_info_set is not None else None}" + # f" reduce_indptr={tuple(reduce_indptr.shape) if reduce_indptr is not None else None} val={reduce_indptr.tolist() if reduce_indptr is not None and reduce_indptr.numel() < 20 else 'big'}" + # f" reduce_final_map={tuple(reduce_final_map.shape) if reduce_final_map is not None else None}" + # f" reduce_partial_map={tuple(reduce_partial_map.shape) if reduce_partial_map is not None else None}" + # f" intra_batch_mode={_sglang_aiter.intra_batch_mode}" + # f" _use_mla_ps_kernel={_sglang_aiter._use_mla_ps_kernel}" + # f" fast_mode={_sglang_aiter.fast_mode}" + # , flush=True, + # ) mla_decode_fwd( q.view(-1, layer.tp_q_head_num, layer.qk_head_dim), From 29a9759735201f475f6e3c1c7a348d83e9de5d03 Mon Sep 17 00:00:00 2001 From: zhuyuhua-v Date: Fri, 20 Mar 2026 03:11:14 +0000 Subject: [PATCH 05/11] add Qwen3-235B support for sgl_oot Signed-off-by: zhuyuhua-v --- atom/plugin/sglang/oot/qwen3_moe.py | 97 +++++++++++++++++++++++++++++ 1 file changed, 97 insertions(+) create mode 100644 atom/plugin/sglang/oot/qwen3_moe.py diff --git a/atom/plugin/sglang/oot/qwen3_moe.py b/atom/plugin/sglang/oot/qwen3_moe.py new file mode 100644 index 000000000..b998eaf50 --- /dev/null +++ b/atom/plugin/sglang/oot/qwen3_moe.py @@ -0,0 +1,97 @@ +"""ATOM Qwen model wrapper for SGLang external model loading. + +Registers Qwen3MoeForCausalLM and Qwen3ForCausalLM as external +model classes via SGLANG_EXTERNAL_MODEL_PACKAGE, replacing sglang's +built-in implementations with ATOM-optimized versions. +""" + +import logging +from typing import Iterable, Optional, Tuple, Union + +import torch +from torch import nn + +from sglang.srt.distributed import get_pp_group +from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput +from sglang.srt.layers.quantization.base_config import QuantizationConfig +from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors + +logger = logging.getLogger("atom.plugin.sglang.oot") + + +class Qwen3MoeForCausalLM(nn.Module): + """ATOM-backed Qwen3 MoE model for SGLang. + + This wrapper delegates model creation and weight loading to ATOM's + plugin system, while conforming to sglang's model interface + (forward signature, LogitsProcessorOutput return type, load_weights). + """ + + def __init__( + self, + config, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + logger.info("Initializing ATOM backend for %s", self.__class__.__name__) + + self.pp_group = get_pp_group() + self.quant_config = quant_config + self.config = config + self.vocab_size = config.vocab_size + self.unpadded_vocab_size = config.vocab_size + + import atom + + self.model = atom.prepare_model(config=config, engine="sglang") + if self.model is None: + model_arch = getattr(config, "architectures", ["unknown"])[0] + raise ValueError( + f"ATOM failed to create model for architecture {model_arch}" + ) + + self.logits_processor = LogitsProcessor(config) + + @torch.no_grad() + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + forward_batch: ForwardBatch, + input_embeds: torch.Tensor = None, + get_embedding: bool = False, + pp_proxy_tensors: Optional[PPProxyTensors] = None, + ) -> Union[LogitsProcessorOutput, PPProxyTensors]: + hidden_states = self.model( + input_ids=input_ids, + positions=positions, + intermediate_tensors=None, + inputs_embeds=input_embeds, + forward_batch=forward_batch, + get_embedding=get_embedding, + pp_proxy_tensors=pp_proxy_tensors, + ) + + if self.pp_group.is_last_rank: + return self.logits_processor( + input_ids, + hidden_states, + self.model.lm_head, + forward_batch, + ) + return hidden_states + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + from atom.model_loader.loader import load_model_in_plugin_mode + + return load_model_in_plugin_mode( + model=self.model, config=self.model.atom_config, prefix="model." + ) + + +class Qwen3ForCausalLM(Qwen3MoeForCausalLM): + pass + + +EntryClass = [Qwen3MoeForCausalLM, Qwen3ForCausalLM] From 2f9aa0b0c31efa3c77230ef462aa622a1116bbf8 Mon Sep 17 00:00:00 2001 From: zhuyuhua-v Date: Fri, 20 Mar 2026 03:16:47 +0000 Subject: [PATCH 06/11] add Deepseek-R1 support for sgl_oot Signed-off-by: zhuyuhua-v --- atom/plugin/sglang/oot/deepseek_v2.py | 97 +++++++++++++++++++++++++++ 1 file changed, 97 insertions(+) create mode 100644 atom/plugin/sglang/oot/deepseek_v2.py diff --git a/atom/plugin/sglang/oot/deepseek_v2.py b/atom/plugin/sglang/oot/deepseek_v2.py new file mode 100644 index 000000000..e8f4444e2 --- /dev/null +++ b/atom/plugin/sglang/oot/deepseek_v2.py @@ -0,0 +1,97 @@ +"""ATOM DeepSeek model wrapper for SGLang external model loading. + +Registers DeepseekV3ForCausalLM and DeepseekV2ForCausalLM as external +model classes via SGLANG_EXTERNAL_MODEL_PACKAGE, replacing sglang's +built-in implementations with ATOM-optimized versions. +""" + +import logging +from typing import Iterable, Optional, Tuple, Union + +import torch +from torch import nn + +from sglang.srt.distributed import get_pp_group +from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput +from sglang.srt.layers.quantization.base_config import QuantizationConfig +from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors + +logger = logging.getLogger("atom.plugin.sglang.oot") + + +class DeepseekV2ForCausalLM(nn.Module): + """ATOM-backed DeepSeek v2/v3 model for SGLang. + + This wrapper delegates model creation and weight loading to ATOM's + plugin system, while conforming to sglang's model interface + (forward signature, LogitsProcessorOutput return type, load_weights). + """ + + def __init__( + self, + config, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + logger.info("Initializing ATOM backend for %s", self.__class__.__name__) + + self.pp_group = get_pp_group() + self.quant_config = quant_config + self.config = config + self.vocab_size = config.vocab_size + self.unpadded_vocab_size = config.vocab_size + + import atom + + self.model = atom.prepare_model(config=config, engine="sglang") + if self.model is None: + model_arch = getattr(config, "architectures", ["unknown"])[0] + raise ValueError( + f"ATOM failed to create model for architecture {model_arch}" + ) + + self.logits_processor = LogitsProcessor(config) + + @torch.no_grad() + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + forward_batch: ForwardBatch, + input_embeds: torch.Tensor = None, + get_embedding: bool = False, + pp_proxy_tensors: Optional[PPProxyTensors] = None, + ) -> Union[LogitsProcessorOutput, PPProxyTensors]: + hidden_states = self.model( + input_ids=input_ids, + positions=positions, + intermediate_tensors=None, + inputs_embeds=input_embeds, + forward_batch=forward_batch, + get_embedding=get_embedding, + pp_proxy_tensors=pp_proxy_tensors, + ) + + if self.pp_group.is_last_rank: + return self.logits_processor( + input_ids, + hidden_states, + self.model.lm_head, + forward_batch, + ) + return hidden_states + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + from atom.model_loader.loader import load_model_in_plugin_mode + + return load_model_in_plugin_mode( + model=self.model, config=self.model.atom_config, prefix="model." + ) + + +class DeepseekV3ForCausalLM(DeepseekV2ForCausalLM): + pass + + +EntryClass = [DeepseekV2ForCausalLM, DeepseekV3ForCausalLM] From bcb63f244d225b1755201edee981b089b908b91d Mon Sep 17 00:00:00 2001 From: zhuyuhua-v Date: Mon, 23 Mar 2026 07:28:08 +0000 Subject: [PATCH 07/11] unify base model wrapper Signed-off-by: zhuyuhua-v --- .../{deepseek_v2.py => base_model_wrapper.py} | 36 ++++--- atom/plugin/sglang/oot/qwen3_moe.py | 97 ------------------- 2 files changed, 22 insertions(+), 111 deletions(-) rename atom/plugin/sglang/oot/{deepseek_v2.py => base_model_wrapper.py} (74%) delete mode 100644 atom/plugin/sglang/oot/qwen3_moe.py diff --git a/atom/plugin/sglang/oot/deepseek_v2.py b/atom/plugin/sglang/oot/base_model_wrapper.py similarity index 74% rename from atom/plugin/sglang/oot/deepseek_v2.py rename to atom/plugin/sglang/oot/base_model_wrapper.py index e8f4444e2..97eff70d4 100644 --- a/atom/plugin/sglang/oot/deepseek_v2.py +++ b/atom/plugin/sglang/oot/base_model_wrapper.py @@ -1,8 +1,9 @@ -"""ATOM DeepSeek model wrapper for SGLang external model loading. +"""ATOM model wrappers for SGLang external model loading (OOT). -Registers DeepseekV3ForCausalLM and DeepseekV2ForCausalLM as external -model classes via SGLANG_EXTERNAL_MODEL_PACKAGE, replacing sglang's -built-in implementations with ATOM-optimized versions. +Registers model architecture classes via SGLANG_EXTERNAL_MODEL_PACKAGE, +replacing sglang's built-in implementations with ATOM-optimized versions. + +To add a new model, append its architecture class name to _MODEL_NAMES. """ import logging @@ -18,13 +19,20 @@ logger = logging.getLogger("atom.plugin.sglang.oot") +_MODEL_NAMES = [ + "DeepseekV2ForCausalLM", + "DeepseekV3ForCausalLM", + "Qwen3MoeForCausalLM", + "Qwen3ForCausalLM", +] + -class DeepseekV2ForCausalLM(nn.Module): - """ATOM-backed DeepSeek v2/v3 model for SGLang. +class _AtomCausalLMBaseForSglangOOT(nn.Module): + """Base ATOM model wrapper conforming to sglang's model interface. - This wrapper delegates model creation and weight loading to ATOM's - plugin system, while conforming to sglang's model interface - (forward signature, LogitsProcessorOutput return type, load_weights). + Delegates model creation and weight loading to ATOM's plugin system, + while providing the forward signature and LogitsProcessorOutput return + type that sglang expects. """ def __init__( @@ -90,8 +98,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): ) -class DeepseekV3ForCausalLM(DeepseekV2ForCausalLM): - pass - - -EntryClass = [DeepseekV2ForCausalLM, DeepseekV3ForCausalLM] +EntryClass = [] +for _name in _MODEL_NAMES: + _cls = type(_name, (_AtomCausalLMBaseForSglangOOT,), {}) + globals()[_name] = _cls + EntryClass.append(_cls) diff --git a/atom/plugin/sglang/oot/qwen3_moe.py b/atom/plugin/sglang/oot/qwen3_moe.py deleted file mode 100644 index b998eaf50..000000000 --- a/atom/plugin/sglang/oot/qwen3_moe.py +++ /dev/null @@ -1,97 +0,0 @@ -"""ATOM Qwen model wrapper for SGLang external model loading. - -Registers Qwen3MoeForCausalLM and Qwen3ForCausalLM as external -model classes via SGLANG_EXTERNAL_MODEL_PACKAGE, replacing sglang's -built-in implementations with ATOM-optimized versions. -""" - -import logging -from typing import Iterable, Optional, Tuple, Union - -import torch -from torch import nn - -from sglang.srt.distributed import get_pp_group -from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput -from sglang.srt.layers.quantization.base_config import QuantizationConfig -from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors - -logger = logging.getLogger("atom.plugin.sglang.oot") - - -class Qwen3MoeForCausalLM(nn.Module): - """ATOM-backed Qwen3 MoE model for SGLang. - - This wrapper delegates model creation and weight loading to ATOM's - plugin system, while conforming to sglang's model interface - (forward signature, LogitsProcessorOutput return type, load_weights). - """ - - def __init__( - self, - config, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - ) -> None: - super().__init__() - logger.info("Initializing ATOM backend for %s", self.__class__.__name__) - - self.pp_group = get_pp_group() - self.quant_config = quant_config - self.config = config - self.vocab_size = config.vocab_size - self.unpadded_vocab_size = config.vocab_size - - import atom - - self.model = atom.prepare_model(config=config, engine="sglang") - if self.model is None: - model_arch = getattr(config, "architectures", ["unknown"])[0] - raise ValueError( - f"ATOM failed to create model for architecture {model_arch}" - ) - - self.logits_processor = LogitsProcessor(config) - - @torch.no_grad() - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - forward_batch: ForwardBatch, - input_embeds: torch.Tensor = None, - get_embedding: bool = False, - pp_proxy_tensors: Optional[PPProxyTensors] = None, - ) -> Union[LogitsProcessorOutput, PPProxyTensors]: - hidden_states = self.model( - input_ids=input_ids, - positions=positions, - intermediate_tensors=None, - inputs_embeds=input_embeds, - forward_batch=forward_batch, - get_embedding=get_embedding, - pp_proxy_tensors=pp_proxy_tensors, - ) - - if self.pp_group.is_last_rank: - return self.logits_processor( - input_ids, - hidden_states, - self.model.lm_head, - forward_batch, - ) - return hidden_states - - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): - from atom.model_loader.loader import load_model_in_plugin_mode - - return load_model_in_plugin_mode( - model=self.model, config=self.model.atom_config, prefix="model." - ) - - -class Qwen3ForCausalLM(Qwen3MoeForCausalLM): - pass - - -EntryClass = [Qwen3MoeForCausalLM, Qwen3ForCausalLM] From 24bc65f943028543f547fe2ec2faf02b8c0ff3c8 Mon Sep 17 00:00:00 2001 From: ZhiweiYan-96 Date: Fri, 20 Mar 2026 05:40:58 +0000 Subject: [PATCH 08/11] enable fuse_rope_cat_and_cache_mla on gfx950 --- atom/model_ops/radix_attention.py | 3 ++- atom/models/deepseek_v2.py | 43 ++++++++++++++++++++++++------- 2 files changed, 35 insertions(+), 11 deletions(-) diff --git a/atom/model_ops/radix_attention.py b/atom/model_ops/radix_attention.py index 4340311b0..9ebffce1e 100644 --- a/atom/model_ops/radix_attention.py +++ b/atom/model_ops/radix_attention.py @@ -94,6 +94,7 @@ def forward_impl_plugin_mode( if is_sglang(): # for sglang, forward_batch is required forward_batch = kwargs.get("forward_batch", None) + save_kv_cache = kwargs.get("save_kv_cache", not self.use_aiter_rope_fused_qknorm) assert forward_batch is not None, "forward_batch is required for sglang" # forward_batch contains the filed attn_backend, which will find the backend registered in ATOM return self.attn( @@ -101,7 +102,7 @@ def forward_impl_plugin_mode( key, value, forward_batch=forward_batch, - save_kv_cache=not self.use_aiter_rope_fused_qknorm, + save_kv_cache=save_kv_cache, ) else: raise NotImplementedError( diff --git a/atom/models/deepseek_v2.py b/atom/models/deepseek_v2.py index 98f2f649d..aef5fc625 100644 --- a/atom/models/deepseek_v2.py +++ b/atom/models/deepseek_v2.py @@ -50,6 +50,7 @@ fused_reduce_rms_mxfp4_quant, fused_rms_mxfp4_quant, ) +from aiter.ops.triton.fused_kv_cache import fused_qk_rope_cat_and_cache_mla from aiter.ops.triton.pa_mqa_logits import deepgemm_fp8_paged_mqa_logits from aiter.rotary_embedding import get_rope from atom.config import ( @@ -1480,6 +1481,7 @@ def __init__( self.use_nsa = is_deepseek_nsa(config) self.use_deep_gemm_bmm = False self.alt_stream = None + self.use_fused_qk_rope_concat_and_cache_mla = _use_aiter_gfx95 # self.w_kc, self.w_vc = self.kv_b_proj.weight.data.unflatten( # 0, (-1, self.qk_nope_head_dim + self.v_head_dim) # ).split([self.qk_nope_head_dim, self.v_head_dim], dim=1) @@ -1720,12 +1722,7 @@ def _forward_sgl_prepare( q_nope_out = q_nope_out.transpose(0, 1) - if ( - self.rotary_emb is not None - # and (not self._fuse_rope_for_trtllm_mla(forward_batch)) - and (not _use_aiter or not _is_gfx95_supported or self.use_nsa) - ): - # Optional hard check during debugging + if self.rotary_emb is not None and not self.use_fused_qk_rope_concat_and_cache_mla: assert q_pe.shape[0] == positions.shape[0], ( f"q_pe tokens {q_pe.shape[0]} != positions {positions.shape[0]}" ) @@ -1765,9 +1762,35 @@ def _forward_sgl_core( llama_4_scaling, ): # 1) build q/k for radix attention path - _is_hip = True - q = torch.cat([q_nope_out, q_pe], dim=-1) - k = torch.cat([k_nope, k_pe], dim=-1) + save_kv_cache = True + + if self.use_fused_qk_rope_concat_and_cache_mla: + cos = self.rotary_emb.cos_cache + sin = self.rotary_emb.sin_cache + kv_cache = forward_batch.token_to_kv_pool.get_key_buffer( + self.layer_num + ) + k_scale = self.mla_attn.attn.k_scale + + q, _, k_pe_roped, _ = fused_qk_rope_cat_and_cache_mla( + q_nope_out, + q_pe, + k_nope, + k_pe, + kv_cache, + forward_batch.out_cache_loc, + positions, + cos, + sin, + k_scale, + self.rotary_emb.is_neox_style, + q_out_dtype=q_nope_out.dtype, + ) + k = torch.cat([k_nope, k_pe_roped], dim=-1) + save_kv_cache = False + else: + q = torch.cat([q_nope_out, q_pe], dim=-1) + k = torch.cat([k_nope, k_pe], dim=-1) if llama_4_scaling is not None: q = q * llama_4_scaling @@ -1778,7 +1801,7 @@ def _forward_sgl_core( k, k_nope, forward_batch=forward_batch, - save_kv_cache=True, + save_kv_cache=save_kv_cache, **(dict(topk_indices=topk_indices) if topk_indices is not None else {}), ) attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank) From 288f81c8c1720e49c7af8cf00f3d848171385e73 Mon Sep 17 00:00:00 2001 From: zhuyuhua-v Date: Fri, 20 Mar 2026 07:55:10 +0000 Subject: [PATCH 09/11] add _is_hip = True --- atom/models/deepseek_v2.py | 1 + 1 file changed, 1 insertion(+) diff --git a/atom/models/deepseek_v2.py b/atom/models/deepseek_v2.py index aef5fc625..ddbee0e74 100644 --- a/atom/models/deepseek_v2.py +++ b/atom/models/deepseek_v2.py @@ -1807,6 +1807,7 @@ def _forward_sgl_core( attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank) # 3) up-proj by w_vc (port from sglang forward_absorb_core) + _is_hip = True if self.use_deep_gemm_bmm: attn_output_val, attn_output_scale, masked_m, expected_m, aligned_m = ( per_token_group_quant_mla_deep_gemm_masked_fp8(attn_output.transpose(0, 1)) From d626960441ec25a78fe05a74b85c1e48aef2ab9c Mon Sep 17 00:00:00 2001 From: zhuyuhua-v Date: Mon, 23 Mar 2026 14:17:35 +0000 Subject: [PATCH 10/11] fix AiterTensor check Signed-off-by: zhuyuhua-v --- atom/model_ops/radix_attention.py | 6 ++++-- atom/plugin/attention_backend/sgl_attn_backend.py | 2 +- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/atom/model_ops/radix_attention.py b/atom/model_ops/radix_attention.py index 9ebffce1e..d0a9f0453 100644 --- a/atom/model_ops/radix_attention.py +++ b/atom/model_ops/radix_attention.py @@ -65,11 +65,13 @@ def __init__( ) if self.attn.k_scale is None: self.attn.k_scale = torch.nn.Parameter( - torch.tensor([1.0], dtype=torch.float32), requires_grad=False + torch.tensor([1.0], dtype=torch.float32, device="cuda"), + requires_grad=False, ) if self.attn.v_scale is None: self.attn.v_scale = torch.nn.Parameter( - torch.tensor([1.0], dtype=torch.float32), requires_grad=False + torch.tensor([1.0], dtype=torch.float32, device="cuda"), + requires_grad=False, ) else: raise NotImplementedError( diff --git a/atom/plugin/attention_backend/sgl_attn_backend.py b/atom/plugin/attention_backend/sgl_attn_backend.py index fe935fe8e..4a46f3dfb 100644 --- a/atom/plugin/attention_backend/sgl_attn_backend.py +++ b/atom/plugin/attention_backend/sgl_attn_backend.py @@ -803,7 +803,7 @@ def init_cuda_graph_state( max_num_tokens: int, kv_indices_buf: Optional[torch.Tensor] = None, ): - self.cuda_graph_kv_last_page_len = torch.ones(max_bs, dtype=torch.int) + self.cuda_graph_kv_last_page_len = torch.ones(max_bs, dtype=torch.int, device=self.device) if kv_indices_buf is None: self.cuda_graph_kv_indices = torch.zeros( (max_bs * self.max_context_len), From 710427fb85b1ca6e9b8cde3403f55c643f7c3a43 Mon Sep 17 00:00:00 2001 From: qichu Date: Wed, 25 Mar 2026 02:45:44 +0000 Subject: [PATCH 11/11] [Feature][Plugin] support GLM4.7 for sglang plugin --- .../attention_backend/sgl_attn_backend.py | 50 ++++--- atom/plugin/sglang/oot/glm4_moe.py | 131 ++++++++++++++++++ 2 files changed, 163 insertions(+), 18 deletions(-) create mode 100644 atom/plugin/sglang/oot/glm4_moe.py diff --git a/atom/plugin/attention_backend/sgl_attn_backend.py b/atom/plugin/attention_backend/sgl_attn_backend.py index 4a46f3dfb..d02b97b84 100644 --- a/atom/plugin/attention_backend/sgl_attn_backend.py +++ b/atom/plugin/attention_backend/sgl_attn_backend.py @@ -273,6 +273,18 @@ def __init__( ) self.decode_using_pa_ps = self.page_size == 1024 + if not self.use_mla: + self._num_kv_heads = num_kv_heads + gqa_group_size = self.num_head // num_kv_heads + padded_group = 1 << (gqa_group_size - 1).bit_length() if gqa_group_size & (gqa_group_size - 1) else gqa_group_size + self._gqa_group_size = gqa_group_size + self._padded_gqa_group_size = padded_group + self._padded_q_heads = padded_group * num_kv_heads + self._need_q_pad = self._padded_q_heads != self.num_head + else: + self._padded_q_heads = self.num_head + self._need_q_pad = False + def init_forward_metadata(self, forward_batch: ForwardBatch): """Init auxiliary variables for triton attention backend.""" bs = forward_batch.batch_size @@ -414,7 +426,7 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): # Build pa_metadata for pa_persistent_fwd if self.decode_using_pa_ps: - self._build_pa_metadata_for_decode(bs, tp_q_head_num=self.num_head) + self._build_pa_metadata_for_decode(bs, tp_q_head_num=self._padded_q_heads) # return # Early return for non-MLA decode mode else: prefix_lens = forward_batch.extend_prefix_lens @@ -970,7 +982,7 @@ def init_forward_metadata_capture_cuda_graph( ) if self.decode_using_pa_ps: - self._build_pa_metadata_for_decode(bs, tp_q_head_num=self.num_head) + self._build_pa_metadata_for_decode(bs, tp_q_head_num=self._padded_q_heads) return else: raise ValueError(f"Invalid mode: {forward_mode=}") @@ -1083,7 +1095,7 @@ def init_forward_metadata_replay_cuda_graph( ) if self.decode_using_pa_ps: - self._build_pa_metadata_for_decode(bs, tp_q_head_num=self.num_head) + self._build_pa_metadata_for_decode(bs, tp_q_head_num=self._padded_q_heads) else: raise ValueError("Invalid forward mode") @@ -1719,21 +1731,14 @@ def forward_decode_pa_ps( q = q.view(batch_size, layer.tp_q_head_num, layer.head_dim) - assert ( - self.forward_metadata.pa_metadata_qo_indptr is not None - ), "pa_metadata_qo_indptr should be set by _build_pa_metadata_for_decode" - assert ( - self.forward_metadata.pa_metadata_pages_kv_indptr is not None - ), "pa_metadata_pages_kv_indptr should be set by _build_pa_metadata_for_decode" - assert ( - self.forward_metadata.pa_metadata_kv_indices is not None - ), "pa_metadata_kv_indices should be set by _build_pa_metadata_for_decode" - assert ( - self.forward_metadata.pa_metadata_context_lens is not None - ), "pa_metadata_context_lens should be set by _build_pa_metadata_for_decode" - assert ( - self.forward_metadata.pa_metadata_max_qlen is not None - ), "pa_metadata_max_qlen should be set by _build_pa_metadata_for_decode" + if self._need_q_pad: + nkv = self._num_kv_heads + gs = self._gqa_group_size + pgs = self._padded_gqa_group_size + q = q.view(batch_size, nkv, gs, layer.head_dim) + q = torch.nn.functional.pad(q, (0, 0, 0, pgs - gs)) + q = q.reshape(batch_size, self._padded_q_heads, layer.head_dim) + o = q.new_empty((batch_size, self._padded_q_heads, head_dim_out)) qo_indptr = self.forward_metadata.pa_metadata_qo_indptr kv_indptr = self.forward_metadata.pa_metadata_pages_kv_indptr @@ -1761,6 +1766,15 @@ def forward_decode_pa_ps( softmax_scale=layer.scaling, mask=1, ) + + if self._need_q_pad: + nkv = self._num_kv_heads + gs = self._gqa_group_size + pgs = self._padded_gqa_group_size + o = o.view(batch_size, nkv, pgs, head_dim_out) + o = o[:, :, :gs, :].contiguous() + o = o.view(batch_size, layer.tp_q_head_num, head_dim_out) + return o.view(-1, layer.tp_q_head_num * head_dim_out) def forward_decode( diff --git a/atom/plugin/sglang/oot/glm4_moe.py b/atom/plugin/sglang/oot/glm4_moe.py new file mode 100644 index 000000000..7b01c9a07 --- /dev/null +++ b/atom/plugin/sglang/oot/glm4_moe.py @@ -0,0 +1,131 @@ +"""ATOM GLM-4.7 model wrapper for SGLang external model loading. + +Registers Glm4MoeForCausalLM and GlmMoeDsaForCausalLM as external +model classes via SGLANG_EXTERNAL_MODEL_PACKAGE, replacing sglang's +built-in implementations with ATOM-optimized versions. +""" + +import logging +from typing import Iterable, Optional, Tuple, Union + +import torch +from torch import nn + +from sglang.srt.distributed import get_pp_group +from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput +from sglang.srt.layers.quantization.base_config import QuantizationConfig +from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors + +logger = logging.getLogger("atom.plugin.sglang.oot") + + +def _patch_rope_in_attention_layers(atom_model): + """Inject RoPE into each Glm4MoeAttention's inner RadixAttention. + + The ATOM GLM4 model has RoPE commented out in Glm4MoeAttention.forward(). + This function wraps each ATOM RadixAttention.forward() to apply rotary_emb + before delegating, without modifying any ATOM source files. + """ + from atom.models.glm4_moe import Glm4MoeAttention + + patched = 0 + for module in atom_model.modules(): + if not isinstance(module, Glm4MoeAttention): + continue + + inner_attn = module.attn + original_forward = inner_attn.forward + rotary_emb = module.rotary_emb + + def _make_rope_wrapper(orig_fwd, rope): + def forward_with_rope(query, key, value, positions=None, **kwargs): + if positions is not None: + query, key = rope(positions, query, key) + return orig_fwd(query, key, value, positions, **kwargs) + + return forward_with_rope + + inner_attn.forward = _make_rope_wrapper(original_forward, rotary_emb) + patched += 1 + + logger.info("Patched RoPE into %d Glm4MoeAttention layers", patched) + + +class Glm4MoeForCausalLM(nn.Module): + """ATOM-backed GLM-4.7 MoE model for SGLang. + + This wrapper delegates model creation and weight loading to ATOM's + plugin system, while conforming to sglang's model interface + (forward signature, LogitsProcessorOutput return type, load_weights). + """ + + def __init__( + self, + config, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + logger.info("Initializing ATOM backend for %s", self.__class__.__name__) + + self.pp_group = get_pp_group() + self.quant_config = quant_config + self.config = config + self.vocab_size = config.vocab_size + self.unpadded_vocab_size = config.vocab_size + + import atom + + self.model = atom.prepare_model(config=config, engine="sglang") + if self.model is None: + model_arch = getattr(config, "architectures", ["unknown"])[0] + raise ValueError( + f"ATOM failed to create model for architecture {model_arch}" + ) + + _patch_rope_in_attention_layers(self.model) + + self.logits_processor = LogitsProcessor(config) + + @torch.no_grad() + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + forward_batch: ForwardBatch, + input_embeds: torch.Tensor = None, + get_embedding: bool = False, + pp_proxy_tensors: Optional[PPProxyTensors] = None, + ) -> Union[LogitsProcessorOutput, PPProxyTensors]: + hidden_states = self.model( + input_ids=input_ids, + positions=positions, + intermediate_tensors=None, + inputs_embeds=input_embeds, + forward_batch=forward_batch, + get_embedding=get_embedding, + pp_proxy_tensors=pp_proxy_tensors, + ) + + if self.pp_group.is_last_rank: + return self.logits_processor( + input_ids, + hidden_states, + self.model.lm_head, + forward_batch, + ) + return hidden_states + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + from atom.model_loader.loader import load_model_in_plugin_mode + + return load_model_in_plugin_mode( + model=self.model, config=self.model.atom_config, prefix="model." + ) + + +class GlmMoeDsaForCausalLM(Glm4MoeForCausalLM): + pass + + +EntryClass = [Glm4MoeForCausalLM, GlmMoeDsaForCausalLM]